PyTorch张量:深入解析与操作指南
1. 张量类型转换
在操作中混合输入类型时,输入会自动转换为更大的类型。若要进行32位计算,需确保所有输入最多为32位。示例代码如下:
points_64 = torch.rand(5, dtype=torch.double)
points_short = points_64.to(torch.short)
points_64 * points_short
# works from PyTorch 1.3 onwards
# 输出结果
# tensor([0., 0., 0., 0., 0.], dtype=torch.float64)
2. 张量API概述
大部分张量操作可在 torch 模块中找到,也能作为张量对象的方法调用。例如 transpose 函数:
# 从torch模块调用
a = torch.ones(3, 2)
a_t = torch.transpose(a, 0, 1)
print(a.shape, a_t.shape)
# 输出:(torch.Size([3, 2]), torch.Size([2, 3]))
# 作为张量对象的方法调用
a = torch.ones(3, 2)
a_t = a.transpose(0, 1)
print(a.shape, a_t.shape)
# 输出:(torch.Size([3, 2]), torch.Size([2, 3]))
超级会员免费看
订阅专栏 解锁全文
2259

被折叠的 条评论
为什么被折叠?



