给部分元素赋值的高级方法

在pytorch的中,对Tensor进行赋值操作时,经常会有只修改其中部分值的需求,可以分为以下几类:

给不同行、不同列的元素赋相同/不同的值给不同行、
相同列的元素赋相同/不同的值给相同行、
不同列的元素赋相同/不同的值

在比较简单的需求下,将Tensor中每行的第3、4列赋值为0,我们可以用最简单语法实现:

a[:, 3:5] = 0

但是一下稍微复杂的需求就麻烦了,比如“将每行的第3、6列赋值为0”,或者“将下标为(1,2)的元素赋值为1,下标为(2,3)的元素赋值为2”,都是无法用上面的方法一步实现的。

本文介绍几种给部分元素赋值的高级方法。
1.scatter方法

scatter()方法按照指定的轴方向(dim)和对应的位置关系(index)逐个填充对应的值

使用场景

对Tensor各行的不同列赋值(可以是相同的值,也可以是多个值)

参数
dim:指定轴方向,以二维Tensor为例:
dim=0表示逐列进行行填充dim=1表示逐列进行行填充index:LongTensor,按照轴方向,在源Tensor中需要填充的位置src:用来进行填充的值:
src为一个数时,用这个数替换所有index位置上的值src为一个Tensor时,其shape必须与index一致,src中的元素会按顺序填充至对应index的位置上

实例

初始化一个2×5的全零矩阵,将其第0行的第0、2个元素 与 第1行的第1、3个元素全部替换为1

a = torch.zeros((2, 5))
index = torch.LongTensor([
    [0, 2],       # 对应下标(0, 0)、(0, 2)
    [1, 3]        # 对应下标(1, 1)、(1, 3)
])

a.scatter(1, index, 1)
 tensor([[1., 0., 1., 0., 0.],
            [0., 1., 0., 1., 0.]])

将其第0行的第0、2个元素 与 第1行的第1、3个元素分别替换为[1,2,3,4]

new_value = torch.Tensor([
    [1, 2], [3, 4]
])

a.scatter(1, index, new_value)
>>> tensor([[1., 0., 2., 0., 0.],
            [0., 3., 0., 4., 0.]])

2.index_fill方法

index_fill()方法用于在Tensor的指定维度上填充指定的值。

使用场景

对Tensor各行的相同列赋相同的值

参数
dim:给定维度
dim=0表示逐列进行行填充dim=1表示逐列进行行填充index:LongTensor。给定维度下的指定下标val:待填充值

实例

初始化一个2×5的全零矩阵,将其每一行的第0、2个元素用1填充

a = torch.zeros((2, 5))
index = torch.LongTensor([0, 2])

a.index_fill(1, index, 1)
>>> tensor([[1., 0., 1., 0., 0.],
            [1., 0., 1., 0., 0.]])

3.index_put方法

使用场景

对二维Tensor不同行、不同列赋值

参数
indices:tuple(Tensor)格式,要填充的索引。注意这个tuple的size==tensor的维度:
第一个Tensor是所有待赋值元素的横坐标第二个Tensor是所有待赋值元素的纵坐标value:Tensor格式,要填充的值,与indices一一对应

实例

初始化一个2×5的全零矩阵,将其第0行的第0、2个元素 与 第1行的第1、3个元素分别替换为[1,2,3,4]

a = torch.zeros((2, 5))
index = (
    torch.LongTensor([0, 0, 1, 1]), 
    torch.LongTensor([0, 2, 1, 3]),
)
new_value = torch.Tensor([1, 2, 3, 4])

a.index_put(index, new_value)
>>> tensor([[1., 0., 2., 0., 0.],
            [0., 3., 0., 4., 0.]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值