**Haiku** is built on top of **JAX** to simplify machine learning research. **JAX** is a numerical computing library that provides functionalities like numpy-like API on accelerators (GPUs/TPUs), automatic gradients, just-in-time compilation, etc. **Haiku** makes the process of neural network development easier. **Haiku** allows users to create a neural network using OOPs concepts and then later helps translate classes to **JAX** based functions. We have already covered how we can use **Haiku** to create neural networks where we have explained how to use it to create simple neural networks. Please feel free to check the below link if you are looking for it. It'll help you with this tutorial as well. We recommend that the reader goes through the below tutorial as it'll help with this one with some background.

As a part of this tutorial, we'll create convolutional neural networks (CNNs) using **Haiku**. We'll be creating a simple CNN with a few convolution layers to solve a simple classification task involving the Fashion MNIST dataset. If you want to know the theory behind CNNs and their pros/cons then please feel free to check our blog from the below link.

As **Haiku** is built on top of **JAX**, the background of **JAX** will help the reader with this tutorial. If you want to know about **JAX** then please feel free to check our small tutorial on it.

Below, we have highlighted important sections of tutorial to give an overview of the material covered.

- Simple Convolutional Neural Network
- Load Dataset
- Create CNN
- Create CNN by Defining Individual Layers
- Create CNN using Sequential API of Haiku

- Define Loss Function
- Train CNN (SGD)
- Make Predictions
- Evaluate Model Performance
- Train CNN (Adam)
- Make Predictions
- Evaluate Model Performance

- Guide to Handle Channels First vs Channels Last

Below, we have imported libraries that we'll be using for our tutorial. We'll be using optimizers available from **optax** library in our example.

```
import haiku as hk
print("Haiku Version : {}".format(hk.__version__))
```

```
import jax
print("JAX Version : {}".format(jax.__version__))
```

```
import optax
print("Optax Version : {}".format(optax.__version__))
```

In this section, we'll explain a step-by-step process to create and train CNN using **Haiku**. We'll be creating a small convolutional neural network with 2 convolution layers for our simple classification task involving the fashion MNIST dataset.

In this section, we have loaded the Fashion MNIST dataset available from Keras that we'll be using for our purpose. It has grayscale images (**28 x 28**) of 10 different fashion items like boots, shirts, pants, etc. The dataset is already divided into the train (60k images) and test (10k images) sets by keras. After loading datasets, we have converted them to **JAX** arrays from numpy arrays. Then, we have reshaped the dataset and introduced one extra dimension at the end. This extra dimension is generally referred to as a channel in computer vision. The color or RGB images have 3 channels (one for each color Red, Green, and Blue) whereas grayscale images do not have channels hence we introduced one. The reason behind introducing an extra channel is that convolution layers work on channels of input data and transform them hence we need to introduce channels in the case of grayscale images if they are not present already in data arrays. Then, we have divided both train and test images by float value 255. This will bring values of images in the range **[0,1]** which will help optimization algorithms like gradients descent to converge faster.

```
from jax import numpy as jnp
```

```
from tensorflow import keras
from sklearn.model_selection import train_test_split
(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)
X_train, X_test = X_train/255.0, X_test/255.0
classes = jnp.unique(Y_train)
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
```

In this section, we have created a CNN that we'll be using for our multi-class classification task. We have explained two different ways of creating neural networks.

In this section, we have created a class by extending **hk.Module** class of **Haiku** for defining our CNN. We need to implement two methods when we define neural network by extending **hk.Module** class.

- In this method, we can define the layers of our model or whole model.**init**()- In this method, we perform forward pass through network by using layers defined in**call**()method. It returns predictions at last.**init**()

In our case below, we have defined two convolution layers, a flattening layer and one dense/linear layer in our ** init()** method. The first convolution layer has

The ** call()** method takes a batch of data as input. It then applies two convolution layers, flattens the output of the second convolution layer, and applies a linear layer to the output. We have applied

Our input shape to CNN is **(n_samples,28,28,1)**. The first convolution layer will translate this shape from **(n_samples,28,28,1)** to **(n_samples,28,28,32)**. The second convolution layer will translate the shape from **(n_samples,28,28,32)** to **(n_samples,28,28,16)**. The flatten layer will translate the data from shape **(n_samples,28,28,16)** to **(n_samples, 28 x 28 x 16) = (n_samples,12544)**. And final linear layer will translate data from shape **(n_samples,12544)** to **(n_samples,10)**.

The **JAX** framework is designed to work with functions hence we need to translate our CNN class. In order to do that, we have created a method that takes input data, initializes CNN, applies CNN to input data, and returns predictions. We then transform this function using **transform()** function of **Haiku** that will return us a **Transformed** object that has two important methods.

**init(seed,input_data)**- We can call this method on**Transformed**object. It takes**JAX**pseudo-random seed and a few data samples as input. It then returns a dictionary-like object that has model parameters (weights & biases) for each layer.**apply(params,seed,input_data)**- This method can be called on**Transformed**object. It takes model parameters, pseudo-random seed, and input data. It then performs a forward pass-through network using weights and returns predictions.

After transforming network using **transform()** method, we have initialized model weights using **init()** method. We have also printed the shape of model parameters. We have then also called **apply()** method to perform a forward pass through the network using a few data samples and make predictions. We can notice from the output that the model seems to be working as expected.

```
class CNN(hk.Module):
def __init__(self):
super().__init__(name="CNN")
self.conv1 = hk.Conv2D(output_channels=32, kernel_shape=(3,3), padding="SAME")
self.conv2 = hk.Conv2D(output_channels=16, kernel_shape=(3,3), padding="SAME")
self.flatten = hk.Flatten()
self.linear = hk.Linear(len(classes))
def __call__(self, x_batch):
x = self.conv1(x_batch)
x = jax.nn.relu(x)
x = self.conv2(x)
x = jax.nn.relu(x)
x = self.flatten(x)
x = self.linear(x)
x = jax.nn.softmax(x)
return x
```

```
def ConvNet(x):
cnn = CNN()
return cnn(x)
conv_net = hk.transform(ConvNet)
```

```
rng = jax.random.PRNGKey(42)
params = conv_net.init(rng, X_train[:5])
print("Weights Type : {}\n".format(type(params)))
for layer_name, weights in params.items():
print(layer_name)
print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))
```

```
preds = conv_net.apply(params, rng, X_train[:5])
preds[:5]
```

In this section, we have explained the second way of creating CNN using **Haiku**. We have created a class like last time by extending **hk.Module** class. Inside of ** init()** method, we have defined the model by stacking layers inside of

After defining CNN, we have transformed it using **transform()** function like earlier. We have then initialized model parameters using **init()** method and printed their shape as well. We have also performed forward pass through network using **apply()** method with few samples for verification purposes.

```
class CNN(hk.Module):
def __init__(self):
super().__init__(name="CNN")
self.conv_model = hk.Sequential([
hk.Conv2D(output_channels=32, kernel_shape=(3,3), padding="SAME"),
jax.nn.relu,
hk.Conv2D(output_channels=16, kernel_shape=(3,3), padding="SAME"),
jax.nn.relu,
hk.Flatten(),
hk.Linear(len(classes)),
jax.nn.softmax
])
def __call__(self, x_batch):
return self.conv_model(x_batch)
```

```
def ConvNet(x):
cnn = CNN()
return cnn(x)
conv_net = hk.transform(ConvNet)
```

```
rng = jax.random.PRNGKey(42)
params = conv_net.init(rng, X_train[:5])
print("Weights Type : {}\n".format(type(params)))
for layer_name, weights in params.items():
print(layer_name)
print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))
```

```
preds = conv_net.apply(params, rng, X_train[:5])
preds[:5]
```

In this section, we have defined the loss function for our multi-class classification task. We'll be using cross entropy loss for our task. The loss function takes model parameters, input data, and actual target values as input. It then makes predictions on input data using **apply()** method using model parameters. It then one hot encodes the target values. Then, it calculates the log of predictions. We then multiply the log of predictions with one-hot encoded target values. At last, we sum the values of the result and return them. Later on, when we'll calculate the loss of this function with respect to parameters it'll be with respect to the first argument of this function.

```
def CrossEntropyLoss(weights, input_data, actual):
preds = conv_net.apply(weights, rng, input_data)
one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
log_preds = jnp.log(preds)
return - jnp.sum(one_hot_actual * log_preds)
```

In this section, we are training our CNN. To train CNN, we have initialized the number of epochs to **25**, batch size to **256**, and learning rate to **0.001**. We have also initialized the model and its parameters. After initialization, we have executed the training loop number of epochs times.

For each epoch, we are calculating start and end indexes of batches of data. We then loop through data in batches using these indexes. In order to perform a forward pass through the network and calculate gradients with respect to loss, we have used **value_and_grad()** function available from **JAX**. This function takes as input function and returns another function that can calculate the gradient of the input function with respect to the first parameter. In our case, we have give our loss function to **value_and_grad()** function. The returned function can calculate gradients of loss with respect to the first parameter which is model weights. We can call returned function with the same parameters as the input function. The function wrapped with **value_and_grad()** return two values as output. The first value is the actual value of the function with input parameters and the second value is the gradients of the function with respect to the first parameter.

After calculating gradients, we have updated model weights using **JAX** utility function named **tree_map()** which applies input function on the leaf of tree data structure. It takes as input a function that updates weights. The **tree_map()** function will update the weights of CNN that are stored in a dictionary-like object. It takes as input parameters update function, actual parameters, and gradients. It then updates parameters by subtracting learning rate times gradients from them. This process of updating weights by subtracting learning rate times gradients are referred to as gradient descent. In our case, it'll be referred to as stochastic gradient descent which is the version of gradient descent that works with batches of data that whole data at once.

We are also recording and printing loss for each epoch. We can notice from the loss getting printed that our model seems to be doing a good job.

```
def UpdateWeights(weights,gradients):
return weights - learning_rate * gradients
```

```
from jax import value_and_grad
rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.
conv_net = hk.transform(ConvNet)
params = conv_net.init(rng, X_train[:5])
epochs = 25
batch_size = 256
learning_rate = jnp.array(1/1e4)
for i in range(1, epochs+1):
batches = jnp.arange((X_train.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X_train[start:end], Y_train[start:end] ## Single batch of data
loss, param_grads = value_and_grad(CrossEntropyLoss)(params, X_batch, Y_batch)
#print(param_grads)
params = jax.tree_map(UpdateWeights, params, param_grads) ## Update Params
losses.append(loss) ## Record Loss
print("CrossEntropy Loss : {:.3f}".format(jnp.array(losses).mean()))
```

In this section, we are using updated model parameters to make predictions on train and test datasets. We have designed a small function that takes updated CNN parameters, input data, and batch size as input. It then loops through data in batches and makes predictions. It then returns predictions of all batches. We have combined the predictions of all batches using **concatenate()** method. As the output of our neural network is probabilities, we need to convert these probabilities to the predicted target class. Our predictions are of shape **(n_samples,10)**. We have 10 probabilities for each data sample. We'll be taking the index of highest probability from these 10 probabilities and making that index a prediction class. We have done that for all data samples using **argmax()** method of **JAX**.

```
def MakePredictions(weights, input_data, batch_size=32):
batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices
preds = []
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch = input_data[start:end]
preds.append(conv_net.apply(weights, rng, X_batch))
return preds
```

```
train_preds = MakePredictions(params, X_train, 256)
train_preds = jnp.concatenate(train_preds).squeeze()
train_preds = train_preds.argmax(axis=1)
test_preds = MakePredictions(params, X_test, 256)
test_preds = jnp.concatenate(test_preds).squeeze()
test_preds = test_preds.argmax(axis=1)
```

In this section, we have evaluated the performance of our CNN by calculating the accuracy of our train and test predictions. We have also calculated classification report on test data that has information like precision, recall, and f1-score for each target class. We have used functions available from scikit-learn to calculate accuracy and classification report.

If you want to learn about functions available to calculate ML metrics through scikit-learn then please check the below link that covers the majority of them in detail.

```
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
```

```
from sklearn.metrics import classification_report
print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
```

In this section, we have again initialized our CNN fresh and trained it but this time using **Adam** optimizer. We have used **Adam** optimizer available from **optax** library. Our code for training in this section is almost the same as our previous training code with few changes. We are maintaining and updating model weights using **Adam** optimizer available from **optax**. We have initialized **Adam** optimizer with learning rate. The optimizer object has two important methods (**init()** and **update()**). The **init()** method of optimizer takes model parameters as input and returns **OptimizerState** object that has model parameters in it. The **update()** method takes gradients and **OptimizerState** object as input and returns updates to be applied to weights and new **OptimizerState** object. We can then update model weights using **apply_updates()** method of **optax** by giving model parameters and updates. It'll return updated weights. We have included training using **Adam** optimizer to explain how we can use **optax** with **Haiku**. We can notice from the loss value getting printed that our model seems to be doing a good job.

```
from jax import value_and_grad
rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.
epochs = 25
batch_size = 256
learning_rate = jnp.array(1/1e4)
conv_net = hk.transform(ConvNet)
params = conv_net.init(rng, X_train[:5])
optimizer = optax.adam(learning_rate=learning_rate) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(params)
for i in range(1, epochs+1):
batches = jnp.arange((X_train.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X_train[start:end], Y_train[start:end] ## Single batch of data
loss, param_grads = value_and_grad(CrossEntropyLoss)(params, X_batch, Y_batch) ## Forward pass, loss and grads calculation
#print(param_grads)
updates, optimizer_state = optimizer.update(param_grads, optimizer_state) ## Calculate parameter updates
params = optax.apply_updates(params, updates) ## Update model weights
#params = jax.tree_map(UpdateWeights, params, param_grads) ## Update Params
losses.append(loss) ## Record Loss
print("CrossEntropy Loss : {:.2f}".format(jnp.array(losses).mean()))
```

In this section, we are making predictions on our train and test datasets using the function that we had defined earlier for making predictions. We are making predictions using updated model weights.

```
train_preds = MakePredictions(params, X_train, 256)
train_preds = jnp.concatenate(train_preds).squeeze()
train_preds = train_preds.argmax(axis=1)
test_preds = MakePredictions(params, X_test, 256)
test_preds = jnp.concatenate(test_preds).squeeze()
test_preds = test_preds.argmax(axis=1)
```

In this section, we are evaluating the performance of our trained CNN by calculating the accuracy of train and test predictions. We have also calculated the classification report on test predictions. We can notice from the performance metrics results that our CNN with **Adam** has almost the same result as with **SGD**.

```
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
```

```
from sklearn.metrics import classification_report
print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
```

In our example above, we have used grayscale images and we introduced channels dimension for images at the end when loading the dataset. There are generally two ways to represent an image with a multi-dimensional tensor.

**Channels First**- Here, we represent color image of**(28,28)**pixels using**(3,28,28)**dimension tensor.**Channels Last**- Here, we represent color image of**(28,28)**pixels using**(28,28,3)**dimension tensor.

By default, the **Conv2D** layer of **Haiku** requires channels’ last format. But we can face the situation where the data format is the channels first. To work with those data formats, **Conv2D** layer of **Haiku** let us provide data format using **data_format** parameter. The default value of that parameter is **NHWC**.

**N**- Number of data samples**H**- Height of Image**W**- Width of Image**C**- Number of Channels.

If our data has channels first data format then we need to provide data format as **NCHW** using **data_format** parameter.

Below, we have explained with examples how we can handle both data formats. We have also shown that if we don't handle data formats properly it can cause issues.

```
def Conv2DFunc1(x):
conv2d = hk.Conv2D(16, (3,3), padding="SAME")
return conv2d(x)
def Conv2DFunc2(x):
conv2d = hk.Conv2D(32, (3,3), padding="SAME")
return conv2d(x)
conv2d1 = hk.transform(Conv2DFunc1)
conv2d2 = hk.transform(Conv2DFunc2)
```

```
rng = jax.random.PRNGKey(0)
params1 = conv2d1.init(rng, jax.random.uniform(rng, (50,28,28,1))) ### Channels Last
preds1 = conv2d1.apply(params1, rng, jax.random.uniform(rng, (50,28,28,1)))
params2 = conv2d2.init(rng, jax.random.uniform(rng, preds1.shape)) ### Channels Last
preds2 = conv2d2.apply(params2, rng, jax.random.uniform(rng, preds1.shape))
print("Weights of First Conv Layer : {}".format(params1["conv2_d"]["w"].shape))
print("Weights of First Conv Layer : {}".format(params2["conv2_d"]["w"].shape))
print("\nInput Shape : {}".format((50,28,28,1)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
```

```
rng = jax.random.PRNGKey(0)
params1 = conv2d1.init(rng, jax.random.uniform(rng, (50,1,28,28))) ### Channels First
preds1 = conv2d1.apply(params1, rng, jax.random.uniform(rng, (50,1,28,28)))
params2 = conv2d2.init(rng, jax.random.uniform(rng, preds1.shape)) ### Channels Last
preds2 = conv2d2.apply(params2, rng, jax.random.uniform(rng, preds1.shape))
print("Weights of First Conv Layer : {}".format(params1["conv2_d"]["w"].shape))
print("Weights of First Conv Layer : {}".format(params2["conv2_d"]["w"].shape))
print("\nInput Shape : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
```

```
def Conv2DFunc1(x):
conv2d = hk.Conv2D(16, (3,3), padding="SAME", data_format="NCHW") # NHWC or NCHW
return conv2d(x)
def Conv2DFunc2(x):
conv2d = hk.Conv2D(32, (3,3), padding="SAME", data_format="NCHW")
return conv2d(x)
conv2d1 = hk.transform(Conv2DFunc1)
conv2d2 = hk.transform(Conv2DFunc2)
```

```
rng = jax.random.PRNGKey(0)
params1 = conv2d1.init(rng, jax.random.uniform(rng, (50,1,28,28))) ### Channels First
preds1 = conv2d1.apply(params1, rng, jax.random.uniform(rng, (50,1,28,28)))
params2 = conv2d2.init(rng, jax.random.uniform(rng, preds1.shape)) ### Channels Last
preds2 = conv2d2.apply(params2, rng, jax.random.uniform(rng, preds1.shape))
print("Weights of First Conv Layer : {}".format(params1["conv2_d"]["w"].shape))
print("Weights of First Conv Layer : {}".format(params2["conv2_d"]["w"].shape))
print("\nInput Shape : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
```

This ends our small tutorial explaining how we can create convolutional neural networks (CNNs) using **Haiku**. Please feel free to let us know your views in the comments section.

- Haiku: Guide to Create Neural Networks
- JAX: Guide to Create Convolutional Neural Networks
- Flax: Convolutional Neural Networks (CNN)
- MXNet: Convolutional Neural Networks (CNN)
- PyTorch - Convolutional Neural Networks
- JAX - (Numpy + Automatic Gradients) on Accelerators (GPUs/TPUs)
- Sonnet: Convolutional Neural Networks (CNNs)

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at **coderzcolumn07@gmail.com**. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

If you want to

- provide some suggestions on topic
- share your views
- include some details in tutorial
- suggest some new topics on which we should create tutorials/blogs