机器学习-非线性激活(pytorch环境)

官方网站

一个例子

import torch
import torchvision.datasets
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.nn import Conv2d, MaxPool2d, ReLU
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

dataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transform, download=True)

writer = SummaryWriter("ReLU")

dataLoader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0, drop_last=False)

class TuDui(nn.Module):
    def __init__(self):
        super(TuDui, self).__init__()
        self.relu = ReLU()
        self.sigmoid = Sigmoid()


    def forward(self,input):
        output = self.sigmoid(input)
        return output

tudui = TuDui()
step = 0
for data in dataLoader:
    imgs, targets = data
    output = tudui(imgs)
    writer.add_images("input",imgs,step)
    writer.add_images("output",output,step)
    step += 1

writer.close()

 

posted @ 2021-08-14 17:40  EA2218764AB  阅读(53)  评论(0编辑  收藏  举报