Updated On : Feb-21,2022 Tags shap-values, text-classi…
Explain Flax (JAX) Text Classification Networks using SHAP Values

Explain Flax (JAX) Text Classification Networks using SHAP Values

Deep Neural networks nowadays can get more great results on many tasks related to computer vision, NLP, etc. But these networks are quite deep and complicated to understand. This makes it hard to understand how they are making particular predictions. We need to know whether our network has generalized and is using the actual parts of data that should be used to make predictions. Let's say that we are looking at the simple text classification task of classifying emails as spam or not, then, we want that our model is using words that commonly appear in spam mails (free, win, lottery, etc.) to classify emails as spam. Interpreting model predictions is becoming quite common to better understand models.

As a part of this tutorial, we have explained how we can create a neural network using Flax to solve text classification tasks. Then, we have used SHAP python library to explain the predictions made by the model using various visualizations. Flax is a high-level framework designed on top of JAX to simplify neural network creation tasks using JAX. SHAP is a python library that generates SHAP values using a game-theoretic approach that can be visualized to explained the predictions of our models. We assume that the reader has knowledge of neural networks and Flax (JAX). Please feel free to check the below links if you do not have a background on JAX, Flax, and SHAP or you want to refer to them to refresh your knowledge.

Below, we have listed important sections of tutorial to give an overview of the material covered.

Important Sections Of Tutorial

  1. Load Data
  2. Vetorize Text Data
  3. Define Model
  4. Define Loss
  5. Train Model
  6. Evaluate Model Performance
  7. Explain Predictions Using SHAP Partition Explainer
    • Visualize SHAP Values For Correct Predictions
      • Text Plot
      • Bar Charts
      • Waterfall Plots
      • Force Plots
    • Visualize SHAP Values For Incorrect Predictions
  8. Explain Predictions Using SHAP Permutation Explainer
    • Visualize SHAP Values For Correct Predictions
    • Visualize SHAP Values For Incorrect Predictions

Below, we have imported the necessary libraries and printed the versions that we have used in our tutorial.

In [1]:
import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.2.27
In [2]:
import flax

print("FLAX Version : {}".format(flax.__version__))
FLAX Version : 0.3.6
In [3]:
import optax

print("OPTAX Version : {}".format(optax.__version__))
OPTAX Version : 0.1.0
In [4]:
import shap

print("SHAP Version : {}".format(shap.__version__))
SHAP Version : 0.40.0

1. Load Data

In this section, we have loaded 20 newsgroups dataset available from scikit-learn. The data has a total of 20 categories but we have selected 5 for our tutorials. We have loaded dataset using fetch_20newsgroups() function of datasets sub-module of scikit-learn. The function lets us load train and test sets separately.

In [5]:
import numpy as np
from sklearn import datasets
from jax import numpy as jnp

all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
                  'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
                  'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
                  'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
                  'talk.politics.misc','talk.religion.misc']

selected_categories = ['alt.atheism','comp.graphics','rec.sport.baseball','sci.med','talk.politics.guns']

X_train_text, Y_train = datasets.fetch_20newsgroups(subset="train", categories=selected_categories, return_X_y=True)
X_test_text , Y_test  = datasets.fetch_20newsgroups(subset="test", categories=selected_categories, return_X_y=True)

X_train_text = np.array(X_train_text)
X_test_text = np.array(X_test_text)

classes = np.unique(Y_train)
mapping = dict(zip(classes, selected_categories))

len(X_train_text), len(X_test_text), classes, mapping
Out[5]:
(2801,
 1865,
 array([0, 1, 2, 3, 4]),
 {0: 'alt.atheism',
  1: 'comp.graphics',
  2: 'rec.sport.baseball',
  3: 'sci.med',
  4: 'talk.politics.guns'})

2. Vetorize Text Data

In this section, we have vectorized the dataset using Tf-IDF (Term Frequency - Inverse Document Frequency) approach. The approach divides the text into a list of words and then assigns a float value to each word. The value is assigned in a way that words that appear commonly across all documents have less value and words unique per documents have more value.

Below, we have used TextVectorizer() estimator available from scikit-learn to vectorize our data. We have a limited dataset with top 50000 words.

If you want to know how text vectorization works internally then please feel free to check the below link where we have discussed Tf-IDF and other approaches.

If you don't have a background with text classification then we suggest that you go through the below link as it'll help you with this tutorial.

In [6]:
import sklearn
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer

vectorizer = TfidfVectorizer(max_features=50000)

vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)

X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)

X_train.shape, X_test.shape
Out[6]:
((2801, 50000), (1865, 50000))
In [7]:
import gc

gc.collect()
Out[7]:
21

3. Define Model

In this section, we have created a network that we'll use to classify text data. The network has 3 linear layers with 128, 64, and 5 (number of classes/categories). We have applied relu (rectified linear unit) activation after the first two layers.

In the next cell, we have initialized the network and its parameters. Then, we have printed the shape weights and biases of each layer. In the cell followed by it, we have performed a forward pass through the network to make predictions as well.

Please feel free to go through the below tutorial if you don't have a background on Flax and want to learn it to design neural networks.

In [8]:
from flax import linen
from jax import random

class TextClassifier(linen.Module):
    def setup(self):
        self.linear1 = linen.Dense(features=128, name="DENSE1")
        self.linear2 = linen.Dense(features=64, name="DENSE2")
        self.linear3 = linen.Dense(len(classes), name="DENSE3")

    def __call__(self, inputs):
        x = linen.relu(self.linear1(inputs))
        x = linen.relu(self.linear2(x))
        logits = self.linear3(x)

        return logits #linen.softmax(x)
In [9]:
seed = jax.random.PRNGKey(0)

model = TextClassifier()
params = model.init(seed, X_train[:5])

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
    print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
Layer Name : DENSE1
	Layer Weights : (50000, 128), Biases : (128,)
Layer Name : DENSE2
	Layer Weights : (128, 64), Biases : (64,)
Layer Name : DENSE3
	Layer Weights : (64, 5), Biases : (5,)
In [10]:
preds = model.apply(params, X_train[:5])

preds.shape
Out[10]:
(5, 5)

4. Define Loss

In this section, we have defined the cross-entropy loss function that we'll be using during training. The function takes network parameters, input features, and target values as input. It then performs a forward pass through the network using weights and input features to make predictions. THen, it one hot encodes actual target values and calculates loss using softmax_cross_entropy() function available from Optax library.

In [11]:
def CrossEntropyLoss(weights, input_data, actual):
    logits_preds = model.apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).sum()

5. Train Model

In this section, we have trained our network. To train our network, we have created a function below. The function takes train data (X, Y), validation data (X_val, Y_val), number of epochs, model parameters, optimizer state, and batch size as input. It then performs training loops number of epochs time. Each time, it loops through whole data in batches calculating loss, calculating gradients, and updating network parameters using gradients. It also records and prints training loss value. At the end of each epoch, it even calculates validation data accuracy.

In [12]:
from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score

def TrainModelInBatches(X, Y, X_val, Y_val, epochs, weights, optimizer_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in tqdm(batches):
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss)(weights, X_batch,Y_batch)

            ## Update Weights
            updates, optimizer_state = optimizer.update(gradients, optimizer_state)
            weights = optax.apply_updates(weights, updates)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))

        Y_val_preds = model.apply(weights, X_val)
        val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
        print("Validation  Accuracy : {:.3f}".format(val_acc))

    return weights

In the below cell, we have actually trained our network. We have initialized batch size to 256, a number of epochs to 8, and learning rate to 0.001. Then, we have initialized the classifier network and its parameters. Followed by it, we have initialized Adam optimizer and called our training function to perform training with data and required parameters. We can notice from the loss and accuracy getting printed after each epoch that our model seems to be doing a good job.

In [13]:
seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)

model = TextClassifier()
weights = model.init(seed, X_train[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)

final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
100%|██████████| 11/11 [00:06<00:00,  1.77it/s]
CrossEntropyLoss : 397.082
Validation  Accuracy : 0.827
100%|██████████| 11/11 [00:03<00:00,  3.48it/s]
CrossEntropyLoss : 331.718
Validation  Accuracy : 0.875
100%|██████████| 11/11 [00:03<00:00,  3.37it/s]
CrossEntropyLoss : 225.360
Validation  Accuracy : 0.923
100%|██████████| 11/11 [00:03<00:00,  3.50it/s]
CrossEntropyLoss : 111.657
Validation  Accuracy : 0.944
100%|██████████| 11/11 [00:03<00:00,  3.38it/s]
CrossEntropyLoss : 42.300
Validation  Accuracy : 0.949
100%|██████████| 11/11 [00:03<00:00,  3.55it/s]
CrossEntropyLoss : 16.382
Validation  Accuracy : 0.951
100%|██████████| 11/11 [00:03<00:00,  3.42it/s]
CrossEntropyLoss : 7.878
Validation  Accuracy : 0.953
100%|██████████| 11/11 [00:03<00:00,  3.53it/s]
CrossEntropyLoss : 4.686
Validation  Accuracy : 0.955

6. Evaluate Model Performance

In this section, we have evaluated the performance of our network by calculating accuracy, classification report (precision, recall and f1-score per target class/category) and confusion matrix metrics on text predictions. We can notice from the classification report that all target categories have good accuracy with accuracy for 'alt.atheism' being a little less compared to others. Few samples of 'alt.atheism' are confused with 'sci.med' category as per confusion matrix chart in the next cell after below cell.

We have used scikit-learn functions to calculate various metrics below. If you want to know about them in detail then please check the below link that covers the majority of ML metrics available from scikit-learn.

The confusion matrix chart is created using Python library scikit-plot. It let us create charts for many other ML metrics. Please check the below link if you want to learn about it.

In [14]:
from sklearn.metrics import accuracy_score, classification_report

train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
Train Accuracy : 1.000
Test  Accuracy : 0.955

Classification Report :
                    precision    recall  f1-score   support

       alt.atheism       0.97      0.92      0.94       319
     comp.graphics       0.93      0.96      0.95       389
rec.sport.baseball       0.97      0.99      0.98       397
           sci.med       0.95      0.94      0.94       396
talk.politics.guns       0.97      0.97      0.97       364

          accuracy                           0.96      1865
         macro avg       0.96      0.95      0.96      1865
      weighted avg       0.96      0.96      0.96      1865

In [ ]:
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([selected_categories[i] for i in Y_test], [selected_categories[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Explain Flax (JAX) Text Classification Networks using SHAP Values

7. SHAP Partition Explainer

In this section, we have used Partition explainer available from SHAP library to explain the predictions made by our model. The partition explainer calculates shap values by recursively trying a hierarchy of feature combinations. We have created text plots highlighting words that contribute positively/negatively to predictions.

In order to use SHAP, we first need to initialize it by calling initjs() function on it.

Then, we have created a Partition explainer that we'll use later on to create Explainer objects explaining predictions. In order to create an explainer, we need to give it model/function, masker, and output category names. In our case, we have designed a function that takes as input a few samples of text data and returns predictions for them. The function vectorizes data and performs a forward pass-through network to make predictions. The masker object is needed so that when text data is tokenized into a list of words, we know which parts of the text to hide and how to map shap values to words in data. We have created a partition explainer using Explainer() constructor but we can also use Partition() constructor as it has the requirements of the same parameters.

The masker object is created with regular expression 'r"\W+"' which will select strings in data other than words. Internally masker object creates a tokenizer using this pattern and returns a dictionary that has below key and values.

  • 'input_ids' - List of strings that are between words in data. These are the strings that are between two words which can be punctuations, spaces, etc which we want to hide when mapping shap values to words.
  • 'offset_mapping' - List of a tuple of 2 values specifying start and end indexes of strings specified through 'input_ids'.

We can also create our own tokenizer and provide it to the masker.

In [ ]:
shap.initjs()

Explain Flax (JAX) Text Classification Networks using SHAP Values

In [17]:
def make_predictions(X_batch_text):
    X_batch = jnp.array(vectorizer.transform(X_batch_text).toarray())
    logits_preds = model.apply(final_weights, X_batch)
    return linen.softmax(logits_preds)

masker = shap.maskers.Text(tokenizer=r"\W+")
explainer = shap.Explainer(make_predictions, masker=masker, output_names=selected_categories)

explainer
Out[17]:
<shap.explainers._partition.Partition at 0x7fcb12afc2d0>

Visualize SHAP Values For Correct Predictions

In this section, we have created SHAP values for a few correct predictions and visualized them to better understand which words contributed to predictions.

We have first selected 2 data samples and printed their content. Then we have made predictions on them using our trained network. We have printed original categories, predicted categories, and probabilities of predictions. The selected two samples have 'talk.politics.gun' and 'comp.graphics' categories.

At last, we have calculated SHAP values by calling the explainer object giving data samples to it.

In the next cell, we have printed shap values shape, base values shape, and tokens (words) stored in an object.

In [18]:
import re

X_batch_text = X_test_text[3:5]
X_batch = X_test[3:5]

print("Samples : ")
for text in X_batch_text:
    print(re.split(r"\W+", text))
    print()

logits = model.apply(final_weights, X_batch)
preds_proba = linen.softmax(logits)
preds = preds_proba.argmax(axis=1)

print("Actual    Target Values : {}".format([selected_categories[target] for target in Y_test[3:5]]))
print("Predicted Target Values : {}".format([selected_categories[target] for target in preds]))
print("Predicted Probabilities : {}".format(preds_proba.max(axis=1)))

shap_values = explainer(X_batch_text)
Samples :
['From', 'cdt', 'sw', 'stratus', 'com', 'C', 'D', 'Tavares', 'Subject', 'Re', 'ATF', 'BURNS', 'DIVIDIAN', 'RANCH', 'NO', 'SURVIVORS', 'Organization', 'Stratus', 'Computer', 'Inc', 'Lines', '61', 'Distribution', 'world', 'NNTP', 'Posting', 'Host', 'rocket', 'sw', 'stratus', 'com', 'Keywords', 'Nata', 'thing', 'In', 'article', '1993Apr21', '053035', '29591', 'mcs', 'kent', 'edu', 'mhamilto', 'Nimitz', 'mcs', 'kent', 'edu', 'The', 'Lawnmowerman', 'writes', 'Oh', 'then', 'I', 'guess', 'that', 'shooting', 'THOSE', 'kind', 'of', 'babies', 'is', 'all', 'right', 'You', 'sick', 'bastard', 'Why', 'thanks', 'for', 'your', 'reply', 'to', 'my', 'post', 'By', 'the', 'way', 'I', 'never', 'never', 'ever', 'said', 'that', 'it', 'was', 'right', 'to', 'shoot', 'THOSE', 'kind', 'of', 'babies', 'Let', 's', 'go', 'to', 'the', 'videotape', 'Here', 's', 'exactly', 'what', 'you', 'posted', 'THIS', 'IS', 'MURDER', 'ATF', 'MURDERERS', 'BUTCHERS', 'Flame', 'on', 'Is', 'this', 'guy', 'serious', 'If', 'he', 'would', 'ever', 'really', 'pay', 'attention', 'to', 'the', 'news', 'oops', 'I', 'forgot', 'that', 'the', 'media', 'for', 'the', 'most', 'part', 'loves', 'to', 'jump', 'right', 'on', 'top', 'of', 'a', 'story', 'before', 'all', 'the', 'facts', 'are', 'known', 'as', 'well', 'as', 'to', 'manipulate', 'what', 'we', 'see', 'and', 'thus', 'what', 'we', 'believe', 'Any', 'ways', 'one', 'of', 'Koresh', 's', 'DEVOTED', 'followers', 'that', 'DID', 'I', 'REPEAT', 'DID', 'survive', 'this', 'GENOCIDAL', 'MASS', 'SLAUGHTER', 'OF', 'INNOCENT', 'PEOPLE', 'Besides', 'there', 'are', 'nine', 'survivors', 'in', 'the', 'burn', 'unit', 'of', 'the', 'local', 'hospital', 'and', 'was', 'reported', 'that', 'David', 'was', 'in', 'one', 'of', 'the', 'towers', 'when', 'the', 'shit', 'hit', 'the', 'fan', 'Besides', 'a', 'majority', 'of', 'these', 'children', 'were', 'children', 'that', 'he', 'was', 'supposed', 'to', 'have', 'been', 'the', 'father', 'of', 'this', 'then', 'makes', 'them', 'bastard', 'children', 'to', 'a', 'sacraligious', 'zeloit', 'sp', 'Also', 'someone', 'should', 'have', 'told', 'David', 'and', 'his', 'followers', 'that', 'if', 'they', 'can', 't', 'the', 'heat', 'then', 'they', 'should', 'stay', 'out', 'of', 'the', 'kitchen', 'pun', 'intended', 'So', 'if', 'you', 'weren', 't', 'saying', 'that', 'the', 'deaths', 'of', 'bastard', 'children', 'of', 'a', 'sacreligious', 'zealot', 'was', 'no', 'big', 'thing', 'and', 'I', 'know', 'I', 'am', 'not', 'the', 'only', 'one', 'who', 'read', 'it', 'that', 'way', 'just', 'what', 'the', 'hell', 'WERE', 'you', 'saying', 'However', 'it', 'was', 'the', 'Branch', 'Davidian', 'people', 'in', 'there', 'that', 'insisted', 'on', 'staying', 'there', 'with', 'their', 'savior', 'yeah', 'right', 'budy', 'boy', 'because', 'he', 'had', 'brain', 'washed', 'them', 'into', 'believing', 'that', 'what', 'ever', 'he', 'says', 'is', 'the', 'truth', 'even', 'if', 'means', 'that', 'they', 'are', 'to', 'give', 'up', 'their', 'lives', 'for', 'HIS', 'cause', 'Therefore', 'it', 'is', 'Davids', 'fault', 'and', 'not', 'the', 'ATF', 's', 'who', 'gave', 'them', '50', 'to', '51', 'days', 'to', 'get', 'out', 'this', 'was', '50', 'days', 'to', 'many', 'for', 'me', 'and', 'for', 'many', 'of', 'the', 'rest', 'of', 'the', 'U', 'S', 'So', 'who', 'brainwashed', 'YOU', 'into', 'believing', 'that', 'whatever', 'the', 'government', 'says', 'it', 'the', 'truth', 'Or', 'that', 'Koresh', 'was', 'any', 'actual', 'threat', 'to', 'you', 'and', 'the', 'rest', 'of', 'the', 'US', 'Haven', 't', 'you', 'ever', 'stopped', 'to', 'wonder', 'WHY', 'the', 'government', 'raided', 'this', 'farm', 'This', 'raid', 'was', 'NOT', 'ABOUT', 'RELIGION', 'SEX', 'OR', 'CHILD', 'ABUSE', 'To', 'the', 'best', 'available', 'evidence', 'at', 'the', 'moment', 'this', 'raid', 'psychological', 'torture', 'and', 'group', 'death', 'was', 'about', 'Koresh', 's', 'failure', 'to', 'PAY', 'A', '200', 'TAX', 'on', 'ONE', 'grenade', 'launcher', 'that', 'he', 'may', 'or', 'may', 'not', 'even', 'have', 'had', 'That', 's', 'why', 'it', 'was', 'the', 'BATF', 'doing', 'the', 'raiding', 'in', 'the', 'first', 'place', 'Do', 'you', 'think', 'a', '200', 'tax', 'evasion', 'justifies', 'an', 'armed', 'assault', 'by', 'the', 'government', 'with', '100', 'armed', 'men', 'in', 'trailers', 'orchestrated', 'character', 'assassination', 'noise', 'torture', 'and', 'a', 'holocaust', 'I', 'am', 'however', 'sad', 'to', 'hear', 'of', 'the', 'death', 'of', 'any', 'child', 'unlike', 'the', 'sick', 'bastard', 'I', 'supposedly', 'am', 'Then', 'think', 'before', 'you', 'post', 'cdt', 'rocket', 'sw', 'stratus', 'com', 'If', 'you', 'believe', 'that', 'I', 'speak', 'for', 'my', 'company', 'OR', 'cdt', 'vos', 'stratus', 'com', 'write', 'today', 'for', 'my', 'special', 'Investors', 'Packet', '']

['From', 'iharkest', 'Lise', 'Unit', 'NO', 'Inge', 'Harkestad', 'Subject', 'Tangent', 'vectors', 'of', 'Kochanek', 'Bartels', 'splines', 'Organization', 'University', 'of', 'Trondheim', 'Lines', '33', 'I', 'm', 'working', 'on', 'a', 'system', 'which', 'uses', 'a', 'given', 'set', 'of', '3D', 'key', 'frame', 'positions', 'x', 'y', 'z', 'to', 'control', 'an', 'imaginary', 'camera', 'movement', 'I', 'm', 'using', 'Kochanek', 'Bartels', 'splines', 'as', 'described', 'in', 'the', 'SIGGRAPH', '84', 'proceedings', 'to', 'create', 'a', 'variable', 'number', 'of', 'inbetweens', 'between', 'the', 'key', 'frames', 'I', 'want', 'the', 'inbetweens', 'to', 'be', 'given', 'in', 'the', 'form', 'x', 'y', 'z', 'dx', 'dy', 'dz', 'where', 'the', 'last', 'three', 'argumentsa', 'are', 'the', 'x', 'y', 'and', 'z', 'component', 'of', 'the', 'viewing', 'direction', 'vector', 'of', 'the', 'camera', 'when', 'positioned', 'at', 'x', 'y', 'z', 'The', 'method', 'presented', 'by', 'Kochanek', 'and', 'Bartels', 'only', 'deals', 'with', 'the', 'positions', 'of', 'the', 'inbetween', 'view', 'points', 'to', 'be', 'generated', 'I', 've', 'tried', 'to', 'set', 'the', 'viewing', 'direction', 'at', 'a', 'view', 'point', 'equal', 'to', 'the', 'chord', 'between', 'the', 'two', 'adjacent', 'view', 'points', 'which', 'in', 'general', 'are', 'not', 'key', 'frames', 'but', 'this', 'causes', 'a', 'sligt', 'discontinuity', 'of', 'the', 'viewing', 'direction', 'vector', 'at', 'the', 'key', 'frame', 'positions', 'although', 'the', 'spatial', 'movement', 'seems', 'to', 'work', 'fine', 'and', 'I', 'm', 'quite', 'certain', 'I', 'think', 'that', 'I', 've', 'not', 'simply', 'made', 'an', 'implementation', 'error', 'Now', 'I', 'wonder', 'if', 'anyone', 'out', 'there', 'has', 'used', 'this', 'spline', 'form', 'for', 'similar', 'purposes', 'and', 'how', 'they', 'decided', 'the', 'viewing', 'vectors', 'I', 'd', 'appreciate', 'replies', 'to', 'be', 'emailed', 'to', 'me', 'at', 'iharkest', 'lise', 'unit', 'no', 'Anyone', 'else', 'interested', 'in', 'the', 'answer', 'will', 'be', 'sent', 'a', 'summary', 'of', 'the', 'replies', 'if', 'they', 'contact', 'me', '_________', '__________________', 'N', 'G', 'E', 'A', 'R', 'K', 'E', 'S', 'T', 'A', 'D', 'iharkest', 'lise', 'unit', 'no', 'Comp', 'Sc', '________', '_______________', 'NTH', 'Norwegian', 'Institute', 'of', 'Technology', 'Some', 'people', 'play', 'hard', 'to', 'get', 'I', 'play', 'hard', 'to', 'want', 'Ford', 'Fairlane', '']

Actual    Target Values : ['talk.politics.guns', 'comp.graphics']
Predicted Target Values : ['talk.politics.guns', 'comp.graphics']
Predicted Probabilities : [0.9969908 0.9471911]
Partition explainer: 3it [00:12, 12.14s/it]
In [19]:
print("SHAP Values Shape : {}".format(shap_values.shape))
print("SHAP Base Values  : {}".format(shap_values.base_values))
print("SHAP Data : ")
print(shap_values.data[0])
print(shap_values.data[1])
SHAP Values Shape : (2, None, 5)
SHAP Base Values  : [[0.14802793 0.25942031 0.21729104 0.22140868 0.15385203]
 [0.14802793 0.25942031 0.21729104 0.22140868 0.15385203]]
SHAP Data :
['From: ' 'cdt@' 'sw.' 'stratus.' 'com (' 'C. ' 'D. ' 'Tavares)\n'
 'Subject: ' 'Re: ' 'ATF ' 'BURNS ' 'DIVIDIAN ' 'RANCH! ' 'NO '
 'SURVIVORS!!!\n' 'Organization: ' 'Stratus ' 'Computer, ' 'Inc.\n'
 'Lines: ' '61\n' 'Distribution: ' 'world\n' 'NNTP-' 'Posting-' 'Host: '
 'rocket.' 'sw.' 'stratus.' 'com\n' 'Keywords: ' 'Nata ' 'thing !!\n\n'
 'In ' 'article <' '1993Apr21.' '053035.' '29591@' 'mcs.' 'kent.' 'edu>, '
 'mhamilto@' 'Nimitz.' 'mcs.' 'kent.' 'edu (' 'The ' 'Lawnmowerman) '
 'writes:\n\n> > ' 'Oh, ' 'then, ' 'I ' 'guess ' 'that ' 'shooting '
 'THOSE ' 'kind ' 'of ' 'babies ' 'is ' 'all ' 'right.\n> > ' 'You '
 'sick ' 'bastard.\n\n> ' 'Why ' 'thanks ' 'for ' 'your ' 'reply ' 'to '
 'my ' 'post.  ' 'By ' 'the ' 'way, ' 'I ' 'never, ' 'never ' 'ever '
 'said \n> ' 'that ' 'it ' 'was ' 'right ' 'to ' 'shoot "' 'THOSE '
 'kind" ' 'of ' 'babies.  \n\n' "Let'" 's ' 'go ' 'to ' 'the '
 'videotape.  ' "Here'" 's ' 'exactly ' 'what ' 'you ' 'posted:\n\n> > '
 'THIS ' 'IS ' 'MURDER!\n> > ' 'ATF ' 'MURDERERS!  ' 'BUTCHERS!!\n\n> '
 'Flame ' 'on!!\n> ' 'Is ' 'this ' 'guy ' 'serious????\n> ' 'If ' 'he '
 'would ' 'ever ' 'really ' 'pay ' 'attention ' 'to ' 'the ' 'news ('
 'oops ' 'I ' 'forgot ' 'that ' 'the ' 'media\n>    ' 'for ' 'the '
 'most ' 'part ' 'loves ' 'to ' 'jump ' 'right ' 'on ' 'top ' 'of ' 'a '
 'story ' 'before ' 'all ' 'the ' 'facts \n>    ' 'are ' 'known, ' 'as '
 'well ' 'as ' 'to ' 'manipulate ' 'what ' 'we ' 'see ' 'and ' 'thus '
 'what ' 'we ' 'believe). \n>    ' 'Any ' 'ways ' 'one ' 'of ' "Koresh'"
 's ' 'DEVOTED ' 'followers ' 'that ' 'DID ' 'I ' 'REPEAT ' 'DID '
 'survive ' 'this\n>    "' 'GENOCIDAL ' 'MASS-' 'SLAUGHTER ' 'OF '
 'INNOCENT ' 'PEOPLE".  ' 'Besides ' 'there ' 'are ' 'nine \n>    '
 'survivors ' 'in ' 'the ' 'burn-' 'unit ' 'of ' 'the ' 'local '
 'hospital ' 'and ' 'was ' 'reported ' 'that ' 'David\n> -> ' 'was ' 'in '
 'one ' 'of ' 'the ' 'towers ' 'when ' 'the ' 'shit ' 'hit ' 'the '
 'fan.  ' 'Besides, ' 'a ' 'majority ' 'of \n> -> ' 'these ' 'children '
 'were ' 'children ' 'that ' 'he ' 'was ' 'supposed ' 'to ' 'have '
 'been ' 'the ' 'father ' 'of,\n> -> ' 'this ' 'then ' 'makes ' 'them '
 'bastard ' 'children ' 'to ' 'a ' 'sacraligious ' 'zeloit (' 'sp).  '
 'Also\n>    ' 'someone ' 'should ' 'have ' 'told ' 'David ' 'and ' 'his '
 'followers ' 'that ' 'if ' 'they ' "can'" 't ' 'the ' 'heat\n>    '
 'then ' 'they ' 'should ' 'stay ' 'out ' 'of ' 'the ' 'kitchen!! ('
 'pun ' 'intended)\n\n' 'So ' 'if ' 'you ' "weren'" 't ' 'saying ' 'that '
 'the ' 'deaths ' 'of "' 'bastard ' 'children ' 'of ' 'a '
 'sacreligious\n' 'zealot" ' 'was ' 'no ' 'big ' 'thing (' 'and ' 'I '
 'know ' 'I ' 'am ' 'not ' 'the ' 'only ' 'one ' 'who ' 'read ' 'it '
 'that\n' 'way), ' 'just ' 'what ' 'the ' 'hell ' 'WERE ' 'you '
 'saying?\n\n> ' 'However ' 'it ' 'was ' 'the ' 'Branch\n> ' 'Davidian '
 'people ' 'in ' 'there ' 'that ' 'insisted ' 'on ' 'staying ' 'there '
 'with ' 'their "' 'savior" \n> (' 'yeah ' 'right ' 'budy ' 'boy) '
 'because ' 'he ' 'had ' 'brain-' 'washed ' 'them ' 'into ' 'believing '
 'that \n> ' 'what ' 'ever ' 'he ' 'says ' 'is ' 'the ' 'truth, ' 'even '
 'if ' 'means ' 'that ' 'they ' 'are ' 'to ' 'give ' 'up ' 'their\n> '
 'lives ' 'for <<<<' 'HIS>>>> ' 'cause.  ' 'Therefore ' 'it ' 'is '
 'Davids ' 'fault ' 'and ' 'not ' 'the ' "ATF'" 's\n> ' 'who ' 'gave '
 'them ' '50 ' 'to ' '51 ' 'days ' 'to ' 'get ' 'out, ' 'this ' 'was '
 '50 ' 'days ' 'to ' 'many ' 'for ' 'me ' 'and\n> ' 'for ' 'many ' 'of '
 'the ' 'rest ' 'of ' 'the ' 'U.' 'S. \n\n' 'So ' 'who ' 'brainwashed '
 'YOU ' 'into ' 'believing ' 'that ' 'whatever ' 'the ' 'government '
 'says ' 'it\n' 'the ' 'truth?  ' 'Or ' 'that ' 'Koresh ' 'was ' 'any '
 'actual ' 'threat ' 'to ' 'you ' 'and ' 'the ' 'rest ' 'of ' 'the \n'
 'US?  ' "Haven'" 't ' 'you ' 'ever ' 'stopped ' 'to ' 'wonder ' 'WHY '
 'the ' 'government ' 'raided ' 'this ' 'farm?\n' 'This ' 'raid ' 'was '
 'NOT ' 'ABOUT ' 'RELIGION, ' 'SEX, ' 'OR ' 'CHILD ' 'ABUSE.  ' 'To '
 'the ' 'best ' 'available\n' 'evidence ' 'at ' 'the ' 'moment, ' 'this '
 'raid, ' 'psychological ' 'torture, ' 'and ' 'group ' 'death ' 'was \n'
 'about ' "Koresh'" 's ' 'failure ' 'to ' 'PAY ' 'A $' '200 ' 'TAX ' 'on '
 'ONE ' 'grenade ' 'launcher ' 'that ' 'he ' 'may \n' 'or ' 'may ' 'not '
 'even ' 'have ' 'had!  ' "That'" 's ' 'why ' 'it ' 'was ' 'the ' 'BATF '
 'doing ' 'the ' 'raiding ' 'in ' 'the\n' 'first ' 'place.\n\n' 'Do '
 'you ' 'think ' 'a $' '200 ' 'tax ' 'evasion ' 'justifies ' 'an '
 'armed ' 'assault ' 'by ' 'the ' 'government\n' 'with ' '100 ' 'armed '
 'men ' 'in ' 'trailers, ' 'orchestrated ' 'character ' 'assassination, '
 'noise\n' 'torture, ' 'and ' 'a ' 'holocaust?\n\n> ' 'I ' 'am '
 'however ' 'sad ' 'to ' 'hear ' 'of ' 'the ' 'death ' 'of ' 'any \n> '
 'child ' 'unlike ' 'the ' 'sick ' 'bastard ' 'I ' 'supposedly ' 'am.\n\n'
 'Then ' 'think ' 'before ' 'you ' 'post.\n-- \n\n' 'cdt@' 'rocket.' 'sw.'
 'stratus.' 'com   --' 'If ' 'you ' 'believe ' 'that ' 'I ' 'speak '
 'for ' 'my ' 'company,\n' 'OR ' 'cdt@' 'vos.' 'stratus.' 'com        '
 'write ' 'today ' 'for ' 'my ' 'special ' "Investors' " 'Packet']
['From: ' 'iharkest@' 'Lise.' 'Unit.' 'NO (' 'Inge ' 'Harkestad)\n'
 'Subject: ' 'Tangent ' 'vectors ' 'of ' 'Kochanek-' 'Bartels '
 'splines\n' 'Organization: ' 'University ' 'of ' 'Trondheim\n' 'Lines: '
 '33\n\n' "I'" 'm ' 'working ' 'on ' 'a ' 'system ' 'which ' 'uses ' 'a '
 'given ' 'set ' 'of ' '3D ' 'key ' 'frame\n' 'positions (' 'x,' 'y,'
 'z) ' 'to ' 'control ' 'an ' 'imaginary ' 'camera ' 'movement.  ' "I'"
 'm\n' 'using ' 'Kochanek-' 'Bartels ' 'splines (' 'as ' 'described '
 'in ' 'the ' "SIGGRAPH '" '84\n' 'proceedings) ' 'to ' 'create ' 'a '
 'variable ' 'number ' 'of ' 'inbetweens ' 'between\n' 'the ' 'key '
 'frames.  ' 'I ' 'want ' 'the ' 'inbetweens ' 'to ' 'be ' 'given ' 'in '
 'the ' 'form\n(' 'x,' 'y,' 'z,' 'dx,' 'dy,' 'dz) ' 'where ' 'the '
 'last ' 'three ' 'argumentsa ' 'are ' 'the ' 'x, ' 'y ' 'and\n' 'z '
 'component ' 'of ' 'the ' 'viewing ' 'direction ' 'vector ' 'of ' 'the '
 'camera ' 'when\n' 'positioned ' 'at (' 'x,' 'y,' 'z).\n\n' 'The '
 'method ' 'presented ' 'by ' 'Kochanek ' 'and ' 'Bartels ' 'only '
 'deals ' 'with ' 'the\n' 'positions ' 'of ' 'the ' 'inbetween ' 'view '
 'points ' 'to ' 'be ' 'generated.  ' "I'" 've\n' 'tried ' 'to ' 'set '
 'the ' 'viewing ' 'direction ' 'at ' 'a ' 'view ' 'point ' 'equal ' 'to '
 'the\n' 'chord ' 'between ' 'the ' 'two ' 'adjacent ' 'view ' 'points ('
 'which ' 'in ' 'general ' 'are\n' 'not ' 'key ' 'frames), ' 'but '
 'this ' 'causes ' 'a ' 'sligt ' 'discontinuity ' 'of ' 'the\n' 'viewing '
 'direction ' 'vector ' 'at ' 'the ' 'key ' 'frame ' 'positions ('
 'although\n' 'the ' 'spatial ' 'movement ' 'seems ' 'to ' 'work '
 'fine; ' 'and ' "I'" 'm ' 'quite ' 'certain\n- ' 'I ' 'think :) - '
 'that ' "I'" 've ' 'not ' 'simply ' 'made ' 'an ' 'implementation\n'
 'error...)\n\n' 'Now ' 'I ' 'wonder ' 'if ' 'anyone ' 'out ' 'there '
 'has ' 'used ' 'this ' 'spline ' 'form ' 'for\n' 'similar ' 'purposes '
 'and ' 'how ' 'they ' 'decided ' 'the ' 'viewing ' 'vectors.\n' "I'" 'd '
 'appreciate ' 'replies ' 'to ' 'be ' 'emailed ' 'to ' 'me ' 'at\n  '
 'iharkest@' 'lise.' 'unit.' 'no\n' 'Anyone ' 'else ' 'interested ' 'in '
 'the ' 'answer ' 'will ' 'be ' 'sent ' 'a ' 'summary ' 'of\n' 'the '
 'replies ' 'if ' 'they ' 'contact ' 'me.\n\n-- \n ' '_________     '
 '__________________\n\\\\        \\\\  \\\\\n \\\\ ' 'N ' 'G '
 'E  \\\\==\\\\ ' 'A ' 'R ' 'K ' 'E ' 'S ' 'T ' 'A ' 'D     ' 'iharkest@'
 'lise.' 'unit.' 'no         ' 'Comp. ' 'Sc.\n  \\\\' '________\\\\  \\\\'
 '_______________     ' 'NTH (' 'Norwegian ' 'Institute ' 'of '
 'Technology)\n\n          "' 'Some ' 'people ' 'play ' 'hard ' 'to '
 'get, ' 'I ' 'play ' 'hard ' 'to ' 'want" (' 'Ford ' 'Fairlane']

Text Plot

In this section, we have created a text plot that shows text samples and highlights words in it that contributed positively (shades of red) and negatively (shades of blue) to the prediction. We can create a text plot using text_plot() function available from the shap library by giving shap values to it. We can notice from the results that words like 'murder', 'assault', 'death', etc are contributing to predicting category 'talk.politics.gun' category for the first sample and words like 'vector', '3D', 'frame', 'technology', etc are contributing to predicting category 'comp.graphics' for the second sample.

The visualization is interactive and we can hover over words to see their shap values.

In [ ]:
shap.text_plot(shap_values)

Explain Flax (JAX) Text Classification Networks using SHAP Values

Explain Flax (JAX) Text Classification Networks using SHAP Values

Bar Charts

In this section, we have created various bar charts that show the contribution of words in predicting a particular category.

Below, we have created a bar chart that shows words that contributed to predicting category 'talk.politics.gun' category. We can create bar charts using bar() function available from the shap library. We have sorted words so that the ones contributing positively are above. Please pay close attention to how we have provided shap values to the bar chart creation function.

We can notice that words like 'assault', 'raided', etc have contributed to predicting category 'talk.politics.gun'.

In [ ]:
shap.plots.bar(shap_values[:,:, selected_categories[preds[0]]].mean(axis=0), max_display=15,
               order=shap.Explanation.argsort.flip)

Explain Flax (JAX) Text Classification Networks using SHAP Values

Below, we have created a bar chart showing which words from the first sample contributed to predicting category 'talk.politics.gun'.

In [ ]:
shap.plots.bar(shap_values[0,:, selected_categories[preds[0]]], max_display=15,
               order=shap.Explanation.argsort.flip)

Explain Flax (JAX) Text Classification Networks using SHAP Values

Below, we have created a bar chart showing words that contributed to predicting category 'comp.graphics'.

In [ ]:
shap.plots.bar(shap_values[:,:, selected_categories[preds[1]]].mean(axis=0), max_display=15,
               order=shap.Explanation.argsort.flip)

Explain Flax (JAX) Text Classification Networks using SHAP Values

Below, we have created a bar chart showing words from the second sample that contributed to predicting category 'comp.graphics'.

In [ ]:
shap.plots.bar(shap_values[1,:, selected_categories[preds[1]]], max_display=15,
               order=shap.Explanation.argsort.flip)

Explain Flax (JAX) Text Classification Networks using SHAP Values

Waterfall Plots

In this section, we have created a waterfall chart that shows how adding shap values to base values comes to a particular prediction. We can create waterfall charts using waterfall_plot() function available from the shap library.

Below, we have created a waterfall chart for the first prediction followed by a waterfall chart for the second prediction.

In [ ]:
shap.waterfall_plot(shap_values[0][:, selected_categories[preds[0]]], max_display=15)

Explain Flax (JAX) Text Classification Networks using SHAP Values

In [ ]:
shap.waterfall_plot(shap_values[1][:, selected_categories[preds[1]]], max_display=15)

Explain Flax (JAX) Text Classification Networks using SHAP Values

Force Plots

In this section, we have created a force plot that shows shap values in an additive force layout. We can create a force plots using force_plot() function. We need to provide base values and shap values as input to function to create a visualization. It'll show how adding shap values of individual words to base value comes to particular prediction probability.

Below, we have created force plots for both of our predictions. Please pay close attention to how we provided base values and shape values to the function.

In [27]:
print(selected_categories)
['alt.atheism', 'comp.graphics', 'rec.sport.baseball', 'sci.med', 'talk.politics.guns']
In [ ]:
import re

tokens = re.split("\W+", X_batch_text[0].lower())

shap.force_plot(shap_values.base_values[0][preds[0].item()], shap_values[0][:, preds[0].item()].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[0]])

Explain Flax (JAX) Text Classification Networks using SHAP Values

In [ ]:
import re

tokens = re.split("\W+", X_batch_text[1].lower())

shap.force_plot(shap_values.base_values[1][preds[1].item()], shap_values[1][:, preds[1].item()].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[1]])

Explain Flax (JAX) Text Classification Networks using SHAP Values

Visualize SHAP Values For Incorrect Predictions

In this section, we have visualized shap values for wrong predictions. Below, we have first made predictions for all test data samples. Then, we have found out indexes of samples that were predicted wrong. Using those indexes, we have selected two samples that were predicted wrong by our model. For first sample, actual category was 'alt.atheism' but model predicted 'talk.politics.guns'. For second sample, actual category was 'sci.med' but model predicted 'comp.graphics'. Now, we'll try to understand what words could have contributed to wrong predictions. We can see that model predicted category 'talk.politics.gun' category with a probability of 0.74 and category 'comp.graphics' with a probability of 0.42.

We have calculated shap values for wrong predictions using explain object created earlier. We have also printed the shape of shap values and base values.

In [30]:
import re

Y_test_preds = model.apply(final_weights, X_test)
Y_test_preds = Y_test_preds.argmax(axis=1)

wrong_preds = np.argwhere(Y_test!=Y_test_preds)
X_batch_text = X_test_text[wrong_preds.flatten()[:2]]
X_batch = X_test[wrong_preds.flatten()[:2]]

print("Samples : ")
for text in X_batch_text:
    print(re.split(r"\W+", text))
    print()

preds_proba = linen.softmax(model.apply(final_weights, X_batch))
preds = preds_proba.argmax(axis=1)

print("Actual    Target Values : {}".format([selected_categories[target] for target in Y_test[wrong_preds.flatten()[:2]]]))
print("Predicted Target Values : {}".format([selected_categories[target] for target in preds]))
print("Predicted Probabilities : {}".format(preds_proba.max(axis=1)))

shap_values = explainer(X_batch_text)

shap_values.shape
Samples :
['From', 'aaron', 'minster', 'york', 'ac', 'uk', 'Subject', 'Re', 'Death', 'Penalty', 'Gulf', 'War', 'long', 'Distribution', 'world', 'Organization', 'Department', 'of', 'Computer', 'Science', 'University', 'of', 'York', 'England', 'Lines', '22', 'Shamim', 'Zvonko', 'Mohamed', 'sham', 'cs', 'arizona', 'edu', 'wrote', 'BULLSHIT', 'In', 'the', 'Gulf', 'Massacre', '7', 'of', 'all', 'ordnance', 'used', 'was', 'smart', 'The', 'rest', 'that', 's', '93', 'was', 'just', 'regular', 'dumb', 'ol', 'iron', 'bombs', 'and', 'stuff', 'Have', 'you', 'forgotten', 'that', 'the', 'Pentagon', 'definition', 'of', 'a', 'successful', 'Patriot', 'launch', 'was', 'when', 'the', 'missile', 'cleared', 'the', 'launching', 'tube', 'with', 'no', 'damage', 'Or', 'that', 'a', 'successful', 'interception', 'of', 'a', 'Scud', 'was', 'defined', 'as', 'the', 'Patriot', 'and', 'Scud', 'passed', 'each', 'other', 'in', 'the', 'same', 'area', 'of', 'the', 'sky', 'And', 'of', 'the', '7', 'that', 'was', 'the', 'smart', 'stuff', '35', 'hit', 'Again', 'try', 'to', 'follow', 'me', 'here', 'that', 'means', '65', 'of', 'this', 'smart', 'arsenal', 'missed', 'I', 'used', 'to', 'have', 'full', 'figures', 'on', 'this', 'including', 'the', 'tons', 'of', 'bombs', 'dropped', 'and', 'the', 'number', 'of', 'cluster', 'bomblet', 'munitions', 'used', 'I', 'had', 'heard', 'the', '90', 'of', 'the', 'laser', 'guided', 'weapons', 'hit', 'which', 'is', 'an', 'unprecedented', 'rate', 'of', 'success', '25', 'of', 'the', 'iron', 'weapons', 'hit', 'again', 'unprecedented', 'The', 'following', 'is', 'a', 'rough', 'estimate', 'but', 'this', 'means', 'of', 'the', '80', '000', 'tons', 'of', 'bombs', 'dropped', 'by', 'US', 'aircraft', 'around', '56', '000', 'tons', 'missed', 'I', 'm', 'not', 'sure', 'what', 'proportion', 'of', 'this', 'was', 'dropped', 'of', 'Baghdad', 'rather', 'than', 'troop', 'concentrations', 'in', 'Iraq', 'and', 'Kuwait', 'Much', 'of', 'the', 'tonnage', 'dropped', 'was', 'cluster', 'munitions', 'as', 'were', 'all', 'the', 'MRLS', 'rounds', 'and', 'many', 'of', 'the', 'artillery', 'rounds', 'Napalm', 'and', 'fuel', 'air', 'explosives', 'were', 'also', 'used', 'Remember', 'how', 'we', 'were', 'told', 'that', 'weapons', 'of', 'mass', 'destruction', 'such', 'as', 'FAE', 'were', 'very', 'naughty', 'indeed', '']

['From', 'per', 'arne', 'melander', 'histocel', 'umu', 'se', 'Per', 'Arne', 'Melander', 'Subject', 'HELP', 'purification', 'of', 'neutrophils', 'from', 'mice', 'Organization', 'University', 'of', 'Ume', 'Sweden', 'Lines', '16', 'Hello', 'I', 'need', 'a', 'technique', 'for', 'separation', 'of', 'polymorphonuclear', 'neutrophils', 'PMN', 'from', 'the', 'peripheral', 'blood', 'of', 'mice', 'Because', 'of', 'the', 'low', 'PMN', 'Lymphocyte', 'ratio', 'approx', '20', '80', 'it', 's', 'not', 'just', 'as', 'easy', 'as', 'the', 'corresponding', 'technique', 'used', 'with', 'human', 'blood', 'Yours', 'Per', 'Arne', 'Melander', 'Ps', 'My', 'E', 'mail', 'address', 'is', 'per', 'arne', 'melander', 'histocel', 'umu', 'se', 'DS', '']

Actual    Target Values : ['alt.atheism', 'sci.med']
Predicted Target Values : ['talk.politics.guns', 'comp.graphics']
Predicted Probabilities : [0.74430966 0.4276343 ]
Out[30]:
(2, None, 5)
In [31]:
print("SHAP Values Shape : {}".format(shap_values.shape))
print("SHAP Base Values  : {}".format(shap_values.base_values))
SHAP Values Shape : (2, None, 5)
SHAP Base Values  : [[0.14802793 0.25942031 0.21729104 0.22140868 0.15385203]
 [0.14802793 0.25942031 0.21729104 0.22140868 0.15385203]]

Text Plot

In this section, we have created a text plot of shap values for wrong predictions. We can notice from the visualization that for the first sample, words like 'massacre', 'arsenal', 'bombs', 'weapons', etc could have contributed to predicting category 'talk.politics.gun'. For the second sample, our model is very much confused between 'comp.graphics' and 'sci.med' as probabilities for them are 0.42 and 0.39 respectively. Words like 'technique', 'mail', etc are contributing towards predicting 'comp.graphics' whereas words like 'blood', 'mice', 'lymphocyte', 'peripheral', etc are contributing towards 'sci.med' category. We can select both categories one by one in the chart to see the contributions of words towards them.

In [ ]:
shap.text_plot(shap_values)

Explain Flax (JAX) Text Classification Networks using SHAP Values

Explain Flax (JAX) Text Classification Networks using SHAP Values

Bar Charts

In this section, we have created various bar charts explaining words that contributed to wrong predictions.

Below, we have created a bar chart displaying words from 1st sample that contributed to predicting category 'talk.politics.gun'.

In [ ]:
shap.plots.bar(shap_values[0,:, selected_categories[preds[0]]], max_display=15,
               order=shap.Explanation.argsort.flip)

Explain Flax (JAX) Text Classification Networks using SHAP Values

Below, we have created a bar chart displaying words that contributed to predicting category 'talk.politics.gun'.

In [ ]:
shap.plots.bar(shap_values[:,:, selected_categories[preds[0]]].mean(axis=0), max_display=15,
               order=shap.Explanation.argsort.flip)

Explain Flax (JAX) Text Classification Networks using SHAP Values

Below, we have created a bar chart displaying words from the second sample that contributed to predicting category 'comp.graphics'.

In [ ]:
shap.plots.bar(shap_values[1,:, selected_categories[preds[1]]], max_display=15,
               order=shap.Explanation.argsort.flip)

Explain Flax (JAX) Text Classification Networks using SHAP Values

Below, we have created a bar chart displaying words that contributed to predicting category 'comp.graphics'.

In [ ]:
shap.plots.bar(shap_values[:,:, selected_categories[preds[1]]].mean(axis=0), max_display=15,
               order=shap.Explanation.argsort.flip)

Explain Flax (JAX) Text Classification Networks using SHAP Values

Force Plots

In this section, we have created force plots explaining the first and second sample predictions.

In [37]:
print(selected_categories)
['alt.atheism', 'comp.graphics', 'rec.sport.baseball', 'sci.med', 'talk.politics.guns']
In [ ]:
import re

tokens = re.split("\W+", X_batch_text[0].lower())

shap.force_plot(shap_values.base_values[0][preds[0].item()], shap_values[0][:, preds[0].item()].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[0]])

Explain Flax (JAX) Text Classification Networks using SHAP Values

In [ ]:
import re

tokens = re.split("\W+", X_batch_text[1].lower())

shap.force_plot(shap_values.base_values[1][preds[1].item()], shap_values[1][:, preds[1].item()].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[1]])

Explain Flax (JAX) Text Classification Networks using SHAP Values

8. SHAP Permutation Explainer

In this section, we have created a permutation explainer and will try to explain predictions made by our models using it. We have tried to explain both correct and incorrect predictions like earlier. Permutation explainer iterates through all permutation of features in forward and reverse direction to generate shap values.

We have created a permutation explainer using PermutationExplainer() constructor available from shap. We have given exactly the same parameters to it that we had given to the partition explainer.

In [40]:
def make_predictions(X_batch_text):
    X_batch = jnp.array(vectorizer.transform(X_batch_text).toarray())
    logits_preds = model.apply(final_weights, X_batch)
    return linen.softmax(logits_preds)

masker = shap.maskers.Text(tokenizer=r"\W+")
explainer = shap.PermutationExplainer(make_predictions, masker=masker, output_names=selected_categories)

explainer
Out[40]:
<shap.explainers._permutation.Permutation.__init__.<locals>.Permutation at 0x7fcb100b2a90>

Visualize SHAP Values For Correct Predictions

In this section, we'll use a permutation explainer to explain correct predictions made by our model.

Below, we have selected two samples from our test dataset and made predictions on them. For the first sample, our model correctly predicted category 'talk.politics.guns' with 0.996 probability, and for the second sample, it correctly predicted category 'comp.graphics' with probability 0.947.

Then, we have created shap values for both samples using a permutation explainer. In the next cell, we have also printed base values and shape of shap values for explanation purposes.

In [41]:
import re

X_batch_text = X_test_text[3:5]
X_batch = X_test[3:5]

print("Samples : ")
for text in X_batch_text:
    print(re.split(r"\W+", text))
    print()

logits = model.apply(final_weights, X_batch)
preds_proba = linen.softmax(logits)
preds = preds_proba.argmax(axis=1)

print("Actual    Target Values : {}".format([selected_categories[target] for target in Y_test[3:5]]))
print("Predicted Target Values : {}".format([selected_categories[target] for target in preds]))
print("Predicted Probabilities : {}".format(preds_proba.max(axis=1)))

shap_values = explainer(X_batch_text, max_evals=1200)
Samples :
['From', 'cdt', 'sw', 'stratus', 'com', 'C', 'D', 'Tavares', 'Subject', 'Re', 'ATF', 'BURNS', 'DIVIDIAN', 'RANCH', 'NO', 'SURVIVORS', 'Organization', 'Stratus', 'Computer', 'Inc', 'Lines', '61', 'Distribution', 'world', 'NNTP', 'Posting', 'Host', 'rocket', 'sw', 'stratus', 'com', 'Keywords', 'Nata', 'thing', 'In', 'article', '1993Apr21', '053035', '29591', 'mcs', 'kent', 'edu', 'mhamilto', 'Nimitz', 'mcs', 'kent', 'edu', 'The', 'Lawnmowerman', 'writes', 'Oh', 'then', 'I', 'guess', 'that', 'shooting', 'THOSE', 'kind', 'of', 'babies', 'is', 'all', 'right', 'You', 'sick', 'bastard', 'Why', 'thanks', 'for', 'your', 'reply', 'to', 'my', 'post', 'By', 'the', 'way', 'I', 'never', 'never', 'ever', 'said', 'that', 'it', 'was', 'right', 'to', 'shoot', 'THOSE', 'kind', 'of', 'babies', 'Let', 's', 'go', 'to', 'the', 'videotape', 'Here', 's', 'exactly', 'what', 'you', 'posted', 'THIS', 'IS', 'MURDER', 'ATF', 'MURDERERS', 'BUTCHERS', 'Flame', 'on', 'Is', 'this', 'guy', 'serious', 'If', 'he', 'would', 'ever', 'really', 'pay', 'attention', 'to', 'the', 'news', 'oops', 'I', 'forgot', 'that', 'the', 'media', 'for', 'the', 'most', 'part', 'loves', 'to', 'jump', 'right', 'on', 'top', 'of', 'a', 'story', 'before', 'all', 'the', 'facts', 'are', 'known', 'as', 'well', 'as', 'to', 'manipulate', 'what', 'we', 'see', 'and', 'thus', 'what', 'we', 'believe', 'Any', 'ways', 'one', 'of', 'Koresh', 's', 'DEVOTED', 'followers', 'that', 'DID', 'I', 'REPEAT', 'DID', 'survive', 'this', 'GENOCIDAL', 'MASS', 'SLAUGHTER', 'OF', 'INNOCENT', 'PEOPLE', 'Besides', 'there', 'are', 'nine', 'survivors', 'in', 'the', 'burn', 'unit', 'of', 'the', 'local', 'hospital', 'and', 'was', 'reported', 'that', 'David', 'was', 'in', 'one', 'of', 'the', 'towers', 'when', 'the', 'shit', 'hit', 'the', 'fan', 'Besides', 'a', 'majority', 'of', 'these', 'children', 'were', 'children', 'that', 'he', 'was', 'supposed', 'to', 'have', 'been', 'the', 'father', 'of', 'this', 'then', 'makes', 'them', 'bastard', 'children', 'to', 'a', 'sacraligious', 'zeloit', 'sp', 'Also', 'someone', 'should', 'have', 'told', 'David', 'and', 'his', 'followers', 'that', 'if', 'they', 'can', 't', 'the', 'heat', 'then', 'they', 'should', 'stay', 'out', 'of', 'the', 'kitchen', 'pun', 'intended', 'So', 'if', 'you', 'weren', 't', 'saying', 'that', 'the', 'deaths', 'of', 'bastard', 'children', 'of', 'a', 'sacreligious', 'zealot', 'was', 'no', 'big', 'thing', 'and', 'I', 'know', 'I', 'am', 'not', 'the', 'only', 'one', 'who', 'read', 'it', 'that', 'way', 'just', 'what', 'the', 'hell', 'WERE', 'you', 'saying', 'However', 'it', 'was', 'the', 'Branch', 'Davidian', 'people', 'in', 'there', 'that', 'insisted', 'on', 'staying', 'there', 'with', 'their', 'savior', 'yeah', 'right', 'budy', 'boy', 'because', 'he', 'had', 'brain', 'washed', 'them', 'into', 'believing', 'that', 'what', 'ever', 'he', 'says', 'is', 'the', 'truth', 'even', 'if', 'means', 'that', 'they', 'are', 'to', 'give', 'up', 'their', 'lives', 'for', 'HIS', 'cause', 'Therefore', 'it', 'is', 'Davids', 'fault', 'and', 'not', 'the', 'ATF', 's', 'who', 'gave', 'them', '50', 'to', '51', 'days', 'to', 'get', 'out', 'this', 'was', '50', 'days', 'to', 'many', 'for', 'me', 'and', 'for', 'many', 'of', 'the', 'rest', 'of', 'the', 'U', 'S', 'So', 'who', 'brainwashed', 'YOU', 'into', 'believing', 'that', 'whatever', 'the', 'government', 'says', 'it', 'the', 'truth', 'Or', 'that', 'Koresh', 'was', 'any', 'actual', 'threat', 'to', 'you', 'and', 'the', 'rest', 'of', 'the', 'US', 'Haven', 't', 'you', 'ever', 'stopped', 'to', 'wonder', 'WHY', 'the', 'government', 'raided', 'this', 'farm', 'This', 'raid', 'was', 'NOT', 'ABOUT', 'RELIGION', 'SEX', 'OR', 'CHILD', 'ABUSE', 'To', 'the', 'best', 'available', 'evidence', 'at', 'the', 'moment', 'this', 'raid', 'psychological', 'torture', 'and', 'group', 'death', 'was', 'about', 'Koresh', 's', 'failure', 'to', 'PAY', 'A', '200', 'TAX', 'on', 'ONE', 'grenade', 'launcher', 'that', 'he', 'may', 'or', 'may', 'not', 'even', 'have', 'had', 'That', 's', 'why', 'it', 'was', 'the', 'BATF', 'doing', 'the', 'raiding', 'in', 'the', 'first', 'place', 'Do', 'you', 'think', 'a', '200', 'tax', 'evasion', 'justifies', 'an', 'armed', 'assault', 'by', 'the', 'government', 'with', '100', 'armed', 'men', 'in', 'trailers', 'orchestrated', 'character', 'assassination', 'noise', 'torture', 'and', 'a', 'holocaust', 'I', 'am', 'however', 'sad', 'to', 'hear', 'of', 'the', 'death', 'of', 'any', 'child', 'unlike', 'the', 'sick', 'bastard', 'I', 'supposedly', 'am', 'Then', 'think', 'before', 'you', 'post', 'cdt', 'rocket', 'sw', 'stratus', 'com', 'If', 'you', 'believe', 'that', 'I', 'speak', 'for', 'my', 'company', 'OR', 'cdt', 'vos', 'stratus', 'com', 'write', 'today', 'for', 'my', 'special', 'Investors', 'Packet', '']

['From', 'iharkest', 'Lise', 'Unit', 'NO', 'Inge', 'Harkestad', 'Subject', 'Tangent', 'vectors', 'of', 'Kochanek', 'Bartels', 'splines', 'Organization', 'University', 'of', 'Trondheim', 'Lines', '33', 'I', 'm', 'working', 'on', 'a', 'system', 'which', 'uses', 'a', 'given', 'set', 'of', '3D', 'key', 'frame', 'positions', 'x', 'y', 'z', 'to', 'control', 'an', 'imaginary', 'camera', 'movement', 'I', 'm', 'using', 'Kochanek', 'Bartels', 'splines', 'as', 'described', 'in', 'the', 'SIGGRAPH', '84', 'proceedings', 'to', 'create', 'a', 'variable', 'number', 'of', 'inbetweens', 'between', 'the', 'key', 'frames', 'I', 'want', 'the', 'inbetweens', 'to', 'be', 'given', 'in', 'the', 'form', 'x', 'y', 'z', 'dx', 'dy', 'dz', 'where', 'the', 'last', 'three', 'argumentsa', 'are', 'the', 'x', 'y', 'and', 'z', 'component', 'of', 'the', 'viewing', 'direction', 'vector', 'of', 'the', 'camera', 'when', 'positioned', 'at', 'x', 'y', 'z', 'The', 'method', 'presented', 'by', 'Kochanek', 'and', 'Bartels', 'only', 'deals', 'with', 'the', 'positions', 'of', 'the', 'inbetween', 'view', 'points', 'to', 'be', 'generated', 'I', 've', 'tried', 'to', 'set', 'the', 'viewing', 'direction', 'at', 'a', 'view', 'point', 'equal', 'to', 'the', 'chord', 'between', 'the', 'two', 'adjacent', 'view', 'points', 'which', 'in', 'general', 'are', 'not', 'key', 'frames', 'but', 'this', 'causes', 'a', 'sligt', 'discontinuity', 'of', 'the', 'viewing', 'direction', 'vector', 'at', 'the', 'key', 'frame', 'positions', 'although', 'the', 'spatial', 'movement', 'seems', 'to', 'work', 'fine', 'and', 'I', 'm', 'quite', 'certain', 'I', 'think', 'that', 'I', 've', 'not', 'simply', 'made', 'an', 'implementation', 'error', 'Now', 'I', 'wonder', 'if', 'anyone', 'out', 'there', 'has', 'used', 'this', 'spline', 'form', 'for', 'similar', 'purposes', 'and', 'how', 'they', 'decided', 'the', 'viewing', 'vectors', 'I', 'd', 'appreciate', 'replies', 'to', 'be', 'emailed', 'to', 'me', 'at', 'iharkest', 'lise', 'unit', 'no', 'Anyone', 'else', 'interested', 'in', 'the', 'answer', 'will', 'be', 'sent', 'a', 'summary', 'of', 'the', 'replies', 'if', 'they', 'contact', 'me', '_________', '__________________', 'N', 'G', 'E', 'A', 'R', 'K', 'E', 'S', 'T', 'A', 'D', 'iharkest', 'lise', 'unit', 'no', 'Comp', 'Sc', '________', '_______________', 'NTH', 'Norwegian', 'Institute', 'of', 'Technology', 'Some', 'people', 'play', 'hard', 'to', 'get', 'I', 'play', 'hard', 'to', 'want', 'Ford', 'Fairlane', '']

Actual    Target Values : ['talk.politics.guns', 'comp.graphics']
Predicted Target Values : ['talk.politics.guns', 'comp.graphics']
Predicted Probabilities : [0.9969908 0.9471911]
Permutation explainer: 3it [00:57, 28.71s/it]
In [42]:
print("SHAP Values Shape : {}".format(shap_values.shape))
print("SHAP Base Values  : {}".format(shap_values.base_values))
SHAP Values Shape : (2, None, 5)
SHAP Base Values  : [[0.14802793 0.25942031 0.21729104 0.22140868 0.15385202]
 [0.14802793 0.25942031 0.21729104 0.22140868 0.15385202]]

Text Plot

In this section, we have created a text plot using the shap values of both samples.

We can notice from visualization that for first sample, words like 'shooting', 'threat', 'raid', 'government', etc contributed to predicting category 'talk.politics.gun'. For second sample, words like 'vector', 'spatial', 'implementation', 'frame', '3D', etc contributed to predicting category 'comp.graphics'.

In [43]:
print(selected_categories)
['alt.atheism', 'comp.graphics', 'rec.sport.baseball', 'sci.med', 'talk.politics.guns']
In [ ]:
shap.text_plot(shap_values)

Explain Flax (JAX) Text Classification Networks using SHAP Values

Explain Flax (JAX) Text Classification Networks using SHAP Values

Force Plots

In this section, we have created force plots for both samples explaining how adding shap values to a base value creates prediction probability.

In [45]:
print(selected_categories)
['alt.atheism', 'comp.graphics', 'rec.sport.baseball', 'sci.med', 'talk.politics.guns']
In [ ]:
import re

tokens = re.split("\W+", X_batch_text[0].lower())

shap.force_plot(shap_values.base_values[0][preds[0].item()], shap_values[0][:, preds[0].item()].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[0]])

Explain Flax (JAX) Text Classification Networks using SHAP Values

In [ ]:
import re

tokens = re.split("\W+", X_batch_text[1].lower())

shap.force_plot(shap_values.base_values[1][preds[1].item()], shap_values[1][:, preds[1].item()].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[1]])

Explain Flax (JAX) Text Classification Networks using SHAP Values

Visualize SHAP Values For Incorrect Predictions

In this section, we have tried to explain the wrong predictions made by our model.

Like earlier, we have first found out indexes of all wrong predictions from the test dataset. Then, we have selected two samples that are predicted wrong by our model. For first sample actual category is 'alt.atheism', but our model predicted 'talk.politics.guns' with probability of 0.74 and for second sample actual category is 'sci.med' but our model predicted 'comp.graphics' with probability of 0.42.

Then, we have generated shap values for both samples using our permutation explainer.

In [48]:
import re

Y_test_preds = model.apply(final_weights, X_test)
Y_test_preds = Y_test_preds.argmax(axis=1)

wrong_preds = np.argwhere(Y_test!=Y_test_preds)
X_batch_text = X_test_text[wrong_preds.flatten()[:2]]
X_batch = X_test[wrong_preds.flatten()[:2]]

print("Samples : ")
for text in X_batch_text:
    print(re.split(r"\W+", text))
    print()

preds_proba = linen.softmax(model.apply(final_weights, X_batch))
preds = preds_proba.argmax(axis=1)

print("Actual    Target Values : {}".format([selected_categories[target] for target in Y_test[wrong_preds.flatten()[:2]]]))
print("Predicted Target Values : {}".format([selected_categories[target] for target in preds]))
print("Predicted Probabilities : {}".format(preds_proba.max(axis=1)))

shap_values = explainer(X_batch_text, max_evals=600)

shap_values.shape
Samples :
['From', 'aaron', 'minster', 'york', 'ac', 'uk', 'Subject', 'Re', 'Death', 'Penalty', 'Gulf', 'War', 'long', 'Distribution', 'world', 'Organization', 'Department', 'of', 'Computer', 'Science', 'University', 'of', 'York', 'England', 'Lines', '22', 'Shamim', 'Zvonko', 'Mohamed', 'sham', 'cs', 'arizona', 'edu', 'wrote', 'BULLSHIT', 'In', 'the', 'Gulf', 'Massacre', '7', 'of', 'all', 'ordnance', 'used', 'was', 'smart', 'The', 'rest', 'that', 's', '93', 'was', 'just', 'regular', 'dumb', 'ol', 'iron', 'bombs', 'and', 'stuff', 'Have', 'you', 'forgotten', 'that', 'the', 'Pentagon', 'definition', 'of', 'a', 'successful', 'Patriot', 'launch', 'was', 'when', 'the', 'missile', 'cleared', 'the', 'launching', 'tube', 'with', 'no', 'damage', 'Or', 'that', 'a', 'successful', 'interception', 'of', 'a', 'Scud', 'was', 'defined', 'as', 'the', 'Patriot', 'and', 'Scud', 'passed', 'each', 'other', 'in', 'the', 'same', 'area', 'of', 'the', 'sky', 'And', 'of', 'the', '7', 'that', 'was', 'the', 'smart', 'stuff', '35', 'hit', 'Again', 'try', 'to', 'follow', 'me', 'here', 'that', 'means', '65', 'of', 'this', 'smart', 'arsenal', 'missed', 'I', 'used', 'to', 'have', 'full', 'figures', 'on', 'this', 'including', 'the', 'tons', 'of', 'bombs', 'dropped', 'and', 'the', 'number', 'of', 'cluster', 'bomblet', 'munitions', 'used', 'I', 'had', 'heard', 'the', '90', 'of', 'the', 'laser', 'guided', 'weapons', 'hit', 'which', 'is', 'an', 'unprecedented', 'rate', 'of', 'success', '25', 'of', 'the', 'iron', 'weapons', 'hit', 'again', 'unprecedented', 'The', 'following', 'is', 'a', 'rough', 'estimate', 'but', 'this', 'means', 'of', 'the', '80', '000', 'tons', 'of', 'bombs', 'dropped', 'by', 'US', 'aircraft', 'around', '56', '000', 'tons', 'missed', 'I', 'm', 'not', 'sure', 'what', 'proportion', 'of', 'this', 'was', 'dropped', 'of', 'Baghdad', 'rather', 'than', 'troop', 'concentrations', 'in', 'Iraq', 'and', 'Kuwait', 'Much', 'of', 'the', 'tonnage', 'dropped', 'was', 'cluster', 'munitions', 'as', 'were', 'all', 'the', 'MRLS', 'rounds', 'and', 'many', 'of', 'the', 'artillery', 'rounds', 'Napalm', 'and', 'fuel', 'air', 'explosives', 'were', 'also', 'used', 'Remember', 'how', 'we', 'were', 'told', 'that', 'weapons', 'of', 'mass', 'destruction', 'such', 'as', 'FAE', 'were', 'very', 'naughty', 'indeed', '']

['From', 'per', 'arne', 'melander', 'histocel', 'umu', 'se', 'Per', 'Arne', 'Melander', 'Subject', 'HELP', 'purification', 'of', 'neutrophils', 'from', 'mice', 'Organization', 'University', 'of', 'Ume', 'Sweden', 'Lines', '16', 'Hello', 'I', 'need', 'a', 'technique', 'for', 'separation', 'of', 'polymorphonuclear', 'neutrophils', 'PMN', 'from', 'the', 'peripheral', 'blood', 'of', 'mice', 'Because', 'of', 'the', 'low', 'PMN', 'Lymphocyte', 'ratio', 'approx', '20', '80', 'it', 's', 'not', 'just', 'as', 'easy', 'as', 'the', 'corresponding', 'technique', 'used', 'with', 'human', 'blood', 'Yours', 'Per', 'Arne', 'Melander', 'Ps', 'My', 'E', 'mail', 'address', 'is', 'per', 'arne', 'melander', 'histocel', 'umu', 'se', 'DS', '']

Actual    Target Values : ['alt.atheism', 'sci.med']
Predicted Target Values : ['talk.politics.guns', 'comp.graphics']
Predicted Probabilities : [0.74430966 0.4276343 ]
Permutation explainer: 3it [00:23, 11.81s/it]
Out[48]:
(2, None, 5)
In [49]:
print("SHAP Values Shape : {}".format(shap_values.shape))
print("SHAP Base Values  : {}".format(shap_values.base_values))
SHAP Values Shape : (2, None, 5)
SHAP Base Values  : [[0.14802793 0.25942031 0.21729104 0.22140868 0.15385202]
 [0.14802793 0.25942031 0.21729104 0.22140868 0.15385202]]

Text Plot

In this section, we have generated a text plot showing words that contributed to wrong predictions.

We can notice from the visualization that for the first sample, words like 'weapons', 'explosives', 'troop', 'destruction', etc have contributed to predicting category 'talk.politics.gun'. For the second sample, the model is very much confused between 'sci.med' and 'comp.graphics' with probabilities of 0.39 and 0.42 respectively.

In [50]:
print(selected_categories)
['alt.atheism', 'comp.graphics', 'rec.sport.baseball', 'sci.med', 'talk.politics.guns']
In [ ]:
shap.text_plot(shap_values)

Explain Flax (JAX) Text Classification Networks using SHAP Values

Explain Flax (JAX) Text Classification Networks using SHAP Values

Force Plots

In this section, we have created force plots explaining wrong predictions.

In [52]:
print(selected_categories)
['alt.atheism', 'comp.graphics', 'rec.sport.baseball', 'sci.med', 'talk.politics.guns']
In [ ]:
import re

tokens = re.split("\W+", X_batch_text[0].lower())

shap.force_plot(shap_values.base_values[0][preds[0].item()], shap_values[0][:, preds[0].item()].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[0]])

Explain Flax (JAX) Text Classification Networks using SHAP Values

In [ ]:
import re

tokens = re.split("\W+", X_batch_text[1].lower())

shap.force_plot(shap_values.base_values[1][preds[1].item()], shap_values[1][:, preds[1].item()].values,
                feature_names = tokens[:-1], out_names=selected_categories[preds[1]])

Explain Flax (JAX) Text Classification Networks using SHAP Values

This ends our small tutorial explaining how we can generate SHAP values for Flax (JAX) text classification networks and visualize them to better understand model predictions. Please feel free to let us know your views in the comments section.

References

Sunny Solanki  Sunny Solanki

 Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to let us know in the comments section below (Guest Comments are allowed). We appreciate and value your feedbacks.

If you like our work please give a thumbs-up to our article in the comments section below. You can also support us with a small contribution by clicking on Support Us link in the footer section.