Add FFT support via AbstractFFTs interface#713
Open
KaanKesginLW wants to merge 39 commits intoJuliaGPU:mainfrom
Open
Add FFT support via AbstractFFTs interface#713KaanKesginLW wants to merge 39 commits intoJuliaGPU:mainfrom
KaanKesginLW wants to merge 39 commits intoJuliaGPU:mainfrom
Conversation
Open
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #713 +/- ##
==========================================
+ Coverage 82.01% 82.21% +0.20%
==========================================
Files 62 64 +2
Lines 2874 3008 +134
==========================================
+ Hits 2357 2473 +116
- Misses 517 535 +18 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Contributor
There was a problem hiding this comment.
Metal Benchmarks
Details
| Benchmark suite | Current: 25108fd | Previous: 1d2f000 | Ratio |
|---|---|---|---|
latency/precompile |
25538749208 ns |
25549419083 ns |
1.00 |
latency/ttfp |
2355936292 ns |
2346831687.5 ns |
1.00 |
latency/import |
1432847167 ns |
1427666042 ns |
1.00 |
integration/metaldevrt |
875875 ns |
877750 ns |
1.00 |
integration/byval/slices=1 |
1565750 ns |
1568625 ns |
1.00 |
integration/byval/slices=3 |
8760125 ns |
8402792 ns |
1.04 |
integration/byval/reference |
1561916 ns |
1559958 ns |
1.00 |
integration/byval/slices=2 |
2656437.5 ns |
2629875 ns |
1.01 |
kernel/indexing |
652167 ns |
627417 ns |
1.04 |
kernel/indexing_checked |
636729.5 ns |
608750 ns |
1.05 |
kernel/launch |
12542 ns |
12667 ns |
0.99 |
kernel/rand |
569458 ns |
576167 ns |
0.99 |
array/construct |
6750 ns |
6500 ns |
1.04 |
array/broadcast |
602750 ns |
606708 ns |
0.99 |
array/random/randn/Float32 |
1020979 ns |
1011104 ns |
1.01 |
array/random/randn!/Float32 |
742021 ns |
753875 ns |
0.98 |
array/random/rand!/Int64 |
550708 ns |
548708 ns |
1.00 |
array/random/rand!/Float32 |
592333 ns |
586208.5 ns |
1.01 |
array/random/rand/Int64 |
792687.5 ns |
789709 ns |
1.00 |
array/random/rand/Float32 |
585584 ns |
645000 ns |
0.91 |
array/accumulate/Int64/1d |
1263333.5 ns |
1260667 ns |
1.00 |
array/accumulate/Int64/dims=1 |
1922583 ns |
1859104.5 ns |
1.03 |
array/accumulate/Int64/dims=2 |
2164958.5 ns |
2179083 ns |
0.99 |
array/accumulate/Int64/dims=1L |
11703708 ns |
11673271 ns |
1.00 |
array/accumulate/Int64/dims=2L |
9821083 ns |
9628146 ns |
1.02 |
array/accumulate/Float32/1d |
1116833.5 ns |
1121395.5 ns |
1.00 |
array/accumulate/Float32/dims=1 |
1565667 ns |
1571667 ns |
1.00 |
array/accumulate/Float32/dims=2 |
1894187.5 ns |
1889459 ns |
1.00 |
array/accumulate/Float32/dims=1L |
9867125 ns |
9834209 ns |
1.00 |
array/accumulate/Float32/dims=2L |
7226562.5 ns |
7249666.5 ns |
1.00 |
array/reductions/reduce/Int64/1d |
1353750 ns |
1386875 ns |
0.98 |
array/reductions/reduce/Int64/dims=1 |
1100666 ns |
1117250 ns |
0.99 |
array/reductions/reduce/Int64/dims=2 |
1145042 ns |
1152958 ns |
0.99 |
array/reductions/reduce/Int64/dims=1L |
2028292 ns |
2013209 ns |
1.01 |
array/reductions/reduce/Int64/dims=2L |
4254000 ns |
4244083 ns |
1.00 |
array/reductions/reduce/Float32/1d |
1050125 ns |
988750 ns |
1.06 |
array/reductions/reduce/Float32/dims=1 |
834959 ns |
843520.5 ns |
0.99 |
array/reductions/reduce/Float32/dims=2 |
864958 ns |
857917 ns |
1.01 |
array/reductions/reduce/Float32/dims=1L |
1333187 ns |
1326625 ns |
1.00 |
array/reductions/reduce/Float32/dims=2L |
1831000 ns |
1810667 ns |
1.01 |
array/reductions/mapreduce/Int64/1d |
1349250 ns |
1356437.5 ns |
0.99 |
array/reductions/mapreduce/Int64/dims=1 |
1107667 ns |
1102166.5 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2 |
1161667 ns |
1149750 ns |
1.01 |
array/reductions/mapreduce/Int64/dims=1L |
2038875 ns |
1988375 ns |
1.03 |
array/reductions/mapreduce/Int64/dims=2L |
3628729 ns |
3626916 ns |
1.00 |
array/reductions/mapreduce/Float32/1d |
1034583.5 ns |
1055917 ns |
0.98 |
array/reductions/mapreduce/Float32/dims=1 |
847834 ns |
847396 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=2 |
861584 ns |
860979.5 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=1L |
1339437.5 ns |
1333042 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=2L |
1821333 ns |
1898125 ns |
0.96 |
array/private/copyto!/gpu_to_gpu |
641875 ns |
633020.5 ns |
1.01 |
array/private/copyto!/cpu_to_gpu |
754854 ns |
804354.5 ns |
0.94 |
array/private/copyto!/gpu_to_cpu |
807000 ns |
816000 ns |
0.99 |
array/private/iteration/findall/int |
1582458 ns |
1581312.5 ns |
1.00 |
array/private/iteration/findall/bool |
1413958 ns |
1404916.5 ns |
1.01 |
array/private/iteration/findfirst/int |
2120875 ns |
2075167 ns |
1.02 |
array/private/iteration/findfirst/bool |
2063250 ns |
2048750 ns |
1.01 |
array/private/iteration/scalar |
3877812.5 ns |
4526479 ns |
0.86 |
array/private/iteration/logical |
2691500 ns |
2693625 ns |
1.00 |
array/private/iteration/findmin/1d |
2277458 ns |
2518041 ns |
0.90 |
array/private/iteration/findmin/2d |
1814146 ns |
1820229.5 ns |
1.00 |
array/private/copy |
581042 ns |
568854 ns |
1.02 |
array/shared/copyto!/gpu_to_gpu |
84750 ns |
84291 ns |
1.01 |
array/shared/copyto!/cpu_to_gpu |
84500 ns |
82875 ns |
1.02 |
array/shared/copyto!/gpu_to_cpu |
83417 ns |
83000 ns |
1.01 |
array/shared/iteration/findall/int |
1581291.5 ns |
1585854.5 ns |
1.00 |
array/shared/iteration/findall/bool |
1431312 ns |
1421875 ns |
1.01 |
array/shared/iteration/findfirst/int |
1657209 ns |
1654709 ns |
1.00 |
array/shared/iteration/findfirst/bool |
1650937.5 ns |
1643542 ns |
1.00 |
array/shared/iteration/scalar |
213334 ns |
210375 ns |
1.01 |
array/shared/iteration/logical |
2285583.5 ns |
2297959 ns |
0.99 |
array/shared/iteration/findmin/1d |
2127208 ns |
2134229 ns |
1.00 |
array/shared/iteration/findmin/2d |
1801646 ns |
1806042 ns |
1.00 |
array/shared/copy |
233667 ns |
241812 ns |
0.97 |
array/permutedims/4d |
2392916.5 ns |
2395583 ns |
1.00 |
array/permutedims/2d |
1168167 ns |
1158833 ns |
1.01 |
array/permutedims/3d |
1686916.5 ns |
1686541 ns |
1.00 |
metal/synchronization/stream |
19666 ns |
19583 ns |
1.00 |
metal/synchronization/context |
20291.5 ns |
20291 ns |
1.00 |
This comment was automatically generated by workflow using github-action-benchmark.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as spam.
This comment was marked as spam.
40871e3 to
b88d77f
Compare
This comment was marked as resolved.
This comment was marked as resolved.
130ed6a to
e3aeeea
Compare
This comment was marked as off-topic.
This comment was marked as off-topic.
This comment was marked as resolved.
This comment was marked as resolved.
e8d6b2c to
ffdffe8
Compare
ffdffe8 to
6f6aa3c
Compare
3c46dc2 to
60077f5
Compare
Implements GPU FFT operations for MtlArray using MPSGraph's fastFourierTransformWithTensor. This addresses issue JuliaGPU#270. Features: - plan_fft, plan_ifft, plan_bfft for ComplexF32 arrays - Multi-dimensional FFT support (single axis, multiple axes, all axes) - FFTW.jl-compatible API via AbstractFFTs.jl interface - Plan execution via * operator and mul! Implementation notes: - Uses MPSGraphFFTDescriptor for forward/inverse control - Scaling handled manually for multi-axis ifft (not via MPSGraph's scalingMode) to ensure correct normalization across all FFT dimensions - Axis mapping accounts for Julia's column-major vs Metal's row-major ordering via shape reversal in placeholderTensor Tested against FFTW.jl with <1e-4 relative tolerance for all operations.
Implements real-to-complex and complex-to-real FFT operations using MPSGraph's realToHermiteanFFTWithTensor and HermiteanToRealFFTWithTensor. Features: - plan_rfft for Float32 arrays (output size n÷2+1 in first FFT dimension) - plan_irfft for ComplexF32 arrays (normalized inverse) - plan_brfft for ComplexF32 arrays (unnormalized inverse) - Proper handling of odd output sizes via roundToOddHermitean - FFTW-compatible dimension conventions The output size reduction follows FFTW convention: the first transformed dimension is reduced to n÷2+1 for rfft, and irfft requires the original size to be specified. Tested against FFTW.jl with <1e-4 relative tolerance.
Phase 4: Type generalization for FFT operations. Changes: - Add ComplexF16 support for fft/ifft/bfft - Add Float16 support for rfft/irfft/brfft - Improve error messages for unsupported types (ComplexF64/Float64) - Add documentation about supported and unsupported types - Add FFTComplexTypes and FFTRealTypes type unions Note: ComplexF64/Float64 are NOT supported by Metal's MPSGraph FFT. Users requiring double precision should use FFTW.jl on CPU. ComplexF16 results have ~3 decimal digits precision (expected for Float16).
Phase 5: Verify and test 1D FFT support. The existing implementation already handles 1D arrays correctly since the multi-dimensional FFT code works for arbitrary dimensions. Added tests: - 1D fft correctness vs FFTW - 1D ifft roundtrip - 1D rfft correctness vs FFTW - 1D rfft -> irfft roundtrip
Phase 6: Implement in-place FFT operations.
New types:
- MtlFFTInplacePlan{T,K,N} - plan that modifies input directly
New functions:
- plan_fft!(x, [region]) - create in-place forward FFT plan
- plan_ifft!(x, [region]) - create in-place inverse FFT plan
- plan_bfft!(x, [region]) - create in-place backward FFT plan
The in-place variants are useful for avoiding memory allocation when
the input data is no longer needed after the transform.
Usage:
x = MtlArray(randn(ComplexF32, 64, 64))
plan = plan_fft!(x)
plan * x # x is modified in-place, returns x
- Remove shift parameter from all plan types and functions - Remove _apply_fftshift_to_tensor helper - Simplify _execute_single_axis_fft! - Remove fftshift tests (14 tests removed) - Aligns API with CUDA.jl which has no built-in fftshift
- Add Float16 CPU shims for FFTW reference comparisons - Add tolerance functions based on type precision - Restructure tests with reusable test functions - Add @inferred checks for plan creation - Add comprehensive batched transform tests (1D, 2D, 3D) - Test both ComplexF16/ComplexF32 and Float16/Float32 - Tests increased from 55 to 117
…ormatting in operations.jl
They can be added in a different PR
60077f5 to
25108fd
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds FFT support for
MtlArrayvia the AbstractFFTs.jl interface.HEAVILY based on CUDA.jl's AbstractFFTs.jl interface implementation using MPSGraph functionality.
Performance
Benchmarked on Apple M2 Max with 30-core GPU against FFTW.jl on CPU:
Example Usage
Close #270