Skip to main content

Dataset Classes

This page is the API reference for all dataset classes in pytorch_segmentation_models_trainer.

Most classes live in:

from pytorch_segmentation_models_trainer.dataset_loader.dataset import <ClassName>

RasterPatchDataset lives in its own module:

from pytorch_segmentation_models_trainer.dataset_loader.raster_patch_dataset import RasterPatchDataset

Class Hierarchy

torch.utils.data.Dataset
├── RasterPatchDataset ← sliding-window, folder-based (no CSV)
└── AbstractDataset ← CSV / DataFrame-based
├── ImageDataset
│ └── TiledInferenceImageDataset
├── SegmentationDataset
│ ├── SegmentationDatasetFromFolder
│ └── FrameFieldSegmentationDataset
├── RandomCropSegmentationDataset
├── ObjectDetectionDataset
│ └── InstanceSegmentationDataset
└── PolygonRNNDataset

RasterPatchDataset

from pytorch_segmentation_models_trainer.dataset_loader.raster_patch_dataset import RasterPatchDataset

Systematic sliding-window dataset for semantic segmentation directly over full-size raster files. Discovers image/mask pairs by recursive folder scan, computes all valid patch_size × patch_size windows for each image, and maps a global 1-D index to the correct (image, row, column) in O(log N) via binary search. Only the requested window pixels are read from disk at __getitem__ time — full images are never loaded into memory.

See the Sliding-Window Patch Dataset user guide for usage examples and the full YAML reference.

Constructor Parameters

ParameterTypeDefaultDescription
image_dirstr | PathrequiredRoot directory of input images. Scanned recursively.
mask_dirstr | PathrequiredRoot directory of segmentation masks. Matching is by relative path.
extensionstr".tif"Image file extension. Leading dot is optional and normalized internally.
patch_sizeint256Side length of each square patch in pixels.
strideint128Step between patch origins. stride < patch_size produces overlapping patches.
mask_extensionstr | NoneNoneMask file extension. When None, uses the same value as extension.
augmentation_listlist | A.Compose | NoneNoneAlbumentations augmentation pipeline. Image is passed as (H, W, C), mask as (H, W).
data_loaderAny | NoneNoneDataLoader sub-configuration. Stored as ds.data_loader; consumed by the Lightning Model.
selected_bandsList[int] | NoneNone1-based rasterio band indices to load. None loads all bands.
image_dtypestr"uint8"Array dtype after reading. Accepted: "uint8", "uint16", "float32", "native".

Raises

ExceptionWhen
ValueErrorimage_dtype is not one of the accepted values.
ValueErrorselected_bands contains a non-positive integer.
ValueErrorNo valid image/mask pairs are found after scanning both directories.
UserWarningAn image is smaller than patch_size (skipped silently).
UserWarningA mask file is missing for a discovered image (skipped silently).

Key Attributes

AttributeTypeDescription
image_infoList[Dict]Per-image metadata: img_path, mask_path, height, width, patches_per_row, patches_per_col.
patch_sizeintPatch side length.
strideintStep between patch origins.
image_dtypestrConfigured dtype.
selected_bandsList[int] | NoneBand selection (1-based).
data_loaderAnyStored DataLoader config.

__len__

Returns the total number of patches across all images, not the number of images.

len(ds) = Σ_i  patches_per_row_i × patches_per_col_i

__getitem__ Returns

Dict[str, torch.Tensor] with keys:

KeydtypeshapeNotes
"image"torch.float32(C, patch_size, patch_size)Normalised ÷255 (uint8) or ÷65535 (uint16) when no transform; unchanged for float32/native.
"mask"torch.int64(patch_size, patch_size)Raw pixel values from the mask file.

When augmentation_list is set the return type matches whatever the Albumentations pipeline produces (typically the same keys with transformed tensors after ToTensorV2).


AbstractDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import AbstractDataset

Base class for all datasets. Handles CSV loading, root directory resolution, and augmentation setup. Subclasses must implement __getitem__.

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPathNonePath to the CSV file. Required when df is not provided.
dfpd.DataFrameNonePre-loaded DataFrame. When provided, input_csv_path is ignored for loading but still stored.
root_dirAnyNoneRoot directory prepended to all relative file paths read from the CSV.
augmentation_listAnyNoneAlbumentations augmentation list or A.Compose object. None disables augmentation.
data_loaderAnyNoneDataLoader sub-configuration (see DataLoader Config Keys below).
image_keystr"image"CSV column name for image paths.
mask_keystr"mask"CSV column name for mask paths.
n_first_rows_to_readintNoneIf set, only the first N rows are read from the CSV via pd.read_csv(..., nrows=N).

Key Methods

MethodSignatureDescription
__len__() -> intReturns the number of rows in the dataset DataFrame.
get_path(idx, key=None, add_root_dir=True) -> strReturns the file path for item idx under the given CSV column key.
load_image(idx, key=None, is_mask=False, force_rgb=False, is_binary_mask=True) -> np.ndarrayLoads and returns a numpy array for the given item.
update_df(new_df) -> NoneReplaces the internal DataFrame and updates self.len.

ImageDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import ImageDataset

Image-only dataset. Returns a dict with the loaded image and its file path. Suitable for inference pipelines that do not require ground-truth masks.

Constructor Parameters

Inherits all parameters from AbstractDataset. The mask_key parameter is accepted by the parent but not used.

ParameterTypeDefaultDescription
input_csv_pathPathNonePath to the CSV file.
dfpd.DataFrameNonePre-loaded DataFrame.
root_dirAnyNoneRoot directory for path resolution.
augmentation_listAnyNoneAlbumentations augmentation pipeline.
data_loaderAnyNoneDataLoader sub-configuration.
image_keystr"image"CSV column name for image paths.
n_first_rows_to_readintNoneLimit on rows to read.

__getitem__ Returns

Dict[str, Any] with keys:

KeyTypeDescription
"image"np.ndarray or torch.TensorThe loaded image, optionally transformed.
"path"strAbsolute path to the source image file.

SegmentationDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import SegmentationDataset

Semantic segmentation dataset. Loads image-mask pairs and supports both PIL (RGB) and rasterio (multi-band/multispectral) image loading.

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPathNonePath to the CSV file. Mutually exclusive with df; at least one must be provided.
dfpd.DataFrameNonePre-built DataFrame with image and mask columns. Allows creating the dataset without a CSV file on disk (e.g. via SegmentationDatasetFromFolder).
root_dirAnyNoneRoot directory for path resolution.
augmentation_listAnyNoneAlbumentations augmentation pipeline.
data_loaderAnyNoneDataLoader sub-configuration.
image_keystr"image"CSV column name for image paths.
mask_keystr"mask"CSV column name for mask paths.
n_first_rows_to_readintNoneLimit on rows to read.
n_classesint2Number of segmentation classes. When 2, masks are binarized (> 0).
selected_bandsOptional[List[int]]None1-based list of band indices to load via rasterio. E.g. [1, 2, 3] loads the first three bands. When None, all bands are loaded.
use_rasterioboolFalseWhen True, forces rasterio for image loading (recommended for multispectral imagery).
image_dtypestr"uint8"Data type applied to the image array after rasterio loading. Accepted values: "uint8", "uint16", "float32", "native". Only takes effect when rasterio is used (use_rasterio=True or selected_bands is set). "native" skips the cast entirely. Raises ValueError for unrecognised values.
reset_augmentation_functionboolFalseWhen True, deep-copies the augmentation pipeline before each call to prevent memory leaks from Albumentations caching.

__getitem__ Returns

Dict[str, Any] with keys:

KeyTypeDescription
"image"torch.Tensor (C, H, W)Float tensor. When no transform is set, automatically normalized: uint8/255, uint16/65535, float32/native → no division.
"mask"torch.Tensor (H, W)Long tensor of class labels. Binary (0/1) when n_classes == 2.

SegmentationDatasetFromFolder

from pytorch_segmentation_models_trainer.dataset_loader.dataset import SegmentationDatasetFromFolder

Extends SegmentationDataset to discover image/mask pairs recursively from two root folders, without requiring a CSV file on disk. Matching is performed by relative subfolder path and file stem (name without extension): only files present in both folders, inside the same subfolder and with the same filename, are included in the dataset.

Expected Folder Structure

images_root/
area_a/
tile_001.tif
tile_002.tif
area_b/
tile_003.tif

masks_root/
area_a/
tile_001.tif ← paired with images_root/area_a/tile_001.tif
tile_002.tif
area_b/
tile_003.tif

Only files with a matching (subfolder, stem) key in both trees are included. Files present in only one folder are silently ignored.

Constructor Parameters

ParameterTypeDefaultDescription
image_folderUnion[str, Path]requiredRoot folder containing the images.
mask_folderUnion[str, Path]requiredRoot folder containing the masks.
image_extensionstr".tif"File extension for images. The leading dot is optional and normalized internally (e.g. "tif" and ".tif" are equivalent).
mask_extensionOptional[str]NoneFile extension for masks. When None, uses the same value as image_extension.
augmentation_listAnyNoneAlbumentations augmentation pipeline.
data_loaderAnyNoneDataLoader sub-configuration.
n_classesint2Number of segmentation classes. When 2, masks are binarized (> 0).
selected_bandsOptional[List[int]]None1-based band indices to load via rasterio. None loads all bands.
use_rasterioboolFalseForces rasterio for image loading.
image_dtypestr"uint8"Data type for rasterio-loaded images. See SegmentationDataset.
reset_augmentation_functionboolFalseDeep-copy the transform to prevent Albumentations memory leaks.

Raises

ExceptionWhen
ValueErrorNo matching image/mask pairs are found (wrong extension, mismatched subfolder structure, etc.).

Instance Attributes

After construction the following extra attributes are available:

AttributeTypeDescription
image_folderPathResolved root folder for images.
mask_folderPathResolved root folder for masks.
image_extensionstrNormalized image extension (with leading dot).
mask_extensionstrNormalized mask extension (with leading dot).

__getitem__ Returns

Same as SegmentationDataset:

KeyTypeDescription
"image"torch.Tensor (C, H, W), float32Normalized image tensor.
"mask"torch.Tensor (H, W), int64Class-index mask.

Static Helper Methods

MethodDescription
_normalize_extension(ext)Ensures the extension starts with a dot.
_build_dataframe(image_folder, mask_folder, image_extension, mask_extension)Scans both folders recursively, matches pairs, and returns a pd.DataFrame with image and mask columns. Raises ValueError if no pairs are found.

FrameFieldSegmentationDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import FrameFieldSegmentationDataset

Extends SegmentationDataset for frame-field polygon learning. Loads multiple auxiliary masks (boundary, vertex, crossfield angle, distance transform, size map) in addition to the primary segmentation mask.

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPathrequiredPath to the CSV file.
root_dirAnyNoneRoot directory for path resolution.
augmentation_listAnyNoneAlbumentations augmentation pipeline.
data_loaderAnyNoneDataLoader sub-configuration.
image_keystr"image"CSV column name for image paths.
mask_keystr"polygon_mask"CSV column name for the primary polygon mask.
multi_band_maskboolFalseWhen True, all three masks (polygon, boundary, vertex) are packed into a single multi-band image file.
boundary_mask_keystr"boundary_mask"CSV column for the boundary mask.
return_boundary_maskboolTrueWhether to load and return the boundary mask.
vertex_mask_keystr"vertex_mask"CSV column for the vertex mask.
return_vertex_maskboolTrueWhether to load and return the vertex mask.
n_first_rows_to_readintNoneLimit on rows to read.
return_crossfield_maskboolTrueWhether to load the crossfield angle mask.
crossfield_mask_keystr"crossfield_mask"CSV column for the crossfield angle image.
return_distance_maskboolTrueWhether to load the distance transform mask.
distance_mask_keystr"distance_mask"CSV column for the distance map.
return_size_maskboolTrueWhether to load the size map.
size_mask_keystr"size_mask"CSV column for the size map.
image_widthint224Target image width used in the fallback transform when an augmentation crop is invalid.
image_heightint224Target image height used in the fallback transform.
gpu_augmentation_listAnyNoneReserved for GPU-side augmentation (not currently applied inside __getitem__).

__getitem__ Returns

Dict[str, Any] with keys:

KeyPresent whenTypeDescription
"idx"alwaysintDataset index of this item.
"path"alwaysstrFile path of the source image.
"image"alwaystorch.TensorInput image tensor.
"gt_polygons_image"alwaystorch.Tensor (C, H, W)Stacked polygon/boundary/vertex masks.
"class_freq"alwaystorch.TensorPer-channel class frequency used for loss weighting.
"gt_crossfield_angle"return_crossfield_mask=Truetorch.Tensor (1, H, W)Crossfield angle map in radians.
"distances"return_distance_mask=Truetorch.Tensor (1, H, W)Normalized distance transform.
"sizes"return_size_mask=Truetorch.Tensor (1, H, W)Object size map.

ObjectDetectionDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import ObjectDetectionDataset

Object detection dataset. Loads images with bounding boxes and class labels from JSON annotation files.

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPathrequiredPath to the CSV file.
root_dirAnyNoneRoot directory for path resolution.
augmentation_listAnyNoneAlbumentations augmentation pipeline with bbox support.
data_loaderAnyNoneDataLoader sub-configuration.
image_keystr"image"CSV column name for image paths.
mask_keystr"mask"CSV column name (not used for detection but inherited).
bounding_box_keystr"bounding_boxes"CSV column pointing to JSON files with bounding box annotations.
n_first_rows_to_readintNoneLimit on rows to read.
bbox_formatstr"xywh"Input bounding box format ("xywh" or "xyxy").
bbox_output_formatstr"xyxy"Output bounding box format after conversion ("xywh" or "xyxy").
bbox_paramsAnyNoneAlbumentations BboxParams (or equivalent dict/config) passed to A.Compose.

__getitem__ Returns

Tuple[torch.Tensor, Dict[str, torch.Tensor], int]:

PositionTypeDescription
[0]torch.TensorImage tensor (RGB, float32).
[1]dictDict with keys "boxes" (float32 tensor of bounding boxes) and "labels" (int64 tensor of class indices).
[2]intDataset index of the item.

InstanceSegmentationDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import InstanceSegmentationDataset

Extends ObjectDetectionDataset with optional per-instance mask and keypoint loading.

Constructor Parameters

Inherits all parameters from ObjectDetectionDataset, plus:

ParameterTypeDefaultDescription
keypoint_keystr"keypoints"CSV column pointing to JSON files with keypoint or polygon annotations.
return_maskboolTrueWhen True, loads and returns per-instance binary masks.
return_keypointsboolFalseWhen True, loads and attaches keypoint data from the JSON file at keypoint_key.

The mask_key default is overridden to "polygon_mask".

__getitem__ Returns

Same tuple structure as ObjectDetectionDataset. When return_mask=True the target dict additionally contains:

KeyTypeDescription
"masks"torch.Tensor (uint8)Per-instance binary mask tensor.

PolygonRNNDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import PolygonRNNDataset

Dataset for Polygon-RNN training and validation. Reads pre-cropped images and normalized polygon JSON files produced by the Dataset Conversion pipeline.

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPathrequiredPath to the CSV file (typically the output of PolygonRNNDatasetConversionStrategy).
sequence_lengthint60Maximum number of polygon vertices in the output sequence. Shorter polygons are padded.
root_dirAnyNoneRoot directory for path resolution.
augmentation_listAnyNoneAlbumentations augmentation pipeline.
data_loaderAnyNoneDataLoader sub-configuration.
image_keystr"image"CSV column for cropped image paths.
mask_keystr"mask"CSV column for polygon JSON paths.
image_width_keystrNoneCSV column for image width (not used internally but stored for downstream use).
image_height_keystrNoneCSV column for image height.
scale_h_keystr"scale_h"CSV column for vertical scale factor.
scale_w_keystr"scale_w"CSV column for horizontal scale factor.
min_col_keystr"min_col"CSV column for the crop left boundary (column).
min_row_keystr"min_row"CSV column for the crop top boundary (row).
original_image_path_keystr"original_image_path"CSV column for the source full-image path.
original_polygon_keystr"original_polygon_wkt"CSV column for the WKT representation of the original polygon.
n_first_rows_to_readstrNoneLimit on rows to read.
dataset_typestr"train"Either "train" or "val". Validation mode returns additional metadata fields per item.
grid_sizeint28Grid resolution used for polygon label encoding.

__getitem__ Returns

Dict[str, Any] with keys:

KeyPresent whenTypeDescription
"x1"alwaystorch.Tensor (float32)Polygon label array at step 2 (first input to RNN).
"x2"alwaystorch.Tensor (float32)Polygon label array from step 0 to T-2 (main RNN input sequence).
"x3"alwaystorch.Tensor (float32)Polygon label array from step 1 to T-1 (shifted sequence).
"ta"alwaystorch.Tensor (int64)Target index array from step 2 to T (supervised labels).
"image"alwaystorch.Tensor (float32)Cropped image tensor.
"polygon_wkt"dataset_type="val"strWKT of the original polygon for metric computation.
"scale_h"dataset_type="val"floatVertical scale factor for back-projection.
"scale_w"dataset_type="val"floatHorizontal scale factor for back-projection.
"min_col"dataset_type="val"floatCrop left boundary for back-projection.
"min_row"dataset_type="val"floatCrop top boundary for back-projection.
"original_image_path"dataset_type="val"strSource full-image path for visualization.

TiledInferenceImageDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import TiledInferenceImageDataset

Extends ImageDataset for large-image inference. Slices each image into overlapping tiles so that a fixed-input-size model can process arbitrarily large images. Uses pytorch_toolbelt.inference.tiles.ImageSlicer internally.

Constructor Parameters

Inherits all parameters from ImageDataset, plus:

ParameterTypeDefaultDescription
normalize_outputboolTrueWhen True, applies A.Normalize() (ImageNet mean/std) to each image before tiling.
pad_if_neededboolFalseWhen True, pads the image to the nearest multiple of model_input_shape before tiling. Requires model_input_shape to be set.
model_input_shapeAnyNoneTuple (H, W) of the model's expected input tile size. Required when pad_if_needed=True.
step_shapeAny(224, 224)Tile step (stride) for the ImageSlicer. Smaller values increase overlap.

The augmentation_list parameter is accepted but ignored; normalization is controlled by normalize_output instead.

__getitem__ Returns

Dict[str, Any] with keys:

KeyTypeDescription
"path"strFile path of the source image.
"tiles"torch.Tensor (N_tiles, C, H, W)Stack of all tiles from this image.
"tile_image_idx"torch.Tensor (N_tiles,)Image index repeated for each tile (for reassembly).
"tiler_object"ImageSlicerSlicer object needed to merge tile predictions back.
"original_shape"tuple(width, height) of the original image.

A collate_fn static method is provided for use with DataLoader to correctly batch variable-tile-count images.


DataLoader Config Keys

All dataset classes accept a data_loader sub-configuration. When used with Hydra this is typically a nested config node. The following keys are recognised:

KeyTypeDefaultDescription
shuffleboolTrueWhether to shuffle the dataset each epoch. Typically True for training, False for validation.
num_workersint0Number of worker processes for data loading.
pin_memoryboolFalseWhether to pin memory for faster GPU transfer.
batch_sizeint8Number of samples per batch.
drop_lastboolFalseWhether to drop the last incomplete batch.

Example:

data_loader:
shuffle: true
num_workers: 4
pin_memory: true
batch_size: 16
drop_last: false

CSV Format Conventions

Each dataset type expects specific columns in the input CSV:

SegmentationDataset / FrameFieldSegmentationDataset

ColumnRequiredDescription
imageyesPath to the input image (RGB or multi-band).
mask / polygon_maskyesPath to the ground-truth mask image.
boundary_maskFrameField onlyPath to the boundary mask image.
vertex_maskFrameField onlyPath to the vertex mask image.
crossfield_maskFrameField onlyPath to the crossfield angle image.
distance_maskFrameField onlyPath to the distance transform image.
size_maskFrameField onlyPath to the size map image.
class_freqoptionalPre-computed class frequency string to skip per-sample computation.

ObjectDetectionDataset / InstanceSegmentationDataset

ColumnRequiredDescription
imageyesPath to the input image.
bounding_boxesyesPath to a JSON file with a list of {"bbox": [...], "class": N} entries.
polygon_maskInstanceSeg onlyPath to the binary mask image.
keypointsInstanceSeg onlyPath to a JSON file with a "keypoints" key.

PolygonRNNDataset

ColumnRequiredDescription
imageyesPath to the cropped object image.
maskyesPath to the normalized polygon JSON file.
scale_hyesVertical scale factor for coordinate back-projection.
scale_wyesHorizontal scale factor.
min_colyesLeft boundary of the original crop.
min_rowyesTop boundary of the original crop.
original_image_pathyes (val)Path to the source full-size image.
original_polygon_wktyes (val)WKT of the original polygon.