pytorch DataLoader处理不定长序列

本文介绍了如何使用PyTorch的DataLoader处理不定长序列,特别是对于RNN模型输入前的填充需求。通过自定义Dataset和collate_fn函数,确保不同长度的句子可以被正确地组织成批次并记录原始长度。

本篇博客的目的是:
将下图这样的输入(每个tensor表示一个句子,01为句子标签):
在这里插入图片描述
转化为下图所示的输出(batch_size=2)
元组的第一个元素为填充后的句子向量,第二个元素为句子长度,第三个元素为句子的label。
在这里插入图片描述
为什么需要这样的处理?
    如果需要使用RNN模型处理序列数据,肯定不能将变长的序列直接输入模型,所以需要在输入前对其进行填充。这里需要注意的是,在有些情况下,输入数据不仅需要填充,并且需要在数据传送过程中记录句子的原始长度,例如在RNN中,如果句子长度差别较大,例如最大长度是50,但大多数句子长度<10,这样会导致很多句子中有很多填充的0,这会导致最后得到的hn是相同的。

第一步:建立Dataset
class SentenceDataSet(Dataset):
    def __init__(self, sent, sent_label):
        self.sent = sent
        self.sent_label = sent_label

    def __getitem__(self,
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值