在pytorch的中,对Tensor进行赋值操作时,经常会有只修改其中部分值的需求,可以分为以下几类:
给不同行、不同列的元素赋相同/不同的值给不同行、
相同列的元素赋相同/不同的值给相同行、
不同列的元素赋相同/不同的值
在比较简单的需求下,将Tensor中每行的第3、4列赋值为0,我们可以用最简单语法实现:
a[:, 3:5] = 0
但是一下稍微复杂的需求就麻烦了,比如“将每行的第3、6列赋值为0”,或者“将下标为(1,2)的元素赋值为1,下标为(2,3)的元素赋值为2”,都是无法用上面的方法一步实现的。
本文介绍几种给部分元素赋值的高级方法。
1.scatter方法
scatter()方法按照指定的轴方向(dim)和对应的位置关系(index)逐个填充对应的值
使用场景
对Tensor各行的不同列赋值(可以是相同的值,也可以是多个值)
参数
dim:指定轴方向,以二维Tensor为例:
dim=0表示逐列进行行填充dim=1表示逐列进行行填充index:LongTensor,按照轴方向,在源Tensor中需要填充的位置src:用来进行填充的值:
src为一个数时,用这个数替换所有index位置上的值src为一个Tensor时,其shape必须与index一致,src中的元素会按顺序填充至对应index的位置上
实例
初始化一个2×5的全零矩阵,将其第0行的第0、2个元素 与 第1行的第1、3个元素全部替换为1
a = torch.zeros((2, 5))
index = torch.LongTensor([
[0, 2], # 对应下标(0, 0)、(0, 2)
[1, 3] # 对应下标(1, 1)、(1, 3)
])
a.scatter(1, index, 1)
tensor([[1., 0., 1., 0., 0.],
[0., 1., 0., 1., 0.]])
将其第0行的第0、2个元素 与 第1行的第1、3个元素分别替换为[1,2,3,4]
new_value = torch.Tensor([
[1, 2], [3, 4]
])
a.scatter(1, index, new_value)
>>> tensor([[1., 0., 2., 0., 0.],
[0., 3., 0., 4., 0.]])
2.index_fill方法
index_fill()方法用于在Tensor的指定维度上填充指定的值。
使用场景
对Tensor各行的相同列赋相同的值
参数
dim:给定维度
dim=0表示逐列进行行填充dim=1表示逐列进行行填充index:LongTensor。给定维度下的指定下标val:待填充值
实例
初始化一个2×5的全零矩阵,将其每一行的第0、2个元素用1填充
a = torch.zeros((2, 5))
index = torch.LongTensor([0, 2])
a.index_fill(1, index, 1)
>>> tensor([[1., 0., 1., 0., 0.],
[1., 0., 1., 0., 0.]])
3.index_put方法
使用场景
对二维Tensor不同行、不同列赋值
参数
indices:tuple(Tensor)格式,要填充的索引。注意这个tuple的size==tensor的维度:
第一个Tensor是所有待赋值元素的横坐标第二个Tensor是所有待赋值元素的纵坐标value:Tensor格式,要填充的值,与indices一一对应
实例
初始化一个2×5的全零矩阵,将其第0行的第0、2个元素 与 第1行的第1、3个元素分别替换为[1,2,3,4]
a = torch.zeros((2, 5))
index = (
torch.LongTensor([0, 0, 1, 1]),
torch.LongTensor([0, 2, 1, 3]),
)
new_value = torch.Tensor([1, 2, 3, 4])
a.index_put(index, new_value)
>>> tensor([[1., 0., 2., 0., 0.],
[0., 3., 0., 4., 0.]])