加减乘除
add, sub,mul,div
[+,-,*,/]
矩阵相乘
element-wise 相乘 “*”
matrix mul matmul和@含义一致;mm只适用于2d数据相乘;
对于高维矩阵的运算:
a = torch.rand(3, 3, 224, 126)
b = torch.rand(3, 3, 224, 126)
c = torch.matmul(a, b.t())
特别:
b = torch.rand(3, 1, 126, 256)
c = torch.matmul(a, b)
c.shape = [3, 3, 224, 256]
先boardcasting后矩阵后两个维度相乘
approximation
.floor(), ceil()
向上取整和向下取整
clamp
clamp(min)
clamp(min, max)
裁剪特征或者权重,防止梯度爆炸和弥散;
高阶
where
where(cond, a, b)
在条件cond下,如成立则a,不成立则b;
cond,a,b的尺寸都要保持一致;
可用于gpu高效运算判定;
gather-收集
对于文本类别,如何将索引列表和对应索引label进行高效索引,使用gather。
gather公式看起来很复杂,简而言之,对于index使用gather可将对应数字序号转化为标签,从而转换标签;
torch.gather(input, dim. index, out)
out[i][j][k] = input[i][j][index[i][j][k]]