Updated On : Jul-17,2022 Time Investment : ~10 mins

Haiku (JAX): Glove Embeddings for Text Classification Tasks

Nowadays, Word embeddings is the most commonly used encoding method to encode text data. It gives more representation power to tokens (words) of text. We can generate our own word embeddings if we have enough data to train the network. But there are situations when we don't have enough data to train the network and generate good embeddings. In those situations, we can use pre-trained embeddings like GloVe (Global Vectors) word embeddings. GloVe is an unsupervised algorithm to generate embeddings for words. Researchers at Stanford have generated different versions of glove embeddings by training algorithm on various large datasets. We can use these embeddings for our purpose if we do not have enough data to generate embeddings. These pre-trained embeddings have a very well-captured meaning of various tokens (words).

As a part of this tutorial, we have explained how to create neural networks using Python deep learning library Haiku that uses GloVe word embeddings to solve text classification tasks. Haiku is a high-level deep learning library built on top of a low-level library JAX. We have explained the different way of processing embeddings to get better results. After training networks, we evaluated their performance by calculating various ML metrics. A further performance check is done by investigating individual text example prediction using LIME algorithm.

Below, we have listed important sections of the tutorial to give an overview of the material covered.

Important Sections Of Tutorial

  1. Prepare Data
    • 1.1 Load Glove Embeddings (840B.300d)
    • 1.2 Load Dataset
    • 1.3 Populate Vocabulary and Vectorize Data
    • 1.4 Retrieve GloVe Embeddings for Vectorized Data
  2. Approach 1: Flattened Glove Embeddings
    • 2.1 Define Network
    • 2.2 Define Loss
    • 2.3 Train Network
    • 2.4 Evaluate Network Performance
    • 2.5 Explain Predictions using LIME Algorithm
  3. Approach 2: Averaged Glove Embeddings
  4. Approach 3: Summed Glove Embeddings
  5. Results Summary
  6. Further Suggestions

Below, we have imported the necessary Python libraries and printed the versions that we have used in our tutorial.

Haiku Installation

  • !pip install -U dm-haiku
import haiku as hk

print("Haiku Version :{}".format(hk.__version__))
Haiku Version :0.0.6
import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.3.13
import optax

print("Optax Version : {}".format(optax.__version__))
Optax Version : 0.1.2
import torchtext

print("Torchtext Version : {}".format(torchtext.__version__))
Torchtext Version : 0.10.1

1. Prepare Data

In this section, we have prepared data to be given to the neural network. As mentioned earlier, we'll be using word embeddings approach for encoding text data and GloVe word embeddings will be used for our purpose. We have followed the below step to prepare data for the network.

  1. Load GloVe Embeddings in memory as a dictionary whose keys are words and values are trained embeddings.
  2. Load Dataset.
  3. Tokenize text examples and Populate Vocabulary. Vocabulary is a simple mapping from the token (word) to the integer index. Each token (word) is assigned a unique index starting from 0.
  4. Vectorize text examples. Tokenize each text example and retrieve the integer index of tokens (words) using populated vocabulary.
  5. Create a matrix of GloVe word embeddings. This matrix will be set as a weight matrix of the embedding layer of the neural network because we want to use these trained embeddings.

The output of step 4 will be given to network for training purposes. Don't worry if you don't understand steps 100% as they will become clear when we implement them below.

Haiku(JAX): Glove Embeddings for Text Classification Tasks

1.1 Load GloVe Embeddings (840B.300d)

In this section, we have first downloaded GloVe embeddings from the Stanford NLP website. We have downloaded '840B.300d' version of embeddings. It has a vocabulary of '2.2 M' words and the embedding length is 300. Please feel free to check link Glove Stanford for more details on available glove embeddings.

After downloading the vocabulary zip file, we unzipped it and loaded it into memory. The vocabulary is loaded as a dictionary whose keys are tokens (words) and values are real-valued vectors (embeddings) of length 300.

!wget https://nlp.stanford.edu/data/glove.840B.300d.zip
--2022-06-08 05:53:54--  https://nlp.stanford.edu/data/glove.840B.300d.zip
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: http://downloads.cs.stanford.edu/nlp/data/glove.840B.300d.zip [following]
--2022-06-08 05:53:54--  http://downloads.cs.stanford.edu/nlp/data/glove.840B.300d.zip
Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22
Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2176768927 (2.0G) [application/zip]
Saving to: ‘glove.840B.300d.zip’

glove.840B.300d.zip 100%[===================>]   2.03G  5.04MB/s    in 6m 51s

2022-06-08 06:00:46 (5.05 MB/s) - ‘glove.840B.300d.zip’ saved [2176768927/2176768927]

!unzip glove.840B.300d.zip
Archive:  glove.840B.300d.zip
  inflating: glove.840B.300d.txt
import gc
gc.collect()
75
%%time

import numpy as np
import gc

glove_embeddings = {}
with open("glove.840B.300d.txt") as f:
    for line in f:
        try:
            line = line.split()
            glove_embeddings[line[0]] = np.array(line[1:], dtype=np.float32)
        except:
            continue
CPU times: user 2min 1s, sys: 4.78 s, total: 2min 6s
Wall time: 2min 6s
embeddings = glove_embeddings["the"]

embeddings.shape, embeddings.dtype
((300,), dtype('float32'))
gc.collect()
63

1.2 Load Dataset

In this section, we have loaded the dataset that we are going to use for our text classification task. We have loaded AG News dataset available from Python torchtext library. The dataset has news articles on 4 different news categories (["World", "Sports", "Business", "Sci/Tech"]).

import numpy as np
import gc

train_dataset, test_dataset = torchtext.datasets.AG_NEWS()

X_train_text, Y_train = [], []
for Y, X in train_dataset:
    X_train_text.append(X)
    Y_train.append(Y)

X_test_text, Y_test = [], []
for Y, X in test_dataset:
    X_test_text.append(X)
    Y_test.append(Y)

unique_classes = list(set(Y_train))
target_classes = ["World", "Sports", "Business", "Sci/Tech"]

#Subtracted 1 from labels to bring range from 1-4 to 0-3
Y_train, Y_test = np.array(Y_train) - 1, np.array(Y_test) - 1

len(X_train_text), len(X_test_text)
train.csv: 29.5MB [00:00, 85.8MB/s]
test.csv: 1.86MB [00:00, 39.3MB/s]
(120000, 7600)

1.3 Populate Vocabulary and Vectorize Data

In this section, we have performed steps 3 and 4 listed earlier.

First, we have populated the vocabulary of unique tokens (words). We have created an instance of Tokenizer class available from the Python deep learning library Keras. After creating the tokenizer, we have called fit_on_texts() method on it with train and test examples. This method call will loop through each text example one by one, tokenize them (split them into tokens), and populate the vocabulary of unique tokens (words). The vocabulary is available through index_word and word_index attributes of the tokenizer object. We have printed the length of vocabulary as well.

After populating vocabulary, we have vectorized text data by calling texts_to_sequences() method on the tokenizer object with train and text examples one by one. This method will tokenize each text example into tokens (words) and retrieve an integer index for tokens using our populated vocabulary. The output of this method is a list of integers per text example. As we know that each text example has a different number of words hence the length of each text example is different. For our network, we need constant length hence we have decided to keep 50 tokens per text example. We have accomplished this by calling pad_sequences() method on the vectorized output. This method ensures that each example has a length of 50. The examples that have more than 50 indexes will be truncated at 50 and those who have less than 50 will be appended with 0s.

After vectorizing data, we have converted data arrays to JAX arrays as Haiku networks only works on them.

from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from jax import numpy as jnp

max_tokens = 50

tokenizer = Tokenizer()
tokenizer.fit_on_texts(X_train_text+X_test_text)

print("Vocabulary Size : {}".format(len(tokenizer.index_word)))

print("Vocabulary Starts @ Index 1: {}".format(list(tokenizer.index_word.items())[:5]))

X_train_vect = pad_sequences(tokenizer.texts_to_sequences(X_train_text), maxlen=max_tokens, padding="post", truncating="post", value=0.)
X_test_vect  = pad_sequences(tokenizer.texts_to_sequences(X_test_text), maxlen=max_tokens, padding="post", truncating="post", value=0.)

X_train_vect, X_test_vect = jnp.array(X_train_vect, dtype=jnp.int32), jnp.array(X_test_vect, dtype=jnp.int32)
Y_train, Y_test = jnp.array(Y_train), jnp.array(Y_test)

X_train_vect.shape, X_test_vect.shape
Vocabulary Size : 72002
Vocabulary Starts @ Index 1: [(1, 'the'), (2, 'to'), (3, 'a'), (4, 'of'), (5, 'in')]
((120000, 50), (7600, 50))
print(X_train_vect[:3])
[[  444   440  1697 15012   108    64     1   852    21    21   739  8198
    444  6337 10243  2965     4  5937 26696    40  4014   801   335     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0]
 [15470  1111   871  1313  4306    21    21   919   809   359 15470    99
    101    22     3  4508     8   504   511 13730     6 15471  1514  2135
      5     1   522   247    22  3938  2289    15  6459     7   209   368
      4     1   129     0     0     0     0     0     0     0     0     0
      0     0]
 [   53     6   379  4509 26697   770    21    21  2446   467    90  1885
   1280    66     1   379     6     1   770     8   285    40   190     2
   5766    34     1   296   129   111    82   230     1  6391     4     1
   1208 15472     0     0     0     0     0     0     0     0     0     0
      0     0]]
# what is word 444

print(tokenizer.index_word[444])

## How many times it comes in first text document?? 
print()
print(X_train_text[0]) ## two times
wall

Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.

1.4 Retrieve GloVe Embeddings for Vectorized Data

In this section, we have created a GloVe embedding matrix which will be set as a weight matrix of the embedding layer of the network. We have first created array of shape (vocab_size, embed_length) = (vocab_size, 300). We are simply looping through the integer index from 0 to the size of vocabulary. For each index, we are retrieving word using our vocabulary. Using this word, we are retrieving glove embeddings from the glove dictionary we had loaded earlier. This way we'll have glove embedding of each word in a matrix.

The input to the network are indexes that represent words. The embedding matrix also has mapped embeddings of words according to these indexes. E.g., if word 'the' has an integer index of 4 per vocabulary then its glove embedding can be retrieved by indexing the embedding matrix as 'embedding_matrix[4]'.

%%time

embed_len = 300

word_embeddings = np.zeros((len(tokenizer.index_word)+1, embed_len))

for idx, word in tokenizer.index_word.items():
    word_embeddings[idx] = glove_embeddings.get(word, np.zeros(embed_len))

word_embeddings = jnp.array(word_embeddings)
CPU times: user 316 ms, sys: 135 ms, total: 451 ms
Wall time: 455 ms
word_embeddings[1][:10]
DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
              0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32)

2. Approach 1: Flattened GloVe Embeddings

In this section, we have explained the first approach of using GloVe word embeddings. Our first approach simply flattens word embeddings (stacks them next to each other) of a single text example before giving it to a linear layer for processing. After training the network, we have also evaluated the performance by calculating various ML metrics. We have also used LIME algorithm to further explain predictions made by the network.

2.1 Define Network

In this section, we have defined a network that we'll use for our text classification task. The network consists of three layers (one embedding and two linear).

The first layer of our network is the embedding layer which we have created using Embed() constructor. We have given vocab size and embedding length to the constructor. We have also set embedding_matrix of the layer as glove embedding matrix we created in the previous section. This will set the glove embeddings matrix as the weight matrix of the layer. The embedding layer transforms shape of data from (batch_size, max_tokens) = (batch_size, 50) to (batch_size, max_tokens, embed_len) = (batch_size, 50, 300).

The output of embedding layer is flattened which transforms shape from (batch_size, 50, 300) to (batch_size, 50 x 300) = (batch_size, 15000).

This flattened output is given to the first linear layer that has 128 output units. It transforms shape to (batch_size, 128) after processing.

The output of the first linear layer is given to the second linear layer which has 4 output units (same as a number of target classes). The output of the second linear layer is a prediction of our network.

After defining the network, we have transformed it (using hk.transform()) to pure JAX function form and initialized it. After initializing it, we printed the shape of weights/biases and performed a forward pass for verification purposes. We have also verified that glove embeddings were set properly.

If you are someone who is new to Haiku and want to learn how to design neural networks using it then we'll recommend that you go through the below link. It'll help you get started with the library.

class EmbeddingClassifier(hk.Module):
    def __init__(self):
        super().__init__(name="EmbeddingClassifier")
        self.embedding = hk.Embed(vocab_size=len(tokenizer.word_index)+1, embed_dim=embed_len,
                                  embedding_matrix=word_embeddings, ## Set GloVe Embeddings as Layer Weights
                                  name="Word_Embeddings")
        self.linear1 = hk.Linear(128, name="Dense1")
        self.linear2 = hk.Linear(len(target_classes), name="Dense2")
        self.flatten = hk.Flatten()

    def __call__(self, X_batch):
        x = self.embedding(X_batch) ## (batch_size, max_tokens, embed_len) = (1024, 50, 300)
        x = self.flatten(x) ## (batch_size, max_tokens x embed_len) = (32, 15000)
        x = self.linear1(x)
        return self.linear2(x)
def EmbeddingClassifierrNet(x):
    classif = EmbeddingClassifier()
    return classif(x)

embed_classif = hk.transform(EmbeddingClassifierrNet)
rng = jax.random.PRNGKey(42)

params = embed_classif.init(rng, X_train_vect[:5])

print("Weights Type : {}\n".format(type(params)))

for layer_name, weights in params.items():
    print(layer_name)
    #print(weights.keys())
    if "Embeddings" in layer_name:
        print("Embeddings : {}\n".format(weights["embeddings"].shape))
    else:
        print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))
Weights Type : <class 'dict'>

EmbeddingClassifier/~/Word_Embeddings
Embeddings : (72003, 300)

EmbeddingClassifier/~/Dense1
Weights : (15000, 128), Biases : (128,)

EmbeddingClassifier/~/Dense2
Weights : (128, 4), Biases : (4,)

params["EmbeddingClassifier/~/Word_Embeddings"]["embeddings"][1][:10], word_embeddings[1][:10]
(DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32))
preds = embed_classif.apply(params, rng, X_train_vect[:5])

preds[:5]
DeviceArray([[ 0.3479766 , -0.05430334,  0.00601142,  0.00389299],
             [ 0.13846956, -0.01907502,  0.02674406, -0.02271202],
             [ 0.16998033, -0.20147529, -0.02332148, -0.11550201],
             [-0.10779881, -0.21140411,  0.07253836,  0.1694662 ],
             [ 0.6230196 , -0.14240463,  0.2389124 , -0.03436066]],            dtype=float32)

2.2 Define Loss

In this section, we have defined a cross entropy loss function which we'll use as a loss function for our text classification task. The function calculates loss using softmax_cross_entropy() function available from Python optax library by providing predictions and actual target values to it.

def CrossEntropyLoss(params, input_data, actual):
    logits_preds = model.apply(params, rng, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(target_classes))
    return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).mean()

2.3 Train Network

Now, we'll train the network we defined earlier. To train the network, we have created a function. The function takes train data (X_train, Y_train), validation data (X_val, Y_val), number of epochs, network parameters, optimizer state, and batch size. It then executes a training loop number of epochs time. During each epoch, it loops through training data in batches. For each batch of data, it performs a forward pass to make predictions, calculates loss, calculates gradients, and updates network parameters. It records the loss of each batch and prints the average loss of all batches at the end of an epoch. We have also calculated validation accuracy at the end of each epoch and printed it. At last, the function returns updated network parameters.

Please make a NOTE that during the training process we have excluded updates to the embedding layer as we don't want to update GloVe embeddings set to layer. We want to use it as it is.

from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score

def TrainModelInBatches(X_train, Y_train, X_val, Y_val, epochs, params, optimizer_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X_train.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in tqdm(batches):
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X_train[start:end], Y_train[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss)(params, X_batch, Y_batch)

            #params = jax.tree_map(UpdateWeights, params, param_grads) ## Update Params
            updates, optimizer_state = optimizer.update(gradients, optimizer_state)
            ## Prevent Updates to Embedding Layer by setting updates to zeros
            updates = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=np.float32) if x.shape == word_embeddings.shape else x, updates)
            params = optax.apply_updates(params, updates)
            losses.append(loss) ## Record Loss

        print("CrossEntropy Loss : {:.3f}".format(jnp.array(losses).mean()))
        gc.collect()
        Y_val_preds = model.apply(params, rng, X_val)
        val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
        print("Validation  Accuracy : {:.3f}".format(val_acc))
        gc.collect()
    return params

Below, we are actually training our network by calling the training routine defined in the previous cell. We have initialized a number of epochs to 8, batch size to 1024, and learning rate to 0.001. Then, we initialized the network and Adam optimizer. Finally, we have called our training routine with the necessary parameters to perform the training process. We can notice from the loss and accuracy values getting printed after each epoch that our network seems to be doing a good job at the text classification task. After training the network, we have also done simple verification that GloVe embeddings are not updated by mistake.

from jax import value_and_grad

rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.
epochs = 8
batch_size = 1024
learning_rate = 1e-3

model = hk.transform(EmbeddingClassifierrNet)
params = model.init(rng, X_train_vect[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)

final_params = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
100%|██████████| 118/118 [02:30<00:00,  1.28s/it]
CrossEntropy Loss : 0.437
Validation  Accuracy : 0.886
100%|██████████| 118/118 [02:23<00:00,  1.22s/it]
CrossEntropy Loss : 0.282
Validation  Accuracy : 0.884
100%|██████████| 118/118 [02:23<00:00,  1.22s/it]
CrossEntropy Loss : 0.249
Validation  Accuracy : 0.881
100%|██████████| 118/118 [02:25<00:00,  1.23s/it]
CrossEntropy Loss : 0.229
Validation  Accuracy : 0.879
100%|██████████| 118/118 [02:34<00:00,  1.31s/it]
CrossEntropy Loss : 0.215
Validation  Accuracy : 0.876
100%|██████████| 118/118 [02:26<00:00,  1.24s/it]
CrossEntropy Loss : 0.204
Validation  Accuracy : 0.873
100%|██████████| 118/118 [02:26<00:00,  1.24s/it]
CrossEntropy Loss : 0.200
Validation  Accuracy : 0.869
100%|██████████| 118/118 [02:27<00:00,  1.25s/it]
CrossEntropy Loss : 0.203
Validation  Accuracy : 0.864
word_embeddings[1][:10], final_params["EmbeddingClassifier/~/Word_Embeddings"]["embeddings"][1][:10], params["EmbeddingClassifier/~/Word_Embeddings"]["embeddings"][1][:10]
(DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32))
gc.collect()
42

2.4 Evaluate Network Performance

In this section, we have evaluated the performance of our network by calculating ML metrics accuracy score, classification report (precision, recall, and f1-score), and confusion matrix on the test dataset. The accuracy score on test data tells us that our network is doing a good job at the task. The accuracy of predicting labels of individual categories is also good. We have calculated these metrics using scikit-learn.

Please feel free to check the below link if you want to learn about various ML metrics available through sklearn. It can be very helpful.

We have also created a heatmap of confusion matrix to have a better look at the performance of the network per target category. The chart is created using python library scikit-plot. Please do check the below link if you want to learn about the library in-depth as it provides charts for many ML metrics.

Scikit-Plot: Visualizing Machine Learning Algorithm Results & Performance Metrics

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

#train_preds = model.apply(final_params, rng, X_train_vect)
test_preds = model.apply(final_params, rng, X_test_vect)

#print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Test  Accuracy : 0.864

Classification Report :
              precision    recall  f1-score   support

       World       0.84      0.89      0.87      1900
      Sports       0.96      0.91      0.94      1900
    Business       0.81      0.83      0.82      1900
    Sci/Tech       0.85      0.82      0.84      1900

    accuracy                           0.86      7600
   macro avg       0.87      0.86      0.86      7600
weighted avg       0.87      0.86      0.86      7600


Confusion Matrix :
[[1699   44   98   59]
 [  96 1729   42   33]
 [ 128   11 1577  184]
 [  99   12  224 1565]]
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Haiku(JAX): Glove Embeddings for Text Classification Tasks

2.5 Explain Predictions using LIME Algorithm

In this section, we are diving further deep into evaluating the performance of our network by looking at individual predictions. We'll use LIME (Local Interpretable Model-Agnostic Explanations) algorithm to check which words of our text example are contributing to predicting a particular target category. This will help us better understand whether our model has generalized or not. We'll be using Python library lime which provides an implementation of the algorithm. It even let us create visualization highlighting words contributing to predictions.

If you are someone who is new to the concept of LIME then we would recommend that you go through the below links to learn about it in your free time.

To interpret prediction using LIME, we first need to create an instance of LimeTextExplainer which we have done below. We'll use the method of explainer object to create an explanation.

from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=target_classes, verbose=True)

In the below cell, we have first defined a function that takes a list of text examples as input and returns predictions made by the model on them. This function will be used later by explained object method. The function vectorizes data and then gives it to the network for making the prediction.

Then, we randomly selected a text example from the test dataset and made predictions on it using our trained model. Our model is able to correctly predict the target label Sci/Tech for the selected text example. Next, we'll create an explanation for this selected text example.

import numpy as np

def make_predictions(X_batch_text):
    X_batch = pad_sequences(tokenizer.texts_to_sequences(X_batch_text), maxlen=max_tokens, padding="post", truncating="post", value=0)
    preds = model.apply(final_params, rng, jnp.array(X_batch))
    preds = jax.nn.softmax(preds)
    return preds.to_py()

rnd_st = np.random.RandomState(1234)
idx = rnd_st.randint(1, len(X_test_text))

print("Prediction : ", target_classes[model.apply(final_params, rng, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", target_classes[Y_test[idx]])
Prediction :  Sci/Tech
Actual :      Sci/Tech

Below, we have first called explain_instance() method on the explainer object. We have provided a selected text example, prediction function, and target label to the function. This method returns an explanation object which has details about words contributing to predicting the target label Sci/Tech.

Then, we called show_in_notebook() method on the explanation object to generate the visualization. We can notice from the visualization that words like 'privacy', 'technology', 'RFID', 'identification', 'threat', etc are contributing to predicting the target label as 'Sci/Tech' which makes sense as they are commonly used words in the tech field.

explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()

Haiku(JAX): Glove Embeddings for Text Classification Tasks

3. Approach 2: Averaged GloVe Embeddings

In this section, we have introduced one more way of handling GloVe word embeddings. Our approach in this section takes average of word embeddings at the text example level before giving them to the linear layer. The majority of the code is exactly the same as earlier with only the difference in network architecture which now averages word embeddings instead of flattening them.

3.1 Define Network

Below, we have defined a network that we'll use for our text classification task in this section. The network has the same number of layers as earlier (one embedding and two linear). The only difference is in the way word embeddings are handled in the forward pass. This time we have taken the average of embeddings at the text example level before giving them to the linear layer. The word embeddings of each text example will be averaged. The rest of the network architecture is the same as earlier.

As usual, we have initialized the network after defining it, printed shape of weights/biases of layers, and performed a forward pass to make predictions for verification purposes.

class EmbeddingClassifier(hk.Module):
    def __init__(self):
        super().__init__(name="EmbeddingClassifier")
        self.embedding = hk.Embed(vocab_size=len(tokenizer.word_index)+1, embed_dim=embed_len,
                                  embedding_matrix=word_embeddings, ## Set GloVe Embeddings as Layer Weights
                                  name="Word_Embeddings")
        self.linear1 = hk.Linear(128, name="Dense1")
        self.linear2 = hk.Linear(len(target_classes), name="Dense2")

    def __call__(self, X_batch):
        x = self.embedding(X_batch) ## (batch_size, max_tokens, embed_len) = (32, 50, 300)
        x = jnp.mean(x, axis=1) ## (batch_size, embed_len) = (32, 300)
        x = self.linear1(x)
        return self.linear2(x)
def EmbeddingClassifierrNet(x):
    classif = EmbeddingClassifier()
    return classif(x)

embed_classif = hk.transform(EmbeddingClassifierrNet)
rng = jax.random.PRNGKey(42)

params = embed_classif.init(rng, X_train_vect[:5])

print("Weights Type : {}\n".format(type(params)))

for layer_name, weights in params.items():
    print(layer_name)
    #print(weights.keys())
    if "Embeddings" in layer_name:
        print("Embeddings : {}\n".format(weights["embeddings"].shape))
    else:
        print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))
Weights Type : <class 'dict'>

EmbeddingClassifier/~/Word_Embeddings
Embeddings : (72003, 300)

EmbeddingClassifier/~/Dense1
Weights : (300, 128), Biases : (128,)

EmbeddingClassifier/~/Dense2
Weights : (128, 4), Biases : (4,)

params["EmbeddingClassifier/~/Word_Embeddings"]["embeddings"][1][:10], word_embeddings[1][:10]
(DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32))
preds = embed_classif.apply(params, rng, X_train_vect[:5])

preds[:5]
DeviceArray([[-0.00721485,  0.08446569,  0.01554516,  0.03061965],
             [-0.11697381,  0.15633112, -0.1060637 ,  0.08666744],
             [ 0.01437217,  0.01379906,  0.01799658,  0.07644538],
             [ 0.02324617,  0.11502138, -0.00413199,  0.07082729],
             [ 0.05059408,  0.07714069,  0.00601102,  0.06528533]],            dtype=float32)

3.2 Train Network

Below, we have trained our network using exactly the same settings that we had used in our previous approach. We have kept the training parameters settings the same across all approaches in order to properly compare them. We can notice from the loss and accuracy values getting printed after each epoch that our network is doing a good job at the given task.

from jax import value_and_grad

rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.
epochs = 8
batch_size = 1024
learning_rate = 1e-3

model = hk.transform(EmbeddingClassifierrNet)
params = model.init(rng, X_train_vect[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)

final_params = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
100%|██████████| 118/118 [02:00<00:00,  1.03s/it]
CrossEntropy Loss : 0.675
Validation  Accuracy : 0.867
100%|██████████| 118/118 [01:55<00:00,  1.02it/s]
CrossEntropy Loss : 0.375
Validation  Accuracy : 0.880
100%|██████████| 118/118 [01:57<00:00,  1.01it/s]
CrossEntropy Loss : 0.350
Validation  Accuracy : 0.884
100%|██████████| 118/118 [01:56<00:00,  1.02it/s]
CrossEntropy Loss : 0.339
Validation  Accuracy : 0.885
100%|██████████| 118/118 [01:57<00:00,  1.01it/s]
CrossEntropy Loss : 0.332
Validation  Accuracy : 0.886
100%|██████████| 118/118 [02:04<00:00,  1.05s/it]
CrossEntropy Loss : 0.328
Validation  Accuracy : 0.889
100%|██████████| 118/118 [02:03<00:00,  1.05s/it]
CrossEntropy Loss : 0.325
Validation  Accuracy : 0.889
100%|██████████| 118/118 [01:59<00:00,  1.01s/it]
CrossEntropy Loss : 0.323
Validation  Accuracy : 0.889
word_embeddings[1][:10], final_params["EmbeddingClassifier/~/Word_Embeddings"]["embeddings"][1][:10], params["EmbeddingClassifier/~/Word_Embeddings"]["embeddings"][1][:10]
(DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32))

3.3 Evaluate Network Performance

Here, we have evaluated network performance as usual by calculating ML metrics accuracy score, confusion matrix, and classification report on test predictions. We can notice from the accuracy value that it is better compared to our previous approach. The classification report metric and visualization highlight that the performance of the model is improved for all categories.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

#train_preds = model.apply(final_params, rng, X_train_vect)
test_preds = model.apply(final_params, rng, X_test_vect)

#print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Test  Accuracy : 0.889

Classification Report :
              precision    recall  f1-score   support

       World       0.89      0.89      0.89      1900
      Sports       0.95      0.97      0.96      1900
    Business       0.83      0.86      0.84      1900
    Sci/Tech       0.89      0.84      0.86      1900

    accuracy                           0.89      7600
   macro avg       0.89      0.89      0.89      7600
weighted avg       0.89      0.89      0.89      7600


Confusion Matrix :
[[1690   67  103   40]
 [  28 1848   17    7]
 [  96   18 1629  157]
 [  76   22  211 1591]]
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Haiku(JAX): Glove Embeddings for Text Classification Tasks

3.4 Explain Predictions using LIME Algorithm

Below, we have evaluated network performance using LIME algorithm. Like earlier, we randomly selected a text example from test data and made predictions on it. Then, we interpreted prediction using LIME algorithm. We can notice from the visualization that words like 'technology', 'privacy', 'proponents', 'identification', 'threat', 'RFID', etc are contributing to predicting target label as 'Sci/Tech' for selected text example.

from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=target_classes, verbose=True)

rnd_st = np.random.RandomState(1234)
idx = rnd_st.randint(1, len(X_test_text))

print("Prediction : ", target_classes[model.apply(final_params, rng, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", target_classes[Y_test[idx]])

explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()

Haiku(JAX): Glove Embeddings for Text Classification Tasks

4. Approach 3: Summed GloVe Embeddings

In this section, we have introduced the third approach of using word embeddings. Our approach in this section is almost the same as our previous approach with the only minor change being that we are taking sum of embeddings instead of average. The rest of the code is exactly the same as earlier.

4.1 Define Network

Below, we have defined the network that we'll use for our task in this section. The network has the same layers as earlier (one embedding and two linear). The only difference is in the forward pass. This time we have taken sum of embeddings using sum() function before giving it to the linear layer. The rest of the architecture is the same.

After defining the network, we have initialized it, printed shape of weights/biases, and performed a forward pass for verification purposes.

class EmbeddingClassifier(hk.Module):
    def __init__(self):
        super().__init__(name="EmbeddingClassifier")
        self.embedding = hk.Embed(vocab_size=len(tokenizer.word_index)+1, embed_dim=embed_len,
                                  embedding_matrix=word_embeddings, ## Set GloVe Embeddings as Layer Weights
                                  name="Word_Embeddings")
        self.linear1 = hk.Linear(128, name="Dense1")
        self.linear2 = hk.Linear(len(target_classes), name="Dense2")
        self.flatten = hk.Flatten()

    def __call__(self, X_batch):
        x = self.embedding(X_batch) ## (batch_size, max_tokens, embed_len) = (32, 50, 300)
        x = jnp.sum(x, axis=1) ## (batch_size, embed_len) = (32, 300)
        x = self.linear1(x)
        return self.linear2(x)
def EmbeddingClassifierrNet(x):
    classif = EmbeddingClassifier()
    return classif(x)

embed_classif = hk.transform(EmbeddingClassifierrNet)
rng = jax.random.PRNGKey(42)

params = embed_classif.init(rng, X_train_vect[:5])

print("Weights Type : {}\n".format(type(params)))

for layer_name, weights in params.items():
    print(layer_name)
    #print(weights.keys())
    if "Embeddings" in layer_name:
        print("Embeddings : {}\n".format(weights["embeddings"].shape))
    else:
        print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))
Weights Type : <class 'dict'>

EmbeddingClassifier/~/Word_Embeddings
Embeddings : (72003, 300)

EmbeddingClassifier/~/Dense1
Weights : (300, 128), Biases : (128,)

EmbeddingClassifier/~/Dense2
Weights : (128, 4), Biases : (4,)

params["EmbeddingClassifier/~/Word_Embeddings"]["embeddings"][1][:10], word_embeddings[1][:10]
(DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32))
preds = embed_classif.apply(params, rng, X_train_vect[:5])

preds[:5]
DeviceArray([[-0.36074197,  4.223285  ,  0.7772586 ,  1.5309834 ],
             [-5.848691  ,  7.816555  , -5.3031816 ,  4.333373  ],
             [ 0.71860635,  0.6899552 ,  0.8998296 ,  3.8222692 ],
             [ 1.1623046 ,  5.7510695 , -0.2066021 ,  3.5413632 ],
             [ 2.5297008 ,  3.8570359 ,  0.30055428,  3.264265  ]],            dtype=float32)

4.2 Train Network

Below, we have trained our network using the same settings that we have been using for all our previous approaches. We can notice from the loss and accuracy values that our network is doing a good job at the text classification task.

from jax import value_and_grad

rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.
epochs = 8
batch_size = 1024
learning_rate = 1e-3

model = hk.transform(EmbeddingClassifierrNet)
params = model.init(rng, X_train_vect[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)

final_params = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
100%|██████████| 118/118 [01:56<00:00,  1.01it/s]
CrossEntropy Loss : 0.598
Validation  Accuracy : 0.871
100%|██████████| 118/118 [01:55<00:00,  1.02it/s]
CrossEntropy Loss : 0.366
Validation  Accuracy : 0.879
100%|██████████| 118/118 [01:56<00:00,  1.01it/s]
CrossEntropy Loss : 0.352
Validation  Accuracy : 0.879
100%|██████████| 118/118 [02:15<00:00,  1.15s/it]
CrossEntropy Loss : 0.347
Validation  Accuracy : 0.880
100%|██████████| 118/118 [02:00<00:00,  1.02s/it]
CrossEntropy Loss : 0.344
Validation  Accuracy : 0.879
100%|██████████| 118/118 [02:00<00:00,  1.02s/it]
CrossEntropy Loss : 0.343
Validation  Accuracy : 0.877
100%|██████████| 118/118 [02:01<00:00,  1.03s/it]
CrossEntropy Loss : 0.341
Validation  Accuracy : 0.880
100%|██████████| 118/118 [02:00<00:00,  1.03s/it]
CrossEntropy Loss : 0.337
Validation  Accuracy : 0.881
word_embeddings[1][:10], final_params["EmbeddingClassifier/~/Word_Embeddings"]["embeddings"][1][:10], params["EmbeddingClassifier/~/Word_Embeddings"]["embeddings"][1][:10]
(DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32))

4.3 Evaluate Network Performance

Below, we have evaluated the performance of our network by calculating the same ML metrics that we have been calculating for all our previous approaches. We can notice from the accuracy value that it is better compared to the first approach but a little less compared to the previous approach. Next, we'll evaluate performance using LIME.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

#train_preds = model.apply(final_params, rng, X_train_vect)
test_preds = model.apply(final_params, rng, X_test_vect)

#print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Test  Accuracy : 0.881

Classification Report :
              precision    recall  f1-score   support

       World       0.90      0.88      0.89      1900
      Sports       0.92      0.98      0.95      1900
    Business       0.80      0.87      0.84      1900
    Sci/Tech       0.90      0.79      0.84      1900

    accuracy                           0.88      7600
   macro avg       0.88      0.88      0.88      7600
weighted avg       0.88      0.88      0.88      7600


Confusion Matrix :
[[1677   81  104   38]
 [  17 1870    8    5]
 [  88   36 1657  119]
 [  79   35  291 1495]]
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Haiku(JAX): Glove Embeddings for Text Classification Tasks

4.4 Explain Predictions using LIME Algorithm

Below, we have interpreted network performance using LIME algorithm. We randomly selected text example from the test dataset and made predictions on them using our trained network. Then, we interpreted predictions made by our network using LIME. The network correctly predicts the target label as 'Sci/Tech'. The visualization highlights that words like 'technology', 'privacy', 'proponents', 'threat', 'identification', 'advocates', etc are contributing to predicting target label 'Sci/Tech'.

from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=target_classes, verbose=True)

rnd_st = np.random.RandomState(1234)
idx = rnd_st.randint(1, len(X_test_text))

print("Prediction : ", target_classes[model.apply(final_params, rng, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", target_classes[Y_test[idx]])

explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()

Haiku(JAX): Glove Embeddings for Text Classification Tasks

5. Results Summary

Approach GloVe Embeddings Test Accuracy
Approach 1: Flattened Glove Embeddings 840b.300d 86.4 %
Approach 2: Averaged Glove Embeddings 840b.300d 88.9 %
Approach 3: Summed Glove Embeddings 840b.300d 88.1 %

6. Further Suggestions

  • Train Network for more epochs.
  • Try other Glove embeddings like 42B, 2B, etc. We have used 840B in this tutorial.
  • Try GloVe embeddings of different lengths. We have used word embeddings of length 300 in this tutorial.
  • Try other aggregating operations (min, max, std, etc) on embeddings. We tried average and summation.
  • Try different activation functions (relu, tanh, etc).
  • Try different weight initialization methods.
  • Try different learning rate schedulers.
  • Try regularization (dropout, batch normalization, etc).
  • Try different token lengths. We have used 50 tokens per text example.
Sunny Solanki  Sunny Solanki

Share Views Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.