import torch
from torch_scatter import scatter
src = torch.range(1,24).view(2,6,2)
index = torch.tensor([0, 1, 0, 1, 2, 1])
# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")
print(src)
print(out)
Pytorch学习 (二十六)---- torch.scatter的使用
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.],
[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.],
[17., 18.],
[19., 20.],
[21., 22.],
[23., 24.]]])
tensor([[[ 6., 8.],
[21., 24.],
[ 9., 10.]],
[[30., 32.],
[57., 60.],
[21., 22.]]])
这里参考了链接,但是ta只是介绍关于index里面的0的情况,同理为1时,将为1时的情况同时相加,就拿21为例,是3+7+11得到,其他数字同理.
当reduce改为mean时,同样是这样计算,不过是再相加后,除以个数
torch.addmm函数的应用
import torch
a =torch.range(1,4).view(2,2)
bias = torch.zeros(2,2)
b = torch.addmm(bias,a,a)
print(b)
tensor([[ 7., 10.],
[15., 22.]])
这个函数最重要就是偏置必须设置,如果偏置为0,可以考虑使用torch.matmul函数
idx_train = torch.where(torch.from_numpy(train_mask)==True)
torch.where(条件)
返回数据的具体位置,如上代码,返回了train_mask里面的为True的索引,如果是维度为一,就返回一个位置;如果为二维则返回一个坐标
out_col = out[col]
实现数组切片
from torch_scatter import scatter_add
scatter_add(out_col, row, dim=0, dim_size=x.size(0))
第一个元素为操作对象,第二个是限定操作对象的哪些元素进行相加,第三个元素指的是对操作对象的第几个维度数据进行操作,第四个元素是控制操作的那个维度的大小,具体如下:
from torch_scatter import scatter_add
import torch
from torch_scatter import scatter_add
x = torch.arange(0,48)
y = x.view(6,8)
print(y)
y = torch.tensor(y)
index = [0,1,1,1,2,2]
variablex = torch.tensor(index)
out = scatter_add(y, variablex, dim=0, dim_size=y.size(0))
print(out)
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15],
[16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31],
[32, 33, 34, 35, 36, 37, 38, 39],
[40, 41, 42, 43, 44, 45, 46, 47]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[48, 51, 54, 57, 60, 63, 66, 69],
[72, 74, 76, 78, 80, 82, 84, 86],
[ 0, 0, 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0]])
import torch
a=torch.range(1,9)
b=torch.range(1,9)
a = a.view(3,3)
b = b.view(3,3)
print(a,b)
print(a@b)
a@b实现的是卷积操作
b = b.unsqueeze(0) if b.dim() == 2 else b
#如果b的维度为2,则进行升维
将2x3升维为1x2x3
torch.matmul()#用于矩阵乘法
b.transpose(1, 2)
如果b是一个四维的数据,比如torch.Size([2, 3, 4, 2])
进行以上操作后得到torch.Size([2, 4, 3, 2])
clamp指的是torch里面的函数,即将数据限制在一个区间内
参考
pytorch:torch.clamp()
adj.numel()
返回adj里面元素的个数
Exception has occurred: TypeError forward() takes from 3 to 4 positional arguments but 5 were given
找到出错的位置,借鉴TypeError: forward() takes 2 positional arguments but 3 were given
对forward函数里面传递的参数进行查找,最后删除掉了一个参数
r, i, j, k = torch.split(kernel, [dim, dim, dim, dim], dim=1)
对tensor kernel在第二个维度进行切分,其中分成4部分,按照dim:dim:dim:dim的比例进行划分
r2 = torch.cat([r, i, j, k], dim=0)
把张量r,i,j,k在第一个维度下进行拼接,如果要在其他维度下,只需要修改dim的值(这个只能够对张量进行操作)