**JAX** is one of the most commonly used frameworks for performing deep learning research nowadays. It provides the majority of numpy APIs which can be run on accelerators like GPUs and TPUs. Apart from that, it provides automatic differentiation(gradients) calculation which eases the process of designing neural networks. It also provides just-in-time functionality to speed up your python functions, even more, using **XLA (Accelerated Linear Algebra)**. **JAX** has a very flexible API which lets us design neural networks using lower-level API (using numpy-like functions) or using higher-level API available through modules like **stax** and **optimizers**. We have already covered a separate tutorial where we explain how to create a neural network using lower-level JAX API. Please feel free to check it from the below link.

If you want to learn about **JAX** from the basics then please feel free to check our tutorial which covers basics with different examples.

As a part of this tutorial, we'll be concentrating on high-level API available from **JAX** to create neural networks. We'll be using small toy datasets available from scikit-learn to make examples easy to understand. Our main aim with this tutorial is to get individuals started creating neural networks using high-level **JAX** API. We won't be covering how optimizers work, how weights are calculated, etc. We expect that the reader has a background in neural networks and knows about things like an optimizer, loss functions, gradient descent algorithm, etc.

Below we have highlighted important sections of the tutorial to give an overview of the material covered in it.

**Regression**- Load Dataset
- Normalize Data
- Create Neural Network
- Define Loss Function
- Train Neural Network
- Make Predictions
- Evaluate Model Performance
- Train Model in Batches of Data
- Make Predictions in Batches
- Evaluate Model Performance

**Classification**

Below we have imported **JAX** and printed the version that we'll be using in this tutorial. We have also imported high-level sub-modules **stax** and **optimizers** available from **example_libraries** module of **JAX** that we'll be using to create neural networks and train them. We have also imported **jax.numpy** module as we'll require it to convert input data to JAX arrays and a few other calculations.

In [1]:

```
import jax
print("JAX Version : {}".format(jax.__version__))
```

In [2]:

```
from jax.example_libraries import stax, optimizers
import jax.numpy as jnp
```

In this section, we'll explain how we can create neural networks using higher-level **JAX** API available from **stax** and **optimizers** submodules to solve regression tasks. We'll be using the Boston housing dataset available from scikit-learn.

In this section, we have first loaded the Boston housing dataset available from scikit-learn. We have loaded data features in variable **X** and target values in variable **Y**. The target variable is median hour price in 1000 dollars and features are variable features related to the house. We have then split the dataset into the train (80%) and test (20%) sets. After dividing the dataset, we have converted each numpy array to a Jax array using **jax.numpy.array()** constructor. We have also printed the shape of the train and test datasets at the end.

In [3]:

```
from sklearn import datasets
from sklearn.model_selection import train_test_split
from jax import numpy as jnp
X, Y = datasets.load_boston(return_X_y=True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, random_state=123)
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
samples, features = X_train.shape
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
```

Out[3]:

In this section, we have normalized our dataset. In order to normalize data, we have first calculated the mean and standard deviation of the training dataset for each feature of data. We have then subtracted the mean from both train and test sets. At last, we have divided subtracted values by standard deviation. The main reason behind performing normalization is to bring values of each feature of data into the almost same scale. This will help optimization algorithm gradient descent to converge faster. If values of different features are on a different scale and vary a lot then it can increase training time because gradient descent algorithm will have a hard time converging.

In [4]:

```
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)
X_train = (X_train - mean) / std
X_test = (X_test - mean) / std
```

In this section, we have created a neural network that we'll be using for our regression task.

The **stax** module of **JAX** provides various readily available layers that we can stack together to create a neural network. We'll be creating a fully connected neural network using **stax** API. The process of creating a neural network using **stax** module is almost the same as that of creating a neural network using **Sequential()** API of **keras**. If the reader has a background with **keras** then it'll be easy to understand this.

The **stax** module provides a method named **serial()** that accepts a list of layers and activation functions as input and creates a neural network. It'll apply the layers in the sequence in which they are given as input when performing forward pass-through data.

We can create dense layers/fully connected layers using **Dense()** method. It takes as input the number of units that will be present in that layer. We can also provide weight initialization and bias initialization function if we don't want internal initialization performed by **JAX** after we create the layer using **Dense()**.

The majority of **stax** module methods (all layers (**Dense, Conv, etc**) and **serial()** method), returns 2 callable as output when executed.

**init_fun**- This function takes seed for weight initialization of that layer/network and input shape for that layer/network as input. It then returns weights and biases. For a single layer, it returns just weights and biases as output and for the neural network, it'll return a list of weights and biases.**apply_fun**- This function takes weights & biases of layer/network and data as input. It then executes layer/network on input data using weights. It performs forward pass-through data for the network.

All activation functions as available as simple attributes of **stax** module and we don't need to call them with brackets. We can just give them as input to **serial()** method after layers and they will be applied to the output of the layer.

Below we have first simply created a **Dense()** layer with 5 units to show the output returned by it. We can notice that it returns two callable which we described above.

We have then created our neural network which has layer sizes **[5,10,15,1]**. The last layer is the output layer and all other layers are hidden layers. We have simply created layers using **Dense()** method followed by **Relu (Rectified Linear Unit)** activation function. All the layers and activation functions are provided to **serial()** method in sequence separated by a comma. The **Relu** function that we have used in our neural network takes as input an array and returns a new array of the same size where all values less than 0 are replaced by 0.

As we have said earlier, the **serial()** method returns two callable as output which we have stored in different variables as we'll be using them in the future.

In [5]:

```
stax.Dense(5)
```

Out[5]:

In [6]:

```
neural_net_init, neural_net_apply = stax.serial(
stax.Dense(5),
stax.Relu,
stax.Dense(10),
stax.Relu,
stax.Dense(15),
stax.Relu,
stax.Dense(1),
)
```

In [7]:

```
neural_net_init, neural_net_apply
```

Out[7]:

Now, we have simply initialized the weights of our neural network by calling **init_fun()** function. We have given seed (**jax.random.PRNGKey(123)**) and input data shape as input to function. It then uses seed and shape information to initialize the weights and biases of each layer of the neural network.

After initializing weights, we have also printed the shape of weights and biases for each layer. This can help us verify that all things went as expected.

In [8]:

```
rng = jax.random.PRNGKey(123)
weights = neural_net_init(rng, (features,))
weights = weights[1] ## Weights are actually stored in second element of two value tuple
for w in weights:
if w:
w, b = w
print("Weights : {}, Biases : {}".format(w.shape, b.shape))
```

In the below cell, we have actually performed a forward pass through our neural network. We have taken a few samples of our data and given them as input to **apply_fun()** function along with weights. The weights are given first followed by a small batch of data. The **apply_func()** will perform one forward pass-through data using weights and return predictions.

In [9]:

```
preds = neural_net_apply(weights, X_train[:5])
preds
```

Out[9]:

In this section, we have created a loss function for our neural network. We'll be calculating the gradient of the loss function with respect to weights and then update weights using gradients.

We'll be using **Mean squared error loss** as our loss. It simply subtracts predictions from actual values, squares subtracted values, and then the mean of them.

`MSE(actual, preds) = 1/n *(actual - preds)^2`

Our loss function takes weights, data, and actual target values as input. It then performs a forward pass through the neural network using **apply_fun()** function providing weights and data to it. The predictions made by the network are stored in a variable. We can then actually calculate **MSE** using actual target values and predictions.

In [10]:

```
def MeanSquaredErrorLoss(weights, input_data, actual):
preds = neural_net_apply(weights, input_data)
preds = preds.squeeze()
return jnp.power(actual - preds, 2).mean()
```

In this section, we have included code to train our neural network. We have created a small function which we'll call to train our neural network. The function takes data features, target values, number of epochs, and optimizer state as input. Optimizer state is an object created by optimizer which has weights of our model. We'll explain about optimizer in the next cell below when we initialize it.

Our function then loops a number of epochs time. Each time, it first calculates loss value and gradients using **value_and_grad()** function. This function takes as input another function which is **MSE** loss function in our case. It then returns another callable which when called will return the actual value of the function as well as the gradient of function with respect to the first parameter which is weights in our case. In our case, we have given our loss function to **value_and_grad()** function as input and then called the returned function by providing weights, data features, and target values. These three are inputs of our loss function. This call will return **MSE** value and gradients for weights and biases of each layer of our neural network.

We have then called an optimizer state update method which takes as input epoch number, gradients, and current optimizer state that has current weights. The method returns a new optimizer state which will have weights updated by subtracting gradients from it.

At last, we return the last optimizer state (final updated weights). We are also printing **MSE** at every 100 epochs to keep track of training progress.

In [11]:

```
from jax import grad, value_and_grad
def TrainModel(X, Y, epochs, opt_state):
for i in range(1,epochs+1):
loss, gradients = value_and_grad(MeanSquaredErrorLoss)(opt_get_weights(opt_state), X, Y)
## Update Weights
opt_state = opt_update(i, gradients, opt_state)
if i%100 ==0: ## Print MSE every 100 epochs
print("MSE : {:.2f}".format(loss))
return opt_state
```

Here, we are actually training our neural network by calling the function we designed in the previous cell.

First, we have initialized the weights of our neural network by calling **init_fun** function from earlier.

We have then initialized an optimizer for our neural network. The optimizer is an algorithm responsible for finding the minimum value of our loss function. The **optimizers** module available from **example_libraries** module of **jax** provides us with a list of different optimizers. We'll be using **sgd() (gradient descent)** optimizer for our purpose. We have initialized our optimizer by giving a learning rate (**0.001**) to it. The optimizer returns three callable that is needed for maintaining and updating weights of the neural network.

**init**- This function takes weights of a neural network as input and returns**OptimizerState**object which is a wrapper for holding and updating weights.**update_fn**- This function takes epoch number, gradients and optimizer state as input. It then updates weights present in the optimizer state object by subtracting learning times gradients from it. It then returns a new**OptimizerState**object which has updated weights.**params_fn**- This function takes**OptimizerState**object as input and returns actual weights of neural network.

After initializing the optimizer with weights, we have called our training routine to actually perform training by providing data, target values, number of epochs, and optimizer state (weights). We are a training network for **2500** epochs. We can notice from **MSE** getting printed every 100 epochs that the model is getting better at the task.

In [12]:

```
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e3)
epochs = 2500
weights = neural_net_init(rng, (features,))
weights = weights[1]
opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)
final_opt_state = TrainModel(X_train, Y_train, epochs, opt_state)
```

In this section, we are making predictions using our trained neural network. We have made predictions for both train and test datasets.

We have retrieved weights of the neural network using **params_fn** optimizer function. We have then given weights and data features as input to **apply_fn** method which will make predictions.

In [13]:

```
test_preds = neural_net_apply(opt_get_weights(final_opt_state), X_test) ## Make Predictions on test dataset
test_preds = test_preds.ravel()
train_preds = neural_net_apply(opt_get_weights(final_opt_state), X_train) ## Make Predictions on train dataset
train_preds = train_preds.ravel()
test_preds[:5], train_preds[:5]
```

Out[13]:

In this section, we are actually evaluating the performance of our regression model. We are calculating **R^2 score** for both our train and test predictions. We are calculating **R^2 score** using **r2_score()** method of scikit-learn. The **R^2** score generally returns the value in the range **[0,1]** where a value near 1 indicates a good model. We can notice from the **R^2 score** that our model seems to be doing a good job.

If you are interested in learning in detail about **R^2** score and other metrics available from scikit-learn for different kinds of tasks then please feel free to check our tutorial which covers the majority of metrics in detail with examples.

In [14]:

```
from sklearn.metrics import r2_score
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.to_py(), Y_train.to_py())))
print("Test R^2 Score : {:.2f}".format(r2_score(test_preds.to_py(), Y_test.to_py())))
```

In real life, there are times when datasets are quite big and do not fit into the main memory of the computer. In those situations, we only bring a small batch of data into the main memory of the computer and train the model in batches until the whole data is covered. The optimization algorithm used in this case is referred to as **stochastic gradient descent** as it works on a small batch of data at a time.

In this section, we have explained how we can modify our code so that we can perform training on data in batches. We have below declared a function like last time which we'll use for training purposes. The function takes data features, target values, number of epochs, optimizer state (weights), and batch size (default **32**) as input. We are then performing training loop number of epochs time. For each training loop, we are calculating start and end indexes of our batch of data. We are performing forward pass, calculating loss, and updating loss on a single batch of data at a time until the whole data is covered in batches. When we can bring whole data in main memory, we update weights only one per epochs but with training in batches, we are updating weights for each batch of data until whole data is covered for a number of epochs.

In [15]:

```
def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
for i in range(1, epochs+1):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss, gradients = value_and_grad(MeanSquaredErrorLoss)(opt_get_weights(opt_state), X_batch, Y_batch)
## Update Weights
opt_state = opt_update(i, gradients, opt_state)
losses.append(loss) ## Record Loss
if i % 100 == 0: ## Print MSE every 100 epochs
print("MSE : {:.2f}".format(jnp.array(losses).mean()))
return opt_state
```

Now, we have actually trained our neural network in batches using the function we designed in the previous cell. We have first initialized the weights of the neural network using **init_fun** by giving seed and input shape to it. Then we have initialized our optimizer by calling **sgd()** function giving learning rate (**0.001**) to it. Then we have created the first optimizer state with weights. We have then called our function from the previous cell to perform training in batches. We are training the neural network for **500** epochs.

In [16]:

```
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e3)
epochs = 500
weights = neural_net_init(rng, (features,))
weights = weights[1]
opt_init, opt_update, opt_get_weights = optimizers.sgd(learning_rate)
opt_state = opt_init(weights)
final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state)
```

As we can not fit our data into the main memory, we need to make predictions as well in batches. In this section, we have designed a function that takes weights and data as input and then makes predictions on data in batches. We have followed almost the same logic that was present in the training function to calculate indexes of batches.

In [17]:

```
def MakePredictions(weights, input_data, batch_size=32):
batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices
preds = []
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch = input_data[start:end]
preds.append(neural_net_apply(weights, X_batch))
return preds
```

Below we have used our function from the previous cell and made predictions on train and test datasets in batches. We have then combined predictions of batches as well.

In [18]:

```
test_preds = MakePredictions(opt_get_weights(final_opt_state), X_test)
test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches
train_preds = MakePredictions(opt_get_weights(final_opt_state), X_train)
train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches
test_preds[:5], train_preds[:5]
```

Out[18]:

In this section, we have evaluated the performance of our neural network by calculating **R^2 score** on train and test predictions.

In [19]:

```
from sklearn.metrics import r2_score
print("Test R^2 Score : {:.2f}".format(r2_score(test_preds, Y_test)))
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds, Y_train)))
```

In this section, we'll explain how we can create a neural network to solve classification tasks using the high-level API available from **JAX**. We'll be using a small toy dataset from scikit-learn for an explanation. We'll also be reusing much of the code which we used in our previous regression section. Due to this, we won't include a detailed description of repeated code over here. Please feel free to look in the regression section if you want a detailed explanation of some section which is not present here.

In this section, we have loaded the breast cancer dataset available from scikit-learn. The dataset has various measurements of tumors as data features. The target value is either 0 (benign tumor) or 1 (malignant tumor). As the output has only two classes, this will be a binary classification task. We have then divided the dataset into the train (80%) and test (20%) sets. The code for this section is almost the same as the code from the regression section of loading data.

In [20]:

```
from sklearn import datasets
from sklearn.model_selection import train_test_split
X, Y = datasets.load_breast_cancer(return_X_y=True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, stratify=Y, random_state=123)
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
samples, features = X_train.shape
classes = jnp.unique(Y_test)
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
```

Out[20]:

In [21]:

```
samples, features, classes
```

Out[21]:

In this section, we have normalized our train and test datasets based on a mean and standard deviation of the training dataset for each feature.

In [22]:

```
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)
X_train = (X_train - mean) / std
X_test = (X_test - mean) / std
```

In this section, we have created a neural network that we'll be using for classification. Our neural network for this task is almost exactly the same as our neural network from the regression section with one minor change. We have applied **sigmoid** function as the activation function of the last layer. The sigmoid function returns output in the range [0-1].

Then in the next two cells, we have initialized weights of the neural network, printed their shape, and also made predictions on the first few training samples for verification purposes that our functions work well.

In [23]:

```
neural_net_init, neural_net_apply = stax.serial(
stax.Dense(5),
stax.Relu,
stax.Dense(10),
stax.Relu,
stax.Dense(15),
stax.Relu,
stax.Dense(1),
stax.Sigmoid
)
```

In [24]:

```
rng = jax.random.PRNGKey(123)
weights = neural_net_init(rng, (features,))
weights = weights[1] ## Weights are actually stored in second element of two value tuple
for w in weights:
if w:
w, b = w
print("Weights : {}, Biases : {}".format(w.shape, b.shape))
```

In [25]:

```
preds = neural_net_apply(weights, X_train[:5])
preds
```

Out[25]:

In this section, we have defined a loss function that we'll be using for our purpose. We'll be using the log loss function as a loss function for our binary classification task.

`log_loss(predictions, actuals) = 1/n * (- actuals * log(predictions) - (1 - actuals) * log(1 - predictions))`

The loss function takes weights, input features data, and actual target values as input. It then makes predictions using weights and input data. Then, it calculates log loss values based on predictions and actual target values.

In [26]:

```
def NegLogLoss(weights, input_data, actual):
preds = neural_net_apply(weights, input_data)
preds = preds.squeeze()
return (- actual * jnp.log(preds) - (1 - actual) * jnp.log(1 - preds)).mean()
```

In this section, we'll be training our neural network for the binary classification task. We have simply copied code to train the neural network from the regression section. The code is almost exactly the same as the regression section with only a change in the loss function.

In [27]:

```
from jax import grad, value_and_grad
def TrainModel(X, Y, epochs, opt_state):
for i in range(1,epochs+1):
loss, gradients = value_and_grad(NegLogLoss)(opt_get_weights(opt_state), X, Y)
## Update Weights
opt_state = opt_update(i, gradients, opt_state)
if i%100 ==0: ## Print MSE every 100 epochs
print("NegLogLoss : {:.2f}".format(loss))
return opt_state
```

Now, we are actually training our neural network by calling the training function we designed in the previous cell. We have first initialized seed for random numbers, learning rate (**0.0001**), and the number of epochs (**1500**). We have then initialized the weights of the neural network using **init_fun** function by providing seed to it as usual.

Then we have initialized the optimizer that we'll use for optimizing the weights of the neural network. This time we are using **rmsprop()** optimizer. The **RMSProp (Root Mean Squared Propagation)** is an updated optimizer from **SGD** which we used in the regression section. The **RMSProp** optimizer updates the learning rate over time for better performance and faster convergence. We have then called the init method of optimizer with network weights to create the initial optimizer state. We have then called training routing with train data, train target values, number of epochs, and optimizer state (weights). We can notice from the log loss getting printed every 100 epochs that our neural network seems to be doing a good job.

In [28]:

```
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 1500
weights = neural_net_init(rng, (features,))
weights = weights[1]
opt_init, opt_update, opt_get_weights = optimizers.rmsprop(learning_rate)
opt_state = opt_init(weights)
final_opt_state = TrainModel(X_train, Y_train, epochs, opt_state)
```

In this section, we are making predictions on train and test datasets using our updated weights and neural network. We are calling **apply_fun** method of neural network with weights and features data to make predictions. The output of our neural network is probability in the range [0-1] due to the sigmoid activation function. We need to convert these probabilities to the actual class of prediction. In order to find out a class from probability, we have set a threshold of 0.5 so that values less than 0.5 will be predicted as class 0 (benign tumor), and values greater than 0.5 will be predicted as class 1 (malignant tumor).

In [29]:

```
test_preds = neural_net_apply(opt_get_weights(final_opt_state), X_test) ## Make Predictions on test dataset
test_preds = test_preds.ravel() ## Combine predictions of all batches
test_preds = (test_preds > 0.5).astype(jnp.float32)
test_preds[:5], Y_test[:5]
```

Out[29]:

In [30]:

```
train_preds = neural_net_apply(opt_get_weights(final_opt_state), X_train) ## Make Predictions on train dataset
train_preds = train_preds.ravel() ## Combine predictions of all batches
train_preds = (train_preds > 0.5).astype(jnp.float32)
train_preds[:5], Y_train[:5]
```

Out[30]:

In this section, we have evaluated the performance of our model by calculating the accuracy of train and test predictions. We can notice from the results that our model seems to be doing a decent job at prediction. We have used **accuracy_score()** method available from scikit-learn to calculate accuracy.

Then in the next cell, we have printed classification report of our test predictions which includes information like precision, recall, and f1-score for each class.

We recommend that you go through our tutorial on ML metrics available through scikit-learn if you don't have a background on these metrics to better understand them.

In [31]:

```
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))
```

In [32]:

```
from sklearn.metrics import classification_report
print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
```

In this section, we have explained how we can perform training in batches of data. We have copied the code from the regression section to perform training on batches of data. The only difference in code is that we are using the log loss function here. The rest of the code is exactly the same as earlier.

In [33]:

```
def TrainModelInBatches(X, Y, epochs, opt_state, batch_size=32):
for i in range(1, epochs+1):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss, gradients = value_and_grad(NegLogLoss)(opt_get_weights(opt_state), X_batch, Y_batch)
## Update Weights
opt_state = opt_update(i, gradients, opt_state)
losses.append(loss) ## Record Loss
if i % 100 == 0: ## Print NegLogLoss every 100 epochs
print("NegLogLoss : {:.2f}".format(jnp.array(losses).mean()))
return opt_state
```

Below, we are actually training our neural network by calling the training function from the previous cell. We have the first initialized seed for generating random numbers, learning rate (**0.0001**), and a number of epochs (**200**). Then, we have initialized the weights of the neural network by calling **init_fun** function providing seed and input shape. We have then initialized the **RMSProp** optimizer that we'll be using for optimizing weights. We have then called the training function to actually perform training with our train data in batches. We can notice from the loss value getting printed that our model seems to be doing a decent job.

In [34]:

```
seed = jax.random.PRNGKey(123)
learning_rate = jnp.array(1/1e4)
epochs = 200
weights = neural_net_init(rng, (features,))
weights = weights[1]
opt_init, opt_update, opt_get_weights = optimizers.rmsprop(learning_rate)
opt_state = opt_init(weights)
final_opt_state = TrainModelInBatches(X_train, Y_train, epochs, opt_state)
```

In this section, we have made predictions on train and test datasets in batches. We have copied the function from the regression section which we had used there to make predictions in batches. We have then combined predictions of all batches and converted probabilities to prediction classes by setting the threshold at 0.5.

In [35]:

```
def MakePredictions(weights, input_data, batch_size=32):
batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices
preds = []
for batch in batches:
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch = input_data[start:end]
preds.append(neural_net_apply(weights, X_batch))
return preds
```

In [36]:

```
test_preds = MakePredictions(opt_get_weights(final_opt_state), X_test)
test_preds = jnp.concatenate(test_preds).squeeze() ## Combine predictions of all batches
test_preds = (test_preds > 0.5).astype(jnp.float32)
train_preds = MakePredictions(opt_get_weights(final_opt_state), X_train)
train_preds = jnp.concatenate(train_preds).squeeze() ## Combine predictions of all batches
train_preds = (train_preds > 0.5).astype(jnp.float32)
test_preds[:5], train_preds[:5]
```

Out[36]:

In this section, we have evaluated the performance of our model by calculating the accuracy of our train and test predictions. We have then also printed a classification report of our train predictions.

In [37]:

```
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))
```

In [38]:

```
from sklearn.metrics import classification_report
print("Test Classification Report ")
print(classification_report(Y_test, test_preds))
```

This ends our small tutorial explaining how we can use high-level **JAX** API available through **'stax'** and **'optimizers'** module to create neural networks. Please feel free to let us know your views in the comments section.

If you want to

- provide some suggestions on topic
- share your views
- include some details in tutorial
- suggest some new topics on which we should create tutorials/blogs

If you like our work please give a thumbs-up to our article in the comments section below. You can also support us with a small contribution by clicking on