JAX is a Python library designed specifically to boost machine learning research. It provides features like numpy-like API on GPUs/TPUs, automatic gradients calculation, faster code using XLA, Just-in-time compilation of code, etc. Due to the simplicity of its API, it has been widely adopted by many researchers to perform machine learning research. We have already covered a tutorial explaining the features of JAX with simple examples. Please feel free to check it as it provides the necessary background for this tutorial.
As a part of this tutorial, we'll be using JAX API to create simple neural networks. We'll be using lower-level API to create neural networks of dense layers which we'll use to solve regression and classification tasks.
JAX also provides a high-level API to create neural networks through modules stax and optimizers. The usage of high-level API can reduce the amount of code and simplify the process. Please feel free to check the below tutorial if you are want to develop neural networks using high-level API.
Below we have highlighted important sections of the tutorial to give an overview of the material covered.
Please make a NOTE that we assume that the readers have a little bit of background on machine learning/ deep learning and topics like activation functions, weights, biases, loss functions, gradients, etc as we won't be covering them in too much detail. The main aim of this tutorial is to get individuals started designing neural networks using JAX.
Below we have imported JAX and its submodules that we'll be using in our tutorials.
import jax
print("JAX Version : {}".format(jax.__version__))
from jax import numpy as jnp
from jax import grad, value_and_grad
import numpy as np
In this section, we'll create a simple neural network using JAX to solve a regression task. We'll be using the Boston housing toy dataset available from scikit-learn for our purpose. We'll create small parts of the networks (weights, loss functions, gradients calculation function, training function, etc.) individually, test them and then connect all of them. We'll implement a training loop to work on the whole data at once as this is a toy dataset and fits into main memory and on data in batches as well.
In this section, we have loaded the Boston housing dataset available from scikit-learn. We have loaded data as X and Y variables, where X will have data features and Y will be the target variable that we'll predict based on values in X. The target variable is median house price in the Boston area which is a continuous variable hence this will be a regression task.
After loading the dataset, we have divided it into the train (80%) and test (20%) sets. We have then converted the datasets which are currently numpy arrays to JAX array using jax.numpy module. We have also recorded training data samples count and features count in different variables as we'll need them.
from sklearn import datasets
from sklearn.model_selection import train_test_split
from jax import numpy as jnp
X, Y = datasets.load_boston(return_X_y=True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, random_state=123)
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
samples, features = X_train.shape
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
samples, features
In this section, we have normalized our data in order to bring all of them to the almost same scale. This will help the gradient descent optimizer to converge faster. If features of data are on different scales which varies a lot then it can create problems for gradient descent algorithm to converge. The scaling can help us fight this problem.
In order to perform scaling, we have calculated the mean and standard deviation of train data first. Then we have subtracted the mean from both datasets (train & test) and divided them by standard deviation.
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)
X_train = (X_train - mean) / std
X_test = (X_test - mean) / std
In this section, we'll be writing a small function which we'll use to initialize weights and biases of layers of our neural network. We'll be using jax.random module to initialize weights with random numbers initially.
Our function takes two arguments to initialize weights of neural network and return it. The first argument is layer sizes which number of units to keep in each layer including the last layer and the second argument is seed for initializing random numbers. The layer sizes is a list of integers specifying the number of neurons/units to keep in each layer including the last layer and seed is jax.random.PRNGKey object required by all jax.random module function for reproducibility.
We loop through layer sizes and initialize weights and biases for each layer using jax.random.uniform() function in the range (-1.0,1.0). For each layer, the shape of weights will be (#units x #units_previous_layer) except for the first layer whose shape will be of form (#units x #features). Here, #units refers to units (neurons) of that particular layer. All biases will have shape (#units,). We keep weights and biases of each layer in a list and add them to the final weights list. At last, we return the final list with weights and biases for each layer.
If you want to know how to use jax.random module functions then please feel free to check our tutorial on jax which covers it.
def InitializeWeights(layer_sizes, seed):
weights = []
for i, units in enumerate(layer_sizes):
if i==0:
w = jax.random.uniform(key=seed, shape=(units, features), minval=-1.0, maxval=1.0, dtype=jnp.float32)
else:
w = jax.random.uniform(key=seed, shape=(units, layer_sizes[i-1]), minval=-1.0, maxval=1.0,
dtype=jnp.float32)
b = jax.random.uniform(key=seed, minval=-1.0, maxval=1.0, shape=(units,), dtype=jnp.float32)
weights.append([w,b])
return weights
Below we have tested our weights initializing function. We have tried to create weights of neural network with layers size [5,10,1] for our current problem of Boston housing. After initializing weights, we have looped through weights returned by the function and printed the shape of them to verify. We can notice that weights are initialized in proper shape including the first and last layers.
seed = jax.random.PRNGKey(123)
weights = InitializeWeights([5,10,1], seed)
for w in weights:
print(w[0].shape, w[1].shape)
In this section, we have created the activation function that we'll be using to activate the output of hidden layers. We have created an activation function named Relu (Rectified Linear Unit) which takes as input array and returns an array where all values less than 0 are replaced by zero. This is a commonly used activation function for hidden layers of neural networks.
def Relu(x):
return jnp.maximum(x, jnp.zeros_like(x)) # max(0,x)
Below we have tested our Relu function by giving a simple array.
x = jnp.array([-1,0,1,-2,4,-6,5])
Relu(x)
In this section, we have created a function that includes the logic of a single layer of neural network. It takes as input weights (weights & biases) for the layer, input data (output of previous layer or data for the first layer), and activation function for the current layer. It then performs matrix multiplication of input data and weights and adds bias values to the result. At last, it applies the activation function to the result before returning it.
When performing the dot product of input data and weights, we have taken transpose of weights. The reason behind this is that the shape of our weights are (#current_layer_units x #previous_layer_units) and as we have input data in shape (#previous_layer_units, #batch_size), we need to take the transpose of it to match dimensions for matrix dot product.
def LinearLayer(weights, input_data, activation=lambda x: x):
w, b = weights
out = jnp.dot(input_data, w.T) + b
return activation(out)
Below we have tested our function which performs the logic of a single layer of neural network. We have created random data of shape (5, #features). We have then fed weights of the first layer and random data as input to function and recorded output. We have printed the shape of input and output both. We have used weights that we initialized during our weights initialization section. There, we had 5 units in first layer of neural network, hence output size if (5,5) which represents (#batch_size x #units). Here, we have multiplied matrix of shape (5,13) (input data) with (13,5) (transposed first layer weights).
rand_data = jax.random.uniform(key=seed, shape=(5, X_train.shape[1]))
out = LinearLayer(weights[0], rand_data)
print("Data Shape : {}".format(rand_data.shape))
print("Output Shape : {}".format(out.shape))
In this section, we have defined a function that performs a single pass through the whole neural network and returns predictions calculated by the neural network. The function takes as input neural network weights and input data, it then performs one forward pass through input data and returns predictions.
The function loops through weights and applies weights & biases to input data by calling the function which we had designed for a single layer in the previous section. It applies Relu activation to all layers except the last layer. The last layer does not have any activation function. We have separated logic for the last layer out of the loop.
def ForwardPass(weights, input_data):
layer_out = input_data
for i in range(len(weights[:-1])):
layer_out = LinearLayer(weights[i], layer_out, Relu)
preds = LinearLayer(weights[-1], layer_out)
return preds.squeeze()
Below we have tested our function which performs one forward pass-through data. We have given weights and train data as input to the function and made predictions. We have printed the shape of predictions to match it with the number of input samples.
preds = ForwardPass(weights, X_train)
preds.shape
In this section, we have defined the loss function for our regression problem. We'll be using the mean squared error loss function for our purposes. To calculate MSE, we first take the difference between the actual target and predictions, square the differences and then take an average of the total array of squared differences.
Our loss function takes as input weights of the neural network, input data, and actual target as input. It then performs forward pass-through input data and makes predictions. At last, it calculates MSE based on actual target data and predictions and returns it.
MSE(preds, actual) = 1/n *(actual - preds)^2
n = number of samples
def MeanSquaredErrorLoss(weights, input_data, actual):
preds = ForwardPass(weights, input_data)
return jnp.power(actual - preds, 2).mean()
In order to execute a gradient descent algorithm when training to update weights of our neural network, we need to calculate gradients of the loss function with respect to weights. We'll then use these gradient values to update our actual weights to reduce the loss of our neural network and improve its performance.
JAX let us easily calculate the gradients of any function using grad() function. Below we have designed a function that takes input weights, input data, and actual target values for input data. It then first calculates the gradient of our MSE loss function. It then executes gradient function using weights, input data, and actual target values to get gradient values of loss with respect to weights. The grad() function by default will calculate gradients with respect to the first input argument of the function which in our case of MSE loss function is the weights of the neural network. At last, our function returns gradients.
If you are interested in learning about grad() function in-depth then please feel free to check our tutorial on JAX basics where we discuss it in detail.
Apart from grad(), JAX also provides one important function named value_and_grad() which returns value of input function and gradients both when used. So if we had used value_and_grad() in out case, it would have returned MSE value and gradients both when executed.
from jax import grad, value_and_grad
def CalculateGradients(weights, input_data, actual):
Grad_MSELoss = grad(MeanSquaredErrorLoss)
gradients = Grad_MSELoss(weights, input_data, actual)
return gradients
In this section, we'll be actually training our neural network by combining all pieces which we had designed till now. We have written a small function to train our neural network and update weights.
The function takes input weights, input data, target values, learning rate, and a number of epochs. The epochs refer to the number of times we want to perform forward pass through full data. A learning rate is a small number which will be the margin by which we'll modify our weights.
The function loops the number of epochs times. It calculates loss and gradients of loss with respect to weights each time. It then updates the weights and biases of each layer of the neural network. The update happens by subtracting the multiplication of the learning rate and gradients from weights. This is commonly referred to as gradient descent algorithm. We also print loss every 100 epochs.
def TrainModel(weights, X, Y, learning_rate, epochs):
for i in range(epochs):
loss = MeanSquaredErrorLoss(weights, X, Y)
gradients = CalculateGradients(weights, X, Y)
## Update Weights
for j in range(len(weights)):
weights[j][0] -= learning_rate * gradients[j][0] ## Update Weights
weights[j][1] -= learning_rate * gradients[j][1] ## Update Biases
if i%100 ==0: ## Print MSE every 100 epochs
print("MSE : {:.2f}".format(loss))
Below we are actually performing training of our neural network by calling the function we designed in the previous cell.
We first initialize seed for weights, learning rate, epochs, and a list of layer sizes. We have created a neural network of layer sizes [5,10,15,1]. We then call our function to initialize weights by giving input layer sizes and weights.
Once weights are initialized, we call our train model function with weights, train data, train target values, learning rate, and epochs. We'll be running our training loop for 1500 epochs with a learning rate of 0.001 for good results. We can notice function prints MSE loss every 100 epochs.
seed = jax.random.PRNGKey(42)
learning_rate = jnp.array(1/1e3)
epochs = 1500
layer_sizes = [5,10,15,1]
weights = InitializeWeights(layer_sizes, seed)
TrainModel(weights, X_train, Y_train, learning_rate, epochs)
In this section, we are actually making predictions on our train and test datasets. We'll be using our forward pass function which we had designed earlier to make predictions.
Below we have first made predictions on the test dataset and then on the training dataset. We have also printed the first few predictions.
test_preds = ForwardPass(weights, X_test)
test_preds[:5], Y_test[:5]
train_preds = ForwardPass(weights, X_train)
train_preds[:5], Y_train[:5]
In this section, we are evaluating the performance of our neural network by calculating loss and R^2 score on the train and test datasets.
Below we have first calculated MSE loss on both train and test datasets using our loss function.
In the next cell, we have used r2_score() function available from scikit-learn to calculate R^2 score on train and test datasets using predictions and actual target values. The R^2 score generally returns the value in the range [0-1] where the values near 1 indicate a good model.
If you are interested in learning more about R^2 score and other metrics available from scikit-learn to measure the performance of models then please feel free to check our tutorial on the same. It discusses many metrics in detail with examples.
print("Test MSE Score : {:.2f}".format(MeanSquaredErrorLoss(weights, X_test, Y_test)))
print("Train MSE Score : {:.2f}".format(MeanSquaredErrorLoss(weights, X_train, Y_train)))
from sklearn.metrics import r2_score
print("Test R^2 Score : {:.2f}".format(r2_score(test_preds.squeeze(), Y_test)))
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.squeeze(), Y_train)))
There are situations in real life where our full dataset does not fit into the main memory of the computer. In those situations, we need to take a small batch of data that fits into the memory train model on it and update weights. To tackle this problem, we have modified our training routine to work on input data in batches. This routine can be modified as per the need. Our dataset for this example is small and fits into the main memory of the computer but we are including training batches for explanation purposes.
When we perform training in batches and update weights for each batch of data, the algorithm is now referred to as stochastic gradient descent which is a modified version of gradient descent. The gradient descent works on whole data whereas stochastic gradient descent works on a small subset of data and update weights.
Our training function takes input weights, input data, target values, learning rate, number of epochs, and batch size. It loops through training data number of epochs time. It generates start and end indices of batches of data. It then creates a batch of data and calculates loss and gradients on them. We update weights using gradients calculated on each batch. We have separated logic to update weights in a different function. We are also printing MSE at every 100 epochs.
The main difference compared to training on whole data is that weights get updated for whole data at once for training on whole data and weights get updated more than once (updates happen for each batch) for training in batches on whole data.
def UpdateWeights(learning_rate, weights, gradients):
for j in range(len(weights)): ## Update Weights
weights[j][0] -= learning_rate * gradients[j][0] ## Update Weights
weights[j][1] -= learning_rate * gradients[j][1] ## Update Biases
def TrainModelInBatches(weights, X, Y, learning_rate, epochs, batch_size=32):
for i in range(epochs):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss = MeanSquaredErrorLoss(weights, X_batch, Y_batch) ## Loss of batch
gradients = CalculateGradients(weights, X_batch, Y_batch)
losses.append(loss) ## Record Loss
UpdateWeights(learning_rate, weights, gradients) ## Update Weights
if i % 100 == 0: ## Print MSE every 100 epochs
print("MSE : {:.2f}".format(jnp.array(losses).mean()))
Below we have first initialized seed, learning rate epochs, and layer sizes. We'll be creating a neural network with layer sizes [5,10,15,1]. We have then initialized weights of the network using the function we created earlier. We have then called our training routine to train data in batches. We have asked the function to use a batch size of 32. We are training the neural network for 500 epochs this time with a learning rate of 0.001.
seed = jax.random.PRNGKey(42)
learning_rate = jnp.array(1/1e3)
epochs = 500
layer_sizes = [5,10,15,1]
weights = InitializeWeights(layer_sizes, seed)
TrainModelInBatches(weights, X_train, Y_train, learning_rate, epochs, batch_size=32)
Here, we have created a function that is used to make predictions on data in batches. As we have assumed that we are training data in batches because total data does not feel into main memory, we'll be making predictions in batches.
The function takes input weights, input data, and batch size. It then loops through data in batches making predictions on a single batch at a time. We have then combined predictions of each batch of data.
def MakePredictions(weights, input_data, batch_size=32):
batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices
preds = []
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch = input_data[start:end]
preds.append(ForwardPass(weights, X_batch))
return preds
Below we have used our function from the previous cell to make predictions on test and train data in batches. We have then combined predictions of individual batches into one array.
test_preds = MakePredictions(weights, X_test)
test_preds = jnp.concatenate(test_preds).squeeze()
train_preds = MakePredictions(weights, X_train)
train_preds = jnp.concatenate(train_preds).squeeze()
Below we have evaluated R^2 score on train and test predictions using scikit-learn r2_score() method. The score seems to be a little better compared to our previous score on the whole data.
from sklearn.metrics import r2_score
print("Test R^2 Score : {:.2f}".format(r2_score(test_preds.squeeze(), Y_test)))
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.squeeze(), Y_train)))
In this section, we'll explain how we can create a neural network for solving classification tasks. We'll be using a small toy dataset available from scikit-learn for our classification task. We'll be reusing many functions and much of the code we have used in the regression section hence we won't be including a detailed description of repeat code parts. We'll be designing the neural network to solve a binary classification task.
In this section, we have loaded a breast cancer dataset available from scikit-learn. The dataset has measures of various features of the tumor (X) and target variable (Y) is either '1' indicating malignant tumor or '0' indicating benign tumor.
After loading the dataset, we have split it into the train (80%) and test (20%) sets. We have then converted data loaded as a numpy array to JAX arrays as our code will work on JAX arrays.
We have also recorded the number of training samples, data features, and the number of classes which we have printed in the next cell below.
from sklearn import datasets
from sklearn.model_selection import train_test_split
X, Y = datasets.load_breast_cancer(return_X_y=True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, stratify=Y, random_state=123)
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
samples, features = X_train.shape
classes = np.unique(Y)
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
samples, features, classes
In this section, we have normalized our data as we did in the regression section by subtracting the mean and dividing by the standard deviation of train data.
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)
X_train = (X_train - mean) / std
X_test = (X_test - mean) / std
In this section, we have included code for initializing weights. The function is exactly the same as that of the regression section. We have included it here for someone to follow up the code of the section if they have started directly here.
def InitializeWeights(layer_sizes, seed):
weights = []
for i, units in enumerate(layer_sizes):
if i==0:
w = jax.random.uniform(key=seed, shape=(units, features), minval=-1.0, maxval=1.0, dtype=jnp.float32)
else:
w = jax.random.uniform(key=seed, shape=(units, layer_sizes[i-1]), minval=-1.0, maxval=1.0,
dtype=jnp.float32)
b = jax.random.uniform(key=seed, minval=-1.0, maxval=1.0, shape=(units,), dtype=jnp.float32)
weights.append([w,b])
return weights
In this section, we have again included Relu (Rectified Linear Units) function which we'll use to activate hidden layers of our neural networks.
def Relu(x):
return jnp.maximum(x, jnp.zeros_like(x)) # max(0,x)
In this section, we have designed a sigmoid activation function which will be the activation function applied to the last layer of the neural network for the binary classification task. It'll map input values in the range [0-1]. We'll then put the threshold at 0.5 to predict values less than that threshold as class 0 (benign) and values greater than that threshold as class 1 (malignant).
sigmoid(x) = 1 / ( 1 + e^-x)
We have also tested the accuracy of our function by comparing the results against JAX function jax.nn.sigmoid.
def Sigmoid(x):
return 1 / (1 + jnp.exp(-1 * x))
arr = jnp.array([1,2,3,4,5], dtype=jnp.float32)
Sigmoid(arr), jax.nn.sigmoid(arr)
In this section, we have included a function that applies weights, biases, and an activation function of a single layer on input data. The code is exactly the same as that from the regression section.
def LinearLayer(weights, input_data, activation=lambda x: x):
w, b = weights
out = jnp.dot(input_data, w.T) + b
return activation(out)
In this section, we have included a function that performs a single pass of input data through a whole neural network. It has almost exactly the same code as that of from regression section with only one minor change. In the regression section, we did not have any activation function for the last layer whereas here, we have applied the sigmoid function as an activation function to our last layer.
def ForwardPass(weights, input_data):
layer_out = input_data
for i in range(len(weights[:-1])):
layer_out = LinearLayer(weights[i], layer_out, Relu)
preds = LinearLayer(weights[-1], layer_out, Sigmoid)
return preds.squeeze()
In this section, we have defined the loss function for our classification task. We'll be using the negative log loss function for our task. The function takes input weights, input data, and actual target values. It then makes a prediction on input data using weights. Then it calculated actual log loss using predicted values and actual target values.
log_loss(actual, preds) = 1/n * (- actual * log(preds) - (1- actual) * log(1 - preds))
def NegLogLoss(weights, input_data, actual):
preds = ForwardPass(weights, input_data)
return (- actual * jnp.log(preds) - (1 - actual) * jnp.log(1 - preds)).mean()
Below we have defined the function which calculates gradients of our loss function with respect to weights. The code for this function is exactly the same as that from the regression section with only one change which is we are taking the gradient of negative log loss here.
from jax import grad, value_and_grad
def CalculateGradients(weights, input_data, actual):
Grad_NegLogLoss = grad(NegLogLoss)
gradients = Grad_NegLogLoss(weights, input_data, actual)
return gradients
In this section, we have created a function that actually performs training on our train dataset. It loops through data for the specified epochs making predictions each time and updating weights based on gradients of the loss function. The code for this function is exactly the same as the one we used in the regression section. We print loss at every 100 epochs.
def TrainModel(weights, X, Y, learning_rate, epochs):
for i in range(epochs):
loss = NegLogLoss(weights, X, Y)
gradients = CalculateGradients(weights, X, Y)
## Update Weights
for j in range(len(weights)):
weights[j][0] -= learning_rate * gradients[j][0] ## Update Weights
weights[j][1] -= learning_rate * gradients[j][1] ## Update Biases
if i%100 ==0: ## Print LogLoss every 100 epochs
print("NegLogLoss : {:.2f}".format(loss))
Below we have initialized seed, learning rate, epochs, and layer sizes. We'll be first initializing weights of the neural network using the weight initialization method we had created earlier. We'll be then giving weights, the train features data, train target values, learning rate, and epochs as input to function to actually perform training.
We are using the network with layer sizes [5,10,15,1] and a learning rate of 0.01 for our training purpose. We are training the network for 1500 epochs for good results.
seed = jax.random.PRNGKey(42)
learning_rate = jnp.array(1/1e2)
epochs = 1500
layer_sizes = [5,10,15,1]
weights = InitializeWeights(layer_sizes, seed)
TrainModel(weights, X_train, Y_train, learning_rate, epochs)
In this section, we have actually made predictions on train and test datasets using the forward pass function we designed earlier. As the output of our neural network is from the sigmoid function, the values will be in the range [0,1] and we need to convert these values to the actual class of classification task.
We have set the threshold of 0.5 for finding the class of prediction. The values less than this will be predicted as class 0 (benign) and values greater than that will be class 1 (malignant).
We have made predictions for both train and test datasets.
test_preds = ForwardPass(weights, X_test)
test_preds = (test_preds > 0.5).astype(jnp.float32)
test_preds[:5], Y_test[:5]
train_preds = ForwardPass(weights, X_train)
train_preds = (train_preds > 0.5).astype(jnp.float32)
train_preds[:5], Y_train[:5]
In this section, we have evaluated the performance of our model by calculating the log loss and accuracy of both train and test sets. We can notice from the results that the network seems to have done a decent job of prediction with nearly 93-94% accuracy on both train and test sets.
print("Test NegLogLoss Score : {:.2f}".format(NegLogLoss(weights, X_test, Y_test)))
print("Train NegLogLoss Score : {:.2f}".format(NegLogLoss(weights, X_train, Y_train)))
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))
In this section, we have created a function to train data in batches. The code for this function is almost exactly the same as our function from the regression section with only one minor difference which is that we are using the log loss function to calculate the loss in this function.
def UpdateWeights(learning_rate, weights, gradients):
for j in range(len(weights)): ## Update Weights
weights[j][0] -= learning_rate * gradients[j][0] ## Update Weights
weights[j][1] -= learning_rate * gradients[j][1] ## Update Biases
def TrainModelInBatches(weights, X, Y, learning_rate, epochs, batch_size=32):
for i in range(epochs):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss = NegLogLoss(weights, X_batch, Y_batch)
gradients = CalculateGradients(weights, X_batch, Y_batch)
losses.append(loss) ## Record Loss
UpdateWeights(learning_rate, weights, gradients) ## Update Weights
if i % 100 == 0: ## Print LogLoss every 100 epochs
print("NegLogLoss : {:.2f}".format(jnp.array(losses).mean()))
Below we have initialized seed, learning rate (0.001), epochs (1000) and layer sizes first ([5,10,15,1]). We have then initialized neural network weights using the weight initialization function which we had designed earlier by giving layer sizes to it. We have then called our training function with weights, train features, train target variables, learning rate, epochs, and batch size (32) to perform the actual training process.
seed = jax.random.PRNGKey(42)
learning_rate = jnp.array(1/1e3)
epochs = 1000
layer_sizes = [5,10,15,1]
weights = InitializeWeights(layer_sizes, seed)
TrainModelInBatches(weights, X_train, Y_train, learning_rate, epochs, batch_size=32)
In this section, we have included the function to make predictions in batches. We have copied code for the same function defined in the regression section.
def MakePredictions(weights, input_data, batch_size=32):
batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices
preds = []
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch = input_data[start:end]
preds.append(ForwardPass(weights, X_batch))
return preds
Below we have made predictions for both train and test datasets both. As the output of our neural network is probability in the range [0-1], we have included logic to convert probability to actual class by setting the threshold at 0.5.
train_preds = MakePredictions(weights, X_train)
train_preds = jnp.concatenate(train_preds).squeeze()
train_preds = (train_preds > 0.5).astype(jnp.float32)
train_preds[:5], Y_train[:5]
test_preds = MakePredictions(weights, X_test)
test_preds = jnp.concatenate(test_preds).squeeze()
test_preds = (test_preds > 0.5).astype(jnp.float32)
test_preds[:5], Y_test[:5]
In this section, we have evaluated the performance of our neural network classifier by calculating the accuracy of train and test dataset predictions.
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))
This ends our small tutorial explaining how we can create simple neural networks using JAX. Please feel free to let us know your views in the comments section.
If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.
When going through coding examples, it's quite common to have doubts and errors.
If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.
You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.
If you want to