pytorch4-pytorch lightning

1,pytorch和pytorch lightning的区别

https://pytorch-lightning.readthedocs.io/en/latest/starter/new-project.html

  • PyTorch Lightning 为您提供构建模型、数据集等所需的 API。 PyTorch 拥有训练模型所需的一切; 然而,深度学习不仅仅是附加层。 在实际训练方面,您需要编写大量样板代码,如果您需要在多台设备/机器上扩展您的训练/推理,则可能需要进行另一组集成。
  • PyTorch Lightning 为您解决了这些问题。 您所需要的只是重组一些现有代码,设置某些标志,然后就完成了。 现在,您可以在 GPU/TPU/IPU 等不同的加速器上训练您的模型,使用最先进的分布式训练机制在多个机器/节点上进行分布式训练,而无需更改代码。
  • 代码组织是 Lightning 的核心。 它将研究逻辑留给您,并使其余部分自动化。

2,pytorch lightning应用的一个例子

  • pytorch lightning构建的是一个系统,而不仅仅是一个模型
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl

'''Step 1: Define LightningModule'''
class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

''' Step 2: Fit with Lightning Trainer'''
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)

# init model
autoencoder = LitAutoEncoder()

# most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
# trainer = pl.Trainer(gpus=8) (if you have GPUs)
trainer = pl.Trainer()
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

''' Step 3: Predict or Deploy, three options'''
# Option 1: Sub-models
# to use as embedding extractor
autoencoder = LitAutoEncoder.load_from_checkpoint("path/to/checkpoint_file.ckpt")
encoder_model = autoencoder.encoder
encoder_model.eval()
# to use as image generator
decoder_model = autoencoder.decoder
decoder_model.eval()

# Option 2: Forward
# using the AE to extract embeddings
class LitAutoEncoder(LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 64))
    def forward(self, x):
        embedding = self.encoder(x)
        return embedding
autoencoder = LitAutoEncoder()
embedding = autoencoder(torch.rand(1, 28 * 28))
# using the AE to generate images
class LitAutoEncoder(LightningModule):
    def __init__(self):
        super().__init__()
        self.decoder = nn.Sequential(nn.Linear(64, 28 * 28))
    def forward(self):
        z = torch.rand(1, 64)
        image = self.decoder(z)
        image = image.view(1, 1, 28, 28)
        return image
autoencoder = LitAutoEncoder()
image_sample = autoencoder()

# Option 3: Production

3,另一个例子

https://www.kaggle.com/somesh88/pytorch-lightning-submission-nbme/notebook

posted @   tensor_zhang  阅读(103)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示

目录传送