TensorFlow/Keras multi-threaded model fitting TensorFlow/Keras multi-threaded model fitting multithreading multithreading

TensorFlow/Keras multi-threaded model fitting


Tensorflow Graphs are not threadsafe (see https://www.tensorflow.org/api_docs/python/tf/Graph) and when you create a new Tensorflow Session, it by default uses the default graph.

You can get around this by creating a new session with a new graph in your parallelized function and constructing your keras model there.

Here is some code that creates and fits a model on each available gpu in parallel:

import concurrent.futuresimport numpy as npimport keras.backend as Kfrom keras.layers import Densefrom keras.models import Sequentialimport tensorflow as tffrom tensorflow.python.client import device_libdef get_available_gpus():    local_device_protos = device_lib.list_local_devices()    return [x.name for x in local_device_protos if x.device_type == 'GPU']xdata = np.random.randn(100, 8)ytrue = np.random.randint(0, 2, 100)def fit(gpu):    with tf.Session(graph=tf.Graph()) as sess:        K.set_session(sess)        with tf.device(gpu):            model = Sequential()            model.add(Dense(12, input_dim=8, activation='relu'))            model.add(Dense(8, activation='relu'))            model.add(Dense(1, activation='sigmoid'))            model.compile(loss='binary_crossentropy', optimizer='adam')            model.fit(xdata, ytrue, verbose=0)            return model.evaluate(xdata, ytrue, verbose=0)gpus = get_available_gpus()with concurrent.futures.ThreadPoolExecutor(len(gpus)) as executor:    results = [x for x in executor.map(fit, gpus)]print('results: ', results)