其实在DeepSeek-R1爆火之前,DeepSeek V2在我们行业就已经妇孺皆知了,它独特的MOE结构值得研究一下。这篇文章是基于 ZOMI酱 的2个视频写的,这2个视频讲的很好,建议大家都学习一下:《MOE终于迎来可视化解读!傻瓜都能看懂MoE核心原理!》和《使用昇腾NPU手撕MoE单机版代码!没想到如此简单!》。
这篇文章是把我自己的理解梳理一下,加强自己的理解和记忆。
MOE结构概述
我们可以从zomi酱视频里面的这张图开始:

MOE是mixture of experts 的缩写,简单来说,就是把传统transformer结构中decoder层里面的单个线性层替换层多个并列的线性层。在这些线性层前面还有一个Router,Router会选择并列线性层里面的一部分进行计算。这样的话,既能让模型学习更多的知识(多个“专家”),又能减少推理计算量(选择部分“专家”进行计算)。
MOE计算代码
接下来我们参考zomi酱提供的代码来详细看一下MOE的计算过程是怎样的:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_npu
from torch_npu.contrib import transfer_to_npu
class Expert(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim))
def forward(self, x):
return self.net(x)
class MoE(nn.Module):
def __init__(self, input_dim, num_experts, top_k, expert_capacity, hidden_dim, output_dim):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.expert_capacity = expert_capacity
# 路由网络
self.gate = nn.Linear(input_dim, num_experts)
# 专家集合
self.experts = nn.ModuleList(
[Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)])
def forward(self, x):
batch_size, input_dim = x.shape
device = x.device
# 路由计算
logits = self.gate(x)
probs = torch.softmax(logits, dim=-1)
print("probs: ", probs)
topk_probs, topk_indices = torch.topk(probs, self.top_k, dim=-1)
print("topk_probs: ", topk_probs)
print("topk_indices: ", topk_indices)
# 辅助损失计算
if

最低0.47元/天 解锁文章
1278

被折叠的 条评论
为什么被折叠?



