Imagine, one day you have an amazing idea for your machine learning project. You write down all the details on a piece of paper- the model architecture, the optimizer, the dataset. And now you just have code it up and do some hyperparameter tuning to put it to application.
So, you light up your machine and start coding. But suddenly it hits you, you need to go through the hard work of creating batches out of the data, writing loops to iterate over batches and epochs, debugging any issues that may arise while doing so, repeating the same for the validation set and the list goes on. It turns out to be a headache before it even started.
But not anymore. PyTorch Lightning is here to save your day. Not only does it automatically do the hard work for you but it also structures your code to make it more scalable. It comes fully packed with awesome features that will enhance your machine learning experience. Beginners should definitely give it a go.
Throughout this blog we will learn how can Lightning be used along with PyTorch to make development easy and reproducible.
With this blog post, I aim to help people get to know PyTorch Lightning. From now on I will be referring to PyTorch Lightning as Lightning.
I will begin with a brief introduction to the new library and its underlying principles so that you can build research-friendly neural network models from scratch.
This tutorial assumes that you have prior knowledge of how a neural network works. It also assumes you are familiar with the PyTorch framework. Even if you are not familiar, you will be alright. For PyTorch users, this tutorial may serve as a medium to encourage them to include Lightening in their PyTorch code.
Let us start with some basic introduction.
Based on the Torch library, PyTorch is an open-source machine learning library. PyTorch is imperative, which means computations run immediately, and the user need not wait to write the full code before checking if it works or not. We can efficiently run a part of the code and inspect it in real-time. The library is python based and built for providing flexibility as a deep learning development platform.
PyTorch is extremely “pythonic” in nature. It is basically a NumPy substitute that utilizes the computation benefits of powerful GPUs
PyTorch enables the support of dynamic computational graphs that allows us to change the network on the fly.
PyTorch is an excellent framework, great for researchers. But after a certain point, it involves more engineering than researching.
As I mentioned in the introduction, the hard work starts taking over the research work. The focus shifts from training and tuning the model to correctly implementing the following features
Even though they may be simple to implement, we would still end up losing precious time and might risk a chance of making a mistake while coding these up leading to time being wasted in debugging.
Consider an example. We are training a model. We want that after 100 epochs it stops and saves the trained model into a .pth file. But we made a mistake in writing the model-saving code. The thing about python is that it does not show an error until it runs into one. So, after 10 hours of training, we run into an error. and our model did not save. And just like that, the 10 hours go down the drain. How frustrating would this be?
Lightning is a very lightweight wrapper on PyTorch. This means you don’t have to learn a new library. It defers the core training and validation logic to you and automates the rest. It guarantees tested and correct code with the best modern practices for the automated parts.
So we can actually save those 10 hours by carefully organizing our code in Lightning modules.
As the name suggests, Lightning is related to closely PyTorch: not only do they share their roots at Facebook but also Lightning is a wrapper for PyTorch itself. In fact, the core foundation of PyTorch Lightning is built upon PyTorch.
In its true sense, Lightning is a structuring tool for your PyTorch code. You just have to provide the bare minimum details (Eg. number of epoch, optimizer, etc). The rest will be automated by Lightning.
By using Lightning, you make sure that all the tricky pieces of code work for you and you can focus on the real research:
Lightning ensures that when your network becomes complex your code doesn’t
It ensures that you focus on the real deal and not worry about how to run your model on multiple GPUs or speeding up the code. Lightning will handle that for you.
But what does this mean for you? It means that this framework is designed to be extremely extensible while making state of the art AI research techniques (like multi-GPU training) trivial.
I will be showing you exactly how you can build a MNIST classifier using Lightning. I will be walking you through a very small network with 99.4% accuracy on MNIST Validation set using <8k trainable parameters. I tried re-implementing the code using PyTorch-Lightening and added my own intuitions and explanations.
We shall do this as quickly as possible so that we can move on to even more interesting details of Lightning
The basic and essential chunks of a Neural Network in Lightning are the following
We can clearly see that they are contained in 2 categories: Restructuring and Abstraction
Restructuring refers to keeping code in its respective place in the Lightning Module. It has just been arranged in the functions of Lightning Module known as Callbacks. They have a special meaning to the Lightning because it helps it understand the functionality of the function
It is to be noted that there is no change in the PyTorch code during the restructuring
The boilerplate code is abstracted by the Lightning trainer. It automates most of the code for us.
Now there is no need to write separate code for saving your model or iterating over batches. Its is now abstracted into the Trainer
Lightning provides us with the following methods of its class pl.LightningModule that help in structuring the code. They refer to them as Callbacks:
I've partnered with OpenCV.org to bring you official courses in Computer Vision, Machine Learning, and AI! Sign up now and take your skills to the next level!
OFFICIAL COURSES BY OPENCV.ORG
Now let’s dive right into coding so that we can get a hands on experience with Lightning
Run the following to install Lightning on Google Colab
1!pip install pytorch_lightning
You will have to restart the runtime for some new changes to be reflected
Do not forget to select the GPU. Go to Edit->Notebook Settings->Hardware Accelerator and select GPU in Google Colab Notebook
1import torch
2from torch.nn import functional as F
3from torch.utils.data import DataLoader, random_split
4from torchvision.datasets import MNIST
5from torchvision import transforms
6import pytorch_lightning as pl
We will be defining our own class called smallAndSmartClassifier and we will be inheriting pl.LightningModule from Lightning
Let’s start building the model
1class smallAndSmartModel(pl.LightningModule):
2 def __init__(self):
3 super(smallAndSmartModel, self).__init__()
4 self.layer1 = torch.nn.Sequential(
5 torch.nn.Conv2d(1,28,kernel_size=5),
6 torch.nn.ReLU(),
7 torch.nn.MaxPool2d(kernel_size=2))
8 self.layer2 = torch.nn.Sequential(
9 torch.nn.Conv2d(28,10,kernel_size=2),
10 torch.nn.ReLU(),
11 torch.nn.MaxPool2d(kernel_size=2))
12 self.dropout1=torch.nn.Dropout(0.25)
13 self.fc1=torch.nn.Linear(250,18)
14 self.dropout2=torch.nn.Dropout(0.08)
15 self.fc2=torch.nn.Linear(18,10)
1class smallAndSmartModel(pl.LightningModule):
2
3 #This contains the manupulation on data that needs to be done only once such as downloading it
4 def prepare_data(self):
5 MNIST(os.getcwd(), train=True, download =True)
6 MNIST(os.getcwd(), train=False, download =True)
7
8 def train_dataloader(self):
9 #This is an essential function. Needs to be included in the code
10 #See here i have set download to false as it is already downloaded in prepare_data
11 mnist_train=MNIST(os.getcwd(), train=True, download =False,transform=transforms.ToTensor())
12
13 #Dividing into validation and training set
14 self.train_set, self.val_set= random_split(mnist_train,[55000,5000])
15
16 return DataLoader(self.train_set,batch_size=128)
17
18 def val_dataloader(self):
19 # OPTIONAL
20 return DataLoader(self.val_set, batch_size=128)
21
22 def test_dataloader(self):
23 # OPTIONAL
24 return DataLoader(MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()), batch_size=128)
The train_dataloader, test_dataloader and val_dataloader are reserved functions in pl.LightningModule. We use them as wrappers for loading our data.
It is necessary to write the code in these functions just because they have a special meaning in Lightning, just like how forward has in nn.module
Each of these is responsible for returning the appropriate data split. Lightning structures it in a way so that it is very clear how the data is being manipulated. If you ever read someone else’s code that isn’t structured like this (like most GitHub codes), you won’t be able to figure out how they manipulated their data.
Lightning even allows multiple data loaders for testing or validating.
1class smallAndSmartModel(pl.LightningModule):
2 def forward(self,x):
3 x=self.layer1(x)
4 x=self.layer2(x)
5 x=self.dropout1(x)
6 x=torch.relu(self.fc1(x.view(x.size(0), -1)))
7 x=F.leaky_relu(self.dropout2(x))
8
9 return F.softmax(self.fc2(x))
This is the forward pass — where the calculation process takes place and we generate the values for the output layers from the inputs data.
Users of PyTorch may notice that there is no change in its implementation
1class smallAndSmartModel(pl.LightningModule):
2 def configure_optimizers(self):
3 # Essential fuction
4 #we are using Adam optimizer for our model
5 return torch.optim.Adam(self.parameters())
This required function returns the kind of optimizer we require. Interestingly Lightning provides us with the wrapper configure_optimizers, which allows us to even return multiple optimizers with ease (for example in GANs)
1class smallAndSmartModel(pl.LightningModule):
2 def training_step(self,batch,batch_idx):
3
4 #extracting input and output from the batch
5 x,labels=batch
6
7 #doing a forward pass
8 pred=self.forward(x)
9
10 #calculating the loss
11 loss = F.nnl_loss(pred, labels)
12
13 #logs
14 logs={"train_loss": loss}
15
16 output={
17 #REQUIRED: It ie required for us to return "loss"
18 "loss": loss,
19 #optional for logging purposes
20 "log": logs
21 }
22
23 return output
This step is called for every batch in our dataset. Some key operations that occur in this function are:
It is essential for training_step to return a dictionary containing loss. Any other data returned is optional
Obviously, there is no magic. But when I tell you what Lightning Trainer is capable of, you won’t refrain from claiming that indeed, it is charming and exquisite.
1#abstracts the training, val and test loops
2
3#using one gpu given to us by google colab for max 40 epochs
4myTrainer=pl.Trainer(gpus=1,max_nb_epochs=100)
5
6model=smallAndSmartModel()
7myTrainer.fit(model)
The Trainer is the heart of PyTorch Lightning. This is where all the abstractions take place. It abstracts the most obvious pieces of code such as:
Now you don’t have to worry about engineering these steps. The Trainer does that for you. You just have to make sure that your code is well structured as explained in the above sections.
Download Code To easily follow along this tutorial, please download code by clicking on the button below. It's FREE!
DOWNLOAD CODE
The trainer provides some very helpful flags. We can assign values to these flags to configure our classifier’s behavior.
By using the Trainer, you automatically get the following tools and features:
That’s the question you should be asking me after I told you so much about Pytorch Lightning. I will answer this by letting you in on my love for Lightning
When I look at how the code is structured in Lightning, it feels almost natural and intuitive to put it there. The structuring ensures that I have a step-by-step strategy of developing my classifier from scratch. It is as if it makes me more confident in developing my models.
The steps to make solution for machine learning are now very simple and intuitive.
Now, to come up with a solution using Lightning, I know that I need to proceed by preparing data, adding optimizers, add the training step, and so on. This helps me in moving along with the flow of ideas in my mind.
The best thing about Lightning is that each process is separated from the other in the Lightningmodule. That’s the benefit of structuring.
training_step contains information about the training step and not about the validation step or about the optimizer. It makes things more clear for me
Since Lightning is a wrapper for PyTorch, I did not have to learn a new language. Also, if I want to make very complex training steps I can easily do that without compromising on the flexibility of PyTorch.
Those who are familiar with PyTorch will find the transition to be extremely smooth.
The Trainer just wins it all. It automates most of the complex tasks for me.
In the case of GPUs, I don’t have to worry about converting my tensors to tensor.to(device=cuda). It automatically figures out the details. I just have to set a few flags. With this, I can even enable 16-bit precision, auto-cluster saving, auto-learning-rate-finder, Tensorboard visualization, etc.
By using the Trainer, I’m not only getting some very neat algorithms but I am also getting the guarantee that they will work correctly. Now that’s one less thing for me to worry about. And I can focus on my real research.
My personal favorite is Tensorboard logging and resuming training from where I left it.
Lightning is best for scholars and researchers who are working on developing the best strategies to tackle a problem. Lightning takes away the unnecessary engineering from them and provides with a clean environment to perform relevant research.
I also believe that early PyTorch users should start using Lightning so that their thinking process becomes structured and more intuitive. Also, they might find it amazing to have so many perks at their disposal, ready to be exploited.
Now that you are acquainted with PyTorch Lightning, I hope you will start using Lightning (especially if you are a researcher) and fall in love with its amazing features.
That’s all from me. If you liked my little introduction to Lightning do share feedback
Keep learning and have fun!!