Updated On : Feb-14,2022 Tags shap-values, text-classi…
SHAP Values for Text Classification Tasks (Keras NLP)

SHAP Values for Text Classification Tasks (Keras NLP)

SHAP (SHapley Additive exPlanations) is a Python library that uses a Game-theoretic approach to generate SHAP values which can be used to explain predictions made by our machine learning models. SHAP can be used to explain predictions for tasks related to fields like computer vision, natural language processing, structured data ML, etc. We have already covered a detailed tutorial about how it can be used for tabular datasets.

As a part of this tutorial, we'll be primarily concentrating on the text classification task of NLP. We'll be using 20 newsgroups text datasets available from scikit-learn. We have trained a simple neural network designed using Keras with a text dataset. Then, we have explained right and wrong predictions made by the model using SHAP values. We have created various visualizations explaining which words contributed to the predictions.

We have one more tutorial explaining how to use shap values to explain text classification models where we have used scikit-learn vectorizers to vectorize text data. Please feel free to check the below link if you want to have a look at that tutorial.

Below, we have listed important sections of tutorial to give an overview of the material covered.

Important Sections Of Tutorial

  1. Load Data
  2. Vectorize Text Data (Word Frequency)
  3. Create And Train Network
  4. Evaluate Model Performance
  5. Explain Model Predictions Using SHAP Partition Explainer

Below, we have imported important libraries of our tutorial and printed the version that we have used.

In [1]:
import tensorflow
from tensorflow import keras

print("Keras Version : {}".format(keras.__version__))
Keras Version : 2.6.0
In [2]:
import shap

print("SHAP Version : {}".format(shap.__version__))
SHAP Version : 0.40.0

1. Load Data

In this section, we have loaded 20 newsgroups dataset available from scikit-learn. The dataset has around 18k text documents for 20 different categories. We have limited our tutorial to using only 5 categories. This will make our problem a multi-class text classification problem. The dataset is loaded using fetch_20newsgroups() function available from datasets sub-module of scikit-learn. It let us load train and test sets separately.

In [3]:
import numpy as np
from sklearn import datasets

all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
                  'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
                  'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
                  'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
                  'talk.politics.misc','talk.religion.misc']

selected_categories = ['alt.atheism','comp.graphics','rec.sport.hockey','sci.space','talk.politics.misc']

X_train, Y_train = datasets.fetch_20newsgroups(subset="train", categories=selected_categories, return_X_y=True)
X_test , Y_test  = datasets.fetch_20newsgroups(subset="test", categories=selected_categories, return_X_y=True)

X_train = np.array(X_train)
X_test = np.array(X_test)

classes = np.unique(Y_train)
mapping = dict(zip(classes, selected_categories))

len(X_train), len(X_test), classes, mapping
Out[3]:
(2722,
 1811,
 array([0, 1, 2, 3, 4]),
 {0: 'alt.atheism',
  1: 'comp.graphics',
  2: 'rec.sport.hockey',
  3: 'sci.space',
  4: 'talk.politics.misc'})

2. Vectorize Text Data (Word Frequency)

In this section, we have created a keras layer that will be responsible for the vectorization of our text data to a list of floats. We have created a text vectorization layer using TextVectorization() constructor available from keras. After creating a layer, we have trained it using train and test datasets combined to populate its vocabulary. Later on, we have used this trained layer inside of our network so that we can provide text data as input to the network and it'll be vectorized first using this layer before feeding to other layers. The vectorized data generated using this layer will have a frequency of words present in it for each word of the text document.

We have printed a few words of the dictionary as well as dictionary size below. We had a limited dictionary size to a maximum of 50000 important words.

Please make a NOTE that we have not discussed text vectorization in detail in this tutorial as we are assuming that the reader has knowledge of it already. Please feel free to check the below tutorials if you want to refresh your knowledge of the topic as we have covered them in detail over there.

In [4]:
text_vectorizer = keras.layers.TextVectorization(max_tokens=50000, standardize="lower_and_strip_punctuation",
                                                 split="whitespace", output_mode="count", pad_to_max_tokens=True)

text_vectorizer.adapt(np.concatenate((X_train, X_test)), batch_size=512)

vocab = text_vectorizer.get_vocabulary()
print("Vocab : {}".format(vocab[:10]))
print("Vocab Size : {}".format(text_vectorizer.vocabulary_size()))

out = text_vectorizer(X_train[:5])
print("Output Shape : {}".format(out.shape))
Vocab : ['[UNK]', 'the', 'to', 'of', 'a', 'and', 'in', 'is', 'that', 'i']
Vocab Size : 50000
Output Shape : (5, 50000)
In [5]:
import gc

gc.collect()
Out[5]:
775

3. Create And Train Network

In this section, we have created a simple neural network and trained it. Our network consists of a text vectorization layer as the first layer followed by two dense layers with a number of units 64 and 5 respectively. After creating a network, we have compiled it to use Adam optimizer, cross entropy loss, and accuracy metric.

At last, we have trained the network for 5 epochs by providing training and validation datasets. We can notice that model is giving quite a good accuracy after 5 epochs. We can now use it to explain predictions made by it using SHAP values.

In [6]:
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers

def create_model(text_vectorizer):
    return Sequential([
                        layers.Input(shape=(1,), dtype="string"),
                        text_vectorizer,
                        layers.Dense(64, activation="relu"),
                        layers.Dense(len(classes), activation="softmax"),
                    ])

model = create_model(text_vectorizer)

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
text_vectorization (TextVect (None, 50000)             0
_________________________________________________________________
dense (Dense)                (None, 64)                3200064
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 325
=================================================================
Total params: 3,200,389
Trainable params: 3,200,389
Non-trainable params: 0
_________________________________________________________________
In [7]:
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
In [8]:
history = model.fit(X_train, Y_train, batch_size=256, epochs=5, validation_data=(X_test, Y_test))
gc.collect()
Epoch 1/5
11/11 [==============================] - 2s 140ms/step - loss: 1.0532 - accuracy: 0.7682 - val_loss: 0.6370 - val_accuracy: 0.9232
Epoch 2/5
11/11 [==============================] - 1s 114ms/step - loss: 0.2730 - accuracy: 0.9871 - val_loss: 0.3784 - val_accuracy: 0.9409
Epoch 3/5
11/11 [==============================] - 1s 125ms/step - loss: 0.1086 - accuracy: 0.9934 - val_loss: 0.2987 - val_accuracy: 0.9475
Epoch 4/5
11/11 [==============================] - 1s 119ms/step - loss: 0.0595 - accuracy: 0.9974 - val_loss: 0.2706 - val_accuracy: 0.9464
Epoch 5/5
11/11 [==============================] - 1s 115ms/step - loss: 0.0386 - accuracy: 0.9989 - val_loss: 0.2556 - val_accuracy: 0.9498
Out[8]:
1615

4. Evaluate Model Performance

In this section, we have evaluated the performance of the network by calculating accuracy, classification report and confusion matrix metrics. We can notice from the results that the model is doing a pretty good job at predicting the majority of categories. The accuracy of talk.politics.misc is a little less compared to other categories.

We have used various functions available from scikit-learn to calculate metrics. Please feel free to check the below tutorial if you want to know about metrics available from sklearn in detail as the majority of them are covered in it.

Below, we have used scikit-plot Python library to plot the confusion matrix. Please feel free to check the below link if you want to know about it and other ML metric visualizations provided by it.

In [9]:
from sklearn.metrics import accuracy_score, classification_report

train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
Train Accuracy : 0.999
Test  Accuracy : 0.950

Classification Report :
                    precision    recall  f1-score   support

       alt.atheism       0.94      0.93      0.94       319
     comp.graphics       0.93      0.97      0.95       389
  rec.sport.hockey       0.97      0.97      0.97       399
         sci.space       0.95      0.95      0.95       394
talk.politics.misc       0.96      0.91      0.94       310

          accuracy                           0.95      1811
         macro avg       0.95      0.95      0.95      1811
      weighted avg       0.95      0.95      0.95      1811

In [ ]:
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([selected_categories[i] for i in Y_test], [selected_categories[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

SHAP Values for Text Classification Tasks (Keras NLP)

5. Explain Model Predictions Using SHAP Partition Explainer

In this section, we have explained how we can explain predictions made by our model by generating SHAP values using Partition Explainer available from SHAP. Partition Explainer calculates SHAP values by recursively trying a different hierarchy of features of data. We'll try to explain both correct and incorrect predictions made by our model using this explainer.

SHAP library let us generate plots using either its javascript backend or matplotlib. We have used javascript backend as it creates interactive charts that can display values when the mouse is hovered over them as well as let us try a few combinations. In order to use SHAP, we need to initialize it by calling initjs() function on it.

In [ ]:
shap.initjs()

SHAP Values for Text Classification Tasks (Keras NLP)

We can create a Partition explainer instance using Explainer() constructor available from SHAP. We need to provide is our model, masker, and names of output categories. The masker is needed for text or image data as it'll help us map shap values with tokens of our data (words in our case) and hide parts of data that should not be mapped to any shap values. We can also create Partition explainer using Partition() constructor which has same arguments as Explainer() constructor.

We have created a masker using Text() constructor by providing regular expression pattern 'r"\W+"'. When we provide a regular expression like this, the constructor internally creates a tokenizer using it. We can also create our tokenizer function and provide it to this masker. The tokenizer should return a dictionary with two keys.

  • 'input_ids' - The value of this key should be a list of tokens that are not words. It is tokens that are between words like punctuations, spaces, etc. The tokens separate two words.
  • 'offset_mapping' - The values of this key should be a list of tuples of length two. These tuples are start and end indexes of tokens available in 'input_ids' values.

When we create a masker using the regular expression 'r"\W+"', it'll internally create a dictionary with the above keys.

Please make a NOTE that even if we don't provide a pattern when creating a masker object, it'll still use the same 'r"\W+"' pattern.

In [12]:
masker = shap.maskers.Text(tokenizer=r"\W+")
explainer = shap.Explainer(model, masker=masker, output_names=selected_categories)

explainer
Out[12]:
<shap.explainers._partition.Partition at 0x7fb60c071710>

Visualize SHAP Values For Correct Predictions

In this section, we have explained a few correct predictions made by our model using SHAP.

We have selected 2 samples from our dataset and printed their contents first. We have printed their contents in tokenized form. Then, we have used our model and made predictions on both samples. We have printed what was actual target category, predicted target category, and probability of predictions.

Then, we have generated the SHAP values of these data samples using Partition explainer. We have called an instance of explainer by providing a list of text samples to it. We'll be generating various visualizations using these generated SHAP values.

In the next cell after it, we have printed the shape of shap values, base values, and data stored inside of shap values object. The base values are values to which shap values are added to generate predictions. As we have two samples and five target categories, the base value shape is (2,5).

In [13]:
import re

X_batch = X_test[3:5]

print("Samples : ")
for text in X_batch:
    print(re.split(r"\W+", text))
    print()

preds_proba = model.predict(X_batch)
preds = preds_proba.argmax(axis=1)

print("Actual    Target Values : {}".format([selected_categories[target] for target in Y_test[3:5]]))
print("Predicted Target Values : {}".format([selected_categories[target] for target in preds]))
print("Predicted Probabilities : {}".format(preds_proba.max(axis=1)))

shap_values = explainer([text.lower() for text in X_batch]) ## Generate SHAP VAlues
Samples :
['From', 'STK1203', 'VAX003', 'STOCKTON', 'EDU', 'Subject', 'Internet', 'resources', 'X', 'Added', 'Forwarded', 'by', 'Space', 'Digest', 'Organization', 'via', 'International', 'Space', 'University', 'Original', 'Sender', 'isu', 'VACATION', 'VENARI', 'CS', 'CMU', 'EDU', 'Distribution', 'sci', 'Lines', '4', 'exit', '']

['From', 'kane', 'buast7', 'bu', 'edu', 'Hot', 'Young', 'Star', 'Subject', 'Re', 'New', 'Study', 'Out', 'On', 'Gay', 'Percentage', 'Organization', 'Astronomy', 'Department', 'Boston', 'University', 'Boston', 'MA', 'USA', 'Lines', '24', 'In', 'article', '15427', 'optilink', 'COM', 'cramer', 'optilink', 'COM', 'Clayton', 'Cramer', 'writes', 'Homosexuals', 'lie', 'about', 'the', '10', 'number', 'to', 'hide', 'the', 'disproportionate', 'involvement', 'of', 'homosexuals', 'in', 'child', 'molestation', 'They', 'also', 'lie', 'about', '10', 'to', 'keep', 'politicians', 'scared', '1', 'You', 'haven', 't', 'shown', 'any', 'disproportionate', 'involvement', '2', 'The', 'Janus', 'Report', 'which', 'came', 'out', 'recently', 'gives', '9', 'as', 'the', 'percentage', 'of', 'exclusively', 'or', 'predominantly', 'gay', 'men', '3', 'No', 'one', 'is', 'presumably', 'going', 'to', 'say', 'they', 're', 'gay', 'if', 'they', 're', 'not', 'But', 'some', 'no', 'doubt', 'are', 'going', 'to', 'hide', 'their', 'homosexuality', 'in', 'surveys', 'Thus', 'the', '1', '2', 'is', 'a', 'lower', 'limit', 'I', 'still', 'say', 'that', 'weighing', 'all', 'the', 'evidence', 'gives', 'a', 'most', 'likely', 'percentage', 'between', '5', 'and', '7', 'Brian', 'kane', 'buast7', 'astro', 'bu', 'edu', 'Hot', 'Young', 'Star', 'Astronomy', 'Dept', 'Boston', 'University', 'Boston', 'MA', '02215', 'True', 'personal', 'salvation', 'is', 'achieved', 'by', 'absolute', 'faith', 'in', 'ones', 'true', 'self', '']

Actual    Target Values : ['sci.space', 'talk.politics.misc']
Predicted Target Values : ['sci.space', 'talk.politics.misc']
Predicted Probabilities : [0.6578798  0.97715056]
In [14]:
print("SHAP Values Shape : {}".format(shap_values.shape))
print("SHAP Base Values  : {}".format(shap_values.base_values))
print("SHAP Data : ")
print(shap_values.data[0])
print(shap_values.data[1])
SHAP Values Shape : (2, None, 5)
SHAP Base Values  : [[0.19148417 0.21390614 0.20331518 0.20023791 0.19105661]
 [0.19148417 0.21390614 0.20331518 0.20023791 0.19105661]]
SHAP Data :
['from: ' 'stk1203@' 'vax003.' 'stockton.' 'edu\n' 'subject: ' 'internet '
 'resources\n' 'x-' 'added: ' 'forwarded ' 'by ' 'space ' 'digest\n'
 'organization: [' 'via ' 'international ' 'space ' 'university]\n'
 'original-' 'sender: ' 'isu@' 'vacation.' 'venari.' 'cs.' 'cmu.' 'edu\n'
 'distribution: ' 'sci\n' 'lines: ' '4\n\n' 'exit']
['from: ' 'kane@' 'buast7.' 'bu.' 'edu (' 'hot ' 'young ' 'star)\n'
 'subject: ' 're: ' 'new ' 'study ' 'out ' 'on ' 'gay ' 'percentage\n'
 'organization: ' 'astronomy ' 'department, ' 'boston ' 'university, '
 'boston, ' 'ma, ' 'usa\n' 'lines: ' '24\n\n' 'in ' 'article <' '15427@'
 'optilink.' 'com> ' 'cramer@' 'optilink.' 'com (' 'clayton ' 'cramer) '
 'writes:\n\n>' 'homosexuals ' 'lie ' 'about ' 'the ' '10% ' 'number '
 'to ' 'hide ' 'the ' 'disproportionate\n>' 'involvement ' 'of '
 'homosexuals ' 'in ' 'child ' 'molestation.  ' 'they ' 'also ' 'lie\n>'
 'about "' '10%" ' 'to ' 'keep ' 'politicians ' 'scared.\n\n' '1. ' 'you '
 "haven'" 't ' 'shown ' 'any ' 'disproportionate ' 'involvement.\n\n'
 '2. ' 'the ' 'janus ' 'report, ' 'which ' 'came ' 'out ' 'recently, '
 'gives ' '9% ' 'as ' 'the ' 'percentage\n' 'of ' 'exclusively ' 'or '
 'predominantly ' 'gay ' 'men.\n\n' '3. ' 'no ' 'one ' 'is ' 'presumably '
 'going ' 'to ' 'say ' "they'" 're ' 'gay ' 'if ' "they'" 're ' 'not. '
 'but\n' 'some ' 'no ' 'doubt ' 'are ' 'going ' 'to ' 'hide ' 'their '
 'homosexuality ' 'in ' 'surveys. ' 'thus\n' 'the ' '1-' '2% ' 'is ' 'a '
 'lower ' 'limit.\n\n' 'i ' 'still ' 'say ' 'that ' 'weighing ' 'all '
 'the ' 'evidence ' 'gives ' 'a ' 'most ' 'likely ' 'percentage\n'
 'between ' '5 ' 'and ' '7%.\n\n'
 'brian\n------------------------------------------------------------------------------\n'
 'kane@{' 'buast7,' 'astro}.' 'bu.' 'edu (' 'hot ' 'young ' 'star) '
 'astronomy ' 'dept, ' 'boston ' 'university,\n' 'boston, ' 'ma '
 '02215. ' 'true ' 'personal ' 'salvation ' 'is ' 'achieved ' 'by '
 'absolute ' 'faith ' 'in\n' 'ones ' 'true ' 'self']

Text Plots

In this section, we have generated a text plot using shap values. We can easily create a text plot by calling text_plot() function of SHAP library and giving it shap values to generate a text plot.

The text plot shows the actual text of the document and shows how much positively or negatively each word contributed to a particular prediction. The chart is interactive where we can click on different output categories and see the results. When we hover over a word, it shows the shap value of that word that contributed to the predictions. The shap values of all words are added to the base value to generate predictions.

In our case, first sample prediction is 'sci.space' with probability 0.657 and second sample prediction is 'talk.politics.misc' with probability of 0.977. These two categories are highlighted with dark colors in the chart. We can click on other categories as well and check their results. The visualizations highlight the words that contributed positively with shades of red color and those contributed negatively with shades of blue color. We can notice from the visualizations how both samples had words that contributed to the prediction category.

In [ ]:
shap.text_plot(shap_values)

SHAP Values for Text Classification Tasks (Keras NLP)

SHAP Values for Text Classification Tasks (Keras NLP)

Bar Charts

In this section, we have created bar charts that can be analyzed to see which words contributed the most to the prediction.

Below, we have created a first bar chart using average values of 'sci.space' categories. These will help us see which words contributed the most towards that category. We can notice from the results that obvious words like 'space', 'sci', 'international', etc contributed towards that category prediction. We have sorted bars from higher to lower shap values and showed only the most important bars. Please take a look at how we have provided shap values.

In [ ]:
shap.plots.bar(shap_values[:,:, selected_categories[preds[0]]].mean(axis=0), max_display=15,
               order=shap.Explanation.argsort.flip)

SHAP Values for Text Classification Tasks (Keras NLP)

Below, we have created bar charts showing which words and their combinations contributed to predicting category 'sci.space' for 1st sample. It has clustered words that contributed most towards prediction.

In [ ]:
shap.plots.bar(shap_values[0,:, selected_categories[preds[0]]], max_display=15,
               order=shap.Explanation.argsort.flip)

SHAP Values for Text Classification Tasks (Keras NLP)

Below, we have created a bar chart showing which words contributed most towards predicting category 'talk.politics.misc'. We can notice that words like 'study', 'molestation', 'homosexuals', etc are contributing more because they can be used by politicians in their speeches and can be topics of common societal problems.

In [ ]:
shap.plots.bar(shap_values[:,:, selected_categories[preds[1]]].mean(axis=0), max_display=15,
               order=shap.Explanation.argsort.flip)

SHAP Values for Text Classification Tasks (Keras NLP)

Below, we have created a bar chart showing the combinations of words that contributed to predicting category 'talk.politics.misc' for 2nd samples.

In [ ]:
shap.plots.bar(shap_values[1,:, selected_categories[preds[1]]], max_display=15,
               order=shap.Explanation.argsort.flip)

SHAP Values for Text Classification Tasks (Keras NLP)

Waterfall Chart

In this section, we have created a waterfall chart that shows how incrementally adding shap values to base value generates the prediction probability.

Below, we have created a waterfall chart for the first sample. It shows how shap values of words contributed to the base value to come to the final prediction probability.

In [ ]:
shap.waterfall_plot(shap_values[0][:, selected_categories[preds[0]]], max_display=15)

SHAP Values for Text Classification Tasks (Keras NLP)

Below, we have created a waterfall chart for the second sample of data.

In [ ]:
shap.waterfall_plot(shap_values[1][:, selected_categories[preds[1]]], max_display=15)

SHAP Values for Text Classification Tasks (Keras NLP)

Force Plot

In this section, we have created a force plot that shows shap values in an additive force layout.

In [22]:
print(selected_categories)
['alt.atheism', 'comp.graphics', 'rec.sport.hockey', 'sci.space', 'talk.politics.misc']

Below, we have created a force plot for our 1st sample predictions. In order to create a force plot, we need to provide base values, shap values, and features of data (words in our case). It is an interactive plot that can be used to analyze which words contributed most towards prediction.

In [ ]:
import re

tokens = re.split("\W+", X_batch[0].lower())

shap.force_plot(shap_values.base_values[0][preds[0]], shap_values[0][:, preds[0]].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[0]])

SHAP Values for Text Classification Tasks (Keras NLP)

Below, we have created a force plot for 2nd sample.

In [ ]:
import re

tokens = re.split("\W+", X_batch[1].lower())

shap.force_plot(shap_values.base_values[1][preds[1]], shap_values[1][:, preds[1]].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[1]])

SHAP Values for Text Classification Tasks (Keras NLP)

Visualize SHAP Values For Incorrect Predictions

In this section, we have explained the wrong predictions made by our model. We have created various visualizations that show which words contributed to predicting a particular category.

In order to find our wrong predictions, we have first made predictions for all test samples. Then, we have compared them with actual test target values and retrieved indexes of samples that are predicted wrong. We have then used the first two indexes from those wrong predictions indexes for explanation purposes.

We have made predictions using our model for those two samples and printed our prediction categories. We have also printed the original target categories for comparison purposes. Also, we have printed the probabilities with which our model predicted the wrong category (i.e. how sure was our model in predicting it). For first sample actual category was 'sci.space' but our model predicted 'comp.graphics' and for second sample actual category was 'comp.graphics' but our model predicted 'sci.space'.

Then, we have generated shap values using Partition explainer by giving data samples to it.

In [25]:
import re

Y_test_preds = model.predict(X_test)
Y_test_preds = Y_test_preds.argmax(axis=1)

wrong_preds = np.argwhere(Y_test!=Y_test_preds)
X_batch = X_test[wrong_preds.flatten()[:2]]

print("Samples : ")
for text in X_batch:
    print(re.split(r"\W+", text))
    print()

preds_proba = model.predict(X_batch)
preds = preds_proba.argmax(axis=1)

print("Actual    Target Values : {}".format([selected_categories[target] for target in Y_test[wrong_preds.flatten()[:2]]]))
print("Predicted Target Values : {}".format([selected_categories[target] for target in preds]))
print("Predicted Probabilities : {}".format(preds_proba.max(axis=1)))

shap_values = explainer([text.lower() for text in X_batch])

shap_values.shape
Samples :
['From', 'rousself', 'cicb', 'fr', 'Frank', 'ROUSSEL', 'Subject', 'FTP', 'images', 'ASTRO', 'server', 'Keywords', 'ftp', 'astronomy', 'images', 'gif', 'Organization', 'CICB', 'Universite', 'de', 'Rennes', 'I', 'FR', 'Lines', '23', 'I', 'commend', 'everybody', 'to', 'look', 'at', 'the', 'FTP', 'site', 'ftp', 'cicb', 'fr', 'Ethernet', 'address', '129', '20', '128', '2', 'in', 'the', 'directory', 'pub', 'Images', 'ASTRO', 'there', 'are', 'lots', 'of', 'images', 'all', 'of', 'kinds', 'in', 'astronomy', 'subject', 'especially', 'in', 'GIF', 'format', 'and', 'a', 'NEW', 'directory', 'of', 'some', 'JPL', 'animations', 'For', 'your', 'comfort', 'README', 'files', 'in', 'all', 'subdirectories', 'give', 'size', 'and', 'description', 'of', 'each', 'image', 'and', 'a', '7', 'days', 'newer', 'images', 'list', 'is', 'in', 'READMENEW', 'Note', 'you', 'can', 'connect', 'it', 'as', 'anonymous', 'or', 'ftp', 'user', 'then', 'the', 'quota', 'for', 'each', 'is', '8', 'users', 'connected', 'in', 'the', 'same', 'time', 'So', 'if', 'the', 'server', 'responds', 'you', 'connection', 'refused', 'be', 'patient', '2nd', 'note', 'this', 'site', 'is', 'reachable', 'by', 'Gopher', 'at', 'roland', 'cicb', 'fr', 'Ethernet', 'address', '129', '20', '128', '27', 'in', 'Divers', 'serveurs', 'Ftp', 'Le', 'serveur', 'ftp', 'du', 'CRI', 'CICB', 'Images', 'ASTRO', 'If', 'you', 'have', 'any', 'comments', 'suggestions', 'problems', 'then', 'you', 'can', 'contact', 'me', 'at', 'E', 'mail', 'rousself', 'univ', 'rennes1', 'fr', 'Hope', 'you', 'enjoy', 'it', '']

['From', 'jackson', 'sandman', 'ece', 'clarkson', 'edu', 'Peter', 'Jackson', 'CH237A', 'Subject', 'Re', 'Where', 'did', 'the', 'hacker', 'ethic', 'go', 'Nntp', 'Posting', 'Host', 'sandman', 'ece', 'clarkson', 'edu', 'Organization', 'Clarkson', 'University', 'Lines', '31', 'From', 'article', '1993May1', '092058', '1', 'aurora', 'alaska', 'edu', 'by', 'pstlb', 'aurora', 'alaska', 'edu', 'I', 'put', 'it', 'to', 'you', 'thus', 'Where', 'HAS', 'the', 'hacker', 'ethic', 'gone', 'If', 'it', 'still', 'exists', 'where', 'And', 'if', 'it', 'DOES', 'exist', 'why', 'are', 'those', 'who', 'call', 'themselves', 'hackers', 'allowing', 'this', 'to', 'perpetuate', 'itself', 'Why', 'are', 'they', 'not', 'creating', 'new', 'innovative', 'interesting', 'ideas', 'to', 'stop', 'the', 'SOS', 'from', 'maintaining', 'its', 'choke', 'hold', 'on', 'the', 'computer', 'industry', 'Since', 'this', 'was', 'posted', 'on', 'comp', 'ai', 'I', 'assume', 'there', 'is', 'an', 'AI', 'angle', 'to', 'this', 'Hacking', 'is', 'what', 'AI', 'students', 'do', 'when', 'they', 're', 'really', 'supposed', 'to', 'be', 'doing', 'something', 'else', 'e', 'g', 'thesis', 'research', 'write', 'up', 'getting', 'their', 'supervisors', 'pet', 'programs', 'to', 'run', 'properly', 'etc', 'No', 'one', 'gets', 'much', 'glory', 'for', 'hacking', 'and', 'no', 'one', 'gets', 'any', 'money', 'out', 'of', 'it', 'Producing', 'good', 'free', 'software', 'requires', 'an', 'enormous', 'investment', 'of', 'time', 'resources', 'that', 'not', 'many', 'people', 'can', 'or', 'want', 'to', 'afford', 'particularly', 'during', 'a', 'recession', 'In', 'addition', 'over', 'the', 'last', '10', 'years', 'I', 'think', 'there', 'has', 'been', 'a', 'de', 'emphasis', 'on', 'producing', 'running', 'programs', 'in', 'AI', 'research', 'and', 'a', 'greater', 'emphasis', 'on', 'more', 'formal', 'approaches', 'to', 'problem', 'solving', 'Students', 'have', 'been', 'proving', 'theorems', 'instead', 'of', 'writing', 'programs', 'At', 'a', 'conference', 'a', 'year', 'or', 'two', 'ago', 'Johann', 'de', 'Kleer', 'suggested', 'that', 'everyone', 'should', 'Get', 'back', 'to', 'the', 'keyboard', 'and', 'write', 'more', 'programs', 'that', 'demonstrate', 'their', 'ideas', 'and', 'I', 'have', 'to', 'say', 'I', 'm', 'inclined', 'to', 'agree', 'I', 'don', 't', 'claim', 'to', 'be', 'a', 'superhacker', 'but', 'I', 'don', 't', 'think', 'that', 'invalidates', 'my', 'remarks', 'And', 'I', 'm', 'sure', 'this', 'isn', 't', 'the', 'whole', 'story', 'Peter', 'Jackson', 'Dept', 'of', 'Electrical', 'Computer', 'Eng', 'Clarkson', 'University', 'Opinions', 'expressed', 'are', 'not', 'those', 'of', 'my', 'employer', 'or', 'any', 'other', 'organization', 'Second', 'Violin', 'Fiddling', 'Firefighters', 'Ensemble', 'Rome', 'Branch', '']

Actual    Target Values : ['sci.space', 'comp.graphics']
Predicted Target Values : ['comp.graphics', 'sci.space']
Predicted Probabilities : [0.9390678  0.40895548]
Out[25]:
(2, None, 5)
In [26]:
print("SHAP Values Shape : {}".format(shap_values.shape))
print("SHAP Base Values  : {}".format(shap_values.base_values))
SHAP Values Shape : (2, None, 5)
SHAP Base Values  : [[0.19148417 0.21390614 0.20331518 0.20023791 0.19105661]
 [0.19148417 0.21390614 0.20331518 0.20023791 0.19105661]]

Text Plot

Below, we have plotted a text plot showing which shap values contributed to the prediction category. We can notice from the results that the first samples have words like 'ftp', 'ethernet', 'gif', 'image', 'directory', etc that could have contributed towards predicting sample as 'comp.graphics'. For second sample, the model is not much confident about prediction with probability of 0.40 for 'sci.space' followed by 0.285 for 'comp.graphics'.

In [ ]:
shap.text_plot(shap_values)

SHAP Values for Text Classification Tasks (Keras NLP)

SHAP Values for Text Classification Tasks (Keras NLP)

Bar Charts

In this section, we have created various bar charts showing words that contributed to predictions.

Below, we have created a bar chart showing word combinations that contributed towards predicting category 'comp.graphics' for the first sample.

In [ ]:
shap.plots.bar(shap_values[0,:, selected_categories[preds[0]]], max_display=15,
               order=shap.Explanation.argsort.flip)

SHAP Values for Text Classification Tasks (Keras NLP)

Below, we have created a bar chart showing words that contributed towards predicting category 'comp.graphics'.

In [ ]:
shap.plots.bar(shap_values[:,:, selected_categories[preds[0]]].mean(axis=0), max_display=15,
               order=shap.Explanation.argsort.flip)

SHAP Values for Text Classification Tasks (Keras NLP)

Below, we have created a bar chart showing word combinations that contributed towards predicting category 'sci.space' for the second sample.

In [ ]:
shap.plots.bar(shap_values[1,:, selected_categories[preds[1]]], max_display=15,
               order=shap.Explanation.argsort.flip)

SHAP Values for Text Classification Tasks (Keras NLP)

Below, we have created a bar chart showing words that contributed towards predicting category 'sci.space'.

In [ ]:
shap.plots.bar(shap_values[:,:, selected_categories[preds[1]]].mean(axis=0), max_display=15,
               order=shap.Explanation.argsort.flip)

SHAP Values for Text Classification Tasks (Keras NLP)

Force Plots

In this section, we have created force plots showing words' contribution to prediction according to their shap values. We have created a force plot for both samples one after another.

In [32]:
print(selected_categories)
['alt.atheism', 'comp.graphics', 'rec.sport.hockey', 'sci.space', 'talk.politics.misc']
In [ ]:
import re

tokens = re.split("\W+", X_batch[0].lower())

shap.force_plot(shap_values.base_values[0][preds[0]], shap_values[0][:, preds[0]].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[0]])

SHAP Values for Text Classification Tasks (Keras NLP)

In [ ]:
import re

tokens = re.split("\W+", X_batch[1].lower())

shap.force_plot(shap_values.base_values[1][preds[1]], shap_values[1][:, preds[1]].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[1]])

SHAP Values for Text Classification Tasks (Keras NLP)

This ends our small tutorial explaining how we can use python library SHAP to explain predictions made by text classification models created using keras. Please feel free to let us know your views in the comments section.

References

Sunny Solanki  Sunny Solanki

 Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to let us know in the comments section below (Guest Comments are allowed). We appreciate and value your feedbacks.

If you like our work please give a thumbs-up to our article in the comments section below. You can also support us with a small contribution by clicking on Support Us link in the footer section.