Convolutional Neural Networks (CNN) are a class of neural networks that uses convolution layers applying convolution operations on input data to find out patterns. CNN is commonly used for tasks related to visual imagery (like object detection, image classification, image segmentation, etc.) where they generally output perform other types of neural networks. CNNs are commonly used for natural language processing and time-series as well. CNNs generally has quite less number of parameters to train compared to fully connected dense neural network for tasks with a lot of features. If you want to learn in-depth about convolutional neural networks then please feel free to check our blog which covers theory in detail.

As a part of this tutorial, we'll cover how we can create a simple convolutional neural network using **JAX**. **JAX** is a famous framework for designing a neural network that provides functionalities like numpy-like idioms on CPUs/GPUs/TPUs, automatic differentiation, Just-in-time compilation, etc. If readers do not have a background on **JAX** then we recommend that you go through our tutorial on it that covers basics. It'll help you with this tutorial as well.

In this tutorial, we'll be designing a simple convolutional neural network using the high-level **stax** API of **JAX**. We have another tutorial on **stax** API describing how to create simple fully connected neural networks. Please feel free to check it if you are looking for it. It'll help with this tutorial.

Below we have highlighted important sections of the tutorial.

- Simple Convolutional Neural Network
- Load Fashion MNIST Dataset
- Create Neural Network
- Define Loss Function
- Train Neural Network (SGD Optimizer)
- Make Predictions
- Evaluate Model Performance
- Train Neural Network (Adam Optimizer)
- Make Predictions
- Evaluate Model Performance

- Guide to Handle Channels First vs Channels Last

Below we have imported **JAX** and printed the version of it that we'll be using in our tutorial.

```
import jax
print("JAX Version : {}".format(jax.__version__))
```

In this section, we'll explain step by step process to create a convolutional neural network and train it. Below, we have imported the necessary submodules of **JAX** that we'll be using in our tutorial. We have imported **stax** and **optimizers** for creating neural networks and optimizers respectively. The **jax.numpy** module helps with maintaining arrays.

```
from jax.example_libraries import stax, optimizers
from jax import numpy as jnp
```

In this section, we have loaded Fashion MNIST dataset available from **keras**. The dataset has **28x28** size grayscale images of 10 fashion items. The dataset is already divided into the train (60k samples) and test (10k samples) sets. We have converted the dataset loaded as numpy arrays to **JAX** arrays as required by the model built-in **JAX**. Later on, we have resized the loaded images to shape **(28,28,1)** where the last extra dimension represents the channel. The convolutional neural networks require channel dimensions. As our images are grayscale and do not have 3-channels (RGB) like color images, we have introduced an extra dimension as a channel for convolution layers which requires it. After reshaping datasets, we have divided both datasets by float 255 to normalize them as currently images are loaded as integers in the range **(0,255)**. This will help the optimization algorithm converge faster.

```
from tensorflow import keras
from sklearn.model_selection import train_test_split
(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)
X_train, X_test = X_train/255.0, X_test/255.0
classes = jnp.unique(Y_train)
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
```

In this section, we have created a convolutional neural network that we'll be using for our fashion MNIST dataset image classification. Our CNN is simple with only 2 convolution layers and one linear/dense layer. The first convolution layer has **32 channels** and a kernel size of **(3,3)**. The second layer convolution layer has **16 channels** and a kernel size of **(3,3)**. Both have a padding of **'SAME'** which indicates that kernel should be applied after padding so that input and output height and width of an image are the same.

The images after first layer will be transformed from shape **(n_samples,height,width,1)** to **(n_samples,height,width,32)**. The first layer will transform 1 channel image into 32 channels. The height and width in our case are **28**. After applying the convolution layer, we have applied **Relu (rectified linear unit)** activation to the output image. The second layer will transform images from shape **(n_samples,height,width,32)** to **(n_samples,height,width,16)**. After the second convolution layer, we have again applied **Relu** activation to the output.

After two convolution layers, we have flattened the output. The new shape now will be **(n_samples, height x width x 16) = (n_samples, 28 x 28 x 16) = (n_samples, 12544)**. After flattening the output, we have added the last dense layer with a number of units the same as our number of classes which is **10**. At last, we have applied **softmax** activation function which maps all 10 output values in the range **[0,1]** such that the sum of all of them is 1. We can then predict image class as the one with the highest value from that 10.

We have created our neural network using **serial()** API of **stax** by giving layers in sequence to it in which they will be applied. The **serial()** function returns two other functions.

**init()**- This function takes**JAX**pseudo-random number generator seed and input data shape as input. It then generates model weights and returns them.**apply()**- This function takes model weights and data as input. It then performs a forward pass of data through the network using weights.

After creating CNN, in the next cell, we have initialized the weights of the neural network using **init()** method and printed their shapes. Then in the following cell, we have performed a forward pass of a few data samples through the network using **apply()** method and printed predictions.

```
conv_init, conv_apply = stax.serial(
stax.Conv(32,(3,3), padding="SAME"),
stax.Relu,
stax.Conv(16, (3,3), padding="SAME"),
stax.Relu,
stax.Flatten,
stax.Dense(len(classes)),
stax.Softmax
)
```

```
rng = jax.random.PRNGKey(123)
weights = conv_init(rng, (18,28,28,1))
weights = weights[1] ## Weights are actually stored in second element of two value tuple
for w in weights:
if w:
w, b = w
print("Weights : {}, Biases : {}".format(w.shape, b.shape))
```

```
preds = conv_apply(weights, X_train[:5])
preds
```

In this section, we have defined the loss function that we'll be using in our case. We'll be using categorical **cross entropy** function as our loss. The function takes weights, input data, and actual target values as input. The function then makes predictions on input data using weights (**apply()** method). Then we have one-hot encoded the actual target values as required by our loss function. Then we have taken the log of our predictions. At last, we have multiplied the log of predictions with one-hot encoded actual target values. Then we have taken the sum of the output array.

Later, we'll be calculating the gradient of this loss function with respect to model weights which is the first parameter of the function.

```
def CrossEntropyLoss(weights, input_data, actual):
preds = conv_apply(weights, input_data)
one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
log_preds = jnp.log(preds)
return - jnp.sum(one_hot_actual * log_preds)
```

In this section, we are actually training our CNN. We have designed a function that takes data and a few other information as input to perform training. The function takes data features, target values, number of epochs, optimizer state (weights), and batch size as input. It then executed the training loop number of epochs time.

For each epoch, it calculates batch indexes for batches of data. It then performs forward pass of batch data through the network to make predictions, calculates loss, and then calculates gradients of loss with respect to weights. It does all that by single call to **value_and_grad()** function of **JAX**. This function takes our loss function as input and returns another function. This new function returned by **value_and_grad()** when called with input values returns two output values. The first value is loss value (cross-entropy) and the second value is gradients of loss with respect to weights. The **value_and_grad()** function calculates the gradient of the loss with respect to the first parameter of the loss function (cross-entropy) which is weights in our case. The forward pass is performed inside of the loss function.

The weights of our neural network are kept in the optimizer state object. We can retrieve weights of the network by calling **get_weights()** method of optimizer by giving optimizer state to it. We have explained below the optimizer state when creating the optimizer below.

After calculating gradients, we have updated the weights of the neural network by calling **update()** method of an optimizer. It takes input gradients and optimizer state (weights). The method returns a new optimizer state which has updated weights.

We are also printing average cross entropy across all batches. At last, we are returning the last updated optimizer state object that has updated weights of the neural network.

```
from jax import value_and_grad
def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
for i in range(1, epochs+1):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss, gradients = value_and_grad(CrossEntropyLoss)(opt_get_weights(opt_state), X_batch,Y_batch)
## Update Weights
opt_state = opt_update(i, gradients, opt_state)
losses.append(loss) ## Record Loss
print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
return opt_state
```

Now, we are actually training our neural network by calling a method from the previous cell. We have first initialized the learning rate to **0.0001**, a number of epochs to **25**, and batch size to **256**.

We have then initialized the weights of our CNN using **init()** method.

Then, we have created **SGD** optimizer using **optimizers** module providing learning rate to it. The optimizer returns three functions.

**init()**- It initialized optimizers with weights of neural network and returns initial**OptimizerState**object which has model weights.**update()**- It takes gradients and optimizer state as input and returns a new optimizer state which has updated weights (weights updated by subtracting learning rate times gradients).**get_weights()**- This method takes**OptimizerState**object as input and returns model weights.

After initializing the optimizer and optimizer state, we have called our function from the previous cell to perform training by providing training data and other parameters. We have stored the final optimizer state returned by a function which we'll use later to make predictions.

We can notice from the cross-entropy loss getting printed that our model seems to be doing a better job. We'll evaluate its performance later by calculating accuracy and classification metrics like precision, recall, and f1-score.

```
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 25
batch_size=256
weights = conv_init(rng, (batch_size,28,28,1))
weights = weights[1]
opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)
final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)
```

In this section, we are making predictions on train and test data. We have created a small function that loops through data in batches and makes predictions. We have then combined predictions of all batches. The output of our CNN is 10 values per sample hence we have included logic to convert those 10 values to the actual class label. We have predicted the index of maximum value as a class label using **argmax()** function.

```
def MakePredictions(weights, input_data, batch_size=32):
batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices
preds = []
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch = input_data[start:end]
if X_batch.shape[0] != 0:
preds.append(conv_apply(weights, X_batch))
return preds
```

```
test_preds = MakePredictions(opt_get_weights(final_opt_state), X_test, batch_size=batch_size)
test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches
test_preds = jnp.argmax(test_preds, axis=1)
train_preds = MakePredictions(opt_get_weights(final_opt_state), X_train, batch_size=batch_size)
train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches
train_preds = jnp.argmax(train_preds, axis=1)
test_preds[:5], train_preds[:5]
```

In this section, we have evaluated the performance of our CNN by calculating the accuracy of train and test predictions. We have used the function available from scikit-learn to calculate the accuracy of train and test predictions.

Then, in the next cell, we have calculated the classification report of test predictions using the function available from scikit-learn. The classification report has information like precision, recall, and f1-score for each target class.

If you want to learn about various metrics available through scikit-learn then please feel free to check our tutorial that covers the majority of them in detail.

```
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
```

```
from sklearn.metrics import classification_report
print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
```

In this section, we have again trained our CNN but this time using **Adam** optimizer instead of **SGD**. The majority of parameters like learning rate, number of epochs, and batch size are the same as previous training. The only change is an optimizer.

```
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 25
batch_size=256
weights = conv_init(rng, (batch_size,28,28,1))
weights = weights[1]
opt_init, opt_update, opt_get_weights = optimizers.adam(learning_rate)
opt_state = opt_init(weights)
final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)
```

In this section, we have made predictions on train and test datasets using new weights calculated using **Adam** optimizer.

```
test_preds = MakePredictions(opt_get_weights(final_opt_state), X_test, batch_size=256)
test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches
test_preds = jnp.argmax(test_preds, axis=1)
train_preds = MakePredictions(opt_get_weights(final_opt_state), X_train)
train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches
train_preds = jnp.argmax(train_preds, axis=1)
test_preds[:5], train_preds[:5]
```

In this section, we have evaluated the performance of our new CNN trained using **Adam** optimizer by calculating accuracy and classification report. It seems from the results that in our case **SGD** optimizer has done a better job compared to **Adam**.

```
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
```

```
from sklearn.metrics import classification_report
print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
```

Color or RGB images have channels for each color red, green, and blue. Other image types like RGBA, CMYK, HSV can have 3 or more channels that are combined to create the final image. The CNN generally applies kernels on channels of images and transforms them. The images are generally represented as multi-dimensional arrays. To represent a single RGB image, we can use two different multi-dimension formats.

**Channels First**- Here, we maintain 3**2D**arrays. The RGB image of shape**(28,28)**pixel is represented as**(3,28,28)**. There are 3 arrays of shape**(28,28)**.**Channels Last**- Here, we maintain details of channels for each pixel together. The RGB image of shape**(28,28)**pixel is represented as**(28,28,3)**. There are 3 values per pixel of**(28,28)**image.

In our example, we had modified our grayscale images from shape **(28,28)** to **(28,28,1)**. We had introduced an extra dimension specifying channel for our convolution layers. The default convolution layer **Conv()** available from **JAX** requires channel details to be present at last. It can not handle if channel details are present first for images.

Our next two cells explain with examples, how default **Conv()** layers of **JAX** properly handle images with channels as the last dimension but fails to handle cases where channels are present first.

The first cell below explains clearly, how it properly transforms an image of shape **(28,28)** from **1** channel to **16** and then **16** channels to **32** channels. The next cell after it fails to transform channels of input images.

```
rng = jax.random.PRNGKey(123)
init_conv1, apply_conv1 = stax.Conv(16, (3,3), padding="SAME")
weights1 = init_conv1(rng, (50,28,28,1))
preds1 = apply_conv1(weights1[1], jax.random.uniform(rng, (50,28,28,1)))
init_conv2, apply_conv2 = stax.Conv(32, (3,3), padding="SAME")
weights2 = init_conv2(rng, preds1.shape)
preds2 = apply_conv2(weights2[1], jax.random.uniform(rng, preds1.shape))
print("Weights of First Conv Layer : {}".format(weights1[1][0].shape))
print("Weights of Second Conv Layer : {}".format(weights2[1][0].shape))
print("\nInput Shape : {}".format((50,28,28,1)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
```

```
rng = jax.random.PRNGKey(123)
init_conv1, apply_conv1 = stax.Conv(16, (3,3), padding="SAME")
weights1 = init_conv1(rng, (50,1,28,28))
preds1 = apply_conv1(weights1[1], jax.random.uniform(rng, (50,1,28,28)))
init_conv2, apply_conv2 = stax.Conv(32, (3,3), padding="SAME")
weights2 = init_conv2(rng, preds1.shape)
preds2 = apply_conv2(weights2[1], jax.random.uniform(rng, preds1.shape))
print("Weights of First Conv Layer : {}".format(weights1[1][0].shape))
print("Weights of Second Conv Layer : {}".format(weights2[1][0].shape))
print("\nInput Shape : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
```

Below, we have printed the definition of **Conv()** layer and we can notice that there are shape details specified as a tuple of three strings. The first string in the tuple represents input image dimensions, the second string represents kernel dimensions and the third string represents output after applying kernel to input images.

**NHWC**- (Number of samples, Height, Width, Channels) - This is input images dimensions.**HWIO**- (Height, Width, Input Channels, Output Channels) - This is kernel dimensions of the convolution layer.

We can notice from the default configuration of **Conv()** layer that it requires channel details present at last.

```
stax.Conv
```

But, what if we have images with channel details present first. The **JAX** provides a layer named **GeneralConv()** that lets us specify the shape of images and kernel as a tuple.

The **GeneralConv()** layer works exactly like **Conv()** but the first input to it is a tuple of 3 strings specifying dimensions of input/output images and kernel. We can handle channels first cases using **GeneralConv()** layer.

Below, we have explained using two simple examples how we can use the **GeneralConv()** layer when channel details are present first.

```
rng = jax.random.PRNGKey(123)
init_conv1, apply_conv1 = stax.GeneralConv(("NCHW", "HWIO", "NCHW"),16, (3,3), padding="SAME")
weights1 = init_conv1(rng, (50,1,28,28))
preds1 = apply_conv1(weights1[1], jax.random.uniform(rng, (50,1,28,28)))
init_conv2, apply_conv2 = stax.GeneralConv(("NCHW", "HWIO", "NCHW"),32, (3,3), padding="SAME")
weights2 = init_conv2(rng, preds1.shape)
preds2 = apply_conv2(weights2[1], jax.random.uniform(rng, preds1.shape))
print("Weights of First Conv Layer : {}".format(weights1[1][0].shape))
print("Weights of Second Conv Layer : {}".format(weights2[1][0].shape))
print("\nInput Shape : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
```

```
rng = jax.random.PRNGKey(123)
init_conv1, apply_conv1 = stax.GeneralConv(("NCHW", "IOHW", "NCHW"),16, (3,3), padding="SAME")
weights1 = init_conv1(rng, (50,1,28,28))
preds1 = apply_conv1(weights1[1], jax.random.uniform(rng, (50,1,28,28)))
init_conv2, apply_conv2 = stax.GeneralConv(("NCHW", "IOHW", "NCHW"),32, (3,3), padding="SAME")
weights2 = init_conv2(rng, preds1.shape)
preds2 = apply_conv2(weights2[1], jax.random.uniform(rng, preds1.shape))
print("Weights of First Conv Layer : {}".format(weights1[1][0].shape))
print("Weights of Second Conv Layer : {}".format(weights2[1][0].shape))
print("\nInput Shape : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
```

This ends our small tutorial explaining how we can create simple convolutional neural networks using **JAX**. Please feel free to let us know your views in the comments section.

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our **YouTube** channel.

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at **coderzcolumn07@gmail.com**. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

If you want to

- provide some suggestions on topic
- share your views
- include some details in tutorial
- suggest some new topics on which we should create tutorials/blogs