pad_sequence
填充可变长度张量列表
例子
>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300)
>>> b = torch.ones(22, 300)
>>> c = torch.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300])
参数
-
batch_first ( bool , optional ) –张量的位置.如果为真,则输出
B x T x *
,否则为T x B x *.
-
padding_value ( float , optional ) -- 填充元素的值。默认值:0。
nn.utils.rnn.pack_padded_sequence
torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)
打包一个包含可变长度填充序列的张量。
input
可以是大小[T, B, *
],其中T是最长序列的长度(等于),是批量大小,并且*是任意数量的维度(包括 0)。
对于未排序的序列,使用enforce_sorted = False。如果enforce_sorted
是 True
,则序列应按长度降序排序,即 input[:,0]
应该是最长的序列,input[:,B-1]
也是最短的序列。enforce_sorted = True仅用于 ONNX 导出。
例子
tensor([[1, 2, 3, 4],
[9, 0, 0, 0]])
# 使用候变成下面这样
PackedSequence(data=tensor([1, 9, 2, 3, 4]),
batch_sizes=tensor([2, 1, 1, 1]),
sorted_indices=None, unsorted_indices=None)
参数
-
输入( Tensor ) – 填充的可变长度序列批次。
-
lengths ( Tensor or list ( int ) ) – 每个批处理元素的序列长度列表(如果作为张量提供,则必须在 CPU 上)。
-
batch_first ( bool , optional ) – 同上
-
enforce_sorted ( bool , optional ) – 如果
True
是,则输入应包含按长度降序排序的序列。如果False
,输入将无条件排序。默认值:True
。
nn.utils.rnn.pad_packed_sequence
填充一组打包的可变长度序列。nn.utils.rnn.pack_padded_sequence的逆向操作.
例子
>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1,2,0], [3,0,0], [4,5,6]])
>>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
>>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
>>> seq_unpacked
tensor([[1, 2, 0],
[3, 0, 0],
[4, 5, 6]])
>>> lens_unpacked
tensor([2, 1, 3])
参数
-
序列( PackedSequence ) – 批处理到填充
-
batch_first ( bool , optional ) – 同上
-
padding_value ( float , optional ) -- 填充元素的值。
-
total_length ( int , optional ) – 如果不是
None
,输出将被填充为具有长度total_length
。