Pytorch基础 - 8. scatter() / scatter_() 函数

本文介绍了PyTorch中的scatter()和scatter_()函数,这两个函数用于根据索引映射创建或修改Tensor。文章详细解释了函数的参数,提供了2维示例,并阐述了其在one-hot编码中的应用。通过实例,读者可以理解如何沿着特定维度进行索引操作,以及当源数据为标量时的情况。
部署运行你感兴趣的模型镜像

目录

1. scatter() 定义和参数说明

2. 示例和详细解释

3. scatter() 常见用途


1. scatter() 定义和参数说明

scatter() 或 scatter_() 常用来返回根据index映射关系映射后的新的tensor。其中,scatter() 不会直接修改原来的 Tensor,而 scatter_() 直接在原tensor上修改。

官方文档:torch.Tensor.scatter_ — PyTorch 2.0 documentation

 参数定义:

  • dim:沿着哪个维度进行索引
  • index:索引值
  • src:数据源,可以是张量,也可以是标量

简言之 scatter() 是通过 src 来修改另一个张量,修改的元素值和位置由 dim 和 index 决定

2. 示例和详细解释

在官方文档中,给出了3维tensor的具体操作说明,看起来很蒙,没关系继续往下看

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

 接下来的示例,我们以2维为例,那上面的公式简化为如下,

self[index[i][j]][j] = src[i][j]  # if dim == 0
self[i][index[i][j]] = src[i][j]  # if dim == 1

示例:将全零的张量,根据index和scr进行值的变化

src = torch.arange(1, 11).reshape((2, 5))
# src: tensor([[0.8351, 0.2974, 0.9028, 0.4250, 0.0370],
#              [0.4564, 0.6832, 0.6854, 0.6056, 0.7118]])
    
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 4, 2, 3]])
dist = torch.zeros(2, 5, dtype=src.dtype).scatter(1, index, src)
    
# dist: tensor([[0.0370, 0.2974, 0.9028, 0.0000, 0.0000],
#               [0.4564, 0.6832, 0.6056, 0.7118, 0.6854]])

将上述张量使用表格表示: 

当 dim = 1时,dist[i[index[i][j]] = src[i][j],所以具体的计算如下

  • i=0, j=0时,dist[0][index[0][0]] = src[0][0], dist[0][1] = 0.8351
  • i=0, j=1时,dist[0][index[0][1]] = src[0][1], dist[0][0] = 0.2974
  • i=0, j=2时,dist[0][index[0][2]] = src[0][2], dist[0][2] = 0.9028
  • i=0, j=3时,dist[0][index[0][3]] = src[0][3], dist[0][4] = 0.4250
  • i=0, j=4时,dist[0][index[0][4]] = src[0][4], dist[0][3] = 0.0370
  • i=1, j=0时,dist[1][index[1][0]] = src[1][0], dist[1][0] = 0.4564
  • i=1, j=1时,dist[1][index[1][1]] = src[1][1], dist[1][1] = 0.6832
  • i=1, j=2时,dist[1][index[1][2]] = src[1][2], dist[1][4] = 0.6854
  • i=1, j=3时,dist[1][index[1][3]] = src[1][3], dist[1][2] = 0.6056
  • i=1, j=4时,dist[1][index[1][4]] = src[1][4], dist[1][3] = 0.7118

当 dim = 0时,赋值方式不同,即dist[index[i][j]][j] = src[i][j],计算的过程与上述同理,不再赘述

当 src 为标量时,更为简单。因为全部都用该标量值进行赋值即可 

3. scatter() 常见用途

scatter() 常用来对标签进行 one-hot 编码

注意:label的类型是torch.LongTensor

class_num, batch_size = 10, 4
label = torch.tensor([[6], [0], [3], [2]], dtype=torch.long)
one_hot = torch.zeros(batch_size, class_num).scatter(1, label, 1)
print(one_hot)
# tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
#        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
#        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
#        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])

这个程序比较好理解,dim =1, label = index, src = 1 ,所以

  • dist[0][label[0][0]] = dist[0][6] = 1
  • dist[1][label[1][0]] = dist[1][0] = 1
  • dist[2][label[2][0]] = dist[2][3] = 1
  • dist[3][label[3][0]] = dist[3][2] = 1

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

<think>用户现在遇到的问题是运行Mask3D的实例分割代码时,出现了OSError,提示在加载torch_scatter的CUDA版本时找不到符号_ZN3c1017RegisterOperatorsD1Ev。这看起来像是PyTorch的C++ API中的符号未找到,通常是由于PyTorch版本和扩展库(如torch_scatter)之间的版本不兼容导致的。 首先,我需要分析错误信息。错误信息中提到的符号是_ZN3c1017RegisterOperatorsD1Ev,这看起来是c10命名空间下的一个析构函数。c10是PyTorch的核心C++库,所以这个符号应该属于PyTorch的某个版本。如果torch_scatter编译时使用的PyTorch版本与当前环境中的PyTorch版本不一致,就可能导致符号解析失败。 接下来,用户的环境是conda环境mask3d,Python 3.10,安装了torch_scatter。需要确认torch_scatter的版本是否与PyTorch版本兼容。例如,PyTorch可能有多个版本(如1.10、1.11等),而torch_scatter需要针对特定版本编译。 可能的解决方案包括: 1. 重新安装与当前PyTorch版本兼容的torch_scatter。 2. 检查PyTorch的版本,并确保所有扩展库都使用相同的版本。 3. 更新PyTorch和相关的CUDA工具包,确保版本匹配。 4. 检查是否有多个PyTorch安装冲突,例如同时存在conda安装和pip安装的版本。 5. 确认CUDA和cuDNN的版本是否与PyTorch版本兼容。 此外,用户之前遇到的问题涉及cuDNN符号链接和PyTorch路径问题,可能环境配置存在不一致。需要确保所有库的路径正确,并且没有版本冲突。 现在,我需要整理这些思路,给出具体的解决步骤,并确保回答后提出相关的问题。</think>该错误由PyTorch C++扩展符号不匹配引发,需从版本兼容性、编译环境、符号解析三方面解决: 1. **验证PyTorch与扩展版本匹配** 检查PyTorch版本并安装对应torch_scatter: ```bash python -c "import torch; print(torch.__version__)" # 假设输出为2.0.1,则安装: pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu118.html ``` 2. **重建C++扩展符号表** 强制重新编译PyTorch扩展: ```bash cd /home/yc/workspace/Mask3D-main rm -rf build dist *.egg-info python setup.py clean --all python setup.py develop ``` 3. **符号解析深度检查** 使用nm工具验证符号存在性: ```bash nm -D /home/yc/anaconda3/envs/mask3d/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so | grep _ZN3c1017RegisterOperatorsD1Ev ``` 若未找到符号,说明PyTorch安装包损坏,需重新安装。 4. **环境隔离修复** 创建新的conda环境确保纯净: ```bash conda create -n mask3d_new python=3.10 conda activate mask3d_new conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu118.html ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值