Updated On : May-23,2022 Time Investment : ~30 mins

Captum: Interpret Predictions Of PyTorch Text Classification Networks

Due to tremendous success of deep learning, majority of people have shifted their attention towards designing deep neural networks consisting of different kind of layers to solve machine learning tasks involving unstructured data. The usage of traditional white-box ML models (decision trees, random forests, gradient boosting machines, etc) has decreased over time. Though black-box deep neural networks gives quite good accuracy for tasks based on unstructured datasets, it's quite hard to interpret their predictions. Over the years, many algorithms have been developed to interpret predictions of deep neural networks and many Python libraries (LIME, SHAP, Eli5, Captum, etc.) has been designed that provides implementation of these algorithms.

As a part of this tutorial, we have concentrated on one such Python library named Captum. Captum is designed specifically to interpret predictions of networks designed using PyTorch only. It has implementation of many interpretation algorithms that were invented over the years. Here, we'll use Captum to explain predictions of neural networks that solves text classification tasks. We have listed all available algorithms from Captum in below link. It explains how to use captum for tabular data.

If you are looking for guide on using captum for image classification tasks then please check below link.

Below, we have listed important sections of Tutorial to give an overview of the material covered.

Important Sections Of Tutorial

  1. Prepare Data
    • 1.1 Load Dataset
    • 1.2 Populate Vocabulary
    • 1.3 Create Data Loaders
  2. Define Network
  3. Train Network
  4. Evaluate Network Performance
  5. Explain Predictions using CAPTUM
    • Integrated Gradients
    • Neuron Integrated Gradients

Below, we have imported necessary Python libraries and printed the versions that we have used in our tutorial.

import captum

print("CAPTUM Version : {}".format(captum.__version__))
CAPTUM Version : 0.5.0
import torch

print("PyTorch Version : {}".format(torch.__version__))
PyTorch Version : 1.9.1+cpu
import torchtext

print("TorchText Version : {}".format(torchtext.__version__))
TorchText Version : 0.10.1

1. Prepare Data

In this section, we are preparing data for neural network. We'll be using word frequency approach for encoding text examples. This approach maintains frequency of words per text example in a vector. We have followed below steps for encoding text data and readying it for neural network.

  1. Load Dataset
  2. Load through each text example, tokenize them and populate vocabulary of unique tokens (words). The vocabulary is simple mapping from token to integer index. Each token is assigned a unique integer index starting from 0.
  3. Vectorize Each text example using populated vocabulary. For each text example, tokenize it and retrieve frequency of each token. The long vector of same size as vocabulary length will be created (per text example) which will have frequency of words at their integer index location as per vocabulary. All other tokens present in vocabulary but not present in text example will have 0 as their frequency.

It is okay if above steps were not clear when you read them. It'll become clear when we actually perform them below.

1.1 Load Dataset

In this section, we have loaded our text classification dataset that we are going to use for our task. We'll be using AG NEWS dataset available from datasets module of torchtext Python library. The dataset is already divided into train and test sets. It has text examples for 4 different news categories (["World", "Sports", "Business", "Sci/Tec"]).

train_dataset, test_dataset = torchtext.datasets.AG_NEWS()

target_classes = ["World", "Sports", "Business", "Sci/Tec"]
train.csv: 29.5MB [00:00, 93.4MB/s]
test.csv: 1.86MB [00:00, 55.6MB/s]

1.2 Populate Vocabulary

In this section, we have populated vocabulary of unique words using various helper functions available from torchtext library.

First, we have initialized a simple tokenizer using get_tokenizer() function available from torchtext. The tokenizer is a function that takes a text example as input and breaks it into list of words.

text = "Hello, How are you? Where are you planning to go?"
tokenized = ["hello", ",", "how", "are", "you", "?", "where",
            "are", "you", "planning", "to", "go", "?"]

After initializing tokenizer, we have populated a vocabulary using build_vocab_from_iterator() function from vocab module. The function requires an iterator as input which returns list of tokens per each call. We have created a simple function named build_vocabulary() which will work as iterator. It takes as input list of datasets. It then loops through each text examples yielding list of tokens for that example which will be used to populate vocabulary. The vocabulary is simple mapping from word to integer index.

vocab = {
    'hello': 0,
    'how': 1,
    'are': 2,
    'you': 3,
    'where': 4,
    'planning': 5,
    'to': 6,
    'go': 7
    ...
}

After populating vocabulary, we have also printed the size of vocabulary which is number of unique tokens (words) in our datasets.

from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer("basic_english")

def build_vocabulary(datasets):
    for dataset in datasets:
        for _, text in dataset:
            yield tokenizer(text)

vocab = build_vocab_from_iterator(build_vocabulary([train_dataset, test_dataset]), min_freq=1, specials=["<UNK>",])
vocab.set_default_index(vocab["<UNK>"])
print("Vocabulary Size : {}".format(len(vocab)))
Vocabulary Size : 98635

1.3 Create Data Loaders

In this section, we have created data loaders that will be used to loop through data examples in batches during training process.

In order to vectorize each text example using our populated vocabulary from earlier, we have used CountVectorizer class available from scikit-learn. We have created an instance of CountVectorizer using our vocabulary and tokenizer. This vectorizer can take batch of text example as input and returns their vectorized version. The input to CountVectorizer will be of shape (batch_size,) and output will be of shape (batch_size, vocab_len). As we said earlier, for each token of text example, their frequency will be present at their integer index as per vocabulary.

After defining vectorizer, we have created a simple function that takes batch of data (text examples and their target labels) as input and returns vectorized batch data (vectorized examples and target labels). This function will be used by data loaders to vectorize each batch of data.

At last, we have created train and test data loaders using respective datasets. We have set batch size of 256 which will let us loop through data in batches of 256 examples and their labels. We have provided our vectorization function to collate_fn parameter when creating data loaders. This function will be called on each batch of data and its output will be given to network for training.

Below, we have explained with one simple example how vectorization happens.

text = "Hello, How are you? Where are you planning to go?"

vocab = {
    'hello': 0,
    'bye': 1,
    'how': 2,
    'the': 3,
    'welcome': 4,
    'are': 5,
    'you': 6,
    'to': 7
}

vector = [1, 0, 1, 0, 0, 2, 2, 1]

If you are someone who is not aware of word frequency (bag of words) approach of encoding text data and CountVectorizer then we recommend that you go through below link which covers it in detail. It'll help you easily sail through this tutorial.

from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from torch.utils.data import DataLoader
from torchtext.data.functional import to_map_style_dataset

vectorizer = CountVectorizer(vocabulary=vocab.get_itos(), tokenizer=tokenizer)

def vectorize_batch(batch):
    Y, X = list(zip(*batch))
    X = vectorizer.transform(X).todense()
    return torch.tensor(X, dtype=torch.float32), torch.tensor(Y) - 1 ## We have deducted 1 from target names to get them in range [0,1,2,3] from [1,2,3,4]

train_dataset, test_dataset  = torchtext.datasets.AG_NEWS()
train_dataset, test_dataset = to_map_style_dataset(train_dataset), to_map_style_dataset(test_dataset)

train_loader = DataLoader(train_dataset, batch_size=256, collate_fn=vectorize_batch)
test_loader  = DataLoader(test_dataset, batch_size=256, collate_fn=vectorize_batch)
for X, Y in train_loader:
    print(X.shape, Y.shape)
    break
torch.Size([256, 98635]) torch.Size([256])
import gc

gc.collect()
42

2. Define Network

In this section, we have defined a network that we'll use for our text classification task. The network consist of 3 linear layers with output units 128, 64 and 4 respectively. The last layer has 4 output units which is same as number of target classes. The first layer takes input of shape (batch_size, vocab_len) and outputs processed data of shape (batch_size,128) which is given to second layer for processing. The second layer outputs processed data of shape (batch_size, 64) which is given to third layer for processing. The third layer outputs data of shape (batch_size, 4) which is prediction of network. We have also applied relu activation to the output of first two linear layers. We have created network using Sequential API of PyTorch which is almost same as Keras Sequential API.

After defining network, we have initialized it and performed forward pass through it for verification purposes.

If you are someone who is new to PyTorch and want to learn how to design neural networks using it then please check below link. It'll get you started with the library.

from torch import nn
from torch.nn import functional as F

class TextClassifier(nn.Module):
    def __init__(self):
        super(TextClassifier, self).__init__()
        self.seq = nn.Sequential(
            nn.Linear(len(vocab), 128),
            nn.ReLU(),

            nn.Linear(128, 64),
            nn.ReLU(),

            nn.Linear(64, 4),
            #nn.ReLU(),

            #nn.Linear(64, 4),
        )

    def forward(self, X_batch):
        return self.seq(X_batch)
text_classifier = TextClassifier()
for X, Y in train_loader:
    Y_preds = text_classifier(X)
    print(Y_preds.shape)
    break
torch.Size([256, 4])

3. Train Network

In this section, we have trained our network. In order to train it, we have designed a function that we'll use to perform training process. The function takes model, loss function, optimizer, train data loader, validation data loader and number of epochs as input. It then executes training loop number of epochs time. For each epoch, it loops through training data in batches. For each batch, it performs forward pass to make predictions, calculates loss, calculates gradients and updates network parameters using gradients. It also prints average loss of all batches of the epoch at the end of epoch. We have also created a helper function to make predictions on validation data and report validation accuracy at the end of each epoch.

from torch.optim import SGD, Adam, lr_scheduler
from tqdm import tqdm
from sklearn.metrics import accuracy_score

def MakePredictions(model, loader):
    with torch.no_grad():
        Y_shuffled, Y_preds = [], []
        for X, Y in loader:
            preds = model(X)
            Y_preds.append(preds)
            Y_shuffled.append(Y)
        gc.collect()
        Y_preds, Y_shuffled = torch.cat(Y_preds), torch.cat(Y_shuffled)

        return Y_shuffled.detach().numpy(), F.softmax(Y_preds, dim=-1).argmax(dim=-1).detach().numpy()

def TrainModel(model, loss_fn, optimizer, train_loader, val_loader, epochs=10):
    for i in range(1, epochs+1):
        losses = []
        for X, Y in tqdm(train_loader):
            Y_preds = model(X)

            loss = loss_fn(Y_preds, Y)
            losses.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print("Train Loss          : {:.3f}".format(torch.tensor(losses).mean()))
        Y_shuffled, Y_preds = MakePredictions(model, val_loader)
        print("Validation Accuracy : {:.3f}".format(accuracy_score(Y_shuffled, Y_preds)))

Below, we have actually trained our network using function designed in previous cell. We have initialized number of epochs to 5 and learning rate to 0.001. Then, we have initialized our text classification network, Adam optimizer and cross entropy loss function. At last, we have called our training routine with necessary parameters to perform training process. By looking at loss and accuracy values getting printed after each epoch, we can notice that our network is doing a good job at the text classification task.

epochs=5
learning_rate=1e-3

classifier = TextClassifier()
optimizer = Adam(classifier.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

TrainModel(classifier, loss_fn, optimizer, train_loader, test_loader, epochs=epochs)
100%|██████████| 469/469 [02:42<00:00,  2.88it/s]
Train Loss          : 0.377
Validation Accuracy : 0.915
100%|██████████| 469/469 [02:57<00:00,  2.65it/s]
Train Loss          : 0.161
Validation Accuracy : 0.914
100%|██████████| 469/469 [02:52<00:00,  2.72it/s]
Train Loss          : 0.088
Validation Accuracy : 0.914
100%|██████████| 469/469 [02:56<00:00,  2.66it/s]
Train Loss          : 0.048
Validation Accuracy : 0.914
100%|██████████| 469/469 [03:14<00:00,  2.41it/s]
Train Loss          : 0.035
Validation Accuracy : 0.913

4. Evaluate Network Performance

In this section, we have evaluated the performance of our trained network by calculating accuracy score, classification report (precision, recall and f1-score per target class) and confusion matrix metrics on test predictions. We can notice from the test accuracy that our model has done a good job at the given task. We have calculated these metrics using functions available from scikit-learn.

Please feel free to check below link if you are interested in learning various ML metrics available from sklearn.

Apart from calculations, we have also created a visualization for confusion matrix using Python library scikit-plot. The visualization shows that our network is quite good at classifying text documents of Sports category compared to other.

Scikit-plot provides visualizations for many other ML metrics. Please check below link in your free time if you want to learn about them.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

Y_actual, Y_preds = MakePredictions(classifier, test_loader)

print("Test Accuracy : {}".format(accuracy_score(Y_actual, Y_preds)))
print("\nClassification Report : ")
print(classification_report(Y_actual, Y_preds, target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_actual, Y_preds))
Test Accuracy : 0.9130263157894737

Classification Report :
              precision    recall  f1-score   support

       World       0.92      0.91      0.91      1900
      Sports       0.97      0.98      0.97      1900
    Business       0.88      0.87      0.88      1900
     Sci/Tec       0.88      0.89      0.89      1900

    accuracy                           0.91      7600
   macro avg       0.91      0.91      0.91      7600
weighted avg       0.91      0.91      0.91      7600


Confusion Matrix :
[[1731   38   76   55]
 [  22 1853   18    7]
 [  70   13 1657  160]
 [  67   13  122 1698]]
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt
import numpy as np

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_actual], [target_classes[i] for i in Y_preds],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Reds",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Captum: Interpret Predictions Of PyTorch Text Classification Network

5. Explain Predictions using CAPTUM

In this section, we have explained how we can use Captum to explain predictions made by our network. We'll create visualizations that show which words contributed positively/negatively to prediction. Captum provides many algorithms for explaining predictions but we'll explain a few of them. We suggest that the reader explores further algorithms. All algorithms are available from 'attr' sub-module of captum.

Below, we have simply loaded test text examples and their predictions. We'll randomly select text examples from this dataset and then explain predictions made by our network on these selected examples.

train_dataset, test_dataset  = torchtext.datasets.AG_NEWS()
X_test_text, Y_test = [], []
for Y, X in test_dataset:
    X_test_text.append(X)
    Y_test.append(Y-1)

Integrated Gradients

In this section, we have explained an algorithm named integrated gradients. This algorithm let us find out the contribution of individual features to the final prediction. The individual features in our case are tokens (words) of our text example.

Below, we have first retrieved a random text example from the test dataset and made predictions on it using our trained network. We have printed the actual text of our example, tokenized text, actual label, predicted label, and prediction probabilities. We can notice that our model correctly predicts the target label as 'World' for the selected text example. Now, we'll explain this prediction using integrated gradients algorithm.

import torch.nn.functional as F

idx = np.random.choice(range(len(X_test_text)))
X_text_vec = torch.tensor(vectorizer.transform(X_test_text[idx:idx+1]).todense(), dtype=torch.float32)
probs = F.softmax(classifier(X_text_vec), dim=-1)

print("============= Actual Text =================== ")
print(X_test_text[idx])
print("============================================= ")
print("============= Tokenized Text ================ ")
tokenized_text = tokenizer(X_test_text[idx])
print(tokenized_text)
print("============================================= ")

print("Actual    Label         : {}".format(target_classes[Y_test[idx]]))
print("Predicted Label         : {}".format(target_classes[probs.argmax().item()]))
print("Categories              : {}".format(target_classes))
print("Predicted Probabilities : {}".format(probs.detach().numpy()))
============= Actual Text ===================
Pakistani leader arrives for talks Pakistani President General Pervez Musharraf has arrived in Britain for a visit which will include talks with Prime Minister Tony Blair.
=============================================
============= Tokenized Text ================
['pakistani', 'leader', 'arrives', 'for', 'talks', 'pakistani', 'president', 'general', 'pervez', 'musharraf', 'has', 'arrived', 'in', 'britain', 'for', 'a', 'visit', 'which', 'will', 'include', 'talks', 'with', 'prime', 'minister', 'tony', 'blair', '.']
=============================================
Actual    Label         : World
Predicted Label         : World
Categories              : ['World', 'Sports', 'Business', 'Sci/Tec']
Predicted Probabilities : [[9.9999821e-01 1.7125021e-09 1.3057952e-06 4.4319830e-07]]

In order to explain prediction, we have first created an instance of an algorithm using IntegratedGradients() constructor available from captum. We have provided a function to the constructor that takes a bunch of text examples as input and returns their predictions (softmax probabilities). After initializing the algorithm, we have called attribute() method on it. We have provided a method with vectorized a text example and target label. It returns the contribution of features.

As we have used bag of words (word frequency) approach to vectorize our text examples, each vectorized text example has a length same as the length of vocabulary. It has word frequency at the index location of words from the text example with 0 for all other words not present in the example. Due to this, the feature contributions returned by attribute() method are the same as the length of vocabulary. We have retrieved contributions of words that are present in the text example separately. We can notice the difference in the shape getting printed.

In order to visualize this feature contribution, we need to call visualize_text() method (available from visualization module of captum) with list of VisualizationDataRecord objects. Each VisualizationDataRecord object represents an explanation of one example. We have first created a visualization object with feature contribution, prediction probability, target class, and tokenized text. Then, we have called visualize_text() method with the record to visualize the explanation.

We can notice from the visualization that words like 'pakistani', 'leader', 'minister', 'tony', 'blair', 'britain', etc contribute positively to predicting the target label as 'World'.

from captum.attr import IntegratedGradients
from captum.attr import visualization

def predict(X_batch):
    preds = classifier(X_batch)
    return F.softmax(preds, dim=-1)

ig = IntegratedGradients(predict)
attributions, delta = ig.attribute(X_text_vec, target=Y_test[idx:idx+1], return_convergence_delta=True)
print("Delta                           : {}".format(delta.item()))

attributions = attributions.flatten() ## This are attributions of all vocabulary tokens.
print("Attributions Actual Shape      : {}".format(attributions.shape[0]))
attributions_of_text = attributions[vocab(tokenized_text)] ## Here, we are retrieving attributions of tokens of text only. All other vocab tokens are removed.
print("Attributions Reorganized Shape : {}".format(attributions_of_text.shape[0]))

viz_record = visualization.VisualizationDataRecord(attributions_of_text,
                                                   probs.max().item(),
                                                   target_classes[probs.argmax().item()],
                                                   target_classes[Y_test[idx]],
                                                   target_classes[probs.argmax().item()],
                                                   attributions.sum(),
                                                   tokenized_text,
                                                   delta)
Delta                           : -0.0028483162243068127
Attributions Actual Shape      : 98635
Attributions Reorganized Shape : 27
visualization.visualize_text([viz_record]);

Captum: Interpret Predictions Of PyTorch Text Classification Network

Below, we have retrieved 5 other text examples from the test dataset and made predictions on them using our trained model. We can notice that our model correctly predicts the target label for all of them. We have also printed prediction probabilities for each label. Next, we'll create a visualization explaining all of them using integrated gradients algorithm.

X_text_vec = torch.tensor(vectorizer.transform(X_test_text[100:105]).todense(), dtype=torch.float32)
probs = F.softmax(classifier(X_text_vec), dim=-1)

actual_labels = [target_classes[Y_test[idx]] for idx in range(100,105)]
predicted_labels = [target_classes[idx] for idx in probs.argmax(dim=-1).numpy()]
predicted_probs = probs.max(dim=-1).values.detach().numpy()

print("Actual    Labels         : {}".format(actual_labels))
print("Predicted Labels         : {}".format(predicted_labels))
print("Predicted Probabilities  : {}".format(predicted_probs))
Actual    Labels         : ['Sports', 'Business', 'Sci/Tec', 'Business', 'World']
Predicted Labels         : ['Sports', 'Business', 'Sci/Tec', 'Business', 'World']
Predicted Probabilities  : [0.9999784  0.99999857 0.99820256 0.8728054  0.9999995 ]

Below, we have generated feature contributions for all text examples and then created visualization data records for them. At last, we have called visualize_text() method with visualization records to create a visualization showing the contribution of words per text example towards the predicted target label. We can notice from the visualization that how words highlighted with green contribute positively towards prediction and words highlighted in red contributes negatively.

from captum.attr import IntegratedGradients
from captum.attr import visualization

ig = IntegratedGradients(predict)
attributions, delta = ig.attribute(X_text_vec, target=Y_test[100:105], return_convergence_delta=True)

viz_records = []
for i in range(5):
    tokenized_text = tokenizer(X_test_text[100+i]) ## Tokenize Text
    token_indexes = vocab(tokenized_text) ## Retrieve token indexes
    attributions_of_text = attributions[i][token_indexes] ## Retrieve attributions for tokens of text.
    viz_record = visualization.VisualizationDataRecord(attributions_of_text,
                                                       predicted_probs[i],
                                                       predicted_labels[i],
                                                       actual_labels[i],
                                                       predicted_labels[i],
                                                       attributions[i].sum(),
                                                       tokenized_text,
                                                       delta[i])
    viz_records.append(viz_record)
visualization.visualize_text(viz_records);

Captum: Interpret Predictions Of PyTorch Text Classification Network

Neuron Integrated Gradients

In this section, we have explained the usage of one another algorithms named neuron integrated gradients. This algorithm let us understand the contribution of features toward a particular neuron of the selected layer. This algorithm belongs to the category Neuron Attribution (please check this link to know about algorithm categories in captum).

We have created an instance of an algorithm using NeuronIntegratedGradients() constructor. Apart from the prediction function, this time we have provided reference to the second linear layer as well. When calling attribute() with algorithm, we have provided neuron_selector parameter with value of 0. This will generate features contribution towards activating the 0th neuron of the second linear layer. After generating contributions, we have also visualized them by creating visualization data records.

layers = list(list(classifier.children())[0].children())

layers
[Linear(in_features=98635, out_features=128, bias=True),
 ReLU(),
 Linear(in_features=128, out_features=64, bias=True),
 ReLU(),
 Linear(in_features=64, out_features=4, bias=True)]
from captum.attr import NeuronIntegratedGradients
from captum.attr import visualization

ig = NeuronIntegratedGradients(predict, layers[2])
attributions = ig.attribute(X_text_vec, neuron_selector=0)

viz_records = []
for i in range(5):
    tokenized_text = tokenizer(X_test_text[100+i]) ## Tokenize Text
    token_indexes = vocab(tokenized_text) ## Retrieve token indexes
    attributions_of_text = attributions[i][token_indexes] ## Retrieve attributions for tokens of text.

    viz_record = visualization.VisualizationDataRecord(attributions_of_text,
                                                       predicted_probs[i],
                                                       predicted_labels[i],
                                                       actual_labels[i],
                                                       predicted_labels[i],
                                                       attributions[i].sum(),
                                                       tokenized_text,
                                                       0)
    viz_records.append(viz_record)
visualization.visualize_text(viz_records);

Captum: Interpret Predictions Of PyTorch Text Classification Network

Sunny Solanki  Sunny Solanki

YouTube Subscribe Comfortable Learning through Video Tutorials?

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.

Need Help Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.


Subscribe to Our YouTube Channel

YouTube SubScribe

Newsletter Subscription