FLAX is a high-level framework designed on top of JAX to make the process of neural networks development easier and faster. It is designed to speed up the process hence researchers and developers can perform more experiments in less time. We have already covered the tutorial where we have explained how we can create neural networks using Flax. Please feel free to check it if you are looking for fully connected networks.
As a part of this tutorial, we'll explain how we can create convolutional neural networks (CNNs) using Flax. We'll explain how to use convolution layers to build simple CNNs. Flax earlier used to have a sub-module for optimizers but they have decommissioned it for optax package which implements the majority of optimizers. The Flax team recommends that we use optimizers from optax library.
The tutorial assumes that the reader has a background on JAX and neural network terms like optimization, loss function, activations, etc. If you want to learn about JAX or want to create neural networks using high-level API of JAX then please feel free to check the below links. It'll help you with this tutorial as well.
Below, we have highlighted important sections of the tutorial to give an overview of the material covered.
import jax
print("JAX Version : {}".format(jax.__version__))
import jax.numpy as jnp
import flax
print("Flax Version : {}".format(flax.__version__))
import optax
print("Optax Version : {}".format(optax.__version__))
In this section, we'll explain how we can create a simple CNN using convolution layers to solve classification tasks. We'll be using the Fashion MNIST dataset available from keras for our purpose.
In this section, we have loaded Fashion MNIST dataset available from keras. The dataset has 60k train images and 10k test images. There are 10 different types of fashion items present in the dataset. The dataset is already divided into train and test sets when we load it from keras. After loading it, we convert datasets from numpy array to JAX arrays as all Flax models work on JAX arrays. Then, we have reshaped datasets and added one extra dimension at the end. This is the channel dimension required by convolution layers. The convolution layers will transform this channel dimension. The color or RGB images already have channel dimension as it has 3 channels. Our dataset has grayscale images with only one channel hence we included an extra dimension to show that channel. We have later divided datasets by float value 255 to bring all values in the array in the range [0,1]. By default, array has values in the range [0,255].
from tensorflow import keras
from sklearn.model_selection import train_test_split
(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)
X_train, X_test = X_train/255.0, X_test/255.0
classes = jnp.unique(Y_train)
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
In this section, we have defined our CNN. We have created a CNN by creating a class that extends linen.Module class of Flax. The layers are available through flax.linen module. We have implemented two methods in our CNN class.
Our CNN uses two convolution layers. The first layer has 32 output channels and a kernel size of (3,3). The second convolution layer has 16 output channels and a kernel size of (3,3). Both convolution layers have padding specified as 'SAME' which indicates that the height and width of the image should be the same after the application of convolution kernels. We apply convolution layers to input data one by one in our call() method. We have applied relu (rectified linear unit) activation function to the output of both convolution layers. We have flattened the output of the second convolution layer using the reshape() method of JAX array. Then, we have applied a linear layer that has 10 units to flatten output. After applying the linear layer, we have performed softmax activation function to the output and returned it. The softmax() activation function will convert 10 values per sample to probabilities whose sum will be equal to 1.
Our input data has shape (n_samples, 28,28,1). The first convolution layer will transform shape from (n_samples,28,28,1) to (n_samples,28,28,32). The second convolution layer will transform shape from (n_samples,28,28,32) to (n_samples,28,28,16). Then flatten operation will transform shape from (n_samples,28,28,16) to (n_samples,28 x 28 x 16) = (n_samples,12544). The linear layer will transform shape from (n_samples,12544) to (n_samples,10) which will be our output shape. Later on, we'll include logic to guess actual class per sample from these 10 values per sample by taking class which has highest probability.
After defining CNN, we have initialized it in the next cell. The initialized model has two important methods.
We have initialized the weights of our CNN by calling init() method. We have then printed the shape of weights of various layers as well for information purposes.
Then, in the next cell, we have performed a forward pass through CNN using apply() method for verification purposes. We have given a few data samples as input to make predictions.
from flax import linen
from jax import random
class CNN(linen.Module):
def setup(self):
self.conv1 = linen.Conv(features=32, kernel_size=(3,3), padding="SAME", name="CONV1")
self.conv2 = linen.Conv(features=16, kernel_size=(3,3), padding="SAME", name="CONV2")
self.linear1 = linen.Dense(len(classes), name="DENSE")
def __call__(self, inputs):
x = linen.relu(self.conv1(inputs))
x = linen.relu(self.conv2(x))
x = x.reshape((x.shape[0], -1))
x = self.linear1(x)
return linen.softmax(x)
seed = jax.random.PRNGKey(0)
model = CNN()
params = model.init(seed, X_train[:5])
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
preds = model.apply(params, X_train[:5])
preds
In this section, we have defined the loss function for our multi-class classification task. We'll be using cross entropy loss for our task. The function takes model parameters, input data, and actual target values as input. It then makes predictions using apply() method by giving model parameters and input data to it. We then convert actual target values to one-hot encoded values. Then, we take the log of model predictions. At last, we multiply one-hot encoded actual target to the log of predictions. We then return the sum of all values of the returned array.
def CrossEntropyLoss(weights, input_data, actual):
preds = model.apply(weights, input_data)
one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
log_preds = jnp.log(preds)
return - jnp.sum(one_hot_actual * log_preds)
In this section, we are training our CNN. We'll be using a gradient descent optimizer available from optax library. We have designed a small function that has logic to train our CNN. We'll be calling this function to perform training in the next cell.
The function takes data features, actual target values, number of epochs, model parameters, optimizer state object, and batch size as input. It then executes the training loop number of epochs time.
For each epoch, it calculates the start and end indexes of batches of data. It then loops through data in batches using these batch indexes. For each batch, it performs forward pass through networks to make predictions and calculates loss using these predictions and actual target values. Then, it calculates the gradients of loss with respect to model parameters. It does all these steps using value_and_grad() method of JAX. This method takes as input any function that operates on JAX arrays and returns another function. We can call this returned function using the same parameters as our main function. It'll return two values. The first value will be the actual value of that function and the second value will be a gradient of the output of that function with respect to the first input parameter of the method.
In our case, we have given our loss function to method value_and_grad(). It returns loss value and gradients of loss with respect to model parameters when called using input values. Then, we call update() method on the optimizer object using gradients and optimizer state. It'll return updates to be made to model parameters and a new optimizer state. We update model weights using apply_updates() method of optax by giving model weights and updates to it. It returns updated model parameters/weights. We also record the loss of each batch. Once all epochs are completed, we return the final updated weights from the function.
from jax import value_and_grad
def TrainModelInBatches(X, Y, epochs, weights, optimizer_state, batch_size=32):
for i in range(1, epochs+1):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss, gradients = value_and_grad(CrossEntropyLoss)(weights, X_batch,Y_batch)
## Update Weights
updates, optimizer_state = optimizer.update(gradients, optimizer_state)
weights = optax.apply_updates(weights, updates)
losses.append(loss) ## Record Loss
print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
return weights
In the below cell, we have first initialized necessary things to train our CNN and then called function from the previous cell to train our CNN.
We have initialized batch size to 256, a number of epochs to 15, and learning rate to 0.0001. Then, we have initialized our model and its weights/parameters. After initializing the model, we have initialized the optimizer for our case. We have initialized SGD() optimizer available from optax by giving a learning rate to it. The optimizer object has two important methods.
We have initialized OptimizerState by calling init() method of the optimizer object by giving model parameters to it. At last, we have called our function from the previous cell to train our CNN. We have provided all the necessary parameters for the method. We can notice from the loss value getting printed after every epoch that our model seems to be doing a good job.
seed = random.PRNGKey(0)
batch_size=256
epochs=15
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
optimizer = optax.sgd(learning_rate=learning_rate) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights = TrainModelInBatches(X_train, Y_train, epochs, weights, optimizer_state, batch_size=batch_size)
In this section, we are making predictions using the trained model on train and test datasets. We have designed a small function that loops through input data in batches and makes predictions. Then, we have combined the predictions of all batches. The function takes final updated model parameters, input data, and batch size as input. It then returns combined predictions.
Our predictions though CNN has 10 values per sample as we had discussed earlier. We have applied softmax activation function to the output of CNN hence the sum of these 10 values per sample will be 1. They are probabilities. To convert these probabilities to the actual target class, we have retrieved the index of the highest probability per sample, and that index value will be our target class prediction.
def MakePredictions(weights, input_data, batch_size=32):
batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices
preds = []
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch = input_data[start:end]
if X_batch.shape[0] != 0:
preds.append(model.apply(weights, X_batch))
return preds
test_preds = MakePredictions(final_weights, X_test, batch_size=256)
test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches
test_preds = jnp.argmax(test_preds, axis=1)
train_preds = MakePredictions(final_weights, X_train, batch_size=256)
train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches
train_preds = jnp.argmax(train_preds, axis=1)
test_preds[:5], train_preds[:5]
In this section, we have evaluated the performance of our CNN by calculating the accuracy of train and test predictions. We have also calculated a classification report on test predictions which has information like precision, recall, and f1-score per target class. We have calculated accuracy and classification report using functions available through scikit-learn.
If you want to learn about various machine learning metrics calculation functions available through scikit-learn then please feel free to check our tutorial that covers the majority of them in detail.
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
from sklearn.metrics import classification_report
print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
In this section, we have trained our CNN again but this time using Adam optimizer. It'll be useful to compare performance with SGD optimizer. All other parameter settings are the same as our SGD training.
seed = random.PRNGKey(0)
batch_size=256
epochs=15
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
optimizer = optax.adam(learning_rate=learning_rate) ## Initialize Adam Optimizer
optimizer_state = optimizer.init(weights)
final_weights = TrainModelInBatches(X_train, Y_train, epochs, weights, optimizer_state, batch_size=batch_size)
In this section, we have made predictions on train and test sets using our CNN trained with Adam optimizer.
test_preds = MakePredictions(final_weights, X_test, batch_size=256)
test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches
test_preds = jnp.argmax(test_preds, axis=1)
train_preds = MakePredictions(final_weights, X_train, batch_size=256)
train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches
train_preds = jnp.argmax(train_preds, axis=1)
test_preds[:5], train_preds[:5]
In this section, we have evaluated the performance of our CNN by calculating the accuracy of train and test predictions. We have also calculated the classification report for test predictions.
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
from sklearn.metrics import classification_report
print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
In our example, we had used grayscale images. The grayscale images generally do not have channels as there is only one channel. But as channel dimension is required by the convolution layer, we have introduced that dimension in our images. Generally, color or RGB images have 3 channels. We can represent these channel details in our multi-dimensional array of images in two different ways.
By default, the convolution layers available through Flax require channel details to be present at last in Channels Last format. It currently can not handle Channels First format.
Below, we have first explained an example where we have created 2 convolution layers and have applied them to input data (Channels Last format). We have printed the output shape from both layers as well as weights shape to give an idea about how convolution operations as applied.
Then, in the next cell, we have again created two convolution layers and have applied them to input data which has channel details present at the beginning (Channels First). We can notice from the output shapes how calculations are going wrong with Channels First format data.
conv_layer1 = flax.linen.Conv(16, (3,3))
conv_layer2 = flax.linen.Conv(32, (3,3))
seed = jax.random.PRNGKey(123)
params1 = conv_layer1.init(seed, jax.random.uniform(seed, (50,28,28,1)))
preds1 = conv_layer1.apply(params1, jax.random.uniform(seed,(50,28,28,1)))
params2 = conv_layer2.init(seed, jax.random.uniform(seed, preds1.shape))
preds2 = conv_layer2.apply(params2, jax.random.uniform(seed,preds1.shape))
print("Weights of First Conv Layer : {}".format(params1["params"]["kernel"].shape))
print("Weights of Second Conv Layer : {}".format(params2["params"]["kernel"].shape))
print("\nInput Shape : {}".format((50,28,28,1)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
conv_layer1 = flax.linen.Conv(16, (3,3))
conv_layer2 = flax.linen.Conv(32, (3,3))
seed = jax.random.PRNGKey(123)
params1 = conv_layer1.init(seed, jax.random.uniform(seed, (50,1,28,28)))
preds1 = conv_layer1.apply(params1, jax.random.uniform(seed,(50,1,28,28)))
params2 = conv_layer2.init(seed, jax.random.uniform(seed, preds1.shape))
preds2 = conv_layer2.apply(params2, jax.random.uniform(seed,preds1.shape))
print("Weights of First Conv Layer : {}".format(params1["params"]["kernel"].shape))
print("Weights of Second Conv Layer : {}".format(params2["params"]["kernel"].shape))
print("\nInput Shape : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
This ends our small tutorial explaining how we can design a convolutional neural network (CNN) using a high-level Flax framework designed on top of JAX. We had used optax library for optimizers as Flax has deprecated its module with optimizers. Please feel free to let us know your views in the comments section.
If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.
When going through coding examples, it's quite common to have doubts and errors.
If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.
You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.
If you want to