pad_sequence和nn.utils.rnn.pack_padded_sequence和nn.utils.rnn.pad_packed_sequence

本文介绍PyTorch中处理可变长度序列的核心方法,包括pad_sequence、pack_padded_sequence及pad_packed_sequence的使用技巧与示例。通过这些方法,可以有效地管理和优化序列数据的处理流程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

pad_sequence

填充可变长度张量列表

例子

>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300)
>>> b = torch.ones(22, 300)
>>> c = torch.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300])

参数

  • 序列list [ Tensor ] ) – 可变长度序列的列表。

  • batch_first ( bool , optional ) –张量的位置.如果为真,则输出B x T x *,否则为T x B x *.

  • padding_value ( float , optional ) -- 填充元素的值。默认值:0。

nn.utils.rnn.pack_padded_sequence

torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)

打包一个包含可变长度填充序列的张量。

input可以是大小[TB*],其中T是最长序列的长度(等于),是批量大小,并且*是任意数量的维度(包括 0)。

对于未排序的序列,使用enforce_sorted = False。如果enforce_sorted是 True,则序列应按长度降序排序,即 input[:,0]应该是最长的序列,input[:,B-1]也是最短的序列。enforce_sorted = True仅用于 ONNX 导出。

例子

tensor([[1, 2, 3, 4],
        [9, 0, 0, 0]])

# 使用候变成下面这样
PackedSequence(data=tensor([1, 9, 2, 3, 4]), 
               batch_sizes=tensor([2, 1, 1, 1]), 
               sorted_indices=None, unsorted_indices=None)

参数

  • 输入Tensor ) – 填充的可变长度序列批次。

  • lengths ( Tensor or list ( int ) ) – 每个批处理元素的序列长度列表(如果作为张量提供,则必须在 CPU 上)。

  • batch_first ( bool , optional ) – 同上

  • enforce_sorted ( bool , optional ) – 如果True是,则输入应包含按长度降序排序的序列。如果 False,输入将无条件排序。默认值:True

nn.utils.rnn.pad_packed_sequence

填充一组打包的可变长度序列。nn.utils.rnn.pack_padded_sequence的逆向操作.

例子

>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1,2,0], [3,0,0], [4,5,6]])
>>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
               sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
>>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
>>> seq_unpacked
tensor([[1, 2, 0],
        [3, 0, 0],
        [4, 5, 6]])
>>> lens_unpacked
tensor([2, 1, 3])

参数 

  • 序列PackedSequence ) – 批处理到填充

  • batch_first ( bool , optional ) – 同上

  • padding_value ( float , optional ) -- 填充元素的值。

  • total_length ( int , optional ) – 如果不是None,输出将被填充为具有长度total_length

官方文档:torch.nn.utils.rnn.pad_packed_sequence — PyTorch 1.10.1 documentationicon-default.png?t=LBL2https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_packed_sequence.html?highlight=nn%20utils%20rnn%20pad_packed_sequence#torch.nn.utils.rnn.pad_packed_sequence 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值