【PyTorch】Tensor.masked_fill() 方法:根据布尔掩码(mask)将张量中部分元素替换为指定值

Tensor.masked_fill() 方法详解

在 PyTorch 中,masked_fill() 是一个用于 根据布尔掩码(mask)将张量中部分元素替换为指定值 的方法。它常用于:

  • 遮蔽(mask)掉某些无效元素(如 attention mask)
  • 对填充值进行处理
  • 实现数值屏蔽(如用 -inf 屏蔽 softmax 中的无效位置)

1. 函数原型

Tensor.masked_fill(mask, value) → Tensor
参数说明
mask一个与原张量 shape 可广播的 布尔类型张量True 的位置将被替换
value替换值,数值标量(如 0, -1e9, -inf 等)

2. 基本示例

import torch

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

mask = x > 3

# 将大于 3 的位置替换为 0
result = x.masked_fill(mask, 0)

print(result)

输出:

tensor([[1, 2, 3],
        [0, 0, 0]])

3. 示例:填充为负无穷(-inf)用于屏蔽 attention

scores = torch.tensor([[0.1, 0.3, 0.6],
                       [0.4, 0.2, 0.4]])

mask = torch.tensor([[False, True, False],
                     [False, False, True]])

masked_scores = scores.masked_fill(mask, float('-inf'))
print(masked_scores)

输出:

tensor([[ 0.1000,   -inf,  0.6000],
        [ 0.4000,  0.2000,   -inf]])

4. 广播兼容的 mask

x = torch.ones(2, 3, 4)
mask = torch.tensor([True, False, True]).view(1, 3, 1)

x_masked = x.masked_fill(mask, -1)
print(x_masked.shape)  # torch.Size([2, 3, 4])

只对中间维度为 True 的部分填充为 -1,其余不变。


5. 常见应用场景

场景用法
注意力机制中屏蔽无效位置scores.masked_fill(mask, -inf)
掩盖 padding 部分的输入x.masked_fill(pad_mask, 0)
实现 logits 屏蔽logits.masked_fill(mask, -1e9)
在 loss 中屏蔽无效标签位置loss.masked_fill(ignore_mask, 0)

6. 与其他方法的对比

方法功能是否返回新张量
masked_fill()替换为某个值
masked_select()筛选出满足条件的元素,返回 1D 张量
masked_scatter()将指定位置替换为另一个张量的值是(更复杂)

7. 注意事项

  • mask 必须为 torch.bool 类型,或者是布尔表达式的结果。
  • value 是标量,不能是张量(使用 masked_scatter 可实现张量赋值)。
  • 会返回一个 新的张量,原张量不变。

8. 总结

特性说明
功能将满足条件的位置替换为指定值
输入mask(布尔张量),value(标量)
返回新的张量,原张量不变
常用场景attention mask、padding mask、loss 屏蔽

masked_fill() 是在构建 Transformer、seq2seq、文本处理等深度学习模型中非常常用的函数,它是高效地实现“条件赋值”的利器。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值