Most efficient way to find mode in numpy array
Check scipy.stats.mode()
(inspired by @tom10's comment):
import numpy as npfrom scipy import statsa = np.array([[1, 3, 4, 2, 2, 7], [5, 2, 2, 1, 4, 1], [3, 3, 2, 2, 1, 1]])m = stats.mode(a)print(m)
Output:
ModeResult(mode=array([[1, 3, 2, 2, 1, 1]]), count=array([[1, 2, 2, 2, 1, 2]]))
As you can see, it returns both the mode as well as the counts. You can select the modes directly via m[0]
:
print(m[0])
Output:
[[1 3 2 2 1 1]]
Update
The scipy.stats.mode
function has been significantly optimized since this post, and would be the recommended method
Old answer
This is a tricky problem, since there is not much out there to calculate mode along an axis. The solution is straight forward for 1-D arrays, where numpy.bincount
is handy, along with numpy.unique
with the return_counts
arg as True
. The most common n-dimensional function I see is scipy.stats.mode, although it is prohibitively slow- especially for large arrays with many unique values. As a solution, I've developed this function, and use it heavily:
import numpydef mode(ndarray, axis=0): # Check inputs ndarray = numpy.asarray(ndarray) ndim = ndarray.ndim if ndarray.size == 1: return (ndarray[0], 1) elif ndarray.size == 0: raise Exception('Cannot compute mode on empty array') try: axis = range(ndarray.ndim)[axis] except: raise Exception('Axis "{}" incompatible with the {}-dimension array'.format(axis, ndim)) # If array is 1-D and numpy version is > 1.9 numpy.unique will suffice if all([ndim == 1, int(numpy.__version__.split('.')[0]) >= 1, int(numpy.__version__.split('.')[1]) >= 9]): modals, counts = numpy.unique(ndarray, return_counts=True) index = numpy.argmax(counts) return modals[index], counts[index] # Sort array sort = numpy.sort(ndarray, axis=axis) # Create array to transpose along the axis and get padding shape transpose = numpy.roll(numpy.arange(ndim)[::-1], axis) shape = list(sort.shape) shape[axis] = 1 # Create a boolean array along strides of unique values strides = numpy.concatenate([numpy.zeros(shape=shape, dtype='bool'), numpy.diff(sort, axis=axis) == 0, numpy.zeros(shape=shape, dtype='bool')], axis=axis).transpose(transpose).ravel() # Count the stride lengths counts = numpy.cumsum(strides) counts[~strides] = numpy.concatenate([[0], numpy.diff(counts[~strides])]) counts[strides] = 0 # Get shape of padded counts and slice to return to the original shape shape = numpy.array(sort.shape) shape[axis] += 1 shape = shape[transpose] slices = [slice(None)] * ndim slices[axis] = slice(1, None) # Reshape and compute final counts counts = counts.reshape(shape).transpose(transpose)[slices] + 1 # Find maximum counts and return modals/counts slices = [slice(None, i) for i in sort.shape] del slices[axis] index = numpy.ogrid[slices] index.insert(axis, numpy.argmax(counts, axis=axis)) return sort[index], counts[index]
Result:
In [2]: a = numpy.array([[1, 3, 4, 2, 2, 7], [5, 2, 2, 1, 4, 1], [3, 3, 2, 2, 1, 1]])In [3]: mode(a)Out[3]: (array([1, 3, 2, 2, 1, 1]), array([1, 2, 2, 2, 1, 2]))
Some benchmarks:
In [4]: import scipy.statsIn [5]: a = numpy.random.randint(1,10,(1000,1000))In [6]: %timeit scipy.stats.mode(a)10 loops, best of 3: 41.6 ms per loopIn [7]: %timeit mode(a)10 loops, best of 3: 46.7 ms per loopIn [8]: a = numpy.random.randint(1,500,(1000,1000))In [9]: %timeit scipy.stats.mode(a)1 loops, best of 3: 1.01 s per loopIn [10]: %timeit mode(a)10 loops, best of 3: 80 ms per loopIn [11]: a = numpy.random.random((200,200))In [12]: %timeit scipy.stats.mode(a)1 loops, best of 3: 3.26 s per loopIn [13]: %timeit mode(a)1000 loops, best of 3: 1.75 ms per loop
EDIT: Provided more of a background and modified the approach to be more memory-efficient
Expanding on this method, applied to finding the mode of the data where you may need the index of the actual array to see how far away the value is from the center of the distribution.
(_, idx, counts) = np.unique(a, return_index=True, return_counts=True)index = idx[np.argmax(counts)]mode = a[index]
Remember to discard the mode when len(np.argmax(counts)) > 1, also to validate if it is actually representative of the central distribution of your data you may check whether it falls inside your standard deviation interval.