Tracking progress of joblib.Parallel execution Tracking progress of joblib.Parallel execution python python

Tracking progress of joblib.Parallel execution


Yet another step ahead from dano and Connor answers is to wrap whole thing as context manager:

import contextlibimport joblibfrom tqdm import tqdm    from joblib import Parallel, delayed@contextlib.contextmanagerdef tqdm_joblib(tqdm_object):    """Context manager to patch joblib to report into tqdm progress bar given as argument"""    class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):        def __init__(self, *args, **kwargs):            super().__init__(*args, **kwargs)        def __call__(self, *args, **kwargs):            tqdm_object.update(n=self.batch_size)            return super().__call__(*args, **kwargs)    old_batch_callback = joblib.parallel.BatchCompletionCallBack    joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback    try:        yield tqdm_object    finally:        joblib.parallel.BatchCompletionCallBack = old_batch_callback        tqdm_object.close()    

Then you can use it like this and don't leave monkey patched code once you've done:

with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:    Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))

which is awesome I think and it looks similar to tqdm pandas integration.


Why can't you simply use tqdm? The following worked for me

from joblib import Parallel, delayedfrom datetime import datetimefrom tqdm import tqdmdef myfun(x):    return x**2results = Parallel(n_jobs=8)(delayed(myfun)(i) for i in tqdm(range(1000))100%|██████████| 1000/1000 [00:00<00:00, 10563.37it/s]


The documentation you linked to states that Parallel has an optional progress meter. It's implemented by using the callback keyword argument provided by multiprocessing.Pool.apply_async:

# This is inside a dispatch functionself._lock.acquire()job = self._pool.apply_async(SafeFunction(func), args,            kwargs, callback=CallBack(self.n_dispatched, self))self._jobs.append(job)self.n_dispatched += 1

...

class CallBack(object):    """ Callback used by parallel: it is used for progress reporting, and        to add data to be processed    """    def __init__(self, index, parallel):        self.parallel = parallel        self.index = index    def __call__(self, out):        self.parallel.print_progress(self.index)        if self.parallel._original_iterable:            self.parallel.dispatch_next()

And here's print_progress:

def print_progress(self, index):    elapsed_time = time.time() - self._start_time    # This is heuristic code to print only 'verbose' times a messages    # The challenge is that we may not know the queue length    if self._original_iterable:        if _verbosity_filter(index, self.verbose):            return        self._print('Done %3i jobs       | elapsed: %s',                    (index + 1,                     short_format_time(elapsed_time),                    ))    else:        # We are finished dispatching        queue_length = self.n_dispatched        # We always display the first loop        if not index == 0:            # Display depending on the number of remaining items            # A message as soon as we finish dispatching, cursor is 0            cursor = (queue_length - index + 1                      - self._pre_dispatch_amount)            frequency = (queue_length // self.verbose) + 1            is_last_item = (index + 1 == queue_length)            if (is_last_item or cursor % frequency):                return        remaining_time = (elapsed_time / (index + 1) *                    (self.n_dispatched - index - 1.))        self._print('Done %3i out of %3i | elapsed: %s remaining: %s',                    (index + 1,                     queue_length,                     short_format_time(elapsed_time),                     short_format_time(remaining_time),                    ))

The way they implement this is kind of weird, to be honest - it seems to assume that tasks will always be completed in the order that they're started. The index variable that goes to print_progress is just the self.n_dispatched variable at the time the job was actually started. So the first job launched will always finish with an index of 0, even if say, the third job finished first. It also means they don't actually keep track of the number of completed jobs. So there's no instance variable for you to monitor.

I think your best best is to make your own CallBack class, and monkey patch Parallel:

from math import sqrtfrom collections import defaultdictfrom joblib import Parallel, delayedclass CallBack(object):    completed = defaultdict(int)    def __init__(self, index, parallel):        self.index = index        self.parallel = parallel    def __call__(self, index):        CallBack.completed[self.parallel] += 1        print("done with {}".format(CallBack.completed[self.parallel]))        if self.parallel._original_iterable:            self.parallel.dispatch_next()import joblib.paralleljoblib.parallel.CallBack = CallBackif __name__ == "__main__":    print(Parallel(n_jobs=2)(delayed(sqrt)(i**2) for i in range(10)))

Output:

done with 1done with 2done with 3done with 4done with 5done with 6done with 7done with 8done with 9done with 10[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]

That way, your callback gets called whenever a job completes, rather than the default one.