Skip to main content

Domain Adaptation — Configuration Reference

Full reference for all fields in the domain_adaptation config section used by DomainAdaptationModel.


Config skeleton

The snippet below shows the top-level structure of a complete DA training config. Two model-related keys coexist and serve different purposes:

  • pl_model — tells train.py which LightningModule subclass to instantiate. For domain adaptation this must always be DomainAdaptationModel.
  • model — the segmentation network (nn.Module) consumed by Model.get_model(). It is the same key used in standard (non-DA) training configs.
# ── LightningModule ──────────────────────────────────────────────────────────
# Tells train.py to instantiate DomainAdaptationModel instead of the default Model.
pl_model:
_target_: pytorch_segmentation_models_trainer.model_loader.domain_adaptation_model.DomainAdaptationModel

# ── Segmentation network (nn.Module) ─────────────────────────────────────────
# Same key used in standard segmentation configs — read by Model.get_model().
# DomainAdaptationModel forwards both source and target batches through this network.
model:
_target_: segmentation_models_pytorch.Unet
encoder_name: resnet50
encoder_weights: imagenet
in_channels: 3
classes: 2

# ── Segmentation loss (source domain only) ───────────────────────────────────
loss:
_target_: torch.nn.CrossEntropyLoss

optimizer:
_target_: torch.optim.AdamW
lr: 1.0e-4
weight_decay: 1.0e-4

hyperparameters:
batch_size: 8
epochs: 50

pl_trainer:
max_epochs: ${hyperparameters.epochs}
accelerator: auto
devices: 1

# ── DA-specific config ───────────────────────────────────────────────────────
domain_adaptation:
method:
_target_: my_package.methods.MyDAMethod
lambda_da: 1.0

source_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
input_csv_path: /data/source/train.csv
data_loader:
batch_size: ${hyperparameters.batch_size}
num_workers: 4
shuffle: true

target_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
input_csv_path: /data/target/train.csv # masks not required for UDA
data_loader:
batch_size: ${hyperparameters.batch_size}
num_workers: 4
shuffle: true

source_val_dataset: null # optional — omit to disable forgetting monitoring
target_val_dataset: null # optional — omit to disable target evaluation
feature_layers: [] # optional — only needed when method.requires_features=True
pretrained_checkpoint: null # optional — warm-start from an existing checkpoint

Top-level config keys

KeyTypeRequiredDescription
pl_model._target_stringyesMust be pytorch_segmentation_models_trainer.model_loader.domain_adaptation_model.DomainAdaptationModel
modeldictyesSegmentation network config (same as standard training)
lossdictyesSegmentation loss config (applied to source domain only)
optimizerdictyesOptimizer config
hyperparametersdictyesMust contain at least batch_size and epochs
pl_trainerdictyesPyTorch Lightning Trainer kwargs
domain_adaptationdictyesDA-specific config (detailed below)
scheduler_listlistnoLR schedulers — same format as standard training
callbackslistnoPL callbacks — same format as standard training
loggerdictnoTensorBoard, W&B, CSV logger
metricslistnotorchmetrics metrics

domain_adaptation section

method

The DA method to use. Hydra instantiates this as a BaseDomainAdaptationMethod subclass.

domain_adaptation:
method:
_target_: my_package.methods.MyMethod # required
lambda_da: 1.0 # default: 1.0
# ... any extra kwargs accepted by your method's __init__
FieldTypeDefaultDescription
_target_stringFully-qualified class name of the method
lambda_dafloat1.0Global DA loss weight: total = seg_loss + lambda_da * da_loss. Used as fallback when no lambda_schedule is set on the method

source_dataset

Labeled source-domain training dataset.

domain_adaptation:
source_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
input_csv_path: /data/source/train.csv
data_loader:
batch_size: 8
num_workers: 4
shuffle: true
pin_memory: true
drop_last: true
prefetch_factor: 2
persistent_workers: false
FieldTypeDefaultDescription
_target_stringDataset class to instantiate
input_csv_pathstringPath to the CSV listing image/mask pairs
data_loader.batch_sizeint8Samples per batch
data_loader.num_workersint4DataLoader worker processes
data_loader.shufflebooltrueShuffle at each epoch
data_loader.pin_memorybooltruePin tensors to CPU memory
data_loader.drop_lastbooltrueDrop incomplete final batch
data_loader.prefetch_factorint2Batches prefetched per worker
data_loader.persistent_workersboolfalseKeep workers alive between epochs

target_dataset

Unlabeled (or weakly labeled) target-domain training dataset. Same structure as source_dataset.

note

For UDA (unsupervised DA), the target dataset CSV may omit the mask column or point to empty masks. The mask key is only read from the target batch when method.requires_target_labels = True.


source_val_dataset (optional)

Labeled source-domain validation dataset. Used to detect catastrophic forgetting.

When omitted (null), source-domain validation is skipped and DomainAdaptationMonitorCallback will log a warning that forgetting monitoring is disabled.

domain_adaptation:
source_val_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
input_csv_path: /data/source/val.csv
data_loader:
batch_size: 8
num_workers: 4
shuffle: false
drop_last: false

target_val_dataset (optional)

Labeled target-domain validation dataset. Used to measure adaptation progress.

tip

Even in UDA settings, it is common to keep a small labeled target validation set (not used during training) to measure true target performance.


feature_layers (optional)

List of dot-separated layer name strings whose outputs will be captured and passed to compute_da_loss as source_features / target_features.

Only active when method.requires_features = True. When empty and requires_features = True, a warning is logged and both feature dicts will be empty.

domain_adaptation:
feature_layers:
- encoder.layer3
- encoder.layer4

Use dict(model.named_modules()).keys() to inspect available layer names for your architecture.


pretrained_checkpoint (optional)

Warm-start the segmentation model from an existing checkpoint before DA training begins. Loads weights only — optimizer state, epoch counter, and scheduler are reset.

domain_adaptation:
pretrained_checkpoint:
path: /checkpoints/source_model.ckpt # required
source_format: pytorch_lightning # "pytorch_lightning" or "pytorch"
strict_loading: true
FieldTypeDefaultDescription
pathstringAbsolute path to the checkpoint file
source_formatstringpytorch_lightning"pytorch_lightning": reads ckpt["state_dict"] and strips "model." prefix. "pytorch": reads file directly as state dict
strict_loadingbooltruePassed to load_state_dict(strict=...). Set false for partial loads
Difference from resume_from_checkpoint

pretrained_checkpoint loads weights only. pl_trainer.resume_from_checkpoint (or hyperparameters.resume_from_checkpoint) restores the complete training state including optimizer and epoch counter. Use the former when starting a new adaptation phase; use the latter when resuming an interrupted training run.


Lambda Schedulers

Configure inside your method config under lambda_schedule. DomainAdaptationModel calls method.lambda_schedule.get_lambda(epoch, total_epochs) each step if the attribute exists.

ConstantScheduler

Returns a fixed value throughout training.

lambda_schedule:
_target_: pytorch_segmentation_models_trainer.domain_adaptation.schedulers.ConstantScheduler
value: 1.0
FieldDefaultDescription
value1.0Constant lambda

LinearScheduler

Linearly interpolates from start_value to end_value.

lambda_schedule:
_target_: pytorch_segmentation_models_trainer.domain_adaptation.schedulers.LinearScheduler
start_value: 0.0
end_value: 1.0
FieldDefaultDescription
start_value0.0Lambda at epoch 0
end_value1.0Lambda at the final epoch

DANNScheduler

Sigmoid-shaped ramp from Ganin et al. (2016): λ = 2 / (1 + exp(-γ·p)) - 1 where p is training progress in [0, 1].

lambda_schedule:
_target_: pytorch_segmentation_models_trainer.domain_adaptation.schedulers.DANNScheduler
gamma: 10.0
FieldDefaultDescription
gamma10.0Steepness of the growth curve. Higher = faster transition

DomainAdaptationMonitorCallback

callbacks:
- _target_: pytorch_segmentation_models_trainer.domain_adaptation.callbacks.DomainAdaptationMonitorCallback
num_classes: 2
class_names: ["background", "building"]
log_every_n_epochs: 1
forgetting_threshold: 0.05
eval_batch_size: 8
eval_num_workers: 0
FieldDefaultDescription
num_classesNumber of segmentation classes (required)
class_names["Class 0", …]Class names for logging
log_every_n_epochs1How often to evaluate both domains
forgetting_threshold0.05IoU drop (fraction) that triggers a warning
eval_batch_size8Batch size for evaluation forward passes
eval_num_workers0DataLoader workers for evaluation (0 = main process)

Metrics logged by this callback:

MetricDescription
iou/source_valMean IoU on source validation set
iou/target_valMean IoU on target validation set (shown in progress bar)
iou/gap_source_minus_targetDifference between source and target IoU
iou/source_drop_from_baselineIoU drop since epoch 0 on source domain
iou/source_baselineBaseline IoU recorded at on_fit_start