class MLA(layers.Layer):
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.n_heads = args.n_heads
self.q_lora_rank = args.q_lora_rank
self.kv_lora_rank = args.kv_lora_rank
self.qk_nope_head_dim = args.qk_nope_head_dim
self.qk_rope_head_dim = args.qk_rope_head_dim
self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
self.v_head_dim = args.v_head_dim
# 初始化投影层
if self.q_lora_rank == 0:
self.wq = layers.Dense(self.n_heads * self.qk_head_dim)
else:
self.wq_a = layers.Dense(self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
self.wq_b = layers.Dense(self.n_heads * self.qk_head_dim)
self.wkv_a = layers.Dense(self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = layers.Dense(self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
self.wo = layers.Dense(self.dim)
self.softmax_scale = self.qk_head_dim ** -0.5
if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
self.softmax_scale *= mscale * mscale
# 初始化缓存
self.k_cache = tf.Variable(tf.zeros((args.max_batch_size, args.max_seq_len,self.n_heads, self.qk_head_dim)),trainable=False)
self.v_cache = tf.Variable(tf.zeros((args.max_batch_size, args.max_seq_len,self.n_heads, self.v_head_dim)),trainable=False)
def call(self, x, start_pos, freqs_cis, mask=None):
bsz = tf.shape(x)[0]
seqlen = tf.shape(x)[1]
end_pos = start_pos + seqlen
# 查询投影
if self.q_lora_rank == 0:
q = self.wq(x)
else:
q = self.wq_b(self.q_norm(self.wq_a(x)))
q = tf.reshape(q, [bsz, seqlen, self.n_heads, self.qk_head_dim])
q_nope, q_pe = tf.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
# 键值投影
kv = self.wkv_a(x)
kv, k_pe = tf.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
k_pe = apply_rotary_emb(tf.expand_dims(k_pe, 2), freqs_cis)
kv = self.wkv_b(self.kv_norm(kv))
kv = tf.reshape(kv, [bsz, seqlen, self.n_heads, self.qk_nope_head_dim + self.v_head_dim])
k_nope, v = tf.split(kv, [self.qk_nope_head_dim, self.v_head_dim], axis=-1)
k = tf.concat([k_nope, tf.tile(k_pe, [1, 1, self.n_heads, 1])], axis=-1)
# 更新缓存
updates_range = tf.range(start_pos, end_pos)
self.k_cache.assign(tf.tensor_scatter_nd_update(self.k_cache,updates_range[:, None],k))
self.v_cache.assign(tf.tensor_scatter_nd_update(self.v_cache,updates_range[:, None],v))
# 注意力计算
q = tf.concat([q_nope, q_pe], axis=-1)
scores = tf.einsum("bshd,bthd->bhst", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
if mask is not None:
scores += mask[:, None, :, :]
scores = tf.nn.softmax(scores, axis=-1)
x = tf.einsum("bhst,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
return self.wo(tf.reshape(x, [bsz, seqlen, -1]))
将缓存去掉