Updated On : Apr-03,2022 Time Investment : ~30 mins

PyTorch: Simple Guide To Text Classification Tasks

PyTorch is one of the most preferred Python libraries to design neural networks nowadays. It evolved a lot over time to provide researchers and developers with the necessary tools to simplify their tasks so they can do more experiments. It has developed separate sub-modules for handling vision (torchvision), recommendation (torchrec), audio (torchaudio) and text (torchtext) tasks. These modules provide many necessary utilities that are commonly used in those fields for research purposes. It makes the life of developers easy.

As a part of this tutorial, we'll design a simple network to classify text documents. We'll be using AG NEWS dataset available from torchtext module. Apart from this, we'll also use vocabulary building and other utilities available from torchtext. The main aim of the tutorial is to introduce individuals to how we can design text classification networks using PyTorch and torchtext modules.

Below, we have listed important sections of tutorial to give an overview of the material covered.

Important Sections Of Tutorial

  1. Prepare Dataset
    • 1.1 Load Dataset
    • 1.2 Build Vocabulary Of Tokens
    • 1.3 Create Data Loaders (Vectorize Text Data)
  2. Define Network
  3. Train Network
  4. Evaluate Network Performance
  5. Explain Network Predictions Using LIME Algorithm

Below, we have imported the necessary libraries and printed the versions that we have used in our tutorial.

import torch

print("PyTorch Version : {}".format(torch.__version__))
PyTorch Version : 1.9.1+cpu
import torchtext

print("Torch Text Version : {}".format(torchtext.__version__))
Torch Text Version : 0.10.1

1. Prepare Dataset

In this section, we have step-by-step vectorized our text dataset to prepare it for neural networks. We transform text data into a list of real-value vectors that are required by our networks. In order to do that, we have first populated vocabulary with tokens of data and then created data loaders that return vectorized data each time they are called. We have used the word frequency approach to vectorize data where we main frequency of word/token per text example.

1.1 Load Dataset

In this section, we have simply loaded AG NEWS dataset available from torchtext. The dataset has text documents for 4 different news categories as specified in the below table. The dataset is already split into train and test datasets.

Index Category
1 World
2 Sports
3 Business
4 Sci/Tec
from torch.utils.data import DataLoader

train_dataset, test_dataset  = torchtext.datasets.AG_NEWS()

target_classes = ["World", "Sports", "Business", "Sci/Tec"]
train.csv: 29.5MB [00:00, 86.2MB/s]
test.csv: 1.86MB [00:00, 37.8MB/s]

1.2 Build Vocabulary Of Tokens

In this section, we have populated vocabulary with tokens from train and test datasets. We have the first initialized tokenizer available from torchtext.data module through method get_tokenizer(). We have initialized a simple tokenizer that separates words and punctuation marks.

After initializing tokenizer, we have populated vocabulary using build_vocab_from_iterator() function available from torchtext.vocab module. The function takes as input an iterator that returns a list of tokens each time we call it. We have created an iterator as a simple function that takes a list of datasets as input. It then loops through each dataset and its text examples, yielding a list of tokens generated for each example using a tokenizer. We have also asked the function to use <UNK> token as a special token to which tokens that are not present in the vocabulary will be mapped.

At last, we have printed the length of vocabulary as well.

from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer("basic_english")

def build_vocab(datasets):
    for dataset in datasets:
        for _, text in dataset:
            yield tokenizer(text)

vocab = build_vocab_from_iterator(build_vocab([train_dataset, test_dataset]), specials=["<UNK>"])

1.3 Create Data Loaders (Vectorize Text Data)

In this section, we have created train and test data loaders that will be used during the training process. We have created data loaders with a batch size of 256. Both data loaders take a callable through collate_fn parameter. This function is responsible for vectorizing a batch of text documents. The function takes as input a single batch of data which has a list of text documents (256) and their respective target labels. It then vectorizes text documents using a CountVectorizer object and returns them after converting them to torch tensors along with target labels.

We have created CountVectorizer object with our vocabulary and tokenizer. The CountVectorizer is available from scikit-learn and uses a word frequency approach to vectorize text data. For each text example, it returns a vector of numbers where all numbers are zero except indexes of tokens as per vocabulary. For the tokens that appear in the text example, there will be the frequency of those tokens present at indexes of those as per vocabulary.

If the reader does not have a background in the word frequency approach of vectorizing text data and is interested in learning about it in depth then we suggest that you go through the below link. There, we have covered in detail how the approach works.

from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from torch.utils.data import DataLoader
from torchtext.data.functional import to_map_style_dataset

vectorizer = CountVectorizer(vocabulary=vocab.get_itos(), tokenizer=tokenizer)

def vectorize_batch(batch):
    Y, X = list(zip(*batch))
    X = vectorizer.transform(X).todense()
    return torch.tensor(X, dtype=torch.float32), torch.tensor(Y) - 1 ## We have deducted 1 from target names to get them in range [0,1,2,3] from [1,2,3,4]

train_dataset, test_dataset  = torchtext.datasets.AG_NEWS()
train_dataset, test_dataset = to_map_style_dataset(train_dataset), to_map_style_dataset(test_dataset)

train_loader = DataLoader(train_dataset, batch_size=256, collate_fn=vectorize_batch)
test_loader  = DataLoader(test_dataset, batch_size=256, collate_fn=vectorize_batch)
for X, Y in train_loader:
    print(X.shape, Y.shape)
torch.Size([256, 98635]) torch.Size([256])

2. Define Network

In this section, we have designed a simple neural network of linear layers using PyTorch that we'll use to classify our text documents. This network will take vectorized data as input and return predictions.

The network has 3 linear layers with 128, 64, and 4 output units. We have applied relu activation to the output of the first two linear layers. The network is designed using Sequential API of PyTorch. The Sequential API is the same as that of Keras API.

If you want to learn about how to design neural networks using PyTorch then please check the below link. It is a simple guide to the topic.

from torch import nn
from torch.nn import functional as F

class TextClassifier(nn.Module):
    def __init__(self):
        super(TextClassifier, self).__init__()
        self.seq = nn.Sequential(
            nn.Linear(len(vocab), 128),

            nn.Linear(128, 64),

            nn.Linear(64, 4),

            #nn.Linear(64, 4),

    def forward(self, X_batch):
        return self.seq(X_batch)
text_classifier = TextClassifier()
for X, Y in train_loader:
    Y_preds = text_classifier(X)
torch.Size([256, 4])

3. Train Network

In this section, we are training the network we defined in the previous section. In order to train the network, we have designed a simple function that will perform training when called. The function takes the model, loss function, optimizer, train data loader, validation data loader, and the number of epochs as input. It then executes the training loop a number of epochs times. For each epoch, it loops through training data in batches using a train data loader which returns vectorized data and their labels for each batch. For each batch, we perform a forward pass-through network to make predictions, calculate loss (using predictions and actual target labels), calculate gradients, and update network parameters. The function also records loss for each batch and prints the average training loss at the end of each epoch. We have also created another helper function which takes as input model, loss function, and validation data loader to calculate validation loss and accuracy.

from tqdm import tqdm
from sklearn.metrics import accuracy_score
import gc

def CalcValLossAndAccuracy(model, loss_fn, val_loader):
    with torch.no_grad():
        Y_shuffled, Y_preds, losses = [],[],[]
        for X, Y in val_loader:
            preds = model(X)
            loss = loss_fn(preds, Y)


        Y_shuffled = torch.cat(Y_shuffled)
        Y_preds = torch.cat(Y_preds)

        print("Valid Loss : {:.3f}".format(torch.tensor(losses).mean()))
        print("Valid Acc  : {:.3f}".format(accuracy_score(Y_shuffled.detach().numpy(), Y_preds.detach().numpy())))

def TrainModel(model, loss_fn, optimizer, train_loader, val_loader, epochs=10):
    for i in range(1, epochs+1):
        losses = []
        for X, Y in tqdm(train_loader):
            Y_preds = model(X)

            loss = loss_fn(Y_preds, Y)


        print("Train Loss : {:.3f}".format(torch.tensor(losses).mean()))
        CalcValLossAndAccuracy(model, loss_fn, val_loader)

Below, we have actually trained our network by calling our training function from the previous cell with the necessary parameters. We have initialized a number of epochs to 8 and the learning rate to 0.0001. Then, we have initialized cross entropy loss, our text classifier network, and Adam optimizer. At last, we have called our training routine with the necessary parameters to perform training. By looking at loss and accuracy values at the end of each epoch, we can come to the conclusion that our model is doing a decent job at the text classification task. We can perform various hyperparameters tunning to further improve the performance of the network.

from torch.optim import Adam

epochs = 8
learning_rate = 1e-4

loss_fn = nn.CrossEntropyLoss()
text_classifier = TextClassifier()
optimizer = Adam(text_classifier.parameters(), lr=learning_rate)

TrainModel(text_classifier, loss_fn, optimizer, train_loader, test_loader, epochs)
100%|██████████| 469/469 [02:35<00:00,  3.02it/s]
Train Loss : 0.812
Valid Loss : 0.367
Valid Acc  : 0.894
100%|██████████| 469/469 [02:28<00:00,  3.15it/s]
Train Loss : 0.291
Valid Loss : 0.282
Valid Acc  : 0.912
100%|██████████| 469/469 [02:27<00:00,  3.17it/s]
Train Loss : 0.218
Valid Loss : 0.258
Valid Acc  : 0.916
100%|██████████| 469/469 [02:29<00:00,  3.14it/s]
Train Loss : 0.175
Valid Loss : 0.247
Valid Acc  : 0.919
100%|██████████| 469/469 [02:31<00:00,  3.10it/s]
Train Loss : 0.145
Valid Loss : 0.244
Valid Acc  : 0.918
100%|██████████| 469/469 [02:35<00:00,  3.02it/s]
Train Loss : 0.121
Valid Loss : 0.246
Valid Acc  : 0.918
100%|██████████| 469/469 [02:32<00:00,  3.07it/s]
Train Loss : 0.100
Valid Loss : 0.251
Valid Acc  : 0.917
100%|██████████| 469/469 [02:33<00:00,  3.06it/s]
Train Loss : 0.083
Valid Loss : 0.260
Valid Acc  : 0.916

4. Evaluate Network Performance

In this section, we have evaluated the performance of our network by calculating accuracy, classification report (precision, recall, and f1-score per target category), and classification matrix metrics on test predictions.

We have created a small function that takes the model and data loader as input and returns predictions. Using this function, we have made predictions on the test dataset.

We have calculated various metrics using functions available from scikit-learn. Please feel free to check the below link if you are interested in learning about various ML metrics available from sklearn.

We have also plotted the confusion matrix in the cell after metrics using scikit-plot library. We can notice from the confusion matrix that our network is doing good for categories Sports and World compared to categories Business and Sci/Tech.

Please feel free to check the below link if you are interested in learning about scikit-plot. It provides visualizations for many ML metrics.

def MakePredictions(model, loader):
    Y_shuffled, Y_preds = [], []
    for X, Y in loader:
        preds = model(X)
    Y_preds, Y_shuffled = torch.cat(Y_preds), torch.cat(Y_shuffled)

    return Y_shuffled.detach().numpy(), F.softmax(Y_preds, dim=-1).argmax(dim=-1).detach().numpy()

Y_actual, Y_preds = MakePredictions(text_classifier, test_loader)
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

print("Test Accuracy : {}".format(accuracy_score(Y_actual, Y_preds)))
print("\nClassification Report : ")
print(classification_report(Y_actual, Y_preds, target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_actual, Y_preds))
Test Accuracy : 0.9160526315789473

Classification Report :
              precision    recall  f1-score   support

       World       0.92      0.92      0.92      1900
      Sports       0.96      0.98      0.97      1900
    Business       0.88      0.89      0.88      1900
     Sci/Tec       0.91      0.88      0.89      1900

    accuracy                           0.92      7600
   macro avg       0.92      0.92      0.92      7600
weighted avg       0.92      0.92      0.92      7600

Confusion Matrix :
[[1747   46   66   41]
 [  25 1855   13    7]
 [  73   13 1695  119]
 [  63   15  157 1665]]
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_actual], [target_classes[i] for i in Y_preds],
                                    title="Confusion Matrix",

PyTorch: Simple Guide To Text Classification Tasks

5. Explain Network Predictions Using LIME

In this section, we have tried to explain predictions made by our text classification network using LIME (Local Interpretable Model-Agnostic Explanations) algorithm. The lime library has an implementation of LIME algorithm that we'll be using for our purpose.

In order to explain prediction, we need to create an instance of LimeTextExplainer first. Then, we need to call explain_instance() method on it to generate Explanation object. At last, we can call show_in_notebook() method on Explanation instance to generate a visualization that highlights words contributing to predicting a particular category.

Please feel free to check the below tutorial if you want to learn about LIME algorithm and lime library in detail.

Below, we have first initialized the instance of LimeTextExplainer using LimeTextExplainer() constructor available from lime_text module of lime library. We have provided unique target labels to it as well.

from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=target_classes, verbose=True)

Below, we have first retrieved all test text documents and their target labels. Then, we have designed a function that takes a list of text documents as input and returns predictions for them. The function vectorizes input text documents and then performs a forward pass through the network using this vectorized data to make predictions. It then generates probabilities using softmax activation function and returns it.

After defining a function, we have randomly selected a text example from the test dataset and made a prediction of it using our trained network. We have printed the actual and predicted labels of the selected text example. Our network correctly predicts the target label as 'Business' for the selected text example.

import numpy as np

## Retrieve test documents.
train_dataset, test_dataset  = torchtext.datasets.AG_NEWS()
X_test_text, Y_test = [], []
for Y, X in test_dataset:

## Function to make prediction from text data
def make_predictions(X_batch_text):
    X_batch_vect = vectorizer.transform(X_batch_text).todense()
    logits = text_classifier(torch.tensor(X_batch_vect, dtype=torch.float32))
    preds = F.softmax(logits, dim=-1)
    return preds.detach().numpy()

## Randomly select test example for explanation
rng = np.random.RandomState(1)
idx = rng.randint(1, len(X_test_text))

X_batch_vect = vectorizer.transform(X_test_text[idx:idx+1]).todense()
logits = text_classifier(torch.tensor(X_batch_vect, dtype=torch.float32))
preds = F.softmax(logits, dim=-1)

print("Prediction : ", target_classes[preds.argmax(axis=-1)[0]])
print("Actual :     ", target_classes[Y_test[idx]])
Prediction :  Business
Actual :      Business

Below, we have called explain_instance() method on LimeTextExplainer instance to generate Explanation instance. We have given the selected text example, function to make predictions, and target labels to the method.

At last, we have called show_in_notebook() method on Explanation instance to generate visualization. We can notice from the visualization that words like 'concessions', 'financing', 'airlines', 'pensions', 'employees', 'bankruptcy', 'labor', 'cuts', etc are contributing to predicting target category as 'Business'. This makes sense as these are commonly used words in businesses.

explanation = explainer.explain_instance(X_test_text[idx],
                                         labels=Y_test[idx:idx+1], num_features=15)

PyTorch: Simple Guide To Text Classification Tasks

This ends our small tutorial explaining how we can perform a text classification using functionalities PyTorch and torchtext modules. Please feel free to let us know your views in the comments section.


Sunny Solanki  Sunny Solanki

YouTube Subscribe Comfortable Learning through Video Tutorials?

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.

Need Help Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.