Updated On : Dec-13,2021 Time Investment : ~45 mins

Guide to Create Neural Networks using High-level JAX API

JAX is one of the most commonly used frameworks for performing deep learning research nowadays. It provides the majority of numpy APIs which can be run on accelerators like GPUs and TPUs. Apart from that, it provides automatic differentiation(gradients) calculation which eases the process of designing neural networks. It also provides just-in-time functionality to speed up your python functions, even more, using XLA (Accelerated Linear Algebra). JAX has a very flexible API which lets us design neural networks using lower-level API (using numpy-like functions) or using higher-level API available through modules like stax and optimizers. We have already covered a separate tutorial where we explain how to create a neural network using lower-level JAX API. Please feel free to check it from the below link.

If you want to learn about JAX from the basics then please feel free to check our tutorial which covers basics with different examples.

As a part of this tutorial, we'll be concentrating on high-level API available from JAX to create neural networks. We'll be using small toy datasets available from scikit-learn to make examples easy to understand. Our main aim with this tutorial is to get individuals started creating neural networks using high-level JAX API. We won't be covering how optimizers work, how weights are calculated, etc. We expect that the reader has a background in neural networks and knows about things like an optimizer, loss functions, gradient descent algorithm, etc.

Below we have highlighted important sections of the tutorial to give an overview of the material covered in it.

Important Sections of Tutorial

  1. Regression
    • Load Dataset
    • Normalize Data
    • Create Neural Network
    • Define Loss Function
    • Train Neural Network
    • Make Predictions
    • Evaluate Model Performance
    • Train Model in Batches of Data
    • Make Predictions in Batches
    • Evaluate Model Performance
  2. Classification

Below we have imported JAX and printed the version that we'll be using in this tutorial. We have also imported high-level sub-modules stax and optimizers available from example_libraries module of JAX that we'll be using to create neural networks and train them. We have also imported jax.numpy module as we'll require it to convert input data to JAX arrays and a few other calculations.

import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.2.26
from jax.example_libraries import stax, optimizers

import jax.numpy as jnp

1. Regression

In this section, we'll explain how we can create neural networks using higher-level JAX API available from stax and optimizers submodules to solve regression tasks. We'll be using the Boston housing dataset available from scikit-learn.

Load Dataset

In this section, we have first loaded the Boston housing dataset available from scikit-learn. We have loaded data features in variable X and target values in variable Y. The target variable is median hour price in 1000 dollars and features are variable features related to the house. We have then split the dataset into the train (80%) and test (20%) sets. After dividing the dataset, we have converted each numpy array to a Jax array using jax.numpy.array() constructor. We have also printed the shape of the train and test datasets at the end.

from sklearn import datasets
from sklearn.model_selection import train_test_split
from jax import numpy as jnp

X, Y = datasets.load_boston(return_X_y=True)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, random_state=123)

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

samples, features = X_train.shape

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
((404, 13), (102, 13), (404,), (102,))

Normalize Data

In this section, we have normalized our dataset. In order to normalize data, we have first calculated the mean and standard deviation of the training dataset for each feature of data. We have then subtracted the mean from both train and test sets. At last, we have divided subtracted values by standard deviation. The main reason behind performing normalization is to bring values of each feature of data into the almost same scale. This will help optimization algorithm gradient descent to converge faster. If values of different features are on a different scale and vary a lot then it can increase training time because gradient descent algorithm will have a hard time converging.

mean = X_train.mean(axis=0)
std = X_train.std(axis=0)

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

Create Neural Network

In this section, we have created a neural network that we'll be using for our regression task.

The stax module of JAX provides various readily available layers that we can stack together to create a neural network. We'll be creating a fully connected neural network using stax API. The process of creating a neural network using stax module is almost the same as that of creating a neural network using Sequential() API of keras. If the reader has a background with keras then it'll be easy to understand this.

The stax module provides a method named serial() that accepts a list of layers and activation functions as input and creates a neural network. It'll apply the layers in the sequence in which they are given as input when performing forward pass-through data.

We can create dense layers/fully connected layers using Dense() method. It takes as input the number of units that will be present in that layer. We can also provide weight initialization and bias initialization function if we don't want internal initialization performed by JAX after we create the layer using Dense().

The majority of stax module methods (all layers (Dense, Conv, etc) and serial() method), returns 2 callable as output when executed.

  1. init_fun - This function takes seed for weight initialization of that layer/network and input shape for that layer/network as input. It then returns weights and biases. For a single layer, it returns just weights and biases as output and for the neural network, it'll return a list of weights and biases.
  2. apply_fun - This function takes weights & biases of layer/network and data as input. It then executes layer/network on input data using weights. It performs forward pass-through data for the network.

All activation functions as available as simple attributes of stax module and we don't need to call them with brackets. We can just give them as input to serial() method after layers and they will be applied to the output of the layer.

Below we have first simply created a Dense() layer with 5 units to show the output returned by it. We can notice that it returns two callable which we described above.

We have then created our neural network which has layer sizes [5,10,15,1]. The last layer is the output layer and all other layers are hidden layers. We have simply created layers using Dense() method followed by Relu (Rectified Linear Unit) activation function. All the layers and activation functions are provided to serial() method in sequence separated by a comma. The Relu function that we have used in our neural network takes as input an array and returns a new array of the same size where all values less than 0 are replaced by 0.

As we have said earlier, the serial() method returns two callable as output which we have stored in different variables as we'll be using them in the future.

stax.Dense(5)
(<function jax.example_libraries.stax.Dense.<locals>.init_fun(rng, input_shape)>,
 <function jax.example_libraries.stax.Dense.<locals>.apply_fun(params, inputs, **kwargs)>)
neural_net_init, neural_net_apply = stax.serial(
                                                  stax.Dense(5),
                                                  stax.Relu,
                                                  stax.Dense(10),
                                                  stax.Relu,
                                                  stax.Dense(15),
                                                  stax.Relu,
                                                  stax.Dense(1),
                                                )
neural_net_init, neural_net_apply
(<function jax.example_libraries.stax.serial.<locals>.init_fun(rng, input_shape)>,
 <function jax.example_libraries.stax.serial.<locals>.apply_fun(params, inputs, **kwargs)>)

Now, we have simply initialized the weights of our neural network by calling init_fun() function. We have given seed (jax.random.PRNGKey(123)) and input data shape as input to function. It then uses seed and shape information to initialize the weights and biases of each layer of the neural network.

After initializing weights, we have also printed the shape of weights and biases for each layer. This can help us verify that all things went as expected.

rng = jax.random.PRNGKey(123)

weights = neural_net_init(rng, (features,))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))
Weights : (13, 5), Biases : (5,)
Weights : (5, 10), Biases : (10,)
Weights : (10, 15), Biases : (15,)
Weights : (15, 1), Biases : (1,)

In the below cell, we have actually performed a forward pass through our neural network. We have taken a few samples of our data and given them as input to apply_fun() function along with weights. The weights are given first followed by a small batch of data. The apply_func() will perform one forward pass-through data using weights and return predictions.

preds = neural_net_apply(weights, X_train[:5])

preds
DeviceArray([[-2.0367734 ],
             [-0.2672972 ],
             [-0.19287597],
             [-0.5991536 ],
             [-1.5175911 ]], dtype=float32)

Define Loss Function

In this section, we have created a loss function for our neural network. We'll be calculating the gradient of the loss function with respect to weights and then update weights using gradients.

We'll be using Mean squared error loss as our loss. It simply subtracts predictions from actual values, squares subtracted values, and then the mean of them.

MSE(actual, preds) = 1/n *(actual - preds)^2

Our loss function takes weights, data, and actual target values as input. It then performs a forward pass through the neural network using apply_fun() function providing weights and data to it. The predictions made by the network are stored in a variable. We can then actually calculate MSE using actual target values and predictions.

def MeanSquaredErrorLoss(weights, input_data, actual):
    preds = neural_net_apply(weights, input_data)
    preds = preds.squeeze()
    return jnp.power(actual - preds, 2).mean()

Train Neural Network

In this section, we have included code to train our neural network. We have created a small function which we'll call to train our neural network. The function takes data features, target values, number of epochs, and optimizer state as input. Optimizer state is an object created by optimizer which has weights of our model. We'll explain about optimizer in the next cell below when we initialize it.

Our function then loops a number of epochs time. Each time, it first calculates loss value and gradients using value_and_grad() function. This function takes as input another function which is MSE loss function in our case. It then returns another callable which when called will return the actual value of the function as well as the gradient of function with respect to the first parameter which is weights in our case. In our case, we have given our loss function to value_and_grad() function as input and then called the returned function by providing weights, data features, and target values. These three are inputs of our loss function. This call will return MSE value and gradients for weights and biases of each layer of our neural network.

We have then called an optimizer state update method which takes as input epoch number, gradients, and current optimizer state that has current weights. The method returns a new optimizer state which will have weights updated by subtracting gradients from it.

At last, we return the last optimizer state (final updated weights). We are also printing MSE at every 100 epochs to keep track of training progress.

from jax import grad, value_and_grad

def TrainModel(X, Y, epochs, opt_state):

    for i in range(1,epochs+1):
        loss, gradients = value_and_grad(MeanSquaredErrorLoss)(opt_get_weights(opt_state), X, Y)

        ## Update Weights
        opt_state = opt_update(i, gradients, opt_state)

        if i%100 ==0: ## Print MSE every 100 epochs
            print("MSE : {:.2f}".format(loss))

    return opt_state

Here, we are actually training our neural network by calling the function we designed in the previous cell.

First, we have initialized the weights of our neural network by calling init_fun function from earlier.

We have then initialized an optimizer for our neural network. The optimizer is an algorithm responsible for finding the minimum value of our loss function. The optimizers module available from example_libraries module of jax provides us with a list of different optimizers. We'll be using sgd() (gradient descent) optimizer for our purpose. We have initialized our optimizer by giving a learning rate (0.001) to it. The optimizer returns three callable that is needed for maintaining and updating weights of the neural network.

  1. init - This function takes weights of a neural network as input and returns OptimizerState object which is a wrapper for holding and updating weights.
  2. update_fn - This function takes epoch number, gradients and optimizer state as input. It then updates weights present in the optimizer state object by subtracting learning times gradients from it. It then returns a new OptimizerState object which has updated weights.
  3. params_fn - This function takes OptimizerState object as input and returns actual weights of neural network.

After initializing the optimizer with weights, we have called our training routine to actually perform training by providing data, target values, number of epochs, and optimizer state (weights). We are a training network for 2500 epochs. We can notice from MSE getting printed every 100 epochs that the model is getting better at the task.

seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e3)
epochs = 2500

weights = neural_net_init(rng, (features,))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModel(X_train, Y_train, epochs, opt_state)
MSE : 21.85
MSE : 12.79
MSE : 10.86
MSE : 10.09
MSE : 9.61
MSE : 9.29
MSE : 9.08
MSE : 8.93
MSE : 8.80
MSE : 8.67
MSE : 8.57
MSE : 8.48
MSE : 8.41
MSE : 8.34
MSE : 8.27
MSE : 8.21
MSE : 8.15
MSE : 8.08
MSE : 8.02
MSE : 7.95
MSE : 7.89
MSE : 7.84
MSE : 7.79
MSE : 7.75
MSE : 7.72

Make Predictions

In this section, we are making predictions using our trained neural network. We have made predictions for both train and test datasets.

We have retrieved weights of the neural network using params_fn optimizer function. We have then given weights and data features as input to apply_fn method which will make predictions.

test_preds = neural_net_apply(opt_get_weights(final_opt_state), X_test) ## Make Predictions on test dataset

test_preds = test_preds.ravel()

train_preds = neural_net_apply(opt_get_weights(final_opt_state), X_train) ## Make Predictions on train dataset

train_preds = train_preds.ravel()

test_preds[:5], train_preds[:5]
(DeviceArray([16.238165, 26.803806, 43.55032 , 20.445574, 28.37116 ], dtype=float32),
 DeviceArray([47.88758 , 11.842489, 20.448275, 26.680183, 14.941553], dtype=float32))

Evaluate Model Performance

In this section, we are actually evaluating the performance of our regression model. We are calculating R^2 score for both our train and test predictions. We are calculating R^2 score using r2_score() method of scikit-learn. The R^2 score generally returns the value in the range [0,1] where a value near 1 indicates a good model. We can notice from the R^2 score that our model seems to be doing a good job.

If you are interested in learning in detail about R^2 score and other metrics available from scikit-learn for different kinds of tasks then please feel free to check our tutorial which covers the majority of metrics in detail with examples.

from sklearn.metrics import r2_score

print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.to_py(), Y_train.to_py())))
print("Test  R^2 Score : {:.2f}".format(r2_score(test_preds.to_py(), Y_test.to_py())))
Train R^2 Score : 0.90
Test  R^2 Score : 0.75

Train Model in Batches of Data

In real life, there are times when datasets are quite big and do not fit into the main memory of the computer. In those situations, we only bring a small batch of data into the main memory of the computer and train the model in batches until the whole data is covered. The optimization algorithm used in this case is referred to as stochastic gradient descent as it works on a small batch of data at a time.

In this section, we have explained how we can modify our code so that we can perform training on data in batches. We have below declared a function like last time which we'll use for training purposes. The function takes data features, target values, number of epochs, optimizer state (weights), and batch size (default 32) as input. We are then performing training loop number of epochs time. For each training loop, we are calculating start and end indexes of our batch of data. We are performing forward pass, calculating loss, and updating loss on a single batch of data at a time until the whole data is covered in batches. When we can bring whole data in main memory, we update weights only one per epochs but with training in batches, we are updating weights for each batch of data until whole data is covered for a number of epochs.

def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(MeanSquaredErrorLoss)(opt_get_weights(opt_state), X_batch, Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss

        if i % 100 == 0: ## Print MSE every 100 epochs
            print("MSE : {:.2f}".format(jnp.array(losses).mean()))

    return opt_state

Now, we have actually trained our neural network in batches using the function we designed in the previous cell. We have first initialized the weights of the neural network using init_fun by giving seed and input shape to it. Then we have initialized our optimizer by calling sgd() function giving learning rate (0.001) to it. Then we have created the first optimizer state with weights. We have then called our function from the previous cell to perform training in batches. We are training the neural network for 500 epochs.

seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e3)
epochs = 500

weights = neural_net_init(rng, (features,))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state)
MSE : 8.84
MSE : 8.12
MSE : 7.66
MSE : 7.20
MSE : 6.69

Make Predictions in Batches

As we can not fit our data into the main memory, we need to make predictions as well in batches. In this section, we have designed a function that takes weights and data as input and then makes predictions on data in batches. We have followed almost the same logic that was present in the training function to calculate indexes of batches.

def MakePredictions(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        preds.append(neural_net_apply(weights, X_batch))

    return preds

Below we have used our function from the previous cell and made predictions on train and test datasets in batches. We have then combined predictions of batches as well.

test_preds = MakePredictions(opt_get_weights(final_opt_state), X_test)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

train_preds = MakePredictions(opt_get_weights(final_opt_state), X_train)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

test_preds[:5], train_preds[:5]
(DeviceArray([25.908619, 26.660995, 44.688435, 19.98889 , 28.745905], dtype=float32),
 DeviceArray([48.370625 , 12.734019 , 20.412106 , 26.055237 , 13.6108055],            dtype=float32))

Evaluate Model Performance

In this section, we have evaluated the performance of our neural network by calculating R^2 score on train and test predictions.

from sklearn.metrics import r2_score

print("Test  R^2 Score : {:.2f}".format(r2_score(test_preds, Y_test)))
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds, Y_train)))
Test  R^2 Score : 0.76
Train R^2 Score : 0.92

2. Classification

In this section, we'll explain how we can create a neural network to solve classification tasks using the high-level API available from JAX. We'll be using a small toy dataset from scikit-learn for an explanation. We'll also be reusing much of the code which we used in our previous regression section. Due to this, we won't include a detailed description of repeated code over here. Please feel free to look in the regression section if you want a detailed explanation of some section which is not present here.

Load Dataset

In this section, we have loaded the breast cancer dataset available from scikit-learn. The dataset has various measurements of tumors as data features. The target value is either 0 (benign tumor) or 1 (malignant tumor). As the output has only two classes, this will be a binary classification task. We have then divided the dataset into the train (80%) and test (20%) sets. The code for this section is almost the same as the code from the regression section of loading data.

from sklearn import datasets
from sklearn.model_selection import train_test_split

X, Y = datasets.load_breast_cancer(return_X_y=True)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, stratify=Y, random_state=123)

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

samples, features = X_train.shape
classes = jnp.unique(Y_test)

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
((455, 30), (114, 30), (455,), (114,))
samples, features, classes
(455, 30, DeviceArray([0., 1.], dtype=float32))

Normalize Data

In this section, we have normalized our train and test datasets based on a mean and standard deviation of the training dataset for each feature.

mean = X_train.mean(axis=0)
std = X_train.std(axis=0)

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

Create Neural Network

In this section, we have created a neural network that we'll be using for classification. Our neural network for this task is almost exactly the same as our neural network from the regression section with one minor change. We have applied sigmoid function as the activation function of the last layer. The sigmoid function returns output in the range [0-1].

Then in the next two cells, we have initialized weights of the neural network, printed their shape, and also made predictions on the first few training samples for verification purposes that our functions work well.

neural_net_init, neural_net_apply = stax.serial(
                                                  stax.Dense(5),
                                                  stax.Relu,
                                                  stax.Dense(10),
                                                  stax.Relu,
                                                  stax.Dense(15),
                                                  stax.Relu,
                                                  stax.Dense(1),
                                                  stax.Sigmoid
                                                )
rng = jax.random.PRNGKey(123)

weights = neural_net_init(rng, (features,))

weights = weights[1] ## Weights are actually stored in second element of two value tuple

for w in weights:
    if w:
        w, b = w
        print("Weights : {}, Biases : {}".format(w.shape, b.shape))
Weights : (30, 5), Biases : (5,)
Weights : (5, 10), Biases : (10,)
Weights : (10, 15), Biases : (15,)
Weights : (15, 1), Biases : (1,)
preds = neural_net_apply(weights, X_train[:5])

preds
DeviceArray([[0.41089782],
             [0.37661743],
             [0.46994004],
             [0.40898827],
             [0.45434952]], dtype=float32)

Define Loss Function

In this section, we have defined a loss function that we'll be using for our purpose. We'll be using the log loss function as a loss function for our binary classification task.

log_loss(predictions, actuals) = 1/n * (- actuals * log(predictions) -  (1 - actuals) * log(1 - predictions))

The loss function takes weights, input features data, and actual target values as input. It then makes predictions using weights and input data. Then, it calculates log loss values based on predictions and actual target values.

def NegLogLoss(weights, input_data, actual):
    preds = neural_net_apply(weights, input_data)
    preds = preds.squeeze()
    return (- actual * jnp.log(preds) - (1 - actual) * jnp.log(1 - preds)).mean()

Train Model

In this section, we'll be training our neural network for the binary classification task. We have simply copied code to train the neural network from the regression section. The code is almost exactly the same as the regression section with only a change in the loss function.

from jax import grad, value_and_grad

def TrainModel(X, Y, epochs, opt_state):

    for i in range(1,epochs+1):
        loss, gradients = value_and_grad(NegLogLoss)(opt_get_weights(opt_state), X, Y)

        ## Update Weights
        opt_state = opt_update(i, gradients, opt_state)

        if i%100 ==0: ## Print MSE every 100 epochs
            print("NegLogLoss : {:.2f}".format(loss))

    return opt_state

Now, we are actually training our neural network by calling the training function we designed in the previous cell. We have first initialized seed for random numbers, learning rate (0.0001), and the number of epochs (1500). We have then initialized the weights of the neural network using init_fun function by providing seed to it as usual.

Then we have initialized the optimizer that we'll use for optimizing the weights of the neural network. This time we are using rmsprop() optimizer. The RMSProp (Root Mean Squared Propagation) is an updated optimizer from SGD which we used in the regression section. The RMSProp optimizer updates the learning rate over time for better performance and faster convergence. We have then called the init method of optimizer with network weights to create the initial optimizer state. We have then called training routing with train data, train target values, number of epochs, and optimizer state (weights). We can notice from the log loss getting printed every 100 epochs that our neural network seems to be doing a good job.

seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 1500

weights = neural_net_init(rng, (features,))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.rmsprop(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModel(X_train, Y_train, epochs, opt_state)
NegLogLoss : 0.73
NegLogLoss : 0.67
NegLogLoss : 0.63
NegLogLoss : 0.58
NegLogLoss : 0.53
NegLogLoss : 0.47
NegLogLoss : 0.41
NegLogLoss : 0.36
NegLogLoss : 0.31
NegLogLoss : 0.27
NegLogLoss : 0.23
NegLogLoss : 0.20
NegLogLoss : 0.17
NegLogLoss : 0.15
NegLogLoss : 0.13

Make Predictions

In this section, we are making predictions on train and test datasets using our updated weights and neural network. We are calling apply_fun method of neural network with weights and features data to make predictions. The output of our neural network is probability in the range [0-1] due to the sigmoid activation function. We need to convert these probabilities to the actual class of prediction. In order to find out a class from probability, we have set a threshold of 0.5 so that values less than 0.5 will be predicted as class 0 (benign tumor), and values greater than 0.5 will be predicted as class 1 (malignant tumor).

test_preds = neural_net_apply(opt_get_weights(final_opt_state), X_test) ## Make Predictions on test dataset

test_preds = test_preds.ravel() ## Combine predictions of all batches

test_preds = (test_preds > 0.5).astype(jnp.float32)

test_preds[:5], Y_test[:5]
(DeviceArray([0., 0., 1., 1., 1.], dtype=float32),
 DeviceArray([0., 0., 1., 1., 1.], dtype=float32))
train_preds = neural_net_apply(opt_get_weights(final_opt_state), X_train) ## Make Predictions on train dataset

train_preds = train_preds.ravel()  ## Combine predictions of all batches

train_preds = (train_preds > 0.5).astype(jnp.float32)

train_preds[:5], Y_train[:5]
(DeviceArray([1., 1., 0., 1., 1.], dtype=float32),
 DeviceArray([1., 1., 0., 0., 1.], dtype=float32))

Evaluate Model Performance

In this section, we have evaluated the performance of our model by calculating the accuracy of train and test predictions. We can notice from the results that our model seems to be doing a decent job at prediction. We have used accuracy_score() method available from scikit-learn to calculate accuracy.

Then in the next cell, we have printed classification report of our test predictions which includes information like precision, recall, and f1-score for each class.

We recommend that you go through our tutorial on ML metrics available through scikit-learn if you don't have a background on these metrics to better understand them.

from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))
Train Accuracy : 0.97
Test  Accuracy : 0.95
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
Test Classification Report
              precision    recall  f1-score   support

         0.0       0.91      0.95      0.93        42
         1.0       0.97      0.94      0.96        72

    accuracy                           0.95       114
   macro avg       0.94      0.95      0.94       114
weighted avg       0.95      0.95      0.95       114

Train Model in Batches of Data

In this section, we have explained how we can perform training in batches of data. We have copied the code from the regression section to perform training on batches of data. The only difference in code is that we are using the log loss function here. The rest of the code is exactly the same as earlier.

def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
    for i in range(1, epochs+1):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss, gradients = value_and_grad(NegLogLoss)(opt_get_weights(opt_state), X_batch, Y_batch)

            ## Update Weights
            opt_state = opt_update(i, gradients, opt_state)

            losses.append(loss) ## Record Loss

        if i % 100 == 0: ## Print NegLogLoss every 100 epochs
            print("NegLogLoss : {:.2f}".format(jnp.array(losses).mean()))

    return opt_state

Below, we are actually training our neural network by calling the training function from the previous cell. We have the first initialized seed for generating random numbers, learning rate (0.0001), and a number of epochs (200). Then, we have initialized the weights of the neural network by calling init_fun function providing seed and input shape. We have then initialized the RMSProp optimizer that we'll be using for optimizing weights. We have then called the training function to actually perform training with our train data in batches. We can notice from the loss value getting printed that our model seems to be doing a decent job.

seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 200

weights = neural_net_init(rng, (features,))
weights = weights[1]


opt_init, opt_update, opt_get_weights = optimizers.rmsprop(learning_rate)
opt_state = opt_init(weights)

final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state)
NegLogLoss : 0.22
NegLogLoss : 0.09

Make Predictions in Batches

In this section, we have made predictions on train and test datasets in batches. We have copied the function from the regression section which we had used there to make predictions in batches. We have then combined predictions of all batches and converted probabilities to prediction classes by setting the threshold at 0.5.

def MakePredictions(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        preds.append(neural_net_apply(weights, X_batch))

    return preds
test_preds = MakePredictions(opt_get_weights(final_opt_state), X_test)

test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches

test_preds = (test_preds > 0.5).astype(jnp.float32)

train_preds = MakePredictions(opt_get_weights(final_opt_state), X_train)

train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches

train_preds = (train_preds > 0.5).astype(jnp.float32)

test_preds[:5], train_preds[:5]
(DeviceArray([0., 0., 1., 1., 1.], dtype=float32),
 DeviceArray([1., 1., 0., 1., 1.], dtype=float32))

Evaluate Model Performance

In this section, we have evaluated the performance of our model by calculating the accuracy of our train and test predictions. We have then also printed a classification report of our train predictions.

from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))
Train Accuracy : 0.98
Test  Accuracy : 0.96
from sklearn.metrics import classification_report

print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
Test Classification Report
              precision    recall  f1-score   support

         0.0       0.93      0.95      0.94        42
         1.0       0.97      0.96      0.97        72

    accuracy                           0.96       114
   macro avg       0.95      0.96      0.95       114
weighted avg       0.96      0.96      0.96       114

This ends our small tutorial explaining how we can use high-level JAX API available through 'stax' and 'optimizers' module to create neural networks. Please feel free to let us know your views in the comments section.

References

Sunny Solanki  Sunny Solanki

YouTube Subscribe Comfortable Learning through Video Tutorials?

If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.

Need Help Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.


Subscribe to Our YouTube Channel

YouTube SubScribe

Newsletter Subscription