Easy parallelization of numpy.apply_along_axis()? Easy parallelization of numpy.apply_along_axis()? numpy numpy

Easy parallelization of numpy.apply_along_axis()?


Alright, I worked it out: an idea is to use the standard multiprocessing module and split the original array in just a few chunks (so as to limit communication overhead with the workers). This can be done relatively easily as follows:

import multiprocessingimport numpy as npdef parallel_apply_along_axis(func1d, axis, arr, *args, **kwargs):    """    Like numpy.apply_along_axis(), but takes advantage of multiple    cores.    """            # Effective axis where apply_along_axis() will be applied by each    # worker (any non-zero axis number would work, so as to allow the use    # of `np.array_split()`, which is only done on axis 0):    effective_axis = 1 if axis == 0 else axis    if effective_axis != axis:        arr = arr.swapaxes(axis, effective_axis)    # Chunks for the mapping (only a few chunks):    chunks = [(func1d, effective_axis, sub_arr, args, kwargs)              for sub_arr in np.array_split(arr, multiprocessing.cpu_count())]    pool = multiprocessing.Pool()    individual_results = pool.map(unpacking_apply_along_axis, chunks)    # Freeing the workers:    pool.close()    pool.join()    return np.concatenate(individual_results)

where the function unpacking_apply_along_axis() being applied in Pool.map() is separate as it should (so that subprocesses can import it), and is simply a thin wrapper that handles the fact that Pool.map() only takes a single argument:

def unpacking_apply_along_axis((func1d, axis, arr, args, kwargs)):    """    Like numpy.apply_along_axis(), but with arguments in a tuple    instead.    This function is useful with multiprocessing.Pool().map(): (1)    map() only handles functions that take a single argument, and (2)    this function can generally be imported from a module, as required    by map().    """    return np.apply_along_axis(func1d, axis, arr, *args, **kwargs)

(in Python 3, this should be written as

def unpacking_apply_along_axis(all_args):    (func1d, axis, arr, args, kwargs) = all_args

because argument unpacking was removed).

In my particular case, this resulted in a 2x speedup on 2 cores with hyper-threading. A factor closer to 4x would have been nicer, but the speed up is already nice, in just a few lines of code, and it should be better for machines with more cores (which are quite common). Maybe there is a way of avoiding data copies and using shared memory (maybe through the multiprocessing module itself)?