Skip to main content

Running Inference

After training a model, you can run inference on new images using either of two CLI modes: predict for single-image sliding-window processing, or predict-from-batch for batch processing via PyTorch Lightning's Trainer.predict.


Choosing the Right Mode

ModeEntry PointBest For
predictpredict.pyLarge images, sliding-window tiling, one-by-one processing
predict-from-batchpredict_from_batch.pyBatch inference with PyTorch Lightning, CSV-driven datasets

predict Mode

The predict script reads a list of images from an image reader, processes each image through an inference processor (with optional sliding-window tiling), applies a threshold, and writes results using an export strategy.

Running the Command

pytorch-smt --config-dir ./configs --config-name predict_config

Override parameters on the command line without editing the YAML:

pytorch-smt --config-dir ./configs --config-name predict_config \
checkpoint_path=/new/model.ckpt \
device=cuda:1 \
inference_threshold=0.4

Top-Level Config Keys

KeyTypeDescription
checkpoint_pathstrPath to the .ckpt file saved during training
devicestrCompute device: cuda:0, cuda:1, cpu, etc.
inference_image_readerobjectDefines which images to process
inference_processorobjectDefines how each image is processed
export_strategyobjectDefines where and how output is saved
inference_thresholdfloatBinarisation threshold (default 0.5)
save_inferenceboolWhether to write output files (default true)
pl_modelobjectPyTorch Lightning model class (same as training config)
hyperparametersobjectMust include batch_size
seg_paramsobjectOptional; controls number of output mask bands

Image Readers

Image readers are set under inference_image_reader and tell the predict script which files to process.

SingleImageReaderProcessor

Processes a single image file.

inference_image_reader:
_target_: pytorch_segmentation_models_trainer.tools.data_handlers.raster_reader.SingleImageReaderProcessor
file_name: /data/images/tile_001.tif

FolderImageReaderProcessor

Scans a folder for all images matching a given extension.

inference_image_reader:
_target_: pytorch_segmentation_models_trainer.tools.data_handlers.raster_reader.FolderImageReaderProcessor
folder_name: /data/images/
image_extension: tif # file extension to match (without dot)
recursive: true # search subfolders recursively
ParameterDefaultDescription
folder_namerequiredRoot folder to scan
image_extensiontifExtension to match (e.g. tif, png, jpg)
recursivetrueRecurse into subdirectories

CSVImageReaderProcessor

Reads image paths from a CSV file column.

inference_image_reader:
_target_: pytorch_segmentation_models_trainer.tools.data_handlers.raster_reader.CSVImageReaderProcessor
input_csv_path: /data/test.csv
key: image # column name containing image paths
root_dir: /data # prepended to relative paths

Inference Processors

The inference processor does the actual model forward pass and handles tiling for large images.

Sliding-Window Tiling

All processors based on SingleImageInfereceProcessor use sliding-window tiling from pytorch-toolbelt:

  1. The input image is padded to a multiple of model_input_shape.
  2. An ImageSlicer divides the padded image into overlapping tiles of size model_input_shape, with a stride of step_shape.
  3. Each batch of tiles is passed through the model with torch.no_grad() and AMP (autocast).
  4. A TileMerger accumulates predictions using weighted averaging in the overlap regions, reducing seam artifacts.
  5. The merged mask is cropped back to the original image dimensions.

Overlap is controlled by the gap between model_input_shape and step_shape. A smaller step_shape means more overlap and smoother blending at the cost of more forward passes.

Choosing step_shape

A common setting is step_shape = model_input_shape / 2. For example, with model_input_shape: [448, 448], set step_shape: [224, 224] for 50% overlap. Reduce the step further for very high-detail predictions.

Test Time Augmentation (TTA)

TTA improves prediction quality by running the model on multiple transformed versions of each tile, reversing the transformation on each output, and averaging the results before handing the tile to the TileMerger. The sliding-window pipeline itself is not affected — it always receives one already-consolidated prediction per tile.

Add use_tta: true and tta_augmentations to any inference_processor block:

inference_processor:
_target_: pytorch_segmentation_models_trainer.tools.inference.inference_processors.SingleImageInfereceProcessor
model_input_shape: [448, 448]
step_shape: [224, 224]
use_tta: true
tta_augmentations:
- rot0 # original image (identity — always include this)
- rot90 # 90° counter-clockwise rotation
- rot180 # 180° rotation
- rot270 # 270° counter-clockwise rotation
- flip_h # horizontal flip (left-right mirror)
- flip_v # vertical flip (up-down mirror)
- flip_h_rot90 # horizontal flip + 90° counter-clockwise rotation
- flip_v_rot90 # vertical flip + 90° counter-clockwise rotation

The eight augmentations form the D4 dihedral group — all symmetries of a square. You can use any subset; the default when tta_augmentations is omitted is the four 90° rotations (rot0, rot90, rot180, rot270).

Augmentation nameTransformation
rot0Identity — original image
rot9090° counter-clockwise rotation
rot180180° rotation
rot270270° counter-clockwise rotation
flip_hHorizontal flip (left-right mirror)
flip_vVertical flip (up-down mirror)
flip_h_rot90Horizontal flip + 90° counter-clockwise rotation
flip_v_rot90Vertical flip + 90° counter-clockwise rotation

Each augmentation has an exact mathematical inverse — predictions are de-augmented before averaging, so no spatial artifact is introduced.

TTA and computational cost

Each augmentation in tta_augmentations adds one full model forward pass per tile. Four rotations = 4× the inference time; the full D4 group = 8×. Balance quality vs. speed accordingly.

Frame-field models (SingleImageFromFrameFieldProcessor)

The crossfield output encodes tangent angles that require non-trivial transformation under spatial rotations. TTA is applied normally to seg; crossfield is always taken from the identity (rot0) pass. Include rot0 in tta_augmentations when using this processor with TTA enabled.

See the TTA advanced guide for full details, API reference, and config examples.

SingleImageInfereceProcessor

General-purpose sliding-window processor for binary or single-output segmentation.

inference_processor:
_target_: pytorch_segmentation_models_trainer.tools.inference.inference_processors.SingleImageInfereceProcessor
model_input_shape: [448, 448] # tile size fed to the model
step_shape: [224, 224] # sliding step (controls overlap)
ParameterDefaultDescription
model_input_shape[448, 448]Height and width of each tile
step_shape[224, 224]Stride of the sliding window
normalize_meannullPer-channel mean for normalization (ImageNet default if null)
normalize_stdnullPer-channel std for normalization (ImageNet default if null)

MultiClassInferenceProcessor

For multi-class segmentation models that return [B, num_classes, H, W] logits. After tile merging, argmax is applied across the class dimension to produce a single-band class-index map with values 0 to num_classes - 1.

inference_processor:
_target_: pytorch_segmentation_models_trainer.tools.inference.inference_processors.MultiClassInferenceProcessor
model_input_shape: [512, 512]
step_shape: [256, 256]
num_classes: 5
ParameterDefaultDescription
num_classes2Number of output classes
model_input_shape[448, 448]Tile size
step_shape[224, 224]Sliding stride

The output raster has 1 band, dtype=uint8, with pixel values equal to the predicted class index. The threshold parameter is ignored for this processor.

SingleImageFromFrameFieldProcessor

For frame field segmentation models (e.g. FrameFieldSegmentationPLModel) that return both a segmentation mask and a cross-field tensor. Both outputs are tiled and merged independently.

inference_processor:
_target_: pytorch_segmentation_models_trainer.tools.inference.inference_processors.SingleImageFromFrameFieldProcessor
model_input_shape: [448, 448]
step_shape: [224, 224]
mask_bands: 1

This processor is required when using any frame-field-based polygonizer (e.g. ASMPolygonizerProcessor, ACMPolygonizerProcessor), since they consume the crossfield key from the inference output.

BinaryMaskInferenceProcessor / Base Class Threshold

All segmentation processors apply threshold after merging:

output_mask = (merged_probability > threshold).astype(uint8)

Set the threshold in the top-level config:

inference_threshold: 0.5

Normalization

By default, all processors normalize images using ImageNet statistics via albumentations.Normalize() (mean [0.485, 0.456, 0.406], std [0.229, 0.224, 0.225]).

To override normalization, set normalize_mean and normalize_std in the inference_processor block:

inference_processor:
_target_: pytorch_segmentation_models_trainer.tools.inference.inference_processors.SingleImageInfereceProcessor
model_input_shape: [448, 448]
step_shape: [224, 224]
normalize_mean: [0.5, 0.5, 0.5]
normalize_std: [0.25, 0.25, 0.25]
Custom Normalization

The mean and std values must match whatever normalization was applied during training. Mismatched normalization is a common source of poor inference results.


Export Strategies

Export strategies control how inference outputs are written to disk or a database. Set them under the export_strategy key.

RasterExportInferenceStrategy

Saves a single output raster (the seg band) to a fixed output path.

export_strategy:
_target_: pytorch_segmentation_models_trainer.tools.inference.export_inference.RasterExportInferenceStrategy
output_file_path: /output/prediction.tif

MultipleRasterExportInferenceStrategy

Saves multiple output bands (e.g. seg and crossfield) as separate GeoTIFF files in a folder, one file per input image.

export_strategy:
_target_: pytorch_segmentation_models_trainer.tools.inference.export_inference.MultipleRasterExportInferenceStrategy
output_folder: /output/predictions/
output_basename: pred.tif

Output filenames follow the pattern {band_key}_{input_name}_{output_basename}.

VectorFileExportInferenceStrategy

Writes inference output as a vector file (GeoJSON or Shapefile). This is typically used together with a polygonizer.

export_strategy:
_target_: pytorch_segmentation_models_trainer.tools.inference.export_inference.VectorFileExportInferenceStrategy
output_file_path: /output/predictions.geojson
driver: GeoJSON # or "ESRI Shapefile"

VectorDatabaseExportInferenceStrategy

Writes polygons directly to a PostGIS-enabled PostgreSQL database.

export_strategy:
_target_: pytorch_segmentation_models_trainer.tools.inference.export_inference.VectorDatabaseExportInferenceStrategy
user: postgres
password: secret
database: geodata
host: localhost
port: 5432
sql: "SELECT id, geom FROM buildings"
table_name: buildings
geometry_column: geom

predict-from-batch Mode

The predict-from-batch script loads images from a CSV (or builds one from a folder), groups them by spatial dimensions to enable efficient batching, and runs PyTorch Lightning's Trainer.predict.

Running the Command

pytorch-smt --config-dir ./configs --config-name predict_batch_config +mode=predict-from-batch

Dataset Configuration

Three input modes are supported, tried in this priority order:

Mode 1 — CSV file:

inference_dataset:
input_csv_path: /data/test.csv
root_dir: /data
data_loader:
num_workers: 4
prefetch_factor: 2

Mode 2 — Build CSV from folder:

inference_dataset:
build_csv_from_folder:
enabled: true
images_folder: /data/images/
root_dir: /data

Mode 3 — Legacy (val_dataset):

val_dataset:
input_csv_path: /data/val.csv
root_dir: /data
Windowed Inference

Set use_inference_processor: true in the batch config to use TiledInferenceImageDataset, which applies sliding-window tiling within the dataloader rather than inside the processor. This requires inference_processor.model_input_shape and inference_processor.step_shape to be defined.


Full Config Examples

Binary Segmentation (Single Image)

configs/predict_binary.yaml
# ── Checkpoint ────────────────────────────────────────────────────────────────
checkpoint_path: /checkpoints/unet_best.ckpt
device: cuda:0

# ── Model (must match training config) ───────────────────────────────────────
pl_model:
_target_: pytorch_segmentation_models_trainer.model_loader.model.Model

hyperparameters:
batch_size: 8

# ── Image Reader ──────────────────────────────────────────────────────────────
inference_image_reader:
_target_: pytorch_segmentation_models_trainer.tools.data_handlers.raster_reader.FolderImageReaderProcessor
folder_name: /data/test_images/
image_extension: tif
recursive: true

# ── Inference Processor ───────────────────────────────────────────────────────
inference_processor:
_target_: pytorch_segmentation_models_trainer.tools.inference.inference_processors.SingleImageInfereceProcessor
model_input_shape: [448, 448]
step_shape: [224, 224]

# ── Export Strategy ───────────────────────────────────────────────────────────
export_strategy:
_target_: pytorch_segmentation_models_trainer.tools.inference.export_inference.MultipleRasterExportInferenceStrategy
output_folder: /output/predictions/
output_basename: pred.tif

# ── Threshold ─────────────────────────────────────────────────────────────────
inference_threshold: 0.5

Multi-Class Segmentation

configs/predict_multiclass.yaml
checkpoint_path: /checkpoints/multiclass_best.ckpt
device: cuda:0

pl_model:
_target_: pytorch_segmentation_models_trainer.model_loader.model.Model

hyperparameters:
batch_size: 4

inference_image_reader:
_target_: pytorch_segmentation_models_trainer.tools.data_handlers.raster_reader.FolderImageReaderProcessor
folder_name: /data/test_images/
image_extension: tif
recursive: false

inference_processor:
_target_: pytorch_segmentation_models_trainer.tools.inference.inference_processors.MultiClassInferenceProcessor
model_input_shape: [512, 512]
step_shape: [256, 256]
num_classes: 5

export_strategy:
_target_: pytorch_segmentation_models_trainer.tools.inference.export_inference.RasterExportInferenceStrategy
output_file_path: /output/multiclass_pred.tif

Sliding-Window on Large Aerial Images

configs/predict_large_image.yaml
checkpoint_path: /checkpoints/building_detector.ckpt
device: cuda:0

pl_model:
_target_: pytorch_segmentation_models_trainer.model_loader.model.Model

hyperparameters:
batch_size: 16 # larger batch for tile processing

inference_image_reader:
_target_: pytorch_segmentation_models_trainer.tools.data_handlers.raster_reader.SingleImageReaderProcessor
file_name: /data/large_ortho.tif

inference_processor:
_target_: pytorch_segmentation_models_trainer.tools.inference.inference_processors.SingleImageInfereceProcessor
model_input_shape: [448, 448]
step_shape: [112, 112] # 75% overlap for high-detail output
normalize_mean: [0.485, 0.456, 0.406]
normalize_std: [0.229, 0.224, 0.225]

export_strategy:
_target_: pytorch_segmentation_models_trainer.tools.inference.export_inference.RasterExportInferenceStrategy
output_file_path: /output/large_ortho_pred.tif

inference_threshold: 0.45
save_inference: true

PyTorch Lightning Trainer Settings for Inference

Both predict and predict-from-batch use the pl_trainer config block. For inference you typically only need accelerator and device settings:

pl_trainer:
accelerator: gpu
devices: [0] # specific GPU index; use -1 for all GPUs