将其中的缓存机制去掉,代码变成了
class MLA(layers.Layer):
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.n_heads = args.n_heads
self.q_lora_rank = args.q_lora_rank
self.kv_lora_rank = args.kv_lora_rank
self.qk_nope_head_dim = args.qk_nope_head_dim
self.qk_rope_head_dim = args.qk_rope_head_dim
self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
self.v_head_dim = args.v_head_dim
# 初始化投影层
if self.q_lora_rank == 0:
self.wq = layers.Dense(self.n_heads * self.qk_head_dim)
else:
self.wq_a = layers.Dense(self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
self.wq_b = layers.Dense(self.n_heads * self.qk_head_dim)
self.wkv_a = layers.Dense(self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = layers.Dense(self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
self.wo = layers.Dense(self.dim)
self.softmax_scale = self.qk_head_dim ** -0.5
if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
self.softmax_scale *= mscale * mscale
def call(self, x, start_pos, freqs_cis, mask=None):
bsz = tf.shape(x)[0]
seqlen = tf.shape(x)[1]
end_pos = start_pos + seqlen
# 查询投影
if self.q_lora_rank == 0:
q = self.wq(x)
else:
q = self.wq_b(self.q_norm(self.wq_a(x)))
q = tf.reshape(q, [bsz, seqlen, self.n_heads, self.qk_head_dim])
q_nope, q_pe = tf.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
# 键值投影
kv = self.wkv_a(x)
kv, k_pe = tf.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
k_pe = apply_rotary_emb(tf.expand_dims(k_pe, 2), freqs_cis)
kv = self.wkv_b(self.kv_norm(kv))
kv = tf.reshape(kv, [bsz, seqlen, self.n_heads, self.qk_nope_head_dim + self.v_head_dim])
k_nope, v = tf.split(kv, [self.qk_nope_head_dim, self.v_head_dim], axis=-1)
k = tf.concat([k_nope, tf.tile(k_pe, [1, 1, self.n_heads, 1])], axis=-1)
# 注意力计算
q = tf.concat([q_nope, q_pe], axis=-1)
scores = tf.einsum("bqhd,bkhd->bhqk", q, k) * self.softmax_scale # 维度调整为qk交互
print(scores.shape) # 此处为(2, 16, 128, 128)
if mask is not None:
print(mask.shape) # 此处为(128, 128)
scores += mask[:, None, :, :]
scores += mask[None, None, :, :]
scores = tf.nn.softmax(scores, axis=-1)
x = tf.einsum("bhqk,bkhd->bqhd", scores, v) # 维度调整
return self.wo(tf.reshape(x, [bsz, seqlen, -1]))
存在以下问题025-03-12 17:01:42.362209: W tensorflow/core/framework/op_kernel.cc:1830] OP_REQUIRES failed at strided_slice_op.cc:111 : INVALID_ARGUMENT: Index out of range using input dim 2; input has only 2 dims
File "E:\算法模型\DeepSeek-V3-main\inference\model_tf.py", line 248, in call
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "E:\算法模型\DeepSeek-V3-main\inference\model_tf.py", line 159, in call
scores += mask[:, None, :, :]
~~~~^^^^^^^^^^^^^^^
tensorflow.python.framework.errors_impl.InvalidArgumentError: Exception encountered when calling layer 'mla' (type MLA).
{{function_node __wrapped__StridedSlice_device_/job:localhost/replica:0/task:0/device:CPU:0}} Index out of range using input dim 2; input has only 2 dims [Op:StridedSlice] name: transformer/block/mla/strided_slice/
Call arguments received by layer 'mla' (type MLA):
• x=tf.Tensor(shape=(2, 128, 10), dtype=float32)
• start_pos=0
• freqs_cis=tf.Tensor(shape=(128, 32), dtype=float32)
• mask=tf.Tensor(shape=(128, 128), dtype=float32)