环境:CentOS 7.9.2009,Python 3.8.12,torch 1.10.1+cu113
先来看一段代码:
# 初始化一个lstm模型,和一个输入,以及手动初始化lstm初始状态h0和c0
rnn = torch.nn.LSTM(640, 512, 1, batch_first=True)
xs = torch.randn([1, 80, 640])
h0, c0 = torch.zeros([1,1,512]), torch.zeros([1,1,512])
# 每次手动输入lstm的初始状态,再用计算之后得到的新的隐含状态更新h0和c0
while True:
_, (h0, c0) = rnn(xs, (h0, c0))
但是上述代码会导致内存不断增长,直到耗尽所有内存,linux主动将其杀掉。
分析(分析没啥用,可以直接看结论)
问题就出在h0和c0,可能我们会觉得,执行rnn的前向过程也只是相当于调用一个函数,用函数的返回值来更新传入的参数是很常规的操作,而且函数接受的实参和返回的值一般会执行拷贝构造,也就是说函数里面的只是一个副本,传出来的也只是一个副本,而且出了作用域,资源就应该释放了,那为什么还会出现这个问题呢?
与C++不同,在python中引用、指针的概念被隐藏了,我们在用一个tensor赋值另一个变量的时候,实际上拿的是它的引用,例如如下例子:
a = torch.Tensor([0,1])