Skip to main content

Building Detection & Instance Segmentation Datasets

This guide covers the ObjectDetectionDataset and InstanceSegmentationDataset classes, which extend the base CSV-driven dataset system with bounding-box and instance mask support.

CSV Schema for Object Detection

The detection CSV requires at minimum an image column and a bounding_box column pointing to a JSON file. A mask column is optional for detection-only tasks but required for instance segmentation.

ColumnRequiredDescription
imageYesPath to the input image
bounding_boxesYesPath to a JSON file containing bounding box annotations
mask / polygon_maskInstance seg onlyPath to the instance segmentation mask

Example CSV

image,bounding_boxes
images/scene_001.tif,bounding_boxes/scene_001.json
images/scene_002.tif,bounding_boxes/scene_002.json

For instance segmentation:

image,polygon_mask,bounding_boxes
images/scene_001.tif,polygon_masks/scene_001.png,bounding_boxes/scene_001.json
images/scene_002.tif,polygon_masks/scene_002.png,bounding_boxes/scene_002.json

Bounding Box JSON Format

Each bounding box JSON file is a JSON array of objects. Each object must have:

FieldTypeDescription
bboxarrayBounding box coordinates (see format below)
classintegerNumeric class / category ID

The bbox field holds coordinates in the format specified by the bbox_format parameter:

  • xywh (default / COCO format): [x_min, y_min, width, height]
  • xyxy: [x_min, y_min, x_max, y_max]

Example JSON File

[
{
"bbox": [120, 45, 80, 60],
"class": 1
},
{
"bbox": [300, 200, 55, 70],
"class": 1
},
{
"bbox": [500, 10, 100, 120],
"class": 2
}
]
note

The JSON field for the class label is "class", not "category_id". The dataset reads box_item["class"] for labels and box_item["bbox"] for coordinates.

The ObjectDetectionDataset Class

Constructor Parameters

ParameterTypeDefaultDescription
input_csv_pathPathrequiredPath to the CSV index file
root_dirstrNoneRoot directory prepended to relative CSV paths
augmentation_listlistNonealbumentations transforms (must include bbox_params)
data_loaderconfigNoneDataLoader keyword arguments
image_keystr"image"CSV column for image paths
mask_keystr"mask"CSV column for mask paths
bounding_box_keystr"bounding_boxes"CSV column pointing to the bounding box JSON file
n_first_rows_to_readintNoneLimit CSV rows read
bbox_formatstr"xywh"Format of bboxes stored in the JSON file ("xywh" or "xyxy")
bbox_output_formatstr"xyxy"Format of bboxes returned in the output dict
bbox_paramsconfigNonealbumentations BboxParams config for bbox-aware augmentation

Dataset Output

__getitem__ returns a 3-tuple: (image, ds_item_dict, index)

ElementTypeDescription
imagetorch.Tensor (C, H, W)The input image (loaded as RGB)
ds_item_dictdictDictionary with boxes and labels tensors
indexintIndex of this sample in the dataset

Keys in ds_item_dict:

KeyShapedtypeDescription
boxes(N, 4)torch.float32Bounding boxes in bbox_output_format
labels(N,)torch.int64Class index for each box

Required: Custom collate_fn

Because different images have varying numbers of bounding boxes, you must use the dataset's built-in collate_fn with your DataLoader. The default PyTorch collate will fail when tensor shapes differ across samples.

from torch.utils.data import DataLoader
from pytorch_segmentation_models_trainer.dataset_loader.dataset import ObjectDetectionDataset

dataset = ObjectDetectionDataset(input_csv_path="train.csv")
loader = DataLoader(
dataset,
batch_size=4,
collate_fn=ObjectDetectionDataset.collate_fn,
)

The collate_fn stacks images into a single tensor and returns a list of per-image target dictionaries:

(images_tensor [B, C, H, W], List[dict], indexes_tensor [B])

YAML Configuration Example

# configs/dataset/detection_train.yaml

train_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.ObjectDetectionDataset
input_csv_path: /data/detection/train.csv
root_dir: /data/detection
bounding_box_key: bounding_boxes
bbox_format: xywh # format stored in the JSON files
bbox_output_format: xyxy # format returned by __getitem__
augmentation_list:
- _target_: albumentations.RandomCrop
height: 512
width: 512
- _target_: albumentations.HorizontalFlip
p: 0.5
- _target_: albumentations.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- _target_: albumentations.pytorch.ToTensorV2
bbox_params:
format: coco # albumentations bbox format string
label_fields: [labels]
min_visibility: 0.1
data_loader:
shuffle: true
num_workers: 4
pin_memory: true
batch_size: 4
drop_last: true

The InstanceSegmentationDataset Class

InstanceSegmentationDataset extends ObjectDetectionDataset by additionally loading a per-instance segmentation mask and optionally keypoints.

Additional Constructor Parameters

ParameterTypeDefaultDescription
mask_keystr"polygon_mask"CSV column for instance mask paths
return_maskboolTrueWhether to load and return the mask
keypoint_keystr"keypoints"CSV column pointing to a keypoints JSON file
return_keypointsboolFalseWhether to load and return keypoints

Dataset Output

The output is the same 3-tuple as ObjectDetectionDataset, with ds_item_dict gaining an additional key:

KeyShapedtypeDescription
boxes(N, 4)torch.float32Bounding boxes
labels(N,)torch.int64Class labels
masks(1, H, W)torch.uint8Binary instance mask (if return_mask)

YAML Configuration Example

# configs/dataset/instance_seg_train.yaml

train_dataset:
_target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.InstanceSegmentationDataset
input_csv_path: /data/instance_seg/train.csv
root_dir: /data/instance_seg
mask_key: polygon_mask
bounding_box_key: bounding_boxes
bbox_format: xywh
bbox_output_format: xyxy
return_mask: true
return_keypoints: false
augmentation_list:
- _target_: albumentations.RandomCrop
height: 512
width: 512
- _target_: albumentations.HorizontalFlip
p: 0.5
- _target_: albumentations.Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- _target_: albumentations.pytorch.ToTensorV2
bbox_params:
format: coco
label_fields: [labels]
min_visibility: 0.1
data_loader:
shuffle: true
num_workers: 4
pin_memory: true
batch_size: 4
drop_last: true

Augmentation Notes

Because bounding boxes must be transformed consistently with the image, the bbox_params config is passed to albumentations.Compose when building the augmentation pipeline. Without this, augmentations like RandomCrop or HorizontalFlip would not update box coordinates.

Key BboxParams fields:

FieldDescription
formatAlbumentations bbox string: "coco" (xywh), "pascal_voc" (xyxy), etc.
label_fieldsList of field names in the transform dict that hold class labels
min_visibilityMinimum fraction of box area that must remain after a crop
tip

Use format: coco in bbox_params when your JSON files store xywh coordinates — this matches albumentations' COCO format string, which is equivalent to xywh.