How do I get indices of N maximum values in a NumPy array? How do I get indices of N maximum values in a NumPy array? python python

How do I get indices of N maximum values in a NumPy array?


Newer NumPy versions (1.8 and up) have a function called argpartition for this. To get the indices of the four largest elements, do

>>> a = np.array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])>>> aarray([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])>>> ind = np.argpartition(a, -4)[-4:]>>> indarray([1, 5, 8, 0])>>> a[ind]array([4, 9, 6, 9])

Unlike argsort, this function runs in linear time in the worst case, but the returned indices are not sorted, as can be seen from the result of evaluating a[ind]. If you need that too, sort them afterwards:

>>> ind[np.argsort(a[ind])]array([1, 8, 5, 0])

To get the top-k elements in sorted order in this way takes O(n + k log k) time.


The simplest I've been able to come up with is:

In [1]: import numpy as npIn [2]: arr = np.array([1, 3, 2, 4, 5])In [3]: arr.argsort()[-3:][::-1]Out[3]: array([4, 3, 1])

This involves a complete sort of the array. I wonder if numpy provides a built-in way to do a partial sort; so far I haven't been able to find one.

If this solution turns out to be too slow (especially for small n), it may be worth looking at coding something up in Cython.


Simpler yet:

idx = (-arr).argsort()[:n]

where n is the number of maximum values.