Sort a numpy array by another array, along a particular axis
You still have to supply indices for the other two dimensions for this to work correctly.
>>> a = numpy.zeros((3, 3, 3))>>> a += numpy.array((1, 3, 2)).reshape((3, 1, 1))>>> b = numpy.arange(3*3*3).reshape((3, 3, 3))>>> sort_indices = numpy.argsort(a, axis=0)>>> static_indices = numpy.indices((3, 3, 3))>>> b[sort_indices, static_indices[1], static_indices[2]]array([[[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8]], [[18, 19, 20], [21, 22, 23], [24, 25, 26]], [[ 9, 10, 11], [12, 13, 14], [15, 16, 17]]])
numpy.indices
calculates the indices of each axis of the array when "flattened" through the other two axes (or n - 1 axes where n = total number of axes). In other words, this (apologies for the long post):
>>> static_indicesarray([[[[0, 0, 0], [0, 0, 0], [0, 0, 0]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[2, 2, 2], [2, 2, 2], [2, 2, 2]]], [[[0, 0, 0], [1, 1, 1], [2, 2, 2]], [[0, 0, 0], [1, 1, 1], [2, 2, 2]], [[0, 0, 0], [1, 1, 1], [2, 2, 2]]], [[[0, 1, 2], [0, 1, 2], [0, 1, 2]], [[0, 1, 2], [0, 1, 2], [0, 1, 2]], [[0, 1, 2], [0, 1, 2], [0, 1, 2]]]])
These are the identity indices for each axis; when used to index b, they recreate b.
>>> b[static_indices[0], static_indices[1], static_indices[2]]array([[[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8]], [[ 9, 10, 11], [12, 13, 14], [15, 16, 17]], [[18, 19, 20], [21, 22, 23], [24, 25, 26]]])
As an alternative to numpy.indices
, you could use numpy.ogrid
, as unutbu suggests. Since the object generated by ogrid
is smaller, I'll create all three axes, just for consistency sake, but note unutbu's comment for a way to do this by generating only two.
>>> static_indices = numpy.ogrid[0:a.shape[0], 0:a.shape[1], 0:a.shape[2]]>>> a[sort_indices, static_indices[1], static_indices[2]]array([[[ 1., 1., 1.], [ 1., 1., 1.], [ 1., 1., 1.]], [[ 2., 2., 2.], [ 2., 2., 2.], [ 2., 2., 2.]], [[ 3., 3., 3.], [ 3., 3., 3.], [ 3., 3., 3.]]])