diff --git a/kd_tree.py b/kd_tree.py index 5b07fda..cd871b7 100644 --- a/kd_tree.py +++ b/kd_tree.py @@ -39,7 +39,10 @@ def __init__(self, points, dim, dist_sq_func=None): def make(points, i=0): if len(points) > 1: - points.sort(key=lambda x: x[i]) + try: + points.sort(key=lambda x: x[i]) # list|tuple + except TypeError: + points.sort(axis=i) # np.ndarray i = (i + 1) % dim m = len(points) >> 1 return [make(points[:m], i), make(points[m + 1:], i), @@ -68,7 +71,7 @@ def get_knn(node, point, k, return_dist_sq, heap, i=0, tiebreaker=1): i = (i + 1) % dim # Goes into the left branch, then the right branch if needed for b in (dx < 0, dx >= 0)[:1 + (dx * dx < -heap[0][0])]: - get_knn(node[b], point, k, return_dist_sq, + get_knn(node[int(b)], point, k, return_dist_sq, heap, i, (tiebreaker << 1) | b) if tiebreaker == 1: return [(-h[0], h[2]) if return_dist_sq else h[2]