How to fine-tune BERT with pytorch-lightning
What’s up world!
I hope you are enjoying fine-tuning transformer-based language models on tasks of your interest and achieving cool results.
I assume quite many of you use this amazing transformers library from huggingface to fine-tune pre-trained language models. This is a library that lets you use state-of-the-art general-purpose (pre-trained) language models for PyTorch and Tensorflow. This library makes the process of downloading pretraining models very easy, plus the library also provides a set of python scripts so you can fine-tune the models on the task that you’re interested in.
Running provided scripts is very easy. However, when we want to add some changes to those scripts may get a bit tricky.
But don’t worry, we are lucky that there are other amazing libraries out there that help you to implement clean and pretty script for fine-tuning.
pytorch-lightning is a lightweight PyTorch wrapper which frees you from writing boring training loops. We will see the minimal functions we need in this tutorial later. To learn detail of this, I will refer you to its documents.
For the data pipeline, we will use tofunlp/lineflow, a dataloader library for deep learning frameworks. This library provides a bunch of functions to ease data handling for NLP datasets. Additionally, it gives us common NLP dataset downloaders. So we don’t need to download datasets by ourselves anymore!
Using these tools, we will go through the following items in this tutorial.
So let’s get started then!
If you don’t have time to read this article through, you can directly go to my GitHub repository, clone it, set up for it, run it.
First, we will take a look at the task that we are tackling today, Microsoft Research Paraphrase Corpus, a task that given two documents, models are asked to predict if they have the same meanings. For instance, two sentences like following, “It is an excellent day for a picnic!” and “In a day like this, I want to go for a picnic!”, have different surfaces but conceptually the same meanings. So we want our models to predict “TRUE” with this pair.
To use this dataset, all we have to do is use lineflow library I mentioned above.
Calling MsrParphrase class in lineflow.datasets module, it downloads data from the web and gives you an iterator. In the sample above, you can see two sentences “sentence1” and “sentence2”, and quality (i.e., label). When quality is “1”, the pair is a paraphrase. If it’s “0”, the pair isn’t a paraphrase.
After we got this raw dataset, we want to convert this dataset into the format that BERT can process. Since BERT requires texts to be processed by BPE, we need to use the same tokenizer that BERT uses when it was pre-trained. But don’t worry, the transformers also provides it with a simple interface.
from transformers import BertTokenizertokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)text = "Hello NLP lovers!"
inputs = tokenizer.encode_plus(text, add_special_tokens=True, max_length=MAX_LEN)input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
By using tokenizer’s encode_plus function, we can do 1) tokenize a raw text, 2) replace tokens with corresponding ids, 3) insert special tokens for BERT. Cool! We can also pass this function a pair of texts so that it can be converted into the perfect format for our task, paraphrase identification.
sent1 = "It is an excellent day for a picnic!"
sent2 = "In a day like this, I want to go for a picnic!"
inputs = tokenizer.encode_plus(sent1, sent2, add_special_token=True, max_length=MAX_LEN)input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
So now we know how to encode input strings into BERT ready format. Next, let’s look at an actual code I wrote for this article.
In this, there are two main functions. First, preprocess: takes a data instance, and encode it to BERT format and pad the sequences. Second, get_dataloader: applies preprocess to all the instances in the dataset and make PyTorch DataLoader. This gist is a bit long, but it is just because I added some comment lines.
To get your BERT ready is very easy with transformers. You just need to choose which transformer-baed language model you want.
from transformers import BertForSequenceClassification
NUM_LABELS = 2 # For paraphrase identification, labels are binary, "paraphrase" or "not paraphrase".
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=NUM_LABELS)
transformers provides BertModel, which is just a pre-trained BERT, here we can use BertForSequenceClassification instead. This is a PyTorch’s nn.Module class which contains pre-trained BERT plus initialized classification layer on top.
Yeah, this is it! Very easy, isn’t it?
We now have the data and model prepared, let’s put them together into a pytorch-lightning format so that we can run the fine-tuning process easy and simple.
As shown in the official document, there at least three methods you need implement to utilize pytorch-lightning’s LightningModule class, 1) train_dataloader, 2) training_step and 3) configure_optimizers. Let’s check how to write these methods for fine-tuning one by one.
In this function, we simply need to return the pytorch dataloader we implemented in the preprocessing section. I will skip the code here, check the repository if you don’t understand.
As it is shown in the pytorch-lighting docs, this function takes batch generated by dataloader we implemented. Then, pass the inputs in batch to the BertForSequenceClassification instance. Since we pass the correct labels with inputs, this model can just return the loss value. We do not even care about calculating the loss in this setting. After we obtained the loss from the model, we just follow the pytorch-lightning format, make a dictionary that contains the loss to be used for updating the model’s parameters.
This is pretty straightforward to implement too. As we can see in the official document, we just need to initialize PyTorch's optimizer and return. For optimizer’s configuration, we simply use the one from huggingface/transformers’s sample script. The actual code block would look like this.
Now, we have everything prepared. For the entire codes, you can check my GitHub repository here. In here, there are actually two scripts, one for Paraphrase detection we just went through, another one for CommonsenseQA.
In this article, we check how to use lineflow to download and preprocess the dataset. Then, use pytorch-lightning to fine-tune pre-trained BERT which is provided from transformers. I hope you enjoyed reading this article, and actually try to run the codes!