随笔分类 - PyTorch
摘要:方法一 网络模型、数据(输入、标注)以及损失函数.cuda() 点击查看代码 import torch import torchvision from torch import nn from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, S
阅读全文
摘要:准备数据集 点击查看代码 train_data = torchvision.datasets.CIFAR10("./dataset1", train=True, download=True, transform=torchvision.transforms.ToTensor()) test_data
阅读全文
摘要:保存 点击查看代码 import torch import torchvision from torch import nn from torch.nn import Conv2d vgg16_false = torchvision.models.vgg16(pretrained=False) #
阅读全文
摘要:点击查看代码 import torch import torchvision from torch import nn vgg16_false = torchvision.models.vgg16(pretrained=False) # pretrained=True 下载训练好的模型 # vgg1
阅读全文
摘要:lr : learning rate 学习速率 点击查看代码 import torch import torchvision from torch import nn from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequentia
阅读全文
摘要:损失函数 1.计算实际输出与目标之间的差距 2.为更新输出提供一定的依据(反向传播)--grad 损失函数(L1Loss、MSELoss、CrossEntropyLoss) import torch from torch.nn import L1Loss, MSELoss, CrossEntropy
阅读全文
摘要:点击查看代码 import torch from torch import nn from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential from torch.utils.tensorboard import Summa
阅读全文
摘要:点击查看代码 import torch import torchvision from torch import nn from torch.nn import Linear from torch.utils.data import DataLoader dataset = torchvision.
阅读全文
摘要:非线性激活:引入非线性特征 点击查看代码 import torch import torchvision from torch import nn from torch.nn import ReLU, Sigmoid # input = torch.tensor([[1, -0.5], # [-1,
阅读全文
摘要:最大池化目的:保留输入特征,减小数据量 点击查看代码 import torch import torchvision from torch import nn from torch.nn import MaxPool2d # 最大池化目的:保留输入特征,减小数据量 from torch.utils.
阅读全文
摘要:点击查看代码 import torch import torchvision from torch import nn from torch.nn import Conv2d from torch.utils.data import DataLoader from torch.utils.tenso
阅读全文
摘要:点击查看代码 import torch import torch.nn.functional as F input = torch.tensor([[1, 2, 0, 3, 1], [0, 1, 2, 3, 1], [1, 2, 1, 0, 0], [5, 2, 3, 1, 1], [2, 1, 0
阅读全文
摘要:点击查看代码 import torch from torch import nn class Test(nn.Module): def __init__(self): super().__init__() def forward(self, input): output = input + 1 re
阅读全文
摘要:点击查看代码 import torchvision from torch.utils.data import DataLoader # 测试集 from torch.utils.tensorboard import SummaryWriter test_set = torchvision.datas
阅读全文