yolov10中,tasks.py的parse_model方法详解

在yolov10改进的时候,经常可以看到需要修改parse_model方法,但是相信很多东西都不知道这个方法是干嘛的,以及流程方式,所以今天给大家详细介绍一下这些变量的含义和作用,方便大家理解原理。

yolov10中,tasks.py的parse_model方法详解_目标检测

源代码

def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)
    """Parse a YOLO model.yaml dictionary into a PyTorch model."""
    import ast

    # Args
    max_channels = float("inf")
    nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
    depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
    if scales:
        scale = d.get("scale")
        if not scale:
            scale = tuple(scales.keys())[0]
            LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
        depth, width, max_channels = scales[scale]

    if act:
        Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()
        if verbose:
            LOGGER.info(f"{colorstr('activation:')} {act}")  # print

    if verbose:
        LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")
    ch = [ch]
    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):  # from, number, module, args
        m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m]  # get module
        for j, a in enumerate(args):
            if isinstance(a, str):
                with contextlib.suppress(ValueError):
                    args[j] = locals()[a] if a in locals() else ast.literal_eval(a)

        n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gain
        if m in {
            Classify,
            Conv,
            ConvTranspose,
            GhostConv,
            Bottleneck,
            GhostBottleneck,
            SPP,
            SPPF,
            DWConv,
            Focus,
            BottleneckCSP,
            C1,
            C2,
            C2f,
            RepNCSPELAN4,
            ADown,
            SPPELAN,
            C2fAttn,
            C3,
            C3TR,
            C3Ghost,
            nn.ConvTranspose2d,
            DWConvTranspose2d,
            C3x,
            RepC3,
            PSA,
            SCDown,
            C2fCIB
        }:
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            if m is C2fAttn:
                args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)  # embed channels
                args[2] = int(
                    max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]
                )  # num heads

            args = [c1, c2, *args[1:]]
            if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB):
                args.insert(2, n)  # number of repeats
                n = 1
        elif m in {CARAFE}:
            c2 = ch[f]
            args = [c2,*args]
        elif m is AIFI:
            args = [ch[f], *args]
        elif m in {HGStem, HGBlock}:
            c1, cm, c2 = ch[f], args[0], args[1]
            args = [c1, cm, c2, *args[2:]]
            if m is HGBlock:
                args.insert(4, n)  # number of repeats
                n = 1
        elif m is ResNetLayer:
            c2 = args[1] if args[3] else args[1] * 4
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
            args.append([ch[x] for x in f])
            if m is Segment:
                args[2] = make_divisible(min(args[2], max_channels) * width, 8)
        elif m is RTDETRDecoder:  # special case, channels arg must be passed in index 1
            args.insert(1, [ch[x] for x in f])
        elif m is CBLinear:
            c2 = args[0]
            c1 = ch[f]
            args = [c1, c2, *args[1:]]
        elif m is CBFuse:
            c2 = ch[f[-1]]
        else:
            c2 = ch[f]

        m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace("__main__.", "")  # module type
        m.np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type = i, f, t  # attach index, 'from' index, type
        if verbose:
            LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f}  {t:<45}{str(args):<30}")  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        ch.append(c2)
    return nn.Sequential(*layers), sorted(save)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
  • 88.
  • 89.
  • 90.
  • 91.
  • 92.
  • 93.
  • 94.
  • 95.
  • 96.
  • 97.
  • 98.
  • 99.
  • 100.
  • 101.
  • 102.
  • 103.
  • 104.
  • 105.
  • 106.
  • 107.
  • 108.
  • 109.
  • 110.
  • 111.
  • 112.
  • 113.
  • 114.
  • 115.
  • 116.
  • 117.
  • 118.
  • 119.


代码解析

parse_model 函数的作用是将 YOLO 模型的配置(通常在 YAML 文件中定义)解析并构建成 PyTorch 模型。这个函数接收一个模型配置字典 d,输入通道数 ch,以及一个可选的布尔值 verbose 来控制是否打印详细的构建信息。下面是对这个函数的逐行解释:

  1. 导入模块
import ast
  • 1.

导入 ast 模块,用于将字符串形式的 Python 表达式转换为 Python 对象。

  1. 提取配置参数
max_channels = float("inf")
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
  • 1.
  • 2.
  • 3.

从模型配置字典 d 中提取必要的参数,包括类别数 nc、激活函数 act、尺度 scales、深度倍数 depth、宽度倍数 width 和关键点形状 kpt_shape。如果 scales 存在,则进一步提取模型的尺度参数。

  1. 设置默认激活函数
if act:
    Conv.default_act = eval(act)
    if verbose:
        LOGGER.info(f"{colorstr('activation:')} {act}")
  • 1.
  • 2.
  • 3.
  • 4.

如果配置中指定了激活函数,则设置为 Conv 类的默认激活函数,并在 verbose 模式下打印。

  1. 初始化日志信息
if verbose:
    LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")
  • 1.
  • 2.

如果 verbose 为 True,初始化日志信息的格式。

  1. 初始化构建参数
ch = [ch]
layers, save, c2 = [], [], ch[-1]
  • 1.
  • 2.

初始化通道列表 ch,层列表 layers,保存列表 save 和当前输出通道数 c2

  1. 遍历模型配置并构建层
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]):
  • 1.

遍历模型的 backbonehead 配置,f 表示输入来源,n 表示重复次数,m 表示模块类型,args 表示模块参数。

  1. 获取模块
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m]
  • 1.

根据模块类型 m 获取对应的 PyTorch 模块。

  1. 处理字符串参数
for j, a in enumerate(args):
    if isinstance(a, str):
        with contextlib.suppress(ValueError):
            args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
  • 1.
  • 2.
  • 3.
  • 4.

args 中的字符串参数转换为实际的 Python 对象。

  1. 调整深度和宽度
n = n_ = max(round(n * depth), 1) if n > 1 else n
  • 1.

根据深度倍数调整重复次数。

  1. 构建特定模块
  • 对于不同的模块类型,如 ClassifyConvConvTranspose 等,根据输入通道 ch[f]、输出通道 args[0] 和其他参数构建模块。
  • 对于 C2fAttn 模块,特别处理嵌入通道数和头数。
  1. 处理特殊模块
  • 对于 CARAFEAIFIHGStemHGBlockResNetLayernn.BatchNorm2dConcatDetectRTDETRDecoderCBLinear 和 CBFuse 等特殊模块类型,进行特定的参数处理。
  1. 构建模块并附加信息
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)
t = str(m)[8:-2].replace("__main__.", "")
m.np = sum(x.numel() for x in m_.parameters())
m_.i, m_.f, m_.type = i, f, t
if verbose:
    LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f}  {t:<45}{str(args):<30}")
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)
layers.append(m_)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.

构建模块 m_,计算参数数量 m.np,并附加索引、来源和类型信息。如果 verbose 为 True,打印模块信息,并更新保存列表。

  1. 更新通道列表
if i == 0:
    ch = []
ch.append(c2)
  • 1.
  • 2.
  • 3.

更新通道列表 ch,为下一层的构建做准备。

  1. 返回构建的模型和保存列表
return nn.Sequential(*layers), sorted(save)
  • 1.

返回构建的 PyTorch 模型和需要保存的层的索引列表。

总结:parse_model 函数的作用是根据 YOLO 模型的配置字典,构建并返回一个 PyTorch 模型和需要保存的层的索引列表。这个函数处理了多种模块类型和参数,能够灵活地构建复杂的 YOLO 模型架构。

posted @   微智启软件工作室  阅读(129)  评论(0编辑  收藏  举报  
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
点击右上角即可分享
微信分享提示