理解专家并行,需要了解einsum函数
import torch
# 设置输入张量的维度:s = 3 tokens, e = 2 experts, c = 2 capacity, m = 4 embedding dim
s, e, c, m = 3, 2, 2, 4
# 1. 输入 token 的嵌入向量 (s, m)
reshaped_input = torch.tensor([
[1.0, 1.0, 1.0, 1.0], # token 0
[2.0, 2.0, 2.0, 2.0], # token 1
[3.0, 3.0, 3.0, 3.0], # token 2
])
# 2. dispatch_mask: (s, e, c)
# 表示每个 token 被分配到哪个 expert 的哪个槽位(slot)
dispatch_mask = torch.tensor([
# token 0
[[1, 0], # expert 0: slot 0
[0, 0]], # expert 1: no slot
# token 1
[[0, 0],
[1, 0]], # expert 1: slot 0
# token 2
[[0, 1], # expert 0: slot 1
[0, 0]], # expert 1: no slot
])
dispatch_mask = dispatch_mask.float()
# 3. 应用 einsum 进行 token 分发到专家
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask, reshaped_input)
# 4. 打印结果
print("Dispatched Input shape:", dispatched_input.shape)
print("\nDispatched Input Tensor:")
print(dispatched_input)
结果如下
Dispatched Input shape: torch.Size([2, 2, 4])
Dispatched Input Tensor:
tensor([[[1., 1., 1., 1.],
[3., 3., 3., 3.]],
[[2., 2., 2., 2.],
[0., 0., 0., 0.]]])
增加画图功能
import torch
import matplotlib.pyplot as plt
# 输入数据
reshaped_input = torch.tensor([
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
]) # float32
dispatch_mask = torch.tensor([
[[1, 0], [0, 0]],
[[0, 0], [1, 0]],
[[0, 1], [0, 0]],
]) # int64 → 不兼容
# 修复:转换为 float 类型
dispatch_mask = dispatch_mask.float()
# Einsum 分发
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask, reshaped_input)
# 输出
print("Dispatched input shape:", dispatched_input.shape)
print(dispatched_input)
def visualize_dispatch(dispatch_mask):
s, e, c = dispatch_mask.shape # tokens, experts, capacity
plt.figure(figsize=(6, 4))
for token in range(s):
for expert in range(e):
for slot in range(c):
if dispatch_mask[token, expert, slot] > 0:
# token 位置 (左边)
x_token, y_token = 0, s - token
# expert-slot 位置 (右边)
x_expert, y_expert = 4, e * c - (expert * c + slot)
# 画连接线
plt.plot([x_token, x_expert], [y_token, y_expert], 'k-', lw=1)
# 标记 token
plt.text(x_token - 0.2, y_token, f"T{token}", va='center', ha='right', fontsize=10)
# 标记 expert-slot
plt.text(x_expert + 0.2, y_expert, f"E{expert}-S{slot}", va='center', ha='left', fontsize=10)
# 设置图形样式
plt.xlim(-1, 6)
plt.ylim(0, max(s, e*c) + 1)
plt.axis('off')
plt.title("Token → Expert-Slot Routing")
plt.show()
visualize_dispatch(dispatch_mask)