Although JIT has become all too common and is perhaps optimal from a certain perspective, requiring everything be JIT compiled comes with the downside that the lower bound for computing any kernel op is actually pretty high (because compiling is non-trivial)
Taking the example from the documentation for smaller dimensions:
M, N, D = 1_000, 2_000, 3
x = torch.randn(M, D, requires_grad=True)
y = torch.randn(N, D)
x_i = LazyTensor(x.view(M, 1, D))
y_j = LazyTensor(y.view(1, N, D))
D_ij = ((x_i - y_j) ** 2).sum(dim=2)
K_ij = (-D_ij).exp()
On my machine, if the kernel has already been JIT-compiled, of course the reduction is blazing fast:
timeit.timeit(lambda: K_ij.sum(dim=1), number=30) / 30
# 0.001784 seconds
In contrast, if remove the JIT-compiled targets (via clean_pykeops()) and run it once, the same computation takes around 2.1 seconds (which about 1200x slower...). It would be nice if, like numpy and torch, pykeops shipped with pre-compiled kernels (for very common ops I guess?) that worked for any size and could be take advantage of when JIT is not preferred.
I realize for smaller cases I could just use torch/numpy, but even for modestly sized $n$ (e.g. $n$ > 10k) all-pairwise operations allocate hundreds of MB or GB of memory.
Although JIT has become all too common and is perhaps optimal from a certain perspective, requiring everything be JIT compiled comes with the downside that the lower bound for computing any kernel op is actually pretty high (because compiling is non-trivial)
Taking the example from the documentation for smaller dimensions:
On my machine, if the kernel has already been JIT-compiled, of course the reduction is blazing fast:
In contrast, if remove the JIT-compiled targets (via
clean_pykeops()) and run it once, the same computation takes around 2.1 seconds (which about 1200x slower...). It would be nice if, likenumpyandtorch,pykeopsshipped with pre-compiled kernels (for very common ops I guess?) that worked for any size and could be take advantage of when JIT is not preferred.I realize for smaller cases I could just use torch/numpy, but even for modestly sized$n$ (e.g. $n$ > 10k) all-pairwise operations allocate hundreds of MB or GB of memory.