We provide a pytorch dataset `XMLDetectionDataset` and a pytorch-lightning datamodule `XMLDetectionDataModule` to load image and annotation data.
Example usage of the XMLDetectionDataset
data_dir = '../data/train' # path to data
dataset = XMLDetectionDataset(root=data_dir, image_transform=None, target_transform=None, transform=None)
print('Samples found:', len(dataset))
for idx, (im, ll) in track(enumerate(dataset), total=len(dataset)):
pass
idx = 1
image, target = dataset[idx]
# Show a sample image with a bounding box
print('image filename:', dataset.image_files[idx])
print('xml filename:', dataset.xml_files[idx])
dataset.draw_sample(idx)
For transforms that work on both image and bbox, use albumentations.
import numpy as np
import albumentations as A
import cv2
idx = 69
# Without a transform
dataset.image_transform = None
dataset.transform = None
dataset.target_transform = None
image, target = dataset[idx]
print('Image size (Without Transform):', image.shape)
print('Target (Without Transform):', target)
# With a transform
dataset.image_transform = lambda x: np.asarray(x)
dataset.target_transform = lambda x: [x + ['ball']]
dataset.transform = A.Compose([
A.SmallestMaxSize(256),
A.RandomCrop(width=224, height=224),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.2),
], bbox_params=A.BboxParams(format='pascal_voc'))
image, target = dataset[idx]
print('Image size (With Transform):', image.shape)
print('Target (With Transform):', target)
To visualize the image after a transform use the static method, visualize
from detection_nbdev.utils import visualize
f = lambda x: x if x == [] else list(map(int, target[0][:4]))
for ax in fastai.vision.data.get_grid(9):
image, target = dataset[idx]
visualize(
image,
[f(target)],
[0],
{0:'ball'},
ax=ax)
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
data_dir = '../detection/data/'
image_transform = lambda x: np.asarray(x)
target_transform = lambda x: [] if x is None else [x + ['ball']]
transform = A.Compose([
A.SmallestMaxSize(224),
A.CenterCrop(width=224, height=224),
ToTensorV2(),
], bbox_params=A.BboxParams(format='pascal_voc'))
dm = XMLDetectionDataModule(data_dir,
r_train=0.6,
r_val=0.2,
r_test=0.2,
image_transform=image_transform,
target_transform=target_transform,
transform=transform
)
dm.setup(mode='use_dir')
print(len(dm.trainset), len(dm.valset), len(dm.testset))
for idx, (im, ll) in track(enumerate(dm.trainset), total=len(dm.trainset)):
assert im.shape == (3, 224, 224), (idx, im.shape)
for idx, (im, ll) in track(enumerate(dm.valset), total=len(dm.valset)):
assert im.shape == (3, 224, 224), (idx, im.shape)
for idx, (im, ll) in track(enumerate(dm.testset), total=len(dm.testset)):
assert im.shape == (3, 224, 224), (idx, im.shape)
for x, y in dm.train_dataloader():
pass
for x, y in dm.val_dataloader():
pass
for x, y in dm.test_dataloader():
pass