Skip to main content

Dual-Head Training

UPerNetDualHead extends UPerNet with two independent decoders sharing a single encoder. The two heads are supervised with different label types simultaneously:

  • Head A — supervised with hard labels (integer class indices, e.g. aggregated annotations).
  • Head B — supervised with soft labels (float probabilities, e.g. Dawid–Skene posteriors or annotator-uncertainty posteriors).

A consistency loss couples the two heads during training, encouraging them to agree while allowing each to specialise.


Architecture

encoder (shared)
├── decoder_A (UPerNet) → seg_head_A → logits_A (hard labels)
└── decoder_B (UPerNet) → seg_head_B → logits_B (soft labels)

forward() returns the primary head output selected by inference_head. Intermediate logits from both heads are stored as last_logits_A / last_logits_B so the loss function can access them without changing the standard forward() → loss(pred, target) interface.


When to Use

Dual-head training is beneficial when:

  • You have two label sources for the same data: hard expert-aggregated masks and soft probabilistic labels from a Bayesian annotation model (e.g. Dawid–Skene).
  • You want a model that is calibrated to annotator uncertainty while also being decisive at inference time.
  • You want to regularise a hard-label model with a soft-label auxiliary objective.

Model Config

model:
_target_: pytorch_segmentation_models_trainer.custom_models.upernet_dual_head.UPerNetDualHead
encoder_name: tu-convnextv2_tiny.fcmae_ft_in22k_in1k_384
encoder_weights: imagenet
in_channels: 3
classes: 6
decoder_channels: 256
inference_head: average # "A", "B", or "average"

Parameters

ParameterTypeDefaultDescription
encoder_namestr"tu-convnextv2_tiny..."timm/SMP encoder name.
encoder_depthint5Number of encoder stages.
encoder_weightsstr"imagenet"Pretrained weights identifier.
decoder_channelsint256Intermediate channels per decoder.
in_channelsint3Input image bands.
classesint6Number of output segmentation classes.
inference_headstr"average"Head used at inference: "A" (hard-trained), "B" (soft-trained), or "average" (mean of both logits).

Dataset Configuration

The training dataset must supply both hard and soft masks. RandomCropSegmentationDataset with soft_labels: true provides the required output format.

train_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.RandomCropSegmentationDataset
input_csv_path: /data/dual_label/train.csv
crop_size: 512
samples_per_epoch: 8000
n_classes: 6
soft_labels: true # enables float soft-label output alongside hard masks
augmentation_list:
- _target_: albumentations.HorizontalFlip
p: 0.5
- _target_: albumentations.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- _target_: albumentations.pytorch.ToTensorV2
data_loader:
shuffle: true
num_workers: 8
batch_size: 16
drop_last: true

The CSV should include a soft_mask column in addition to the standard image and mask columns:

image,mask,soft_mask
/data/scenes/area_a.tif,/data/masks/hard/area_a.tif,/data/masks/soft/area_a.tif

Loss Configuration

The _shared_step training loop detects dual-head models automatically and routes Head A to the hard-label loss and Head B to the soft-label loss. A consistency regularisation term between the two heads is also computed internally.

loss:
_target_: segmentation_models_pytorch.losses.DiceLoss
mode: multiclass

No special dual-head loss configuration is needed; the consistency coupling is handled inside _shared_step.


Inference

At inference time only one head is active, controlled by inference_head:

inference_headBehaviour
"A"Returns logits from the hard-label head. Most decisive.
"B"Returns logits from the soft-label head. Better calibrated.
"average"Returns the mean of both heads' logits (default). Balances decisiveness and calibration.

Change the inference head without retraining by editing the YAML and reloading the checkpoint:

model:
inference_head: B # switch to soft-label head at inference

Full Training Example

model:
_target_: pytorch_segmentation_models_trainer.custom_models.upernet_dual_head.UPerNetDualHead
encoder_name: tu-convnextv2_tiny.fcmae_ft_in22k_in1k_384
encoder_weights: imagenet
in_channels: 3
classes: 6
decoder_channels: 256
inference_head: average

loss:
_target_: segmentation_models_pytorch.losses.DiceLoss
mode: multiclass

train_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.RandomCropSegmentationDataset
input_csv_path: /data/dual_label/train.csv
crop_size: 512
samples_per_epoch: 8000
n_classes: 6
soft_labels: true
augmentation_list:
- _target_: albumentations.HorizontalFlip
p: 0.5
- _target_: albumentations.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- _target_: albumentations.pytorch.ToTensorV2
data_loader:
shuffle: true
num_workers: 8
batch_size: 16
drop_last: true

val_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.RandomCropSegmentationDataset
input_csv_path: /data/dual_label/val.csv
crop_size: 512
grid_mode: true
n_classes: 6
augmentation_list:
- _target_: albumentations.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- _target_: albumentations.pytorch.ToTensorV2
data_loader:
shuffle: false
num_workers: 4
batch_size: 16
drop_last: false

callbacks:
- _target_: pytorch_segmentation_models_trainer.custom_callbacks.training_callbacks.EMACallback
decay: 0.999
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: loss/val
mode: min
save_top_k: 3