Fine-tuning BERT with 5 lines of code!

Manoj Gadde
Geek Culture
Published in
3 min readJul 8, 2021

--

In this blog let’s try to understand how to fine-tune pre-trained model(BERT) on one of the Kaggle competitions which is 'Disaster Tweets Prediction’.
By the end of this article, you will be able to use BERT on your own dataset.

So let’s get started!

Google-BERT

Data Understanding:
In this dataset we have tweets and their corresponding labels like 0 and 1. if the tweet belongs to disaster then it is labeled as 1 otherwise 0. so after model training our model should be able to predict whether the tweet belongs to disaster or not.

Data

But before training, let’s also do some usual preprocessing on the text like removing stop words, punctuations, etc.

Text Preprocessing:

Text Processing

Here I have divided two columns into X and y and then performed pre-processing steps like removing stop words, punctuations, and lemmatization on text. After Preprocessing I have done train and test split.

train and test split

Using Pre-Trained Model Tokenizer:

This is the most important step, till now we haven’t converted our raw text into numerical values which the model can understand. To convert our raw text into numerical values we need to use a pre-trained model tokenizer and remember tokenizer will be different for different models, here we are using ‘distil bert tokenizer fast’ for the ‘distilbert-base-uncased’ pre-trained model.

tokenizer

After creating the tokenizer object, we just need to pass our raw text to the tokenizer object.

Note: tokenizer only accepts the input text sequences in a list format.

Importing pre-trained model:

importing model is simple, we just need to import ‘TFDistilBertForSequenceClassification’(since our task is sequence classification) from the transformers library and load the pre-trained model using from_pretrained.

model

Model Training:

training

Finally, we are compiling our model using the loss function to compute loss and optimizer to update the weights. after that, we passed our train data and corresponding labels to .fit() by specifying epochs and batch size.

By using this approach I was able to achieve 564/3000 rank in this competition with 0.82 score.

Before BERT, I tried other architectures like LSTM and GRU but these models achieved validation accuracy of 70% only.

That’s it for this blog! Thank you

my other blog on ‘how i helped my village farmer using deep learning’

--

--

Manoj Gadde
Geek Culture

Machine Learning | Deep Learning | Artificial Intelligence