pytroch 掌握深度模型构建精髓

pytorch几十行代码搞清楚模型的构建和训练

复制代码
import torch
import torch.nn as nn

N, D_in, H, D_out = 64, 1000, 100, 10
# data
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# mdoel define
class TwoLayerNet(nn.Module):
    def __init__(self, D_in, H, D_out):
        # main layers
        super(TwoLayerNet, self).__init__()
        self.linear1 = nn.Linear(D_in, H)
        self.linear2 = nn.Linear(H, D_out)
        
    def forward(self, x):
        y_pred = self.linear2(self.linear1(x).clamp(min=0))
        return y_pred
    
# init model
loss_fn = nn.MSELoss(reduction='sum')
model = TwoLayerNet(D_in, H, D_out)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# training
for i in range(500):
    # 1.forward pass
    y_pred = model(x)
    
    # 2.compute loss
    loss = loss_fn(y_pred, y)
    print(i, loss.item())
    
    optimizer.zero_grad()
    # 3.backward pass
    loss.backward()
    
    # 4.weights update
    optimizer.step()
    
复制代码

 

posted @   今夜无风  阅读(166)  评论(0编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)
点击右上角即可分享
微信分享提示