Tensorflow and Multiprocessing: Passing Sessions
You can't use Python multiprocessing to pass a TensorFlow Session
into a multiprocessing.Pool
in the straightfoward way because the Session
object can't be pickled (it's fundamentally not serializable because it may manage GPU memory and state like that).
I'd suggest parallelizing the code using actors, which are essentially the parallel computing analog of "objects" and use used to manage state in the distributed setting.
Ray is a good framework for doing this. You can define a Python class which manages the TensorFlow Session
and exposes a method for running your simulation.
import rayimport tensorflow as tfray.init()@ray.remoteclass Simulator(object): def __init__(self): self.sess = tf.Session() self.simple_model = tf.constant([1.0]) def simulate(self): return self.sess.run(self.simple_model)# Create two actors.simulators = [Simulator.remote() for _ in range(2)]# Run two simulations in parallel.results = ray.get([s.simulate.remote() for s in simulators])
Here are a few more examples of parallelizing TensorFlow with Ray.
See the Ray documentation. Note that I'm one of the Ray developers.
I use keras as a wrapper with tensorflow as a backed, but the same general principal should apply.
If you try something like this:
import kerasfrom functools import partialfrom multiprocessing import Pooldef ModelFunc(i,SomeData): YourModel = Here return(ModelScore)pool = Pool(processes = 4)for i,Score in enumerate(pool.imap(partial(ModelFunc,SomeData),range(4))): print(Score)
It will fail. However, if you try something like this:
from functools import partialfrom multiprocessing import Pooldef ModelFunc(i,SomeData): import keras YourModel = Here return(ModelScore)pool = Pool(processes = 4)for i,Score in enumerate(pool.imap(partial(ModelFunc,SomeData),range(4))): print(Score)
It should work. Try calling tensorflow separately for each process.