Skip to main content

Loss Functions

This page is the API reference for all loss classes in pytorch_segmentation_models_trainer.

Base classes and segmentation/frame-field losses live in:

from pytorch_segmentation_models_trainer.custom_losses.base_loss import <ClassName>

The LossWrapper and builder utilities live in:

from pytorch_segmentation_models_trainer.custom_losses.loss_builder import LossWrapper

Loss

from pytorch_segmentation_models_trainer.custom_losses.base_loss import Loss

Abstract base class for all custom loss functions. Extends torch.nn.Module. Implements automatic loss normalization via a running average meter, and supports distributed training synchronization.

Constructor

Loss(name: str)
ParameterTypeDefaultDescription
namestrrequiredHuman-readable name used for logging and __repr__.

forward Signature

def forward(self, pred_batch, gt_batch, normalize: bool = True) -> Tuple[torch.Tensor, Dict]

Calls self.compute(pred_batch, gt_batch), optionally divides the result by self.norm[0] (the running average of past loss values), then returns a (loss_value, extra_info) tuple.

  • When normalize=True (default), the raw loss is divided by the running norm. The norm must be > 1e-9; an AssertionError is raised otherwise.
  • extra_info is a dict populated by self.compute() with intermediate tensors useful for visualization.

Normalization Methods

MethodDescription
reset_norm()Resets the running average meter to an initial value of 1.
update_norm(pred_batch, gt_batch, nums)Runs a forward pass (without gradient) and updates the running average with the result. Call this before training begins to calibrate the norm.
sync(world_size)Performs an all_reduce to average self.norm across all GPUs when using distributed training.

Abstract Method

def compute(self, pred_batch, gt_batch) -> torch.Tensor

Subclasses must implement this method. It should return a scalar tensor representing the raw (unnormalized) loss value.

YAML Config Snippet

loss:
_target_: pytorch_segmentation_models_trainer.custom_losses.base_loss.SegLoss
name: my_loss
gt_channel_selector: 0

LossWrapper

from pytorch_segmentation_models_trainer.custom_losses.loss_builder import LossWrapper

Wraps any torch.nn.Module loss function to make it compatible with MultiLoss. This enables standard PyTorch losses (e.g. nn.BCELoss, nn.CrossEntropyLoss) and third-party losses to participate in the compound loss system without modification.

Constructor

LossWrapper(loss_func: nn.Module, name: str = None)
ParameterTypeDefaultDescription
loss_funcnn.ModulerequiredThe loss function to wrap. Can be a custom Loss subclass or any standard PyTorch/third-party loss module.
namestrNoneName for the wrapper. If None, uses loss_func.name (if present) or loss_func.__class__.__name__.

forward Signature

def forward(self, pred_batch, gt_batch, normalize: bool = True) -> Tuple[torch.Tensor, Dict]
  • If the wrapped loss is a Loss subclass, delegates to loss_func(pred_batch, gt_batch, normalize=normalize).
  • If the wrapped loss is a standard PyTorch loss, calls loss_func(pred_batch, gt_batch) and returns (loss_value, {}).

YAML Config Snippet

compound_loss:
losses:
- loss:
_target_: segmentation_models_pytorch.losses.DiceLoss
mode: binary
weight: 1.0

Non-custom losses are wrapped automatically by build_compound_loss_from_config.


MultiLoss

from pytorch_segmentation_models_trainer.custom_losses.base_loss import MultiLoss

Combines multiple loss functions with scalar or epoch-scheduled weights. Runs optional pre-processing steps before computing individual losses.

Constructor

MultiLoss(
loss_funcs: List[nn.Module],
weights: List[Union[float, List[float]]],
epoch_thresholds: Optional[List[float]] = None,
pre_processes: Optional[List[Callable]] = None,
)
ParameterTypeDefaultDescription
loss_funcsList[nn.Module]requiredList of individual loss functions. Must be the same length as weights.
weightsList[Union[float, List[float]]]requiredPer-loss weights. Each entry can be a scalar float for a constant weight, or a list of floats for epoch-scheduled weights (interpolated via scipy.interpolate.interp1d against epoch_thresholds).
epoch_thresholdsList[float]NoneEpoch values corresponding to each position in a list-type weight. Required when any weight is a list.
pre_processesList[Callable]NoneList of callables with signature (pred_batch, gt_batch) -> (pred_batch, gt_batch). Called in order before any loss is computed.

forward Signature

def forward(
self,
pred_batch,
gt_batch,
normalize: bool = True,
epoch: Optional[float] = None,
) -> Tuple[torch.Tensor, Dict, Dict]

Returns (total_loss, individual_losses_dict, extra_dict):

  • total_loss: Weighted sum of all individual losses.
  • individual_losses_dict: {loss_name: scalar_tensor} for each component.
  • extra_dict: {loss_name: extra_info_dict} forwarded from each component's forward.

When epoch is provided and a weight is a scheduled list, the weight at that epoch is interpolated and applied.

Weight Scheduling Example

compound_loss:
epoch_thresholds: [0, 5, 10]
losses:
- loss:
_target_: pytorch_segmentation_models_trainer.custom_losses.base_loss.SegLoss
name: seg
gt_channel_selector: 0
bce_coef: 0.5
dice_coef: 0.5
weight: [0.0, 0.5, 1.0] # ramps from 0 to 1 over 10 epochs
- loss:
_target_: pytorch_segmentation_models_trainer.custom_losses.base_loss.CrossfieldAlignLoss
name: crossfield_align
weight: 1.0

SegLoss

from pytorch_segmentation_models_trainer.custom_losses.base_loss import SegLoss

Combined segmentation loss mixing Binary Cross Entropy (BCE) and Dice loss. For multi-class tasks, uses torch.nn.CrossEntropyLoss. Optionally supports mixed precision, MixUp augmentation, and label smoothing.

Constructor

SegLoss(
name: str,
gt_channel_selector: int,
n_classes: int = 2,
bce_coef: float = 0.5,
dice_coef: float = 0.5,
tversky_focal_coef: float = 0,
use_mixed_precision: bool = False,
use_mixup: bool = False,
mixup_alpha: float = 0.5,
use_label_smooth: bool = False,
smooth_factor: float = 0.0,
)
ParameterTypeDefaultDescription
namestrrequiredLoss name for logging.
gt_channel_selectorint or slicerequiredSelects which channel(s) of gt_batch["gt_polygons_image"] to compare against pred_batch["seg"].
n_classesint2Number of classes. When 2, uses F.binary_cross_entropy; otherwise uses CrossEntropyLoss.
bce_coeffloat0.5Weight applied to the BCE/cross-entropy term.
dice_coeffloat0.5Weight applied to the Dice loss term.
tversky_focal_coeffloat0Weight applied to the focal Tversky loss term. 0 disables this term.
use_mixed_precisionboolFalseWhen True, uses F.binary_cross_entropy_with_logits instead of F.binary_cross_entropy (skip the sigmoid in the model head).
use_mixupboolFalseWhen True, expects MixUp-augmented fields in gt_batch (mixup_pred, mixup_y_a, mixup_y_b, mixup_lam).
mixup_alphafloat0.5Alpha parameter for the Beta distribution used in MixUp.
use_label_smoothboolFalseWhen True, applies label smoothing to the cross-entropy term.
smooth_factorfloat0.0Smoothing factor for label smoothing.

The final loss value is: bce_coef * BCE + dice_coef * Dice + tversky_focal_coef * FocalTversky.

YAML Config Snippet

compound_loss:
losses:
- loss:
_target_: pytorch_segmentation_models_trainer.custom_losses.base_loss.SegLoss
name: seg
gt_channel_selector: 0
n_classes: 2
bce_coef: 0.5
dice_coef: 0.5
weight: 10.0

CrossfieldAlignLoss

from pytorch_segmentation_models_trainer.custom_losses.base_loss import CrossfieldAlignLoss

Frame-field alignment loss. Penalizes the deviation of the predicted frame field (c0, c2) from the ground-truth tangent field (gt_field) along polygon edges. The loss is masked to zero on non-edge pixels.

Constructor

CrossfieldAlignLoss(name: str)
ParameterTypeDefaultDescription
namestrrequiredLoss name for logging.

compute Signature

def compute(self, pred_batch, gt_batch) -> torch.Tensor

Expected batch keys:

BatchKeyDescription
pred_batch"crossfield"Tensor (N, 4, H, W): first 2 channels are c0, last 2 are c2.
gt_batch"gt_field"Complex-valued ground-truth tangent field (2-channel tensor).
gt_batch"gt_polygons_image"At least 2-channel mask; channel 1 is the edge mask used for spatial weighting.

Stores gt_batch["gt_field"] in self.extra_info for visualization.

YAML Config Snippet

compound_loss:
losses:
- loss:
_target_: pytorch_segmentation_models_trainer.custom_losses.base_loss.CrossfieldAlignLoss
name: crossfield_align
weight: 1.0

CrossfieldAlign90Loss

from pytorch_segmentation_models_trainer.custom_losses.base_loss import CrossfieldAlign90Loss

90-degree symmetry loss for frame fields. Enforces that the predicted frame field is also well-aligned when the ground-truth tangent is rotated 90°. This encourages right-angle corners. The loss is applied on edge pixels minus vertex pixels.

Constructor

CrossfieldAlign90Loss(name: str)
ParameterTypeDefaultDescription
namestrrequiredLoss name for logging.

compute Signature

def compute(self, pred_batch, gt_batch) -> torch.Tensor

Expected batch keys:

BatchKeyDescription
pred_batch"crossfield"Tensor (N, 4, H, W).
gt_batch"gt_field"2-channel ground-truth tangent field.
gt_batch"gt_polygons_image"Exactly 3-channel mask (interior, edge, vertex).

YAML Config Snippet

compound_loss:
losses:
- loss:
_target_: pytorch_segmentation_models_trainer.custom_losses.base_loss.CrossfieldAlign90Loss
name: crossfield_align90
weight: 1.0

CrossfieldSmoothLoss

from pytorch_segmentation_models_trainer.custom_losses.base_loss import CrossfieldSmoothLoss

Spatial smoothness loss for frame fields. Penalizes high-frequency spatial variation in the frame field using a Laplacian penalty, but only in non-edge regions. This prevents the frame field from oscillating in homogeneous areas.

Constructor

CrossfieldSmoothLoss(name: str)
ParameterTypeDefaultDescription
namestrrequiredLoss name for logging.

Internally creates a frame_field_utils.LaplacianPenalty(channels=4) module.

compute Signature

def compute(self, pred_batch, gt_batch) -> torch.Tensor

Expected batch keys:

BatchKeyDescription
pred_batch"crossfield"Tensor (N, 4, H, W).
gt_batch"gt_polygons_image"At least 2-channel mask; channel 1 is the edge mask (inverted to define the smooth region).

YAML Config Snippet

compound_loss:
losses:
- loss:
_target_: pytorch_segmentation_models_trainer.custom_losses.base_loss.CrossfieldSmoothLoss
name: crossfield_smooth
weight: 0.1

ComputeSegGrads

from pytorch_segmentation_models_trainer.custom_losses.base_loss import ComputeSegGrads

Not a loss function. A pre-processing callable used inside MultiLoss.pre_processes. Computes spatial gradients of the predicted segmentation mask and populates pred_batch with derived tensors required by SegCrossfieldLoss and SegEdgeInteriorLoss.

Constructor

ComputeSegGrads(device)
ParameterTypeDescription
devicestr or torch.deviceDevice on which to instantiate the SpatialGradient (Scharr filter) operator.

__call__ Signature

def __call__(self, pred_batch, gt_batch) -> Tuple[dict, dict]

Adds the following keys to pred_batch:

KeyShapeDescription
"seg_grads"(N, C, 2, H, W)Spatial gradients of pred_batch["seg"] computed with the Scharr operator (scaled by 2).
"seg_grad_norm"(N, C, H, W)L2 norm of the spatial gradient at each pixel.
"seg_grads_normed"(N, C, 2, H, W)Unit-normalized gradient vectors (seg_grads / (seg_grad_norm + 1e-6)).

Usage in Config

ComputeSegGrads is instantiated programmatically by build_combined_loss when coupling losses require gradient information. When using compound_loss, add it via pre_processes:

compound_loss:
pre_processes:
- _target_: pytorch_segmentation_models_trainer.custom_losses.base_loss.ComputeSegGrads
device: cuda
losses:
- loss:
_target_: pytorch_segmentation_models_trainer.custom_losses.base_loss.SegLoss
name: seg
gt_channel_selector: 0
weight: 10.0

Loss Configuration Patterns

The framework supports three configuration patterns for specifying losses in your Hydra config.

Pattern 1: Single Loss

Use loss directly under the model config for a simple, single loss function:

loss:
_target_: pytorch_segmentation_models_trainer.custom_losses.base_loss.SegLoss
name: seg
gt_channel_selector: 0
bce_coef: 0.5
dice_coef: 0.5

Use compound_loss under loss_params to combine multiple losses with individual weights. Supports both custom Loss subclasses and standard PyTorch/third-party losses via automatic LossWrapper wrapping:

loss_params:
compound_loss:
epoch_thresholds: [0, 5, 10] # required if any weight is a list
losses:
- loss:
_target_: pytorch_segmentation_models_trainer.custom_losses.base_loss.SegLoss
name: seg
gt_channel_selector: 0
bce_coef: 0.5
dice_coef: 0.5
weight: 10.0

- loss:
_target_: pytorch_segmentation_models_trainer.custom_losses.base_loss.CrossfieldAlignLoss
name: crossfield_align
weight: [0.0, 0.5, 1.0] # ramps up over 10 epochs

- loss:
_target_: segmentation_models_pytorch.losses.DiceLoss
mode: binary
weight: 1.0 # third-party loss, auto-wrapped

Pattern 3: Loss List (legacy multiloss)

Use multi_loss / multiloss under loss_params with the legacy build_combined_loss builder. This is the older configuration style and is maintained for backward compatibility:

loss_params:
multiloss:
coefs:
epoch_thresholds: [0, 10]
seg: 10.0
crossfield_align: [0.0, 1.0]
crossfield_align90: [0.0, 1.0]
crossfield_smooth: 0.1
seg_interior_crossfield: 1.0
seg_edge_crossfield: 1.0
seg_edge_interior: 1.0
seg_loss_params:
bce_coef: 0.5
dice_coef: 0.5
use_dist: false
use_size: false
w0: 10.0
sigma: 5.0

The build_loss_from_config utility automatically selects between Pattern 2 and Pattern 3 based on which keys are present in cfg.loss_params.