一. torch.cat()函数解析
1. 函数说明
1.1 官网:torch.cat(),函数定义及参数说明如下图所示:

1.2 函数功能
函数将两个张量(tensor)按指定维度拼接在一起,注意:除拼接维数dim数值可不同外其余维数数值需相同,方能对齐,如下面例子所示。torch.cat()函数不会新增维度,而torch.stack()函数会新增一个维度,相同的是两个都是对张量进行拼接
2. 代码举例
2.1 输入两个二维张量(dim=0):dim=0对行进行拼接
a = torch.randn(2,3)
b = torch.randn(3,3)
c = torch.cat((a,b),dim=0)
a,b,c
输出结果如下:
(tensor([[-0.90, -0.37, 1.96],
[-2.65, -0.60, 0.05]]),
tensor([[ 1.30, 0.24, 0.27],
[-1.99, -1.09, 1.67],
[-1.62, 1.54, -0.14]]),
tensor([[-0.90, -0.37, 1.96],
[-2.65, -0.60, 0.05],
[ 1.30, 0.24, 0.27],
[-1.99, -1.09, 1.67],
[-1.62, 1.54, -0.14]]))
2.2 输入两个二维张量(dim=1): dim=1对列进行拼接
a = torch.randn(2,3)
b = torch.randn(2,4)
c = torch.cat((a,b),dim=1)
a,b,c
输出结果如下:
(tensor([[-0.55, -0.84, -1.60],
[ 0.39, -0.96, 1.02]]),
tensor([

最低0.47元/天 解锁文章
3181

被折叠的 条评论
为什么被折叠?



