How to improve data input pipeline performance?
Mentioning the Solution and the Important observations of @AlexisBRENON in the Answer Section, for the benefit of the Community.
Below mentioned are the Important Observations:
- According to this GitHub issue, the
TFRecordDataset
interleaving
is a legacy one, sointerleave
function is better. batch
beforemap
is a good habit (vectorizing your function) and reduce the number of times the mapped function is called.- No need of
repeat
anymore. Since TF2.0, the Keras model API supports the dataset API and can use cache (see the SO post) - Switch from a
VarLenFeature
to aFixedLenSequenceFeature
, removing a useless call totf.sparse.to_dense
.
Code for the Pipeline, with improved performance, in line with above observations is mentioned below:
def build_dataset(file_pattern): tf.data.Dataset.list_files( file_pattern ).interleave( TFRecordDataset, cycle_length=tf.data.experimental.AUTOTUNE, num_parallel_calls=tf.data.experimental.AUTOTUNE ).shuffle( 2048 ).batch( batch_size=64, drop_remainder=True, ).map( map_func=parse_examples_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE ).cache( ).prefetch( tf.data.experimental.AUTOTUNE )def parse_examples_batch(examples): preprocessed_sample_columns = { "features": tf.io.FixedLenSequenceFeature((), tf.float32, allow_missing=True), "booleanFeatures": tf.io.FixedLenFeature((), tf.string, ""), "label": tf.io.FixedLenFeature((), tf.float32, -1) } samples = tf.io.parse_example(examples, preprocessed_sample_columns) bits_to_float = tf.io.decode_raw(samples["booleanFeatures"], tf.uint8) return ( (samples['features'], bits_to_float), tf.expand_dims(samples["label"], 1) )
I have a further suggestion to add:
According to the documentation of interleave(), you can as the first parameteruse a mapping function.
This means, one can write:
dataset = tf.data.Dataset.list_files(file_pattern) dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x).map(parse_fn, num_parallel_calls=AUTOTUNE), cycle_length=tf.data.experimental.AUTOTUNE, num_parallel_calls=tf.data.experimental.AUTOTUNE )
As I understand it, this maps a parsing function to each shard, and then interleaves the results. This then eliminates the use of dataset.map(...)
later on.