使用不同函数打印torch.nn模型——print(model),named_children(),named_modules():

创建模型

创建一个具有三级嵌套的模型,结构如图:
image

import torch
import torch.nn as nn

# 定义子子模块
class SubSubModule(nn.Module):
    def __init__(self):
        super(SubSubModule, self).__init__()
        self.conv = nn.Conv2d(3, 3, kernel_size=3, padding=1)

    def forward(self, x):
        return self.conv(x)

# 定义子模块
class SubModule(nn.Module):
    def __init__(self):
        super(SubModule, self).__init__()
        self.sub_sub_module = SubSubModule()  # 实例化子子模块
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        x = self.sub_sub_module(x)  # 使用子子模块
        x = torch.relu(x)
        x = self.pool(x)
        return x

# 定义主模块
class MainModule(nn.Module):
    def __init__(self):
        super(MainModule, self).__init__()
        self.sub_module = SubModule()  # 实例化子模块
        self.fc = nn.Linear(3 * 16 * 16, 10)  # 假设输入图像大小为 32x32

    def forward(self, x):
        x = self.sub_module(x)  # 使用子模块
        x = x.view(x.size(0), -1)  # 展平特征图
        x = self.fc(x)
        return x

# 实例化主模块
model = MainModule()

# 打印模型结构
print(model)

使用print直接打印

直接使用print函数打印,会以整个模型为单位打印

# 实例化主模块
model = MainModule()

# 打印模型结构
print(model)

image

使用named_children()函数打印模型的子模块

named_children()只会打印children,也就是子模块,至于孙子,曾孙子...一律不打印,即 子子模块及以下的都都不会打印

#打印模型的子模块
for name, module in model.named_children():
    print(name, module)

image
image

使用named_modules函数打印模型的子模块

named_modules从命名就可以看出,会遍历模型中的所有模块(与named_children()恰恰相反),从主模块到子模块到子子模块到子子...子模块,每一个模块都会打印出来

#打印模型的所有模块
for name, module in model.named_modules():
    print(name, module)

image

使用named_parameters()函数打印模型的可学习参数

#打印模型的可学习参数
for name, param in model.named_parameters():
    print(name, param.size())

image

posted @ 2024-07-01 16:34  seekwhale13  阅读(43)  评论(0编辑  收藏  举报