Skip to main content

Implementing a Domain Adaptation Method

Adding a new DA method requires creating a single file with a subclass of BaseDomainAdaptationMethod. No framework code needs to be modified.


The Contract

from pytorch_segmentation_models_trainer.domain_adaptation.base_method import (
BaseDomainAdaptationMethod,
DomainAdaptationLossOutput,
)

class MyMethod(BaseDomainAdaptationMethod):
# ── Class-level flags ────────────────────────────────────────────────
requires_features: bool = False # True → FeatureExtractorHook is activated
requires_target_labels: bool = False # True → SSDA (target has labels)

def __init__(self, my_param=0.5, **kwargs):
super().__init__(**kwargs) # passes lambda_da and other base kwargs
self.my_param = my_param
# auxiliary networks go here as nn.Module attributes:
# self.discriminator = nn.Sequential(...)

# ── REQUIRED ────────────────────────────────────────────────────────
def compute_da_loss(
self,
source_batch, # dict with "image", "mask"
target_batch, # dict with "image" (no "mask" for UDA)
source_output, # model logits for source, shape (B, C, H, W)
target_output, # model logits for target, shape (B, C, H, W)
source_features, # dict of layer_name → tensor, or {} if requires_features=False
target_features, # dict of layer_name → tensor, or {} if requires_features=False
**kwargs,
) -> DomainAdaptationLossOutput:
loss = ... # your logic
return DomainAdaptationLossOutput(
loss=loss,
log_dict={"my_metric": loss.item()},
)

# ── OPTIONAL hooks (default = no-op) ────────────────────────────────
def on_fit_start(self, pl_module): ...
def on_train_epoch_start(self, pl_module, epoch): ...
def on_train_epoch_end(self, pl_module, epoch): ...

# ── OPTIONAL: override if auxiliary networks need a different LR ────
def get_extra_parameter_groups(self):
return [{"params": self.discriminator.parameters(), "lr": 1e-4}]

__init__ Convention

Hydra instantiates your method from the config by calling it with all config keys as keyword arguments:

method:
_target_: my_package.methods.MyMethod
lambda_da: 1.0
my_param: 0.5

MyMethod(lambda_da=1.0, my_param=0.5)

Always forward remaining kwargs to super().__init__(**kwargs) so that lambda_da and any future base-class kwargs are handled correctly.


DomainAdaptationLossOutput

Every compute_da_loss must return a DomainAdaptationLossOutput:

@dataclass
class DomainAdaptationLossOutput:
loss: Tensor # DA loss — combined as: total = seg_loss + λ * loss
log_dict: dict # logged under "da/" prefix in TensorBoard/W&B
extra: dict = {} # not logged automatically — available to callbacks

Example 1 — Entropy Minimization (ADVENT)

The simplest possible method. No auxiliary networks, no feature maps needed.

class ADVENTMethod(BaseDomainAdaptationMethod):
"""Entropy minimization on the target domain (Vu et al., 2019)."""

requires_features = False

def compute_da_loss(
self, source_batch, target_batch,
source_output, target_output,
source_features, target_features, **kwargs
):
probs = torch.softmax(target_output, dim=1)
entropy = -(probs * torch.log(probs + 1e-6)).sum(dim=1).mean()
return DomainAdaptationLossOutput(
loss=entropy,
log_dict={"entropy_loss": entropy.item()},
)

Config:

method:
_target_: my_package.methods.ADVENTMethod
lambda_da: 0.001

Example 2 — DANN with Gradient Reversal

Uses intermediate feature maps and a discriminator network.

import torch
import torch.nn as nn
import torch.nn.functional as F


class GradientReversal(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x

@staticmethod
def backward(ctx, grad):
return -ctx.alpha * grad, None


class DANNMethod(BaseDomainAdaptationMethod):
"""Domain-Adversarial Neural Network (Ganin et al., 2016)."""

requires_features = True # needs intermediate feature maps

def __init__(self, feature_dim=256, hidden_dim=1024, **kwargs):
super().__init__(**kwargs)
self.discriminator = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(feature_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
)
self.lambda_schedule = None # set from config if desired

def compute_da_loss(
self, source_batch, target_batch,
source_output, target_output,
source_features, target_features, **kwargs
):
# Use the feature map captured by FeatureExtractorHook
src_feat = source_features["encoder.layer3"] # matches cfg.feature_layers
tgt_feat = target_features["encoder.layer3"]

alpha = self.lambda_da # or use self.lambda_schedule if configured

# Gradient reversal — forward is identity, backward reverses gradient
src_rev = GradientReversal.apply(src_feat, alpha)
tgt_rev = GradientReversal.apply(tgt_feat, alpha)

src_pred = self.discriminator(src_rev)
tgt_pred = self.discriminator(tgt_rev)

labels = torch.cat([
torch.zeros(len(src_pred), device=src_pred.device),
torch.ones(len(tgt_pred), device=tgt_pred.device),
])
loss = F.binary_cross_entropy_with_logits(
torch.cat([src_pred, tgt_pred]).squeeze(), labels
)
return DomainAdaptationLossOutput(
loss=loss,
log_dict={"dann_loss": loss.item()},
)

def get_extra_parameter_groups(self):
# discriminator can use a higher LR than the encoder
return [{"params": self.discriminator.parameters(), "lr": 1e-3}]

Config:

method:
_target_: my_package.methods.DANNMethod
lambda_da: 1.0
feature_dim: 256
hidden_dim: 1024
lambda_schedule:
_target_: pytorch_segmentation_models_trainer.domain_adaptation.schedulers.DANNScheduler
gamma: 10.0

feature_layers:
- encoder.layer3

Example 3 — Pseudo-labels with Confidence Threshold

Self-training approach: generate pseudo-labels on the target domain and use them as supervision only where the model is confident.

class PseudoLabelMethod(BaseDomainAdaptationMethod):
"""Self-training with confidence-thresholded pseudo-labels."""

requires_features = False

def __init__(self, threshold=0.9, **kwargs):
super().__init__(**kwargs)
self.threshold = threshold

def compute_da_loss(
self, source_batch, target_batch,
source_output, target_output,
source_features, target_features, **kwargs
):
probs = torch.softmax(target_output, dim=1)
confidence, pseudo_labels = probs.max(dim=1)

# Only supervise pixels above confidence threshold
mask = confidence > self.threshold
if mask.sum() == 0:
loss = torch.tensor(0.0, device=target_output.device, requires_grad=True)
return DomainAdaptationLossOutput(
loss=loss,
log_dict={"pseudo_label_ratio": 0.0},
)

loss = F.cross_entropy(
target_output[mask.unsqueeze(1).expand_as(target_output)].view(-1, target_output.shape[1]),
pseudo_labels[mask],
)
ratio = mask.float().mean().item()
return DomainAdaptationLossOutput(
loss=loss,
log_dict={"pseudo_label_loss": loss.item(), "pseudo_label_ratio": ratio},
)

Lifecycle Hooks

Use hooks when your method needs to maintain state across epochs:

def on_fit_start(self, pl_module):
# Called once when trainer.fit() begins.
# Good for: initialising memory banks, computing dataset statistics.
pass

def on_train_epoch_start(self, pl_module, epoch):
# Called at the start of each epoch.
# Good for: updating threshold schedules, resetting accumulators.
pass

def on_train_epoch_end(self, pl_module, epoch):
# Called at the end of each epoch.
# Good for: recomputing pseudo-labels, logging epoch-level stats.
pass

Using Intermediate Feature Maps

Set requires_features = True and list the layers to capture in the config:

domain_adaptation:
feature_layers:
- encoder.layer3
- encoder.layer4

DomainAdaptationModel installs a FeatureExtractorHook that captures those layers during forward. The tensors arrive in source_features and target_features as a dict:

src_feat = source_features["encoder.layer3"]   # Tensor (B, C, H, W)
tgt_feat = target_features["encoder.layer3"]

The layer names must match the attribute path on self.model. Use dict(model.named_modules()) to inspect available names.