Training#
Base class defining training task. |
|
|
Task definition for training of parameters of nn.Module. |
|
Thin-wrapper to flax.training.train_state.TrainState. |
Train tasks#
- class bobbin.BaseTrainTask[source]#
Base class defining training task.
- compute_loss(params, batch, *, extra_vars, prng_key, step)[source]#
Abstract method to be overridden for defining the loss function.
- reduce_extra_vars(colname, tree, *, axis_name)[source]#
Abstract method to be overridden for sync non-parameter variables.
- write_trainer_log(train_state, *, step_info, logger, loglevel)[source]#
Abstract method to be overridden for custom logging output.
- class bobbin.TrainTask(model, example_args, example_kwargs=None, required_rngs=())[source]#
Task definition for training of parameters of nn.Module.
- __init__(model, example_args, example_kwargs=None, required_rngs=())[source]#
Constructs the train task.
- Parameters
model (
Module) – flax model to be trained.required_rngs (
Iterable[str]) – the sequence of RNG names required for training. the values provided here will be used in TrainTask.get_rng_dict for simplifying RNG handling in for example compute_loss.
- property model: flax.linen.module.Module#
Returns model given in the constructor.
- Return type
Module
- get_rng_dict(rng_key, extra_keys=())[source]#
Splits rng_key and returns rngs for each key specified in constructor.
- Parameters
rng_key (
PRNGKeyArray) – base RNG seed to be split.extra_keys (
Iterable[str]) – if set, additional RNG seeds are generated and stored to the return value with the provided keys.
- Return type
Dict[str,PRNGKeyArray]
- initialize_train_state(rng, tx, checkpoint_path=None, compile_init=None)[source]#
Initializes TrainState for this task.
- Parameters
rng (
PRNGKeyArray) – RNG seed for variable initialization.tx (
GradientTransformation) – optax gradient transformer attached to TrainState.checkpoint_path (
Union[str,bytes,PathLike,None]) – if set, this method first tries to deserialize from a checkpoint in the checkpoint_path.compile_init (
Optional[Callable]) – if set, self.model.init is wrapped by the given function as compile_init(self.model.init).
- Return type
- Returns
initialized train state.
Train state#
- class bobbin.TrainState(step, apply_fn, params, tx, opt_state, extra_vars)[source]#
Thin-wrapper to flax.training.train_state.TrainState.
This class is introduced to accommodate extra_vars field for handling mutable (non-trainable) variables.
- property model_vars#
Returns model variable by merging parameters and extra_vars.
- __init__(step, apply_fn, params, tx, opt_state, extra_vars)#
Evaluation#
Evaluation results. |
|
|
Immutable set containing the fixed number samples from the elements added. |
|
Base class defining evaluation task. |
Evaluation results#
- class bobbin.EvalResults[source]#
Evaluation results.
- unshard_and_reduce()[source]#
Merges sharded results (typically obtained via pmap) by reduce.
- Return type
- is_better_than(other)[source]#
Returns True if self is better than other.
This function is used for keeping track of the best checkpoint. If such comparison is not required, keep this unimplemented.
- Parameters
other (
EvalResults) – Another instance of the same EvalResults class.- Return type
bool- Returns
A boolean value.
- write_to_tensorboard(current_train_state, writer)[source]#
Writes summary of this EvalResults to writer.
This function is a default method for .tensorboard.make_eval_results_writer. If make_eval_results_writer isn’t used, or the custom method is specified in this function, you can keep this unimplemented.
- Parameters
current_train_state (
TrainState) – TrainState used to obtain this EvalResults.writer (
SummaryWriter) – destination SummaryWriter.
- __init__()#
- class bobbin.SampledSet(max_size, values=(), priorities=())[source]#
Immutable set containing the fixed number samples from the elements added.
- add(x)[source]#
Returns SampledSet with the given element x added to this set.
- Return type
SampledSet[~T]
- union(iterable)[source]#
Returns the union of this set and the given set.
- Return type
SampledSet[~T]
- __init__(max_size, values=(), priorities=())#
Evaluation tasks#
- class bobbin.EvalTask[source]#
Base class defining evaluation task.
- evaluate(batch, *args, **kwargs)[source]#
Evaluate single batch and returns EvalResults.
- Return type
Crontab#
|
Table that contains pairs of triggers and actions. |
- class bobbin.CronTab[source]#
Table that contains pairs of triggers and actions.
- schedule(action, *, name=None, **kwargs)[source]#
Schedules an action for the trigger specified via kwargs.
- Parameters
action (
Callable[…,Optional[Tuple[TrainState, …]]]) – a function f(train_state, …) that takes TrainState and additional arguments specified in CronTab.run.name (
Optional[str]) – name for the registered action-trigger pair.kwargs –
- the following trigger specifiers are currently supported:
step_interval=N : do action for each N-steps.
at_step=N or at_step=(N1, N2, …) : do action at N, N1, or N2 steps
at_first_steps=N : do action for the first N steps of training.
at_first_steps_of_process=N : do action for the first N steps since the current training process started.
time_interval=X : do action if X (float) seconds passed since the last action from this trigger is invoked.
- Return type
- Returns
self. updated crontab.
- run(train_state, is_train_state_replicated=True, *args, **kwargs)[source]#
Checks triggers and run the corresponding actions if needed.
- Parameters
train_state (
TrainState) – current training state.is_train_state_replicated (
bool) – if True (by default), train_state will be unreplicated before given to actions.args – extra parameters passed to the actions.
kwargs – extra parameters passed to the actions.
- Return type
Dict[str,Any]
TensorBoard#
Null-object counterpart of flax.metrics.tensorboard.SummaryWriter. |
|
|
|
|
|
|
SummaryWriter that changes the destination depending on the tag. |
|
|
|
SummaryWriter that writes summaries in a separate thread. |
|
Writes variables specified in the args to SummaryWriter writer. |
|
Publishes environment information to Tensorboard "Text" section. |
Summary writers#
- class bobbin.NullSummaryWriter[source]#
Null-object counterpart of flax.metrics.tensorboard.SummaryWriter.
- __init__()#
- class bobbin.MultiDirectorySummaryWriter(log_dir_root, *, keys=(), allow_new_keys=True, only_from_leader_process=True, auto_flush=True, tag_to_key=<function _default_key_from_tag>, dirname=<function _default_dirname_from_key>, use_threaded_writer=True)[source]#
SummaryWriter that changes the destination depending on the tag. actual usecases for wrapping functions like scalar.
- __init__(log_dir_root, *, keys=(), allow_new_keys=True, only_from_leader_process=True, auto_flush=True, tag_to_key=<function _default_key_from_tag>, dirname=<function _default_dirname_from_key>, use_threaded_writer=True)[source]#
Constructs MultiDirectorySummaryWriter.
- Parameters
log_dir_root – root directory for this MultiDirectorySummaryWriter. it can be anything that can be passed to epath.Path.
keys (
Iterable[str]) – pre-specified set of keys.only_on_leader_process – if True, (by default), writers do not actually write summaries in the processes with jax.process_index() != 0.
auto_flush (
bool) – if True, sub-writers are instantiated with auto_flush=True argument.tag_to_key (
Callable[[str],Tuple[str,str]]) – a function that extracts sub-writer names and tag names used in the sub-writer from the tags. default is a function equivalent to lambda s: s.split(‘/’, maxsplit=1).dirname (
Callable[[str],str]) – a function that converts keys to the directory names under the root directory.
- class bobbin.ThreadedSummaryWriter(base_writer, max_workers=1, wait_on_flush=False)[source]#
SummaryWriter that writes summaries in a separate thread.
This writer defers write operations (including flush operation) and returns from the function as soon as possible. However, by design, the transfer from devices to CPU memory is not deferred. Therefore, the latency of method calls will be dominated by memory transfer. This is important so devices can release the memory used for storing summaries as soon as the write method is called.
- __init__(base_writer, max_workers=1, wait_on_flush=False)[source]#
Constructs ThreadedSummaryWriter that wraps base_writer.
- Parameters
base_writer (
SummaryWriter) – SummaryWriter or a compatible instance to be wrapped.max_workers (
int) – the number of I/O thread.wait_on_flush (
bool) – if True, ThreadedSummaryWriter.flush waits for all the write ops finished. Otherwise (by default), flush operation is also deferred and eventually executed.
Summary variable wrappers#
Publish functions#
- bobbin.publish_train_intermediates(writer, tree, step, *, prefix='summary/')[source]#
Writes variables specified in the args to SummaryWriter writer.
Currently, this function only supports scalar summaries.
- Parameters
writer (
SummaryWriter) – Destination as flax.metrics.tensorboard.SummaryWriter.tree (
Union[Array,ndarray,bool_,number,Iterable[ForwardRef],Mapping[Any,ForwardRef]]) – Source of summary variables. Typically, a training state, or a variable collection.step (
int) – The step number of summary information.scalar_selector – Selector for extracting scalars to be published. By default, variables with “:scalar” suffix will be selected.
scalar_tag_rewriter – A function that converts variable paths to a tag name used in TensorBoard. By default, tag names are defined as a name of last path component without “:scalar” suffix.
- Return type
None
- bobbin.publish_trainer_env_info(writer, train_state, *, prefix='trainer/diagnosis/', also_do_logging=True, loglevel=20)[source]#
Publishes environment information to Tensorboard “Text” section.
Currently, this function publishes the following information.
The numbers of parameters and extra variables in the model.
Shape information of the parameter tree.
str(jax.local_devices())
List of sys.argv elements.
- Parameters
writer (
SummaryWriter) – An instance of flax.metrics.tensorboard.SummaryWriter.train_state (
TrainState) – An instance of flax.training.train_state.TrainState that contains parameters.prefix (
str) – Prefix to tag names to be published.also_do_logging (
bool) – If True (default), published text data is also logged.loglevel (
int) – Log level used when also_do_logging == True.
- Return type
None
Pmap utils#
|
Transparent pmap (tpmap). |
|
|
Gathers arbitrary trees from distributed processes. |
|
|
Checks if replicas have exactly same values over devices and processes. |
- bobbin.tpmap(f, axis_name, argtypes, kwargtypes=None, *, devices=None, backend=None, wrap_return=None, **kwargs)[source]#
Transparent pmap (tpmap).
This function wraps a Jax function so it can be transparently performed on multiple devices. This wraps the function f with jax.pmap with applying argument-and-return wrappers for ensuring API compatibility with the original function.
- Parameters
f (
Callable) – Function to be wrapped.axis_name (
str) – axis_name used in jax.pmap.argtypes (
Sequence[Union[str,ArgType]]) – ArgType instances or string-representation of those describing distribution strategies of arguments.kwargtypes (
Optional[Mapping[str,Union[str,ArgType]]]) – argtypes for keyword arguments.devices (
Optional[Any]) – List of devices being used. By default, all jax.local_devices in the specified backend will be used.backend (
Optional[str]) – Backend for computation. By default, jax.default_backend will be used.wrap_return (
Optional[Callable[[Union[Array,ndarray,bool_,number,Iterable[ForwardRef],Mapping[Any,ForwardRef]]],Union[Array,ndarray,bool_,number,Iterable[ForwardRef],Mapping[Any,ForwardRef]]]]) – Wrapper for the return value. If None (default), nothing will be applied so the return values have an extra leading axis that represents each local device.
- Return type
Callable- Returns
Wrapped function that runs on multiple devices.
- bobbin.assert_replica_integrity(tree, *, is_device_replicated=True, atol=1e-05, rtol=1e-05, backend='cpu')[source]#
Checks if replicas have exactly same values over devices and processes.
- Parameters
tree (
Union[Array,ndarray,bool_,number,Iterable[ForwardRef],Mapping[Any,ForwardRef]]) – Values to be verified.is_device_replicated (
bool) – If True (by default), it assumes that tree is already replicated over devices and has the leading axis corresponding to the local device.atol – absolute/ relative tolerance passed to np.testing.assert_allclose.
rtol – absolute/ relative tolerance passed to np.testing.assert_allclose.
backend (
str) – backend used for collective operations. strongly recommended to use the default value (“cpu”) as it will require huge amount of memory.
- Return type
None
Var utils#
|
Returns an iterator for leaves in the tree and their paths. |
|
Constructs a tree with the same structure but containing path names as leaves. |
|
|
|
|
|
|
|
|
|
Returns a string that summarizes shapes and dtypes of the tree. |
|
Returns total dimensionality of the variables in the given tree. |