JAX is a python library specifically designed for making machine learning research easier. It provides the same API as that of numpy which lets us create multidimensional arrays and perform operations on them. All the arrays can be easily transferred from CPU to GPU/TPU and vice-versa. Arrays can be created directly on accelerators as well which does not need transfer across devices. Apart from this, it can compute an **automatic differentiation** of python functions. JAX is successor of famous automatic differentiation library **Autograd**. It can compute differentiation of almost any python code involving loops, conditions, recursion, etc, and can also compute differentiation of differentiation.JAX compiles code using **XLA** (Accelerated Linear Algebra) which fuses many operations to reduce the run time of the code based on underlying machine instructions and then run it on CPU/GPU/TPU. JAX can also provide just in time (JIT) compiler for running code fast on a single computer with multiple cores.

Below we have highlighted important features of JAX.

- Numpy-like API on CPU/GPU/TPU.
- XLA (Accelerated Linear Algebra) for faster execution
- Automatic Differentiation
- Just-In-Time Compilation.

As a part of this tutorial, we'll be covering the basics of JAX where we'll try to explain basic features with a few simple examples. Below we have highlighted important sections of the tutorial to give an overview of the material that we have covered.

**CPU**- pip install --upgrade "jax[cpu]"

**GPU**- pip install --upgrade "jax[cuda]"

Please make a **NOTE** that the GPU version requires that CUDA and CuDNN be installed as well. It does not come with pip installation. Apart from this two, it also requires that **jaxlib** library is installed which can be easily pip installed (pip install -U jaxlib). Please feel free to check JAX installation link for detailed installation instructions.

We'll start by importing **jax**, **jax.numpy**, and **numpy** as we'll be using this three in our tutorial.

In [1]:

```
import jax
print("JAX Version : {}".format(jax.__version__))
```

In [2]:

```
from jax import numpy as jnp
import numpy as np
```

In this section, we'll explain various ways to create a Jax array. The **jax.numpy** has almost the same API like that of **numpy** hence it has the majority of functions the same as **numpy**. The little background with **numpy** will make this tutorial quite easy to grasp.

Below we have created a JAX array by using **arange()** function which starts at 0 by default and increments value by 1 till it reaches 10. It has parameters **start, stop step, and dtype** for generating different arrays according to needs.

In [3]:

```
arr = jnp.arange(10)
arr
```

Out[3]:

JAX arrays are generally divided into two categories.

**Uncommitted Arrays**- Not attached to any device (CPU/GPU/TPU).**Committed Array**- Attached to the device.

By default, Jax arrays are uncommitted and kept first default device.

The list of devices can be found by calling **jax.devices()** function. It returns a list of **Device** instances describing individual devices. This list has more than one entry if there are multiple GPUs.

If a single GPU is present then by default JAX arrays are kept on GPU. If there is more than one GPU then the jax arrays will be kept on the first GPU from the list returned by **jax.devices()** function call. If GPU is not present then jax arrays will be kept on the CPU.

We can transfer jax arrays from one device to another by calling **jax.device_put()**. The function takes two parameters where the first parameter is an array and the second parameter is **Device** instance. It'll transfer the array to the device specified. If no device parameter is specified then it'll transfer to default GPU/CPU based on the list returned by **jax.devices()** function (first entry). The arrays transferred using **jax.device_put()** are committed arrays.

We have created this tutorial on CPU hence all our arrays will be on CPU.

In [24]:

```
device = arr.device() # this is same as arr.device_buffer.device()
device
```

Out[24]:

In [25]:

```
device.platform, device.device_kind
```

Out[25]:

In [27]:

```
## this is default device where arrays will be created.
## This default can be changed by setting 'JAX_PLATFORM_NAME' environment variable to 'cpu' or 'gpu'
default_device = jax.devices()[0]
default_device.platform, default_device.device_kind
```

Out[27]:

We can also create an array by using **array()** method giving it any numpy array or Python list. Below we have created a jax array from a simple python list.

In [23]:

```
arr = jnp.array([1,2,3,4,5])
arr
```

Out[23]:

There are also methods like **ones()** and **zeros()** like numpy which accepts array dimensions as input and creates an array with all elements one or zero.

In [5]:

```
jnp.ones((2,3))
```

Out[5]:

In [6]:

```
jnp.zeros((3,4))
```

Out[6]:

We also have methods like **eye()** and **diag()** which let us create an array with elements on the diagonal of the array.

The **eye()** method accepts a single dimension and it'll create a square array of that shape where all elements on diagonal will be one and all other elements will be zero.

The **diag()** method accepts single dimension array as input. It then creates a two-dimensional array where elements of the input array will be kept on diagonal and all other elements will be set to zero.

In [7]:

```
jnp.eye(5)
```

Out[7]:

In [8]:

```
jnp.diag(jnp.arange(1, 5))
```

Out[8]:

In [4]:

```
arr1 = jnp.arange(12, dtype=jnp.float32).reshape(3,4)
arr1
```

Out[4]:

In this section, we'll commonly perform array operations like reshape, transpose, dot product, addition, scaler operations, etc.

Below we have created a new array using **arange()** method and then reshaped it from one dimension to a two-dimensional array.

In [7]:

```
arr2 = jnp.arange(10,18, dtype=jnp.float32).reshape(4,2)
arr2
```

Out[7]:

Below we have reshaped our array which we had loaded earlier. We have changed it from a 2-dimensional array to a one-dimensional array.

In [8]:

```
arr1.reshape(-1)
```

Out[8]:

The **ravel()** method transforms the array of any time into a one-dimensional array. It basically flattens an array.

Below we have flattened our two-dimensional array which we had created earlier.

In [6]:

```
arr1.ravel()
```

Out[6]:

We can find the transpose of an array by calling **transpose()** method on it or by calling **T** attribute of it.

Below we have transposed our two-dimensional array.

In [12]:

```
arr1.T ## works exactly like arr1.transpose()
```

Out[12]:

We can also find our dot product of two arrays using **dot()** function. It accepts two arrays as input which needs to be multiplied.

Below we have calculated the dot product of our array 1 which is 3x4 dimensional and array 2 which is 4x2 dimensional hence resulting array is 3x2 dimensional.

In [9]:

```
jnp.dot(arr1, arr2)
```

Out[9]:

We can add two arrays of the same dimension using **add()** method as explained below.

In [15]:

```
jnp.add(arr1, arr1)
```

Out[15]:

We can find an index of minimum and maximum elements of the array using **argmin()** and **argmax()** methods.

It also accepts **axis** argument which specifies the axis according to which find a minimum or maximum. By default, it'll find min/max considering array is of single-dimensional hence will return index which will not be according to the dimension of the array which is more than one dimension. We can ravel the array and then use an index to find the min/max array. We can also find min/max according to a particular axis using **axis** attribute.

Below we have explained usage with simple examples.

In [16]:

```
jnp.argmin(arr1)
```

Out[16]:

In [17]:

```
jnp.argmin(arr1, axis=1)
```

Out[17]:

In [18]:

```
jnp.argmax(arr1)
```

Out[18]:

In [19]:

```
jnp.argmax(arr1, axis=1)
```

Out[19]:

In this section, we'll introduce some functions which are commonly used to perform simple statistical operations like mean, median, standard deviation, variance, correlation, etc.

We can find the minimum element of the array using **min()** element.

By default, it'll return the minimum element of the array.

If we want then we can find minimum elements across various axes as well using **axis** parameter.

In [20]:

```
jnp.min(arr1)
```

Out[20]:

Below we have found minimum element across all rows of array.

In [11]:

```
jnp.min(arr1, axis=0)
```

Out[11]:

This method works exactly like **min()** but finds the maximum element.

In [21]:

```
jnp.max(arr1)
```

Out[21]:

In [12]:

```
jnp.max(arr1, axis=0)
```

Out[12]:

This function helps find an average of the whole array or across different axes.

In [22]:

```
jnp.mean(arr1)
```

Out[22]:

In [13]:

```
jnp.mean(arr1, axis=1)
```

Out[13]:

The **std()** function helps us find the standard deviation of the array. We can use **axis** parameter to find standard deviation across different axes.

In [23]:

```
jnp.std(arr1)
```

Out[23]:

In [14]:

```
jnp.std(arr1, axis=1)
```

Out[14]:

The **sum()** function as the name suggests helps us find the sum of all array elements. We can use **axis** parameter to find the sum of different axes.

In [24]:

```
jnp.sum(arr1)
```

Out[24]:

In [15]:

```
jnp.sum(arr1, axis=0)
```

Out[15]:

The **var()** function helps us find the variance of the array. We can use **axis** parameter to find variance across different axes.

In [25]:

```
jnp.var(arr1)
```

Out[25]:

In [16]:

```
jnp.var(arr1, axis=1)
```

Out[16]:

The **correlate()** function helps us find cross-correlation of two 1-dimensional array.

Below we have found a correlation between the 1st row of our array 1 and 1 column of our array 2. The function lets us find a correlation between two one-dimensional arrays only.

In [28]:

```
arr1[1], arr2[:, 1]
```

Out[28]:

In [29]:

```
jnp.correlate(arr1[1], arr2[:, 1])
```

Out[29]:

We can generate random numbers using **random** module of **jax**. In this section, we'll explain with a few simple examples how we can generate random numbers using **jax**.

All functions of **jax.random** module requires us to provide seed to generate random numbers. We need to provide seed as an object of class **jax.random.PRNGKey** giving some random integer to it.

We can generate an array of random integers in a particular range using **randint()** function. We need to provide seed for generating random numbers as the first parameter followed by the shape of the output array and range as the minimum and maximum values.

Below we have generated an array of random integers of shape (2,3) in the range (1,10).

In [190]:

```
jax.random.randint(key=jax.random.PRNGKey(123), shape=(2,3), minval=1, maxval=10)
```

Out[190]:

We can generate a sample uniform distribution array in a particular range using **uniform()** function. It works almost exactly like **randint()** function. We need to provide seed followed by the shape of the array. The minimum and maximum value of the range is set as 0 and 1 respectively for generating random uniform distribution.

Below we have generated two random uniform distributed arrays using **uniform()** function. The first one generates floats in the range (0,1) and the second one generates floats in the range (0,10).

In [194]:

```
jax.random.uniform(key=jax.random.PRNGKey(123), shape=(2,3))
```

Out[194]:

In [195]:

```
jax.random.uniform(key=jax.random.PRNGKey(123), shape=(2,3), minval=1, maxval=10)
```

Out[195]:

We can generate normal random values using **normal()** function. We need to give it seed followed by the shape of the expected output array.

In [197]:

```
jax.random.normal(key=jax.random.PRNGKey(123), shape=(2,3))
```

Out[197]:

In this section, we'll explain how we can create a simple python function and pass jax arrays to it. We'll be performing simple operations on jax arrays through function. This section will build the background for the next section which is automatic gradients of functions.

Below we have created a simple function that takes three parameters as input. It then squares the first parameter **x**. Then it multiplies parameter **m** with the squared value of **x** and adds parameter **c** to it. At last, it returns the sum of all values of the output array.

In [6]:

```
def func(x, m, c):
x_square = (x * x)
y = m * x_square + c
return y.sum()
```

Below we have created an array with different values and called our function designed above to check whether it’s working properly. We can notice from the output that it seems to do the job.

The first call passes an array of size **2x2** with all elements as 1. The value of **m** is set as 10 and the value of **c** as 5. The square of the **2x2** array with all 1 will be 1. Then all elements will be multiplied by 10 (value of **m**) hence new **2x2** array will have all elements as 10. Then 5 (value of **c**) will be added to all elements of the array hence final **2x2** array will have all elements set as 15. At last, we sum up all elements of the array which is 4 elements with each having a value of 15 will result in a total of 60. We can do the calculation the same way for the second call in the next cell as well where we pass an array of size **2x2** with all elements set to 2.

We'll be using this function in the next section and try to find out the gradient of it.

In [50]:

```
arr = jnp.ones((2,2), dtype=jnp.float32)
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)
func(arr, m, c)
```

Out[50]:

In [51]:

```
arr = jnp.ones((2,2), dtype=jnp.float32) * 2
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)
func(arr, m, c)
```

Out[51]:

**JAX** can help us find out the differentiation of a function that works on jax arrays. There are many situations where we need to find differentiation of function and evaluate it. We can easily write code for the differentiation of simple functions, but things get really complicated for more evolved and complicated functions.

**JAX** can calculate differentiation of any function which works on jax arrays automatically with just one function call. It also lets us find out the differentiation of differentiation and so on. The main requirement is that our function should be working on jax arrays and the output of the function should be scalar.

**JAX** provides a function named **grad()** which can be used to find out the gradient of function working with jax arrays.

**grad(func,argnums=0,hax_aux=False,allow_int=False)**- This function takes as input other function working on jax arrays and calculates differentiation of it with respect to first parameter of the function. It returns another function which is the gradient of our input function. We can call this gradient function with the same parameter as our input function and it'll return the value gradient of the input function with respect to the specified parameters.- The
**argnums**parameter accepts single integer or tuple of integer specifying argument index with respect to which find out differentiation. By default, it's set to**0**which will force the function to find out differentiation with respect to the first parameter. We can give any other index other than 0 as well. We can give a tuple of indices as well if we want to find out differentiation with respect to more than one parameter. - The
**has_aux**accepts a boolean value. This is used in situations when our input function returns more than one value but only the first value of output is considered as the output of the function and all other returned data as auxiliary data. The differentiation will be calculated based on the first value. If this parameter is set to**True**then our input function can return more than one value and the first value will be considered the output of the function. It's**False**by default hence our function needs to return a single value. - The
**allow_int**parameter accepts boolean value. It can be set to**True**if we want to allow differentiation with respect to integer values. By default, it's**False**and differentiation with respect to only floats is allowed.

- The

The **grad()** function can let us find out the differentiation of function with respect to a parameter which accepts python list, dictionary, and tuple as well but all elements of this python types should be jax arrays. It'll differentiate input function with respect to all jax arrays given through python list, dict, or tuple.

Below we have calculated the gradient of our function from the previous section. By default, **grad()** function will find out gradient with respect to the first parameter which is **x** in our case. We have then evaluated the gradient by giving input data (**arr,m, and c**) as well.

As our earlier function was easy to understand, we can calculate the gradient of it as well. It'll be **m*2x**. We can then give **x** and **m** parameter values to calculate the gradient. We have given as input **x** (array of size 2x2 with all elements set to 1), **m** (scaler value of 10) and **c** (scalar value of 5). To find out the gradient of the function first all elements of **x** will be multiplied by 2 and then will be multiplied by 10 (value of **m**).

In [ ]:

```
def func(x, m, c):
x_square = (x * x)
y = m * x_square + c
return y.sum()
```

In [52]:

```
arr = jnp.ones((2,2), dtype=jnp.float32)
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)
jax.grad(func)(arr, m, c)
```

Out[52]:

Below we have explained another example of finding out the gradient of function where we are finding out the gradient of our function with respect to the second parameter by setting **argnums** parameter of **grad()** function to **(1,)**.

In [54]:

```
arr = jnp.ones((2,2), dtype=jnp.float32)
jax.grad(func, argnums=(1,))(arr, m, c)
```

Out[54]:

Below we have explained another example, where we are finding out the gradient of our input function with respect to all three parameters.

In [58]:

```
arr = jnp.ones((2,2), dtype=jnp.float32)
grad1, grad2, grad3 = jax.grad(func, argnums=(0, 1, 2))(arr, m, c)
grad1, grad2, grad3
```

Out[58]:

There are situations where we need our user-defined functions to work on a bunch of elements in parallel and combine results. **JAX** provides a function named **vmap()** for this purpose. It’s short for vectorized map and as its name suggests it works like the python **map()** function of python but it takes as input Jax arrays and lets us specify which axis of an array to use to vectorize operations.

**vmap(func, in_axes=0,out_axes=0)**- This function takes as input another function returns vectorized version of it. The usage of function will become easy to understand when we'll explain it with examples below.- The
**in_axes**parameter takes as input single integer or tuple of integer specifying which axes of parameters on which to vectorize input.

- The

Below we have first created an array of size **3x2x2** whose elements will be used as input for parameter **x** of our function. We'll be vectorizing this array at axis 0 hence all **2x2** arrays will be given as the value of parameter **x** of function.

In [4]:

```
arr = jnp.array([1,1,1,1,2,2,2,2,3,3,3,3])
arr = arr.reshape(3,2,2)
arr
```

Out[4]:

Below we have vectorized our function which we had defined in the working with functions section and used in the automatic gradients section. We have vectorized input for parameter **x** at the 0th axis of the input array.

We have then called vectorized function with our **3x2x2** array for parameter **x** and scalar values of parameter **m** and **c**. The vectorized function will work on 3 **2x2** arrays for parameter **x**. The value of parameters **m** and **c** are not vectorized and will be the same for all **2x2** arrays. The output will be an array of size **3** as we have 3 **2x2** arrays as input.

In [ ]:

```
def func(x, m, c):
x_square = (x * x)
y = m * x_square + c
return y.sum()
```

In [10]:

```
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)
vmapped_func = jax.vmap(func, in_axes=(0,None,None))
vmapped_func(arr,m,c)
```

Out[10]:

Below we have a printed array if we vectorize input according to the 1st axis instead of the second. If vectorize input according to 1st axis our input **3x2x2** array then input array will be of size **3x2** for parameter **x**.

In [13]:

```
arr[:,0,:], arr[:,1,:]
```

Out[13]:

In the below cell, we have vectorized input function according to the 1st axis of the first parameter which is **x** in our case. We can notice that the output array is of size **2** because if we vectorize the array according to the 1st axis then there will be 2 **3x2** arrays for parameter **x**.

In [11]:

```
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)
vmapped_func = jax.vmap(func, in_axes=(1,None,None))
vmapped_func(arr,m,c)
```

Out[11]:

Below we have explained another example where we are vectorizing gradient of a function.

In [105]:

```
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)
jax.vmap(jax.grad(func), in_axes=(0,None,None))(arr,m,c)
```

Out[105]:

Below we have created another example where we have vectorized function at 2nd parameter **m**. We have given a single array of size **2x2** for parameter **x** and a single scalar value for parameter **c**. For parameter **m**, we have declared an array of 3 elements.

In [106]:

```
arr = jnp.ones((2,2), dtype=jnp.float32)
m = jnp.array([10.,15.,20.,])
c = jnp.array(5, dtype=jnp.float32)
jax.vmap(func, in_axes=(None,0,None))(arr,m,c)
```

Out[106]:

In the next cell, we have created another example where we are vectorizing function for two input parameters (**m** and **c**). The value of parameter **x** is a fixed **2x2** array.

In [107]:

```
arr = jnp.ones((2,2), dtype=jnp.float32)
m = jnp.array([10.,15.,20.,])
c = jnp.array([5.,10.,15.,])
jax.vmap(func, in_axes=(None,0,0))(arr,m,c)
```

Out[107]:

The below example shows how we can vectorize our function for all three input parameters. Below example will work on **2x2** arrays for parameter **x** and scaler values for parameter **m** and **c**.

In [108]:

```
arr = jnp.array([1.,1,1,1,2,2,2,2,3,3,3,3])
arr = arr.reshape(3,2,2)
m = jnp.array([10.,15.,20.,])
c = jnp.array([5.,10.,15.,])
jax.vmap(func, in_axes=(0,0,0))(arr,m,c)
```

Out[108]:

Below we have vectorized the gradient of our input function for all three input parameters.

In [109]:

```
jax.vmap(jax.grad(func), in_axes=(0,0,0))(arr,m,c)
```

Out[109]:

In this section, we'll explain how we can speed the operations performed using jax. As we had explained earlier, jax can use **XLA** (Accelerated Linear Algebra) compiler to fuse multiple operations together to speed the operations and reduce the running time of operations. XLA can speed up many linear algebra operations. It can speed up operations on both CPUs and GPUs.

In order to use XLA to speed up operations, JAX provides us with a method named **jit()**. We can either wrap any existing method inside of **jit()** method and it'll speed it up or we can also decorate that method with **@jit** decorator to speed it up.

Below we have taken our existing function which we had been using in our previous sections and wrapped it inside of **jit()** method to speed it up. We have then called normal function and jit-wrapped functions a few times to compare their performance. We have also recorded the time taken by each call for comparison.

The speedup provided by the jit-wrapped function becomes more visible as the size of Jax arrays increases.

In [ ]:

```
def func(x, m, c):
x_square = (x * x)
y = m * x_square + c
return y.sum()
```

In [184]:

```
from jax import jit
arr = jnp.ones((2,2), dtype=jnp.float32)
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)
jitted_func = jax.jit(func)
%time out1 = func(arr,m,c)
%time out2 = jitted_func(arr,m,c)
%time out3 = jitted_func(arr,m,c)
```

In [185]:

```
arr = jnp.ones((5,5), dtype=jnp.float32)
m = jnp.array(10, dtype=jnp.float32)
c = jnp.array(5, dtype=jnp.float32)
jitted_func = jax.jit(func)
%time out1 = func(arr,m,c)
%time out2 = jitted_func(arr,m,c)
%time out3 = jitted_func(arr,m,c)
%time out4 = jitted_func(arr*2,m,c)
```

In this section, we have explained how we can combine **vmap()** and **jit()** functions to increase the speed of the computation a lot.

Below we have first taken our existing function which we have been using for the last few sections and wrapped it using **vmap()** function. We have then wrapped the vmap-wrapped function inside of **jit()** function to increase speed further. The **vmap()** will vectorize our function and increase speed a bit by using low-level c code. The **jit()** function will increase speed further by using **XLA** compiler.

After wrapping our function inside of both **vmap()** and **jit()**, we have called it using different array sizes. We have called the vmap-wrapped function as well and recorded its time as well. We can easily notice the time difference in the vmap-wrapped function and both the vmap & jit-wrapped function.

In [ ]:

```
def func(x, m, c):
x_square = (x * x)
y = m * x_square + c
return y.sum()
```

In [147]:

```
vmapped_func = jax.vmap(func, in_axes=(0,0,0))
jitted_func = jax.jit(vmapped_func)
```

In [148]:

```
arr = jnp.ones((3,2,2), dtype=jnp.float32)
m = jnp.ones(shape=(3,), dtype=jnp.float32) * 10
c = jnp.ones(shape=(3,), dtype=jnp.float32) * 5
%time out1 = vmapped_func(arr,m,c)
%time out2 = jitted_func(arr,m,c)
%time out3 = jitted_func(arr,m,c)
```

In [150]:

```
arr = jnp.ones((100,2,2), dtype=jnp.float32)
m = jnp.ones(shape=(100,), dtype=jnp.float32) * 10
c = jnp.ones(shape=(100,), dtype=jnp.float32) * 5
%time out1 = vmapped_func(arr,m,c)
%time out2 = jitted_func(arr,m,c)
%time out3 = jitted_func(arr,m,c)
```

**Thank You** for visiting our website. If you like our work, please support us so that we can keep on creating new tutorials/blogs on interesting topics (like AI, ML, Data Science, Python, Digital Marketing, SEO, etc.) that can help people learn new things faster. You can support us by clicking on the **Coffee** button at the bottom right corner. We would appreciate even if you can give a thumbs-up to our article in the comments section below.

If you want to

- provide some suggestions on topic
- share your views
- include some details in tutorial
- suggest some new topics on which we should create tutorials/blogs