Skip to content

[Do not merge] Test KernelIntrinsics#688

Open
christiangnrd wants to merge 3 commits intomainfrom
kaintr
Open

[Do not merge] Test KernelIntrinsics#688
christiangnrd wants to merge 3 commits intomainfrom
kaintr

Conversation

@christiangnrd
Copy link
Member

Not a draft to also run benchmarks

@github-actions
Copy link
Contributor

github-actions bot commented Oct 22, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/MetalKernels.jl b/src/MetalKernels.jl
index 4e856194..7573c5e1 100644
--- a/src/MetalKernels.jl
+++ b/src/MetalKernels.jl
@@ -136,26 +136,26 @@ end
 
 KI.argconvert(::MetalBackend, arg) = mtlconvert(arg)
 
-function KI.kernel_function(::MetalBackend, f::F, tt::TT=Tuple{}; name=nothing, kwargs...) where {F,TT}
+function KI.kernel_function(::MetalBackend, f::F, tt::TT = Tuple{}; name = nothing, kwargs...) where {F, TT}
     kern = mtlfunction(f, tt; name, kwargs...)
-    KI.Kernel{MetalBackend, typeof(kern)}(MetalBackend(), kern)
+    return KI.Kernel{MetalBackend, typeof(kern)}(MetalBackend(), kern)
 end
 
-function (obj::KI.Kernel{MetalBackend})(args...; numworkgroups=1, workgroupsize=1)
+function (obj::KI.Kernel{MetalBackend})(args...; numworkgroups = 1, workgroupsize = 1)
     KI.check_launch_args(numworkgroups, workgroupsize)
 
-    obj.kern(args...; threads=workgroupsize, groups=numworkgroups)
+    return obj.kern(args...; threads = workgroupsize, groups = numworkgroups)
 end
 
 
-function KI.kernel_max_work_group_size(kikern::KI.Kernel{<:MetalBackend}; max_work_items::Int=typemax(Int))::Int
-    Int(min(kikern.kern.pipeline.maxTotalThreadsPerThreadgroup, max_work_items))
+function KI.kernel_max_work_group_size(kikern::KI.Kernel{<:MetalBackend}; max_work_items::Int = typemax(Int))::Int
+    return Int(min(kikern.kern.pipeline.maxTotalThreadsPerThreadgroup, max_work_items))
 end
 function KI.max_work_group_size(::MetalBackend)::Int
-    Int(device().maxThreadsPerThreadgroup.width)
+    return Int(device().maxThreadsPerThreadgroup.width)
 end
 function KI.multiprocessor_count(::MetalBackend)::Int
-    Metal.num_gpu_cores()
+    return Metal.num_gpu_cores()
 end
 
 
diff --git a/src/broadcast.jl b/src/broadcast.jl
index 5d107ec2..3455fad2 100644
--- a/src/broadcast.jl
+++ b/src/broadcast.jl
@@ -66,9 +66,9 @@ end
     if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
         ## COV_EXCL_START
         function broadcast_cartesian_static(dest, bc, Is)
-             i = KI.get_global_id().x
-             stride = KI.get_global_size().x
-             while 1 <= i <= length(dest)
+            i = KI.get_global_id().x
+            stride = KI.get_global_size().x
+            while 1 <= i <= length(dest)
                 I = @inbounds Is[i]
                 @inbounds dest[I] = bc[I]
                 i += stride
@@ -91,13 +91,13 @@ end
        (isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear))
         ## COV_EXCL_START
         function broadcast_linear(dest, bc)
-             i = KI.get_global_id().x
-             stride = KI.get_global_size().x
-             while 1 <= i <= length(dest)
-                 @inbounds dest[i] = bc[i]
-                 i += stride
-             end
-             return
+            i = KI.get_global_id().x
+            stride = KI.get_global_size().x
+            while 1 <= i <= length(dest)
+                @inbounds dest[i] = bc[i]
+                i += stride
+            end
+            return
         end
         ## COV_EXCL_STOP
 
@@ -168,9 +168,9 @@ end
     else
         ## COV_EXCL_START
         function broadcast_cartesian(dest, bc)
-             i = KI.get_global_id().x
-             stride = KI.get_global_size().x
-             while 1 <= i <= length(dest)
+            i = KI.get_global_id().x
+            stride = KI.get_global_size().x
+            while 1 <= i <= length(dest)
                 I = @inbounds CartesianIndices(dest)[i]
                 @inbounds dest[I] = bc[I]
                 i += stride
diff --git a/src/device/random.jl b/src/device/random.jl
index 12b053a2..edc999cd 100644
--- a/src/device/random.jl
+++ b/src/device/random.jl
@@ -89,8 +89,8 @@ end
         @inbounds global_random_counters()[simdgroupId]
     elseif field === :ctr2
         globalId = KI.get_global_id().x +
-                   (KI.get_global_id().y - 1i32) * KI.get_global_size().x +
-                   (KI.get_global_id().z - 1i32) * KI.get_global_size().x * KI.get_global_size().y
+            (KI.get_global_id().y - 1i32) * KI.get_global_size().x +
+            (KI.get_global_id().z - 1i32) * KI.get_global_size().x * KI.get_global_size().y
         globalId % UInt32
     end::UInt32
 end
diff --git a/src/mapreduce.jl b/src/mapreduce.jl
index 7be5ef43..a737e8d0 100644
--- a/src/mapreduce.jl
+++ b/src/mapreduce.jl
@@ -224,7 +224,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
     # we might not be able to launch all those threads to reduce each slice in one go.
     # that's why each threads also loops across their inputs, processing multiple values
     # so that we can span the entire reduction dimension using a single item group.
-    kernel = KI.@kernel backend launch = false partial_mapreduce_device(f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
+    kernel = KI.@kernel backend launch = false partial_mapreduce_device(
+        f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
                                                           Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A)
 
     # how many threads do we want?
@@ -260,7 +261,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
         # we can cover the dimensions to reduce using a single group
         kernel(f, op, init, Val(maxthreads), Val(Rreduce), Val(Rother),
                Val(UInt64(length(Rother))), Val(grain), Val(shuffle), R, A;
-               workgroupsize = threads, numworkgroups = groups)
+            workgroupsize = threads, numworkgroups = groups
+        )
     else
         # temporary empty array whose type will match the final partial array
 	    partial = similar(R, ntuple(_ -> 0, Val(ndims(R)+1)))
@@ -287,7 +289,8 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
         partial_kernel(f, op, init, Val(threads), Val(Rreduce),
                         Val(Rother), Val(UInt64(length(Rother))),
                         Val(grain), Val(shuffle), partial, A;
-                        numworkgroups = partial_groups, workgroupsize = partial_threads)
+            numworkgroups = partial_groups, workgroupsize = partial_threads
+        )
 
         GPUArrays.mapreducedim!(identity, op, R, partial; init=init)
     end
diff --git a/test/kernelabstractions.jl b/test/kernelabstractions.jl
index cda5b249..339fcbc8 100644
--- a/test/kernelabstractions.jl
+++ b/test/kernelabstractions.jl
@@ -7,6 +7,6 @@ Testsuite.testsuite(MetalBackend, "Metal", Metal, MtlArray, Metal.MtlDeviceArray
     "Convert",           # depends on https://github.com/JuliaGPU/Metal.jl/issues/69
     "SpecialFunctions",  # gamma and erfc not currently supported on Metal.jl
     "sparse",            # not supported yet
-    "CPU synchronization",
-    "fallback test: callable types",
+            "CPU synchronization",
+            "fallback test: callable types",
 ]))
diff --git a/test/runtests.jl b/test/runtests.jl
index 32b45c8c..14fcfb93 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,6 +1,6 @@
 @static if VERSION < v"1.11" && get(ENV, "BUILDKITE_PIPELINE_NAME", "Metal.jl") == "Metal.jl"
     using Pkg
-    Pkg.add(url="https://github.com/JuliaGPU/KernelAbstractions.jl", rev="main")
+    Pkg.add(url = "https://github.com/JuliaGPU/KernelAbstractions.jl", rev = "main")
 end
 
 using Metal

@christiangnrd christiangnrd force-pushed the kaintr branch 3 times, most recently from 9ac3d49 to 6314372 Compare October 22, 2025 04:31
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Details
Benchmark suite Current: 865af1a Previous: d1a8cc2 Ratio
latency/precompile 30229762937 ns 26336492125 ns 1.15
latency/ttfp 2505249124.5 ns 2442892125 ns 1.03
latency/import 1574101104 ns 1495421459 ns 1.05
integration/metaldevrt 886458.5 ns 518291 ns 1.71
integration/byval/slices=1 1573250 ns 1190833 ns 1.32
integration/byval/slices=3 9320833.5 ns 8102291 ns 1.15
integration/byval/reference 1574125 ns 1189208 ns 1.32
integration/byval/slices=2 2657834 ns 2114958 ns 1.26
kernel/indexing 658250 ns 260375 ns 2.53
kernel/indexing_checked 506416.5 ns 272167 ns 1.86
kernel/launch 13459 ns 12500 ns 1.08
kernel/rand 575291 ns 292458 ns 1.97
array/construct 6958 ns 6709 ns 1.04
array/broadcast 607959 ns 284000 ns 2.14
array/random/randn/Float32 1022000 ns 493667 ns 2.07
array/random/randn!/Float32 755250 ns 418250 ns 1.81
array/random/rand!/Int64 547459 ns 314416.5 ns 1.74
array/random/rand!/Float32 586583 ns 275792 ns 2.13
array/random/rand/Int64 803813 ns 462958 ns 1.74
array/random/rand/Float32 638104 ns 349625 ns 1.83
array/accumulate/Int64/1d 1335374.5 ns 970583 ns 1.38
array/accumulate/Int64/dims=1 1979584 ns 1042125 ns 1.90
array/accumulate/Int64/dims=2 2305000 ns 1363688 ns 1.69
array/accumulate/Int64/dims=1L 11904000 ns 9571750 ns 1.24
array/accumulate/Int64/dims=2L 9977625 ns 7945646 ns 1.26
array/accumulate/Float32/1d 1159875 ns 808292 ns 1.43
array/accumulate/Float32/dims=1 1789563 ns 904895.5 ns 1.98
array/accumulate/Float32/dims=2 2071542 ns 1184584 ns 1.75
array/accumulate/Float32/dims=1L 10015875 ns 8567583 ns 1.17
array/accumulate/Float32/dims=2L 8151125 ns 4391750 ns 1.86
array/reductions/reduce/Int64/1d 1352333 ns 744625 ns 1.82
array/reductions/reduce/Int64/dims=1 1141395.5 ns 685541 ns 1.66
array/reductions/reduce/Int64/dims=2 1342541 ns 722666 ns 1.86
array/reductions/reduce/Int64/dims=1L 2051229 ns 1164834 ns 1.76
array/reductions/reduce/Int64/dims=2L 4290209 ns 2224833 ns 1.93
array/reductions/reduce/Float32/1d 1062542 ns 558917 ns 1.90
array/reductions/reduce/Float32/dims=1 853834 ns 400208 ns 2.13
array/reductions/reduce/Float32/dims=2 895250 ns 430792 ns 2.08
array/reductions/reduce/Float32/dims=1L 1355375 ns 675854.5 ns 2.01
array/reductions/reduce/Float32/dims=2L 1905833.5 ns 1244709 ns 1.53
array/reductions/mapreduce/Int64/1d 1353000 ns 749417 ns 1.81
array/reductions/mapreduce/Int64/dims=1 1095417 ns 687167 ns 1.59
array/reductions/mapreduce/Int64/dims=2 1331625 ns 723333 ns 1.84
array/reductions/mapreduce/Int64/dims=1L 2071542 ns 1138000 ns 1.82
array/reductions/mapreduce/Int64/dims=2L 4315458 ns 1854834 ns 2.33
array/reductions/mapreduce/Float32/1d 1065896 ns 561291 ns 1.90
array/reductions/mapreduce/Float32/dims=1 863895.5 ns 400834 ns 2.16
array/reductions/mapreduce/Float32/dims=2 910354.5 ns 431458 ns 2.11
array/reductions/mapreduce/Float32/dims=1L 1345937.5 ns 690812.5 ns 1.95
array/reductions/mapreduce/Float32/dims=2L 1888062 ns 1251208 ns 1.51
array/private/copyto!/gpu_to_gpu 636084 ns 231917 ns 2.74
array/private/copyto!/cpu_to_gpu 801750 ns 255125 ns 3.14
array/private/copyto!/gpu_to_cpu 803958.5 ns 254833 ns 3.15
array/private/iteration/findall/int 1666354 ns 1211667 ns 1.38
array/private/iteration/findall/bool 1487104 ns 1071937.5 ns 1.39
array/private/iteration/findfirst/int 2136437 ns 1192604.5 ns 1.79
array/private/iteration/findfirst/bool 2116354.5 ns 1183042 ns 1.79
array/private/iteration/scalar 5696459 ns 1706917 ns 3.34
array/private/iteration/logical 2725771 ns 1644459 ns 1.66
array/private/iteration/findmin/1d 2597833.5 ns 1415000 ns 1.84
array/private/iteration/findmin/2d 1874312.5 ns 1220583 ns 1.54
array/private/copy 571958 ns 326542 ns 1.75
array/shared/copyto!/gpu_to_gpu 84541 ns 78584 ns 1.08
array/shared/copyto!/cpu_to_gpu 79916 ns 79729.5 ns 1.00
array/shared/copyto!/gpu_to_cpu 82375 ns 77083 ns 1.07
array/shared/iteration/findall/int 1666375 ns 1213417 ns 1.37
array/shared/iteration/findall/bool 1507416.5 ns 1077291 ns 1.40
array/shared/iteration/findfirst/int 1703958 ns 1003500 ns 1.70
array/shared/iteration/findfirst/bool 1721333 ns 995583 ns 1.73
array/shared/iteration/scalar 211125 ns 184375 ns 1.15
array/shared/iteration/logical 2653625.5 ns 1442938 ns 1.84
array/shared/iteration/findmin/1d 2217666.5 ns 1224583.5 ns 1.81
array/shared/iteration/findmin/2d 1871667 ns 1221208 ns 1.53
array/shared/copy 247000 ns 235833 ns 1.05
array/permutedims/4d 2670917 ns 1723000 ns 1.55
array/permutedims/2d 1167813 ns 558541 ns 2.09
array/permutedims/3d 1698042 ns 1119959 ns 1.52
metal/synchronization/stream 19500 ns 16750 ns 1.16
metal/synchronization/context 20417 ns 17541 ns 1.16

This comment was automatically generated by workflow using github-action-benchmark.

skip scripts tests on 1.10

Project.toml

Better workaround
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant