Updated On : Oct-30,2021 Tags jax, xla, automatic-gradients
JAX - (Numpy + Automatic Gradients) on Accelerators (GPUs/TPUs)

JAX - (Numpy + Automatic Gradients) on Accelerators (GPUs/TPUs)

JAX is a python library specifically designed for making machine learning research easier. It provides the same API as that of numpy which lets us create multidimensional arrays and perform operations on them. All the arrays can be easily transferred from CPU to GPU/TPU and vice-versa. Arrays can be created directly on accelerators as well which does not need transfer across devices. Apart from this, it can compute an automatic differentiation of python functions. JAX is successor of famous automatic differentiation library Autograd. It can compute differentiation of almost any python code involving loops, conditions, recursion, etc, and can also compute differentiation of differentiation.JAX compiles code using XLA (Accelerated Linear Algebra) which fuses many operations to reduce the run time of the code based on underlying machine instructions and then run it on CPU/GPU/TPU. JAX can also provide just in time (JIT) compiler for running code fast on a single computer with multiple cores.

Below we have highlighted important features of JAX.

  • Numpy-like API on CPU/GPU/TPU.
  • XLA (Accelerated Linear Algebra) for faster execution
  • Automatic Differentiation
  • Just-In-Time Compilation.

As a part of this tutorial, we'll be covering the basics of JAX where we'll try to explain basic features with a few simple examples. Below we have highlighted important sections of the tutorial to give an overview of the material that we have covered.

Important Sections of Tutorial

  1. Array Creation
  2. Normal Array Operations
  3. Simple Statistics
  4. Random Numbers
  5. Working with Functions
  6. Automatic Gradients
  7. vmap (Vectorized Mapping)
  8. Just In Time (JIT) Compiled
  9. JIT + vmap

Installation

  • CPU
    • pip install --upgrade "jax[cpu]"
  • GPU
    • pip install --upgrade "jax[cuda]"

Please make a NOTE that the GPU version requires that CUDA and CuDNN be installed as well. It does not come with pip installation. Apart from this two, it also requires that jaxlib library is installed which can be easily pip installed (pip install -U jaxlib). Please feel free to check JAX installation link for detailed installation instructions.

We'll start by importing jax, jax.numpy, and numpy as we'll be using this three in our tutorial.

In [1]:
import jax

print("JAX Version : {}".format(jax.__version__))
JAX Version : 0.2.24
In [2]:
from jax import numpy as jnp

import numpy as np

1. Array Creation

In this section, we'll explain various ways to create a Jax array. The jax.numpy has almost the same API like that of numpy hence it has the majority of functions the same as numpy. The little background with numpy will make this tutorial quite easy to grasp.

Below we have created a JAX array by using arange() function which starts at 0 by default and increments value by 1 till it reaches 10. It has parameters start, stop step, and dtype for generating different arrays according to needs.

In [3]:
arr = jnp.arange(10)

arr
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Out[3]:
DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

Device (CPU/GPU/TPU) of JAX Array

JAX arrays are generally divided into two categories.

  1. Uncommitted Arrays - Not attached to any device (CPU/GPU/TPU).
  2. Committed Array - Attached to the device.

By default, Jax arrays are uncommitted and kept first default device.

The list of devices can be found by calling jax.devices() function. It returns a list of Device instances describing individual devices. This list has more than one entry if there are multiple GPUs.

If a single GPU is present then by default JAX arrays are kept on GPU. If there is more than one GPU then the jax arrays will be kept on the first GPU from the list returned by jax.devices() function call. If GPU is not present then jax arrays will be kept on the CPU.

We can transfer jax arrays from one device to another by calling jax.device_put(). The function takes two parameters where the first parameter is an array and the second parameter is Device instance. It'll transfer the array to the device specified. If no device parameter is specified then it'll transfer to default GPU/CPU based on the list returned by jax.devices() function (first entry). The arrays transferred using jax.device_put() are committed arrays.

We have created this tutorial on CPU hence all our arrays will be on CPU.

In [24]:
device = arr.device() # this is same as arr.device_buffer.device()

device
Out[24]:
<jaxlib.xla_extension.Device at 0x7fb9f6ada8f0>
In [25]:
device.platform, device.device_kind
Out[25]:
('cpu', 'cpu')
In [27]:
## this is default device where arrays will be created.
## This default can be changed by setting 'JAX_PLATFORM_NAME' environment variable to 'cpu' or 'gpu'

default_device = jax.devices()[0]

default_device.platform, default_device.device_kind
Out[27]:
('cpu', 'cpu')

We can also create an array by using array() method giving it any numpy array or Python list. Below we have created a jax array from a simple python list.

In [23]:
arr = jnp.array([1,2,3,4,5])

arr
Out[23]:
DeviceArray([1, 2, 3, 4, 5], dtype=int32)

There are also methods like ones() and zeros() like numpy which accepts array dimensions as input and creates an array with all elements one or zero.

In [5]:
jnp.ones((2,3))
Out[5]:
DeviceArray([[1., 1., 1.],
             [1., 1., 1.]], dtype=float32)
In [6]:
jnp.zeros((3,4))
Out[6]:
DeviceArray([[0., 0., 0., 0.],
             [0., 0., 0., 0.],
             [0., 0., 0., 0.]], dtype=float32)

We also have methods like eye() and diag() which let us create an array with elements on the diagonal of the array.

The eye() method accepts a single dimension and it'll create a square array of that shape where all elements on diagonal will be one and all other elements will be zero.

The diag() method accepts single dimension array as input. It then creates a two-dimensional array where elements of the input array will be kept on diagonal and all other elements will be set to zero.

In [7]:
jnp.eye(5)
Out[7]:
DeviceArray([[1., 0., 0., 0., 0.],
             [0., 1., 0., 0., 0.],
             [0., 0., 1., 0., 0.],
             [0., 0., 0., 1., 0.],
             [0., 0., 0., 0., 1.]], dtype=float32)
In [8]:
jnp.diag(jnp.arange(1, 5))
Out[8]:
DeviceArray([[1, 0, 0, 0],
             [0, 2, 0, 0],
             [0, 0, 3, 0],
             [0, 0, 0, 4]], dtype=int32)
In [4]:
arr1 = jnp.arange(12, dtype=jnp.float32).reshape(3,4)

arr1
Out[4]:
DeviceArray([[ 0.,  1.,  2.,  3.],
             [ 4.,  5.,  6.,  7.],
             [ 8.,  9., 10., 11.]], dtype=float32)

2. Normal Array Operations

In this section, we'll commonly perform array operations like reshape, transpose, dot product, addition, scaler operations, etc.

Below we have created a new array using arange() method and then reshaped it from one dimension to a two-dimensional array.

In [7]:
arr2 = jnp.arange(10,18, dtype=jnp.float32).reshape(4,2)

arr2
Out[7]:
DeviceArray([[10., 11.],
             [12., 13.],
             [14., 15.],
             [16., 17.]], dtype=float32)

Below we have reshaped our array which we had loaded earlier. We have changed it from a 2-dimensional array to a one-dimensional array.

In [8]:
arr1.reshape(-1)
Out[8]:
DeviceArray([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.],            dtype=float32)
ravel()

The ravel() method transforms the array of any time into a one-dimensional array. It basically flattens an array.

Below we have flattened our two-dimensional array which we had created earlier.

In [6]:
arr1.ravel()
Out[6]:
DeviceArray([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.],            dtype=float32)
transpose()

We can find the transpose of an array by calling transpose() method on it or by calling T attribute of it.

Below we have transposed our two-dimensional array.

In [12]:
arr1.T ## works exactly like arr1.transpose()
Out[12]:
DeviceArray([[ 0.,  4.,  8.],
             [ 1.,  5.,  9.],
             [ 2.,  6., 10.],
             [ 3.,  7., 11.]], dtype=float32)
dot()

We can also find our dot product of two arrays using dot() function. It accepts two arrays as input which needs to be multiplied.

Below we have calculated the dot product of our array 1 which is 3x4 dimensional and array 2 which is 4x2 dimensional hence resulting array is 3x2 dimensional.

In [9]:
jnp.dot(arr1, arr2)
Out[9]:
DeviceArray([[ 88.,  94.],
             [296., 318.],
             [504., 542.]], dtype=float32)
add()

We can add two arrays of the same dimension using add() method as explained below.

In [15]:
jnp.add(arr1, arr1)
Out[15]:
DeviceArray([[ 0.,  2.,  4.,  6.],
             [ 8., 10., 12., 14.],
             [16., 18., 20., 22.]], dtype=float32)
argmin() & argmax()

We can find an index of minimum and maximum elements of the array using argmin() and argmax() methods.

It also accepts axis argument which specifies the axis according to which find a minimum or maximum. By default, it'll find min/max considering array is of single-dimensional hence will return index which will not be according to the dimension of the array which is more than one dimension. We can ravel the array and then use an index to find the min/max array. We can also find min/max according to a particular axis using axis attribute.

Below we have explained usage with simple examples.

In [16]:
jnp.argmin(arr1)
Out[16]:
DeviceArray(0, dtype=int32)
In [17]:
jnp.argmin(arr1, axis=1)
Out[17]:
DeviceArray([0, 0, 0], dtype=int32)
In [18]:
jnp.argmax(arr1)
Out[18]:
DeviceArray(11, dtype=int32)
In [19]:
jnp.argmax(arr1, axis=1)
Out[19]:
DeviceArray([3, 3, 3], dtype=int32)

3. Simple Statistics

In this section, we'll introduce some functions which are commonly used to perform simple statistical operations like mean, median, standard deviation, variance, correlation, etc.

min()

We can find the minimum element of the array using min() element.

By default, it'll return the minimum element of the array.

If we want then we can find minimum elements across various axes as well using axis parameter.

In [20]:
jnp.min(arr1)
Out[20]:
DeviceArray(0., dtype=float32)

Below we have found minimum element across all rows of array.

In [11]:
jnp.min(arr1, axis=0)
Out[11]:
DeviceArray([0., 1., 2., 3.], dtype=float32)
max()

This method works exactly like min() but finds the maximum element.

In [21]:
jnp.max(arr1)
Out[21]:
DeviceArray(11., dtype=float32)
In [12]:
jnp.max(arr1, axis=0)
Out[12]:
DeviceArray([ 8.,  9., 10., 11.], dtype=float32)
mean()

This function helps find an average of the whole array or across different axes.

In [22]:
jnp.mean(arr1)
Out[22]:
DeviceArray(5.5, dtype=float32)
In [13]:
jnp.mean(arr1, axis=1)
Out[13]:
DeviceArray([1.5, 5.5, 9.5], dtype=float32)
std()

The std() function helps us find the standard deviation of the array. We can use axis parameter to find standard deviation across different axes.

In [23]:
jnp.std(arr1)
Out[23]:
DeviceArray(3.4520526, dtype=float32)
In [14]:
jnp.std(arr1, axis=1)
Out[14]:
DeviceArray([1.118034, 1.118034, 1.118034], dtype=float32)
sum()

The sum() function as the name suggests helps us find the sum of all array elements. We can use axis parameter to find the sum of different axes.

In [24]:
jnp.sum(arr1)
Out[24]:
DeviceArray(66., dtype=float32)
In [15]:
jnp.sum(arr1, axis=0)
Out[15]:
DeviceArray([12., 15., 18., 21.], dtype=float32)
var()

The var() function helps us find the variance of the array. We can use axis parameter to find variance across different axes.

In [25]:
jnp.var(arr1)
Out[25]:
DeviceArray(11.916667, dtype=float32)
In [16]:
jnp.var(arr1, axis=1)
Out[16]:
DeviceArray([1.25, 1.25, 1.25], dtype=float32)
correlate()

The correlate() function helps us find cross-correlation of two 1-dimensional array.

Below we have found a correlation between the 1st row of our array 1 and 1 column of our array 2. The function lets us find a correlation between two one-dimensional arrays only.

In [28]:
arr1[1], arr2[:, 1]
Out[28]:
(DeviceArray([4., 5., 6., 7.], dtype=float32),
 DeviceArray([11., 13., 15., 17.], dtype=float32))
In [29]:
jnp.correlate(arr1[1], arr2[:, 1])
Out[29]:
DeviceArray([318.], dtype=float32)

4. Random Numbers

We can generate random numbers using random module of jax. In this section, we'll explain with a few simple examples how we can generate random numbers using jax.

All functions of jax.random module requires us to provide seed to generate random numbers. We need to provide seed as an object of class jax.random.PRNGKey giving some random integer to it.

randint()

We can generate an array of random integers in a particular range using randint() function. We need to provide seed for generating random numbers as the first parameter followed by the shape of the output array and range as the minimum and maximum values.

Below we have generated an array of random integers of shape (2,3) in the range (1,10).

In [190]:
jax.random.randint(key=jax.random.PRNGKey(123), shape=(2,3), minval=1, maxval=10)
Out[190]:
DeviceArray([[8, 3, 2],
             [4, 5, 3]], dtype=int32)
uniform()

We can generate a sample uniform distribution array in a particular range using uniform() function. It works almost exactly like randint() function. We need to provide seed followed by the shape of the array. The minimum and maximum value of the range is set as 0 and 1 respectively for generating random uniform distribution.

Below we have generated two random uniform distributed arrays using uniform() function. The first one generates floats in the range (0,1) and the second one generates floats in the range (0,10).

In [194]:
jax.random.uniform(key=jax.random.PRNGKey(123), shape=(2,3))
Out[194]:
DeviceArray([[0.38492894, 0.38952553, 0.2153877 ],
             [0.18297386, 0.8140422 , 0.7754953 ]], dtype=float32)
In [195]:
jax.random.uniform(key=jax.random.PRNGKey(123), shape=(2,3), minval=1, maxval=10)
Out[195]:
DeviceArray([[4.46436  , 4.5057297, 2.9384894],
             [2.6467648, 8.32638  , 7.979458 ]], dtype=float32)
normal()

We can generate normal random values using normal() function. We need to give it seed followed by the shape of the expected output array.

In [197]:
jax.random.normal(key=jax.random.PRNGKey(123), shape=(2,3))
Out[197]:
DeviceArray([[-0.29256082, -0.28055587, -0.7878654 ],
             [-0.9040898 ,  0.8928908 ,  0.7570675 ]], dtype=float32)

5. Working with Functions

In this section, we'll explain how we can create a simple python function and pass jax arrays to it. We'll be performing simple operations on jax arrays through function. This section will build the background for the next section which is automatic gradients of functions.

Below we have created a simple function that takes three parameters as input. It then squares the first parameter x. Then it multiplies parameter m with the squared value of x and adds parameter c to it. At last, it returns the sum of all values of the output array.

In [6]:
def func(x, m, c):
    x_square = (x * x)
    y = m * x_square + c

    return y.sum()

Below we have created an array with different values and called our function designed above to check whether it’s working properly. We can notice from the output that it seems to do the job.

The first call passes an array of size 2x2 with all elements as 1. The value of m is set as 10 and the value of c as 5. The square of the 2x2 array with all 1 will be 1. Then all elements will be multiplied by 10 (value of m) hence new 2x2 array will have all elements as 10. Then 5 (value of c) will be added to all elements of the array hence final 2x2 array will have all elements set as 15. At last, we sum up all elements of the array which is 4 elements with each having a value of 15 will result in a total of 60. We can do the calculation the same way for the second call in the next cell as well where we pass an array of size 2x2 with all elements set to 2.

We'll be using this function in the next section and try to find out the gradient of it.

In [50]:
arr = jnp.ones((2,2), dtype=jnp.float32)
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)

func(arr, m, c)
Out[50]:
DeviceArray(60., dtype=float32)
In [51]:
arr = jnp.ones((2,2), dtype=jnp.float32) * 2
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)

func(arr, m, c)
Out[51]:
DeviceArray(180., dtype=float32)

6. Automatic Gradients/Differentiation

JAX can help us find out the differentiation of a function that works on jax arrays. There are many situations where we need to find differentiation of function and evaluate it. We can easily write code for the differentiation of simple functions, but things get really complicated for more evolved and complicated functions.

JAX can calculate differentiation of any function which works on jax arrays automatically with just one function call. It also lets us find out the differentiation of differentiation and so on. The main requirement is that our function should be working on jax arrays and the output of the function should be scalar.

JAX provides a function named grad() which can be used to find out the gradient of function working with jax arrays.


  • grad(func,argnums=0,hax_aux=False,allow_int=False) - This function takes as input other function working on jax arrays and calculates differentiation of it with respect to first parameter of the function. It returns another function which is the gradient of our input function. We can call this gradient function with the same parameter as our input function and it'll return the value gradient of the input function with respect to the specified parameters.
    • The argnums parameter accepts single integer or tuple of integer specifying argument index with respect to which find out differentiation. By default, it's set to 0 which will force the function to find out differentiation with respect to the first parameter. We can give any other index other than 0 as well. We can give a tuple of indices as well if we want to find out differentiation with respect to more than one parameter.
    • The has_aux accepts a boolean value. This is used in situations when our input function returns more than one value but only the first value of output is considered as the output of the function and all other returned data as auxiliary data. The differentiation will be calculated based on the first value. If this parameter is set to True then our input function can return more than one value and the first value will be considered the output of the function. It's False by default hence our function needs to return a single value.
    • The allow_int parameter accepts boolean value. It can be set to True if we want to allow differentiation with respect to integer values. By default, it's False and differentiation with respect to only floats is allowed.

The grad() function can let us find out the differentiation of function with respect to a parameter which accepts python list, dictionary, and tuple as well but all elements of this python types should be jax arrays. It'll differentiate input function with respect to all jax arrays given through python list, dict, or tuple.


Below we have calculated the gradient of our function from the previous section. By default, grad() function will find out gradient with respect to the first parameter which is x in our case. We have then evaluated the gradient by giving input data (arr,m, and c) as well.

As our earlier function was easy to understand, we can calculate the gradient of it as well. It'll be m*2x. We can then give x and m parameter values to calculate the gradient. We have given as input x (array of size 2x2 with all elements set to 1), m (scaler value of 10) and c (scalar value of 5). To find out the gradient of the function first all elements of x will be multiplied by 2 and then will be multiplied by 10 (value of m).

In [ ]:
def func(x, m, c):
    x_square = (x * x)
    y = m * x_square + c

    return y.sum()
In [52]:
arr = jnp.ones((2,2), dtype=jnp.float32)
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)

jax.grad(func)(arr, m, c)
Out[52]:
DeviceArray([[20., 20.],
             [20., 20.]], dtype=float32)

Below we have explained another example of finding out the gradient of function where we are finding out the gradient of our function with respect to the second parameter by setting argnums parameter of grad() function to (1,).

In [54]:
arr = jnp.ones((2,2), dtype=jnp.float32)

jax.grad(func, argnums=(1,))(arr, m, c)
Out[54]:
(DeviceArray(4., dtype=float32),)

Below we have explained another example, where we are finding out the gradient of our input function with respect to all three parameters.

In [58]:
arr = jnp.ones((2,2), dtype=jnp.float32)

grad1, grad2, grad3 = jax.grad(func, argnums=(0, 1, 2))(arr, m, c)

grad1, grad2, grad3
Out[58]:
(DeviceArray([[20., 20.],
              [20., 20.]], dtype=float32),
 DeviceArray(4., dtype=float32),
 DeviceArray(4., dtype=float32))

7. vmap (Vectorized Mapping)

There are situations where we need our user-defined functions to work on a bunch of elements in parallel and combine results. JAX provides a function named vmap() for this purpose. It’s short for vectorized map and as its name suggests it works like the python map() function of python but it takes as input Jax arrays and lets us specify which axis of an array to use to vectorize operations.


  • vmap(func, in_axes=0,out_axes=0) - This function takes as input another function returns vectorized version of it. The usage of function will become easy to understand when we'll explain it with examples below.
    • The in_axes parameter takes as input single integer or tuple of integer specifying which axes of parameters on which to vectorize input.

Below we have first created an array of size 3x2x2 whose elements will be used as input for parameter x of our function. We'll be vectorizing this array at axis 0 hence all 2x2 arrays will be given as the value of parameter x of function.

In [4]:
arr = jnp.array([1,1,1,1,2,2,2,2,3,3,3,3])

arr = arr.reshape(3,2,2)

arr
Out[4]:
DeviceArray([[[1, 1],
              [1, 1]],

             [[2, 2],
              [2, 2]],

             [[3, 3],
              [3, 3]]], dtype=int32)

Below we have vectorized our function which we had defined in the working with functions section and used in the automatic gradients section. We have vectorized input for parameter x at the 0th axis of the input array.

We have then called vectorized function with our 3x2x2 array for parameter x and scalar values of parameter m and c. The vectorized function will work on 3 2x2 arrays for parameter x. The value of parameters m and c are not vectorized and will be the same for all 2x2 arrays. The output will be an array of size 3 as we have 3 2x2 arrays as input.

In [ ]:
def func(x, m, c):
    x_square = (x * x)
    y = m * x_square + c

    return y.sum()
In [10]:
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)

vmapped_func = jax.vmap(func, in_axes=(0,None,None))

vmapped_func(arr,m,c)
Out[10]:
DeviceArray([ 60., 180., 380.], dtype=float32)

Below we have a printed array if we vectorize input according to the 1st axis instead of the second. If vectorize input according to 1st axis our input 3x2x2 array then input array will be of size 3x2 for parameter x.

In [13]:
arr[:,0,:], arr[:,1,:]
Out[13]:
(DeviceArray([[1, 1],
              [2, 2],
              [3, 3]], dtype=int32),
 DeviceArray([[1, 1],
              [2, 2],
              [3, 3]], dtype=int32))

In the below cell, we have vectorized input function according to the 1st axis of the first parameter which is x in our case. We can notice that the output array is of size 2 because if we vectorize the array according to the 1st axis then there will be 2 3x2 arrays for parameter x.

In [11]:
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)

vmapped_func = jax.vmap(func, in_axes=(1,None,None))

vmapped_func(arr,m,c)
Out[11]:
DeviceArray([310., 310.], dtype=float32)

Below we have explained another example where we are vectorizing gradient of a function.

In [105]:
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)

jax.vmap(jax.grad(func), in_axes=(0,None,None))(arr,m,c)
Out[105]:
DeviceArray([[20., 20.],
             [20., 20.]], dtype=float32)

Below we have created another example where we have vectorized function at 2nd parameter m. We have given a single array of size 2x2 for parameter x and a single scalar value for parameter c. For parameter m, we have declared an array of 3 elements.

In [106]:
arr = jnp.ones((2,2), dtype=jnp.float32)
m = jnp.array([10.,15.,20.,])
c = jnp.array(5, dtype=jnp.float32)

jax.vmap(func, in_axes=(None,0,None))(arr,m,c)
Out[106]:
DeviceArray([ 60.,  80., 100.], dtype=float32)

In the next cell, we have created another example where we are vectorizing function for two input parameters (m and c). The value of parameter x is a fixed 2x2 array.

In [107]:
arr = jnp.ones((2,2), dtype=jnp.float32)
m = jnp.array([10.,15.,20.,])
c = jnp.array([5.,10.,15.,])

jax.vmap(func, in_axes=(None,0,0))(arr,m,c)
Out[107]:
DeviceArray([ 60., 100., 140.], dtype=float32)

The below example shows how we can vectorize our function for all three input parameters. Below example will work on 2x2 arrays for parameter x and scaler values for parameter m and c.

In [108]:
arr = jnp.array([1.,1,1,1,2,2,2,2,3,3,3,3])
arr = arr.reshape(3,2,2)

m = jnp.array([10.,15.,20.,])
c = jnp.array([5.,10.,15.,])

jax.vmap(func, in_axes=(0,0,0))(arr,m,c)
Out[108]:
DeviceArray([ 60., 280., 780.], dtype=float32)

Below we have vectorized the gradient of our input function for all three input parameters.

In [109]:
jax.vmap(jax.grad(func), in_axes=(0,0,0))(arr,m,c)
Out[109]:
DeviceArray([[[ 20.,  20.],
              [ 20.,  20.]],

             [[ 60.,  60.],
              [ 60.,  60.]],

             [[120., 120.],
              [120., 120.]]], dtype=float32)

8. Just In Time (JIT) Compiled

In this section, we'll explain how we can speed the operations performed using jax. As we had explained earlier, jax can use XLA (Accelerated Linear Algebra) compiler to fuse multiple operations together to speed the operations and reduce the running time of operations. XLA can speed up many linear algebra operations. It can speed up operations on both CPUs and GPUs.

In order to use XLA to speed up operations, JAX provides us with a method named jit(). We can either wrap any existing method inside of jit() method and it'll speed it up or we can also decorate that method with @jit decorator to speed it up.

Below we have taken our existing function which we had been using in our previous sections and wrapped it inside of jit() method to speed it up. We have then called normal function and jit-wrapped functions a few times to compare their performance. We have also recorded the time taken by each call for comparison.

The speedup provided by the jit-wrapped function becomes more visible as the size of Jax arrays increases.

In [ ]:
def func(x, m, c):
    x_square = (x * x)
    y = m * x_square + c

    return y.sum()
In [184]:
from jax import jit

arr = jnp.ones((2,2), dtype=jnp.float32)
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)

jitted_func = jax.jit(func)

%time out1 = func(arr,m,c)

%time out2 = jitted_func(arr,m,c)

%time out3 = jitted_func(arr,m,c)
CPU times: user 131 µs, sys: 0 ns, total: 131 µs
Wall time: 141 µs
CPU times: user 63 µs, sys: 0 ns, total: 63 µs
Wall time: 70.6 µs
CPU times: user 37 µs, sys: 0 ns, total: 37 µs
Wall time: 43.9 µs
In [185]:
arr = jnp.ones((5,5), dtype=jnp.float32)
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)

jitted_func = jax.jit(func)

%time out1 = func(arr,m,c)

%time out2 = jitted_func(arr,m,c)

%time out3 = jitted_func(arr,m,c)

%time out4 = jitted_func(arr*2,m,c)
CPU times: user 1.75 ms, sys: 0 ns, total: 1.75 ms
Wall time: 1.23 ms
CPU times: user 127 µs, sys: 0 ns, total: 127 µs
Wall time: 139 µs
CPU times: user 143 µs, sys: 0 ns, total: 143 µs
Wall time: 106 µs
CPU times: user 681 µs, sys: 0 ns, total: 681 µs
Wall time: 531 µs

9. JIT + vmap

In this section, we have explained how we can combine vmap() and jit() functions to increase the speed of the computation a lot.

Below we have first taken our existing function which we have been using for the last few sections and wrapped it using vmap() function. We have then wrapped the vmap-wrapped function inside of jit() function to increase speed further. The vmap() will vectorize our function and increase speed a bit by using low-level c code. The jit() function will increase speed further by using XLA compiler.

After wrapping our function inside of both vmap() and jit(), we have called it using different array sizes. We have called the vmap-wrapped function as well and recorded its time as well. We can easily notice the time difference in the vmap-wrapped function and both the vmap & jit-wrapped function.

In [ ]:
def func(x, m, c):
    x_square = (x * x)
    y = m * x_square + c

    return y.sum()
In [147]:
vmapped_func = jax.vmap(func, in_axes=(0,0,0))
jitted_func = jax.jit(vmapped_func)
In [148]:
arr = jnp.ones((3,2,2), dtype=jnp.float32)

m = jnp.ones(shape=(3,), dtype=jnp.float32) * 10
c = jnp.ones(shape=(3,), dtype=jnp.float32) * 5

%time out1 = vmapped_func(arr,m,c)

%time out2 = jitted_func(arr,m,c)

%time out3 = jitted_func(arr,m,c)
CPU times: user 4.94 ms, sys: 0 ns, total: 4.94 ms
Wall time: 5.12 ms
CPU times: user 38 ms, sys: 0 ns, total: 38 ms
Wall time: 36.8 ms
CPU times: user 29 µs, sys: 0 ns, total: 29 µs
Wall time: 31.9 µs
In [150]:
arr = jnp.ones((100,2,2), dtype=jnp.float32)

m = jnp.ones(shape=(100,), dtype=jnp.float32) * 10
c = jnp.ones(shape=(100,), dtype=jnp.float32) * 5

%time out1 = vmapped_func(arr,m,c)

%time out2 = jitted_func(arr,m,c)

%time out3 = jitted_func(arr,m,c)
CPU times: user 1.11 ms, sys: 171 µs, total: 1.28 ms
Wall time: 1.3 ms
CPU times: user 32.3 ms, sys: 0 ns, total: 32.3 ms
Wall time: 31.9 ms
CPU times: user 634 µs, sys: 0 ns, total: 634 µs
Wall time: 506 µs

This ends our small tutorial explaining JAX library and various functionalities provided by it. Please feel free to let us know your views in the comments.

References



Sunny Solanki  Sunny Solanki