Lightning设计理念

Lightning使用以下原则构建PyTorch代码:

Lightning强制您的代码采用以下结构,从而使其可重用和可共享:

  • 研究代码(LightningModule)
  • 工程代码(您删除,并由培训师处理)
  • 非必要的研究代码(日志记录等,这在回调中使用)
  • 数据(使用PyTorch DataLoaders或将其组织到LightningDataModule中)

一旦这样做了,您就可以在多GPU、TPU、CPU上进行培训,甚至可以在不更改代码的情况下进行16位精度的培训!

开始使用我们的2 step guide


持续集成

Lightning在多个GPU、TPU CPU上进行了严格的测试,并针对主要的Python和PyTorch版本进行了测试

当前生成状态
系统/火炬版本 1.4(最低请求。) 1.5 1.6 1.7 1.8(LTS) 1.9(最新)
Conda py3.7[Linux] PyTorch & Conda PyTorch & Conda PyTorch & Conda PyTorch & Conda PyTorch & Conda PyTorch & Conda
Linux py3.7[GPU**] Build Status
Linux py3.{6,7}[TPU*] TPU tests TPU tests
Linux py3.{6,7,8,9} CI complete testing CI complete testing
OSX py3.{6,7,8,9} CI complete testing CI complete testing
Windows py3.{6,7,8,9} CI complete testing CI complete testing
  • **测试在两个NVIDIA P100上运行
  • *测试在Google GKE TPUv2/3上运行
  • TPU py3.7意味着我们支持Colab和Kaggle环境

如何使用

步骤0:安装

从PyPI轻松安装

pip install pytorch-lightning
其他安装选项

使用可选依赖项安装

pip install pytorch-lightning['extra']

孔达

conda install pytorch-lightning -c conda-forge

安装稳定版1.3.x

1.3[稳定]的实际状态如下:

CI base testing
CI complete testing
PyTorch & Conda
TPU tests
Docs check

从源安装未来版本

pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@release/1.3.x --upgrade

安装尖端技术-未来1.4

夜间从源安装(不保证)

pip install https://github.com/PyTorchLightning/pytorch-lightning/archive/master.zip

或通过测试PyPI

pip install -iU https://test.pypi.org/simple/ pytorch-lightning

步骤1:添加这些导入

import os import torch from torch import nn import torch.nn.functional as F from torchvision.datasets import MNIST from torch.utils.data import DataLoader, random_split from torchvision import transforms import pytorch_lightning as pl

步骤2:定义LightningModule(nn.Module子类)

LightningModule定义完整的系统(即:GAN、自动编码器、BERT或简单图像分类器)

class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions embedding = self.encoder(x)
        return embedding def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward x, y = batch x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

注:Training_Step定义训练循环。转发定义了LightningModule在推理/预测期间的行为方式

第三步:训练!

dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])

autoencoder = LitAutoEncoder()
trainer = pl.Trainer()
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))

高级功能

闪电已经过去了40+ advanced features专为专业的大规模人工智能研究而设计

以下是一些示例:

突出显示的功能代码片段
# 8 GPUs # no code changes needed trainer = Trainer(max_epochs=1, gpus=8)

# 256 GPUs trainer = Trainer(max_epochs=1, gpus=8, num_nodes=32)
在不更改代码的TPU上进行培训
# no code changes needed trainer = Trainer(tpu_cores=8)
16位精度
# no code changes needed trainer = Trainer(precision=16)
实验管理员
from pytorch_lightning import loggers # tensorboard trainer = Trainer(logger=TensorBoardLogger('logs/'))

# weights and biases trainer = Trainer(logger=loggers.WandbLogger())

# comet trainer = Trainer(logger=loggers.CometLogger())

# mlflow trainer = Trainer(logger=loggers.MLFlowLogger())

# neptune trainer = Trainer(logger=loggers.NeptuneLogger())

# ... and dozens more
提前停止
es = EarlyStopping(monitor='val_loss')
trainer = Trainer(callbacks=[es])
检查点设置
checkpointing = ModelCheckpoint(monitor='val_loss')
trainer = Trainer(callbacks=[checkpointing])
导出为Torchscript(JIT)(生产用途)
# torchscript autoencoder = LitAutoEncoder()
torch.jit.save(autoencoder.to_torchscript(), "model.pt")
导出到ONNX(生产用途)
# onnx with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
    autoencoder = LitAutoEncoder()
    input_sample = torch.randn((1, 64))
    autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True)
    os.path.isfile(tmpfile.name)

培训回路的高级控制(高级用户)

对于复杂/专业级别的工作,您可以选择完全控制培训循环和优化器

class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.automatic_optimization = False def training_step(self, batch, batch_idx):
        # access your optimizers with use_pl_optimizer=False. Default is True opt_a, opt_b = self.optimizers(use_pl_optimizer=True)

        loss_a = ...
        self.manual_backward(loss_a, opt_a)
        opt_a.step()
        opt_a.zero_grad()

        loss_b = ...
        self.manual_backward(loss_b, opt_b, retain_graph=True)
        self.manual_backward(loss_b, opt_b)
        opt_b.step()
        opt_b.zero_grad()

相对于非结构化PyTorch的优势

  • 型号变得与硬件无关
  • 代码清晰易读,因为工程代码是抽象出来的
  • 更容易复制
  • 犯的错误更少,因为闪电处理了棘手的工程
  • 保持所有的灵活性(LightningModules仍然是PyTorch模块),但删除了大量的样板文件
  • Lightning与流行的机器学习工具进行了数十次集成
  • Tested rigorously with every new PR我们测试PyTorch和Python支持的版本、每个操作系统、多个GPU甚至TPU的每种组合
  • 最小运行速度开销(与纯PyTorch相比,每历元约300毫秒)

示例

你好,世界
对比学习
NLP
强化学习
愿景
经典ML

社区

闪电社区由

  • 10+ core contributors他们都是来自顶尖人工智能实验室的专业工程师、研究科学家和博士生的混合体
  • 480+活跃的社区贡献者

想要帮助我们构建Lightning并为数千名研究人员减少样板吗?Learn how to make your first contribution here

闪电也是PyTorch ecosystem这要求项目有可靠的测试、文档和支持

寻求帮助

如果您有任何问题,请:

  1. Read the docs
  2. Search through existing Discussions,或add a new question
  3. Join our slack

资金来源

We’re venture funded为确保我们能够提供全天候支持,请聘请全职员工,参加会议,并通过实施您要求的功能加快行动速度


网格AI

网格AI是我们在云上大规模训练模型的平台!

注册我们的免费社区层here

要使用GRID,请使用您的常规命令:

python my_model.py --learning_rate 1e-6 --layers 2 --gpus 4

并将其更改为使用GRID TRAIN命令:

grid train --grid_gpus 4 my_model.py --learning_rate 'uniform(1e-6, 1e-1, 20)' --layers '[2, 4, 8, 16]'

上面的命令将启动(20*4)个实验,每个实验在4个GPU(320个GPU!)上运行-不对代码进行任何更改


牌照

请遵守此存储库中列出的Apache 2.0许可证。此外,Lightning框架正在申请专利

BibTeX

如果您想引用该框架,请随意使用(但只有在您喜欢它的情况下😊)或zenodo

@article{falcon2019pytorch,
  title={PyTorch Lightning},
  author={Falcon, WA, et al.},
  journal={GitHub. Note: https://github.com/PyTorchLightning/pytorch-lightning},
  volume={3},
  year={2019}
}
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。