Document classification or commonly referred to as text classification is a very common natural language processing (NLP) task in computer science. The task involves classifying text documents into various categories based on their contents. There are various ways to complete this task. We'll be primarily concentrating on classifying text documents using deep learning techniques involving neural networks.
As a part of this tutorial, we have explained how we can perform text classification using neural networks designed with Flax Python library. Flax is a high-level neural network designing library designed on top of JAX. JAX is a deep learning research library designed by google research teams. Flax makes the process of designing neural networks quite easier compared to designing with JAX. We can think of Flax as a PyTorch-like version of JAX.
We assume that the tutorial reader has a background in Flax, JAX, and text vectorization. Please feel free to check the below tutorials if you want to refresh some knowledge or want to build a background for these topics.
We have used 20 newsgroups data available from scikit-learn as a part of our tutorial. We have used vectorizers available from scikit-learn to convert text data to a list of floats as required by ML algorithms. The tutorial tries different text vectorization approaches to explain them and check their results. We also suggest that readers go through our below tutorial that explains how text vectorization happens in detail as it'll help with this tutorial.
Below, we have listed down important sections of tutorial to give an overview of the material covered.
Below, we have imported important libraries and printed their versions that we have used in our tutorial.
import jax
print("JAX Version : {}".format(jax.__version__))
import flax
print("FLAX Version : {}".format(flax.__version__))
import optax
print("OPTAX Version : {}".format(optax.__version__))
In this section, we have trained a neural network on data that is vectorized version of our input text data and has a list of floats representing a frequency of words present in the text. We have used a scikit-learn count vectorizer that takes text data as input and converts it to floats of frequency. After training, we have evaluated the performance of the network by calculating various metrics on test data like accuracy, f1-score, precision, recall, and confusion matrix.
In this section, we have loaded 20 newsgroups dataset available from scikit-learn. It has 18k newsgroup posts about 20 different topics. We have selected 5 topics from these 20 topics and filtered the dataset to keep entries for those 5 topics only. The scikit-learn separately provides train and test datasets through function fetch_20newsgroups().
import numpy as np
from sklearn import datasets
import gc
all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
                  'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
                  'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
                  'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
                  'talk.politics.misc','talk.religion.misc']
selected_categories = ['comp.sys.mac.hardware','comp.windows.x','rec.motorcycles','sci.crypt','talk.politics.mideast']
X_train_text, Y_train = datasets.fetch_20newsgroups(subset="train", categories=selected_categories, return_X_y=True)
X_test_text , Y_test  = datasets.fetch_20newsgroups(subset="test", categories=selected_categories, return_X_y=True)
X_train_text = np.array(X_train_text)
X_test_text = np.array(X_test_text)
classes = np.unique(Y_train)
mapping = dict(zip(classes, selected_categories))
len(X_train_text), len(X_test_text), classes, mapping
In order to use text data with our neural networks, we need a way to convert them to a list of floats. Because neural network works on floats only, we need to convert text to floats. There are various ways to convert text to floats. In this section, we have used the simplest one which takes the text of the sample as input and returns a list of floats for each word where float represents the frequency of that word in that text document.
Scikit-learn provides a CountVectorizer estimator that can help us with this task. We need to train it with data to populate vocabulary first. In our case, we have trained it with train and test datasets both so that vocabulary has words from both. We have set max_features to 50000 which will inform the count vectorizer to keep only 50000 commonly appearing words and all others will be removed. This will limit vocabulary size to 50000 words. The vocabulary will be simply a list of words and when we transform our individual text document to a list of floats, it'll have 50000 floats and the position where a word is present in the dictionary will have a frequency of that word in the text.
If you want to understand the count vectorizer in-depth with examples then we suggest that you go through our below tutorial. It'll help the reader with this tutorial as well. We have explained in-detail how scikit-learn's CountVectorizer and TfIdfVectorizer works.
import sklearn
from jax import numpy as jnp
from sklearn.feature_extraction.text import CountVectorizer
vectorizer = CountVectorizer(max_features=50000)
vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)
X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)
X_train.shape, X_test.shape
import gc
gc.collect()
In this section, we have defined our neural network using Flax. We have created a network of 3 linear/dense layers. The first layer has 128 units, the second layer has 64 units and the final layer has the same units as a number of target classes which is 5 in our case (5 topics). We apply relu (rectified linear unit) activation to the output of each layer except the last layer. We have created a network by creating a class that extends linen.Module class. The layers are defined in setup() method and forward pass is defined in call() method.
In the next cell, we have initialized the network and its parameters (weights & biases of layers). We have printed the weights of layers of the network. We have also performed a forward pass-through network using a few train samples to verify that it works as expected.
from flax import linen
from jax import random
class TextClassifier(linen.Module):
    def setup(self):
        self.linear1 = linen.Dense(features=128, name="DENSE1")
        self.linear2 = linen.Dense(features=64, name="DENSE2")
        self.linear3 = linen.Dense(len(classes), name="DENSE3")
    def __call__(self, inputs):
        x = linen.relu(self.linear1(inputs))
        x = linen.relu(self.linear2(x))
        logits = self.linear3(x)
        return logits #linen.softmax(x)
seed = jax.random.PRNGKey(0)
model = TextClassifier()
params = model.init(seed, X_train[:5])
for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
    print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
preds = model.apply(params, X_train[:5])
preds.shape
Below, we have defined the loss function for our multi-class text classification task. We'll be using cross entropy loss for our purpose. The function takes parameters, input data, and actual target values of that data as input. It then performs a forward pass-through network using parameters and input data to make predictions. Then, it calculates loss using predictions and actual target values and returns it.
def CrossEntropyLoss(weights, input_data, actual):
    logits_preds = model.apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).sum()
In this section, we have first defined a function that we'll be using for training our network in all sections of the tutorial. The function takes train data features (X), train target values (Y), validation data (X_val, Y_val), number of epochs, network parameters (weights), optimizer state, and batch size as input. It then loops through the network number of epochs time. Each time it performs forward pass through network in batches to make predictions, calculate loss, calculate gradients, and update network parameters. At the end of each epoch, it prints training loss, calculates validation data accuracy, and prints it as well.
from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score
def TrainModelInBatches(X, Y, X_val, Y_val, epochs, weights, optimizer_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
        losses = [] ## Record loss of each batch
        for batch in tqdm(batches):
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None
            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
            loss, gradients = value_and_grad(CrossEntropyLoss)(weights, X_batch,Y_batch)
            ## Update Weights
            updates, optimizer_state = optimizer.update(gradients, optimizer_state)
            weights = optax.apply_updates(weights, updates)
            losses.append(loss) ## Record Loss
        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
        Y_val_preds = model.apply(weights, X_val)
        val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
        print("Validation  Accuracy : {:.3f}".format(val_acc))
    return weights
Below, we have actually trained our network using the function created in the previous cell. We have initialized batch size to 256, a number of epochs to 8, and learning rate to 0.001. We have then initialized the network and its parameters. Then, we have initialized Adam optimizer for updating network parameters and created an optimizer state using network parameters. At last, we have called our training function to train the network by providing the necessary parameters.
We can notice from the training loss and validation accuracy getting printed at the end of each epoch that our model has done a decent job at the task.
seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)
model = TextClassifier()
weights = model.init(seed, X_train[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)
final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
In this section, we have evaluated the performance of our network. We have calculated train and test accuracies, precision, recall, and f1-score per target class and confusion matrix. We can notice from confusion matrix that some samples of sci.crypt and com.windows.x are confused with topic category comp.sys.mac.hardware. Due to that, the recall and f1-score for those two categories are a bit less compared to others.
If you want to know in detail about various ML metrics available through scikit-learn then please check the below link where we have covered the majority of them in detail.
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
In this section, we have again trained our neural network using the same word frequency count approach that we used in our previous section. But, this time, we have removed commonly used stop words (like 'the', 'a', 'an', 'then', etc.) from our vocabulary. Hence, these words will be ignored in text documents as they are commonly present in all documents. We need words that actually contribute to a particular category and are exclusive to it to make the model more accurate.
Below, we have vectorized our text data to a list of floats using a count vectorizer from scikit-learn. The code is almost exactly the same as the previous section with one minor change which is that we have provided 'english' value to parameters 'stop_words' which will remove common stop words of English.
import sklearn
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
vectorizer = CountVectorizer(max_features=50000, stop_words="english")
vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)
X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)
X_train.shape, X_test.shape
In this section, we have trained our network on the vectorized dataset. All the training settings are the same as the previous section with only changes in input vectorized data. We can notice from the results that the accuracy has improved little after we removed stop words.
seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)
model = TextClassifier()
weights = model.init(seed, X_train[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)
final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
In this section, we have evaluated the performance of our network by calculating accuracy, classification report, and confusion matrix metrics. We can notice from the results that few samples of categories 'comp.windows.x' and 'sci.crypt' are confused with category 'comp.sys.mac.hardware' as per first column of confusion matrix. The few samples of 'comp.windows.x' and 'comp.windows.x' are also confused with each other. All these 3 categories are related to computers and can have overlapping keywords that could confuse the model.
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
In this section, we are taking our word frequency approach further and trying 2-words and 3-words combinations as well along with single words. All our previous approaches tried vectorization by taking into account only single words. This is generally referred to as 1-gram. We can also try to use 2 consecutive words and 3 consecutive words that come next to each other and all combinations of them to catch little contextual details. We'll check whether this approach helps improve accuracy further.
Below, we have vectorized our data using the scikit-learn count vectorizer. All the code of this section is exactly the same as the previous section with one minor change. We have provided value (1,3) to parameter 'ngram_range' of count vectorizer. This will instruct it to consider all 1-gram, 2-gram, and 3-gram word combinations.
import sklearn
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
vectorizer = CountVectorizer(max_features=50000, stop_words="english", ngram_range=(1,3))
vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)
X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)
X_train.shape, X_test.shape
Below, we have trained our network using the vectorized data we created in the previous cell. We can notice from the results that there is not much improvement in accuracy with this approach though.
seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)
model = TextClassifier()
weights = model.init(seed, X_train[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)
final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
In this section, we have evaluated the performance of our network by calculating accuracy, classification report, and confusion matrix metrics on the test dataset. We can notice from the results that samples of categories 'comp.sys.mac.hardware', 'comp.windows.x' and 'sci.crypt' are confused with each other.
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
In this section, we have used an approach called the term frequency-inverse document frequency. This approach creates a float value for each word such that words that appear very common across all documents get less value and the ones that are unique per document get more value. This will help better understand unique words that actually contribute to a particular classification category.
Please feel free to check the below link if you want to learn in-depth how Tf-Idf works internally. We have explained it in detail over there.
Below, we have vectorized our text data using the Tf-Idf vectorizer available from scikit-learn. We'll be using this vectorized data to train our network.
import sklearn
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
vectorizer = TfidfVectorizer(max_features=50000)
vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)
X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)
X_train.shape, X_test.shape
Below, we have trained our network using Tf-Idf vectorized data. We can notice from the results that they are almost the same as the count vectorizer. There is not much improvement in results.
seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)
model = TextClassifier()
weights = model.init(seed, X_train[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)
final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
In this section, we have evaluated the performance of the network by calculating accuracy, classification report, and confusion matrix metrics on the test dataset. We can notice that as usual, the samples of categories 'comp.sys.mac.hardware', 'comp.windows.x' and 'sci.crypt' are confused with each other.
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
In this section, we have again used Tf-Idf approach but this time we have removed stop words from the dataset. We'll check whether this helps improve results further.
In this section, we have vectorized our text dataset with Tf-Idf vectorizer from scikit-learn. The majority of the code is repeated from previous sections.
import sklearn
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
vectorizer = TfidfVectorizer(max_features=50000, stop_words="english")
vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)
X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)
X_train.shape, X_test.shape
Below, we have trained our network on Tf-Idf vectorized data with stop words removed. We can notice from the results that accuracy is the highest of all the approaches we tried till now.
seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)
model = TextClassifier()
weights = model.init(seed, X_train[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)
final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
In this section, we have evaluated the performance of the network by calculating accuracy, classification report, and confusion matrix metrics on test datasets. As usual, 3 computer science-related categories are confused with each other but this time confusion is a little less compared to previous approaches.
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
In this section, we have tried our last approach where we have used Tf-Idf vectorization with stop words removed and n-grams in the range (1,3) is used.
In this section, we have vectorized our data as usual.
import sklearn
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
vectorizer = TfidfVectorizer(max_features=50000, stop_words="english", ngram_range=(1,3))
vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)
X_train = jnp.array(X_train.toarray(), dtype=jnp.float16)
X_test  = jnp.array(X_test.toarray(), dtype=jnp.float16)
X_train.shape, X_test.shape
Below, we have trained our network using vectorized data created in the previous cell. We can notice from the results that there is not much improvement in accuracy instead it's the lowest of all our approaches.
seed = random.PRNGKey(0)
batch_size=256
epochs=8
learning_rate = jnp.array(1/1e3)
model = TextClassifier()
weights = model.init(seed, X_train[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(weights)
final_weights = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
In this section, we have evaluated the performance of our network by calculating various metrics on our test dataset. As usual, three computer science category samples are confused with each other.
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
train_preds = model.apply(final_weights, X_train)
test_preds = model.apply(final_weights, X_test)
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
This ends our small tutorial explaining how we can perform text classification tasks using Flax (JAX). Please feel free to let us know your views in the comments section.
If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.
When going through coding examples, it's quite common to have doubts and errors.
If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.
You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.
If you want to
 Sunny Solanki
                             Sunny Solanki
                        
                    
                     Comfortable Learning through Video Tutorials?
 Comfortable Learning through Video Tutorials? Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?
 Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code? Want to Share Your Views? Have Any Suggestions?
 Want to Share Your Views? Have Any Suggestions?