From 26c78ced80f393c3b7d90b827d27c0a0176e23e0 Mon Sep 17 00:00:00 2001 From: Bjotori <120367777+Bjotori@users.noreply.github.com> Date: Mon, 12 Dec 2022 09:34:05 +0100 Subject: [PATCH] fix when points are np.ndarray fix sort with np.ndarray + fix warning 'np.bool_' scalars to be interpreted as an index --- kd_tree.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/kd_tree.py b/kd_tree.py index 5b07fda..cd871b7 100644 --- a/kd_tree.py +++ b/kd_tree.py @@ -39,7 +39,10 @@ def __init__(self, points, dim, dist_sq_func=None): def make(points, i=0): if len(points) > 1: - points.sort(key=lambda x: x[i]) + try: + points.sort(key=lambda x: x[i]) # list|tuple + except TypeError: + points.sort(axis=i) # np.ndarray i = (i + 1) % dim m = len(points) >> 1 return [make(points[:m], i), make(points[m + 1:], i), @@ -68,7 +71,7 @@ def get_knn(node, point, k, return_dist_sq, heap, i=0, tiebreaker=1): i = (i + 1) % dim # Goes into the left branch, then the right branch if needed for b in (dx < 0, dx >= 0)[:1 + (dx * dx < -heap[0][0])]: - get_knn(node[b], point, k, return_dist_sq, + get_knn(node[int(b)], point, k, return_dist_sq, heap, i, (tiebreaker << 1) | b) if tiebreaker == 1: return [(-h[0], h[2]) if return_dist_sq else h[2]