Skip to main content

Dataset Classes

This page is the API reference for all dataset classes in pytorch_segmentation_models_trainer.

Most classes live in:

from pytorch_segmentation_models_trainer.dataset_loader.dataset import <ClassName>

RasterPatchDataset lives in its own module:

from pytorch_segmentation_models_trainer.dataset_loader.raster_patch_dataset import RasterPatchDataset

MBTiles/mask-aligned datasets live in:

from pytorch_segmentation_models_trainer.dataset_loader.mbtiles_mask_dataset import (
MBTilesMaskWindowedDataset,
)

Image-only datasets live in:

from pytorch_segmentation_models_trainer.dataset_loader.image_dataset import (
ImageDataset,
CSVWindowedImageDataset,
TiledInferenceImageDataset,
AutoencoderDataset,
AutoencoderRandomCropDataset,
WindowedImageDataset,
WindowedImageAutoencoderDataset,
)

Class Hierarchy

torch.utils.data.Dataset
├── RasterPatchDataset ← sliding-window, folder-based (no CSV)
├── MBTilesMaskWindowedDataset ← MBTiles imagery aligned to mask windows
└── AbstractDataset ← CSV / DataFrame-based
├── ImageDataset
│ ├── AutoencoderDataset
│ ├── AutoencoderRandomCropDataset
│ ├── WindowedImageDataset
│ │ └── WindowedImageAutoencoderDataset
│ ├── CSVWindowedImageDataset
│ └── TiledInferenceImageDataset
├── SegmentationDataset
│ ├── SegmentationDatasetFromFolder
│ ├── CSVWindowedSegmentationDataset
│ └── FrameFieldSegmentationDataset
├── RandomCropSegmentationDataset
├── ObjectDetectionDataset
│ └── InstanceSegmentationDataset
└── PolygonRNNDataset

AutoencoderDataset

from pytorch_segmentation_models_trainer.dataset_loader.image_dataset import AutoencoderDataset

Whole-image dataset for reconstruction training. It returns the clean image as target; corruption_augmentation_list is applied only to image, while augmentation_list is applied synchronously to image and target.

__getitem__ Returns

KeyDescription
"image"Input image, optionally corrupted.
"target"Clean reconstruction target.
"path"Source image path.

AutoencoderRandomCropDataset

from pytorch_segmentation_models_trainer.dataset_loader.image_dataset import AutoencoderRandomCropDataset

Random-crop dataset for unlabeled image folders or CSV-backed full-size rasters. It scans image_dir recursively, filters by extension, optionally applies a deterministic train/validation split, and reads only the requested raster window at __getitem__ time.

Constructor Parameters

ParameterTypeDefaultDescription
image_dirstr | Path | NoneNoneRoot folder scanned recursively when no CSV/DataFrame is provided.
input_csv_pathstr | Path | NoneNoneOptional CSV with an image column.
image_extensionsList[str] | Nonecommon image extensionsExtensions used in folder mode. Leading dot is optional.
splitstr"all"One of "all", "train", "val".
val_fractionfloat0.2Fraction assigned to validation in folder mode.
split_seedint42Seed for deterministic splitting.
crop_sizeList[int][256, 256]Crop size [height, width].
samples_per_epochint10000Number of random crops per epoch; <= 0 estimates 3x area coverage.
selected_bandsList[int] | NoneNone1-based rasterio bands to read.
image_dtypestr"uint8""uint8", "uint16", "float32" or "native".
corruption_augmentation_listlistNoneAlbumentations pipeline applied only to image.

__getitem__ Returns

KeyDescription
"image"Random crop input.
"target"Clean random crop target.
"path"Source image path used for the crop.

WindowedImageDataset

from pytorch_segmentation_models_trainer.dataset_loader.image_dataset import WindowedImageDataset

Deterministic sliding-window dataset for image-only tasks over full-size rasters. It computes a global patch index across all images and reads only the requested window from disk at __getitem__ time.

Constructor Parameters

ParameterTypeDefaultDescription
image_dirstr | Path | NoneNoneFolder scanned recursively when no CSV/DataFrame is provided.
crop_sizeList[int][256, 256]Patch size as [height, width].
strideint | List[int] | Nonecrop_sizeStep between patch origins.
selected_bandsList[int] | NoneNone1-based rasterio bands to read.
image_dtypestr"uint8""uint8", "uint16", "float32" or "native".
file_cache_maxsizeint0Max open rasterio handles; 0 auto-sizes from indexed images.
verify_windowsboolFalseRead every candidate window during init and index only readable windows.
window_index_cachestr | Path | NoneNoneJSON cache for the verified window index. Rebuilt when paths, file metadata, crop, stride, bands, dtype, or image key change.
serialize_rasterio_readsboolFalseSerialize rasterio reads per source file across DataLoader workers.
rasterio_lock_dirstr | Path | NoneNoneDirectory for lock files when serialize_rasterio_reads=True; defaults to /tmp/psmt_rasterio_locks.
reopen_rasterio_on_readboolFalseOpen and close the raster inside each locked read instead of using the per-worker rasterio handle cache.

__getitem__ Returns

Keydtype / shapeDescription
"image"torch.float32, (C, H, W) without transformWindow image patch.
"path"strSource raster path.

WindowedImageAutoencoderDataset

from pytorch_segmentation_models_trainer.dataset_loader.image_dataset import WindowedImageAutoencoderDataset

Autoencoder variant of WindowedImageDataset. It returns the clean crop as target and optionally applies corruption_augmentation_list only to image. The verify_windows and window_index_cache parameters are inherited from WindowedImageDataset.

__getitem__ Returns

Keydtype / shapeDescription
"image"torch.float32, (C, H, W) without transformInput crop, optionally corrupted.
"target"torch.float32, (C, H, W) without transformClean reconstruction target.
"path"strSource raster path.

IterableWindowedImageDataset

from pytorch_segmentation_models_trainer.dataset_loader.image_dataset import IterableWindowedImageDataset

Iterable variant of WindowedImageDataset. It yields the same samples but shards source images across DataLoader workers so each worker reads complete windows from its own subset of rasters.

IterableWindowedImageAutoencoderDataset

from pytorch_segmentation_models_trainer.dataset_loader.image_dataset import IterableWindowedImageAutoencoderDataset

Iterable autoencoder variant. Use it for validation/testing with num_workers > 0 when concurrent reads from the same large GeoTIFF cause GDAL decoder errors. Configure the DataLoader with shuffle: false.


RasterPatchDataset

from pytorch_segmentation_models_trainer.dataset_loader.raster_patch_dataset import RasterPatchDataset

Systematic sliding-window dataset for semantic segmentation directly over full-size raster files. Discovers image/mask pairs by recursive folder scan, computes all valid patch_size × patch_size windows for each image, and maps a global 1-D index to the correct (image, row, column) in O(log N) via binary search. Only the requested window pixels are read from disk at __getitem__ time — full images are never loaded into memory.

See the Sliding-Window Patch Dataset user guide for usage examples and the full YAML reference.

Constructor Parameters

ParameterTypeDefaultDescription
image_dirstr | PathrequiredRoot directory of input images. Scanned recursively.
mask_dirstr | PathrequiredRoot directory of segmentation masks. Matching is by relative path.
extensionstr".tif"Image file extension. Leading dot is optional and normalized internally.
patch_sizeint256Side length of each square patch in pixels.
strideint128Step between patch origins. stride < patch_size produces overlapping patches.
mask_extensionstr | NoneNoneMask file extension. When None, uses the same value as extension.
augmentation_listlist | A.Compose | NoneNoneAlbumentations augmentation pipeline. Image is passed as (H, W, C), mask as (H, W).
data_loaderAny | NoneNoneDataLoader sub-configuration. Stored as ds.data_loader; consumed by the Lightning Model.
selected_bandsList[int] | NoneNone1-based rasterio band indices to load. None loads all bands.
image_dtypestr"uint8"Array dtype after reading. Accepted: "uint8", "uint16", "float32", "native".

Raises

ExceptionWhen
ValueErrorimage_dtype is not one of the accepted values.
ValueErrorselected_bands contains a non-positive integer.
ValueErrorNo valid image/mask pairs are found after scanning both directories.
UserWarningAn image is smaller than patch_size (skipped silently).
UserWarningA mask file is missing for a discovered image (skipped silently).

Key Attributes

AttributeTypeDescription
image_infoList[Dict]Per-image metadata: img_path, mask_path, height, width, patches_per_row, patches_per_col.
patch_sizeintPatch side length.
strideintStep between patch origins.
image_dtypestrConfigured dtype.
selected_bandsList[int] | NoneBand selection (1-based).
data_loaderAnyStored DataLoader config.

__len__

Returns the total number of patches across all images, not the number of images.

len(ds) = Σ_i patches_per_row_i × patches_per_col_i

__getitem__ Returns

Dict[str, torch.Tensor] with keys:

KeydtypeshapeNotes
"image"torch.float32(C, patch_size, patch_size)Normalised ÷255 (uint8) or ÷65535 (uint16) when no transform; unchanged for float32/native.
"mask"torch.int64(patch_size, patch_size)Raw pixel values from the mask file.

When augmentation_list is set the return type matches whatever the Albumentations pipeline produces (typically the same keys with transformed tensors after ToTensorV2).


AbstractDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import AbstractDataset

Base class for all datasets. Handles CSV loading, root directory resolution, and augmentation setup. Subclasses must implement __getitem__.

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPathNonePath to the CSV file. Required when df is not provided.
dfpd.DataFrameNonePre-loaded DataFrame. When provided, input_csv_path is ignored for loading but still stored.
root_dirAnyNoneRoot directory prepended to all relative file paths read from the CSV.
augmentation_listAnyNoneAlbumentations augmentation list or A.Compose object. None disables augmentation.
data_loaderAnyNoneDataLoader sub-configuration (see DataLoader Config Keys below).
image_keystr"image"CSV column name for image paths.
mask_keystr"mask"CSV column name for mask paths.
n_first_rows_to_readintNoneIf set, only the first N rows are read from the CSV via pd.read_csv(..., nrows=N).

Key Methods

MethodSignatureDescription
__len__() -> intReturns the number of rows in the dataset DataFrame.
get_path(idx, key=None, add_root_dir=True) -> strReturns the file path for item idx under the given CSV column key.
load_image(idx, key=None, is_mask=False, force_rgb=False, is_binary_mask=True) -> np.ndarrayLoads and returns a numpy array for the given item.
update_df(new_df) -> NoneReplaces the internal DataFrame and updates self.len.

CSVWindowedImageDataset

from pytorch_segmentation_models_trainer.dataset_loader.image_dataset import CSVWindowedImageDataset

Extends ImageDataset to read specific patches from large images using rasterio windowed read. Coordinates (offsets) for each patch are read from the input CSV. Does not include masks.

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPath | strNonePath to the CSV file.
dfpd.DataFrameNonePre-built DataFrame.
root_dirAnyNoneRoot directory for path resolution.
augmentation_listAnyNoneAlbumentations augmentation pipeline.
data_loaderAnyNoneDataLoader sub-configuration.
image_keystr"image"CSV column for image paths.
row_off_keystr"row_off"CSV column for vertical offset.
col_off_keystr"col_off"CSV column for horizontal offset.
patch_size_keystr"patch_size"CSV column for patch size.
n_first_rows_to_readintNoneLimit on rows to read.
selected_bandsOptional[List[int]]None1-based band indices to load.
use_rasterioboolTrueForces rasterio for windowed read.
image_dtypestr"uint8"Data type for rasterio-loaded images.

__getitem__ Returns

Dict[str, Any] with keys:

KeyTypeDescription
"image"np.ndarray or torch.TensorThe loaded patch, optionally transformed.
"path"strAbsolute path to the source image file.

ImageDataset

from pytorch_segmentation_models_trainer.dataset_loader.image_dataset import ImageDataset

Image-only dataset. Returns a dict with the loaded image and its file path. Suitable for inference pipelines that do not require ground-truth masks.

Constructor Parameters

Inherits all parameters from AbstractDataset. The mask_key parameter is accepted by the parent but not used.

ParameterTypeDefaultDescription
input_csv_pathPathNonePath to the CSV file.
dfpd.DataFrameNonePre-loaded DataFrame.
root_dirAnyNoneRoot directory for path resolution.
augmentation_listAnyNoneAlbumentations augmentation pipeline.
data_loaderAnyNoneDataLoader sub-configuration.
image_keystr"image"CSV column name for image paths.
n_first_rows_to_readintNoneLimit on rows to read.

__getitem__ Returns

Dict[str, Any] with keys:

KeyTypeDescription
"image"np.ndarray or torch.TensorThe loaded image, optionally transformed.
"path"strAbsolute path to the source image file.

SegmentationDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import SegmentationDataset

Semantic segmentation dataset. Loads image-mask pairs and supports both PIL (RGB) and rasterio (multi-band/multispectral) image loading.

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPathNonePath to the CSV file. Mutually exclusive with df; at least one must be provided.
dfpd.DataFrameNonePre-built DataFrame with image and mask columns. Allows creating the dataset without a CSV file on disk (e.g. via SegmentationDatasetFromFolder).
root_dirAnyNoneRoot directory for path resolution.
augmentation_listAnyNoneAlbumentations augmentation pipeline.
data_loaderAnyNoneDataLoader sub-configuration.
image_keystr"image"CSV column name for image paths.
mask_keystr"mask"CSV column name for mask paths.
n_first_rows_to_readintNoneLimit on rows to read.
n_classesint2Number of segmentation classes. When 2, masks are binarized (> 0).
selected_bandsOptional[List[int]]None1-based list of band indices to load via rasterio. E.g. [1, 2, 3] loads the first three bands. When None, all bands are loaded.
use_rasterioboolFalseWhen True, forces rasterio for image loading (recommended for multispectral imagery).
image_dtypestr"uint8"Data type applied to the image array after rasterio loading. Accepted values: "uint8", "uint16", "float32", "native". Only takes effect when rasterio is used (use_rasterio=True or selected_bands is set). "native" skips the cast entirely. Raises ValueError for unrecognised values.
reset_augmentation_functionboolFalseWhen True, deep-copies the augmentation pipeline before each call to prevent memory leaks from Albumentations caching.

__getitem__ Returns

Dict[str, Any] with keys:

KeyTypeDescription
"image"torch.Tensor (C, H, W)Float tensor. When no transform is set, automatically normalized: uint8/255, uint16/65535, float32/native → no division.
"mask"torch.Tensor (H, W)Long tensor of class labels. Binary (0/1) when n_classes == 2.

SegmentationDatasetFromFolder

from pytorch_segmentation_models_trainer.dataset_loader.dataset import SegmentationDatasetFromFolder

Extends SegmentationDataset to discover image/mask pairs recursively from two root folders, without requiring a CSV file on disk. Matching is performed by relative subfolder path and file stem (name without extension): only files present in both folders, inside the same subfolder and with the same filename, are included in the dataset.

Expected Folder Structure

images_root/
area_a/
tile_001.tif
tile_002.tif
area_b/
tile_003.tif

masks_root/
area_a/
tile_001.tif ← paired with images_root/area_a/tile_001.tif
tile_002.tif
area_b/
tile_003.tif

Only files with a matching (subfolder, stem) key in both trees are included. Files present in only one folder are silently ignored.

Constructor Parameters

ParameterTypeDefaultDescription
image_folderUnion[str, Path]requiredRoot folder containing the images.
mask_folderUnion[str, Path]requiredRoot folder containing the masks.
image_extensionstr".tif"File extension for images. The leading dot is optional and normalized internally (e.g. "tif" and ".tif" are equivalent).
mask_extensionOptional[str]NoneFile extension for masks. When None, uses the same value as image_extension.
augmentation_listAnyNoneAlbumentations augmentation pipeline.
data_loaderAnyNoneDataLoader sub-configuration.
n_classesint2Number of segmentation classes. When 2, masks are binarized (> 0).
selected_bandsOptional[List[int]]None1-based band indices to load via rasterio. None loads all bands.
use_rasterioboolFalseForces rasterio for image loading.
image_dtypestr"uint8"Data type for rasterio-loaded images. See SegmentationDataset.
reset_augmentation_functionboolFalseDeep-copy the transform to prevent Albumentations memory leaks.

Raises

ExceptionWhen
ValueErrorNo matching image/mask pairs are found (wrong extension, mismatched subfolder structure, etc.).

Instance Attributes

After construction the following extra attributes are available:

AttributeTypeDescription
image_folderPathResolved root folder for images.
mask_folderPathResolved root folder for masks.
image_extensionstrNormalized image extension (with leading dot).
mask_extensionstrNormalized mask extension (with leading dot).

__getitem__ Returns

Same as SegmentationDataset:

KeyTypeDescription
"image"torch.Tensor (C, H, W), float32Normalized image tensor.
"mask"torch.Tensor (H, W), int64Class-index mask.

Static Helper Methods

MethodDescription
_normalize_extension(ext)Ensures the extension starts with a dot.
_build_dataframe(image_folder, mask_folder, image_extension, mask_extension)Scans both folders recursively, matches pairs, and returns a pd.DataFrame with image and mask columns. Raises ValueError if no pairs are found.

CSVWindowedSegmentationDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import CSVWindowedSegmentationDataset

Extends SegmentationDataset to read specific patches from large GeoTIFFs using rasterio windowed read. Coordinates (offsets) for each patch are read from the input CSV.

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPath | strNonePath to the CSV file.
dfpd.DataFrameNonePre-built DataFrame (alternative to input_csv_path).
root_dirAnyNoneRoot directory for path resolution.
augmentation_listAnyNoneAlbumentations augmentation pipeline.
data_loaderAnyNoneDataLoader sub-configuration.
image_keystr"image"CSV column for image paths.
mask_keystr"mask"CSV column for mask paths.
row_off_keystr"row_off"CSV column for vertical offset.
col_off_keystr"col_off"CSV column for horizontal offset.
patch_size_keystr"patch_size"CSV column for patch size.
n_classesint2Number of classes. If 2, mask is binarized (> 0).
selected_bandsOptional[List[int]]None1-based band indices to load via rasterio.
use_rasterioboolTrueForces rasterio for windowed read.
image_dtypestr"uint8"Data type for rasterio-loaded images.

__getitem__ Returns

Same as SegmentationDataset:

KeyTypeDescription
"image"torch.Tensor (C, H, W), float32Normalized image tensor.
"mask"torch.Tensor (H, W), int64Class-index mask.

FrameFieldSegmentationDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import FrameFieldSegmentationDataset

Extends SegmentationDataset for frame-field polygon learning. Loads multiple auxiliary masks (boundary, vertex, crossfield angle, distance transform, size map) in addition to the primary segmentation mask.

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPathrequiredPath to the CSV file.
root_dirAnyNoneRoot directory for path resolution.
augmentation_listAnyNoneAlbumentations augmentation pipeline.
data_loaderAnyNoneDataLoader sub-configuration.
image_keystr"image"CSV column name for image paths.
mask_keystr"polygon_mask"CSV column name for the primary polygon mask.
multi_band_maskboolFalseWhen True, all three masks (polygon, boundary, vertex) are packed into a single multi-band image file.
boundary_mask_keystr"boundary_mask"CSV column for the boundary mask.
return_boundary_maskboolTrueWhether to load and return the boundary mask.
vertex_mask_keystr"vertex_mask"CSV column for the vertex mask.
return_vertex_maskboolTrueWhether to load and return the vertex mask.
n_first_rows_to_readintNoneLimit on rows to read.
return_crossfield_maskboolTrueWhether to load the crossfield angle mask.
crossfield_mask_keystr"crossfield_mask"CSV column for the crossfield angle image.
return_distance_maskboolTrueWhether to load the distance transform mask.
distance_mask_keystr"distance_mask"CSV column for the distance map.
return_size_maskboolTrueWhether to load the size map.
size_mask_keystr"size_mask"CSV column for the size map.
image_widthint224Target image width used in the fallback transform when an augmentation crop is invalid.
image_heightint224Target image height used in the fallback transform.
gpu_augmentation_listAnyNoneReserved for GPU-side augmentation (not currently applied inside __getitem__).

__getitem__ Returns

Dict[str, Any] with keys:

KeyPresent whenTypeDescription
"idx"alwaysintDataset index of this item.
"path"alwaysstrFile path of the source image.
"image"alwaystorch.TensorInput image tensor.
"gt_polygons_image"alwaystorch.Tensor (C, H, W)Stacked polygon/boundary/vertex masks.
"class_freq"alwaystorch.TensorPer-channel class frequency used for loss weighting.
"gt_crossfield_angle"return_crossfield_mask=Truetorch.Tensor (1, H, W)Crossfield angle map in radians.
"distances"return_distance_mask=Truetorch.Tensor (1, H, W)Normalized distance transform.
"sizes"return_size_mask=Truetorch.Tensor (1, H, W)Object size map.

ObjectDetectionDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import ObjectDetectionDataset

Object detection dataset. Loads images with bounding boxes and class labels from JSON annotation files.

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPathrequiredPath to the CSV file.
root_dirAnyNoneRoot directory for path resolution.
augmentation_listAnyNoneAlbumentations augmentation pipeline with bbox support.
data_loaderAnyNoneDataLoader sub-configuration.
image_keystr"image"CSV column name for image paths.
mask_keystr"mask"CSV column name (not used for detection but inherited).
bounding_box_keystr"bounding_boxes"CSV column pointing to JSON files with bounding box annotations.
n_first_rows_to_readintNoneLimit on rows to read.
bbox_formatstr"xywh"Input bounding box format ("xywh" or "xyxy").
bbox_output_formatstr"xyxy"Output bounding box format after conversion ("xywh" or "xyxy").
bbox_paramsAnyNoneAlbumentations BboxParams (or equivalent dict/config) passed to A.Compose.

__getitem__ Returns

Tuple[torch.Tensor, Dict[str, torch.Tensor], int]:

PositionTypeDescription
[0]torch.TensorImage tensor (RGB, float32).
[1]dictDict with keys "boxes" (float32 tensor of bounding boxes) and "labels" (int64 tensor of class indices).
[2]intDataset index of the item.

InstanceSegmentationDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import InstanceSegmentationDataset

Extends ObjectDetectionDataset with optional per-instance mask and keypoint loading.

Constructor Parameters

Inherits all parameters from ObjectDetectionDataset, plus:

ParameterTypeDefaultDescription
keypoint_keystr"keypoints"CSV column pointing to JSON files with keypoint or polygon annotations.
return_maskboolTrueWhen True, loads and returns per-instance binary masks.
return_keypointsboolFalseWhen True, loads and attaches keypoint data from the JSON file at keypoint_key.

The mask_key default is overridden to "polygon_mask".

__getitem__ Returns

Same tuple structure as ObjectDetectionDataset. When return_mask=True the target dict additionally contains:

KeyTypeDescription
"masks"torch.Tensor (uint8)Per-instance binary mask tensor.

PolygonRNNDataset

from pytorch_segmentation_models_trainer.dataset_loader.dataset import PolygonRNNDataset

Dataset for Polygon-RNN training and validation. Reads pre-cropped images and normalized polygon JSON files produced by the Dataset Conversion pipeline.

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPathrequiredPath to the CSV file (typically the output of PolygonRNNDatasetConversionStrategy).
sequence_lengthint60Maximum number of polygon vertices in the output sequence. Shorter polygons are padded.
root_dirAnyNoneRoot directory for path resolution.
augmentation_listAnyNoneAlbumentations augmentation pipeline.
data_loaderAnyNoneDataLoader sub-configuration.
image_keystr"image"CSV column for cropped image paths.
mask_keystr"mask"CSV column for polygon JSON paths.
image_width_keystrNoneCSV column for image width (not used internally but stored for downstream use).
image_height_keystrNoneCSV column for image height.
scale_h_keystr"scale_h"CSV column for vertical scale factor.
scale_w_keystr"scale_w"CSV column for horizontal scale factor.
min_col_keystr"min_col"CSV column for the crop left boundary (column).
min_row_keystr"min_row"CSV column for the crop top boundary (row).
original_image_path_keystr"original_image_path"CSV column for the source full-image path.
original_polygon_keystr"original_polygon_wkt"CSV column for the WKT representation of the original polygon.
n_first_rows_to_readstrNoneLimit on rows to read.
dataset_typestr"train"Either "train" or "val". Validation mode returns additional metadata fields per item.
grid_sizeint28Grid resolution used for polygon label encoding.

__getitem__ Returns

Dict[str, Any] with keys:

KeyPresent whenTypeDescription
"x1"alwaystorch.Tensor (float32)Polygon label array at step 2 (first input to RNN).
"x2"alwaystorch.Tensor (float32)Polygon label array from step 0 to T-2 (main RNN input sequence).
"x3"alwaystorch.Tensor (float32)Polygon label array from step 1 to T-1 (shifted sequence).
"ta"alwaystorch.Tensor (int64)Target index array from step 2 to T (supervised labels).
"image"alwaystorch.Tensor (float32)Cropped image tensor.
"polygon_wkt"dataset_type="val"strWKT of the original polygon for metric computation.
"scale_h"dataset_type="val"floatVertical scale factor for back-projection.
"scale_w"dataset_type="val"floatHorizontal scale factor for back-projection.
"min_col"dataset_type="val"floatCrop left boundary for back-projection.
"min_row"dataset_type="val"floatCrop top boundary for back-projection.
"original_image_path"dataset_type="val"strSource full-image path for visualization.

TiledInferenceImageDataset

from pytorch_segmentation_models_trainer.dataset_loader.image_dataset import TiledInferenceImageDataset

Extends ImageDataset for large-image inference. Slices each image into overlapping tiles so that a fixed-input-size model can process arbitrarily large images. Uses pytorch_toolbelt.inference.tiles.ImageSlicer internally.

Constructor Parameters

Inherits all parameters from ImageDataset, plus:

ParameterTypeDefaultDescription
normalize_outputboolTrueWhen True, applies A.Normalize() (ImageNet mean/std) to each image before tiling.
pad_if_neededboolFalseWhen True, pads the image to the nearest multiple of model_input_shape before tiling. Requires model_input_shape to be set.
model_input_shapeAnyNoneTuple (H, W) of the model's expected input tile size. Required when pad_if_needed=True.
step_shapeAny(224, 224)Tile step (stride) for the ImageSlicer. Smaller values increase overlap.

The augmentation_list parameter is accepted but ignored; normalization is controlled by normalize_output instead.

__getitem__ Returns

Dict[str, Any] with keys:

KeyTypeDescription
"path"strFile path of the source image.
"tiles"torch.Tensor (N_tiles, C, H, W)Stack of all tiles from this image.
"tile_image_idx"torch.Tensor (N_tiles,)Image index repeated for each tile (for reassembly).
"tiler_object"ImageSlicerSlicer object needed to merge tile predictions back.
"original_shape"tuple(width, height) of the original image.

A collate_fn static method is provided for use with DataLoader to correctly batch variable-tile-count images.


DataLoader Config Keys

All dataset classes accept a data_loader sub-configuration. When used with Hydra this is typically a nested config node. The following keys are recognised:

KeyTypeDefaultDescription
shuffleboolTrueWhether to shuffle the dataset each epoch. Typically True for training, False for validation.
num_workersint0Number of worker processes for data loading.
pin_memoryboolFalseWhether to pin memory for faster GPU transfer.
batch_sizeint8Number of samples per batch.
drop_lastboolFalseWhether to drop the last incomplete batch.

Example:

data_loader:
shuffle: true
num_workers: 4
pin_memory: true
batch_size: 16
drop_last: false

CSV Format Conventions

Each dataset type expects specific columns in the input CSV:

SegmentationDataset / FrameFieldSegmentationDataset

ColumnRequiredDescription
imageyesPath to the input image (RGB or multi-band).
mask / polygon_maskyesPath to the ground-truth mask image.
boundary_maskFrameField onlyPath to the boundary mask image.
vertex_maskFrameField onlyPath to the vertex mask image.
crossfield_maskFrameField onlyPath to the crossfield angle image.
distance_maskFrameField onlyPath to the distance transform image.
size_maskFrameField onlyPath to the size map image.
class_freqoptionalPre-computed class frequency string to skip per-sample computation.

ObjectDetectionDataset / InstanceSegmentationDataset

ColumnRequiredDescription
imageyesPath to the input image.
bounding_boxesyesPath to a JSON file with a list of {"bbox": [...], "class": N} entries.
polygon_maskInstanceSeg onlyPath to the binary mask image.
keypointsInstanceSeg onlyPath to a JSON file with a "keypoints" key.

PolygonRNNDataset

ColumnRequiredDescription
imageyesPath to the cropped object image.
maskyesPath to the normalized polygon JSON file.
scale_hyesVertical scale factor for coordinate back-projection.
scale_wyesHorizontal scale factor.
min_colyesLeft boundary of the original crop.
min_rowyesTop boundary of the original crop.
original_image_pathyes (val)Path to the source full-size image.
original_polygon_wktyes (val)WKT of the original polygon.