fast python numpy where functionality?
It turns out that a pure Python loop can be much much faster than NumPy indexing (or calls to np.where) in this case.
Consider the following alternatives:
import numpy as npimport collectionsimport itertools as ITshape = (2600,5200)# shape = (26,52)emiss_data = np.random.random(shape)obj_data = np.random.random_integers(1, 800, size=shape)UNIQ_IDS = np.unique(obj_data)def using_where(): max = np.max where = np.where MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS] return MAX_EMISSdef using_index(): max = np.max MAX_EMISS = [max(emiss_data[obj_data == i]) for i in UNIQ_IDS] return MAX_EMISSdef using_max(): MAX_EMISS = [(emiss_data[obj_data == i]).max() for i in UNIQ_IDS] return MAX_EMISSdef using_loop(): result = collections.defaultdict(list) for val, idx in IT.izip(emiss_data.ravel(), obj_data.ravel()): result[idx].append(val) return [max(result[idx]) for idx in UNIQ_IDS]def using_sort(): uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 vals = uind.argsort() count = np.bincount(uind) start = 0 end = 0 out = np.empty(count.shape[0]) for ind, x in np.ndenumerate(count): end += x out[ind] = np.max(np.take(emiss_data, vals[start:end])) start += x return outdef using_split(): uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 vals = uind.argsort() count = np.bincount(uind) return [np.take(emiss_data, item).max() for item in np.split(vals, count.cumsum())[:-1]]for func in (using_index, using_max, using_loop, using_sort, using_split): assert using_where() == func()
Here are the benchmarks, with shape = (2600,5200)
:
In [57]: %timeit using_loop()1 loops, best of 3: 9.15 s per loopIn [90]: %timeit using_sort()1 loops, best of 3: 9.33 s per loopIn [91]: %timeit using_split()1 loops, best of 3: 9.33 s per loopIn [61]: %timeit using_index()1 loops, best of 3: 63.2 s per loopIn [62]: %timeit using_max()1 loops, best of 3: 64.4 s per loopIn [58]: %timeit using_where()1 loops, best of 3: 112 s per loop
Thus using_loop
(pure Python) turns out to be more than 11x faster than using_where
.
I'm not entirely sure why pure Python is faster than NumPy here. My guess is that the pure Python version zips (yes, pun intended) through both arrays once. It leverages the fact that despite all the fancy indexing, we really just want to visit each value once. Thus it side-steps the issue with having to determine exactly which group each value in emiss_data
falls in. But this is just vague speculation. I didn't know it would be faster until I benchmarked.
Can use np.unique
with return_index
:
def using_sort(): #UNIQ_IDS,uind=np.unique(obj_data, return_inverse=True) uind= np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 vals=uind.argsort() count=np.bincount(uind) start=0 end=0 out=np.empty(count.shape[0]) for ind,x in np.ndenumerate(count): end+=x out[ind]=np.max(np.take(emiss_data,vals[start:end])) start+=x return out
Using @unutbu's answer as a baseline for shape = (2600,5200)
:
np.allclose(using_loop(),using_sort())True%timeit using_loop()1 loops, best of 3: 12.3 s per loop#With np.unique inside the definition%timeit using_sort()1 loops, best of 3: 9.06 s per loop#With np.unique outside the definition %timeit using_sort()1 loops, best of 3: 2.75 s per loop#Using @Jamie's suggestion for uind%timeit using_sort()1 loops, best of 3: 6.74 s per loop
I believe the fastest way to accomplish this is to use the groupby()
operations in the pandas
package. Comparing to @Ophion's using_sort()
function, Pandas is about a factor of 10 faster:
import numpy as npimport pandas as pdshape = (2600,5200)emiss_data = np.random.random(shape)obj_data = np.random.random_integers(1, 800, size=shape)UNIQ_IDS = np.unique(obj_data)def using_sort(): #UNIQ_IDS,uind=np.unique(obj_data, return_inverse=True) uind= np.digitize(obj_data.ravel(), UNIQ_IDS) - 1 vals=uind.argsort() count=np.bincount(uind) start=0 end=0 out=np.empty(count.shape[0]) for ind,x in np.ndenumerate(count): end+=x out[ind]=np.max(np.take(emiss_data,vals[start:end])) start+=x return outdef using_pandas(): return pd.Series(emiss_data.ravel()).groupby(obj_data.ravel()).max()print('same results:', np.allclose(using_pandas(), using_sort()))# same results: True%timeit using_sort()# 1 loops, best of 3: 3.39 s per loop%timeit using_pandas()# 1 loops, best of 3: 397 ms per loop