随笔 - 480  文章 - 0 评论 - 45 阅读 - 73万
< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5

基本的卷积神经网络

复制代码
from torch import nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        layer1 = nn.Sequential() # 将网络模型进行添加
        layer1.add_module('conv1', nn.Conv2d(3, 32, 3, 1, padding=1)) # nn.Conv
        layer1.add_module('relu1', nn.ReLU(True))
        layer1.add_module('pool1', nn.MaxPool2d(2, 2))
        self.layer1 = layer1

        layer2 = nn.Sequential()
        layer2.add_module('conv2', nn.Conv2d(32, 64, 3, 1, padding=1))
        layer2.add_module('relu2', nn.ReLU(True))
        layer2.add_module('pool2', nn.MaxPool2d(2, 2))
        self.layer2 = layer2

        layer3 = nn.Sequential()
        layer3.add_module('conv3', nn.Conv2d(64, 128, 3, 1, padding=1))
        layer3.add_module('relu3', nn.ReLU(True))
        layer3.add_module('pool3', nn.MaxPool2d(2, 2))
        self.layer3 = layer3

        layer4 = nn.Sequential()
        layer4.add_module('fc1', nn.Linear(2048, 512))
        layer4.add_module('fc_relu1', nn.ReLU(True))
        layer4.add_module('fc2', nn.Linear(512, 64))
        layer4.add_module('fc_relu2', nn.ReLU(True))
        layer4.add_module('fc3', nn.Linear(64, 10))
        self.layer4 = layer4

    def forward(self, x):
        conv1 = self.layer1(x)
        conv2 = self.layer2(conv1)
        conv3 = self.layer3(conv2)
        fc_input = conv3.view(conv3.size(0), -1)
        fc_out = self.layer4(fc_input)

        return fc_out

model = SimpleCNN()
# print(model) # 打印输出网络结构
复制代码

提取前两层网络结构 

new_model = nn.Sequential(*list(model.children())[:2])  # 提取前两层的网络结构, 构造nn.Sequential网络串接, * 表示将里面的内容一个个传进去

提取所有的卷积层网络

conv_model = nn.Sequential()
# 提取所有的卷积层操作
for name, layer in model.named_modules():
    if isinstance(layer, nn.Conv2d):
        name = name.replace('.', '_')
        conv_model.add_module(name, layer)
print(conv_model)

打印卷积层的网络名字

for param in model.named_parameters():
    print(param)

对权重参数进行初始化操作

复制代码
from torch.nn import init
# 对权重参数进行初始化操作
for m in model.modules():
    if isinstance(m, nn.Conv2d):
        init.normal(m.weight.data)
        init.xavier_normal(m.weight.data)
        init.kaiming_normal(m.weight.data)
    elif isinstance(m, nn.Linear):
        m.weight.data.normal_()
复制代码

 

posted on   python我的最爱  阅读(2309)  评论(2编辑  收藏  举报
编辑推荐:
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
阅读排行:
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 地球OL攻略 —— 某应届生求职总结
· 提示词工程——AI应用必不可少的技术
· Open-Sora 2.0 重磅开源!
· 周边上新:园子的第一款马克杯温暖上架
点击右上角即可分享
微信分享提示