Interpreting the predictions of the deep neural network is a good practice to understand whether our model has generalized or not. It helps us understand which parts of the data network are used to make predictions. Let's say for example that we have a trained classifier that classifies cats vs dogs, then we need that classifier is using pixels of a cat from the image to recognize the cat and not some background pixels for decision making. Nowadays, it's easy to get higher accuracy by training data with deep neural networks with many layers. Though deep networks give better results than traditional white-box ML models (decision trees, random forest, gradient boosting, etc.), they are less interpretable compared to them and are generally referred to as black-box models. To interpret the results of deep neural networks, an algorithm named LIME (Local Interpretable Model-Agnostic Explanations) is developed. The LIME algorithm generates fake samples from input data and trains a local ML model that can mimic the behavior of our deep neural network. It then uses this local ML model to interpret predictions. If you are interested in learning about how LIME works internally then please check the below tutorial where we have explained the steps that it performs internally to interpret predictions.
As a part of this tutorial, we'll explain how we can use LIME implementation available through eli5 library to explain the predictions made by text classification network designed using Flax. The Flax is a high-level deep learning library designed on top of another deep learning research library named JAX. We have used newsgroups text dataset available from scikit-learn and trained a simple Flax network on it to classify text documents. Then, we have used LIME algorithm implementation available from lime sub-module of eli5 to explain predictions made by the model by trying various parameter combinations of explainer.
We have assumed that the reader has little background on eli5 library. Please feel free to check the below link if you are looking for some background. Though this tutorial can be followed through even without a deep understanding of eli5.
Below, we have listed important sections of the tutorial to give an overview of the material covered.
Below, we have imported the necessary libraries and printed the version that we have used in our tutorial.
import eli5
print("Eli5 Version : {}".format(eli5.__version__))
import jax
print("JAX Version : {}".format(jax.__version__))
import flax
print("FLAX Version : {}".format(flax.__version__))
import optax
print("OPTAX Version : {}".format(optax.__version__))
In this section, we have loaded the newsgroups dataset available from scikit-learn. The dataset has ~18k text documents for 20 different categories of news. For our tutorial, we have selected only 3 categories to keep things simple.
import numpy as np
from sklearn import datasets
import gc
all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
'talk.politics.misc','talk.religion.misc']
selected_categories = ['comp.sys.ibm.pc.hardware','rec.sport.hockey','talk.politics.mideast']
X_train_text, Y_train = datasets.fetch_20newsgroups(subset="train", categories=selected_categories, return_X_y=True)
X_test_text , Y_test = datasets.fetch_20newsgroups(subset="test", categories=selected_categories, return_X_y=True)
X_train_text = np.array(X_train_text)
X_test_text = np.array(X_test_text)
classes = np.unique(Y_train)
mapping = dict(zip(classes, selected_categories))
len(X_train_text), len(X_test_text), classes, mapping
Below, we have vectorized our text data. We need to vectorize text data before giving it to neural networks because networks work on floats and not on string data.
Here, we have vectorized text data using Tf-Idf (Term Frequency - Inverse Document Frequency) approach. The Tf-Idf approach generates float value for each word of the sentence in a way that words that appear commonly across many text documents get assigned low values and rarely appearing words get assigned high values. This kind of approach helps with our text classification tasks as we want unique words to separate documents. In order to assign, the Tf-Idf value to each word, we need to tokenize them using a tokenizer that separates words of each sample.
In order to vectorize text data using the Tf-Idf approach, we have used TfidfVectorizer vectorizer available from scikit-learn. It tokenizes words and then assigns the Tf-Idf value to each word of the sample.
Here, we have not covered how Tf-Idf works in detail. Please feel free to check the below link if you are interested in learning about it in-depth.
import sklearn
from jax import numpy as jnp
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
vectorizer = TfidfVectorizer()
vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)
X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test = jnp.array(X_test.toarray(), dtype=jnp.float16)
X_train.shape, X_test.shape
In this section, we have defined a simple neural network of 3 dense layers using Flax library. The three layers have 128, 64, and 3 output units respectively. We have applied relu activation after the first two layers.
Here, we have not covered in detail various Flax modules and how they are used to design networks. Please feel free to check the below link where we have explained how to design a network using Flax.
After defining the network, we have also initialized and printed shapes of weights and biases of each layer. We have also made sample forward pass through the network using a few train text samples for verification purposes.
from flax import linen
from jax import random
class TextClassifier(linen.Module):
def setup(self):
self.linear1 = linen.Dense(features=128, name="DENSE1")
self.linear2 = linen.Dense(features=64, name="DENSE2")
self.linear3 = linen.Dense(len(classes), name="DENSE3")
def __call__(self, inputs):
x = linen.relu(self.linear1(inputs))
x = linen.relu(self.linear2(x))
logits = self.linear3(x)
return logits #linen.softmax(x)
seed = jax.random.PRNGKey(0)
model = TextClassifier()
params = model.init(seed, X_train[:5])
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
preds = model.apply(params, X_train[:5])
preds.shape
In this section, we have defined a function that we'll use as a loss function for our text classification task. We have designed a cross entropy loss using softmax_cross_entropy() function available from Optax library.
def CrossEntropyLoss(weights, input_data, actual):
logits_preds = model.apply(weights, input_data)
one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).sum()
Now, we have trained our network. To train the network, we have defined a small function that will perform the training process. The function takes training data (X, Y), validation data (X_val, Y_val), number of epochs, network parameters, optimizer state, and batch size as input. It then executes the training loop number of epochs time. For each epoch, it loops through whole training data in batches. For each batch, it performs a forward pass to make predictions, calculates loss, calculates gradients, and updates network parameters. It also records loss for each batch and then prints average loss for the total epoch by averaging losses of batches. The function also makes predictions on validation data and prints validation accuracy at the end of each epoch. At last, it returns updated network parameters which we can use to make predictions.
from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score
def TrainModelInBatches(X, Y, X_val, Y_val, epochs, weights, optimizer_state, batch_size=32):
for i in range(1, epochs+1):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in tqdm(batches):
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss, gradients = value_and_grad(CrossEntropyLoss)(weights, X_batch,Y_batch)
## Update Weights
updates, optimizer_state = optimizer.update(gradients, optimizer_state)
weights = optax.apply_updates(weights, updates)
losses.append(loss) ## Record Loss
print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
Y_val_preds = model.apply(weights, X_val)
val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
print("Validation Accuracy : {:.3f}".format(val_acc))
return weights
Below, we have trained our network by calling function designed in the previous cell. First, we have initialized batch size to 256, a number of epochs to 8, and learning rate to 0.001. Then, we have initialized the network, its parameters, and Adam optimizer. At last, we have called our training routine with the necessary parameters to train our network. We can notice from the loss and validation accuracy getting printed after each epoch that our model is doing a good job.
seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)
model = TextClassifier()
weights = model.init(seed, X_train[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)
final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
In this section, we have evaluated the performance of our network by calculating accuracy, classification report and confusion matrix metrics on test predictions. We can notice from the results that our model seems to be doing a good job at the text classification task. It's still confusing for some samples but overall performance is acceptable.
We have calculated these metrics using scikit-learn various functions. Please feel free to check the below link if you are interested in learning about various ML metrics available through sklearn. It covers the majority of them in-depth.
Then, in the next cell after the below cell, we have also plotted the confusion matrix using scikit-plot. The scikit-plot library provides plotting functionality for many ML metrics. Please feel free to check the below link if you want to learn it.
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt
skplt.metrics.plot_confusion_matrix([selected_categories[i] for i in Y_test], [selected_categories[i] for i in np.argmax(test_preds, axis=1)],
normalize=True,
title="Confusion Matrix",
cmap="Purples",
hide_zeros=True,
figsize=(5,5)
);
plt.xticks(rotation=90);
Now, we'll explain predictions made by our network using LIME implementation available through eli5.
In order to explain predictions using LIME, we need to follow the below steps.
In this section, we have explained the first text sample from test data using eli5. In order to do that we have first created an instance of TextExplainer below.
Please make a NOTE that TextExplainer() constructor has a few parameters that are used to modify the default behavior of the explainer. We have covered the majority of them with examples below. Please check this link if you are looking for a doc on parameters explanation.
from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers
explainer = TextExplainer(random_state=123, sampler=MaskingTextSampler())
Below, we have first defined a simple function that takes a list of text documents as input and returns predicted probabilities for them. The function first vectorizes data and then makes predictions using the trained network. It then performs softmax activation on the output of the network (logits) to convert them to probabilities and return them.
Then, we have selected the first sample from our test samples and made predictions on it using our trained network. We have printed original and predicted labels as well as the probability of the model.
At last, we have called fit() method on TextExplainer instance by giving the first text sample and predictions function. This step will internally execute LIME algorithm and generate an explanation for the sample.
import numpy as np
def predict_proba(X_docs):
X_vect = vectorizer.transform(X_docs)
X_vect = jnp.array(X_vect.todense(), dtype=jnp.float32)
logits = model.apply(final_weights, X_vect)
preds_probs = jax.nn.softmax(logits)
#preds_probs = np.array([probs / probs.sum() for probs in preds_probs])
return preds_probs.to_py()
print("Actual Label : {}".format(selected_categories[Y_test[0]]))
pred_probs = predict_proba(X_test_text[:1])
print("Predicted Label : {}".format(selected_categories[pred_probs.argmax(axis=1)[0]]))
print("Probability : {:.3f}".format(pred_probs.max(axis=1)[0]))
explainer.fit(X_test_text[0], predict_proba)
In the below cell, we have called show_predictions() function on TextExplainer instance to create a visualization showing which words contributed to prediction.
We can notice from the visualization that words like 'politics', 'society', 'truth', 'organization', etc are contributing to predicting category talk.politics.mideast. This makes sense as these words should be contributing to predict talk.politics.mideast category. The 'homosexual' word should have contributed to 'talk.politics.mideast' category though but it seems to be contributing to 'rec.sport.hockey' category which little odd.
Please make a NOTE that by default, show_prediction() function shows output for all categories. We can restrict categories by setting targets parameter. In the cell after the below cell, we have explained it by creating a visualization for an only predicted category. The targets parameter accepts a list of categories that we want to visualize.
explainer.show_prediction(target_names=selected_categories)
pred_cat = selected_categories[pred_probs.argmax(axis=1)[0]]
explainer.show_prediction(target_names=selected_categories, targets=[pred_cat])
In this section, we have explained one wrong prediction. In order to do that we have first made predictions on all test samples and found out indexes of samples predicted wrong. Then, we have selected one sample from those wrong predictions.
We have followed the same process as we have followed earlier to create a visualization explaining the wrong prediction. The original category of the selected sample is talk.politics.mideast but our model predicted comp.sys.ibm.pc.hardware. We can notice that model has predicted comp.sys.ibm.pc.hardware with a probability of 0.63. It's not very much sure about the category. The second highest probability is of talk.politics.mideast category.
X_test_vect = vectorizer.transform(X_test_text).todense()
Y_logits = model.apply(final_weights, jnp.array(X_test_vect, dtype=jnp.float32))
Y_probs = linen.softmax(Y_logits)
Y_preds = Y_probs.argmax(axis=-1).to_py()
del X_test_vect
gc.collect()
wrong_preds = np.argwhere(Y_preds != Y_test)
idx = wrong_preds.flatten()[0]
print("Actual Label : {}".format(selected_categories[Y_test[idx]]))
pred_probs = predict_proba(X_test_text[idx:idx+1])
print("Predicted Label : {}".format(selected_categories[pred_probs.argmax(axis=1)[0]]))
print("Probability : {:.3f}".format(pred_probs.max(axis=1)[0]))
explainer.fit(X_test_text[idx], predict_proba)
explainer.show_prediction(target_names=selected_categories)
By default, TextExplainer internally trains SGDClassifier available from scikit-learn to explain predictions. We can replace this local ML model with some other model as well.
In this example, we have used LogisticRegression classifier from scikit-learn instead to check how it helps with creating explanations. We can use almost any estimator from scikit-learn that has get_feature_names() method.
We can notice from the visualization that the results are almost the same as our earlier section. The words like 'politics', 'society', 'yard', 'california', 'genuine', etc are contributing to predicting category talk.politics.mideast.
from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
explainer = TextExplainer(random_state=123, sampler=MaskingTextSampler(), clf=LogisticRegression(max_iter=10000))
print("Actual Label : {}".format(selected_categories[Y_test[0]]))
pred_probs = predict_proba(X_test_text[:1])
print("Predicted Label : {}".format(selected_categories[pred_probs.argmax(axis=1)[0]]))
print("Probability : {:.3f}".format(pred_probs.max(axis=1)[0]))
explainer.fit(X_test_text[0], predict_proba)
explainer.show_prediction(target_names=selected_categories)
We can also call show_weights() method on TextExplainer if you want a tabular version of visualization showing words contributions.
explainer.show_weights(target_names=selected_categories, targets=selected_categories)
The TextExplainer lets us provide our own vectorizer if we are not happy with the performance of the vectorizer used by it. We can provide our own vectorizer using vec parameter.
Below, we have given our own vectorizer to vec parameter of TextExplainer and then generated visualization explaining the first test sample.
We can notice from the visualization that the results do not seem to be that good.
from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers
explainer = TextExplainer(random_state=123, sampler=MaskingTextSampler(), vec=vectorizer)
print("Actual Label : {}".format(selected_categories[Y_test[0]]))
pred_probs = predict_proba(X_test_text[:1])
print("Predicted Label : {}".format(selected_categories[pred_probs.argmax(axis=1)[0]]))
print("Probability : {:.3f}".format(pred_probs.max(axis=1)[0]))
explainer.fit(X_test_text[0], predict_proba)
explainer.show_prediction(target_names=selected_categories, top=(8,8))
The TextExplainer lets us provide our own tokenizing pattern if we do not like the internal pattern used by it. We can provide our own pattern by giving a regular expression pattern to token_pattern parameter.
Below, we have explained for example how we can provide our own tokenization pattern to create an explainer object. We have then generated visualization using that explainer. The results seem little good with we providing tokenizer.
from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers
explainer = TextExplainer(random_state=123, sampler=MaskingTextSampler(), token_pattern=r"\w+")
explainer.fit(X_test_text[0], predict_proba)
explainer.show_prediction(target_names=selected_categories)
In this section, we have generated visualization using TextExplainer object created with default arguments.
from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers
explainer = TextExplainer(random_state=123)
explainer.fit(X_test_text[0], predict_proba)
explainer.show_prediction(target_names=selected_categories)
The TextExplainer lets us provide multiple samplers that will be used to generate samples for training local ML model for LIME algorithm. We can do that by creating an instance of MaskingTextSamplers and providing it to sampler parameter of TextExplainer constructor. The MaskingTextSamplers instance is list of MaskingTextSampler instances. We can create MaskingTextSamplers instance using a list of dictionaries where each dictionary specifies parameters of MaskingTextSampler instance.
Below, we have explained with example how we can create TextExplainer instance using multiple samplers and then use it to interpret predictions of the network.
Please make a NOTE that 'MaskingTextSamplers' and 'MaskingTextSampler' are different.
from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers
explainer = TextExplainer(random_state=123, sampler=MaskingTextSamplers([{"bow":True, "token_pattern": r"\w+"}, {"bow": False, "token_pattern": r"\w+"}]))
explainer.fit(X_test_text[0], predict_proba)
explainer.show_prediction(target_names=selected_categories)
We can ask TextExplainer instance to take position of words into consideration by setting position_dependent parameter to True. This can help improve performance in situations where some words together have some context.
Below, we have generated TextExplainer instance with position_dependent parameter set to True. Then, we have used it to interpret network prediction. The result is almost the same as previous results.
from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers
explainer = TextExplainer(random_state=123, sampler=MaskingTextSampler(), position_dependent=True)
explainer.fit(X_test_text[0], predict_proba)
explainer.show_prediction(target_names=selected_categories)
This ends our small tutorial explaining how we can use LIME implementation available through eli5.lime module to explain the predictions generated by text classification network created using Flax (JAX). Please feel free to let us know your views in the comments section.
If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.
When going through coding examples, it's quite common to have doubts and errors.
If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.
You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.
If you want to