In this RFC, I propose a new feature in TVM called the DataLoader. The DataLoader class is a soft wrapper class for dataset classes in other machine learning frameworks.
Motivation:
There are a wide variety of datasets that exist in the machine learning framework ecosystem, and each have their own separate API. Since TVM does not have its own datasets, we must write code that uses datasets from other frameworks (most often, pytorch, tensorflow, keras and mxnet). Since these datasets all have different APIs, it is difficult to write generalized code that uses datasets without assuming that the dataset is from one of these specific frameworks.
For example, for quantizing relay models using data aware quantization, it is useful to have a unified API that wraps datasets so we don’t have to handle each type of dataset separately.
Existing APIs:
The most popular frameworks have very different APIs for their datasets.
Pytorch datasets use indexing by overwriting get_item, so you can index directly into the dataset object:
for i in range(len(dataset)):
sample = dataset[i]
// Do something with the data
Mxnet provides the option to create python iterable out of the dataset:
iterable_dataset = iter(dataset)
for data in iterable_dataset:
sample = data.asnumpy()
// Do something with the data
Tensorflow is very similar to Mxnet, however converting individual datapoints to numpy requires using .numpy() method instead of the .asnumpy() method, so it would look something like this:
iterable_dataset = iter(dataset)
for data in iterable_dataset:
sample = data.numpy()
// Do something with the data
Keras datasets are actually provided in a single large numpy array. So, to use the keras dataset, you have to iterate over the batch axis.
for i in range(0, len_dataset, batch_size):
data = dataset[i * batch_size : (i + 1) * batch_size]
// Do something with the data
Proposed solution:
I propose writing a class that iterates over an existing dataset and also contains batch size and the number of batches in the dataset— information that is useful to software that may use the DataLoader.
Here is the abstract DataLoader class (subclasses will implement the DataLoader for each framework).
class DataLoader:
"""Wrapper class for data loader or data set classes implemented by other machine learning
frameworks. Use this class when you want to use different machine learning framework datasets
interchangably."""
def __iter__(self):
"""Returns the DataLoaderIterator."""
return self
def __next__(self):
"""Returns the next batch of data.
Returns
-------
inputs : List of ndarray
The inputs to be provided to the graph.
The list is of the form [batched_input_1, batched_input_2, ..., batched_input_n]
labels: List
The expected outputs of the graph.
The length of labels should be equal to the batch size. If the DataLoader doesn't
have labels, labels will be None.
"""
raise NotImplementedError
def get_num_batches(self):
"""Returns the number of batches the DataLoader has.
Returns
------
num_batches : int
The number of batches the DataLoader contains.
"""
raise NotImplementedError
def get_batch_size(self):
"""Gets the batch size.
Returns
-------
batch_size : int
The size of the batch returned by the DataLoader.
"""
raise NotImplementedError
You can construct a python iterator out of the DataLoader class easily. This means that our framework is similar to tensorflow and mxnet’s dataset API, which makes it familiar to users who want to use it directly. And, the other fields allow us to store information that might be useful for things like calculating accuracy or doing averages, like batch size and the total number of batches.
Here’s a link to the PR that introduces this code: https://github.com/apache/tvm/pull/7710
In the PR, I also implement the RandomDataLoader, MxnetDataLoader, TFDataLoader and NumpyDataLoader (loads keras and other datasets that are stored in the numpy format). The RandomDataLoader class provides random numpy data of a specific shape and dtype for testing purposes. The TFDataLoader takes in a tensorflow dataset as an input, the MxnetDataLoader takes an mxnet dataset as an input, and the NumpyDataLoader can take in any data that is formatted in a numpy array (keras datasets provide data in this format).
Writing a PytorchDataLoader will be future work. I think it will be similar to the implementation of the NumpyDataLoader, and will not be difficult to implement.