Training Models Using PyTorch Lightning
How to Train models using Deep Lake and PyTorch Lightning
How to Train models using Deep Lake and PyTorch Lightning
Deep Lake's integration with PyTorch can also be used to train models using an integration with PyTorch Lightning, a popular open-source high-level interface for PyTorch.
Overview
At a high-level, Deep Lake is connected to PyTorch lightning by passing the Deep Lake's PyTorch dataloader to any PyTorch Lightning API that expects a dataloader parameter, such as trainer.fit(..., train_dataloaders = deeplake_dataloader)
. The only caveats are:
Deep Lake handles distributed training via it's
distributed
parameter in the .pytorch() method. Therefore, the PyTorch Lightning Trainer class should be initialized withreplace_sampler_ddp = False.
Example Code
This tutorial uses PyTorch Lightning to execute the identical training workflow that is shown here in PyTorch. It is also available as a Colab Notebook.
Data Preprocessing
The first step is to load the dataset for training. This tutorial uses the Fashion MNIST dataset that has already been converted into Deep Leake format. It is a simple image classification dataset that categorizes images by clothing type (trouser, shirt, etc.)
The next step is to define a transformation function that will process the data and convert it into a format that can be passed into a deep learning model. In this particular example, torchvision.transforms
is used as a part of the transformation pipeline that performs operations such as normalization and image augmentation (rotation).
You can now create a PyTorch dataloader that connects the Deep Lake dataset to the PyTorch model using the provided method ds.pytorch()
. This method automatically applies the transformation function and takes care of random shuffling (if desired). The num_workers
parameter can be used to parallelize data preprocessing, which is critical for ensuring that preprocessing does not bottleneck the overall training workflow.
The transform
input is a dictionary where the key
is the tensor name and the value
is the transformation function that should be applied to that tensor. If a specific tensor's data does not need to be returned, it should be omitted from the keys. If the transformation function is set as None
, the input tensor is converted to a torch tensor without additional modification.
Model and LightningModule Definition
This tutorial uses a pre-trained ResNet18 neural network from the torchvision.models module, converted to a single-channel network for grayscale images. The LightningModule organizes the training code.
Training the Model
PyTorchLightning takes care of the training loop, so the remaining steps are to initialize the Trainer and call the .fit()
method using the training and validation dataloaders.
Congrats! You successfully trained an image classification model using PyTorch Lightning while streaming data directly from the cloud! 🎉