September 24, 2020

Finding good learning rate for your neural nets using PyTorch Lightning

mtszkw

Among of all hyperparameters used in machine learning, learning rate is probably the very first one you hear about. It may also the one that you start tuning in the first place. You can find the right value with a bit of hyper parameter optimization, running tons of training sessions or you can let tools do it, much faster. Nowadays, many libraries implement LR Finder or “LR Range Test”.

Why do we care about learning rate?

The shortest explanation of learning rate is that it controls how fast your network learns. If you recall how supervised learning works, you can imagine neural network adapting to the problem based on supervisor’s response (e.g. wrong class, output value too low/high) so it can give more accurate answer next time. Then, learning rate controls how much to change model in response to recent errors.

If there’s “good learning rate”, what does “bad” learning rate mean then? If you understand the concept, you may imagine that smaller value of LR makes your model adapt slower. It’s a slow-learner and it needs more iterations to give accurate answers. Sometimes you may already decide to stop training before it gives you the right output.

You can also pick a value that’s too large. And what then? For instance, it may cause a neural network to change its mind too quickly (and too often). Every new sample will have a huge impact on your network beliefs. Such training will be highly unstable. It’s no longer a slow-learner, but it may be even worse: your model may end up not learning anything useful in the end.

Finding optimal LR “automatically”

There’s an important (and yet relatively simple) paper by Leslie N. Smith that everybody mentions in context of finding optimal learning rate. Main purpose of that paper is to introduce “cyclical learning rates” for neural networks but also, after reading that work, you may understand how to find a good learning rate (or a range of good learning rates) for training.

So here’s how we search for optimal LR: we run a short pre-training in which learning rate is increased (linearly or exponentially) between two boundaries min_lr and max_lr. At the beginning, with small learning rate the network will start to slowly converge which results in loss values getting lower and lower. At some point, learning rate will get too large and cause network to diverge.

Learning rate suggested by lr_find method
Figure 1. Learning rate suggested by lr_find method

Then if you plot loss metric vs. tested learning rate values (Figure 1.), you usually should find the best learning rate values somewhere around the middle of the steepest descending loss curve. In Figure 1 where loss starts decreasing significantly between LR 10−3 and 10−1, red dot indicates optimal value chosen by PyTorch Lightning.

Finding optimal LR in PyTorch Lightning

Recently PyTorch Lightning became my tool of choice for short machine learning projects. I have used it for the first time couple months ago and I keep using it since then. Apart from all the cool stuff it has, it also provides Learning Rate Finder class that will help us find a good learning rate.

When you build a model with Lightning, the easiest way to enable LR Finder is what you can see below:

class LitModel(LightningModule):

   def __init__(self, learning_rate):
       self.learning_rate = learning_rate

   def configure_optimizers(self):
       return Adam(self.parameters(), lr=(self.lr or self.learning_rate))

trainer = Trainer(auto_lr_find=True) # by default it's False

Now when you call trainer.fit method, it performs that LR range test, finds a good initial learning rate and then actually trains (fit) your model. So basically it all happens automatically within fit call and you have absolutely nothing to worry about.

As stated in documentation, there’s another approach that allows you to execute LR finder manually and inspect its results. This time you have to create Trainer object with default value of auto_lr_find (False) and call lr_find method manually:

lr_finder = trainer.tuner.lr_find(model) # Run learning rate finder

fig = lr_finder.plot(suggest=True) # Plot
fig.show()

model.hparams.lr = lr_finder.suggestion()

trainer.fit(model) # Fit model

And that’s it. Result should be the same, however main advantage of this approach is that you can take a closer look at lr_finder.plot that shows which value was chosen. I used this method in my toy project to compare how LR Finder can help me to come up with better model. In that project I used simple LeNet to classify Fashion MNIST images. Take a look.

Example: LR find for Fashion MNIST classification

Basically I wanted to train a fairly simple convolutional neural network (LeNet) on an uncomplicated dataset (Fashion MNIST). I ran four separate experiments that only differ in initial learning rate values: 10−5, 10−4, 10−1 and one selected by Learning Rate Finder. I won’t describe whole implementation and other parameters as you can read it by yourself here. Let me just show you the findings of lr_find method.

It took around 12 seconds to find best initial learning rate which turned out to be 0.0363.

Looking at loss/LR plot (Figure 1) I was surprised because the suggested point is not exactly “halfway the sharpest downward slope”. However I couldn’t tell if that’s good or bad until I train the model. For logging and visualization I used TensorBoard to log loss and accuracy during training and validation steps. Below you can see metrics history for each of four experiments.

Training and validation acc. for 4 experiments
Figure 2. Training and validation acc. for 4 experiments

Learning rate suggested by Lightning (light blue) seems to outperform other values in both training and validation. At the end it reached 88.85% accuracy on validation set which is the highest score from all experiments (Figure 2). Also loss function values were the best for the “find_lr” experiment. In the last validation step it reached loss equal to 0.3091 which is the lowest value compared to other curves (Figure 3).

Training and validation loss for 4 experiments
Figure 3. Training and validation loss for 4 experiments

Conclusion

In this case, Learning Rate Finder has outperformed my choices of learning rate. Of course, I could have picked 0.0363 as my initial guess, but the whole point of LR Finder is to minimize your guesswork. Using Learning Rate Finder doesn’t require much additional code. In PyTorch Lightning you can enable that feature with just one flag.

I think using this feature is useful, as written by Leslie N. Smith in his publication:

Whenever one is starting with a new architecture or dataset, a single LR range test provides both a good LR value and a good range. Then one should compare runs with a fixed LR versus CLR with this range. Whichever wins can be used with confidence for the rest of one’s experiments.

If you don’t want to perform hyperparameter search using different LR values, which can take ages, you have two options left: pick initial LR values at random (which may leave you with terribly bad performance and convergence) or use a learning rate finder included in your machine learning framework of choice.

Which one would you pick?