Skip to main content

Training a Semantic Segmentation Model

This guide walks through setting up and running a full semantic segmentation training job using the Model base class, which wraps a segmentation_models_pytorch architecture inside PyTorch Lightning.

The Model Class

The base Model class (pytorch_segmentation_models_trainer.model_loader.model.Model) is a pl.LightningModule that wires together architecture, loss, optimizer, scheduler, datasets, metrics, and GPU augmentations from a single Hydra config object.

To use it, set the pl_model key in your config:

pl_model:
_target_: pytorch_segmentation_models_trainer.model_loader.model.Model

If pl_model is omitted, the training script defaults to Model automatically.

Config Sections Overview

A complete training config contains the following top-level keys:

KeyPurpose
modelNeural network architecture
lossLoss function (simple) or loss_params (compound)
optimizerOptimizer class and hyperparameters
scheduler_listList of LR scheduler configs
hyperparametersBatch size, device count
pl_trainerPyTorch Lightning Trainer kwargs
callbacksCheckpoint, early stopping, LR monitor, etc.
metricstorchmetrics metrics to compute
loggerTensorBoard, WandB, CSV logger
train_datasetTraining dataset and dataloader settings
val_datasetValidation dataset — monitored at the end of every epoch during fit
test_dataset(optional) Test dataset — evaluated once after fit via trainer.test()
modetrain or predict

Supported Architectures

The model key accepts any class from segmentation_models_pytorch:

Architecture_target_
U-Netsegmentation_models_pytorch.Unet
U-Net++segmentation_models_pytorch.UnetPlusPlus
DeepLabV3+segmentation_models_pytorch.DeepLabV3Plus
FPNsegmentation_models_pytorch.FPN
PSPNetsegmentation_models_pytorch.PSPNet
PANsegmentation_models_pytorch.PAN
MAnetsegmentation_models_pytorch.MAnet
Linknetsegmentation_models_pytorch.Linknet

Any timm-compatible encoder name works. Common choices:

  • resnet34 — lightweight baseline, fast to train
  • resnet50 — stronger features, moderate cost
  • efficientnet-b3 — high accuracy per parameter
  • resnext50_32x4d — strong general-purpose encoder

Scheduler List

The scheduler_list is a list of scheduler configs, each with these keys:

scheduler_list:
- scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
T_max: 100
eta_min: 1e-6
monitor: loss/val # metric to monitor (for ReduceLROnPlateau)
interval: epoch # "epoch" or "step"
name: cosine_annealing # label shown in logger
OneCycleLR Auto-Configuration

When using OneCycleLR, set steps_per_epoch to null or omit it entirely. The framework automatically computes it from the CSV dataset size and batch size at the start of training:

steps_per_epoch = dataset_size // (batch_size * devices * accumulate_grad_batches)

The computed value is printed to the console during setup.

scheduler_list:
- scheduler:
_target_: torch.optim.lr_scheduler.OneCycleLR
max_lr: 0.001
epochs: 100
# steps_per_epoch is computed automatically from train_dataset CSV
interval: step
name: one_cycle_lr

Complete Config Example

configs/train_unet.yaml
# ── Model ────────────────────────────────────────────────────────────────────
model:
_target_: segmentation_models_pytorch.Unet
encoder_name: resnet34
encoder_weights: imagenet
in_channels: 3
classes: 1
activation: sigmoid

# ── Loss ─────────────────────────────────────────────────────────────────────
loss:
_target_: segmentation_models_pytorch.losses.DiceLoss
mode: binary

# ── Optimizer ────────────────────────────────────────────────────────────────
optimizer:
_target_: torch.optim.AdamW
lr: 0.001
weight_decay: 1e-4

# ── LR Schedulers ────────────────────────────────────────────────────────────
scheduler_list:
- scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
T_max: 100
eta_min: 1e-6
monitor: loss/val
interval: epoch
name: cosine_lr

# ── Hyperparameters ──────────────────────────────────────────────────────────
hyperparameters:
batch_size: 8

# ── PyTorch Lightning Trainer ────────────────────────────────────────────────
pl_trainer:
max_epochs: 100
accelerator: gpu
devices: 1
precision: 16
log_every_n_steps: 50

# ── Callbacks ────────────────────────────────────────────────────────────────
callbacks:
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: loss/val
mode: min
save_top_k: 3
filename: "best-{epoch:02d}-{loss/val:.4f}"
dirpath: ./checkpoints

- _target_: pytorch_lightning.callbacks.EarlyStopping
monitor: loss/val
patience: 15
mode: min

- _target_: pytorch_lightning.callbacks.LearningRateMonitor
logging_interval: epoch

# ── Metrics ──────────────────────────────────────────────────────────────────
metrics:
- _target_: torchmetrics.JaccardIndex
task: binary
- _target_: torchmetrics.F1Score
task: binary

# ── Logger ───────────────────────────────────────────────────────────────────
logger:
_target_: pytorch_lightning.loggers.TensorBoardLogger
save_dir: ./experiments
name: unet_resnet34

# ── Datasets ─────────────────────────────────────────────────────────────────
train_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
input_csv_path: /data/train.csv
root_dir: /data
augmentation_list:
- _target_: albumentations.HorizontalFlip
p: 0.5
- _target_: albumentations.RandomBrightnessContrast
p: 0.3
- _target_: albumentations.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
data_loader:
shuffle: true
num_workers: 4
pin_memory: true
drop_last: true
prefetch_factor: 2

val_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
input_csv_path: /data/val.csv
root_dir: /data
augmentation_list:
- _target_: albumentations.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
data_loader:
shuffle: false
num_workers: 4
pin_memory: true
drop_last: false
prefetch_factor: 2

# ── Test Dataset (optional) ───────────────────────────────────────────────────
# When present, trainer.test() is called automatically after trainer.fit().
# All metrics are logged with the "test/" prefix.
test_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
input_csv_path: /data/test.csv
root_dir: /data
augmentation_list:
- _target_: albumentations.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
data_loader:
shuffle: false
num_workers: 4
pin_memory: true
drop_last: false
prefetch_factor: 2

# ── Mode ─────────────────────────────────────────────────────────────────────
mode: train
device: cuda

Multi-GPU Training

Enable distributed training by setting accelerator, devices, and strategy in pl_trainer:

pl_trainer:
accelerator: gpu
devices: 2 # number of GPUs
strategy: ddp # DistributedDataParallel
max_epochs: 100

hyperparameters:
batch_size: 8 # per-GPU batch size
Effective Batch Size

With DDP, the effective batch size is batch_size * devices. OneCycleLR auto-configuration accounts for this automatically when computing steps_per_epoch.

Use devices: -1 to use all available GPUs on the machine.


Mixed Precision Training

pl_trainer:
precision: 16 # FP16 mixed precision
# or:
# precision: "bf16" # BFloat16 (better numerical stability on Ampere+)
Learning Rate with Mixed Precision

When using precision: 16, consider increasing the learning rate slightly (e.g., from 0.001 to 0.002) as mixed precision training can tolerate higher learning rates.


Checkpointing

The ModelCheckpoint callback saves models according to a monitored metric:

callbacks:
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: loss/val # metric to watch
mode: min # save when metric decreases
save_top_k: 3 # keep top 3 checkpoints
save_last: true # also save the last epoch
filename: "epoch={epoch}-val_loss={loss/val:.4f}"
dirpath: ./checkpoints/my_experiment

Resuming from Checkpoint

To resume a training run from a saved checkpoint, add resume_from_checkpoint to hyperparameters:

hyperparameters:
batch_size: 8
resume_from_checkpoint: ./checkpoints/my_experiment/last.ckpt

The training script detects this key and calls load_from_checkpoint before constructing the trainer, restoring model weights, optimizer state, and epoch counter.


Running Training

pytorch-smt --config-dir ./configs --config-name train_unet

Override individual parameters at the command line without editing the YAML:

# Change encoder
pytorch-smt --config-dir ./configs --config-name train_unet \
model.encoder_name=resnet50

# Change batch size and learning rate
pytorch-smt --config-dir ./configs --config-name train_unet \
hyperparameters.batch_size=16 optimizer.lr=5e-4

# Enable multi-GPU
pytorch-smt --config-dir ./configs --config-name train_unet \
pl_trainer.devices=4 pl_trainer.strategy=ddp
Preview Your Config

Before training, print the fully resolved config to check interpolations and overrides:

pytorch-smt --config-dir ./configs --config-name train_unet --cfg job

Dataset Splits

The framework follows the standard three-way split used in machine learning:

Config keyPyTorch Lightning stepWhen it runsMetric prefix
train_datasettraining_stepEvery batch during trainer.fit()train/
val_datasetvalidation_stepEnd of every epoch during trainer.fit()val/
test_datasettest_stepOnce after trainer.fit(), via trainer.test()test/

val_dataset and test_dataset are both optional. When val_dataset is absent, Lightning skips the validation loop entirely. When test_dataset is absent, trainer.test() is not called. You can use all three, only train + val, or only train.

When to use each split
  • val_dataset: use this for epoch-level monitoring during training — it drives early stopping, ModelCheckpoint, and LR schedulers that monitor a val/ metric.
  • test_dataset: use this for the final held-out evaluation reported in papers or production metrics. It is intentionally kept separate from val_dataset so that hyperparameter tuning decisions are not influenced by test-set performance.

Logged Metrics

During training, the following metrics are written to the logger automatically:

Metric keyWhen logged
loss/trainEach step and epoch
loss/valEach validation epoch (requires val_dataset)
loss/testOnce after training (requires test_dataset)
train/<metric_name>Per step and epoch (from metrics list)
val/<metric_name>Per validation epoch (from metrics list)
test/<metric_name>Once after training (from metrics list)
losses/train_<name>Per component when using compound loss
losses/val_<name>Per component when using compound loss
losses/test_<name>Per component when using compound loss (test run)