Skip to main content

Custom Callbacks

Custom callbacks extend PyTorch Lightning's Callback and BasePredictionWriter interfaces to add visualisation, metrics reporting, loss normalisation, and polygonisation steps at specific points in the training or prediction loop.

Callbacks are split across three modules:

ModulePath
Image visualisationpytorch_segmentation_models_trainer.custom_callbacks.image_callbacks
Metricspytorch_segmentation_models_trainer.custom_callbacks.metrics_callbacks
Training utilitiespytorch_segmentation_models_trainer.custom_callbacks.training_callbacks

Adding Callbacks to Your Config

Callbacks are listed under the callbacks: key in your training YAML. Hydra instantiates each entry using its _target_ field.

trainer:
_target_: pytorch_lightning.Trainer
max_epochs: 100
callbacks:
- _target_: pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks.WarmupCallback
warmup_epochs: 5
- _target_: pytorch_segmentation_models_trainer.custom_callbacks.image_callbacks.ImageSegmentationResultCallback
n_samples: 4
log_every_k_epochs: 5
- _target_: pytorch_segmentation_models_trainer.custom_callbacks.metrics_callbacks.ConfusionMatrixCallback
num_classes: 6
class_names: [background, building, road, water, vegetation, bare_soil]
log_every_n_epochs: 10

Image Visualisation Callbacks

ImageSegmentationResultCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.image_callbacks import ImageSegmentationResultCallback

After each validation epoch, runs the model on a small sample of validation images, generates side-by-side ground-truth / prediction plots, saves them to disk, and logs them to TensorBoard.

Fires on

  • on_sanity_check_end — resolves and creates the output directory.
  • on_validation_epoch_end — generates and logs visualisations (rank 0 only).

Constructor Parameters

ParameterTypeDefaultDescription
n_samplesintNoneNumber of images to visualise. Defaults to the validation batch size when None.
output_pathstrNoneDirectory to write PNG files to. Defaults to <log_dir>/image_logs.
normalized_inputboolTrueWhether the input images are normalised; controls denormalisation before plotting.
norm_paramsdict{}kwargs passed to denormalize_np_array().
log_every_k_epochsint1Only generate visualisations every k epochs.

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.image_callbacks.ImageSegmentationResultCallback
n_samples: 8
log_every_k_epochs: 5
normalized_input: true

FrameFieldResultCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.image_callbacks import FrameFieldResultCallback

Subclass of ImageSegmentationResultCallback tailored for frame-field models. Reads "gt_polygons_image" from batches (a 2-channel ground-truth with polygon and boundary channels) and plots both channels alongside their predictions.

Fires on

Same as ImageSegmentationResultCallback.

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.image_callbacks.FrameFieldResultCallback
n_samples: 4
log_every_k_epochs: 10

FrameFieldOverlayedResultCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.image_callbacks import FrameFieldOverlayedResultCallback

Subclass of ImageSegmentationResultCallback that generates an overlayed visualisation combining the segmentation mask and the cross-field orientation on top of the original image. Logs directly to TensorBoard without saving files to disk.

Fires on

  • on_validation_epoch_end

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.image_callbacks.FrameFieldOverlayedResultCallback
n_samples: 4

ObjectDetectionResultCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.image_callbacks import ObjectDetectionResultCallback

Subclass of ImageSegmentationResultCallback for object detection models. Draws predicted bounding boxes (filtered by a score threshold) on the input image and logs the result to TensorBoard.

Additional Constructor Parameter

ParameterTypeDefaultDescription
thresholdfloat0.5Minimum confidence score for a box to be drawn.

Fires on

  • on_validation_epoch_end

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.image_callbacks.ObjectDetectionResultCallback
n_samples: 8
threshold: 0.4

PolygonRNNResultCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.image_callbacks import PolygonRNNResultCallback

Subclass of ImageSegmentationResultCallback for PolygonRNN models. Visualises ground-truth polygon vertices alongside predicted vertices overlaid on the source image.

Fires on

  • on_validation_epoch_end

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.image_callbacks.PolygonRNNResultCallback
n_samples: 8

ModPolyMapperResultCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.image_callbacks import ModPolyMapperResultCallback

Subclass of PolygonRNNResultCallback for the ModPolyMapper architecture (detection + PolygonRNN). Handles the combined detection + polygon output format.

Additional Constructor Parameters

ParameterTypeDefaultDescription
thresholdfloat0.5Detection score threshold.
show_label_scoresboolFalseWhether to overlay detection scores on the visualisation.
n_samplesint16Defaults to 16 (overrides parent default).

Fires on

  • on_validation_epoch_end

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.image_callbacks.ModPolyMapperResultCallback
n_samples: 16
threshold: 0.5
show_label_scores: true

Metrics Callbacks

ConfusionMatrixCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.metrics_callbacks import ConfusionMatrixCallback

Builds a multiclass confusion matrix across the full validation set using torchmetrics.ConfusionMatrix (GPU-efficient). At the end of every log_every_n_epochs epochs, the matrix is plotted as a seaborn heatmap, saved to disk, and logged to TensorBoard. Per-class accuracy and precision scalars are also logged.

Fires on

  • on_sanity_check_end — resolves the output directory.
  • on_validation_epoch_start — resets the confusion matrix accumulator.
  • on_validation_batch_end — updates the accumulator with batch predictions.
  • on_validation_epoch_end — computes, plots, and logs the matrix (rank 0 only).

Constructor Parameters

ParameterTypeDefaultDescription
num_classesintrequiredTotal number of classes.
class_nameslist[str]NoneHuman-readable class labels. Defaults to ["Class 0", "Class 1", ...].
normalizestr"true"Normalisation mode: "true" (row-normalised), "pred" (col-normalised), "all" (global), or None (raw counts).
log_every_n_epochsint10Plot frequency.
figsizetuple(12, 10)Matplotlib figure size.
output_pathstrNoneSave directory. Defaults to <log_dir>/confusion_matrices.

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.metrics_callbacks.ConfusionMatrixCallback
num_classes: 6
class_names: [background, building, road, water, vegetation, bare_soil]
normalize: true
log_every_n_epochs: 10
figsize: [14, 12]

ClassificationReportCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.metrics_callbacks import ClassificationReportCallback

Generates a sklearn.metrics.classification_report (precision, recall, F1, support per class) at the end of every log_every_n_epochs epochs and writes it to a timestamped .txt file.

Fires on

  • on_sanity_check_end — resolves the output directory.
  • on_validation_batch_end — accumulates flattened predictions and targets.
  • on_validation_epoch_end — computes and saves the report (rank 0 only).

Constructor Parameters

ParameterTypeDefaultDescription
num_classesintrequiredNumber of classes.
class_nameslist[str]NoneClass labels. Defaults to index strings.
log_every_n_epochsint10Report frequency.
output_pathstrNoneSave directory. Defaults to <log_dir>/classification_reports.

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.metrics_callbacks.ClassificationReportCallback
num_classes: 6
class_names: [background, building, road, water, vegetation, bare_soil]
log_every_n_epochs: 10

Training Utility Callbacks

EMACallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks import EMACallback

Maintains an Exponential Moving Average (EMA) of model weights alongside the online (SGD-updated) weights. The EMA model tends to be smoother and generalises better than the checkpoint from any single epoch.

How it works:

At each optimizer step the shadow weights are updated as:

shadow = decay_eff × shadow + (1 − decay_eff) × param

where decay_eff = min(decay, (step+1) / (step+10)) applies a warmup that prevents early random-initialisation weights from contaminating the EMA. During validation the shadow weights are swapped in so that metrics reflect the EMA model; after validation the online weights are restored so training continues normally. The EMA state is saved inside every checkpoint automatically.

Loading EMA checkpoints

Because EMACallback injects the EMA weights into state_dict at save time, loading a checkpoint produces the EMA model automatically. load_from_checkpoint must be called with strict=False (the framework already does this in predict.py).

Fires on

  • on_fit_start — initialises shadow weights from the current model parameters.
  • on_train_batch_end — updates shadow weights (only when the optimizer has actually stepped).
  • on_validation_epoch_start — swaps in shadow weights.
  • on_validation_epoch_end — restores online weights.
  • on_save_checkpoint — injects EMA weights into the checkpoint state dict.

Constructor Parameters

ParameterTypeDefaultDescription
decayfloat0.999Target EMA decay. Higher values → slower-moving average. Typical range: 0.990.9999.

Example Config

callbacks:
- _target_: pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks.EMACallback
decay: 0.999

- _target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: loss/val
mode: min
save_top_k: 3

MixStyleCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks import MixStyleCallback

Applies MixStyle (Zhou et al., ICLR 2021) as a forward hook on the encoder. MixStyle mixes feature-level statistics (mean and standard deviation) between random pairs of instances in a batch, creating implicit domain augmentation at zero parameter cost. This improves generalisation to unseen acquisition conditions, sensors, or temporal shifts.

How it works:

At each forward pass during training, with probability p, the feature map output of each hooked encoder stage is normalised to zero mean / unit variance and then re-scaled with a Beta-distributed convex combination of the original instance's and a random other instance's statistics:

mu_mix  = λ × mu_i  + (1−λ) × mu_j
sig_mix = λ × sig_i + (1−λ) × sig_j

The hook is only active during training (on_train_batch_start / on_train_batch_end guard). Hooks are registered on encoder.model.stages (SMP TimmUniversalEncoder) or encoder.stages; a warning is emitted if neither attribute is found and MixStyle is silently disabled.

Fires on

  • on_fit_start — registers forward hooks on the configured encoder stages.
  • on_train_batch_start — enables the hook.
  • on_train_batch_end — disables the hook.
  • on_validation_epoch_start — ensures the hook is disabled during validation.
  • on_fit_end — removes all hooks.

Constructor Parameters

ParameterTypeDefaultDescription
pfloat0.5Probability of applying MixStyle on a given forward pass.
alphafloat0.1Beta distribution parameter. Smaller values produce mixing closer to the original style.
stageslist[int][0, 1]Encoder stage indices to hook. Early stages carry more domain/style information.
Choosing stages

Hook only the earliest stages ([0, 1] or just [0]). Later stages encode semantic content that should not be mixed. For ResNet-style encoders, stages 0 and 1 correspond to the first two residual groups.

Example Config

callbacks:
- _target_: pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks.MixStyleCallback
p: 0.5
alpha: 0.1
stages: [0, 1]

WarmupCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks import WarmupCallback

Freezes the model encoder for the first warmup_epochs training epochs to let the decoder head stabilise before end-to-end fine-tuning begins. Calls pl_module.set_encoder_trainable(trainable) on the Lightning module, so the module must implement that method.

Fires on

  • on_fit_start — checks whether warmup has already elapsed (for resumed training).
  • on_train_epoch_start — freezes encoder weights if still in warmup.
  • on_train_epoch_end — unfreezes encoder weights once warmup is complete.

Constructor Parameters

ParameterTypeDefaultDescription
warmup_epochsint2Number of epochs during which the encoder is frozen.

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks.WarmupCallback
warmup_epochs: 5

FrameFieldOnlyCrossfieldWarmupCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks import FrameFieldOnlyCrossfieldWarmupCallback

Variant of WarmupCallback for frame-field models. During warmup, all weights except the cross-field head are frozen by calling pl_module.set_all_but_crossfield_trainable(trainable). This allows the cross-field head to initialise before full model training.

Fires on

Same as WarmupCallback.

Constructor Parameters

ParameterTypeDefaultDescription
warmup_epochsint2Warmup duration in epochs.

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks.FrameFieldOnlyCrossfieldWarmupCallback
warmup_epochs: 3

ComputeWeightNormLossesCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks import ComputeWeightNormLossesCallback

General-purpose callback for computing loss normalisation weights before training starts. Designed for models that use a compound (multi-term) loss where each term needs to be normalised to a comparable scale. Runs a partial forward pass over the training set on rank 0, then syncs normalisation values across DDP ranks.

The callback is a no-op if:

  • Normalisation has already been computed.
  • The model does not have loss_params in its config.
  • The compound loss does not have normalization_params.
  • The loss norm has already been updated.

Fires on

  • on_train_start — computes normalisation (rank 0 only).

Constructor Parameters

This callback takes no constructor arguments.

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks.ComputeWeightNormLossesCallback

FrameFieldComputeWeightNormLossesCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks import FrameFieldComputeWeightNormLossesCallback

Frame-field-specific variant of ComputeWeightNormLossesCallback. Uses the multiloss.normalization_params.min_samples and max_samples config fields to determine the number of batches, then calls pl_module.compute_loss_norms() directly.

Fires on

  • on_fit_start — computes and syncs loss normalisation.

Constructor Parameters

This callback takes no constructor arguments.

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks.FrameFieldComputeWeightNormLossesCallback

FrameFieldPolygonizerCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks import FrameFieldPolygonizerCallback

A BasePredictionWriter callback that runs polygonisation on each prediction batch during trainer.predict(). Instantiates the polygonizer from the model config and processes (seg, crossfield) output pairs in a thread pool to avoid blocking the GPU.

Fires on

  • on_predict_batch_end

Constructor Parameters

ParameterTypeDefaultDescription
write_intervalstr"batch"Lightning write interval ("batch" or "epoch").

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks.FrameFieldPolygonizerCallback
write_interval: batch

ActiveSkeletonsPolygonizerCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks import ActiveSkeletonsPolygonizerCallback

BasePredictionWriter callback that runs the Active Skeletons polygonisation algorithm on each prediction batch. Falls back to per-image processing when the batch-level call raises an exception, skipping only images that continue to fail.

Fires on

  • on_predict_batch_end

Constructor Parameters

ParameterTypeDefaultDescription
write_intervalstr"batch"Lightning write interval.

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks.ActiveSkeletonsPolygonizerCallback
write_interval: batch

ModPolymapperPolygonizerCallback

Import path

from pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks import ModPolymapperPolygonizerCallback

BasePredictionWriter callback for the ModPolyMapper prediction pipeline. Processes detection output through a PolygonRNN polygonizer in parallel threads. Optionally reprojects output polygon coordinates to world/CRS coordinates using the source raster profile.

Fires on

  • on_predict_batch_end

Constructor Parameters

ParameterTypeDefaultDescription
convert_output_to_world_coordsboolTrueWhen True, reads the raster CRS profile and reprojects polygon coordinates.
write_intervalstr"batch"Lightning write interval.

Example Config

- _target_: pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks.ModPolymapperPolygonizerCallback
convert_output_to_world_coords: true
write_interval: batch