yolo parse_model 函数详解

def parse_model(d, ch, verbose=True):
    """
    Parse a YOLO model.yaml dictionary into a PyTorch model.

    Args:
        d (dict): Model dictionary.
        ch (int): Input channels.
        verbose (bool): Whether to print model details.

    Returns:
        model (torch.nn.Sequential): PyTorch model.
        save (list): Sorted list of output layers.
    """
    import ast  # 用来把字符串字面量安全地解析成Python对象(如 "1" -> 1 或 "[1,2]" -> [1,2])

    # ----------------- 基本参数与兼容性设置 -----------------
    legacy = True  # 兼容旧版本模型的标志(v3/v5/v8/v9 等)。部分模块会根据这个标志改变行为。
    max_channels = float("inf")  # 最大通道数(后面可能被 scales 限制)
    # 从模型字典 d 中获取常用字段(如果不存在就返回 None)
    nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
    # depth_multiple, width_multiple, kpt_shape 是常见的缩放因子,若不存在则默认 1.0
    depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
    scale = d.get("scale")  # 可选的特定 scale(例如 's','m','l','x' 等)

    # 如果模型在 yaml 里有按 scale 定义多个配置(scales),并且没有手动传入 scale,则默认选第一个 scale 并给出警告
    if scales:
        if not scale:
            scale = tuple(scales.keys())[0]
            LOGGER.warning(f"no model scale passed. Assuming scale='{scale}'.")
        # scales[scale] 通常是 (depth, width, max_channels)
        depth, width, max_channels = scales[scale]

    # activation 字段存在时,改变全局 Conv 的默认激活函数(通过 eval 将字符串转为可执行的类/函数)
    if act:
        Conv.default_act = eval(act)  # 例如 act='torch.nn.SiLU()' 会把默认激活改成 SiLU
        if verbose:
            LOGGER.info(f"{colorstr('activation:')} {act}")  # 打印激活函数信息

    # 打印模型表头(可读性输出)
    if verbose:
        LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")

    # ch 最开始是一个整数输入通道数,把它改成列表以便后续记录每一层的输出通道
    ch = [ch]

    # layers 保存构建好的模块,save 保存需要输出/保留的层索引,c2 记录最新层的输出通道
    layers, save, c2 = [], [], ch[-1]

    # base_modules: 被认为是“基础模块”的集合(这些模块的构造参数通常是 (c1, c2, ... ))
    base_modules = frozenset(
        {
            Classify,
            Conv,
            ConvTranspose,
            GhostConv,
            Bottleneck,
            GhostBottleneck,
            SPP,
            SPPF,
            C2fPSA,
            C2PSA,
            DWConv,
            Focus,
            BottleneckCSP,
            C1,
            C2,
            C2f,
            C3k2,
            RepNCSPELAN4,
            ELAN1,
            ADown,
            AConv,
            SPPELAN,
            C2fAttn,
            C3,
            C3TR,
            C3Ghost,
            torch.nn.ConvTranspose2d,
            DWConvTranspose2d,
            C3x,
            RepC3,
            PSA,
            SCDown,
            C2fCIB,
            A2C2f,
        }
    )
    # repeat_modules: 这些模块通常支持“重复次数(repeat)”参数(即模块内部会按 n 重复)
    repeat_modules = frozenset(
        {
            BottleneckCSP,
            C1,
            C2,
            C2f,
            C3k2,
            C2fAttn,
            C3,
            C3TR,
            C3Ghost,
            C3x,
            RepC3,
            C2fPSA,
            C2fCIB,
            C2PSA,
            A2C2f,
        }
    )

    # ----------------- 遍历 backbone 和 head 中定义的每一层 -----------------
    # d["backbone"] 和 d["head"] 都是类似 [(from, number, module_name, args), ...] 的条目列表
    for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):  # f: from, n: repeat count, m: module str, args: 参数列表
        # 将字符串模块名转换成实际的类/构造函数:
        # - 如果是 "nn." 前缀的,使用 torch.nn 中的对应类
        # - 如果是 "torchvision.ops." 前缀的,使用 torchvision.ops 中的类
        # - 否则从 globals()(即当前模块中)获取
        m = (
            getattr(torch.nn, m[3:])  # 当 m 的前缀是 "nn." 时
            if "nn." in m
            else getattr(__import__("torchvision").ops, m[16:])  # 当 m 的前缀是 "torchvision.ops." 时
            if "torchvision.ops." in m
            else globals()[m]  # 否则从当前全局命名空间中取
        )

        # args 中可能存在字符串表示的数值或变量名(yaml 保存时常会是字符串)
        # 这里遍历 args,将能被 ast.literal_eval 解析为Python字面量的字符串转换为对应对象
        for j, a in enumerate(args):
            if isinstance(a, str):
                with contextlib.suppress(ValueError):
                    # 优先查找 locals() 里是否有同名变量(例如先前定义的 depth/width 等)
                    args[j] = locals()[a] if a in locals() else ast.literal_eval(a)

        # 深度缩放:如果 n>1 则应用 depth 缩放并四舍五入(n_ 用来打印原始的 n)
        n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gain

        # ----------------- 对不同类型模块进行参数/通道推导 -----------------
        if m in base_modules:
            # 对于基础模块,常见的参数结构是 args[0] 表示输出通道数(c2)
            c1, c2 = ch[f], args[0]  # c1 = 来自索引 f 的输入通道数, c2 = 用户在 yaml 中指定的输出通道数
            # 如果 c2 不是类别数(nc),则对 c2 做 max_channels 和 width 缩放,并确保能被 8 整除(make_divisible)
            if c2 != nc:  # Classify() 常把 c2 设为 nc(分类头)
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            # C2fAttn 需要额外处理 embed channels 和 num heads(把 args[1], args[2] 按规则缩放)
            if m is C2fAttn:  # set 1) embed channels and 2) num heads
                args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
                args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])

            # 将 c1, c2 插入到 args 前面(很多模块的第一个两个参数要求是 in_channels/out_channels)
            args = [c1, c2, *args[1:]]
            # 如果该模块支持 repeat,插入 repeat 次数(位置为 args 的第3项),然后把 n 设回 1(因为重复由模块内部处理)
            if m in repeat_modules:
                args.insert(2, n)  # number of repeats
                n = 1
            # 针对某些模块根据 model scale 做特殊设置
            if m is C3k2:  # for M/L/X sizes
                legacy = False
                if scale in "mlx":
                    args[3] = True
            if m is A2C2f:
                legacy = False
                if scale in "lx":  # for L/X sizes
                    args.extend((True, 1.2))
            if m is C2fCIB:
                legacy = False

        elif m is AIFI:
            # AIFI 模块的构造参数形式不同:需要把 ch[f](输入通道)作为第一个参数
            args = [ch[f], *args]

        elif m in frozenset({HGStem, HGBlock}):
            # HGStem/HGBlock 是类似 Hourglass 的模块:第一个两个 args 是中间通道数和输出通道数,构造参数需要为 (c1, cm, c2, ...)
            c1, cm, c2 = ch[f], args[0], args[1]
            args = [c1, cm, c2, *args[2:]]
            if m is HGBlock:
                # HGBlock 支持重复次数,插入重复次数并将 n 设为 1(重复在模块内部处理)
                args.insert(4, n)  # number of repeats
                n = 1

        elif m is ResNetLayer:
            # ResNetLayer 的 c2(输出通道)要根据 bottleneck 是否扩展来设定
            c2 = args[1] if args[3] else args[1] * 4

        elif m is torch.nn.BatchNorm2d:
            # BatchNorm2d 只需要输入通道数作为构造参数
            args = [ch[f]]

        elif m is Concat:
            # Concat 模块的输出通道是所有被连接层通道之和
            c2 = sum(ch[x] for x in f)

        elif m in frozenset(
            {Detect, WorldDetect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB, ImagePoolingAttn, v10Detect}
        ):
            # 检测/分割/姿态等头部模块需要知道它们连接来自哪些前面层,因此把这些层的输出通道列表追加到 args
            args.append([ch[x] for x in f])
            # 如果是分割头,需要对 args[2] 做通道缩放与对齐
            if m is Segment or m is YOLOESegment:
                args[2] = make_divisible(min(args[2], max_channels) * width, 8)
            # 对于这类检测/分割/姿态模块,记下 legacy 标志用于兼容性
            if m in {Detect, YOLOEDetect, Segment, YOLOESegment, Pose, OBB}:
                m.legacy = legacy

        elif m is RTDETRDecoder:
            # RTDETRDecoder 要求 channels 作为索引 1 的参数,所以把前面层的通道列表插入到 args 的索引 1
            args.insert(1, [ch[x] for x in f])

        elif m is CBLinear:
            # CBLinear 的第一个 arg 是输出通道数(c2),构造时需要把输入通道 c1 放到前面
            c2 = args[0]
            c1 = ch[f]
            args = [c1, c2, *args[1:]]

        elif m is CBFuse:
            # CBFuse 的输出通道等于被 fuse 的最后一个分支的通道
            c2 = ch[f[-1]]

        # ==================== 新增代码开始(你原始代码加入的 LocalAttention 特殊处理) ====================
        elif m is LocalAttention:
            # LocalAttention 是一个需要保证输入输出通道一致的注意力模块(你在代码中新增的处理)
            c1 = ch[f]                  # 输入通道
            # 关键点1: args[0](例如 yaml 中写的 16)是模块内部的某个参数(如窗口大小、内层特征维等),不是输出通道
            # 关键点2: 强制输出通道等于输入通道(c2 = c1),即该模块不改变通道数
            c2 = c1
            # 关键点3: 保持原来的 args[0](作为内部参数)并把 c1 作为第一个参数传入模块构造
            args = [c1, args[0]]
        # ==================== 新增代码结束 ====================

        elif m in frozenset({TorchVision, Index}):
            # 这些模块的构造参数有特殊安排:通常 args[0] 是输出通道,c1 来自 ch[f]
            c2 = args[0]
            c1 = ch[f]
            args = [*args[1:]]  # 剥离第一个参数(因为已用 c2/c1 替代)
        else:
            # 默认分支:输出通道等于来自 f 的层的通道(c2 = ch[f])
            c2 = ch[f]

        # ----------------- 用解析好的 args 构建模块实例 -----------------
        # 如果 n > 1(重复次数),用 torch.nn.Sequential 将同一模块重复 n 次;否则直接实例化
        m_ = torch.nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module

        # t 保存模块类型的字符串(用于打印)
        t = str(m)[8:-2].replace("__main__.", "")  # module type

        # m_.np 保存该模块的参数数量(用于报表)
        m_.np = sum(x.numel() for x in m_.parameters())  # number params

        # 给模块实例附加一些元信息:索引 i,来源 f,以及类型字符串 t
        m_.i, m_.f, m_.type = i, f, t  # attach index, 'from' index, type

        # 根据 verbose 设置打印每一层的摘要信息:索引、来源、重复次数、参数量、模块名、构造参数
        if verbose:
            LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f}  {t:<45}{str(args):<30}")  # print

        # 把需要保存的层索引加入 save 列表:
        # - 如果 f 是 int,则保存 f % i(支持引用相对索引)
        # - 如果 f 是 list,则对每个元素做同样操作(常见于多输入模块)
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist

        # 把构造好的模块放进层列表
        layers.append(m_)

        # 如果 i==0(第一层)执行清空 ch(因为 ch 初始只含输入通道,这里开始记录每层输出)
        if i == 0:
            ch = []
        # 把当前层的输出通道 c2 添加到 ch 列表,以便后续层引用
        ch.append(c2)

    # 循环结束后,把所有模块打包为一个 nn.Sequential,并返回 savelist 的排序版本
    return torch.nn.Sequential(*layers), sorted(save)
 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值