进阶torch:torch.stack和torch.cat + one-hot报错

本文详细介绍了PyTorch中两种重要的张量操作:torch.cat用于拼接多个tensor成一个tensor;torch.stack用于在指定维度连接两个张量。同时,针对one-hot编码时出现的错误提供了有效的解决方法。

torch.cat函数参数解析

作用:torch.cat将多个tensor拼接而成一个tensor,实现了在不同维度上的tensor拼接。

参数:inputs:  待拼接的两个张量序列,要拼接两个单独的张量,可以采用inputs=[tensor_1,tensor_2]的方法。

           dim: 拼接的维度,下面将区分两种拼接维度的区别,一看就懂,易上手。

示例代码:

import torch
tensor_a = torch.tensor([[1,2],[2,3],[3,4],[4,5]])
tensor_b = torch.tensor([[0,1],[-1,0],[-2,-1],[-3,-2]])
tensor_c1 = torch.cat((tensor_a,tensor_b),dim=0)
tensor_c2 = torch.cat((tensor_a,tensor_b),dim=1)
print(tensor_c1)
print(tensor_c2)

运行结果:

 dim=0的情况下的拼接运算:

 dim=1的情况下的拼接运算:

torch.stack

作用:连接张量,同torch.cat不同的是,连接不是拼接,拼接讲的是在两个张量进行张量内的拼接,而连接指的是两个张量之间进行连接。

参数:inputs:  待连接的两个张量序列,要在张量间连接两个单独的张量,可以采用inputs=[tensor_1,tensor_2]的方法。

           dim: 拼接的维度,下面将区分两种拼接维度的区别,一看就懂,易上手。

 示例代码:

import torch
tensor_a = torch.tensor([[1,2],[2,3],[3,4],[4,5]])
tensor_b = torch.tensor([[0,1],[-1,0],[-2,-1],[-3,-2]])
tensor_c1 = torch.stack((tensor_a,tensor_b),dim=0)
tensor_c2 = torch.stack((tensor_a,tensor_b),dim=1)
print(tensor_c1)
print(tensor_c2)

dim=0时:

 

dim=1时:

 one-hot独热编码报错:one-hot is only applicable to index tensor.

解决办法:

检查传入参数是不是torch.int32类型的,如果是,请改为torch.int64。

参考博客:pytorch 独热编码报错的解决办法

                torch.stack()的官方解释,详解以及例子

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小鬼缠身、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值