最近写算法的时候发现网上关于BiLSTM加Attention的实现方式五花八门,其中很多是错的,自己基于PyTorch框架实现了一版,主要用到了LSTM处理变长序列和masked softmax两个技巧。代码如下:
1、attention_utils.py
from typing import Dict, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
def create_src_lengths_mask(
batch_size: int, src_lengths: Tensor, max_src_len: Optional[int] = None
):
"""
Generate boolean mask to prevent attention beyond the end of source
Inputs:
batch_size : int
src_lengths : [batch_size] of sentence lengths
max_src_len: Optionally override max_src_len for the mask
Outputs:
[batch_size, max_src_len]
"""
if max_src_len is None:
max_src_len = int(src_lengths.max())
src_indices = torch.arange(0, max_src_len).unsqueeze(0).type_as(src_lengths)
src_indices = src_indices.expand(batch_size, max_src_len)
src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_src_len)
# returns [batch_size, max_seq_len]
return (src_indices < src_lengths).int().detach()
def masked_softmax(scores, src_lengths, src_length_masking=True):
"""Apply source length masking then softmax.
Input and output have shape bsz x src_len"""
if src_length_masking:
bsz, max_src_len = scores.size()
# compute masks
src_mask = create_src_lengths_mask(bsz, src_le

本文介绍了如何在PyTorch中实现一个带有Attention机制的BiLSTM模型,用于处理变长序列。代码包括了创建源长度掩码、应用 masked softmax 以及 MLP 注意力网络的实现。测试结果显示了注意力分布和预测概率。
最低0.47元/天 解锁文章
3431

被折叠的 条评论
为什么被折叠?



