How to set up periodic action in the loop#
This notebook demonstrates crontab feature of bobbin.
Preamble: Install prerequisites, import modules.#
!pip -q install --upgrade pip
!pip -q install --upgrade "jax[cpu]"
!pip -q uninstall -y bobbin
!pip -q install --upgrade git+https://github.com/yotarok/bobbin.git
%%capture
import logging
import sys
import tempfile
import time
import bobbin
import chex
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
Array = chex.Array
logging.basicConfig(stream=sys.stdout)
logging.root.setLevel(logging.INFO)
logger = logging.getLogger()
logger.addHandler(logging.FileHandler("/dev/stdout"))
Define tasks and models#
Here, we will demonstrate how to construct a loop that involves full training setup. Some training/ evaluation setup is needed. Only minimum explanation added to the training/ evaluation code below. Please refer the following documents for training/ evaluation tasks in bobbin.
Training: How to write a training loop
Evaluation: How to define an evaluation task
First, let’s build a pipeline for pulling the training and evaluation datasets. The functions can be built as follows:
def get_train_dataset(batch_size):
ds = tfds.load("mnist", split="train", as_supervised=True)
ds = ds.repeat().shuffle(1024).batch(batch_size).prefetch(1)
return ds
def get_eval_dataset(batch_size):
ds = tfds.load("mnist", split="test[:1000]", as_supervised=True)
ds = ds.batch(batch_size).prefetch(1)
return ds
Then, we define the classifier model and loss function (in a subclass of TrainTask), as follows:
(please also check How to write a training loop)
class MnistClassifier(nn.Module):
@nn.compact
def __call__(self, x: Array, *, training=True) -> Array:
batch_size, *unused_image_dims = x.shape
x = x.reshape((batch_size, -1)) # flatten the input image.
hidden = nn.sigmoid(nn.Dense(features=512)(x))
return nn.Dense(features=10)(hidden)
class MnistTrainingTask(bobbin.TrainTask):
def __init__(self):
super().__init__(
MnistClassifier(),
example_args=(
np.zeros((1, 28, 28, 1), np.float32), # comma-here is important
),
)
def compute_loss(self, params, batch, *, extra_vars, prng_key, step):
images, labels = batch
logits = self.model.apply({"params": params}, images)
per_sample_loss = optax.softmax_cross_entropy(
logits=logits, labels=jax.nn.one_hot(labels, 10)
)
return jnp.mean(per_sample_loss), ({}, None)
task = MnistTrainingTask()
train_step_fn = task.make_training_step_fn().jit()
The evaluation metrics and how to evaluate the model can be defined as follows: (check How to define an evaluation task, too)
class EvalResults(bobbin.EvalResults):
correct_count: int
predict_count: int
@property
def accuracy(self) -> float:
return self.correct_count / self.predict_count
def is_better_than(self, other: "EvalResults") -> bool:
return self.accuracy > other.accuracy
def reduce(self, other: "EvalResults") -> "EvalResults":
return jax.tree_util.tree_map(lambda x, y: x + y, self, other)
def to_log_message(self) -> str:
return f"formatted in `EvalResults.to_log_message`. acc={self.accuracy:.2f}"
class EvalTask(bobbin.EvalTask):
def __init__(self):
self.model = MnistClassifier()
def create_eval_results(self, dataset_name):
return EvalResults(correct_count=0, predict_count=0)
def evaluate(self, batch, model_vars) -> EvalResults:
inputs, labels = batch
logits = self.model.apply(model_vars, inputs)
predicts = logits.argmax(axis=-1)
return EvalResults(
correct_count=(predicts == labels).astype(np.int32).sum(),
predict_count=labels.shape[0],
)
eval_batch_gens = {
"test": get_eval_dataset(32).as_numpy_iterator,
}
evaler = EvalTask()
Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/docs/tensorflow_datasets/mnist/3.0.1...
Dl Completed...: 0%| | 0/5 [00:00<?, ? file/s]
Dl Completed...: 20%|██ | 1/5 [00:00<00:01, 3.90 file/s]
Dl Completed...: 20%|██ | 1/5 [00:00<00:01, 3.90 file/s]
Dl Completed...: 40%|████ | 2/5 [00:00<00:00, 3.90 file/s]
Dl Completed...: 60%|██████ | 3/5 [00:00<00:00, 3.90 file/s]
Dl Completed...: 80%|████████ | 4/5 [00:00<00:00, 7.83 file/s]
Dl Completed...: 80%|████████ | 4/5 [00:00<00:00, 7.83 file/s]
Dl Completed...: 100%|██████████| 5/5 [00:00<00:00, 7.86 file/s]
Dl Completed...: 100%|██████████| 5/5 [00:00<00:00, 7.86 file/s]
Dl Completed...: 100%|██████████| 5/5 [00:00<00:00, 7.35 file/s]
Dataset mnist downloaded and prepared to /home/docs/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Setup crontab#
Given the above models and tasks, we are now ready to actually write a training loop.
As a first example, we design our main loop to greet to users for each 0.1 second using CronTab.schedule method.
def say_hello(train_state, *, message: str, **kwargs):
print(
f"{message} Training is currently at {train_state.step}-th step. {time.time()}"
)
crontab = bobbin.CronTab()
crontab.schedule(say_hello, time_interval=0.1)
prng_key = jax.random.PRNGKey(0)
train_state = task.initialize_train_state(jax.random.PRNGKey(0), optax.sgd(0.01))
for batch in get_train_dataset(64).take(500).as_numpy_iterator():
rng, prng_key = jax.random.split(prng_key)
train_state, step_info = train_step_fn(train_state, batch, rng)
crontab.run(train_state, message="Hello!!", is_train_state_replicated=False)
Hello!! Training is currently at 1-th step. 1689409022.1710103
Hello!! Training is currently at 19-th step. 1689409022.2746735
Hello!! Training is currently at 35-th step. 1689409022.3768137
Hello!! Training is currently at 50-th step. 1689409022.4794903
Hello!! Training is currently at 67-th step. 1689409022.5815556
Hello!! Training is currently at 83-th step. 1689409022.6842482
Hello!! Training is currently at 103-th step. 1689409022.7880812
Hello!! Training is currently at 121-th step. 1689409022.8893607
Hello!! Training is currently at 138-th step. 1689409022.9924822
Hello!! Training is currently at 154-th step. 1689409023.0935805
Hello!! Training is currently at 172-th step. 1689409023.1976123
Hello!! Training is currently at 190-th step. 1689409023.2999516
Hello!! Training is currently at 204-th step. 1689409023.403486
Hello!! Training is currently at 226-th step. 1689409023.5073564
Hello!! Training is currently at 244-th step. 1689409023.6129787
Hello!! Training is currently at 260-th step. 1689409023.7134995
Hello!! Training is currently at 277-th step. 1689409023.815819
Hello!! Training is currently at 294-th step. 1689409023.9218025
Hello!! Training is currently at 312-th step. 1689409024.0272682
Hello!! Training is currently at 329-th step. 1689409024.1277897
Hello!! Training is currently at 346-th step. 1689409024.232746
Hello!! Training is currently at 366-th step. 1689409024.3368897
Hello!! Training is currently at 382-th step. 1689409024.4413776
Hello!! Training is currently at 398-th step. 1689409024.5432444
Hello!! Training is currently at 416-th step. 1689409024.6452277
Hello!! Training is currently at 433-th step. 1689409024.7462764
Hello!! Training is currently at 448-th step. 1689409024.8471355
Hello!! Training is currently at 465-th step. 1689409024.951886
Hello!! Training is currently at 482-th step. 1689409025.0533895
Hello!! Training is currently at 500-th step. 1689409025.1602309
The first argument of CronTab.schedule is something called “Action” that can be anything called as f(train_state, **kwargs).
The action registered by CronTab.schedule will be called when you call CronTab.run at the end of each training step, and if the pre-specified condition met.
In this case, the pre-defined condition is satisfied when the elapsed time since the action is lastly executed is longer than 0.1 second.
(In other words, the action executed only once even if the step took longer than 0.2 seconds.)
One can pass additional context information by adding keywords arguments to the call of CronTab.run.
CronTab is defined to be a hub for weakly connect the functionalities provided by other bobbin sub-modules.
For example, TrainTask provides an action that write training log to the logger, and EvalTask provides an action to run the evaluation process over the datasets, as follows:
crontab = bobbin.CronTab()
crontab.schedule(
task.make_log_writer(loglevel=logging.WARNING), at_step=123, step_interval=100
)
crontab.schedule(
evaler.make_cron_action(eval_batch_gens, tensorboard_root_path=None),
step_interval=123,
)
prng_key = jax.random.PRNGKey(0)
train_state = task.initialize_train_state(jax.random.PRNGKey(0), optax.sgd(0.01))
for batch in get_train_dataset(64).take(500).as_numpy_iterator():
rng, prng_key = jax.random.split(prng_key)
train_state, step_info = train_step_fn(train_state, batch, rng)
crontab.run(train_state, step_info=step_info, is_train_state_replicated=False)
In the example, TrainTask.make_log_writer only writes a very simple log message, this can be customized by overriding TrainTask.write_trainer_log function.
CronTab can also be used for tying the training loop with checkpoint writers. In the below example, we use two directory for storing checkpoints; one is for storing normal checkpoints for resuming the training processes, and the other one is for keeping best performing checkpoints for the future usage.
checkpoint_temp_dir = tempfile.TemporaryDirectory()
best_checkpoint_temp_dir = tempfile.TemporaryDirectory()
crontab = bobbin.CronTab()
crontab.schedule(
task.make_checkpoint_saver(checkpoint_temp_dir.name), step_interval=1000
)
crontab.schedule(
evaler.make_cron_action(
eval_batch_gens, tensorboard_root_path=None
).keep_best_checkpoint("test", best_checkpoint_temp_dir.name),
step_interval=1000,
)
prng_key = jax.random.PRNGKey(0)
train_state = task.initialize_train_state(jax.random.PRNGKey(0), optax.sgd(0.1))
for batch in get_train_dataset(64).take(5000).as_numpy_iterator():
rng, prng_key = jax.random.split(prng_key)
train_state, step_info = train_step_fn(train_state, batch, rng)
crontab.run(train_state, step_info=step_info, is_train_state_replicated=False)
print("Latest checkpoints:")
!ls {checkpoint_temp_dir.name}
print("Best checkpoints:")
!ls {best_checkpoint_temp_dir.name}
print("Results of the best checkpoint")
!cat {best_checkpoint_temp_dir.name}/results.json
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-9-37c635701446> in <module>
18 rng, prng_key = jax.random.split(prng_key)
19 train_state, step_info = train_step_fn(train_state, batch, rng)
---> 20 crontab.run(train_state, step_info=step_info, is_train_state_replicated=False)
21
22 print("Latest checkpoints:")
~/checkouts/readthedocs.org/user_builds/bobbin/envs/latest/lib/python3.8/site-packages/bobbin/cron.py in run(self, train_state, is_train_state_replicated, *args, **kwargs)
220 for name, trig, act in self._actions:
221 if trig.check(train_state):
--> 222 results[name] = act(train_state, *args, **kwargs)
223 return results
~/checkouts/readthedocs.org/user_builds/bobbin/envs/latest/lib/python3.8/site-packages/bobbin/training.py in save(train_state, **unused_kwargs)
317 # checkpointing configuration.
318 def save(train_state: TrainState, **unused_kwargs):
--> 319 checkpoints.save_checkpoint(
320 checkpoint_path, train_state, train_state.step, **save_args
321 )
~/checkouts/readthedocs.org/user_builds/bobbin/envs/latest/lib/python3.8/site-packages/flax/training/checkpoints.py in save_checkpoint(ckpt_dir, target, step, prefix, keep, overwrite, keep_every_n_steps, async_manager, orbax_checkpointer)
584 if not orbax_checkpointer:
585 orbax_checkpointer = orbax.Checkpointer(
--> 586 orbax.PyTreeCheckpointHandler(restore_with_serialized_types=False)
587 )
588 # Check singular target.
TypeError: __init__() got an unexpected keyword argument 'restore_with_serialized_types'