# Augmentation pipeline using Albumentations
tform_train = A.Compose([
A.RandomSizedBBoxSafeCrop(width=128, height=128, erosion_rate = 0.2),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(), # transpose_mask = True
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels', 'bbox_ids'], min_area=25, min_visibility=0.6)) # 'label_fields' and 'box_ids' are all the fields that will be cut when a bounding box is cut.
# Transformation function for pre-processing the hub sample before sending it to the model
def transform(sample_in):
# Convert boxes to Pascal VOC format
boxes = coco_2_pascal(sample_in['boxes'])
# Convert any grayscale images to RGB
images = sample_in['images']
images = np.repeat(images, int(3/images.shape[2]), axis = 2)
# Pass all data to the Albumentations transformation
# Mask must be converted to a list
transformed = tform_train(image = images,
masks = [sample_in['masks'][:,:,i].astype(np.uint8) for i in range(sample_in['masks'].shape[2])],
bbox_ids = np.arange(boxes.shape[0]),
class_labels = sample_in['categories'],
# Convert boxes and labels from lists to torch tensors, because Albumentations does not do that automatically.
# Be very careful with rounding and casting to integers, becuase that can create bounding boxes with invalid dimensions
labels_torch = torch.tensor(transformed['class_labels'], dtype = torch.int64)
boxes_torch = torch.zeros((len(transformed['bboxes']), 4), dtype = torch.int64)
for b, box in enumerate(transformed['bboxes']):
boxes_torch[b,:] = torch.tensor(np.round(box))
# Filter out the masks that were dropped by filtering of bounding box area and visibility
masks_torch = torch.zeros((len(transformed['bbox_ids']), transformed['image'].shape[1], transformed['image'].shape[2]), dtype = torch.int64)
if len(transformed['bbox_ids'])>0:
masks_torch = torch.tensor(np.stack([transformed['masks'][i] for i in transformed['bbox_ids']], axis = 0), dtype = torch.uint8)
# Put annotations in a separate object
target = {'masks': masks_torch, 'labels': labels_torch, 'boxes': boxes_torch}
return transformed['image'], target
# Conversion script for bounding boxes from coco to Pascal VOC format
def coco_2_pascal(boxes):
# Convert bounding boxes to Pascal VOC format and clip bounding boxes to make sure they have non-negative width and height
return np.stack((boxes[:,0], boxes[:,1], boxes[:,0]+np.clip(boxes[:,2], 1, None), boxes[:,1]+np.clip(boxes[:,3], 1, None)), axis = 1)
return tuple(zip(*batch))