torch里面的函数

这篇博客介绍了PyTorch中的一些关键函数,包括torch.scatter的用法,如何在reduce为mean时进行计算,以及torch.addmm在矩阵乘法中的应用。还探讨了torch.where实现的条件选择,数组切片的实现方式,以及卷积操作和数据约束。此外,提到了torch.clamp用于限制数据范围,并通过错误示例解释了参数传递的注意事项。文章最后讨论了张量的切分和拼接操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import torch
from torch_scatter import scatter
src = torch.range(1,24).view(2,6,2)
index = torch.tensor([0, 1, 0, 1, 2, 1])
# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")
print(src)
print(out)

Pytorch学习 (二十六)---- torch.scatter的使用

tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.],
         [ 7.,  8.],
         [ 9., 10.],
         [11., 12.]],

        [[13., 14.],
         [15., 16.],
         [17., 18.],
         [19., 20.],
         [21., 22.],
         [23., 24.]]])
tensor([[[ 6.,  8.],
         [21., 24.],
         [ 9., 10.]],

        [[30., 32.],
         [57., 60.],
         [21., 22.]]])

这里参考了链接,但是ta只是介绍关于index里面的0的情况,同理为1时,将为1时的情况同时相加,就拿21为例,是3+7+11得到,其他数字同理.

当reduce改为mean时,同样是这样计算,不过是再相加后,除以个数

torch.addmm函数的应用

import torch
a =torch.range(1,4).view(2,2)
bias = torch.zeros(2,2)
b = torch.addmm(bias,a,a) 
print(b)

tensor([[ 7., 10.],
[15., 22.]])

这个函数最重要就是偏置必须设置,如果偏置为0,可以考虑使用torch.matmul函数

idx_train = torch.where(torch.from_numpy(train_mask)==True)

torch.where(条件)
返回数据的具体位置,如上代码,返回了train_mask里面的为True的索引,如果是维度为一,就返回一个位置;如果为二维则返回一个坐标

out_col = out[col]

实现数组切片

from torch_scatter import scatter_add
scatter_add(out_col, row, dim=0, dim_size=x.size(0))

第一个元素为操作对象,第二个是限定操作对象的哪些元素进行相加,第三个元素指的是对操作对象的第几个维度数据进行操作,第四个元素是控制操作的那个维度的大小,具体如下:

from torch_scatter import scatter_add
import torch
from torch_scatter import scatter_add
x = torch.arange(0,48)
y = x.view(6,8)
print(y)
y = torch.tensor(y)
index = [0,1,1,1,2,2]
variablex = torch.tensor(index)
out = scatter_add(y, variablex, dim=0, dim_size=y.size(0))
print(out)
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29, 30, 31],
        [32, 33, 34, 35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44, 45, 46, 47]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [48, 51, 54, 57, 60, 63, 66, 69],
        [72, 74, 76, 78, 80, 82, 84, 86],
        [ 0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0]])
import torch

a=torch.range(1,9)
b=torch.range(1,9)
a = a.view(3,3)
b = b.view(3,3)
print(a,b)
print(a@b)

a@b实现的是卷积操作

b = b.unsqueeze(0) if b.dim() == 2 else b
#如果b的维度为2,则进行升维

将2x3升维为1x2x3

torch.matmul()#用于矩阵乘法
b.transpose(1, 2)

如果b是一个四维的数据,比如torch.Size([2, 3, 4, 2])
进行以上操作后得到torch.Size([2, 4, 3, 2])

clamp指的是torch里面的函数,即将数据限制在一个区间内
参考
pytorch:torch.clamp()

adj.numel()

返回adj里面元素的个数

Exception has occurred: TypeError forward() takes from 3 to 4 positional arguments but 5 were given

找到出错的位置,借鉴TypeError: forward() takes 2 positional arguments but 3 were given

对forward函数里面传递的参数进行查找,最后删除掉了一个参数

r, i, j, k = torch.split(kernel, [dim, dim, dim, dim], dim=1)

对tensor kernel在第二个维度进行切分,其中分成4部分,按照dim:dim:dim:dim的比例进行划分

r2 = torch.cat([r, i, j, k], dim=0)

把张量r,i,j,k在第一个维度下进行拼接,如果要在其他维度下,只需要修改dim的值(这个只能够对张量进行操作)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值