mlp

import torch
from d2l import torch as d2l
from torch import nn
batch_size = 100
train_iter , test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
input_size = 784
hidden_size = 300
output_size = 10
W1 = nn.Parameter( 
    torch.randn(input_size , hidden_size , requires_grad = True)*0.01
)
b1 = nn.Parameter( 
    torch.randn(1 , hidden_size , requires_grad = True)*0.01
)

W2 = nn.Parameter( 
    torch.randn(hidden_size , output_size , requires_grad = True)*0.01
)
b2 = nn.Parameter( 
    torch.randn(1 , output_size , requires_grad = True)*0.01
)

params = [W1 , b1 , W2 , b2]
W1.shape , b1.shape , W2.shape , b2.shape
(torch.Size([784, 300]),
 torch.Size([1, 300]),
 torch.Size([300, 10]),
 torch.Size([1, 10]))
def relu(X):
    zero = torch.zeros_like(X)
    return torch.max(X,zero)
relu( torch.randn(1,2) )
tensor([[2.3051, 0.0000]])
def net(X):
    hid1 = relu((X.reshape(X.shape[0],-1))@W1 + b1)
    return hid1@W2+b2
loss = nn.CrossEntropyLoss(reduction="none")
lr = 0.1
trainer = torch.optim.SGD(params , lr)
help(d2l.train_ch3)
Help on function train_ch3 in module d2l.torch:

train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)
    Train a model (defined in Chapter 3).
    
    Defined in :numref:`sec_softmax_scratch`

num_epoch = 10
d2l.train_ch3(net , train_iter , test_iter ,loss , num_epoch , trainer )

重点函数

  • torch.zeros_like(x) 创建与x的shape相同的零矩阵张量

简洁版

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 100
train_iter , test_iter = d2l.load_data_fashion_mnist(batch_size)
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784,500),
                    nn.ReLU(),
                    nn.Linear(500,10)
                   )
lr = 0.1
loss = nn.CrossEntropyLoss(reduction="none")
trainer = torch.optim.SGD(net.parameters() , lr)
help(d2l.train_ch3)
Help on function train_ch3 in module d2l.torch:

train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)
    Train a model (defined in Chapter 3).
    
    Defined in :numref:`sec_softmax_scratch`

num_epoch = 10
d2l.train_ch3(net,train_iter , test_iter , loss , num_epoch , trainer)


posted @ 2024-06-24 13:31  Mr小明同学  阅读(20)  评论(0编辑  收藏  举报