Skip to main content

Main Module

The main entry point for the pytorch-segmentation-models-trainer CLI and core functionality.

Overview

The main module provides the primary interface for training, prediction, and other operations. It uses Hydra for configuration management and supports multiple execution modes.

CLI Usage

# Training
pytorch-smt --config-dir ./configs --config-name my_config +mode=train

# Prediction
pytorch-smt --config-dir ./configs --config-name my_config +mode=predict

# Mask building
pytorch-smt --config-dir ./configs --config-name my_config +mode=build-mask

# Configuration validation
pytorch-smt --config-dir ./configs --config-name my_config +mode=validate-config

Supported Modes

ModeDescriptionUse Case
trainTrain a modelModel training with PyTorch Lightning
predictRun inferenceBatch prediction on images
predict-from-batchBatch predictionEfficient batch processing
validate-configConfig validationDebug configuration files
build-maskBuild masksGenerate training masks from vectors
evaluate-experimentsRun evaluation pipelineCompare model outputs against ground truth
convert-datasetConvert dataset formatPrepare data for Polygon-RNN models

Configuration Structure

The main function expects a Hydra configuration with the following structure:

# Operation mode
mode: train # or predict, build-mask, etc.

# Model configuration
model:
_target_: segmentation_models_pytorch.Unet
# ... model parameters

# Training configuration (for train mode)
pl_trainer:
max_epochs: 100
gpus: 1

# Prediction configuration (for predict mode)
checkpoint_path: /path/to/model.ckpt
inference_threshold: 0.5

Examples

Programmatic Usage

from pytorch_segmentation_models_trainer.main import main
from omegaconf import DictConfig

# Create configuration
config = DictConfig({
"mode": "train",
"model": {
"_target_": "segmentation_models_pytorch.Unet",
"encoder_name": "resnet34",
"classes": 1
},
"pl_trainer": {
"max_epochs": 10,
"gpus": 0
}
# ... other config
})

# Run training
main(config)

Custom Entry Point

import hydra
from omegaconf import DictConfig
from pytorch_segmentation_models_trainer.main import main

@hydra.main(config_path="configs", config_name="my_config")
def my_main(cfg: DictConfig):
# Custom preprocessing
cfg.custom_param = "my_value"

# Run main function
return main(cfg)

if __name__ == "__main__":
my_main()

Logging

Configure logging in your config:

hydra:
job:
chdir: true
run:
dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}

logging:
level: INFO
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'