Pytorchd的scatter函数详解

部署运行你感兴趣的模型镜像

y = y.scatter(dim, index, src)

target.scatter(dim, index, src)

target:即目标张量,将在该张量上进行映射

dim 沿着哪个维度进行索引 指定轴方向,定义了填充方式。

index:用来scatter的元素索引,在target张量中需要填充的位置

src:即源张量,将把该张量上的元素逐个映射到目标张量上

当为二维时:

在这里插入代码片

y [ index[i][j] ] [j] = src[i][j] 定列对行

import torch
a = torch.arange(15).reshape(3,5).float()
b = torch.zeros(3,5)
la =torch.LongTensor([[1, 2, 1, 1, 2], [2, 0, 2, 1, 0],[0, 1, 0, 2, 1]])
b_= b.scatter(dim=0, index=la,src=a)
print(‘index\n’,la)
print(‘src\n’,a)
print(‘b_zeors\n’,b)
print(‘b_out\n’,b_)
print(“dim=1********\n”)

下面为输出结果:

index
tensor([[1, 2, 1, 1, 2],
[2, 0, 2, 1, 0],
[0, 1, 0, 2, 1]])
src
tensor([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.]])
b_zeors
tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
b_out
tensor([[10., 6., 12., 0., 9.],
[ 0., 11., 2., 8., 14.],
[ 5., 1., 7., 13., 4.]])

y[i] [ index[i][j] ] = src[i][j] dim=1 定行找列

import torch
mini_batch = 4 #行
out_planes = 6 #列
out_put = torch.rand(mini_batch, out_planes) #随机生成一个4x6的矩阵
softmax = torch.nn.Softmax(dim=1) #设置softmax=1
out_put = softmax(out_put) #softmax为行列和都为1的矩阵
print(‘sf_out_put\n’,out_put)
label = torch.tensor([1,3,3,5]).unsqueeze(1)
print(‘index\n’,label)
print(‘one_hot_label_zero\n’,torch.zeros(mini_batch, out_planes))
one_hot_label = torch.zeros(mini_batch, out_planes).scatter_(1,label,1) # 将out_put置零向量按照label给的位置把0替换为1 unsqueeze(1)用来升维
#torch.zeros(mini_batch, out_planes) 将out_put矩阵置0
print(‘one_hot_label\n’,one_hot_label)

*下面为输出结果:

sf_out_put
tensor([[0.1231, 0.1056, 0.2420, 0.2162, 0.2070, 0.1061],
[0.2469, 0.2268, 0.1362, 0.1493, 0.1223, 0.1185],
[0.1040, 0.2340, 0.2137, 0.0924, 0.1851, 0.1707],
[0.1931, 0.1484, 0.2069, 0.0881, 0.1892, 0.1742]])
index
tensor([[1],
[3],
[3],
[5]])
one_hot_label_zero
tensor([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]])
one_hot_label
tensor([[0., 1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 1.]])
引用文章:
https://blog.youkuaiyun.com/t20134297/article/details/105755817
https://www.cnblogs.com/vvzhang/p/14152210.html

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值