Advanced Dataset usage#

Open in Colab

If you decided to use Dataset APIs, there’s a good chance you want to do one or more processing steps described in this section, especially if working on data ingestion for generative model training.

# @test {"output": "ignore"}
!pip install grain
# @test {"output": "ignore"}
!pip install tensorflow_datasets
import grain
import numpy as np
import tensorflow_datasets as tfds
from pprint import pprint

Checkpointing#

We provide Checkpoint{Save|Restore} to checkpoint the DatasetIterator. It is recommended to use it with Orbax, which can checkpoint both, input pipeline and model, and handles the edge cases for distributed training.

ds = (
    grain.MapDataset.source(tfds.data_source("mnist", split="train"))
    .seed(seed=45)
    .shuffle()
    .to_iter_dataset()
)

num_steps = 4
ds_iter = iter(ds)

# Read some elements.
for i in range(num_steps):
  x = next(ds_iter)
  print(i, x["label"])
0 7
1 4
2 0
3 1
# @test {"output": "ignore"}
!pip install orbax
import orbax.checkpoint as ocp

mngr = ocp.CheckpointManager("/tmp/orbax")

!rm -rf /tmp/orbax

# Save the checkpoint.
assert mngr.save(
    step=num_steps, args=grain.checkpoint.CheckpointSave(ds_iter), force=True
)
# Checkpoint saving in Orbax is asynchronous by default, so we'll wait until
# finished before examining checkpoint.
mngr.wait_until_finished()

# @test {"output": "ignore"}
!ls -R /tmp/orbax
/tmp/orbax:
4

/tmp/orbax/4:
_CHECKPOINT_METADATA
default

/tmp/orbax/4/default:
process_0-of-1.json
!cat /tmp/orbax/*/*/*.json
{
    "next_index": 4
}
# Read more elements and advance the iterator.
for i in range(4, 8):
  x = next(ds_iter)
  print(i, x["label"])
4 7
5 4
6 8
7 0
# Restore iterator from the previously saved checkpoint.
mngr.restore(num_steps, args=grain.checkpoint.CheckpointRestore(ds_iter))
# Iterator should be set back to start from 4.
for i in range(4, 8):
  x = next(ds_iter)
  print(i, x["label"])
4 7
5 4
6 8
7 0

Mixing datasets#

Dataset allows mixing multiple data sources with potentially different transformations. There’s two different ways of mixing Datasets: MapDataset.mix and IterDataset.mix. If the mixed Datasets are sparse (e.g. one of the mixture components needs to be filtered) use IterDataset.mix, otherwise use MapDataset.mix.

tfds.core.DatasetInfo.file_format = (
    tfds.core.file_adapters.FileFormat.ARRAY_RECORD
)
# This particular dataset mixes medical images with hand written numbers,
# probably not useful but allows to illustrate the API on small datasets.
source1 = tfds.data_source(name="pneumonia_mnist", split="train")
source2 = tfds.data_source(name="mnist", split="train")
ds1 = grain.MapDataset.source(source1).map(lambda features: features["image"])
ds2 = grain.MapDataset.source(source2).map(lambda features: features["image"])
ds = grain.MapDataset.mix([ds1, ds2], weights=[0.7, 0.3])
print(f"Mixed dataset length = {len(ds)}")
pprint(np.shape(ds[0]))
Mixed dataset length = 6728
(28, 28, 1)

If filtering inputs to the mixture, use IterDataset.mix.

source1 = tfds.data_source(name="pneumonia_mnist", split="train")
source2 = tfds.data_source(name="mnist", split="train")
ds1 = (
    grain.MapDataset.source(source1)
    .filter(lambda features: int(features["label"]) == 1)
    .to_iter_dataset()
)
ds2 = (
    grain.MapDataset.source(source2)
    .filter(lambda features: int(features["label"]) > 4)
    .to_iter_dataset()
)

ds = grain.IterDataset.mix([ds1, ds2], weights=[0.7, 0.3]).map(
    lambda features: features["image"]
)
pprint(np.shape(next(iter(ds))))
(28, 28, 1)

Multi-epoch training#

Mixed dataset length is determined by a combination of the length of the shortest input dataset and mixing weights. This means that once the shortest component is exhausted the new epoch will begin and the remainder of other datasets is going to be discarded. This can be avoided by repeating inputs to the mixture.

source1 = tfds.data_source(name="pneumonia_mnist", split="train")
source2 = tfds.data_source(name="mnist", split="train")
ds1 = grain.MapDataset.source(source1).repeat()
ds2 = grain.MapDataset.source(source2).repeat()

ds = grain.MapDataset.mix([ds1, ds2], weights=[1, 2])
print(f"Mixed dataset length = {len(ds1)}")  # sys.maxsize
print(f"Mixed dataset length = {len(ds2)}")  # sys.maxsize
# Ds1 and ds2 are repeated to fill out the sys.maxsize with respect to weights.
print(f"Mixed dataset length = {len(ds)}")  # sys.maxsize
Mixed dataset length = 9223372036854775807
Mixed dataset length = 9223372036854775807
Mixed dataset length = 9223372036854775807

Shuffling#

If you need to globally shuffle the mixed data prefer shuffling individual Datasets before mixing. This will ensure that the actual weights of the mixed Datasets are stable and as close as possible to the provided weights.

Additionally, make sure to provide different seeds to different mixture components. This way there’s no chance of introducing a seed dependency between the components if the random transformations overlap.

source1 = tfds.data_source(name="pneumonia_mnist", split="train")
source2 = tfds.data_source(name="mnist", split="train")
ds1 = grain.MapDataset.source(source1).seed(42).shuffle().repeat()
ds2 = grain.MapDataset.source(source2).seed(43).shuffle().repeat()

ds = grain.MapDataset.mix([ds1, ds2], weights=[1, 2])
print(f"Mixed dataset length = {len(ds)}")  # sys.maxsize
Mixed dataset length = 9223372036854775807