### 三、BERT模型注意力机制分析
print("\n#### BERT模型注意力模式分析")
# 安全处理token ID的函数(使用纯TensorFlow操作)
def safe_token_ids(input_ids, vocab_size):
"""确保所有token ID在有效范围内(纯TensorFlow实现)"""
return tf.clip_by_value(input_ids, 0, vocab_size - 1)
# 获取注意力权重的函数(带安全检查)
def get_attention_weights(model, input_text, tokenizer, layer_idx=-1):
"""获取BERT模型指定层的注意力权重"""
# 编码文本
inputs = tokenizer(
input_text,
return_tensors="tf",
max_length=100,
padding="max_length",
truncation=True
)
# 获取vocab_size用于安全检查
vocab_size = tokenizer.vocab_size
# 安全处理input_ids(使用纯TensorFlow操作)
if 'input_ids' in inputs:
inputs['input_ids'] = tf.map_fn(
lambda x: safe_token_ids(x, vocab_size),
inputs['input_ids'],
dtype=tf.int32
)
# 运行模型获取注意力权重
outputs = model(**inputs, output_attentions=True)
attentions = outputs.attentions[layer_idx] # 取最后一层注意力
attentions = tf.squeeze(attentions, axis=0) # 移除批次维度
attentions = tf.reduce_mean(attentions, axis=0) # 平均所有头的注意力
# 获取输入tokens
input_ids = inputs["input_ids"][0].numpy()
tokens = tokenizer.convert_ids_to_tokens(input_ids)
return tokens, attentions.numpy()
# 随机选取测试集中的样本进行分析
sample_idx = np.random.randint(0, len(X_test))
sample_text = X_test.iloc[sample_idx]
true_label = "抑郁" if y_test.iloc[sample_idx] == 1 else "正常"
# 获取预测结果
predictions = bert_model.predict(test_ds.take(1))
logits = predictions.logits
pred_prob = tf.nn.sigmoid(logits)[0, 0].numpy()
pred_label = "抑郁" if pred_prob > 0.5 else "正常"
print(f"\n分析样本: {'抑郁' if true_label == '抑郁' else '正常'}类文本")
print(f"文本内容: {sample_text[:100]}...")
print(f"真实标签: {true_label}, 预测标签: {pred_label}, 概率: {pred_prob:.4f}")
# 获取注意力权重(带安全检查)
tokens, attentions = get_attention_weights(
bert_model,
sample_text,
tokenizer,
layer_idx=-1 # 最后一层注意力
)
# 过滤掉特殊tokens
valid_tokens = []
valid_attentions = []
for token, attn in zip(tokens, attentions):
if token not in ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']:
valid_tokens.append(token)
valid_attentions.append(attn)
# 确保有足够的有效tokens用于绘图
if len(valid_tokens) > 1:
# 绘制注意力热力图(修正为二维输入)
plt.figure(figsize=(15, 3))
sns.heatmap([valid_attentions],
yticklabels=['注意力权重'],
xticklabels=valid_tokens,
cmap='YlOrRd',
cbar_kws={'label': '注意力强度'},
annot=False) # 避免文本过多
plt.title('BERT模型注意力权重分布')
plt.xticks(rotation=90, fontsize=8) # 旋转x轴标签并减小字体
plt.tight_layout()
plt.show()
# 绘制注意力条形图(作为备选可视化)
plt.figure(figsize=(15, 5))
top_n = min(30, len(valid_tokens)) # 只显示前30个token
top_indices = np.argsort(valid_attentions)[-top_n:]
top_tokens = [valid_tokens[i] for i in top_indices]
top_attentions = [valid_attentions[i] for i in top_indices]
sns.barplot(x=top_tokens, y=top_attentions, palette='YlOrRd')
plt.title('BERT模型注意力权重最高的词汇')
plt.xlabel('词汇')
plt.ylabel('注意力权重')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()
else:
print("有效tokens数量不足,无法绘制注意力热力图")
问题为ValueError Traceback (most recent call last)
Cell In[32], line 174
171 if len(valid_tokens) > 1:
172 # 绘制注意力热力图(修正为二维输入)
173 plt.figure(figsize=(15, 3))
--> 174 sns.heatmap([valid_attentions],
175 yticklabels=['注意力权重'],
176 xticklabels=valid_tokens,
177 cmap='YlOrRd',
178 cbar_kws={'label': '注意力强度'},
179 annot=False) # 避免文本过多
180 plt.title('BERT模型注意力权重分布')
181 plt.xticks(rotation=90, fontsize=8) # 旋转x轴标签并减小字体
File D:\Anaconda\envs\pytorch1\lib\site-packages\seaborn\matrix.py:446, in heatmap(data, vmin, vmax, cmap, center, robust, annot, fmt, annot_kws, linewidths, linecolor, cbar, cbar_kws, cbar_ax, square, xticklabels, yticklabels, mask, ax, **kwargs)
365 """Plot rectangular data as a color-encoded matrix.
366
367 This is an Axes-level function and will draw the heatmap into the
(...)
443
444 """
445 # Initialize the plotter object
--> 446 plotter = _HeatMapper(data, vmin, vmax, cmap, center, robust, annot, fmt,
447 annot_kws, cbar, cbar_kws, xticklabels,
448 yticklabels, mask)
450 # Add the pcolormesh kwargs here
451 kwargs["linewidths"] = linewidths
File D:\Anaconda\envs\pytorch1\lib\site-packages\seaborn\matrix.py:110, in _HeatMapper.__init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt, annot_kws, cbar, cbar_kws, xticklabels, yticklabels, mask)
108 else:
109 plot_data = np.asarray(data)
--> 110 data = pd.DataFrame(plot_data)
112 # Validate the mask and convert to DataFrame
113 mask = _matrix_mask(data, mask)
File D:\Anaconda\envs\pytorch1\lib\site-packages\pandas\core\frame.py:672, in DataFrame.__init__(self, data, index, columns, dtype, copy)
662 mgr = dict_to_mgr(
663 # error: Item "ndarray" of "Union[ndarray, Series, Index]" has no
664 # attribute "name"
(...)
669 typ=manager,
670 )
671 else:
--> 672 mgr = ndarray_to_mgr(
673 data,
674 index,
675 columns,
676 dtype=dtype,
677 copy=copy,
678 typ=manager,
679 )
681 # For data is list-like, or Iterable (will consume into list)
682 elif is_list_like(data):
File D:\Anaconda\envs\pytorch1\lib\site-packages\pandas\core\internals\construction.py:304, in ndarray_to_mgr(values, index, columns, dtype, copy, typ)
299 values = values.reshape(-1, 1)
301 else:
302 # by definition an array here
303 # the dtypes will be coerced to a single dtype
--> 304 values = _prep_ndarray(values, copy=copy)
306 if dtype is not None and not is_dtype_equal(values.dtype, dtype):
307 shape = values.shape
File D:\Anaconda\envs\pytorch1\lib\site-packages\pandas\core\internals\construction.py:555, in _prep_ndarray(values, copy)
553 values = values.reshape((values.shape[0], 1))
554 elif values.ndim != 2:
--> 555 raise ValueError(f"Must pass 2-d input. shape={values.shape}")
557 return values
ValueError: Must pass 2-d input. shape=(1, 98, 100)
<Figure size 1500x300 with 0 Axes>
怎么解决,标出错误的地方
最新发布