ds.pytorch()
. If your model training is highly sensitive to the randomization of the input data, please pre-shuffle the data, or explore our writeup onShuffling in ds.pytorch().transform
parameter in ds.pytorch()
is a dictionary where the key
is the tensor name and the value
is the transformation function for that tensor. If a tensor's data does not need to be returned, the tensor should be omitted from the keys. If a tensor's data does not need to be modified during preprocessing, the transformation function for the tensor is set as None
.transforms.Lambda(lambda x: x.repeat(int(3/x.shape[0]), 1, 1))
ds
object to the PyTorch Dataset's constructor and pulling data in the __getitem__
method using self.ds.image[ids].numpy()
:ds.tensorflow()
. Downstream, functions from the tf.Data
API such as map, shuffle, etc. can be applied to process the data before training.