From e2a7bca3bd703ea2d1b916a0ab82f40388790f31 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 24 Feb 2026 16:01:30 +0100 Subject: [PATCH 1/2] Expose ballot voting intrinsics --- src/device/intrinsics/simd.jl | 37 +++++++++++++- test/device/intrinsics/simd.jl | 91 ++++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 1 deletion(-) diff --git a/src/device/intrinsics/simd.jl b/src/device/intrinsics/simd.jl index a6508c3c1..73dd3d3ee 100644 --- a/src/device/intrinsics/simd.jl +++ b/src/device/intrinsics/simd.jl @@ -1,6 +1,6 @@ export simdgroup_load, simdgroup_store, simdgroup_multiply, simdgroup_multiply_accumulate, simd_shuffle_down, simd_shuffle_up, simd_shuffle_and_fill_down, simd_shuffle_and_fill_up, - simd_shuffle, simd_shuffle_xor + simd_shuffle, simd_shuffle_xor, simd_ballot, simd_all, simd_any using Core: LLVMPtr @@ -124,6 +124,18 @@ for (jltype, suffix) in simd_shuffle_map end end +## SIMD Voting Functions + +@device_function simd_ballot(predicate::Bool) = + ccall("extern air.simd_ballot.i64", llvmcall, UInt64, (Bool,), predicate) + +@device_function simd_all(bitmask::UInt64) = + ccall("extern air.simd_vote_all.i64", llvmcall, Bool, (UInt64,), bitmask) + +@device_function simd_any(bitmask::UInt64) = + ccall("extern air.simd_vote_any.i64", llvmcall, Bool, (UInt64,), bitmask) + + ## Documentation @doc """ @@ -209,3 +221,26 @@ The `modulo` parameter defines the vector width that splits the SIMD-group into T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8 """ simd_shuffle_and_fill_up + +@doc """ + simd_ballot(predicate::Bool) + +Returns a UInt64 bitmask of the evaluation of the Boolean expression for all active +threads in the SIMD-group for which `predicate` is true. The function sets the bits that correspond +to inactive threads to 0. +""" +simd_ballot + +@doc """ + simd_all(bitmask::UInt64) + +Returns true if all bits in `bitmask` are set. +""" +simd_all + +@doc """ + simd_any(bitmask::UInt64) + +Returns true if any bits in `bitmask` are set. +""" +simd_any diff --git a/test/device/intrinsics/simd.jl b/test/device/intrinsics/simd.jl index 43d7ab55f..321bce393 100644 --- a/test/device/intrinsics/simd.jl +++ b/test/device/intrinsics/simd.jl @@ -115,6 +115,97 @@ end Metal.@sync @metal threads=N kernel_mod(mtlfilling2, mtlfilling, midN) @test Array(mtlfilling2) == [circshift(data[1:midN], nshift); circshift(data[midN+1:end], nshift)] end + +@testset "simd_ballot" begin + function ballot_kernel(output, threshold) + idx = thread_position_in_grid().x + lane = thread_index_in_simdgroup() + + # Each thread votes true if its lane index is ≤ threshold + predicate = lane ≤ threshold + ballot = simd_ballot(predicate) + + output[idx] = ballot + return + end + + threads_per_simdgroup = 32 + + @testset "threshold=$threshold" for threshold in [0, 1, 8, 16, 31, 32] + output = MtlArray(zeros(UInt64, threads_per_simdgroup)) + Metal.@sync @metal threads = threads_per_simdgroup ballot_kernel(output, UInt32(threshold)) + + # Expected: bits 0..(threshold-1) are set (1-indexed threshold means bits 0 to threshold-1) + expected_ballot = threshold == 0 ? UInt64(0) : (UInt64(1) << threshold) - 1 + result = Array(output) + + # All threads should see the same ballot result + @test all(result .== expected_ballot) + end +end + +@testset "simd_all" begin + function all_kernel(output, threshold) + idx = thread_position_in_grid().x + lane = thread_index_in_simdgroup() + + # First get a ballot mask based on threshold + predicate = lane ≤ threshold + ballot = simd_ballot(predicate) + + # simd_all checks if all bits in the mask are set + result = simd_all(ballot) + + output[idx] = result + return + end + + threads_per_simdgroup = 32 + + # simd_all returns true only when all bits in the ballot mask are set + @testset "threshold=$threshold" for threshold in [0, 16, 31, 32, 33] + output = MtlArray(zeros(UInt8, threads_per_simdgroup)) + Metal.@sync @metal threads = threads_per_simdgroup all_kernel(output, UInt32(threshold)) + + result = Array(output) + # All bits set means threshold ≥ 32 (all 32 lanes voted true) + expected = threshold ≥ threads_per_simdgroup + + @test all(result .== expected) + end +end + +@testset "simd_any" begin + function any_kernel(output, threshold) + idx = thread_position_in_grid().x + lane = thread_index_in_simdgroup() + + # First get a ballot mask based on threshold + predicate = lane ≤ threshold + ballot = simd_ballot(predicate) + + # simd_any checks if any bit in the mask is set + result = simd_any(ballot) + + output[idx] = result + return + end + + threads_per_simdgroup = 32 + + # simd_any returns true when any bit in the ballot mask is set + @testset "threshold=$threshold" for threshold in [0, 1, 16, 32] + output = MtlArray(zeros(UInt8, threads_per_simdgroup)) + Metal.@sync @metal threads = threads_per_simdgroup any_kernel(output, UInt32(threshold)) + + result = Array(output) + # Any bit set means threshold ≥ 1 (at least lane 1 voted true) + expected = threshold ≥ 1 + + @test all(result .== expected) + end +end + @testset "matrix functions" begin @testset "load_store($typ)" for typ in [Float16, Float32] function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, From 7a3fb2cf2ec315babdaa31f45d7f40dd2100438f Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Wed, 25 Feb 2026 13:11:58 +0100 Subject: [PATCH 2/2] Rename functions and update docstrings --- src/device/intrinsics/simd.jl | 20 +++++++++++--------- test/device/intrinsics/simd.jl | 16 ++++++++-------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/device/intrinsics/simd.jl b/src/device/intrinsics/simd.jl index 73dd3d3ee..368227ac2 100644 --- a/src/device/intrinsics/simd.jl +++ b/src/device/intrinsics/simd.jl @@ -1,6 +1,6 @@ export simdgroup_load, simdgroup_store, simdgroup_multiply, simdgroup_multiply_accumulate, simd_shuffle_down, simd_shuffle_up, simd_shuffle_and_fill_down, simd_shuffle_and_fill_up, - simd_shuffle, simd_shuffle_xor, simd_ballot, simd_all, simd_any + simd_shuffle, simd_shuffle_xor, simd_ballot, simd_vote_all, simd_vote_any using Core: LLVMPtr @@ -129,10 +129,10 @@ end @device_function simd_ballot(predicate::Bool) = ccall("extern air.simd_ballot.i64", llvmcall, UInt64, (Bool,), predicate) -@device_function simd_all(bitmask::UInt64) = +@device_function simd_vote_all(bitmask::UInt64) = ccall("extern air.simd_vote_all.i64", llvmcall, Bool, (UInt64,), bitmask) -@device_function simd_any(bitmask::UInt64) = +@device_function simd_vote_any(bitmask::UInt64) = ccall("extern air.simd_vote_any.i64", llvmcall, Bool, (UInt64,), bitmask) @@ -232,15 +232,17 @@ to inactive threads to 0. simd_ballot @doc """ - simd_all(bitmask::UInt64) + simd_vote_all(bitmask::UInt64) -Returns true if all bits in `bitmask` are set. +Returns true if all bits corresponding to threads in the SIMD-group are set. The input is a +voting `bitmask`, such as the one returned by `simd_ballot`. """ -simd_all +simd_vote_all @doc """ - simd_any(bitmask::UInt64) + simd_vote_any(bitmask::UInt64) -Returns true if any bits in `bitmask` are set. +Returns true if any bits corresponding to threads in the SIMD-group are set. The input is a +voting `bitmask`, such as the one returned by `simd_ballot`. """ -simd_any +simd_vote_any diff --git a/test/device/intrinsics/simd.jl b/test/device/intrinsics/simd.jl index 321bce393..684bb04bf 100644 --- a/test/device/intrinsics/simd.jl +++ b/test/device/intrinsics/simd.jl @@ -144,7 +144,7 @@ end end end -@testset "simd_all" begin +@testset "simd_vote_all" begin function all_kernel(output, threshold) idx = thread_position_in_grid().x lane = thread_index_in_simdgroup() @@ -153,8 +153,8 @@ end predicate = lane ≤ threshold ballot = simd_ballot(predicate) - # simd_all checks if all bits in the mask are set - result = simd_all(ballot) + # simd_vote_all checks if all bits in the mask are set + result = simd_vote_all(ballot) output[idx] = result return @@ -162,7 +162,7 @@ end threads_per_simdgroup = 32 - # simd_all returns true only when all bits in the ballot mask are set + # simd_vote_all returns true only when all bits in the ballot mask are set @testset "threshold=$threshold" for threshold in [0, 16, 31, 32, 33] output = MtlArray(zeros(UInt8, threads_per_simdgroup)) Metal.@sync @metal threads = threads_per_simdgroup all_kernel(output, UInt32(threshold)) @@ -175,7 +175,7 @@ end end end -@testset "simd_any" begin +@testset "simd_vote_any" begin function any_kernel(output, threshold) idx = thread_position_in_grid().x lane = thread_index_in_simdgroup() @@ -184,8 +184,8 @@ end predicate = lane ≤ threshold ballot = simd_ballot(predicate) - # simd_any checks if any bit in the mask is set - result = simd_any(ballot) + # simd_vote_any checks if any bit in the mask is set + result = simd_vote_any(ballot) output[idx] = result return @@ -193,7 +193,7 @@ end threads_per_simdgroup = 32 - # simd_any returns true when any bit in the ballot mask is set + # simd_vote_any returns true when any bit in the ballot mask is set @testset "threshold=$threshold" for threshold in [0, 1, 16, 32] output = MtlArray(zeros(UInt8, threads_per_simdgroup)) Metal.@sync @metal threads = threads_per_simdgroup any_kernel(output, UInt32(threshold))