TensorFlow - tf.data.Dataset reading large HDF5 files TensorFlow - tf.data.Dataset reading large HDF5 files python python

TensorFlow - tf.data.Dataset reading large HDF5 files


I stumbled across this question while dealing with a similar issue. I came up with a solution based on using a Python generator, together with the TF dataset construction method from_generator. Because we use a generator, the HDF5 file should be opened for reading only once and kept open as long as there are entries to read. So it will not be opened, read, and then closed for every single call to get the next data element.

Generator definition

To allow the user to pass in the HDF5 filename as an argument, I generated a class that has a __call__ method since from_generator specifies that the generator has to be callable. This is the generator:

import h5pyimport tensorflow as tfclass generator:    def __init__(self, file):        self.file = file    def __call__(self):        with h5py.File(self.file, 'r') as hf:            for im in hf["train_img"]:                yield im

By using a generator, the code should pick up from where it left off at each call from the last time it returned a result, instead of running everything from the beginning again. In this case it is on the next iteration of the inner for loop. So this should skip opening the file again for reading, keeping it open as long as there is data to yield. For more on generators, see this excellent Q&A.

Of course, you will have to replace anything inside the with block to match how your dataset is constructed and what outputs you want to obtain.

Usage example

ds = tf.data.Dataset.from_generator(    generator(hdf5_path),     tf.uint8,     tf.TensorShape([427,561,3]))value = ds.make_one_shot_iterator().get_next()# Example on how to read elementswhile True:    try:        data = sess.run(value)        print(data.shape)    except tf.errors.OutOfRangeError:        print('done.')        break

Again, in my case I had stored uint8 images of height 427, width 561, and 3 color channels in my dataset, so you will need to modify these in the above call to match your use case.

Handling multiple files

I have a proposed solution for handling multiple HDF5 files. The basic idea is to construct a Dataset from the filenames as usual, and then use the interleave method to process many input files concurrently, getting samples from each of them to form a batch, for example.

The idea is as follows:

ds = tf.data.Dataset.from_tensor_slices(filenames)# You might want to shuffle() the filenames here depending on the applicationds = ds.interleave(lambda filename: tf.data.Dataset.from_generator(        generator(filename),         tf.uint8,         tf.TensorShape([427,561,3])),       cycle_length, block_length)

What this does is open cycle_length files concurrently, and produce block_length items from each before moving to the next file - see interleave documentation for details. You can set the values here to match what is appropriate for your application: e.g., do you need to process one file at a time or several concurrently, do you only want to have a single sample at a time from each file, and so on.

Edit: for a parallel version, take a look at tf.contrib.data.parallel_interleave!

Possible caveats

Be aware of the peculiarities of using from_generator if you decide to go with the solution. For Tensorflow 1.6.0, the documentation of from_generator mentions these two notes.

It may be challenging to apply this across different environments or with distributed training:

NOTE: The current implementation of Dataset.from_generator() uses tf.py_func and inherits the same constraints. In particular, it requires the Dataset- and Iterator-related operations to be placed on a device in the same process as the Python program that called Dataset.from_generator(). The body of generator will not be serialized in a GraphDef, and you should not use this method if you need to serialize your model and restore it in a different environment.

Be careful if the generator depends on external state:

NOTE: If generator depends on mutable global variables or other external state, be aware that the runtime may invoke generator multiple times (in order to support repeating the Dataset) and at any time between the call to Dataset.from_generator() and the production of the first element from the generator. Mutating global variables or external state can cause undefined behavior, and we recommend that you explicitly cache any external state in generator before calling Dataset.from_generator().


I took me a while to figure this out, so I thought I should record this here. Based on mikkola's answer, this is how to handle multiple files:

import h5pyimport tensorflow as tfclass generator:    def __call__(self, file):        with h5py.File(file, 'r') as hf:            for im in hf["train_img"]:                yield imds = tf.data.Dataset.from_tensor_slices(filenames)ds = ds.interleave(lambda filename: tf.data.Dataset.from_generator(        generator(),         tf.uint8,         tf.TensorShape([427,561,3]),        args=(filename,)),       cycle_length, block_length)

The key is you can't pass filename directly to generator, since it's a Tensor. You have to pass it through args, which tensorflow evaluates and converts it to a regular python variable.