Tracking progress of joblib.Parallel execution
Yet another step ahead from dano and Connor answers is to wrap whole thing as context manager:
import contextlibimport joblibfrom tqdm import tqdm from joblib import Parallel, delayed@contextlib.contextmanagerdef tqdm_joblib(tqdm_object): """Context manager to patch joblib to report into tqdm progress bar given as argument""" class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __call__(self, *args, **kwargs): tqdm_object.update(n=self.batch_size) return super().__call__(*args, **kwargs) old_batch_callback = joblib.parallel.BatchCompletionCallBack joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback try: yield tqdm_object finally: joblib.parallel.BatchCompletionCallBack = old_batch_callback tqdm_object.close()
Then you can use it like this and don't leave monkey patched code once you've done:
with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
which is awesome I think and it looks similar to tqdm pandas integration.
Why can't you simply use tqdm
? The following worked for me
from joblib import Parallel, delayedfrom datetime import datetimefrom tqdm import tqdmdef myfun(x): return x**2results = Parallel(n_jobs=8)(delayed(myfun)(i) for i in tqdm(range(1000))100%|██████████| 1000/1000 [00:00<00:00, 10563.37it/s]
The documentation you linked to states that Parallel
has an optional progress meter. It's implemented by using the callback
keyword argument provided by multiprocessing.Pool.apply_async
:
# This is inside a dispatch functionself._lock.acquire()job = self._pool.apply_async(SafeFunction(func), args, kwargs, callback=CallBack(self.n_dispatched, self))self._jobs.append(job)self.n_dispatched += 1
...
class CallBack(object): """ Callback used by parallel: it is used for progress reporting, and to add data to be processed """ def __init__(self, index, parallel): self.parallel = parallel self.index = index def __call__(self, out): self.parallel.print_progress(self.index) if self.parallel._original_iterable: self.parallel.dispatch_next()
And here's print_progress
:
def print_progress(self, index): elapsed_time = time.time() - self._start_time # This is heuristic code to print only 'verbose' times a messages # The challenge is that we may not know the queue length if self._original_iterable: if _verbosity_filter(index, self.verbose): return self._print('Done %3i jobs | elapsed: %s', (index + 1, short_format_time(elapsed_time), )) else: # We are finished dispatching queue_length = self.n_dispatched # We always display the first loop if not index == 0: # Display depending on the number of remaining items # A message as soon as we finish dispatching, cursor is 0 cursor = (queue_length - index + 1 - self._pre_dispatch_amount) frequency = (queue_length // self.verbose) + 1 is_last_item = (index + 1 == queue_length) if (is_last_item or cursor % frequency): return remaining_time = (elapsed_time / (index + 1) * (self.n_dispatched - index - 1.)) self._print('Done %3i out of %3i | elapsed: %s remaining: %s', (index + 1, queue_length, short_format_time(elapsed_time), short_format_time(remaining_time), ))
The way they implement this is kind of weird, to be honest - it seems to assume that tasks will always be completed in the order that they're started. The index
variable that goes to print_progress
is just the self.n_dispatched
variable at the time the job was actually started. So the first job launched will always finish with an index
of 0, even if say, the third job finished first. It also means they don't actually keep track of the number of completed jobs. So there's no instance variable for you to monitor.
I think your best best is to make your own CallBack class, and monkey patch Parallel:
from math import sqrtfrom collections import defaultdictfrom joblib import Parallel, delayedclass CallBack(object): completed = defaultdict(int) def __init__(self, index, parallel): self.index = index self.parallel = parallel def __call__(self, index): CallBack.completed[self.parallel] += 1 print("done with {}".format(CallBack.completed[self.parallel])) if self.parallel._original_iterable: self.parallel.dispatch_next()import joblib.paralleljoblib.parallel.CallBack = CallBackif __name__ == "__main__": print(Parallel(n_jobs=2)(delayed(sqrt)(i**2) for i in range(10)))
Output:
done with 1done with 2done with 3done with 4done with 5done with 6done with 7done with 8done with 9done with 10[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
That way, your callback gets called whenever a job completes, rather than the default one.