PyTorch张量:从基础到高级应用
1. 张量类型转换与运算
在操作中混合输入类型时,输入会自动转换为更大的类型。若要进行32位计算,需确保所有输入最多为32位。例如:
points_64 = torch.rand(5, dtype=torch.double)
points_short = points_64.to(torch.short)
points_64 * points_short
这里, to 方法会检查是否需要转换,若需要则进行转换。像 float 这类以 dtype 命名的转换方法是 to 的简写,但 to 方法可以接受额外参数。
2. 张量API概述
PyTorch的张量操作大多可在 torch 模块中找到,也能作为张量对象的方法调用。例如转置操作:
# 使用torch模块
a = torch.ones(3, 2)
a_t = torch.transpose(a, 0, 1)
a.shape, a_t.shape
# 使用张量对象方法
a = torch.ones(3, 2)
a_t = a.transpose(0, 1)
a.shape, a_t.shape
这两种形式效果相同,可互换使用。PyTorch的在线文档(http://pytorc
超级会员免费看
订阅专栏 解锁全文
1425

被折叠的 条评论
为什么被折叠?



