**Flax** is a deep learning framework designed on the top of **JAX**. It let us create a neural network easily using its high-level API. **JAX** is a library that provides numpy like arrays (functions to work on arrays) on CPUs/GPUs/TPUs and automatic differentiation of functions working with arrays. The neural networks created using **Flax** are faster as it utilizes various optimization functionalities provided by **JAX** like JIT (just-in-time compilation), vmap (vectorization), pmap (parallelization), etc. As **Flax** is designed on the top of **JAX**, a little bit of **JAX** background is required. If you want to learn about **JAX** then please feel free to check our small tutorial on it. We recommend that you go through our **JAX** tutorial and other **JAX** tutorials mentioned in the reference section to understand **JAX** better as it'll help with this tutorial.

As a part of this tutorial, we'll explain how we can create simple neural networks using **Flax** to solve simple regression and classification tasks. We'll be using small toy datasets available from scikit-learn for our purposes. The **Flax** earlier used to have its own implementation of optimizers but they deprecated that sub-module in favor of a new library named **Optax** that provides an implementation of the majority of optimizers. Hence, we'll be using optimizers from **Optax** as a part of our tutorial to optimize our loss functions.

The tutorial requires that the reader has little background on neural networks and it's parts like optimizers, layers, activation functions, loss functions, etc. Because the main aim of the tutorial is to get individuals started designing neural networks using **Flax API** and not to explain how neural networks work in-depth.

Below we have highlighted important sections of the tutorial to give an overview of the material covered.

- Regression
- Load Dataset
- Normalize Data
- Define Neural Network
- Define Loss Function
- Train Model
- Make Predictions
- Evaluate Model Performance

- Classification

**pip install --upgrade jax jaxlib****pip install flax**

Below we have imported the necessary libraries that we'll use in our tutorial and printed the version of them as well.

In [1]:

```
import flax
print("Flax Version : {}".format(flax.__version__))
```

In [2]:

```
import jax
print("Jax Version : {}".format(jax.__version__))
```

In [3]:

```
import optax
print("Optax Version : {}".format(optax.__version__))
```

In this section, we'll explain how we can create simple neural networks to solve regression tasks. We'll be using the Boston housing dataset available from scikit-learn for our purposes.

In this section, we have loaded the Boston housing dataset available from scikit-learn. We have loaded data features (independent variables) into variable **X** and target values into variable **Y**. The target values are median house prices in 1000 dollars. The features are various features related to the house and adjoining area.

After loading the dataset, we have divided it into the train (80%) and test (20%) sets. We have then converted datasets from numpy to **JAX** arrays.

In [4]:

```
from sklearn import datasets
from sklearn.model_selection import train_test_split
from jax import numpy as jnp
X, Y = datasets.load_boston(return_X_y=True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, random_state=123)
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
samples, features = X_train.shape
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
```

Out[4]:

In [5]:

```
samples, features
```

Out[5]:

In this section, we have normalized our train and test datasets. We normalize datasets so that features that are on different scales and vary a lot in their values, come on the same scale. This will help optimization algorithms like gradient descent to converge faster.

In order to normalize datasets, we have first calculated the mean and standard deviation features of train data. We have then subtracted the mean from the train and test dataset followed by dividing subtracted results with standard deviation.

In [6]:

```
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)
X_train = (X_train - mean) / std
X_test = (X_test - mean) / std
```

In this section, we have created a neural network using **Flax**. **Flax** provides a module named **linen** which has all necessary layers that are required to create neural networks.

In order to create a neural network using **Flax**, we need to create a class that extends **linen.Module** class. We then need to define **setup()** and ** call()** methods. Inside

Below, we have first created a class representing our neural network by extending **linen.Module** class. We have then declared a **features** class variable that holds layer sizes details. We have then created list of linear/dense layers using that **features** variable inside **setup()** method. Inside ** call()** method, we have looped through layers initialized in

After defining our neural network, we have created an instance of our neural network. We can initialize the weights of the neural network by calling **init()** method on it. In order to initialize weights, we need to provide pseudo-random number seed and same data as input to **init()** method. It returns a dictionary-like object which has the parameters/weights of a neural network. The weights are kept in **'params'** key of the dictionary. We have printed the shape of weights of the neural network by looping through the weights/parameters dictionary.

Then, in the next cell, we have performed a forward pass through the neural network by calling **apply()** method on it giving sample data to it. We have then also printed predictions to verify that network is working as expected.

In [7]:

```
from typing import Sequence, Tuple
from jax import random
import jax.numpy as jnp
from flax import linen
class MultiLayerPerceptronRegressor(linen.Module):
features: Sequence[int] = (5,10,15,1)
def setup(self):
self.layers = [linen.Dense(feat) for feat in self.features]
def __call__(self, inputs):
x = inputs
for i, lyr in enumerate(self.layers):
x = lyr(x)
if i != len(self.layers) - 1:
x = linen.relu(x)
return x
seed = random.PRNGKey(0)
model = MultiLayerPerceptronRegressor()
params = model.init(seed, X_train[:5])
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
```

In [8]:

```
preds = model.apply(params, X_train[:5])
preds
```

Out[8]:

In this section, we have defined the loss function for our neural network which is the mean squared error loss function. We'll be calculating gradients with respect to this function. The method takes weights, input data, and actual predictions as input. It first makes predictions using **apply()** method of neural network. Then, it calculates loss using predictions and actual target values.

The mean squared error loss is simply the average of the squared difference between actual target values and predictions. We return the scalar loss value from the function.

`mean square error(predictions, actuals) = 1/n * (actuals - predictions)^2`

In [9]:

```
def MeanSquaredErrorLoss(weights, input_data, actual):
preds = model.apply(weights, input_data)
return jnp.power(actual - preds.squeeze(), 2).mean()
```

In this section, we have trained our neural network. We have first initialized a number of epochs to **1000** and also initialized pseudo-random number seed that will be used to initialize model weights. We have then created an instance of neural network and initialized model weights by calling **init()** method on network instance by giving seed and sample random data.

Then, we have created a gradient descent optimizer using **optax** library. The optimizer that we'll be using for our training is **sgd()** optimizer. We have initialized it with a learning rate of **0.001**. We have then initialized the optimizer state by calling **init()** method on the optimizer instance giving neural network weights to it. This optimizer state will have information about model weights which we'll update when performing the training loop.

Then, in the next line, we have created another function by wrapping our loss function inside of **value_and_grad()** method of **JAX**. This method returns a function that takes the same parameters as our original loss function but when called with those parameter values, it returns two values as output. The first value will be the loss value returned after the actual execution of the loss function with weights, input data, and actual target values. The second value will be gradients of loss with respect to the first parameter of loss function which is the weights of the neural network. We'll be using this function during the training loop to calculate gradients.

After all initializations, we are looping our training loop number of epochs time. Each time, we first calculate loss values and gradients using the function we initialized earlier. This function also performs a forward pass through the network when calculating predictions inside of the loss function. Then, we call **update()** method on optimizer instance giving gradients and optimizer state to it. It returns updates and a new optimizer state. At last, we update model weights by calling **apply_updates()** method of **optax** to update model weights.

We are printing loss value every 100 epochs. We can notice from loss value getting printed that our model is doing a decent job.

We recommend that readers go through our other tutorials guiding how to create neural networks using **JAX** as it'll help them better understand **JAX** and frameworks based on it.

In [10]:

```
seed = random.PRNGKey(0)
epochs=1000
model = MultiLayerPerceptronRegressor() ## Define Model
random_arr = jax.random.normal(key=seed, shape=(5, features))
params = model.init(seed, random_arr) ## Initialize Model Parameters
optimizer = optax.sgd(learning_rate=1/1e3) ## Initialize SGD Optimizer using OPTAX
optimizer_state = optimizer.init(params)
loss_grad = jax.value_and_grad(MeanSquaredErrorLoss)
for i in range(1,epochs+1):
loss_val, gradients = loss_grad(params, X_train, Y_train) ## Calculate Loss and Gradients
updates, optimizer_state = optimizer.update(gradients, optimizer_state)
params = optax.apply_updates(params, updates) ## Update weights
if i % 100 == 0:
print('MSE After {} Epochs : {:.2f}'.format(i, loss_val))
```

In this section, we are making predictions on train and test sets. We are calling **apply()** method on the network object by giving updated model weights and input data to it to make predictions.

In [11]:

```
test_preds = model.apply(params, X_test) ## Make Predictions on test dataset
test_preds = test_preds.ravel()
train_preds = model.apply(params, X_train) ## Make Predictions on train dataset
train_preds = train_preds.ravel()
```

In this section, we have evaluated the performance of our network by calculating **r^2 score** on train and test predictions. The **r^2 score** generally returns float value in the range **[0,1]** for good models where values near **1** are considered good model. We have calculated **r^2 score** using **r2_score()** method available from scikit-learn. We can notice from the score calculated on our train and test predictions that our model seems to be doing a good job.

If you want to learn about **r^2 score** and other metrics provided for ML tasks by scikit-learn then please feel free to check our below tutorial that covers the majority of metrics in detail.

In [12]:

```
from sklearn.metrics import r2_score
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.to_py(), Y_train.to_py())))
print("Test R^2 Score : {:.2f}".format(r2_score(test_preds.to_py(), Y_test.to_py())))
```

In this section, we have explained how we can create simple neural networks for classification tasks. We'll be using the breast cancer dataset available from the scikit-learn for our explanation purposes. We'll be reusing the majority of the code that we have created in the regression section. Hence, we haven't included a detailed description of repeated code parts. Please feel free to check their description in the regression section if you have started directly from the classification section.

In this section, we have loaded the breast cancer dataset available from scikit-learn. The features of the dataset are various measurements of tumor and the target value is binary (0 - benign tumor, 1 - malignant tumor). As our target values are binary, this will be a binary classification task.

After loading the dataset, we have divided it into the train (80%) and test (20%) sets.

In [13]:

```
from sklearn import datasets
from sklearn.model_selection import train_test_split
X, Y = datasets.load_breast_cancer(return_X_y=True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, stratify=Y, random_state=123)
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
samples, features = X_train.shape
classes = jnp.unique(Y_test)
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
```

Out[13]:

In [12]:

```
samples, features, classes
```

Out[12]:

In this section, we have normalized our train and test sets by using the mean and standard deviation calculated on the train set. As we had explained earlier, it helps with faster convergence.

In [13]:

```
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)
X_train = (X_train - mean) / std
X_test = (X_test - mean) / std
```

In this section, we have defined a neural network to perform the binary classification task. The code for this section is almost exactly the same as our code from the regression section with one minor change. We have designed the same neural network with layer size **[5,10,15,1]** as earlier. The only difference is that we have applied **sigmoid** activation to the output of the last layer. The **sigmoid** function maps input values to the floats in the range **[0,1]**. Hence, the output of our neural networks will be float in the range **[0,1]**. We'll be later converting these floats to actual prediction classes (0 - benign, 1 - malignant).

In [15]:

```
from typing import Sequence, Tuple
from jax import random
import jax.numpy as jnp
from flax import linen
class MultiLayerPerceptronClassifier(linen.Module):
features: Sequence[int] = (5,10,15,1)
def setup(self):
self.layers = [linen.Dense(feat) for feat in self.features]
def __call__(self, inputs):
x = inputs
for i, lyr in enumerate(self.layers):
x = lyr(x)
if i != len(self.layers) - 1:
x = linen.relu(x)
return linen.sigmoid(x)
seed = random.PRNGKey(0)
model = MultiLayerPerceptronClassifier()
params = model.init(seed, X_train[:5])
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
```

In [17]:

```
preds = model.apply(params, X_train[:5])
preds
```

Out[17]:

In this section, we have defined the loss function for our neural network. We'll be using the negative log loss function for our network. The function takes weights, input data, and actual target values as input. It then makes predictions using weights and input data. Then, calculates loss using predictions and actual target values.

`log_loss(predictions, actuals) = 1/n * (- actuals * log(predictions) - (1 - actual) * log(1 - predictions))`

In [20]:

```
def NegLogLoss(weights, input_data, actual):
preds = model.apply(weights, input_data)
preds = preds.squeeze()
return (- actual * jnp.log(preds) - (1 - actual) * jnp.log(1 - preds)).mean()
```

In this section, we have included code to train the neural network. The code is almost exactly the same as the code from the regression section with the only change that we are using the log loss function as our loss. We are a training network for **1000** epochs with a learning rate of **0.01**. We can notice from the loss value getting printed every 100 epochs that our model seems to be doing a good job.

In [26]:

```
seed = random.PRNGKey(0)
epochs=1000
model = MultiLayerPerceptronClassifier() ## Define Model
random_arr = jax.random.normal(key=seed, shape=(5, features))
params = model.init(seed, random_arr) ## Initialize Model Parameters
optimizer = optax.sgd(learning_rate=1/1e2) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(params)
loss_grad = jax.value_and_grad(NegLogLoss)
for i in range(1, epochs+1):
loss_val, gradients = loss_grad(params, X_train, Y_train)
updates, optimizer_state = optimizer.update(gradients, optimizer_state)
params = optax.apply_updates(params, updates)
if i % 100 == 0:
print('NegLogLoss After {} Epochs : {:.2f}'.format(i, loss_val))
```

In this section, we have made predictions on train and test sets. We have called **apply()** method on the model instance to make predictions. We have given updated model weights and input data to the method to make predictions. As the output of our model is float values in the range **[0,1]** due to **sigmoid** activation function, we need to convert these floats to the actual prediction class. To do this, we have set the threshold at 0.5, predicting class 0 for values that are less than 0.5 and class 1 for values greater than 0.5.

In [27]:

```
test_preds = model.apply(params, X_test) ## Make Predictions on test dataset
test_preds = test_preds.ravel()
test_preds = (test_preds > 0.5).astype(jnp.float32)
test_preds[:5], Y_test[:5]
```

Out[27]:

In [28]:

```
train_preds = model.apply(params, X_train) ## Make Predictions on train dataset
train_preds = train_preds.ravel()
train_preds = (train_preds > 0.5).astype(jnp.float32)
train_preds[:5], Y_train[:5]
```

Out[28]:

In this section, we have evaluated the performance of our model by calculating the accuracy of train and test predictions. We have also calculated classification report on test predictions that has information like precision, recall, and f1-score. We can notice from metrics results that our model seems to be doing a decent job.

In [29]:

```
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))
```

In [30]:

```
from sklearn.metrics import classification_report
print("Test Data Classification Report : ")
print(classification_report(Y_test, test_preds))
```

This ends our small tutorial explaining how we can use **Flax** to create neural networks. Please feel free to let us know your views in the comments section.

- Guide to Create Neural Networks using High-level JAX API
- Guide to Create Simple Neural Networks using JAX
- JAX - (Numpy + Automatic Gradients) on Accelerators (GPUs/TPUs)
- Create Simple PyTorch Neural Networks using 'torch.nn' Module
- Guide to Create Simple Neural Networks using PyTorch
- Sonnet: Guide to Create Simple Neural Networks
- Scikit-Learn - Neural Network
- MXNet: Guide to Create Neural Networks
- Flax: Convolutional Neural Networks

**Thank You** for visiting our website. If you like our work, please support us so that we can keep on creating new tutorials/blogs on interesting topics (like AI, ML, Data Science, Python, Digital Marketing, SEO, etc.) that can help people learn new things faster. You can support us by clicking on the **Coffee** button at the bottom right corner. We would appreciate even if you can give a thumbs-up to our article in the comments section below.

If you want to

- provide some suggestions on topic
- share your views
- include some details in tutorial
- suggest some new topics on which we should create tutorials/blogs

Sunny Solanki