Updated On : Jan-12,2022 Time Investment : ~30 mins

Flax: Convolutional Neural Networks (CNN)

FLAX is a high-level framework designed on top of JAX to make the process of neural networks development easier and faster. It is designed to speed up the process hence researchers and developers can perform more experiments in less time. We have already covered the tutorial where we have explained how we can create neural networks using Flax. Please feel free to check it if you are looking for fully connected networks.

As a part of this tutorial, we'll explain how we can create convolutional neural networks (CNNs) using Flax. We'll explain how to use convolution layers to build simple CNNs. Flax earlier used to have a sub-module for optimizers but they have decommissioned it for optax package which implements the majority of optimizers. The Flax team recommends that we use optimizers from optax library.

The tutorial assumes that the reader has a background on JAX and neural network terms like optimization, loss function, activations, etc. If you want to learn about JAX or want to create neural networks using high-level API of JAX then please feel free to check the below links. It'll help you with this tutorial as well.

Below, we have highlighted important sections of the tutorial to give an overview of the material covered.

Important Sections of Tutorial

  1. Simple Convolutional Neural Network (CNN)
    • Load Dataset
    • Create Convolutional Neural Network (CNN)
    • Define Loss Function
    • Train CNN (SGD)
    • Make Predictions
    • Evaluate Model Performance
    • Train CNN (Adam Optimizer)
    • Make Predictions
    • Evaluate Model Performance
  2. Channels First vs Channels Last
import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.2.25
import jax.numpy as jnp
import flax

print("Flax Version : {}".format(flax.__version__))
Flax Version : 0.3.6
import optax

print("Optax Version : {}".format(optax.__version__))
Optax Version : 0.1.0

1. Simple Convolutional Neural Network (CNN)

In this section, we'll explain how we can create a simple CNN using convolution layers to solve classification tasks. We'll be using the Fashion MNIST dataset available from keras for our purpose.

Load Dataset

In this section, we have loaded Fashion MNIST dataset available from keras. The dataset has 60k train images and 10k test images. There are 10 different types of fashion items present in the dataset. The dataset is already divided into train and test sets when we load it from keras. After loading it, we convert datasets from numpy array to JAX arrays as all Flax models work on JAX arrays. Then, we have reshaped datasets and added one extra dimension at the end. This is the channel dimension required by convolution layers. The convolution layers will transform this channel dimension. The color or RGB images already have channel dimension as it has 3 channels. Our dataset has grayscale images with only one channel hence we included an extra dimension to show that channel. We have later divided datasets by float value 255 to bring all values in the array in the range [0,1]. By default, array has values in the range [0,255].

from tensorflow import keras
from sklearn.model_selection import train_test_split

(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)

X_train, X_test = X_train/255.0, X_test/255.0

classes =  jnp.unique(Y_train)

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
40960/29515 [=========================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
26435584/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
16384/5148 [===============================================================================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step
4431872/4422102 [==============================] - 0s 0us/step
((60000, 28, 28, 1), (10000, 28, 28, 1), (60000,), (10000,))

Create Convolutional Neural Network (CNN)

In this section, we have defined our CNN. We have created a CNN by creating a class that extends linen.Module class of Flax. The layers are available through flax.linen module. We have implemented two methods in our CNN class.

  1. setup() - In this method, we have initialized the layers of our CNN.
  2. call() - In this method, we have performed forward pass-through data using layers defined in setup() method. This method returns predictions.

Our CNN uses two convolution layers. The first layer has 32 output channels and a kernel size of (3,3). The second convolution layer has 16 output channels and a kernel size of (3,3). Both convolution layers have padding specified as 'SAME' which indicates that the height and width of the image should be the same after the application of convolution kernels. We apply convolution layers to input data one by one in our call() method. We have applied relu (rectified linear unit) activation function to the output of both convolution layers. We have flattened the output of the second convolution layer using the reshape() method of JAX array. Then, we have applied a linear layer that has 10 units to flatten output. After applying the linear layer, we have performed softmax activation function to the output and returned it. The softmax() activation function will convert 10 values per sample to probabilities whose sum will be equal to 1.

Our input data has shape (n_samples, 28,28,1). The first convolution layer will transform shape from (n_samples,28,28,1) to (n_samples,28,28,32). The second convolution layer will transform shape from (n_samples,28,28,32) to (n_samples,28,28,16). Then flatten operation will transform shape from (n_samples,28,28,16) to (n_samples,28 x 28 x 16) = (n_samples,12544). The linear layer will transform shape from (n_samples,12544) to (n_samples,10) which will be our output shape. Later on, we'll include logic to guess actual class per sample from these 10 values per sample by taking class which has highest probability.

After defining CNN, we have initialized it in the next cell. The initialized model has two important methods.

  1. init(seed, sample_input_data) - This method takes PRNG seed and sample data as input to initialize model weights. It returns model weights as the dictionary-like object.
  2. apply(params, input_data) - This method performs forward pass through the network on given input data using given parameters.

We have initialized the weights of our CNN by calling init() method. We have then printed the shape of weights of various layers as well for information purposes.

Then, in the next cell, we have performed a forward pass through CNN using apply() method for verification purposes. We have given a few data samples as input to make predictions.

from flax import linen
from jax import random

class CNN(linen.Module):
    def setup(self):
        self.conv1 = linen.Conv(features=32, kernel_size=(3,3), padding="SAME", name="CONV1")
        self.conv2 = linen.Conv(features=16, kernel_size=(3,3), padding="SAME", name="CONV2")
        self.linear1 = linen.Dense(len(classes), name="DENSE")

    def __call__(self, inputs):
        x = linen.relu(self.conv1(inputs))
        x = linen.relu(self.conv2(x))

        x = x.reshape((x.shape[0], -1))
        x = self.linear1(x)

        return  linen.softmax(x)
seed = jax.random.PRNGKey(0)

model = CNN()
params = model.init(seed, X_train[:5])

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
    print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
Layer Name : CONV1
	Layer Weights : (3, 3, 1, 32), Biases : (32,)
Layer Name : CONV2
	Layer Weights : (3, 3, 32, 16), Biases : (16,)
Layer Name : DENSE
	Layer Weights : (12544, 10), Biases : (10,)
preds = model.apply(params, X_train[:5])

preds
DeviceArray([[0.11021595, 0.09791899, 0.10163375, 0.0848286 , 0.11105911,
              0.10501531, 0.09337082, 0.09462234, 0.10987436, 0.09146074],
             [0.11308134, 0.08098681, 0.10621079, 0.10240593, 0.12000108,
              0.09611297, 0.08810329, 0.08588154, 0.11353485, 0.09368135],
             [0.10423359, 0.09533928, 0.1022153 , 0.10492188, 0.11806756,
              0.096218  , 0.09491643, 0.09707405, 0.09605328, 0.09096077],
             [0.09996213, 0.08457012, 0.10385586, 0.10420775, 0.12418522,
              0.09746528, 0.08569309, 0.1008865 , 0.11318334, 0.08599078],
             [0.10923128, 0.08600571, 0.10318682, 0.10372809, 0.14236486,
              0.09100852, 0.07801039, 0.08804696, 0.10861228, 0.08980505]],            dtype=float32)

Define Loss Function

In this section, we have defined the loss function for our multi-class classification task. We'll be using cross entropy loss for our task. The function takes model parameters, input data, and actual target values as input. It then makes predictions using apply() method by giving model parameters and input data to it. We then convert actual target values to one-hot encoded values. Then, we take the log of model predictions. At last, we multiply one-hot encoded actual target to the log of predictions. We then return the sum of all values of the returned array.

def CrossEntropyLoss(weights, input_data, actual):
    preds = model.apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

Train CNN (SGD)

In this section, we are training our CNN. We'll be using a gradient descent optimizer available from optax library. We have designed a small function that has logic to train our CNN. We'll be calling this function to perform training in the next cell.

The function takes data features, actual target values, number of epochs, model parameters, optimizer state object, and batch size as input. It then executes the training loop number of epochs time.

For each epoch, it calculates the start and end indexes of batches of data. It then loops through data in batches using these batch indexes. For each batch, it performs forward pass through networks to make predictions and calculates loss using these predictions and actual target values. Then, it calculates the gradients of loss with respect to model parameters. It does all these steps using value_and_grad() method of JAX. This method takes as input any function that operates on JAX arrays and returns another function. We can call this returned function using the same parameters as our main function. It'll return two values. The first value will be the actual value of that function and the second value will be a gradient of the output of that function with respect to the first input parameter of the method.

In our case, we have given our loss function to method value_and_grad(). It returns loss value and gradients of loss with respect to model parameters when called using input values. Then, we call update() method on the optimizer object using gradients and optimizer state. It'll return updates to be made to model parameters and a new optimizer state. We update model weights using apply_updates() method of optax by giving model weights and updates to it. It returns updated model parameters/weights. We also record the loss of each batch. Once all epochs are completed, we return the final updated weights from the function.

from jax import value_and_grad

def TrainModelInBatches(X, Y, epochs, weights, optimizer_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss)(weights, X_batch,Y_batch)

            ## Update Weights
            updates, optimizer_state = optimizer.update(gradients, optimizer_state)
            weights = optax.apply_updates(weights, updates)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))

    return weights

In the below cell, we have first initialized necessary things to train our CNN and then called function from the previous cell to train our CNN.

We have initialized batch size to 256, a number of epochs to 15, and learning rate to 0.0001. Then, we have initialized our model and its weights/parameters. After initializing the model, we have initialized the optimizer for our case. We have initialized SGD() optimizer available from optax by giving a learning rate to it. The optimizer object has two important methods.

  1. init(params) - This method takes model parameters and returns OptimizerState object. It has model weights/parameters wrapped inside it.
  2. update(params,optimizer_state) - This method takes model parameters and optimizer state as input. It then returns updates to be applied to model parameters and the new optimizer state.

We have initialized OptimizerState by calling init() method of the optimizer object by giving model parameters to it. At last, we have called our function from the previous cell to train our CNN. We have provided all the necessary parameters for the method. We can notice from the loss value getting printed after every epoch that our model seems to be doing a good job.

seed = random.PRNGKey(0)
batch_size=256
epochs=15
learning_rate = jnp.array(1/1e4)

model = CNN()
weights = model.init(seed, X_train[:5])

optimizer = optax.sgd(learning_rate=learning_rate) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)

final_weights = TrainModelInBatches(X_train, Y_train, epochs, weights, optimizer_state, batch_size=batch_size)
CrossEntropyLoss : 186.879
CrossEntropyLoss : 116.241
CrossEntropyLoss : 103.818
CrossEntropyLoss : 97.013
CrossEntropyLoss : 92.256
CrossEntropyLoss : 88.584
CrossEntropyLoss : 85.530
CrossEntropyLoss : 82.880
CrossEntropyLoss : 80.526
CrossEntropyLoss : 78.407
CrossEntropyLoss : 76.492
CrossEntropyLoss : 74.721
CrossEntropyLoss : 73.085
CrossEntropyLoss : 71.568
CrossEntropyLoss : 70.153

Make Predictions

In this section, we are making predictions using the trained model on train and test datasets. We have designed a small function that loops through input data in batches and makes predictions. Then, we have combined the predictions of all batches. The function takes final updated model parameters, input data, and batch size as input. It then returns combined predictions.

Our predictions though CNN has 10 values per sample as we had discussed earlier. We have applied softmax activation function to the output of CNN hence the sum of these 10 values per sample will be 1. They are probabilities. To convert these probabilities to the actual target class, we have retrieved the index of the highest probability per sample, and that index value will be our target class prediction.

def MakePredictions(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        if X_batch.shape[0] != 0:
            preds.append(model.apply(weights, X_batch))

    return preds
test_preds = MakePredictions(final_weights, X_test, batch_size=256)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions(final_weights, X_train, batch_size=256)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]
(DeviceArray([9, 2, 1, 1, 6], dtype=int32),
 DeviceArray([9, 0, 0, 3, 1], dtype=int32))

Evaluate Model Performance

In this section, we have evaluated the performance of our CNN by calculating the accuracy of train and test predictions. We have also calculated a classification report on test predictions which has information like precision, recall, and f1-score per target class. We have calculated accuracy and classification report using functions available through scikit-learn.

If you want to learn about various machine learning metrics calculation functions available through scikit-learn then please feel free to check our tutorial that covers the majority of them in detail.

from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
Train Accuracy : 0.903
Test  Accuracy : 0.881
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
Test Classification Report
              precision    recall  f1-score   support

         0.0       0.80      0.87      0.84      1000
         1.0       0.99      0.96      0.98      1000
         2.0       0.75      0.87      0.81      1000
         3.0       0.85      0.93      0.89      1000
         4.0       0.86      0.70      0.77      1000
         5.0       0.98      0.95      0.97      1000
         6.0       0.73      0.64      0.68      1000
         7.0       0.94      0.95      0.94      1000
         8.0       0.97      0.97      0.97      1000
         9.0       0.95      0.96      0.95      1000

    accuracy                           0.88     10000
   macro avg       0.88      0.88      0.88     10000
weighted avg       0.88      0.88      0.88     10000

Train CNN (Adam Optimizer)

In this section, we have trained our CNN again but this time using Adam optimizer. It'll be useful to compare performance with SGD optimizer. All other parameter settings are the same as our SGD training.

seed = random.PRNGKey(0)
batch_size=256
epochs=15
learning_rate = jnp.array(1/1e4)

model = CNN()
weights = model.init(seed, X_train[:5])

optimizer = optax.adam(learning_rate=learning_rate) ## Initialize Adam Optimizer
optimizer_state = optimizer.init(weights)

final_weights = TrainModelInBatches(X_train, Y_train, epochs, weights, optimizer_state, batch_size=batch_size)
CrossEntropyLoss : 192.716
CrossEntropyLoss : 111.581
CrossEntropyLoss : 100.706
CrossEntropyLoss : 94.950
CrossEntropyLoss : 90.966
CrossEntropyLoss : 87.804
CrossEntropyLoss : 85.118
CrossEntropyLoss : 82.761
CrossEntropyLoss : 80.669
CrossEntropyLoss : 78.795
CrossEntropyLoss : 77.107
CrossEntropyLoss : 75.570
CrossEntropyLoss : 74.154
CrossEntropyLoss : 72.841
CrossEntropyLoss : 71.615

Make Predictions

In this section, we have made predictions on train and test sets using our CNN trained with Adam optimizer.

test_preds = MakePredictions(final_weights, X_test, batch_size=256)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions(final_weights, X_train, batch_size=256)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]
(DeviceArray([9, 2, 1, 1, 6], dtype=int32),
 DeviceArray([9, 0, 0, 3, 1], dtype=int32))

Evaluate Model Performance

In this section, we have evaluated the performance of our CNN by calculating the accuracy of train and test predictions. We have also calculated the classification report for test predictions.

from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
Train Accuracy : 0.902
Test  Accuracy : 0.883
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
Test Classification Report
              precision    recall  f1-score   support

         0.0       0.78      0.89      0.83      1000
         1.0       0.99      0.97      0.98      1000
         2.0       0.80      0.83      0.82      1000
         3.0       0.87      0.90      0.88      1000
         4.0       0.79      0.84      0.81      1000
         5.0       0.96      0.97      0.96      1000
         6.0       0.77      0.57      0.65      1000
         7.0       0.95      0.93      0.94      1000
         8.0       0.98      0.96      0.97      1000
         9.0       0.95      0.97      0.96      1000

    accuracy                           0.88     10000
   macro avg       0.88      0.88      0.88     10000
weighted avg       0.88      0.88      0.88     10000

Channels First vs Channels Last

In our example, we had used grayscale images. The grayscale images generally do not have channels as there is only one channel. But as channel dimension is required by the convolution layer, we have introduced that dimension in our images. Generally, color or RGB images have 3 channels. We can represent these channel details in our multi-dimensional array of images in two different ways.

  1. Channels First - Here, we represent color image of (28,28) pixels as (3,28,28) array.
  2. Channels Last - Here, we represent color image of (28,28) pixels as (28,28,3) array.

By default, the convolution layers available through Flax require channel details to be present at last in Channels Last format. It currently can not handle Channels First format.

Below, we have first explained an example where we have created 2 convolution layers and have applied them to input data (Channels Last format). We have printed the output shape from both layers as well as weights shape to give an idea about how convolution operations as applied.

Then, in the next cell, we have again created two convolution layers and have applied them to input data which has channel details present at the beginning (Channels First). We can notice from the output shapes how calculations are going wrong with Channels First format data.

conv_layer1 = flax.linen.Conv(16, (3,3))
conv_layer2 = flax.linen.Conv(32, (3,3))

seed = jax.random.PRNGKey(123)

params1 = conv_layer1.init(seed, jax.random.uniform(seed, (50,28,28,1)))
preds1 = conv_layer1.apply(params1, jax.random.uniform(seed,(50,28,28,1)))

params2 = conv_layer2.init(seed, jax.random.uniform(seed, preds1.shape))
preds2 = conv_layer2.apply(params2, jax.random.uniform(seed,preds1.shape))

print("Weights of First Conv Layer : {}".format(params1["params"]["kernel"].shape))
print("Weights of Second Conv Layer : {}".format(params2["params"]["kernel"].shape))

print("\nInput Shape               : {}".format((50,28,28,1)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
Weights of First Conv Layer : (3, 3, 1, 16)
Weights of Second Conv Layer : (3, 3, 16, 32)

Input Shape               : (50, 28, 28, 1)
Conv Layer 1 Output Shape : (50, 28, 28, 16)
Conv Layer 2 Output Shape : (50, 28, 28, 32)
conv_layer1 = flax.linen.Conv(16, (3,3))
conv_layer2 = flax.linen.Conv(32, (3,3))

seed = jax.random.PRNGKey(123)

params1 = conv_layer1.init(seed, jax.random.uniform(seed, (50,1,28,28)))
preds1 = conv_layer1.apply(params1, jax.random.uniform(seed,(50,1,28,28)))

params2 = conv_layer2.init(seed, jax.random.uniform(seed, preds1.shape))
preds2 = conv_layer2.apply(params2, jax.random.uniform(seed,preds1.shape))

print("Weights of First Conv Layer : {}".format(params1["params"]["kernel"].shape))
print("Weights of Second Conv Layer : {}".format(params2["params"]["kernel"].shape))

print("\nInput Shape               : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
Weights of First Conv Layer : (3, 3, 28, 16)
Weights of Second Conv Layer : (3, 3, 16, 32)

Input Shape               : (50, 1, 28, 28)
Conv Layer 1 Output Shape : (50, 1, 28, 16)
Conv Layer 2 Output Shape : (50, 1, 28, 32)

This ends our small tutorial explaining how we can design a convolutional neural network (CNN) using a high-level Flax framework designed on top of JAX. We had used optax library for optimizers as Flax has deprecated its module with optimizers. Please feel free to let us know your views in the comments section.

References

Sunny Solanki  Sunny Solanki

YouTube Subscribe Comfortable Learning through Video Tutorials?

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.

Need Help Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.