Pytorch LSTM内存泄漏问题

在使用PyTorch的LSTM时,发现内存不断增长直至耗尽。问题源于LSTM中h0和c0的引用,而非C++中的拷贝构造。在Python中,tensor赋值实际上是引用传递,导致旧tensor未被释放。虽然不存在循环引用,但LSTM的动态计算图在while循环中不断扩展,使得内存线性增长。解决方案是通过detach()方法使tensor脱离计算图,避免内存累积。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

环境:CentOS 7.9.2009,Python 3.8.12,torch 1.10.1+cu113

先来看一段代码:

# 初始化一个lstm模型,和一个输入,以及手动初始化lstm初始状态h0和c0
rnn = torch.nn.LSTM(640, 512, 1, batch_first=True)
xs = torch.randn([1, 80, 640])
h0, c0 = torch.zeros([1,1,512]), torch.zeros([1,1,512])
​
# 每次手动输入lstm的初始状态,再用计算之后得到的新的隐含状态更新h0和c0
while True:
    _, (h0, c0) = rnn(xs, (h0, c0))

但是上述代码会导致内存不断增长,直到耗尽所有内存,linux主动将其杀掉。

分析(分析没啥用,可以直接看结论)

问题就出在h0和c0,可能我们会觉得,执行rnn的前向过程也只是相当于调用一个函数,用函数的返回值来更新传入的参数是很常规的操作,而且函数接受的实参和返回的值一般会执行拷贝构造,也就是说函数里面的只是一个副本,传出来的也只是一个副本,而且出了作用域,资源就应该释放了,那为什么还会出现这个问题呢?

与C++不同,在python中引用、指针的概念被隐藏了,我们在用一个tensor赋值另一个变量的时候,实际上拿的是它的引用,例如如下例子:

a = torch.Tensor([0,1])
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值