How to define an evaluation task#

This notebook demonstrates how bobbin can help building evaluation.

Preamble: Install prerequisites, import modules.#

!pip -q install --upgrade pip
!pip -q install --upgrade "jax[cpu]"
!pip -q uninstall -y bobbin
!pip -q install --upgrade git+https://github.com/yotarok/bobbin.git
%%capture
from typing import Tuple

import bobbin
import chex
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
chex.set_n_cpu_devices(8)

Array = chex.Array

Prepare a model to be evaluated#

First, we train a small model to be evaluated over the MNIST datasets. For more details in training with bobbin, please refer How to write a training loop.

The function below is a definition for dataset pipelines.

def get_dataset(batch_size, *, is_train=True, split="train"):
    ds = tfds.load("mnist", split=split, as_supervised=True)
    if is_train:
        ds = ds.repeat().shuffle(1024)
    ds = ds.batch(batch_size).prefetch(1)
    return ds

Note that the dataset is infinite only when is_train == True.

Here, a simple linear classifier is defined as a Flax module, also bobbin.TrainTask subclass is implemented for defining cross-entropy trainng over that simple classifier. (See training doc for details.)

class MnistLinearClassifier(nn.Module):
    @nn.compact
    def __call__(self, x: Array) -> Array:
        batch_size, *unused_image_dims = x.shape
        x = x.reshape((batch_size, -1))  # flatten the input image.
        return nn.Dense(features=10)(x)


class MnistTrainingTask(bobbin.TrainTask):
    def __init__(self):
        super().__init__(
            MnistLinearClassifier(),
            example_args=(
                np.zeros((1, 28, 28, 1), np.float32),  # comma-here is important
            ),
        )

    def compute_loss(self, params, batch, *, extra_vars, prng_key, step):
        images, labels = batch
        logits = self.model.apply({"params": params}, images)
        per_sample_loss = optax.softmax_cross_entropy(
            logits=logits, labels=jax.nn.one_hot(labels, 10)
        )
        return jnp.mean(per_sample_loss), ({}, None)


task = MnistTrainingTask()
train_state = task.initialize_train_state(
    jax.random.PRNGKey(0), optax.sgd(0.001, momentum=0.9)
)
train_step_fn = task.make_training_step_fn()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

500 steps of stochastic gradient descent are done as follows:

prng_key = jax.random.PRNGKey(0)
for batch in get_dataset(64).take(500).as_numpy_iterator():
    rng, prng_key = jax.random.split(prng_key)
    train_state, step_info = train_step_fn(train_state, batch, rng)
print(f"Last loss value = {step_info.loss}")
Last loss value = 34.819725036621094

Define EvalResults#

First, we define the metric used for evaluation. Here, also for demonstrating SampledSet, the result is containing both confusion matrices and sampled triples of inputs, outputs, and labels.

class EvalResults(bobbin.EvalResults):
    confusion_matrix: Array
    examples: bobbin.SampledSet[Tuple[int, int, Array]] = flax.struct.field(
        pytree_node=False, default=bobbin.SampledSet(max_size=4)
    )
    dataset_name: str = flax.struct.field(pytree_node=False, default="")

    def prediction_count(self) -> int:
        return jnp.sum(self.confusion_matrix)

    def correct_count(self) -> int:
        return jnp.sum(jnp.diag(self.confusion_matrix))

    def accuracy(self) -> float:
        return self.correct_count() / self.prediction_count()

    def reduce(self, other: "EvalResults") -> "EvalResults":
        return type(self)(
            confusion_matrix=self.confusion_matrix + other.confusion_matrix,
            examples=self.examples.union(other.examples),
        )

EvalResults is a subclass of flax.dataclass.PyTreeNode that means that it follows Python’s dataclass semantics. The type annotations at the beginning of the class definition are also used as a list of fields, and some special methods, e.g. constructors, are created automatically according to the list of fields.

Here, it should be noted that some of fields (“examples” and “dataset_name” in this example) are explicitly excluded from pytree definition by setting pytree_node=False. This is especially important when we do use JIT-ed function in the EvalTask below. Anything that is passed to the compiled (JIT-ed) function must have a jax representation, and some types like str doesn’t have that. Therefore, we have to exclude those things from pytree so they are treated as non-jax variables.

A subclass of bobbin.EvalTask is defined for implementing how to compute EvalResults above.

class EvalTask(bobbin.EvalTask):
    def __init__(self):
        self.model = MnistLinearClassifier()

    def create_eval_results(self, dataset_name):
        return EvalResults(
            confusion_matrix=np.zeros((10, 10)),
            examples=bobbin.SampledSet(max_size=4),
            dataset_name=dataset_name,
        )

    def evaluate(self, batch, model_vars) -> EvalResults:
        inputs, labels = batch
        logits = self.model.apply(model_vars, inputs)
        predicts = logits.argmax(axis=-1)
        confusion_mat = jnp.dot(
            jax.nn.one_hot(predicts, 10).T, jax.nn.one_hot(labels, 10)
        )
        examples = bobbin.SampledSet(max_size=4).union(
            zip(predicts, labels, list(inputs))
        )
        return EvalResults(confusion_matrix=confusion_mat, examples=examples)

EvalTask.create_eval_results is used for initializing the EvalResults instance. For each incoming batch, EvalTask.evaluate is called, and the results are combined following the pseudo-code below:

result = eval_task.create_eval_results(...)
for b in batches:
    result = result.reduce(eval_task.evaluate(b, ...))

The datasets given to the evaluation process are represented as a (nullary) function that returns iterator of batches. If we use tf.data API for representing the dataset, we can easily obtain such function as follows:

eval_datasets = {
    "train": get_dataset(32, split="train[:1000]", is_train=False).as_numpy_iterator,
    "test": get_dataset(32, split="test", is_train=False).as_numpy_iterator,
}

Actual computation can be invoked by calling bobbin.eval_datasets as below. (The function below basically just performs the for-loop based implementation written in the pseudo-code above, so you may write it down on your codebase.)

eval_task = EvalTask()
all_results = bobbin.eval_datasets(
    eval_task, eval_datasets, {"params": train_state.params}
)
all_results.keys()
dict_keys(['train', 'test'])

Here, you obtained all_results: dict[str, EvalResults] containing accumulated results for each dataset given. You can compute accuracies over the results, and also visualize the confusion matrix, as follows:

for dataset_name, results in all_results.items():
    print(f"dataset={dataset_name}:\tAccuracy = {results.accuracy()}")

plt.subplot(1, 2, 1)
plt.imshow(all_results["train"].confusion_matrix)
plt.title("train")
plt.subplot(1, 2, 2)
plt.imshow(all_results["test"].confusion_matrix)
plt.title("test")
dataset=train:	Accuracy = 0.847000002861023
dataset=test:	Accuracy = 0.8457000255584717
Text(0.5, 1.0, 'test')
_images/dc5e5968451ff49fed49f8514c36db8e18139161b9ef8a0cd4a6bb17db658397.png

EvalResults.examples field holds 4 random samples of input-label-output triples. Here, we can visualize how classifier worked as follows:

index = 1
for pred, lab, inputs in all_results["train"].examples:
    plt.subplot(2, 4, index)
    plt.imshow(inputs.reshape((28, 28)))
    plt.title(f"out={pred}, lab={lab}")
    index += 1
for pred, lab, inputs in all_results["test"].examples:
    plt.subplot(2, 4, index)
    plt.imshow(inputs.reshape((28, 28)))
    plt.title(f"out={pred}, lab={lab}")
    index += 1
_images/cfe7c233489addf545655b138cab856df67e0e46ad8da4b8a7cfe5913f8c7025.png