Updated On : Dec-17,2021 Time Investment : ~30 mins

Haiku: Guide to Create Multi-Layer Perceptrons using JAX

Haiku is a deep learning framework designed by deepmind on top of JAX to accelerate their deep learning experiments. Haiku provides built-in implementations for multi-layer perceptrons, convolutional nets, etc. As Haiku is designed on the top of JAX, a little bit of background with JAX can help better understand Haiku.

Haiku is designed to let developers use object-oriented programming approach to design neural network and at the same time use JAX's function transformations (grad(), jit(), vmap(), pmap(), etc ) on the object-oriented code. All JAX's important functions are designed to work with functions rather than class and when we design neural networks, we generally define it using class (OOP). Haiku makes it possible to use OOP-designed modules/classes like pure JAX functions.

As a part of this tutorial, we'll be explaining how we can create simple multi-layer perceptrons using Haiku. The main aim of the tutorial is to introduce individuals to Haiku framework. We have created two small examples explaining how we can use Haiku with small toy datasets to solve regression and classification tasks.

As we have used JAX in this tutorial for some tasks, if you want to learn about JAX then please feel free to check our tutorial on it.

Installation

  • pip install -U dm-haiku

Below we have highlighted important sections of the tutorial to give an overview of the material covered.

Important Sections of Tutorial

  1. Regression
    • Load Data
    • Normalize Data
    • Define Neural Network
    • Define Loss Function
    • Train Neural Network
    • Make Predictions
    • Evaluate Model Performance
    • Train Model in Batches of Data
    • Make Predictions in Batches
    • Evaluate Model Performance
  2. Classification

Below we have imported haiku and JAX libraries which we'll use in our tutorial. We have also printed the versions of both libraries that we'll use in our tutorial.

import haiku as hk

print("Haiku Version :{}".format(hk.__version__))
Haiku Version :0.0.5
import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.2.26
from jax import numpy as jnp

Regression

In this section, we'll explain how we can create a simple multi-layer perceptron using Haiku to solve simple regression tasks. We'll be using a small dataset available from scikit-learn for our example.

Load Data

In this section, we have loaded the Boston housing dataset available from scikit-learn. We have loaded dataset features in variable X and target values in variable Y. The target values are median house prices in 1000 dollars which is continuous hence our problem is regression. After loading the dataset, we have divided it into the train (80%) and test (20%) sets.

Scikit-learn loads dataset as numpy arrays. We have also converted all our datasets from numpy to jax arrays.

from sklearn import datasets
from sklearn.model_selection import train_test_split

X, Y = datasets.load_boston(return_X_y=True)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, random_state=123)

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32),\

samples, features = X_train.shape

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
((404, 13), (102, 13), (404,), (102,))
samples, features
(404, 13)

Normalize Data

In this section, we have normalized our datasets. Normalization is generally performed to bring all features values to the same scales so that optimization algorithms of our neural network converge faster. If values are on a different scale which varies by big amount then it can make the optimization process harder and it'll take more time for the algorithm to converge due to variance in feature values.

To normalize datasets, we have first calculated the mean and standard deviation of the training dataset for each feature. We'll have the mean and standard deviation of each feature of data. We'll then subtract this mean from both train and test datasets. Then, we'll divide subtracted values by standard deviation.

mean = X_train.mean(axis=0)
std = X_train.std(axis=0)

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

Define Neural Network

In this section, we have created a neural network using Haiku that we'll use for our regression task. We'll be using a class available from Haiku named MLP to create our neural network. The MLP class has a constructor with the same name which can be used to initialize a neural network. The MLP will create a simple multi-layer perceptron based on input layer sizes. Below we have included the definition of MLP() constructor.


  • MLP(output_size=None,w_init=None,b_init=None,with_bias=True,activation=jax.nn.relu,activate_final=False, name=None) - This constructor takes as input layer sizes and create instance of class MLP which we can later train and use to make prediction.
    • The layer_size takes a list of numbers specifying the number of units per layer of neural network. It should include the number of units for the output layer as well.
    • The w_init and b_init function takes callable as input to initialize weights and biases.
    • The with_bias argument accepts boolean value specifying whether to use bias or not. By default, biases are added.
    • The activation function accepts activation that will be applied to layers of neural networks. The default is jax.nn.relu.
    • The activate_final function accepts boolean value specifying whether to activate final layer of neural network with activation specified using activation argument. By default, the last layer is not activated.

The majority of neural networks that we create with Haiku are a subclass of haiku.Module class. We need to transform all class-based models to function based and we can do so by using hk.transform() function. It takes as another function and transforms function to pure jax functions.

We have first defined our MLP with layer sizes [5,10,15,1] inside of another function. We are providing input data to that function and calling our MLP instance with input data and returning it. This function basically creates a neural network and performs forward pass through input data. It returns predictions at last. We have transformed this method using hk.transform() method. It returns a transformed object which has two methods.

  1. init(rng, data) - This method takes seed for random numbers and a few samples as input. It then initialized the weights of the neural network and returns it. It returns an instance of FlatMap which is a dictionary-like object which has weights and biases of all layers of the neural network.
  2. apply(weights, rng, data) - This method takes weights, seed, and data features as input. It then performs forward pass-through data with weights and returns predictions.

Below we have first created our neural network and stored it in a variable named model. We have then called init() method on it to retrieve weights of the neural network. We have also printed the shape of weights and biased for all layers for verification. Then, in the next cell, we have called apply() method on our neural network with weights, the seed for random numbers, and train data. It returns predictions made on data. We have printed the first few predictions. These are predictions with initial weights, we have not trained the neural network and updated weights yet.

def FeedForward(x):
    mlp = hk.nets.MLP(output_sizes=[5,10,15,1])
    return mlp(x)
model = hk.transform(FeedForward)
rng = jax.random.PRNGKey(42)

params = model.init(rng, X_train[:5])

print("Weights Type : {}\n".format(type(params)))

for layer_name, weights in params.items():
    print(layer_name)
    print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))
Weights Type : <class 'haiku._src.data_structures.FlatMap'>

mlp/~/linear_0
Weights : (13, 5), Biases : (5,)

mlp/~/linear_1
Weights : (5, 10), Biases : (10,)

mlp/~/linear_2
Weights : (10, 15), Biases : (15,)

mlp/~/linear_3
Weights : (15, 1), Biases : (1,)

preds = model.apply(params, rng, X_train)

preds[:5]
DeviceArray([[-0.7874112 ],
             [-0.27768713],
             [-0.01174072],
             [-0.01407549],
             [-0.3872894 ]], dtype=float32)

Define Loss Function

In this section, we have defined the loss function that we'll be using for our task. We'll be using Mean Squared Error Loss (MSE) function as our loss. The function takes weights, input data, and actual target values as input. It then uses apply() method of the model to make predictions on input data using weights. Then we calculate MSE loss with predictions and actual target values. We calculate MSE by first subtracting predictions from actual target values. Then, we square differences and take the mean of all squared differences.

MSE(actuals, predictions) = 1/n * (actuals - predictions)^2
def MeanSquaredErrorLoss(weights, input_data, actual):
    preds = model.apply(weights, rng, input_data)
    preds = preds.squeeze()
    return jnp.power(actual - preds, 2).mean()

Train Neural Network

In this section, we are training our neural network.

First, we have defined a simple function that takes weights and gradients as input. It then subtracts learning rate times gradients from weights. We'll be using this function to iteratively update the weights of all layers.

def UpdateWeights(weights,gradients):
    return weights - learning_rate * gradients

Below, we have included logic to train our neural network. We have first initialized the weights of our neural network using init() method of our model by providing seed for random numbers and a few data samples. We have then initialized the number of epochs (1000) and learning rate (0.001).

We then perform a training loop number of epoch times. The logic inside the loop is pretty simple and straightforward. We have first called value_and_grad() function with our loss function. The value_and_grad() function takes as input a function and returns another function. We can then call returned function with parameter values, it'll return two values as output. The first value in output will be the actual value of that wrapped function with input parameter values and the second value will be gradients of input function with respect to the first input parameter.

In our case first output will be MSE loss value and second value will be gradients of MSE with respect to weights (first parameter of MeanSquaredErrorLoss()).

Then on the next line, we have logic to update the weights of our neural network. We have used jax.tree_map() function for updating weights by subtracting learning rate times gradients. The jax.treemap() function takes function followed by arguments of function as input. The input arguments have a tree-like structure. It then performs a given function on each leaf of input arguments. This process of updating weights based on learning rate and gradients is generally referred to as gradient descent

In our case, weights are the tree-like structure that we had printed when we defined the neural network earlier. The gradients will be in the same structure. We are then applying the update weights function which we declared in our previous cell to each leaf of our weights and gradients. This function will return a new tree-like data structure with weights updates based on gradients and learning rate.

We are also printing MSE at every 100 epochs. We can notice from the MSE getting printed at every 100 epochs that our model seems to be doing a good job.

from jax import value_and_grad

rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.

params = model.init(rng, X_train[:5])
epochs = 1000
learning_rate = jnp.array(0.001)

for i in range(1, epochs+1):
    loss, param_grads = value_and_grad(MeanSquaredErrorLoss)(params, X_train, Y_train)
    params = jax.tree_map(UpdateWeights, params, param_grads)

    if i%100 == 0:
        print("MSE : {:.2f}".format(loss))
MSE : 17.04
MSE : 12.32
MSE : 10.59
MSE : 9.66
MSE : 9.09
MSE : 8.76
MSE : 8.55
MSE : 8.38
MSE : 8.23
MSE : 8.13

Make Predictions

In this section, we have made predictions on our train and test datasets using apply() method of our model.

train_preds = model.apply(params, rng, X_train)

train_preds[:5]
DeviceArray([[48.474415],
             [11.66548 ],
             [21.02784 ],
             [26.184229],
             [15.279277]], dtype=float32)
test_preds = model.apply(params, rng, X_test)

test_preds[:5]
DeviceArray([[20.90538 ],
             [25.025263],
             [44.169964],
             [21.290577],
             [29.036098]], dtype=float32)

Evaluate Model Performance

In this section, we have evaluated the performance of our model by evaluating MSE loss and R^2 score on our train and test predictions. The R^2 score returns value in the range [0,1] and values near 1 are considered good model. We can notice from R^2 score on our train and test predictions that our model is doing a decent job.

If you are interested in learning in detail about R^2 score and other metrics available from scikit-learn for different kinds of tasks then please feel free to check our tutorial which covers the majority of metrics in detail with examples.

print("Test  MSE Score : {:.2f}".format(MeanSquaredErrorLoss(params, X_test, Y_test)))
print("Train MSE Score : {:.2f}".format(MeanSquaredErrorLoss(params, X_train, Y_train)))
Test  MSE Score : 18.38
Train MSE Score : 8.13
from sklearn.metrics import r2_score

print("Test  R^2 Score : {:.2f}".format(r2_score(test_preds.squeeze(), Y_test)))
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.squeeze(), Y_train)))
Test  R^2 Score : 0.75
Train R^2 Score : 0.89

Train Model in Batches of Data

In real-life, datasets are generally large and many times datasets might not fit into the main memory of the computer. To handle datasets that do not fit into the main memory of the computer, we follow an approach where we bring a small batch of samples into the main memory and train the model with a small batch of data. We cover the whole data by training the model on a small batch of data at a time. We update the weights of the neural network for each batch of data. This algorithm of updating weights based on a small batch of data is referred to as stochastic gradient descent.

Our current dataset is quite small and easily fits into the main memory of the computer but we'll treat it as a big dataset that does not fit into the main memory of the computer. We'll divide the dataset into batches of data and train the model on a small batch of data at a time. Below, we have included logic to train data in small batches.

We have initialized model weights first using the seed of random numbers. We have then initialized a number of epochs (500), batch size (32), and learning rate (0.001). We have then executed the training loop number of epochs time. Each time, we have generated indexes for batches of data. We are then looping through the whole data in batches. For each batch, we calculate loss and gradients and then update model weights using gradients. We are updating the weights of the model for each batch of data.

We can notice from the MSE loss getting printed at every 100 epochs that our model seems to be doing a good job.

from jax import value_and_grad

rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.

params = model.init(rng, X_train[:5])
epochs = 500
batch_size = 32
learning_rate = jnp.array(0.001)

for i in range(1, epochs+1):
    batches = jnp.arange((X_train.shape[0]//batch_size)+1) ### Batch Indices

    losses = [] ## Record loss of each batch
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch, Y_batch = X_train[start:end], Y_train[start:end] ## Single batch of data

        loss, param_grads = value_and_grad(MeanSquaredErrorLoss)(params, X_batch, Y_batch)
        params = jax.tree_map(UpdateWeights, params, param_grads) ## Update Params
        losses.append(loss) ## Record Loss

    if i % 100 == 0: ## Print MSE every 100 epochs
        print("MSE : {:.2f}".format(jnp.array(losses).mean()))
MSE : 8.51
MSE : 7.81
MSE : 7.35
MSE : 6.98
MSE : 6.75

Make Predictions in Batches

As we can not fit whole data into the main memory of the computer, we need to do predictions also on a batch of data. Below we have defined a function that takes updated model weights and input data. It then loops through data in batches making predictions on a batch of data at a time. It then combines predictions of all batches and returns them. It uses the same logic to create batch indexes that were used during the training section to create batches.

We have then used the function to make predictions on train and test datasets.

def MakePredictions(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        preds.append(model.apply(weights, rng, X_batch))

    return preds
train_preds = MakePredictions(params, X_train, 32)
train_preds = jnp.concatenate(train_preds).squeeze()

test_preds = MakePredictions(params, X_test, 32)
test_preds = jnp.concatenate(test_preds).squeeze()

Evaluate Model Performance

In this section, we have evaluated the performance of the model by calculating R^2 score on train and test predictions.

print("Test  MSE Score : {:.2f}".format(MeanSquaredErrorLoss(params, X_test, Y_test)))
print("Train MSE Score : {:.2f}".format(MeanSquaredErrorLoss(params, X_train, Y_train)))
Test  MSE Score : 16.50
Train MSE Score : 6.69
from sklearn.metrics import r2_score

print("Test  R^2 Score : {:.2f}".format(r2_score(test_preds.squeeze(), Y_test)))
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.squeeze(), Y_train)))
Test  R^2 Score : 0.80
Train R^2 Score : 0.92

Classification

In this section, we'll explain how to create a simple multi-layer perceptron using MLP class of Haiku to solve classification tasks. We'll be using a small toy dataset available from scikit-learn for explanation purposes. We have reused the majority of the code from the previous regression section in this section hence we don't have included a detailed description of code sections that are repeated here. If you don't find a detailed description of some code sections then please feel free to check the same section in the regression section as it'll have a detailed description.

Load Data

In this section, we have loaded the breast cancer dataset available from scikit-learn. We have loaded data features in variable X and target values in variable Y. The target values are either 1 (malignant tumor) or 0 (benign tumor). As our target values have only two classes, this will be a binary classification problem.

After loading the dataset, we have divided it into the train (80%) and test (20%) sets. We have also converted datasets held in numpy arrays to jax arrays.

from sklearn import datasets
from sklearn.model_selection import train_test_split

X, Y = datasets.load_breast_cancer(return_X_y=True)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, stratify=Y, random_state=123)

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

samples, features = X_train.shape
classes = jnp.unique(Y)

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
((455, 30), (114, 30), (455,), (114,))
samples, features, classes
(455, 30, DeviceArray([0, 1], dtype=int32))

Normalize Data

In this section, we have normalized our train and test datasets using the mean and standard deviation of features calculated on the training dataset.

mean = X_train.mean(axis=0)
std = X_train.std(axis=0)

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

Define Neural Network

In this section, we have designed a neural network that we'll be using for our binary classification task. The network design is exactly the same as that of the regression section. As Haiku MLP() constructor does not let us explicitly specify activation function for the last layer, we have separately added a sigmoid() activation function in our loss function to be applied to the output of the neural network. We'll apply sigmoid() function available from jax.nn module when making predictions as well.

def FeedForward(x):
    mlp = hk.nets.MLP(output_sizes=[5,10,15,1])
    return mlp(x)
model = hk.transform(FeedForward)
rng = jax.random.PRNGKey(42)

params = model.init(rng, X_train[:5])

for layer_name, weights in params.items():
    print(layer_name)
    print("Weights : {}, Biases : {}\n".format(params[layer_name]["w"].shape,params[layer_name]["b"].shape))
mlp/~/linear_0
Weights : (30, 5), Biases : (5,)

mlp/~/linear_1
Weights : (5, 10), Biases : (10,)

mlp/~/linear_2
Weights : (10, 15), Biases : (15,)

mlp/~/linear_3
Weights : (15, 1), Biases : (1,)

Define Loss Function

In this section, we have defined a loss function that we'll be using our binary classification task. We'll be using log loss for our task.

log_loss(actuals, predictions) = 1/n * ( - actuals * log(predictions) - (1-actuals) * log(1-predictions))

The function takes weights, features data, and actual target values as input. It then uses apply() method of the model to make predictions. After making predictions, we have applied sigmoid activation to the output of the last layer using jax.nn.sigmoid() function. Then we have calculated loss using predictions and actual target values.

def NegLogLoss(weights, input_data, actual):
    preds = model.apply(weights, rng, input_data)
    preds = preds.squeeze()
    preds = jax.nn.sigmoid(preds)
    return (- actual * jnp.log(preds) - (1 - actual) * jnp.log(1 - preds)).mean()

Train Neural Network

In this section, we have included the logic to train the neural network. Our logic to train a neural network is the same as the logic we used in the regression section. Only our parameter settings and loss function are different. We have set the number of epochs to 1500 and the learning rate to 0.001. We can notice from the loss value getting printed every 100 epochs that our model seems to be doing a good job.

def UpdateWeights(weights,gradients):
    return weights - learning_rate * gradients
from jax import value_and_grad

rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.

params = model.init(rng, X_train[:5])
epochs = 1500
learning_rate = jnp.array(0.001)

for i in range(1,epochs+1):
    loss, param_grads = value_and_grad(NegLogLoss)(params, X_train, Y_train)
    params = jax.tree_map(UpdateWeights, params, param_grads)

    if i%100 == 0:
        print("NegLogLoss : {:.2f}".format(loss))
NegLogLoss : 0.69
NegLogLoss : 0.67
NegLogLoss : 0.66
NegLogLoss : 0.65
NegLogLoss : 0.63
NegLogLoss : 0.62
NegLogLoss : 0.60
NegLogLoss : 0.58
NegLogLoss : 0.57
NegLogLoss : 0.55
NegLogLoss : 0.53
NegLogLoss : 0.51
NegLogLoss : 0.50
NegLogLoss : 0.48
NegLogLoss : 0.46

Make Predictions

In this section, we have made predictions using the latest weights on train and test datasets. After making predictions using apply() method, we have applied jax.nn.sigmoid() function to outputs to bring all output values in the range [0-1]. We have set the threshold at 0.5 to predict the target class. The values less than 0.5 will be predicted as 0 (benign tumor) and values greater than 0.5 will be predicted as 1 (malignant tumor).

train_preds = model.apply(params, rng, X_train)
train_preds = jax.nn.sigmoid(train_preds.squeeze())
train_preds = (train_preds > 0.5).astype(jnp.float32)

test_preds = model.apply(params, rng, X_test)
test_preds = jax.nn.sigmoid(test_preds.squeeze())
test_preds = (test_preds > 0.5).astype(jnp.float32)

Evaluate Model Performance

In this section, we have evaluated the performance of our model by calculating the accuracy of train and test predictions. We can notice from the results that our model seems to be doing a good job.

print("Test  NegLogLoss Score : {:.2f}".format(NegLogLoss(params, X_test, Y_test)))
print("Train NegLogLoss Score : {:.2f}".format(NegLogLoss(params, X_train, Y_train)))
Test  NegLogLoss Score : 0.46
Train NegLogLoss Score : 0.46
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))
Train Accuracy : 0.92
Test  Accuracy : 0.94

Train Model in Batches of Data

In this section, we have included logic to train a neural network in batches. The code for this section is exactly the same as the code from the regression section with only changes in parameter values. We have set epochs to 500, batch size to 32, and learning rate to 0.001. We can notice from the loss value getting printed every 100 epochs that our model is doing a good job.

from jax import value_and_grad

rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.

params = model.init(rng, X_train[:5])
epochs = 500
batch_size = 32
learning_rate = jnp.array(0.001)

for i in range(1,epochs+1):
    batches = jnp.arange((X_train.shape[0]//batch_size)+1) ### Batch Indices

    losses = [] ## Record loss of each batch
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch, Y_batch = X_train[start:end], Y_train[start:end] ## Single batch of data

        loss, param_grads = value_and_grad(NegLogLoss)(params, X_batch, Y_batch)
        params = jax.tree_map(UpdateWeights, params, param_grads) ## Update Params
        losses.append(loss) ## Record Loss

    if i % 100 == 0: ## Print NegLogLoss every 100 epochs
        print("NegLogLoss : {:.2f}".format(jnp.array(losses).mean()))
NegLogLoss : 0.46
NegLogLoss : 0.23
NegLogLoss : 0.14
NegLogLoss : 0.10
NegLogLoss : 0.09

Make Predictions in Batches

In this section, we have made predictions on train and test datasets in batches. We have used the same function we had defined in the regression section to make predictions on data in batches. The main difference here is that after making a prediction, we have applied the sigmoid function to the output and then predicted target classes by setting the threshold at 0.5.

def MakePredictions(weights, input_data, batch_size=32):
    batches = jnp.arange((input_data.shape[0]//batch_size)+1) ### Batch Indices

    preds = []
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch = input_data[start:end]

        preds.append(model.apply(weights, rng, X_batch))

    return preds
train_preds = MakePredictions(params, X_train, 32)
train_preds = jnp.concatenate(train_preds).squeeze()
train_preds = jax.nn.sigmoid(train_preds)
train_preds = (train_preds > 0.5).astype(jnp.float32)

train_preds[:5], Y_train[:5]
(DeviceArray([1., 1., 0., 1., 1.], dtype=float32),
 DeviceArray([1., 1., 0., 0., 1.], dtype=float32))
test_preds = MakePredictions(params, X_test, 32)
test_preds = jnp.concatenate(test_preds).squeeze()
test_preds = jax.nn.sigmoid(test_preds)
test_preds = (test_preds > 0.5).astype(jnp.float32)

test_preds[:5], Y_test[:5]
(DeviceArray([0., 0., 1., 1., 1.], dtype=float32),
 DeviceArray([0., 0., 1., 1., 1.], dtype=float32))

Evaluate Model Performance

In this section, we have evaluated the performance of our model by calculating the accuracy of train and test predictions.

print("Test  NegLogLoss Score : {:.2f}".format(NegLogLoss(params, X_test, Y_test)))
print("Train NegLogLoss Score : {:.2f}".format(NegLogLoss(params, X_train, Y_train)))
Test  NegLogLoss Score : 0.09
Train NegLogLoss Score : 0.09
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))
Train Accuracy : 0.97
Test  Accuracy : 0.96

This ends our small tutorial explaining how we can use haiku to create simple multi-layer perceptrons. Please feel free to let us know your views in the comments section.

References

Sunny Solanki  Sunny Solanki

Share Views Stuck Somewhere? Need Help with Coding? Have Doubts About the Topic/Code?

When going through coding examples, it's quite common to have doubts and errors.

If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.

You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.

Share Views Want to Share Your Views? Have Any Suggestions?

If you want to

  • provide some suggestions on topic
  • share your views
  • include some details in tutorial
  • suggest some new topics on which we should create tutorials/blogs
Please feel free to contact us at coderzcolumn07@gmail.com. We appreciate and value your feedbacks. You can also support us with a small contribution by clicking DONATE.