绘制自注意力:
import os
import matplotlib.pyplot as plt
import seaborn as sns
def save_attention_heatmaps(attn_output_weights, save_path, bs, swt=False):
"""
attn_output_weights: (bs * num_heads, tgt_len, src_len)
save_path: path to save the attention heatmaps
bs: batch size
swt: whether to draw the attention heatmaps
"""
if not swt:
return
# 获取batch size, 头的数量
bs_nh, tgt_len, src_len = attn_output_weights.shape
nh = bs_nh // bs
attn_output_weights = attn_output_weights.reshape(bs, nh, tgt_len, src_len)
attn_output_weights = attn_output_weights.cpu().detach()
# 为每个batch创建一个文件夹
for batch_id in range(bs):
batch_folder = os.path.join(save_path, f'batch_{batch_id}')
if not os.path.exists(batch_folder):
os.makedirs(batch_folder)
# 为每个头绘制热力图
for head_id in range(nh):
plt.figure(figsize=(10, 8))
sns.heatmap(attn_output_weights[batch_id, head_id].numpy(), annot=False, cmap='viridis')
# 调整坐标轴
plt.xlabel('Source Sequence')
plt.ylabel('Target Sequence')
plt.title(f'Batch {batch_id} - Attention Weights for Head {head_id}')
# 保存图像
plt.savefig(os.path.join(batch_folder, f'attention_head_{head_id}.png'))
plt.close()
绘制对memory的交叉注意力(只绘制一个object query的,代码中为第50个):
def save_attention_heatmaps_for_img(attn_output_weights, save_path, bs, swt=False, hw=None):
"""
attn_output_weights: (bs * num_heads, tgt_len, src_len)
save_path: path to save the attention heatmaps
bs: batch size
swt: whether to draw the attention heatmaps
hw: (h, w) of the feature map (h*w = src_len)
"""
if not swt:
return
# 获取batch size, 头的数量
hs_nh, tgt_len, src_len = attn_output_weights.shape
nh = hs_nh // bs
h, w = hw
attn_output_weights = attn_output_weights.reshape(bs, nh, tgt_len, h, w)
attn_output_weights = attn_output_weights.cpu().detach()
# 为每个batch创建一个文件夹
for batch_id in range(bs):
batch_folder = os.path.join(save_path, f'batch_{batch_id}')
if not os.path.exists(batch_folder):
os.makedirs(batch_folder)
# 为每个头绘制热力图
for head_id in range(nh):
plt.figure(figsize=(10, 8))
sns.heatmap(attn_output_weights[batch_id, head_id, 50].numpy(), annot=False, cmap='viridis')
# 调整坐标轴
plt.xlabel('W')
plt.ylabel('H')
plt.title(f'Batch {batch_id} - Attention Weights for Head {head_id}')
# 保存图像
plt.savefig(os.path.join(batch_folder, f'attention_head_{head_id}.png'))
plt.close()