Skip to content

MonAI Handlers

Handlers for experiment tracking on Weights & Biases with MonAI Engines.

WandbStatsHandler

WandbStatsHandler defines a set of Ignite Event-handlers for all the Weights & Biases logging logic. It can be used for any Ignite Engine(trainer, validator and evaluator) and support both epoch level and iteration level. The expected data source is Ignite engine.state.output and engine.state.metrics.

Default behaviors
  • When EPOCH_COMPLETED, write each dictionary item in engine.state.metrics to Weights & Biases.
  • When ITERATION_COMPLETED, write each dictionary item in self.output_transform(engine.state.output) to Weights & Biases.

Usage:

# WandbStatsHandler for logging training metrics and losses at
# every iteration to Weights & Biases
train_wandb_stats_handler = WandbStatsHandler(output_transform=lambda x: x)
train_wandb_stats_handler.attach(trainer)

# WandbStatsHandler for logging validation metrics and losses at
# every iteration to Weights & Biases
val_wandb_stats_handler = WandbStatsHandler(
    output_transform=lambda x: None,
    global_epoch_transform=lambda x: trainer.state.epoch,
)
val_wandb_stats_handler.attach(evaluator)
Example notebooks:
Pull Request to add WandbStatsHandler to MonAI repository

There is an open pull request to add WandbStatsHandler to MonAI.

Parameters:

Name Type Description Default
iteration_log bool

Whether to write data to Weights & Biases when iteration completed, default to True.

True
epoch_log bool

Whether to write data to Weights & Biases when epoch completed, default to True.

True
epoch_event_writer Optional[Callable[[Engine, Any], Any]]

Customized callable Weights & Biases writer for epoch level. Must accept parameter "engine" and "summary_writer", use default event writer if None.

None
epoch_interval int

The epoch interval at which the epoch_event_writer is called. Defaults to 1.

1
iteration_event_writer Optional[Callable[[Engine, Any], Any]]

Customized callable Weights & Biases writer for iteration level. Must accept parameter "engine" and "summary_writer", use default event writer if None.

None
iteration_interval int

The iteration interval at which the iteration_event_writer is called. Defaults to 1.

1
output_transform Callable

A callable that is used to transform the ignite.engine.state.output into a scalar to plot, or a dictionary of {key: scalar}. In the latter case, the output string will be formatted as key: value. By default this value plotting happens when every iteration completed. The default behavior is to print loss from output[0] as output is a decollated list and we replicated loss value for every item of the decollated list. engine.state and output_transform inherit from the ignite concept: https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.

lambda : x[0]
global_epoch_transform Callable

A callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine might want to use trainer engines epoch number when plotting epoch vs metric curves.

lambda : x
state_attributes Optional[Sequence[str]]

Expected attributes from engine.state, if provided, will extract them when epoch completed.

None
tag_name str

When iteration output is a scalar, tag_name is used to plot, defaults to 'Loss'.

DEFAULT_TAG
Source code in wandb_addons/monai/stats_handler.py
class WandbStatsHandler:
    """
    `WandbStatsHandler` defines a set of Ignite Event-handlers for all the Weights & Biases logging
    logic. It can be used for any Ignite Engine(trainer, validator and evaluator) and support both
    epoch level and iteration level. The expected data source is Ignite `engine.state.output` and
    `engine.state.metrics`.

    Default behaviors:
        - When EPOCH_COMPLETED, write each dictionary item in `engine.state.metrics` to
            Weights & Biases.
        - When ITERATION_COMPLETED, write each dictionary item in
            `self.output_transform(engine.state.output)` to Weights & Biases.

    **Usage:**

    ```python
    # WandbStatsHandler for logging training metrics and losses at
    # every iteration to Weights & Biases
    train_wandb_stats_handler = WandbStatsHandler(output_transform=lambda x: x)
    train_wandb_stats_handler.attach(trainer)

    # WandbStatsHandler for logging validation metrics and losses at
    # every iteration to Weights & Biases
    val_wandb_stats_handler = WandbStatsHandler(
        output_transform=lambda x: None,
        global_epoch_transform=lambda x: trainer.state.epoch,
    )
    val_wandb_stats_handler.attach(evaluator)
    ```

    ??? example "Example notebooks:"
        - [3D classification using MonAI](../examples/densenet_training_dict).
        - [3D segmentation using MonAI](../examples/unet_3d_segmentation).

    ??? note "Pull Request to add `WandbStatsHandler` to MonAI repository"

        There is an [open pull request](https://github.com/Project-MONAI/MONAI/pull/6305)
        to add `WandbStatsHandler` to [MonAI](https://github.com/Project-MONAI/MONAI).


    Args:
        iteration_log (bool): Whether to write data to Weights & Biases when iteration completed,
            default to `True`.
        epoch_log (bool): Whether to write data to Weights & Biases when epoch completed, default to
            `True`.
        epoch_event_writer (Optional[Callable[[Engine, Any], Any]]): Customized callable
            Weights & Biases writer for epoch level. Must accept parameter "engine" and
            "summary_writer", use default event writer if None.
        epoch_interval (int): The epoch interval at which the epoch_event_writer is called. Defaults
            to 1.
        iteration_event_writer (Optional[Callable[[Engine, Any], Any]]): Customized callable
            Weights & Biases writer for iteration level. Must accept parameter "engine" and
            "summary_writer", use default event writer if None.
        iteration_interval (int): The iteration interval at which the iteration_event_writer is
            called. Defaults to 1.
        output_transform (Callable): A callable that is used to transform the
            `ignite.engine.state.output` into a scalar to plot, or a dictionary of `{key: scalar}`. In
            the latter case, the output string will be formatted as key: value. By default this value
            plotting happens when every iteration completed. The default behavior is to print loss
            from output[0] as output is a decollated list and we replicated loss value for every item
            of the decollated list. `engine.state` and `output_transform` inherit from the ignite
            concept: https://pytorch.org/ignite/concepts.html#state, explanation and usage example are
            in the tutorial:
            https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
        global_epoch_transform (Callable): A callable that is used to customize global epoch number. For
            example, in evaluation, the evaluator engine might want to use trainer engines epoch number
            when plotting epoch vs metric curves.
        state_attributes (Optional[Sequence[str]]): Expected attributes from `engine.state`, if provided,
            will extract them when epoch completed.
        tag_name (str): When iteration output is a scalar, tag_name is used to plot, defaults to `'Loss'`.
    """

    def __init__(
        self,
        iteration_log: bool = True,
        epoch_log: bool = True,
        epoch_event_writer: Optional[Callable[[Engine, Any], Any]] = None,
        epoch_interval: int = 1,
        iteration_event_writer: Optional[Callable[[Engine, Any], Any]] = None,
        iteration_interval: int = 1,
        output_transform: Callable = lambda x: x[0],
        global_epoch_transform: Callable = lambda x: x,
        state_attributes: Optional[Sequence[str]] = None,
        tag_name: str = DEFAULT_TAG,
    ):
        if wandb.run is None:
            raise wandb.Error("You must call `wandb.init()` before WandbStatsHandler()")

        self.iteration_log = iteration_log
        self.epoch_log = epoch_log
        self.epoch_event_writer = epoch_event_writer
        self.epoch_interval = epoch_interval
        self.iteration_event_writer = iteration_event_writer
        self.iteration_interval = iteration_interval
        self.output_transform = output_transform
        self.global_epoch_transform = global_epoch_transform
        self.state_attributes = state_attributes
        self.tag_name = tag_name

    def attach(self, engine: Engine) -> None:
        """
        Register a set of Ignite Event-Handlers to a specified Ignite engine.

        Args:
            engine (ignite.engine.engine.Engine): Ignite Engine, it can be a trainer, validator
                or evaluator.
        """
        if self.iteration_log and not engine.has_event_handler(
            self.iteration_completed, Events.ITERATION_COMPLETED
        ):
            engine.add_event_handler(
                Events.ITERATION_COMPLETED(every=self.iteration_interval),
                self.iteration_completed,
            )
        if self.epoch_log and not engine.has_event_handler(
            self.epoch_completed, Events.EPOCH_COMPLETED
        ):
            engine.add_event_handler(
                Events.EPOCH_COMPLETED(every=self.epoch_interval), self.epoch_completed
            )

    def epoch_completed(self, engine: Engine) -> None:
        """
        Handler for train or validation/evaluation epoch completed Event. Write epoch level events
        to Weights & Biases, default values are from Ignite `engine.state.metrics` dict.

        Args:
            engine (ignite.engine.engine.Engine): Ignite Engine, it can be a trainer, validator
                or evaluator.
        """
        if self.epoch_event_writer is not None:
            self.epoch_event_writer(engine)
        else:
            self._default_epoch_writer(engine)

    def iteration_completed(self, engine: Engine) -> None:
        """
        Handler for train or validation/evaluation iteration completed Event. Write iteration level
        events to Weighs & Biases, default values are from Ignite `engine.state.output`.

        Args:
            engine (ignite.engine.engine.Engine): Ignite Engine, it can be a trainer, validator
                or evaluator.
        """
        if self.iteration_event_writer is not None:
            self.iteration_event_writer(engine)
        else:
            self._default_iteration_writer(engine)

    def _default_epoch_writer(self, engine: Engine) -> None:
        """
        Execute epoch level event write operation. Default to write the values from Ignite
        `engine.state.metrics` dict and write the values of specified attributes of `engine.state`
        to [Weights & Biases](https://wandb.ai/site).

        Args:
            engine (ignite.engine.engine.Engine): Ignite Engine, it can be a trainer, validator
                or evaluator.
        """
        summary_dict = engine.state.metrics

        for key, value in summary_dict.items():
            if is_scalar(value):
                value = value.item() if isinstance(value, torch.Tensor) else value
                wandb.log({key: value})

        if self.state_attributes is not None:
            for attr in self.state_attributes:
                value = getattr(engine.state, attr, None)
                value = value.item() if isinstance(value, torch.Tensor) else value
                wandb.log({attr: value})

    def _default_iteration_writer(self, engine: Engine) -> None:
        """
        Execute iteration level event write operation based on Ignite `engine.state.output` data.
        Extract the values from `self.output_transform(engine.state.output)`. Since
        `engine.state.output` is a decollated list and we replicated the loss value for every item
        of the decollated list, the default behavior is to track the loss from `output[0]`.

        Args:
            engine (ignite.engine.engine.Engine): Ignite Engine, it can be a trainer, validator
                or evaluator.
        """
        loss = self.output_transform(engine.state.output)
        if loss is None:
            return  # do nothing if output is empty
        log_dict = dict()
        if isinstance(loss, dict):
            for key, value in loss.items():
                if not is_scalar(value):
                    warnings.warn(
                        "ignoring non-scalar output in WandbStatsHandler,"
                        " make sure `output_transform(engine.state.output)` returns"
                        " a scalar or dictionary of key and scalar pairs to avoid this warning."
                        " {}:{}".format(key, type(value))
                    )
                    continue  # not plot multi dimensional output
                log_dict[key] = (
                    value.item() if isinstance(value, torch.Tensor) else value
                )
        elif is_scalar(loss):  # not printing multi dimensional output
            log_dict[self.tag_name] = (
                loss.item() if isinstance(loss, torch.Tensor) else loss
            )
        else:
            warnings.warn(
                "ignoring non-scalar output in WandbStatsHandler,"
                " make sure `output_transform(engine.state.output)` returns"
                " a scalar or a dictionary of key and scalar pairs to avoid this warning."
                " {}".format(type(loss))
            )

        wandb.log(log_dict)

    def close(self):
        """Close `WandbStatsHandler`"""
        wandb.finish()

attach(engine)

Register a set of Ignite Event-Handlers to a specified Ignite engine.

Parameters:

Name Type Description Default
engine Engine

Ignite Engine, it can be a trainer, validator or evaluator.

required
Source code in wandb_addons/monai/stats_handler.py
def attach(self, engine: Engine) -> None:
    """
    Register a set of Ignite Event-Handlers to a specified Ignite engine.

    Args:
        engine (ignite.engine.engine.Engine): Ignite Engine, it can be a trainer, validator
            or evaluator.
    """
    if self.iteration_log and not engine.has_event_handler(
        self.iteration_completed, Events.ITERATION_COMPLETED
    ):
        engine.add_event_handler(
            Events.ITERATION_COMPLETED(every=self.iteration_interval),
            self.iteration_completed,
        )
    if self.epoch_log and not engine.has_event_handler(
        self.epoch_completed, Events.EPOCH_COMPLETED
    ):
        engine.add_event_handler(
            Events.EPOCH_COMPLETED(every=self.epoch_interval), self.epoch_completed
        )

close()

Close WandbStatsHandler

Source code in wandb_addons/monai/stats_handler.py
def close(self):
    """Close `WandbStatsHandler`"""
    wandb.finish()

epoch_completed(engine)

Handler for train or validation/evaluation epoch completed Event. Write epoch level events to Weights & Biases, default values are from Ignite engine.state.metrics dict.

Parameters:

Name Type Description Default
engine Engine

Ignite Engine, it can be a trainer, validator or evaluator.

required
Source code in wandb_addons/monai/stats_handler.py
def epoch_completed(self, engine: Engine) -> None:
    """
    Handler for train or validation/evaluation epoch completed Event. Write epoch level events
    to Weights & Biases, default values are from Ignite `engine.state.metrics` dict.

    Args:
        engine (ignite.engine.engine.Engine): Ignite Engine, it can be a trainer, validator
            or evaluator.
    """
    if self.epoch_event_writer is not None:
        self.epoch_event_writer(engine)
    else:
        self._default_epoch_writer(engine)

iteration_completed(engine)

Handler for train or validation/evaluation iteration completed Event. Write iteration level events to Weighs & Biases, default values are from Ignite engine.state.output.

Parameters:

Name Type Description Default
engine Engine

Ignite Engine, it can be a trainer, validator or evaluator.

required
Source code in wandb_addons/monai/stats_handler.py
def iteration_completed(self, engine: Engine) -> None:
    """
    Handler for train or validation/evaluation iteration completed Event. Write iteration level
    events to Weighs & Biases, default values are from Ignite `engine.state.output`.

    Args:
        engine (ignite.engine.engine.Engine): Ignite Engine, it can be a trainer, validator
            or evaluator.
    """
    if self.iteration_event_writer is not None:
        self.iteration_event_writer(engine)
    else:
        self._default_iteration_writer(engine)

WandbModelCheckpointSaver

Bases: BaseSaveHandler

WandbModelCheckpointSaver is a save handler for PyTorch Ignite that saves model checkpoints as Weights & Biases Artifacts.

Usage:

from wandb_addons.monai import WandbModelCheckpointSaver

checkpoint_handler = Checkpoint(
    {"model": model, "optimizer": optimizer},
    WandbModelCheckpointSaver(),
    n_saved=1,
    filename_prefix="best_checkpoint",
    score_name=metric_name,
    global_step_transform=global_step_from_engine(trainer)
)
evaluator.add_event_handler(Events.COMPLETED, checkpoint_handler)
Source code in wandb_addons/monai/checkpoint_handler.py
class WandbModelCheckpointSaver(BaseSaveHandler):
    """`WandbModelCheckpointSaver` is a save handler for PyTorch Ignite that saves model checkpoints as
    [Weights & Biases Artifacts](https://docs.wandb.ai/guides/artifacts).

    Usage:

    ```python
    from wandb_addons.monai import WandbModelCheckpointSaver

    checkpoint_handler = Checkpoint(
        {"model": model, "optimizer": optimizer},
        WandbModelCheckpointSaver(),
        n_saved=1,
        filename_prefix="best_checkpoint",
        score_name=metric_name,
        global_step_transform=global_step_from_engine(trainer)
    )
    evaluator.add_event_handler(Events.COMPLETED, checkpoint_handler)
    ```
    """

    @one_rank_only()
    def __init__(self):
        if wandb.run is None:
            raise wandb.Error(
                "You must call `wandb.init()` before `WandbModelCheckpointSaver()`"
            )

        self.checkpoint_dir = tempfile.mkdtemp()

    @one_rank_only()
    def __call__(self, checkpoint: Mapping, filename: Union[str, Path]):
        checkpoint_path = os.path.join(self.checkpoint_dir, filename)
        torch.save(checkpoint, checkpoint_path)

        artifact = wandb.Artifact(f"{wandb.run.id}-checkpoint", type="model")

        if os.path.isfile(checkpoint_path):
            artifact.add_file(checkpoint_path)
        elif os.path.isdir(checkpoint_path):
            artifact.add_dir(checkpoint_path)
        else:
            raise wandb.Error(
                f"Unable to local checkpoint path {checkpoint_path} to artifact"
            )

        wandb.log_artifact(artifact)

    @one_rank_only()
    def remove(self, filename):
        if os.path.exists(filename):
            shutil.rmtree(filename)