Easy parallelization of numpy.apply_along_axis()?
Alright, I worked it out: an idea is to use the standard multiprocessing
module and split the original array in just a few chunks (so as to limit communication overhead with the workers). This can be done relatively easily as follows:
import multiprocessingimport numpy as npdef parallel_apply_along_axis(func1d, axis, arr, *args, **kwargs): """ Like numpy.apply_along_axis(), but takes advantage of multiple cores. """ # Effective axis where apply_along_axis() will be applied by each # worker (any non-zero axis number would work, so as to allow the use # of `np.array_split()`, which is only done on axis 0): effective_axis = 1 if axis == 0 else axis if effective_axis != axis: arr = arr.swapaxes(axis, effective_axis) # Chunks for the mapping (only a few chunks): chunks = [(func1d, effective_axis, sub_arr, args, kwargs) for sub_arr in np.array_split(arr, multiprocessing.cpu_count())] pool = multiprocessing.Pool() individual_results = pool.map(unpacking_apply_along_axis, chunks) # Freeing the workers: pool.close() pool.join() return np.concatenate(individual_results)
where the function unpacking_apply_along_axis()
being applied in Pool.map()
is separate as it should (so that subprocesses can import it), and is simply a thin wrapper that handles the fact that Pool.map()
only takes a single argument:
def unpacking_apply_along_axis((func1d, axis, arr, args, kwargs)): """ Like numpy.apply_along_axis(), but with arguments in a tuple instead. This function is useful with multiprocessing.Pool().map(): (1) map() only handles functions that take a single argument, and (2) this function can generally be imported from a module, as required by map(). """ return np.apply_along_axis(func1d, axis, arr, *args, **kwargs)
(in Python 3, this should be written as
def unpacking_apply_along_axis(all_args): (func1d, axis, arr, args, kwargs) = all_args
because argument unpacking was removed).
In my particular case, this resulted in a 2x speedup on 2 cores with hyper-threading. A factor closer to 4x would have been nicer, but the speed up is already nice, in just a few lines of code, and it should be better for machines with more cores (which are quite common). Maybe there is a way of avoiding data copies and using shared memory (maybe through the multiprocessing
module itself)?