Skip to main content

Building a Segmentation Dataset

This guide explains how to prepare data and configure the SegmentationDataset class for semantic segmentation training. The dataset system is built around a CSV index file that points to image and mask pairs on disk.

CSV-Based Dataset Format

All dataset classes in this project use a CSV file as their index. Each row describes one training sample. The minimal required columns for segmentation are image and mask.

ColumnDescription
imagePath to the input image (absolute or relative to root_dir)
maskPath to the corresponding segmentation mask

Paths may be absolute or relative. When root_dir is provided in the dataset config, relative paths are resolved against it.

Example CSV

image,mask
images/tile_001.tif,polygon_masks/tile_001.png
images/tile_002.tif,polygon_masks/tile_002.png
images/tile_003.tif,polygon_masks/tile_003.png

Generating a CSV with Python

import os
import pandas as pd
from pathlib import Path

images_dir = Path("/data/my_dataset/images")
masks_dir = Path("/data/my_dataset/polygon_masks")

rows = []
for img_path in sorted(images_dir.glob("*.tif")):
mask_path = masks_dir / img_path.with_suffix(".png").name
if mask_path.exists():
rows.append({"image": str(img_path), "mask": str(mask_path)})

df = pd.DataFrame(rows)
df.to_csv("/data/my_dataset/train.csv", index=False)
print(f"Wrote {len(df)} rows")

Supported Image Formats

The dataset supports any format readable by PIL (default) or rasterio (when use_rasterio: true):

  • JPG / JPEG — standard RGB photographs
  • PNG — lossless RGB or RGBA images
  • TIFF / GeoTIFF — remote sensing imagery, including multispectral and hyperspectral data
tip

For GeoTIFF files with more than three bands or with a specific band ordering, always set use_rasterio: true and specify selected_bands.

Mask Format

Masks must be single-channel (grayscale) PNG files where each pixel value encodes a class:

Pixel valueMeaning
0Background
1Class 1
2Class 2
Additional classes
255Foreground (binary segmentation)

When n_classes=2 (the default), the dataset automatically binarises the mask: any non-zero value becomes 1, and zero stays 0. For multi-class masks, set n_classes to the actual number of classes and pixel values should match class indices directly.

The SegmentationDataset Class

SegmentationDataset extends the abstract base class and handles image/mask loading, optional rasterio-based reading, band selection, and augmentation.

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPathrequiredPath to the CSV index file
root_dirstrNoneRoot directory prepended to relative paths in the CSV
augmentation_listlistNoneList of albumentations transforms
data_loaderconfig objectNoneDataLoader keyword arguments
image_keystr"image"CSV column name for the image path
mask_keystr"mask"CSV column name for the mask path
n_first_rows_to_readintNoneLimit the number of CSV rows read (useful for quick experiments)
n_classesint2Number of segmentation classes
selected_bandsList[int]None1-based list of band indices to read (e.g. [1, 2, 3])
use_rasterioboolFalseUse rasterio instead of PIL for image loading
image_dtypestr"uint8"Data type for image interpretation when using rasterio. Accepted values: "uint8", "uint16", "float32", "native". See Image Dtype below.
reset_augmentation_functionboolFalseDeep-copy the transform on each call to prevent albumentations memory leaks

data_loader Sub-Config

The data_loader config block is passed directly as keyword arguments to PyTorch's DataLoader. Supported fields:

FieldTypeDefaultDescription
shufflebooltrueShuffle samples every epoch
num_workersint4Number of parallel data loading workers
pin_memorybooltruePin memory for faster GPU transfer
batch_sizeint8Number of samples per batch
drop_lastboolfalseDrop the last incomplete batch

Augmentation Pipeline

The project uses albumentations for augmentation. Each transform in augmentation_list is instantiated via Hydra's _target_ mechanism.

Example Augmentation Transforms

augmentation_list:
- _target_: albumentations.RandomCrop
height: 512
width: 512
- _target_: albumentations.HorizontalFlip
p: 0.5
- _target_: albumentations.VerticalFlip
p: 0.5
- _target_: albumentations.RandomBrightnessContrast
p: 0.2
- _target_: albumentations.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- _target_: albumentations.pytorch.ToTensorV2
note

ToTensorV2 converts the image to a (C, H, W) float tensor and the mask to a (H, W) long tensor. It must be the last transform in the list.

Multispectral Support

For multispectral images (e.g. 4-band or 8-band GeoTIFFs), set use_rasterio: true and specify which bands to load with selected_bands. Band indices are 1-based, matching rasterio's convention.

use_rasterio: true
selected_bands: [1, 2, 3, 4] # load the first four bands

If selected_bands is omitted while use_rasterio is true, all available bands are loaded.

tip

Setting use_rasterio: true also works for standard 3-band GeoTIFFs when you need correct handling of geospatial metadata or when PIL cannot open the format.

Image Dtype

The image_dtype parameter controls how the pixel values are interpreted after loading via rasterio. It only affects the rasterio code path (i.e. when use_rasterio: true or selected_bands is set).

Valuenumpy dtypeAutomatic normalization (no-transform path)Typical use case
"uint8" (default)np.uint8divided by 255.0Standard RGB / 8-bit GeoTIFF
"uint16"np.uint16divided by 65535.016-bit satellite imagery (Sentinel-2, Landsat, WorldView)
"float32"np.float32no divisionImagery already stored as normalized floats
"native"unchangedno divisionPreserves the file's original dtype; useful for inspection

The default "uint8" preserves the original behaviour, so existing configuration files do not need any change.

Transform vs no-transform path

When an augmentation_list is provided, normalization is fully handled by Albumentations (e.g. A.Normalize, A.ToFloat). In that case image_dtype only controls the cast applied to the numpy array before it enters the transform pipeline. The automatic division described in the table above applies only when augmentation_list is null.

Albumentations and uint16

Albumentations does not natively support uint16 arrays. When using image_dtype: uint16 with an augmentation pipeline, add A.ToFloat as the first transform to convert values to float32 in [0, 1]:

augmentation_list:
- _target_: albumentations.ToFloat
max_value: 65535.0
- _target_: albumentations.RandomCrop
height: 256
width: 256
- _target_: albumentations.Normalize
mean: [0.5, 0.5, 0.5, 0.4]
std: [0.2, 0.2, 0.2, 0.15]
max_pixel_value: 1.0 # values are already in [0, 1] after ToFloat
- _target_: albumentations.pytorch.ToTensorV2

Full YAML Configuration Example

The following example shows a complete train / val / test dataset configuration for a binary building segmentation task.

val_dataset is monitored at the end of every training epoch (early stopping, checkpointing, LR scheduling). test_dataset is evaluated once after training completes via trainer.test() and its metrics are logged with a test/ prefix. Both are optional.

# configs/dataset/train_val_test.yaml

train_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
input_csv_path: /data/my_dataset/train.csv
root_dir: /data/my_dataset
n_classes: 2
use_rasterio: false
selected_bands: null
augmentation_list:
- _target_: albumentations.RandomCrop
height: 512
width: 512
- _target_: albumentations.HorizontalFlip
p: 0.5
- _target_: albumentations.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- _target_: albumentations.pytorch.ToTensorV2
data_loader:
shuffle: true
num_workers: 4
pin_memory: true
batch_size: 8
drop_last: true

val_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
input_csv_path: /data/my_dataset/val.csv
root_dir: /data/my_dataset
n_classes: 2
use_rasterio: false
selected_bands: null
augmentation_list:
- _target_: albumentations.Resize
height: 512
width: 512
- _target_: albumentations.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- _target_: albumentations.pytorch.ToTensorV2
data_loader:
shuffle: false
num_workers: 4
pin_memory: true
batch_size: 8
drop_last: false

# Optional — when present, trainer.test() is called automatically after fit.
# Metrics are logged with the "test/" prefix.
test_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
input_csv_path: /data/my_dataset/test.csv
root_dir: /data/my_dataset
n_classes: 2
use_rasterio: false
selected_bands: null
augmentation_list:
- _target_: albumentations.Resize
height: 512
width: 512
- _target_: albumentations.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- _target_: albumentations.pytorch.ToTensorV2
data_loader:
shuffle: false
num_workers: 4
pin_memory: true
batch_size: 8
drop_last: false

Multispectral Example — 8-bit GeoTIFF

train_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
input_csv_path: /data/multispectral/train.csv
root_dir: /data/multispectral
n_classes: 5
use_rasterio: true
selected_bands: [1, 2, 3, 4] # RGB + NIR
image_dtype: uint8 # default; explicit for clarity
augmentation_list:
- _target_: albumentations.RandomCrop
height: 256
width: 256
- _target_: albumentations.HorizontalFlip
p: 0.5
- _target_: albumentations.Normalize
mean: [0.485, 0.456, 0.406, 0.4]
std: [0.229, 0.224, 0.225, 0.2]
- _target_: albumentations.pytorch.ToTensorV2
data_loader:
shuffle: true
num_workers: 4
pin_memory: true
batch_size: 16
drop_last: true

Multispectral Example — 16-bit GeoTIFF (Sentinel-2, Landsat)

train_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
input_csv_path: /data/sentinel2/train.csv
root_dir: /data/sentinel2
n_classes: 2
use_rasterio: true
selected_bands: [1, 2, 3, 4]
image_dtype: uint16 # preserves 16-bit precision
augmentation_list:
- _target_: albumentations.ToFloat
max_value: 65535.0 # converts uint16 → float32 in [0, 1]
- _target_: albumentations.RandomCrop
height: 256
width: 256
- _target_: albumentations.HorizontalFlip
p: 0.5
- _target_: albumentations.Normalize
mean: [0.5, 0.5, 0.5, 0.4]
std: [0.2, 0.2, 0.2, 0.15]
max_pixel_value: 1.0 # values are already in [0, 1] after ToFloat
- _target_: albumentations.pytorch.ToTensorV2
data_loader:
shuffle: true
num_workers: 4
pin_memory: true
batch_size: 8
drop_last: true

Dataset Output

__getitem__ returns a dictionary with the following keys:

KeyShapedtypeDescription
image(C, H, W)torch.float32Normalised image tensor
mask(H, W)torch.int64Class-index mask

When augmentation_list is null, the image is automatically converted to a (C, H, W) float tensor. The normalization factor depends on image_dtype:

image_dtypeAutomatic scaling
"uint8"divided by 255.0[0, 1]
"uint16"divided by 65535.0[0, 1]
"float32"no scaling applied
"native"no scaling applied

The mask is always cast to torch.int64.


Folder-Based Dataset (No CSV Required)

When your images and masks are already organized in matching folder hierarchies, you can skip the CSV preparation step entirely and use SegmentationDatasetFromFolder. It scans both folders recursively with pathlib.Path.rglob() and automatically builds the image/mask pairs.

Matching Rules

A pair is valid when both conditions are met:

  1. Same relative subfolder — the path from the root to the file's parent directory is identical in both trees.
  2. Same file stem — the filename without extension is identical.

Files present in only one folder (or in a different subfolder) are silently excluded.

Folder Structure Example

/data/my_dataset/
├── images/
│ ├── area_a/
│ │ ├── tile_001.tif
│ │ └── tile_002.tif
│ └── area_b/
│ └── tile_003.tif
└── masks/
├── area_a/
│ ├── tile_001.tif ← paired
│ └── tile_002.tif ← paired
└── area_b/
└── tile_003.tif ← paired

Python Usage

from pytorch_segmentation_models_trainer.dataset_loader.dataset import SegmentationDatasetFromFolder

ds = SegmentationDatasetFromFolder(
image_folder="/data/my_dataset/images",
mask_folder="/data/my_dataset/masks",
image_extension=".tif", # leading dot is optional
)

print(len(ds)) # number of matched pairs
item = ds[0]
print(item["image"].shape) # (C, H, W) float32 tensor
print(item["mask"].shape) # (H, W) int64 tensor

Different Extensions for Images and Masks

When images and masks have different file extensions, use the mask_extension argument:

ds = SegmentationDatasetFromFolder(
image_folder="/data/my_dataset/images",
mask_folder="/data/my_dataset/masks",
image_extension=".tif",
mask_extension=".png", # masks stored as PNG, images as GeoTIFF
)

Multispectral Support (SegmentationDatasetFromFolder)

All parameters of SegmentationDataset are available, including use_rasterio, selected_bands, and image_dtype:

ds = SegmentationDatasetFromFolder(
image_folder="/data/sentinel2/images",
mask_folder="/data/sentinel2/masks",
image_extension=".tif",
mask_extension=".png",
use_rasterio=True,
selected_bands=[1, 2, 3, 4], # RGB + NIR
image_dtype="uint16",
n_classes=2,
)

With Augmentation Pipeline

import albumentations as A
from albumentations.pytorch import ToTensorV2

augmentation_list = [
A.RandomCrop(height=512, width=512),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
]

ds = SegmentationDatasetFromFolder(
image_folder="/data/my_dataset/images",
mask_folder="/data/my_dataset/masks",
image_extension=".tif",
augmentation_list=augmentation_list,
n_classes=2,
)

Error Handling

A ValueError is raised at construction time if no valid pairs are found. This usually means:

  • The extension is wrong (e.g. searching for .tif in a folder of .png files).
  • The subfolder structure does not match between the image and mask trees.
  • One of the provided folders is empty or does not exist.
try:
ds = SegmentationDatasetFromFolder(
image_folder="/data/images",
mask_folder="/data/masks",
image_extension=".tif",
)
except ValueError as e:
print(e)
# Nenhum par imagem/máscara encontrado entre ...

When to Use Each Approach

ScenarioRecommended class
Pre-existing CSV indexSegmentationDataset
Structured folder hierarchy, no CSV neededSegmentationDatasetFromFolder
Large full-scene images, on-the-fly random croppingRandomCropSegmentationDataset
Systematic sliding-window evaluationRasterPatchDataset