import torch
obs=[2,4,2,6]
obs_tensor = torch.as_tensor(obs, dtype=torch.float32)
print((obs_tensor))
print((obs_tensor).shape)
obs_tensor.unsqueeze(0) #在0的位置加上一维
print(obs_tensor.unsqueeze(0))
print(obs_tensor.unsqueeze(0).shape)
output
tensor([2., 4., 2., 6.])
torch.Size([4])
tensor([[2., 4., 2., 6.]])
torch.Size([1, 4])
torch.Size括号中有几个数字就是几维,具体参考——torch.size: link
**unsqueeze()**这个函数主要是对数据维度进行扩充。
给指定位置加上维数为一的维度,比如原本有个四行的数据(4),unsqueeze(0)后就会在0的位置加了一维就变成一行四列(1,4)。参考链接: link 和 link