文章目录
1.PyTorch的基本数据类型–张量
Tensor的基本数据类型有五种:
(1)32位浮点型:torch.FloatTensor (注:pytorch.Tensor()默认的就是这种类型。)
(2)64位整型:torch.LongTensor
(3)32位整型:torch.IntTensor
(4)16位整型:torch.ShortTensor
(5)64位浮点型:torch.DoubleTensor
2.数据类型的判断
注意:
numpy和Tensor的最大区别就在于对GPU的支持上,Tensor可以通过调用cuda()函数将其转化为能在GPU上运行的类型,同一个Tensor部署在CPU和GPU上面的数据类型是不一样的。
3.不同维度的张量
(1)Dim 0(最简单的数据类型)
Dimension为0的张量,等价于一个标量,通常用于损失函数Loss
(2)Dim 1
Dimension为1的张量,等价于一个向量,通常用于模型的偏置Bias、神经网络的输入输出Linear Input/Output
注意:
.tensor()接收的是数据的内容,即数据本身;.FloatTensor()接收的是数据的shape
(3)Dim 2
Dimension为2的张量,等价于一个矩阵,常用于带批量大小的神经网络的输入输出,即[batch,linear_input]
(4)Dim 3
Dimension为3的张量,常用于RNN模型的输入信息,即[batch,num_word,word_embedding]
(5)Dim 4
Dim为4的张量,常用于CNN模型的输入信息,即[batch,channel,height,weight],适合表达图片数据类型