Flax is a deep learning framework designed on the top of JAX. It let us create a neural network easily using its high-level API. JAX is a library that provides numpy like arrays (functions to work on arrays) on CPUs/GPUs/TPUs and automatic differentiation of functions working with arrays. The neural networks created using Flax are faster as it utilizes various optimization functionalities provided by JAX like JIT (just-in-time compilation), vmap (vectorization), pmap (parallelization), etc. As Flax is designed on the top of JAX, a little bit of JAX background is required. If you want to learn about JAX then please feel free to check our small tutorial on it. We recommend that you go through our JAX tutorial and other JAX tutorials mentioned in the reference section to understand JAX better as it'll help with this tutorial.
As a part of this tutorial, we'll explain how we can create simple neural networks using Flax to solve simple regression and classification tasks. We'll be using small toy datasets available from scikit-learn for our purposes. The Flax earlier used to have its own implementation of optimizers but they deprecated that sub-module in favor of a new library named Optax that provides an implementation of the majority of optimizers. Hence, we'll be using optimizers from Optax as a part of our tutorial to optimize our loss functions.
The tutorial requires that the reader has little background on neural networks and it's parts like optimizers, layers, activation functions, loss functions, etc. Because the main aim of the tutorial is to get individuals started designing neural networks using Flax API and not to explain how neural networks work in-depth.
Below we have highlighted important sections of the tutorial to give an overview of the material covered.
Below we have imported the necessary libraries that we'll use in our tutorial and printed the version of them as well.
import flax
print("Flax Version : {}".format(flax.__version__))
import jax
print("Jax Version : {}".format(jax.__version__))
import optax
print("Optax Version : {}".format(optax.__version__))
In this section, we'll explain how we can create simple neural networks to solve regression tasks. We'll be using the Boston housing dataset available from scikit-learn for our purposes.
In this section, we have loaded the Boston housing dataset available from scikit-learn. We have loaded data features (independent variables) into variable X and target values into variable Y. The target values are median house prices in 1000 dollars. The features are various features related to the house and adjoining area.
After loading the dataset, we have divided it into the train (80%) and test (20%) sets. We have then converted datasets from numpy to JAX arrays.
from sklearn import datasets
from sklearn.model_selection import train_test_split
from jax import numpy as jnp
X, Y = datasets.load_boston(return_X_y=True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, random_state=123)
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
samples, features = X_train.shape
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
samples, features
In this section, we have normalized our train and test datasets. We normalize datasets so that features that are on different scales and vary a lot in their values, come on the same scale. This will help optimization algorithms like gradient descent to converge faster.
In order to normalize datasets, we have first calculated the mean and standard deviation features of train data. We have then subtracted the mean from the train and test dataset followed by dividing subtracted results with standard deviation.
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)
X_train = (X_train - mean) / std
X_test = (X_test - mean) / std
In this section, we have created a neural network using Flax. Flax provides a module named linen which has all necessary layers that are required to create neural networks.
In order to create a neural network using Flax, we need to create a class that extends linen.Module class. We then need to define setup() and call() methods. Inside setup() method, we declare various layers and few default variables. Inside call() method, we actually implement forward pass logic through our network using layers defined inside setup() method. The actual logic of how layers will be applied is kept in call() method. The call() method takes data as input to which forward pass of the neural network will be applied.
Below, we have first created a class representing our neural network by extending linen.Module class. We have then declared a features class variable that holds layer sizes details. We have then created list of linear/dense layers using that features variable inside setup() method. Inside call() method, we have looped through layers initialized in setup() method and executed it on input data one by one. We have applied Relu (Rectified Linear Units) activation function to the output of each layer using linen.relu() function. At last, we return predictions from call() method.
After defining our neural network, we have created an instance of our neural network. We can initialize the weights of the neural network by calling init() method on it. In order to initialize weights, we need to provide pseudo-random number seed and same data as input to init() method. It returns a dictionary-like object which has the parameters/weights of a neural network. The weights are kept in 'params' key of the dictionary. We have printed the shape of weights of the neural network by looping through the weights/parameters dictionary.
Then, in the next cell, we have performed a forward pass through the neural network by calling apply() method on it giving sample data to it. We have then also printed predictions to verify that network is working as expected.
from typing import Sequence, Tuple
from jax import random
import jax.numpy as jnp
from flax import linen
class MultiLayerPerceptronRegressor(linen.Module):
features: Sequence[int] = (5,10,15,1)
def setup(self):
self.layers = [linen.Dense(feat) for feat in self.features]
def __call__(self, inputs):
x = inputs
for i, lyr in enumerate(self.layers):
x = lyr(x)
if i != len(self.layers) - 1:
x = linen.relu(x)
return x
seed = random.PRNGKey(0)
model = MultiLayerPerceptronRegressor()
params = model.init(seed, X_train[:5])
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
preds = model.apply(params, X_train[:5])
preds
In this section, we have defined the loss function for our neural network which is the mean squared error loss function. We'll be calculating gradients with respect to this function. The method takes weights, input data, and actual predictions as input. It first makes predictions using apply() method of neural network. Then, it calculates loss using predictions and actual target values.
The mean squared error loss is simply the average of the squared difference between actual target values and predictions. We return the scalar loss value from the function.
mean square error(predictions, actuals) = 1/n * (actuals - predictions)^2
def MeanSquaredErrorLoss(weights, input_data, actual):
preds = model.apply(weights, input_data)
return jnp.power(actual - preds.squeeze(), 2).mean()
In this section, we have trained our neural network. We have first initialized a number of epochs to 1000 and also initialized pseudo-random number seed that will be used to initialize model weights. We have then created an instance of neural network and initialized model weights by calling init() method on network instance by giving seed and sample random data.
Then, we have created a gradient descent optimizer using optax library. The optimizer that we'll be using for our training is sgd() optimizer. We have initialized it with a learning rate of 0.001. We have then initialized the optimizer state by calling init() method on the optimizer instance giving neural network weights to it. This optimizer state will have information about model weights which we'll update when performing the training loop.
Then, in the next line, we have created another function by wrapping our loss function inside of value_and_grad() method of JAX. This method returns a function that takes the same parameters as our original loss function but when called with those parameter values, it returns two values as output. The first value will be the loss value returned after the actual execution of the loss function with weights, input data, and actual target values. The second value will be gradients of loss with respect to the first parameter of loss function which is the weights of the neural network. We'll be using this function during the training loop to calculate gradients.
After all initializations, we are looping our training loop number of epochs time. Each time, we first calculate loss values and gradients using the function we initialized earlier. This function also performs a forward pass through the network when calculating predictions inside of the loss function. Then, we call update() method on optimizer instance giving gradients and optimizer state to it. It returns updates and a new optimizer state. At last, we update model weights by calling apply_updates() method of optax to update model weights.
We are printing loss value every 100 epochs. We can notice from loss value getting printed that our model is doing a decent job.
We recommend that readers go through our other tutorials guiding how to create neural networks using JAX as it'll help them better understand JAX and frameworks based on it.
seed = random.PRNGKey(0)
epochs=1000
model = MultiLayerPerceptronRegressor() ## Define Model
random_arr = jax.random.normal(key=seed, shape=(5, features))
params = model.init(seed, random_arr) ## Initialize Model Parameters
optimizer = optax.sgd(learning_rate=1/1e3) ## Initialize SGD Optimizer using OPTAX
optimizer_state = optimizer.init(params)
loss_grad = jax.value_and_grad(MeanSquaredErrorLoss)
for i in range(1,epochs+1):
loss_val, gradients = loss_grad(params, X_train, Y_train) ## Calculate Loss and Gradients
updates, optimizer_state = optimizer.update(gradients, optimizer_state)
params = optax.apply_updates(params, updates) ## Update weights
if i % 100 == 0:
print('MSE After {} Epochs : {:.2f}'.format(i, loss_val))
In this section, we are making predictions on train and test sets. We are calling apply() method on the network object by giving updated model weights and input data to it to make predictions.
test_preds = model.apply(params, X_test) ## Make Predictions on test dataset
test_preds = test_preds.ravel()
train_preds = model.apply(params, X_train) ## Make Predictions on train dataset
train_preds = train_preds.ravel()
In this section, we have evaluated the performance of our network by calculating r^2 score on train and test predictions. The r^2 score generally returns float value in the range [0,1] for good models where values near 1 are considered good model. We have calculated r^2 score using r2_score() method available from scikit-learn. We can notice from the score calculated on our train and test predictions that our model seems to be doing a good job.
If you want to learn about r^2 score and other metrics provided for ML tasks by scikit-learn then please feel free to check our below tutorial that covers the majority of metrics in detail.
from sklearn.metrics import r2_score
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.to_py(), Y_train.to_py())))
print("Test R^2 Score : {:.2f}".format(r2_score(test_preds.to_py(), Y_test.to_py())))
In this section, we have explained how we can create simple neural networks for classification tasks. We'll be using the breast cancer dataset available from the scikit-learn for our explanation purposes. We'll be reusing the majority of the code that we have created in the regression section. Hence, we haven't included a detailed description of repeated code parts. Please feel free to check their description in the regression section if you have started directly from the classification section.
In this section, we have loaded the breast cancer dataset available from scikit-learn. The features of the dataset are various measurements of tumor and the target value is binary (0 - benign tumor, 1 - malignant tumor). As our target values are binary, this will be a binary classification task.
After loading the dataset, we have divided it into the train (80%) and test (20%) sets.
from sklearn import datasets
from sklearn.model_selection import train_test_split
X, Y = datasets.load_breast_cancer(return_X_y=True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, stratify=Y, random_state=123)
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
samples, features = X_train.shape
classes = jnp.unique(Y_test)
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
samples, features, classes
In this section, we have normalized our train and test sets by using the mean and standard deviation calculated on the train set. As we had explained earlier, it helps with faster convergence.
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)
X_train = (X_train - mean) / std
X_test = (X_test - mean) / std
In this section, we have defined a neural network to perform the binary classification task. The code for this section is almost exactly the same as our code from the regression section with one minor change. We have designed the same neural network with layer size [5,10,15,1] as earlier. The only difference is that we have applied sigmoid activation to the output of the last layer. The sigmoid function maps input values to the floats in the range [0,1]. Hence, the output of our neural networks will be float in the range [0,1]. We'll be later converting these floats to actual prediction classes (0 - benign, 1 - malignant).
from typing import Sequence, Tuple
from jax import random
import jax.numpy as jnp
from flax import linen
class MultiLayerPerceptronClassifier(linen.Module):
features: Sequence[int] = (5,10,15,1)
def setup(self):
self.layers = [linen.Dense(feat) for feat in self.features]
def __call__(self, inputs):
x = inputs
for i, lyr in enumerate(self.layers):
x = lyr(x)
if i != len(self.layers) - 1:
x = linen.relu(x)
return linen.sigmoid(x)
seed = random.PRNGKey(0)
model = MultiLayerPerceptronClassifier()
params = model.init(seed, X_train[:5])
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
preds = model.apply(params, X_train[:5])
preds
In this section, we have defined the loss function for our neural network. We'll be using the negative log loss function for our network. The function takes weights, input data, and actual target values as input. It then makes predictions using weights and input data. Then, calculates loss using predictions and actual target values.
log_loss(predictions, actuals) = 1/n * (- actuals * log(predictions) - (1 - actual) * log(1 - predictions))
def NegLogLoss(weights, input_data, actual):
preds = model.apply(weights, input_data)
preds = preds.squeeze()
return (- actual * jnp.log(preds) - (1 - actual) * jnp.log(1 - preds)).mean()
In this section, we have included code to train the neural network. The code is almost exactly the same as the code from the regression section with the only change that we are using the log loss function as our loss. We are a training network for 1000 epochs with a learning rate of 0.01. We can notice from the loss value getting printed every 100 epochs that our model seems to be doing a good job.
seed = random.PRNGKey(0)
epochs=1000
model = MultiLayerPerceptronClassifier() ## Define Model
random_arr = jax.random.normal(key=seed, shape=(5, features))
params = model.init(seed, random_arr) ## Initialize Model Parameters
optimizer = optax.sgd(learning_rate=1/1e2) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(params)
loss_grad = jax.value_and_grad(NegLogLoss)
for i in range(1, epochs+1):
loss_val, gradients = loss_grad(params, X_train, Y_train)
updates, optimizer_state = optimizer.update(gradients, optimizer_state)
params = optax.apply_updates(params, updates)
if i % 100 == 0:
print('NegLogLoss After {} Epochs : {:.2f}'.format(i, loss_val))
In this section, we have made predictions on train and test sets. We have called apply() method on the model instance to make predictions. We have given updated model weights and input data to the method to make predictions. As the output of our model is float values in the range [0,1] due to sigmoid activation function, we need to convert these floats to the actual prediction class. To do this, we have set the threshold at 0.5, predicting class 0 for values that are less than 0.5 and class 1 for values greater than 0.5.
test_preds = model.apply(params, X_test) ## Make Predictions on test dataset
test_preds = test_preds.ravel()
test_preds = (test_preds > 0.5).astype(jnp.float32)
test_preds[:5], Y_test[:5]
train_preds = model.apply(params, X_train) ## Make Predictions on train dataset
train_preds = train_preds.ravel()
train_preds = (train_preds > 0.5).astype(jnp.float32)
train_preds[:5], Y_train[:5]
In this section, we have evaluated the performance of our model by calculating the accuracy of train and test predictions. We have also calculated classification report on test predictions that has information like precision, recall, and f1-score. We can notice from metrics results that our model seems to be doing a decent job.
from sklearn.metrics import accuracy_score
print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))
from sklearn.metrics import classification_report
print("Test Data Classification Report : ")
print(classification_report(Y_test, test_preds))
This ends our small tutorial explaining how we can use Flax to create neural networks. Please feel free to let us know your views in the comments section.
If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.
When going through coding examples, it's quite common to have doubts and errors.
If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.
You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.
If you want to