Updated On : Mar-05,2022 Tags lime, image-classificati‚Ķ
LIME: Explain Keras Image Classification Network (CNN) Predictions

LIME: Explain Keras Image Classification Network (CNN) Predictions

LIME (Local Interpretable Model-Agnostic Explanations) is one of the most commonly used algorithms to explain the predictions of black-box models. We can generate feature importances when we are using ML models like linear regression, decision trees, random forests, gradient boosting trees, etc. But when we are using deep neural networks with many layers, it's hard to generate feature importances specifying how much each feature contributed to particular predictions. In those situations, LIME can help us. It internally generates a few fake samples from the input sample, trains a simple ML model (decision tree, linear regression, etc.) that can mimic the prediction of our black-box model (deep neural network), and use that simple model's prediction to explain the prediction of our black-box model. This helps us generate feature importances for our black-box model. We have a separate tutorial where we have covered how LIME works in detail. Please feel free to check it from the below link.

As a part of this tutorial, we'll be primarily concentrating on image classification tasks. We have created a keras convolution neural network and trained it on the Fashion MNIST dataset. Later on, we have used LIME to explain the predictions made by the model. We have plotted visualizations showing which parts of the images are contributing to prediction.

Below, we have listed important sections of our tutorial to give an overview of the material covered.

Important Sections Of Tutorial

  1. Load Data
  2. Create And Train Model
  3. Evaluate Network Performance
  4. Explain Predictions Using Lime Image Explainer

Below, we have imported the necessary libraries and printed the versions that we have used in our tutorial.

In [1]:
import tensorflow
from tensorflow import keras

print("Keras Version : {}".format(keras.__version__))
Keras Version : 2.6.0
In [2]:
import lime

#print("LIME Version : {}".format(lime.__version__))

1. Load Data

In this section, we have loaded the Fashion MNIST dataset available from keras. The dataset has grayscale images of shape (28,28) pixels of 10 different fashion items. The dataset is already divided into the train (60k images) and test (10k images) sets. Below, we have included a table that has a mapping from index to item names.

Label Description
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot
In [3]:
from tensorflow import keras
from sklearn.model_selection import train_test_split
from jax import numpy as jnp
import numpy as np

(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()

X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)

X_train, X_test = X_train/255.0, X_test/255.0

classes =  np.unique(Y_train)
class_labels = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot"]
mapping = dict(zip(classes, class_labels))

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
40960/29515 [=========================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
26435584/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
16384/5148 [===============================================================================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step
4431872/4422102 [==============================] - 0s 0us/step
Out[3]:
((60000, 28, 28, 1), (10000, 28, 28, 1), (60000,), (10000,))

2. Create And Train Model

In this section, we have created a Convolution neural network for the image classification tasks and trained it. The network consists of 3 convolution layers and 1 dense layer. The three convolution layers have 48, 32, and 16 filters respectively with a kernel size of (3,3). All convolution layers have relu activation function. The output of the third convolution layer is flattened and given to the dense layer which has 10 units as input. The dense layer has softmax activation function.

After creating and initializing the network, we have compiled it to use Adam optimizer, cross entropy loss, and accuracy metric.

At last, we have trained the network for 10 epochs with a batch size of 256. We can notice from loss and accuracy getting printed after each epoch that the model is doing a good job at the image classification task.

In [4]:
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers

model = Sequential([
    layers.Input(shape=X_train.shape[1:]),
    layers.Conv2D(filters=48, kernel_size=(3,3), padding="same", activation="relu"),
    layers.Conv2D(filters=32, kernel_size=(3,3), padding="same", activation="relu"),
    layers.Conv2D(filters=16, kernel_size=(3,3), padding="same", activation="relu"),

    layers.Flatten(),
    layers.Dense(len(classes), activation="softmax")
])

model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d (Conv2D)              (None, 28, 28, 48)        480
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 32)        13856
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 16)        4624
_________________________________________________________________
flatten (Flatten)            (None, 12544)             0
_________________________________________________________________
dense (Dense)                (None, 10)                125450
=================================================================
Total params: 144,410
Trainable params: 144,410
Non-trainable params: 0
_________________________________________________________________
In [5]:
model.compile("adam", "sparse_categorical_crossentropy", ["accuracy"])
In [6]:
model.fit(X_train, Y_train, batch_size=256, epochs=10, validation_data=(X_test, Y_test))
Epoch 1/10
235/235 [==============================] - 62s 261ms/step - loss: 0.5122 - accuracy: 0.8177 - val_loss: 0.3762 - val_accuracy: 0.8709
Epoch 2/10
235/235 [==============================] - 60s 257ms/step - loss: 0.3192 - accuracy: 0.8862 - val_loss: 0.3237 - val_accuracy: 0.8839
Epoch 3/10
235/235 [==============================] - 61s 259ms/step - loss: 0.2714 - accuracy: 0.9030 - val_loss: 0.2903 - val_accuracy: 0.8958
Epoch 4/10
235/235 [==============================] - 61s 258ms/step - loss: 0.2363 - accuracy: 0.9157 - val_loss: 0.2925 - val_accuracy: 0.8982
Epoch 5/10
235/235 [==============================] - 61s 260ms/step - loss: 0.2119 - accuracy: 0.9241 - val_loss: 0.2648 - val_accuracy: 0.9076
Epoch 6/10
235/235 [==============================] - 60s 255ms/step - loss: 0.1890 - accuracy: 0.9315 - val_loss: 0.2750 - val_accuracy: 0.9030
Epoch 7/10
235/235 [==============================] - 61s 260ms/step - loss: 0.1697 - accuracy: 0.9389 - val_loss: 0.2520 - val_accuracy: 0.9128
Epoch 8/10
235/235 [==============================] - 62s 263ms/step - loss: 0.1532 - accuracy: 0.9450 - val_loss: 0.2666 - val_accuracy: 0.9098
Epoch 9/10
235/235 [==============================] - 61s 261ms/step - loss: 0.1385 - accuracy: 0.9509 - val_loss: 0.2698 - val_accuracy: 0.9136
Epoch 10/10
235/235 [==============================] - 61s 260ms/step - loss: 0.1226 - accuracy: 0.9559 - val_loss: 0.2883 - val_accuracy: 0.9104
Out[6]:
<keras.callbacks.History at 0x7f038eaefc10>

3. Evaluate Network Performance

In this section, we have evaluated the performance of the network by calculating accuracy, confusion matrix and classification report metrics on test predictions. We can notice from test accuracy that the model is doing a good job overall. When we look closely at the classification report and confusion matrix, we can notice that model is confused in categories T-shirt/top, Pullover, and Shirt. They have a little less accuracy compared to other categories.

We have used various functions available from scikit-learn to calculate various ML metrics. Please feel free to check the below link if you want to learn about them as the tutorial covers the majority of them.

In [7]:
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

Y_test_preds = model.predict(X_test)
Y_test_preds = np.argmax(Y_test_preds, axis=1)

print("Test Accuracy : {}".format(accuracy_score(Y_test, Y_test_preds)))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, Y_test_preds))
print("\nClassification Report :")
print(classification_report(Y_test, Y_test_preds, target_names=class_labels))
Test Accuracy : 0.9104

Confusion Matrix :
[[830   1  13  20   2   0 130   0   4   0]
 [  1 981   1   9   3   0   3   0   2   0]
 [ 19   1 860  13  34   0  71   0   2   0]
 [ 17   1   7 938  14   0  23   0   0   0]
 [  3   3  66  30 797   0 101   0   0   0]
 [  0   0   0   0   1 983   1  10   0   5]
 [ 84   2  38  29  35   0 805   0   7   0]
 [  0   0   0   0   0   6   0 986   0   8]
 [  9   0   4   2   1   2   7   4 971   0]
 [  1   0   0   0   0   7   0  39   0 953]]

Classification Report :
              precision    recall  f1-score   support

 T-shirt/top       0.86      0.83      0.85      1000
     Trouser       0.99      0.98      0.99      1000
    Pullover       0.87      0.86      0.86      1000
       Dress       0.90      0.94      0.92      1000
        Coat       0.90      0.80      0.84      1000
      Sandal       0.98      0.98      0.98      1000
       Shirt       0.71      0.81      0.75      1000
     Sneaker       0.95      0.99      0.97      1000
         Bag       0.98      0.97      0.98      1000
  Ankle boot       0.99      0.95      0.97      1000

    accuracy                           0.91     10000
   macro avg       0.91      0.91      0.91     10000
weighted avg       0.91      0.91      0.91     10000

4. Explanation Using Lime Image Explainer

In this section, we have explained predictions made by our model using an image explainer available from lime python library. In order to explain prediction using lime, we need to create an instance of LimeImageExplainer. Then, we can call explain_instance() method on it to create an instance of Explanation. The Explanation instance has details about pixels that contributed to prediction. We can call get_image_and_mask() method on Explanation instance to generate an image and a mask that can be plotted to see which pixels contributed to prediction.

Below, we have first created an instance of LimeImageExplainer using LimeImageExplainer() constructor available from lime_image sub-module of lime. We have not covered in detail various parameters of LimeImageExplainer() constructor here. Please feel free to check the below link if you want to tweak default settings.

In [8]:
from lime import lime_image
In [9]:
explainer = lime_image.LimeImageExplainer(random_state=123)

explainer
Out[9]:
<lime.lime_image.LimeImageExplainer at 0x7f0311846510>

Create Prediction Function

In this section, we have created a simple function that takes images as input and makes predictions. This function will be required by explain_instance() method later. It simply returns 10 probabilities for each sample.

Please make a NOTE that our function takes as an input color image and works on it by converting it to a grayscale image. The reason behind this is that explain_instance() function of LimeImageExplainer can take the grayscale image as input but it converts the grayscale image to color before passing it to the prediction function.

In [10]:
import skimage
from skimage.color import gray2rgb, rgb2gray

def make_prediction(color_img):
    gray_img = rgb2gray(color_img).reshape(-1,28,28,1)
    preds = model.predict(gray_img)
    return preds
In [11]:
colored_image = gray2rgb(X_test[0].squeeze())

preds = make_prediction(colored_image)

preds.shape
Out[11]:
(1, 10)

4.1 Explain True Prediction

In this section, we have explained the correct prediction made by our model. We have created separate visualizations showing pixels that contribute positively and negatively to prediction.

Below, we have first randomly selected a sample from test data. Then, we have printed the actual label of data and the predicted label. Then, we have called explain_instance() function with the test sample and prediction function. It generated an instance of Explanation which we'll use to generate an image and mask for visualization purposes next.

In [12]:
from skimage.segmentation import felzenszwalb, flood_fill, flood

rng = np.random.RandomState(42)
idx = rng.choice(range(len(X_test)))

print("Actual Target Value     : {}".format(mapping[Y_test[idx]]))
pred = model.predict(X_test[idx:idx+1]).argmax(axis=1)[0]
print("Predicted Target Values : {}".format(mapping[pred]))

explanation = explainer.explain_instance(X_test[idx].squeeze(), make_prediction, random_seed=123)

explanation
Actual Target Value     : T-shirt/top
Predicted Target Values : T-shirt/top
Out[12]:
<lime.lime_image.ImageExplanation at 0x7f031186e7d0>

Pixels Contributing Positively to Prediction

In this section, we have generated an image and mask that has pixels contributing positively to the prediction highlighted. We have given actual label as input to the get_image_and_mask() function of Explanation instance.

Then, in the next cell, we have displayed the actual image, the image returned by get_image_and_mask() function, mask, and combination of image and mask. This will help us compare and see which pixels contributed to the prediction. We can notice that the top and bottom pixels of the t-shirt are contributing to prediction positively.

In [13]:
img, mask = explanation.get_image_and_mask(Y_test[idx], positive_only=True, hide_rest=True)

img.shape, mask.shape
Out[13]:
((28, 28, 3), (28, 28))
In [ ]:
from skimage.segmentation import mark_boundaries
import matplotlib.pyplot as plt

def plot_comparison(main_image, img, mask):
    fig = plt.figure(figsize=(15,5))

    ax = fig.add_subplot(141)
    ax.imshow(main_image, cmap="gray");
    ax.set_title("Original Image")
    ax = fig.add_subplot(142)
    ax.imshow(img);
    ax.set_title("Image")
    ax = fig.add_subplot(143)
    ax.imshow(mask);
    ax.set_title("Mask")
    ax = fig.add_subplot(144)
    ax.imshow(mark_boundaries(img, mask, color=(0,1,0)));
    ax.set_title("Image+Mask Combined");

plot_comparison(X_test[idx], img, mask)

LIME: Explain Keras Image Classification Network (CNN) Predictions

Pixels Contributing Negatively to Prediction

In this section, we have generated visualization showing pixels that contributes negatively to the prediction category. We have set negative_only parameter of method get_image_and_mask() function to True and positive_only parameter to False. We can notice from the visualization that some middle pixels are contributing negatively to prediction which means that they could be positively contributing to some other category.

In [15]:
img, mask = explanation.get_image_and_mask(Y_test[idx], positive_only=False, negative_only=True, hide_rest=True)

img.shape, mask.shape
Out[15]:
((28, 28, 3), (28, 28))
In [ ]:
plot_comparison(X_test[idx], img, mask)

LIME: Explain Keras Image Classification Network (CNN) Predictions

4.2 Explain True Predictions With Segmentation Method

In this section, we have again explained correct prediction but this time, we have provided different segmentation methods to explain_instance() function to see whether it makes any difference. We have used felzenszwalb segmentation method available from scikit-image library. We have generated a new explanation instance for a random test sample using felzenszwalb segmentation method.

In [17]:
from skimage.segmentation import felzenszwalb

rng = np.random.RandomState(42)
idx = rng.choice(range(len(X_test)))

print("Actual Target Value     : {}".format(mapping[Y_test[idx]]))
pred = model.predict(X_test[idx:idx+1]).argmax(axis=1)[0]
print("Predicted Target Values : {}".format(mapping[pred]))

explanation = explainer.explain_instance(X_test[idx].squeeze(), make_prediction,
                                         segmentation_fn=felzenszwalb, random_seed=123)

explanation
Actual Target Value     : T-shirt/top
Predicted Target Values : T-shirt/top
Out[17]:
<lime.lime_image.ImageExplanation at 0x7f02ec1c5250>

Pixels Contributing Positively to Prediction

In this section, we have created a visualization showing pixels contributing positively to the prediction category using an Explanation object created using felzenszwalb segmentation method.

In [ ]:
img, mask = explanation.get_image_and_mask(Y_test[idx], positive_only=True, hide_rest=True)

plot_comparison(X_test[idx], img, mask)

LIME: Explain Keras Image Classification Network (CNN) Predictions

Pixels Contributing Negatively to Prediction

In this section, we have created a visualization showing pixels contributing negatively to the prediction category using an Explanation object created using felzenszwalb segmentation method.

In [ ]:
img, mask = explanation.get_image_and_mask(Y_test[idx], positive_only=False, negative_only=True, hide_rest=True)

plot_comparison(X_test[idx], img, mask)

LIME: Explain Keras Image Classification Network (CNN) Predictions

4.3 Explain Wrong Prediction

In this section, we have explained a wrong prediction made by our model. This will help us better understand which pixels are contributing to the prediction of the wrong category.

Below, we have first retrieved indexes of samples that are predicted wrong by our model. Then, we have randomly selected a sample that is predicted wrong by our model. The actual category of the selected sample is Pullover but our model predicts T-shirt/top. We have created an explanation instance as usual.

In [20]:
from skimage.segmentation import felzenszwalb

rng = np.random.RandomState(42)
idx = rng.choice(np.argwhere(Y_test!=Y_test_preds).flatten())

print("Actual Target Value     : {}".format(mapping[Y_test[idx]]))
pred = model.predict(X_test[idx:idx+1]).argmax(axis=1)[0]
print("Predicted Target Values : {}".format(mapping[pred]))

explanation = explainer.explain_instance(X_test[idx].squeeze(), make_prediction,
                                         segmentation_fn=felzenszwalb, random_seed=123)

explanation
Actual Target Value     : Pullover
Predicted Target Values : T-shirt/top
Out[20]:
<lime.lime_image.ImageExplanation at 0x7f035a4e0b90>

Pixels Contributing Positively to Prediction

In this section, we have created a visualization showing pixels that contribute positively to the actual category of the sample. We can notice from the results that it's missing important middle pixels that could have contributed to predicting the category Pullover.

In [ ]:
img, mask = explanation.get_image_and_mask(Y_test[idx], positive_only=True, hide_rest=True)

plot_comparison(X_test[idx], img, mask)

LIME: Explain Keras Image Classification Network (CNN) Predictions

Pixels Contributing Negatively to Prediction

In this section, we have created a visualization showing pixels that contribute negatively to predicting the actual category of the sample. We can notice that middle pixels which should have contributed positively to the prediction category are contributing negatively. This highlights that we need to work on our model and improve it further so that it can catch these kinds of patterns.

In [ ]:
img, mask = explanation.get_image_and_mask(Y_test[idx], positive_only=False, negative_only=True, hide_rest=True)

plot_comparison(X_test[idx], img, mask)

LIME: Explain Keras Image Classification Network (CNN) Predictions

This ends our small tutorial explaining how we can interpret the predictions made by the keras image classification network using LIME. Please feel free to let us know your views in the comments section.

References

Sunny Solanki  Sunny Solanki

 Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to let us know in the comments section below (Guest Comments are allowed). We appreciate and value your feedbacks.

If you like our work please give a thumbs-up to our article in the comments section below. You can also support us with a small contribution by clicking on Support Us link in the footer section.