Skip to content

Add FFT support via AbstractFFTs interface#713

Open
KaanKesginLW wants to merge 39 commits intoJuliaGPU:mainfrom
KaanKesginLW:feature/fft-support
Open

Add FFT support via AbstractFFTs interface#713
KaanKesginLW wants to merge 39 commits intoJuliaGPU:mainfrom
KaanKesginLW:feature/fft-support

Conversation

@KaanKesginLW
Copy link
Contributor

@KaanKesginLW KaanKesginLW commented Dec 3, 2025

Adds FFT support for MtlArray via the AbstractFFTs.jl interface.

HEAVILY based on CUDA.jl's AbstractFFTs.jl interface implementation using MPSGraph functionality.

using Metal

x = MtlArray(randn(ComplexF32, 2048, 2048))
y = fft(x)  # Just works!

Performance

Benchmarked on Apple M2 Max with 30-core GPU against FFTW.jl on CPU:

Size CPU (FFTW) GPU (Metal)
512×512 4.1ms 5.3ms
1024×1024 19.7ms 8.5ms
2048×2048 119.7ms 10.5ms
4096×4096 460.6ms 15.8ms

Example Usage

using Metal

# Complex FFT
x = MtlArray(randn(ComplexF32, 1024, 1024))
y = fft(x)
z = ifft(y)  # z ≈ x

# Real FFT  
r = MtlArray(randn(Float32, 1024, 1024))
c = rfft(r)           # Real → Complex
r2 = irfft(c, 1024)   # Complex → Real, r2 ≈ r

# FFT along specific dimensions
y = fft(x, 1)         # First dimension only
y = fft(x, (1, 2))    # Batched transform

# Plan reuse
x = MtlArray(randn(ComplexF32, 1024, 1024))
another_x = MtlArray(randn(ComplexF32, 1024, 1024))
p = plan_fft(x)
y1 = p * x
y2 = p * another_x    # Same plan, different data

Close #270

@KaanKesginLW KaanKesginLW mentioned this pull request Dec 3, 2025
@codecov
Copy link

codecov bot commented Dec 3, 2025

Codecov Report

❌ Patch coverage is 80.59701% with 26 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.21%. Comparing base (1d2f000) to head (25108fd).

Files with missing lines Patch % Lines
src/fft.jl 78.15% 26 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #713      +/-   ##
==========================================
+ Coverage   82.01%   82.21%   +0.20%     
==========================================
  Files          62       64       +2     
  Lines        2874     3008     +134     
==========================================
+ Hits         2357     2473     +116     
- Misses        517      535      +18     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Details
Benchmark suite Current: 25108fd Previous: 1d2f000 Ratio
latency/precompile 25538749208 ns 25549419083 ns 1.00
latency/ttfp 2355936292 ns 2346831687.5 ns 1.00
latency/import 1432847167 ns 1427666042 ns 1.00
integration/metaldevrt 875875 ns 877750 ns 1.00
integration/byval/slices=1 1565750 ns 1568625 ns 1.00
integration/byval/slices=3 8760125 ns 8402792 ns 1.04
integration/byval/reference 1561916 ns 1559958 ns 1.00
integration/byval/slices=2 2656437.5 ns 2629875 ns 1.01
kernel/indexing 652167 ns 627417 ns 1.04
kernel/indexing_checked 636729.5 ns 608750 ns 1.05
kernel/launch 12542 ns 12667 ns 0.99
kernel/rand 569458 ns 576167 ns 0.99
array/construct 6750 ns 6500 ns 1.04
array/broadcast 602750 ns 606708 ns 0.99
array/random/randn/Float32 1020979 ns 1011104 ns 1.01
array/random/randn!/Float32 742021 ns 753875 ns 0.98
array/random/rand!/Int64 550708 ns 548708 ns 1.00
array/random/rand!/Float32 592333 ns 586208.5 ns 1.01
array/random/rand/Int64 792687.5 ns 789709 ns 1.00
array/random/rand/Float32 585584 ns 645000 ns 0.91
array/accumulate/Int64/1d 1263333.5 ns 1260667 ns 1.00
array/accumulate/Int64/dims=1 1922583 ns 1859104.5 ns 1.03
array/accumulate/Int64/dims=2 2164958.5 ns 2179083 ns 0.99
array/accumulate/Int64/dims=1L 11703708 ns 11673271 ns 1.00
array/accumulate/Int64/dims=2L 9821083 ns 9628146 ns 1.02
array/accumulate/Float32/1d 1116833.5 ns 1121395.5 ns 1.00
array/accumulate/Float32/dims=1 1565667 ns 1571667 ns 1.00
array/accumulate/Float32/dims=2 1894187.5 ns 1889459 ns 1.00
array/accumulate/Float32/dims=1L 9867125 ns 9834209 ns 1.00
array/accumulate/Float32/dims=2L 7226562.5 ns 7249666.5 ns 1.00
array/reductions/reduce/Int64/1d 1353750 ns 1386875 ns 0.98
array/reductions/reduce/Int64/dims=1 1100666 ns 1117250 ns 0.99
array/reductions/reduce/Int64/dims=2 1145042 ns 1152958 ns 0.99
array/reductions/reduce/Int64/dims=1L 2028292 ns 2013209 ns 1.01
array/reductions/reduce/Int64/dims=2L 4254000 ns 4244083 ns 1.00
array/reductions/reduce/Float32/1d 1050125 ns 988750 ns 1.06
array/reductions/reduce/Float32/dims=1 834959 ns 843520.5 ns 0.99
array/reductions/reduce/Float32/dims=2 864958 ns 857917 ns 1.01
array/reductions/reduce/Float32/dims=1L 1333187 ns 1326625 ns 1.00
array/reductions/reduce/Float32/dims=2L 1831000 ns 1810667 ns 1.01
array/reductions/mapreduce/Int64/1d 1349250 ns 1356437.5 ns 0.99
array/reductions/mapreduce/Int64/dims=1 1107667 ns 1102166.5 ns 1.00
array/reductions/mapreduce/Int64/dims=2 1161667 ns 1149750 ns 1.01
array/reductions/mapreduce/Int64/dims=1L 2038875 ns 1988375 ns 1.03
array/reductions/mapreduce/Int64/dims=2L 3628729 ns 3626916 ns 1.00
array/reductions/mapreduce/Float32/1d 1034583.5 ns 1055917 ns 0.98
array/reductions/mapreduce/Float32/dims=1 847834 ns 847396 ns 1.00
array/reductions/mapreduce/Float32/dims=2 861584 ns 860979.5 ns 1.00
array/reductions/mapreduce/Float32/dims=1L 1339437.5 ns 1333042 ns 1.00
array/reductions/mapreduce/Float32/dims=2L 1821333 ns 1898125 ns 0.96
array/private/copyto!/gpu_to_gpu 641875 ns 633020.5 ns 1.01
array/private/copyto!/cpu_to_gpu 754854 ns 804354.5 ns 0.94
array/private/copyto!/gpu_to_cpu 807000 ns 816000 ns 0.99
array/private/iteration/findall/int 1582458 ns 1581312.5 ns 1.00
array/private/iteration/findall/bool 1413958 ns 1404916.5 ns 1.01
array/private/iteration/findfirst/int 2120875 ns 2075167 ns 1.02
array/private/iteration/findfirst/bool 2063250 ns 2048750 ns 1.01
array/private/iteration/scalar 3877812.5 ns 4526479 ns 0.86
array/private/iteration/logical 2691500 ns 2693625 ns 1.00
array/private/iteration/findmin/1d 2277458 ns 2518041 ns 0.90
array/private/iteration/findmin/2d 1814146 ns 1820229.5 ns 1.00
array/private/copy 581042 ns 568854 ns 1.02
array/shared/copyto!/gpu_to_gpu 84750 ns 84291 ns 1.01
array/shared/copyto!/cpu_to_gpu 84500 ns 82875 ns 1.02
array/shared/copyto!/gpu_to_cpu 83417 ns 83000 ns 1.01
array/shared/iteration/findall/int 1581291.5 ns 1585854.5 ns 1.00
array/shared/iteration/findall/bool 1431312 ns 1421875 ns 1.01
array/shared/iteration/findfirst/int 1657209 ns 1654709 ns 1.00
array/shared/iteration/findfirst/bool 1650937.5 ns 1643542 ns 1.00
array/shared/iteration/scalar 213334 ns 210375 ns 1.01
array/shared/iteration/logical 2285583.5 ns 2297959 ns 0.99
array/shared/iteration/findmin/1d 2127208 ns 2134229 ns 1.00
array/shared/iteration/findmin/2d 1801646 ns 1806042 ns 1.00
array/shared/copy 233667 ns 241812 ns 0.97
array/permutedims/4d 2392916.5 ns 2395583 ns 1.00
array/permutedims/2d 1168167 ns 1158833 ns 1.01
array/permutedims/3d 1686916.5 ns 1686541 ns 1.00
metal/synchronization/stream 19666 ns 19583 ns 1.00
metal/synchronization/context 20291.5 ns 20291 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

@christiangnrd

This comment was marked as outdated.

@KaanKesginLW

This comment was marked as outdated.

@github-actions

This comment was marked as spam.

@KaanKesginLW

This comment was marked as resolved.

@christiangnrd christiangnrd force-pushed the feature/fft-support branch 2 times, most recently from 130ed6a to e3aeeea Compare January 19, 2026 01:09
christiangnrd

This comment was marked as outdated.

@liuyxpp

This comment was marked as off-topic.

@bjarthur

This comment was marked as resolved.

@christiangnrd christiangnrd marked this pull request as draft February 2, 2026 01:36
@christiangnrd christiangnrd force-pushed the feature/fft-support branch 2 times, most recently from e8d6b2c to ffdffe8 Compare February 4, 2026 02:59
@christiangnrd christiangnrd marked this pull request as ready for review February 4, 2026 03:31
@christiangnrd christiangnrd force-pushed the feature/fft-support branch 3 times, most recently from 3c46dc2 to 60077f5 Compare February 19, 2026 16:38
Implements GPU FFT operations for MtlArray using MPSGraph's
fastFourierTransformWithTensor. This addresses issue JuliaGPU#270.

Features:
- plan_fft, plan_ifft, plan_bfft for ComplexF32 arrays
- Multi-dimensional FFT support (single axis, multiple axes, all axes)
- FFTW.jl-compatible API via AbstractFFTs.jl interface
- Plan execution via * operator and mul!

Implementation notes:
- Uses MPSGraphFFTDescriptor for forward/inverse control
- Scaling handled manually for multi-axis ifft (not via MPSGraph's
  scalingMode) to ensure correct normalization across all FFT dimensions
- Axis mapping accounts for Julia's column-major vs Metal's row-major
  ordering via shape reversal in placeholderTensor

Tested against FFTW.jl with <1e-4 relative tolerance for all operations.
Implements real-to-complex and complex-to-real FFT operations using
MPSGraph's realToHermiteanFFTWithTensor and HermiteanToRealFFTWithTensor.

Features:
- plan_rfft for Float32 arrays (output size n÷2+1 in first FFT dimension)
- plan_irfft for ComplexF32 arrays (normalized inverse)
- plan_brfft for ComplexF32 arrays (unnormalized inverse)
- Proper handling of odd output sizes via roundToOddHermitean
- FFTW-compatible dimension conventions

The output size reduction follows FFTW convention: the first transformed
dimension is reduced to n÷2+1 for rfft, and irfft requires the original
size to be specified.

Tested against FFTW.jl with <1e-4 relative tolerance.
Phase 4: Type generalization for FFT operations.

Changes:
- Add ComplexF16 support for fft/ifft/bfft
- Add Float16 support for rfft/irfft/brfft
- Improve error messages for unsupported types (ComplexF64/Float64)
- Add documentation about supported and unsupported types
- Add FFTComplexTypes and FFTRealTypes type unions

Note: ComplexF64/Float64 are NOT supported by Metal's MPSGraph FFT.
Users requiring double precision should use FFTW.jl on CPU.

ComplexF16 results have ~3 decimal digits precision (expected for Float16).
Phase 5: Verify and test 1D FFT support.

The existing implementation already handles 1D arrays correctly since
the multi-dimensional FFT code works for arbitrary dimensions.

Added tests:
- 1D fft correctness vs FFTW
- 1D ifft roundtrip
- 1D rfft correctness vs FFTW
- 1D rfft -> irfft roundtrip
Phase 6: Implement in-place FFT operations.

New types:
- MtlFFTInplacePlan{T,K,N} - plan that modifies input directly

New functions:
- plan_fft!(x, [region]) - create in-place forward FFT plan
- plan_ifft!(x, [region]) - create in-place inverse FFT plan
- plan_bfft!(x, [region]) - create in-place backward FFT plan

The in-place variants are useful for avoiding memory allocation when
the input data is no longer needed after the transform.

Usage:
    x = MtlArray(randn(ComplexF32, 64, 64))
    plan = plan_fft!(x)
    plan * x  # x is modified in-place, returns x
KaanKesginLW and others added 29 commits February 23, 2026 14:42
- Remove shift parameter from all plan types and functions
- Remove _apply_fftshift_to_tensor helper
- Simplify _execute_single_axis_fft!
- Remove fftshift tests (14 tests removed)
- Aligns API with CUDA.jl which has no built-in fftshift
- Add Float16 CPU shims for FFTW reference comparisons
- Add tolerance functions based on type precision
- Restructure tests with reusable test functions
- Add @inferred checks for plan creation
- Add comprehensive batched transform tests (1D, 2D, 3D)
- Test both ComplexF16/ComplexF32 and Float16/Float32
- Tests increased from 55 to 117
They can be added in a different PR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

FFT support

4 participants