fast python numpy where functionality? fast python numpy where functionality? numpy numpy

fast python numpy where functionality?


It turns out that a pure Python loop can be much much faster than NumPy indexing (or calls to np.where) in this case.

Consider the following alternatives:

import numpy as npimport collectionsimport itertools as ITshape = (2600,5200)# shape = (26,52)emiss_data = np.random.random(shape)obj_data = np.random.random_integers(1, 800, size=shape)UNIQ_IDS = np.unique(obj_data)def using_where():    max = np.max    where = np.where    MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS]    return MAX_EMISSdef using_index():    max = np.max    MAX_EMISS = [max(emiss_data[obj_data == i]) for i in UNIQ_IDS]    return MAX_EMISSdef using_max():    MAX_EMISS = [(emiss_data[obj_data == i]).max() for i in UNIQ_IDS]    return MAX_EMISSdef using_loop():    result = collections.defaultdict(list)    for val, idx in IT.izip(emiss_data.ravel(), obj_data.ravel()):        result[idx].append(val)    return [max(result[idx]) for idx in UNIQ_IDS]def using_sort():    uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1    vals = uind.argsort()    count = np.bincount(uind)    start = 0    end = 0    out = np.empty(count.shape[0])    for ind, x in np.ndenumerate(count):        end += x        out[ind] = np.max(np.take(emiss_data, vals[start:end]))        start += x    return outdef using_split():    uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1    vals = uind.argsort()    count = np.bincount(uind)    return [np.take(emiss_data, item).max()            for item in np.split(vals, count.cumsum())[:-1]]for func in (using_index, using_max, using_loop, using_sort, using_split):    assert using_where() == func()

Here are the benchmarks, with shape = (2600,5200):

In [57]: %timeit using_loop()1 loops, best of 3: 9.15 s per loopIn [90]: %timeit using_sort()1 loops, best of 3: 9.33 s per loopIn [91]: %timeit using_split()1 loops, best of 3: 9.33 s per loopIn [61]: %timeit using_index()1 loops, best of 3: 63.2 s per loopIn [62]: %timeit using_max()1 loops, best of 3: 64.4 s per loopIn [58]: %timeit using_where()1 loops, best of 3: 112 s per loop

Thus using_loop (pure Python) turns out to be more than 11x faster than using_where.

I'm not entirely sure why pure Python is faster than NumPy here. My guess is that the pure Python version zips (yes, pun intended) through both arrays once. It leverages the fact that despite all the fancy indexing, we really just want to visit each value once. Thus it side-steps the issue with having to determine exactly which group each value in emiss_data falls in. But this is just vague speculation. I didn't know it would be faster until I benchmarked.


Can use np.unique with return_index:

def using_sort():    #UNIQ_IDS,uind=np.unique(obj_data, return_inverse=True)    uind= np.digitize(obj_data.ravel(), UNIQ_IDS) - 1    vals=uind.argsort()    count=np.bincount(uind)    start=0    end=0    out=np.empty(count.shape[0])    for ind,x in np.ndenumerate(count):        end+=x        out[ind]=np.max(np.take(emiss_data,vals[start:end]))        start+=x    return out

Using @unutbu's answer as a baseline for shape = (2600,5200):

np.allclose(using_loop(),using_sort())True%timeit using_loop()1 loops, best of 3: 12.3 s per loop#With np.unique inside the definition%timeit using_sort()1 loops, best of 3: 9.06 s per loop#With np.unique outside the definition %timeit using_sort()1 loops, best of 3: 2.75 s per loop#Using @Jamie's suggestion for uind%timeit using_sort()1 loops, best of 3: 6.74 s per loop


I believe the fastest way to accomplish this is to use the groupby() operations in the pandas package. Comparing to @Ophion's using_sort() function, Pandas is about a factor of 10 faster:

import numpy as npimport pandas as pdshape = (2600,5200)emiss_data = np.random.random(shape)obj_data = np.random.random_integers(1, 800, size=shape)UNIQ_IDS = np.unique(obj_data)def using_sort():    #UNIQ_IDS,uind=np.unique(obj_data, return_inverse=True)    uind= np.digitize(obj_data.ravel(), UNIQ_IDS) - 1    vals=uind.argsort()    count=np.bincount(uind)    start=0    end=0    out=np.empty(count.shape[0])    for ind,x in np.ndenumerate(count):        end+=x        out[ind]=np.max(np.take(emiss_data,vals[start:end]))        start+=x    return outdef using_pandas():    return pd.Series(emiss_data.ravel()).groupby(obj_data.ravel()).max()print('same results:', np.allclose(using_pandas(), using_sort()))# same results: True%timeit using_sort()# 1 loops, best of 3: 3.39 s per loop%timeit using_pandas()# 1 loops, best of 3: 397 ms per loop