解决使用copy.deepcopy()拷贝Tensor或model时报错只支持用户显式创建的Tensor问题

本文探讨了在PyTorch中使用copy.deepcopy()时遇到的问题,即无法拷贝requires_grad=True的非叶子节点Tensor,并提供了相应的解决办法。

模型训练过程中常需边训练边做validation或在训练完的模型需要做测试,通常的做法当然是先创建model实例然后掉用load_state_dict()装载训练出来的权重到model里再调用model.eval()把模型转为测试模式,这样写对于训练完专门做测试时当然是比较合适的,但是对于边训练边做validation使用这种方式就需要写一堆代码,如果能使用copy.deepcopy()直接深度拷贝训练中的model用来做validation显然是比较简洁的写法,但是由于copy.deepcopy()的限制,写model里代码时如果没注意,调用copy.deepcopy(model)时可能就会遇到这个错误:Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment,详细错误信息如下:

 File "/usr/local/lib/python3.6/site-packages/prc/framework/model/validation.py", line 147, in init_val_model
    val_model = copy.deepcopy(model)
  File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib64/python3.6/copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib64/python3.6/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/usr/lib64/python3.6/copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib64/python3.6/copy.py", line 306, in _reconstruct
    value = deepcopy(value, memo)
  File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib64/python3.6/copy.py", line 280, in _reconstruct
    state = deepcopy(sta
import numpy as np import torch import torch.optim as optim import torch.nn.functional as F import copy from model import Model class DQN: def __init__(self, state_dim, action_dim, device, GAMMA): self.device = device self.GAMMA = GAMMA self.network = Model(state_dim, action_dim).to(self.device) self.target_network = copy.deepcopy(self.network).to(self.device) self.optimizer = optim.Adam(self.network.parameters(), lr=5e-4) def update(self, batch): states, actions, rewards, next_states, dones = zip(*batch) states = torch.from_numpy(np.array(states)).float().to(self.device) actions = torch.from_numpy(np.array(actions)).to(self.device).unsqueeze(1) rewards = torch.from_numpy(np.array(rewards)).float().to(self.device).unsqueeze(1) next_states = torch.from_numpy(np.array(next_states)).float().to(self.device) dones = torch.from_numpy(np.array(dones)).to(self.device).unsqueeze(1) with torch.no_grad(): # Double DQN argmax = self.network(next_states).detach().max(1)[1].unsqueeze(1) target = rewards + (GAMMA * self.target_network(next_states).detach().gather(1, argmax))*(~dones) Q_current = self.network(states).gather(1, actions) self.optimizer.zero_grad() loss = F.mse_loss(target, Q_current) loss.backward() self.optimizer.step() def act(self, state): state = torch.tensor(state).to(self.device).float() with torch.no_grad(): Q_values = self.network(state.unsqueeze(0)) return np.argmax(Q_values.cpu().data.numpy()) def update_target(self): self.target_network = copy.deepcopy(self.network)逐行分析 并解释原理以及python中相关用法
最新发布
10-03
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Arnold-FY-Chen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值