JAX is a deep learning research framework recently introduced by Google and is written in Python. It provides functionalities like numpy-like API on CPU/GPU/TPU, automatic gradients, just-in-time compilation, etc. It's commonly used in many Google projects for deep learning research. JAX is a low-level library just like Tensorflow which can require many lines of code to design a network. To solve this, Google researchers designed another library named Flax on top of JAX to simplify the design of neural networks. Flax let us create neural networks just like we do with PyTorch.
Initially, Flax had its own sub-module of optimizers (SGD, Adam, etc.). But another team at Google designed a package named Optax which has implemented the majority of optimizers that are commonly used in deep learning nowadays. Due to this, Flax team has deprecated their sub-module of optimizers and recommends to everyone that we use optimizers available from Optax.
Generally, when we train a neural network using optimizers like Stochastic Gradient Descent, it keeps the learning rate the same through the training process. The learning rate for all batches and epochs is the same. This can sometimes lead to stagnant results after a few epochs and reducing the learning rate by a small amount can boost results a little further.
This process of reducing the learning rate during the training process is generally referred to as learning rate scheduling or learning rate annealing. There are various ways to decrease and handle the learning rate during training. Optax provides different types of learning rate schedules that we'll discuss as a part of this tutorial. We have used the Fashion MNIST dataset as a part of this tutorial and have trained a simple CNN on it to explain various schedules.
The tutorial assumes that the reader has a background of using Flax and JAX. Please feel free to go through the below tutorials that cover them if you want to refresh some concepts of them.
Below, we have listed important sections of tutorial to give an overview of the material covered.
Below, we have imported the main necessary Python libraries and printed the version that we have used in our tutorial.
import jax
print("JAX Version : {}".format(jax.__version__))
import flax
print("FLAX Version : {}".format(flax.__version__))
import optax
print("OPTAX Version : {}".format(optax.__version__))
In this section, we have loaded the fashion MNIST dataset available from keras. It has grayscale images of shape (28,28) pixels of 10 different fashion items. The data is already divided into the train (60k images) and test sets (10k images). After loading the datasets, we have introduced an extra channel dimension at the end of images are required by CNN networks. We have also divided images by 255 to bring numbers in the range [0,1]. The below table shows the mapping from index to item type.
Label | Description |
---|---|
0 | T-shirt/top |
1 | Trouser |
2 | Pullover |
3 | Dress |
4 | Coat |
5 | Sandal |
6 | Shirt |
7 | Sneaker |
8 | Bag |
9 | Ankle boot |
from tensorflow import keras
from sklearn.model_selection import train_test_split
from jax import numpy as jnp
(X_train, Y_train), (X_test, Y_test) = keras.datasets.fashion_mnist.load_data()
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
jnp.array(X_test, dtype=jnp.float32),\
jnp.array(Y_train, dtype=jnp.float32),\
jnp.array(Y_test, dtype=jnp.float32)
X_train, X_test = X_train.reshape(-1,28,28,1), X_test.reshape(-1,28,28,1)
X_train, X_test = X_train/255.0, X_test/255.0
classes = jnp.unique(Y_train)
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
In this section, we have defined a simple convolutional neural network that we'll use to classify images while explaining the usage of various learning rate schedules.
The network has two convolution layers with a filter size of 32 and 16 respectively and one dense layer with the same units as a number of different target classes. Both apply kernels of shape (3,3) on input data. Our forward pass first applies two convolution layers to input images one by one. It also performs relu (rectified linear unit) activation to the output of each convolution layer. The output of the second convolution layer is flattened after applying relu. The flattened output is then fed to a dense layer. The output of the dense layer is our prediction.
We have also printed the shape of weights and biases of layers of network in the next cell. We have also initialized network weight and performed a forward pass through it to verify.
from flax import linen
from jax import random
class CNN(linen.Module):
def setup(self):
self.conv1 = linen.Conv(features=32, kernel_size=(3,3), padding="SAME", name="CONV1")
self.conv2 = linen.Conv(features=16, kernel_size=(3,3), padding="SAME", name="CONV2")
self.linear1 = linen.Dense(len(classes), name="DENSE")
def __call__(self, inputs):
x = linen.relu(self.conv1(inputs))
x = linen.relu(self.conv2(x))
x = x.reshape((x.shape[0], -1))
logits = self.linear1(x)
return logits #linen.softmax(x)
seed = jax.random.PRNGKey(0)
model = CNN()
params = model.init(seed, X_train[:5])
for layer_params in params["params"].items():
print("Layer Name : {}".format(layer_params[0]))
weights, biases = layer_params[1]["kernel"], layer_params[1]["bias"]
print("\tLayer Weights : {}, Biases : {}".format(weights.shape, biases.shape))
preds = model.apply(params, X_train[:5])
preds.shape
In this section, we have defined the loss function that we'll use for our multi-class classification task. We have defined cross entropy loss function. The function takes as input network parameters, input data, and actual labels. It then performs a forward pass-through network using parameters to make predictions. Then, it one hot encodes actual labels. At last, it calculates loss using softmax_cross_entropy() function available from optax library.
def CrossEntropyLoss(weights, input_data, actual):
logits = model.apply(weights, input_data)
one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
return optax.softmax_cross_entropy(logits, one_hot_actual).sum()
In this section, we have defined a function that we'll use to train our network. The function takes data features (X), target values (Y), validation dataset (X_val, Y_val), number of epochs, model parameters, optimizer state, and batch size as input. It then executed a training loop for a number of epochs. Each time, it performs forward pass through network in batches of data, calculates loss, calculates gradients, and updates weights. After completion of each epoch, it even calculates validation accuracy. Once total training is completed, it returns model parameters and last validation accuracy.
from jax import value_and_grad
from tqdm import tqdm
from sklearn.metrics import accuracy_score
def TrainModelInBatches(X, Y, X_val, Y_val, epochs, weights, optimizer_state, batch_size=32):
for i in range(1, epochs+1):
batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices
losses = [] ## Record loss of each batch
for batch in tqdm(batches):
if batch != batches[-1]:
start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
else:
start, end = int(batch*batch_size), None
X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data
loss, gradients = value_and_grad(CrossEntropyLoss)(weights, X_batch,Y_batch)
## Update Weights
updates, optimizer_state = optimizer.update(gradients, optimizer_state)
weights = optax.apply_updates(weights, updates)
losses.append(loss) ## Record Loss
print("CrossEntropyLoss : {:.3f}".format(jnp.array(losses).mean()))
Y_val_preds = model.apply(weights, X_val)
val_acc = accuracy_score(Y_val, jnp.argmax(Y_val_preds, axis=1))
print("Validation Accuracy : {:.3f}".format(val_acc))
return weights, val_acc
In this section, we are training our neural network using a constant learning rate. We have set the number of epochs to 10, batch size to 256, and learning rate to 0.0001*. We have then initialized the model and its parameters. We have initialized the SGD optimizer with a constant learning rate. At last, we have called our training routine to train network. The function returns final updated parameters and validation accuracy.
We are maintaining a dictionary where we'll be storing the validation accuracy of all schedulers for comparison purposes later.
Later on, we have also plotted a chart showing the constant learning rate throughout the training process. We'll be plotting a chart like this for all our schedules to show how the learning rate is changing through the training process.
scheduler_val_accs = {}
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
optimizer = optax.sgd(learning_rate=learning_rate) ## Initialize SGD Optimizer
#constant_scheduler = optax.constant_schedule(0.0001)
#optimizer = optax.sgd(learning_rate=constant_scheduler)
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["Constant Learning Rate"] = val_acc
import matplotlib.pyplot as plt
constant_scheduler = optax.constant_schedule(0.0001)
lrs = [constant_scheduler(i) for i in range(100)]
plt.scatter(range(100), lrs)
plt.title("Constant Learning Rate")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
In this section, we are training our network with a cosine decay scheduler. We can create cosine decay scheduler using cosine_decay_schedule() function of optax. It accepts the initial learning rate and the total number of decay steps. It'll then decay learning rate in cosine curve fashion for those many steps. Here, one step is the execution of one batch of data. We can retrieve a total number of batches (steps) per epoch by dividing the number of train samples by batch size. Then, we can retrieve the total number of batches (steps) per training by multiplying the number of epochs by batches per epoch. This will constantly reduce the learning rate for all batches of training data.
The scheduler is a function that can take as input a step number and it'll return the learning rate to use for that step. Below is the formula used by the cosine decay scheduler to retrieve the learning rate for any step of training. We have tried to include the internal logic of the majority of schedulers for anyone who wants to know how they work internally. We have retrieved logic from optax documentation.
def scheduler(step_number):
count = minimum(count, decay_steps)
cosine_decay = 0.5 * (1 + cos(pi * step_number / decay_steps))
decayed = (1 - alpha) * cosine_decay + alpha
return init_learning_rate * decayed
In order to use schedulers, we need to initialize them and provide them to SGD optimizer. It'll then reduce the learning rate using the above formula. All the schedulers can be used the same way through SGD.
After training the network with the scheduler, we have stored validation accuracy in our dictionary.
In the next cell, we have also plotted a chart to show how the learning rate will change over time during training using a cosine decay scheduler. This can help us better understand how it works.
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs ## Total Batches
cosine_decay_scheduler = optax.cosine_decay_schedule(0.0001, decay_steps=total_steps, alpha=0.95)
optimizer = optax.sgd(learning_rate=cosine_decay_scheduler) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["Cosine Decay Scheduler"] = val_acc
import matplotlib.pyplot as plt
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
cosine_decay_scheduler = optax.cosine_decay_schedule(0.0001, decay_steps=total_steps, alpha=0.95)
lrs = [cosine_decay_scheduler(i) for i in range(total_steps)]
plt.scatter(range(total_steps), lrs)
plt.title("Cosine Decay Scheduler")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
In this section, we are training our network using a cosine one-cycle scheduler. We can create cosine one cycle scheduler using cosine_onecycle_schedule() function available from optax. It has a list of the below parameters. Some have default values and are optional.
Below, scheduler logic is used to retrieve the learning rate of any step number during training.
def _cosine_interpolate(start: float, end: float, pct: float):
return end + (start-end) / 2.0 * (jnp.cos(jnp.pi * pct) + 1)
def scheduler(step_number):
init_learning_rate = peak_value / div_factor
boundaries_and_scales = {int(pct_start * transition_steps): div_factor,
int(transition_steps): 1. / (div_factor * final_div_factor)}
boundaries, scales = zip(*sorted(boundaries_and_scales.items()))
bounds = np.stack((0,) + boundaries)
values = np.cumprod(jnp.stack((init_learning_rate,) + scales))
interval_sizes = (bounds[1:] - bounds[:-1])
indicator = (bounds[:-1] <= step_number) & (step_number < bounds[1:])
pct = (step_number - bounds[:-1]) / interval_sizes
interp_vals = _cosine_interpolate(values[:-1], values[1:], pct)
return indicator.dot(interp_vals) + (bounds[-1] <= step_number) * values[-1]
In our case, we have set peak value to 0.001, percentage start to 0.20 (20 %), division factor to 30, and final division factor to 100.
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
cosine_onecycle_scheduler = optax.cosine_onecycle_schedule(transition_steps=total_steps, peak_value=0.0001,
pct_start=0.20, div_factor=30., final_div_factor=100.)
optimizer = optax.sgd(learning_rate=cosine_onecycle_scheduler) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["Cosine One Cycle Scheduler"] = val_acc
import matplotlib.pyplot as plt
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
cosine_onecycle_scheduler = optax.cosine_onecycle_schedule(transition_steps=total_steps, peak_value=0.0001,
pct_start=0.20, div_factor=30., final_div_factor=100.)
lrs = [cosine_onecycle_scheduler(i) for i in range(total_steps)]
plt.scatter(range(total_steps), lrs)
plt.title("Cosine One Cycle Scheduler")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
In this section, we are training our network using an exponential decay scheduler to reduce the learning rate over time. We can create exponential decay scheduler using exponential_decay() method of optax library. The method takes the below parameters. Some of them have default values hence are optional.
def scheduler(step_number):
decayed_value = init_learning_rate * decay_rate ^ (step_number / transition_steps)
return decayed_value
Below, we have trained our network for 10 epochs using SGD with exponential decay. We are starting transition at 25% steps hence our learning rate will stay constant for the first 25% steps.
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
exponential_decay_scheduler = optax.exponential_decay(init_value=0.0001, transition_steps=total_steps,
decay_rate=0.98, transition_begin=int(total_steps*0.25),
staircase=False)
optimizer = optax.sgd(learning_rate=exponential_decay_scheduler) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["Exponential Decay"] = val_acc
import matplotlib.pyplot as plt
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
exponential_decay_scheduler = optax.exponential_decay(init_value=0.0001, transition_steps=total_steps,
decay_rate=0.98, transition_begin=int(total_steps*0.25),
staircase=False)
lrs = [exponential_decay_scheduler(i) for i in range(total_steps)]
plt.scatter(range(total_steps), lrs)
plt.title("Exponential Decay Scheduler")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
In this section, we are training our network with a linear one-cycle scheduler with SGD. We can create linear one cycle scheduler using linear_onecycle_schedule() function of optax library. It has a list of the below parameters. Some of the parameters are optional and have default values.
Below, we have included code that is used internally by optax to determine the learning rate at each step.
def _linear_interpolate(start: float, end: float, pct: float):
return (end-start) * pct + start
def scheduler(step_number):
init_learning_rate = peak_value / div_factor
boundaries_and_scales = {int(pct_start * transition_steps): div_factor,
int(transition_steps): 1. / (div_factor * final_div_factor)}
boundaries, scales = zip(*sorted(boundaries_and_scales.items()))
bounds = np.stack((0,) + boundaries)
values = np.cumprod(jnp.stack((init_learning_rate,) + scales))
interval_sizes = (bounds[1:] - bounds[:-1])
indicator = (bounds[:-1] <= step_number) & (step_number < bounds[1:])
pct = (step_number - bounds[:-1]) / interval_sizes
interp_vals = _linear_interpolate(values[:-1], values[1:], pct)
return indicator.dot(interp_vals) + (bounds[-1] <= step_number) * values[-1]
In our case, we have set peak value at 0.0001, percent start a 25%, percent final at 70%, division factor for initial learning rate at 10 (initial lr = 0.00001) and final division factor at 100 (final lr = 0.000001).
In the next cell after the below cell, we have also plotted a chart showing how the learning rate changes over time during the training process.
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
linear_onecycle_decay_scheduler = optax.linear_onecycle_schedule(transition_steps=total_steps, peak_value=0.0001,
pct_start=0.25, pct_final=0.7, div_factor=10.,
final_div_factor=100.)
optimizer = optax.sgd(learning_rate=linear_onecycle_decay_scheduler) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["Linear One Cycle Scheduler"] = val_acc
import matplotlib.pyplot as plt
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
linear_onecycle_decay_scheduler = optax.linear_onecycle_schedule(transition_steps=total_steps, peak_value=0.0001,
pct_start=0.25, pct_final=0.7,
div_factor=10., final_div_factor=100.)
lrs = [linear_onecycle_decay_scheduler(i) for i in range(total_steps)]
plt.scatter(range(total_steps), lrs)
plt.title("Linear One Cycle Decay Scheduler")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
In this section, we have trained our CNN using SGD with the linear scheduler. We can create linear scheduler using linear_schedule() function. It accepts a list of the below parameters.
def scheduler(step_number):
count = np.clip(step_number - transition_begin, 0, transition_steps)
frac = 1 - step_number / transition_steps
return (init_value - end_value) * frac + end_value
return schedule
In our case below, we have set the initial learning rate to 0.0001, the final learning rate to 0.00001 and the transition begin after the first 25% steps.
In the next cell, we have plotted a chart showing how the learning rate will change over time during training.
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
linear_decay_scheduler = optax.linear_schedule(init_value=0.0001, end_value=0.00001,
transition_steps=total_steps,
transition_begin=int(total_steps*0.25))
optimizer = optax.sgd(learning_rate=linear_decay_scheduler) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["Linear Scheduler"] = val_acc
import matplotlib.pyplot as plt
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
linear_decay_scheduler = optax.linear_schedule(init_value=0.0001, end_value=0.00001,
transition_steps=total_steps,
transition_begin=int(total_steps*0.25))
lrs = [linear_decay_scheduler(i) for i in range(total_steps)]
plt.scatter(range(total_steps), lrs)
plt.title("Linear Decay Scheduler")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
In this section, we have trained our CNN using SGD with the piecewise constant scheduler. We can create piecewise constant scheduler using piecewise_constant_schedule() function. Below is a list of parameters of the function.
Below, we have included code that optax uses internally to reduce the learning rate over time.
def scheduler(step_number):
v = init_value
if boundaries_and_scales is not None:
for threshold, scale in sorted(boundaries_and_scales.items()):
indicator = np.maximum(0., np.sign(step_number - count))
v = v * indicator + (1 - indicator) * scale * v
return v
In our case, we have set the initial learning rate to 0.0003. Then, mapping has 3 entries. The 25% steps, 50% steps, and 75% steps, all of which have a factor of 0.5. This will inform the scheduler that after 25% steps are passed, scale learning rate by 0.5, scale it by another 0.5 amount after 50% steps has passed, and finally after 75% has passed, scale it again by 0.5.
In the next cell, we have also plotted a chart showing how the learning rate will change that can make things clear to understand.
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
piecewise_constant_decay_scheduler = optax.piecewise_constant_schedule(init_value=0.0003,
boundaries_and_scales={int(total_steps*0.25):0.5,
int(total_steps*0.5):0.5,
int(total_steps*0.75):0.5})
optimizer = optax.sgd(learning_rate=piecewise_constant_decay_scheduler) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["Piecewise Constant Scheduler"] = val_acc
import matplotlib.pyplot as plt
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
piecewise_constant_decay_scheduler = optax.piecewise_constant_schedule(init_value=0.0003, boundaries_and_scales={int(total_steps*0.25):0.5, int(total_steps*0.5):0.5, int(total_steps*0.75):0.5})
lrs = [piecewise_constant_decay_scheduler(i) for i in range(total_steps)]
plt.scatter(range(total_steps), lrs)
plt.title("Piecewise Constant Decay Scheduler")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
In this section, we have trained our CNN using SGD with a piece-wise interpolate scheduler. We can create piecewise interpolate scheduler using piecewise_interpolate_schedule() function. It has the below-mentioned parameters.
Below, we have included logic on how the calculation of learning rate is done during training.
def _linear_interpolate(start: float, end: float, pct: float):
return (end-start) * pct + start
def _cosine_interpolate(start: float, end: float, pct: float):
return end + (start-end) / 2.0 * (jnp.cos(jnp.pi * pct) + 1)
def scheduler(step_number):
if interpolate_type == 'linear':
interpolate_fn = _linear_interpolate
elif interpolate_type == 'cosine':
interpolate_fn = _cosine_interpolate
else:
raise ValueError('`interpolate_type` must be either \'cos\' or \'linear\'')
if boundaries_and_scales:
boundaries, scales = zip(*sorted(boundaries_and_scales.items()))
else:
boundaries, scales = (), ()
bounds = np.stack((0,) + boundaries)
values = np.cumprod(np.stack((init_value,) + scales))
interval_sizes = (bounds[1:] - bounds[:-1])
indicator = (bounds[:-1] <= count) & (count < bounds[1:])
pct = (count - bounds[:-1]) / interval_sizes
interp_vals = interpolate_fn(values[:-1], values[1:], pct)
return indicator.dot(interp_vals) + (bounds[-1] <= count) * values[-1]
Below, we have initialized our scheduler with interpolate type as 'linear' and initial learning rate of 0.0003. The boundaries and scales parameter is the same as in our previous example.
In the next cell, we have also plotted learning rate change over time to give a better idea of how the scheduler works.
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
piecewise_interpolate_decay_scheduler = optax.piecewise_interpolate_schedule(interpolate_type="linear",
init_value=0.0003,
boundaries_and_scales={int(total_steps*0.25):0.5,
int(total_steps*0.5):0.5,
int(total_steps*0.75):0.5})
optimizer = optax.sgd(learning_rate=piecewise_interpolate_decay_scheduler) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["Piecewise Interpolate Scheduler V1"] = val_acc
import matplotlib.pyplot as plt
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
piecewise_interpolate_decay_scheduler = optax.piecewise_interpolate_schedule(interpolate_type="linear",
init_value=0.0003,
boundaries_and_scales={int(total_steps*0.25):0.5,
int(total_steps*0.5):0.5,
int(total_steps*0.75):0.5})
lrs = [piecewise_interpolate_decay_scheduler(i) for i in range(total_steps)]
plt.scatter(range(total_steps), lrs)
plt.title("Piecewise Interpolate Decay Scheduler")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
In the below cell, we are training our CNN again using a piece-wise interpolate scheduler. The scheduler has the same setting as the above example with only a change in interpolate type which is changed to 'cosine'.
In the next cell, we have also plotted a chart showing how the learning rate changes using this scheduler during training.
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
piecewise_interpolate_decay_scheduler = optax.piecewise_interpolate_schedule(interpolate_type="cosine",
init_value=0.0003,
boundaries_and_scales={int(total_steps*0.25):0.5,
int(total_steps*0.5):0.5,
int(total_steps*0.75):0.5})
optimizer = optax.sgd(learning_rate=piecewise_interpolate_decay_scheduler) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["Piecewise Interpolate Scheduler V2"] = val_acc
import matplotlib.pyplot as plt
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
piecewise_interpolate_decay_scheduler = optax.piecewise_interpolate_schedule(interpolate_type="cosine",
init_value=0.0003,
boundaries_and_scales={int(total_steps*0.25):0.5,
int(total_steps*0.5):0.5,
int(total_steps*0.75):0.5})
lrs = [piecewise_interpolate_decay_scheduler(i) for i in range(total_steps)]
plt.scatter(range(total_steps), lrs)
plt.title("Piecewise Interpolate Decay Scheduler")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
In this section, we have trained our CNN using SGD with the polynomial scheduler. We can create polynomial scheduler using polynomial_schedule() function. Below are important parameters of the function.
The below formula is used by optax to calculate the learning rate at any step of training.
def scheduler(step_number):
count = np.clip(step_number - transition_begin, 0, transition_steps)
frac = 1 - step_number / transition_steps
return (init_value - end_value) * (frac**power) + end_value
In our case, we have set the initial learning rate to 0.0001, the final learning rate to 0.00001, power to 0.5, and transition begin to 20% steps.
In the next cell below after training, we have also plotted a chart showing how the learning rate changes during training.
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
polynomial_decay_scheduler = optax.polynomial_schedule(init_value=0.0001, end_value=0.00001,
power=0.5, transition_steps=total_steps,
transition_begin=int(total_steps*0.20))
optimizer = optax.sgd(learning_rate=piecewise_interpolate_decay_scheduler) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["Polynomial Scheduler"] = val_acc
import matplotlib.pyplot as plt
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
polynomial_decay_scheduler = optax.polynomial_schedule(init_value=0.0001, end_value=0.00001,
power=0.5, transition_steps=total_steps,
transition_begin=int(total_steps*0.20))
lrs = [polynomial_decay_scheduler(i) for i in range(total_steps)]
plt.scatter(range(total_steps), lrs)
plt.title("Polynomial Decay Scheduler")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
In this section, we have trained our CNN using SGD with a warmup cosine decay scheduler. We can create warmup cosine decay scheduler using warmup_cosine_decay_schedule() function. It takes the below-mentioned parameters.
This scheduler first applies linear scheduler (Section 6 of this tutorial) to reach from initial learning rate to peak learning rate and then applies cosine decay scheduler (Section 2) to take learning rate from peak to final/end value.
In our case, we have set the initial learning rate to 0.0001, peak value to 0.0003, warmup steps to 20% of steps, and end value to 0.00001.
In the next cell after the training cell, we have plotted a chart showing how the learning rate changes during training.
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
warmup_cosine_decay_scheduler = optax.warmup_cosine_decay_schedule(init_value=0.0001, peak_value=0.0003,
warmup_steps=int(total_steps*0.2),
decay_steps=total_steps, end_value=0.00001)
optimizer = optax.sgd(learning_rate=warmup_cosine_decay_scheduler) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["Warmup Cosine Decay Scheduler"] = val_acc
import matplotlib.pyplot as plt
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
warmup_cosine_decay_scheduler = optax.warmup_cosine_decay_schedule(init_value=0.0001, peak_value=0.0003,
warmup_steps=int(total_steps*0.2),
decay_steps=total_steps, end_value=0.00001)
lrs = [warmup_cosine_decay_scheduler(i) for i in range(total_steps)]
plt.scatter(range(total_steps), lrs)
plt.title("Warmup Cosine Decay Scheduler")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
In this section, we have trained our CNN using SGD with a warmup exponential decay scheduler. We can create warmup exponential decay scheduler using warmup_exponential_decay_schedule() function. It takes the below-mentioned parameters.
This scheduler first applies linear scheduler (Section 6 of this tutorial) to reach from initial learning rate to peak learning rate and then applies exponential decay scheduler (Section 4) to take learning rate from peak to final/end value.
In our case, the initial learning rate is 0.0001, the peak value is 0.0003, warmup steps are initial 20% steps, the decay rate is 0.8, the transition begins at 10% steps (after the first 20% steps of warmup) and end value of learning rate is 0.00001.
We have also plotted the learning rate to show how it changes during training.
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
warmup_exponential_decay_scheduler = optax.warmup_exponential_decay_schedule(init_value=0.0001, peak_value=0.0003,
warmup_steps=int(total_steps*0.2),
transition_steps=total_steps,
decay_rate=0.8,
transition_begin=int(total_steps*0.1),
end_value=0.00001)
optimizer = optax.sgd(learning_rate=warmup_exponential_decay_scheduler) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["Warmup Exponential Decay Scheduler"] = val_acc
import matplotlib.pyplot as plt
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
warmup_exponential_decay_scheduler = optax.warmup_exponential_decay_schedule(init_value=0.0001, peak_value=0.0003,
warmup_steps=int(total_steps*0.2),
transition_steps=total_steps,
decay_rate=0.8,
transition_begin=int(total_steps*0.1),
end_value=0.00001)
lrs = [warmup_exponential_decay_scheduler(i) for i in range(total_steps)]
plt.scatter(range(total_steps), lrs)
plt.title("Warmup Exponential Decay Scheduler")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
Optax also let us combine multiple schedulers if one of them is not sufficient for our case. It provides a method named join_schedules() that lets us combine multiple schedulers. We need to provide a list of schedulers and boundaries of steps to which to apply those schedulers.
In our case below, we have created three schedulers.
We have set boundaries to [1000,2000]. This will apply a cosine decay scheduler to the first 1000 steps, a linear scheduler to steps from 1000 to 2000, and a cosine one cycle scheduler to the remaining steps after 2000.
We have also plotted how the learning rate will change during training with these combined schedulers.
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
cosine_decay_scheduler = optax.cosine_decay_schedule(0.0003, decay_steps=1000, alpha=0.7)
linear_decay_scheduler = optax.linear_schedule(init_value=0.0002, end_value=0.0001,
transition_steps=1000)
cosine_onecycle_scheduler = optax.cosine_onecycle_schedule(transition_steps=1000, peak_value=0.0001,
div_factor=30., final_div_factor=1000.)
multiple_schedulers = optax.join_schedules([cosine_decay_scheduler, linear_decay_scheduler,
cosine_onecycle_scheduler], boundaries=[1000, 2000])
optimizer = optax.sgd(learning_rate=multiple_schedulers) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["Combining Multiple Scheduler"] = val_acc
import matplotlib.pyplot as plt
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
cosine_decay_scheduler = optax.cosine_decay_schedule(0.0003, decay_steps=1000, alpha=0.7)
linear_decay_scheduler = optax.linear_schedule(init_value=0.0002, end_value=0.0001,
transition_steps=1000)
cosine_onecycle_scheduler = optax.cosine_onecycle_schedule(transition_steps=1000, peak_value=0.0001,
div_factor=30., final_div_factor=1000.)
multiple_schedulers = optax.join_schedules([cosine_decay_scheduler, linear_decay_scheduler,
cosine_onecycle_scheduler], boundaries=[1000, 2000])
lrs = [multiple_schedulers(i) for i in range(total_steps)]
plt.scatter(range(total_steps), lrs)
plt.title("Multiple Schedulers")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
In this section, we have trained our CNN using SGD with a warm restarts scheduler. We can initialize SGD with warm restarts scheduler using sgdr_schedule() function. It takes a list of dictionaries where each dictionary specifies parameters of the warmup cosine decay scheduler. It'll execute all warmup cosine decay schedulers one by one for a specified number of steps in configuration.
In our case, we have provided a list of three dictionaries to sgdr_schedule() function. Each specifies a different warmup cosine decay scheduler. The first entry asks to start learning rate from 0.0003 to the peak value of 0.0004 and then end at 0.0002. It'll run for the first 1000 steps. It'll take it 100 steps to each from 0.0003 to 0.0004. The same logic will be followed for the next 1000 steps using the second dictionary. Here, we'll take the learning rate from 0.0002 to 0.0003 and then drop it to 0.0001. For the last 1000 steps, we'll take the learning rate from 0.0001 to 0.0002 and then drop it to 0.00005.
We have also plotted how the learning rate will change during training using this scheduler.
seed = random.PRNGKey(0)
batch_size=256
epochs=10
learning_rate = jnp.array(1/1e4)
model = CNN()
weights = model.init(seed, X_train[:5])
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
sgd_with_warm_restarts_decay_scheduler = optax.sgdr_schedule(
[{"init_value":0.0003, "peak_value":0.0004, "decay_steps":1000, "warmup_steps":100, "end_value":0.0002},
{"init_value":0.0002, "peak_value":0.0003, "decay_steps":1000, "warmup_steps":100, "end_value":0.0001},
{"init_value":0.0001, "peak_value":0.0002, "decay_steps":1000, "warmup_steps":100, "end_value":0.00005},
]
)
optimizer = optax.sgd(learning_rate=sgd_with_warm_restarts_decay_scheduler) ## Initialize SGD Optimizer
optimizer_state = optimizer.init(weights)
final_weights, val_acc = TrainModelInBatches(X_train, Y_train, X_test, Y_test, epochs, weights, optimizer_state, batch_size=batch_size)
scheduler_val_accs["SGD With Warm Restarts"] = val_acc
import matplotlib.pyplot as plt
total_steps = epochs*(X_train.shape[0]//batch_size) + epochs
sgd_with_warm_restarts_decay_scheduler = optax.sgdr_schedule(
[{"init_value":0.0003, "peak_value":0.0004, "decay_steps":1000, "warmup_steps":100, "end_value":0.0002},
{"init_value":0.0002, "peak_value":0.0003, "decay_steps":1000, "warmup_steps":100, "end_value":0.0001},
{"init_value":0.0001, "peak_value":0.0002, "decay_steps":1000, "warmup_steps":100, "end_value":0.00005},
]
)
lrs = [sgd_with_warm_restarts_decay_scheduler(i) for i in range(total_steps)]
plt.scatter(range(total_steps), lrs)
plt.title("SGD With Warm Restarts Scheduler")
plt.ylabel("Learning Rate")
plt.xlabel("Epochs/Steps");
In this section, we have simply created a dataframe from validation accuracy that we had stored in the dictionary for all schedulers for comparison purposes.
import pandas as pd
pd.DataFrame(scheduler_val_accs, index=["Valid Accuracy"]).T
This ends our small tutorial explaining how we can use learning rate schedulers available from optax library for our Flax/JAX networks. We can easily create a custom scheduler as well for Flax/JAX networks as all of the schedulers are basically function that takes as input step number and returns learning rate for that step. Please feel free to let us know your views in the comments section.
If you are more comfortable learning through video tutorials then we would recommend that you subscribe to our YouTube channel.
When going through coding examples, it's quite common to have doubts and errors.
If you have doubts about some code examples or are stuck somewhere when trying our code, send us an email at coderzcolumn07@gmail.com. We'll help you or point you in the direction where you can find a solution to your problem.
You can even send us a mail if you are trying something new and need guidance regarding coding. We'll try to respond as soon as possible.
If you want to