Tensorflow: Load data in multiple threads on cpu Tensorflow: Load data in multiple threads on cpu multithreading multithreading

Tensorflow: Load data in multiple threads on cpu


Assuming you're using the latest Tensorflow (1.4 at the time of this writing), you can keep the generator and use the tf.data.* API as follows (I chose arbitrary values for the thread number, prefetch buffer size, batch size and output data types):

NUM_THREADS = 5sceneGen = SceneGenerator()dataset = tf.data.Dataset.from_generator(sceneGen.generate_data, output_types=(tf.float32, tf.int32))dataset = dataset.map(lambda x,y : (x,y), num_parallel_calls=NUM_THREADS).prefetch(buffer_size=1000)dataset = dataset.batch(42)X, y = dataset.make_one_shot_iterator().get_next()

To show that it's actually multiple threads extracting from the generator, I modified your class as follows:

import threading    class SceneGenerator(object):  def __init__(self):    # some inits    pass  def generate_data(self):    """    Generator. Yield data X and labels y after some preprocessing    """    while True:      # opening files, selecting data      X,y = threading.get_ident(), 2 #self.preprocess(some_params, filenames, ...)                  yield X, y

This way, creating a Tensorflow session and getting one batch shows the thread IDs of the threads getting the data. On my pc, running:

sess = tf.Session()print(sess.run([X, y]))

prints

[array([  8460.,   8460.,   8460.,  15912.,  16200.,  16200.,   8460.,         15912.,  16200.,   8460.,  15912.,  16200.,  16200.,   8460.,         15912.,  15912.,   8460.,   8460.,   6552.,  15912.,  15912.,          8460.,   8460.,  15912.,   9956.,  16200.,   9956.,  16200.,         15912.,  15912.,   9956.,  16200.,  15912.,  16200.,  16200.,         16200.,   6552.,  16200.,  16200.,   9956.,   6552.,   6552.], dtype=float32), array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])]

Note: You might want to experiment removing the map call (that we only use to have the multiple threads) and checking if the prefetch's buffer is enough to remove the bottleneck in your input pipeline (even with only one thread, often the input preprocessing is faster than the actual graph execution, so the buffer is enough to have the preprocessing go as fast as it can).


Running a session with a feed_dict is indeed pretty slow:

Feed_dict does a single-threaded memcpy of contents from Python runtime into TensorFlow runtime.

A faster way to feed the data is by using tf.train.string_input_producer + *Reader + tf.train.Coordinator, which will batch the data in multiple threads. For that, you read the data directly into tensors, e.g., here's a way to read and process a csv file:

def batch_generator(filenames):  filename_queue = tf.train.string_input_producer(filenames)  reader = tf.TextLineReader(skip_header_lines=1)  _, value = reader.read(filename_queue)  content = tf.decode_csv(value, record_defaults=record_defaults)  content[4] = tf.cond(tf.equal(content[4], tf.constant('Present')),                       lambda: tf.constant(1.0),                       lambda: tf.constant(0.0))  features = tf.stack(content[:N_FEATURES])  label = content[-1]  data_batch, label_batch = tf.train.shuffle_batch([features, label],                                                   batch_size=BATCH_SIZE,                                                   capacity=20*BATCH_SIZE,                                                   min_after_dequeue=10*BATCH_SIZE)  return data_batch, label_batch

This function gets the list of input files, creates the reader and data transformations and outputs the tensors, which are evaluated to the contents of these files. Your scene generator is likely to do different transformations, but the idea is the same.

Next, you start a tf.train.Coordinator to parallelize this:

with tf.Session() as sess:    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(coord=coord)    for _ in range(10):  # generate 10 batches        features, labels = sess.run([data_batch, label_batch])        print(features)    coord.request_stop()    coord.join(threads)

In my experience, this way feeds the data much faster and allows to utilize the whole available GPU power. Complete working example can be found here.