We provide a pytorch dataset `XMLDetectionDataset` and a pytorch-lightning datamodule `XMLDetectionDataModule` to load image and annotation data.

Overview

class XMLDetectionDataset[source]

XMLDetectionDataset(*args, **kwds) :: VisionDataset

Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>_ Detection Dataset.

Args: root (string): Root directory of the VOC Dataset. image_transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop target_transform (callable, required): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version.

Example usage of the XMLDetectionDataset

data_dir = '../data/train' # path to data
dataset = XMLDetectionDataset(root=data_dir, image_transform=None, target_transform=None, transform=None)
print('Samples found:', len(dataset))

for idx, (im, ll) in track(enumerate(dataset), total=len(dataset)):
    pass

XMLDetectionDataset.draw_sample[source]

XMLDetectionDataset.draw_sample(idx=None)

idx = 1
image, target = dataset[idx]

# Show a sample image with a bounding box
print('image filename:', dataset.image_files[idx])
print('xml filename:', dataset.xml_files[idx])
dataset.draw_sample(idx)
image filename: ../detection/data/train/10.jpg
xml filename: ../detection/data/train/10.xml

Get a single item from the dataset

For transforms that work on both image and bbox, use albumentations.

import numpy as np
import albumentations as A
import cv2

idx = 69

# Without a transform
dataset.image_transform = None
dataset.transform = None
dataset.target_transform = None
image, target = dataset[idx]
print('Image size (Without Transform):', image.shape)
print('Target (Without Transform):', target)

# With a transform
dataset.image_transform = lambda x: np.asarray(x)
dataset.target_transform = lambda x: [x + ['ball']]
dataset.transform = A.Compose([
        A.SmallestMaxSize(256),
        A.RandomCrop(width=224, height=224),
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
    ], bbox_params=A.BboxParams(format='pascal_voc'))

image, target = dataset[idx]
print('Image size (With Transform):', image.shape)
print('Target (With Transform):', target)
Image size (Without Transform):
(4206, 2938)
Target (Without Transform):
[385, 3665, 777, 4058]
Image size (With Transform):
(224, 224, 3)
Target (With Transform):
[]

To visualize the image after a transform use the static method, visualize

from detection_nbdev.utils import visualize

f = lambda x: x if x == [] else list(map(int, target[0][:4]))
for ax in fastai.vision.data.get_grid(9):

    image, target = dataset[idx]
    visualize(
        image,
        [f(target)],
        [0],
        {0:'ball'},
        ax=ax)

class XMLDetectionDataModule[source]

XMLDetectionDataModule(*args, **kwargs) :: LightningDataModule

A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.

Example::

class MyDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
    def setup(self):
        # make assignments here (val/train/test split)
        # called on every process in DDP
    def train_dataloader(self):
        train_split = Dataset(...)
        return DataLoader(train_split)
    def val_dataloader(self):
        val_split = Dataset(...)
        return DataLoader(val_split)
    def test_dataloader(self):
        test_split = Dataset(...)
        return DataLoader(test_split)

A DataModule implements 5 key methods:

  • prepare_data (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode).
  • setup (things to do on every accelerator in distributed mode).
  • train_dataloader the training dataloader.
  • val_dataloader the val dataloader(s).
  • test_dataloader the test dataloader(s).

This allows you to share a full dataset without explaining how to download, split transform and process the data

import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

data_dir = '../detection/data/'

image_transform = lambda x: np.asarray(x)
target_transform = lambda x: [] if x is None else [x + ['ball']]
transform = A.Compose([
        A.SmallestMaxSize(224),
        A.CenterCrop(width=224, height=224),
        ToTensorV2(),
    ], bbox_params=A.BboxParams(format='pascal_voc'))


dm = XMLDetectionDataModule(data_dir, 
                            r_train=0.6, 
                            r_val=0.2, 
                            r_test=0.2, 
                            image_transform=image_transform,
                            target_transform=target_transform,
                            transform=transform
                           )
dm.setup(mode='use_dir')
print(len(dm.trainset), len(dm.valset), len(dm.testset))

for idx, (im, ll) in track(enumerate(dm.trainset), total=len(dm.trainset)):
    assert im.shape == (3, 224, 224), (idx, im.shape)
    
for idx, (im, ll) in track(enumerate(dm.valset), total=len(dm.valset)):
    assert im.shape == (3, 224, 224), (idx, im.shape)
    
for idx, (im, ll) in track(enumerate(dm.testset), total=len(dm.testset)):
    assert im.shape == (3, 224, 224), (idx, im.shape)
2676 892 893
/home/atom/miniconda3/lib/python3.8/site-packages/albumentations/pytorch/transforms.py:107: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)
  return torch.from_numpy(img.transpose(2, 0, 1))
for x, y in dm.train_dataloader():
    pass
for x, y in dm.val_dataloader():
    pass
for x, y in dm.test_dataloader():
    pass