Updated On : Mar-25,2022 Time Investment : ~45 mins

Eli5.lime: Explain Flax (JAX) Text Classifier Predictions Using LIME Algorithm

Interpreting the predictions of the deep neural network is a good practice to understand whether our model has generalized or not. It helps us understand which parts of the data network are used to make predictions. Let's say for example that we have a trained classifier that classifies cats vs dogs, then we need that classifier is using pixels of a cat from the image to recognize the cat and not some background pixels for decision making. Nowadays, it's easy to get higher accuracy by training data with deep neural networks with many layers. Though deep networks give better results than traditional white-box ML models (decision trees, random forest, gradient boosting, etc.), they are less interpretable compared to them and are generally referred to as black-box models. To interpret the results of deep neural networks, an algorithm named LIME (Local Interpretable Model-Agnostic Explanations) is developed. The LIME algorithm generates fake samples from input data and trains a local ML model that can mimic the behavior of our deep neural network. It then uses this local ML model to interpret predictions. If you are interested in learning about how LIME works internally then please check the below tutorial where we have explained the steps that it performs internally to interpret predictions.

As a part of this tutorial, we'll explain how we can use LIME implementation available through eli5 library to explain the predictions made by text classification network designed using Flax. The Flax is a high-level deep learning library designed on top of another deep learning research library named JAX. We have used newsgroups text dataset available from scikit-learn and trained a simple Flax network on it to classify text documents. Then, we have used LIME algorithm implementation available from lime sub-module of eli5 to explain predictions made by the model by trying various parameter combinations of explainer.

We have assumed that the reader has little background on eli5 library. Please feel free to check the below link if you are looking for some background. Though this tutorial can be followed through even without a deep understanding of eli5.

Below, we have listed important sections of the tutorial to give an overview of the material covered.

Important Sections Of Tutorial

  1. Load Data
  2. Vectorize Text Data
  3. Define Neural Network
  4. Define Loss
  5. Train Network
  6. Evaluate Network Performance
  7. Explain Network Predictions Using Eli5.lime.TextExplainer
    • 7.1 Explain Correct Predictions
    • 7.2 Explain Incorrect Prediction
    • 7.3 Try Different Classifier With TextExplainer
    • 7.4 Provide Our Vectorizer to TextExplainer
    • 7.5 Provide Tokenizer Pattern to TextExplainer
    • 7.6 TextExplainer With All Default Arguments
    • 7.7 Provide Multiple TextSamplers to TextExplainer
    • 7.8 Ask TextExplainer To Take Word Position Into Consideration

Below, we have imported the necessary libraries and printed the version that we have used in our tutorial.

import eli5

print("Eli5 Version : {}".format(eli5.__version__))
Eli5 Version : 0.11.0
import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.3.1
import flax

print("FLAX Version : {}".format(flax.__version__))
FLAX Version : 0.4.0
import optax

print("OPTAX Version : {}".format(optax.__version__))
OPTAX Version : 0.1.1

1. Load Data

In this section, we have loaded the newsgroups dataset available from scikit-learn. The dataset has ~18k text documents for 20 different categories of news. For our tutorial, we have selected only 3 categories to keep things simple.

import numpy as np
from sklearn import datasets
import gc

all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
                  'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
                  'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
                  'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
                  'talk.politics.misc','talk.religion.misc']

selected_categories = ['comp.sys.ibm.pc.hardware','rec.sport.hockey','talk.politics.mideast']

X_train_text, Y_train = datasets.fetch_20newsgroups(subset="train", categories=selected_categories, return_X_y=True)
X_test_text , Y_test  = datasets.fetch_20newsgroups(subset="test", categories=selected_categories, return_X_y=True)

X_train_text = np.array(X_train_text)
X_test_text = np.array(X_test_text)

classes = np.unique(Y_train)
mapping = dict(zip(classes, selected_categories))

len(X_train_text), len(X_test_text), classes, mapping
(1754,
 1167,
 array([0, 1, 2]),
 {0: 'comp.sys.ibm.pc.hardware',
  1: 'rec.sport.hockey',
  2: 'talk.politics.mideast'})

2. Vectorize Text Data

Below, we have vectorized our text data. We need to vectorize text data before giving it to neural networks because networks work on floats and not on string data.

Here, we have vectorized text data using Tf-Idf (Term Frequency - Inverse Document Frequency) approach. The Tf-Idf approach generates float value for each word of the sentence in a way that words that appear commonly across many text documents get assigned low values and rarely appearing words get assigned high values. This kind of approach helps with our text classification tasks as we want unique words to separate documents. In order to assign, the Tf-Idf value to each word, we need to tokenize them using a tokenizer that separates words of each sample.

In order to vectorize text data using the Tf-Idf approach, we have used TfidfVectorizer vectorizer available from scikit-learn. It tokenizes words and then assigns the Tf-Idf value to each word of the sample.

Here, we have not covered how Tf-Idf works in detail. Please feel free to check the below link if you are interested in learning about it in-depth.

import sklearn
from jax import numpy as jnp
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer

vectorizer = TfidfVectorizer()

vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)

X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)

X_train.shape, X_test.shape
((1754, 39120), (1167, 39120))

3. Define Neural Network

In this section, we have defined a simple neural network of 3 dense layers using Flax library. The three layers have 128, 64, and 3 output units respectively. We have applied relu activation after the first two layers.

Here, we have not covered in detail various Flax modules and how they are used to design networks. Please feel free to check the below link where we have explained how to design a network using Flax.

After defining the network, we have also initialized and printed shapes of weights and biases of each layer. We have also made sample forward pass through the network using a few train text samples for verification purposes.

from flax import linen
from jax import random

class TextClassifier(linen.Module):
    def setup(self):
        self.linear1 = linen.Dense(features=128, name="DENSE1")
        self.linear2 = linen.Dense(features=64, name="DENSE2")
        self.linear3 = linen.Dense(len(classes), name="DENSE3")

    def __call__(self, inputs):
        x = linen.relu(self.linear1(inputs))
        x = linen.relu(self.linear2(x))
        logits = self.linear3(x)

        return logits #linen.softmax(x)
seed = jax.random.PRNGKey(0)

model = TextClassifier()
params = model.init(seed, X_train[:5])

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
    print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
Layer Name : DENSE1
	Layer Weights : (39120, 128), Biases : (128,)
Layer Name : DENSE2
	Layer Weights : (128, 64), Biases : (64,)
Layer Name : DENSE3
	Layer Weights : (64, 3), Biases : (3,)
preds = model.apply(params, X_train[:5])

preds.shape
(5, 3)

4. Define Loss

In this section, we have defined a function that we'll use as a loss function for our text classification task. We have designed a cross entropy loss using softmax_cross_entropy() function available from Optax library.

def CrossEntropyLoss(weights, input_data, actual):
    logits_preds = model.apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).sum()

5. Train Network

Now, we have trained our network. To train the network, we have defined a small function that will perform the training process. The function takes training data (X, Y), validation data (X_val, Y_val), number of epochs, network parameters, optimizer state, and batch size as input. It then executes the training loop number of epochs time. For each epoch, it loops through whole training data in batches. For each batch, it performs a forward pass to make predictions, calculates loss, calculates gradients, and updates network parameters. It also records loss for each batch and then prints average loss for the total epoch by averaging losses of batches. The function also makes predictions on validation data and prints validation accuracy at the end of each epoch. At last, it returns updated network parameters which we can use to make predictions.

from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score

def TrainModelInBatches(X, Y, X_val, Y_val, epochs, weights, optimizer_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in tqdm(batches):
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss)(weights, X_batch,Y_batch)

            ## Update Weights
            updates, optimizer_state = optimizer.update(gradients, optimizer_state)
            weights = optax.apply_updates(weights, updates)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))

        Y_val_preds = model.apply(weights, X_val)
        val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
        print("Validation  Accuracy : {:.3f}".format(val_acc))

    return weights

Below, we have trained our network by calling function designed in the previous cell. First, we have initialized batch size to 256, a number of epochs to 8, and learning rate to 0.001. Then, we have initialized the network, its parameters, and Adam optimizer. At last, we have called our training routine with the necessary parameters to train our network. We can notice from the loss and validation accuracy getting printed after each epoch that our model is doing a good job.

seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)

model = TextClassifier()
weights = model.init(seed, X_train[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)

final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
100%|██████████| 7/7 [00:05<00:00,  1.27it/s]
CrossEntropyLoss : 268.084
Validation  Accuracy : 0.970
100%|██████████| 7/7 [00:01<00:00,  4.42it/s]
CrossEntropyLoss : 229.332
Validation  Accuracy : 0.984
100%|██████████| 7/7 [00:01<00:00,  4.42it/s]
CrossEntropyLoss : 166.841
Validation  Accuracy : 0.988
100%|██████████| 7/7 [00:01<00:00,  4.64it/s]
CrossEntropyLoss : 98.061
Validation  Accuracy : 0.991
100%|██████████| 7/7 [00:01<00:00,  4.49it/s]
CrossEntropyLoss : 46.757
Validation  Accuracy : 0.990
100%|██████████| 7/7 [00:01<00:00,  4.32it/s]
CrossEntropyLoss : 20.094
Validation  Accuracy : 0.991
100%|██████████| 7/7 [00:01<00:00,  4.46it/s]
CrossEntropyLoss : 9.050
Validation  Accuracy : 0.991
100%|██████████| 7/7 [00:01<00:00,  4.46it/s]
CrossEntropyLoss : 4.694
Validation  Accuracy : 0.991

6. Evaluate Network Performance

In this section, we have evaluated the performance of our network by calculating accuracy, classification report and confusion matrix metrics on test predictions. We can notice from the results that our model seems to be doing a good job at the text classification task. It's still confusing for some samples but overall performance is acceptable.

We have calculated these metrics using scikit-learn various functions. Please feel free to check the below link if you are interested in learning about various ML metrics available through sklearn. It covers the majority of them in-depth.

Then, in the next cell after the below cell, we have also plotted the confusion matrix using scikit-plot. The scikit-plot library provides plotting functionality for many ML metrics. Please feel free to check the below link if you want to learn it.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 1.000
Test  Accuracy : 0.991

Classification Report :
                          precision    recall  f1-score   support

comp.sys.ibm.pc.hardware       0.99      1.00      0.99       392
        rec.sport.hockey       0.99      0.98      0.99       399
   talk.politics.mideast       0.99      0.99      0.99       376

                accuracy                           0.99      1167
               macro avg       0.99      0.99      0.99      1167
            weighted avg       0.99      0.99      0.99      1167


Confusion Matrix :
[[391   1   0]
 [  1 393   5]
 [  3   1 372]]
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([selected_categories[i] for i in Y_test], [selected_categories[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Eli5.lime: Explain Flax (JAX) Text Classifier Predictions Using LIME

7. Explain Network Predictions Using Eli5.lime.TextExplainer

Now, we'll explain predictions made by our network using LIME implementation available through eli5.

In order to explain predictions using LIME, we need to follow the below steps.

  1. Create TextExplainer instance.
  2. Call fit() method on TextExplainer instance with selected text sample and function that takes text samples as input and returns predictions probabilities for them.
  3. Call show_prediction() function on TextExplainer instance to create visualization highlighting words that contributes positively/negatively to prediction category.

7.1 Explain Correct Predictions

In this section, we have explained the first text sample from test data using eli5. In order to do that we have first created an instance of TextExplainer below.

Please make a NOTE that TextExplainer() constructor has a few parameters that are used to modify the default behavior of the explainer. We have covered the majority of them with examples below. Please check this link if you are looking for a doc on parameters explanation.

from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers

explainer = TextExplainer(random_state=123, sampler=MaskingTextSampler())

Below, we have first defined a simple function that takes a list of text documents as input and returns predicted probabilities for them. The function first vectorizes data and then makes predictions using the trained network. It then performs softmax activation on the output of the network (logits) to convert them to probabilities and return them.

Then, we have selected the first sample from our test samples and made predictions on it using our trained network. We have printed original and predicted labels as well as the probability of the model.

At last, we have called fit() method on TextExplainer instance by giving the first text sample and predictions function. This step will internally execute LIME algorithm and generate an explanation for the sample.

import numpy as np

def predict_proba(X_docs):
    X_vect = vectorizer.transform(X_docs)
    X_vect = jnp.array(X_vect.todense(), dtype=jnp.float32)
    logits = model.apply(final_weights, X_vect)
    preds_probs = jax.nn.softmax(logits)
    #preds_probs = np.array([probs / probs.sum() for probs in preds_probs])
    return preds_probs.to_py()

print("Actual    Label : {}".format(selected_categories[Y_test[0]]))
pred_probs = predict_proba(X_test_text[:1])
print("Predicted Label : {}".format(selected_categories[pred_probs.argmax(axis=1)[0]]))
print("Probability     : {:.3f}".format(pred_probs.max(axis=1)[0]))

explainer.fit(X_test_text[0], predict_proba)
Actual    Label : talk.politics.mideast
Predicted Label : talk.politics.mideast
Probability     : 0.968
TextExplainer(char_based=False,
              clf=SGDClassifier(alpha=0.001, loss='log', penalty='elasticnet',
                                random_state=RandomState(MT19937) at 0x7F95B5E405A0),
              random_state=123,
              sampler=MaskingTextSampler(token_pattern='(?u)\\b\\w+\\b'),
              token_pattern='(?u)\\b\\w+\\b',
              vec=CountVectorizer(ngram_range=(1, 2),
                                  token_pattern='(?u)\\b\\w+\\b'))

In the below cell, we have called show_predictions() function on TextExplainer instance to create a visualization showing which words contributed to prediction.

We can notice from the visualization that words like 'politics', 'society', 'truth', 'organization', etc are contributing to predicting category talk.politics.mideast. This makes sense as these words should be contributing to predict talk.politics.mideast category. The 'homosexual' word should have contributed to 'talk.politics.mideast' category though but it seems to be contributing to 'rec.sport.hockey' category which little odd.

Please make a NOTE that by default, show_prediction() function shows output for all categories. We can restrict categories by setting targets parameter. In the cell after the below cell, we have explained it by creating a visualization for an only predicted category. The targets parameter accepts a list of categories that we want to visualize.

explainer.show_prediction(target_names=selected_categories)

Eli5.lime: Explain Flax (JAX) Text Classifier Predictions Using LIME

pred_cat = selected_categories[pred_probs.argmax(axis=1)[0]]

explainer.show_prediction(target_names=selected_categories, targets=[pred_cat])

Eli5.lime: Explain Flax (JAX) Text Classifier Predictions Using LIME

7.2 Explain Incorrect Prediction

In this section, we have explained one wrong prediction. In order to do that we have first made predictions on all test samples and found out indexes of samples predicted wrong. Then, we have selected one sample from those wrong predictions.

We have followed the same process as we have followed earlier to create a visualization explaining the wrong prediction. The original category of the selected sample is talk.politics.mideast but our model predicted comp.sys.ibm.pc.hardware. We can notice that model has predicted comp.sys.ibm.pc.hardware with a probability of 0.63. It's not very much sure about the category. The second highest probability is of talk.politics.mideast category.

X_test_vect = vectorizer.transform(X_test_text).todense()
Y_logits = model.apply(final_weights, jnp.array(X_test_vect, dtype=jnp.float32))
Y_probs = linen.softmax(Y_logits)
Y_preds = Y_probs.argmax(axis=-1).to_py()
del X_test_vect
gc.collect()

wrong_preds = np.argwhere(Y_preds != Y_test)
idx = wrong_preds.flatten()[0]

print("Actual    Label : {}".format(selected_categories[Y_test[idx]]))
pred_probs = predict_proba(X_test_text[idx:idx+1])
print("Predicted Label : {}".format(selected_categories[pred_probs.argmax(axis=1)[0]]))
print("Probability     : {:.3f}".format(pred_probs.max(axis=1)[0]))
Actual    Label : talk.politics.mideast
Predicted Label : comp.sys.ibm.pc.hardware
Probability     : 0.636
explainer.fit(X_test_text[idx], predict_proba)
explainer.show_prediction(target_names=selected_categories)

Eli5.lime: Explain Flax (JAX) Text Classifier Predictions Using LIME

7.3 Try Different Classifier With TextExplainer

By default, TextExplainer internally trains SGDClassifier available from scikit-learn to explain predictions. We can replace this local ML model with some other model as well.

In this example, we have used LogisticRegression classifier from scikit-learn instead to check how it helps with creating explanations. We can use almost any estimator from scikit-learn that has get_feature_names() method.

We can notice from the visualization that the results are almost the same as our earlier section. The words like 'politics', 'society', 'yard', 'california', 'genuine', etc are contributing to predicting category talk.politics.mideast.

from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier

explainer = TextExplainer(random_state=123, sampler=MaskingTextSampler(), clf=LogisticRegression(max_iter=10000))

print("Actual    Label : {}".format(selected_categories[Y_test[0]]))
pred_probs = predict_proba(X_test_text[:1])
print("Predicted Label : {}".format(selected_categories[pred_probs.argmax(axis=1)[0]]))
print("Probability     : {:.3f}".format(pred_probs.max(axis=1)[0]))
Actual    Label : talk.politics.mideast
Predicted Label : talk.politics.mideast
Probability     : 0.968
explainer.fit(X_test_text[0], predict_proba)
explainer.show_prediction(target_names=selected_categories)

Eli5.lime: Explain Flax (JAX) Text Classifier Predictions Using LIME

We can also call show_weights() method on TextExplainer if you want a tabular version of visualization showing words contributions.

explainer.show_weights(target_names=selected_categories, targets=selected_categories)

Eli5.lime: Explain Flax (JAX) Text Classifier Predictions Using LIME

7.4 Provide Our Vectorizer to TextExplainer

The TextExplainer lets us provide our own vectorizer if we are not happy with the performance of the vectorizer used by it. We can provide our own vectorizer using vec parameter.

Below, we have given our own vectorizer to vec parameter of TextExplainer and then generated visualization explaining the first test sample.

We can notice from the visualization that the results do not seem to be that good.

from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers

explainer = TextExplainer(random_state=123, sampler=MaskingTextSampler(), vec=vectorizer)

print("Actual    Label : {}".format(selected_categories[Y_test[0]]))
pred_probs = predict_proba(X_test_text[:1])
print("Predicted Label : {}".format(selected_categories[pred_probs.argmax(axis=1)[0]]))
print("Probability     : {:.3f}".format(pred_probs.max(axis=1)[0]))
Actual    Label : talk.politics.mideast
Predicted Label : talk.politics.mideast
Probability     : 0.968
explainer.fit(X_test_text[0], predict_proba)
explainer.show_prediction(target_names=selected_categories, top=(8,8))

Eli5.lime: Explain Flax (JAX) Text Classifier Predictions Using LIME

7.5 Provide Tokenizer Pattern to TextExplainer

The TextExplainer lets us provide our own tokenizing pattern if we do not like the internal pattern used by it. We can provide our own pattern by giving a regular expression pattern to token_pattern parameter.

Below, we have explained for example how we can provide our own tokenization pattern to create an explainer object. We have then generated visualization using that explainer. The results seem little good with we providing tokenizer.

from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers

explainer = TextExplainer(random_state=123, sampler=MaskingTextSampler(), token_pattern=r"\w+")
explainer.fit(X_test_text[0], predict_proba)
explainer.show_prediction(target_names=selected_categories)

Eli5.lime: Explain Flax (JAX) Text Classifier Predictions Using LIME

7.6 TextExplainer With All Default Arguments

In this section, we have generated visualization using TextExplainer object created with default arguments.

from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers

explainer = TextExplainer(random_state=123)
explainer.fit(X_test_text[0], predict_proba)
explainer.show_prediction(target_names=selected_categories)

Eli5.lime: Explain Flax (JAX) Text Classifier Predictions Using LIME

7.7 Provide Multiple TextSamplers to TextExplainer

The TextExplainer lets us provide multiple samplers that will be used to generate samples for training local ML model for LIME algorithm. We can do that by creating an instance of MaskingTextSamplers and providing it to sampler parameter of TextExplainer constructor. The MaskingTextSamplers instance is list of MaskingTextSampler instances. We can create MaskingTextSamplers instance using a list of dictionaries where each dictionary specifies parameters of MaskingTextSampler instance.

Below, we have explained with example how we can create TextExplainer instance using multiple samplers and then use it to interpret predictions of the network.

Please make a NOTE that 'MaskingTextSamplers' and 'MaskingTextSampler' are different.

from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers

explainer = TextExplainer(random_state=123, sampler=MaskingTextSamplers([{"bow":True, "token_pattern": r"\w+"}, {"bow": False, "token_pattern": r"\w+"}]))
explainer.fit(X_test_text[0], predict_proba)
explainer.show_prediction(target_names=selected_categories)

Eli5.lime: Explain Flax (JAX) Text Classifier Predictions Using LIME

7.8 Ask TextExplainer To Take Word Position Into Consideration

We can ask TextExplainer instance to take position of words into consideration by setting position_dependent parameter to True. This can help improve performance in situations where some words together have some context.

Below, we have generated TextExplainer instance with position_dependent parameter set to True. Then, we have used it to interpret network prediction. The result is almost the same as previous results.

from eli5.lime import TextExplainer
from eli5.lime.samplers import MaskingTextSampler, MaskingTextSamplers

explainer = TextExplainer(random_state=123, sampler=MaskingTextSampler(), position_dependent=True)
explainer.fit(X_test_text[0], predict_proba)
explainer.show_prediction(target_names=selected_categories)

Eli5.lime: Explain Flax (JAX) Text Classifier Predictions Using LIME

This ends our small tutorial explaining how we can use LIME implementation available through eli5.lime module to explain the predictions generated by text classification network created using Flax (JAX). Please feel free to let us know your views in the comments section.

References

Sunny Solanki  Sunny Solanki

Share Views Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.