torch.cat((tensor1,tensor2), dim)
将两个tensor连接起来,具体如何连接见下面例子
x = torch.rand((2,2,3))
y = torch.rand((2,2,3))
print("x:",x)
print("y:",y)
print("dim=0:", torch.cat((x,y),dim=0))
print("dim=1:", torch.cat((x,y), dim=1))
print("dim=2:", torch.cat((x, y), dim=2))
输出:
x: tensor([[[0.2571, 0.9011, 0.7935],
[0.9308, 0.3267, 0.3290]],
[[0.6155, 0.4739, 0.7251],
[0.8025, 0.0424, 0.8101]]])
y: tensor([[[0.8813, 0.1149, 0.7757],
[0.4733, 0.9003, 0.3300]],
[[0.2597, 0.5810, 0.2507],
[0.1220, 0.2260, 0.5620]]])
dim=0: tensor([[[0.2571, 0.9011, 0.7935],
[0.9308, 0.3267, 0.3290]],
[[0.6155, 0.4739, 0.7251],
[0.8025, 0.0424, 0.8101]],
[[0.8813, 0.1149, 0.7757],
[0.4733, 0.9003, 0.3300]],
[[0.2597, 0.5810, 0.2507],
[0.1220, 0.2260, 0.5620]]])
dim=1: tensor([[[0.2571, 0.9011, 0.7935],
[0.9308, 0.3267, 0.3290],
[0.8813, 0.1149, 0.7757],
[0.4733, 0.9003, 0.3300]],
[[0.6155, 0.4739, 0.7251],
[0.8025, 0.0424, 0.8101],
[0.2597, 0.5810, 0.2507],
[0.1220, 0.2260, 0.5620]]])
dim=2: tensor([[[0.2571, 0.9011, 0.7935, 0.8813, 0.1149, 0.7757],
[0.9308, 0.3267, 0.3290, 0.4733, 0.9003, 0.3300]],
[[0.6155, 0.4739, 0.7251, 0.2597, 0.5810, 0.2507],
[0.8025, 0.0424, 0.8101, 0.1220, 0.2260, 0.5620]]])
[Finished in 2.1s]
torch.stack((tensor1, tensor2), dim)
x = torch.rand((2,2,3))
y = torch.rand((2,2,3))
print("x:",x)
print("y:",y)
print("dim=0:", torch.stack((x,y),dim=0))
print("dim=1:", torch.stack((x,y), dim=1))
print("dim=2:", torch.stack((x, y), dim=2))
print("dim=3", torch.stack((x, y), dim=3))
输出:
x: tensor([[[0.5099, 0.3434, 0.3731],
[0.8523, 0.4672, 0.4163]],
[[0.3364, 0.4910, 0.2302],
[0.7896, 0.8119, 0.3978]]])
y: tensor([[[0.3843, 0.7627, 0.9757],
[0.0065, 0.5462, 0.2765]],
[[0.1890, 0.1698, 0.4486],
[0.3459, 0.5552, 0.1908]]])
dim=0: tensor([[[[0.5099, 0.3434, 0.3731],
[0.8523, 0.4672, 0.4163]],
[[0.3364, 0.4910, 0.2302],
[0.7896, 0.8119, 0.3978]]],
[[[0.3843, 0.7627, 0.9757],
[0.0065, 0.5462, 0.2765]],
[[0.1890, 0.1698, 0.4486],
[0.3459, 0.5552, 0.1908]]]])
dim=1: tensor([[[[0.5099, 0.3434, 0.3731],
[0.8523, 0.4672, 0.4163]],
[[0.3843, 0.7627, 0.9757],
[0.0065, 0.5462, 0.2765]]],
[[[0.3364, 0.4910, 0.2302],
[0.7896, 0.8119, 0.3978]],
[[0.1890, 0.1698, 0.4486],
[0.3459, 0.5552, 0.1908]]]])
dim=2: tensor([[[[0.5099, 0.3434, 0.3731],
[0.3843, 0.7627, 0.9757]],
[[0.8523, 0.4672, 0.4163],
[0.0065, 0.5462, 0.2765]]],
[[[0.3364, 0.4910, 0.2302],
[0.1890, 0.1698, 0.4486]],
[[0.7896, 0.8119, 0.3978],
[0.3459, 0.5552, 0.1908]]]])
dim=3 tensor([[[[0.5099, 0.3843],
[0.3434, 0.7627],
[0.3731, 0.9757]],
[[0.8523, 0.0065],
[0.4672, 0.5462],
[0.4163, 0.2765]]],
[[[0.3364, 0.1890],
[0.4910, 0.1698],
[0.2302, 0.4486]],
[[0.7896, 0.3459],
[0.8119, 0.5552],
[0.3978, 0.1908]]]])
[Finished in 2.2s]
注意stack和cat的区别
- stack操作后会在原来的基础上再增加一维,比如原来两个tensor的维度都是3维,经过stack后的结果为4维tensor; 而cat操作其结果和原来的tensor保持一致
- 具体stack和cat如何连接两个tensor见上方例子
本文详细介绍了PyTorch中两种重要的张量操作:torch.cat和torch.stack。通过具体的代码示例,展示了如何使用这两种函数来连接张量,以及它们在维度处理上的不同之处。torch.cat用于沿指定维度拼接张量,而torch.stack则在原维度基础上新增一维进行堆叠。
1574

被折叠的 条评论
为什么被折叠?



