def parse_model(d, ch, verbose=True):
"""
Parse a YOLO model.yaml dictionary into a PyTorch model.
Args:
d (dict): Model dictionary.
ch (int): Input channels.
verbose (bool): Whether to print model details.
Returns:
model (torch.nn.Sequential): PyTorch model.
save (list): Sorted list of output layers.
"""
import ast # 用来把字符串字面量安全地解析成Python对象(如 "1" -> 1 或 "[1,2]" -> [1,2])
# ----------------- 基本参数与兼容性设置 -----------------
legacy = True # 兼容旧版本模型的标志(v3/v5/v8/v9 等)。部分模块会根据这个标志改变行为。
max_channels = float("inf") # 最大通道数(后面可能被 scales 限制)
# 从模型字典 d 中获取常用字段(如果不存在就返回 None)
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
# depth_multiple, width_multiple, kpt_shape 是常见的缩放因子,若不存在则默认 1.0
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
scale = d.get("scale") # 可选的特定 scale(例如 's','m','l','x' 等)
# 如果模型在 yaml 里有按 scale 定义多个配置(scales),并且没有手动传入 scale,则默认选第一个 scale 并给出警告
if scales:
if not scale:
scale = tuple(scales.keys())[0]
LOGGER.warning(f"no model scale passed. Assuming scale='{scale}'.")
# scales[scale] 通常是 (depth, width, max_channels)
depth, width, max_channels = scales[scale]
# activation 字段存在时,改变全局 Conv 的默认激活函数(通过 eval 将字符串转为可执行的类/函数)
if act:
Conv.default_act = eval(act) # 例如 act='torch.nn.SiLU()' 会把默认激活改成 SiLU
if verbose:
LOGGER.info(f"{colorstr('activation:')} {act}") # 打印激活函数信息
# 打印模型表头(可读性输出)
if verbose:
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
# ch 最开始是一个整数输入通道数,把它改成列表以便后续记录每一层的输出通道
ch = [ch]
# layers 保存构建好的模块,save 保存需要输出/保留的层索引,c2 记录最新层的输出通道
layers, save, c2 = [], [], ch[-1]
# base_modules: 被认为是“基础模块”的集合(这些模块的构造参数通常是 (c1, c2, ... ))
base_modules = frozenset(
{
Classify,
Conv,
ConvTranspose,
GhostConv,
Bottleneck,
GhostBottleneck,
SPP,
SPPF,
C2fPSA,
C2PSA,
DWConv,
Focus,
BottleneckCSP,
C1,
C2,
C2f,
C3k2,
RepNCSPELAN4,
ELAN1,
ADown,
AConv,
SPPELAN,
C2fAttn,
C3,
C3TR,
C3Ghost,
torch.nn.ConvTranspose2d,
DWConvTranspose2d,
C3x,
RepC3,
PSA,
SCDown,
C2fCIB,
A2C2f,
}
)
# repeat_modules: 这些模块通常支持“重复次数(repeat)”参数(即模块内部会按 n 重复)
repeat_modules = frozenset(
{
BottleneckCSP,
C1,
C2,
C2f,
C3k2,
C2fAttn,
C3,
C3TR,
C3Ghost,
C3x,
RepC3,
C2fPSA,
C2fCIB,
C2PSA,
A2C2f,
}
)
# ----------------- 遍历 backbone 和 head 中定义的每一层 -----------------
# d["backbone"] 和 d["head"] 都是类似 [(from, number, module_name, args), ...] 的条目列表
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # f: from, n: repeat count, m: module str, args: 参数列表
# 将字符串模块名转换成实际的类/构造函数:
# - 如果是 "nn." 前缀的,使用 torch.nn 中的对应类
# - 如果是 "torchvision.ops." 前缀的,使用 torchvision.ops 中的类
# - 否则从 globals()(即当前模块中)获取
m = (
getattr(torch.nn, m[3:]) # 当 m 的前缀是 "nn." 时
if "nn." in m
else getattr(__import__("torchvision").ops, m[16:]) # 当 m 的前缀是 "torchvision.ops." 时
if "torchvision.ops." in m
else globals()[m] # 否则从当前全局命名空间中取
)
# args 中可能存在字符串表示的数值或变量名(yaml 保存时常会是字符串)
# 这里遍历 args,将能被 ast.literal_eval 解析为Python字面量的字符串转换为对应对象
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
# 优先查找 locals() 里是否有同名变量(例如先前定义的 depth/width 等)
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
# 深度缩放:如果 n>1 则应用 depth 缩放并四舍五入(n_ 用来打印原始的 n)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
# ----------------- 对不同类型模块进行参数/通道推导 -----------------
if m in base_modules:
# 对于基础模块,常见的参数结构是 args[0] 表示输出通道数(c2)
c1, c2 = ch[f], args[0] # c1 = 来自索引 f 的输入通道数, c2 = 用户在 yaml 中指定的输出通道数
# 如果 c2 不是类别数(nc),则对 c2 做 max_channels 和 width 缩放,并确保能被 8 整除(make_divisible)
if c2 != nc: # Classify() 常把 c2 设为 nc(分类头)
c2 = make_divisible(min(c2, max_channels) * width, 8)
# C2fAttn 需要额外处理 embed channels 和 num heads(把 args[1], args[2] 按规则缩放)
if m is C2fAttn: # set 1) embed channels and 2) num heads
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])
# 将 c1, c2 插入到 args 前面(很多模块的第一个两个参数要求是 in_channels/out_channels)
args = [c1, c2, *args[1:]]
# 如果该模块支持 repeat,插入 repeat 次数(位置为 args 的第3项),然后把 n 设回 1(因为重复由模块内部处理)
if m in repeat_modules:
args.insert(2, n) # number of repeats
n = 1
# 针对某些模块根据 model scale 做特殊设置
if m is C3k2: # for M/L/X sizes
legacy = False
if scale in "mlx":
args[3] = True
if m is A2C2f:
legacy = False
if scale in "lx": # for L/X sizes
args.extend((True, 1.2))
if m is C2fCIB:
legacy = False
elif m is AIFI:
# AIFI 模块的构造参数形式不同:需要把 ch[f](输入通道)作为第一个参数
args = [ch[f], *args]
elif m in frozenset({HGStem, HGBlock}):
# HGStem/HGBlock 是类似 Hourglass 的模块:第一个两个 args 是中间通道数和输出通道数,构造参数需要为 (c1, cm, c2, ...)
c1, cm, c2 = ch[f], args[0], args[1]
args = [c1, cm, c2, *args[2:]]
if m is HGBlock:
# HGBlock 支持重复次数,插入重复次数并将 n 设为 1(重复在模块内部处理)
args.insert(4, n) # number of repeats
n = 1
elif m is ResNetLayer:
# ResNetLayer 的 c2(输出通道)要根据 bottleneck 是否扩展来设定
c2 = args[1] if args[3] else args[1] * 4
elif m is torch.nn.BatchNorm2d:
# BatchNorm2d 只需要输入通道数作为构造参数
args = [ch[f]]
elif m is Concat:
# Concat 模块的输出通道是所有被连接层通道之和
c2 = sum(ch[x] for x in f)
elif m in frozenset(
{Detect, WorldDetect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB, ImagePoolingAttn, v10Detect}
):
# 检测/分割/姿态等头部模块需要知道它们连接来自哪些前面层,因此把这些层的输出通道列表追加到 args
args.append([ch[x] for x in f])
# 如果是分割头,需要对 args[2] 做通道缩放与对齐
if m is Segment or m is YOLOESegment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
# 对于这类检测/分割/姿态模块,记下 legacy 标志用于兼容性
if m in {Detect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB}:
m.legacy = legacy
elif m is RTDETRDecoder:
# RTDETRDecoder 要求 channels 作为索引 1 的参数,所以把前面层的通道列表插入到 args 的索引 1
args.insert(1, [ch[x] for x in f])
elif m is CBLinear:
# CBLinear 的第一个 arg 是输出通道数(c2),构造时需要把输入通道 c1 放到前面
c2 = args[0]
c1 = ch[f]
args = [c1, c2, *args[1:]]
elif m is CBFuse:
# CBFuse 的输出通道等于被 fuse 的最后一个分支的通道
c2 = ch[f[-1]]
# ==================== 新增代码开始(你原始代码加入的 LocalAttention 特殊处理) ====================
elif m is LocalAttention:
# LocalAttention 是一个需要保证输入输出通道一致的注意力模块(你在代码中新增的处理)
c1 = ch[f] # 输入通道
# 关键点1: args[0](例如 yaml 中写的 16)是模块内部的某个参数(如窗口大小、内层特征维等),不是输出通道
# 关键点2: 强制输出通道等于输入通道(c2 = c1),即该模块不改变通道数
c2 = c1
# 关键点3: 保持原来的 args[0](作为内部参数)并把 c1 作为第一个参数传入模块构造
args = [c1, args[0]]
# ==================== 新增代码结束 ====================
elif m in frozenset({TorchVision, Index}):
# 这些模块的构造参数有特殊安排:通常 args[0] 是输出通道,c1 来自 ch[f]
c2 = args[0]
c1 = ch[f]
args = [*args[1:]] # 剥离第一个参数(因为已用 c2/c1 替代)
else:
# 默认分支:输出通道等于来自 f 的层的通道(c2 = ch[f])
c2 = ch[f]
# ----------------- 用解析好的 args 构建模块实例 -----------------
# 如果 n > 1(重复次数),用 torch.nn.Sequential 将同一模块重复 n 次;否则直接实例化
m_ = torch.nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
# t 保存模块类型的字符串(用于打印)
t = str(m)[8:-2].replace("__main__.", "") # module type
# m_.np 保存该模块的参数数量(用于报表)
m_.np = sum(x.numel() for x in m_.parameters()) # number params
# 给模块实例附加一些元信息:索引 i,来源 f,以及类型字符串 t
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
# 根据 verbose 设置打印每一层的摘要信息:索引、来源、重复次数、参数量、模块名、构造参数
if verbose:
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print
# 把需要保存的层索引加入 save 列表:
# - 如果 f 是 int,则保存 f % i(支持引用相对索引)
# - 如果 f 是 list,则对每个元素做同样操作(常见于多输入模块)
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
# 把构造好的模块放进层列表
layers.append(m_)
# 如果 i==0(第一层)执行清空 ch(因为 ch 初始只含输入通道,这里开始记录每层输出)
if i == 0:
ch = []
# 把当前层的输出通道 c2 添加到 ch 列表,以便后续层引用
ch.append(c2)
# 循环结束后,把所有模块打包为一个 nn.Sequential,并返回 savelist 的排序版本
return torch.nn.Sequential(*layers), sorted(save)

3350

被折叠的 条评论
为什么被折叠?



