Tensorflow张量不能修改

部署运行你感兴趣的模型镜像

# gate_shape: [bs, task_num, total_expert]
# self_weight_task: [task_num, expert_per_group]

# Step1: 构造一个全局 mask,用来定位每个 task 的 experts
mask = tf.reshape(tf.one_hot(tf.range(task_num * expert_per_group), depth=gate_shape.shape[-1]),
                  [task_num, expert_per_group, gate_shape.shape[-1]])  # [task_num, expert_per_group, total_expert]

# Step2: 把 self_weight_task 扩展,跟 mask 对齐
# [task_num, expert_per_group] -> [1, task_num, expert_per_group, 1]
weight_expanded = self_weight_task[None, :, :, None]

# Step3: 乘 mask 得到稀疏展开
# [1, task_num, expert_per_group, 1] * [task_num, expert_per_group, total_expert]
# -> [1, task_num, expert_per_group, total_expert]
updates = weight_expanded * mask[None, :, :, :]

# Step4: sum 掉 expert_per_group 维度 -> [1, task_num, total_expert]
updates = tf.reduce_sum(updates, axis=2)

# Step5: broadcast 到 batch,直接加上去
gate_shape = gate_shape + updates

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值