Updated On : May-15,2022 Time Investment : ~30 mins

Flax (JAX): GloVe Embeddings for Text Classification Tasks

Word embeddings is one the most preferred text encoding approach when we are trying to solve any NLP tasks using deep learning algorithms. The other encoding approaches like word frequency, one-hot, Tf-Idf, etc assign just a single scalar value to each word which has very low representation power. It can not capture the different meanings of words in different contexts as well as relationships with other words. The word embeddings on the other hand use a real-valued vector to represent each word. The size of this vector varies a lot. It can be vector of 10,15,20,50,10,200,300,etc scalars. This gives more representation power to the approach which deep learning algorithms can utilize to produce better results. We can generate embeddings for words of our datasets by training them in the network or we can use readily trained embeddings from some other network for our purpose if our dataset is small.

GloVe (Global Vectors) is an unsupervised algorithm that can generate embeddings for words. Stanford professors have already generated GloVe word embeddings for many words by training algorithms on a large corpus of data. They have generated different embeddings of different lengths from twitter and Wikipedia datasets. Glove embeddings are a very good option if your dataset is very small and you can not generate word embeddings on it by training network. Please feel free to check the below link if you are looking for detailed information on GloVe.

As a part of this tutorial, we have explained how we can use GloVe word embeddings with Flax Networks for text classification tasks. Flax is a Python deep learning library designed on top of a low-level Python deep learning library JAX. We have tried different approaches at using GloVe embeddings in the tutorial and compared their results. After training networks, we have also evaluated their performance by calculating various ML metrics. We have even explained predictions made by the network using LIME (Local Interpretable Model-Agnostic Explanations) algorithm.

Please check the below link if you are looking to train your own word embeddings using Flax (JAX) Networks.

Below, we have listed important sections of Tutorial to give an overview of the material covered.

Important Sections Of Tutorial

  1. Prepare Data
    • 1.1 Download And Unzip GloVe Embeddings (840B.300d)
    • 1.2 Load GloVe Word Embeddings in Memory
    • 1.3 Load Dataset
    • 1.4 Tokenize Examples, Populate Vocabulary And Vecorize Text Examples
    • 1.5 Create Matrix Of GloVe Embeddings for Vocabulary Tokens
  2. Approach 1: GloVe Embeddings Flattened (Max Tokens=50, Embedding Length=300)
    • Define Network
    • Define Loss Function
    • Train Network
    • Evaluate Network Performance
    • Explain Network Predictions using LIME Algorithm
  3. Approach 2: GloVe Embeddings Averaged (Max Tokens=50, Embedding Length=300)
  4. Approach 2: GloVe Embeddings Summed (Max Tokens=50, Embedding Length=300)
  5. Results Summary And Further Suggestions

Below, we have imported the necessary Python libraries that we have used in our tutorial and printed their versions for reference purposes.

import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.3.10
import flax

print("Flax Version : {}".format(flax.__version__))
Flax Version : 0.4.2
import optax

print("OPTAX Version : {}".format(optax.__version__))
OPTAX Version : 0.1.2
import torchtext

print("Torchtext Version : {}".format(torchtext.__version__))
Torchtext Version : 0.12.0
from tensorflow import keras

print("Keras Version : {}".format(keras.__version__))
Keras Version : 2.6.0
import gc

1. Prepare Data

In this section, we'll prepare data for the neural network. As we are going to word embeddings approach for encoding text data and we have decided to use GloVe word embeddings for our purpose, we'll be following the below steps for encoding text data.

  1. Loop through all text examples, tokenize them and build a vocabulary of unique tokens (words). A vocabulary is a simple mapping from tokens to their integer indexes. Each token (word) is assigned a unique index starting from 0.
  2. Loop through each text example, tokenize them and retrieve indexes of those tokens using vocabulary created in the previous step. After performing this step, we'll have a list of token indexes for each text example. The output of this step will be given to the neural network for training.
  3. Next, we need to loop through tokens of our vocabulary and retrieve their GloVe embeddings. We need to stack glove embeddings for tokens of vocabulary in a matrix. We'll refer to this matrix as the embeddings matrix.
  4. Now, retrieve GloVe embeddings for each token index created in the 2nd step by indexing the embeddings matrix created in the 3rd step.

So basically, we first convert our text examples to a list of token indexes and then retrieve GloVe embeddings for them by indexing the embedding matrix.

The first 3 steps mentioned above will be performed in this section. The 4th step will be implemented at the embedding layer in the neural network which will map token indexes to their embeddings. We'll set the embedding matrix as the weight matrix of the embedding layer.

Below, we have included an image giving an idea about word embeddings.

Flax (JAX): GloVe Embeddings for Text Classification

1.1 Download And Unzip GloVe Embeddings (840B.300d)

In this section, we have simply downloaded GloVe embeddings from the Stanford website and unzipped them. We have downloaded GloVe 840B.300d embeddings that have word embeddings of length 300 for 2.2 Million tokens.

!wget https://nlp.stanford.edu/data/glove.840B.300d.zip
--2022-05-16 06:52:49--  https://nlp.stanford.edu/data/glove.840B.300d.zip
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: http://downloads.cs.stanford.edu/nlp/data/glove.840B.300d.zip [following]
--2022-05-16 06:52:49--  http://downloads.cs.stanford.edu/nlp/data/glove.840B.300d.zip
Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22
Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2176768927 (2.0G) [application/zip]
Saving to: ‘glove.840B.300d.zip’

glove.840B.300d.zip 100%[===================>]   2.03G  5.05MB/s    in 6m 50s

2022-05-16 06:59:39 (5.06 MB/s) - ‘glove.840B.300d.zip’ saved [2176768927/2176768927]

!unzip glove.840B.300d.zip
Archive:  glove.840B.300d.zip
  inflating: glove.840B.300d.txt

1.2 Load GloVe Word Embeddings in Memory

In this section, we have simply loaded word embeddings in memory. We have created a Python dictionary whose keys are tokens and values are numpy arrays representing embeddings of those tokens.

%%time

import numpy as np

glove_embeddings = {}
with open("glove.840B.300d.txt") as f:
    for line in f:
        try:
            line = line.split()
            glove_embeddings[line[0]] = np.array(line[1:], dtype=np.float32)
        except:
            continue
CPU times: user 3min 4s, sys: 6.42 s, total: 3min 11s
Wall time: 3min 7s
embeddings = glove_embeddings["the"]

embeddings.shape, embeddings.dtype
((300,), dtype('float32'))
gc.collect()
126

1.3 Load Dataset

In this section, we have loaded the dataset that we are going to use for our text classification task. We'll be using newsgroups dataset available from scikit-learn. The dataset has text documents for 20 different news categories, though we'll be using 5 categories for our purpose. The dataset is already divided into train and test sets for our convenience.

import numpy as np
from sklearn import datasets
import gc

all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
                  'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
                  'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
                  'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
                  'talk.politics.misc','talk.religion.misc']

target_classes = ['alt.atheism', 'comp.graphics','rec.autos','rec.sport.hockey','talk.politics.mideast']

X_train_text, Y_train = datasets.fetch_20newsgroups(subset="train", categories=target_classes, return_X_y=True)
X_test_text , Y_test  = datasets.fetch_20newsgroups(subset="test", categories=target_classes, return_X_y=True)

classes = np.unique(Y_train)
mapping = dict(zip(classes, target_classes))

len(X_train_text), len(X_test_text), classes, mapping
(2822,
 1879,
 array([0, 1, 2, 3, 4]),
 {0: 'alt.atheism',
  1: 'comp.graphics',
  2: 'rec.autos',
  3: 'rec.sport.hockey',
  4: 'talk.politics.mideast'})

1.4 Tokenize Examples, Populate Vocabulary, And Vecorize Text Examples

In this section, we have completed the first two steps of our embedding process that we had explained earlier. We have populated vocabulary and transformed text examples into a list of token indexes using this vocabulary here.

First, we have created an instance of Tokenizer() available from the Python deep learning library Keras. We have called fit_on_texts() method on tokenizer with train and test examples. This will populate the vocabulary of all unique tokens and keep it in the tokenizer object. The vocabulary can be accessed by calling word_index property of the tokenizer object. We have also later printed the size of the vocabulary to show the number of tokens.

Next, we have called texts_to_sequences() method with train and text examples separately. This will tokenize each text example into a list of tokens (words) and retrieve their indexes from the vocabulary. The output of this method will be a list of token indexes. The different text examples can have a different number of tokens but for the neural network, we need fixed input sizes. For this, we have decided to keep maximum tokens per text example at 50. To enforce this, we have called pad_sequences() method on the output of texts_to_sequences() method. This method ensures that all example has exactly 50 tokens. The examples that have more than 50 tokens will be truncated at 50 tokens and examples that have less than 50 tokens will be padded with 0s.

We have also printed the first few train examples to show how text gets encoded to a list of token indexes.

Below, we have explained with a simple example how text example is vectorized.

text = "Hello, How are you? Where are you planning to go?"

tokens = ['hello', ',', 'how', 'are', 'you', '?', 'where',
            'are', 'you', 'planning', 'to', 'go', '?']

vocab = {
    'hello': 0,
    'bye': 1,
    'how': 2,
    'the': 3,
    'welcome': 4,
    'are': 5,
    'you': 6,
    'to': 7,
    '<unk>': 8,
}

vector = [0,8,2,4,6,8,8,5,6,8,7,8,8]
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from jax import numpy as jnp

max_tokens = 50 ## Hyperparameter

tokenizer = Tokenizer()
tokenizer.fit_on_texts(X_train_text+X_test_text)

## Vectorizing data to keep 50 words per sample.
X_train_vect = pad_sequences(tokenizer.texts_to_sequences(X_train_text), maxlen=max_tokens, padding="post", truncating="post", value=0.)
X_test_vect  = pad_sequences(tokenizer.texts_to_sequences(X_test_text), maxlen=max_tokens, padding="post", truncating="post", value=0.)

print(X_train_vect[:3])

X_train_vect, X_test_vect = jnp.array(X_train_vect, dtype=jnp.int32), jnp.array(X_test_vect, dtype=jnp.int32)
Y_train, Y_test = jnp.array(Y_train), jnp.array(Y_test)

X_train_vect.shape, X_test_vect.shape
[[   13 13349 26798 26799  5691  6675   743 13349    32    51  4261    47
   2877     6   155     3 11447 21894  6236   334  5844   584     2    93
    636   453   253   253   252   923    36  3172     4   185   337 26800
     67 26801 26802 16357 13350    41  2353 36106 26803  3488    41 36107
  21895  6236]
 [   13  7228 16358  4642  2117  6676  7228    32    51 14634 10675    15
   8590  2758    88    76    85 18672    36  4642    67     2 10067  1676
     33   198     7    19    16   402    28     7   642    46    16  3604
     10     3    62    40  3416     1  1801     2   209   375     4   184
   1706    28]
 [   32    51 12292     4     1   270   117    13  4017 21896 21897   504
   2232  4017 21898    36 26806   108  4262   238  1960   250    33   196
      6    56 11448 36113 36114  6238  6457  1173  5692  6238  6457  1173
     48   563  1412     7    70     5  3417 36115  6935     4   105   182
      7   132]]
((2822, 50), (1879, 50))
print("Vocab Size : {}".format(len(tokenizer.word_index)))
Vocab Size : 54983
## What is word 13

print(tokenizer.index_word[13])

## How many times it comes in first text document??

print(X_train_text[0]) ## 1 times
from
From: fischer@iesd.auc.dk (Lars Peter Fischer)
Subject: Re: Rumours about 3DO ???
In-Reply-To: archer@elysium.esd.sgi.com's message of 6 Apr 93 18:18:30 GMT
Organization: Mathematics and Computer Science, Aalborg University
	<C51Eyz.4Ix@optimla.aimla.com> <1993Apr6.144520.2190@unocal.com>
	<h48vtis@zola.esd.sgi.com>
Lines: 11


>>>>> "Archer" == Archer (Bad Cop) Surly (archer@elysium.esd.sgi.com)

Archer> How about "Interactive Sex with Madonna"?

or "Sexium" for short.

/Lars
--
Lars Fischer, fischer@iesd.auc.dk | It takes an uncommon mind to think of
CS Dept., Aalborg Univ., DENMARK. | these things.  -- Calvin

1.5 Create Matrix Of GloVe Embeddings for Vocabulary Tokens

In this section, we have created an embedding matrix using Glove embeddings that we had described in 3rd step of our encoding process earlier.

Here, we are looping through tokens of our vocabulary and retrieving GloVe embedding for each token from the glove embedding dictionary that we had loaded earlier. We have stacked embeddings for each token to create an embedding matrix. As we know that the embedding length for GloVe 840B.300d is 300, our embedding matrix shape will be (vocab_len, 300).

We can retrieve embeddings of tokens by indexing this matrix using a token index. To explain it with an example, let's say that the index of 'how' token in our vocabulary is '10' then we can simply index embedding matrix like 'word_embeddings[10]' to retrieve the embedding of the token 'how'.

Later on, we'll set this matrix as the weight matrix of the embedding layer of the network which will take token indexes as input and index this matrix to retrieve embeddings for tokens.

%%time

embed_len = 300

word_embeddings = np.zeros((len(tokenizer.index_word)+1, embed_len))

for idx, word in tokenizer.index_word.items():
    word_embeddings[idx] = glove_embeddings.get(word, np.zeros(embed_len))

word_embeddings = jnp.array(word_embeddings)
CPU times: user 312 ms, sys: 219 ms, total: 531 ms
Wall time: 534 ms
word_embeddings[1][:10]
DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
              0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32)

Approach 1: GloVe Embeddings Flattened (Max Tokens=50, Embedding Length=300)

Our approach in this section is based on the concept of flattening embeddings. The flattened embeddings are then given to dense layers for processing. After training the network in this section, we have also evaluated various ML metrics for checking the performance of the network.

Define Network

Below, we have defined the network that we are going to use for our text classification task. The network consists of one embedding layer and two dense layers.

The first layer of our network is the embedding layer. We have created this layer using Embed() constructor from linen module of Flax. We have provided it with vocabulary length as a number of tokens. The embedding length of 300 is set as the value of features parameter. We have initialized the weight of this layer by providing callable to embedding_init parameter. The callable simply returns the embedding matrix that we had created in the previous section. This layer is responsible for mapping token indexes to their embeddings by indexing this weight matrix. The shape of input data to this layer is (batch_size, max_tokens) = (batch_size, 50) and output shape is (batch_size, max_tokens, embed_len) = (batch_size, 50, 300).

Please make a NOTE that as we are using trained embeddings (GloVe embeddings) already, we need to prevent updates to embeddings during the training process. We have done that later by making a minor modification to the training routine that ignores updates to the embedding layer.

The output from embedding layer is flattened so that shape gets transformed from (batch_size, 50, 300) to (batch_size, 50 x 300) = (batch_size, 15000).

This flattened output is then given to our first dense layer that has 100 output units. This transforms shape from (batch_size, 15000) to (batch_size, 100). We have also applied relu activation to the output later.

The relu-activated output of the first dense layer is given to the second dense layer that has 5 output units (same as a number of target classes) for processing. This transforms shape from (batch_size, 100) to (batch_size, 5). The output of the second dense layer is returned as a prediction of our network.

After defining the network, we initialized it and printed the shape of weights and biases of layers. We have also printed a few weights of the embedding layer for verification purposes that it was initialized properly. After initializing the network, we have also performed a forward pass through the network using a few train examples for verification purposes.

Please make a NOTE that we have not described the process of network creation in detail as we expect that reader has little background on Flax library. Please feel free to go through the below link if you are new to the library as it'll help you get started with neural network creation.

from flax import linen

embed_len = 300

class EmbeddingClassifier(linen.Module):
    def setup(self):
        self.embedding = linen.Embed(num_embeddings = len(tokenizer.word_index)+1, features=embed_len,
                                     name="Word Embeddings", embedding_init=lambda *args: word_embeddings)
        self.linear1 = linen.Dense(100, name="Dense1")
        self.linear2 = linen.Dense(len(target_classes), name="Dense2")

    def __call__(self, X_batch):
        x = self.embedding(X_batch)
        x = x.reshape(len(X_batch), -1) ## Flatten Embeddings

        x = self.linear1(x)
        x = linen.relu(x)

        logits = self.linear2(x)
        return logits
from jax import numpy as jnp

seed = jax.random.PRNGKey(0)

embed_classif = EmbeddingClassifier()

params = embed_classif.init(seed, jax.random.randint(seed, (100, max_tokens), minval=1, maxval=20))

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    if "Embedding" in layer_params[0]:
        weights = layer_params[1]["embedding"]
        print("\tLayer Weights : {}".format(weights.shape))
    else:
        weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
        print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
Layer Name : Word Embeddings
	Layer Weights : (54984, 300)
Layer Name : Dense1
	Layer Weights : (15000, 100), Biases : (100,)
Layer Name : Dense2
	Layer Weights : (100, 5), Biases : (5,)
params["params"]["Word Embeddings"]["embedding"][1][:10]
DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
              0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32)
preds = embed_classif.apply(params, X_train_vect[:5])

preds
DeviceArray([[ 0.16070068, -0.18725368, -0.10340895,  0.13814029,
              -0.3553394 ],
             [-0.03378607,  0.00632447, -0.15137541,  0.2485371 ,
              -0.21806909],
             [ 0.11125846, -0.24498412, -0.09532435,  0.3489216 ,
               0.17814054],
             [ 0.23810926, -0.3992119 ,  0.18238424, -0.04051574,
              -0.30003887],
             [ 0.03021425, -0.4305239 , -0.01692869,  0.084884  ,
              -0.09168024]], dtype=float32)

Define Loss Function

In this section, we have created a loss function that we'll use for our classification task. We'll be using cross entropy loss. The function takes network parameters, data, and target labels as input. It then performs a forward pass through the network to make predictions and one-hot encode target labels. Then, we calculate loss using softmax_cross_entropy() function available from Optax.

def CrossEntropyLoss(params, input_data, actual):
    logits_preds = model.apply(params, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(target_classes))
    return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).sum()

Train Network

In this section, we are training the network that we defined earlier. We have defined a function that will help us with the training process. The function takes train data (X, Y), validation data (X_val, Y_val), number of epochs, network parameters, optimizer state, and batch size as input. It then executes a training loop number of epochs time. For each epoch, it loops through training data in batches. For each batch of data, it performs a forward pass to make predictions, calculates loss, calculates gradients, and updates network parameters. It records loss for each batch and prints average training data loss by averaging the loss of all batches at the end of each epoch. We are also calculating the validation accuracy of the model and printing it for checking network performance on validation data.

Please make a NOTE that we are not updating the weights of the embedding layer. We have included a special line in the training routine that applies tree_map() function to updates which zeros updates to be made to the embedding layer hence preventing any updates to it.

from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score

def TrainModelInBatches(X, Y, X_val, Y_val, epochs, params, optimizer_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in tqdm(batches):
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss)(params, X_batch,Y_batch)

            ## Update Network Parameters
            updates, optimizer_state = optimizer.update(gradients, optimizer_state)

            ## IMPORTANT: Prevent Updates to Embedding Layer by setting updates to zeros
            updates = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=np.float32) if x.shape == word_embeddings.shape else x, updates)

            params = optax.apply_updates(params, updates)
            losses.append(loss) ## Record Loss
        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))

        Y_val_preds = model.apply(params, X_val)
        val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
        print("Validation  Accuracy : {:.3f}".format(val_acc))

    return params

Below, we are actually training our network using the training routine defined in the previous cell. We have first initialized bath size to 64, a number of epochs to 8, and learning rate to 0.001. Then, we have initialized our text classifier and Adam optimizer. At last, we have called our training routine with the necessary parameters to perform training. We can notice from the loss and accuracy values getting printed that our network is doing a good job at the classification task.

After training, we have verified the final weights of the embedding layer to confirm that it's not updated.

from jax import random

seed = random.PRNGKey(0)
batch_size=64
epochs=8
learning_rate = jnp.array(1e-3)

model = EmbeddingClassifier()
params = model.init(seed, X_train_vect[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)

final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
100%|██████████| 45/45 [00:41<00:00,  1.08it/s]
CrossEntropyLoss : 45.886
Validation  Accuracy : 0.817
100%|██████████| 45/45 [00:44<00:00,  1.02it/s]
CrossEntropyLoss : 2.921
Validation  Accuracy : 0.834
100%|██████████| 45/45 [00:38<00:00,  1.18it/s]
CrossEntropyLoss : 0.515
Validation  Accuracy : 0.839
100%|██████████| 45/45 [00:37<00:00,  1.19it/s]
CrossEntropyLoss : 0.213
Validation  Accuracy : 0.847
100%|██████████| 45/45 [00:37<00:00,  1.20it/s]
CrossEntropyLoss : 0.141
Validation  Accuracy : 0.850
100%|██████████| 45/45 [00:37<00:00,  1.19it/s]
CrossEntropyLoss : 0.107
Validation  Accuracy : 0.850
100%|██████████| 45/45 [00:32<00:00,  1.37it/s]
CrossEntropyLoss : 0.085
Validation  Accuracy : 0.850
100%|██████████| 45/45 [00:31<00:00,  1.43it/s]
CrossEntropyLoss : 0.070
Validation  Accuracy : 0.849
word_embeddings[1][:10], final_weights["params"]["Word Embeddings"]["embedding"][1][:10], params["params"]["Word Embeddings"]["embedding"][1][:10]
(DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32))

Evaluate Network Performance

In this section, we have evaluated the performance of our trained network by calculating accuracy score, classification report (precision, recall, and f1-score per target class) and confusion matrix metrics on test predictions. We can notice from the accuracy score that it is above average. We have calculated these ML metrics using scikit-learn.

If you want to learn about various ML metrics available from sklearn then we recommend that you go through the below link. It covers the majority of them in detail.

Apart from calculations, we have also plotted confusion matrix metric using Python library scikit-plot. We can notice from the visualization that our network is doing a good job at classifying text documents from categories 'comp.graphics', 'rec.autos' and 'rec.sport.hockey' compared to categories 'alt.atheism' and 'talk.politics.mideast'.

The scikit-plot provides visualizations for many ML metrics. Please feel free to check the below link in your spare time if you want to learn about it.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

test_preds = model.apply(final_weights, X_test_vect)

print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Test  Accuracy : 0.849

Classification Report :
                       precision    recall  f1-score   support

          alt.atheism       0.82      0.78      0.80       319
        comp.graphics       0.84      0.88      0.86       389
            rec.autos       0.84      0.90      0.87       396
     rec.sport.hockey       0.87      0.89      0.88       399
talk.politics.mideast       0.87      0.79      0.83       376

             accuracy                           0.85      1879
            macro avg       0.85      0.85      0.85      1879
         weighted avg       0.85      0.85      0.85      1879


Confusion Matrix :
[[248  18  12  11  30]
 [  9 341  20   9  10]
 [  8  16 356  13   3]
 [  9  11  22 354   3]
 [ 27  18  13  21 297]]
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Flax (JAX): GloVe Embeddings for Text Classification

Explain Network Predictions using LIME Algorithm

In this section, we are trying to explain predictions made by our network using LIME (Local Interpretable Model-Agnostic Explanations) algorithm. The LIME is a very commonly used algorithm when trying to explain predictions made by any black-box ML models. We have used the Python library lime which has an implementation of the algorithm.

If you are someone who is new to the concept of LIME then we suggest that you go through the below links. It will give you in-depth knowledge of it as well as get you started with using it.

In order to explain prediction using LIME, we need to create an instance of LimeTextExplainer which we have done below.

from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=target_classes, verbose=True)

Below, we have first created a prediction function that takes a batch of text examples and returns their predicted probabilities by the network. Please make a NOTE that network outputs are not probabilities and we have converted them to probabilities by applying softmax activation to it.

Then, we randomly selected a text example from test data and made predictions it using our trained model. The model correctly predicts target label as 'rec.sport.hockey'. Next, we'll create an explanation for this prediction.

def make_predictions(X_batch_text):
    X_batch = pad_sequences(tokenizer.texts_to_sequences(X_batch_text), maxlen=50, padding="post", truncating="post", value=0.)
    logits = model.apply(final_weights, jnp.array(X_batch))
    preds = linen.softmax(logits)
    return preds.to_py()

rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))

print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", target_classes[Y_test[idx]])
Prediction :  rec.sport.hockey
Actual :      rec.sport.hockey

Below, we have first created an Explanation object by calling explain_instance() method on LimeTextExplainer instance. We have provided the method with a selected text example, prediction function, and target label. The explanation object has details about words contributing to prediction. Next, we have visualized the explanation by calling show_in_notebook() method on it.

The visualization highlights that words like 'leafs', 'games', 'goals', 'goalie', etc contributed to predicting target label as 'rec.sport.hockey'. This makes sense as these are commonly used words in the hockey world.

explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions,
                                         labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()

Flax (JAX): GloVe Embeddings for Text Classification

Approach 2: GloVe Embeddings Averaged (Max Tokens=50, Embedding Length=300)

Our approach in this section has a minor change in the way we handle word embeddings. In this section, we have taken the average of word embeddings across tokens of each example. The majority of the code is exactly the same as earlier with a minor change in the network architecture.

Define Network

Below, we have defined the network that we'll use for our text classification task. The network has 3 layers like earlier. The only difference is in forward pass logic. Earlier, we had flattened embeddings whereas here, we have taken the average of embeddings across tokens of text examples using mean() function. This will transform data shape from (batch_size, 50, 300) to (batch_size, 300). This output will then be given to a dense layer for processing.

After defining the network, we initialized it and performed a forward pass using a few train examples to make predictions for verification purposes.

from flax import linen

embed_len = 300

class EmbeddingClassifier(linen.Module):
    def setup(self):
        self.embedding = linen.Embed(num_embeddings = len(tokenizer.word_index)+1, features=embed_len,
                                     name="Word Embeddings", embedding_init=lambda *args: word_embeddings)
        self.linear1 = linen.Dense(100, name="Dense1")
        self.linear2 = linen.Dense(len(target_classes), name="Dense2")

    def __call__(self, X_batch):
        x = self.embedding(X_batch)
        x = x.mean(axis=1) ## Average word embeddings for each words together

        x = self.linear1(x)
        x = linen.relu(x)

        logits = self.linear2(x)
        return logits
from jax import numpy as jnp

seed = jax.random.PRNGKey(0)

embed_classif = EmbeddingClassifier()

params = embed_classif.init(seed, jax.random.randint(seed, (100, max_tokens), minval=1, maxval=20))

preds = embed_classif.apply(params, X_train_vect[:5])

preds
DeviceArray([[-0.06569727, -0.10394667,  0.04159085,  0.02477646,
              -0.00985948],
             [-0.02463551, -0.12782736, -0.04452181,  0.05409296,
               0.01808915],
             [-0.06321083, -0.061309  , -0.0065282 ,  0.08547784,
               0.05013282],
             [-0.06172216, -0.11052934,  0.01591276,  0.03255364,
               0.05282662],
             [-0.02530712, -0.142962  , -0.03585762,  0.05743529,
              -0.00298506]], dtype=float32)
params["params"]["Word Embeddings"]["embedding"][1][:10]
DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
              0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32)

Train Network

In this section, we have trained the network using the same settings that we had used in our previous section. We can notice from the loss and accuracy values getting printed after each epoch that our network is doing a good job at the classification task.

from jax import random

seed = random.PRNGKey(0)
batch_size=64
epochs=8
learning_rate = jnp.array(1e-3)

model = EmbeddingClassifier()
params = model.init(seed, X_train_vect[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)

final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
100%|██████████| 45/45 [00:34<00:00,  1.32it/s]
CrossEntropyLoss : 92.191
Validation  Accuracy : 0.652
100%|██████████| 45/45 [00:33<00:00,  1.34it/s]
CrossEntropyLoss : 68.851
Validation  Accuracy : 0.778
100%|██████████| 45/45 [00:33<00:00,  1.34it/s]
CrossEntropyLoss : 48.988
Validation  Accuracy : 0.824
100%|██████████| 45/45 [00:33<00:00,  1.34it/s]
CrossEntropyLoss : 37.212
Validation  Accuracy : 0.843
100%|██████████| 45/45 [00:33<00:00,  1.34it/s]
CrossEntropyLoss : 30.333
Validation  Accuracy : 0.858
100%|██████████| 45/45 [00:34<00:00,  1.31it/s]
CrossEntropyLoss : 25.915
Validation  Accuracy : 0.863
100%|██████████| 45/45 [00:35<00:00,  1.27it/s]
CrossEntropyLoss : 22.830
Validation  Accuracy : 0.864
100%|██████████| 45/45 [00:32<00:00,  1.39it/s]
CrossEntropyLoss : 20.530
Validation  Accuracy : 0.870

word_embeddings[1][:10], final_weights["params"]["Word Embeddings"]["embedding"][1][:10], params["params"]["Word Embeddings"]["embedding"][1][:10]
(DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32))

Evaluate Network Performance

In this section, we have evaluated the performance of our trained network by calculating accuracy score, classification report and confusion matrix metrics on test predictions. We can notice from the accuracy score that it is quite better compared to our previous approach. We have also plotted a confusion matrix for reference purposes which also hints at improvements in classifying text documents.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

test_preds = model.apply(final_weights, X_test_vect)

print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Test  Accuracy : 0.870

Classification Report :
                       precision    recall  f1-score   support

          alt.atheism       0.84      0.82      0.83       319
        comp.graphics       0.86      0.89      0.87       389
            rec.autos       0.84      0.90      0.87       396
     rec.sport.hockey       0.91      0.88      0.90       399
talk.politics.mideast       0.89      0.85      0.87       376

             accuracy                           0.87      1879
            macro avg       0.87      0.87      0.87      1879
         weighted avg       0.87      0.87      0.87      1879


Confusion Matrix :
[[261  16  17   6  19]
 [  8 346  17  10   8]
 [  6  22 356   6   6]
 [  5  11  24 352   7]
 [ 29   7   9  12 319]]
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Flax (JAX): GloVe Embeddings for Text Classification

Explain Network Predictions using LIME Algorithm

In this section, we have explained predictions made by our trained network using LIME algorithm. The network correctly predicts the target label as 'rec.sport.hockey'. The visualization shows that words like 'leafs', 'games', 'goalie', 'goals', 'defense', etc are contributing to predicting target label as 'rec.sport.hockey'.

from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=target_classes)

rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))

print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", target_classes[Y_test[idx]])

explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()

Flax (JAX): GloVe Embeddings for Text Classification

Approach 3: GloVe Embeddings Summed (Max Tokens=50, Embedding Length=300)

Our approach in this section is the same as our previous approach with a minor change in that we are taking the sum of embeddings across tokens this time instead. The majority of the code is the same as earlier with only a change in network architecture.

Define Network

Below, we have defined the network architecture that we'll use for our text classification task. The network has the same layer as earlier. The only difference is in forward pass logic which now uses sum() function to take the sum of embeddings instead of the average.

After defining the network, we initialized it and performed a forward pass to make predictions for verification purposes.

from flax import linen

embed_len = 300

class EmbeddingClassifier(linen.Module):
    def setup(self):
        self.embedding = linen.Embed(num_embeddings = len(tokenizer.word_index)+1, features=embed_len,
                                     name="Word Embeddings", embedding_init=lambda *args: word_embeddings)
        self.linear1 = linen.Dense(100, name="Dense1")
        self.linear2 = linen.Dense(len(target_classes), name="Dense2")

    def __call__(self, X_batch):
        x = self.embedding(X_batch)
        x = x.sum(axis=1) ## Sum word embeddings for each words together

        x = self.linear1(x)
        x = linen.relu(x)

        logits = self.linear2(x)
        return logits
from jax import numpy as jnp

seed = jax.random.PRNGKey(0)

embed_classif = EmbeddingClassifier()

params = embed_classif.init(seed, jax.random.randint(seed, (100, max_tokens), minval=1, maxval=20))

preds = embed_classif.apply(params, X_train_vect[:5])

preds
DeviceArray([[-3.284863  , -5.1973324 ,  2.0795426 ,  1.238823  ,
              -0.49297437],
             [-1.2317724 , -6.391368  , -2.226091  ,  2.7046487 ,
               0.9044571 ],
             [-3.160542  , -3.0654523 , -0.32640934,  4.2738943 ,
               2.50664   ],
             [-3.0861063 , -5.5264664 ,  0.7956378 ,  1.6276834 ,
               2.6413312 ],
             [-1.2653544 , -7.148101  , -1.7928789 ,  2.871761  ,
              -0.1492536 ]], dtype=float32)
params["params"]["Word Embeddings"]["embedding"][1][:10]
DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
              0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32)

Train Network

In this section, we have trained the network using exactly the same settings that we have used for all our approaches. We can notice from the loss and accuracy score that the network is doing a good job at the given task.

from jax import random

seed = random.PRNGKey(0)
batch_size=64
epochs=8
learning_rate = jnp.array(1e-3)

model = EmbeddingClassifier()
params = model.init(seed, X_train_vect[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)

final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
100%|██████████| 45/45 [00:30<00:00,  1.47it/s]
CrossEntropyLoss : 87.922
Validation  Accuracy : 0.758
100%|██████████| 45/45 [00:30<00:00,  1.47it/s]
CrossEntropyLoss : 26.698
Validation  Accuracy : 0.817
100%|██████████| 45/45 [00:41<00:00,  1.07it/s]
CrossEntropyLoss : 18.094
Validation  Accuracy : 0.840
100%|██████████| 45/45 [00:32<00:00,  1.38it/s]
CrossEntropyLoss : 13.771
Validation  Accuracy : 0.847
100%|██████████| 45/45 [00:31<00:00,  1.41it/s]
CrossEntropyLoss : 10.899
Validation  Accuracy : 0.857
100%|██████████| 45/45 [00:32<00:00,  1.39it/s]
CrossEntropyLoss : 8.810
Validation  Accuracy : 0.863
100%|██████████| 45/45 [00:32<00:00,  1.40it/s]
CrossEntropyLoss : 7.174
Validation  Accuracy : 0.873
100%|██████████| 45/45 [00:32<00:00,  1.37it/s]
CrossEntropyLoss : 5.806
Validation  Accuracy : 0.876

word_embeddings[1][:10], final_weights["params"]["Word Embeddings"]["embedding"][1][:10], params["params"]["Word Embeddings"]["embedding"][1][:10]
(DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32),
 DeviceArray([ 0.27204  , -0.06203  , -0.1884   ,  0.023225 , -0.018158 ,
               0.0067192, -0.13877  ,  0.17708  ,  0.17709  ,  2.5882   ],            dtype=float32))

Evaluate Network Performance

In this section, we have evaluated the performance of our network by calculating ML metrics like accuracy score, classification report and confusion matrix on test predictions. We can notice from the accuracy score that it is a little better compared to the previous approach and the highest of all our approaches. We have even plotted a confusion matrix for reference purposes.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

test_preds = model.apply(final_weights, X_test_vect)

print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Test  Accuracy : 0.876

Classification Report :
                       precision    recall  f1-score   support

          alt.atheism       0.90      0.81      0.85       319
        comp.graphics       0.85      0.89      0.87       389
            rec.autos       0.85      0.90      0.87       396
     rec.sport.hockey       0.93      0.86      0.89       399
talk.politics.mideast       0.87      0.91      0.89       376

             accuracy                           0.88      1879
            macro avg       0.88      0.87      0.88      1879
         weighted avg       0.88      0.88      0.88      1879


Confusion Matrix :
[[258  20  16   5  20]
 [  6 348  19   4  12]
 [  3  23 356   9   5]
 [  5  14  21 343  16]
 [ 14   6   7   8 341]]
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Flax (JAX): GloVe Embeddings for Text Classification

Explain Network Predictions using LIME Algorithm

In this section, we have explained predictions made by the network using LIME algorithm. The network correctly predicts the target category as 'rec.sport.hockey'. The visualization highlights that words like 'leafs', 'games', 'host', 'goalie', etc contributes to predicting target label as 'rec.sport.hockey'.

from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=target_classes)

rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))

print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", target_classes[Y_test[idx]])

explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()

Flax (JAX): GloVe Embeddings for Text Classification

5. Results Summary And Further Suggestions

Approach Max Tokens Embedding Length Test Accuracy (%)
GloVe Embeddings (840B.300d) Flattened 50 300 84.9
GloVe Embeddings (840B.300d) Averaged 50 300 87.0
GloVe Embeddings (840B.300d) Summed 50 300 87.6

Further Recommendations

  • Try different token sizes.
  • Try different GloVe embeddings like 42B, 6B, 27B, etc.
  • Try different combinations of dense layers after the embedding layer.
  • Try different ways to handle embedding other than flattening, averaging, and summing (max, min, std, etc.).
  • Try different weight initialization methods.
  • Try different optimizers.
  • Train network for more epochs.
  • Try learning rate schedulers
Sunny Solanki  Sunny Solanki

Share Views Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.