Convolutional Neural Networks (CNN) are a class of neural networks that uses convolution layers applying convolution operations on input data to find out patterns. CNN is commonly used for tasks related to visual imagery (like object detection, image classification, image segmentation, etc.) where they generally output perform other types of neural networks. CNNs are commonly used for natural language processing and time-series as well. CNNs generally has quite less number of parameters to train compared to fully connected dense neural network for tasks with a lot of features. If you want to learn in-depth about convolutional neural networks then please feel free to check our blog which covers theory in detail.
As a part of this tutorial, we'll cover how we can create a simple convolutional neural network using JAX. JAX is a famous framework for designing a neural network that provides functionalities like numpy-like idioms on CPUs/GPUs/TPUs, automatic differentiation, Just-in-time compilation, etc. If readers do not have a background on JAX then we recommend that you go through our tutorial on it that covers basics. It'll help you with this tutorial as well.
In this tutorial, we'll be designing a simple convolutional neural network using the high-level stax API of JAX. We have another tutorial on stax API describing how to create simple fully connected neural networks. Please feel free to check it if you are looking for it. It'll help with this tutorial.
Below we have highlighted important sections of the tutorial.
Below we have imported JAX and printed the version of it that we'll be using in our tutorial.
import jax
print("JAX Version : {}".format(jax.__version__))
In this section, we'll explain step by step process to create a convolutional neural network and train it. Below, we have imported the necessary submodules of JAX that we'll be using in our tutorial. We have imported stax and optimizers for creating neural networks and optimizers respectively. The jax.numpy module helps with maintaining arrays.
from jax.example_libraries import stax, optimizers
from jax import numpy as jnp
In this section, we have loaded Fashion MNIST dataset available from keras. The dataset has 28x28 size grayscale images of 10 fashion items. The dataset is already divided into the train (60k samples) and test (10k samples) sets. We have converted the dataset loaded as numpy arrays to JAX arrays as required by the model built-in JAX. Later on, we have resized the loaded images to shape (28,28,1) where the last extra dimension represents the channel. The convolutional neural networks require channel dimensions. As our images are grayscale and do not have 3-channels (RGB) like color images, we have introduced an extra dimension as a channel for convolution layers which requires it. After reshaping datasets, we have divided both datasets by float 255 to normalize them as currently images are loaded as integers in the range (0,255). This will help the optimization algorithm converge faster.
from tensorflow import keras
from sklearn.model_selection import train_test_split
(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)
X_train, X_test = X_train/255.0, X_test/255.0
classes = jnp.unique(Y_train)
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
In this section, we have created a convolutional neural network that we'll be using for our fashion MNIST dataset image classification. Our CNN is simple with only 2 convolution layers and one linear/dense layer. The first convolution layer has 32 channels and a kernel size of (3,3). The second layer convolution layer has 16 channels and a kernel size of (3,3). Both have a padding of 'SAME' which indicates that kernel should be applied after padding so that input and output height and width of an image are the same.
The images after first layer will be transformed from shape (n_samples,height,width,1) to (n_samples,height,width,32). The first layer will transform 1 channel image into 32 channels. The height and width in our case are 28. After applying the convolution layer, we have applied Relu (rectified linear unit) activation to the output image. The second layer will transform images from shape (n_samples,height,width,32) to (n_samples,height,width,16). After the second convolution layer, we have again applied Relu activation to the output.
After two convolution layers, we have flattened the output. The new shape now will be (n_samples, height x width x 16) = (n_samples, 28 x 28 x 16) = (n_samples, 12544). After flattening the output, we have added the last dense layer with a number of units the same as our number of classes which is 10. At last, we have applied softmax activation function which maps all 10 output values in the range [0,1] such that the sum of all of them is 1. We can then predict image class as the one with the highest value from that 10.
We have created our neural network using serial() API of stax by giving layers in sequence to it in which they will be applied. The serial() function returns two other functions.
After creating CNN, in the next cell, we have initialized the weights of the neural network using init() method and printed their shapes. Then in the following cell, we have performed a forward pass of a few data samples through the network using apply() method and printed predictions.
conv_init, conv_apply = stax.serial(
stax.Conv(32,(3,3), padding="SAME"),
stax.Relu,
stax.Conv(16, (3,3), padding="SAME"),
stax.Relu,
stax.Flatten,
stax.Dense(len(classes)),
stax.Softmax
)
rng = jax.random.PRNGKey(123)
weights = conv_init(rng, (18,28,28,1))
weights = weights[1] ## Weights are actually stored in second element of two value tuple
for w in weights:
if w:
w, b = w
print("Weights : {}, Biases : {}".format(w.shape, b.shape))
preds = conv_apply(weights, X_train[:5])
preds
In this section, we have defined the loss function that we'll be using in our case. We'll be using categorical cross entropy function as our loss. The function takes weights, input data, and actual target values as input. The function then makes predictions on input data using weights (apply() method). Then we have one-hot encoded the actual target values as required by our loss function. Then we have taken the log of our predictions. At last, we have multiplied the log of predictions with one-hot encoded actual target values. Then we have taken the sum of the output array.
Later, we'll be calculating the gradient of this loss function with respect to model weights which is the first parameter of the function.
def CrossEntropyLoss(weights, input_data, actual):
preds = conv_apply(weights, input_data)
one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
log_preds = jnp.log(preds)
return - jnp.sum(one_hot_actual * log_preds)
In this section, we are actually training our CNN. We have designed a function that takes data and a few other information as input to perform training. The function takes data features, target values, number of epochs, optimizer state (weights), and batch size as input. It then executed the training loop number of epochs time.
For each epoch, it calculates batch indexes for batches of data. It then performs forward pass of batch data through the network to make predictions, calculates loss, and then calculates gradients of loss with respect to weights. It does all that by single call to value_and_grad() function of JAX. This function takes our loss function as input and returns another function. This new function returned by value_and_grad() when called with input values returns two output values. The first value is loss value (cross-entropy) and the second value is gradients of loss with respect to weights. The value_and_grad() function calculates the gradient of the loss with respect to the first parameter of the loss function (cross-entropy) which is weights in our case. The forward pass is performed inside of the loss function.
The weights of our neural network are kept in the optimizer state object. We can retrieve weights of the network by calling get_weights() method of optimizer by giving optimizer state to it. We have explained below the optimizer state when creating the optimizer below.
After calculating gradients, we have updated the weights of the neural network by calling update() method of an optimizer. It takes input gradients and optimizer state (weights). The method returns a new optimizer state which has updated weights.
We are also printing average cross entropy across all batches. At last, we are returning the last updated optimizer state object that has updated weights of the neural network.
from jax import value_and_grad
def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
for i in range(1, epochs+1):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss, gradients = value_and_grad(CrossEntropyLoss)(opt_get_weights(opt_state), X_batch,Y_batch)
## Update Weights
opt_state = opt_update(i, gradients, opt_state)
losses.append(loss) ## Record Loss
print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
return opt_state
Now, we are actually training our neural network by calling a method from the previous cell. We have first initialized the learning rate to 0.0001, a number of epochs to 25, and batch size to 256.
We have then initialized the weights of our CNN using init() method.
Then, we have created SGD optimizer using optimizers module providing learning rate to it. The optimizer returns three functions.
After initializing the optimizer and optimizer state, we have called our function from the previous cell to perform training by providing training data and other parameters. We have stored the final optimizer state returned by a function which we'll use later to make predictions.
We can notice from the cross-entropy loss getting printed that our model seems to be doing a better job. We'll evaluate its performance later by calculating accuracy and classification metrics like precision, recall, and f1-score.
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 25
batch_size=256
weights = conv_init(rng, (batch_size,28,28,1))
weights = weights[1]
opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)
final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)
In this section, we are making predictions on train and test data. We have created a small function that loops through data in batches and makes predictions. We have then combined predictions of all batches. The output of our CNN is 10 values per sample hence we have included logic to convert those 10 values to the actual class label. We have predicted the index of maximum value as a class label using argmax() function.
def MakePredictions(weights, input_data, batch_size=32):
batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices
preds = []
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch = input_data[start:end]
if X_batch.shape[0] != 0:
preds.append(conv_apply(weights, X_batch))
return preds
test_preds = MakePredictions(opt_get_weights(final_opt_state), X_test, batch_size=batch_size)
test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches
test_preds = jnp.argmax(test_preds, axis=1)
train_preds = MakePredictions(opt_get_weights(final_opt_state), X_train, batch_size=batch_size)
train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches
train_preds = jnp.argmax(train_preds, axis=1)
test_preds[:5], train_preds[:5]
In this section, we have evaluated the performance of our CNN by calculating the accuracy of train and test predictions. We have used the function available from scikit-learn to calculate the accuracy of train and test predictions.
Then, in the next cell, we have calculated the classification report of test predictions using the function available from scikit-learn. The classification report has information like precision, recall, and f1-score for each target class.
If you want to learn about various metrics available through scikit-learn then please feel free to check our tutorial that covers the majority of them in detail.
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
from sklearn.metrics import classification_report
print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
In this section, we have again trained our CNN but this time using Adam optimizer instead of SGD. The majority of parameters like learning rate, number of epochs, and batch size are the same as previous training. The only change is an optimizer.
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 25
batch_size=256
weights = conv_init(rng, (batch_size,28,28,1))
weights = weights[1]
opt_init, opt_update, opt_get_weights = optimizers.adam(learning_rate)
opt_state = opt_init(weights)
final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state, batch_size=batch_size)
In this section, we have made predictions on train and test datasets using new weights calculated using Adam optimizer.
test_preds = MakePredictions(opt_get_weights(final_opt_state), X_test, batch_size=256)
test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches
test_preds = jnp.argmax(test_preds, axis=1)
train_preds = MakePredictions(opt_get_weights(final_opt_state), X_train)
train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches
train_preds = jnp.argmax(train_preds, axis=1)
test_preds[:5], train_preds[:5]
In this section, we have evaluated the performance of our new CNN trained using Adam optimizer by calculating accuracy and classification report. It seems from the results that in our case SGD optimizer has done a better job compared to Adam.
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, test_preds)))
from sklearn.metrics import classification_report
print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
Color or RGB images have channels for each color red, green, and blue. Other image types like RGBA, CMYK, HSV can have 3 or more channels that are combined to create the final image. The CNN generally applies kernels on channels of images and transforms them. The images are generally represented as multi-dimensional arrays. To represent a single RGB image, we can use two different multi-dimension formats.
In our example, we had modified our grayscale images from shape (28,28) to (28,28,1). We had introduced an extra dimension specifying channel for our convolution layers. The default convolution layer Conv() available from JAX requires channel details to be present at last. It can not handle if channel details are present first for images.
Our next two cells explain with examples, how default Conv() layers of JAX properly handle images with channels as the last dimension but fails to handle cases where channels are present first.
The first cell below explains clearly, how it properly transforms an image of shape (28,28) from 1 channel to 16 and then 16 channels to 32 channels. The next cell after it fails to transform channels of input images.
rng = jax.random.PRNGKey(123)
init_conv1, apply_conv1 = stax.Conv(16, (3,3), padding="SAME")
weights1 = init_conv1(rng, (50,28,28,1))
preds1 = apply_conv1(weights1[1], jax.random.uniform(rng, (50,28,28,1)))
init_conv2, apply_conv2 = stax.Conv(32, (3,3), padding="SAME")
weights2 = init_conv2(rng, preds1.shape)
preds2 = apply_conv2(weights2[1], jax.random.uniform(rng, preds1.shape))
print("Weights of First Conv Layer : {}".format(weights1[1][0].shape))
print("Weights of Second Conv Layer : {}".format(weights2[1][0].shape))
print("\nInput Shape : {}".format((50,28,28,1)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
rng = jax.random.PRNGKey(123)
init_conv1, apply_conv1 = stax.Conv(16, (3,3), padding="SAME")
weights1 = init_conv1(rng, (50,1,28,28))
preds1 = apply_conv1(weights1[1], jax.random.uniform(rng, (50,1,28,28)))
init_conv2, apply_conv2 = stax.Conv(32, (3,3), padding="SAME")
weights2 = init_conv2(rng, preds1.shape)
preds2 = apply_conv2(weights2[1], jax.random.uniform(rng, preds1.shape))
print("Weights of First Conv Layer : {}".format(weights1[1][0].shape))
print("Weights of Second Conv Layer : {}".format(weights2[1][0].shape))
print("\nInput Shape : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
Below, we have printed the definition of Conv() layer and we can notice that there are shape details specified as a tuple of three strings. The first string in the tuple represents input image dimensions, the second string represents kernel dimensions and the third string represents output after applying kernel to input images.
We can notice from the default configuration of Conv() layer that it requires channel details present at last.
stax.Conv
But, what if we have images with channel details present first. The JAX provides a layer named GeneralConv() that lets us specify the shape of images and kernel as a tuple.
The GeneralConv() layer works exactly like Conv() but the first input to it is a tuple of 3 strings specifying dimensions of input/output images and kernel. We can handle channels first cases using GeneralConv() layer.
Below, we have explained using two simple examples how we can use the GeneralConv() layer when channel details are present first.
rng = jax.random.PRNGKey(123)
init_conv1, apply_conv1 = stax.GeneralConv(("NCHW", "HWIO", "NCHW"),16, (3,3), padding="SAME")
weights1 = init_conv1(rng, (50,1,28,28))
preds1 = apply_conv1(weights1[1], jax.random.uniform(rng, (50,1,28,28)))
init_conv2, apply_conv2 = stax.GeneralConv(("NCHW", "HWIO", "NCHW"),32, (3,3), padding="SAME")
weights2 = init_conv2(rng, preds1.shape)
preds2 = apply_conv2(weights2[1], jax.random.uniform(rng, preds1.shape))
print("Weights of First Conv Layer : {}".format(weights1[1][0].shape))
print("Weights of Second Conv Layer : {}".format(weights2[1][0].shape))
print("\nInput Shape : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
rng = jax.random.PRNGKey(123)
init_conv1, apply_conv1 = stax.GeneralConv(("NCHW", "IOHW", "NCHW"),16, (3,3), padding="SAME")
weights1 = init_conv1(rng, (50,1,28,28))
preds1 = apply_conv1(weights1[1], jax.random.uniform(rng, (50,1,28,28)))
init_conv2, apply_conv2 = stax.GeneralConv(("NCHW", "IOHW", "NCHW"),32, (3,3), padding="SAME")
weights2 = init_conv2(rng, preds1.shape)
preds2 = apply_conv2(weights2[1], jax.random.uniform(rng, preds1.shape))
print("Weights of First Conv Layer : {}".format(weights1[1][0].shape))
print("Weights of Second Conv Layer : {}".format(weights2[1][0].shape))
print("\nInput Shape : {}".format((50,1,28,28)))
print("Conv Layer 1 Output Shape : {}".format(preds1.shape))
print("Conv Layer 2 Output Shape : {}".format(preds2.shape))
This ends our small tutorial explaining how we can create simple convolutional neural networks using JAX. Please feel free to let us know your views in the comments section.
If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.
When going through coding examples, it's quite common to have doubts and errors.
If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.
You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.
If you want to