在使用Pytorch赋值过程中有一些小细节需要注意。
先出示实验结果
1. 初始化一个值为1的tensor
torch.Tensor([1])
2. 正确的赋值方式:
w.data[0]=a
以下展示实验记录
In [1]: import torch
# 初始化参数 w,a ,b,
In [2]: w = torch.nn.Parameter(torch.zeros(4))
In [3]: w
Out[3]:
Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True)
In [4]: a = torch.Tensor([1])
In [5]: b = torch.Tensor(1)
In [6]: a
Out[6]: tensor([1.])
In [7]: b
Out[7]: tensor([0.])
In [8]: w[0].data=a
# 赋值失败
In [9]: w
Out[9]:
Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True)
In [10]: w.data[0]=a
# 赋值成功
In [11]: w
Out[11]:
Parameter containing:
tensor([1., 0., 0., 0.], requires_grad=True)