bert模型取last_hidden_state[:, 0]

文章解释了在处理模型输出时,如何从(32,100,768)的形状转换到(32,768),涉及到batch_size、max_length和hidden_size的概念。通过一个numpy数组的例子,展示了last_hidden_state[:,0]如何提取每个样本的第一行,即(1,768)形状的数据,最终得到(32,768)的形状,这通常与BERT模型中的CLStoken相关联。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

问题

看代码时,第一行跑完target_pred的last_hidden_state的shape为(32,100,768),第二行跑完target_pred的shape为(32,768) ,不理解。
注:32为batch_size,100为max_length,768为hidden_size。

target_pred = model(target_input_ids, target_attention_mask, target_token_type_ids, output_hidden_states=True, return_dict=True)
target_pred = target_pred.last_hidden_state[:, 0]

解决

涉及的知识点是三维数组切片,测试代码如下:

a=np.arange(0, 24, 1).reshape(3,2,4)
print(a)
# 输出
[[[ 0  1  2  3]
  [ 4  5  6  7]]

 [[ 8  9 10 11]
  [12 13 14 15]]

 [[16 17 18 19]
  [20 21 22 23]]]
  
b=a[:,0]
print(b)
print(b.shape)
# 输出
[[ 0  1  2  3]
 [ 8  9 10 11]
 [16 17 18 19]]
(3, 4)

以上面的例子类推,last_hidden_state[:, 0]的含义:每个batch有32个样本,每个样本的shape为(100,768)的二维数组,对每个样本取第一行,即(1,768),因此last_hidden_state[:, 0]的shape为(32,768),相当于降维。并且调用了cls。

扩展阅读

1.【bert】: 在eval时pooler、last_hiddent_state、cls的区分
2.索引与切片,玩转数组之七十二变

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值