A Python implementation of Kd-Tree and K-nearest neighbors search
Because python does not have the build-in n-partition function. For now, I sort each layer, which make the code a little bit slower but is enough to pass the test cases in most cases.
import sys
import math
import heapq
input = sys.stdin.readline
class k_d_tree():
def __init__(self, points):
self.d = len(points[0])
self.n = len(points)
self.tree = [None] * self.n
self.build_tree(points, 0, 0, self.n - 1)
def build_tree(self, points, depth, left, right):
if left > right:
return
dimension = depth % self.d
points[left: right + 1] = sorted(points[left: right + 1], key=lambda X: X[dimension])
mid = (left + right) // 2
self.tree[mid] = points[mid]
self.build_tree(points, depth + 1, left, mid - 1)
self.build_tree(points, depth + 1, mid + 1, right)
return
def query(self, query_point, heap, k, depth, left, right):
if left > right:
return
dimension = depth % self.d
# mid
mid = (left + right) // 2
mid_dist = sum((x1 - x2) ** 2 for x1, x2 in zip(query_point, self.tree[mid]))
heapq.heappush(heap, -mid_dist)
if len(heap) > k:
heapq.heappop(heap)
# left
if query_point[dimension] <= self.tree[mid][dimension]:
self.query(query_point, heap, k, depth + 1, left, mid - 1)
if len(heap) < k or (query_point[dimension] - self.tree[mid][dimension]) ** 2 <= -heap[0]:
self.query(query_point, heap, k, depth + 1, mid + 1, right)
# right
else:
self.query(query_point, heap, k, depth + 1, mid + 1, right)
if len(heap) < k or (query_point[dimension] - self.tree[mid][dimension]) ** 2 <= -heap[0]:
self.query(query_point, heap, k, depth + 1, left, mid - 1)
n, d = list(map(int, input().split()))
points = []
for _ in range(n):
point = list(map(float, input().split()))
points.append(point)
tree = k_d_tree(points)
q = int(input())
for _ in range(q):
query = input().split()
k = int(query[0])
query_point = list(map(float, query[1:]))
heap = []
tree.query(query_point, heap, k, 0, 0, len(points) - 1)
print(math.sqrt(-heap[0]))