# gate_shape: [bs, task_num, total_expert]
# self_weight_task: [task_num, expert_per_group]
# Step1: 构造一个全局 mask,用来定位每个 task 的 experts
mask = tf.reshape(tf.one_hot(tf.range(task_num * expert_per_group), depth=gate_shape.shape[-1]),
[task_num, expert_per_group, gate_shape.shape[-1]]) # [task_num, expert_per_group, total_expert]
# Step2: 把 self_weight_task 扩展,跟 mask 对齐
# [task_num, expert_per_group] -> [1, task_num, expert_per_group, 1]
weight_expanded = self_weight_task[None, :, :, None]
# Step3: 乘 mask 得到稀疏展开
# [1, task_num, expert_per_group, 1] * [task_num, expert_per_group, total_expert]
# -> [1, task_num, expert_per_group, total_expert]
updates = weight_expanded * mask[None, :, :, :]
# Step4: sum 掉 expert_per_group 维度 -> [1, task_num, total_expert]
updates = tf.reduce_sum(updates, axis=2)
# Step5: broadcast 到 batch,直接加上去
gate_shape = gate_shape + updates
2427

被折叠的 条评论
为什么被折叠?



