diff --git a/src/PointNeighbors.jl b/src/PointNeighbors.jl index 702c19e4..9eb1d7dc 100644 --- a/src/PointNeighbors.jl +++ b/src/PointNeighbors.jl @@ -6,7 +6,7 @@ using Adapt: Adapt using Atomix: Atomix using Base: @propagate_inbounds using GPUArraysCore: AbstractGPUArray -using KernelAbstractions: KernelAbstractions, @kernel, @index +using KernelAbstractions: KernelAbstractions, @kernel, @index, @groupsize using LinearAlgebra: dot using Polyester: Polyester @reexport using StaticArrays: SVector diff --git a/src/util.jl b/src/util.jl index c018b89a..20299977 100644 --- a/src/util.jl +++ b/src/util.jl @@ -155,22 +155,30 @@ end # On the GPU, we can only loop over `1:N`. Therefore, we loop over `1:length(iterator)` # and index with `iterator[eachindex(iterator)[i]]`. # Note that this only works with vector-like iterators that support arbitrary indexing. - indices = eachindex(iterator) + indices = eachindex(IndexLinear(), iterator) ndrange = length(indices) + # TODO: Is it better to pass `indices` to the kernel, + # or should we "recreate" them inside the kernel. + # Skip empty loops ndrange == 0 && return # Call the generic kernel that is defined below, which only calls a function with # the global GPU index. - generic_kernel(backend)(ndrange = ndrange) do i - @inbounds @inline f(iterator[indices[i]]) - end + foreach_ka(backend)(f, iterator, indices, ndrange = ndrange) KernelAbstractions.synchronize(backend) end -@kernel function generic_kernel(f) - i = @index(Global) - @inline f(i) +@kernel unsafe_indices=true function foreach_ka(f, iterator, indices) + # Calculate global index + N = @groupsize()[1] + iblock = @index(Group, Linear) + ithread = @index(Local, Linear) + i = ithread + (iblock - Int32(1)) * N + + if i <= length(indices) + @inbounds @inline f(iterator[indices[i]]) + end end