HeatMap attention 高亮显示

本文介绍如何使用Matlab进行注意力矩阵的可视化处理,包括文本注意力热图的绘制及图像注意力区域的高亮显示。通过具体代码实例展示了如何加载注意力矩阵、调整注意力权重并将其应用于图像上,最终实现对特定文本词汇的关注度可视化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

%load attention matrix,   h_img,img_h,p_h,h_p   9824*49*49
clear;
load('attention.mat');
%parameters
example_index = 2734;% matlab array index start from 1
h_word_index = 3;%the index of the key word in hypothesis
end_index =11;%the length of premise sentence

text_attention = h_p(example_index,h_word_index,1:end_index);
text_attention = reshape(text_attention,1,end_index);
text_attention = mapminmax(text_attention, -0.4, 1);
h = HeatMap(text_attention,'Colormap',gray);


image_name = '4455622662.jpg';

image_example = imread(image_name);
imshow(image_example);

image_size = size(image_example);

width = image_size(1,1);
height = image_size(1,2);

% attention resize
% image_attention = h_img(example_index,h_word_index,:);% keyword attention attention 1*1*49
% image_attention = reshape(image_attention,7,7)';      %7*7
% image_attention_norm_ori = mapminmax(image_attention, 0, 1);


load('img_norm_light.mat');

image_attention_norm = imresize(image_attention_norm_ori, [width height]);


% FlattenedData = image_attention(:)'; % 展开矩阵为一列,然后转置为一行。
% MappedFlattened = mapminmax(FlattenedData, 0, 1); % 归一化。
% image_attention_norm_ori = reshape(MappedFlattened, size(image_attention));
% image_attention_norm = imresize(image_attention_norm_ori, [width height]);
%heatmap
%imshow(image_attention_norm);
hmap = HeatMap(image_attention_norm);

im_cam_tri = zeros(width, height, 3);

im_cam_tri(:, :, 1) = image_attention_norm * 255;
im_cam_tri(:, :, 2) = image_attention_norm * 255;
im_cam_tri(:, :, 3) = image_attention_norm * 255;

% for k=1:3
% im_cam_tri(:,:,k)=flipud(im_cam_tri(:,:,k));%上下翻转
% end


new_image = double(0.05*image_example) + 0.95*(im_cam_tri);

figure;
imshow(new_image/255);

% new_image = double(image).* (im_cam_tri > 200);
% imshow(new_image/255);
### 三、BERT模型注意力机制分析 print("\n#### BERT模型注意力模式分析") # 安全处理token ID的函数(使用纯TensorFlow操作) def safe_token_ids(input_ids, vocab_size): """确保所有token ID在有效范围内(纯TensorFlow实现)""" return tf.clip_by_value(input_ids, 0, vocab_size - 1) # 获取注意力权重的函数(带安全检查) def get_attention_weights(model, input_text, tokenizer, layer_idx=-1): """获取BERT模型指定层的注意力权重""" # 编码文本 inputs = tokenizer( input_text, return_tensors="tf", max_length=100, padding="max_length", truncation=True ) # 获取vocab_size用于安全检查 vocab_size = tokenizer.vocab_size # 安全处理input_ids(使用纯TensorFlow操作) if 'input_ids' in inputs: inputs['input_ids'] = tf.map_fn( lambda x: safe_token_ids(x, vocab_size), inputs['input_ids'], dtype=tf.int32 ) # 运行模型获取注意力权重 outputs = model(**inputs, output_attentions=True) attentions = outputs.attentions[layer_idx] # 取最后一层注意力 attentions = tf.squeeze(attentions, axis=0) # 移除批次维度 attentions = tf.reduce_mean(attentions, axis=0) # 平均所有头的注意力 # 获取输入tokens input_ids = inputs["input_ids"][0].numpy() tokens = tokenizer.convert_ids_to_tokens(input_ids) return tokens, attentions.numpy() # 随机选取测试集中的样本进行分析 sample_idx = np.random.randint(0, len(X_test)) sample_text = X_test.iloc[sample_idx] true_label = "抑郁" if y_test.iloc[sample_idx] == 1 else "正常" # 获取预测结果 predictions = bert_model.predict(test_ds.take(1)) logits = predictions.logits pred_prob = tf.nn.sigmoid(logits)[0, 0].numpy() pred_label = "抑郁" if pred_prob > 0.5 else "正常" print(f"\n分析样本: {'抑郁' if true_label == '抑郁' else '正常'}类文本") print(f"文本内容: {sample_text[:100]}...") print(f"真实标签: {true_label}, 预测标签: {pred_label}, 概率: {pred_prob:.4f}") # 获取注意力权重(带安全检查) tokens, attentions = get_attention_weights( bert_model, sample_text, tokenizer, layer_idx=-1 # 最后一层注意力 ) # 过滤掉特殊tokens valid_tokens = [] valid_attentions = [] for token, attn in zip(tokens, attentions): if token not in ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']: valid_tokens.append(token) valid_attentions.append(attn) # 确保有足够的有效tokens用于绘图 if len(valid_tokens) > 1: # 绘制注意力热力图(修正为二维输入) plt.figure(figsize=(15, 3)) sns.heatmap([valid_attentions], yticklabels=['注意力权重'], xticklabels=valid_tokens, cmap='YlOrRd', cbar_kws={'label': '注意力强度'}, annot=False) # 避免文本过多 plt.title('BERT模型注意力权重分布') plt.xticks(rotation=90, fontsize=8) # 旋转x轴标签并减小字体 plt.tight_layout() plt.show() # 绘制注意力条形图(作为备选可视化) plt.figure(figsize=(15, 5)) top_n = min(30, len(valid_tokens)) # 只显示前30个token top_indices = np.argsort(valid_attentions)[-top_n:] top_tokens = [valid_tokens[i] for i in top_indices] top_attentions = [valid_attentions[i] for i in top_indices] sns.barplot(x=top_tokens, y=top_attentions, palette='YlOrRd') plt.title('BERT模型注意力权重最高的词汇') plt.xlabel('词汇') plt.ylabel('注意力权重') plt.xticks(rotation=45, ha='right') plt.tight_layout() plt.show() else: print("有效tokens数量不足,无法绘制注意力热力图") 问题为ValueError Traceback (most recent call last) Cell In[32], line 174 171 if len(valid_tokens) > 1: 172 # 绘制注意力热力图(修正为二维输入) 173 plt.figure(figsize=(15, 3)) --> 174 sns.heatmap([valid_attentions], 175 yticklabels=['注意力权重'], 176 xticklabels=valid_tokens, 177 cmap='YlOrRd', 178 cbar_kws={'label': '注意力强度'}, 179 annot=False) # 避免文本过多 180 plt.title('BERT模型注意力权重分布') 181 plt.xticks(rotation=90, fontsize=8) # 旋转x轴标签并减小字体 File D:\Anaconda\envs\pytorch1\lib\site-packages\seaborn\matrix.py:446, in heatmap(data, vmin, vmax, cmap, center, robust, annot, fmt, annot_kws, linewidths, linecolor, cbar, cbar_kws, cbar_ax, square, xticklabels, yticklabels, mask, ax, **kwargs) 365 """Plot rectangular data as a color-encoded matrix. 366 367 This is an Axes-level function and will draw the heatmap into the (...) 443 444 """ 445 # Initialize the plotter object --> 446 plotter = _HeatMapper(data, vmin, vmax, cmap, center, robust, annot, fmt, 447 annot_kws, cbar, cbar_kws, xticklabels, 448 yticklabels, mask) 450 # Add the pcolormesh kwargs here 451 kwargs["linewidths"] = linewidths File D:\Anaconda\envs\pytorch1\lib\site-packages\seaborn\matrix.py:110, in _HeatMapper.__init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt, annot_kws, cbar, cbar_kws, xticklabels, yticklabels, mask) 108 else: 109 plot_data = np.asarray(data) --> 110 data = pd.DataFrame(plot_data) 112 # Validate the mask and convert to DataFrame 113 mask = _matrix_mask(data, mask) File D:\Anaconda\envs\pytorch1\lib\site-packages\pandas\core\frame.py:672, in DataFrame.__init__(self, data, index, columns, dtype, copy) 662 mgr = dict_to_mgr( 663 # error: Item "ndarray" of "Union[ndarray, Series, Index]" has no 664 # attribute "name" (...) 669 typ=manager, 670 ) 671 else: --> 672 mgr = ndarray_to_mgr( 673 data, 674 index, 675 columns, 676 dtype=dtype, 677 copy=copy, 678 typ=manager, 679 ) 681 # For data is list-like, or Iterable (will consume into list) 682 elif is_list_like(data): File D:\Anaconda\envs\pytorch1\lib\site-packages\pandas\core\internals\construction.py:304, in ndarray_to_mgr(values, index, columns, dtype, copy, typ) 299 values = values.reshape(-1, 1) 301 else: 302 # by definition an array here 303 # the dtypes will be coerced to a single dtype --> 304 values = _prep_ndarray(values, copy=copy) 306 if dtype is not None and not is_dtype_equal(values.dtype, dtype): 307 shape = values.shape File D:\Anaconda\envs\pytorch1\lib\site-packages\pandas\core\internals\construction.py:555, in _prep_ndarray(values, copy) 553 values = values.reshape((values.shape[0], 1)) 554 elif values.ndim != 2: --> 555 raise ValueError(f"Must pass 2-d input. shape={values.shape}") 557 return values ValueError: Must pass 2-d input. shape=(1, 98, 100) <Figure size 1500x300 with 0 Axes> ​ 怎么解决,标出错误的地方
06-24
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

猴猴猪猪

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值