import torch
obs=[2,4,2,6]
obs_tensor = torch.as_tensor(obs, dtype=torch.float32)
print((obs_tensor))
print((obs_tensor).shape)
obs_tensor.unsqueeze(0) #在0的位置加上一维
print(obs_tensor.unsqueeze(0))
print(obs_tensor.unsqueeze(0).shape)
output
tensor([2., 4., 2., 6.])
torch.Size([4])
tensor([[2., 4., 2., 6.]])
torch.Size([1, 4])
torch.Size括号中有几个数字就是几维,具体参考——torch.size: link
**unsqueeze()**这个函数主要是对数据维度进行扩充。
给指定位置加上维数为一的维度,比如原本有个四行的数据(4),unsqueeze(0)后就会在0的位置加了一维就变成一行四列(1,4)。参考链接: link 和 link
本文介绍了PyTorch中unsqueeze()函数的作用,通过示例展示了如何使用该函数在指定位置增加一个维度。文章指出,unsqueeze()能够帮助将一维数据转换成二维数据,这对于处理神经网络输入尤其有用。在示例中,一个四元素的向量通过unsqueeze(0)变为了一维大小为1的矩阵。理解unsqueeze()对于理解和操作张量的维度至关重要。
8449

被折叠的 条评论
为什么被折叠?



