绘制注意力热力图

绘制自注意力:

import os
import matplotlib.pyplot as plt
import seaborn as sns

def save_attention_heatmaps(attn_output_weights, save_path, bs, swt=False):
    """
    attn_output_weights: (bs * num_heads, tgt_len, src_len)
    save_path: path to save the attention heatmaps
    bs: batch size
    swt: whether to draw the attention heatmaps
    """
    if not swt:
        return
    # 获取batch size, 头的数量
    bs_nh, tgt_len, src_len = attn_output_weights.shape
    nh = bs_nh // bs
    attn_output_weights = attn_output_weights.reshape(bs, nh, tgt_len, src_len)
    attn_output_weights = attn_output_weights.cpu().detach()
    # 为每个batch创建一个文件夹
    for batch_id in range(bs):
        batch_folder = os.path.join(save_path, f'batch_{batch_id}')
        if not os.path.exists(batch_folder):
            os.makedirs(batch_folder)

        # 为每个头绘制热力图
        for head_id in range(nh):
            plt.figure(figsize=(10, 8))
            sns.heatmap(attn_output_weights[batch_id, head_id].numpy(), annot=False, cmap='viridis')

            # 调整坐标轴
            plt.xlabel('Source Sequence')
            plt.ylabel('Target Sequence')
            plt.title(f'Batch {batch_id} - Attention Weights for Head {head_id}')

            # 保存图像
            plt.savefig(os.path.join(batch_folder, f'attention_head_{head_id}.png'))
            plt.close()

绘制对memory的交叉注意力(只绘制一个object query的,代码中为第50个):

def save_attention_heatmaps_for_img(attn_output_weights, save_path, bs, swt=False, hw=None):
    """
    attn_output_weights: (bs * num_heads, tgt_len, src_len)
    save_path: path to save the attention heatmaps
    bs: batch size
    swt: whether to draw the attention heatmaps
    hw: (h, w) of the feature map (h*w = src_len)
    """
    if not swt:
        return
    # 获取batch size, 头的数量
    hs_nh, tgt_len, src_len = attn_output_weights.shape
    nh = hs_nh // bs
    h, w = hw
    attn_output_weights = attn_output_weights.reshape(bs, nh, tgt_len, h, w)
    attn_output_weights = attn_output_weights.cpu().detach()
    # 为每个batch创建一个文件夹
    for batch_id in range(bs):
        batch_folder = os.path.join(save_path, f'batch_{batch_id}')
        if not os.path.exists(batch_folder):
            os.makedirs(batch_folder)

        # 为每个头绘制热力图
        for head_id in range(nh):
            plt.figure(figsize=(10, 8))
            sns.heatmap(attn_output_weights[batch_id, head_id, 50].numpy(), annot=False, cmap='viridis')

            # 调整坐标轴
            plt.xlabel('W')
            plt.ylabel('H')
            plt.title(f'Batch {batch_id} - Attention Weights for Head {head_id}')

            # 保存图像
            plt.savefig(os.path.join(batch_folder, f'attention_head_{head_id}.png'))
            plt.close()

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值