Training#

BaseTrainTask()

Base class defining training task.

TrainTask(model, example_args[, ...])

Task definition for training of parameters of nn.Module.

TrainState(step, apply_fn, params, tx, ...)

Thin-wrapper to flax.training.train_state.TrainState.

Train tasks#

class bobbin.BaseTrainTask[source]#

Base class defining training task.

compute_loss(params, batch, *, extra_vars, prng_key, step)[source]#

Abstract method to be overridden for defining the loss function.

Return type

Tuple[Union[float, int], Tuple[Dict[str, Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]

make_training_step_fn()[source]#

Creates training step function.

Return type

TrainingStepFnBuilder

reduce_extra_vars(colname, tree, *, axis_name)[source]#

Abstract method to be overridden for sync non-parameter variables.

Return type

Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef]]

write_trainer_log(train_state, *, step_info, logger, loglevel)[source]#

Abstract method to be overridden for custom logging output.

make_log_writer(*, logger=None, loglevel=20)[source]#

Makes logging action that can be registered in CronTab.

Return type

Callable

make_checkpoint_saver(checkpoint_path, save_args=None)[source]#

Makes an action that saves checkpoint in the specified path.

class bobbin.TrainTask(model, example_args, example_kwargs=None, required_rngs=())[source]#

Task definition for training of parameters of nn.Module.

__init__(model, example_args, example_kwargs=None, required_rngs=())[source]#

Constructs the train task.

Parameters
  • model (Module) – flax model to be trained.

  • required_rngs (Iterable[str]) – the sequence of RNG names required for training. the values provided here will be used in TrainTask.get_rng_dict for simplifying RNG handling in for example compute_loss.

property model: flax.linen.module.Module#

Returns model given in the constructor.

Return type

Module

get_rng_dict(rng_key, extra_keys=())[source]#

Splits rng_key and returns rngs for each key specified in constructor.

Parameters
  • rng_key (PRNGKeyArray) – base RNG seed to be split.

  • extra_keys (Iterable[str]) – if set, additional RNG seeds are generated and stored to the return value with the provided keys.

Return type

Dict[str, PRNGKeyArray]

initialize_train_state(rng, tx, checkpoint_path=None, compile_init=None)[source]#

Initializes TrainState for this task.

Parameters
  • rng (PRNGKeyArray) – RNG seed for variable initialization.

  • tx (GradientTransformation) – optax gradient transformer attached to TrainState.

  • checkpoint_path (Union[str, bytes, PathLike, None]) – if set, this method first tries to deserialize from a checkpoint in the checkpoint_path.

  • compile_init (Optional[Callable]) – if set, self.model.init is wrapped by the given function as compile_init(self.model.init).

Return type

TrainState

Returns

initialized train state.

Train state#

class bobbin.TrainState(step, apply_fn, params, tx, opt_state, extra_vars)[source]#

Thin-wrapper to flax.training.train_state.TrainState.

This class is introduced to accommodate extra_vars field for handling mutable (non-trainable) variables.

property model_vars#

Returns model variable by merging parameters and extra_vars.

is_replicated_for_pmap()[source]#

Check if TrainState is replicated for pmap.

__init__(step, apply_fn, params, tx, opt_state, extra_vars)#
replace(**updates)[source]#

“Returns a new object replacing the specified fields with new values.

Evaluation#

EvalResults()

Evaluation results.

SampledSet(max_size[, values, priorities])

Immutable set containing the fixed number samples from the elements added.

EvalTask()

Base class defining evaluation task.

Evaluation results#

class bobbin.EvalResults[source]#

Evaluation results.

to_log_message()[source]#

Format EvalResults to a logging-friendly string.

reduce(other)[source]#

Combines two EvalResults.

Return type

EvalResults

unshard_and_reduce()[source]#

Merges sharded results (typically obtained via pmap) by reduce.

Return type

EvalResults

is_better_than(other)[source]#

Returns True if self is better than other.

This function is used for keeping track of the best checkpoint. If such comparison is not required, keep this unimplemented.

Parameters

other (EvalResults) – Another instance of the same EvalResults class.

Return type

bool

Returns

A boolean value.

write_to_tensorboard(current_train_state, writer)[source]#

Writes summary of this EvalResults to writer.

This function is a default method for .tensorboard.make_eval_results_writer. If make_eval_results_writer isn’t used, or the custom method is specified in this function, you can keep this unimplemented.

Parameters
  • current_train_state (TrainState) – TrainState used to obtain this EvalResults.

  • writer (SummaryWriter) – destination SummaryWriter.

__init__()#
replace(**updates)[source]#

“Returns a new object replacing the specified fields with new values.

class bobbin.SampledSet(max_size, values=(), priorities=())[source]#

Immutable set containing the fixed number samples from the elements added.

add(x)[source]#

Returns SampledSet with the given element x added to this set.

Return type

SampledSet[~T]

union(iterable)[source]#

Returns the union of this set and the given set.

Return type

SampledSet[~T]

__init__(max_size, values=(), priorities=())#
replace(**updates)[source]#

“Returns a new object replacing the specified fields with new values.

Evaluation tasks#

class bobbin.EvalTask[source]#

Base class defining evaluation task.

create_eval_results(dataset_name)[source]#

Initializes evaluation result.

Return type

EvalResults

evaluate(batch, *args, **kwargs)[source]#

Evaluate single batch and returns EvalResults.

Return type

EvalResults

finalize_eval_results(metrics)[source]#

Finalize eval metrics before it is stored or published to tensorboard.

Return type

EvalResults

make_cron_action(batch_gens, *, tensorboard_root_path=None)[source]#

Make cron action function for running the evaluation.

Crontab#

CronTab()

Table that contains pairs of triggers and actions.

class bobbin.CronTab[source]#

Table that contains pairs of triggers and actions.

__init__()[source]#
schedule(action, *, name=None, **kwargs)[source]#

Schedules an action for the trigger specified via kwargs.

Parameters
  • action (Callable[…, Optional[Tuple[TrainState, …]]]) – a function f(train_state, …) that takes TrainState and additional arguments specified in CronTab.run.

  • name (Optional[str]) – name for the registered action-trigger pair.

  • kwargs

    the following trigger specifiers are currently supported:
    • step_interval=N : do action for each N-steps.

    • at_step=N or at_step=(N1, N2, …) : do action at N, N1, or N2 steps

    • at_first_steps=N : do action for the first N steps of training.

    • at_first_steps_of_process=N : do action for the first N steps since the current training process started.

    • time_interval=X : do action if X (float) seconds passed since the last action from this trigger is invoked.

Return type

CronTab

Returns

self. updated crontab.

add(name, trigger, action)[source]#

Adds a trigger-action pair.

Return type

CronTab

run(train_state, is_train_state_replicated=True, *args, **kwargs)[source]#

Checks triggers and run the corresponding actions if needed.

Parameters
  • train_state (TrainState) – current training state.

  • is_train_state_replicated (bool) – if True (by default), train_state will be unreplicated before given to actions.

  • args – extra parameters passed to the actions.

  • kwargs – extra parameters passed to the actions.

Return type

Dict[str, Any]

TensorBoard#

NullSummaryWriter()

Null-object counterpart of flax.metrics.tensorboard.SummaryWriter.

ImageSummary(image)

MplImageSummary(image[, h_paddings, ...])

MultiDirectorySummaryWriter(log_dir_root, *)

SummaryWriter that changes the destination depending on the tag.

ScalarSummary(value)

ThreadedSummaryWriter(base_writer[, ...])

SummaryWriter that writes summaries in a separate thread.

publish_train_intermediates(writer, tree, ...)

Writes variables specified in the args to SummaryWriter writer.

publish_trainer_env_info(writer, train_state, *)

Publishes environment information to Tensorboard "Text" section.

Summary writers#

class bobbin.NullSummaryWriter[source]#

Null-object counterpart of flax.metrics.tensorboard.SummaryWriter.

__init__()#
class bobbin.MultiDirectorySummaryWriter(log_dir_root, *, keys=(), allow_new_keys=True, only_from_leader_process=True, auto_flush=True, tag_to_key=<function _default_key_from_tag>, dirname=<function _default_dirname_from_key>, use_threaded_writer=True)[source]#

SummaryWriter that changes the destination depending on the tag. actual usecases for wrapping functions like scalar.

__init__(log_dir_root, *, keys=(), allow_new_keys=True, only_from_leader_process=True, auto_flush=True, tag_to_key=<function _default_key_from_tag>, dirname=<function _default_dirname_from_key>, use_threaded_writer=True)[source]#

Constructs MultiDirectorySummaryWriter.

Parameters
  • log_dir_root – root directory for this MultiDirectorySummaryWriter. it can be anything that can be passed to epath.Path.

  • keys (Iterable[str]) – pre-specified set of keys.

  • only_on_leader_process – if True, (by default), writers do not actually write summaries in the processes with jax.process_index() != 0.

  • auto_flush (bool) – if True, sub-writers are instantiated with auto_flush=True argument.

  • tag_to_key (Callable[[str], Tuple[str, str]]) – a function that extracts sub-writer names and tag names used in the sub-writer from the tags. default is a function equivalent to lambda s: s.split(‘/’, maxsplit=1).

  • dirname (Callable[[str], str]) – a function that converts keys to the directory names under the root directory.

class bobbin.ThreadedSummaryWriter(base_writer, max_workers=1, wait_on_flush=False)[source]#

SummaryWriter that writes summaries in a separate thread.

This writer defers write operations (including flush operation) and returns from the function as soon as possible. However, by design, the transfer from devices to CPU memory is not deferred. Therefore, the latency of method calls will be dominated by memory transfer. This is important so devices can release the memory used for storing summaries as soon as the write method is called.

__init__(base_writer, max_workers=1, wait_on_flush=False)[source]#

Constructs ThreadedSummaryWriter that wraps base_writer.

Parameters
  • base_writer (SummaryWriter) – SummaryWriter or a compatible instance to be wrapped.

  • max_workers (int) – the number of I/O thread.

  • wait_on_flush (bool) – if True, ThreadedSummaryWriter.flush waits for all the write ops finished. Otherwise (by default), flush operation is also deferred and eventually executed.

Summary variable wrappers#

class bobbin.ImageSummary(image)[source]#
__init__(image)#
class bobbin.MplImageSummary(image, h_paddings=None, v_paddings=None, cmap=None, interpolation=None, aspect=None, origin=None, with_colorbar=False)[source]#
__init__(image, h_paddings=None, v_paddings=None, cmap=None, interpolation=None, aspect=None, origin=None, with_colorbar=False)#
class bobbin.ScalarSummary(value)[source]#
__init__(value)#

Publish functions#

bobbin.publish_train_intermediates(writer, tree, step, *, prefix='summary/')[source]#

Writes variables specified in the args to SummaryWriter writer.

Currently, this function only supports scalar summaries.

Parameters
  • writer (SummaryWriter) – Destination as flax.metrics.tensorboard.SummaryWriter.

  • tree (Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef]]) – Source of summary variables. Typically, a training state, or a variable collection.

  • step (int) – The step number of summary information.

  • scalar_selector – Selector for extracting scalars to be published. By default, variables with “:scalar” suffix will be selected.

  • scalar_tag_rewriter – A function that converts variable paths to a tag name used in TensorBoard. By default, tag names are defined as a name of last path component without “:scalar” suffix.

Return type

None

bobbin.publish_trainer_env_info(writer, train_state, *, prefix='trainer/diagnosis/', also_do_logging=True, loglevel=20)[source]#

Publishes environment information to Tensorboard “Text” section.

Currently, this function publishes the following information.

  • The numbers of parameters and extra variables in the model.

  • Shape information of the parameter tree.

  • str(jax.local_devices())

  • List of sys.argv elements.

Parameters
  • writer (SummaryWriter) – An instance of flax.metrics.tensorboard.SummaryWriter.

  • train_state (TrainState) – An instance of flax.training.train_state.TrainState that contains parameters.

  • prefix (str) – Prefix to tag names to be published.

  • also_do_logging (bool) – If True (default), published text data is also logged.

  • loglevel (int) – Log level used when also_do_logging == True.

Return type

None

Pmap utils#

tpmap(f, axis_name, argtypes[, kwargtypes, ...])

Transparent pmap (tpmap).

unshard(tree)

rtype

List[Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]

gather_from_jax_processes(v)

Gathers arbitrary trees from distributed processes.

assert_replica_integrity(tree, *[, ...])

Checks if replicas have exactly same values over devices and processes.

bobbin.tpmap(f, axis_name, argtypes, kwargtypes=None, *, devices=None, backend=None, wrap_return=None, **kwargs)[source]#

Transparent pmap (tpmap).

This function wraps a Jax function so it can be transparently performed on multiple devices. This wraps the function f with jax.pmap with applying argument-and-return wrappers for ensuring API compatibility with the original function.

Parameters
  • f (Callable) – Function to be wrapped.

  • axis_name (str) – axis_name used in jax.pmap.

  • argtypes (Sequence[Union[str, ArgType]]) – ArgType instances or string-representation of those describing distribution strategies of arguments.

  • kwargtypes (Optional[Mapping[str, Union[str, ArgType]]]) – argtypes for keyword arguments.

  • devices (Optional[Any]) – List of devices being used. By default, all jax.local_devices in the specified backend will be used.

  • backend (Optional[str]) – Backend for computation. By default, jax.default_backend will be used.

  • wrap_return (Optional[Callable[[Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef]]], Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]]) – Wrapper for the return value. If None (default), nothing will be applied so the return values have an extra leading axis that represents each local device.

Return type

Callable

Returns

Wrapped function that runs on multiple devices.

bobbin.unshard(tree)[source]#
Return type

List[Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]

bobbin.gather_from_jax_processes(v)[source]#

Gathers arbitrary trees from distributed processes.

Return type

List[Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef]]]

bobbin.assert_replica_integrity(tree, *, is_device_replicated=True, atol=1e-05, rtol=1e-05, backend='cpu')[source]#

Checks if replicas have exactly same values over devices and processes.

Parameters
  • tree (Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef]]) – Values to be verified.

  • is_device_replicated (bool) – If True (by default), it assumes that tree is already replicated over devices and has the leading axis corresponding to the local device.

  • atol – absolute/ relative tolerance passed to np.testing.assert_allclose.

  • rtol – absolute/ relative tolerance passed to np.testing.assert_allclose.

  • backend (str) – backend used for collective operations. strongly recommended to use the default value (“cpu”) as it will require huge amount of memory.

Return type

None

Var utils#

flatten_with_paths(node, *[, is_leaf])

Returns an iterator for leaves in the tree and their paths.

nested_vars_to_paths(node, *[, pathsep, is_leaf])

Constructs a tree with the same structure but containing path names as leaves.

dump_pytree_json(tree)

rtype

str

parse_pytree_json(json_str, template)

rtype

Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef]]

read_pytree_json_file(path, template)

rtype

Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef], None]

write_pytree_json_file(path, tree)

rtype

None

summarize_shape(tree)

Returns a string that summarizes shapes and dtypes of the tree.

total_dimensionality(tree)

Returns total dimensionality of the variables in the given tree.

Path#

bobbin.flatten_with_paths(node, *, is_leaf=None)[source]#

Returns an iterator for leaves in the tree and their paths.

Return type

Iterator[Tuple[str, Union[Array, ndarray, bool_, number]]]

bobbin.nested_vars_to_paths(node, *, pathsep='/', is_leaf=None)[source]#

Constructs a tree with the same structure but containing path names as leaves.

Return type

Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef]]

JSON I/O#

bobbin.dump_pytree_json(tree)[source]#
Return type

str

bobbin.parse_pytree_json(json_str, template)[source]#
Return type

Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef]]

bobbin.read_pytree_json_file(path, template)[source]#
Return type

Union[Array, ndarray, bool_, number, Iterable[ForwardRef], Mapping[Any, ForwardRef], None]

bobbin.write_pytree_json_file(path, tree)[source]#
Return type

None

Summarization#

bobbin.summarize_shape(tree)[source]#

Returns a string that summarizes shapes and dtypes of the tree.

Return type

str

bobbin.total_dimensionality(tree)[source]#

Returns total dimensionality of the variables in the given tree.

Return type

int