Updated On : Mar-31,2022 Tags flax, jax, word-embeddin…
Guide To Use Word Embeddings For Flax (JAX) Text Classification Networks

Guide To Use Word Embeddings For Flax (JAX) Text Classification Networks

When we want to develop a deep learning or machine learning model that works on text data, we first need to convert text data to floats using some approach as ML/DL models only work with floats. Over the years, many approaches have been developed like word frequency, one-hot encoding of words, TF-IDF (Term Frequency-Inverse Document Frequency), word embeddings, etc. The approaches like word frequency and TF-IDF uses just one float to represent one word/token. This approach works generally well for NLP tasks like text classification involving less data but might not work well with other NLP tasks like text generation that requires understanding the context of the text. As they use just one float per word/token, there is a limitation to the amount of information that can be represented through it. To solve this, word embeddings were invented. In the case of word embeddings, a list of floats (a vector) is used to represent a single word/token which is generally referred to as embeddings of that word/token. This float vector has the capability to capture the meaning of the word better and can also capture contextual information. Generally, when we train and update these embeddings through our networks, embeddings of words that are the same (by meaning or in the same context) will be near to one another.

Word Embeddings for PyTorch Text Classification Networks

As a part of this tutorial, we have explained how we can use word embeddings for text classification networks designed using Flax (JAX). The Flax is a high-level Python deep learning library designed on top of JAX to simplify the process of creating neural networks. We have tried various approaches to work with word embeddings. We have used AG NEWS dataset available from torchtext and word tokenizing functionalities available from Keras. Apart from this, we have also explained predictions made by the network using LIME algorithm.

Below, we have listed important sections of tutorial to give an overview of the material covered.

Important Sections Of Tutorial

  1. Prepare Data
    • 1.1 Load Data
    • 1.2 Vectorize Data
  2. Approach 1 - Word Embeddings
    • 2.1 Create Network
    • 2.2 Define Loss
    • 2.3 Train Network
    • 2.4 Evaluate Network Performance
    • 2.5 Explain Network Predictions Using LIME
  3. Approach 2 - Word Embeddings With More Values
  4. Approach 3 - Word Embeddings Averaged
  5. Approach 4 - Word Embeddings Summed

Below, we have imported the necessary libraries and printed the versions that we have used in our tutorial.

In [1]:
import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.3.1
In [2]:
import flax

print("Flax Version : {}".format(flax.__version__))
Flax Version : 0.4.0
In [3]:
import optax

print("OPTAX Version : {}".format(optax.__version__))
OPTAX Version : 0.1.1
In [4]:
import torchtext

print("Torchtext Version : {}".format(torchtext.__version__))
Torchtext Version : 0.10.1
In [5]:
from tensorflow import keras

print("Keras Version : {}".format(keras.__version__))
Keras Version : 2.6.0

1. Prepare Data

In this section, we have prepared data that we'll be feeding to our neural network. We have first loaded the dataset, populated vocabulary (mapping from unique words to indexes) with tokens/words, and then transformed text to a list of token indexes using vocabulary. In the end, we'll have a list of indexes for each text document where indexes are mapped from index to token/word as per our vocabulary. We'll give this list of indexes as input to our network which will then generate embeddings from them and update embeddings as we train it.

1.1 Load Data

Below, we have simply loaded AG NEWS dataset available from PyTorch. The dataset has text documents for 4 different categories of news (["World", "Sports", "Business", "Sci/Tech"]).

Index Category
1 World
2 Sports
3 Business
4 Sci/Tec
In [6]:
import numpy as np

train_dataset, test_dataset = torchtext.datasets.AG_NEWS()

X_train_text, Y_train = [], []
for Y, X in train_dataset:
    X_train_text.append(X)
    Y_train.append(Y)

X_test_text, Y_test = [], []
for Y, X in test_dataset:
    X_test_text.append(X)
    Y_test.append(Y)

unique_classes = list(set(Y_train))
target_classes = ["World", "Sports", "Business", "Sci/Tech"]

## Subtracted 1 from labels to bring range from 1-4 to 0-3
Y_train, Y_test = np.array(Y_train) - 1, np.array(Y_test) - 1

len(X_train_text), len(X_test_text)
train.csv: 29.5MB [00:00, 77.9MB/s]
test.csv: 1.86MB [00:00, 51.8MB/s]
Out[6]:
(120000, 7600)

1.2 Vectorize Data

In this section, we have first populated vocabulary using text data and then mapped text to a list of indexes based on populated vocabulary. To perform these operations, we have used Tokenizer() constructor available from Keras library. We have first initialized Tokenizer and then called fit_on_texts() method on it with train and test text documents to populate the vocabulary. The Tokenizer object will be populated with tokens/words from our data.

After the vocabulary is populated, we have called texts_to_sequences() method on Tokenizer object giving train and text documents to it to generate a list of indexes for tokens/words of text documents. The texts_to_sequences() method will first tokenize text documents to generate a list of tokens/words and then will map tokens/words to their respective indexes based on populated vocabulary.

We have decided to keep a maximum of 50 words per document for our classification task. Some documents can have more than 50 words whereas some can have less than 50 words. We have used pad_sequences() function from keras to bring the length of all vectorized data examples to 50. It'll append 0s to examples whose length is less than 50 and will truncate examples whose length is more than 50.

After we have vectorized text documents to a list of indexes, we have also converted them to JAX array as required by our neural networks.

In [7]:
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from jax import numpy as jnp

tokenizer = Tokenizer()
tokenizer.fit_on_texts(X_train_text+X_test_text)

## Vectorizing data to keep 50 words per sample.
X_train_vect = pad_sequences(tokenizer.texts_to_sequences(X_train_text), maxlen=50, padding="post", truncating="post", value=0.)
X_test_vect  = pad_sequences(tokenizer.texts_to_sequences(X_test_text), maxlen=50, padding="post", truncating="post", value=0.)

print(X_train_vect[:3])

X_train_vect, X_test_vect = jnp.array(X_train_vect, dtype=jnp.int32), jnp.array(X_test_vect, dtype=jnp.int32)
Y_train, Y_test = jnp.array(Y_train), jnp.array(Y_test)

X_train_vect.shape, X_test_vect.shape
[[  444   440  1697 15012   108    64     1   852    21    21   739  8198
    444  6337 10243  2965     4  5937 26696    40  4014   801   335     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0]
 [15470  1111   871  1313  4306    21    21   919   809   359 15470    99
    101    22     3  4508     8   504   511 13730     6 15471  1514  2135
      5     1   522   247    22  3938  2289    15  6459     7   209   368
      4     1   129     0     0     0     0     0     0     0     0     0
      0     0]
 [   53     6   379  4509 26697   770    21    21  2446   467    90  1885
   1280    66     1   379     6     1   770     8   285    40   190     2
   5766    34     1   296   129   111    82   230     1  6391     4     1
   1208 15472     0     0     0     0     0     0     0     0     0     0
      0     0]]
Out[7]:
((120000, 50), (7600, 50))
In [8]:
## What is word 444

print(tokenizer.index_word[444])

## How many times it comes in first text document??

print(X_train_text[0]) ## 2 times
wall
Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.

2. Approach 1 - Word Embeddings

This is our first approach to explaining how to train a network on vectorized data we generated earlier. We have created a network with the embedding layer and dense layers to classify text documents. We have tried embedding a length of 10 which means that each word will be mapped to a vector of 10 floats.

2.1 Create Network

Below, we have created a network that we'll use for our text classification tasks. Our network consists of one embedding layer and 2 dense layers.

The embedding layer has embeddings for a number of vocabulary words with each word having an embedding of length 10. We have defined the embedding layer using Embed() constructor by giving it a length of vocabulary and an embedding length of 10. The embedding layer will initialize the weight matrix of shape (vocab_len+1, 10). It'll then map each index value to the embedding. We have already translated our text data to a list of indexes for tokens/words. The embedding layer will then map those indexes with their respective embeddings. During the training process, these weights/embeddings will be updated to improve the accuracy of the network.

The output of the embedding layer is flattened and given to a dense layer that has 100 output units. Then, we have applied relu activation to the output of the first dense layer. After that, the output is given to the second dense layer that has 4 output units (same as target class labels). The output of the second dense layer is a prediction of our network.

After defining the network, we have initialized it and printed the shape of the layers of the network. We have also performed a forward pass-through network to make predictions on a few train samples for verification purposes.

Please feel free to check the below tutorial if you are looking for some background on how to create networks using Flax. We have covered various modules of Flax over there in detail.

In [9]:
from flax import linen

class EmbeddingClassifier(linen.Module):
    def setup(self):
        self.embedding = linen.Embed(len(tokenizer.word_index)+1, 10, name="Word Embeddings")
        self.linear1 = linen.Dense(100, name="Dense1")
        self.linear2 = linen.Dense(len(unique_classes), name="Dense2")

    def __call__(self, X_batch):
        x = self.embedding(X_batch)
        x = x.reshape(len(X_batch), -1)

        x = self.linear1(x)
        x = linen.relu(x)

        logits = self.linear2(x)
        return logits
In [10]:
from jax import numpy as jnp

seed = jax.random.PRNGKey(0)

embed_classif = EmbeddingClassifier()

params = embed_classif.init(seed, jax.random.randint(seed, (100, 50), minval=1, maxval=20))

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    if "Embedding" in layer_params[0]:
        weights = layer_params[1]["embedding"]
        print("\tLayer Weights : {}".format(weights.shape))
    else:
        weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
        print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
Layer Name : Word Embeddings
	Layer Weights : (72003, 10)
Layer Name : Dense1
	Layer Weights : (500, 100), Biases : (100,)
Layer Name : Dense2
	Layer Weights : (100, 4), Biases : (4,)
In [11]:
preds = embed_classif.apply(params, X_train_vect[:5])

preds
Out[11]:
DeviceArray([[-0.19533372,  0.23670919,  0.5155503 , -0.03628911],
             [-0.04272176,  0.20425099,  0.90916896, -0.21494554],
             [-0.02283784,  0.15638712,  0.5786724 , -0.19889665],
             [-0.06193038,  0.36467755,  0.04414063,  0.01532627],
             [-0.09400441,  0.12597957,  0.18625824, -0.10882761]],            dtype=float32)

2.2 Define Loss

Below, we have defined a loss function whose output we'll try to minimize as our optimization problem. We have used cross entropy loss for our case. The function takes network parameters, input data, and actual target labels as input. It then makes predictions on input data using network parameters. Then, the actual targets are one-hot encoded. At the end, cross entropy is calculated using softmax_cross_entropy() function available from Optax library giving predictions and one-hot encoded values.

In [12]:
def CrossEntropyLoss(params, input_data, actual):
    logits_preds = model.apply(params, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(unique_classes))
    return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).sum()

2.3 Train Network

Now, we'll train our network. To train the network, we have defined a simple function that takes train data (X,Y), validation data (X_val, Y_val), number of epochs, network parameters, optimizers, and batch size as input. The function executes a training loop number of epochs time. During each epoch, it loops through training data in batches. For each batch, we perform a forward pass to make predictions, calculate loss, calculate gradients, and update network parameters using gradients. The function also records loss for each batch and prints the average loss at the end of each epoch. We also calculate validation accuracy at the end of each epoch and print it. At last, the function returns updated network parameters.

In [13]:
from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score

def TrainModelInBatches(X, Y, X_val, Y_val, epochs, params, optimizer_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in tqdm(batches):
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss)(params, X_batch,Y_batch)

            ## Update Network Parameters
            updates, optimizer_state = optimizer.update(gradients, optimizer_state)
            params = optax.apply_updates(params, updates)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))

        Y_val_preds = model.apply(params, X_val)
        val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
        print("Validation  Accuracy : {:.3f}".format(val_acc))

    return params

Below, we have called our training routine to perform training. We have first initialized batch size to 1024, a number of epochs to 15, and learning rate to 0.001. Then, we have initialized our network and Adam optimizer. At last, we have called our training routine with the necessary parameters to perform training. We can notice from the training loss and validation accuracy getting printed after each epoch that our model is doing a good job at classifying text documents.

In [14]:
from jax import random

seed = random.PRNGKey(0)
batch_size=1024
epochs=15
learning_rate = jnp.array(1e-3)

model = EmbeddingClassifier()
params = model.init(seed, X_train_vect[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)

final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
100%|██████████| 118/118 [00:23<00:00,  5.09it/s]
CrossEntropyLoss : 1255.593
Validation  Accuracy : 0.623
100%|██████████| 118/118 [00:18<00:00,  6.40it/s]
CrossEntropyLoss : 698.493
Validation  Accuracy : 0.795
100%|██████████| 118/118 [00:19<00:00,  6.15it/s]
CrossEntropyLoss : 456.550
Validation  Accuracy : 0.841
100%|██████████| 118/118 [00:18<00:00,  6.45it/s]
CrossEntropyLoss : 358.194
Validation  Accuracy : 0.862
100%|██████████| 118/118 [00:18<00:00,  6.34it/s]
CrossEntropyLoss : 300.330
Validation  Accuracy : 0.873
100%|██████████| 118/118 [00:19<00:00,  6.11it/s]
CrossEntropyLoss : 259.826
Validation  Accuracy : 0.880
100%|██████████| 118/118 [00:18<00:00,  6.39it/s]
CrossEntropyLoss : 228.544
Validation  Accuracy : 0.886
100%|██████████| 118/118 [00:18<00:00,  6.46it/s]
CrossEntropyLoss : 202.844
Validation  Accuracy : 0.890
100%|██████████| 118/118 [00:19<00:00,  6.08it/s]
CrossEntropyLoss : 180.826
Validation  Accuracy : 0.894
100%|██████████| 118/118 [00:19<00:00,  6.21it/s]
CrossEntropyLoss : 161.414
Validation  Accuracy : 0.896
100%|██████████| 118/118 [00:18<00:00,  6.27it/s]
CrossEntropyLoss : 143.937
Validation  Accuracy : 0.897
100%|██████████| 118/118 [00:19<00:00,  6.09it/s]
CrossEntropyLoss : 128.089
Validation  Accuracy : 0.897
100%|██████████| 118/118 [00:19<00:00,  6.18it/s]
CrossEntropyLoss : 113.588
Validation  Accuracy : 0.897
100%|██████████| 118/118 [00:18<00:00,  6.33it/s]
CrossEntropyLoss : 100.272
Validation  Accuracy : 0.897
100%|██████████| 118/118 [00:19<00:00,  6.20it/s]
CrossEntropyLoss : 88.142
Validation  Accuracy : 0.898

2.4 Evaluate Network Performance

Here, we have evaluated the performance of our network by calculating accuracy, classification report (precision, recall, and f1-score per class) and confusion matrix metrics on test predictions. We can notice from the results of metrics that our model seems to be doing a good job with good accuracy.

We have calculated various metrics using functions available from scikit-learn. Please feel free to check the below link if you are interested in learning about various metrics available from sklearn in detail.

In the next cell after the below cell, we have also plotted the confusion matrix using the function available from scikit-plot python library. We can notice from the plots that our model is doing good for categories Sports and World compared to categories Business and Sci/Tech. Please feel free to check the below link if you want to learn scikit-plot libraries which provide visualization for many ML metrics.

In [15]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.apply(final_weights, X_train_vect)
test_preds = model.apply(final_weights, X_test_vect)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 0.978
Test  Accuracy : 0.898

Classification Report :
              precision    recall  f1-score   support

       World       0.90      0.90      0.90      1900
      Sports       0.96      0.95      0.96      1900
    Business       0.87      0.86      0.86      1900
    Sci/Tech       0.87      0.87      0.87      1900

    accuracy                           0.90      7600
   macro avg       0.90      0.90      0.90      7600
weighted avg       0.90      0.90      0.90      7600


Confusion Matrix :
[[1715   38   84   63]
 [  54 1813   17   16]
 [  73   20 1634  173]
 [  68   23  149 1660]]
In [ ]:
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Word Embeddings for PyTorch Text Classification Networks

2.5 Explain Network Predictions Using LIME

In this section, we have tried to explain predictions made by the network using LIME algorithm. The lime python library provides an implementation of LIME algorithm. In order to use it, we first need to create an instance of LimeTextExplainer object and call explain_instance() method on it to generate Explanation object. The Explanation object has details about which words contributed to predicting a particular category. We can call show_in_notebook() method on Explanation object to generate visualization showing words contribution to prediction. Please feel free to check the below tutorial if you do not have a background on LIME and are interested in learning it.

Below, we have created an instance of LimeTextExplainer by giving target labels to it.

In [17]:
from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=target_classes, verbose=True)

Below, we have designed a simple function that takes an input list of text documents and returns predictions for them. The function takes a list of text documents and converts them to a list of indexes using our trained tokenizer from earlier. It then gives this list of indexes to the network to make predictions. The prediction of the network is converted to probabilities using softmax activation function before returning from the function.

After defining a function, we have randomly selected a text sample from test data. We have then made predictions on it using our model. We have printed the actual label and predicted label for the selected sample. The actual label for our selected sample is Sci/Tech and our model predicted the same.

In [18]:
import numpy as np

def make_predictions(X_batch_text):
    X_batch = pad_sequences(tokenizer.texts_to_sequences(X_batch_text), maxlen=50, padding="post", truncating="post", value=0.)
    logits = model.apply(final_weights, jnp.array(X_batch))
    preds = linen.softmax(logits)
    return preds.to_py()

rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))

print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", target_classes[Y_test[idx]])
Prediction :  Sci/Tech
Actual :      Sci/Tech

Below, we have called explain_instance() method on explain objects with the selected sample, classifier function (the one we defined in the previous cell), and actual label that we want to explain. The method returns an Explanation object. Then, we have called show_in_notebook() method on Explanation instance to create a visualization.

We can notice from the visualization that words like 'RFID', 'frequency', 'identification', 'technology', 'threat', 'privacy', etc are contributing to predicting category as Sci/Tech.

In [ ]:
explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()

Word Embeddings for PyTorch Text Classification Networks

3. Approach 2 - Word Embeddings With More Values

In this section, we have tried another approach to explaining word embeddings. Our approach for this section is exactly the same as our previous approach with only differences in embedding size. In this section, we have used an embedding size of 20 which is more compared to the previous example. The majority of the code in this section is the same as our code from the previous section with only a change in embedding length per word/token.

3.1 Define Network

Below, we have defined the network that we'll use for classification. The network has the exactly same structure as our previous network with only a difference in embedding length defined when creating the embedding layer. We have set it to 20 this time.

In [20]:
from flax import linen

class EmbeddingClassifier(linen.Module):
    def setup(self):
        self.embedding = linen.Embed(len(tokenizer.word_index)+1, 20, name="Word Embeddings") ## Word embeddings size increased
        self.linear1 = linen.Dense(100, name="Dense1")
        self.linear2 = linen.Dense(len(unique_classes), name="Dense2")

    def __call__(self, X_batch):
        x = self.embedding(X_batch)
        x = x.reshape(len(X_batch), -1)

        x = self.linear1(x)
        x = linen.relu(x)

        logits = self.linear2(x)
        return logits

3.2 Train Network

Now, we have trained our new network with more embeddings. We have initialized batch size to 1024, a number of epochs to 1024, and the learning rate to 0.001. Then, we have initialized our network and Adam optimizer. At last, we have called our training routine to perform training. We can notice from the training loss and validation accuracy that our model is doing a good job.

In [21]:
from jax import random

seed = random.PRNGKey(0)
batch_size=1024
epochs=15
learning_rate = jnp.array(1e-3)

model = EmbeddingClassifier()
params = model.init(seed, X_train_vect[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)

final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
100%|██████████| 118/118 [00:17<00:00,  6.83it/s]
CrossEntropyLoss : 1031.812
Validation  Accuracy : 0.792
100%|██████████| 118/118 [00:16<00:00,  7.01it/s]
CrossEntropyLoss : 427.626
Validation  Accuracy : 0.865
100%|██████████| 118/118 [00:16<00:00,  6.97it/s]
CrossEntropyLoss : 299.141
Validation  Accuracy : 0.886
100%|██████████| 118/118 [00:16<00:00,  7.00it/s]
CrossEntropyLoss : 236.821
Validation  Accuracy : 0.895
100%|██████████| 118/118 [00:16<00:00,  7.19it/s]
CrossEntropyLoss : 193.992
Validation  Accuracy : 0.899
100%|██████████| 118/118 [00:16<00:00,  7.11it/s]
CrossEntropyLoss : 160.696
Validation  Accuracy : 0.900
100%|██████████| 118/118 [00:17<00:00,  6.89it/s]
CrossEntropyLoss : 133.005
Validation  Accuracy : 0.900
100%|██████████| 118/118 [00:16<00:00,  7.12it/s]
CrossEntropyLoss : 109.176
Validation  Accuracy : 0.901
100%|██████████| 118/118 [00:16<00:00,  7.13it/s]
CrossEntropyLoss : 88.537
Validation  Accuracy : 0.903
100%|██████████| 118/118 [00:17<00:00,  6.75it/s]
CrossEntropyLoss : 70.895
Validation  Accuracy : 0.905
100%|██████████| 118/118 [00:17<00:00,  6.93it/s]
CrossEntropyLoss : 56.203
Validation  Accuracy : 0.904
100%|██████████| 118/118 [00:16<00:00,  6.98it/s]
CrossEntropyLoss : 44.988
Validation  Accuracy : 0.900
100%|██████████| 118/118 [00:17<00:00,  6.89it/s]
CrossEntropyLoss : 37.415
Validation  Accuracy : 0.895
100%|██████████| 118/118 [00:18<00:00,  6.54it/s]
CrossEntropyLoss : 33.463
Validation  Accuracy : 0.894
100%|██████████| 118/118 [00:17<00:00,  6.79it/s]
CrossEntropyLoss : 30.530
Validation  Accuracy : 0.896

3.3 Evaluate Network Performance

Here, we have calculated accuracy, classification report and confusion matrix metrics for test predictions. We can notice from the accuracy that it is almost the same as our accuracy from the previous section. The increase in embedding length does not seem to have increased accuracy much. The results of other metrics are almost the same with little to no improvements.

In [22]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.apply(final_weights, X_train_vect)
test_preds = model.apply(final_weights, X_test_vect)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 0.996
Test  Accuracy : 0.896

Classification Report :
              precision    recall  f1-score   support

       World       0.88      0.91      0.89      1900
      Sports       0.95      0.96      0.96      1900
    Business       0.87      0.85      0.86      1900
    Sci/Tech       0.88      0.86      0.87      1900

    accuracy                           0.90      7600
   macro avg       0.90      0.90      0.90      7600
weighted avg       0.90      0.90      0.90      7600


Confusion Matrix :
[[1730   48   67   55]
 [  46 1824   19   11]
 [ 109   18 1613  160]
 [  83   25  152 1640]]
In [ ]:
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Word Embeddings for PyTorch Text Classification Networks

3.4 Explain Network Predictions Using LIME

Below, we have again tried to explain the prediction made by the network. We have randomly selected a test example. The selected sample has Sci/Tech label and the network predicted the same. We can notice from the visualization that this time few words ('frequency', 'identification', 'threat', etc) that were contributing to prediction in the previous section are not contributing anymore.

In [ ]:
from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=target_classes)

rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))

print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", target_classes[Y_test[idx]])

explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()

Word Embeddings for PyTorch Text Classification Networks

4. Approach 3 - Word Embeddings Averaged

In this section, we have tried another approach involving word embeddings. We have used an embedding length of 20 like our previous section. The only difference in this approach is that, till now, we were flattening the embeddings but this time we have averaged embeddings of all words/tokens per text example. The averaged embedding is then given to dense layers.

4.1 Define Network

Our network for this section is exactly the same as our network from the previous section. The only difference is in the implementation of the forward pass. In this section, we have averaged the output of the embedding layer before giving it to the linear layer. We have averaged it in a way that word embeddings for a single text example will be averaged. As we have kept 50 tokens/words per text example, the network will average embeddings of all 50 tokens/words.

In [25]:
from flax import linen

class EmbeddingClassifier(linen.Module):
    def setup(self):
        self.embedding = linen.Embed(len(tokenizer.word_index)+1, 20, name="Word Embeddings")
        self.linear1 = linen.Dense(100, name="Dense1")
        self.linear2 = linen.Dense(len(unique_classes), name="Dense2")

    def __call__(self, X_batch):
        x = self.embedding(X_batch)
        x = x.mean(axis=1) ## Average word embeddings for each words together

        x = self.linear1(x)
        x = linen.relu(x)

        logits = self.linear2(x)
        return logits

4.2 Train Network

Below, we have trained our network using exactly the same settings that we have used in our previous sections. The training loss and validation accuracy printed at the end of the epoch points out that our model seems to be doing quite a good job compared to previous approaches we tried.

In [26]:
from jax import random

seed = random.PRNGKey(0)
batch_size=1024
epochs=15
learning_rate = jnp.array(1e-3)

model = EmbeddingClassifier()
params = model.init(seed, X_train_vect[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)

final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
100%|██████████| 118/118 [00:16<00:00,  7.07it/s]
CrossEntropyLoss : 1315.283
Validation  Accuracy : 0.774
100%|██████████| 118/118 [00:16<00:00,  7.07it/s]
CrossEntropyLoss : 651.022
Validation  Accuracy : 0.867
100%|██████████| 118/118 [00:15<00:00,  7.45it/s]
CrossEntropyLoss : 353.634
Validation  Accuracy : 0.890
100%|██████████| 118/118 [00:16<00:00,  7.26it/s]
CrossEntropyLoss : 282.815
Validation  Accuracy : 0.902
100%|██████████| 118/118 [00:17<00:00,  6.94it/s]
CrossEntropyLoss : 243.787
Validation  Accuracy : 0.908
100%|██████████| 118/118 [00:16<00:00,  6.96it/s]
CrossEntropyLoss : 215.755
Validation  Accuracy : 0.912
100%|██████████| 118/118 [00:16<00:00,  7.14it/s]
CrossEntropyLoss : 193.397
Validation  Accuracy : 0.913
100%|██████████| 118/118 [00:16<00:00,  7.22it/s]
CrossEntropyLoss : 174.637
Validation  Accuracy : 0.913
100%|██████████| 118/118 [00:16<00:00,  7.00it/s]
CrossEntropyLoss : 158.385
Validation  Accuracy : 0.914
100%|██████████| 118/118 [00:16<00:00,  7.27it/s]
CrossEntropyLoss : 143.973
Validation  Accuracy : 0.915
100%|██████████| 118/118 [00:16<00:00,  7.22it/s]
CrossEntropyLoss : 131.030
Validation  Accuracy : 0.916
100%|██████████| 118/118 [00:16<00:00,  6.95it/s]
CrossEntropyLoss : 119.342
Validation  Accuracy : 0.916
100%|██████████| 118/118 [00:16<00:00,  7.19it/s]
CrossEntropyLoss : 108.748
Validation  Accuracy : 0.915
100%|██████████| 118/118 [00:16<00:00,  7.32it/s]
CrossEntropyLoss : 99.136
Validation  Accuracy : 0.914
100%|██████████| 118/118 [00:16<00:00,  7.24it/s]
CrossEntropyLoss : 90.404
Validation  Accuracy : 0.913

4.3 Evaluate Network Performance

Below, we have calculated accuracy, classification report and confusion matrix metrics on test predictions as usual. We can notice from the accuracy that it's better than both of our previous approaches. The model's accuracy in classifying text documents from Business and Sci/Tech has increased compared to previous approaches.

In [27]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.apply(final_weights, X_train_vect)
test_preds = model.apply(final_weights, X_test_vect)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 0.976
Test  Accuracy : 0.913

Classification Report :
              precision    recall  f1-score   support

       World       0.91      0.91      0.91      1900
      Sports       0.96      0.97      0.97      1900
    Business       0.88      0.89      0.88      1900
    Sci/Tech       0.90      0.88      0.89      1900

    accuracy                           0.91      7600
   macro avg       0.91      0.91      0.91      7600
weighted avg       0.91      0.91      0.91      7600


Confusion Matrix :
[[1735   45   69   51]
 [  28 1845   19    8]
 [  70   16 1689  125]
 [  67   17  146 1670]]
In [ ]:
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Word Embeddings for PyTorch Text Classification Networks

4.4 Explain Network Predictions Using LIME

Below, we have explained one test prediction using LIME. We can notice from the visualization that words like 'privacy', 'RFID', 'frequency', 'technology', etc are contributing to predicting the category Sci/Tech.

In [ ]:
from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=target_classes)

rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))

print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", target_classes[Y_test[idx]])

explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()

Word Embeddings for PyTorch Text Classification Networks

5. Approach 4 - Word Embeddings Summed

Our approach in this section has the majority of the code same as our previous approach with the only difference that we are summing up embeddings for text examples instead of averaging them this time.

5.1 Define Network

Below, we have defined our network which has the almost same structure as our network from the previous approach. The only difference is in the forward pass of the network. We are summing up embeddings of text examples this time instead of averaging.

In [30]:
from flax import linen

class EmbeddingClassifier(linen.Module):
    def setup(self):
        self.embedding = linen.Embed(len(tokenizer.word_index)+1, 20, name="Word Embeddings")
        self.linear1 = linen.Dense(100, name="Dense1")
        self.linear2 = linen.Dense(len(unique_classes), name="Dense2")

    def __call__(self, X_batch):
        x = self.embedding(X_batch)
        x = x.sum(axis=1) ## Sum word embeddings for each words together

        x = self.linear1(x)
        x = linen.relu(x)

        logits = self.linear2(x)
        return logits

5.2 Train Network

Below, we have trained our network using the same settings that we have been using for all our previous approaches. We can notice from the loss and accuracy getting printed after each epoch that our model has done a good job at the text classification task.

In [31]:
from jax import random

seed = random.PRNGKey(0)
batch_size=1024
epochs=15
learning_rate = jnp.array(1e-3)

model = EmbeddingClassifier()
params = model.init(seed, X_train_vect[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)

final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
100%|██████████| 118/118 [00:16<00:00,  7.12it/s]
CrossEntropyLoss : 1156.244
Validation  Accuracy : 0.756
100%|██████████| 118/118 [00:16<00:00,  7.21it/s]
CrossEntropyLoss : 508.957
Validation  Accuracy : 0.852
100%|██████████| 118/118 [00:16<00:00,  7.22it/s]
CrossEntropyLoss : 348.453
Validation  Accuracy : 0.880
100%|██████████| 118/118 [00:16<00:00,  7.03it/s]
CrossEntropyLoss : 276.214
Validation  Accuracy : 0.892
100%|██████████| 118/118 [00:15<00:00,  7.38it/s]
CrossEntropyLoss : 229.587
Validation  Accuracy : 0.897
100%|██████████| 118/118 [00:16<00:00,  7.30it/s]
CrossEntropyLoss : 194.919
Validation  Accuracy : 0.900
100%|██████████| 118/118 [00:15<00:00,  7.45it/s]
CrossEntropyLoss : 167.057
Validation  Accuracy : 0.903
100%|██████████| 118/118 [00:16<00:00,  6.97it/s]
CrossEntropyLoss : 143.562
Validation  Accuracy : 0.904
100%|██████████| 118/118 [00:16<00:00,  7.36it/s]
CrossEntropyLoss : 123.298
Validation  Accuracy : 0.905
100%|██████████| 118/118 [00:16<00:00,  7.35it/s]
CrossEntropyLoss : 105.800
Validation  Accuracy : 0.905
100%|██████████| 118/118 [00:16<00:00,  7.09it/s]
CrossEntropyLoss : 90.706
Validation  Accuracy : 0.904
100%|██████████| 118/118 [00:16<00:00,  7.34it/s]
CrossEntropyLoss : 77.586
Validation  Accuracy : 0.904
100%|██████████| 118/118 [00:16<00:00,  7.36it/s]
CrossEntropyLoss : 66.281
Validation  Accuracy : 0.902
100%|██████████| 118/118 [00:15<00:00,  7.41it/s]
CrossEntropyLoss : 56.605
Validation  Accuracy : 0.900
100%|██████████| 118/118 [00:16<00:00,  7.01it/s]
CrossEntropyLoss : 48.390
Validation  Accuracy : 0.897

5.3 Evaluate Network Performance

Below, we have evaluated network performance as usual by calculating accuracy, classification report and confusion matrix metrics on test predictions. The accuracy is pretty good for this model as well.

In [32]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.apply(final_weights, X_train_vect)
test_preds = model.apply(final_weights, X_test_vect)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 0.989
Test  Accuracy : 0.897

Classification Report :
              precision    recall  f1-score   support

       World       0.89      0.90      0.90      1900
      Sports       0.95      0.95      0.95      1900
    Business       0.86      0.87      0.87      1900
    Sci/Tech       0.88      0.87      0.87      1900

    accuracy                           0.90      7600
   macro avg       0.90      0.90      0.90      7600
weighted avg       0.90      0.90      0.90      7600


Confusion Matrix :
[[1708   50   78   64]
 [  52 1809   21   18]
 [  88   17 1649  146]
 [  66   24  159 1651]]
In [ ]:
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Purples",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Word Embeddings for PyTorch Text Classification Networks

5.4 Explain Network Predictions Using LIME

Here, we have explained the prediction made by our model on a random test example using LIME. The network correctly predicts the category as 'Sci/Tech' for the selected sample. The visualization shows that words like 'privacy', 'RFID', 'charging', 'technology', 'threat', etc are contributing to predicting category 'Sci/Tech'.

In [ ]:
from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=target_classes)

rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))

print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", target_classes[Y_test[idx]])

explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()

Word Embeddings for PyTorch Text Classification Networks

This ends our small tutorial explaining how we can use word embeddings for Flax (JAX) text classification networks. Please feel free to let us know your views in the comments section.

References

Sunny Solanki  Sunny Solanki

 Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to let us know in the comments section below (Guest Comments are allowed). We appreciate and value your feedbacks.

If you like our work please give a thumbs-up to our article in the comments section below. You can also support us with a small contribution by clicking on Support Us link in the footer section.