本篇博客的目的是:
将下图这样的输入(每个tensor表示一个句子,01为句子标签):

转化为下图所示的输出(batch_size=2)
元组的第一个元素为填充后的句子向量,第二个元素为句子长度,第三个元素为句子的label。

为什么需要这样的处理?
如果需要使用RNN模型处理序列数据,肯定不能将变长的序列直接输入模型,所以需要在输入前对其进行填充。这里需要注意的是,在有些情况下,输入数据不仅需要填充,并且需要在数据传送过程中记录句子的原始长度,例如在RNN中,如果句子长度差别较大,例如最大长度是50,但大多数句子长度<10,这样会导致很多句子中有很多填充的0,这会导致最后得到的hn是相同的。
第一步:建立Dataset
class SentenceDataSet(Dataset):
def __init__(self, sent, sent_label):
self.sent = sent
self.sent_label = sent_label
def __getitem__(self,

本文介绍了如何使用PyTorch的DataLoader处理不定长序列,特别是对于RNN模型输入前的填充需求。通过自定义Dataset和collate_fn函数,确保不同长度的句子可以被正确地组织成批次并记录原始长度。
最低0.47元/天 解锁文章
579





