1. PackedSequence
torch.nn.utils.rnn.PackedSequence
这个类的实例不能手动创建。它们只能被pack_padded_sequence() 实例化。
2. pack_padded_sequence
torch.nn.utils.rnn.pack_padded_sequence()**
输入:
input: [seq_length x batch_size x input_size] 或 [batch_size x seq_length x input_size],input中的seq要按照长度递减的方式排列。
lengths: seq的长度列表,是一个递减的列表,与input里的seq长度对应。ie. [5,4,1]
batch_first: bool变量,当它为True时,表示input为这种输入形式[batch_size x seq_length x input_size],否则为另一种。
输出:
一个PackedSequence对象,包含一个Var

本文总结了PyTorch中处理RNN序列数据的工具,包括PackedSequence类和pack_padded_sequence、pad_packed_sequence函数的使用,详细解释了如何通过这些工具进行序列填充和压缩。
最低0.47元/天 解锁文章
1141

被折叠的 条评论
为什么被折叠?



