tf.einsum

tf.einsum(equation, *inputs)
#例子:
tf.einsum(‘ij,jk->ik’, ts1,ts2) #矩阵乘法
tf.einsum(‘ij->ji’,ts1) #矩阵转置

class MLA(layers.Layer): def __init__(self, args: ModelArgs): super().__init__() self.dim = args.dim self.n_heads = args.n_heads self.q_lora_rank = args.q_lora_rank self.kv_lora_rank = args.kv_lora_rank self.qk_nope_head_dim = args.qk_nope_head_dim self.qk_rope_head_dim = args.qk_rope_head_dim self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim self.v_head_dim = args.v_head_dim # 初始化投影层 if self.q_lora_rank == 0: self.wq = layers.Dense(self.n_heads * self.qk_head_dim) else: self.wq_a = layers.Dense(self.q_lora_rank) self.q_norm = RMSNorm(self.q_lora_rank) self.wq_b = layers.Dense(self.n_heads * self.qk_head_dim) self.wkv_a = layers.Dense(self.kv_lora_rank + self.qk_rope_head_dim) self.kv_norm = RMSNorm(self.kv_lora_rank) self.wkv_b = layers.Dense(self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) self.wo = layers.Dense(self.dim) self.softmax_scale = self.qk_head_dim ** -0.5 if args.max_seq_len > args.original_seq_len: mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 self.softmax_scale *= mscale * mscale # 初始化缓存 self.k_cache = tf.Variable(tf.zeros((args.max_batch_size, args.max_seq_len,self.n_heads, self.qk_head_dim)),trainable=False) self.v_cache = tf.Variable(tf.zeros((args.max_batch_size, args.max_seq_len,self.n_heads, self.v_head_dim)),trainable=False) def call(self, x, start_pos, freqs_cis, mask=None): bsz = tf.shape(x)[0] seqlen = tf.shape(x)[1] end_pos = start_pos + seqlen # 查询投影 if self.q_lora_rank == 0: q = self.wq(x) else: q = self.wq_b(self.q_norm(self.wq_a(x))) q = tf.reshape(q, [bsz, seqlen, self.n_heads, self.qk_head_dim]) q_nope, q_pe = tf.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) q_pe = apply_rotary_emb(q_pe, freqs_cis) # 键值投影 kv = self.wkv_a(x) kv, k_pe = tf.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) k_pe = apply_rotary_emb(tf.expand_dims(k_pe, 2), freqs_cis) kv = self.wkv_b(self.kv_norm(kv)) kv = tf.reshape(kv, [bsz, seqlen, self.n_heads, self.qk_nope_head_dim + self.v_head_dim]) k_nope, v = tf.split(kv, [self.qk_nope_head_dim, self.v_head_dim], axis=-1) k = tf.concat([k_nope, tf.tile(k_pe, [1, 1, self.n_heads, 1])], axis=-1) # 更新缓存 updates_range = tf.range(start_pos, end_pos) self.k_cache.assign(tf.tensor_scatter_nd_update(self.k_cache,updates_range[:, None],k)) self.v_cache.assign(tf.tensor_scatter_nd_update(self.v_cache,updates_range[:, None],v)) # 注意力计算 q = tf.concat([q_nope, q_pe], axis=-1) scores = tf.einsum("bshd,bthd->bhst", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale if mask is not None: scores += mask[:, None, :, :] scores = tf.nn.softmax(scores, axis=-1) x = tf.einsum("bhst,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) return self.wo(tf.reshape(x, [bsz, seqlen, -1])) 将缓存去掉
03-13
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值