PyTorch从零实现和可视化Transformer注意力(Additive Attention、Scaled Dot Product Attention、Multi-Head Attention)

目录

加性注意力(Additive Attention)

缩放点积注意力(Scaled Dot Product Attention)

多头注意力(Multi-head Attention)

PyTorch实现


加性注意力(Additive Attention)


缩放点积注意力(Scaled Dot Product Attention)

上面公式没有体现批量操作,每个矩阵可以添加batch_size维度,例如Q的维度为(batch_size,n,d)。 


多头注意力(Multi-head Attention)


PyTorch实现

from torch import nn
import torch
import math
import matplotlib.pyplot as plt


def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(6, 6),
                  cmap='Reds'):
    """
    显示矩阵热图.
    :param matrices: 4阶张量. (显示的行数,显示的列数,查询的数量,键的数量)
        表示有 行数 * 列数 个矩阵.
    :param xlabel: x轴标签.
    :param ylabel: y轴标签.
    :param titles: 每个图的名称,是一个list.
    :param figsize: 图大小.
    :param cmap: 热图颜色.
    :return:
    """
    # 显示的行数,显示的列数. fig上的子图数  num_rows * num_cols.
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize, sharex=True, sharey=True, squeeze=False)
    # 绘制第i行第j列个子图.
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):  # 遍历matrices的第一个维度:行数.
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):  # 遍历matrices的第二个维度:列数.
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)  # 绘制特例图.
            ax.set_xlabel(xlabel)  # 为每个子图绘制x轴标签.
            ax.set_ylabel(ylabel)  # 为每个子图绘制y轴标签.
            if titles:
                ax.set_title(titles[j])  # 为每个子图绘制标题.
            fig.colorbar(pcm, ax=ax, shrink=0.6)  # 绘制彩色条.
    plt.show()


def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = X.reshape(-1, shape[-1])

        max_len = X.size(1)
        mask = torch.arange((max_len), dtype=torch.float32,
                            device=X.device)[None, :] < valid_lens[:, None]
        X[~mask] = -1e6

        return nn.functional.softmax(X.reshape(shape), dim=-1)


class AdditiveAttention(nn.Module):
    """加性注意力"""

    def __init__(self, query_size, key_size, num_hiddens, dropout, **kwargs):
        """
            定义网络结构.
        :param query_size: 查询的维度.
        :param key_size: 键的维度.
        :param num_hiddens: 隐藏层神经元个数.
        :param dropout: 暂退概率.
        :param kwargs: 
        """
        super(AdditiveAttention, self).__init__(**kwargs)
        # 上一层神经元个数为query_size,当前层神经元个数num_hiddens,
        # 权重矩阵维度 W_q.shape = (num_hiddens, query_size)
        # 对于输入A * W_q.T = B.
        # 维度变化 (batch_size,query_size) * (num_hiddens, query_size).T = (batch_size,num_hiddens)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)  # 注意这里W_k是神经网络的层,不是权重.
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        """
            前馈网络计算.
        :param queries: 查询.
        :param keys: 键.
        :param values: 值.
        :param valid_lens: 有效长度.
        :return:
        """

        # W_q.shape = (h, q).
        # W_k.shape = (h, q).
        # queries.shape = (batch_size, T1, q).
        # keys.shape = (batch_size, T2, k).
        # 矩阵计算如下:queries * W_q.T,keys * W_k.T .
        # 计算后的queries维度 (batch_size, T1, q) * (h,q).T = (batch_size, T1, h).
        # keys的维度 (batch_size, T2, k) * (h,k).T = (batch_size, T2, h).
        queries, keys = self.W_q(queries), self.W_k(keys)

        # unsqueeze(n)表示第dim=n维度进行扩展,n从0开始.
        # queries.shape = (batch_size, T1, h),在dim=2扩展后 shape = (batch_size, T1, 1, h)
        # keys.shape = (batch_size, T2, h),在dim=1扩展后,shape = (batch_size, 1, T2, h)
        queries = queries.unsqueeze(2)
        keys = keys.unsqueeze(1)
        features = queries + keys  # 广播求和,features.shape = (batch_size, T1, T2, h)
        features = torch.tanh(features)

        # w_v的形状(1,h),矩阵计算:features * w_v.T .
        # 计算后的维度:(batch_size, T1, T2, h) * (1,h).T = (batch_size, T1, T2, 1)
        # 最后维度为1,压缩,得到scores.shape = (batch_size, T1, T2).
        scores = self.w_v(features).squeeze(-1)

        # softmax不会改变维度,attention_weights.shape = (batch_size, T1, T2).
        # 如果不掩码,可以使用如下代码直接计算softmax:
        # self.attention_weights = torch.softmax(scores, dim=-1)
        # 计算结果为:
        # tensor([[[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
        #           0.1000, 0.1000]],
        #
        #         [[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
        #           0.1000, 0.1000]]], grad_fn= < SoftmaxBackward0 >)

        # 掩码操作scores.shape = (batch_size, T1, T2),对T1,T2两个维度进行掩码,
        # 如果valid_lens是一阶张量,对T1这个维度进行掩码.
        # 如果valid_lens是二阶张量,对(T1,T2)这两个维度进掩码.
        # 例如valid_lens = [2,6],为一阶张量,对T1进行掩码,T1=0时,保留有效长度为2,T1=1时,保留有效长度为6.
        # 掩码前:
        # tensor([[[-0.5469, -0.5469, -0.5469, -0.5469, -0.5469, -0.5469, -0.5469,
        #           -0.5469, -0.5469, -0.5469]],
        #
        #         [[ 0.0430,  0.0430,  0.0430,  0.0430,  0.0430,  0.0430,  0.0430,
        #            0.0430,  0.0430,  0.0430]]], grad_fn=<SqueezeBackward1>)
        # 掩码后:
        # tensor([[[0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        #           0.0000, 0.0000]],
        #
        #         [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000,
        #           0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)
        self.attention_weights = masked_softmax(scores, valid_lens)

        # bmm为矩阵乘法的批量操作,两个张量的batch_size要相等.
        # attention_weights.shape = (batch_size, T1, T2)
        # values.shape = (batch_size, T2, v).
        # 首先保持batch_size这个不变,剩下的维度按照矩阵乘法计算,计算后维度为(T1, T2) * (T2, v) = (T1, v)
        # 经过bmm计算后,形状shape=(batch_size, T1, v).
        return torch.bmm(self.dropout(self.attention_weights), values)


class DotProductAttention(nn.Module):
    """缩放点积注意力"""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        """

        :param queries: (batch_size,n,d)
        :param keys: (batch_size,m,d)
        :param values: (batch_size,m,v)
        :param valid_lens:
        :return:
        """
        d = queries.shape[-1]
        # (batch_size,n,m)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        # (batch_size,n,m)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # (batch_size,n,m) * (batch_size,m,v) = (batch_size,n,v)
        return torch.bmm(self.dropout(self.attention_weights), values)


class MultiHeadAttention(nn.Module):
    """多头注意力"""

    def __init__(self, query_size, key_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads  # 多头注意力,头的个数.
        self.attention = DotProductAttention(dropout)  # 初始化点乘注意力.
        # 输入维度:(batch_size,query_size),输出维度:(batch_size,num_hiddens),
        # 权重矩阵维度:(num_hiddens, query_size)
        # 例如 A * W.T = B.
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)  # 查询query的线性投射.
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)  # 键key的线性投射.
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)  # 值value的线性投射.
        self.W_o = nn.Linear(num_hiddens, value_size, bias=bias)  # 多头head连接后的线性投射.

    def forward(self, queries, keys, values, valid_lens):
        # 为了多头也可以并行计算,对queries、keys、values进行维度变换.
        # queries.shape=(batch_size, T1, q).
        # W_q.shape = (num_hiddens, q).
        # self.W_q(queries)的维度=(batch_size, T1, num_hiddens).
        # 变换后queries.shape=(batch_size * h, T1, num_hiddens/h).
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        # keys.shape=(batch_size, T2, k).
        # W_k.shape = (num_hiddens, k).
        # self.W_k(keys)的维度=(batch_size, T2, num_hiddens).
        # 变换后 keys.shape=(batch_size * h, T2, num_hiddens/h).
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        # values.shape=(batch_size, T2, v).
        # W_v.shape = (num_hiddens, v).
        # self.W_v(values)的维度=(batch_size, T2, num_hiddens).
        # 变换后values.shape=(batch_size * h, T2, num_hiddens/h).
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类.
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # 点乘注意力.
        # output的形状:(batch_size * h, T1, num_hiddens/h).
        output = self.attention(queries, keys, values, valid_lens)

        # 多头并行计算结束后,再将维度变换回来.
        # 经过变换后output_concat.shape = (batch_size, T1, num_hiddens).
        output_concat = transpose_output(output, self.num_heads)

        # output_concat.shape = (batch_size, T1, num_hiddens).
        # W_o.shape = (batch_size, v, num_hiddens).
        # 经过矩阵计算output_concat * W_o.T
        # 最终返回的维度:(batch_size, T1, v)
        return self.W_o(output_concat)


# @save
def transpose_qkv(X, num_heads):
    """
    为了多注意力头的并行计算而变换形状
    :param X: (batch_size, T, d)
    :param num_heads: h
    :return: X,X.shape = (batch_size * h, T, d/h)
    """
    # 如果X的形状为:(batch_size, T, d)
    # 变换后X的形状为:(batch_size, T, h, d/h)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 安装permute函数中索引的值,设置X的维度.
    # (batch_size, T, h, d/h) 变为:(batch_size, h, T, d/h)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size * h, T, d/h)
    return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
    """
    逆转transpose_qkv函数的操作.
    :param X: X.shape = (batch_size * h, T, d/h)
    :param num_heads: h
    :return: X,X.shape = (batch_size, T, d)
    """
    # 如果X.shape = (batch_size * h, T, d/h).
    # 变换后X.shape = (batch_size, h, T, d/h).
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    # 维度互换.
    # 变换后X.shape = (batch_size, T, h, d/h).
    X = X.permute(0, 2, 1, 3)

    # 最终X.shape = (batch_size, T, d).
    return X.reshape(X.shape[0], X.shape[1], -1)


def test_additive_attention():
    # queries.shape = (batch_size, T1, q),其中T1=1.
    # 表示:输入序列的序列长度为T1,也就是由T1个token构成,每个token的特征维度为q,一次输入batch_size.
    # keys.shape = (batch_size, T2, k),其中T2=10.
    # 表示:输入序列的长度为T2,也就是由T2个token构成,每个token的特征维度为k,一次输入batch_size.
    queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
    print('queries=', queries)
    print('keys=', keys)

    # values.shape = (batch_size, T2, v),因为keys,values形成对.
    values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
    print('values=', values)

    # 初始化加性注意力.
    attention = AdditiveAttention(query_size=queries.shape[2],
                                  key_size=keys.shape[2],
                                  num_hiddens=8,
                                  dropout=0.1)
    # 注意力结构信息:
    val = attention.eval()
    print('val=', val)

    valid_lens = torch.tensor([2, 6])
    # result.shape=(batch_size, T1, v).
    attention_val = attention(queries, keys, values, valid_lens)
    print('attention_val=', attention_val)
    show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries', titles=['AdditiveAttention'])


def test_dot_product_attention():
    queries, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))
    print('queries=', queries)
    print('keys=', keys)

    # values.shape = (batch_size, T2, v),因为keys,values形成对.
    values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
    print('values=', values)

    attention = DotProductAttention(dropout=0.5)
    val = attention.eval()
    print('val=', val)

    valid_lens = torch.tensor([2, 6])
    attention_val = attention(queries, keys, values, valid_lens)
    print('attention_val=', attention_val)

    show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries', titles=['DotProductAttention'])


def test_multi_head_attention():
    # queries.shape=(batch_size, T1, q).
    batch_size, num_queries = 2, 4
    query_size = 20  # q=20.
    queries = torch.ones((batch_size, num_queries, query_size))

    # keys.shape=(batch_size, T2, k).
    num_kv_pairs = 6
    key_size = 20  # k=20.
    keys = torch.ones((batch_size, num_kv_pairs, key_size))
    # values.shape=(batch_size, T2, v).
    value_size = 30  # v=30.
    values = torch.ones((batch_size, num_kv_pairs, value_size))

    valid_lens = torch.tensor([3, 2])
    num_hiddens, num_heads = 100, 5
    attention = MultiHeadAttention(query_size=query_size,
                                   key_size=key_size,
                                   value_size=value_size,
                                   num_hiddens=num_hiddens,
                                   num_heads=num_heads,
                                   dropout=0.5)
    val = attention.eval()
    print('val=', val)
    # 最终返回的维度:(batch_size, T1, o)
    attention_val = attention(queries, keys, values, valid_lens)
    # [batch_size, T1, v]
    print('attention_val.shape=', attention_val.shape)


if __name__ == '__main__':
    # 加性注意力测试.
    print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
    test_additive_attention()

    # 缩放点积注意力.
    print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
    test_dot_product_attention()

    # 多头注意力.
    print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
    test_multi_head_attention()

程序输出结果:

+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
queries= tensor([[[-0.8907, -1.3623,  0.1130, -1.2687, -1.4991, -0.3159,  0.3349,
          -0.7298, -1.8507, -2.5918, -0.2319, -1.3374,  0.5941,  0.9423,
          -0.5563,  0.4634, -0.1660, -1.4720, -1.6535,  0.2234]],

        [[ 0.5545, -1.1212, -1.6501, -0.5919,  1.4009,  0.9828,  0.4245,
          -1.4334,  0.1970, -0.4550, -2.1928,  0.1660, -1.1109, -1.5174,
          -1.3909,  1.8204, -0.8928, -0.6160, -0.2877, -0.0853]]])
keys= tensor([[[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]]])
values= tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.],
         [32., 33., 34., 35.],
         [36., 37., 38., 39.]],

        [[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.],
         [32., 33., 34., 35.],
         [36., 37., 38., 39.]]])
val= AdditiveAttention(
  (W_q): Linear(in_features=20, out_features=8, bias=False)
  (W_k): Linear(in_features=2, out_features=8, bias=False)
  (w_v): Linear(in_features=8, out_features=1, bias=False)
  (dropout): Dropout(p=0.1, inplace=False)
)
attention_val= tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
queries= tensor([[[ 0.1392, -0.7757]],

        [[-1.8050, -0.4379]]])
keys= tensor([[[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]]])
values= tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.],
         [32., 33., 34., 35.],
         [36., 37., 38., 39.]],

        [[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.],
         [32., 33., 34., 35.],
         [36., 37., 38., 39.]]])
val= DotProductAttention(
  (dropout): Dropout(p=0.5, inplace=False)
)
attention_val= tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]])
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
val= MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=20, out_features=100, bias=False)
  (W_k): Linear(in_features=20, out_features=100, bias=False)
  (W_v): Linear(in_features=30, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=30, bias=False)
)
attention_val.shape= torch.Size([2, 4, 30])

可视化如下:

### 如何在自定义神经网络中实现注意力机制 为了在自定义神经网络中注意力机制,可以通过引入一种简单的注意力Additive Attention)或缩放注意力Scaled Dot-Product Attention)。以下是具体方法以及如何生成注意力热力图。 #### 注意力机制的方法 假设有一个输入张量 \( X \in \mathbb{R}^{n \times d} \),其中 \( n \) 是序列长度,\( d \) 是特征维度。可以按照以下方式构建注意力层: 1. **计算查询向量 (Query)** 键向量 (Key): 定义两个线变换矩阵 \( W_Q \in \mathbb{R}^{d_q \times d} \) \( W_K \in \mathbb{R}^{d_k \times d} \),分别用于映射查询键。 查询向量 \( Q = XW_Q^\top \)[^2],键向量 \( K = XW_K^\top \)[^2]。 2. **计算注意力分数**: 使用乘法或其他相似度函数计算每一对查询键之间的匹配程度。对于缩放注意力,公式如下: \[ A_{ij} = \frac{\text{Q}_i \cdot \text{K}_j}{\sqrt{d_k}} \][^2] 3. **应用 Softmax 函数**: 将上述得分经过 softmax 转化为概率分布形式,表示每个位置的重要权重: \[ \alpha_i = \text{softmax}(A_i) \][^2] 4. **权求得到上下文向量**: 上下文向量由值向量 \( V = XW_V^\top \) 经过权重总获得: \[ C_i = \sum_j \alpha_{ij}V_j \][^2] #### 实现代码示例 下面是一个基于 PyTorch 的简单注意力模块及其可视化代码: ```python import torch import torch.nn as nn import matplotlib.pyplot as plt class SimpleAttention(nn.Module): def __init__(self, input_dim, hidden_dim): super(SimpleAttention, self).__init__() self.query_layer = nn.Linear(input_dim, hidden_dim) self.key_layer = nn.Linear(input_dim, hidden_dim) self.value_layer = nn.Linear(input_dim, hidden_dim) def forward(self, inputs): queries = self.query_layer(inputs) # Shape: [batch_size, seq_len, hidden_dim] keys = self.key_layer(inputs).transpose(1, 2) # Transpose to align dimensions values = self.value_layer(inputs) attention_scores = torch.bmm(queries, keys) / (keys.size(-1) ** 0.5) # Scaled dot-product attention_weights = torch.softmax(attention_scores, dim=-1) # Normalize scores context_vector = torch.bmm(attention_weights, values) # Weighted sum of value vectors return context_vector, attention_weights # Example usage and visualization if __name__ == "__main__": batch_size, seq_len, input_dim = 1, 8, 16 model = SimpleAttention(input_dim=input_dim, hidden_dim=32) inputs = torch.randn(batch_size, seq_len, input_dim) context, weights = model(inputs) # Plotting the heatmap fig, ax = plt.subplots() cax = ax.matshow(weights.squeeze().detach().numpy(), cmap='viridis') fig.colorbar(cax) ax.set_title('Attention Heatmap', fontsize=14) ax.set_xlabel('Keys', fontsize=12) ax.set_ylabel('Queries', fontsize=12) plt.show() ``` 此代码实现了基本的注意力机制,并展示了如何利用 `matplotlib` 来绘制注意力权重的热力图。 --- ### 注意事项 - 如果希望进一步提升能,可尝试多头注意力Multi-head Attention),其核心思想是通过多个独立的注意力子空间捕捉不同的模式[^2]。 - 对于大规模数据集或者复杂任务,建议采用预训练好的 Transformer 架构作为基础模型[^3]。 问题
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值