Find minimum distances between groups of points in 2D (fast and not too memory consuming) Find minimum distances between groups of points in 2D (fast and not too memory consuming) numpy numpy

Find minimum distances between groups of points in 2D (fast and not too memory consuming)


The best approach will involve using a data structure specially designed for nearest neighbor search, such as a k-d tree. For example, SciPy's cKDTree allows you to solve the problem this way:

from scipy.spatial import cKDTreemin_dists, min_dist_idx = cKDTree(B).query(A, 1)

The result is much more efficient than any approach based on broadcasting, both in terms of computation and memory use.

For example, even with 1,000,000 points, the computation does not run out of memory, and takes only a few seconds on my laptop:

N = 1000000A = np.random.uniform(0., 5000., (N, 2))B = np.random.uniform(0., 5000., (N, 2))%timeit cKDTree(B).query(A, 1)# 3.25 s ± 17.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


The trick is to maximize compute versus memory ratio here. The output is of length N, one index and distance for each pt in A. We could reduce it to one loop with one output element per iteration and this will process through all B pts per iteration, which is bringing in the high compute ratio.

Thus, leveraging einsum and matrix-multiplication inspired from this post, for each point pt in A, we would get the squared euclidean distances, like so -

for pt in A:    d = np.einsum('ij,ij->i',B,B) + pt.dot(pt) - 2*B.dot(pt)

Thus, generalizing it cover all points in A and pre-computing np.einsum('ij,ij->i',B,B), we would have an implementation like so -

min_idx = np.empty(N, dtype=int)min_dist = np.empty(N)Bsqsum = np.einsum('ij,ij->i',B,B) for i,pt in enumerate(A):    d = Bsqsum + pt.dot(pt) - 2*B.dot(pt)    min_idx[i] = d.argmin()    min_dist[i] = d[min_idx[i]]min_dist = np.sqrt(min_dist)

Working in chunks

Now, a fully vectorized solution would be -

np.einsum('ij,ij->i',B,B)[:,None] + np.einsum('ij,ij->i',A,A) - 2*B.dot(A.T)

So, to work in chunks, we would slice out rows off A and to do so would be easier to simply reshape to 3D, like so -

chunk_size= 100 # Edit this as per memory setup available                # More means more memory neededA.shape = (A.shape[0]//chunk_size, chunk_size,-1)min_idx = np.empty((N//chunk_size, chunk_size), dtype=int)min_dist = np.empty((N//chunk_size, chunk_size))Bsqsum = np.einsum('ij,ij->i',B,B)[:,None]r = np.arange(chunk_size)for i,chnk in enumerate(A):    d = Bsqsum + np.einsum('ij,ij->i',chnk,chnk) - 2*B.dot(chnk.T)    idx = d.argmin(0)    min_idx[i] = idx    min_dist[i] = d[idx,r]min_dist = np.sqrt(min_dist)min_idx.shape = (N,)min_dist.shape = (N,)A.shape = (N,-1)