神经网络训练中断后,想要load之前训练保存的模型权重文件继续训练(pytorch)
报错:
"If capturable=False, state_steps should not be CUDA tensors." AssertionError: If capturable=False, state_steps should not be CUDA tensors.
说是我的adam优化器里面 capturable 参数设置为 False,这个导致load进之前保存的模型参数后,进行优化器的更新.step( )时,导致失败,因为之前训练的时候为了训练速度快一点,默认设置capturable参数为False,函数定义如下:
但是load进之前的权重参数后,为了把优化器的数据放到gpu上,需要capture in a CUDA graph,所以就需要设置优化器的capturable参数为True.
具体修改:
在load 进之前的.pth文件以及更新了各个权重参数后的位置,添加以下代码手动设置优化器的caturable参数为True: