Skip to main content

Model Classes

This page is the API reference for all PyTorch Lightning model classes in pytorch_segmentation_models_trainer.


Model

Base model class. All other model classes extend Model.

Import path: pytorch_segmentation_models_trainer.model_loader.model.Model

Constructor

Model(cfg, inference_mode=False)
ParameterTypeDefaultDescription
cfgDictConfigrequiredHydra configuration object
inference_modeboolFalseWhen True, skips dataset and loss instantiation (for inference-only use)

Key Methods

MethodSignatureDescription
get_model() -> nn.ModuleInstantiates the backbone from cfg.model via Hydra's instantiate. Optionally replaces activations if cfg.replace_model_activation is set
get_loss_function() -> Union[nn.Module, MultiLoss]Resolves the loss function using three dispatch paths (see below)
training_step(batch, batch_idx) -> TensorUnpacks batch as (images, masks), runs forward pass, computes loss (simple or compound), logs loss/train and individual component losses
validation_step(batch, batch_idx) -> TensorSame as training_step but logs loss/val; also logs val/ metrics if cfg.metrics is configured
test_step(batch, batch_idx) -> TensorRuns after training via trainer.test(); logs loss/test and test/ metrics. Applies gpu_test_transform if configured
configure_optimizers() -> Tuple[List, List]Builds optimizer from cfg.optimizer; builds scheduler list from cfg.scheduler_list. Automatically computes steps_per_epoch for OneCycleLR by reading the training CSV
train_dataloader() -> DataLoaderCreates DataLoader for train_ds using cfg.hyperparameters.batch_size and cfg.train_dataset.data_loader settings
val_dataloader() -> Optional[DataLoader]Creates DataLoader for val_ds, or returns None if val_dataset is absent from the config
test_dataloader() -> Optional[DataLoader]Creates DataLoader for test_ds, or returns None if test_dataset is absent from the config
forward(x) -> TensorDelegates to self.model(x)
predict_step(batch, batch_idx, dataloader_idx=0) -> TensorRuns self(batch) for inference
set_encoder_trainable(trainable=False) -> NoneFreezes or unfreezes the model.encoder parameters

Configuration Keys

KeyRequiredDescription
modelYesHydra target for the backbone architecture
lossConditionalSimple loss (lowest priority dispatch path)
loss_params.compound_lossConditionalNew YAML-based compound loss config (highest priority)
loss_params.multi_lossConditionalLegacy multi-loss config (medium priority)
optimizerYesHydra target for the optimizer
scheduler_listNoList of scheduler configs; each entry has a scheduler key and Lightning interval/frequency keys
hyperparameters.batch_sizeYesBatch size for train, val, and test dataloaders
hyperparameters.epochsNoUsed when auto-computing steps_per_epoch
pl_trainerNoPassed to the Lightning Trainer; also checked for devices and accumulate_grad_batches
callbacksNoList of Hydra-instantiated callbacks
metricsNoList of torchmetrics metrics to compute; auto-prefixed train/, val/, or test/
loggerNoLogger configuration
train_datasetYes (unless inference_mode)Dataset config including input_csv_path and data_loader sub-config
val_datasetNoDataset config for per-epoch validation during fit. When absent, val_dataloader() returns None and Lightning skips the validation loop
test_datasetNoDataset config for final held-out evaluation. When present, trainer.test() is called automatically after fit; metrics are logged with test/ prefix
replace_model_activationNoDict with old_activation and new_activation for activation replacement

FrameFieldSegmentationPLModel

PyTorch Lightning model for frame field segmentation. Extends Model with frame-field-specific loss building, normalization, and output handling.

Import path: pytorch_segmentation_models_trainer.model_loader.frame_field_model.FrameFieldSegmentationPLModel

Constructor

FrameFieldSegmentationPLModel(cfg: DictConfig)
ParameterTypeDescription
cfgDictConfigHydra configuration. Must include backbone, compute_seg, compute_crossfield, and seg_params in addition to the base Model keys

Additional configuration keys beyond Model:

KeyDescription
backboneHydra target for the frame field backbone (e.g., a FrameFieldModel wrapper)
compute_segWhether to compute segmentation output
compute_crossfieldWhether to compute cross-field output
seg_params.compute_interiorWhether to predict interior mask channel
seg_params.compute_edgeWhether to predict edge mask channel
seg_params.compute_vertexWhether to predict vertex mask channel

Key Methods

MethodSignatureDescription
training_step(batch, batch_idx) -> TensorRuns forward pass, computes MultiLoss, logs loss/train and all component losses under losses/train_{name}; also logs iou/train_IoU_{threshold}
validation_step(batch, batch_idx) -> TensorSame as training but logs loss/val, losses/val_{name}, and iou/val_IoU_{threshold}
test_step(batch, batch_idx) -> TensorRuns after training via trainer.test(); logs loss/test, losses/test_{name}, and iou/test_IoU_{threshold}
on_train_start() -> NoneTriggers loss normalization computation via _compute_loss_normalization()

Model output dict keys:

KeyShapeDescription
"seg"(N, C_seg, H, W)Segmentation probabilities
"crossfield"(N, 4, H, W)Cross-field angles (c0 + c2)

ObjectDetectionPLModel

PyTorch Lightning model for object detection using Torchvision-style detection models (e.g., Faster R-CNN).

Import path: pytorch_segmentation_models_trainer.model_loader.detection_model.ObjectDetectionPLModel

Constructor

ObjectDetectionPLModel(cfg)

The loss function is managed internally by the detection model itself (Torchvision detection models return a loss_dict in training mode), so get_loss_function() returns None.

Behavior

  • training_step: calls self.model(images, targets), sums the returned loss_dict, logs total and individual losses
  • validation_step: computes loss in train mode, then switches to eval mode to compute box IoU; logs loss/val and metrics/val_iou
  • train_dataloader / val_dataloader: uses the dataset's collate_fn (required for detection batches with variable-length box lists)

InstanceSegmentationPLModel

Extends ObjectDetectionPLModel for instance segmentation (e.g., Mask R-CNN). No additional logic — the mask handling is done by the underlying Torchvision model.

Import path: pytorch_segmentation_models_trainer.model_loader.detection_model.InstanceSegmentationPLModel

Constructor

InstanceSegmentationPLModel(cfg)

Inherits all behavior from ObjectDetectionPLModel.


PolygonRNNPLModel

PyTorch Lightning model for vertex-by-vertex polygon prediction using a recurrent neural network.

Import path: pytorch_segmentation_models_trainer.model_loader.polygon_rnn_model.PolygonRNNPLModel

Constructor

PolygonRNNPLModel(cfg, grid_size=28)
ParameterTypeDefaultDescription
cfgDictConfigrequiredHydra configuration
grid_sizeint28Grid size for the polygon RNN output; output logits have grid_size * grid_size + 3 classes (grid cells + EOS + first vertex + padding)

Configuration Keys

KeyDescription
model.load_vggWhether to load pretrained VGG weights for the CNN encoder (default: False)
val_dataset.sequence_lengthMaximum sequence length for test-time polygon generation

Key Methods

MethodSignatureDescription
training_step(batch, batch_idx) -> TensorCalls compute(batch) then compute_loss_acc(batch, result); logs loss/train and metrics/train_acc
validation_step(batch, batch_idx) -> dictSame as training plus runs model.test(...) for sequence generation; evaluates PoLiS metric and IoU
compute(batch) -> TensorRuns self.model(image, x1, x2, x3) and reshapes output to (-1, grid_size*grid_size+3)
compute_loss_acc(batch, result) -> Tuple[Tensor, Tensor]Computes cross-entropy loss and per-token accuracy
evaluate_batch(batch, result) -> TupleConverts predicted/GT token sequences back to polygon vertex lists; computes batch PoLiS and IoU

The loss function is nn.CrossEntropyLoss() (fixed).