机器学习-非线性激活(pytorch环境)
一个例子
import torch import torchvision.datasets from torch import nn from torch.nn import ReLU, Sigmoid from torch.nn import Conv2d, MaxPool2d, ReLU from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter dataset_transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor() ]) dataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transform, download=True) writer = SummaryWriter("ReLU") dataLoader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0, drop_last=False) class TuDui(nn.Module): def __init__(self): super(TuDui, self).__init__() self.relu = ReLU() self.sigmoid = Sigmoid() def forward(self,input): output = self.sigmoid(input) return output tudui = TuDui() step = 0 for data in dataLoader: imgs, targets = data output = tudui(imgs) writer.add_images("input",imgs,step) writer.add_images("output",output,step) step += 1 writer.close()