How to use var_util API#
This short notebook demonstrates how to use “var_util” API of Bobbin. “var_util” is aiming at providing an easy way to access to deeply nested pytree structures.
Preamble: Install prerequisites, import modules.#
!pip -q install --upgrade pip
!pip -q install --upgrade "jax[cpu]"
!pip -q install git+https://github.com/yotarok/bobbin.git
%%capture
import bobbin
import chex
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
Define an array tree via nn.Module#
In this notebook, we demonstrate how to inspect/ manipulate the variables in some Flax modules. For this, we define a module that has several parameters, as follows:
Array = chex.Array
# You can use your custom pytree node as a part of variable.
class DiagnosticInfo(flax.struct.PyTreeNode):
average_entropy: float
input_norms: Array
class GaussianClassifier(nn.Module):
class_count: int = 4
@nn.compact
def __call__(self, x):
*unused_batch_sizes, dims = x.shape
means = self.param("means", nn.initializers.normal(), (dims, self.class_count))
logprecs = self.param(
"logprecs", nn.initializers.zeros_init(), (dims, self.class_count)
)
diffs = x[..., np.newaxis] - means.reshape((1,) * (x.ndim - 1) + means.shape)
diffs = jnp.exp(logprecs.reshape((1,) * (x.ndim - 1) + logprecs.shape)) * diffs
logits = jnp.sum(-diffs, axis=-2)
class_logprob = jax.nn.log_softmax(logits)
avg_entropy = jnp.mean(jnp.sum(-class_logprob * np.exp(class_logprob), axis=-1))
self.sow(
"diagnosis",
"info",
DiagnosticInfo(
average_entropy=avg_entropy,
input_norms=jnp.sqrt(jnp.sum(x * x, axis=-1)),
),
)
return class_logprob
The variable tree for this module can be obtained following normal Flax procedure, as follows:
batch_size = 4
dims = 3
mod = GaussianClassifier()
variables = mod.init(jax.random.PRNGKey(0), np.zeros((batch_size, dims)))
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Paths for variables#
“var_util” provides methods to access various pytrees via “path”s.
Paths are unique identifiers for each nodes in the tree. Leaves in the tree can be enumerated by using flatten_with_paths function as follows:
list(bobbin.var_util.flatten_with_paths(variables))
[('/diagnosis/info/0/average_entropy', Array(1.3861325, dtype=float32)),
('/diagnosis/info/0/input_norms', Array([0., 0., 0., 0.], dtype=float32)),
('/params/logprecs',
Array([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32)),
('/params/means',
Array([[ 0.0065701 , 0.00706267, -0.00381893, -0.01414316],
[ 0.00661003, -0.00954964, -0.00893679, 0.00803079],
[ 0.00558195, -0.01153143, -0.00493697, -0.02342076]], dtype=float32))]
Similarly to obtaining the list of pairs, a path-tree where each node is replaced by its path string can be obtained as follows:
paths = bobbin.var_util.nested_vars_to_paths(variables)
paths
FrozenDict({
params: {
means: '/params/means',
logprecs: '/params/logprecs',
},
diagnosis: {
info: (DiagnosticInfo(average_entropy='/diagnosis/info/0/average_entropy', input_norms='/diagnosis/info/0/input_norms'),),
},
})
Such path-trees are particularly important for doing some path-dependent operations over the tree. The following example overwrites “logprecs” parameters in the tree by ones.
def reset_logprecs(x, path):
return jnp.ones_like(x) if path.endswith("logprecs") else x
variables = jax.tree_util.tree_map(reset_logprecs, variables, paths)
variables
FrozenDict({
diagnosis: {
info: (DiagnosticInfo(average_entropy=Array(1.3861325, dtype=float32), input_norms=Array([0., 0., 0., 0.], dtype=float32)),),
},
params: {
logprecs: Array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float32),
means: Array([[ 0.0065701 , 0.00706267, -0.00381893, -0.01414316],
[ 0.00661003, -0.00954964, -0.00893679, 0.00803079],
[ 0.00558195, -0.01153143, -0.00493697, -0.02342076]], dtype=float32),
},
})
One can also use this mechanism to compute L2 norm for the specific parameters.
def compute_squared_l2norm_for_logprecs(x, path):
return jnp.sum(x * x) if path.endswith("logprecs") else 0.0
norm_tree = jax.tree_util.tree_map(
compute_squared_l2norm_for_logprecs, variables, paths
)
squared_l2_norm = jax.tree_util.tree_reduce(lambda acc, x: acc + x, norm_tree, 0.0)
print(squared_l2_norm)
12.0
JSON dumps#
For some use cases, JSON serialization for py-trees are useful, for example, for storing the evaluation results. Due to the inefficiency of text format, it is not recommended to store whole variables in this way, but some cases like evaluation metrics, that is convenient.
The JSON format can be obtained via dump_pytree_json function used as below:
json_text = bobbin.var_util.dump_pytree_json(variables)
print(json_text)
{"diagnosis": {"info": {"0": {"average_entropy": 1.3861324787139893, "input_norms": {"__array__": true, "dtype": "float32", "data": [0.0, 0.0, 0.0, 0.0]}}}}, "params": {"logprecs": {"__array__": true, "dtype": "float32", "data": [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]}, "means": {"__array__": true, "dtype": "float32", "data": [[0.006570101715624332, 0.007062666118144989, -0.003818930359557271, -0.01414316426962614], [0.006610026117414236, -0.009549644775688648, -0.008936785161495209, 0.008030789904296398], [0.0055819484405219555, -0.011531432159245014, -0.004936968442052603, -0.02342076413333416]]}}}
Here, you see that the array is stored with a special marker "__array__": true
and dtype field. However, other than that it is a normal JSON format that you can use various tools for manipulating it. If you want to write it directly to file systems (or GCS buckets), you may use write_pytree_json_file instead.
Loading JSON can be done by parse_pytree_json or it’s file-based equivalent, read_pytree_json_file.
For those functions, you need to specify template parameter for specifying the structure of a pytree to be loaded. Here, in the example below, template is obtained by initializing the same flax module (with different RNG key).
another_vars = mod.init(jax.random.PRNGKey(1), np.zeros((batch_size, dims)))
loaded_vars = bobbin.var_util.parse_pytree_json(json_text, another_vars)
loaded_vars
FrozenDict({
params: {
means: array([[ 0.0065701 , 0.00706267, -0.00381893, -0.01414316],
[ 0.00661003, -0.00954964, -0.00893679, 0.00803079],
[ 0.00558195, -0.01153143, -0.00493697, -0.02342076]],
dtype=float32),
logprecs: array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float32),
},
diagnosis: {
info: (DiagnosticInfo(average_entropy=1.3861324787139893, input_norms=array([0., 0., 0., 0.], dtype=float32)),),
},
})
It should be noted that template argument is only used for obtaining the tree structure, so it will not be altered after calling parse_pytree_json (or read_pytree_json_file.
another_vars
FrozenDict({
params: {
means: Array([[ 0.00078776, -0.00394429, 0.00607885, 0.00394586],
[-0.00017481, -0.00678178, -0.01871471, -0.00491523],
[ 0.00404862, 0.01051817, -0.00541831, -0.00435552]], dtype=float32),
logprecs: Array([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32),
},
diagnosis: {
info: (DiagnosticInfo(average_entropy=Array(1.3862587, dtype=float32), input_norms=Array([0., 0., 0., 0.], dtype=float32)),),
},
})
Miscellaneous utilities#
bobbin.summarize_shape can be used for obtaining shapes of the variable tree.
print(bobbin.summarize_shape(variables))
diagnosis:
info:
0:
average_entropy: () dtype=float32
input_norms: (4,) dtype=float32
params:
logprecs: (3, 4) dtype=float32
means: (3, 4) dtype=float32
Such shape information can be helpful when it is written as the TensorBoard text summary.
Also, there’s a short-cut for obtaining the total number of parameters.
print("# of variables =", bobbin.total_dimensionality(variables))
# of variables = 29.0