Updated On : Feb-11,2022 Time Investment : ~30 mins

Text Classification Using Flax (JAX) Networks

Document classification or commonly referred to as text classification is a very common natural language processing (NLP) task in computer science. The task involves classifying text documents into various categories based on their contents. There are various ways to complete this task. We'll be primarily concentrating on classifying text documents using deep learning techniques involving neural networks.

As a part of this tutorial, we have explained how we can perform text classification using neural networks designed with Flax Python library. Flax is a high-level neural network designing library designed on top of JAX. JAX is a deep learning research library designed by google research teams. Flax makes the process of designing neural networks quite easier compared to designing with JAX. We can think of Flax as a PyTorch-like version of JAX.

We assume that the tutorial reader has a background in Flax, JAX, and text vectorization. Please feel free to check the below tutorials if you want to refresh some knowledge or want to build a background for these topics.

We have used 20 newsgroups data available from scikit-learn as a part of our tutorial. We have used vectorizers available from scikit-learn to convert text data to a list of floats as required by ML algorithms. The tutorial tries different text vectorization approaches to explain them and check their results. We also suggest that readers go through our below tutorial that explains how text vectorization happens in detail as it'll help with this tutorial.

Below, we have listed down important sections of tutorial to give an overview of the material covered.

Important Sections Of Tutorial

  1. Word Frequency Count Model
    • Load Data
    • Vectorize Data (Text to List of Floats)
    • Define Network
    • Define Loss
    • Train Network
    • Evaluate Network Performance
  2. Word Frequency Count Model (Stop Words Removed)
  3. Word Frequency Count Model (Stop Words Removed + n-Grams (1,3))
  4. TfIdf (Term Frequency-Inverse Document Frequency) Model
  5. TfIdf Model (Stop Words Removed)
  6. TfIdf Model (Stop Words Removed + n-Grams (1,3))
  7. Try Other Approaches to Improve Results Further

Below, we have imported important libraries and printed their versions that we have used in our tutorial.

import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.2.27
import flax

print("FLAX Version : {}".format(flax.__version__))
FLAX Version : 0.3.6
import optax

print("OPTAX Version : {}".format(optax.__version__))
OPTAX Version : 0.1.0

1. Word Frequency Count Model

In this section, we have trained a neural network on data that is vectorized version of our input text data and has a list of floats representing a frequency of words present in the text. We have used a scikit-learn count vectorizer that takes text data as input and converts it to floats of frequency. After training, we have evaluated the performance of the network by calculating various metrics on test data like accuracy, f1-score, precision, recall, and confusion matrix.

Load Data

In this section, we have loaded 20 newsgroups dataset available from scikit-learn. It has 18k newsgroup posts about 20 different topics. We have selected 5 topics from these 20 topics and filtered the dataset to keep entries for those 5 topics only. The scikit-learn separately provides train and test datasets through function fetch_20newsgroups().

import numpy as np
from sklearn import datasets
import gc

all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
                  'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
                  'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
                  'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
                  'talk.politics.misc','talk.religion.misc']

selected_categories = ['comp.sys.mac.hardware','comp.windows.x','rec.motorcycles','sci.crypt','talk.politics.mideast']

X_train_text, Y_train = datasets.fetch_20newsgroups(subset="train", categories=selected_categories, return_X_y=True)
X_test_text , Y_test  = datasets.fetch_20newsgroups(subset="test", categories=selected_categories, return_X_y=True)

X_train_text = np.array(X_train_text)
X_test_text = np.array(X_test_text)

classes = np.unique(Y_train)
mapping = dict(zip(classes, selected_categories))

len(X_train_text), len(X_test_text), classes, mapping
(2928,
 1950,
 array([0, 1, 2, 3, 4]),
 {0: 'comp.sys.mac.hardware',
  1: 'comp.windows.x',
  2: 'rec.motorcycles',
  3: 'sci.crypt',
  4: 'talk.politics.mideast'})

Vectorize Data (Text to List of Floats)

In order to use text data with our neural networks, we need a way to convert them to a list of floats. Because neural network works on floats only, we need to convert text to floats. There are various ways to convert text to floats. In this section, we have used the simplest one which takes the text of the sample as input and returns a list of floats for each word where float represents the frequency of that word in that text document.

Scikit-learn provides a CountVectorizer estimator that can help us with this task. We need to train it with data to populate vocabulary first. In our case, we have trained it with train and test datasets both so that vocabulary has words from both. We have set max_features to 50000 which will inform the count vectorizer to keep only 50000 commonly appearing words and all others will be removed. This will limit vocabulary size to 50000 words. The vocabulary will be simply a list of words and when we transform our individual text document to a list of floats, it'll have 50000 floats and the position where a word is present in the dictionary will have a frequency of that word in the text.

If you want to understand the count vectorizer in-depth with examples then we suggest that you go through our below tutorial. It'll help the reader with this tutorial as well. We have explained in-detail how scikit-learn's CountVectorizer and TfIdfVectorizer works.

import sklearn
from jax import numpy as jnp
from sklearn.feature_extraction.text import CountVectorizer

vectorizer = CountVectorizer(max_features=50000)

vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)

X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)

X_train.shape, X_test.shape
((2928, 50000), (1950, 50000))
import gc

gc.collect()
21

Define Network

In this section, we have defined our neural network using Flax. We have created a network of 3 linear/dense layers. The first layer has 128 units, the second layer has 64 units and the final layer has the same units as a number of target classes which is 5 in our case (5 topics). We apply relu (rectified linear unit) activation to the output of each layer except the last layer. We have created a network by creating a class that extends linen.Module class. The layers are defined in setup() method and forward pass is defined in call() method.

In the next cell, we have initialized the network and its parameters (weights & biases of layers). We have printed the weights of layers of the network. We have also performed a forward pass-through network using a few train samples to verify that it works as expected.

from flax import linen
from jax import random

class TextClassifier(linen.Module):
    def setup(self):
        self.linear1 = linen.Dense(features=128, name="DENSE1")
        self.linear2 = linen.Dense(features=64, name="DENSE2")
        self.linear3 = linen.Dense(len(classes), name="DENSE3")

    def __call__(self, inputs):
        x = linen.relu(self.linear1(inputs))
        x = linen.relu(self.linear2(x))
        logits = self.linear3(x)

        return logits #linen.softmax(x)
seed = jax.random.PRNGKey(0)

model = TextClassifier()
params = model.init(seed, X_train[:5])

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
    print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
Layer Name : DENSE1
	Layer Weights : (50000, 128), Biases : (128,)
Layer Name : DENSE2
	Layer Weights : (128, 64), Biases : (64,)
Layer Name : DENSE3
	Layer Weights : (64, 5), Biases : (5,)
preds = model.apply(params, X_train[:5])

preds.shape
(5, 5)

Define Loss

Below, we have defined the loss function for our multi-class text classification task. We'll be using cross entropy loss for our purpose. The function takes parameters, input data, and actual target values of that data as input. It then performs a forward pass-through network using parameters and input data to make predictions. Then, it calculates loss using predictions and actual target values and returns it.

def CrossEntropyLoss(weights, input_data, actual):
    logits_preds = model.apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).sum()

Train Network

In this section, we have first defined a function that we'll be using for training our network in all sections of the tutorial. The function takes train data features (X), train target values (Y), validation data (X_val, Y_val), number of epochs, network parameters (weights), optimizer state, and batch size as input. It then loops through the network number of epochs time. Each time it performs forward pass through network in batches to make predictions, calculate loss, calculate gradients, and update network parameters. At the end of each epoch, it prints training loss, calculates validation data accuracy, and prints it as well.

from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score

def TrainModelInBatches(X, Y, X_val, Y_val, epochs, weights, optimizer_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in tqdm(batches):
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss)(weights, X_batch,Y_batch)

            ## Update Weights
            updates, optimizer_state = optimizer.update(gradients, optimizer_state)
            weights = optax.apply_updates(weights, updates)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))

        Y_val_preds = model.apply(weights, X_val)
        val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
        print("Validation  Accuracy : {:.3f}".format(val_acc))

    return weights

Below, we have actually trained our network using the function created in the previous cell. We have initialized batch size to 256, a number of epochs to 8, and learning rate to 0.001. We have then initialized the network and its parameters. Then, we have initialized Adam optimizer for updating network parameters and created an optimizer state using network parameters. At last, we have called our training function to train the network by providing the necessary parameters.

We can notice from the training loss and validation accuracy getting printed at the end of each epoch that our model has done a decent job at the task.

seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)

model = TextClassifier()
weights = model.init(seed, X_train[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)

final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
100%|██████████| 12/12 [00:06<00:00,  1.89it/s]
CrossEntropyLoss : 261.351
Validation  Accuracy : 0.956
100%|██████████| 12/12 [00:03<00:00,  3.47it/s]
CrossEntropyLoss : 53.057
Validation  Accuracy : 0.960
100%|██████████| 12/12 [00:03<00:00,  3.52it/s]
CrossEntropyLoss : 12.941
Validation  Accuracy : 0.963
100%|██████████| 12/12 [00:03<00:00,  3.44it/s]
CrossEntropyLoss : 4.580
Validation  Accuracy : 0.968
100%|██████████| 12/12 [00:03<00:00,  3.43it/s]
CrossEntropyLoss : 2.220
Validation  Accuracy : 0.967
100%|██████████| 12/12 [00:03<00:00,  3.51it/s]
CrossEntropyLoss : 1.377
Validation  Accuracy : 0.966
100%|██████████| 12/12 [00:03<00:00,  3.45it/s]
CrossEntropyLoss : 0.960
Validation  Accuracy : 0.966
100%|██████████| 12/12 [00:03<00:00,  3.40it/s]
CrossEntropyLoss : 0.716
Validation  Accuracy : 0.965

Evaluate Network Performance

In this section, we have evaluated the performance of our network. We have calculated train and test accuracies, precision, recall, and f1-score per target class and confusion matrix. We can notice from confusion matrix that some samples of sci.crypt and com.windows.x are confused with topic category comp.sys.mac.hardware. Due to that, the recall and f1-score for those two categories are a bit less compared to others.

If you want to know in detail about various ML metrics available through scikit-learn then please check the below link where we have covered the majority of them in detail.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 1.000
Test  Accuracy : 0.965

Classification Report :
                       precision    recall  f1-score   support

comp.sys.mac.hardware       0.94      0.97      0.95       385
       comp.windows.x       0.96      0.93      0.95       395
      rec.motorcycles       0.98      0.99      0.99       398
            sci.crypt       0.96      0.95      0.95       396
talk.politics.mideast       0.99      0.98      0.99       376

             accuracy                           0.97      1950
            macro avg       0.97      0.97      0.97      1950
         weighted avg       0.97      0.97      0.97      1950


Confusion Matrix :
[[373   2   4   5   1]
 [ 13 369   2  10   1]
 [  1   1 395   1   0]
 [ 11   7   3 375   0]
 [  0   5   0   1 370]]

2. Word Frequency Count Model (Stop Words Removed)

In this section, we have again trained our neural network using the same word frequency count approach that we used in our previous section. But, this time, we have removed commonly used stop words (like 'the', 'a', 'an', 'then', etc.) from our vocabulary. Hence, these words will be ignored in text documents as they are commonly present in all documents. We need words that actually contribute to a particular category and are exclusive to it to make the model more accurate.

Vectorize Data (Text to List of Floats)

Below, we have vectorized our text data to a list of floats using a count vectorizer from scikit-learn. The code is almost exactly the same as the previous section with one minor change which is that we have provided 'english' value to parameters 'stop_words' which will remove common stop words of English.

import sklearn
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer

vectorizer = CountVectorizer(max_features=50000, stop_words="english")

vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)

X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)

X_train.shape, X_test.shape
((2928, 50000), (1950, 50000))

Train Network

In this section, we have trained our network on the vectorized dataset. All the training settings are the same as the previous section with only changes in input vectorized data. We can notice from the results that the accuracy has improved little after we removed stop words.

seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)

model = TextClassifier()
weights = model.init(seed, X_train[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)

final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
100%|██████████| 12/12 [00:03<00:00,  3.54it/s]
CrossEntropyLoss : 267.135
Validation  Accuracy : 0.954
100%|██████████| 12/12 [00:03<00:00,  3.59it/s]
CrossEntropyLoss : 50.685
Validation  Accuracy : 0.972
100%|██████████| 12/12 [00:03<00:00,  3.62it/s]
CrossEntropyLoss : 9.192
Validation  Accuracy : 0.969
100%|██████████| 12/12 [00:03<00:00,  3.73it/s]
CrossEntropyLoss : 2.867
Validation  Accuracy : 0.969
100%|██████████| 12/12 [00:03<00:00,  3.61it/s]
CrossEntropyLoss : 1.392
Validation  Accuracy : 0.969
100%|██████████| 12/12 [00:03<00:00,  3.59it/s]
CrossEntropyLoss : 0.895
Validation  Accuracy : 0.970
100%|██████████| 12/12 [00:03<00:00,  3.51it/s]
CrossEntropyLoss : 0.664
Validation  Accuracy : 0.970
100%|██████████| 12/12 [00:03<00:00,  3.53it/s]
CrossEntropyLoss : 0.528
Validation  Accuracy : 0.969

Evaluate Network Performance

In this section, we have evaluated the performance of our network by calculating accuracy, classification report, and confusion matrix metrics. We can notice from the results that few samples of categories 'comp.windows.x' and 'sci.crypt' are confused with category 'comp.sys.mac.hardware' as per first column of confusion matrix. The few samples of 'comp.windows.x' and 'comp.windows.x' are also confused with each other. All these 3 categories are related to computers and can have overlapping keywords that could confuse the model.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 1.000
Test  Accuracy : 0.969

Classification Report :
                       precision    recall  f1-score   support

comp.sys.mac.hardware       0.94      0.98      0.96       385
       comp.windows.x       0.95      0.95      0.95       395
      rec.motorcycles       0.99      0.99      0.99       398
            sci.crypt       0.97      0.94      0.96       396
talk.politics.mideast       1.00      0.98      0.99       376

             accuracy                           0.97      1950
            macro avg       0.97      0.97      0.97      1950
         weighted avg       0.97      0.97      0.97      1950


Confusion Matrix :
[[378   3   1   2   1]
 [ 12 376   1   6   0]
 [  0   1 396   1   0]
 [ 12   9   3 372   0]
 [  0   6   0   2 368]]

3. Word Frequency Count Model (Stop Words Removed + n-Grams (1,3))

In this section, we are taking our word frequency approach further and trying 2-words and 3-words combinations as well along with single words. All our previous approaches tried vectorization by taking into account only single words. This is generally referred to as 1-gram. We can also try to use 2 consecutive words and 3 consecutive words that come next to each other and all combinations of them to catch little contextual details. We'll check whether this approach helps improve accuracy further.

Vectorize Data (Text to List of Floats)

Below, we have vectorized our data using the scikit-learn count vectorizer. All the code of this section is exactly the same as the previous section with one minor change. We have provided value (1,3) to parameter 'ngram_range' of count vectorizer. This will instruct it to consider all 1-gram, 2-gram, and 3-gram word combinations.

import sklearn
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer

vectorizer = CountVectorizer(max_features=50000, stop_words="english", ngram_range=(1,3))

vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)

X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)

X_train.shape, X_test.shape
((2928, 50000), (1950, 50000))

Train Network

Below, we have trained our network using the vectorized data we created in the previous cell. We can notice from the results that there is not much improvement in accuracy with this approach though.

seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)

model = TextClassifier()
weights = model.init(seed, X_train[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)

final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
100%|██████████| 12/12 [00:03<00:00,  3.66it/s]
CrossEntropyLoss : 228.179
Validation  Accuracy : 0.959
100%|██████████| 12/12 [00:03<00:00,  3.62it/s]
CrossEntropyLoss : 28.782
Validation  Accuracy : 0.957
100%|██████████| 12/12 [00:03<00:00,  3.60it/s]
CrossEntropyLoss : 5.586
Validation  Accuracy : 0.960
100%|██████████| 12/12 [00:03<00:00,  3.52it/s]
CrossEntropyLoss : 1.756
Validation  Accuracy : 0.958
100%|██████████| 12/12 [00:03<00:00,  3.82it/s]
CrossEntropyLoss : 0.873
Validation  Accuracy : 0.961
100%|██████████| 12/12 [00:03<00:00,  3.90it/s]
CrossEntropyLoss : 0.572
Validation  Accuracy : 0.961
100%|██████████| 12/12 [00:02<00:00,  4.02it/s]
CrossEntropyLoss : 0.431
Validation  Accuracy : 0.962
100%|██████████| 12/12 [00:03<00:00,  3.93it/s]
CrossEntropyLoss : 0.346
Validation  Accuracy : 0.962

Evaluate Network Performance

In this section, we have evaluated the performance of our network by calculating accuracy, classification report, and confusion matrix metrics on the test dataset. We can notice from the results that samples of categories 'comp.sys.mac.hardware', 'comp.windows.x' and 'sci.crypt' are confused with each other.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 1.000
Test  Accuracy : 0.962

Classification Report :
                       precision    recall  f1-score   support

comp.sys.mac.hardware       0.92      0.98      0.95       385
       comp.windows.x       0.94      0.95      0.94       395
      rec.motorcycles       0.98      0.98      0.98       398
            sci.crypt       0.98      0.93      0.95       396
talk.politics.mideast       1.00      0.97      0.98       376

             accuracy                           0.96      1950
            macro avg       0.96      0.96      0.96      1950
         weighted avg       0.96      0.96      0.96      1950


Confusion Matrix :
[[377   4   2   1   1]
 [ 14 374   1   6   0]
 [  4   2 392   0   0]
 [ 11  13   4 368   0]
 [  6   4   2   0 364]]

4. TfIdf (Term Frequency-Inverse Document Frequency) Model

In this section, we have used an approach called the term frequency-inverse document frequency. This approach creates a float value for each word such that words that appear very common across all documents get less value and the ones that are unique per document get more value. This will help better understand unique words that actually contribute to a particular classification category.

Please feel free to check the below link if you want to learn in-depth how Tf-Idf works internally. We have explained it in detail over there.

Vectorize Data (Text to List of Floats)

Below, we have vectorized our text data using the Tf-Idf vectorizer available from scikit-learn. We'll be using this vectorized data to train our network.

import sklearn
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer(max_features=50000)

vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)

X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)

X_train.shape, X_test.shape
((2928, 50000), (1950, 50000))

Train Network

Below, we have trained our network using Tf-Idf vectorized data. We can notice from the results that they are almost the same as the count vectorizer. There is not much improvement in results.

seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)

model = TextClassifier()
weights = model.init(seed, X_train[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)

final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
100%|██████████| 12/12 [00:03<00:00,  3.83it/s]
CrossEntropyLoss : 379.083
Validation  Accuracy : 0.782
100%|██████████| 12/12 [00:03<00:00,  3.67it/s]
CrossEntropyLoss : 307.200
Validation  Accuracy : 0.891
100%|██████████| 12/12 [00:03<00:00,  3.55it/s]
CrossEntropyLoss : 195.671
Validation  Accuracy : 0.943
100%|██████████| 12/12 [00:03<00:00,  3.62it/s]
CrossEntropyLoss : 88.351
Validation  Accuracy : 0.960
100%|██████████| 12/12 [00:03<00:00,  3.60it/s]
CrossEntropyLoss : 31.819
Validation  Accuracy : 0.965
100%|██████████| 12/12 [00:03<00:00,  3.55it/s]
CrossEntropyLoss : 12.776
Validation  Accuracy : 0.965
100%|██████████| 12/12 [00:03<00:00,  3.27it/s]
CrossEntropyLoss : 6.534
Validation  Accuracy : 0.965
100%|██████████| 12/12 [00:03<00:00,  3.68it/s]
CrossEntropyLoss : 4.084
Validation  Accuracy : 0.967

Evaluate Network Performance

In this section, we have evaluated the performance of the network by calculating accuracy, classification report, and confusion matrix metrics on the test dataset. We can notice that as usual, the samples of categories 'comp.sys.mac.hardware', 'comp.windows.x' and 'sci.crypt' are confused with each other.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 1.000
Test  Accuracy : 0.967

Classification Report :
                       precision    recall  f1-score   support

comp.sys.mac.hardware       0.94      0.97      0.95       385
       comp.windows.x       0.95      0.95      0.95       395
      rec.motorcycles       0.99      0.99      0.99       398
            sci.crypt       0.96      0.94      0.95       396
talk.politics.mideast       1.00      0.99      0.99       376

             accuracy                           0.97      1950
            macro avg       0.97      0.97      0.97      1950
         weighted avg       0.97      0.97      0.97      1950


Confusion Matrix :
[[373   5   2   4   1]
 [ 13 375   0   7   0]
 [  2   2 393   1   0]
 [ 10  11   2 373   0]
 [  0   3   0   2 371]]

5. TfIdf Model (Stop Words Removed)

In this section, we have again used Tf-Idf approach but this time we have removed stop words from the dataset. We'll check whether this helps improve results further.

Vectorize Data (Text to List of Floats)

In this section, we have vectorized our text dataset with Tf-Idf vectorizer from scikit-learn. The majority of the code is repeated from previous sections.

import sklearn
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer(max_features=50000, stop_words="english")

vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)

X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)

X_train.shape, X_test.shape
((2928, 50000), (1950, 50000))

Train Network

Below, we have trained our network on Tf-Idf vectorized data with stop words removed. We can notice from the results that accuracy is the highest of all the approaches we tried till now.

seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)

model = TextClassifier()
weights = model.init(seed, X_train[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)

final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
100%|██████████| 12/12 [00:03<00:00,  3.93it/s]
CrossEntropyLoss : 379.706
Validation  Accuracy : 0.843
100%|██████████| 12/12 [00:03<00:00,  3.89it/s]
CrossEntropyLoss : 309.512
Validation  Accuracy : 0.908
100%|██████████| 12/12 [00:03<00:00,  3.95it/s]
CrossEntropyLoss : 197.086
Validation  Accuracy : 0.947
100%|██████████| 12/12 [00:03<00:00,  3.95it/s]
CrossEntropyLoss : 88.640
Validation  Accuracy : 0.963
100%|██████████| 12/12 [00:03<00:00,  3.98it/s]
CrossEntropyLoss : 31.219
Validation  Accuracy : 0.968
100%|██████████| 12/12 [00:03<00:00,  3.96it/s]
CrossEntropyLoss : 12.172
Validation  Accuracy : 0.969
100%|██████████| 12/12 [00:03<00:00,  3.98it/s]
CrossEntropyLoss : 6.143
Validation  Accuracy : 0.969
100%|██████████| 12/12 [00:03<00:00,  3.93it/s]
CrossEntropyLoss : 3.830
Validation  Accuracy : 0.970

Evaluate Network Performance

In this section, we have evaluated the performance of the network by calculating accuracy, classification report, and confusion matrix metrics on test datasets. As usual, 3 computer science-related categories are confused with each other but this time confusion is a little less compared to previous approaches.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 1.000
Test  Accuracy : 0.970

Classification Report :
                       precision    recall  f1-score   support

comp.sys.mac.hardware       0.94      0.98      0.96       385
       comp.windows.x       0.95      0.95      0.95       395
      rec.motorcycles       0.99      0.99      0.99       398
            sci.crypt       0.97      0.94      0.96       396
talk.politics.mideast       1.00      0.99      0.99       376

             accuracy                           0.97      1950
            macro avg       0.97      0.97      0.97      1950
         weighted avg       0.97      0.97      0.97      1950


Confusion Matrix :
[[378   3   0   3   1]
 [ 12 377   1   5   0]
 [  2   2 393   1   0]
 [ 10  13   1 372   0]
 [  0   3   0   2 371]]

6. TfIdf Model (Stop Words Removed + n-Grams (1,3))

In this section, we have tried our last approach where we have used Tf-Idf vectorization with stop words removed and n-grams in the range (1,3) is used.

Vectorize Data (Text to List of Floats)

In this section, we have vectorized our data as usual.

import sklearn
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer(max_features=50000, stop_words="english", ngram_range=(1,3))

vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)

X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)

X_train.shape, X_test.shape
((2928, 50000), (1950, 50000))

Train Network

Below, we have trained our network using vectorized data created in the previous cell. We can notice from the results that there is not much improvement in accuracy instead it's the lowest of all our approaches.

seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)

model = TextClassifier()
weights = model.init(seed, X_train[:5])

optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)

final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
100%|██████████| 12/12 [00:03<00:00,  3.95it/s]
CrossEntropyLoss : 374.707
Validation  Accuracy : 0.885
100%|██████████| 12/12 [00:03<00:00,  3.82it/s]
CrossEntropyLoss : 284.541
Validation  Accuracy : 0.918
100%|██████████| 12/12 [00:03<00:00,  3.58it/s]
CrossEntropyLoss : 157.985
Validation  Accuracy : 0.951
100%|██████████| 12/12 [00:03<00:00,  3.59it/s]
CrossEntropyLoss : 62.078
Validation  Accuracy : 0.955
100%|██████████| 12/12 [00:03<00:00,  3.53it/s]
CrossEntropyLoss : 22.056
Validation  Accuracy : 0.958
100%|██████████| 12/12 [00:03<00:00,  3.37it/s]
CrossEntropyLoss : 9.458
Validation  Accuracy : 0.958
100%|██████████| 12/12 [00:03<00:00,  3.92it/s]
CrossEntropyLoss : 5.102
Validation  Accuracy : 0.957
100%|██████████| 12/12 [00:03<00:00,  3.93it/s]
CrossEntropyLoss : 3.283
Validation  Accuracy : 0.957

Evaluate Network Performance

In this section, we have evaluated the performance of our network by calculating various metrics on our test dataset. As usual, three computer science category samples are confused with each other.

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
Train Accuracy : 1.000
Test  Accuracy : 0.957

Classification Report :
                       precision    recall  f1-score   support

comp.sys.mac.hardware       0.92      0.97      0.94       385
       comp.windows.x       0.93      0.94      0.93       395
      rec.motorcycles       0.98      0.98      0.98       398
            sci.crypt       0.98      0.93      0.95       396
talk.politics.mideast       0.99      0.97      0.98       376

             accuracy                           0.96      1950
            macro avg       0.96      0.96      0.96      1950
         weighted avg       0.96      0.96      0.96      1950


Confusion Matrix :
[[372   9   1   2   1]
 [ 16 372   1   6   0]
 [  5   3 390   0   0]
 [  8  15   4 368   1]
 [  5   3   2   1 365]]

7. Try Other Approaches to Improve Results Further

  • In our case, we kept vocabulary size to 50k words. We suggest trying different vocabulary sizes.
  • We also suggest trying different neural network architectures.
  • Trying different n-grams with different vocabulary sizes might improve results.
  • Create your own tokenizer to tokenize text document to the list of words and provide it to scikit-learn vectorizers.
  • Keep uppercase letters as uppercase and check whether it helps improve accuracy. You can do this by setting 'lowercase' parameter to False when creating vectorizers
  • Try using character vectorization instead of word vectorization to see if it helps.

This ends our small tutorial explaining how we can perform text classification tasks using Flax (JAX). Please feel free to let us know your views in the comments section.

References

Sunny Solanki  Sunny Solanki

YouTube Subscribe Comfortable Learning through Video Tutorials?

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.

Need Help Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.