This post answers the most frequent question about why you need Lightning if you’re using PyTorch.
PyTorch is extremely easy to use to build complex AI models. But once the research gets complicated and things like multi-GPU training, 16-bit precision and TPU training get mixed in, users are likely to introduce bugs.
PyTorch Lightning solves exactly this problem. Lightning structures your PyTorch code so it can abstract the details of training. This makes AI research scalable and fast to iterate on.
PyTorch Lightning was created for professional researchers and PhD students working on AI research.
Lightning was born out of my Ph.D. AI research at NYU CILVR and Facebook AI Research. As a result, the framework is designed to be extremely extensible while making state of the art AI research techniques (like TPU training) trivial.
Now the core contributors are all pushing the state of the art in AI using Lightning and continue to add new cool features.
However, the simple interface gives professional production teams and newcomers access to the latest state of the art techniques developed by the Pytorch and PyTorch Lightning community.
Lightning counts with over 320 contributors, a core team of 11 research scientists, PhD students and professional deep learning engineers.
it is rigorously tested
This tutorial will walk you through building a simple MNIST classifier showing PyTorch and PyTorch Lightning code side-by-side. While Lightning can build any arbitrarily complicated system, we use MNIST to illustrate how to refactor PyTorch code into PyTorch Lightning.
The full code is available at this Colab Notebook.
In a research project, we normally want to identify the following key components:
Let’s design a 3-layer fully-connected neural network that takes as input an image that is 28x28 and outputs a probability distribution over 10 possible labels.
First, let’s define the model in PyTorch
This model defines the computational graph to take as input an MNIST image and convert it to a probability distribution over 10 classes for digits 0–9.
To convert this model to PyTorch Lightning we simply replace the nn.Module with the pl.LightningModule
The new PyTorch Lightning class is EXACTLY the same as the PyTorch, except that the LightningModule provides a structure for the research code.
Lightning provides structure to PyTorch code
See? The code is EXACTLY the same for both!
This means you can use a LightningModule exactly as you would a PyTorch module such as prediction
Or use it as a pretrained model
For this tutorial we’re using MNIST.
Let’s generate three splits of MNIST, a training, validation and test split.
This again, is the same code in PyTorch as it is in Lightning.
The dataset is added to the Dataloader which handles the loading, shuffling and batching of the dataset.
In short, data preparation has 4 steps:
Again, the code is exactly the same except that we’ve organized the PyTorch code into 4 functions:
prepare_data
This function handles downloads and any data processing. This function makes sure that when you use multiple GPUs you don’t download multiple datasets or apply double manipulations to the data.
This is because each GPU will execute the same PyTorch thereby causing duplication. ALL of the code in Lightning makes sure the critical parts are called from ONLY one GPU.
train_dataloader, val_dataloader, test_dataloader
Each of these is responsible for returning the appropriate data split. Lightning structures it this way so that it is VERY clear HOW the data are being manipulated. If you ever read random github code written in PyTorch it’s nearly impossible to see how they manipulate their data.
Lightning even allows multiple dataloaders for testing or validating.
This code is organized under what we call a DataModule. Although this is 100% optional and lightning can use DataLoaders directly, a DataModule makes your data reusable and easy to share.
Now we choose how we’re going to do the optimization. We’ll use Adam instead of SGD because it is a good default in most DL research.
Again, this is exactly the same in both except it is organized into the configure optimizers function.
Lightning is extremely extensible. For instance, if you wanted to use multiple optimizers (ie: a GAN), you could just return both here.
You’ll also notice that in Lightning we pass in self.parameters() and not a model because the LightningModule IS the model.
For n-way classification we want to compute the cross-entropy loss. Cross-entropy is the same as NegativeLogLikelihood(log_softmax) which we’ll use instead.
Again… code is exactly the same!
We assembled all the key ingredients needed for training:
Now we implement a full training routine which does the following:
in math
in code
in math
in code
in math
the code
in math
in code
in math
in code
in math
in code
In both PyTorch and Lightning the pseudocode looks like this
This is where lightning differs though. In PyTorch, you write the for loop yourself which means you have to remember to call the correct things in the right order — this leaves a lot of room for bugs.
Even if your model is simple, it won’t be once you start doing more advanced things like using multiple GPUs, gradient clipping, early stopping, checkpointing, TPU training, 16-bit precision, etc… Your code complexity will quickly explode.
Even if your model is simple, it won’t be once you start doing more advanced things
Here’s are the validation and training loop for both PyTorch and Lightning
This is the beauty of lightning. It abstracts the boilerplate (the stuff not in boxes) but leaves everything else unchanged. This means you are STILL writing PyTorch except your code has been structured nicely.
This increases readability which helps with reproducibility!
The trainer is how we abstract the boilerplate code.
Again, this is possible because ALL you had to do was organize your PyTorch code into a LightningModule
The full MNIST example written in PyTorch is as follows:
The lightning version is EXACTLY the same except:
This version does not use the DataModule, but instead keeps the dataloaders defined freely.
And here is the same code but the data has been grouped under the DataModule and made more reusable.
Let’s call out a few key points
5. In Lightning you got a bunch of freebies such as a sick progress bar
you also got a beautiful weights summary
tensorboard logs (yup! you had to nothing to get this)
and free checkpointing, and early stopping.
All for free!
But Lightning is known best for out of the box goodies such as TPU training etc…
In Lightning, you can train your model on CPUs, GPUs, Multiple GPUs, or TPUs without changing a single line of your PyTorch code.
You can also do 16-bit precision training
Log using 5 other alternatives to Tensorboard
We even have a built in profiler that can tell you where the bottlenecks are in your training.
Setting this flag on gives you this output
Or a more advanced output if you want
We can also train on multiple GPUs at once without you doing any work (you still have to submit a SLURM job)
And there are about 40 other features it supports which you can read about in the documentation.
You’re probably wondering how it’s possible for Lightning to do this for you and yet somehow make it so that you have full control over everything?
Unlike keras or other high-level frameworks lightning does not hide any of the necessary details. But if you do find the need to modify every aspect of training on your own, then you have two main options.
The first is extensibility by overriding hooks. Here’s a non-exhaustive list:
These overrides happen in the LightningModule
A callback is a piece of code that you’d like to be executed at various parts of training. In Lightning callbacks are reserved for non-essential code such as logging or something not related to research code. This keeps the research code super clean and organized.
Let’s say you wanted to print something or save something at various parts of training. Here’s how the callback would look like
PyTorch Lightning Callback
Now you pass this into the trainer and this code will be called at arbitrary times
This paradigm keeps your research code organized into three different buckets
Hopefully this guide showed you exactly how to get started. The easiest way to start is to run the colab notebook with the MNIST example here.
Or install Lightning
Or check out the Github page.