Updated On : Mar-02,2022 Time Investment : ~30 mins

LIME: Interpret Predictions Of Keras Text Classification Networks

Interpreting the predictions of our ML model helps us better understand whether our model has generalized or not. It further opens up opportunities to understand what tasks model is doing good and at what tasks it's getting confused. This can help us make informed changes that can further improve model performance. We can easily generate feature importance for simple ML models like linear regression, decision trees, random forests, etc. But when it comes to deep neural networks, it becomes hard to generate feature importances specifying which features are contributing to the prediction. LIME (Local Interpretable Model-Agnostic Explanations) is an algorithm that helps us solve this problem. It can help us understand the prediction of our deep network by training simple ML models (like decision trees, linear regression, etc) on fake data generated from the input sample. This model tries to mimic the predictions of our network. We have covered in detail how LIME works internally in a separate tutorial. Please feel free to check the below link which has steps of internal workings.

As a part of this tutorial, we have used LIME to explain the predictions made by our text classification keras neural network. We have used the newsgroups text dataset available from scikit-learn for our purpose. The tutorial trains model on data vectorized with different text vectorization approach to see which one is doing a better job. We recommend that the reader goes through the below link if he/she does not have a background on text classification using keras networks.

Below, we have listed important sections of tutorial to give an overview of the material covered.

Important Sections Of Tutorial

  1. Network Using Keras Text Vectorization (Word Frequency)
    • Load Data
    • Vectorize Text Data
    • Create And Train Network
    • Evaluate Network Performance
    • Explain Network Predictions Using Lime
      • Explain Correct Predictions
      • Explain Incorrect Predictions
  2. Network Using Scikit-Learn Text Vectorizer (Word Frequency)
    • Vectorize Text Data
    • Create And Train Model
    • Explain Network Predictions
      • Explain Correct Predictions
      • Explain Incorrect Predictions
  3. Network Using Scikit-Learn Text Vectorizer (Word Frequency + Stop Words Removed)
  4. Network Using Scikit-learn Text Vectorizer (Tf-Idf + Stop Words Removed)

Below, we have imported the necessary libraries and printed the versions that we have used in our tutorial.

import tensorflow
from tensorflow import keras

print("Keras Version : {}".format(keras.__version__))
Keras Version : 2.6.0

1. Network Using Keras Text Vectorization (Word Frequency)

In this section, we have vectorized our input text data using the word frequency approach and then trained a network on it. After training the network, we have evaluated its performance by calculating various ML metrics and explained predictions made by the network. We have used the text vectorization layer available from keras to vectorize data.

Load Data

In this section, we have loaded 20 newsgroups dataset that we'll be using throughout our tutorial. The dataset is available from scikit-learn and has ~18k text documents of 20 different categories. As a part of our example, we have selected 5 categories as listed in the code below. We have loaded train and test datasets using fetch_20newsgroups() function available from scikit-learn.

import numpy as np
from sklearn import datasets
import gc

all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
                  'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
                  'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
                  'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
                  'talk.politics.misc','talk.religion.misc']

selected_categories = ['misc.forsale','rec.autos','rec.sport.baseball','sci.electronics','soc.religion.christian']

X_train, Y_train = datasets.fetch_20newsgroups(subset="train", categories=selected_categories, return_X_y=True)
X_test , Y_test  = datasets.fetch_20newsgroups(subset="test", categories=selected_categories, return_X_y=True)

X_train = np.array(X_train)
X_test = np.array(X_test)

classes = np.unique(Y_train)
mapping = dict(zip(classes, selected_categories))

len(X_train), len(X_test), classes, mapping
(2966,
 1974,
 array([0, 1, 2, 3, 4]),
 {0: 'misc.forsale',
  1: 'rec.autos',
  2: 'rec.sport.baseball',
  3: 'sci.electronics',
  4: 'soc.religion.christian'})

Vectorize Text Data

In this section, we have trained our text vectorization layer available from keras using our dataset. The training will populate the vocabulary of the vectorization layer. We'll later use this layer as a part of our keras network.

text_vectorizer = keras.layers.TextVectorization(max_tokens=50000, standardize="lower_and_strip_punctuation",
                                                 split="whitespace", output_mode="count", pad_to_max_tokens=True)

text_vectorizer.adapt(np.concatenate((X_train, X_test)), batch_size=512)

vocab = text_vectorizer.get_vocabulary()
print("Vocab : {}".format(vocab[:10]))
print("Vocab Size : {}".format(text_vectorizer.vocabulary_size()))
out = text_vectorizer(X_train[:5])
print("Output Shape : {}".format(out.shape))
Vocab : ['[UNK]', 'the', 'to', 'of', 'a', 'and', 'in', 'i', 'is', 'that']
Vocab Size : 50000
Output Shape : (5, 50000)

Create And Train Network

In this section, we have created a keras network which we'll use to classify text documents. The network consists of a text vectorization layer which we trained in the previous section to populate its dictionary followed by 3 dense layers. The three dense layers had 128, 64 AND 5 units respectively. The first two dense layers have relu activation and the last dense layer has softmax activation. When we perform a forward pass through the network, the first text vectorization layer will vectorize text data and pass vectorized data to the dense layer next. After creating a network, we have also summarized the network to show network parameters.

Then, we have compiled a network to use Adam optimizer, cross entropy loss, and accuracy metrics.

At last, we have called fit() method on the model to train it with train data for 8 epochs with a batch size of 256. We have also provided validation data to check the accuracy of the model on it.

from tensorflow.keras.models import Sequential
from tensorflow.keras import layers

def create_model(text_vectorizer):
    return Sequential([
                        layers.Input(shape=(1,), dtype="string"),
                        text_vectorizer,
                        #layers.Dense(256, activation="relu"),
                        layers.Dense(128, activation="relu"),
                        layers.Dense(64, activation="relu"),
                        layers.Dense(len(classes), activation="softmax"),
                    ])

model = create_model(text_vectorizer)

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
text_vectorization (TextVect (None, 50000)             0
_________________________________________________________________
dense (Dense)                (None, 128)               6400128
_________________________________________________________________
dense_1 (Dense)              (None, 64)                8256
_________________________________________________________________
dense_2 (Dense)              (None, 5)                 325
=================================================================
Total params: 6,408,709
Trainable params: 6,408,709
Non-trainable params: 0
_________________________________________________________________
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
history = model.fit(X_train, Y_train, batch_size=256, epochs=8, validation_data=(X_test, Y_test))
gc.collect()
Epoch 1/8
12/12 [==============================] - 3s 180ms/step - loss: 1.0541 - accuracy: 0.7384 - val_loss: 0.5561 - val_accuracy: 0.9245
Epoch 2/8
12/12 [==============================] - 2s 146ms/step - loss: 0.2360 - accuracy: 0.9777 - val_loss: 0.2652 - val_accuracy: 0.9493
Epoch 3/8
12/12 [==============================] - 2s 162ms/step - loss: 0.0611 - accuracy: 0.9956 - val_loss: 0.1993 - val_accuracy: 0.9514
Epoch 4/8
12/12 [==============================] - 2s 215ms/step - loss: 0.0232 - accuracy: 0.9987 - val_loss: 0.1731 - val_accuracy: 0.9529
Epoch 5/8
12/12 [==============================] - 2s 172ms/step - loss: 0.0101 - accuracy: 0.9997 - val_loss: 0.1691 - val_accuracy: 0.9509
Epoch 6/8
12/12 [==============================] - 2s 205ms/step - loss: 0.0060 - accuracy: 0.9997 - val_loss: 0.1675 - val_accuracy: 0.9514
Epoch 7/8
12/12 [==============================] - 3s 221ms/step - loss: 0.0050 - accuracy: 0.9993 - val_loss: 0.1690 - val_accuracy: 0.9488
Epoch 8/8
12/12 [==============================] - 2s 146ms/step - loss: 0.0031 - accuracy: 0.9997 - val_loss: 0.1654 - val_accuracy: 0.9529
2949

Evaluate Network Performance

In this section, we have evaluated the performance of the network by calculating accuracy, classification report and confusion matrix metrics on the test dataset. We can notice from the results that our model seems to be doing a good job as per the accuracy metric. The model is confusing some samples of categories misc.forsale, rec.autos and sci.electronics as per confusion matrix.

We have calculated various ML metrics using functions available from scikit-learn. Please feel free to check the below link if you want to learn about them and others available through scikit-learn. It covers them in detail.

In the next cell after the below cell, we have also plotted classification metrics using the scikit-plot python library. It provides visualizations for many ML metrics. Please feel free to check the below link if you want to learn about it.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

print("Train Accuracy : {}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 0.9996628455832771
Test  Accuracy : 0.952887537993921

Classification Report :
                        precision    recall  f1-score   support

          misc.forsale       0.93      0.94      0.94       390
             rec.autos       0.95      0.94      0.95       396
    rec.sport.baseball       0.99      0.96      0.97       397
       sci.electronics       0.92      0.94      0.93       393
soc.religion.christian       0.98      0.97      0.98       398

              accuracy                           0.95      1974
             macro avg       0.95      0.95      0.95      1974
          weighted avg       0.95      0.95      0.95      1974


Confusion Matrix :
[[368   9   2  11   0]
 [  7 373   1  15   0]
 [  2   4 381   3   7]
 [ 13   6   1 371   2]
 [  4   1   1   4 388]]
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([selected_categories[i] for i in Y_test], [selected_categories[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

LIME: Interpret Predictions Of Keras Text Classification Networks

Explain Network Predictions Using Lime

In this section, we have explained the predictions made by our model using LIME algorithm. The lime library provides us with LimeTextExplainer that lets us explain predictions made by our model by creating visualizations that highlight words that contribute to the prediction category.

In order to explain predictions using lime, we first need to create an instance of LimeTextExplainer. Then, we need to call explain_instance() method on it with a sample to explain. The method returns an instance of Explanation on which we can call show_in_notebook() method to create visualization explaining prediction.

Below, we have first created an instance of LimeTextExplainer that we'll use to explain predictions. We have given it labels of all the output categories of our model.

We recommend that readers go through another simple tutorial on lime where we have covered arguments of LimeTextExplainer() constructor.

from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=selected_categories, verbose=True)

explainer
<lime.lime_text.LimeTextExplainer at 0x7f46ac879750>

Explain Correct Predictions

In this section, we have explained the correct prediction made by our model. We have first randomly selected one sample from the test data set. The category of the selected sample is soc.religion.christian and the same is predicted by our model.

Then, we have called explain_instance() function on LimeTabularExplainer instance by giving text sample, reference to predict() function of model and actual label of the sample. This returned an instance of Explanation which has sample explanation details.

At last, we have called show_in_notebook() function on explanation instance to show sample explanation. The visualization has a bar chart that shows which words contributed positively to the prediction categories as well as has original text of the sample with words highlighted showing whether they contributed positively or negatively to the prediction label. The words are highlighted as shades of color which are based on their contribution to prediction. We can notice from the visualization that words like 'church', 'homosexual', 'churches', 'group', etc are contributing to predicting category soc.religion.christian.

import numpy as np

rng = np.random.RandomState(42)
idx = rng.randint(1, len(X_test))

print("Prediction : ", selected_categories[model.predict(X_test[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", selected_categories[Y_test[idx]])
Prediction :  soc.religion.christian
Actual :      soc.religion.christian
explanation = explainer.explain_instance(X_test[idx], classifier_fn=model.predict, labels=Y_test[idx:idx+1])

explanation
Intercept 0.689602685771859
Prediction_local [1.03962346]
Right: 0.99935097
<lime.explanation.Explanation at 0x7f46ae38ed90>
explanation.show_in_notebook()

LIME: Interpret Predictions Of Keras Text Classification Networks

Explain Incorrect Predictions

In this section, we have explained a wrong prediction using our explainer object. We have first found out indexes of all test samples that are predicted wrong by our model and then we have randomly selected one sample from it. The actual category of selected sample is sci.electronics and model predicted rec.autos. This time, we have generated two explanation objects (one with respect to the actual label and one with respect to the predicted label). This is done so that we can see which words are contributing to a particular category.

We can notice from the visualizations that words like 'batteries', 'old', etc contributed to predicting category sci.electronics and words like 'car', 'oil', 'engine', 'Honda', etc contributed to predicting category rec.autos.

import numpy as np

Y_test_preds = np.argmax(test_preds, axis=-1)
wrong_preds = np.argwhere(Y_test!=Y_test_preds)

rng = np.random.RandomState(123)
idx = rng.choice(wrong_preds.flatten())

print("Actual :     ", selected_categories[Y_test[idx]])
print("Prediction : ", selected_categories[model.predict(X_test[idx:idx+1]).argmax(axis=-1)[0]])

explanation_actual = explainer.explain_instance(X_test[idx], classifier_fn=model.predict,
                                                labels=Y_test[idx:idx+1])
explanation_pred = explainer.explain_instance(X_test[idx], classifier_fn=model.predict,
                                              labels=Y_test_preds[idx:idx+1])
Actual :      sci.electronics
Prediction :  rec.autos
Intercept 0.627446482690789
Prediction_local [0.28472881]
Right: 0.20287848
Intercept 0.12331060583892323
Prediction_local [0.73731424]
Right: 0.7962232
explanation_actual.show_in_notebook()

LIME: Interpret Predictions Of Keras Text Classification Networks

explanation_pred.show_in_notebook()

LIME: Interpret Predictions Of Keras Text Classification Networks

Explain Correct Predictions

In this section, we have again explained the correct prediction but this time we have kept feature_selection parameter of LimeTextExplainer to 'none'.

from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=selected_categories, verbose=True, feature_selection="none")

rng = np.random.RandomState(42)
idx = rng.randint(1, len(X_test))

print("Actual :     ", selected_categories[Y_test[idx]])
print("Prediction : ", selected_categories[model.predict(X_test[idx:idx+1]).argmax(axis=-1)[0]])

explanation = explainer.explain_instance(X_test[idx], classifier_fn=model.predict, num_features=10,
                                         labels=Y_test[idx:idx+1])

explanation.show_in_notebook()

LIME: Interpret Predictions Of Keras Text Classification Networks

2. Network Using Scikit-Learn Text Vectorizer (Word Frequency)

In this section, we have again used the word frequency text vectorization approach but this time we have vectorized text data using CountVectorizer available from scikit-learn. Our keras network will be working directly on vectorized data this time.

Please NOTE that we have not covered text vectorization approaches that we have used in this tutorial in detail as we expect that reader has little background on them. Please feel free to check the below link if you want to learn about them in detail. The link covers them in detail.

Vectorize Text Data

Below, we have vectorized our text data using CountVectorizer available from scikit-learn.

import sklearn
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer

vectorizer = CountVectorizer(max_features=50000)

vectorizer.fit(np.concatenate((X_train, X_test)))
X_train_vect = vectorizer.transform(X_train)
X_test_vect = vectorizer.transform(X_test)

X_train_vect, X_test_vect = X_train_vect.toarray(), X_test_vect.toarray()

X_train_vect.shape, X_test_vect.shape
((2966, 47953), (1974, 47953))

Create And Train Model

In this section, we have first created a keras network that works on vectorized data. It has 3 dense layers with units 128, 64, and 5 respectively.

After creating a network, we have compiled it to use Adam optimizer, cross entropy loss, and accuracy metrics.

At last, we have trained the network by calling fit() function for 8 epochs with a batch size of 256. We can notice from the loss and accuracy getting printed after each epoch that the model is doing a good job.

from tensorflow.keras.models import Sequential
from tensorflow.keras import layers

def create_model(input_shape):
    return Sequential([
                            layers.Input(shape=input_shape),
                            layers.Dense(128, activation="relu"),
                            layers.Dense(64, activation="relu"),
                            layers.Dense(len(selected_categories), activation="softmax"),
                        ])
model = create_model(X_train_vect.shape[1:])

model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense_3 (Dense)              (None, 128)               6138112
_________________________________________________________________
dense_4 (Dense)              (None, 64)                8256
_________________________________________________________________
dense_5 (Dense)              (None, 5)                 325
=================================================================
Total params: 6,146,693
Trainable params: 6,146,693
Non-trainable params: 0
_________________________________________________________________
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
history = model.fit(X_train_vect, Y_train, batch_size=256, epochs=8, validation_data=(X_test_vect, Y_test))
Epoch 1/8
12/12 [==============================] - 4s 248ms/step - loss: 0.9927 - accuracy: 0.7245 - val_loss: 0.4695 - val_accuracy: 0.9357
Epoch 2/8
12/12 [==============================] - 2s 140ms/step - loss: 0.1804 - accuracy: 0.9821 - val_loss: 0.2231 - val_accuracy: 0.9504
Epoch 3/8
12/12 [==============================] - 2s 140ms/step - loss: 0.0427 - accuracy: 0.9963 - val_loss: 0.1756 - val_accuracy: 0.9524
Epoch 4/8
12/12 [==============================] - 2s 140ms/step - loss: 0.0147 - accuracy: 0.9997 - val_loss: 0.1779 - val_accuracy: 0.9519
Epoch 5/8
12/12 [==============================] - 2s 139ms/step - loss: 0.0087 - accuracy: 0.9997 - val_loss: 0.1673 - val_accuracy: 0.9534
Epoch 6/8
12/12 [==============================] - 2s 146ms/step - loss: 0.0048 - accuracy: 0.9997 - val_loss: 0.1614 - val_accuracy: 0.9514
Epoch 7/8
12/12 [==============================] - 2s 143ms/step - loss: 0.0039 - accuracy: 0.9993 - val_loss: 0.1605 - val_accuracy: 0.9504
Epoch 8/8
12/12 [==============================] - 2s 136ms/step - loss: 0.0032 - accuracy: 0.9997 - val_loss: 0.1638 - val_accuracy: 0.9529

Explain Network Predictions

Explain Correct Predictions

In this section, we have explained the correct prediction made by our model. We have randomly selected one sample from the test set. The actual and predicted category of that sample is soc.religion.christian.

This time we have created a small function that takes as input text samples and returns probabilities of those samples. This is required by classifier_fn parameter of explain_instance() function. In our previous example, we did not need to create such a function because our model was taking text samples as input and making predictions but in this example, it works on vectorized data hence we need to do it.

We can notice from the visualization that the words like 'church', 'homosexual', 'group', 'topic', etc are contributing to predicting category soc.religion.christian.

from lime import lime_text

def make_predictions(X_batch):
    X_batch_text = vectorizer.transform(X_batch)
    preds = model.predict(X_batch_text.toarray())
    return preds

explainer = lime_text.LimeTextExplainer(class_names=selected_categories)

rng = np.random.RandomState(42)
idx = rng.randint(1, len(X_test))

print("Prediction : ", selected_categories[model.predict(X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", selected_categories[Y_test[idx]])

explanation = explainer.explain_instance(X_test[idx], classifier_fn=make_predictions, num_features=10,
                                         labels=Y_test[idx:idx+1])

explanation.show_in_notebook()

LIME: Interpret Predictions Of Keras Text Classification Networks

Explain Incorrect Predictions

In this section, we have explained one wrong prediction. The actual category of our selected sample is sci.electronics whereas our model predicted misc.forsale. As usual, we have created two visualizations to see which words are contributing to actual and predicted labels.

We can notice that words like 'sells', 'com', etc are contributing to category sci.electronics whereas words like 'tools', 'Mac', 'Macintosh', etc are contributing to category misc.forsale. This does not make much sense as words that should be part of sci.electronics category are contributing to misc.forsale category. This hints that our model is not that generic enough.

Y_test_preds = model.predict(X_test_vect).argmax(axis=-1)
wrong_preds = np.argwhere(Y_test!=Y_test_preds)

rng = np.random.RandomState(123)
idx = rng.choice(wrong_preds.flatten())

print("Actual :     ", selected_categories[Y_test[idx]])
print("Prediction : ", selected_categories[model.predict(X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])

explanation_actual = explainer.explain_instance(X_test[idx], classifier_fn=make_predictions,
                                                num_features=10, labels=Y_test[idx:idx+1])
explanation_pred = explainer.explain_instance(X_test[idx], classifier_fn=make_predictions,
                                              num_features=10, labels=Y_test_preds[idx:idx+1])

explanation_actual.show_in_notebook()
explanation_pred.show_in_notebook()

LIME: Interpret Predictions Of Keras Text Classification Networks

3. Network Using Scikit-Learn Text Vectorizer (Word Frequency + Stop Words Removed)

In this example, we have again used the word frequency approach but this time, we have removed commonly appearing words in the English language (words like 'the', 'a', 'an', 'then', etc) that do not contribute much to classification as they appear in almost every other text document. The majority of the code in this section is repeated from the previous section hence we have not included a detailed explanation of repeated parts.

Vectorize Text Data

In this section, we have vectorized data using scikit-learn CountVectorizer. We have set stop_words parameter to 'english' asking it to remove stop words from the vocabulary.

import sklearn
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer

vectorizer = CountVectorizer(max_features=50000, stop_words="english")

vectorizer.fit(np.concatenate((X_train, X_test)))
X_train_vect = vectorizer.transform(X_train)
X_test_vect = vectorizer.transform(X_test)

X_train_vect, X_test_vect = X_train_vect.toarray(), X_test_vect.toarray()

X_train_vect.shape, X_test_vect.shape
((2966, 47645), (1974, 47645))

Create And Train Model

Here, we have created a model, compiled it, and trained it using vectorized data. We can notice from the loss and accuracy getting printed at the end of all epochs that the model is performing well.

model = create_model(X_train_vect.shape[1:])

model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
history = model.fit(X_train_vect, Y_train, batch_size=256, epochs=8, validation_data=(X_test_vect, Y_test))
Epoch 1/8
12/12 [==============================] - 3s 182ms/step - loss: 1.0352 - accuracy: 0.7519 - val_loss: 0.5140 - val_accuracy: 0.9428
Epoch 2/8
12/12 [==============================] - 2s 143ms/step - loss: 0.1949 - accuracy: 0.9885 - val_loss: 0.2230 - val_accuracy: 0.9539
Epoch 3/8
12/12 [==============================] - 2s 140ms/step - loss: 0.0390 - accuracy: 0.9976 - val_loss: 0.1700 - val_accuracy: 0.9564
Epoch 4/8
12/12 [==============================] - 2s 140ms/step - loss: 0.0116 - accuracy: 0.9990 - val_loss: 0.1621 - val_accuracy: 0.9549
Epoch 5/8
12/12 [==============================] - 2s 180ms/step - loss: 0.0061 - accuracy: 0.9990 - val_loss: 0.1586 - val_accuracy: 0.9549
Epoch 6/8
12/12 [==============================] - 2s 139ms/step - loss: 0.0036 - accuracy: 0.9993 - val_loss: 0.1572 - val_accuracy: 0.9539
Epoch 7/8
12/12 [==============================] - 2s 144ms/step - loss: 0.0039 - accuracy: 0.9990 - val_loss: 0.1545 - val_accuracy: 0.9554
Epoch 8/8
12/12 [==============================] - 2s 134ms/step - loss: 0.0044 - accuracy: 0.9990 - val_loss: 0.1551 - val_accuracy: 0.9559

Explain Network Predictions

Explain Correct Predictions

In this section, we have explained the correct prediction made by our model. The actual category of the selected sample is soc.religion.christian and the same is predicted by our model. The words like 'church', 'group', 'homosexual', 'presbyterian', etc are contributing to the prediction.

from lime import lime_text

def make_predictions(X_batch):
    X_batch_text = vectorizer.transform(X_batch)
    preds = model.predict(X_batch_text.toarray())
    return preds

explainer = lime_text.LimeTextExplainer(class_names=selected_categories)

rng = np.random.RandomState(42)
idx = rng.randint(1, len(X_test))

print("Prediction : ", selected_categories[model.predict(X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", selected_categories[Y_test[idx]])

explanation = explainer.explain_instance(X_test[idx], classifier_fn=make_predictions, num_features=10, labels=Y_test[idx:idx+1])

explanation.show_in_notebook()

LIME: Interpret Predictions Of Keras Text Classification Networks

Explain Incorrect Predictions

In this section, we have explained one incorrect prediction. The actual category of the sample is 'misc.forsale' whereas our model predicted 'sci.electronics'. We have generated an explanation with respect to both actual and predicted labels.

Y_test_preds = model.predict(X_test_vect).argmax(axis=-1)
wrong_preds = np.argwhere(Y_test!=Y_test_preds)

rng = np.random.RandomState(123)
idx = rng.choice(wrong_preds.flatten())

print("Actual :     ", selected_categories[Y_test[idx]])
print("Prediction : ", selected_categories[model.predict(X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])

explanation_actual = explainer.explain_instance(X_test[idx], classifier_fn=make_predictions, num_features=10, labels=Y_test[idx:idx+1])
explanation_pred = explainer.explain_instance(X_test[idx], classifier_fn=make_predictions, num_features=10, labels=Y_test_preds[idx:idx+1])

explanation_actual.show_in_notebook()
explanation_pred.show_in_notebook()

LIME: Interpret Predictions Of Keras Text Classification Networks

4. Network Using Scikit-learn Text Vectorizer (Tf-Idf + Stop Words Removed)

In this section, we have vectorized our data using Tf-Idf (Term Frequency-Inverse Document Frequency). It assigns float values to each word in a way that words that appear commonly across many documents get assigned low values and those appearing rarely get assigned high values.

Please feel free to check the below link if you want to learn about Tf-IDF in detail.

Vectorize Text Data

In this section, we have vectorized data using TfidfVectorizer available from scikit-learn. The TfidfVectorizer is an implementation of Tf-IDF concept.

import sklearn
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer

vectorizer = TfidfVectorizer(max_features=50000, stop_words="english")

vectorizer.fit(np.concatenate((X_train, X_test)))
X_train_vect = vectorizer.transform(X_train)
X_test_vect = vectorizer.transform(X_test)

X_train_vect, X_test_vect = X_train_vect.toarray(), X_test_vect.toarray()

X_train_vect.shape, X_test_vect.shape
((2966, 47645), (1974, 47645))

Create And Train Model

In this section, we have created and trained our model as usual. We can notice from the loss and accuracy that the model has done a good job at prediction.

model = create_model(X_train_vect.shape[1:])

model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
history = model.fit(X_train_vect, Y_train, batch_size=256, epochs=8, validation_data=(X_test_vect, Y_test))
Epoch 1/8
12/12 [==============================] - 2s 172ms/step - loss: 1.5553 - accuracy: 0.7283 - val_loss: 1.4553 - val_accuracy: 0.8886
Epoch 2/8
12/12 [==============================] - 1s 118ms/step - loss: 1.2273 - accuracy: 0.9788 - val_loss: 1.1326 - val_accuracy: 0.9068
Epoch 3/8
12/12 [==============================] - 1s 120ms/step - loss: 0.7272 - accuracy: 0.9922 - val_loss: 0.7467 - val_accuracy: 0.9352
Epoch 4/8
12/12 [==============================] - 1s 119ms/step - loss: 0.3110 - accuracy: 0.9963 - val_loss: 0.4733 - val_accuracy: 0.9438
Epoch 5/8
12/12 [==============================] - 1s 121ms/step - loss: 0.1153 - accuracy: 0.9980 - val_loss: 0.3367 - val_accuracy: 0.9493
Epoch 6/8
12/12 [==============================] - 1s 123ms/step - loss: 0.0496 - accuracy: 0.9993 - val_loss: 0.2781 - val_accuracy: 0.9488
Epoch 7/8
12/12 [==============================] - 1s 121ms/step - loss: 0.0260 - accuracy: 0.9997 - val_loss: 0.2506 - val_accuracy: 0.9483
Epoch 8/8
12/12 [==============================] - 1s 116ms/step - loss: 0.0169 - accuracy: 0.9993 - val_loss: 0.2358 - val_accuracy: 0.9488

Explain Network Predictions

Explain Correct Predictions

In this section, we have explained correct prediction. The model has predicted correct category soc.religion.christian. The words like 'homosexual', 'church', 'passages', 'group', 'presbyterian', etc contributed to prediction.

from lime import lime_text

def make_predictions(X_batch):
    X_batch_text = vectorizer.transform(X_batch)
    preds = model.predict(X_batch_text.toarray())
    return preds

explainer = lime_text.LimeTextExplainer(class_names=selected_categories)

rng = np.random.RandomState(42)
idx = rng.randint(1, len(X_test))

print("Prediction : ", selected_categories[model.predict(X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", selected_categories[Y_test[idx]])

explanation = explainer.explain_instance(X_test[idx], classifier_fn=make_predictions, num_features=10, labels=Y_test[idx:idx+1])

explanation.show_in_notebook()

LIME: Interpret Predictions Of Keras Text Classification Networks

Explain Incorrect Predictions

In this section, we have explained one incorrect prediction. The actual category of the selected sample is rec.autos whereas our model predicted rec.sport.baseball. We can check from visualizations which words contributed to a particular category.

Y_test_preds = model.predict(X_test_vect).argmax(axis=-1)
wrong_preds = np.argwhere(Y_test!=Y_test_preds)

rng = np.random.RandomState(123)
idx = rng.choice(wrong_preds.flatten())

print("Actual :     ", selected_categories[Y_test[idx]])
print("Prediction : ", selected_categories[model.predict(X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])

explanation_actual = explainer.explain_instance(X_test[idx], classifier_fn=make_predictions, num_features=10, labels=Y_test[idx:idx+1])
explanation_pred = explainer.explain_instance(X_test[idx], classifier_fn=make_predictions, num_features=10, labels=Y_test_preds[idx:idx+1])

explanation_actual.show_in_notebook()
explanation_pred.show_in_notebook()

LIME: Interpret Predictions Of Keras Text Classification Networks

This ends our small tutorial explaining how we can use lime to explain predictions made by our Keras text classification networks. Please feel free to let us know your views in the comments section.

References

Sunny Solanki  Sunny Solanki

YouTube Subscribe Comfortable Learning through Video Tutorials?

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.

Need Help Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.


Subscribe to Our YouTube Channel

YouTube SubScribe

Newsletter Subscription