Updated On : Jan-05,2022 Time Investment : ~30 mins

JAX: Guide to Create Convolutional Neural Networks

Convolutional Neural Networks (CNN) are a class of neural networks that uses convolution layers applying convolution operations on input data to find out patterns. CNN is commonly used for tasks related to visual imagery (like object detection, image classification, image segmentation, etc.) where they generally output perform other types of neural networks. CNNs are commonly used for natural language processing and time-series as well. CNNs generally has quite less number of parameters to train compared to fully connected dense neural network for tasks with a lot of features. If you want to learn in-depth about convolutional neural networks then please feel free to check our blog which covers theory in detail.

As a part of this tutorial, we'll cover how we can create a simple convolutional neural network using JAX. JAX is a famous framework for designing a neural network that provides functionalities like numpy-like idioms on CPUs/GPUs/TPUs, automatic differentiation, Just-in-time compilation, etc. If readers do not have a background on JAX then we recommend that you go through our tutorial on it that covers basics. It'll help you with this tutorial as well.

In this tutorial, we'll be designing a simple convolutional neural network using the high-level stax API of JAX. We have another tutorial on stax API describing how to create simple fully connected neural networks. Please feel free to check it if you are looking for it. It'll help with this tutorial.

Below we have highlighted important sections of the tutorial.

Important Sections of Tutorial

  1. Simple Convolutional Neural Network
    • Load Fashion MNIST Dataset
    • Create Neural Network
    • Define Loss Function
    • Train Neural Network (SGD Optimizer)
    • Make Predictions
    • Evaluate Model Performance
    • Train Neural Network (Adam Optimizer)
    • Make Predictions
    • Evaluate Model Performance
  2. Guide to Handle Channels First vs Channels Last

Below we have imported JAX and printed the version of it that we'll be using in our tutorial.

import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.2.25

Simple Convolutional Neural Network

In this section, we'll explain step by step process to create a convolutional neural network and train it. Below, we have imported the necessary submodules of JAX that we'll be using in our tutorial. We have imported stax and optimizers for creating neural networks and optimizers respectively. The jax.numpy module helps with maintaining arrays.

from jax.example_libraries import stax, optimizers

from jax import numpy as jnp

Load Fashion MNIST Dataset

In this section, we have loaded Fashion MNIST dataset available from keras. The dataset has 28x28 size grayscale images of 10 fashion items. The dataset is already divided into the train (60k samples) and test (10k samples) sets. We have converted the dataset loaded as numpy arrays to JAX arrays as required by the model built-in JAX. Later on, we have resized the loaded images to shape (28,28,1) where the last extra dimension represents the channel. The convolutional neural networks require channel dimensions. As our images are grayscale and do not have 3-channels (RGB) like color images, we have introduced an extra dimension as a channel for convolution layers which requires it. After reshaping datasets, we have divided both datasets by float 255 to normalize them as currently images are loaded as integers in the range (0,255). This will help the optimization algorithm converge faster.

from tensorflow import keras
from sklearn.model_selection import train_test_split

(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)

X_train, X_test = X_train/255.0, X_test/255.0

classes =  jnp.unique(Y_train)

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
40960/29515 [=========================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
26435584/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
16384/5148 [===============================================================================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step
4431872/4422102 [==============================] - 0s 0us/step
((60000, 28, 28, 1), (10000, 28, 28, 1), (60000,), (10000,))

Create Neural Network

In this section, we have created a convolutional neural network that we'll be using for our fashion MNIST dataset image classification. Our CNN is simple with only 2 convolution layers and one linear/dense layer. The first convolution layer has 32 channels and a kernel size of (3,3). The second layer convolution layer has 16 channels and a kernel size of (3,3). Both have a padding of 'SAME' which indicates that kernel should be applied after padding so that input and output height and width of an image are the same.

The images after first layer will be transformed from shape (n_samples,height,width,1) to (n_samples,height,width,32). The first layer will transform 1 channel image into 32 channels. The height and width in our case are 28. After applying the convolution layer, we have applied Relu (rectified linear unit) activation to the output image. The second layer will transform images from shape (n_samples,height,width,32) to (n_samples,height,width,16). After the second convolution layer, we have again applied Relu activation to the output.

After two convolution layers, we have flattened the output. The new shape now will be (n_samples, height x width x 16) = (n_samples, 28 x 28 x 16) = (n_samples, 12544). After flattening the output, we have added the last dense layer with a number of units the same as our number of classes which is 10. At last, we have applied softmax activation function which maps all 10 output values in the range [0,1] such that the sum of all of them is 1. We can then predict image class as the one with the highest value from that 10.

We have created our neural network using serial() API of stax by giving layers in sequence to it in which they will be applied. The serial() function returns two other functions.

  1. init() - This function takes JAX pseudo-random number generator seed and input data shape as input. It then generates model weights and returns them.
  2. apply() - This function takes model weights and data as input. It then performs a forward pass of data through the network using weights.

After creating CNN, in the next cell, we have initialized the weights of the neural network using init() method and printed their shapes. Then in the following cell, we have performed a forward pass of a few data samples through the network using apply() method and printed predictions.

conv_init, conv_apply = stax.serial(
    stax.Conv(32,(3,3), padding="SAME"),
    stax.Relu,
    stax.Conv(16, (3,3), padding="SAME"),
    stax.Relu,

    stax.Flatten,
    stax.Dense(len(classes)),
    stax.Softmax
)
rng = jax.random.PRNGKey(123)

weights = conv_init(rng, (18,28,28,1))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))
Weights : (3, 3, 1, 32), Biases : (1, 1, 1, 32)
Weights : (3, 3, 32, 16), Biases : (1, 1, 1, 16)
Weights : (12544, 10), Biases : (10,)
preds = conv_apply(weights, X_train[:5])

preds
DeviceArray([[0.1010583 , 0.08450613, 0.08876862, 0.10357412, 0.0942447 ,
              0.06837884, 0.13585268, 0.10559722, 0.10213769, 0.11588173],
             [0.10652138, 0.08236859, 0.11018915, 0.11165013, 0.08258919,
              0.0772452 , 0.14238864, 0.09115906, 0.08840751, 0.10748117],
             [0.09566505, 0.09545239, 0.10094573, 0.10249694, 0.09882495,
              0.08883354, 0.11344208, 0.09690957, 0.10275006, 0.10467971],
             [0.10200436, 0.08598677, 0.10633808, 0.10407417, 0.09481844,
              0.0829622 , 0.12135128, 0.093064  , 0.09924015, 0.11016052],
             [0.09390713, 0.08359886, 0.10012675, 0.11463808, 0.09753538,
              0.07206484, 0.12989207, 0.08960547, 0.10824842, 0.11038305]],            dtype=float32)

Define Loss Function

In this section, we have defined the loss function that we'll be using in our case. We'll be using categorical cross entropy function as our loss. The function takes weights, input data, and actual target values as input. The function then makes predictions on input data using weights (apply() method). Then we have one-hot encoded the actual target values as required by our loss function. Then we have taken the log of our predictions. At last, we have multiplied the log of predictions with one-hot encoded actual target values. Then we have taken the sum of the output array.

Later, we'll be calculating the gradient of this loss function with respect to model weights which is the first parameter of the function.

def CrossEntropyLoss(weights, input_data, actual):
    preds = conv_apply(weights, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

Train Neural Network (SGD Optimizer)

In this section, we are actually training our CNN. We have designed a function that takes data and a few other information as input to perform training. The function takes data features, target values, number of epochs, optimizer state (weights), and batch size as input. It then executed the training loop number of epochs time.

For each epoch, it calculates batch indexes for batches of data. It then performs forward pass of batch data through the network to make predictions, calculates loss, and then calculates gradients of loss with respect to weights. It does all that by single call to value_and_grad() function of JAX. This function takes our loss function as input and returns another function. This new function returned by value_and_grad() when called with input values returns two output values. The first value is loss value (cross-entropy) and the second value is gradients of loss with respect to weights. The value_and_grad() function calculates the gradient of the loss with respect to the first parameter of the loss function (cross-entropy) which is weights in our case. The forward pass is performed inside of the loss function.

The weights of our neural network are kept in the optimizer state object. We can retrieve weights of the network by calling get_weights() method of optimizer by giving optimizer state to it. We have explained below the optimizer state when creating the optimizer below.

After calculating gradients, we have updated the weights of the neural network by calling update() method of an optimizer. It takes input gradients and optimizer state (weights). The method returns a new optimizer state which has updated weights.

We are also printing average cross entropy across all batches. At last, we are returning the last updated optimizer state object that has updated weights of the neural network.

from jax import value_and_grad

def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(CrossEntropyLoss)(opt_get_weights(opt_state), X_batch,Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss

        print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))

    return opt_state

Now, we are actually training our neural network by calling a method from the previous cell. We have first initialized the learning rate to 0.0001, a number of epochs to 25, and batch size to 256.

We have then initialized the weights of our CNN using init() method.

Then, we have created SGD optimizer using optimizers module providing learning rate to it. The optimizer returns three functions.

  1. init() - It initialized optimizers with weights of neural network and returns initial OptimizerState object which has model weights.
  2. update() - It takes gradients and optimizer state as input and returns a new optimizer state which has updated weights (weights updated by subtracting learning rate times gradients).
  3. get_weights() - This method takes OptimizerState object as input and returns model weights.

After initializing the optimizer and optimizer state, we have called our function from the previous cell to perform training by providing training data and other parameters. We have stored the final optimizer state returned by a function which we'll use later to make predictions.

We can notice from the cross-entropy loss getting printed that our model seems to be doing a better job. We'll evaluate its performance later by calculating accuracy and classification metrics like precision, recall, and f1-score.

seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 25
batch_size=256

weights = conv_init(rng, (batch_size,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)
CrossEntropyLoss : 229.267
CrossEntropyLoss : 145.876
CrossEntropyLoss : 127.667
CrossEntropyLoss : 117.161
CrossEntropyLoss : 109.770
CrossEntropyLoss : 104.089
CrossEntropyLoss : 99.591
CrossEntropyLoss : 95.963
CrossEntropyLoss : 92.993
CrossEntropyLoss : 90.504
CrossEntropyLoss : 88.365
CrossEntropyLoss : 86.485
CrossEntropyLoss : 84.809
CrossEntropyLoss : 83.273
CrossEntropyLoss : 81.874
CrossEntropyLoss : 80.553
CrossEntropyLoss : 79.321
CrossEntropyLoss : 78.159
CrossEntropyLoss : 77.078
CrossEntropyLoss : 76.028
CrossEntropyLoss : 75.039
CrossEntropyLoss : 74.089
CrossEntropyLoss : 73.167
CrossEntropyLoss : 72.287
CrossEntropyLoss : 71.416

Make Predictions

In this section, we are making predictions on train and test data. We have created a small function that loops through data in batches and makes predictions. We have then combined predictions of all batches. The output of our CNN is 10 values per sample hence we have included logic to convert those 10 values to the actual class label. We have predicted the index of maximum value as a class label using argmax() function.

def MakePredictions(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        if X_batch.shape[0] != 0:
            preds.append(conv_apply(weights, X_batch))

    return preds
test_preds = MakePredictions(opt_get_weights(final_opt_state), X_test, batch_size=batch_size)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions(opt_get_weights(final_opt_state), X_train, batch_size=batch_size)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]
(DeviceArray([9, 2, 1, 1, 6], dtype=int32),
 DeviceArray([9, 0, 0, 3, 1], dtype=int32))

Evaluate Model Performance

In this section, we have evaluated the performance of our CNN by calculating the accuracy of train and test predictions. We have used the function available from scikit-learn to calculate the accuracy of train and test predictions.

Then, in the next cell, we have calculated the classification report of test predictions using the function available from scikit-learn. The classification report has information like precision, recall, and f1-score for each target class.

If you want to learn about various metrics available through scikit-learn then please feel free to check our tutorial that covers the majority of them in detail.

from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
Train Accuracy : 0.902
Test  Accuracy : 0.882
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
Test Classification Report
              precision    recall  f1-score   support

         0.0       0.81      0.86      0.83      1000
         1.0       0.99      0.96      0.98      1000
         2.0       0.75      0.87      0.81      1000
         3.0       0.85      0.93      0.89      1000
         4.0       0.84      0.75      0.79      1000
         5.0       0.98      0.96      0.97      1000
         6.0       0.74      0.62      0.67      1000
         7.0       0.93      0.95      0.94      1000
         8.0       0.97      0.97      0.97      1000
         9.0       0.96      0.96      0.96      1000

    accuracy                           0.88     10000
   macro avg       0.88      0.88      0.88     10000
weighted avg       0.88      0.88      0.88     10000

Train Neural Network (Adam Optimizer)

In this section, we have again trained our CNN but this time using Adam optimizer instead of SGD. The majority of parameters like learning rate, number of epochs, and batch size are the same as previous training. The only change is an optimizer.

seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 25
batch_size=256

weights = conv_init(rng, (batch_size,28,28,1))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.adam(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)
CrossEntropyLoss : 231.014
CrossEntropyLoss : 142.974
CrossEntropyLoss : 130.786
CrossEntropyLoss : 124.205
CrossEntropyLoss : 119.794
CrossEntropyLoss : 116.557
CrossEntropyLoss : 114.069
CrossEntropyLoss : 112.018
CrossEntropyLoss : 110.263
CrossEntropyLoss : 108.722
CrossEntropyLoss : 107.340
CrossEntropyLoss : 106.081
CrossEntropyLoss : 104.918
CrossEntropyLoss : 103.831
CrossEntropyLoss : 102.820
CrossEntropyLoss : 101.880
CrossEntropyLoss : 100.992
CrossEntropyLoss : 100.147
CrossEntropyLoss : 99.340
CrossEntropyLoss : 98.573
CrossEntropyLoss : 97.851
CrossEntropyLoss : 97.159
CrossEntropyLoss : 96.497
CrossEntropyLoss : 95.860
CrossEntropyLoss : 95.248

Make Predictions

In this section, we have made predictions on train and test datasets using new weights calculated using Adam optimizer.

test_preds = MakePredictions(opt_get_weights(final_opt_state), X_test, batch_size=256)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = jnp.argmax(test_preds, axis=1)

train_preds = MakePredictions(opt_get_weights(final_opt_state), X_train)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = jnp.argmax(train_preds, axis=1)

test_preds[:5], train_preds[:5]
(DeviceArray([9, 2, 1, 1, 6], dtype=int32),
 DeviceArray([9, 0, 0, 3, 3], dtype=int32))

Evaluate Model Performance

In this section, we have evaluated the performance of our new CNN trained using Adam optimizer by calculating accuracy and classification report. It seems from the results that in our case SGD optimizer has done a better job compared to Adam.

from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
Train Accuracy : 0.870
Test  Accuracy : 0.858
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
Test Classification Report
              precision    recall  f1-score   support

         0.0       0.80      0.83      0.82      1000
         1.0       0.97      0.96      0.97      1000
         2.0       0.79      0.76      0.77      1000
         3.0       0.86      0.86      0.86      1000
         4.0       0.74      0.80      0.77      1000
         5.0       0.94      0.95      0.95      1000
         6.0       0.64      0.58      0.61      1000
         7.0       0.94      0.90      0.92      1000
         8.0       0.96      0.95      0.95      1000
         9.0       0.93      0.96      0.95      1000

    accuracy                           0.86     10000
   macro avg       0.86      0.86      0.86     10000
weighted avg       0.86      0.86      0.86     10000

Guide to Handle Channels First vs Channels Last

Color or RGB images have channels for each color red, green, and blue. Other image types like RGBA, CMYK, HSV can have 3 or more channels that are combined to create the final image. The CNN generally applies kernels on channels of images and transforms them. The images are generally represented as multi-dimensional arrays. To represent a single RGB image, we can use two different multi-dimension formats.

  1. Channels First - Here, we maintain 3 2D arrays. The RGB image of shape (28,28) pixel is represented as (3,28,28). There are 3 arrays of shape (28,28).
  2. Channels Last - Here, we maintain details of channels for each pixel together. The RGB image of shape (28,28) pixel is represented as (28,28,3). There are 3 values per pixel of (28,28) image.

In our example, we had modified our grayscale images from shape (28,28) to (28,28,1). We had introduced an extra dimension specifying channel for our convolution layers. The default convolution layer Conv() available from JAX requires channel details to be present at last. It can not handle if channel details are present first for images.

Our next two cells explain with examples, how default Conv() layers of JAX properly handle images with channels as the last dimension but fails to handle cases where channels are present first.

The first cell below explains clearly, how it properly transforms an image of shape (28,28) from 1 channel to 16 and then 16 channels to 32 channels. The next cell after it fails to transform channels of input images.

rng = jax.random.PRNGKey(123)

init_conv1, apply_conv1 = stax.Conv(16, (3,3), padding="SAME")

weights1 = init_conv1(rng, (50,28,28,1))
preds1 = apply_conv1(weights1[1], jax.random.uniform(rng, (50,28,28,1)))

init_conv2, apply_conv2 = stax.Conv(32, (3,3), padding="SAME")
weights2 = init_conv2(rng, preds1.shape)
preds2 = apply_conv2(weights2[1], jax.random.uniform(rng, preds1.shape))

print("Weights of First Conv Layer : {}".format(weights1[1][0].shape))
print("Weights of Second Conv Layer : {}".format(weights2[1][0].shape))

print("\nInput Shape               : {}".format((50,28,28,1)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
Weights of First Conv Layer : (3, 3, 1, 16)
Weights of Second Conv Layer : (3, 3, 16, 32)

Input Shape               : (50, 28, 28, 1)
Conv Layer 1 Output Shape : (50, 28, 28, 16)
Conv Layer 2 Output Shape : (50, 28, 28, 32)
rng = jax.random.PRNGKey(123)

init_conv1, apply_conv1 = stax.Conv(16, (3,3), padding="SAME")

weights1 = init_conv1(rng, (50,1,28,28))
preds1 = apply_conv1(weights1[1], jax.random.uniform(rng, (50,1,28,28)))

init_conv2, apply_conv2 = stax.Conv(32, (3,3), padding="SAME")
weights2 = init_conv2(rng, preds1.shape)
preds2 = apply_conv2(weights2[1], jax.random.uniform(rng, preds1.shape))

print("Weights of First Conv Layer : {}".format(weights1[1][0].shape))
print("Weights of Second Conv Layer : {}".format(weights2[1][0].shape))

print("\nInput Shape               : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
Weights of First Conv Layer : (3, 3, 28, 16)
Weights of Second Conv Layer : (3, 3, 16, 32)

Input Shape               : (50, 1, 28, 28)
Conv Layer 1 Output Shape : (50, 1, 28, 16)
Conv Layer 2 Output Shape : (50, 1, 28, 32)

Below, we have printed the definition of Conv() layer and we can notice that there are shape details specified as a tuple of three strings. The first string in the tuple represents input image dimensions, the second string represents kernel dimensions and the third string represents output after applying kernel to input images.

  1. NHWC - (Number of samples, Height, Width, Channels) - This is input images dimensions.
  2. HWIO - (Height, Width, Input Channels, Output Channels) - This is kernel dimensions of the convolution layer.

We can notice from the default configuration of Conv() layer that it requires channel details present at last.

stax.Conv
functools.partial(<function GeneralConv at 0x7fcec369aa70>, ('NHWC', 'HWIO', 'NHWC'))

But, what if we have images with channel details present first. The JAX provides a layer named GeneralConv() that lets us specify the shape of images and kernel as a tuple.

The GeneralConv() layer works exactly like Conv() but the first input to it is a tuple of 3 strings specifying dimensions of input/output images and kernel. We can handle channels first cases using GeneralConv() layer.

Below, we have explained using two simple examples how we can use the GeneralConv() layer when channel details are present first.

rng = jax.random.PRNGKey(123)

init_conv1, apply_conv1 = stax.GeneralConv(("NCHW", "HWIO", "NCHW"),16, (3,3), padding="SAME")

weights1 = init_conv1(rng, (50,1,28,28))
preds1 = apply_conv1(weights1[1], jax.random.uniform(rng, (50,1,28,28)))

init_conv2, apply_conv2 = stax.GeneralConv(("NCHW", "HWIO", "NCHW"),32, (3,3), padding="SAME")
weights2 = init_conv2(rng, preds1.shape)
preds2 = apply_conv2(weights2[1], jax.random.uniform(rng, preds1.shape))

print("Weights of First Conv Layer : {}".format(weights1[1][0].shape))
print("Weights of Second Conv Layer : {}".format(weights2[1][0].shape))

print("\nInput Shape               : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
Weights of First Conv Layer : (3, 3, 1, 16)
Weights of Second Conv Layer : (3, 3, 16, 32)

Input Shape               : (50, 1, 28, 28)
Conv Layer 1 Output Shape : (50, 16, 28, 28)
Conv Layer 2 Output Shape : (50, 32, 28, 28)
rng = jax.random.PRNGKey(123)

init_conv1, apply_conv1 = stax.GeneralConv(("NCHW", "IOHW", "NCHW"),16, (3,3), padding="SAME")

weights1 = init_conv1(rng, (50,1,28,28))
preds1 = apply_conv1(weights1[1], jax.random.uniform(rng, (50,1,28,28)))

init_conv2, apply_conv2 = stax.GeneralConv(("NCHW", "IOHW", "NCHW"),32, (3,3), padding="SAME")
weights2 = init_conv2(rng, preds1.shape)
preds2 = apply_conv2(weights2[1], jax.random.uniform(rng, preds1.shape))

print("Weights of First Conv Layer : {}".format(weights1[1][0].shape))
print("Weights of Second Conv Layer : {}".format(weights2[1][0].shape))

print("\nInput Shape               : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
Weights of First Conv Layer : (1, 16, 3, 3)
Weights of Second Conv Layer : (16, 32, 3, 3)

Input Shape               : (50, 1, 28, 28)
Conv Layer 1 Output Shape : (50, 16, 28, 28)
Conv Layer 2 Output Shape : (50, 32, 28, 28)

This ends our small tutorial explaining how we can create simple convolutional neural networks using JAX. Please feel free to let us know your views in the comments section.

Reference

Sunny Solanki  Sunny Solanki

Share Views Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.