Updated On : Jan-19,2022 Time Investment : ~45 mins

Haiku: Convolutional Neural Networks (CNNs)

Haiku is built on top of JAX to simplify machine learning research. JAX is a numerical computing library that provides functionalities like numpy-like API on accelerators (GPUs/TPUs), automatic gradients, just-in-time compilation, etc. Haiku makes the process of neural network development easier. Haiku allows users to create a neural network using OOPs concepts and then later helps translate classes to JAX based functions. We have already covered how we can use Haiku to create neural networks where we have explained how to use it to create simple neural networks. Please feel free to check the below link if you are looking for it. It'll help you with this tutorial as well. We recommend that the reader goes through the below tutorial as it'll help with this one with some background.

As a part of this tutorial, we'll create convolutional neural networks (CNNs) using Haiku. We'll be creating a simple CNN with a few convolution layers to solve a simple classification task involving the Fashion MNIST dataset. If you want to know the theory behind CNNs and their pros/cons then please feel free to check our blog from the below link.

As Haiku is built on top of JAX, the background of JAX will help the reader with this tutorial. If you want to know about JAX then please feel free to check our small tutorial on it.

Below, we have highlighted important sections of tutorial to give an overview of the material covered.

Important Sections of Tutorial

  1. Simple Convolutional Neural Network
    • Load Dataset
    • Create CNN
      • Create CNN by Defining Individual Layers
      • Create CNN using Sequential API of Haiku
    • Define Loss Function
    • Train CNN (SGD)
    • Make Predictions
    • Evaluate Model Performance
    • Train CNN (Adam)
    • Make Predictions
    • Evaluate Model Performance
  2. Guide to Handle Channels First vs Channels Last

Below, we have imported libraries that we'll be using for our tutorial. We'll be using optimizers available from optax library in our example.

import haiku as hk

print("Haiku Version : {}".format(hk.__version__))
Haiku Version : 0.0.5
import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.2.25
import optax

print("Optax Version : {}".format(optax.__version__))
Optax Version : 0.1.0

1. Simple Convolutional Neural Network

In this section, we'll explain a step-by-step process to create and train CNN using Haiku. We'll be creating a small convolutional neural network with 2 convolution layers for our simple classification task involving the fashion MNIST dataset.

Load Dataset

In this section, we have loaded the Fashion MNIST dataset available from Keras that we'll be using for our purpose. It has grayscale images (28 x 28) of 10 different fashion items like boots, shirts, pants, etc. The dataset is already divided into the train (60k images) and test (10k images) sets by keras. After loading datasets, we have converted them to JAX arrays from numpy arrays. Then, we have reshaped the dataset and introduced one extra dimension at the end. This extra dimension is generally referred to as a channel in computer vision. The color or RGB images have 3 channels (one for each color Red, Green, and Blue) whereas grayscale images do not have channels hence we introduced one. The reason behind introducing an extra channel is that convolution layers work on channels of input data and transform them hence we need to introduce channels in the case of grayscale images if they are not present already in data arrays. Then, we have divided both train and test images by float value 255. This will bring values of images in the range [0,1] which will help optimization algorithms like gradients descent to converge faster.

from jax import numpy as jnp
from tensorflow import keras
from sklearn.model_selection import train_test_split

(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)

X_train, X_test = X_train/255.0, X_test/255.0

classes =  jnp.unique(Y_train)

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
40960/29515 [=========================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
26435584/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
16384/5148 [===============================================================================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step
4431872/4422102 [==============================] - 0s 0us/step
((60000, 28, 28, 1), (10000, 28, 28, 1), (60000,), (10000,))

Create CNN

In this section, we have created a CNN that we'll be using for our multi-class classification task. We have explained two different ways of creating neural networks.

Create CNN by Defining Individual Layers

In this section, we have created a class by extending hk.Module class of Haiku for defining our CNN. We need to implement two methods when we define neural network by extending hk.Module class.

  1. init() - In this method, we can define the layers of our model or whole model.
  2. call() - In this method, we perform forward pass through network by using layers defined in init() method. It returns predictions at last.

In our case below, we have defined two convolution layers, a flattening layer and one dense/linear layer in our init() method. The first convolution layer has 32 output channels and a kernel size of (3,3). The second convolution layer has 16 output channels and a kernel size of (3,3). Both convolution layers have padding set to 'SAME' which will make sure that the input and output shape (height & width) of the image is the same. It'll add padding to make maintain the dimensions of images. The linear layer has 10 output units which are the same as the number of output classes.

The call() method takes a batch of data as input. It then applies two convolution layers, flattens the output of the second convolution layer, and applies a linear layer to the output. We have applied Relu (Rectified Linear Unit) activation function to the output of both convolution layers. We have applied softmax activation function to the output of the last layer before returning it. The softmax function will map 10 values of sample to probabilities in the range [0,1]. The sum of all 10 values per sample will be 1. We'll be predicting the target class as the one which has the highest probability from all 10 returned values per sample.

Our input shape to CNN is (n_samples,28,28,1). The first convolution layer will translate this shape from (n_samples,28,28,1) to (n_samples,28,28,32). The second convolution layer will translate the shape from (n_samples,28,28,32) to (n_samples,28,28,16). The flatten layer will translate the data from shape (n_samples,28,28,16) to (n_samples, 28 x 28 x 16) = (n_samples,12544). And final linear layer will translate data from shape (n_samples,12544) to (n_samples,10).

The JAX framework is designed to work with functions hence we need to translate our CNN class. In order to do that, we have created a method that takes input data, initializes CNN, applies CNN to input data, and returns predictions. We then transform this function using transform() function of Haiku that will return us a Transformed object that has two important methods.

  1. init(seed,input_data) - We can call this method on Transformed object. It takes JAX pseudo-random seed and a few data samples as input. It then returns a dictionary-like object that has model parameters (weights & biases) for each layer.
  2. apply(params,seed,input_data) - This method can be called on Transformed object. It takes model parameters, pseudo-random seed, and input data. It then performs a forward pass-through network using weights and returns predictions.

After transforming network using transform() method, we have initialized model weights using init() method. We have also printed the shape of model parameters. We have then also called apply() method to perform a forward pass through the network using a few data samples and make predictions. We can notice from the output that the model seems to be working as expected.

class CNN(hk.Module):
    def __init__(self):
        super().__init__(name="CNN")
        self.conv1 = hk.Conv2D(output_channels=32, kernel_shape=(3,3), padding="SAME")
        self.conv2 = hk.Conv2D(output_channels=16, kernel_shape=(3,3), padding="SAME")
        self.flatten = hk.Flatten()
        self.linear = hk.Linear(len(classes))

    def __call__(self, x_batch):
        x = self.conv1(x_batch)
        x = jax.nn.relu(x)
        x = self.conv2(x)
        x = jax.nn.relu(x)
        x = self.flatten(x)
        x = self.linear(x)
        x = jax.nn.softmax(x)
        return x
def ConvNet(x):
    cnn = CNN()
    return cnn(x)

conv_net = hk.transform(ConvNet)
rng = jax.random.PRNGKey(42)

params = conv_net.init(rng, X_train[:5])

print("Weights Type : {}\n".format(type(params)))

for layer_name, weights in params.items():
    print(layer_name)
    print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))
Weights Type : <class 'haiku._src.data_structures.FlatMap'>

CNN/~/conv2_d
Weights : (3, 3, 1, 32), Biases : (32,)

CNN/~/conv2_d_1
Weights : (3, 3, 32, 16), Biases : (16,)

CNN/~/linear
Weights : (12544, 10), Biases : (10,)

preds = conv_net.apply(params, rng, X_train[:5])

preds[:5]
DeviceArray([[0.11908672, 0.08093624, 0.11933971, 0.08424285, 0.09047501,
              0.08459727, 0.10602134, 0.1412306 , 0.08764312, 0.08642722],
             [0.0948781 , 0.11035708, 0.12414759, 0.07415   , 0.08888784,
              0.07233351, 0.10289562, 0.1523459 , 0.08882214, 0.09118223],
             [0.10695767, 0.0984986 , 0.11776689, 0.07959657, 0.10304435,
              0.09098944, 0.09596974, 0.11778899, 0.08268595, 0.10670183],
             [0.10329415, 0.10112754, 0.11333219, 0.08018617, 0.1020422 ,
              0.08837533, 0.10473682, 0.11802705, 0.08745508, 0.10142355],
             [0.10525166, 0.10007066, 0.11870965, 0.06995527, 0.09680063,
              0.08000981, 0.10547846, 0.14934564, 0.06930337, 0.10507484]],            dtype=float32)
Create CNN using Sequential API of Haiku

In this section, we have explained the second way of creating CNN using Haiku. We have created a class like last time by extending hk.Module class. Inside of init() method, we have defined the model by stacking layers inside of hk.Sequential() class. This will return an instance of Sequential which will apply layers to input data in the sequence in which they are applied. We have given exactly the same layers with the same configuration as our previous example. In call() method, we have called Sequential object to perform forward pass through input data. This way of defining CNN is almost the same as that of Keras.

After defining CNN, we have transformed it using transform() function like earlier. We have then initialized model parameters using init() method and printed their shape as well. We have also performed forward pass through network using apply() method with few samples for verification purposes.

class CNN(hk.Module):
    def __init__(self):
        super().__init__(name="CNN")
        self.conv_model = hk.Sequential([
                                        hk.Conv2D(output_channels=32, kernel_shape=(3,3), padding="SAME"),
                                        jax.nn.relu,

                                        hk.Conv2D(output_channels=16, kernel_shape=(3,3), padding="SAME"),
                                        jax.nn.relu,

                                        hk.Flatten(),
                                        hk.Linear(len(classes)),
                                        jax.nn.softmax
                                    ])

    def __call__(self, x_batch):
        return self.conv_model(x_batch)
def ConvNet(x):
    cnn = CNN()
    return cnn(x)

conv_net = hk.transform(ConvNet)
rng = jax.random.PRNGKey(42)

params = conv_net.init(rng, X_train[:5])

print("Weights Type : {}\n".format(type(params)))

for layer_name, weights in params.items():
    print(layer_name)
    print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))
Weights Type : <class 'haiku._src.data_structures.FlatMap'>

CNN/~/conv2_d
Weights : (3, 3, 1, 32), Biases : (32,)

CNN/~/conv2_d_1
Weights : (3, 3, 32, 16), Biases : (16,)

CNN/~/linear
Weights : (12544, 10), Biases : (10,)

preds = conv_net.apply(params, rng, X_train[:5])

preds[:5]
DeviceArray([[0.11908672, 0.08093624, 0.11933971, 0.08424285, 0.09047501,
              0.08459727, 0.10602134, 0.1412306 , 0.08764312, 0.08642722],
             [0.0948781 , 0.11035708, 0.12414759, 0.07415   , 0.08888784,
              0.07233351, 0.10289562, 0.1523459 , 0.08882214, 0.09118223],
             [0.10695767, 0.0984986 , 0.11776689, 0.07959657, 0.10304435,
              0.09098944, 0.09596974, 0.11778899, 0.08268595, 0.10670183],
             [0.10329415, 0.10112754, 0.11333219, 0.08018617, 0.1020422 ,
              0.08837533, 0.10473682, 0.11802705, 0.08745508, 0.10142355],
             [0.10525166, 0.10007066, 0.11870965, 0.06995527, 0.09680063,
              0.08000981, 0.10547846, 0.14934564, 0.06930337, 0.10507484]],            dtype=float32)

Define Loss Function

In this section, we have defined the loss function for our multi-class classification task. We'll be using cross entropy loss for our task. The loss function takes model parameters, input data, and actual target values as input. It then makes predictions on input data using apply() method using model parameters. It then one hot encodes the target values. Then, it calculates the log of predictions. We then multiply the log of predictions with one-hot encoded target values. At last, we sum the values of the result and return them. Later on, when we'll calculate the loss of this function with respect to parameters it'll be with respect to the first argument of this function.

def CrossEntropyLoss(weights, input_data, actual):
    preds = conv_net.apply(weights, rng, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

Train CNN (SGD)

In this section, we are training our CNN. To train CNN, we have initialized the number of epochs to 25, batch size to 256, and learning rate to 0.001. We have also initialized the model and its parameters. After initialization, we have executed the training loop number of epochs times.

For each epoch, we are calculating start and end indexes of batches of data. We then loop through data in batches using these indexes. In order to perform a forward pass through the network and calculate gradients with respect to loss, we have used value_and_grad() function available from JAX. This function takes as input function and returns another function that can calculate the gradient of the input function with respect to the first parameter. In our case, we have give our loss function to value_and_grad() function. The returned function can calculate gradients of loss with respect to the first parameter which is model weights. We can call returned function with the same parameters as the input function. The function wrapped with value_and_grad() return two values as output. The first value is the actual value of the function with input parameters and the second value is the gradients of the function with respect to the first parameter.

After calculating gradients, we have updated model weights using JAX utility function named tree_map() which applies input function on the leaf of tree data structure. It takes as input a function that updates weights. The tree_map() function will update the weights of CNN that are stored in a dictionary-like object. It takes as input parameters update function, actual parameters, and gradients. It then updates parameters by subtracting learning rate times gradients from them. This process of updating weights by subtracting learning rate times gradients are referred to as gradient descent. In our case, it'll be referred to as stochastic gradient descent which is the version of gradient descent that works with batches of data that whole data at once.

We are also recording and printing loss for each epoch. We can notice from the loss getting printed that our model seems to be doing a good job.

def UpdateWeights(weights,gradients):
    return weights - learning_rate * gradients
from jax import value_and_grad

rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.

conv_net = hk.transform(ConvNet)
params = conv_net.init(rng, X_train[:5])
epochs = 25
batch_size = 256
learning_rate = jnp.array(1/1e4)

for i in range(1, epochs+1):
    batches = jnp.arange((X_train.shape[0]//batch_size)+1) ### Batch Indices

    losses = [] ## Record loss of each batch
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch, Y_batch = X_train[start:end], Y_train[start:end] ## Single batch of data

        loss, param_grads = value_and_grad(CrossEntropyLoss)(params, X_batch, Y_batch)
        #print(param_grads)
        params = jax.tree_map(UpdateWeights, params, param_grads) ## Update Params
        losses.append(loss) ## Record Loss

    print("CrossEntropy Loss : {:.3f}".format(jnp.array(losses).mean()))
CrossEntropy Loss : 198.377
CrossEntropy Loss : 121.871
CrossEntropy Loss : 107.817
CrossEntropy Loss : 100.251
CrossEntropy Loss : 95.148
CrossEntropy Loss : 91.252
CrossEntropy Loss : 88.042
CrossEntropy Loss : 85.257
CrossEntropy Loss : 82.777
CrossEntropy Loss : 80.551
CrossEntropy Loss : 78.525
CrossEntropy Loss : 76.664
CrossEntropy Loss : 74.953
CrossEntropy Loss : 73.392
CrossEntropy Loss : 71.941
CrossEntropy Loss : 70.602
CrossEntropy Loss : 69.341
CrossEntropy Loss : 68.184
CrossEntropy Loss : 67.075
CrossEntropy Loss : 66.021
CrossEntropy Loss : 65.024
CrossEntropy Loss : 64.059
CrossEntropy Loss : 63.149
CrossEntropy Loss : 62.265
CrossEntropy Loss : 61.409

Make Predictions

In this section, we are using updated model parameters to make predictions on train and test datasets. We have designed a small function that takes updated CNN parameters, input data, and batch size as input. It then loops through data in batches and makes predictions. It then returns predictions of all batches. We have combined the predictions of all batches using concatenate() method. As the output of our neural network is probabilities, we need to convert these probabilities to the predicted target class. Our predictions are of shape (n_samples,10). We have 10 probabilities for each data sample. We'll be taking the index of highest probability from these 10 probabilities and making that index a prediction class. We have done that for all data samples using argmax() method of JAX.

def MakePredictions(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        preds.append(conv_net.apply(weights, rng, X_batch))

    return preds
train_preds = MakePredictions(params, X_train, 256)
train_preds = jnp.concatenate(train_preds).squeeze()
train_preds = train_preds.argmax(axis=1)

test_preds = MakePredictions(params, X_test, 256)
test_preds = jnp.concatenate(test_preds).squeeze()
test_preds = test_preds.argmax(axis=1)

Evaluate Model Performance

In this section, we have evaluated the performance of our CNN by calculating the accuracy of our train and test predictions. We have also calculated classification report on test data that has information like precision, recall, and f1-score for each target class. We have used functions available from scikit-learn to calculate accuracy and classification report.

If you want to learn about functions available to calculate ML metrics through scikit-learn then please check the below link that covers the majority of them in detail.

from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
Train Accuracy : 0.917
Test  Accuracy : 0.891
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
Test Classification Report
              precision    recall  f1-score   support

         0.0       0.82      0.87      0.84      1000
         1.0       0.99      0.97      0.98      1000
         2.0       0.79      0.87      0.83      1000
         3.0       0.87      0.92      0.90      1000
         4.0       0.87      0.76      0.81      1000
         5.0       0.98      0.96      0.97      1000
         6.0       0.73      0.67      0.70      1000
         7.0       0.94      0.95      0.95      1000
         8.0       0.97      0.97      0.97      1000
         9.0       0.95      0.96      0.96      1000

    accuracy                           0.89     10000
   macro avg       0.89      0.89      0.89     10000
weighted avg       0.89      0.89      0.89     10000

Train CNN (Adam)

In this section, we have again initialized our CNN fresh and trained it but this time using Adam optimizer. We have used Adam optimizer available from optax library. Our code for training in this section is almost the same as our previous training code with few changes. We are maintaining and updating model weights using Adam optimizer available from optax. We have initialized Adam optimizer with learning rate. The optimizer object has two important methods (init() and update()). The init() method of optimizer takes model parameters as input and returns OptimizerState object that has model parameters in it. The update() method takes gradients and OptimizerState object as input and returns updates to be applied to weights and new OptimizerState object. We can then update model weights using apply_updates() method of optax by giving model parameters and updates. It'll return updated weights. We have included training using Adam optimizer to explain how we can use optax with Haiku. We can notice from the loss value getting printed that our model seems to be doing a good job.

from jax import value_and_grad

rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.
epochs = 25
batch_size = 256
learning_rate = jnp.array(1/1e4)

conv_net = hk.transform(ConvNet)
params = conv_net.init(rng, X_train[:5])

optimizer = optax.adam(learning_rate=learning_rate) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(params)

for i in range(1, epochs+1):
    batches = jnp.arange((X_train.shape[0]//batch_size)+1) ### Batch Indices

    losses = [] ## Record loss of each batch
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch, Y_batch = X_train[start:end], Y_train[start:end] ## Single batch of data

        loss, param_grads = value_and_grad(CrossEntropyLoss)(params, X_batch, Y_batch) ## Forward pass, loss and grads calculation
        #print(param_grads)
        updates, optimizer_state = optimizer.update(param_grads, optimizer_state) ## Calculate parameter updates
        params = optax.apply_updates(params, updates) ## Update model weights
        #params = jax.tree_map(UpdateWeights, params, param_grads) ## Update Params
        losses.append(loss) ## Record Loss

    print("CrossEntropy Loss : {:.2f}".format(jnp.array(losses).mean()))
CrossEntropy Loss : 196.32
CrossEntropy Loss : 114.20
CrossEntropy Loss : 102.79
CrossEntropy Loss : 97.02
CrossEntropy Loss : 93.17
CrossEntropy Loss : 90.24
CrossEntropy Loss : 87.81
CrossEntropy Loss : 85.67
CrossEntropy Loss : 83.70
CrossEntropy Loss : 81.87
CrossEntropy Loss : 80.15
CrossEntropy Loss : 78.52
CrossEntropy Loss : 76.99
CrossEntropy Loss : 75.53
CrossEntropy Loss : 74.16
CrossEntropy Loss : 72.87
CrossEntropy Loss : 71.64
CrossEntropy Loss : 70.48
CrossEntropy Loss : 69.38
CrossEntropy Loss : 68.34
CrossEntropy Loss : 67.35
CrossEntropy Loss : 66.42
CrossEntropy Loss : 65.52
CrossEntropy Loss : 64.67
CrossEntropy Loss : 63.86

Make Predictions

In this section, we are making predictions on our train and test datasets using the function that we had defined earlier for making predictions. We are making predictions using updated model weights.

train_preds = MakePredictions(params, X_train, 256)
train_preds = jnp.concatenate(train_preds).squeeze()
train_preds = train_preds.argmax(axis=1)

test_preds = MakePredictions(params, X_test, 256)
test_preds = jnp.concatenate(test_preds).squeeze()
test_preds = test_preds.argmax(axis=1)

Evaluate Model Performance

In this section, we are evaluating the performance of our trained CNN by calculating the accuracy of train and test predictions. We have also calculated the classification report on test predictions. We can notice from the performance metrics results that our CNN with Adam has almost the same result as with SGD.

from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
Train Accuracy : 0.912
Test  Accuracy : 0.891
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
Test Classification Report
              precision    recall  f1-score   support

         0.0       0.80      0.89      0.84      1000
         1.0       0.99      0.97      0.98      1000
         2.0       0.82      0.83      0.83      1000
         3.0       0.87      0.92      0.89      1000
         4.0       0.80      0.86      0.83      1000
         5.0       0.96      0.98      0.97      1000
         6.0       0.79      0.60      0.68      1000
         7.0       0.96      0.94      0.95      1000
         8.0       0.98      0.97      0.97      1000
         9.0       0.95      0.96      0.96      1000

    accuracy                           0.89     10000
   macro avg       0.89      0.89      0.89     10000
weighted avg       0.89      0.89      0.89     10000

2. Guide to Handle Channels First vs Channels Last

In our example above, we have used grayscale images and we introduced channels dimension for images at the end when loading the dataset. There are generally two ways to represent an image with a multi-dimensional tensor.

  1. Channels First - Here, we represent color image of (28,28) pixels using (3,28,28) dimension tensor.
  2. Channels Last - Here, we represent color image of (28,28) pixels using (28,28,3) dimension tensor.

By default, the Conv2D layer of Haiku requires channels’ last format. But we can face the situation where the data format is the channels first. To work with those data formats, Conv2D layer of Haiku let us provide data format using data_format parameter. The default value of that parameter is NHWC.

  • N - Number of data samples
  • H - Height of Image
  • W - Width of Image
  • C - Number of Channels.

If our data has channels first data format then we need to provide data format as NCHW using data_format parameter.

Below, we have explained with examples how we can handle both data formats. We have also shown that if we don't handle data formats properly it can cause issues.

def Conv2DFunc1(x):
    conv2d = hk.Conv2D(16, (3,3), padding="SAME")
    return conv2d(x)

def Conv2DFunc2(x):
    conv2d = hk.Conv2D(32, (3,3), padding="SAME")
    return conv2d(x)

conv2d1 = hk.transform(Conv2DFunc1)
conv2d2 = hk.transform(Conv2DFunc2)
rng = jax.random.PRNGKey(0)

params1 = conv2d1.init(rng, jax.random.uniform(rng, (50,28,28,1))) ### Channels Last
preds1 = conv2d1.apply(params1, rng, jax.random.uniform(rng, (50,28,28,1)))

params2 = conv2d2.init(rng, jax.random.uniform(rng, preds1.shape)) ### Channels Last
preds2 = conv2d2.apply(params2, rng, jax.random.uniform(rng, preds1.shape))

print("Weights of First Conv Layer : {}".format(params1["conv2_d"]["w"].shape))
print("Weights of First Conv Layer : {}".format(params2["conv2_d"]["w"].shape))

print("\nInput Shape               : {}".format((50,28,28,1)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
Weights of First Conv Layer : (3, 3, 1, 16)
Weights of First Conv Layer : (3, 3, 16, 32)

Input Shape               : (50, 28, 28, 1)
Conv Layer 1 Output Shape : (50, 28, 28, 16)
Conv Layer 2 Output Shape : (50, 28, 28, 32)
rng = jax.random.PRNGKey(0)

params1 = conv2d1.init(rng, jax.random.uniform(rng, (50,1,28,28))) ### Channels First
preds1 = conv2d1.apply(params1, rng, jax.random.uniform(rng, (50,1,28,28)))

params2 = conv2d2.init(rng, jax.random.uniform(rng, preds1.shape)) ### Channels Last
preds2 = conv2d2.apply(params2, rng, jax.random.uniform(rng, preds1.shape))

print("Weights of First Conv Layer : {}".format(params1["conv2_d"]["w"].shape))
print("Weights of First Conv Layer : {}".format(params2["conv2_d"]["w"].shape))

print("\nInput Shape               : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
Weights of First Conv Layer : (3, 3, 28, 16)
Weights of First Conv Layer : (3, 3, 16, 32)

Input Shape               : (50, 1, 28, 28)
Conv Layer 1 Output Shape : (50, 1, 28, 16)
Conv Layer 2 Output Shape : (50, 1, 28, 32)
def Conv2DFunc1(x):
    conv2d = hk.Conv2D(16, (3,3), padding="SAME", data_format="NCHW") # NHWC or NCHW
    return conv2d(x)

def Conv2DFunc2(x):
    conv2d = hk.Conv2D(32, (3,3), padding="SAME", data_format="NCHW")
    return conv2d(x)

conv2d1 = hk.transform(Conv2DFunc1)
conv2d2 = hk.transform(Conv2DFunc2)
rng = jax.random.PRNGKey(0)

params1 = conv2d1.init(rng, jax.random.uniform(rng, (50,1,28,28))) ### Channels First
preds1 = conv2d1.apply(params1, rng, jax.random.uniform(rng, (50,1,28,28)))

params2 = conv2d2.init(rng, jax.random.uniform(rng, preds1.shape)) ### Channels Last
preds2 = conv2d2.apply(params2, rng, jax.random.uniform(rng, preds1.shape))

print("Weights of First Conv Layer : {}".format(params1["conv2_d"]["w"].shape))
print("Weights of First Conv Layer : {}".format(params2["conv2_d"]["w"].shape))

print("\nInput Shape               : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
Weights of First Conv Layer : (3, 3, 1, 16)
Weights of First Conv Layer : (3, 3, 16, 32)

Input Shape               : (50, 1, 28, 28)
Conv Layer 1 Output Shape : (50, 16, 28, 28)
Conv Layer 2 Output Shape : (50, 32, 28, 28)
Sunny Solanki  Sunny Solanki

YouTube Subscribe Comfortable Learning through Video Tutorials?

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.

Need Help Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.


Subscribe to Our YouTube Channel

YouTube SubScribe

Newsletter Subscription