2023.5.17·
今天学习了pytroch对图片进行识别训练。
1 import torch 2 from torch import nn 3 from torch.utils.data import DataLoader 4 from torchvision import datasets 5 from torchvision.transforms import ToTensor 6 # Download training data from open datasets. 7 training_data = datasets.FashionMNIST( 8 root="data", 9 train=True, 10 download=True, 11 transform=ToTensor(), 12 ) 13 14 # Download test data from open datasets. 15 test_data = datasets.FashionMNIST( 16 root="data", 17 train=False, 18 download=True, 19 transform=ToTensor(), 20 ) 21 batch_size = 64 22 23 # Create data loaders. 24 train_dataloader = DataLoader(training_data, batch_size=batch_size) 25 test_dataloader = DataLoader(test_data, batch_size=batch_size) 26 27 for X, y in test_dataloader: 28 print(f"Shape of X [N, C, H, W]: {X.shape}") 29 print(f"Shape of y: {y.shape} {y.dtype}") 30 break# Get cpu, gpu or mps device for training. 31 device = ( 32 "cuda" 33 if torch.cuda.is_available() 34 else "mps" 35 if torch.backends.mps.is_available() 36 else "cpu" 37 ) 38 print(f"Using {device} device") 39 40 # Define model 41 class NeuralNetwork(nn.Module): 42 def __init__(self): 43 super().__init__() 44 self.flatten = nn.Flatten() 45 self.linear_relu_stack = nn.Sequential( 46 nn.Linear(28*28, 512), 47 nn.ReLU(), 48 nn.Linear(512, 512), 49 nn.ReLU(), 50 nn.Linear(512, 10) 51 ) 52 53 def forward(self, x): 54 x = self.flatten(x) 55 logits = self.linear_relu_stack(x) 56 return logits 57 58 model = NeuralNetwork().to(device) 59 print(model) 60 loss_fn = nn.CrossEntropyLoss() 61 optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) 62 def train(dataloader, model, loss_fn, optimizer): 63 size = len(dataloader.dataset) 64 model.train() 65 for batch, (X, y) in enumerate(dataloader): 66 X, y = X.to(device), y.to(device) 67 68 # Compute prediction error 69 pred = model(X) 70 loss = loss_fn(pred, y) 71 72 # Backpropagation 73 loss.backward() 74 optimizer.step() 75 optimizer.zero_grad() 76 77 if batch % 100 == 0: 78 loss, current = loss.item(), (batch + 1) * len(X) 79 print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") 80 def test(dataloader, model, loss_fn): 81 size = len(dataloader.dataset) 82 num_batches = len(dataloader) 83 model.eval() 84 test_loss, correct = 0, 0 85 with torch.no_grad(): 86 for X, y in dataloader: 87 X, y = X.to(device), y.to(device) 88 pred = model(X) 89 test_loss += loss_fn(pred, y).item() 90 correct += (pred.argmax(1) == y).type(torch.float).sum().item() 91 test_loss /= num_batches 92 correct /= size 93 print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") 94 epochs = 5 95 for t in range(epochs): 96 print(f"Epoch {t+1}\n-------------------------------") 97 train(train_dataloader, model, loss_fn, optimizer) 98 test(test_dataloader, model, loss_fn) 99 print("Done!") 100 torch.save(model.state_dict(), "model.pth") 101 print("Saved PyTorch Model State to model.pth")