值域范围tf.clip_by_value(a,min,max)

本文介绍如何使用TensorFlow的tf.clip_by_value函数来限制张量中的值域。通过一个具体的代码示例,展示了如何将张量中的值限定在指定的最小值和最大值之间,所有低于最小值的元素被替换为最小值,所有超过最大值的元素被替换为最大值。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import tensorflow as tf

a = tf.Variable(tf.constant([[1,2,3],[4,5,6]]))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(a))
    print(sess.run(tf.clip_by_value(a,2,3)))

值域划分,tf.clip_by_value(a,min,max)

根据英语单词意思即可理解。

输出一个张量,使得张量a中的小于min值的都改为min,大于max值的都改为max.

### 三、BERT模型注意力机制分析 print("\n#### BERT模型注意力模式分析") # 安全处理token ID的函数(使用纯TensorFlow操作) def safe_token_ids(input_ids, vocab_size): """确保所有token ID在有效范围内(纯TensorFlow实现)""" return tf.clip_by_value(input_ids, 0, vocab_size - 1) # 获取注意力权重的函数(带安全检查) def get_attention_weights(model, input_text, tokenizer, layer_idx=-1): """获取BERT模型指定层的注意力权重""" # 编码文本 inputs = tokenizer( input_text, return_tensors="tf", max_length=100, padding="max_length", truncation=True ) # 获取vocab_size用于安全检查 vocab_size = tokenizer.vocab_size # 安全处理input_ids(使用纯TensorFlow操作) if 'input_ids' in inputs: inputs['input_ids'] = tf.map_fn( lambda x: safe_token_ids(x, vocab_size), inputs['input_ids'], dtype=tf.int32 ) # 运行模型获取注意力权重 outputs = model(**inputs, output_attentions=True) attentions = outputs.attentions[layer_idx] # 取最后一层注意力 attentions = tf.squeeze(attentions, axis=0) # 移除批次维度 attentions = tf.reduce_mean(attentions, axis=0) # 平均所有头的注意力 # 获取输入tokens input_ids = inputs["input_ids"][0].numpy() tokens = tokenizer.convert_ids_to_tokens(input_ids) return tokens, attentions.numpy() # 随机选取测试集中的样本进行分析 sample_idx = np.random.randint(0, len(X_test)) sample_text = X_test.iloc[sample_idx] true_label = "抑郁" if y_test.iloc[sample_idx] == 1 else "正常" # 获取预测结果 predictions = bert_model.predict(test_ds.take(1)) logits = predictions.logits pred_prob = tf.nn.sigmoid(logits)[0, 0].numpy() pred_label = "抑郁" if pred_prob > 0.5 else "正常" print(f"\n分析样本: {'抑郁' if true_label == '抑郁' else '正常'}类文本") print(f"文本内容: {sample_text[:100]}...") print(f"真实标签: {true_label}, 预测标签: {pred_label}, 概率: {pred_prob:.4f}") # 获取注意力权重(带安全检查) tokens, attentions = get_attention_weights( bert_model, sample_text, tokenizer, layer_idx=-1 # 最后一层注意力 ) # 过滤掉特殊tokens valid_tokens = [] valid_attentions = [] for token, attn in zip(tokens, attentions): if token not in ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']: valid_tokens.append(token) valid_attentions.append(attn) # 确保有足够的有效tokens用于绘图 if len(valid_tokens) > 1: # 绘制注意力热力图(修正为二维输入) plt.figure(figsize=(15, 3)) sns.heatmap([valid_attentions], yticklabels=['注意力权重'], xticklabels=valid_tokens, cmap='YlOrRd', cbar_kws={'label': '注意力强度'}, annot=False) # 避免文本过多 plt.title('BERT模型注意力权重分布') plt.xticks(rotation=90, fontsize=8) # 旋转x轴标签并减小字体 plt.tight_layout() plt.show() # 绘制注意力条形图(作为备选可视化) plt.figure(figsize=(15, 5)) top_n = min(30, len(valid_tokens)) # 只显示前30个token top_indices = np.argsort(valid_attentions)[-top_n:] top_tokens = [valid_tokens[i] for i in top_indices] top_attentions = [valid_attentions[i] for i in top_indices] sns.barplot(x=top_tokens, y=top_attentions, palette='YlOrRd') plt.title('BERT模型注意力权重最高的词汇') plt.xlabel('词汇') plt.ylabel('注意力权重') plt.xticks(rotation=45, ha='right') plt.tight_layout() plt.show() else: print("有效tokens数量不足,无法绘制注意力热力图") 问题为ValueError Traceback (most recent call last) Cell In[32], line 174 171 if len(valid_tokens) > 1: 172 # 绘制注意力热力图(修正为二维输入) 173 plt.figure(figsize=(15, 3)) --> 174 sns.heatmap([valid_attentions], 175 yticklabels=['注意力权重'], 176 xticklabels=valid_tokens, 177 cmap='YlOrRd', 178 cbar_kws={'label': '注意力强度'}, 179 annot=False) # 避免文本过多 180 plt.title('BERT模型注意力权重分布') 181 plt.xticks(rotation=90, fontsize=8) # 旋转x轴标签并减小字体 File D:\Anaconda\envs\pytorch1\lib\site-packages\seaborn\matrix.py:446, in heatmap(data, vmin, vmax, cmap, center, robust, annot, fmt, annot_kws, linewidths, linecolor, cbar, cbar_kws, cbar_ax, square, xticklabels, yticklabels, mask, ax, **kwargs) 365 """Plot rectangular data as a color-encoded matrix. 366 367 This is an Axes-level function and will draw the heatmap into the (...) 443 444 """ 445 # Initialize the plotter object --> 446 plotter = _HeatMapper(data, vmin, vmax, cmap, center, robust, annot, fmt, 447 annot_kws, cbar, cbar_kws, xticklabels, 448 yticklabels, mask) 450 # Add the pcolormesh kwargs here 451 kwargs["linewidths"] = linewidths File D:\Anaconda\envs\pytorch1\lib\site-packages\seaborn\matrix.py:110, in _HeatMapper.__init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt, annot_kws, cbar, cbar_kws, xticklabels, yticklabels, mask) 108 else: 109 plot_data = np.asarray(data) --> 110 data = pd.DataFrame(plot_data) 112 # Validate the mask and convert to DataFrame 113 mask = _matrix_mask(data, mask) File D:\Anaconda\envs\pytorch1\lib\site-packages\pandas\core\frame.py:672, in DataFrame.__init__(self, data, index, columns, dtype, copy) 662 mgr = dict_to_mgr( 663 # error: Item "ndarray" of "Union[ndarray, Series, Index]" has no 664 # attribute "name" (...) 669 typ=manager, 670 ) 671 else: --> 672 mgr = ndarray_to_mgr( 673 data, 674 index, 675 columns, 676 dtype=dtype, 677 copy=copy, 678 typ=manager, 679 ) 681 # For data is list-like, or Iterable (will consume into list) 682 elif is_list_like(data): File D:\Anaconda\envs\pytorch1\lib\site-packages\pandas\core\internals\construction.py:304, in ndarray_to_mgr(values, index, columns, dtype, copy, typ) 299 values = values.reshape(-1, 1) 301 else: 302 # by definition an array here 303 # the dtypes will be coerced to a single dtype --> 304 values = _prep_ndarray(values, copy=copy) 306 if dtype is not None and not is_dtype_equal(values.dtype, dtype): 307 shape = values.shape File D:\Anaconda\envs\pytorch1\lib\site-packages\pandas\core\internals\construction.py:555, in _prep_ndarray(values, copy) 553 values = values.reshape((values.shape[0], 1)) 554 elif values.ndim != 2: --> 555 raise ValueError(f"Must pass 2-d input. shape={values.shape}") 557 return values ValueError: Must pass 2-d input. shape=(1, 98, 100) <Figure size 1500x300 with 0 Axes> ​ 怎么解决,标出错误的地方
最新发布
06-24
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值