torch.cat函数参数解析
作用:torch.cat将多个tensor拼接而成一个tensor,实现了在不同维度上的tensor拼接。
参数:inputs: 待拼接的两个张量序列,要拼接两个单独的张量,可以采用inputs=[tensor_1,tensor_2]的方法。
dim: 拼接的维度,下面将区分两种拼接维度的区别,一看就懂,易上手。
示例代码:
import torch
tensor_a = torch.tensor([[1,2],[2,3],[3,4],[4,5]])
tensor_b = torch.tensor([[0,1],[-1,0],[-2,-1],[-3,-2]])
tensor_c1 = torch.cat((tensor_a,tensor_b),dim=0)
tensor_c2 = torch.cat((tensor_a,tensor_b),dim=1)
print(tensor_c1)
print(tensor_c2)
运行结果:

dim=0的情况下的拼接运算:

dim=1的情况下的拼接运算:

torch.stack
作用:连接张量,同torch.cat不同的是,连接不是拼接,拼接讲的是在两个张量进行张量内的拼接,而连接指的是两个张量之间进行连接。
参数:inputs: 待连接的两个张量序列,要在张量间连接两个单独的张量,可以采用inputs=[tensor_1,tensor_2]的方法。
dim: 拼接的维度,下面将区分两种拼接维度的区别,一看就懂,易上手。
示例代码:
import torch
tensor_a = torch.tensor([[1,2],[2,3],[3,4],[4,5]])
tensor_b = torch.tensor([[0,1],[-1,0],[-2,-1],[-3,-2]])
tensor_c1 = torch.stack((tensor_a,tensor_b),dim=0)
tensor_c2 = torch.stack((tensor_a,tensor_b),dim=1)
print(tensor_c1)
print(tensor_c2)
dim=0时:

dim=1时:

one-hot独热编码报错:one-hot is only applicable to index tensor.
解决办法:
检查传入参数是不是torch.int32类型的,如果是,请改为torch.int64。
参考博客:pytorch 独热编码报错的解决办法
本文详细介绍了PyTorch中两种重要的张量操作:torch.cat用于拼接多个tensor成一个tensor;torch.stack用于在指定维度连接两个张量。同时,针对one-hot编码时出现的错误提供了有效的解决方法。
2878

被折叠的 条评论
为什么被折叠?



