Word embeddings is one the most preferred text encoding approach when we are trying to solve any NLP tasks using deep learning algorithms. The other encoding approaches like word frequency, one-hot, Tf-Idf, etc assign just a single scalar value to each word which has very low representation power. It can not capture the different meanings of words in different contexts as well as relationships with other words. The word embeddings on the other hand use a real-valued vector to represent each word. The size of this vector varies a lot. It can be vector of 10,15,20,50,10,200,300,etc scalars. This gives more representation power to the approach which deep learning algorithms can utilize to produce better results. We can generate embeddings for words of our datasets by training them in the network or we can use readily trained embeddings from some other network for our purpose if our dataset is small.
GloVe (Global Vectors) is an unsupervised algorithm that can generate embeddings for words. Stanford professors have already generated GloVe word embeddings for many words by training algorithms on a large corpus of data. They have generated different embeddings of different lengths from twitter and Wikipedia datasets. Glove embeddings are a very good option if your dataset is very small and you can not generate word embeddings on it by training network. Please feel free to check the below link if you are looking for detailed information on GloVe.
As a part of this tutorial, we have explained how we can use GloVe word embeddings with Flax Networks for text classification tasks. Flax is a Python deep learning library designed on top of a low-level Python deep learning library JAX. We have tried different approaches at using GloVe embeddings in the tutorial and compared their results. After training networks, we have also evaluated their performance by calculating various ML metrics. We have even explained predictions made by the network using LIME (Local Interpretable Model-Agnostic Explanations) algorithm.
Please check the below link if you are looking to train your own word embeddings using Flax (JAX) Networks.
Below, we have listed important sections of Tutorial to give an overview of the material covered.
Below, we have imported the necessary Python libraries that we have used in our tutorial and printed their versions for reference purposes.
import jax
print("JAX Version : {}".format(jax.__version__))
import flax
print("Flax Version : {}".format(flax.__version__))
import optax
print("OPTAX Version : {}".format(optax.__version__))
import torchtext
print("Torchtext Version : {}".format(torchtext.__version__))
from tensorflow import keras
print("Keras Version : {}".format(keras.__version__))
import gc
In this section, we'll prepare data for the neural network. As we are going to word embeddings approach for encoding text data and we have decided to use GloVe word embeddings for our purpose, we'll be following the below steps for encoding text data.
So basically, we first convert our text examples to a list of token indexes and then retrieve GloVe embeddings for them by indexing the embedding matrix.
The first 3 steps mentioned above will be performed in this section. The 4th step will be implemented at the embedding layer in the neural network which will map token indexes to their embeddings. We'll set the embedding matrix as the weight matrix of the embedding layer.
Below, we have included an image giving an idea about word embeddings.
In this section, we have simply downloaded GloVe embeddings from the Stanford website and unzipped them. We have downloaded GloVe 840B.300d embeddings that have word embeddings of length 300 for 2.2 Million tokens.
!wget https://nlp.stanford.edu/data/glove.840B.300d.zip
!unzip glove.840B.300d.zip
In this section, we have simply loaded word embeddings in memory. We have created a Python dictionary whose keys are tokens and values are numpy arrays representing embeddings of those tokens.
%%time
import numpy as np
glove_embeddings = {}
with open("glove.840B.300d.txt") as f:
for line in f:
try:
line = line.split()
glove_embeddings[line[0]] = np.array(line[1:], dtype=np.float32)
except:
continue
embeddings = glove_embeddings["the"]
embeddings.shape, embeddings.dtype
gc.collect()
In this section, we have loaded the dataset that we are going to use for our text classification task. We'll be using newsgroups dataset available from scikit-learn. The dataset has text documents for 20 different news categories, though we'll be using 5 categories for our purpose. The dataset is already divided into train and test sets for our convenience.
import numpy as np
from sklearn import datasets
import gc
all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
'talk.politics.misc','talk.religion.misc']
target_classes = ['alt.atheism', 'comp.graphics','rec.autos','rec.sport.hockey','talk.politics.mideast']
X_train_text, Y_train = datasets.fetch_20newsgroups(subset="train", categories=target_classes, return_X_y=True)
X_test_text , Y_test = datasets.fetch_20newsgroups(subset="test", categories=target_classes, return_X_y=True)
classes = np.unique(Y_train)
mapping = dict(zip(classes, target_classes))
len(X_train_text), len(X_test_text), classes, mapping
In this section, we have completed the first two steps of our embedding process that we had explained earlier. We have populated vocabulary and transformed text examples into a list of token indexes using this vocabulary here.
First, we have created an instance of Tokenizer() available from the Python deep learning library Keras. We have called fit_on_texts() method on tokenizer with train and test examples. This will populate the vocabulary of all unique tokens and keep it in the tokenizer object. The vocabulary can be accessed by calling word_index property of the tokenizer object. We have also later printed the size of the vocabulary to show the number of tokens.
Next, we have called texts_to_sequences() method with train and text examples separately. This will tokenize each text example into a list of tokens (words) and retrieve their indexes from the vocabulary. The output of this method will be a list of token indexes. The different text examples can have a different number of tokens but for the neural network, we need fixed input sizes. For this, we have decided to keep maximum tokens per text example at 50. To enforce this, we have called pad_sequences() method on the output of texts_to_sequences() method. This method ensures that all example has exactly 50 tokens. The examples that have more than 50 tokens will be truncated at 50 tokens and examples that have less than 50 tokens will be padded with 0s.
We have also printed the first few train examples to show how text gets encoded to a list of token indexes.
Below, we have explained with a simple example how text example is vectorized.
text = "Hello, How are you? Where are you planning to go?"
tokens = ['hello', ',', 'how', 'are', 'you', '?', 'where',
'are', 'you', 'planning', 'to', 'go', '?']
vocab = {
'hello': 0,
'bye': 1,
'how': 2,
'the': 3,
'welcome': 4,
'are': 5,
'you': 6,
'to': 7,
'<unk>': 8,
}
vector = [0,8,2,4,6,8,8,5,6,8,7,8,8]
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from jax import numpy as jnp
max_tokens = 50 ## Hyperparameter
tokenizer = Tokenizer()
tokenizer.fit_on_texts(X_train_text+X_test_text)
## Vectorizing data to keep 50 words per sample.
X_train_vect = pad_sequences(tokenizer.texts_to_sequences(X_train_text), maxlen=max_tokens, padding="post", truncating="post", value=0.)
X_test_vect = pad_sequences(tokenizer.texts_to_sequences(X_test_text), maxlen=max_tokens, padding="post", truncating="post", value=0.)
print(X_train_vect[:3])
X_train_vect, X_test_vect = jnp.array(X_train_vect, dtype=jnp.int32), jnp.array(X_test_vect, dtype=jnp.int32)
Y_train, Y_test = jnp.array(Y_train), jnp.array(Y_test)
X_train_vect.shape, X_test_vect.shape
print("Vocab Size : {}".format(len(tokenizer.word_index)))
## What is word 13
print(tokenizer.index_word[13])
## How many times it comes in first text document??
print(X_train_text[0]) ## 1 times
In this section, we have created an embedding matrix using Glove embeddings that we had described in 3rd step of our encoding process earlier.
Here, we are looping through tokens of our vocabulary and retrieving GloVe embedding for each token from the glove embedding dictionary that we had loaded earlier. We have stacked embeddings for each token to create an embedding matrix. As we know that the embedding length for GloVe 840B.300d is 300, our embedding matrix shape will be (vocab_len, 300).
We can retrieve embeddings of tokens by indexing this matrix using a token index. To explain it with an example, let's say that the index of 'how' token in our vocabulary is '10' then we can simply index embedding matrix like 'word_embeddings[10]' to retrieve the embedding of the token 'how'.
Later on, we'll set this matrix as the weight matrix of the embedding layer of the network which will take token indexes as input and index this matrix to retrieve embeddings for tokens.
%%time
embed_len = 300
word_embeddings = np.zeros((len(tokenizer.index_word)+1, embed_len))
for idx, word in tokenizer.index_word.items():
word_embeddings[idx] = glove_embeddings.get(word, np.zeros(embed_len))
word_embeddings = jnp.array(word_embeddings)
word_embeddings[1][:10]
Our approach in this section is based on the concept of flattening embeddings. The flattened embeddings are then given to dense layers for processing. After training the network in this section, we have also evaluated various ML metrics for checking the performance of the network.
Below, we have defined the network that we are going to use for our text classification task. The network consists of one embedding layer and two dense layers.
The first layer of our network is the embedding layer. We have created this layer using Embed() constructor from linen module of Flax. We have provided it with vocabulary length as a number of tokens. The embedding length of 300 is set as the value of features parameter. We have initialized the weight of this layer by providing callable to embedding_init parameter. The callable simply returns the embedding matrix that we had created in the previous section. This layer is responsible for mapping token indexes to their embeddings by indexing this weight matrix. The shape of input data to this layer is (batch_size, max_tokens) = (batch_size, 50) and output shape is (batch_size, max_tokens, embed_len) = (batch_size, 50, 300).
Please make a NOTE that as we are using trained embeddings (GloVe embeddings) already, we need to prevent updates to embeddings during the training process. We have done that later by making a minor modification to the training routine that ignores updates to the embedding layer.
The output from embedding layer is flattened so that shape gets transformed from (batch_size, 50, 300) to (batch_size, 50 x 300) = (batch_size, 15000).
This flattened output is then given to our first dense layer that has 100 output units. This transforms shape from (batch_size, 15000) to (batch_size, 100). We have also applied relu activation to the output later.
The relu-activated output of the first dense layer is given to the second dense layer that has 5 output units (same as a number of target classes) for processing. This transforms shape from (batch_size, 100) to (batch_size, 5). The output of the second dense layer is returned as a prediction of our network.
After defining the network, we initialized it and printed the shape of weights and biases of layers. We have also printed a few weights of the embedding layer for verification purposes that it was initialized properly. After initializing the network, we have also performed a forward pass through the network using a few train examples for verification purposes.
Please make a NOTE that we have not described the process of network creation in detail as we expect that reader has little background on Flax library. Please feel free to go through the below link if you are new to the library as it'll help you get started with neural network creation.
from flax import linen
embed_len = 300
class EmbeddingClassifier(linen.Module):
def setup(self):
self.embedding = linen.Embed(num_embeddings = len(tokenizer.word_index)+1, features=embed_len,
name="Word Embeddings", embedding_init=lambda *args: word_embeddings)
self.linear1 = linen.Dense(100, name="Dense1")
self.linear2 = linen.Dense(len(target_classes), name="Dense2")
def __call__(self, X_batch):
x = self.embedding(X_batch)
x = x.reshape(len(X_batch), -1) ## Flatten Embeddings
x = self.linear1(x)
x = linen.relu(x)
logits = self.linear2(x)
return logits
from jax import numpy as jnp
seed = jax.random.PRNGKey(0)
embed_classif = EmbeddingClassifier()
params = embed_classif.init(seed, jax.random.randint(seed, (100, max_tokens), minval=1, maxval=20))
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
if "Embedding" in layer_params[0]:
weights = layer_params[1]["embedding"]
print("\tLayer Weights : {}".format(weights.shape))
else:
weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
params["params"]["Word Embeddings"]["embedding"][1][:10]
preds = embed_classif.apply(params, X_train_vect[:5])
preds
In this section, we have created a loss function that we'll use for our classification task. We'll be using cross entropy loss. The function takes network parameters, data, and target labels as input. It then performs a forward pass through the network to make predictions and one-hot encode target labels. Then, we calculate loss using softmax_cross_entropy() function available from Optax.
def CrossEntropyLoss(params, input_data, actual):
logits_preds = model.apply(params, input_data)
one_hot_actual = jax.nn.one_hot(actual, num_classes=len(target_classes))
return optax.softmax_cross_entropy(logits=logits_preds, labels=one_hot_actual).sum()
In this section, we are training the network that we defined earlier. We have defined a function that will help us with the training process. The function takes train data (X, Y), validation data (X_val, Y_val), number of epochs, network parameters, optimizer state, and batch size as input. It then executes a training loop number of epochs time. For each epoch, it loops through training data in batches. For each batch of data, it performs a forward pass to make predictions, calculates loss, calculates gradients, and updates network parameters. It records loss for each batch and prints average training data loss by averaging the loss of all batches at the end of each epoch. We are also calculating the validation accuracy of the model and printing it for checking network performance on validation data.
Please make a NOTE that we are not updating the weights of the embedding layer. We have included a special line in the training routine that applies tree_map() function to updates which zeros updates to be made to the embedding layer hence preventing any updates to it.
from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score
def TrainModelInBatches(X, Y, X_val, Y_val, epochs, params, optimizer_state, batch_size=32):
for i in range(1, epochs+1):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in tqdm(batches):
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss, gradients = value_and_grad(CrossEntropyLoss)(params, X_batch,Y_batch)
## Update Network Parameters
updates, optimizer_state = optimizer.update(gradients, optimizer_state)
## IMPORTANT: Prevent Updates to Embedding Layer by setting updates to zeros
updates = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=np.float32) if x.shape == word_embeddings.shape else x, updates)
params = optax.apply_updates(params, updates)
losses.append(loss) ## Record Loss
print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
Y_val_preds = model.apply(params, X_val)
val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
print("Validation Accuracy : {:.3f}".format(val_acc))
return params
Below, we are actually training our network using the training routine defined in the previous cell. We have first initialized bath size to 64, a number of epochs to 8, and learning rate to 0.001. Then, we have initialized our text classifier and Adam optimizer. At last, we have called our training routine with the necessary parameters to perform training. We can notice from the loss and accuracy values getting printed that our network is doing a good job at the classification task.
After training, we have verified the final weights of the embedding layer to confirm that it's not updated.
from jax import random
seed = random.PRNGKey(0)
batch_size=64
epochs=8
learning_rate = jnp.array(1e-3)
model = EmbeddingClassifier()
params = model.init(seed, X_train_vect[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)
final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
word_embeddings[1][:10], final_weights["params"]["Word Embeddings"]["embedding"][1][:10], params["params"]["Word Embeddings"]["embedding"][1][:10]
In this section, we have evaluated the performance of our trained network by calculating accuracy score, classification report (precision, recall, and f1-score per target class) and confusion matrix metrics on test predictions. We can notice from the accuracy score that it is above average. We have calculated these ML metrics using scikit-learn.
If you want to learn about various ML metrics available from sklearn then we recommend that you go through the below link. It covers the majority of them in detail.
Apart from calculations, we have also plotted confusion matrix metric using Python library scikit-plot. We can notice from the visualization that our network is doing a good job at classifying text documents from categories 'comp.graphics', 'rec.autos' and 'rec.sport.hockey' compared to categories 'alt.atheism' and 'talk.politics.mideast'.
The scikit-plot provides visualizations for many ML metrics. Please feel free to check the below link in your spare time if you want to learn about it.
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
test_preds = model.apply(final_weights, X_test_vect)
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt
skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
normalize=True,
title="Confusion Matrix",
cmap="Purples",
hide_zeros=True,
figsize=(5,5)
);
plt.xticks(rotation=90);
In this section, we are trying to explain predictions made by our network using LIME (Local Interpretable Model-Agnostic Explanations) algorithm. The LIME is a very commonly used algorithm when trying to explain predictions made by any black-box ML models. We have used the Python library lime which has an implementation of the algorithm.
If you are someone who is new to the concept of LIME then we suggest that you go through the below links. It will give you in-depth knowledge of it as well as get you started with using it.
In order to explain prediction using LIME, we need to create an instance of LimeTextExplainer which we have done below.
from lime import lime_text
explainer = lime_text.LimeTextExplainer(class_names=target_classes, verbose=True)
Below, we have first created a prediction function that takes a batch of text examples and returns their predicted probabilities by the network. Please make a NOTE that network outputs are not probabilities and we have converted them to probabilities by applying softmax activation to it.
Then, we randomly selected a text example from test data and made predictions it using our trained model. The model correctly predicts target label as 'rec.sport.hockey'. Next, we'll create an explanation for this prediction.
def make_predictions(X_batch_text):
X_batch = pad_sequences(tokenizer.texts_to_sequences(X_batch_text), maxlen=50, padding="post", truncating="post", value=0.)
logits = model.apply(final_weights, jnp.array(X_batch))
preds = linen.softmax(logits)
return preds.to_py()
rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))
print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual : ", target_classes[Y_test[idx]])
Below, we have first created an Explanation object by calling explain_instance() method on LimeTextExplainer instance. We have provided the method with a selected text example, prediction function, and target label. The explanation object has details about words contributing to prediction. Next, we have visualized the explanation by calling show_in_notebook() method on it.
The visualization highlights that words like 'leafs', 'games', 'goals', 'goalie', etc contributed to predicting target label as 'rec.sport.hockey'. This makes sense as these are commonly used words in the hockey world.
explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions,
labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()
Our approach in this section has a minor change in the way we handle word embeddings. In this section, we have taken the average of word embeddings across tokens of each example. The majority of the code is exactly the same as earlier with a minor change in the network architecture.
Below, we have defined the network that we'll use for our text classification task. The network has 3 layers like earlier. The only difference is in forward pass logic. Earlier, we had flattened embeddings whereas here, we have taken the average of embeddings across tokens of text examples using mean() function. This will transform data shape from (batch_size, 50, 300) to (batch_size, 300). This output will then be given to a dense layer for processing.
After defining the network, we initialized it and performed a forward pass using a few train examples to make predictions for verification purposes.
from flax import linen
embed_len = 300
class EmbeddingClassifier(linen.Module):
def setup(self):
self.embedding = linen.Embed(num_embeddings = len(tokenizer.word_index)+1, features=embed_len,
name="Word Embeddings", embedding_init=lambda *args: word_embeddings)
self.linear1 = linen.Dense(100, name="Dense1")
self.linear2 = linen.Dense(len(target_classes), name="Dense2")
def __call__(self, X_batch):
x = self.embedding(X_batch)
x = x.mean(axis=1) ## Average word embeddings for each words together
x = self.linear1(x)
x = linen.relu(x)
logits = self.linear2(x)
return logits
from jax import numpy as jnp
seed = jax.random.PRNGKey(0)
embed_classif = EmbeddingClassifier()
params = embed_classif.init(seed, jax.random.randint(seed, (100, max_tokens), minval=1, maxval=20))
preds = embed_classif.apply(params, X_train_vect[:5])
preds
params["params"]["Word Embeddings"]["embedding"][1][:10]
In this section, we have trained the network using the same settings that we had used in our previous section. We can notice from the loss and accuracy values getting printed after each epoch that our network is doing a good job at the classification task.
from jax import random
seed = random.PRNGKey(0)
batch_size=64
epochs=8
learning_rate = jnp.array(1e-3)
model = EmbeddingClassifier()
params = model.init(seed, X_train_vect[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)
final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
word_embeddings[1][:10], final_weights["params"]["Word Embeddings"]["embedding"][1][:10], params["params"]["Word Embeddings"]["embedding"][1][:10]
In this section, we have evaluated the performance of our trained network by calculating accuracy score, classification report and confusion matrix metrics on test predictions. We can notice from the accuracy score that it is quite better compared to our previous approach. We have also plotted a confusion matrix for reference purposes which also hints at improvements in classifying text documents.
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
test_preds = model.apply(final_weights, X_test_vect)
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt
skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
normalize=True,
title="Confusion Matrix",
cmap="Purples",
hide_zeros=True,
figsize=(5,5)
);
plt.xticks(rotation=90);
In this section, we have explained predictions made by our trained network using LIME algorithm. The network correctly predicts the target label as 'rec.sport.hockey'. The visualization shows that words like 'leafs', 'games', 'goalie', 'goals', 'defense', etc are contributing to predicting target label as 'rec.sport.hockey'.
from lime import lime_text
explainer = lime_text.LimeTextExplainer(class_names=target_classes)
rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))
print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual : ", target_classes[Y_test[idx]])
explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()
Our approach in this section is the same as our previous approach with a minor change in that we are taking the sum of embeddings across tokens this time instead. The majority of the code is the same as earlier with only a change in network architecture.
Below, we have defined the network architecture that we'll use for our text classification task. The network has the same layer as earlier. The only difference is in forward pass logic which now uses sum() function to take the sum of embeddings instead of the average.
After defining the network, we initialized it and performed a forward pass to make predictions for verification purposes.
from flax import linen
embed_len = 300
class EmbeddingClassifier(linen.Module):
def setup(self):
self.embedding = linen.Embed(num_embeddings = len(tokenizer.word_index)+1, features=embed_len,
name="Word Embeddings", embedding_init=lambda *args: word_embeddings)
self.linear1 = linen.Dense(100, name="Dense1")
self.linear2 = linen.Dense(len(target_classes), name="Dense2")
def __call__(self, X_batch):
x = self.embedding(X_batch)
x = x.sum(axis=1) ## Sum word embeddings for each words together
x = self.linear1(x)
x = linen.relu(x)
logits = self.linear2(x)
return logits
from jax import numpy as jnp
seed = jax.random.PRNGKey(0)
embed_classif = EmbeddingClassifier()
params = embed_classif.init(seed, jax.random.randint(seed, (100, max_tokens), minval=1, maxval=20))
preds = embed_classif.apply(params, X_train_vect[:5])
preds
params["params"]["Word Embeddings"]["embedding"][1][:10]
In this section, we have trained the network using exactly the same settings that we have used for all our approaches. We can notice from the loss and accuracy score that the network is doing a good job at the given task.
from jax import random
seed = random.PRNGKey(0)
batch_size=64
epochs=8
learning_rate = jnp.array(1e-3)
model = EmbeddingClassifier()
params = model.init(seed, X_train_vect[:5])
optimizer = optax.adam(learning_rate=learning_rate)
optimizer_state = optimizer.init(params)
final_weights = TrainModelInBatches(X_train_vect, Y_train, X_test_vect, Y_test, epochs, params, optimizer_state, batch_size=batch_size)
word_embeddings[1][:10], final_weights["params"]["Word Embeddings"]["embedding"][1][:10], params["params"]["Word Embeddings"]["embedding"][1][:10]
In this section, we have evaluated the performance of our network by calculating ML metrics like accuracy score, classification report and confusion matrix on test predictions. We can notice from the accuracy score that it is a little better compared to the previous approach and the highest of all our approaches. We have even plotted a confusion matrix for reference purposes.
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
test_preds = model.apply(final_weights, X_test_vect)
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=target_classes))
print("\nConfusion Matrix : ")
print(confusion_matrix(Y_test, np.argmax(test_preds, axis=1)))
from sklearn.metrics import confusion_matrix
import scikitplot as skplt
import matplotlib.pyplot as plt
skplt.metrics.plot_confusion_matrix([target_classes[i] for i in Y_test], [target_classes[i] for i in np.argmax(test_preds, axis=1)],
normalize=True,
title="Confusion Matrix",
cmap="Purples",
hide_zeros=True,
figsize=(5,5)
);
plt.xticks(rotation=90);
In this section, we have explained predictions made by the network using LIME algorithm. The network correctly predicts the target category as 'rec.sport.hockey'. The visualization highlights that words like 'leafs', 'games', 'host', 'goalie', etc contributes to predicting target label as 'rec.sport.hockey'.
from lime import lime_text
explainer = lime_text.LimeTextExplainer(class_names=target_classes)
rng = np.random.RandomState(1234)
idx = rng.randint(1, len(X_test_text))
print("Prediction : ", target_classes[model.apply(final_weights, X_test_vect[idx:idx+1]).argmax(axis=-1)[0]])
print("Actual : ", target_classes[Y_test[idx]])
explanation = explainer.explain_instance(X_test_text[idx], classifier_fn=make_predictions, labels=Y_test[idx:idx+1].to_py())
explanation.show_in_notebook()
Approach | Max Tokens | Embedding Length | Test Accuracy (%) |
---|---|---|---|
GloVe Embeddings (840B.300d) Flattened | 50 | 300 | 84.9 |
GloVe Embeddings (840B.300d) Averaged | 50 | 300 | 87.0 |
GloVe Embeddings (840B.300d) Summed | 50 | 300 | 87.6 |
This ends our small tutorial explaining how we can use GloVe embeddings with Flax(JAX) networks for text classification tasks. Please feel free to let us know your views in the comments section.
If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.
When going through coding examples, it's quite common to have doubts and errors.
If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.
You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.
If you want to