Updated On : Apr-30,2022 Tags flax, jax, conv1d, text-…
Flax (JAX): Conv1D For Text Classification Tasks

Flax (JAX): Conv1D For Text Classification Tasks

Recurrent Neural Networks (RNNs) and their variants like LSTM and GRU are the most commonly preferred network for ML tasks that involves some kind of sequence data (time-series data, text data, speech data, etc). The RNNs are very good at capturing the long sequences of the data. Though RNNs are better at capturing sequences, it takes time to train RNN networks. Recent studies have shown that CNNs consisting of 1D convolution layers are also good at capturing sequences. Hence, we can use a network of 1D Convolutions for NLP tasks like text classification, text generation, etc. The main benefit of using CNNs is that it has fewer parameters to train compared to RNNs and gets trained faster.

As a part of this tutorial, we have designed CNNs with 1D convolution layers using Flax for text classification task. Flax is a Python deep learning library built on top of JAX for designing deep neural networks. We have used AG NEWS dataset which has text documents for 4 different categories of news. We have tried different approaches to using Conv1D layers for solving text classification tasks. For encoding text data before giving it to the convolution layer, we have used word embeddings approach. After training networks, we have also evaluated their performance by calculating various ML metrics and explained their predictions using LIME algorithm.

Below, we have listed important sections of Tutorial to give an overview of the material covered.

Important Sections Of Tutorial

  1. Prepare Data
    • 1.1 Load Datasets
    • 1.2 Vectorize Text Data
  2. Approach 1: Single Conv1D Layer Network (Max Tokens=50, Embedding Length=128, Conv Output Channels=32)
    • Define Network
    • Define Loss Function
    • Train Network
    • Evaluate Performance Of Network
    • Explain Network Predictions using LIME Algorithm
  3. Approach 2: Multiple Conv1D Layers Network (Max Tokens=50, Embedding Length=128, Conv Output Channels=32,32)
  4. Results Summary And Further Recommendation

Below, we have imported the necessary Python libraries that we have used in our tutorial and printed the versions of them as well.

In [1]:
import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.3.5
In [2]:
import flax

print("FLAX Version : {}".format(flax.__version__))
FLAX Version : 0.4.1
In [3]:
import optax

print("OPTAX Version : {}".format(optax.__version__))
OPTAX Version : 0.1.1
In [4]:
import torchtext

print("Torchtext Version : {}".format(torchtext.__version__))
Torchtext Version : 0.10.1
In [5]:
from tensorflow import keras

print("Keras Version : {}".format(keras.__version__))
Keras Version : 2.6.0
In [6]:
import warnings

warnings.filterwarnings("ignore")

1. Prepare Data

In this section, we are preparing data to be given to the neural network. As our raw data is text and neural network works on real-valued data. We need to convert our text data to real-valued data.

In this tutorial, we are going to use word embeddings approach where we break text into a list of tokens (words, punctuation marks, etc) and assign a real-valued vector to each token of text. We'll achieve this in two steps.

  1. Map each token of text to a unique integer index.
  2. Map integer index to real-valued vector.

We'll complete the first step in this section. The second step will be implemented in the neural network where we include Embedding Layer as the first layer which is responsible for mapping the integer index of tokens to their respective embeddings. These embeddings get updated during training of the network to better understand the meaning of token.

Flax (JAX): Conv1D For Text Classification

1.1 Load Datasets

In this section, we have loaded AG NEWS dataset that we are going to use for our case. It has text documents for 4 different news categories (["World", "Sports", "Business", "Sci/Tech"]) which our trained network will try to classify. The dataset is already divided into train and test sets.

In [7]:
import numpy as np

train_dataset, test_dataset = torchtext.datasets.AG_NEWS()

X_train_text, Y_train = [], []
for Y, X in train_dataset:
    X_train_text.append(X)
    Y_train.append(Y)

X_test_text, Y_test = [], []
for Y, X in test_dataset:
    X_test_text.append(X)
    Y_test.append(Y)

unique_classes = list(set(Y_train))
target_classes = ["World", "Sports", "Business", "Sci/Tech"]

## Subtracted 1 from labels to bring range from 1-4 to 0-3
Y_train, Y_test = np.array(Y_train) - 1, np.array(Y_test) - 1

len(X_train_text), len(X_test_text)
train.csv: 29.5MB [00:00, 90.5MB/s]
test.csv: 1.86MB [00:00, 59.9MB/s]
Out[7]:
(120000, 7600)

1.2 Vectorize Text Data

In this section, we are vectorizing our text data.

First, we have initialized Tokenizer object available from keras. After initializing the tokenizer, we have trained it by calling fit_on_texts() method on it with train and test text examples. The call to this method will populate the vocabulary inside of Tokenizer object which will have all unique tokens (words, punctuation marks, etc). The vocabulary is a simple mapping from a token to a unique integer index. Each token is assigned a unique integer index starting from 1. Position 0 is reserved for unknown tokens encountered in the future.

After vocabulary is populated with tokens, we can call texts_to_sequences() method on the tokenizer object with a list of text documents. It'll tokenize them and return indexes of tokens for each text example. We have called texts_to_sequence() method on train and test text documents. We have also called pad_sequences() method around it. This method is used to bring tokens of all text examples to the same length. The text for different documents can have a different number of tokens. But we have decided to keep a maximum of 50 tokens per text example. All tokens beyond 50 will be truncated. The examples that have less than 50 tokens will be appended with 0s.

After vectorizing text data to indexes, we have also printed the first few examples with their indexes.

Below, we have explained with a simple example how text documents will be mapped to indexes.

text = "Hello, How are you? Where are you planning to go?"

tokens = ['hello', ',', 'how', 'are', 'you', '?', 'where',
            'are', 'you', 'planning', 'to', 'go', '?']

vocab = {
    'hello': 0,
    'bye': 1,
    'how': 2,
    'the': 3,
    'welcome': 4,
    'are': 5,
    'you': 6,
    'to': 7,
    '<unk>': 8,
}

vector = [0,8,2,4,6,8,8,5,6,8,7,8,8]
In [8]:
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from jax import numpy as jnp

max_tokens = 50 ## Hyperparameter to tune for better performance

tokenizer = Tokenizer()
tokenizer.fit_on_texts(X_train_text+X_test_text)

## Vectorizing data to keep 50 words per sample.
X_train_vect = pad_sequences(tokenizer.texts_to_sequences(X_train_text), maxlen=max_tokens, padding="post", truncating="post", value=0.)
X_test_vect  = pad_sequences(tokenizer.texts_to_sequences(X_test_text), maxlen=max_tokens, padding="post", truncating="post", value=0.)

print(X_train_vect[:3])

X_train_vect, X_test_vect = jnp.array(X_train_vect, dtype=jnp.int32), jnp.array(X_test_vect, dtype=jnp.int32)
Y_train, Y_test = jnp.array(Y_train), jnp.array(Y_test)

X_train_vect.shape, X_test_vect.shape
[[  444   440  1697 15012   108    64     1   852    21    21   739  8198
    444  6337 10243  2965     4  5937 26696    40  4014   801   335     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0]
 [15470  1111   871  1313  4306    21    21   919   809   359 15470    99
    101    22     3  4508     8   504   511 13730     6 15471  1514  2135
      5     1   522   247    22  3938  2289    15  6459     7   209   368
      4     1   129     0     0     0     0     0     0     0     0     0
      0     0]
 [   53     6   379  4509 26697   770    21    21  2446   467    90  1885
   1280    66     1   379     6     1   770     8   285    40   190     2
   5766    34     1   296   129   111    82   230     1  6391     4     1
   1208 15472     0     0     0     0     0     0     0     0     0     0
      0     0]]
Out[8]:
((120000, 50), (7600, 50))
In [9]:
## What is word 21

print(tokenizer.index_word[21])

## How many times it comes in first text document??

print(X_train_text[0]) ## 2 times
reuters
Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.

Approach 1: Single Conv1D Layer Network (Max Tokens=50, Embedding Length=128, Conv Output Channels=32)

Our first approach tries to perform a text classification task with a simple one convolution layer network. The network consists of an embedding layer, a 1D convolution layer, and a dense layer. Below, we have defined the network, trained it, evaluated the performance of the network, and tried to explain predictions made by the network as well.

Define Network

In this section, we have defined a network that we'll use for our text classification task in this section. The network consists of 3 layers.

  • Embedding Layer
  • Conv1D Layer
  • Dense Layer

The first layer of the network is the embedding layer. We have created an embedding layer using Embed() constructor available from linen sub-module of Flax. We have asked it to use a length of vocabulary as a number of tokens (first parameter) and embedding length of 128. When we create this layer, it internally will create a weight matrix of shape (vocab_len, embed_len). This matrix has embeddings of length 128 for all tokens of our vocabulary. The embedding layer simply takes a list of indexes of tokens as input and returns embeddings for all those indexes from the weight matrix. The input to embedding layer is of shape (batch_size, max_tokens) = (batch_size, 50) and the output is of shape (batch_size, max_tokens, embed_len) = (batch_size, 50, 128).

The output of the embedding layer is given to Conv1D layer. We have defined Conv1D layer with 32 output channels and a kernel size of 7. We'll be treating the embedding length dimension as the channels dimension in our case. Hence, the input to convolution layer will be of shape (batch_size, max_tokens, embed_len) = (batch_size, 50, 128) and output shape will be (batch_size, max_tokens, conv_output_channels) = (batch_size, 50, 32). We have applied relu activation to the output of the convolution layer.

After applying relu activation, we have applied max() (we can also try average instead of max) function at max_tokens dimension which will return us output of shape (batch_size, 32). This output will then be given to Dense layer which has 4 output units (same as the number of target classes). The output of Dense layer is returned from our network as the prediction.

After defining the network, we initialized it, printed the shape of weights/biases of layers, and performed a forward pass through it using random data for verification purposes.

If you are someone new to Flax and want to learn how to create a neural network using it then we recommend that you go through the below tutorials that cover it in detail. It'll help us better understand Flax.

In [10]:
from flax import linen

embed_len = 128

class Conv1DTextClassifier(linen.Module):
    def setup(self):
        self.embedding = linen.Embed(len(tokenizer.word_index)+1, embed_len, name="Word Embeddings")
        self.conv1 = linen.Conv(32, kernel_size=(7,),name="Conv1")
        self.linear1 = linen.Dense(len(unique_classes), name="Dense1")

    def __call__(self, X_batch):
        x = self.embedding(X_batch)
        x = linen.relu(self.conv1(x))
        x = x.max(axis=1)
        logits = self.linear1(x)
        return logits
In [11]:
from jax import numpy as jnp

seed = jax.random.PRNGKey(0)

text_classif = Conv1DTextClassifier()

params = text_classif.init(seed, jax.random.randint(seed, (100, max_tokens), minval=1, maxval=20))

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    if "Embedding" in layer_params[0]:
        weights = layer_params[1]["embedding"]
        print("\tLayer Weights : {}".format(weights.shape))
    else:
        weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
        print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
Layer Name : Word Embeddings
	Layer Weights : (72003, 128)
Layer Name : Conv1
	Layer Weights : (7, 128, 32), Biases : (32,)
Layer Name : Dense1
	Layer Weights : (32, 4), Biases : (4,)
In [12]:
preds = text_classif.apply(params, X_train_vect[:5])

preds
Out[12]:
DeviceArray([[-0.20374326, -0.01410167,  0.24069656, -0.20879401],
             [-0.1101338 , -0.15477017,  0.22043025, -0.17012575],
             [-0.08682966, -0.14510089,  0.29758802, -0.25169313],
             [-0.14112273, -0.14207381,  0.33813193, -0.17564772],
             [-0.07097152, -0.16631724,  0.2733841 , -0.18639344]],            dtype=float32)

Define Loss Function

Below, we have defined the loss function that we'll be using for our task. We'll be using cross entropy loss for our case. The function takes network parameter, actual data, and their target labels as input. It then makes predictions on data using network parameters and one-hot encodes actual target labels. At last, it calculates cross entropy loss by giving predictions and one-hot encoded target labels to softmax_cross_entropy() function available from Optax library.

In [13]:
def CrossEntropyLoss(params, input_data, actual):
    logits_preds = model.apply(params, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(unique_classes))
    return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).sum()

Train Network

In this section, we are training the neural network that we defined earlier. We have defined a simple training function that will perform training for us. The function takes train data (X, Y), validation data (X_val, Y_val), number of epochs, network parameters, optimizer state, and batch size as input. It then executes a training loop number of epochs time. For each epoch, it loops through training data in batches. For each batch, it performs a forward pass to make predictions, calculates loss, calculates gradients, and updates network parameters using gradients. It also records loss for each batch and prints the average loss of all batches at the end of the epoch. The function also calculates validation data accuracy at the end of each epoch.

In [14]:
from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score

def TrainModelInBatches(X, Y, X_val, Y_val, epochs, params, optimizer_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in tqdm(batches):
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss)(params, X_batch,Y_batch)

            ## Update Network Parameters
            updates, optimizer_state = optimizer.update(gradients, optimizer_state)
            params = optax.apply_updates(params, updates)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))

        Y_val_preds = model.apply(params, X_val)
        val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
        print("Validation  Accuracy : {:.3f}".format(val_acc))

    return params

Below, we have actually trained our network using the training routine defined in the previous cell. First, we have initialized batch size to 1024, the number of epochs to 10, and the learning rate to 0.001. Then, we have initialized our text classification network and Adam optimizer. At last, we have called our training routine with the necessary parameters to perform training. We can notice from the loss and accuracy value getting printed after each epoch that our model is doing a good job at the text classification task.

In [15]:
from jax import random

seed = random.PRNGKey(0)
batch_size=1024
epochs=10
learning_rate = jnp.array(1e-3)

model = Conv1DTextClassifier()
params = model.init(seed, X_train_vect[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)

final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
100%|██████████| 118/118 [01:16<00:00,  1.55it/s]
CrossEntropyLoss : 921.101
Validation  Accuracy : 0.879
100%|██████████| 118/118 [01:11<00:00,  1.66it/s]
CrossEntropyLoss : 302.596
Validation  Accuracy : 0.906
100%|██████████| 118/118 [01:10<00:00,  1.68it/s]
CrossEntropyLoss : 213.165
Validation  Accuracy : 0.912
100%|██████████| 118/118 [01:20<00:00,  1.46it/s]
CrossEntropyLoss : 160.152
Validation  Accuracy : 0.912
100%|██████████| 118/118 [01:13<00:00,  1.61it/s]
CrossEntropyLoss : 120.241
Validation  Accuracy : 0.913
100%|██████████| 118/118 [01:08<00:00,  1.73it/s]
CrossEntropyLoss : 88.937
Validation  Accuracy : 0.913
100%|██████████| 118/118 [01:00<00:00,  1.95it/s]
CrossEntropyLoss : 64.543
Validation  Accuracy : 0.913
100%|██████████| 118/118 [01:03<00:00,  1.85it/s]
CrossEntropyLoss : 46.213
Validation  Accuracy : 0.912
100%|██████████| 118/118 [01:04<00:00,  1.82it/s]
CrossEntropyLoss : 32.868
Validation  Accuracy : 0.911
100%|██████████| 118/118 [01:04<00:00,  1.84it/s]
CrossEntropyLoss : 23.687
Validation  Accuracy : 0.912

Evaluate Performance Of Network

In this section, we have evaluated the performance of our train network by calculating accuracy score, classification report and confusion matrix metrics on test predictions. We can notice from the accuracy score that the network is doing quite a good job at the classification task.

We have calculated ML metrics using functions available from scikit-learn. If you want to learn about various ML metrics available from sklearn then please feel free to check the below link.

We have also created a plot of confusion matrix using Python library scikit-plot. We can notice from the plot that our model is doing a good job at classifying text documents of categories Sports and World compared to categories Business and Sci/Tech.

If you are new to library scikit-plot and want to learn various ML metrics plots it provides then please do check the below link.

In [16]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.apply(final_weights, X_train_vect)
test_preds = model.apply(final_weights, X_test_vect)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 0.997
Test  Accuracy : 0.912

Classification Report :
              precision    recall  f1-score   support

       World       0.92      0.91      0.92      1900
      Sports       0.96      0.97      0.96      1900
    Business       0.87      0.89      0.88      1900
    Sci/Tech       0.90      0.87      0.89      1900

    accuracy                           0.91      7600
   macro avg       0.91      0.91      0.91      7600
weighted avg       0.91      0.91      0.91      7600


Confusion Matrix :
[[1735   50   66   49]
 [  25 1848   20    7]
 [  68   18 1696  118]
 [  53   19  178 1650]]
In [ ]:
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Blues",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Flax (JAX): Conv1D For Text Classification

Explain Network Predictions using LIME Algorithm

In this section, we have tried to explain predictions made by our network using LIME algorithm. We'll be using an implementation of the algorithm available through lime python library. It let us generate visualization showing which words from our text example contributed to predicting a particular target label.

Please feel to check the below links if you are new to LIME and want to learn about it in-depth.

In order to explain predictions using LIME, we first need to create an instance of LimeTextExplainer which we have done below.

In [18]:
from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=target_classes, verbose=True)

Now, we have randomly selected a text example from the test dataset and made predictions on it using our trained network. Our network correctly predicts the target label as Sci/Tech for the selected text example.

Apart from this, we have also defined the prediction function in the below cell. The function takes a list of text examples as input and returns their prediction probabilities using our model. The function tokenizes data and then gives it to the network to make predictions. At last, softmax activation is applied to the output of the network to convert it to probabilities and returned from the function.

In [19]:
import numpy as np

def make_predictions(X_batch_text):
    X_batch = pad_sequences(tokenizer.texts_to_sequences(X_batch_text), maxlen=max_tokens, padding="post", truncating="post", value=0.)
    logits = model.apply(final_weights, jnp.array(X_batch))
    preds = linen.softmax(logits)
    return preds.to_py()

rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))

print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", target_classes[Y_test[idx]])
Prediction :  Sci/Tech
Actual :      Sci/Tech

Below, we have first generated an Explanation object by calling explain_instance() method on LimeTextExplainer object. We have given text examples, prediction function, and the target label to the method. This explanation object has details about words contributing to prediction.

Then, we have called show_in_notebook() method on Explanation object to generate a visualization showing words contributing to predicting the target label as Sci/Tech.

We can notice from the visualization that words like 'RFID', 'privacy', 'frequency', 'identification', etc are used for predicting the target label as Sci/Tech which makes sense as they are commonly used in the field.

In [ ]:
explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()

Flax (JAX): Conv1D For Text Classification

Approach 2: Multiple Conv1D Layers Network (Max Tokens=50, Embedding Length=128, Conv Output Channels=32,32)

Our approach in this section uses multiple convolution layers. The majority of the code in this section is the same as our code from the previous section except for the definition of neural network which uses two 1d convolution layers this time instead of one.

Define Network

In this section, we have defined a network that we'll use for our text classification task in this section. The network consists of one embedding layer, two convolution layers, and one dense layer. We have again used the embedding length of 128 in this section as well. The two convolution layers have 32 output channels and a kernel size of 7. The output of the embedding layer is given to the convolution layer. The relu activation is applied to the output of the first convolution layer and then max-pooling operation is performed on it. The output of max-pooling is given to the second convolution layer. Then, we have applied relu activation to the output of the second convolution layer and given it to the dense layer. The output of the dense layer is returned as a prediction of the network.

After defining the network, we initialized it, printed the shape of weights/biases of layers, and performed a forward pass through it using random data for verification purposes.

In [21]:
from flax import linen

embed_len = 128 ## Hyperparameter

class Conv1DTextClassifier(linen.Module):
    def setup(self):
        self.embedding = linen.Embed(len(tokenizer.word_index)+1, embed_len, name="Word Embeddings")
        self.conv1 = linen.Conv(32, kernel_size=(7,),name="Conv1")
        self.conv2 = linen.Conv(32, kernel_size=(7,),name="Conv2")
        self.linear1 = linen.Dense(len(unique_classes), name="Dense1")

    def __call__(self, X_batch):
        x = self.embedding(X_batch)

        x = linen.relu(self.conv1(x))

        x = linen.max_pool(x, window_shape=(5,))

        x = linen.relu(self.conv2(x))

        x = x.max(axis=1)

        logits = self.linear1(x)
        return logits
In [22]:
from jax import numpy as jnp

seed = jax.random.PRNGKey(0)

text_classif = Conv1DTextClassifier()

params = text_classif.init(seed, jax.random.randint(seed, (100, max_tokens), minval=1, maxval=20))

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    if "Embedding" in layer_params[0]:
        weights = layer_params[1]["embedding"]
        print("\tLayer Weights : {}".format(weights.shape))
    else:
        weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
        print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
Layer Name : Word Embeddings
	Layer Weights : (72003, 128)
Layer Name : Conv1
	Layer Weights : (7, 128, 32), Biases : (32,)
Layer Name : Conv2
	Layer Weights : (7, 32, 32), Biases : (32,)
Layer Name : Dense1
	Layer Weights : (32, 4), Biases : (4,)
In [23]:
preds = text_classif.apply(params, X_train_vect[:5])

preds
Out[23]:
DeviceArray([[-0.10522123, -0.03273935,  0.27437395, -0.0857728 ],
             [-0.0656735 , -0.06544174,  0.2836984 , -0.1417185 ],
             [-0.1282662 , -0.00825347,  0.3193404 , -0.1120418 ],
             [-0.12225489, -0.03385258,  0.32456946, -0.12756759],
             [-0.13004538, -0.0449469 ,  0.29020992, -0.10896963]],            dtype=float32)

Train Network

Here, we have trained our network using exactly the same settings that we have used in the previous section. We can notice from the loss and accuracy values getting printed after each epoch that our model is doing a good job at the text classification task.

In [24]:
from jax import random

seed = random.PRNGKey(0)
batch_size=1024
epochs=10
learning_rate = jnp.array(1e-3)

model = Conv1DTextClassifier()
params = model.init(seed, X_train_vect[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)

final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
100%|██████████| 118/118 [01:09<00:00,  1.69it/s]
CrossEntropyLoss : 787.356
Validation  Accuracy : 0.894
100%|██████████| 118/118 [01:35<00:00,  1.24it/s]
CrossEntropyLoss : 257.627
Validation  Accuracy : 0.906
100%|██████████| 118/118 [01:11<00:00,  1.65it/s]
CrossEntropyLoss : 168.014
Validation  Accuracy : 0.909
100%|██████████| 118/118 [01:22<00:00,  1.42it/s]
CrossEntropyLoss : 112.964
Validation  Accuracy : 0.909
100%|██████████| 118/118 [01:25<00:00,  1.39it/s]
CrossEntropyLoss : 73.333
Validation  Accuracy : 0.906
100%|██████████| 118/118 [01:10<00:00,  1.68it/s]
CrossEntropyLoss : 47.830
Validation  Accuracy : 0.904
100%|██████████| 118/118 [01:25<00:00,  1.38it/s]
CrossEntropyLoss : 39.231
Validation  Accuracy : 0.896
100%|██████████| 118/118 [01:15<00:00,  1.57it/s]
CrossEntropyLoss : 50.822
Validation  Accuracy : 0.900
100%|██████████| 118/118 [01:18<00:00,  1.50it/s]
CrossEntropyLoss : 48.646
Validation  Accuracy : 0.898
100%|██████████| 118/118 [01:36<00:00,  1.22it/s]
CrossEntropyLoss : 34.234
Validation  Accuracy : 0.899

Evaluate Network Performance

In this section, we have evaluated the performance of our trained network by calculating accuracy score, classification report and confusion matrix metrics on test predictions. We can notice from the test accuracy that it's a little low compared to the previous approach which is surprising as we had expected that trying more convolution layers might improve accuracy. We have also plotted the confusion matrix for reference.

In [25]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.apply(final_weights, X_train_vect)
test_preds = model.apply(final_weights, X_test_vect)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 0.994
Test  Accuracy : 0.899

Classification Report :
              precision    recall  f1-score   support

       World       0.88      0.93      0.90      1900
      Sports       0.96      0.95      0.96      1900
    Business       0.86      0.88      0.87      1900
    Sci/Tech       0.90      0.84      0.87      1900

    accuracy                           0.90      7600
   macro avg       0.90      0.90      0.90      7600
weighted avg       0.90      0.90      0.90      7600


Confusion Matrix :
[[1762   48   53   37]
 [  56 1814   24    6]
 [  88   17 1663  132]
 [ 106   18  183 1593]]
In [ ]:
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt

skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
                                    normalize=True,
                                    title="Confusion Matrix",
                                    cmap="Blues",
                                    hide_zeros=True,
                                    figsize=(5,5)
                                    );
plt.xticks(rotation=90);

Flax (JAX): Conv1D For Text Classification

Explain Network Predictions using LIME Algorithm

In this section, we have again tried to explain predictions made by our network using LIME algorithm. Our network correctly predicts the target label as Sci/Tech for a selected random test example. From the visualization created using LIME, we can notice that words like 'RFID', 'technology', 'frequency', 'identification', 'privacy', etc are contributing to predicting the target label as Sci/Tech.

In [ ]:
from lime import lime_text

explainer = lime_text.LimeTextExplainer(class_names=target_classes)

rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))

print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual :     ", target_classes[Y_test[idx]])

explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()

Flax (JAX): Conv1D For Text Classification

4. Results Summary And Further Recommendation

Below, we have listed a summary of various approaches we tried above.

Approach Max Tokens Embedding Length Conv Output Channels Test Accuracy (%)
Single Conv1D Layer Network 50 128 32 91.2
Multiple Conv1D Layers Network 50 128 32,32 89.9

Further Suggestions

  • Try different convolution layer output channels.
  • Try different kernel sizes.
  • Try a different number of maximum tokens.
  • Try different embedding lengths.
  • Add more dense layers at the end.
  • Try different weight initialization methods.
  • Try different optimizers (RMSProp, Adagrad, etc).
  • Add recurrent layers (LSTM, GRU) after the convolution layer.
  • Try learning rate schedulers.

This ends our small tutorial explaining how we can design CNNs of 1D Convolution layers using Flax (JAX) framework for solving text classification tasks. Please feel free to let us know your views in the comments section.

References

Sunny Solanki  Sunny Solanki

 Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to let us know in the comments section below (Guest Comments are allowed). We appreciate and value your feedbacks.

If you like our work please give a thumbs-up to our article in the comments section below. You can also support us with a small contribution by clicking on Support Us link in the footer section.