机器学习-线性层(pytorch环境)
一个例子:
import torch import torchvision.datasets from torch import nn from torch.nn import ReLU, Sigmoid, Linear from torch.nn import Conv2d, MaxPool2d, ReLU from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter dataset_transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor() ]) dataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transform, download=True) dataLoader = DataLoader(dataset=dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=True) class TuDui(nn.Module): def __init__(self): super(TuDui, self).__init__() self.linear = Linear(196608,10) def forward(self,input): output = self.linear(input) return output tudui = TuDui() for data in dataLoader: imgs, targets = data print(imgs.shape) output = torch.flatten(imgs) print(output.shape) output = tudui(output) print(output.shape)
class Linear(Module): r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` This module supports :ref:`TensorFloat32<tf32_on_ampere>`. Args: in_features: size of each input sample out_features: size of each output sample bias: If set to ``False``, the layer will not learn an additive bias. Default: ``True`` Shape: - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of additional dimensions and :math:`H_{in} = \text{in\_features}` - Output: :math:`(N, *, H_{out})` where all but the last dimension are the same shape as the input and :math:`H_{out} = \text{out\_features}`. Attributes: weight: the learnable weights of the module of shape :math:`(\text{out\_features}, \text{in\_features})`. The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k = \frac{1}{\text{in\_features}}` bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. If :attr:`bias` is ``True``, the values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{in\_features}}` Examples:: >>> m = nn.Linear(20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30]) """