PyTorch张量:从基础到高级应用
1. 张量类型转换
在操作中混合输入类型时,输入会自动转换为更大的类型。如果我们需要进行32位计算,就必须确保所有输入最多为32位。例如:
points_64 = torch.rand(5, dtype=torch.double)
points_short = points_64.to(torch.short)
points_64 * points_short
这里, to 方法会检查是否需要转换,如果需要则进行转换。像 float 这样的以 dtype 命名的转换方法是 to 的简写,但 to 方法可以接受额外的参数。
2. 张量API概述
PyTorch提供了丰富的张量操作,大多数操作既可以通过 torch 模块调用,也可以作为张量对象的方法调用。例如转置操作:
# 方法一:通过torch模块调用
a = torch.ones(3, 2)
a_t = torch.transpose(a, 0, 1)
print(a.shape, a_t.shape)
# 方法二:作为张量对象的方法调用
a = torch.ones(3, 2)
a_t = a.transpose(0, 1)
print(a.shape, a_t.shape)
<
超级会员免费看
订阅专栏 解锁全文
35

被折叠的 条评论
为什么被折叠?



