Training an Object Detection and Segmentation Model in PyTorch
Training an object detection and segmentation model is a great way to learn about complex data preprocessing for training models.

How to train an object detection and instance segmentation model in PyTorch using Hub

This tutorial is also available as a Colab Notebook​

The primary objective for Hub is to enable users to manage their data more easily so they can train better ML models. This tutorial shows you how to train an object detection and instance segmentation model while streaming data from a Hub dataset stored in the cloud.
Since these models are often complex, this tutorial will focus on data-preprocessing for connecting the data to the model. The user should take additional steps to scale up the code for logging, collecting additional metrics, model testing, and running on GPUs.
This tutorial is inspired by this PyTorch tutorial on training object detection and segmentation models.

Data Preprocessing

The first step is to select a dataset for training. This tutorial uses the COCO dataset that has already been converted into hub format. It is a multi-modal image dataset that contains bounding boxes, segmentation masks, keypoints, and other data.
1
import hub
2
import numpy as np
3
import math
4
import sys
5
import time
6
import torchvision
7
import albumentations as A
8
from albumentations.pytorch import ToTensorV2
9
import torch
10
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
11
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
12
import torchvision.models.detection.mask_rcnn
13
​
14
# Connect to the training dataset
15
ds_train = hub.load('hub://activeloop/coco-train')
Copied!
Note that the dataset can be visualized at the link printed by the hub.load command above.
We extract the number of classes for use later:
1
num_classes = len(ds_train.categories.info.class_names)
Copied!
For complex dataset like this one, it's critical to carefully define the pre-processing function that returns the torch tensors that are use for training. Here we use an Albumentations augmentation pipeline combined with additional pre-processing steps that are necessary for this particular model.
1
# Augmentation pipeline using Albumentations
2
tform_train = A.Compose([
3
A.RandomSizedBBoxSafeCrop(width=128, height=128, erosion_rate = 0.2),
4
A.HorizontalFlip(p=0.5),
5
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
6
ToTensorV2(), # transpose_mask = True
7
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels', 'bbox_ids'], min_area=25, min_visibility=0.6)) # 'label_fields' and 'box_ids' are all the fields that will be cut when a bounding box is cut.
8
​
9
​
10
# Transformation function for pre-processing the hub sample before sending it to the model
11
def transform(sample_in):
12
​
13
# Convert boxes to Pascal VOC format
14
boxes = coco_2_pascal(sample_in['boxes'])
15
​
16
# Convert any grayscale images to RGB
17
images = sample_in['images']
18
if images.shape[2] == 1:
19
images = np.repeat(images, int(3/images.shape[2]), axis = 2)
20
​
21
# Pass all data to the Albumentations transformation
22
# Mask must be converted to a list
23
transformed = tform_train(image = images,
24
masks = [sample_in['masks'][:,:,i].astype(np.uint8) for i in range(sample_in['masks'].shape[2])],
25
bboxes = boxes,
26
bbox_ids = np.arange(boxes.shape[0]),
27
class_labels = sample_in['categories'],
28
)
29
​
30
# Convert boxes and labels from lists to torch tensors, because Albumentations does not do that automatically.
31
# Be very careful with rounding and casting to integers, becuase that can create bounding boxes with invalid dimensions
32
labels_torch = torch.tensor(transformed['class_labels'], dtype = torch.int64)
33
​
34
boxes_torch = torch.zeros((len(transformed['bboxes']), 4), dtype = torch.int64)
35
for b, box in enumerate(transformed['bboxes']):
36
boxes_torch[b,:] = torch.tensor(np.round(box))
37
38
​
39
# Filter out the masks that were dropped by filtering of bounding box area and visibility
40
masks_torch = torch.zeros((len(transformed['bbox_ids']), transformed['image'].shape[1], transformed['image'].shape[2]), dtype = torch.int64)
41
if len(transformed['bbox_ids'])>0:
42
masks_torch = torch.tensor(np.stack([transformed['masks'][i] for i in transformed['bbox_ids']], axis = 0), dtype = torch.uint8)
43
44
​
45
​
46
# Put annotations in a separate object
47
target = {'masks': masks_torch, 'labels': labels_torch, 'boxes': boxes_torch}
48
​
49
return transformed['image'], target
50
​
51
​
52
# Conversion script for bounding boxes from coco to Pascal VOC format
53
def coco_2_pascal(boxes):
54
# Convert bounding boxes to Pascal VOC format and clip bounding boxes to make sure they have non-negative width and height
55
​
56
return np.stack((boxes[:,0], boxes[:,1], boxes[:,0]+np.clip(boxes[:,2], 1, None), boxes[:,1]+np.clip(boxes[:,3], 1, None)), axis = 1)
57
​
58
​
59
def collate_fn(batch):
60
return tuple(zip(*batch))
Copied!
You can now create a PyTorch dataloader that connects the Hub dataset to the PyTorch model using the provided method ds.pytorch(). This method automatically applies the transformation function and takes care of random shuffling (if desired). The num_workers parameter can be used to parallelize data preprocessing, which is critical for ensuring that preprocessing does not bottleneck the overall training workflow.
Since the dataset contains many tensors that are not used for training, a list of tensors for loading is specified in order to avoid streaming of unused data.
1
batch_size = 8
2
​
3
train_loader = ds_train.pytorch(num_workers = 2, shuffle = False,
4
tensors = ['images', 'masks', 'categories', 'boxes'], # Specify the tensors that are needed, so we don't load unused data
5
transform = transform,
6
batch_size = batch_size,
7
collate_fn = collate_fn)
Copied!

Model Definition

This tutorial uses a pre-trained torchvision neural network from the torchvision.models module.
Training is performed on a GPU if possible. Otherwise, it's on a CPU.
1
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
2
print(device)
Copied!
1
# Helper function for loading the model
2
def get_model_instance_segmentation(num_classes):
3
# Load an instance segmentation model pre-trained on COCO
4
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
5
​
6
# Get number of input features for the classifier
7
in_features = model.roi_heads.box_predictor.cls_score.in_features
8
# replace the pre-trained head with a new one
9
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
10
​
11
# Get the number of input features for the mask classifier
12
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
13
hidden_layer = 256
14
# Replace the mask predictor with a new one
15
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
16
hidden_layer,
17
num_classes)
18
​
19
return model
Copied!
Let's initialize the model and optimizer.
1
model = get_model_instance_segmentation(num_classes)
2
​
3
model.to(device)
4
​
5
# Specity the optimizer
6
params = [p for p in model.parameters() if p.requires_grad]
7
optimizer = torch.optim.SGD(params, lr=0.005,
8
momentum=0.9, weight_decay=0.0005)
Copied!

Training the Model

Helper functions for training and testing the model are defined. Note that the output from Hub's PyTorch dataloader is fed into the model just like data from ordinary PyTorch dataloaders.
1
# Helper function for training for 1 epoch
2
def train_one_epoch(model, optimizer, data_loader, device):
3
model.train()
4
​
5
start_time = time.time()
6
for i, data in enumerate(data_loader):
7
​
8
images = list(image.to(device) for image in data[0])
9
targets = [{k: v.to(device) for k, v in t.items()} for t in data[1]]
10
11
loss_dict = model(images, targets)
12
losses = sum(loss for loss in loss_dict.values())
13
loss_value = losses.item()
14
​
15
# Print performance statistics
16
batch_time = time.time()
17
speed = (i+1)/(batch_time-start_time)
18
print('[%5d] loss: %.3f, speed: %.2f' %
19
(i, loss_value, speed))
20
​
21
if not math.isfinite(loss_value):
22
print(f"Loss is {loss_value}, stopping training")
23
print(loss_dict)
24
break
25
​
26
optimizer.zero_grad()
27
​
28
losses.backward()
29
optimizer.step()
Copied!
The model and data are ready for training πŸš€!
1
# Train the model for 1 epoch
2
num_epochs = 1
3
for epoch in range(num_epochs): # loop over the dataset multiple times
4
print("------------------ Training Epoch {} ------------------".format(epoch+1))
5
train_one_epoch(model, optimizer, train_loader, device)
6
7
# --- Insert Testing Code Here ---
8
​
9
print('Finished Training')
Copied!
Congrats! You successfully trained an object detection and instance segmentation model while streaming data directly from the cloud! πŸŽ‰
​