BERT (Devlin, et al, 2018) is perhaps the most popular NLP approach to transfer learning. The implementation by Huggingface offers a lot of nice features and abstracts away details behind a beautiful API.
PyTorch Lightning is a lightweight framework (really more like refactoring your PyTorch code) which allows anyone using PyTorch such as students, researchers and production teams, to scale deep learning code easily while making it reproducible. It also provides 42+ advanced research features via trainer flags.
Lightning does not add abstractions on to of PyTorch which means it plays nicely with other great packages like Huggingface! In this tutorial we’ll use their implementation of BERT to do a finetuning task in Lightning.
In this tutorial we’ll do transfer learning for NLP in 3 steps:
If you’d rather see this in actual code, copy this colab notebook!
If you’re a researcher trying to improve on the NYU GLUE benchmark, or a data scientist trying to understand product reviews to recommend new content, you’re looking for a way to extract a representation of a piece of text so you can solve a different task.
For transfer learning you generally have two steps. You use dataset X to pretrain your model. Then you use that pretrained model to carry that knowledge into solving dataset B. In this case, BERT has been pretrained on BookCorpus and English Wikipedia [1]. The downstream task is what you care about which is solving a GLUE task or classifying product reviews.
The benefit of pretraining is that we don’t need much data in the downstream task to get amazing results.
In general, we can finetune with PyTorch Lightning using the following abstract approach:
For transfer learning we define two core parts inside the LightningModule.
You can think of the pretrained model as a feature extractor. This can allow you to represent objects or inputs in a much better way than say a boolean or some tabular mapping.
For instance if you have a collection of documents, you could run each through the pretrained model, and use the output vectors to compare documents to each other.
The finetune model can be arbitrarily complex. It could be a deep network, or it could be a simple Linear model or SVM.
Here we’ll use a pretrained BERT to finetune on a task called MNLI. This is really just trying to classify text into three categories. Here’s the LightningModule:
In this case we’re using the pretrained BERT from the huggingface library and adding our own simple linear classifier to classify a given text input into one of three classes.
However, we still need to define the validation loop which calculates our validation accuracy
And the test loop which calcualates our test accuracy
Finally, we define the optimizer and dataset we’ll operate on. This dataset should be the downstream dataset which you’re trying to solve.
The full LightningModule Looks like this.
Here we learned to use the Huggingface BERT as a feature extractor inside a LightningModule. This approach means you can leverage a really strong text representation to do things like:
You also saw how well PyTorch Lightning plays with other libraries including Huggingface!