Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion src/device/intrinsics/simd.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export simdgroup_load, simdgroup_store, simdgroup_multiply, simdgroup_multiply_accumulate,
simd_shuffle_down, simd_shuffle_up, simd_shuffle_and_fill_down, simd_shuffle_and_fill_up,
simd_shuffle, simd_shuffle_xor
simd_shuffle, simd_shuffle_xor, simd_ballot, simd_vote_all, simd_vote_any

using Core: LLVMPtr

Expand Down Expand Up @@ -124,6 +124,18 @@ for (jltype, suffix) in simd_shuffle_map
end
end

## SIMD Voting Functions

@device_function simd_ballot(predicate::Bool) =
ccall("extern air.simd_ballot.i64", llvmcall, UInt64, (Bool,), predicate)

@device_function simd_vote_all(bitmask::UInt64) =
ccall("extern air.simd_vote_all.i64", llvmcall, Bool, (UInt64,), bitmask)

@device_function simd_vote_any(bitmask::UInt64) =
ccall("extern air.simd_vote_any.i64", llvmcall, Bool, (UInt64,), bitmask)


## Documentation

@doc """
Expand Down Expand Up @@ -209,3 +221,28 @@ The `modulo` parameter defines the vector width that splits the SIMD-group into
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
"""
simd_shuffle_and_fill_up

@doc """
simd_ballot(predicate::Bool)

Returns a UInt64 bitmask of the evaluation of the Boolean expression for all active
threads in the SIMD-group for which `predicate` is true. The function sets the bits that correspond
to inactive threads to 0.
"""
simd_ballot

@doc """
simd_vote_all(bitmask::UInt64)

Returns true if all bits corresponding to threads in the SIMD-group are set. The input is a
voting `bitmask`, such as the one returned by `simd_ballot`.
"""
simd_vote_all

@doc """
simd_vote_any(bitmask::UInt64)

Returns true if any bits corresponding to threads in the SIMD-group are set. The input is a
voting `bitmask`, such as the one returned by `simd_ballot`.
"""
simd_vote_any
91 changes: 91 additions & 0 deletions test/device/intrinsics/simd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,97 @@ end
Metal.@sync @metal threads=N kernel_mod(mtlfilling2, mtlfilling, midN)
@test Array(mtlfilling2) == [circshift(data[1:midN], nshift); circshift(data[midN+1:end], nshift)]
end

@testset "simd_ballot" begin
function ballot_kernel(output, threshold)
idx = thread_position_in_grid().x
lane = thread_index_in_simdgroup()

# Each thread votes true if its lane index is ≤ threshold
predicate = lane ≤ threshold
ballot = simd_ballot(predicate)

output[idx] = ballot
return
end

threads_per_simdgroup = 32

@testset "threshold=$threshold" for threshold in [0, 1, 8, 16, 31, 32]
output = MtlArray(zeros(UInt64, threads_per_simdgroup))
Metal.@sync @metal threads = threads_per_simdgroup ballot_kernel(output, UInt32(threshold))

# Expected: bits 0..(threshold-1) are set (1-indexed threshold means bits 0 to threshold-1)
expected_ballot = threshold == 0 ? UInt64(0) : (UInt64(1) << threshold) - 1
result = Array(output)

# All threads should see the same ballot result
@test all(result .== expected_ballot)
end
end

@testset "simd_vote_all" begin
function all_kernel(output, threshold)
idx = thread_position_in_grid().x
lane = thread_index_in_simdgroup()

# First get a ballot mask based on threshold
predicate = lane ≤ threshold
ballot = simd_ballot(predicate)

# simd_vote_all checks if all bits in the mask are set
result = simd_vote_all(ballot)

output[idx] = result
return
end

threads_per_simdgroup = 32

# simd_vote_all returns true only when all bits in the ballot mask are set
@testset "threshold=$threshold" for threshold in [0, 16, 31, 32, 33]
output = MtlArray(zeros(UInt8, threads_per_simdgroup))
Metal.@sync @metal threads = threads_per_simdgroup all_kernel(output, UInt32(threshold))

result = Array(output)
# All bits set means threshold ≥ 32 (all 32 lanes voted true)
expected = threshold ≥ threads_per_simdgroup

@test all(result .== expected)
end
end

@testset "simd_vote_any" begin
function any_kernel(output, threshold)
idx = thread_position_in_grid().x
lane = thread_index_in_simdgroup()

# First get a ballot mask based on threshold
predicate = lane ≤ threshold
ballot = simd_ballot(predicate)

# simd_vote_any checks if any bit in the mask is set
result = simd_vote_any(ballot)

output[idx] = result
return
end

threads_per_simdgroup = 32

# simd_vote_any returns true when any bit in the ballot mask is set
@testset "threshold=$threshold" for threshold in [0, 1, 16, 32]
output = MtlArray(zeros(UInt8, threads_per_simdgroup))
Metal.@sync @metal threads = threads_per_simdgroup any_kernel(output, UInt32(threshold))

result = Array(output)
# Any bit set means threshold ≥ 1 (at least lane 1 voted true)
expected = threshold ≥ 1

@test all(result .== expected)
end
end

@testset "matrix functions" begin
@testset "load_store($typ)" for typ in [Float16, Float32]
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T},
Expand Down