pytorch4-pytorch lightning
1,pytorch和pytorch lightning的区别
https://pytorch-lightning.readthedocs.io/en/latest/starter/new-project.html
- PyTorch Lightning 为您提供构建模型、数据集等所需的 API。 PyTorch 拥有训练模型所需的一切; 然而,深度学习不仅仅是附加层。 在实际训练方面,您需要编写大量样板代码,如果您需要在多台设备/机器上扩展您的训练/推理,则可能需要进行另一组集成。
- PyTorch Lightning 为您解决了这些问题。 您所需要的只是重组一些现有代码,设置某些标志,然后就完成了。 现在,您可以在 GPU/TPU/IPU 等不同的加速器上训练您的模型,使用最先进的分布式训练机制在多个机器/节点上进行分布式训练,而无需更改代码。
- 代码组织是 Lightning 的核心。 它将研究逻辑留给您,并使其余部分自动化。
2,pytorch lightning应用的一个例子
- pytorch lightning构建的是一个系统,而不仅仅是一个模型
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
'''Step 1: Define LightningModule'''
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
# Logging to TensorBoard by default
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
''' Step 2: Fit with Lightning Trainer'''
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)
# init model
autoencoder = LitAutoEncoder()
# most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
# trainer = pl.Trainer(gpus=8) (if you have GPUs)
trainer = pl.Trainer()
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
''' Step 3: Predict or Deploy, three options'''
# Option 1: Sub-models
# to use as embedding extractor
autoencoder = LitAutoEncoder.load_from_checkpoint("path/to/checkpoint_file.ckpt")
encoder_model = autoencoder.encoder
encoder_model.eval()
# to use as image generator
decoder_model = autoencoder.decoder
decoder_model.eval()
# Option 2: Forward
# using the AE to extract embeddings
class LitAutoEncoder(LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64))
def forward(self, x):
embedding = self.encoder(x)
return embedding
autoencoder = LitAutoEncoder()
embedding = autoencoder(torch.rand(1, 28 * 28))
# using the AE to generate images
class LitAutoEncoder(LightningModule):
def __init__(self):
super().__init__()
self.decoder = nn.Sequential(nn.Linear(64, 28 * 28))
def forward(self):
z = torch.rand(1, 64)
image = self.decoder(z)
image = image.view(1, 1, 28, 28)
return image
autoencoder = LitAutoEncoder()
image_sample = autoencoder()
# Option 3: Production
3,另一个例子
https://www.kaggle.com/somesh88/pytorch-lightning-submission-nbme/notebook
行动是治愈恐惧的良药,而犹豫拖延将不断滋养恐惧。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律