How to write a training loop#

This notebook demonstrates how bobbin can support basic training loop of flax models.

Preamble: Install prerequisites, import modules.#

!pip -q install --upgrade pip
!pip -q install --upgrade "jax[cpu]"
!pip -q uninstall -y bobbin
!pip -q install --upgrade git+https://github.com/yotarok/bobbin.git
%%capture
import bobbin
import chex
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
# Simulate multi-device environment by CPU
chex.set_n_cpu_devices(8)

Array = chex.Array

Dataset configuration#

In this example, we use MNIST dataset from tensorflow_datasets (TFDS) for explanation. Random batches can be obtained from the MNIST training dataset by the following function.

def get_dataset(batch_size):
    ds = tfds.load("mnist", split="train", as_supervised=True)
    ds = ds.repeat().shuffle(1024).batch(batch_size).prefetch(1)
    return ds

Model definition#

The model we used for demonstration is a simple feed-forward networks with 2 hidden layer with a sigmoid activation function. Dropout is also applied for demonstrating how random number generators (RNGs) can be handled.

class MnistClassifier(nn.Module):
    @nn.compact
    def __call__(self, x: Array, *, training=True) -> Array:
        batch_size, *unused_image_dims = x.shape
        x = x.reshape((batch_size, -1))  # flatten the input image.
        hidden = nn.sigmoid(nn.Dense(features=1024)(x))
        hidden = nn.Dropout(0.5)(hidden, deterministic=not training)
        return nn.Dense(features=10)(hidden)

TrainTask definition#

For bobbin-based training loop, a subclass of bobbin.TrainTask should be defined for defining all training related artifacts. In the below example, we specified the following details of description:

  • How to compute loss function in overridden compute_loss function,

  • A model to be trained passed to the base-class constructor.

  • A fake input batch used for initializing the parameters passed to the base-class constructor as example_args argument.

  • A list of extra RNG names required for model initialization passed to the base-class constructor as required_rngs argument.

It should be noted that the required_rngs argument does not have to include “params”, as it is always used for initialization.

class MnistTrainingTask(bobbin.TrainTask):
    def __init__(self):
        super().__init__(
            MnistClassifier(),
            example_args=(
                np.zeros((1, 28, 28, 1), np.float32),  # comma-here is important
            ),
            required_rngs=("dropout",),
        )

    def compute_loss(self, params, batch, *, extra_vars, prng_key, step):
        images, labels = batch
        logits = self.model.apply(
            {"params": params}, images, rngs=self.get_rng_dict(prng_key)
        )
        per_sample_loss = optax.softmax_cross_entropy(
            logits=logits, labels=jax.nn.one_hot(labels, 10)
        )
        return jnp.mean(per_sample_loss), ({}, None)

Here, the return value of compute_loss is a bit complicated for this task. compute_loss is expected to return (loss, (updated_vars, loss_aux_info): Tuple[float, Tuple[VariableCollection, LossAuxInfo]]. In addition to the loss value as in loss, this function can update some non-trainable variable by returning updated_vars and some auxiliary info via loss_aux_info. Here, we set updated_vars to be an empty dictionary, and loss_aux_info to be None.

Initialization of TrainState and step function#

Once we have a definition for task, we can initialize the train_state, which contains everything needed for continuing training can be initialized as follows:

task = MnistTrainingTask()
train_state = task.initialize_train_state(jax.random.PRNGKey(0), optax.sgd(0.01))
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Here, the first argument for initialize_train_state is a base RNG seed used for initialization, and the second argument for initialize_train_state is an instance of optax.GradientTransformation, i.e. optimizer used in the training.

In addition to training state, we prepare the function that maps training state to the updated training state. This can be done by calling TrainingTask.make_training_step_fn, as follows:

train_step_fn = task.make_training_step_fn().pmap("d")

The return value of make_training_step_fn is a “configurable” function that can transform itself with some methods. Here, we call train_step_fn.pmap("d") for parallelizing the computation across multiple devices (“d” is passed as a axis_name parameter of jax.pmap). pmapped version of train_step_fn maps replicated version of TrainState so we also need to convert TrainState so it is compatible with the step function, as follows:

print(f"{train_state.step=}")
train_state = flax.jax_utils.replicate(train_state, jax.local_devices())
print(f"replicated version of {train_state.step=}")
train_state.step=0
replicated version of train_state.step=Array([0, 0, 0, 0, 0, 0, 0, 0], dtype=int32, weak_type=True)

As we see in the previous outputs, replication copies the same values across the devices, and from the Python interpreter, it can be seen as an array with an extra leading axis for each device.

Run the training loop#

Once we got everything, train_state and train_step_fn, ready, training can be written simply by repeating (train_state, ...) =  train_step_fn(train_state, batch, ...). Here, we do this over the first 500 batches from the dataset.

prng_key = jax.random.PRNGKey(0)
step_infos = []
for batch in get_dataset(64).take(500).as_numpy_iterator():
    rng, prng_key = jax.random.split(prng_key)
    train_state, step_info = train_step_fn(train_state, batch, rng)
    step_infos.append(step_info)

It should be noted that we had an extra complication due to RNGs. For each step, we need new RNG seed for Dropout module used in the model. This is done by splitting the root RNG (here defined as jax.random.PRNGKey(0)) for each usage of RNG.

The return value of train_step_fn is actually a pair of updated train_state and step_info: bobbin.training.StepInfo that carries auxiliary information obtained during the step. StepInfo has the following fields:

  • StepInfo.loss: loss value.

  • StepInfo.loss_aux_out: auxiliary output from the loss function. In this example, our compute_loss returns None, so this field must be None.

For pmapped version of train_step_fn, StepInfo is also replicated as follows:

print(step_infos[0])
StepInfo(loss=Array([2.803742 , 2.6738906, 3.0400233, 3.374338 , 3.1578255, 3.1524312,
       2.930612 , 2.8431826], dtype=float32), loss_aux_out=None)

So here, np.mean is applied to obtain actual loss values used for optimization, that are means over devices.

losses = [np.mean(step_info.loss) for step_info in step_infos]
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x7f2997f9b850>]
_images/445dd6ebe64217fa6b0f404078b129debca63623a15c00837a8bed0235c7230b.png