diff --git a/.github/workflows/baseline.yml b/.github/workflows/baseline.yml index 4b514dee..a1f6d10f 100644 --- a/.github/workflows/baseline.yml +++ b/.github/workflows/baseline.yml @@ -34,7 +34,7 @@ jobs: name: Save Benchmark Baseline runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install Rust uses: dtolnay/rust-toolchain@stable @@ -49,7 +49,7 @@ jobs: # Cache keyed by SHA so each merge gets its own entry. # benchmark.yml uses restore-keys prefix matching to find the latest one. - name: Cache baseline - uses: actions/cache/save@v4 + uses: actions/cache/save@v5 with: path: target/fluxbench/baseline.json key: numr-bench-baseline-${{ github.sha }} diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 3c4ee4a6..d0a6fd6a 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -42,7 +42,7 @@ jobs: name: Regression Check runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: fetch-depth: 0 @@ -61,7 +61,7 @@ jobs: # picks the latest cache entry starting with "numr-bench-baseline-". - name: Restore baseline from main id: baseline-cache - uses: actions/cache/restore@v4 + uses: actions/cache/restore@v5 with: path: target/fluxbench/baseline.json key: numr-bench-baseline-dummy diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a53be6c1..2b2ffe1b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -23,7 +23,7 @@ jobs: outputs: version: ${{ steps.version.outputs.version }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install Rust uses: dtolnay/rust-toolchain@stable @@ -71,7 +71,7 @@ jobs: runs-on: ubuntu-latest environment: crates-io steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install Rust uses: dtolnay/rust-toolchain@stable diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 696e9828..f9e5221d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,7 +24,7 @@ jobs: name: Lint, Format & Docs runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install Rust uses: dtolnay/rust-toolchain@stable @@ -56,7 +56,7 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install Rust uses: dtolnay/rust-toolchain@stable @@ -75,7 +75,7 @@ jobs: name: Backend Compile, Parity & Examples runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install Rust uses: dtolnay/rust-toolchain@stable @@ -86,7 +86,7 @@ jobs: # Backend compile gates - name: "Compile: cpu-only (no default features)" - run: cargo check --no-default-features --features cpu + run: cargo check --no-default-features - name: "Compile: cpu + f16 + sparse" run: cargo check --features f16,sparse @@ -95,7 +95,7 @@ jobs: run: cargo check --features wgpu,f16,sparse - name: "Compile tests: cpu-only" - run: cargo test --no-run --no-default-features --features cpu + run: cargo test --no-run --no-default-features - name: "Compile tests: wgpu" run: cargo test --no-run --features wgpu,f16,sparse diff --git a/Cargo.toml b/Cargo.toml index e9b1a4d3..09522cb5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "numr" -version = "0.4.0" +version = "0.5.0" edition = "2024" rust-version = "1.89" description = "High-performance numerical computing with multi-backend GPU acceleration (CPU/CUDA/WebGPU)" @@ -15,14 +15,20 @@ features = ["f16", "sparse"] # cuda and wgpu require hardware SDKs not available on docs.rs [features] -default = ["cpu", "rayon"] -cpu = [] +default = ["rayon"] cuda = ["dep:cudarc"] +nccl = ["cuda", "cudarc?/nccl"] +distributed = ["dep:nexar", "dep:tokio"] +distributed-gpu = ["distributed", "nccl", "dep:nexar-nccl"] wgpu = ["dep:wgpu", "dep:pollster"] rayon = ["dep:rayon"] -f16 = ["dep:half", "cudarc?/f16"] # Half-precision floats (F16, BF16) - optional reduced-precision support -fp8 = [] # 8-bit floats (FP8E4M3, FP8E5M2) - optional ultra-low-precision support -sparse = [] # Sparse tensor formats (CSR, CSC, COO) and operations +f16 = [ + "dep:half", + "cudarc?/f16", +] # Half-precision floats (F16, BF16) - optional reduced-precision support +fp8 = [ +] # 8-bit floats (FP8E4M3, FP8E5M2) - optional ultra-low-precision support +sparse = [] # Sparse tensor formats (CSR, CSC, COO) and operations [dependencies] # Core @@ -35,11 +41,7 @@ parking_lot = "0.12" # Optional: Parallelism rayon = { version = "1.11", optional = true } -# Random number generation (required for rand/randn operations) -rand = "0.9" -rand_distr = "0.5" - -# Zero-copy serialization for embedded data +# Zero-copy serialization for embedded data (used by sobol_data) rkyv = "0.8" # Optional: Half-precision floats @@ -48,15 +50,20 @@ half = { version = "2.7", optional = true, features = [ "num-traits", ] } +# Optional: Inter-node distributed communication +nexar = { version = "0.1", optional = true } +nexar-nccl = { version = "0.1", optional = true } +tokio = { version = "1", features = ["rt"], optional = true } + # Optional: CUDA backend -cudarc = { version = "0.18", optional = true, features = [ +cudarc = { version = "0.19", optional = true, features = [ "cuda-version-from-build-system", ] } # Optional: WebGPU backend wgpu = { version = "28.0", optional = true } pollster = { version = "0.4", optional = true } -paste = "1.0.15" +paste = "1.0" [dev-dependencies] approx = "0.5" diff --git a/README.md b/README.md index ca43b8ee..53acc449 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ numr implements a comprehensive set of tensor operations across CPU, CUDA, and W ### Shape and Data Movement - **ShapeOps**: cat, stack, split, chunk, repeat, pad, roll -- **IndexingOps**: gather, scatter, gather_nd, scatter_reduce, index_select, masked_select, masked_fill, embedding_lookup, bincount, argmax, argmin +- **IndexingOps**: gather, scatter, gather_nd, scatter_reduce, index_select, masked_select, masked_fill, embedding_lookup, bincount, argmax, argmin, slice_assign - **SortingOps**: sort, argsort, topk, unique, nonzero, searchsorted ### Reductions @@ -106,22 +106,34 @@ numr implements a comprehensive set of tensor operations across CPU, CUDA, and W ### Activation & Normalization Functions -- **ActivationOps**: relu, sigmoid, silu, gelu, leaky_relu, elu, softmax -- **NormalizationOps**: rms_norm, layer_norm +- **ActivationOps**: relu, sigmoid, silu, gelu, swiglu, leaky_relu, elu, softmax, dropout, fused activation-mul (for gated architectures) +- **NormalizationOps**: rms_norm, layer_norm, batch_norm, group_norm, instance_norm, fused add-norm (residual + normalize in one pass) +- **GemmEpilogueOps**: fused matmul+bias+activation in a single kernel (forward + backward) +- **FusedElementwiseOps**: fused element-wise operation chains across all backends - **ConvOps**: conv1d, conv2d, depthwise_conv2d (with stride, padding, dilation, groups) +- **EinsumOps**: Einstein summation notation _These are mathematical functions commonly used in ML, but numr itself is not an ML framework._ ### Linear Algebra -- **MatmulOps**: matmul, matmul_bias (fused GEMM+bias) +- **MatmulOps**: matmul, matmul_bias (fused GEMM+bias), i8×i8→i32 quantized matmul, FP8 matmul - **LinalgOps**: solve, lstsq, pinverse, inverse, det, trace, matrix_rank, diag, matrix_norm, kron, khatri_rao - **ComplexOps**: conj, real, imag, angle (for complex tensor support) +### Automatic Differentiation + +- **Reverse-mode**: `Var` tracked tensors, `backward()` for gradient computation +- **Forward-mode**: `jvp()`, `jacobian_forward()` via dual numbers +- **Second-order**: `hvp()` for Hessian-vector products, `backward_with_graph()` for higher-order gradients +- **Activation checkpointing**: `checkpoint()` to trade compute for memory +- **Backward hooks**: `BackwardHook` trait for gradient notifications (e.g., distributed allreduce) +- **Differentiable ops**: matmul, conv1d, conv2d, softmax, rms_norm, layer_norm, SiLU, softplus, SwiGLU, dropout, fused GEMM epilogue, fused add-norm, dtype cast, narrow, cat + ### Statistics and Probability - **StatisticalOps**: var, std, skew, kurtosis, quantile, percentile, median, cov, corrcoef -- **RandomOps**: rand, randn, randint, multinomial, bernoulli, poisson, binomial, beta, gamma, exponential, chi_squared, student_t, f_distribution +- **RandomOps**: rand, randn, randint, multinomial, bernoulli, poisson, binomial, beta, gamma, exponential, chi_squared, student_t, f_distribution (with seeded deterministic generation) - **MultivariateRandomOps**: multivariate_normal, wishart, dirichlet - **QuasirandomOps**: Sobol, Halton sequences @@ -165,10 +177,38 @@ _These are mathematical functions commonly used in ML, but numr itself is not an - polyroots, polyval, polyfromroots, polymul +**Iterative Solvers (`numr::iterative`):** + +- **Linear solvers**: CG, MINRES, BiCGSTAB, GMRES, LGMRES, CGS, QMR, Jacobi, SOR, Adaptive GMRES +- **Eigensolvers**: Lanczos (symmetric), Arnoldi/IRAM (non-symmetric) +- **Sparse SVD**: via Lanczos bidiagonalization +- **Preconditioners**: ILU(0), IC(0), Algebraic Multigrid (AMG) with V-cycles + **Sparse Tensors (`numr::sparse`, feature-gated):** - Formats: CSR, CSC, COO - Operations: SpGEMM (sparse matrix multiplication), SpMV (sparse matrix-vector), DSMM (dense-sparse matrix) +- 2:4 structured sparsity with multi-backend support + +**Sparse Linear Algebra (`numr::sparse_linalg`):** + +- **Direct solvers**: Sparse LU (Gilbert-Peierls), sparse QR +- **Incomplete factorizations**: ILU(0), ILU(k), IC(0) +- **Preprocessing**: COLAMD ordering, maximum transversal +- **Symbolic/numeric split**: Reuse sparsity structure for repeated solves + +**Graph Capture (`numr::runtime`):** + +- **`Graph` trait**: Capture a sequence of operations and replay them with zero re-launch overhead +- **CUDA Graphs**: Full capture support—fixed-address buffer replay for inference loops and training steps +- **CPU / WebGPU**: Transparent no-op path; callers write backend-agnostic code using `R::supports_graph_capture()` + +**Distributed Computing (`numr::communicator`, feature `nccl`):** + +- **`CommunicatorGroup`**: Single-node multi-GPU all-reduce, broadcast, and allgather via NCCL +- **`HierarchicalCommunicator`**: Two-level collective—NCCL intra-node, nexar inter-node +- **`NexarNetCommunicator`**: Pure-Rust distributed transport (QUIC via nexar) for multi-machine tensor parallelism +- **`BackwardHook`**: Autograd hook interface—trigger cross-node gradient synchronization during `backward()` ## Dtypes @@ -198,15 +238,15 @@ Every operation supports every compatible dtype. No hardcoded f32-only kernels. All backends implement identical algorithms with native kernels—no cuBLAS, MKL, or vendor library dependencies. -| Hardware | Backend | Feature | Status | Notes | -| ------------ | ------- | ------------- | ------- | ------------------ | -| CPU (x86-64) | CPU | cpu (default) | ✓ | AVX-512/AVX2 SIMD | -| CPU (ARM64) | CPU | cpu | ✓ | NEON SIMD | -| NVIDIA GPU | CUDA | cuda | ✓ | Native PTX kernels | -| AMD GPU | WebGPU | wgpu | ✓ | WGSL shaders | -| Intel GPU | WebGPU | wgpu | ✓ | WGSL shaders | -| Apple GPU | WebGPU | wgpu | ✓ | WGSL shaders | -| AMD GPU | ROCm | - | Planned | Native HIP kernels | +| Hardware | Backend | Feature | Status | Notes | +| ------------ | ------- | ------------- | ------- | ------------------------------------------------------ | +| CPU (x86-64) | CPU | cpu (default) | ✓ | AVX-512/AVX2 SIMD | +| CPU (ARM64) | CPU | cpu | ✓ | NEON SIMD | +| NVIDIA GPU | CUDA | cuda | ✓ | Native PTX kernels, caching allocator, GEMV fast paths | +| AMD GPU | WebGPU | wgpu | ✓ | WGSL shaders | +| Intel GPU | WebGPU | wgpu | ✓ | WGSL shaders | +| Apple GPU | WebGPU | wgpu | ✓ | WGSL shaders | +| AMD GPU | ROCm | - | Planned | Native HIP kernels | ### SIMD Acceleration @@ -443,6 +483,45 @@ fn main() -> Result<()> { } ``` +### Automatic Differentiation + +```rust +use numr::prelude::*; +use numr::autograd::*; + +fn main() -> Result<()> { + let client = CpuRuntime::client()?; + + // Create tracked variables + let x = Var::new(Tensor::::from_slice(&[2.0, 3.0], &[2])?, true); + let w = Var::new(Tensor::::from_slice(&[0.5, -1.0], &[2])?, true); + + // Forward pass (builds computation graph) + let y = var_mul(&x, &w, &client)?; + let loss = var_sum(&y, &client)?; + + // Backward pass + let grads = backward(&loss, &client)?; + let dx = grads.get(x.tensor()); // gradients w.r.t. x + let dw = grads.get(w.tensor()); // gradients w.r.t. w + + // Activation checkpointing (trade compute for memory) + let checkpointed = checkpoint(|inputs| { + let h = var_relu(&inputs[0], &client)?; + var_matmul(&h, &inputs[1], &client) + }, &[&x, &w])?; + + // Forward-mode AD (Jacobian-vector products) + let tangent = Tensor::::ones(&[2], &device)?; + let jvp_result = jvp(|x| client.mul(x, x), &x.tensor(), &tangent, &client)?; + + // Hessian-vector product + let hvp_result = hvp(|x, c| c.mul(x, x), &x.tensor(), &tangent, &client)?; + + Ok(()) +} +``` + ## Installation ### CPU-only (default) @@ -484,7 +563,9 @@ numr = { version = "*", features = [ | `wgpu` | Cross-platform GPU (WebGPU) | ✗ | | `rayon` | Multi-threaded CPU via Rayon | ✓ | | `f16` | Half-precision floats (F16, BF16) | ✗ | +| `fp8` | FP8 precision (E4M3, E5M2) | ✗ | | `sparse` | Sparse tensor support (CSR, CSC, COO) | ✗ | +| `nccl` | Multi-GPU communication via NCCL | ✗ | ## Building from Source diff --git a/build.rs b/build.rs index 95a8aa4c..723018c1 100644 --- a/build.rs +++ b/build.rs @@ -37,6 +37,7 @@ fn compile_cuda_kernels() { #[allow(unused_mut)] let mut kernel_files = vec![ "activation.cu", + "softmax.cu", "advanced_random.cu", "binary.cu", "cast.cu", @@ -47,6 +48,10 @@ fn compile_cuda_kernels() { "distance.cu", "distributions.cu", "fft.cu", + "fused_activation_mul.cu", + "fused_activation_mul_bwd.cu", + "fused_add_norm.cu", + "fused_elementwise.cu", "index.cu", "linalg_advanced.cu", "linalg_banded.cu", @@ -59,6 +64,8 @@ fn compile_cuda_kernels() { "linalg_schur.cu", "linalg_solvers.cu", "linalg_svd.cu", + "fp8_matmul.cu", + "gemv.cu", "matmul.cu", "norm.cu", "semiring_matmul.cu", @@ -73,11 +80,14 @@ fn compile_cuda_kernels() { "ternary.cu", "unary.cu", "utility.cu", + "gemm_epilogue.cu", + "gemm_epilogue_bwd.cu", ]; // Add sparse kernels if sparse feature is enabled #[cfg(feature = "sparse")] { + kernel_files.push("sparse_24.cu"); kernel_files.push("sparse_spmv.cu"); kernel_files.push("sparse_merge.cu"); kernel_files.push("sparse_convert.cu"); @@ -114,6 +124,14 @@ fn compile_cuda_kernels() { panic!("nvcc not found - CUDA Toolkit must be installed for the 'cuda' feature"); }); + // Determine compute capability from NUMR_CUDA_ARCH env var, default sm_80 (Ampere) + // sm_80 enables tensor cores for F16/BF16, async copy, and other Ampere features + let cuda_arch = env::var("NUMR_CUDA_ARCH").unwrap_or_else(|_| "sm_80".to_string()); + println!( + "cargo:warning=numr: compiling {} CUDA kernels for {cuda_arch} (set NUMR_CUDA_ARCH to override)", + kernel_files.len() + ); + for kernel_file in kernel_files { let cu_path = kernels_dir.join(kernel_file); let ptx_name = kernel_file.replace(".cu", ".ptx"); @@ -131,15 +149,12 @@ fn compile_cuda_kernels() { ); } - // Compile to PTX - // Target: sm_75 (Turing) - supports CUDA 10.0+ - // This provides good compatibility while enabling modern features let output = Command::new(&nvcc) .args([ "-ptx", "-O3", "--use_fast_math", - "-arch=sm_75", + &format!("-arch={cuda_arch}"), "-o", ptx_path.to_str().unwrap(), cu_path.to_str().unwrap(), diff --git a/src/algorithm/iterative/helpers.rs b/src/algorithm/iterative/helpers.rs index c741e18a..baa23720 100644 --- a/src/algorithm/iterative/helpers.rs +++ b/src/algorithm/iterative/helpers.rs @@ -29,7 +29,7 @@ pub const REORTH_TOL: f64 = 1e-15; /// Uses optimized `item()` for scalar extraction (single element copy, no Vec allocation). pub fn vector_norm(client: &C, v: &Tensor) -> Result where - R: Runtime, + R: Runtime, C: BinaryOps + UnaryOps + ReduceOps, { // v^2 @@ -57,7 +57,7 @@ where /// Uses optimized `item()` for scalar extraction (single element copy, no Vec allocation). pub fn vector_dot(client: &C, u: &Tensor, v: &Tensor) -> Result where - R: Runtime, + R: Runtime, C: BinaryOps + ReduceOps, { // u * v @@ -176,7 +176,7 @@ pub fn update_solution( y: &[f64], ) -> Result> where - R: Runtime, + R: Runtime, C: BinaryOps + ScalarOps, { let m = y.len(); @@ -227,7 +227,7 @@ pub fn accumulate_basis_combination( device: &R::Device, ) -> Result> where - R: Runtime, + R: Runtime, C: BinaryOps + ScalarOps, { let mut result = Tensor::::zeros(&[n], dtype, device); @@ -249,7 +249,7 @@ where /// Used by Jacobi, SOR, and AMG V-cycle smoothing. pub fn extract_diagonal_inv(client: &C, a: &crate::sparse::CsrData) -> Result> where - R: Runtime, + R: Runtime, C: UnaryOps + BinaryOps + ScalarOps + crate::sparse::SparseOps, { let n = a.shape[0]; diff --git a/src/algorithm/iterative/impl_generic/adaptive_gmres.rs b/src/algorithm/iterative/impl_generic/adaptive_gmres.rs index af24e4f7..645c317c 100644 --- a/src/algorithm/iterative/impl_generic/adaptive_gmres.rs +++ b/src/algorithm/iterative/impl_generic/adaptive_gmres.rs @@ -39,7 +39,7 @@ pub fn adaptive_gmres_impl( adaptive_opts: AdaptivePreconditionerOptions, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps @@ -181,7 +181,7 @@ fn gmres_with_iluk( residual_history: &mut Vec, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps diff --git a/src/algorithm/iterative/impl_generic/amg.rs b/src/algorithm/iterative/impl_generic/amg.rs index 6e8a1a4c..11ca03bb 100644 --- a/src/algorithm/iterative/impl_generic/amg.rs +++ b/src/algorithm/iterative/impl_generic/amg.rs @@ -36,7 +36,7 @@ use super::amg_coarsen::{ /// The setup is done once and the hierarchy is reused for many V-cycles. pub fn amg_setup(client: &C, a: &CsrData, options: AmgOptions) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseOps + BinaryOps + UnaryOps + ReduceOps + ScalarOps, { @@ -161,7 +161,7 @@ pub fn amg_vcycle( level: usize, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseOps + BinaryOps + UnaryOps + ReduceOps + ScalarOps, { @@ -242,7 +242,7 @@ pub fn amg_preconditioned_cg( atol: f64, ) -> Result<(Tensor, usize, f64, bool)> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseOps + BinaryOps + UnaryOps + ReduceOps + ScalarOps, { diff --git a/src/algorithm/iterative/impl_generic/amg_coarsen.rs b/src/algorithm/iterative/impl_generic/amg_coarsen.rs index 040c6b13..2f8af2f0 100644 --- a/src/algorithm/iterative/impl_generic/amg_coarsen.rs +++ b/src/algorithm/iterative/impl_generic/amg_coarsen.rs @@ -5,6 +5,7 @@ //! - PMIS (Parallel Modified Independent Set) coarsening //! - Classical interpolation with truncation +use crate::dtype::DType; use crate::error::Result; use crate::runtime::Runtime; use crate::sparse::CsrData; @@ -125,7 +126,7 @@ pub fn pmis_coarsening(strong_connections: &[Vec], n: usize) -> CfSplitti /// For coarse points: P[i, coarse_map[i]] = 1 /// For fine points: P[i, j] = -a_ij / a_ii for strongly connected coarse j, /// normalized to sum to 1 -pub fn build_interpolation( +pub fn build_interpolation>( row_ptrs: &[i64], col_indices: &[i64], values: &[f64], @@ -229,7 +230,7 @@ pub fn build_interpolation( } /// Build restriction operator R = P^T (transpose of interpolation) -pub fn build_restriction(p: &CsrData) -> Result> { +pub fn build_restriction>(p: &CsrData) -> Result> { // P^T: CsrData::transpose() returns CscData, then to_csr() gives CSR of P^T let pt = p.transpose().to_csr()?; Ok(pt) @@ -239,7 +240,7 @@ pub fn build_restriction(p: &CsrData) -> Result> { /// /// This is done via sparse matrix multiplication. /// For simplicity, we compute it via explicit SpMM on CPU. -pub fn galerkin_coarse_operator( +pub fn galerkin_coarse_operator>( row_ptrs: &[i64], col_indices: &[i64], values: &[f64], diff --git a/src/algorithm/iterative/impl_generic/arnoldi_eig.rs b/src/algorithm/iterative/impl_generic/arnoldi_eig.rs index 71b5d7b3..288e8656 100644 --- a/src/algorithm/iterative/impl_generic/arnoldi_eig.rs +++ b/src/algorithm/iterative/impl_generic/arnoldi_eig.rs @@ -34,7 +34,7 @@ pub fn arnoldi_eig_impl( options: SparseEigOptions, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps @@ -223,7 +223,7 @@ fn build_result( nconv: usize, ) -> Result> where - R: Runtime, + R: Runtime, C: BinaryOps + ScalarOps, { let k_actual = k.min(indices.len()); @@ -289,7 +289,7 @@ fn thick_restart( device: &R::Device, ) -> Result<()> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps diff --git a/src/algorithm/iterative/impl_generic/bicgstab.rs b/src/algorithm/iterative/impl_generic/bicgstab.rs index 60466eeb..846cbd32 100644 --- a/src/algorithm/iterative/impl_generic/bicgstab.rs +++ b/src/algorithm/iterative/impl_generic/bicgstab.rs @@ -26,7 +26,7 @@ pub fn bicgstab_impl( options: BiCgStabOptions, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps diff --git a/src/algorithm/iterative/impl_generic/cg.rs b/src/algorithm/iterative/impl_generic/cg.rs index 5528c1fa..768cca20 100644 --- a/src/algorithm/iterative/impl_generic/cg.rs +++ b/src/algorithm/iterative/impl_generic/cg.rs @@ -38,7 +38,7 @@ pub fn cg_impl( options: CgOptions, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps diff --git a/src/algorithm/iterative/impl_generic/cgs.rs b/src/algorithm/iterative/impl_generic/cgs.rs index a9429637..de0f6946 100644 --- a/src/algorithm/iterative/impl_generic/cgs.rs +++ b/src/algorithm/iterative/impl_generic/cgs.rs @@ -49,7 +49,7 @@ pub fn cgs_impl( options: CgsOptions, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps diff --git a/src/algorithm/iterative/impl_generic/gmres.rs b/src/algorithm/iterative/impl_generic/gmres.rs index 63544aea..ef5d9e39 100644 --- a/src/algorithm/iterative/impl_generic/gmres.rs +++ b/src/algorithm/iterative/impl_generic/gmres.rs @@ -38,7 +38,7 @@ pub fn gmres_impl( options: GmresOptions, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps diff --git a/src/algorithm/iterative/impl_generic/jacobi.rs b/src/algorithm/iterative/impl_generic/jacobi.rs index 4e192535..1fcdc9c9 100644 --- a/src/algorithm/iterative/impl_generic/jacobi.rs +++ b/src/algorithm/iterative/impl_generic/jacobi.rs @@ -34,7 +34,7 @@ pub fn jacobi_impl( options: JacobiOptions, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseOps + BinaryOps + UnaryOps + ReduceOps + ScalarOps, { diff --git a/src/algorithm/iterative/impl_generic/lanczos_eig.rs b/src/algorithm/iterative/impl_generic/lanczos_eig.rs index 1a67d5e6..4ac1c1bf 100644 --- a/src/algorithm/iterative/impl_generic/lanczos_eig.rs +++ b/src/algorithm/iterative/impl_generic/lanczos_eig.rs @@ -33,7 +33,7 @@ pub fn lanczos_eig_impl( options: SparseEigOptions, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps @@ -214,7 +214,7 @@ where /// Each column vector is transferred once from device to host, then the /// complete matrix is transferred back. This is O(k) transfers for final /// output assembly — not used in any iterative loop. -fn assemble_column_matrix( +fn assemble_column_matrix>( columns: &[Tensor], n: usize, k: usize, diff --git a/src/algorithm/iterative/impl_generic/lgmres.rs b/src/algorithm/iterative/impl_generic/lgmres.rs index a1023bd2..39e45cc7 100644 --- a/src/algorithm/iterative/impl_generic/lgmres.rs +++ b/src/algorithm/iterative/impl_generic/lgmres.rs @@ -40,7 +40,7 @@ pub fn lgmres_impl( options: LgmresOptions, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps diff --git a/src/algorithm/iterative/impl_generic/minres.rs b/src/algorithm/iterative/impl_generic/minres.rs index 7c742bfa..de7d62e4 100644 --- a/src/algorithm/iterative/impl_generic/minres.rs +++ b/src/algorithm/iterative/impl_generic/minres.rs @@ -31,7 +31,7 @@ pub fn minres_impl( options: MinresOptions, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps diff --git a/src/algorithm/iterative/impl_generic/qmr.rs b/src/algorithm/iterative/impl_generic/qmr.rs index 67c68c04..de8981cb 100644 --- a/src/algorithm/iterative/impl_generic/qmr.rs +++ b/src/algorithm/iterative/impl_generic/qmr.rs @@ -29,7 +29,7 @@ pub fn qmr_impl( options: QmrOptions, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps diff --git a/src/algorithm/iterative/impl_generic/sor.rs b/src/algorithm/iterative/impl_generic/sor.rs index 2a492436..91f5d39b 100644 --- a/src/algorithm/iterative/impl_generic/sor.rs +++ b/src/algorithm/iterative/impl_generic/sor.rs @@ -37,7 +37,7 @@ pub fn sor_impl( options: SorOptions, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps @@ -125,7 +125,7 @@ where /// - j < i: omega * a_ij (scaled strict lower triangle) /// - j == i: a_ii (diagonal, unscaled) /// - j > i: excluded (upper triangle) -fn build_sor_lower_triangular( +fn build_sor_lower_triangular>( a: &CsrData, omega: f64, device: &R::Device, diff --git a/src/algorithm/iterative/impl_generic/svds.rs b/src/algorithm/iterative/impl_generic/svds.rs index 5cab8ea2..99f2ea44 100644 --- a/src/algorithm/iterative/impl_generic/svds.rs +++ b/src/algorithm/iterative/impl_generic/svds.rs @@ -37,7 +37,7 @@ pub fn svds_impl( options: SvdsOptions, ) -> Result> where - R: Runtime, + R: Runtime, R::Client: SparseOps, C: SparseLinAlgAlgorithms + SparseOps diff --git a/src/algorithm/linalg/helpers.rs b/src/algorithm/linalg/helpers.rs index 601f52ed..75980e47 100644 --- a/src/algorithm/linalg/helpers.rs +++ b/src/algorithm/linalg/helpers.rs @@ -66,7 +66,7 @@ pub fn linalg_promote<'a, R, C>( tensor: &'a Tensor, ) -> Result<(std::borrow::Cow<'a, Tensor>, DType)> where - R: Runtime, + R: Runtime, C: TypeConversionOps, { let original_dtype = tensor.dtype(); @@ -90,7 +90,7 @@ pub fn linalg_demote( original_dtype: DType, ) -> Result> where - R: Runtime, + R: Runtime, C: TypeConversionOps, { if result.dtype() != original_dtype { diff --git a/src/algorithm/linalg/tensor_decompose_core.rs b/src/algorithm/linalg/tensor_decompose_core.rs index 2391a56e..12aa6781 100644 --- a/src/algorithm/linalg/tensor_decompose_core.rs +++ b/src/algorithm/linalg/tensor_decompose_core.rs @@ -19,7 +19,8 @@ use super::decompositions::{ }; use crate::dtype::DType; use crate::error::{Error, Result}; -use crate::ops::traits::{BinaryOps, MatmulOps, RandomOps, ReduceOps, UnaryOps}; +use crate::ops::traits::RandomOps; +use crate::ops::traits::{BinaryOps, MatmulOps, ReduceOps, UnaryOps}; use crate::runtime::Runtime; use crate::tensor::Tensor; @@ -135,7 +136,7 @@ fn unfold_permutation(mode: usize, ndim: usize) -> Vec { /// /// Unfolds tensor T of shape [I₁, I₂, ..., Iₙ] along mode n into matrix /// of shape [Iₙ, ∏ⱼ≠ₙ Iⱼ]. -pub fn unfold_impl( +pub fn unfold_impl>( tensor: &Tensor, mode: usize, dtype_support: TensorDecomposeDTypeSupport, @@ -168,7 +169,7 @@ pub fn unfold_impl( /// Mode-n folding (tensorization) - inverse of unfolding /// /// Reconstructs tensor from its mode-n unfolding. -pub fn fold_impl( +pub fn fold_impl>( matrix: &Tensor, mode: usize, shape: &[usize], @@ -232,7 +233,7 @@ pub fn mode_n_product_impl( dtype_support: TensorDecomposeDTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: MatmulOps, { let tensor_shape = tensor.shape(); @@ -287,7 +288,7 @@ pub fn hosvd_impl( dtype_support: TensorDecomposeDTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: LinearAlgebraAlgorithms + MatmulOps, { let shape = tensor.shape(); @@ -335,7 +336,7 @@ where /// Compute Frobenius norm of a tensor - returns GPU scalar tensor (no CPU transfer) fn frobenius_norm_tensor(client: &C, tensor: &Tensor) -> Result> where - R: Runtime, + R: Runtime, C: ReduceOps + BinaryOps + UnaryOps, { let sq = client.mul(tensor, tensor)?; @@ -355,7 +356,7 @@ pub fn tucker_impl( dtype_support: TensorDecomposeDTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: LinearAlgebraAlgorithms + MatmulOps + ReduceOps + BinaryOps + RandomOps, { let shape = tensor.shape(); @@ -437,7 +438,7 @@ fn initialize_cp_factors( dtype_support: TensorDecomposeDTypeSupport, ) -> Result>> where - R: Runtime, + R: Runtime, C: LinearAlgebraAlgorithms + RandomOps, { let shape = tensor.shape(); @@ -517,7 +518,7 @@ fn compute_gram_hadamard_except( skip_mode: usize, ) -> Result> where - R: Runtime, + R: Runtime, C: MatmulOps + BinaryOps, { let n = factors.len(); @@ -573,7 +574,7 @@ pub fn cp_decompose_impl( dtype_support: TensorDecomposeDTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: LinearAlgebraAlgorithms + MatmulOps + ReduceOps @@ -647,7 +648,7 @@ pub fn tensor_train_impl( dtype_support: TensorDecomposeDTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: LinearAlgebraAlgorithms + ReduceOps + BinaryOps + UnaryOps, { let shape = tensor.shape(); @@ -784,7 +785,7 @@ pub fn tucker_reconstruct_impl( dtype_support: TensorDecomposeDTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: MatmulOps, { let mut result = decomp.core.clone(); @@ -804,7 +805,7 @@ pub fn cp_reconstruct_impl( _dtype_support: TensorDecomposeDTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: LinearAlgebraAlgorithms + MatmulOps + BinaryOps, { let ndim = decomp.factors.len(); @@ -848,7 +849,7 @@ pub fn tt_reconstruct_impl( decomp: &TensorTrainDecomposition, ) -> Result> where - R: Runtime, + R: Runtime, C: MatmulOps, { if decomp.cores.is_empty() { diff --git a/src/algorithm/polynomial/core/convolve.rs b/src/algorithm/polynomial/core/convolve.rs index 7bb93ef9..5b443524 100644 --- a/src/algorithm/polynomial/core/convolve.rs +++ b/src/algorithm/polynomial/core/convolve.rs @@ -17,6 +17,7 @@ use super::DTypeSupport; use crate::algorithm::fft::{FftAlgorithms, FftNormalization}; +use crate::dtype::DType; use crate::error::Result; use crate::ops::{BinaryOps, ComplexOps, IndexingOps, ReduceOps, ShapeOps, UtilityOps}; use crate::runtime::{Runtime, RuntimeClient}; @@ -60,7 +61,7 @@ pub fn convolve_impl( dtype_support: DTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + BinaryOps + IndexingOps @@ -103,7 +104,7 @@ fn convolve_direct( dtype_support: DTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + BinaryOps + IndexingOps @@ -179,7 +180,7 @@ fn convolve_fft( dtype_support: DTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + BinaryOps + ShapeOps @@ -245,7 +246,7 @@ where /// This uses BinaryOps::mul which handles complex types via the Element trait. fn complex_mul(client: &C, a: &Tensor, b: &Tensor) -> Result> where - R: Runtime, + R: Runtime, C: BinaryOps, { // BinaryOps::mul handles complex multiplication natively diff --git a/src/algorithm/polynomial/core/mod.rs b/src/algorithm/polynomial/core/mod.rs index 6b879cd6..1bb8da4a 100644 --- a/src/algorithm/polynomial/core/mod.rs +++ b/src/algorithm/polynomial/core/mod.rs @@ -111,7 +111,7 @@ impl DTypeSupport { /// * `index` - The index value /// * `index_dtype` - The dtype for the index tensor (I32 or I64) /// * `device` - The device to create the tensor on -pub(crate) fn create_index_tensor( +pub(crate) fn create_index_tensor>( index: usize, index_dtype: DType, device: &R::Device, @@ -130,7 +130,7 @@ pub(crate) fn create_index_tensor( /// * `end` - End index (exclusive) /// * `index_dtype` - The dtype for the index tensor (I32 or I64) /// * `device` - The device to create the tensor on -pub(crate) fn create_arange_tensor( +pub(crate) fn create_arange_tensor>( start: usize, end: usize, index_dtype: DType, diff --git a/src/algorithm/polynomial/core/polyfromroots.rs b/src/algorithm/polynomial/core/polyfromroots.rs index 63a311bf..747fc161 100644 --- a/src/algorithm/polynomial/core/polyfromroots.rs +++ b/src/algorithm/polynomial/core/polyfromroots.rs @@ -3,6 +3,7 @@ use super::{DTypeSupport, convolve_impl, create_index_tensor}; use crate::algorithm::fft::FftAlgorithms; use crate::algorithm::polynomial::helpers::{validate_polynomial_dtype, validate_polynomial_roots}; +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::{BinaryOps, ComplexOps, IndexingOps, ReduceOps, ShapeOps, UnaryOps, UtilityOps}; use crate::runtime::{Runtime, RuntimeClient}; @@ -36,7 +37,7 @@ pub fn polyfromroots_impl( dtype_support: DTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + BinaryOps + UnaryOps diff --git a/src/algorithm/polynomial/core/polymul.rs b/src/algorithm/polynomial/core/polymul.rs index 9546e5ba..8f94bc6f 100644 --- a/src/algorithm/polynomial/core/polymul.rs +++ b/src/algorithm/polynomial/core/polymul.rs @@ -5,6 +5,7 @@ use crate::algorithm::fft::FftAlgorithms; use crate::algorithm::polynomial::helpers::{ validate_polynomial_coeffs, validate_polynomial_dtype, }; +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::{BinaryOps, ComplexOps, IndexingOps, ReduceOps, ShapeOps, UtilityOps}; use crate::runtime::{Runtime, RuntimeClient}; @@ -33,7 +34,7 @@ pub fn polymul_impl( dtype_support: DTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + BinaryOps + IndexingOps diff --git a/src/algorithm/polynomial/core/polyroots.rs b/src/algorithm/polynomial/core/polyroots.rs index f1a79b01..3be85500 100644 --- a/src/algorithm/polynomial/core/polyroots.rs +++ b/src/algorithm/polynomial/core/polyroots.rs @@ -5,6 +5,7 @@ use crate::algorithm::linalg::LinearAlgebraAlgorithms; use crate::algorithm::polynomial::helpers::validate_polynomial_coeffs; use crate::algorithm::polynomial::helpers::validate_polynomial_dtype; use crate::algorithm::polynomial::types::PolynomialRoots; +use crate::dtype::DType; use crate::error::Result; use crate::ops::{ BinaryOps, CompareOps, IndexingOps, LinalgOps, ReduceOps, ScalarOps, ShapeOps, UtilityOps, @@ -48,7 +49,7 @@ pub fn polyroots_impl( dtype_support: DTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + LinearAlgebraAlgorithms + BinaryOps diff --git a/src/algorithm/polynomial/core/polyval.rs b/src/algorithm/polynomial/core/polyval.rs index 4251d944..e31cc51e 100644 --- a/src/algorithm/polynomial/core/polyval.rs +++ b/src/algorithm/polynomial/core/polyval.rs @@ -4,6 +4,7 @@ use super::{DTypeSupport, create_index_tensor}; use crate::algorithm::polynomial::helpers::{ validate_polynomial_coeffs, validate_polynomial_dtype, }; +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::{BinaryOps, IndexingOps, ScalarOps, ShapeOps}; use crate::runtime::{Runtime, RuntimeClient}; @@ -34,7 +35,7 @@ pub fn polyval_impl( dtype_support: DTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + BinaryOps + ScalarOps + IndexingOps + ShapeOps, { validate_polynomial_dtype(coeffs.dtype())?; diff --git a/src/algorithm/sparse.rs b/src/algorithm/sparse.rs index f9c2f317..0d37da3d 100644 --- a/src/algorithm/sparse.rs +++ b/src/algorithm/sparse.rs @@ -90,7 +90,7 @@ pub trait SparseAlgorithms { // ============================================================================ /// Zero tolerance threshold for filtering small values -pub use crate::runtime::sparse_utils::zero_tolerance; +pub use crate::runtime::common::sparse_utils::zero_tolerance; /// Validate CSR matrix dimensions for SpGEMM pub fn validate_spgemm_shapes( diff --git a/src/algorithm/sparse_linalg/cpu/ic0.rs b/src/algorithm/sparse_linalg/cpu/ic0.rs index 6f56ada4..65482b51 100644 --- a/src/algorithm/sparse_linalg/cpu/ic0.rs +++ b/src/algorithm/sparse_linalg/cpu/ic0.rs @@ -36,7 +36,10 @@ use crate::tensor::Tensor; /// # Returns /// /// IC decomposition with lower triangular factor L -pub fn ic0_cpu(a: &CsrData, options: IcOptions) -> Result> { +pub fn ic0_cpu>( + a: &CsrData, + options: IcOptions, +) -> Result> { let n = validate_square_sparse(a.shape)?; let dtype = a.values().dtype(); validate_cpu_dtype(dtype)?; @@ -174,7 +177,7 @@ pub fn ic0_cpu(a: &CsrData, options: IcOptions) -> Result( +fn extract_lower_triangle>( n: usize, row_ptrs: &[i64], col_indices: &[i64], diff --git a/src/algorithm/sparse_linalg/cpu/ilu0.rs b/src/algorithm/sparse_linalg/cpu/ilu0.rs index 05f68950..830d8cee 100644 --- a/src/algorithm/sparse_linalg/cpu/ilu0.rs +++ b/src/algorithm/sparse_linalg/cpu/ilu0.rs @@ -31,7 +31,10 @@ use crate::tensor::Tensor; /// # Returns /// /// ILU decomposition with L (unit lower triangular) and U (upper triangular) -pub fn ilu0_cpu(a: &CsrData, options: IluOptions) -> Result> { +pub fn ilu0_cpu>( + a: &CsrData, + options: IluOptions, +) -> Result> { let n = validate_square_sparse(a.shape)?; let dtype = a.values().dtype(); validate_cpu_dtype(dtype)?; @@ -148,7 +151,7 @@ pub fn ilu0_cpu(a: &CsrData, options: IluOptions) -> Result( +fn split_lu>( n: usize, row_ptrs: &[i64], col_indices: &[i64], @@ -243,7 +246,7 @@ fn split_lu( /// Analyzes the sparsity pattern to create an efficient update schedule /// for numeric factorization. This avoids hash map lookups during the /// numeric phase. -pub fn ilu0_symbolic_cpu(pattern: &CsrData) -> Result { +pub fn ilu0_symbolic_cpu>(pattern: &CsrData) -> Result { let n = validate_square_sparse(pattern.shape)?; // Extract CSR structure for CPU-based symbolic analysis @@ -258,7 +261,7 @@ pub fn ilu0_symbolic_cpu(pattern: &CsrData) -> Result( +pub fn ilu0_numeric_cpu>( a: &CsrData, symbolic: &SymbolicIlu0, options: IluOptions, diff --git a/src/algorithm/sparse_linalg/cpu/iluk.rs b/src/algorithm/sparse_linalg/cpu/iluk.rs index 59a112f5..3fc4b0f1 100644 --- a/src/algorithm/sparse_linalg/cpu/iluk.rs +++ b/src/algorithm/sparse_linalg/cpu/iluk.rs @@ -24,7 +24,10 @@ use crate::tensor::Tensor; /// - `level[i,j]` = min over all paths i→k→j of: `level[i,k]` + `level[k,j]` + 1 /// /// Positions with `level[i,j]` ≤ k are included in the fill pattern. -pub fn iluk_symbolic_cpu(a: &CsrData, level: IluFillLevel) -> Result { +pub fn iluk_symbolic_cpu>( + a: &CsrData, + level: IluFillLevel, +) -> Result { let n = validate_square_sparse(a.shape)?; // Extract CSR structure for CPU-based symbolic analysis @@ -36,7 +39,7 @@ pub fn iluk_symbolic_cpu(a: &CsrData, level: IluFillLevel) -> Res } /// ILU(k) numeric factorization on CPU using precomputed symbolic data -pub fn iluk_numeric_cpu( +pub fn iluk_numeric_cpu>( a: &CsrData, symbolic: &IlukSymbolic, opts: &IlukOptions, @@ -251,7 +254,10 @@ pub fn iluk_numeric_cpu( } /// Combined ILU(k) factorization (symbolic + numeric) -pub fn iluk_cpu(a: &CsrData, opts: IlukOptions) -> Result> { +pub fn iluk_cpu>( + a: &CsrData, + opts: IlukOptions, +) -> Result> { let symbolic = iluk_symbolic_cpu(a, opts.fill_level)?; iluk_numeric_cpu(a, &symbolic, &opts) } diff --git a/src/algorithm/sparse_linalg/cpu/triangular_solve.rs b/src/algorithm/sparse_linalg/cpu/triangular_solve.rs index 71dd1dc0..513f20ca 100644 --- a/src/algorithm/sparse_linalg/cpu/triangular_solve.rs +++ b/src/algorithm/sparse_linalg/cpu/triangular_solve.rs @@ -36,7 +36,7 @@ use crate::tensor::Tensor; /// # Returns /// /// Solution vector x `[n]` or matrix `[n, k]` -pub fn sparse_solve_triangular_cpu( +pub fn sparse_solve_triangular_cpu>( l_or_u: &CsrData, b: &Tensor, lower: bool, diff --git a/src/algorithm/sparse_linalg/lu/cpu/lu.rs b/src/algorithm/sparse_linalg/lu/cpu/lu.rs index 59260f41..38aa30d8 100644 --- a/src/algorithm/sparse_linalg/lu/cpu/lu.rs +++ b/src/algorithm/sparse_linalg/lu/cpu/lu.rs @@ -16,7 +16,7 @@ use crate::tensor::Tensor; /// Sparse LU factorization with full symbolic information (CPU) /// /// Uses Gilbert-Peierls left-looking algorithm with partial pivoting. -pub fn sparse_lu_cpu( +pub fn sparse_lu_cpu>( a: &CscData, symbolic: &LuSymbolic, options: &LuOptions, @@ -26,7 +26,7 @@ pub fn sparse_lu_cpu( } /// Sparse LU factorization with metrics (CPU) -pub fn sparse_lu_cpu_with_metrics( +pub fn sparse_lu_cpu_with_metrics>( a: &CscData, symbolic: &LuSymbolic, options: &LuOptions, @@ -101,7 +101,7 @@ pub fn sparse_lu_cpu_with_metrics( /// - Matrix dimensions don't match symbolic structure /// - Workspace dimension doesn't match matrix /// - Zero pivot encountered (unless diagonal shift is enabled) -pub fn sparse_lu_cpu_with_workspace( +pub fn sparse_lu_cpu_with_workspace>( a: &CscData, symbolic: &LuSymbolic, options: &LuOptions, @@ -113,7 +113,7 @@ pub fn sparse_lu_cpu_with_workspace( } /// Sparse LU factorization with workspace reuse and metrics (CPU) -pub fn sparse_lu_cpu_with_workspace_and_metrics( +pub fn sparse_lu_cpu_with_workspace_and_metrics>( a: &CscData, symbolic: &LuSymbolic, options: &LuOptions, @@ -194,7 +194,7 @@ pub fn sparse_lu_cpu_with_workspace_and_metrics( /// /// This version doesn't require full symbolic analysis from solvr. /// Fill-in is discovered dynamically, which is less efficient. -pub fn sparse_lu_simple_cpu( +pub fn sparse_lu_simple_cpu>( a: &CscData, options: &LuOptions, ) -> Result> { @@ -237,7 +237,10 @@ pub fn sparse_lu_simple_cpu( /// Solve Ax = b using precomputed LU factors (CPU) /// /// Solves by: x = U⁻¹ L⁻¹ P b -pub fn sparse_lu_solve_cpu(factors: &LuFactors, b: &Tensor) -> Result> { +pub fn sparse_lu_solve_cpu>( + factors: &LuFactors, + b: &Tensor, +) -> Result> { let n = factors.row_perm.len(); let b_shape = b.shape(); @@ -853,7 +856,7 @@ fn dfs_reach( // ============================================================================ /// Extract values as f64 from CSC matrix -fn extract_values_f64(a: &CscData) -> Result> { +fn extract_values_f64>(a: &CscData) -> Result> { let dtype = a.values().dtype(); match dtype { DType::F32 => Ok(a @@ -871,7 +874,7 @@ fn extract_values_f64(a: &CscData) -> Result> { } /// Extract values as f64 from tensor -fn extract_values_f64_tensor(t: &Tensor) -> Result> { +fn extract_values_f64_tensor>(t: &Tensor) -> Result> { let dtype = t.dtype(); match dtype { DType::F32 => Ok(t.to_vec::().iter().map(|&x| x as f64).collect()), @@ -884,7 +887,7 @@ fn extract_values_f64_tensor(t: &Tensor) -> Result> { } /// Create L and U tensors from computed values -fn create_lu_tensors( +fn create_lu_tensors>( n: usize, l_col_ptrs: &[i64], l_row_indices: &[i64], diff --git a/src/algorithm/sparse_linalg/lu/cuda/lu.rs b/src/algorithm/sparse_linalg/lu/cuda/lu.rs index ec4905d7..41c4189a 100644 --- a/src/algorithm/sparse_linalg/lu/cuda/lu.rs +++ b/src/algorithm/sparse_linalg/lu/cuda/lu.rs @@ -212,13 +212,13 @@ fn run_factorization_f32( let device_index = client.device.index; // Base GPU pointers - let a_values_ptr = a_values_gpu.storage().ptr(); - let a_row_indices_ptr = a_row_indices_gpu.storage().ptr(); - let l_values_ptr = l_values_gpu.storage().ptr(); - let l_row_indices_ptr = l_row_indices_gpu.storage().ptr(); - let u_values_ptr = u_values_gpu.storage().ptr(); - let u_row_indices_ptr = u_row_indices_gpu.storage().ptr(); - let work_ptr = work_gpu.storage().ptr(); + let a_values_ptr = a_values_gpu.ptr(); + let a_row_indices_ptr = a_row_indices_gpu.ptr(); + let l_values_ptr = l_values_gpu.ptr(); + let l_row_indices_ptr = l_row_indices_gpu.ptr(); + let u_values_ptr = u_values_gpu.ptr(); + let u_row_indices_ptr = u_row_indices_gpu.ptr(); + let work_ptr = work_gpu.ptr(); let elem_size = std::mem::size_of::() as u64; let idx_size = std::mem::size_of::() as u64; @@ -417,13 +417,13 @@ fn run_factorization_f64( let device_index = client.device.index; // Base GPU pointers - let a_values_ptr = a_values_gpu.storage().ptr(); - let a_row_indices_ptr = a_row_indices_gpu.storage().ptr(); - let l_values_ptr = l_values_gpu.storage().ptr(); - let l_row_indices_ptr = l_row_indices_gpu.storage().ptr(); - let u_values_ptr = u_values_gpu.storage().ptr(); - let u_row_indices_ptr = u_row_indices_gpu.storage().ptr(); - let work_ptr = work_gpu.storage().ptr(); + let a_values_ptr = a_values_gpu.ptr(); + let a_row_indices_ptr = a_row_indices_gpu.ptr(); + let l_values_ptr = l_values_gpu.ptr(); + let l_row_indices_ptr = l_row_indices_gpu.ptr(); + let u_values_ptr = u_values_gpu.ptr(); + let u_row_indices_ptr = u_row_indices_gpu.ptr(); + let work_ptr = work_gpu.ptr(); let elem_size = std::mem::size_of::() as u64; let idx_size = std::mem::size_of::() as u64; @@ -776,9 +776,9 @@ pub fn sparse_lu_solve_cuda( context, stream, device_index, - l_col_ptrs_gpu.storage().ptr(), - l_row_indices_gpu.storage().ptr(), - l_diag_ptr_gpu.storage().ptr(), + l_col_ptrs_gpu.ptr(), + l_row_indices_gpu.ptr(), + l_diag_ptr_gpu.ptr(), n as i32, )?; @@ -786,9 +786,9 @@ pub fn sparse_lu_solve_cuda( context, stream, device_index, - u_col_ptrs_gpu.storage().ptr(), - u_row_indices_gpu.storage().ptr(), - u_diag_ptr_gpu.storage().ptr(), + u_col_ptrs_gpu.ptr(), + u_row_indices_gpu.ptr(), + u_diag_ptr_gpu.ptr(), n as i32, )?; } @@ -806,9 +806,9 @@ pub fn sparse_lu_solve_cuda( context, stream, device_index, - b.storage().ptr(), - row_perm_gpu.storage().ptr(), - y_gpu.storage().ptr(), + b.ptr(), + row_perm_gpu.ptr(), + y_gpu.ptr(), n as i32, )?; }, @@ -817,9 +817,9 @@ pub fn sparse_lu_solve_cuda( context, stream, device_index, - b.storage().ptr(), - row_perm_gpu.storage().ptr(), - y_gpu.storage().ptr(), + b.ptr(), + row_perm_gpu.ptr(), + y_gpu.ptr(), n as i32, )?; }, @@ -840,8 +840,8 @@ pub fn sparse_lu_solve_cuda( continue; } - let level_cols_ptr = l_level_cols_gpu.storage().ptr() - + (level_start as u64) * std::mem::size_of::() as u64; + let level_cols_ptr = + l_level_cols_gpu.ptr() + (level_start as u64) * std::mem::size_of::() as u64; match dtype { DType::F32 => unsafe { @@ -851,11 +851,11 @@ pub fn sparse_lu_solve_cuda( device_index, level_cols_ptr, level_size, - l_col_ptrs_gpu.storage().ptr(), - l_row_indices_gpu.storage().ptr(), - factors.l.values().storage().ptr(), - l_diag_ptr_gpu.storage().ptr(), - y_gpu.storage().ptr(), + l_col_ptrs_gpu.ptr(), + l_row_indices_gpu.ptr(), + factors.l.values().ptr(), + l_diag_ptr_gpu.ptr(), + y_gpu.ptr(), n as i32, true, // L has unit diagonal for LU )?; @@ -867,11 +867,11 @@ pub fn sparse_lu_solve_cuda( device_index, level_cols_ptr, level_size, - l_col_ptrs_gpu.storage().ptr(), - l_row_indices_gpu.storage().ptr(), - factors.l.values().storage().ptr(), - l_diag_ptr_gpu.storage().ptr(), - y_gpu.storage().ptr(), + l_col_ptrs_gpu.ptr(), + l_row_indices_gpu.ptr(), + factors.l.values().ptr(), + l_diag_ptr_gpu.ptr(), + y_gpu.ptr(), n as i32, true, // L has unit diagonal for LU )?; @@ -894,8 +894,8 @@ pub fn sparse_lu_solve_cuda( continue; } - let level_cols_ptr = u_level_cols_gpu.storage().ptr() - + (level_start as u64) * std::mem::size_of::() as u64; + let level_cols_ptr = + u_level_cols_gpu.ptr() + (level_start as u64) * std::mem::size_of::() as u64; match dtype { DType::F32 => unsafe { @@ -905,11 +905,11 @@ pub fn sparse_lu_solve_cuda( device_index, level_cols_ptr, level_size, - u_col_ptrs_gpu.storage().ptr(), - u_row_indices_gpu.storage().ptr(), - factors.u.values().storage().ptr(), - u_diag_ptr_gpu.storage().ptr(), - y_gpu.storage().ptr(), + u_col_ptrs_gpu.ptr(), + u_row_indices_gpu.ptr(), + factors.u.values().ptr(), + u_diag_ptr_gpu.ptr(), + y_gpu.ptr(), n as i32, )?; }, @@ -920,11 +920,11 @@ pub fn sparse_lu_solve_cuda( device_index, level_cols_ptr, level_size, - u_col_ptrs_gpu.storage().ptr(), - u_row_indices_gpu.storage().ptr(), - factors.u.values().storage().ptr(), - u_diag_ptr_gpu.storage().ptr(), - y_gpu.storage().ptr(), + u_col_ptrs_gpu.ptr(), + u_row_indices_gpu.ptr(), + factors.u.values().ptr(), + u_diag_ptr_gpu.ptr(), + y_gpu.ptr(), n as i32, )?; }, diff --git a/src/algorithm/sparse_linalg/lu/wgpu/lu.rs b/src/algorithm/sparse_linalg/lu/wgpu/lu.rs index b7190062..fff38d0b 100644 --- a/src/algorithm/sparse_linalg/lu/wgpu/lu.rs +++ b/src/algorithm/sparse_linalg/lu/wgpu/lu.rs @@ -209,19 +209,19 @@ fn run_factorization_f32( let wgpu_device = &client.wgpu_device; // Get buffer references - let a_values_buf = get_buffer(a_values_gpu.storage().ptr()) + let a_values_buf = get_buffer(a_values_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid A values buffer".to_string()))?; - let a_row_indices_buf = get_buffer(a_row_indices_gpu.storage().ptr()) + let a_row_indices_buf = get_buffer(a_row_indices_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid A row_indices buffer".to_string()))?; - let l_values_buf = get_buffer(l_values_gpu.storage().ptr()) + let l_values_buf = get_buffer(l_values_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid L values buffer".to_string()))?; - let l_row_indices_buf = get_buffer(l_row_indices_gpu.storage().ptr()) + let l_row_indices_buf = get_buffer(l_row_indices_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid L row_indices buffer".to_string()))?; - let u_values_buf = get_buffer(u_values_gpu.storage().ptr()) + let u_values_buf = get_buffer(u_values_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid U values buffer".to_string()))?; - let u_row_indices_buf = get_buffer(u_row_indices_gpu.storage().ptr()) + let u_row_indices_buf = get_buffer(u_row_indices_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid U row_indices buffer".to_string()))?; - let work_buf = get_buffer(work_gpu.storage().ptr()) + let work_buf = get_buffer(work_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid work buffer".to_string()))?; // Create reusable uniform buffers for parameters @@ -802,31 +802,31 @@ pub fn sparse_lu_solve_wgpu( Tensor::::zeros(&[n], DType::I32, &device); // Get buffer references - let l_col_ptrs_buf = get_buffer(l_col_ptrs_gpu.storage().ptr()) + let l_col_ptrs_buf = get_buffer(l_col_ptrs_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid L col_ptrs buffer".to_string()))?; - let l_row_indices_buf = get_buffer(l_row_indices_gpu.storage().ptr()) + let l_row_indices_buf = get_buffer(l_row_indices_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid L row_indices buffer".to_string()))?; - let l_values_buf = get_buffer(factors.l.values().storage().ptr()) + let l_values_buf = get_buffer(factors.l.values().ptr()) .ok_or_else(|| Error::Internal("Invalid L values buffer".to_string()))?; - let l_diag_ptr_buf = get_buffer(l_diag_ptr_gpu.storage().ptr()) + let l_diag_ptr_buf = get_buffer(l_diag_ptr_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid L diag_ptr buffer".to_string()))?; - let l_level_cols_buf = get_buffer(l_level_cols_gpu.storage().ptr()) + let l_level_cols_buf = get_buffer(l_level_cols_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid L level_cols buffer".to_string()))?; - let u_col_ptrs_buf = get_buffer(u_col_ptrs_gpu.storage().ptr()) + let u_col_ptrs_buf = get_buffer(u_col_ptrs_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid U col_ptrs buffer".to_string()))?; - let u_row_indices_buf = get_buffer(u_row_indices_gpu.storage().ptr()) + let u_row_indices_buf = get_buffer(u_row_indices_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid U row_indices buffer".to_string()))?; - let u_values_buf = get_buffer(factors.u.values().storage().ptr()) + let u_values_buf = get_buffer(factors.u.values().ptr()) .ok_or_else(|| Error::Internal("Invalid U values buffer".to_string()))?; - let u_diag_ptr_buf = get_buffer(u_diag_ptr_gpu.storage().ptr()) + let u_diag_ptr_buf = get_buffer(u_diag_ptr_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid U diag_ptr buffer".to_string()))?; - let u_level_cols_buf = get_buffer(u_level_cols_gpu.storage().ptr()) + let u_level_cols_buf = get_buffer(u_level_cols_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid U level_cols buffer".to_string()))?; - let b_buf = get_buffer(b.storage().ptr()) - .ok_or_else(|| Error::Internal("Invalid b buffer".to_string()))?; - let row_perm_buf = get_buffer(row_perm_gpu.storage().ptr()) + let b_buf = + get_buffer(b.ptr()).ok_or_else(|| Error::Internal("Invalid b buffer".to_string()))?; + let row_perm_buf = get_buffer(row_perm_gpu.ptr()) .ok_or_else(|| Error::Internal("Invalid row_perm buffer".to_string()))?; // Load shader @@ -954,8 +954,8 @@ pub fn sparse_lu_solve_wgpu( // ========================================================================== let y_gpu: Tensor = Tensor::::zeros(&[n], dtype, &device); - let y_buf = get_buffer(y_gpu.storage().ptr()) - .ok_or_else(|| Error::Internal("Invalid y buffer".to_string()))?; + let y_buf = + get_buffer(y_gpu.ptr()).ok_or_else(|| Error::Internal("Invalid y buffer".to_string()))?; let perm_module = cache.get_or_create_module_from_source("sparse_apply_perm", shader_source); let perm_layout = cache.get_or_create_layout(LayoutKey { diff --git a/src/algorithm/sparse_linalg/mod.rs b/src/algorithm/sparse_linalg/mod.rs index 67a17160..3d575955 100644 --- a/src/algorithm/sparse_linalg/mod.rs +++ b/src/algorithm/sparse_linalg/mod.rs @@ -53,6 +53,7 @@ pub mod levels; pub mod lu; pub mod matching; pub mod ordering; +pub mod qr; pub mod symbolic; pub mod traits; pub mod types; @@ -93,3 +94,18 @@ pub use ordering::{ColamdOptions, ColamdStats, SparseOrdering, colamd}; // Re-export matching algorithms pub use matching::{BipartiteMatching, MatchingResult, hopcroft_karp, maximum_transversal}; + +// Re-export sparse QR types and functions +pub use qr::{ + QrFactors, QrMetrics, QrOptions, QrOrdering, QrSymbolic, sparse_qr_cpu, + sparse_qr_cpu_with_metrics, sparse_qr_least_squares_cpu, sparse_qr_simple_cpu, + sparse_qr_solve_cpu, sparse_qr_symbolic, +}; + +// Re-export CUDA QR implementations +#[cfg(feature = "cuda")] +pub use qr::{sparse_qr_cuda, sparse_qr_simple_cuda, sparse_qr_solve_cuda}; + +// Re-export WebGPU QR implementations +#[cfg(feature = "wgpu")] +pub use qr::{sparse_qr_simple_wgpu, sparse_qr_solve_wgpu, sparse_qr_wgpu}; diff --git a/src/algorithm/sparse_linalg/qr/cpu/algorithm.rs b/src/algorithm/sparse_linalg/qr/cpu/algorithm.rs new file mode 100644 index 00000000..83da479e --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/cpu/algorithm.rs @@ -0,0 +1,340 @@ +//! Core Householder QR algorithm for sparse matrices +//! +//! Column-wise left-looking Householder QR with rank detection. + +use crate::algorithm::sparse_linalg::qr::types::QrOptions; +use crate::error::{Error, Result}; + +/// Internal result from numeric QR factorization +pub(crate) struct QrNumericResult { + pub householder_vectors: Vec<(Vec, Vec)>, + pub tau: Vec, + pub r_col_ptrs: Vec, + pub r_row_indices: Vec, + pub r_values: Vec, + pub rank: usize, +} + +/// Column-wise left-looking Householder QR factorization +/// +/// Processes one column at a time: +/// 1. Apply COLAMD permutation to get A*P +/// 2. For each column k: +/// a. Scatter A*P column k into dense work vector +/// b. Apply previous Householder reflectors to the column +/// c. Compute new Householder reflector from column below diagonal +/// d. Store R entries (above diagonal) and reflector +/// 3. Detect rank from R diagonal +pub(crate) fn householder_qr( + m: usize, + n: usize, + col_ptrs: &[i64], + row_indices: &[i64], + values: &[f64], + col_perm: &[usize], + options: &QrOptions, +) -> Result { + let min_mn = m.min(n); + + let mut householder_vectors: Vec<(Vec, Vec)> = Vec::with_capacity(min_mn); + let mut tau_vec: Vec = Vec::with_capacity(min_mn); + + // R stored column by column (dynamically built) + let mut r_col_ptrs: Vec = vec![0i64; n + 1]; + let mut r_row_indices: Vec = Vec::new(); + let mut r_values: Vec = Vec::new(); + + let mut rank = min_mn; + + // Dense work vector for current column + let mut work = vec![0.0f64; m]; + + for k in 0..min_mn { + // Step 1: Scatter permuted column k into work vector + let orig_col = col_perm[k]; + let start = col_ptrs[orig_col] as usize; + let end = col_ptrs[orig_col + 1] as usize; + + work.fill(0.0); + for idx in start..end { + let row = row_indices[idx] as usize; + work[row] = values[idx]; + } + + // Step 2: Apply previous Householder reflectors Q_0..Q_{k-1} to this column + apply_reflectors(&householder_vectors, &tau_vec, &mut work, k); + + // Step 3: Extract R entries for column k (rows 0..k) + for row in 0..k { + if work[row].abs() > 1e-15 { + r_row_indices.push(row as i64); + r_values.push(work[row]); + } + } + + // Step 4: Compute Householder reflector for work[k..m] + let (v_indices, v_values, tau, diag_val) = compute_householder(&work, k, m); + + // Store R diagonal entry + r_row_indices.push(k as i64); + r_values.push(diag_val); + + r_col_ptrs[k + 1] = r_row_indices.len() as i64; + + // Check rank + if diag_val.abs() < options.rank_tolerance { + rank = k; + householder_vectors.push((v_indices, v_values)); + tau_vec.push(tau); + + process_remaining_columns( + k + 1, + min_mn, + n, + col_ptrs, + row_indices, + values, + col_perm, + &mut householder_vectors, + &mut tau_vec, + &mut work, + &mut r_col_ptrs, + &mut r_row_indices, + &mut r_values, + ); + + return Ok(QrNumericResult { + householder_vectors, + tau: tau_vec, + r_col_ptrs, + r_row_indices, + r_values, + rank, + }); + } + + // Store reflector + householder_vectors.push((v_indices, v_values)); + tau_vec.push(tau); + } + + // Fill remaining R col_ptrs for columns beyond min_mn (if n > m, they're empty) + for kk in min_mn..n { + r_col_ptrs[kk + 1] = r_col_ptrs[min_mn]; + } + + Ok(QrNumericResult { + householder_vectors, + tau: tau_vec, + r_col_ptrs, + r_row_indices, + r_values, + rank, + }) +} + +/// Apply Householder reflectors 0..count to a work vector +fn apply_reflectors( + householder_vectors: &[(Vec, Vec)], + tau_vec: &[f64], + work: &mut [f64], + count: usize, +) { + for j in 0..count { + let (ref v_indices, ref v_values) = householder_vectors[j]; + let tau_j = tau_vec[j]; + + let mut dot = 0.0; + for (idx, &vi) in v_indices.iter().zip(v_values.iter()) { + dot += vi * work[*idx as usize]; + } + + let scale = tau_j * dot; + for (idx, &vi) in v_indices.iter().zip(v_values.iter()) { + work[*idx as usize] -= scale * vi; + } + } +} + +/// Process remaining columns after rank deficiency is detected +#[allow(clippy::too_many_arguments)] +fn process_remaining_columns( + start_col: usize, + min_mn: usize, + n: usize, + col_ptrs: &[i64], + row_indices: &[i64], + values: &[f64], + col_perm: &[usize], + householder_vectors: &mut Vec<(Vec, Vec)>, + tau_vec: &mut Vec, + work: &mut [f64], + r_col_ptrs: &mut [i64], + r_row_indices: &mut Vec, + r_values: &mut Vec, +) { + let m = work.len(); + + for kk in start_col..min_mn { + let orig_col2 = col_perm[kk]; + let start2 = col_ptrs[orig_col2] as usize; + let end2 = col_ptrs[orig_col2 + 1] as usize; + + work.fill(0.0); + for idx in start2..end2 { + let row = row_indices[idx] as usize; + work[row] = values[idx]; + } + + // Apply all previous reflectors (including newly added ones) + apply_reflectors( + householder_vectors, + tau_vec, + work, + householder_vectors.len(), + ); + + // Store R column + for row in 0..=kk { + if work[row].abs() > 1e-15 || row == kk { + r_row_indices.push(row as i64); + r_values.push(work[row]); + } + } + r_col_ptrs[kk + 1] = r_row_indices.len() as i64; + + // Compute and store reflector for this column + let (vi, vv, t, _dv) = compute_householder(work, kk, m); + householder_vectors.push((vi, vv)); + tau_vec.push(t); + } + + // Fill remaining R col_ptrs + for kk in min_mn..n { + r_col_ptrs[kk + 1] = r_col_ptrs[kk]; + } +} + +/// Compute Householder reflector for x = work[start..m] +/// +/// Returns: (v_row_indices, v_values, tau, diagonal_value) +/// +/// The reflector satisfies: (I - tau * v * v^T) * x = ||x|| * e_1 +pub(crate) fn compute_householder( + work: &[f64], + start: usize, + m: usize, +) -> (Vec, Vec, f64, f64) { + // Compute norm of x = work[start..m] + let mut norm_sq = 0.0; + for i in start..m { + norm_sq += work[i] * work[i]; + } + let norm = norm_sq.sqrt(); + + if norm < 1e-30 { + // Zero column — no reflector needed + return (vec![start as i64], vec![1.0], 0.0, 0.0); + } + + // Choose sign to avoid cancellation: sigma = -sign(x[start]) * ||x|| + let sigma = if work[start] >= 0.0 { -norm } else { norm }; + let diag_val = sigma; // R[start, start] = sigma + + let v_start = work[start] - sigma; + + // Normalize v so that v[start] = 1 + if v_start.abs() < 1e-30 { + return (vec![start as i64], vec![1.0], 0.0, diag_val); + } + + let inv_v_start = 1.0 / v_start; + + let mut v_indices = Vec::new(); + let mut v_values = Vec::new(); + + v_indices.push(start as i64); + v_values.push(1.0); // v[start] = 1 (normalized) + + for i in (start + 1)..m { + if work[i].abs() > 1e-15 { + v_indices.push(i as i64); + v_values.push(work[i] * inv_v_start); + } + } + + // tau = (sigma - x[start]) / sigma = -v_start / sigma + let tau = -v_start / sigma; + + (v_indices, v_values, tau, diag_val) +} + +/// Apply Q^T to a vector by applying Householder reflectors in forward order. +/// +/// Q^T * b is computed as: for j = 0, 1, ..., k-1: b = (I - tau_j * v_j * v_j^T) * b +pub(crate) fn apply_qt(householder_vectors: &[(Vec, Vec)], tau: &[f64], b: &mut [f64]) { + for j in 0..householder_vectors.len() { + let (ref v_indices, ref v_values) = householder_vectors[j]; + let tau_j = tau[j]; + + if tau_j == 0.0 { + continue; + } + + let mut dot = 0.0; + for (idx, &vi) in v_indices.iter().zip(v_values.iter()) { + dot += vi * b[*idx as usize]; + } + + let scale = tau_j * dot; + for (idx, &vi) in v_indices.iter().zip(v_values.iter()) { + b[*idx as usize] -= scale * vi; + } + } +} + +/// Back-substitute: solve R[0:n, 0:n] * x = rhs +/// R is in CSC format. +pub(crate) fn back_substitute( + n: usize, + r_col_ptrs: &[i64], + r_row_indices: &[i64], + r_values: &[f64], + rhs: &[f64], + x: &mut [f64], +) -> Result<()> { + x[..n].copy_from_slice(rhs); + + for col in (0..n).rev() { + let start = r_col_ptrs[col] as usize; + let end = r_col_ptrs[col + 1] as usize; + + // Find diagonal entry + let mut diag_val = 0.0; + for idx in start..end { + if r_row_indices[idx] as usize == col { + diag_val = r_values[idx]; + break; + } + } + + if diag_val.abs() < 1e-30 { + return Err(Error::Internal(format!( + "sparse_qr back_substitute: zero diagonal at column {}", + col + ))); + } + + x[col] /= diag_val; + + // Update rows above + for idx in start..end { + let row = r_row_indices[idx] as usize; + if row < col { + x[row] -= r_values[idx] * x[col]; + } + } + } + + Ok(()) +} diff --git a/src/algorithm/sparse_linalg/qr/cpu/helpers.rs b/src/algorithm/sparse_linalg/qr/cpu/helpers.rs new file mode 100644 index 00000000..77b8942d --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/cpu/helpers.rs @@ -0,0 +1,155 @@ +//! Helper functions for sparse QR CPU implementation +//! +//! Data extraction and tensor creation utilities. + +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::runtime::Runtime; +use crate::sparse::CscData; +use crate::tensor::Tensor; + +/// Extract values as f64 from CSC matrix (sparse QR requires floating-point) +pub(crate) fn extract_values_f64>(a: &CscData) -> Result> { + let dtype = a.values().dtype(); + match dtype { + DType::F32 => Ok(a + .values() + .to_vec::() + .iter() + .map(|&x| x as f64) + .collect()), + DType::F64 => Ok(a.values().to_vec()), + _ => Err(Error::UnsupportedDType { + dtype, + op: "sparse_qr", + }), + } +} + +/// Extract values as f64 from tensor (sparse QR requires floating-point) +pub(crate) fn extract_values_f64_tensor>( + t: &Tensor, +) -> Result> { + let dtype = t.dtype(); + match dtype { + DType::F32 => Ok(t.to_vec::().iter().map(|&x| x as f64).collect()), + DType::F64 => Ok(t.to_vec()), + _ => Err(Error::UnsupportedDType { + dtype, + op: "sparse_qr", + }), + } +} + +/// Create R tensor in CSC format +pub(crate) fn create_r_tensor>( + m: usize, + n: usize, + r_col_ptrs: &[i64], + r_row_indices: &[i64], + r_values: &[f64], + dtype: DType, + device: &R::Device, +) -> Result> { + match dtype { + DType::F32 => { + let vals_f32: Vec = r_values.iter().map(|&x| x as f32).collect(); + CscData::::from_slices(r_col_ptrs, r_row_indices, &vals_f32, [m, n], device) + } + DType::F64 => { + CscData::::from_slices(r_col_ptrs, r_row_indices, r_values, [m, n], device) + } + _ => Err(Error::UnsupportedDType { + dtype, + op: "sparse_qr", + }), + } +} + +/// Create a vector tensor from f64 data +pub(crate) fn create_vector_tensor>( + data: &[f64], + dtype: DType, + device: &R::Device, +) -> Result> { + let n = data.len(); + match dtype { + DType::F32 => { + let data_f32: Vec = data.iter().map(|&x| x as f32).collect(); + Ok(Tensor::::from_slice(&data_f32, &[n], device)) + } + DType::F64 => Ok(Tensor::::from_slice(data, &[n], device)), + _ => Err(Error::UnsupportedDType { + dtype, + op: "sparse_qr", + }), + } +} + +/// Compute dense Householder vector offset for reflector k in a flat buffer. +/// +/// Reflector k has length (m - k), stored at offset `k*m - k*(k-1)/2`. +/// This packs variable-length vectors contiguously: reflector 0 at offset 0 +/// with length m, reflector 1 at offset m with length m-1, etc. +#[cfg(any(feature = "cuda", feature = "wgpu"))] +pub(crate) fn h_offset(k: usize, m: usize) -> usize { + k * m - k * (k.wrapping_sub(1)) / 2 +} + +/// Compute R off-diagonal offset for column k in a flat buffer. +/// +/// Column k has k off-diagonal entries, stored at offset `k*(k-1)/2`. +#[cfg(any(feature = "cuda", feature = "wgpu"))] +pub(crate) fn r_offdiag_offset(k: usize) -> usize { + k * (k.wrapping_sub(1)) / 2 +} + +/// Build R factor in CSC format from flat off-diagonal and diagonal buffers. +/// +/// Off-diagonal entries for column k are stored at `r_offdiag_offset(k)` with +/// k entries. Diagonal entries are in a separate `diag` array. Near-zero +/// off-diagonal entries are dropped. +#[cfg(any(feature = "cuda", feature = "wgpu"))] +pub(crate) fn build_r_csc( + r_offdiag: &[f64], + diag: &[f64], + min_mn: usize, + n: usize, +) -> (Vec, Vec, Vec) { + let mut r_col_ptrs = vec![0i64; n + 1]; + let mut r_row_indices: Vec = Vec::new(); + let mut r_values: Vec = Vec::new(); + + for k in 0..min_mn { + let ro = r_offdiag_offset(k); + for row in 0..k { + let val = r_offdiag[ro + row]; + if val.abs() > 1e-15 { + r_row_indices.push(row as i64); + r_values.push(val); + } + } + r_row_indices.push(k as i64); + r_values.push(diag[k]); + r_col_ptrs[k + 1] = r_row_indices.len() as i64; + } + for k in min_mn..n { + r_col_ptrs[k + 1] = r_col_ptrs[min_mn]; + } + + (r_col_ptrs, r_row_indices, r_values) +} + +/// Detect numerical rank from R diagonal entries. +/// +/// Returns the index of the first diagonal entry whose absolute value is +/// below `rank_tolerance`, or `min_mn` if all entries are above tolerance. +#[cfg(any(feature = "cuda", feature = "wgpu"))] +pub(crate) fn detect_rank(diag: &[f64], min_mn: usize, rank_tolerance: f64) -> usize { + for k in 0..min_mn { + if diag[k].abs() < rank_tolerance { + return k; + } + } + min_mn +} diff --git a/src/algorithm/sparse_linalg/qr/cpu/mod.rs b/src/algorithm/sparse_linalg/qr/cpu/mod.rs new file mode 100644 index 00000000..49d1621d --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/cpu/mod.rs @@ -0,0 +1,12 @@ +//! CPU implementation of sparse QR factorization +//! +//! Householder QR with COLAMD column ordering. + +pub(crate) mod algorithm; +pub(crate) mod helpers; +mod qr; + +pub use qr::{ + sparse_qr_cpu, sparse_qr_cpu_with_metrics, sparse_qr_least_squares_cpu, sparse_qr_simple_cpu, + sparse_qr_solve_cpu, +}; diff --git a/src/algorithm/sparse_linalg/qr/cpu/qr.rs b/src/algorithm/sparse_linalg/qr/cpu/qr.rs new file mode 100644 index 00000000..34eb103b --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/cpu/qr.rs @@ -0,0 +1,443 @@ +//! CPU implementation of sparse Householder QR factorization +//! +//! Column-wise left-looking Householder QR with partial pivoting (rank detection). + +use crate::algorithm::sparse_linalg::qr::symbolic::sparse_qr_symbolic; +use crate::algorithm::sparse_linalg::qr::types::{QrFactors, QrMetrics, QrOptions, QrSymbolic}; +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::runtime::Runtime; +use crate::sparse::CscData; +use crate::tensor::Tensor; + +use super::algorithm::{apply_qt, back_substitute, householder_qr}; +use super::helpers::{ + create_r_tensor, create_vector_tensor, extract_values_f64, extract_values_f64_tensor, +}; + +/// Sparse QR factorization with precomputed symbolic information (CPU) +pub fn sparse_qr_cpu>( + a: &CscData, + symbolic: &QrSymbolic, + options: &QrOptions, +) -> Result> { + let (factors, _metrics) = sparse_qr_cpu_with_metrics(a, symbolic, options)?; + Ok(factors) +} + +/// Sparse QR factorization with metrics (CPU) +pub fn sparse_qr_cpu_with_metrics>( + a: &CscData, + symbolic: &QrSymbolic, + options: &QrOptions, +) -> Result<(QrFactors, QrMetrics)> { + let [m, n] = a.shape; + + if m != symbolic.m || n != symbolic.n { + return Err(Error::ShapeMismatch { + expected: vec![symbolic.m, symbolic.n], + got: vec![m, n], + }); + } + + if m < n { + return Err(Error::Internal( + "sparse_qr: requires m >= n (more rows than columns)".to_string(), + )); + } + + let col_ptrs: Vec = a.col_ptrs().to_vec(); + let row_indices: Vec = a.row_indices().to_vec(); + let values = extract_values_f64(a)?; + + let result = householder_qr( + m, + n, + &col_ptrs, + &row_indices, + &values, + &symbolic.col_perm, + options, + )?; + + let device = a.values().device(); + let dtype = a.values().dtype(); + + let r = create_r_tensor::( + m, + n, + &result.r_col_ptrs, + &result.r_row_indices, + &result.r_values, + dtype, + device, + )?; + + let original_nnz = values.len(); + let r_nnz = result.r_values.len(); + + let factors = QrFactors { + householder_vectors: result.householder_vectors, + tau: result.tau, + r, + col_perm: symbolic.col_perm.clone(), + rank: result.rank, + gpu_householder_values: None, + gpu_tau: None, + }; + + let metrics = QrMetrics { + original_nnz, + r_nnz, + fill_ratio: if original_nnz > 0 { + r_nnz as f64 / original_nnz as f64 + } else { + 0.0 + }, + numerical_rank: result.rank, + }; + + Ok((factors, metrics)) +} + +/// Sparse QR factorization without precomputed symbolic information (CPU) +pub fn sparse_qr_simple_cpu>( + a: &CscData, + options: &QrOptions, +) -> Result> { + let [m, n] = a.shape; + let col_ptrs: Vec = a.col_ptrs().to_vec(); + let row_indices: Vec = a.row_indices().to_vec(); + + let symbolic = sparse_qr_symbolic(&col_ptrs, &row_indices, m, n, options)?; + sparse_qr_cpu(a, &symbolic, options) +} + +/// Solve A*x = b using precomputed QR factors (square full-rank systems) +/// +/// Computes x = P * R^{-1} * Q^T * b +pub fn sparse_qr_solve_cpu>( + factors: &QrFactors, + b: &Tensor, +) -> Result> { + let [m, n] = factors.r.shape; + let b_shape = b.shape(); + + if b_shape.is_empty() || b_shape[0] != m { + return Err(Error::ShapeMismatch { + expected: vec![m], + got: b_shape.to_vec(), + }); + } + + if factors.rank < n { + return Err(Error::Internal(format!( + "sparse_qr_solve: matrix is rank-deficient (rank {} < n {})", + factors.rank, n + ))); + } + + let b_vals = extract_values_f64_tensor(b)?; + + // Step 1: Compute Q^T * b by applying Householder reflectors + let mut qtb = b_vals; + apply_qt(&factors.householder_vectors, &factors.tau, &mut qtb); + + // Step 2: Back-substitute R * x = (Q^T * b)[0:n] + let r_col_ptrs: Vec = factors.r.col_ptrs().to_vec(); + let r_row_indices: Vec = factors.r.row_indices().to_vec(); + let r_values = extract_values_f64(&factors.r)?; + + let mut x = vec![0.0f64; n]; + back_substitute(n, &r_col_ptrs, &r_row_indices, &r_values, &qtb[..n], &mut x)?; + + // Step 3: Apply column permutation: x_orig[col_perm[k]] = x[k] + let mut x_perm = vec![0.0f64; n]; + for (k, &orig_col) in factors.col_perm.iter().enumerate() { + x_perm[orig_col] = x[k]; + } + + create_vector_tensor::(&x_perm, b.dtype(), b.device()) +} + +/// Solve least-squares min ||A*x - b||_2 using QR factors (overdetermined systems) +/// +/// For m > n: x = P * R[0:n, 0:n]^{-1} * (Q^T * b)[0:n] +pub fn sparse_qr_least_squares_cpu>( + factors: &QrFactors, + b: &Tensor, +) -> Result> { + let [m, n] = factors.r.shape; + let b_shape = b.shape(); + + if b_shape.is_empty() || b_shape[0] != m { + return Err(Error::ShapeMismatch { + expected: vec![m], + got: b_shape.to_vec(), + }); + } + + if factors.rank == 0 { + return Err(Error::Internal( + "sparse_qr_least_squares: matrix has zero rank".to_string(), + )); + } + + let b_vals = extract_values_f64_tensor(b)?; + + // Step 1: Compute Q^T * b + let mut qtb = b_vals; + apply_qt(&factors.householder_vectors, &factors.tau, &mut qtb); + + // Step 2: Back-substitute R[0:rank, 0:rank] * x = (Q^T * b)[0:rank] + let r_col_ptrs: Vec = factors.r.col_ptrs().to_vec(); + let r_row_indices: Vec = factors.r.row_indices().to_vec(); + let r_values = extract_values_f64(&factors.r)?; + + let rank = factors.rank; + let mut x = vec![0.0f64; n]; + back_substitute( + rank, + &r_col_ptrs, + &r_row_indices, + &r_values, + &qtb[..rank], + &mut x, + )?; + // Columns rank..n remain zero (minimum-norm solution) + + // Step 3: Apply column permutation + let mut x_perm = vec![0.0f64; n]; + for (k, &orig_col) in factors.col_perm.iter().enumerate() { + if k < n { + x_perm[orig_col] = x[k]; + } + } + + create_vector_tensor::(&x_perm, b.dtype(), b.device()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::runtime::cpu::CpuRuntime; + + fn cpu_device() -> ::Device { + ::Device::default() + } + + /// Create a 4x4 tridiagonal SPD matrix in CSC format + fn create_tridiagonal_4x4() -> CscData { + let col_ptrs = vec![0i64, 2, 5, 8, 10]; + let row_indices = vec![0i64, 1, 0, 1, 2, 1, 2, 3, 2, 3]; + let values = vec![4.0f64, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0]; + CscData::from_slices(&col_ptrs, &row_indices, &values, [4, 4], &cpu_device()).unwrap() + } + + /// Create a 5x3 overdetermined matrix in CSC format + fn create_overdetermined_5x3() -> CscData { + let col_ptrs = vec![0i64, 3, 6, 8]; + let row_indices = vec![0i64, 2, 4, 1, 3, 4, 0, 3]; + let values = vec![1.0f64, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; + CscData::from_slices(&col_ptrs, &row_indices, &values, [5, 3], &cpu_device()).unwrap() + } + + fn verify_ax_eq_b(a_dense: &[&[f64]], x: &[f64], b: &[f64]) { + let m = a_dense.len(); + let n = x.len(); + for i in 0..m { + let mut ax_i = 0.0; + for j in 0..n { + ax_i += a_dense[i][j] * x[j]; + } + assert!( + (ax_i - b[i]).abs() < 1e-10, + "A*x[{}] = {}, expected {}", + i, + ax_i, + b[i] + ); + } + } + + #[test] + fn test_sparse_qr_simple_square() { + let a = create_tridiagonal_4x4(); + let options = QrOptions::no_ordering(); + let factors = sparse_qr_simple_cpu(&a, &options).unwrap(); + + assert_eq!(factors.rank, 4); + assert_eq!(factors.householder_vectors.len(), 4); + assert_eq!(factors.tau.len(), 4); + } + + #[test] + fn test_sparse_qr_solve_square() { + let a = create_tridiagonal_4x4(); + let options = QrOptions::no_ordering(); + let factors = sparse_qr_simple_cpu(&a, &options).unwrap(); + + let b = Tensor::::from_slice(&[1.0f64, 2.0, 3.0, 4.0], &[4], &cpu_device()); + let x = sparse_qr_solve_cpu(&factors, &b).unwrap(); + let x_vals: Vec = x.to_vec(); + + let a_dense: &[&[f64]] = &[ + &[4.0, 1.0, 0.0, 0.0], + &[1.0, 4.0, 1.0, 0.0], + &[0.0, 1.0, 4.0, 1.0], + &[0.0, 0.0, 1.0, 4.0], + ]; + verify_ax_eq_b(a_dense, &x_vals, &[1.0, 2.0, 3.0, 4.0]); + } + + #[test] + fn test_sparse_qr_overdetermined_least_squares() { + let a = create_overdetermined_5x3(); + let options = QrOptions::no_ordering(); + let factors = sparse_qr_simple_cpu(&a, &options).unwrap(); + + assert_eq!(factors.rank, 3); + + let b = + Tensor::::from_slice(&[1.0f64, 2.0, 3.0, 4.0, 5.0], &[5], &cpu_device()); + let x = sparse_qr_least_squares_cpu(&factors, &b).unwrap(); + let x_vals: Vec = x.to_vec(); + + // Verify optimality: A^T * (A*x - b) ≈ 0 + let a_dense = [ + [1.0, 0.0, 1.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 1.0], + [1.0, 1.0, 0.0], + ]; + let b_vals = [1.0, 2.0, 3.0, 4.0, 5.0]; + + let mut residual = vec![0.0f64; 5]; + for i in 0..5 { + for j in 0..3 { + residual[i] += a_dense[i][j] * x_vals[j]; + } + residual[i] -= b_vals[i]; + } + + for j in 0..3 { + let mut at_r = 0.0; + for i in 0..5 { + at_r += a_dense[i][j] * residual[i]; + } + assert!( + at_r.abs() < 1e-10, + "A^T * residual[{}] = {}, expected ~0", + j, + at_r + ); + } + } + + #[test] + fn test_sparse_qr_rank_deficient() { + // Rank-2 matrix (3x3) where col 2 = col 0 + col 1 + let col_ptrs = vec![0i64, 2, 4, 7]; + let row_indices = vec![0i64, 2, 1, 2, 0, 1, 2]; + let values = vec![1.0f64, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0]; + let a = CscData::::from_slices( + &col_ptrs, + &row_indices, + &values, + [3, 3], + &cpu_device(), + ) + .unwrap(); + + let options = QrOptions::no_ordering(); + let factors = sparse_qr_simple_cpu(&a, &options).unwrap(); + + assert!( + factors.rank < 3, + "Expected rank < 3, got rank = {}", + factors.rank + ); + } + + #[test] + fn test_sparse_qr_with_colamd() { + let a = create_tridiagonal_4x4(); + let options = QrOptions::default(); // Uses Colamd + let factors = sparse_qr_simple_cpu(&a, &options).unwrap(); + + assert_eq!(factors.rank, 4); + + let b = Tensor::::from_slice(&[1.0f64, 0.0, 0.0, 0.0], &[4], &cpu_device()); + let x = sparse_qr_solve_cpu(&factors, &b).unwrap(); + let x_vals: Vec = x.to_vec(); + + let a_dense: &[&[f64]] = &[ + &[4.0, 1.0, 0.0, 0.0], + &[1.0, 4.0, 1.0, 0.0], + &[0.0, 1.0, 4.0, 1.0], + &[0.0, 0.0, 1.0, 4.0], + ]; + verify_ax_eq_b(a_dense, &x_vals, &[1.0, 0.0, 0.0, 0.0]); + } + + #[test] + fn test_sparse_qr_known_diagonal() { + // 2x2 identity matrix: QR should give R = I + let col_ptrs = vec![0i64, 1, 2]; + let row_indices = vec![0i64, 1]; + let values = vec![1.0f64, 1.0]; + let a = CscData::::from_slices( + &col_ptrs, + &row_indices, + &values, + [2, 2], + &cpu_device(), + ) + .unwrap(); + + let options = QrOptions::no_ordering(); + let factors = sparse_qr_simple_cpu(&a, &options).unwrap(); + + assert_eq!(factors.rank, 2); + + // R diagonal should be ±1 + let r_values: Vec = factors.r.values().to_vec(); + let r_row_indices: Vec = factors.r.row_indices().to_vec(); + let r_col_ptrs: Vec = factors.r.col_ptrs().to_vec(); + + for col in 0..2 { + let start = r_col_ptrs[col] as usize; + let end = r_col_ptrs[col + 1] as usize; + for idx in start..end { + if r_row_indices[idx] as usize == col { + assert!( + (r_values[idx].abs() - 1.0).abs() < 1e-10, + "R[{},{}] = {}, expected ±1", + r_row_indices[idx], + col, + r_values[idx] + ); + } + } + } + } + + #[test] + fn test_sparse_qr_metrics() { + let a = create_tridiagonal_4x4(); + let options = QrOptions::no_ordering(); + + let col_ptrs: Vec = a.col_ptrs().to_vec(); + let row_indices: Vec = a.row_indices().to_vec(); + let symbolic = sparse_qr_symbolic(&col_ptrs, &row_indices, 4, 4, &options).unwrap(); + + let (factors, metrics) = sparse_qr_cpu_with_metrics(&a, &symbolic, &options).unwrap(); + + assert_eq!(metrics.original_nnz, 10); + assert_eq!(metrics.numerical_rank, 4); + assert!(metrics.r_nnz > 0); + assert!(metrics.fill_ratio > 0.0); + assert_eq!(factors.rank, 4); + } +} diff --git a/src/algorithm/sparse_linalg/qr/cuda/factorize.rs b/src/algorithm/sparse_linalg/qr/cuda/factorize.rs new file mode 100644 index 00000000..a618c4f4 --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/cuda/factorize.rs @@ -0,0 +1,460 @@ +//! CUDA GPU factorization loop for sparse Householder QR +//! +//! Keeps ALL data on GPU with zero intermediate transfers: +//! 1. Structure (col_ptrs, col_perm) on CPU drives the column loop +//! 2. Matrix values and dense work buffers on GPU +//! 3. Householder vectors stored as dense sub-vectors on GPU (kept GPU-resident) +//! 4. Only R structural data (diag, off-diag) transferred to CPU for CSC construction + +use crate::algorithm::sparse_linalg::qr::cpu::helpers::{ + build_r_csc, create_r_tensor, detect_rank, h_offset, +}; +use crate::algorithm::sparse_linalg::qr::types::{QrFactors, QrOptions, QrSymbolic}; +use crate::error::{Error, Result}; +use crate::runtime::cuda::kernels::{ + launch_sparse_qr_apply_reflector_f32, launch_sparse_qr_apply_reflector_f64, + launch_sparse_qr_clear_f32, launch_sparse_qr_clear_f64, launch_sparse_qr_extract_r_f32, + launch_sparse_qr_extract_r_f64, launch_sparse_qr_householder_f32, + launch_sparse_qr_householder_f64, launch_sparse_qr_norm_f32, launch_sparse_qr_norm_f64, + launch_sparse_scatter_f32, launch_sparse_scatter_f64, +}; +use crate::runtime::cuda::{CudaClient, CudaRuntime}; +use crate::sparse::CscData; +use crate::tensor::Tensor; + +/// Run the GPU factorization for a specific dtype +pub(super) fn run_factorization( + client: &CudaClient, + a: &CscData, + symbolic: &QrSymbolic, + options: &QrOptions, + m: usize, + n: usize, +) -> Result> { + let dtype = a.values().dtype(); + let min_mn = m.min(n); + let device = a.values().device(); + let col_ptrs: Vec = a.col_ptrs().to_vec(); + + // A's row_indices as i32 for CUDA kernels + let a_row_indices_i32: Vec = a + .row_indices() + .to_vec::() + .iter() + .map(|&x| x as i32) + .collect(); + let a_row_indices_gpu = + Tensor::::from_slice(&a_row_indices_i32, &[a_row_indices_i32.len()], &device); + + // Pre-compute buffer sizes + let total_h_size = if min_mn > 0 { + h_offset(min_mn - 1, m) + (m - (min_mn - 1)) + } else { + 0 + }; + let total_r_offdiag = min_mn * min_mn.saturating_sub(1) / 2; + + // Allocate GPU buffers + let work_gpu = Tensor::::zeros(&[m], dtype, &device); + let h_values_gpu = Tensor::::zeros(&[total_h_size.max(1)], dtype, &device); + let tau_gpu = Tensor::::zeros(&[min_mn.max(1)], dtype, &device); + let diag_gpu = Tensor::::zeros(&[min_mn.max(1)], dtype, &device); + let r_offdiag_gpu = Tensor::::zeros(&[total_r_offdiag.max(1)], dtype, &device); + let norm_sq_gpu = Tensor::::zeros(&[1], dtype, &device); + + let context = &client.context; + let stream = &client.stream; + let device_index = client.device.index; + + let elem_size = T::ELEM_SIZE as u64; + let idx_size = std::mem::size_of::() as u64; + + let work_ptr = work_gpu.ptr(); + let h_values_ptr = h_values_gpu.ptr(); + let tau_ptr = tau_gpu.ptr(); + let diag_ptr = diag_gpu.ptr(); + let r_offdiag_ptr = r_offdiag_gpu.ptr(); + let norm_sq_ptr = norm_sq_gpu.ptr(); + let a_values_ptr = a.values().ptr(); + let a_indices_ptr = a_row_indices_gpu.ptr(); + + for k in 0..min_mn { + // Step 1: Clear work vector + unsafe { T::launch_clear(context, stream, device_index, work_ptr, m as i32)? }; + + // Step 2: Scatter permuted column into work + let orig_col = symbolic.col_perm[k]; + let a_col_start = col_ptrs[orig_col] as usize; + let a_col_end = col_ptrs[orig_col + 1] as usize; + let a_col_nnz = a_col_end - a_col_start; + + if a_col_nnz > 0 { + let values_offset = a_values_ptr + (a_col_start as u64) * elem_size; + let indices_offset = a_indices_ptr + (a_col_start as u64) * idx_size; + + unsafe { + T::launch_scatter( + context, + stream, + device_index, + values_offset, + indices_offset, + work_ptr, + a_col_nnz as i32, + )?; + } + } + + // Step 3: Apply previous Householder reflectors + for j in 0..k { + let v_offset = h_values_ptr + (h_offset(j, m) as u64) * elem_size; + let tau_j_ptr = tau_ptr + (j as u64) * elem_size; + + unsafe { + T::launch_apply_reflector( + context, + stream, + device_index, + v_offset, + j as i32, + (m - j) as i32, + tau_j_ptr, + work_ptr, + m as i32, + )?; + } + } + + // Step 4: Extract R off-diagonal entries (work[0..k]) + if k > 0 { + let r_out = r_offdiag_ptr + + (crate::algorithm::sparse_linalg::qr::cpu::helpers::r_offdiag_offset(k) as u64) + * elem_size; + unsafe { + T::launch_extract_r(context, stream, device_index, work_ptr, k as i32, r_out)?; + } + } + + // Step 5: Compute norm ||work[k..m]||^2 + unsafe { + T::launch_norm( + context, + stream, + device_index, + work_ptr, + k as i32, + (m - k) as i32, + norm_sq_ptr, + )?; + } + + // Step 6: Compute Householder vector + let h_out = h_values_ptr + (h_offset(k, m) as u64) * elem_size; + let tau_k_ptr = tau_ptr + (k as u64) * elem_size; + let diag_k_ptr = diag_ptr + (k as u64) * elem_size; + + unsafe { + T::launch_householder( + context, + stream, + device_index, + work_ptr, + k as i32, + m as i32, + norm_sq_ptr, + h_out, + tau_k_ptr, + diag_k_ptr, + )?; + } + } + + // Synchronize + client + .stream + .synchronize() + .map_err(|e| Error::Internal(format!("CUDA stream sync failed: {:?}", e)))?; + + // Transfer ONLY R structural data (diag + off-diag) for CSC construction. + // Householder vectors and tau stay GPU-resident — no GPU→CPU transfer. + let diag_cpu = T::structural_to_f64(&diag_gpu, min_mn); + let r_offdiag_cpu = T::structural_to_f64(&r_offdiag_gpu, total_r_offdiag); + + // Build R factor on CPU (small structural data) + let (r_col_ptrs, r_row_indices, r_values) = build_r_csc(&r_offdiag_cpu, &diag_cpu, min_mn, n); + let rank = detect_rank(&diag_cpu, min_mn, options.rank_tolerance); + let r = create_r_tensor::( + m, + n, + &r_col_ptrs, + &r_row_indices, + &r_values, + dtype, + &device, + )?; + + Ok(QrFactors { + // GPU factorization keeps Householder data GPU-resident only. + // CPU sparse representation is empty; use gpu_householder_values for solve. + householder_vectors: Vec::new(), + tau: Vec::new(), + r, + col_perm: symbolic.col_perm.clone(), + rank, + gpu_householder_values: Some(h_values_gpu), + gpu_tau: Some(tau_gpu), + }) +} + +/// Trait for dtype-specific GPU kernel dispatch. +/// +/// Eliminates f32/f64 code duplication by providing a uniform interface +/// to dtype-specific CUDA kernel launchers. +pub(super) trait GpuQrScalar: Sized { + const ELEM_SIZE: usize; + + unsafe fn launch_clear( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + work: u64, + n: i32, + ) -> Result<()>; + + unsafe fn launch_scatter( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + values: u64, + indices: u64, + work: u64, + nnz: i32, + ) -> Result<()>; + + unsafe fn launch_apply_reflector( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + v: u64, + v_start: i32, + v_len: i32, + tau_ptr: u64, + work: u64, + m: i32, + ) -> Result<()>; + + unsafe fn launch_norm( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + work: u64, + start: i32, + count: i32, + result: u64, + ) -> Result<()>; + + unsafe fn launch_householder( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + work: u64, + start: i32, + m: i32, + norm_sq: u64, + out_v: u64, + out_tau: u64, + out_diag: u64, + ) -> Result<()>; + + unsafe fn launch_extract_r( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + work: u64, + count: i32, + output: u64, + ) -> Result<()>; + + /// Extract small structural data (diag, off-diag) as f64 for R CSC construction. + /// Only used for O(n) / O(n²) structural buffers, NOT for large data tensors. + fn structural_to_f64(tensor: &Tensor, count: usize) -> Vec; +} + +impl GpuQrScalar for f32 { + const ELEM_SIZE: usize = 4; + + unsafe fn launch_clear( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + work: u64, + n: i32, + ) -> Result<()> { + unsafe { launch_sparse_qr_clear_f32(ctx, stream, dev, work, n) } + } + unsafe fn launch_scatter( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + values: u64, + indices: u64, + work: u64, + nnz: i32, + ) -> Result<()> { + unsafe { launch_sparse_scatter_f32(ctx, stream, dev, values, indices, work, nnz) } + } + unsafe fn launch_apply_reflector( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + v: u64, + v_start: i32, + v_len: i32, + tau_ptr: u64, + work: u64, + m: i32, + ) -> Result<()> { + unsafe { + launch_sparse_qr_apply_reflector_f32( + ctx, stream, dev, v, v_start, v_len, tau_ptr, work, m, + ) + } + } + unsafe fn launch_norm( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + work: u64, + start: i32, + count: i32, + result: u64, + ) -> Result<()> { + unsafe { launch_sparse_qr_norm_f32(ctx, stream, dev, work, start, count, result) } + } + unsafe fn launch_householder( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + work: u64, + start: i32, + m: i32, + norm_sq: u64, + out_v: u64, + out_tau: u64, + out_diag: u64, + ) -> Result<()> { + unsafe { + launch_sparse_qr_householder_f32( + ctx, stream, dev, work, start, m, norm_sq, out_v, out_tau, out_diag, + ) + } + } + unsafe fn launch_extract_r( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + work: u64, + count: i32, + output: u64, + ) -> Result<()> { + unsafe { launch_sparse_qr_extract_r_f32(ctx, stream, dev, work, count, output) } + } + fn structural_to_f64(tensor: &Tensor, count: usize) -> Vec { + if count == 0 { + return vec![]; + } + tensor + .to_vec::() + .iter() + .take(count) + .map(|&x| x as f64) + .collect() + } +} + +impl GpuQrScalar for f64 { + const ELEM_SIZE: usize = 8; + + unsafe fn launch_clear( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + work: u64, + n: i32, + ) -> Result<()> { + unsafe { launch_sparse_qr_clear_f64(ctx, stream, dev, work, n) } + } + unsafe fn launch_scatter( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + values: u64, + indices: u64, + work: u64, + nnz: i32, + ) -> Result<()> { + unsafe { launch_sparse_scatter_f64(ctx, stream, dev, values, indices, work, nnz) } + } + unsafe fn launch_apply_reflector( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + v: u64, + v_start: i32, + v_len: i32, + tau_ptr: u64, + work: u64, + m: i32, + ) -> Result<()> { + unsafe { + launch_sparse_qr_apply_reflector_f64( + ctx, stream, dev, v, v_start, v_len, tau_ptr, work, m, + ) + } + } + unsafe fn launch_norm( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + work: u64, + start: i32, + count: i32, + result: u64, + ) -> Result<()> { + unsafe { launch_sparse_qr_norm_f64(ctx, stream, dev, work, start, count, result) } + } + unsafe fn launch_householder( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + work: u64, + start: i32, + m: i32, + norm_sq: u64, + out_v: u64, + out_tau: u64, + out_diag: u64, + ) -> Result<()> { + unsafe { + launch_sparse_qr_householder_f64( + ctx, stream, dev, work, start, m, norm_sq, out_v, out_tau, out_diag, + ) + } + } + unsafe fn launch_extract_r( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + work: u64, + count: i32, + output: u64, + ) -> Result<()> { + unsafe { launch_sparse_qr_extract_r_f64(ctx, stream, dev, work, count, output) } + } + fn structural_to_f64(tensor: &Tensor, count: usize) -> Vec { + if count == 0 { + return vec![]; + } + tensor.to_vec::().iter().copied().take(count).collect() + } +} diff --git a/src/algorithm/sparse_linalg/qr/cuda/mod.rs b/src/algorithm/sparse_linalg/qr/cuda/mod.rs new file mode 100644 index 00000000..2511068e --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/cuda/mod.rs @@ -0,0 +1,8 @@ +//! CUDA implementation of sparse Householder QR factorization + +mod factorize; +mod qr; +mod solve; + +pub use qr::{sparse_qr_cuda, sparse_qr_simple_cuda}; +pub use solve::sparse_qr_solve_cuda; diff --git a/src/algorithm/sparse_linalg/qr/cuda/qr.rs b/src/algorithm/sparse_linalg/qr/cuda/qr.rs new file mode 100644 index 00000000..2a01a8f1 --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/cuda/qr.rs @@ -0,0 +1,175 @@ +//! CUDA sparse QR public API: factorize and simple +//! +//! Delegates GPU factorization to `factorize.rs`, solve to `solve.rs`. + +use crate::algorithm::sparse_linalg::qr::symbolic::sparse_qr_symbolic; +use crate::algorithm::sparse_linalg::qr::types::{QrFactors, QrOptions, QrSymbolic}; +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::runtime::cuda::{CudaClient, CudaRuntime}; +use crate::sparse::CscData; + +use super::factorize::run_factorization; + +/// Sparse QR factorization with precomputed symbolic information (CUDA) +/// +/// Uses GPU kernels with zero intermediate transfers. Householder vectors and tau +/// stay GPU-resident. Only R structural data (diag, off-diag) transferred to CPU +/// for CSC construction. +pub fn sparse_qr_cuda( + client: &CudaClient, + a: &CscData, + symbolic: &QrSymbolic, + options: &QrOptions, +) -> Result> { + let [m, n] = a.shape; + let dtype = a.values().dtype(); + + if dtype != DType::F32 && dtype != DType::F64 { + return Err(Error::UnsupportedDType { + dtype, + op: "sparse_qr_cuda", + }); + } + + if m != symbolic.m || n != symbolic.n { + return Err(Error::ShapeMismatch { + expected: vec![symbolic.m, symbolic.n], + got: vec![m, n], + }); + } + + if m < n { + return Err(Error::Internal( + "sparse_qr: requires m >= n (more rows than columns)".to_string(), + )); + } + + match dtype { + DType::F32 => run_factorization::(client, a, symbolic, options, m, n), + DType::F64 => run_factorization::(client, a, symbolic, options, m, n), + _ => unreachable!(), + } +} + +/// Sparse QR factorization without precomputed symbolic information (CUDA) +/// +/// `col_ptrs_host` and `row_indices_host` must be the CPU-resident structural +/// data for `a` (the same values used to construct `a` via `CscData::from_slices`). +/// These are kept CPU-side to avoid GPU→CPU transfers during symbolic analysis, +/// which requires irregular graph traversal and runs on the CPU. +pub fn sparse_qr_simple_cuda( + client: &CudaClient, + a: &CscData, + col_ptrs_host: &[i64], + row_indices_host: &[i64], + options: &QrOptions, +) -> Result> { + let [m, n] = a.shape; + let symbolic = sparse_qr_symbolic(col_ptrs_host, row_indices_host, m, n, options)?; + sparse_qr_cuda(client, a, &symbolic, options) +} + +#[cfg(test)] +mod tests { + use super::super::sparse_qr_solve_cuda; + use super::*; + use crate::runtime::cuda::CudaDevice; + use crate::tensor::Tensor; + + fn cuda_setup() -> Option<(::Device, CudaClient)> { + if !crate::runtime::cuda::is_cuda_available() { + return None; + } + let device = ::Device::new(0); + let client = CudaClient::new(CudaDevice::new(0)).expect("CUDA device required"); + Some((device, client)) + } + + #[test] + fn test_sparse_qr_cuda_simple_square() { + let Some((device, client)) = cuda_setup() else { + return; + }; + + let col_ptrs = vec![0i64, 2, 5, 8, 10]; + let row_indices = vec![0i64, 1, 0, 1, 2, 1, 2, 3, 2, 3]; + let values = vec![4.0f64, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0]; + let a = + CscData::::from_slices(&col_ptrs, &row_indices, &values, [4, 4], &device) + .unwrap(); + + let options = QrOptions::no_ordering(); + let factors = + sparse_qr_simple_cuda(&client, &a, &col_ptrs, &row_indices, &options).unwrap(); + + assert_eq!(factors.rank, 4); + // GPU factorization keeps Householder data GPU-resident only + assert!(factors.gpu_householder_values.is_some()); + assert!(factors.gpu_tau.is_some()); + } + + #[test] + fn test_sparse_qr_cuda_solve() { + let Some((device, client)) = cuda_setup() else { + return; + }; + + let col_ptrs = vec![0i64, 2, 5, 8, 10]; + let row_indices = vec![0i64, 1, 0, 1, 2, 1, 2, 3, 2, 3]; + let values = vec![4.0f64, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0]; + let a = + CscData::::from_slices(&col_ptrs, &row_indices, &values, [4, 4], &device) + .unwrap(); + + let options = QrOptions::no_ordering(); + let factors = + sparse_qr_simple_cuda(&client, &a, &col_ptrs, &row_indices, &options).unwrap(); + + let b = Tensor::::from_slice(&[1.0f64, 2.0, 3.0, 4.0], &[4], &device); + let x = sparse_qr_solve_cuda(&client, &factors, &b).unwrap(); + let x_vals: Vec = x.to_vec(); + + // Verify A*x ≈ b + let a_dense: &[&[f64]] = &[ + &[4.0, 1.0, 0.0, 0.0], + &[1.0, 4.0, 1.0, 0.0], + &[0.0, 1.0, 4.0, 1.0], + &[0.0, 0.0, 1.0, 4.0], + ]; + let b_vals = [1.0, 2.0, 3.0, 4.0]; + for i in 0..4 { + let mut ax_i = 0.0; + for j in 0..4 { + ax_i += a_dense[i][j] * x_vals[j]; + } + assert!( + (ax_i - b_vals[i]).abs() < 1e-8, + "A*x[{}] = {}, expected {}", + i, + ax_i, + b_vals[i] + ); + } + } + + #[test] + fn test_sparse_qr_cuda_f32() { + let Some((device, client)) = cuda_setup() else { + return; + }; + + let col_ptrs = vec![0i64, 2, 5, 8, 10]; + let row_indices = vec![0i64, 1, 0, 1, 2, 1, 2, 3, 2, 3]; + let values = vec![4.0f32, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0]; + let a = + CscData::::from_slices(&col_ptrs, &row_indices, &values, [4, 4], &device) + .unwrap(); + + let options = QrOptions::no_ordering(); + let factors = + sparse_qr_simple_cuda(&client, &a, &col_ptrs, &row_indices, &options).unwrap(); + + assert_eq!(factors.rank, 4); + } +} diff --git a/src/algorithm/sparse_linalg/qr/cuda/solve.rs b/src/algorithm/sparse_linalg/qr/cuda/solve.rs new file mode 100644 index 00000000..af299db3 --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/cuda/solve.rs @@ -0,0 +1,380 @@ +//! GPU-resident QR solve for CUDA +//! +//! Solves A*x = b using precomputed QR factors entirely on GPU. +//! No CPU↔GPU data transfers except final result retrieval by the caller. +//! +//! Steps: +//! 1. Q^T * b: apply Householder reflectors via `apply_reflector` kernels +//! 2. R \ (Q^T b): level-scheduled upper triangular solve on GPU +//! 3. Column permutation: scatter kernel with inverse permutation + +use crate::algorithm::sparse_linalg::qr::cpu::helpers::h_offset; +use crate::algorithm::sparse_linalg::qr::types::QrFactors; +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::runtime::cuda::kernels::{ + launch_apply_row_perm_f32, launch_apply_row_perm_f64, launch_find_diag_indices_csc, + launch_sparse_qr_apply_reflector_f32, launch_sparse_qr_apply_reflector_f64, + launch_sparse_trsv_csc_upper_level_f32, launch_sparse_trsv_csc_upper_level_f64, +}; +use crate::runtime::cuda::{CudaClient, CudaRuntime}; +use crate::tensor::Tensor; + +/// Solve A*x = b using precomputed QR factors, fully on GPU. +/// +/// Requires `factors.gpu_householder_values` and `factors.gpu_tau` to be populated +/// (they are set automatically by `sparse_qr_cuda`). +pub fn sparse_qr_solve_cuda( + client: &CudaClient, + factors: &QrFactors, + b: &Tensor, +) -> Result> { + let [m, n] = factors.r.shape; + let b_shape = b.shape(); + + if b_shape.is_empty() || b_shape[0] != m { + return Err(Error::ShapeMismatch { + expected: vec![m], + got: b_shape.to_vec(), + }); + } + + if factors.rank < n { + return Err(Error::Internal(format!( + "sparse_qr_solve: matrix is rank-deficient (rank {} < n {})", + factors.rank, n + ))); + } + + let dtype = b.dtype(); + if dtype != DType::F32 && dtype != DType::F64 { + return Err(Error::UnsupportedDType { + dtype, + op: "sparse_qr_solve_cuda", + }); + } + + let gpu_h = factors.gpu_householder_values.as_ref().ok_or_else(|| { + Error::Internal("sparse_qr_solve_cuda: GPU Householder vectors not available".to_string()) + })?; + let gpu_tau = factors.gpu_tau.as_ref().ok_or_else(|| { + Error::Internal("sparse_qr_solve_cuda: GPU tau not available".to_string()) + })?; + + match dtype { + DType::F32 => solve_impl::(client, factors, b, gpu_h, gpu_tau, m, n), + DType::F64 => solve_impl::(client, factors, b, gpu_h, gpu_tau, m, n), + _ => unreachable!(), + } +} + +trait SolveScalar: Sized { + const ELEM_SIZE: usize; + + unsafe fn launch_apply_reflector( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + v: u64, + v_start: i32, + v_len: i32, + tau_ptr: u64, + work: u64, + m: i32, + ) -> Result<()>; + + unsafe fn launch_trsv_upper_level( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + level_cols: u64, + level_size: i32, + col_ptrs: u64, + row_indices: u64, + values: u64, + diag_ptr: u64, + x: u64, + n: i32, + ) -> Result<()>; + + unsafe fn launch_perm( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + b: u64, + perm: u64, + y: u64, + n: i32, + ) -> Result<()>; +} + +impl SolveScalar for f32 { + const ELEM_SIZE: usize = 4; + + unsafe fn launch_apply_reflector( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + v: u64, + v_start: i32, + v_len: i32, + tau_ptr: u64, + work: u64, + m: i32, + ) -> Result<()> { + unsafe { + launch_sparse_qr_apply_reflector_f32( + ctx, stream, dev, v, v_start, v_len, tau_ptr, work, m, + ) + } + } + + unsafe fn launch_trsv_upper_level( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + level_cols: u64, + level_size: i32, + col_ptrs: u64, + row_indices: u64, + values: u64, + diag_ptr: u64, + x: u64, + n: i32, + ) -> Result<()> { + unsafe { + launch_sparse_trsv_csc_upper_level_f32( + ctx, + stream, + dev, + level_cols, + level_size, + col_ptrs, + row_indices, + values, + diag_ptr, + x, + n, + ) + } + } + + unsafe fn launch_perm( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + b: u64, + perm: u64, + y: u64, + n: i32, + ) -> Result<()> { + unsafe { launch_apply_row_perm_f32(ctx, stream, dev, b, perm, y, n) } + } +} + +impl SolveScalar for f64 { + const ELEM_SIZE: usize = 8; + + unsafe fn launch_apply_reflector( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + v: u64, + v_start: i32, + v_len: i32, + tau_ptr: u64, + work: u64, + m: i32, + ) -> Result<()> { + unsafe { + launch_sparse_qr_apply_reflector_f64( + ctx, stream, dev, v, v_start, v_len, tau_ptr, work, m, + ) + } + } + + unsafe fn launch_trsv_upper_level( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + level_cols: u64, + level_size: i32, + col_ptrs: u64, + row_indices: u64, + values: u64, + diag_ptr: u64, + x: u64, + n: i32, + ) -> Result<()> { + unsafe { + launch_sparse_trsv_csc_upper_level_f64( + ctx, + stream, + dev, + level_cols, + level_size, + col_ptrs, + row_indices, + values, + diag_ptr, + x, + n, + ) + } + } + + unsafe fn launch_perm( + ctx: &std::sync::Arc, + stream: &cudarc::driver::safe::CudaStream, + dev: usize, + b: u64, + perm: u64, + y: u64, + n: i32, + ) -> Result<()> { + unsafe { launch_apply_row_perm_f64(ctx, stream, dev, b, perm, y, n) } + } +} + +fn solve_impl( + client: &CudaClient, + factors: &QrFactors, + b: &Tensor, + gpu_h: &Tensor, + gpu_tau: &Tensor, + m: usize, + n: usize, +) -> Result> { + use crate::algorithm::sparse_linalg::levels::{compute_levels_csc_upper, flatten_levels}; + + let min_mn = m.min(n); + let dtype = b.dtype(); + let device = b.device(); + let context = &client.context; + let stream = &client.stream; + let dev = client.device.index; + let elem_size = T::ELEM_SIZE as u64; + + // ======================================================================== + // Step 1: Copy b into work buffer (GPU-to-GPU) + // ======================================================================== + let work = b.clone(); + let work_ptr = work.ptr(); + + let h_ptr = gpu_h.ptr(); + let tau_ptr = gpu_tau.ptr(); + + // ======================================================================== + // Step 2: Apply Q^T by launching reflector kernels (CPU drives loop) + // ======================================================================== + for k in 0..min_mn { + let v_offset = h_ptr + (h_offset(k, m) as u64) * elem_size; + let tau_k_ptr = tau_ptr + (k as u64) * elem_size; + + unsafe { + T::launch_apply_reflector( + context, + stream, + dev, + v_offset, + k as i32, + (m - k) as i32, + tau_k_ptr, + work_ptr, + m as i32, + )?; + } + } + + // ======================================================================== + // Step 3: Upper triangular solve R * x = (Q^T b)[0:n] + // ======================================================================== + let r_col_ptrs: Vec = factors.r.col_ptrs().to_vec(); + let r_row_indices: Vec = factors.r.row_indices().to_vec(); + + let u_schedule = compute_levels_csc_upper(n, &r_col_ptrs, &r_row_indices)?; + let (u_level_ptrs, u_level_cols) = flatten_levels(&u_schedule); + + // Upload structure to GPU + let r_col_ptrs_i32: Vec = r_col_ptrs.iter().map(|&x| x as i32).collect(); + let r_row_indices_i32: Vec = r_row_indices.iter().map(|&x| x as i32).collect(); + let r_col_ptrs_gpu = + Tensor::::from_slice(&r_col_ptrs_i32, &[r_col_ptrs_i32.len()], &device); + let r_row_indices_gpu = + Tensor::::from_slice(&r_row_indices_i32, &[r_row_indices_i32.len()], &device); + let u_level_cols_gpu = + Tensor::::from_slice(&u_level_cols, &[u_level_cols.len()], &device); + + // Find diagonal indices on GPU + let u_diag_ptr_gpu = Tensor::::zeros(&[n], DType::I32, &device); + unsafe { + launch_find_diag_indices_csc( + context, + stream, + dev, + r_col_ptrs_gpu.ptr(), + r_row_indices_gpu.ptr(), + u_diag_ptr_gpu.ptr(), + n as i32, + )?; + } + + // Launch level-scheduled upper triangular solve + // work[0:n] = R^{-1} * work[0:n] + let idx_size = std::mem::size_of::() as u64; + for level in 0..u_level_ptrs.len().saturating_sub(1) { + let offset = u_level_ptrs[level]; + let size = (u_level_ptrs[level + 1] - u_level_ptrs[level]) as i32; + if size == 0 { + continue; + } + + // Offset the level_cols pointer to point at this level's columns + let level_cols_ptr = u_level_cols_gpu.ptr() + (offset as u64) * idx_size; + + unsafe { + T::launch_trsv_upper_level( + context, + stream, + dev, + level_cols_ptr, + size, + r_col_ptrs_gpu.ptr(), + r_row_indices_gpu.ptr(), + factors.r.values().ptr(), + u_diag_ptr_gpu.ptr(), + work_ptr, + n as i32, + )?; + } + } + + // ======================================================================== + // Step 4: Apply column permutation x_out[col_perm[k]] = work[k] + // ======================================================================== + let mut inv_perm = vec![0i32; n]; + for (k, &orig_col) in factors.col_perm.iter().enumerate() { + inv_perm[orig_col] = k as i32; + } + let inv_perm_gpu = Tensor::::from_slice(&inv_perm, &[n], &device); + + let result = Tensor::::zeros(&[n], dtype, &device); + unsafe { + T::launch_perm( + context, + stream, + dev, + work_ptr, + inv_perm_gpu.ptr(), + result.ptr(), + n as i32, + )?; + } + + client + .stream + .synchronize() + .map_err(|e| Error::Internal(format!("CUDA stream sync failed: {:?}", e)))?; + + Ok(result) +} diff --git a/src/algorithm/sparse_linalg/qr/mod.rs b/src/algorithm/sparse_linalg/qr/mod.rs new file mode 100644 index 00000000..01416458 --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/mod.rs @@ -0,0 +1,60 @@ +//! Sparse QR Factorization +//! +//! Householder QR factorization for sparse matrices: A*P = Q*R +//! +//! # Algorithm +//! +//! Column-wise left-looking Householder QR: +//! +//! ```text +//! For each column k = 0 to min(m, n) - 1: +//! 1. Apply previous reflectors to column k +//! 2. Compute Householder reflector for column k below diagonal +//! 3. Store R[0:k+1, k] and Householder vector v_k, tau_k +//! ``` +//! +//! # Usage +//! +//! ```ignore +//! use numr::algorithm::sparse_linalg::qr::*; +//! +//! // Simple factorization +//! let factors = sparse_qr_simple_cpu(&matrix, &QrOptions::default())?; +//! +//! // Solve Ax = b +//! let x = sparse_qr_solve_cpu(&factors, &b)?; +//! +//! // Least-squares min ||Ax - b|| +//! let x = sparse_qr_least_squares_cpu(&factors, &b)?; +//! ``` + +pub mod cpu; +pub mod symbolic; +pub mod traits; +pub mod types; + +#[cfg(feature = "cuda")] +pub mod cuda; + +#[cfg(feature = "wgpu")] +pub mod wgpu; + +// Re-export types +pub use types::{QrFactors, QrMetrics, QrOptions, QrOrdering, QrSymbolic}; + +// Re-export symbolic analysis +pub use symbolic::sparse_qr_symbolic; + +// Re-export CPU implementations +pub use cpu::{ + sparse_qr_cpu, sparse_qr_cpu_with_metrics, sparse_qr_least_squares_cpu, sparse_qr_simple_cpu, + sparse_qr_solve_cpu, +}; + +// Re-export CUDA implementations +#[cfg(feature = "cuda")] +pub use cuda::{sparse_qr_cuda, sparse_qr_simple_cuda, sparse_qr_solve_cuda}; + +// Re-export WebGPU implementations +#[cfg(feature = "wgpu")] +pub use wgpu::{sparse_qr_simple_wgpu, sparse_qr_solve_wgpu, sparse_qr_wgpu}; diff --git a/src/algorithm/sparse_linalg/qr/symbolic.rs b/src/algorithm/sparse_linalg/qr/symbolic.rs new file mode 100644 index 00000000..e3f2791d --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/symbolic.rs @@ -0,0 +1,288 @@ +//! Symbolic analysis for sparse QR factorization +//! +//! Computes the elimination tree and column counts for R without +//! forming A^T*A explicitly. Uses the row structure of A instead. + +use crate::algorithm::sparse_linalg::ordering::{ColamdOptions, colamd}; +use crate::error::Result; + +use super::types::{QrOptions, QrOrdering, QrSymbolic}; + +/// Compute symbolic analysis for sparse QR factorization +/// +/// # Arguments +/// +/// * `col_ptrs` - CSC column pointers `[n+1]` +/// * `row_indices` - CSC row indices `[nnz]` +/// * `m` - Number of rows +/// * `n` - Number of columns +/// * `options` - QR options (ordering strategy) +/// +/// # Returns +/// +/// Symbolic structure with elimination tree, column counts, and permutation. +pub fn sparse_qr_symbolic( + col_ptrs: &[i64], + row_indices: &[i64], + m: usize, + n: usize, + options: &QrOptions, +) -> Result { + // Step 1: Compute column permutation + let col_perm = match options.ordering { + QrOrdering::Identity => (0..n).collect::>(), + QrOrdering::Colamd => { + let colamd_opts = ColamdOptions::default(); + let (perm, _stats) = colamd(m, n, col_ptrs, row_indices, &colamd_opts)?; + perm + } + }; + + // Step 2: Build permuted column pointers and row indices + let (perm_col_ptrs, perm_row_indices) = permute_columns(col_ptrs, row_indices, n, &col_perm); + + // Step 3: Compute elimination tree of A^T*A from row structure of A + let etree = compute_etree_ata(&perm_col_ptrs, &perm_row_indices, m, n); + + // Step 4: Compute column counts for R using etree + let r_col_counts = compute_r_col_counts(&perm_col_ptrs, &perm_row_indices, &etree, m, n); + + let predicted_r_nnz: usize = r_col_counts.iter().sum(); + + Ok(QrSymbolic { + m, + n, + etree, + r_col_counts, + col_perm, + predicted_r_nnz, + }) +} + +/// Permute columns of a CSC matrix according to a permutation vector +fn permute_columns( + col_ptrs: &[i64], + row_indices: &[i64], + n: usize, + perm: &[usize], +) -> (Vec, Vec) { + // Count entries per new column + let mut new_counts = vec![0usize; n]; + for new_col in 0..n { + let old_col = perm[new_col]; + let start = col_ptrs[old_col] as usize; + let end = col_ptrs[old_col + 1] as usize; + new_counts[new_col] = end - start; + } + + // Build new column pointers + let mut new_col_ptrs = vec![0i64; n + 1]; + for j in 0..n { + new_col_ptrs[j + 1] = new_col_ptrs[j] + new_counts[j] as i64; + } + + // Copy row indices in new column order + let total_nnz = new_col_ptrs[n] as usize; + let mut new_row_indices = vec![0i64; total_nnz]; + for new_col in 0..n { + let old_col = perm[new_col]; + let old_start = col_ptrs[old_col] as usize; + let old_end = col_ptrs[old_col + 1] as usize; + let new_start = new_col_ptrs[new_col] as usize; + + for (i, &row) in row_indices[old_start..old_end].iter().enumerate() { + new_row_indices[new_start + i] = row; + } + } + + (new_col_ptrs, new_row_indices) +} + +/// Compute the elimination tree of A^T*A from the row structure of A. +/// +/// Uses the column-based algorithm from Gilbert, Ng, Peyton (1994). +/// For each column j (processed left to right), we look at every row i +/// that column j touches. For that row, if we've seen a previous column k < j +/// that also touches row i, then we follow k's path up the tree (path compression) +/// to find its root r, and set parent[r] = j. +/// +/// This correctly builds the etree without forming A^T*A. +fn compute_etree_ata(col_ptrs: &[i64], row_indices: &[i64], m: usize, n: usize) -> Vec { + let mut parent = vec![-1i64; n]; + // ancestor[j] used for path compression in union-find + let mut ancestor = vec![0usize; n]; + for j in 0..n { + ancestor[j] = j; + } + // first_col[row] = first column that touches this row, or usize::MAX if none yet + let mut first_col = vec![usize::MAX; m]; + + for j in 0..n { + // Mark column j as its own ancestor (fresh) + ancestor[j] = j; + + let start = col_ptrs[j] as usize; + let end = col_ptrs[j + 1] as usize; + + for &row in &row_indices[start..end] { + let row = row as usize; + let k = first_col[row]; + if k == usize::MAX { + // First column to touch this row + first_col[row] = j; + } else { + // Column k < j also touches this row → they share a row + // Find root of k with path compression + let mut r = k; + while ancestor[r] != r { + r = ancestor[r]; + } + // Path compression + let mut node = k; + while node != r { + let next = ancestor[node]; + ancestor[node] = r; + node = next; + } + + if r != j { + // Set parent of root to j + parent[r] = j as i64; + ancestor[r] = j; + } + } + } + } + + parent +} + +/// Compute upper bound on R column counts using the elimination tree. +/// +/// For each column j, the column count in R is at most the number of +/// original rows in column j plus fill-in from the etree descendants. +fn compute_r_col_counts( + col_ptrs: &[i64], + _row_indices: &[i64], + etree: &[i64], + m: usize, + n: usize, +) -> Vec { + // Simple upper bound: for each column, count unique rows that appear + // in the column and all its descendants in the etree + // + // For a tighter bound we'd need the row subtree approach, but this + // conservative estimate is sufficient for pre-allocation. + + // Start with direct column counts (capped at min(m, col_index + 1)) + let mut counts = vec![0usize; n]; + for col in 0..n { + let start = col_ptrs[col] as usize; + let end = col_ptrs[col + 1] as usize; + // Number of entries in this column, capped at entries that can be in R + // (only rows 0..=col for R's upper triangular structure, for square; + // for rectangular, min(m, col+1)) + let direct = end - start; + counts[col] = direct.min(m); + } + + // Propagate counts up the etree (children contribute to parent's count) + // Process in reverse order (leaves first) + // This is a conservative estimate - actual fill depends on row overlap + for j in 0..n { + let parent = etree[j]; + if parent >= 0 && (parent as usize) < n { + // Parent gains at most the child's count minus 1 (the diagonal) + let contribution = if counts[j] > 0 { counts[j] - 1 } else { 0 }; + counts[parent as usize] = counts[parent as usize].max(contribution + 1); + } + } + + // Ensure each column has at least 1 entry (the diagonal of R, if rank allows) + for count in &mut counts { + *count = (*count).max(1); + } + + counts +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_symbolic_identity_ordering() { + // 3x3 diagonal matrix + let col_ptrs = vec![0i64, 1, 2, 3]; + let row_indices = vec![0i64, 1, 2]; + + let options = QrOptions::no_ordering(); + let symbolic = sparse_qr_symbolic(&col_ptrs, &row_indices, 3, 3, &options).unwrap(); + + assert_eq!(symbolic.m, 3); + assert_eq!(symbolic.n, 3); + assert_eq!(symbolic.col_perm, vec![0, 1, 2]); + } + + #[test] + fn test_symbolic_tridiagonal() { + // 4x4 tridiagonal matrix: + // [x . . .] + // [x x . .] + // [. x x .] + // [. . x x] + let col_ptrs = vec![0i64, 2, 4, 6, 7]; + let row_indices = vec![0i64, 1, 1, 2, 2, 3, 3]; + + let options = QrOptions::no_ordering(); + let symbolic = sparse_qr_symbolic(&col_ptrs, &row_indices, 4, 4, &options).unwrap(); + + assert_eq!(symbolic.m, 4); + assert_eq!(symbolic.n, 4); + // Each column should have a reasonable count + for &count in &symbolic.r_col_counts { + assert!(count >= 1); + } + } + + #[test] + fn test_symbolic_with_colamd() { + // 4x3 overdetermined matrix + let col_ptrs = vec![0i64, 3, 5, 7]; + let row_indices = vec![0i64, 1, 2, 1, 3, 0, 3]; + + let options = QrOptions::default(); // uses Colamd + let symbolic = sparse_qr_symbolic(&col_ptrs, &row_indices, 4, 3, &options).unwrap(); + + assert_eq!(symbolic.m, 4); + assert_eq!(symbolic.n, 3); + assert_eq!(symbolic.col_perm.len(), 3); + // Permutation should be a valid permutation of 0..3 + let mut sorted_perm = symbolic.col_perm.clone(); + sorted_perm.sort_unstable(); + assert_eq!(sorted_perm, vec![0, 1, 2]); + } + + #[test] + fn test_etree_chain() { + // Matrix where columns share rows in a chain pattern + // Col 0: rows {0, 1} + // Col 1: rows {1, 2} + // Col 2: rows {2, 3} + let col_ptrs = vec![0i64, 2, 4, 6]; + let row_indices = vec![0i64, 1, 1, 2, 2, 3]; + + let etree = compute_etree_ata(&col_ptrs, &row_indices, 4, 3); + + // Col 0 and 1 share row 1, so etree[0] = 1 + // Col 1 and 2 share row 2, so etree[1] = 2 + // Col 2 is root + assert_eq!(etree[0], 1); + assert_eq!(etree[1], 2); + assert_eq!(etree[2], -1); + } +} diff --git a/src/algorithm/sparse_linalg/qr/traits.rs b/src/algorithm/sparse_linalg/qr/traits.rs new file mode 100644 index 00000000..622b5638 --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/traits.rs @@ -0,0 +1,5 @@ +//! Trait definitions for sparse QR factorization +//! +//! Sparse QR uses free functions per backend (sparse_qr_cpu, sparse_qr_cuda, etc.) +//! rather than a trait-based dispatch pattern, because the CPU implementation +//! operates on extracted f64 data while GPU backends will need native kernels. diff --git a/src/algorithm/sparse_linalg/qr/types.rs b/src/algorithm/sparse_linalg/qr/types.rs new file mode 100644 index 00000000..c09e1b97 --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/types.rs @@ -0,0 +1,153 @@ +//! Types for sparse QR factorization +//! +//! Contains factorization results, symbolic structures, and options. + +use crate::runtime::Runtime; +use crate::sparse::CscData; +use crate::tensor::Tensor; + +// ============================================================================ +// QR Factorization Types +// ============================================================================ + +/// Result of sparse Householder QR factorization: A*P = Q*R +/// +/// Q is stored implicitly as a sequence of Householder reflectors. +/// R is stored explicitly in CSC format. +/// P is the column permutation from COLAMD ordering. +/// +/// For GPU backends, Householder vectors and tau are stored GPU-resident only +/// (`gpu_householder_values`, `gpu_tau`), and the CPU sparse fields +/// (`householder_vectors`, `tau`) are empty. GPU solve uses the GPU tensors +/// directly. CPU factorization populates the CPU fields instead. +#[derive(Debug, Clone)] +pub struct QrFactors { + /// Householder reflectors stored as sparse vectors (CPU). + /// Each entry is (row_indices, values) for one reflector. + /// Reflector k has support in rows k..m. + /// Empty for GPU-factorized results (use `gpu_householder_values` instead). + pub householder_vectors: Vec<(Vec, Vec)>, + + /// Tau coefficients for each Householder reflector. + /// `Q_k = I - tau_k * v_k * v_k^T` + /// Empty for GPU-factorized results (use `gpu_tau` instead). + pub tau: Vec, + + /// Upper triangular factor R in CSC format. + /// Shape: `[m, n]` but only first `rank` rows of each column are meaningful. + pub r: CscData, + + /// Column permutation from COLAMD ordering. + /// `col_perm[k]` = original column index of the k-th column in the permuted matrix. + pub col_perm: Vec, + + /// Numerical rank detected during factorization. + pub rank: usize, + + /// Dense Householder vectors on GPU (optional, for GPU-resident solve). + /// + /// Flat buffer of length `sum(m-k for k in 0..min(m,n))`. Reflector k is + /// stored at `h_offset(k, m)` with length `m - k`. Only populated by GPU + /// factorization backends; `None` for CPU factorization. + pub gpu_householder_values: Option>, + + /// Tau coefficients on GPU (optional, for GPU-resident solve). + /// + /// Length `min(m, n)`. Only populated by GPU factorization backends. + pub gpu_tau: Option>, +} + +/// Symbolic analysis for sparse QR factorization +/// +/// Precomputed structural information based on the sparsity pattern. +/// Reusable for multiple numeric factorizations with the same pattern. +#[derive(Debug, Clone)] +pub struct QrSymbolic { + /// Number of rows + pub m: usize, + + /// Number of columns + pub n: usize, + + /// Elimination tree for R: `etree[j]` = parent of column j, or -1 if root. + /// Derived from the column structure of A^T*A without forming it explicitly. + pub etree: Vec, + + /// Predicted column counts for R (upper bound on nnz per column). + pub r_col_counts: Vec, + + /// Column permutation from COLAMD. + pub col_perm: Vec, + + /// Predicted total nnz in R. + pub predicted_r_nnz: usize, +} + +impl QrSymbolic { + /// Create a trivial symbolic structure (identity permutation, no etree). + pub fn identity(m: usize, n: usize) -> Self { + Self { + m, + n, + etree: vec![-1; n], + r_col_counts: vec![1; n], + col_perm: (0..n).collect(), + predicted_r_nnz: n, + } + } +} + +/// Configuration for sparse QR factorization +#[derive(Debug, Clone)] +pub struct QrOptions { + /// Tolerance for rank detection (default: 1e-12). + /// Diagonal entries of R with absolute value below this are treated as zero. + pub rank_tolerance: f64, + + /// Column ordering strategy. + pub ordering: QrOrdering, +} + +impl Default for QrOptions { + fn default() -> Self { + Self { + rank_tolerance: 1e-12, + ordering: QrOrdering::Colamd, + } + } +} + +impl QrOptions { + /// Create options with no column ordering. + pub fn no_ordering() -> Self { + Self { + ordering: QrOrdering::Identity, + ..Default::default() + } + } +} + +/// Column ordering strategy for QR factorization +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QrOrdering { + /// No column permutation (identity permutation, original column order). + Identity, + /// COLAMD approximate minimum degree ordering. + Colamd, +} + +/// Metrics from QR factorization for diagnostics +#[derive(Debug, Clone)] +pub struct QrMetrics { + /// Number of non-zeros in original matrix + pub original_nnz: usize, + + /// Number of non-zeros in R factor + pub r_nnz: usize, + + /// Fill ratio: r_nnz / original_nnz + pub fill_ratio: f64, + + /// Numerical rank detected + pub numerical_rank: usize, +} diff --git a/src/algorithm/sparse_linalg/qr/wgpu/factorize.rs b/src/algorithm/sparse_linalg/qr/wgpu/factorize.rs new file mode 100644 index 00000000..c9f4f7d1 --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/wgpu/factorize.rs @@ -0,0 +1,711 @@ +//! WebGPU GPU factorization loop for sparse Householder QR +//! +//! F32 only. Same architecture as CUDA: dense Householder vectors on GPU, +//! structure-driven column loop on CPU. Householder vectors and tau stay +//! GPU-resident; only R structural data transferred to CPU for CSC construction. + +use wgpu::{BufferDescriptor, BufferUsages}; + +use crate::algorithm::sparse_linalg::qr::cpu::helpers::{ + build_r_csc, create_r_tensor, detect_rank, h_offset, r_offdiag_offset, +}; +use crate::algorithm::sparse_linalg::qr::types::{QrFactors, QrOptions, QrSymbolic}; +use crate::error::{Error, Result}; +use crate::runtime::wgpu::client::get_buffer; +use crate::runtime::wgpu::shaders::{LayoutKey, workgroup_count}; +use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; +use crate::sparse::CscData; +use crate::tensor::Tensor; + +/// Run the WebGPU factorization for f32 +pub(super) fn run_factorization_wgpu( + client: &WgpuClient, + a: &CscData, + symbolic: &QrSymbolic, + options: &QrOptions, +) -> Result> { + let [m, n] = a.shape; + let dtype = a.values().dtype(); + let min_mn = m.min(n); + let device = a.values().device(); + let col_ptrs: Vec = a.col_ptrs().to_vec(); + + // A's row_indices as i32 + let a_row_indices_i32: Vec = a + .row_indices() + .to_vec::() + .iter() + .map(|&x| x as i32) + .collect(); + let a_row_indices_gpu = + Tensor::::from_slice(&a_row_indices_i32, &[a_row_indices_i32.len()], &device); + + // Buffer sizes + let total_h_size = if min_mn > 0 { + h_offset(min_mn - 1, m) + (m - (min_mn - 1)) + } else { + 0 + }; + let total_r_offdiag = min_mn * min_mn.saturating_sub(1) / 2; + + // Allocate GPU buffers + let work_gpu = Tensor::::zeros(&[m], dtype, &device); + let h_values_gpu = Tensor::::zeros(&[total_h_size.max(1)], dtype, &device); + let tau_gpu = Tensor::::zeros(&[min_mn.max(1)], dtype, &device); + let diag_gpu = Tensor::::zeros(&[min_mn.max(1)], dtype, &device); + let r_offdiag_gpu = Tensor::::zeros(&[total_r_offdiag.max(1)], dtype, &device); + let norm_sq_gpu = Tensor::::zeros(&[1], dtype, &device); + + // Get buffer references + let work_buf = get_buffer(work_gpu.ptr()) + .ok_or_else(|| Error::Internal("Invalid work buffer".to_string()))?; + let h_values_buf = get_buffer(h_values_gpu.ptr()) + .ok_or_else(|| Error::Internal("Invalid h_values buffer".to_string()))?; + let tau_buf = get_buffer(tau_gpu.ptr()) + .ok_or_else(|| Error::Internal("Invalid tau buffer".to_string()))?; + let diag_buf = get_buffer(diag_gpu.ptr()) + .ok_or_else(|| Error::Internal("Invalid diag buffer".to_string()))?; + let r_offdiag_buf = get_buffer(r_offdiag_gpu.ptr()) + .ok_or_else(|| Error::Internal("Invalid r_offdiag buffer".to_string()))?; + let norm_sq_buf = get_buffer(norm_sq_gpu.ptr()) + .ok_or_else(|| Error::Internal("Invalid norm_sq buffer".to_string()))?; + let a_values_buf = get_buffer(a.values().ptr()) + .ok_or_else(|| Error::Internal("Invalid A values buffer".to_string()))?; + let a_indices_buf = get_buffer(a_row_indices_gpu.ptr()) + .ok_or_else(|| Error::Internal("Invalid A indices buffer".to_string()))?; + + let cache = &client.pipeline_cache; + let queue = &client.queue; + let wgpu_device = &client.wgpu_device; + + let shader_source = include_str!("../../../../runtime/wgpu/shaders/sparse_linalg.wgsl"); + + // Create pipelines + let pipelines = create_pipelines(cache, shader_source); + + // Create reusable uniform buffers + let uniform_bufs = create_uniform_buffers(wgpu_device); + + // Tau scalar buffer for per-reflector access (WGPU doesn't support buffer offsets) + let tau_scalar_buf = wgpu_device.create_buffer(&BufferDescriptor { + label: Some("qr_tau_scalar"), + size: 4, + usage: BufferUsages::STORAGE | BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + let elem_size = 4u64; // f32 + + // Column loop + for k in 0..min_mn { + dispatch_clear(wgpu_device, queue, &pipelines, &uniform_bufs, &work_buf, m); + + dispatch_scatter( + wgpu_device, + queue, + &pipelines, + &uniform_bufs, + &a_values_buf, + &a_indices_buf, + &work_buf, + &col_ptrs, + &symbolic.col_perm, + k, + ); + + dispatch_apply_reflectors( + wgpu_device, + queue, + &pipelines, + &uniform_bufs, + &h_values_buf, + &tau_buf, + &tau_scalar_buf, + &work_buf, + k, + m, + elem_size, + ); + + dispatch_extract_r( + wgpu_device, + queue, + &pipelines, + &uniform_bufs, + &work_buf, + &r_offdiag_buf, + k, + elem_size, + ); + + dispatch_norm( + wgpu_device, + queue, + &pipelines, + &uniform_bufs, + &work_buf, + &norm_sq_buf, + k, + m, + ); + + dispatch_householder( + wgpu_device, + queue, + &pipelines, + &uniform_bufs, + &work_buf, + &norm_sq_buf, + &h_values_buf, + &tau_buf, + &diag_buf, + k, + m, + elem_size, + ); + } + + // Wait for completion + let _ = wgpu_device.poll(wgpu::PollType::Wait { + submission_index: None, + timeout: Some(std::time::Duration::from_secs(60)), + }); + + // Transfer ONLY R structural data (diag + off-diag) for CSC construction. + // Householder vectors and tau stay GPU-resident — no GPU→CPU transfer. + let diag_cpu_f32: Vec = diag_gpu.to_vec(); + let r_offdiag_cpu_f32: Vec = r_offdiag_gpu.to_vec(); + + let diag_cpu: Vec = diag_cpu_f32 + .iter() + .take(min_mn) + .map(|&x| x as f64) + .collect(); + let r_offdiag_cpu: Vec = r_offdiag_cpu_f32.iter().map(|&x| x as f64).collect(); + + // Build R factor on CPU (small structural data) + let (r_col_ptrs, r_row_indices, r_values) = build_r_csc(&r_offdiag_cpu, &diag_cpu, min_mn, n); + let rank = detect_rank(&diag_cpu, min_mn, options.rank_tolerance); + let r = create_r_tensor::( + m, + n, + &r_col_ptrs, + &r_row_indices, + &r_values, + dtype, + &device, + )?; + + Ok(QrFactors { + // GPU factorization keeps Householder data GPU-resident only. + // CPU sparse representation is empty; use gpu_householder_values for solve. + householder_vectors: Vec::new(), + tau: Vec::new(), + r, + col_perm: symbolic.col_perm.clone(), + rank, + gpu_householder_values: Some(h_values_gpu), + gpu_tau: Some(tau_gpu), + }) +} + +// ============================================================================ +// Pipeline and buffer setup +// ============================================================================ + +struct Pipelines { + scatter: std::sync::Arc, + scatter_layout: std::sync::Arc, + reflector: std::sync::Arc, + reflector_layout: std::sync::Arc, + norm: std::sync::Arc, + norm_layout: std::sync::Arc, + householder: std::sync::Arc, + hh_layout: std::sync::Arc, + extract_r: std::sync::Arc, + extract_layout: std::sync::Arc, + clear: std::sync::Arc, + clear_layout: std::sync::Arc, +} + +struct UniformBuffers { + scatter: wgpu::Buffer, + reflector: wgpu::Buffer, + norm: wgpu::Buffer, + householder: wgpu::Buffer, + extract_r: wgpu::Buffer, + clear: wgpu::Buffer, +} + +fn create_pipelines( + cache: &crate::runtime::wgpu::shaders::PipelineCache, + shader_source: &str, +) -> Pipelines { + let make = |name: &str, entry: &str, num_storage: u32, num_readonly: u32| { + let module = cache.get_or_create_module_from_source(name, shader_source); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: num_storage, + num_uniform_buffers: 1, + num_readonly_storage: num_readonly, + }); + let pipeline = cache.get_or_create_dynamic_pipeline(name, entry, &module, &layout); + (pipeline, layout) + }; + + let (scatter, scatter_layout) = make("sparse_qr_scatter", "sparse_scatter_offset_f32", 3, 2); + let (reflector, reflector_layout) = + make("sparse_qr_reflector", "sparse_qr_apply_reflector_f32", 3, 2); + let (norm, norm_layout) = make("sparse_qr_norm", "sparse_qr_norm_f32", 2, 1); + let (householder, hh_layout) = make("sparse_qr_householder", "sparse_qr_householder_f32", 5, 2); + let (extract_r, extract_layout) = make("sparse_qr_extract", "sparse_qr_extract_r_f32", 2, 1); + let (clear, clear_layout) = make("sparse_qr_clear", "sparse_qr_clear_f32", 1, 0); + + Pipelines { + scatter, + scatter_layout, + reflector, + reflector_layout, + norm, + norm_layout, + householder, + hh_layout, + extract_r, + extract_layout, + clear, + clear_layout, + } +} + +fn create_uniform_buffers(dev: &wgpu::Device) -> UniformBuffers { + let make = |label| { + dev.create_buffer(&BufferDescriptor { + label: Some(label), + size: 8, + usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST, + mapped_at_creation: false, + }) + }; + UniformBuffers { + scatter: make("qr_scatter_params"), + reflector: make("qr_reflector_params"), + norm: make("qr_norm_params"), + householder: make("qr_hh_params"), + extract_r: make("qr_extract_params"), + clear: make("qr_clear_params"), + } +} + +fn dispatch_clear( + dev: &wgpu::Device, + queue: &wgpu::Queue, + p: &Pipelines, + u: &UniformBuffers, + work_buf: &wgpu::Buffer, + m: usize, +) { + #[repr(C)] + #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] + struct Params { + n: u32, + _alignment: u32, + } + queue.write_buffer( + &u.clear, + 0, + bytemuck::bytes_of(&Params { + n: m as u32, + _alignment: 0, + }), + ); + + let bg = dev.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("qr_clear_bg"), + layout: &p.clear_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: work_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: u.clear.as_entire_binding(), + }, + ], + }); + let mut enc = dev.create_command_encoder(&Default::default()); + { + let mut pass = enc.begin_compute_pass(&Default::default()); + pass.set_pipeline(&p.clear); + pass.set_bind_group(0, Some(&bg), &[]); + pass.dispatch_workgroups(workgroup_count(m), 1, 1); + } + queue.submit(std::iter::once(enc.finish())); +} + +fn dispatch_scatter( + dev: &wgpu::Device, + queue: &wgpu::Queue, + p: &Pipelines, + u: &UniformBuffers, + a_values_buf: &wgpu::Buffer, + a_indices_buf: &wgpu::Buffer, + work_buf: &wgpu::Buffer, + col_ptrs: &[i64], + col_perm: &[usize], + k: usize, +) { + let orig_col = col_perm[k]; + let a_col_start = col_ptrs[orig_col] as u32; + let a_col_end = col_ptrs[orig_col + 1] as u32; + let a_col_nnz = a_col_end - a_col_start; + + if a_col_nnz == 0 { + return; + } + + #[repr(C)] + #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] + struct Params { + offset: u32, + count: u32, + } + queue.write_buffer( + &u.scatter, + 0, + bytemuck::bytes_of(&Params { + offset: a_col_start, + count: a_col_nnz, + }), + ); + + let bg = dev.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("qr_scatter_bg"), + layout: &p.scatter_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: a_values_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: a_indices_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: work_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 3, + resource: u.scatter.as_entire_binding(), + }, + ], + }); + let mut enc = dev.create_command_encoder(&Default::default()); + { + let mut pass = enc.begin_compute_pass(&Default::default()); + pass.set_pipeline(&p.scatter); + pass.set_bind_group(0, Some(&bg), &[]); + pass.dispatch_workgroups(workgroup_count(a_col_nnz as usize), 1, 1); + } + queue.submit(std::iter::once(enc.finish())); +} + +#[allow(clippy::too_many_arguments)] +fn dispatch_apply_reflectors( + dev: &wgpu::Device, + queue: &wgpu::Queue, + p: &Pipelines, + u: &UniformBuffers, + h_values_buf: &wgpu::Buffer, + tau_buf: &wgpu::Buffer, + tau_scalar_buf: &wgpu::Buffer, + work_buf: &wgpu::Buffer, + k: usize, + m: usize, + elem_size: u64, +) { + for j in 0..k { + // Copy tau[j] to scalar buffer (GPU-to-GPU) + let tau_byte_offset = (j as u64) * elem_size; + let mut enc = dev.create_command_encoder(&Default::default()); + enc.copy_buffer_to_buffer(tau_buf, tau_byte_offset, tau_scalar_buf, 0, 4); + queue.submit(std::iter::once(enc.finish())); + + // Extract v sub-range into temp buffer (GPU-to-GPU copy, not CPU transfer) + let v_byte_offset = (h_offset(j, m) as u64) * elem_size; + let v_len = m - j; + let v_byte_len = (v_len as u64) * elem_size; + + let v_temp_buf = dev.create_buffer(&BufferDescriptor { + label: Some("qr_v_temp"), + size: v_byte_len.max(4), + usage: BufferUsages::STORAGE | BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + let mut enc = dev.create_command_encoder(&Default::default()); + enc.copy_buffer_to_buffer(h_values_buf, v_byte_offset, &v_temp_buf, 0, v_byte_len); + queue.submit(std::iter::once(enc.finish())); + + #[repr(C)] + #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] + struct Params { + v_start: u32, + v_len: u32, + } + queue.write_buffer( + &u.reflector, + 0, + bytemuck::bytes_of(&Params { + v_start: j as u32, + v_len: v_len as u32, + }), + ); + + let bg = dev.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("qr_reflector_bg"), + layout: &p.reflector_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: v_temp_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: tau_scalar_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: work_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 3, + resource: u.reflector.as_entire_binding(), + }, + ], + }); + let mut enc = dev.create_command_encoder(&Default::default()); + { + let mut pass = enc.begin_compute_pass(&Default::default()); + pass.set_pipeline(&p.reflector); + pass.set_bind_group(0, Some(&bg), &[]); + pass.dispatch_workgroups(1, 1, 1); + } + queue.submit(std::iter::once(enc.finish())); + } +} + +fn dispatch_extract_r( + dev: &wgpu::Device, + queue: &wgpu::Queue, + p: &Pipelines, + u: &UniformBuffers, + work_buf: &wgpu::Buffer, + r_offdiag_buf: &wgpu::Buffer, + k: usize, + elem_size: u64, +) { + if k == 0 { + return; + } + + #[repr(C)] + #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] + struct Params { + count: u32, + _alignment: u32, + } + queue.write_buffer( + &u.extract_r, + 0, + bytemuck::bytes_of(&Params { + count: k as u32, + _alignment: 0, + }), + ); + + let r_byte_offset = (r_offdiag_offset(k) as u64) * elem_size; + let r_byte_len = (k as u64) * elem_size; + let r_temp_buf = dev.create_buffer(&BufferDescriptor { + label: Some("qr_r_temp"), + size: r_byte_len.max(4), + usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let bg = dev.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("qr_extract_bg"), + layout: &p.extract_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: work_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: r_temp_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: u.extract_r.as_entire_binding(), + }, + ], + }); + let mut enc = dev.create_command_encoder(&Default::default()); + { + let mut pass = enc.begin_compute_pass(&Default::default()); + pass.set_pipeline(&p.extract_r); + pass.set_bind_group(0, Some(&bg), &[]); + pass.dispatch_workgroups(workgroup_count(k), 1, 1); + } + enc.copy_buffer_to_buffer(&r_temp_buf, 0, r_offdiag_buf, r_byte_offset, r_byte_len); + queue.submit(std::iter::once(enc.finish())); +} + +fn dispatch_norm( + dev: &wgpu::Device, + queue: &wgpu::Queue, + p: &Pipelines, + u: &UniformBuffers, + work_buf: &wgpu::Buffer, + norm_sq_buf: &wgpu::Buffer, + k: usize, + m: usize, +) { + #[repr(C)] + #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] + struct Params { + start: u32, + count: u32, + } + queue.write_buffer( + &u.norm, + 0, + bytemuck::bytes_of(&Params { + start: k as u32, + count: (m - k) as u32, + }), + ); + + let bg = dev.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("qr_norm_bg"), + layout: &p.norm_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: work_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: norm_sq_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: u.norm.as_entire_binding(), + }, + ], + }); + let mut enc = dev.create_command_encoder(&Default::default()); + { + let mut pass = enc.begin_compute_pass(&Default::default()); + pass.set_pipeline(&p.norm); + pass.set_bind_group(0, Some(&bg), &[]); + pass.dispatch_workgroups(1, 1, 1); + } + queue.submit(std::iter::once(enc.finish())); +} + +#[allow(clippy::too_many_arguments)] +fn dispatch_householder( + dev: &wgpu::Device, + queue: &wgpu::Queue, + p: &Pipelines, + u: &UniformBuffers, + work_buf: &wgpu::Buffer, + norm_sq_buf: &wgpu::Buffer, + h_values_buf: &wgpu::Buffer, + tau_buf: &wgpu::Buffer, + diag_buf: &wgpu::Buffer, + k: usize, + m: usize, + elem_size: u64, +) { + #[repr(C)] + #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] + struct Params { + start: u32, + m: u32, + } + queue.write_buffer( + &u.householder, + 0, + bytemuck::bytes_of(&Params { + start: k as u32, + m: m as u32, + }), + ); + + let v_len = m - k; + let v_byte_len = (v_len as u64) * elem_size; + let v_byte_offset = (h_offset(k, m) as u64) * elem_size; + + let v_out_buf = dev.create_buffer(&BufferDescriptor { + label: Some("qr_hh_v_out"), + size: v_byte_len.max(4), + usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + let tau_out_buf = dev.create_buffer(&BufferDescriptor { + label: Some("qr_hh_tau_out"), + size: 4, + usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + let diag_out_buf = dev.create_buffer(&BufferDescriptor { + label: Some("qr_hh_diag_out"), + size: 4, + usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let bg = dev.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("qr_hh_bg"), + layout: &p.hh_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: work_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: norm_sq_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: v_out_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 3, + resource: tau_out_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 4, + resource: diag_out_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 5, + resource: u.householder.as_entire_binding(), + }, + ], + }); + let mut enc = dev.create_command_encoder(&Default::default()); + { + let mut pass = enc.begin_compute_pass(&Default::default()); + pass.set_pipeline(&p.householder); + pass.set_bind_group(0, Some(&bg), &[]); + pass.dispatch_workgroups(1, 1, 1); + } + enc.copy_buffer_to_buffer(&v_out_buf, 0, h_values_buf, v_byte_offset, v_byte_len); + enc.copy_buffer_to_buffer(&tau_out_buf, 0, tau_buf, (k as u64) * elem_size, 4); + enc.copy_buffer_to_buffer(&diag_out_buf, 0, diag_buf, (k as u64) * elem_size, 4); + queue.submit(std::iter::once(enc.finish())); +} diff --git a/src/algorithm/sparse_linalg/qr/wgpu/mod.rs b/src/algorithm/sparse_linalg/qr/wgpu/mod.rs new file mode 100644 index 00000000..c9a049b3 --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/wgpu/mod.rs @@ -0,0 +1,9 @@ +//! WebGPU implementation of sparse Householder QR factorization + +#[cfg(feature = "wgpu")] +mod factorize; +mod qr; +mod solve; + +pub use qr::{sparse_qr_simple_wgpu, sparse_qr_wgpu}; +pub use solve::sparse_qr_solve_wgpu; diff --git a/src/algorithm/sparse_linalg/qr/wgpu/qr.rs b/src/algorithm/sparse_linalg/qr/wgpu/qr.rs new file mode 100644 index 00000000..5a3bab4b --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/wgpu/qr.rs @@ -0,0 +1,149 @@ +//! WebGPU sparse QR public API: factorize and simple +//! +//! F32 only. Delegates GPU factorization to `factorize.rs`, solve to `solve.rs`. + +#[cfg(feature = "wgpu")] +use crate::algorithm::sparse_linalg::qr::symbolic::sparse_qr_symbolic; +#[cfg(feature = "wgpu")] +use crate::algorithm::sparse_linalg::qr::types::{QrFactors, QrOptions}; +#[cfg(feature = "wgpu")] +use crate::dtype::DType; +#[cfg(feature = "wgpu")] +use crate::error::{Error, Result}; +#[cfg(feature = "wgpu")] +use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; +#[cfg(feature = "wgpu")] +use crate::sparse::CscData; + +/// Sparse QR factorization with precomputed symbolic information (WebGPU) +/// +/// F32 only. Uses GPU kernels with zero intermediate transfers. +#[cfg(feature = "wgpu")] +pub fn sparse_qr_wgpu( + client: &WgpuClient, + a: &CscData, + symbolic: &crate::algorithm::sparse_linalg::qr::types::QrSymbolic, + options: &QrOptions, +) -> Result> { + let [m, n] = a.shape; + let dtype = a.values().dtype(); + + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "sparse_qr_wgpu", + }); + } + + if m != symbolic.m || n != symbolic.n { + return Err(Error::ShapeMismatch { + expected: vec![symbolic.m, symbolic.n], + got: vec![m, n], + }); + } + + if m < n { + return Err(Error::Internal("sparse_qr: requires m >= n".to_string())); + } + + super::factorize::run_factorization_wgpu(client, a, symbolic, options) +} + +/// Sparse QR factorization without precomputed symbolic information (WebGPU) +/// +/// `col_ptrs_host` and `row_indices_host` must be the CPU-resident structural +/// data for `a` (the same values used to construct `a` via `CscData::from_slices`). +/// These are kept CPU-side to avoid GPU→CPU transfers during symbolic analysis, +/// which requires irregular graph traversal and runs on the CPU. +#[cfg(feature = "wgpu")] +pub fn sparse_qr_simple_wgpu( + client: &WgpuClient, + a: &CscData, + col_ptrs_host: &[i64], + row_indices_host: &[i64], + options: &QrOptions, +) -> Result> { + let [m, n] = a.shape; + let symbolic = sparse_qr_symbolic(col_ptrs_host, row_indices_host, m, n, options)?; + sparse_qr_wgpu(client, a, &symbolic, options) +} + +#[cfg(test)] +#[cfg(feature = "wgpu")] +mod tests { + use super::super::sparse_qr_solve_wgpu; + use super::*; + use crate::tensor::Tensor; + + fn wgpu_device() -> ::Device { + ::Device::default() + } + + fn get_wgpu_client() -> WgpuClient { + WgpuClient::new(wgpu_device()).expect("WGPU device required") + } + + #[test] + fn test_sparse_qr_wgpu_simple_square() { + let device = wgpu_device(); + let client = get_wgpu_client(); + + let col_ptrs = vec![0i64, 2, 5, 8, 10]; + let row_indices = vec![0i64, 1, 0, 1, 2, 1, 2, 3, 2, 3]; + let values = vec![4.0f32, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0]; + let a = + CscData::::from_slices(&col_ptrs, &row_indices, &values, [4, 4], &device) + .unwrap(); + + let options = QrOptions::no_ordering(); + let factors = + sparse_qr_simple_wgpu(&client, &a, &col_ptrs, &row_indices, &options).unwrap(); + + assert_eq!(factors.rank, 4); + // GPU factorization keeps Householder data GPU-resident only + assert!(factors.gpu_householder_values.is_some()); + assert!(factors.gpu_tau.is_some()); + } + + #[test] + fn test_sparse_qr_wgpu_solve() { + let device = wgpu_device(); + let client = get_wgpu_client(); + + let col_ptrs = vec![0i64, 2, 5, 8, 10]; + let row_indices = vec![0i64, 1, 0, 1, 2, 1, 2, 3, 2, 3]; + let values = vec![4.0f32, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0, 1.0, 1.0, 4.0]; + let a = + CscData::::from_slices(&col_ptrs, &row_indices, &values, [4, 4], &device) + .unwrap(); + + let options = QrOptions::no_ordering(); + let factors = + sparse_qr_simple_wgpu(&client, &a, &col_ptrs, &row_indices, &options).unwrap(); + + let b = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4], &device); + let x = sparse_qr_solve_wgpu(&client, &factors, &b).unwrap(); + let x_vals: Vec = x.to_vec(); + + let a_dense: &[&[f32]] = &[ + &[4.0, 1.0, 0.0, 0.0], + &[1.0, 4.0, 1.0, 0.0], + &[0.0, 1.0, 4.0, 1.0], + &[0.0, 0.0, 1.0, 4.0], + ]; + let b_vals = [1.0f32, 2.0, 3.0, 4.0]; + for i in 0..4 { + let mut ax_i: f32 = 0.0; + for j in 0..4 { + ax_i += a_dense[i][j] * x_vals[j]; + } + assert!( + (ax_i - b_vals[i]).abs() < 1e-4, + "A*x[{}] = {}, expected {}", + i, + ax_i, + b_vals[i] + ); + } + } +} diff --git a/src/algorithm/sparse_linalg/qr/wgpu/solve.rs b/src/algorithm/sparse_linalg/qr/wgpu/solve.rs new file mode 100644 index 00000000..ab615b8a --- /dev/null +++ b/src/algorithm/sparse_linalg/qr/wgpu/solve.rs @@ -0,0 +1,490 @@ +//! GPU-resident QR solve for WebGPU (F32 only) +//! +//! Solves A*x = b using precomputed QR factors entirely on GPU. +//! No CPU↔GPU data transfers except final result retrieval by the caller. +//! +//! Steps: +//! 1. Q^T * b: apply Householder reflectors via `apply_reflector` shaders +//! 2. R \ (Q^T b): level-scheduled upper triangular solve on GPU +//! 3. Column permutation: permutation shader with inverse permutation + +#[cfg(feature = "wgpu")] +use crate::algorithm::sparse_linalg::levels::{compute_levels_csc_upper, flatten_levels}; +#[cfg(feature = "wgpu")] +use crate::algorithm::sparse_linalg::qr::cpu::helpers::h_offset; +#[cfg(feature = "wgpu")] +use crate::algorithm::sparse_linalg::qr::types::QrFactors; +#[cfg(feature = "wgpu")] +use crate::dtype::DType; +#[cfg(feature = "wgpu")] +use crate::error::{Error, Result}; +#[cfg(feature = "wgpu")] +use crate::runtime::wgpu::client::get_buffer; +#[cfg(feature = "wgpu")] +use crate::runtime::wgpu::shaders::{LayoutKey, workgroup_count}; +#[cfg(feature = "wgpu")] +use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; +#[cfg(feature = "wgpu")] +use crate::tensor::Tensor; +#[cfg(feature = "wgpu")] +use wgpu::{BufferDescriptor, BufferUsages}; + +/// Solve A*x = b using precomputed QR factors, fully on GPU (F32 only). +/// +/// Requires `factors.gpu_householder_values` and `factors.gpu_tau` to be populated +/// (they are set automatically by `sparse_qr_wgpu`). +#[cfg(feature = "wgpu")] +pub fn sparse_qr_solve_wgpu( + client: &WgpuClient, + factors: &QrFactors, + b: &Tensor, +) -> Result> { + let [m, n] = factors.r.shape; + let b_shape = b.shape(); + + if b_shape.is_empty() || b_shape[0] != m { + return Err(Error::ShapeMismatch { + expected: vec![m], + got: b_shape.to_vec(), + }); + } + + if factors.rank < n { + return Err(Error::Internal(format!( + "sparse_qr_solve: matrix is rank-deficient (rank {} < n {})", + factors.rank, n + ))); + } + + let dtype = b.dtype(); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "sparse_qr_solve_wgpu", + }); + } + + let gpu_h = factors.gpu_householder_values.as_ref().ok_or_else(|| { + Error::Internal("sparse_qr_solve_wgpu: GPU Householder vectors not available".to_string()) + })?; + let gpu_tau = factors.gpu_tau.as_ref().ok_or_else(|| { + Error::Internal("sparse_qr_solve_wgpu: GPU tau not available".to_string()) + })?; + + let min_mn = m.min(n); + let device = b.device(); + let wgpu_device = &client.wgpu_device; + let queue = &client.queue; + let cache = &client.pipeline_cache; + let elem_size: u64 = 4; + + let shader_source = include_str!("../../../../runtime/wgpu/shaders/sparse_linalg.wgsl"); + + // Get GPU buffers + let h_buf = get_buffer(gpu_h.ptr()) + .ok_or_else(|| Error::Internal("Invalid h_values buffer".to_string()))?; + let tau_buf = get_buffer(gpu_tau.ptr()) + .ok_or_else(|| Error::Internal("Invalid tau buffer".to_string()))?; + + // Copy b into work buffer (GPU-to-GPU) + let work = b.clone(); + let work_buf = + get_buffer(work.ptr()).ok_or_else(|| Error::Internal("Invalid work buffer".to_string()))?; + + // ======================================================================== + // Step 1: Apply Q^T via Householder reflectors + // ======================================================================== + let make = |name: &str, entry: &str, num_storage: u32, num_readonly: u32| { + let module = cache.get_or_create_module_from_source(name, shader_source); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: num_storage, + num_uniform_buffers: 1, + num_readonly_storage: num_readonly, + }); + let pipeline = cache.get_or_create_dynamic_pipeline(name, entry, &module, &layout); + (pipeline, layout) + }; + + let (reflector_pipeline, reflector_layout) = + make("sparse_qr_reflector", "sparse_qr_apply_reflector_f32", 3, 2); + + // Temp buffer for scalar tau value + let tau_scalar_buf = wgpu_device.create_buffer(&BufferDescriptor { + label: Some("qr_solve_tau_scalar"), + size: 4, + usage: BufferUsages::STORAGE | BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + // Uniform buffer for reflector params + let reflector_params_buf = wgpu_device.create_buffer(&BufferDescriptor { + label: Some("qr_solve_reflector_params"), + size: 8, + usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + for k in 0..min_mn { + // Copy tau[k] to scalar buffer + let tau_byte_offset = (k as u64) * elem_size; + let mut enc = wgpu_device.create_command_encoder(&Default::default()); + enc.copy_buffer_to_buffer(&tau_buf, tau_byte_offset, &tau_scalar_buf, 0, 4); + queue.submit(std::iter::once(enc.finish())); + + // Copy v sub-range to temp buffer + let v_byte_offset = (h_offset(k, m) as u64) * elem_size; + let v_len = m - k; + let v_byte_len = (v_len as u64) * elem_size; + + let v_temp_buf = wgpu_device.create_buffer(&BufferDescriptor { + label: Some("qr_solve_v_temp"), + size: v_byte_len.max(4), + usage: BufferUsages::STORAGE | BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + let mut enc = wgpu_device.create_command_encoder(&Default::default()); + enc.copy_buffer_to_buffer(&h_buf, v_byte_offset, &v_temp_buf, 0, v_byte_len); + queue.submit(std::iter::once(enc.finish())); + + #[repr(C)] + #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] + struct ReflectorParams { + v_start: u32, + v_len: u32, + } + queue.write_buffer( + &reflector_params_buf, + 0, + bytemuck::bytes_of(&ReflectorParams { + v_start: k as u32, + v_len: v_len as u32, + }), + ); + + let bg = wgpu_device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("qr_solve_reflector_bg"), + layout: &reflector_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: v_temp_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: tau_scalar_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: work_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 3, + resource: reflector_params_buf.as_entire_binding(), + }, + ], + }); + let mut enc = wgpu_device.create_command_encoder(&Default::default()); + { + let mut pass = enc.begin_compute_pass(&Default::default()); + pass.set_pipeline(&reflector_pipeline); + pass.set_bind_group(0, Some(&bg), &[]); + pass.dispatch_workgroups(1, 1, 1); + } + queue.submit(std::iter::once(enc.finish())); + } + + // ======================================================================== + // Step 2: Upper triangular solve R * x = (Q^T b)[0:n] + // ======================================================================== + let r_col_ptrs: Vec = factors.r.col_ptrs().to_vec(); + let r_row_indices: Vec = factors.r.row_indices().to_vec(); + + let u_schedule = compute_levels_csc_upper(n, &r_col_ptrs, &r_row_indices)?; + let (u_level_ptrs, u_level_cols) = flatten_levels(&u_schedule); + + // Upload structure to GPU + let r_col_ptrs_i32: Vec = r_col_ptrs.iter().map(|&x| x as i32).collect(); + let r_row_indices_i32: Vec = r_row_indices.iter().map(|&x| x as i32).collect(); + let r_col_ptrs_gpu = + Tensor::::from_slice(&r_col_ptrs_i32, &[r_col_ptrs_i32.len()], &device); + let r_row_indices_gpu = + Tensor::::from_slice(&r_row_indices_i32, &[r_row_indices_i32.len()], &device); + let u_level_cols_gpu = + Tensor::::from_slice(&u_level_cols, &[u_level_cols.len()], &device); + + let r_col_ptrs_buf = get_buffer(r_col_ptrs_gpu.ptr()) + .ok_or_else(|| Error::Internal("Invalid r_col_ptrs buffer".to_string()))?; + let r_row_indices_buf = get_buffer(r_row_indices_gpu.ptr()) + .ok_or_else(|| Error::Internal("Invalid r_row_indices buffer".to_string()))?; + let u_level_cols_buf = get_buffer(u_level_cols_gpu.ptr()) + .ok_or_else(|| Error::Internal("Invalid u_level_cols buffer".to_string()))?; + let r_values_buf = get_buffer(factors.r.values().ptr()) + .ok_or_else(|| Error::Internal("Invalid r_values buffer".to_string()))?; + + // Find diagonal indices + let u_diag_gpu = Tensor::::zeros(&[n], DType::I32, &device); + let u_diag_buf = get_buffer(u_diag_gpu.ptr()) + .ok_or_else(|| Error::Internal("Invalid u_diag buffer".to_string()))?; + + let find_diag_module = + cache.get_or_create_module_from_source("sparse_find_diag_csc", shader_source); + let find_diag_layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 3, + num_uniform_buffers: 1, + num_readonly_storage: 2, + }); + let find_diag_pipeline = cache.get_or_create_dynamic_pipeline( + "sparse_find_diag_csc", + "find_diag_indices_csc_f32", + &find_diag_module, + &find_diag_layout, + ); + + #[repr(C)] + #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] + struct FindDiagParams { + n: u32, + _p1: u32, + _p2: u32, + _p3: u32, + } + + let find_diag_params_buf = wgpu_device.create_buffer(&BufferDescriptor { + label: Some("qr_solve_find_diag_params"), + size: 16, + usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + queue.write_buffer( + &find_diag_params_buf, + 0, + bytemuck::bytes_of(&FindDiagParams { + n: n as u32, + _p1: 0, + _p2: 0, + _p3: 0, + }), + ); + + { + let bg = wgpu_device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("qr_solve_find_diag_bg"), + layout: &find_diag_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: r_col_ptrs_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: r_row_indices_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: u_diag_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 3, + resource: find_diag_params_buf.as_entire_binding(), + }, + ], + }); + let mut enc = wgpu_device.create_command_encoder(&Default::default()); + { + let mut pass = enc.begin_compute_pass(&Default::default()); + pass.set_pipeline(&find_diag_pipeline); + pass.set_bind_group(0, Some(&bg), &[]); + pass.dispatch_workgroups(workgroup_count(n), 1, 1); + } + queue.submit(std::iter::once(enc.finish())); + } + + // Level-scheduled upper triangular solve + let upper_module = + cache.get_or_create_module_from_source("sparse_trsv_csc_upper", shader_source); + let upper_layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 6, + num_uniform_buffers: 1, + num_readonly_storage: 5, + }); + let upper_pipeline = cache.get_or_create_dynamic_pipeline( + "sparse_trsv_csc_upper", + "sparse_trsv_csc_upper_level_f32", + &upper_module, + &upper_layout, + ); + + #[repr(C)] + #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] + struct TrsvParams { + level_offset: u32, + level_size: u32, + n: u32, + _pad: u32, + } + + let trsv_params_buf = wgpu_device.create_buffer(&BufferDescriptor { + label: Some("qr_solve_trsv_params"), + size: 16, + usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + for level in 0..u_level_ptrs.len().saturating_sub(1) { + let level_start = u_level_ptrs[level] as u32; + let level_end = u_level_ptrs[level + 1] as u32; + let level_size = level_end - level_start; + if level_size == 0 { + continue; + } + + queue.write_buffer( + &trsv_params_buf, + 0, + bytemuck::bytes_of(&TrsvParams { + level_offset: level_start, + level_size, + n: n as u32, + _pad: 0, + }), + ); + + let bg = wgpu_device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("qr_solve_trsv_bg"), + layout: &upper_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: u_level_cols_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: r_col_ptrs_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: r_row_indices_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 3, + resource: r_values_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 4, + resource: u_diag_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 5, + resource: work_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 6, + resource: trsv_params_buf.as_entire_binding(), + }, + ], + }); + let mut enc = wgpu_device.create_command_encoder(&Default::default()); + { + let mut pass = enc.begin_compute_pass(&Default::default()); + pass.set_pipeline(&upper_pipeline); + pass.set_bind_group(0, Some(&bg), &[]); + pass.dispatch_workgroups(workgroup_count(level_size as usize), 1, 1); + } + queue.submit(std::iter::once(enc.finish())); + } + + // ======================================================================== + // Step 3: Apply column permutation + // ======================================================================== + let mut inv_perm = vec![0i32; n]; + for (k, &orig_col) in factors.col_perm.iter().enumerate() { + inv_perm[orig_col] = k as i32; + } + let inv_perm_gpu = Tensor::::from_slice(&inv_perm, &[n], &device); + let inv_perm_buf = get_buffer(inv_perm_gpu.ptr()) + .ok_or_else(|| Error::Internal("Invalid inv_perm buffer".to_string()))?; + + let result = Tensor::::zeros(&[n], dtype, &device); + let result_buf = get_buffer(result.ptr()) + .ok_or_else(|| Error::Internal("Invalid result buffer".to_string()))?; + + let perm_module = cache.get_or_create_module_from_source("sparse_apply_perm", shader_source); + let perm_layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 3, + num_uniform_buffers: 1, + num_readonly_storage: 2, + }); + let perm_pipeline = cache.get_or_create_dynamic_pipeline( + "sparse_apply_perm", + "apply_row_perm_f32", + &perm_module, + &perm_layout, + ); + + #[repr(C)] + #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)] + struct PermParams { + n: u32, + _p1: u32, + _p2: u32, + _p3: u32, + } + + let perm_params_buf = wgpu_device.create_buffer(&BufferDescriptor { + label: Some("qr_solve_perm_params"), + size: 16, + usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + queue.write_buffer( + &perm_params_buf, + 0, + bytemuck::bytes_of(&PermParams { + n: n as u32, + _p1: 0, + _p2: 0, + _p3: 0, + }), + ); + + { + let bg = wgpu_device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("qr_solve_perm_bg"), + layout: &perm_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: work_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: inv_perm_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: result_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 3, + resource: perm_params_buf.as_entire_binding(), + }, + ], + }); + let mut enc = wgpu_device.create_command_encoder(&Default::default()); + { + let mut pass = enc.begin_compute_pass(&Default::default()); + pass.set_pipeline(&perm_pipeline); + pass.set_bind_group(0, Some(&bg), &[]); + pass.dispatch_workgroups(workgroup_count(n), 1, 1); + } + queue.submit(std::iter::once(enc.finish())); + } + + // Wait for completion + let _ = wgpu_device.poll(wgpu::PollType::Wait { + submission_index: None, + timeout: Some(std::time::Duration::from_secs(60)), + }); + + Ok(result) +} diff --git a/src/algorithm/special/constants.rs b/src/algorithm/special/constants.rs new file mode 100644 index 00000000..a59e2468 --- /dev/null +++ b/src/algorithm/special/constants.rs @@ -0,0 +1,37 @@ +//! Mathematical constants and Lanczos coefficients used by special functions. + +// ============================================================================ +// Mathematical Constants +// ============================================================================ + +/// Square root of pi: √π ≈ 1.7724538509055159 +pub const SQRT_PI: f64 = 1.7724538509055160272981674833411451827975; + +/// 2 / √π ≈ 1.1283791670955126 (used in erf) +pub const TWO_OVER_SQRT_PI: f64 = std::f64::consts::FRAC_2_SQRT_PI; + +/// Euler-Mascheroni constant: γ ≈ 0.5772156649015329 +pub const EULER_MASCHERONI: f64 = 0.5772156649015328606065120900824024310422; + +/// ln(√(2π)) ≈ 0.9189385332046727 (used in Stirling's approximation) +pub const LN_SQRT_2PI: f64 = 0.9189385332046727417803297364056176398614; + +// ============================================================================ +// Lanczos Coefficients for Gamma Function +// ============================================================================ + +/// Lanczos approximation coefficients (g=7, n=9). +pub const LANCZOS_G: f64 = 7.0; + +/// Lanczos coefficients for g=7. +pub const LANCZOS_COEFFICIENTS: [f64; 9] = [ + 0.999_999_999_999_809_9, + 676.520_368_121_885_1, + -1_259.139_216_722_402_8, + 771.323_428_777_653_1, + -176.615_029_162_140_6, + 12.507_343_278_686_905, + -0.138_571_095_265_720_12, + 9.984_369_578_019_572e-6, + 1.505_632_735_149_311_6e-7, +]; diff --git a/src/algorithm/special/helpers.rs b/src/algorithm/special/helpers.rs new file mode 100644 index 00000000..d90d2c31 --- /dev/null +++ b/src/algorithm/special/helpers.rs @@ -0,0 +1,17 @@ +//! Validation helpers for special mathematical functions. + +use crate::dtype::DType; +use crate::error::{Error, Result}; + +/// Validate that dtype is suitable for special functions. +pub fn validate_special_dtype(dtype: DType) -> Result<()> { + match dtype { + DType::F32 | DType::F64 | DType::F16 | DType::BF16 | DType::FP8E4M3 | DType::FP8E5M2 => { + Ok(()) + } + _ => Err(Error::UnsupportedDType { + dtype, + op: "special function", + }), + } +} diff --git a/src/algorithm/special/mod.rs b/src/algorithm/special/mod.rs index b8779211..6a2dface 100644 --- a/src/algorithm/special/mod.rs +++ b/src/algorithm/special/mod.rs @@ -1,627 +1,17 @@ -//! Special mathematical functions for scientific computing +//! Special mathematical functions for scientific computing. //! -//! This module defines traits for special functions required by probability -//! distributions, statistics, and scientific applications. These are critical -//! for solvr::stats to implement distributions like normal, gamma, beta, etc. -//! -//! # Functions Provided -//! -//! ## Error Functions (for normal distribution) -//! - `erf` - Error function -//! - `erfc` - Complementary error function (1 - erf(x)) -//! - `erfinv` - Inverse error function -//! -//! ## Gamma Functions (for gamma, chi2, t, F distributions) -//! - `gamma` - Gamma function Γ(x) -//! - `lgamma` - Log-gamma function ln(Γ(x)) (numerically stable) -//! - `digamma` - Digamma function ψ(x) = Γ'(x)/Γ(x) -//! -//! ## Beta Functions (for beta distribution) -//! - `beta` - Beta function B(a,b) = Γ(a)Γ(b)/Γ(a+b) -//! - `betainc` - Regularized incomplete beta function I_x(a,b) -//! -//! ## Incomplete Gamma (for gamma/chi2 CDF) -//! - `gammainc` - Lower regularized incomplete gamma P(a,x) -//! - `gammaincc` - Upper regularized incomplete gamma Q(a,x) = 1 - P(a,x) -//! -//! ## Bessel Functions -//! - `bessel_j0`, `bessel_j1` - First kind J₀, J₁ -//! - `bessel_y0`, `bessel_y1` - Second kind Y₀, Y₁ -//! - `bessel_i0`, `bessel_i1` - Modified first kind I₀, I₁ -//! - `bessel_k0`, `bessel_k1` - Modified second kind K₀, K₁ -//! -//! ## Elliptic Integrals -//! - `ellipk` - Complete elliptic integral of first kind K(m) -//! - `ellipe` - Complete elliptic integral of second kind E(m) -//! -//! ## Hypergeometric Functions -//! - `hyp2f1` - Gauss hypergeometric function ₂F₁(a, b; c; z) -//! - `hyp1f1` - Confluent hypergeometric function ₁F₁(a; b; z) -//! -//! ## Airy Functions -//! - `airy_ai` - Airy function of first kind Ai(x) -//! - `airy_bi` - Airy function of second kind Bi(x) -//! -//! ## Legendre Functions and Spherical Harmonics -//! - `legendre_p` - Legendre polynomial P_n(x) -//! - `legendre_p_assoc` - Associated Legendre function P_n^m(x) -//! - `sph_harm` - Real spherical harmonic Y_n^m(θ, φ) -//! -//! ## Fresnel Integrals -//! - `fresnel_s` - Fresnel sine integral S(x) -//! - `fresnel_c` - Fresnel cosine integral C(x) -//! -//! # Algorithm Sources -//! -//! Implementations follow well-established numerical algorithms: -//! - Cody's rational approximation for erf/erfc -//! - Lanczos approximation for gamma/lgamma -//! - Continued fraction expansion for incomplete gamma/beta -//! - Newton-Raphson iteration for inverse functions -//! - Numerical Recipes polynomial approximations for Bessel functions -//! - AGM method for elliptic integrals -//! - Power series with transformations for hypergeometric functions -//! - Power series and asymptotic expansions for Airy functions -//! - Three-term recurrence for Legendre polynomials +//! See [`traits`] for the `SpecialFunctions` trait, [`constants`] for +//! mathematical constants, and [`helpers`] for validation utilities. pub mod bessel_coefficients; +pub mod constants; +pub mod helpers; pub mod scalar; +pub mod traits; +pub use constants::{ + EULER_MASCHERONI, LANCZOS_COEFFICIENTS, LANCZOS_G, LN_SQRT_2PI, SQRT_PI, TWO_OVER_SQRT_PI, +}; +pub use helpers::validate_special_dtype; pub use scalar::*; - -use crate::error::{Error, Result}; -use crate::runtime::Runtime; -use crate::tensor::Tensor; - -// ============================================================================ -// Special Functions Trait -// ============================================================================ - -/// Special mathematical functions for scientific computing. -/// -/// All backends must implement these functions to enable solvr probability -/// distributions and statistical functions. -/// -/// # Implementation Notes -/// -/// - Functions operate element-wise on tensors -/// - Input validation (domain checks) should return appropriate errors -/// - Numerical stability is critical - use established algorithms -/// - GPU implementations can use the same algorithms as CPU -pub trait SpecialFunctions { - // ======================================================================== - // Error Functions - // ======================================================================== - - /// Compute the error function element-wise. - /// - /// ```text - /// erf(x) = (2/√π) ∫₀ˣ e^(-t²) dt - /// ``` - /// - /// # Properties - /// - Domain: all real numbers - /// - Range: (-1, 1) - /// - erf(0) = 0 - /// - erf(∞) = 1, erf(-∞) = -1 - /// - erf(-x) = -erf(x) (odd function) - fn erf(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::erf", - }) - } - - /// Compute the complementary error function element-wise. - /// - /// ```text - /// erfc(x) = 1 - erf(x) = (2/√π) ∫ₓ^∞ e^(-t²) dt - /// ``` - /// - /// For large x, erf(x) ≈ 1 and computing 1 - erf(x) loses precision. - /// erfc(x) computes the small tail directly, maintaining accuracy. - fn erfc(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::erfc", - }) - } - - /// Compute the inverse error function element-wise. - /// - /// Returns y such that erf(y) = x. - /// - /// # Properties - /// - Domain: (-1, 1) - /// - Range: all real numbers - /// - erfinv(0) = 0 - fn erfinv(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::erfinv", - }) - } - - // ======================================================================== - // Gamma Functions - // ======================================================================== - - /// Compute the gamma function element-wise. - /// - /// ```text - /// Γ(x) = ∫₀^∞ t^(x-1) e^(-t) dt - /// ``` - /// - /// # Properties - /// - Γ(n) = (n-1)! for positive integers - /// - Γ(1) = 1, Γ(1/2) = √π - /// - Has poles at non-positive integers (returns NaN/Inf) - fn gamma(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::gamma", - }) - } - - /// Compute the log-gamma function element-wise. - /// - /// ```text - /// lgamma(x) = ln(|Γ(x)|) - /// ``` - /// - /// Γ(x) grows extremely fast (Γ(171) overflows F64). - /// lgamma computes the logarithm directly without overflow. - fn lgamma(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::lgamma", - }) - } - - /// Compute the digamma (psi) function element-wise. - /// - /// ```text - /// ψ(x) = d/dx ln(Γ(x)) = Γ'(x)/Γ(x) - /// ``` - fn digamma(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::digamma", - }) - } - - // ======================================================================== - // Beta Functions - // ======================================================================== - - /// Compute the beta function element-wise. - /// - /// ```text - /// B(a, b) = Γ(a)Γ(b)/Γ(a+b) - /// ``` - fn beta(&self, a: &Tensor, b: &Tensor) -> Result> { - let _ = (a, b); - Err(Error::NotImplemented { - feature: "SpecialFunctions::beta", - }) - } - - /// Compute the regularized incomplete beta function element-wise. - /// - /// ```text - /// I_x(a,b) = B(x;a,b)/B(a,b) = (1/B(a,b)) ∫₀ˣ t^(a-1)(1-t)^(b-1) dt - /// ``` - fn betainc(&self, a: &Tensor, b: &Tensor, x: &Tensor) -> Result> { - let _ = (a, b, x); - Err(Error::NotImplemented { - feature: "SpecialFunctions::betainc", - }) - } - - // ======================================================================== - // Incomplete Gamma Functions - // ======================================================================== - - /// Compute the lower regularized incomplete gamma function. - /// - /// ```text - /// P(a, x) = γ(a,x)/Γ(a) = (1/Γ(a)) ∫₀ˣ t^(a-1) e^(-t) dt - /// ``` - fn gammainc(&self, a: &Tensor, x: &Tensor) -> Result> { - let _ = (a, x); - Err(Error::NotImplemented { - feature: "SpecialFunctions::gammainc", - }) - } - - /// Compute the upper regularized incomplete gamma function. - /// - /// ```text - /// Q(a, x) = 1 - P(a, x) - /// ``` - fn gammaincc(&self, a: &Tensor, x: &Tensor) -> Result> { - let _ = (a, x); - Err(Error::NotImplemented { - feature: "SpecialFunctions::gammaincc", - }) - } - - /// Compute the inverse of the lower regularized incomplete gamma function. - /// - /// Returns x such that P(a, x) = p. - /// - /// # Properties - /// - Domain: p in [0, 1], a > 0 - /// - Range: x >= 0 - /// - gammaincinv(a, 0) = 0 - /// - gammaincinv(a, 1) = ∞ - fn gammaincinv(&self, a: &Tensor, p: &Tensor) -> Result> { - let _ = (a, p); - Err(Error::NotImplemented { - feature: "SpecialFunctions::gammaincinv", - }) - } - - /// Compute the inverse of the regularized incomplete beta function. - /// - /// Returns x such that I_x(a, b) = p. - /// - /// # Properties - /// - Domain: p in [0, 1], a > 0, b > 0 - /// - Range: x in [0, 1] - /// - betaincinv(a, b, 0) = 0 - /// - betaincinv(a, b, 1) = 1 - fn betaincinv(&self, a: &Tensor, b: &Tensor, p: &Tensor) -> Result> { - let _ = (a, b, p); - Err(Error::NotImplemented { - feature: "SpecialFunctions::betaincinv", - }) - } - - // ======================================================================== - // Bessel Functions - // ======================================================================== - - /// Compute Bessel function of the first kind, order 0. - /// - /// J₀(0) = 1, even function, oscillates with decreasing amplitude. - fn bessel_j0(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::bessel_j0", - }) - } - - /// Compute Bessel function of the first kind, order 1. - /// - /// J₁(0) = 0, odd function, oscillates with decreasing amplitude. - fn bessel_j1(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::bessel_j1", - }) - } - - /// Compute Bessel function of the second kind, order 0 (Neumann function). - /// - /// Y₀(x) → -∞ as x → 0⁺. Domain: x > 0. - fn bessel_y0(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::bessel_y0", - }) - } - - /// Compute Bessel function of the second kind, order 1 (Neumann function). - /// - /// Y₁(x) → -∞ as x → 0⁺. Domain: x > 0. - fn bessel_y1(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::bessel_y1", - }) - } - - /// Compute modified Bessel function of the first kind, order 0. - /// - /// I₀(0) = 1, even function, grows exponentially. - fn bessel_i0(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::bessel_i0", - }) - } - - /// Compute modified Bessel function of the first kind, order 1. - /// - /// I₁(0) = 0, odd function, grows exponentially. - fn bessel_i1(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::bessel_i1", - }) - } - - /// Compute modified Bessel function of the second kind, order 0. - /// - /// K₀(x) → ∞ as x → 0⁺. Domain: x > 0. Decays exponentially. - fn bessel_k0(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::bessel_k0", - }) - } - - /// Compute modified Bessel function of the second kind, order 1. - /// - /// K₁(x) → ∞ as x → 0⁺. Domain: x > 0. Decays exponentially. - fn bessel_k1(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::bessel_k1", - }) - } - - // ======================================================================== - // Elliptic Integrals - // ======================================================================== - - /// Compute the complete elliptic integral of the first kind K(m). - /// - /// ```text - /// K(m) = ∫₀^(π/2) dθ / √(1 - m·sin²θ) - /// ``` - /// - /// # Properties - /// - Domain: m ∈ [0, 1) - /// - K(0) = π/2 - /// - K(m) → ∞ as m → 1 - /// - Uses parameter convention m = k², where k is the modulus - fn ellipk(&self, m: &Tensor) -> Result> { - let _ = m; - Err(Error::NotImplemented { - feature: "SpecialFunctions::ellipk", - }) - } - - /// Compute the complete elliptic integral of the second kind E(m). - /// - /// ```text - /// E(m) = ∫₀^(π/2) √(1 - m·sin²θ) dθ - /// ``` - /// - /// # Properties - /// - Domain: m ∈ [0, 1] - /// - E(0) = π/2 - /// - E(1) = 1 - fn ellipe(&self, m: &Tensor) -> Result> { - let _ = m; - Err(Error::NotImplemented { - feature: "SpecialFunctions::ellipe", - }) - } - - // ======================================================================== - // Hypergeometric Functions - // ======================================================================== - - /// Compute the Gauss hypergeometric function ₂F₁(a, b; c; z). - /// - /// ```text - /// ₂F₁(a, b; c; z) = Σ_{n=0}^∞ (a)_n (b)_n / ((c)_n n!) z^n - /// ``` - /// - /// # Properties - /// - Converges for |z| < 1 - /// - ₂F₁(a, b; c; 0) = 1 - /// - /// # Arguments - /// - a, b, c: Scalar parameters - /// - z: Input tensor - fn hyp2f1(&self, a: f64, b: f64, c: f64, z: &Tensor) -> Result> { - let _ = (a, b, c, z); - Err(Error::NotImplemented { - feature: "SpecialFunctions::hyp2f1", - }) - } - - /// Compute the confluent hypergeometric function ₁F₁(a; b; z) (Kummer's M). - /// - /// ```text - /// ₁F₁(a; b; z) = M(a, b, z) = Σ_{n=0}^∞ (a)_n / ((b)_n n!) z^n - /// ``` - /// - /// # Properties - /// - ₁F₁(a; b; 0) = 1 - /// - ₁F₁(0; b; z) = 1 - /// - Entire function in z - fn hyp1f1(&self, a: f64, b: f64, z: &Tensor) -> Result> { - let _ = (a, b, z); - Err(Error::NotImplemented { - feature: "SpecialFunctions::hyp1f1", - }) - } - - // ======================================================================== - // Airy Functions - // ======================================================================== - - /// Compute the Airy function of the first kind Ai(x). - /// - /// ```text - /// Ai(x) is the solution of y'' - xy = 0 that decays as x → +∞ - /// ``` - /// - /// # Properties - /// - Ai(x) → 0 as x → +∞ (exponentially) - /// - Ai(x) oscillates for x < 0 - /// - Ai(0) ≈ 0.3550280538878172 - fn airy_ai(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::airy_ai", - }) - } - - /// Compute the Airy function of the second kind Bi(x). - /// - /// ```text - /// Bi(x) is the solution of y'' - xy = 0 that grows as x → +∞ - /// ``` - /// - /// # Properties - /// - Bi(x) → +∞ as x → +∞ (exponentially) - /// - Bi(x) oscillates for x < 0 - /// - Bi(0) ≈ 0.6149266274460007 - fn airy_bi(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::airy_bi", - }) - } - - // ======================================================================== - // Legendre Functions - // ======================================================================== - - /// Compute the Legendre polynomial P_n(x). - /// - /// # Properties - /// - Domain: x ∈ [-1, 1] - /// - P_n(1) = 1 - /// - P_n(-1) = (-1)^n - /// - P_0(x) = 1, P_1(x) = x - fn legendre_p(&self, n: i32, x: &Tensor) -> Result> { - let _ = (n, x); - Err(Error::NotImplemented { - feature: "SpecialFunctions::legendre_p", - }) - } - - /// Compute the associated Legendre function P_n^m(x). - /// - /// Uses Condon-Shortley phase convention (factor of (-1)^m). - /// - /// # Properties - /// - Domain: x ∈ [-1, 1], 0 ≤ m ≤ n - /// - P_n^0(x) = P_n(x) - fn legendre_p_assoc(&self, n: i32, m: i32, x: &Tensor) -> Result> { - let _ = (n, m, x); - Err(Error::NotImplemented { - feature: "SpecialFunctions::legendre_p_assoc", - }) - } - - /// Compute the real spherical harmonic Y_n^m(θ, φ). - /// - /// Returns the real-valued spherical harmonic with Schmidt semi-normalization. - /// - m > 0: Y_n^m ∝ P_n^m(cos θ) cos(mφ) - /// - m = 0: Y_n^0 ∝ P_n(cos θ) - /// - m < 0: Y_n^m ∝ P_n^|m|(cos θ) sin(|m|φ) - /// - /// # Arguments - /// - n: degree (n ≥ 0) - /// - m: order (-n ≤ m ≤ n) - /// - theta: polar angle θ ∈ [0, π] (colatitude) - /// - phi: azimuthal angle φ ∈ [0, 2π) - fn sph_harm(&self, n: i32, m: i32, theta: &Tensor, phi: &Tensor) -> Result> { - let _ = (n, m, theta, phi); - Err(Error::NotImplemented { - feature: "SpecialFunctions::sph_harm", - }) - } - - // ======================================================================== - // Fresnel Integrals - // ======================================================================== - - /// Compute the Fresnel sine integral S(x). - /// - /// ```text - /// S(x) = ∫₀ˣ sin(π t²/2) dt - /// ``` - /// - /// # Properties - /// - S(0) = 0 - /// - S(∞) = 0.5 - /// - S(-x) = -S(x) (odd function) - fn fresnel_s(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::fresnel_s", - }) - } - - /// Compute the Fresnel cosine integral C(x). - /// - /// ```text - /// C(x) = ∫₀ˣ cos(π t²/2) dt - /// ``` - /// - /// # Properties - /// - C(0) = 0 - /// - C(∞) = 0.5 - /// - C(-x) = -C(x) (odd function) - fn fresnel_c(&self, x: &Tensor) -> Result> { - let _ = x; - Err(Error::NotImplemented { - feature: "SpecialFunctions::fresnel_c", - }) - } -} - -// ============================================================================ -// Validation Helpers -// ============================================================================ - -/// Validate that dtype is suitable for special functions. -pub fn validate_special_dtype(dtype: crate::dtype::DType) -> Result<()> { - use crate::dtype::DType; - use crate::error::Error; - - match dtype { - DType::F32 | DType::F64 | DType::F16 | DType::BF16 | DType::FP8E4M3 | DType::FP8E5M2 => { - Ok(()) - } - _ => Err(Error::UnsupportedDType { - dtype, - op: "special function", - }), - } -} - -// ============================================================================ -// Mathematical Constants -// ============================================================================ - -/// Square root of pi: √π ≈ 1.7724538509055159 -pub const SQRT_PI: f64 = 1.7724538509055160272981674833411451827975; - -/// 2 / √π ≈ 1.1283791670955126 (used in erf) -pub const TWO_OVER_SQRT_PI: f64 = std::f64::consts::FRAC_2_SQRT_PI; - -/// Euler-Mascheroni constant: γ ≈ 0.5772156649015329 -pub const EULER_MASCHERONI: f64 = 0.5772156649015328606065120900824024310422; - -/// ln(√(2π)) ≈ 0.9189385332046727 (used in Stirling's approximation) -pub const LN_SQRT_2PI: f64 = 0.9189385332046727417803297364056176398614; - -// ============================================================================ -// Lanczos Coefficients for Gamma Function -// ============================================================================ - -/// Lanczos approximation coefficients (g=7, n=9). -pub const LANCZOS_G: f64 = 7.0; - -/// Lanczos coefficients for g=7. -pub const LANCZOS_COEFFICIENTS: [f64; 9] = [ - 0.999_999_999_999_809_9, - 676.520_368_121_885_1, - -1_259.139_216_722_402_8, - 771.323_428_777_653_1, - -176.615_029_162_140_6, - 12.507_343_278_686_905, - -0.138_571_095_265_720_12, - 9.984_369_578_019_572e-6, - 1.505_632_735_149_311_6e-7, -]; +pub use traits::SpecialFunctions; diff --git a/src/algorithm/special/traits.rs b/src/algorithm/special/traits.rs new file mode 100644 index 00000000..f5c85809 --- /dev/null +++ b/src/algorithm/special/traits.rs @@ -0,0 +1,566 @@ +//! Special mathematical functions trait for scientific computing. +//! +//! Defines the `SpecialFunctions` trait required by probability distributions, +//! statistics, and scientific applications. These are critical for +//! solvr::stats to implement distributions like normal, gamma, beta, etc. +//! +//! # Functions Provided +//! +//! ## Error Functions (for normal distribution) +//! - `erf` - Error function +//! - `erfc` - Complementary error function (1 - erf(x)) +//! - `erfinv` - Inverse error function +//! +//! ## Gamma Functions (for gamma, chi2, t, F distributions) +//! - `gamma` - Gamma function Γ(x) +//! - `lgamma` - Log-gamma function ln(Γ(x)) (numerically stable) +//! - `digamma` - Digamma function ψ(x) = Γ'(x)/Γ(x) +//! +//! ## Beta Functions (for beta distribution) +//! - `beta` - Beta function B(a,b) = Γ(a)Γ(b)/Γ(a+b) +//! - `betainc` - Regularized incomplete beta function I_x(a,b) +//! +//! ## Incomplete Gamma (for gamma/chi2 CDF) +//! - `gammainc` - Lower regularized incomplete gamma P(a,x) +//! - `gammaincc` - Upper regularized incomplete gamma Q(a,x) = 1 - P(a,x) +//! +//! ## Bessel Functions +//! - `bessel_j0`, `bessel_j1` - First kind J₀, J₁ +//! - `bessel_y0`, `bessel_y1` - Second kind Y₀, Y₁ +//! - `bessel_i0`, `bessel_i1` - Modified first kind I₀, I₁ +//! - `bessel_k0`, `bessel_k1` - Modified second kind K₀, K₁ +//! +//! ## Elliptic Integrals +//! - `ellipk` - Complete elliptic integral of first kind K(m) +//! - `ellipe` - Complete elliptic integral of second kind E(m) +//! +//! ## Hypergeometric Functions +//! - `hyp2f1` - Gauss hypergeometric function ₂F₁(a, b; c; z) +//! - `hyp1f1` - Confluent hypergeometric function ₁F₁(a; b; z) +//! +//! ## Airy Functions +//! - `airy_ai` - Airy function of first kind Ai(x) +//! - `airy_bi` - Airy function of second kind Bi(x) +//! +//! ## Legendre Functions and Spherical Harmonics +//! - `legendre_p` - Legendre polynomial P_n(x) +//! - `legendre_p_assoc` - Associated Legendre function P_n^m(x) +//! - `sph_harm` - Real spherical harmonic Y_n^m(θ, φ) +//! +//! ## Fresnel Integrals +//! - `fresnel_s` - Fresnel sine integral S(x) +//! - `fresnel_c` - Fresnel cosine integral C(x) +//! +//! # Algorithm Sources +//! +//! Implementations follow well-established numerical algorithms: +//! - Cody's rational approximation for erf/erfc +//! - Lanczos approximation for gamma/lgamma +//! - Continued fraction expansion for incomplete gamma/beta +//! - Newton-Raphson iteration for inverse functions +//! - Numerical Recipes polynomial approximations for Bessel functions +//! - AGM method for elliptic integrals +//! - Power series with transformations for hypergeometric functions +//! - Power series and asymptotic expansions for Airy functions +//! - Three-term recurrence for Legendre polynomials + +use crate::error::{Error, Result}; +use crate::runtime::Runtime; +use crate::tensor::Tensor; + +// ============================================================================ +// Special Functions Trait +// ============================================================================ + +/// Special mathematical functions for scientific computing. +/// +/// All backends must implement these functions to enable solvr probability +/// distributions and statistical functions. +/// +/// # Implementation Notes +/// +/// - Functions operate element-wise on tensors +/// - Input validation (domain checks) should return appropriate errors +/// - Numerical stability is critical - use established algorithms +/// - GPU implementations can use the same algorithms as CPU +pub trait SpecialFunctions { + // ======================================================================== + // Error Functions + // ======================================================================== + + /// Compute the error function element-wise. + /// + /// ```text + /// erf(x) = (2/√π) ∫₀ˣ e^(-t²) dt + /// ``` + /// + /// # Properties + /// - Domain: all real numbers + /// - Range: (-1, 1) + /// - erf(0) = 0 + /// - erf(∞) = 1, erf(-∞) = -1 + /// - erf(-x) = -erf(x) (odd function) + fn erf(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::erf", + }) + } + + /// Compute the complementary error function element-wise. + /// + /// ```text + /// erfc(x) = 1 - erf(x) = (2/√π) ∫ₓ^∞ e^(-t²) dt + /// ``` + /// + /// For large x, erf(x) ≈ 1 and computing 1 - erf(x) loses precision. + /// erfc(x) computes the small tail directly, maintaining accuracy. + fn erfc(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::erfc", + }) + } + + /// Compute the inverse error function element-wise. + /// + /// Returns y such that erf(y) = x. + /// + /// # Properties + /// - Domain: (-1, 1) + /// - Range: all real numbers + /// - erfinv(0) = 0 + fn erfinv(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::erfinv", + }) + } + + // ======================================================================== + // Gamma Functions + // ======================================================================== + + /// Compute the gamma function element-wise. + /// + /// ```text + /// Γ(x) = ∫₀^∞ t^(x-1) e^(-t) dt + /// ``` + /// + /// # Properties + /// - Γ(n) = (n-1)! for positive integers + /// - Γ(1) = 1, Γ(1/2) = √π + /// - Has poles at non-positive integers (returns NaN/Inf) + fn gamma(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::gamma", + }) + } + + /// Compute the log-gamma function element-wise. + /// + /// ```text + /// lgamma(x) = ln(|Γ(x)|) + /// ``` + /// + /// Γ(x) grows extremely fast (Γ(171) overflows F64). + /// lgamma computes the logarithm directly without overflow. + fn lgamma(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::lgamma", + }) + } + + /// Compute the digamma (psi) function element-wise. + /// + /// ```text + /// ψ(x) = d/dx ln(Γ(x)) = Γ'(x)/Γ(x) + /// ``` + fn digamma(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::digamma", + }) + } + + // ======================================================================== + // Beta Functions + // ======================================================================== + + /// Compute the beta function element-wise. + /// + /// ```text + /// B(a, b) = Γ(a)Γ(b)/Γ(a+b) + /// ``` + fn beta(&self, a: &Tensor, b: &Tensor) -> Result> { + let _ = (a, b); + Err(Error::NotImplemented { + feature: "SpecialFunctions::beta", + }) + } + + /// Compute the regularized incomplete beta function element-wise. + /// + /// ```text + /// I_x(a,b) = B(x;a,b)/B(a,b) = (1/B(a,b)) ∫₀ˣ t^(a-1)(1-t)^(b-1) dt + /// ``` + fn betainc(&self, a: &Tensor, b: &Tensor, x: &Tensor) -> Result> { + let _ = (a, b, x); + Err(Error::NotImplemented { + feature: "SpecialFunctions::betainc", + }) + } + + // ======================================================================== + // Incomplete Gamma Functions + // ======================================================================== + + /// Compute the lower regularized incomplete gamma function. + /// + /// ```text + /// P(a, x) = γ(a,x)/Γ(a) = (1/Γ(a)) ∫₀ˣ t^(a-1) e^(-t) dt + /// ``` + fn gammainc(&self, a: &Tensor, x: &Tensor) -> Result> { + let _ = (a, x); + Err(Error::NotImplemented { + feature: "SpecialFunctions::gammainc", + }) + } + + /// Compute the upper regularized incomplete gamma function. + /// + /// ```text + /// Q(a, x) = 1 - P(a, x) + /// ``` + fn gammaincc(&self, a: &Tensor, x: &Tensor) -> Result> { + let _ = (a, x); + Err(Error::NotImplemented { + feature: "SpecialFunctions::gammaincc", + }) + } + + /// Compute the inverse of the lower regularized incomplete gamma function. + /// + /// Returns x such that P(a, x) = p. + /// + /// # Properties + /// - Domain: p in [0, 1], a > 0 + /// - Range: x >= 0 + /// - gammaincinv(a, 0) = 0 + /// - gammaincinv(a, 1) = ∞ + fn gammaincinv(&self, a: &Tensor, p: &Tensor) -> Result> { + let _ = (a, p); + Err(Error::NotImplemented { + feature: "SpecialFunctions::gammaincinv", + }) + } + + /// Compute the inverse of the regularized incomplete beta function. + /// + /// Returns x such that I_x(a, b) = p. + /// + /// # Properties + /// - Domain: p in [0, 1], a > 0, b > 0 + /// - Range: x in [0, 1] + /// - betaincinv(a, b, 0) = 0 + /// - betaincinv(a, b, 1) = 1 + fn betaincinv(&self, a: &Tensor, b: &Tensor, p: &Tensor) -> Result> { + let _ = (a, b, p); + Err(Error::NotImplemented { + feature: "SpecialFunctions::betaincinv", + }) + } + + // ======================================================================== + // Bessel Functions + // ======================================================================== + + /// Compute Bessel function of the first kind, order 0. + /// + /// J₀(0) = 1, even function, oscillates with decreasing amplitude. + fn bessel_j0(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::bessel_j0", + }) + } + + /// Compute Bessel function of the first kind, order 1. + /// + /// J₁(0) = 0, odd function, oscillates with decreasing amplitude. + fn bessel_j1(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::bessel_j1", + }) + } + + /// Compute Bessel function of the second kind, order 0 (Neumann function). + /// + /// Y₀(x) → -∞ as x → 0⁺. Domain: x > 0. + fn bessel_y0(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::bessel_y0", + }) + } + + /// Compute Bessel function of the second kind, order 1 (Neumann function). + /// + /// Y₁(x) → -∞ as x → 0⁺. Domain: x > 0. + fn bessel_y1(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::bessel_y1", + }) + } + + /// Compute modified Bessel function of the first kind, order 0. + /// + /// I₀(0) = 1, even function, grows exponentially. + fn bessel_i0(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::bessel_i0", + }) + } + + /// Compute modified Bessel function of the first kind, order 1. + /// + /// I₁(0) = 0, odd function, grows exponentially. + fn bessel_i1(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::bessel_i1", + }) + } + + /// Compute modified Bessel function of the second kind, order 0. + /// + /// K₀(x) → ∞ as x → 0⁺. Domain: x > 0. Decays exponentially. + fn bessel_k0(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::bessel_k0", + }) + } + + /// Compute modified Bessel function of the second kind, order 1. + /// + /// K₁(x) → ∞ as x → 0⁺. Domain: x > 0. Decays exponentially. + fn bessel_k1(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::bessel_k1", + }) + } + + // ======================================================================== + // Elliptic Integrals + // ======================================================================== + + /// Compute the complete elliptic integral of the first kind K(m). + /// + /// ```text + /// K(m) = ∫₀^(π/2) dθ / √(1 - m·sin²θ) + /// ``` + /// + /// # Properties + /// - Domain: m ∈ [0, 1) + /// - K(0) = π/2 + /// - K(m) → ∞ as m → 1 + /// - Uses parameter convention m = k², where k is the modulus + fn ellipk(&self, m: &Tensor) -> Result> { + let _ = m; + Err(Error::NotImplemented { + feature: "SpecialFunctions::ellipk", + }) + } + + /// Compute the complete elliptic integral of the second kind E(m). + /// + /// ```text + /// E(m) = ∫₀^(π/2) √(1 - m·sin²θ) dθ + /// ``` + /// + /// # Properties + /// - Domain: m ∈ [0, 1] + /// - E(0) = π/2 + /// - E(1) = 1 + fn ellipe(&self, m: &Tensor) -> Result> { + let _ = m; + Err(Error::NotImplemented { + feature: "SpecialFunctions::ellipe", + }) + } + + // ======================================================================== + // Hypergeometric Functions + // ======================================================================== + + /// Compute the Gauss hypergeometric function ₂F₁(a, b; c; z). + /// + /// ```text + /// ₂F₁(a, b; c; z) = Σ_{n=0}^∞ (a)_n (b)_n / ((c)_n n!) z^n + /// ``` + /// + /// # Properties + /// - Converges for |z| < 1 + /// - ₂F₁(a, b; c; 0) = 1 + /// + /// # Arguments + /// - a, b, c: Scalar parameters + /// - z: Input tensor + fn hyp2f1(&self, a: f64, b: f64, c: f64, z: &Tensor) -> Result> { + let _ = (a, b, c, z); + Err(Error::NotImplemented { + feature: "SpecialFunctions::hyp2f1", + }) + } + + /// Compute the confluent hypergeometric function ₁F₁(a; b; z) (Kummer's M). + /// + /// ```text + /// ₁F₁(a; b; z) = M(a, b, z) = Σ_{n=0}^∞ (a)_n / ((b)_n n!) z^n + /// ``` + /// + /// # Properties + /// - ₁F₁(a; b; 0) = 1 + /// - ₁F₁(0; b; z) = 1 + /// - Entire function in z + fn hyp1f1(&self, a: f64, b: f64, z: &Tensor) -> Result> { + let _ = (a, b, z); + Err(Error::NotImplemented { + feature: "SpecialFunctions::hyp1f1", + }) + } + + // ======================================================================== + // Airy Functions + // ======================================================================== + + /// Compute the Airy function of the first kind Ai(x). + /// + /// ```text + /// Ai(x) is the solution of y'' - xy = 0 that decays as x → +∞ + /// ``` + /// + /// # Properties + /// - Ai(x) → 0 as x → +∞ (exponentially) + /// - Ai(x) oscillates for x < 0 + /// - Ai(0) ≈ 0.3550280538878172 + fn airy_ai(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::airy_ai", + }) + } + + /// Compute the Airy function of the second kind Bi(x). + /// + /// ```text + /// Bi(x) is the solution of y'' - xy = 0 that grows as x → +∞ + /// ``` + /// + /// # Properties + /// - Bi(x) → +∞ as x → +∞ (exponentially) + /// - Bi(x) oscillates for x < 0 + /// - Bi(0) ≈ 0.6149266274460007 + fn airy_bi(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::airy_bi", + }) + } + + // ======================================================================== + // Legendre Functions + // ======================================================================== + + /// Compute the Legendre polynomial P_n(x). + /// + /// # Properties + /// - Domain: x ∈ [-1, 1] + /// - P_n(1) = 1 + /// - P_n(-1) = (-1)^n + /// - P_0(x) = 1, P_1(x) = x + fn legendre_p(&self, n: i32, x: &Tensor) -> Result> { + let _ = (n, x); + Err(Error::NotImplemented { + feature: "SpecialFunctions::legendre_p", + }) + } + + /// Compute the associated Legendre function P_n^m(x). + /// + /// Uses Condon-Shortley phase convention (factor of (-1)^m). + /// + /// # Properties + /// - Domain: x ∈ [-1, 1], 0 ≤ m ≤ n + /// - P_n^0(x) = P_n(x) + fn legendre_p_assoc(&self, n: i32, m: i32, x: &Tensor) -> Result> { + let _ = (n, m, x); + Err(Error::NotImplemented { + feature: "SpecialFunctions::legendre_p_assoc", + }) + } + + /// Compute the real spherical harmonic Y_n^m(θ, φ). + /// + /// Returns the real-valued spherical harmonic with Schmidt semi-normalization. + /// - m > 0: Y_n^m ∝ P_n^m(cos θ) cos(mφ) + /// - m = 0: Y_n^0 ∝ P_n(cos θ) + /// - m < 0: Y_n^m ∝ P_n^|m|(cos θ) sin(|m|φ) + /// + /// # Arguments + /// - n: degree (n ≥ 0) + /// - m: order (-n ≤ m ≤ n) + /// - theta: polar angle θ ∈ [0, π] (colatitude) + /// - phi: azimuthal angle φ ∈ [0, 2π) + fn sph_harm(&self, n: i32, m: i32, theta: &Tensor, phi: &Tensor) -> Result> { + let _ = (n, m, theta, phi); + Err(Error::NotImplemented { + feature: "SpecialFunctions::sph_harm", + }) + } + + // ======================================================================== + // Fresnel Integrals + // ======================================================================== + + /// Compute the Fresnel sine integral S(x). + /// + /// ```text + /// S(x) = ∫₀ˣ sin(π t²/2) dt + /// ``` + /// + /// # Properties + /// - S(0) = 0 + /// - S(∞) = 0.5 + /// - S(-x) = -S(x) (odd function) + fn fresnel_s(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::fresnel_s", + }) + } + + /// Compute the Fresnel cosine integral C(x). + /// + /// ```text + /// C(x) = ∫₀ˣ cos(π t²/2) dt + /// ``` + /// + /// # Properties + /// - C(0) = 0 + /// - C(∞) = 0.5 + /// - C(-x) = -C(x) (odd function) + fn fresnel_c(&self, x: &Tensor) -> Result> { + let _ = x; + Err(Error::NotImplemented { + feature: "SpecialFunctions::fresnel_c", + }) + } +} diff --git a/src/autograd/backward.rs b/src/autograd/backward.rs index 70b2cb64..e14fd2d0 100644 --- a/src/autograd/backward.rs +++ b/src/autograd/backward.rs @@ -14,6 +14,7 @@ //! `Var`s that retain their computation history, enabling Hessians and HVPs. use super::{GradFn, GradStore, Var, VarGradStore, var_add}; +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::TensorOps; use crate::runtime::{Runtime, RuntimeClient}; @@ -21,6 +22,29 @@ use crate::tensor::{Tensor, TensorId}; use std::collections::HashSet; use std::sync::Arc; +// ============================================================================ +// Backward Hooks +// ============================================================================ + +/// Hook called during backward when a leaf variable's gradient is fully accumulated. +/// +/// This enables overlapping gradient communication with backward computation +/// in distributed training scenarios (e.g., bucketed allreduce). +pub trait BackwardHook: Send { + /// Called when a leaf variable's gradient is fully accumulated. + /// + /// At the point this is called, the gradient for `id` in the grad store + /// is complete — all upstream contributions have been accumulated. + fn on_leaf_grad_ready(&mut self, id: TensorId, grad: &Tensor); +} + +/// No-op backward hook for use when no hook behavior is needed. +pub struct NoOpHook; + +impl BackwardHook for NoOpHook { + fn on_leaf_grad_ready(&mut self, _id: TensorId, _grad: &Tensor) {} +} + // ============================================================================ // Helper Functions // ============================================================================ @@ -51,7 +75,7 @@ fn validate_loss(loss: &Var, fn_name: &str) -> Result<()> { /// Create the initial gradient tensor for the loss (dL/dL = 1) #[inline] -fn create_loss_gradient(loss: &Var) -> Tensor { +fn create_loss_gradient>(loss: &Var) -> Tensor { Tensor::::ones(loss.shape(), loss.tensor().dtype(), loss.tensor().device()) } @@ -93,10 +117,43 @@ fn create_loss_gradient(loss: &Var) -> Tensor { /// ``` pub fn backward(loss: &Var, client: &C) -> Result> where - R: Runtime, + R: Runtime, + C: RuntimeClient + TensorOps, +{ + backward_with_hooks(loss, client, &mut NoOpHook) +} + +/// Compute gradients with hooks that fire when leaf gradients are ready. +/// +/// Identical to [`backward`], but calls `hooks.on_leaf_grad_ready(id, grad)` +/// after a leaf variable's gradient is fully accumulated. This enables +/// overlapping gradient communication with backward computation (e.g., +/// bucketed allreduce in distributed training). +/// +/// A leaf variable is one with no `grad_fn` (i.e., a model parameter or +/// input created with `requires_grad = true`). By the time the hook fires, +/// all upstream contributions to that leaf's gradient have been accumulated. +/// +/// # Arguments +/// +/// * `loss` - The scalar loss tensor to differentiate +/// * `client` - The runtime client for tensor operations +/// * `hooks` - Hook implementation called when each leaf gradient is ready +/// +/// # Returns +/// +/// A `GradStore` containing gradients for all tensors in the graph. +pub fn backward_with_hooks( + loss: &Var, + client: &C, + hooks: &mut H, +) -> Result> +where + R: Runtime, C: RuntimeClient + TensorOps, + H: BackwardHook, { - validate_loss(loss, "backward")?; + validate_loss(loss, "backward_with_hooks")?; // Initialize gradient store with dL/dL = 1 let mut grad_store = GradStore::new(); @@ -129,6 +186,9 @@ where })?; } } + } else { + // Leaf node (no grad_fn) with a gradient — notify hook + hooks.on_leaf_grad_ready(var_id, &grad_output); } } @@ -183,7 +243,7 @@ where /// when you actually need second-order derivatives. pub fn backward_with_graph(loss: &Var, client: &C) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + TensorOps, R::Client: TensorOps, { @@ -279,6 +339,97 @@ mod tests { use crate::autograd::{var_mul, var_sum}; use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + use std::cell::RefCell; + use std::rc::Rc; + + /// Test hook that records leaf gradient notifications + struct RecordingHook { + leaf_ids: Rc>>, + } + + impl RecordingHook { + fn new() -> (Self, Rc>>) { + let ids = Rc::new(RefCell::new(Vec::new())); + ( + Self { + leaf_ids: ids.clone(), + }, + ids, + ) + } + } + + // RecordingHook is not Send (due to Rc), so we wrap for single-threaded tests + unsafe impl Send for RecordingHook {} + + impl BackwardHook for RecordingHook { + fn on_leaf_grad_ready(&mut self, id: TensorId, _grad: &Tensor) { + self.leaf_ids.borrow_mut().push(id); + } + } + + #[test] + fn test_backward_with_hooks_matches_backward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let x = Var::new( + Tensor::::from_slice(&[2.0f32], &[1], &device), + true, + ); + let y = Var::new( + Tensor::::from_slice(&[3.0f32], &[1], &device), + true, + ); + + // z = x * y + let z1 = var_mul(&x, &y, &client).unwrap(); + let z2 = var_mul(&x, &y, &client).unwrap(); + + let grads1 = backward(&z1, &client).unwrap(); + + let (mut hook, leaf_ids) = RecordingHook::new(); + let grads2 = backward_with_hooks(&z2, &client, &mut hook).unwrap(); + + // Gradients should match + let gx1: Vec = grads1.get(x.id()).unwrap().to_vec(); + let gx2: Vec = grads2.get(x.id()).unwrap().to_vec(); + assert!((gx1[0] - gx2[0]).abs() < 1e-6); + + let gy1: Vec = grads1.get(y.id()).unwrap().to_vec(); + let gy2: Vec = grads2.get(y.id()).unwrap().to_vec(); + assert!((gy1[0] - gy2[0]).abs() < 1e-6); + + // Hook should have been called for both leaf variables + let ids = leaf_ids.borrow(); + assert_eq!(ids.len(), 2); + assert!(ids.contains(&x.id())); + assert!(ids.contains(&y.id())); + } + + #[test] + fn test_backward_with_hooks_no_hook_for_non_leaf() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let x = Var::new( + Tensor::::from_slice(&[2.0f32, 3.0], &[2], &device), + true, + ); + + // y = sum(x * x) — intermediate x*x is NOT a leaf + let x_sq = var_mul(&x, &x, &client).unwrap(); + let loss = var_sum(&x_sq, &[0], false, &client).unwrap(); + + let (mut hook, leaf_ids) = RecordingHook::new(); + let _grads = backward_with_hooks(&loss, &client, &mut hook).unwrap(); + + // Only x is a leaf, not x_sq or loss + let ids = leaf_ids.borrow(); + assert_eq!(ids.len(), 1); + assert!(ids.contains(&x.id())); + } + #[test] fn test_backward_requires_scalar() { let device = CpuDevice::new(); diff --git a/src/autograd/checkpoint.rs b/src/autograd/checkpoint.rs new file mode 100644 index 00000000..f13ae7ef --- /dev/null +++ b/src/autograd/checkpoint.rs @@ -0,0 +1,325 @@ +//! Activation checkpointing for memory-efficient training. +//! +//! Discards intermediate activations during forward and recomputes them during +//! backward. Trades ~33% extra compute for dramatically less activation memory. +//! +//! # Example +//! +//! ``` +//! # use numr::prelude::*; +//! # use numr::autograd::{Var, backward, checkpoint, var_mul, var_sum}; +//! # let device = CpuDevice::new(); +//! # let client = CpuRuntime::default_client(&device); +//! let x = Var::new(Tensor::from_slice(&[3.0f32], &[1], &device), true); +//! +//! // Wrap computation in checkpoint — intermediates are dropped and recomputed +//! let y = checkpoint(|inputs, c| { +//! let x_sq = var_mul(&inputs[0], &inputs[0], c)?; +//! Ok(x_sq) +//! }, &[&x])?; +//! +//! let loss = var_sum(&y, &[], false, &client)?; +//! let grads = backward(&loss, &client)?; +//! // grad_x = 2 * 3 = 6 +//! # Ok::<(), numr::error::Error>(()) +//! ``` + +use std::sync::Arc; + +use crate::autograd::{GradFn, Var, backward, var_mul, var_sum}; +use crate::dtype::DType; +use crate::error::Result; +use crate::ops::TensorOps; +use crate::runtime::Runtime; +use crate::tensor::{Tensor, TensorId}; + +/// Run `f` on `inputs` with activation checkpointing. +/// +/// During forward, `f` runs on detached copies of the inputs so no intermediate +/// graph nodes are retained. During backward, `f` is re-run with grad tracking +/// to reconstruct the graph and propagate gradients. +pub fn checkpoint(f: F, inputs: &[&Var]) -> Result> +where + R: Runtime, + R::Client: TensorOps, + F: Fn(&[Var], &R::Client) -> Result> + Send + Sync + 'static, +{ + if inputs.is_empty() { + return Err(crate::error::Error::Internal( + "checkpoint requires at least one input".to_string(), + )); + } + + // Save original input info for backward + let input_ids: Vec = inputs.iter().map(|v| v.id()).collect(); + let input_tensors: Vec> = inputs.iter().map(|v| v.tensor().clone()).collect(); + let input_grad_fns: Vec>>> = + inputs.iter().map(|v| v.grad_fn().cloned()).collect(); + + // Forward: run on detached inputs (no grad tracking inside the segment) + let detached: Vec> = inputs + .iter() + .map(|v| Var::new(v.tensor().clone(), false)) + .collect(); + + let device = inputs[0].tensor().device(); + let client = R::default_client(device); + + let output = f(&detached, &client)?; + // output has no grad graph inside — all intermediates are already dropped + + let checkpoint_backward = CheckpointBackward { + func: Arc::new(f), + input_ids: input_ids.clone(), + input_tensors, + input_grad_fns, + }; + + Ok(Var::from_op( + output.tensor().clone(), + Arc::new(checkpoint_backward), + )) +} + +struct CheckpointBackward { + func: Arc], &R::Client) -> Result> + Send + Sync>, + input_ids: Vec, + input_tensors: Vec>, + input_grad_fns: Vec>>>, +} + +impl GradFn for CheckpointBackward +where + R: Runtime, + R::Client: TensorOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + + // Reconstruct input Vars as LEAF nodes with original IDs. + // They have no grad_fn so backward stops here — the outer backward + // pass handles continuing through input_grad_fns() returned below. + let reconstructed: Vec> = self + .input_ids + .iter() + .zip(self.input_tensors.iter()) + .map(|(id, tensor)| Var::with_id(tensor.clone(), *id, true)) + .collect(); + + // Re-run forward WITH grad tracking — rebuilds the intermediate graph + let recomputed_output = (self.func)(&reconstructed, &client)?; + + // Backprop grad_output through the recomputed graph. + // loss = sum(recomputed * grad_output) is a scalar whose gradient w.r.t. + // each input is exactly the VJP: sum_j(grad_output_j * d(output_j)/d(input_i)) + let grad_output_var = Var::new(grad_output.clone(), false); + let product = var_mul(&recomputed_output, &grad_output_var, &client)?; + let loss = var_sum(&product, &[], false, &client)?; + + let grads = backward(&loss, &client)?; + + Ok(self + .input_ids + .iter() + .map(|id| grads.get(*id).cloned()) + .collect()) + } + + fn inputs(&self) -> &[TensorId] { + &self.input_ids + } + + fn input_grad_fns(&self) -> Vec>>> { + self.input_grad_fns.clone() + } + + fn name(&self) -> &'static str { + "CheckpointBackward" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::autograd::{BackwardHook, backward, backward_with_hooks, var_add, var_mul, var_sum}; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + + fn device_and_client() -> (CpuDevice, ::Client) { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + (device, client) + } + + #[test] + fn test_checkpoint_x_squared() { + // f(x) = x^2, df/dx = 2x + let (device, client) = device_and_client(); + + let x = Var::new( + Tensor::::from_slice(&[3.0f32], &[1], &device), + true, + ); + + // Without checkpoint + let y_normal = var_mul(&x, &x, &client).unwrap(); + let loss_normal = var_sum(&y_normal, &[], false, &client).unwrap(); + let grads_normal = backward(&loss_normal, &client).unwrap(); + + // With checkpoint + let y_ckpt = checkpoint(|inputs, c| var_mul(&inputs[0], &inputs[0], c), &[&x]).unwrap(); + let loss_ckpt = var_sum(&y_ckpt, &[], false, &client).unwrap(); + let grads_ckpt = backward(&loss_ckpt, &client).unwrap(); + + let g_normal: Vec = grads_normal.get(x.id()).unwrap().to_vec(); + let g_ckpt: Vec = grads_ckpt.get(x.id()).unwrap().to_vec(); + + assert!( + (g_normal[0] - g_ckpt[0]).abs() < 1e-6, + "normal={}, checkpoint={}", + g_normal[0], + g_ckpt[0] + ); + assert!((g_ckpt[0] - 6.0).abs() < 1e-6); + } + + #[test] + fn test_checkpoint_multi_input() { + // f(x, y) = x * y + let (device, client) = device_and_client(); + + let x = Var::new( + Tensor::::from_slice(&[2.0f32], &[1], &device), + true, + ); + let y = Var::new( + Tensor::::from_slice(&[5.0f32], &[1], &device), + true, + ); + + let out = checkpoint(|inputs, c| var_mul(&inputs[0], &inputs[1], c), &[&x, &y]).unwrap(); + + let grads = backward(&out, &client).unwrap(); + + // d(x*y)/dx = y = 5 + let gx: Vec = grads.get(x.id()).unwrap().to_vec(); + assert!((gx[0] - 5.0).abs() < 1e-6); + + // d(x*y)/dy = x = 2 + let gy: Vec = grads.get(y.id()).unwrap().to_vec(); + assert!((gy[0] - 2.0).abs() < 1e-6); + } + + #[test] + fn test_checkpoint_chained() { + // checkpoint(f1) -> checkpoint(f2) + // f1(x) = x^2, f2(z) = z^2, so total = x^4 + // d(x^4)/dx = 4x^3 = 4*8 = 32 at x=2 + let (device, client) = device_and_client(); + + let x = Var::new( + Tensor::::from_slice(&[2.0f32], &[1], &device), + true, + ); + + let z = checkpoint(|inputs, c| var_mul(&inputs[0], &inputs[0], c), &[&x]).unwrap(); + + let w = checkpoint(|inputs, c| var_mul(&inputs[0], &inputs[0], c), &[&z]).unwrap(); + + let loss = var_sum(&w, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let gx: Vec = grads.get(x.id()).unwrap().to_vec(); + assert!((gx[0] - 32.0).abs() < 1e-4, "expected 32.0, got {}", gx[0]); + } + + #[test] + fn test_checkpoint_matches_normal_complex() { + // More complex: f(x) = (x + x) * x = 2x^2 + // df/dx = 4x = 12 at x=3 + let (device, client) = device_and_client(); + + let x = Var::new( + Tensor::::from_slice(&[3.0f32], &[1], &device), + true, + ); + + let y = checkpoint( + |inputs, c| { + let sum = var_add(&inputs[0], &inputs[0], c)?; + var_mul(&sum, &inputs[0], c) + }, + &[&x], + ) + .unwrap(); + + let loss = var_sum(&y, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let gx: Vec = grads.get(x.id()).unwrap().to_vec(); + assert!((gx[0] - 12.0).abs() < 1e-5, "expected 12.0, got {}", gx[0]); + } + + #[test] + fn test_checkpoint_with_backward_hooks() { + // Verify leaf hooks still fire through checkpointed segments + use std::cell::RefCell; + use std::rc::Rc; + + struct RecordingHook { + leaf_ids: Rc>>, + } + + unsafe impl Send for RecordingHook {} + + impl BackwardHook for RecordingHook { + fn on_leaf_grad_ready(&mut self, id: TensorId, _grad: &Tensor) { + self.leaf_ids.borrow_mut().push(id); + } + } + + let (device, client) = device_and_client(); + + let x = Var::new( + Tensor::::from_slice(&[3.0f32], &[1], &device), + true, + ); + + let y = checkpoint(|inputs, c| var_mul(&inputs[0], &inputs[0], c), &[&x]).unwrap(); + + let loss = var_sum(&y, &[], false, &client).unwrap(); + + let ids = Rc::new(RefCell::new(Vec::new())); + let mut hook = RecordingHook { + leaf_ids: ids.clone(), + }; + let _grads = backward_with_hooks(&loss, &client, &mut hook).unwrap(); + + let recorded = ids.borrow(); + assert!( + recorded.contains(&x.id()), + "leaf hook should have fired for x" + ); + } + + #[test] + fn test_checkpoint_vector_output() { + // f(x) = x * x where x is a vector [2, 3] + // loss = sum(f(x)) = 4 + 9 = 13 + // d(loss)/dx = [4, 6] + let (device, client) = device_and_client(); + + let x = Var::new( + Tensor::::from_slice(&[2.0f32, 3.0], &[2], &device), + true, + ); + + let y = checkpoint(|inputs, c| var_mul(&inputs[0], &inputs[0], c), &[&x]).unwrap(); + + let loss = var_sum(&y, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let gx: Vec = grads.get(x.id()).unwrap().to_vec(); + assert!((gx[0] - 4.0).abs() < 1e-6); + assert!((gx[1] - 6.0).abs() < 1e-6); + } +} diff --git a/src/autograd/dual.rs b/src/autograd/dual.rs index c0e32da2..f119ec50 100644 --- a/src/autograd/dual.rs +++ b/src/autograd/dual.rs @@ -92,7 +92,10 @@ impl DualTensor { /// /// The tangent is initialized to all ones with the same shape as the primal. /// This is useful when computing the derivative of a scalar function. - pub fn with_unit_tangent(primal: Tensor, device: &R::Device) -> Self { + pub fn with_unit_tangent(primal: Tensor, device: &R::Device) -> Self + where + R: Runtime, + { let tangent = Tensor::ones(primal.shape(), primal.dtype(), device); Self { primal, @@ -144,7 +147,10 @@ impl DualTensor { /// Get the data type #[inline] - pub fn dtype(&self) -> DType { + pub fn dtype(&self) -> DType + where + R: Runtime, + { self.primal.dtype() } @@ -178,7 +184,10 @@ impl DualTensor { /// /// This is useful when we need an explicit zero tangent for operations /// that can't handle `Option` directly. - pub fn zero_tangent(&self, device: &R::Device) -> Tensor { + pub fn zero_tangent(&self, device: &R::Device) -> Tensor + where + R: Runtime, + { Tensor::zeros(self.primal.shape(), self.primal.dtype(), device) } } diff --git a/src/autograd/dual_ops/activation.rs b/src/autograd/dual_ops/activation.rs index 332faebc..67805347 100644 --- a/src/autograd/dual_ops/activation.rs +++ b/src/autograd/dual_ops/activation.rs @@ -1,6 +1,7 @@ //! Activation operations on dual tensors use crate::autograd::DualTensor; +use crate::dtype::DType; use crate::error::Result; use crate::ops::{ActivationOps, BinaryOps, CompareOps, ScalarOps, TensorOps}; use crate::runtime::{Runtime, RuntimeClient}; @@ -9,7 +10,7 @@ use crate::tensor::Tensor; /// Dual ReLU: relu(a, ȧ) = (relu(a), ȧ * (a > 0)) pub fn dual_relu(a: &DualTensor, client: &C) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + ActivationOps + CompareOps + BinaryOps + TensorOps, { let primal = client.relu(a.primal())?; @@ -30,7 +31,7 @@ where /// Dual sigmoid: sigmoid(a, ȧ) = (σ(a), σ(a) * (1 - σ(a)) * ȧ) pub fn dual_sigmoid(a: &DualTensor, client: &C) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + ActivationOps + BinaryOps + ScalarOps, { let primal = client.sigmoid(a.primal())?; diff --git a/src/autograd/dual_ops/unary.rs b/src/autograd/dual_ops/unary.rs index e70b73a7..b6a7e3e3 100644 --- a/src/autograd/dual_ops/unary.rs +++ b/src/autograd/dual_ops/unary.rs @@ -1,6 +1,7 @@ //! Unary operations on dual tensors use crate::autograd::DualTensor; +use crate::dtype::DType; use crate::error::Result; use crate::ops::{BinaryOps, ScalarOps, UnaryOps}; use crate::runtime::{Runtime, RuntimeClient}; @@ -138,7 +139,7 @@ where /// Dual hyperbolic tangent: tanh(a, ȧ) = (tanh(a), (1 - tanh²(a)) * ȧ) pub fn dual_tanh(a: &DualTensor, client: &C) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + UnaryOps + BinaryOps + ScalarOps, { let primal = client.tanh(a.primal())?; diff --git a/src/autograd/forward.rs b/src/autograd/forward.rs index 655566c2..c69e9673 100644 --- a/src/autograd/forward.rs +++ b/src/autograd/forward.rs @@ -53,6 +53,7 @@ //! ``` use super::DualTensor; +use crate::dtype::DType; use crate::error::Result; use crate::ops::TensorOps; use crate::runtime::{Runtime, RuntimeClient}; @@ -116,7 +117,7 @@ pub fn jvp( client: &C, ) -> Result<(Tensor, Tensor)> where - R: Runtime, + R: Runtime, C: RuntimeClient + TensorOps, F: FnOnce(&[DualTensor], &C) -> Result>, { @@ -175,7 +176,7 @@ pub fn jvp_multi( client: &C, ) -> Result<(Vec>, Vec>)> where - R: Runtime, + R: Runtime, C: RuntimeClient + TensorOps, F: FnOnce(&[DualTensor], &C) -> Result>>, { @@ -246,7 +247,7 @@ where /// ``` pub fn jacobian_forward(f: F, x: &Tensor, client: &C) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + TensorOps, F: Fn(&DualTensor, &C) -> Result>, { @@ -323,7 +324,7 @@ where /// second-order derivatives through the existing reverse-mode infrastructure. pub fn hvp(grad_f: F, x: &Tensor, v: &Tensor, client: &C) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + TensorOps, F: Fn(&DualTensor, &C) -> Result>, { diff --git a/src/autograd/mod.rs b/src/autograd/mod.rs index cb66a42c..e2d2138c 100644 --- a/src/autograd/mod.rs +++ b/src/autograd/mod.rs @@ -107,6 +107,7 @@ // Reverse-mode AD mod backward; +mod checkpoint; mod grad_fn; mod grad_store; mod var; @@ -122,17 +123,27 @@ pub mod ops; // Reverse-mode exports pub use crate::tensor::id::TensorId; -pub use backward::{backward, backward_with_graph}; +pub use backward::{BackwardHook, NoOpHook, backward, backward_with_graph, backward_with_hooks}; +pub use checkpoint::checkpoint; pub use grad_fn::GradFn; pub use grad_store::GradStore; pub use var::Var; pub use var_grad_store::VarGradStore; +pub use var_ops::var_dropout; pub use var_ops::{ - var_abs, var_add, var_add_scalar, var_cholesky, var_clamp, var_cos, var_cumprod, var_cumsum, - var_det, var_div, var_div_scalar, var_exp, var_gather, var_inverse, var_log, var_matmul, + var_abs, var_add, var_add_scalar, var_cast, var_cholesky, var_clamp, var_conv1d, var_conv2d, + var_cos, var_cumprod, var_cumsum, var_det, var_div, var_div_scalar, var_exp, + var_fused_add_layer_norm, var_fused_add_rms_norm, var_gather, var_gelu_mul, var_group_norm, + var_inverse, var_layer_norm, var_log, var_log_softmax, var_matmul, var_matmul_bias_activation, var_max, var_mean, var_min, var_mul, var_mul_scalar, var_neg, var_pow, var_pow_scalar, - var_recip, var_relu, var_sigmoid, var_sin, var_softmax, var_solve, var_sqrt, var_square, - var_std, var_sub, var_sub_scalar, var_sum, var_tan, var_tanh, var_trace, var_var, + var_recip, var_relu, var_relu_mul, var_rms_norm, var_sigmoid, var_sigmoid_mul, var_silu, + var_silu_mul, var_sin, var_softmax, var_softplus, var_solve, var_sqrt, var_square, var_std, + var_sub, var_sub_scalar, var_sum, var_swiglu, var_tan, var_tanh, var_trace, var_var, +}; + +// Shape operation exports (re-exported via autograd::ops::*) +pub use self::ops::{ + var_broadcast_to, var_cat, var_narrow, var_permute, var_reshape, var_transpose, }; // Forward-mode exports diff --git a/src/autograd/ops/activation.rs b/src/autograd/ops/activation.rs index 767f471f..008e22a1 100644 --- a/src/autograd/ops/activation.rs +++ b/src/autograd/ops/activation.rs @@ -1,19 +1,17 @@ //! Backward implementations for activation functions //! -//! Implements gradient computation for relu, sigmoid, and softmax. +//! Implements gradient computation for relu, sigmoid, silu, softplus, softmax, and log_softmax. use crate::autograd::GradFn; use crate::autograd::var::Var; use crate::autograd::var_ops::{var_mul, var_sub, var_sum}; +use crate::dtype::DType; use crate::error::Result; -use crate::ops::{BinaryOps, CompareOps, ReduceOps, ScalarOps, TensorOps}; +use crate::ops::{ActivationOps, BinaryOps, CompareOps, ReduceOps, ScalarOps, TensorOps, UnaryOps}; use crate::runtime::{Runtime, RuntimeClient}; use crate::tensor::{Tensor, TensorId}; use std::sync::Arc; -#[cfg(test)] -use crate::ops::ActivationOps; - // ============================================================================ // ReluBackward // ============================================================================ @@ -43,7 +41,7 @@ impl ReluBackward { } } -impl GradFn for ReluBackward +impl> GradFn for ReluBackward where R::Client: TensorOps + CompareOps, { @@ -134,7 +132,7 @@ impl SigmoidBackward { } } -impl GradFn for SigmoidBackward +impl> GradFn for SigmoidBackward where R::Client: TensorOps, { @@ -198,6 +196,96 @@ where } } +// ============================================================================ +// SiluBackward +// ============================================================================ + +/// Backward for SiLU (Swish): z = a * sigmoid(a) +/// +/// Gradient: dL/da = dL/dz * (sigmoid(a) + a * sigmoid(a) * (1 - sigmoid(a))) +/// = dL/dz * (sigmoid(a) * (1 + a * (1 - sigmoid(a)))) +/// = dL/dz * (z/a * (1 + a - z)) [numerically: use saved input + output] +pub struct SiluBackward { + input_id: TensorId, + saved_input: Tensor, + saved_output: Tensor, // silu(a) + input_grad_fn: Option>>, +} + +impl SiluBackward { + /// Create a new SiluBackward + pub fn new( + input_id: TensorId, + input: Tensor, + output: Tensor, + input_grad_fn: Option>>, + ) -> Self { + Self { + input_id, + saved_input: input, + saved_output: output, + input_grad_fn, + } + } +} + +impl> GradFn for SiluBackward +where + R::Client: TensorOps + ActivationOps + ScalarOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + + // silu'(x) = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x)) + // = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + // = sigmoid(x) * (1 + x - x*sigmoid(x)) + // = sigmoid(x) * (1 + x - silu(x)) + let sigmoid = client.sigmoid(&self.saved_input)?; + let one_plus_x = client.add_scalar(&self.saved_input, 1.0)?; + let one_plus_x_minus_silu = client.sub(&one_plus_x, &self.saved_output)?; + let deriv = client.mul(&sigmoid, &one_plus_x_minus_silu)?; + let grad = client.mul(grad_output, &deriv)?; + + Ok(vec![Some(grad)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + TensorOps + ActivationOps + ScalarOps, + { + let client = R::default_client(grad_output.tensor().device()); + + let sigmoid = client.sigmoid(&self.saved_input)?; + let one_plus_x = client.add_scalar(&self.saved_input, 1.0)?; + let one_plus_x_minus_silu = client.sub(&one_plus_x, &self.saved_output)?; + let deriv = client.mul(&sigmoid, &one_plus_x_minus_silu)?; + + let deriv_var = Var::new(deriv, false); + let grad = var_mul(grad_output, &deriv_var, &client)?; + + Ok(vec![Some(grad)]) + } + + fn inputs(&self) -> &[TensorId] { + std::slice::from_ref(&self.input_id) + } + + fn input_grad_fns(&self) -> Vec>>> { + vec![self.input_grad_fn.clone()] + } + + fn saved_tensors(&self) -> &[Tensor] { + // Both saved_input and saved_output are stored internally for gradient computation. + // The trait returns a slice, so we expose only the input here; saved_output is + // accessed directly during backward() and backward_var(). + std::slice::from_ref(&self.saved_input) + } + + fn name(&self) -> &'static str { + "SiluBackward" + } +} + // ============================================================================ // SoftmaxBackward // ============================================================================ @@ -307,6 +395,181 @@ where } } +// ============================================================================ +// LogSoftmaxBackward +// ============================================================================ + +/// Backward for log_softmax: z = log(softmax(a, dim)) +/// +/// Gradient: dL/da = dL/dz - softmax(a) * sum(dL/dz, dim) +/// = dL/dz - exp(z) * sum(dL/dz, dim) +pub struct LogSoftmaxBackward { + input_id: TensorId, + saved_output: Tensor, // log_softmax(a) + dim: isize, + input_grad_fn: Option>>, +} + +impl LogSoftmaxBackward { + /// Create a new LogSoftmaxBackward + pub fn new( + input_id: TensorId, + output: Tensor, + dim: isize, + input_grad_fn: Option>>, + ) -> Self { + Self { + input_id, + saved_output: output, + dim, + input_grad_fn, + } + } +} + +impl> GradFn for LogSoftmaxBackward +where + R::Client: TensorOps + UnaryOps + ReduceOps + ScalarOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + + let ndim = self.saved_output.ndim(); + let dim_idx = if self.dim < 0 { + (ndim as isize + self.dim) as usize + } else { + self.dim as usize + }; + + // log_softmax gradient: grad_input = grad_output - exp(output) * sum(grad_output, dim) + let softmax_output = client.exp(&self.saved_output)?; + let sum_grad = client.sum(grad_output, &[dim_idx], true)?; + let softmax_sum = client.mul(&softmax_output, &sum_grad)?; + let grad = client.sub(grad_output, &softmax_sum)?; + + Ok(vec![Some(grad)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + TensorOps + UnaryOps + ReduceOps + ScalarOps, + { + let client = R::default_client(grad_output.tensor().device()); + + let ndim = self.saved_output.ndim(); + let dim_idx = if self.dim < 0 { + (ndim as isize + self.dim) as usize + } else { + self.dim as usize + }; + + // exp(log_softmax(x)) = softmax(x), treated as constant + let softmax_output = client.exp(&self.saved_output)?; + let softmax_var = Var::new(softmax_output, false); + + let sum_grad = var_sum(grad_output, &[dim_idx], true, &client)?; + let softmax_sum = var_mul(&softmax_var, &sum_grad, &client)?; + let grad = var_sub(grad_output, &softmax_sum, &client)?; + + Ok(vec![Some(grad)]) + } + + fn inputs(&self) -> &[TensorId] { + std::slice::from_ref(&self.input_id) + } + + fn input_grad_fns(&self) -> Vec>>> { + vec![self.input_grad_fn.clone()] + } + + fn saved_tensors(&self) -> &[Tensor] { + std::slice::from_ref(&self.saved_output) + } + + fn name(&self) -> &'static str { + "LogSoftmaxBackward" + } +} + +// ============================================================================ +// SoftplusBackward +// ============================================================================ + +/// Backward for softplus: `z = log(1 + exp(a))` +/// +/// Gradient: `dL/da = dL/dz * sigmoid(a)` +/// +/// `d/da log(1 + exp(a)) = exp(a) / (1 + exp(a)) = sigmoid(a)` +/// +/// The backward is numerically stable since `sigmoid` is bounded in `(0, 1)`. +/// The forward must be computed via the stable form `relu(a) + log(1 + exp(-|a|))` +/// (see `softplus_impl`) — never the naive `log(1 + exp(a))` which overflows for +/// large positive inputs. +pub struct SoftplusBackward { + input_id: TensorId, + saved_input: Tensor, + input_grad_fn: Option>>, +} + +impl SoftplusBackward { + /// Create a new SoftplusBackward + pub fn new( + input_id: TensorId, + input: Tensor, + input_grad_fn: Option>>, + ) -> Self { + Self { + input_id, + saved_input: input, + input_grad_fn, + } + } +} + +impl> GradFn for SoftplusBackward +where + R::Client: TensorOps + ActivationOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + + // softplus'(x) = sigmoid(x) + let sigmoid = client.sigmoid(&self.saved_input)?; + let grad = client.mul(grad_output, &sigmoid)?; + + Ok(vec![Some(grad)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + TensorOps + ActivationOps, + { + let client = R::default_client(grad_output.tensor().device()); + + let sigmoid = client.sigmoid(&self.saved_input)?; + let sigmoid_var = Var::new(sigmoid, false); + let grad = var_mul(grad_output, &sigmoid_var, &client)?; + + Ok(vec![Some(grad)]) + } + + fn inputs(&self) -> &[TensorId] { + std::slice::from_ref(&self.input_id) + } + + fn input_grad_fns(&self) -> Vec>>> { + vec![self.input_grad_fn.clone()] + } + + fn saved_tensors(&self) -> &[Tensor] { + std::slice::from_ref(&self.saved_input) + } + + fn name(&self) -> &'static str { + "SoftplusBackward" + } +} + #[cfg(test)] mod tests { use super::*; @@ -370,6 +633,247 @@ mod tests { assert!((grad_data[0] - 0.25).abs() < 1e-6); } + #[test] + fn test_silu_backward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // silu(0) = 0 * sigmoid(0) = 0 * 0.5 = 0 + // silu'(0) = sigmoid(0) * (1 + 0 * (1 - sigmoid(0))) = 0.5 * 1 = 0.5 + let input = Tensor::::from_slice(&[0.0f32], &[1], &device); + let output = client.silu(&input).unwrap(); + + let grad_out = Tensor::::ones(&[1], DType::F32, &device); + + let backward = SiluBackward::::new(input.id(), input.clone(), output, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_data: Vec = grads[0].as_ref().unwrap().to_vec(); + assert!((grad_data[0] - 0.5).abs() < 1e-6); + } + + #[test] + fn test_silu_backward_nonzero() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // silu(1) = 1 * sigmoid(1) ≈ 0.7311 + // silu'(1) = sigmoid(1) * (1 + 1 * (1 - sigmoid(1))) + // ≈ 0.7311 * (1 + 1 * 0.2689) ≈ 0.7311 * 1.2689 ≈ 0.9277 + let input = Tensor::::from_slice(&[1.0f32], &[1], &device); + let output = client.silu(&input).unwrap(); + + let grad_out = Tensor::::ones(&[1], DType::F32, &device); + + let backward = SiluBackward::::new(input.id(), input.clone(), output, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_data: Vec = grads[0].as_ref().unwrap().to_vec(); + let sigmoid_1 = 1.0f32 / (1.0 + (-1.0f32).exp()); + let expected = sigmoid_1 * (1.0 + 1.0 * (1.0 - sigmoid_1)); + assert!((grad_data[0] - expected).abs() < 1e-5); + } + + #[test] + fn test_silu_backward_2d() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // Shape [2, 3] — verifies element-wise gradient correctness on batched tensors. + // silu'(x) = sigmoid(x) * (1 + x - silu(x)) + let data = [-1.0f32, 0.0, 1.0, 2.0, -2.0, 0.5]; + let input = Tensor::::from_slice(&data, &[2, 3], &device); + let output = client.silu(&input).unwrap(); + let grad_out = Tensor::::ones(&[2, 3], DType::F32, &device); + + let backward = + SiluBackward::::new(input.id(), input.clone(), output.clone(), None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_data: Vec = grads[0].as_ref().unwrap().to_vec(); + let out_data: Vec = output.to_vec(); + + for (i, &x) in data.iter().enumerate() { + let sigmoid_x = 1.0f32 / (1.0 + (-x).exp()); + let expected = sigmoid_x * (1.0 + x - out_data[i]); + assert!( + (grad_data[i] - expected).abs() < 1e-5, + "mismatch at index {i}: got {}, expected {expected}", + grad_data[i] + ); + } + } + + #[test] + fn test_silu_backward_negative_gradient() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // Verify chain rule: grad_output scales the derivative correctly. + let input = Tensor::::from_slice(&[1.0f32, -1.0], &[2], &device); + let output = client.silu(&input).unwrap(); + + // grad_output = [2.0, 3.0] — non-unit upstream gradient + let grad_out = Tensor::::from_slice(&[2.0f32, 3.0], &[2], &device); + + let backward = + SiluBackward::::new(input.id(), input.clone(), output.clone(), None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_data: Vec = grads[0].as_ref().unwrap().to_vec(); + let out_data: Vec = output.to_vec(); + let upstream = [2.0f32, 3.0]; + + for (i, (&x, &up)) in [1.0f32, -1.0].iter().zip(upstream.iter()).enumerate() { + let sigmoid_x = 1.0f32 / (1.0 + (-x).exp()); + let local_deriv = sigmoid_x * (1.0 + x - out_data[i]); + let expected = up * local_deriv; + assert!( + (grad_data[i] - expected).abs() < 1e-5, + "mismatch at index {i}: got {}, expected {expected}", + grad_data[i] + ); + } + } + + #[test] + fn test_softplus_backward() { + let device = CpuDevice::new(); + + // softplus(0) = log(1 + exp(0)) = log(2) ≈ 0.6931 + // softplus'(0) = sigmoid(0) = 0.5 + let input = Tensor::::from_slice(&[0.0f32], &[1], &device); + let grad_out = Tensor::::ones(&[1], DType::F32, &device); + + let backward = SoftplusBackward::::new(input.id(), input, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_data: Vec = grads[0].as_ref().unwrap().to_vec(); + assert!((grad_data[0] - 0.5).abs() < 1e-6); + } + + #[test] + fn test_softplus_backward_nonzero() { + let device = CpuDevice::new(); + + // softplus'(x) = sigmoid(x) + let input = Tensor::::from_slice(&[1.0f32, -1.0, 2.0], &[3], &device); + let grad_out = Tensor::::ones(&[3], DType::F32, &device); + + let backward = SoftplusBackward::::new(input.id(), input, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_data: Vec = grads[0].as_ref().unwrap().to_vec(); + for (i, &x) in [1.0f32, -1.0, 2.0].iter().enumerate() { + let expected = 1.0 / (1.0 + (-x).exp()); + assert!( + (grad_data[i] - expected).abs() < 1e-5, + "mismatch at {i}: got {}, expected {expected}", + grad_data[i] + ); + } + } + + #[test] + fn test_softplus_backward_large_positive() { + let device = CpuDevice::new(); + + // For large positive x, sigmoid(x) → 1.0; must not produce NaN. + // This exercises the numerical stability of the stable softplus formula. + let input = Tensor::::from_slice(&[100.0f32], &[1], &device); + let grad_out = Tensor::::ones(&[1], DType::F32, &device); + + let backward = SoftplusBackward::::new(input.id(), input, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_data: Vec = grads[0].as_ref().unwrap().to_vec(); + assert!( + !grad_data[0].is_nan(), + "gradient must not be NaN for large positive input" + ); + assert!( + !grad_data[0].is_infinite(), + "gradient must not be Inf for large positive input" + ); + // sigmoid(100) ≈ 1.0 + assert!((grad_data[0] - 1.0).abs() < 1e-5); + } + + #[test] + fn test_softplus_backward_large_negative() { + let device = CpuDevice::new(); + + // For large negative x, sigmoid(x) → 0.0; must not produce NaN. + let input = Tensor::::from_slice(&[-100.0f32], &[1], &device); + let grad_out = Tensor::::ones(&[1], DType::F32, &device); + + let backward = SoftplusBackward::::new(input.id(), input, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_data: Vec = grads[0].as_ref().unwrap().to_vec(); + assert!( + !grad_data[0].is_nan(), + "gradient must not be NaN for large negative input" + ); + assert!( + !grad_data[0].is_infinite(), + "gradient must not be Inf for large negative input" + ); + // sigmoid(-100) ≈ 0.0 + assert!(grad_data[0].abs() < 1e-5); + } + + #[test] + fn test_softplus_backward_2d() { + let device = CpuDevice::new(); + + // Shape [2, 3] — verifies element-wise gradient on batched tensors. + let data = [-2.0f32, -1.0, 0.0, 1.0, 2.0, 100.0]; + let input = Tensor::::from_slice(&data, &[2, 3], &device); + let grad_out = Tensor::::ones(&[2, 3], DType::F32, &device); + + let backward = SoftplusBackward::::new(input.id(), input, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_data: Vec = grads[0].as_ref().unwrap().to_vec(); + for (i, &x) in data.iter().enumerate() { + let expected = 1.0f32 / (1.0 + (-x).exp()); + assert!( + !grad_data[i].is_nan(), + "gradient NaN at index {i} for x={x}" + ); + assert!( + (grad_data[i] - expected).abs() < 1e-4, + "mismatch at index {i} for x={x}: got {}, expected {expected}", + grad_data[i] + ); + } + } + + #[test] + fn test_softplus_backward_non_unit_gradient() { + let device = CpuDevice::new(); + + // Verify chain rule: upstream gradient scales local derivative. + let input = Tensor::::from_slice(&[0.0f32, 1.0], &[2], &device); + let grad_out = Tensor::::from_slice(&[2.0f32, 3.0], &[2], &device); + + let backward = SoftplusBackward::::new(input.id(), input, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_data: Vec = grads[0].as_ref().unwrap().to_vec(); + let upstream = [2.0f32, 3.0]; + for (i, (&x, &up)) in [0.0f32, 1.0].iter().zip(upstream.iter()).enumerate() { + let sigmoid_x = 1.0f32 / (1.0 + (-x).exp()); + let expected = up * sigmoid_x; + assert!( + (grad_data[i] - expected).abs() < 1e-5, + "mismatch at index {i}: got {}, expected {expected}", + grad_data[i] + ); + } + } + #[test] fn test_softmax_backward() { let device = CpuDevice::new(); @@ -394,4 +898,33 @@ mod tests { assert!((grad_data[0] - 0.25).abs() < 1e-6); assert!((grad_data[1] - (-0.25)).abs() < 1e-6); } + + #[test] + fn test_log_softmax_backward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // Simple 2-element log_softmax + let input = Tensor::::from_slice(&[0.0f32, 0.0], &[2], &device); + let output = client.log_softmax(&input, -1).unwrap(); // [ln(0.5), ln(0.5)] + + let output_data: Vec = output.to_vec(); + let expected_log = (0.5f32).ln(); + assert!((output_data[0] - expected_log).abs() < 1e-6); + assert!((output_data[1] - expected_log).abs() < 1e-6); + + // dL/dz = [1, 0] + let grad_out = Tensor::::from_slice(&[1.0f32, 0.0], &[2], &device); + + let backward = LogSoftmaxBackward::::new(input.id(), output, -1, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_data: Vec = grads[0].as_ref().unwrap().to_vec(); + // log_softmax gradient: grad = dy - exp(z) * sum(dy, dim) + // exp(z) = [0.5, 0.5], sum(dy) = 1.0 + // grad[0] = 1.0 - 0.5 * 1.0 = 0.5 + // grad[1] = 0.0 - 0.5 * 1.0 = -0.5 + assert!((grad_data[0] - 0.5).abs() < 1e-6); + assert!((grad_data[1] - (-0.5)).abs() < 1e-6); + } } diff --git a/src/autograd/ops/cast.rs b/src/autograd/ops/cast.rs new file mode 100644 index 00000000..75ae31c2 --- /dev/null +++ b/src/autograd/ops/cast.rs @@ -0,0 +1,53 @@ +//! Backward implementation for dtype cast operation +//! +//! The backward of cast(x, target_dtype) is cast(grad_output, input_dtype). + +use crate::autograd::GradFn; +use crate::dtype::DType; +use crate::error::Result; +use crate::ops::TypeConversionOps; +use crate::runtime::Runtime; +use crate::tensor::{Tensor, TensorId}; + +/// Backward for cast: z = cast(a, target_dtype) +/// +/// Gradient: dL/da = cast(dL/dz, a.dtype) +pub struct CastBackward { + input_id: TensorId, + input_dtype: DType, + _marker: std::marker::PhantomData, +} + +impl CastBackward { + /// Create a new CastBackward + pub fn new(input_id: TensorId, input_dtype: DType) -> Self { + Self { + input_id, + input_dtype, + _marker: std::marker::PhantomData, + } + } +} + +impl> GradFn for CastBackward +where + R::Client: TypeConversionOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + let grad = if grad_output.dtype() == self.input_dtype { + grad_output.clone() + } else { + client.cast(grad_output, self.input_dtype)? + }; + Ok(vec![Some(grad)]) + } + + fn inputs(&self) -> &[TensorId] { + std::slice::from_ref(&self.input_id) + } + + fn name(&self) -> &'static str { + "CastBackward" + } +} diff --git a/src/autograd/ops/gemm_epilogue.rs b/src/autograd/ops/gemm_epilogue.rs new file mode 100644 index 00000000..b8f54cab --- /dev/null +++ b/src/autograd/ops/gemm_epilogue.rs @@ -0,0 +1,237 @@ +//! Backward implementation for fused GEMM + bias + activation + +use crate::autograd::GradFn; +use crate::autograd::var::Var; +use crate::autograd::var_ops::{var_matmul, var_sum}; +use crate::error::Result; +use crate::ops::{BinaryOps, GemmActivation, MatmulOps, ReduceOps, ScalarOps, TensorOps, UnaryOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use crate::tensor::{Tensor, TensorId}; +use std::sync::Arc; + +/// Backward for fused GEMM + bias + activation: output = activation(A @ B + bias) +pub struct MatmulBiasActivationBackward { + input_ids: [TensorId; 3], + saved_tensors: Vec>, // [a, b, bias] + activation: GemmActivation, + input_grad_fns: [Option>>; 3], +} + +impl MatmulBiasActivationBackward { + /// Create a new MatmulBiasActivationBackward + pub fn new( + a_id: TensorId, + b_id: TensorId, + bias_id: TensorId, + a: Tensor, + b: Tensor, + bias: Tensor, + activation: GemmActivation, + a_grad_fn: Option>>, + b_grad_fn: Option>>, + bias_grad_fn: Option>>, + ) -> Self { + Self { + input_ids: [a_id, b_id, bias_id], + saved_tensors: vec![a, b, bias], + activation, + input_grad_fns: [a_grad_fn, b_grad_fn, bias_grad_fn], + } + } +} + +impl GradFn for MatmulBiasActivationBackward +where + R::Client: + TensorOps + ScalarOps + BinaryOps + ReduceOps + UnaryOps + MatmulOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + let a = &self.saved_tensors[0]; + let b = &self.saved_tensors[1]; + let bias = &self.saved_tensors[2]; + + // Recompute pre_activation = A @ B + bias + let matmul_out = client.matmul(a, b)?; + let pre_act = client.add(&matmul_out, bias)?; + + // Compute activation gradient: grad_pre = grad_output * activation'(pre_act) + let grad_pre = apply_activation_grad(&client, grad_output, &pre_act, self.activation)?; + + // d_a = grad_pre @ B^T + let b_t = b.transpose(-2, -1)?; + let d_a = client.matmul(&grad_pre, &b_t)?; + + // d_b = A^T @ grad_pre + let a_t = a.transpose(-2, -1)?; + let d_b = client.matmul(&a_t, &grad_pre)?; + + // d_bias = sum(grad_pre, batch_and_row_dims) + let ndim = grad_output.ndim(); + let batch_dims: Vec = (0..ndim - 1).collect(); + let d_bias = if batch_dims.is_empty() { + grad_pre + } else { + client.sum(&grad_pre, &batch_dims, false)? + }; + + Ok(vec![Some(d_a), Some(d_b), Some(d_bias)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + + TensorOps + + ScalarOps + + BinaryOps + + ReduceOps + + UnaryOps + + MatmulOps, + { + let client = R::default_client(grad_output.tensor().device()); + let a = &self.saved_tensors[0]; + let b = &self.saved_tensors[1]; + let bias = &self.saved_tensors[2]; + + // Recompute pre_activation from saved tensors + let matmul_out = client.matmul(a, b)?; + let pre_act = client.add(&matmul_out, bias)?; + + // Compute activation gradient as a constant tensor + let ones = client.add_scalar(&client.mul_scalar(&pre_act, 0.0)?, 1.0)?; + let act_grad = apply_activation_grad(&client, &ones, &pre_act, self.activation)?; + + // grad_pre = grad_output * activation'(pre_act) + let act_grad_var = Var::new(act_grad, false); + let grad_pre = crate::autograd::var_ops::var_mul(grad_output, &act_grad_var, &client)?; + + // d_a = grad_pre @ B^T + let b_t = b.transpose(-2, -1)?; + let b_t_var = Var::new(b_t, false); + let d_a = var_matmul(&grad_pre, &b_t_var, &client)?; + + // d_b = A^T @ grad_pre + let a_t = a.transpose(-2, -1)?; + let a_t_var = Var::new(a_t, false); + let d_b = var_matmul(&a_t_var, &grad_pre, &client)?; + + // d_bias = sum(grad_pre, batch_dims) + let ndim = grad_output.tensor().ndim(); + let batch_dims: Vec = (0..ndim - 1).collect(); + let d_bias = if batch_dims.is_empty() { + grad_pre + } else { + var_sum(&grad_pre, &batch_dims, false, &client)? + }; + + Ok(vec![Some(d_a), Some(d_b), Some(d_bias)]) + } + + fn inputs(&self) -> &[TensorId] { + &self.input_ids + } + + fn input_grad_fns(&self) -> Vec>>> { + self.input_grad_fns.to_vec() + } + + fn saved_tensors(&self) -> &[Tensor] { + &self.saved_tensors + } + + fn name(&self) -> &'static str { + "MatmulBiasActivationBackward" + } +} + +/// Compute grad_output * activation'(pre_act) using only basic ops +fn apply_activation_grad( + client: &R::Client, + grad: &Tensor, + pre_act: &Tensor, + activation: GemmActivation, +) -> Result> +where + R::Client: TensorOps + ScalarOps + BinaryOps + UnaryOps, +{ + match activation { + GemmActivation::None => { + // Identity: derivative is 1, so just return grad + Ok(grad.clone()) + } + GemmActivation::ReLU => { + // ReLU': 1 if x > 0, 0 if x <= 0 + // Approximate mask: clamp(sign(x), 0, 1) using: (x + |x|) / (2 * |x| + eps) + // Simpler: use step = (sign(x) + 1) / 2 where sign uses abs + let abs_x = client.abs(pre_act)?; + // For x > 0: sign = x/|x| = 1, for x < 0: sign = -1, x=0: 0 + let abs_plus_eps = client.add_scalar(&abs_x, 1e-30)?; + let sign = client.div(pre_act, &abs_plus_eps)?; + // mask = (sign + 1) / 2: maps 1->1, -1->0, 0->0.5 (close enough) + let mask = client.mul_scalar(&client.add_scalar(&sign, 1.0)?, 0.5)?; + client.mul(grad, &mask) + } + GemmActivation::Sigmoid => { + // sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x)) + // sigmoid(x) = 1 / (1 + exp(-x)) + let neg_x = client.neg(pre_act)?; + let exp_neg = client.exp(&neg_x)?; + let one_plus = client.add_scalar(&exp_neg, 1.0)?; + let sig = client.recip(&one_plus)?; + let one_minus_sig = client.rsub_scalar(&sig, 1.0)?; + let deriv = client.mul(&sig, &one_minus_sig)?; + client.mul(grad, &deriv) + } + GemmActivation::Tanh => { + // tanh'(x) = 1 - tanh(x)^2 + let t = client.tanh(pre_act)?; + let t_sq = client.mul(&t, &t)?; + let deriv = client.rsub_scalar(&t_sq, 1.0)?; + client.mul(grad, &deriv) + } + GemmActivation::SiLU => { + // silu(x) = x * sigmoid(x) + // silu'(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + let neg_x = client.neg(pre_act)?; + let exp_neg = client.exp(&neg_x)?; + let one_plus = client.add_scalar(&exp_neg, 1.0)?; + let sig = client.recip(&one_plus)?; + let one_minus_sig = client.rsub_scalar(&sig, 1.0)?; + let x_one_minus_sig = client.mul(pre_act, &one_minus_sig)?; + let inner = client.add_scalar(&x_one_minus_sig, 1.0)?; + let deriv = client.mul(&sig, &inner)?; + client.mul(grad, &deriv) + } + GemmActivation::GELU => { + // GELU(x) = 0.5 * x * (1 + tanh(k)), k = sqrt(2/pi) * (x + 0.044715 * x^3) + // d/dx = 0.5 * (1 + tanh(k)) + 0.5 * x * sech²(k) * dk/dx + // dk/dx = sqrt(2/pi) * (1 + 3*0.044715*x²) + let sqrt_2_pi: f64 = (2.0f64 / std::f64::consts::PI).sqrt(); + let x_sq = client.mul(pre_act, pre_act)?; + let x_cubed = client.mul(pre_act, &x_sq)?; + let inner = client.add(pre_act, &client.mul_scalar(&x_cubed, 0.044715)?)?; + let k = client.mul_scalar(&inner, sqrt_2_pi)?; + let tanh_k = client.tanh(&k)?; + + // 0.5 * (1 + tanh(k)) + let term1 = client.mul_scalar(&client.add_scalar(&tanh_k, 1.0)?, 0.5)?; + + // sech²(k) = 1 - tanh²(k) + let tanh_sq = client.mul(&tanh_k, &tanh_k)?; + let sech_sq = client.rsub_scalar(&tanh_sq, 1.0)?; + + // dk/dx = sqrt(2/pi) * (1 + 3 * 0.044715 * x²) + let dk_dx = client.mul_scalar( + &client.add_scalar(&client.mul_scalar(&x_sq, 3.0 * 0.044715)?, 1.0)?, + sqrt_2_pi, + )?; + + // 0.5 * x * sech²(k) * dk/dx + let term2 = + client.mul_scalar(&client.mul(pre_act, &client.mul(&sech_sq, &dk_dx)?)?, 0.5)?; + + let deriv = client.add(&term1, &term2)?; + client.mul(grad, &deriv) + } + } +} diff --git a/src/autograd/ops/indexing.rs b/src/autograd/ops/indexing.rs index d901bcad..de5ff719 100644 --- a/src/autograd/ops/indexing.rs +++ b/src/autograd/ops/indexing.rs @@ -2,6 +2,7 @@ use crate::autograd::GradFn; use crate::autograd::var::Var; +use crate::dtype::DType; use crate::error::Result; use crate::ops::IndexingOps; use crate::runtime::Runtime; @@ -42,7 +43,7 @@ impl GatherBackward { } } -impl GradFn for GatherBackward +impl> GradFn for GatherBackward where R::Client: IndexingOps, { diff --git a/src/autograd/ops/linalg.rs b/src/autograd/ops/linalg.rs index 32820cdd..ca0ccdcc 100644 --- a/src/autograd/ops/linalg.rs +++ b/src/autograd/ops/linalg.rs @@ -20,6 +20,7 @@ use crate::algorithm::LinearAlgebraAlgorithms; use crate::autograd::var_ops::{var_matmul, var_mul, var_neg}; use crate::autograd::{GradFn, Var}; +use crate::dtype::DType; use crate::error::Result; use crate::ops::{ BinaryOps, LinalgOps, MatmulOps, ScalarOps, TensorOps, TypeConversionOps, UnaryOps, @@ -44,7 +45,10 @@ use std::sync::Arc; /// # Returns /// A tensor where upper triangular elements are zero, lower triangular elements /// are unchanged, and diagonal elements are halved. -fn tril_with_halved_diagonal(x: &Tensor, client: &R::Client) -> Result> +fn tril_with_halved_diagonal>( + x: &Tensor, + client: &R::Client, +) -> Result> where R::Client: TensorOps + ScalarOps, { @@ -100,7 +104,7 @@ impl TraceBackward { } } -impl GradFn for TraceBackward +impl> GradFn for TraceBackward where R::Client: TensorOps + ScalarOps + LinearAlgebraAlgorithms, { @@ -523,7 +527,7 @@ impl CholeskyBackward { } } -impl GradFn for CholeskyBackward +impl> GradFn for CholeskyBackward where R::Client: MatmulOps + TensorOps + ScalarOps + LinearAlgebraAlgorithms, { diff --git a/src/autograd/ops/mod.rs b/src/autograd/ops/mod.rs index 2499c22e..63d3b8dd 100644 --- a/src/autograd/ops/mod.rs +++ b/src/autograd/ops/mod.rs @@ -16,10 +16,13 @@ mod activation; mod arithmetic; +mod cast; mod cumulative; +mod gemm_epilogue; mod indexing; mod linalg; mod matmul; +mod normalization; mod reduce; mod scalar; mod shape; @@ -27,10 +30,13 @@ mod unary; pub use activation::*; pub use arithmetic::*; +pub use cast::*; pub use cumulative::*; +pub use gemm_epilogue::*; pub use indexing::*; pub use linalg::*; pub use matmul::*; +pub use normalization::*; pub use reduce::*; pub use scalar::*; pub use shape::*; diff --git a/src/autograd/ops/normalization/fused_add_layer_norm.rs b/src/autograd/ops/normalization/fused_add_layer_norm.rs new file mode 100644 index 00000000..0e6bf148 --- /dev/null +++ b/src/autograd/ops/normalization/fused_add_layer_norm.rs @@ -0,0 +1,223 @@ +//! Backward implementation for Fused Add + Layer Normalization + +use crate::autograd::GradFn; +use crate::autograd::var::Var; +use crate::autograd::var_ops::{var_mean, var_mul, var_sub, var_sum}; +use crate::error::Result; +use crate::ops::{BinaryOps, NormalizationOps, ReduceOps, ScalarOps, TensorOps, UnaryOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use crate::tensor::{Tensor, TensorId}; +use std::sync::Arc; + +/// Backward for Fused Add + Layer Normalization: +/// pre_norm = x + residual, output = layer_norm(pre_norm, weight, bias, eps) +/// +/// Gradients: +/// - d_input_residual = shared gradient for both x and residual +/// - d_weight = sum(grad_out * x_norm, batch_dims) +/// - d_bias = sum(grad_out, batch_dims) +pub struct FusedAddLayerNormBackward { + input_ids: [TensorId; 4], + saved_tensors: Vec>, // [pre_norm, weight, bias] + eps: f32, + input_grad_fns: [Option>>; 4], +} + +impl FusedAddLayerNormBackward { + /// Create a new FusedAddLayerNormBackward + pub fn new( + x_id: TensorId, + residual_id: TensorId, + weight_id: TensorId, + bias_id: TensorId, + pre_norm: Tensor, + weight: Tensor, + bias: Tensor, + eps: f32, + x_grad_fn: Option>>, + residual_grad_fn: Option>>, + weight_grad_fn: Option>>, + bias_grad_fn: Option>>, + ) -> Self { + Self { + input_ids: [x_id, residual_id, weight_id, bias_id], + saved_tensors: vec![pre_norm, weight, bias], + eps, + input_grad_fns: [x_grad_fn, residual_grad_fn, weight_grad_fn, bias_grad_fn], + } + } +} + +impl GradFn for FusedAddLayerNormBackward +where + R::Client: TensorOps + ScalarOps + BinaryOps + ReduceOps + UnaryOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + let pre_norm = &self.saved_tensors[0]; + let weight = &self.saved_tensors[1]; + let bias = &self.saved_tensors[2]; + + let (d_input_residual, d_weight, d_bias) = + client.fused_add_layer_norm_bwd(grad_output, pre_norm, weight, bias, self.eps)?; + + Ok(vec![ + Some(d_input_residual.clone()), + Some(d_input_residual), + Some(d_weight), + Some(d_bias), + ]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + + TensorOps + + ScalarOps + + BinaryOps + + ReduceOps + + UnaryOps, + { + let client = R::default_client(grad_output.tensor().device()); + let pre_norm = &self.saved_tensors[0]; + let weight = &self.saved_tensors[1]; + let ndim = pre_norm.ndim(); + let last_dim = ndim - 1; + + // Recompute rstd and x_norm from pre_norm (treat as constants) + let mu = client.mean(pre_norm, &[last_dim], true)?; + let x_centered = client.sub(pre_norm, &mu)?; + let x_centered_sq = client.mul(&x_centered, &x_centered)?; + let variance = client.mean(&x_centered_sq, &[last_dim], true)?; + let variance_eps = client.add_scalar(&variance, self.eps as f64)?; + let std = client.sqrt(&variance_eps)?; + let rstd = client.recip(&std)?; + let x_norm = client.mul(&x_centered, &rstd)?; + + // Wrap as non-differentiable Vars + let rstd_var = Var::new(rstd, false); + let x_norm_var = Var::new(x_norm, false); + let weight_var = Var::new(weight.clone(), false); + + // d_input_residual = rstd * (gw - mean(gw) - x_norm * mean(gw * x_norm)) + let gw = var_mul(grad_output, &weight_var, &client)?; + let mean_gw = var_mean(&gw, &[last_dim], true, &client)?; + let gw_xn = var_mul(&gw, &x_norm_var, &client)?; + let mean_gw_xn = var_mean(&gw_xn, &[last_dim], true, &client)?; + let xn_mean_gw_xn = var_mul(&x_norm_var, &mean_gw_xn, &client)?; + let inner = var_sub(&gw, &mean_gw, &client)?; + let inner = var_sub(&inner, &xn_mean_gw_xn, &client)?; + let d_input_residual = var_mul(&inner, &rstd_var, &client)?; + + // d_weight = sum(grad_output * x_norm, batch_dims) + let g_xn = var_mul(grad_output, &x_norm_var, &client)?; + let batch_dims: Vec = (0..last_dim).collect(); + let d_weight = if batch_dims.is_empty() { + g_xn + } else { + var_sum(&g_xn, &batch_dims, false, &client)? + }; + + // d_bias = sum(grad_output, batch_dims) + let d_bias = if batch_dims.is_empty() { + grad_output.clone() + } else { + var_sum(grad_output, &batch_dims, false, &client)? + }; + + Ok(vec![ + Some(d_input_residual.clone()), + Some(d_input_residual), + Some(d_weight), + Some(d_bias), + ]) + } + + fn inputs(&self) -> &[TensorId] { + &self.input_ids + } + + fn input_grad_fns(&self) -> Vec>>> { + self.input_grad_fns.to_vec() + } + + fn saved_tensors(&self) -> &[Tensor] { + &self.saved_tensors + } + + fn name(&self) -> &'static str { + "FusedAddLayerNormBackward" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + + #[test] + fn test_fused_add_layer_norm_backward_basic() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let pre_norm = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 4], &device); + let weight = Tensor::::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device); + let bias = Tensor::::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[4], &device); + let eps = 1e-5f32; + + let grad_out = Tensor::::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[1, 4], &device); + + let (d_input_residual, d_weight, d_bias) = client + .fused_add_layer_norm_bwd(&grad_out, &pre_norm, &weight, &bias, eps) + .unwrap(); + + let di: Vec = d_input_residual.to_vec(); + let dw: Vec = d_weight.to_vec(); + let db: Vec = d_bias.to_vec(); + + // d_input for uniform grad through layer norm should sum to ~0 + let sum: f32 = di.iter().sum(); + assert!( + sum.abs() < 1e-5, + "d_input_residual sum should be ~0, got {sum}" + ); + + for val in dw.iter().chain(db.iter()) { + assert!(val.is_finite()); + } + } + + #[test] + fn test_fused_add_layer_norm_backward_shared_gradient() { + let device = CpuDevice::new(); + + let pre_norm = Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3], &device); + let weight = Tensor::::from_slice(&[1.0f32, 1.0, 1.0], &[3], &device); + let bias = Tensor::::from_slice(&[0.0f32, 0.0, 0.0], &[3], &device); + + let grad_out = Tensor::::from_slice(&[1.0f32, 1.0, 1.0], &[1, 3], &device); + + let backward = FusedAddLayerNormBackward::::new( + TensorId::new(), + TensorId::new(), + weight.id(), + bias.id(), + pre_norm, + weight, + bias, + 1e-5, + None, + None, + None, + None, + ); + let grads = backward.backward(&grad_out).unwrap(); + + assert_eq!(grads.len(), 4); + let d_x: Vec = grads[0].as_ref().unwrap().to_vec(); + let d_r: Vec = grads[1].as_ref().unwrap().to_vec(); + for (a, b) in d_x.iter().zip(d_r.iter()) { + assert!((a - b).abs() < 1e-10, "x and residual grads must match"); + } + } +} diff --git a/src/autograd/ops/normalization/fused_add_rms_norm.rs b/src/autograd/ops/normalization/fused_add_rms_norm.rs new file mode 100644 index 00000000..b053e00e --- /dev/null +++ b/src/autograd/ops/normalization/fused_add_rms_norm.rs @@ -0,0 +1,196 @@ +//! Backward implementation for Fused Add + RMS Normalization + +use crate::autograd::GradFn; +use crate::autograd::var::Var; +use crate::autograd::var_ops::{var_mean, var_mul, var_sub, var_sum}; +use crate::error::Result; +use crate::ops::{BinaryOps, NormalizationOps, ReduceOps, ScalarOps, TensorOps, UnaryOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use crate::tensor::{Tensor, TensorId}; +use std::sync::Arc; + +/// Backward for Fused Add + RMS Normalization: pre_norm = x + residual, output = rms_norm(pre_norm, weight, eps) +/// +/// Gradients: +/// - d_input_residual = shared gradient for both x and residual (since d(x+r)/dx = d(x+r)/dr = 1) +/// - d_weight = sum(grad_out * x_norm, batch_dims) +pub struct FusedAddRmsNormBackward { + input_ids: [TensorId; 3], + saved_tensors: Vec>, // [pre_norm, weight] + eps: f32, + input_grad_fns: [Option>>; 3], +} + +impl FusedAddRmsNormBackward { + /// Create a new FusedAddRmsNormBackward + pub fn new( + x_id: TensorId, + residual_id: TensorId, + weight_id: TensorId, + pre_norm: Tensor, + weight: Tensor, + eps: f32, + x_grad_fn: Option>>, + residual_grad_fn: Option>>, + weight_grad_fn: Option>>, + ) -> Self { + Self { + input_ids: [x_id, residual_id, weight_id], + saved_tensors: vec![pre_norm, weight], + eps, + input_grad_fns: [x_grad_fn, residual_grad_fn, weight_grad_fn], + } + } +} + +impl GradFn for FusedAddRmsNormBackward +where + R::Client: TensorOps + ScalarOps + BinaryOps + ReduceOps + UnaryOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + let pre_norm = &self.saved_tensors[0]; + let weight = &self.saved_tensors[1]; + + let (d_input_residual, d_weight) = + client.fused_add_rms_norm_bwd(grad_output, pre_norm, weight, self.eps)?; + + // x and residual share the same gradient + Ok(vec![ + Some(d_input_residual.clone()), + Some(d_input_residual), + Some(d_weight), + ]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + + TensorOps + + ScalarOps + + BinaryOps + + ReduceOps + + UnaryOps, + { + let client = R::default_client(grad_output.tensor().device()); + let pre_norm = &self.saved_tensors[0]; + let weight = &self.saved_tensors[1]; + let ndim = pre_norm.ndim(); + let last_dim = ndim - 1; + + // Recompute rstd and x_norm from pre_norm (treat as constants) + let x_sq = client.mul(pre_norm, pre_norm)?; + let mean_x_sq = client.mean(&x_sq, &[last_dim], true)?; + let variance_eps = client.add_scalar(&mean_x_sq, self.eps as f64)?; + let rms = client.sqrt(&variance_eps)?; + let rstd = client.recip(&rms)?; + let x_norm = client.mul(pre_norm, &rstd)?; + + // Wrap as non-differentiable Vars + let rstd_var = Var::new(rstd, false); + let x_norm_var = Var::new(x_norm, false); + let weight_var = Var::new(weight.clone(), false); + + // d_input_residual = rstd * (grad_output * weight - x_norm * mean(grad_output * weight * x_norm)) + let gw = var_mul(grad_output, &weight_var, &client)?; + let gw_xn = var_mul(&gw, &x_norm_var, &client)?; + let mean_gw_xn = var_mean(&gw_xn, &[last_dim], true, &client)?; + let correction = var_mul(&x_norm_var, &mean_gw_xn, &client)?; + let inner = var_sub(&gw, &correction, &client)?; + let d_input_residual = var_mul(&inner, &rstd_var, &client)?; + + // d_weight = sum(grad_output * x_norm, batch_dims) + let g_xn = var_mul(grad_output, &x_norm_var, &client)?; + let batch_dims: Vec = (0..last_dim).collect(); + let d_weight = if batch_dims.is_empty() { + g_xn + } else { + var_sum(&g_xn, &batch_dims, false, &client)? + }; + + Ok(vec![ + Some(d_input_residual.clone()), + Some(d_input_residual), + Some(d_weight), + ]) + } + + fn inputs(&self) -> &[TensorId] { + &self.input_ids + } + + fn input_grad_fns(&self) -> Vec>>> { + self.input_grad_fns.to_vec() + } + + fn saved_tensors(&self) -> &[Tensor] { + &self.saved_tensors + } + + fn name(&self) -> &'static str { + "FusedAddRmsNormBackward" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + + #[test] + fn test_fused_add_rms_norm_backward_basic() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let pre_norm = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 4], &device); + let weight = Tensor::::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device); + let eps = 1e-5f32; + + let grad_out = Tensor::::from_slice(&[1.0f32, 0.0, 0.0, 0.0], &[1, 4], &device); + + let (d_input_residual, d_weight) = client + .fused_add_rms_norm_bwd(&grad_out, &pre_norm, &weight, eps) + .unwrap(); + + let di: Vec = d_input_residual.to_vec(); + let dw: Vec = d_weight.to_vec(); + + for val in &di { + assert!(val.is_finite(), "d_input_residual should be finite"); + } + for val in &dw { + assert!(val.is_finite(), "d_weight should be finite"); + } + } + + #[test] + fn test_fused_add_rms_norm_backward_shared_gradient() { + let device = CpuDevice::new(); + + let pre_norm = Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3], &device); + let weight = Tensor::::from_slice(&[1.0f32, 1.0, 1.0], &[3], &device); + + let grad_out = Tensor::::from_slice(&[1.0f32, 1.0, 1.0], &[1, 3], &device); + + let backward = FusedAddRmsNormBackward::::new( + TensorId::new(), + TensorId::new(), + weight.id(), + pre_norm, + weight, + 1e-5, + None, + None, + None, + ); + let grads = backward.backward(&grad_out).unwrap(); + + assert_eq!(grads.len(), 3); + // x and residual gradients should be identical + let d_x: Vec = grads[0].as_ref().unwrap().to_vec(); + let d_r: Vec = grads[1].as_ref().unwrap().to_vec(); + for (a, b) in d_x.iter().zip(d_r.iter()) { + assert!((a - b).abs() < 1e-10, "x and residual grads must match"); + } + } +} diff --git a/src/autograd/ops/normalization/group_norm.rs b/src/autograd/ops/normalization/group_norm.rs new file mode 100644 index 00000000..82f807f6 --- /dev/null +++ b/src/autograd/ops/normalization/group_norm.rs @@ -0,0 +1,142 @@ +//! Backward implementation for Group Normalization + +use crate::autograd::GradFn; +use crate::autograd::var::Var; +use crate::error::Result; +use crate::ops::{BinaryOps, ReduceOps, ScalarOps, TensorOps, UnaryOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use crate::tensor::{Tensor, TensorId}; +use std::sync::Arc; + +/// Backward for Group Normalization. +/// +/// Input shape: `[B, C, *spatial]`. Normalizes over (C/G, *spatial) per group. +/// +/// Gradients: +/// - d_input: similar to layer_norm but per-group +/// - d_weight = sum(grad_out * x_norm, batch_and_spatial_dims) +/// - d_bias = sum(grad_out, batch_and_spatial_dims) +pub struct GroupNormBackward { + input_ids: [TensorId; 3], // [input, weight, bias] + saved_input: Tensor, + saved_weight: Tensor, + num_groups: usize, + eps: f32, + input_grad_fns: [Option>>; 3], +} + +impl GroupNormBackward { + /// Create a new GroupNormBackward + pub fn new( + input_id: TensorId, + weight_id: TensorId, + bias_id: TensorId, + input: Tensor, + weight: Tensor, + num_groups: usize, + eps: f32, + input_grad_fn: Option>>, + weight_grad_fn: Option>>, + bias_grad_fn: Option>>, + ) -> Self { + Self { + input_ids: [input_id, weight_id, bias_id], + saved_input: input, + saved_weight: weight, + num_groups, + eps, + input_grad_fns: [input_grad_fn, weight_grad_fn, bias_grad_fn], + } + } +} + +impl GradFn for GroupNormBackward +where + R::Client: TensorOps + ScalarOps + ReduceOps + BinaryOps + UnaryOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + let input = &self.saved_input; + let weight = &self.saved_weight; + let shape = input.shape(); + let batch = shape[0]; + let channels = shape[1]; + let cpg = channels / self.num_groups; + let spatial: usize = shape[2..].iter().product::().max(1); + let group_size = cpg * spatial; + + // Flatten to [B, G, C/G * spatial] for per-group normalization + let flat_shape = [batch, self.num_groups, group_size]; + let input_flat = input.reshape(&flat_shape)?; + let grad_flat = grad_output.reshape(&flat_shape)?; + + // Per-group mean and variance: reduce over dim 2 + let mu = client.mean(&input_flat, &[2], true)?; + let x_centered = client.sub(&input_flat, &mu)?; + let x_sq = client.mul(&x_centered, &x_centered)?; + let variance = client.mean(&x_sq, &[2], true)?; + let var_eps = client.add_scalar(&variance, self.eps as f64)?; + let std = client.sqrt(&var_eps)?; + let rstd = client.recip(&std)?; + let x_norm_flat = client.mul(&x_centered, &rstd)?; + + // Reshape weight [C] → [1, G, cpg, 1] → broadcast → [1, G, cpg, spatial] → [1, G, group_size] + let weight_4d = weight.reshape(&[1, self.num_groups, cpg, 1])?; + let weight_bcast = weight_4d + .broadcast_to(&[1, self.num_groups, cpg, spatial])? + .contiguous(); + let weight_flat = weight_bcast.reshape(&[1, self.num_groups, group_size])?; + + // d_input (per-group layer norm backward) + let gw = client.mul(&grad_flat, &weight_flat)?; + let mean_gw = client.mean(&gw, &[2], true)?; + let gw_xn = client.mul(&gw, &x_norm_flat)?; + let mean_gw_xn = client.mean(&gw_xn, &[2], true)?; + let xn_correction = client.mul(&x_norm_flat, &mean_gw_xn)?; + let inner = client.sub(&gw, &mean_gw)?; + let inner = client.sub(&inner, &xn_correction)?; + let d_input_flat = client.mul(&inner, &rstd)?; + let d_input = d_input_flat.reshape(shape)?; + + // x_norm reshaped back to [B, C, spatial] + let x_norm_bcs = x_norm_flat.reshape(&[batch, channels, spatial])?; + let grad_bcs = grad_output.reshape(&[batch, channels, spatial])?; + + // d_weight = sum(grad * x_norm, dims=[0, 2]) → [C] + let gxn = client.mul(&grad_bcs, &x_norm_bcs)?; + let d_weight = client.sum(&gxn, &[0, 2], false)?; + + // d_bias = sum(grad, dims=[0, 2]) → [C] + let d_bias = client.sum(&grad_bcs, &[0, 2], false)?; + + Ok(vec![Some(d_input), Some(d_weight), Some(d_bias)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + TensorOps + ScalarOps, + { + // For higher-order gradients, fall back to tensor backward wrapped in Var + let grads = self.backward(grad_output.tensor())?; + Ok(grads + .into_iter() + .map(|g| g.map(|t| Var::new(t, false))) + .collect()) + } + + fn inputs(&self) -> &[TensorId] { + &self.input_ids + } + + fn input_grad_fns(&self) -> Vec>>> { + self.input_grad_fns.to_vec() + } + + fn saved_tensors(&self) -> &[Tensor] { + std::slice::from_ref(&self.saved_input) + } + + fn name(&self) -> &'static str { + "GroupNormBackward" + } +} diff --git a/src/autograd/ops/normalization/layer_norm.rs b/src/autograd/ops/normalization/layer_norm.rs new file mode 100644 index 00000000..e9a7f3ea --- /dev/null +++ b/src/autograd/ops/normalization/layer_norm.rs @@ -0,0 +1,242 @@ +//! Backward implementation for Layer Normalization + +use crate::autograd::GradFn; +use crate::autograd::var::Var; +use crate::autograd::var_ops::{var_mean, var_mul, var_sub, var_sum}; +use crate::error::Result; +use crate::ops::{BinaryOps, ReduceOps, ScalarOps, TensorOps, UnaryOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use crate::tensor::{Tensor, TensorId}; +use std::sync::Arc; + +/// Backward for Layer Normalization: y = (x - mean(x)) / sqrt(var(x) + eps) * weight + bias +/// +/// Gradients: +/// - d_input = rstd * (gw - mean(gw) - x_norm * mean(gw * x_norm)) +/// - d_weight = sum(grad_out * x_norm, batch_dims) +/// - d_bias = sum(grad_out, batch_dims) +/// +/// Where gw = grad_out * weight, rstd = 1/sqrt(var+eps), x_norm = (x-mean)*rstd +pub struct LayerNormBackward { + input_ids: [TensorId; 3], + saved_tensors: Vec>, // [input, weight] + eps: f32, + input_grad_fns: [Option>>; 3], +} + +impl LayerNormBackward { + /// Create a new LayerNormBackward + pub fn new( + input_id: TensorId, + weight_id: TensorId, + bias_id: TensorId, + input: Tensor, + weight: Tensor, + eps: f32, + input_grad_fn: Option>>, + weight_grad_fn: Option>>, + bias_grad_fn: Option>>, + ) -> Self { + Self { + input_ids: [input_id, weight_id, bias_id], + saved_tensors: vec![input, weight], + eps, + input_grad_fns: [input_grad_fn, weight_grad_fn, bias_grad_fn], + } + } +} + +impl GradFn for LayerNormBackward +where + R::Client: TensorOps + ScalarOps + BinaryOps + ReduceOps + UnaryOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + let saved_input = &self.saved_tensors[0]; + let saved_weight = &self.saved_tensors[1]; + let ndim = saved_input.ndim(); + let last_dim = ndim - 1; + + // Recompute rstd and x_norm + let mu = client.mean(saved_input, &[last_dim], true)?; + let x_centered = client.sub(saved_input, &mu)?; + let x_centered_sq = client.mul(&x_centered, &x_centered)?; + let variance = client.mean(&x_centered_sq, &[last_dim], true)?; + let variance_eps = client.add_scalar(&variance, self.eps as f64)?; + let std = client.sqrt(&variance_eps)?; + let rstd = client.recip(&std)?; + let x_norm = client.mul(&x_centered, &rstd)?; + + // d_input = rstd * (gw - mean(gw) - x_norm * mean(gw * x_norm)) + let gw = client.mul(grad_output, saved_weight)?; + let mean_gw = client.mean(&gw, &[last_dim], true)?; + let gw_xn = client.mul(&gw, &x_norm)?; + let mean_gw_xn = client.mean(&gw_xn, &[last_dim], true)?; + let xn_mean_gw_xn = client.mul(&x_norm, &mean_gw_xn)?; + let inner = client.sub(&gw, &mean_gw)?; + let inner = client.sub(&inner, &xn_mean_gw_xn)?; + let d_input = client.mul(&inner, &rstd)?; + + // d_weight = sum(grad_output * x_norm, batch_dims) + let g_xn = client.mul(grad_output, &x_norm)?; + let batch_dims: Vec = (0..last_dim).collect(); + let d_weight = if batch_dims.is_empty() { + g_xn + } else { + client.sum(&g_xn, &batch_dims, false)? + }; + + // d_bias = sum(grad_output, batch_dims) + let d_bias = if batch_dims.is_empty() { + grad_output.clone() + } else { + client.sum(grad_output, &batch_dims, false)? + }; + + Ok(vec![Some(d_input), Some(d_weight), Some(d_bias)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + + TensorOps + + ScalarOps + + BinaryOps + + ReduceOps + + UnaryOps, + { + let client = R::default_client(grad_output.tensor().device()); + let saved_input = &self.saved_tensors[0]; + let saved_weight = &self.saved_tensors[1]; + let ndim = saved_input.ndim(); + let last_dim = ndim - 1; + + // Recompute from saved tensors (constants w.r.t. grad_output) + let mu = client.mean(saved_input, &[last_dim], true)?; + let x_centered = client.sub(saved_input, &mu)?; + let x_centered_sq = client.mul(&x_centered, &x_centered)?; + let variance = client.mean(&x_centered_sq, &[last_dim], true)?; + let variance_eps = client.add_scalar(&variance, self.eps as f64)?; + let std = client.sqrt(&variance_eps)?; + let rstd = client.recip(&std)?; + let x_norm = client.mul(&x_centered, &rstd)?; + + // Wrap as non-differentiable Vars + let rstd_var = Var::new(rstd, false); + let x_norm_var = Var::new(x_norm, false); + let weight_var = Var::new(saved_weight.clone(), false); + + // d_input = rstd * (gw - mean(gw) - x_norm * mean(gw * x_norm)) + let gw = var_mul(grad_output, &weight_var, &client)?; + let mean_gw = var_mean(&gw, &[last_dim], true, &client)?; + let gw_xn = var_mul(&gw, &x_norm_var, &client)?; + let mean_gw_xn = var_mean(&gw_xn, &[last_dim], true, &client)?; + let xn_mean_gw_xn = var_mul(&x_norm_var, &mean_gw_xn, &client)?; + let inner = var_sub(&gw, &mean_gw, &client)?; + let inner = var_sub(&inner, &xn_mean_gw_xn, &client)?; + let d_input = var_mul(&inner, &rstd_var, &client)?; + + // d_weight = sum(grad_output * x_norm, batch_dims) + let g_xn = var_mul(grad_output, &x_norm_var, &client)?; + let batch_dims: Vec = (0..last_dim).collect(); + let d_weight = if batch_dims.is_empty() { + g_xn + } else { + var_sum(&g_xn, &batch_dims, false, &client)? + }; + + // d_bias = sum(grad_output, batch_dims) + let d_bias = if batch_dims.is_empty() { + grad_output.clone() + } else { + var_sum(grad_output, &batch_dims, false, &client)? + }; + + Ok(vec![Some(d_input), Some(d_weight), Some(d_bias)]) + } + + fn inputs(&self) -> &[TensorId] { + &self.input_ids + } + + fn input_grad_fns(&self) -> Vec>>> { + self.input_grad_fns.to_vec() + } + + fn saved_tensors(&self) -> &[Tensor] { + &self.saved_tensors + } + + fn name(&self) -> &'static str { + "LayerNormBackward" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dtype::DType; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + + #[test] + fn test_layer_norm_backward_uniform_grad() { + let device = CpuDevice::new(); + + let input = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 4], &device); + let weight = Tensor::::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device); + let eps = 1e-5f32; + + let grad_out = Tensor::::ones(&[1, 4], DType::F32, &device); + + let backward = LayerNormBackward::::new( + input.id(), + weight.id(), + TensorId::new(), + input, + weight, + eps, + None, + None, + None, + ); + let grads = backward.backward(&grad_out).unwrap(); + + assert_eq!(grads.len(), 3); + let d_input: Vec = grads[0].as_ref().unwrap().to_vec(); + + let sum: f32 = d_input.iter().sum(); + assert!( + sum.abs() < 1e-5, + "sum of d_input should be ~0 for uniform grad, got {}", + sum + ); + } + + #[test] + fn test_layer_norm_backward_bias_grad() { + let device = CpuDevice::new(); + + let input = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device); + let weight = Tensor::::from_slice(&[1.0f32, 1.0], &[2], &device); + + let grad_out = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device); + + let backward = LayerNormBackward::::new( + input.id(), + weight.id(), + TensorId::new(), + input, + weight, + 1e-5, + None, + None, + None, + ); + let grads = backward.backward(&grad_out).unwrap(); + + let d_bias: Vec = grads[2].as_ref().unwrap().to_vec(); + + assert!((d_bias[0] - 4.0).abs() < 1e-5); + assert!((d_bias[1] - 6.0).abs() < 1e-5); + } +} diff --git a/src/autograd/ops/normalization/mod.rs b/src/autograd/ops/normalization/mod.rs new file mode 100644 index 00000000..ab81029d --- /dev/null +++ b/src/autograd/ops/normalization/mod.rs @@ -0,0 +1,13 @@ +//! Backward implementations for normalization operations + +mod fused_add_layer_norm; +mod fused_add_rms_norm; +mod group_norm; +mod layer_norm; +mod rms_norm; + +pub use fused_add_layer_norm::*; +pub use fused_add_rms_norm::*; +pub use group_norm::*; +pub use layer_norm::*; +pub use rms_norm::*; diff --git a/src/autograd/ops/normalization/rms_norm.rs b/src/autograd/ops/normalization/rms_norm.rs new file mode 100644 index 00000000..f79aabd8 --- /dev/null +++ b/src/autograd/ops/normalization/rms_norm.rs @@ -0,0 +1,215 @@ +//! Backward implementation for RMS Normalization + +use crate::autograd::GradFn; +use crate::autograd::var::Var; +use crate::autograd::var_ops::{var_mean, var_mul, var_sub, var_sum}; +use crate::error::Result; +use crate::ops::{BinaryOps, ReduceOps, ScalarOps, TensorOps, UnaryOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use crate::tensor::{Tensor, TensorId}; +use std::sync::Arc; + +/// Backward for RMS Normalization: y = x / rms(x) * weight +/// +/// Where rms(x) = sqrt(mean(x^2, dim=-1) + eps) +/// +/// Gradients: +/// - d_input = rstd * (grad_out * weight - x_norm * mean(grad_out * weight * x_norm, dim=-1)) +/// - d_weight = sum(grad_out * x_norm, batch_dims) +/// +/// Where rstd = 1/rms(x), x_norm = x * rstd +pub struct RmsNormBackward { + input_ids: [TensorId; 2], + saved_tensors: Vec>, // [input, weight] + eps: f32, + input_grad_fns: [Option>>; 2], +} + +impl RmsNormBackward { + /// Create a new RmsNormBackward + pub fn new( + input_id: TensorId, + weight_id: TensorId, + input: Tensor, + weight: Tensor, + eps: f32, + input_grad_fn: Option>>, + weight_grad_fn: Option>>, + ) -> Self { + Self { + input_ids: [input_id, weight_id], + saved_tensors: vec![input, weight], + eps, + input_grad_fns: [input_grad_fn, weight_grad_fn], + } + } +} + +impl GradFn for RmsNormBackward +where + R::Client: TensorOps + ScalarOps + BinaryOps + ReduceOps + UnaryOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + let saved_input = &self.saved_tensors[0]; + let saved_weight = &self.saved_tensors[1]; + let ndim = saved_input.ndim(); + let last_dim = ndim - 1; + + // Recompute rstd = 1 / sqrt(mean(x^2, dim=-1, keepdim=True) + eps) + let x_sq = client.mul(saved_input, saved_input)?; + let mean_x_sq = client.mean(&x_sq, &[last_dim], true)?; + let variance_eps = client.add_scalar(&mean_x_sq, self.eps as f64)?; + let rms = client.sqrt(&variance_eps)?; + let rstd = client.recip(&rms)?; + + // x_norm = x * rstd + let x_norm = client.mul(saved_input, &rstd)?; + + // d_input = rstd * (grad_output * weight - x_norm * mean(grad_output * weight * x_norm)) + let gw = client.mul(grad_output, saved_weight)?; + let gw_xn = client.mul(&gw, &x_norm)?; + let mean_gw_xn = client.mean(&gw_xn, &[last_dim], true)?; + let correction = client.mul(&x_norm, &mean_gw_xn)?; + let inner = client.sub(&gw, &correction)?; + let d_input = client.mul(&inner, &rstd)?; + + // d_weight = sum(grad_output * x_norm, batch_dims) + let g_xn = client.mul(grad_output, &x_norm)?; + let batch_dims: Vec = (0..last_dim).collect(); + let d_weight = if batch_dims.is_empty() { + g_xn + } else { + client.sum(&g_xn, &batch_dims, false)? + }; + + Ok(vec![Some(d_input), Some(d_weight)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + + TensorOps + + ScalarOps + + BinaryOps + + ReduceOps + + UnaryOps, + { + let client = R::default_client(grad_output.tensor().device()); + let saved_input = &self.saved_tensors[0]; + let saved_weight = &self.saved_tensors[1]; + let ndim = saved_input.ndim(); + let last_dim = ndim - 1; + + // Recompute rstd and x_norm from saved tensors (treat as constants) + let x_sq = client.mul(saved_input, saved_input)?; + let mean_x_sq = client.mean(&x_sq, &[last_dim], true)?; + let variance_eps = client.add_scalar(&mean_x_sq, self.eps as f64)?; + let rms = client.sqrt(&variance_eps)?; + let rstd = client.recip(&rms)?; + let x_norm = client.mul(saved_input, &rstd)?; + + // Wrap as non-differentiable Vars (constants w.r.t. grad_output) + let rstd_var = Var::new(rstd, false); + let x_norm_var = Var::new(x_norm, false); + let weight_var = Var::new(saved_weight.clone(), false); + + // d_input = rstd * (grad_output * weight - x_norm * mean(grad_output * weight * x_norm)) + let gw = var_mul(grad_output, &weight_var, &client)?; + let gw_xn = var_mul(&gw, &x_norm_var, &client)?; + let mean_gw_xn = var_mean(&gw_xn, &[last_dim], true, &client)?; + let correction = var_mul(&x_norm_var, &mean_gw_xn, &client)?; + let inner = var_sub(&gw, &correction, &client)?; + let d_input = var_mul(&inner, &rstd_var, &client)?; + + // d_weight = sum(grad_output * x_norm, batch_dims) + let g_xn = var_mul(grad_output, &x_norm_var, &client)?; + let batch_dims: Vec = (0..last_dim).collect(); + let d_weight = if batch_dims.is_empty() { + g_xn + } else { + var_sum(&g_xn, &batch_dims, false, &client)? + }; + + Ok(vec![Some(d_input), Some(d_weight)]) + } + + fn inputs(&self) -> &[TensorId] { + &self.input_ids + } + + fn input_grad_fns(&self) -> Vec>>> { + self.input_grad_fns.to_vec() + } + + fn saved_tensors(&self) -> &[Tensor] { + &self.saved_tensors + } + + fn name(&self) -> &'static str { + "RmsNormBackward" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dtype::DType; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + + #[test] + fn test_rms_norm_backward_uniform() { + let device = CpuDevice::new(); + + let input = Tensor::::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[1, 4], &device); + let weight = Tensor::::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device); + let eps = 1e-5f32; + + let grad_out = Tensor::::from_slice(&[1.0f32, 0.0, 0.0, 0.0], &[1, 4], &device); + + let backward = RmsNormBackward::::new( + input.id(), + weight.id(), + input, + weight, + eps, + None, + None, + ); + let grads = backward.backward(&grad_out).unwrap(); + + assert_eq!(grads.len(), 2); + let d_input: Vec = grads[0].as_ref().unwrap().to_vec(); + let d_weight: Vec = grads[1].as_ref().unwrap().to_vec(); + + assert!(d_input[0] > 0.0, "d_input[0] should be positive"); + assert!(d_input[1] < 0.0, "d_input[1] should be negative"); + assert!((d_weight[0] - 1.0).abs() < 0.01); + assert!(d_weight[1].abs() < 1e-5); + } + + #[test] + fn test_rms_norm_backward_gradient_sum() { + let device = CpuDevice::new(); + + let input = Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3], &device); + let weight = Tensor::::from_slice(&[1.0f32, 1.0, 1.0], &[3], &device); + let grad_out = Tensor::::ones(&[1, 3], DType::F32, &device); + + let backward = RmsNormBackward::::new( + input.id(), + weight.id(), + input, + weight, + 1e-5, + None, + None, + ); + let grads = backward.backward(&grad_out).unwrap(); + let d_input: Vec = grads[0].as_ref().unwrap().to_vec(); + + for val in &d_input { + assert!(val.is_finite(), "gradient should be finite"); + } + } +} diff --git a/src/autograd/ops/reduce.rs b/src/autograd/ops/reduce.rs deleted file mode 100644 index fdbcad07..00000000 --- a/src/autograd/ops/reduce.rs +++ /dev/null @@ -1,1025 +0,0 @@ -//! Backward implementations for reduction operations -//! -//! Implements gradient computation for sum, mean, max, and min reductions. - -use crate::autograd::GradFn; -use crate::autograd::var::Var; -use crate::autograd::var_ops::{var_div_scalar, var_mul}; -use crate::error::Result; -use crate::ops::{BinaryOps, CompareOps, ReduceOps, ScalarOps, TensorOps}; -use crate::runtime::{Runtime, RuntimeClient}; -use crate::tensor::{Tensor, TensorId}; -use std::sync::Arc; - -// ============================================================================ -// Helper Functions -// ============================================================================ - -/// Ensure a tensor is contiguous, making a copy if necessary. -#[inline] -fn ensure_contiguous(tensor: Tensor) -> Tensor { - if tensor.is_contiguous() { - tensor - } else { - tensor.contiguous() - } -} - -// ============================================================================ -// SumBackward -// ============================================================================ - -/// Backward for sum reduction: z = sum(a, dims) -/// -/// The gradient of sum is broadcast expansion. -/// For z = sum(a, dims), dL/da = broadcast(dL/dz, original_shape) -/// -/// If keepdim=false, we need to unsqueeze the gradient before broadcasting. -pub struct SumBackward { - input_id: TensorId, - input_shape: Vec, - dims: Vec, - keepdim: bool, - input_grad_fn: Option>>, -} - -impl SumBackward { - /// Create a new SumBackward - pub fn new( - input_id: TensorId, - input_shape: &[usize], - dims: &[usize], - keepdim: bool, - input_grad_fn: Option>>, - ) -> Self { - Self { - input_id, - input_shape: input_shape.to_vec(), - dims: dims.to_vec(), - keepdim, - input_grad_fn, - } - } -} - -impl GradFn for SumBackward { - fn backward(&self, grad_output: &Tensor) -> Result>>> { - // For sum, the gradient is broadcast back to the original shape - // All elements contribute equally to the sum, so each gets the full gradient - - let mut grad = grad_output.clone(); - - // If keepdim=false, we need to unsqueeze the dimensions that were reduced - if !self.keepdim { - // Sort dims in ascending order to unsqueeze correctly - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - - for &dim in &sorted_dims { - grad = grad.unsqueeze(dim as isize)?; - } - } - - // Broadcast to original shape and ensure contiguous - grad = ensure_contiguous(grad.broadcast_to(&self.input_shape)?); - - Ok(vec![Some(grad)]) - } - - fn backward_var(&self, grad_output: &Var) -> Result>>> { - // For sum, the gradient is just shape manipulation (unsqueeze + broadcast) - // The operations are linear/constant, so second derivative is 0 - // We still need to track the gradient flow through grad_output - - let mut grad_tensor = grad_output.tensor().clone(); - - // If keepdim=false, we need to unsqueeze the dimensions that were reduced - if !self.keepdim { - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - - for &dim in &sorted_dims { - grad_tensor = grad_tensor.unsqueeze(dim as isize)?; - } - } - - // Broadcast to original shape and ensure contiguous - grad_tensor = ensure_contiguous(grad_tensor.broadcast_to(&self.input_shape)?); - - // Wrap in Var - since sum's backward is purely linear (identity broadcast), - // the computation graph for second-order derivatives flows through grad_output - // which is already tracked. The broadcast is a view operation. - Ok(vec![Some(Var::new(grad_tensor, true))]) - } - - fn inputs(&self) -> &[TensorId] { - std::slice::from_ref(&self.input_id) - } - - fn input_grad_fns(&self) -> Vec>>> { - vec![self.input_grad_fn.clone()] - } - - fn name(&self) -> &'static str { - "SumBackward" - } -} - -// ============================================================================ -// MeanBackward -// ============================================================================ - -/// Backward for mean reduction: z = mean(a, dims) -/// -/// For z = mean(a, dims), dL/da = broadcast(dL/dz, original_shape) / count -/// where count is the number of elements being averaged. -pub struct MeanBackward { - input_id: TensorId, - input_shape: Vec, - dims: Vec, - keepdim: bool, - input_grad_fn: Option>>, -} - -impl MeanBackward { - /// Create a new MeanBackward - pub fn new( - input_id: TensorId, - input_shape: &[usize], - dims: &[usize], - keepdim: bool, - input_grad_fn: Option>>, - ) -> Self { - Self { - input_id, - input_shape: input_shape.to_vec(), - dims: dims.to_vec(), - keepdim, - input_grad_fn, - } - } -} - -impl GradFn for MeanBackward -where - R::Client: ScalarOps, -{ - fn backward(&self, grad_output: &Tensor) -> Result>>> { - let client = R::default_client(grad_output.device()); - - // Calculate the count (number of elements being averaged) - let count: usize = self.dims.iter().map(|&d| self.input_shape[d]).product(); - let count_f64 = count as f64; - - let mut grad = grad_output.clone(); - - // If keepdim=false, we need to unsqueeze the dimensions that were reduced - if !self.keepdim { - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - - for &dim in &sorted_dims { - grad = grad.unsqueeze(dim as isize)?; - } - } - - // Broadcast to original shape and ensure contiguous - grad = ensure_contiguous(grad.broadcast_to(&self.input_shape)?); - - // Divide by count - let grad = client.div_scalar(&grad, count_f64)?; - - Ok(vec![Some(grad)]) - } - - fn backward_var(&self, grad_output: &Var) -> Result>>> - where - R::Client: RuntimeClient + TensorOps + ScalarOps, - { - let client = R::default_client(grad_output.tensor().device()); - - // Calculate the count (number of elements being averaged) - let count: usize = self.dims.iter().map(|&d| self.input_shape[d]).product(); - let count_f64 = count as f64; - - let mut grad_tensor = grad_output.tensor().clone(); - - // If keepdim=false, we need to unsqueeze the dimensions that were reduced - if !self.keepdim { - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - - for &dim in &sorted_dims { - grad_tensor = grad_tensor.unsqueeze(dim as isize)?; - } - } - - // Broadcast to original shape and ensure contiguous - grad_tensor = ensure_contiguous(grad_tensor.broadcast_to(&self.input_shape)?); - - // Create a Var for the broadcast gradient - let grad_var = Var::new(grad_tensor, grad_output.requires_grad()); - - // Divide by count using var_div_scalar to track gradients - let grad = var_div_scalar(&grad_var, count_f64, &client)?; - - Ok(vec![Some(grad)]) - } - - fn inputs(&self) -> &[TensorId] { - std::slice::from_ref(&self.input_id) - } - - fn input_grad_fns(&self) -> Vec>>> { - vec![self.input_grad_fn.clone()] - } - - fn name(&self) -> &'static str { - "MeanBackward" - } -} - -// ============================================================================ -// MaxBackward -// ============================================================================ - -/// Backward for max reduction: z = max(a, dims) -/// -/// The gradient flows only to the element(s) that had the maximum value. -/// For ties, the gradient is distributed equally among tied elements. -pub struct MaxBackward { - input_id: TensorId, - saved_input: Tensor, - dims: Vec, - keepdim: bool, - input_grad_fn: Option>>, -} - -impl MaxBackward { - /// Create a new MaxBackward - pub fn new( - input_id: TensorId, - input: Tensor, - dims: &[usize], - keepdim: bool, - input_grad_fn: Option>>, - ) -> Self { - Self { - input_id, - saved_input: input, - dims: dims.to_vec(), - keepdim, - input_grad_fn, - } - } -} - -impl GradFn for MaxBackward -where - R::Client: TensorOps + ScalarOps + CompareOps + ReduceOps, -{ - fn backward(&self, grad_output: &Tensor) -> Result>>> { - let client = R::default_client(grad_output.device()); - - // Recompute max to get the max values - let max_vals = client.max(&self.saved_input, &self.dims, true)?; - - // Broadcast max values to input shape for comparison - let max_broadcast = ensure_contiguous(max_vals.broadcast_to(self.saved_input.shape())?); - - // Create mask where input equals max (handles ties) - let mask = client.eq(&self.saved_input, &max_broadcast)?; - - // Count how many elements equal the max per reduction group (for distributing gradient in case of ties) - let mask_sum = client.sum(&mask, &self.dims, true)?; - - // Broadcast mask_sum to input shape - let mask_sum_broadcast = - ensure_contiguous(mask_sum.broadcast_to(self.saved_input.shape())?); - - // Normalize mask by count (distribute gradient equally among tied elements) - let normalized_mask = client.div(&mask, &mask_sum_broadcast)?; - - // Broadcast grad_output to input shape - let mut grad = grad_output.clone(); - if !self.keepdim { - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - for &dim in &sorted_dims { - grad = grad.unsqueeze(dim as isize)?; - } - } - let grad_broadcast = ensure_contiguous(grad.broadcast_to(self.saved_input.shape())?); - - // Multiply gradient by normalized mask - let grad_input = client.mul(&grad_broadcast, &normalized_mask)?; - - Ok(vec![Some(grad_input)]) - } - - fn backward_var(&self, grad_output: &Var) -> Result>>> - where - R::Client: RuntimeClient + TensorOps + ScalarOps + CompareOps + ReduceOps, - { - let client = R::default_client(grad_output.tensor().device()); - - // Recompute max to get the max values - let max_vals = client.max(&self.saved_input, &self.dims, true)?; - - // Broadcast max values to input shape for comparison - let max_broadcast = ensure_contiguous(max_vals.broadcast_to(self.saved_input.shape())?); - - // Create mask where input equals max (handles ties) - let mask = client.eq(&self.saved_input, &max_broadcast)?; - - // Count how many elements equal the max per reduction group - let mask_sum = client.sum(&mask, &self.dims, true)?; - - // Broadcast mask_sum to input shape - let mask_sum_broadcast = - ensure_contiguous(mask_sum.broadcast_to(self.saved_input.shape())?); - - // Normalize mask by count (distribute gradient equally among tied elements) - let normalized_mask = client.div(&mask, &mask_sum_broadcast)?; - - // Broadcast grad_output to input shape - let mut grad_tensor = grad_output.tensor().clone(); - if !self.keepdim { - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - for &dim in &sorted_dims { - grad_tensor = grad_tensor.unsqueeze(dim as isize)?; - } - } - let grad_broadcast = ensure_contiguous(grad_tensor.broadcast_to(self.saved_input.shape())?); - - // Create Vars for the multiplication - // The normalized_mask is constant w.r.t. grad_output (it's a hard mask based on input) - // So we wrap it as a detached Var - let grad_var = Var::new(grad_broadcast, grad_output.requires_grad()); - let mask_var = Var::new(normalized_mask, false); // mask is not differentiable - - // Multiply gradient by normalized mask using var_mul to track gradients through grad_output - let grad_input = var_mul(&grad_var, &mask_var, &client)?; - - Ok(vec![Some(grad_input)]) - } - - fn inputs(&self) -> &[TensorId] { - std::slice::from_ref(&self.input_id) - } - - fn input_grad_fns(&self) -> Vec>>> { - vec![self.input_grad_fn.clone()] - } - - fn saved_tensors(&self) -> &[Tensor] { - std::slice::from_ref(&self.saved_input) - } - - fn name(&self) -> &'static str { - "MaxBackward" - } -} - -// ============================================================================ -// MinBackward -// ============================================================================ - -/// Backward for min reduction: z = min(a, dims) -/// -/// The gradient flows only to the element(s) that had the minimum value. -/// For ties, the gradient is distributed equally among tied elements. -pub struct MinBackward { - input_id: TensorId, - saved_input: Tensor, - dims: Vec, - keepdim: bool, - input_grad_fn: Option>>, -} - -impl MinBackward { - /// Create a new MinBackward - pub fn new( - input_id: TensorId, - input: Tensor, - dims: &[usize], - keepdim: bool, - input_grad_fn: Option>>, - ) -> Self { - Self { - input_id, - saved_input: input, - dims: dims.to_vec(), - keepdim, - input_grad_fn, - } - } -} - -impl GradFn for MinBackward -where - R::Client: TensorOps + ScalarOps + CompareOps + ReduceOps, -{ - fn backward(&self, grad_output: &Tensor) -> Result>>> { - let client = R::default_client(grad_output.device()); - - // Recompute min to get the min values - let min_vals = client.min(&self.saved_input, &self.dims, true)?; - - // Broadcast min values to input shape for comparison - let min_broadcast = ensure_contiguous(min_vals.broadcast_to(self.saved_input.shape())?); - - // Create mask where input equals min (handles ties) - let mask = client.eq(&self.saved_input, &min_broadcast)?; - - // Count how many elements equal the min per reduction group - let mask_sum = client.sum(&mask, &self.dims, true)?; - - // Broadcast mask_sum to input shape - let mask_sum_broadcast = - ensure_contiguous(mask_sum.broadcast_to(self.saved_input.shape())?); - - // Normalize mask by count - let normalized_mask = client.div(&mask, &mask_sum_broadcast)?; - - // Broadcast grad_output to input shape - let mut grad = grad_output.clone(); - if !self.keepdim { - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - for &dim in &sorted_dims { - grad = grad.unsqueeze(dim as isize)?; - } - } - let grad_broadcast = ensure_contiguous(grad.broadcast_to(self.saved_input.shape())?); - - // Multiply gradient by normalized mask - let grad_input = client.mul(&grad_broadcast, &normalized_mask)?; - - Ok(vec![Some(grad_input)]) - } - - fn backward_var(&self, grad_output: &Var) -> Result>>> - where - R::Client: RuntimeClient + TensorOps + ScalarOps + CompareOps + ReduceOps, - { - let client = R::default_client(grad_output.tensor().device()); - - // Recompute min to get the min values - let min_vals = client.min(&self.saved_input, &self.dims, true)?; - - // Broadcast min values to input shape for comparison - let min_broadcast = ensure_contiguous(min_vals.broadcast_to(self.saved_input.shape())?); - - // Create mask where input equals min (handles ties) - let mask = client.eq(&self.saved_input, &min_broadcast)?; - - // Count how many elements equal the min per reduction group - let mask_sum = client.sum(&mask, &self.dims, true)?; - - // Broadcast mask_sum to input shape - let mask_sum_broadcast = - ensure_contiguous(mask_sum.broadcast_to(self.saved_input.shape())?); - - // Normalize mask by count - let normalized_mask = client.div(&mask, &mask_sum_broadcast)?; - - // Broadcast grad_output to input shape - let mut grad_tensor = grad_output.tensor().clone(); - if !self.keepdim { - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - for &dim in &sorted_dims { - grad_tensor = grad_tensor.unsqueeze(dim as isize)?; - } - } - let grad_broadcast = ensure_contiguous(grad_tensor.broadcast_to(self.saved_input.shape())?); - - // Create Vars for the multiplication - // The normalized_mask is constant w.r.t. grad_output (it's a hard mask based on input) - let grad_var = Var::new(grad_broadcast, grad_output.requires_grad()); - let mask_var = Var::new(normalized_mask, false); // mask is not differentiable - - // Multiply gradient by normalized mask using var_mul to track gradients through grad_output - let grad_input = var_mul(&grad_var, &mask_var, &client)?; - - Ok(vec![Some(grad_input)]) - } - - fn inputs(&self) -> &[TensorId] { - std::slice::from_ref(&self.input_id) - } - - fn input_grad_fns(&self) -> Vec>>> { - vec![self.input_grad_fn.clone()] - } - - fn saved_tensors(&self) -> &[Tensor] { - std::slice::from_ref(&self.saved_input) - } - - fn name(&self) -> &'static str { - "MinBackward" - } -} - -// ============================================================================ -// VarBackward -// ============================================================================ - -/// Backward for variance reduction: z = var(a, dims, correction) -/// -/// The gradient of variance is: -/// dL/da = dL/dz * 2 * (a - mean(a)) / (N - correction) -/// -/// where N is the number of elements being reduced. -pub struct VarBackward { - input_id: TensorId, - saved_input: Tensor, - dims: Vec, - keepdim: bool, - correction: usize, - input_grad_fn: Option>>, -} - -impl VarBackward { - /// Create a new VarBackward - pub fn new( - input_id: TensorId, - input: Tensor, - dims: &[usize], - keepdim: bool, - correction: usize, - input_grad_fn: Option>>, - ) -> Self { - Self { - input_id, - saved_input: input, - dims: dims.to_vec(), - keepdim, - correction, - input_grad_fn, - } - } -} - -impl GradFn for VarBackward -where - R::Client: TensorOps + ScalarOps + ReduceOps, -{ - fn backward(&self, grad_output: &Tensor) -> Result>>> { - let client = R::default_client(grad_output.device()); - - // Calculate N (number of elements in reduction) - let n: usize = self - .dims - .iter() - .map(|&d| self.saved_input.shape()[d]) - .product(); - let n_minus_corr = (n - self.correction) as f64; - - // Compute mean of input - let mean = client.mean(&self.saved_input, &self.dims, true)?; - - // Broadcast mean to input shape - let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?); - - // a - mean(a) - let centered = client.sub(&self.saved_input, &mean_broadcast)?; - - // 2 * (a - mean) / (N - correction) - let scale = 2.0 / n_minus_corr; - let grad_contrib = client.mul_scalar(¢ered, scale)?; - - // Handle grad_output shape - broadcast to input shape - let mut grad = grad_output.clone(); - if !self.keepdim { - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - for &dim in &sorted_dims { - grad = grad.unsqueeze(dim as isize)?; - } - } - let grad_broadcast = ensure_contiguous(grad.broadcast_to(self.saved_input.shape())?); - - // Final gradient - let grad_input = client.mul(&grad_broadcast, &grad_contrib)?; - - Ok(vec![Some(grad_input)]) - } - - fn backward_var(&self, grad_output: &Var) -> Result>>> - where - R::Client: RuntimeClient + TensorOps + ScalarOps + ReduceOps, - { - let client = R::default_client(grad_output.tensor().device()); - - // Calculate N (number of elements in reduction) - let n: usize = self - .dims - .iter() - .map(|&d| self.saved_input.shape()[d]) - .product(); - let n_minus_corr = (n - self.correction) as f64; - - // Compute mean of input - let mean = client.mean(&self.saved_input, &self.dims, true)?; - - // Broadcast mean to input shape - let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?); - - // a - mean(a) - let centered = client.sub(&self.saved_input, &mean_broadcast)?; - - // 2 * (a - mean) / (N - correction) - let scale = 2.0 / n_minus_corr; - let grad_contrib = client.mul_scalar(¢ered, scale)?; - - // Handle grad_output shape - broadcast to input shape - let mut grad_tensor = grad_output.tensor().clone(); - if !self.keepdim { - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - for &dim in &sorted_dims { - grad_tensor = grad_tensor.unsqueeze(dim as isize)?; - } - } - let grad_broadcast = ensure_contiguous(grad_tensor.broadcast_to(self.saved_input.shape())?); - - // Create Vars for the multiplication - // grad_contrib depends on input (through centering), but for second-order - // differentiation of variance w.r.t. grad_output, it's treated as constant - let grad_var = Var::new(grad_broadcast, grad_output.requires_grad()); - let contrib_var = Var::new(grad_contrib, false); - - // Final gradient using var_mul to track gradients through grad_output - let grad_input = var_mul(&grad_var, &contrib_var, &client)?; - - Ok(vec![Some(grad_input)]) - } - - fn inputs(&self) -> &[TensorId] { - std::slice::from_ref(&self.input_id) - } - - fn input_grad_fns(&self) -> Vec>>> { - vec![self.input_grad_fn.clone()] - } - - fn saved_tensors(&self) -> &[Tensor] { - std::slice::from_ref(&self.saved_input) - } - - fn name(&self) -> &'static str { - "VarBackward" - } -} - -// ============================================================================ -// StdBackward -// ============================================================================ - -/// Backward for standard deviation reduction: z = std(a, dims, correction) -/// -/// std = sqrt(var), so by chain rule: -/// dL/da = dL/dz * d(sqrt(var))/dvar * dvar/da -/// = dL/dz * 1/(2*std) * 2*(a - mean) / (N - correction) -/// = dL/dz * (a - mean) / ((N - correction) * std) -pub struct StdBackward { - input_id: TensorId, - saved_input: Tensor, - saved_output: Tensor, // std(a) - dims: Vec, - keepdim: bool, - correction: usize, - input_grad_fn: Option>>, -} - -impl StdBackward { - /// Create a new StdBackward - pub fn new( - input_id: TensorId, - input: Tensor, - output: Tensor, - dims: &[usize], - keepdim: bool, - correction: usize, - input_grad_fn: Option>>, - ) -> Self { - Self { - input_id, - saved_input: input, - saved_output: output, - dims: dims.to_vec(), - keepdim, - correction, - input_grad_fn, - } - } -} - -impl GradFn for StdBackward -where - R::Client: TensorOps + ScalarOps + ReduceOps, -{ - fn backward(&self, grad_output: &Tensor) -> Result>>> { - let client = R::default_client(grad_output.device()); - - // Calculate N (number of elements in reduction) - let n: usize = self - .dims - .iter() - .map(|&d| self.saved_input.shape()[d]) - .product(); - let n_minus_corr = (n - self.correction) as f64; - - // Compute mean of input - let mean = client.mean(&self.saved_input, &self.dims, true)?; - - // Broadcast mean and std to input shape - let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?); - - let std_for_broadcast = if self.keepdim { - self.saved_output.clone() - } else { - let mut std_expanded = self.saved_output.clone(); - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - for &dim in &sorted_dims { - std_expanded = std_expanded.unsqueeze(dim as isize)?; - } - std_expanded - }; - let std_broadcast = - ensure_contiguous(std_for_broadcast.broadcast_to(self.saved_input.shape())?); - - // (a - mean) - let centered = client.sub(&self.saved_input, &mean_broadcast)?; - - // (a - mean) / ((N - correction) * std) - let denominator = client.mul_scalar(&std_broadcast, n_minus_corr)?; - let grad_contrib = client.div(¢ered, &denominator)?; - - // Handle grad_output shape - broadcast to input shape - let mut grad = grad_output.clone(); - if !self.keepdim { - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - for &dim in &sorted_dims { - grad = grad.unsqueeze(dim as isize)?; - } - } - let grad_broadcast = ensure_contiguous(grad.broadcast_to(self.saved_input.shape())?); - - // Final gradient - let grad_input = client.mul(&grad_broadcast, &grad_contrib)?; - - Ok(vec![Some(grad_input)]) - } - - fn backward_var(&self, grad_output: &Var) -> Result>>> - where - R::Client: RuntimeClient + TensorOps + ScalarOps + ReduceOps, - { - let client = R::default_client(grad_output.tensor().device()); - - // Calculate N (number of elements in reduction) - let n: usize = self - .dims - .iter() - .map(|&d| self.saved_input.shape()[d]) - .product(); - let n_minus_corr = (n - self.correction) as f64; - - // Compute mean of input - let mean = client.mean(&self.saved_input, &self.dims, true)?; - - // Broadcast mean and std to input shape - let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?); - - let std_for_broadcast = if self.keepdim { - self.saved_output.clone() - } else { - let mut std_expanded = self.saved_output.clone(); - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - for &dim in &sorted_dims { - std_expanded = std_expanded.unsqueeze(dim as isize)?; - } - std_expanded - }; - let std_broadcast = - ensure_contiguous(std_for_broadcast.broadcast_to(self.saved_input.shape())?); - - // (a - mean) - let centered = client.sub(&self.saved_input, &mean_broadcast)?; - - // (a - mean) / ((N - correction) * std) - let denominator = client.mul_scalar(&std_broadcast, n_minus_corr)?; - let grad_contrib = client.div(¢ered, &denominator)?; - - // Handle grad_output shape - broadcast to input shape - let mut grad_tensor = grad_output.tensor().clone(); - if !self.keepdim { - let mut sorted_dims = self.dims.clone(); - sorted_dims.sort(); - for &dim in &sorted_dims { - grad_tensor = grad_tensor.unsqueeze(dim as isize)?; - } - } - let grad_broadcast = ensure_contiguous(grad_tensor.broadcast_to(self.saved_input.shape())?); - - // Create Vars for the multiplication - // grad_contrib depends on input and saved_output, but for second-order - // differentiation of std w.r.t. grad_output, it's treated as constant - let grad_var = Var::new(grad_broadcast, grad_output.requires_grad()); - let contrib_var = Var::new(grad_contrib, false); - - // Final gradient using var_mul to track gradients through grad_output - let grad_input = var_mul(&grad_var, &contrib_var, &client)?; - - Ok(vec![Some(grad_input)]) - } - - fn inputs(&self) -> &[TensorId] { - std::slice::from_ref(&self.input_id) - } - - fn input_grad_fns(&self) -> Vec>>> { - vec![self.input_grad_fn.clone()] - } - - fn saved_tensors(&self) -> &[Tensor] { - // Return both saved tensors - but we can only return a slice, so just input for now - std::slice::from_ref(&self.saved_input) - } - - fn name(&self) -> &'static str { - "StdBackward" - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::dtype::DType; - use crate::runtime::cpu::{CpuDevice, CpuRuntime}; - - #[test] - fn test_sum_backward_keepdim() { - let device = CpuDevice::new(); - - // a = [[1, 2, 3], [4, 5, 6]] (2x3) - // sum(a, dim=1, keepdim=True) = [[6], [15]] (2x1) - // dL/dz = [[1], [1]] (2x1) - // dL/da = [[1, 1, 1], [1, 1, 1]] (2x3) - - let grad_out = Tensor::::ones(&[2, 1], DType::F32, &device); - - let backward = SumBackward::::new( - TensorId::new(), - &[2, 3], - &[1], - true, // keepdim - None, // input_grad_fn - ); - let grads = backward.backward(&grad_out).unwrap(); - - let grad_a = grads[0].as_ref().unwrap(); - assert_eq!(grad_a.shape(), &[2, 3]); - - let grad_data: Vec = grad_a.to_vec(); - assert_eq!(grad_data, vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); - } - - #[test] - fn test_sum_backward_no_keepdim() { - let device = CpuDevice::new(); - - // a = [[1, 2, 3], [4, 5, 6]] (2x3) - // sum(a, dim=1, keepdim=False) = [6, 15] (2,) - // dL/dz = [1, 1] (2,) - // dL/da = [[1, 1, 1], [1, 1, 1]] (2x3) - - let grad_out = Tensor::::ones(&[2], DType::F32, &device); - - let backward = SumBackward::::new( - TensorId::new(), - &[2, 3], - &[1], - false, // no keepdim - None, // input_grad_fn - ); - let grads = backward.backward(&grad_out).unwrap(); - - let grad_a = grads[0].as_ref().unwrap(); - assert_eq!(grad_a.shape(), &[2, 3]); - - let grad_data: Vec = grad_a.to_vec(); - assert_eq!(grad_data, vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); - } - - #[test] - fn test_mean_backward() { - let device = CpuDevice::new(); - - // a = [[1, 2, 3], [4, 5, 6]] (2x3) - // mean(a, dim=1, keepdim=True) = [[2], [5]] (2x1) - // dL/dz = [[1], [1]] (2x1) - // dL/da = [[1/3, 1/3, 1/3], [1/3, 1/3, 1/3]] (2x3) - - let grad_out = Tensor::::ones(&[2, 1], DType::F32, &device); - - let backward = MeanBackward::::new( - TensorId::new(), - &[2, 3], - &[1], - true, // keepdim - None, // input_grad_fn - ); - let grads = backward.backward(&grad_out).unwrap(); - - let grad_a = grads[0].as_ref().unwrap(); - assert_eq!(grad_a.shape(), &[2, 3]); - - let grad_data: Vec = grad_a.to_vec(); - let expected = 1.0 / 3.0; - for val in grad_data { - assert!((val - expected).abs() < 1e-6); - } - } - - #[test] - fn test_max_backward() { - let device = CpuDevice::new(); - let _client = CpuRuntime::default_client(&device); - - // a = [[1, 3, 2], [4, 2, 5]] (2x3) - // max(a, dim=1, keepdim=True) = [[3], [5]] (2x1) - // dL/dz = [[1], [1]] (2x1) - // dL/da = [[0, 1, 0], [0, 0, 1]] (gradient flows only to max elements) - let a = - Tensor::::from_slice(&[1.0f32, 3.0, 2.0, 4.0, 2.0, 5.0], &[2, 3], &device); - let grad_out = Tensor::::ones(&[2, 1], DType::F32, &device); - - let backward = MaxBackward::::new(a.id(), a.clone(), &[1], true, None); - let grads = backward.backward(&grad_out).unwrap(); - - let grad_a = grads[0].as_ref().unwrap(); - assert_eq!(grad_a.shape(), &[2, 3]); - - let grad_data: Vec = grad_a.to_vec(); - // Max at index 1 for first row, index 2 for second row - assert_eq!(grad_data, vec![0.0, 1.0, 0.0, 0.0, 0.0, 1.0]); - } - - #[test] - fn test_min_backward() { - let device = CpuDevice::new(); - let _client = CpuRuntime::default_client(&device); - - // a = [[3, 1, 2], [4, 2, 5]] (2x3) - // min(a, dim=1, keepdim=True) = [[1], [2]] (2x1) - // dL/dz = [[1], [1]] (2x1) - // dL/da = [[0, 1, 0], [0, 1, 0]] (gradient flows only to min elements) - let a = - Tensor::::from_slice(&[3.0f32, 1.0, 2.0, 4.0, 2.0, 5.0], &[2, 3], &device); - let grad_out = Tensor::::ones(&[2, 1], DType::F32, &device); - - let backward = MinBackward::::new(a.id(), a.clone(), &[1], true, None); - let grads = backward.backward(&grad_out).unwrap(); - - let grad_a = grads[0].as_ref().unwrap(); - assert_eq!(grad_a.shape(), &[2, 3]); - - let grad_data: Vec = grad_a.to_vec(); - // Min at index 1 for first row, index 1 for second row - assert_eq!(grad_data, vec![0.0, 1.0, 0.0, 0.0, 1.0, 0.0]); - } - - #[test] - fn test_max_backward_with_ties() { - let device = CpuDevice::new(); - let _client = CpuRuntime::default_client(&device); - - // a = [[3, 3, 1]] (1x3) - two tied max values - // max(a, dim=1, keepdim=True) = [[3]] (1x1) - // dL/dz = [[1]] (1x1) - // dL/da = [[0.5, 0.5, 0]] (gradient split equally among tied max elements) - let a = Tensor::::from_slice(&[3.0f32, 3.0, 1.0], &[1, 3], &device); - let grad_out = Tensor::::ones(&[1, 1], DType::F32, &device); - - let backward = MaxBackward::::new(a.id(), a.clone(), &[1], true, None); - let grads = backward.backward(&grad_out).unwrap(); - - let grad_a = grads[0].as_ref().unwrap(); - assert_eq!(grad_a.shape(), &[1, 3]); - - let grad_data: Vec = grad_a.to_vec(); - // Gradient split equally among two max elements - assert!((grad_data[0] - 0.5).abs() < 1e-6); - assert!((grad_data[1] - 0.5).abs() < 1e-6); - assert!((grad_data[2] - 0.0).abs() < 1e-6); - } -} diff --git a/src/autograd/ops/reduce/common.rs b/src/autograd/ops/reduce/common.rs new file mode 100644 index 00000000..0668a3b6 --- /dev/null +++ b/src/autograd/ops/reduce/common.rs @@ -0,0 +1,14 @@ +//! Shared utilities for reduction backward implementations + +use crate::runtime::Runtime; +use crate::tensor::Tensor; + +/// Ensure a tensor is contiguous, making a copy if necessary. +#[inline] +pub(super) fn ensure_contiguous(tensor: Tensor) -> Tensor { + if tensor.is_contiguous() { + tensor + } else { + tensor.contiguous() + } +} diff --git a/src/autograd/ops/reduce/extremum.rs b/src/autograd/ops/reduce/extremum.rs new file mode 100644 index 00000000..d2d51a82 --- /dev/null +++ b/src/autograd/ops/reduce/extremum.rs @@ -0,0 +1,327 @@ +//! Backward implementations for max and min reductions + +use crate::autograd::GradFn; +use crate::autograd::var::Var; +use crate::autograd::var_ops::var_mul; +use crate::error::Result; +use crate::ops::{BinaryOps, CompareOps, ReduceOps, ScalarOps, TensorOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use crate::tensor::{Tensor, TensorId}; +use std::sync::Arc; + +use super::common::ensure_contiguous; + +// ============================================================================ +// MaxBackward +// ============================================================================ + +/// Backward for max reduction: z = max(a, dims) +/// +/// The gradient flows only to the element(s) that had the maximum value. +/// For ties, the gradient is distributed equally among tied elements. +pub struct MaxBackward { + input_id: TensorId, + saved_input: Tensor, + dims: Vec, + keepdim: bool, + input_grad_fn: Option>>, +} + +impl MaxBackward { + /// Create a new MaxBackward + pub fn new( + input_id: TensorId, + input: Tensor, + dims: &[usize], + keepdim: bool, + input_grad_fn: Option>>, + ) -> Self { + Self { + input_id, + saved_input: input, + dims: dims.to_vec(), + keepdim, + input_grad_fn, + } + } +} + +/// Shared logic for extremum (max/min) backward pass +fn extremum_backward( + saved_input: &Tensor, + grad_output: &Tensor, + dims: &[usize], + keepdim: bool, + is_max: bool, +) -> Result> +where + R::Client: TensorOps + ScalarOps + CompareOps + ReduceOps, +{ + let client = R::default_client(grad_output.device()); + + // Recompute extremum values + let extremum_vals = if is_max { + client.max(saved_input, dims, true)? + } else { + client.min(saved_input, dims, true)? + }; + + // Broadcast to input shape for comparison + let extremum_broadcast = ensure_contiguous(extremum_vals.broadcast_to(saved_input.shape())?); + + // Create mask where input equals extremum (handles ties) + let mask = client.eq(saved_input, &extremum_broadcast)?; + + // Count ties per reduction group + let mask_sum = client.sum(&mask, dims, true)?; + let mask_sum_broadcast = ensure_contiguous(mask_sum.broadcast_to(saved_input.shape())?); + + // Normalize mask by count + let normalized_mask = client.div(&mask, &mask_sum_broadcast)?; + + // Broadcast grad_output to input shape + let mut grad = grad_output.clone(); + if !keepdim { + let mut sorted_dims = dims.to_vec(); + sorted_dims.sort(); + for &dim in &sorted_dims { + grad = grad.unsqueeze(dim as isize)?; + } + } + let grad_broadcast = ensure_contiguous(grad.broadcast_to(saved_input.shape())?); + + client.mul(&grad_broadcast, &normalized_mask) +} + +/// Shared logic for extremum backward_var pass +fn extremum_backward_var( + saved_input: &Tensor, + grad_output: &Var, + dims: &[usize], + keepdim: bool, + is_max: bool, +) -> Result> +where + R::Client: RuntimeClient + TensorOps + ScalarOps + CompareOps + ReduceOps, +{ + let client = R::default_client(grad_output.tensor().device()); + + let extremum_vals = if is_max { + client.max(saved_input, dims, true)? + } else { + client.min(saved_input, dims, true)? + }; + + let extremum_broadcast = ensure_contiguous(extremum_vals.broadcast_to(saved_input.shape())?); + let mask = client.eq(saved_input, &extremum_broadcast)?; + let mask_sum = client.sum(&mask, dims, true)?; + let mask_sum_broadcast = ensure_contiguous(mask_sum.broadcast_to(saved_input.shape())?); + let normalized_mask = client.div(&mask, &mask_sum_broadcast)?; + + let mut grad_tensor = grad_output.tensor().clone(); + if !keepdim { + let mut sorted_dims = dims.to_vec(); + sorted_dims.sort(); + for &dim in &sorted_dims { + grad_tensor = grad_tensor.unsqueeze(dim as isize)?; + } + } + let grad_broadcast = ensure_contiguous(grad_tensor.broadcast_to(saved_input.shape())?); + + let grad_var = Var::new(grad_broadcast, grad_output.requires_grad()); + let mask_var = Var::new(normalized_mask, false); + + var_mul(&grad_var, &mask_var, &client) +} + +impl GradFn for MaxBackward +where + R::Client: TensorOps + ScalarOps + CompareOps + ReduceOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let grad_input = extremum_backward( + &self.saved_input, + grad_output, + &self.dims, + self.keepdim, + true, + )?; + Ok(vec![Some(grad_input)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + TensorOps + ScalarOps + CompareOps + ReduceOps, + { + let grad_input = extremum_backward_var( + &self.saved_input, + grad_output, + &self.dims, + self.keepdim, + true, + )?; + Ok(vec![Some(grad_input)]) + } + + fn inputs(&self) -> &[TensorId] { + std::slice::from_ref(&self.input_id) + } + + fn input_grad_fns(&self) -> Vec>>> { + vec![self.input_grad_fn.clone()] + } + + fn saved_tensors(&self) -> &[Tensor] { + std::slice::from_ref(&self.saved_input) + } + + fn name(&self) -> &'static str { + "MaxBackward" + } +} + +// ============================================================================ +// MinBackward +// ============================================================================ + +/// Backward for min reduction: z = min(a, dims) +/// +/// The gradient flows only to the element(s) that had the minimum value. +/// For ties, the gradient is distributed equally among tied elements. +pub struct MinBackward { + input_id: TensorId, + saved_input: Tensor, + dims: Vec, + keepdim: bool, + input_grad_fn: Option>>, +} + +impl MinBackward { + /// Create a new MinBackward + pub fn new( + input_id: TensorId, + input: Tensor, + dims: &[usize], + keepdim: bool, + input_grad_fn: Option>>, + ) -> Self { + Self { + input_id, + saved_input: input, + dims: dims.to_vec(), + keepdim, + input_grad_fn, + } + } +} + +impl GradFn for MinBackward +where + R::Client: TensorOps + ScalarOps + CompareOps + ReduceOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let grad_input = extremum_backward( + &self.saved_input, + grad_output, + &self.dims, + self.keepdim, + false, + )?; + Ok(vec![Some(grad_input)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + TensorOps + ScalarOps + CompareOps + ReduceOps, + { + let grad_input = extremum_backward_var( + &self.saved_input, + grad_output, + &self.dims, + self.keepdim, + false, + )?; + Ok(vec![Some(grad_input)]) + } + + fn inputs(&self) -> &[TensorId] { + std::slice::from_ref(&self.input_id) + } + + fn input_grad_fns(&self) -> Vec>>> { + vec![self.input_grad_fn.clone()] + } + + fn saved_tensors(&self) -> &[Tensor] { + std::slice::from_ref(&self.saved_input) + } + + fn name(&self) -> &'static str { + "MinBackward" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dtype::DType; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + + #[test] + fn test_max_backward() { + let device = CpuDevice::new(); + let _client = CpuRuntime::default_client(&device); + + let a = + Tensor::::from_slice(&[1.0f32, 3.0, 2.0, 4.0, 2.0, 5.0], &[2, 3], &device); + let grad_out = Tensor::::ones(&[2, 1], DType::F32, &device); + + let backward = MaxBackward::::new(a.id(), a.clone(), &[1], true, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_a = grads[0].as_ref().unwrap(); + assert_eq!(grad_a.shape(), &[2, 3]); + + let grad_data: Vec = grad_a.to_vec(); + assert_eq!(grad_data, vec![0.0, 1.0, 0.0, 0.0, 0.0, 1.0]); + } + + #[test] + fn test_min_backward() { + let device = CpuDevice::new(); + let _client = CpuRuntime::default_client(&device); + + let a = + Tensor::::from_slice(&[3.0f32, 1.0, 2.0, 4.0, 2.0, 5.0], &[2, 3], &device); + let grad_out = Tensor::::ones(&[2, 1], DType::F32, &device); + + let backward = MinBackward::::new(a.id(), a.clone(), &[1], true, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_a = grads[0].as_ref().unwrap(); + assert_eq!(grad_a.shape(), &[2, 3]); + + let grad_data: Vec = grad_a.to_vec(); + assert_eq!(grad_data, vec![0.0, 1.0, 0.0, 0.0, 1.0, 0.0]); + } + + #[test] + fn test_max_backward_with_ties() { + let device = CpuDevice::new(); + let _client = CpuRuntime::default_client(&device); + + let a = Tensor::::from_slice(&[3.0f32, 3.0, 1.0], &[1, 3], &device); + let grad_out = Tensor::::ones(&[1, 1], DType::F32, &device); + + let backward = MaxBackward::::new(a.id(), a.clone(), &[1], true, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_a = grads[0].as_ref().unwrap(); + assert_eq!(grad_a.shape(), &[1, 3]); + + let grad_data: Vec = grad_a.to_vec(); + assert!((grad_data[0] - 0.5).abs() < 1e-6); + assert!((grad_data[1] - 0.5).abs() < 1e-6); + assert!((grad_data[2] - 0.0).abs() < 1e-6); + } +} diff --git a/src/autograd/ops/reduce/mod.rs b/src/autograd/ops/reduce/mod.rs new file mode 100644 index 00000000..b3d07989 --- /dev/null +++ b/src/autograd/ops/reduce/mod.rs @@ -0,0 +1,10 @@ +//! Backward implementations for reduction operations + +mod common; +mod extremum; +mod statistical; +mod sum_mean; + +pub use extremum::*; +pub use statistical::*; +pub use sum_mean::*; diff --git a/src/autograd/ops/reduce/statistical.rs b/src/autograd/ops/reduce/statistical.rs new file mode 100644 index 00000000..824d080c --- /dev/null +++ b/src/autograd/ops/reduce/statistical.rs @@ -0,0 +1,309 @@ +//! Backward implementations for variance and standard deviation reductions + +use crate::autograd::GradFn; +use crate::autograd::var::Var; +use crate::autograd::var_ops::var_mul; +use crate::error::Result; +use crate::ops::{BinaryOps, ReduceOps, ScalarOps, TensorOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use crate::tensor::{Tensor, TensorId}; +use std::sync::Arc; + +use super::common::ensure_contiguous; + +// ============================================================================ +// VarBackward +// ============================================================================ + +/// Backward for variance reduction: z = var(a, dims, correction) +/// +/// The gradient of variance is: +/// dL/da = dL/dz * 2 * (a - mean(a)) / (N - correction) +/// +/// where N is the number of elements being reduced. +pub struct VarBackward { + input_id: TensorId, + saved_input: Tensor, + dims: Vec, + keepdim: bool, + correction: usize, + input_grad_fn: Option>>, +} + +impl VarBackward { + /// Create a new VarBackward + pub fn new( + input_id: TensorId, + input: Tensor, + dims: &[usize], + keepdim: bool, + correction: usize, + input_grad_fn: Option>>, + ) -> Self { + Self { + input_id, + saved_input: input, + dims: dims.to_vec(), + keepdim, + correction, + input_grad_fn, + } + } +} + +impl GradFn for VarBackward +where + R::Client: TensorOps + ScalarOps + ReduceOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + + let n: usize = self + .dims + .iter() + .map(|&d| self.saved_input.shape()[d]) + .product(); + let n_minus_corr = (n - self.correction) as f64; + + let mean = client.mean(&self.saved_input, &self.dims, true)?; + let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?); + + let centered = client.sub(&self.saved_input, &mean_broadcast)?; + + let scale = 2.0 / n_minus_corr; + let grad_contrib = client.mul_scalar(¢ered, scale)?; + + let mut grad = grad_output.clone(); + if !self.keepdim { + let mut sorted_dims = self.dims.clone(); + sorted_dims.sort(); + for &dim in &sorted_dims { + grad = grad.unsqueeze(dim as isize)?; + } + } + let grad_broadcast = ensure_contiguous(grad.broadcast_to(self.saved_input.shape())?); + + let grad_input = client.mul(&grad_broadcast, &grad_contrib)?; + + Ok(vec![Some(grad_input)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + TensorOps + ScalarOps + ReduceOps, + { + let client = R::default_client(grad_output.tensor().device()); + + let n: usize = self + .dims + .iter() + .map(|&d| self.saved_input.shape()[d]) + .product(); + let n_minus_corr = (n - self.correction) as f64; + + let mean = client.mean(&self.saved_input, &self.dims, true)?; + let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?); + + let centered = client.sub(&self.saved_input, &mean_broadcast)?; + + let scale = 2.0 / n_minus_corr; + let grad_contrib = client.mul_scalar(¢ered, scale)?; + + let mut grad_tensor = grad_output.tensor().clone(); + if !self.keepdim { + let mut sorted_dims = self.dims.clone(); + sorted_dims.sort(); + for &dim in &sorted_dims { + grad_tensor = grad_tensor.unsqueeze(dim as isize)?; + } + } + let grad_broadcast = ensure_contiguous(grad_tensor.broadcast_to(self.saved_input.shape())?); + + let grad_var = Var::new(grad_broadcast, grad_output.requires_grad()); + let contrib_var = Var::new(grad_contrib, false); + + let grad_input = var_mul(&grad_var, &contrib_var, &client)?; + + Ok(vec![Some(grad_input)]) + } + + fn inputs(&self) -> &[TensorId] { + std::slice::from_ref(&self.input_id) + } + + fn input_grad_fns(&self) -> Vec>>> { + vec![self.input_grad_fn.clone()] + } + + fn saved_tensors(&self) -> &[Tensor] { + std::slice::from_ref(&self.saved_input) + } + + fn name(&self) -> &'static str { + "VarBackward" + } +} + +// ============================================================================ +// StdBackward +// ============================================================================ + +/// Backward for standard deviation reduction: z = std(a, dims, correction) +/// +/// std = sqrt(var), so by chain rule: +/// dL/da = dL/dz * d(sqrt(var))/dvar * dvar/da +/// = dL/dz * 1/(2*std) * 2*(a - mean) / (N - correction) +/// = dL/dz * (a - mean) / ((N - correction) * std) +pub struct StdBackward { + input_id: TensorId, + saved_input: Tensor, + saved_output: Tensor, + dims: Vec, + keepdim: bool, + correction: usize, + input_grad_fn: Option>>, +} + +impl StdBackward { + /// Create a new StdBackward + pub fn new( + input_id: TensorId, + input: Tensor, + output: Tensor, + dims: &[usize], + keepdim: bool, + correction: usize, + input_grad_fn: Option>>, + ) -> Self { + Self { + input_id, + saved_input: input, + saved_output: output, + dims: dims.to_vec(), + keepdim, + correction, + input_grad_fn, + } + } +} + +impl GradFn for StdBackward +where + R::Client: TensorOps + ScalarOps + ReduceOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + + let n: usize = self + .dims + .iter() + .map(|&d| self.saved_input.shape()[d]) + .product(); + let n_minus_corr = (n - self.correction) as f64; + + let mean = client.mean(&self.saved_input, &self.dims, true)?; + let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?); + + let std_for_broadcast = if self.keepdim { + self.saved_output.clone() + } else { + let mut std_expanded = self.saved_output.clone(); + let mut sorted_dims = self.dims.clone(); + sorted_dims.sort(); + for &dim in &sorted_dims { + std_expanded = std_expanded.unsqueeze(dim as isize)?; + } + std_expanded + }; + let std_broadcast = + ensure_contiguous(std_for_broadcast.broadcast_to(self.saved_input.shape())?); + + let centered = client.sub(&self.saved_input, &mean_broadcast)?; + + let denominator = client.mul_scalar(&std_broadcast, n_minus_corr)?; + let grad_contrib = client.div(¢ered, &denominator)?; + + let mut grad = grad_output.clone(); + if !self.keepdim { + let mut sorted_dims = self.dims.clone(); + sorted_dims.sort(); + for &dim in &sorted_dims { + grad = grad.unsqueeze(dim as isize)?; + } + } + let grad_broadcast = ensure_contiguous(grad.broadcast_to(self.saved_input.shape())?); + + let grad_input = client.mul(&grad_broadcast, &grad_contrib)?; + + Ok(vec![Some(grad_input)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + TensorOps + ScalarOps + ReduceOps, + { + let client = R::default_client(grad_output.tensor().device()); + + let n: usize = self + .dims + .iter() + .map(|&d| self.saved_input.shape()[d]) + .product(); + let n_minus_corr = (n - self.correction) as f64; + + let mean = client.mean(&self.saved_input, &self.dims, true)?; + let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?); + + let std_for_broadcast = if self.keepdim { + self.saved_output.clone() + } else { + let mut std_expanded = self.saved_output.clone(); + let mut sorted_dims = self.dims.clone(); + sorted_dims.sort(); + for &dim in &sorted_dims { + std_expanded = std_expanded.unsqueeze(dim as isize)?; + } + std_expanded + }; + let std_broadcast = + ensure_contiguous(std_for_broadcast.broadcast_to(self.saved_input.shape())?); + + let centered = client.sub(&self.saved_input, &mean_broadcast)?; + + let denominator = client.mul_scalar(&std_broadcast, n_minus_corr)?; + let grad_contrib = client.div(¢ered, &denominator)?; + + let mut grad_tensor = grad_output.tensor().clone(); + if !self.keepdim { + let mut sorted_dims = self.dims.clone(); + sorted_dims.sort(); + for &dim in &sorted_dims { + grad_tensor = grad_tensor.unsqueeze(dim as isize)?; + } + } + let grad_broadcast = ensure_contiguous(grad_tensor.broadcast_to(self.saved_input.shape())?); + + let grad_var = Var::new(grad_broadcast, grad_output.requires_grad()); + let contrib_var = Var::new(grad_contrib, false); + + let grad_input = var_mul(&grad_var, &contrib_var, &client)?; + + Ok(vec![Some(grad_input)]) + } + + fn inputs(&self) -> &[TensorId] { + std::slice::from_ref(&self.input_id) + } + + fn input_grad_fns(&self) -> Vec>>> { + vec![self.input_grad_fn.clone()] + } + + fn saved_tensors(&self) -> &[Tensor] { + std::slice::from_ref(&self.saved_input) + } + + fn name(&self) -> &'static str { + "StdBackward" + } +} diff --git a/src/autograd/ops/reduce/sum_mean.rs b/src/autograd/ops/reduce/sum_mean.rs new file mode 100644 index 00000000..2a9b708f --- /dev/null +++ b/src/autograd/ops/reduce/sum_mean.rs @@ -0,0 +1,252 @@ +//! Backward implementations for sum and mean reductions + +use crate::autograd::GradFn; +use crate::autograd::var::Var; +use crate::autograd::var_ops::var_div_scalar; +use crate::error::Result; +use crate::ops::ScalarOps; +use crate::runtime::{Runtime, RuntimeClient}; +use crate::tensor::{Tensor, TensorId}; +use std::sync::Arc; + +use super::common::ensure_contiguous; + +// ============================================================================ +// SumBackward +// ============================================================================ + +/// Backward for sum reduction: z = sum(a, dims) +/// +/// The gradient of sum is broadcast expansion. +/// For z = sum(a, dims), dL/da = broadcast(dL/dz, original_shape) +/// +/// If keepdim=false, we need to unsqueeze the gradient before broadcasting. +pub struct SumBackward { + input_id: TensorId, + input_shape: Vec, + dims: Vec, + keepdim: bool, + input_grad_fn: Option>>, +} + +impl SumBackward { + /// Create a new SumBackward + pub fn new( + input_id: TensorId, + input_shape: &[usize], + dims: &[usize], + keepdim: bool, + input_grad_fn: Option>>, + ) -> Self { + Self { + input_id, + input_shape: input_shape.to_vec(), + dims: dims.to_vec(), + keepdim, + input_grad_fn, + } + } +} + +impl GradFn for SumBackward { + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let mut grad = grad_output.clone(); + + if !self.keepdim { + let mut sorted_dims = self.dims.clone(); + sorted_dims.sort(); + for &dim in &sorted_dims { + grad = grad.unsqueeze(dim as isize)?; + } + } + + grad = ensure_contiguous(grad.broadcast_to(&self.input_shape)?); + + Ok(vec![Some(grad)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> { + let mut grad_tensor = grad_output.tensor().clone(); + + if !self.keepdim { + let mut sorted_dims = self.dims.clone(); + sorted_dims.sort(); + for &dim in &sorted_dims { + grad_tensor = grad_tensor.unsqueeze(dim as isize)?; + } + } + + grad_tensor = ensure_contiguous(grad_tensor.broadcast_to(&self.input_shape)?); + + Ok(vec![Some(Var::new(grad_tensor, true))]) + } + + fn inputs(&self) -> &[TensorId] { + std::slice::from_ref(&self.input_id) + } + + fn input_grad_fns(&self) -> Vec>>> { + vec![self.input_grad_fn.clone()] + } + + fn name(&self) -> &'static str { + "SumBackward" + } +} + +// ============================================================================ +// MeanBackward +// ============================================================================ + +/// Backward for mean reduction: z = mean(a, dims) +/// +/// For z = mean(a, dims), dL/da = broadcast(dL/dz, original_shape) / count +/// where count is the number of elements being averaged. +pub struct MeanBackward { + input_id: TensorId, + input_shape: Vec, + dims: Vec, + keepdim: bool, + input_grad_fn: Option>>, +} + +impl MeanBackward { + /// Create a new MeanBackward + pub fn new( + input_id: TensorId, + input_shape: &[usize], + dims: &[usize], + keepdim: bool, + input_grad_fn: Option>>, + ) -> Self { + Self { + input_id, + input_shape: input_shape.to_vec(), + dims: dims.to_vec(), + keepdim, + input_grad_fn, + } + } +} + +impl GradFn for MeanBackward +where + R::Client: ScalarOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + + let count: usize = self.dims.iter().map(|&d| self.input_shape[d]).product(); + let count_f64 = count as f64; + + let mut grad = grad_output.clone(); + + if !self.keepdim { + let mut sorted_dims = self.dims.clone(); + sorted_dims.sort(); + for &dim in &sorted_dims { + grad = grad.unsqueeze(dim as isize)?; + } + } + + grad = ensure_contiguous(grad.broadcast_to(&self.input_shape)?); + + let grad = client.div_scalar(&grad, count_f64)?; + + Ok(vec![Some(grad)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + crate::ops::TensorOps + ScalarOps, + { + let client = R::default_client(grad_output.tensor().device()); + + let count: usize = self.dims.iter().map(|&d| self.input_shape[d]).product(); + let count_f64 = count as f64; + + let mut grad_tensor = grad_output.tensor().clone(); + + if !self.keepdim { + let mut sorted_dims = self.dims.clone(); + sorted_dims.sort(); + for &dim in &sorted_dims { + grad_tensor = grad_tensor.unsqueeze(dim as isize)?; + } + } + + grad_tensor = ensure_contiguous(grad_tensor.broadcast_to(&self.input_shape)?); + + let grad_var = Var::new(grad_tensor, grad_output.requires_grad()); + let grad = var_div_scalar(&grad_var, count_f64, &client)?; + + Ok(vec![Some(grad)]) + } + + fn inputs(&self) -> &[TensorId] { + std::slice::from_ref(&self.input_id) + } + + fn input_grad_fns(&self) -> Vec>>> { + vec![self.input_grad_fn.clone()] + } + + fn name(&self) -> &'static str { + "MeanBackward" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dtype::DType; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + + #[test] + fn test_sum_backward_keepdim() { + let device = CpuDevice::new(); + let grad_out = Tensor::::ones(&[2, 1], DType::F32, &device); + + let backward = SumBackward::::new(TensorId::new(), &[2, 3], &[1], true, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_a = grads[0].as_ref().unwrap(); + assert_eq!(grad_a.shape(), &[2, 3]); + + let grad_data: Vec = grad_a.to_vec(); + assert_eq!(grad_data, vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); + } + + #[test] + fn test_sum_backward_no_keepdim() { + let device = CpuDevice::new(); + let grad_out = Tensor::::ones(&[2], DType::F32, &device); + + let backward = SumBackward::::new(TensorId::new(), &[2, 3], &[1], false, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_a = grads[0].as_ref().unwrap(); + assert_eq!(grad_a.shape(), &[2, 3]); + + let grad_data: Vec = grad_a.to_vec(); + assert_eq!(grad_data, vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); + } + + #[test] + fn test_mean_backward() { + let device = CpuDevice::new(); + let grad_out = Tensor::::ones(&[2, 1], DType::F32, &device); + + let backward = MeanBackward::::new(TensorId::new(), &[2, 3], &[1], true, None); + let grads = backward.backward(&grad_out).unwrap(); + + let grad_a = grads[0].as_ref().unwrap(); + assert_eq!(grad_a.shape(), &[2, 3]); + + let grad_data: Vec = grad_a.to_vec(); + let expected = 1.0 / 3.0; + for val in grad_data { + assert!((val - expected).abs() < 1e-6); + } + } +} diff --git a/src/autograd/ops/shape.rs b/src/autograd/ops/shape.rs index c4a1876c..fcfc1fe0 100644 --- a/src/autograd/ops/shape.rs +++ b/src/autograd/ops/shape.rs @@ -9,8 +9,9 @@ //! is just reshaping the gradient back to the original shape. use crate::autograd::{GradFn, Var}; +use crate::dtype::DType; use crate::error::Result; -use crate::ops::ReduceOps; +use crate::ops::{ReduceOps, ShapeOps}; use crate::runtime::{Runtime, RuntimeClient}; use crate::tensor::{Tensor, TensorId}; use std::sync::Arc; @@ -419,6 +420,303 @@ where } } +// ============================================================================ +// NarrowBackward +// ============================================================================ + +/// Backward for narrow: z = narrow(a, dim, start, length) +/// +/// Gradient: dL/da is a zero tensor with dL/dz placed at the sliced region. +/// We use pad-with-zeros: create zeros of original shape, then add the gradient +/// into the narrow region. +pub struct NarrowBackward { + input_id: TensorId, + input_shape: Vec, + dim: usize, + start: usize, + input_grad_fn: Option>>, +} + +impl NarrowBackward { + /// Create a new `NarrowBackward` node. + /// + /// - `input_id` — ID of the input tensor before narrowing + /// - `input_shape` — original shape of the input tensor + /// - `dim` — dimension that was narrowed + /// - `start` — start index along `dim` + /// - `input_grad_fn` — gradient function of the input, if it requires grad + pub fn new( + input_id: TensorId, + input_shape: Vec, + dim: usize, + start: usize, + input_grad_fn: Option>>, + ) -> Self { + Self { + input_id, + input_shape, + dim, + start, + input_grad_fn, + } + } +} + +impl> GradFn for NarrowBackward +where + R::Client: RuntimeClient + crate::ops::TensorOps + ShapeOps, +{ + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let client = R::default_client(grad_output.device()); + + // Pad gradient back to original size along the narrowed dimension. + // Before: zeros of size [start], After: zeros of size [orig_dim - start - length] + let length = grad_output.shape()[self.dim]; + let orig_dim_size = self.input_shape[self.dim]; + let end = self.start + length; + + let mut parts: Vec> = Vec::new(); + + // Padding before the narrow region + if self.start > 0 { + let mut pad_shape = self.input_shape.clone(); + pad_shape[self.dim] = self.start; + parts.push(Tensor::::zeros( + &pad_shape, + grad_output.dtype(), + grad_output.device(), + )); + } + + // The gradient itself (make contiguous for cat) + parts.push(grad_output.contiguous()); + + // Padding after the narrow region + if end < orig_dim_size { + let mut pad_shape = self.input_shape.clone(); + pad_shape[self.dim] = orig_dim_size - end; + parts.push(Tensor::::zeros( + &pad_shape, + grad_output.dtype(), + grad_output.device(), + )); + } + + let refs: Vec<&Tensor> = parts.iter().collect(); + let grad_input = client.cat(&refs, self.dim as isize)?; + + Ok(vec![Some(grad_input)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> { + let client = R::default_client(grad_output.tensor().device()); + + let length = grad_output.shape()[self.dim]; + let orig_dim_size = self.input_shape[self.dim]; + let end = self.start + length; + + let mut parts: Vec> = Vec::new(); + + if self.start > 0 { + let mut pad_shape = self.input_shape.clone(); + pad_shape[self.dim] = self.start; + parts.push(Tensor::::zeros( + &pad_shape, + grad_output.tensor().dtype(), + grad_output.tensor().device(), + )); + } + + parts.push(grad_output.tensor().contiguous()); + + if end < orig_dim_size { + let mut pad_shape = self.input_shape.clone(); + pad_shape[self.dim] = orig_dim_size - end; + parts.push(Tensor::::zeros( + &pad_shape, + grad_output.tensor().dtype(), + grad_output.tensor().device(), + )); + } + + let refs: Vec<&Tensor> = parts.iter().collect(); + let grad_input = client.cat(&refs, self.dim as isize)?; + + Ok(vec![Some(Var::new(grad_input, false))]) + } + + fn inputs(&self) -> &[TensorId] { + std::slice::from_ref(&self.input_id) + } + + fn input_grad_fns(&self) -> Vec>>> { + vec![self.input_grad_fn.clone()] + } + + fn name(&self) -> &'static str { + "NarrowBackward" + } +} + +// ============================================================================ +// CatBackward +// ============================================================================ + +/// Backward for cat: z = cat([a, b, ...], dim) +/// +/// Gradient: split dL/dz along dim, one slice per input. +pub struct CatBackward { + input_ids: Vec, + /// Size of each input along the cat dimension + split_sizes: Vec, + dim: usize, + input_grad_fns: Vec>>>, +} + +impl CatBackward { + /// Create a new `CatBackward` node. + /// + /// - `input_ids` — IDs of the input tensors that were concatenated + /// - `split_sizes` — size of each input along the cat dimension + /// - `dim` — dimension along which the inputs were concatenated + /// - `input_grad_fns` — gradient functions of each input, if they require grad + pub fn new( + input_ids: Vec, + split_sizes: Vec, + dim: usize, + input_grad_fns: Vec>>>, + ) -> Self { + Self { + input_ids, + split_sizes, + dim, + input_grad_fns, + } + } +} + +impl GradFn for CatBackward { + fn backward(&self, grad_output: &Tensor) -> Result>>> { + let mut grads = Vec::with_capacity(self.split_sizes.len()); + let mut offset = 0; + for &size in &self.split_sizes { + let grad_slice = grad_output.narrow(self.dim as isize, offset, size)?; + // Make contiguous so downstream ops get clean data + grads.push(Some(grad_slice.contiguous())); + offset += size; + } + Ok(grads) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> { + let mut grads = Vec::with_capacity(self.split_sizes.len()); + let mut offset = 0; + for &size in &self.split_sizes { + let grad_slice = grad_output + .tensor() + .narrow(self.dim as isize, offset, size)? + .contiguous(); + grads.push(Some(Var::new(grad_slice, false))); + offset += size; + } + Ok(grads) + } + + fn inputs(&self) -> &[TensorId] { + &self.input_ids + } + + fn input_grad_fns(&self) -> Vec>>> { + self.input_grad_fns.clone() + } + + fn name(&self) -> &'static str { + "CatBackward" + } +} + +// ============================================================================ +// Var Operations for Narrow and Cat +// ============================================================================ + +/// Narrow (slice) a Var along a dimension +/// +/// Creates NarrowBackward for gradient computation. +pub fn var_narrow>( + a: &Var, + dim: isize, + start: usize, + length: usize, +) -> Result> +where + R::Client: RuntimeClient + crate::ops::TensorOps + ShapeOps, +{ + let dim_idx = + a.tensor() + .layout() + .normalize_dim(dim) + .ok_or(crate::error::Error::InvalidDimension { + dim, + ndim: a.ndim(), + })?; + + let output = a.tensor().narrow(dim, start, length)?; + + if a.requires_grad() { + let grad_fn = NarrowBackward::::new( + a.id(), + a.shape().to_vec(), + dim_idx, + start, + a.grad_fn().cloned(), + ); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + +/// Concatenate Vars along a dimension +/// +/// Creates CatBackward for gradient computation. +pub fn var_cat(vars: &[&Var], dim: isize, client: &C) -> Result> +where + R: Runtime, + C: RuntimeClient + crate::ops::ShapeOps, +{ + if vars.is_empty() { + return Err(crate::error::Error::InvalidArgument { + arg: "vars", + reason: "var_cat requires at least one input".into(), + }); + } + + let tensors: Vec<&Tensor> = vars.iter().map(|v| v.tensor()).collect(); + let output = client.cat(&tensors, dim)?; + + let any_requires_grad = vars.iter().any(|v| v.requires_grad()); + + if any_requires_grad { + // Normalize dim for split_sizes + let dim_idx = vars[0].tensor().layout().normalize_dim(dim).ok_or( + crate::error::Error::InvalidDimension { + dim, + ndim: vars[0].ndim(), + }, + )?; + + let input_ids: Vec = vars.iter().map(|v| v.id()).collect(); + let split_sizes: Vec = vars.iter().map(|v| v.shape()[dim_idx]).collect(); + let input_grad_fns: Vec>>> = + vars.iter().map(|v| v.grad_fn().cloned()).collect(); + + let grad_fn = CatBackward::::new(input_ids, split_sizes, dim_idx, input_grad_fns); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + #[cfg(test)] mod tests { use super::*; @@ -629,4 +927,90 @@ mod tests { let grad_data: Vec = grad.to_vec(); assert_eq!(grad_data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); } + + #[test] + fn test_var_narrow() { + let device = CpuDevice::new(); + + let tensor = + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[6], &device); + let x = Var::new(tensor, true); + + let y = var_narrow(&x, 0, 1, 3).unwrap(); + assert_eq!(y.shape(), &[3]); + assert!(y.requires_grad()); + assert_eq!(y.grad_fn().unwrap().name(), "NarrowBackward"); + + let y_data: Vec = y.tensor().to_vec(); + assert_eq!(y_data, vec![2.0, 3.0, 4.0]); + } + + #[test] + fn test_narrow_backward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let x = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &[5], &device), + true, + ); + + // narrow(dim=0, start=1, length=3) -> [2.0, 3.0, 4.0] + let y = var_narrow(&x, 0, 1, 3).unwrap(); + let loss = crate::autograd::var_sum(&y, &[0], false, &client).unwrap(); + let grads = crate::autograd::backward(&loss, &client).unwrap(); + + let grad_x: Vec = grads.get(x.id()).unwrap().to_vec(); + // Gradient should be [0, 1, 1, 1, 0] — ones in the narrow region, zeros outside + assert_eq!(grad_x, vec![0.0, 1.0, 1.0, 1.0, 0.0]); + } + + #[test] + fn test_var_cat() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let a = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0], &[2], &device), + true, + ); + let b = Var::new( + Tensor::::from_slice(&[3.0f32, 4.0, 5.0], &[3], &device), + true, + ); + + let c = var_cat(&[&a, &b], 0, &client).unwrap(); + assert_eq!(c.shape(), &[5]); + assert!(c.requires_grad()); + assert_eq!(c.grad_fn().unwrap().name(), "CatBackward"); + + let c_data: Vec = c.tensor().to_vec(); + assert_eq!(c_data, vec![1.0, 2.0, 3.0, 4.0, 5.0]); + } + + #[test] + fn test_cat_backward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let a = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0], &[2], &device), + true, + ); + let b = Var::new( + Tensor::::from_slice(&[3.0f32, 4.0, 5.0], &[3], &device), + true, + ); + + let c = var_cat(&[&a, &b], 0, &client).unwrap(); + let loss = crate::autograd::var_sum(&c, &[0], false, &client).unwrap(); + let grads = crate::autograd::backward(&loss, &client).unwrap(); + + let grad_a: Vec = grads.get(a.id()).unwrap().to_vec(); + let grad_b: Vec = grads.get(b.id()).unwrap().to_vec(); + + // Sum backward → all ones, split back to original sizes + assert_eq!(grad_a, vec![1.0, 1.0]); + assert_eq!(grad_b, vec![1.0, 1.0, 1.0]); + } } diff --git a/src/autograd/ops/unary.rs b/src/autograd/ops/unary.rs index de5ffe1f..a74ef77a 100644 --- a/src/autograd/ops/unary.rs +++ b/src/autograd/ops/unary.rs @@ -6,6 +6,7 @@ use crate::autograd::{ GradFn, Var, var_abs, var_cos, var_div, var_mul, var_mul_scalar, var_neg, var_sin, var_square, var_sub, }; +use crate::dtype::DType; use crate::error::Result; use crate::ops::{BinaryOps, CompareOps, ScalarOps, TensorOps, UnaryOps}; use crate::runtime::{Runtime, RuntimeClient}; @@ -363,7 +364,7 @@ impl TanhBackward { } } -impl GradFn for TanhBackward +impl> GradFn for TanhBackward where R::Client: TensorOps + ScalarOps, { @@ -685,7 +686,7 @@ impl ClampBackward { } } -impl GradFn for ClampBackward +impl> GradFn for ClampBackward where R::Client: TensorOps + ScalarOps + CompareOps, { diff --git a/src/autograd/var_grad_store.rs b/src/autograd/var_grad_store.rs index 1cdbfe7e..adddc4fd 100644 --- a/src/autograd/var_grad_store.rs +++ b/src/autograd/var_grad_store.rs @@ -17,7 +17,7 @@ use std::collections::HashMap; /// /// # Example /// -/// ``` +/// ```no_run /// # use numr::prelude::*; /// # use numr::autograd::{backward_with_graph, backward, Var, var_mul, var_sum}; /// # let device = CpuDevice::new(); diff --git a/src/autograd/var_ops/activation.rs b/src/autograd/var_ops/activation.rs index 88b12ffa..81f623ac 100644 --- a/src/autograd/var_ops/activation.rs +++ b/src/autograd/var_ops/activation.rs @@ -2,15 +2,16 @@ use super::ops::*; use crate::autograd::Var; +use crate::dtype::DType; use crate::error::Result; -use crate::ops::{CompareOps, ReduceOps, ScalarOps, TensorOps}; +use crate::ops::{ActivationOps, CompareOps, ReduceOps, ScalarOps, TensorOps, UnaryOps}; use crate::runtime::{Runtime, RuntimeClient}; use std::sync::Arc; /// ReLU: z = max(0, a) pub fn var_relu(a: &Var, client: &C) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + TensorOps + CompareOps, R::Client: TensorOps + CompareOps, { @@ -27,7 +28,7 @@ where /// Sigmoid: z = 1 / (1 + exp(-a)) pub fn var_sigmoid(a: &Var, client: &C) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + TensorOps, R::Client: TensorOps, { @@ -41,6 +42,59 @@ where } } +/// SiLU (Swish) activation: `z = a * sigmoid(a)` +/// +/// A smooth, non-monotonic activation function popular in modern architectures +/// (e.g., SwiGLU in LLaMA). Often preferred over ReLU for its non-zero gradient +/// at negative inputs. +/// +/// Gradient: `dz/da = sigmoid(a) * (1 + a - silu(a))` +pub fn var_silu(a: &Var, client: &C) -> Result> +where + R: Runtime, + C: RuntimeClient + TensorOps + ActivationOps + ScalarOps, + R::Client: TensorOps + ActivationOps + ScalarOps, +{ + let output = client.silu(a.tensor())?; + + if a.requires_grad() { + let grad_fn = SiluBackward::::new( + a.id(), + a.tensor().clone(), + output.clone(), + a.grad_fn().cloned(), + ); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + +/// Softplus: `z = log(1 + exp(a))` +/// +/// A smooth, always-positive approximation to ReLU. Used in Mamba2 for dt +/// (step size) processing via `softplus(dt_proj(x)) + dt_bias`. +/// +/// Computed via the numerically stable form `relu(a) + log(1 + exp(-|a|))` +/// to avoid overflow for large positive inputs. +/// +/// Gradient: `dz/da = sigmoid(a)` +pub fn var_softplus(a: &Var, client: &C) -> Result> +where + R: Runtime, + C: RuntimeClient + ActivationOps, + R::Client: TensorOps + ActivationOps, +{ + let output = client.softplus(a.tensor())?; + + if a.requires_grad() { + let grad_fn = SoftplusBackward::::new(a.id(), a.tensor().clone(), a.grad_fn().cloned()); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + /// Softmax along dimension: z_i = exp(a_i) / sum(exp(a)) pub fn var_softmax(a: &Var, dim: isize, client: &C) -> Result> where @@ -57,3 +111,21 @@ where Ok(Var::new(output, false)) } } + +/// Log-softmax along dimension: z = log(softmax(a, dim)) +pub fn var_log_softmax(a: &Var, dim: isize, client: &C) -> Result> +where + R: Runtime, + C: RuntimeClient + TensorOps + ActivationOps, + R::Client: TensorOps + UnaryOps + ReduceOps + ScalarOps, +{ + let output = client.log_softmax(a.tensor(), dim)?; + + if a.requires_grad() { + let grad_fn = + LogSoftmaxBackward::::new(a.id(), output.clone(), dim, a.grad_fn().cloned()); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} diff --git a/src/autograd/var_ops/cast.rs b/src/autograd/var_ops/cast.rs new file mode 100644 index 00000000..33079122 --- /dev/null +++ b/src/autograd/var_ops/cast.rs @@ -0,0 +1,95 @@ +//! Autograd-aware dtype casting + +use crate::autograd::Var; +use crate::autograd::var_ops::ops::CastBackward; +use crate::dtype::DType; +use crate::error::Result; +use crate::ops::TypeConversionOps; +use crate::runtime::{Runtime, RuntimeClient}; +use std::sync::Arc; + +/// Cast a variable to a different dtype, preserving gradient flow. +/// +/// The backward pass casts the gradient back to the input's original dtype. +/// +/// # Arguments +/// * `a` - Input variable +/// * `dtype` - Target dtype +/// * `client` - Runtime client +pub fn var_cast(a: &Var, dtype: DType, client: &C) -> Result> +where + R: Runtime, + C: RuntimeClient + TypeConversionOps, + R::Client: TypeConversionOps, +{ + let input_dtype = a.tensor().dtype(); + + // No-op if already the target dtype + if input_dtype == dtype { + return Ok(Var::with_id(a.tensor().clone(), a.id(), a.requires_grad())); + } + + let output = client.cast(a.tensor(), dtype)?; + + if a.requires_grad() { + let grad_fn = CastBackward::::new(a.id(), input_dtype); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::autograd::backward; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + use crate::tensor::Tensor; + + #[test] + fn test_var_cast_noop_same_dtype() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + let t = Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device); + let v = Var::new(t, true); + let result = var_cast(&v, DType::F32, &client).unwrap(); + // Same dtype returns clone — data should match + assert_eq!(result.tensor().dtype(), DType::F32); + let data = result.tensor().to_vec::(); + assert_eq!(data, vec![1.0, 2.0, 3.0]); + } + + #[test] + fn test_var_cast_f32_to_f64_gradient() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let t = Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device); + let x = Var::new(t, true); + + // Cast F32 → F64 + let y = var_cast(&x, DType::F64, &client).unwrap(); + assert_eq!(y.tensor().dtype(), DType::F64); + + // Sum to scalar for backward + let sum = crate::autograd::var_sum(&y, &[], false, &client).unwrap(); + let grads = backward(&sum, &client).unwrap(); + + // Gradient should be F32 (cast back from F64) + let grad = grads.get(x.id()).unwrap(); + assert_eq!(grad.dtype(), DType::F32); + let data = grad.to_vec::(); + assert_eq!(data, vec![1.0, 1.0, 1.0]); + } + + #[test] + fn test_var_cast_no_grad() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + let t = Tensor::::from_slice(&[1.0f32, 2.0], &[2], &device); + let v = Var::new(t, false); + let result = var_cast(&v, DType::F64, &client).unwrap(); + assert!(!result.requires_grad()); + assert_eq!(result.tensor().dtype(), DType::F64); + } +} diff --git a/src/autograd/var_ops/conv1d.rs b/src/autograd/var_ops/conv1d.rs new file mode 100644 index 00000000..f2aaed29 --- /dev/null +++ b/src/autograd/var_ops/conv1d.rs @@ -0,0 +1,535 @@ +//! Conv1d autograd operation +//! +//! Wraps `ConvOps::conv1d` with gradient tracking. +//! +//! Backward computes: +//! - d_input = transposed convolution of grad_output with weight +//! - d_weight = cross-correlation of input with grad_output +//! - d_bias = sum(grad_output) over batch and spatial dims + +use crate::autograd::Var; +use crate::dtype::DType; +use crate::error::Result; +use crate::ops::{BinaryOps, ConvOps, PaddingMode, ReduceOps, ScalarOps, TensorOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use std::sync::Arc; + +use super::conv_common::compute_padding; + +/// Differentiable 1D convolution. +/// +/// Wraps the forward `conv1d` and builds autograd graph for backward. +/// +/// # Arguments +/// * `input` - Input Var of shape `[batch, in_channels, length]` +/// * `weight` - Weight Var of shape `[out_channels, in_channels/groups, kernel_size]` +/// * `bias` - Optional bias Var of shape `[out_channels]` +/// * `stride` - Stride +/// * `padding` - Padding mode +/// * `dilation` - Dilation +/// * `groups` - Groups +/// * `client` - Runtime client +pub fn var_conv1d( + input: &Var, + weight: &Var, + bias: Option<&Var>, + stride: usize, + padding: PaddingMode, + dilation: usize, + groups: usize, + client: &C, +) -> Result> +where + R: Runtime, + C: RuntimeClient + ConvOps + TensorOps + ReduceOps + BinaryOps + ScalarOps, + R::Client: ConvOps + TensorOps + ReduceOps + BinaryOps + ScalarOps, +{ + let output = client.conv1d( + input.tensor(), + weight.tensor(), + bias.map(|b| b.tensor()), + stride, + padding, + dilation, + groups, + )?; + + let needs_grad = + input.requires_grad() || weight.requires_grad() || bias.is_some_and(|b| b.requires_grad()); + + if needs_grad { + let grad_fn = Conv1dBackward::::new( + input.id(), + weight.id(), + bias.map(|b| b.id()), + input.tensor().clone(), + weight.tensor().clone(), + input.tensor().shape().to_vec(), + stride, + padding, + dilation, + groups, + input.grad_fn().cloned(), + weight.grad_fn().cloned(), + bias.and_then(|b| b.grad_fn().cloned()), + ); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + +/// Backward for conv1d. +/// +/// Computes gradients for input, weight, and bias using: +/// - d_input: transposed convolution (conv with flipped kernel, adjusted padding) +/// - d_weight: cross-correlation of input with grad_output +/// - d_bias: sum of grad_output over batch and spatial dims +pub struct Conv1dBackward { + input_ids: Vec, + saved_input: crate::tensor::Tensor, + saved_weight: crate::tensor::Tensor, + input_shape: Vec, + stride: usize, + padding: PaddingMode, + dilation: usize, + groups: usize, + input_grad_fn: Option>>, + weight_grad_fn: Option>>, + bias_grad_fn: Option>>, +} + +impl Conv1dBackward { + #[allow(clippy::too_many_arguments)] + pub fn new( + input_id: crate::tensor::TensorId, + weight_id: crate::tensor::TensorId, + bias_id: Option, + input: crate::tensor::Tensor, + + weight: crate::tensor::Tensor, + input_shape: Vec, + stride: usize, + padding: PaddingMode, + dilation: usize, + groups: usize, + input_grad_fn: Option>>, + weight_grad_fn: Option>>, + bias_grad_fn: Option>>, + ) -> Self { + let mut ids = vec![input_id, weight_id]; + if let Some(bid) = bias_id { + ids.push(bid); + } + Self { + input_ids: ids, + saved_input: input, + saved_weight: weight, + input_shape, + stride, + padding, + dilation, + groups, + input_grad_fn, + weight_grad_fn, + bias_grad_fn, + } + } +} + +/// Compute conv1d backward for input using tensor operations. +/// +/// d_input[n, c_in, l] = sum over c_out, k of: +/// weight[c_out, c_in, k] * grad_output[n, c_out, l*stride - pad + k*dilation] +/// +/// This is equivalent to a transposed convolution (conv_transpose1d). +/// +/// IMPLEMENTATION NOTE: Uses tensor operations (no to_vec/to_cpu). All computation +/// is performed through the client, which works on any backend. The Rust loop +/// structures the iteration, but actual mathematical operations (matmul, add) +/// happen on the device via the client. +fn conv1d_input_backward( + client: &C, + grad_output: &crate::tensor::Tensor, + weight: &crate::tensor::Tensor, + input_shape: &[usize], + stride: usize, + padding: PaddingMode, + dilation: usize, + groups: usize, +) -> Result> +where + R: Runtime, + C: TensorOps + BinaryOps + ReduceOps + ScalarOps, +{ + let batch = input_shape[0]; + let _c_in = input_shape[1]; + let input_len = input_shape[2]; + let c_out = weight.shape()[0]; + let c_in_per_group = weight.shape()[1]; + let kernel_size = weight.shape()[2]; + let output_len = grad_output.shape()[2]; + let c_out_per_group = c_out / groups; + + let (pad_left, _pad_right) = compute_padding(padding, kernel_size, dilation); + + let device = grad_output.device(); + let dtype = grad_output.dtype(); + + let mut d_input = crate::tensor::Tensor::::zeros(input_shape, dtype, device); + + // Accumulate contributions by iterating and accumulating tensor operations + for k in 0..kernel_size { + let weight_k = weight.narrow(2, k, 1)?; + let weight_k = weight_k.squeeze(Some(2)); + + for o in 0..output_len { + let i_pos = o * stride + k * dilation; + + if i_pos >= pad_left && i_pos < pad_left + input_len { + let i = i_pos - pad_left; + + let grad_o = grad_output.narrow(2, o, 1)?; + let grad_o = grad_o.squeeze(Some(2)); + + for g in 0..groups { + let c_in_start = g * c_in_per_group; + let c_out_start = g * c_out_per_group; + + let grad_g = grad_o.narrow(1, c_out_start, c_out_per_group)?; + let weight_g = weight_k.narrow(0, c_out_start, c_out_per_group)?; + + // Compute contribution: [batch, c_out_per_group] @ [c_out_per_group, c_in_per_group].T + let contrib_g = client.matmul(&grad_g, &weight_g.transpose(0, 1)?)?; + + // Reshape to [batch, c_in_per_group, 1] + let contrib_g_3d = contrib_g.reshape(&[batch, c_in_per_group, 1])?; + + // Get the slice at position i in the full d_input + let mut d_input_at_i = d_input.narrow(2, i, 1)?; // [batch, c_in, 1] + + // Get the group slice + let d_input_group = d_input_at_i.narrow(1, c_in_start, c_in_per_group)?; // [batch, c_in_per_group, 1] + + // Add contribution + let updated_group = client.add(&d_input_group, &contrib_g_3d)?; + + // Now put it back. We need to use slice_assign correctly. + // The challenge is that we have a [batch, c_in_per_group, 1] but + // we need to update a specific region of a [batch, c_in, 1]. + // slice_assign along dim 1 requires src to have the same dimension count + // and the same size on all dims except dim. + // So src should be [batch, c_in_per_group, 1] and we use dim=1, start=c_in_start + d_input_at_i = + client.slice_assign(&d_input_at_i, &updated_group, 1, c_in_start)?; + + // Now put d_input_at_i back into d_input at position i + d_input = client.slice_assign(&d_input, &d_input_at_i, 2, i)?; + } + } + } + } + + Ok(d_input) +} + +/// Compute conv1d backward for weight using tensor operations. +/// +/// d_weight[c_out, c_in, k] = sum over n, o of: +/// input[n, c_in, o*stride - pad + k*dilation] * grad_output[n, c_out, o] +/// +/// This function uses only tensor operations (no to_vec/to_cpu). All computation +/// is performed through the client, which works on any backend. +fn conv1d_weight_backward( + client: &C, + grad_output: &crate::tensor::Tensor, + input: &crate::tensor::Tensor, + weight_shape: &[usize], + stride: usize, + padding: PaddingMode, + dilation: usize, + groups: usize, +) -> Result> +where + R: Runtime, + C: TensorOps + BinaryOps + ReduceOps + ScalarOps, +{ + let _batch = input.shape()[0]; + let _c_in = input.shape()[1]; + let input_len = input.shape()[2]; + let c_out = weight_shape[0]; + let c_in_per_group = weight_shape[1]; + let kernel_size = weight_shape[2]; + let output_len = grad_output.shape()[2]; + let c_out_per_group = c_out / groups; + + let (pad_left, _pad_right) = compute_padding(padding, kernel_size, dilation); + + let device = grad_output.device(); + let dtype = grad_output.dtype(); + + let mut d_weight = crate::tensor::Tensor::::zeros(weight_shape, dtype, device); + + // Accumulate contributions by iterating and accumulating tensor operations + for o in 0..output_len { + for k in 0..kernel_size { + let i_pos = o * stride + k * dilation; + + if i_pos >= pad_left && i_pos < pad_left + input_len { + let i = i_pos - pad_left; + + let input_i = input.narrow(2, i, 1)?; + let input_i = input_i.squeeze(Some(2)); + + let grad_o = grad_output.narrow(2, o, 1)?; + let grad_o = grad_o.squeeze(Some(2)); + + for g in 0..groups { + let c_in_start = g * c_in_per_group; + let c_out_start = g * c_out_per_group; + + let input_g = input_i.narrow(1, c_in_start, c_in_per_group)?; + let grad_g = grad_o.narrow(1, c_out_start, c_out_per_group)?; + + // Compute: [c_out_per_group, batch] @ [batch, c_in_per_group] + // = [c_out_per_group, c_in_per_group] + let contrib_2d = client.matmul(&grad_g.transpose(0, 1)?, &input_g)?; + + // Reshape to [c_out_per_group, c_in_per_group, 1] + let contrib_3d = contrib_2d.reshape(&[c_out_per_group, c_in_per_group, 1])?; + + // Get the weight slice at kernel position k + let mut d_weight_at_k = d_weight.narrow(2, k, 1)?; // [c_out, c_in_per_group, 1] + + // Get the group slice + let d_weight_group = d_weight_at_k.narrow(0, c_out_start, c_out_per_group)?; // [c_out_per_group, c_in_per_group, 1] + + // Add contribution + let updated_group = client.add(&d_weight_group, &contrib_3d)?; + + // Put back along dimension 0 + d_weight_at_k = + client.slice_assign(&d_weight_at_k, &updated_group, 0, c_out_start)?; + + // Put back into d_weight along dimension 2 + d_weight = client.slice_assign(&d_weight, &d_weight_at_k, 2, k)?; + } + } + } + } + + Ok(d_weight) +} + +impl> crate::autograd::GradFn for Conv1dBackward +where + R::Client: ConvOps + TensorOps + ReduceOps + BinaryOps + ScalarOps, +{ + fn backward( + &self, + grad_output: &crate::tensor::Tensor, + ) -> Result>>> { + let client = R::default_client(grad_output.device()); + + // d_input via transposed convolution + let d_input = conv1d_input_backward::( + &client, + grad_output, + &self.saved_weight, + &self.input_shape, + self.stride, + self.padding, + self.dilation, + self.groups, + )?; + + // d_weight via cross-correlation + let d_weight = conv1d_weight_backward::( + &client, + grad_output, + &self.saved_input, + self.saved_weight.shape(), + self.stride, + self.padding, + self.dilation, + self.groups, + )?; + + // d_bias = sum over batch and length dims + let d_bias = if self.input_ids.len() > 2 { + // grad_output shape: [batch, c_out, output_len] + // sum over dim 0 (batch) and dim 2 (length) → [c_out] + let summed = client.sum(grad_output, &[0, 2], false)?; + Some(summed) + } else { + None + }; + + Ok(vec![Some(d_input), Some(d_weight), d_bias]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + + ConvOps + + TensorOps + + ReduceOps + + BinaryOps + + ScalarOps, + { + // First-order only for conv — second-order conv is rarely needed + let grads = self.backward(grad_output.tensor())?; + Ok(grads + .into_iter() + .map(|g| g.map(|t| Var::new(t, true))) + .collect()) + } + + fn inputs(&self) -> &[crate::tensor::TensorId] { + &self.input_ids + } + + fn input_grad_fns(&self) -> Vec>>> { + let mut fns = vec![self.input_grad_fn.clone(), self.weight_grad_fn.clone()]; + if self.input_ids.len() > 2 { + fns.push(self.bias_grad_fn.clone()); + } + fns + } + + fn saved_tensors(&self) -> &[crate::tensor::Tensor] { + std::slice::from_ref(&self.saved_input) + } + + fn name(&self) -> &'static str { + "Conv1dBackward" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::autograd::backward; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + use crate::tensor::Tensor; + + #[test] + fn test_var_conv1d_forward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // weight: [out=1, in=1, kernel=1] → identity-like + let input = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[1, 1, 3], &device), + false, + ); + let weight = Var::new( + Tensor::::from_slice(&[2.0f32], &[1, 1, 1], &device), + false, + ); + + let output = + var_conv1d(&input, &weight, None, 1, PaddingMode::Valid, 1, 1, &client).unwrap(); + let data: Vec = output.tensor().to_vec(); + assert_eq!(data, vec![2.0, 4.0, 6.0]); + } + + #[test] + fn test_var_conv1d_backward_input() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let input = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[1, 1, 3], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[2.0f32], &[1, 1, 1], &device), + true, + ); + + let output = + var_conv1d(&input, &weight, None, 1, PaddingMode::Valid, 1, 1, &client).unwrap(); + let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let d_input: Vec = grads.get(input.id()).unwrap().to_vec(); + // d_input should be weight broadcast: [2, 2, 2] + assert_eq!(d_input, vec![2.0, 2.0, 2.0]); + + let d_weight: Vec = grads.get(weight.id()).unwrap().to_vec(); + // d_weight = sum of input = 1+2+3 = 6 + assert!((d_weight[0] - 6.0).abs() < 1e-5); + } + + #[test] + fn test_var_conv1d_backward_with_bias() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let input = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0], &[1, 1, 2], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32], &[1, 1, 1], &device), + true, + ); + let bias = Var::new( + Tensor::::from_slice(&[10.0f32], &[1], &device), + true, + ); + + let output = var_conv1d( + &input, + &weight, + Some(&bias), + 1, + PaddingMode::Valid, + 1, + 1, + &client, + ) + .unwrap(); + let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let d_bias: Vec = grads.get(bias.id()).unwrap().to_vec(); + // d_bias = sum of grad_output (all ones) over batch and length = 2 + assert!((d_bias[0] - 2.0).abs() < 1e-5); + } + + #[test] + fn test_var_conv1d_kernel3() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // kernel_size=3, input_length=5 → output_length=3 + let input = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0, 1.0], &[1, 1, 3], &device), + true, + ); + + let output = + var_conv1d(&input, &weight, None, 1, PaddingMode::Valid, 1, 1, &client).unwrap(); + let data: Vec = output.tensor().to_vec(); + // [1+2+3, 2+3+4, 3+4+5] = [6, 9, 12] + assert_eq!(data, vec![6.0, 9.0, 12.0]); + + let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let d_input: Vec = grads.get(input.id()).unwrap().to_vec(); + // Each input position contributes to 1-3 output positions + // pos 0: contributes to output 0 → weight[0] = 1 + // pos 1: contributes to outputs 0,1 → weight[1]+weight[0] = 2 + // pos 2: contributes to outputs 0,1,2 → weight[2]+weight[1]+weight[0] = 3 + // pos 3: contributes to outputs 1,2 → weight[2]+weight[1] = 2 + // pos 4: contributes to output 2 → weight[2] = 1 + assert_eq!(d_input, vec![1.0, 2.0, 3.0, 2.0, 1.0]); + } +} diff --git a/src/autograd/var_ops/conv2d.rs b/src/autograd/var_ops/conv2d.rs new file mode 100644 index 00000000..62fa2ff1 --- /dev/null +++ b/src/autograd/var_ops/conv2d.rs @@ -0,0 +1,595 @@ +//! Conv2d autograd operation +//! +//! Wraps `ConvOps::conv2d` with gradient tracking. +//! +//! Backward computes: +//! - d_input = transposed convolution of grad_output with weight +//! - d_weight = cross-correlation of input with grad_output +//! - d_bias = sum(grad_output) over batch and spatial dims + +use crate::autograd::Var; +use crate::dtype::DType; +use crate::error::Result; +use crate::ops::{BinaryOps, ConvOps, PaddingMode, ReduceOps, ScalarOps, TensorOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use std::sync::Arc; + +use super::conv_common::compute_padding_2d; + +/// Differentiable 2D convolution. +/// +/// Wraps the forward `conv2d` and builds autograd graph for backward. +/// +/// # Arguments +/// * `input` - Input Var of shape `[batch, in_channels, height, width]` +/// * `weight` - Weight Var of shape `[out_channels, in_channels/groups, kH, kW]` +/// * `bias` - Optional bias Var of shape `[out_channels]` +/// * `stride` - Stride as `(stride_h, stride_w)` +/// * `padding` - Padding mode +/// * `dilation` - Dilation as `(dilation_h, dilation_w)` +/// * `groups` - Groups +/// * `client` - Runtime client +pub fn var_conv2d( + input: &Var, + weight: &Var, + bias: Option<&Var>, + stride: (usize, usize), + padding: PaddingMode, + dilation: (usize, usize), + groups: usize, + client: &C, +) -> Result> +where + R: Runtime, + C: RuntimeClient + ConvOps + TensorOps + ReduceOps + BinaryOps + ScalarOps, + R::Client: ConvOps + TensorOps + ReduceOps + BinaryOps + ScalarOps, +{ + let output = client.conv2d( + input.tensor(), + weight.tensor(), + bias.map(|b| b.tensor()), + stride, + padding, + dilation, + groups, + )?; + + let needs_grad = + input.requires_grad() || weight.requires_grad() || bias.is_some_and(|b| b.requires_grad()); + + if needs_grad { + let grad_fn = Conv2dBackward::::new( + input.id(), + weight.id(), + bias.map(|b| b.id()), + input.tensor().clone(), + weight.tensor().clone(), + input.tensor().shape().to_vec(), + stride, + padding, + dilation, + groups, + input.grad_fn().cloned(), + weight.grad_fn().cloned(), + bias.and_then(|b| b.grad_fn().cloned()), + ); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + +/// Backward for conv2d. +/// +/// Computes gradients for input, weight, and bias using: +/// - d_input: transposed convolution (conv with flipped kernel, adjusted padding) +/// - d_weight: cross-correlation of input with grad_output +/// - d_bias: sum of grad_output over batch and spatial dims +pub struct Conv2dBackward { + input_ids: Vec, + saved_input: crate::tensor::Tensor, + saved_weight: crate::tensor::Tensor, + input_shape: Vec, + stride: (usize, usize), + padding: PaddingMode, + dilation: (usize, usize), + groups: usize, + input_grad_fn: Option>>, + weight_grad_fn: Option>>, + bias_grad_fn: Option>>, +} + +impl Conv2dBackward { + #[allow(clippy::too_many_arguments)] + pub fn new( + input_id: crate::tensor::TensorId, + weight_id: crate::tensor::TensorId, + bias_id: Option, + input: crate::tensor::Tensor, + weight: crate::tensor::Tensor, + input_shape: Vec, + stride: (usize, usize), + padding: PaddingMode, + dilation: (usize, usize), + groups: usize, + input_grad_fn: Option>>, + weight_grad_fn: Option>>, + bias_grad_fn: Option>>, + ) -> Self { + let mut ids = vec![input_id, weight_id]; + if let Some(bid) = bias_id { + ids.push(bid); + } + Self { + input_ids: ids, + saved_input: input, + saved_weight: weight, + input_shape, + stride, + padding, + dilation, + groups, + input_grad_fn, + weight_grad_fn, + bias_grad_fn, + } + } +} + +/// Compute conv2d backward for input using tensor operations. +/// +/// d_input[n, c_in, h, w] = sum over c_out, kh, kw of: +/// weight[c_out, c_in, kh, kw] * grad_output[n, c_out, h*sh - pad_top + kh*dh, w*sw - pad_left + kw*dw] +fn conv2d_input_backward( + client: &C, + grad_output: &crate::tensor::Tensor, + weight: &crate::tensor::Tensor, + input_shape: &[usize], + stride: (usize, usize), + padding: PaddingMode, + dilation: (usize, usize), + groups: usize, +) -> Result> +where + R: Runtime, + C: TensorOps + BinaryOps + ReduceOps + ScalarOps, +{ + let batch = input_shape[0]; + let _c_in = input_shape[1]; + let input_h = input_shape[2]; + let input_w = input_shape[3]; + let c_out = weight.shape()[0]; + let c_in_per_group = weight.shape()[1]; + let kernel_h = weight.shape()[2]; + let kernel_w = weight.shape()[3]; + let output_h = grad_output.shape()[2]; + let output_w = grad_output.shape()[3]; + let c_out_per_group = c_out / groups; + + let (pad_top, _pad_bottom, pad_left, _pad_right) = + compute_padding_2d(padding, kernel_h, kernel_w, dilation.0, dilation.1); + + let device = grad_output.device(); + let dtype = grad_output.dtype(); + + let mut d_input = crate::tensor::Tensor::::zeros(input_shape, dtype, device); + + for kh in 0..kernel_h { + for kw in 0..kernel_w { + // Extract weight slice at [kh, kw]: weight[:, :, kh, kw] → [c_out, c_in_per_group] + let weight_kh = weight.narrow(2, kh, 1)?; + let weight_khkw = weight_kh.narrow(3, kw, 1)?; + let weight_2d = weight_khkw.squeeze(Some(3)).squeeze(Some(2)); + + for oh in 0..output_h { + let ih_pos = oh * stride.0 + kh * dilation.0; + if ih_pos < pad_top || ih_pos >= pad_top + input_h { + continue; + } + let ih = ih_pos - pad_top; + + for ow in 0..output_w { + let iw_pos = ow * stride.1 + kw * dilation.1; + if iw_pos < pad_left || iw_pos >= pad_left + input_w { + continue; + } + let iw = iw_pos - pad_left; + + // grad_output[:, :, oh, ow] → [batch, c_out] + let grad_o = grad_output.narrow(2, oh, 1)?.narrow(3, ow, 1)?; + let grad_o_2d = grad_o.squeeze(Some(3)).squeeze(Some(2)); + + for g in 0..groups { + let c_in_start = g * c_in_per_group; + let c_out_start = g * c_out_per_group; + + let grad_g = grad_o_2d.narrow(1, c_out_start, c_out_per_group)?; + let weight_g = weight_2d.narrow(0, c_out_start, c_out_per_group)?; + + // [batch, c_out_per_group] @ [c_out_per_group, c_in_per_group] + let contrib_g = client.matmul(&grad_g, &weight_g.transpose(0, 1)?)?; + + // Reshape to [batch, c_in_per_group, 1, 1] + let contrib_4d = contrib_g.reshape(&[batch, c_in_per_group, 1, 1])?; + + // Get the slice at position (ih, iw) in the full d_input + let mut d_input_at = d_input.narrow(2, ih, 1)?.narrow(3, iw, 1)?; + + // Get the group slice + let d_input_group = d_input_at.narrow(1, c_in_start, c_in_per_group)?; + + // Add contribution + let updated_group = client.add(&d_input_group, &contrib_4d)?; + + // Put back along dim 1 + d_input_at = + client.slice_assign(&d_input_at, &updated_group, 1, c_in_start)?; + + // Put back into d_input: first along dim 3 (width), then dim 2 (height) + let mut d_input_h = d_input.narrow(2, ih, 1)?; + d_input_h = client.slice_assign(&d_input_h, &d_input_at, 3, iw)?; + d_input = client.slice_assign(&d_input, &d_input_h, 2, ih)?; + } + } + } + } + } + + Ok(d_input) +} + +/// Compute conv2d backward for weight using tensor operations. +/// +/// d_weight[c_out, c_in, kh, kw] = sum over n, oh, ow of: +/// input[n, c_in, oh*sh - pad_top + kh*dh, ow*sw - pad_left + kw*dw] * grad_output[n, c_out, oh, ow] +fn conv2d_weight_backward( + client: &C, + grad_output: &crate::tensor::Tensor, + input: &crate::tensor::Tensor, + weight_shape: &[usize], + stride: (usize, usize), + padding: PaddingMode, + dilation: (usize, usize), + groups: usize, +) -> Result> +where + R: Runtime, + C: TensorOps + BinaryOps + ReduceOps + ScalarOps, +{ + let _batch = input.shape()[0]; + let _c_in = input.shape()[1]; + let input_h = input.shape()[2]; + let input_w = input.shape()[3]; + let c_out = weight_shape[0]; + let c_in_per_group = weight_shape[1]; + let kernel_h = weight_shape[2]; + let kernel_w = weight_shape[3]; + let output_h = grad_output.shape()[2]; + let output_w = grad_output.shape()[3]; + let c_out_per_group = c_out / groups; + + let (pad_top, _pad_bottom, pad_left, _pad_right) = + compute_padding_2d(padding, kernel_h, kernel_w, dilation.0, dilation.1); + + let device = grad_output.device(); + let dtype = grad_output.dtype(); + + let mut d_weight = crate::tensor::Tensor::::zeros(weight_shape, dtype, device); + + for oh in 0..output_h { + for ow in 0..output_w { + // grad_output[:, :, oh, ow] → [batch, c_out] + let grad_o = grad_output.narrow(2, oh, 1)?.narrow(3, ow, 1)?; + let grad_o_2d = grad_o.squeeze(Some(3)).squeeze(Some(2)); + + for kh in 0..kernel_h { + let ih_pos = oh * stride.0 + kh * dilation.0; + if ih_pos < pad_top || ih_pos >= pad_top + input_h { + continue; + } + let ih = ih_pos - pad_top; + + for kw in 0..kernel_w { + let iw_pos = ow * stride.1 + kw * dilation.1; + if iw_pos < pad_left || iw_pos >= pad_left + input_w { + continue; + } + let iw = iw_pos - pad_left; + + // input[:, :, ih, iw] → [batch, c_in] + let input_hw = input.narrow(2, ih, 1)?.narrow(3, iw, 1)?; + let input_2d = input_hw.squeeze(Some(3)).squeeze(Some(2)); + + for g in 0..groups { + let c_in_start = g * c_in_per_group; + let c_out_start = g * c_out_per_group; + + let input_g = input_2d.narrow(1, c_in_start, c_in_per_group)?; + let grad_g = grad_o_2d.narrow(1, c_out_start, c_out_per_group)?; + + // [c_out_per_group, batch] @ [batch, c_in_per_group] + // = [c_out_per_group, c_in_per_group] + let contrib_2d = client.matmul(&grad_g.transpose(0, 1)?, &input_g)?; + + // Reshape to [c_out_per_group, c_in_per_group, 1, 1] + let contrib_4d = + contrib_2d.reshape(&[c_out_per_group, c_in_per_group, 1, 1])?; + + // Get the weight slice at kernel position (kh, kw) + let mut d_weight_at = d_weight.narrow(2, kh, 1)?.narrow(3, kw, 1)?; + + // Get the group slice + let d_weight_group = d_weight_at.narrow(0, c_out_start, c_out_per_group)?; + + // Add contribution + let updated_group = client.add(&d_weight_group, &contrib_4d)?; + + // Put back along dim 0 + d_weight_at = + client.slice_assign(&d_weight_at, &updated_group, 0, c_out_start)?; + + // Put back into d_weight: first along dim 3, then dim 2 + let mut d_weight_kh = d_weight.narrow(2, kh, 1)?; + d_weight_kh = client.slice_assign(&d_weight_kh, &d_weight_at, 3, kw)?; + d_weight = client.slice_assign(&d_weight, &d_weight_kh, 2, kh)?; + } + } + } + } + } + + Ok(d_weight) +} + +impl> crate::autograd::GradFn for Conv2dBackward +where + R::Client: ConvOps + TensorOps + ReduceOps + BinaryOps + ScalarOps, +{ + fn backward( + &self, + grad_output: &crate::tensor::Tensor, + ) -> Result>>> { + let client = R::default_client(grad_output.device()); + + // d_input via transposed convolution + let d_input = conv2d_input_backward::( + &client, + grad_output, + &self.saved_weight, + &self.input_shape, + self.stride, + self.padding, + self.dilation, + self.groups, + )?; + + // d_weight via cross-correlation + let d_weight = conv2d_weight_backward::( + &client, + grad_output, + &self.saved_input, + self.saved_weight.shape(), + self.stride, + self.padding, + self.dilation, + self.groups, + )?; + + // d_bias = sum over batch, height, and width dims + let d_bias = if self.input_ids.len() > 2 { + // grad_output shape: [batch, c_out, out_h, out_w] + // sum over dim 0 (batch), dim 2 (height), dim 3 (width) → [c_out] + let summed = client.sum(grad_output, &[0, 2, 3], false)?; + Some(summed) + } else { + None + }; + + Ok(vec![Some(d_input), Some(d_weight), d_bias]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + + ConvOps + + TensorOps + + ReduceOps + + BinaryOps + + ScalarOps, + { + // First-order only for conv — second-order conv is rarely needed + let grads = self.backward(grad_output.tensor())?; + Ok(grads + .into_iter() + .map(|g| g.map(|t| Var::new(t, true))) + .collect()) + } + + fn inputs(&self) -> &[crate::tensor::TensorId] { + &self.input_ids + } + + fn input_grad_fns(&self) -> Vec>>> { + let mut fns = vec![self.input_grad_fn.clone(), self.weight_grad_fn.clone()]; + if self.input_ids.len() > 2 { + fns.push(self.bias_grad_fn.clone()); + } + fns + } + + fn saved_tensors(&self) -> &[crate::tensor::Tensor] { + std::slice::from_ref(&self.saved_input) + } + + fn name(&self) -> &'static str { + "Conv2dBackward" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::autograd::backward; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + use crate::tensor::Tensor; + + #[test] + fn test_var_conv2d_forward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // Input: [batch=1, c_in=1, h=2, w=2], weight: [c_out=1, c_in=1, kH=1, kW=1] = 2.0 + let input = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 1, 2, 2], &device), + false, + ); + let weight = Var::new( + Tensor::::from_slice(&[2.0f32], &[1, 1, 1, 1], &device), + false, + ); + + let output = var_conv2d( + &input, + &weight, + None, + (1, 1), + PaddingMode::Valid, + (1, 1), + 1, + &client, + ) + .unwrap(); + let data: Vec = output.tensor().to_vec(); + assert_eq!(data, vec![2.0, 4.0, 6.0, 8.0]); + } + + #[test] + fn test_var_conv2d_backward_input() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // Input: [1, 1, 2, 2], weight: [1, 1, 1, 1] = 2.0 + let input = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 1, 2, 2], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[2.0f32], &[1, 1, 1, 1], &device), + true, + ); + + let output = var_conv2d( + &input, + &weight, + None, + (1, 1), + PaddingMode::Valid, + (1, 1), + 1, + &client, + ) + .unwrap(); + let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let d_input: Vec = grads.get(input.id()).unwrap().to_vec(); + // With 1x1 kernel of weight=2, d_input should be [2, 2, 2, 2] + assert_eq!(d_input, vec![2.0, 2.0, 2.0, 2.0]); + + let d_weight: Vec = grads.get(weight.id()).unwrap().to_vec(); + // d_weight = sum of input = 1+2+3+4 = 10 + assert!((d_weight[0] - 10.0).abs() < 1e-5); + } + + #[test] + fn test_var_conv2d_backward_with_bias() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // Input: [1, 1, 2, 2], weight: [1, 1, 1, 1] = 1.0, bias: [1] = 10.0 + let input = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 1, 2, 2], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32], &[1, 1, 1, 1], &device), + true, + ); + let bias = Var::new( + Tensor::::from_slice(&[10.0f32], &[1], &device), + true, + ); + + let output = var_conv2d( + &input, + &weight, + Some(&bias), + (1, 1), + PaddingMode::Valid, + (1, 1), + 1, + &client, + ) + .unwrap(); + let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let d_bias: Vec = grads.get(bias.id()).unwrap().to_vec(); + // d_bias = sum of grad_output (all ones) over batch, h, w = 2*2 = 4 + assert!((d_bias[0] - 4.0).abs() < 1e-5); + } + + #[test] + fn test_var_conv2d_kernel2x2() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // Input: [1, 1, 3, 3], weight: [1, 1, 2, 2] all ones + // Output: [1, 1, 2, 2] + #[rustfmt::skip] + let input_data: Vec = vec![ + 1.0, 2.0, 3.0, + 4.0, 5.0, 6.0, + 7.0, 8.0, 9.0, + ]; + let input = Var::new( + Tensor::::from_slice(&input_data, &[1, 1, 3, 3], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[1, 1, 2, 2], &device), + true, + ); + + let output = var_conv2d( + &input, + &weight, + None, + (1, 1), + PaddingMode::Valid, + (1, 1), + 1, + &client, + ) + .unwrap(); + let data: Vec = output.tensor().to_vec(); + // [1+2+4+5, 2+3+5+6, 4+5+7+8, 5+6+8+9] = [12, 16, 24, 28] + assert_eq!(data, vec![12.0, 16.0, 24.0, 28.0]); + + let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let d_input: Vec = grads.get(input.id()).unwrap().to_vec(); + // Each input position contributes to 1-4 output positions (2x2 kernel, all 1s) + // pos(0,0): out(0,0) → 1 + // pos(0,1): out(0,0)+out(0,1) → 2 + // pos(0,2): out(0,1) → 1 + // pos(1,0): out(0,0)+out(1,0) → 2 + // pos(1,1): out(0,0)+out(0,1)+out(1,0)+out(1,1) → 4 + // pos(1,2): out(0,1)+out(1,1) → 2 + // pos(2,0): out(1,0) → 1 + // pos(2,1): out(1,0)+out(1,1) → 2 + // pos(2,2): out(1,1) → 1 + assert_eq!(d_input, vec![1.0, 2.0, 1.0, 2.0, 4.0, 2.0, 1.0, 2.0, 1.0]); + } +} diff --git a/src/autograd/var_ops/conv_common.rs b/src/autograd/var_ops/conv_common.rs new file mode 100644 index 00000000..f363108e --- /dev/null +++ b/src/autograd/var_ops/conv_common.rs @@ -0,0 +1,43 @@ +//! Shared utilities for conv autograd operations. + +use crate::ops::PaddingMode; + +/// Compute effective padding amounts for a single spatial dimension. +/// +/// Returns `(pad_before, pad_after)` for the given kernel size and dilation. +pub(super) fn compute_padding( + padding: PaddingMode, + kernel_size: usize, + dilation: usize, +) -> (usize, usize) { + match padding { + PaddingMode::Valid => (0, 0), + PaddingMode::Same => { + let effective_k = dilation * (kernel_size - 1) + 1; + let total = effective_k.saturating_sub(1); + (total / 2, total - total / 2) + } + PaddingMode::Custom(left, right, _, _) => (left, right), + } +} + +/// Compute effective padding amounts for 2D convolution. +/// +/// Returns `(pad_top, pad_bottom, pad_left, pad_right)`. +pub(super) fn compute_padding_2d( + padding: PaddingMode, + kernel_h: usize, + kernel_w: usize, + dilation_h: usize, + dilation_w: usize, +) -> (usize, usize, usize, usize) { + match padding { + PaddingMode::Valid => (0, 0, 0, 0), + PaddingMode::Same => { + let (top, bottom) = compute_padding(PaddingMode::Same, kernel_h, dilation_h); + let (left, right) = compute_padding(PaddingMode::Same, kernel_w, dilation_w); + (top, bottom, left, right) + } + PaddingMode::Custom(top, bottom, left, right) => (top, bottom, left, right), + } +} diff --git a/src/autograd/var_ops/dropout.rs b/src/autograd/var_ops/dropout.rs new file mode 100644 index 00000000..b1903a0d --- /dev/null +++ b/src/autograd/var_ops/dropout.rs @@ -0,0 +1,228 @@ +//! Dropout operation with gradient support +//! +//! Dropout randomly zeroes elements with probability `p` during training, +//! scaling remaining elements by `1/(1-p)` (inverted dropout). +//! During inference, it's a no-op (identity function). + +use crate::autograd::Var; +use crate::autograd::var_ops::var_mul; +use crate::dtype::DType; +use crate::error::Result; +use crate::ops::{BinaryOps, RandomOps, ScalarOps, TensorOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use std::sync::Arc; + +/// Dropout with inverted scaling: zero elements with probability `p`, +/// scale survivors by `1/(1-p)`. +/// +/// Returns `(output, mask)` where mask is the binary mask scaled by `1/(1-p)`. +/// The mask is needed by the caller to store for potential reuse (e.g., in +/// the `Dropout` module) and is also saved internally for the backward pass. +/// +/// When `p == 0.0`, this is an identity operation (no dropout applied). +pub fn var_dropout( + a: &Var, + p: f64, + client: &C, +) -> Result<(Var, crate::tensor::Tensor)> +where + R: Runtime, + C: RuntimeClient + TensorOps + RandomOps + ScalarOps, + R::Client: TensorOps + ScalarOps, +{ + if p == 0.0 { + // No dropout — return input unchanged with a ones mask + let mask = crate::tensor::Tensor::::ones( + a.tensor().shape(), + a.tensor().dtype(), + a.tensor().device(), + ); + return Ok((Var::new(a.tensor().clone(), a.requires_grad()), mask)); + } + + if p >= 1.0 { + // Drop everything — return zeros + let zeros = crate::tensor::Tensor::::zeros( + a.tensor().shape(), + a.tensor().dtype(), + a.tensor().device(), + ); + return Ok((Var::new(zeros.clone(), a.requires_grad()), zeros)); + } + + // Generate bernoulli mask: 1 with probability (1-p), 0 with probability p + let keep_prob = 1.0 - p; + let mask = client.bernoulli(keep_prob, a.tensor().shape(), a.tensor().dtype())?; + + // Scale mask by 1/(1-p) for inverted dropout + let scale = 1.0 / keep_prob; + let scaled_mask = client.mul_scalar(&mask, scale)?; + + // output = input * scaled_mask + let output = client.mul(a.tensor(), &scaled_mask)?; + + if a.requires_grad() { + let grad_fn = DropoutBackward::::new(a.id(), scaled_mask.clone(), a.grad_fn().cloned()); + Ok((Var::from_op(output, Arc::new(grad_fn)), scaled_mask)) + } else { + Ok((Var::new(output, false), scaled_mask)) + } +} + +/// Backward for dropout. +/// +/// Gradient: `dL/da = dL/dz * scaled_mask` +/// +/// The same mask used in forward is applied to the gradient — zeroed positions +/// remain zeroed, surviving positions are scaled by `1/(1-p)`. +pub struct DropoutBackward { + input_id: crate::tensor::TensorId, + saved_mask: crate::tensor::Tensor, + input_grad_fn: Option>>, +} + +impl DropoutBackward { + pub fn new( + input_id: crate::tensor::TensorId, + mask: crate::tensor::Tensor, + input_grad_fn: Option>>, + ) -> Self { + Self { + input_id, + saved_mask: mask, + input_grad_fn, + } + } +} + +impl> crate::autograd::GradFn for DropoutBackward +where + R::Client: TensorOps + BinaryOps, +{ + fn backward( + &self, + grad_output: &crate::tensor::Tensor, + ) -> Result>>> { + let client = R::default_client(grad_output.device()); + let grad = client.mul(grad_output, &self.saved_mask)?; + Ok(vec![Some(grad)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + TensorOps, + { + let client = R::default_client(grad_output.tensor().device()); + let mask_var = Var::new(self.saved_mask.clone(), false); + let grad = var_mul(grad_output, &mask_var, &client)?; + Ok(vec![Some(grad)]) + } + + fn inputs(&self) -> &[crate::tensor::TensorId] { + std::slice::from_ref(&self.input_id) + } + + fn input_grad_fns(&self) -> Vec>>> { + vec![self.input_grad_fn.clone()] + } + + fn saved_tensors(&self) -> &[crate::tensor::Tensor] { + std::slice::from_ref(&self.saved_mask) + } + + fn name(&self) -> &'static str { + "DropoutBackward" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::autograd::backward; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + + #[test] + fn test_dropout_zero_rate() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let input = Var::new( + crate::tensor::Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device), + false, + ); + let (output, _mask) = var_dropout(&input, 0.0, &client).unwrap(); + + let data: Vec = output.tensor().to_vec(); + assert_eq!(data, vec![1.0, 2.0, 3.0]); + } + + #[test] + fn test_dropout_full_rate() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let input = Var::new( + crate::tensor::Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device), + false, + ); + // p=1.0 means drop everything + let (output, _mask) = var_dropout(&input, 1.0, &client).unwrap(); + + let data: Vec = output.tensor().to_vec(); + for val in data { + assert_eq!(val, 0.0); + } + } + + #[test] + fn test_dropout_scaling() { + // With p=0.5, surviving elements should be scaled by 2.0 + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let input = Var::new( + crate::tensor::Tensor::::from_slice(&[1.0f32; 1000], &[1000], &device), + false, + ); + let (output, _mask) = var_dropout(&input, 0.5, &client).unwrap(); + + let data: Vec = output.tensor().to_vec(); + for val in &data { + // Each element is either 0.0 or 2.0 (1.0 * 1/(1-0.5)) + assert!(*val == 0.0 || (*val - 2.0).abs() < 1e-5, "got {val}"); + } + + // Statistically, roughly half should be non-zero + let nonzero = data.iter().filter(|&&v| v != 0.0).count(); + assert!(nonzero > 300 && nonzero < 700, "nonzero count: {nonzero}"); + } + + #[test] + fn test_dropout_backward_gradient() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let input = Var::new( + crate::tensor::Tensor::::from_slice( + &[1.0f32, 2.0, 3.0, 4.0], + &[4], + &device, + ), + true, + ); + let (output, mask) = var_dropout(&input, 0.5, &client).unwrap(); + + // Sum to get scalar loss + let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + let grad = grads.get(input.id()).unwrap(); + + let grad_data: Vec = grad.to_vec(); + let mask_data: Vec = mask.to_vec(); + + // Gradient should equal the mask (since d(sum(x*mask))/dx = mask) + for (g, m) in grad_data.iter().zip(mask_data.iter()) { + assert!((g - m).abs() < 1e-5, "grad {g} != mask {m}"); + } + } +} diff --git a/src/autograd/var_ops/fused_activation_mul.rs b/src/autograd/var_ops/fused_activation_mul.rs new file mode 100644 index 00000000..f65815c3 --- /dev/null +++ b/src/autograd/var_ops/fused_activation_mul.rs @@ -0,0 +1,578 @@ +//! Fused activation-multiplication with gradient support +//! +//! Each function computes `activation(a) * b` in a single memory pass. +//! Backward computes: +//! - d_a = grad_output * b * activation'(a) +//! - d_b = grad_output * activation(a) + +use crate::autograd::Var; +use crate::autograd::var_ops::var_mul; +use crate::dtype::DType; +use crate::error::Result; +use crate::ops::{ + ActivationOps, BinaryOps, CompareOps, ConditionalOps, ScalarOps, TensorOps, UnaryOps, +}; +use crate::runtime::{Runtime, RuntimeClient}; +use std::sync::Arc; + +/// Which fused activation-mul variant +#[derive(Clone, Copy)] +enum FusedKind { + Silu, + Gelu, + Relu, + Sigmoid, +} + +/// Fused SiLU-Mul: output = silu(a) * b +pub fn var_silu_mul(a: &Var, b: &Var, client: &C) -> Result> +where + R: Runtime, + C: RuntimeClient + + TensorOps + + ActivationOps + + ScalarOps + + BinaryOps + + CompareOps + + ConditionalOps + + UnaryOps, + R::Client: TensorOps + + ActivationOps + + ScalarOps + + BinaryOps + + CompareOps + + ConditionalOps + + UnaryOps, +{ + var_fused_activation_mul(a, b, client, FusedKind::Silu) +} + +/// Fused GELU-Mul: output = gelu(a) * b +pub fn var_gelu_mul(a: &Var, b: &Var, client: &C) -> Result> +where + R: Runtime, + C: RuntimeClient + + TensorOps + + ActivationOps + + ScalarOps + + BinaryOps + + CompareOps + + ConditionalOps + + UnaryOps, + R::Client: TensorOps + + ActivationOps + + ScalarOps + + BinaryOps + + CompareOps + + ConditionalOps + + UnaryOps, +{ + var_fused_activation_mul(a, b, client, FusedKind::Gelu) +} + +/// Fused ReLU-Mul: output = relu(a) * b +pub fn var_relu_mul(a: &Var, b: &Var, client: &C) -> Result> +where + R: Runtime, + C: RuntimeClient + + TensorOps + + ActivationOps + + ScalarOps + + BinaryOps + + CompareOps + + ConditionalOps + + UnaryOps, + R::Client: TensorOps + + ActivationOps + + ScalarOps + + BinaryOps + + CompareOps + + ConditionalOps + + UnaryOps, +{ + var_fused_activation_mul(a, b, client, FusedKind::Relu) +} + +/// Fused Sigmoid-Mul: output = sigmoid(a) * b +pub fn var_sigmoid_mul(a: &Var, b: &Var, client: &C) -> Result> +where + R: Runtime, + C: RuntimeClient + + TensorOps + + ActivationOps + + ScalarOps + + BinaryOps + + CompareOps + + ConditionalOps + + UnaryOps, + R::Client: TensorOps + + ActivationOps + + ScalarOps + + BinaryOps + + CompareOps + + ConditionalOps + + UnaryOps, +{ + var_fused_activation_mul(a, b, client, FusedKind::Sigmoid) +} + +/// Shared implementation for all fused activation-mul variants +fn var_fused_activation_mul( + a: &Var, + b: &Var, + client: &C, + kind: FusedKind, +) -> Result> +where + R: Runtime, + C: RuntimeClient + + TensorOps + + ActivationOps + + ScalarOps + + BinaryOps + + CompareOps + + ConditionalOps + + UnaryOps, + R::Client: TensorOps + + ActivationOps + + ScalarOps + + BinaryOps + + CompareOps + + ConditionalOps + + UnaryOps, +{ + // Forward: use fused kernel + let output = match kind { + FusedKind::Silu => client.silu_mul(a.tensor(), b.tensor())?, + FusedKind::Gelu => client.gelu_mul(a.tensor(), b.tensor())?, + FusedKind::Relu => client.relu_mul(a.tensor(), b.tensor())?, + FusedKind::Sigmoid => client.sigmoid_mul(a.tensor(), b.tensor())?, + }; + + if a.requires_grad() || b.requires_grad() { + // Compute activation(a) for backward (needed for d_b) + let activation_a = match kind { + FusedKind::Silu => client.silu(a.tensor())?, + FusedKind::Gelu => client.gelu(a.tensor())?, + FusedKind::Relu => client.relu(a.tensor())?, + FusedKind::Sigmoid => client.sigmoid(a.tensor())?, + }; + + let grad_fn = FusedActivationMulBackward::::new( + a.id(), + b.id(), + a.tensor().clone(), + b.tensor().clone(), + activation_a, + kind, + a.grad_fn().cloned(), + b.grad_fn().cloned(), + ); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + +/// Backward for fused activation-mul: output = activation(a) * b +/// +/// Gradients: +/// - d_b = grad_output * activation(a) +/// - d_a = grad_output * b * activation'(a) +/// +/// Derivatives: +/// - silu'(x) = sigmoid(x) * (1 + x - silu(x)) +/// - gelu'(x) = 0.5*(1+tanh(inner)) + 0.5*x*sech²(inner)*sqrt(2/π)*(1+3*0.044715*x²) +/// - relu'(x) = 1 if x > 0, else 0 +/// - sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x)) +pub struct FusedActivationMulBackward { + input_ids: [crate::tensor::TensorId; 2], + saved_a: crate::tensor::Tensor, + saved_b: crate::tensor::Tensor, + saved_activation_a: crate::tensor::Tensor, + kind: FusedKind, + a_grad_fn: Option>>, + b_grad_fn: Option>>, +} + +impl FusedActivationMulBackward { + #[allow(clippy::too_many_arguments)] + fn new( + a_id: crate::tensor::TensorId, + b_id: crate::tensor::TensorId, + a: crate::tensor::Tensor, + b: crate::tensor::Tensor, + activation_a: crate::tensor::Tensor, + kind: FusedKind, + a_grad_fn: Option>>, + b_grad_fn: Option>>, + ) -> Self { + Self { + input_ids: [a_id, b_id], + saved_a: a, + saved_b: b, + saved_activation_a: activation_a, + kind, + a_grad_fn, + b_grad_fn, + } + } +} + +impl> crate::autograd::GradFn for FusedActivationMulBackward +where + R::Client: TensorOps + + ActivationOps + + ScalarOps + + BinaryOps + + CompareOps + + ConditionalOps + + UnaryOps, +{ + fn backward( + &self, + grad_output: &crate::tensor::Tensor, + ) -> Result>>> { + let client = R::default_client(grad_output.device()); + + // Delegate to fused backward trait method — allows backends (e.g. CUDA) + // to provide a single fused kernel for the entire backward pass. + let (d_a, d_b) = match self.kind { + FusedKind::Silu => client.silu_mul_bwd(grad_output, &self.saved_a, &self.saved_b)?, + FusedKind::Gelu => client.gelu_mul_bwd(grad_output, &self.saved_a, &self.saved_b)?, + FusedKind::Relu => client.relu_mul_bwd(grad_output, &self.saved_a, &self.saved_b)?, + FusedKind::Sigmoid => { + client.sigmoid_mul_bwd(grad_output, &self.saved_a, &self.saved_b)? + } + }; + + Ok(vec![Some(d_a), Some(d_b)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + + TensorOps + + ActivationOps + + ScalarOps + + BinaryOps + + CompareOps + + ConditionalOps + + UnaryOps, + { + let client = R::default_client(grad_output.tensor().device()); + + // d_b = grad_output * activation(a) (activation_a is constant w.r.t. higher-order) + let act_var = Var::new(self.saved_activation_a.clone(), false); + let d_b = var_mul(grad_output, &act_var, &client)?; + + // d_a = grad_output * b * activation'(a) + let activation_deriv = compute_activation_derivative( + &client, + &self.saved_a, + &self.saved_activation_a, + self.kind, + )?; + let deriv_var = Var::new(activation_deriv, false); + let b_var = Var::new(self.saved_b.clone(), false); + let grad_times_b = var_mul(grad_output, &b_var, &client)?; + let d_a = var_mul(&grad_times_b, &deriv_var, &client)?; + + Ok(vec![Some(d_a), Some(d_b)]) + } + + fn inputs(&self) -> &[crate::tensor::TensorId] { + &self.input_ids + } + + fn input_grad_fns(&self) -> Vec>>> { + vec![self.a_grad_fn.clone(), self.b_grad_fn.clone()] + } + + fn saved_tensors(&self) -> &[crate::tensor::Tensor] { + std::slice::from_ref(&self.saved_a) + } + + fn name(&self) -> &'static str { + match self.kind { + FusedKind::Silu => "SiluMulBackward", + FusedKind::Gelu => "GeluMulBackward", + FusedKind::Relu => "ReluMulBackward", + FusedKind::Sigmoid => "SigmoidMulBackward", + } + } +} + +/// Compute activation'(x) for the backward pass +fn compute_activation_derivative( + client: &C, + a: &crate::tensor::Tensor, + activation_a: &crate::tensor::Tensor, + kind: FusedKind, +) -> Result> +where + R: Runtime, + C: TensorOps + + ActivationOps + + ScalarOps + + BinaryOps + + CompareOps + + ConditionalOps + + UnaryOps, +{ + match kind { + FusedKind::Silu => { + // silu'(x) = sigmoid(x) * (1 + x - silu(x)) + let sigmoid_a = client.sigmoid(a)?; + let one_plus_a = client.add_scalar(a, 1.0)?; + let one_plus_a_minus_silu = client.sub(&one_plus_a, activation_a)?; + client.mul(&sigmoid_a, &one_plus_a_minus_silu) + } + FusedKind::Gelu => { + // gelu'(x) = 0.5*(1+tanh(inner)) + 0.5*x*sech²(inner)*sqrt(2/π)*(1+3*0.044715*x²) + // where inner = sqrt(2/π) * (x + 0.044715*x³) + // + // Simpler: d/dx gelu(x) = gelu(x)/x + x * pdf(x) + // But that has x=0 issues. Use the direct form: + // + // Let's use: gelu(x) = 0.5*x*(1+tanh(inner)) + // gelu'(x) = 0.5*(1+tanh(inner)) + 0.5*x*(1-tanh²(inner))*inner' + // inner' = sqrt(2/π)*(1 + 3*0.044715*x²) + let x_sq = client.mul(a, a)?; + let x_cu = client.mul(&x_sq, a)?; + let coef_x_cu = client.mul_scalar(&x_cu, 0.044715)?; + let inner_arg = client.add(a, &coef_x_cu)?; + let sqrt_2_pi = 0.7978845608028654; + let inner = client.mul_scalar(&inner_arg, sqrt_2_pi)?; + + // tanh(inner) + let tanh_inner = { + // Use exp to compute tanh: tanh(x) = (exp(2x)-1)/(exp(2x)+1) + let two_inner = client.mul_scalar(&inner, 2.0)?; + let exp_2 = client.exp(&two_inner)?; + let num = client.add_scalar(&exp_2, -1.0)?; + let den = client.add_scalar(&exp_2, 1.0)?; + client.div(&num, &den)? + }; + + // 0.5*(1+tanh(inner)) + let one_plus_tanh = client.add_scalar(&tanh_inner, 1.0)?; + let term1 = client.mul_scalar(&one_plus_tanh, 0.5)?; + + // sech²(inner) = 1 - tanh²(inner) + let tanh_sq = client.mul(&tanh_inner, &tanh_inner)?; + let sech_sq = client.add_scalar(&tanh_sq, -1.0)?; + let sech_sq = client.neg(&sech_sq)?; + + // inner' = sqrt(2/π) * (1 + 3*0.044715*x²) + let three_coef_x_sq = client.mul_scalar(&x_sq, 3.0 * 0.044715)?; + let inner_deriv_unscaled = client.add_scalar(&three_coef_x_sq, 1.0)?; + let inner_deriv = client.mul_scalar(&inner_deriv_unscaled, sqrt_2_pi)?; + + // term2 = 0.5 * x * sech²(inner) * inner' + let x_sech_sq = client.mul(a, &sech_sq)?; + let x_sech_sq_inner_d = client.mul(&x_sech_sq, &inner_deriv)?; + let term2 = client.mul_scalar(&x_sech_sq_inner_d, 0.5)?; + + client.add(&term1, &term2) + } + FusedKind::Relu => { + // relu'(x) = 1 if x > 0, else 0 + let zeros = crate::tensor::Tensor::::zeros(a.shape(), a.dtype(), a.device()); + let ones = crate::tensor::Tensor::::ones(a.shape(), a.dtype(), a.device()); + let mask = client.gt(a, &zeros)?; + client.where_cond(&mask, &ones, &zeros) + } + FusedKind::Sigmoid => { + // sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x)) + let sigmoid_a = client.sigmoid(a)?; + let one_minus_sig = client.add_scalar(&sigmoid_a, -1.0)?; + let one_minus_sig = client.neg(&one_minus_sig)?; + client.mul(&sigmoid_a, &one_minus_sig) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::autograd::backward; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + use crate::tensor::Tensor; + + #[test] + fn test_silu_mul_forward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let a = Var::new( + Tensor::::from_slice(&[0.0f32, 1.0, -1.0], &[3], &device), + false, + ); + let b = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device), + false, + ); + + let output = var_silu_mul(&a, &b, &client).unwrap(); + let data: Vec = output.tensor().to_vec(); + + // silu(0)*1 = 0, silu(1)*2, silu(-1)*3 + assert!(data[0].abs() < 1e-6); + let silu_1 = 1.0 / (1.0 + (-1.0f32).exp()); + assert!((data[1] - silu_1 * 2.0).abs() < 1e-4); + let silu_neg1 = -1.0 / (1.0 + 1.0f32.exp()); + assert!((data[2] - silu_neg1 * 3.0).abs() < 1e-4); + } + + #[test] + fn test_silu_mul_matches_separate_ops() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let a_data = vec![0.5f32, -0.3, 1.2, -2.0, 0.0]; + let b_data = vec![1.0f32, 2.0, 0.5, -1.0, 3.0]; + + // Fused + let fused = client + .silu_mul( + &Tensor::::from_slice(&a_data, &[5], &device), + &Tensor::::from_slice(&b_data, &[5], &device), + ) + .unwrap(); + + // Separate + let silu_a = client + .silu(&Tensor::::from_slice(&a_data, &[5], &device)) + .unwrap(); + let separate = client + .mul( + &silu_a, + &Tensor::::from_slice(&b_data, &[5], &device), + ) + .unwrap(); + + let fused_v: Vec = fused.to_vec(); + let separate_v: Vec = separate.to_vec(); + for i in 0..5 { + assert!( + (fused_v[i] - separate_v[i]).abs() < 1e-5, + "mismatch at {i}: {} vs {}", + fused_v[i], + separate_v[i] + ); + } + } + + #[test] + fn test_silu_mul_backward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let a = Var::new( + Tensor::::from_slice(&[1.0f32, -1.0], &[2], &device), + true, + ); + let b = Var::new( + Tensor::::from_slice(&[2.0f32, 3.0], &[2], &device), + true, + ); + + let output = var_silu_mul(&a, &b, &client).unwrap(); + let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let d_a: Vec = grads.get(a.id()).unwrap().to_vec(); + let d_b: Vec = grads.get(b.id()).unwrap().to_vec(); + + // Verify d_b = silu(a) + for (i, &g) in [1.0f32, -1.0].iter().enumerate() { + let expected = g / (1.0 + (-g).exp()); + assert!( + (d_b[i] - expected).abs() < 1e-4, + "d_b[{i}]: got {}, expected {expected}", + d_b[i] + ); + } + + // Verify d_a = b * silu'(a) + for (i, (&g, &u)) in [1.0f32, -1.0].iter().zip([2.0f32, 3.0].iter()).enumerate() { + let sig = 1.0 / (1.0 + (-g).exp()); + let silu_g = g * sig; + let silu_deriv = sig * (1.0 + g - silu_g); + let expected = u * silu_deriv; + assert!( + (d_a[i] - expected).abs() < 1e-4, + "d_a[{i}]: got {}, expected {expected}", + d_a[i] + ); + } + } + + #[test] + fn test_relu_mul_forward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let a = Var::new( + Tensor::::from_slice(&[-1.0f32, 0.0, 2.0], &[3], &device), + false, + ); + let b = Var::new( + Tensor::::from_slice(&[5.0f32, 5.0, 5.0], &[3], &device), + false, + ); + + let output = var_relu_mul(&a, &b, &client).unwrap(); + let data: Vec = output.tensor().to_vec(); + assert!((data[0] - 0.0).abs() < 1e-6); + assert!((data[1] - 0.0).abs() < 1e-6); + assert!((data[2] - 10.0).abs() < 1e-6); + } + + #[test] + fn test_sigmoid_mul_backward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let a = Var::new( + Tensor::::from_slice(&[0.0f32], &[1], &device), + true, + ); + let b = Var::new( + Tensor::::from_slice(&[2.0f32], &[1], &device), + true, + ); + + let output = var_sigmoid_mul(&a, &b, &client).unwrap(); + let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let d_a: Vec = grads.get(a.id()).unwrap().to_vec(); + let d_b: Vec = grads.get(b.id()).unwrap().to_vec(); + + // d_b = sigmoid(0) = 0.5 + assert!((d_b[0] - 0.5).abs() < 1e-4); + + // d_a = b * sigmoid'(0) = 2 * sigmoid(0)*(1-sigmoid(0)) = 2 * 0.25 = 0.5 + assert!((d_a[0] - 0.5).abs() < 1e-4); + } + + #[test] + fn test_no_grad() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let a = Var::new( + Tensor::::from_slice(&[1.0f32], &[1], &device), + false, + ); + let b = Var::new( + Tensor::::from_slice(&[2.0f32], &[1], &device), + false, + ); + + let output = var_gelu_mul(&a, &b, &client).unwrap(); + assert!(!output.requires_grad()); + } +} diff --git a/src/autograd/var_ops/gemm_epilogue.rs b/src/autograd/var_ops/gemm_epilogue.rs new file mode 100644 index 00000000..8a2ddb15 --- /dev/null +++ b/src/autograd/var_ops/gemm_epilogue.rs @@ -0,0 +1,152 @@ +//! Fused GEMM + bias + activation var operations + +use super::ops::*; +use crate::autograd::Var; +use crate::error::Result; +use crate::ops::{GemmActivation, GemmEpilogueOps, ScalarOps, TensorOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use std::sync::Arc; + +/// Fused GEMM + bias + activation: output = activation(A @ B + bias) +/// +/// # Arguments +/// +/// * `a` - Input variable of shape `[..., M, K]` +/// * `b` - Weight variable of shape `[..., K, N]` +/// * `bias` - Bias variable of shape `[N]` +/// * `activation` - Activation function to apply +/// * `client` - Runtime client +pub fn var_matmul_bias_activation( + a: &Var, + b: &Var, + bias: &Var, + activation: GemmActivation, + client: &C, +) -> Result> +where + R: Runtime, + C: RuntimeClient + GemmEpilogueOps, + R::Client: TensorOps + ScalarOps, +{ + let output = + client.matmul_bias_activation(a.tensor(), b.tensor(), bias.tensor(), activation)?; + + if a.requires_grad() || b.requires_grad() || bias.requires_grad() { + let grad_fn = MatmulBiasActivationBackward::::new( + a.id(), + b.id(), + bias.id(), + a.tensor().clone(), + b.tensor().clone(), + bias.tensor().clone(), + activation, + a.grad_fn().cloned(), + b.grad_fn().cloned(), + bias.grad_fn().cloned(), + ); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::autograd::backward; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + use crate::tensor::Tensor; + + #[test] + fn test_var_matmul_bias_activation_forward_none() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let a = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device), + true, + ); + let b = Var::new( + Tensor::::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &device), + true, + ); + let bias = Var::new( + Tensor::::from_slice(&[0.1f32, 0.2], &[2], &device), + true, + ); + + let result = + var_matmul_bias_activation(&a, &b, &bias, GemmActivation::None, &client).unwrap(); + let data: Vec = result.tensor().to_vec(); + + // A @ B = [[1, 2], [3, 4]] @ [[1, 0], [0, 1]] = [[1, 2], [3, 4]] + // + bias = [[1.1, 2.2], [3.1, 4.2]] + assert!((data[0] - 1.1).abs() < 1e-5); + assert!((data[1] - 2.2).abs() < 1e-5); + assert!((data[2] - 3.1).abs() < 1e-5); + assert!((data[3] - 4.2).abs() < 1e-5); + } + + #[test] + fn test_var_matmul_bias_activation_backward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let a = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device), + true, + ); + let b = Var::new( + Tensor::::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &device), + true, + ); + let bias = Var::new( + Tensor::::from_slice(&[0.0f32, 0.0], &[2], &device), + true, + ); + + let output = + var_matmul_bias_activation(&a, &b, &bias, GemmActivation::None, &client).unwrap(); + let loss = crate::autograd::var_sum(&output, &[0, 1], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let ga: Vec = grads.get(a.id()).unwrap().to_vec(); + let gb: Vec = grads.get(b.id()).unwrap().to_vec(); + let gbias: Vec = grads.get(bias.id()).unwrap().to_vec(); + + assert_eq!(ga.len(), 4); + assert_eq!(gb.len(), 4); + assert_eq!(gbias.len(), 2); + + for val in ga.iter().chain(gb.iter()).chain(gbias.iter()) { + assert!(val.is_finite(), "gradient should be finite"); + } + + // d_bias should be sum over rows = [2.0, 2.0] (2 rows, each contributing 1.0) + assert!((gbias[0] - 2.0).abs() < 1e-5); + assert!((gbias[1] - 2.0).abs() < 1e-5); + } + + #[test] + fn test_var_matmul_bias_activation_no_grad() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let a = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device), + false, + ); + let b = Var::new( + Tensor::::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &device), + false, + ); + let bias = Var::new( + Tensor::::from_slice(&[0.0f32, 0.0], &[2], &device), + false, + ); + + let result = + var_matmul_bias_activation(&a, &b, &bias, GemmActivation::None, &client).unwrap(); + assert!(!result.requires_grad()); + } +} diff --git a/src/autograd/var_ops/indexing.rs b/src/autograd/var_ops/indexing.rs index 7e36fa78..9d364fb9 100644 --- a/src/autograd/var_ops/indexing.rs +++ b/src/autograd/var_ops/indexing.rs @@ -2,6 +2,7 @@ use super::ops::*; use crate::autograd::Var; +use crate::dtype::DType; use crate::error::Result; use crate::ops::IndexingOps; use crate::runtime::{Runtime, RuntimeClient}; @@ -15,7 +16,7 @@ pub fn var_gather( client: &C, ) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + IndexingOps, R::Client: IndexingOps, { diff --git a/src/autograd/var_ops/linalg.rs b/src/autograd/var_ops/linalg.rs index 3d8ae64d..b67a0968 100644 --- a/src/autograd/var_ops/linalg.rs +++ b/src/autograd/var_ops/linalg.rs @@ -3,6 +3,7 @@ use super::ops::*; use crate::algorithm::LinearAlgebraAlgorithms; use crate::autograd::Var; +use crate::dtype::DType; use crate::error::Result; use crate::ops::{ScalarOps, TensorOps}; use crate::runtime::{Runtime, RuntimeClient}; @@ -13,7 +14,7 @@ use std::sync::Arc; /// Creates TraceBackward for gradient computation. pub fn var_trace(a: &Var, client: &C) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + LinearAlgebraAlgorithms, R::Client: TensorOps + ScalarOps + LinearAlgebraAlgorithms, { @@ -102,7 +103,7 @@ where /// Creates CholeskyBackward for gradient computation. pub fn var_cholesky(a: &Var, client: &C) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + LinearAlgebraAlgorithms, R::Client: TensorOps + ScalarOps + LinearAlgebraAlgorithms, { diff --git a/src/autograd/var_ops/macros.rs b/src/autograd/var_ops/macros.rs index d4dfc6af..07123702 100644 --- a/src/autograd/var_ops/macros.rs +++ b/src/autograd/var_ops/macros.rs @@ -180,7 +180,7 @@ macro_rules! impl_var_unary_op_output_scalar { $(#[$meta])* pub fn $fn_name(a: &Var, client: &C) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + TensorOps, R::Client: TensorOps + ScalarOps, { diff --git a/src/autograd/var_ops/mod.rs b/src/autograd/var_ops/mod.rs index 8fdb0265..47adc882 100644 --- a/src/autograd/var_ops/mod.rs +++ b/src/autograd/var_ops/mod.rs @@ -27,26 +27,46 @@ pub mod ops; mod activation; mod arithmetic; +mod cast; +mod conv1d; +mod conv2d; +mod conv_common; mod cumulative; +mod dropout; +mod fused_activation_mul; +mod gemm_epilogue; mod indexing; + pub mod linalg; mod matmul; +mod normalization; pub mod reduce; mod scalar; mod stats; +mod swiglu; mod unary; mod utility; // Re-export all public functions -pub use activation::{var_relu, var_sigmoid, var_softmax}; +pub use activation::{var_log_softmax, var_relu, var_sigmoid, var_silu, var_softmax, var_softplus}; pub use arithmetic::{var_add, var_div, var_mul, var_pow, var_sub}; +pub use cast::var_cast; +pub use conv1d::var_conv1d; +pub use conv2d::var_conv2d; pub use cumulative::{var_cumprod, var_cumsum}; +pub use dropout::var_dropout; +pub use fused_activation_mul::{var_gelu_mul, var_relu_mul, var_sigmoid_mul, var_silu_mul}; +pub use gemm_epilogue::var_matmul_bias_activation; pub use indexing::var_gather; pub use linalg::{var_cholesky, var_det, var_inverse, var_solve, var_trace}; pub use matmul::var_matmul; +pub use normalization::{ + var_fused_add_layer_norm, var_fused_add_rms_norm, var_group_norm, var_layer_norm, var_rms_norm, +}; pub use reduce::{var_max, var_mean, var_min, var_sum}; pub use scalar::{var_add_scalar, var_div_scalar, var_mul_scalar, var_pow_scalar, var_sub_scalar}; pub use stats::{var_std, var_var}; +pub use swiglu::var_swiglu; pub use unary::{ var_abs, var_cos, var_exp, var_log, var_neg, var_recip, var_sin, var_sqrt, var_square, var_tan, var_tanh, diff --git a/src/autograd/var_ops/normalization.rs b/src/autograd/var_ops/normalization.rs new file mode 100644 index 00000000..e0e88362 --- /dev/null +++ b/src/autograd/var_ops/normalization.rs @@ -0,0 +1,696 @@ +//! Normalization operations (rms_norm, layer_norm) + +use super::ops::*; +use crate::autograd::Var; +use crate::error::Result; +use crate::ops::{NormalizationOps, ScalarOps, TensorOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use std::sync::Arc; + +/// RMS Normalization: y = x / rms(x) * weight +/// +/// Uses the fused `NormalizationOps::rms_norm` kernel for the forward pass +/// and tracks gradients for both input and weight. +/// +/// # Arguments +/// +/// * `input` - Input variable of shape `[..., hidden_size]` +/// * `weight` - Weight variable of shape `[hidden_size]` +/// * `eps` - Small constant for numerical stability +/// * `client` - Runtime client +pub fn var_rms_norm(input: &Var, weight: &Var, eps: f32, client: &C) -> Result> +where + R: Runtime, + C: RuntimeClient + NormalizationOps, + R::Client: TensorOps + ScalarOps, +{ + let output = client.rms_norm(input.tensor(), weight.tensor(), eps)?; + + if input.requires_grad() || weight.requires_grad() { + let grad_fn = RmsNormBackward::::new( + input.id(), + weight.id(), + input.tensor().clone(), + weight.tensor().clone(), + eps, + input.grad_fn().cloned(), + weight.grad_fn().cloned(), + ); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + +/// Layer Normalization: y = (x - mean(x)) / sqrt(var(x) + eps) * weight + bias +/// +/// Uses the fused `NormalizationOps::layer_norm` kernel for the forward pass +/// and tracks gradients for input, weight, and bias. +/// +/// # Arguments +/// +/// * `input` - Input variable of shape `[..., hidden_size]` +/// * `weight` - Weight (gamma) variable of shape `[hidden_size]` +/// * `bias` - Bias (beta) variable of shape `[hidden_size]` +/// * `eps` - Small constant for numerical stability +/// * `client` - Runtime client +pub fn var_layer_norm( + input: &Var, + weight: &Var, + bias: &Var, + eps: f32, + client: &C, +) -> Result> +where + R: Runtime, + C: RuntimeClient + NormalizationOps, + R::Client: TensorOps + ScalarOps, +{ + let output = client.layer_norm(input.tensor(), weight.tensor(), bias.tensor(), eps)?; + + if input.requires_grad() || weight.requires_grad() || bias.requires_grad() { + let grad_fn = LayerNormBackward::::new( + input.id(), + weight.id(), + bias.id(), + input.tensor().clone(), + weight.tensor().clone(), + eps, + input.grad_fn().cloned(), + weight.grad_fn().cloned(), + bias.grad_fn().cloned(), + ); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + +/// Group Normalization with autograd support. +/// +/// Input: `[batch, channels, *spatial]` +/// Normalizes over groups of channels independently. +/// +/// # Arguments +/// * `input` - Input variable `[batch, channels, *spatial]` +/// * `weight` - Gamma variable `[channels]` +/// * `bias` - Beta variable `[channels]` +/// * `num_groups` - Number of groups (must divide channels) +/// * `eps` - Numerical stability constant +/// * `client` - Runtime client +pub fn var_group_norm( + input: &Var, + weight: &Var, + bias: &Var, + num_groups: usize, + eps: f32, + client: &C, +) -> Result> +where + R: Runtime, + C: RuntimeClient + NormalizationOps, + R::Client: TensorOps + ScalarOps, +{ + let output = client.group_norm( + input.tensor(), + weight.tensor(), + bias.tensor(), + num_groups, + eps, + )?; + + if input.requires_grad() || weight.requires_grad() || bias.requires_grad() { + let grad_fn = GroupNormBackward::::new( + input.id(), + weight.id(), + bias.id(), + input.tensor().clone(), + weight.tensor().clone(), + num_groups, + eps, + input.grad_fn().cloned(), + weight.grad_fn().cloned(), + bias.grad_fn().cloned(), + ); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + +/// Fused Add + RMS Normalization: pre_norm = x + residual, output = rms_norm(pre_norm, weight, eps) +/// +/// Returns a single output variable. Both `x` and `residual` receive the same gradient. +/// +/// # Arguments +/// +/// * `x` - Input variable of shape `[..., hidden_size]` +/// * `residual` - Residual variable of same shape as `x` +/// * `weight` - Weight variable of shape `[hidden_size]` +/// * `eps` - Small constant for numerical stability +/// * `client` - Runtime client +pub fn var_fused_add_rms_norm( + x: &Var, + residual: &Var, + weight: &Var, + eps: f32, + client: &C, +) -> Result> +where + R: Runtime, + C: RuntimeClient + NormalizationOps, + R::Client: TensorOps + ScalarOps, +{ + let (output, pre_norm) = + client.fused_add_rms_norm(x.tensor(), residual.tensor(), weight.tensor(), eps)?; + + if x.requires_grad() || residual.requires_grad() || weight.requires_grad() { + let grad_fn = FusedAddRmsNormBackward::::new( + x.id(), + residual.id(), + weight.id(), + pre_norm, + weight.tensor().clone(), + eps, + x.grad_fn().cloned(), + residual.grad_fn().cloned(), + weight.grad_fn().cloned(), + ); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + +/// Fused Add + Layer Normalization: pre_norm = x + residual, output = layer_norm(pre_norm, weight, bias, eps) +/// +/// Returns a single output variable. Both `x` and `residual` receive the same gradient. +/// +/// # Arguments +/// +/// * `x` - Input variable of shape `[..., hidden_size]` +/// * `residual` - Residual variable of same shape as `x` +/// * `weight` - Weight (gamma) variable of shape `[hidden_size]` +/// * `bias` - Bias (beta) variable of shape `[hidden_size]` +/// * `eps` - Small constant for numerical stability +/// * `client` - Runtime client +pub fn var_fused_add_layer_norm( + x: &Var, + residual: &Var, + weight: &Var, + bias: &Var, + eps: f32, + client: &C, +) -> Result> +where + R: Runtime, + C: RuntimeClient + NormalizationOps, + R::Client: TensorOps + ScalarOps, +{ + let (output, pre_norm) = client.fused_add_layer_norm( + x.tensor(), + residual.tensor(), + weight.tensor(), + bias.tensor(), + eps, + )?; + + if x.requires_grad() + || residual.requires_grad() + || weight.requires_grad() + || bias.requires_grad() + { + let grad_fn = FusedAddLayerNormBackward::::new( + x.id(), + residual.id(), + weight.id(), + bias.id(), + pre_norm, + weight.tensor().clone(), + bias.tensor().clone(), + eps, + x.grad_fn().cloned(), + residual.grad_fn().cloned(), + weight.grad_fn().cloned(), + bias.grad_fn().cloned(), + ); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::autograd::backward; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + use crate::tensor::Tensor; + + #[test] + fn test_var_rms_norm_forward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let input = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 4], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device), + true, + ); + + let result = var_rms_norm(&input, &weight, 1e-5, &client).unwrap(); + let data: Vec = result.tensor().to_vec(); + + // rms = sqrt(mean([1, 4, 9, 16]) + 1e-5) = sqrt(7.5 + 1e-5) ~ 2.7386 + // output = [1/rms, 2/rms, 3/rms, 4/rms] * [1,1,1,1] + let rms = (7.5f32 + 1e-5).sqrt(); + for i in 0..4 { + let expected = (i as f32 + 1.0) / rms; + assert!( + (data[i] - expected).abs() < 1e-5, + "data[{}] = {}, expected {}", + i, + data[i], + expected, + ); + } + } + + #[test] + fn test_var_rms_norm_backward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let input = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0, 1.0], &[3], &device), + true, + ); + + let output = var_rms_norm(&input, &weight, 1e-5, &client).unwrap(); + + // Sum the output to get a scalar for backward + // Sum over all dims to get a scalar for backward + let loss = crate::autograd::var_sum(&output, &[0, 1], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let grad_input = grads.get(input.id()).unwrap(); + let grad_weight = grads.get(weight.id()).unwrap(); + + let gi: Vec = grad_input.to_vec(); + let gw: Vec = grad_weight.to_vec(); + + // Verify gradients are finite and have correct shapes + assert_eq!(gi.len(), 3); + assert_eq!(gw.len(), 3); + for val in &gi { + assert!(val.is_finite(), "input gradient should be finite"); + } + for val in &gw { + assert!(val.is_finite(), "weight gradient should be finite"); + } + } + + #[test] + fn test_var_layer_norm_forward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let input = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 4], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device), + true, + ); + let bias = Var::new( + Tensor::::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[4], &device), + true, + ); + + let result = var_layer_norm(&input, &weight, &bias, 1e-5, &client).unwrap(); + let data: Vec = result.tensor().to_vec(); + + // mean = 2.5, var = mean([(-1.5)^2, (-0.5)^2, (0.5)^2, (1.5)^2]) = 1.25 + // rstd = 1/sqrt(1.25 + 1e-5) + // output should have mean ~0 and unit variance + let sum: f32 = data.iter().sum(); + assert!(sum.abs() < 1e-4, "layer norm output should have ~0 mean"); + } + + #[test] + fn test_var_layer_norm_backward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let input = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0, 1.0], &[3], &device), + true, + ); + let bias = Var::new( + Tensor::::from_slice(&[0.0f32, 0.0, 0.0], &[3], &device), + true, + ); + + let output = var_layer_norm(&input, &weight, &bias, 1e-5, &client).unwrap(); + + // Sum over all dims to get a scalar for backward + let loss = crate::autograd::var_sum(&output, &[0, 1], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let grad_input = grads.get(input.id()).unwrap(); + let grad_weight = grads.get(weight.id()).unwrap(); + let grad_bias = grads.get(bias.id()).unwrap(); + + let gi: Vec = grad_input.to_vec(); + let gw: Vec = grad_weight.to_vec(); + let gb: Vec = grad_bias.to_vec(); + + // Verify shapes + assert_eq!(gi.len(), 3); + assert_eq!(gw.len(), 3); + assert_eq!(gb.len(), 3); + + // For layer norm with sum loss: + // d_bias = sum(grad_output) = [1, 1, 1] (each element contributes 1) + for val in &gb { + assert!( + (*val - 1.0).abs() < 1e-5, + "bias gradient should be 1.0, got {}", + val, + ); + } + + // d_input should sum to ~0 (layer norm property) + let sum: f32 = gi.iter().sum(); + assert!( + sum.abs() < 1e-5, + "sum of input gradients should be ~0, got {}", + sum, + ); + + // All gradients should be finite + for val in &gi { + assert!(val.is_finite()); + } + for val in &gw { + assert!(val.is_finite()); + } + } + + #[test] + fn test_var_rms_norm_no_grad() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // When no inputs require grad, output should not track gradients + let input = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0], &[1, 2], &device), + false, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0], &[2], &device), + false, + ); + + let result = var_rms_norm(&input, &weight, 1e-5, &client).unwrap(); + assert!(!result.requires_grad()); + assert!(result.grad_fn().is_none()); + } + + #[test] + fn test_var_group_norm_forward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // [batch=1, channels=4, spatial=3], 2 groups + let input = Var::new( + Tensor::::from_slice( + &[ + 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ], + &[1, 4, 3], + &device, + ), + false, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device), + false, + ); + let bias = Var::new( + Tensor::::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[4], &device), + false, + ); + + let result = var_group_norm(&input, &weight, &bias, 2, 1e-5, &client).unwrap(); + assert_eq!(result.tensor().shape(), &[1, 4, 3]); + + // Each group should have approximately zero mean + let data: Vec = result.tensor().to_vec(); + // Group 0: channels 0,1 → indices 0..6 + let group0_sum: f32 = data[0..6].iter().sum(); + assert!( + group0_sum.abs() < 1e-4, + "group 0 mean should be ~0, sum={group0_sum}" + ); + // Group 1: channels 2,3 → indices 6..12 + let group1_sum: f32 = data[6..12].iter().sum(); + assert!( + group1_sum.abs() < 1e-4, + "group 1 mean should be ~0, sum={group1_sum}" + ); + } + + #[test] + fn test_var_group_norm_backward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + // [batch=1, channels=4, spatial=2], 2 groups + let input = Var::new( + Tensor::::from_slice( + &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + &[1, 4, 2], + &device, + ), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device), + true, + ); + let bias = Var::new( + Tensor::::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[4], &device), + true, + ); + + let output = var_group_norm(&input, &weight, &bias, 2, 1e-5, &client).unwrap(); + let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let d_input: Vec = grads.get(input.id()).unwrap().to_vec(); + let d_weight: Vec = grads.get(weight.id()).unwrap().to_vec(); + let d_bias: Vec = grads.get(bias.id()).unwrap().to_vec(); + + assert_eq!(d_input.len(), 8); + assert_eq!(d_weight.len(), 4); + assert_eq!(d_bias.len(), 4); + + // d_bias should be sum of grad_output over batch and spatial = [2, 2, 2, 2] + for &b in &d_bias { + assert!((b - 2.0).abs() < 1e-5, "d_bias should be 2.0, got {b}"); + } + + // All gradients should be finite + for v in d_input.iter().chain(d_weight.iter()) { + assert!(v.is_finite()); + } + } + + #[test] + fn test_var_fused_add_rms_norm_forward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let x = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 4], &device), + true, + ); + let residual = Var::new( + Tensor::::from_slice(&[0.1f32, 0.2, 0.3, 0.4], &[1, 4], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device), + true, + ); + + let result = var_fused_add_rms_norm(&x, &residual, &weight, 1e-5, &client).unwrap(); + let data: Vec = result.tensor().to_vec(); + + assert_eq!(data.len(), 4); + for val in &data { + assert!(val.is_finite()); + } + } + + #[test] + fn test_var_fused_add_rms_norm_backward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let x = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3], &device), + true, + ); + let residual = Var::new( + Tensor::::from_slice(&[0.1f32, 0.2, 0.3], &[1, 3], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0, 1.0], &[3], &device), + true, + ); + + let output = var_fused_add_rms_norm(&x, &residual, &weight, 1e-5, &client).unwrap(); + let loss = crate::autograd::var_sum(&output, &[0, 1], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let gx: Vec = grads.get(x.id()).unwrap().to_vec(); + let gr: Vec = grads.get(residual.id()).unwrap().to_vec(); + let gw: Vec = grads.get(weight.id()).unwrap().to_vec(); + + assert_eq!(gx.len(), 3); + assert_eq!(gr.len(), 3); + assert_eq!(gw.len(), 3); + + // x and residual should get the same gradient + for (a, b) in gx.iter().zip(gr.iter()) { + assert!( + (a - b).abs() < 1e-5, + "x and residual grads must match: {a} vs {b}" + ); + } + for val in gx.iter().chain(gw.iter()) { + assert!(val.is_finite()); + } + } + + #[test] + fn test_var_fused_add_rms_norm_no_grad() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let x = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0], &[1, 2], &device), + false, + ); + let residual = Var::new( + Tensor::::from_slice(&[0.1f32, 0.2], &[1, 2], &device), + false, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0], &[2], &device), + false, + ); + + let result = var_fused_add_rms_norm(&x, &residual, &weight, 1e-5, &client).unwrap(); + assert!(!result.requires_grad()); + } + + #[test] + fn test_var_fused_add_layer_norm_forward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let x = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 4], &device), + true, + ); + let residual = Var::new( + Tensor::::from_slice(&[0.1f32, 0.2, 0.3, 0.4], &[1, 4], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device), + true, + ); + let bias = Var::new( + Tensor::::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[4], &device), + true, + ); + + let result = + var_fused_add_layer_norm(&x, &residual, &weight, &bias, 1e-5, &client).unwrap(); + let data: Vec = result.tensor().to_vec(); + + // Layer norm output should have ~0 mean + let sum: f32 = data.iter().sum(); + assert!( + sum.abs() < 1e-4, + "output should have ~0 mean, got sum={sum}" + ); + } + + #[test] + fn test_var_fused_add_layer_norm_backward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let x = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3], &device), + true, + ); + let residual = Var::new( + Tensor::::from_slice(&[0.1f32, 0.2, 0.3], &[1, 3], &device), + true, + ); + let weight = Var::new( + Tensor::::from_slice(&[1.0f32, 1.0, 1.0], &[3], &device), + true, + ); + let bias = Var::new( + Tensor::::from_slice(&[0.0f32, 0.0, 0.0], &[3], &device), + true, + ); + + let output = + var_fused_add_layer_norm(&x, &residual, &weight, &bias, 1e-5, &client).unwrap(); + let loss = crate::autograd::var_sum(&output, &[0, 1], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let gx: Vec = grads.get(x.id()).unwrap().to_vec(); + let gr: Vec = grads.get(residual.id()).unwrap().to_vec(); + let gw: Vec = grads.get(weight.id()).unwrap().to_vec(); + let gb: Vec = grads.get(bias.id()).unwrap().to_vec(); + + // x and residual should get the same gradient + for (a, b) in gx.iter().zip(gr.iter()) { + assert!((a - b).abs() < 1e-5, "x and residual grads must match"); + } + + // d_bias should be [1, 1, 1] for sum loss + for val in &gb { + assert!( + (*val - 1.0).abs() < 1e-5, + "bias gradient should be 1.0, got {val}" + ); + } + + for val in gx.iter().chain(gw.iter()) { + assert!(val.is_finite()); + } + } +} diff --git a/src/autograd/var_ops/swiglu.rs b/src/autograd/var_ops/swiglu.rs new file mode 100644 index 00000000..4a4e1647 --- /dev/null +++ b/src/autograd/var_ops/swiglu.rs @@ -0,0 +1,252 @@ +//! Fused SwiGLU activation with gradient support +//! +//! SwiGLU(gate, up) = silu(gate) * up +//! +//! Fused version saves one intermediate tensor vs composing var_silu + var_mul: +//! - Composed: stores gate, silu(gate), up (3 tensors) +//! - Fused: stores gate, up, output (3 tensors but recomputes sigmoid in backward) +//! +//! More importantly, the fused backward computes gradients in fewer ops. + +use crate::autograd::Var; +use crate::autograd::var_ops::var_mul; +use crate::dtype::DType; +use crate::error::Result; +use crate::ops::{ActivationOps, BinaryOps, ScalarOps, TensorOps}; +use crate::runtime::{Runtime, RuntimeClient}; +use std::sync::Arc; + +/// Fused SwiGLU: output = silu(gate) * up +/// +/// # Arguments +/// * `gate` - Gate input (will have silu applied) +/// * `up` - Up projection (multiplied element-wise with activated gate) +/// * `client` - Runtime client +/// +/// # Returns +/// The SwiGLU output with autograd tracking. +pub fn var_swiglu(gate: &Var, up: &Var, client: &C) -> Result> +where + R: Runtime, + C: RuntimeClient + TensorOps + ActivationOps + ScalarOps + BinaryOps, + R::Client: TensorOps + ActivationOps + ScalarOps + BinaryOps, +{ + // Forward: output = silu(gate) * up (fused single-pass kernel) + let silu_gate = client.silu(gate.tensor())?; + let output = client.silu_mul(gate.tensor(), up.tensor())?; + + if gate.requires_grad() || up.requires_grad() { + let grad_fn = SwiGLUBackward::::new( + gate.id(), + up.id(), + gate.tensor().clone(), + up.tensor().clone(), + silu_gate, + gate.grad_fn().cloned(), + up.grad_fn().cloned(), + ); + Ok(Var::from_op(output, Arc::new(grad_fn))) + } else { + Ok(Var::new(output, false)) + } +} + +/// Backward for fused SwiGLU: output = silu(gate) * up +/// +/// Gradients: +/// - d_gate = grad_output * up * silu'(gate) +/// = grad_output * up * (sigmoid(gate) * (1 + gate - silu(gate))) +/// - d_up = grad_output * silu(gate) +pub struct SwiGLUBackward { + input_ids: [crate::tensor::TensorId; 2], + saved_gate: crate::tensor::Tensor, + saved_up: crate::tensor::Tensor, + saved_silu_gate: crate::tensor::Tensor, + gate_grad_fn: Option>>, + up_grad_fn: Option>>, +} + +impl SwiGLUBackward { + pub fn new( + gate_id: crate::tensor::TensorId, + up_id: crate::tensor::TensorId, + gate: crate::tensor::Tensor, + up: crate::tensor::Tensor, + silu_gate: crate::tensor::Tensor, + gate_grad_fn: Option>>, + up_grad_fn: Option>>, + ) -> Self { + Self { + input_ids: [gate_id, up_id], + saved_gate: gate, + saved_up: up, + saved_silu_gate: silu_gate, + gate_grad_fn, + up_grad_fn, + } + } +} + +impl> crate::autograd::GradFn for SwiGLUBackward +where + R::Client: TensorOps + ActivationOps + ScalarOps + BinaryOps, +{ + fn backward( + &self, + grad_output: &crate::tensor::Tensor, + ) -> Result>>> { + let client = R::default_client(grad_output.device()); + + // d_up = grad_output * silu(gate) + let d_up = client.mul(grad_output, &self.saved_silu_gate)?; + + // d_gate = grad_output * up * silu'(gate) + // silu'(x) = sigmoid(x) * (1 + x - silu(x)) + let sigmoid_gate = client.sigmoid(&self.saved_gate)?; + let one_plus_gate = client.add_scalar(&self.saved_gate, 1.0)?; + let one_plus_gate_minus_silu = client.sub(&one_plus_gate, &self.saved_silu_gate)?; + let silu_deriv = client.mul(&sigmoid_gate, &one_plus_gate_minus_silu)?; + let grad_times_up = client.mul(grad_output, &self.saved_up)?; + let d_gate = client.mul(&grad_times_up, &silu_deriv)?; + + Ok(vec![Some(d_gate), Some(d_up)]) + } + + fn backward_var(&self, grad_output: &Var) -> Result>>> + where + R::Client: RuntimeClient + TensorOps + ActivationOps + ScalarOps + BinaryOps, + { + let client = R::default_client(grad_output.tensor().device()); + + // d_up = grad_output * silu(gate) [silu_gate is constant w.r.t. higher-order] + let silu_var = Var::new(self.saved_silu_gate.clone(), false); + let d_up = var_mul(grad_output, &silu_var, &client)?; + + // d_gate = grad_output * up * silu'(gate) + let sigmoid_gate = client.sigmoid(&self.saved_gate)?; + let one_plus_gate = client.add_scalar(&self.saved_gate, 1.0)?; + let one_plus_gate_minus_silu = client.sub(&one_plus_gate, &self.saved_silu_gate)?; + let silu_deriv = client.mul(&sigmoid_gate, &one_plus_gate_minus_silu)?; + let silu_deriv_var = Var::new(silu_deriv, false); + + let up_var = Var::new(self.saved_up.clone(), false); + let grad_times_up = var_mul(grad_output, &up_var, &client)?; + let d_gate = var_mul(&grad_times_up, &silu_deriv_var, &client)?; + + Ok(vec![Some(d_gate), Some(d_up)]) + } + + fn inputs(&self) -> &[crate::tensor::TensorId] { + &self.input_ids + } + + fn input_grad_fns(&self) -> Vec>>> { + vec![self.gate_grad_fn.clone(), self.up_grad_fn.clone()] + } + + fn saved_tensors(&self) -> &[crate::tensor::Tensor] { + std::slice::from_ref(&self.saved_gate) + } + + fn name(&self) -> &'static str { + "SwiGLUBackward" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::autograd::backward; + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + use crate::tensor::Tensor; + + #[test] + fn test_swiglu_forward() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let gate = Var::new( + Tensor::::from_slice(&[0.0f32, 1.0, -1.0], &[3], &device), + false, + ); + let up = Var::new( + Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device), + false, + ); + + let output = var_swiglu(&gate, &up, &client).unwrap(); + let data: Vec = output.tensor().to_vec(); + + // silu(0) * 1 = 0 * 0.5 * 1 = 0 + assert!(data[0].abs() < 1e-6); + // silu(1) * 2 = 0.7311 * 2 ≈ 1.4621 + let silu_1 = 1.0 / (1.0 + (-1.0f32).exp()); + assert!((data[1] - silu_1 * 2.0).abs() < 1e-4); + // silu(-1) * 3 = -0.2689 * 3 ≈ -0.8067 + let silu_neg1 = -1.0 / (1.0 + 1.0f32.exp()); + assert!((data[2] - silu_neg1 * 3.0).abs() < 1e-4); + } + + #[test] + fn test_swiglu_backward_gate() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let gate = Var::new( + Tensor::::from_slice(&[1.0f32, -1.0], &[2], &device), + true, + ); + let up = Var::new( + Tensor::::from_slice(&[2.0f32, 3.0], &[2], &device), + true, + ); + + let output = var_swiglu(&gate, &up, &client).unwrap(); + let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap(); + let grads = backward(&loss, &client).unwrap(); + + let d_gate: Vec = grads.get(gate.id()).unwrap().to_vec(); + let d_up: Vec = grads.get(up.id()).unwrap().to_vec(); + + // Verify d_up = silu(gate) + for (i, &g) in [1.0f32, -1.0].iter().enumerate() { + let expected_d_up = g * (1.0 / (1.0 + (-g).exp())); + assert!( + (d_up[i] - expected_d_up).abs() < 1e-5, + "d_up[{i}]: got {}, expected {expected_d_up}", + d_up[i] + ); + } + + // Verify d_gate = up * silu'(gate) + for (i, (&g, &u)) in [1.0f32, -1.0].iter().zip([2.0f32, 3.0].iter()).enumerate() { + let sig = 1.0 / (1.0 + (-g).exp()); + let silu_g = g * sig; + let silu_deriv = sig * (1.0 + g - silu_g); + let expected = u * silu_deriv; + assert!( + (d_gate[i] - expected).abs() < 1e-4, + "d_gate[{i}]: got {}, expected {expected}", + d_gate[i] + ); + } + } + + #[test] + fn test_swiglu_no_grad() { + let device = CpuDevice::new(); + let client = CpuRuntime::default_client(&device); + + let gate = Var::new( + Tensor::::from_slice(&[1.0f32], &[1], &device), + false, + ); + let up = Var::new( + Tensor::::from_slice(&[2.0f32], &[1], &device), + false, + ); + + let output = var_swiglu(&gate, &up, &client).unwrap(); + assert!(!output.requires_grad()); + } +} diff --git a/src/autograd/var_ops/utility.rs b/src/autograd/var_ops/utility.rs index ad272dc8..7fd599a4 100644 --- a/src/autograd/var_ops/utility.rs +++ b/src/autograd/var_ops/utility.rs @@ -2,6 +2,7 @@ use super::ops::*; use crate::autograd::Var; +use crate::dtype::DType; use crate::error::Result; use crate::ops::{CompareOps, ScalarOps, TensorOps}; use crate::runtime::{Runtime, RuntimeClient}; @@ -12,7 +13,7 @@ use std::sync::Arc; /// Creates ClampBackward for gradient computation. pub fn var_clamp(a: &Var, min_val: f64, max_val: f64, client: &C) -> Result> where - R: Runtime, + R: Runtime, C: RuntimeClient + TensorOps, R::Client: TensorOps + ScalarOps + CompareOps, { diff --git a/src/dtype/complex.rs b/src/dtype/complex.rs index 47a9efc1..87285720 100644 --- a/src/dtype/complex.rs +++ b/src/dtype/complex.rs @@ -32,7 +32,7 @@ use bytemuck::{Pod, Zeroable}; use std::fmt; -use std::ops::{Add, Div, Mul, Neg, Sub}; +use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; // ============================================================================ // CUDA Compatibility Traits @@ -243,6 +243,29 @@ macro_rules! impl_complex { } } + impl AddAssign for $name { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.re += rhs.re; + self.im += rhs.im; + } + } + + impl SubAssign for $name { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.re -= rhs.re; + self.im -= rhs.im; + } + } + + impl MulAssign for $name { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } + } + impl PartialOrd for $name { /// Complex numbers are not naturally ordered. /// This compares by magnitude for sorting purposes. diff --git a/src/dtype/data_type.rs b/src/dtype/data_type.rs new file mode 100644 index 00000000..f0d78d1a --- /dev/null +++ b/src/dtype/data_type.rs @@ -0,0 +1,104 @@ +//! Extensible data type trait for tensor element types. + +use std::fmt; +use std::hash::Hash; + +use super::DType; + +/// Trait for data types that can be stored in tensors. +/// +/// numr's [`DType`] implements this. Downstream libraries (e.g. boostr) can +/// define their own dtype enums with quantized variants that also implement +/// this trait. The [`Runtime`](crate::runtime::Runtime) trait has an associated +/// `DType` type bounded by `DataType`, enabling each runtime to specify its +/// own dtype enum. +pub trait DataType: + Copy + Clone + fmt::Debug + PartialEq + Eq + Hash + Send + Sync + 'static +{ + /// Size of one element in bytes. + /// + /// For block-quantized types, returns 1 as placeholder — use + /// [`block_bytes`](Self::block_bytes) / [`block_size`](Self::block_size) for exact sizing. + fn size_in_bytes(self) -> usize; + + /// Short display name (e.g., "f32", "q4_0"). + fn short_name(self) -> &'static str; + + /// Whether this is a floating point type. + fn is_float(self) -> bool; + + /// Whether this is an integer type. + fn is_int(self) -> bool; + + /// Whether this is a quantized/block type. + fn is_quantized(self) -> bool { + false + } + + /// Block size for quantized types (elements per block), 1 for scalar types. + fn block_size(self) -> usize { + 1 + } + + /// Bytes per block for quantized types, `size_in_bytes()` for scalar types. + fn block_bytes(self) -> usize { + self.size_in_bytes() + } + + /// Total storage bytes for `numel` elements. + fn storage_bytes(self, numel: usize) -> usize { + if self.is_quantized() { + let bs = self.block_size(); + let bb = self.block_bytes(); + ((numel + bs - 1) / bs) * bb + } else { + numel * self.size_in_bytes() + } + } + + /// Try to convert to numr's standard [`DType`]. + /// + /// Returns `None` for custom/quantized types that have no numr equivalent. + fn as_standard(&self) -> Option; + + /// Fill a buffer with `count` elements set to `value`, returning raw bytes. + /// + /// This enables generic constructors (zeros, ones, full_scalar) to work + /// with any DType, not just numr's built-in DType. The default impl + /// delegates to `as_standard()` and uses numr's fill logic. + /// + /// Downstream libraries with custom dtypes (e.g. quantized types) should + /// override this if they need fill support. + fn fill_bytes(self, value: f64, count: usize) -> Option> { + self.as_standard() + .map(|std_dtype| std_dtype.fill_bytes_impl(value, count)) + } +} + +/// Implement `DataType` for numr's built-in `DType`. +impl DataType for DType { + #[inline] + fn size_in_bytes(self) -> usize { + DType::size_in_bytes(self) + } + + #[inline] + fn short_name(self) -> &'static str { + DType::short_name(self) + } + + #[inline] + fn is_float(self) -> bool { + DType::is_float(self) + } + + #[inline] + fn is_int(self) -> bool { + DType::is_int(self) + } + + #[inline] + fn as_standard(&self) -> Option { + Some(*self) + } +} diff --git a/src/dtype/dtype_enum.rs b/src/dtype/dtype_enum.rs new file mode 100644 index 00000000..1e04b8a6 --- /dev/null +++ b/src/dtype/dtype_enum.rs @@ -0,0 +1,275 @@ +//! Core DType enum and methods. + +use std::fmt; + +use super::complex::{Complex64, Complex128}; +use super::fp8::{FP8E4M3, FP8E5M2}; + +/// Data types supported by numr tensors +/// +/// This enum represents the element type of a tensor at runtime. +/// Using an enum (rather than generics) allows: +/// - Mixed-precision operations +/// - Runtime type selection +/// - Support for quantized types that aren't `Copy` +/// +/// # Discriminant Values (Serialization Stability) +/// +/// The discriminant values are **stable** for serialization purposes: +/// - Floats: 0-9 (F64=0, F32=1, F16=2, BF16=3, FP8E4M3=4, FP8E5M2=5) +/// - Signed ints: 10-19 (I64=10, I32=11, I16=12, I8=13) +/// - Unsigned ints: 20-29 (U64=20, U32=21, U16=22, U8=23) +/// - Bool: 30 +/// - Complex: 40-49 (Complex64=40, Complex128=41) +/// +/// New types will use reserved ranges. Existing values are NEVER changed. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +#[non_exhaustive] +#[repr(u8)] +pub enum DType { + // Floating point types (0-9) + /// 64-bit floating point + F64 = 0, + /// 32-bit floating point (most common) + F32 = 1, + /// 16-bit floating point (IEEE 754) + F16 = 2, + /// 16-bit brain floating point + BF16 = 3, + /// 8-bit floating point (1 sign + 4 exp + 3 mant), range ~[-448, 448] + /// Best for: weights, activations (higher precision, smaller range) + FP8E4M3 = 4, + /// 8-bit floating point (1 sign + 5 exp + 2 mant), range ~[-57344, 57344] + /// Best for: gradients (larger dynamic range, lower precision) + FP8E5M2 = 5, + + // Integer types + /// 64-bit signed integer + I64 = 10, + /// 32-bit signed integer + I32 = 11, + /// 16-bit signed integer + I16 = 12, + /// 8-bit signed integer + I8 = 13, + + // Unsigned integer types + /// 64-bit unsigned integer + U64 = 20, + /// 32-bit unsigned integer + U32 = 21, + /// 16-bit unsigned integer + U16 = 22, + /// 8-bit unsigned integer + U8 = 23, + + /// Boolean type + Bool = 30, + + // Complex types + /// 64-bit complex (two f32: re, im) + Complex64 = 40, + /// 128-bit complex (two f64: re, im) + Complex128 = 41, +} + +impl DType { + /// Size of one element in bytes + #[inline] + pub const fn size_in_bytes(self) -> usize { + match self { + Self::Complex128 => 16, + Self::F64 | Self::I64 | Self::U64 | Self::Complex64 => 8, + Self::F32 | Self::I32 | Self::U32 => 4, + Self::F16 | Self::BF16 | Self::I16 | Self::U16 => 2, + Self::FP8E4M3 | Self::FP8E5M2 | Self::I8 | Self::U8 | Self::Bool => 1, + } + } + + /// Returns true if this is a floating point type + #[inline] + pub const fn is_float(self) -> bool { + matches!( + self, + Self::F64 | Self::F32 | Self::F16 | Self::BF16 | Self::FP8E4M3 | Self::FP8E5M2 + ) + } + + /// Returns true if this is a complex number type + #[inline] + pub const fn is_complex(self) -> bool { + matches!(self, Self::Complex64 | Self::Complex128) + } + + /// Returns the underlying float type for complex types + /// Returns None for non-complex types + #[inline] + pub const fn complex_component_dtype(self) -> Option { + match self { + Self::Complex64 => Some(Self::F32), + Self::Complex128 => Some(Self::F64), + _ => None, + } + } + + /// Returns true if this is a signed integer type + #[inline] + pub const fn is_signed_int(self) -> bool { + matches!(self, Self::I64 | Self::I32 | Self::I16 | Self::I8) + } + + /// Returns true if this is an unsigned integer type + #[inline] + pub const fn is_unsigned_int(self) -> bool { + matches!(self, Self::U64 | Self::U32 | Self::U16 | Self::U8) + } + + /// Returns true if this is any integer type (signed or unsigned) + #[inline] + pub const fn is_int(self) -> bool { + self.is_signed_int() || self.is_unsigned_int() + } + + /// Returns true if this is a boolean type + #[inline] + pub const fn is_bool(self) -> bool { + matches!(self, Self::Bool) + } + + /// Returns true if this type can represent negative values + #[inline] + pub const fn is_signed(self) -> bool { + self.is_float() || self.is_signed_int() || self.is_complex() + } + + /// Get the default dtype for floating point operations + #[inline] + pub const fn default_float() -> Self { + Self::F32 + } + + /// Get the default dtype for integer operations + #[inline] + pub const fn default_int() -> Self { + Self::I64 + } + + /// Short name for display (e.g., "f32", "i64") + pub const fn short_name(self) -> &'static str { + match self { + Self::F64 => "f64", + Self::F32 => "f32", + Self::F16 => "f16", + Self::BF16 => "bf16", + Self::FP8E4M3 => "fp8e4m3", + Self::FP8E5M2 => "fp8e5m2", + Self::I64 => "i64", + Self::I32 => "i32", + Self::I16 => "i16", + Self::I8 => "i8", + Self::U64 => "u64", + Self::U32 => "u32", + Self::U16 => "u16", + Self::U8 => "u8", + Self::Bool => "bool", + Self::Complex64 => "c64", + Self::Complex128 => "c128", + } + } + + /// Minimum value representable by this dtype (as f64) + /// + /// For complex types, returns the minimum value of each component + pub fn min_value(self) -> f64 { + match self { + Self::F64 => f64::MIN, + Self::F32 => f32::MIN as f64, + Self::F16 => -65504.0, // IEEE 754 half precision + Self::BF16 => -3.4e38, // Approximate + Self::FP8E4M3 => -448.0, // 1 sign + 4 exp + 3 mant + Self::FP8E5M2 => -57344.0, // 1 sign + 5 exp + 2 mant + Self::I64 => i64::MIN as f64, + Self::I32 => i32::MIN as f64, + Self::I16 => i16::MIN as f64, + Self::I8 => i8::MIN as f64, + Self::U64 => 0.0, + Self::U32 => 0.0, + Self::U16 => 0.0, + Self::U8 => 0.0, + Self::Bool => 0.0, + Self::Complex64 => f32::MIN as f64, + Self::Complex128 => f64::MIN, + } + } + + /// Fill a buffer with `count` elements of this DType set to `value`. + /// + /// Returns the raw bytes. Used by generic constructors (zeros, ones, full_scalar). + pub fn fill_bytes_impl(self, value: f64, count: usize) -> Vec { + #[inline] + fn typed_to_bytes(v: Vec) -> Vec { + bytemuck::cast_slice::(&v).to_vec() + } + + match self { + DType::F64 => typed_to_bytes(vec![value; count]), + DType::F32 => typed_to_bytes(vec![value as f32; count]), + DType::F16 => { + let bits = crate::dtype::half_from_f32_util(value as f32, true); + typed_to_bytes(vec![bits; count]) + } + DType::BF16 => { + let bits = crate::dtype::half_from_f32_util(value as f32, false); + typed_to_bytes(vec![bits; count]) + } + DType::FP8E4M3 => { + vec![FP8E4M3::from_f32(value as f32).to_bits(); count] + } + DType::FP8E5M2 => { + vec![FP8E5M2::from_f32(value as f32).to_bits(); count] + } + DType::I64 => typed_to_bytes(vec![value as i64; count]), + DType::I32 => typed_to_bytes(vec![value as i32; count]), + DType::I16 => typed_to_bytes(vec![value as i16; count]), + DType::I8 => typed_to_bytes(vec![value as i8; count]), + DType::U64 => typed_to_bytes(vec![value as u64; count]), + DType::U32 => typed_to_bytes(vec![value as u32; count]), + DType::U16 => typed_to_bytes(vec![value as u16; count]), + DType::U8 => vec![value as u8; count], + DType::Bool => vec![if value != 0.0 { 1u8 } else { 0u8 }; count], + DType::Complex64 => typed_to_bytes(vec![Complex64::new(value as f32, 0.0); count]), + DType::Complex128 => typed_to_bytes(vec![Complex128::new(value, 0.0); count]), + } + } + + /// Maximum value representable by this dtype (as f64) + /// + /// For complex types, returns the maximum value of each component + pub fn max_value(self) -> f64 { + match self { + Self::F64 => f64::MAX, + Self::F32 => f32::MAX as f64, + Self::F16 => 65504.0, + Self::BF16 => 3.4e38, + Self::FP8E4M3 => 448.0, + Self::FP8E5M2 => 57344.0, + Self::I64 => i64::MAX as f64, + Self::I32 => i32::MAX as f64, + Self::I16 => i16::MAX as f64, + Self::I8 => i8::MAX as f64, + Self::U64 => u64::MAX as f64, + Self::U32 => u32::MAX as f64, + Self::U16 => u16::MAX as f64, + Self::U8 => u8::MAX as f64, + Self::Bool => 1.0, + Self::Complex64 => f32::MAX as f64, + Self::Complex128 => f64::MAX, + } + } +} + +impl fmt::Display for DType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.short_name()) + } +} diff --git a/src/dtype/dtype_set.rs b/src/dtype/dtype_set.rs new file mode 100644 index 00000000..4e396489 --- /dev/null +++ b/src/dtype/dtype_set.rs @@ -0,0 +1,91 @@ +//! Efficient bitset for DType membership testing. + +use super::DType; + +/// Set of dtypes for efficient membership testing +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct DTypeSet { + bits: u64, +} + +impl DTypeSet { + /// Empty set + pub const EMPTY: Self = Self { bits: 0 }; + + /// All floating point types + pub const FLOATS: Self = Self { + bits: (1 << DType::F64 as u8) + | (1 << DType::F32 as u8) + | (1 << DType::F16 as u8) + | (1 << DType::BF16 as u8) + | (1 << DType::FP8E4M3 as u8) + | (1 << DType::FP8E5M2 as u8), + }; + + /// All signed integer types + pub const SIGNED_INTS: Self = Self { + bits: (1 << DType::I64 as u8) + | (1 << DType::I32 as u8) + | (1 << DType::I16 as u8) + | (1 << DType::I8 as u8), + }; + + /// All unsigned integer types + pub const UNSIGNED_INTS: Self = Self { + bits: (1 << DType::U64 as u8) + | (1 << DType::U32 as u8) + | (1 << DType::U16 as u8) + | (1 << DType::U8 as u8), + }; + + /// All integer types + pub const INTS: Self = Self { + bits: Self::SIGNED_INTS.bits | Self::UNSIGNED_INTS.bits, + }; + + /// All numeric types (floats + ints) + pub const NUMERIC: Self = Self { + bits: Self::FLOATS.bits | Self::INTS.bits, + }; + + /// All complex types + pub const COMPLEX: Self = Self { + bits: (1 << DType::Complex64 as u8) | (1 << DType::Complex128 as u8), + }; + + /// Create a set containing a single dtype + #[inline] + pub const fn single(dtype: DType) -> Self { + Self { + bits: 1 << dtype as u8, + } + } + + /// Check if the set contains a dtype + #[inline] + pub const fn contains(self, dtype: DType) -> bool { + self.bits & (1 << dtype as u8) != 0 + } + + /// Union of two sets + #[inline] + pub const fn union(self, other: Self) -> Self { + Self { + bits: self.bits | other.bits, + } + } + + /// Intersection of two sets + #[inline] + pub const fn intersection(self, other: Self) -> Self { + Self { + bits: self.bits & other.bits, + } + } + + /// Check if set is empty + #[inline] + pub const fn is_empty(self) -> bool { + self.bits == 0 + } +} diff --git a/src/dtype/fp8.rs b/src/dtype/fp8.rs index e763b90e..c3d49985 100644 --- a/src/dtype/fp8.rs +++ b/src/dtype/fp8.rs @@ -208,6 +208,34 @@ impl Div for FP8E4M3 { } } +impl std::ops::AddAssign for FP8E4M3 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = Self::from_f32(self.to_f32() + rhs.to_f32()); + } +} + +impl std::ops::SubAssign for FP8E4M3 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = Self::from_f32(self.to_f32() - rhs.to_f32()); + } +} + +impl std::ops::MulAssign for FP8E4M3 { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = Self::from_f32(self.to_f32() * rhs.to_f32()); + } +} + +impl std::ops::DivAssign for FP8E4M3 { + #[inline] + fn div_assign(&mut self, rhs: Self) { + *self = Self::from_f32(self.to_f32() / rhs.to_f32()); + } +} + // ============================================================================ // FP8E5M2 Type // ============================================================================ @@ -389,6 +417,34 @@ impl Div for FP8E5M2 { } } +impl std::ops::AddAssign for FP8E5M2 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = Self::from_f32(self.to_f32() + rhs.to_f32()); + } +} + +impl std::ops::SubAssign for FP8E5M2 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = Self::from_f32(self.to_f32() - rhs.to_f32()); + } +} + +impl std::ops::MulAssign for FP8E5M2 { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = Self::from_f32(self.to_f32() * rhs.to_f32()); + } +} + +impl std::ops::DivAssign for FP8E5M2 { + #[inline] + fn div_assign(&mut self, rhs: Self) { + *self = Self::from_f32(self.to_f32() / rhs.to_f32()); + } +} + // ============================================================================ // CUDA Trait Implementations // ============================================================================ diff --git a/src/dtype/half_util.rs b/src/dtype/half_util.rs new file mode 100644 index 00000000..e365ae87 --- /dev/null +++ b/src/dtype/half_util.rs @@ -0,0 +1,47 @@ +//! Half-precision float conversion utilities. + +/// Convert f32 to half-precision bit representation. +/// +/// If `is_f16` is true, converts to IEEE 754 half-precision (F16). +/// If false, converts to brain floating point (BF16). +/// +/// This is a simple conversion for common cases. For full compliance, +/// enable the `f16` feature which uses the `half` crate. +pub fn half_from_f32_util(value: f32, is_f16: bool) -> u16 { + #[cfg(feature = "f16")] + { + if is_f16 { + half::f16::from_f32(value).to_bits() + } else { + half::bf16::from_f32(value).to_bits() + } + } + #[cfg(not(feature = "f16"))] + { + let bits = value.to_bits(); + let sign = (bits >> 31) & 1; + let exp = ((bits >> 23) & 0xFF) as i32; + let frac = bits & 0x7FFFFF; + + if !is_f16 { + // BF16: truncate mantissa + ((bits >> 16) & 0xFFFF) as u16 + } else { + // F16: IEEE 754 half precision + if exp == 0 { + (sign << 15) as u16 + } else if exp == 0xFF { + ((sign << 15) | 0x7C00 | if frac != 0 { 0x200 } else { 0 }) as u16 + } else { + let new_exp = exp - 127 + 15; + if new_exp <= 0 { + (sign << 15) as u16 + } else if new_exp >= 31 { + ((sign << 15) | 0x7C00) as u16 + } else { + ((sign << 15) | ((new_exp as u32) << 10) | (frac >> 13)) as u16 + } + } + } + } +} diff --git a/src/dtype/mod.rs b/src/dtype/mod.rs index e5139edf..d50d2bbe 100644 --- a/src/dtype/mod.rs +++ b/src/dtype/mod.rs @@ -1,391 +1,25 @@ -//! Data type system for numr tensors -//! -//! This module provides the `DType` enum representing all supported element types, -//! along with type promotion rules and conversion utilities. +//! Data type system for numr tensors. pub mod complex; +mod data_type; +mod dtype_enum; +mod dtype_set; mod element; pub mod fp8; +mod half_util; +mod precision; mod promotion; pub use complex::{Complex64, Complex128}; +pub use data_type::DataType; +pub use dtype_enum::DType; +pub use dtype_set::DTypeSet; pub use element::Element; pub use fp8::{FP8E4M3, FP8E5M2}; +pub use half_util::half_from_f32_util; +pub use precision::ComputePrecision; pub use promotion::promote; -use std::fmt; - -// ============================================================================ -// Mixed Precision Configuration -// ============================================================================ - -/// Compute precision for intermediate calculations with reduced-precision types. -/// -/// When operating on reduced-precision types (F16, BF16, FP8), values are typically -/// converted to a higher precision format for computation, then converted back. -/// This allows trading off speed vs precision. -/// -/// # Precision Comparison -/// -/// | Precision | Decimal Digits | Speed | Use Case | -/// |-----------|----------------|---------|----------| -/// | **F64** | ~15-16 | Slowest | Scientific computing requiring maximum precision | -/// | **F32** | ~7 | Medium | High-precision ML, when BF16 isn't enough | -/// | **BF16** | ~3 | Fastest | ML training/inference (default, industry standard) | -/// -/// # Applicability -/// -/// - **FP8**: Always needs upcasting (8-bit storage, compute in BF16, F32, or F64) -/// - **F16/BF16**: Can optionally upcast to F32/F64 for higher precision -/// - **F32**: Can upcast to F64 for scientific computing -/// - **F64**: No upcasting needed (already highest precision) -/// -/// # Resolution Order -/// -/// `per-operation > tensor-level > client default` -/// -/// # Default -/// -/// BF16 is the default, as it provides good speed with the same dynamic range as F32. -/// This is the industry standard for mixed-precision ML training. -#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] -#[non_exhaustive] -pub enum ComputePrecision { - /// Compute in F64 (highest precision, slowest) - /// Use for: scientific simulations, physics, when F32 precision is insufficient - F64, - /// Compute in F32 (high precision, medium speed) - /// Use for: high-precision ML, numerical algorithms sensitive to rounding - F32, - /// Compute in BF16 (lower precision, fastest, industry standard for ML) - /// Use for: ML training/inference, when speed matters more than precision - #[default] - BF16, -} - -// ============================================================================ -// DType Enum -// ============================================================================ - -/// Data types supported by numr tensors -/// -/// This enum represents the element type of a tensor at runtime. -/// Using an enum (rather than generics) allows: -/// - Mixed-precision operations -/// - Runtime type selection -/// - Support for quantized types that aren't `Copy` -/// -/// # Discriminant Values (Serialization Stability) -/// -/// The discriminant values are **stable** for serialization purposes: -/// - Floats: 0-9 (F64=0, F32=1, F16=2, BF16=3, FP8E4M3=4, FP8E5M2=5) -/// - Signed ints: 10-19 (I64=10, I32=11, I16=12, I8=13) -/// - Unsigned ints: 20-29 (U64=20, U32=21, U16=22, U8=23) -/// - Bool: 30 -/// - Complex: 40-49 (Complex64=40, Complex128=41) -/// -/// New types will use reserved ranges. Existing values are NEVER changed. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -#[non_exhaustive] -#[repr(u8)] -pub enum DType { - // Floating point types (0-9) - /// 64-bit floating point - F64 = 0, - /// 32-bit floating point (most common) - F32 = 1, - /// 16-bit floating point (IEEE 754) - F16 = 2, - /// 16-bit brain floating point - BF16 = 3, - /// 8-bit floating point (1 sign + 4 exp + 3 mant), range ~[-448, 448] - /// Best for: weights, activations (higher precision, smaller range) - FP8E4M3 = 4, - /// 8-bit floating point (1 sign + 5 exp + 2 mant), range ~[-57344, 57344] - /// Best for: gradients (larger dynamic range, lower precision) - FP8E5M2 = 5, - - // Integer types - /// 64-bit signed integer - I64 = 10, - /// 32-bit signed integer - I32 = 11, - /// 16-bit signed integer - I16 = 12, - /// 8-bit signed integer - I8 = 13, - - // Unsigned integer types - /// 64-bit unsigned integer - U64 = 20, - /// 32-bit unsigned integer - U32 = 21, - /// 16-bit unsigned integer - U16 = 22, - /// 8-bit unsigned integer - U8 = 23, - - /// Boolean type - Bool = 30, - - // Complex types - /// 64-bit complex (two f32: re, im) - Complex64 = 40, - /// 128-bit complex (two f64: re, im) - Complex128 = 41, -} - -impl DType { - /// Size of one element in bytes - #[inline] - pub const fn size_in_bytes(self) -> usize { - match self { - Self::Complex128 => 16, - Self::F64 | Self::I64 | Self::U64 | Self::Complex64 => 8, - Self::F32 | Self::I32 | Self::U32 => 4, - Self::F16 | Self::BF16 | Self::I16 | Self::U16 => 2, - Self::FP8E4M3 | Self::FP8E5M2 | Self::I8 | Self::U8 | Self::Bool => 1, - } - } - - /// Returns true if this is a floating point type - #[inline] - pub const fn is_float(self) -> bool { - matches!( - self, - Self::F64 | Self::F32 | Self::F16 | Self::BF16 | Self::FP8E4M3 | Self::FP8E5M2 - ) - } - - /// Returns true if this is a complex number type - #[inline] - pub const fn is_complex(self) -> bool { - matches!(self, Self::Complex64 | Self::Complex128) - } - - /// Returns the underlying float type for complex types - /// Returns None for non-complex types - #[inline] - pub const fn complex_component_dtype(self) -> Option { - match self { - Self::Complex64 => Some(Self::F32), - Self::Complex128 => Some(Self::F64), - _ => None, - } - } - - /// Returns true if this is a signed integer type - #[inline] - pub const fn is_signed_int(self) -> bool { - matches!(self, Self::I64 | Self::I32 | Self::I16 | Self::I8) - } - - /// Returns true if this is an unsigned integer type - #[inline] - pub const fn is_unsigned_int(self) -> bool { - matches!(self, Self::U64 | Self::U32 | Self::U16 | Self::U8) - } - - /// Returns true if this is any integer type (signed or unsigned) - #[inline] - pub const fn is_int(self) -> bool { - self.is_signed_int() || self.is_unsigned_int() - } - - /// Returns true if this is a boolean type - #[inline] - pub const fn is_bool(self) -> bool { - matches!(self, Self::Bool) - } - - /// Returns true if this type can represent negative values - #[inline] - pub const fn is_signed(self) -> bool { - self.is_float() || self.is_signed_int() || self.is_complex() - } - - /// Get the default dtype for floating point operations - #[inline] - pub const fn default_float() -> Self { - Self::F32 - } - - /// Get the default dtype for integer operations - #[inline] - pub const fn default_int() -> Self { - Self::I64 - } - - /// Short name for display (e.g., "f32", "i64") - pub const fn short_name(self) -> &'static str { - match self { - Self::F64 => "f64", - Self::F32 => "f32", - Self::F16 => "f16", - Self::BF16 => "bf16", - Self::FP8E4M3 => "fp8e4m3", - Self::FP8E5M2 => "fp8e5m2", - Self::I64 => "i64", - Self::I32 => "i32", - Self::I16 => "i16", - Self::I8 => "i8", - Self::U64 => "u64", - Self::U32 => "u32", - Self::U16 => "u16", - Self::U8 => "u8", - Self::Bool => "bool", - Self::Complex64 => "c64", - Self::Complex128 => "c128", - } - } - - /// Minimum value representable by this dtype (as f64) - /// - /// For complex types, returns the minimum value of each component - pub fn min_value(self) -> f64 { - match self { - Self::F64 => f64::MIN, - Self::F32 => f32::MIN as f64, - Self::F16 => -65504.0, // IEEE 754 half precision - Self::BF16 => -3.4e38, // Approximate - Self::FP8E4M3 => -448.0, // 1 sign + 4 exp + 3 mant - Self::FP8E5M2 => -57344.0, // 1 sign + 5 exp + 2 mant - Self::I64 => i64::MIN as f64, - Self::I32 => i32::MIN as f64, - Self::I16 => i16::MIN as f64, - Self::I8 => i8::MIN as f64, - Self::U64 => 0.0, - Self::U32 => 0.0, - Self::U16 => 0.0, - Self::U8 => 0.0, - Self::Bool => 0.0, - // Complex types: component min - Self::Complex64 => f32::MIN as f64, - Self::Complex128 => f64::MIN, - } - } - - /// Maximum value representable by this dtype (as f64) - /// - /// For complex types, returns the maximum value of each component - pub fn max_value(self) -> f64 { - match self { - Self::F64 => f64::MAX, - Self::F32 => f32::MAX as f64, - Self::F16 => 65504.0, - Self::BF16 => 3.4e38, - Self::FP8E4M3 => 448.0, - Self::FP8E5M2 => 57344.0, - Self::I64 => i64::MAX as f64, - Self::I32 => i32::MAX as f64, - Self::I16 => i16::MAX as f64, - Self::I8 => i8::MAX as f64, - Self::U64 => u64::MAX as f64, - Self::U32 => u32::MAX as f64, - Self::U16 => u16::MAX as f64, - Self::U8 => u8::MAX as f64, - Self::Bool => 1.0, - // Complex types: component max - Self::Complex64 => f32::MAX as f64, - Self::Complex128 => f64::MAX, - } - } -} - -impl fmt::Display for DType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.short_name()) - } -} - -/// Set of dtypes for efficient membership testing -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct DTypeSet { - bits: u64, -} - -impl DTypeSet { - /// Empty set - pub const EMPTY: Self = Self { bits: 0 }; - - /// All floating point types - pub const FLOATS: Self = Self { - bits: (1 << DType::F64 as u8) - | (1 << DType::F32 as u8) - | (1 << DType::F16 as u8) - | (1 << DType::BF16 as u8) - | (1 << DType::FP8E4M3 as u8) - | (1 << DType::FP8E5M2 as u8), - }; - - /// All signed integer types - pub const SIGNED_INTS: Self = Self { - bits: (1 << DType::I64 as u8) - | (1 << DType::I32 as u8) - | (1 << DType::I16 as u8) - | (1 << DType::I8 as u8), - }; - - /// All unsigned integer types - pub const UNSIGNED_INTS: Self = Self { - bits: (1 << DType::U64 as u8) - | (1 << DType::U32 as u8) - | (1 << DType::U16 as u8) - | (1 << DType::U8 as u8), - }; - - /// All integer types - pub const INTS: Self = Self { - bits: Self::SIGNED_INTS.bits | Self::UNSIGNED_INTS.bits, - }; - - /// All numeric types (floats + ints) - pub const NUMERIC: Self = Self { - bits: Self::FLOATS.bits | Self::INTS.bits, - }; - - /// All complex types - pub const COMPLEX: Self = Self { - bits: (1 << DType::Complex64 as u8) | (1 << DType::Complex128 as u8), - }; - - /// Create a set containing a single dtype - #[inline] - pub const fn single(dtype: DType) -> Self { - Self { - bits: 1 << dtype as u8, - } - } - - /// Check if the set contains a dtype - #[inline] - pub const fn contains(self, dtype: DType) -> bool { - self.bits & (1 << dtype as u8) != 0 - } - - /// Union of two sets - #[inline] - pub const fn union(self, other: Self) -> Self { - Self { - bits: self.bits | other.bits, - } - } - - /// Intersection of two sets - #[inline] - pub const fn intersection(self, other: Self) -> Self { - Self { - bits: self.bits & other.bits, - } - } - - /// Check if set is empty - #[inline] - pub const fn is_empty(self) -> bool { - self.bits == 0 - } -} - #[cfg(test)] mod tests { use super::*; @@ -397,7 +31,6 @@ mod tests { assert_eq!(DType::F16.size_in_bytes(), 2); assert_eq!(DType::I8.size_in_bytes(), 1); assert_eq!(DType::Bool.size_in_bytes(), 1); - // FP8 types are 1 byte assert_eq!(DType::FP8E4M3.size_in_bytes(), 1); assert_eq!(DType::FP8E5M2.size_in_bytes(), 1); } @@ -409,7 +42,6 @@ mod tests { assert!(DType::I32.is_signed_int()); assert!(DType::U32.is_unsigned_int()); assert!(!DType::U32.is_signed()); - // FP8 types are floats assert!(DType::FP8E4M3.is_float()); assert!(DType::FP8E5M2.is_float()); assert!(DType::FP8E4M3.is_signed()); @@ -423,17 +55,14 @@ mod tests { assert!(DTypeSet::INTS.contains(DType::I32)); assert!(DTypeSet::NUMERIC.contains(DType::F32)); assert!(DTypeSet::NUMERIC.contains(DType::I32)); - // FP8 types in FLOATS set assert!(DTypeSet::FLOATS.contains(DType::FP8E4M3)); assert!(DTypeSet::FLOATS.contains(DType::FP8E5M2)); } #[test] fn test_fp8_dtype_values() { - // FP8E4M3: range ~[-448, 448] assert_eq!(DType::FP8E4M3.min_value(), -448.0); assert_eq!(DType::FP8E4M3.max_value(), 448.0); - // FP8E5M2: range ~[-57344, 57344] assert_eq!(DType::FP8E5M2.min_value(), -57344.0); assert_eq!(DType::FP8E5M2.max_value(), 57344.0); } diff --git a/src/dtype/precision.rs b/src/dtype/precision.rs new file mode 100644 index 00000000..f57ccf7c --- /dev/null +++ b/src/dtype/precision.rs @@ -0,0 +1,45 @@ +//! Mixed precision configuration for intermediate calculations. + +/// Compute precision for intermediate calculations with reduced-precision types. +/// +/// When operating on reduced-precision types (F16, BF16, FP8), values are typically +/// converted to a higher precision format for computation, then converted back. +/// This allows trading off speed vs precision. +/// +/// # Precision Comparison +/// +/// | Precision | Decimal Digits | Speed | Use Case | +/// |-----------|----------------|---------|----------| +/// | **F64** | ~15-16 | Slowest | Scientific computing requiring maximum precision | +/// | **F32** | ~7 | Medium | High-precision ML, when BF16 isn't enough | +/// | **BF16** | ~3 | Fastest | ML training/inference (default, industry standard) | +/// +/// # Applicability +/// +/// - **FP8**: Always needs upcasting (8-bit storage, compute in BF16, F32, or F64) +/// - **F16/BF16**: Can optionally upcast to F32/F64 for higher precision +/// - **F32**: Can upcast to F64 for scientific computing +/// - **F64**: No upcasting needed (already highest precision) +/// +/// # Resolution Order +/// +/// `per-operation > tensor-level > client default` +/// +/// # Default +/// +/// BF16 is the default, as it provides good speed with the same dynamic range as F32. +/// This is the industry standard for mixed-precision ML training. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum ComputePrecision { + /// Compute in F64 (highest precision, slowest) + /// Use for: scientific simulations, physics, when F32 precision is insufficient + F64, + /// Compute in F32 (high precision, medium speed) + /// Use for: high-precision ML, numerical algorithms sensitive to rounding + F32, + /// Compute in BF16 (lower precision, fastest, industry standard for ML) + /// Use for: ML training/inference, when speed matters more than precision + #[default] + BF16, +} diff --git a/src/error.rs b/src/error.rs index feddc785..9ef8b08a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -112,6 +112,10 @@ pub enum Error { #[error("CUDA error: {0}")] Cuda(#[from] cudarc::driver::DriverError), + /// Generic message error + #[error("{0}")] + Msg(String), + /// Generic internal error #[error("Internal error: {0}")] Internal(String), @@ -133,6 +137,17 @@ pub enum Error { /// The cargo feature name to enable feature: &'static str, }, + + /// Allocator cannot reset while allocations are still live + #[error("Allocator busy: {active_allocations} allocations still active")] + AllocatorBusy { + /// Number of allocations that are still live + active_allocations: usize, + }, + + /// Allocator is frozen — no new allocations permitted + #[error("Allocator frozen: allocation rejected while frozen")] + AllocatorFrozen, } impl Error { diff --git a/src/lib.rs b/src/lib.rs index c51576d8..a4e31bd4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,8 +27,9 @@ //! ```rust,ignore //! use numr::prelude::*; //! -//! let a = Tensor::::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2])?; -//! let b = Tensor::::from_slice(&[5.0, 6.0, 7.0, 8.0], &[2, 2])?; +//! let device = CpuDevice; +//! let a = Tensor::::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2], &device); +//! let b = Tensor::::from_slice(&[5.0, 6.0, 7.0, 8.0], &[2, 2], &device); //! //! let c = &a + &b; //! let d = a.matmul(&b)?; @@ -94,7 +95,7 @@ pub mod tensor; /// - Backend runtimes: `CpuRuntime`, `CudaRuntime`, `WgpuRuntime` (feature-gated) pub mod prelude { // Core types - pub use crate::dtype::DType; + pub use crate::dtype::{DType, DataType}; pub use crate::error::{Error, Result}; pub use crate::tensor::{Layout, Shape, Strides, Tensor}; @@ -103,19 +104,18 @@ pub mod prelude { // Operation traits (same API across all backends) pub use crate::ops::{ - ActivationOps, AdvancedRandomOps, BinaryOps, CompareOps, ComplexOps, ConditionalOps, - ConvOps, CumulativeOps, DistanceMetric, DistanceOps, IndexingOps, LinalgOps, LogicalOps, - MatmulOps, MeshgridIndexing, MultivariateRandomOps, NormalizationOps, PaddingMode, - QuasiRandomOps, RandomOps, ReduceOps, ScalarOps, ShapeOps, SortingOps, StatisticalOps, - TensorOps, TypeConversionOps, UnaryOps, UtilityOps, + ActivationOps, BinaryOps, CompareOps, ComplexOps, ConditionalOps, ConvOps, CumulativeOps, + DistanceMetric, DistanceOps, IndexingOps, LinalgOps, LogicalOps, MatmulOps, + MeshgridIndexing, NormalizationOps, PaddingMode, ReduceOps, ScalarOps, ShapeOps, + SortingOps, StatisticalOps, TensorOps, TypeConversionOps, UnaryOps, UtilityOps, }; + pub use crate::ops::{AdvancedRandomOps, MultivariateRandomOps, QuasiRandomOps, RandomOps}; // Algorithm traits pub use crate::algorithm::SpecialFunctions; pub use crate::algorithm::fft::{FftAlgorithms, FftDirection, FftNormalization}; // Backend runtimes - #[cfg(feature = "cpu")] pub use crate::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime, ParallelismConfig}; #[cfg(feature = "cuda")] @@ -126,7 +126,9 @@ pub mod prelude { // Sparse tensors (feature-gated) #[cfg(feature = "sparse")] - pub use crate::sparse::{SparseFormat, SparseOps, SparseTensor}; + pub use crate::sparse::Sparse24Ops; + #[cfg(feature = "sparse")] + pub use crate::sparse::{Sparse24Tensor, SparseFormat, SparseOps, SparseTensor}; } /// Default runtime based on enabled features diff --git a/src/ops/common/complex_validation.rs b/src/ops/common/complex_validation.rs index 3a92b15a..ce86a29b 100644 --- a/src/ops/common/complex_validation.rs +++ b/src/ops/common/complex_validation.rs @@ -20,7 +20,10 @@ use crate::tensor::Tensor; /// - `ShapeMismatch` if real and imag have different shapes /// - `DTypeMismatch` if real and imag have different dtypes /// - `UnsupportedDType` if dtype is not F32 or F64 -pub fn validate_make_complex_inputs(real: &Tensor, imag: &Tensor) -> Result<()> { +pub fn validate_make_complex_inputs>( + real: &Tensor, + imag: &Tensor, +) -> Result<()> { // Check shapes match if real.shape() != imag.shape() { return Err(Error::ShapeMismatch { @@ -57,7 +60,7 @@ pub fn validate_make_complex_inputs(real: &Tensor, imag: &Tensor< /// - `DTypeMismatch` if real and imag have different dtypes /// - `UnsupportedDType` if dtype is not F32 #[cfg(feature = "wgpu")] -pub fn validate_make_complex_inputs_f32_only( +pub fn validate_make_complex_inputs_f32_only>( real: &Tensor, imag: &Tensor, ) -> Result<()> { @@ -103,7 +106,7 @@ pub fn validate_make_complex_inputs_f32_only( /// - `ShapeMismatch` if shapes don't match /// - `DTypeMismatch` if real dtype doesn't match complex component dtype /// - `UnsupportedDType` if complex is not Complex64/Complex128 -pub fn validate_complex_real_inputs( +pub fn validate_complex_real_inputs>( complex: &Tensor, real: &Tensor, op: &'static str, @@ -142,7 +145,7 @@ pub fn validate_complex_real_inputs( /// - `DTypeMismatch` if real dtype is not F32 /// - `UnsupportedDType` if complex is not Complex64 or if Complex128 is used #[cfg(feature = "wgpu")] -pub fn validate_complex_real_inputs_f32_only( +pub fn validate_complex_real_inputs_f32_only>( complex: &Tensor, real: &Tensor, op: &'static str, diff --git a/src/ops/cpu/activation.rs b/src/ops/cpu/activation.rs index 886d1c2c..add5eb66 100644 --- a/src/ops/cpu/activation.rs +++ b/src/ops/cpu/activation.rs @@ -1,12 +1,16 @@ //! CPU implementation of activation operations. use crate::error::{Error, Result}; -use crate::ops::{ActivationOps, activation::normalize_softmax_dim}; +use crate::ops::impl_generic::activation::{dropout_impl, log_softmax_impl, softplus_impl}; +use crate::ops::{ + ActivationOps, BinaryOps, CompareOps, ConditionalOps, ScalarOps, UnaryOps, + activation::normalize_softmax_dim, +}; use crate::runtime::cpu::{ CpuClient, CpuRuntime, helpers::{ - ActivationOp, activation_op_impl, dispatch_dtype, elu_impl, ensure_contiguous, - leaky_relu_impl, + ActivationOp, FusedActivationMulOp, activation_op_impl, dispatch_dtype, elu_impl, + ensure_contiguous, fused_activation_mul_impl, leaky_relu_impl, }, kernels, }; @@ -30,6 +34,131 @@ impl ActivationOps for CpuClient { activation_op_impl(self, a, ActivationOp::Gelu, "gelu") } + fn silu_mul( + &self, + a: &Tensor, + b: &Tensor, + ) -> Result> { + fused_activation_mul_impl(self, a, b, FusedActivationMulOp::SiluMul, "silu_mul") + } + + fn gelu_mul( + &self, + a: &Tensor, + b: &Tensor, + ) -> Result> { + fused_activation_mul_impl(self, a, b, FusedActivationMulOp::GeluMul, "gelu_mul") + } + + fn relu_mul( + &self, + a: &Tensor, + b: &Tensor, + ) -> Result> { + fused_activation_mul_impl(self, a, b, FusedActivationMulOp::ReluMul, "relu_mul") + } + + fn sigmoid_mul( + &self, + a: &Tensor, + b: &Tensor, + ) -> Result> { + fused_activation_mul_impl(self, a, b, FusedActivationMulOp::SigmoidMul, "sigmoid_mul") + } + + fn silu_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + // silu(a) = a * sigmoid(a) + let silu_a = self.silu(a)?; + let d_b = self.mul(grad, &silu_a)?; + // silu'(x) = sigmoid(x) * (1 + x - silu(x)) + let sigmoid_a = self.sigmoid(a)?; + let one_plus_a = self.add_scalar(a, 1.0)?; + let one_plus_a_minus_silu = self.sub(&one_plus_a, &silu_a)?; + let silu_deriv = self.mul(&sigmoid_a, &one_plus_a_minus_silu)?; + let grad_times_b = self.mul(grad, b)?; + let d_a = self.mul(&grad_times_b, &silu_deriv)?; + Ok((d_a, d_b)) + } + + fn gelu_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let gelu_a = self.gelu(a)?; + let d_b = self.mul(grad, &gelu_a)?; + // gelu'(x) = 0.5*(1+tanh(inner)) + 0.5*x*sech²(inner)*inner' + // inner = sqrt(2/π) * (x + 0.044715*x³), inner' = sqrt(2/π)*(1 + 3*0.044715*x²) + let x_sq = self.mul(a, a)?; + let x_cu = self.mul(&x_sq, a)?; + let coef_x_cu = self.mul_scalar(&x_cu, 0.044715)?; + let inner_arg = self.add(a, &coef_x_cu)?; + let sqrt_2_pi: f64 = 0.7978845608028654; + let inner = self.mul_scalar(&inner_arg, sqrt_2_pi)?; + // Use tanh op directly — avoids exp overflow for low-precision dtypes (F16/FP8) + let tanh_inner = self.tanh(&inner)?; + // term1 = 0.5*(1+tanh(inner)) + let one_plus_tanh = self.add_scalar(&tanh_inner, 1.0)?; + let term1 = self.mul_scalar(&one_plus_tanh, 0.5)?; + // sech²(inner) = 1 - tanh²(inner) + let tanh_sq = self.mul(&tanh_inner, &tanh_inner)?; + let sech_sq = self.add_scalar(&tanh_sq, -1.0)?; + let sech_sq = self.neg(&sech_sq)?; + // inner' = sqrt(2/π) * (1 + 3*0.044715*x²) + let three_coef_x_sq = self.mul_scalar(&x_sq, 3.0 * 0.044715)?; + let inner_deriv_unscaled = self.add_scalar(&three_coef_x_sq, 1.0)?; + let inner_deriv = self.mul_scalar(&inner_deriv_unscaled, sqrt_2_pi)?; + // term2 = 0.5 * x * sech²(inner) * inner' + let x_sech_sq = self.mul(a, &sech_sq)?; + let x_sech_sq_inner_d = self.mul(&x_sech_sq, &inner_deriv)?; + let term2 = self.mul_scalar(&x_sech_sq_inner_d, 0.5)?; + let gelu_deriv = self.add(&term1, &term2)?; + let grad_times_b = self.mul(grad, b)?; + let d_a = self.mul(&grad_times_b, &gelu_deriv)?; + Ok((d_a, d_b)) + } + + fn relu_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let relu_a = self.relu(a)?; + let d_b = self.mul(grad, &relu_a)?; + // relu'(x) = 1 if x > 0, else 0 + let zeros = Tensor::::zeros(a.shape(), a.dtype(), a.device()); + let ones = Tensor::::ones(a.shape(), a.dtype(), a.device()); + let mask = self.gt(a, &zeros)?; + let relu_deriv = self.where_cond(&mask, &ones, &zeros)?; + let grad_times_b = self.mul(grad, b)?; + let d_a = self.mul(&grad_times_b, &relu_deriv)?; + Ok((d_a, d_b)) + } + + fn sigmoid_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let sigmoid_a = self.sigmoid(a)?; + let d_b = self.mul(grad, &sigmoid_a)?; + // sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x)) + let one_minus_sig = self.add_scalar(&sigmoid_a, -1.0)?; + let one_minus_sig = self.neg(&one_minus_sig)?; + let sigmoid_deriv = self.mul(&sigmoid_a, &one_minus_sig)?; + let grad_times_b = self.mul(grad, b)?; + let d_a = self.mul(&grad_times_b, &sigmoid_deriv)?; + Ok((d_a, d_b)) + } + fn leaky_relu( &self, a: &Tensor, @@ -68,8 +197,8 @@ impl ActivationOps for CpuClient { if dim_idx == ndim - 1 { // Simple case: softmax over last dimension - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -84,8 +213,8 @@ impl ActivationOps for CpuClient { } else { // General case: softmax over non-last dimension // Pre-allocate buffer outside loops to avoid repeated allocations - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -102,46 +231,262 @@ impl ActivationOps for CpuClient { Ok(out) } + + fn softmax_bwd( + &self, + grad: &Tensor, + output: &Tensor, + dim: isize, + ) -> Result> { + let dtype = grad.dtype(); + let ndim = grad.ndim(); + let dim_idx = + normalize_softmax_dim(ndim, dim).ok_or(Error::InvalidDimension { dim, ndim })?; + + let grad_contig = ensure_contiguous(grad); + let output_contig = ensure_contiguous(output); + let d_input = Tensor::::empty(grad.shape(), dtype, &self.device); + + let shape = grad.shape(); + let outer_size: usize = shape[..dim_idx].iter().product(); + let dim_size = shape[dim_idx]; + let inner_size: usize = shape[dim_idx + 1..].iter().product(); + + if dim_idx == ndim - 1 { + // Last dim: use fused SIMD kernel + let g_ptr = grad_contig.ptr(); + let o_ptr = output_contig.ptr(); + let d_ptr = d_input.ptr(); + + dispatch_dtype!(dtype, T => { + unsafe { + kernels::softmax_bwd_kernel::( + g_ptr as *const T, + o_ptr as *const T, + d_ptr as *mut T, + outer_size, + dim_size, + ); + } + }, "softmax_bwd"); + } else { + // Non-last dim: strided access pattern + let g_ptr = grad_contig.ptr(); + let o_ptr = output_contig.ptr(); + let d_ptr = d_input.ptr(); + + dispatch_dtype!(dtype, T => { + unsafe { + softmax_bwd_non_last_dim::( + g_ptr as *const T, + o_ptr as *const T, + d_ptr as *mut T, + outer_size, + dim_size, + inner_size, + ); + } + }, "softmax_bwd"); + } + + Ok(d_input) + } + + fn softplus(&self, a: &Tensor) -> Result> { + softplus_impl(self, a) + } + + fn log_softmax(&self, a: &Tensor, dim: isize) -> Result> { + log_softmax_impl(self, a, dim) + } + + fn dropout( + &self, + a: &Tensor, + p: f64, + training: bool, + ) -> Result> { + dropout_impl(self, a, p, training) + } } -unsafe fn softmax_non_last_dim( - a_ptr: *const T, - out_ptr: *mut T, +#[cfg(test)] +mod tests { + use super::*; + use crate::ops::ActivationOps; + use crate::runtime::cpu::CpuDevice; + + #[test] + fn test_log_softmax_basic() { + let device = CpuDevice::new(); + let client = CpuClient::new(device.clone()); + + let input = Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device); + let result = client.log_softmax(&input, -1).unwrap(); + let data: Vec = result.to_vec(); + + // log_softmax should sum to something reasonable + // exp(log_softmax) should sum to 1 + let exp_sum: f32 = data.iter().map(|x| x.exp()).sum(); + assert!((exp_sum - 1.0).abs() < 1e-5); + + // Values should be negative (log of probability) + for &v in &data { + assert!(v < 0.0); + } + } + + #[test] + fn test_log_softmax_2d() { + let device = CpuDevice::new(); + let client = CpuClient::new(device.clone()); + + let input = + Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], &device); + let result = client.log_softmax(&input, -1).unwrap(); + let data: Vec = result.to_vec(); + + // Each row should independently sum (in exp space) to 1 + let row1_sum: f32 = data[0..3].iter().map(|x| x.exp()).sum(); + let row2_sum: f32 = data[3..6].iter().map(|x| x.exp()).sum(); + assert!((row1_sum - 1.0).abs() < 1e-5); + assert!((row2_sum - 1.0).abs() < 1e-5); + } + + #[test] + fn test_dropout_training() { + let device = CpuDevice::new(); + let client = CpuClient::new(device.clone()); + + let input = Tensor::::ones(&[1000], crate::dtype::DType::F32, &device); + let result = client.dropout(&input, 0.5, true).unwrap(); + let data: Vec = result.to_vec(); + + // Some values should be 0 (dropped), others should be 2.0 (scaled by 1/(1-0.5)) + let zeros = data.iter().filter(|&&v| v == 0.0).count(); + let scaled = data.iter().filter(|&&v| (v - 2.0).abs() < 1e-5).count(); + + // With p=0.5, roughly half should be dropped (allow wide margin for randomness) + assert!(zeros > 200, "too few zeros: {zeros}"); + assert!(zeros < 800, "too many zeros: {zeros}"); + assert_eq!(zeros + scaled, 1000); + } + + #[test] + fn test_dropout_inference() { + let device = CpuDevice::new(); + let client = CpuClient::new(device.clone()); + + let input = Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device); + let result = client.dropout(&input, 0.5, false).unwrap(); + let data: Vec = result.to_vec(); + + // During inference, dropout is identity + assert!((data[0] - 1.0).abs() < 1e-6); + assert!((data[1] - 2.0).abs() < 1e-6); + assert!((data[2] - 3.0).abs() < 1e-6); + } + + #[test] + fn test_dropout_p_zero() { + let device = CpuDevice::new(); + let client = CpuClient::new(device.clone()); + + let input = Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device); + let result = client.dropout(&input, 0.0, true).unwrap(); + let data: Vec = result.to_vec(); + + // p=0 means no dropout + assert!((data[0] - 1.0).abs() < 1e-6); + assert!((data[1] - 2.0).abs() < 1e-6); + assert!((data[2] - 3.0).abs() < 1e-6); + } + + #[test] + fn test_dropout_p_one() { + let device = CpuDevice::new(); + let client = CpuClient::new(device.clone()); + + let input = Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device); + let result = client.dropout(&input, 1.0, true).unwrap(); + let data: Vec = result.to_vec(); + + // p=1 means all dropped + for &v in &data { + assert!((v).abs() < 1e-6); + } + } +} + +/// Softmax backward for non-last dimension (strided access pattern). +/// +/// d_input = output * (grad - dot), where dot = sum(grad * output) along dim. +unsafe fn softmax_bwd_non_last_dim( + grad: *const T, + output: *const T, + d_input: *mut T, outer_size: usize, dim_size: usize, inner_size: usize, ) { unsafe { - // Pre-allocate reusable buffer for softmax computation - let mut slice = vec![0.0f64; dim_size]; - for outer in 0..outer_size { for inner in 0..inner_size { - // Elements are at: outer * dim_size * inner_size + d * inner_size + inner let base_idx = outer * dim_size * inner_size + inner; let stride = inner_size; - // Read slice into buffer - for (d, slot) in slice.iter_mut().enumerate() { + // Pass 1: dot = sum(grad * output) along dim + let mut dot = 0.0f64; + for d in 0..dim_size { let idx = base_idx + d * stride; - *slot = (*a_ptr.add(idx)).to_f64(); + dot += (*grad.add(idx)).to_f64() * (*output.add(idx)).to_f64(); } - // Compute softmax with numerical stability - let max_val = slice.iter().cloned().fold(f64::NEG_INFINITY, f64::max); - let mut exp_sum = 0.0f64; - for val in &mut slice { - *val = (*val - max_val).exp(); - exp_sum += *val; + // Pass 2: d_input = output * (grad - dot) + for d in 0..dim_size { + let idx = base_idx + d * stride; + let g = (*grad.add(idx)).to_f64(); + let o = (*output.add(idx)).to_f64(); + *d_input.add(idx) = T::from_f64(o * (g - dot)); } + } + } + } +} + +unsafe fn softmax_non_last_dim( + a_ptr: *const T, + out_ptr: *mut T, + outer_size: usize, + dim_size: usize, + inner_size: usize, +) { + unsafe { + for outer in 0..outer_size { + for inner in 0..inner_size { + let base_idx = outer * dim_size * inner_size + inner; + let stride = inner_size; - // Handle edge case: avoid division by zero - let inv_sum = if exp_sum > 0.0 { 1.0 / exp_sum } else { 0.0 }; + // Pass 1: Online max + sum (reads strided input once) + let mut max_val = (*a_ptr.add(base_idx)).to_f64(); + let mut sum = 1.0f64; + for d in 1..dim_size { + let idx = base_idx + d * stride; + let val = (*a_ptr.add(idx)).to_f64(); + if val > max_val { + sum = sum * (max_val - val).exp() + 1.0; + max_val = val; + } else { + sum += (val - max_val).exp(); + } + } - // Write normalized values back - for (d, &val) in slice.iter().enumerate() { + // Pass 2: exp(x - max) / sum (reads input, writes output) + let inv_sum = if sum > 0.0 { 1.0 / sum } else { 0.0 }; + for d in 0..dim_size { let idx = base_idx + d * stride; - *out_ptr.add(idx) = T::from_f64(val * inv_sum); + let val = (*a_ptr.add(idx)).to_f64(); + *out_ptr.add(idx) = T::from_f64((val - max_val).exp() * inv_sum); } } } diff --git a/src/ops/cpu/advanced_random.rs b/src/ops/cpu/advanced_random.rs index d7c8b466..e5b091b2 100644 --- a/src/ops/cpu/advanced_random.rs +++ b/src/ops/cpu/advanced_random.rs @@ -29,7 +29,7 @@ impl AdvancedRandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -61,7 +61,7 @@ impl AdvancedRandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -93,7 +93,7 @@ impl AdvancedRandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -125,7 +125,7 @@ impl AdvancedRandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -157,7 +157,7 @@ impl AdvancedRandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -189,7 +189,7 @@ impl AdvancedRandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -220,7 +220,7 @@ impl AdvancedRandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -251,7 +251,7 @@ impl AdvancedRandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { diff --git a/src/ops/cpu/binary.rs b/src/ops/cpu/binary.rs index 0a01fa90..e68dd7e0 100644 --- a/src/ops/cpu/binary.rs +++ b/src/ops/cpu/binary.rs @@ -4,7 +4,7 @@ use crate::error::Result; use crate::ops::BinaryOps; use crate::runtime::cpu::{ CpuClient, CpuRuntime, - helpers::{BinaryOp, binary_op_impl}, + helpers::{BinaryOp, binary_op_impl, fused_add_mul_impl, fused_mul_add_impl}, }; use crate::tensor::Tensor; @@ -49,4 +49,22 @@ impl BinaryOps for CpuClient { fn atan2(&self, y: &Tensor, x: &Tensor) -> Result> { binary_op_impl(self, BinaryOp::Atan2, y, x, "atan2") } + + fn fused_mul_add( + &self, + a: &Tensor, + b: &Tensor, + c: &Tensor, + ) -> Result> { + fused_mul_add_impl(self, a, b, c) + } + + fn fused_add_mul( + &self, + a: &Tensor, + b: &Tensor, + c: &Tensor, + ) -> Result> { + fused_add_mul_impl(self, a, b, c) + } } diff --git a/src/ops/cpu/complex.rs b/src/ops/cpu/complex.rs index 06d4b9bf..b942483e 100644 --- a/src/ops/cpu/complex.rs +++ b/src/ops/cpu/complex.rs @@ -27,8 +27,8 @@ impl ComplexOps for CpuClient { return Ok(out); } - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); let chunk_size = self.chunk_size_hint(); match dtype { @@ -82,8 +82,8 @@ impl ComplexOps for CpuClient { return Ok(out); } - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); let chunk_size = self.chunk_size_hint(); match dtype { @@ -137,8 +137,8 @@ impl ComplexOps for CpuClient { return Ok(out); } - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); let chunk_size = self.chunk_size_hint(); match dtype { @@ -185,8 +185,8 @@ impl ComplexOps for CpuClient { return Ok(out); } - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); match dtype { DType::F32 => { @@ -230,8 +230,8 @@ impl ComplexOps for CpuClient { return Ok(out); } - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); match dtype { DType::Complex64 => { @@ -287,9 +287,9 @@ impl ComplexOps for CpuClient { return Ok(out); } - let real_ptr = real_contig.storage().ptr(); - let imag_ptr = imag_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let real_ptr = real_contig.ptr(); + let imag_ptr = imag_contig.ptr(); + let out_ptr = out.ptr(); let chunk_size = self.chunk_size_hint(); match input_dtype { @@ -341,9 +341,9 @@ impl ComplexOps for CpuClient { return Ok(out); } - let complex_ptr = complex_contig.storage().ptr(); - let real_ptr = real_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let complex_ptr = complex_contig.ptr(); + let real_ptr = real_contig.ptr(); + let out_ptr = out.ptr(); let chunk_size = self.chunk_size_hint(); match dtype { @@ -395,9 +395,9 @@ impl ComplexOps for CpuClient { return Ok(out); } - let complex_ptr = complex_contig.storage().ptr(); - let real_ptr = real_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let complex_ptr = complex_contig.ptr(); + let real_ptr = real_contig.ptr(); + let out_ptr = out.ptr(); let chunk_size = self.chunk_size_hint(); match dtype { diff --git a/src/ops/cpu/conditional.rs b/src/ops/cpu/conditional.rs index 69a4be23..7aaa8040 100644 --- a/src/ops/cpu/conditional.rs +++ b/src/ops/cpu/conditional.rs @@ -43,7 +43,7 @@ impl ConditionalOps for CpuClient { })?; let out = Tensor::::empty(&out_shape, dtype, &self.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); // Fast path: all same shape, use simple kernel if cond.shape() == x.shape() && x.shape() == y.shape() { @@ -51,9 +51,9 @@ impl ConditionalOps for CpuClient { let x_contig = ensure_contiguous(x); let y_contig = ensure_contiguous(y); - let cond_ptr = cond_contig.storage().ptr(); - let x_ptr = x_contig.storage().ptr(); - let y_ptr = y_contig.storage().ptr(); + let cond_ptr = cond_contig.ptr(); + let x_ptr = x_contig.ptr(); + let y_ptr = y_contig.ptr(); let numel = x.numel(); // Double dispatch: cond dtype and value dtype @@ -93,9 +93,9 @@ impl ConditionalOps for CpuClient { let x_broadcast = x.broadcast_to(&out_shape)?; let y_broadcast = y.broadcast_to(&out_shape)?; - let cond_ptr = cond_broadcast.storage().ptr(); - let x_ptr = x_broadcast.storage().ptr(); - let y_ptr = y_broadcast.storage().ptr(); + let cond_ptr = cond_broadcast.ptr(); + let x_ptr = x_broadcast.ptr(); + let y_ptr = y_broadcast.ptr(); // Get strides from broadcast layouts let cond_strides: Vec = cond_broadcast.layout().strides().to_vec(); diff --git a/src/ops/cpu/conv.rs b/src/ops/cpu/conv.rs index a5887a5b..1b9765b7 100644 --- a/src/ops/cpu/conv.rs +++ b/src/ops/cpu/conv.rs @@ -144,10 +144,10 @@ impl ConvOps for CpuClient { &self.device, ); - let input_ptr = input.storage().ptr(); - let weight_ptr = weight.storage().ptr(); - let bias_ptr = bias.as_ref().map(|b| b.storage().ptr()); - let output_ptr = output.storage().ptr(); + let input_ptr = input.ptr(); + let weight_ptr = weight.ptr(); + let bias_ptr = bias.as_ref().map(|b| b.ptr()); + let output_ptr = output.ptr(); dispatch_conv!( dtype, conv1d, input_ptr, weight_ptr, bias_ptr, output_ptr, params @@ -203,10 +203,10 @@ impl ConvOps for CpuClient { &self.device, ); - let input_ptr = input.storage().ptr(); - let weight_ptr = weight.storage().ptr(); - let bias_ptr = bias.as_ref().map(|b| b.storage().ptr()); - let output_ptr = output.storage().ptr(); + let input_ptr = input.ptr(); + let weight_ptr = weight.ptr(); + let bias_ptr = bias.as_ref().map(|b| b.ptr()); + let output_ptr = output.ptr(); dispatch_conv!( dtype, conv2d, input_ptr, weight_ptr, bias_ptr, output_ptr, params @@ -260,10 +260,10 @@ impl ConvOps for CpuClient { &self.device, ); - let input_ptr = input.storage().ptr(); - let weight_ptr = weight.storage().ptr(); - let bias_ptr = bias.as_ref().map(|b| b.storage().ptr()); - let output_ptr = output.storage().ptr(); + let input_ptr = input.ptr(); + let weight_ptr = weight.ptr(); + let bias_ptr = bias.as_ref().map(|b| b.ptr()); + let output_ptr = output.ptr(); dispatch_conv!( dtype, diff --git a/src/ops/cpu/distance.rs b/src/ops/cpu/distance.rs index e1619e76..198933b6 100644 --- a/src/ops/cpu/distance.rs +++ b/src/ops/cpu/distance.rs @@ -2,6 +2,8 @@ use crate::dtype::DType; use crate::error::{Error, Result}; +#[cfg(feature = "fp8")] +use crate::ops::TypeConversionOps; use crate::ops::distance_common::*; use crate::ops::{DistanceMetric, DistanceOps}; use crate::runtime::cpu::{CpuClient, CpuRuntime, helpers::ensure_contiguous, kernels}; @@ -72,9 +74,29 @@ impl DistanceOps for CpuClient { let y = ensure_contiguous(y); let out = Tensor::::empty(&[n, m], dtype, &self.device); - let x_ptr = x.storage().ptr(); - let y_ptr = y.storage().ptr(); - let out_ptr = out.storage().ptr(); + let x_ptr = x.ptr(); + let y_ptr = y.ptr(); + let out_ptr = out.ptr(); + + // FP8 types: compute in F32, then cast result back + #[cfg(feature = "fp8")] + if dtype == DType::FP8E4M3 || dtype == DType::FP8E5M2 { + let x_f32 = self.cast(&x, DType::F32)?; + let y_f32 = self.cast(&y, DType::F32)?; + let out_f32 = Tensor::::empty(&[n, m], DType::F32, &self.device); + unsafe { + kernels::cdist_kernel::( + x_f32.ptr() as *const f32, + y_f32.ptr() as *const f32, + out_f32.ptr() as *mut f32, + n, + m, + d, + metric, + ); + } + return self.cast(&out_f32, dtype); + } dispatch_float_dtype!(dtype, T => { unsafe { @@ -112,8 +134,25 @@ impl DistanceOps for CpuClient { let x = ensure_contiguous(x); let out = Tensor::::empty(&[out_size], dtype, &self.device); - let x_ptr = x.storage().ptr(); - let out_ptr = out.storage().ptr(); + let x_ptr = x.ptr(); + let out_ptr = out.ptr(); + + // FP8 types: compute in F32, then cast result back + #[cfg(feature = "fp8")] + if dtype == DType::FP8E4M3 || dtype == DType::FP8E5M2 { + let x_f32 = self.cast(&x, DType::F32)?; + let out_f32 = Tensor::::empty(&[out_size], DType::F32, &self.device); + unsafe { + kernels::pdist_kernel::( + x_f32.ptr() as *const f32, + out_f32.ptr() as *mut f32, + n, + d, + metric, + ); + } + return self.cast(&out_f32, dtype); + } dispatch_float_dtype!(dtype, T => { unsafe { @@ -151,8 +190,8 @@ impl DistanceOps for CpuClient { let condensed = ensure_contiguous(condensed); let out = Tensor::::empty(&[n, n], dtype, &self.device); - let cond_ptr = condensed.storage().ptr(); - let out_ptr = out.storage().ptr(); + let cond_ptr = condensed.ptr(); + let out_ptr = out.ptr(); dispatch_float_dtype!(dtype, T => { unsafe { @@ -191,8 +230,8 @@ impl DistanceOps for CpuClient { let out_size = n * (n - 1) / 2; let out = Tensor::::empty(&[out_size], dtype, &self.device); - let sq_ptr = square.storage().ptr(); - let out_ptr = out.storage().ptr(); + let sq_ptr = square.ptr(); + let out_ptr = out.ptr(); dispatch_float_dtype!(dtype, T => { unsafe { diff --git a/src/ops/cpu/fp8_matmul.rs b/src/ops/cpu/fp8_matmul.rs new file mode 100644 index 00000000..914d7b1d --- /dev/null +++ b/src/ops/cpu/fp8_matmul.rs @@ -0,0 +1,198 @@ +//! CPU implementation of FP8 matrix multiplication operations. +//! +//! Fused kernel: reads FP8, converts to F32 inline during accumulation, +//! applies scaling, and writes output in the target dtype. No intermediate +//! tensor allocations. + +use crate::dtype::{DType, FP8E4M3, FP8E5M2}; +use crate::error::{Error, Result}; +use crate::ops::Fp8MatmulOps; +use crate::runtime::cpu::{CpuClient, CpuRuntime}; +use crate::tensor::Tensor; + +/// Validate FP8 matmul arguments. +fn validate_fp8_matmul( + a: &Tensor, + b: &Tensor, + expected_a_dtype: DType, + expected_b_dtype: DType, + out_dtype: DType, +) -> Result<(Vec, usize, usize, usize, usize)> { + if a.dtype() != expected_a_dtype { + return Err(Error::DTypeMismatch { + lhs: a.dtype(), + rhs: expected_a_dtype, + }); + } + if b.dtype() != expected_b_dtype { + return Err(Error::DTypeMismatch { + lhs: b.dtype(), + rhs: expected_b_dtype, + }); + } + match out_dtype { + DType::F32 | DType::F16 | DType::BF16 => {} + _ => { + return Err(Error::UnsupportedDType { + dtype: out_dtype, + op: "fp8_matmul output", + }); + } + } + let a_shape = a.shape(); + let b_shape = b.shape(); + if a_shape.len() < 2 || b_shape.len() < 2 { + return Err(Error::ShapeMismatch { + expected: a_shape.to_vec(), + got: b_shape.to_vec(), + }); + } + let m = a_shape[a_shape.len() - 2]; + let k = a_shape[a_shape.len() - 1]; + let k_b = b_shape[b_shape.len() - 2]; + let n = b_shape[b_shape.len() - 1]; + if k != k_b { + return Err(Error::ShapeMismatch { + expected: a_shape.to_vec(), + got: b_shape.to_vec(), + }); + } + + let out_shape = + crate::ops::matmul_output_shape(a_shape, b_shape).ok_or(Error::ShapeMismatch { + expected: a_shape.to_vec(), + got: b_shape.to_vec(), + })?; + + let batch_size: usize = out_shape + .iter() + .take(out_shape.len().saturating_sub(2)) + .product(); + let batch_size = batch_size.max(1); + + Ok((out_shape, batch_size, m, k, n)) +} + +/// Fused FP8 matmul kernel: converts FP8→F32 inline during multiply-accumulate, +/// applies combined scale, writes output directly in target dtype. +/// +/// `convert_a` and `convert_b` are FP8→f32 conversion functions. +fn fused_fp8_matmul_kernel( + a_ptr: *const u8, + b_ptr: *const u8, + out_ptr: u64, + convert_a: fn(u8) -> f32, + convert_b: fn(u8) -> f32, + combined_scale: f32, + out_dtype: DType, + batch_size: usize, + m: usize, + k: usize, + n: usize, +) { + let a_batch_stride = m * k; + let b_batch_stride = k * n; + let out_batch_stride = m * n; + + for batch in 0..batch_size { + let a_base = unsafe { a_ptr.add(batch * a_batch_stride) }; + let b_base = unsafe { b_ptr.add(batch * b_batch_stride) }; + + for i in 0..m { + for j in 0..n { + let mut acc: f32 = 0.0; + for p in 0..k { + let a_val = convert_a(unsafe { *a_base.add(i * k + p) }); + let b_val = convert_b(unsafe { *b_base.add(p * n + j) }); + acc += a_val * b_val; + } + acc *= combined_scale; + + let out_idx = batch * out_batch_stride + i * n + j; + match out_dtype { + DType::F32 => unsafe { + let ptr = out_ptr as *mut f32; + *ptr.add(out_idx) = acc; + }, + #[cfg(feature = "f16")] + DType::F16 => unsafe { + let ptr = out_ptr as *mut half::f16; + *ptr.add(out_idx) = half::f16::from_f32(acc); + }, + #[cfg(feature = "f16")] + DType::BF16 => unsafe { + let ptr = out_ptr as *mut half::bf16; + *ptr.add(out_idx) = half::bf16::from_f32(acc); + }, + _ => {} // validated above + } + } + } + } +} + +impl Fp8MatmulOps for CpuClient { + fn fp8_matmul( + &self, + a: &Tensor, + b: &Tensor, + scale_a: f32, + scale_b: f32, + out_dtype: DType, + ) -> Result> { + let (out_shape, batch_size, m, k, n) = + validate_fp8_matmul(a, b, DType::FP8E4M3, DType::FP8E4M3, out_dtype)?; + + let a_contig = crate::runtime::cpu::helpers::ensure_contiguous(a); + let b_contig = crate::runtime::cpu::helpers::ensure_contiguous(b); + let out = Tensor::::empty(&out_shape, out_dtype, &self.device); + + fused_fp8_matmul_kernel( + a_contig.ptr() as *const u8, + b_contig.ptr() as *const u8, + out.ptr(), + |byte| FP8E4M3::from_bits(byte).to_f32(), + |byte| FP8E4M3::from_bits(byte).to_f32(), + scale_a * scale_b, + out_dtype, + batch_size, + m, + k, + n, + ); + + Ok(out) + } + + fn fp8_matmul_e5m2( + &self, + a: &Tensor, + b: &Tensor, + scale_a: f32, + scale_b: f32, + out_dtype: DType, + ) -> Result> { + let (out_shape, batch_size, m, k, n) = + validate_fp8_matmul(a, b, DType::FP8E5M2, DType::FP8E4M3, out_dtype)?; + + let a_contig = crate::runtime::cpu::helpers::ensure_contiguous(a); + let b_contig = crate::runtime::cpu::helpers::ensure_contiguous(b); + let out = Tensor::::empty(&out_shape, out_dtype, &self.device); + + fused_fp8_matmul_kernel( + a_contig.ptr() as *const u8, + b_contig.ptr() as *const u8, + out.ptr(), + |byte| FP8E5M2::from_bits(byte).to_f32(), + |byte| FP8E4M3::from_bits(byte).to_f32(), + scale_a * scale_b, + out_dtype, + batch_size, + m, + k, + n, + ); + + Ok(out) + } +} diff --git a/src/ops/cpu/gemm_epilogue.rs b/src/ops/cpu/gemm_epilogue.rs new file mode 100644 index 00000000..0913da8f --- /dev/null +++ b/src/ops/cpu/gemm_epilogue.rs @@ -0,0 +1,350 @@ +//! CPU implementation of GEMM epilogue operations. + +use crate::dtype::Element; +use crate::error::{Error, Result}; +use crate::ops::{GemmActivation, GemmEpilogueOps}; +use crate::ops::{matmul_bias_output_shape, validate_matmul_bias_dtypes}; +use crate::runtime::cpu::helpers::{dispatch_dtype, ensure_contiguous}; +use crate::runtime::cpu::kernels::{ + matmul_bias_activation_bwd_kernel, matmul_bias_activation_kernel, matmul_bias_residual_kernel, +}; +use crate::runtime::cpu::{CpuClient, CpuRuntime}; +use crate::tensor::Tensor; + +impl GemmEpilogueOps for CpuClient { + fn matmul_bias_activation( + &self, + a: &Tensor, + b: &Tensor, + bias: &Tensor, + activation: GemmActivation, + ) -> Result> { + let dtype = validate_matmul_bias_dtypes(a.dtype(), b.dtype(), bias.dtype())?; + + let out_shape = matmul_bias_output_shape(a.shape(), b.shape(), bias.shape()).ok_or( + Error::ShapeMismatch { + expected: a.shape().to_vec(), + got: b.shape().to_vec(), + }, + )?; + + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = if a_shape.len() >= 2 { + a_shape[a_shape.len() - 2] + } else { + 1 + }; + let k = a_shape[a_shape.len() - 1]; + let n = b_shape[b_shape.len() - 1]; + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let bias_contig = ensure_contiguous(bias); + + let batch_size: usize = out_shape + .iter() + .take(out_shape.len().saturating_sub(2)) + .product::() + .max(1); + + let out = Tensor::::empty(&out_shape, dtype, &self.device); + + let a_ptr = a_contig.ptr(); + let b_ptr = b_contig.ptr(); + let bias_ptr = bias_contig.ptr(); + let out_ptr = out.ptr(); + + let lda = k; + let ldb = n; + let ldc = n; + + dispatch_dtype!(dtype, T => { + #[cfg(feature = "rayon")] + { + use rayon::prelude::*; + + if batch_size > 1 { + let min_len = self.rayon_min_len(); + self.install_parallelism(|| { + (0..batch_size) + .into_par_iter() + .with_min_len(min_len) + .for_each(|batch| unsafe { + matmul_bias_activation_kernel::( + (a_ptr as *const T).add(batch * m * k), + (b_ptr as *const T).add(batch * k * n), + bias_ptr as *const T, + (out_ptr as *mut T).add(batch * m * n), + m, n, k, lda, ldb, ldc, + activation, + ); + }); + }); + } else { + unsafe { + matmul_bias_activation_kernel::( + a_ptr as *const T, + b_ptr as *const T, + bias_ptr as *const T, + out_ptr as *mut T, + m, n, k, lda, ldb, ldc, + activation, + ); + } + } + } + + #[cfg(not(feature = "rayon"))] + unsafe { + for batch in 0..batch_size { + matmul_bias_activation_kernel::( + (a_ptr as *const T).add(batch * m * k), + (b_ptr as *const T).add(batch * k * n), + bias_ptr as *const T, + (out_ptr as *mut T).add(batch * m * n), + m, n, k, lda, ldb, ldc, + activation, + ); + } + } + }, "matmul_bias_activation"); + + Ok(out) + } + + fn matmul_bias_residual( + &self, + a: &Tensor, + b: &Tensor, + bias: &Tensor, + residual: &Tensor, + ) -> Result> { + let dtype = validate_matmul_bias_dtypes(a.dtype(), b.dtype(), bias.dtype())?; + if residual.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: residual.dtype(), + }); + } + + let out_shape = matmul_bias_output_shape(a.shape(), b.shape(), bias.shape()).ok_or( + Error::ShapeMismatch { + expected: a.shape().to_vec(), + got: b.shape().to_vec(), + }, + )?; + + // Validate residual shape matches output shape + if residual.shape() != out_shape.as_slice() { + return Err(Error::ShapeMismatch { + expected: out_shape.clone(), + got: residual.shape().to_vec(), + }); + } + + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = if a_shape.len() >= 2 { + a_shape[a_shape.len() - 2] + } else { + 1 + }; + let k = a_shape[a_shape.len() - 1]; + let n = b_shape[b_shape.len() - 1]; + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let bias_contig = ensure_contiguous(bias); + let residual_contig = ensure_contiguous(residual); + + let batch_size: usize = out_shape + .iter() + .take(out_shape.len().saturating_sub(2)) + .product::() + .max(1); + + let out = Tensor::::empty(&out_shape, dtype, &self.device); + + let a_ptr = a_contig.ptr(); + let b_ptr = b_contig.ptr(); + let bias_ptr = bias_contig.ptr(); + let res_ptr = residual_contig.ptr(); + let out_ptr = out.ptr(); + + let lda = k; + let ldb = n; + let ldc = n; + + dispatch_dtype!(dtype, T => { + #[cfg(feature = "rayon")] + { + use rayon::prelude::*; + + if batch_size > 1 { + let min_len = self.rayon_min_len(); + self.install_parallelism(|| { + (0..batch_size) + .into_par_iter() + .with_min_len(min_len) + .for_each(|batch| unsafe { + matmul_bias_residual_kernel::( + (a_ptr as *const T).add(batch * m * k), + (b_ptr as *const T).add(batch * k * n), + bias_ptr as *const T, + (res_ptr as *const T).add(batch * m * n), + (out_ptr as *mut T).add(batch * m * n), + m, n, k, lda, ldb, ldc, + ); + }); + }); + } else { + unsafe { + matmul_bias_residual_kernel::( + a_ptr as *const T, + b_ptr as *const T, + bias_ptr as *const T, + res_ptr as *const T, + out_ptr as *mut T, + m, n, k, lda, ldb, ldc, + ); + } + } + } + + #[cfg(not(feature = "rayon"))] + unsafe { + for batch in 0..batch_size { + matmul_bias_residual_kernel::( + (a_ptr as *const T).add(batch * m * k), + (b_ptr as *const T).add(batch * k * n), + bias_ptr as *const T, + (res_ptr as *const T).add(batch * m * n), + (out_ptr as *mut T).add(batch * m * n), + m, n, k, lda, ldb, ldc, + ); + } + } + }, "matmul_bias_residual"); + + Ok(out) + } + + fn matmul_bias_activation_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + bias: &Tensor, + activation: GemmActivation, + ) -> Result<(Tensor, Tensor, Tensor)> { + let dtype = validate_matmul_bias_dtypes(a.dtype(), b.dtype(), bias.dtype())?; + if grad.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: grad.dtype(), + }); + } + + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = if a_shape.len() >= 2 { + a_shape[a_shape.len() - 2] + } else { + 1 + }; + let k = a_shape[a_shape.len() - 1]; + let n = b_shape[b_shape.len() - 1]; + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let bias_contig = ensure_contiguous(bias); + let grad_contig = ensure_contiguous(grad); + + let batch_size: usize = a_shape + .iter() + .take(a_shape.len().saturating_sub(2)) + .product::() + .max(1); + + // Output gradients + let d_a = Tensor::::empty(a_shape, dtype, &self.device); + let d_b = Tensor::::zeros(b_shape, dtype, &self.device); + + // d_bias is always [N] — we need to sum across batches + let d_bias_full = Tensor::::empty(&[n], dtype, &self.device); + + let a_ptr = a_contig.ptr(); + let b_ptr = b_contig.ptr(); + let bias_ptr = bias_contig.ptr(); + let grad_ptr = grad_contig.ptr(); + let d_a_ptr = d_a.ptr(); + let d_b_ptr = d_b.ptr(); + let d_bias_ptr = d_bias_full.ptr(); + + let lda = k; + let ldb = n; + let ld_grad = n; + + dispatch_dtype!(dtype, T => { + if batch_size == 1 { + unsafe { + matmul_bias_activation_bwd_kernel::( + grad_ptr as *const T, + a_ptr as *const T, + b_ptr as *const T, + bias_ptr as *const T, + d_a_ptr as *mut T, + d_b_ptr as *mut T, + d_bias_ptr as *mut T, + m, n, k, lda, ldb, ld_grad, + activation, + ); + } + } else { + // For batched: compute per-batch, accumulate d_b and d_bias + // Zero out d_b and d_bias first + unsafe { + for i in 0..k * n { + *(d_b_ptr as *mut T).add(i) = T::zero(); + } + for j in 0..n { + *(d_bias_ptr as *mut T).add(j) = T::zero(); + } + } + + let mut temp_d_b = vec![T::zero(); k * n]; + let mut temp_d_bias = vec![T::zero(); n]; + + for batch in 0..batch_size { + unsafe { + matmul_bias_activation_bwd_kernel::( + (grad_ptr as *const T).add(batch * m * n), + (a_ptr as *const T).add(batch * m * k), + (b_ptr as *const T).add(batch * k * n), + bias_ptr as *const T, + (d_a_ptr as *mut T).add(batch * m * k), + temp_d_b.as_mut_ptr(), + temp_d_bias.as_mut_ptr(), + m, n, k, lda, ldb, ld_grad, + activation, + ); + + // Accumulate d_b + for i in 0..k * n { + let ptr = (d_b_ptr as *mut T).add(i); + *ptr += temp_d_b[i]; + } + // Accumulate d_bias + for j in 0..n { + let ptr = (d_bias_ptr as *mut T).add(j); + *ptr += temp_d_bias[j]; + } + } + } + } + }, "matmul_bias_activation_bwd"); + + Ok((d_a, d_b, d_bias_full)) + } +} diff --git a/src/ops/cpu/indexing.rs b/src/ops/cpu/indexing.rs index bd298299..895a3814 100644 --- a/src/ops/cpu/indexing.rs +++ b/src/ops/cpu/indexing.rs @@ -11,7 +11,7 @@ use crate::runtime::cpu::{ helpers::{ bincount_impl, dispatch_dtype, embedding_lookup_impl, ensure_contiguous, gather_2d_impl, gather_impl, gather_nd_impl, index_put_impl, index_select_impl, masked_fill_impl, - masked_select_impl, scatter_impl, scatter_reduce_impl, + masked_select_impl, scatter_impl, scatter_reduce_impl, slice_assign_impl, }, kernels, }; @@ -43,8 +43,8 @@ impl IndexingOps for CpuClient { let a_contig = ensure_contiguous(a); let out = Tensor::::empty(&out_shape, DType::I64, &self.device); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -85,8 +85,8 @@ impl IndexingOps for CpuClient { let a_contig = ensure_contiguous(a); let out = Tensor::::empty(&out_shape, DType::I64, &self.device); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -203,4 +203,14 @@ impl IndexingOps for CpuClient { ) -> Result> { gather_2d_impl(self, input, rows, cols) } + + fn slice_assign( + &self, + dst: &Tensor, + src: &Tensor, + dim: usize, + start: usize, + ) -> Result> { + slice_assign_impl(self, dst, src, dim, start) + } } diff --git a/src/ops/cpu/logical.rs b/src/ops/cpu/logical.rs index 09c2bdad..07e3ef4c 100644 --- a/src/ops/cpu/logical.rs +++ b/src/ops/cpu/logical.rs @@ -38,9 +38,9 @@ impl LogicalOps for CpuClient { let b_contig = ensure_contiguous(b); let out = Tensor::::empty(a.shape(), DType::U8, &self.device); - let a_ptr = a_contig.storage().ptr() as *const u8; - let b_ptr = b_contig.storage().ptr() as *const u8; - let out_ptr = out.storage().ptr() as *mut u8; + let a_ptr = a_contig.ptr() as *const u8; + let b_ptr = b_contig.ptr() as *const u8; + let out_ptr = out.ptr() as *mut u8; let numel = a.numel(); unsafe { @@ -81,9 +81,9 @@ impl LogicalOps for CpuClient { let b_contig = ensure_contiguous(b); let out = Tensor::::empty(a.shape(), DType::U8, &self.device); - let a_ptr = a_contig.storage().ptr() as *const u8; - let b_ptr = b_contig.storage().ptr() as *const u8; - let out_ptr = out.storage().ptr() as *mut u8; + let a_ptr = a_contig.ptr() as *const u8; + let b_ptr = b_contig.ptr() as *const u8; + let out_ptr = out.ptr() as *mut u8; let numel = a.numel(); unsafe { @@ -124,9 +124,9 @@ impl LogicalOps for CpuClient { let b_contig = ensure_contiguous(b); let out = Tensor::::empty(a.shape(), DType::U8, &self.device); - let a_ptr = a_contig.storage().ptr() as *const u8; - let b_ptr = b_contig.storage().ptr() as *const u8; - let out_ptr = out.storage().ptr() as *mut u8; + let a_ptr = a_contig.ptr() as *const u8; + let b_ptr = b_contig.ptr() as *const u8; + let out_ptr = out.ptr() as *mut u8; let numel = a.numel(); unsafe { @@ -148,8 +148,8 @@ impl LogicalOps for CpuClient { let a_contig = ensure_contiguous(a); let out = Tensor::::empty(a.shape(), DType::U8, &self.device); - let a_ptr = a_contig.storage().ptr() as *const u8; - let out_ptr = out.storage().ptr() as *mut u8; + let a_ptr = a_contig.ptr() as *const u8; + let out_ptr = out.ptr() as *mut u8; let numel = a.numel(); unsafe { diff --git a/src/ops/cpu/matmul.rs b/src/ops/cpu/matmul.rs index 8b94693c..53cb25ed 100644 --- a/src/ops/cpu/matmul.rs +++ b/src/ops/cpu/matmul.rs @@ -1,5 +1,6 @@ //! CPU implementation of matrix multiplication operations. +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::{Kernel, MatmulOps}; use crate::runtime::cpu::{ @@ -40,30 +41,209 @@ impl MatmulOps for CpuClient { let k = a_shape[a_shape.len() - 1]; let n = b_shape[b_shape.len() - 1]; - // Require row-major contiguous tensors for SIMD-optimized packing - // Non-contiguous tensors (transposed, views) are copied to contiguous layout - let a_contig = ensure_contiguous(a); - let b_contig = ensure_contiguous(b); - - // Calculate batch size + // Calculate batch size from output shape, and per-operand batch sizes for broadcasting let batch_size: usize = out_shape .iter() .take(out_shape.len().saturating_sub(2)) .product(); let batch_size = batch_size.max(1); - // Create output tensor - let out = Tensor::::empty(&out_shape, dtype, &self.device); + let a_batch: usize = a_shape + .iter() + .take(a_shape.len().saturating_sub(2)) + .product::() + .max(1); + let b_batch: usize = b_shape + .iter() + .take(b_shape.len().saturating_sub(2)) + .product::() + .max(1); + + // GEMV-BT fast path: detect transposed B and use dot-product kernel + // When B has shape [K,N] with strides [1,K], it's a transpose of contiguous [N,K]. + // For small M (decode), we can dot A rows against B's original [N,K] rows directly, + // avoiding the costly contiguous copy (e.g. 500MB for lm_head weights). + if m <= 16 && b_shape.len() >= 2 && dtype != DType::I8 { + let b_strides = b.strides(); + let ndim = b_shape.len(); + let stride_row = b_strides[ndim - 2]; // stride for K dimension + let stride_col = b_strides[ndim - 1]; // stride for N dimension + + // Check if B is a simple transpose: shape [K,N], strides [1, K] + // meaning the underlying data is contiguous [N,K] + if stride_row == 1 && stride_col == k as isize { + let a_contig = ensure_contiguous(a); + let a_ptr = a_contig.ptr(); + let b_ptr = b.ptr(); // Use original ptr - data is contiguous [N,K] + + // Create output tensor + let out = Tensor::::empty(&out_shape, dtype, &self.device); + let out_ptr = out.ptr(); + let ldc = n; + + dispatch_dtype!(dtype, T => { + for batch in 0..batch_size { + let a_offset = if a_batch > 1 { batch * m * k } else { 0 }; + let b_offset = if b_batch > 1 { batch * n * k } else { 0 }; + let out_offset = batch * m * n; + + #[cfg(feature = "rayon")] + { + use rayon::prelude::*; + + // Parallelize over output columns for large N + // Each thread computes a chunk of columns independently + let min_cols_per_thread = 64usize; + let num_threads = rayon::current_num_threads(); + let chunk_size = ((n + num_threads - 1) / num_threads).max(min_cols_per_thread); + + if n > min_cols_per_thread && num_threads > 1 { + // Convert to usize for Send safety - each thread + // accesses disjoint memory regions + let a_send = (a_ptr as usize) + a_offset * std::mem::size_of::(); + let b_send = (b_ptr as usize) + b_offset * std::mem::size_of::(); + let out_send = (out_ptr as usize) + out_offset * std::mem::size_of::(); + let elem_size = std::mem::size_of::(); + + self.install_parallelism(|| { + (0..n).into_par_iter().step_by(chunk_size).for_each(|col_start| { + let col_end = (col_start + chunk_size).min(n); + let chunk_n = col_end - col_start; + unsafe { + let a_base = a_send as *const T; + let b_chunk = (b_send + col_start * k * elem_size) as *const T; + let out_chunk = (out_send + col_start * elem_size) as *mut T; + + crate::runtime::cpu::kernels::gemv_bt_kernel::( + a_base, + b_chunk, + out_chunk, + m, chunk_n, k, n, + ); + } + }); + }); + } else { + unsafe { + crate::runtime::cpu::kernels::gemv_bt_kernel::( + (a_ptr as *const T).add(a_offset), + (b_ptr as *const T).add(b_offset), + (out_ptr as *mut T).add(out_offset), + m, n, k, ldc, + ); + } + } + } + + #[cfg(not(feature = "rayon"))] + unsafe { + crate::runtime::cpu::kernels::gemv_bt_kernel::( + (a_ptr as *const T).add(a_offset), + (b_ptr as *const T).add(b_offset), + (out_ptr as *mut T).add(out_offset), + m, n, k, ldc, + ); + } + } + }, "matmul_gemv_bt"); + + return Ok(out); + } + } + + // Require row-major contiguous tensors for SIMD-optimized packing + // Non-contiguous tensors (transposed, views) are copied to contiguous layout + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); - let a_ptr = a_contig.storage().ptr(); - let b_ptr = b_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let b_ptr = b_contig.ptr(); // Leading dimensions for contiguous row-major matrices let lda = k; let ldb = n; let ldc = n; + // Special case: i8 × i8 → i32 matmul (quantized accumulation) + if dtype == DType::I8 { + use crate::runtime::cpu::kernels::matmul_i8_to_i32_kernel; + + let out = Tensor::::empty(&out_shape, DType::I32, &self.device); + let out_ptr = out.ptr(); + + #[cfg(feature = "rayon")] + { + use rayon::prelude::*; + + if batch_size > 1 { + let min_len = self.rayon_min_len(); + self.install_parallelism(|| { + (0..batch_size) + .into_par_iter() + .with_min_len(min_len) + .for_each(|batch| unsafe { + let a_offset = if a_batch > 1 { batch * m * k } else { 0 }; + let b_offset = if b_batch > 1 { batch * k * n } else { 0 }; + let out_offset = batch * m * n; + + matmul_i8_to_i32_kernel( + (a_ptr as *const i8).add(a_offset), + (b_ptr as *const i8).add(b_offset), + (out_ptr as *mut i32).add(out_offset), + m, + n, + k, + lda, + ldb, + ldc, + ); + }); + }); + } else { + unsafe { + matmul_i8_to_i32_kernel( + a_ptr as *const i8, + b_ptr as *const i8, + out_ptr as *mut i32, + m, + n, + k, + lda, + ldb, + ldc, + ); + } + } + } + + #[cfg(not(feature = "rayon"))] + unsafe { + for batch in 0..batch_size { + let a_offset = if a_batch > 1 { batch * m * k } else { 0 }; + let b_offset = if b_batch > 1 { batch * k * n } else { 0 }; + let out_offset = batch * m * n; + + matmul_i8_to_i32_kernel( + (a_ptr as *const i8).add(a_offset), + (b_ptr as *const i8).add(b_offset), + (out_ptr as *mut i32).add(out_offset), + m, + n, + k, + lda, + ldb, + ldc, + ); + } + } + + return Ok(out); + } + + // Create output tensor + let out = Tensor::::empty(&out_shape, dtype, &self.device); + let out_ptr = out.ptr(); + // Dispatch based on dtype dispatch_dtype!(dtype, T => { #[cfg(feature = "rayon")] @@ -77,8 +257,8 @@ impl MatmulOps for CpuClient { .into_par_iter() .with_min_len(min_len) .for_each(|batch| unsafe { - let a_offset = batch * m * k; - let b_offset = batch * k * n; + let a_offset = if a_batch > 1 { batch * m * k } else { 0 }; + let b_offset = if b_batch > 1 { batch * k * n } else { 0 }; let out_offset = batch * m * n; >::matmul::( @@ -119,8 +299,8 @@ impl MatmulOps for CpuClient { #[cfg(not(feature = "rayon"))] unsafe { for batch in 0..batch_size { - let a_offset = batch * m * k; - let b_offset = batch * k * n; + let a_offset = if a_batch > 1 { batch * m * k } else { 0 }; + let b_offset = if b_batch > 1 { batch * k * n } else { 0 }; let out_offset = batch * m * n; >::matmul::( @@ -178,20 +358,31 @@ impl MatmulOps for CpuClient { let b_contig = ensure_contiguous(b); let bias_contig = ensure_contiguous(bias); - // Calculate batch size + // Calculate batch size from output shape, and per-operand batch sizes for broadcasting let batch_size: usize = out_shape .iter() .take(out_shape.len().saturating_sub(2)) .product(); let batch_size = batch_size.max(1); + let a_batch: usize = a_shape + .iter() + .take(a_shape.len().saturating_sub(2)) + .product::() + .max(1); + let b_batch: usize = b_shape + .iter() + .take(b_shape.len().saturating_sub(2)) + .product::() + .max(1); + // Create output tensor let out = Tensor::::empty(&out_shape, dtype, &self.device); - let a_ptr = a_contig.storage().ptr(); - let b_ptr = b_contig.storage().ptr(); - let bias_ptr = bias_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let b_ptr = b_contig.ptr(); + let bias_ptr = bias_contig.ptr(); + let out_ptr = out.ptr(); // Leading dimensions for contiguous row-major matrices let lda = k; @@ -211,8 +402,8 @@ impl MatmulOps for CpuClient { .into_par_iter() .with_min_len(min_len) .for_each(|batch| unsafe { - let a_offset = batch * m * k; - let b_offset = batch * k * n; + let a_offset = if a_batch > 1 { batch * m * k } else { 0 }; + let b_offset = if b_batch > 1 { batch * k * n } else { 0 }; let out_offset = batch * m * n; matmul_bias_kernel::( @@ -254,8 +445,8 @@ impl MatmulOps for CpuClient { #[cfg(not(feature = "rayon"))] unsafe { for batch in 0..batch_size { - let a_offset = batch * m * k; - let b_offset = batch * k * n; + let a_offset = if a_batch > 1 { batch * m * k } else { 0 }; + let b_offset = if b_batch > 1 { batch * k * n } else { 0 }; let out_offset = batch * m * n; matmul_bias_kernel::( diff --git a/src/ops/cpu/mod.rs b/src/ops/cpu/mod.rs index 39515f1b..0e0aab53 100644 --- a/src/ops/cpu/mod.rs +++ b/src/ops/cpu/mod.rs @@ -13,6 +13,9 @@ pub mod conv; pub mod cumulative; pub mod distance; pub mod einsum; +#[cfg(feature = "fp8")] +pub mod fp8_matmul; +pub mod gemm_epilogue; pub mod indexing; pub mod linalg; pub mod logical; @@ -25,6 +28,8 @@ pub mod random; pub mod reduce; pub mod scalar; pub mod shape; +#[cfg(feature = "sparse")] +pub mod sparse_24; pub mod sorting; pub mod statistics; pub mod type_conversion; diff --git a/src/ops/cpu/normalization.rs b/src/ops/cpu/normalization.rs index 2ba0fc23..826b786f 100644 --- a/src/ops/cpu/normalization.rs +++ b/src/ops/cpu/normalization.rs @@ -45,9 +45,9 @@ impl NormalizationOps for CpuClient { let weight_contig = ensure_contiguous(weight); let out = Tensor::::empty(input_shape, dtype, &self.device); - let input_ptr = input_contig.storage().ptr(); - let weight_ptr = weight_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let input_ptr = input_contig.ptr(); + let weight_ptr = weight_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -111,10 +111,10 @@ impl NormalizationOps for CpuClient { let bias_contig = ensure_contiguous(bias); let out = Tensor::::empty(input_shape, dtype, &self.device); - let input_ptr = input_contig.storage().ptr(); - let weight_ptr = weight_contig.storage().ptr(); - let bias_ptr = bias_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let input_ptr = input_contig.ptr(); + let weight_ptr = weight_contig.ptr(); + let bias_ptr = bias_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -132,4 +132,356 @@ impl NormalizationOps for CpuClient { Ok(out) } + + fn group_norm( + &self, + input: &Tensor, + weight: &Tensor, + bias: &Tensor, + num_groups: usize, + eps: f32, + ) -> Result> { + let dtype = input.dtype(); + + if weight.dtype() != dtype || bias.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if weight.dtype() != dtype { + weight.dtype() + } else { + bias.dtype() + }, + }); + } + + let shape = input.shape(); + if shape.len() < 2 { + return Err(Error::InvalidArgument { + arg: "input", + reason: "group_norm requires at least 2D input [batch, channels, ...]".into(), + }); + } + + let batch = shape[0]; + let channels = shape[1]; + if !channels.is_multiple_of(num_groups) { + return Err(Error::InvalidArgument { + arg: "num_groups", + reason: format!("channels {channels} not divisible by num_groups {num_groups}"), + }); + } + let channels_per_group = channels / num_groups; + let spatial: usize = shape[2..].iter().product::().max(1); + + if weight.shape() != [channels] || bias.shape() != [channels] { + return Err(Error::ShapeMismatch { + expected: vec![channels], + got: if weight.shape() != [channels] { + weight.shape().to_vec() + } else { + bias.shape().to_vec() + }, + }); + } + + let input_contig = ensure_contiguous(input); + let weight_contig = ensure_contiguous(weight); + let bias_contig = ensure_contiguous(bias); + let out = Tensor::::empty(shape, dtype, &self.device); + + dispatch_dtype!(dtype, T => { + unsafe { + kernels::group_norm_kernel::( + input_contig.ptr() as *const T, + weight_contig.ptr() as *const T, + bias_contig.ptr() as *const T, + out.ptr() as *mut T, + batch, + channels, + spatial, + num_groups, + channels_per_group, + eps, + ); + } + }, "group_norm"); + + Ok(out) + } + + fn fused_add_rms_norm( + &self, + x: &Tensor, + residual: &Tensor, + weight: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor)> { + let dtype = x.dtype(); + + if residual.dtype() != dtype || weight.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if residual.dtype() != dtype { + residual.dtype() + } else { + weight.dtype() + }, + }); + } + + let input_shape = x.shape(); + if residual.shape() != input_shape { + return Err(Error::ShapeMismatch { + expected: input_shape.to_vec(), + got: residual.shape().to_vec(), + }); + } + + let hidden_size = input_shape[input_shape.len() - 1]; + if weight.shape() != [hidden_size] { + return Err(Error::ShapeMismatch { + expected: vec![hidden_size], + got: weight.shape().to_vec(), + }); + } + + let batch_size: usize = input_shape[..input_shape.len() - 1].iter().product(); + let batch_size = batch_size.max(1); + + let x_contig = ensure_contiguous(x); + let res_contig = ensure_contiguous(residual); + let weight_contig = ensure_contiguous(weight); + let out = Tensor::::empty(input_shape, dtype, &self.device); + let pre_norm = Tensor::::empty(input_shape, dtype, &self.device); + + dispatch_dtype!(dtype, T => { + unsafe { + kernels::fused_add_rms_norm_kernel::( + x_contig.ptr() as *const T, + res_contig.ptr() as *const T, + weight_contig.ptr() as *const T, + out.ptr() as *mut T, + pre_norm.ptr() as *mut T, + batch_size, + hidden_size, + eps, + ); + } + }, "fused_add_rms_norm"); + + Ok((out, pre_norm)) + } + + fn fused_add_rms_norm_bwd( + &self, + grad: &Tensor, + pre_norm: &Tensor, + weight: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor)> { + let dtype = grad.dtype(); + + if pre_norm.dtype() != dtype || weight.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if pre_norm.dtype() != dtype { + pre_norm.dtype() + } else { + weight.dtype() + }, + }); + } + + let grad_shape = grad.shape(); + if pre_norm.shape() != grad_shape { + return Err(Error::ShapeMismatch { + expected: grad_shape.to_vec(), + got: pre_norm.shape().to_vec(), + }); + } + + let hidden_size = grad_shape[grad_shape.len() - 1]; + if weight.shape() != [hidden_size] { + return Err(Error::ShapeMismatch { + expected: vec![hidden_size], + got: weight.shape().to_vec(), + }); + } + + let batch_size: usize = grad_shape[..grad_shape.len() - 1].iter().product(); + let batch_size = batch_size.max(1); + + let grad_contig = ensure_contiguous(grad); + let pre_norm_contig = ensure_contiguous(pre_norm); + let weight_contig = ensure_contiguous(weight); + let d_input_residual = Tensor::::empty(grad_shape, dtype, &self.device); + let d_weight = Tensor::::zeros(&[hidden_size], dtype, &self.device); + + dispatch_dtype!(dtype, T => { + unsafe { + kernels::fused_add_rms_norm_bwd_kernel::( + grad_contig.ptr() as *const T, + pre_norm_contig.ptr() as *const T, + weight_contig.ptr() as *const T, + d_input_residual.ptr() as *mut T, + d_weight.ptr() as *mut T, + batch_size, + hidden_size, + eps, + ); + } + }, "fused_add_rms_norm_bwd"); + + Ok((d_input_residual, d_weight)) + } + + fn fused_add_layer_norm( + &self, + x: &Tensor, + residual: &Tensor, + weight: &Tensor, + bias: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor)> { + let dtype = x.dtype(); + + if residual.dtype() != dtype || weight.dtype() != dtype || bias.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if residual.dtype() != dtype { + residual.dtype() + } else if weight.dtype() != dtype { + weight.dtype() + } else { + bias.dtype() + }, + }); + } + + let input_shape = x.shape(); + if residual.shape() != input_shape { + return Err(Error::ShapeMismatch { + expected: input_shape.to_vec(), + got: residual.shape().to_vec(), + }); + } + + let hidden_size = input_shape[input_shape.len() - 1]; + if weight.shape() != [hidden_size] { + return Err(Error::ShapeMismatch { + expected: vec![hidden_size], + got: weight.shape().to_vec(), + }); + } + if bias.shape() != [hidden_size] { + return Err(Error::ShapeMismatch { + expected: vec![hidden_size], + got: bias.shape().to_vec(), + }); + } + + let batch_size: usize = input_shape[..input_shape.len() - 1].iter().product(); + let batch_size = batch_size.max(1); + + let x_contig = ensure_contiguous(x); + let res_contig = ensure_contiguous(residual); + let weight_contig = ensure_contiguous(weight); + let bias_contig = ensure_contiguous(bias); + let out = Tensor::::empty(input_shape, dtype, &self.device); + let pre_norm = Tensor::::empty(input_shape, dtype, &self.device); + + dispatch_dtype!(dtype, T => { + unsafe { + kernels::fused_add_layer_norm_kernel::( + x_contig.ptr() as *const T, + res_contig.ptr() as *const T, + weight_contig.ptr() as *const T, + bias_contig.ptr() as *const T, + out.ptr() as *mut T, + pre_norm.ptr() as *mut T, + batch_size, + hidden_size, + eps, + ); + } + }, "fused_add_layer_norm"); + + Ok((out, pre_norm)) + } + + fn fused_add_layer_norm_bwd( + &self, + grad: &Tensor, + pre_norm: &Tensor, + weight: &Tensor, + bias: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor, Tensor)> { + let dtype = grad.dtype(); + + if pre_norm.dtype() != dtype || weight.dtype() != dtype || bias.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if pre_norm.dtype() != dtype { + pre_norm.dtype() + } else if weight.dtype() != dtype { + weight.dtype() + } else { + bias.dtype() + }, + }); + } + + let grad_shape = grad.shape(); + if pre_norm.shape() != grad_shape { + return Err(Error::ShapeMismatch { + expected: grad_shape.to_vec(), + got: pre_norm.shape().to_vec(), + }); + } + + let hidden_size = grad_shape[grad_shape.len() - 1]; + if weight.shape() != [hidden_size] { + return Err(Error::ShapeMismatch { + expected: vec![hidden_size], + got: weight.shape().to_vec(), + }); + } + if bias.shape() != [hidden_size] { + return Err(Error::ShapeMismatch { + expected: vec![hidden_size], + got: bias.shape().to_vec(), + }); + } + + let batch_size: usize = grad_shape[..grad_shape.len() - 1].iter().product(); + let batch_size = batch_size.max(1); + + let grad_contig = ensure_contiguous(grad); + let pre_norm_contig = ensure_contiguous(pre_norm); + let weight_contig = ensure_contiguous(weight); + let bias_contig = ensure_contiguous(bias); + let d_input_residual = Tensor::::empty(grad_shape, dtype, &self.device); + let d_weight = Tensor::::zeros(&[hidden_size], dtype, &self.device); + let d_bias = Tensor::::zeros(&[hidden_size], dtype, &self.device); + + dispatch_dtype!(dtype, T => { + unsafe { + kernels::fused_add_layer_norm_bwd_kernel::( + grad_contig.ptr() as *const T, + pre_norm_contig.ptr() as *const T, + weight_contig.ptr() as *const T, + bias_contig.ptr() as *const T, + d_input_residual.ptr() as *mut T, + d_weight.ptr() as *mut T, + d_bias.ptr() as *mut T, + batch_size, + hidden_size, + eps, + ); + } + }, "fused_add_layer_norm_bwd"); + + Ok((d_input_residual, d_weight, d_bias)) + } } diff --git a/src/ops/cpu/quasirandom.rs b/src/ops/cpu/quasirandom.rs index aba9dc06..d53be983 100644 --- a/src/ops/cpu/quasirandom.rs +++ b/src/ops/cpu/quasirandom.rs @@ -26,10 +26,10 @@ impl QuasiRandomOps for CpuClient { match dtype { DType::F32 => unsafe { - kernels::sobol_f32(out.storage().ptr() as *mut f32, n_points, dimension, skip); + kernels::sobol_f32(out.ptr() as *mut f32, n_points, dimension, skip); }, DType::F64 => unsafe { - kernels::sobol_f64(out.storage().ptr() as *mut f64, n_points, dimension, skip); + kernels::sobol_f64(out.ptr() as *mut f64, n_points, dimension, skip); }, _ => unreachable!("dtype validation should prevent this"), } @@ -50,10 +50,10 @@ impl QuasiRandomOps for CpuClient { match dtype { DType::F32 => unsafe { - kernels::halton_f32(out.storage().ptr() as *mut f32, n_points, dimension, skip); + kernels::halton_f32(out.ptr() as *mut f32, n_points, dimension, skip); }, DType::F64 => unsafe { - kernels::halton_f64(out.storage().ptr() as *mut f64, n_points, dimension, skip); + kernels::halton_f64(out.ptr() as *mut f64, n_points, dimension, skip); }, _ => unreachable!("dtype validation should prevent this"), } @@ -79,10 +79,10 @@ impl QuasiRandomOps for CpuClient { match dtype { DType::F32 => unsafe { - kernels::latin_hypercube_f32(out.storage().ptr() as *mut f32, n_samples, dimension); + kernels::latin_hypercube_f32(out.ptr() as *mut f32, n_samples, dimension); }, DType::F64 => unsafe { - kernels::latin_hypercube_f64(out.storage().ptr() as *mut f64, n_samples, dimension); + kernels::latin_hypercube_f64(out.ptr() as *mut f64, n_samples, dimension); }, _ => unreachable!("dtype validation should prevent this"), } diff --git a/src/ops/cpu/random.rs b/src/ops/cpu/random.rs index 6ada13b7..119dc152 100644 --- a/src/ops/cpu/random.rs +++ b/src/ops/cpu/random.rs @@ -26,7 +26,7 @@ impl RandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -37,6 +37,34 @@ impl RandomOps for CpuClient { Ok(out) } + fn rand_seeded(&self, shape: &[usize], dtype: DType, seed: u64) -> Result> { + if !dtype.is_float() { + return Err(Error::UnsupportedDType { + dtype, + op: "rand_seeded", + }); + } + + let out = Tensor::::empty(shape, dtype, &self.device); + let numel = out.numel(); + + if numel == 0 { + return Ok(out); + } + + let out_ptr = out.ptr(); + + dispatch_dtype!(dtype, T => { + unsafe { + kernels::xoshiro256_uniform_kernel::( + out_ptr as *mut T, numel, seed, + ); + } + }, "rand_seeded"); + + Ok(out) + } + fn randn(&self, shape: &[usize], dtype: DType) -> Result> { // Validate dtype is floating point if !dtype.is_float() { @@ -51,7 +79,7 @@ impl RandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -107,7 +135,7 @@ impl RandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -186,17 +214,15 @@ impl RandomOps for CpuClient { // Check the max value - if all values are <= 0, we cannot sample let max_prob: f64 = match dtype { DType::F32 => { - let data: &[f32] = unsafe { - std::slice::from_raw_parts(probs.storage().ptr() as *const f32, probs.numel()) - }; + let data: &[f32] = + unsafe { std::slice::from_raw_parts(probs.ptr() as *const f32, probs.numel()) }; data.iter() .cloned() .fold(f64::NEG_INFINITY, |a, b| a.max(b as f64)) } DType::F64 => { - let data: &[f64] = unsafe { - std::slice::from_raw_parts(probs.storage().ptr() as *const f64, probs.numel()) - }; + let data: &[f64] = + unsafe { std::slice::from_raw_parts(probs.ptr() as *const f64, probs.numel()) }; data.iter().cloned().fold(f64::NEG_INFINITY, f64::max) } _ => { @@ -220,8 +246,8 @@ impl RandomOps for CpuClient { } let out = Tensor::::empty(&out_shape, DType::I64, &self.device); - let out_ptr = out.storage().ptr() as *mut i64; - let probs_ptr = probs.storage().ptr(); + let out_ptr = out.ptr() as *mut i64; + let probs_ptr = probs.ptr(); // Dispatch based on input dtype dispatch_dtype!(dtype, T => { @@ -272,7 +298,7 @@ impl RandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { kernels::bernoulli_kernel::(out_ptr as *mut T, p, numel); } }, "bernoulli"); @@ -312,7 +338,7 @@ impl RandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { kernels::beta_kernel::(out_ptr as *mut T, alpha, beta, numel); } }, "beta"); @@ -352,7 +378,7 @@ impl RandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { kernels::gamma_kernel::(out_ptr as *mut T, shape_param, scale, numel); } }, "gamma"); @@ -383,7 +409,7 @@ impl RandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { kernels::exponential_kernel::(out_ptr as *mut T, rate, numel); } }, "exponential"); @@ -414,7 +440,7 @@ impl RandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { kernels::poisson_kernel::(out_ptr as *mut T, lambda, numel); } }, "poisson"); @@ -457,7 +483,7 @@ impl RandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { kernels::binomial_kernel::(out_ptr as *mut T, n, p, numel); } }, "binomial"); @@ -494,7 +520,7 @@ impl RandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { kernels::laplace_kernel::(out_ptr as *mut T, loc, scale, numel); } }, "laplace"); @@ -525,7 +551,7 @@ impl RandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { kernels::chi_squared_kernel::(out_ptr as *mut T, df, numel); } }, "chi_squared"); @@ -556,7 +582,7 @@ impl RandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { kernels::student_t_kernel::(out_ptr as *mut T, df, numel); } }, "student_t"); @@ -573,7 +599,7 @@ impl RandomOps for CpuClient { } let out = Tensor::::empty(&[n], DType::I64, &self.device); - let out_ptr = out.storage().ptr() as *mut i64; + let out_ptr = out.ptr() as *mut i64; unsafe { kernels::randperm_kernel(out_ptr, n); @@ -617,7 +643,7 @@ impl RandomOps for CpuClient { return Ok(out); } - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { kernels::f_distribution_kernel::(out_ptr as *mut T, df1, df2, numel); } }, "f_distribution"); diff --git a/src/ops/cpu/scalar.rs b/src/ops/cpu/scalar.rs index c4019c82..9b0b6e7a 100644 --- a/src/ops/cpu/scalar.rs +++ b/src/ops/cpu/scalar.rs @@ -3,7 +3,8 @@ use crate::error::Result; use crate::ops::{BinaryOp, ScalarOps}; use crate::runtime::cpu::{ - CpuClient, CpuRuntime, helpers::scalar::rsub_scalar_op_impl, helpers::scalar_op_impl, + CpuClient, CpuRuntime, helpers::fused_mul_add_scalar_impl, + helpers::scalar::rsub_scalar_op_impl, helpers::scalar_op_impl, }; use crate::tensor::Tensor; @@ -31,6 +32,15 @@ impl ScalarOps for CpuClient { fn rsub_scalar(&self, a: &Tensor, scalar: f64) -> Result> { rsub_scalar_op_impl(self, a, scalar) } + + fn fused_mul_add_scalar( + &self, + a: &Tensor, + scale: f64, + bias: f64, + ) -> Result> { + fused_mul_add_scalar_impl(self, a, scale, bias) + } } #[cfg(test)] diff --git a/src/ops/cpu/semiring_matmul.rs b/src/ops/cpu/semiring_matmul.rs index e69fcb9e..c61aeb32 100644 --- a/src/ops/cpu/semiring_matmul.rs +++ b/src/ops/cpu/semiring_matmul.rs @@ -57,19 +57,30 @@ impl SemiringMatmulOps for CpuClient { let a_contig = ensure_contiguous(a); let b_contig = ensure_contiguous(b); - // Calculate batch size + // Calculate batch size from output shape and per-input batch counts let batch_size: usize = out_shape .iter() .take(out_shape.len().saturating_sub(2)) .product(); let batch_size = batch_size.max(1); + let a_batch_count: usize = a_shape + .iter() + .take(a_shape.len().saturating_sub(2)) + .product(); + let a_batch_count = a_batch_count.max(1); + let b_batch_count: usize = b_shape + .iter() + .take(b_shape.len().saturating_sub(2)) + .product(); + let b_batch_count = b_batch_count.max(1); + // Create output tensor let out = Tensor::::empty(&out_shape, dtype, &self.device); - let a_ptr = a_contig.storage().ptr(); - let b_ptr = b_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let b_ptr = b_contig.ptr(); + let out_ptr = out.ptr(); let lda = k; let ldb = n; @@ -80,8 +91,8 @@ impl SemiringMatmulOps for CpuClient { // Bool is stored as u8 internally unsafe { for batch in 0..batch_size { - let a_offset = batch * m * k; - let b_offset = batch * k * n; + let a_offset = (batch % a_batch_count) * m * k; + let b_offset = (batch % b_batch_count) * k * n; let out_offset = batch * m * n; or_and_kernel( @@ -104,8 +115,8 @@ impl SemiringMatmulOps for CpuClient { dispatch_dtype!(dtype, T => { unsafe { for batch in 0..batch_size { - let a_offset = batch * m * k; - let b_offset = batch * k * n; + let a_offset = (batch % a_batch_count) * m * k; + let b_offset = (batch % b_batch_count) * k * n; let out_offset = batch * m * n; semiring_matmul_kernel::( diff --git a/src/ops/cpu/sparse_24.rs b/src/ops/cpu/sparse_24.rs new file mode 100644 index 00000000..cd50a0e5 --- /dev/null +++ b/src/ops/cpu/sparse_24.rs @@ -0,0 +1,98 @@ +//! CPU implementation of 2:4 structured sparsity operations. + +use crate::dispatch_dtype; +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::ops::MatmulOps; +use crate::ops::traits::Sparse24Ops; +use crate::runtime::cpu::kernels::sparse_24; +use crate::runtime::cpu::{CpuClient, CpuRuntime}; +use crate::runtime::ensure_contiguous; +use crate::sparse::structured::{Sparse24Tensor, meta_cols_for_k}; +use crate::tensor::Tensor; + +impl Sparse24Ops for CpuClient { + fn prune_to_24(&self, dense: &Tensor) -> Result> { + if dense.ndim() != 2 { + return Err(Error::InvalidArgument { + arg: "dense", + reason: format!("Expected 2D tensor, got {}D", dense.ndim()), + }); + } + + let m = dense.shape()[0]; + let k = dense.shape()[1]; + + if !k.is_multiple_of(4) { + return Err(Error::InvalidArgument { + arg: "dense", + reason: format!("K dimension ({k}) must be divisible by 4 for 2:4 sparsity"), + }); + } + + let dtype = dense.dtype(); + let device = dense.device().clone(); + let dense_contig = ensure_contiguous(dense); + + let half_k = k / 2; + let mc = meta_cols_for_k(k); + + let compressed = Tensor::::empty(&[m, half_k], dtype, &device); + let metadata = Tensor::::empty(&[m, mc], DType::U32, &device); + + dispatch_dtype!(dtype, T => { + unsafe { + sparse_24::prune_to_24_kernel::( + dense_contig.ptr() as *const T, + compressed.ptr() as *mut T, + metadata.ptr() as *mut u32, + m, + k, + ); + } + }, "prune_to_24"); + + Sparse24Tensor::new(compressed, metadata, [m, k]) + } + + fn sparse_24_to_dense( + &self, + sparse: &Sparse24Tensor, + ) -> Result> { + let [m, k] = sparse.shape(); + let dtype = sparse.dtype(); + let device = sparse.compressed_values().device().clone(); + + let dense = Tensor::::empty(&[m, k], dtype, &device); + + let vals = ensure_contiguous(sparse.compressed_values()); + let meta = ensure_contiguous(sparse.metadata()); + + dispatch_dtype!(dtype, T => { + unsafe { + sparse_24::decompress_24_kernel::( + vals.ptr() as *const T, + meta.ptr() as *const u32, + dense.ptr() as *mut T, + m, + k, + ); + } + }, "sparse_24_to_dense"); + + Ok(dense) + } + + fn sparse_24_matmul( + &self, + input: &Tensor, + weight: &Sparse24Tensor, + ) -> Result> { + // CPU fallback: decompress weight to dense, then standard matmul + // input: [N, K], weight: [M, K] → output: [N, M] + // matmul(input, weight^T) = matmul(input [N,K], dense_weight^T [K,M]) → [N, M] + let dense_weight = self.sparse_24_to_dense(weight)?; + let weight_t = dense_weight.t()?; + self.matmul(input, &weight_t) + } +} diff --git a/src/ops/cpu/statistics.rs b/src/ops/cpu/statistics.rs index b786d978..48acdced 100644 --- a/src/ops/cpu/statistics.rs +++ b/src/ops/cpu/statistics.rs @@ -33,7 +33,7 @@ impl StatisticalOps for CpuClient { // Reduce over all dimensions - return scalar variance let numel = a.numel(); let a_contig = ensure_contiguous(a); - let a_ptr = a_contig.storage().ptr(); + let a_ptr = a_contig.ptr(); let variance = dispatch_dtype!(dtype, T => { unsafe { @@ -58,7 +58,7 @@ impl StatisticalOps for CpuClient { let out_shape = if keepdim { vec![1; ndim] } else { vec![] }; let out = Tensor::::empty(&out_shape, dtype, &self.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -85,8 +85,8 @@ impl StatisticalOps for CpuClient { let a_contig = ensure_contiguous(a); let out = Tensor::::empty(&out_shape, dtype, &self.device); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { diff --git a/src/ops/cpu/type_conversion.rs b/src/ops/cpu/type_conversion.rs index 57c7616e..da3c6b22 100644 --- a/src/ops/cpu/type_conversion.rs +++ b/src/ops/cpu/type_conversion.rs @@ -21,8 +21,8 @@ impl TypeConversionOps for CpuClient { let a_contig = ensure_contiguous(a); let out = Tensor::::empty(shape, target_dtype, &self.device); - let src_ptr = a_contig.storage().ptr() as *const u8; - let dst_ptr = out.storage().ptr() as *mut u8; + let src_ptr = a_contig.ptr() as *const u8; + let dst_ptr = out.ptr() as *mut u8; unsafe { kernels::cast_kernel(src_ptr, dst_ptr, numel, src_dtype, target_dtype)?; diff --git a/src/ops/cpu/unary.rs b/src/ops/cpu/unary.rs index 0b74f8f5..a406230a 100644 --- a/src/ops/cpu/unary.rs +++ b/src/ops/cpu/unary.rs @@ -141,8 +141,8 @@ impl UnaryOps for CpuClient { let a_contig = ensure_contiguous(a); let out = Tensor::::empty(a.shape(), DType::U8, &self.device); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); let numel = a.numel(); dispatch_dtype!(dtype, T => { @@ -163,8 +163,8 @@ impl UnaryOps for CpuClient { let a_contig = ensure_contiguous(a); let out = Tensor::::empty(a.shape(), DType::U8, &self.device); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); let numel = a.numel(); dispatch_dtype!(dtype, T => { diff --git a/src/ops/cpu/utility.rs b/src/ops/cpu/utility.rs index af44debe..0c954a87 100644 --- a/src/ops/cpu/utility.rs +++ b/src/ops/cpu/utility.rs @@ -25,8 +25,8 @@ impl UtilityOps for CpuClient { let a_contig = ensure_contiguous(a); let out = Tensor::::empty(a.shape(), dtype, &self.device); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); let numel = a.numel(); dispatch_dtype!(dtype, T => { @@ -46,7 +46,7 @@ impl UtilityOps for CpuClient { fn fill(&self, shape: &[usize], value: f64, dtype: DType) -> Result> { let out = Tensor::::empty(shape, dtype, &self.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); let numel = out.numel(); dispatch_dtype!(dtype, T => { @@ -72,7 +72,7 @@ impl UtilityOps for CpuClient { } let out = Tensor::::empty(&[numel], dtype, &self.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -100,7 +100,7 @@ impl UtilityOps for CpuClient { if steps == 1 { let out = Tensor::::empty(&[1], dtype, &self.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -112,7 +112,7 @@ impl UtilityOps for CpuClient { } let out = Tensor::::empty(&[steps], dtype, &self.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -138,7 +138,7 @@ impl UtilityOps for CpuClient { } let out = Tensor::::empty(&[rows, cols], dtype, &self.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -179,14 +179,14 @@ impl UtilityOps for CpuClient { out_shape.push(num_classes); let out = Tensor::::empty(&out_shape, DType::F32, &self.device); - let out_ptr = out.storage().ptr() as *mut f32; + let out_ptr = out.ptr() as *mut f32; // Zero-fill output unsafe { std::ptr::write_bytes(out_ptr, 0, numel * num_classes); } - let indices_ptr = indices.storage().ptr(); + let indices_ptr = indices.ptr(); // Dispatch on index dtype to read indices, write into f32 output dispatch_dtype!(dtype, T => { diff --git a/src/ops/cuda/activation.rs b/src/ops/cuda/activation.rs index 9b3793e2..0b5e9ae7 100644 --- a/src/ops/cuda/activation.rs +++ b/src/ops/cuda/activation.rs @@ -2,9 +2,12 @@ use crate::error::{Error, Result}; use crate::ops::ActivationOps; use crate::ops::activation::normalize_softmax_dim; +use crate::ops::impl_generic::activation::{dropout_impl, log_softmax_impl, softplus_impl}; use crate::runtime::cuda::kernels::{ - launch_elu, launch_gelu, launch_leaky_relu, launch_relu, launch_sigmoid, launch_silu, - launch_softmax, launch_softmax_dim, + launch_elu, launch_gelu, launch_gelu_mul, launch_gelu_mul_bwd, launch_leaky_relu, launch_relu, + launch_relu_mul, launch_relu_mul_bwd, launch_sigmoid, launch_sigmoid_mul, + launch_sigmoid_mul_bwd, launch_silu, launch_silu_mul, launch_silu_mul_bwd, launch_softmax, + launch_softmax_bwd, launch_softmax_bwd_dim, launch_softmax_dim, }; use crate::runtime::cuda::{CudaClient, CudaRuntime}; use crate::runtime::ensure_contiguous; @@ -22,8 +25,8 @@ impl ActivationOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), out.numel(), )?; } @@ -42,8 +45,8 @@ impl ActivationOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), out.numel(), )?; } @@ -62,8 +65,8 @@ impl ActivationOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), out.numel(), )?; } @@ -82,8 +85,8 @@ impl ActivationOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), out.numel(), )?; } @@ -91,6 +94,258 @@ impl ActivationOps for CudaClient { Ok(out) } + fn silu_mul( + &self, + a: &Tensor, + b: &Tensor, + ) -> Result> { + let dtype = a.dtype(); + if b.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: b.dtype(), + }); + } + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let out = Tensor::::empty(a.shape(), dtype, &self.device); + + unsafe { + launch_silu_mul( + &self.context, + &self.stream, + self.device.index, + dtype, + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), + out.numel(), + )?; + } + + Ok(out) + } + + fn gelu_mul( + &self, + a: &Tensor, + b: &Tensor, + ) -> Result> { + let dtype = a.dtype(); + if b.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: b.dtype(), + }); + } + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let out = Tensor::::empty(a.shape(), dtype, &self.device); + + unsafe { + launch_gelu_mul( + &self.context, + &self.stream, + self.device.index, + dtype, + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), + out.numel(), + )?; + } + + Ok(out) + } + + fn relu_mul( + &self, + a: &Tensor, + b: &Tensor, + ) -> Result> { + let dtype = a.dtype(); + if b.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: b.dtype(), + }); + } + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let out = Tensor::::empty(a.shape(), dtype, &self.device); + + unsafe { + launch_relu_mul( + &self.context, + &self.stream, + self.device.index, + dtype, + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), + out.numel(), + )?; + } + + Ok(out) + } + + fn sigmoid_mul( + &self, + a: &Tensor, + b: &Tensor, + ) -> Result> { + let dtype = a.dtype(); + if b.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: b.dtype(), + }); + } + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let out = Tensor::::empty(a.shape(), dtype, &self.device); + + unsafe { + launch_sigmoid_mul( + &self.context, + &self.stream, + self.device.index, + dtype, + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), + out.numel(), + )?; + } + + Ok(out) + } + + fn silu_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let dtype = a.dtype(); + let grad_contig = ensure_contiguous(grad); + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let d_a = Tensor::::empty(a.shape(), dtype, &self.device); + let d_b = Tensor::::empty(b.shape(), dtype, &self.device); + + unsafe { + launch_silu_mul_bwd( + &self.context, + &self.stream, + self.device.index, + dtype, + grad_contig.ptr(), + a_contig.ptr(), + b_contig.ptr(), + d_a.ptr(), + d_b.ptr(), + a.numel(), + )?; + } + + Ok((d_a, d_b)) + } + + fn gelu_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let dtype = a.dtype(); + let grad_contig = ensure_contiguous(grad); + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let d_a = Tensor::::empty(a.shape(), dtype, &self.device); + let d_b = Tensor::::empty(b.shape(), dtype, &self.device); + + unsafe { + launch_gelu_mul_bwd( + &self.context, + &self.stream, + self.device.index, + dtype, + grad_contig.ptr(), + a_contig.ptr(), + b_contig.ptr(), + d_a.ptr(), + d_b.ptr(), + a.numel(), + )?; + } + + Ok((d_a, d_b)) + } + + fn relu_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let dtype = a.dtype(); + let grad_contig = ensure_contiguous(grad); + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let d_a = Tensor::::empty(a.shape(), dtype, &self.device); + let d_b = Tensor::::empty(b.shape(), dtype, &self.device); + + unsafe { + launch_relu_mul_bwd( + &self.context, + &self.stream, + self.device.index, + dtype, + grad_contig.ptr(), + a_contig.ptr(), + b_contig.ptr(), + d_a.ptr(), + d_b.ptr(), + a.numel(), + )?; + } + + Ok((d_a, d_b)) + } + + fn sigmoid_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let dtype = a.dtype(); + let grad_contig = ensure_contiguous(grad); + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let d_a = Tensor::::empty(a.shape(), dtype, &self.device); + let d_b = Tensor::::empty(b.shape(), dtype, &self.device); + + unsafe { + launch_sigmoid_mul_bwd( + &self.context, + &self.stream, + self.device.index, + dtype, + grad_contig.ptr(), + a_contig.ptr(), + b_contig.ptr(), + d_a.ptr(), + d_b.ptr(), + a.numel(), + )?; + } + + Ok((d_a, d_b)) + } + fn leaky_relu( &self, a: &Tensor, @@ -106,8 +361,8 @@ impl ActivationOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), out.numel(), negative_slope as f32, )?; @@ -127,8 +382,8 @@ impl ActivationOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), out.numel(), alpha as f32, )?; @@ -163,8 +418,8 @@ impl ActivationOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), outer_size, dim_size, )?; @@ -175,8 +430,8 @@ impl ActivationOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), outer_size, dim_size, inner_size, @@ -186,4 +441,73 @@ impl ActivationOps for CudaClient { Ok(out) } + + fn softmax_bwd( + &self, + grad: &Tensor, + output: &Tensor, + dim: isize, + ) -> Result> { + let dtype = grad.dtype(); + let ndim = grad.ndim(); + let dim_idx = + normalize_softmax_dim(ndim, dim).ok_or(Error::InvalidDimension { dim, ndim })?; + + let grad_contig = ensure_contiguous(grad); + let output_contig = ensure_contiguous(output); + let d_input = Tensor::::empty(grad.shape(), dtype, &self.device); + + let shape = grad.shape(); + let outer_size: usize = shape[..dim_idx].iter().product::().max(1); + let dim_size = shape[dim_idx]; + let inner_size: usize = shape[dim_idx + 1..].iter().product::().max(1); + + unsafe { + if dim_idx == ndim - 1 { + launch_softmax_bwd( + &self.context, + &self.stream, + self.device.index, + dtype, + grad_contig.ptr(), + output_contig.ptr(), + d_input.ptr(), + outer_size, + dim_size, + )?; + } else { + launch_softmax_bwd_dim( + &self.context, + &self.stream, + self.device.index, + dtype, + grad_contig.ptr(), + output_contig.ptr(), + d_input.ptr(), + outer_size, + dim_size, + inner_size, + )?; + } + } + + Ok(d_input) + } + + fn softplus(&self, a: &Tensor) -> Result> { + softplus_impl(self, a) + } + + fn log_softmax(&self, a: &Tensor, dim: isize) -> Result> { + log_softmax_impl(self, a, dim) + } + + fn dropout( + &self, + a: &Tensor, + p: f64, + training: bool, + ) -> Result> { + dropout_impl(self, a, p, training) + } } diff --git a/src/ops/cuda/advanced_random.rs b/src/ops/cuda/advanced_random.rs index 5432d176..bc20d751 100644 --- a/src/ops/cuda/advanced_random.rs +++ b/src/ops/cuda/advanced_random.rs @@ -37,7 +37,7 @@ impl AdvancedRandomOps for CudaClient { dtype, key, counter, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -74,7 +74,7 @@ impl AdvancedRandomOps for CudaClient { dtype, key, counter, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -111,7 +111,7 @@ impl AdvancedRandomOps for CudaClient { dtype, key, counter, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -148,7 +148,7 @@ impl AdvancedRandomOps for CudaClient { dtype, key, counter, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -185,7 +185,7 @@ impl AdvancedRandomOps for CudaClient { dtype, seed, stream, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -222,7 +222,7 @@ impl AdvancedRandomOps for CudaClient { dtype, seed, stream, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -257,7 +257,7 @@ impl AdvancedRandomOps for CudaClient { self.device.index, dtype, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -292,7 +292,7 @@ impl AdvancedRandomOps for CudaClient { self.device.index, dtype, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } diff --git a/src/ops/cuda/binary.rs b/src/ops/cuda/binary.rs index 2bafe9e5..8a8f410c 100644 --- a/src/ops/cuda/binary.rs +++ b/src/ops/cuda/binary.rs @@ -1,8 +1,10 @@ //! Binary operations for CUDA runtime -use crate::error::Result; +use crate::error::{Error, Result}; use crate::ops::BinaryOps; +use crate::runtime::cuda::kernels::{launch_fused_add_mul, launch_fused_mul_add}; use crate::runtime::cuda::ops::helpers::native_binary_op; use crate::runtime::cuda::{CudaClient, CudaRuntime}; +use crate::runtime::ensure_contiguous; use crate::tensor::Tensor; impl BinaryOps for CudaClient { @@ -49,4 +51,82 @@ impl BinaryOps for CudaClient { ) -> Result> { native_binary_op(self, y, x, "atan2") } + + fn fused_mul_add( + &self, + a: &Tensor, + b: &Tensor, + c: &Tensor, + ) -> Result> { + let dtype = a.dtype(); + if b.dtype() != dtype || c.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if b.dtype() != dtype { + b.dtype() + } else { + c.dtype() + }, + }); + } + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let c_contig = ensure_contiguous(c); + let out = Tensor::::empty(a.shape(), dtype, &self.device); + + unsafe { + launch_fused_mul_add( + &self.context, + &self.stream, + self.device.index, + dtype, + a_contig.ptr(), + b_contig.ptr(), + c_contig.ptr(), + out.ptr(), + out.numel(), + )?; + } + + Ok(out) + } + + fn fused_add_mul( + &self, + a: &Tensor, + b: &Tensor, + c: &Tensor, + ) -> Result> { + let dtype = a.dtype(); + if b.dtype() != dtype || c.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if b.dtype() != dtype { + b.dtype() + } else { + c.dtype() + }, + }); + } + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let c_contig = ensure_contiguous(c); + let out = Tensor::::empty(a.shape(), dtype, &self.device); + + unsafe { + launch_fused_add_mul( + &self.context, + &self.stream, + self.device.index, + dtype, + a_contig.ptr(), + b_contig.ptr(), + c_contig.ptr(), + out.ptr(), + out.numel(), + )?; + } + + Ok(out) + } } diff --git a/src/ops/cuda/complex.rs b/src/ops/cuda/complex.rs index 92785c94..a0bdcd63 100644 --- a/src/ops/cuda/complex.rs +++ b/src/ops/cuda/complex.rs @@ -29,8 +29,8 @@ impl ComplexOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), a.numel(), )?; } @@ -62,8 +62,8 @@ impl ComplexOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), a.numel(), )?; } @@ -97,7 +97,7 @@ impl ComplexOps for CudaClient { self.device.index, out_dtype, 0.0, - out.storage().ptr(), + out.ptr(), out.numel(), )?; } @@ -113,8 +113,8 @@ impl ComplexOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), a.numel(), )?; } @@ -149,8 +149,8 @@ impl ComplexOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), a.numel(), )?; }, @@ -163,7 +163,7 @@ impl ComplexOps for CudaClient { self.device.index, out_dtype, 0.0, - out.storage().ptr(), + out.ptr(), out.numel(), )?; } @@ -179,8 +179,8 @@ impl ComplexOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), a.numel(), )?; } @@ -221,9 +221,9 @@ impl ComplexOps for CudaClient { &self.stream, self.device.index, input_dtype, - real_contig.storage().ptr(), - imag_contig.storage().ptr(), - out.storage().ptr(), + real_contig.ptr(), + imag_contig.ptr(), + out.ptr(), numel, )?; } @@ -257,9 +257,9 @@ impl ComplexOps for CudaClient { &self.stream, self.device.index, dtype, - complex_contig.storage().ptr(), - real_contig.storage().ptr(), - out.storage().ptr(), + complex_contig.ptr(), + real_contig.ptr(), + out.ptr(), numel, )?; } @@ -293,9 +293,9 @@ impl ComplexOps for CudaClient { &self.stream, self.device.index, dtype, - complex_contig.storage().ptr(), - real_contig.storage().ptr(), - out.storage().ptr(), + complex_contig.ptr(), + real_contig.ptr(), + out.ptr(), numel, )?; } diff --git a/src/ops/cuda/conditional.rs b/src/ops/cuda/conditional.rs index 981b5ad9..45c84c76 100644 --- a/src/ops/cuda/conditional.rs +++ b/src/ops/cuda/conditional.rs @@ -38,10 +38,10 @@ impl ConditionalOps for CudaClient { &self.stream, self.device.index, dtype, - cond_contig.storage().ptr(), - x_contig.storage().ptr(), - y_contig.storage().ptr(), - out.storage().ptr(), + cond_contig.ptr(), + x_contig.ptr(), + y_contig.ptr(), + out.ptr(), out.numel(), )?; } else { @@ -52,10 +52,10 @@ impl ConditionalOps for CudaClient { self.device.index, cond_dtype, dtype, - cond_contig.storage().ptr(), - x_contig.storage().ptr(), - y_contig.storage().ptr(), - out.storage().ptr(), + cond_contig.ptr(), + x_contig.ptr(), + y_contig.ptr(), + out.ptr(), out.numel(), )?; } @@ -88,10 +88,10 @@ impl ConditionalOps for CudaClient { self.device.index, &self.device, dtype, - cond_contig.storage().ptr(), - x_contig.storage().ptr(), - y_contig.storage().ptr(), - out.storage().ptr(), + cond_contig.ptr(), + x_contig.ptr(), + y_contig.ptr(), + out.ptr(), cond.shape(), x.shape(), y.shape(), @@ -106,10 +106,10 @@ impl ConditionalOps for CudaClient { &self.device, cond_dtype, dtype, - cond_contig.storage().ptr(), - x_contig.storage().ptr(), - y_contig.storage().ptr(), - out.storage().ptr(), + cond_contig.ptr(), + x_contig.ptr(), + y_contig.ptr(), + out.ptr(), cond.shape(), x.shape(), y.shape(), diff --git a/src/ops/cuda/conv.rs b/src/ops/cuda/conv.rs index ce7bdb93..a3cbef19 100644 --- a/src/ops/cuda/conv.rs +++ b/src/ops/cuda/conv.rs @@ -57,10 +57,10 @@ impl ConvOps for CudaClient { ); // Get device pointers - let input_ptr = input.storage().ptr(); - let weight_ptr = weight.storage().ptr(); - let bias_ptr = bias.as_ref().map(|b| b.storage().ptr()); - let output_ptr = output.storage().ptr(); + let input_ptr = input.ptr(); + let weight_ptr = weight.ptr(); + let bias_ptr = bias.as_ref().map(|b| b.ptr()); + let output_ptr = output.ptr(); // Launch CUDA kernel unsafe { @@ -137,10 +137,10 @@ impl ConvOps for CudaClient { ); // Get device pointers - let input_ptr = input.storage().ptr(); - let weight_ptr = weight.storage().ptr(); - let bias_ptr = bias.as_ref().map(|b| b.storage().ptr()); - let output_ptr = output.storage().ptr(); + let input_ptr = input.ptr(); + let weight_ptr = weight.ptr(); + let bias_ptr = bias.as_ref().map(|b| b.ptr()); + let output_ptr = output.ptr(); // Launch CUDA kernel unsafe { @@ -221,10 +221,10 @@ impl ConvOps for CudaClient { ); // Get device pointers - let input_ptr = input.storage().ptr(); - let weight_ptr = weight.storage().ptr(); - let bias_ptr = bias.as_ref().map(|b| b.storage().ptr()); - let output_ptr = output.storage().ptr(); + let input_ptr = input.ptr(); + let weight_ptr = weight.ptr(); + let bias_ptr = bias.as_ref().map(|b| b.ptr()); + let output_ptr = output.ptr(); // Launch CUDA kernel unsafe { diff --git a/src/ops/cuda/cumulative.rs b/src/ops/cuda/cumulative.rs index 43d62b93..c5a5587a 100644 --- a/src/ops/cuda/cumulative.rs +++ b/src/ops/cuda/cumulative.rs @@ -54,8 +54,8 @@ impl CumulativeOps for CudaClient { &self.stream, self.device.index, a.dtype(), - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), scan_size, outer, )?; @@ -68,8 +68,8 @@ impl CumulativeOps for CudaClient { &self.stream, self.device.index, a.dtype(), - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), scan_size, outer_size.max(1), inner_size, @@ -124,8 +124,8 @@ impl CumulativeOps for CudaClient { &self.stream, self.device.index, a.dtype(), - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), scan_size, outer, )?; @@ -138,8 +138,8 @@ impl CumulativeOps for CudaClient { &self.stream, self.device.index, a.dtype(), - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), scan_size, outer_size.max(1), inner_size, @@ -256,8 +256,8 @@ impl CumulativeOps for CudaClient { &self.stream, self.device.index, a_compute.dtype(), - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), reduce_size, outer, )?; @@ -270,8 +270,8 @@ impl CumulativeOps for CudaClient { &self.stream, self.device.index, a_compute.dtype(), - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), reduce_size, outer_size.max(1), inner_size, diff --git a/src/ops/cuda/distance.rs b/src/ops/cuda/distance.rs index e97ec611..dc398be4 100644 --- a/src/ops/cuda/distance.rs +++ b/src/ops/cuda/distance.rs @@ -46,9 +46,9 @@ impl DistanceOps for CudaClient { &self.stream, self.device.index, dtype, - x.storage().ptr(), - y.storage().ptr(), - out.storage().ptr(), + x.ptr(), + y.ptr(), + out.ptr(), n, m, d, @@ -91,8 +91,8 @@ impl DistanceOps for CudaClient { &self.stream, self.device.index, dtype, - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), n, d, metric, @@ -131,8 +131,8 @@ impl DistanceOps for CudaClient { &self.stream, self.device.index, dtype, - condensed.storage().ptr(), - out.storage().ptr(), + condensed.ptr(), + out.ptr(), n, )?; } @@ -168,8 +168,8 @@ impl DistanceOps for CudaClient { &self.stream, self.device.index, dtype, - square.storage().ptr(), - out.storage().ptr(), + square.ptr(), + out.ptr(), n, )?; } diff --git a/src/ops/cuda/fp8_matmul.rs b/src/ops/cuda/fp8_matmul.rs new file mode 100644 index 00000000..8a8fbc1a --- /dev/null +++ b/src/ops/cuda/fp8_matmul.rs @@ -0,0 +1,185 @@ +//! CUDA implementation of FP8 matrix multiplication operations. + +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::ops::{Fp8MatmulOps, matmul_output_shape}; +use crate::runtime::cuda::kernels::{ + launch_fp8_matmul_e4m3, launch_fp8_matmul_e4m3_batched, launch_fp8_matmul_e5m2, + launch_fp8_matmul_e5m2_batched, +}; +use crate::runtime::cuda::{CudaClient, CudaRuntime}; +use crate::runtime::ensure_contiguous; +use crate::tensor::Tensor; + +/// Validate FP8 matmul inputs and extract dimensions. +fn validate_and_extract( + a: &Tensor, + b: &Tensor, + expected_a_dtype: DType, + expected_b_dtype: DType, + out_dtype: DType, +) -> Result<(Vec, usize, usize, usize, usize)> { + if a.dtype() != expected_a_dtype { + return Err(Error::DTypeMismatch { + lhs: a.dtype(), + rhs: expected_a_dtype, + }); + } + if b.dtype() != expected_b_dtype { + return Err(Error::DTypeMismatch { + lhs: b.dtype(), + rhs: expected_b_dtype, + }); + } + match out_dtype { + DType::F32 | DType::F16 | DType::BF16 => {} + _ => { + return Err(Error::UnsupportedDType { + dtype: out_dtype, + op: "fp8_matmul output", + }); + } + } + + let a_shape = a.shape(); + let b_shape = b.shape(); + if a_shape.len() < 2 || b_shape.len() < 2 { + return Err(Error::ShapeMismatch { + expected: a_shape.to_vec(), + got: b_shape.to_vec(), + }); + } + + let m = a_shape[a_shape.len() - 2]; + let k = a_shape[a_shape.len() - 1]; + let k_b = b_shape[b_shape.len() - 2]; + let n = b_shape[b_shape.len() - 1]; + + if k != k_b { + return Err(Error::ShapeMismatch { + expected: a_shape.to_vec(), + got: b_shape.to_vec(), + }); + } + + let out_shape = matmul_output_shape(a_shape, b_shape).ok_or(Error::ShapeMismatch { + expected: a_shape.to_vec(), + got: b_shape.to_vec(), + })?; + + let batch_size: usize = out_shape + .iter() + .take(out_shape.len().saturating_sub(2)) + .product(); + let batch_size = batch_size.max(1); + + Ok((out_shape, batch_size, m, k, n)) +} + +impl Fp8MatmulOps for CudaClient { + fn fp8_matmul( + &self, + a: &Tensor, + b: &Tensor, + scale_a: f32, + scale_b: f32, + out_dtype: DType, + ) -> Result> { + let (out_shape, batch_size, m, k, n) = + validate_and_extract(a, b, DType::FP8E4M3, DType::FP8E4M3, out_dtype)?; + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let out = Tensor::::empty(&out_shape, out_dtype, &self.device); + + unsafe { + if batch_size > 1 { + launch_fp8_matmul_e4m3_batched( + &self.context, + &self.stream, + self.device.index, + out_dtype, + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), + scale_a, + scale_b, + batch_size, + m, + n, + k, + )?; + } else { + launch_fp8_matmul_e4m3( + &self.context, + &self.stream, + self.device.index, + out_dtype, + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), + scale_a, + scale_b, + m, + n, + k, + )?; + } + } + + Ok(out) + } + + fn fp8_matmul_e5m2( + &self, + a: &Tensor, + b: &Tensor, + scale_a: f32, + scale_b: f32, + out_dtype: DType, + ) -> Result> { + let (out_shape, batch_size, m, k, n) = + validate_and_extract(a, b, DType::FP8E5M2, DType::FP8E4M3, out_dtype)?; + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let out = Tensor::::empty(&out_shape, out_dtype, &self.device); + + unsafe { + if batch_size > 1 { + launch_fp8_matmul_e5m2_batched( + &self.context, + &self.stream, + self.device.index, + out_dtype, + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), + scale_a, + scale_b, + batch_size, + m, + n, + k, + )?; + } else { + launch_fp8_matmul_e5m2( + &self.context, + &self.stream, + self.device.index, + out_dtype, + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), + scale_a, + scale_b, + m, + n, + k, + )?; + } + } + + Ok(out) + } +} diff --git a/src/ops/cuda/gemm_epilogue.rs b/src/ops/cuda/gemm_epilogue.rs new file mode 100644 index 00000000..d45e4f98 --- /dev/null +++ b/src/ops/cuda/gemm_epilogue.rs @@ -0,0 +1,308 @@ +//! CUDA implementation of GEMM epilogue operations. + +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::ops::{ + GemmActivation, GemmEpilogueOps, TypeConversionOps, matmul_bias_output_shape, + validate_matmul_bias_dtypes, +}; +use crate::runtime::cuda::kernels::{ + launch_gemm_bias_act_batched_kernel, launch_gemm_bias_act_bwd_batched_kernel, + launch_gemm_bias_act_bwd_kernel, launch_gemm_bias_act_kernel, + launch_gemm_bias_residual_batched_kernel, launch_gemm_bias_residual_kernel, +}; +use crate::runtime::cuda::{CudaClient, CudaRuntime}; +use crate::runtime::ensure_contiguous; +use crate::tensor::Tensor; + +impl GemmEpilogueOps for CudaClient { + fn matmul_bias_activation( + &self, + a: &Tensor, + b: &Tensor, + bias: &Tensor, + activation: GemmActivation, + ) -> Result> { + let dtype = validate_matmul_bias_dtypes(a.dtype(), b.dtype(), bias.dtype())?; + + // FP8: compute in F32 (tiled GEMM with shared memory needs native arithmetic) + #[cfg(feature = "fp8")] + if dtype == DType::FP8E4M3 || dtype == DType::FP8E5M2 { + let a_f32 = self.cast(a, DType::F32)?; + let b_f32 = self.cast(b, DType::F32)?; + let bias_f32 = self.cast(bias, DType::F32)?; + let result = self.matmul_bias_activation(&a_f32, &b_f32, &bias_f32, activation)?; + return self.cast(&result, dtype); + } + + if bias.shape().len() != 1 { + return Err(Error::InvalidArgument { + arg: "bias", + reason: format!("bias must be 1D tensor, got shape {:?}", bias.shape()), + }); + } + + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = if a_shape.len() >= 2 { + a_shape[a_shape.len() - 2] + } else { + 1 + }; + let k = a_shape[a_shape.len() - 1]; + let n = b_shape[b_shape.len() - 1]; + + let out_shape = matmul_bias_output_shape(a_shape, b_shape, bias.shape()).ok_or( + Error::ShapeMismatch { + expected: a_shape.to_vec(), + got: b_shape.to_vec(), + }, + )?; + + let batch_size: usize = out_shape + .iter() + .take(out_shape.len().saturating_sub(2)) + .product::() + .max(1); + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let bias_contig = ensure_contiguous(bias); + + let out = Tensor::::empty(&out_shape, dtype, &self.device); + + unsafe { + if batch_size > 1 { + launch_gemm_bias_act_batched_kernel( + &self.context, + &self.stream, + self.device.index, + dtype, + a_contig.ptr(), + b_contig.ptr(), + bias_contig.ptr(), + out.ptr(), + batch_size, + m, + n, + k, + activation, + )?; + } else { + launch_gemm_bias_act_kernel( + &self.context, + &self.stream, + self.device.index, + dtype, + a_contig.ptr(), + b_contig.ptr(), + bias_contig.ptr(), + out.ptr(), + m, + n, + k, + activation, + )?; + } + } + + Ok(out) + } + + fn matmul_bias_residual( + &self, + a: &Tensor, + b: &Tensor, + bias: &Tensor, + residual: &Tensor, + ) -> Result> { + let dtype = validate_matmul_bias_dtypes(a.dtype(), b.dtype(), bias.dtype())?; + + // FP8: compute in F32 + #[cfg(feature = "fp8")] + if dtype == DType::FP8E4M3 || dtype == DType::FP8E5M2 { + let a_f32 = self.cast(a, DType::F32)?; + let b_f32 = self.cast(b, DType::F32)?; + let bias_f32 = self.cast(bias, DType::F32)?; + let res_f32 = self.cast(residual, DType::F32)?; + let result = self.matmul_bias_residual(&a_f32, &b_f32, &bias_f32, &res_f32)?; + return self.cast(&result, dtype); + } + + if residual.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: residual.dtype(), + }); + } + + let a_shape = a.shape(); + let b_shape = b.shape(); + + let out_shape = matmul_bias_output_shape(a_shape, b_shape, bias.shape()).ok_or( + Error::ShapeMismatch { + expected: a_shape.to_vec(), + got: b_shape.to_vec(), + }, + )?; + + if residual.shape() != out_shape.as_slice() { + return Err(Error::ShapeMismatch { + expected: out_shape.clone(), + got: residual.shape().to_vec(), + }); + } + + let m = if a_shape.len() >= 2 { + a_shape[a_shape.len() - 2] + } else { + 1 + }; + let k = a_shape[a_shape.len() - 1]; + let n = b_shape[b_shape.len() - 1]; + + let batch_size: usize = out_shape + .iter() + .take(out_shape.len().saturating_sub(2)) + .product::() + .max(1); + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let bias_contig = ensure_contiguous(bias); + let res_contig = ensure_contiguous(residual); + + let out = Tensor::::empty(&out_shape, dtype, &self.device); + + unsafe { + if batch_size > 1 { + launch_gemm_bias_residual_batched_kernel( + &self.context, + &self.stream, + self.device.index, + dtype, + a_contig.ptr(), + b_contig.ptr(), + bias_contig.ptr(), + res_contig.ptr(), + out.ptr(), + batch_size, + m, + n, + k, + )?; + } else { + launch_gemm_bias_residual_kernel( + &self.context, + &self.stream, + self.device.index, + dtype, + a_contig.ptr(), + b_contig.ptr(), + bias_contig.ptr(), + res_contig.ptr(), + out.ptr(), + m, + n, + k, + )?; + } + } + + Ok(out) + } + + fn matmul_bias_activation_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + bias: &Tensor, + activation: GemmActivation, + ) -> Result<( + Tensor, + Tensor, + Tensor, + )> { + let dtype = validate_matmul_bias_dtypes(a.dtype(), b.dtype(), bias.dtype())?; + if grad.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: grad.dtype(), + }); + } + + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = if a_shape.len() >= 2 { + a_shape[a_shape.len() - 2] + } else { + 1 + }; + let k = a_shape[a_shape.len() - 1]; + let n = b_shape[b_shape.len() - 1]; + + let batch_size: usize = a_shape + .iter() + .take(a_shape.len().saturating_sub(2)) + .product::() + .max(1); + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let bias_contig = ensure_contiguous(bias); + let grad_contig = ensure_contiguous(grad); + + let d_a = Tensor::::empty(a_shape, dtype, &self.device); + let d_b = Tensor::::zeros(b_shape, dtype, &self.device); + let d_bias = Tensor::::zeros(&[n], dtype, &self.device); + + // Temporary buffer for grad_pre (M * N elements, reused per batch) + let grad_pre = Tensor::::empty(&[m, n], dtype, &self.device); + + unsafe { + if batch_size > 1 { + launch_gemm_bias_act_bwd_batched_kernel( + &self.context, + &self.stream, + self.device.index, + dtype, + grad_contig.ptr(), + a_contig.ptr(), + b_contig.ptr(), + bias_contig.ptr(), + grad_pre.ptr(), + d_a.ptr(), + d_b.ptr(), + d_bias.ptr(), + batch_size, + m, + n, + k, + activation, + )?; + } else { + launch_gemm_bias_act_bwd_kernel( + &self.context, + &self.stream, + self.device.index, + dtype, + grad_contig.ptr(), + a_contig.ptr(), + b_contig.ptr(), + bias_contig.ptr(), + grad_pre.ptr(), + d_a.ptr(), + d_b.ptr(), + d_bias.ptr(), + m, + n, + k, + activation, + )?; + } + } + + Ok((d_a, d_b, d_bias)) + } +} diff --git a/src/ops/cuda/indexing/advanced.rs b/src/ops/cuda/indexing/advanced.rs index e1781856..d7200883 100644 --- a/src/ops/cuda/indexing/advanced.rs +++ b/src/ops/cuda/indexing/advanced.rs @@ -52,9 +52,9 @@ pub fn embedding_lookup( &client.stream, client.device.index, dtype, - emb_contig.storage().ptr(), - idx_contig.storage().ptr(), - out.storage().ptr(), + emb_contig.ptr(), + idx_contig.ptr(), + out.ptr(), num_indices, vocab_size, embedding_dim, @@ -152,8 +152,8 @@ pub fn scatter_reduce( &client.stream, client.device.index, dtype, - dst_contig.storage().ptr(), - out.storage().ptr(), + dst_contig.ptr(), + out.ptr(), dst.numel(), )?; } @@ -172,7 +172,7 @@ pub fn scatter_reduce( client.device.index, dtype, identity, - out.storage().ptr(), + out.ptr(), dst.numel(), )?; } @@ -190,9 +190,9 @@ pub fn scatter_reduce( &client.stream, client.device.index, dtype, - src_contig.storage().ptr(), - index_contig.storage().ptr(), - out.storage().ptr(), + src_contig.ptr(), + index_contig.ptr(), + out.ptr(), dim, outer_size, dim_size, @@ -221,7 +221,7 @@ pub fn scatter_reduce( client.device.index, dtype, 0.0, - count.storage().ptr(), + count.ptr(), dst.numel(), )?; } @@ -235,7 +235,7 @@ pub fn scatter_reduce( client.device.index, dtype, 1.0, - count.storage().ptr(), + count.ptr(), dst.numel(), )?; } @@ -248,8 +248,8 @@ pub fn scatter_reduce( &client.stream, client.device.index, dtype, - index_contig.storage().ptr(), - count.storage().ptr(), + index_contig.ptr(), + count.ptr(), dim, outer_size, dim_size, @@ -266,9 +266,9 @@ pub fn scatter_reduce( &client.stream, client.device.index, dtype, - out.storage().ptr(), - count.storage().ptr(), - result.storage().ptr(), + out.ptr(), + count.ptr(), + result.ptr(), dst.numel(), )?; } @@ -361,9 +361,9 @@ pub fn gather_nd( &client.stream, client.device.index, dtype, - input_contig.storage().ptr(), - indices_contig.storage().ptr(), - out.storage().ptr(), + input_contig.ptr(), + indices_contig.ptr(), + out.ptr(), shape_ptr, strides_ptr, num_slices, @@ -449,13 +449,13 @@ pub fn bincount( client.device.index, out_dtype, 0.0, - out.storage().ptr(), + out.ptr(), output_len, )?; } let weights_contig = weights.map(ensure_contiguous); - let weights_ptr = weights_contig.as_ref().map(|w| w.storage().ptr()); + let weights_ptr = weights_contig.as_ref().map(|w| w.ptr()); unsafe { launch_bincount_weighted( @@ -464,9 +464,9 @@ pub fn bincount( client.device.index, input_dtype, weights_dtype, - input_contig.storage().ptr(), + input_contig.ptr(), weights_ptr, - out.storage().ptr(), + out.ptr(), numel, output_len, )?; diff --git a/src/ops/cuda/indexing/argmax.rs b/src/ops/cuda/indexing/argmax.rs index dd72601c..29a30947 100644 --- a/src/ops/cuda/indexing/argmax.rs +++ b/src/ops/cuda/indexing/argmax.rs @@ -39,8 +39,8 @@ pub fn argmax( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), outer_size, reduce_size, inner_size, @@ -81,8 +81,8 @@ pub fn argmin( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), outer_size, reduce_size, inner_size, diff --git a/src/ops/cuda/indexing/gather_scatter.rs b/src/ops/cuda/indexing/gather_scatter.rs index c4be89a7..a377f7c3 100644 --- a/src/ops/cuda/indexing/gather_scatter.rs +++ b/src/ops/cuda/indexing/gather_scatter.rs @@ -4,7 +4,7 @@ use crate::dtype::DType; use crate::error::{Error, Result}; use crate::runtime::cuda::kernels::{ launch_copy, launch_fill_with_f64, launch_gather, launch_gather_2d, launch_index_put, - launch_index_select, launch_scatter, launch_validate_indices, + launch_index_select, launch_scatter, launch_slice_assign, launch_validate_indices, }; use crate::runtime::cuda::{CudaClient, CudaRuntime}; use crate::runtime::{Runtime, compute_contiguous_strides, ensure_contiguous}; @@ -82,9 +82,9 @@ pub fn gather( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - index_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + index_contig.ptr(), + out.ptr(), ndim, dim, input_shape_ptr, @@ -155,8 +155,8 @@ pub fn scatter( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), a.numel(), )?; } @@ -208,10 +208,10 @@ pub fn scatter( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - index_contig.storage().ptr(), - src_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + index_contig.ptr(), + src_contig.ptr(), + out.ptr(), ndim, dim, output_shape_ptr, @@ -281,7 +281,7 @@ pub fn index_select( client.device.index, DType::U32, 0.0, - error_count_tensor.storage().ptr(), + error_count_tensor.ptr(), 1, )?; @@ -290,8 +290,8 @@ pub fn index_select( &client.context, &client.stream, client.device.index, - index_contig.storage().ptr(), - error_count_tensor.storage().ptr(), + index_contig.ptr(), + error_count_tensor.ptr(), index_len, dim_size, )?; @@ -321,9 +321,9 @@ pub fn index_select( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - index_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + index_contig.ptr(), + out.ptr(), outer_size, dim_size, inner_size, @@ -397,10 +397,10 @@ pub fn gather_2d( &client.stream, client.device.index, dtype, - input_contig.storage().ptr(), - rows_contig.storage().ptr(), - cols_contig.storage().ptr(), - out.storage().ptr(), + input_contig.ptr(), + rows_contig.ptr(), + cols_contig.ptr(), + out.ptr(), nrows, ncols, num_indices, @@ -477,7 +477,7 @@ pub fn index_put( client.device.index, DType::U32, 0.0, - error_count_tensor.storage().ptr(), + error_count_tensor.ptr(), 1, )?; @@ -486,8 +486,8 @@ pub fn index_put( &client.context, &client.stream, client.device.index, - index_contig.storage().ptr(), - error_count_tensor.storage().ptr(), + index_contig.ptr(), + error_count_tensor.ptr(), index_len, dim_size, )?; @@ -518,9 +518,9 @@ pub fn index_put( &client.stream, client.device.index, dtype, - index_contig.storage().ptr(), - src_contig.storage().ptr(), - out.storage().ptr(), + index_contig.ptr(), + src_contig.ptr(), + out.ptr(), outer_size, dim_size, inner_size, @@ -530,3 +530,95 @@ pub fn index_put( Ok(out) } + +/// Execute slice_assign operation: assign src into a slice of dst along dim. +pub fn slice_assign( + client: &CudaClient, + dst: &Tensor, + src: &Tensor, + dim: usize, + start: usize, +) -> Result> { + let ndim = dst.ndim(); + if dim >= ndim { + return Err(Error::InvalidDimension { + dim: dim as isize, + ndim, + }); + } + + if src.ndim() != ndim { + return Err(Error::ShapeMismatch { + expected: dst.shape().to_vec(), + got: src.shape().to_vec(), + }); + } + for d in 0..ndim { + if d != dim && src.shape()[d] != dst.shape()[d] { + return Err(Error::ShapeMismatch { + expected: dst.shape().to_vec(), + got: src.shape().to_vec(), + }); + } + } + + let src_dim_size = src.shape()[dim]; + let dst_dim_size = dst.shape()[dim]; + if start + src_dim_size > dst_dim_size { + return Err(Error::InvalidArgument { + arg: "start", + reason: format!( + "start ({}) + src dim size ({}) exceeds dst dim size ({})", + start, src_dim_size, dst_dim_size + ), + }); + } + + let dtype = dst.dtype(); + if src.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: src.dtype(), + }); + } + + let outer_size: usize = dst.shape()[..dim].iter().product(); + let outer_size = outer_size.max(1); + let inner_size: usize = dst.shape()[dim + 1..].iter().product(); + let inner_size = inner_size.max(1); + + let dst_contig = ensure_contiguous(dst); + let src_contig = ensure_contiguous(src); + + let out = Tensor::::empty(dst.shape(), dtype, &client.device); + + unsafe { + // Copy dst → output + launch_copy( + &client.context, + &client.stream, + client.device.index, + dtype, + dst_contig.ptr(), + out.ptr(), + dst_contig.numel(), + )?; + + // Overwrite the slice with src + launch_slice_assign( + &client.context, + &client.stream, + client.device.index, + dtype, + src_contig.ptr(), + out.ptr(), + outer_size, + dst_dim_size, + src_dim_size, + inner_size, + start, + )?; + } + + Ok(out) +} diff --git a/src/ops/cuda/indexing/helpers.rs b/src/ops/cuda/indexing/helpers.rs index 981533c8..2126b02a 100644 --- a/src/ops/cuda/indexing/helpers.rs +++ b/src/ops/cuda/indexing/helpers.rs @@ -112,10 +112,7 @@ impl BroadcastContext { self.needs_broadcast, "strides_ptr() called on non-broadcast context" ); - self.strides_tensor - .as_ref() - .map(|t| t.storage().ptr()) - .unwrap_or(0) + self.strides_tensor.as_ref().map(|t| t.ptr()).unwrap_or(0) } /// Get shape pointer. @@ -130,9 +127,6 @@ impl BroadcastContext { self.needs_broadcast, "shape_ptr() called on non-broadcast context" ); - self.shape_tensor - .as_ref() - .map(|t| t.storage().ptr()) - .unwrap_or(0) + self.shape_tensor.as_ref().map(|t| t.ptr()).unwrap_or(0) } } diff --git a/src/ops/cuda/indexing/masked.rs b/src/ops/cuda/indexing/masked.rs index 84a79640..b964f0c2 100644 --- a/src/ops/cuda/indexing/masked.rs +++ b/src/ops/cuda/indexing/masked.rs @@ -39,7 +39,7 @@ pub fn masked_select( &client.context, &client.stream, client.device.index, - mask_contig.storage().ptr(), + mask_contig.ptr(), count_ptr, bcast.strides_ptr(), bcast.shape_ptr(), @@ -53,7 +53,7 @@ pub fn masked_select( &client.context, &client.stream, client.device.index, - mask_contig.storage().ptr(), + mask_contig.ptr(), count_ptr, numel, )?; @@ -84,7 +84,7 @@ pub fn masked_select( &client.context, &client.stream, client.device.index, - mask_contig.storage().ptr(), + mask_contig.ptr(), prefix_sum_ptr, bcast.strides_ptr(), bcast.shape_ptr(), @@ -98,7 +98,7 @@ pub fn masked_select( &client.context, &client.stream, client.device.index, - mask_contig.storage().ptr(), + mask_contig.ptr(), prefix_sum_ptr, numel, )?; @@ -113,9 +113,9 @@ pub fn masked_select( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - mask_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + mask_contig.ptr(), + out.ptr(), prefix_sum_ptr, bcast.strides_ptr(), bcast.shape_ptr(), @@ -130,9 +130,9 @@ pub fn masked_select( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - mask_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + mask_contig.ptr(), + out.ptr(), prefix_sum_ptr, numel, )?; @@ -167,9 +167,9 @@ pub fn masked_fill( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - mask_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + mask_contig.ptr(), + out.ptr(), value, bcast.strides_ptr(), bcast.shape_ptr(), @@ -184,9 +184,9 @@ pub fn masked_fill( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - mask_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + mask_contig.ptr(), + out.ptr(), value, numel, )?; diff --git a/src/ops/cuda/indexing/mod.rs b/src/ops/cuda/indexing/mod.rs index 932219a9..86f03c91 100644 --- a/src/ops/cuda/indexing/mod.rs +++ b/src/ops/cuda/indexing/mod.rs @@ -130,4 +130,14 @@ impl IndexingOps for CudaClient { ) -> Result> { gather_scatter::gather_2d(self, input, rows, cols) } + + fn slice_assign( + &self, + dst: &Tensor, + src: &Tensor, + dim: usize, + start: usize, + ) -> Result> { + gather_scatter::slice_assign(self, dst, src, dim, start) + } } diff --git a/src/ops/cuda/logical.rs b/src/ops/cuda/logical.rs index 6c32cc9a..320eae52 100644 --- a/src/ops/cuda/logical.rs +++ b/src/ops/cuda/logical.rs @@ -49,9 +49,9 @@ impl LogicalOps for CudaClient { &self.context, &self.stream, self.device.index, - a_contig.storage().ptr(), - b_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), out.numel(), )?; } @@ -74,9 +74,9 @@ impl LogicalOps for CudaClient { &self.context, &self.stream, self.device.index, - a_contig.storage().ptr(), - b_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), out.numel(), )?; } @@ -99,9 +99,9 @@ impl LogicalOps for CudaClient { &self.context, &self.stream, self.device.index, - a_contig.storage().ptr(), - b_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), out.numel(), )?; } @@ -125,8 +125,8 @@ impl LogicalOps for CudaClient { &self.context, &self.stream, self.device.index, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), out.numel(), )?; } diff --git a/src/ops/cuda/mod.rs b/src/ops/cuda/mod.rs index 325e59de..f8c98a16 100644 --- a/src/ops/cuda/mod.rs +++ b/src/ops/cuda/mod.rs @@ -12,6 +12,9 @@ pub mod conv; pub mod cumulative; pub mod distance; pub mod einsum; +#[cfg(feature = "fp8")] +pub mod fp8_matmul; +pub mod gemm_epilogue; pub mod indexing; pub mod linalg; pub mod logical; @@ -24,6 +27,8 @@ pub mod random; pub mod reduce; pub mod scalar; pub mod shape; +#[cfg(feature = "sparse")] +pub mod sparse_24; pub mod sorting; pub mod statistics; pub mod type_conversion; diff --git a/src/ops/cuda/multivariate.rs b/src/ops/cuda/multivariate.rs index 1e4c5d6c..927d4262 100644 --- a/src/ops/cuda/multivariate.rs +++ b/src/ops/cuda/multivariate.rs @@ -78,9 +78,9 @@ impl MultinomialSamplingOps for CudaClient { let output = Tensor::::zeros(&[n_samples, k], dtype, &self.device); // Get device pointers - let cdf_ptr = cdf.storage().ptr(); - let uniforms_ptr = uniforms.storage().ptr(); - let output_ptr = output.storage().ptr(); + let cdf_ptr = cdf.ptr(); + let uniforms_ptr = uniforms.ptr(); + let output_ptr = output.ptr(); // Launch kernel unsafe { diff --git a/src/ops/cuda/normalization.rs b/src/ops/cuda/normalization.rs index 29bcc3c3..7f6d9f2a 100644 --- a/src/ops/cuda/normalization.rs +++ b/src/ops/cuda/normalization.rs @@ -1,7 +1,11 @@ //! Normalization operations for CUDA runtime +use crate::dtype::DType; use crate::error::{Error, Result}; -use crate::ops::NormalizationOps; -use crate::runtime::cuda::kernels::{launch_layer_norm, launch_rms_norm}; +use crate::ops::{NormalizationOps, TypeConversionOps}; +use crate::runtime::cuda::kernels::{ + launch_fused_add_layer_norm, launch_fused_add_layer_norm_bwd, launch_fused_add_rms_norm, + launch_fused_add_rms_norm_bwd, launch_group_norm, launch_layer_norm, launch_rms_norm, +}; use crate::runtime::cuda::{CudaClient, CudaRuntime}; use crate::runtime::ensure_contiguous; use crate::tensor::Tensor; @@ -47,9 +51,9 @@ impl NormalizationOps for CudaClient { &self.stream, self.device.index, dtype, - input_contig.storage().ptr(), - weight_contig.storage().ptr(), - out.storage().ptr(), + input_contig.ptr(), + weight_contig.ptr(), + out.ptr(), batch_size, hidden_size, eps, @@ -111,10 +115,10 @@ impl NormalizationOps for CudaClient { &self.stream, self.device.index, dtype, - input_contig.storage().ptr(), - weight_contig.storage().ptr(), - bias_contig.storage().ptr(), - out.storage().ptr(), + input_contig.ptr(), + weight_contig.ptr(), + bias_contig.ptr(), + out.ptr(), batch_size, hidden_size, eps, @@ -123,4 +127,399 @@ impl NormalizationOps for CudaClient { Ok(out) } + + fn group_norm( + &self, + input: &Tensor, + weight: &Tensor, + bias: &Tensor, + num_groups: usize, + eps: f32, + ) -> Result> { + let dtype = input.dtype(); + + if weight.dtype() != dtype || bias.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if weight.dtype() != dtype { + weight.dtype() + } else { + bias.dtype() + }, + }); + } + + let shape = input.shape(); + if shape.len() < 2 { + return Err(Error::InvalidArgument { + arg: "input", + reason: "group_norm requires at least 2D input [batch, channels, ...]".into(), + }); + } + + let batch = shape[0]; + let channels = shape[1]; + if !channels.is_multiple_of(num_groups) { + return Err(Error::InvalidArgument { + arg: "num_groups", + reason: format!("channels {channels} not divisible by num_groups {num_groups}"), + }); + } + let channels_per_group = channels / num_groups; + let spatial: usize = shape[2..].iter().product::().max(1); + + if weight.shape() != [channels] || bias.shape() != [channels] { + return Err(Error::ShapeMismatch { + expected: vec![channels], + got: if weight.shape() != [channels] { + weight.shape().to_vec() + } else { + bias.shape().to_vec() + }, + }); + } + + let input_contig = ensure_contiguous(input); + let weight_contig = ensure_contiguous(weight); + let bias_contig = ensure_contiguous(bias); + let out = Tensor::::empty(shape, dtype, &self.device); + + unsafe { + launch_group_norm( + &self.context, + &self.stream, + self.device.index, + dtype, + input_contig.ptr(), + weight_contig.ptr(), + bias_contig.ptr(), + out.ptr(), + batch, + channels, + spatial, + num_groups, + channels_per_group, + eps, + )?; + } + + Ok(out) + } + + fn fused_add_rms_norm( + &self, + x: &Tensor, + residual: &Tensor, + weight: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor)> { + let dtype = x.dtype(); + + // Validate dtypes match + if residual.dtype() != dtype || weight.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if residual.dtype() != dtype { + residual.dtype() + } else { + weight.dtype() + }, + }); + } + + // Weight must be 1D with size matching input's last dimension + let x_shape = x.shape(); + let hidden_size = x_shape[x_shape.len() - 1]; + if weight.shape() != [hidden_size] { + return Err(Error::ShapeMismatch { + expected: vec![hidden_size], + got: weight.shape().to_vec(), + }); + } + + // Residual must match x shape + if residual.shape() != x_shape { + return Err(Error::ShapeMismatch { + expected: x_shape.to_vec(), + got: residual.shape().to_vec(), + }); + } + + // Compute batch_size as product of all dimensions except last + let batch_size: usize = x_shape[..x_shape.len() - 1].iter().product(); + let batch_size = batch_size.max(1); + + let x_contig = ensure_contiguous(x); + let residual_contig = ensure_contiguous(residual); + let weight_contig = ensure_contiguous(weight); + let output = Tensor::::empty(x_shape, dtype, &self.device); + let pre_norm = Tensor::::empty(x_shape, dtype, &self.device); + + unsafe { + launch_fused_add_rms_norm( + &self.context, + &self.stream, + self.device.index, + dtype, + x_contig.ptr(), + residual_contig.ptr(), + weight_contig.ptr(), + output.ptr(), + pre_norm.ptr(), + batch_size, + hidden_size, + eps, + )?; + } + + Ok((output, pre_norm)) + } + + fn fused_add_rms_norm_bwd( + &self, + grad: &Tensor, + pre_norm: &Tensor, + weight: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor)> { + let dtype = grad.dtype(); + + // Validate dtypes match + if pre_norm.dtype() != dtype || weight.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if pre_norm.dtype() != dtype { + pre_norm.dtype() + } else { + weight.dtype() + }, + }); + } + + // Shapes must match + let grad_shape = grad.shape(); + if pre_norm.shape() != grad_shape { + return Err(Error::ShapeMismatch { + expected: grad_shape.to_vec(), + got: pre_norm.shape().to_vec(), + }); + } + + let hidden_size = grad_shape[grad_shape.len() - 1]; + if weight.shape() != [hidden_size] { + return Err(Error::ShapeMismatch { + expected: vec![hidden_size], + got: weight.shape().to_vec(), + }); + } + + let batch_size: usize = grad_shape[..grad_shape.len() - 1].iter().product(); + let batch_size = batch_size.max(1); + + let grad_contig = ensure_contiguous(grad); + let pre_norm_contig = ensure_contiguous(pre_norm); + let weight_contig = ensure_contiguous(weight); + let d_input_residual = Tensor::::empty(grad_shape, dtype, &self.device); + let d_weight = Tensor::::zeros(&[hidden_size], dtype, &self.device); + + unsafe { + launch_fused_add_rms_norm_bwd( + &self.context, + &self.stream, + self.device.index, + dtype, + grad_contig.ptr(), + pre_norm_contig.ptr(), + weight_contig.ptr(), + d_input_residual.ptr(), + d_weight.ptr(), + batch_size, + hidden_size, + eps, + )?; + } + + Ok((d_input_residual, d_weight)) + } + + fn fused_add_layer_norm( + &self, + x: &Tensor, + residual: &Tensor, + weight: &Tensor, + bias: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor)> { + let dtype = x.dtype(); + + // Validate dtypes match + if residual.dtype() != dtype || weight.dtype() != dtype || bias.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if residual.dtype() != dtype { + residual.dtype() + } else if weight.dtype() != dtype { + weight.dtype() + } else { + bias.dtype() + }, + }); + } + + // Weight and bias must be 1D with size matching input's last dimension + let x_shape = x.shape(); + let hidden_size = x_shape[x_shape.len() - 1]; + if weight.shape() != [hidden_size] { + return Err(Error::ShapeMismatch { + expected: vec![hidden_size], + got: weight.shape().to_vec(), + }); + } + if bias.shape() != [hidden_size] { + return Err(Error::ShapeMismatch { + expected: vec![hidden_size], + got: bias.shape().to_vec(), + }); + } + + // Residual must match x shape + if residual.shape() != x_shape { + return Err(Error::ShapeMismatch { + expected: x_shape.to_vec(), + got: residual.shape().to_vec(), + }); + } + + let batch_size: usize = x_shape[..x_shape.len() - 1].iter().product(); + let batch_size = batch_size.max(1); + + let x_contig = ensure_contiguous(x); + let residual_contig = ensure_contiguous(residual); + let weight_contig = ensure_contiguous(weight); + let bias_contig = ensure_contiguous(bias); + let output = Tensor::::empty(x_shape, dtype, &self.device); + let pre_norm = Tensor::::empty(x_shape, dtype, &self.device); + + unsafe { + launch_fused_add_layer_norm( + &self.context, + &self.stream, + self.device.index, + dtype, + x_contig.ptr(), + residual_contig.ptr(), + weight_contig.ptr(), + bias_contig.ptr(), + output.ptr(), + pre_norm.ptr(), + batch_size, + hidden_size, + eps, + )?; + } + + Ok((output, pre_norm)) + } + + fn fused_add_layer_norm_bwd( + &self, + grad: &Tensor, + pre_norm: &Tensor, + weight: &Tensor, + bias: &Tensor, + eps: f32, + ) -> Result<( + Tensor, + Tensor, + Tensor, + )> { + let dtype = grad.dtype(); + + // Validate dtypes match + if pre_norm.dtype() != dtype || weight.dtype() != dtype || bias.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if pre_norm.dtype() != dtype { + pre_norm.dtype() + } else if weight.dtype() != dtype { + weight.dtype() + } else { + bias.dtype() + }, + }); + } + + // Shapes must match + let grad_shape = grad.shape(); + if pre_norm.shape() != grad_shape { + return Err(Error::ShapeMismatch { + expected: grad_shape.to_vec(), + got: pre_norm.shape().to_vec(), + }); + } + + let hidden_size = grad_shape[grad_shape.len() - 1]; + if weight.shape() != [hidden_size] || bias.shape() != [hidden_size] { + return Err(Error::ShapeMismatch { + expected: vec![hidden_size], + got: if weight.shape() != [hidden_size] { + weight.shape().to_vec() + } else { + bias.shape().to_vec() + }, + }); + } + + let batch_size: usize = grad_shape[..grad_shape.len() - 1].iter().product(); + let batch_size = batch_size.max(1); + + // FP8: compute backward in F32, then cast results back (FP8 precision too low for + // multi-pass backward with atomicAdd accumulation) + #[cfg(feature = "fp8")] + if dtype == DType::FP8E4M3 || dtype == DType::FP8E5M2 { + let grad_f32 = self.cast(grad, DType::F32)?; + let pre_norm_f32 = self.cast(pre_norm, DType::F32)?; + let weight_f32 = self.cast(weight, DType::F32)?; + let bias_f32 = self.cast(bias, DType::F32)?; + let (d_ir, d_w, d_b) = self.fused_add_layer_norm_bwd( + &grad_f32, + &pre_norm_f32, + &weight_f32, + &bias_f32, + eps, + )?; + return Ok(( + self.cast(&d_ir, dtype)?, + self.cast(&d_w, dtype)?, + self.cast(&d_b, dtype)?, + )); + } + + let grad_contig = ensure_contiguous(grad); + let pre_norm_contig = ensure_contiguous(pre_norm); + let weight_contig = ensure_contiguous(weight); + let d_input_residual = Tensor::::empty(grad_shape, dtype, &self.device); + let d_weight = Tensor::::zeros(&[hidden_size], dtype, &self.device); + let d_bias = Tensor::::zeros(&[hidden_size], dtype, &self.device); + + unsafe { + launch_fused_add_layer_norm_bwd( + &self.context, + &self.stream, + self.device.index, + dtype, + grad_contig.ptr(), + pre_norm_contig.ptr(), + weight_contig.ptr(), + d_input_residual.ptr(), + d_weight.ptr(), + d_bias.ptr(), + batch_size, + hidden_size, + eps, + )?; + } + + Ok((d_input_residual, d_weight, d_bias)) + } } diff --git a/src/ops/cuda/quasirandom.rs b/src/ops/cuda/quasirandom.rs index 41d5bef2..07b1a23d 100644 --- a/src/ops/cuda/quasirandom.rs +++ b/src/ops/cuda/quasirandom.rs @@ -41,7 +41,7 @@ impl QuasiRandomOps for CudaClient { &self.stream, self.device.index, &self.device, - out.storage().ptr(), + out.ptr(), n_points, dimension, skip, @@ -53,7 +53,7 @@ impl QuasiRandomOps for CudaClient { &self.stream, self.device.index, &self.device, - out.storage().ptr(), + out.ptr(), n_points, dimension, skip, @@ -87,7 +87,7 @@ impl QuasiRandomOps for CudaClient { &self.context, &self.stream, self.device.index, - out.storage().ptr(), + out.ptr(), n_points, dimension, skip, @@ -98,7 +98,7 @@ impl QuasiRandomOps for CudaClient { &self.context, &self.stream, self.device.index, - out.storage().ptr(), + out.ptr(), n_points, dimension, skip, @@ -140,7 +140,7 @@ impl QuasiRandomOps for CudaClient { &self.context, &self.stream, self.device.index, - out.storage().ptr(), + out.ptr(), n_samples, dimension, seed, @@ -151,7 +151,7 @@ impl QuasiRandomOps for CudaClient { &self.context, &self.stream, self.device.index, - out.storage().ptr(), + out.ptr(), n_samples, dimension, seed, @@ -198,15 +198,20 @@ mod tests { use crate::runtime::Runtime; use crate::runtime::cuda::CudaDevice; - fn setup() -> (CudaDevice, CudaClient) { + fn setup() -> Option<(CudaDevice, CudaClient)> { + if !crate::runtime::cuda::is_cuda_available() { + return None; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); - (device, client) + Some((device, client)) } #[test] fn test_sobol_basic() { - let (_device, client) = setup(); + let Some((_device, client)) = setup() else { + return; + }; let points = client.sobol(10, 2, 0, DType::F32).unwrap(); assert_eq!(points.shape(), &[10, 2]); @@ -220,7 +225,9 @@ mod tests { #[test] fn test_halton_basic() { - let (_device, client) = setup(); + let Some((_device, client)) = setup() else { + return; + }; let points = client.halton(10, 3, 0, DType::F32).unwrap(); assert_eq!(points.shape(), &[10, 3]); @@ -234,7 +241,9 @@ mod tests { #[test] fn test_latin_hypercube_basic() { - let (_device, client) = setup(); + let Some((_device, client)) = setup() else { + return; + }; let samples = client.latin_hypercube(20, 4, DType::F32).unwrap(); assert_eq!(samples.shape(), &[20, 4]); @@ -248,7 +257,9 @@ mod tests { #[test] fn test_sobol_deterministic() { - let (_device, client) = setup(); + let Some((_device, client)) = setup() else { + return; + }; let points1 = client.sobol(5, 2, 0, DType::F32).unwrap(); let points2 = client.sobol(5, 2, 0, DType::F32).unwrap(); @@ -264,21 +275,27 @@ mod tests { #[test] fn test_error_zero_points() { - let (_device, client) = setup(); + let Some((_device, client)) = setup() else { + return; + }; let result = client.sobol(0, 2, 0, DType::F32); assert!(result.is_err()); } #[test] fn test_error_unsupported_dtype() { - let (_device, client) = setup(); + let Some((_device, client)) = setup() else { + return; + }; let result = client.sobol(10, 2, 0, DType::I32); assert!(result.is_err()); } #[test] fn test_sobol_dimension_limit() { - let (_device, client) = setup(); + let Some((_device, client)) = setup() else { + return; + }; // Should work up to 21,201 dimensions (full Joe & Kuo dataset) let result = client.sobol(10, 100, 0, DType::F32); @@ -294,7 +311,9 @@ mod tests { #[test] fn test_halton_dimension_limit() { - let (_device, client) = setup(); + let Some((_device, client)) = setup() else { + return; + }; // Should work up to 100 dimensions let result = client.halton(10, 100, 0, DType::F32); diff --git a/src/ops/cuda/random.rs b/src/ops/cuda/random.rs index cdb78edf..1ac2750b 100644 --- a/src/ops/cuda/random.rs +++ b/src/ops/cuda/random.rs @@ -2,7 +2,8 @@ use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::RandomOps; -use crate::ops::TypeConversionOps; // Required for self.cast() method resolution +#[cfg(feature = "fp8")] +use crate::ops::TypeConversionOps; use crate::runtime::cuda::kernels::{ launch_bernoulli, launch_beta_dist, launch_binomial, launch_chi_squared, launch_exponential, launch_f_distribution, launch_gamma_dist, launch_laplace, launch_multinomial_with_replacement, @@ -48,7 +49,43 @@ impl RandomOps for CudaClient { self.device.index, dtype, seed, - out.storage().ptr(), + out.ptr(), + numel, + )?; + } + + Ok(out) + } + + fn rand_seeded(&self, shape: &[usize], dtype: DType, seed: u64) -> Result> { + #[cfg(feature = "fp8")] + if matches!(dtype, DType::FP8E4M3 | DType::FP8E5M2) { + let f32_result = self.rand_seeded(shape, DType::F32, seed)?; + return self.cast(&f32_result, dtype); + } + + if !matches!(dtype, DType::F32 | DType::F64 | DType::F16 | DType::BF16) { + return Err(Error::UnsupportedDType { + dtype, + op: "rand_seeded", + }); + } + + let numel: usize = shape.iter().product(); + if numel == 0 { + return Ok(Tensor::::empty(shape, dtype, &self.device)); + } + + let out = Tensor::::empty(shape, dtype, &self.device); + + unsafe { + launch_rand( + &self.context, + &self.stream, + self.device.index, + dtype, + seed, + out.ptr(), numel, )?; } @@ -89,7 +126,7 @@ impl RandomOps for CudaClient { self.device.index, dtype, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -157,7 +194,7 @@ impl RandomOps for CudaClient { low, range, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -243,8 +280,8 @@ impl RandomOps for CudaClient { &self.stream, self.device.index, dtype, - probs.storage().ptr(), - out.storage().ptr(), + probs.ptr(), + out.ptr(), seed, num_distributions, num_categories, @@ -256,8 +293,8 @@ impl RandomOps for CudaClient { &self.stream, self.device.index, dtype, - probs.storage().ptr(), - out.storage().ptr(), + probs.ptr(), + out.ptr(), seed, num_distributions, num_categories, @@ -299,7 +336,7 @@ impl RandomOps for CudaClient { dtype, p, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -347,7 +384,7 @@ impl RandomOps for CudaClient { alpha, beta, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -395,7 +432,7 @@ impl RandomOps for CudaClient { shape_param, scale, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -433,7 +470,7 @@ impl RandomOps for CudaClient { dtype, rate, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -471,7 +508,7 @@ impl RandomOps for CudaClient { dtype, lambda, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -522,7 +559,7 @@ impl RandomOps for CudaClient { n, p, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -567,7 +604,7 @@ impl RandomOps for CudaClient { loc, scale, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -605,7 +642,7 @@ impl RandomOps for CudaClient { dtype, df, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -643,7 +680,7 @@ impl RandomOps for CudaClient { dtype, df, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -694,7 +731,7 @@ impl RandomOps for CudaClient { df1, df2, seed, - out.storage().ptr(), + out.ptr(), numel, )?; } diff --git a/src/ops/cuda/scalar.rs b/src/ops/cuda/scalar.rs index 2492a66b..0df23fc5 100644 --- a/src/ops/cuda/scalar.rs +++ b/src/ops/cuda/scalar.rs @@ -2,8 +2,10 @@ use crate::error::Result; use crate::ops::ScalarOps; +use crate::runtime::cuda::kernels::launch_fused_mul_add_scalar; use crate::runtime::cuda::ops::helpers::native_scalar_op; use crate::runtime::cuda::{CudaClient, CudaRuntime}; +use crate::runtime::ensure_contiguous; use crate::tensor::Tensor; impl ScalarOps for CudaClient { @@ -30,4 +32,31 @@ impl ScalarOps for CudaClient { fn rsub_scalar(&self, a: &Tensor, scalar: f64) -> Result> { native_scalar_op(self, a, "rsub_scalar", scalar) } + + fn fused_mul_add_scalar( + &self, + a: &Tensor, + scale: f64, + bias: f64, + ) -> Result> { + let dtype = a.dtype(); + let a_contig = ensure_contiguous(a); + let out = Tensor::::empty(a.shape(), dtype, &self.device); + + unsafe { + launch_fused_mul_add_scalar( + &self.context, + &self.stream, + self.device.index, + dtype, + a_contig.ptr(), + out.ptr(), + out.numel(), + scale, + bias, + )?; + } + + Ok(out) + } } diff --git a/src/ops/cuda/semiring_matmul.rs b/src/ops/cuda/semiring_matmul.rs index 060dd7da..b3d10c09 100644 --- a/src/ops/cuda/semiring_matmul.rs +++ b/src/ops/cuda/semiring_matmul.rs @@ -38,9 +38,13 @@ impl SemiringMatmulOps for CudaClient { }); } - // Only F32, F64, I32 have CUDA kernels + // Supported CUDA kernel dtypes match dtype { - DType::F32 | DType::F64 | DType::I32 => {} + DType::F32 | DType::F64 | DType::I32 | DType::Bool | DType::U8 => {} + #[cfg(feature = "f16")] + DType::F16 | DType::BF16 => {} + #[cfg(feature = "fp8")] + DType::FP8E4M3 | DType::FP8E5M2 => {} _ => { return Err(Error::UnsupportedDType { dtype, @@ -84,10 +88,17 @@ impl SemiringMatmulOps for CudaClient { let op_code = semiring_op_code(op); + // Bool uses the u8 kernel (same underlying type) + let kernel_dtype = if dtype == DType::Bool { + DType::U8 + } else { + dtype + }; + if batch_size > 1 { - semiring_matmul_batched_native(self, a, b, dtype, batch_size, m, k, n, op_code) + semiring_matmul_batched_native(self, a, b, kernel_dtype, batch_size, m, k, n, op_code) } else { - semiring_matmul_native(self, a, b, dtype, m, k, n, op_code) + semiring_matmul_native(self, a, b, kernel_dtype, m, k, n, op_code) } } } diff --git a/src/ops/cuda/shape.rs b/src/ops/cuda/shape.rs index 674b7bcc..620eef58 100644 --- a/src/ops/cuda/shape.rs +++ b/src/ops/cuda/shape.rs @@ -4,12 +4,12 @@ use crate::ops::ShapeOps; use crate::ops::impl_generic::{repeat_interleave_impl, unfold_impl}; use crate::runtime::cuda::kernels::{launch_cat_copy, launch_pad, launch_repeat, launch_roll}; use crate::runtime::cuda::{CudaClient, CudaRuntime}; -use crate::runtime::{ensure_contiguous, shape_ops}; +use crate::runtime::{common::shape_ops, ensure_contiguous}; use crate::tensor::Tensor; impl ShapeOps for CudaClient { fn cat(&self, tensors: &[&Tensor], dim: isize) -> Result> { - let params = crate::runtime::shape_ops::validate_cat(tensors, dim)?; + let params = crate::runtime::common::shape_ops::validate_cat(tensors, dim)?; // Allocate output let out = Tensor::::empty(¶ms.out_shape, params.dtype, &self.device); @@ -26,8 +26,8 @@ impl ShapeOps for CudaClient { &self.stream, self.device.index, params.dtype, - tensor_contig.storage().ptr(), - out.storage().ptr(), + tensor_contig.ptr(), + out.ptr(), params.outer_size, src_cat_size, params.cat_dim_total, @@ -44,7 +44,7 @@ impl ShapeOps for CudaClient { fn stack(&self, tensors: &[&Tensor], dim: isize) -> Result> { // Validate tensors and get normalized dimension - let _ = crate::runtime::shape_ops::validate_stack(tensors, dim)?; + let _ = crate::runtime::common::shape_ops::validate_stack(tensors, dim)?; // stack(tensors, dim) = cat([t.unsqueeze(dim) for t in tensors], dim) let unsqueezed: Vec> = tensors @@ -96,8 +96,8 @@ impl ShapeOps for CudaClient { self.device.index, &self.device, tensor.dtype(), - tensor_contig.storage().ptr(), - out.storage().ptr(), + tensor_contig.ptr(), + out.ptr(), tensor.shape(), ¶ms.out_shape, )?; @@ -132,8 +132,8 @@ impl ShapeOps for CudaClient { self.device.index, &self.device, tensor.dtype(), - tensor_contig.storage().ptr(), - out.storage().ptr(), + tensor_contig.ptr(), + out.ptr(), value, tensor.shape(), ¶ms.out_shape, @@ -172,8 +172,8 @@ impl ShapeOps for CudaClient { &self.stream, self.device.index, tensor.dtype(), - tensor_contig.storage().ptr(), - out.storage().ptr(), + tensor_contig.ptr(), + out.ptr(), outer_size, params.dim_size, inner_size, diff --git a/src/ops/cuda/sorting.rs b/src/ops/cuda/sorting.rs index fe7e7bf2..0b191056 100644 --- a/src/ops/cuda/sorting.rs +++ b/src/ops/cuda/sorting.rs @@ -37,8 +37,8 @@ impl SortingOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), outer_size, sort_size, inner_size, @@ -76,9 +76,9 @@ impl SortingOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out_values.storage().ptr(), - out_indices.storage().ptr(), + a_contig.ptr(), + out_values.ptr(), + out_indices.ptr(), outer_size, sort_size, inner_size, @@ -118,8 +118,8 @@ impl SortingOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), outer_size, sort_size, inner_size, @@ -188,9 +188,9 @@ impl SortingOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out_values.storage().ptr(), - out_indices.storage().ptr(), + a_contig.ptr(), + out_values.ptr(), + out_indices.ptr(), outer_size, sort_size, inner_size, @@ -228,8 +228,8 @@ impl SortingOps for CudaClient { &self.stream, self.device.index, dtype, - sorted_tensor.storage().ptr(), - counter.storage().ptr(), + sorted_tensor.ptr(), + counter.ptr(), numel, )?; } @@ -256,9 +256,9 @@ impl SortingOps for CudaClient { &self.stream, self.device.index, dtype, - sorted_tensor.storage().ptr(), - out.storage().ptr(), - counter.storage().ptr(), + sorted_tensor.ptr(), + out.ptr(), + counter.ptr(), numel, )?; } @@ -300,8 +300,8 @@ impl SortingOps for CudaClient { &self.context, &self.stream, self.device.index, - inverse.storage().ptr(), - counts.storage().ptr(), + inverse.ptr(), + counts.ptr(), numel, unique_count, )?; @@ -335,8 +335,8 @@ impl SortingOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - counter.storage().ptr(), + a_contig.ptr(), + counter.ptr(), numel, )?; } @@ -374,9 +374,9 @@ impl SortingOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - flat_indices.storage().ptr(), - counter.storage().ptr(), + a_contig.ptr(), + flat_indices.ptr(), + counter.ptr(), numel, )?; } @@ -394,11 +394,11 @@ impl SortingOps for CudaClient { &self.context, &self.stream, self.device.index, - flat_indices.storage().ptr(), - out.storage().ptr(), + flat_indices.ptr(), + out.ptr(), nnz, ndim, - shape_tensor.storage().ptr(), + shape_tensor.ptr(), )?; } @@ -447,9 +447,9 @@ impl SortingOps for CudaClient { &self.stream, self.device.index, dtype, - seq_contig.storage().ptr(), - values_contig.storage().ptr(), - out.storage().ptr(), + seq_contig.ptr(), + values_contig.ptr(), + out.ptr(), seq_len, num_values, right, diff --git a/src/ops/cuda/sparse_24.rs b/src/ops/cuda/sparse_24.rs new file mode 100644 index 00000000..426b06f5 --- /dev/null +++ b/src/ops/cuda/sparse_24.rs @@ -0,0 +1,136 @@ +//! CUDA implementation of 2:4 structured sparsity operations. + +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::ops::traits::Sparse24Ops; +use crate::runtime::cuda::kernels::{ + launch_sparse_24_decompress, launch_sparse_24_matmul, launch_sparse_24_prune, +}; +use crate::runtime::cuda::{CudaClient, CudaRuntime}; +use crate::runtime::ensure_contiguous; +use crate::sparse::structured::{Sparse24Tensor, meta_cols_for_k}; +use crate::tensor::Tensor; + +impl Sparse24Ops for CudaClient { + fn prune_to_24(&self, dense: &Tensor) -> Result> { + if dense.ndim() != 2 { + return Err(Error::InvalidArgument { + arg: "dense", + reason: format!("Expected 2D tensor, got {}D", dense.ndim()), + }); + } + + let m = dense.shape()[0]; + let k = dense.shape()[1]; + + if !k.is_multiple_of(4) { + return Err(Error::InvalidArgument { + arg: "dense", + reason: format!("K dimension ({k}) must be divisible by 4 for 2:4 sparsity"), + }); + } + + let dtype = dense.dtype(); + let dense_contig = ensure_contiguous(dense); + let half_k = k / 2; + let mc = meta_cols_for_k(k); + + let compressed = Tensor::::empty(&[m, half_k], dtype, &self.device); + // Metadata must be zeroed before kernel's atomic OR operations + let metadata = Tensor::::zeros(&[m, mc], DType::U32, &self.device); + + unsafe { + launch_sparse_24_prune( + &self.context, + &self.stream, + self.device.index, + dtype, + dense_contig.ptr(), + compressed.ptr(), + metadata.ptr(), + m, + k, + )?; + } + + Sparse24Tensor::new(compressed, metadata, [m, k]) + } + + fn sparse_24_to_dense( + &self, + sparse: &Sparse24Tensor, + ) -> Result> { + let [m, k] = sparse.shape(); + let dtype = sparse.dtype(); + + let dense = Tensor::::empty(&[m, k], dtype, &self.device); + + let vals = ensure_contiguous(sparse.compressed_values()); + let meta = ensure_contiguous(sparse.metadata()); + + unsafe { + launch_sparse_24_decompress( + &self.context, + &self.stream, + self.device.index, + dtype, + vals.ptr(), + meta.ptr(), + dense.ptr(), + m, + k, + )?; + } + + Ok(dense) + } + + fn sparse_24_matmul( + &self, + input: &Tensor, + weight: &Sparse24Tensor, + ) -> Result> { + if input.ndim() != 2 { + return Err(Error::InvalidArgument { + arg: "input", + reason: format!("Expected 2D tensor, got {}D", input.ndim()), + }); + } + + let n = input.shape()[0]; + let input_k = input.shape()[1]; + let [m, weight_k] = weight.shape(); + + if input_k != weight_k { + return Err(Error::ShapeMismatch { + expected: vec![n, weight_k], + got: vec![n, input_k], + }); + } + + let dtype = input.dtype(); + let input_contig = ensure_contiguous(input); + let vals = ensure_contiguous(weight.compressed_values()); + let meta = ensure_contiguous(weight.metadata()); + + let output = Tensor::::empty(&[n, m], dtype, &self.device); + + unsafe { + launch_sparse_24_matmul( + &self.context, + &self.stream, + self.device.index, + dtype, + input_contig.ptr(), + vals.ptr(), + meta.ptr(), + output.ptr(), + n, + m, + weight_k, + )?; + } + + Ok(output) + } +} diff --git a/src/ops/cuda/type_conversion.rs b/src/ops/cuda/type_conversion.rs index 2afc2288..f422417b 100644 --- a/src/ops/cuda/type_conversion.rs +++ b/src/ops/cuda/type_conversion.rs @@ -28,8 +28,8 @@ impl TypeConversionOps for CudaClient { self.device.index, src_dtype, target_dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), numel, )?; } diff --git a/src/ops/cuda/unary.rs b/src/ops/cuda/unary.rs index 842d04ad..ba2a730a 100644 --- a/src/ops/cuda/unary.rs +++ b/src/ops/cuda/unary.rs @@ -145,8 +145,8 @@ impl UnaryOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), out.numel(), )?; } @@ -166,8 +166,8 @@ impl UnaryOps for CudaClient { &self.stream, self.device.index, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), out.numel(), )?; } diff --git a/src/ops/cuda/utility.rs b/src/ops/cuda/utility.rs index 03f973fe..fdf8019a 100644 --- a/src/ops/cuda/utility.rs +++ b/src/ops/cuda/utility.rs @@ -49,7 +49,7 @@ impl UtilityOps for CudaClient { self.device.index, dtype, value, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -82,7 +82,7 @@ impl UtilityOps for CudaClient { dtype, start, step, - out.storage().ptr(), + out.ptr(), numel, )?; } @@ -119,7 +119,7 @@ impl UtilityOps for CudaClient { dtype, start, stop, - out.storage().ptr(), + out.ptr(), steps, )?; } @@ -166,7 +166,7 @@ impl UtilityOps for CudaClient { dtype, rows, cols, - out.storage().ptr(), + out.ptr(), )?; } diff --git a/src/ops/impl_generic/activation.rs b/src/ops/impl_generic/activation.rs new file mode 100644 index 00000000..31dce82f --- /dev/null +++ b/src/ops/impl_generic/activation.rs @@ -0,0 +1,88 @@ +//! Generic implementations of composite activation operations. + +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::ops::activation::normalize_softmax_dim; +use crate::ops::traits::{ + ActivationOps, BinaryOps, CompareOps, ConditionalOps, CumulativeOps, RandomOps, ScalarOps, + UnaryOps, +}; +use crate::runtime::{Runtime, RuntimeClient}; +use crate::tensor::Tensor; + +/// Generic softplus implementation: softplus(x) = log(1 + exp(x)) +/// +/// Uses the numerically stable form: `relu(x) + log(1 + exp(-|x|))` +/// +/// The naive formula `log(1 + exp(x))` overflows to `Inf` for large positive x +/// (e.g., x = 100: `exp(100) = Inf`). The stable decomposition keeps all +/// intermediate values bounded: +/// - For large x > 0: `relu(x) ≈ x`, `log(1 + exp(-x)) ≈ 0` → result ≈ x ✓ +/// - For large x < 0: `relu(x) = 0`, `log(1 + exp(-|x|)) ≈ exp(x)` → result ≈ exp(x) ✓ +/// - At x = 0: `0 + log(2) ≈ 0.693` ✓ +/// +/// All backends delegate here — guarantees identical numerical behaviour. +pub fn softplus_impl(client: &C, a: &Tensor) -> Result> +where + R: Runtime, + C: ActivationOps + UnaryOps + ScalarOps + BinaryOps, +{ + // relu(x) = max(0, x) + let relu_x = client.relu(a)?; + + // log(1 + exp(-|x|)) — all values bounded: exp(-|x|) ∈ (0, 1] + let abs_x = client.abs(a)?; + let neg_abs = client.neg(&abs_x)?; + let exp_neg_abs = client.exp(&neg_abs)?; + let one_plus = client.add_scalar(&exp_neg_abs, 1.0)?; + let log_term = client.log(&one_plus)?; + + client.add(&relu_x, &log_term) +} + +/// Generic log_softmax implementation: log_softmax(x, dim) = x - logsumexp(x, dim, keepdim=true) +/// +/// This is the canonical algorithm — all backends delegate here. +/// Numerically stable because logsumexp uses the max-subtraction trick internally. +pub fn log_softmax_impl(client: &C, a: &Tensor, dim: isize) -> Result> +where + R: Runtime, + C: BinaryOps + CumulativeOps, +{ + let ndim = a.ndim(); + let dim_idx = normalize_softmax_dim(ndim, dim).ok_or(Error::InvalidDimension { dim, ndim })?; + + let lse = client.logsumexp(a, &[dim_idx], true)?; + client.sub(a, &lse) +} + +/// Generic dropout implementation: where(rand > p, x / (1-p), 0) +/// +/// During training, randomly zeros elements with probability `p` and scales +/// remaining elements by `1/(1-p)` to preserve expected values. +/// During inference (`training=false`), returns input unchanged. +pub fn dropout_impl(client: &C, a: &Tensor, p: f64, training: bool) -> Result> +where + R: Runtime, + C: RandomOps + CompareOps + ConditionalOps + ScalarOps + RuntimeClient, +{ + if !training || p == 0.0 { + return Ok(a.clone()); + } + if p >= 1.0 { + return Ok(Tensor::::zeros(a.shape(), a.dtype(), client.device())); + } + + // Generate random mask: rand > p means "keep" + let rand_tensor = client.rand(a.shape(), a.dtype())?; + let threshold = Tensor::::full_scalar(a.shape(), a.dtype(), p, client.device()); + let mask = client.gt(&rand_tensor, &threshold)?; + + // Scale kept values by 1/(1-p) + let scale = 1.0 / (1.0 - p); + let scaled = client.mul_scalar(a, scale)?; + + // Zero out dropped elements + let zeros = Tensor::::zeros(a.shape(), a.dtype(), client.device()); + client.where_cond(&mask, &scaled, &zeros) +} diff --git a/src/ops/impl_generic/linalg.rs b/src/ops/impl_generic/linalg.rs index 02e89a16..1c39866b 100644 --- a/src/ops/impl_generic/linalg.rs +++ b/src/ops/impl_generic/linalg.rs @@ -40,7 +40,7 @@ fn triangular_mask_impl( triangle: Triangle, ) -> Result> where - R: Runtime, + R: Runtime, C: UtilityOps + ScalarOps + CompareOps + TypeConversionOps + BinaryOps, { let (m, n) = validate_matrix_2d(a.shape())?; @@ -74,7 +74,7 @@ where #[cfg(any(feature = "cuda", feature = "wgpu"))] pub fn triu_impl(client: &C, a: &Tensor, diagonal: i64) -> Result> where - R: Runtime, + R: Runtime, C: UtilityOps + ScalarOps + CompareOps + TypeConversionOps + BinaryOps, { triangular_mask_impl(client, a, diagonal, Triangle::Upper) @@ -86,7 +86,7 @@ where #[cfg(any(feature = "cuda", feature = "wgpu"))] pub fn tril_impl(client: &C, a: &Tensor, diagonal: i64) -> Result> where - R: Runtime, + R: Runtime, C: UtilityOps + ScalarOps + CompareOps + TypeConversionOps + BinaryOps, { triangular_mask_impl(client, a, diagonal, Triangle::Lower) @@ -100,7 +100,7 @@ where #[cfg(any(feature = "cuda", feature = "wgpu"))] pub fn slogdet_impl(client: &C, a: &Tensor) -> Result> where - R: Runtime, + R: Runtime, C: LinearAlgebraAlgorithms + UtilityOps + BinaryOps diff --git a/src/ops/impl_generic/mod.rs b/src/ops/impl_generic/mod.rs index 2b6bf006..120463bb 100644 --- a/src/ops/impl_generic/mod.rs +++ b/src/ops/impl_generic/mod.rs @@ -19,6 +19,7 @@ //! └── wgpu/multivariate.rs delegates here //! ``` +pub mod activation; pub mod einsum; pub mod linalg; pub mod multivariate; diff --git a/src/ops/impl_generic/multivariate.rs b/src/ops/impl_generic/multivariate.rs index f3b15dd1..3be49c3b 100644 --- a/src/ops/impl_generic/multivariate.rs +++ b/src/ops/impl_generic/multivariate.rs @@ -44,7 +44,7 @@ impl DTypeSupport { // Validation Helpers (parameter extraction is OK - these are small user inputs) // ============================================================================ -fn validate_multivariate_normal_inputs( +fn validate_multivariate_normal_inputs>( mean: &Tensor, cov: &Tensor, n_samples: usize, @@ -109,7 +109,7 @@ fn validate_multivariate_normal_inputs( Ok(d) } -fn validate_wishart_inputs( +fn validate_wishart_inputs>( scale: &Tensor, df: usize, n_samples: usize, @@ -164,7 +164,7 @@ fn validate_wishart_inputs( } /// Validate dirichlet inputs. Extracts alpha values (small parameter vector). -fn validate_dirichlet_inputs( +fn validate_dirichlet_inputs>( alpha: &Tensor, n_samples: usize, ) -> Result<(usize, Vec)> { @@ -213,7 +213,7 @@ fn validate_dirichlet_inputs( } /// Validate multinomial inputs. Extracts probs and computes CDF (small parameter vector). -fn validate_multinomial_inputs( +fn validate_multinomial_inputs>( probs: &Tensor, n_trials: usize, n_samples: usize, @@ -280,7 +280,7 @@ pub fn multivariate_normal_impl( dtype_support: DTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: LinearAlgebraAlgorithms + MatmulOps + BinaryOps + RandomOps, { let d = validate_multivariate_normal_inputs(mean, cov, n_samples, dtype_support)?; @@ -318,7 +318,7 @@ pub fn wishart_impl( dtype_support: DTypeSupport, ) -> Result> where - R: Runtime, + R: Runtime, C: LinearAlgebraAlgorithms + MatmulOps + BinaryOps @@ -398,7 +398,7 @@ fn construct_bartlett_matrices( device: &R::Device, ) -> Result> where - R: Runtime, + R: Runtime, C: BinaryOps + ShapeOps, { // We need to place values at specific positions. @@ -465,7 +465,7 @@ where /// ALL OPERATIONS ON GPU - only alpha parameters extracted (small user input). pub fn dirichlet_impl(client: &C, alpha: &Tensor, n_samples: usize) -> Result> where - R: Runtime, + R: Runtime, C: RandomOps + ReduceOps + BinaryOps + ShapeOps, { let (k, alpha_data) = validate_dirichlet_inputs(alpha, n_samples)?; @@ -511,7 +511,7 @@ pub fn multinomial_samples_impl( n_samples: usize, ) -> Result> where - R: Runtime, + R: Runtime, C: MultinomialSamplingOps, { let k = validate_multinomial_inputs(probs, n_trials, n_samples)?; @@ -536,7 +536,7 @@ where /// /// This requires a GPU kernel because CDF lookup + counting cannot be /// efficiently expressed with standard tensor operations. -pub trait MultinomialSamplingOps { +pub trait MultinomialSamplingOps> { /// Multinomial sampling kernel. /// /// Given probability vector, generates n_samples where each sample diff --git a/src/ops/impl_generic/utility.rs b/src/ops/impl_generic/utility.rs index e6eb3ccf..c980bc0c 100644 --- a/src/ops/impl_generic/utility.rs +++ b/src/ops/impl_generic/utility.rs @@ -23,7 +23,7 @@ use crate::tensor::Tensor; #[cfg(any(feature = "cuda", feature = "wgpu"))] pub fn one_hot_impl(client: &C, indices: &Tensor, num_classes: usize) -> Result> where - R: Runtime, + R: Runtime, C: UtilityOps + TypeConversionOps + CompareOps, { if num_classes == 0 { diff --git a/src/ops/mod.rs b/src/ops/mod.rs index 39dde6fa..1a0be85e 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -34,7 +34,7 @@ //! let out = Tensor::empty(&out_shape, a.dtype(), self.device()); //! //! // 3. Dispatch kernel -//! cuda_add_kernel(a.storage().ptr(), b.storage().ptr(), out.storage().ptr(), ...); +//! cuda_add_kernel(a.ptr(), b.ptr(), out.ptr(), ...); //! //! Ok(out) //! } @@ -99,10 +99,12 @@ pub(crate) use matmul::{ pub(crate) use reduce::{ AccumulationPrecision, compute_reduce_strides, reduce_dim_output_shape, reduce_output_shape, }; +pub use traits::Fp8MatmulOps; pub use traits::{ - ActivationOps, AdvancedRandomOps, BinaryOps, CompareOps, ComplexOps, ConditionalOps, ConvOps, - CumulativeOps, DistanceMetric, DistanceOps, EinsumOps, IndexingOps, Kernel, LinalgOps, - LogicalOps, MatmulOps, MeshgridIndexing, MultivariateRandomOps, NormalizationOps, PaddingMode, - QuasiRandomOps, RandomOps, ReduceOps, ScalarOps, ScatterReduceOp, SemiringMatmulOps, ShapeOps, - SortingOps, StatisticalOps, TensorOps, TypeConversionOps, UnaryOps, UtilityOps, + ActivationOps, BinaryOps, CompareOps, ComplexOps, ConditionalOps, ConvOps, CumulativeOps, + DistanceMetric, DistanceOps, EinsumOps, GemmActivation, GemmEpilogueOps, IndexingOps, Kernel, + LinalgOps, LogicalOps, MatmulOps, MeshgridIndexing, NormalizationOps, PaddingMode, ReduceOps, + ScalarOps, ScatterReduceOp, SemiringMatmulOps, ShapeOps, SortingOps, StatisticalOps, TensorOps, + TypeConversionOps, UnaryOps, UtilityOps, }; +pub use traits::{AdvancedRandomOps, MultivariateRandomOps, QuasiRandomOps, RandomOps}; diff --git a/src/ops/reduce.rs b/src/ops/reduce.rs index 1efbf0f2..a3b7d5b5 100644 --- a/src/ops/reduce.rs +++ b/src/ops/reduce.rs @@ -141,7 +141,10 @@ mod tests { ); // Reduce all dims - assert_eq!(reduce_output_shape(&[2, 3, 4], &[0, 1, 2], false), vec![]); + assert_eq!( + reduce_output_shape(&[2, 3, 4], &[0, 1, 2], false), + Vec::::new() + ); assert_eq!( reduce_output_shape(&[2, 3, 4], &[0, 1, 2], true), vec![1, 1, 1] diff --git a/src/ops/semiring.rs b/src/ops/semiring.rs index 2ecfeb0b..322aaf42 100644 --- a/src/ops/semiring.rs +++ b/src/ops/semiring.rs @@ -139,13 +139,14 @@ impl SemiringOp { _ => { matches!(dtype, DType::F32 | DType::F64 | DType::I32 | DType::I64) || { #[cfg(feature = "f16")] - { - matches!(dtype, DType::F16 | DType::BF16) + if matches!(dtype, DType::F16 | DType::BF16) { + return true; } - #[cfg(not(feature = "f16"))] - { - false + #[cfg(feature = "fp8")] + if matches!(dtype, DType::FP8E4M3 | DType::FP8E5M2) { + return true; } + false } } } diff --git a/src/ops/traits/activation.rs b/src/ops/traits/activation.rs index 36349ab9..4d96599b 100644 --- a/src/ops/traits/activation.rs +++ b/src/ops/traits/activation.rs @@ -71,4 +71,172 @@ pub trait ActivationOps { feature: "ActivationOps::softmax", }) } + + /// Log-softmax along a dimension: log(softmax(x, dim)) + /// + /// Computed as `x - logsumexp(x, dim)` for numerical stability. + /// Used in log-probability calculations, Bayesian inference, + /// categorical distributions, and information theory. + fn log_softmax(&self, a: &Tensor, dim: isize) -> Result> { + let _ = (a, dim); + Err(Error::NotImplemented { + feature: "ActivationOps::log_softmax", + }) + } + + /// Softmax backward pass: computes gradient w.r.t. input given output gradient and softmax output. + /// + /// Formula: `d_input = output * (grad - sum(grad * output, dim, keepdim=true))` + /// + /// This is the Jacobian-vector product for softmax, used in training backward passes. + /// + /// # Arguments + /// * `grad` - Upstream gradient (same shape as output) + /// * `output` - The softmax output from the forward pass + /// * `dim` - The dimension along which softmax was computed + fn softmax_bwd(&self, grad: &Tensor, output: &Tensor, dim: isize) -> Result> { + let _ = (grad, output, dim); + Err(Error::NotImplemented { + feature: "ActivationOps::softmax_bwd", + }) + } + + /// Softplus: `log(1 + exp(a))` + /// + /// A smooth approximation to ReLU that is always positive and differentiable. + /// Used in Mamba2 for dt (step size) processing via `softplus(dt_proj(x)) + dt_bias`. + /// + /// Gradient: `sigmoid(a)` + fn softplus(&self, a: &Tensor) -> Result> { + let _ = a; + Err(Error::NotImplemented { + feature: "ActivationOps::softplus", + }) + } + + /// Fused SiLU-Mul: `silu(a) * b` in a single pass. + /// + /// Computes `(a / (1 + exp(-a))) * b` element-wise with one memory pass + /// instead of two (activation + multiply). Used in SwiGLU and similar gated architectures. + fn silu_mul(&self, a: &Tensor, b: &Tensor) -> Result> { + let _ = (a, b); + Err(Error::NotImplemented { + feature: "ActivationOps::silu_mul", + }) + } + + /// Fused GELU-Mul: `gelu(a) * b` in a single pass. + /// + /// Computes `(0.5 * a * (1 + tanh(sqrt(2/pi) * (a + 0.044715*a^3)))) * b` element-wise. + /// Used in GeGLU gated architectures. + fn gelu_mul(&self, a: &Tensor, b: &Tensor) -> Result> { + let _ = (a, b); + Err(Error::NotImplemented { + feature: "ActivationOps::gelu_mul", + }) + } + + /// Fused ReLU-Mul: `relu(a) * b` in a single pass. + /// + /// Computes `max(0, a) * b` element-wise. Used in ReGLU gated architectures. + fn relu_mul(&self, a: &Tensor, b: &Tensor) -> Result> { + let _ = (a, b); + Err(Error::NotImplemented { + feature: "ActivationOps::relu_mul", + }) + } + + /// Fused Sigmoid-Mul: `sigmoid(a) * b` in a single pass. + /// + /// Computes `(1 / (1 + exp(-a))) * b` element-wise. Used in SiGLU gated architectures. + fn sigmoid_mul(&self, a: &Tensor, b: &Tensor) -> Result> { + let _ = (a, b); + Err(Error::NotImplemented { + feature: "ActivationOps::sigmoid_mul", + }) + } + + /// Fused SiLU-Mul backward: computes gradients for `output = silu(a) * b`. + /// + /// Returns `(d_a, d_b)` where: + /// - `d_a = grad * b * silu'(a)` with `silu'(x) = sigmoid(x) * (1 + x - silu(x))` + /// - `d_b = grad * silu(a)` + /// + /// Backends may implement this as a single fused kernel for better performance. + fn silu_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let _ = (grad, a, b); + Err(Error::NotImplemented { + feature: "ActivationOps::silu_mul_bwd", + }) + } + + /// Fused GELU-Mul backward: computes gradients for `output = gelu(a) * b`. + /// + /// Returns `(d_a, d_b)` where: + /// - `d_a = grad * b * gelu'(a)` + /// - `d_b = grad * gelu(a)` + fn gelu_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let _ = (grad, a, b); + Err(Error::NotImplemented { + feature: "ActivationOps::gelu_mul_bwd", + }) + } + + /// Fused ReLU-Mul backward: computes gradients for `output = relu(a) * b`. + /// + /// Returns `(d_a, d_b)` where: + /// - `d_a = grad * b * relu'(a)` with `relu'(x) = 1 if x > 0, else 0` + /// - `d_b = grad * relu(a)` + fn relu_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let _ = (grad, a, b); + Err(Error::NotImplemented { + feature: "ActivationOps::relu_mul_bwd", + }) + } + + /// Fused Sigmoid-Mul backward: computes gradients for `output = sigmoid(a) * b`. + /// + /// Returns `(d_a, d_b)` where: + /// - `d_a = grad * b * sigmoid'(a)` with `sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x))` + /// - `d_b = grad * sigmoid(a)` + fn sigmoid_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let _ = (grad, a, b); + Err(Error::NotImplemented { + feature: "ActivationOps::sigmoid_mul_bwd", + }) + } + + /// Dropout: randomly zero elements with probability `p` during training. + /// + /// When `training` is true, each element is independently zeroed with probability `p`, + /// and remaining elements are scaled by `1/(1-p)` to maintain expected values. + /// When `training` is false, returns the input unchanged. + /// + /// Used in regularization, Monte Carlo dropout, and Bayesian approximation. + fn dropout(&self, a: &Tensor, p: f64, training: bool) -> Result> { + let _ = (a, p, training); + Err(Error::NotImplemented { + feature: "ActivationOps::dropout", + }) + } } diff --git a/src/ops/traits/binary.rs b/src/ops/traits/binary.rs index cbf81f7c..dcb1acf1 100644 --- a/src/ops/traits/binary.rs +++ b/src/ops/traits/binary.rs @@ -255,4 +255,31 @@ pub trait BinaryOps { /// # Ok::<(), numr::error::Error>(()) /// ``` fn atan2(&self, y: &Tensor, x: &Tensor) -> Result>; + + /// Fused multiply-add: a * b + c + /// + /// Computes the element-wise fused multiply-add of three tensors in a single pass, + /// reducing memory bandwidth compared to separate multiply and add operations. + /// Uses hardware FMA instructions where available (AVX2/AVX-512/NEON). + /// + /// All three tensors must have the same shape (no broadcasting). + /// + /// # Arguments + /// * `a` - First multiplicand + /// * `b` - Second multiplicand + /// * `c` - Addend + fn fused_mul_add(&self, a: &Tensor, b: &Tensor, c: &Tensor) -> Result>; + + /// Fused add-multiply: (a + b) * c + /// + /// Computes the element-wise fused add-multiply of three tensors in a single pass. + /// Common in residual + scaling patterns. + /// + /// All three tensors must have the same shape (no broadcasting). + /// + /// # Arguments + /// * `a` - First addend + /// * `b` - Second addend + /// * `c` - Multiplicand + fn fused_add_mul(&self, a: &Tensor, b: &Tensor, c: &Tensor) -> Result>; } diff --git a/src/ops/traits/fp8_matmul.rs b/src/ops/traits/fp8_matmul.rs new file mode 100644 index 00000000..a905de5d --- /dev/null +++ b/src/ops/traits/fp8_matmul.rs @@ -0,0 +1,92 @@ +//! FP8 matrix multiplication operations trait. +//! +//! FP8 matmul differs from standard matmul in two key ways: +//! 1. Per-tensor scale factors compensate for the limited dynamic range of FP8 +//! 2. Accumulation is always in FP32 for numerical accuracy +//! +//! The output dtype can differ from input dtype (typically F32, F16, or BF16). + +use crate::dtype::DType; +use crate::error::Result; +use crate::runtime::Runtime; +use crate::tensor::Tensor; + +/// FP8 matrix multiplication operations with per-tensor scaling. +/// +/// FP8 GEMM computes: `output = (scale_a * A) @ (scale_b * B)` where A and B are +/// FP8 tensors, arithmetic is performed in FP32, and the result is cast to `out_dtype`. +/// +/// # Scale Factors +/// +/// FP8 has very limited dynamic range (~[-448, 448] for E4M3, ~[-57344, 57344] for E5M2). +/// Per-tensor scale factors map the original tensor range into the FP8 representable range: +/// +/// ```text +/// quantize: fp8_tensor = original_tensor / scale +/// dequantize: original_tensor = fp8_tensor * scale +/// matmul: C = (A * scale_a) @ (B * scale_b) = scale_a * scale_b * (A_fp8 @ B_fp8) +/// ``` +/// +/// # Use Cases +/// +/// - `fp8_matmul`: E4M3 x E4M3 — forward pass (weights and activations) +/// - `fp8_matmul_e5m2`: E5M2 x E4M3 — backward pass (gradients x weights) +pub trait Fp8MatmulOps { + /// FP8 E4M3 x E4M3 matrix multiplication with per-tensor scaling. + /// + /// Computes: `output = scale_a * scale_b * (a_e4m3 @ b_e4m3)` + /// with FP32 accumulation, then casts to `out_dtype`. + /// + /// # Arguments + /// + /// * `a` - Input tensor of shape `[..., M, K]` with dtype FP8E4M3 + /// * `b` - Weight tensor of shape `[..., K, N]` with dtype FP8E4M3 + /// * `scale_a` - Per-tensor scale factor for A (scalar f32) + /// * `scale_b` - Per-tensor scale factor for B (scalar f32) + /// * `out_dtype` - Output dtype (F32, F16, or BF16) + /// + /// # Returns + /// + /// Output tensor of shape `[..., M, N]` with dtype `out_dtype`. + /// + /// # Errors + /// + /// - `DTypeMismatch` if inputs are not FP8E4M3 + /// - `ShapeMismatch` if inner dimensions don't match + /// - `UnsupportedDType` if `out_dtype` is not F32/F16/BF16 + fn fp8_matmul( + &self, + a: &Tensor, + b: &Tensor, + scale_a: f32, + scale_b: f32, + out_dtype: DType, + ) -> Result>; + + /// FP8 E5M2 x E4M3 matrix multiplication with per-tensor scaling. + /// + /// Used for backward pass: gradients (E5M2, larger range) x weights (E4M3, higher precision). + /// + /// Computes: `output = scale_a * scale_b * (a_e5m2 @ b_e4m3)` + /// with FP32 accumulation, then casts to `out_dtype`. + /// + /// # Arguments + /// + /// * `a` - Gradient tensor of shape `[..., M, K]` with dtype FP8E5M2 + /// * `b` - Weight tensor of shape `[..., K, N]` with dtype FP8E4M3 + /// * `scale_a` - Per-tensor scale factor for A (scalar f32) + /// * `scale_b` - Per-tensor scale factor for B (scalar f32) + /// * `out_dtype` - Output dtype (F32, F16, or BF16) + /// + /// # Returns + /// + /// Output tensor of shape `[..., M, N]` with dtype `out_dtype`. + fn fp8_matmul_e5m2( + &self, + a: &Tensor, + b: &Tensor, + scale_a: f32, + scale_b: f32, + out_dtype: DType, + ) -> Result>; +} diff --git a/src/ops/traits/gemm_epilogue.rs b/src/ops/traits/gemm_epilogue.rs new file mode 100644 index 00000000..45a2ff55 --- /dev/null +++ b/src/ops/traits/gemm_epilogue.rs @@ -0,0 +1,114 @@ +//! GEMM epilogue operations trait. +//! +//! Fused matrix multiplication with bias and activation/residual in a single kernel. +//! Eliminates extra kernel launches and memory round-trips for `Linear + Activation` patterns. + +use crate::error::Result; +use crate::runtime::Runtime; +use crate::tensor::Tensor; + +/// Activation function to fuse into the GEMM epilogue. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum GemmActivation { + /// No activation (identity) + None, + /// ReLU: max(0, x) + ReLU, + /// GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + GELU, + /// SiLU/Swish: x * sigmoid(x) + SiLU, + /// Sigmoid: 1 / (1 + exp(-x)) + Sigmoid, + /// Tanh: hyperbolic tangent + Tanh, +} + +/// Fused GEMM + bias + activation/residual operations. +/// +/// These operations fuse post-processing into the GEMM epilogue, avoiding extra +/// kernel launches and memory round-trips compared to separate matmul_bias + activation. +/// +/// # Performance +/// +/// For a typical `Linear + ReLU` pattern: +/// - **Unfused**: `temp = A @ B + bias` (write temp), `out = relu(temp)` (read temp, write out) +/// - **Fused**: `out = relu(A @ B + bias)` (single write) +/// +/// This saves one full read+write of the output matrix. +/// +/// # Backend Support +/// +/// | Backend | Supported DTypes | Notes | +/// |---------|------------------|-------| +/// | CPU | All dtypes | SIMD-accelerated activations | +/// | CUDA | F32, F64, F16, BF16 | Fused in GEMM epilogue | +/// | WebGPU | F32 | Per-activation entry points | +pub trait GemmEpilogueOps { + /// Fused GEMM + bias + activation: `activation(A @ B + bias)` + /// + /// # Arguments + /// + /// * `a` - Input tensor of shape `[..., M, K]` + /// * `b` - Weight tensor of shape `[..., K, N]` + /// * `bias` - Bias tensor of shape `[N]` (1D, broadcast across rows) + /// * `activation` - Activation function to apply element-wise after bias addition + /// + /// # Returns + /// + /// Output tensor of shape `[..., M, N]` + fn matmul_bias_activation( + &self, + a: &Tensor, + b: &Tensor, + bias: &Tensor, + activation: GemmActivation, + ) -> Result>; + + /// Fused GEMM + bias + residual: `A @ B + bias + residual` + /// + /// # Arguments + /// + /// * `a` - Input tensor of shape `[..., M, K]` + /// * `b` - Weight tensor of shape `[..., K, N]` + /// * `bias` - Bias tensor of shape `[N]` (1D, broadcast across rows) + /// * `residual` - Residual tensor of shape `[..., M, N]` (same shape as output) + /// + /// # Returns + /// + /// Output tensor of shape `[..., M, N]` + fn matmul_bias_residual( + &self, + a: &Tensor, + b: &Tensor, + bias: &Tensor, + residual: &Tensor, + ) -> Result>; + + /// Backward pass for fused GEMM + bias + activation. + /// + /// Computes gradients for `activation(A @ B + bias)`. + /// + /// # Arguments + /// + /// * `grad` - Gradient of the loss w.r.t. the output, shape `[..., M, N]` + /// * `a` - Input tensor from forward pass, shape `[..., M, K]` + /// * `b` - Weight tensor from forward pass, shape `[..., K, N]` + /// * `bias` - Bias tensor from forward pass, shape `[N]` + /// * `activation` - Activation function used in forward pass + /// + /// # Returns + /// + /// Tuple of `(d_a, d_b, d_bias)`: + /// * `d_a` - Gradient w.r.t. input A, shape `[..., M, K]` + /// * `d_b` - Gradient w.r.t. weight B, shape `[..., K, N]` + /// * `d_bias` - Gradient w.r.t. bias, shape `[N]` + fn matmul_bias_activation_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + bias: &Tensor, + activation: GemmActivation, + ) -> Result<(Tensor, Tensor, Tensor)>; +} diff --git a/src/ops/traits/indexing.rs b/src/ops/traits/indexing.rs index 2c3515a6..dfde8f6b 100644 --- a/src/ops/traits/indexing.rs +++ b/src/ops/traits/indexing.rs @@ -24,7 +24,7 @@ pub enum ScatterReduceOp { } /// Validate that indices tensor has an integer dtype (I32 or I64). -fn validate_index_dtype(indices: &Tensor) -> Result<()> { +fn validate_index_dtype>(indices: &Tensor) -> Result<()> { match indices.dtype() { DType::I32 | DType::I64 => Ok(()), other => Err(Error::InvalidArgument { @@ -228,7 +228,10 @@ pub trait IndexingOps { /// # Returns /// /// Tensor of shape `indices.shape()` with gathered values - fn take(&self, tensor: &Tensor, indices: &Tensor) -> Result> { + fn take(&self, tensor: &Tensor, indices: &Tensor) -> Result> + where + R: Runtime, + { validate_index_dtype(indices)?; let flat = tensor.contiguous().flatten()?; let indices_flat = indices.contiguous().flatten()?; @@ -250,12 +253,10 @@ pub trait IndexingOps { /// # Returns /// /// New tensor with the same shape as `tensor` and updated values - fn put( - &self, - tensor: &Tensor, - indices: &Tensor, - values: &Tensor, - ) -> Result> { + fn put(&self, tensor: &Tensor, indices: &Tensor, values: &Tensor) -> Result> + where + R: Runtime, + { validate_index_dtype(indices)?; let flat = tensor.contiguous().flatten()?; let indices_flat = indices.contiguous().flatten()?; @@ -562,4 +563,33 @@ pub trait IndexingOps { feature: "IndexingOps::gather_2d", }) } + + /// Assign `src` into a slice of `dst` along dimension `dim` starting at `start`. + /// + /// Returns a new tensor equal to `dst` except that the region + /// `dst[..., start..start+src.shape[dim], ...]` is replaced by `src`. + /// + /// # Arguments + /// + /// * `dst` - Destination tensor + /// * `src` - Source tensor. Must have same shape as `dst` except at `dim`, + /// where `src.shape[dim] + start <= dst.shape[dim]` + /// * `dim` - Dimension along which to assign + /// * `start` - Starting index in `dst` along `dim` + /// + /// # Returns + /// + /// New tensor with the slice replaced + fn slice_assign( + &self, + dst: &Tensor, + src: &Tensor, + dim: usize, + start: usize, + ) -> Result> { + let _ = (dst, src, dim, start); + Err(Error::NotImplemented { + feature: "IndexingOps::slice_assign", + }) + } } diff --git a/src/ops/traits/mod.rs b/src/ops/traits/mod.rs index 9d0faeb3..d3aa3d79 100644 --- a/src/ops/traits/mod.rs +++ b/src/ops/traits/mod.rs @@ -13,6 +13,8 @@ mod conv; mod cumulative; mod distance; mod einsum; +mod fp8_matmul; +mod gemm_epilogue; mod indexing; mod kernel; mod linalg; @@ -27,6 +29,8 @@ mod scalar; mod semiring_matmul; mod shape; mod sorting; +#[cfg(feature = "sparse")] +mod sparse_24; mod statistics; mod tensor_ops; mod type_conversion; @@ -43,6 +47,8 @@ pub use conv::{ConvOps, PaddingMode}; pub use cumulative::CumulativeOps; pub use distance::{DistanceMetric, DistanceOps}; pub use einsum::EinsumOps; +pub use fp8_matmul::Fp8MatmulOps; +pub use gemm_epilogue::{GemmActivation, GemmEpilogueOps}; pub use indexing::{IndexingOps, ScatterReduceOp}; pub use kernel::Kernel; pub use linalg::LinalgOps; @@ -57,6 +63,8 @@ pub use scalar::ScalarOps; pub use semiring_matmul::SemiringMatmulOps; pub use shape::ShapeOps; pub use sorting::SortingOps; +#[cfg(feature = "sparse")] +pub use sparse_24::Sparse24Ops; pub use statistics::StatisticalOps; pub use tensor_ops::TensorOps; pub use type_conversion::TypeConversionOps; diff --git a/src/ops/traits/normalization.rs b/src/ops/traits/normalization.rs index ec982797..2a14812a 100644 --- a/src/ops/traits/normalization.rs +++ b/src/ops/traits/normalization.rs @@ -45,4 +45,131 @@ pub trait NormalizationOps { feature: "NormalizationOps::layer_norm", }) } + + /// Group Normalization: normalize over groups of channels. + /// + /// Divides channels into `num_groups` groups and normalizes each group + /// independently. Used in some vision architectures and diffusion models. + /// + /// # Arguments + /// + /// * `input` - Input tensor of shape `[batch, channels, ...]` + /// * `weight` - Scale (gamma) of shape `[channels]` + /// * `bias` - Bias (beta) of shape `[channels]` + /// * `num_groups` - Number of groups (must divide channels evenly) + /// * `eps` - Small constant for numerical stability + fn group_norm( + &self, + input: &Tensor, + weight: &Tensor, + bias: &Tensor, + num_groups: usize, + eps: f32, + ) -> Result> { + let _ = (input, weight, bias, num_groups, eps); + Err(Error::NotImplemented { + feature: "NormalizationOps::group_norm", + }) + } + + /// Fused Add + RMS Normalization: pre_norm = x + residual, output = rms_norm(pre_norm, weight, eps) + /// + /// Saves one full memory pass vs separate add + rms_norm. Used in every + /// transformer residual connection. Returns `(output, pre_norm)` where + /// `pre_norm` is needed for backward pass and residual chaining. + /// + /// # Arguments + /// + /// * `x` - Input tensor of shape `[..., hidden_size]` + /// * `residual` - Residual tensor of same shape as `x` + /// * `weight` - Weight tensor of shape `[hidden_size]` + /// * `eps` - Small constant for numerical stability + fn fused_add_rms_norm( + &self, + x: &Tensor, + residual: &Tensor, + weight: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor)> { + let _ = (x, residual, weight, eps); + Err(Error::NotImplemented { + feature: "NormalizationOps::fused_add_rms_norm", + }) + } + + /// Backward pass for fused add + RMS normalization. + /// + /// Returns `(d_input_residual, d_weight)` where `d_input_residual` is the + /// gradient for both `x` and `residual` (they share the same gradient since + /// `d(x + residual)/dx = d(x + residual)/d(residual) = 1`). + /// + /// # Arguments + /// + /// * `grad` - Upstream gradient of shape `[..., hidden_size]` + /// * `pre_norm` - The `x + residual` value from forward pass + /// * `weight` - Weight tensor of shape `[hidden_size]` + /// * `eps` - Same eps used in forward pass + fn fused_add_rms_norm_bwd( + &self, + grad: &Tensor, + pre_norm: &Tensor, + weight: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor)> { + let _ = (grad, pre_norm, weight, eps); + Err(Error::NotImplemented { + feature: "NormalizationOps::fused_add_rms_norm_bwd", + }) + } + + /// Fused Add + Layer Normalization: pre_norm = x + residual, output = layer_norm(pre_norm, weight, bias, eps) + /// + /// Saves one full memory pass vs separate add + layer_norm. + /// Returns `(output, pre_norm)`. + /// + /// # Arguments + /// + /// * `x` - Input tensor of shape `[..., hidden_size]` + /// * `residual` - Residual tensor of same shape as `x` + /// * `weight` - Weight (gamma) tensor of shape `[hidden_size]` + /// * `bias` - Bias (beta) tensor of shape `[hidden_size]` + /// * `eps` - Small constant for numerical stability + fn fused_add_layer_norm( + &self, + x: &Tensor, + residual: &Tensor, + weight: &Tensor, + bias: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor)> { + let _ = (x, residual, weight, bias, eps); + Err(Error::NotImplemented { + feature: "NormalizationOps::fused_add_layer_norm", + }) + } + + /// Backward pass for fused add + layer normalization. + /// + /// Returns `(d_input_residual, d_weight, d_bias)`. + /// + /// # Arguments + /// + /// * `grad` - Upstream gradient of shape `[..., hidden_size]` + /// * `pre_norm` - The `x + residual` value from forward pass + /// * `weight` - Weight (gamma) tensor of shape `[hidden_size]` + /// * `bias` - Bias (beta) tensor of shape `[hidden_size]` + /// * `eps` - Same eps used in forward pass + fn fused_add_layer_norm_bwd( + &self, + grad: &Tensor, + pre_norm: &Tensor, + weight: &Tensor, + bias: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor, Tensor)> { + let _ = (grad, pre_norm, weight, bias, eps); + Err(Error::NotImplemented { + feature: "NormalizationOps::fused_add_layer_norm_bwd", + }) + } } diff --git a/src/ops/traits/random.rs b/src/ops/traits/random.rs index bd456933..20de4603 100644 --- a/src/ops/traits/random.rs +++ b/src/ops/traits/random.rs @@ -32,6 +32,28 @@ pub trait RandomOps { }) } + /// Generate uniform random values in [0, 1) with a deterministic seed + /// + /// Same as `rand()` but uses the provided seed for reproducible output. + /// Calling with the same seed and shape always produces the same tensor. + /// + /// # Arguments + /// + /// * `shape` - Shape of the output tensor + /// * `dtype` - Data type of the output tensor (must be floating point) + /// * `seed` - Deterministic seed for the PRNG + fn rand_seeded( + &self, + shape: &[usize], + dtype: crate::dtype::DType, + seed: u64, + ) -> Result> { + let _ = (shape, dtype, seed); + Err(Error::NotImplemented { + feature: "RandomOps::rand_seeded", + }) + } + /// Generate standard normal random values (mean=0, std=1) /// /// Creates a tensor filled with random values from standard normal distribution N(0, 1). diff --git a/src/ops/traits/scalar.rs b/src/ops/traits/scalar.rs index 0e466365..0c0b3e1f 100644 --- a/src/ops/traits/scalar.rs +++ b/src/ops/traits/scalar.rs @@ -25,4 +25,10 @@ pub trait ScalarOps: TensorOps { /// Reverse subtract: scalar - a fn rsub_scalar(&self, a: &Tensor, scalar: f64) -> Result>; + + /// Fused multiply-add scalar: a * scale + bias + /// + /// Applies an affine transform to each element in a single pass. + /// Common in normalization (scale + shift) and quantization. + fn fused_mul_add_scalar(&self, a: &Tensor, scale: f64, bias: f64) -> Result>; } diff --git a/src/ops/traits/sparse_24.rs b/src/ops/traits/sparse_24.rs new file mode 100644 index 00000000..c628a1fe --- /dev/null +++ b/src/ops/traits/sparse_24.rs @@ -0,0 +1,57 @@ +//! 2:4 structured sparsity operations trait. + +use crate::error::{Error, Result}; +use crate::runtime::Runtime; +use crate::sparse::Sparse24Tensor; +use crate::tensor::Tensor; + +/// Operations for 2:4 structured sparsity +/// +/// Provides pruning (dense → 2:4 compressed), decompression (2:4 → dense), +/// and sparse matrix multiplication using the compressed format. +pub trait Sparse24Ops { + /// Prune a dense matrix to 2:4 structured sparsity + /// + /// For each group of 4 consecutive elements along the K dimension, + /// keeps the 2 with largest magnitude and zeros the rest. + /// + /// # Arguments + /// * `dense` - Input tensor of shape [M, K] where K is divisible by 4 + /// + /// # Returns + /// A `Sparse24Tensor` containing the compressed values and metadata + fn prune_to_24(&self, dense: &Tensor) -> Result> { + let _ = dense; + Err(Error::NotImplemented { + feature: "Sparse24Ops::prune_to_24", + }) + } + + /// Decompress a 2:4 sparse tensor back to dense format + /// + /// Reconstructs the dense [M, K] matrix by placing non-zero values + /// at their original positions (zeros elsewhere). + fn sparse_24_to_dense(&self, sparse: &Sparse24Tensor) -> Result> { + let _ = sparse; + Err(Error::NotImplemented { + feature: "Sparse24Ops::sparse_24_to_dense", + }) + } + + /// Matrix multiplication with 2:4 sparse weight matrix + /// + /// Computes `input @ weight^T` where weight is in 2:4 compressed format. + /// + /// # Arguments + /// * `input` - Dense input tensor of shape [N, K] + /// * `weight` - 2:4 sparse weight of original shape [M, K] + /// + /// # Returns + /// Dense output tensor of shape [N, M] + fn sparse_24_matmul(&self, input: &Tensor, weight: &Sparse24Tensor) -> Result> { + let _ = (input, weight); + Err(Error::NotImplemented { + feature: "Sparse24Ops::sparse_24_matmul", + }) + } +} diff --git a/src/ops/traits/tensor_ops.rs b/src/ops/traits/tensor_ops.rs index 1d2ad98a..7ac9e382 100644 --- a/src/ops/traits/tensor_ops.rs +++ b/src/ops/traits/tensor_ops.rs @@ -6,8 +6,8 @@ use crate::runtime::Runtime; use super::{ ActivationOps, BinaryOps, ComplexOps, ConditionalOps, CumulativeOps, DistanceOps, IndexingOps, - LinalgOps, MatmulOps, NormalizationOps, RandomOps, ReduceOps, SemiringMatmulOps, ShapeOps, - SortingOps, StatisticalOps, TypeConversionOps, UnaryOps, UtilityOps, + LinalgOps, MatmulOps, NormalizationOps, ReduceOps, SemiringMatmulOps, ShapeOps, SortingOps, + StatisticalOps, TypeConversionOps, UnaryOps, UtilityOps, }; /// Core tensor operations trait @@ -43,7 +43,6 @@ pub trait TensorOps: + ShapeOps + SortingOps + StatisticalOps - + RandomOps + UnaryOps + BinaryOps + SemiringMatmulOps diff --git a/src/ops/wgpu/activation.rs b/src/ops/wgpu/activation.rs index 8cdb82d0..6d2aa474 100644 --- a/src/ops/wgpu/activation.rs +++ b/src/ops/wgpu/activation.rs @@ -2,10 +2,12 @@ use crate::error::Result; use crate::ops::ActivationOps; +use crate::ops::impl_generic::activation::{dropout_impl, log_softmax_impl, softplus_impl}; use crate::runtime::wgpu::WgpuClient; use crate::runtime::wgpu::WgpuRuntime; use crate::runtime::wgpu::ops::native::{ - native_parametric_activation, native_softmax, native_unary_op, + native_fused_activation_mul_bwd, native_fused_activation_mul_fwd, native_parametric_activation, + native_softmax, native_softmax_bwd, native_unary_op, }; use crate::tensor::Tensor; @@ -22,6 +24,15 @@ impl ActivationOps for WgpuClient { native_softmax(self, a, dim) } + fn softmax_bwd( + &self, + grad: &Tensor, + output: &Tensor, + dim: isize, + ) -> Result> { + native_softmax_bwd(self, grad, output, dim) + } + fn silu(&self, a: &Tensor) -> Result> { native_unary_op(self, "silu", a) } @@ -41,4 +52,89 @@ impl ActivationOps for WgpuClient { fn elu(&self, a: &Tensor, alpha: f64) -> Result> { native_parametric_activation(self, "elu", a, alpha) } + + fn silu_mul( + &self, + a: &Tensor, + b: &Tensor, + ) -> Result> { + native_fused_activation_mul_fwd(self, "silu_mul", a, b) + } + + fn gelu_mul( + &self, + a: &Tensor, + b: &Tensor, + ) -> Result> { + native_fused_activation_mul_fwd(self, "gelu_mul", a, b) + } + + fn relu_mul( + &self, + a: &Tensor, + b: &Tensor, + ) -> Result> { + native_fused_activation_mul_fwd(self, "relu_mul", a, b) + } + + fn sigmoid_mul( + &self, + a: &Tensor, + b: &Tensor, + ) -> Result> { + native_fused_activation_mul_fwd(self, "sigmoid_mul", a, b) + } + + fn silu_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + native_fused_activation_mul_bwd(self, "silu_mul_bwd", grad, a, b) + } + + fn gelu_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + native_fused_activation_mul_bwd(self, "gelu_mul_bwd", grad, a, b) + } + + fn relu_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + native_fused_activation_mul_bwd(self, "relu_mul_bwd", grad, a, b) + } + + fn sigmoid_mul_bwd( + &self, + grad: &Tensor, + a: &Tensor, + b: &Tensor, + ) -> Result<(Tensor, Tensor)> { + native_fused_activation_mul_bwd(self, "sigmoid_mul_bwd", grad, a, b) + } + + fn softplus(&self, a: &Tensor) -> Result> { + softplus_impl(self, a) + } + + fn log_softmax(&self, a: &Tensor, dim: isize) -> Result> { + log_softmax_impl(self, a, dim) + } + + fn dropout( + &self, + a: &Tensor, + p: f64, + training: bool, + ) -> Result> { + dropout_impl(self, a, p, training) + } } diff --git a/src/ops/wgpu/binary.rs b/src/ops/wgpu/binary.rs index 6f22d344..61a09aa9 100644 --- a/src/ops/wgpu/binary.rs +++ b/src/ops/wgpu/binary.rs @@ -4,7 +4,9 @@ use crate::error::Result; use crate::ops::BinaryOps; use crate::runtime::wgpu::WgpuClient; use crate::runtime::wgpu::WgpuRuntime; -use crate::runtime::wgpu::ops::native::native_binary_op; +use crate::runtime::wgpu::ops::native::{ + native_binary_op, native_fused_add_mul, native_fused_mul_add, +}; use crate::tensor::Tensor; impl BinaryOps for WgpuClient { @@ -51,4 +53,22 @@ impl BinaryOps for WgpuClient { ) -> Result> { native_binary_op(self, "atan2", y, x) } + + fn fused_mul_add( + &self, + a: &Tensor, + b: &Tensor, + c: &Tensor, + ) -> Result> { + native_fused_mul_add(self, a, b, c) + } + + fn fused_add_mul( + &self, + a: &Tensor, + b: &Tensor, + c: &Tensor, + ) -> Result> { + native_fused_add_mul(self, a, b, c) + } } diff --git a/src/ops/wgpu/fp8_matmul.rs b/src/ops/wgpu/fp8_matmul.rs new file mode 100644 index 00000000..ea2180d0 --- /dev/null +++ b/src/ops/wgpu/fp8_matmul.rs @@ -0,0 +1,40 @@ +//! WebGPU implementation of FP8 matrix multiplication operations. +//! +//! WebGPU is intentionally limited to 32-bit types (F32, I32, U32). +//! FP8 dtypes are not supported on the WebGPU backend. + +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::ops::Fp8MatmulOps; +use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; +use crate::tensor::Tensor; + +impl Fp8MatmulOps for WgpuClient { + fn fp8_matmul( + &self, + a: &Tensor, + _b: &Tensor, + _scale_a: f32, + _scale_b: f32, + _out_dtype: DType, + ) -> Result> { + Err(Error::UnsupportedDType { + dtype: a.dtype(), + op: "fp8_matmul (WebGPU does not support FP8 types)", + }) + } + + fn fp8_matmul_e5m2( + &self, + a: &Tensor, + _b: &Tensor, + _scale_a: f32, + _scale_b: f32, + _out_dtype: DType, + ) -> Result> { + Err(Error::UnsupportedDType { + dtype: a.dtype(), + op: "fp8_matmul_e5m2 (WebGPU does not support FP8 types)", + }) + } +} diff --git a/src/ops/wgpu/gemm_epilogue.rs b/src/ops/wgpu/gemm_epilogue.rs new file mode 100644 index 00000000..e5806dc4 --- /dev/null +++ b/src/ops/wgpu/gemm_epilogue.rs @@ -0,0 +1,46 @@ +//! WebGPU implementation of GEMM epilogue operations. + +use crate::error::{Error, Result}; +use crate::ops::{GemmActivation, GemmEpilogueOps}; +use crate::runtime::wgpu::ops::native::{native_gemm_bias_activation, native_gemm_bias_residual}; +use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; +use crate::tensor::Tensor; + +impl GemmEpilogueOps for WgpuClient { + fn matmul_bias_activation( + &self, + a: &Tensor, + b: &Tensor, + bias: &Tensor, + activation: GemmActivation, + ) -> Result> { + native_gemm_bias_activation(self, a, b, bias, activation) + } + + fn matmul_bias_residual( + &self, + a: &Tensor, + b: &Tensor, + bias: &Tensor, + residual: &Tensor, + ) -> Result> { + native_gemm_bias_residual(self, a, b, bias, residual) + } + + fn matmul_bias_activation_bwd( + &self, + _grad: &Tensor, + _a: &Tensor, + _b: &Tensor, + _bias: &Tensor, + _activation: GemmActivation, + ) -> Result<( + Tensor, + Tensor, + Tensor, + )> { + Err(Error::NotImplemented { + feature: "matmul_bias_activation_bwd on WebGPU; use CPU backend for training", + }) + } +} diff --git a/src/ops/wgpu/indexing.rs b/src/ops/wgpu/indexing.rs index 372ba88d..68439c02 100644 --- a/src/ops/wgpu/indexing.rs +++ b/src/ops/wgpu/indexing.rs @@ -14,6 +14,7 @@ use crate::runtime::wgpu::ops::helpers::{ use crate::runtime::wgpu::ops::native::{ native_argreduce_op, native_embedding_lookup, native_gather, native_index_put, native_index_select, native_masked_fill, native_masked_select, native_scatter, + native_slice_assign, }; use crate::runtime::wgpu::shaders::{ launch_bincount, launch_gather_2d, launch_gather_nd, launch_scatter_reduce, @@ -602,4 +603,14 @@ impl IndexingOps for WgpuClient { Ok(output) } + + fn slice_assign( + &self, + dst: &Tensor, + src: &Tensor, + dim: usize, + start: usize, + ) -> Result> { + native_slice_assign(self, dst, src, dim, start) + } } diff --git a/src/ops/wgpu/mod.rs b/src/ops/wgpu/mod.rs index b685d604..6bcdaf4c 100644 --- a/src/ops/wgpu/mod.rs +++ b/src/ops/wgpu/mod.rs @@ -25,8 +25,12 @@ pub mod random; pub mod reduce; pub mod scalar; pub mod shape; +#[cfg(feature = "sparse")] +pub mod sparse_24; pub mod sorting; pub mod statistics; pub mod type_conversion; pub mod unary; pub mod utility; +pub mod fp8_matmul; +pub mod gemm_epilogue; diff --git a/src/ops/wgpu/multivariate.rs b/src/ops/wgpu/multivariate.rs index 7146ead5..0e6247d7 100644 --- a/src/ops/wgpu/multivariate.rs +++ b/src/ops/wgpu/multivariate.rs @@ -109,11 +109,11 @@ fn dispatch_multinomial_count_shader( let output = Tensor::::empty(&[n_samples, k], DType::F32, client.device()); // Get buffers - let cdf_buf = get_buffer(cdf.storage().ptr()) - .ok_or_else(|| Error::Internal("CDF buffer not found".to_string()))?; - let uniforms_buf = get_buffer(uniforms.storage().ptr()) + let cdf_buf = + get_buffer(cdf.ptr()).ok_or_else(|| Error::Internal("CDF buffer not found".to_string()))?; + let uniforms_buf = get_buffer(uniforms.ptr()) .ok_or_else(|| Error::Internal("Uniforms buffer not found".to_string()))?; - let output_buf = get_buffer(output.storage().ptr()) + let output_buf = get_buffer(output.ptr()) .ok_or_else(|| Error::Internal("Output buffer not found".to_string()))?; // Create params buffer diff --git a/src/ops/wgpu/normalization.rs b/src/ops/wgpu/normalization.rs index 37f1570f..1cd86ae9 100644 --- a/src/ops/wgpu/normalization.rs +++ b/src/ops/wgpu/normalization.rs @@ -4,7 +4,10 @@ use crate::error::Result; use crate::ops::NormalizationOps; use crate::runtime::wgpu::WgpuClient; use crate::runtime::wgpu::WgpuRuntime; -use crate::runtime::wgpu::ops::native::{native_layer_norm, native_rms_norm}; +use crate::runtime::wgpu::ops::native::{ + native_fused_add_layer_norm, native_fused_add_layer_norm_bwd, native_fused_add_rms_norm, + native_fused_add_rms_norm_bwd, native_group_norm, native_layer_norm, native_rms_norm, +}; use crate::tensor::Tensor; impl NormalizationOps for WgpuClient { @@ -26,4 +29,61 @@ impl NormalizationOps for WgpuClient { ) -> Result> { native_layer_norm(self, a, weight, bias, eps) } + + fn group_norm( + &self, + input: &Tensor, + weight: &Tensor, + bias: &Tensor, + num_groups: usize, + eps: f32, + ) -> Result> { + native_group_norm(self, input, weight, bias, num_groups, eps) + } + + fn fused_add_rms_norm( + &self, + x: &Tensor, + residual: &Tensor, + weight: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor)> { + native_fused_add_rms_norm(self, x, residual, weight, eps) + } + + fn fused_add_layer_norm( + &self, + x: &Tensor, + residual: &Tensor, + weight: &Tensor, + bias: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor)> { + native_fused_add_layer_norm(self, x, residual, weight, bias, eps) + } + + fn fused_add_rms_norm_bwd( + &self, + grad: &Tensor, + pre_norm: &Tensor, + weight: &Tensor, + eps: f32, + ) -> Result<(Tensor, Tensor)> { + native_fused_add_rms_norm_bwd(self, grad, pre_norm, weight, eps) + } + + fn fused_add_layer_norm_bwd( + &self, + grad: &Tensor, + pre_norm: &Tensor, + weight: &Tensor, + bias: &Tensor, + eps: f32, + ) -> Result<( + Tensor, + Tensor, + Tensor, + )> { + native_fused_add_layer_norm_bwd(self, grad, pre_norm, weight, bias, eps) + } } diff --git a/src/ops/wgpu/random.rs b/src/ops/wgpu/random.rs index d7e8f212..9104f726 100644 --- a/src/ops/wgpu/random.rs +++ b/src/ops/wgpu/random.rs @@ -66,6 +66,46 @@ impl RandomOps for WgpuClient { Ok(out) } + fn rand_seeded(&self, shape: &[usize], dtype: DType, seed: u64) -> Result> { + if !matches!(dtype, DType::F32) { + return Err(Error::UnsupportedDType { + dtype, + op: "rand_seeded", + }); + } + + let numel: usize = shape.iter().product(); + if numel == 0 { + return Ok(Tensor::empty(shape, dtype, self.device())); + } + + let out = alloc_output(self, shape, dtype); + let out_buf = get_tensor_buffer(&out)?; + + // Truncate u64 seed to u32 — WGSL has no native u64 support. + // Determinism is still guaranteed: same seed → same u32 → same output. + let seed = seed as u32; + + let params = RandParams { + numel: numel as u32, + seed, + _pad1: 0, + _pad2: 0, + }; + let params_buf = create_params_buffer(self, ¶ms); + + shape::launch_rand( + self.pipeline_cache(), + self.wgpu_queue(), + &out_buf, + ¶ms_buf, + numel, + dtype, + )?; + + Ok(out) + } + fn randn(&self, shape: &[usize], dtype: DType) -> Result> { // WebGPU randn only supports F32 if !matches!(dtype, DType::F32) { diff --git a/src/ops/wgpu/scalar.rs b/src/ops/wgpu/scalar.rs index 5197cdef..2e1f792c 100644 --- a/src/ops/wgpu/scalar.rs +++ b/src/ops/wgpu/scalar.rs @@ -2,7 +2,7 @@ use crate::error::Result; use crate::ops::ScalarOps; -use crate::runtime::wgpu::ops::native::native_scalar_op; +use crate::runtime::wgpu::ops::native::{native_fused_mul_add_scalar, native_scalar_op}; use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; use crate::tensor::Tensor; @@ -30,4 +30,13 @@ impl ScalarOps for WgpuClient { fn rsub_scalar(&self, a: &Tensor, scalar: f64) -> Result> { native_scalar_op(self, "rsub_scalar", a, scalar) } + + fn fused_mul_add_scalar( + &self, + a: &Tensor, + scale: f64, + bias: f64, + ) -> Result> { + native_fused_mul_add_scalar(self, a, scale, bias) + } } diff --git a/src/ops/wgpu/shape.rs b/src/ops/wgpu/shape.rs index dd2ece85..9125f9e9 100644 --- a/src/ops/wgpu/shape.rs +++ b/src/ops/wgpu/shape.rs @@ -4,8 +4,8 @@ use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::ShapeOps; use crate::ops::impl_generic::{repeat_interleave_impl, unfold_impl}; -use crate::runtime::shape_ops; -use crate::runtime::shape_ops::{validate_cat, validate_stack}; +use crate::runtime::common::shape_ops; +use crate::runtime::common::shape_ops::{validate_cat, validate_stack}; use crate::runtime::wgpu::WgpuClient; use crate::runtime::wgpu::WgpuRuntime; use crate::runtime::wgpu::ops::helpers::{ @@ -34,11 +34,7 @@ impl ShapeOps for WgpuClient { // Copy data from each tensor using WGSL kernel let mut cat_offset = 0usize; for &tensor in tensors { - let tensor_contig = if tensor.is_contiguous() { - tensor.clone() - } else { - tensor.contiguous() - }; + let tensor_contig = tensor.contiguous(); let src_cat_size = tensor.shape()[cat_params.dim_idx]; let total_elements = cat_params.outer_size * src_cat_size * cat_params.inner_size; diff --git a/src/ops/wgpu/sorting.rs b/src/ops/wgpu/sorting.rs index 29a4ee14..38b473f7 100644 --- a/src/ops/wgpu/sorting.rs +++ b/src/ops/wgpu/sorting.rs @@ -1,5 +1,8 @@ //! Sorting operations for WebGPU runtime +/// Maximum sort dimension size supported by the WebGPU bitonic sort (shared memory limit). +const MAX_SHARED_SORT_SIZE: usize = 512; + use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::{CumulativeOps, SortingOps, TypeConversionOps}; @@ -39,14 +42,13 @@ impl SortingOps for WgpuClient { let sort_size = shape[dim_idx]; // Check sort size limit (WebGPU bitonic sort in shared memory) - if sort_size > crate::runtime::wgpu::shaders::generator::MAX_SHARED_SORT_SIZE { + if sort_size > MAX_SHARED_SORT_SIZE { return Err(Error::backend_limitation( "WebGPU", "sort", format!( "max {} elements per dimension, got {}", - crate::runtime::wgpu::shaders::generator::MAX_SHARED_SORT_SIZE, - sort_size + MAX_SHARED_SORT_SIZE, sort_size ), )); } @@ -123,14 +125,13 @@ impl SortingOps for WgpuClient { let dim_idx = normalize_dim(dim, ndim)?; let sort_size = shape[dim_idx]; - if sort_size > crate::runtime::wgpu::shaders::generator::MAX_SHARED_SORT_SIZE { + if sort_size > MAX_SHARED_SORT_SIZE { return Err(Error::backend_limitation( "WebGPU", "sort_with_indices", format!( "max {} elements per dimension, got {}", - crate::runtime::wgpu::shaders::generator::MAX_SHARED_SORT_SIZE, - sort_size + MAX_SHARED_SORT_SIZE, sort_size ), )); } @@ -197,14 +198,13 @@ impl SortingOps for WgpuClient { let dim_idx = normalize_dim(dim, ndim)?; let sort_size = shape[dim_idx]; - if sort_size > crate::runtime::wgpu::shaders::generator::MAX_SHARED_SORT_SIZE { + if sort_size > MAX_SHARED_SORT_SIZE { return Err(Error::backend_limitation( "WebGPU", "argsort", format!( "max {} elements per dimension, got {}", - crate::runtime::wgpu::shaders::generator::MAX_SHARED_SORT_SIZE, - sort_size + MAX_SHARED_SORT_SIZE, sort_size ), )); } @@ -277,14 +277,13 @@ impl SortingOps for WgpuClient { }); } - if sort_size > crate::runtime::wgpu::shaders::generator::MAX_SHARED_SORT_SIZE { + if sort_size > MAX_SHARED_SORT_SIZE { return Err(Error::backend_limitation( "WebGPU", "topk", format!( "max {} elements per dimension, got {}", - crate::runtime::wgpu::shaders::generator::MAX_SHARED_SORT_SIZE, - sort_size + MAX_SHARED_SORT_SIZE, sort_size ), )); } diff --git a/src/ops/wgpu/sparse_24.rs b/src/ops/wgpu/sparse_24.rs new file mode 100644 index 00000000..274a361b --- /dev/null +++ b/src/ops/wgpu/sparse_24.rs @@ -0,0 +1,149 @@ +//! WebGPU implementation of 2:4 structured sparsity operations. +//! +//! WebGPU uses decompress + standard matmul (no hardware sparse tensor cores). +//! F32 only (WebGPU constraint). + +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::ops::MatmulOps; +use crate::ops::traits::Sparse24Ops; +use crate::runtime::ensure_contiguous; +use crate::runtime::wgpu::WgpuClient; +use crate::runtime::wgpu::WgpuRuntime; +use crate::runtime::wgpu::ops::helpers::{alloc_output, create_params_buffer, get_tensor_buffer}; +use crate::runtime::wgpu::shaders::sparse_24::{ + Sparse24Params, launch_sparse_24_decompress, launch_sparse_24_prune, +}; +use crate::sparse::structured::{Sparse24Tensor, meta_cols_for_k}; +use crate::tensor::Tensor; + +impl Sparse24Ops for WgpuClient { + fn prune_to_24(&self, dense: &Tensor) -> Result> { + if dense.ndim() != 2 { + return Err(Error::InvalidArgument { + arg: "dense", + reason: format!("Expected 2D tensor, got {}D", dense.ndim()), + }); + } + + let dtype = dense.dtype(); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "sparse_24_prune (WebGPU: F32 only)", + }); + } + + let m = dense.shape()[0]; + let k = dense.shape()[1]; + + if !k.is_multiple_of(4) { + return Err(Error::InvalidArgument { + arg: "dense", + reason: format!("K dimension ({k}) must be divisible by 4 for 2:4 sparsity"), + }); + } + + let dense_contig = ensure_contiguous(dense); + let half_k = k / 2; + let mc = meta_cols_for_k(k); + let num_groups = k / 4; + let total_groups = m * num_groups; + + let compressed = alloc_output(self, &[m, half_k], dtype); + let metadata = alloc_output(self, &[m, mc], DType::U32); + + // wgpu buffers are zero-initialized by default (spec requirement) + + let dense_buf = get_tensor_buffer(&dense_contig)?; + let comp_buf = get_tensor_buffer(&compressed)?; + let meta_buf = get_tensor_buffer(&metadata)?; + + let params = Sparse24Params { + total_groups: total_groups as u32, + num_groups_per_row: num_groups as u32, + meta_cols: mc as u32, + half_k: half_k as u32, + k: k as u32, + _pad0: 0, + _pad1: 0, + _pad2: 0, + }; + let params_buf = create_params_buffer(self, ¶ms); + + launch_sparse_24_prune( + self.pipeline_cache(), + self.wgpu_queue(), + &dense_buf, + &comp_buf, + &meta_buf, + ¶ms_buf, + total_groups, + )?; + + Sparse24Tensor::new(compressed, metadata, [m, k]) + } + + fn sparse_24_to_dense( + &self, + sparse: &Sparse24Tensor, + ) -> Result> { + let [m, k] = sparse.shape(); + let dtype = sparse.dtype(); + + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "sparse_24_to_dense (WebGPU: F32 only)", + }); + } + + let num_groups = k / 4; + let total_groups = m * num_groups; + let mc = meta_cols_for_k(k); + let half_k = k / 2; + + let vals = ensure_contiguous(sparse.compressed_values()); + let meta = ensure_contiguous(sparse.metadata()); + let dense = alloc_output(self, &[m, k], dtype); + + let vals_buf = get_tensor_buffer(&vals)?; + let meta_buf = get_tensor_buffer(&meta)?; + let dense_buf = get_tensor_buffer(&dense)?; + + let params = Sparse24Params { + total_groups: total_groups as u32, + num_groups_per_row: num_groups as u32, + meta_cols: mc as u32, + half_k: half_k as u32, + k: k as u32, + _pad0: 0, + _pad1: 0, + _pad2: 0, + }; + let params_buf = create_params_buffer(self, ¶ms); + + launch_sparse_24_decompress( + self.pipeline_cache(), + self.wgpu_queue(), + &vals_buf, + &meta_buf, + &dense_buf, + ¶ms_buf, + total_groups, + )?; + + Ok(dense) + } + + fn sparse_24_matmul( + &self, + input: &Tensor, + weight: &Sparse24Tensor, + ) -> Result> { + // WebGPU: decompress weight to dense, then standard matmul + let dense_weight = self.sparse_24_to_dense(weight)?; + let weight_t = dense_weight.t()?; + self.matmul(input, &weight_t) + } +} diff --git a/src/runtime/allocator.rs b/src/runtime/allocator.rs deleted file mode 100644 index 3bb9b7a9..00000000 --- a/src/runtime/allocator.rs +++ /dev/null @@ -1,145 +0,0 @@ -//! Memory allocator traits and default implementation -//! -//! The Allocator trait provides memory management with optional "freeze" support -//! for graph capture scenarios (e.g., CUDA Graphs). - -/// Memory allocator trait for runtime backends -/// -/// Allocators manage device memory with optional support for "freezing" - -/// a mode where allocations are captured for graph replay. -pub trait Allocator: Clone + Send + Sync { - /// Allocate memory of given size - /// - /// Returns a device pointer (u64) that can be used for operations. - /// Returns `Err(OutOfMemory)` if allocation fails. - fn allocate(&self, size_bytes: usize) -> crate::error::Result; - - /// Deallocate memory - fn deallocate(&self, ptr: u64, size_bytes: usize); - - /// Freeze the allocator for graph capture - /// - /// After freezing, allocations may be captured for replay. - /// Not all allocators support this (returns false by default). - fn freeze(&self) -> bool { - false - } - - /// Unfreeze the allocator - fn unfreeze(&self) { - // Default: no-op - } - - /// Check if the allocator is frozen - fn is_frozen(&self) -> bool { - false - } - - /// Get the total allocated bytes - fn allocated_bytes(&self) -> usize { - 0 // Default: tracking not supported - } -} - -/// Default allocator that delegates to Runtime methods -/// -/// This is a simple allocator that just calls the runtime's allocate/deallocate. -/// It doesn't support freezing or memory tracking. -#[derive(Clone, Debug)] -pub struct DefaultAllocator { - device: D, - allocate_fn: fn(usize, &D) -> crate::error::Result, - deallocate_fn: fn(u64, usize, &D), -} - -impl DefaultAllocator { - /// Create a new default allocator - pub fn new( - device: D, - allocate_fn: fn(usize, &D) -> crate::error::Result, - deallocate_fn: fn(u64, usize, &D), - ) -> Self { - Self { - device, - allocate_fn, - deallocate_fn, - } - } - - /// Get the device this allocator is associated with - pub fn device(&self) -> &D { - &self.device - } -} - -impl Allocator for DefaultAllocator { - fn allocate(&self, size_bytes: usize) -> crate::error::Result { - (self.allocate_fn)(size_bytes, &self.device) - } - - fn deallocate(&self, ptr: u64, size_bytes: usize) { - (self.deallocate_fn)(ptr, size_bytes, &self.device) - } -} - -#[cfg(any(feature = "cuda", feature = "wgpu"))] -/// RAII guard for GPU memory allocations. -/// -/// Ensures memory is deallocated when the guard is dropped, preventing leaks -/// on error paths. Call [`release`](AllocGuard::release) to take ownership of the -/// pointer (e.g., when transferring it into a `Tensor`). -pub struct AllocGuard<'a, A: Allocator> { - allocator: &'a A, - ptr: u64, - size: usize, - released: bool, -} - -#[cfg(any(feature = "cuda", feature = "wgpu"))] -impl<'a, A: Allocator> AllocGuard<'a, A> { - /// Allocate memory and wrap it in a guard. - pub fn new(allocator: &'a A, size_bytes: usize) -> crate::error::Result { - let ptr = allocator.allocate(size_bytes)?; - Ok(Self { - allocator, - ptr, - size: size_bytes, - released: false, - }) - } - - /// Get the raw pointer. - #[inline] - pub fn ptr(&self) -> u64 { - self.ptr - } - - /// Release ownership of the pointer, preventing deallocation on drop. - /// - /// Returns the raw pointer for use in tensor construction. - #[inline] - pub fn release(mut self) -> u64 { - self.released = true; - self.ptr - } -} - -#[cfg(any(feature = "cuda", feature = "wgpu"))] -impl Drop for AllocGuard<'_, A> { - fn drop(&mut self) { - if !self.released && self.ptr != 0 { - self.allocator.deallocate(self.ptr, self.size); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_default_allocator_trait_bounds() { - fn assert_allocator() {} - assert_allocator::>(); - } -} diff --git a/src/runtime/common/allocator.rs b/src/runtime/common/allocator.rs new file mode 100644 index 00000000..f94918b5 --- /dev/null +++ b/src/runtime/common/allocator.rs @@ -0,0 +1,531 @@ +//! Memory allocator traits and default implementation +//! +//! The Allocator trait provides memory management with optional "freeze" support +//! for graph capture scenarios (e.g., CUDA Graphs). + +/// Allocation statistics for debugging and profiling +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct AllocationStats { + /// Total number of allocations made (cumulative) + pub total_allocations: usize, + /// Total bytes allocated (cumulative) + pub total_bytes: usize, + /// Number of allocations currently live (not yet deallocated) + pub active_allocations: usize, + /// Whether the allocator is currently frozen + pub is_frozen: bool, + /// Peak memory usage in bytes (high-water mark) + pub peak_usage: usize, +} + +/// Memory allocator trait for runtime backends +/// +/// Allocators manage device memory with optional support for "freezing" - +/// a mode where allocations are captured for graph replay. +pub trait Allocator: Clone + Send + Sync { + /// Allocate memory of given size + /// + /// Returns a device pointer (u64) that can be used for operations. + /// Returns `Err(OutOfMemory)` if allocation fails. + fn allocate(&self, size_bytes: usize) -> crate::error::Result; + + /// Deallocate memory + fn deallocate(&self, ptr: u64, size_bytes: usize); + + /// Freeze the allocator for graph capture + /// + /// After freezing, allocations may be captured for replay. + /// Not all allocators support this (returns false by default). + fn freeze(&self) -> bool { + false + } + + /// Unfreeze the allocator + fn unfreeze(&self) { + // Default: no-op + } + + /// Check if the allocator is frozen + fn is_frozen(&self) -> bool { + false + } + + /// Get the total allocated bytes + fn allocated_bytes(&self) -> usize { + 0 // Default: tracking not supported + } + + /// Get allocation statistics + /// + /// Returns detailed allocation stats including active count, peak usage, + /// and frozen state. Default returns zeroed stats for allocators without tracking. + fn stats(&self) -> AllocationStats { + AllocationStats::default() + } + + /// Reset allocator counters and reclaim pooled memory. + /// + /// When `active_allocations == 0`, this zeros out stats counters + /// (total_allocations, total_bytes, peak_usage) and releases any + /// internally pooled/cached buffers back to the OS or driver. + /// + /// # Errors + /// + /// Returns `Err(AllocatorBusy)` if `active_allocations > 0`. + /// Caller must drop all tensors/storage referencing this allocator's + /// memory before calling reset — active allocations mean live + /// Storage references exist, and reclaiming that memory would + /// cause use-after-free. + fn reset(&self) -> crate::error::Result<()> { + Ok(()) + } +} + +/// Default allocator that delegates to Runtime methods +/// +/// This is a simple allocator that just calls the runtime's allocate/deallocate. +/// It doesn't support freezing or memory tracking. +#[derive(Clone, Debug)] +pub struct DefaultAllocator { + device: D, + allocate_fn: fn(usize, &D) -> crate::error::Result, + deallocate_fn: fn(u64, usize, &D), +} + +impl DefaultAllocator { + /// Create a new default allocator + pub fn new( + device: D, + allocate_fn: fn(usize, &D) -> crate::error::Result, + deallocate_fn: fn(u64, usize, &D), + ) -> Self { + Self { + device, + allocate_fn, + deallocate_fn, + } + } + + /// Get the device this allocator is associated with + pub fn device(&self) -> &D { + &self.device + } +} + +impl Allocator for DefaultAllocator { + fn allocate(&self, size_bytes: usize) -> crate::error::Result { + (self.allocate_fn)(size_bytes, &self.device) + } + + fn deallocate(&self, ptr: u64, size_bytes: usize) { + (self.deallocate_fn)(ptr, size_bytes, &self.device) + } +} + +/// Tracking allocator state (behind Arc> for thread-safe sharing) +#[derive(Debug)] +struct TrackingState { + inner: A, + total_allocations: usize, + total_bytes: usize, + active_allocations: usize, + active_bytes: usize, + peak_usage: usize, + frozen: bool, +} + +/// Allocator wrapper that tracks allocation statistics. +/// +/// Wraps any `Allocator` implementation with proper tracking of active +/// allocations, total bytes, peak usage, and freeze/reset support. +/// +/// Thread-safe via `Arc>` — cloning shares the same state. +/// +/// # Example +/// +/// ```ignore +/// let inner = DefaultAllocator::new(device, alloc_fn, dealloc_fn); +/// let tracking = TrackingAllocator::new(inner); +/// +/// let ptr = tracking.allocate(1024)?; +/// assert_eq!(tracking.stats().active_allocations, 1); +/// assert_eq!(tracking.stats().active_bytes(), 1024); +/// +/// tracking.deallocate(ptr, 1024); +/// assert_eq!(tracking.stats().active_allocations, 0); +/// +/// tracking.reset()?; // succeeds: no active allocations +/// ``` +#[derive(Debug)] +pub struct TrackingAllocator { + state: std::sync::Arc>>, +} + +impl Clone for TrackingAllocator { + fn clone(&self) -> Self { + Self { + state: self.state.clone(), + } + } +} + +impl TrackingAllocator { + /// Acquire the inner lock, recovering from poison if another thread panicked. + /// + /// Poisoning means a thread panicked while holding the lock. The tracking + /// counters may be inconsistent, but the inner allocator is still usable. + /// Recovering is safer than panicking the caller. + fn lock(&self) -> std::sync::MutexGuard<'_, TrackingState> { + self.state + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + } + + /// Create a new tracking allocator wrapping `inner`. + pub fn new(inner: A) -> Self { + Self { + state: std::sync::Arc::new(std::sync::Mutex::new(TrackingState { + inner, + total_allocations: 0, + total_bytes: 0, + active_allocations: 0, + active_bytes: 0, + peak_usage: 0, + frozen: false, + })), + } + } + + /// Get the current number of live bytes (convenience for active_bytes in stats) + pub fn active_bytes(&self) -> usize { + let s = self.lock(); + s.active_bytes + } +} + +impl Allocator for TrackingAllocator { + fn allocate(&self, size_bytes: usize) -> crate::error::Result { + let mut s = self.lock(); + if s.frozen { + return Err(crate::error::Error::AllocatorFrozen); + } + let ptr = s.inner.allocate(size_bytes)?; + s.total_allocations += 1; + s.total_bytes += size_bytes; + s.active_allocations += 1; + s.active_bytes += size_bytes; + if s.active_bytes > s.peak_usage { + s.peak_usage = s.active_bytes; + } + Ok(ptr) + } + + fn deallocate(&self, ptr: u64, size_bytes: usize) { + let mut s = self.lock(); + s.inner.deallocate(ptr, size_bytes); + s.active_allocations = s.active_allocations.saturating_sub(1); + s.active_bytes = s.active_bytes.saturating_sub(size_bytes); + } + + fn freeze(&self) -> bool { + let mut s = self.lock(); + s.frozen = true; + true + } + + fn unfreeze(&self) { + let mut s = self.lock(); + s.frozen = false; + } + + fn is_frozen(&self) -> bool { + let s = self.lock(); + s.frozen + } + + fn allocated_bytes(&self) -> usize { + let s = self.lock(); + s.active_bytes + } + + fn stats(&self) -> AllocationStats { + let s = self.lock(); + AllocationStats { + total_allocations: s.total_allocations, + total_bytes: s.total_bytes, + active_allocations: s.active_allocations, + is_frozen: s.frozen, + peak_usage: s.peak_usage, + } + } + + fn reset(&self) -> crate::error::Result<()> { + let mut s = self.lock(); + if s.active_allocations > 0 { + return Err(crate::error::Error::AllocatorBusy { + active_allocations: s.active_allocations, + }); + } + s.total_allocations = 0; + s.total_bytes = 0; + s.active_bytes = 0; + s.peak_usage = 0; + // frozen state is NOT reset — caller must explicitly unfreeze + Ok(()) + } +} + +#[cfg(any(feature = "cuda", feature = "wgpu"))] +/// RAII guard for GPU memory allocations. +/// +/// Ensures memory is deallocated when the guard is dropped, preventing leaks +/// on error paths. Call [`release`](AllocGuard::release) to take ownership of the +/// pointer (e.g., when transferring it into a `Tensor`). +pub struct AllocGuard<'a, A: Allocator> { + allocator: &'a A, + ptr: u64, + size: usize, + released: bool, +} + +#[cfg(any(feature = "cuda", feature = "wgpu"))] +impl<'a, A: Allocator> AllocGuard<'a, A> { + /// Allocate memory and wrap it in a guard. + pub fn new(allocator: &'a A, size_bytes: usize) -> crate::error::Result { + let ptr = allocator.allocate(size_bytes)?; + Ok(Self { + allocator, + ptr, + size: size_bytes, + released: false, + }) + } + + /// Get the raw pointer. + #[inline] + pub fn ptr(&self) -> u64 { + self.ptr + } + + /// Release ownership of the pointer, preventing deallocation on drop. + /// + /// Returns the raw pointer for use in tensor construction. + #[inline] + pub fn release(mut self) -> u64 { + self.released = true; + self.ptr + } +} + +#[cfg(any(feature = "cuda", feature = "wgpu"))] +impl Drop for AllocGuard<'_, A> { + fn drop(&mut self) { + if !self.released && self.ptr != 0 { + self.allocator.deallocate(self.ptr, self.size); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_allocator_trait_bounds() { + fn assert_allocator() {} + assert_allocator::>(); + } + + /// Simple in-memory allocator for testing (uses Vec storage behind the scenes) + #[derive(Clone)] + struct TestAllocator; + + impl Allocator for TestAllocator { + fn allocate(&self, size_bytes: usize) -> crate::error::Result { + if size_bytes == 0 { + return Ok(0); + } + let layout = std::alloc::Layout::from_size_align(size_bytes, 8).unwrap(); + let ptr = unsafe { std::alloc::alloc(layout) }; + if ptr.is_null() { + return Err(crate::error::Error::OutOfMemory { size: size_bytes }); + } + Ok(ptr as u64) + } + + fn deallocate(&self, ptr: u64, size_bytes: usize) { + if ptr == 0 || size_bytes == 0 { + return; + } + let layout = std::alloc::Layout::from_size_align(size_bytes, 8).unwrap(); + unsafe { std::alloc::dealloc(ptr as *mut u8, layout) }; + } + } + + #[test] + fn test_tracking_allocator_basic_stats() { + let tracking = TrackingAllocator::new(TestAllocator); + + let stats = tracking.stats(); + assert_eq!(stats.total_allocations, 0); + assert_eq!(stats.total_bytes, 0); + assert_eq!(stats.active_allocations, 0); + assert_eq!(stats.peak_usage, 0); + assert!(!stats.is_frozen); + + let ptr1 = tracking.allocate(1024).unwrap(); + let stats = tracking.stats(); + assert_eq!(stats.total_allocations, 1); + assert_eq!(stats.total_bytes, 1024); + assert_eq!(stats.active_allocations, 1); + assert_eq!(stats.peak_usage, 1024); + + let ptr2 = tracking.allocate(2048).unwrap(); + let stats = tracking.stats(); + assert_eq!(stats.total_allocations, 2); + assert_eq!(stats.total_bytes, 3072); + assert_eq!(stats.active_allocations, 2); + assert_eq!(stats.peak_usage, 3072); + + tracking.deallocate(ptr1, 1024); + let stats = tracking.stats(); + assert_eq!(stats.active_allocations, 1); + assert_eq!(stats.peak_usage, 3072); // peak unchanged + + tracking.deallocate(ptr2, 2048); + let stats = tracking.stats(); + assert_eq!(stats.active_allocations, 0); + assert_eq!(stats.peak_usage, 3072); // peak unchanged + } + + #[test] + fn test_tracking_allocator_allocated_bytes() { + let tracking = TrackingAllocator::new(TestAllocator); + + assert_eq!(tracking.allocated_bytes(), 0); + + let ptr = tracking.allocate(512).unwrap(); + assert_eq!(tracking.allocated_bytes(), 512); + assert_eq!(tracking.active_bytes(), 512); + + tracking.deallocate(ptr, 512); + assert_eq!(tracking.allocated_bytes(), 0); + } + + #[test] + fn test_tracking_allocator_freeze() { + let tracking = TrackingAllocator::new(TestAllocator); + + assert!(!tracking.is_frozen()); + assert!(tracking.freeze()); + assert!(tracking.is_frozen()); + + // Allocation must fail while frozen + let result = tracking.allocate(128); + assert!(result.is_err()); + match result.unwrap_err() { + crate::error::Error::AllocatorFrozen => {} + other => panic!("expected AllocatorFrozen, got: {other}"), + } + + tracking.unfreeze(); + assert!(!tracking.is_frozen()); + + // Allocation succeeds after unfreeze + let ptr = tracking.allocate(128).unwrap(); + tracking.deallocate(ptr, 128); + } + + #[test] + fn test_tracking_allocator_reset_success() { + let tracking = TrackingAllocator::new(TestAllocator); + + let ptr = tracking.allocate(1024).unwrap(); + tracking.deallocate(ptr, 1024); + + // All deallocated, reset should succeed + tracking.reset().unwrap(); + + let stats = tracking.stats(); + assert_eq!(stats.total_allocations, 0); + assert_eq!(stats.total_bytes, 0); + assert_eq!(stats.active_allocations, 0); + assert_eq!(stats.peak_usage, 0); + } + + #[test] + fn test_tracking_allocator_reset_busy() { + let tracking = TrackingAllocator::new(TestAllocator); + + let ptr = tracking.allocate(1024).unwrap(); + + // Active allocation, reset must fail + let result = tracking.reset(); + assert!(result.is_err()); + match result.unwrap_err() { + crate::error::Error::AllocatorBusy { + active_allocations: 1, + } => {} + other => panic!("expected AllocatorBusy(1), got: {other}"), + } + + // Clean up + tracking.deallocate(ptr, 1024); + } + + #[test] + fn test_tracking_allocator_peak_across_cycles() { + let tracking = TrackingAllocator::new(TestAllocator); + + // Cycle 1: allocate 4096 bytes total + let p1 = tracking.allocate(2048).unwrap(); + let p2 = tracking.allocate(2048).unwrap(); + assert_eq!(tracking.stats().peak_usage, 4096); + tracking.deallocate(p1, 2048); + tracking.deallocate(p2, 2048); + + // Peak is still 4096 (cumulative until reset) + assert_eq!(tracking.stats().peak_usage, 4096); + + // Reset clears peak + tracking.reset().unwrap(); + assert_eq!(tracking.stats().peak_usage, 0); + + // Cycle 2: smaller allocation + let p3 = tracking.allocate(512).unwrap(); + assert_eq!(tracking.stats().peak_usage, 512); + tracking.deallocate(p3, 512); + } + + #[test] + fn test_tracking_allocator_clone_shares_state() { + let tracking = TrackingAllocator::new(TestAllocator); + let clone = tracking.clone(); + + let ptr = tracking.allocate(256).unwrap(); + // Clone sees the same stats (Arc-shared state) + assert_eq!(clone.stats().active_allocations, 1); + + clone.deallocate(ptr, 256); + assert_eq!(tracking.stats().active_allocations, 0); + } + + #[test] + fn test_tracking_allocator_freeze_preserved_on_reset() { + let tracking = TrackingAllocator::new(TestAllocator); + tracking.freeze(); + // Reset with no active allocations succeeds but freeze is preserved + tracking.reset().unwrap(); + assert!(tracking.is_frozen()); + } + + #[test] + fn test_allocation_stats_default() { + let stats = AllocationStats::default(); + assert_eq!(stats.total_allocations, 0); + assert_eq!(stats.total_bytes, 0); + assert_eq!(stats.active_allocations, 0); + assert!(!stats.is_frozen); + assert_eq!(stats.peak_usage, 0); + } +} diff --git a/src/runtime/common/graph.rs b/src/runtime/common/graph.rs new file mode 100644 index 00000000..ff4e6735 --- /dev/null +++ b/src/runtime/common/graph.rs @@ -0,0 +1,86 @@ +//! Graph capture and replay for compute backends +//! +//! Graph capture records a sequence of operations that can be replayed efficiently. +//! This is a runtime-level concept (CUDA Graphs, Vulkan command buffers, etc.) +//! that benefits any compute workload — not just ML. + +/// A captured computation sequence that can be replayed. +/// +/// # Replay semantics +/// +/// On capture-capable backends (CUDA), `launch()` replays the recorded +/// computation on the same fixed-address buffers. Callers update input +/// data in-place, then call `launch()` to re-execute with new values. +/// +/// On non-capture backends (CPU, WebGPU), `capture_graph` executes the +/// closure eagerly and returns `NoOpGraph`. `launch()` is a no-op — +/// the computation already ran. Callers wanting repeated execution on +/// these backends must call the operations directly (not via launch). +/// +/// Use `R::supports_graph_capture()` to check capability without +/// side effects, then branch: +/// +/// ```ignore +/// if R::supports_graph_capture() { +/// let (graph, _) = R::capture_graph(client, |c| hot_path(c))?; +/// loop { update_inputs(); graph.launch()?; read_outputs(); } +/// } else { +/// loop { update_inputs(); hot_path(client)?; } +/// } +/// ``` +pub trait Graph: Send + Sync + Clone { + /// Replay the recorded computation. + fn launch(&self) -> crate::error::Result<()>; + + /// Whether `launch()` actually replays computation. + /// + /// Returns `true` for backends with real capture (CUDA), `false` for no-op (CPU, WebGPU). + /// + /// # Invariant + /// + /// Must be consistent with `Runtime::supports_graph_capture()`: + /// if `supports_graph_capture()` returns true, then any `Graph` produced + /// by `capture_graph()` MUST return true from `is_replay_capable()`, + /// and vice versa. + fn is_replay_capable(&self) -> bool { + false + } +} + +/// No-op graph for backends without capture support (CPU, WebGPU). +/// +/// Operations execute eagerly during "capture" — `launch()` is a no-op. +#[derive(Clone, Debug, Default)] +pub struct NoOpGraph; + +impl Graph for NoOpGraph { + fn launch(&self) -> crate::error::Result<()> { + Ok(()) + } + // is_replay_capable() returns false (default) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_noop_graph_launch() { + let graph = NoOpGraph; + assert!(graph.launch().is_ok()); + assert!(!graph.is_replay_capable()); + } + + #[test] + fn test_noop_graph_clone() { + let graph = NoOpGraph; + let cloned = graph.clone(); + assert!(cloned.launch().is_ok()); + } + + #[test] + fn test_noop_graph_send_sync() { + fn assert_send_sync() {} + assert_send_sync::(); + } +} diff --git a/src/runtime/helpers.rs b/src/runtime/common/helpers.rs similarity index 95% rename from src/runtime/helpers.rs rename to src/runtime/common/helpers.rs index 53331b97..84a154f0 100644 --- a/src/runtime/helpers.rs +++ b/src/runtime/common/helpers.rs @@ -141,7 +141,7 @@ pub fn validate_eye(n: usize, m: Option) -> (usize, usize) { /// A new tensor that is guaranteed to be contiguous. If the input was already /// contiguous, this is zero-copy (just clones the Arc). Otherwise, data is copied. #[inline] -pub fn ensure_contiguous(tensor: &Tensor) -> Tensor { +pub fn ensure_contiguous>(tensor: &Tensor) -> Tensor { if tensor.is_contiguous() { tensor.clone() } else { @@ -171,7 +171,10 @@ pub fn ensure_contiguous(tensor: &Tensor) -> Tensor { /// /// Returns `Error::DTypeMismatch` if the tensors have different dtypes. #[inline] -pub fn validate_binary_dtypes(a: &Tensor, b: &Tensor) -> Result { +pub fn validate_binary_dtypes>( + a: &Tensor, + b: &Tensor, +) -> Result { if a.dtype() != b.dtype() { return Err(Error::DTypeMismatch { lhs: a.dtype(), @@ -202,7 +205,10 @@ pub fn validate_binary_dtypes(a: &Tensor, b: &Tensor) -> Resul /// /// Returns `Error::BroadcastError` if shapes cannot be broadcast together. #[inline] -pub fn compute_broadcast_shape(a: &Tensor, b: &Tensor) -> Result> { +pub fn compute_broadcast_shape>( + a: &Tensor, + b: &Tensor, +) -> Result> { broadcast_shape(a.shape(), b.shape()).ok_or_else(|| Error::BroadcastError { lhs: a.shape().to_vec(), rhs: b.shape().to_vec(), diff --git a/src/runtime/common/mod.rs b/src/runtime/common/mod.rs new file mode 100644 index 00000000..7accc854 --- /dev/null +++ b/src/runtime/common/mod.rs @@ -0,0 +1,38 @@ +pub(crate) mod helpers; +pub(crate) mod shape_ops; +pub(crate) mod statistics_common; + +mod allocator; +mod graph; + +#[cfg(feature = "sparse")] +pub(crate) mod sparse_utils; + +// Allocator re-exports +#[cfg(any(feature = "cuda", feature = "wgpu"))] +pub(crate) use allocator::AllocGuard; +pub(crate) use allocator::DefaultAllocator; +pub use allocator::{AllocationStats, Allocator, TrackingAllocator}; + +// Graph re-exports +pub use graph::{Graph, NoOpGraph}; + +// Helper re-exports +pub(crate) use helpers::{ + compute_broadcast_shape, ensure_contiguous, normalize_dim, validate_arange, + validate_binary_dtypes, validate_eye, +}; + +/// Compute contiguous (row-major) strides for a given shape. +#[cfg(any(feature = "cuda", feature = "wgpu"))] +#[inline] +pub(crate) fn compute_contiguous_strides(shape: &[usize]) -> Vec { + if shape.is_empty() { + return Vec::new(); + } + let mut strides = vec![1usize; shape.len()]; + for i in (0..shape.len().saturating_sub(1)).rev() { + strides[i] = strides[i + 1] * shape[i + 1]; + } + strides +} diff --git a/src/runtime/shape_ops.rs b/src/runtime/common/shape_ops.rs similarity index 98% rename from src/runtime/shape_ops.rs rename to src/runtime/common/shape_ops.rs index 1ec74e27..7fd943c3 100644 --- a/src/runtime/shape_ops.rs +++ b/src/runtime/common/shape_ops.rs @@ -82,7 +82,10 @@ pub struct CatParams { /// Validate inputs for cat operation and compute output parameters. /// /// This is the single source of truth for cat validation, used by all backends. -pub fn validate_cat(tensors: &[&Tensor], dim: isize) -> Result { +pub fn validate_cat>( + tensors: &[&Tensor], + dim: isize, +) -> Result { // Validate: need at least one tensor if tensors.is_empty() { return Err(Error::InvalidArgument { @@ -159,7 +162,10 @@ pub fn validate_cat(tensors: &[&Tensor], dim: isize) -> Result(tensors: &[&Tensor], dim: isize) -> Result { +pub fn validate_stack>( + tensors: &[&Tensor], + dim: isize, +) -> Result { // Validate: need at least one tensor if tensors.is_empty() { return Err(Error::InvalidArgument { diff --git a/src/runtime/sparse_utils.rs b/src/runtime/common/sparse_utils.rs similarity index 100% rename from src/runtime/sparse_utils.rs rename to src/runtime/common/sparse_utils.rs diff --git a/src/runtime/statistics_common.rs b/src/runtime/common/statistics_common.rs similarity index 99% rename from src/runtime/statistics_common.rs rename to src/runtime/common/statistics_common.rs index a5a8bd48..518b1917 100644 --- a/src/runtime/statistics_common.rs +++ b/src/runtime/common/statistics_common.rs @@ -245,7 +245,7 @@ pub fn skew_composite( correction: usize, ) -> Result> where - R: crate::runtime::Runtime, + R: crate::runtime::Runtime, C: crate::ops::BinaryOps + crate::ops::ReduceOps + crate::ops::StatisticalOps @@ -294,7 +294,7 @@ pub fn kurtosis_composite( correction: usize, ) -> Result> where - R: crate::runtime::Runtime, + R: crate::runtime::Runtime, C: crate::ops::BinaryOps + crate::ops::ReduceOps + crate::ops::StatisticalOps diff --git a/src/runtime/communicator/group.rs b/src/runtime/communicator/group.rs new file mode 100644 index 00000000..50252365 --- /dev/null +++ b/src/runtime/communicator/group.rs @@ -0,0 +1,172 @@ +//! Communicator groups for multi-dimensional parallelism. +//! +//! Splits a world communicator into sub-communicators for Tensor Parallelism +//! (TP), Pipeline Parallelism (PP), Data Parallelism (DP), and Expert +//! Parallelism (EP). Uses the `Communicator::split()` method to create +//! sub-groups. + +use std::collections::HashMap; +use std::sync::Arc; + +use super::Communicator; +use crate::error::{Error, Result}; + +/// Dimension of parallelism. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ParallelDim { + /// Data parallelism: replicate model, shard data. + Data, + /// Tensor parallelism: shard weight matrices within a layer. + Tensor, + /// Pipeline parallelism: shard layers across stages. + Pipeline, + /// Expert parallelism: distribute MoE experts across devices. + Expert, +} + +/// A group of sub-communicators for multi-dimensional parallelism. +/// +/// Created by splitting a world communicator along TP, PP, and DP dimensions. +/// The layout is `[DP, PP, TP]` (TP innermost), meaning consecutive ranks +/// form a TP group. +/// +/// # Example +/// +/// ```ignore +/// // 8 GPUs: TP=2, PP=2, DP=2 +/// let group = CommunicatorGroup::new(world_comm, 2, 2, 2)?; +/// let tp_comm = group.tp(); // 2 ranks per group +/// let pp_comm = group.pp(); // 2 ranks per group +/// let dp_comm = group.dp(); // 2 ranks per group +/// ``` +pub struct CommunicatorGroup { + world: Arc, + dims: HashMap>, +} + +impl CommunicatorGroup { + /// Create communicator groups from a world communicator. + /// + /// Layout: `[DP, PP, TP]` (TP innermost). + /// - Ranks `[0..tp_size)` form the first TP group + /// - Ranks with the same `rank % tp_size` and same PP stage form a DP group + /// - etc. + /// + /// Requires `tp_size * pp_size * dp_size == world_size`. + pub fn new( + world: Arc, + tp_size: usize, + pp_size: usize, + dp_size: usize, + ) -> Result { + let ws = world.world_size(); + if tp_size * pp_size * dp_size != ws { + return Err(Error::Backend(format!( + "CommunicatorGroup: tp({tp_size}) * pp({pp_size}) * dp({dp_size}) = {} != world_size({ws})", + tp_size * pp_size * dp_size, + ))); + } + + let rank = world.rank(); + let mut dims = HashMap::new(); + + // Layout: [DP, PP, TP] — TP innermost + // rank = dp_idx * (pp_size * tp_size) + pp_idx * tp_size + tp_idx + let tp_idx = rank % tp_size; + let pp_idx = (rank / tp_size) % pp_size; + let dp_idx = rank / (tp_size * pp_size); + + // TP group: same dp_idx, same pp_idx → color = dp_idx * pp_size + pp_idx + if tp_size > 1 { + let tp_color = (dp_idx * pp_size + pp_idx) as u32; + if let Some(comm) = world.split(tp_color, tp_idx as u32)? { + dims.insert(ParallelDim::Tensor, Arc::from(comm)); + } + } + + // PP group: same dp_idx, same tp_idx → color = dp_idx * tp_size + tp_idx + // Use offset to avoid color collision with TP + if pp_size > 1 { + let color_offset = dp_size * pp_size; + let pp_color = (color_offset + dp_idx * tp_size + tp_idx) as u32; + if let Some(comm) = world.split(pp_color, pp_idx as u32)? { + dims.insert(ParallelDim::Pipeline, Arc::from(comm)); + } + } + + // DP group: same pp_idx, same tp_idx → color = pp_idx * tp_size + tp_idx + // Use offset to avoid collision with TP and PP + if dp_size > 1 { + let color_offset = dp_size * pp_size + dp_size * tp_size; + let dp_color = (color_offset + pp_idx * tp_size + tp_idx) as u32; + if let Some(comm) = world.split(dp_color, dp_idx as u32)? { + dims.insert(ParallelDim::Data, Arc::from(comm)); + } + } + + Ok(Self { world, dims }) + } + + /// The world communicator (all ranks). + pub fn world(&self) -> &Arc { + &self.world + } + + /// Tensor parallelism communicator. `None` if `tp_size == 1`. + pub fn tp(&self) -> Option<&Arc> { + self.dims.get(&ParallelDim::Tensor) + } + + /// Pipeline parallelism communicator. `None` if `pp_size == 1`. + pub fn pp(&self) -> Option<&Arc> { + self.dims.get(&ParallelDim::Pipeline) + } + + /// Data parallelism communicator. `None` if `dp_size == 1`. + pub fn dp(&self) -> Option<&Arc> { + self.dims.get(&ParallelDim::Data) + } + + /// Get communicator for an arbitrary parallelism dimension. + pub fn get(&self, dim: ParallelDim) -> Option<&Arc> { + self.dims.get(&dim) + } + + /// Add an expert parallelism communicator after construction. + /// + /// EP is orthogonal to the TP/PP/DP layout and may use a custom split. + pub fn set_expert(&mut self, comm: Arc) { + self.dims.insert(ParallelDim::Expert, comm); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::runtime::communicator::NoOpCommunicator; + + #[test] + fn test_parallel_dim_eq() { + assert_eq!(ParallelDim::Data, ParallelDim::Data); + assert_ne!(ParallelDim::Data, ParallelDim::Tensor); + } + + #[test] + fn test_single_rank_group() { + let world = Arc::new(NoOpCommunicator) as Arc; + let group = CommunicatorGroup::new(world, 1, 1, 1).unwrap(); + assert_eq!(group.world().world_size(), 1); + // All dims are size 1, so no sub-communicators created + assert!(group.tp().is_none()); + assert!(group.pp().is_none()); + assert!(group.dp().is_none()); + } + + #[test] + fn test_invalid_dimensions() { + let world = Arc::new(NoOpCommunicator) as Arc; + // 2*2*2=8 != 1 + let result = CommunicatorGroup::new(world, 2, 2, 2); + assert!(result.is_err()); + } +} diff --git a/src/runtime/communicator/hierarchical.rs b/src/runtime/communicator/hierarchical.rs new file mode 100644 index 00000000..d01aa651 --- /dev/null +++ b/src/runtime/communicator/hierarchical.rs @@ -0,0 +1,169 @@ +//! Hierarchical communicator: NCCL intra-node + nexar inter-node. +//! +//! Wraps [`nexar_nccl::HierarchicalComm`] and implements [`Communicator`] so +//! that numr's distributed patterns work transparently over hierarchical +//! GPU clusters. Uses NCCL for same-node GPU-GPU (NVLink/PCIe) and nexar +//! QUIC for cross-node communication. + +use super::nexar_compat::{to_nexar_dtype, to_nexar_op}; +use super::{Communicator, ReduceOp}; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +/// Maps a nexar-nccl error to a numr error. +fn map_err(e: nexar_nccl::NcclCommError) -> Error { + Error::Backend(format!("hierarchical communicator: {e}")) +} + +/// Maps a nexar error to a numr error. +fn map_nexar_err(e: nexar::NexarError) -> Error { + Error::Backend(format!("hierarchical communicator (nexar): {e}")) +} + +/// Hierarchical communicator backed by [`nexar_nccl::HierarchicalComm`]. +/// +/// Combines NCCL for intra-node GPU-GPU with nexar for inter-node +/// communication. This is the standard 2D decomposition used by +/// Megatron-LM and DeepSpeed. +/// +/// # Construction +/// +/// Use [`nexar_nccl::form_hierarchical_comm`] to create the underlying +/// `HierarchicalComm`, then wrap it: +/// +/// ```ignore +/// let hcomm = unsafe { form_hierarchical_comm(nexar_client, stream).await? }; +/// let rt = tokio::runtime::Runtime::new()?; +/// let comm = HierarchicalCommunicator::new(hcomm, rt); +/// ``` +pub struct HierarchicalCommunicator { + comm: nexar_nccl::HierarchicalComm, + rt: tokio::runtime::Runtime, +} + +impl HierarchicalCommunicator { + /// Wrap an existing `HierarchicalComm` with a tokio runtime for async→sync bridging. + pub fn new(comm: nexar_nccl::HierarchicalComm, rt: tokio::runtime::Runtime) -> Self { + Self { comm, rt } + } + + /// Reference to the underlying hierarchical communicator. + pub fn inner(&self) -> &nexar_nccl::HierarchicalComm { + &self.comm + } +} + +impl Communicator for HierarchicalCommunicator { + fn world_size(&self) -> usize { + self.comm.world_size() as usize + } + + fn rank(&self) -> usize { + self.comm.rank() as usize + } + + unsafe fn all_reduce(&self, ptr: u64, count: usize, dtype: DType, op: ReduceOp) -> Result<()> { + let nd = to_nexar_dtype(dtype)?; + let no = to_nexar_op(op); + self.rt + .block_on(unsafe { self.comm.allreduce(ptr, count, nd, no) }) + .map_err(map_err) + } + + unsafe fn broadcast(&self, ptr: u64, count: usize, dtype: DType, root: usize) -> Result<()> { + let nd = to_nexar_dtype(dtype)?; + self.rt + .block_on(unsafe { self.comm.broadcast(ptr, count, nd, root as u32) }) + .map_err(map_err) + } + + unsafe fn all_gather( + &self, + send_ptr: u64, + recv_ptr: u64, + count: usize, + dtype: DType, + ) -> Result<()> { + let nd = to_nexar_dtype(dtype)?; + self.rt + .block_on(unsafe { self.comm.allgather(send_ptr, recv_ptr, count, nd) }) + .map_err(map_err) + } + + unsafe fn reduce_scatter( + &self, + send_ptr: u64, + recv_ptr: u64, + count: usize, + dtype: DType, + op: ReduceOp, + ) -> Result<()> { + // HierarchicalComm doesn't expose reduce_scatter directly. + // Compose: allreduce the full buffer, then each rank copies its chunk. + // + // allreduce is in-place on send_ptr, so we need send_ptr to hold the + // full data (count * world_size elements). After allreduce, each rank + // copies its slice (rank * count .. (rank+1) * count) into recv_ptr. + let nd = to_nexar_dtype(dtype)?; + let no = to_nexar_op(op); + let ws = self.comm.world_size() as usize; + let total = count * ws; + + // Step 1: allreduce the full buffer in-place + self.rt + .block_on(unsafe { self.comm.allreduce(send_ptr, total, nd, no) }) + .map_err(map_err)?; + + // Step 2: copy this rank's chunk to recv_ptr + let elem_size = dtype.size_in_bytes(); + let offset = self.comm.rank() as usize * count * elem_size; + let bytes = count * elem_size; + unsafe { + std::ptr::copy_nonoverlapping( + (send_ptr as *const u8).add(offset), + recv_ptr as *mut u8, + bytes, + ); + } + Ok(()) + } + + unsafe fn send( + &self, + ptr: u64, + count: usize, + dtype: DType, + dest: usize, + tag: u32, + ) -> Result<()> { + // Route through the nexar client for point-to-point. + let nd = to_nexar_dtype(dtype)?; + let size = count * nd.size_in_bytes(); + self.rt + .block_on(unsafe { self.comm.nexar().send(ptr, size, dest as u32, tag) }) + .map_err(map_nexar_err) + } + + unsafe fn recv( + &self, + ptr: u64, + count: usize, + dtype: DType, + src: usize, + tag: u32, + ) -> Result<()> { + let nd = to_nexar_dtype(dtype)?; + let size = count * nd.size_in_bytes(); + self.rt + .block_on(unsafe { self.comm.nexar().recv(ptr, size, src as u32, tag) }) + .map_err(map_nexar_err) + } + + fn sync(&self) -> Result<()> { + self.comm.synchronize().map_err(map_err) + } + + fn barrier(&self) -> Result<()> { + self.rt.block_on(self.comm.barrier()).map_err(map_err) + } +} diff --git a/src/runtime/communicator/mod.rs b/src/runtime/communicator/mod.rs new file mode 100644 index 00000000..edfbdb89 --- /dev/null +++ b/src/runtime/communicator/mod.rs @@ -0,0 +1,19 @@ +//! Multi-device collective communication. + +mod group; +#[cfg(feature = "distributed-gpu")] +mod hierarchical; +#[cfg(feature = "distributed")] +mod nexar; +#[cfg(feature = "distributed")] +mod nexar_compat; +mod noop; +mod traits; + +#[cfg(feature = "distributed")] +pub use self::nexar::NexarNetCommunicator; +pub use group::{CommunicatorGroup, ParallelDim}; +#[cfg(feature = "distributed-gpu")] +pub use hierarchical::HierarchicalCommunicator; +pub use noop::NoOpCommunicator; +pub use traits::{Communicator, ReduceOp, StreamSyncOps}; diff --git a/src/runtime/communicator/nexar.rs b/src/runtime/communicator/nexar.rs new file mode 100644 index 00000000..e744a070 --- /dev/null +++ b/src/runtime/communicator/nexar.rs @@ -0,0 +1,202 @@ +//! nexar-backed distributed communicator for inter-node collective operations. +//! +//! Wraps [`nexar::SyncClient`] and implements [`Communicator`] so that numr's +//! existing distributed patterns (gradient sync, tensor parallelism) work +//! transparently over QUIC transport. + +use super::nexar_compat::{to_nexar_dtype, to_nexar_op}; +use super::{Communicator, ReduceOp}; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +/// Maps a nexar error to a numr error. +fn map_err(e: nexar::NexarError) -> Error { + Error::Backend(format!("nexar: {e}")) +} + +/// Distributed communicator backed by [`nexar::SyncClient`]. +/// +/// Provides inter-node collective operations (allreduce, broadcast, etc.) +/// over QUIC transport. For intra-node GPU-GPU communication, use +/// `NcclCommunicator` instead — NVLink/PCIe is orders of magnitude faster +/// than any network. +/// +/// # Usage +/// +/// ```ignore +/// use nexar::{CpuAdapter, SyncClient}; +/// use numr::runtime::{NexarNetCommunicator, Communicator}; +/// use std::sync::Arc; +/// +/// let adapter = Arc::new(CpuAdapter::new()); +/// let clients = SyncClient::bootstrap_local(4, adapter).unwrap(); +/// let comms: Vec = clients +/// .into_iter() +/// .map(NexarNetCommunicator::new) +/// .collect(); +/// ``` +pub struct NexarNetCommunicator { + client: nexar::SyncClient, +} + +impl NexarNetCommunicator { + /// Wrap an existing nexar `SyncClient`. + pub fn new(client: nexar::SyncClient) -> Self { + Self { client } + } +} + +impl Communicator for NexarNetCommunicator { + fn world_size(&self) -> usize { + self.client.world_size() as usize + } + + fn rank(&self) -> usize { + self.client.rank() as usize + } + + unsafe fn all_reduce(&self, ptr: u64, count: usize, dtype: DType, op: ReduceOp) -> Result<()> { + let nd = to_nexar_dtype(dtype)?; + let no = to_nexar_op(op); + unsafe { self.client.all_reduce(ptr, count, nd, no).map_err(map_err) } + } + + unsafe fn broadcast(&self, ptr: u64, count: usize, dtype: DType, root: usize) -> Result<()> { + let nd = to_nexar_dtype(dtype)?; + unsafe { + self.client + .broadcast(ptr, count, nd, root as u32) + .map_err(map_err) + } + } + + unsafe fn all_gather( + &self, + send_ptr: u64, + recv_ptr: u64, + count: usize, + dtype: DType, + ) -> Result<()> { + let nd = to_nexar_dtype(dtype)?; + unsafe { + self.client + .all_gather(send_ptr, recv_ptr, count, nd) + .map_err(map_err) + } + } + + unsafe fn reduce_scatter( + &self, + send_ptr: u64, + recv_ptr: u64, + count: usize, + dtype: DType, + op: ReduceOp, + ) -> Result<()> { + let nd = to_nexar_dtype(dtype)?; + let no = to_nexar_op(op); + unsafe { + self.client + .reduce_scatter(send_ptr, recv_ptr, count, nd, no) + .map_err(map_err) + } + } + + unsafe fn send( + &self, + ptr: u64, + count: usize, + dtype: DType, + dest: usize, + tag: u32, + ) -> Result<()> { + let nd = to_nexar_dtype(dtype)?; + let size = count * nd.size_in_bytes(); + unsafe { + self.client + .send(ptr, size, dest as u32, tag) + .map_err(map_err) + } + } + + unsafe fn recv( + &self, + ptr: u64, + count: usize, + dtype: DType, + src: usize, + tag: u32, + ) -> Result<()> { + let nd = to_nexar_dtype(dtype)?; + let size = count * nd.size_in_bytes(); + unsafe { + self.client + .recv(ptr, size, src as u32, tag) + .map_err(map_err) + } + } + + fn sync(&self) -> Result<()> { + // nexar operations are synchronous (block_on), so sync is a no-op. + Ok(()) + } + + fn barrier(&self) -> Result<()> { + self.client.barrier().map_err(map_err) + } + + fn split(&self, color: u32, key: u32) -> Result>> { + let sub = self.client.split(color, key).map_err(map_err)?; + Ok(Some(Box::new(NexarNetCommunicator::new(sub)))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_nexar_communicator_metadata() { + let adapter = std::sync::Arc::new(nexar::CpuAdapter::new()); + let clients = nexar::SyncClient::bootstrap_local(2, adapter).unwrap(); + let comms: Vec = + clients.into_iter().map(NexarNetCommunicator::new).collect(); + + assert_eq!(comms[0].world_size(), 2); + assert_eq!(comms[0].rank(), 0); + assert_eq!(comms[1].rank(), 1); + } + + #[test] + fn test_nexar_allreduce_f32() { + let adapter = std::sync::Arc::new(nexar::CpuAdapter::new()); + let clients = nexar::SyncClient::bootstrap_local(2, adapter).unwrap(); + let comms: Vec = + clients.into_iter().map(NexarNetCommunicator::new).collect(); + + // Each rank has its own data; run allreduce concurrently. + std::thread::scope(|s| { + let handles: Vec<_> = comms + .iter() + .enumerate() + .map(|(i, comm)| { + s.spawn(move || { + let val = (i + 1) as f32; + let mut data = vec![val; 4]; + let ptr = data.as_mut_ptr() as u64; + unsafe { + comm.all_reduce(ptr, 4, DType::F32, ReduceOp::Sum).unwrap(); + } + data + }) + }) + .collect(); + + for h in handles { + let data = h.join().unwrap(); + // 1.0 + 2.0 = 3.0 + assert_eq!(data, vec![3.0f32; 4]); + } + }); + } +} diff --git a/src/runtime/communicator/nexar_compat.rs b/src/runtime/communicator/nexar_compat.rs new file mode 100644 index 00000000..2e31932e --- /dev/null +++ b/src/runtime/communicator/nexar_compat.rs @@ -0,0 +1,70 @@ +//! Shared conversion helpers between numr and nexar types. + +use super::ReduceOp; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +/// Maps a numr `DType` to a nexar `DataType`. +/// +/// Returns `Err` for types nexar doesn't support (Complex, Bool, FP8, I16, U16). +pub fn to_nexar_dtype(dtype: DType) -> Result { + match dtype { + DType::F32 => Ok(nexar::DataType::F32), + DType::F64 => Ok(nexar::DataType::F64), + DType::F16 => Ok(nexar::DataType::F16), + DType::BF16 => Ok(nexar::DataType::BF16), + DType::I8 => Ok(nexar::DataType::I8), + DType::I32 => Ok(nexar::DataType::I32), + DType::I64 => Ok(nexar::DataType::I64), + DType::U8 => Ok(nexar::DataType::U8), + DType::U32 => Ok(nexar::DataType::U32), + DType::U64 => Ok(nexar::DataType::U64), + _ => Err(Error::Backend(format!( + "nexar: unsupported dtype {dtype:?} for collective operation" + ))), + } +} + +/// Maps a numr `ReduceOp` to a nexar `ReduceOp`. +pub fn to_nexar_op(op: ReduceOp) -> nexar::ReduceOp { + match op { + ReduceOp::Sum => nexar::ReduceOp::Sum, + ReduceOp::Prod => nexar::ReduceOp::Prod, + ReduceOp::Min => nexar::ReduceOp::Min, + ReduceOp::Max => nexar::ReduceOp::Max, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dtype_mapping() { + assert_eq!(to_nexar_dtype(DType::F32).unwrap(), nexar::DataType::F32); + assert_eq!(to_nexar_dtype(DType::F64).unwrap(), nexar::DataType::F64); + assert_eq!(to_nexar_dtype(DType::F16).unwrap(), nexar::DataType::F16); + assert_eq!(to_nexar_dtype(DType::BF16).unwrap(), nexar::DataType::BF16); + assert_eq!(to_nexar_dtype(DType::I8).unwrap(), nexar::DataType::I8); + assert_eq!(to_nexar_dtype(DType::I32).unwrap(), nexar::DataType::I32); + assert_eq!(to_nexar_dtype(DType::I64).unwrap(), nexar::DataType::I64); + assert_eq!(to_nexar_dtype(DType::U8).unwrap(), nexar::DataType::U8); + assert_eq!(to_nexar_dtype(DType::U32).unwrap(), nexar::DataType::U32); + assert_eq!(to_nexar_dtype(DType::U64).unwrap(), nexar::DataType::U64); + } + + #[test] + fn test_dtype_mapping_unsupported() { + assert!(to_nexar_dtype(DType::Bool).is_err()); + assert!(to_nexar_dtype(DType::Complex64).is_err()); + assert!(to_nexar_dtype(DType::Complex128).is_err()); + } + + #[test] + fn test_reduce_op_mapping() { + assert_eq!(to_nexar_op(ReduceOp::Sum), nexar::ReduceOp::Sum); + assert_eq!(to_nexar_op(ReduceOp::Prod), nexar::ReduceOp::Prod); + assert_eq!(to_nexar_op(ReduceOp::Min), nexar::ReduceOp::Min); + assert_eq!(to_nexar_op(ReduceOp::Max), nexar::ReduceOp::Max); + } +} diff --git a/src/runtime/communicator/noop.rs b/src/runtime/communicator/noop.rs new file mode 100644 index 00000000..a5851541 --- /dev/null +++ b/src/runtime/communicator/noop.rs @@ -0,0 +1,231 @@ +//! No-op communicator for single-device operation. + +use crate::dtype::DType; +use crate::error::Result; + +use super::{Communicator, ReduceOp}; + +/// No-op communicator for single-device operation (world_size=1). +/// +/// - In-place collectives (`all_reduce`, `broadcast`): true no-ops +/// - Separate-buffer collectives (`all_gather`, `reduce_scatter`): memcpy send→recv +/// - Point-to-point (`send`, `recv`): no-ops (nothing to communicate) +/// - `sync`, `barrier`: no-ops +#[derive(Clone, Debug, Default)] +pub struct NoOpCommunicator; + +impl Communicator for NoOpCommunicator { + fn world_size(&self) -> usize { + 1 + } + + fn rank(&self) -> usize { + 0 + } + + unsafe fn all_reduce( + &self, + _ptr: u64, + _count: usize, + _dtype: DType, + _op: ReduceOp, + ) -> Result<()> { + // Single device: buffer already contains the "reduced" result + Ok(()) + } + + unsafe fn broadcast( + &self, + _ptr: u64, + _count: usize, + _dtype: DType, + _root: usize, + ) -> Result<()> { + // Single device: buffer already has root's data (we are root) + Ok(()) + } + + unsafe fn all_gather( + &self, + send_ptr: u64, + recv_ptr: u64, + count: usize, + dtype: DType, + ) -> Result<()> { + // Single device: copy send → recv (output = input for world_size=1) + if send_ptr != recv_ptr { + let bytes = count * dtype.size_in_bytes(); + unsafe { + std::ptr::copy_nonoverlapping(send_ptr as *const u8, recv_ptr as *mut u8, bytes); + } + } + Ok(()) + } + + unsafe fn reduce_scatter( + &self, + send_ptr: u64, + recv_ptr: u64, + count: usize, + dtype: DType, + _op: ReduceOp, + ) -> Result<()> { + // Single device: the "reduced" result is just the input, + // and the single rank gets the full slice + if send_ptr != recv_ptr { + let bytes = count * dtype.size_in_bytes(); + unsafe { + std::ptr::copy_nonoverlapping(send_ptr as *const u8, recv_ptr as *mut u8, bytes); + } + } + Ok(()) + } + + unsafe fn send( + &self, + _ptr: u64, + _count: usize, + _dtype: DType, + _dest: usize, + _tag: u32, + ) -> Result<()> { + // Single device: no-op + Ok(()) + } + + unsafe fn recv( + &self, + _ptr: u64, + _count: usize, + _dtype: DType, + _src: usize, + _tag: u32, + ) -> Result<()> { + // Single device: no-op + Ok(()) + } + + fn sync(&self) -> Result<()> { + Ok(()) + } + + fn barrier(&self) -> Result<()> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_noop_metadata() { + let comm = NoOpCommunicator; + assert_eq!(comm.world_size(), 1); + assert_eq!(comm.rank(), 0); + } + + #[test] + fn test_noop_all_reduce() { + let comm = NoOpCommunicator; + let mut data = [1.0f32, 2.0, 3.0, 4.0]; + unsafe { + comm.all_reduce(data.as_mut_ptr() as u64, 4, DType::F32, ReduceOp::Sum) + .unwrap(); + } + // Data unchanged (single device) + assert_eq!(data, [1.0, 2.0, 3.0, 4.0]); + } + + #[test] + fn test_noop_broadcast() { + let comm = NoOpCommunicator; + let mut data = [1.0f32, 2.0]; + unsafe { + comm.broadcast(data.as_mut_ptr() as u64, 2, DType::F32, 0) + .unwrap(); + } + assert_eq!(data, [1.0, 2.0]); + } + + #[test] + fn test_noop_all_gather() { + let comm = NoOpCommunicator; + let send = [1.0f32, 2.0, 3.0]; + let mut recv = [0.0f32; 3]; + unsafe { + comm.all_gather( + send.as_ptr() as u64, + recv.as_mut_ptr() as u64, + 3, + DType::F32, + ) + .unwrap(); + } + assert_eq!(recv, [1.0, 2.0, 3.0]); + } + + #[test] + fn test_noop_reduce_scatter() { + let comm = NoOpCommunicator; + let send = [10.0f32, 20.0]; + let mut recv = [0.0f32; 2]; + unsafe { + comm.reduce_scatter( + send.as_ptr() as u64, + recv.as_mut_ptr() as u64, + 2, + DType::F32, + ReduceOp::Sum, + ) + .unwrap(); + } + assert_eq!(recv, [10.0, 20.0]); + } + + #[test] + fn test_noop_send_recv() { + let comm = NoOpCommunicator; + let data = [1.0f32]; + unsafe { + comm.send(data.as_ptr() as u64, 1, DType::F32, 0, 0) + .unwrap(); + comm.recv(data.as_ptr() as u64, 1, DType::F32, 0, 0) + .unwrap(); + } + } + + #[test] + fn test_noop_sync_barrier() { + let comm = NoOpCommunicator; + comm.sync().unwrap(); + comm.barrier().unwrap(); + } + + #[test] + fn test_noop_send_sync() { + fn assert_send_sync() {} + assert_send_sync::(); + } + + #[test] + fn test_noop_all_gather_same_ptr() { + let comm = NoOpCommunicator; + let mut data = [1.0f32, 2.0]; + let ptr = data.as_mut_ptr() as u64; + unsafe { + comm.all_gather(ptr, ptr, 2, DType::F32).unwrap(); + } + assert_eq!(data, [1.0, 2.0]); + } + + #[test] + fn test_reduce_op_variants() { + let ops = [ReduceOp::Sum, ReduceOp::Prod, ReduceOp::Min, ReduceOp::Max]; + for i in 0..ops.len() { + for j in (i + 1)..ops.len() { + assert_ne!(ops[i], ops[j]); + } + } + } +} diff --git a/src/runtime/communicator/traits.rs b/src/runtime/communicator/traits.rs new file mode 100644 index 00000000..141fe2a9 --- /dev/null +++ b/src/runtime/communicator/traits.rs @@ -0,0 +1,209 @@ +//! Communicator trait and reduction operations. + +use crate::dtype::DType; +use crate::error::Result; + +/// Reduction operation for collective communication +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ReduceOp { + /// Element-wise sum across ranks + Sum, + /// Element-wise product across ranks + Prod, + /// Element-wise minimum across ranks + Min, + /// Element-wise maximum across ranks + Max, +} + +/// Multi-device collective communication +/// +/// Operates on device pointers (`u64`) + element count + `DType`, matching +/// NCCL's and MPI's native calling conventions. The `u64` pointer is the +/// same abstraction as `Runtime::allocate()` / `Runtime::deallocate()`. +/// +/// `DType` provides unambiguous type information so backends can dispatch +/// to the correct reduction unit (e.g., f16 vs bf16 vs i16 are all 2 bytes +/// but require different hardware reduction units). +/// +/// # Safety +/// +/// All pointer-based methods are `unsafe fn` because passing an invalid `u64` +/// (dangling, wrong device, wrong provenance) causes undefined behavior. +/// Callers MUST ensure: +/// - **NCCL**: pointers are GPU device pointers from the same CUDA context +/// - **MPI**: pointers are valid host pointers +/// - Pointer provenance matches the communicator backend +/// - Buffers remain allocated until `sync()` or `barrier()` +/// +/// Higher-level wrappers (boostr's distributed patterns) accept `Tensor` +/// and extract pointers internally, providing a safe public API. +/// +/// # Drop contract +/// +/// Dropping with pending non-blocking operations attempts best-effort sync +/// with a bounded timeout. On failure the destructor **logs** the error +/// (via `tracing::error!`) and proceeds — it **never panics**. +/// +/// # Thread safety +/// +/// `Send + Sync` so it can be stored in `Arc`. If multiple threads call +/// `send()`/`recv()` concurrently, submission order is implementation-defined. +/// For deterministic ordering, serialize submissions externally. +pub trait Communicator: Send + Sync { + /// Number of participants + fn world_size(&self) -> usize; + + /// This participant's rank (0-indexed) + fn rank(&self) -> usize; + + /// AllReduce in-place: reduce across all ranks, result on all ranks. + /// + /// Completion semantics are implementation-defined. On NCCL the operation + /// is non-blocking (stream-ordered). **Portable code must call `sync()` + /// before reading the result buffer.** + /// + /// # Safety + /// + /// `ptr` must be a valid device pointer with at least `count` elements of `dtype`. + unsafe fn all_reduce(&self, ptr: u64, count: usize, dtype: DType, op: ReduceOp) -> Result<()>; + + /// Broadcast from root rank to all other ranks. + /// + /// # Safety + /// + /// `ptr` must be a valid device pointer with at least `count` elements of `dtype`. + unsafe fn broadcast(&self, ptr: u64, count: usize, dtype: DType, root: usize) -> Result<()>; + + /// AllGather: each rank contributes `count` elements, result is + /// `count * world_size` elements on all ranks. + /// + /// # Safety + /// + /// - `send_ptr` must point to at least `count` elements + /// - `recv_ptr` must point to at least `count * world_size` elements + unsafe fn all_gather( + &self, + send_ptr: u64, + recv_ptr: u64, + count: usize, + dtype: DType, + ) -> Result<()>; + + /// ReduceScatter: reduce + scatter. Each rank gets a different slice + /// of the reduced result. + /// + /// # Safety + /// + /// - `send_ptr` must point to at least `count * world_size` elements + /// - `recv_ptr` must point to at least `count` elements + unsafe fn reduce_scatter( + &self, + send_ptr: u64, + recv_ptr: u64, + count: usize, + dtype: DType, + op: ReduceOp, + ) -> Result<()>; + + /// Point-to-point send to a specific rank (non-blocking). + /// + /// The send buffer must NOT be modified or deallocated until `sync()`. + /// + /// `tag` is used for message matching on MPI. On NCCL, `tag` is accepted + /// but ignored (stream-ordered submission determines matching). + /// + /// # Safety + /// + /// `ptr` must be a valid device pointer with at least `count` elements of `dtype`. + unsafe fn send( + &self, + ptr: u64, + count: usize, + dtype: DType, + dest: usize, + tag: u32, + ) -> Result<()>; + + /// Point-to-point receive from a specific rank (non-blocking). + /// + /// The recv buffer contains valid data only after `sync()` or `barrier()`. + /// + /// # Safety + /// + /// `ptr` must be a valid device pointer with at least `count` elements of `dtype`. + unsafe fn recv(&self, ptr: u64, count: usize, dtype: DType, src: usize, tag: u32) + -> Result<()>; + + /// Wait for all pending operations to complete. + /// + /// After sync returns, all output/recv buffers contain valid data and + /// all send/input buffers are safe to reuse. + fn sync(&self) -> Result<()>; + + /// Barrier: block until all ranks reach this point. + /// + /// Implies `sync()` — all pending operations complete before the barrier. + fn barrier(&self) -> Result<()>; + + /// Split this communicator into sub-communicators by color and key. + /// + /// All ranks must call `split()` collectively. Ranks with the same `color` + /// end up in the same sub-communicator, ordered by `key`. + /// + /// Returns `None` for backends that don't support splitting (e.g., NCCL + /// without `ncclCommSplit`, or the no-op communicator). + fn split(&self, _color: u32, _key: u32) -> Result>> { + Ok(None) + } + + /// Downcast to `StreamSyncOps` if this communicator supports CUDA + /// stream/event synchronization for compute-communication overlap. + /// + /// Returns `None` by default. Backends with separate communication + /// streams (e.g., NCCL) override this to return `Some(self)`. + fn as_stream_sync(&self) -> Option<&dyn StreamSyncOps> { + None + } +} + +/// Stream/event synchronization for compute-communication overlap. +/// +/// Enables launching allreduce on a separate communication stream while +/// backward computation continues on the compute stream. Events provide +/// GPU-side synchronization without blocking the CPU. +/// +/// # Event Lifecycle +/// +/// 1. Create event with [`create_event`] +/// 2. Record on compute stream (gradient ready) with [`record_on_stream`] +/// 3. Make comm stream wait with [`comm_stream_wait_event`] +/// 4. Launch allreduce (runs on comm stream) +/// 5. Record completion on comm stream with [`record_on_comm_stream`] +/// 6. Make compute stream wait with [`stream_wait_event`] +/// 7. Destroy event with [`destroy_event`] +pub trait StreamSyncOps { + /// Create a CUDA event for synchronization. + /// + /// Returns an opaque event handle. Uses `CU_EVENT_DISABLE_TIMING` for + /// minimal overhead (only ordering semantics needed, not timing). + fn create_event(&self) -> Result; + + /// Destroy a previously created event. + fn destroy_event(&self, event: u64) -> Result<()>; + + /// Record an event on the communicator's internal stream. + fn record_on_comm_stream(&self, event: u64) -> Result<()>; + + /// Record an event on an external stream (e.g., the compute stream). + fn record_on_stream(&self, event: u64, stream_handle: u64) -> Result<()>; + + /// Make the communicator's internal stream wait for an event. + fn comm_stream_wait_event(&self, event: u64) -> Result<()>; + + /// Make an external stream wait for an event. + fn stream_wait_event(&self, stream_handle: u64, event: u64) -> Result<()>; + + /// Synchronize the communicator's internal stream (CPU-blocking). + fn sync_comm_stream(&self) -> Result<()>; +} diff --git a/src/runtime/cpu/fft/mod.rs b/src/runtime/cpu/fft/mod.rs index 6321dd60..cb243c5c 100644 --- a/src/runtime/cpu/fft/mod.rs +++ b/src/runtime/cpu/fft/mod.rs @@ -211,8 +211,8 @@ impl CpuClient { let batch_size = batch_size.max(1); let min_len = self.chunk_size_hint(); - let input_ptr = input.storage().ptr(); - let output_ptr = output.storage().ptr(); + let input_ptr = input.ptr(); + let output_ptr = output.ptr(); match dtype { DType::Complex64 => { diff --git a/src/runtime/cpu/fft/real.rs b/src/runtime/cpu/fft/real.rs index 854423e0..fd788ca0 100644 --- a/src/runtime/cpu/fft/real.rs +++ b/src/runtime/cpu/fft/real.rs @@ -49,8 +49,8 @@ pub(super) fn rfft_impl( #[cfg(feature = "rayon")] let min_len = client.rayon_min_len(); - let input_ptr = input_contig.storage().ptr(); - let output_ptr = output.storage().ptr(); + let input_ptr = input_contig.ptr(); + let output_ptr = output.ptr(); match dtype { DType::F32 => { @@ -199,8 +199,8 @@ pub(super) fn irfft_impl( #[cfg(feature = "rayon")] let min_len = client.rayon_min_len(); - let input_ptr = input_contig.storage().ptr(); - let output_ptr = output.storage().ptr(); + let input_ptr = input_contig.ptr(); + let output_ptr = output.ptr(); match dtype { DType::Complex64 => { diff --git a/src/runtime/cpu/fft/shift.rs b/src/runtime/cpu/fft/shift.rs index b6cdb29f..3877bd11 100644 --- a/src/runtime/cpu/fft/shift.rs +++ b/src/runtime/cpu/fft/shift.rs @@ -47,8 +47,8 @@ fn shift_impl( #[cfg(feature = "rayon")] let min_len = client.rayon_min_len(); - let input_ptr = input_contig.storage().ptr(); - let output_ptr = output.storage().ptr(); + let input_ptr = input_contig.ptr(); + let output_ptr = output.ptr(); let op_name = if inverse { "ifftshift" } else { "fftshift" }; @@ -158,7 +158,7 @@ pub(super) fn fftfreq_impl( let output = Tensor::::empty(&[n], dtype, device); let scale = 1.0 / (d * n as f64); - let output_ptr = output.storage().ptr(); + let output_ptr = output.ptr(); match dtype { DType::F32 => { @@ -216,7 +216,7 @@ pub(super) fn rfftfreq_impl( let output_len = n / 2 + 1; let output = Tensor::::empty(&[output_len], dtype, device); let scale = 1.0 / (d * n as f64); - let output_ptr = output.storage().ptr(); + let output_ptr = output.ptr(); match dtype { DType::F32 => { diff --git a/src/runtime/cpu/helpers/activation.rs b/src/runtime/cpu/helpers/activation.rs index 8544da02..fee8590e 100644 --- a/src/runtime/cpu/helpers/activation.rs +++ b/src/runtime/cpu/helpers/activation.rs @@ -37,8 +37,8 @@ pub fn activation_op_impl( let out = Tensor::::empty(a.shape(), dtype, &client.device); let len = a.numel(); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -70,6 +70,69 @@ pub fn activation_op_impl( Ok(out) } +/// Fused activation-mul operation kind +#[derive(Copy, Clone)] +#[allow(clippy::enum_variant_names)] +pub enum FusedActivationMulOp { + SiluMul, + GeluMul, + ReluMul, + SigmoidMul, +} + +/// Helper for fused activation-mul operations: activation(a) * b +pub fn fused_activation_mul_impl( + client: &CpuClient, + a: &Tensor, + b: &Tensor, + op: FusedActivationMulOp, + op_name: &'static str, +) -> Result> { + let dtype = a.dtype(); + if b.dtype() != dtype { + return Err(crate::error::Error::DTypeMismatch { + lhs: dtype, + rhs: b.dtype(), + }); + } + if a.shape() != b.shape() { + return Err(crate::error::Error::ShapeMismatch { + expected: a.shape().to_vec(), + got: b.shape().to_vec(), + }); + } + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let out = Tensor::::empty(a.shape(), dtype, &client.device); + + let len = a.numel(); + let a_ptr = a_contig.ptr(); + let b_ptr = b_contig.ptr(); + let out_ptr = out.ptr(); + + dispatch_dtype!(dtype, T => { + unsafe { + match op { + FusedActivationMulOp::SiluMul => kernels::silu_mul_kernel::( + a_ptr as *const T, b_ptr as *const T, out_ptr as *mut T, len, + ), + FusedActivationMulOp::GeluMul => kernels::gelu_mul_kernel::( + a_ptr as *const T, b_ptr as *const T, out_ptr as *mut T, len, + ), + FusedActivationMulOp::ReluMul => kernels::relu_mul_kernel::( + a_ptr as *const T, b_ptr as *const T, out_ptr as *mut T, len, + ), + FusedActivationMulOp::SigmoidMul => kernels::sigmoid_mul_kernel::( + a_ptr as *const T, b_ptr as *const T, out_ptr as *mut T, len, + ), + } + } + }, op_name); + + Ok(out) +} + /// Helper for parametric activation operations (leaky_relu, elu) /// /// These activations take a single f64 parameter in addition to the input tensor. @@ -85,8 +148,8 @@ pub fn parametric_activation_impl( let out = Tensor::::empty(a.shape(), dtype, &client.device); let len = a.numel(); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { diff --git a/src/runtime/cpu/helpers/binary.rs b/src/runtime/cpu/helpers/binary.rs index 5d70b069..1df7bd35 100644 --- a/src/runtime/cpu/helpers/binary.rs +++ b/src/runtime/cpu/helpers/binary.rs @@ -21,7 +21,7 @@ pub fn binary_op_impl( // Create output tensor let out = Tensor::::empty(&out_shape, dtype, &client.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); // Check if we can use the fast path (same shapes, both contiguous) let same_shapes = a.shape() == b.shape() && a.shape() == out_shape.as_slice(); @@ -30,8 +30,8 @@ pub fn binary_op_impl( if same_shapes && both_contiguous { // Fast path: no broadcasting needed, use contiguous kernel let len = a.numel(); - let a_ptr = a.storage().ptr(); - let b_ptr = b.storage().ptr(); + let a_ptr = a.ptr(); + let b_ptr = b.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -50,14 +50,12 @@ pub fn binary_op_impl( let a_broadcast = a.broadcast_to(&out_shape)?; let b_broadcast = b.broadcast_to(&out_shape)?; - let a_ptr = a_broadcast.storage().ptr(); - let b_ptr = b_broadcast.storage().ptr(); + let a_ptr = a_broadcast.ptr(); + let b_ptr = b_broadcast.ptr(); // Get strides from broadcast layouts let a_strides: Vec = a_broadcast.layout().strides().to_vec(); let b_strides: Vec = b_broadcast.layout().strides().to_vec(); - let a_offset = a_broadcast.layout().offset(); - let b_offset = b_broadcast.layout().offset(); dispatch_dtype!(dtype, T => { unsafe { @@ -69,8 +67,8 @@ pub fn binary_op_impl( &out_shape, &a_strides, &b_strides, - a_offset, - b_offset, + 0, + 0, ); } }, op_name); diff --git a/src/runtime/cpu/helpers/compare.rs b/src/runtime/cpu/helpers/compare.rs index 13c42996..da051238 100644 --- a/src/runtime/cpu/helpers/compare.rs +++ b/src/runtime/cpu/helpers/compare.rs @@ -19,7 +19,7 @@ pub fn compare_op_impl( let dtype = validate_binary_dtypes(a, b)?; let out_shape = compute_broadcast_shape(a, b)?; let out = Tensor::::empty(&out_shape, dtype, &client.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); // Fast path for same shapes, both contiguous let same_shapes = a.shape() == b.shape() && a.shape() == out_shape.as_slice(); @@ -27,8 +27,8 @@ pub fn compare_op_impl( if same_shapes && both_contiguous { let len = a.numel(); - let a_ptr = a.storage().ptr(); - let b_ptr = b.storage().ptr(); + let a_ptr = a.ptr(); + let b_ptr = b.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -48,10 +48,8 @@ pub fn compare_op_impl( let a_strides: Vec = a_broadcast.layout().strides().to_vec(); let b_strides: Vec = b_broadcast.layout().strides().to_vec(); - let a_offset = a_broadcast.layout().offset(); - let b_offset = b_broadcast.layout().offset(); - let a_ptr = a_broadcast.storage().ptr(); - let b_ptr = b_broadcast.storage().ptr(); + let a_ptr = a_broadcast.ptr(); + let b_ptr = b_broadcast.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -63,8 +61,8 @@ pub fn compare_op_impl( &out_shape, &a_strides, &b_strides, - a_offset, - b_offset, + 0, + 0, ); } }, op_name); diff --git a/src/runtime/cpu/helpers/cumulative.rs b/src/runtime/cpu/helpers/cumulative.rs index 7fe03e13..0bdec499 100644 --- a/src/runtime/cpu/helpers/cumulative.rs +++ b/src/runtime/cpu/helpers/cumulative.rs @@ -50,8 +50,8 @@ pub fn cumsum_impl( let inner_size: usize = shape[dim_idx + 1..].iter().product(); let inner_size = inner_size.max(1); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -109,8 +109,8 @@ pub fn cumprod_impl( let inner_size: usize = shape[dim_idx + 1..].iter().product(); let inner_size = inner_size.max(1); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -178,8 +178,8 @@ pub fn logsumexp_impl( let out_shape = reduce_output_shape(shape, dims, keepdim); let out = Tensor::::empty(&out_shape, dtype, &client.device); - let a_ptr = a.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -256,8 +256,8 @@ fn logsumexp_single_dim( let out = Tensor::::empty(&out_shape, dtype, &client.device); - let a_ptr = a.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { diff --git a/src/runtime/cpu/helpers/fused_elementwise.rs b/src/runtime/cpu/helpers/fused_elementwise.rs new file mode 100644 index 00000000..69c8a168 --- /dev/null +++ b/src/runtime/cpu/helpers/fused_elementwise.rs @@ -0,0 +1,148 @@ +//! Fused elementwise operation helpers for CPU tensors + +use super::super::kernels; +use super::super::{CpuClient, CpuRuntime}; +use crate::dispatch_dtype; +use crate::error::{Error, Result}; +use crate::runtime::ensure_contiguous; +use crate::tensor::Tensor; + +/// Helper for fused_mul_add: out = a * b + c +pub fn fused_mul_add_impl( + client: &CpuClient, + a: &Tensor, + b: &Tensor, + c: &Tensor, +) -> Result> { + let dtype = a.dtype(); + if b.dtype() != dtype || c.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if b.dtype() != dtype { + b.dtype() + } else { + c.dtype() + }, + }); + } + if a.shape() != b.shape() || a.shape() != c.shape() { + return Err(Error::ShapeMismatch { + expected: a.shape().to_vec(), + got: if a.shape() != b.shape() { + b.shape().to_vec() + } else { + c.shape().to_vec() + }, + }); + } + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let c_contig = ensure_contiguous(c); + let out = Tensor::::empty(a.shape(), dtype, &client.device); + + let len = a.numel(); + let a_ptr = a_contig.ptr(); + let b_ptr = b_contig.ptr(); + let c_ptr = c_contig.ptr(); + let out_ptr = out.ptr(); + + dispatch_dtype!(dtype, T => { + unsafe { + kernels::fused_mul_add_kernel::( + a_ptr as *const T, + b_ptr as *const T, + c_ptr as *const T, + out_ptr as *mut T, + len, + ); + } + }, "fused_mul_add"); + + Ok(out) +} + +/// Helper for fused_add_mul: out = (a + b) * c +pub fn fused_add_mul_impl( + client: &CpuClient, + a: &Tensor, + b: &Tensor, + c: &Tensor, +) -> Result> { + let dtype = a.dtype(); + if b.dtype() != dtype || c.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if b.dtype() != dtype { + b.dtype() + } else { + c.dtype() + }, + }); + } + if a.shape() != b.shape() || a.shape() != c.shape() { + return Err(Error::ShapeMismatch { + expected: a.shape().to_vec(), + got: if a.shape() != b.shape() { + b.shape().to_vec() + } else { + c.shape().to_vec() + }, + }); + } + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let c_contig = ensure_contiguous(c); + let out = Tensor::::empty(a.shape(), dtype, &client.device); + + let len = a.numel(); + let a_ptr = a_contig.ptr(); + let b_ptr = b_contig.ptr(); + let c_ptr = c_contig.ptr(); + let out_ptr = out.ptr(); + + dispatch_dtype!(dtype, T => { + unsafe { + kernels::fused_add_mul_kernel::( + a_ptr as *const T, + b_ptr as *const T, + c_ptr as *const T, + out_ptr as *mut T, + len, + ); + } + }, "fused_add_mul"); + + Ok(out) +} + +/// Helper for fused_mul_add_scalar: out = a * scale + bias +pub fn fused_mul_add_scalar_impl( + client: &CpuClient, + a: &Tensor, + scale: f64, + bias: f64, +) -> Result> { + let dtype = a.dtype(); + let a_contig = ensure_contiguous(a); + let out = Tensor::::empty(a.shape(), dtype, &client.device); + + let len = a.numel(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); + + dispatch_dtype!(dtype, T => { + unsafe { + kernels::fused_mul_add_scalar_kernel::( + a_ptr as *const T, + scale, + bias, + out_ptr as *mut T, + len, + ); + } + }, "fused_mul_add_scalar"); + + Ok(out) +} diff --git a/src/runtime/cpu/helpers/indexing.rs b/src/runtime/cpu/helpers/indexing.rs index f0a5355e..025dbeee 100644 --- a/src/runtime/cpu/helpers/indexing.rs +++ b/src/runtime/cpu/helpers/indexing.rs @@ -94,10 +94,10 @@ pub fn gather_2d_impl( // Allocate output let out = Tensor::::empty(&[num_indices], dtype, &client.device); - let input_ptr = input_contig.storage().ptr(); - let rows_ptr = rows_contig.storage().ptr(); - let cols_ptr = cols_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let input_ptr = input_contig.ptr(); + let rows_ptr = rows_contig.ptr(); + let cols_ptr = cols_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { let success = unsafe { @@ -158,9 +158,9 @@ pub fn gather_impl( let index_contig = ensure_contiguous(&index_i64); let out = Tensor::::empty(&out_shape, dtype, &client.device); - let a_ptr = a_contig.storage().ptr(); - let index_ptr = index_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let index_ptr = index_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -220,10 +220,10 @@ pub fn scatter_impl( let src_contig = ensure_contiguous(src); let out = Tensor::::empty(shape, dtype, &client.device); - let a_ptr = a_contig.storage().ptr(); - let index_ptr = index_contig.storage().ptr(); - let src_ptr = src_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let index_ptr = index_contig.ptr(); + let src_ptr = src_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -282,9 +282,8 @@ pub fn index_select_impl( // Validate all indices are within bounds (before calling unsafe kernel) let dim_size = shape[dim]; - let index_data = unsafe { - std::slice::from_raw_parts(index_contig.storage().ptr() as *const i64, index_len) - }; + let index_data = + unsafe { std::slice::from_raw_parts(index_contig.ptr() as *const i64, index_len) }; for &idx in index_data.iter() { // Negative indices are not supported - must be in [0, dim_size) if idx < 0 || idx as usize >= dim_size { @@ -297,9 +296,9 @@ pub fn index_select_impl( let out = Tensor::::empty(&out_shape, dtype, &client.device); - let a_ptr = a_contig.storage().ptr(); - let index_ptr = index_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let index_ptr = index_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -373,9 +372,8 @@ pub fn index_put_impl( // Validate all indices are within bounds (before calling unsafe kernel) let dim_size = shape[dim]; - let index_data = unsafe { - std::slice::from_raw_parts(index_contig.storage().ptr() as *const i64, index_len) - }; + let index_data = + unsafe { std::slice::from_raw_parts(index_contig.ptr() as *const i64, index_len) }; for &idx in index_data.iter() { // Negative indices are not supported - must be in [0, dim_size) if idx < 0 || idx as usize >= dim_size { @@ -389,10 +387,10 @@ pub fn index_put_impl( // Clone a's data for output let out = Tensor::::empty(shape, dtype, &client.device); - let a_ptr = a_contig.storage().ptr(); - let index_ptr = index_contig.storage().ptr(); - let src_ptr = src_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let index_ptr = index_contig.ptr(); + let src_ptr = src_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -434,8 +432,8 @@ pub fn masked_select_impl( let mask_contig = ensure_contiguous(&mask_broadcast); let numel = a.numel(); - let a_ptr = a_contig.storage().ptr(); - let mask_ptr = mask_contig.storage().ptr(); + let a_ptr = a_contig.ptr(); + let mask_ptr = mask_contig.ptr(); // Use SIMD for f32/f64 on x86_64 #[cfg(target_arch = "x86_64")] @@ -445,7 +443,7 @@ pub fn masked_select_impl( // Allocate output with correct size let out = Tensor::::empty(&[count], dtype, &client.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); match dtype { DType::F32 => { @@ -495,7 +493,7 @@ pub fn masked_select_impl( // Allocate output with correct size let out = Tensor::::empty(&[count], dtype, &client.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -537,9 +535,9 @@ pub fn masked_fill_impl( let out = Tensor::::empty(a.shape(), dtype, &client.device); let numel = a.numel(); - let a_ptr = a_contig.storage().ptr(); - let mask_ptr = mask_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let mask_ptr = mask_contig.ptr(); + let out_ptr = out.ptr(); // Use SIMD for f32/f64 on x86_64 #[cfg(target_arch = "x86_64")] @@ -626,9 +624,9 @@ pub fn embedding_lookup_impl( let idx_contig = ensure_contiguous(&indices_i64); let out = Tensor::::empty(&out_shape, dtype, &client.device); - let emb_ptr = emb_contig.storage().ptr(); - let idx_ptr = idx_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let emb_ptr = emb_contig.ptr(); + let idx_ptr = idx_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -703,10 +701,10 @@ pub fn scatter_reduce_impl( let dst_numel: usize = shape.iter().product(); let counts_buffer: Vec = vec![0; dst_numel]; - let dst_ptr = dst_contig.storage().ptr(); - let index_ptr = index_contig.storage().ptr(); - let src_ptr = src_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let dst_ptr = dst_contig.ptr(); + let index_ptr = index_contig.ptr(); + let src_ptr = src_contig.ptr(); + let out_ptr = out.ptr(); let counts_ptr = if op == ScatterReduceOp::Mean { counts_buffer.as_ptr() as *mut u32 } else { @@ -778,9 +776,9 @@ pub fn gather_nd_impl( let indices_contig = ensure_contiguous(&indices_i64); let out = Tensor::::empty(&out_shape, dtype, &client.device); - let input_ptr = input_contig.storage().ptr(); - let indices_ptr = indices_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let input_ptr = input_contig.ptr(); + let indices_ptr = indices_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -840,14 +838,11 @@ pub fn bincount_impl( // Convert input to i64 if needed let input_i64: Vec = if input_dtype == DType::I64 { - unsafe { - std::slice::from_raw_parts(input_contig.storage().ptr() as *const i64, numel).to_vec() - } + unsafe { std::slice::from_raw_parts(input_contig.ptr() as *const i64, numel).to_vec() } } else { // I32 input - let i32_slice = unsafe { - std::slice::from_raw_parts(input_contig.storage().ptr() as *const i32, numel) - }; + let i32_slice = + unsafe { std::slice::from_raw_parts(input_contig.ptr() as *const i32, numel) }; i32_slice.iter().map(|&x| x as i64).collect() }; @@ -862,11 +857,11 @@ pub fn bincount_impl( let output_len = (max_val as usize + 1).max(minlength); let out = Tensor::::empty(&[output_len], out_dtype, &client.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); if let Some(w) = weights { let w_contig = ensure_contiguous(w); - let w_ptr = w_contig.storage().ptr(); + let w_ptr = w_contig.ptr(); dispatch_dtype!(out_dtype, T => { let success = unsafe { @@ -906,3 +901,87 @@ pub fn bincount_impl( Ok(out) } + +/// Slice assign implementation: copies src into a slice of dst along dim starting at start. +pub fn slice_assign_impl( + client: &CpuClient, + dst: &Tensor, + src: &Tensor, + dim: usize, + start: usize, +) -> Result> { + let ndim = dst.ndim(); + if dim >= ndim { + return Err(Error::InvalidDimension { + dim: dim as isize, + ndim, + }); + } + + // Validate shapes match except at dim + if src.ndim() != ndim { + return Err(Error::ShapeMismatch { + expected: dst.shape().to_vec(), + got: src.shape().to_vec(), + }); + } + for d in 0..ndim { + if d != dim && src.shape()[d] != dst.shape()[d] { + return Err(Error::ShapeMismatch { + expected: dst.shape().to_vec(), + got: src.shape().to_vec(), + }); + } + } + + let src_dim_size = src.shape()[dim]; + let dst_dim_size = dst.shape()[dim]; + if start + src_dim_size > dst_dim_size { + return Err(Error::InvalidArgument { + arg: "start", + reason: format!( + "start ({}) + src dim size ({}) exceeds dst dim size ({})", + start, src_dim_size, dst_dim_size + ), + }); + } + + let dtype = dst.dtype(); + if src.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: src.dtype(), + }); + } + + // Compute outer/inner sizes + let outer_size: usize = dst.shape()[..dim].iter().product(); + let outer_size = if outer_size == 0 { 1 } else { outer_size }; + let inner_size: usize = dst.shape()[dim + 1..].iter().product(); + let inner_size = if inner_size == 0 { 1 } else { inner_size }; + + let dst_c = ensure_contiguous(dst); + let src_c = ensure_contiguous(src); + let out = Tensor::::empty(dst.shape(), dtype, &client.device); + + let dst_ptr = dst_c.ptr(); + let src_ptr = src_c.ptr(); + let out_ptr = out.ptr(); + + dispatch_dtype!(dtype, T => { + unsafe { + kernels::slice_assign_kernel::( + dst_ptr as *const T, + src_ptr as *const T, + out_ptr as *mut T, + outer_size, + dst_dim_size, + src_dim_size, + inner_size, + start, + ); + } + }, "slice_assign"); + + Ok(out) +} diff --git a/src/runtime/cpu/helpers/mod.rs b/src/runtime/cpu/helpers/mod.rs index 2d54dd2e..d5b5b585 100644 --- a/src/runtime/cpu/helpers/mod.rs +++ b/src/runtime/cpu/helpers/mod.rs @@ -7,6 +7,7 @@ pub mod activation; pub mod binary; pub mod compare; pub mod cumulative; +pub mod fused_elementwise; pub mod indexing; pub mod reduce; pub mod scalar; @@ -14,14 +15,18 @@ pub mod shape; pub mod unary; // Re-export all helper functions -pub use activation::{ActivationOp, activation_op_impl, elu_impl, leaky_relu_impl}; +pub use activation::{ + ActivationOp, FusedActivationMulOp, activation_op_impl, elu_impl, fused_activation_mul_impl, + leaky_relu_impl, +}; pub use binary::binary_op_impl; pub use compare::compare_op_impl; pub use cumulative::{cumprod_impl, cumsum_impl, logsumexp_impl}; +pub use fused_elementwise::{fused_add_mul_impl, fused_mul_add_impl, fused_mul_add_scalar_impl}; pub use indexing::{ bincount_impl, embedding_lookup_impl, gather_2d_impl, gather_impl, gather_nd_impl, index_put_impl, index_select_impl, masked_fill_impl, masked_select_impl, scatter_impl, - scatter_reduce_impl, + scatter_reduce_impl, slice_assign_impl, }; pub use reduce::{reduce_impl, reduce_impl_with_precision}; pub use scalar::scalar_op_impl; diff --git a/src/runtime/cpu/helpers/reduce/mod.rs b/src/runtime/cpu/helpers/reduce/mod.rs index 39999be4..c56f2531 100644 --- a/src/runtime/cpu/helpers/reduce/mod.rs +++ b/src/runtime/cpu/helpers/reduce/mod.rs @@ -49,8 +49,8 @@ pub fn reduce_impl( let out_shape = reduce_output_shape(shape, dims, keepdim); let out = Tensor::::empty(&out_shape, dtype, &client.device); - let a_ptr = a.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -67,7 +67,9 @@ pub fn reduce_impl( Ok(out) } else if dims.is_empty() { - Ok(a.clone()) + // Empty dims = reduce over ALL dimensions → scalar + let all_dims: Vec = (0..ndim).collect(); + return reduce_impl(client, op, a, &all_dims, keepdim, op_name); } else if should_fuse_multi_dim_reduction(a, dims) { reduce_multi_dim_fused( client, diff --git a/src/runtime/cpu/helpers/reduce/multi_dim.rs b/src/runtime/cpu/helpers/reduce/multi_dim.rs index 83ccf2cc..14c2a1b4 100644 --- a/src/runtime/cpu/helpers/reduce/multi_dim.rs +++ b/src/runtime/cpu/helpers/reduce/multi_dim.rs @@ -44,8 +44,8 @@ pub(super) fn reduce_multi_dim_fused( let numel = a.numel(); let out_numel = out.numel(); - let in_ptr = a.storage().ptr(); - let out_ptr = out.storage().ptr(); + let in_ptr = a.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(a.dtype(), T => { unsafe { diff --git a/src/runtime/cpu/helpers/reduce/precision.rs b/src/runtime/cpu/helpers/reduce/precision.rs index 68f0e5ac..a4b8bbc4 100644 --- a/src/runtime/cpu/helpers/reduce/precision.rs +++ b/src/runtime/cpu/helpers/reduce/precision.rs @@ -46,8 +46,8 @@ pub fn reduce_impl_with_precision( let out_shape = reduce_output_shape(shape, dims, keepdim); let out = Tensor::::empty(&out_shape, dtype, &client.device); - let a_ptr = a.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -115,8 +115,8 @@ fn reduce_single_dim_with_precision( let out_shape = reduce_output_shape(shape, &[dim], keepdim); let out = Tensor::::empty(&out_shape, dtype, &client.device); - let a_ptr = a.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a.ptr(); + let out_ptr = out.ptr(); if dim == ndim - 1 { dispatch_dtype!(dtype, T => { diff --git a/src/runtime/cpu/helpers/reduce/single_dim.rs b/src/runtime/cpu/helpers/reduce/single_dim.rs index f2709fa2..e277d1ad 100644 --- a/src/runtime/cpu/helpers/reduce/single_dim.rs +++ b/src/runtime/cpu/helpers/reduce/single_dim.rs @@ -41,8 +41,8 @@ pub(super) fn reduce_single_dim( let out = Tensor::::empty(&out_shape, dtype, &client.device); if dim == ndim - 1 { - let a_ptr = a.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -57,8 +57,8 @@ pub(super) fn reduce_single_dim( } }, op_name); } else { - let a_ptr = a.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { diff --git a/src/runtime/cpu/helpers/scalar.rs b/src/runtime/cpu/helpers/scalar.rs index f66b652a..bec9ff9c 100644 --- a/src/runtime/cpu/helpers/scalar.rs +++ b/src/runtime/cpu/helpers/scalar.rs @@ -21,8 +21,8 @@ pub fn scalar_op_impl( let out = Tensor::::empty(a.shape(), dtype, &client.device); let len = a.numel(); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -50,8 +50,8 @@ pub fn rsub_scalar_op_impl( let out = Tensor::::empty(a.shape(), dtype, &client.device); let len = a.numel(); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { diff --git a/src/runtime/cpu/helpers/shape.rs b/src/runtime/cpu/helpers/shape.rs index 58c624e1..a968d7c2 100644 --- a/src/runtime/cpu/helpers/shape.rs +++ b/src/runtime/cpu/helpers/shape.rs @@ -4,7 +4,8 @@ use super::super::{CpuClient, CpuRuntime}; use crate::dispatch_dtype; use crate::dtype::Element; use crate::error::Result; -use crate::runtime::{ensure_contiguous, shape_ops}; +use crate::runtime::common::shape_ops; +use crate::runtime::ensure_contiguous; use crate::tensor::Tensor; /// Concatenate tensors along a dimension @@ -16,7 +17,7 @@ pub fn cat_impl( let params = shape_ops::validate_cat(tensors, dim)?; let out = Tensor::::empty(¶ms.out_shape, params.dtype, &client.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); let elem_size = params.dtype.size_in_bytes(); // Byte-level copies — memcpy doesn't need type dispatch, and dispatch_dtype! @@ -26,10 +27,10 @@ pub fn cat_impl( for &tensor in tensors { let contig_tmp; let src_ptr = if tensor.is_contiguous() { - tensor.storage().ptr() as *const u8 + tensor.ptr() as *const u8 } else { contig_tmp = tensor.contiguous(); - contig_tmp.storage().ptr() as *const u8 + contig_tmp.ptr() as *const u8 }; let src_cat_size = tensor.shape()[params.dim_idx]; let src_bytes = src_cat_size * params.inner_size * elem_size; @@ -119,8 +120,8 @@ pub fn repeat_impl( // Make input contiguous let tensor_contig = ensure_contiguous(tensor); - let src_ptr = tensor_contig.storage().ptr(); - let dst_ptr = out.storage().ptr(); + let src_ptr = tensor_contig.ptr(); + let dst_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -200,8 +201,8 @@ pub fn pad_impl( // Make input contiguous let tensor_contig = ensure_contiguous(tensor); - let src_ptr = tensor_contig.storage().ptr(); - let dst_ptr = out.storage().ptr(); + let src_ptr = tensor_contig.ptr(); + let dst_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -278,8 +279,8 @@ pub fn roll_impl( // Make input contiguous let tensor_contig = ensure_contiguous(tensor); - let src_ptr = tensor_contig.storage().ptr(); - let dst_ptr = out.storage().ptr(); + let src_ptr = tensor_contig.ptr(); + let dst_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { diff --git a/src/runtime/cpu/helpers/unary.rs b/src/runtime/cpu/helpers/unary.rs index 6b12deba..0d04d5eb 100644 --- a/src/runtime/cpu/helpers/unary.rs +++ b/src/runtime/cpu/helpers/unary.rs @@ -19,8 +19,8 @@ pub fn unary_op_impl( let out = Tensor::::empty(a.shape(), dtype, &client.device); let len = a.numel(); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { diff --git a/src/runtime/cpu/kernels/binary.rs b/src/runtime/cpu/kernels/binary.rs index c7dc2507..a37b1feb 100644 --- a/src/runtime/cpu/kernels/binary.rs +++ b/src/runtime/cpu/kernels/binary.rs @@ -4,7 +4,7 @@ //! On x86-64, f32 and f64 operations use AVX-512 or AVX2 when available. //! On aarch64, f32 and f64 operations use NEON when available. -use crate::dtype::{DType, Element}; +use crate::dtype::Element; use crate::ops::BinaryOp; /// Execute a binary operation element-wise with automatic SIMD dispatch @@ -30,6 +30,7 @@ pub unsafe fn binary_op_kernel( #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] { use super::simd::binary; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -40,6 +41,32 @@ pub unsafe fn binary_op_kernel( binary::binary_f64(op, a as *const f64, b as *const f64, out as *mut f64, len); return; } + DType::I32 => { + binary::binary_i32(op, a as *const i32, b as *const i32, out as *mut i32, len); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + binary::binary_f16( + op, + a as *const half::f16, + b as *const half::f16, + out as *mut half::f16, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + binary::binary_bf16( + op, + a as *const half::bf16, + b as *const half::bf16, + out as *mut half::bf16, + len, + ); + return; + } _ => {} // Fall through to scalar } } @@ -237,6 +264,72 @@ pub unsafe fn binary_scalar_f64( } } +/// Scalar binary operation for i32 (used by SIMD for small arrays and tail) +#[inline] +pub unsafe fn binary_scalar_i32( + op: BinaryOp, + a: *const i32, + b: *const i32, + out: *mut i32, + len: usize, +) { + match op { + BinaryOp::Add => { + for i in 0..len { + *out.add(i) = (*a.add(i)).wrapping_add(*b.add(i)); + } + } + BinaryOp::Sub => { + for i in 0..len { + *out.add(i) = (*a.add(i)).wrapping_sub(*b.add(i)); + } + } + BinaryOp::Mul => { + for i in 0..len { + *out.add(i) = (*a.add(i)).wrapping_mul(*b.add(i)); + } + } + BinaryOp::Div => { + for i in 0..len { + let bv = *b.add(i); + *out.add(i) = if bv != 0 { + (*a.add(i)).wrapping_div(bv) + } else { + 0 + }; + } + } + BinaryOp::Max => { + for i in 0..len { + let av = *a.add(i); + let bv = *b.add(i); + *out.add(i) = if av > bv { av } else { bv }; + } + } + BinaryOp::Min => { + for i in 0..len { + let av = *a.add(i); + let bv = *b.add(i); + *out.add(i) = if av < bv { av } else { bv }; + } + } + BinaryOp::Pow => { + for i in 0..len { + let base = *a.add(i) as f64; + let exp = *b.add(i) as f64; + *out.add(i) = base.powf(exp) as i32; + } + } + BinaryOp::Atan2 => { + for i in 0..len { + let y = *a.add(i) as f64; + let x = *b.add(i) as f64; + *out.add(i) = y.atan2(x) as i32; + } + } + } +} + /// Execute a binary operation with broadcasting support /// /// Uses strides to handle arbitrary broadcasting patterns. Stride of 0 means diff --git a/src/runtime/cpu/kernels/compare.rs b/src/runtime/cpu/kernels/compare.rs index 62c18608..8f4e58e5 100644 --- a/src/runtime/cpu/kernels/compare.rs +++ b/src/runtime/cpu/kernels/compare.rs @@ -1,6 +1,6 @@ //! Comparison operation kernels -use crate::dtype::{DType, Element}; +use crate::dtype::Element; use crate::ops::CompareOp; /// Execute a comparison operation element-wise @@ -26,6 +26,7 @@ pub unsafe fn compare_op_kernel( #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] { use super::simd::compare; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -36,6 +37,28 @@ pub unsafe fn compare_op_kernel( compare::compare_f64(op, a as *const f64, b as *const f64, out as *mut f64, len); return; } + #[cfg(feature = "f16")] + DType::F16 => { + compare::compare_f16( + op, + a as *const half::f16, + b as *const half::f16, + out as *mut half::f16, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + compare::compare_bf16( + op, + a as *const half::bf16, + b as *const half::bf16, + out as *mut half::bf16, + len, + ); + return; + } _ => {} // Fall through to scalar } } diff --git a/src/runtime/cpu/kernels/conv.rs b/src/runtime/cpu/kernels/conv.rs index b0132a58..1ff4fe6b 100644 --- a/src/runtime/cpu/kernels/conv.rs +++ b/src/runtime/cpu/kernels/conv.rs @@ -2,6 +2,8 @@ //! //! Direct convolution implementations without im2col transformation. +#[cfg(feature = "f16")] +use crate::dtype::DType; use crate::dtype::Element; use crate::ops::conv_common::{Conv1dParams, Conv2dParams}; @@ -20,6 +22,36 @@ pub unsafe fn conv1d_kernel( output: *mut T, params: Conv1dParams, ) { + // Dispatch to SIMD for f16/bf16 on x86-64 and aarch64 + #[cfg(all(feature = "f16", any(target_arch = "x86_64", target_arch = "aarch64")))] + { + use super::simd::conv as simd_conv; + + match T::DTYPE { + DType::F16 => { + simd_conv::conv1d_f16( + input as *const half::f16, + weight as *const half::f16, + bias.map(|b| b as *const half::f16), + output as *mut half::f16, + params, + ); + return; + } + DType::BF16 => { + simd_conv::conv1d_bf16( + input as *const half::bf16, + weight as *const half::bf16, + bias.map(|b| b as *const half::bf16), + output as *mut half::bf16, + params, + ); + return; + } + _ => {} // Fall through to scalar + } + } + let Conv1dParams { batch, c_in, @@ -106,6 +138,36 @@ pub unsafe fn conv2d_kernel( output: *mut T, params: Conv2dParams, ) { + // Dispatch to SIMD for f16/bf16 on x86-64 and aarch64 + #[cfg(all(feature = "f16", any(target_arch = "x86_64", target_arch = "aarch64")))] + { + use super::simd::conv as simd_conv; + + match T::DTYPE { + DType::F16 => { + simd_conv::conv2d_f16( + input as *const half::f16, + weight as *const half::f16, + bias.map(|b| b as *const half::f16), + output as *mut half::f16, + params, + ); + return; + } + DType::BF16 => { + simd_conv::conv2d_bf16( + input as *const half::bf16, + weight as *const half::bf16, + bias.map(|b| b as *const half::bf16), + output as *mut half::bf16, + params, + ); + return; + } + _ => {} // Fall through to scalar + } + } + let Conv2dParams { batch, c_in, @@ -222,6 +284,36 @@ pub unsafe fn depthwise_conv2d_kernel( output: *mut T, params: Conv2dParams, ) { + // Dispatch to SIMD for f16/bf16 on x86-64 and aarch64 + #[cfg(all(feature = "f16", any(target_arch = "x86_64", target_arch = "aarch64")))] + { + use super::simd::conv as simd_conv; + + match T::DTYPE { + DType::F16 => { + simd_conv::depthwise_conv2d_f16( + input as *const half::f16, + weight as *const half::f16, + bias.map(|b| b as *const half::f16), + output as *mut half::f16, + params, + ); + return; + } + DType::BF16 => { + simd_conv::depthwise_conv2d_bf16( + input as *const half::bf16, + weight as *const half::bf16, + bias.map(|b| b as *const half::bf16), + output as *mut half::bf16, + params, + ); + return; + } + _ => {} // Fall through to scalar + } + } + let Conv2dParams { batch, c_in, diff --git a/src/runtime/cpu/kernels/cumulative.rs b/src/runtime/cpu/kernels/cumulative.rs index f122aaa1..f73b9c84 100644 --- a/src/runtime/cpu/kernels/cumulative.rs +++ b/src/runtime/cpu/kernels/cumulative.rs @@ -1,6 +1,6 @@ //! Cumulative operation kernels (cumsum, cumprod, logsumexp) -use crate::dtype::{DType, Element}; +use crate::dtype::Element; /// Cumulative sum along a contiguous dimension /// @@ -53,6 +53,7 @@ pub unsafe fn cumsum_strided_kernel( #[cfg(target_arch = "x86_64")] { use super::simd::cumulative; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -75,6 +76,28 @@ pub unsafe fn cumsum_strided_kernel( ); return; } + #[cfg(feature = "f16")] + DType::F16 => { + cumulative::cumsum_strided_f16( + a as *const half::f16, + out as *mut half::f16, + scan_size, + outer_size, + inner_size, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + cumulative::cumsum_strided_bf16( + a as *const half::bf16, + out as *mut half::bf16, + scan_size, + outer_size, + inner_size, + ); + return; + } _ => {} // Fall through to scalar } } @@ -144,6 +167,7 @@ pub unsafe fn cumprod_strided_kernel( #[cfg(target_arch = "x86_64")] { use super::simd::cumulative; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -166,6 +190,28 @@ pub unsafe fn cumprod_strided_kernel( ); return; } + #[cfg(feature = "f16")] + DType::F16 => { + cumulative::cumprod_strided_f16( + a as *const half::f16, + out as *mut half::f16, + scan_size, + outer_size, + inner_size, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + cumulative::cumprod_strided_bf16( + a as *const half::bf16, + out as *mut half::bf16, + scan_size, + outer_size, + inner_size, + ); + return; + } _ => {} // Fall through to scalar } } @@ -212,6 +258,7 @@ pub unsafe fn logsumexp_kernel( #[cfg(target_arch = "x86_64")] { use super::simd::logsumexp; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -222,6 +269,26 @@ pub unsafe fn logsumexp_kernel( logsumexp::logsumexp_f64(a as *const f64, out as *mut f64, reduce_size, outer_size); return; } + #[cfg(feature = "f16")] + DType::F16 => { + logsumexp::logsumexp_f16( + a as *const half::f16, + out as *mut half::f16, + reduce_size, + outer_size, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + logsumexp::logsumexp_bf16( + a as *const half::bf16, + out as *mut half::bf16, + reduce_size, + outer_size, + ); + return; + } _ => {} // Fall through to scalar } } diff --git a/src/runtime/cpu/kernels/distributions.rs b/src/runtime/cpu/kernels/distributions.rs index f365a749..81fbfeb2 100644 --- a/src/runtime/cpu/kernels/distributions.rs +++ b/src/runtime/cpu/kernels/distributions.rs @@ -1,11 +1,10 @@ //! Distribution sampling kernels for CPU //! -//! Implements probability distribution sampling using the rand_distr crate. +//! Implements probability distribution sampling using numr's own PRNG and samplers. //! All kernels support F32, F64, and optionally F16/BF16 via the Element trait. +use super::rng; use crate::dtype::Element; -use rand::Rng; -use rand_distr::{Beta, Binomial, Distribution, Exp, Gamma, Poisson, StandardNormal}; /// Sample from Bernoulli distribution (binary outcomes) /// @@ -16,11 +15,11 @@ use rand_distr::{Beta, Binomial, Distribution, Exp, Gamma, Poisson, StandardNorm /// - `p` must be in [0, 1] #[inline] pub unsafe fn bernoulli_kernel(out: *mut T, p: f64, len: usize) { - let mut rng = rand::rng(); + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, len); for elem in out_slice.iter_mut() { - let u: f64 = rng.random(); + let u = rng::sample_uniform(&mut prng); let val = if u < p { 1.0 } else { 0.0 }; *elem = T::from_f64(val); } @@ -28,20 +27,19 @@ pub unsafe fn bernoulli_kernel(out: *mut T, p: f64, len: usize) { /// Sample from Beta distribution /// -/// Uses the relationship: if X ~ Gamma(α, 1) and Y ~ Gamma(β, 1), -/// then X / (X + Y) ~ Beta(α, β). +/// Uses the relationship: if X ~ Gamma(a, 1) and Y ~ Gamma(b, 1), +/// then X / (X + Y) ~ Beta(a, b). /// /// # Safety /// - `out` must be a valid pointer to `len` elements /// - `alpha > 0` and `beta > 0` #[inline] pub unsafe fn beta_kernel(out: *mut T, alpha: f64, beta: f64, len: usize) { - let mut rng = rand::rng(); - let dist = Beta::new(alpha, beta).expect("Invalid beta parameters"); + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, len); for elem in out_slice.iter_mut() { - let val: f64 = dist.sample(&mut rng); + let val = rng::sample_beta(&mut prng, alpha, beta); *elem = T::from_f64(val); } } @@ -56,12 +54,11 @@ pub unsafe fn beta_kernel(out: *mut T, alpha: f64, beta: f64, len: u /// - `shape_param > 0` and `scale > 0` #[inline] pub unsafe fn gamma_kernel(out: *mut T, shape_param: f64, scale: f64, len: usize) { - let mut rng = rand::rng(); - let dist = Gamma::new(shape_param, scale).expect("Invalid gamma parameters"); + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, len); for elem in out_slice.iter_mut() { - let val: f64 = dist.sample(&mut rng); + let val = rng::sample_gamma(&mut prng, shape_param, scale); *elem = T::from_f64(val); } } @@ -75,12 +72,11 @@ pub unsafe fn gamma_kernel(out: *mut T, shape_param: f64, scale: f64 /// - `rate > 0` #[inline] pub unsafe fn exponential_kernel(out: *mut T, rate: f64, len: usize) { - let mut rng = rand::rng(); - let dist = Exp::new(rate).expect("Invalid exponential rate"); + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, len); for elem in out_slice.iter_mut() { - let val: f64 = dist.sample(&mut rng); + let val = rng::sample_exponential(&mut prng, rate); *elem = T::from_f64(val); } } @@ -88,19 +84,18 @@ pub unsafe fn exponential_kernel(out: *mut T, rate: f64, len: usize) /// Sample from Poisson distribution /// /// For small lambda (< 30): uses direct inversion method. -/// For large lambda: uses normal approximation internally. +/// For large lambda: uses normal approximation. /// /// # Safety /// - `out` must be a valid pointer to `len` elements /// - `lambda > 0` #[inline] pub unsafe fn poisson_kernel(out: *mut T, lambda: f64, len: usize) { - let mut rng = rand::rng(); - let dist = Poisson::new(lambda).expect("Invalid poisson lambda"); + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, len); for elem in out_slice.iter_mut() { - let val: f64 = dist.sample(&mut rng); + let val = rng::sample_poisson(&mut prng, lambda) as f64; *elem = T::from_f64(val); } } @@ -108,26 +103,25 @@ pub unsafe fn poisson_kernel(out: *mut T, lambda: f64, len: usize) { /// Sample from Binomial distribution /// /// For small n: direct simulation (sum of Bernoulli trials). -/// For large n: uses BTRD algorithm internally. +/// For large n: uses normal approximation. /// /// # Safety /// - `out` must be a valid pointer to `len` elements /// - `n > 0` and `p` in [0, 1] #[inline] pub unsafe fn binomial_kernel(out: *mut T, n: u64, p: f64, len: usize) { - let mut rng = rand::rng(); - let dist = Binomial::new(n, p).expect("Invalid binomial parameters"); + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, len); for elem in out_slice.iter_mut() { - let val = dist.sample(&mut rng); - *elem = T::from_f64(val as f64); + let val = rng::sample_binomial(&mut prng, n, p) as f64; + *elem = T::from_f64(val); } } /// Sample from Laplace (double exponential) distribution /// -/// Uses inverse transform: X = μ - b * sign(U - 0.5) * ln(1 - 2|U - 0.5|) +/// Uses inverse transform: X = mu - b * sign(U - 0.5) * ln(1 - 2|U - 0.5|) /// where U ~ Uniform(0, 1). /// /// # Safety @@ -135,11 +129,11 @@ pub unsafe fn binomial_kernel(out: *mut T, n: u64, p: f64, len: usiz /// - `scale > 0` #[inline] pub unsafe fn laplace_kernel(out: *mut T, loc: f64, scale: f64, len: usize) { - let mut rng = rand::rng(); + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, len); for elem in out_slice.iter_mut() { - let u: f64 = rng.random::() - 0.5; + let u = rng::sample_uniform(&mut prng) - 0.5; // Avoid log(0) by clamping let abs_u = u.abs().max(1e-300); let val = loc - scale * u.signum() * (1.0 - 2.0 * abs_u).ln(); @@ -149,42 +143,37 @@ pub unsafe fn laplace_kernel(out: *mut T, loc: f64, scale: f64, len: /// Sample from Chi-squared distribution /// -/// Implemented as Gamma(df/2, 2) since χ²(k) = Gamma(k/2, 2). +/// Implemented as Gamma(df/2, 2) since chi2(k) = Gamma(k/2, 2). /// /// # Safety /// - `out` must be a valid pointer to `len` elements /// - `df > 0` #[inline] pub unsafe fn chi_squared_kernel(out: *mut T, df: f64, len: usize) { - let mut rng = rand::rng(); - // χ²(df) = Gamma(df/2, 2) - let dist = Gamma::new(df / 2.0, 2.0).expect("Invalid chi-squared df"); + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, len); for elem in out_slice.iter_mut() { - let val: f64 = dist.sample(&mut rng); + let val = rng::sample_gamma(&mut prng, df / 2.0, 2.0); *elem = T::from_f64(val); } } /// Sample from Student's t distribution /// -/// Uses the relationship: T = Z / sqrt(V/ν) where Z ~ N(0,1) and V ~ χ²(ν). +/// Uses the relationship: T = Z / sqrt(V/nu) where Z ~ N(0,1) and V ~ chi2(nu). /// /// # Safety /// - `out` must be a valid pointer to `len` elements /// - `df > 0` #[inline] pub unsafe fn student_t_kernel(out: *mut T, df: f64, len: usize) { - let mut rng = rand::rng(); - let normal = StandardNormal; - // χ²(df) = Gamma(df/2, 2) - let chi2 = Gamma::new(df / 2.0, 2.0).expect("Invalid student-t df"); + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, len); for elem in out_slice.iter_mut() { - let z: f64 = normal.sample(&mut rng); - let v: f64 = chi2.sample(&mut rng); + let z = rng::sample_normal(&mut prng); + let v = rng::sample_gamma(&mut prng, df / 2.0, 2.0); let val = z / (v / df).sqrt(); *elem = T::from_f64(val); } @@ -192,23 +181,20 @@ pub unsafe fn student_t_kernel(out: *mut T, df: f64, len: usize) { /// Sample from F distribution /// -/// Uses the relationship: F = (X₁/d₁) / (X₂/d₂) -/// where X₁ ~ χ²(d₁) and X₂ ~ χ²(d₂). +/// Uses the relationship: F = (X1/d1) / (X2/d2) +/// where X1 ~ chi2(d1) and X2 ~ chi2(d2). /// /// # Safety /// - `out` must be a valid pointer to `len` elements /// - `df1 > 0` and `df2 > 0` #[inline] pub unsafe fn f_distribution_kernel(out: *mut T, df1: f64, df2: f64, len: usize) { - let mut rng = rand::rng(); - // χ²(df) = Gamma(df/2, 2) - let chi2_1 = Gamma::new(df1 / 2.0, 2.0).expect("Invalid F df1"); - let chi2_2 = Gamma::new(df2 / 2.0, 2.0).expect("Invalid F df2"); + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, len); for elem in out_slice.iter_mut() { - let x1: f64 = chi2_1.sample(&mut rng); - let x2: f64 = chi2_2.sample(&mut rng); + let x1 = rng::sample_gamma(&mut prng, df1 / 2.0, 2.0); + let x2 = rng::sample_gamma(&mut prng, df2 / 2.0, 2.0); let val = (x1 / df1) / (x2 / df2); *elem = T::from_f64(val); } @@ -239,7 +225,7 @@ mod tests { // All values should be in (0, 1) assert!(out.iter().all(|&x| x > 0.0 && x < 1.0)); - // Mean should be approximately α/(α+β) = 2/7 ≈ 0.286 + // Mean should be approximately alpha/(alpha+beta) = 2/7 ~ 0.286 let mean: f64 = out.iter().sum::() / 1000.0; assert!((mean - 0.286).abs() < 0.05); } @@ -252,7 +238,7 @@ mod tests { // All values should be positive assert!(out.iter().all(|&x| x > 0.0)); - // Mean should be approximately k*θ = 2 + // Mean should be approximately k*theta = 2 let mean: f64 = out.iter().sum::() / 1000.0; assert!((mean - 2.0).abs() < 0.3); } @@ -265,7 +251,7 @@ mod tests { // All values should be non-negative assert!(out.iter().all(|&x| x >= 0.0)); - // Mean should be approximately 1/λ = 2 + // Mean should be approximately 1/lambda = 2 let mean: f64 = out.iter().sum::() / 1000.0; assert!((mean - 2.0).abs() < 0.4); } @@ -278,7 +264,7 @@ mod tests { // All values should be non-negative integers assert!(out.iter().all(|&x| x >= 0.0 && x == x.floor())); - // Mean should be approximately λ = 5 + // Mean should be approximately lambda = 5 let mean: f64 = out.iter().sum::() / 1000.0; assert!((mean - 5.0).abs() < 0.5); } @@ -337,7 +323,7 @@ mod tests { // All values should be positive assert!(out.iter().all(|&x| x > 0.0)); - // Mean should be approximately d₂/(d₂-2) = 20/18 ≈ 1.11 for d₂ > 2 + // Mean should be approximately d2/(d2-2) = 20/18 ~ 1.11 for d2 > 2 let mean: f64 = out.iter().sum::() / 1000.0; assert!((mean - 1.11).abs() < 0.3); } diff --git a/src/runtime/cpu/kernels/fused_add_norm.rs b/src/runtime/cpu/kernels/fused_add_norm.rs new file mode 100644 index 00000000..2cc5cc1c --- /dev/null +++ b/src/runtime/cpu/kernels/fused_add_norm.rs @@ -0,0 +1,546 @@ +//! Fused Add + Normalization kernels +//! +//! Provides fused add+norm operations with automatic SIMD dispatch. + +use crate::dtype::Element; + +/// Fused Add + RMS Norm kernel: pre_norm = input + residual, output = rms_norm(pre_norm) +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_kernel( + input: *const T, + residual: *const T, + weight: *const T, + out: *mut T, + pre_norm: *mut T, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + { + use super::simd::norm; + use crate::dtype::DType; + match T::DTYPE { + DType::F32 => { + norm::fused_add_rms_norm_f32( + input as *const f32, + residual as *const f32, + weight as *const f32, + out as *mut f32, + pre_norm as *mut f32, + batch_size, + hidden_size, + eps, + ); + return; + } + DType::F64 => { + norm::fused_add_rms_norm_f64( + input as *const f64, + residual as *const f64, + weight as *const f64, + out as *mut f64, + pre_norm as *mut f64, + batch_size, + hidden_size, + eps as f64, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + norm::fused_add_rms_norm_f16( + input as *const half::f16, + residual as *const half::f16, + weight as *const half::f16, + out as *mut half::f16, + pre_norm as *mut half::f16, + batch_size, + hidden_size, + eps, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + norm::fused_add_rms_norm_bf16( + input as *const half::bf16, + residual as *const half::bf16, + weight as *const half::bf16, + out as *mut half::bf16, + pre_norm as *mut half::bf16, + batch_size, + hidden_size, + eps, + ); + return; + } + _ => {} + } + } + fused_add_rms_norm_scalar( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ); +} + +#[inline] +unsafe fn fused_add_rms_norm_scalar( + input: *const T, + residual: *const T, + weight: *const T, + out: *mut T, + pre_norm_out: *mut T, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let eps = eps as f64; + let weight_slice = std::slice::from_raw_parts(weight, hidden_size); + for batch in 0..batch_size { + let row = batch * hidden_size; + let mut sum_sq = 0.0f64; + for i in 0..hidden_size { + let pn = (*input.add(row + i)).to_f64() + (*residual.add(row + i)).to_f64(); + *pre_norm_out.add(row + i) = T::from_f64(pn); + sum_sq += pn * pn; + } + let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt(); + for (i, &w) in weight_slice.iter().enumerate() { + let pn = (*pre_norm_out.add(row + i)).to_f64(); + *out.add(row + i) = T::from_f64(pn * inv_rms * w.to_f64()); + } + } +} + +/// Backward pass for fused add + RMS norm +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_bwd_kernel( + grad: *const T, + pre_norm: *const T, + weight: *const T, + d_input_residual: *mut T, + d_weight: *mut T, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + { + use super::simd::norm; + use crate::dtype::DType; + match T::DTYPE { + DType::F32 => { + norm::fused_add_rms_norm_bwd_f32( + grad as *const f32, + pre_norm as *const f32, + weight as *const f32, + d_input_residual as *mut f32, + d_weight as *mut f32, + batch_size, + hidden_size, + eps, + ); + return; + } + DType::F64 => { + norm::fused_add_rms_norm_bwd_f64( + grad as *const f64, + pre_norm as *const f64, + weight as *const f64, + d_input_residual as *mut f64, + d_weight as *mut f64, + batch_size, + hidden_size, + eps as f64, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + norm::fused_add_rms_norm_bwd_f16( + grad as *const half::f16, + pre_norm as *const half::f16, + weight as *const half::f16, + d_input_residual as *mut half::f16, + d_weight as *mut half::f16, + batch_size, + hidden_size, + eps, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + norm::fused_add_rms_norm_bwd_bf16( + grad as *const half::bf16, + pre_norm as *const half::bf16, + weight as *const half::bf16, + d_input_residual as *mut half::bf16, + d_weight as *mut half::bf16, + batch_size, + hidden_size, + eps, + ); + return; + } + _ => {} + } + } + fused_add_rms_norm_bwd_scalar( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ); +} + +#[inline] +unsafe fn fused_add_rms_norm_bwd_scalar( + grad: *const T, + pre_norm: *const T, + weight: *const T, + d_input_residual: *mut T, + d_weight: *mut T, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let eps = eps as f64; + let weight_slice = std::slice::from_raw_parts(weight, hidden_size); + // d_weight is pre-zeroed by caller + for batch in 0..batch_size { + let row = batch * hidden_size; + // Recompute inv_rms + let mut sum_sq = 0.0f64; + for i in 0..hidden_size { + let pn = (*pre_norm.add(row + i)).to_f64(); + sum_sq += pn * pn; + } + let mean_sq = sum_sq / hidden_size as f64; + let inv_rms = 1.0 / (mean_sq + eps).sqrt(); + // Compute dot = sum(grad * weight * pre_norm) + let mut dot = 0.0f64; + for i in 0..hidden_size { + let g = (*grad.add(row + i)).to_f64(); + let w = weight_slice[i].to_f64(); + let pn = (*pre_norm.add(row + i)).to_f64(); + dot += g * w * pn; + } + let coeff = dot * inv_rms / (hidden_size as f64 * (mean_sq + eps)); + for i in 0..hidden_size { + let g = (*grad.add(row + i)).to_f64(); + let w = weight_slice[i].to_f64(); + let pn = (*pre_norm.add(row + i)).to_f64(); + let d_ir = (g * w - pn * coeff) * inv_rms; + *d_input_residual.add(row + i) = T::from_f64(d_ir); + // Accumulate d_weight + let dw_old = (*d_weight.add(i)).to_f64(); + *d_weight.add(i) = T::from_f64(dw_old + g * pn * inv_rms); + } + } +} + +/// Fused Add + Layer Norm kernel +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_kernel( + input: *const T, + residual: *const T, + weight: *const T, + bias: *const T, + out: *mut T, + pre_norm: *mut T, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + { + use super::simd::norm; + use crate::dtype::DType; + match T::DTYPE { + DType::F32 => { + norm::fused_add_layer_norm_f32( + input as *const f32, + residual as *const f32, + weight as *const f32, + bias as *const f32, + out as *mut f32, + pre_norm as *mut f32, + batch_size, + hidden_size, + eps, + ); + return; + } + DType::F64 => { + norm::fused_add_layer_norm_f64( + input as *const f64, + residual as *const f64, + weight as *const f64, + bias as *const f64, + out as *mut f64, + pre_norm as *mut f64, + batch_size, + hidden_size, + eps as f64, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + norm::fused_add_layer_norm_f16( + input as *const half::f16, + residual as *const half::f16, + weight as *const half::f16, + bias as *const half::f16, + out as *mut half::f16, + pre_norm as *mut half::f16, + batch_size, + hidden_size, + eps, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + norm::fused_add_layer_norm_bf16( + input as *const half::bf16, + residual as *const half::bf16, + weight as *const half::bf16, + bias as *const half::bf16, + out as *mut half::bf16, + pre_norm as *mut half::bf16, + batch_size, + hidden_size, + eps, + ); + return; + } + _ => {} + } + } + fused_add_layer_norm_scalar( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ); +} + +#[inline] +unsafe fn fused_add_layer_norm_scalar( + input: *const T, + residual: *const T, + weight: *const T, + bias: *const T, + out: *mut T, + pre_norm_out: *mut T, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let eps = eps as f64; + let weight_slice = std::slice::from_raw_parts(weight, hidden_size); + let bias_slice = std::slice::from_raw_parts(bias, hidden_size); + for batch in 0..batch_size { + let row = batch * hidden_size; + // Pass 1: add + compute mean + let mut sum = 0.0f64; + for i in 0..hidden_size { + let pn = (*input.add(row + i)).to_f64() + (*residual.add(row + i)).to_f64(); + *pre_norm_out.add(row + i) = T::from_f64(pn); + sum += pn; + } + let mean = sum / hidden_size as f64; + // Pass 2: variance + let mut var_sum = 0.0f64; + for i in 0..hidden_size { + let pn = (*pre_norm_out.add(row + i)).to_f64(); + let diff = pn - mean; + var_sum += diff * diff; + } + let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); + // Pass 3: normalize + for i in 0..hidden_size { + let pn = (*pre_norm_out.add(row + i)).to_f64(); + let w = weight_slice[i].to_f64(); + let b = bias_slice[i].to_f64(); + *out.add(row + i) = T::from_f64((pn - mean) * inv_std * w + b); + } + } +} + +/// Backward pass for fused add + layer norm +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_bwd_kernel( + grad: *const T, + pre_norm: *const T, + weight: *const T, + _bias: *const T, + d_input_residual: *mut T, + d_weight: *mut T, + d_bias: *mut T, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + { + use super::simd::norm; + use crate::dtype::DType; + match T::DTYPE { + DType::F32 => { + norm::fused_add_layer_norm_bwd_f32( + grad as *const f32, + pre_norm as *const f32, + weight as *const f32, + d_input_residual as *mut f32, + d_weight as *mut f32, + d_bias as *mut f32, + batch_size, + hidden_size, + eps, + ); + return; + } + DType::F64 => { + norm::fused_add_layer_norm_bwd_f64( + grad as *const f64, + pre_norm as *const f64, + weight as *const f64, + d_input_residual as *mut f64, + d_weight as *mut f64, + d_bias as *mut f64, + batch_size, + hidden_size, + eps as f64, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + norm::fused_add_layer_norm_bwd_f16( + grad as *const half::f16, + pre_norm as *const half::f16, + weight as *const half::f16, + d_input_residual as *mut half::f16, + d_weight as *mut half::f16, + d_bias as *mut half::f16, + batch_size, + hidden_size, + eps, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + norm::fused_add_layer_norm_bwd_bf16( + grad as *const half::bf16, + pre_norm as *const half::bf16, + weight as *const half::bf16, + d_input_residual as *mut half::bf16, + d_weight as *mut half::bf16, + d_bias as *mut half::bf16, + batch_size, + hidden_size, + eps, + ); + return; + } + _ => {} + } + } + fused_add_layer_norm_bwd_scalar( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ); +} + +#[inline] +unsafe fn fused_add_layer_norm_bwd_scalar( + grad: *const T, + pre_norm: *const T, + weight: *const T, + d_input_residual: *mut T, + d_weight: *mut T, + d_bias: *mut T, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let eps = eps as f64; + let weight_slice = std::slice::from_raw_parts(weight, hidden_size); + // d_weight and d_bias are pre-zeroed + for batch in 0..batch_size { + let row = batch * hidden_size; + // Recompute mean and inv_std from pre_norm + let mut sum = 0.0f64; + for i in 0..hidden_size { + sum += (*pre_norm.add(row + i)).to_f64(); + } + let mean = sum / hidden_size as f64; + let mut var_sum = 0.0f64; + for i in 0..hidden_size { + let diff = (*pre_norm.add(row + i)).to_f64() - mean; + var_sum += diff * diff; + } + let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); + + // Compute intermediate sums for d_input_residual + let mut mean_gs = 0.0f64; + let mut mean_gs_n = 0.0f64; + for i in 0..hidden_size { + let g = (*grad.add(row + i)).to_f64(); + let w = weight_slice[i].to_f64(); + let pn = (*pre_norm.add(row + i)).to_f64(); + let normalized = (pn - mean) * inv_std; + let gs = g * w; + mean_gs += gs; + mean_gs_n += gs * normalized; + } + mean_gs /= hidden_size as f64; + mean_gs_n /= hidden_size as f64; + + for i in 0..hidden_size { + let g = (*grad.add(row + i)).to_f64(); + let w = weight_slice[i].to_f64(); + let pn = (*pre_norm.add(row + i)).to_f64(); + let normalized = (pn - mean) * inv_std; + let gs = g * w; + let d_ir = inv_std * (gs - mean_gs - normalized * mean_gs_n); + *d_input_residual.add(row + i) = T::from_f64(d_ir); + // Accumulate d_weight and d_bias + let dw_old = (*d_weight.add(i)).to_f64(); + *d_weight.add(i) = T::from_f64(dw_old + g * normalized); + let db_old = (*d_bias.add(i)).to_f64(); + *d_bias.add(i) = T::from_f64(db_old + g); + } + } +} diff --git a/src/runtime/cpu/kernels/fused_elementwise.rs b/src/runtime/cpu/kernels/fused_elementwise.rs new file mode 100644 index 00000000..001008ea --- /dev/null +++ b/src/runtime/cpu/kernels/fused_elementwise.rs @@ -0,0 +1,237 @@ +//! Fused elementwise kernel entry points +//! +//! - fused_mul_add: out = a * b + c +//! - fused_add_mul: out = (a + b) * c +//! - fused_mul_add_scalar: out = a * scale + bias + +use crate::dtype::Element; + +/// Fused multiply-add: `out[i] = a[i] * b[i] + c[i]` +/// +/// # Safety +/// - `a`, `b`, `c`, and `out` must be valid pointers to `len` elements +#[inline] +pub unsafe fn fused_mul_add_kernel( + a: *const T, + b: *const T, + c: *const T, + out: *mut T, + len: usize, +) { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + { + use super::simd::fused_elementwise; + use crate::dtype::DType; + + match T::DTYPE { + DType::F32 => { + fused_elementwise::fused_mul_add_f32( + a as *const f32, + b as *const f32, + c as *const f32, + out as *mut f32, + len, + ); + return; + } + DType::F64 => { + fused_elementwise::fused_mul_add_f64( + a as *const f64, + b as *const f64, + c as *const f64, + out as *mut f64, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + fused_elementwise::fused_mul_add_f16( + a as *const half::f16, + b as *const half::f16, + c as *const half::f16, + out as *mut half::f16, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + fused_elementwise::fused_mul_add_bf16( + a as *const half::bf16, + b as *const half::bf16, + c as *const half::bf16, + out as *mut half::bf16, + len, + ); + return; + } + _ => {} + } + } + + fused_ternary_scalar(a, b, c, out, len, |x, y, z| x * y + z); +} + +/// Fused add-multiply: `out[i] = (a[i] + b[i]) * c[i]` +/// +/// # Safety +/// - `a`, `b`, `c`, and `out` must be valid pointers to `len` elements +#[inline] +pub unsafe fn fused_add_mul_kernel( + a: *const T, + b: *const T, + c: *const T, + out: *mut T, + len: usize, +) { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + { + use super::simd::fused_elementwise; + use crate::dtype::DType; + + match T::DTYPE { + DType::F32 => { + fused_elementwise::fused_add_mul_f32( + a as *const f32, + b as *const f32, + c as *const f32, + out as *mut f32, + len, + ); + return; + } + DType::F64 => { + fused_elementwise::fused_add_mul_f64( + a as *const f64, + b as *const f64, + c as *const f64, + out as *mut f64, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + fused_elementwise::fused_add_mul_f16( + a as *const half::f16, + b as *const half::f16, + c as *const half::f16, + out as *mut half::f16, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + fused_elementwise::fused_add_mul_bf16( + a as *const half::bf16, + b as *const half::bf16, + c as *const half::bf16, + out as *mut half::bf16, + len, + ); + return; + } + _ => {} + } + } + + fused_ternary_scalar(a, b, c, out, len, |x, y, z| (x + y) * z); +} + +/// Fused multiply-add scalar: `out[i] = a[i] * scale + bias` +/// +/// # Safety +/// - `a` and `out` must be valid pointers to `len` elements +#[inline] +pub unsafe fn fused_mul_add_scalar_kernel( + a: *const T, + scale: f64, + bias: f64, + out: *mut T, + len: usize, +) { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + { + use super::simd::fused_elementwise; + use crate::dtype::DType; + + match T::DTYPE { + DType::F32 => { + fused_elementwise::fused_mul_add_scalar_f32_kernel( + a as *const f32, + scale as f32, + bias as f32, + out as *mut f32, + len, + ); + return; + } + DType::F64 => { + fused_elementwise::fused_mul_add_scalar_f64_kernel( + a as *const f64, + scale, + bias, + out as *mut f64, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + fused_elementwise::fused_mul_add_scalar_f32_f16( + a as *const half::f16, + scale as f32, + bias as f32, + out as *mut half::f16, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + fused_elementwise::fused_mul_add_scalar_f32_bf16( + a as *const half::bf16, + scale as f32, + bias as f32, + out as *mut half::bf16, + len, + ); + return; + } + _ => {} + } + } + + // Scalar fallback + let a_slice = std::slice::from_raw_parts(a, len); + let out_slice = std::slice::from_raw_parts_mut(out, len); + for i in 0..len { + let val = a_slice[i].to_f64(); + out_slice[i] = T::from_f64(val * scale + bias); + } +} + +/// Generic scalar fallback for ternary fused ops +#[inline] +unsafe fn fused_ternary_scalar f64>( + a: *const T, + b: *const T, + c: *const T, + out: *mut T, + len: usize, + op: F, +) { + let a_slice = std::slice::from_raw_parts(a, len); + let b_slice = std::slice::from_raw_parts(b, len); + let c_slice = std::slice::from_raw_parts(c, len); + let out_slice = std::slice::from_raw_parts_mut(out, len); + + for i in 0..len { + let x = a_slice[i].to_f64(); + let y = b_slice[i].to_f64(); + let z = c_slice[i].to_f64(); + out_slice[i] = T::from_f64(op(x, y, z)); + } +} diff --git a/src/runtime/cpu/kernels/gemm_epilogue/backward.rs b/src/runtime/cpu/kernels/gemm_epilogue/backward.rs new file mode 100644 index 00000000..8f1fe7dc --- /dev/null +++ b/src/runtime/cpu/kernels/gemm_epilogue/backward.rs @@ -0,0 +1,281 @@ +//! Backward kernel for GEMM epilogue operations. +//! +//! Computes gradients for `activation(A @ B + bias)`. +//! Accumulation is done in f32 for sub-f32 types (F16, BF16) and in native +//! precision for F32/F64, matching standard ML framework practice. + +use crate::dtype::{DType, Element}; +use crate::ops::GemmActivation; + +/// Float type used for accumulation in backward pass. +/// +/// Only f32 and f64 are used as accumulation types. This trait provides +/// the minimal interface needed for the backward kernel to be generic +/// over both precisions. +trait AccFloat: + Copy + + std::ops::Add + + std::ops::AddAssign + + std::ops::Sub + + std::ops::Mul + + std::ops::Neg + + PartialOrd +{ + fn zero() -> Self; + fn one() -> Self; + fn half() -> Self; + fn from_elem(v: T) -> Self; + fn to_elem(self) -> T; + fn tanh(self) -> Self; + fn exp(self) -> Self; + fn recip(self) -> Self; + fn from_f64_const(v: f64) -> Self; + fn is_finite(self) -> bool; +} + +impl AccFloat for f32 { + #[inline] + fn zero() -> Self { + 0.0 + } + #[inline] + fn one() -> Self { + 1.0 + } + #[inline] + fn half() -> Self { + 0.5 + } + #[inline] + fn from_elem(v: T) -> Self { + v.to_f32() + } + #[inline] + fn to_elem(self) -> T { + T::from_f32(self) + } + #[inline] + fn tanh(self) -> Self { + f32::tanh(self) + } + #[inline] + fn exp(self) -> Self { + f32::exp(self) + } + #[inline] + fn recip(self) -> Self { + 1.0 / self + } + #[inline] + fn from_f64_const(v: f64) -> Self { + v as f32 + } + #[inline] + fn is_finite(self) -> bool { + f32::is_finite(self) + } +} + +impl AccFloat for f64 { + #[inline] + fn zero() -> Self { + 0.0 + } + #[inline] + fn one() -> Self { + 1.0 + } + #[inline] + fn half() -> Self { + 0.5 + } + #[inline] + fn from_elem(v: T) -> Self { + v.to_f64() + } + #[inline] + fn to_elem(self) -> T { + T::from_f64(self) + } + #[inline] + fn tanh(self) -> Self { + f64::tanh(self) + } + #[inline] + fn exp(self) -> Self { + f64::exp(self) + } + #[inline] + fn recip(self) -> Self { + 1.0 / self + } + #[inline] + fn from_f64_const(v: f64) -> Self { + v + } + #[inline] + fn is_finite(self) -> bool { + f64::is_finite(self) + } +} + +/// Backward pass for fused matmul + bias + activation. +/// +/// Given `output = activation(A @ B + bias)`, computes: +/// - `d_a = (grad * activation'(pre_act)) @ B^T` +/// - `d_b = A^T @ (grad * activation'(pre_act))` +/// - `d_bias = sum(grad * activation'(pre_act), dim=0)` +/// +/// where `pre_act = A @ B + bias`. +/// +/// # Safety +/// - All pointers must be valid for the specified dimensions +/// - Output pointers must not alias with input pointers +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn matmul_bias_activation_bwd_kernel( + grad: *const T, + a: *const T, + b: *const T, + bias: *const T, + d_a: *mut T, + d_b: *mut T, + d_bias: *mut T, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ld_grad: usize, + activation: GemmActivation, +) { + if T::DTYPE == DType::F64 { + bwd_in::( + grad, a, b, bias, d_a, d_b, d_bias, m, n, k, lda, ldb, ld_grad, activation, + ); + } else { + bwd_in::( + grad, a, b, bias, d_a, d_b, d_bias, m, n, k, lda, ldb, ld_grad, activation, + ); + } +} + +#[inline] +#[allow(clippy::too_many_arguments)] +unsafe fn bwd_in( + grad: *const T, + a: *const T, + b: *const T, + bias: *const T, + d_a: *mut T, + d_b: *mut T, + d_bias: *mut T, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ld_grad: usize, + activation: GemmActivation, +) { + let total = m * n; + + // Step 1: pre_act = A @ B + bias, then grad_pre = grad * activation'(pre_act) + let mut grad_pre = vec![A::zero(); total]; + for i in 0..m { + for j in 0..n { + grad_pre[i * n + j] = A::from_elem(*bias.add(j)); + } + } + for i in 0..m { + for kk in 0..k { + let a_val: A = A::from_elem(*a.add(i * lda + kk)); + for j in 0..n { + grad_pre[i * n + j] += a_val * A::from_elem(*b.add(kk * ldb + j)); + } + } + } + for i in 0..total { + let g: A = A::from_elem(*grad.add((i / n) * ld_grad + (i % n))); + let deriv = activation_derivative(grad_pre[i], activation); + // Guard against non-finite derivatives from platform-specific FP edge cases + let deriv = if deriv.is_finite() { deriv } else { A::zero() }; + grad_pre[i] = g * deriv; + } + + // Step 2: d_a = grad_pre @ B^T + let mut d_a_buf = vec![A::zero(); m * k]; + for i in 0..m { + for j in 0..n { + let gp = grad_pre[i * n + j]; + for kk in 0..k { + d_a_buf[i * k + kk] += gp * A::from_elem(*b.add(kk * ldb + j)); + } + } + } + for i in 0..m * k { + *d_a.add(i) = d_a_buf[i].to_elem::(); + } + + // Step 3: d_b = A^T @ grad_pre + let mut d_b_buf = vec![A::zero(); k * n]; + for i in 0..m { + for kk in 0..k { + let a_val: A = A::from_elem(*a.add(i * lda + kk)); + for j in 0..n { + d_b_buf[kk * n + j] += a_val * grad_pre[i * n + j]; + } + } + } + for i in 0..k * n { + *d_b.add(i) = d_b_buf[i].to_elem::(); + } + + // Step 4: d_bias = sum(grad_pre, dim=0) + let mut d_bias_buf = vec![A::zero(); n]; + for i in 0..m { + for j in 0..n { + d_bias_buf[j] += grad_pre[i * n + j]; + } + } + for j in 0..n { + *d_bias.add(j) = d_bias_buf[j].to_elem::(); + } +} + +/// Compute activation derivative at the pre-activation value. +fn activation_derivative(pre_act: A, activation: GemmActivation) -> A { + match activation { + GemmActivation::None => A::one(), + GemmActivation::ReLU => { + if pre_act > A::zero() { + A::one() + } else { + A::zero() + } + } + GemmActivation::GELU => { + let sqrt_2_over_pi = A::from_f64_const(0.7978845608028654); + let coef = A::from_f64_const(0.044715); + let three = A::from_f64_const(3.0); + let x = pre_act; + let inner = sqrt_2_over_pi * (x + coef * x * x * x); + let tanh_val = inner.tanh(); + let sech2 = A::one() - tanh_val * tanh_val; + let d_inner = sqrt_2_over_pi * (A::one() + three * coef * x * x); + A::half() * (A::one() + tanh_val) + A::half() * x * sech2 * d_inner + } + GemmActivation::SiLU => { + let sig = (A::one() + (-pre_act).exp()).recip(); + sig + pre_act * sig * (A::one() - sig) + } + GemmActivation::Sigmoid => { + let sig = (A::one() + (-pre_act).exp()).recip(); + sig * (A::one() - sig) + } + GemmActivation::Tanh => { + let t = pre_act.tanh(); + A::one() - t * t + } + } +} diff --git a/src/runtime/cpu/kernels/gemm_epilogue/forward.rs b/src/runtime/cpu/kernels/gemm_epilogue/forward.rs new file mode 100644 index 00000000..39c78e29 --- /dev/null +++ b/src/runtime/cpu/kernels/gemm_epilogue/forward.rs @@ -0,0 +1,330 @@ +//! Forward kernels for GEMM epilogue operations. +//! +//! matmul_bias_activation: C = activation(A @ B + bias) +//! matmul_bias_residual: C = A @ B + bias + residual + +use crate::dtype::Element; +use crate::ops::GemmActivation; + +/// Fused matmul + bias + activation kernel. +/// +/// Computes `activation(A @ B + bias)` in a single pass: +/// 1. Initialize output with bias +/// 2. Accumulate matmul result (ikj order) +/// 3. Apply activation in-place +/// +/// # Safety +/// - All pointers must be valid for the specified dimensions +/// - `out` must not alias with `a`, `b`, or `bias` +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn matmul_bias_activation_kernel( + a: *const T, + b: *const T, + bias: *const T, + out: *mut T, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, + activation: GemmActivation, +) { + // For GemmActivation::None, just do matmul_bias (avoid activation dispatch overhead) + if activation == GemmActivation::None { + crate::runtime::cpu::kernels::matmul_bias_kernel(a, b, bias, out, m, n, k, lda, ldb, ldc); + return; + } + + // SIMD dispatch for f32/f64 on x86_64: matmul_bias first, then apply activation via SIMD + #[cfg(target_arch = "x86_64")] + { + use crate::dtype::DType; + match T::DTYPE { + DType::F32 => { + matmul_bias_activation_simd_f32( + a as *const f32, + b as *const f32, + bias as *const f32, + out as *mut f32, + m, + n, + k, + lda, + ldb, + ldc, + activation, + ); + return; + } + DType::F64 => { + matmul_bias_activation_simd_f64( + a as *const f64, + b as *const f64, + bias as *const f64, + out as *mut f64, + m, + n, + k, + lda, + ldb, + ldc, + activation, + ); + return; + } + _ => {} // Fall through to scalar + } + } + + matmul_bias_activation_scalar(a, b, bias, out, m, n, k, lda, ldb, ldc, activation); +} + +/// Fused matmul + bias + residual kernel. +/// +/// Computes `A @ B + bias + residual` in a single pass. +/// +/// # Safety +/// - All pointers must be valid for the specified dimensions +/// - `out` must not alias with `a`, `b`, `bias`, or `residual` +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn matmul_bias_residual_kernel( + a: *const T, + b: *const T, + bias: *const T, + residual: *const T, + out: *mut T, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, +) { + // Initialize output with bias + residual + for i in 0..m { + for j in 0..n { + *out.add(i * ldc + j) = *bias.add(j) + *residual.add(i * ldc + j); + } + } + + // Accumulate matmul result (ikj order for cache locality) + for i in 0..m { + for kk in 0..k { + let a_val = *a.add(i * lda + kk); + for j in 0..n { + let out_ptr = out.add(i * ldc + j); + *out_ptr = *out_ptr + a_val * *b.add(kk * ldb + j); + } + } + } +} + +// ============================================================================ +// SIMD-accelerated paths (matmul_bias then SIMD activation) +// ============================================================================ + +#[cfg(target_arch = "x86_64")] +#[allow(clippy::too_many_arguments, dead_code)] +unsafe fn matmul_bias_activation_simd_f32( + a: *const f32, + b: *const f32, + bias: *const f32, + out: *mut f32, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, + activation: GemmActivation, +) { + use super::super::simd::matmul; + + // Step 1: Compute matmul_bias into output buffer + matmul::matmul_bias_f32(a, b, bias, out, m, n, k, lda, ldb, ldc); + + // Step 2: Apply activation in-place using SIMD + let total = m * n; + apply_activation_inplace_f32(out, total, activation); +} + +#[cfg(target_arch = "x86_64")] +#[allow(clippy::too_many_arguments, dead_code)] +unsafe fn matmul_bias_activation_simd_f64( + a: *const f64, + b: *const f64, + bias: *const f64, + out: *mut f64, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, + activation: GemmActivation, +) { + use super::super::simd::matmul; + + // Step 1: Compute matmul_bias into output buffer + matmul::matmul_bias_f64(a, b, bias, out, m, n, k, lda, ldb, ldc); + + // Step 2: Apply activation in-place using SIMD + let total = m * n; + apply_activation_inplace_f64(out, total, activation); +} + +/// Apply activation in-place on f32 buffer using SIMD helpers. +#[cfg(target_arch = "x86_64")] +#[allow(dead_code)] +unsafe fn apply_activation_inplace_f32(buf: *mut f32, len: usize, activation: GemmActivation) { + use super::super::simd::activations; + + match activation { + GemmActivation::None => {} + GemmActivation::ReLU => { + // ReLU is simple: max(0, x) — use scalar for in-place + for i in 0..len { + let val = *buf.add(i); + if val < 0.0 { + *buf.add(i) = 0.0; + } + } + } + GemmActivation::GELU => { + // Use SIMD gelu (reads from buf, writes to buf — safe since non-overlapping access) + activations::gelu_f32(buf as *const f32, buf, len); + } + GemmActivation::SiLU => { + activations::silu_f32(buf as *const f32, buf, len); + } + GemmActivation::Sigmoid => { + activations::sigmoid_f32(buf as *const f32, buf, len); + } + GemmActivation::Tanh => { + for i in 0..len { + *buf.add(i) = (*buf.add(i)).tanh(); + } + } + } +} + +/// Apply activation in-place on f64 buffer using SIMD helpers. +#[cfg(target_arch = "x86_64")] +#[allow(dead_code)] +unsafe fn apply_activation_inplace_f64(buf: *mut f64, len: usize, activation: GemmActivation) { + use super::super::simd::activations; + + match activation { + GemmActivation::None => {} + GemmActivation::ReLU => { + for i in 0..len { + let val = *buf.add(i); + if val < 0.0 { + *buf.add(i) = 0.0; + } + } + } + GemmActivation::GELU => { + activations::gelu_f64(buf as *const f64, buf, len); + } + GemmActivation::SiLU => { + activations::silu_f64(buf as *const f64, buf, len); + } + GemmActivation::Sigmoid => { + activations::sigmoid_f64(buf as *const f64, buf, len); + } + GemmActivation::Tanh => { + for i in 0..len { + *buf.add(i) = (*buf.add(i)).tanh(); + } + } + } +} + +// ============================================================================ +// Scalar fallback +// ============================================================================ + +#[allow(clippy::too_many_arguments, dead_code)] +unsafe fn matmul_bias_activation_scalar( + a: *const T, + b: *const T, + bias: *const T, + out: *mut T, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, + activation: GemmActivation, +) { + // Initialize output with bias + for i in 0..m { + for j in 0..n { + *out.add(i * ldc + j) = *bias.add(j); + } + } + + // Accumulate matmul result (ikj order) + for i in 0..m { + for kk in 0..k { + let a_val = *a.add(i * lda + kk); + for j in 0..n { + let out_ptr = out.add(i * ldc + j); + *out_ptr = *out_ptr + a_val * *b.add(kk * ldb + j); + } + } + } + + // Apply activation in-place + apply_activation_scalar(out, m * n, activation); +} + +/// Apply activation element-wise using scalar math (generic over Element). +#[allow(dead_code)] +unsafe fn apply_activation_scalar(buf: *mut T, len: usize, activation: GemmActivation) { + match activation { + GemmActivation::None => {} + GemmActivation::ReLU => { + for i in 0..len { + let val = *buf.add(i); + if val < T::zero() { + *buf.add(i) = T::zero(); + } + } + } + GemmActivation::GELU => { + // GELU needs float math — convert through f64 + for i in 0..len { + let x = (*buf.add(i)).to_f64(); + let inner = 0.7978845608028654 * (x + 0.044715 * x * x * x); + let result = 0.5 * x * (1.0 + inner.tanh()); + *buf.add(i) = T::from_f64(result); + } + } + GemmActivation::SiLU => { + for i in 0..len { + let x = (*buf.add(i)).to_f64(); + let result = x / (1.0 + (-x).exp()); + *buf.add(i) = T::from_f64(result); + } + } + GemmActivation::Sigmoid => { + for i in 0..len { + let x = (*buf.add(i)).to_f64(); + let result = 1.0 / (1.0 + (-x).exp()); + *buf.add(i) = T::from_f64(result); + } + } + GemmActivation::Tanh => { + for i in 0..len { + let x = (*buf.add(i)).to_f64(); + *buf.add(i) = T::from_f64(x.tanh()); + } + } + } +} diff --git a/src/runtime/cpu/kernels/gemm_epilogue/mod.rs b/src/runtime/cpu/kernels/gemm_epilogue/mod.rs new file mode 100644 index 00000000..a2c5415b --- /dev/null +++ b/src/runtime/cpu/kernels/gemm_epilogue/mod.rs @@ -0,0 +1,9 @@ +//! GEMM epilogue CPU kernels +//! +//! Fused matmul + bias + activation/residual kernels. + +pub mod backward; +pub mod forward; + +pub use backward::matmul_bias_activation_bwd_kernel; +pub use forward::{matmul_bias_activation_kernel, matmul_bias_residual_kernel}; diff --git a/src/runtime/cpu/kernels/index.rs b/src/runtime/cpu/kernels/index.rs index 5fb096d2..5e1da487 100644 --- a/src/runtime/cpu/kernels/index.rs +++ b/src/runtime/cpu/kernels/index.rs @@ -871,3 +871,35 @@ pub unsafe fn gather_2d_kernel( true } + +/// Slice assign kernel: copies src into a slice of dst along a dimension. +/// +/// dst is first fully copied to output, then src overwrites the slice region. +/// +/// # Safety +/// +/// All pointers must be valid with the correct element counts. +pub unsafe fn slice_assign_kernel( + dst: *const T, + src: *const T, + out: *mut T, + outer_size: usize, + dst_dim_size: usize, + src_dim_size: usize, + inner_size: usize, + start: usize, +) { + let dst_total = outer_size * dst_dim_size * inner_size; + + // Copy entire dst to output + std::ptr::copy_nonoverlapping(dst, out, dst_total); + + // Overwrite the slice region with src + for o in 0..outer_size { + for s in 0..src_dim_size { + let src_offset = o * src_dim_size * inner_size + s * inner_size; + let dst_offset = o * dst_dim_size * inner_size + (start + s) * inner_size; + std::ptr::copy_nonoverlapping(src.add(src_offset), out.add(dst_offset), inner_size); + } + } +} diff --git a/src/runtime/cpu/kernels/matmul.rs b/src/runtime/cpu/kernels/matmul.rs index 3c6e2365..7a52f974 100644 --- a/src/runtime/cpu/kernels/matmul.rs +++ b/src/runtime/cpu/kernels/matmul.rs @@ -3,7 +3,379 @@ //! This module provides matrix multiplication with automatic SIMD dispatch. //! On x86-64, f32 and f64 matmuls use AVX-512 or AVX2+FMA when available. -use crate::dtype::{DType, Element}; +use crate::dtype::Element; + +/// SIMD-accelerated f32 dot product for use in half-precision GEMV-BT. +/// +/// Dispatches to AVX-512 or AVX2+FMA based on detected SIMD level. +/// Each backend is a separate function with `#[target_feature]` so the compiler +/// can optimize the entire function body for that ISA. +/// +/// # Safety +/// - `a` and `b` must be valid pointers to `len` f32 elements +#[cfg(all(feature = "f16", target_arch = "x86_64"))] +#[inline] +unsafe fn simd_dot_f32( + a: *const f32, + b: *const f32, + len: usize, + level: super::simd::SimdLevel, +) -> f32 { + use super::simd::SimdLevel; + + match level { + SimdLevel::Avx512 => simd_dot_f32_avx512(a, b, len), + SimdLevel::Avx2Fma => simd_dot_f32_avx2(a, b, len), + _ => { + let mut sum = 0.0f32; + for i in 0..len { + sum += *a.add(i) * *b.add(i); + } + sum + } + } +} + +#[cfg(all(feature = "f16", target_arch = "x86_64"))] +#[target_feature(enable = "avx512f")] +unsafe fn simd_dot_f32_avx512(a: *const f32, b: *const f32, len: usize) -> f32 { + use std::arch::x86_64::*; + let mut offset = 0; + let mut acc0 = _mm512_setzero_ps(); + let mut acc1 = _mm512_setzero_ps(); + while offset + 32 <= len { + let av0 = _mm512_loadu_ps(a.add(offset)); + let bv0 = _mm512_loadu_ps(b.add(offset)); + acc0 = _mm512_fmadd_ps(av0, bv0, acc0); + let av1 = _mm512_loadu_ps(a.add(offset + 16)); + let bv1 = _mm512_loadu_ps(b.add(offset + 16)); + acc1 = _mm512_fmadd_ps(av1, bv1, acc1); + offset += 32; + } + acc0 = _mm512_add_ps(acc0, acc1); + while offset + 16 <= len { + let av = _mm512_loadu_ps(a.add(offset)); + let bv = _mm512_loadu_ps(b.add(offset)); + acc0 = _mm512_fmadd_ps(av, bv, acc0); + offset += 16; + } + let mut sum = _mm512_reduce_add_ps(acc0); + while offset < len { + sum += *a.add(offset) * *b.add(offset); + offset += 1; + } + sum +} + +#[cfg(all(feature = "f16", target_arch = "x86_64"))] +#[target_feature(enable = "avx2", enable = "fma")] +unsafe fn simd_dot_f32_avx2(a: *const f32, b: *const f32, len: usize) -> f32 { + use std::arch::x86_64::*; + let mut offset = 0; + let mut acc0 = _mm256_setzero_ps(); + let mut acc1 = _mm256_setzero_ps(); + // Process 16 floats per iteration with 2 independent accumulators + // to hide FMA latency (4-5 cycles on modern x86) + while offset + 16 <= len { + let av0 = _mm256_loadu_ps(a.add(offset)); + let bv0 = _mm256_loadu_ps(b.add(offset)); + acc0 = _mm256_fmadd_ps(av0, bv0, acc0); + let av1 = _mm256_loadu_ps(a.add(offset + 8)); + let bv1 = _mm256_loadu_ps(b.add(offset + 8)); + acc1 = _mm256_fmadd_ps(av1, bv1, acc1); + offset += 16; + } + acc0 = _mm256_add_ps(acc0, acc1); + // Handle remaining 8-float chunk + while offset + 8 <= len { + let av = _mm256_loadu_ps(a.add(offset)); + let bv = _mm256_loadu_ps(b.add(offset)); + acc0 = _mm256_fmadd_ps(av, bv, acc0); + offset += 8; + } + // Horizontal sum of 256-bit accumulator + let hi = _mm256_extractf128_ps(acc0, 1); + let lo = _mm256_castps256_ps128(acc0); + let sum128 = _mm_add_ps(lo, hi); + let shuf = _mm_movehdup_ps(sum128); + let sums = _mm_add_ps(sum128, shuf); + let shuf2 = _mm_movehl_ps(sums, sums); + let sums2 = _mm_add_ss(sums, shuf2); + let mut sum = _mm_cvtss_f32(sums2); + while offset < len { + sum += *a.add(offset) * *b.add(offset); + offset += 1; + } + sum +} + +/// GEMV-BT kernel: C[M,N] = A[M,K] @ B^T where B is stored as contiguous [N,K] +/// +/// This avoids the costly contiguous copy of transposed weight matrices during +/// decode (M=1). Both A rows and B rows are contiguous, making this ideal for +/// SIMD dot products. +/// +/// # Arguments +/// * `a` - Pointer to matrix A (m × k), contiguous row-major +/// * `b_nk` - Pointer to B in [N,K] layout (NOT the transposed view) +/// * `out` - Pointer to output C (m × n), row-major with leading dimension ldc +/// * `m`, `n`, `k` - Matrix dimensions +/// * `ldc` - Leading dimension of output +/// +/// # Safety +/// - `a` must be valid for m*k contiguous reads +/// - `b_nk` must be valid for n*k contiguous reads +/// - `out` must be valid for m*ldc writes +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn gemv_bt_kernel( + a: *const T, + b_nk: *const T, + out: *mut T, + m: usize, + n: usize, + k: usize, + ldc: usize, +) { + #[cfg(target_arch = "x86_64")] + { + use super::simd::detect_simd; + use super::simd::matmul::gemv_bt; + use crate::dtype::DType; + + match T::DTYPE { + DType::F32 => { + let level = detect_simd(); + gemv_bt::gemv_bt_f32( + a as *const f32, + b_nk as *const f32, + out as *mut f32, + m, + n, + k, + ldc, + level, + ); + return; + } + DType::F64 => { + let level = detect_simd(); + gemv_bt::gemv_bt_f64( + a as *const f64, + b_nk as *const f64, + out as *mut f64, + m, + n, + k, + ldc, + level, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 | DType::BF16 => { + gemv_bt_via_f32(a, b_nk, out, m, n, k, ldc); + return; + } + _ => {} + } + } + + #[cfg(not(target_arch = "x86_64"))] + { + #[allow(unused_imports)] + use crate::dtype::DType; + match T::DTYPE { + #[cfg(feature = "f16")] + DType::F16 | DType::BF16 => { + gemv_bt_via_f32(a, b_nk, out, m, n, k, ldc); + return; + } + _ => {} + } + } + + // Scalar fallback + gemv_bt_scalar(a, b_nk, out, m, n, k, ldc); +} + +/// Scalar GEMV-BT fallback +#[inline] +#[allow(clippy::too_many_arguments)] +unsafe fn gemv_bt_scalar( + a: *const T, + b_nk: *const T, + out: *mut T, + m: usize, + n: usize, + k: usize, + ldc: usize, +) { + for row in 0..m { + let a_row = a.add(row * k); + let out_row = out.add(row * ldc); + for col in 0..n { + let b_row = b_nk.add(col * k); + let mut sum = T::zero(); + for i in 0..k { + sum = sum + *a_row.add(i) * *b_row.add(i); + } + *out_row.add(col) = sum; + } + } +} + +/// GEMV-BT for f16/bf16 via f32 conversion +/// +/// Converts A row to f32 (batch SIMD conversion), then converts each B row +/// to f32 in SIMD chunks and uses the f32 AVX2/AVX-512 dot product. +#[cfg(feature = "f16")] +#[inline] +#[allow(clippy::too_many_arguments)] +unsafe fn gemv_bt_via_f32( + a: *const T, + b_nk: *const T, + out: *mut T, + m: usize, + n: usize, + k: usize, + ldc: usize, +) { + // Convert A row to f32 once (small buffer, reused per row) + let mut a_f32 = vec![0.0f32; k]; + let mut b_f32 = vec![0.0f32; k]; + + #[cfg(target_arch = "x86_64")] + let level = super::simd::detect_simd(); + + for row in 0..m { + let a_row = a.add(row * k); + // Batch convert A row to f32 + batch_half_to_f32::(a_row, a_f32.as_mut_ptr(), k); + + let out_row = out.add(row * ldc); + + for col in 0..n { + let b_row = b_nk.add(col * k); + // Batch convert B row to f32 + batch_half_to_f32::(b_row, b_f32.as_mut_ptr(), k); + + // Use SIMD f32 dot product + #[cfg(target_arch = "x86_64")] + { + let dot = simd_dot_f32(a_f32.as_ptr(), b_f32.as_ptr(), k, level); + *out_row.add(col) = T::from_f32(dot); + } + #[cfg(not(target_arch = "x86_64"))] + { + let mut sum = 0.0f32; + for i in 0..k { + sum += a_f32[i] * b_f32[i]; + } + *out_row.add(col) = T::from_f32(sum); + } + } + } +} + +/// Batch convert half-precision (f16/bf16) elements to f32 using SIMD when available. +#[cfg(feature = "f16")] +#[inline] +unsafe fn batch_half_to_f32(src: *const T, dst: *mut f32, len: usize) { + use crate::dtype::DType; + match T::DTYPE { + #[cfg(target_arch = "x86_64")] + DType::BF16 => { + // BF16 → f32: shift left by 16 bits (bf16 is upper 16 bits of f32) + batch_bf16_to_f32(src as *const u16, dst, len); + } + #[cfg(target_arch = "x86_64")] + DType::F16 => { + // F16 → f32: use F16C instruction if available + batch_f16_to_f32(src as *const u16, dst, len); + } + _ => { + for i in 0..len { + *dst.add(i) = (*src.add(i)).to_f32(); + } + } + } +} + +/// BF16 → f32 conversion using SIMD bit-shift (bf16 is just f32 with lower 16 bits zeroed) +#[cfg(all(feature = "f16", target_arch = "x86_64"))] +#[inline] +unsafe fn batch_bf16_to_f32(src: *const u16, dst: *mut f32, len: usize) { + if is_x86_feature_detected!("avx2") { + batch_bf16_to_f32_avx2(src, dst, len); + } else { + batch_bf16_to_f32_scalar(src, dst, len); + } +} + +#[cfg(all(feature = "f16", target_arch = "x86_64"))] +#[target_feature(enable = "avx2")] +unsafe fn batch_bf16_to_f32_avx2(src: *const u16, dst: *mut f32, len: usize) { + use std::arch::x86_64::*; + let mut i = 0usize; + while i + 8 <= len { + let bf16_vals = _mm_loadu_si128(src.add(i) as *const __m128i); + let i32_vals = _mm256_cvtepu16_epi32(bf16_vals); + let f32_bits = _mm256_slli_epi32(i32_vals, 16); + _mm256_storeu_ps(dst.add(i), _mm256_castsi256_ps(f32_bits)); + i += 8; + } + // Scalar tail + while i < len { + let bits = (*src.add(i) as u32) << 16; + *dst.add(i) = f32::from_bits(bits); + i += 1; + } +} + +#[cfg(all(feature = "f16", target_arch = "x86_64"))] +unsafe fn batch_bf16_to_f32_scalar(src: *const u16, dst: *mut f32, len: usize) { + for i in 0..len { + let bits = (*src.add(i) as u32) << 16; + *dst.add(i) = f32::from_bits(bits); + } +} + +/// F16 → f32 conversion using F16C instructions (vcvtph2ps) +#[cfg(all(feature = "f16", target_arch = "x86_64"))] +#[inline] +unsafe fn batch_f16_to_f32(src: *const u16, dst: *mut f32, len: usize) { + if is_x86_feature_detected!("f16c") { + batch_f16_to_f32_f16c(src, dst, len); + } else { + batch_f16_to_f32_scalar(src, dst, len); + } +} + +#[cfg(all(feature = "f16", target_arch = "x86_64"))] +#[target_feature(enable = "f16c", enable = "avx")] +unsafe fn batch_f16_to_f32_f16c(src: *const u16, dst: *mut f32, len: usize) { + use std::arch::x86_64::*; + let mut i = 0usize; + while i + 8 <= len { + let f16_vals = _mm_loadu_si128(src.add(i) as *const __m128i); + let f32_vals = _mm256_cvtph_ps(f16_vals); + _mm256_storeu_ps(dst.add(i), f32_vals); + i += 8; + } + // Scalar tail + while i < len { + *dst.add(i) = half::f16::from_bits(*src.add(i)).to_f32(); + i += 1; + } +} + +#[cfg(all(feature = "f16", target_arch = "x86_64"))] +unsafe fn batch_f16_to_f32_scalar(src: *const u16, dst: *mut f32, len: usize) { + for i in 0..len { + *dst.add(i) = half::f16::from_bits(*src.add(i)).to_f32(); + } +} /// Matrix multiplication with automatic SIMD dispatch: C = A @ B /// @@ -39,8 +411,23 @@ pub unsafe fn matmul_kernel( #[cfg(target_arch = "x86_64")] { use super::simd::matmul; + use crate::dtype::DType; match T::DTYPE { + DType::I32 => { + matmul::int32::matmul_i32( + a as *const i32, + b as *const i32, + out as *mut i32, + m, + n, + k, + lda, + ldb, + ldc, + ); + return; + } DType::F32 => { matmul::matmul_f32( a as *const f32, @@ -152,6 +539,7 @@ pub unsafe fn matmul_bias_kernel( #[cfg(target_arch = "x86_64")] { use super::simd::matmul; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { diff --git a/src/runtime/cpu/kernels/matmul_i8.rs b/src/runtime/cpu/kernels/matmul_i8.rs new file mode 100644 index 00000000..0fd8794b --- /dev/null +++ b/src/runtime/cpu/kernels/matmul_i8.rs @@ -0,0 +1,27 @@ +//! i8 × i8 → i32 matrix multiplication kernel +//! +//! Entry point for i8 matmul that dispatches to SIMD dot-product-based implementation. + +/// i8 × i8 → i32 matmul: C[m×n] = A[m×k] @ B[k×n] +/// +/// Input matrices are i8, output is i32 (standard quantized matmul accumulation). +/// +/// # Safety +/// - `a` must point to m×lda i8 elements +/// - `b` must point to k×ldb i8 elements +/// - `out` must point to m×ldc i32 elements +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn matmul_i8_to_i32_kernel( + a: *const i8, + b: *const i8, + out: *mut i32, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, +) { + super::simd::matmul::int8::matmul_i8_to_i32(a, b, out, m, n, k, lda, ldb, ldc); +} diff --git a/src/runtime/cpu/kernels/memory.rs b/src/runtime/cpu/kernels/memory.rs index 570bb4a3..ffe48a06 100644 --- a/src/runtime/cpu/kernels/memory.rs +++ b/src/runtime/cpu/kernels/memory.rs @@ -1,8 +1,7 @@ //! Memory operation kernels (fill, copy, cast, random) +use super::rng; use crate::dtype::Element; -use rand::Rng; -use rand_distr::{Distribution, StandardNormal}; /// Fill buffer with a constant value /// @@ -311,14 +310,14 @@ pub unsafe fn cast_kernel( /// - `out` must be a valid pointer to `len` elements #[inline] pub unsafe fn rand_uniform_kernel(out: *mut T, len: usize) { - let mut rng = rand::rng(); + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, len); // Check once if this type can round values near 1.0 up to 1.0 let needs_clamp = T::from_f64(0.9999).to_f64() >= 1.0; for elem in out_slice.iter_mut() { - let val: f64 = rng.random(); + let val = rng::sample_uniform(&mut prng); *elem = T::from_f64(val); // For reduced-precision types (BF16, FP8), rounding can push values // near 1.0 up to exactly 1.0. Clamp to the largest representable @@ -337,12 +336,11 @@ pub unsafe fn rand_uniform_kernel(out: *mut T, len: usize) { /// - `out` must be a valid pointer to `len` elements #[inline] pub unsafe fn rand_normal_kernel(out: *mut T, len: usize) { - let mut rng = rand::rng(); - let normal = StandardNormal; + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, len); for elem in out_slice.iter_mut() { - let val: f64 = normal.sample(&mut rng); + let val = rng::sample_normal(&mut prng); *elem = T::from_f64(val); } } @@ -356,15 +354,11 @@ pub unsafe fn rand_normal_kernel(out: *mut T, len: usize) { /// - `low < high` must be satisfied #[inline] pub unsafe fn randint_kernel(out: *mut T, low: i64, high: i64, len: usize) { - use rand::distr::Uniform; - use rand::prelude::Distribution; - - let mut rng = rand::rng(); - let dist = Uniform::new(low, high).unwrap(); + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, len); for elem in out_slice.iter_mut() { - let val: i64 = dist.sample(&mut rng); + let val = rng::sample_uniform_int(&mut prng, low, high); *elem = T::from_f64(val as f64); } } @@ -448,7 +442,7 @@ pub unsafe fn multinomial_kernel_with_replacement( num_categories: usize, num_samples: usize, ) { - let mut rng = rand::rng(); + let mut prng = rng::thread_rng(); for dist in 0..num_distributions { let prob_row = std::slice::from_raw_parts(probs.add(dist * num_categories), num_categories); @@ -475,7 +469,7 @@ pub unsafe fn multinomial_kernel_with_replacement( let out_row = std::slice::from_raw_parts_mut(out.add(dist * num_samples), num_samples); for sample in out_row { - let u: f64 = rng.random(); + let u = rng::sample_uniform(&mut prng); // Binary search: find first index where cdf[i] >= u let idx = cdf.partition_point(|&c| c < u); *sample = idx.min(num_categories - 1) as i64; @@ -502,7 +496,7 @@ pub unsafe fn multinomial_kernel_without_replacement( num_categories: usize, num_samples: usize, ) { - let mut rng = rand::rng(); + let mut prng = rng::thread_rng(); for dist in 0..num_distributions { let prob_row = std::slice::from_raw_parts(probs.add(dist * num_categories), num_categories); @@ -527,7 +521,7 @@ pub unsafe fn multinomial_kernel_without_replacement( } // Sample - let u: f64 = rng.random(); + let u = rng::sample_uniform(&mut prng); let idx = cdf.partition_point(|&c| c < u).min(num_categories - 1); *sample = idx as i64; @@ -542,7 +536,7 @@ pub unsafe fn multinomial_kernel_without_replacement( /// # Safety /// - `out` must be a valid pointer to `n` elements of i64 pub unsafe fn randperm_kernel(out: *mut i64, n: usize) { - let mut rng = rand::rng(); + let mut prng = rng::thread_rng(); let out_slice = std::slice::from_raw_parts_mut(out, n); // Initialize with [0, 1, 2, ..., n-1] @@ -552,7 +546,7 @@ pub unsafe fn randperm_kernel(out: *mut i64, n: usize) { // Fisher-Yates shuffle for i in (1..n).rev() { - let j = rng.random_range(0..=i); + let j = (prng.next() % (i as u64 + 1)) as usize; out_slice.swap(i, j); } } diff --git a/src/runtime/cpu/kernels/mod.rs b/src/runtime/cpu/kernels/mod.rs index 000788c6..c6d713a0 100644 --- a/src/runtime/cpu/kernels/mod.rs +++ b/src/runtime/cpu/kernels/mod.rs @@ -14,13 +14,18 @@ pub mod cumulative; pub mod distance; pub mod distributions; pub mod fft; +pub mod fused_add_norm; +pub mod fused_elementwise; +pub mod gemm_epilogue; pub mod index; pub mod logical; pub mod matmul; +pub mod matmul_i8; pub mod memory; pub mod norm; pub mod quasirandom; pub mod reduce; +pub(crate) mod rng; pub mod scalar; pub mod semiring_matmul; pub mod simd; @@ -28,6 +33,8 @@ pub mod sobol_data; pub mod sort; #[cfg(feature = "sparse")] pub mod sparse; +#[cfg(feature = "sparse")] +pub mod sparse_24; pub mod unary; pub mod where_select; @@ -59,25 +66,36 @@ pub use fft::{ fftshift_c64, fftshift_c128, ifftshift_c64, ifftshift_c128, irfft_c64, irfft_c128, rfft_c64, rfft_c128, stockham_fft_batched_c64, stockham_fft_batched_c128, }; +pub use fused_add_norm::{ + fused_add_layer_norm_bwd_kernel, fused_add_layer_norm_kernel, fused_add_rms_norm_bwd_kernel, + fused_add_rms_norm_kernel, +}; +pub use fused_elementwise::{ + fused_add_mul_kernel, fused_mul_add_kernel, fused_mul_add_scalar_kernel, +}; +pub use gemm_epilogue::{ + matmul_bias_activation_bwd_kernel, matmul_bias_activation_kernel, matmul_bias_residual_kernel, +}; pub use index::{ bincount_kernel, embedding_lookup_kernel, gather_2d_kernel, gather_kernel, gather_nd_kernel, index_put_kernel, index_select_kernel, masked_fill_kernel, masked_select_kernel, - max_i64_kernel, scatter_kernel, scatter_reduce_kernel, + max_i64_kernel, scatter_kernel, scatter_reduce_kernel, slice_assign_kernel, }; pub use logical::{logical_and_kernel, logical_not_kernel, logical_or_kernel, logical_xor_kernel}; -pub use matmul::{matmul_bias_kernel, matmul_kernel}; +pub use matmul::{gemv_bt_kernel, matmul_bias_kernel, matmul_kernel}; +pub use matmul_i8::matmul_i8_to_i32_kernel; pub use memory::{ arange_kernel, cast_kernel, copy_kernel, eye_kernel, fill_kernel, linspace_kernel, multinomial_kernel_with_replacement, multinomial_kernel_without_replacement, one_hot_kernel, rand_normal_kernel, rand_uniform_kernel, randint_kernel, randperm_kernel, }; -pub use norm::{layer_norm_kernel, rms_norm_kernel}; +pub use norm::{group_norm_kernel, layer_norm_kernel, rms_norm_kernel}; pub use quasirandom::{ halton_f32, halton_f64, latin_hypercube_f32, latin_hypercube_f64, sobol_f32, sobol_f64, }; pub use reduce::{ Accumulator, argmax_kernel, argmin_kernel, reduce_kernel, reduce_kernel_with_precision, - softmax_kernel, variance_kernel, + softmax_bwd_kernel, softmax_kernel, variance_kernel, }; pub use scalar::{rsub_scalar_kernel, scalar_op_kernel}; pub use sort::{ @@ -86,13 +104,18 @@ pub use sort::{ sort_values_kernel, topk_kernel, unique_with_counts_kernel, }; pub use unary::{ - clamp_kernel, elu_kernel, gelu_kernel, isinf_kernel, isnan_kernel, leaky_relu_kernel, - relu_kernel, sigmoid_kernel, silu_kernel, unary_op_kernel, + clamp_kernel, elu_kernel, gelu_kernel, gelu_mul_kernel, isinf_kernel, isnan_kernel, + leaky_relu_kernel, relu_kernel, relu_mul_kernel, sigmoid_kernel, sigmoid_mul_kernel, + silu_kernel, silu_mul_kernel, unary_op_kernel, }; pub use where_select::{ where_kernel, where_kernel_generic, where_strided_kernel, where_strided_kernel_generic, }; +// Re-export SIMD dot product kernels for downstream crates (e.g., boostr quantized ops) +#[allow(unused_imports)] +pub use simd::dot::{i8xi8_dot_f32, i8xi8_dot_i32}; + // Re-export sparse kernel functions for external use #[cfg(feature = "sparse")] #[allow(unused_imports)] diff --git a/src/runtime/cpu/kernels/norm.rs b/src/runtime/cpu/kernels/norm.rs index c6a98188..1e14a3d3 100644 --- a/src/runtime/cpu/kernels/norm.rs +++ b/src/runtime/cpu/kernels/norm.rs @@ -3,7 +3,7 @@ //! Provides normalization operations with automatic SIMD dispatch. //! On x86-64, f32 and f64 operations use AVX-512 or AVX2 when available. -use crate::dtype::{DType, Element}; +use crate::dtype::Element; /// RMS Normalization: output = input * rsqrt(mean(input^2) + eps) * weight /// @@ -39,6 +39,7 @@ pub unsafe fn rms_norm_kernel( #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] { use super::simd::norm; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -63,6 +64,30 @@ pub unsafe fn rms_norm_kernel( ); return; } + #[cfg(feature = "f16")] + DType::F16 => { + norm::rms_norm_f16( + input as *const half::f16, + weight as *const half::f16, + out as *mut half::f16, + batch_size, + hidden_size, + eps, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + norm::rms_norm_bf16( + input as *const half::bf16, + weight as *const half::bf16, + out as *mut half::bf16, + batch_size, + hidden_size, + eps, + ); + return; + } _ => {} // Fall through to scalar } } @@ -143,6 +168,7 @@ pub unsafe fn layer_norm_kernel( #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] { use super::simd::norm; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -169,6 +195,32 @@ pub unsafe fn layer_norm_kernel( ); return; } + #[cfg(feature = "f16")] + DType::F16 => { + norm::layer_norm_f16( + input as *const half::f16, + weight as *const half::f16, + bias as *const half::f16, + out as *mut half::f16, + batch_size, + hidden_size, + eps, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + norm::layer_norm_bf16( + input as *const half::bf16, + weight as *const half::bf16, + bias as *const half::bf16, + out as *mut half::bf16, + batch_size, + hidden_size, + eps, + ); + return; + } _ => {} // Fall through to scalar } } @@ -224,3 +276,72 @@ unsafe fn layer_norm_kernel_scalar( } } } + +/// Group Normalization kernel. +/// +/// Input layout: `[batch, channels, spatial]` (contiguous). +/// For each (batch, group), computes mean/var over `channels_per_group * spatial` elements, +/// then applies per-channel weight and bias. +/// +/// # Safety +/// - `input` and `out`: valid for `batch * channels * spatial` elements +/// - `weight` and `bias`: valid for `channels` elements +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn group_norm_kernel( + input: *const T, + weight: *const T, + bias: *const T, + out: *mut T, + batch: usize, + channels: usize, + spatial: usize, + num_groups: usize, + channels_per_group: usize, + eps: f32, +) { + let eps = eps as f64; + let group_size = channels_per_group * spatial; + + for b in 0..batch { + for g in 0..num_groups { + let c_start = g * channels_per_group; + + // Compute mean over group + let mut sum = 0.0f64; + for c in 0..channels_per_group { + let ch = c_start + c; + let offset = (b * channels + ch) * spatial; + for s in 0..spatial { + sum += (*input.add(offset + s)).to_f64(); + } + } + let mean = sum / group_size as f64; + + // Compute variance over group + let mut var_sum = 0.0f64; + for c in 0..channels_per_group { + let ch = c_start + c; + let offset = (b * channels + ch) * spatial; + for s in 0..spatial { + let diff = (*input.add(offset + s)).to_f64() - mean; + var_sum += diff * diff; + } + } + let inv_std = 1.0 / (var_sum / group_size as f64 + eps).sqrt(); + + // Normalize and apply per-channel affine + for c in 0..channels_per_group { + let ch = c_start + c; + let w = (*weight.add(ch)).to_f64(); + let bi = (*bias.add(ch)).to_f64(); + let offset = (b * channels + ch) * spatial; + for s in 0..spatial { + let x = (*input.add(offset + s)).to_f64(); + let result = (x - mean) * inv_std * w + bi; + *out.add(offset + s) = T::from_f64(result); + } + } + } + } +} diff --git a/src/runtime/cpu/kernels/quasirandom.rs b/src/runtime/cpu/kernels/quasirandom.rs index 5ee6f428..1bddf32d 100644 --- a/src/runtime/cpu/kernels/quasirandom.rs +++ b/src/runtime/cpu/kernels/quasirandom.rs @@ -2,12 +2,11 @@ //! //! Implements Sobol, Halton, and Latin Hypercube Sampling sequences. +use super::rng; use super::sobol_data::{MAX_SOBOL_DIMENSION, get_polynomial}; use crate::ops::common::quasirandom::{ SOBOL_BITS, compute_direction_vectors, dimension_zero_vectors, }; -use rand::Rng; -use rand::seq::SliceRandom; /// Generate Sobol sequence points (F32 version). /// @@ -237,20 +236,20 @@ fn van_der_corput_single_f64(mut index: usize, base: u32) -> f64 { /// - `out` must point to valid memory of length `n_samples * dimension` #[inline] pub unsafe fn latin_hypercube_f32(out: *mut f32, n_samples: usize, dimension: usize) { - let mut rng = rand::rng(); + let mut prng = rng::thread_rng(); for d in 0..dimension { // Create stratified intervals let mut intervals: Vec = (0..n_samples).collect(); // Shuffle intervals - intervals.shuffle(&mut rng); + rng::shuffle(&mut prng, &mut intervals); // Generate random point within each interval for (i, &interval) in intervals.iter().enumerate() { let lower = interval as f32 / n_samples as f32; let upper = (interval + 1) as f32 / n_samples as f32; - let random_offset: f32 = rng.random_range(0.0..1.0); + let random_offset = rng::sample_uniform(&mut prng) as f32; *out.add(i * dimension + d) = lower + random_offset * (upper - lower); } @@ -260,16 +259,16 @@ pub unsafe fn latin_hypercube_f32(out: *mut f32, n_samples: usize, dimension: us /// Generate Latin Hypercube samples (F64 version). #[inline] pub unsafe fn latin_hypercube_f64(out: *mut f64, n_samples: usize, dimension: usize) { - let mut rng = rand::rng(); + let mut prng = rng::thread_rng(); for d in 0..dimension { let mut intervals: Vec = (0..n_samples).collect(); - intervals.shuffle(&mut rng); + rng::shuffle(&mut prng, &mut intervals); for (i, &interval) in intervals.iter().enumerate() { let lower = interval as f64 / n_samples as f64; let upper = (interval + 1) as f64 / n_samples as f64; - let random_offset: f64 = rng.random_range(0.0..1.0); + let random_offset = rng::sample_uniform(&mut prng); *out.add(i * dimension + d) = lower + random_offset * (upper - lower); } diff --git a/src/runtime/cpu/kernels/reduce/mod.rs b/src/runtime/cpu/kernels/reduce/mod.rs index 4c8fa225..593206a8 100644 --- a/src/runtime/cpu/kernels/reduce/mod.rs +++ b/src/runtime/cpu/kernels/reduce/mod.rs @@ -5,9 +5,11 @@ mod special; -pub use special::{argmax_kernel, argmin_kernel, softmax_kernel, variance_kernel}; +pub use special::{ + argmax_kernel, argmin_kernel, softmax_bwd_kernel, softmax_kernel, variance_kernel, +}; -use crate::dtype::{DType, Element}; +use crate::dtype::Element; use crate::ops::{AccumulationPrecision, ReduceOp}; /// Reduce along contiguous dimension with automatic SIMD dispatch @@ -39,6 +41,7 @@ pub unsafe fn reduce_kernel( #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] { use super::simd::reduce; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -61,6 +64,28 @@ pub unsafe fn reduce_kernel( ); return; } + #[cfg(feature = "f16")] + DType::F16 => { + reduce::reduce_f16( + op, + a as *const half::f16, + out as *mut half::f16, + reduce_size, + outer_size, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + reduce::reduce_bf16( + op, + a as *const half::bf16, + out as *mut half::bf16, + reduce_size, + outer_size, + ); + return; + } _ => {} // Fall through to scalar } } diff --git a/src/runtime/cpu/kernels/reduce/special.rs b/src/runtime/cpu/kernels/reduce/special.rs index eb5c8b6d..c94397ca 100644 --- a/src/runtime/cpu/kernels/reduce/special.rs +++ b/src/runtime/cpu/kernels/reduce/special.rs @@ -2,7 +2,7 @@ //! //! Contains argmax, argmin, softmax, and variance kernels. -use crate::dtype::{DType, Element}; +use crate::dtype::Element; /// Argmax along a dimension - returns indices of maximum values /// @@ -117,6 +117,7 @@ pub unsafe fn softmax_kernel( // Dispatch to SIMD for f32/f64 on x86-64 #[cfg(target_arch = "x86_64")] { + use crate::dtype::DType; use crate::runtime::cpu::kernels::simd::softmax; match T::DTYPE { @@ -128,6 +129,26 @@ pub unsafe fn softmax_kernel( softmax::softmax_f64(a as *const f64, out as *mut f64, outer_size, dim_size); return; } + #[cfg(feature = "f16")] + DType::F16 => { + softmax::softmax_f16( + a as *const half::f16, + out as *mut half::f16, + outer_size, + dim_size, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + softmax::softmax_bf16( + a as *const half::bf16, + out as *mut half::bf16, + outer_size, + dim_size, + ); + return; + } _ => {} // Fall through to scalar } } @@ -136,7 +157,7 @@ pub unsafe fn softmax_kernel( softmax_kernel_scalar(a, out, outer_size, dim_size); } -/// Scalar softmax for all Element types +/// Scalar softmax for all Element types using online algorithm (2-pass). #[inline] unsafe fn softmax_kernel_scalar( a: *const T, @@ -147,29 +168,175 @@ unsafe fn softmax_kernel_scalar( for o in 0..outer_size { let base = o * dim_size; - // Find max for numerical stability + // Pass 1: Online max + sum let mut max_val = (*a.add(base)).to_f64(); + let mut sum = 1.0f64; for d in 1..dim_size { let val = (*a.add(base + d)).to_f64(); if val > max_val { + sum = sum * (max_val - val).exp() + 1.0; max_val = val; + } else { + sum += (val - max_val).exp(); } } - // Compute exp(x - max) and sum - let mut sum = 0.0f64; + // Pass 2: exp(x - max) / sum + let inv_sum = 1.0 / sum; for d in 0..dim_size { let val = (*a.add(base + d)).to_f64(); - let exp_val = (val - max_val).exp(); - *out.add(base + d) = T::from_f64(exp_val); - sum += exp_val; + *out.add(base + d) = T::from_f64((val - max_val).exp() * inv_sum); } + } +} - // Normalize by sum - let inv_sum = 1.0 / sum; +/// Softmax backward kernel: d_input = output * (grad - sum(grad * output)) +/// +/// Dispatches to SIMD for f32/f64, with f16/bf16 block-convert wrappers. +/// Falls back to scalar for other types. +/// +/// # Safety +/// - `grad`, `output`, `d_input` must point to `outer_size * dim_size` elements +#[inline] +pub unsafe fn softmax_bwd_kernel( + grad: *const T, + output: *const T, + d_input: *mut T, + outer_size: usize, + dim_size: usize, +) { + #[cfg(target_arch = "x86_64")] + { + use crate::dtype::DType; + use crate::runtime::cpu::kernels::simd::softmax_bwd; + + match T::DTYPE { + DType::F32 => { + softmax_bwd::softmax_bwd_f32( + grad as *const f32, + output as *const f32, + d_input as *mut f32, + outer_size, + dim_size, + ); + return; + } + DType::F64 => { + softmax_bwd::softmax_bwd_f64( + grad as *const f64, + output as *const f64, + d_input as *mut f64, + outer_size, + dim_size, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + softmax_bwd::softmax_bwd_f16( + grad as *const half::f16, + output as *const half::f16, + d_input as *mut half::f16, + outer_size, + dim_size, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + softmax_bwd::softmax_bwd_bf16( + grad as *const half::bf16, + output as *const half::bf16, + d_input as *mut half::bf16, + outer_size, + dim_size, + ); + return; + } + _ => {} // Fall through to scalar + } + } + + #[cfg(target_arch = "aarch64")] + { + use crate::dtype::DType; + use crate::runtime::cpu::kernels::simd::softmax_bwd; + + match T::DTYPE { + DType::F32 => { + softmax_bwd::softmax_bwd_f32( + grad as *const f32, + output as *const f32, + d_input as *mut f32, + outer_size, + dim_size, + ); + return; + } + DType::F64 => { + softmax_bwd::softmax_bwd_f64( + grad as *const f64, + output as *const f64, + d_input as *mut f64, + outer_size, + dim_size, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + softmax_bwd::softmax_bwd_f16( + grad as *const half::f16, + output as *const half::f16, + d_input as *mut half::f16, + outer_size, + dim_size, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + softmax_bwd::softmax_bwd_bf16( + grad as *const half::bf16, + output as *const half::bf16, + d_input as *mut half::bf16, + outer_size, + dim_size, + ); + return; + } + _ => {} // Fall through to scalar + } + } + + // Scalar fallback + softmax_bwd_kernel_scalar(grad, output, d_input, outer_size, dim_size); +} + +/// Scalar softmax backward for all Element types +#[inline] +unsafe fn softmax_bwd_kernel_scalar( + grad: *const T, + output: *const T, + d_input: *mut T, + outer_size: usize, + dim_size: usize, +) { + for o in 0..outer_size { + let base = o * dim_size; + + // Pass 1: dot = sum(grad * output) + let mut dot = 0.0f64; + for d in 0..dim_size { + dot += (*grad.add(base + d)).to_f64() * (*output.add(base + d)).to_f64(); + } + + // Pass 2: d_input = output * (grad - dot) for d in 0..dim_size { - let val = (*out.add(base + d)).to_f64(); - *out.add(base + d) = T::from_f64(val * inv_sum); + let idx = base + d; + let g = (*grad.add(idx)).to_f64(); + let out = (*output.add(idx)).to_f64(); + *d_input.add(idx) = T::from_f64(out * (g - dot)); } } } diff --git a/src/runtime/cpu/kernels/rng.rs b/src/runtime/cpu/kernels/rng.rs new file mode 100644 index 00000000..fbb04bca --- /dev/null +++ b/src/runtime/cpu/kernels/rng.rs @@ -0,0 +1,365 @@ +//! Shared PRNG and distribution sampling for CPU kernels. +//! +//! Provides Xoshiro256++ as the standard PRNG and distribution samplers +//! that replace the `rand` and `rand_distr` crate dependencies. + +use std::f64::consts::PI; +use std::sync::atomic::{AtomicU64, Ordering}; + +// --------------------------------------------------------------------------- +// Xoshiro256++ PRNG +// --------------------------------------------------------------------------- + +/// Xoshiro256++ state (Blackman & Vigna 2018). +#[derive(Clone)] +pub(crate) struct Xoshiro256 { + s: [u64; 4], +} + +impl Xoshiro256 { + /// Create from seed using SplitMix64 to expand the seed. + #[inline(always)] + pub(crate) fn from_seed(seed: u64) -> Self { + let mut sm_state = seed; + let mut splitmix = || { + sm_state = sm_state.wrapping_add(0x9e3779b97f4a7c15); + let mut z = sm_state; + z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9); + z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb); + z ^ (z >> 31) + }; + + Self { + s: [splitmix(), splitmix(), splitmix(), splitmix()], + } + } + + /// Generate next u64. + #[inline(always)] + pub(crate) fn next(&mut self) -> u64 { + let result = self.s[0] + .wrapping_add(self.s[3]) + .rotate_left(23) + .wrapping_add(self.s[0]); + + let t = self.s[1] << 17; + + self.s[2] ^= self.s[0]; + self.s[3] ^= self.s[1]; + self.s[1] ^= self.s[2]; + self.s[0] ^= self.s[3]; + + self.s[2] ^= t; + self.s[3] = self.s[3].rotate_left(45); + + result + } +} + +// --------------------------------------------------------------------------- +// Entropy-based seeding (no getrandom / no rand crate) +// --------------------------------------------------------------------------- + +static COUNTER: AtomicU64 = AtomicU64::new(0); + +#[cfg(not(target_arch = "wasm32"))] +fn get_thread_entropy() -> u64 { + let id = std::thread::current().id(); + let s = format!("{:?}", id); + let mut h: u64 = 0xcbf29ce484222325; + for b in s.bytes() { + h ^= b as u64; + h = h.wrapping_mul(0x100000001b3); + } + h +} + +#[cfg(target_arch = "wasm32")] +fn get_thread_entropy() -> u64 { + // No threads on wasm, use a different mixing constant. + 0xd1342543de82ef95 +} + +/// Create a new Xoshiro256++ seeded from available entropy. +/// +/// Uses a combination of address-space randomisation (ASLR), an atomic +/// counter, and thread ID to generate unique seeds without OS entropy. +pub(crate) fn thread_rng() -> Xoshiro256 { + let counter = COUNTER.fetch_add(1, Ordering::Relaxed); + let thread_id = get_thread_entropy(); + // Mix in a stack address for ASLR entropy. + let stack_addr = &counter as *const _ as u64; + let seed = counter + .wrapping_mul(0x9e3779b97f4a7c15) + .wrapping_add(thread_id) + .wrapping_add(stack_addr); + Xoshiro256::from_seed(seed) +} + +// --------------------------------------------------------------------------- +// Primitive samplers +// --------------------------------------------------------------------------- + +/// Convert a raw u64 to a uniform f64 in [0, 1) using 53 bits. +#[inline(always)] +pub(crate) fn u64_to_uniform(u: u64) -> f64 { + (u >> 11) as f64 / (1u64 << 53) as f64 +} + +/// Sample a uniform f64 in [0, 1). +#[inline(always)] +pub(crate) fn sample_uniform(rng: &mut Xoshiro256) -> f64 { + u64_to_uniform(rng.next()) +} + +/// Sample a standard-normal f64 (mean 0, std 1) via Box-Muller. +/// +/// Generates a pair and discards the second value for simplicity. +#[inline(always)] +pub(crate) fn sample_normal(rng: &mut Xoshiro256) -> f64 { + let u1 = sample_uniform(rng).clamp(1e-10, 1.0 - 1e-10); + let u2 = sample_uniform(rng); + let r = (-2.0 * u1.ln()).sqrt(); + r * (2.0 * PI * u2).cos() +} + +/// Sample a uniform integer in [low, high). +/// +/// Uses rejection sampling to avoid modulo bias: we reject values from the +/// incomplete final bucket of size `u64::MAX % range` at the top of the range. +#[inline(always)] +pub(crate) fn sample_uniform_int(rng: &mut Xoshiro256, low: i64, high: i64) -> i64 { + debug_assert!(low < high); + let range = (high - low) as u64; + // Largest multiple of `range` that fits in u64: reject anything >= limit. + // For power-of-2 ranges, limit == 0 (wraps), so the loop always accepts on first try. + let limit = range.wrapping_neg() % range; // = (2^64 - range) % range = 2^64 % range + loop { + let raw = rng.next(); + if raw >= limit { + return low + (raw % range) as i64; + } + } +} + +/// Sample from Exponential(rate) via inverse transform. +#[inline(always)] +pub(crate) fn sample_exponential(rng: &mut Xoshiro256, rate: f64) -> f64 { + let u = sample_uniform(rng).clamp(1e-300, 1.0 - 1e-10); + -u.ln() / rate +} + +/// Sample from Gamma(shape, scale) using Marsaglia & Tsang (2000). +pub(crate) fn sample_gamma(rng: &mut Xoshiro256, shape: f64, scale: f64) -> f64 { + if shape < 1.0 { + // Gamma(shape) = Gamma(shape+1) * U^(1/shape) + let g = sample_gamma(rng, shape + 1.0, 1.0); + let u = sample_uniform(rng).clamp(1e-300, 1.0); + return g * u.powf(1.0 / shape) * scale; + } + + let d = shape - 1.0 / 3.0; + let c = 1.0 / (9.0 * d).sqrt(); + + loop { + let x = sample_normal(rng); + let v_base = 1.0 + c * x; + if v_base <= 0.0 { + continue; + } + let v = v_base * v_base * v_base; + let u = sample_uniform(rng).clamp(1e-300, 1.0); + + // Squeeze test (fast path) + if u < 1.0 - 0.0331 * (x * x) * (x * x) { + return d * v * scale; + } + // Full test + if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) { + return d * v * scale; + } + } +} + +/// Sample from Beta(alpha, beta) via two Gamma samples. +#[inline] +pub(crate) fn sample_beta(rng: &mut Xoshiro256, alpha: f64, beta: f64) -> f64 { + let x = sample_gamma(rng, alpha, 1.0); + let y = sample_gamma(rng, beta, 1.0); + x / (x + y) +} + +/// Sample from Poisson(lambda). +/// +/// Knuth's algorithm for small lambda (<30), normal approximation for large. +pub(crate) fn sample_poisson(rng: &mut Xoshiro256, lambda: f64) -> u64 { + if lambda < 30.0 { + let l = (-lambda).exp(); + let mut k: u64 = 0; + let mut p = 1.0f64; + loop { + k += 1; + p *= sample_uniform(rng); + if p < l { + return k - 1; + } + } + } else { + // Normal approximation + let val = lambda + lambda.sqrt() * sample_normal(rng); + val.round().max(0.0) as u64 + } +} + +/// Sample from Binomial(n, p). +/// +/// For small n, sum of Bernoulli trials. For large n, normal approximation. +pub(crate) fn sample_binomial(rng: &mut Xoshiro256, n: u64, p: f64) -> u64 { + if n <= 64 { + let mut successes = 0u64; + for _ in 0..n { + if sample_uniform(rng) < p { + successes += 1; + } + } + successes + } else { + // Normal approximation: N(np, np(1-p)) + let mean = n as f64 * p; + let std = (mean * (1.0 - p)).sqrt(); + let val = mean + std * sample_normal(rng); + val.round().clamp(0.0, n as f64) as u64 + } +} + +/// Fisher-Yates shuffle of a mutable slice. +pub(crate) fn shuffle(rng: &mut Xoshiro256, slice: &mut [T]) { + let n = slice.len(); + for i in (1..n).rev() { + let bound = i as u64 + 1; + let j = sample_uniform_int(rng, 0, bound as i64) as usize; + slice.swap(i, j); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_uniform_range() { + let mut rng = Xoshiro256::from_seed(42); + for _ in 0..10_000 { + let v = sample_uniform(&mut rng); + assert!((0.0..1.0).contains(&v)); + } + } + + #[test] + fn test_normal_statistics() { + let mut rng = Xoshiro256::from_seed(42); + let n = 50_000; + let samples: Vec = (0..n).map(|_| sample_normal(&mut rng)).collect(); + let mean = samples.iter().sum::() / n as f64; + let var = samples.iter().map(|x| (x - mean).powi(2)).sum::() / n as f64; + assert!(mean.abs() < 0.05, "mean = {mean}"); + assert!((var - 1.0).abs() < 0.1, "var = {var}"); + } + + #[test] + fn test_uniform_int() { + let mut rng = Xoshiro256::from_seed(42); + for _ in 0..10_000 { + let v = sample_uniform_int(&mut rng, -5, 10); + assert!((-5..10).contains(&v)); + } + } + + #[test] + fn test_exponential_positive() { + let mut rng = Xoshiro256::from_seed(42); + for _ in 0..1_000 { + assert!(sample_exponential(&mut rng, 1.0) > 0.0); + } + } + + #[test] + fn test_gamma_statistics() { + let mut rng = Xoshiro256::from_seed(42); + let n = 10_000; + let shape = 2.0; + let scale = 1.0; + let samples: Vec = (0..n) + .map(|_| sample_gamma(&mut rng, shape, scale)) + .collect(); + let mean = samples.iter().sum::() / n as f64; + assert!(samples.iter().all(|&x| x > 0.0)); + assert!((mean - shape * scale).abs() < 0.3, "mean = {mean}"); + } + + #[test] + fn test_gamma_small_shape() { + let mut rng = Xoshiro256::from_seed(42); + let n = 5_000; + let samples: Vec = (0..n).map(|_| sample_gamma(&mut rng, 0.5, 1.0)).collect(); + assert!(samples.iter().all(|&x| x > 0.0)); + let mean = samples.iter().sum::() / n as f64; + assert!((mean - 0.5).abs() < 0.2, "mean = {mean}"); + } + + #[test] + fn test_beta_range() { + let mut rng = Xoshiro256::from_seed(42); + for _ in 0..1_000 { + let v = sample_beta(&mut rng, 2.0, 5.0); + assert!((0.0..=1.0).contains(&v)); + } + } + + #[test] + fn test_poisson_small() { + let mut rng = Xoshiro256::from_seed(42); + let n = 10_000; + let lambda = 5.0; + let samples: Vec = (0..n).map(|_| sample_poisson(&mut rng, lambda)).collect(); + let mean = samples.iter().sum::() as f64 / n as f64; + assert!((mean - lambda).abs() < 0.5, "mean = {mean}"); + } + + #[test] + fn test_poisson_large() { + let mut rng = Xoshiro256::from_seed(42); + let n = 10_000; + let lambda = 100.0; + let samples: Vec = (0..n).map(|_| sample_poisson(&mut rng, lambda)).collect(); + let mean = samples.iter().sum::() as f64 / n as f64; + assert!((mean - lambda).abs() < 5.0, "mean = {mean}"); + } + + #[test] + fn test_binomial_small() { + let mut rng = Xoshiro256::from_seed(42); + let n = 10_000; + let trials = 10u64; + let p = 0.5; + let samples: Vec = (0..n) + .map(|_| sample_binomial(&mut rng, trials, p)) + .collect(); + assert!(samples.iter().all(|&x| x <= trials)); + let mean = samples.iter().sum::() as f64 / n as f64; + assert!((mean - trials as f64 * p).abs() < 0.5, "mean = {mean}"); + } + + #[test] + fn test_shuffle() { + let mut rng = Xoshiro256::from_seed(42); + let mut v: Vec = (0..100).collect(); + shuffle(&mut rng, &mut v); + // Should still contain all elements + let mut sorted = v.clone(); + sorted.sort(); + assert_eq!(sorted, (0..100).collect::>()); + // Should not be in original order (extremely unlikely) + assert_ne!(v, (0..100).collect::>()); + } +} diff --git a/src/runtime/cpu/kernels/scalar.rs b/src/runtime/cpu/kernels/scalar.rs index 1afe2019..791e7458 100644 --- a/src/runtime/cpu/kernels/scalar.rs +++ b/src/runtime/cpu/kernels/scalar.rs @@ -3,7 +3,7 @@ //! Provides tensor-scalar operations with automatic SIMD dispatch. //! On x86-64, f32 and f64 operations use AVX-512 or AVX2 when available. -use crate::dtype::{DType, Element}; +use crate::dtype::Element; use crate::ops::BinaryOp; /// Binary operation with a scalar (tensor op scalar) with automatic SIMD dispatch @@ -27,6 +27,7 @@ pub unsafe fn scalar_op_kernel( #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] { use super::simd::scalar; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -37,6 +38,28 @@ pub unsafe fn scalar_op_kernel( scalar::scalar_f64(op, a as *const f64, scalar, out as *mut f64, len); return; } + #[cfg(feature = "f16")] + DType::F16 => { + scalar::scalar_f16( + op, + a as *const half::f16, + scalar as f32, + out as *mut half::f16, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + scalar::scalar_bf16( + op, + a as *const half::bf16, + scalar as f32, + out as *mut half::bf16, + len, + ); + return; + } _ => {} // Fall through to scalar } } @@ -116,6 +139,7 @@ pub unsafe fn rsub_scalar_kernel(a: *const T, scalar: f64, out: *mut #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] { use super::simd::scalar; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -126,6 +150,26 @@ pub unsafe fn rsub_scalar_kernel(a: *const T, scalar: f64, out: *mut scalar::rsub_scalar_f64(a as *const f64, scalar, out as *mut f64, len); return; } + #[cfg(feature = "f16")] + DType::F16 => { + scalar::rsub_scalar_f16( + a as *const half::f16, + scalar as f32, + out as *mut half::f16, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + scalar::rsub_scalar_bf16( + a as *const half::bf16, + scalar as f32, + out as *mut half::bf16, + len, + ); + return; + } _ => {} // Fall through to scalar } } diff --git a/src/runtime/cpu/kernels/simd/activations/dispatch.rs b/src/runtime/cpu/kernels/simd/activations/dispatch.rs new file mode 100644 index 00000000..abd1160a --- /dev/null +++ b/src/runtime/cpu/kernels/simd/activations/dispatch.rs @@ -0,0 +1,557 @@ +//! SIMD-accelerated activation function dispatch and scalar fallbacks. +//! +//! This module provides: +//! - Top-level dispatch functions that select the best SIMD implementation +//! - Scalar fallback implementations for all activations + +#[cfg(target_arch = "aarch64")] +use super::aarch64; +#[cfg(target_arch = "x86_64")] +use super::avx2; +#[cfg(target_arch = "x86_64")] +use super::avx512; +use crate::runtime::cpu::kernels::simd::{SimdLevel, detect_simd}; + +/// Minimum length to justify SIMD overhead +const SIMD_THRESHOLD: usize = 32; + +/// SIMD sigmoid for f32 +/// +/// # Safety +/// - `a` and `out` must point to `len` elements +#[inline] +pub unsafe fn sigmoid_f32(a: *const f32, out: *mut f32, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + sigmoid_scalar_f32(a, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::sigmoid_f32(a, out, len), + SimdLevel::Avx2Fma => avx2::sigmoid_f32(a, out, len), + _ => sigmoid_scalar_f32(a, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::sigmoid_f32(a, out, len), + _ => sigmoid_scalar_f32(a, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + sigmoid_scalar_f32(a, out, len); +} + +/// SIMD sigmoid for f64 +/// +/// # Safety +/// - `a` and `out` must point to `len` elements +#[inline] +pub unsafe fn sigmoid_f64(a: *const f64, out: *mut f64, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + sigmoid_scalar_f64(a, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::sigmoid_f64(a, out, len), + SimdLevel::Avx2Fma => avx2::sigmoid_f64(a, out, len), + _ => sigmoid_scalar_f64(a, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::sigmoid_f64(a, out, len), + _ => sigmoid_scalar_f64(a, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + sigmoid_scalar_f64(a, out, len); +} + +/// SIMD SiLU for f32 +/// +/// # Safety +/// - `a` and `out` must point to `len` elements +#[inline] +pub unsafe fn silu_f32(a: *const f32, out: *mut f32, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + silu_scalar_f32(a, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::silu_f32(a, out, len), + SimdLevel::Avx2Fma => avx2::silu_f32(a, out, len), + _ => silu_scalar_f32(a, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::silu_f32(a, out, len), + _ => silu_scalar_f32(a, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + silu_scalar_f32(a, out, len); +} + +/// SIMD SiLU for f64 +/// +/// # Safety +/// - `a` and `out` must point to `len` elements +#[inline] +pub unsafe fn silu_f64(a: *const f64, out: *mut f64, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + silu_scalar_f64(a, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::silu_f64(a, out, len), + SimdLevel::Avx2Fma => avx2::silu_f64(a, out, len), + _ => silu_scalar_f64(a, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::silu_f64(a, out, len), + _ => silu_scalar_f64(a, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + silu_scalar_f64(a, out, len); +} + +/// SIMD GELU for f32 +/// +/// # Safety +/// - `a` and `out` must point to `len` elements +#[inline] +pub unsafe fn gelu_f32(a: *const f32, out: *mut f32, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + gelu_scalar_f32(a, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::gelu_f32(a, out, len), + SimdLevel::Avx2Fma => avx2::gelu_f32(a, out, len), + _ => gelu_scalar_f32(a, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::gelu_f32(a, out, len), + _ => gelu_scalar_f32(a, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + gelu_scalar_f32(a, out, len); +} + +/// SIMD GELU for f64 +/// +/// # Safety +/// - `a` and `out` must point to `len` elements +#[inline] +pub unsafe fn gelu_f64(a: *const f64, out: *mut f64, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + gelu_scalar_f64(a, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::gelu_f64(a, out, len), + SimdLevel::Avx2Fma => avx2::gelu_f64(a, out, len), + _ => gelu_scalar_f64(a, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::gelu_f64(a, out, len), + _ => gelu_scalar_f64(a, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + gelu_scalar_f64(a, out, len); +} + +/// SIMD Leaky ReLU for f32 +/// +/// # Safety +/// - `a` and `out` must point to `len` elements +#[inline] +pub unsafe fn leaky_relu_f32(a: *const f32, out: *mut f32, len: usize, negative_slope: f32) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + leaky_relu_scalar_f32(a, out, len, negative_slope); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::leaky_relu_f32(a, out, len, negative_slope), + SimdLevel::Avx2Fma => avx2::leaky_relu_f32(a, out, len, negative_slope), + _ => leaky_relu_scalar_f32(a, out, len, negative_slope), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::leaky_relu_f32(a, out, len, negative_slope) + } + _ => leaky_relu_scalar_f32(a, out, len, negative_slope), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + leaky_relu_scalar_f32(a, out, len, negative_slope); +} + +/// SIMD Leaky ReLU for f64 +/// +/// # Safety +/// - `a` and `out` must point to `len` elements +#[inline] +pub unsafe fn leaky_relu_f64(a: *const f64, out: *mut f64, len: usize, negative_slope: f64) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + leaky_relu_scalar_f64(a, out, len, negative_slope); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::leaky_relu_f64(a, out, len, negative_slope), + SimdLevel::Avx2Fma => avx2::leaky_relu_f64(a, out, len, negative_slope), + _ => leaky_relu_scalar_f64(a, out, len, negative_slope), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::leaky_relu_f64(a, out, len, negative_slope) + } + _ => leaky_relu_scalar_f64(a, out, len, negative_slope), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + leaky_relu_scalar_f64(a, out, len, negative_slope); +} + +/// SIMD ELU for f32 +/// +/// # Safety +/// - `a` and `out` must point to `len` elements +#[inline] +pub unsafe fn elu_f32(a: *const f32, out: *mut f32, len: usize, alpha: f32) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + elu_scalar_f32(a, out, len, alpha); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::elu_f32(a, out, len, alpha), + SimdLevel::Avx2Fma => avx2::elu_f32(a, out, len, alpha), + _ => elu_scalar_f32(a, out, len, alpha), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::elu_f32(a, out, len, alpha), + _ => elu_scalar_f32(a, out, len, alpha), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + elu_scalar_f32(a, out, len, alpha); +} + +/// SIMD ELU for f64 +/// +/// # Safety +/// - `a` and `out` must point to `len` elements +#[inline] +pub unsafe fn elu_f64(a: *const f64, out: *mut f64, len: usize, alpha: f64) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + elu_scalar_f64(a, out, len, alpha); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::elu_f64(a, out, len, alpha), + SimdLevel::Avx2Fma => avx2::elu_f64(a, out, len, alpha), + _ => elu_scalar_f64(a, out, len, alpha), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::elu_f64(a, out, len, alpha), + _ => elu_scalar_f64(a, out, len, alpha), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + elu_scalar_f64(a, out, len, alpha); +} + +// ============================================================================ +// Scalar fallbacks +// ============================================================================ + +/// Scalar sigmoid for f32 +#[inline] +pub unsafe fn sigmoid_scalar_f32(a: *const f32, out: *mut f32, len: usize) { + for i in 0..len { + let x = *a.add(i); + *out.add(i) = 1.0 / (1.0 + (-x).exp()); + } +} + +/// Scalar sigmoid for f64 +#[inline] +pub unsafe fn sigmoid_scalar_f64(a: *const f64, out: *mut f64, len: usize) { + for i in 0..len { + let x = *a.add(i); + *out.add(i) = 1.0 / (1.0 + (-x).exp()); + } +} + +/// Scalar SiLU for f32 +#[inline] +pub unsafe fn silu_scalar_f32(a: *const f32, out: *mut f32, len: usize) { + for i in 0..len { + let x = *a.add(i); + *out.add(i) = x / (1.0 + (-x).exp()); + } +} + +/// Scalar SiLU for f64 +#[inline] +pub unsafe fn silu_scalar_f64(a: *const f64, out: *mut f64, len: usize) { + for i in 0..len { + let x = *a.add(i); + *out.add(i) = x / (1.0 + (-x).exp()); + } +} + +/// Scalar GELU for f32 +#[inline] +pub unsafe fn gelu_scalar_f32(a: *const f32, out: *mut f32, len: usize) { + const SQRT_2_OVER_PI: f32 = 0.7978845608; // sqrt(2/pi) + const TANH_COEF: f32 = 0.044715; + + for i in 0..len { + let x = *a.add(i); + let inner = SQRT_2_OVER_PI * (x + TANH_COEF * x * x * x); + *out.add(i) = 0.5 * x * (1.0 + inner.tanh()); + } +} + +/// Scalar GELU for f64 +#[inline] +pub unsafe fn gelu_scalar_f64(a: *const f64, out: *mut f64, len: usize) { + const SQRT_2_OVER_PI: f64 = 0.7978845608028654; // sqrt(2/pi) + const TANH_COEF: f64 = 0.044715; + + for i in 0..len { + let x = *a.add(i); + let inner = SQRT_2_OVER_PI * (x + TANH_COEF * x * x * x); + *out.add(i) = 0.5 * x * (1.0 + inner.tanh()); + } +} + +/// Scalar Leaky ReLU for f32 +#[inline] +pub unsafe fn leaky_relu_scalar_f32(a: *const f32, out: *mut f32, len: usize, negative_slope: f32) { + for i in 0..len { + let x = *a.add(i); + *out.add(i) = if x > 0.0 { x } else { negative_slope * x }; + } +} + +/// Scalar Leaky ReLU for f64 +#[inline] +pub unsafe fn leaky_relu_scalar_f64(a: *const f64, out: *mut f64, len: usize, negative_slope: f64) { + for i in 0..len { + let x = *a.add(i); + *out.add(i) = if x > 0.0 { x } else { negative_slope * x }; + } +} + +/// Scalar ELU for f32 +#[inline] +pub unsafe fn elu_scalar_f32(a: *const f32, out: *mut f32, len: usize, alpha: f32) { + for i in 0..len { + let x = *a.add(i); + *out.add(i) = if x > 0.0 { x } else { alpha * (x.exp() - 1.0) }; + } +} + +/// Scalar ELU for f64 +#[inline] +pub unsafe fn elu_scalar_f64(a: *const f64, out: *mut f64, len: usize, alpha: f64) { + for i in 0..len { + let x = *a.add(i); + *out.add(i) = if x > 0.0 { x } else { alpha * (x.exp() - 1.0) }; + } +} + +// ============================================================================ +// f16/bf16 wrappers (block-convert through f32) +// ============================================================================ + +half_unary!(sigmoid, sigmoid_f32); +half_unary!(silu, silu_f32); +half_unary!(gelu, gelu_f32); +half_unary_param!(leaky_relu, leaky_relu_f32); +half_unary_param!(elu, elu_f32); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sigmoid_f32() { + let len = 128; + let input: Vec = (0..len).map(|x| (x as f32) / 32.0 - 2.0).collect(); + let mut out = vec![0.0f32; len]; + let mut out_ref = vec![0.0f32; len]; + + unsafe { + sigmoid_f32(input.as_ptr(), out.as_mut_ptr(), len); + sigmoid_scalar_f32(input.as_ptr(), out_ref.as_mut_ptr(), len); + } + + for i in 0..len { + let diff = (out[i] - out_ref[i]).abs(); + let rel_err = diff / out_ref[i].abs().max(1e-6); + assert!( + rel_err < 0.01, + "sigmoid mismatch at {}: {} vs {}", + i, + out[i], + out_ref[i] + ); + } + } + + #[test] + fn test_silu_f32() { + let len = 128; + let input: Vec = (0..len).map(|x| (x as f32) / 32.0 - 2.0).collect(); + let mut out = vec![0.0f32; len]; + let mut out_ref = vec![0.0f32; len]; + + unsafe { + silu_f32(input.as_ptr(), out.as_mut_ptr(), len); + silu_scalar_f32(input.as_ptr(), out_ref.as_mut_ptr(), len); + } + + for i in 0..len { + let diff = (out[i] - out_ref[i]).abs(); + let denom = out_ref[i].abs().max(1e-6); + let rel_err = diff / denom; + assert!( + rel_err < 0.01, + "silu mismatch at {}: {} vs {}", + i, + out[i], + out_ref[i] + ); + } + } + + #[test] + fn test_gelu_f32() { + let len = 128; + let input: Vec = (0..len).map(|x| (x as f32) / 32.0 - 2.0).collect(); + let mut out = vec![0.0f32; len]; + let mut out_ref = vec![0.0f32; len]; + + unsafe { + gelu_f32(input.as_ptr(), out.as_mut_ptr(), len); + gelu_scalar_f32(input.as_ptr(), out_ref.as_mut_ptr(), len); + } + + for i in 0..len { + let diff = (out[i] - out_ref[i]).abs(); + let denom = out_ref[i].abs().max(1e-6); + let rel_err = diff / denom; + assert!( + rel_err < 0.02, + "gelu mismatch at {}: {} vs {}", + i, + out[i], + out_ref[i] + ); + } + } + + #[test] + fn test_leaky_relu_f32() { + let len = 128; + let input: Vec = (0..len).map(|x| (x as f32) - 64.0).collect(); + let mut out = vec![0.0f32; len]; + let mut out_ref = vec![0.0f32; len]; + let negative_slope = 0.1f32; + + unsafe { + leaky_relu_f32(input.as_ptr(), out.as_mut_ptr(), len, negative_slope); + leaky_relu_scalar_f32(input.as_ptr(), out_ref.as_mut_ptr(), len, negative_slope); + } + + assert_eq!(out, out_ref); + } + + #[test] + fn test_elu_f32() { + let len = 128; + let input: Vec = (0..len).map(|x| (x as f32) / 32.0 - 2.0).collect(); + let mut out = vec![0.0f32; len]; + let mut out_ref = vec![0.0f32; len]; + let alpha = 1.0f32; + + unsafe { + elu_f32(input.as_ptr(), out.as_mut_ptr(), len, alpha); + elu_scalar_f32(input.as_ptr(), out_ref.as_mut_ptr(), len, alpha); + } + + for i in 0..len { + let diff = (out[i] - out_ref[i]).abs(); + let denom = out_ref[i].abs().max(1e-6); + let rel_err = diff / denom; + assert!( + rel_err < 0.01, + "elu mismatch at {}: {} vs {}", + i, + out[i], + out_ref[i] + ); + } + } +} diff --git a/src/runtime/cpu/kernels/simd/activations/mod.rs b/src/runtime/cpu/kernels/simd/activations/mod.rs index d9bf80c8..c2a294c2 100644 --- a/src/runtime/cpu/kernels/simd/activations/mod.rs +++ b/src/runtime/cpu/kernels/simd/activations/mod.rs @@ -1,558 +1,14 @@ -//! SIMD-accelerated activation functions +//! SIMD-accelerated activation functions. //! -//! Provides vectorized implementations of common neural network activations: -//! - Sigmoid: 1 / (1 + exp(-x)) -//! - SiLU (Swish): x * sigmoid(x) -//! - GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) -//! - Leaky ReLU: max(negative_slope * x, x) -//! - ELU: x if x > 0, else alpha * (exp(x) - 1) -//! -//! # SIMD Approach -//! -//! Uses polynomial approximations for exp and tanh: -//! - exp(x): Range reduction + Taylor series -//! - tanh(x): Based on exp via tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) +//! See [`dispatch`] for the public dispatch functions and scalar fallbacks. #[cfg(target_arch = "x86_64")] -mod avx2; +pub(crate) mod avx2; #[cfg(target_arch = "x86_64")] -mod avx512; +pub(crate) mod avx512; +pub(crate) mod dispatch; #[cfg(target_arch = "aarch64")] -mod aarch64; - -use super::{SimdLevel, detect_simd}; - -/// Minimum length to justify SIMD overhead -const SIMD_THRESHOLD: usize = 32; - -/// SIMD sigmoid for f32 -/// -/// # Safety -/// - `a` and `out` must point to `len` elements -#[inline] -pub unsafe fn sigmoid_f32(a: *const f32, out: *mut f32, len: usize) { - let level = detect_simd(); - - if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { - sigmoid_scalar_f32(a, out, len); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::sigmoid_f32(a, out, len), - SimdLevel::Avx2Fma => avx2::sigmoid_f32(a, out, len), - _ => sigmoid_scalar_f32(a, out, len), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::sigmoid_f32(a, out, len), - _ => sigmoid_scalar_f32(a, out, len), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - sigmoid_scalar_f32(a, out, len); -} - -/// SIMD sigmoid for f64 -/// -/// # Safety -/// - `a` and `out` must point to `len` elements -#[inline] -pub unsafe fn sigmoid_f64(a: *const f64, out: *mut f64, len: usize) { - let level = detect_simd(); - - if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { - sigmoid_scalar_f64(a, out, len); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::sigmoid_f64(a, out, len), - SimdLevel::Avx2Fma => avx2::sigmoid_f64(a, out, len), - _ => sigmoid_scalar_f64(a, out, len), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::sigmoid_f64(a, out, len), - _ => sigmoid_scalar_f64(a, out, len), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - sigmoid_scalar_f64(a, out, len); -} - -/// SIMD SiLU for f32 -/// -/// # Safety -/// - `a` and `out` must point to `len` elements -#[inline] -pub unsafe fn silu_f32(a: *const f32, out: *mut f32, len: usize) { - let level = detect_simd(); - - if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { - silu_scalar_f32(a, out, len); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::silu_f32(a, out, len), - SimdLevel::Avx2Fma => avx2::silu_f32(a, out, len), - _ => silu_scalar_f32(a, out, len), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::silu_f32(a, out, len), - _ => silu_scalar_f32(a, out, len), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - silu_scalar_f32(a, out, len); -} - -/// SIMD SiLU for f64 -/// -/// # Safety -/// - `a` and `out` must point to `len` elements -#[inline] -pub unsafe fn silu_f64(a: *const f64, out: *mut f64, len: usize) { - let level = detect_simd(); - - if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { - silu_scalar_f64(a, out, len); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::silu_f64(a, out, len), - SimdLevel::Avx2Fma => avx2::silu_f64(a, out, len), - _ => silu_scalar_f64(a, out, len), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::silu_f64(a, out, len), - _ => silu_scalar_f64(a, out, len), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - silu_scalar_f64(a, out, len); -} - -/// SIMD GELU for f32 -/// -/// # Safety -/// - `a` and `out` must point to `len` elements -#[inline] -pub unsafe fn gelu_f32(a: *const f32, out: *mut f32, len: usize) { - let level = detect_simd(); - - if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { - gelu_scalar_f32(a, out, len); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::gelu_f32(a, out, len), - SimdLevel::Avx2Fma => avx2::gelu_f32(a, out, len), - _ => gelu_scalar_f32(a, out, len), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::gelu_f32(a, out, len), - _ => gelu_scalar_f32(a, out, len), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - gelu_scalar_f32(a, out, len); -} - -/// SIMD GELU for f64 -/// -/// # Safety -/// - `a` and `out` must point to `len` elements -#[inline] -pub unsafe fn gelu_f64(a: *const f64, out: *mut f64, len: usize) { - let level = detect_simd(); - - if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { - gelu_scalar_f64(a, out, len); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::gelu_f64(a, out, len), - SimdLevel::Avx2Fma => avx2::gelu_f64(a, out, len), - _ => gelu_scalar_f64(a, out, len), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::gelu_f64(a, out, len), - _ => gelu_scalar_f64(a, out, len), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - gelu_scalar_f64(a, out, len); -} - -/// SIMD Leaky ReLU for f32 -/// -/// # Safety -/// - `a` and `out` must point to `len` elements -#[inline] -pub unsafe fn leaky_relu_f32(a: *const f32, out: *mut f32, len: usize, negative_slope: f32) { - let level = detect_simd(); - - if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { - leaky_relu_scalar_f32(a, out, len, negative_slope); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::leaky_relu_f32(a, out, len, negative_slope), - SimdLevel::Avx2Fma => avx2::leaky_relu_f32(a, out, len, negative_slope), - _ => leaky_relu_scalar_f32(a, out, len, negative_slope), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - aarch64::neon::leaky_relu_f32(a, out, len, negative_slope) - } - _ => leaky_relu_scalar_f32(a, out, len, negative_slope), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - leaky_relu_scalar_f32(a, out, len, negative_slope); -} - -/// SIMD Leaky ReLU for f64 -/// -/// # Safety -/// - `a` and `out` must point to `len` elements -#[inline] -pub unsafe fn leaky_relu_f64(a: *const f64, out: *mut f64, len: usize, negative_slope: f64) { - let level = detect_simd(); - - if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { - leaky_relu_scalar_f64(a, out, len, negative_slope); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::leaky_relu_f64(a, out, len, negative_slope), - SimdLevel::Avx2Fma => avx2::leaky_relu_f64(a, out, len, negative_slope), - _ => leaky_relu_scalar_f64(a, out, len, negative_slope), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - aarch64::neon::leaky_relu_f64(a, out, len, negative_slope) - } - _ => leaky_relu_scalar_f64(a, out, len, negative_slope), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - leaky_relu_scalar_f64(a, out, len, negative_slope); -} - -/// SIMD ELU for f32 -/// -/// # Safety -/// - `a` and `out` must point to `len` elements -#[inline] -pub unsafe fn elu_f32(a: *const f32, out: *mut f32, len: usize, alpha: f32) { - let level = detect_simd(); - - if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { - elu_scalar_f32(a, out, len, alpha); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::elu_f32(a, out, len, alpha), - SimdLevel::Avx2Fma => avx2::elu_f32(a, out, len, alpha), - _ => elu_scalar_f32(a, out, len, alpha), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::elu_f32(a, out, len, alpha), - _ => elu_scalar_f32(a, out, len, alpha), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - elu_scalar_f32(a, out, len, alpha); -} - -/// SIMD ELU for f64 -/// -/// # Safety -/// - `a` and `out` must point to `len` elements -#[inline] -pub unsafe fn elu_f64(a: *const f64, out: *mut f64, len: usize, alpha: f64) { - let level = detect_simd(); - - if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { - elu_scalar_f64(a, out, len, alpha); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::elu_f64(a, out, len, alpha), - SimdLevel::Avx2Fma => avx2::elu_f64(a, out, len, alpha), - _ => elu_scalar_f64(a, out, len, alpha), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::elu_f64(a, out, len, alpha), - _ => elu_scalar_f64(a, out, len, alpha), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - elu_scalar_f64(a, out, len, alpha); -} - -// ============================================================================ -// Scalar fallbacks -// ============================================================================ - -/// Scalar sigmoid for f32 -#[inline] -pub unsafe fn sigmoid_scalar_f32(a: *const f32, out: *mut f32, len: usize) { - for i in 0..len { - let x = *a.add(i); - *out.add(i) = 1.0 / (1.0 + (-x).exp()); - } -} - -/// Scalar sigmoid for f64 -#[inline] -pub unsafe fn sigmoid_scalar_f64(a: *const f64, out: *mut f64, len: usize) { - for i in 0..len { - let x = *a.add(i); - *out.add(i) = 1.0 / (1.0 + (-x).exp()); - } -} - -/// Scalar SiLU for f32 -#[inline] -pub unsafe fn silu_scalar_f32(a: *const f32, out: *mut f32, len: usize) { - for i in 0..len { - let x = *a.add(i); - *out.add(i) = x / (1.0 + (-x).exp()); - } -} - -/// Scalar SiLU for f64 -#[inline] -pub unsafe fn silu_scalar_f64(a: *const f64, out: *mut f64, len: usize) { - for i in 0..len { - let x = *a.add(i); - *out.add(i) = x / (1.0 + (-x).exp()); - } -} - -/// Scalar GELU for f32 -#[inline] -pub unsafe fn gelu_scalar_f32(a: *const f32, out: *mut f32, len: usize) { - const SQRT_2_OVER_PI: f32 = 0.7978845608; // sqrt(2/pi) - const TANH_COEF: f32 = 0.044715; - - for i in 0..len { - let x = *a.add(i); - let inner = SQRT_2_OVER_PI * (x + TANH_COEF * x * x * x); - *out.add(i) = 0.5 * x * (1.0 + inner.tanh()); - } -} - -/// Scalar GELU for f64 -#[inline] -pub unsafe fn gelu_scalar_f64(a: *const f64, out: *mut f64, len: usize) { - const SQRT_2_OVER_PI: f64 = 0.7978845608028654; // sqrt(2/pi) - const TANH_COEF: f64 = 0.044715; - - for i in 0..len { - let x = *a.add(i); - let inner = SQRT_2_OVER_PI * (x + TANH_COEF * x * x * x); - *out.add(i) = 0.5 * x * (1.0 + inner.tanh()); - } -} - -/// Scalar Leaky ReLU for f32 -#[inline] -pub unsafe fn leaky_relu_scalar_f32(a: *const f32, out: *mut f32, len: usize, negative_slope: f32) { - for i in 0..len { - let x = *a.add(i); - *out.add(i) = if x > 0.0 { x } else { negative_slope * x }; - } -} - -/// Scalar Leaky ReLU for f64 -#[inline] -pub unsafe fn leaky_relu_scalar_f64(a: *const f64, out: *mut f64, len: usize, negative_slope: f64) { - for i in 0..len { - let x = *a.add(i); - *out.add(i) = if x > 0.0 { x } else { negative_slope * x }; - } -} - -/// Scalar ELU for f32 -#[inline] -pub unsafe fn elu_scalar_f32(a: *const f32, out: *mut f32, len: usize, alpha: f32) { - for i in 0..len { - let x = *a.add(i); - *out.add(i) = if x > 0.0 { x } else { alpha * (x.exp() - 1.0) }; - } -} - -/// Scalar ELU for f64 -#[inline] -pub unsafe fn elu_scalar_f64(a: *const f64, out: *mut f64, len: usize, alpha: f64) { - for i in 0..len { - let x = *a.add(i); - *out.add(i) = if x > 0.0 { x } else { alpha * (x.exp() - 1.0) }; - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_sigmoid_f32() { - let len = 128; - let input: Vec = (0..len).map(|x| (x as f32) / 32.0 - 2.0).collect(); - let mut out = vec![0.0f32; len]; - let mut out_ref = vec![0.0f32; len]; - - unsafe { - sigmoid_f32(input.as_ptr(), out.as_mut_ptr(), len); - sigmoid_scalar_f32(input.as_ptr(), out_ref.as_mut_ptr(), len); - } - - for i in 0..len { - let diff = (out[i] - out_ref[i]).abs(); - let rel_err = diff / out_ref[i].abs().max(1e-6); - assert!( - rel_err < 0.01, - "sigmoid mismatch at {}: {} vs {}", - i, - out[i], - out_ref[i] - ); - } - } - - #[test] - fn test_silu_f32() { - let len = 128; - let input: Vec = (0..len).map(|x| (x as f32) / 32.0 - 2.0).collect(); - let mut out = vec![0.0f32; len]; - let mut out_ref = vec![0.0f32; len]; - - unsafe { - silu_f32(input.as_ptr(), out.as_mut_ptr(), len); - silu_scalar_f32(input.as_ptr(), out_ref.as_mut_ptr(), len); - } - - for i in 0..len { - let diff = (out[i] - out_ref[i]).abs(); - let denom = out_ref[i].abs().max(1e-6); - let rel_err = diff / denom; - assert!( - rel_err < 0.01, - "silu mismatch at {}: {} vs {}", - i, - out[i], - out_ref[i] - ); - } - } - - #[test] - fn test_gelu_f32() { - let len = 128; - let input: Vec = (0..len).map(|x| (x as f32) / 32.0 - 2.0).collect(); - let mut out = vec![0.0f32; len]; - let mut out_ref = vec![0.0f32; len]; - - unsafe { - gelu_f32(input.as_ptr(), out.as_mut_ptr(), len); - gelu_scalar_f32(input.as_ptr(), out_ref.as_mut_ptr(), len); - } - - for i in 0..len { - let diff = (out[i] - out_ref[i]).abs(); - let denom = out_ref[i].abs().max(1e-6); - let rel_err = diff / denom; - assert!( - rel_err < 0.02, - "gelu mismatch at {}: {} vs {}", - i, - out[i], - out_ref[i] - ); - } - } - - #[test] - fn test_leaky_relu_f32() { - let len = 128; - let input: Vec = (0..len).map(|x| (x as f32) - 64.0).collect(); - let mut out = vec![0.0f32; len]; - let mut out_ref = vec![0.0f32; len]; - let negative_slope = 0.1f32; - - unsafe { - leaky_relu_f32(input.as_ptr(), out.as_mut_ptr(), len, negative_slope); - leaky_relu_scalar_f32(input.as_ptr(), out_ref.as_mut_ptr(), len, negative_slope); - } - - assert_eq!(out, out_ref); - } - - #[test] - fn test_elu_f32() { - let len = 128; - let input: Vec = (0..len).map(|x| (x as f32) / 32.0 - 2.0).collect(); - let mut out = vec![0.0f32; len]; - let mut out_ref = vec![0.0f32; len]; - let alpha = 1.0f32; - - unsafe { - elu_f32(input.as_ptr(), out.as_mut_ptr(), len, alpha); - elu_scalar_f32(input.as_ptr(), out_ref.as_mut_ptr(), len, alpha); - } +pub(crate) mod aarch64; - for i in 0..len { - let diff = (out[i] - out_ref[i]).abs(); - let denom = out_ref[i].abs().max(1e-6); - let rel_err = diff / denom; - assert!( - rel_err < 0.01, - "elu mismatch at {}: {} vs {}", - i, - out[i], - out_ref[i] - ); - } - } -} +pub use dispatch::*; diff --git a/src/runtime/cpu/kernels/simd/binary/aarch64/mod.rs b/src/runtime/cpu/kernels/simd/binary/aarch64/mod.rs index 6069fa8d..1b5d22a5 100644 --- a/src/runtime/cpu/kernels/simd/binary/aarch64/mod.rs +++ b/src/runtime/cpu/kernels/simd/binary/aarch64/mod.rs @@ -1,3 +1,4 @@ //! ARM64 SIMD implementations for binary operations pub mod neon; +pub mod neon_int; diff --git a/src/runtime/cpu/kernels/simd/binary/aarch64/neon_int.rs b/src/runtime/cpu/kernels/simd/binary/aarch64/neon_int.rs new file mode 100644 index 00000000..0f584433 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/binary/aarch64/neon_int.rs @@ -0,0 +1,106 @@ +//! NEON binary operation kernels for i32 on ARM64 +//! +//! Processes 4 i32s per iteration using 128-bit vectors. + +#[cfg(target_arch = "aarch64")] +use std::arch::aarch64::*; + +use super::super::binary_scalar_i32; +use crate::ops::BinaryOp; + +const I32_LANES: usize = 4; + +/// NEON binary operation for i32 +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - All pointers must be valid for `len` elements +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn binary_i32(op: BinaryOp, a: *const i32, b: *const i32, out: *mut i32, len: usize) { + let chunks = len / I32_LANES; + let remainder = len % I32_LANES; + + // Ops without SIMD integer support + if !matches!( + op, + BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Max | BinaryOp::Min + ) { + binary_scalar_i32(op, a, b, out, len); + return; + } + + match op { + BinaryOp::Add => binary_add_i32(a, b, out, chunks), + BinaryOp::Sub => binary_sub_i32(a, b, out, chunks), + BinaryOp::Mul => binary_mul_i32(a, b, out, chunks), + BinaryOp::Max => binary_max_i32(a, b, out, chunks), + BinaryOp::Min => binary_min_i32(a, b, out, chunks), + _ => unreachable!(), + } + + if remainder > 0 { + let offset = chunks * I32_LANES; + binary_scalar_i32(op, a.add(offset), b.add(offset), out.add(offset), remainder); + } +} + +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +unsafe fn binary_add_i32(a: *const i32, b: *const i32, out: *mut i32, chunks: usize) { + for i in 0..chunks { + let offset = i * I32_LANES; + let va = vld1q_s32(a.add(offset)); + let vb = vld1q_s32(b.add(offset)); + let vr = vaddq_s32(va, vb); + vst1q_s32(out.add(offset), vr); + } +} + +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +unsafe fn binary_sub_i32(a: *const i32, b: *const i32, out: *mut i32, chunks: usize) { + for i in 0..chunks { + let offset = i * I32_LANES; + let va = vld1q_s32(a.add(offset)); + let vb = vld1q_s32(b.add(offset)); + let vr = vsubq_s32(va, vb); + vst1q_s32(out.add(offset), vr); + } +} + +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +unsafe fn binary_mul_i32(a: *const i32, b: *const i32, out: *mut i32, chunks: usize) { + for i in 0..chunks { + let offset = i * I32_LANES; + let va = vld1q_s32(a.add(offset)); + let vb = vld1q_s32(b.add(offset)); + let vr = vmulq_s32(va, vb); + vst1q_s32(out.add(offset), vr); + } +} + +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +unsafe fn binary_max_i32(a: *const i32, b: *const i32, out: *mut i32, chunks: usize) { + for i in 0..chunks { + let offset = i * I32_LANES; + let va = vld1q_s32(a.add(offset)); + let vb = vld1q_s32(b.add(offset)); + let vr = vmaxq_s32(va, vb); + vst1q_s32(out.add(offset), vr); + } +} + +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +unsafe fn binary_min_i32(a: *const i32, b: *const i32, out: *mut i32, chunks: usize) { + for i in 0..chunks { + let offset = i * I32_LANES; + let va = vld1q_s32(a.add(offset)); + let vb = vld1q_s32(b.add(offset)); + let vr = vminq_s32(va, vb); + vst1q_s32(out.add(offset), vr); + } +} diff --git a/src/runtime/cpu/kernels/simd/binary/dispatch.rs b/src/runtime/cpu/kernels/simd/binary/dispatch.rs new file mode 100644 index 00000000..ab83d3b6 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/binary/dispatch.rs @@ -0,0 +1,507 @@ +//! SIMD-accelerated binary operation dispatch. +//! +//! This module provides multi-architecture SIMD implementations for element-wise +//! binary operations (add, sub, mul, div, max, min, pow). +//! +//! # Architecture Support +//! +//! | Architecture | Instruction Set | Vector Width | f32 lanes | f64 lanes | +//! |--------------|-----------------|--------------|-----------|-----------| +//! | x86-64 | AVX-512 | 512 bits | 16 | 8 | +//! | x86-64 | AVX2 + FMA | 256 bits | 8 | 4 | +//! | ARM64 | NEON | 128 bits | 4 | 2 | + +#[cfg(target_arch = "aarch64")] +use super::aarch64; +#[cfg(target_arch = "x86_64")] +use super::x86_64; + +use crate::ops::BinaryOp; +use crate::runtime::cpu::kernels::simd::{SimdLevel, detect_simd}; + +// Import scalar fallbacks from kernels module (single source of truth) +pub use crate::runtime::cpu::kernels::binary::{ + binary_scalar_f32, binary_scalar_f64, binary_scalar_i32, +}; + +/// Minimum elements to justify SIMD overhead +const SIMD_THRESHOLD: usize = 32; + +/// SIMD binary operation for f32 +/// +/// # Safety +/// - `a`, `b`, and `out` must be valid pointers to `len` elements +#[inline] +pub unsafe fn binary_f32(op: BinaryOp, a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + binary_scalar_f32(op, a, b, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => x86_64::avx512::binary_f32(op, a, b, out, len), + SimdLevel::Avx2Fma => x86_64::avx2::binary_f32(op, a, b, out, len), + _ => binary_scalar_f32(op, a, b, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::binary_f32(op, a, b, out, len), + _ => binary_scalar_f32(op, a, b, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + binary_scalar_f32(op, a, b, out, len); +} + +/// SIMD binary operation for f64 +/// +/// # Safety +/// - `a`, `b`, and `out` must be valid pointers to `len` elements +#[inline] +pub unsafe fn binary_f64(op: BinaryOp, a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + binary_scalar_f64(op, a, b, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => x86_64::avx512::binary_f64(op, a, b, out, len), + SimdLevel::Avx2Fma => x86_64::avx2::binary_f64(op, a, b, out, len), + _ => binary_scalar_f64(op, a, b, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::binary_f64(op, a, b, out, len), + _ => binary_scalar_f64(op, a, b, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + binary_scalar_f64(op, a, b, out, len); +} + +/// SIMD binary operation for i32 +/// +/// # Safety +/// - `a`, `b`, and `out` must be valid pointers to `len` elements +#[inline] +pub unsafe fn binary_i32(op: BinaryOp, a: *const i32, b: *const i32, out: *mut i32, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + binary_scalar_i32(op, a, b, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => x86_64::avx512_int::binary_i32(op, a, b, out, len), + SimdLevel::Avx2Fma => x86_64::avx2_int::binary_i32(op, a, b, out, len), + _ => binary_scalar_i32(op, a, b, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon_int::binary_i32(op, a, b, out, len), + _ => binary_scalar_i32(op, a, b, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + binary_scalar_i32(op, a, b, out, len); +} + +half_binary_op!(binary, binary_f32, BinaryOp); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_binary_add_f32() { + let a: Vec = (0..100).map(|x| x as f32).collect(); + let b: Vec = (0..100).map(|x| (x * 2) as f32).collect(); + let mut out = vec![0.0f32; 100]; + + unsafe { binary_f32(BinaryOp::Add, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 100) } + + for i in 0..100 { + assert_eq!(out[i], a[i] + b[i], "mismatch at index {}", i); + } + } + + #[test] + fn test_binary_mul_f64() { + let a: Vec = (1..101).map(|x| x as f64).collect(); + let b: Vec = (1..101).map(|x| (x * 2) as f64).collect(); + let mut out = vec![0.0f64; 100]; + + unsafe { binary_f64(BinaryOp::Mul, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 100) } + + for i in 0..100 { + assert_eq!(out[i], a[i] * b[i], "mismatch at index {}", i); + } + } + + #[test] + fn test_small_array_uses_scalar() { + let a = [1.0f32, 2.0, 3.0, 4.0]; + let b = [5.0f32, 6.0, 7.0, 8.0]; + let mut out = [0.0f32; 4]; + + unsafe { binary_f32(BinaryOp::Add, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 4) } + + assert_eq!(out, [6.0, 8.0, 10.0, 12.0]); + } + + #[test] + fn test_non_aligned_length() { + let a: Vec = (0..67).map(|x| x as f32).collect(); + let b: Vec = (0..67).map(|x| (x * 2) as f32).collect(); + let mut out = vec![0.0f32; 67]; + + unsafe { binary_f32(BinaryOp::Add, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 67) } + + for i in 0..67 { + assert_eq!(out[i], a[i] + b[i], "mismatch at index {}", i); + } + } + + #[test] + fn test_binary_pow_f32() { + let a: Vec = (1..101).map(|x| x as f32 * 0.1).collect(); + let b: Vec = (1..101).map(|x| (x % 5) as f32 + 0.5).collect(); + let mut out = vec![0.0f32; 100]; + + unsafe { binary_f32(BinaryOp::Pow, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 100) } + + for i in 0..100 { + let expected = a[i].powf(b[i]); + let diff = (out[i] - expected).abs(); + assert!( + diff < 1e-3 * expected.abs().max(1.0), + "pow mismatch at {}: got {}, expected {} (a={}, b={})", + i, + out[i], + expected, + a[i], + b[i] + ); + } + } + + #[test] + fn test_binary_pow_f64() { + let a: Vec = (1..101).map(|x| x as f64 * 0.1).collect(); + let b: Vec = (1..101).map(|x| (x % 5) as f64 + 0.5).collect(); + let mut out = vec![0.0f64; 100]; + + unsafe { binary_f64(BinaryOp::Pow, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 100) } + + for i in 0..100 { + let expected = a[i].powf(b[i]); + let diff = (out[i] - expected).abs(); + assert!( + diff < 1e-4 * expected.abs().max(1.0), + "pow mismatch at {}: got {}, expected {} (a={}, b={})", + i, + out[i], + expected, + a[i], + b[i] + ); + } + } + + #[test] + fn test_binary_max_min_f32() { + let a: Vec = (0..100).map(|x| (x as f32 - 50.0) * 0.5).collect(); + let b: Vec = (0..100).map(|x| ((x + 25) as f32 - 50.0) * 0.5).collect(); + let mut out_max = vec![0.0f32; 100]; + let mut out_min = vec![0.0f32; 100]; + + unsafe { + binary_f32( + BinaryOp::Max, + a.as_ptr(), + b.as_ptr(), + out_max.as_mut_ptr(), + 100, + ); + binary_f32( + BinaryOp::Min, + a.as_ptr(), + b.as_ptr(), + out_min.as_mut_ptr(), + 100, + ); + } + + for i in 0..100 { + assert_eq!(out_max[i], a[i].max(b[i]), "max mismatch at {}", i); + assert_eq!(out_min[i], a[i].min(b[i]), "min mismatch at {}", i); + } + } + + #[test] + fn test_binary_sub_div_f32() { + let a: Vec = (1..101).map(|x| x as f32 * 2.0).collect(); + let b: Vec = (1..101).map(|x| x as f32).collect(); + let mut out_sub = vec![0.0f32; 100]; + let mut out_div = vec![0.0f32; 100]; + + unsafe { + binary_f32( + BinaryOp::Sub, + a.as_ptr(), + b.as_ptr(), + out_sub.as_mut_ptr(), + 100, + ); + binary_f32( + BinaryOp::Div, + a.as_ptr(), + b.as_ptr(), + out_div.as_mut_ptr(), + 100, + ); + } + + for i in 0..100 { + assert_eq!(out_sub[i], a[i] - b[i], "sub mismatch at {}", i); + assert_eq!(out_div[i], a[i] / b[i], "div mismatch at {}", i); + } + } + + // ============================================================================ + // Streaming store tests (x86-64 only) + // ============================================================================ + + #[cfg(target_arch = "x86_64")] + mod streaming_tests { + use super::super::super::super::streaming::{ + STREAMING_THRESHOLD_F32, STREAMING_THRESHOLD_F64, + }; + + /// Test streaming threshold constant is correctly defined + #[test] + fn test_streaming_threshold_defined() { + // 1MB = 262144 f32s, 131072 f64s + assert_eq!(STREAMING_THRESHOLD_F32, 262144); + assert_eq!(STREAMING_THRESHOLD_F64, 131072); + } + } + + /// Test that large arrays produce correct results (exercises streaming path if aligned) + #[test] + fn test_large_array_correctness_f32() { + const LEN: usize = 1024; + let a: Vec = (0..LEN).map(|x| (x as f32) * 0.1).collect(); + let b: Vec = (0..LEN).map(|x| (x as f32) * 0.2 + 1.0).collect(); + let mut out = vec![0.0f32; LEN]; + + unsafe { binary_f32(BinaryOp::Add, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), LEN) } + + for i in 0..LEN { + let expected = a[i] + b[i]; + assert!( + (out[i] - expected).abs() < 1e-6, + "large array mismatch at {}: got {}, expected {}", + i, + out[i], + expected + ); + } + } + + /// Test that large arrays produce correct results for all operations + #[test] + fn test_large_array_all_ops_f32() { + const LEN: usize = 512; + let a: Vec = (1..=LEN as i32).map(|x| x as f32).collect(); + let b: Vec = (1..=LEN as i32).map(|x| (x as f32) * 0.5 + 0.5).collect(); + + for op in [ + BinaryOp::Add, + BinaryOp::Sub, + BinaryOp::Mul, + BinaryOp::Div, + BinaryOp::Max, + BinaryOp::Min, + ] { + let mut out = vec![0.0f32; LEN]; + unsafe { binary_f32(op, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), LEN) } + + for i in 0..LEN { + let expected = match op { + BinaryOp::Add => a[i] + b[i], + BinaryOp::Sub => a[i] - b[i], + BinaryOp::Mul => a[i] * b[i], + BinaryOp::Div => a[i] / b[i], + BinaryOp::Max => a[i].max(b[i]), + BinaryOp::Min => a[i].min(b[i]), + BinaryOp::Pow => a[i].powf(b[i]), + BinaryOp::Atan2 => a[i].atan2(b[i]), + }; + assert!( + (out[i] - expected).abs() < 1e-5 * expected.abs().max(1.0), + "{:?} mismatch at {}: got {}, expected {}", + op, + i, + out[i], + expected + ); + } + } + } + + #[test] + fn test_binary_add_i32() { + let a: Vec = (0..100).collect(); + let b: Vec = (0..100).map(|x| x * 2).collect(); + let mut out = vec![0i32; 100]; + + unsafe { binary_i32(BinaryOp::Add, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 100) } + + for i in 0..100 { + assert_eq!(out[i], a[i] + b[i], "i32 add mismatch at index {}", i); + } + } + + #[test] + fn test_binary_all_ops_i32() { + let a: Vec = (1..101).collect(); + let b: Vec = (1..101).map(|x| x * 2 + 1).collect(); + + for op in [ + BinaryOp::Add, + BinaryOp::Sub, + BinaryOp::Mul, + BinaryOp::Max, + BinaryOp::Min, + ] { + let mut out = vec![0i32; 100]; + unsafe { binary_i32(op, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 100) } + + for i in 0..100 { + let expected = match op { + BinaryOp::Add => a[i] + b[i], + BinaryOp::Sub => a[i] - b[i], + BinaryOp::Mul => a[i] * b[i], + BinaryOp::Max => a[i].max(b[i]), + BinaryOp::Min => a[i].min(b[i]), + _ => unreachable!(), + }; + assert_eq!(out[i], expected, "{:?} i32 mismatch at {}", op, i); + } + } + } + + #[test] + fn test_binary_i32_non_aligned_length() { + let a: Vec = (0..67).collect(); + let b: Vec = (0..67).map(|x| x * 3).collect(); + let mut out = vec![0i32; 67]; + + unsafe { binary_i32(BinaryOp::Add, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 67) } + + for i in 0..67 { + assert_eq!(out[i], a[i] + b[i], "i32 add tail mismatch at index {}", i); + } + } + + #[test] + fn test_binary_i32_small_array() { + let a = [1i32, 2, 3, 4]; + let b = [5i32, 6, 7, 8]; + let mut out = [0i32; 4]; + + unsafe { binary_i32(BinaryOp::Add, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 4) } + + assert_eq!(out, [6, 8, 10, 12]); + } + + #[test] + fn test_binary_div_i32() { + let a: Vec = (1..101).collect(); + let b: Vec = (1..101).map(|x| x * 2 + 1).collect(); + let mut out = vec![0i32; 100]; + + unsafe { binary_i32(BinaryOp::Div, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 100) } + + for i in 0..100 { + assert_eq!(out[i], a[i] / b[i], "div mismatch at {}", i); + } + } + + #[test] + fn test_binary_div_i32_by_zero() { + let a = [10i32, 20, 0, 30, -5, 100, i32::MAX, i32::MIN]; + let b = [0i32, 2, 5, 0, 0, -3, 0, 0]; + let mut out = [0i32; 8]; + + unsafe { binary_i32(BinaryOp::Div, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 8) } + + assert_eq!(out[0], 0, "10 / 0 should be 0"); + assert_eq!(out[1], 10, "20 / 2 should be 10"); + assert_eq!(out[2], 0, "0 / 5 should be 0"); + assert_eq!(out[3], 0, "30 / 0 should be 0"); + assert_eq!(out[4], 0, "-5 / 0 should be 0"); + assert_eq!(out[5], -33, "100 / -3 should be -33"); + assert_eq!(out[6], 0, "i32::MAX / 0 should be 0"); + assert_eq!(out[7], 0, "i32::MIN / 0 should be 0"); + } + + #[test] + fn test_binary_pow_i32() { + let a = [2i32, 3, 10, 0, -2, 1, 5, 100]; + let b = [10i32, 5, 3, 5, 3, 100, 0, 1]; + let mut out = [0i32; 8]; + + unsafe { binary_i32(BinaryOp::Pow, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 8) } + + assert_eq!(out[0], 1024, "2^10"); + assert_eq!(out[1], 243, "3^5"); + assert_eq!(out[2], 1000, "10^3"); + assert_eq!(out[3], 0, "0^5"); + assert_eq!(out[4], -8, "(-2)^3"); + assert_eq!(out[5], 1, "1^100"); + assert_eq!(out[6], 1, "5^0"); + assert_eq!(out[7], 100, "100^1"); + } + + #[test] + fn test_binary_atan2_i32() { + let a = [0i32, 1, -1, 10, 0, 100]; + let b = [1i32, 0, 0, 10, 0, 1]; + let mut out = [0i32; 6]; + + unsafe { binary_i32(BinaryOp::Atan2, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 6) } + + assert_eq!(out[0], 0, "atan2(0,1) = 0"); + assert_eq!(out[1], 1, "atan2(1,0) truncates to 1"); + assert_eq!(out[2], -1, "atan2(-1,0) truncates to -1"); + assert_eq!(out[3], 0, "atan2(10,10) truncates to 0"); + } + + /// Test alignment check functions (x86-64 only) + #[cfg(target_arch = "x86_64")] + #[test] + fn test_alignment_checks() { + use crate::runtime::cpu::kernels::simd::streaming::{is_aligned_avx2, is_aligned_avx512}; + + assert!(is_aligned_avx2(32 as *const f32)); + assert!(is_aligned_avx2(64 as *const f32)); + assert!(!is_aligned_avx2(16 as *const f32)); + + assert!(is_aligned_avx512(64 as *const f32)); + assert!(is_aligned_avx512(128 as *const f32)); + assert!(!is_aligned_avx512(32 as *const f32)); + } +} diff --git a/src/runtime/cpu/kernels/simd/binary/mod.rs b/src/runtime/cpu/kernels/simd/binary/mod.rs index af6afb27..21046414 100644 --- a/src/runtime/cpu/kernels/simd/binary/mod.rs +++ b/src/runtime/cpu/kernels/simd/binary/mod.rs @@ -1,347 +1,12 @@ -//! SIMD-accelerated binary operations +//! SIMD-accelerated binary operations. //! -//! This module provides multi-architecture SIMD implementations for element-wise -//! binary operations (add, sub, mul, div, max, min, pow). -//! -//! # Architecture Support -//! -//! | Architecture | Instruction Set | Vector Width | f32 lanes | f64 lanes | -//! |--------------|-----------------|--------------|-----------|-----------| -//! | x86-64 | AVX-512 | 512 bits | 16 | 8 | -//! | x86-64 | AVX2 + FMA | 256 bits | 8 | 4 | -//! | ARM64 | NEON | 128 bits | 4 | 2 | +//! See [`dispatch`] for the public dispatch functions. #[cfg(target_arch = "aarch64")] -mod aarch64; +pub(crate) mod aarch64; #[cfg(target_arch = "x86_64")] -mod x86_64; - -use super::{SimdLevel, detect_simd}; -use crate::ops::BinaryOp; - -// Import scalar fallbacks from kernels module (single source of truth) -pub use crate::runtime::cpu::kernels::binary::{binary_scalar_f32, binary_scalar_f64}; - -/// Minimum elements to justify SIMD overhead -const SIMD_THRESHOLD: usize = 32; - -/// SIMD binary operation for f32 -/// -/// # Safety -/// - `a`, `b`, and `out` must be valid pointers to `len` elements -#[inline] -pub unsafe fn binary_f32(op: BinaryOp, a: *const f32, b: *const f32, out: *mut f32, len: usize) { - let level = detect_simd(); - - if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { - binary_scalar_f32(op, a, b, out, len); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => x86_64::avx512::binary_f32(op, a, b, out, len), - SimdLevel::Avx2Fma => x86_64::avx2::binary_f32(op, a, b, out, len), - _ => binary_scalar_f32(op, a, b, out, len), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::binary_f32(op, a, b, out, len), - _ => binary_scalar_f32(op, a, b, out, len), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - binary_scalar_f32(op, a, b, out, len); -} - -/// SIMD binary operation for f64 -/// -/// # Safety -/// - `a`, `b`, and `out` must be valid pointers to `len` elements -#[inline] -pub unsafe fn binary_f64(op: BinaryOp, a: *const f64, b: *const f64, out: *mut f64, len: usize) { - let level = detect_simd(); - - if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { - binary_scalar_f64(op, a, b, out, len); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => x86_64::avx512::binary_f64(op, a, b, out, len), - SimdLevel::Avx2Fma => x86_64::avx2::binary_f64(op, a, b, out, len), - _ => binary_scalar_f64(op, a, b, out, len), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::binary_f64(op, a, b, out, len), - _ => binary_scalar_f64(op, a, b, out, len), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - binary_scalar_f64(op, a, b, out, len); -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_binary_add_f32() { - let a: Vec = (0..100).map(|x| x as f32).collect(); - let b: Vec = (0..100).map(|x| (x * 2) as f32).collect(); - let mut out = vec![0.0f32; 100]; - - unsafe { binary_f32(BinaryOp::Add, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 100) } - - for i in 0..100 { - assert_eq!(out[i], a[i] + b[i], "mismatch at index {}", i); - } - } - - #[test] - fn test_binary_mul_f64() { - let a: Vec = (1..101).map(|x| x as f64).collect(); - let b: Vec = (1..101).map(|x| (x * 2) as f64).collect(); - let mut out = vec![0.0f64; 100]; - - unsafe { binary_f64(BinaryOp::Mul, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 100) } - - for i in 0..100 { - assert_eq!(out[i], a[i] * b[i], "mismatch at index {}", i); - } - } - - #[test] - fn test_small_array_uses_scalar() { - let a = [1.0f32, 2.0, 3.0, 4.0]; - let b = [5.0f32, 6.0, 7.0, 8.0]; - let mut out = [0.0f32; 4]; - - unsafe { binary_f32(BinaryOp::Add, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 4) } - - assert_eq!(out, [6.0, 8.0, 10.0, 12.0]); - } - - #[test] - fn test_non_aligned_length() { - let a: Vec = (0..67).map(|x| x as f32).collect(); - let b: Vec = (0..67).map(|x| (x * 2) as f32).collect(); - let mut out = vec![0.0f32; 67]; - - unsafe { binary_f32(BinaryOp::Add, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 67) } - - for i in 0..67 { - assert_eq!(out[i], a[i] + b[i], "mismatch at index {}", i); - } - } - - #[test] - fn test_binary_pow_f32() { - let a: Vec = (1..101).map(|x| x as f32 * 0.1).collect(); - let b: Vec = (1..101).map(|x| (x % 5) as f32 + 0.5).collect(); - let mut out = vec![0.0f32; 100]; - - unsafe { binary_f32(BinaryOp::Pow, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 100) } - - for i in 0..100 { - let expected = a[i].powf(b[i]); - let diff = (out[i] - expected).abs(); - // pow uses exp(b*log(a)), so errors compound - ~1e-3 relative error is acceptable - assert!( - diff < 1e-3 * expected.abs().max(1.0), - "pow mismatch at {}: got {}, expected {} (a={}, b={})", - i, - out[i], - expected, - a[i], - b[i] - ); - } - } - - #[test] - fn test_binary_pow_f64() { - let a: Vec = (1..101).map(|x| x as f64 * 0.1).collect(); - let b: Vec = (1..101).map(|x| (x % 5) as f64 + 0.5).collect(); - let mut out = vec![0.0f64; 100]; - - unsafe { binary_f64(BinaryOp::Pow, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), 100) } - - for i in 0..100 { - let expected = a[i].powf(b[i]); - let diff = (out[i] - expected).abs(); - // pow uses exp(b*log(a)), so errors compound - ~1e-4 relative error is acceptable - assert!( - diff < 1e-4 * expected.abs().max(1.0), - "pow mismatch at {}: got {}, expected {} (a={}, b={})", - i, - out[i], - expected, - a[i], - b[i] - ); - } - } - - #[test] - fn test_binary_max_min_f32() { - let a: Vec = (0..100).map(|x| (x as f32 - 50.0) * 0.5).collect(); - let b: Vec = (0..100).map(|x| ((x + 25) as f32 - 50.0) * 0.5).collect(); - let mut out_max = vec![0.0f32; 100]; - let mut out_min = vec![0.0f32; 100]; - - unsafe { - binary_f32( - BinaryOp::Max, - a.as_ptr(), - b.as_ptr(), - out_max.as_mut_ptr(), - 100, - ); - binary_f32( - BinaryOp::Min, - a.as_ptr(), - b.as_ptr(), - out_min.as_mut_ptr(), - 100, - ); - } - - for i in 0..100 { - assert_eq!(out_max[i], a[i].max(b[i]), "max mismatch at {}", i); - assert_eq!(out_min[i], a[i].min(b[i]), "min mismatch at {}", i); - } - } - - #[test] - fn test_binary_sub_div_f32() { - let a: Vec = (1..101).map(|x| x as f32 * 2.0).collect(); - let b: Vec = (1..101).map(|x| x as f32).collect(); - let mut out_sub = vec![0.0f32; 100]; - let mut out_div = vec![0.0f32; 100]; - - unsafe { - binary_f32( - BinaryOp::Sub, - a.as_ptr(), - b.as_ptr(), - out_sub.as_mut_ptr(), - 100, - ); - binary_f32( - BinaryOp::Div, - a.as_ptr(), - b.as_ptr(), - out_div.as_mut_ptr(), - 100, - ); - } - - for i in 0..100 { - assert_eq!(out_sub[i], a[i] - b[i], "sub mismatch at {}", i); - assert_eq!(out_div[i], a[i] / b[i], "div mismatch at {}", i); - } - } - - // ============================================================================ - // Streaming store tests (x86-64 only) - // ============================================================================ - - #[cfg(target_arch = "x86_64")] - mod streaming_tests { - use super::super::super::streaming::{STREAMING_THRESHOLD_F32, STREAMING_THRESHOLD_F64}; - - /// Test streaming threshold constant is correctly defined - #[test] - fn test_streaming_threshold_defined() { - // 1MB = 262144 f32s, 131072 f64s - assert_eq!(STREAMING_THRESHOLD_F32, 262144); - assert_eq!(STREAMING_THRESHOLD_F64, 131072); - } - } - - /// Test that large arrays produce correct results (exercises streaming path if aligned) - #[test] - fn test_large_array_correctness_f32() { - // Use a size that triggers streaming (> 1MB = 262144 f32s) - // For testing we use a smaller aligned buffer to avoid OOM - const LEN: usize = 1024; // Small but validates the code path - let a: Vec = (0..LEN).map(|x| (x as f32) * 0.1).collect(); - let b: Vec = (0..LEN).map(|x| (x as f32) * 0.2 + 1.0).collect(); - let mut out = vec![0.0f32; LEN]; - - unsafe { binary_f32(BinaryOp::Add, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), LEN) } - - for i in 0..LEN { - let expected = a[i] + b[i]; - assert!( - (out[i] - expected).abs() < 1e-6, - "large array mismatch at {}: got {}, expected {}", - i, - out[i], - expected - ); - } - } - - /// Test that large arrays produce correct results for all operations - #[test] - fn test_large_array_all_ops_f32() { - const LEN: usize = 512; - let a: Vec = (1..=LEN as i32).map(|x| x as f32).collect(); - let b: Vec = (1..=LEN as i32).map(|x| (x as f32) * 0.5 + 0.5).collect(); - - for op in [ - BinaryOp::Add, - BinaryOp::Sub, - BinaryOp::Mul, - BinaryOp::Div, - BinaryOp::Max, - BinaryOp::Min, - ] { - let mut out = vec![0.0f32; LEN]; - unsafe { binary_f32(op, a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), LEN) } - - for i in 0..LEN { - let expected = match op { - BinaryOp::Add => a[i] + b[i], - BinaryOp::Sub => a[i] - b[i], - BinaryOp::Mul => a[i] * b[i], - BinaryOp::Div => a[i] / b[i], - BinaryOp::Max => a[i].max(b[i]), - BinaryOp::Min => a[i].min(b[i]), - BinaryOp::Pow => a[i].powf(b[i]), - BinaryOp::Atan2 => a[i].atan2(b[i]), - }; - assert!( - (out[i] - expected).abs() < 1e-5 * expected.abs().max(1.0), - "{:?} mismatch at {}: got {}, expected {}", - op, - i, - out[i], - expected - ); - } - } - } - - /// Test alignment check functions (x86-64 only) - #[cfg(target_arch = "x86_64")] - #[test] - fn test_alignment_checks() { - use super::super::streaming::{is_aligned_avx2, is_aligned_avx512}; +pub(crate) mod x86_64; - // Test known aligned addresses - assert!(is_aligned_avx2(32 as *const f32)); // 32 % 32 == 0 - assert!(is_aligned_avx2(64 as *const f32)); // 64 % 32 == 0 - assert!(!is_aligned_avx2(16 as *const f32)); // 16 % 32 != 0 +pub(crate) mod dispatch; - assert!(is_aligned_avx512(64 as *const f32)); // 64 % 64 == 0 - assert!(is_aligned_avx512(128 as *const f32)); // 128 % 64 == 0 - assert!(!is_aligned_avx512(32 as *const f32)); // 32 % 64 != 0 - } -} +pub use dispatch::*; diff --git a/src/runtime/cpu/kernels/simd/binary/x86_64/avx2_int.rs b/src/runtime/cpu/kernels/simd/binary/x86_64/avx2_int.rs new file mode 100644 index 00000000..7d0bfaf7 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/binary/x86_64/avx2_int.rs @@ -0,0 +1,67 @@ +//! AVX2 binary operation kernels for i32 +//! +//! Processes 8 i32s per iteration using 256-bit vectors. + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use crate::ops::BinaryOp; + +const I32_LANES: usize = 8; + +macro_rules! impl_binary_i32_avx2 { + ($name:ident, $vec_op:ident) => { + #[target_feature(enable = "avx2")] + unsafe fn $name(a: *const i32, b: *const i32, out: *mut i32, chunks: usize) { + for i in 0..chunks { + let offset = i * I32_LANES; + let va = _mm256_loadu_si256(a.add(offset) as *const __m256i); + let vb = _mm256_loadu_si256(b.add(offset) as *const __m256i); + let vr = $vec_op(va, vb); + _mm256_storeu_si256(out.add(offset) as *mut __m256i, vr); + } + } + }; +} + +impl_binary_i32_avx2!(binary_add_i32, _mm256_add_epi32); +impl_binary_i32_avx2!(binary_sub_i32, _mm256_sub_epi32); +impl_binary_i32_avx2!(binary_mul_i32, _mm256_mullo_epi32); +impl_binary_i32_avx2!(binary_max_i32, _mm256_max_epi32); +impl_binary_i32_avx2!(binary_min_i32, _mm256_min_epi32); + +/// AVX2 binary operation for i32 +/// +/// # Safety +/// - CPU must support AVX2 +/// - All pointers must be valid for `len` elements +#[target_feature(enable = "avx2")] +pub unsafe fn binary_i32(op: BinaryOp, a: *const i32, b: *const i32, out: *mut i32, len: usize) { + let chunks = len / I32_LANES; + let remainder = len % I32_LANES; + + match op { + BinaryOp::Add => binary_add_i32(a, b, out, chunks), + BinaryOp::Sub => binary_sub_i32(a, b, out, chunks), + BinaryOp::Mul => binary_mul_i32(a, b, out, chunks), + BinaryOp::Max => binary_max_i32(a, b, out, chunks), + BinaryOp::Min => binary_min_i32(a, b, out, chunks), + // Div, Pow, Atan2 have no integer SIMD — use scalar fallback + _ => { + super::super::binary_scalar_i32(op, a, b, out, len); + return; + } + } + + // Handle tail with scalar + if remainder > 0 { + let offset = chunks * I32_LANES; + super::super::binary_scalar_i32( + op, + a.add(offset), + b.add(offset), + out.add(offset), + remainder, + ); + } +} diff --git a/src/runtime/cpu/kernels/simd/binary/x86_64/avx512_int.rs b/src/runtime/cpu/kernels/simd/binary/x86_64/avx512_int.rs new file mode 100644 index 00000000..b2347190 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/binary/x86_64/avx512_int.rs @@ -0,0 +1,65 @@ +//! AVX-512 binary operation kernels for i32 +//! +//! Processes 16 i32s per iteration using 512-bit vectors. + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use crate::ops::BinaryOp; + +const I32_LANES: usize = 16; + +macro_rules! impl_binary_i32_avx512 { + ($name:ident, $vec_op:ident) => { + #[target_feature(enable = "avx512f")] + unsafe fn $name(a: *const i32, b: *const i32, out: *mut i32, chunks: usize) { + for i in 0..chunks { + let offset = i * I32_LANES; + let va = _mm512_loadu_si512(a.add(offset) as *const __m512i); + let vb = _mm512_loadu_si512(b.add(offset) as *const __m512i); + let vr = $vec_op(va, vb); + _mm512_storeu_si512(out.add(offset) as *mut __m512i, vr); + } + } + }; +} + +impl_binary_i32_avx512!(binary_add_i32, _mm512_add_epi32); +impl_binary_i32_avx512!(binary_sub_i32, _mm512_sub_epi32); +impl_binary_i32_avx512!(binary_mul_i32, _mm512_mullo_epi32); +impl_binary_i32_avx512!(binary_max_i32, _mm512_max_epi32); +impl_binary_i32_avx512!(binary_min_i32, _mm512_min_epi32); + +/// AVX-512 binary operation for i32 +/// +/// # Safety +/// - CPU must support AVX-512F +/// - All pointers must be valid for `len` elements +#[target_feature(enable = "avx512f")] +pub unsafe fn binary_i32(op: BinaryOp, a: *const i32, b: *const i32, out: *mut i32, len: usize) { + let chunks = len / I32_LANES; + let remainder = len % I32_LANES; + + match op { + BinaryOp::Add => binary_add_i32(a, b, out, chunks), + BinaryOp::Sub => binary_sub_i32(a, b, out, chunks), + BinaryOp::Mul => binary_mul_i32(a, b, out, chunks), + BinaryOp::Max => binary_max_i32(a, b, out, chunks), + BinaryOp::Min => binary_min_i32(a, b, out, chunks), + _ => { + super::super::binary_scalar_i32(op, a, b, out, len); + return; + } + } + + if remainder > 0 { + let offset = chunks * I32_LANES; + super::super::binary_scalar_i32( + op, + a.add(offset), + b.add(offset), + out.add(offset), + remainder, + ); + } +} diff --git a/src/runtime/cpu/kernels/simd/binary/x86_64/mod.rs b/src/runtime/cpu/kernels/simd/binary/x86_64/mod.rs index dc317472..f338e82c 100644 --- a/src/runtime/cpu/kernels/simd/binary/x86_64/mod.rs +++ b/src/runtime/cpu/kernels/simd/binary/x86_64/mod.rs @@ -1,4 +1,6 @@ //! x86-64 SIMD implementations for binary operations pub mod avx2; +pub mod avx2_int; pub mod avx512; +pub mod avx512_int; diff --git a/src/runtime/cpu/kernels/simd/clamp/mod.rs b/src/runtime/cpu/kernels/simd/clamp/mod.rs index 529bd23e..550dde4a 100644 --- a/src/runtime/cpu/kernels/simd/clamp/mod.rs +++ b/src/runtime/cpu/kernels/simd/clamp/mod.rs @@ -118,6 +118,8 @@ pub unsafe fn clamp_scalar_f64( } } +half_clamp!(clamp, clamp_f32); + #[cfg(test)] mod tests { use super::*; diff --git a/src/runtime/cpu/kernels/simd/compare/mod.rs b/src/runtime/cpu/kernels/simd/compare/mod.rs index 81298601..1252220b 100644 --- a/src/runtime/cpu/kernels/simd/compare/mod.rs +++ b/src/runtime/cpu/kernels/simd/compare/mod.rs @@ -173,6 +173,8 @@ pub unsafe fn compare_scalar_f64( } } +half_binary_op!(compare, compare_f32, CompareOp); + #[cfg(test)] mod tests { use super::*; diff --git a/src/runtime/cpu/kernels/simd/conv/half.rs b/src/runtime/cpu/kernels/simd/conv/half.rs new file mode 100644 index 00000000..977cd175 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/conv/half.rs @@ -0,0 +1,122 @@ +//! f16/bf16 convolution wrappers via bulk f32 conversion +//! +//! Convolutions need random access across the entire input (sliding window), +//! so block-convert is not feasible. Instead we pre-convert all inputs to f32 +//! using a single allocation (partitioned into input/weight/output/bias regions) +//! to minimize allocator overhead. + +use super::super::half_convert_utils::*; +use super::*; +use crate::ops::conv_common::{Conv1dParams, Conv2dParams}; + +/// Generate f16 and bf16 conv wrappers that pre-convert to f32 via a single allocation. +macro_rules! half_conv_wrapper { + ( + $fn_f16:ident, $fn_bf16:ident, $f32_fn:path, $params_ty:ty, + sizes: |$p:ident| ($in_expr:expr, $w_expr:expr, $out_expr:expr, $bias_expr:expr) + ) => { + #[cfg(feature = "f16")] + pub unsafe fn $fn_f16( + input: *const ::half::f16, + weight: *const ::half::f16, + bias: Option<*const ::half::f16>, + output: *mut ::half::f16, + $p: $params_ty, + ) { + let (input_len, weight_len, output_len, bias_len) = + ($in_expr, $w_expr, $out_expr, $bias_expr); + let total = + input_len + weight_len + output_len + if bias.is_some() { bias_len } else { 0 }; + let mut buf = vec![0.0f32; total]; + let (input_f32, rest) = buf.split_at_mut(input_len); + let (weight_f32, rest) = rest.split_at_mut(weight_len); + let (output_f32, bias_f32) = rest.split_at_mut(output_len); + + convert_f16_to_f32(input as *const u16, input_f32.as_mut_ptr(), input_len); + convert_f16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), weight_len); + + let bias_ptr = if let Some(b) = bias { + convert_f16_to_f32(b as *const u16, bias_f32.as_mut_ptr(), bias_len); + Some(bias_f32.as_ptr()) + } else { + None + }; + + $f32_fn( + input_f32.as_ptr(), + weight_f32.as_ptr(), + bias_ptr, + output_f32.as_mut_ptr(), + $p, + ); + convert_f32_to_f16(output_f32.as_ptr(), output as *mut u16, output_len); + } + + #[cfg(feature = "f16")] + pub unsafe fn $fn_bf16( + input: *const ::half::bf16, + weight: *const ::half::bf16, + bias: Option<*const ::half::bf16>, + output: *mut ::half::bf16, + $p: $params_ty, + ) { + let (input_len, weight_len, output_len, bias_len) = + ($in_expr, $w_expr, $out_expr, $bias_expr); + let total = + input_len + weight_len + output_len + if bias.is_some() { bias_len } else { 0 }; + let mut buf = vec![0.0f32; total]; + let (input_f32, rest) = buf.split_at_mut(input_len); + let (weight_f32, rest) = rest.split_at_mut(weight_len); + let (output_f32, bias_f32) = rest.split_at_mut(output_len); + + convert_bf16_to_f32(input as *const u16, input_f32.as_mut_ptr(), input_len); + convert_bf16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), weight_len); + + let bias_ptr = if let Some(b) = bias { + convert_bf16_to_f32(b as *const u16, bias_f32.as_mut_ptr(), bias_len); + Some(bias_f32.as_ptr()) + } else { + None + }; + + $f32_fn( + input_f32.as_ptr(), + weight_f32.as_ptr(), + bias_ptr, + output_f32.as_mut_ptr(), + $p, + ); + convert_f32_to_bf16(output_f32.as_ptr(), output as *mut u16, output_len); + } + }; +} + +half_conv_wrapper!( + conv1d_f16, conv1d_bf16, conv1d_f32, Conv1dParams, + sizes: |params| ( + params.batch * params.c_in * params.length, + params.c_out * (params.c_in / params.groups) * params.kernel_size, + params.batch * params.c_out * params.output_length, + params.c_out + ) +); + +half_conv_wrapper!( + conv2d_f16, conv2d_bf16, conv2d_f32, Conv2dParams, + sizes: |params| ( + params.batch * params.c_in * params.height * params.width, + params.c_out * (params.c_in / params.groups) * params.kernel_h * params.kernel_w, + params.batch * params.c_out * params.output_h * params.output_w, + params.c_out + ) +); + +half_conv_wrapper!( + depthwise_conv2d_f16, depthwise_conv2d_bf16, depthwise_conv2d_f32, Conv2dParams, + sizes: |params| ( + params.batch * params.c_in * params.height * params.width, + params.c_in * params.kernel_h * params.kernel_w, + params.batch * params.c_out * params.output_h * params.output_w, + params.c_out + ) +); diff --git a/src/runtime/cpu/kernels/simd/conv/mod.rs b/src/runtime/cpu/kernels/simd/conv/mod.rs index fe325de0..d8126ecb 100644 --- a/src/runtime/cpu/kernels/simd/conv/mod.rs +++ b/src/runtime/cpu/kernels/simd/conv/mod.rs @@ -19,9 +19,17 @@ mod avx512; #[cfg(target_arch = "aarch64")] mod aarch64; +#[cfg(feature = "f16")] +mod half; +mod scalar; + use super::{SimdLevel, detect_simd}; use crate::ops::conv_common::{Conv1dParams, Conv2dParams}; +#[cfg(feature = "f16")] +pub use half::*; +pub use scalar::*; + /// Minimum input channels to justify SIMD overhead for f32 const SIMD_THRESHOLD_F32: usize = 8; @@ -283,86 +291,6 @@ pub unsafe fn depthwise_conv2d_f64( depthwise_conv2d_scalar_f64(input, weight, bias, output, params); } -// ============================================================================ -// Scalar fallbacks -// ============================================================================ - -/// Scalar conv1d for f32 -#[inline] -pub unsafe fn conv1d_scalar_f32( - input: *const f32, - weight: *const f32, - bias: Option<*const f32>, - output: *mut f32, - params: Conv1dParams, -) { - crate::runtime::cpu::kernels::conv::conv1d_kernel(input, weight, bias, output, params); -} - -/// Scalar conv1d for f64 -#[inline] -pub unsafe fn conv1d_scalar_f64( - input: *const f64, - weight: *const f64, - bias: Option<*const f64>, - output: *mut f64, - params: Conv1dParams, -) { - crate::runtime::cpu::kernels::conv::conv1d_kernel(input, weight, bias, output, params); -} - -/// Scalar conv2d for f32 -#[inline] -pub unsafe fn conv2d_scalar_f32( - input: *const f32, - weight: *const f32, - bias: Option<*const f32>, - output: *mut f32, - params: Conv2dParams, -) { - crate::runtime::cpu::kernels::conv::conv2d_kernel(input, weight, bias, output, params); -} - -/// Scalar conv2d for f64 -#[inline] -pub unsafe fn conv2d_scalar_f64( - input: *const f64, - weight: *const f64, - bias: Option<*const f64>, - output: *mut f64, - params: Conv2dParams, -) { - crate::runtime::cpu::kernels::conv::conv2d_kernel(input, weight, bias, output, params); -} - -/// Scalar depthwise conv2d for f32 -#[inline] -pub unsafe fn depthwise_conv2d_scalar_f32( - input: *const f32, - weight: *const f32, - bias: Option<*const f32>, - output: *mut f32, - params: Conv2dParams, -) { - crate::runtime::cpu::kernels::conv::depthwise_conv2d_kernel( - input, weight, bias, output, params, - ); -} - -/// Scalar depthwise conv2d for f64 -#[inline] -pub unsafe fn depthwise_conv2d_scalar_f64( - input: *const f64, - weight: *const f64, - bias: Option<*const f64>, - output: *mut f64, - params: Conv2dParams, -) { - crate::runtime::cpu::kernels::conv::depthwise_conv2d_kernel( - input, weight, bias, output, params, - ); -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/runtime/cpu/kernels/simd/conv/scalar.rs b/src/runtime/cpu/kernels/simd/conv/scalar.rs new file mode 100644 index 00000000..e19e909c --- /dev/null +++ b/src/runtime/cpu/kernels/simd/conv/scalar.rs @@ -0,0 +1,79 @@ +//! Scalar fallbacks for convolution operations + +use crate::ops::conv_common::{Conv1dParams, Conv2dParams}; + +/// Scalar conv1d for f32 +#[inline] +pub unsafe fn conv1d_scalar_f32( + input: *const f32, + weight: *const f32, + bias: Option<*const f32>, + output: *mut f32, + params: Conv1dParams, +) { + crate::runtime::cpu::kernels::conv::conv1d_kernel(input, weight, bias, output, params); +} + +/// Scalar conv1d for f64 +#[inline] +pub unsafe fn conv1d_scalar_f64( + input: *const f64, + weight: *const f64, + bias: Option<*const f64>, + output: *mut f64, + params: Conv1dParams, +) { + crate::runtime::cpu::kernels::conv::conv1d_kernel(input, weight, bias, output, params); +} + +/// Scalar conv2d for f32 +#[inline] +pub unsafe fn conv2d_scalar_f32( + input: *const f32, + weight: *const f32, + bias: Option<*const f32>, + output: *mut f32, + params: Conv2dParams, +) { + crate::runtime::cpu::kernels::conv::conv2d_kernel(input, weight, bias, output, params); +} + +/// Scalar conv2d for f64 +#[inline] +pub unsafe fn conv2d_scalar_f64( + input: *const f64, + weight: *const f64, + bias: Option<*const f64>, + output: *mut f64, + params: Conv2dParams, +) { + crate::runtime::cpu::kernels::conv::conv2d_kernel(input, weight, bias, output, params); +} + +/// Scalar depthwise conv2d for f32 +#[inline] +pub unsafe fn depthwise_conv2d_scalar_f32( + input: *const f32, + weight: *const f32, + bias: Option<*const f32>, + output: *mut f32, + params: Conv2dParams, +) { + crate::runtime::cpu::kernels::conv::depthwise_conv2d_kernel( + input, weight, bias, output, params, + ); +} + +/// Scalar depthwise conv2d for f64 +#[inline] +pub unsafe fn depthwise_conv2d_scalar_f64( + input: *const f64, + weight: *const f64, + bias: Option<*const f64>, + output: *mut f64, + params: Conv2dParams, +) { + crate::runtime::cpu::kernels::conv::depthwise_conv2d_kernel( + input, weight, bias, output, params, + ); +} diff --git a/src/runtime/cpu/kernels/simd/cumulative/aarch64/neon.rs b/src/runtime/cpu/kernels/simd/cumulative/aarch64/neon.rs index 73153bf0..4cfbdf3d 100644 --- a/src/runtime/cpu/kernels/simd/cumulative/aarch64/neon.rs +++ b/src/runtime/cpu/kernels/simd/cumulative/aarch64/neon.rs @@ -35,7 +35,7 @@ pub unsafe fn cumsum_strided_f32( ) { let lanes = 4; let chunks = inner_size / lanes; - let remainder = inner_size % lanes; + let _remainder = inner_size % lanes; for o in 0..outer_size { let outer_offset = o * scan_size * inner_size; diff --git a/src/runtime/cpu/kernels/simd/cumulative/mod.rs b/src/runtime/cpu/kernels/simd/cumulative/mod.rs index cdee660f..9021791b 100644 --- a/src/runtime/cpu/kernels/simd/cumulative/mod.rs +++ b/src/runtime/cpu/kernels/simd/cumulative/mod.rs @@ -251,6 +251,118 @@ unsafe fn cumprod_strided_scalar_f64( } } +// ============================================================================ +// f16 / bf16 wrappers +// ============================================================================ + +#[cfg(feature = "f16")] +/// f16 wrapper for cumsum_strided: converts input to f32, runs f32 cumsum, converts output back. +/// +/// # Safety +/// - All pointers must be valid for the specified sizes +pub unsafe fn cumsum_strided_f16( + a: *const half::f16, + out: *mut half::f16, + scan_size: usize, + outer_size: usize, + inner_size: usize, +) { + use super::half_convert_utils::*; + let total = outer_size * scan_size * inner_size; + let mut a_f32 = vec![0.0f32; total]; + let mut out_f32 = vec![0.0f32; total]; + convert_f16_to_f32(a as *const u16, a_f32.as_mut_ptr(), total); + cumsum_strided_f32( + a_f32.as_ptr(), + out_f32.as_mut_ptr(), + scan_size, + outer_size, + inner_size, + ); + convert_f32_to_f16(out_f32.as_ptr(), out as *mut u16, total); +} + +#[cfg(feature = "f16")] +/// bf16 wrapper for cumsum_strided: converts input to f32, runs f32 cumsum, converts output back. +/// +/// # Safety +/// - All pointers must be valid for the specified sizes +pub unsafe fn cumsum_strided_bf16( + a: *const half::bf16, + out: *mut half::bf16, + scan_size: usize, + outer_size: usize, + inner_size: usize, +) { + use super::half_convert_utils::*; + let total = outer_size * scan_size * inner_size; + let mut a_f32 = vec![0.0f32; total]; + let mut out_f32 = vec![0.0f32; total]; + convert_bf16_to_f32(a as *const u16, a_f32.as_mut_ptr(), total); + cumsum_strided_f32( + a_f32.as_ptr(), + out_f32.as_mut_ptr(), + scan_size, + outer_size, + inner_size, + ); + convert_f32_to_bf16(out_f32.as_ptr(), out as *mut u16, total); +} + +#[cfg(feature = "f16")] +/// f16 wrapper for cumprod_strided: converts input to f32, runs f32 cumprod, converts output back. +/// +/// # Safety +/// - All pointers must be valid for the specified sizes +pub unsafe fn cumprod_strided_f16( + a: *const half::f16, + out: *mut half::f16, + scan_size: usize, + outer_size: usize, + inner_size: usize, +) { + use super::half_convert_utils::*; + let total = outer_size * scan_size * inner_size; + let mut a_f32 = vec![0.0f32; total]; + let mut out_f32 = vec![0.0f32; total]; + convert_f16_to_f32(a as *const u16, a_f32.as_mut_ptr(), total); + cumprod_strided_f32( + a_f32.as_ptr(), + out_f32.as_mut_ptr(), + scan_size, + outer_size, + inner_size, + ); + convert_f32_to_f16(out_f32.as_ptr(), out as *mut u16, total); +} + +#[cfg(feature = "f16")] +/// bf16 wrapper for cumprod_strided: converts input to f32, runs f32 cumprod, converts output back. +/// +/// # Safety +/// - All pointers must be valid for the specified sizes +pub unsafe fn cumprod_strided_bf16( + a: *const half::bf16, + out: *mut half::bf16, + scan_size: usize, + outer_size: usize, + inner_size: usize, +) { + use super::half_convert_utils::*; + let total = outer_size * scan_size * inner_size; + let mut a_f32 = vec![0.0f32; total]; + let mut out_f32 = vec![0.0f32; total]; + convert_bf16_to_f32(a as *const u16, a_f32.as_mut_ptr(), total); + cumprod_strided_f32( + a_f32.as_ptr(), + out_f32.as_mut_ptr(), + scan_size, + outer_size, + inner_size, + ); + convert_f32_to_bf16(out_f32.as_ptr(), out as *mut u16, total); +} + // ============================================================================ // Tests // ============================================================================ diff --git a/src/runtime/cpu/kernels/simd/dot/aarch64/mod.rs b/src/runtime/cpu/kernels/simd/dot/aarch64/mod.rs new file mode 100644 index 00000000..1f5d76af --- /dev/null +++ b/src/runtime/cpu/kernels/simd/dot/aarch64/mod.rs @@ -0,0 +1,3 @@ +//! ARM64 SIMD implementations for integer dot products + +pub mod neon; diff --git a/src/runtime/cpu/kernels/simd/dot/aarch64/neon.rs b/src/runtime/cpu/kernels/simd/dot/aarch64/neon.rs new file mode 100644 index 00000000..6afc2592 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/dot/aarch64/neon.rs @@ -0,0 +1,50 @@ +//! NEON i8 dot product kernels for ARM64 +//! +//! Uses vmull_s8 + vpadalq_s16 for i8 x i8 → i32 accumulation. + +#[cfg(target_arch = "aarch64")] +use std::arch::aarch64::*; + +const I8_LANES: usize = 16; // 128-bit / 8-bit (process 8 at a time via vmull) + +/// Dot product of signed i8 vectors, accumulated in i32. +/// +/// Processes 16 i8 elements per iteration using two vmull_s8 (low/high halves). +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - Pointers must be valid for `len` elements +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn i8xi8_dot_i32(a: *const i8, b: *const i8, len: usize) -> i32 { + let chunks = len / I8_LANES; + let remainder = len % I8_LANES; + + let mut acc = vdupq_n_s32(0); + + for i in 0..chunks { + let offset = i * I8_LANES; + let va = vld1q_s8(a.add(offset)); + let vb = vld1q_s8(b.add(offset)); + + // Multiply low 8 elements: i8 x i8 → 8x i16 + let prod_lo = vmull_s8(vget_low_s8(va), vget_low_s8(vb)); + // Multiply high 8 elements: i8 x i8 → 8x i16 + let prod_hi = vmull_s8(vget_high_s8(va), vget_high_s8(vb)); + + // Pairwise add and accumulate i16 → i32 + acc = vpadalq_s16(acc, prod_lo); + acc = vpadalq_s16(acc, prod_hi); + } + + // Horizontal sum of 4 i32 lanes + let mut result = vaddvq_s32(acc); + + // Scalar tail + for i in 0..remainder { + let offset = chunks * I8_LANES + i; + result += (*a.add(offset) as i32) * (*b.add(offset) as i32); + } + + result +} diff --git a/src/runtime/cpu/kernels/simd/dot/mod.rs b/src/runtime/cpu/kernels/simd/dot/mod.rs new file mode 100644 index 00000000..561045c6 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/dot/mod.rs @@ -0,0 +1,189 @@ +//! SIMD-accelerated integer dot product operations +//! +//! Provides high-throughput i8 x i8 → i32 dot products for quantized inference. +//! +//! # Architecture Support +//! +//! | Architecture | Instruction Set | Elements/cycle | Key Intrinsic | +//! |--------------|------------------|----------------|------------------------| +//! | x86-64 | AVX-512BW | 64 | maddubs + madd | +//! | x86-64 | AVX2 | 32 | maddubs + madd | +//! | ARM64 | NEON | 16 | vmull_s8 + vpadalq_s16 | + +#[cfg(target_arch = "aarch64")] +mod aarch64; +#[cfg(target_arch = "x86_64")] +mod x86_64; + +use super::{SimdLevel, detect_simd}; + +/// Minimum elements to justify SIMD overhead for dot products +const DOT_SIMD_THRESHOLD: usize = 32; + +/// Dot product of signed i8 vectors, accumulated in i32. +/// +/// Automatically dispatches to the best SIMD implementation available: +/// - x86-64/AVX-512BW: 64 elements per iteration via `_mm512_maddubs_epi16` + `_mm512_madd_epi16` +/// - x86-64/AVX2: 32 elements per iteration via `_mm256_maddubs_epi16` + `_mm256_madd_epi16` +/// - ARM64/NEON: 16 elements per iteration via `vmull_s8` + `vpadalq_s16` +/// - Scalar fallback for small arrays (<32 elements) or unsupported platforms +/// +/// Computes sum(a[i] * b[i]) for i in 0..len. +/// +/// # Safety +/// - `a` and `b` must be valid pointers to `len` elements +#[inline] +pub unsafe fn i8xi8_dot_i32(a: *const i8, b: *const i8, len: usize) -> i32 { + let level = detect_simd(); + + if len < DOT_SIMD_THRESHOLD || level == SimdLevel::Scalar { + return i8xi8_dot_scalar(a, b, len); + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => return x86_64::avx512::i8xi8_dot_i32(a, b, len), + SimdLevel::Avx2Fma => return x86_64::avx2::i8xi8_dot_i32(a, b, len), + _ => return i8xi8_dot_scalar(a, b, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => return aarch64::neon::i8xi8_dot_i32(a, b, len), + _ => return i8xi8_dot_scalar(a, b, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + i8xi8_dot_scalar(a, b, len) +} + +/// Scaled dot product of signed i8 vectors, returning f32. +/// +/// Computes scale * sum(a[i] * b[i]) for i in 0..len. +/// +/// # Safety +/// - `a` and `b` must be valid pointers to `len` elements +#[inline] +#[allow(dead_code)] // Public API for downstream crates (e.g., boostr quantized ops) +pub unsafe fn i8xi8_dot_f32(a: *const i8, b: *const i8, scale: f32, len: usize) -> f32 { + (i8xi8_dot_i32(a, b, len) as f32) * scale +} + +/// Scalar fallback for i8 dot product +#[inline] +unsafe fn i8xi8_dot_scalar(a: *const i8, b: *const i8, len: usize) -> i32 { + let mut acc = 0i32; + for i in 0..len { + acc += (*a.add(i) as i32) * (*b.add(i) as i32); + } + acc +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_i8xi8_dot_basic() { + let a: Vec = (0..100).map(|x| (x % 127) as i8).collect(); + let b: Vec = (0..100).map(|x| ((x * 3) % 127) as i8).collect(); + + let result = unsafe { i8xi8_dot_i32(a.as_ptr(), b.as_ptr(), a.len()) }; + + // Compute expected + let expected: i32 = a + .iter() + .zip(b.iter()) + .map(|(&x, &y)| x as i32 * y as i32) + .sum(); + assert_eq!(result, expected); + } + + #[test] + fn test_i8xi8_dot_negative() { + let a: Vec = (0..64).map(|x| (x as i8) - 32).collect(); + let b: Vec = (0..64).map(|x| (x as i8) - 16).collect(); + + let result = unsafe { i8xi8_dot_i32(a.as_ptr(), b.as_ptr(), a.len()) }; + let expected: i32 = a + .iter() + .zip(b.iter()) + .map(|(&x, &y)| x as i32 * y as i32) + .sum(); + assert_eq!(result, expected); + } + + #[test] + fn test_i8xi8_dot_tail() { + // Non-aligned length to exercise scalar tail + let a: Vec = (0..67).map(|x| (x % 50) as i8).collect(); + let b: Vec = (0..67).map(|x| ((x * 2) % 50) as i8).collect(); + + let result = unsafe { i8xi8_dot_i32(a.as_ptr(), b.as_ptr(), a.len()) }; + let expected: i32 = a + .iter() + .zip(b.iter()) + .map(|(&x, &y)| x as i32 * y as i32) + .sum(); + assert_eq!(result, expected); + } + + #[test] + fn test_i8xi8_dot_small() { + let a: Vec = vec![1, 2, 3, 4]; + let b: Vec = vec![5, 6, 7, 8]; + + let result = unsafe { i8xi8_dot_i32(a.as_ptr(), b.as_ptr(), a.len()) }; + assert_eq!(result, 1 * 5 + 2 * 6 + 3 * 7 + 4 * 8); + } + + #[test] + fn test_i8xi8_dot_f32_scaled() { + let a: Vec = vec![10, 20, 30, 40]; + let b: Vec = vec![1, 2, 3, 4]; + let scale = 0.5f32; + + let result = unsafe { i8xi8_dot_f32(a.as_ptr(), b.as_ptr(), scale, a.len()) }; + let expected = (10 + 40 + 90 + 160) as f32 * scale; + assert!((result - expected).abs() < 1e-6); + } + + #[test] + fn test_i8xi8_dot_extremes() { + // Test with extreme i8 values + let a: Vec = vec![ + -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, + -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, + ]; + let b: Vec = vec![ + 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, + 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, 127, -128, + ]; + + let result = unsafe { i8xi8_dot_i32(a.as_ptr(), b.as_ptr(), a.len()) }; + let expected: i32 = a + .iter() + .zip(b.iter()) + .map(|(&x, &y)| x as i32 * y as i32) + .sum(); + assert_eq!(result, expected); + } + + #[test] + fn test_i8xi8_dot_large() { + let a: Vec = (0..1024) + .map(|x| ((x * 7 + 13) % 256 - 128) as i8) + .collect(); + let b: Vec = (0..1024) + .map(|x| ((x * 11 + 5) % 256 - 128) as i8) + .collect(); + + let result = unsafe { i8xi8_dot_i32(a.as_ptr(), b.as_ptr(), a.len()) }; + let expected: i32 = a + .iter() + .zip(b.iter()) + .map(|(&x, &y)| x as i32 * y as i32) + .sum(); + assert_eq!(result, expected); + } +} diff --git a/src/runtime/cpu/kernels/simd/dot/x86_64/avx2.rs b/src/runtime/cpu/kernels/simd/dot/x86_64/avx2.rs new file mode 100644 index 00000000..7c6e6556 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/dot/x86_64/avx2.rs @@ -0,0 +1,67 @@ +//! AVX2 i8 dot product kernels +//! +//! Uses i8 → i16 widening + _mm256_madd_epi16 for correct signed i8 x i8 → i32 accumulation. +//! Processes 32 elements per iteration (two 16-element halves widened to i16). + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +const I8_LANES: usize = 32; // Process 32 i8s per iteration + +/// Horizontal sum of 8 i32 lanes in __m256i +#[target_feature(enable = "avx2")] +unsafe fn hsum_epi32(v: __m256i) -> i32 { + let hi128 = _mm256_extracti128_si256(v, 1); + let lo128 = _mm256_castsi256_si128(v); + let sum128 = _mm_add_epi32(lo128, hi128); + let hi64 = _mm_unpackhi_epi64(sum128, sum128); + let sum64 = _mm_add_epi32(sum128, hi64); + let hi32 = _mm_shuffle_epi32(sum64, 0b_00_00_00_01); + let sum32 = _mm_add_epi32(sum64, hi32); + _mm_cvtsi128_si32(sum32) +} + +/// Dot product of signed i8 vectors, accumulated in i32. +/// +/// Strategy: Load 32 bytes, split into low/high 16 bytes, sign-extend to i16, +/// use _mm256_madd_epi16 (signed i16 pairs → i32) to accumulate. +/// +/// # Safety +/// - CPU must support AVX2 +/// - Pointers must be valid for `len` elements +#[target_feature(enable = "avx2")] +pub unsafe fn i8xi8_dot_i32(a: *const i8, b: *const i8, len: usize) -> i32 { + let chunks = len / I8_LANES; + let remainder = len % I8_LANES; + + let mut acc = _mm256_setzero_si256(); + + for i in 0..chunks { + let offset = i * I8_LANES; + let va = _mm256_loadu_si256(a.add(offset) as *const __m256i); + let vb = _mm256_loadu_si256(b.add(offset) as *const __m256i); + + // Process low 16 bytes: sign-extend i8 → i16 + let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va)); + let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb)); + // madd: multiply pairs of i16, sum adjacent → i32 + let prod_lo = _mm256_madd_epi16(va_lo, vb_lo); + acc = _mm256_add_epi32(acc, prod_lo); + + // Process high 16 bytes + let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1)); + let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1)); + let prod_hi = _mm256_madd_epi16(va_hi, vb_hi); + acc = _mm256_add_epi32(acc, prod_hi); + } + + let mut result = hsum_epi32(acc); + + // Scalar tail + for i in 0..remainder { + let offset = chunks * I8_LANES + i; + result += (*a.add(offset) as i32) * (*b.add(offset) as i32); + } + + result +} diff --git a/src/runtime/cpu/kernels/simd/dot/x86_64/avx512.rs b/src/runtime/cpu/kernels/simd/dot/x86_64/avx512.rs new file mode 100644 index 00000000..ffde7579 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/dot/x86_64/avx512.rs @@ -0,0 +1,69 @@ +//! AVX-512 i8 dot product kernels +//! +//! Uses i8 → i16 widening + _mm512_madd_epi16 for correct signed i8 x i8 → i32 accumulation. +//! Processes 64 elements per iteration (two 32-element halves widened to i16). + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +const I8_LANES: usize = 64; // Process 64 i8s per iteration + +/// Horizontal sum of 16 i32 lanes in __m512i +#[target_feature(enable = "avx512f")] +unsafe fn hsum_epi32_512(v: __m512i) -> i32 { + let lo256 = _mm512_castsi512_si256(v); + let hi256 = _mm512_extracti64x4_epi64(v, 1); + let sum256 = _mm256_add_epi32(lo256, hi256); + let hi128 = _mm256_extracti128_si256(sum256, 1); + let lo128 = _mm256_castsi256_si128(sum256); + let sum128 = _mm_add_epi32(lo128, hi128); + let hi64 = _mm_unpackhi_epi64(sum128, sum128); + let sum64 = _mm_add_epi32(sum128, hi64); + let hi32 = _mm_shuffle_epi32(sum64, 0b_00_00_00_01); + let sum32 = _mm_add_epi32(sum64, hi32); + _mm_cvtsi128_si32(sum32) +} + +/// Dot product of signed i8 vectors using AVX-512BW, accumulated in i32. +/// +/// Strategy: Load 64 bytes, split into low/high 32 bytes, sign-extend to i16, +/// use _mm512_madd_epi16 (signed i16 pairs → i32) to accumulate. +/// +/// # Safety +/// - CPU must support AVX-512F + AVX-512BW +/// - Pointers must be valid for `len` elements +#[target_feature(enable = "avx512f", enable = "avx512bw")] +pub unsafe fn i8xi8_dot_i32(a: *const i8, b: *const i8, len: usize) -> i32 { + let chunks = len / I8_LANES; + let remainder = len % I8_LANES; + + let mut acc = _mm512_setzero_si512(); + + for i in 0..chunks { + let offset = i * I8_LANES; + let va = _mm512_loadu_si512(a.add(offset) as *const __m512i); + let vb = _mm512_loadu_si512(b.add(offset) as *const __m512i); + + // Process low 32 bytes: sign-extend i8 → i16 in 512-bit + let va_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(va)); + let vb_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(vb)); + let prod_lo = _mm512_madd_epi16(va_lo, vb_lo); + acc = _mm512_add_epi32(acc, prod_lo); + + // Process high 32 bytes + let va_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(va, 1)); + let vb_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(vb, 1)); + let prod_hi = _mm512_madd_epi16(va_hi, vb_hi); + acc = _mm512_add_epi32(acc, prod_hi); + } + + let mut result = hsum_epi32_512(acc); + + // Scalar tail + for i in 0..remainder { + let offset = chunks * I8_LANES + i; + result += (*a.add(offset) as i32) * (*b.add(offset) as i32); + } + + result +} diff --git a/src/runtime/cpu/kernels/simd/dot/x86_64/mod.rs b/src/runtime/cpu/kernels/simd/dot/x86_64/mod.rs new file mode 100644 index 00000000..e6b9d14e --- /dev/null +++ b/src/runtime/cpu/kernels/simd/dot/x86_64/mod.rs @@ -0,0 +1,4 @@ +//! x86-64 SIMD implementations for integer dot products + +pub mod avx2; +pub mod avx512; diff --git a/src/runtime/cpu/kernels/simd/fused_activation_mul/aarch64/mod.rs b/src/runtime/cpu/kernels/simd/fused_activation_mul/aarch64/mod.rs new file mode 100644 index 00000000..d143322f --- /dev/null +++ b/src/runtime/cpu/kernels/simd/fused_activation_mul/aarch64/mod.rs @@ -0,0 +1 @@ +pub mod neon; diff --git a/src/runtime/cpu/kernels/simd/fused_activation_mul/aarch64/neon.rs b/src/runtime/cpu/kernels/simd/fused_activation_mul/aarch64/neon.rs new file mode 100644 index 00000000..5fdfa091 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/fused_activation_mul/aarch64/neon.rs @@ -0,0 +1,320 @@ +//! NEON fused activation-mul function kernels for ARM64 +//! +//! Provides vectorized implementations of fused activation * multiplication +//! using 128-bit NEON registers. Functions take two inputs (a, b) and compute +//! activation(a) * b in a single pass. + +#[cfg(target_arch = "aarch64")] +use std::arch::aarch64::*; + +use super::super::super::math::aarch64::neon::{exp_f32, exp_f64, tanh_f32}; + +const F32_LANES: usize = 4; +const F64_LANES: usize = 2; + +// ============================================================================ +// SiLU_mul: (x / (1 + exp(-x))) * y +// ============================================================================ + +/// NEON silu_mul for f32 +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - `a`, `b`, and `out` must point to `len` valid elements +/// - Elements must not overlap +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn silu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let chunks = len / F32_LANES; + let remainder = len % F32_LANES; + let one = vdupq_n_f32(1.0); + + for i in 0..chunks { + let offset = i * F32_LANES; + let x = vld1q_f32(a.add(offset)); + let y = vld1q_f32(b.add(offset)); + let neg_x = vnegq_f32(x); + let exp_neg_x = exp_f32(neg_x); + let activation = vdivq_f32(x, vaddq_f32(one, exp_neg_x)); + let result = vmulq_f32(activation, y); + vst1q_f32(out.add(offset), result); + } + + if remainder > 0 { + let offset = chunks * F32_LANES; + super::super::silu_mul_scalar_f32(a.add(offset), b.add(offset), out.add(offset), remainder); + } +} + +/// NEON silu_mul for f64 +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - `a`, `b`, and `out` must point to `len` valid elements +/// - Elements must not overlap +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn silu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let chunks = len / F64_LANES; + let remainder = len % F64_LANES; + let one = vdupq_n_f64(1.0); + + for i in 0..chunks { + let offset = i * F64_LANES; + let x = vld1q_f64(a.add(offset)); + let y = vld1q_f64(b.add(offset)); + let neg_x = vnegq_f64(x); + let exp_neg_x = exp_f64(neg_x); + let activation = vdivq_f64(x, vaddq_f64(one, exp_neg_x)); + let result = vmulq_f64(activation, y); + vst1q_f64(out.add(offset), result); + } + + if remainder > 0 { + let offset = chunks * F64_LANES; + super::super::silu_mul_scalar_f64(a.add(offset), b.add(offset), out.add(offset), remainder); + } +} + +// ============================================================================ +// GELU_mul: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) * y +// ============================================================================ + +/// NEON gelu_mul for f32 +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - `a`, `b`, and `out` must point to `len` valid elements +/// - Elements must not overlap +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn gelu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let chunks = len / F32_LANES; + let remainder = len % F32_LANES; + + let half = vdupq_n_f32(0.5); + let one = vdupq_n_f32(1.0); + let sqrt_2_over_pi = vdupq_n_f32(0.7978845608); + let tanh_coef = vdupq_n_f32(0.044715); + + for i in 0..chunks { + let offset = i * F32_LANES; + let x = vld1q_f32(a.add(offset)); + let y = vld1q_f32(b.add(offset)); + + // x_cubed = x * x * x + let x_sq = vmulq_f32(x, x); + let x_cubed = vmulq_f32(x_sq, x); + + // inner = sqrt_2_over_pi * (x + tanh_coef * x_cubed) + let tanh_coef_x_cubed = vmulq_f32(tanh_coef, x_cubed); + let x_plus = vaddq_f32(x, tanh_coef_x_cubed); + let inner = vmulq_f32(sqrt_2_over_pi, x_plus); + + // tanh_inner = tanh(inner) + let tanh_inner = tanh_f32(inner); + + // activation = 0.5 * x * (1 + tanh_inner) + let one_plus = vaddq_f32(one, tanh_inner); + let x_times = vmulq_f32(x, one_plus); + let activation = vmulq_f32(half, x_times); + + // result = activation * y + let result = vmulq_f32(activation, y); + vst1q_f32(out.add(offset), result); + } + + if remainder > 0 { + let offset = chunks * F32_LANES; + super::super::gelu_mul_scalar_f32(a.add(offset), b.add(offset), out.add(offset), remainder); + } +} + +/// NEON gelu_mul for f64 +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - `a`, `b`, and `out` must point to `len` valid elements +/// - Elements must not overlap +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn gelu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let chunks = len / F64_LANES; + let remainder = len % F64_LANES; + + let half = vdupq_n_f64(0.5); + let one = vdupq_n_f64(1.0); + let sqrt_2_over_pi = vdupq_n_f64(0.7978845608028654); + let tanh_coef = vdupq_n_f64(0.044715); + + for i in 0..chunks { + let offset = i * F64_LANES; + let x = vld1q_f64(a.add(offset)); + let y = vld1q_f64(b.add(offset)); + + // x_cubed = x * x * x + let x_sq = vmulq_f64(x, x); + let x_cubed = vmulq_f64(x_sq, x); + + // inner = sqrt_2_over_pi * (x + tanh_coef * x_cubed) + let tanh_coef_x_cubed = vmulq_f64(tanh_coef, x_cubed); + let x_plus = vaddq_f64(x, tanh_coef_x_cubed); + let inner = vmulq_f64(sqrt_2_over_pi, x_plus); + + // tanh_inner = tanh(inner) - using exp-based approximation + // tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) + let two_inner = vmulq_f64(vdupq_n_f64(2.0), inner); + let exp_2x = exp_f64(two_inner); + let exp_2x_minus_1 = vsubq_f64(exp_2x, one); + let exp_2x_plus_1 = vaddq_f64(exp_2x, one); + let tanh_inner = vdivq_f64(exp_2x_minus_1, exp_2x_plus_1); + + // activation = 0.5 * x * (1 + tanh_inner) + let one_plus = vaddq_f64(one, tanh_inner); + let x_times = vmulq_f64(x, one_plus); + let activation = vmulq_f64(half, x_times); + + // result = activation * y + let result = vmulq_f64(activation, y); + vst1q_f64(out.add(offset), result); + } + + if remainder > 0 { + let offset = chunks * F64_LANES; + super::super::gelu_mul_scalar_f64(a.add(offset), b.add(offset), out.add(offset), remainder); + } +} + +// ============================================================================ +// ReLU_mul: max(0, x) * y +// ============================================================================ + +/// NEON relu_mul for f32 +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - `a`, `b`, and `out` must point to `len` valid elements +/// - Elements must not overlap +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn relu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let chunks = len / F32_LANES; + let remainder = len % F32_LANES; + let zero = vdupq_n_f32(0.0); + + for i in 0..chunks { + let offset = i * F32_LANES; + let x = vld1q_f32(a.add(offset)); + let y = vld1q_f32(b.add(offset)); + let activation = vmaxq_f32(zero, x); + let result = vmulq_f32(activation, y); + vst1q_f32(out.add(offset), result); + } + + if remainder > 0 { + let offset = chunks * F32_LANES; + super::super::relu_mul_scalar_f32(a.add(offset), b.add(offset), out.add(offset), remainder); + } +} + +/// NEON relu_mul for f64 +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - `a`, `b`, and `out` must point to `len` valid elements +/// - Elements must not overlap +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn relu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let chunks = len / F64_LANES; + let remainder = len % F64_LANES; + let zero = vdupq_n_f64(0.0); + + for i in 0..chunks { + let offset = i * F64_LANES; + let x = vld1q_f64(a.add(offset)); + let y = vld1q_f64(b.add(offset)); + let activation = vmaxq_f64(zero, x); + let result = vmulq_f64(activation, y); + vst1q_f64(out.add(offset), result); + } + + if remainder > 0 { + let offset = chunks * F64_LANES; + super::super::relu_mul_scalar_f64(a.add(offset), b.add(offset), out.add(offset), remainder); + } +} + +// ============================================================================ +// Sigmoid_mul: (1 / (1 + exp(-x))) * y +// ============================================================================ + +/// NEON sigmoid_mul for f32 +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - `a`, `b`, and `out` must point to `len` valid elements +/// - Elements must not overlap +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn sigmoid_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let chunks = len / F32_LANES; + let remainder = len % F32_LANES; + let one = vdupq_n_f32(1.0); + + for i in 0..chunks { + let offset = i * F32_LANES; + let x = vld1q_f32(a.add(offset)); + let y = vld1q_f32(b.add(offset)); + let neg_x = vnegq_f32(x); + let exp_neg_x = exp_f32(neg_x); + let activation = vdivq_f32(one, vaddq_f32(one, exp_neg_x)); + let result = vmulq_f32(activation, y); + vst1q_f32(out.add(offset), result); + } + + if remainder > 0 { + let offset = chunks * F32_LANES; + super::super::sigmoid_mul_scalar_f32( + a.add(offset), + b.add(offset), + out.add(offset), + remainder, + ); + } +} + +/// NEON sigmoid_mul for f64 +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - `a`, `b`, and `out` must point to `len` valid elements +/// - Elements must not overlap +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn sigmoid_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let chunks = len / F64_LANES; + let remainder = len % F64_LANES; + let one = vdupq_n_f64(1.0); + + for i in 0..chunks { + let offset = i * F64_LANES; + let x = vld1q_f64(a.add(offset)); + let y = vld1q_f64(b.add(offset)); + let neg_x = vnegq_f64(x); + let exp_neg_x = exp_f64(neg_x); + let activation = vdivq_f64(one, vaddq_f64(one, exp_neg_x)); + let result = vmulq_f64(activation, y); + vst1q_f64(out.add(offset), result); + } + + if remainder > 0 { + let offset = chunks * F64_LANES; + super::super::sigmoid_mul_scalar_f64( + a.add(offset), + b.add(offset), + out.add(offset), + remainder, + ); + } +} diff --git a/src/runtime/cpu/kernels/simd/fused_activation_mul/avx2.rs b/src/runtime/cpu/kernels/simd/fused_activation_mul/avx2.rs new file mode 100644 index 00000000..d058519b --- /dev/null +++ b/src/runtime/cpu/kernels/simd/fused_activation_mul/avx2.rs @@ -0,0 +1,266 @@ +//! AVX2 fused activation-mul kernels +//! +//! Vectorized implementations of fused activation * multiplication using 256-bit registers. +//! Functions take two inputs (a, b) and compute activation(a) * b in a single pass. + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::super::math::avx2::{exp_f32, exp_f64, tanh_f32, tanh_f64}; +use super::{ + gelu_mul_scalar_f32, gelu_mul_scalar_f64, relu_mul_scalar_f32, relu_mul_scalar_f64, + sigmoid_mul_scalar_f32, sigmoid_mul_scalar_f64, silu_mul_scalar_f32, silu_mul_scalar_f64, +}; + +const F32_LANES: usize = 8; +const F64_LANES: usize = 4; + +/// AVX2 silu_mul for f32 +/// +/// Computes: (a / (1 + exp(-a))) * b +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn silu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let chunks = len / F32_LANES; + let one = _mm256_set1_ps(1.0); + + for c in 0..chunks { + let offset = c * F32_LANES; + let x = _mm256_loadu_ps(a.add(offset)); + let y = _mm256_loadu_ps(b.add(offset)); + let neg_x = _mm256_sub_ps(_mm256_setzero_ps(), x); + let exp_neg_x = exp_f32(neg_x); + let activation = _mm256_div_ps(x, _mm256_add_ps(one, exp_neg_x)); + let result = _mm256_mul_ps(activation, y); + _mm256_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + silu_mul_scalar_f32( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX2 silu_mul for f64 +/// +/// Computes: (a / (1 + exp(-a))) * b +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn silu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let chunks = len / F64_LANES; + let one = _mm256_set1_pd(1.0); + + for c in 0..chunks { + let offset = c * F64_LANES; + let x = _mm256_loadu_pd(a.add(offset)); + let y = _mm256_loadu_pd(b.add(offset)); + let neg_x = _mm256_sub_pd(_mm256_setzero_pd(), x); + let exp_neg_x = exp_f64(neg_x); + let activation = _mm256_div_pd(x, _mm256_add_pd(one, exp_neg_x)); + let result = _mm256_mul_pd(activation, y); + _mm256_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + silu_mul_scalar_f64( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX2 gelu_mul for f32 +/// +/// Computes: 0.5 * a * (1 + tanh(sqrt(2/pi) * (a + 0.044715 * a^3))) * b +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn gelu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let chunks = len / F32_LANES; + let half = _mm256_set1_ps(0.5); + let one = _mm256_set1_ps(1.0); + let sqrt_2_over_pi = _mm256_set1_ps(0.7978845608); + let tanh_coef = _mm256_set1_ps(0.044715); + + for c in 0..chunks { + let offset = c * F32_LANES; + let x = _mm256_loadu_ps(a.add(offset)); + let y = _mm256_loadu_ps(b.add(offset)); + + let x_cubed = _mm256_mul_ps(_mm256_mul_ps(x, x), x); + let inner = _mm256_mul_ps(sqrt_2_over_pi, _mm256_fmadd_ps(tanh_coef, x_cubed, x)); + + let tanh_inner = tanh_f32(inner); + let activation = _mm256_mul_ps(half, _mm256_mul_ps(x, _mm256_add_ps(one, tanh_inner))); + + let result = _mm256_mul_ps(activation, y); + _mm256_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + gelu_mul_scalar_f32( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX2 gelu_mul for f64 +/// +/// Computes: 0.5 * a * (1 + tanh(sqrt(2/pi) * (a + 0.044715 * a^3))) * b +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn gelu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let chunks = len / F64_LANES; + let half = _mm256_set1_pd(0.5); + let one = _mm256_set1_pd(1.0); + let sqrt_2_over_pi = _mm256_set1_pd(0.7978845608028654); + let tanh_coef = _mm256_set1_pd(0.044715); + + for c in 0..chunks { + let offset = c * F64_LANES; + let x = _mm256_loadu_pd(a.add(offset)); + let y = _mm256_loadu_pd(b.add(offset)); + + let x_cubed = _mm256_mul_pd(_mm256_mul_pd(x, x), x); + let inner = _mm256_mul_pd(sqrt_2_over_pi, _mm256_fmadd_pd(tanh_coef, x_cubed, x)); + + let tanh_inner = tanh_f64(inner); + let activation = _mm256_mul_pd(half, _mm256_mul_pd(x, _mm256_add_pd(one, tanh_inner))); + + let result = _mm256_mul_pd(activation, y); + _mm256_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + gelu_mul_scalar_f64( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX2 relu_mul for f32 +/// +/// Computes: max(0, a) * b +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn relu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let chunks = len / F32_LANES; + let zero = _mm256_setzero_ps(); + + for c in 0..chunks { + let offset = c * F32_LANES; + let x = _mm256_loadu_ps(a.add(offset)); + let y = _mm256_loadu_ps(b.add(offset)); + let activation = _mm256_max_ps(zero, x); + let result = _mm256_mul_ps(activation, y); + _mm256_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + relu_mul_scalar_f32( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX2 relu_mul for f64 +/// +/// Computes: max(0, a) * b +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn relu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let chunks = len / F64_LANES; + let zero = _mm256_setzero_pd(); + + for c in 0..chunks { + let offset = c * F64_LANES; + let x = _mm256_loadu_pd(a.add(offset)); + let y = _mm256_loadu_pd(b.add(offset)); + let activation = _mm256_max_pd(zero, x); + let result = _mm256_mul_pd(activation, y); + _mm256_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + relu_mul_scalar_f64( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX2 sigmoid_mul for f32 +/// +/// Computes: (1 / (1 + exp(-a))) * b +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn sigmoid_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let chunks = len / F32_LANES; + let one = _mm256_set1_ps(1.0); + + for c in 0..chunks { + let offset = c * F32_LANES; + let x = _mm256_loadu_ps(a.add(offset)); + let y = _mm256_loadu_ps(b.add(offset)); + let neg_x = _mm256_sub_ps(_mm256_setzero_ps(), x); + let exp_neg_x = exp_f32(neg_x); + let activation = _mm256_div_ps(one, _mm256_add_ps(one, exp_neg_x)); + let result = _mm256_mul_ps(activation, y); + _mm256_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + sigmoid_mul_scalar_f32( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX2 sigmoid_mul for f64 +/// +/// Computes: (1 / (1 + exp(-a))) * b +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn sigmoid_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let chunks = len / F64_LANES; + let one = _mm256_set1_pd(1.0); + + for c in 0..chunks { + let offset = c * F64_LANES; + let x = _mm256_loadu_pd(a.add(offset)); + let y = _mm256_loadu_pd(b.add(offset)); + let neg_x = _mm256_sub_pd(_mm256_setzero_pd(), x); + let exp_neg_x = exp_f64(neg_x); + let activation = _mm256_div_pd(one, _mm256_add_pd(one, exp_neg_x)); + let result = _mm256_mul_pd(activation, y); + _mm256_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + sigmoid_mul_scalar_f64( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} diff --git a/src/runtime/cpu/kernels/simd/fused_activation_mul/avx512.rs b/src/runtime/cpu/kernels/simd/fused_activation_mul/avx512.rs new file mode 100644 index 00000000..c45cdddd --- /dev/null +++ b/src/runtime/cpu/kernels/simd/fused_activation_mul/avx512.rs @@ -0,0 +1,266 @@ +//! AVX-512 fused activation-mul kernels +//! +//! Vectorized implementations of fused activation * multiplication using 512-bit registers. +//! Functions take two inputs (a, b) and compute activation(a) * b in a single pass. + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::super::math::avx512::{exp_f32, exp_f64, tanh_f32, tanh_f64}; +use super::{ + gelu_mul_scalar_f32, gelu_mul_scalar_f64, relu_mul_scalar_f32, relu_mul_scalar_f64, + sigmoid_mul_scalar_f32, sigmoid_mul_scalar_f64, silu_mul_scalar_f32, silu_mul_scalar_f64, +}; + +const F32_LANES: usize = 16; +const F64_LANES: usize = 8; + +/// AVX-512 silu_mul for f32 +/// +/// Computes: (a / (1 + exp(-a))) * b +#[target_feature(enable = "avx512f")] +pub unsafe fn silu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let chunks = len / F32_LANES; + let one = _mm512_set1_ps(1.0); + + for c in 0..chunks { + let offset = c * F32_LANES; + let x = _mm512_loadu_ps(a.add(offset)); + let y = _mm512_loadu_ps(b.add(offset)); + let neg_x = _mm512_sub_ps(_mm512_setzero_ps(), x); + let exp_neg_x = exp_f32(neg_x); + let activation = _mm512_div_ps(x, _mm512_add_ps(one, exp_neg_x)); + let result = _mm512_mul_ps(activation, y); + _mm512_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + silu_mul_scalar_f32( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX-512 silu_mul for f64 +/// +/// Computes: (a / (1 + exp(-a))) * b +#[target_feature(enable = "avx512f")] +pub unsafe fn silu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let chunks = len / F64_LANES; + let one = _mm512_set1_pd(1.0); + + for c in 0..chunks { + let offset = c * F64_LANES; + let x = _mm512_loadu_pd(a.add(offset)); + let y = _mm512_loadu_pd(b.add(offset)); + let neg_x = _mm512_sub_pd(_mm512_setzero_pd(), x); + let exp_neg_x = exp_f64(neg_x); + let activation = _mm512_div_pd(x, _mm512_add_pd(one, exp_neg_x)); + let result = _mm512_mul_pd(activation, y); + _mm512_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + silu_mul_scalar_f64( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX-512 gelu_mul for f32 +/// +/// Computes: 0.5 * a * (1 + tanh(sqrt(2/pi) * (a + 0.044715 * a^3))) * b +#[target_feature(enable = "avx512f")] +pub unsafe fn gelu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let chunks = len / F32_LANES; + let half = _mm512_set1_ps(0.5); + let one = _mm512_set1_ps(1.0); + let sqrt_2_over_pi = _mm512_set1_ps(0.7978845608); + let tanh_coef = _mm512_set1_ps(0.044715); + + for c in 0..chunks { + let offset = c * F32_LANES; + let x = _mm512_loadu_ps(a.add(offset)); + let y = _mm512_loadu_ps(b.add(offset)); + + let x_cubed = _mm512_mul_ps(_mm512_mul_ps(x, x), x); + let inner = _mm512_mul_ps(sqrt_2_over_pi, _mm512_fmadd_ps(tanh_coef, x_cubed, x)); + + let tanh_inner = tanh_f32(inner); + let activation = _mm512_mul_ps(half, _mm512_mul_ps(x, _mm512_add_ps(one, tanh_inner))); + + let result = _mm512_mul_ps(activation, y); + _mm512_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + gelu_mul_scalar_f32( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX-512 gelu_mul for f64 +/// +/// Computes: 0.5 * a * (1 + tanh(sqrt(2/pi) * (a + 0.044715 * a^3))) * b +#[target_feature(enable = "avx512f")] +pub unsafe fn gelu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let chunks = len / F64_LANES; + let half = _mm512_set1_pd(0.5); + let one = _mm512_set1_pd(1.0); + let sqrt_2_over_pi = _mm512_set1_pd(0.7978845608028654); + let tanh_coef = _mm512_set1_pd(0.044715); + + for c in 0..chunks { + let offset = c * F64_LANES; + let x = _mm512_loadu_pd(a.add(offset)); + let y = _mm512_loadu_pd(b.add(offset)); + + let x_cubed = _mm512_mul_pd(_mm512_mul_pd(x, x), x); + let inner = _mm512_mul_pd(sqrt_2_over_pi, _mm512_fmadd_pd(tanh_coef, x_cubed, x)); + + let tanh_inner = tanh_f64(inner); + let activation = _mm512_mul_pd(half, _mm512_mul_pd(x, _mm512_add_pd(one, tanh_inner))); + + let result = _mm512_mul_pd(activation, y); + _mm512_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + gelu_mul_scalar_f64( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX-512 relu_mul for f32 +/// +/// Computes: max(0, a) * b +#[target_feature(enable = "avx512f")] +pub unsafe fn relu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let chunks = len / F32_LANES; + let zero = _mm512_setzero_ps(); + + for c in 0..chunks { + let offset = c * F32_LANES; + let x = _mm512_loadu_ps(a.add(offset)); + let y = _mm512_loadu_ps(b.add(offset)); + let activation = _mm512_max_ps(zero, x); + let result = _mm512_mul_ps(activation, y); + _mm512_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + relu_mul_scalar_f32( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX-512 relu_mul for f64 +/// +/// Computes: max(0, a) * b +#[target_feature(enable = "avx512f")] +pub unsafe fn relu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let chunks = len / F64_LANES; + let zero = _mm512_setzero_pd(); + + for c in 0..chunks { + let offset = c * F64_LANES; + let x = _mm512_loadu_pd(a.add(offset)); + let y = _mm512_loadu_pd(b.add(offset)); + let activation = _mm512_max_pd(zero, x); + let result = _mm512_mul_pd(activation, y); + _mm512_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + relu_mul_scalar_f64( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX-512 sigmoid_mul for f32 +/// +/// Computes: (1 / (1 + exp(-a))) * b +#[target_feature(enable = "avx512f")] +pub unsafe fn sigmoid_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let chunks = len / F32_LANES; + let one = _mm512_set1_ps(1.0); + + for c in 0..chunks { + let offset = c * F32_LANES; + let x = _mm512_loadu_ps(a.add(offset)); + let y = _mm512_loadu_ps(b.add(offset)); + let neg_x = _mm512_sub_ps(_mm512_setzero_ps(), x); + let exp_neg_x = exp_f32(neg_x); + let activation = _mm512_div_ps(one, _mm512_add_ps(one, exp_neg_x)); + let result = _mm512_mul_ps(activation, y); + _mm512_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + sigmoid_mul_scalar_f32( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX-512 sigmoid_mul for f64 +/// +/// Computes: (1 / (1 + exp(-a))) * b +#[target_feature(enable = "avx512f")] +pub unsafe fn sigmoid_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let chunks = len / F64_LANES; + let one = _mm512_set1_pd(1.0); + + for c in 0..chunks { + let offset = c * F64_LANES; + let x = _mm512_loadu_pd(a.add(offset)); + let y = _mm512_loadu_pd(b.add(offset)); + let neg_x = _mm512_sub_pd(_mm512_setzero_pd(), x); + let exp_neg_x = exp_f64(neg_x); + let activation = _mm512_div_pd(one, _mm512_add_pd(one, exp_neg_x)); + let result = _mm512_mul_pd(activation, y); + _mm512_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + sigmoid_mul_scalar_f64( + a.add(processed), + b.add(processed), + out.add(processed), + len - processed, + ); + } +} diff --git a/src/runtime/cpu/kernels/simd/fused_activation_mul/mod.rs b/src/runtime/cpu/kernels/simd/fused_activation_mul/mod.rs new file mode 100644 index 00000000..d9c025de --- /dev/null +++ b/src/runtime/cpu/kernels/simd/fused_activation_mul/mod.rs @@ -0,0 +1,534 @@ +//! SIMD-accelerated fused activation-multiplication operations +//! +//! Provides vectorized implementations of fused activation * multiplication: +//! - silu_mul: (x / (1 + exp(-x))) * y +//! - gelu_mul: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) * y +//! - relu_mul: max(0, x) * y +//! - sigmoid_mul: (1 / (1 + exp(-x))) * y +//! +//! These operations take TWO inputs (a, b) and compute `activation(a) * b` in one pass, +//! reducing memory bandwidth compared to separate operations. + +#[cfg(target_arch = "x86_64")] +mod avx2; +#[cfg(target_arch = "x86_64")] +mod avx512; + +#[cfg(target_arch = "aarch64")] +mod aarch64; + +use super::{SimdLevel, detect_simd}; + +/// Minimum length to justify SIMD overhead +const SIMD_THRESHOLD: usize = 32; + +/// SIMD silu_mul for f32 +/// +/// Computes: (a / (1 + exp(-a))) * b +/// +/// # Safety +/// - `a`, `b`, and `out` must point to `len` elements +/// - Elements must not overlap +#[inline] +pub unsafe fn silu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + silu_mul_scalar_f32(a, b, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::silu_mul_f32(a, b, out, len), + SimdLevel::Avx2Fma => avx2::silu_mul_f32(a, b, out, len), + _ => silu_mul_scalar_f32(a, b, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::silu_mul_f32(a, b, out, len), + _ => silu_mul_scalar_f32(a, b, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + silu_mul_scalar_f32(a, b, out, len); +} + +/// SIMD silu_mul for f64 +/// +/// Computes: (a / (1 + exp(-a))) * b +/// +/// # Safety +/// - `a`, `b`, and `out` must point to `len` elements +/// - Elements must not overlap +#[inline] +pub unsafe fn silu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + silu_mul_scalar_f64(a, b, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::silu_mul_f64(a, b, out, len), + SimdLevel::Avx2Fma => avx2::silu_mul_f64(a, b, out, len), + _ => silu_mul_scalar_f64(a, b, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::silu_mul_f64(a, b, out, len), + _ => silu_mul_scalar_f64(a, b, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + silu_mul_scalar_f64(a, b, out, len); +} + +/// SIMD gelu_mul for f32 +/// +/// Computes: 0.5 * a * (1 + tanh(sqrt(2/pi) * (a + 0.044715 * a^3))) * b +/// +/// # Safety +/// - `a`, `b`, and `out` must point to `len` elements +/// - Elements must not overlap +#[inline] +pub unsafe fn gelu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + gelu_mul_scalar_f32(a, b, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::gelu_mul_f32(a, b, out, len), + SimdLevel::Avx2Fma => avx2::gelu_mul_f32(a, b, out, len), + _ => gelu_mul_scalar_f32(a, b, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::gelu_mul_f32(a, b, out, len), + _ => gelu_mul_scalar_f32(a, b, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + gelu_mul_scalar_f32(a, b, out, len); +} + +/// SIMD gelu_mul for f64 +/// +/// Computes: 0.5 * a * (1 + tanh(sqrt(2/pi) * (a + 0.044715 * a^3))) * b +/// +/// # Safety +/// - `a`, `b`, and `out` must point to `len` elements +/// - Elements must not overlap +#[inline] +pub unsafe fn gelu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + gelu_mul_scalar_f64(a, b, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::gelu_mul_f64(a, b, out, len), + SimdLevel::Avx2Fma => avx2::gelu_mul_f64(a, b, out, len), + _ => gelu_mul_scalar_f64(a, b, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::gelu_mul_f64(a, b, out, len), + _ => gelu_mul_scalar_f64(a, b, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + gelu_mul_scalar_f64(a, b, out, len); +} + +/// SIMD relu_mul for f32 +/// +/// Computes: max(0, a) * b +/// +/// # Safety +/// - `a`, `b`, and `out` must point to `len` elements +/// - Elements must not overlap +#[inline] +pub unsafe fn relu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + relu_mul_scalar_f32(a, b, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::relu_mul_f32(a, b, out, len), + SimdLevel::Avx2Fma => avx2::relu_mul_f32(a, b, out, len), + _ => relu_mul_scalar_f32(a, b, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::relu_mul_f32(a, b, out, len), + _ => relu_mul_scalar_f32(a, b, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + relu_mul_scalar_f32(a, b, out, len); +} + +/// SIMD relu_mul for f64 +/// +/// Computes: max(0, a) * b +/// +/// # Safety +/// - `a`, `b`, and `out` must point to `len` elements +/// - Elements must not overlap +#[inline] +pub unsafe fn relu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + relu_mul_scalar_f64(a, b, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::relu_mul_f64(a, b, out, len), + SimdLevel::Avx2Fma => avx2::relu_mul_f64(a, b, out, len), + _ => relu_mul_scalar_f64(a, b, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::relu_mul_f64(a, b, out, len), + _ => relu_mul_scalar_f64(a, b, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + relu_mul_scalar_f64(a, b, out, len); +} + +/// SIMD sigmoid_mul for f32 +/// +/// Computes: (1 / (1 + exp(-a))) * b +/// +/// # Safety +/// - `a`, `b`, and `out` must point to `len` elements +/// - Elements must not overlap +#[inline] +pub unsafe fn sigmoid_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + sigmoid_mul_scalar_f32(a, b, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::sigmoid_mul_f32(a, b, out, len), + SimdLevel::Avx2Fma => avx2::sigmoid_mul_f32(a, b, out, len), + _ => sigmoid_mul_scalar_f32(a, b, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::sigmoid_mul_f32(a, b, out, len), + _ => sigmoid_mul_scalar_f32(a, b, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + sigmoid_mul_scalar_f32(a, b, out, len); +} + +/// SIMD sigmoid_mul for f64 +/// +/// Computes: (1 / (1 + exp(-a))) * b +/// +/// # Safety +/// - `a`, `b`, and `out` must point to `len` elements +/// - Elements must not overlap +#[inline] +pub unsafe fn sigmoid_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + sigmoid_mul_scalar_f64(a, b, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::sigmoid_mul_f64(a, b, out, len), + SimdLevel::Avx2Fma => avx2::sigmoid_mul_f64(a, b, out, len), + _ => sigmoid_mul_scalar_f64(a, b, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::sigmoid_mul_f64(a, b, out, len), + _ => sigmoid_mul_scalar_f64(a, b, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + sigmoid_mul_scalar_f64(a, b, out, len); +} + +// ============================================================================ +// Scalar fallbacks +// ============================================================================ + +/// Scalar silu_mul for f32 +#[inline] +pub unsafe fn silu_mul_scalar_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + for i in 0..len { + let x = *a.add(i); + let y = *b.add(i); + *out.add(i) = (x / (1.0 + (-x).exp())) * y; + } +} + +/// Scalar silu_mul for f64 +#[inline] +pub unsafe fn silu_mul_scalar_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + for i in 0..len { + let x = *a.add(i); + let y = *b.add(i); + *out.add(i) = (x / (1.0 + (-x).exp())) * y; + } +} + +/// Scalar gelu_mul for f32 +#[inline] +pub unsafe fn gelu_mul_scalar_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + const SQRT_2_OVER_PI: f32 = 0.7978845608; + const TANH_COEF: f32 = 0.044715; + + for i in 0..len { + let x = *a.add(i); + let y = *b.add(i); + let inner = SQRT_2_OVER_PI * (x + TANH_COEF * x * x * x); + *out.add(i) = 0.5 * x * (1.0 + inner.tanh()) * y; + } +} + +/// Scalar gelu_mul for f64 +#[inline] +pub unsafe fn gelu_mul_scalar_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + const SQRT_2_OVER_PI: f64 = 0.7978845608028654; + const TANH_COEF: f64 = 0.044715; + + for i in 0..len { + let x = *a.add(i); + let y = *b.add(i); + let inner = SQRT_2_OVER_PI * (x + TANH_COEF * x * x * x); + *out.add(i) = 0.5 * x * (1.0 + inner.tanh()) * y; + } +} + +/// Scalar relu_mul for f32 +#[inline] +pub unsafe fn relu_mul_scalar_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + for i in 0..len { + let x = *a.add(i); + let y = *b.add(i); + *out.add(i) = if x > 0.0 { x * y } else { 0.0 }; + } +} + +/// Scalar relu_mul for f64 +#[inline] +pub unsafe fn relu_mul_scalar_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + for i in 0..len { + let x = *a.add(i); + let y = *b.add(i); + *out.add(i) = if x > 0.0 { x * y } else { 0.0 }; + } +} + +/// Scalar sigmoid_mul for f32 +#[inline] +pub unsafe fn sigmoid_mul_scalar_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) { + for i in 0..len { + let x = *a.add(i); + let y = *b.add(i); + *out.add(i) = (1.0 / (1.0 + (-x).exp())) * y; + } +} + +/// Scalar sigmoid_mul for f64 +#[inline] +pub unsafe fn sigmoid_mul_scalar_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) { + for i in 0..len { + let x = *a.add(i); + let y = *b.add(i); + *out.add(i) = (1.0 / (1.0 + (-x).exp())) * y; + } +} + +// ============================================================================ +// f16/bf16 block-convert-compute wrappers +// ============================================================================ + +/// Generate f16/bf16 wrappers for binary fused ops: `fn(a, b, out, len)` +macro_rules! _half_binary_fused { + ($fn_name:ident, $half_ty:ty, $to_f32:path, $from_f32:path, $f32_fn:path) => { + #[cfg(feature = "f16")] + #[inline] + pub unsafe fn $fn_name( + a: *const $half_ty, + b: *const $half_ty, + out: *mut $half_ty, + len: usize, + ) { + use super::half_convert_utils::HALF_BLOCK; + let mut a_buf = [0.0f32; HALF_BLOCK]; + let mut b_buf = [0.0f32; HALF_BLOCK]; + let mut out_buf = [0.0f32; HALF_BLOCK]; + let mut offset = 0; + while offset < len { + let chunk = (len - offset).min(HALF_BLOCK); + $to_f32(a.add(offset) as *const u16, a_buf.as_mut_ptr(), chunk); + $to_f32(b.add(offset) as *const u16, b_buf.as_mut_ptr(), chunk); + $f32_fn(a_buf.as_ptr(), b_buf.as_ptr(), out_buf.as_mut_ptr(), chunk); + $from_f32(out_buf.as_ptr(), out.add(offset) as *mut u16, chunk); + offset += chunk; + } + } + }; +} + +macro_rules! half_binary_fused { + ($name:ident, $f32_fn:path) => { + paste::paste! { + _half_binary_fused!([<$name _f16>], half::f16, + super::half_convert_utils::convert_f16_to_f32, + super::half_convert_utils::convert_f32_to_f16, $f32_fn); + _half_binary_fused!([<$name _bf16>], half::bf16, + super::half_convert_utils::convert_bf16_to_f32, + super::half_convert_utils::convert_f32_to_bf16, $f32_fn); + } + }; +} + +half_binary_fused!(silu_mul, silu_mul_f32); +half_binary_fused!(gelu_mul, gelu_mul_f32); +half_binary_fused!(relu_mul, relu_mul_f32); +half_binary_fused!(sigmoid_mul, sigmoid_mul_f32); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_silu_mul_f32() { + let len = 128; + let a: Vec = (0..len).map(|x| (x as f32) / 32.0 - 2.0).collect(); + let b: Vec = (0..len).map(|x| (x as f32) / 64.0 + 0.5).collect(); + let mut out = vec![0.0f32; len]; + let mut out_ref = vec![0.0f32; len]; + + unsafe { + silu_mul_f32(a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), len); + silu_mul_scalar_f32(a.as_ptr(), b.as_ptr(), out_ref.as_mut_ptr(), len); + } + + for i in 0..len { + let diff = (out[i] - out_ref[i]).abs(); + let denom = out_ref[i].abs().max(1e-6); + let rel_err = diff / denom; + assert!( + rel_err < 0.01, + "silu_mul mismatch at {}: {} vs {} (err: {})", + i, + out[i], + out_ref[i], + rel_err + ); + } + } + + #[test] + fn test_gelu_mul_f32() { + let len = 128; + let a: Vec = (0..len).map(|x| (x as f32) / 32.0 - 2.0).collect(); + let b: Vec = (0..len).map(|x| (x as f32) / 64.0 + 0.5).collect(); + let mut out = vec![0.0f32; len]; + let mut out_ref = vec![0.0f32; len]; + + unsafe { + gelu_mul_f32(a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), len); + gelu_mul_scalar_f32(a.as_ptr(), b.as_ptr(), out_ref.as_mut_ptr(), len); + } + + for i in 0..len { + let diff = (out[i] - out_ref[i]).abs(); + let denom = out_ref[i].abs().max(1e-6); + let rel_err = diff / denom; + assert!( + rel_err < 0.02, + "gelu_mul mismatch at {}: {} vs {} (err: {})", + i, + out[i], + out_ref[i], + rel_err + ); + } + } + + #[test] + fn test_relu_mul_f32() { + let len = 128; + let a: Vec = (0..len).map(|x| (x as f32) - 64.0).collect(); + let b: Vec = (0..len).map(|x| (x as f32) / 64.0 + 0.5).collect(); + let mut out = vec![0.0f32; len]; + let mut out_ref = vec![0.0f32; len]; + + unsafe { + relu_mul_f32(a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), len); + relu_mul_scalar_f32(a.as_ptr(), b.as_ptr(), out_ref.as_mut_ptr(), len); + } + + assert_eq!(out, out_ref); + } + + #[test] + fn test_sigmoid_mul_f32() { + let len = 128; + let a: Vec = (0..len).map(|x| (x as f32) / 32.0 - 2.0).collect(); + let b: Vec = (0..len).map(|x| (x as f32) / 64.0 + 0.5).collect(); + let mut out = vec![0.0f32; len]; + let mut out_ref = vec![0.0f32; len]; + + unsafe { + sigmoid_mul_f32(a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), len); + sigmoid_mul_scalar_f32(a.as_ptr(), b.as_ptr(), out_ref.as_mut_ptr(), len); + } + + for i in 0..len { + let diff = (out[i] - out_ref[i]).abs(); + let denom = out_ref[i].abs().max(1e-6); + let rel_err = diff / denom; + assert!( + rel_err < 0.01, + "sigmoid_mul mismatch at {}: {} vs {} (err: {})", + i, + out[i], + out_ref[i], + rel_err + ); + } + } +} diff --git a/src/runtime/cpu/kernels/simd/fused_elementwise/aarch64/mod.rs b/src/runtime/cpu/kernels/simd/fused_elementwise/aarch64/mod.rs new file mode 100644 index 00000000..d143322f --- /dev/null +++ b/src/runtime/cpu/kernels/simd/fused_elementwise/aarch64/mod.rs @@ -0,0 +1 @@ +pub mod neon; diff --git a/src/runtime/cpu/kernels/simd/fused_elementwise/aarch64/neon.rs b/src/runtime/cpu/kernels/simd/fused_elementwise/aarch64/neon.rs new file mode 100644 index 00000000..ce4a3225 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/fused_elementwise/aarch64/neon.rs @@ -0,0 +1,210 @@ +//! NEON fused elementwise kernels (128-bit) + +#[cfg(target_arch = "aarch64")] +use std::arch::aarch64::*; + +use super::super::{ + fused_add_mul_scalar_f32 as fused_add_mul_fallback_f32, + fused_add_mul_scalar_f64 as fused_add_mul_fallback_f64, + fused_mul_add_scalar_f32 as fused_mul_add_fallback_f32, + fused_mul_add_scalar_f64 as fused_mul_add_fallback_f64, fused_mul_add_scalar_loop_f32, + fused_mul_add_scalar_loop_f64, +}; + +const F32_LANES: usize = 4; +const F64_LANES: usize = 2; + +/// NEON fused_mul_add for f32: out = a * b + c +#[target_feature(enable = "neon")] +pub unsafe fn fused_mul_add_f32( + a: *const f32, + b: *const f32, + c: *const f32, + out: *mut f32, + len: usize, +) { + let chunks = len / F32_LANES; + + for i in 0..chunks { + let offset = i * F32_LANES; + let va = vld1q_f32(a.add(offset)); + let vb = vld1q_f32(b.add(offset)); + let vc = vld1q_f32(c.add(offset)); + // vfmaq_f32: vc + va * vb + let result = vfmaq_f32(vc, va, vb); + vst1q_f32(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + fused_mul_add_fallback_f32( + a.add(processed), + b.add(processed), + c.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// NEON fused_mul_add for f64: out = a * b + c +#[target_feature(enable = "neon")] +pub unsafe fn fused_mul_add_f64( + a: *const f64, + b: *const f64, + c: *const f64, + out: *mut f64, + len: usize, +) { + let chunks = len / F64_LANES; + + for i in 0..chunks { + let offset = i * F64_LANES; + let va = vld1q_f64(a.add(offset)); + let vb = vld1q_f64(b.add(offset)); + let vc = vld1q_f64(c.add(offset)); + let result = vfmaq_f64(vc, va, vb); + vst1q_f64(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + fused_mul_add_fallback_f64( + a.add(processed), + b.add(processed), + c.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// NEON fused_add_mul for f32: out = (a + b) * c +#[target_feature(enable = "neon")] +pub unsafe fn fused_add_mul_f32( + a: *const f32, + b: *const f32, + c: *const f32, + out: *mut f32, + len: usize, +) { + let chunks = len / F32_LANES; + + for i in 0..chunks { + let offset = i * F32_LANES; + let va = vld1q_f32(a.add(offset)); + let vb = vld1q_f32(b.add(offset)); + let vc = vld1q_f32(c.add(offset)); + let sum = vaddq_f32(va, vb); + let result = vmulq_f32(sum, vc); + vst1q_f32(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + fused_add_mul_fallback_f32( + a.add(processed), + b.add(processed), + c.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// NEON fused_add_mul for f64: out = (a + b) * c +#[target_feature(enable = "neon")] +pub unsafe fn fused_add_mul_f64( + a: *const f64, + b: *const f64, + c: *const f64, + out: *mut f64, + len: usize, +) { + let chunks = len / F64_LANES; + + for i in 0..chunks { + let offset = i * F64_LANES; + let va = vld1q_f64(a.add(offset)); + let vb = vld1q_f64(b.add(offset)); + let vc = vld1q_f64(c.add(offset)); + let sum = vaddq_f64(va, vb); + let result = vmulq_f64(sum, vc); + vst1q_f64(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + fused_add_mul_fallback_f64( + a.add(processed), + b.add(processed), + c.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// NEON fused_mul_add_scalar for f32: out = a * scale + bias +#[target_feature(enable = "neon")] +pub unsafe fn fused_mul_add_scalar_f32( + a: *const f32, + scale: f32, + bias: f32, + out: *mut f32, + len: usize, +) { + let chunks = len / F32_LANES; + let vscale = vdupq_n_f32(scale); + let vbias = vdupq_n_f32(bias); + + for i in 0..chunks { + let offset = i * F32_LANES; + let va = vld1q_f32(a.add(offset)); + let result = vfmaq_f32(vbias, va, vscale); + vst1q_f32(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + fused_mul_add_scalar_loop_f32( + a.add(processed), + scale, + bias, + out.add(processed), + len - processed, + ); + } +} + +/// NEON fused_mul_add_scalar for f64: out = a * scale + bias +#[target_feature(enable = "neon")] +pub unsafe fn fused_mul_add_scalar_f64( + a: *const f64, + scale: f64, + bias: f64, + out: *mut f64, + len: usize, +) { + let chunks = len / F64_LANES; + let vscale = vdupq_n_f64(scale); + let vbias = vdupq_n_f64(bias); + + for i in 0..chunks { + let offset = i * F64_LANES; + let va = vld1q_f64(a.add(offset)); + let result = vfmaq_f64(vbias, va, vscale); + vst1q_f64(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + fused_mul_add_scalar_loop_f64( + a.add(processed), + scale, + bias, + out.add(processed), + len - processed, + ); + } +} diff --git a/src/runtime/cpu/kernels/simd/fused_elementwise/dispatch.rs b/src/runtime/cpu/kernels/simd/fused_elementwise/dispatch.rs new file mode 100644 index 00000000..f084ff38 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/fused_elementwise/dispatch.rs @@ -0,0 +1,532 @@ +//! SIMD-accelerated fused elementwise operation dispatch and scalar fallbacks. +//! +//! Provides vectorized implementations of: +//! - fused_mul_add: a * b + c (FMA) +//! - fused_add_mul: (a + b) * c +//! - fused_mul_add_scalar: a * scale + bias (affine transform) +//! +//! These use hardware FMA intrinsics where available for better accuracy +//! and throughput (single rounding instead of two). + +#[cfg(target_arch = "aarch64")] +use super::aarch64; +#[cfg(target_arch = "x86_64")] +use super::x86_64; +use crate::runtime::cpu::kernels::simd::{SimdLevel, detect_simd}; + +/// Minimum length to justify SIMD overhead +const SIMD_THRESHOLD: usize = 32; + +// ============================================================================ +// fused_mul_add: a * b + c +// ============================================================================ + +/// SIMD fused_mul_add for f32: out[i] = a[i] * b[i] + c[i] +/// +/// # Safety +/// - `a`, `b`, `c`, and `out` must point to `len` elements +/// - Elements must not overlap +#[inline] +pub unsafe fn fused_mul_add_f32( + a: *const f32, + b: *const f32, + c: *const f32, + out: *mut f32, + len: usize, +) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_mul_add_scalar_f32(a, b, c, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => x86_64::avx512::fused_mul_add_f32(a, b, c, out, len), + SimdLevel::Avx2Fma => x86_64::avx2::fused_mul_add_f32(a, b, c, out, len), + _ => fused_mul_add_scalar_f32(a, b, c, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::fused_mul_add_f32(a, b, c, out, len) + } + _ => fused_mul_add_scalar_f32(a, b, c, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_mul_add_scalar_f32(a, b, c, out, len); +} + +/// SIMD fused_mul_add for f64: out[i] = a[i] * b[i] + c[i] +/// +/// # Safety +/// - `a`, `b`, `c`, and `out` must point to `len` elements +#[inline] +pub unsafe fn fused_mul_add_f64( + a: *const f64, + b: *const f64, + c: *const f64, + out: *mut f64, + len: usize, +) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_mul_add_scalar_f64(a, b, c, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => x86_64::avx512::fused_mul_add_f64(a, b, c, out, len), + SimdLevel::Avx2Fma => x86_64::avx2::fused_mul_add_f64(a, b, c, out, len), + _ => fused_mul_add_scalar_f64(a, b, c, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::fused_mul_add_f64(a, b, c, out, len) + } + _ => fused_mul_add_scalar_f64(a, b, c, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_mul_add_scalar_f64(a, b, c, out, len); +} + +// ============================================================================ +// fused_add_mul: (a + b) * c +// ============================================================================ + +/// SIMD fused_add_mul for f32: out[i] = (a[i] + b[i]) * c[i] +/// +/// # Safety +/// - `a`, `b`, `c`, and `out` must point to `len` elements +#[inline] +pub unsafe fn fused_add_mul_f32( + a: *const f32, + b: *const f32, + c: *const f32, + out: *mut f32, + len: usize, +) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_add_mul_scalar_f32(a, b, c, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => x86_64::avx512::fused_add_mul_f32(a, b, c, out, len), + SimdLevel::Avx2Fma => x86_64::avx2::fused_add_mul_f32(a, b, c, out, len), + _ => fused_add_mul_scalar_f32(a, b, c, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::fused_add_mul_f32(a, b, c, out, len) + } + _ => fused_add_mul_scalar_f32(a, b, c, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_add_mul_scalar_f32(a, b, c, out, len); +} + +/// SIMD fused_add_mul for f64: out[i] = (a[i] + b[i]) * c[i] +/// +/// # Safety +/// - `a`, `b`, `c`, and `out` must point to `len` elements +#[inline] +pub unsafe fn fused_add_mul_f64( + a: *const f64, + b: *const f64, + c: *const f64, + out: *mut f64, + len: usize, +) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_add_mul_scalar_f64(a, b, c, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => x86_64::avx512::fused_add_mul_f64(a, b, c, out, len), + SimdLevel::Avx2Fma => x86_64::avx2::fused_add_mul_f64(a, b, c, out, len), + _ => fused_add_mul_scalar_f64(a, b, c, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::fused_add_mul_f64(a, b, c, out, len) + } + _ => fused_add_mul_scalar_f64(a, b, c, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_add_mul_scalar_f64(a, b, c, out, len); +} + +// ============================================================================ +// fused_mul_add_scalar: a * scale + bias +// ============================================================================ + +/// SIMD fused_mul_add_scalar for f32: out[i] = a[i] * scale + bias +/// +/// # Safety +/// - `a` and `out` must point to `len` elements +#[inline] +pub unsafe fn fused_mul_add_scalar_f32_kernel( + a: *const f32, + scale: f32, + bias: f32, + out: *mut f32, + len: usize, +) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_mul_add_scalar_loop_f32(a, scale, bias, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => x86_64::avx512::fused_mul_add_scalar_f32(a, scale, bias, out, len), + SimdLevel::Avx2Fma => x86_64::avx2::fused_mul_add_scalar_f32(a, scale, bias, out, len), + _ => fused_mul_add_scalar_loop_f32(a, scale, bias, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::fused_mul_add_scalar_f32(a, scale, bias, out, len) + } + _ => fused_mul_add_scalar_loop_f32(a, scale, bias, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_mul_add_scalar_loop_f32(a, scale, bias, out, len); +} + +/// SIMD fused_mul_add_scalar for f64: out[i] = a[i] * scale + bias +/// +/// # Safety +/// - `a` and `out` must point to `len` elements +#[inline] +pub unsafe fn fused_mul_add_scalar_f64_kernel( + a: *const f64, + scale: f64, + bias: f64, + out: *mut f64, + len: usize, +) { + let level = detect_simd(); + + if len < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_mul_add_scalar_loop_f64(a, scale, bias, out, len); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => x86_64::avx512::fused_mul_add_scalar_f64(a, scale, bias, out, len), + SimdLevel::Avx2Fma => x86_64::avx2::fused_mul_add_scalar_f64(a, scale, bias, out, len), + _ => fused_mul_add_scalar_loop_f64(a, scale, bias, out, len), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::fused_mul_add_scalar_f64(a, scale, bias, out, len) + } + _ => fused_mul_add_scalar_loop_f64(a, scale, bias, out, len), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_mul_add_scalar_loop_f64(a, scale, bias, out, len); +} + +// ============================================================================ +// Scalar fallbacks +// ============================================================================ + +#[inline] +pub unsafe fn fused_mul_add_scalar_f32( + a: *const f32, + b: *const f32, + c: *const f32, + out: *mut f32, + len: usize, +) { + for i in 0..len { + *out.add(i) = (*a.add(i)).mul_add(*b.add(i), *c.add(i)); + } +} + +#[inline] +pub unsafe fn fused_mul_add_scalar_f64( + a: *const f64, + b: *const f64, + c: *const f64, + out: *mut f64, + len: usize, +) { + for i in 0..len { + *out.add(i) = (*a.add(i)).mul_add(*b.add(i), *c.add(i)); + } +} + +#[inline] +pub unsafe fn fused_add_mul_scalar_f32( + a: *const f32, + b: *const f32, + c: *const f32, + out: *mut f32, + len: usize, +) { + for i in 0..len { + *out.add(i) = (*a.add(i) + *b.add(i)) * *c.add(i); + } +} + +#[inline] +pub unsafe fn fused_add_mul_scalar_f64( + a: *const f64, + b: *const f64, + c: *const f64, + out: *mut f64, + len: usize, +) { + for i in 0..len { + *out.add(i) = (*a.add(i) + *b.add(i)) * *c.add(i); + } +} + +#[inline] +pub unsafe fn fused_mul_add_scalar_loop_f32( + a: *const f32, + scale: f32, + bias: f32, + out: *mut f32, + len: usize, +) { + for i in 0..len { + *out.add(i) = (*a.add(i)).mul_add(scale, bias); + } +} + +#[inline] +pub unsafe fn fused_mul_add_scalar_loop_f64( + a: *const f64, + scale: f64, + bias: f64, + out: *mut f64, + len: usize, +) { + for i in 0..len { + *out.add(i) = (*a.add(i)).mul_add(scale, bias); + } +} + +// ============================================================================ +// f16/bf16 block-convert-compute wrappers +// ============================================================================ + +/// Generate f16/bf16 wrappers for ternary fused ops: `fn(a, b, c, out, len)` +macro_rules! _half_ternary_fused { + ($fn_name:ident, $half_ty:ty, $to_f32:path, $from_f32:path, $f32_fn:path) => { + #[cfg(feature = "f16")] + #[inline] + pub unsafe fn $fn_name( + a: *const $half_ty, + b: *const $half_ty, + c: *const $half_ty, + out: *mut $half_ty, + len: usize, + ) { + use super::super::half_convert_utils::HALF_BLOCK; + let mut a_buf = [0.0f32; HALF_BLOCK]; + let mut b_buf = [0.0f32; HALF_BLOCK]; + let mut c_buf = [0.0f32; HALF_BLOCK]; + let mut out_buf = [0.0f32; HALF_BLOCK]; + let mut offset = 0; + while offset < len { + let chunk = (len - offset).min(HALF_BLOCK); + $to_f32(a.add(offset) as *const u16, a_buf.as_mut_ptr(), chunk); + $to_f32(b.add(offset) as *const u16, b_buf.as_mut_ptr(), chunk); + $to_f32(c.add(offset) as *const u16, c_buf.as_mut_ptr(), chunk); + $f32_fn( + a_buf.as_ptr(), + b_buf.as_ptr(), + c_buf.as_ptr(), + out_buf.as_mut_ptr(), + chunk, + ); + $from_f32(out_buf.as_ptr(), out.add(offset) as *mut u16, chunk); + offset += chunk; + } + } + }; +} + +macro_rules! half_ternary_fused { + ($name:ident, $f32_fn:path) => { + paste::paste! { + _half_ternary_fused!([<$name _f16>], half::f16, + super::super::half_convert_utils::convert_f16_to_f32, + super::super::half_convert_utils::convert_f32_to_f16, $f32_fn); + _half_ternary_fused!([<$name _bf16>], half::bf16, + super::super::half_convert_utils::convert_bf16_to_f32, + super::super::half_convert_utils::convert_f32_to_bf16, $f32_fn); + } + }; +} + +half_ternary_fused!(fused_mul_add, fused_mul_add_f32); +half_ternary_fused!(fused_add_mul, fused_add_mul_f32); + +/// Generate f16/bf16 wrappers for scalar fused ops: `fn(a, scale, bias, out, len)` +macro_rules! _half_scalar_fused { + ($fn_name:ident, $half_ty:ty, $to_f32:path, $from_f32:path, $f32_fn:path) => { + #[cfg(feature = "f16")] + #[inline] + pub unsafe fn $fn_name( + a: *const $half_ty, + scale: f32, + bias: f32, + out: *mut $half_ty, + len: usize, + ) { + use super::super::half_convert_utils::HALF_BLOCK; + let mut a_buf = [0.0f32; HALF_BLOCK]; + let mut out_buf = [0.0f32; HALF_BLOCK]; + let mut offset = 0; + while offset < len { + let chunk = (len - offset).min(HALF_BLOCK); + $to_f32(a.add(offset) as *const u16, a_buf.as_mut_ptr(), chunk); + $f32_fn(a_buf.as_ptr(), scale, bias, out_buf.as_mut_ptr(), chunk); + $from_f32(out_buf.as_ptr(), out.add(offset) as *mut u16, chunk); + offset += chunk; + } + } + }; +} + +macro_rules! half_scalar_fused { + ($name:ident, $f32_fn:path) => { + paste::paste! { + _half_scalar_fused!([<$name _f32_f16>], half::f16, + super::super::half_convert_utils::convert_f16_to_f32, + super::super::half_convert_utils::convert_f32_to_f16, $f32_fn); + _half_scalar_fused!([<$name _f32_bf16>], half::bf16, + super::super::half_convert_utils::convert_bf16_to_f32, + super::super::half_convert_utils::convert_f32_to_bf16, $f32_fn); + } + }; +} + +half_scalar_fused!(fused_mul_add_scalar, fused_mul_add_scalar_f32_kernel); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fused_mul_add_f32() { + let len = 128; + let a: Vec = (0..len).map(|x| x as f32 * 0.1).collect(); + let b: Vec = (0..len).map(|x| x as f32 * 0.2 + 1.0).collect(); + let c: Vec = (0..len).map(|x| x as f32 * 0.05 - 0.5).collect(); + let mut out = vec![0.0f32; len]; + let mut out_ref = vec![0.0f32; len]; + + unsafe { + fused_mul_add_f32(a.as_ptr(), b.as_ptr(), c.as_ptr(), out.as_mut_ptr(), len); + fused_mul_add_scalar_f32( + a.as_ptr(), + b.as_ptr(), + c.as_ptr(), + out_ref.as_mut_ptr(), + len, + ); + } + + for i in 0..len { + let diff = (out[i] - out_ref[i]).abs(); + assert!( + diff < 1e-5, + "fused_mul_add mismatch at {i}: {} vs {}", + out[i], + out_ref[i] + ); + } + } + + #[test] + fn test_fused_add_mul_f32() { + let len = 128; + let a: Vec = (0..len).map(|x| x as f32 * 0.1).collect(); + let b: Vec = (0..len).map(|x| x as f32 * 0.2 + 1.0).collect(); + let c: Vec = (0..len).map(|x| x as f32 * 0.05 + 0.5).collect(); + let mut out = vec![0.0f32; len]; + let mut out_ref = vec![0.0f32; len]; + + unsafe { + fused_add_mul_f32(a.as_ptr(), b.as_ptr(), c.as_ptr(), out.as_mut_ptr(), len); + fused_add_mul_scalar_f32( + a.as_ptr(), + b.as_ptr(), + c.as_ptr(), + out_ref.as_mut_ptr(), + len, + ); + } + + for i in 0..len { + let diff = (out[i] - out_ref[i]).abs(); + assert!( + diff < 1e-5, + "fused_add_mul mismatch at {i}: {} vs {}", + out[i], + out_ref[i] + ); + } + } + + #[test] + fn test_fused_mul_add_scalar_f32() { + let len = 128; + let a: Vec = (0..len).map(|x| x as f32 * 0.1 - 5.0).collect(); + let scale = 2.5f32; + let bias = -1.0f32; + let mut out = vec![0.0f32; len]; + let mut out_ref = vec![0.0f32; len]; + + unsafe { + fused_mul_add_scalar_f32_kernel(a.as_ptr(), scale, bias, out.as_mut_ptr(), len); + fused_mul_add_scalar_loop_f32(a.as_ptr(), scale, bias, out_ref.as_mut_ptr(), len); + } + + for i in 0..len { + let diff = (out[i] - out_ref[i]).abs(); + assert!( + diff < 1e-5, + "fused_mul_add_scalar mismatch at {i}: {} vs {}", + out[i], + out_ref[i] + ); + } + } +} diff --git a/src/runtime/cpu/kernels/simd/fused_elementwise/mod.rs b/src/runtime/cpu/kernels/simd/fused_elementwise/mod.rs new file mode 100644 index 00000000..89adfb2c --- /dev/null +++ b/src/runtime/cpu/kernels/simd/fused_elementwise/mod.rs @@ -0,0 +1,13 @@ +//! SIMD-accelerated fused elementwise operations. +//! +//! See [`dispatch`] for the public dispatch functions and scalar fallbacks. + +#[cfg(target_arch = "x86_64")] +pub(crate) mod x86_64; + +#[cfg(target_arch = "aarch64")] +pub(crate) mod aarch64; + +pub(crate) mod dispatch; + +pub use dispatch::*; diff --git a/src/runtime/cpu/kernels/simd/fused_elementwise/x86_64/avx2.rs b/src/runtime/cpu/kernels/simd/fused_elementwise/x86_64/avx2.rs new file mode 100644 index 00000000..0e2ece9d --- /dev/null +++ b/src/runtime/cpu/kernels/simd/fused_elementwise/x86_64/avx2.rs @@ -0,0 +1,210 @@ +//! AVX2+FMA fused elementwise kernels (256-bit) + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::super::{ + fused_add_mul_scalar_f32 as fused_add_mul_fallback_f32, + fused_add_mul_scalar_f64 as fused_add_mul_fallback_f64, + fused_mul_add_scalar_f32 as fused_mul_add_fallback_f32, + fused_mul_add_scalar_f64 as fused_mul_add_fallback_f64, fused_mul_add_scalar_loop_f32, + fused_mul_add_scalar_loop_f64, +}; + +const F32_LANES: usize = 8; +const F64_LANES: usize = 4; + +/// AVX2+FMA fused_mul_add for f32: out = a * b + c +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn fused_mul_add_f32( + a: *const f32, + b: *const f32, + c: *const f32, + out: *mut f32, + len: usize, +) { + let chunks = len / F32_LANES; + + for i in 0..chunks { + let offset = i * F32_LANES; + let va = _mm256_loadu_ps(a.add(offset)); + let vb = _mm256_loadu_ps(b.add(offset)); + let vc = _mm256_loadu_ps(c.add(offset)); + // FMA: va * vb + vc in single instruction + let result = _mm256_fmadd_ps(va, vb, vc); + _mm256_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + fused_mul_add_fallback_f32( + a.add(processed), + b.add(processed), + c.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX2+FMA fused_mul_add for f64: out = a * b + c +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn fused_mul_add_f64( + a: *const f64, + b: *const f64, + c: *const f64, + out: *mut f64, + len: usize, +) { + let chunks = len / F64_LANES; + + for i in 0..chunks { + let offset = i * F64_LANES; + let va = _mm256_loadu_pd(a.add(offset)); + let vb = _mm256_loadu_pd(b.add(offset)); + let vc = _mm256_loadu_pd(c.add(offset)); + let result = _mm256_fmadd_pd(va, vb, vc); + _mm256_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + fused_mul_add_fallback_f64( + a.add(processed), + b.add(processed), + c.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX2 fused_add_mul for f32: out = (a + b) * c +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn fused_add_mul_f32( + a: *const f32, + b: *const f32, + c: *const f32, + out: *mut f32, + len: usize, +) { + let chunks = len / F32_LANES; + + for i in 0..chunks { + let offset = i * F32_LANES; + let va = _mm256_loadu_ps(a.add(offset)); + let vb = _mm256_loadu_ps(b.add(offset)); + let vc = _mm256_loadu_ps(c.add(offset)); + let sum = _mm256_add_ps(va, vb); + let result = _mm256_mul_ps(sum, vc); + _mm256_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + fused_add_mul_fallback_f32( + a.add(processed), + b.add(processed), + c.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX2 fused_add_mul for f64: out = (a + b) * c +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn fused_add_mul_f64( + a: *const f64, + b: *const f64, + c: *const f64, + out: *mut f64, + len: usize, +) { + let chunks = len / F64_LANES; + + for i in 0..chunks { + let offset = i * F64_LANES; + let va = _mm256_loadu_pd(a.add(offset)); + let vb = _mm256_loadu_pd(b.add(offset)); + let vc = _mm256_loadu_pd(c.add(offset)); + let sum = _mm256_add_pd(va, vb); + let result = _mm256_mul_pd(sum, vc); + _mm256_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + fused_add_mul_fallback_f64( + a.add(processed), + b.add(processed), + c.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX2+FMA fused_mul_add_scalar for f32: out = a * scale + bias +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn fused_mul_add_scalar_f32( + a: *const f32, + scale: f32, + bias: f32, + out: *mut f32, + len: usize, +) { + let chunks = len / F32_LANES; + let vscale = _mm256_set1_ps(scale); + let vbias = _mm256_set1_ps(bias); + + for i in 0..chunks { + let offset = i * F32_LANES; + let va = _mm256_loadu_ps(a.add(offset)); + let result = _mm256_fmadd_ps(va, vscale, vbias); + _mm256_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + fused_mul_add_scalar_loop_f32( + a.add(processed), + scale, + bias, + out.add(processed), + len - processed, + ); + } +} + +/// AVX2+FMA fused_mul_add_scalar for f64: out = a * scale + bias +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn fused_mul_add_scalar_f64( + a: *const f64, + scale: f64, + bias: f64, + out: *mut f64, + len: usize, +) { + let chunks = len / F64_LANES; + let vscale = _mm256_set1_pd(scale); + let vbias = _mm256_set1_pd(bias); + + for i in 0..chunks { + let offset = i * F64_LANES; + let va = _mm256_loadu_pd(a.add(offset)); + let result = _mm256_fmadd_pd(va, vscale, vbias); + _mm256_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + fused_mul_add_scalar_loop_f64( + a.add(processed), + scale, + bias, + out.add(processed), + len - processed, + ); + } +} diff --git a/src/runtime/cpu/kernels/simd/fused_elementwise/x86_64/avx512.rs b/src/runtime/cpu/kernels/simd/fused_elementwise/x86_64/avx512.rs new file mode 100644 index 00000000..07d87897 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/fused_elementwise/x86_64/avx512.rs @@ -0,0 +1,209 @@ +//! AVX-512 fused elementwise kernels (512-bit) + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::super::{ + fused_add_mul_scalar_f32 as fused_add_mul_fallback_f32, + fused_add_mul_scalar_f64 as fused_add_mul_fallback_f64, + fused_mul_add_scalar_f32 as fused_mul_add_fallback_f32, + fused_mul_add_scalar_f64 as fused_mul_add_fallback_f64, fused_mul_add_scalar_loop_f32, + fused_mul_add_scalar_loop_f64, +}; + +const F32_LANES: usize = 16; +const F64_LANES: usize = 8; + +/// AVX-512 fused_mul_add for f32: out = a * b + c +#[target_feature(enable = "avx512f")] +pub unsafe fn fused_mul_add_f32( + a: *const f32, + b: *const f32, + c: *const f32, + out: *mut f32, + len: usize, +) { + let chunks = len / F32_LANES; + + for i in 0..chunks { + let offset = i * F32_LANES; + let va = _mm512_loadu_ps(a.add(offset)); + let vb = _mm512_loadu_ps(b.add(offset)); + let vc = _mm512_loadu_ps(c.add(offset)); + let result = _mm512_fmadd_ps(va, vb, vc); + _mm512_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + fused_mul_add_fallback_f32( + a.add(processed), + b.add(processed), + c.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX-512 fused_mul_add for f64: out = a * b + c +#[target_feature(enable = "avx512f")] +pub unsafe fn fused_mul_add_f64( + a: *const f64, + b: *const f64, + c: *const f64, + out: *mut f64, + len: usize, +) { + let chunks = len / F64_LANES; + + for i in 0..chunks { + let offset = i * F64_LANES; + let va = _mm512_loadu_pd(a.add(offset)); + let vb = _mm512_loadu_pd(b.add(offset)); + let vc = _mm512_loadu_pd(c.add(offset)); + let result = _mm512_fmadd_pd(va, vb, vc); + _mm512_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + fused_mul_add_fallback_f64( + a.add(processed), + b.add(processed), + c.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX-512 fused_add_mul for f32: out = (a + b) * c +#[target_feature(enable = "avx512f")] +pub unsafe fn fused_add_mul_f32( + a: *const f32, + b: *const f32, + c: *const f32, + out: *mut f32, + len: usize, +) { + let chunks = len / F32_LANES; + + for i in 0..chunks { + let offset = i * F32_LANES; + let va = _mm512_loadu_ps(a.add(offset)); + let vb = _mm512_loadu_ps(b.add(offset)); + let vc = _mm512_loadu_ps(c.add(offset)); + let sum = _mm512_add_ps(va, vb); + let result = _mm512_mul_ps(sum, vc); + _mm512_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + fused_add_mul_fallback_f32( + a.add(processed), + b.add(processed), + c.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX-512 fused_add_mul for f64: out = (a + b) * c +#[target_feature(enable = "avx512f")] +pub unsafe fn fused_add_mul_f64( + a: *const f64, + b: *const f64, + c: *const f64, + out: *mut f64, + len: usize, +) { + let chunks = len / F64_LANES; + + for i in 0..chunks { + let offset = i * F64_LANES; + let va = _mm512_loadu_pd(a.add(offset)); + let vb = _mm512_loadu_pd(b.add(offset)); + let vc = _mm512_loadu_pd(c.add(offset)); + let sum = _mm512_add_pd(va, vb); + let result = _mm512_mul_pd(sum, vc); + _mm512_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + fused_add_mul_fallback_f64( + a.add(processed), + b.add(processed), + c.add(processed), + out.add(processed), + len - processed, + ); + } +} + +/// AVX-512 fused_mul_add_scalar for f32: out = a * scale + bias +#[target_feature(enable = "avx512f")] +pub unsafe fn fused_mul_add_scalar_f32( + a: *const f32, + scale: f32, + bias: f32, + out: *mut f32, + len: usize, +) { + let chunks = len / F32_LANES; + let vscale = _mm512_set1_ps(scale); + let vbias = _mm512_set1_ps(bias); + + for i in 0..chunks { + let offset = i * F32_LANES; + let va = _mm512_loadu_ps(a.add(offset)); + let result = _mm512_fmadd_ps(va, vscale, vbias); + _mm512_storeu_ps(out.add(offset), result); + } + + let processed = chunks * F32_LANES; + if processed < len { + fused_mul_add_scalar_loop_f32( + a.add(processed), + scale, + bias, + out.add(processed), + len - processed, + ); + } +} + +/// AVX-512 fused_mul_add_scalar for f64: out = a * scale + bias +#[target_feature(enable = "avx512f")] +pub unsafe fn fused_mul_add_scalar_f64( + a: *const f64, + scale: f64, + bias: f64, + out: *mut f64, + len: usize, +) { + let chunks = len / F64_LANES; + let vscale = _mm512_set1_pd(scale); + let vbias = _mm512_set1_pd(bias); + + for i in 0..chunks { + let offset = i * F64_LANES; + let va = _mm512_loadu_pd(a.add(offset)); + let result = _mm512_fmadd_pd(va, vscale, vbias); + _mm512_storeu_pd(out.add(offset), result); + } + + let processed = chunks * F64_LANES; + if processed < len { + fused_mul_add_scalar_loop_f64( + a.add(processed), + scale, + bias, + out.add(processed), + len - processed, + ); + } +} diff --git a/src/runtime/cpu/kernels/simd/fused_elementwise/x86_64/mod.rs b/src/runtime/cpu/kernels/simd/fused_elementwise/x86_64/mod.rs new file mode 100644 index 00000000..451cc92d --- /dev/null +++ b/src/runtime/cpu/kernels/simd/fused_elementwise/x86_64/mod.rs @@ -0,0 +1,2 @@ +pub mod avx2; +pub mod avx512; diff --git a/src/runtime/cpu/kernels/simd/half_convert_utils/aarch64.rs b/src/runtime/cpu/kernels/simd/half_convert_utils/aarch64.rs new file mode 100644 index 00000000..9e232256 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/half_convert_utils/aarch64.rs @@ -0,0 +1,113 @@ +//! aarch64 NEON implementations for f16/bf16 ↔ f32 conversion +//! +//! - f16: NEON `vcvt_f32_f16` / `vcvt_f16_f32` +//! - bf16: NEON integer bit-shift + +// --------------------------------------------------------------------------- +// NEON: f16 ↔ f32 +// --------------------------------------------------------------------------- + +pub(super) unsafe fn convert_f16_to_f32_neon(src: *const u16, dst: *mut f32, len: usize) { + use std::arch::aarch64::*; + + let mut i = 0usize; + + // Process 4 elements at a time using vcvt_f32_f16 + while i + 4 <= len { + let half_vec = vld1_u16(src.add(i)); + let half_f16 = vreinterpret_f16_u16(half_vec); + let float_vec = vcvt_f32_f16(half_f16); + vst1q_f32(dst.add(i), float_vec); + i += 4; + } + + // Scalar tail + while i < len { + *dst.add(i) = half::f16::from_bits(*src.add(i)).to_f32(); + i += 1; + } +} + +pub(super) unsafe fn convert_f32_to_f16_neon(src: *const f32, dst: *mut u16, len: usize) { + use std::arch::aarch64::*; + + let mut i = 0usize; + + // Process 4 elements at a time using vcvt_f16_f32 + while i + 4 <= len { + let float_vec = vld1q_f32(src.add(i)); + let half_f16 = vcvt_f16_f32(float_vec); + let half_u16 = vreinterpret_u16_f16(half_f16); + vst1_u16(dst.add(i), half_u16); + i += 4; + } + + // Scalar tail + while i < len { + *dst.add(i) = half::f16::from_f32(*src.add(i)).to_bits(); + i += 1; + } +} + +// --------------------------------------------------------------------------- +// NEON: bf16 ↔ f32 (integer bit-shift) +// --------------------------------------------------------------------------- + +pub(super) unsafe fn convert_bf16_to_f32_neon(src: *const u16, dst: *mut f32, len: usize) { + use std::arch::aarch64::*; + + let mut i = 0usize; + + // Process 4 bf16 values at a time: zero-extend to u32, shift left 16 + while i + 4 <= len { + let bf16_vec = vld1_u16(src.add(i)); + // vmovl_u16: uint16x4 → uint32x4 (zero-extend) + let u32_vec = vmovl_u16(bf16_vec); + let shifted = vshlq_n_u32(u32_vec, 16); + let f32_vec = vreinterpretq_f32_u32(shifted); + vst1q_f32(dst.add(i), f32_vec); + i += 4; + } + + // Scalar tail + while i < len { + let bits = (*src.add(i) as u32) << 16; + *dst.add(i) = f32::from_bits(bits); + i += 1; + } +} + +pub(super) unsafe fn convert_f32_to_bf16_neon(src: *const f32, dst: *mut u16, len: usize) { + use std::arch::aarch64::*; + + let mut i = 0usize; + + let rounding_bias = vdupq_n_u32(0x7FFF); + let one = vdupq_n_u32(1); + + // Process 4 f32 values at a time + while i + 4 <= len { + let f32_vec = vld1q_f32(src.add(i)); + let bits = vreinterpretq_u32_f32(f32_vec); + + // Round-to-nearest-even + let shifted = vshrq_n_u32(bits, 16); + let lsb = vandq_u32(shifted, one); + let bias = vaddq_u32(rounding_bias, lsb); + let rounded = vaddq_u32(bits, bias); + let bf16_u32 = vshrq_n_u32(rounded, 16); + + // Narrow u32x4 → u16x4 + let bf16_u16 = vmovn_u32(bf16_u32); + vst1_u16(dst.add(i), bf16_u16); + i += 4; + } + + // Scalar tail with same rounding + while i < len { + let bits = (*src.add(i)).to_bits(); + let rounded = bits.wrapping_add(0x7FFF + ((bits >> 16) & 1)); + *dst.add(i) = (rounded >> 16) as u16; + i += 1; + } +} diff --git a/src/runtime/cpu/kernels/simd/half_convert_utils/mod.rs b/src/runtime/cpu/kernels/simd/half_convert_utils/mod.rs new file mode 100644 index 00000000..e3cc3c45 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/half_convert_utils/mod.rs @@ -0,0 +1,294 @@ +//! SIMD-accelerated f16/bf16 ↔ f32 conversion utilities +//! +//! These are the building blocks for the block-convert-compute pattern: +//! convert half-precision data to f32 in L1-sized blocks, run existing +//! f32 SIMD kernels, then convert back. +//! +//! # Conversion strategies +//! +//! - **x86 f16**: F16C instructions (`_mm256_cvtph_ps` / `_mm256_cvtps_ph`) +//! - **x86 bf16**: SIMD integer bit-shift (`u32 << 16` for load, rounded `>> 16` for store) +//! - **ARM f16**: NEON `vcvt_f32_f16` / `vcvt_f16_f32` +//! - **ARM bf16**: NEON integer bit-shift +//! - **Fallback**: `half` crate scalar conversion + +#[cfg(target_arch = "aarch64")] +mod aarch64; +#[cfg(target_arch = "x86_64")] +mod x86_64; + +/// Block size for stack-allocated conversion buffers. +/// 256 f32s = 1024 bytes, fits comfortably in L1 cache. +pub const HALF_BLOCK: usize = 256; + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +/// Convert f16 values to f32 using SIMD when available. +/// +/// # Safety +/// - `src` must be valid for reads of `len` u16 values (f16 bit patterns) +/// - `dst` must be valid for writes of `len` f32 values +#[inline] +pub unsafe fn convert_f16_to_f32(src: *const u16, dst: *mut f32, len: usize) { + if len == 0 { + return; + } + + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("f16c") { + return x86_64::convert_f16_to_f32_f16c(src, dst, len); + } + } + + #[cfg(target_arch = "aarch64")] + { + return aarch64::convert_f16_to_f32_neon(src, dst, len); + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + convert_f16_to_f32_scalar(src, dst, len); +} + +/// Convert f32 values to f16 using SIMD when available. +/// +/// # Safety +/// - `src` must be valid for reads of `len` f32 values +/// - `dst` must be valid for writes of `len` u16 values (f16 bit patterns) +#[inline] +pub unsafe fn convert_f32_to_f16(src: *const f32, dst: *mut u16, len: usize) { + if len == 0 { + return; + } + + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("f16c") { + return x86_64::convert_f32_to_f16_f16c(src, dst, len); + } + } + + #[cfg(target_arch = "aarch64")] + { + return aarch64::convert_f32_to_f16_neon(src, dst, len); + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + convert_f32_to_f16_scalar(src, dst, len); +} + +/// Convert bf16 values to f32 using SIMD when available. +/// +/// # Safety +/// - `src` must be valid for reads of `len` u16 values (bf16 bit patterns) +/// - `dst` must be valid for writes of `len` f32 values +#[inline] +pub unsafe fn convert_bf16_to_f32(src: *const u16, dst: *mut f32, len: usize) { + if len == 0 { + return; + } + + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx2") { + return x86_64::convert_bf16_to_f32_avx2(src, dst, len); + } + } + + #[cfg(target_arch = "aarch64")] + { + return aarch64::convert_bf16_to_f32_neon(src, dst, len); + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + convert_bf16_to_f32_scalar(src, dst, len); +} + +/// Convert f32 values to bf16 using SIMD when available (with round-to-nearest-even). +/// +/// # Safety +/// - `src` must be valid for reads of `len` f32 values +/// - `dst` must be valid for writes of `len` u16 values (bf16 bit patterns) +#[inline] +pub unsafe fn convert_f32_to_bf16(src: *const f32, dst: *mut u16, len: usize) { + if len == 0 { + return; + } + + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx2") { + return x86_64::convert_f32_to_bf16_avx2(src, dst, len); + } + } + + #[cfg(target_arch = "aarch64")] + { + return aarch64::convert_f32_to_bf16_neon(src, dst, len); + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + convert_f32_to_bf16_scalar(src, dst, len); +} + +// --------------------------------------------------------------------------- +// Scalar fallbacks +// --------------------------------------------------------------------------- + +#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] +#[inline] +unsafe fn convert_f16_to_f32_scalar(src: *const u16, dst: *mut f32, len: usize) { + for i in 0..len { + *dst.add(i) = half::f16::from_bits(*src.add(i)).to_f32(); + } +} + +#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] +#[inline] +unsafe fn convert_f32_to_f16_scalar(src: *const f32, dst: *mut u16, len: usize) { + for i in 0..len { + *dst.add(i) = half::f16::from_f32(*src.add(i)).to_bits(); + } +} + +#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] +#[inline] +unsafe fn convert_bf16_to_f32_scalar(src: *const u16, dst: *mut f32, len: usize) { + for i in 0..len { + *dst.add(i) = half::bf16::from_bits(*src.add(i)).to_f32(); + } +} + +#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] +#[inline] +unsafe fn convert_f32_to_bf16_scalar(src: *const f32, dst: *mut u16, len: usize) { + for i in 0..len { + let bits = (*src.add(i)).to_bits(); + let rounded = bits.wrapping_add(0x7FFF + ((bits >> 16) & 1)); + *dst.add(i) = (rounded >> 16) as u16; + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_f16_roundtrip() { + let values: Vec = vec![ + 0.0, + 1.0, + -1.0, + 0.5, + -0.5, + 65504.0, + -65504.0, + 0.000061035156, + 3.15, + ]; + let f16_bits: Vec = values + .iter() + .map(|&v| half::f16::from_f32(v).to_bits()) + .collect(); + let mut f32_out = vec![0.0f32; values.len()]; + let mut f16_out = vec![0u16; values.len()]; + + unsafe { + convert_f16_to_f32(f16_bits.as_ptr(), f32_out.as_mut_ptr(), values.len()); + convert_f32_to_f16(f32_out.as_ptr(), f16_out.as_mut_ptr(), f32_out.len()); + } + + for i in 0..values.len() { + assert_eq!( + f16_bits[i], f16_out[i], + "f16 roundtrip failed at index {}: input bits {:04x}, output bits {:04x}", + i, f16_bits[i], f16_out[i] + ); + } + } + + #[test] + fn test_bf16_roundtrip() { + let values: Vec = vec![0.0, 1.0, -1.0, 0.5, -0.5, 100.0, -100.0, 3.15]; + let bf16_bits: Vec = values + .iter() + .map(|&v| half::bf16::from_f32(v).to_bits()) + .collect(); + let mut f32_out = vec![0.0f32; values.len()]; + let mut bf16_out = vec![0u16; values.len()]; + + unsafe { + convert_bf16_to_f32(bf16_bits.as_ptr(), f32_out.as_mut_ptr(), values.len()); + convert_f32_to_bf16(f32_out.as_ptr(), bf16_out.as_mut_ptr(), f32_out.len()); + } + + for i in 0..values.len() { + assert_eq!( + bf16_bits[i], bf16_out[i], + "bf16 roundtrip failed at index {}: input bits {:04x}, output bits {:04x}", + i, bf16_bits[i], bf16_out[i] + ); + } + } + + #[test] + fn test_f16_conversion_accuracy() { + let f16_bits: Vec = (0..512) + .map(|i| half::f16::from_f32((i as f32 - 256.0) * 0.1).to_bits()) + .collect(); + let mut f32_out = vec![0.0f32; f16_bits.len()]; + unsafe { convert_f16_to_f32(f16_bits.as_ptr(), f32_out.as_mut_ptr(), f16_bits.len()) } + + for i in 0..f16_bits.len() { + let expected = half::f16::from_bits(f16_bits[i]).to_f32(); + assert_eq!(f32_out[i], expected, "f16→f32 mismatch at index {}", i); + } + } + + #[test] + fn test_bf16_conversion_accuracy() { + let bf16_bits: Vec = (0..512) + .map(|i| half::bf16::from_f32((i as f32 - 256.0) * 0.1).to_bits()) + .collect(); + let mut f32_out = vec![0.0f32; bf16_bits.len()]; + unsafe { convert_bf16_to_f32(bf16_bits.as_ptr(), f32_out.as_mut_ptr(), bf16_bits.len()) } + + for i in 0..bf16_bits.len() { + let expected = half::bf16::from_bits(bf16_bits[i]).to_f32(); + assert_eq!(f32_out[i], expected, "bf16→f32 mismatch at index {}", i); + } + } + + #[test] + fn test_empty_conversion() { + unsafe { + convert_f16_to_f32(std::ptr::null(), std::ptr::null_mut(), 0); + convert_f32_to_f16(std::ptr::null(), std::ptr::null_mut(), 0); + convert_bf16_to_f32(std::ptr::null(), std::ptr::null_mut(), 0); + convert_f32_to_bf16(std::ptr::null(), std::ptr::null_mut(), 0); + } + } + + #[test] + fn test_unaligned_lengths() { + for len in [1, 3, 5, 7, 9, 15, 17, 31, 33] { + let f16_bits: Vec = (0..len) + .map(|i| half::f16::from_f32(i as f32).to_bits()) + .collect(); + let mut f32_out = vec![0.0f32; len]; + + unsafe { convert_f16_to_f32(f16_bits.as_ptr(), f32_out.as_mut_ptr(), len) } + + for i in 0..len { + let expected = half::f16::from_bits(f16_bits[i]).to_f32(); + assert_eq!(f32_out[i], expected, "mismatch at len={}, index={}", len, i); + } + } + } +} diff --git a/src/runtime/cpu/kernels/simd/half_convert_utils/x86_64.rs b/src/runtime/cpu/kernels/simd/half_convert_utils/x86_64.rs new file mode 100644 index 00000000..197bda35 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/half_convert_utils/x86_64.rs @@ -0,0 +1,119 @@ +//! x86_64 SIMD implementations for f16/bf16 ↔ f32 conversion +//! +//! - f16: F16C instructions (`_mm256_cvtph_ps` / `_mm256_cvtps_ph`) +//! - bf16: AVX2 integer bit-shift (`u32 << 16` / rounded `>> 16`) + +// --------------------------------------------------------------------------- +// F16C: f16 ↔ f32 +// --------------------------------------------------------------------------- + +#[target_feature(enable = "f16c,avx")] +pub(super) unsafe fn convert_f16_to_f32_f16c(src: *const u16, dst: *mut f32, len: usize) { + use std::arch::x86_64::*; + + let mut i = 0usize; + + // Process 8 elements at a time + while i + 8 <= len { + let half_vec = _mm_loadu_si128(src.add(i) as *const __m128i); + let float_vec = _mm256_cvtph_ps(half_vec); + _mm256_storeu_ps(dst.add(i), float_vec); + i += 8; + } + + // Scalar tail + while i < len { + *dst.add(i) = half::f16::from_bits(*src.add(i)).to_f32(); + i += 1; + } +} + +#[target_feature(enable = "f16c,avx")] +pub(super) unsafe fn convert_f32_to_f16_f16c(src: *const f32, dst: *mut u16, len: usize) { + use std::arch::x86_64::*; + + let mut i = 0usize; + + // _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC = 0x08 + while i + 8 <= len { + let float_vec = _mm256_loadu_ps(src.add(i)); + let half_vec = _mm256_cvtps_ph(float_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128(dst.add(i) as *mut __m128i, half_vec); + i += 8; + } + + // Scalar tail + while i < len { + *dst.add(i) = half::f16::from_f32(*src.add(i)).to_bits(); + i += 1; + } +} + +// --------------------------------------------------------------------------- +// AVX2: bf16 ↔ f32 (integer bit-shift) +// --------------------------------------------------------------------------- + +#[target_feature(enable = "avx2")] +pub(super) unsafe fn convert_bf16_to_f32_avx2(src: *const u16, dst: *mut f32, len: usize) { + use std::arch::x86_64::*; + + let mut i = 0usize; + + // bf16 → f32: zero-extend u16 to u32, shift left by 16 + while i + 8 <= len { + let bf16_vec = _mm_loadu_si128(src.add(i) as *const __m128i); + let u32_vec = _mm256_cvtepu16_epi32(bf16_vec); + let f32_bits = _mm256_slli_epi32(u32_vec, 16); + _mm256_storeu_ps(dst.add(i), _mm256_castsi256_ps(f32_bits)); + i += 8; + } + + // Scalar tail + while i < len { + let bits = (*src.add(i) as u32) << 16; + *dst.add(i) = f32::from_bits(bits); + i += 1; + } +} + +#[target_feature(enable = "avx2")] +pub(super) unsafe fn convert_f32_to_bf16_avx2(src: *const f32, dst: *mut u16, len: usize) { + use std::arch::x86_64::*; + + let mut i = 0usize; + + // f32 → bf16 with round-to-nearest-even: + // Add rounding bias 0x7FFF + ((bits >> 16) & 1), then shift right 16 + let rounding_bias = _mm256_set1_epi32(0x7FFF); + let one = _mm256_set1_epi32(1); + + while i + 8 <= len { + let f32_vec = _mm256_loadu_ps(src.add(i)); + let bits = _mm256_castps_si256(f32_vec); + + // Round-to-nearest-even: bias = 0x7FFF + ((bits >> 16) & 1) + let shifted = _mm256_srli_epi32(bits, 16); + let lsb = _mm256_and_si256(shifted, one); + let bias = _mm256_add_epi32(rounding_bias, lsb); + + // Add bias and shift right + let rounded = _mm256_add_epi32(bits, bias); + let bf16_u32 = _mm256_srli_epi32(rounded, 16); + + // Pack 8 u32 values down to 8 u16 values + let lo = _mm256_castsi256_si128(bf16_u32); + let hi = _mm256_extracti128_si256(bf16_u32, 1); + let packed = _mm_packus_epi32(lo, hi); + + _mm_storeu_si128(dst.add(i) as *mut __m128i, packed); + i += 8; + } + + // Scalar tail with same rounding + while i < len { + let bits = (*src.add(i)).to_bits(); + let rounded = bits.wrapping_add(0x7FFF + ((bits >> 16) & 1)); + *dst.add(i) = (rounded >> 16) as u16; + i += 1; + } +} diff --git a/src/runtime/cpu/kernels/simd/half_macros.rs b/src/runtime/cpu/kernels/simd/half_macros.rs new file mode 100644 index 00000000..e1c327cd --- /dev/null +++ b/src/runtime/cpu/kernels/simd/half_macros.rs @@ -0,0 +1,337 @@ +//! Macros for generating f16/bf16 block-convert-compute wrappers. +//! +//! These macros eliminate boilerplate by generating both f16 and bf16 variants +//! of functions that operate via the block-convert-compute pattern: +//! 1. Convert half-precision input(s) to f32 in L1-sized stack blocks +//! 2. Call the existing f32 SIMD kernel +//! 3. Convert f32 output back to half-precision +//! +//! # Available Macros +//! +//! | Macro | Pattern | Example | +//! |-------|---------|---------| +//! | `half_unary!` | `fn(in, out, len)` | sigmoid, relu, erf | +//! | `half_unary_op!` | `fn(op, in, out, len)` | unary(UnaryOp) | +//! | `half_unary_param!` | `fn(in, out, len, p)` | leaky_relu, elu | +//! | `half_binary_op!` | `fn(op, a, b, out, len)` | binary, compare | +//! | `half_scalar_op!` | `fn(op, a, s, out, len)` | scalar ops | +//! | `half_unary_scalar!` | `fn(a, s, out, len)` | rsub_scalar | +//! | `half_where!` | `fn(cond, x, y, out, len)` | where_select | +//! | `half_clamp!` | `fn(a, out, len, min, max)` | clamp | + +/// Internal: generate a single half-precision variant (f16 or bf16). +/// All public macros delegate to this to avoid duplicating the block-convert loop. +macro_rules! _half_variant { + // 1-input, no extra args: fn(input, output, len) + (unary, $fn_name:ident, $half_ty:ty, $to_f32:path, $from_f32:path, $f32_fn:path) => { + #[cfg(feature = "f16")] + #[inline] + pub unsafe fn $fn_name(input: *const $half_ty, output: *mut $half_ty, len: usize) { + use crate::runtime::cpu::kernels::simd::half_convert_utils::HALF_BLOCK; + let mut a_buf = [0.0f32; HALF_BLOCK]; + let mut out_buf = [0.0f32; HALF_BLOCK]; + let mut offset = 0; + while offset < len { + let chunk = (len - offset).min(HALF_BLOCK); + $to_f32(input.add(offset) as *const u16, a_buf.as_mut_ptr(), chunk); + $f32_fn(a_buf.as_ptr(), out_buf.as_mut_ptr(), chunk); + $from_f32(out_buf.as_ptr(), output.add(offset) as *mut u16, chunk); + offset += chunk; + } + } + }; + // 1-input with leading op: fn(op, input, output, len) + (unary_op, $fn_name:ident, $half_ty:ty, $to_f32:path, $from_f32:path, $f32_fn:path, $op_ty:ty) => { + #[cfg(feature = "f16")] + #[inline] + pub unsafe fn $fn_name( + op: $op_ty, + input: *const $half_ty, + output: *mut $half_ty, + len: usize, + ) { + use crate::runtime::cpu::kernels::simd::half_convert_utils::HALF_BLOCK; + let mut a_buf = [0.0f32; HALF_BLOCK]; + let mut out_buf = [0.0f32; HALF_BLOCK]; + let mut offset = 0; + while offset < len { + let chunk = (len - offset).min(HALF_BLOCK); + $to_f32(input.add(offset) as *const u16, a_buf.as_mut_ptr(), chunk); + $f32_fn(op, a_buf.as_ptr(), out_buf.as_mut_ptr(), chunk); + $from_f32(out_buf.as_ptr(), output.add(offset) as *mut u16, chunk); + offset += chunk; + } + } + }; + // 1-input with trailing f32 param: fn(input, output, len, param) + (unary_param, $fn_name:ident, $half_ty:ty, $to_f32:path, $from_f32:path, $f32_fn:path) => { + #[cfg(feature = "f16")] + #[inline] + pub unsafe fn $fn_name( + input: *const $half_ty, + output: *mut $half_ty, + len: usize, + param: f32, + ) { + use crate::runtime::cpu::kernels::simd::half_convert_utils::HALF_BLOCK; + let mut a_buf = [0.0f32; HALF_BLOCK]; + let mut out_buf = [0.0f32; HALF_BLOCK]; + let mut offset = 0; + while offset < len { + let chunk = (len - offset).min(HALF_BLOCK); + $to_f32(input.add(offset) as *const u16, a_buf.as_mut_ptr(), chunk); + $f32_fn(a_buf.as_ptr(), out_buf.as_mut_ptr(), chunk, param); + $from_f32(out_buf.as_ptr(), output.add(offset) as *mut u16, chunk); + offset += chunk; + } + } + }; + // 2-input with leading op: fn(op, a, b, output, len) + (binary_op, $fn_name:ident, $half_ty:ty, $to_f32:path, $from_f32:path, $f32_fn:path, $op_ty:ty) => { + #[cfg(feature = "f16")] + #[inline] + pub unsafe fn $fn_name( + op: $op_ty, + a: *const $half_ty, + b: *const $half_ty, + out: *mut $half_ty, + len: usize, + ) { + use crate::runtime::cpu::kernels::simd::half_convert_utils::HALF_BLOCK; + let mut a_buf = [0.0f32; HALF_BLOCK]; + let mut b_buf = [0.0f32; HALF_BLOCK]; + let mut out_buf = [0.0f32; HALF_BLOCK]; + let mut offset = 0; + while offset < len { + let chunk = (len - offset).min(HALF_BLOCK); + $to_f32(a.add(offset) as *const u16, a_buf.as_mut_ptr(), chunk); + $to_f32(b.add(offset) as *const u16, b_buf.as_mut_ptr(), chunk); + $f32_fn( + op, + a_buf.as_ptr(), + b_buf.as_ptr(), + out_buf.as_mut_ptr(), + chunk, + ); + $from_f32(out_buf.as_ptr(), out.add(offset) as *mut u16, chunk); + offset += chunk; + } + } + }; + // 1-input with op + scalar: fn(op, a, scalar, output, len) + (scalar_op, $fn_name:ident, $half_ty:ty, $to_f32:path, $from_f32:path, $f32_fn:path, $op_ty:ty) => { + #[cfg(feature = "f16")] + #[inline] + pub unsafe fn $fn_name( + op: $op_ty, + a: *const $half_ty, + scalar: f32, + out: *mut $half_ty, + len: usize, + ) { + use crate::runtime::cpu::kernels::simd::half_convert_utils::HALF_BLOCK; + let mut a_buf = [0.0f32; HALF_BLOCK]; + let mut out_buf = [0.0f32; HALF_BLOCK]; + let mut offset = 0; + while offset < len { + let chunk = (len - offset).min(HALF_BLOCK); + $to_f32(a.add(offset) as *const u16, a_buf.as_mut_ptr(), chunk); + $f32_fn(op, a_buf.as_ptr(), scalar, out_buf.as_mut_ptr(), chunk); + $from_f32(out_buf.as_ptr(), out.add(offset) as *mut u16, chunk); + offset += chunk; + } + } + }; + // 1-input with scalar (no op): fn(a, scalar, output, len) + (unary_scalar, $fn_name:ident, $half_ty:ty, $to_f32:path, $from_f32:path, $f32_fn:path) => { + #[cfg(feature = "f16")] + #[inline] + pub unsafe fn $fn_name(a: *const $half_ty, scalar: f32, out: *mut $half_ty, len: usize) { + use crate::runtime::cpu::kernels::simd::half_convert_utils::HALF_BLOCK; + let mut a_buf = [0.0f32; HALF_BLOCK]; + let mut out_buf = [0.0f32; HALF_BLOCK]; + let mut offset = 0; + while offset < len { + let chunk = (len - offset).min(HALF_BLOCK); + $to_f32(a.add(offset) as *const u16, a_buf.as_mut_ptr(), chunk); + $f32_fn(a_buf.as_ptr(), scalar, out_buf.as_mut_ptr(), chunk); + $from_f32(out_buf.as_ptr(), out.add(offset) as *mut u16, chunk); + offset += chunk; + } + } + }; + // where/select: fn(cond, x, y, output, len) + (where_select, $fn_name:ident, $half_ty:ty, $to_f32:path, $from_f32:path, $f32_fn:path) => { + #[cfg(feature = "f16")] + #[inline] + pub unsafe fn $fn_name( + cond: *const u8, + x: *const $half_ty, + y: *const $half_ty, + out: *mut $half_ty, + len: usize, + ) { + use crate::runtime::cpu::kernels::simd::half_convert_utils::HALF_BLOCK; + let mut x_buf = [0.0f32; HALF_BLOCK]; + let mut y_buf = [0.0f32; HALF_BLOCK]; + let mut out_buf = [0.0f32; HALF_BLOCK]; + let mut offset = 0; + while offset < len { + let chunk = (len - offset).min(HALF_BLOCK); + $to_f32(x.add(offset) as *const u16, x_buf.as_mut_ptr(), chunk); + $to_f32(y.add(offset) as *const u16, y_buf.as_mut_ptr(), chunk); + $f32_fn( + cond.add(offset), + x_buf.as_ptr(), + y_buf.as_ptr(), + out_buf.as_mut_ptr(), + chunk, + ); + $from_f32(out_buf.as_ptr(), out.add(offset) as *mut u16, chunk); + offset += chunk; + } + } + }; + // clamp: fn(a, output, len, min, max) + (clamp, $fn_name:ident, $half_ty:ty, $to_f32:path, $from_f32:path, $f32_fn:path) => { + #[cfg(feature = "f16")] + #[inline] + pub unsafe fn $fn_name( + a: *const $half_ty, + out: *mut $half_ty, + len: usize, + min_val: f32, + max_val: f32, + ) { + use crate::runtime::cpu::kernels::simd::half_convert_utils::HALF_BLOCK; + let mut a_buf = [0.0f32; HALF_BLOCK]; + let mut out_buf = [0.0f32; HALF_BLOCK]; + let mut offset = 0; + while offset < len { + let chunk = (len - offset).min(HALF_BLOCK); + $to_f32(a.add(offset) as *const u16, a_buf.as_mut_ptr(), chunk); + $f32_fn( + a_buf.as_ptr(), + out_buf.as_mut_ptr(), + chunk, + min_val, + max_val, + ); + $from_f32(out_buf.as_ptr(), out.add(offset) as *mut u16, chunk); + offset += chunk; + } + } + }; +} + +/// Generate f16/bf16 wrappers for unary: `fn(input, output, len)` +macro_rules! half_unary { + ($name:ident, $f32_fn:path) => { + paste::paste! { + _half_variant!(unary, [<$name _f16>], half::f16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_f16, $f32_fn); + _half_variant!(unary, [<$name _bf16>], half::bf16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_bf16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_bf16, $f32_fn); + } + }; +} + +/// Generate f16/bf16 wrappers for unary with leading op: `fn(op, input, output, len)` +macro_rules! half_unary_op { + ($name:ident, $f32_fn:path, $op_ty:ty) => { + paste::paste! { + _half_variant!(unary_op, [<$name _f16>], half::f16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_f16, $f32_fn, $op_ty); + _half_variant!(unary_op, [<$name _bf16>], half::bf16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_bf16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_bf16, $f32_fn, $op_ty); + } + }; +} + +/// Generate f16/bf16 wrappers for unary with trailing f32 param: `fn(input, output, len, param)` +macro_rules! half_unary_param { + ($name:ident, $f32_fn:path) => { + paste::paste! { + _half_variant!(unary_param, [<$name _f16>], half::f16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_f16, $f32_fn); + _half_variant!(unary_param, [<$name _bf16>], half::bf16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_bf16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_bf16, $f32_fn); + } + }; +} + +/// Generate f16/bf16 wrappers for binary with op: `fn(op, a, b, output, len)` +macro_rules! half_binary_op { + ($name:ident, $f32_fn:path, $op_ty:ty) => { + paste::paste! { + _half_variant!(binary_op, [<$name _f16>], half::f16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_f16, $f32_fn, $op_ty); + _half_variant!(binary_op, [<$name _bf16>], half::bf16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_bf16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_bf16, $f32_fn, $op_ty); + } + }; +} + +/// Generate f16/bf16 wrappers for scalar op: `fn(op, a, scalar, output, len)` +macro_rules! half_scalar_op { + ($name:ident, $f32_fn:path, $op_ty:ty) => { + paste::paste! { + _half_variant!(scalar_op, [<$name _f16>], half::f16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_f16, $f32_fn, $op_ty); + _half_variant!(scalar_op, [<$name _bf16>], half::bf16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_bf16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_bf16, $f32_fn, $op_ty); + } + }; +} + +/// Generate f16/bf16 wrappers for simple scalar fn: `fn(a, scalar, output, len)` +macro_rules! half_unary_scalar { + ($name:ident, $f32_fn:path) => { + paste::paste! { + _half_variant!(unary_scalar, [<$name _f16>], half::f16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_f16, $f32_fn); + _half_variant!(unary_scalar, [<$name _bf16>], half::bf16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_bf16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_bf16, $f32_fn); + } + }; +} + +/// Generate f16/bf16 wrappers for where/select: `fn(cond, x, y, output, len)` +macro_rules! half_where { + ($name:ident, $f32_fn:path) => { + paste::paste! { + _half_variant!(where_select, [<$name _f16>], half::f16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_f16, $f32_fn); + _half_variant!(where_select, [<$name _bf16>], half::bf16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_bf16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_bf16, $f32_fn); + } + }; +} + +/// Generate f16/bf16 wrappers for clamp: `fn(a, output, len, min, max)` +macro_rules! half_clamp { + ($name:ident, $f32_fn:path) => { + paste::paste! { + _half_variant!(clamp, [<$name _f16>], half::f16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_f16, $f32_fn); + _half_variant!(clamp, [<$name _bf16>], half::bf16, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_bf16_to_f32, + crate::runtime::cpu::kernels::simd::half_convert_utils::convert_f32_to_bf16, $f32_fn); + } + }; +} diff --git a/src/runtime/cpu/kernels/simd/index/aarch64/neon.rs b/src/runtime/cpu/kernels/simd/index/aarch64/neon.rs index 4e20ada3..05e932bc 100644 --- a/src/runtime/cpu/kernels/simd/index/aarch64/neon.rs +++ b/src/runtime/cpu/kernels/simd/index/aarch64/neon.rs @@ -222,7 +222,7 @@ pub unsafe fn masked_count(mask: *const u8, len: usize) -> usize { // Horizontal sum let sum16 = vpaddlq_u8(total_acc); let sum32 = vpaddlq_u16(sum16); - let sum64 = vpaddlq_u32(sum32); + let _sum64 = vpaddlq_u32(sum32); // Will handle at final reduction } } diff --git a/src/runtime/cpu/kernels/simd/logsumexp/mod.rs b/src/runtime/cpu/kernels/simd/logsumexp/mod.rs index 8f150e06..f4d60e93 100644 --- a/src/runtime/cpu/kernels/simd/logsumexp/mod.rs +++ b/src/runtime/cpu/kernels/simd/logsumexp/mod.rs @@ -162,6 +162,58 @@ pub unsafe fn logsumexp_scalar_f64( } } +#[cfg(feature = "f16")] +/// f16 wrapper for logsumexp: converts input to f32, runs f32 logsumexp, converts output back. +/// +/// # Safety +/// - `a` must point to `reduce_size * outer_size` elements +/// - `out` must point to `outer_size` elements +pub unsafe fn logsumexp_f16( + a: *const half::f16, + out: *mut half::f16, + reduce_size: usize, + outer_size: usize, +) { + use super::half_convert_utils::*; + let input_len = outer_size * reduce_size; + let mut a_f32 = vec![0.0f32; input_len]; + let mut out_f32 = vec![0.0f32; outer_size]; + convert_f16_to_f32(a as *const u16, a_f32.as_mut_ptr(), input_len); + logsumexp_f32( + a_f32.as_ptr(), + out_f32.as_mut_ptr(), + reduce_size, + outer_size, + ); + convert_f32_to_f16(out_f32.as_ptr(), out as *mut u16, outer_size); +} + +#[cfg(feature = "f16")] +/// bf16 wrapper for logsumexp: converts input to f32, runs f32 logsumexp, converts output back. +/// +/// # Safety +/// - `a` must point to `reduce_size * outer_size` elements +/// - `out` must point to `outer_size` elements +pub unsafe fn logsumexp_bf16( + a: *const half::bf16, + out: *mut half::bf16, + reduce_size: usize, + outer_size: usize, +) { + use super::half_convert_utils::*; + let input_len = outer_size * reduce_size; + let mut a_f32 = vec![0.0f32; input_len]; + let mut out_f32 = vec![0.0f32; outer_size]; + convert_bf16_to_f32(a as *const u16, a_f32.as_mut_ptr(), input_len); + logsumexp_f32( + a_f32.as_ptr(), + out_f32.as_mut_ptr(), + reduce_size, + outer_size, + ); + convert_f32_to_bf16(out_f32.as_ptr(), out as *mut u16, outer_size); +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/runtime/cpu/kernels/simd/math/avx2/exp_log.rs b/src/runtime/cpu/kernels/simd/math/avx2/exp_log.rs new file mode 100644 index 00000000..9750b6fa --- /dev/null +++ b/src/runtime/cpu/kernels/simd/math/avx2/exp_log.rs @@ -0,0 +1,475 @@ +//! AVX2 exponential and logarithm implementations (exp, log, and derived functions) +//! +//! # Safety +//! +//! All functions require AVX2 and FMA CPU features. + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::super::common::{exp_coefficients, log_coefficients}; + +// ============================================================================ +// Exponential function: exp(x) +// ============================================================================ + +/// Fast SIMD exp approximation for f32 using AVX2+FMA +/// +/// See `common::_EXP_ALGORITHM_DOC` for algorithm details. +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn exp_f32(x: __m256) -> __m256 { + use exp_coefficients::*; + + let log2e = _mm256_set1_ps(std::f32::consts::LOG2_E); + let ln2 = _mm256_set1_ps(std::f32::consts::LN_2); + + let c0 = _mm256_set1_ps(C0_F32); + let c1 = _mm256_set1_ps(C1_F32); + let c2 = _mm256_set1_ps(C2_F32); + let c3 = _mm256_set1_ps(C3_F32); + let c4 = _mm256_set1_ps(C4_F32); + let c5 = _mm256_set1_ps(C5_F32); + let c6 = _mm256_set1_ps(C6_F32); + + // Clamp input to avoid overflow/underflow + let x = _mm256_max_ps(x, _mm256_set1_ps(MIN_F32)); + let x = _mm256_min_ps(x, _mm256_set1_ps(MAX_F32)); + + // y = x * log2(e) + let y = _mm256_mul_ps(x, log2e); + + // n = round(y) - integer part + let n = _mm256_round_ps::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(y); + + // f = y - n - fractional part in [-0.5, 0.5] + let f = _mm256_sub_ps(y, n); + + // r = f * ln(2) - convert back to natural log scale + let r = _mm256_mul_ps(f, ln2); + + // Polynomial approximation using Horner's method + let r2 = _mm256_mul_ps(r, r); + let r3 = _mm256_mul_ps(r2, r); + let r4 = _mm256_mul_ps(r2, r2); + let r5 = _mm256_mul_ps(r4, r); + let r6 = _mm256_mul_ps(r4, r2); + + let mut poly = c0; + poly = _mm256_fmadd_ps(c1, r, poly); + poly = _mm256_fmadd_ps(c2, r2, poly); + poly = _mm256_fmadd_ps(c3, r3, poly); + poly = _mm256_fmadd_ps(c4, r4, poly); + poly = _mm256_fmadd_ps(c5, r5, poly); + poly = _mm256_fmadd_ps(c6, r6, poly); + + // Compute 2^n using IEEE 754 bit manipulation + // 2^n = reinterpret((n + 127) << 23) for f32 + let n_i32 = _mm256_cvtps_epi32(n); + let bias = _mm256_set1_epi32(127); + let exp_bits = _mm256_slli_epi32::<23>(_mm256_add_epi32(n_i32, bias)); + let pow2n = _mm256_castsi256_ps(exp_bits); + + // Result = 2^n * exp(r) + _mm256_mul_ps(pow2n, poly) +} + +/// Fast SIMD exp approximation for f64 using AVX2+FMA +/// +/// See `common::_EXP_ALGORITHM_DOC` for algorithm details. +/// +/// # Note +/// AVX2 lacks native 64-bit integer <-> double conversion. This implementation +/// uses scalar extraction for the 2^n computation, which is the standard +/// workaround. The polynomial computation remains fully vectorized. +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn exp_f64(x: __m256d) -> __m256d { + use exp_coefficients::*; + + let log2e = _mm256_set1_pd(std::f64::consts::LOG2_E); + let ln2 = _mm256_set1_pd(std::f64::consts::LN_2); + + let c0 = _mm256_set1_pd(C0_F64); + let c1 = _mm256_set1_pd(C1_F64); + let c2 = _mm256_set1_pd(C2_F64); + let c3 = _mm256_set1_pd(C3_F64); + let c4 = _mm256_set1_pd(C4_F64); + let c5 = _mm256_set1_pd(C5_F64); + let c6 = _mm256_set1_pd(C6_F64); + + // Clamp input + let x = _mm256_max_pd(x, _mm256_set1_pd(MIN_F64)); + let x = _mm256_min_pd(x, _mm256_set1_pd(MAX_F64)); + + let y = _mm256_mul_pd(x, log2e); + let n = _mm256_round_pd::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(y); + let f = _mm256_sub_pd(y, n); + let r = _mm256_mul_pd(f, ln2); + + let r2 = _mm256_mul_pd(r, r); + let r3 = _mm256_mul_pd(r2, r); + let r4 = _mm256_mul_pd(r2, r2); + let r5 = _mm256_mul_pd(r4, r); + let r6 = _mm256_mul_pd(r4, r2); + + let mut poly = c0; + poly = _mm256_fmadd_pd(c1, r, poly); + poly = _mm256_fmadd_pd(c2, r2, poly); + poly = _mm256_fmadd_pd(c3, r3, poly); + poly = _mm256_fmadd_pd(c4, r4, poly); + poly = _mm256_fmadd_pd(c5, r5, poly); + poly = _mm256_fmadd_pd(c6, r6, poly); + + // AVX2 lacks _mm256_cvtpd_epi64, use scalar conversion for 2^n + // This is a known AVX2 limitation - polynomial eval is still SIMD + let mut result = [0.0f64; 4]; + let mut n_arr = [0.0f64; 4]; + let mut poly_arr = [0.0f64; 4]; + + _mm256_storeu_pd(n_arr.as_mut_ptr(), n); + _mm256_storeu_pd(poly_arr.as_mut_ptr(), poly); + + for i in 0..4 { + let n_i = n_arr[i] as i64; + let exp_bits = ((n_i + 1023) as u64) << 52; + let pow2n = f64::from_bits(exp_bits); + result[i] = pow2n * poly_arr[i]; + } + + _mm256_loadu_pd(result.as_ptr()) +} + +// ============================================================================ +// Natural logarithm: log(x) +// ============================================================================ + +/// Fast SIMD log approximation for f32 using AVX2+FMA +/// +/// See `common::_LOG_ALGORITHM_DOC` for algorithm details. +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn log_f32(x: __m256) -> __m256 { + use log_coefficients::*; + + let one = _mm256_set1_ps(1.0); + let ln2 = _mm256_set1_ps(std::f32::consts::LN_2); + let sqrt2 = _mm256_set1_ps(std::f32::consts::SQRT_2); + let half = _mm256_set1_ps(0.5); + + let c1 = _mm256_set1_ps(C1_F32); + let c2 = _mm256_set1_ps(C2_F32); + let c3 = _mm256_set1_ps(C3_F32); + let c4 = _mm256_set1_ps(C4_F32); + let c5 = _mm256_set1_ps(C5_F32); + let c6 = _mm256_set1_ps(C6_F32); + let c7 = _mm256_set1_ps(C7_F32); + + // Extract exponent: reinterpret as int, shift right by 23, subtract bias + let x_bits = _mm256_castps_si256(x); + let exp_raw = _mm256_srli_epi32::<23>(x_bits); + let exp_unbiased = _mm256_sub_epi32(exp_raw, _mm256_set1_epi32(EXP_BIAS_F32)); + let mut n = _mm256_cvtepi32_ps(exp_unbiased); + + // Extract mantissa and set exponent to 0 (so mantissa is in [1, 2)) + let mantissa_mask = _mm256_set1_epi32(MANTISSA_MASK_F32); + let exp_zero = _mm256_set1_epi32(EXP_ZERO_F32); + let m_bits = _mm256_or_si256(_mm256_and_si256(x_bits, mantissa_mask), exp_zero); + let mut m = _mm256_castsi256_ps(m_bits); + + // Normalize: if m > sqrt(2), divide by 2 and increment exponent + // This keeps f in [-0.2929, 0.4142] for better polynomial accuracy + let need_adjust = _mm256_cmp_ps::<_CMP_GT_OQ>(m, sqrt2); + m = _mm256_blendv_ps(m, _mm256_mul_ps(m, half), need_adjust); + n = _mm256_blendv_ps(n, _mm256_add_ps(n, one), need_adjust); + + // f = m - 1, so log(m) = log(1 + f), f is now in [-0.2929, 0.4142] + let f = _mm256_sub_ps(m, one); + + // Horner's method: ((((((c7*f + c6)*f + c5)*f + c4)*f + c3)*f + c2)*f + c1)*f + let mut poly = c7; + poly = _mm256_fmadd_ps(poly, f, c6); + poly = _mm256_fmadd_ps(poly, f, c5); + poly = _mm256_fmadd_ps(poly, f, c4); + poly = _mm256_fmadd_ps(poly, f, c3); + poly = _mm256_fmadd_ps(poly, f, c2); + poly = _mm256_fmadd_ps(poly, f, c1); + poly = _mm256_mul_ps(poly, f); + + // Result = n * ln(2) + log(m) + _mm256_fmadd_ps(n, ln2, poly) +} + +/// Fast SIMD log approximation for f64 using AVX2+FMA +/// +/// See `common::_LOG_ALGORITHM_DOC` for algorithm details. +/// +/// # Implementation Note +/// Unlike the naive scalar-loop approach, this implementation uses native AVX2 +/// 64-bit SIMD operations for exponent extraction. The only scalar operations +/// are for the normalization conditional and final reconstruction, which cannot +/// be avoided due to AVX2's lack of 64-bit comparison and conversion intrinsics. +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn log_f64(x: __m256d) -> __m256d { + use log_coefficients::*; + + let one = _mm256_set1_pd(1.0); + let ln2 = _mm256_set1_pd(std::f64::consts::LN_2); + let sqrt2_val = std::f64::consts::SQRT_2; + + let c1 = _mm256_set1_pd(C1_F64); + let c2 = _mm256_set1_pd(C2_F64); + let c3 = _mm256_set1_pd(C3_F64); + let c4 = _mm256_set1_pd(C4_F64); + let c5 = _mm256_set1_pd(C5_F64); + let c6 = _mm256_set1_pd(C6_F64); + let c7 = _mm256_set1_pd(C7_F64); + let c8 = _mm256_set1_pd(C8_F64); + let c9 = _mm256_set1_pd(C9_F64); + + // Use SIMD for bit manipulation - AVX2 has 64-bit shifts + let x_bits = _mm256_castpd_si256(x); + + // Extract exponent using 64-bit SIMD shift + let exp_raw = _mm256_srli_epi64::<52>(x_bits); + + // Extract mantissa and set exponent to bias (so mantissa is in [1, 2)) + let mantissa_mask = _mm256_set1_epi64x(MANTISSA_MASK_F64 as i64); + let exp_zero = _mm256_set1_epi64x(EXP_ZERO_F64 as i64); + let m_bits = _mm256_or_si256(_mm256_and_si256(x_bits, mantissa_mask), exp_zero); + let m_initial = _mm256_castsi256_pd(m_bits); + + // AVX2 lacks 64-bit int comparison and conversion, so we extract for + // normalization and exponent calculation. The heavy lifting (polynomial + // evaluation) remains fully vectorized. + let mut m_arr = [0.0f64; 4]; + let mut exp_arr = [0i64; 4]; + _mm256_storeu_pd(m_arr.as_mut_ptr(), m_initial); + _mm256_storeu_si256(exp_arr.as_mut_ptr() as *mut __m256i, exp_raw); + + let mut n_arr = [0.0f64; 4]; + for i in 0..4 { + let mut exp_unbiased = exp_arr[i] - EXP_BIAS_F64; + let mut m = m_arr[i]; + + // Normalize: if m > sqrt(2), divide by 2 and increment exponent + if m > sqrt2_val { + m *= 0.5; + exp_unbiased += 1; + } + + n_arr[i] = exp_unbiased as f64; + m_arr[i] = m; + } + + let n = _mm256_loadu_pd(n_arr.as_ptr()); + let m = _mm256_loadu_pd(m_arr.as_ptr()); + + // f = m - 1 (fully SIMD from here) + let f = _mm256_sub_pd(m, one); + + // Horner's method for polynomial (fully vectorized) + let mut poly = c9; + poly = _mm256_fmadd_pd(poly, f, c8); + poly = _mm256_fmadd_pd(poly, f, c7); + poly = _mm256_fmadd_pd(poly, f, c6); + poly = _mm256_fmadd_pd(poly, f, c5); + poly = _mm256_fmadd_pd(poly, f, c4); + poly = _mm256_fmadd_pd(poly, f, c3); + poly = _mm256_fmadd_pd(poly, f, c2); + poly = _mm256_fmadd_pd(poly, f, c1); + poly = _mm256_mul_pd(poly, f); + + // Result = n * ln(2) + log(m) (fully SIMD) + _mm256_fmadd_pd(n, ln2, poly) +} + +// ============================================================================ +// Derived exponential/logarithm functions +// ============================================================================ + +/// Fast SIMD exp2 (2^x) for f32 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn exp2_f32(x: __m256) -> __m256 { + // 2^x = e^(x * ln(2)) + let ln2 = _mm256_set1_ps(std::f32::consts::LN_2); + exp_f32(_mm256_mul_ps(x, ln2)) +} + +/// Fast SIMD exp2 (2^x) for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn exp2_f64(x: __m256d) -> __m256d { + let ln2 = _mm256_set1_pd(std::f64::consts::LN_2); + exp_f64(_mm256_mul_pd(x, ln2)) +} + +/// Fast SIMD expm1 (e^x - 1) for f32 using AVX2 +/// Uses direct computation for |x| > 0.5, Taylor series for small x +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn expm1_f32(x: __m256) -> __m256 { + let one = _mm256_set1_ps(1.0); + let half = _mm256_set1_ps(0.5); + let abs_x = _mm256_andnot_ps(_mm256_set1_ps(-0.0), x); + + // For small |x|, use Taylor series: x + x^2/2 + x^3/6 + x^4/24 + let x2 = _mm256_mul_ps(x, x); + let x3 = _mm256_mul_ps(x2, x); + let x4 = _mm256_mul_ps(x2, x2); + let c2 = _mm256_set1_ps(0.5); + let c3 = _mm256_set1_ps(1.0 / 6.0); + let c4 = _mm256_set1_ps(1.0 / 24.0); + let taylor = _mm256_fmadd_ps(c4, x4, _mm256_fmadd_ps(c3, x3, _mm256_fmadd_ps(c2, x2, x))); + + // For large |x|, use exp(x) - 1 + let exp_result = _mm256_sub_ps(exp_f32(x), one); + + // Blend based on |x| > 0.5 + let mask = _mm256_cmp_ps::<_CMP_GT_OQ>(abs_x, half); + _mm256_blendv_ps(taylor, exp_result, mask) +} + +/// Fast SIMD expm1 (e^x - 1) for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn expm1_f64(x: __m256d) -> __m256d { + let one = _mm256_set1_pd(1.0); + let half = _mm256_set1_pd(0.5); + let abs_x = _mm256_andnot_pd(_mm256_set1_pd(-0.0), x); + + let x2 = _mm256_mul_pd(x, x); + let x3 = _mm256_mul_pd(x2, x); + let x4 = _mm256_mul_pd(x2, x2); + let c2 = _mm256_set1_pd(0.5); + let c3 = _mm256_set1_pd(1.0 / 6.0); + let c4 = _mm256_set1_pd(1.0 / 24.0); + let taylor = _mm256_fmadd_pd(c4, x4, _mm256_fmadd_pd(c3, x3, _mm256_fmadd_pd(c2, x2, x))); + + let exp_result = _mm256_sub_pd(exp_f64(x), one); + let mask = _mm256_cmp_pd::<_CMP_GT_OQ>(abs_x, half); + _mm256_blendv_pd(taylor, exp_result, mask) +} + +/// Fast SIMD log2 for f32 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn log2_f32(x: __m256) -> __m256 { + // log2(x) = log(x) * log2(e) + let log2e = _mm256_set1_ps(std::f32::consts::LOG2_E); + _mm256_mul_ps(log_f32(x), log2e) +} + +/// Fast SIMD log2 for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn log2_f64(x: __m256d) -> __m256d { + let log2e = _mm256_set1_pd(std::f64::consts::LOG2_E); + _mm256_mul_pd(log_f64(x), log2e) +} + +/// Fast SIMD log10 for f32 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn log10_f32(x: __m256) -> __m256 { + // log10(x) = log(x) * log10(e) + let log10e = _mm256_set1_ps(std::f32::consts::LOG10_E); + _mm256_mul_ps(log_f32(x), log10e) +} + +/// Fast SIMD log10 for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn log10_f64(x: __m256d) -> __m256d { + let log10e = _mm256_set1_pd(std::f64::consts::LOG10_E); + _mm256_mul_pd(log_f64(x), log10e) +} + +/// Fast SIMD log1p (log(1+x)) for f32 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn log1p_f32(x: __m256) -> __m256 { + let one = _mm256_set1_ps(1.0); + let half = _mm256_set1_ps(0.5); + let abs_x = _mm256_andnot_ps(_mm256_set1_ps(-0.0), x); + + // For small |x|, use Taylor series: x - x^2/2 + x^3/3 - x^4/4 + let x2 = _mm256_mul_ps(x, x); + let x3 = _mm256_mul_ps(x2, x); + let x4 = _mm256_mul_ps(x2, x2); + let c2 = _mm256_set1_ps(-0.5); + let c3 = _mm256_set1_ps(1.0 / 3.0); + let c4 = _mm256_set1_ps(-0.25); + let taylor = _mm256_fmadd_ps(c4, x4, _mm256_fmadd_ps(c3, x3, _mm256_fmadd_ps(c2, x2, x))); + + // For large |x|, use log(1 + x) + let log_result = log_f32(_mm256_add_ps(one, x)); + + let mask = _mm256_cmp_ps::<_CMP_GT_OQ>(abs_x, half); + _mm256_blendv_ps(taylor, log_result, mask) +} + +/// Fast SIMD log1p (log(1+x)) for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn log1p_f64(x: __m256d) -> __m256d { + let one = _mm256_set1_pd(1.0); + let half = _mm256_set1_pd(0.5); + let abs_x = _mm256_andnot_pd(_mm256_set1_pd(-0.0), x); + + let x2 = _mm256_mul_pd(x, x); + let x3 = _mm256_mul_pd(x2, x); + let x4 = _mm256_mul_pd(x2, x2); + let c2 = _mm256_set1_pd(-0.5); + let c3 = _mm256_set1_pd(1.0 / 3.0); + let c4 = _mm256_set1_pd(-0.25); + let taylor = _mm256_fmadd_pd(c4, x4, _mm256_fmadd_pd(c3, x3, _mm256_fmadd_pd(c2, x2, x))); + + let log_result = log_f64(_mm256_add_pd(one, x)); + let mask = _mm256_cmp_pd::<_CMP_GT_OQ>(abs_x, half); + _mm256_blendv_pd(taylor, log_result, mask) +} diff --git a/src/runtime/cpu/kernels/simd/math/avx2/hyperbolic.rs b/src/runtime/cpu/kernels/simd/math/avx2/hyperbolic.rs new file mode 100644 index 00000000..94723db3 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/math/avx2/hyperbolic.rs @@ -0,0 +1,197 @@ +//! AVX2 hyperbolic function implementations (tanh, sinh, cosh, asinh, acosh, atanh) +//! +//! # Safety +//! +//! All functions require AVX2 and FMA CPU features. + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::exp_log::{exp_f32, exp_f64, log_f32, log_f64}; + +// ============================================================================ +// Hyperbolic tangent: tanh(x) +// ============================================================================ + +/// Fast SIMD tanh approximation for f32 using AVX2+FMA +/// +/// Algorithm: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn tanh_f32(x: __m256) -> __m256 { + let two = _mm256_set1_ps(2.0); + let one = _mm256_set1_ps(1.0); + + let exp2x = exp_f32(_mm256_mul_ps(two, x)); + let num = _mm256_sub_ps(exp2x, one); + let den = _mm256_add_ps(exp2x, one); + + _mm256_div_ps(num, den) +} + +/// Fast SIMD tanh approximation for f64 using AVX2+FMA +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn tanh_f64(x: __m256d) -> __m256d { + let two = _mm256_set1_pd(2.0); + let one = _mm256_set1_pd(1.0); + + let exp2x = exp_f64(_mm256_mul_pd(two, x)); + let num = _mm256_sub_pd(exp2x, one); + let den = _mm256_add_pd(exp2x, one); + + _mm256_div_pd(num, den) +} + +// ============================================================================ +// Hyperbolic sine and cosine: sinh(x), cosh(x) +// ============================================================================ + +/// Fast SIMD sinh for f32 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn sinh_f32(x: __m256) -> __m256 { + // sinh(x) = (exp(x) - exp(-x)) / 2 + let half = _mm256_set1_ps(0.5); + let exp_x = exp_f32(x); + let exp_neg_x = exp_f32(_mm256_sub_ps(_mm256_setzero_ps(), x)); + _mm256_mul_ps(half, _mm256_sub_ps(exp_x, exp_neg_x)) +} + +/// Fast SIMD sinh for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn sinh_f64(x: __m256d) -> __m256d { + let half = _mm256_set1_pd(0.5); + let exp_x = exp_f64(x); + let exp_neg_x = exp_f64(_mm256_sub_pd(_mm256_setzero_pd(), x)); + _mm256_mul_pd(half, _mm256_sub_pd(exp_x, exp_neg_x)) +} + +/// Fast SIMD cosh for f32 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn cosh_f32(x: __m256) -> __m256 { + // cosh(x) = (exp(x) + exp(-x)) / 2 + let half = _mm256_set1_ps(0.5); + let exp_x = exp_f32(x); + let exp_neg_x = exp_f32(_mm256_sub_ps(_mm256_setzero_ps(), x)); + _mm256_mul_ps(half, _mm256_add_ps(exp_x, exp_neg_x)) +} + +/// Fast SIMD cosh for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn cosh_f64(x: __m256d) -> __m256d { + let half = _mm256_set1_pd(0.5); + let exp_x = exp_f64(x); + let exp_neg_x = exp_f64(_mm256_sub_pd(_mm256_setzero_pd(), x)); + _mm256_mul_pd(half, _mm256_add_pd(exp_x, exp_neg_x)) +} + +// ============================================================================ +// Inverse hyperbolic functions: asinh, acosh, atanh +// ============================================================================ + +/// Fast SIMD asinh for f32 using AVX2 +/// asinh(x) = log(x + sqrt(x^2 + 1)) +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn asinh_f32(x: __m256) -> __m256 { + let one = _mm256_set1_ps(1.0); + let x2 = _mm256_mul_ps(x, x); + let sqrt_term = _mm256_sqrt_ps(_mm256_add_ps(x2, one)); + log_f32(_mm256_add_ps(x, sqrt_term)) +} + +/// Fast SIMD asinh for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn asinh_f64(x: __m256d) -> __m256d { + let one = _mm256_set1_pd(1.0); + let x2 = _mm256_mul_pd(x, x); + let sqrt_term = _mm256_sqrt_pd(_mm256_add_pd(x2, one)); + log_f64(_mm256_add_pd(x, sqrt_term)) +} + +/// Fast SIMD acosh for f32 using AVX2 +/// acosh(x) = log(x + sqrt(x^2 - 1)) for x >= 1 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn acosh_f32(x: __m256) -> __m256 { + let one = _mm256_set1_ps(1.0); + let x2 = _mm256_mul_ps(x, x); + let sqrt_term = _mm256_sqrt_ps(_mm256_sub_ps(x2, one)); + log_f32(_mm256_add_ps(x, sqrt_term)) +} + +/// Fast SIMD acosh for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn acosh_f64(x: __m256d) -> __m256d { + let one = _mm256_set1_pd(1.0); + let x2 = _mm256_mul_pd(x, x); + let sqrt_term = _mm256_sqrt_pd(_mm256_sub_pd(x2, one)); + log_f64(_mm256_add_pd(x, sqrt_term)) +} + +/// Fast SIMD atanh for f32 using AVX2 +/// atanh(x) = 0.5 * log((1 + x) / (1 - x)) for |x| < 1 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn atanh_f32(x: __m256) -> __m256 { + let half = _mm256_set1_ps(0.5); + let one = _mm256_set1_ps(1.0); + let one_plus_x = _mm256_add_ps(one, x); + let one_minus_x = _mm256_sub_ps(one, x); + let ratio = _mm256_div_ps(one_plus_x, one_minus_x); + _mm256_mul_ps(half, log_f32(ratio)) +} + +/// Fast SIMD atanh for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn atanh_f64(x: __m256d) -> __m256d { + let half = _mm256_set1_pd(0.5); + let one = _mm256_set1_pd(1.0); + let one_plus_x = _mm256_add_pd(one, x); + let one_minus_x = _mm256_sub_pd(one, x); + let ratio = _mm256_div_pd(one_plus_x, one_minus_x); + _mm256_mul_pd(half, log_f64(ratio)) +} diff --git a/src/runtime/cpu/kernels/simd/math/avx2/reduce.rs b/src/runtime/cpu/kernels/simd/math/avx2/reduce.rs new file mode 100644 index 00000000..04279a52 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/math/avx2/reduce.rs @@ -0,0 +1,76 @@ +//! AVX2 horizontal reduction operations (hmax, hsum) +//! +//! # Safety +//! +//! All functions require AVX2 and FMA CPU features. + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +// ============================================================================ +// Horizontal reductions +// ============================================================================ + +/// Horizontal maximum of 8 f32 values in an AVX2 register +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn hmax_f32(v: __m256) -> f32 { + let high = _mm256_extractf128_ps(v, 1); + let low = _mm256_castps256_ps128(v); + let max128 = _mm_max_ps(low, high); + let shuf = _mm_movehdup_ps(max128); + let max64 = _mm_max_ps(max128, shuf); + let shuf2 = _mm_movehl_ps(max64, max64); + let max32 = _mm_max_ss(max64, shuf2); + _mm_cvtss_f32(max32) +} + +/// Horizontal maximum of 4 f64 values in an AVX2 register +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn hmax_f64(v: __m256d) -> f64 { + let high = _mm256_extractf128_pd(v, 1); + let low = _mm256_castpd256_pd128(v); + let max128 = _mm_max_pd(low, high); + let shuf = _mm_unpackhi_pd(max128, max128); + let max64 = _mm_max_sd(max128, shuf); + _mm_cvtsd_f64(max64) +} + +/// Horizontal sum of 8 f32 values in an AVX2 register +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn hsum_f32(v: __m256) -> f32 { + let high = _mm256_extractf128_ps(v, 1); + let low = _mm256_castps256_ps128(v); + let sum128 = _mm_add_ps(low, high); + let shuf = _mm_movehdup_ps(sum128); + let sum64 = _mm_add_ps(sum128, shuf); + let shuf2 = _mm_movehl_ps(sum64, sum64); + let sum32 = _mm_add_ss(sum64, shuf2); + _mm_cvtss_f32(sum32) +} + +/// Horizontal sum of 4 f64 values in an AVX2 register +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn hsum_f64(v: __m256d) -> f64 { + let high = _mm256_extractf128_pd(v, 1); + let low = _mm256_castpd256_pd128(v); + let sum128 = _mm_add_pd(low, high); + let shuf = _mm_unpackhi_pd(sum128, sum128); + let sum64 = _mm_add_sd(sum128, shuf); + _mm_cvtsd_f64(sum64) +} diff --git a/src/runtime/cpu/kernels/simd/math/avx2/special.rs b/src/runtime/cpu/kernels/simd/math/avx2/special.rs new file mode 100644 index 00000000..46ff2dfd --- /dev/null +++ b/src/runtime/cpu/kernels/simd/math/avx2/special.rs @@ -0,0 +1,117 @@ +//! AVX2 special function implementations (rsqrt, cbrt) +//! +//! # Safety +//! +//! All functions require AVX2 and FMA CPU features. + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::exp_log::{exp_f64, log_f64}; + +// ============================================================================ +// Additional transcendental functions +// ============================================================================ + +/// Fast SIMD rsqrt (1/sqrt(x)) for f32 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn rsqrt_f32(x: __m256) -> __m256 { + // Use Newton-Raphson refinement on the fast approximation + let approx = _mm256_rsqrt_ps(x); + let half = _mm256_set1_ps(0.5); + let three = _mm256_set1_ps(3.0); + // One Newton-Raphson iteration: y = 0.5 * y * (3 - x * y * y) + let x_approx2 = _mm256_mul_ps(x, _mm256_mul_ps(approx, approx)); + let factor = _mm256_sub_ps(three, x_approx2); + _mm256_mul_ps(half, _mm256_mul_ps(approx, factor)) +} + +/// Fast SIMD rsqrt (1/sqrt(x)) for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn rsqrt_f64(x: __m256d) -> __m256d { + let sqrt_x = _mm256_sqrt_pd(x); + _mm256_div_pd(_mm256_set1_pd(1.0), sqrt_x) +} + +/// Fast SIMD cbrt (cube root) for f32 using AVX2 +/// Uses Halley's method for refinement +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn cbrt_f32(x: __m256) -> __m256 { + // Handle sign separately + let sign_mask = _mm256_set1_ps(-0.0); + let sign = _mm256_and_ps(x, sign_mask); + let abs_x = _mm256_andnot_ps(sign_mask, x); + + // Initial approximation using bit manipulation + // cbrt(x) ≈ 2^(log2(x)/3) via IEEE 754 + let one_third = _mm256_set1_ps(1.0 / 3.0); + let bias = _mm256_set1_ps(127.0); + + // Extract exponent: e = floor(log2(|x|)) + let xi = _mm256_castps_si256(abs_x); + let exp_bits = _mm256_srli_epi32::<23>(xi); + let exp_f = _mm256_cvtepi32_ps(_mm256_sub_epi32(exp_bits, _mm256_set1_epi32(127))); + + // Initial guess: 2^(e/3) + let new_exp = _mm256_mul_ps(exp_f, one_third); + let new_exp_i = _mm256_cvtps_epi32(_mm256_add_ps(new_exp, bias)); + let guess = _mm256_castsi256_ps(_mm256_slli_epi32::<23>(new_exp_i)); + + // Newton-Raphson iteration: y = y * (2*y^3 + x) / (2*x + y^3) + // Simplified: y = (2*y + x/y^2) / 3 + let two = _mm256_set1_ps(2.0); + let three = _mm256_set1_ps(3.0); + + let y = guess; + let y2 = _mm256_mul_ps(y, y); + let y_new = _mm256_div_ps(_mm256_fmadd_ps(two, y, _mm256_div_ps(abs_x, y2)), three); + + // One more iteration + let y2 = _mm256_mul_ps(y_new, y_new); + let result = _mm256_div_ps(_mm256_fmadd_ps(two, y_new, _mm256_div_ps(abs_x, y2)), three); + + // Restore sign + _mm256_or_ps(result, sign) +} + +/// Fast SIMD cbrt (cube root) for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn cbrt_f64(x: __m256d) -> __m256d { + let sign_mask = _mm256_set1_pd(-0.0); + let sign = _mm256_and_pd(x, sign_mask); + let abs_x = _mm256_andnot_pd(sign_mask, x); + + let one_third = _mm256_set1_pd(1.0 / 3.0); + + // Initial guess: cbrt(x) ≈ exp(log(x) / 3) + let log_x = log_f64(abs_x); + let guess = exp_f64(_mm256_mul_pd(log_x, one_third)); + + let two = _mm256_set1_pd(2.0); + let three = _mm256_set1_pd(3.0); + + let y = guess; + let y2 = _mm256_mul_pd(y, y); + let y_new = _mm256_div_pd(_mm256_fmadd_pd(two, y, _mm256_div_pd(abs_x, y2)), three); + + let y2 = _mm256_mul_pd(y_new, y_new); + let result = _mm256_div_pd(_mm256_fmadd_pd(two, y_new, _mm256_div_pd(abs_x, y2)), three); + + _mm256_or_pd(result, sign) +} diff --git a/src/runtime/cpu/kernels/simd/math/avx2/trig.rs b/src/runtime/cpu/kernels/simd/math/avx2/trig.rs new file mode 100644 index 00000000..a821095e --- /dev/null +++ b/src/runtime/cpu/kernels/simd/math/avx2/trig.rs @@ -0,0 +1,485 @@ +//! AVX2 trigonometric function implementations (sin, cos, tan, atan, asin, acos) +//! +//! # Safety +//! +//! All functions require AVX2 and FMA CPU features. + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::super::common::{atan_coefficients, tan_coefficients, trig_coefficients}; + +// ============================================================================ +// Trigonometric functions: sin, cos, tan +// ============================================================================ + +/// Fast SIMD sin approximation for f32 using AVX2+FMA +/// +/// See `common::_TRIG_ALGORITHM_DOC` for algorithm details. +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn sin_f32(x: __m256) -> __m256 { + use trig_coefficients::*; + + let two_over_pi = _mm256_set1_ps(std::f32::consts::FRAC_2_PI); + let pi_over_2 = _mm256_set1_ps(std::f32::consts::FRAC_PI_2); + + let s1 = _mm256_set1_ps(S1_F32); + let s3 = _mm256_set1_ps(S3_F32); + let s5 = _mm256_set1_ps(S5_F32); + let s7 = _mm256_set1_ps(S7_F32); + + let c0 = _mm256_set1_ps(C0_F32); + let c2 = _mm256_set1_ps(C2_F32); + let c4 = _mm256_set1_ps(C4_F32); + let c6 = _mm256_set1_ps(C6_F32); + + // Range reduction: j = round(x * 2/π), y = x - j * π/2 + let j = _mm256_round_ps::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(_mm256_mul_ps( + x, + two_over_pi, + )); + let j_int = _mm256_cvtps_epi32(j); + + let y = _mm256_fnmadd_ps(j, pi_over_2, x); + + let y2 = _mm256_mul_ps(y, y); + let y3 = _mm256_mul_ps(y2, y); + let y4 = _mm256_mul_ps(y2, y2); + let y5 = _mm256_mul_ps(y4, y); + let y6 = _mm256_mul_ps(y4, y2); + let y7 = _mm256_mul_ps(y4, y3); + + // sin(y) polynomial + let sin_y = _mm256_fmadd_ps( + s7, + y7, + _mm256_fmadd_ps(s5, y5, _mm256_fmadd_ps(s3, y3, _mm256_mul_ps(s1, y))), + ); + + // cos(y) polynomial + let cos_y = _mm256_fmadd_ps(c6, y6, _mm256_fmadd_ps(c4, y4, _mm256_fmadd_ps(c2, y2, c0))); + + // Select sin or cos based on j mod 4 + // j mod 4 = 0: sin(y), 1: cos(y), 2: -sin(y), 3: -cos(y) + let j_mod_4 = _mm256_and_si256(j_int, _mm256_set1_epi32(3)); + + // Use cos when j mod 4 is 1 or 3 + let use_cos_mask = _mm256_cmpeq_epi32( + _mm256_and_si256(j_mod_4, _mm256_set1_epi32(1)), + _mm256_set1_epi32(1), + ); + let use_cos_mask = _mm256_castsi256_ps(use_cos_mask); + + // Negate when j mod 4 is 2 or 3 + let negate_mask = _mm256_cmpeq_epi32( + _mm256_and_si256(j_mod_4, _mm256_set1_epi32(2)), + _mm256_set1_epi32(2), + ); + let negate_mask = _mm256_castsi256_ps(negate_mask); + let sign_bit = _mm256_set1_ps(-0.0); // Just the sign bit + + let result = _mm256_blendv_ps(sin_y, cos_y, use_cos_mask); + let negated = _mm256_xor_ps(result, sign_bit); + _mm256_blendv_ps(result, negated, negate_mask) +} + +/// Fast SIMD sin approximation for f64 using AVX2+FMA +/// +/// See `common::_TRIG_ALGORITHM_DOC` for algorithm details. +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn sin_f64(x: __m256d) -> __m256d { + use trig_coefficients::*; + + let two_over_pi = _mm256_set1_pd(std::f64::consts::FRAC_2_PI); + let pi_over_2 = _mm256_set1_pd(std::f64::consts::FRAC_PI_2); + + let s1 = _mm256_set1_pd(S1_F64); + let s3 = _mm256_set1_pd(S3_F64); + let s5 = _mm256_set1_pd(S5_F64); + let s7 = _mm256_set1_pd(S7_F64); + let s9 = _mm256_set1_pd(S9_F64); + + let c0 = _mm256_set1_pd(C0_F64); + let c2 = _mm256_set1_pd(C2_F64); + let c4 = _mm256_set1_pd(C4_F64); + let c6 = _mm256_set1_pd(C6_F64); + let c8 = _mm256_set1_pd(C8_F64); + + let j = _mm256_round_pd::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(_mm256_mul_pd( + x, + two_over_pi, + )); + + // Get j as integers for quadrant selection (AVX2 lacks 64-bit int conversion) + let mut j_arr = [0.0f64; 4]; + _mm256_storeu_pd(j_arr.as_mut_ptr(), j); + let j_int: [i32; 4] = [ + j_arr[0] as i32, + j_arr[1] as i32, + j_arr[2] as i32, + j_arr[3] as i32, + ]; + + let y = _mm256_fnmadd_pd(j, pi_over_2, x); + + let y2 = _mm256_mul_pd(y, y); + let y3 = _mm256_mul_pd(y2, y); + let y4 = _mm256_mul_pd(y2, y2); + let y5 = _mm256_mul_pd(y4, y); + let y6 = _mm256_mul_pd(y4, y2); + let y7 = _mm256_mul_pd(y4, y3); + let y8 = _mm256_mul_pd(y4, y4); + let y9 = _mm256_mul_pd(y8, y); + + // sin(y) and cos(y) polynomials + let mut sin_y = _mm256_mul_pd(s1, y); + sin_y = _mm256_fmadd_pd(s3, y3, sin_y); + sin_y = _mm256_fmadd_pd(s5, y5, sin_y); + sin_y = _mm256_fmadd_pd(s7, y7, sin_y); + sin_y = _mm256_fmadd_pd(s9, y9, sin_y); + + let mut cos_y = c0; + cos_y = _mm256_fmadd_pd(c2, y2, cos_y); + cos_y = _mm256_fmadd_pd(c4, y4, cos_y); + cos_y = _mm256_fmadd_pd(c6, y6, cos_y); + cos_y = _mm256_fmadd_pd(c8, y8, cos_y); + + // Compute result per-element based on quadrant + let mut sin_arr = [0.0f64; 4]; + let mut cos_arr = [0.0f64; 4]; + _mm256_storeu_pd(sin_arr.as_mut_ptr(), sin_y); + _mm256_storeu_pd(cos_arr.as_mut_ptr(), cos_y); + + let mut result = [0.0f64; 4]; + for i in 0..4 { + let quadrant = j_int[i] & 3; + result[i] = match quadrant { + 0 => sin_arr[i], + 1 => cos_arr[i], + 2 => -sin_arr[i], + 3 => -cos_arr[i], + _ => unreachable!(), + }; + } + + _mm256_loadu_pd(result.as_ptr()) +} + +/// Fast SIMD cos approximation for f32 using AVX2+FMA +/// +/// Implemented as: cos(x) = sin(x + π/2) +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn cos_f32(x: __m256) -> __m256 { + let pi_over_2 = _mm256_set1_ps(std::f32::consts::FRAC_PI_2); + sin_f32(_mm256_add_ps(x, pi_over_2)) +} + +/// Fast SIMD cos approximation for f64 using AVX2+FMA +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn cos_f64(x: __m256d) -> __m256d { + let pi_over_2 = _mm256_set1_pd(std::f64::consts::FRAC_PI_2); + sin_f64(_mm256_add_pd(x, pi_over_2)) +} + +/// Fast SIMD tan approximation for f32 using AVX2+FMA +/// +/// See `common::_TAN_ALGORITHM_DOC` for algorithm details. +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn tan_f32(x: __m256) -> __m256 { + use tan_coefficients::*; + + let two_over_pi = _mm256_set1_ps(std::f32::consts::FRAC_2_PI); + let pi_over_2 = _mm256_set1_ps(std::f32::consts::FRAC_PI_2); + + // Range reduction + let j = _mm256_round_ps::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(_mm256_mul_ps( + x, + two_over_pi, + )); + let y = _mm256_fnmadd_ps(j, pi_over_2, x); + + let t1 = _mm256_set1_ps(T1_F32); + let t3 = _mm256_set1_ps(T3_F32); + let t5 = _mm256_set1_ps(T5_F32); + let t7 = _mm256_set1_ps(T7_F32); + let t9 = _mm256_set1_ps(T9_F32); + let t11 = _mm256_set1_ps(T11_F32); + + let y2 = _mm256_mul_ps(y, y); + + // Horner's method: tan(y) ≈ y * (1 + y²*(t3 + y²*(t5 + y²*(t7 + y²*(t9 + y²*t11))))) + let mut poly = t11; + poly = _mm256_fmadd_ps(poly, y2, t9); + poly = _mm256_fmadd_ps(poly, y2, t7); + poly = _mm256_fmadd_ps(poly, y2, t5); + poly = _mm256_fmadd_ps(poly, y2, t3); + poly = _mm256_fmadd_ps(poly, y2, t1); + let tan_y = _mm256_mul_ps(y, poly); + + // For quadrants 1 and 3, tan(y + π/2) = -1/tan(y) = -cot(y) + let j_int = _mm256_cvtps_epi32(j); + let use_cot_mask = _mm256_cmpeq_epi32( + _mm256_and_si256(j_int, _mm256_set1_epi32(1)), + _mm256_set1_epi32(1), + ); + let use_cot_mask = _mm256_castsi256_ps(use_cot_mask); + + let neg_one = _mm256_set1_ps(-1.0); + let cot_y = _mm256_div_ps(neg_one, tan_y); + + _mm256_blendv_ps(tan_y, cot_y, use_cot_mask) +} + +/// Fast SIMD tan approximation for f64 using AVX2+FMA +/// +/// See `common::_TAN_ALGORITHM_DOC` for algorithm details. +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn tan_f64(x: __m256d) -> __m256d { + use tan_coefficients::*; + + let two_over_pi = _mm256_set1_pd(std::f64::consts::FRAC_2_PI); + let pi_over_2 = _mm256_set1_pd(std::f64::consts::FRAC_PI_2); + + let j = _mm256_round_pd::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(_mm256_mul_pd( + x, + two_over_pi, + )); + let y = _mm256_fnmadd_pd(j, pi_over_2, x); + + let t1 = _mm256_set1_pd(T1_F64); + let t3 = _mm256_set1_pd(T3_F64); + let t5 = _mm256_set1_pd(T5_F64); + let t7 = _mm256_set1_pd(T7_F64); + let t9 = _mm256_set1_pd(T9_F64); + let t11 = _mm256_set1_pd(T11_F64); + let t13 = _mm256_set1_pd(T13_F64); + + let y2 = _mm256_mul_pd(y, y); + + // Horner's method + let mut poly = t13; + poly = _mm256_fmadd_pd(poly, y2, t11); + poly = _mm256_fmadd_pd(poly, y2, t9); + poly = _mm256_fmadd_pd(poly, y2, t7); + poly = _mm256_fmadd_pd(poly, y2, t5); + poly = _mm256_fmadd_pd(poly, y2, t3); + poly = _mm256_fmadd_pd(poly, y2, t1); + let tan_y = _mm256_mul_pd(y, poly); + + // Handle quadrant for cotangent (AVX2 lacks 64-bit int comparison) + let mut j_arr = [0.0f64; 4]; + let mut tan_arr = [0.0f64; 4]; + _mm256_storeu_pd(j_arr.as_mut_ptr(), j); + _mm256_storeu_pd(tan_arr.as_mut_ptr(), tan_y); + + let mut result = [0.0f64; 4]; + for i in 0..4 { + let j_int = j_arr[i] as i32; + result[i] = if (j_int & 1) == 1 { + -1.0 / tan_arr[i] + } else { + tan_arr[i] + }; + } + + _mm256_loadu_pd(result.as_ptr()) +} + +// ============================================================================ +// Inverse tangent function: atan(x) +// ============================================================================ + +/// Fast SIMD atan approximation for f32 using AVX2+FMA +/// +/// See `common::_ATAN_ALGORITHM_DOC` for algorithm details. +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn atan_f32(x: __m256) -> __m256 { + use atan_coefficients::*; + + let one = _mm256_set1_ps(1.0); + let pi_over_2 = _mm256_set1_ps(std::f32::consts::FRAC_PI_2); + + // Save sign and work with absolute value + let sign_mask = _mm256_set1_ps(-0.0); // 0x80000000 + let sign = _mm256_and_ps(x, sign_mask); + let abs_x = _mm256_andnot_ps(sign_mask, x); + + // Range reduction: for |x| > 1, compute atan(1/x) then adjust + let need_recip = _mm256_cmp_ps::<_CMP_GT_OQ>(abs_x, one); + let recip_x = _mm256_div_ps(one, abs_x); + let y = _mm256_blendv_ps(abs_x, recip_x, need_recip); + + // Polynomial approximation for atan(y) where y in [0, 1] + let a0 = _mm256_set1_ps(A0_F32); + let a2 = _mm256_set1_ps(A2_F32); + let a4 = _mm256_set1_ps(A4_F32); + let a6 = _mm256_set1_ps(A6_F32); + let a8 = _mm256_set1_ps(A8_F32); + let a10 = _mm256_set1_ps(A10_F32); + let a12 = _mm256_set1_ps(A12_F32); + + let y2 = _mm256_mul_ps(y, y); + + // Horner's method: a0 + y²*(a2 + y²*(a4 + y²*(a6 + y²*(a8 + y²*(a10 + y²*a12))))) + let mut poly = a12; + poly = _mm256_fmadd_ps(poly, y2, a10); + poly = _mm256_fmadd_ps(poly, y2, a8); + poly = _mm256_fmadd_ps(poly, y2, a6); + poly = _mm256_fmadd_ps(poly, y2, a4); + poly = _mm256_fmadd_ps(poly, y2, a2); + poly = _mm256_fmadd_ps(poly, y2, a0); + let atan_y = _mm256_mul_ps(y, poly); + + // Apply range reduction inverse: if |x| > 1, result = π/2 - atan(1/x) + let adjusted = _mm256_sub_ps(pi_over_2, atan_y); + let result = _mm256_blendv_ps(atan_y, adjusted, need_recip); + + // Restore sign + _mm256_or_ps(result, sign) +} + +/// Fast SIMD atan approximation for f64 using AVX2+FMA +/// +/// See `common::_ATAN_ALGORITHM_DOC` for algorithm details. +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn atan_f64(x: __m256d) -> __m256d { + use atan_coefficients::*; + + let one = _mm256_set1_pd(1.0); + let pi_over_2 = _mm256_set1_pd(std::f64::consts::FRAC_PI_2); + + // Save sign and work with absolute value + let sign_mask = _mm256_set1_pd(-0.0); // 0x8000000000000000 + let sign = _mm256_and_pd(x, sign_mask); + let abs_x = _mm256_andnot_pd(sign_mask, x); + + // Range reduction: for |x| > 1, compute atan(1/x) then adjust + let need_recip = _mm256_cmp_pd::<_CMP_GT_OQ>(abs_x, one); + let recip_x = _mm256_div_pd(one, abs_x); + let y = _mm256_blendv_pd(abs_x, recip_x, need_recip); + + // Polynomial approximation for atan(y) where y in [0, 1] + let a0 = _mm256_set1_pd(A0_F64); + let a2 = _mm256_set1_pd(A2_F64); + let a4 = _mm256_set1_pd(A4_F64); + let a6 = _mm256_set1_pd(A6_F64); + let a8 = _mm256_set1_pd(A8_F64); + let a10 = _mm256_set1_pd(A10_F64); + let a12 = _mm256_set1_pd(A12_F64); + let a14 = _mm256_set1_pd(A14_F64); + let a16 = _mm256_set1_pd(A16_F64); + let a18 = _mm256_set1_pd(A18_F64); + let a20 = _mm256_set1_pd(A20_F64); + + let y2 = _mm256_mul_pd(y, y); + + // Horner's method with 11 terms for higher precision + let mut poly = a20; + poly = _mm256_fmadd_pd(poly, y2, a18); + poly = _mm256_fmadd_pd(poly, y2, a16); + poly = _mm256_fmadd_pd(poly, y2, a14); + poly = _mm256_fmadd_pd(poly, y2, a12); + poly = _mm256_fmadd_pd(poly, y2, a10); + poly = _mm256_fmadd_pd(poly, y2, a8); + poly = _mm256_fmadd_pd(poly, y2, a6); + poly = _mm256_fmadd_pd(poly, y2, a4); + poly = _mm256_fmadd_pd(poly, y2, a2); + poly = _mm256_fmadd_pd(poly, y2, a0); + let atan_y = _mm256_mul_pd(y, poly); + + // Apply range reduction inverse: if |x| > 1, result = π/2 - atan(1/x) + let adjusted = _mm256_sub_pd(pi_over_2, atan_y); + let result = _mm256_blendv_pd(atan_y, adjusted, need_recip); + + // Restore sign + _mm256_or_pd(result, sign) +} + +// ============================================================================ +// Inverse trigonometric functions: asin, acos +// ============================================================================ + +/// Fast SIMD asin for f32 using AVX2 +/// Uses polynomial approximation with range reduction +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn asin_f32(x: __m256) -> __m256 { + // asin(x) = atan(x / sqrt(1 - x^2)) + let one = _mm256_set1_ps(1.0); + let x2 = _mm256_mul_ps(x, x); + let sqrt_term = _mm256_sqrt_ps(_mm256_sub_ps(one, x2)); + let ratio = _mm256_div_ps(x, sqrt_term); + atan_f32(ratio) +} + +/// Fast SIMD asin for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn asin_f64(x: __m256d) -> __m256d { + let one = _mm256_set1_pd(1.0); + let x2 = _mm256_mul_pd(x, x); + let sqrt_term = _mm256_sqrt_pd(_mm256_sub_pd(one, x2)); + let ratio = _mm256_div_pd(x, sqrt_term); + atan_f64(ratio) +} + +/// Fast SIMD acos for f32 using AVX2 +/// acos(x) = pi/2 - asin(x) +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn acos_f32(x: __m256) -> __m256 { + let pi_half = _mm256_set1_ps(std::f32::consts::FRAC_PI_2); + _mm256_sub_ps(pi_half, asin_f32(x)) +} + +/// Fast SIMD acos for f64 using AVX2 +/// +/// # Safety +/// Requires AVX2 and FMA CPU features. +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub unsafe fn acos_f64(x: __m256d) -> __m256d { + let pi_half = _mm256_set1_pd(std::f64::consts::FRAC_PI_2); + _mm256_sub_pd(pi_half, asin_f64(x)) +} diff --git a/src/runtime/cpu/kernels/simd/matmul/dispatch.rs b/src/runtime/cpu/kernels/simd/matmul/dispatch.rs new file mode 100644 index 00000000..6e2be9fa --- /dev/null +++ b/src/runtime/cpu/kernels/simd/matmul/dispatch.rs @@ -0,0 +1,604 @@ +//! SIMD-optimized matrix multiplication with cache-aware tiling +//! +//! This module provides the tiled matmul algorithm that dispatches to +//! SIMD microkernels based on runtime CPU feature detection. +//! +//! # Algorithm Overview (BLIS-style) +//! +//! ```text +//! for jc in (0..N).step_by(NC): # L3 cache blocking +//! for pc in (0..K).step_by(KC): # L2 cache blocking +//! pack B[pc:pc+KC, jc:jc+NC] → B̃ # Pack B panel +//! for ic in (0..M).step_by(MC): # L2 cache blocking +//! pack A[ic:ic+MC, pc:pc+KC] → Ã # Pack A panel +//! for jr in (0..NC).step_by(NR): # Microkernel loop +//! for ir in (0..MC).step_by(MR): +//! microkernel(Ã[ir], B̃[jr], C[ic+ir, jc+jr]) +//! ``` +//! +//! # Microkernel Dimensions +//! +//! | SIMD Level | f32 (MR×NR) | f64 (MR×NR) | +//! |------------|-------------|-------------| +//! | AVX-512 | 6×16 | 6×8 | +//! | AVX2+FMA | 6×8 | 6×4 | +//! | Scalar | 6×4 | 6×4 | + +#[cfg(target_arch = "aarch64")] +use super::aarch64; +#[cfg(target_arch = "x86_64")] +use super::avx2; +#[cfg(target_arch = "x86_64")] +use super::avx512; +use super::scalar::{matmul_bias_scalar_f32, matmul_bias_scalar_f64}; +use super::scalar::{matmul_scalar_f32, matmul_scalar_f64}; +use super::scalar::{microkernel_edge_f32, microkernel_edge_f64}; +use super::small; +use super::tiling::{matmul_bias_tiled_f32, matmul_bias_tiled_f64}; +use super::tiling::{matmul_tiled_f32, matmul_tiled_f64}; +use crate::runtime::cpu::kernels::simd::{SimdLevel, detect_simd}; + +// ============================================================================ +// Constants +// ============================================================================ + +/// Micro-kernel row dimension (Mr) +pub const MR: usize = 6; + +/// L3 cache blocking: M dimension (Mc) +/// Must be a multiple of MR to avoid buffer overflow in packing. +pub const MC: usize = 126; // 21 * MR(6) + +/// L2 cache blocking: K dimension (Kc) +/// Sized so packed_A (MC×KC×4) fits in L2 cache (~256KB): +/// 126 × 256 × 4 = 129KB +pub const KC: usize = 256; + +/// L3 cache blocking: N dimension (Nc) +pub const NC: usize = 512; + +/// Small matrix threshold - below this, register-blocked SIMD is faster than tiled +const SMALL_MATRIX_THRESHOLD: usize = 128 * 128 * 128 + 1; + +// ============================================================================ +// Public API +// ============================================================================ + +/// SIMD-optimized matrix multiplication: C = A @ B +/// +/// Dispatches to the best available SIMD implementation based on CPU features. +/// Falls back to scalar for unsupported CPUs or small matrices. +/// +/// # Safety +/// - All pointers must be valid for the specified dimensions +/// - `out` must not alias with `a` or `b` +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn matmul_f32( + a: *const f32, + b: *const f32, + out: *mut f32, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, +) { + let level = detect_simd(); + + if m * n * k < SMALL_MATRIX_THRESHOLD { + small::small_matmul_f32(a, b, out, m, n, k, lda, ldb, ldc, level); + return; + } + + // Use double-width NR for 12 FMA chains (2×NR columns per microkernel) + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => matmul_tiled_f32::<32>(a, b, out, m, n, k, lda, ldb, ldc, level), + SimdLevel::Avx2Fma => matmul_tiled_f32::<16>(a, b, out, m, n, k, lda, ldb, ldc, level), + _ => matmul_scalar_f32(a, b, out, m, n, k, lda, ldb, ldc), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + matmul_tiled_f32::<8>(a, b, out, m, n, k, lda, ldb, ldc, level) + } + _ => matmul_scalar_f32(a, b, out, m, n, k, lda, ldb, ldc), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + matmul_scalar_f32(a, b, out, m, n, k, lda, ldb, ldc); +} + +/// SIMD-optimized matrix multiplication for f64 +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn matmul_f64( + a: *const f64, + b: *const f64, + out: *mut f64, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, +) { + let level = detect_simd(); + + if m * n * k < SMALL_MATRIX_THRESHOLD { + small::small_matmul_f64(a, b, out, m, n, k, lda, ldb, ldc, level); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => matmul_tiled_f64::<16>(a, b, out, m, n, k, lda, ldb, ldc, level), + SimdLevel::Avx2Fma => matmul_tiled_f64::<8>(a, b, out, m, n, k, lda, ldb, ldc, level), + _ => matmul_scalar_f64(a, b, out, m, n, k, lda, ldb, ldc), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + matmul_tiled_f64::<4>(a, b, out, m, n, k, lda, ldb, ldc, level) + } + _ => matmul_scalar_f64(a, b, out, m, n, k, lda, ldb, ldc), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + matmul_scalar_f64(a, b, out, m, n, k, lda, ldb, ldc); +} + +/// Fused matmul with bias: C = A @ B + bias (single-pass, cache-efficient) +/// +/// Initializes C with bias, then accumulates the matmul result. +/// This is more cache-efficient than separate matmul + bias addition. +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn matmul_bias_f32( + a: *const f32, + b: *const f32, + bias: *const f32, + out: *mut f32, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, +) { + let level = detect_simd(); + + if m * n * k < SMALL_MATRIX_THRESHOLD { + small::small_matmul_bias_f32(a, b, bias, out, m, n, k, lda, ldb, ldc, level); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => { + matmul_bias_tiled_f32::<32>(a, b, bias, out, m, n, k, lda, ldb, ldc, level) + } + SimdLevel::Avx2Fma => { + matmul_bias_tiled_f32::<16>(a, b, bias, out, m, n, k, lda, ldb, ldc, level) + } + _ => matmul_bias_scalar_f32(a, b, bias, out, m, n, k, lda, ldb, ldc), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + matmul_bias_tiled_f32::<8>(a, b, bias, out, m, n, k, lda, ldb, ldc, level) + } + _ => matmul_bias_scalar_f32(a, b, bias, out, m, n, k, lda, ldb, ldc), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + matmul_bias_scalar_f32(a, b, bias, out, m, n, k, lda, ldb, ldc); +} + +/// Fused matmul with bias for f64 +#[inline] +#[allow(clippy::too_many_arguments)] +pub unsafe fn matmul_bias_f64( + a: *const f64, + b: *const f64, + bias: *const f64, + out: *mut f64, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, +) { + let level = detect_simd(); + + if m * n * k < SMALL_MATRIX_THRESHOLD { + small::small_matmul_bias_f64(a, b, bias, out, m, n, k, lda, ldb, ldc, level); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => { + matmul_bias_tiled_f64::<16>(a, b, bias, out, m, n, k, lda, ldb, ldc, level) + } + SimdLevel::Avx2Fma => { + matmul_bias_tiled_f64::<8>(a, b, bias, out, m, n, k, lda, ldb, ldc, level) + } + _ => matmul_bias_scalar_f64(a, b, bias, out, m, n, k, lda, ldb, ldc), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + matmul_bias_tiled_f64::<4>(a, b, bias, out, m, n, k, lda, ldb, ldc, level) + } + _ => matmul_bias_scalar_f64(a, b, bias, out, m, n, k, lda, ldb, ldc), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + matmul_bias_scalar_f64(a, b, bias, out, m, n, k, lda, ldb, ldc); +} + +// ============================================================================ +// Microkernel dispatch +// ============================================================================ + +/// Dispatch to the appropriate SIMD microkernel for f32 (single-width NR) +/// +/// `first_k`: when true, accumulators start from zero (beta=0, no load from C). +#[inline] +pub unsafe fn call_microkernel_f32( + a: *const f32, + b: *const f32, + c: *mut f32, + k: usize, + ldc: usize, + level: SimdLevel, + first_k: bool, +) { + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::microkernel_6x16_f32(a, b, c, k, ldc, first_k), + SimdLevel::Avx2Fma => avx2::microkernel_6x8_f32(a, b, c, k, ldc, first_k), + _ => microkernel_edge_f32(a, b, c, MR, 4, k, ldc, first_k), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::microkernel_6x4_f32(a, b, c, k, ldc, first_k) + } + _ => microkernel_edge_f32(a, b, c, MR, 4, k, ldc, first_k), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + microkernel_edge_f32(a, b, c, MR, 4, k, ldc, first_k); +} + +/// Dispatch to the double-width SIMD microkernel for f32 (2×NR columns) +/// +/// Processes 6 rows × 2*NR columns = 12 independent FMA chains. +#[inline] +pub unsafe fn call_microkernel_2x_f32( + a: *const f32, + b: *const f32, + c: *mut f32, + k: usize, + ldc: usize, + level: SimdLevel, + first_k: bool, +) { + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::microkernel_6x32_f32(a, b, c, k, ldc, first_k), + SimdLevel::Avx2Fma => avx2::microkernel_6x16_f32(a, b, c, k, ldc, first_k), + _ => { + // Fallback: call single-width twice + let nr = 4usize; + microkernel_edge_f32(a, b, c, MR, nr, k, ldc, first_k); + microkernel_edge_f32(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k); + } + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + // NEON: call single-width twice (4+4=8) + aarch64::neon::microkernel_6x4_f32(a, b, c, k, ldc, first_k); + aarch64::neon::microkernel_6x4_f32(a, b.add(4 * k), c.add(4), k, ldc, first_k); + } + _ => { + let nr = 4usize; + microkernel_edge_f32(a, b, c, MR, nr, k, ldc, first_k); + microkernel_edge_f32(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k); + } + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let nr = 4usize; + microkernel_edge_f32(a, b, c, MR, nr, k, ldc, first_k); + microkernel_edge_f32(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k); + } +} + +/// Dispatch to the appropriate SIMD microkernel for f64 (single-width NR) +#[inline] +pub unsafe fn call_microkernel_f64( + a: *const f64, + b: *const f64, + c: *mut f64, + k: usize, + ldc: usize, + level: SimdLevel, + first_k: bool, +) { + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::microkernel_6x8_f64(a, b, c, k, ldc, first_k), + SimdLevel::Avx2Fma => avx2::microkernel_6x4_f64(a, b, c, k, ldc, first_k), + _ => microkernel_edge_f64(a, b, c, MR, 4, k, ldc, first_k), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::microkernel_6x2_f64(a, b, c, k, ldc, first_k) + } + _ => microkernel_edge_f64(a, b, c, MR, 2, k, ldc, first_k), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + microkernel_edge_f64(a, b, c, MR, 4, k, ldc, first_k); +} + +/// Dispatch to the double-width SIMD microkernel for f64 (2×NR columns) +#[inline] +pub unsafe fn call_microkernel_2x_f64( + a: *const f64, + b: *const f64, + c: *mut f64, + k: usize, + ldc: usize, + level: SimdLevel, + first_k: bool, +) { + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::microkernel_6x16_f64(a, b, c, k, ldc, first_k), + SimdLevel::Avx2Fma => avx2::microkernel_6x8_f64(a, b, c, k, ldc, first_k), + _ => { + let nr = 4usize; + microkernel_edge_f64(a, b, c, MR, nr, k, ldc, first_k); + microkernel_edge_f64(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k); + } + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::microkernel_6x2_f64(a, b, c, k, ldc, first_k); + aarch64::neon::microkernel_6x2_f64(a, b.add(2 * k), c.add(2), k, ldc, first_k); + } + _ => { + let nr = 2usize; + microkernel_edge_f64(a, b, c, MR, nr, k, ldc, first_k); + microkernel_edge_f64(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k); + } + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let nr = 4usize; + microkernel_edge_f64(a, b, c, MR, nr, k, ldc, first_k); + microkernel_edge_f64(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k); + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + fn reference_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec { + let mut c = vec![0.0f32; m * n]; + for i in 0..m { + for j in 0..n { + let mut sum = 0.0f32; + for kk in 0..k { + sum += a[i * k + kk] * b[kk * n + j]; + } + c[i * n + j] = sum; + } + } + c + } + + fn reference_matmul_f64(a: &[f64], b: &[f64], m: usize, n: usize, k: usize) -> Vec { + let mut c = vec![0.0f64; m * n]; + for i in 0..m { + for j in 0..n { + let mut sum = 0.0f64; + for kk in 0..k { + sum += a[i * k + kk] * b[kk * n + j]; + } + c[i * n + j] = sum; + } + } + c + } + + fn reference_matmul_bias_f32( + a: &[f32], + b: &[f32], + bias: &[f32], + m: usize, + n: usize, + k: usize, + ) -> Vec { + let mut c = reference_matmul_f32(a, b, m, n, k); + for i in 0..m { + for j in 0..n { + c[i * n + j] += bias[j]; + } + } + c + } + + const F32_SMALL_TOL: f32 = 1e-4; + const F32_LARGE_TOL: f32 = 1e-3; + const F64_SMALL_TOL: f64 = 1e-10; + const F64_LARGE_TOL: f64 = 1e-9; + + #[test] + fn test_matmul_f32_small() { + let (m, n, k) = (4, 4, 4); + let a: Vec = (0..m * k).map(|i| (i + 1) as f32).collect(); + let b: Vec = (0..k * n).map(|i| (i + 1) as f32).collect(); + let mut c = vec![0.0f32; m * n]; + let expected = reference_matmul_f32(&a, &b, m, n, k); + + unsafe { matmul_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) }; + + for i in 0..m * n { + assert!((c[i] - expected[i]).abs() < F32_SMALL_TOL); + } + } + + #[test] + fn test_matmul_f32_large() { + let (m, n, k) = (128, 128, 128); + let a: Vec = (0..m * k).map(|i| ((i % 17) as f32) * 0.1).collect(); + let b: Vec = (0..k * n).map(|i| ((i % 13) as f32) * 0.1).collect(); + let mut c = vec![0.0f32; m * n]; + let expected = reference_matmul_f32(&a, &b, m, n, k); + + unsafe { matmul_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) }; + + let max_diff = (0..m * n) + .map(|i| (c[i] - expected[i]).abs()) + .fold(0.0f32, f32::max); + assert!(max_diff < F32_LARGE_TOL); + } + + #[test] + fn test_matmul_f64_small() { + let (m, n, k) = (4, 4, 4); + let a: Vec = (0..m * k).map(|i| (i + 1) as f64).collect(); + let b: Vec = (0..k * n).map(|i| (i + 1) as f64).collect(); + let mut c = vec![0.0f64; m * n]; + let expected = reference_matmul_f64(&a, &b, m, n, k); + + unsafe { matmul_f64(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) }; + + for i in 0..m * n { + assert!((c[i] - expected[i]).abs() < F64_SMALL_TOL); + } + } + + #[test] + fn test_matmul_f64_large() { + let (m, n, k) = (128, 128, 128); + let a: Vec = (0..m * k).map(|i| ((i % 17) as f64) * 0.1).collect(); + let b: Vec = (0..k * n).map(|i| ((i % 13) as f64) * 0.1).collect(); + let mut c = vec![0.0f64; m * n]; + let expected = reference_matmul_f64(&a, &b, m, n, k); + + unsafe { matmul_f64(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) }; + + let max_diff = (0..m * n) + .map(|i| (c[i] - expected[i]).abs()) + .fold(0.0f64, f64::max); + assert!(max_diff < F64_LARGE_TOL); + } + + #[test] + fn test_matmul_non_square() { + let (m, n, k) = (37, 53, 41); + let a: Vec = (0..m * k).map(|i| ((i % 7) as f32) * 0.5).collect(); + let b: Vec = (0..k * n).map(|i| ((i % 11) as f32) * 0.3).collect(); + let mut c = vec![0.0f32; m * n]; + let expected = reference_matmul_f32(&a, &b, m, n, k); + + unsafe { matmul_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) }; + + let max_diff = (0..m * n) + .map(|i| (c[i] - expected[i]).abs()) + .fold(0.0f32, f32::max); + assert!(max_diff < F32_LARGE_TOL); + } + + #[test] + fn test_matmul_bias_f32_small() { + let (m, n, k) = (4, 4, 4); + let a: Vec = (0..m * k).map(|i| (i + 1) as f32).collect(); + let b: Vec = (0..k * n).map(|i| (i + 1) as f32).collect(); + let bias: Vec = (0..n).map(|i| (i * 10) as f32).collect(); + let mut c = vec![0.0f32; m * n]; + let expected = reference_matmul_bias_f32(&a, &b, &bias, m, n, k); + + unsafe { + matmul_bias_f32( + a.as_ptr(), + b.as_ptr(), + bias.as_ptr(), + c.as_mut_ptr(), + m, + n, + k, + k, + n, + n, + ) + }; + + for i in 0..m * n { + assert!((c[i] - expected[i]).abs() < F32_SMALL_TOL); + } + } + + #[test] + fn test_matmul_bias_f32_large() { + let (m, n, k) = (128, 128, 128); + let a: Vec = (0..m * k).map(|i| ((i % 17) as f32) * 0.1).collect(); + let b: Vec = (0..k * n).map(|i| ((i % 13) as f32) * 0.1).collect(); + let bias: Vec = (0..n).map(|i| ((i % 7) as f32) * 0.5).collect(); + let mut c = vec![0.0f32; m * n]; + let expected = reference_matmul_bias_f32(&a, &b, &bias, m, n, k); + + unsafe { + matmul_bias_f32( + a.as_ptr(), + b.as_ptr(), + bias.as_ptr(), + c.as_mut_ptr(), + m, + n, + k, + k, + n, + n, + ) + }; + + let max_diff = (0..m * n) + .map(|i| (c[i] - expected[i]).abs()) + .fold(0.0f32, f32::max); + assert!(max_diff < F32_LARGE_TOL); + } + + #[test] + fn test_simd_level_detection() { + let level = detect_simd(); + println!("Detected SIMD level: {level:?}"); + } +} diff --git a/src/runtime/cpu/kernels/simd/matmul/gemv_bt.rs b/src/runtime/cpu/kernels/simd/matmul/gemv_bt.rs new file mode 100644 index 00000000..cb3d0b36 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/matmul/gemv_bt.rs @@ -0,0 +1,646 @@ +//! GEMV-BT kernel: C[M,N] = A[M,K] @ B^T where B is stored as [N,K] +//! +//! When a weight matrix W[N,K] is transposed to get W^T[K,N], the result has +//! shape [K,N] and strides [1,K] — it's a view into the original [N,K] data. +//! Rather than copying to make it contiguous (which allocates K*N elements), +//! we can compute the matmul directly: each output C[m,n] = dot(A[m,:], B[n,:]) +//! where both A[m,:] and B[n,:] are contiguous K-element vectors. +//! +//! For decode (M=1), this eliminates: +//! - The contiguous copy of the entire weight matrix (e.g. 500MB for lm_head) +//! - The full B→f32 conversion buffer allocation (another 1GB for BF16) +//! - The overhead of the tiled GEMM algorithm for a single row + +use super::super::SimdLevel; + +/// GEMV-BT for f32: C[M,N] = A[M,K] @ B^T, B stored [N,K] row-major +/// +/// # Safety +/// - `a` must point to M*K contiguous f32 elements (row-major, stride=K) +/// - `b` must point to N*K contiguous f32 elements (row-major, stride=K) +/// - `out` must point to M*N writable f32 elements (row-major, stride=ldc) +#[allow(clippy::too_many_arguments)] +pub unsafe fn gemv_bt_f32( + a: *const f32, + b: *const f32, + out: *mut f32, + m: usize, + n: usize, + k: usize, + ldc: usize, + level: SimdLevel, +) { + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => gemv_bt_f32_avx512(a, b, out, m, n, k, ldc), + SimdLevel::Avx2Fma => gemv_bt_f32_avx2(a, b, out, m, n, k, ldc), + _ => gemv_bt_f32_scalar(a, b, out, m, n, k, ldc), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => gemv_bt_f32_neon(a, b, out, m, n, k, ldc), + _ => gemv_bt_f32_scalar(a, b, out, m, n, k, ldc), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let _ = level; + gemv_bt_f32_scalar(a, b, out, m, n, k, ldc); + } +} + +#[allow(clippy::too_many_arguments)] +unsafe fn gemv_bt_f32_scalar( + a: *const f32, + b: *const f32, + out: *mut f32, + m: usize, + n: usize, + k: usize, + ldc: usize, +) { + for row in 0..m { + let a_row = a.add(row * k); + let out_row = out.add(row * ldc); + for col in 0..n { + let b_row = b.add(col * k); + let mut sum = 0.0f32; + for i in 0..k { + sum += *a_row.add(i) * *b_row.add(i); + } + *out_row.add(col) = sum; + } + } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2,fma")] +#[allow(clippy::too_many_arguments)] +unsafe fn gemv_bt_f32_avx2( + a: *const f32, + b: *const f32, + out: *mut f32, + m: usize, + n: usize, + k: usize, + ldc: usize, +) { + use std::arch::x86_64::*; + + for row in 0..m { + let a_row = a.add(row * k); + let out_row = out.add(row * ldc); + + // Process 4 output columns at a time for better ILP + let mut col = 0usize; + while col + 4 <= n { + let b0 = b.add(col * k); + let b1 = b.add((col + 1) * k); + let b2 = b.add((col + 2) * k); + let b3 = b.add((col + 3) * k); + + let mut acc0 = _mm256_setzero_ps(); + let mut acc1 = _mm256_setzero_ps(); + let mut acc2 = _mm256_setzero_ps(); + let mut acc3 = _mm256_setzero_ps(); + + let mut i = 0usize; + while i + 8 <= k { + let av = _mm256_loadu_ps(a_row.add(i)); + acc0 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b0.add(i)), acc0); + acc1 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b1.add(i)), acc1); + acc2 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b2.add(i)), acc2); + acc3 = _mm256_fmadd_ps(av, _mm256_loadu_ps(b3.add(i)), acc3); + i += 8; + } + + let mut s0 = hsum_avx2(acc0); + let mut s1 = hsum_avx2(acc1); + let mut s2 = hsum_avx2(acc2); + let mut s3 = hsum_avx2(acc3); + + // Scalar tail + while i < k { + let av = *a_row.add(i); + s0 += av * *b0.add(i); + s1 += av * *b1.add(i); + s2 += av * *b2.add(i); + s3 += av * *b3.add(i); + i += 1; + } + + *out_row.add(col) = s0; + *out_row.add(col + 1) = s1; + *out_row.add(col + 2) = s2; + *out_row.add(col + 3) = s3; + col += 4; + } + + // Remaining columns + while col < n { + let b_row = b.add(col * k); + let mut acc = _mm256_setzero_ps(); + let mut i = 0usize; + while i + 8 <= k { + let av = _mm256_loadu_ps(a_row.add(i)); + acc = _mm256_fmadd_ps(av, _mm256_loadu_ps(b_row.add(i)), acc); + i += 8; + } + let mut s = hsum_avx2(acc); + while i < k { + s += *a_row.add(i) * *b_row.add(i); + i += 1; + } + *out_row.add(col) = s; + col += 1; + } + } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +#[inline] +pub unsafe fn hsum_avx2(v: std::arch::x86_64::__m256) -> f32 { + use std::arch::x86_64::*; + // [a0+a4, a1+a5, a2+a6, a3+a7] as 128-bit + let hi = _mm256_extractf128_ps(v, 1); + let lo = _mm256_castps256_ps128(v); + let sum128 = _mm_add_ps(lo, hi); + // [s0+s2, s1+s3, ...] + let shuf = _mm_movehdup_ps(sum128); + let sums = _mm_add_ps(sum128, shuf); + let shuf2 = _mm_movehl_ps(sums, sums); + let sums2 = _mm_add_ss(sums, shuf2); + _mm_cvtss_f32(sums2) +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +#[allow(clippy::too_many_arguments)] +unsafe fn gemv_bt_f32_avx512( + a: *const f32, + b: *const f32, + out: *mut f32, + m: usize, + n: usize, + k: usize, + ldc: usize, +) { + use std::arch::x86_64::*; + + for row in 0..m { + let a_row = a.add(row * k); + let out_row = out.add(row * ldc); + + // Process 4 output columns at a time + let mut col = 0usize; + while col + 4 <= n { + let b0 = b.add(col * k); + let b1 = b.add((col + 1) * k); + let b2 = b.add((col + 2) * k); + let b3 = b.add((col + 3) * k); + + let mut acc0 = _mm512_setzero_ps(); + let mut acc1 = _mm512_setzero_ps(); + let mut acc2 = _mm512_setzero_ps(); + let mut acc3 = _mm512_setzero_ps(); + + let mut i = 0usize; + while i + 16 <= k { + let av = _mm512_loadu_ps(a_row.add(i)); + acc0 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b0.add(i)), acc0); + acc1 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b1.add(i)), acc1); + acc2 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b2.add(i)), acc2); + acc3 = _mm512_fmadd_ps(av, _mm512_loadu_ps(b3.add(i)), acc3); + i += 16; + } + + let mut s0 = _mm512_reduce_add_ps(acc0); + let mut s1 = _mm512_reduce_add_ps(acc1); + let mut s2 = _mm512_reduce_add_ps(acc2); + let mut s3 = _mm512_reduce_add_ps(acc3); + + while i < k { + let av = *a_row.add(i); + s0 += av * *b0.add(i); + s1 += av * *b1.add(i); + s2 += av * *b2.add(i); + s3 += av * *b3.add(i); + i += 1; + } + + *out_row.add(col) = s0; + *out_row.add(col + 1) = s1; + *out_row.add(col + 2) = s2; + *out_row.add(col + 3) = s3; + col += 4; + } + + while col < n { + let b_row = b.add(col * k); + let mut acc = _mm512_setzero_ps(); + let mut i = 0usize; + while i + 16 <= k { + let av = _mm512_loadu_ps(a_row.add(i)); + acc = _mm512_fmadd_ps(av, _mm512_loadu_ps(b_row.add(i)), acc); + i += 16; + } + let mut s = _mm512_reduce_add_ps(acc); + while i < k { + s += *a_row.add(i) * *b_row.add(i); + i += 1; + } + *out_row.add(col) = s; + col += 1; + } + } +} + +/// GEMV-BT for f64: C[M,N] = A[M,K] @ B^T, B stored [N,K] row-major +#[allow(clippy::too_many_arguments)] +pub unsafe fn gemv_bt_f64( + a: *const f64, + b: *const f64, + out: *mut f64, + m: usize, + n: usize, + k: usize, + ldc: usize, + level: SimdLevel, +) { + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => gemv_bt_f64_avx512(a, b, out, m, n, k, ldc), + SimdLevel::Avx2Fma => gemv_bt_f64_avx2(a, b, out, m, n, k, ldc), + _ => gemv_bt_f64_scalar(a, b, out, m, n, k, ldc), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => gemv_bt_f64_neon(a, b, out, m, n, k, ldc), + _ => gemv_bt_f64_scalar(a, b, out, m, n, k, ldc), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + { + let _ = level; + gemv_bt_f64_scalar(a, b, out, m, n, k, ldc); + } +} + +#[allow(clippy::too_many_arguments)] +unsafe fn gemv_bt_f64_scalar( + a: *const f64, + b: *const f64, + out: *mut f64, + m: usize, + n: usize, + k: usize, + ldc: usize, +) { + for row in 0..m { + let a_row = a.add(row * k); + let out_row = out.add(row * ldc); + for col in 0..n { + let b_row = b.add(col * k); + let mut sum = 0.0f64; + for i in 0..k { + sum += *a_row.add(i) * *b_row.add(i); + } + *out_row.add(col) = sum; + } + } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2,fma")] +#[allow(clippy::too_many_arguments)] +unsafe fn gemv_bt_f64_avx2( + a: *const f64, + b: *const f64, + out: *mut f64, + m: usize, + n: usize, + k: usize, + ldc: usize, +) { + use std::arch::x86_64::*; + + for row in 0..m { + let a_row = a.add(row * k); + let out_row = out.add(row * ldc); + + for col in 0..n { + let b_row = b.add(col * k); + let mut acc0 = _mm256_setzero_pd(); + let mut acc1 = _mm256_setzero_pd(); + + let mut i = 0usize; + while i + 8 <= k { + acc0 = _mm256_fmadd_pd( + _mm256_loadu_pd(a_row.add(i)), + _mm256_loadu_pd(b_row.add(i)), + acc0, + ); + acc1 = _mm256_fmadd_pd( + _mm256_loadu_pd(a_row.add(i + 4)), + _mm256_loadu_pd(b_row.add(i + 4)), + acc1, + ); + i += 8; + } + let mut acc = _mm256_add_pd(acc0, acc1); + + while i + 4 <= k { + acc = _mm256_fmadd_pd( + _mm256_loadu_pd(a_row.add(i)), + _mm256_loadu_pd(b_row.add(i)), + acc, + ); + i += 4; + } + + let mut s = hsum_avx2_f64(acc); + while i < k { + s += *a_row.add(i) * *b_row.add(i); + i += 1; + } + *out_row.add(col) = s; + } + } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +#[inline] +unsafe fn hsum_avx2_f64(v: std::arch::x86_64::__m256d) -> f64 { + use std::arch::x86_64::*; + let hi = _mm256_extractf128_pd(v, 1); + let lo = _mm256_castpd256_pd128(v); + let sum128 = _mm_add_pd(lo, hi); + let hi64 = _mm_unpackhi_pd(sum128, sum128); + let sum = _mm_add_sd(sum128, hi64); + _mm_cvtsd_f64(sum) +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +#[allow(clippy::too_many_arguments)] +unsafe fn gemv_bt_f64_avx512( + a: *const f64, + b: *const f64, + out: *mut f64, + m: usize, + n: usize, + k: usize, + ldc: usize, +) { + use std::arch::x86_64::*; + + for row in 0..m { + let a_row = a.add(row * k); + let out_row = out.add(row * ldc); + + for col in 0..n { + let b_row = b.add(col * k); + let mut acc = _mm512_setzero_pd(); + let mut i = 0usize; + while i + 8 <= k { + let av = _mm512_loadu_pd(a_row.add(i)); + acc = _mm512_fmadd_pd(av, _mm512_loadu_pd(b_row.add(i)), acc); + i += 8; + } + let mut s = _mm512_reduce_add_pd(acc); + while i < k { + s += *a_row.add(i) * *b_row.add(i); + i += 1; + } + *out_row.add(col) = s; + } + } +} + +// ============================================================================ +// NEON implementations (aarch64) +// ============================================================================ + +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +#[allow(clippy::too_many_arguments)] +unsafe fn gemv_bt_f32_neon( + a: *const f32, + b: *const f32, + out: *mut f32, + m: usize, + n: usize, + k: usize, + ldc: usize, +) { + use std::arch::aarch64::*; + + for row in 0..m { + let a_row = a.add(row * k); + let out_row = out.add(row * ldc); + + // Process 4 output columns at a time + let mut col = 0usize; + while col + 4 <= n { + let b0 = b.add(col * k); + let b1 = b.add((col + 1) * k); + let b2 = b.add((col + 2) * k); + let b3 = b.add((col + 3) * k); + + let mut acc0 = vdupq_n_f32(0.0); + let mut acc1 = vdupq_n_f32(0.0); + let mut acc2 = vdupq_n_f32(0.0); + let mut acc3 = vdupq_n_f32(0.0); + + let mut i = 0usize; + while i + 4 <= k { + let av = vld1q_f32(a_row.add(i)); + acc0 = vfmaq_f32(acc0, av, vld1q_f32(b0.add(i))); + acc1 = vfmaq_f32(acc1, av, vld1q_f32(b1.add(i))); + acc2 = vfmaq_f32(acc2, av, vld1q_f32(b2.add(i))); + acc3 = vfmaq_f32(acc3, av, vld1q_f32(b3.add(i))); + i += 4; + } + + let mut s0 = vaddvq_f32(acc0); + let mut s1 = vaddvq_f32(acc1); + let mut s2 = vaddvq_f32(acc2); + let mut s3 = vaddvq_f32(acc3); + + while i < k { + let av = *a_row.add(i); + s0 += av * *b0.add(i); + s1 += av * *b1.add(i); + s2 += av * *b2.add(i); + s3 += av * *b3.add(i); + i += 1; + } + + *out_row.add(col) = s0; + *out_row.add(col + 1) = s1; + *out_row.add(col + 2) = s2; + *out_row.add(col + 3) = s3; + col += 4; + } + + while col < n { + let b_row = b.add(col * k); + let mut acc = vdupq_n_f32(0.0); + let mut i = 0usize; + while i + 4 <= k { + acc = vfmaq_f32(acc, vld1q_f32(a_row.add(i)), vld1q_f32(b_row.add(i))); + i += 4; + } + let mut s = vaddvq_f32(acc); + while i < k { + s += *a_row.add(i) * *b_row.add(i); + i += 1; + } + *out_row.add(col) = s; + col += 1; + } + } +} + +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +#[allow(clippy::too_many_arguments)] +unsafe fn gemv_bt_f64_neon( + a: *const f64, + b: *const f64, + out: *mut f64, + m: usize, + n: usize, + k: usize, + ldc: usize, +) { + use std::arch::aarch64::*; + + for row in 0..m { + let a_row = a.add(row * k); + let out_row = out.add(row * ldc); + + for col in 0..n { + let b_row = b.add(col * k); + let mut acc0 = vdupq_n_f64(0.0); + let mut acc1 = vdupq_n_f64(0.0); + + let mut i = 0usize; + while i + 4 <= k { + acc0 = vfmaq_f64(acc0, vld1q_f64(a_row.add(i)), vld1q_f64(b_row.add(i))); + acc1 = vfmaq_f64( + acc1, + vld1q_f64(a_row.add(i + 2)), + vld1q_f64(b_row.add(i + 2)), + ); + i += 4; + } + let mut acc = vaddq_f64(acc0, acc1); + + while i + 2 <= k { + acc = vfmaq_f64(acc, vld1q_f64(a_row.add(i)), vld1q_f64(b_row.add(i))); + i += 2; + } + + let mut s = vaddvq_f64(acc); + while i < k { + s += *a_row.add(i) * *b_row.add(i); + i += 1; + } + *out_row.add(col) = s; + } + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + fn reference_gemv_bt(a: &[f32], b_nk: &[f32], m: usize, n: usize, k: usize) -> Vec { + let mut c = vec![0.0f32; m * n]; + for i in 0..m { + for j in 0..n { + let mut sum = 0.0f32; + for kk in 0..k { + sum += a[i * k + kk] * b_nk[j * k + kk]; + } + c[i * n + j] = sum; + } + } + c + } + + #[test] + fn test_gemv_bt_f32_m1() { + let (m, n, k) = (1, 64, 128); + let a: Vec = (0..m * k).map(|i| ((i % 17) as f32) * 0.1).collect(); + let b: Vec = (0..n * k).map(|i| ((i % 13) as f32) * 0.1).collect(); + let mut c = vec![0.0f32; m * n]; + let expected = reference_gemv_bt(&a, &b, m, n, k); + + let level = super::super::super::detect_simd(); + unsafe { gemv_bt_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, n, level) }; + + let max_diff = c + .iter() + .zip(&expected) + .map(|(a, b)| (a - b).abs()) + .fold(0.0f32, f32::max); + assert!(max_diff < 1e-4, "max_diff={max_diff}"); + } + + #[test] + fn test_gemv_bt_f32_m4() { + let (m, n, k) = (4, 53, 97); + let a: Vec = (0..m * k).map(|i| ((i % 7) as f32) * 0.3).collect(); + let b: Vec = (0..n * k).map(|i| ((i % 11) as f32) * 0.2).collect(); + let mut c = vec![0.0f32; m * n]; + let expected = reference_gemv_bt(&a, &b, m, n, k); + + let level = super::super::super::detect_simd(); + unsafe { gemv_bt_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, n, level) }; + + let max_diff = c + .iter() + .zip(&expected) + .map(|(a, b)| (a - b).abs()) + .fold(0.0f32, f32::max); + assert!(max_diff < 1e-3, "max_diff={max_diff}"); + } + + #[test] + fn test_gemv_bt_f64_m1() { + let (m, n, k) = (1, 64, 128); + let a: Vec = (0..m * k).map(|i| ((i % 17) as f64) * 0.1).collect(); + let b_nk: Vec = (0..n * k).map(|i| ((i % 13) as f64) * 0.1).collect(); + let mut c = vec![0.0f64; m * n]; + + // Reference + let mut expected = vec![0.0f64; m * n]; + for j in 0..n { + let mut sum = 0.0f64; + for kk in 0..k { + sum += a[kk] * b_nk[j * k + kk]; + } + expected[j] = sum; + } + + let level = super::super::super::detect_simd(); + unsafe { gemv_bt_f64(a.as_ptr(), b_nk.as_ptr(), c.as_mut_ptr(), m, n, k, n, level) }; + + let max_diff = c + .iter() + .zip(&expected) + .map(|(a, b)| (a - b).abs()) + .fold(0.0f64, f64::max); + assert!(max_diff < 1e-10, "max_diff={max_diff}"); + } +} diff --git a/src/runtime/cpu/kernels/simd/matmul/int32.rs b/src/runtime/cpu/kernels/simd/matmul/int32.rs new file mode 100644 index 00000000..0b06384f --- /dev/null +++ b/src/runtime/cpu/kernels/simd/matmul/int32.rs @@ -0,0 +1,175 @@ +//! SIMD-optimized i32 matrix multiplication +//! +//! Uses AVX2 `_mm256_mullo_epi32` for 8-wide i32 multiply-accumulate. + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +#[cfg(target_arch = "x86_64")] +use super::super::SimdLevel; +use super::super::detect_simd; + +/// SIMD-optimized i32 matrix multiplication: C = A @ B +/// +/// # Safety +/// - All pointers must be valid for the specified dimensions +/// - `out` must not alias with `a` or `b` +#[allow(clippy::too_many_arguments)] +pub unsafe fn matmul_i32( + a: *const i32, + b: *const i32, + out: *mut i32, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, +) { + let level = detect_simd(); + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 | SimdLevel::Avx2Fma => { + matmul_i32_avx2(a, b, out, m, n, k, lda, ldb, ldc); + return; + } + _ => {} + } + + // Scalar fallback + #[cfg(target_arch = "aarch64")] + let _ = level; + + matmul_i32_scalar(a, b, out, m, n, k, lda, ldb, ldc); +} + +/// AVX2 i32 matmul: row × column with 8-wide multiply-accumulate +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +#[allow(clippy::too_many_arguments)] +unsafe fn matmul_i32_avx2( + a: *const i32, + b: *const i32, + out: *mut i32, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, +) { + const LANES: usize = 8; + + for i in 0..m { + let a_row = a.add(i * lda); + + // Process 8 output columns at a time + let mut j = 0; + while j + LANES <= n { + let mut acc = _mm256_setzero_si256(); + + for kk in 0..k { + let a_val = _mm256_set1_epi32(*a_row.add(kk)); + let b_vals = _mm256_loadu_si256(b.add(kk * ldb + j) as *const __m256i); + let prod = _mm256_mullo_epi32(a_val, b_vals); + acc = _mm256_add_epi32(acc, prod); + } + + _mm256_storeu_si256(out.add(i * ldc + j) as *mut __m256i, acc); + j += LANES; + } + + // Scalar tail for remaining columns + while j < n { + let mut sum = 0i32; + for kk in 0..k { + sum += (*a_row.add(kk)) * (*b.add(kk * ldb + j)); + } + *out.add(i * ldc + j) = sum; + j += 1; + } + } +} + +/// Scalar i32 matmul fallback +#[allow(clippy::too_many_arguments)] +unsafe fn matmul_i32_scalar( + a: *const i32, + b: *const i32, + out: *mut i32, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, +) { + // Zero output + for i in 0..m { + for j in 0..n { + *out.add(i * ldc + j) = 0; + } + } + + // ikj order for cache locality + for i in 0..m { + for kk in 0..k { + let a_val = *a.add(i * lda + kk); + for j in 0..n { + let out_ptr = out.add(i * ldc + j); + *out_ptr += a_val * (*b.add(kk * ldb + j)); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_matmul_i32_basic() { + // A = [[1, 2], [3, 4]], B = [[5, 6], [7, 8]] + // C = [[19, 22], [43, 50]] + let a = [1i32, 2, 3, 4]; + let b = [5i32, 6, 7, 8]; + let mut c = [0i32; 4]; + + unsafe { matmul_i32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), 2, 2, 2, 2, 2, 2) }; + assert_eq!(c, [19, 22, 43, 50]); + } + + #[test] + fn test_matmul_i32_non_square() { + // A(3x2) @ B(2x4) = C(3x4) + let a = [1i32, 2, 3, 4, 5, 6]; + let b = [1i32, 2, 3, 4, 5, 6, 7, 8]; + let mut c = [0i32; 12]; + + unsafe { matmul_i32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), 3, 4, 2, 2, 4, 4) }; + assert_eq!(c, [11, 14, 17, 20, 23, 30, 37, 44, 35, 46, 57, 68]); + } + + #[test] + fn test_matmul_i32_wide() { + // Test with n > 8 to exercise SIMD path + let (m, n, k) = (2, 16, 3); + let a: Vec = (0..m * k).map(|i| (i + 1) as i32).collect(); + let b: Vec = (0..k * n).map(|i| (i + 1) as i32).collect(); + let mut c = vec![0i32; m * n]; + + unsafe { matmul_i32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) }; + + // Reference + let mut expected = vec![0i32; m * n]; + for i in 0..m { + for j in 0..n { + for kk in 0..k { + expected[i * n + j] += a[i * k + kk] * b[kk * n + j]; + } + } + } + assert_eq!(c, expected); + } +} diff --git a/src/runtime/cpu/kernels/simd/matmul/int8.rs b/src/runtime/cpu/kernels/simd/matmul/int8.rs new file mode 100644 index 00000000..243c1980 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/matmul/int8.rs @@ -0,0 +1,103 @@ +//! i8 × i8 → i32 matrix multiplication using SIMD dot product kernels +//! +//! Each output element C[i][j] = sum_k(A[i][k] * B[k][j]) where A,B are i8 +//! and accumulation is in i32. Uses the SIMD dot product from `simd::dot`. + +use super::super::dot::i8xi8_dot_i32; + +/// i8 × i8 → i32 matmul: C[m×n] = A[m×k] @ B[k×n] +/// +/// Packs columns of B into a contiguous scratch buffer so each dot product +/// operates on contiguous memory. +/// +/// # Safety +/// - `a` must be valid for m*lda i8 elements +/// - `b` must be valid for k*ldb i8 elements +/// - `out` must be valid for m*ldc i32 elements +#[allow(clippy::too_many_arguments)] +pub unsafe fn matmul_i8_to_i32( + a: *const i8, + b: *const i8, + out: *mut i32, + m: usize, + n: usize, + k: usize, + lda: usize, + ldb: usize, + ldc: usize, +) { + // Pack column j of B into contiguous memory for efficient dot products + let mut b_col = vec![0i8; k]; + + for j in 0..n { + // Pack column j + for kk in 0..k { + *b_col.as_mut_ptr().add(kk) = *b.add(kk * ldb + j); + } + + // Compute dot product of each row of A with packed column + for i in 0..m { + let a_row = a.add(i * lda); + *out.add(i * ldc + j) = i8xi8_dot_i32(a_row, b_col.as_ptr(), k); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_matmul_i8_to_i32_basic() { + let a: Vec = vec![1, 2, 3, 4]; + let b: Vec = vec![5, 6, 7, 8]; + let mut c = [0i32; 4]; + + unsafe { + matmul_i8_to_i32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), 2, 2, 2, 2, 2, 2); + } + // [[1,2],[3,4]] @ [[5,6],[7,8]] = [[19,22],[43,50]] + assert_eq!(c, [19, 22, 43, 50]); + } + + #[test] + fn test_matmul_i8_to_i32_negative() { + let a: Vec = vec![-1, 2, 3, -4]; + let b: Vec = vec![5, -6, -7, 8]; + let mut c = [0i32; 4]; + + unsafe { + matmul_i8_to_i32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), 2, 2, 2, 2, 2, 2); + } + // [[-1,2],[3,-4]] @ [[5,-6],[-7,8]] = [[-19,22],[43,-50]] + assert_eq!(c, [-19, 22, 43, -50]); + } + + #[test] + fn test_matmul_i8_to_i32_wide() { + // Test with larger k to exercise SIMD dot product path + let (m, n, k) = (2, 3, 64); + let a: Vec = (0..m * k) + .map(|i| ((i % 127) as i8).wrapping_sub(64)) + .collect(); + let b: Vec = (0..k * n) + .map(|i| ((i * 3 % 127) as i8).wrapping_sub(64)) + .collect(); + let mut c = vec![0i32; m * n]; + + unsafe { + matmul_i8_to_i32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n); + } + + // Reference + let mut expected = vec![0i32; m * n]; + for i in 0..m { + for j in 0..n { + for kk in 0..k { + expected[i * n + j] += a[i * k + kk] as i32 * b[kk * n + j] as i32; + } + } + } + assert_eq!(c, expected); + } +} diff --git a/src/runtime/cpu/kernels/simd/matmul/macros.rs b/src/runtime/cpu/kernels/simd/matmul/macros.rs index 5f481a03..fb564fc1 100644 --- a/src/runtime/cpu/kernels/simd/matmul/macros.rs +++ b/src/runtime/cpu/kernels/simd/matmul/macros.rs @@ -16,6 +16,7 @@ //! Each k iteration: 2 B loads shared across 6 A broadcasts = good reuse. /// Generate a 6×NR matmul microkernel for f32 (single column chunk) +#[cfg(target_arch = "x86_64")] macro_rules! define_microkernel_f32 { ( $name:ident, @@ -99,6 +100,7 @@ macro_rules! define_microkernel_f32 { /// Generate a 6×(2*NR) double-width matmul microkernel for f32 /// /// Processes 2 column chunks per row = 12 independent FMA chains. +#[cfg(target_arch = "x86_64")] macro_rules! define_microkernel_2x_f32 { ( $name:ident, @@ -211,6 +213,7 @@ macro_rules! define_microkernel_2x_f32 { } /// Generate a 6×NR matmul microkernel for f64 (single column chunk) +#[cfg(target_arch = "x86_64")] macro_rules! define_microkernel_f64 { ( $name:ident, @@ -292,6 +295,7 @@ macro_rules! define_microkernel_f64 { } /// Generate a 6×(2*NR) double-width matmul microkernel for f64 +#[cfg(target_arch = "x86_64")] macro_rules! define_microkernel_2x_f64 { ( $name:ident, @@ -399,7 +403,11 @@ macro_rules! define_microkernel_2x_f64 { }; } +#[cfg(target_arch = "x86_64")] pub(crate) use define_microkernel_2x_f32; +#[cfg(target_arch = "x86_64")] pub(crate) use define_microkernel_2x_f64; +#[cfg(target_arch = "x86_64")] pub(crate) use define_microkernel_f32; +#[cfg(target_arch = "x86_64")] pub(crate) use define_microkernel_f64; diff --git a/src/runtime/cpu/kernels/simd/matmul/mod.rs b/src/runtime/cpu/kernels/simd/matmul/mod.rs index e3d25652..83c0cb4d 100644 --- a/src/runtime/cpu/kernels/simd/matmul/mod.rs +++ b/src/runtime/cpu/kernels/simd/matmul/mod.rs @@ -1,622 +1,30 @@ -//! SIMD-optimized matrix multiplication with cache-aware tiling +//! SIMD-optimized matrix multiplication. //! -//! This module provides the tiled matmul algorithm that dispatches to -//! SIMD microkernels based on runtime CPU feature detection. -//! -//! # Algorithm Overview (BLIS-style) -//! -//! ```text -//! for jc in (0..N).step_by(NC): # L3 cache blocking -//! for pc in (0..K).step_by(KC): # L2 cache blocking -//! pack B[pc:pc+KC, jc:jc+NC] → B̃ # Pack B panel -//! for ic in (0..M).step_by(MC): # L2 cache blocking -//! pack A[ic:ic+MC, pc:pc+KC] → Ã # Pack A panel -//! for jr in (0..NC).step_by(NR): # Microkernel loop -//! for ir in (0..MC).step_by(MR): -//! microkernel(Ã[ir], B̃[jr], C[ic+ir, jc+jr]) -//! ``` -//! -//! # Microkernel Dimensions -//! -//! | SIMD Level | f32 (MR×NR) | f64 (MR×NR) | -//! |------------|-------------|-------------| -//! | AVX-512 | 6×16 | 6×8 | -//! | AVX2+FMA | 6×8 | 6×4 | -//! | Scalar | 6×4 | 6×4 | -//! -//! # Module Structure -//! -//! - `avx512.rs` / `avx2.rs`: SIMD microkernels (macro-generated) -//! - `macros.rs`: Macro definitions for microkernel generation -//! - `packing.rs`: Matrix packing functions -//! - `scalar.rs`: Scalar fallback implementations -//! - `tiling.rs`: Cache-aware tiled algorithm +//! See [`dispatch`] for the public API and microkernel dispatch functions. #[cfg(target_arch = "x86_64")] -mod avx2; +pub(crate) mod avx2; #[cfg(target_arch = "x86_64")] -mod avx512; -mod macros; -mod packing; -mod scalar; -mod small; -mod small_kernels; -mod tiling; +pub(crate) mod avx512; +pub(crate) mod dispatch; +pub(crate) mod gemv_bt; +pub(crate) mod int32; +pub(crate) mod int8; +pub(crate) mod macros; +pub(crate) mod packing; +pub(crate) mod scalar; +pub(crate) mod small; +pub(crate) mod small_kernels; +pub(crate) mod tiling; #[cfg(target_arch = "aarch64")] -mod aarch64; +pub(crate) mod aarch64; #[cfg(all(feature = "f16", target_arch = "x86_64"))] pub(crate) mod half_convert; -use super::{SimdLevel, detect_simd}; -use scalar::{matmul_bias_scalar_f32, matmul_bias_scalar_f64}; -use scalar::{matmul_scalar_f32, matmul_scalar_f64}; -use scalar::{microkernel_edge_f32, microkernel_edge_f64}; -use tiling::{matmul_bias_tiled_f32, matmul_bias_tiled_f64}; -use tiling::{matmul_tiled_f32, matmul_tiled_f64}; - -// ============================================================================ -// Constants -// ============================================================================ - -/// Micro-kernel row dimension (Mr) -pub const MR: usize = 6; - -/// L3 cache blocking: M dimension (Mc) -/// Must be a multiple of MR to avoid buffer overflow in packing. -pub const MC: usize = 126; // 21 * MR(6) - -/// L2 cache blocking: K dimension (Kc) -/// Sized so packed_A (MC×KC×4) fits in L2 cache (~256KB): -/// 126 × 256 × 4 = 129KB -pub const KC: usize = 256; - -/// L3 cache blocking: N dimension (Nc) -pub const NC: usize = 512; - -/// Small matrix threshold - below this, register-blocked SIMD is faster than tiled -const SMALL_MATRIX_THRESHOLD: usize = 128 * 128 * 128 + 1; - -// ============================================================================ -// Public API -// ============================================================================ - -/// SIMD-optimized matrix multiplication: C = A @ B -/// -/// Dispatches to the best available SIMD implementation based on CPU features. -/// Falls back to scalar for unsupported CPUs or small matrices. -/// -/// # Safety -/// - All pointers must be valid for the specified dimensions -/// - `out` must not alias with `a` or `b` -#[inline] -#[allow(clippy::too_many_arguments)] -pub unsafe fn matmul_f32( - a: *const f32, - b: *const f32, - out: *mut f32, - m: usize, - n: usize, - k: usize, - lda: usize, - ldb: usize, - ldc: usize, -) { - let level = detect_simd(); - - if m * n * k < SMALL_MATRIX_THRESHOLD { - small::small_matmul_f32(a, b, out, m, n, k, lda, ldb, ldc, level); - return; - } - - // Use double-width NR for 12 FMA chains (2×NR columns per microkernel) - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => matmul_tiled_f32::<32>(a, b, out, m, n, k, lda, ldb, ldc, level), - SimdLevel::Avx2Fma => matmul_tiled_f32::<16>(a, b, out, m, n, k, lda, ldb, ldc, level), - _ => matmul_scalar_f32(a, b, out, m, n, k, lda, ldb, ldc), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - matmul_tiled_f32::<8>(a, b, out, m, n, k, lda, ldb, ldc, level) - } - _ => matmul_scalar_f32(a, b, out, m, n, k, lda, ldb, ldc), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - matmul_scalar_f32(a, b, out, m, n, k, lda, ldb, ldc); -} - -/// SIMD-optimized matrix multiplication for f64 -#[inline] -#[allow(clippy::too_many_arguments)] -pub unsafe fn matmul_f64( - a: *const f64, - b: *const f64, - out: *mut f64, - m: usize, - n: usize, - k: usize, - lda: usize, - ldb: usize, - ldc: usize, -) { - let level = detect_simd(); - - if m * n * k < SMALL_MATRIX_THRESHOLD { - small::small_matmul_f64(a, b, out, m, n, k, lda, ldb, ldc, level); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => matmul_tiled_f64::<16>(a, b, out, m, n, k, lda, ldb, ldc, level), - SimdLevel::Avx2Fma => matmul_tiled_f64::<8>(a, b, out, m, n, k, lda, ldb, ldc, level), - _ => matmul_scalar_f64(a, b, out, m, n, k, lda, ldb, ldc), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - matmul_tiled_f64::<4>(a, b, out, m, n, k, lda, ldb, ldc, level) - } - _ => matmul_scalar_f64(a, b, out, m, n, k, lda, ldb, ldc), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - matmul_scalar_f64(a, b, out, m, n, k, lda, ldb, ldc); -} - -/// Fused matmul with bias: C = A @ B + bias (single-pass, cache-efficient) -/// -/// Initializes C with bias, then accumulates the matmul result. -/// This is more cache-efficient than separate matmul + bias addition. -#[inline] -#[allow(clippy::too_many_arguments)] -pub unsafe fn matmul_bias_f32( - a: *const f32, - b: *const f32, - bias: *const f32, - out: *mut f32, - m: usize, - n: usize, - k: usize, - lda: usize, - ldb: usize, - ldc: usize, -) { - let level = detect_simd(); - - if m * n * k < SMALL_MATRIX_THRESHOLD { - small::small_matmul_bias_f32(a, b, bias, out, m, n, k, lda, ldb, ldc, level); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => { - matmul_bias_tiled_f32::<32>(a, b, bias, out, m, n, k, lda, ldb, ldc, level) - } - SimdLevel::Avx2Fma => { - matmul_bias_tiled_f32::<16>(a, b, bias, out, m, n, k, lda, ldb, ldc, level) - } - _ => matmul_bias_scalar_f32(a, b, bias, out, m, n, k, lda, ldb, ldc), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - matmul_bias_tiled_f32::<8>(a, b, bias, out, m, n, k, lda, ldb, ldc, level) - } - _ => matmul_bias_scalar_f32(a, b, bias, out, m, n, k, lda, ldb, ldc), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - matmul_bias_scalar_f32(a, b, bias, out, m, n, k, lda, ldb, ldc); -} - -/// Fused matmul with bias for f64 -#[inline] -#[allow(clippy::too_many_arguments)] -pub unsafe fn matmul_bias_f64( - a: *const f64, - b: *const f64, - bias: *const f64, - out: *mut f64, - m: usize, - n: usize, - k: usize, - lda: usize, - ldb: usize, - ldc: usize, -) { - let level = detect_simd(); - - if m * n * k < SMALL_MATRIX_THRESHOLD { - small::small_matmul_bias_f64(a, b, bias, out, m, n, k, lda, ldb, ldc, level); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => { - matmul_bias_tiled_f64::<16>(a, b, bias, out, m, n, k, lda, ldb, ldc, level) - } - SimdLevel::Avx2Fma => { - matmul_bias_tiled_f64::<8>(a, b, bias, out, m, n, k, lda, ldb, ldc, level) - } - _ => matmul_bias_scalar_f64(a, b, bias, out, m, n, k, lda, ldb, ldc), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - matmul_bias_tiled_f64::<4>(a, b, bias, out, m, n, k, lda, ldb, ldc, level) - } - _ => matmul_bias_scalar_f64(a, b, bias, out, m, n, k, lda, ldb, ldc), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - matmul_bias_scalar_f64(a, b, bias, out, m, n, k, lda, ldb, ldc); -} - -// ============================================================================ -// Microkernel dispatch (must be here for target_feature to work) -// ============================================================================ - -/// Dispatch to the appropriate SIMD microkernel for f32 (single-width NR) -/// -/// `first_k`: when true, accumulators start from zero (beta=0, no load from C). -#[inline] -pub(crate) unsafe fn call_microkernel_f32( - a: *const f32, - b: *const f32, - c: *mut f32, - k: usize, - ldc: usize, - level: SimdLevel, - first_k: bool, -) { - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::microkernel_6x16_f32(a, b, c, k, ldc, first_k), - SimdLevel::Avx2Fma => avx2::microkernel_6x8_f32(a, b, c, k, ldc, first_k), - _ => microkernel_edge_f32(a, b, c, MR, 4, k, ldc, first_k), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - aarch64::neon::microkernel_6x4_f32(a, b, c, k, ldc, first_k) - } - _ => microkernel_edge_f32(a, b, c, MR, 4, k, ldc, first_k), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - microkernel_edge_f32(a, b, c, MR, 4, k, ldc, first_k); -} - -/// Dispatch to the double-width SIMD microkernel for f32 (2×NR columns) -/// -/// Processes 6 rows × 2*NR columns = 12 independent FMA chains. -#[inline] -pub(crate) unsafe fn call_microkernel_2x_f32( - a: *const f32, - b: *const f32, - c: *mut f32, - k: usize, - ldc: usize, - level: SimdLevel, - first_k: bool, -) { - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::microkernel_6x32_f32(a, b, c, k, ldc, first_k), - SimdLevel::Avx2Fma => avx2::microkernel_6x16_f32(a, b, c, k, ldc, first_k), - _ => { - // Fallback: call single-width twice - let nr = 4usize; - microkernel_edge_f32(a, b, c, MR, nr, k, ldc, first_k); - microkernel_edge_f32(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k); - } - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - // NEON: call single-width twice (4+4=8) - aarch64::neon::microkernel_6x4_f32(a, b, c, k, ldc, first_k); - aarch64::neon::microkernel_6x4_f32(a, b.add(4 * k), c.add(4), k, ldc, first_k); - } - _ => { - let nr = 4usize; - microkernel_edge_f32(a, b, c, MR, nr, k, ldc, first_k); - microkernel_edge_f32(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k); - } - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - { - let nr = 4usize; - microkernel_edge_f32(a, b, c, MR, nr, k, ldc, first_k); - microkernel_edge_f32(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k); - } -} - -/// Dispatch to the appropriate SIMD microkernel for f64 (single-width NR) -#[inline] -pub(crate) unsafe fn call_microkernel_f64( - a: *const f64, - b: *const f64, - c: *mut f64, - k: usize, - ldc: usize, - level: SimdLevel, - first_k: bool, -) { - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::microkernel_6x8_f64(a, b, c, k, ldc, first_k), - SimdLevel::Avx2Fma => avx2::microkernel_6x4_f64(a, b, c, k, ldc, first_k), - _ => microkernel_edge_f64(a, b, c, MR, 4, k, ldc, first_k), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - aarch64::neon::microkernel_6x2_f64(a, b, c, k, ldc, first_k) - } - _ => microkernel_edge_f64(a, b, c, MR, 2, k, ldc, first_k), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - microkernel_edge_f64(a, b, c, MR, 4, k, ldc, first_k); -} - -/// Dispatch to the double-width SIMD microkernel for f64 (2×NR columns) -#[inline] -pub(crate) unsafe fn call_microkernel_2x_f64( - a: *const f64, - b: *const f64, - c: *mut f64, - k: usize, - ldc: usize, - level: SimdLevel, - first_k: bool, -) { - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::microkernel_6x16_f64(a, b, c, k, ldc, first_k), - SimdLevel::Avx2Fma => avx2::microkernel_6x8_f64(a, b, c, k, ldc, first_k), - _ => { - let nr = 4usize; - microkernel_edge_f64(a, b, c, MR, nr, k, ldc, first_k); - microkernel_edge_f64(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k); - } - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - aarch64::neon::microkernel_6x2_f64(a, b, c, k, ldc, first_k); - aarch64::neon::microkernel_6x2_f64(a, b.add(2 * k), c.add(2), k, ldc, first_k); - } - _ => { - let nr = 2usize; - microkernel_edge_f64(a, b, c, MR, nr, k, ldc, first_k); - microkernel_edge_f64(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k); - } - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - { - let nr = 4usize; - microkernel_edge_f64(a, b, c, MR, nr, k, ldc, first_k); - microkernel_edge_f64(a, b.add(nr * k), c.add(nr), MR, nr, k, ldc, first_k); - } -} - -// ============================================================================ -// Tests -// ============================================================================ - -#[cfg(test)] -mod tests { - use super::*; - - fn reference_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec { - let mut c = vec![0.0f32; m * n]; - for i in 0..m { - for j in 0..n { - let mut sum = 0.0f32; - for kk in 0..k { - sum += a[i * k + kk] * b[kk * n + j]; - } - c[i * n + j] = sum; - } - } - c - } - - fn reference_matmul_f64(a: &[f64], b: &[f64], m: usize, n: usize, k: usize) -> Vec { - let mut c = vec![0.0f64; m * n]; - for i in 0..m { - for j in 0..n { - let mut sum = 0.0f64; - for kk in 0..k { - sum += a[i * k + kk] * b[kk * n + j]; - } - c[i * n + j] = sum; - } - } - c - } - - fn reference_matmul_bias_f32( - a: &[f32], - b: &[f32], - bias: &[f32], - m: usize, - n: usize, - k: usize, - ) -> Vec { - let mut c = reference_matmul_f32(a, b, m, n, k); - for i in 0..m { - for j in 0..n { - c[i * n + j] += bias[j]; - } - } - c - } - - const F32_SMALL_TOL: f32 = 1e-4; - const F32_LARGE_TOL: f32 = 1e-3; - const F64_SMALL_TOL: f64 = 1e-10; - const F64_LARGE_TOL: f64 = 1e-9; - - #[test] - fn test_matmul_f32_small() { - let (m, n, k) = (4, 4, 4); - let a: Vec = (0..m * k).map(|i| (i + 1) as f32).collect(); - let b: Vec = (0..k * n).map(|i| (i + 1) as f32).collect(); - let mut c = vec![0.0f32; m * n]; - let expected = reference_matmul_f32(&a, &b, m, n, k); - - unsafe { matmul_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) }; - - for i in 0..m * n { - assert!((c[i] - expected[i]).abs() < F32_SMALL_TOL); - } - } - - #[test] - fn test_matmul_f32_large() { - let (m, n, k) = (128, 128, 128); - let a: Vec = (0..m * k).map(|i| ((i % 17) as f32) * 0.1).collect(); - let b: Vec = (0..k * n).map(|i| ((i % 13) as f32) * 0.1).collect(); - let mut c = vec![0.0f32; m * n]; - let expected = reference_matmul_f32(&a, &b, m, n, k); - - unsafe { matmul_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) }; - - let max_diff = (0..m * n) - .map(|i| (c[i] - expected[i]).abs()) - .fold(0.0f32, f32::max); - assert!(max_diff < F32_LARGE_TOL); - } - - #[test] - fn test_matmul_f64_small() { - let (m, n, k) = (4, 4, 4); - let a: Vec = (0..m * k).map(|i| (i + 1) as f64).collect(); - let b: Vec = (0..k * n).map(|i| (i + 1) as f64).collect(); - let mut c = vec![0.0f64; m * n]; - let expected = reference_matmul_f64(&a, &b, m, n, k); - - unsafe { matmul_f64(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) }; - - for i in 0..m * n { - assert!((c[i] - expected[i]).abs() < F64_SMALL_TOL); - } - } - - #[test] - fn test_matmul_f64_large() { - let (m, n, k) = (128, 128, 128); - let a: Vec = (0..m * k).map(|i| ((i % 17) as f64) * 0.1).collect(); - let b: Vec = (0..k * n).map(|i| ((i % 13) as f64) * 0.1).collect(); - let mut c = vec![0.0f64; m * n]; - let expected = reference_matmul_f64(&a, &b, m, n, k); - - unsafe { matmul_f64(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) }; - - let max_diff = (0..m * n) - .map(|i| (c[i] - expected[i]).abs()) - .fold(0.0f64, f64::max); - assert!(max_diff < F64_LARGE_TOL); - } - - #[test] - fn test_matmul_non_square() { - let (m, n, k) = (37, 53, 41); - let a: Vec = (0..m * k).map(|i| ((i % 7) as f32) * 0.5).collect(); - let b: Vec = (0..k * n).map(|i| ((i % 11) as f32) * 0.3).collect(); - let mut c = vec![0.0f32; m * n]; - let expected = reference_matmul_f32(&a, &b, m, n, k); - - unsafe { matmul_f32(a.as_ptr(), b.as_ptr(), c.as_mut_ptr(), m, n, k, k, n, n) }; - - let max_diff = (0..m * n) - .map(|i| (c[i] - expected[i]).abs()) - .fold(0.0f32, f32::max); - assert!(max_diff < F32_LARGE_TOL); - } - - #[test] - fn test_matmul_bias_f32_small() { - let (m, n, k) = (4, 4, 4); - let a: Vec = (0..m * k).map(|i| (i + 1) as f32).collect(); - let b: Vec = (0..k * n).map(|i| (i + 1) as f32).collect(); - let bias: Vec = (0..n).map(|i| (i * 10) as f32).collect(); - let mut c = vec![0.0f32; m * n]; - let expected = reference_matmul_bias_f32(&a, &b, &bias, m, n, k); - - unsafe { - matmul_bias_f32( - a.as_ptr(), - b.as_ptr(), - bias.as_ptr(), - c.as_mut_ptr(), - m, - n, - k, - k, - n, - n, - ) - }; - - for i in 0..m * n { - assert!((c[i] - expected[i]).abs() < F32_SMALL_TOL); - } - } - - #[test] - fn test_matmul_bias_f32_large() { - let (m, n, k) = (128, 128, 128); - let a: Vec = (0..m * k).map(|i| ((i % 17) as f32) * 0.1).collect(); - let b: Vec = (0..k * n).map(|i| ((i % 13) as f32) * 0.1).collect(); - let bias: Vec = (0..n).map(|i| ((i % 7) as f32) * 0.5).collect(); - let mut c = vec![0.0f32; m * n]; - let expected = reference_matmul_bias_f32(&a, &b, &bias, m, n, k); - - unsafe { - matmul_bias_f32( - a.as_ptr(), - b.as_ptr(), - bias.as_ptr(), - c.as_mut_ptr(), - m, - n, - k, - k, - n, - n, - ) - }; - - let max_diff = (0..m * n) - .map(|i| (c[i] - expected[i]).abs()) - .fold(0.0f32, f32::max); - assert!(max_diff < F32_LARGE_TOL); - } +pub use dispatch::{KC, MC, MR, NC, matmul_bias_f32, matmul_bias_f64, matmul_f32, matmul_f64}; - #[test] - fn test_simd_level_detection() { - let level = detect_simd(); - println!("Detected SIMD level: {level:?}"); - } -} +pub use dispatch::{ + call_microkernel_2x_f32, call_microkernel_2x_f64, call_microkernel_f32, call_microkernel_f64, +}; diff --git a/src/runtime/cpu/kernels/simd/mod.rs b/src/runtime/cpu/kernels/simd/mod.rs index 0671f3b1..7f45b0c9 100644 --- a/src/runtime/cpu/kernels/simd/mod.rs +++ b/src/runtime/cpu/kernels/simd/mod.rs @@ -29,6 +29,15 @@ //! | ARM64 | NEON | 128 bits | Supported | //! | Any | Scalar | N/A | Fallback | +// Shared f16/bf16 ↔ f32 SIMD conversion utilities +#[cfg(feature = "f16")] +pub mod half_convert_utils; + +// Macros for generating f16/bf16 block-convert-compute wrappers (must come before users) +// Always compiled - macros internally gate generated code with #[cfg(feature = "f16")] +#[macro_use] +mod half_macros; + // Operation modules - available on all architectures // Each operation's mod.rs handles internal architecture dispatch pub mod activations; @@ -37,6 +46,9 @@ pub mod clamp; pub mod compare; pub mod conv; pub mod cumulative; +pub mod dot; +pub mod fused_activation_mul; +pub mod fused_elementwise; pub mod index; pub mod logsumexp; pub mod math; @@ -45,6 +57,7 @@ pub mod norm; pub mod reduce; pub mod scalar; pub mod softmax; +pub mod softmax_bwd; pub mod special; pub mod unary; pub mod where_select; @@ -196,6 +209,7 @@ fn detect_simd_uncached() -> SimdLevel { return SimdLevel::Neon; } + #[allow(unreachable_code)] SimdLevel::Scalar } diff --git a/src/runtime/cpu/kernels/simd/norm/aarch64/neon/fused_add_layer_norm.rs b/src/runtime/cpu/kernels/simd/norm/aarch64/neon/fused_add_layer_norm.rs new file mode 100644 index 00000000..75988e4b --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/aarch64/neon/fused_add_layer_norm.rs @@ -0,0 +1,450 @@ +//! NEON fused add + layer normalization kernels (forward and backward) + +#[cfg(target_arch = "aarch64")] +use std::arch::aarch64::*; + +use super::super::super::super::math::aarch64::neon::{hsum_f32, hsum_f64}; +use super::{F32_LANES, F64_LANES}; + +/// NEON Fused Add + Layer Normalization for f32 +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_f32( + input: *const f32, + residual: *const f32, + weight: *const f32, + bias: *const f32, + out: *mut f32, + pre_norm: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + let remainder = hidden_size % F32_LANES; + + for b in 0..batch_size { + let base = input.add(b * hidden_size); + let res_base = residual.add(b * hidden_size); + let pn_base = pre_norm.add(b * hidden_size); + let out_base = out.add(b * hidden_size); + + // Phase 1: Compute mean + let mut sum_acc = vdupq_n_f32(0.0); + for i in 0..chunks { + let offset = i * F32_LANES; + let v_in = vld1q_f32(base.add(offset)); + let v_res = vld1q_f32(res_base.add(offset)); + let pn = vaddq_f32(v_in, v_res); + vst1q_f32(pn_base.add(offset), pn); + sum_acc = vaddq_f32(sum_acc, pn); + } + let mut sum = hsum_f32(sum_acc); + + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + let pn = *base.add(offset) + *res_base.add(offset); + *pn_base.add(offset) = pn; + sum += pn; + } + + let mean = sum / hidden_size as f32; + let v_mean = vdupq_n_f32(mean); + + // Phase 2: Compute variance + let mut var_acc = vdupq_n_f32(0.0); + for i in 0..chunks { + let offset = i * F32_LANES; + let pn = vld1q_f32(pn_base.add(offset)); + let diff = vsubq_f32(pn, v_mean); + var_acc = vfmaq_f32(var_acc, diff, diff); + } + let mut var_sum = hsum_f32(var_acc); + + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + let diff = *pn_base.add(offset) - mean; + var_sum += diff * diff; + } + + let inv_std = 1.0 / (var_sum / hidden_size as f32 + eps).sqrt(); + let v_inv_std = vdupq_n_f32(inv_std); + + // Phase 3: Apply normalization, weight, and bias + for i in 0..chunks { + let offset = i * F32_LANES; + let pn = vld1q_f32(pn_base.add(offset)); + let v_w = vld1q_f32(weight.add(offset)); + let v_b = vld1q_f32(bias.add(offset)); + + let normalized = vmulq_f32(vsubq_f32(pn, v_mean), v_inv_std); + let result = vfmaq_f32(v_b, normalized, v_w); + vst1q_f32(out_base.add(offset), result); + } + + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + let x = *pn_base.add(offset); + let w = *weight.add(offset); + let b = *bias.add(offset); + *out_base.add(offset) = (x - mean) * inv_std * w + b; + } + } +} + +/// NEON Fused Add + Layer Normalization for f64 +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_f64( + input: *const f64, + residual: *const f64, + weight: *const f64, + bias: *const f64, + out: *mut f64, + pre_norm: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + let remainder = hidden_size % F64_LANES; + + for b in 0..batch_size { + let base = input.add(b * hidden_size); + let res_base = residual.add(b * hidden_size); + let pn_base = pre_norm.add(b * hidden_size); + let out_base = out.add(b * hidden_size); + + let mut sum_acc = vdupq_n_f64(0.0); + for i in 0..chunks { + let offset = i * F64_LANES; + let v_in = vld1q_f64(base.add(offset)); + let v_res = vld1q_f64(res_base.add(offset)); + let pn = vaddq_f64(v_in, v_res); + vst1q_f64(pn_base.add(offset), pn); + sum_acc = vaddq_f64(sum_acc, pn); + } + let mut sum = hsum_f64(sum_acc); + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + let pn = *base.add(offset) + *res_base.add(offset); + *pn_base.add(offset) = pn; + sum += pn; + } + + let mean = sum / hidden_size as f64; + let v_mean = vdupq_n_f64(mean); + + let mut var_acc = vdupq_n_f64(0.0); + for i in 0..chunks { + let offset = i * F64_LANES; + let pn = vld1q_f64(pn_base.add(offset)); + let diff = vsubq_f64(pn, v_mean); + var_acc = vfmaq_f64(var_acc, diff, diff); + } + let mut var_sum = hsum_f64(var_acc); + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + let diff = *pn_base.add(offset) - mean; + var_sum += diff * diff; + } + + let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); + let v_inv_std = vdupq_n_f64(inv_std); + + for i in 0..chunks { + let offset = i * F64_LANES; + let pn = vld1q_f64(pn_base.add(offset)); + let v_w = vld1q_f64(weight.add(offset)); + let v_b = vld1q_f64(bias.add(offset)); + + let normalized = vmulq_f64(vsubq_f64(pn, v_mean), v_inv_std); + let result = vfmaq_f64(v_b, normalized, v_w); + vst1q_f64(out_base.add(offset), result); + } + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + let x = *pn_base.add(offset); + let w = *weight.add(offset); + let b = *bias.add(offset); + *out_base.add(offset) = (x - mean) * inv_std * w + b; + } + } +} + +/// NEON Fused Add + Layer Norm Backward for f32 +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_bwd_f32( + grad: *const f32, + pre_norm: *const f32, + weight: *const f32, + d_input_residual: *mut f32, + d_weight: *mut f32, + d_bias: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + let remainder = hidden_size % F32_LANES; + + for b in 0..batch_size { + let pn_base = pre_norm.add(b * hidden_size); + let grad_base = grad.add(b * hidden_size); + let d_ir_base = d_input_residual.add(b * hidden_size); + + // Recompute mean from pre_norm + let mut sum_acc = vdupq_n_f32(0.0); + for i in 0..chunks { + let offset = i * F32_LANES; + let pn = vld1q_f32(pn_base.add(offset)); + sum_acc = vaddq_f32(sum_acc, pn); + } + let mut sum = hsum_f32(sum_acc); + + for i in 0..remainder { + sum += *pn_base.add(chunks * F32_LANES + i); + } + + let mean = sum / hidden_size as f32; + let v_mean = vdupq_n_f32(mean); + + // Recompute variance + let mut var_acc = vdupq_n_f32(0.0); + for i in 0..chunks { + let offset = i * F32_LANES; + let pn = vld1q_f32(pn_base.add(offset)); + let diff = vsubq_f32(pn, v_mean); + var_acc = vfmaq_f32(var_acc, diff, diff); + } + let mut var_sum = hsum_f32(var_acc); + + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + let diff = *pn_base.add(offset) - mean; + var_sum += diff * diff; + } + + let inv_std = 1.0 / (var_sum / hidden_size as f32 + eps).sqrt(); + + // Compute mean_gs = mean(grad * weight) and mean_gs_n = mean(grad * weight * normalized) + let mut gs_acc = vdupq_n_f32(0.0); + let mut gsn_acc = vdupq_n_f32(0.0); + for i in 0..chunks { + let offset = i * F32_LANES; + let g = vld1q_f32(grad_base.add(offset)); + let w = vld1q_f32(weight.add(offset)); + let pn = vld1q_f32(pn_base.add(offset)); + + let gs = vmulq_f32(g, w); + gs_acc = vaddq_f32(gs_acc, gs); + + let diff = vsubq_f32(pn, v_mean); + let normalized = vmulq_f32(diff, vdupq_n_f32(inv_std)); + let gsn = vmulq_f32(gs, normalized); + gsn_acc = vaddq_f32(gsn_acc, gsn); + } + let mut mean_gs_simd = hsum_f32(gs_acc); + let mut mean_gsn_simd = hsum_f32(gsn_acc); + + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + let g = *grad_base.add(offset); + let w = *weight.add(offset); + let pn = *pn_base.add(offset); + + let gs = g * w; + mean_gs_simd += gs; + + let normalized = (pn - mean) * inv_std; + mean_gsn_simd += gs * normalized; + } + + let mean_gs = mean_gs_simd / hidden_size as f32; + let mean_gs_n = mean_gsn_simd / hidden_size as f32; + let v_inv_std = vdupq_n_f32(inv_std); + let v_mean_gs = vdupq_n_f32(mean_gs); + let v_mean_gs_n = vdupq_n_f32(mean_gs_n); + + // Apply and accumulate + for i in 0..chunks { + let offset = i * F32_LANES; + let g = vld1q_f32(grad_base.add(offset)); + let w = vld1q_f32(weight.add(offset)); + let pn = vld1q_f32(pn_base.add(offset)); + + let normalized = vmulq_f32(vsubq_f32(pn, v_mean), v_inv_std); + let gs = vmulq_f32(g, w); + let d_ir = vmulq_f32( + v_inv_std, + vsubq_f32(gs, vaddq_f32(v_mean_gs, vmulq_f32(normalized, v_mean_gs_n))), + ); + vst1q_f32(d_ir_base.add(offset), d_ir); + + let dw_old = vld1q_f32(d_weight.add(offset)); + let dw_add = vmulq_f32(g, normalized); + let dw_new = vaddq_f32(dw_old, dw_add); + vst1q_f32(d_weight.add(offset), dw_new); + + let db_old = vld1q_f32(d_bias.add(offset)); + let db_new = vaddq_f32(db_old, g); + vst1q_f32(d_bias.add(offset), db_new); + } + + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + let g = *grad_base.add(offset); + let w = *weight.add(offset); + let pn = *pn_base.add(offset); + + let normalized = (pn - mean) * inv_std; + let gs = g * w; + let d_ir = inv_std * (gs - mean_gs - normalized * mean_gs_n); + *d_ir_base.add(offset) = d_ir; + + *d_weight.add(offset) += g * normalized; + *d_bias.add(offset) += g; + } + } +} + +/// NEON Fused Add + Layer Norm Backward for f64 +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_bwd_f64( + grad: *const f64, + pre_norm: *const f64, + weight: *const f64, + d_input_residual: *mut f64, + d_weight: *mut f64, + d_bias: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + let remainder = hidden_size % F64_LANES; + + for b in 0..batch_size { + let pn_base = pre_norm.add(b * hidden_size); + let grad_base = grad.add(b * hidden_size); + let d_ir_base = d_input_residual.add(b * hidden_size); + + let mut sum_acc = vdupq_n_f64(0.0); + for i in 0..chunks { + let offset = i * F64_LANES; + let pn = vld1q_f64(pn_base.add(offset)); + sum_acc = vaddq_f64(sum_acc, pn); + } + let mut sum = hsum_f64(sum_acc); + + for i in 0..remainder { + sum += *pn_base.add(chunks * F64_LANES + i); + } + + let mean = sum / hidden_size as f64; + let v_mean = vdupq_n_f64(mean); + + let mut var_acc = vdupq_n_f64(0.0); + for i in 0..chunks { + let offset = i * F64_LANES; + let pn = vld1q_f64(pn_base.add(offset)); + let diff = vsubq_f64(pn, v_mean); + var_acc = vfmaq_f64(var_acc, diff, diff); + } + let mut var_sum = hsum_f64(var_acc); + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + let diff = *pn_base.add(offset) - mean; + var_sum += diff * diff; + } + + let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); + + let mut gs_acc = vdupq_n_f64(0.0); + let mut gsn_acc = vdupq_n_f64(0.0); + for i in 0..chunks { + let offset = i * F64_LANES; + let g = vld1q_f64(grad_base.add(offset)); + let w = vld1q_f64(weight.add(offset)); + let pn = vld1q_f64(pn_base.add(offset)); + + let gs = vmulq_f64(g, w); + gs_acc = vaddq_f64(gs_acc, gs); + + let diff = vsubq_f64(pn, v_mean); + let normalized = vmulq_f64(diff, vdupq_n_f64(inv_std)); + let gsn = vmulq_f64(gs, normalized); + gsn_acc = vaddq_f64(gsn_acc, gsn); + } + let mut mean_gs_simd = hsum_f64(gs_acc); + let mut mean_gsn_simd = hsum_f64(gsn_acc); + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + let g = *grad_base.add(offset); + let w = *weight.add(offset); + let pn = *pn_base.add(offset); + + let gs = g * w; + mean_gs_simd += gs; + + let normalized = (pn - mean) * inv_std; + mean_gsn_simd += gs * normalized; + } + + let mean_gs = mean_gs_simd / hidden_size as f64; + let mean_gs_n = mean_gsn_simd / hidden_size as f64; + let v_inv_std = vdupq_n_f64(inv_std); + let v_mean_gs = vdupq_n_f64(mean_gs); + let v_mean_gs_n = vdupq_n_f64(mean_gs_n); + + for i in 0..chunks { + let offset = i * F64_LANES; + let g = vld1q_f64(grad_base.add(offset)); + let w = vld1q_f64(weight.add(offset)); + let pn = vld1q_f64(pn_base.add(offset)); + + let normalized = vmulq_f64(vsubq_f64(pn, v_mean), v_inv_std); + let gs = vmulq_f64(g, w); + let d_ir = vmulq_f64( + v_inv_std, + vsubq_f64(gs, vaddq_f64(v_mean_gs, vmulq_f64(normalized, v_mean_gs_n))), + ); + vst1q_f64(d_ir_base.add(offset), d_ir); + + let dw_old = vld1q_f64(d_weight.add(offset)); + let dw_add = vmulq_f64(g, normalized); + let dw_new = vaddq_f64(dw_old, dw_add); + vst1q_f64(d_weight.add(offset), dw_new); + + let db_old = vld1q_f64(d_bias.add(offset)); + let db_new = vaddq_f64(db_old, g); + vst1q_f64(d_bias.add(offset), db_new); + } + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + let g = *grad_base.add(offset); + let w = *weight.add(offset); + let pn = *pn_base.add(offset); + + let normalized = (pn - mean) * inv_std; + let gs = g * w; + let d_ir = inv_std * (gs - mean_gs - normalized * mean_gs_n); + *d_ir_base.add(offset) = d_ir; + + *d_weight.add(offset) += g * normalized; + *d_bias.add(offset) += g; + } + } +} diff --git a/src/runtime/cpu/kernels/simd/norm/aarch64/neon/fused_add_rms_norm.rs b/src/runtime/cpu/kernels/simd/norm/aarch64/neon/fused_add_rms_norm.rs new file mode 100644 index 00000000..61833e34 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/aarch64/neon/fused_add_rms_norm.rs @@ -0,0 +1,332 @@ +//! NEON fused add + RMS normalization kernels (forward and backward) + +#[cfg(target_arch = "aarch64")] +use std::arch::aarch64::*; + +use super::super::super::super::math::aarch64::neon::{hsum_f32, hsum_f64}; +use super::{F32_LANES, F64_LANES}; + +/// NEON Fused Add + RMS Normalization for f32 +/// +/// Computes: output = (input + residual) * rsqrt(mean((input + residual)^2) + eps) * weight +/// Stores intermediate (input + residual) in pre_norm for backward pass. +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_f32( + input: *const f32, + residual: *const f32, + weight: *const f32, + out: *mut f32, + pre_norm: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + let remainder = hidden_size % F32_LANES; + + for b in 0..batch_size { + let base = input.add(b * hidden_size); + let res_base = residual.add(b * hidden_size); + let pn_base = pre_norm.add(b * hidden_size); + let out_base = out.add(b * hidden_size); + + // Phase 1: Add input + residual, store in pre_norm, accumulate sum of squares + let mut ss_acc = vdupq_n_f32(0.0); + for i in 0..chunks { + let offset = i * F32_LANES; + let v_in = vld1q_f32(base.add(offset)); + let v_res = vld1q_f32(res_base.add(offset)); + let pn = vaddq_f32(v_in, v_res); + vst1q_f32(pn_base.add(offset), pn); + ss_acc = vfmaq_f32(ss_acc, pn, pn); + } + let mut sum_sq = hsum_f32(ss_acc) as f64; + + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + let pn = *base.add(offset) + *res_base.add(offset); + *pn_base.add(offset) = pn; + let pn64 = pn as f64; + sum_sq += pn64 * pn64; + } + + let inv_rms = (1.0f64 / (sum_sq / hidden_size as f64 + eps as f64).sqrt()) as f32; + let v_inv_rms = vdupq_n_f32(inv_rms); + + // Phase 2: Apply normalization and weight + for i in 0..chunks { + let offset = i * F32_LANES; + let pn = vld1q_f32(pn_base.add(offset)); + let v_w = vld1q_f32(weight.add(offset)); + let result = vmulq_f32(vmulq_f32(pn, v_inv_rms), v_w); + vst1q_f32(out_base.add(offset), result); + } + + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + let pn = *pn_base.add(offset); + let w = *weight.add(offset); + *out_base.add(offset) = pn * inv_rms * w; + } + } +} + +/// NEON Fused Add + RMS Normalization for f64 +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_f64( + input: *const f64, + residual: *const f64, + weight: *const f64, + out: *mut f64, + pre_norm: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + let remainder = hidden_size % F64_LANES; + + for b in 0..batch_size { + let base = input.add(b * hidden_size); + let res_base = residual.add(b * hidden_size); + let pn_base = pre_norm.add(b * hidden_size); + let out_base = out.add(b * hidden_size); + + let mut ss_acc = vdupq_n_f64(0.0); + for i in 0..chunks { + let offset = i * F64_LANES; + let v_in = vld1q_f64(base.add(offset)); + let v_res = vld1q_f64(res_base.add(offset)); + let pn = vaddq_f64(v_in, v_res); + vst1q_f64(pn_base.add(offset), pn); + ss_acc = vfmaq_f64(ss_acc, pn, pn); + } + let mut sum_sq = hsum_f64(ss_acc); + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + let pn = *base.add(offset) + *res_base.add(offset); + *pn_base.add(offset) = pn; + sum_sq += pn * pn; + } + + let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt(); + let v_inv_rms = vdupq_n_f64(inv_rms); + + for i in 0..chunks { + let offset = i * F64_LANES; + let pn = vld1q_f64(pn_base.add(offset)); + let v_w = vld1q_f64(weight.add(offset)); + let result = vmulq_f64(vmulq_f64(pn, v_inv_rms), v_w); + vst1q_f64(out_base.add(offset), result); + } + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + let pn = *pn_base.add(offset); + let w = *weight.add(offset); + *out_base.add(offset) = pn * inv_rms * w; + } + } +} + +/// NEON Fused Add + RMS Norm Backward for f32 +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_bwd_f32( + grad: *const f32, + pre_norm: *const f32, + weight: *const f32, + d_input_residual: *mut f32, + d_weight: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + let remainder = hidden_size % F32_LANES; + + for b in 0..batch_size { + let pn_base = pre_norm.add(b * hidden_size); + let grad_base = grad.add(b * hidden_size); + let d_ir_base = d_input_residual.add(b * hidden_size); + + // Recompute mean square from pre_norm + let mut acc_sq = vdupq_n_f32(0.0); + for i in 0..chunks { + let offset = i * F32_LANES; + let pn = vld1q_f32(pn_base.add(offset)); + acc_sq = vfmaq_f32(acc_sq, pn, pn); + } + let mut sum_sq = hsum_f32(acc_sq); + + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + let pn = *pn_base.add(offset); + sum_sq += pn * pn; + } + + let mean_sq = sum_sq / hidden_size as f32; + let inv_rms = 1.0 / (mean_sq + eps).sqrt(); + + // Compute dot = sum(grad * weight * pre_norm) + let mut dot_acc = vdupq_n_f32(0.0); + for i in 0..chunks { + let offset = i * F32_LANES; + let g = vld1q_f32(grad_base.add(offset)); + let w = vld1q_f32(weight.add(offset)); + let pn = vld1q_f32(pn_base.add(offset)); + let gw = vmulq_f32(g, w); + dot_acc = vfmaq_f32(dot_acc, gw, pn); + } + let mut dot = hsum_f32(dot_acc); + + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + let g = *grad_base.add(offset); + let w = *weight.add(offset); + let pn = *pn_base.add(offset); + dot += g * w * pn; + } + + let coeff = dot * inv_rms / (hidden_size as f32 * (mean_sq + eps)); + let v_inv_rms = vdupq_n_f32(inv_rms); + let v_coeff = vdupq_n_f32(coeff); + + // Compute d_input_residual and accumulate d_weight + for i in 0..chunks { + let offset = i * F32_LANES; + let g = vld1q_f32(grad_base.add(offset)); + let w = vld1q_f32(weight.add(offset)); + let pn = vld1q_f32(pn_base.add(offset)); + + // d_ir = (g*w - pn*coeff) * inv_rms + let gw = vmulq_f32(g, w); + let pn_coeff = vmulq_f32(pn, v_coeff); + let diff = vsubq_f32(gw, pn_coeff); + let d_ir = vmulq_f32(diff, v_inv_rms); + vst1q_f32(d_ir_base.add(offset), d_ir); + + // d_weight += g * pn * inv_rms + let dw_old = vld1q_f32(d_weight.add(offset)); + let gp = vmulq_f32(g, pn); + let gp_inv = vmulq_f32(gp, v_inv_rms); + let dw_new = vaddq_f32(dw_old, gp_inv); + vst1q_f32(d_weight.add(offset), dw_new); + } + + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + let g = *grad_base.add(offset); + let w = *weight.add(offset); + let pn = *pn_base.add(offset); + + let d_ir = (g * w - pn * coeff) * inv_rms; + *d_ir_base.add(offset) = d_ir; + + let d_w = g * pn * inv_rms; + *d_weight.add(offset) += d_w; + } + } +} + +/// NEON Fused Add + RMS Norm Backward for f64 +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_bwd_f64( + grad: *const f64, + pre_norm: *const f64, + weight: *const f64, + d_input_residual: *mut f64, + d_weight: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + let remainder = hidden_size % F64_LANES; + + for b in 0..batch_size { + let pn_base = pre_norm.add(b * hidden_size); + let grad_base = grad.add(b * hidden_size); + let d_ir_base = d_input_residual.add(b * hidden_size); + + let mut acc_sq = vdupq_n_f64(0.0); + for i in 0..chunks { + let offset = i * F64_LANES; + let pn = vld1q_f64(pn_base.add(offset)); + acc_sq = vfmaq_f64(acc_sq, pn, pn); + } + let mut sum_sq = hsum_f64(acc_sq); + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + let pn = *pn_base.add(offset); + sum_sq += pn * pn; + } + + let mean_sq = sum_sq / hidden_size as f64; + let inv_rms = 1.0 / (mean_sq + eps).sqrt(); + + let mut dot_acc = vdupq_n_f64(0.0); + for i in 0..chunks { + let offset = i * F64_LANES; + let g = vld1q_f64(grad_base.add(offset)); + let w = vld1q_f64(weight.add(offset)); + let pn = vld1q_f64(pn_base.add(offset)); + let gw = vmulq_f64(g, w); + dot_acc = vfmaq_f64(dot_acc, gw, pn); + } + let mut dot = hsum_f64(dot_acc); + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + let g = *grad_base.add(offset); + let w = *weight.add(offset); + let pn = *pn_base.add(offset); + dot += g * w * pn; + } + + let coeff = dot * inv_rms / (hidden_size as f64 * (mean_sq + eps)); + let v_inv_rms = vdupq_n_f64(inv_rms); + let v_coeff = vdupq_n_f64(coeff); + + for i in 0..chunks { + let offset = i * F64_LANES; + let g = vld1q_f64(grad_base.add(offset)); + let w = vld1q_f64(weight.add(offset)); + let pn = vld1q_f64(pn_base.add(offset)); + + let gw = vmulq_f64(g, w); + let pn_coeff = vmulq_f64(pn, v_coeff); + let diff = vsubq_f64(gw, pn_coeff); + let d_ir = vmulq_f64(diff, v_inv_rms); + vst1q_f64(d_ir_base.add(offset), d_ir); + + let dw_old = vld1q_f64(d_weight.add(offset)); + let gp = vmulq_f64(g, pn); + let gp_inv = vmulq_f64(gp, v_inv_rms); + let dw_new = vaddq_f64(dw_old, gp_inv); + vst1q_f64(d_weight.add(offset), dw_new); + } + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + let g = *grad_base.add(offset); + let w = *weight.add(offset); + let pn = *pn_base.add(offset); + + let d_ir = (g * w - pn * coeff) * inv_rms; + *d_ir_base.add(offset) = d_ir; + + let d_w = g * pn * inv_rms; + *d_weight.add(offset) += d_w; + } + } +} diff --git a/src/runtime/cpu/kernels/simd/norm/aarch64/neon.rs b/src/runtime/cpu/kernels/simd/norm/aarch64/neon/layer_norm.rs similarity index 54% rename from src/runtime/cpu/kernels/simd/norm/aarch64/neon.rs rename to src/runtime/cpu/kernels/simd/norm/aarch64/neon/layer_norm.rs index d7b59168..3af53048 100644 --- a/src/runtime/cpu/kernels/simd/norm/aarch64/neon.rs +++ b/src/runtime/cpu/kernels/simd/norm/aarch64/neon/layer_norm.rs @@ -1,140 +1,10 @@ -//! NEON normalization kernels for ARM64 -//! -//! Provides vectorized RMS normalization and Layer normalization using 128-bit NEON registers. -//! -//! # RMS Normalization -//! output = input * rsqrt(mean(input^2) + eps) * weight -//! -//! # Layer Normalization -//! output = (input - mean) * rsqrt(var + eps) * weight + bias -//! -//! # SIMD Strategy -//! -//! 1. SIMD sum of squares (FMA: acc += x * x) -//! 2. Horizontal reduction for sum -//! 3. Compute inverse RMS/std -//! 4. SIMD multiply for normalization and weight +//! NEON layer normalization kernels #[cfg(target_arch = "aarch64")] use std::arch::aarch64::*; -use super::super::super::math::aarch64::neon::{hsum_f32, hsum_f64}; - -const F32_LANES: usize = 4; -const F64_LANES: usize = 2; - -/// NEON RMS normalization for f32 -/// -/// # Safety -/// - CPU must support NEON (always true on AArch64) -/// - `input` and `out` must point to `batch_size * hidden_size` valid f32 elements -/// - `weight` must point to `hidden_size` valid f32 elements -#[cfg(target_arch = "aarch64")] -#[target_feature(enable = "neon")] -pub unsafe fn rms_norm_f32( - input: *const f32, - weight: *const f32, - out: *mut f32, - batch_size: usize, - hidden_size: usize, - eps: f32, -) { - let chunks = hidden_size / F32_LANES; - let remainder = hidden_size % F32_LANES; - - for b in 0..batch_size { - let base = input.add(b * hidden_size); - let out_base = out.add(b * hidden_size); - - // Phase 1: Sum of squares using FMA - let mut ss_acc = vdupq_n_f32(0.0); - for i in 0..chunks { - let v = vld1q_f32(base.add(i * F32_LANES)); - ss_acc = vfmaq_f32(ss_acc, v, v); // FMA: acc += v * v - } - let mut sum_sq = hsum_f32(ss_acc); - - // Scalar tail for sum of squares - for i in 0..remainder { - let v = *base.add(chunks * F32_LANES + i); - sum_sq += v * v; - } - - // Compute inverse RMS: 1 / sqrt(mean_sq + eps) - let inv_rms = 1.0 / (sum_sq / hidden_size as f32 + eps).sqrt(); - let v_inv_rms = vdupq_n_f32(inv_rms); - - // Phase 2: Apply normalization and weight - for i in 0..chunks { - let offset = i * F32_LANES; - let v_in = vld1q_f32(base.add(offset)); - let v_w = vld1q_f32(weight.add(offset)); - let result = vmulq_f32(vmulq_f32(v_in, v_inv_rms), v_w); - vst1q_f32(out_base.add(offset), result); - } - - // Scalar tail for normalization - for i in 0..remainder { - let offset = chunks * F32_LANES + i; - *out_base.add(offset) = *base.add(offset) * inv_rms * *weight.add(offset); - } - } -} - -/// NEON RMS normalization for f64 -/// -/// # Safety -/// - CPU must support NEON (always true on AArch64) -/// - `input` and `out` must point to `batch_size * hidden_size` valid f64 elements -/// - `weight` must point to `hidden_size` valid f64 elements -#[cfg(target_arch = "aarch64")] -#[target_feature(enable = "neon")] -pub unsafe fn rms_norm_f64( - input: *const f64, - weight: *const f64, - out: *mut f64, - batch_size: usize, - hidden_size: usize, - eps: f64, -) { - let chunks = hidden_size / F64_LANES; - let remainder = hidden_size % F64_LANES; - - for b in 0..batch_size { - let base = input.add(b * hidden_size); - let out_base = out.add(b * hidden_size); - - // Phase 1: Sum of squares - let mut ss_acc = vdupq_n_f64(0.0); - for i in 0..chunks { - let v = vld1q_f64(base.add(i * F64_LANES)); - ss_acc = vfmaq_f64(ss_acc, v, v); - } - let mut sum_sq = hsum_f64(ss_acc); - - for i in 0..remainder { - let v = *base.add(chunks * F64_LANES + i); - sum_sq += v * v; - } - - let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt(); - let v_inv_rms = vdupq_n_f64(inv_rms); - - // Phase 2: Apply normalization and weight - for i in 0..chunks { - let offset = i * F64_LANES; - let v_in = vld1q_f64(base.add(offset)); - let v_w = vld1q_f64(weight.add(offset)); - let result = vmulq_f64(vmulq_f64(v_in, v_inv_rms), v_w); - vst1q_f64(out_base.add(offset), result); - } - - for i in 0..remainder { - let offset = chunks * F64_LANES + i; - *out_base.add(offset) = *base.add(offset) * inv_rms * *weight.add(offset); - } - } -} +use super::super::super::super::math::aarch64::neon::{hsum_f32, hsum_f64}; +use super::{F32_LANES, F64_LANES}; /// NEON Layer normalization for f32 /// diff --git a/src/runtime/cpu/kernels/simd/norm/aarch64/neon/mod.rs b/src/runtime/cpu/kernels/simd/norm/aarch64/neon/mod.rs new file mode 100644 index 00000000..20597ede --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/aarch64/neon/mod.rs @@ -0,0 +1,22 @@ +//! NEON normalization kernels for ARM64 +//! +//! Provides vectorized RMS normalization and Layer normalization using 128-bit NEON registers. + +pub(super) const F32_LANES: usize = 4; +pub(super) const F64_LANES: usize = 2; + +mod fused_add_layer_norm; +mod fused_add_rms_norm; +mod layer_norm; +mod rms_norm; + +pub use fused_add_layer_norm::{ + fused_add_layer_norm_bwd_f32, fused_add_layer_norm_bwd_f64, fused_add_layer_norm_f32, + fused_add_layer_norm_f64, +}; +pub use fused_add_rms_norm::{ + fused_add_rms_norm_bwd_f32, fused_add_rms_norm_bwd_f64, fused_add_rms_norm_f32, + fused_add_rms_norm_f64, +}; +pub use layer_norm::{layer_norm_f32, layer_norm_f64}; +pub use rms_norm::{rms_norm_f32, rms_norm_f64}; diff --git a/src/runtime/cpu/kernels/simd/norm/aarch64/neon/rms_norm.rs b/src/runtime/cpu/kernels/simd/norm/aarch64/neon/rms_norm.rs new file mode 100644 index 00000000..c4179d9f --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/aarch64/neon/rms_norm.rs @@ -0,0 +1,120 @@ +//! NEON RMS normalization kernels + +#[cfg(target_arch = "aarch64")] +use std::arch::aarch64::*; + +use super::super::super::super::math::aarch64::neon::{hsum_f32, hsum_f64}; +use super::{F32_LANES, F64_LANES}; + +/// NEON RMS normalization for f32 +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - `input` and `out` must point to `batch_size * hidden_size` valid f32 elements +/// - `weight` must point to `hidden_size` valid f32 elements +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn rms_norm_f32( + input: *const f32, + weight: *const f32, + out: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + let remainder = hidden_size % F32_LANES; + + for b in 0..batch_size { + let base = input.add(b * hidden_size); + let out_base = out.add(b * hidden_size); + + // Phase 1: Sum of squares using FMA + let mut ss_acc = vdupq_n_f32(0.0); + for i in 0..chunks { + let v = vld1q_f32(base.add(i * F32_LANES)); + ss_acc = vfmaq_f32(ss_acc, v, v); + } + let mut sum_sq = hsum_f32(ss_acc) as f64; + + // Scalar tail for sum of squares + for i in 0..remainder { + let v = *base.add(chunks * F32_LANES + i) as f64; + sum_sq += v * v; + } + + // Compute inverse RMS in f64 for precision (matches llama.cpp) + let inv_rms = (1.0f64 / (sum_sq / hidden_size as f64 + eps as f64).sqrt()) as f32; + let v_inv_rms = vdupq_n_f32(inv_rms); + + // Phase 2: Apply normalization and weight + for i in 0..chunks { + let offset = i * F32_LANES; + let v_in = vld1q_f32(base.add(offset)); + let v_w = vld1q_f32(weight.add(offset)); + let result = vmulq_f32(vmulq_f32(v_in, v_inv_rms), v_w); + vst1q_f32(out_base.add(offset), result); + } + + // Scalar tail for normalization + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + *out_base.add(offset) = *base.add(offset) * inv_rms * *weight.add(offset); + } + } +} + +/// NEON RMS normalization for f64 +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - `input` and `out` must point to `batch_size * hidden_size` valid f64 elements +/// - `weight` must point to `hidden_size` valid f64 elements +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn rms_norm_f64( + input: *const f64, + weight: *const f64, + out: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + let remainder = hidden_size % F64_LANES; + + for b in 0..batch_size { + let base = input.add(b * hidden_size); + let out_base = out.add(b * hidden_size); + + // Phase 1: Sum of squares + let mut ss_acc = vdupq_n_f64(0.0); + for i in 0..chunks { + let v = vld1q_f64(base.add(i * F64_LANES)); + ss_acc = vfmaq_f64(ss_acc, v, v); + } + let mut sum_sq = hsum_f64(ss_acc); + + for i in 0..remainder { + let v = *base.add(chunks * F64_LANES + i); + sum_sq += v * v; + } + + let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt(); + let v_inv_rms = vdupq_n_f64(inv_rms); + + // Phase 2: Apply normalization and weight + for i in 0..chunks { + let offset = i * F64_LANES; + let v_in = vld1q_f64(base.add(offset)); + let v_w = vld1q_f64(weight.add(offset)); + let result = vmulq_f64(vmulq_f64(v_in, v_inv_rms), v_w); + vst1q_f64(out_base.add(offset), result); + } + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + *out_base.add(offset) = *base.add(offset) * inv_rms * *weight.add(offset); + } + } +} diff --git a/src/runtime/cpu/kernels/simd/norm/avx2.rs b/src/runtime/cpu/kernels/simd/norm/avx2.rs deleted file mode 100644 index 7812166b..00000000 --- a/src/runtime/cpu/kernels/simd/norm/avx2.rs +++ /dev/null @@ -1,295 +0,0 @@ -//! AVX2 normalization kernels -//! -//! SIMD-optimized RMS norm and layer norm with manual horizontal reductions. - -#[cfg(target_arch = "x86_64")] -use std::arch::x86_64::*; - -use super::{ - layer_norm_scalar_f32, layer_norm_scalar_f64, rms_norm_scalar_f32, rms_norm_scalar_f64, -}; - -const F32_LANES: usize = 8; -const F64_LANES: usize = 4; - -// ============================================================================ -// Horizontal reduction helpers -// ============================================================================ - -#[target_feature(enable = "avx2", enable = "fma")] -#[inline] -unsafe fn hsum_f32(v: __m256) -> f32 { - let high = _mm256_extractf128_ps(v, 1); - let low = _mm256_castps256_ps128(v); - let sum128 = _mm_add_ps(low, high); - let shuf = _mm_movehdup_ps(sum128); - let sum64 = _mm_add_ps(sum128, shuf); - let shuf2 = _mm_movehl_ps(sum64, sum64); - let sum32 = _mm_add_ss(sum64, shuf2); - _mm_cvtss_f32(sum32) -} - -#[target_feature(enable = "avx2", enable = "fma")] -#[inline] -unsafe fn hsum_f64(v: __m256d) -> f64 { - let high = _mm256_extractf128_pd(v, 1); - let low = _mm256_castpd256_pd128(v); - let sum128 = _mm_add_pd(low, high); - let shuf = _mm_unpackhi_pd(sum128, sum128); - let sum64 = _mm_add_sd(sum128, shuf); - _mm_cvtsd_f64(sum64) -} - -// ============================================================================ -// RMS Norm -// ============================================================================ - -/// AVX2 RMS normalization for f32 -#[target_feature(enable = "avx2", enable = "fma")] -pub unsafe fn rms_norm_f32( - input: *const f32, - weight: *const f32, - out: *mut f32, - batch_size: usize, - hidden_size: usize, - eps: f32, -) { - let chunks = hidden_size / F32_LANES; - - for batch in 0..batch_size { - let row_start = batch * hidden_size; - - // SIMD sum of squares using FMA - let mut acc = _mm256_setzero_ps(); - for c in 0..chunks { - let offset = row_start + c * F32_LANES; - let v = _mm256_loadu_ps(input.add(offset)); - acc = _mm256_fmadd_ps(v, v, acc); - } - let mut sum_sq = hsum_f32(acc); - - for i in (chunks * F32_LANES)..hidden_size { - let x = *input.add(row_start + i); - sum_sq += x * x; - } - - let inv_rms = 1.0 / (sum_sq / hidden_size as f32 + eps).sqrt(); - let v_inv_rms = _mm256_set1_ps(inv_rms); - - for c in 0..chunks { - let offset = row_start + c * F32_LANES; - let w_offset = c * F32_LANES; - let v_input = _mm256_loadu_ps(input.add(offset)); - let v_weight = _mm256_loadu_ps(weight.add(w_offset)); - let v_result = _mm256_mul_ps(_mm256_mul_ps(v_input, v_inv_rms), v_weight); - _mm256_storeu_ps(out.add(offset), v_result); - } - - for i in (chunks * F32_LANES)..hidden_size { - let x = *input.add(row_start + i); - let w = *weight.add(i); - *out.add(row_start + i) = x * inv_rms * w; - } - } -} - -/// AVX2 RMS normalization for f64 -#[target_feature(enable = "avx2", enable = "fma")] -pub unsafe fn rms_norm_f64( - input: *const f64, - weight: *const f64, - out: *mut f64, - batch_size: usize, - hidden_size: usize, - eps: f64, -) { - let chunks = hidden_size / F64_LANES; - - for batch in 0..batch_size { - let row_start = batch * hidden_size; - - let mut acc = _mm256_setzero_pd(); - for c in 0..chunks { - let offset = row_start + c * F64_LANES; - let v = _mm256_loadu_pd(input.add(offset)); - acc = _mm256_fmadd_pd(v, v, acc); - } - let mut sum_sq = hsum_f64(acc); - - for i in (chunks * F64_LANES)..hidden_size { - let x = *input.add(row_start + i); - sum_sq += x * x; - } - - let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt(); - let v_inv_rms = _mm256_set1_pd(inv_rms); - - for c in 0..chunks { - let offset = row_start + c * F64_LANES; - let w_offset = c * F64_LANES; - let v_input = _mm256_loadu_pd(input.add(offset)); - let v_weight = _mm256_loadu_pd(weight.add(w_offset)); - let v_result = _mm256_mul_pd(_mm256_mul_pd(v_input, v_inv_rms), v_weight); - _mm256_storeu_pd(out.add(offset), v_result); - } - - for i in (chunks * F64_LANES)..hidden_size { - let x = *input.add(row_start + i); - let w = *weight.add(i); - *out.add(row_start + i) = x * inv_rms * w; - } - } -} - -// ============================================================================ -// Layer Norm -// ============================================================================ - -/// AVX2 Layer normalization for f32 -#[target_feature(enable = "avx2", enable = "fma")] -pub unsafe fn layer_norm_f32( - input: *const f32, - weight: *const f32, - bias: *const f32, - out: *mut f32, - batch_size: usize, - hidden_size: usize, - eps: f32, -) { - let chunks = hidden_size / F32_LANES; - - for batch in 0..batch_size { - let row_start = batch * hidden_size; - - // SIMD sum for mean - let mut sum_acc = _mm256_setzero_ps(); - for c in 0..chunks { - let v = _mm256_loadu_ps(input.add(row_start + c * F32_LANES)); - sum_acc = _mm256_add_ps(sum_acc, v); - } - let mut sum = hsum_f32(sum_acc); - - for i in (chunks * F32_LANES)..hidden_size { - sum += *input.add(row_start + i); - } - let mean = sum / hidden_size as f32; - let v_mean = _mm256_set1_ps(mean); - - // SIMD variance computation - let mut var_acc = _mm256_setzero_ps(); - for c in 0..chunks { - let v = _mm256_loadu_ps(input.add(row_start + c * F32_LANES)); - let diff = _mm256_sub_ps(v, v_mean); - var_acc = _mm256_fmadd_ps(diff, diff, var_acc); - } - let mut var_sum = hsum_f32(var_acc); - - for i in (chunks * F32_LANES)..hidden_size { - let diff = *input.add(row_start + i) - mean; - var_sum += diff * diff; - } - let inv_std = 1.0 / (var_sum / hidden_size as f32 + eps).sqrt(); - let v_inv_std = _mm256_set1_ps(inv_std); - - // SIMD normalization with weight and bias - for c in 0..chunks { - let offset = row_start + c * F32_LANES; - let w_offset = c * F32_LANES; - let v_input = _mm256_loadu_ps(input.add(offset)); - let v_weight = _mm256_loadu_ps(weight.add(w_offset)); - let v_bias = _mm256_loadu_ps(bias.add(w_offset)); - - let diff = _mm256_sub_ps(v_input, v_mean); - let normalized = _mm256_mul_ps(diff, v_inv_std); - let scaled = _mm256_mul_ps(normalized, v_weight); - let result = _mm256_add_ps(scaled, v_bias); - - _mm256_storeu_ps(out.add(offset), result); - } - - for i in (chunks * F32_LANES)..hidden_size { - let x = *input.add(row_start + i); - let w = *weight.add(i); - let b = *bias.add(i); - *out.add(row_start + i) = (x - mean) * inv_std * w + b; - } - } -} - -/// AVX2 Layer normalization for f64 -#[target_feature(enable = "avx2", enable = "fma")] -pub unsafe fn layer_norm_f64( - input: *const f64, - weight: *const f64, - bias: *const f64, - out: *mut f64, - batch_size: usize, - hidden_size: usize, - eps: f64, -) { - let chunks = hidden_size / F64_LANES; - - for batch in 0..batch_size { - let row_start = batch * hidden_size; - - let mut sum_acc = _mm256_setzero_pd(); - for c in 0..chunks { - let v = _mm256_loadu_pd(input.add(row_start + c * F64_LANES)); - sum_acc = _mm256_add_pd(sum_acc, v); - } - let mut sum = hsum_f64(sum_acc); - - for i in (chunks * F64_LANES)..hidden_size { - sum += *input.add(row_start + i); - } - let mean = sum / hidden_size as f64; - let v_mean = _mm256_set1_pd(mean); - - let mut var_acc = _mm256_setzero_pd(); - for c in 0..chunks { - let v = _mm256_loadu_pd(input.add(row_start + c * F64_LANES)); - let diff = _mm256_sub_pd(v, v_mean); - var_acc = _mm256_fmadd_pd(diff, diff, var_acc); - } - let mut var_sum = hsum_f64(var_acc); - - for i in (chunks * F64_LANES)..hidden_size { - let diff = *input.add(row_start + i) - mean; - var_sum += diff * diff; - } - let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); - let v_inv_std = _mm256_set1_pd(inv_std); - - for c in 0..chunks { - let offset = row_start + c * F64_LANES; - let w_offset = c * F64_LANES; - let v_input = _mm256_loadu_pd(input.add(offset)); - let v_weight = _mm256_loadu_pd(weight.add(w_offset)); - let v_bias = _mm256_loadu_pd(bias.add(w_offset)); - - let diff = _mm256_sub_pd(v_input, v_mean); - let normalized = _mm256_mul_pd(diff, v_inv_std); - let scaled = _mm256_mul_pd(normalized, v_weight); - let result = _mm256_add_pd(scaled, v_bias); - - _mm256_storeu_pd(out.add(offset), result); - } - - for i in (chunks * F64_LANES)..hidden_size { - let x = *input.add(row_start + i); - let w = *weight.add(i); - let b = *bias.add(i); - *out.add(row_start + i) = (x - mean) * inv_std * w + b; - } - } -} - -// Suppress unused warnings for scalar fallback imports used in dispatch -const _: () = { - let _ = rms_norm_scalar_f32 as unsafe fn(*const f32, *const f32, *mut f32, usize, usize, f32); - let _ = rms_norm_scalar_f64 as unsafe fn(*const f64, *const f64, *mut f64, usize, usize, f64); - let _ = layer_norm_scalar_f32 - as unsafe fn(*const f32, *const f32, *const f32, *mut f32, usize, usize, f32); - let _ = layer_norm_scalar_f64 - as unsafe fn(*const f64, *const f64, *const f64, *mut f64, usize, usize, f64); -}; diff --git a/src/runtime/cpu/kernels/simd/norm/avx2/fused_add_layer_norm.rs b/src/runtime/cpu/kernels/simd/norm/avx2/fused_add_layer_norm.rs new file mode 100644 index 00000000..594f1566 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/avx2/fused_add_layer_norm.rs @@ -0,0 +1,516 @@ +//! AVX2 fused add + layer normalization kernels (forward and backward) + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::{F32_LANES, F64_LANES, hsum_f32, hsum_f64}; + +/// AVX2 Fused Add + Layer Normalization for f32 +/// +/// Computes: output = (input + residual - mean) / sqrt(var + eps) * weight + bias +/// Stores intermediate (input + residual) in pre_norm for backward pass. +#[target_feature(enable = "avx2", enable = "fma")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_f32( + input: *const f32, + residual: *const f32, + weight: *const f32, + bias: *const f32, + out: *mut f32, + pre_norm: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + // Phase 1: Add and store in pre_norm, compute mean + let mut sum_acc = _mm256_setzero_ps(); + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let v_in = _mm256_loadu_ps(input.add(offset)); + let v_res = _mm256_loadu_ps(residual.add(offset)); + let pn = _mm256_add_ps(v_in, v_res); + _mm256_storeu_ps(pre_norm.add(offset), pn); + sum_acc = _mm256_add_ps(sum_acc, pn); + } + let mut sum = hsum_f32(sum_acc); + + for i in (chunks * F32_LANES)..hidden_size { + let pn = *input.add(row_start + i) + *residual.add(row_start + i); + *pre_norm.add(row_start + i) = pn; + sum += pn; + } + + let mean = sum / hidden_size as f32; + let v_mean = _mm256_set1_ps(mean); + + // Phase 2: Compute variance (dual accumulators) + let mut var_acc0 = _mm256_setzero_ps(); + let mut var_acc1 = _mm256_setzero_ps(); + let mut c = 0; + let chunk_pairs = chunks / 2 * 2; + while c < chunk_pairs { + let diff0 = _mm256_sub_ps( + _mm256_loadu_ps(pre_norm.add(row_start + c * F32_LANES)), + v_mean, + ); + var_acc0 = _mm256_fmadd_ps(diff0, diff0, var_acc0); + let diff1 = _mm256_sub_ps( + _mm256_loadu_ps(pre_norm.add(row_start + (c + 1) * F32_LANES)), + v_mean, + ); + var_acc1 = _mm256_fmadd_ps(diff1, diff1, var_acc1); + c += 2; + } + while c < chunks { + let diff = _mm256_sub_ps( + _mm256_loadu_ps(pre_norm.add(row_start + c * F32_LANES)), + v_mean, + ); + var_acc0 = _mm256_fmadd_ps(diff, diff, var_acc0); + c += 1; + } + let mut var_sum = hsum_f32(_mm256_add_ps(var_acc0, var_acc1)); + + for i in (chunks * F32_LANES)..hidden_size { + let diff = *pre_norm.add(row_start + i) - mean; + var_sum += diff * diff; + } + + let inv_std = 1.0 / (var_sum / hidden_size as f32 + eps).sqrt(); + let v_inv_std = _mm256_set1_ps(inv_std); + + // Phase 3: Normalize, apply weight and bias + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let pn = _mm256_loadu_ps(pre_norm.add(offset)); + let v_weight = _mm256_loadu_ps(weight.add(w_offset)); + let v_bias = _mm256_loadu_ps(bias.add(w_offset)); + + let diff = _mm256_sub_ps(pn, v_mean); + let normalized = _mm256_mul_ps(diff, v_inv_std); + let scaled = _mm256_mul_ps(normalized, v_weight); + let result = _mm256_add_ps(scaled, v_bias); + + _mm256_storeu_ps(out.add(offset), result); + } + + for i in (chunks * F32_LANES)..hidden_size { + let pn = *pre_norm.add(row_start + i); + let w = *weight.add(i); + let b = *bias.add(i); + *out.add(row_start + i) = (pn - mean) * inv_std * w + b; + } + } +} + +/// AVX2 Fused Add + Layer Normalization for f64 +#[target_feature(enable = "avx2", enable = "fma")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_f64( + input: *const f64, + residual: *const f64, + weight: *const f64, + bias: *const f64, + out: *mut f64, + pre_norm: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum_acc = _mm256_setzero_pd(); + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let v_in = _mm256_loadu_pd(input.add(offset)); + let v_res = _mm256_loadu_pd(residual.add(offset)); + let pn = _mm256_add_pd(v_in, v_res); + _mm256_storeu_pd(pre_norm.add(offset), pn); + sum_acc = _mm256_add_pd(sum_acc, pn); + } + let mut sum = hsum_f64(sum_acc); + + for i in (chunks * F64_LANES)..hidden_size { + let pn = *input.add(row_start + i) + *residual.add(row_start + i); + *pre_norm.add(row_start + i) = pn; + sum += pn; + } + + let mean = sum / hidden_size as f64; + let v_mean = _mm256_set1_pd(mean); + + let mut var_acc0 = _mm256_setzero_pd(); + let mut var_acc1 = _mm256_setzero_pd(); + let mut c = 0; + let chunk_pairs_v = chunks / 2 * 2; + while c < chunk_pairs_v { + let diff0 = _mm256_sub_pd( + _mm256_loadu_pd(pre_norm.add(row_start + c * F64_LANES)), + v_mean, + ); + var_acc0 = _mm256_fmadd_pd(diff0, diff0, var_acc0); + let diff1 = _mm256_sub_pd( + _mm256_loadu_pd(pre_norm.add(row_start + (c + 1) * F64_LANES)), + v_mean, + ); + var_acc1 = _mm256_fmadd_pd(diff1, diff1, var_acc1); + c += 2; + } + while c < chunks { + let diff = _mm256_sub_pd( + _mm256_loadu_pd(pre_norm.add(row_start + c * F64_LANES)), + v_mean, + ); + var_acc0 = _mm256_fmadd_pd(diff, diff, var_acc0); + c += 1; + } + let mut var_sum = hsum_f64(_mm256_add_pd(var_acc0, var_acc1)); + + for i in (chunks * F64_LANES)..hidden_size { + let diff = *pre_norm.add(row_start + i) - mean; + var_sum += diff * diff; + } + + let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); + let v_inv_std = _mm256_set1_pd(inv_std); + + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let pn = _mm256_loadu_pd(pre_norm.add(offset)); + let v_weight = _mm256_loadu_pd(weight.add(w_offset)); + let v_bias = _mm256_loadu_pd(bias.add(w_offset)); + + let diff = _mm256_sub_pd(pn, v_mean); + let normalized = _mm256_mul_pd(diff, v_inv_std); + let scaled = _mm256_mul_pd(normalized, v_weight); + let result = _mm256_add_pd(scaled, v_bias); + + _mm256_storeu_pd(out.add(offset), result); + } + + for i in (chunks * F64_LANES)..hidden_size { + let pn = *pre_norm.add(row_start + i); + let w = *weight.add(i); + let b = *bias.add(i); + *out.add(row_start + i) = (pn - mean) * inv_std * w + b; + } + } +} + +/// AVX2 Fused Add + Layer Norm Backward for f32 +/// +/// Computes gradients for backward pass of layer norm +#[target_feature(enable = "avx2", enable = "fma")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_bwd_f32( + grad: *const f32, + pre_norm: *const f32, + weight: *const f32, + d_input_residual: *mut f32, + d_weight: *mut f32, + d_bias: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + // Recompute mean from pre_norm + let mut sum_acc = _mm256_setzero_ps(); + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let pn = _mm256_loadu_ps(pre_norm.add(offset)); + sum_acc = _mm256_add_ps(sum_acc, pn); + } + let mut sum = hsum_f32(sum_acc); + + for i in (chunks * F32_LANES)..hidden_size { + sum += *pre_norm.add(row_start + i); + } + + let mean = sum / hidden_size as f32; + let v_mean = _mm256_set1_ps(mean); + + // Recompute variance (dual accumulators) + let mut var_acc0 = _mm256_setzero_ps(); + let mut var_acc1 = _mm256_setzero_ps(); + let mut c = 0; + let chunk_pairs_v = chunks / 2 * 2; + while c < chunk_pairs_v { + let diff0 = _mm256_sub_ps( + _mm256_loadu_ps(pre_norm.add(row_start + c * F32_LANES)), + v_mean, + ); + var_acc0 = _mm256_fmadd_ps(diff0, diff0, var_acc0); + let diff1 = _mm256_sub_ps( + _mm256_loadu_ps(pre_norm.add(row_start + (c + 1) * F32_LANES)), + v_mean, + ); + var_acc1 = _mm256_fmadd_ps(diff1, diff1, var_acc1); + c += 2; + } + while c < chunks { + let diff = _mm256_sub_ps( + _mm256_loadu_ps(pre_norm.add(row_start + c * F32_LANES)), + v_mean, + ); + var_acc0 = _mm256_fmadd_ps(diff, diff, var_acc0); + c += 1; + } + let mut var_sum = hsum_f32(_mm256_add_ps(var_acc0, var_acc1)); + + for i in (chunks * F32_LANES)..hidden_size { + let diff = *pre_norm.add(row_start + i) - mean; + var_sum += diff * diff; + } + + let inv_std = 1.0 / (var_sum / hidden_size as f32 + eps).sqrt(); + + // Compute mean_gs = mean(grad * weight) and mean_gs_n = mean(grad * weight * normalized) + let mut gs_acc = _mm256_setzero_ps(); + let mut gsn_acc = _mm256_setzero_ps(); + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let g = _mm256_loadu_ps(grad.add(offset)); + let w = _mm256_loadu_ps(weight.add(w_offset)); + let pn = _mm256_loadu_ps(pre_norm.add(offset)); + + let gs = _mm256_mul_ps(g, w); + gs_acc = _mm256_add_ps(gs_acc, gs); + + let diff = _mm256_sub_ps(pn, v_mean); + let normalized = _mm256_mul_ps(diff, _mm256_set1_ps(inv_std)); + let gsn = _mm256_mul_ps(gs, normalized); + gsn_acc = _mm256_add_ps(gsn_acc, gsn); + } + let mut mean_gs_simd = hsum_f32(gs_acc); + let mut mean_gsn_simd = hsum_f32(gsn_acc); + + for i in (chunks * F32_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let gs = g * w; + mean_gs_simd += gs; + + let normalized = (pn - mean) * inv_std; + mean_gsn_simd += gs * normalized; + } + + let mean_gs = mean_gs_simd / hidden_size as f32; + let mean_gs_n = mean_gsn_simd / hidden_size as f32; + let v_inv_std = _mm256_set1_ps(inv_std); + let v_mean_gs = _mm256_set1_ps(mean_gs); + let v_mean_gs_n = _mm256_set1_ps(mean_gs_n); + + // Apply and accumulate + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let g = _mm256_loadu_ps(grad.add(offset)); + let w = _mm256_loadu_ps(weight.add(w_offset)); + let pn = _mm256_loadu_ps(pre_norm.add(offset)); + + let normalized = _mm256_mul_ps(_mm256_sub_ps(pn, v_mean), v_inv_std); + let gs = _mm256_mul_ps(g, w); + let d_ir = _mm256_mul_ps( + v_inv_std, + _mm256_sub_ps( + gs, + _mm256_add_ps(v_mean_gs, _mm256_mul_ps(normalized, v_mean_gs_n)), + ), + ); + _mm256_storeu_ps(d_input_residual.add(offset), d_ir); + + // d_weight += g * normalized + let dw_old = _mm256_loadu_ps(d_weight.add(w_offset)); + let dw_add = _mm256_mul_ps(g, normalized); + let dw_new = _mm256_add_ps(dw_old, dw_add); + _mm256_storeu_ps(d_weight.add(w_offset), dw_new); + + // d_bias += g + let db_old = _mm256_loadu_ps(d_bias.add(w_offset)); + let db_new = _mm256_add_ps(db_old, g); + _mm256_storeu_ps(d_bias.add(w_offset), db_new); + } + + for i in (chunks * F32_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let normalized = (pn - mean) * inv_std; + let gs = g * w; + let d_ir = inv_std * (gs - mean_gs - normalized * mean_gs_n); + *d_input_residual.add(row_start + i) = d_ir; + + *d_weight.add(i) += g * normalized; + *d_bias.add(i) += g; + } + } +} + +/// AVX2 Fused Add + Layer Norm Backward for f64 +#[target_feature(enable = "avx2", enable = "fma")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_bwd_f64( + grad: *const f64, + pre_norm: *const f64, + weight: *const f64, + d_input_residual: *mut f64, + d_weight: *mut f64, + d_bias: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum_acc = _mm256_setzero_pd(); + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let pn = _mm256_loadu_pd(pre_norm.add(offset)); + sum_acc = _mm256_add_pd(sum_acc, pn); + } + let mut sum = hsum_f64(sum_acc); + + for i in (chunks * F64_LANES)..hidden_size { + sum += *pre_norm.add(row_start + i); + } + + let mean = sum / hidden_size as f64; + let v_mean = _mm256_set1_pd(mean); + + let mut var_acc0 = _mm256_setzero_pd(); + let mut var_acc1 = _mm256_setzero_pd(); + let mut c = 0; + let chunk_pairs_v = chunks / 2 * 2; + while c < chunk_pairs_v { + let diff0 = _mm256_sub_pd( + _mm256_loadu_pd(pre_norm.add(row_start + c * F64_LANES)), + v_mean, + ); + var_acc0 = _mm256_fmadd_pd(diff0, diff0, var_acc0); + let diff1 = _mm256_sub_pd( + _mm256_loadu_pd(pre_norm.add(row_start + (c + 1) * F64_LANES)), + v_mean, + ); + var_acc1 = _mm256_fmadd_pd(diff1, diff1, var_acc1); + c += 2; + } + while c < chunks { + let diff = _mm256_sub_pd( + _mm256_loadu_pd(pre_norm.add(row_start + c * F64_LANES)), + v_mean, + ); + var_acc0 = _mm256_fmadd_pd(diff, diff, var_acc0); + c += 1; + } + let mut var_sum = hsum_f64(_mm256_add_pd(var_acc0, var_acc1)); + + for i in (chunks * F64_LANES)..hidden_size { + let diff = *pre_norm.add(row_start + i) - mean; + var_sum += diff * diff; + } + + let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); + + let mut gs_acc = _mm256_setzero_pd(); + let mut gsn_acc = _mm256_setzero_pd(); + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let g = _mm256_loadu_pd(grad.add(offset)); + let w = _mm256_loadu_pd(weight.add(w_offset)); + let pn = _mm256_loadu_pd(pre_norm.add(offset)); + + let gs = _mm256_mul_pd(g, w); + gs_acc = _mm256_add_pd(gs_acc, gs); + + let diff = _mm256_sub_pd(pn, v_mean); + let normalized = _mm256_mul_pd(diff, _mm256_set1_pd(inv_std)); + let gsn = _mm256_mul_pd(gs, normalized); + gsn_acc = _mm256_add_pd(gsn_acc, gsn); + } + let mut mean_gs_simd = hsum_f64(gs_acc); + let mut mean_gsn_simd = hsum_f64(gsn_acc); + + for i in (chunks * F64_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let gs = g * w; + mean_gs_simd += gs; + + let normalized = (pn - mean) * inv_std; + mean_gsn_simd += gs * normalized; + } + + let mean_gs = mean_gs_simd / hidden_size as f64; + let mean_gs_n = mean_gsn_simd / hidden_size as f64; + let v_inv_std = _mm256_set1_pd(inv_std); + let v_mean_gs = _mm256_set1_pd(mean_gs); + let v_mean_gs_n = _mm256_set1_pd(mean_gs_n); + + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let g = _mm256_loadu_pd(grad.add(offset)); + let w = _mm256_loadu_pd(weight.add(w_offset)); + let pn = _mm256_loadu_pd(pre_norm.add(offset)); + + let normalized = _mm256_mul_pd(_mm256_sub_pd(pn, v_mean), v_inv_std); + let gs = _mm256_mul_pd(g, w); + let d_ir = _mm256_mul_pd( + v_inv_std, + _mm256_sub_pd( + gs, + _mm256_add_pd(v_mean_gs, _mm256_mul_pd(normalized, v_mean_gs_n)), + ), + ); + _mm256_storeu_pd(d_input_residual.add(offset), d_ir); + + let dw_old = _mm256_loadu_pd(d_weight.add(w_offset)); + let dw_add = _mm256_mul_pd(g, normalized); + let dw_new = _mm256_add_pd(dw_old, dw_add); + _mm256_storeu_pd(d_weight.add(w_offset), dw_new); + + let db_old = _mm256_loadu_pd(d_bias.add(w_offset)); + let db_new = _mm256_add_pd(db_old, g); + _mm256_storeu_pd(d_bias.add(w_offset), db_new); + } + + for i in (chunks * F64_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let normalized = (pn - mean) * inv_std; + let gs = g * w; + let d_ir = inv_std * (gs - mean_gs - normalized * mean_gs_n); + *d_input_residual.add(row_start + i) = d_ir; + + *d_weight.add(i) += g * normalized; + *d_bias.add(i) += g; + } + } +} diff --git a/src/runtime/cpu/kernels/simd/norm/avx2/fused_add_rms_norm.rs b/src/runtime/cpu/kernels/simd/norm/avx2/fused_add_rms_norm.rs new file mode 100644 index 00000000..8705707c --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/avx2/fused_add_rms_norm.rs @@ -0,0 +1,360 @@ +//! AVX2 fused add + RMS normalization kernels (forward and backward) + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::{F32_LANES, F64_LANES, hsum_f32, hsum_f64}; + +/// AVX2 Fused Add + RMS Normalization for f32 +/// +/// Computes: output = (input + residual) * rsqrt(mean((input + residual)^2) + eps) * weight +/// Stores intermediate (input + residual) in pre_norm for backward pass. +#[target_feature(enable = "avx2", enable = "fma")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_f32( + input: *const f32, + residual: *const f32, + weight: *const f32, + out: *mut f32, + pre_norm: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + // Phase 1: Add input + residual, store in pre_norm, accumulate sum of squares in f64 + let mut acc_lo = _mm256_setzero_pd(); + let mut acc_hi = _mm256_setzero_pd(); + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let v_in = _mm256_loadu_ps(input.add(offset)); + let v_res = _mm256_loadu_ps(residual.add(offset)); + let pn = _mm256_add_ps(v_in, v_res); + _mm256_storeu_ps(pre_norm.add(offset), pn); + let lo = _mm256_cvtps_pd(_mm256_castps256_ps128(pn)); + let hi = _mm256_cvtps_pd(_mm256_extractf128_ps(pn, 1)); + acc_lo = _mm256_fmadd_pd(lo, lo, acc_lo); + acc_hi = _mm256_fmadd_pd(hi, hi, acc_hi); + } + let mut sum_sq = hsum_f64(_mm256_add_pd(acc_lo, acc_hi)); + + // Scalar tail for add and sum of squares + for i in (chunks * F32_LANES)..hidden_size { + let pn = *input.add(row_start + i) + *residual.add(row_start + i); + *pre_norm.add(row_start + i) = pn; + let pn64 = pn as f64; + sum_sq += pn64 * pn64; + } + + // Compute inverse RMS in f64 for precision (matches llama.cpp) + let inv_rms = (1.0f64 / (sum_sq / hidden_size as f64 + eps as f64).sqrt()) as f32; + let v_inv_rms = _mm256_set1_ps(inv_rms); + + // Phase 2: Normalize and apply weight + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let pn = _mm256_loadu_ps(pre_norm.add(offset)); + let v_weight = _mm256_loadu_ps(weight.add(w_offset)); + let v_result = _mm256_mul_ps(_mm256_mul_ps(pn, v_inv_rms), v_weight); + _mm256_storeu_ps(out.add(offset), v_result); + } + + // Scalar tail for normalization + for i in (chunks * F32_LANES)..hidden_size { + let pn = *pre_norm.add(row_start + i); + let w = *weight.add(i); + *out.add(row_start + i) = pn * inv_rms * w; + } + } +} + +/// AVX2 Fused Add + RMS Normalization for f64 +#[target_feature(enable = "avx2", enable = "fma")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_f64( + input: *const f64, + residual: *const f64, + weight: *const f64, + out: *mut f64, + pre_norm: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut acc0 = _mm256_setzero_pd(); + let mut acc1 = _mm256_setzero_pd(); + let mut c = 0; + let chunk_pairs = chunks / 2 * 2; + while c < chunk_pairs { + let offset0 = row_start + c * F64_LANES; + let offset1 = row_start + (c + 1) * F64_LANES; + let v_in0 = _mm256_loadu_pd(input.add(offset0)); + let v_res0 = _mm256_loadu_pd(residual.add(offset0)); + let pn0 = _mm256_add_pd(v_in0, v_res0); + _mm256_storeu_pd(pre_norm.add(offset0), pn0); + acc0 = _mm256_fmadd_pd(pn0, pn0, acc0); + + let v_in1 = _mm256_loadu_pd(input.add(offset1)); + let v_res1 = _mm256_loadu_pd(residual.add(offset1)); + let pn1 = _mm256_add_pd(v_in1, v_res1); + _mm256_storeu_pd(pre_norm.add(offset1), pn1); + acc1 = _mm256_fmadd_pd(pn1, pn1, acc1); + c += 2; + } + while c < chunks { + let offset = row_start + c * F64_LANES; + let v_in = _mm256_loadu_pd(input.add(offset)); + let v_res = _mm256_loadu_pd(residual.add(offset)); + let pn = _mm256_add_pd(v_in, v_res); + _mm256_storeu_pd(pre_norm.add(offset), pn); + acc0 = _mm256_fmadd_pd(pn, pn, acc0); + c += 1; + } + let mut sum_sq = hsum_f64(_mm256_add_pd(acc0, acc1)); + + for i in (chunks * F64_LANES)..hidden_size { + let pn = *input.add(row_start + i) + *residual.add(row_start + i); + *pre_norm.add(row_start + i) = pn; + sum_sq += pn * pn; + } + + let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt(); + let v_inv_rms = _mm256_set1_pd(inv_rms); + + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let pn = _mm256_loadu_pd(pre_norm.add(offset)); + let v_weight = _mm256_loadu_pd(weight.add(w_offset)); + let v_result = _mm256_mul_pd(_mm256_mul_pd(pn, v_inv_rms), v_weight); + _mm256_storeu_pd(out.add(offset), v_result); + } + + for i in (chunks * F64_LANES)..hidden_size { + let pn = *pre_norm.add(row_start + i); + let w = *weight.add(i); + *out.add(row_start + i) = pn * inv_rms * w; + } + } +} + +/// AVX2 Fused Add + RMS Norm Backward for f32 +/// +/// Computes gradients: d_input_residual = (grad * weight - pre_norm * coeff) / inv_rms +/// d_weight += grad * pre_norm / inv_rms +#[target_feature(enable = "avx2", enable = "fma")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_bwd_f32( + grad: *const f32, + pre_norm: *const f32, + weight: *const f32, + d_input_residual: *mut f32, + d_weight: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + // Recompute mean square from pre_norm (dual accumulators) + let mut acc_sq0 = _mm256_setzero_ps(); + let mut acc_sq1 = _mm256_setzero_ps(); + let mut c = 0; + let chunk_pairs = chunks / 2 * 2; + while c < chunk_pairs { + let pn0 = _mm256_loadu_ps(pre_norm.add(row_start + c * F32_LANES)); + acc_sq0 = _mm256_fmadd_ps(pn0, pn0, acc_sq0); + let pn1 = _mm256_loadu_ps(pre_norm.add(row_start + (c + 1) * F32_LANES)); + acc_sq1 = _mm256_fmadd_ps(pn1, pn1, acc_sq1); + c += 2; + } + while c < chunks { + let pn = _mm256_loadu_ps(pre_norm.add(row_start + c * F32_LANES)); + acc_sq0 = _mm256_fmadd_ps(pn, pn, acc_sq0); + c += 1; + } + let mut sum_sq = hsum_f32(_mm256_add_ps(acc_sq0, acc_sq1)); + + for i in (chunks * F32_LANES)..hidden_size { + let pn = *pre_norm.add(row_start + i); + sum_sq += pn * pn; + } + + let mean_sq = sum_sq / hidden_size as f32; + let inv_rms = 1.0 / (mean_sq + eps).sqrt(); + + // Compute dot = sum(grad * weight * pre_norm) + let mut dot_acc = _mm256_setzero_ps(); + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let g = _mm256_loadu_ps(grad.add(offset)); + let w = _mm256_loadu_ps(weight.add(w_offset)); + let pn = _mm256_loadu_ps(pre_norm.add(offset)); + let gw = _mm256_mul_ps(g, w); + dot_acc = _mm256_fmadd_ps(gw, pn, dot_acc); + } + let mut dot = hsum_f32(dot_acc); + + for i in (chunks * F32_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + dot += g * w * pn; + } + + let coeff = dot * inv_rms / (hidden_size as f32 * (mean_sq + eps)); + let v_inv_rms = _mm256_set1_ps(inv_rms); + let v_coeff = _mm256_set1_ps(coeff); + + // Compute d_input_residual and accumulate d_weight + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let g = _mm256_loadu_ps(grad.add(offset)); + let w = _mm256_loadu_ps(weight.add(w_offset)); + let pn = _mm256_loadu_ps(pre_norm.add(offset)); + + // d_ir = (g*w - pn*coeff) * inv_rms + let gw = _mm256_mul_ps(g, w); + let pn_coeff = _mm256_mul_ps(pn, v_coeff); + let diff = _mm256_sub_ps(gw, pn_coeff); + let d_ir = _mm256_mul_ps(diff, v_inv_rms); + _mm256_storeu_ps(d_input_residual.add(offset), d_ir); + + // d_weight += g * pn * inv_rms + let dw_old = _mm256_loadu_ps(d_weight.add(w_offset)); + let gp = _mm256_mul_ps(g, pn); + let gp_inv = _mm256_mul_ps(gp, v_inv_rms); + let dw_new = _mm256_add_ps(dw_old, gp_inv); + _mm256_storeu_ps(d_weight.add(w_offset), dw_new); + } + + // Scalar tail + for i in (chunks * F32_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let d_ir = (g * w - pn * coeff) * inv_rms; + *d_input_residual.add(row_start + i) = d_ir; + + let d_w = g * pn * inv_rms; + *d_weight.add(i) += d_w; + } + } +} + +/// AVX2 Fused Add + RMS Norm Backward for f64 +#[target_feature(enable = "avx2", enable = "fma")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_bwd_f64( + grad: *const f64, + pre_norm: *const f64, + weight: *const f64, + d_input_residual: *mut f64, + d_weight: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut acc_sq0 = _mm256_setzero_pd(); + let mut acc_sq1 = _mm256_setzero_pd(); + let mut c = 0; + let chunk_pairs = chunks / 2 * 2; + while c < chunk_pairs { + let pn0 = _mm256_loadu_pd(pre_norm.add(row_start + c * F64_LANES)); + acc_sq0 = _mm256_fmadd_pd(pn0, pn0, acc_sq0); + let pn1 = _mm256_loadu_pd(pre_norm.add(row_start + (c + 1) * F64_LANES)); + acc_sq1 = _mm256_fmadd_pd(pn1, pn1, acc_sq1); + c += 2; + } + while c < chunks { + let pn = _mm256_loadu_pd(pre_norm.add(row_start + c * F64_LANES)); + acc_sq0 = _mm256_fmadd_pd(pn, pn, acc_sq0); + c += 1; + } + let mut sum_sq = hsum_f64(_mm256_add_pd(acc_sq0, acc_sq1)); + + for i in (chunks * F64_LANES)..hidden_size { + let pn = *pre_norm.add(row_start + i); + sum_sq += pn * pn; + } + + let mean_sq = sum_sq / hidden_size as f64; + let inv_rms = 1.0 / (mean_sq + eps).sqrt(); + + let mut dot_acc = _mm256_setzero_pd(); + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let g = _mm256_loadu_pd(grad.add(offset)); + let w = _mm256_loadu_pd(weight.add(w_offset)); + let pn = _mm256_loadu_pd(pre_norm.add(offset)); + let gw = _mm256_mul_pd(g, w); + dot_acc = _mm256_fmadd_pd(gw, pn, dot_acc); + } + let mut dot = hsum_f64(dot_acc); + + for i in (chunks * F64_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + dot += g * w * pn; + } + + let coeff = dot * inv_rms / (hidden_size as f64 * (mean_sq + eps)); + let v_inv_rms = _mm256_set1_pd(inv_rms); + let v_coeff = _mm256_set1_pd(coeff); + + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let g = _mm256_loadu_pd(grad.add(offset)); + let w = _mm256_loadu_pd(weight.add(w_offset)); + let pn = _mm256_loadu_pd(pre_norm.add(offset)); + + let gw = _mm256_mul_pd(g, w); + let pn_coeff = _mm256_mul_pd(pn, v_coeff); + let diff = _mm256_sub_pd(gw, pn_coeff); + let d_ir = _mm256_mul_pd(diff, v_inv_rms); + _mm256_storeu_pd(d_input_residual.add(offset), d_ir); + + let dw_old = _mm256_loadu_pd(d_weight.add(w_offset)); + let gp = _mm256_mul_pd(g, pn); + let gp_inv = _mm256_mul_pd(gp, v_inv_rms); + let dw_new = _mm256_add_pd(dw_old, gp_inv); + _mm256_storeu_pd(d_weight.add(w_offset), dw_new); + } + + for i in (chunks * F64_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let d_ir = (g * w - pn * coeff) * inv_rms; + *d_input_residual.add(row_start + i) = d_ir; + + let d_w = g * pn * inv_rms; + *d_weight.add(i) += d_w; + } + } +} diff --git a/src/runtime/cpu/kernels/simd/norm/avx2/layer_norm.rs b/src/runtime/cpu/kernels/simd/norm/avx2/layer_norm.rs new file mode 100644 index 00000000..462500bf --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/avx2/layer_norm.rs @@ -0,0 +1,145 @@ +//! AVX2 layer normalization kernels + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::{F32_LANES, F64_LANES, hsum_f32, hsum_f64}; + +/// AVX2 Layer normalization for f32 +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn layer_norm_f32( + input: *const f32, + weight: *const f32, + bias: *const f32, + out: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + // SIMD sum for mean + let mut sum_acc = _mm256_setzero_ps(); + for c in 0..chunks { + let v = _mm256_loadu_ps(input.add(row_start + c * F32_LANES)); + sum_acc = _mm256_add_ps(sum_acc, v); + } + let mut sum = hsum_f32(sum_acc); + + for i in (chunks * F32_LANES)..hidden_size { + sum += *input.add(row_start + i); + } + let mean = sum / hidden_size as f32; + let v_mean = _mm256_set1_ps(mean); + + // SIMD variance computation + let mut var_acc = _mm256_setzero_ps(); + for c in 0..chunks { + let v = _mm256_loadu_ps(input.add(row_start + c * F32_LANES)); + let diff = _mm256_sub_ps(v, v_mean); + var_acc = _mm256_fmadd_ps(diff, diff, var_acc); + } + let mut var_sum = hsum_f32(var_acc); + + for i in (chunks * F32_LANES)..hidden_size { + let diff = *input.add(row_start + i) - mean; + var_sum += diff * diff; + } + let inv_std = 1.0 / (var_sum / hidden_size as f32 + eps).sqrt(); + let v_inv_std = _mm256_set1_ps(inv_std); + + // SIMD normalization with weight and bias + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let v_input = _mm256_loadu_ps(input.add(offset)); + let v_weight = _mm256_loadu_ps(weight.add(w_offset)); + let v_bias = _mm256_loadu_ps(bias.add(w_offset)); + + let diff = _mm256_sub_ps(v_input, v_mean); + let normalized = _mm256_mul_ps(diff, v_inv_std); + let scaled = _mm256_mul_ps(normalized, v_weight); + let result = _mm256_add_ps(scaled, v_bias); + + _mm256_storeu_ps(out.add(offset), result); + } + + for i in (chunks * F32_LANES)..hidden_size { + let x = *input.add(row_start + i); + let w = *weight.add(i); + let b = *bias.add(i); + *out.add(row_start + i) = (x - mean) * inv_std * w + b; + } + } +} + +/// AVX2 Layer normalization for f64 +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn layer_norm_f64( + input: *const f64, + weight: *const f64, + bias: *const f64, + out: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum_acc = _mm256_setzero_pd(); + for c in 0..chunks { + let v = _mm256_loadu_pd(input.add(row_start + c * F64_LANES)); + sum_acc = _mm256_add_pd(sum_acc, v); + } + let mut sum = hsum_f64(sum_acc); + + for i in (chunks * F64_LANES)..hidden_size { + sum += *input.add(row_start + i); + } + let mean = sum / hidden_size as f64; + let v_mean = _mm256_set1_pd(mean); + + let mut var_acc = _mm256_setzero_pd(); + for c in 0..chunks { + let v = _mm256_loadu_pd(input.add(row_start + c * F64_LANES)); + let diff = _mm256_sub_pd(v, v_mean); + var_acc = _mm256_fmadd_pd(diff, diff, var_acc); + } + let mut var_sum = hsum_f64(var_acc); + + for i in (chunks * F64_LANES)..hidden_size { + let diff = *input.add(row_start + i) - mean; + var_sum += diff * diff; + } + let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); + let v_inv_std = _mm256_set1_pd(inv_std); + + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let v_input = _mm256_loadu_pd(input.add(offset)); + let v_weight = _mm256_loadu_pd(weight.add(w_offset)); + let v_bias = _mm256_loadu_pd(bias.add(w_offset)); + + let diff = _mm256_sub_pd(v_input, v_mean); + let normalized = _mm256_mul_pd(diff, v_inv_std); + let scaled = _mm256_mul_pd(normalized, v_weight); + let result = _mm256_add_pd(scaled, v_bias); + + _mm256_storeu_pd(out.add(offset), result); + } + + for i in (chunks * F64_LANES)..hidden_size { + let x = *input.add(row_start + i); + let w = *weight.add(i); + let b = *bias.add(i); + *out.add(row_start + i) = (x - mean) * inv_std * w + b; + } + } +} diff --git a/src/runtime/cpu/kernels/simd/norm/avx2/mod.rs b/src/runtime/cpu/kernels/simd/norm/avx2/mod.rs new file mode 100644 index 00000000..3a9c27ae --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/avx2/mod.rs @@ -0,0 +1,55 @@ +//! AVX2 normalization kernels +//! +//! SIMD-optimized RMS norm and layer norm with manual horizontal reductions. + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +pub(super) const F32_LANES: usize = 8; +pub(super) const F64_LANES: usize = 4; + +mod fused_add_layer_norm; +mod fused_add_rms_norm; +mod layer_norm; +mod rms_norm; + +pub use fused_add_layer_norm::{ + fused_add_layer_norm_bwd_f32, fused_add_layer_norm_bwd_f64, fused_add_layer_norm_f32, + fused_add_layer_norm_f64, +}; +pub use fused_add_rms_norm::{ + fused_add_rms_norm_bwd_f32, fused_add_rms_norm_bwd_f64, fused_add_rms_norm_f32, + fused_add_rms_norm_f64, +}; +pub use layer_norm::{layer_norm_f32, layer_norm_f64}; +pub use rms_norm::{rms_norm_f32, rms_norm_f64}; + +// ============================================================================ +// Horizontal reduction helpers (used by sub-modules) +// ============================================================================ + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub(super) unsafe fn hsum_f32(v: __m256) -> f32 { + let high = _mm256_extractf128_ps(v, 1); + let low = _mm256_castps256_ps128(v); + let sum128 = _mm_add_ps(low, high); + let shuf = _mm_movehdup_ps(sum128); + let sum64 = _mm_add_ps(sum128, shuf); + let shuf2 = _mm_movehl_ps(sum64, sum64); + let sum32 = _mm_add_ss(sum64, shuf2); + _mm_cvtss_f32(sum32) +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2", enable = "fma")] +#[inline] +pub(super) unsafe fn hsum_f64(v: __m256d) -> f64 { + let high = _mm256_extractf128_pd(v, 1); + let low = _mm256_castpd256_pd128(v); + let sum128 = _mm_add_pd(low, high); + let shuf = _mm_unpackhi_pd(sum128, sum128); + let sum64 = _mm_add_sd(sum128, shuf); + _mm_cvtsd_f64(sum64) +} diff --git a/src/runtime/cpu/kernels/simd/norm/avx2/rms_norm.rs b/src/runtime/cpu/kernels/simd/norm/avx2/rms_norm.rs new file mode 100644 index 00000000..7f3d733a --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/avx2/rms_norm.rs @@ -0,0 +1,108 @@ +//! AVX2 RMS normalization kernels + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::{F32_LANES, F64_LANES, hsum_f64}; + +/// AVX2 RMS normalization for f32 +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn rms_norm_f32( + input: *const f32, + weight: *const f32, + out: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + // Accumulate sum of squares in f64 for precision (matches llama.cpp) + let mut acc_lo = _mm256_setzero_pd(); + let mut acc_hi = _mm256_setzero_pd(); + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let v = _mm256_loadu_ps(input.add(offset)); + // Split 8xf32 into 2x4xf64 + let lo = _mm256_cvtps_pd(_mm256_castps256_ps128(v)); + let hi = _mm256_cvtps_pd(_mm256_extractf128_ps(v, 1)); + acc_lo = _mm256_fmadd_pd(lo, lo, acc_lo); + acc_hi = _mm256_fmadd_pd(hi, hi, acc_hi); + } + let mut sum_sq = hsum_f64(_mm256_add_pd(acc_lo, acc_hi)); + + for i in (chunks * F32_LANES)..hidden_size { + let x = *input.add(row_start + i) as f64; + sum_sq += x * x; + } + + let inv_rms = (1.0f64 / (sum_sq / hidden_size as f64 + eps as f64).sqrt()) as f32; + let v_inv_rms = _mm256_set1_ps(inv_rms); + + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let v_input = _mm256_loadu_ps(input.add(offset)); + let v_weight = _mm256_loadu_ps(weight.add(w_offset)); + let v_result = _mm256_mul_ps(_mm256_mul_ps(v_input, v_inv_rms), v_weight); + _mm256_storeu_ps(out.add(offset), v_result); + } + + for i in (chunks * F32_LANES)..hidden_size { + let x = *input.add(row_start + i); + let w = *weight.add(i); + *out.add(row_start + i) = x * inv_rms * w; + } + } +} + +/// AVX2 RMS normalization for f64 +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn rms_norm_f64( + input: *const f64, + weight: *const f64, + out: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut acc = _mm256_setzero_pd(); + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let v = _mm256_loadu_pd(input.add(offset)); + acc = _mm256_fmadd_pd(v, v, acc); + } + let mut sum_sq = hsum_f64(acc); + + for i in (chunks * F64_LANES)..hidden_size { + let x = *input.add(row_start + i); + sum_sq += x * x; + } + + let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt(); + let v_inv_rms = _mm256_set1_pd(inv_rms); + + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let v_input = _mm256_loadu_pd(input.add(offset)); + let v_weight = _mm256_loadu_pd(weight.add(w_offset)); + let v_result = _mm256_mul_pd(_mm256_mul_pd(v_input, v_inv_rms), v_weight); + _mm256_storeu_pd(out.add(offset), v_result); + } + + for i in (chunks * F64_LANES)..hidden_size { + let x = *input.add(row_start + i); + let w = *weight.add(i); + *out.add(row_start + i) = x * inv_rms * w; + } + } +} diff --git a/src/runtime/cpu/kernels/simd/norm/avx512/fused_add_layer_norm.rs b/src/runtime/cpu/kernels/simd/norm/avx512/fused_add_layer_norm.rs new file mode 100644 index 00000000..d902e348 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/avx512/fused_add_layer_norm.rs @@ -0,0 +1,502 @@ +//! AVX-512 fused add + layer normalization kernels (forward and backward) + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::{F32_LANES, F64_LANES}; + +/// AVX-512 Fused Add + Layer Normalization for f32 +#[target_feature(enable = "avx512f")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_f32( + input: *const f32, + residual: *const f32, + weight: *const f32, + bias: *const f32, + out: *mut f32, + pre_norm: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum_acc = _mm512_setzero_ps(); + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let v_in = _mm512_loadu_ps(input.add(offset)); + let v_res = _mm512_loadu_ps(residual.add(offset)); + let pn = _mm512_add_ps(v_in, v_res); + _mm512_storeu_ps(pre_norm.add(offset), pn); + sum_acc = _mm512_add_ps(sum_acc, pn); + } + let mut sum = _mm512_reduce_add_ps(sum_acc); + + for i in (chunks * F32_LANES)..hidden_size { + let pn = *input.add(row_start + i) + *residual.add(row_start + i); + *pre_norm.add(row_start + i) = pn; + sum += pn; + } + + let mean = sum / hidden_size as f32; + let v_mean = _mm512_set1_ps(mean); + + let mut var_acc0 = _mm512_setzero_ps(); + let mut var_acc1 = _mm512_setzero_ps(); + let mut c = 0; + let chunk_pairs = chunks / 2 * 2; + while c < chunk_pairs { + let diff0 = _mm512_sub_ps( + _mm512_loadu_ps(pre_norm.add(row_start + c * F32_LANES)), + v_mean, + ); + var_acc0 = _mm512_fmadd_ps(diff0, diff0, var_acc0); + let diff1 = _mm512_sub_ps( + _mm512_loadu_ps(pre_norm.add(row_start + (c + 1) * F32_LANES)), + v_mean, + ); + var_acc1 = _mm512_fmadd_ps(diff1, diff1, var_acc1); + c += 2; + } + while c < chunks { + let diff = _mm512_sub_ps( + _mm512_loadu_ps(pre_norm.add(row_start + c * F32_LANES)), + v_mean, + ); + var_acc0 = _mm512_fmadd_ps(diff, diff, var_acc0); + c += 1; + } + let mut var_sum = _mm512_reduce_add_ps(_mm512_add_ps(var_acc0, var_acc1)); + + for i in (chunks * F32_LANES)..hidden_size { + let diff = *pre_norm.add(row_start + i) - mean; + var_sum += diff * diff; + } + + let inv_std = 1.0 / (var_sum / hidden_size as f32 + eps).sqrt(); + let v_inv_std = _mm512_set1_ps(inv_std); + + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let pn = _mm512_loadu_ps(pre_norm.add(offset)); + let v_weight = _mm512_loadu_ps(weight.add(w_offset)); + let v_bias = _mm512_loadu_ps(bias.add(w_offset)); + + let diff = _mm512_sub_ps(pn, v_mean); + let normalized = _mm512_mul_ps(diff, v_inv_std); + let scaled = _mm512_mul_ps(normalized, v_weight); + let result = _mm512_add_ps(scaled, v_bias); + + _mm512_storeu_ps(out.add(offset), result); + } + + for i in (chunks * F32_LANES)..hidden_size { + let pn = *pre_norm.add(row_start + i); + let w = *weight.add(i); + let b = *bias.add(i); + *out.add(row_start + i) = (pn - mean) * inv_std * w + b; + } + } +} + +/// AVX-512 Fused Add + Layer Normalization for f64 +#[target_feature(enable = "avx512f")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_f64( + input: *const f64, + residual: *const f64, + weight: *const f64, + bias: *const f64, + out: *mut f64, + pre_norm: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum_acc = _mm512_setzero_pd(); + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let v_in = _mm512_loadu_pd(input.add(offset)); + let v_res = _mm512_loadu_pd(residual.add(offset)); + let pn = _mm512_add_pd(v_in, v_res); + _mm512_storeu_pd(pre_norm.add(offset), pn); + sum_acc = _mm512_add_pd(sum_acc, pn); + } + let mut sum = _mm512_reduce_add_pd(sum_acc); + + for i in (chunks * F64_LANES)..hidden_size { + let pn = *input.add(row_start + i) + *residual.add(row_start + i); + *pre_norm.add(row_start + i) = pn; + sum += pn; + } + + let mean = sum / hidden_size as f64; + let v_mean = _mm512_set1_pd(mean); + + let mut var_acc0 = _mm512_setzero_pd(); + let mut var_acc1 = _mm512_setzero_pd(); + let mut c = 0; + let chunk_pairs_v = chunks / 2 * 2; + while c < chunk_pairs_v { + let diff0 = _mm512_sub_pd( + _mm512_loadu_pd(pre_norm.add(row_start + c * F64_LANES)), + v_mean, + ); + var_acc0 = _mm512_fmadd_pd(diff0, diff0, var_acc0); + let diff1 = _mm512_sub_pd( + _mm512_loadu_pd(pre_norm.add(row_start + (c + 1) * F64_LANES)), + v_mean, + ); + var_acc1 = _mm512_fmadd_pd(diff1, diff1, var_acc1); + c += 2; + } + while c < chunks { + let diff = _mm512_sub_pd( + _mm512_loadu_pd(pre_norm.add(row_start + c * F64_LANES)), + v_mean, + ); + var_acc0 = _mm512_fmadd_pd(diff, diff, var_acc0); + c += 1; + } + let mut var_sum = _mm512_reduce_add_pd(_mm512_add_pd(var_acc0, var_acc1)); + + for i in (chunks * F64_LANES)..hidden_size { + let diff = *pre_norm.add(row_start + i) - mean; + var_sum += diff * diff; + } + + let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); + let v_inv_std = _mm512_set1_pd(inv_std); + + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let pn = _mm512_loadu_pd(pre_norm.add(offset)); + let v_weight = _mm512_loadu_pd(weight.add(w_offset)); + let v_bias = _mm512_loadu_pd(bias.add(w_offset)); + + let diff = _mm512_sub_pd(pn, v_mean); + let normalized = _mm512_mul_pd(diff, v_inv_std); + let scaled = _mm512_mul_pd(normalized, v_weight); + let result = _mm512_add_pd(scaled, v_bias); + + _mm512_storeu_pd(out.add(offset), result); + } + + for i in (chunks * F64_LANES)..hidden_size { + let pn = *pre_norm.add(row_start + i); + let w = *weight.add(i); + let b = *bias.add(i); + *out.add(row_start + i) = (pn - mean) * inv_std * w + b; + } + } +} + +/// AVX-512 Fused Add + Layer Norm Backward for f32 +#[target_feature(enable = "avx512f")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_bwd_f32( + grad: *const f32, + pre_norm: *const f32, + weight: *const f32, + d_input_residual: *mut f32, + d_weight: *mut f32, + d_bias: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum_acc = _mm512_setzero_ps(); + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let pn = _mm512_loadu_ps(pre_norm.add(offset)); + sum_acc = _mm512_add_ps(sum_acc, pn); + } + let mut sum = _mm512_reduce_add_ps(sum_acc); + + for i in (chunks * F32_LANES)..hidden_size { + sum += *pre_norm.add(row_start + i); + } + + let mean = sum / hidden_size as f32; + let v_mean = _mm512_set1_ps(mean); + + let mut var_acc0 = _mm512_setzero_ps(); + let mut var_acc1 = _mm512_setzero_ps(); + let mut c = 0; + let chunk_pairs_v = chunks / 2 * 2; + while c < chunk_pairs_v { + let diff0 = _mm512_sub_ps( + _mm512_loadu_ps(pre_norm.add(row_start + c * F32_LANES)), + v_mean, + ); + var_acc0 = _mm512_fmadd_ps(diff0, diff0, var_acc0); + let diff1 = _mm512_sub_ps( + _mm512_loadu_ps(pre_norm.add(row_start + (c + 1) * F32_LANES)), + v_mean, + ); + var_acc1 = _mm512_fmadd_ps(diff1, diff1, var_acc1); + c += 2; + } + while c < chunks { + let diff = _mm512_sub_ps( + _mm512_loadu_ps(pre_norm.add(row_start + c * F32_LANES)), + v_mean, + ); + var_acc0 = _mm512_fmadd_ps(diff, diff, var_acc0); + c += 1; + } + let mut var_sum = _mm512_reduce_add_ps(_mm512_add_ps(var_acc0, var_acc1)); + + for i in (chunks * F32_LANES)..hidden_size { + let diff = *pre_norm.add(row_start + i) - mean; + var_sum += diff * diff; + } + + let inv_std = 1.0 / (var_sum / hidden_size as f32 + eps).sqrt(); + + let mut gs_acc = _mm512_setzero_ps(); + let mut gsn_acc = _mm512_setzero_ps(); + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let g = _mm512_loadu_ps(grad.add(offset)); + let w = _mm512_loadu_ps(weight.add(w_offset)); + let pn = _mm512_loadu_ps(pre_norm.add(offset)); + + let gs = _mm512_mul_ps(g, w); + gs_acc = _mm512_add_ps(gs_acc, gs); + + let diff = _mm512_sub_ps(pn, v_mean); + let normalized = _mm512_mul_ps(diff, _mm512_set1_ps(inv_std)); + let gsn = _mm512_mul_ps(gs, normalized); + gsn_acc = _mm512_add_ps(gsn_acc, gsn); + } + let mut mean_gs_simd = _mm512_reduce_add_ps(gs_acc); + let mut mean_gsn_simd = _mm512_reduce_add_ps(gsn_acc); + + for i in (chunks * F32_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let gs = g * w; + mean_gs_simd += gs; + + let normalized = (pn - mean) * inv_std; + mean_gsn_simd += gs * normalized; + } + + let mean_gs = mean_gs_simd / hidden_size as f32; + let mean_gs_n = mean_gsn_simd / hidden_size as f32; + let v_inv_std = _mm512_set1_ps(inv_std); + let v_mean_gs = _mm512_set1_ps(mean_gs); + let v_mean_gs_n = _mm512_set1_ps(mean_gs_n); + + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let g = _mm512_loadu_ps(grad.add(offset)); + let w = _mm512_loadu_ps(weight.add(w_offset)); + let pn = _mm512_loadu_ps(pre_norm.add(offset)); + + let normalized = _mm512_mul_ps(_mm512_sub_ps(pn, v_mean), v_inv_std); + let gs = _mm512_mul_ps(g, w); + let d_ir = _mm512_mul_ps( + v_inv_std, + _mm512_sub_ps( + gs, + _mm512_add_ps(v_mean_gs, _mm512_mul_ps(normalized, v_mean_gs_n)), + ), + ); + _mm512_storeu_ps(d_input_residual.add(offset), d_ir); + + let dw_old = _mm512_loadu_ps(d_weight.add(w_offset)); + let dw_add = _mm512_mul_ps(g, normalized); + let dw_new = _mm512_add_ps(dw_old, dw_add); + _mm512_storeu_ps(d_weight.add(w_offset), dw_new); + + let db_old = _mm512_loadu_ps(d_bias.add(w_offset)); + let db_new = _mm512_add_ps(db_old, g); + _mm512_storeu_ps(d_bias.add(w_offset), db_new); + } + + for i in (chunks * F32_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let normalized = (pn - mean) * inv_std; + let gs = g * w; + let d_ir = inv_std * (gs - mean_gs - normalized * mean_gs_n); + *d_input_residual.add(row_start + i) = d_ir; + + *d_weight.add(i) += g * normalized; + *d_bias.add(i) += g; + } + } +} + +/// AVX-512 Fused Add + Layer Norm Backward for f64 +#[target_feature(enable = "avx512f")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_bwd_f64( + grad: *const f64, + pre_norm: *const f64, + weight: *const f64, + d_input_residual: *mut f64, + d_weight: *mut f64, + d_bias: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum_acc = _mm512_setzero_pd(); + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let pn = _mm512_loadu_pd(pre_norm.add(offset)); + sum_acc = _mm512_add_pd(sum_acc, pn); + } + let mut sum = _mm512_reduce_add_pd(sum_acc); + + for i in (chunks * F64_LANES)..hidden_size { + sum += *pre_norm.add(row_start + i); + } + + let mean = sum / hidden_size as f64; + let v_mean = _mm512_set1_pd(mean); + + let mut var_acc0 = _mm512_setzero_pd(); + let mut var_acc1 = _mm512_setzero_pd(); + let mut c = 0; + let chunk_pairs_v = chunks / 2 * 2; + while c < chunk_pairs_v { + let diff0 = _mm512_sub_pd( + _mm512_loadu_pd(pre_norm.add(row_start + c * F64_LANES)), + v_mean, + ); + var_acc0 = _mm512_fmadd_pd(diff0, diff0, var_acc0); + let diff1 = _mm512_sub_pd( + _mm512_loadu_pd(pre_norm.add(row_start + (c + 1) * F64_LANES)), + v_mean, + ); + var_acc1 = _mm512_fmadd_pd(diff1, diff1, var_acc1); + c += 2; + } + while c < chunks { + let diff = _mm512_sub_pd( + _mm512_loadu_pd(pre_norm.add(row_start + c * F64_LANES)), + v_mean, + ); + var_acc0 = _mm512_fmadd_pd(diff, diff, var_acc0); + c += 1; + } + let mut var_sum = _mm512_reduce_add_pd(_mm512_add_pd(var_acc0, var_acc1)); + + for i in (chunks * F64_LANES)..hidden_size { + let diff = *pre_norm.add(row_start + i) - mean; + var_sum += diff * diff; + } + + let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); + + let mut gs_acc = _mm512_setzero_pd(); + let mut gsn_acc = _mm512_setzero_pd(); + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let g = _mm512_loadu_pd(grad.add(offset)); + let w = _mm512_loadu_pd(weight.add(w_offset)); + let pn = _mm512_loadu_pd(pre_norm.add(offset)); + + let gs = _mm512_mul_pd(g, w); + gs_acc = _mm512_add_pd(gs_acc, gs); + + let diff = _mm512_sub_pd(pn, v_mean); + let normalized = _mm512_mul_pd(diff, _mm512_set1_pd(inv_std)); + let gsn = _mm512_mul_pd(gs, normalized); + gsn_acc = _mm512_add_pd(gsn_acc, gsn); + } + let mut mean_gs_simd = _mm512_reduce_add_pd(gs_acc); + let mut mean_gsn_simd = _mm512_reduce_add_pd(gsn_acc); + + for i in (chunks * F64_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let gs = g * w; + mean_gs_simd += gs; + + let normalized = (pn - mean) * inv_std; + mean_gsn_simd += gs * normalized; + } + + let mean_gs = mean_gs_simd / hidden_size as f64; + let mean_gs_n = mean_gsn_simd / hidden_size as f64; + let v_inv_std = _mm512_set1_pd(inv_std); + let v_mean_gs = _mm512_set1_pd(mean_gs); + let v_mean_gs_n = _mm512_set1_pd(mean_gs_n); + + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let g = _mm512_loadu_pd(grad.add(offset)); + let w = _mm512_loadu_pd(weight.add(w_offset)); + let pn = _mm512_loadu_pd(pre_norm.add(offset)); + + let normalized = _mm512_mul_pd(_mm512_sub_pd(pn, v_mean), v_inv_std); + let gs = _mm512_mul_pd(g, w); + let d_ir = _mm512_mul_pd( + v_inv_std, + _mm512_sub_pd( + gs, + _mm512_add_pd(v_mean_gs, _mm512_mul_pd(normalized, v_mean_gs_n)), + ), + ); + _mm512_storeu_pd(d_input_residual.add(offset), d_ir); + + let dw_old = _mm512_loadu_pd(d_weight.add(w_offset)); + let dw_add = _mm512_mul_pd(g, normalized); + let dw_new = _mm512_add_pd(dw_old, dw_add); + _mm512_storeu_pd(d_weight.add(w_offset), dw_new); + + let db_old = _mm512_loadu_pd(d_bias.add(w_offset)); + let db_new = _mm512_add_pd(db_old, g); + _mm512_storeu_pd(d_bias.add(w_offset), db_new); + } + + for i in (chunks * F64_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let normalized = (pn - mean) * inv_std; + let gs = g * w; + let d_ir = inv_std * (gs - mean_gs - normalized * mean_gs_n); + *d_input_residual.add(row_start + i) = d_ir; + + *d_weight.add(i) += g * normalized; + *d_bias.add(i) += g; + } + } +} diff --git a/src/runtime/cpu/kernels/simd/norm/avx512/fused_add_rms_norm.rs b/src/runtime/cpu/kernels/simd/norm/avx512/fused_add_rms_norm.rs new file mode 100644 index 00000000..583a0446 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/avx512/fused_add_rms_norm.rs @@ -0,0 +1,366 @@ +//! AVX-512 fused add + RMS normalization kernels (forward and backward) + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::{F32_LANES, F64_LANES}; + +/// AVX-512 Fused Add + RMS Normalization for f32 +/// +/// Computes: output = (input + residual) * rsqrt(mean((input + residual)^2) + eps) * weight +/// Stores intermediate (input + residual) in pre_norm for backward pass. +#[target_feature(enable = "avx512f")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_f32( + input: *const f32, + residual: *const f32, + weight: *const f32, + out: *mut f32, + pre_norm: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut acc0 = _mm512_setzero_ps(); + let mut acc1 = _mm512_setzero_ps(); + let mut c = 0; + let chunk_pairs = chunks / 2 * 2; + while c < chunk_pairs { + let offset0 = row_start + c * F32_LANES; + let offset1 = row_start + (c + 1) * F32_LANES; + let pn0 = _mm512_add_ps( + _mm512_loadu_ps(input.add(offset0)), + _mm512_loadu_ps(residual.add(offset0)), + ); + _mm512_storeu_ps(pre_norm.add(offset0), pn0); + acc0 = _mm512_fmadd_ps(pn0, pn0, acc0); + let pn1 = _mm512_add_ps( + _mm512_loadu_ps(input.add(offset1)), + _mm512_loadu_ps(residual.add(offset1)), + ); + _mm512_storeu_ps(pre_norm.add(offset1), pn1); + acc1 = _mm512_fmadd_ps(pn1, pn1, acc1); + c += 2; + } + while c < chunks { + let offset = row_start + c * F32_LANES; + let pn = _mm512_add_ps( + _mm512_loadu_ps(input.add(offset)), + _mm512_loadu_ps(residual.add(offset)), + ); + _mm512_storeu_ps(pre_norm.add(offset), pn); + acc0 = _mm512_fmadd_ps(pn, pn, acc0); + c += 1; + } + let mut sum_sq = _mm512_reduce_add_ps(_mm512_add_ps(acc0, acc1)) as f64; + + for i in (chunks * F32_LANES)..hidden_size { + let pn = *input.add(row_start + i) + *residual.add(row_start + i); + *pre_norm.add(row_start + i) = pn; + let pn64 = pn as f64; + sum_sq += pn64 * pn64; + } + + let inv_rms = (1.0f64 / (sum_sq / hidden_size as f64 + eps as f64).sqrt()) as f32; + let v_inv_rms = _mm512_set1_ps(inv_rms); + + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let pn = _mm512_loadu_ps(pre_norm.add(offset)); + let v_weight = _mm512_loadu_ps(weight.add(w_offset)); + let v_result = _mm512_mul_ps(_mm512_mul_ps(pn, v_inv_rms), v_weight); + _mm512_storeu_ps(out.add(offset), v_result); + } + + for i in (chunks * F32_LANES)..hidden_size { + let pn = *pre_norm.add(row_start + i); + let w = *weight.add(i); + *out.add(row_start + i) = pn * inv_rms * w; + } + } +} + +/// AVX-512 Fused Add + RMS Normalization for f64 +#[target_feature(enable = "avx512f")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_f64( + input: *const f64, + residual: *const f64, + weight: *const f64, + out: *mut f64, + pre_norm: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut acc0 = _mm512_setzero_pd(); + let mut acc1 = _mm512_setzero_pd(); + let mut c = 0; + let chunk_pairs = chunks / 2 * 2; + while c < chunk_pairs { + let offset0 = row_start + c * F64_LANES; + let offset1 = row_start + (c + 1) * F64_LANES; + let pn0 = _mm512_add_pd( + _mm512_loadu_pd(input.add(offset0)), + _mm512_loadu_pd(residual.add(offset0)), + ); + _mm512_storeu_pd(pre_norm.add(offset0), pn0); + acc0 = _mm512_fmadd_pd(pn0, pn0, acc0); + let pn1 = _mm512_add_pd( + _mm512_loadu_pd(input.add(offset1)), + _mm512_loadu_pd(residual.add(offset1)), + ); + _mm512_storeu_pd(pre_norm.add(offset1), pn1); + acc1 = _mm512_fmadd_pd(pn1, pn1, acc1); + c += 2; + } + while c < chunks { + let offset = row_start + c * F64_LANES; + let pn = _mm512_add_pd( + _mm512_loadu_pd(input.add(offset)), + _mm512_loadu_pd(residual.add(offset)), + ); + _mm512_storeu_pd(pre_norm.add(offset), pn); + acc0 = _mm512_fmadd_pd(pn, pn, acc0); + c += 1; + } + let mut sum_sq = _mm512_reduce_add_pd(_mm512_add_pd(acc0, acc1)); + + for i in (chunks * F64_LANES)..hidden_size { + let pn = *input.add(row_start + i) + *residual.add(row_start + i); + *pre_norm.add(row_start + i) = pn; + sum_sq += pn * pn; + } + + let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt(); + let v_inv_rms = _mm512_set1_pd(inv_rms); + + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let pn = _mm512_loadu_pd(pre_norm.add(offset)); + let v_weight = _mm512_loadu_pd(weight.add(w_offset)); + let v_result = _mm512_mul_pd(_mm512_mul_pd(pn, v_inv_rms), v_weight); + _mm512_storeu_pd(out.add(offset), v_result); + } + + for i in (chunks * F64_LANES)..hidden_size { + let pn = *pre_norm.add(row_start + i); + let w = *weight.add(i); + *out.add(row_start + i) = pn * inv_rms * w; + } + } +} + +/// AVX-512 Fused Add + RMS Norm Backward for f32 +#[target_feature(enable = "avx512f")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_bwd_f32( + grad: *const f32, + pre_norm: *const f32, + weight: *const f32, + d_input_residual: *mut f32, + d_weight: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut acc_sq0 = _mm512_setzero_ps(); + let mut acc_sq1 = _mm512_setzero_ps(); + let mut c = 0; + let chunk_pairs = chunks / 2 * 2; + while c < chunk_pairs { + let pn0 = _mm512_loadu_ps(pre_norm.add(row_start + c * F32_LANES)); + acc_sq0 = _mm512_fmadd_ps(pn0, pn0, acc_sq0); + let pn1 = _mm512_loadu_ps(pre_norm.add(row_start + (c + 1) * F32_LANES)); + acc_sq1 = _mm512_fmadd_ps(pn1, pn1, acc_sq1); + c += 2; + } + while c < chunks { + let pn = _mm512_loadu_ps(pre_norm.add(row_start + c * F32_LANES)); + acc_sq0 = _mm512_fmadd_ps(pn, pn, acc_sq0); + c += 1; + } + let mut sum_sq = _mm512_reduce_add_ps(_mm512_add_ps(acc_sq0, acc_sq1)); + + for i in (chunks * F32_LANES)..hidden_size { + let pn = *pre_norm.add(row_start + i); + sum_sq += pn * pn; + } + + let mean_sq = sum_sq / hidden_size as f32; + let inv_rms = 1.0 / (mean_sq + eps).sqrt(); + + let mut dot_acc = _mm512_setzero_ps(); + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let g = _mm512_loadu_ps(grad.add(offset)); + let w = _mm512_loadu_ps(weight.add(w_offset)); + let pn = _mm512_loadu_ps(pre_norm.add(offset)); + let gw = _mm512_mul_ps(g, w); + dot_acc = _mm512_fmadd_ps(gw, pn, dot_acc); + } + let mut dot = _mm512_reduce_add_ps(dot_acc); + + for i in (chunks * F32_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + dot += g * w * pn; + } + + let coeff = dot * inv_rms / (hidden_size as f32 * (mean_sq + eps)); + let v_inv_rms = _mm512_set1_ps(inv_rms); + let v_coeff = _mm512_set1_ps(coeff); + + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let g = _mm512_loadu_ps(grad.add(offset)); + let w = _mm512_loadu_ps(weight.add(w_offset)); + let pn = _mm512_loadu_ps(pre_norm.add(offset)); + + let gw = _mm512_mul_ps(g, w); + let pn_coeff = _mm512_mul_ps(pn, v_coeff); + let diff = _mm512_sub_ps(gw, pn_coeff); + let d_ir = _mm512_mul_ps(diff, v_inv_rms); + _mm512_storeu_ps(d_input_residual.add(offset), d_ir); + + let dw_old = _mm512_loadu_ps(d_weight.add(w_offset)); + let gp = _mm512_mul_ps(g, pn); + let gp_inv = _mm512_mul_ps(gp, v_inv_rms); + let dw_new = _mm512_add_ps(dw_old, gp_inv); + _mm512_storeu_ps(d_weight.add(w_offset), dw_new); + } + + for i in (chunks * F32_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let d_ir = (g * w - pn * coeff) * inv_rms; + *d_input_residual.add(row_start + i) = d_ir; + + let d_w = g * pn * inv_rms; + *d_weight.add(i) += d_w; + } + } +} + +/// AVX-512 Fused Add + RMS Norm Backward for f64 +#[target_feature(enable = "avx512f")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_bwd_f64( + grad: *const f64, + pre_norm: *const f64, + weight: *const f64, + d_input_residual: *mut f64, + d_weight: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut acc_sq0 = _mm512_setzero_pd(); + let mut acc_sq1 = _mm512_setzero_pd(); + let mut c = 0; + let chunk_pairs = chunks / 2 * 2; + while c < chunk_pairs { + let pn0 = _mm512_loadu_pd(pre_norm.add(row_start + c * F64_LANES)); + acc_sq0 = _mm512_fmadd_pd(pn0, pn0, acc_sq0); + let pn1 = _mm512_loadu_pd(pre_norm.add(row_start + (c + 1) * F64_LANES)); + acc_sq1 = _mm512_fmadd_pd(pn1, pn1, acc_sq1); + c += 2; + } + while c < chunks { + let pn = _mm512_loadu_pd(pre_norm.add(row_start + c * F64_LANES)); + acc_sq0 = _mm512_fmadd_pd(pn, pn, acc_sq0); + c += 1; + } + let mut sum_sq = _mm512_reduce_add_pd(_mm512_add_pd(acc_sq0, acc_sq1)); + + for i in (chunks * F64_LANES)..hidden_size { + let pn = *pre_norm.add(row_start + i); + sum_sq += pn * pn; + } + + let mean_sq = sum_sq / hidden_size as f64; + let inv_rms = 1.0 / (mean_sq + eps).sqrt(); + + let mut dot_acc = _mm512_setzero_pd(); + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let g = _mm512_loadu_pd(grad.add(offset)); + let w = _mm512_loadu_pd(weight.add(w_offset)); + let pn = _mm512_loadu_pd(pre_norm.add(offset)); + let gw = _mm512_mul_pd(g, w); + dot_acc = _mm512_fmadd_pd(gw, pn, dot_acc); + } + let mut dot = _mm512_reduce_add_pd(dot_acc); + + for i in (chunks * F64_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + dot += g * w * pn; + } + + let coeff = dot * inv_rms / (hidden_size as f64 * (mean_sq + eps)); + let v_inv_rms = _mm512_set1_pd(inv_rms); + let v_coeff = _mm512_set1_pd(coeff); + + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let g = _mm512_loadu_pd(grad.add(offset)); + let w = _mm512_loadu_pd(weight.add(w_offset)); + let pn = _mm512_loadu_pd(pre_norm.add(offset)); + + let gw = _mm512_mul_pd(g, w); + let pn_coeff = _mm512_mul_pd(pn, v_coeff); + let diff = _mm512_sub_pd(gw, pn_coeff); + let d_ir = _mm512_mul_pd(diff, v_inv_rms); + _mm512_storeu_pd(d_input_residual.add(offset), d_ir); + + let dw_old = _mm512_loadu_pd(d_weight.add(w_offset)); + let gp = _mm512_mul_pd(g, pn); + let gp_inv = _mm512_mul_pd(gp, v_inv_rms); + let dw_new = _mm512_add_pd(dw_old, gp_inv); + _mm512_storeu_pd(d_weight.add(w_offset), dw_new); + } + + for i in (chunks * F64_LANES)..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let d_ir = (g * w - pn * coeff) * inv_rms; + *d_input_residual.add(row_start + i) = d_ir; + + let d_w = g * pn * inv_rms; + *d_weight.add(i) += d_w; + } + } +} diff --git a/src/runtime/cpu/kernels/simd/norm/avx512.rs b/src/runtime/cpu/kernels/simd/norm/avx512/layer_norm.rs similarity index 53% rename from src/runtime/cpu/kernels/simd/norm/avx512.rs rename to src/runtime/cpu/kernels/simd/norm/avx512/layer_norm.rs index 741435e0..1947433f 100644 --- a/src/runtime/cpu/kernels/simd/norm/avx512.rs +++ b/src/runtime/cpu/kernels/simd/norm/avx512/layer_norm.rs @@ -1,122 +1,9 @@ -//! AVX-512 normalization kernels -//! -//! SIMD-optimized RMS norm and layer norm using: -//! - Vertical FMA accumulation for sum of squares -//! - Horizontal reduction intrinsics -//! - Vectorized final normalization pass +//! AVX-512 layer normalization kernels #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; -use super::{ - layer_norm_scalar_f32, layer_norm_scalar_f64, rms_norm_scalar_f32, rms_norm_scalar_f64, -}; - -const F32_LANES: usize = 16; -const F64_LANES: usize = 8; - -/// AVX-512 RMS normalization for f32 -#[target_feature(enable = "avx512f")] -pub unsafe fn rms_norm_f32( - input: *const f32, - weight: *const f32, - out: *mut f32, - batch_size: usize, - hidden_size: usize, - eps: f32, -) { - let chunks = hidden_size / F32_LANES; - let remainder = hidden_size % F32_LANES; - - for batch in 0..batch_size { - let row_start = batch * hidden_size; - - // SIMD sum of squares - let mut acc = _mm512_setzero_ps(); - for c in 0..chunks { - let offset = row_start + c * F32_LANES; - let v = _mm512_loadu_ps(input.add(offset)); - acc = _mm512_fmadd_ps(v, v, acc); // acc += v * v - } - let mut sum_sq = _mm512_reduce_add_ps(acc); - - // Scalar tail for sum of squares - for i in (chunks * F32_LANES)..hidden_size { - let x = *input.add(row_start + i); - sum_sq += x * x; - } - - // Compute inverse RMS - let inv_rms = 1.0 / (sum_sq / hidden_size as f32 + eps).sqrt(); - let v_inv_rms = _mm512_set1_ps(inv_rms); - - // SIMD normalization with weight - for c in 0..chunks { - let offset = row_start + c * F32_LANES; - let w_offset = c * F32_LANES; - let v_input = _mm512_loadu_ps(input.add(offset)); - let v_weight = _mm512_loadu_ps(weight.add(w_offset)); - let v_result = _mm512_mul_ps(_mm512_mul_ps(v_input, v_inv_rms), v_weight); - _mm512_storeu_ps(out.add(offset), v_result); - } - - // Scalar tail for normalization - for i in (chunks * F32_LANES)..hidden_size { - let x = *input.add(row_start + i); - let w = *weight.add(i); - *out.add(row_start + i) = x * inv_rms * w; - } - let _ = remainder; - } -} - -/// AVX-512 RMS normalization for f64 -#[target_feature(enable = "avx512f")] -pub unsafe fn rms_norm_f64( - input: *const f64, - weight: *const f64, - out: *mut f64, - batch_size: usize, - hidden_size: usize, - eps: f64, -) { - let chunks = hidden_size / F64_LANES; - - for batch in 0..batch_size { - let row_start = batch * hidden_size; - - let mut acc = _mm512_setzero_pd(); - for c in 0..chunks { - let offset = row_start + c * F64_LANES; - let v = _mm512_loadu_pd(input.add(offset)); - acc = _mm512_fmadd_pd(v, v, acc); - } - let mut sum_sq = _mm512_reduce_add_pd(acc); - - for i in (chunks * F64_LANES)..hidden_size { - let x = *input.add(row_start + i); - sum_sq += x * x; - } - - let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt(); - let v_inv_rms = _mm512_set1_pd(inv_rms); - - for c in 0..chunks { - let offset = row_start + c * F64_LANES; - let w_offset = c * F64_LANES; - let v_input = _mm512_loadu_pd(input.add(offset)); - let v_weight = _mm512_loadu_pd(weight.add(w_offset)); - let v_result = _mm512_mul_pd(_mm512_mul_pd(v_input, v_inv_rms), v_weight); - _mm512_storeu_pd(out.add(offset), v_result); - } - - for i in (chunks * F64_LANES)..hidden_size { - let x = *input.add(row_start + i); - let w = *weight.add(i); - *out.add(row_start + i) = x * inv_rms * w; - } - } -} +use super::{F32_LANES, F64_LANES}; /// AVX-512 Layer normalization for f32 #[target_feature(enable = "avx512f")] @@ -256,13 +143,3 @@ pub unsafe fn layer_norm_f64( } } } - -// Suppress unused warnings for scalar fallback imports used in dispatch -const _: () = { - let _ = rms_norm_scalar_f32 as unsafe fn(*const f32, *const f32, *mut f32, usize, usize, f32); - let _ = rms_norm_scalar_f64 as unsafe fn(*const f64, *const f64, *mut f64, usize, usize, f64); - let _ = layer_norm_scalar_f32 - as unsafe fn(*const f32, *const f32, *const f32, *mut f32, usize, usize, f32); - let _ = layer_norm_scalar_f64 - as unsafe fn(*const f64, *const f64, *const f64, *mut f64, usize, usize, f64); -}; diff --git a/src/runtime/cpu/kernels/simd/norm/avx512/mod.rs b/src/runtime/cpu/kernels/simd/norm/avx512/mod.rs new file mode 100644 index 00000000..5f148c15 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/avx512/mod.rs @@ -0,0 +1,25 @@ +//! AVX-512 normalization kernels +//! +//! SIMD-optimized RMS norm and layer norm using: +//! - Vertical FMA accumulation for sum of squares +//! - Horizontal reduction intrinsics +//! - Vectorized final normalization pass + +pub(super) const F32_LANES: usize = 16; +pub(super) const F64_LANES: usize = 8; + +mod fused_add_layer_norm; +mod fused_add_rms_norm; +mod layer_norm; +mod rms_norm; + +pub use fused_add_layer_norm::{ + fused_add_layer_norm_bwd_f32, fused_add_layer_norm_bwd_f64, fused_add_layer_norm_f32, + fused_add_layer_norm_f64, +}; +pub use fused_add_rms_norm::{ + fused_add_rms_norm_bwd_f32, fused_add_rms_norm_bwd_f64, fused_add_rms_norm_f32, + fused_add_rms_norm_f64, +}; +pub use layer_norm::{layer_norm_f32, layer_norm_f64}; +pub use rms_norm::{rms_norm_f32, rms_norm_f64}; diff --git a/src/runtime/cpu/kernels/simd/norm/avx512/rms_norm.rs b/src/runtime/cpu/kernels/simd/norm/avx512/rms_norm.rs new file mode 100644 index 00000000..9c284458 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/avx512/rms_norm.rs @@ -0,0 +1,109 @@ +//! AVX-512 RMS normalization kernels + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::{F32_LANES, F64_LANES}; + +/// AVX-512 RMS normalization for f32 +#[target_feature(enable = "avx512f")] +pub unsafe fn rms_norm_f32( + input: *const f32, + weight: *const f32, + out: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let chunks = hidden_size / F32_LANES; + let remainder = hidden_size % F32_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + // SIMD sum of squares + let mut acc = _mm512_setzero_ps(); + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let v = _mm512_loadu_ps(input.add(offset)); + acc = _mm512_fmadd_ps(v, v, acc); + } + let mut sum_sq = _mm512_reduce_add_ps(acc) as f64; + + // Scalar tail for sum of squares + for i in (chunks * F32_LANES)..hidden_size { + let x = *input.add(row_start + i) as f64; + sum_sq += x * x; + } + + // Compute inverse RMS in f64 for precision (matches llama.cpp) + let inv_rms = (1.0f64 / (sum_sq / hidden_size as f64 + eps as f64).sqrt()) as f32; + let v_inv_rms = _mm512_set1_ps(inv_rms); + + // SIMD normalization with weight + for c in 0..chunks { + let offset = row_start + c * F32_LANES; + let w_offset = c * F32_LANES; + let v_input = _mm512_loadu_ps(input.add(offset)); + let v_weight = _mm512_loadu_ps(weight.add(w_offset)); + let v_result = _mm512_mul_ps(_mm512_mul_ps(v_input, v_inv_rms), v_weight); + _mm512_storeu_ps(out.add(offset), v_result); + } + + // Scalar tail for normalization + for i in (chunks * F32_LANES)..hidden_size { + let x = *input.add(row_start + i); + let w = *weight.add(i); + *out.add(row_start + i) = x * inv_rms * w; + } + let _ = remainder; + } +} + +/// AVX-512 RMS normalization for f64 +#[target_feature(enable = "avx512f")] +pub unsafe fn rms_norm_f64( + input: *const f64, + weight: *const f64, + out: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let chunks = hidden_size / F64_LANES; + + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut acc = _mm512_setzero_pd(); + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let v = _mm512_loadu_pd(input.add(offset)); + acc = _mm512_fmadd_pd(v, v, acc); + } + let mut sum_sq = _mm512_reduce_add_pd(acc); + + for i in (chunks * F64_LANES)..hidden_size { + let x = *input.add(row_start + i); + sum_sq += x * x; + } + + let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt(); + let v_inv_rms = _mm512_set1_pd(inv_rms); + + for c in 0..chunks { + let offset = row_start + c * F64_LANES; + let w_offset = c * F64_LANES; + let v_input = _mm512_loadu_pd(input.add(offset)); + let v_weight = _mm512_loadu_pd(weight.add(w_offset)); + let v_result = _mm512_mul_pd(_mm512_mul_pd(v_input, v_inv_rms), v_weight); + _mm512_storeu_pd(out.add(offset), v_result); + } + + for i in (chunks * F64_LANES)..hidden_size { + let x = *input.add(row_start + i); + let w = *weight.add(i); + *out.add(row_start + i) = x * inv_rms * w; + } + } +} diff --git a/src/runtime/cpu/kernels/simd/norm/fused_add_layer_norm.rs b/src/runtime/cpu/kernels/simd/norm/fused_add_layer_norm.rs new file mode 100644 index 00000000..2e4a733b --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/fused_add_layer_norm.rs @@ -0,0 +1,649 @@ +//! SIMD dispatch and scalar fallbacks for fused Add + Layer normalization + +use super::super::{SimdLevel, detect_simd}; +use super::SIMD_THRESHOLD; + +#[cfg(target_arch = "x86_64")] +use super::avx2; +#[cfg(target_arch = "x86_64")] +use super::avx512; + +#[cfg(target_arch = "aarch64")] +use super::aarch64; + +// ============================================================================ +// Fused Add + Layer Norm (forward) +// ============================================================================ + +/// SIMD Fused Add + Layer Normalization for f32 +#[inline] +pub unsafe fn fused_add_layer_norm_f32( + input: *const f32, + residual: *const f32, + weight: *const f32, + bias: *const f32, + out: *mut f32, + pre_norm: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let level = detect_simd(); + + if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_add_layer_norm_scalar_f32( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::fused_add_layer_norm_f32( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + SimdLevel::Avx2Fma => avx2::fused_add_layer_norm_f32( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_layer_norm_scalar_f32( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::fused_add_layer_norm_f32( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_layer_norm_scalar_f32( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_add_layer_norm_scalar_f32( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ); +} + +/// SIMD Fused Add + Layer Normalization for f64 +#[inline] +pub unsafe fn fused_add_layer_norm_f64( + input: *const f64, + residual: *const f64, + weight: *const f64, + bias: *const f64, + out: *mut f64, + pre_norm: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let level = detect_simd(); + + if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_add_layer_norm_scalar_f64( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::fused_add_layer_norm_f64( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + SimdLevel::Avx2Fma => avx2::fused_add_layer_norm_f64( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_layer_norm_scalar_f64( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::fused_add_layer_norm_f64( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_layer_norm_scalar_f64( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_add_layer_norm_scalar_f64( + input, + residual, + weight, + bias, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ); +} + +// ============================================================================ +// Fused Add + Layer Norm (backward) +// ============================================================================ + +/// SIMD Fused Add + Layer Norm Backward for f32 +#[inline] +pub unsafe fn fused_add_layer_norm_bwd_f32( + grad: *const f32, + pre_norm: *const f32, + weight: *const f32, + d_input_residual: *mut f32, + d_weight: *mut f32, + d_bias: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let level = detect_simd(); + + if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_add_layer_norm_bwd_scalar_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::fused_add_layer_norm_bwd_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ), + SimdLevel::Avx2Fma => avx2::fused_add_layer_norm_bwd_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_layer_norm_bwd_scalar_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::fused_add_layer_norm_bwd_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_layer_norm_bwd_scalar_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_add_layer_norm_bwd_scalar_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ); +} + +/// SIMD Fused Add + Layer Norm Backward for f64 +#[inline] +pub unsafe fn fused_add_layer_norm_bwd_f64( + grad: *const f64, + pre_norm: *const f64, + weight: *const f64, + d_input_residual: *mut f64, + d_weight: *mut f64, + d_bias: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let level = detect_simd(); + + if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_add_layer_norm_bwd_scalar_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::fused_add_layer_norm_bwd_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ), + SimdLevel::Avx2Fma => avx2::fused_add_layer_norm_bwd_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_layer_norm_bwd_scalar_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::fused_add_layer_norm_bwd_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_layer_norm_bwd_scalar_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_add_layer_norm_bwd_scalar_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + d_bias, + batch_size, + hidden_size, + eps, + ); +} + +// ============================================================================ +// Scalar fallbacks for fused add + layer norm +// ============================================================================ + +/// Scalar fused add + layer norm for f32 +#[inline] +pub unsafe fn fused_add_layer_norm_scalar_f32( + input: *const f32, + residual: *const f32, + weight: *const f32, + bias: *const f32, + out: *mut f32, + pre_norm: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + // Add and store pre_norm, compute mean + let mut sum = 0.0f32; + for i in 0..hidden_size { + let pn = *input.add(row_start + i) + *residual.add(row_start + i); + *pre_norm.add(row_start + i) = pn; + sum += pn; + } + let mean = sum / hidden_size as f32; + + // Compute variance + let mut var_sum = 0.0f32; + for i in 0..hidden_size { + let diff = *pre_norm.add(row_start + i) - mean; + var_sum += diff * diff; + } + let inv_std = 1.0 / (var_sum / hidden_size as f32 + eps).sqrt(); + + // Normalize, apply weight and bias + for i in 0..hidden_size { + let pn = *pre_norm.add(row_start + i); + let w = *weight.add(i); + let b = *bias.add(i); + *out.add(row_start + i) = (pn - mean) * inv_std * w + b; + } + } +} + +/// Scalar fused add + layer norm for f64 +#[inline] +pub unsafe fn fused_add_layer_norm_scalar_f64( + input: *const f64, + residual: *const f64, + weight: *const f64, + bias: *const f64, + out: *mut f64, + pre_norm: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum = 0.0f64; + for i in 0..hidden_size { + let pn = *input.add(row_start + i) + *residual.add(row_start + i); + *pre_norm.add(row_start + i) = pn; + sum += pn; + } + let mean = sum / hidden_size as f64; + + let mut var_sum = 0.0f64; + for i in 0..hidden_size { + let diff = *pre_norm.add(row_start + i) - mean; + var_sum += diff * diff; + } + let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); + + for i in 0..hidden_size { + let pn = *pre_norm.add(row_start + i); + let w = *weight.add(i); + let b = *bias.add(i); + *out.add(row_start + i) = (pn - mean) * inv_std * w + b; + } + } +} + +/// Scalar fused add + layer norm backward for f32 +#[inline] +pub unsafe fn fused_add_layer_norm_bwd_scalar_f32( + grad: *const f32, + pre_norm: *const f32, + weight: *const f32, + d_input_residual: *mut f32, + d_weight: *mut f32, + d_bias: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum = 0.0f32; + for i in 0..hidden_size { + sum += *pre_norm.add(row_start + i); + } + let mean = sum / hidden_size as f32; + + let mut var_sum = 0.0f32; + for i in 0..hidden_size { + let diff = *pre_norm.add(row_start + i) - mean; + var_sum += diff * diff; + } + let inv_std = 1.0 / (var_sum / hidden_size as f32 + eps).sqrt(); + + let mut mean_gs = 0.0f32; + let mut mean_gs_n = 0.0f32; + for i in 0..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + let gs = g * w; + mean_gs += gs; + mean_gs_n += gs * (pn - mean) * inv_std; + } + mean_gs /= hidden_size as f32; + mean_gs_n /= hidden_size as f32; + + for i in 0..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + let normalized = (pn - mean) * inv_std; + let gs = g * w; + let d_ir = inv_std * (gs - mean_gs - normalized * mean_gs_n); + *d_input_residual.add(row_start + i) = d_ir; + + *d_weight.add(i) += g * normalized; + *d_bias.add(i) += g; + } + } +} + +/// Scalar fused add + layer norm backward for f64 +#[inline] +pub unsafe fn fused_add_layer_norm_bwd_scalar_f64( + grad: *const f64, + pre_norm: *const f64, + weight: *const f64, + d_input_residual: *mut f64, + d_weight: *mut f64, + d_bias: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum = 0.0f64; + for i in 0..hidden_size { + sum += *pre_norm.add(row_start + i); + } + let mean = sum / hidden_size as f64; + + let mut var_sum = 0.0f64; + for i in 0..hidden_size { + let diff = *pre_norm.add(row_start + i) - mean; + var_sum += diff * diff; + } + let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); + + let mut mean_gs = 0.0f64; + let mut mean_gs_n = 0.0f64; + for i in 0..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + let gs = g * w; + mean_gs += gs; + mean_gs_n += gs * (pn - mean) * inv_std; + } + mean_gs /= hidden_size as f64; + mean_gs_n /= hidden_size as f64; + + for i in 0..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + let normalized = (pn - mean) * inv_std; + let gs = g * w; + let d_ir = inv_std * (gs - mean_gs - normalized * mean_gs_n); + *d_input_residual.add(row_start + i) = d_ir; + + *d_weight.add(i) += g * normalized; + *d_bias.add(i) += g; + } + } +} diff --git a/src/runtime/cpu/kernels/simd/norm/fused_add_rms_norm.rs b/src/runtime/cpu/kernels/simd/norm/fused_add_rms_norm.rs new file mode 100644 index 00000000..0adac7b2 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/fused_add_rms_norm.rs @@ -0,0 +1,582 @@ +//! SIMD dispatch and scalar fallbacks for fused Add + RMS normalization + +use super::super::{SimdLevel, detect_simd}; +use super::SIMD_THRESHOLD; + +#[cfg(target_arch = "x86_64")] +use super::avx2; +#[cfg(target_arch = "x86_64")] +use super::avx512; + +#[cfg(target_arch = "aarch64")] +use super::aarch64; + +// ============================================================================ +// Fused Add + RMS Norm (forward) +// ============================================================================ + +/// SIMD Fused Add + RMS normalization for f32 +#[inline] +pub unsafe fn fused_add_rms_norm_f32( + input: *const f32, + residual: *const f32, + weight: *const f32, + out: *mut f32, + pre_norm: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let level = detect_simd(); + + if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_add_rms_norm_scalar_f32( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::fused_add_rms_norm_f32( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + SimdLevel::Avx2Fma => avx2::fused_add_rms_norm_f32( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_rms_norm_scalar_f32( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::fused_add_rms_norm_f32( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_rms_norm_scalar_f32( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_add_rms_norm_scalar_f32( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ); +} + +/// SIMD Fused Add + RMS normalization for f64 +#[inline] +pub unsafe fn fused_add_rms_norm_f64( + input: *const f64, + residual: *const f64, + weight: *const f64, + out: *mut f64, + pre_norm: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let level = detect_simd(); + + if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_add_rms_norm_scalar_f64( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::fused_add_rms_norm_f64( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + SimdLevel::Avx2Fma => avx2::fused_add_rms_norm_f64( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_rms_norm_scalar_f64( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::fused_add_rms_norm_f64( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_rms_norm_scalar_f64( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_add_rms_norm_scalar_f64( + input, + residual, + weight, + out, + pre_norm, + batch_size, + hidden_size, + eps, + ); +} + +// ============================================================================ +// Fused Add + RMS Norm (backward) +// ============================================================================ + +/// SIMD Fused Add + RMS Norm Backward for f32 +#[inline] +pub unsafe fn fused_add_rms_norm_bwd_f32( + grad: *const f32, + pre_norm: *const f32, + weight: *const f32, + d_input_residual: *mut f32, + d_weight: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let level = detect_simd(); + + if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_add_rms_norm_bwd_scalar_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::fused_add_rms_norm_bwd_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ), + SimdLevel::Avx2Fma => avx2::fused_add_rms_norm_bwd_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_rms_norm_bwd_scalar_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::fused_add_rms_norm_bwd_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_rms_norm_bwd_scalar_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_add_rms_norm_bwd_scalar_f32( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ); +} + +/// SIMD Fused Add + RMS Norm Backward for f64 +#[inline] +pub unsafe fn fused_add_rms_norm_bwd_f64( + grad: *const f64, + pre_norm: *const f64, + weight: *const f64, + d_input_residual: *mut f64, + d_weight: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let level = detect_simd(); + + if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + fused_add_rms_norm_bwd_scalar_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::fused_add_rms_norm_bwd_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ), + SimdLevel::Avx2Fma => avx2::fused_add_rms_norm_bwd_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_rms_norm_bwd_scalar_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::fused_add_rms_norm_bwd_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ), + _ => fused_add_rms_norm_bwd_scalar_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + fused_add_rms_norm_bwd_scalar_f64( + grad, + pre_norm, + weight, + d_input_residual, + d_weight, + batch_size, + hidden_size, + eps, + ); +} + +// ============================================================================ +// Scalar fallbacks for fused add + RMS norm +// ============================================================================ + +/// Scalar fused add + RMS norm for f32 +#[inline] +pub unsafe fn fused_add_rms_norm_scalar_f32( + input: *const f32, + residual: *const f32, + weight: *const f32, + out: *mut f32, + pre_norm: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + // Add and store pre_norm, compute sum of squares in f64 (matches llama.cpp) + let mut sum_sq = 0.0f64; + for i in 0..hidden_size { + let pn = *input.add(row_start + i) + *residual.add(row_start + i); + *pre_norm.add(row_start + i) = pn; + let pn64 = pn as f64; + sum_sq += pn64 * pn64; + } + + let inv_rms = (1.0f64 / (sum_sq / hidden_size as f64 + eps as f64).sqrt()) as f32; + + for i in 0..hidden_size { + let pn = *pre_norm.add(row_start + i); + let w = *weight.add(i); + *out.add(row_start + i) = pn * inv_rms * w; + } + } +} + +/// Scalar fused add + RMS norm for f64 +#[inline] +pub unsafe fn fused_add_rms_norm_scalar_f64( + input: *const f64, + residual: *const f64, + weight: *const f64, + out: *mut f64, + pre_norm: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum_sq = 0.0f64; + for i in 0..hidden_size { + let pn = *input.add(row_start + i) + *residual.add(row_start + i); + *pre_norm.add(row_start + i) = pn; + sum_sq += pn * pn; + } + + let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt(); + + for i in 0..hidden_size { + let pn = *pre_norm.add(row_start + i); + let w = *weight.add(i); + *out.add(row_start + i) = pn * inv_rms * w; + } + } +} + +/// Scalar fused add + RMS norm backward for f32 +#[inline] +pub unsafe fn fused_add_rms_norm_bwd_scalar_f32( + grad: *const f32, + pre_norm: *const f32, + weight: *const f32, + d_input_residual: *mut f32, + d_weight: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum_sq = 0.0f32; + for i in 0..hidden_size { + let pn = *pre_norm.add(row_start + i); + sum_sq += pn * pn; + } + + let mean_sq = sum_sq / hidden_size as f32; + let inv_rms = 1.0 / (mean_sq + eps).sqrt(); + + let mut dot = 0.0f32; + for i in 0..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + dot += g * w * pn; + } + + let coeff = dot * inv_rms / (hidden_size as f32 * (mean_sq + eps)); + + for i in 0..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let d_ir = (g * w - pn * coeff) * inv_rms; + *d_input_residual.add(row_start + i) = d_ir; + + let d_w = g * pn * inv_rms; + *d_weight.add(i) += d_w; + } + } +} + +/// Scalar fused add + RMS norm backward for f64 +#[inline] +pub unsafe fn fused_add_rms_norm_bwd_scalar_f64( + grad: *const f64, + pre_norm: *const f64, + weight: *const f64, + d_input_residual: *mut f64, + d_weight: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum_sq = 0.0f64; + for i in 0..hidden_size { + let pn = *pre_norm.add(row_start + i); + sum_sq += pn * pn; + } + + let mean_sq = sum_sq / hidden_size as f64; + let inv_rms = 1.0 / (mean_sq + eps).sqrt(); + + let mut dot = 0.0f64; + for i in 0..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + dot += g * w * pn; + } + + let coeff = dot * inv_rms / (hidden_size as f64 * (mean_sq + eps)); + + for i in 0..hidden_size { + let g = *grad.add(row_start + i); + let w = *weight.add(i); + let pn = *pre_norm.add(row_start + i); + + let d_ir = (g * w - pn * coeff) * inv_rms; + *d_input_residual.add(row_start + i) = d_ir; + + let d_w = g * pn * inv_rms; + *d_weight.add(i) += d_w; + } + } +} diff --git a/src/runtime/cpu/kernels/simd/norm/half.rs b/src/runtime/cpu/kernels/simd/norm/half.rs new file mode 100644 index 00000000..ab4f21a0 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/half.rs @@ -0,0 +1,486 @@ +//! f16/bf16 normalization wrappers via bulk f32 conversion +//! +//! Pre-converts all inputs to f32 using a single allocation, runs the f32 SIMD +//! norm kernel, then converts the output back. + +use super::super::half_convert_utils::*; + +/// f16 wrapper for RMS norm. +/// +/// # Safety +/// - `input` and `out` must point to `batch_size * hidden_size` elements +/// - `weight` must point to `hidden_size` elements +#[cfg(feature = "f16")] +pub unsafe fn rms_norm_f16( + input: *const half::f16, + weight: *const half::f16, + out: *mut half::f16, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let total = batch_size * hidden_size; + let mut buf = vec![0.0f32; total + hidden_size + total]; + let (input_f32, rest) = buf.split_at_mut(total); + let (weight_f32, out_f32) = rest.split_at_mut(hidden_size); + convert_f16_to_f32(input as *const u16, input_f32.as_mut_ptr(), total); + convert_f16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), hidden_size); + super::rms_norm_f32( + input_f32.as_ptr(), + weight_f32.as_ptr(), + out_f32.as_mut_ptr(), + batch_size, + hidden_size, + eps, + ); + convert_f32_to_f16(out_f32.as_ptr(), out as *mut u16, total); +} + +/// bf16 wrapper for RMS norm. +/// +/// # Safety +/// - `input` and `out` must point to `batch_size * hidden_size` elements +/// - `weight` must point to `hidden_size` elements +#[cfg(feature = "f16")] +pub unsafe fn rms_norm_bf16( + input: *const half::bf16, + weight: *const half::bf16, + out: *mut half::bf16, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let total = batch_size * hidden_size; + let mut buf = vec![0.0f32; total + hidden_size + total]; + let (input_f32, rest) = buf.split_at_mut(total); + let (weight_f32, out_f32) = rest.split_at_mut(hidden_size); + convert_bf16_to_f32(input as *const u16, input_f32.as_mut_ptr(), total); + convert_bf16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), hidden_size); + super::rms_norm_f32( + input_f32.as_ptr(), + weight_f32.as_ptr(), + out_f32.as_mut_ptr(), + batch_size, + hidden_size, + eps, + ); + convert_f32_to_bf16(out_f32.as_ptr(), out as *mut u16, total); +} + +/// f16 wrapper for layer norm. +/// +/// # Safety +/// - `input` and `out` must point to `batch_size * hidden_size` elements +/// - `weight` and `bias` must point to `hidden_size` elements +#[cfg(feature = "f16")] +pub unsafe fn layer_norm_f16( + input: *const half::f16, + weight: *const half::f16, + bias: *const half::f16, + out: *mut half::f16, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let total = batch_size * hidden_size; + let mut buf = vec![0.0f32; total + hidden_size + hidden_size + total]; + let (input_f32, rest) = buf.split_at_mut(total); + let (weight_f32, rest) = rest.split_at_mut(hidden_size); + let (bias_f32, out_f32) = rest.split_at_mut(hidden_size); + convert_f16_to_f32(input as *const u16, input_f32.as_mut_ptr(), total); + convert_f16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), hidden_size); + convert_f16_to_f32(bias as *const u16, bias_f32.as_mut_ptr(), hidden_size); + super::layer_norm_f32( + input_f32.as_ptr(), + weight_f32.as_ptr(), + bias_f32.as_ptr(), + out_f32.as_mut_ptr(), + batch_size, + hidden_size, + eps, + ); + convert_f32_to_f16(out_f32.as_ptr(), out as *mut u16, total); +} + +/// bf16 wrapper for layer norm. +/// +/// # Safety +/// - `input` and `out` must point to `batch_size * hidden_size` elements +/// - `weight` and `bias` must point to `hidden_size` elements +#[cfg(feature = "f16")] +pub unsafe fn layer_norm_bf16( + input: *const half::bf16, + weight: *const half::bf16, + bias: *const half::bf16, + out: *mut half::bf16, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let total = batch_size * hidden_size; + let mut buf = vec![0.0f32; total + hidden_size + hidden_size + total]; + let (input_f32, rest) = buf.split_at_mut(total); + let (weight_f32, rest) = rest.split_at_mut(hidden_size); + let (bias_f32, out_f32) = rest.split_at_mut(hidden_size); + convert_bf16_to_f32(input as *const u16, input_f32.as_mut_ptr(), total); + convert_bf16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), hidden_size); + convert_bf16_to_f32(bias as *const u16, bias_f32.as_mut_ptr(), hidden_size); + super::layer_norm_f32( + input_f32.as_ptr(), + weight_f32.as_ptr(), + bias_f32.as_ptr(), + out_f32.as_mut_ptr(), + batch_size, + hidden_size, + eps, + ); + convert_f32_to_bf16(out_f32.as_ptr(), out as *mut u16, total); +} + +/// f16 wrapper for fused add + RMS norm. +/// +/// # Safety +/// - `input`, `residual`, and `out` must point to `batch_size * hidden_size` elements +/// - `weight` must point to `hidden_size` elements +/// - `pre_norm` must point to `batch_size * hidden_size` elements +#[cfg(feature = "f16")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_f16( + input: *const half::f16, + residual: *const half::f16, + weight: *const half::f16, + out: *mut half::f16, + pre_norm: *mut half::f16, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let total = batch_size * hidden_size; + let mut buf = vec![0.0f32; total + total + hidden_size + total + total]; + let (input_f32, rest) = buf.split_at_mut(total); + let (residual_f32, rest) = rest.split_at_mut(total); + let (weight_f32, rest) = rest.split_at_mut(hidden_size); + let (out_f32, pre_norm_f32) = rest.split_at_mut(total); + convert_f16_to_f32(input as *const u16, input_f32.as_mut_ptr(), total); + convert_f16_to_f32(residual as *const u16, residual_f32.as_mut_ptr(), total); + convert_f16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), hidden_size); + super::fused_add_rms_norm_f32( + input_f32.as_ptr(), + residual_f32.as_ptr(), + weight_f32.as_ptr(), + out_f32.as_mut_ptr(), + pre_norm_f32.as_mut_ptr(), + batch_size, + hidden_size, + eps, + ); + convert_f32_to_f16(out_f32.as_ptr(), out as *mut u16, total); + convert_f32_to_f16(pre_norm_f32.as_ptr(), pre_norm as *mut u16, total); +} + +/// bf16 wrapper for fused add + RMS norm. +/// +/// # Safety +/// - `input`, `residual`, and `out` must point to `batch_size * hidden_size` elements +/// - `weight` must point to `hidden_size` elements +/// - `pre_norm` must point to `batch_size * hidden_size` elements +#[cfg(feature = "f16")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_bf16( + input: *const half::bf16, + residual: *const half::bf16, + weight: *const half::bf16, + out: *mut half::bf16, + pre_norm: *mut half::bf16, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let total = batch_size * hidden_size; + let mut buf = vec![0.0f32; total + total + hidden_size + total + total]; + let (input_f32, rest) = buf.split_at_mut(total); + let (residual_f32, rest) = rest.split_at_mut(total); + let (weight_f32, rest) = rest.split_at_mut(hidden_size); + let (out_f32, pre_norm_f32) = rest.split_at_mut(total); + convert_bf16_to_f32(input as *const u16, input_f32.as_mut_ptr(), total); + convert_bf16_to_f32(residual as *const u16, residual_f32.as_mut_ptr(), total); + convert_bf16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), hidden_size); + super::fused_add_rms_norm_f32( + input_f32.as_ptr(), + residual_f32.as_ptr(), + weight_f32.as_ptr(), + out_f32.as_mut_ptr(), + pre_norm_f32.as_mut_ptr(), + batch_size, + hidden_size, + eps, + ); + convert_f32_to_bf16(out_f32.as_ptr(), out as *mut u16, total); + convert_f32_to_bf16(pre_norm_f32.as_ptr(), pre_norm as *mut u16, total); +} + +/// f16 wrapper for backward pass of fused add + RMS norm. +/// +/// # Safety +/// - `grad` and `pre_norm` must point to `batch_size * hidden_size` elements +/// - `weight` must point to `hidden_size` elements +/// - `d_input_residual` must point to `batch_size * hidden_size` elements +/// - `d_weight` must point to `hidden_size` elements (pre-zeroed by caller) +#[cfg(feature = "f16")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_bwd_f16( + grad: *const half::f16, + pre_norm: *const half::f16, + weight: *const half::f16, + d_input_residual: *mut half::f16, + d_weight: *mut half::f16, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let total = batch_size * hidden_size; + let mut buf = vec![0.0f32; total + total + hidden_size + total + hidden_size]; + let (grad_f32, rest) = buf.split_at_mut(total); + let (pre_norm_f32, rest) = rest.split_at_mut(total); + let (weight_f32, rest) = rest.split_at_mut(hidden_size); + let (d_ir_f32, d_weight_f32) = rest.split_at_mut(total); + convert_f16_to_f32(grad as *const u16, grad_f32.as_mut_ptr(), total); + convert_f16_to_f32(pre_norm as *const u16, pre_norm_f32.as_mut_ptr(), total); + convert_f16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), hidden_size); + super::fused_add_rms_norm_bwd_f32( + grad_f32.as_ptr(), + pre_norm_f32.as_ptr(), + weight_f32.as_ptr(), + d_ir_f32.as_mut_ptr(), + d_weight_f32.as_mut_ptr(), + batch_size, + hidden_size, + eps, + ); + convert_f32_to_f16(d_ir_f32.as_ptr(), d_input_residual as *mut u16, total); + convert_f32_to_f16(d_weight_f32.as_ptr(), d_weight as *mut u16, hidden_size); +} + +/// bf16 wrapper for backward pass of fused add + RMS norm. +/// +/// # Safety +/// - `grad` and `pre_norm` must point to `batch_size * hidden_size` elements +/// - `weight` must point to `hidden_size` elements +/// - `d_input_residual` must point to `batch_size * hidden_size` elements +/// - `d_weight` must point to `hidden_size` elements (pre-zeroed by caller) +#[cfg(feature = "f16")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_rms_norm_bwd_bf16( + grad: *const half::bf16, + pre_norm: *const half::bf16, + weight: *const half::bf16, + d_input_residual: *mut half::bf16, + d_weight: *mut half::bf16, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let total = batch_size * hidden_size; + let mut buf = vec![0.0f32; total + total + hidden_size + total + hidden_size]; + let (grad_f32, rest) = buf.split_at_mut(total); + let (pre_norm_f32, rest) = rest.split_at_mut(total); + let (weight_f32, rest) = rest.split_at_mut(hidden_size); + let (d_ir_f32, d_weight_f32) = rest.split_at_mut(total); + convert_bf16_to_f32(grad as *const u16, grad_f32.as_mut_ptr(), total); + convert_bf16_to_f32(pre_norm as *const u16, pre_norm_f32.as_mut_ptr(), total); + convert_bf16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), hidden_size); + super::fused_add_rms_norm_bwd_f32( + grad_f32.as_ptr(), + pre_norm_f32.as_ptr(), + weight_f32.as_ptr(), + d_ir_f32.as_mut_ptr(), + d_weight_f32.as_mut_ptr(), + batch_size, + hidden_size, + eps, + ); + convert_f32_to_bf16(d_ir_f32.as_ptr(), d_input_residual as *mut u16, total); + convert_f32_to_bf16(d_weight_f32.as_ptr(), d_weight as *mut u16, hidden_size); +} + +/// f16 wrapper for fused add + layer norm. +/// +/// # Safety +/// - `input`, `residual`, and `out` must point to `batch_size * hidden_size` elements +/// - `weight` and `bias` must point to `hidden_size` elements +/// - `pre_norm` must point to `batch_size * hidden_size` elements +#[cfg(feature = "f16")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_f16( + input: *const half::f16, + residual: *const half::f16, + weight: *const half::f16, + bias: *const half::f16, + out: *mut half::f16, + pre_norm: *mut half::f16, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let total = batch_size * hidden_size; + let mut buf = vec![0.0f32; total + total + hidden_size + hidden_size + total + total]; + let (input_f32, rest) = buf.split_at_mut(total); + let (residual_f32, rest) = rest.split_at_mut(total); + let (weight_f32, rest) = rest.split_at_mut(hidden_size); + let (bias_f32, rest) = rest.split_at_mut(hidden_size); + let (out_f32, pre_norm_f32) = rest.split_at_mut(total); + convert_f16_to_f32(input as *const u16, input_f32.as_mut_ptr(), total); + convert_f16_to_f32(residual as *const u16, residual_f32.as_mut_ptr(), total); + convert_f16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), hidden_size); + convert_f16_to_f32(bias as *const u16, bias_f32.as_mut_ptr(), hidden_size); + super::fused_add_layer_norm_f32( + input_f32.as_ptr(), + residual_f32.as_ptr(), + weight_f32.as_ptr(), + bias_f32.as_ptr(), + out_f32.as_mut_ptr(), + pre_norm_f32.as_mut_ptr(), + batch_size, + hidden_size, + eps, + ); + convert_f32_to_f16(out_f32.as_ptr(), out as *mut u16, total); + convert_f32_to_f16(pre_norm_f32.as_ptr(), pre_norm as *mut u16, total); +} + +/// bf16 wrapper for fused add + layer norm. +/// +/// # Safety +/// - `input`, `residual`, and `out` must point to `batch_size * hidden_size` elements +/// - `weight` and `bias` must point to `hidden_size` elements +/// - `pre_norm` must point to `batch_size * hidden_size` elements +#[cfg(feature = "f16")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_bf16( + input: *const half::bf16, + residual: *const half::bf16, + weight: *const half::bf16, + bias: *const half::bf16, + out: *mut half::bf16, + pre_norm: *mut half::bf16, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let total = batch_size * hidden_size; + let mut buf = vec![0.0f32; total + total + hidden_size + hidden_size + total + total]; + let (input_f32, rest) = buf.split_at_mut(total); + let (residual_f32, rest) = rest.split_at_mut(total); + let (weight_f32, rest) = rest.split_at_mut(hidden_size); + let (bias_f32, rest) = rest.split_at_mut(hidden_size); + let (out_f32, pre_norm_f32) = rest.split_at_mut(total); + convert_bf16_to_f32(input as *const u16, input_f32.as_mut_ptr(), total); + convert_bf16_to_f32(residual as *const u16, residual_f32.as_mut_ptr(), total); + convert_bf16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), hidden_size); + convert_bf16_to_f32(bias as *const u16, bias_f32.as_mut_ptr(), hidden_size); + super::fused_add_layer_norm_f32( + input_f32.as_ptr(), + residual_f32.as_ptr(), + weight_f32.as_ptr(), + bias_f32.as_ptr(), + out_f32.as_mut_ptr(), + pre_norm_f32.as_mut_ptr(), + batch_size, + hidden_size, + eps, + ); + convert_f32_to_bf16(out_f32.as_ptr(), out as *mut u16, total); + convert_f32_to_bf16(pre_norm_f32.as_ptr(), pre_norm as *mut u16, total); +} + +/// f16 wrapper for backward pass of fused add + layer norm. +/// +/// # Safety +/// - `grad` and `pre_norm` must point to `batch_size * hidden_size` elements +/// - `weight` must point to `hidden_size` elements +/// - `d_input_residual` must point to `batch_size * hidden_size` elements +/// - `d_weight` and `d_bias` must point to `hidden_size` elements (pre-zeroed by caller) +#[cfg(feature = "f16")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_bwd_f16( + grad: *const half::f16, + pre_norm: *const half::f16, + weight: *const half::f16, + d_input_residual: *mut half::f16, + d_weight: *mut half::f16, + d_bias: *mut half::f16, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let total = batch_size * hidden_size; + let mut buf = vec![0.0f32; total + total + hidden_size + total + hidden_size + hidden_size]; + let (grad_f32, rest) = buf.split_at_mut(total); + let (pre_norm_f32, rest) = rest.split_at_mut(total); + let (weight_f32, rest) = rest.split_at_mut(hidden_size); + let (d_ir_f32, rest) = rest.split_at_mut(total); + let (d_weight_f32, d_bias_f32) = rest.split_at_mut(hidden_size); + convert_f16_to_f32(grad as *const u16, grad_f32.as_mut_ptr(), total); + convert_f16_to_f32(pre_norm as *const u16, pre_norm_f32.as_mut_ptr(), total); + convert_f16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), hidden_size); + super::fused_add_layer_norm_bwd_f32( + grad_f32.as_ptr(), + pre_norm_f32.as_ptr(), + weight_f32.as_ptr(), + d_ir_f32.as_mut_ptr(), + d_weight_f32.as_mut_ptr(), + d_bias_f32.as_mut_ptr(), + batch_size, + hidden_size, + eps, + ); + convert_f32_to_f16(d_ir_f32.as_ptr(), d_input_residual as *mut u16, total); + convert_f32_to_f16(d_weight_f32.as_ptr(), d_weight as *mut u16, hidden_size); + convert_f32_to_f16(d_bias_f32.as_ptr(), d_bias as *mut u16, hidden_size); +} + +/// bf16 wrapper for backward pass of fused add + layer norm. +/// +/// # Safety +/// - `grad` and `pre_norm` must point to `batch_size * hidden_size` elements +/// - `weight` must point to `hidden_size` elements +/// - `d_input_residual` must point to `batch_size * hidden_size` elements +/// - `d_weight` and `d_bias` must point to `hidden_size` elements (pre-zeroed by caller) +#[cfg(feature = "f16")] +#[allow(clippy::too_many_arguments)] +pub unsafe fn fused_add_layer_norm_bwd_bf16( + grad: *const half::bf16, + pre_norm: *const half::bf16, + weight: *const half::bf16, + d_input_residual: *mut half::bf16, + d_weight: *mut half::bf16, + d_bias: *mut half::bf16, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let total = batch_size * hidden_size; + let mut buf = vec![0.0f32; total + total + hidden_size + total + hidden_size + hidden_size]; + let (grad_f32, rest) = buf.split_at_mut(total); + let (pre_norm_f32, rest) = rest.split_at_mut(total); + let (weight_f32, rest) = rest.split_at_mut(hidden_size); + let (d_ir_f32, rest) = rest.split_at_mut(total); + let (d_weight_f32, d_bias_f32) = rest.split_at_mut(hidden_size); + convert_bf16_to_f32(grad as *const u16, grad_f32.as_mut_ptr(), total); + convert_bf16_to_f32(pre_norm as *const u16, pre_norm_f32.as_mut_ptr(), total); + convert_bf16_to_f32(weight as *const u16, weight_f32.as_mut_ptr(), hidden_size); + super::fused_add_layer_norm_bwd_f32( + grad_f32.as_ptr(), + pre_norm_f32.as_ptr(), + weight_f32.as_ptr(), + d_ir_f32.as_mut_ptr(), + d_weight_f32.as_mut_ptr(), + d_bias_f32.as_mut_ptr(), + batch_size, + hidden_size, + eps, + ); + convert_f32_to_bf16(d_ir_f32.as_ptr(), d_input_residual as *mut u16, total); + convert_f32_to_bf16(d_weight_f32.as_ptr(), d_weight as *mut u16, hidden_size); + convert_f32_to_bf16(d_bias_f32.as_ptr(), d_bias as *mut u16, hidden_size); +} diff --git a/src/runtime/cpu/kernels/simd/norm/layer_norm.rs b/src/runtime/cpu/kernels/simd/norm/layer_norm.rs new file mode 100644 index 00000000..0d065f55 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/layer_norm.rs @@ -0,0 +1,226 @@ +//! SIMD dispatch and scalar fallbacks for Layer normalization + +use super::super::{SimdLevel, detect_simd}; +use super::SIMD_THRESHOLD; + +#[cfg(target_arch = "x86_64")] +use super::avx2; +#[cfg(target_arch = "x86_64")] +use super::avx512; + +#[cfg(target_arch = "aarch64")] +use super::aarch64; + +/// SIMD Layer normalization for f32 +/// +/// # Safety +/// - `input` and `out` must point to `batch_size * hidden_size` elements +/// - `weight` and `bias` must point to `hidden_size` elements +#[inline] +pub unsafe fn layer_norm_f32( + input: *const f32, + weight: *const f32, + bias: *const f32, + out: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let level = detect_simd(); + + if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + layer_norm_scalar_f32(input, weight, bias, out, batch_size, hidden_size, eps); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => { + avx512::layer_norm_f32(input, weight, bias, out, batch_size, hidden_size, eps) + } + SimdLevel::Avx2Fma => { + avx2::layer_norm_f32(input, weight, bias, out, batch_size, hidden_size, eps) + } + _ => layer_norm_scalar_f32(input, weight, bias, out, batch_size, hidden_size, eps), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::layer_norm_f32(input, weight, bias, out, batch_size, hidden_size, eps) + } + _ => layer_norm_scalar_f32(input, weight, bias, out, batch_size, hidden_size, eps), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + layer_norm_scalar_f32(input, weight, bias, out, batch_size, hidden_size, eps); +} + +/// SIMD Layer normalization for f64 +/// +/// # Safety +/// - `input` and `out` must point to `batch_size * hidden_size` elements +/// - `weight` and `bias` must point to `hidden_size` elements +#[inline] +pub unsafe fn layer_norm_f64( + input: *const f64, + weight: *const f64, + bias: *const f64, + out: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let level = detect_simd(); + + if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + layer_norm_scalar_f64(input, weight, bias, out, batch_size, hidden_size, eps); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => { + avx512::layer_norm_f64(input, weight, bias, out, batch_size, hidden_size, eps) + } + SimdLevel::Avx2Fma => { + avx2::layer_norm_f64(input, weight, bias, out, batch_size, hidden_size, eps) + } + _ => layer_norm_scalar_f64(input, weight, bias, out, batch_size, hidden_size, eps), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::layer_norm_f64(input, weight, bias, out, batch_size, hidden_size, eps) + } + _ => layer_norm_scalar_f64(input, weight, bias, out, batch_size, hidden_size, eps), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + layer_norm_scalar_f64(input, weight, bias, out, batch_size, hidden_size, eps); +} + +/// Scalar layer norm for f32 +#[inline] +pub unsafe fn layer_norm_scalar_f32( + input: *const f32, + weight: *const f32, + bias: *const f32, + out: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + // Compute mean + let mut sum = 0.0f32; + for i in 0..hidden_size { + sum += *input.add(row_start + i); + } + let mean = sum / hidden_size as f32; + + // Compute variance + let mut var_sum = 0.0f32; + for i in 0..hidden_size { + let diff = *input.add(row_start + i) - mean; + var_sum += diff * diff; + } + let inv_std = 1.0 / (var_sum / hidden_size as f32 + eps).sqrt(); + + // Apply normalization, weight, and bias + for i in 0..hidden_size { + let x = *input.add(row_start + i); + let w = *weight.add(i); + let b = *bias.add(i); + *out.add(row_start + i) = (x - mean) * inv_std * w + b; + } + } +} + +/// Scalar layer norm for f64 +#[inline] +pub unsafe fn layer_norm_scalar_f64( + input: *const f64, + weight: *const f64, + bias: *const f64, + out: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum = 0.0f64; + for i in 0..hidden_size { + sum += *input.add(row_start + i); + } + let mean = sum / hidden_size as f64; + + let mut var_sum = 0.0f64; + for i in 0..hidden_size { + let diff = *input.add(row_start + i) - mean; + var_sum += diff * diff; + } + let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); + + for i in 0..hidden_size { + let x = *input.add(row_start + i); + let w = *weight.add(i); + let b = *bias.add(i); + *out.add(row_start + i) = (x - mean) * inv_std * w + b; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_layer_norm_f32() { + let hidden_size = 128; + let batch_size = 4; + let input: Vec = (0..(batch_size * hidden_size)) + .map(|x| (x as f32) / 100.0 - 2.5) + .collect(); + let weight: Vec = vec![1.0f32; hidden_size]; + let bias: Vec = vec![0.0f32; hidden_size]; + let mut out = vec![0.0f32; batch_size * hidden_size]; + let mut out_ref = vec![0.0f32; batch_size * hidden_size]; + + unsafe { + layer_norm_f32( + input.as_ptr(), + weight.as_ptr(), + bias.as_ptr(), + out.as_mut_ptr(), + batch_size, + hidden_size, + 1e-5, + ); + layer_norm_scalar_f32( + input.as_ptr(), + weight.as_ptr(), + bias.as_ptr(), + out_ref.as_mut_ptr(), + batch_size, + hidden_size, + 1e-5, + ); + } + + for i in 0..(batch_size * hidden_size) { + assert!( + (out[i] - out_ref[i]).abs() < 1e-4, + "mismatch at {}: {} vs {}", + i, + out[i], + out_ref[i] + ); + } + } +} diff --git a/src/runtime/cpu/kernels/simd/norm/mod.rs b/src/runtime/cpu/kernels/simd/norm/mod.rs index 6b98b93d..4f33f86e 100644 --- a/src/runtime/cpu/kernels/simd/norm/mod.rs +++ b/src/runtime/cpu/kernels/simd/norm/mod.rs @@ -17,407 +17,31 @@ mod avx512; #[cfg(target_arch = "aarch64")] mod aarch64; -use super::{SimdLevel, detect_simd}; +#[cfg(feature = "f16")] +mod half; +#[cfg(feature = "f16")] +pub use half::{ + fused_add_layer_norm_bf16, fused_add_layer_norm_bwd_bf16, fused_add_layer_norm_bwd_f16, + fused_add_layer_norm_f16, fused_add_rms_norm_bf16, fused_add_rms_norm_bwd_bf16, + fused_add_rms_norm_bwd_f16, fused_add_rms_norm_f16, layer_norm_bf16, layer_norm_f16, + rms_norm_bf16, rms_norm_f16, +}; /// Minimum hidden_size to justify SIMD overhead -const SIMD_THRESHOLD: usize = 64; - -/// SIMD RMS normalization for f32 -/// -/// # Safety -/// - `input` and `out` must point to `batch_size * hidden_size` elements -/// - `weight` must point to `hidden_size` elements -#[inline] -pub unsafe fn rms_norm_f32( - input: *const f32, - weight: *const f32, - out: *mut f32, - batch_size: usize, - hidden_size: usize, - eps: f32, -) { - let level = detect_simd(); - - if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { - rms_norm_scalar_f32(input, weight, out, batch_size, hidden_size, eps); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::rms_norm_f32(input, weight, out, batch_size, hidden_size, eps), - SimdLevel::Avx2Fma => avx2::rms_norm_f32(input, weight, out, batch_size, hidden_size, eps), - _ => rms_norm_scalar_f32(input, weight, out, batch_size, hidden_size, eps), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - aarch64::neon::rms_norm_f32(input, weight, out, batch_size, hidden_size, eps) - } - _ => rms_norm_scalar_f32(input, weight, out, batch_size, hidden_size, eps), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - rms_norm_scalar_f32(input, weight, out, batch_size, hidden_size, eps); -} - -/// SIMD RMS normalization for f64 -/// -/// # Safety -/// - `input` and `out` must point to `batch_size * hidden_size` elements -/// - `weight` must point to `hidden_size` elements -#[inline] -pub unsafe fn rms_norm_f64( - input: *const f64, - weight: *const f64, - out: *mut f64, - batch_size: usize, - hidden_size: usize, - eps: f64, -) { - let level = detect_simd(); - - if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { - rms_norm_scalar_f64(input, weight, out, batch_size, hidden_size, eps); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => avx512::rms_norm_f64(input, weight, out, batch_size, hidden_size, eps), - SimdLevel::Avx2Fma => avx2::rms_norm_f64(input, weight, out, batch_size, hidden_size, eps), - _ => rms_norm_scalar_f64(input, weight, out, batch_size, hidden_size, eps), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - aarch64::neon::rms_norm_f64(input, weight, out, batch_size, hidden_size, eps) - } - _ => rms_norm_scalar_f64(input, weight, out, batch_size, hidden_size, eps), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - rms_norm_scalar_f64(input, weight, out, batch_size, hidden_size, eps); -} - -/// SIMD Layer normalization for f32 -/// -/// # Safety -/// - `input` and `out` must point to `batch_size * hidden_size` elements -/// - `weight` and `bias` must point to `hidden_size` elements -#[inline] -pub unsafe fn layer_norm_f32( - input: *const f32, - weight: *const f32, - bias: *const f32, - out: *mut f32, - batch_size: usize, - hidden_size: usize, - eps: f32, -) { - let level = detect_simd(); - - if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { - layer_norm_scalar_f32(input, weight, bias, out, batch_size, hidden_size, eps); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => { - avx512::layer_norm_f32(input, weight, bias, out, batch_size, hidden_size, eps) - } - SimdLevel::Avx2Fma => { - avx2::layer_norm_f32(input, weight, bias, out, batch_size, hidden_size, eps) - } - _ => layer_norm_scalar_f32(input, weight, bias, out, batch_size, hidden_size, eps), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - aarch64::neon::layer_norm_f32(input, weight, bias, out, batch_size, hidden_size, eps) - } - _ => layer_norm_scalar_f32(input, weight, bias, out, batch_size, hidden_size, eps), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - layer_norm_scalar_f32(input, weight, bias, out, batch_size, hidden_size, eps); -} - -/// SIMD Layer normalization for f64 -/// -/// # Safety -/// - `input` and `out` must point to `batch_size * hidden_size` elements -/// - `weight` and `bias` must point to `hidden_size` elements -#[inline] -pub unsafe fn layer_norm_f64( - input: *const f64, - weight: *const f64, - bias: *const f64, - out: *mut f64, - batch_size: usize, - hidden_size: usize, - eps: f64, -) { - let level = detect_simd(); - - if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { - layer_norm_scalar_f64(input, weight, bias, out, batch_size, hidden_size, eps); - return; - } - - #[cfg(target_arch = "x86_64")] - match level { - SimdLevel::Avx512 => { - avx512::layer_norm_f64(input, weight, bias, out, batch_size, hidden_size, eps) - } - SimdLevel::Avx2Fma => { - avx2::layer_norm_f64(input, weight, bias, out, batch_size, hidden_size, eps) - } - _ => layer_norm_scalar_f64(input, weight, bias, out, batch_size, hidden_size, eps), - } - - #[cfg(target_arch = "aarch64")] - match level { - SimdLevel::Neon | SimdLevel::NeonFp16 => { - aarch64::neon::layer_norm_f64(input, weight, bias, out, batch_size, hidden_size, eps) - } - _ => layer_norm_scalar_f64(input, weight, bias, out, batch_size, hidden_size, eps), - } - - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - layer_norm_scalar_f64(input, weight, bias, out, batch_size, hidden_size, eps); -} - -// ============================================================================ -// Scalar fallbacks -// ============================================================================ - -/// Scalar RMS norm for f32 -#[inline] -pub unsafe fn rms_norm_scalar_f32( - input: *const f32, - weight: *const f32, - out: *mut f32, - batch_size: usize, - hidden_size: usize, - eps: f32, -) { - for batch in 0..batch_size { - let row_start = batch * hidden_size; - - // Compute sum of squares - let mut sum_sq = 0.0f32; - for i in 0..hidden_size { - let x = *input.add(row_start + i); - sum_sq += x * x; - } - - // Compute inverse RMS - let inv_rms = 1.0 / (sum_sq / hidden_size as f32 + eps).sqrt(); - - // Apply normalization and weight - for i in 0..hidden_size { - let x = *input.add(row_start + i); - let w = *weight.add(i); - *out.add(row_start + i) = x * inv_rms * w; - } - } -} - -/// Scalar RMS norm for f64 -#[inline] -pub unsafe fn rms_norm_scalar_f64( - input: *const f64, - weight: *const f64, - out: *mut f64, - batch_size: usize, - hidden_size: usize, - eps: f64, -) { - for batch in 0..batch_size { - let row_start = batch * hidden_size; - - let mut sum_sq = 0.0f64; - for i in 0..hidden_size { - let x = *input.add(row_start + i); - sum_sq += x * x; - } - - let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt(); - - for i in 0..hidden_size { - let x = *input.add(row_start + i); - let w = *weight.add(i); - *out.add(row_start + i) = x * inv_rms * w; - } - } -} - -/// Scalar layer norm for f32 -#[inline] -pub unsafe fn layer_norm_scalar_f32( - input: *const f32, - weight: *const f32, - bias: *const f32, - out: *mut f32, - batch_size: usize, - hidden_size: usize, - eps: f32, -) { - for batch in 0..batch_size { - let row_start = batch * hidden_size; - - // Compute mean - let mut sum = 0.0f32; - for i in 0..hidden_size { - sum += *input.add(row_start + i); - } - let mean = sum / hidden_size as f32; - - // Compute variance - let mut var_sum = 0.0f32; - for i in 0..hidden_size { - let diff = *input.add(row_start + i) - mean; - var_sum += diff * diff; - } - let inv_std = 1.0 / (var_sum / hidden_size as f32 + eps).sqrt(); - - // Apply normalization, weight, and bias - for i in 0..hidden_size { - let x = *input.add(row_start + i); - let w = *weight.add(i); - let b = *bias.add(i); - *out.add(row_start + i) = (x - mean) * inv_std * w + b; - } - } -} - -/// Scalar layer norm for f64 -#[inline] -pub unsafe fn layer_norm_scalar_f64( - input: *const f64, - weight: *const f64, - bias: *const f64, - out: *mut f64, - batch_size: usize, - hidden_size: usize, - eps: f64, -) { - for batch in 0..batch_size { - let row_start = batch * hidden_size; - - let mut sum = 0.0f64; - for i in 0..hidden_size { - sum += *input.add(row_start + i); - } - let mean = sum / hidden_size as f64; - - let mut var_sum = 0.0f64; - for i in 0..hidden_size { - let diff = *input.add(row_start + i) - mean; - var_sum += diff * diff; - } - let inv_std = 1.0 / (var_sum / hidden_size as f64 + eps).sqrt(); - - for i in 0..hidden_size { - let x = *input.add(row_start + i); - let w = *weight.add(i); - let b = *bias.add(i); - *out.add(row_start + i) = (x - mean) * inv_std * w + b; - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_rms_norm_f32() { - let hidden_size = 128; - let batch_size = 4; - let input: Vec = (0..(batch_size * hidden_size)) - .map(|x| (x as f32) / 100.0 - 2.5) - .collect(); - let weight: Vec = vec![1.0f32; hidden_size]; - let mut out = vec![0.0f32; batch_size * hidden_size]; - let mut out_ref = vec![0.0f32; batch_size * hidden_size]; - - unsafe { - rms_norm_f32( - input.as_ptr(), - weight.as_ptr(), - out.as_mut_ptr(), - batch_size, - hidden_size, - 1e-5, - ); - rms_norm_scalar_f32( - input.as_ptr(), - weight.as_ptr(), - out_ref.as_mut_ptr(), - batch_size, - hidden_size, - 1e-5, - ); - } - - for i in 0..(batch_size * hidden_size) { - assert!( - (out[i] - out_ref[i]).abs() < 1e-4, - "mismatch at {}: {} vs {}", - i, - out[i], - out_ref[i] - ); - } - } - - #[test] - fn test_layer_norm_f32() { - let hidden_size = 128; - let batch_size = 4; - let input: Vec = (0..(batch_size * hidden_size)) - .map(|x| (x as f32) / 100.0 - 2.5) - .collect(); - let weight: Vec = vec![1.0f32; hidden_size]; - let bias: Vec = vec![0.0f32; hidden_size]; - let mut out = vec![0.0f32; batch_size * hidden_size]; - let mut out_ref = vec![0.0f32; batch_size * hidden_size]; - - unsafe { - layer_norm_f32( - input.as_ptr(), - weight.as_ptr(), - bias.as_ptr(), - out.as_mut_ptr(), - batch_size, - hidden_size, - 1e-5, - ); - layer_norm_scalar_f32( - input.as_ptr(), - weight.as_ptr(), - bias.as_ptr(), - out_ref.as_mut_ptr(), - batch_size, - hidden_size, - 1e-5, - ); - } - - for i in 0..(batch_size * hidden_size) { - assert!( - (out[i] - out_ref[i]).abs() < 1e-4, - "mismatch at {}: {} vs {}", - i, - out[i], - out_ref[i] - ); - } - } -} +pub(super) const SIMD_THRESHOLD: usize = 64; + +mod fused_add_layer_norm; +mod fused_add_rms_norm; +mod layer_norm; +mod rms_norm; + +pub use fused_add_layer_norm::{ + fused_add_layer_norm_bwd_f32, fused_add_layer_norm_bwd_f64, fused_add_layer_norm_f32, + fused_add_layer_norm_f64, +}; +pub use fused_add_rms_norm::{ + fused_add_rms_norm_bwd_f32, fused_add_rms_norm_bwd_f64, fused_add_rms_norm_f32, + fused_add_rms_norm_f64, +}; +pub use layer_norm::{layer_norm_f32, layer_norm_f64}; +pub use rms_norm::{rms_norm_f32, rms_norm_f64}; diff --git a/src/runtime/cpu/kernels/simd/norm/rms_norm.rs b/src/runtime/cpu/kernels/simd/norm/rms_norm.rs new file mode 100644 index 00000000..e7109b8e --- /dev/null +++ b/src/runtime/cpu/kernels/simd/norm/rms_norm.rs @@ -0,0 +1,199 @@ +//! SIMD dispatch and scalar fallbacks for RMS normalization + +use super::super::{SimdLevel, detect_simd}; +use super::SIMD_THRESHOLD; + +#[cfg(target_arch = "x86_64")] +use super::avx2; +#[cfg(target_arch = "x86_64")] +use super::avx512; + +#[cfg(target_arch = "aarch64")] +use super::aarch64; + +/// SIMD RMS normalization for f32 +/// +/// # Safety +/// - `input` and `out` must point to `batch_size * hidden_size` elements +/// - `weight` must point to `hidden_size` elements +#[inline] +pub unsafe fn rms_norm_f32( + input: *const f32, + weight: *const f32, + out: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + let level = detect_simd(); + + if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + rms_norm_scalar_f32(input, weight, out, batch_size, hidden_size, eps); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::rms_norm_f32(input, weight, out, batch_size, hidden_size, eps), + SimdLevel::Avx2Fma => avx2::rms_norm_f32(input, weight, out, batch_size, hidden_size, eps), + _ => rms_norm_scalar_f32(input, weight, out, batch_size, hidden_size, eps), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::rms_norm_f32(input, weight, out, batch_size, hidden_size, eps) + } + _ => rms_norm_scalar_f32(input, weight, out, batch_size, hidden_size, eps), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + rms_norm_scalar_f32(input, weight, out, batch_size, hidden_size, eps); +} + +/// SIMD RMS normalization for f64 +/// +/// # Safety +/// - `input` and `out` must point to `batch_size * hidden_size` elements +/// - `weight` must point to `hidden_size` elements +#[inline] +pub unsafe fn rms_norm_f64( + input: *const f64, + weight: *const f64, + out: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + let level = detect_simd(); + + if hidden_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + rms_norm_scalar_f64(input, weight, out, batch_size, hidden_size, eps); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::rms_norm_f64(input, weight, out, batch_size, hidden_size, eps), + SimdLevel::Avx2Fma => avx2::rms_norm_f64(input, weight, out, batch_size, hidden_size, eps), + _ => rms_norm_scalar_f64(input, weight, out, batch_size, hidden_size, eps), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::rms_norm_f64(input, weight, out, batch_size, hidden_size, eps) + } + _ => rms_norm_scalar_f64(input, weight, out, batch_size, hidden_size, eps), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + rms_norm_scalar_f64(input, weight, out, batch_size, hidden_size, eps); +} + +/// Scalar RMS norm for f32 +#[inline] +pub unsafe fn rms_norm_scalar_f32( + input: *const f32, + weight: *const f32, + out: *mut f32, + batch_size: usize, + hidden_size: usize, + eps: f32, +) { + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + // Compute sum of squares in f64 for precision (matches llama.cpp's ggml_float) + let mut sum_sq = 0.0f64; + for i in 0..hidden_size { + let x = *input.add(row_start + i) as f64; + sum_sq += x * x; + } + + // Compute inverse RMS in f64, then cast to f32 + let inv_rms = (1.0f64 / (sum_sq / hidden_size as f64 + eps as f64).sqrt()) as f32; + + // Apply normalization and weight + for i in 0..hidden_size { + let x = *input.add(row_start + i); + let w = *weight.add(i); + *out.add(row_start + i) = x * inv_rms * w; + } + } +} + +/// Scalar RMS norm for f64 +#[inline] +pub unsafe fn rms_norm_scalar_f64( + input: *const f64, + weight: *const f64, + out: *mut f64, + batch_size: usize, + hidden_size: usize, + eps: f64, +) { + for batch in 0..batch_size { + let row_start = batch * hidden_size; + + let mut sum_sq = 0.0f64; + for i in 0..hidden_size { + let x = *input.add(row_start + i); + sum_sq += x * x; + } + + let inv_rms = 1.0 / (sum_sq / hidden_size as f64 + eps).sqrt(); + + for i in 0..hidden_size { + let x = *input.add(row_start + i); + let w = *weight.add(i); + *out.add(row_start + i) = x * inv_rms * w; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rms_norm_f32() { + let hidden_size = 128; + let batch_size = 4; + let input: Vec = (0..(batch_size * hidden_size)) + .map(|x| (x as f32) / 100.0 - 2.5) + .collect(); + let weight: Vec = vec![1.0f32; hidden_size]; + let mut out = vec![0.0f32; batch_size * hidden_size]; + let mut out_ref = vec![0.0f32; batch_size * hidden_size]; + + unsafe { + rms_norm_f32( + input.as_ptr(), + weight.as_ptr(), + out.as_mut_ptr(), + batch_size, + hidden_size, + 1e-5, + ); + rms_norm_scalar_f32( + input.as_ptr(), + weight.as_ptr(), + out_ref.as_mut_ptr(), + batch_size, + hidden_size, + 1e-5, + ); + } + + for i in 0..(batch_size * hidden_size) { + assert!( + (out[i] - out_ref[i]).abs() < 1e-4, + "mismatch at {}: {} vs {}", + i, + out[i], + out_ref[i] + ); + } + } +} diff --git a/src/runtime/cpu/kernels/simd/reduce/mod.rs b/src/runtime/cpu/kernels/simd/reduce/mod.rs index 531db295..8b6981fb 100644 --- a/src/runtime/cpu/kernels/simd/reduce/mod.rs +++ b/src/runtime/cpu/kernels/simd/reduce/mod.rs @@ -271,6 +271,62 @@ pub unsafe fn reduce_scalar_f64( } } +#[cfg(feature = "f16")] +/// f16 wrapper for reduce: converts input to f32, runs f32 reduce, converts output back. +/// +/// # Safety +/// - `a` must point to `reduce_size * outer_size` elements +/// - `out` must point to `outer_size` elements +pub unsafe fn reduce_f16( + op: ReduceOp, + a: *const half::f16, + out: *mut half::f16, + reduce_size: usize, + outer_size: usize, +) { + use super::half_convert_utils::*; + let input_len = outer_size * reduce_size; + let mut a_f32 = vec![0.0f32; input_len]; + let mut out_f32 = vec![0.0f32; outer_size]; + convert_f16_to_f32(a as *const u16, a_f32.as_mut_ptr(), input_len); + reduce_f32( + op, + a_f32.as_ptr(), + out_f32.as_mut_ptr(), + reduce_size, + outer_size, + ); + convert_f32_to_f16(out_f32.as_ptr(), out as *mut u16, outer_size); +} + +#[cfg(feature = "f16")] +/// bf16 wrapper for reduce: converts input to f32, runs f32 reduce, converts output back. +/// +/// # Safety +/// - `a` must point to `reduce_size * outer_size` elements +/// - `out` must point to `outer_size` elements +pub unsafe fn reduce_bf16( + op: ReduceOp, + a: *const half::bf16, + out: *mut half::bf16, + reduce_size: usize, + outer_size: usize, +) { + use super::half_convert_utils::*; + let input_len = outer_size * reduce_size; + let mut a_f32 = vec![0.0f32; input_len]; + let mut out_f32 = vec![0.0f32; outer_size]; + convert_bf16_to_f32(a as *const u16, a_f32.as_mut_ptr(), input_len); + reduce_f32( + op, + a_f32.as_ptr(), + out_f32.as_mut_ptr(), + reduce_size, + outer_size, + ); + convert_f32_to_bf16(out_f32.as_ptr(), out as *mut u16, outer_size); +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/runtime/cpu/kernels/simd/scalar/mod.rs b/src/runtime/cpu/kernels/simd/scalar/mod.rs index c1600435..c8a3c985 100644 --- a/src/runtime/cpu/kernels/simd/scalar/mod.rs +++ b/src/runtime/cpu/kernels/simd/scalar/mod.rs @@ -300,6 +300,9 @@ pub unsafe fn rsub_scalar_f64(a: *const f64, scalar: f64, out: *mut f64, len: us } } +half_scalar_op!(scalar, scalar_f32, BinaryOp); +half_unary_scalar!(rsub_scalar, rsub_scalar_f32); + #[cfg(test)] mod tests { use super::*; diff --git a/src/runtime/cpu/kernels/simd/softmax/aarch64/neon.rs b/src/runtime/cpu/kernels/simd/softmax/aarch64/neon.rs index b042df5c..36604481 100644 --- a/src/runtime/cpu/kernels/simd/softmax/aarch64/neon.rs +++ b/src/runtime/cpu/kernels/simd/softmax/aarch64/neon.rs @@ -1,15 +1,7 @@ -//! NEON softmax kernels for ARM64 +//! NEON softmax kernels for ARM64 using online algorithm (2-pass). //! -//! Provides vectorized softmax operation using 128-bit NEON registers. -//! -//! softmax(x)[i] = exp(x[i] - max(x)) / sum(exp(x - max(x))) -//! -//! # SIMD Strategy -//! -//! 1. SIMD max-reduce to find maximum for numerical stability -//! 2. SIMD exp computation with shifted values -//! 3. SIMD sum-reduce for normalization factor -//! 4. SIMD multiply by inverse sum +//! Pass 1: Online SIMD max + sum (single read of input) +//! Pass 2: Compute exp(x - max) / sum and write output (one read + one write) #[cfg(target_arch = "aarch64")] use std::arch::aarch64::*; @@ -21,7 +13,7 @@ use super::super::super::math::aarch64::neon::{ const F32_LANES: usize = 4; const F64_LANES: usize = 2; -/// NEON softmax for f32 +/// NEON softmax for f32 using online algorithm. /// /// # Safety /// - CPU must support NEON (always true on AArch64) @@ -36,63 +28,87 @@ pub unsafe fn softmax_f32(a: *const f32, out: *mut f32, outer_size: usize, dim_s let base = a.add(o * dim_size); let out_base = out.add(o * dim_size); - // Phase 1: Find max (for numerical stability) - let mut max_acc = vdupq_n_f32(f32::NEG_INFINITY); + // Pass 1: Online max + sum + let mut max_vec = vdupq_n_f32(f32::NEG_INFINITY); + let mut sum_vec = vdupq_n_f32(0.0); + for i in 0..chunks { let v = vld1q_f32(base.add(i * F32_LANES)); - max_acc = vmaxq_f32(max_acc, v); + + let old_max = max_vec; + max_vec = vmaxq_f32(max_vec, v); + + // Rescale previous sum + // Guard: when old_max == max_vec == -inf, exp(-inf-(-inf)) = NaN. + // Use mask to zero out -inf lanes (their sum contribution is 0). + let neg_inf = vdupq_n_f32(f32::NEG_INFINITY); + let valid_old = vmvnq_u32(vceqq_f32(old_max, neg_inf)); // != -inf + let rescale = exp_f32(vsubq_f32(old_max, max_vec)); + let rescale = + vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(rescale), valid_old)); + sum_vec = vmulq_f32(sum_vec, rescale); + + // Add new contributions + let valid_new = vmvnq_u32(vceqq_f32(max_vec, neg_inf)); // != -inf + let exp_v = exp_f32(vsubq_f32(v, max_vec)); + let exp_v = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(exp_v), valid_new)); + sum_vec = vaddq_f32(sum_vec, exp_v); } - let mut max_val = hmax_f32(max_acc); - // Scalar tail for max + // Horizontal reduce to get per-lane max, then reconcile with scalar tail + let mut max_val = hmax_f32(max_vec); + + // Scalar tail (online) + let mut tail_sum = 0.0f32; for i in 0..remainder { let val = *base.add(chunks * F32_LANES + i); if val > max_val { + let rescale = if max_val == f32::NEG_INFINITY { + 0.0 + } else { + (max_val - val).exp() + }; + tail_sum = tail_sum * rescale + 1.0; max_val = val; + } else if val == f32::NEG_INFINITY { + // skip + } else { + tail_sum += (val - max_val).exp(); } } + // Reconcile SIMD sum with global max + // Guard -inf lanes to avoid NaN from exp(-inf - (-inf)) + let neg_inf = vdupq_n_f32(f32::NEG_INFINITY); + let valid_mask = vmvnq_u32(vceqq_f32(max_vec, neg_inf)); + let v_global_max = vdupq_n_f32(max_val); + let rescale = exp_f32(vsubq_f32(max_vec, v_global_max)); + let rescale = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(rescale), valid_mask)); + let rescaled_sum = vmulq_f32(sum_vec, rescale); + let sum = hsum_f32(rescaled_sum) + tail_sum; + + // Pass 2: exp(x - max) / sum let v_max = vdupq_n_f32(max_val); + let inv_sum_vec = vdupq_n_f32(1.0 / sum); - // Phase 2: Compute exp(x - max) and sum - let mut sum_acc = vdupq_n_f32(0.0); for i in 0..chunks { let offset = i * F32_LANES; let v = vld1q_f32(base.add(offset)); let shifted = vsubq_f32(v, v_max); - let exp_v = exp_f32(shifted); - vst1q_f32(out_base.add(offset), exp_v); - sum_acc = vaddq_f32(sum_acc, exp_v); - } - let mut sum = hsum_f32(sum_acc); - - // Scalar tail for exp and sum - for i in 0..remainder { - let offset = chunks * F32_LANES + i; - let val = *base.add(offset); - let exp_val = (val - max_val).exp(); - *out_base.add(offset) = exp_val; - sum += exp_val; - } - - // Phase 3: Normalize by sum - let inv_sum = vdupq_n_f32(1.0 / sum); - for i in 0..chunks { - let offset = i * F32_LANES; - let v = vld1q_f32(out_base.add(offset)); - vst1q_f32(out_base.add(offset), vmulq_f32(v, inv_sum)); + let normalized = vmulq_f32(exp_f32(shifted), inv_sum_vec); + vst1q_f32(out_base.add(offset), normalized); } - // Scalar tail for normalization let scalar_inv_sum = 1.0 / sum; for i in 0..remainder { let offset = chunks * F32_LANES + i; - *out_base.add(offset) *= scalar_inv_sum; + let val = *base.add(offset); + *out_base.add(offset) = (val - max_val).exp() * scalar_inv_sum; } } } -/// NEON softmax for f64 +/// NEON softmax for f64 using online algorithm. /// /// # Safety /// - CPU must support NEON (always true on AArch64) @@ -107,55 +123,76 @@ pub unsafe fn softmax_f64(a: *const f64, out: *mut f64, outer_size: usize, dim_s let base = a.add(o * dim_size); let out_base = out.add(o * dim_size); - // Phase 1: Find max - let mut max_acc = vdupq_n_f64(f64::NEG_INFINITY); + // Pass 1: Online max + sum + let mut max_vec = vdupq_n_f64(f64::NEG_INFINITY); + let mut sum_vec = vdupq_n_f64(0.0); + for i in 0..chunks { let v = vld1q_f64(base.add(i * F64_LANES)); - max_acc = vmaxq_f64(max_acc, v); + + let old_max = max_vec; + max_vec = vmaxq_f64(max_vec, v); + + // Guard -inf lanes + let neg_inf = vdupq_n_f64(f64::NEG_INFINITY); + let valid_old = veorq_u64(vceqq_f64(old_max, neg_inf), vdupq_n_u64(!0)); + let rescale = exp_f64(vsubq_f64(old_max, max_vec)); + let rescale = + vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(rescale), valid_old)); + sum_vec = vmulq_f64(sum_vec, rescale); + + let valid_new = veorq_u64(vceqq_f64(max_vec, neg_inf), vdupq_n_u64(!0)); + let exp_v = exp_f64(vsubq_f64(v, max_vec)); + let exp_v = vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(exp_v), valid_new)); + sum_vec = vaddq_f64(sum_vec, exp_v); } - let mut max_val = hmax_f64(max_acc); + let mut max_val = hmax_f64(max_vec); + + let mut tail_sum = 0.0f64; for i in 0..remainder { let val = *base.add(chunks * F64_LANES + i); if val > max_val { + let rescale = if max_val == f64::NEG_INFINITY { + 0.0 + } else { + (max_val - val).exp() + }; + tail_sum = tail_sum * rescale + 1.0; max_val = val; + } else if val == f64::NEG_INFINITY { + // skip + } else { + tail_sum += (val - max_val).exp(); } } + // Reconcile SIMD sum with global max + let neg_inf = vdupq_n_f64(f64::NEG_INFINITY); + let valid_mask = veorq_u64(vceqq_f64(max_vec, neg_inf), vdupq_n_u64(!0)); + let v_global_max = vdupq_n_f64(max_val); + let rescale = exp_f64(vsubq_f64(max_vec, v_global_max)); + let rescale = vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(rescale), valid_mask)); + let rescaled_sum = vmulq_f64(sum_vec, rescale); + let sum = hsum_f64(rescaled_sum) + tail_sum; + + // Pass 2: exp(x - max) / sum let v_max = vdupq_n_f64(max_val); + let inv_sum_vec = vdupq_n_f64(1.0 / sum); - // Phase 2: Compute exp(x - max) and sum - let mut sum_acc = vdupq_n_f64(0.0); for i in 0..chunks { let offset = i * F64_LANES; let v = vld1q_f64(base.add(offset)); let shifted = vsubq_f64(v, v_max); - let exp_v = exp_f64(shifted); - vst1q_f64(out_base.add(offset), exp_v); - sum_acc = vaddq_f64(sum_acc, exp_v); - } - let mut sum = hsum_f64(sum_acc); - - for i in 0..remainder { - let offset = chunks * F64_LANES + i; - let val = *base.add(offset); - let exp_val = (val - max_val).exp(); - *out_base.add(offset) = exp_val; - sum += exp_val; - } - - // Phase 3: Normalize - let inv_sum = vdupq_n_f64(1.0 / sum); - for i in 0..chunks { - let offset = i * F64_LANES; - let v = vld1q_f64(out_base.add(offset)); - vst1q_f64(out_base.add(offset), vmulq_f64(v, inv_sum)); + let normalized = vmulq_f64(exp_f64(shifted), inv_sum_vec); + vst1q_f64(out_base.add(offset), normalized); } let scalar_inv_sum = 1.0 / sum; for i in 0..remainder { let offset = chunks * F64_LANES + i; - *out_base.add(offset) *= scalar_inv_sum; + let val = *base.add(offset); + *out_base.add(offset) = (val - max_val).exp() * scalar_inv_sum; } } } diff --git a/src/runtime/cpu/kernels/simd/softmax/avx2.rs b/src/runtime/cpu/kernels/simd/softmax/avx2.rs index a3e63803..a8b2e423 100644 --- a/src/runtime/cpu/kernels/simd/softmax/avx2.rs +++ b/src/runtime/cpu/kernels/simd/softmax/avx2.rs @@ -1,6 +1,7 @@ -//! AVX2 softmax kernels +//! AVX2 softmax kernels using online algorithm (2-pass). //! -//! Uses SIMD for max-reduce, sum-reduce, and final normalization. +//! Pass 1: Online SIMD max + sum (single read of input) +//! Pass 2: Compute exp(x - max) / sum and write output (one read + one write) #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; @@ -10,7 +11,7 @@ use super::super::math::avx2::{exp_f32, exp_f64, hmax_f32, hmax_f64, hsum_f32, h const F32_LANES: usize = 8; const F64_LANES: usize = 4; -/// AVX2 softmax for f32 +/// AVX2 softmax for f32 using online algorithm. #[target_feature(enable = "avx2", enable = "fma")] pub unsafe fn softmax_f32(a: *const f32, out: *mut f32, outer_size: usize, dim_size: usize) { let chunks = dim_size / F32_LANES; @@ -18,65 +19,92 @@ pub unsafe fn softmax_f32(a: *const f32, out: *mut f32, outer_size: usize, dim_s for o in 0..outer_size { let base = o * dim_size; - // Step 1: SIMD max-reduce + // Pass 1: Online max + sum in a single read pass let mut max_vec = _mm256_set1_ps(f32::NEG_INFINITY); + let mut sum_vec = _mm256_setzero_ps(); + for c in 0..chunks { let offset = base + c * F32_LANES; let v = _mm256_loadu_ps(a.add(offset)); + + // Save old max, compute new max + let old_max = max_vec; max_vec = _mm256_max_ps(max_vec, v); + + // Rescale previous sum: sum *= exp(old_max - new_max) + // Guard: when old_max == new_max == -inf, exp(-inf-(-inf)) = NaN. + // Use a validity mask to zero out -inf lanes (their sum contribution is 0). + let neg_inf = _mm256_set1_ps(f32::NEG_INFINITY); + let valid_old = _mm256_cmp_ps(old_max, neg_inf, _CMP_GT_OQ); + let rescale = exp_f32(_mm256_sub_ps(old_max, max_vec)); + let rescale = _mm256_and_ps(rescale, valid_old); + sum_vec = _mm256_mul_ps(sum_vec, rescale); + + // Add new contributions: sum += exp(v - new_max) + let valid_new = _mm256_cmp_ps(max_vec, neg_inf, _CMP_GT_OQ); + let exp_v = exp_f32(_mm256_sub_ps(v, max_vec)); + let exp_v = _mm256_and_ps(exp_v, valid_new); + sum_vec = _mm256_add_ps(sum_vec, exp_v); } - let mut max_val = hmax_f32(max_vec); - // Scalar tail for max + // Horizontal reduce: reconcile per-lane max/sum to scalar + let max_val_simd = hmax_f32(max_vec); + let mut max_val = max_val_simd; + + // Handle scalar tail for max (online) + let mut tail_sum = 0.0f32; for d in (chunks * F32_LANES)..dim_size { let val = *a.add(base + d); if val > max_val { + let rescale = if max_val == f32::NEG_INFINITY { + 0.0 + } else { + (max_val - val).exp() + }; + tail_sum = tail_sum * rescale + 1.0; max_val = val; + } else if val == f32::NEG_INFINITY { + // skip: contribution is 0 + } else { + tail_sum += (val - max_val).exp(); } } - // Step 2: Compute exp(x - max) and accumulate sum + // Reconcile SIMD sum with scalar max: each lane's sum must be rescaled + // sum_vec[i] was computed relative to max_vec[i], but we need it relative to max_val + // Guard: if a lane's max is -inf (all elements were -inf), its sum contribution is 0, + // not NaN. We zero out those lanes to avoid NaN from exp(-inf - (-inf)). + let v_max_vec = max_vec; // per-lane max values + let v_global_max = _mm256_set1_ps(max_val); + let neg_inf = _mm256_set1_ps(f32::NEG_INFINITY); + let valid_mask = _mm256_cmp_ps(v_max_vec, neg_inf, _CMP_GT_OQ); + let rescale = exp_f32(_mm256_sub_ps(v_max_vec, v_global_max)); + let rescale = _mm256_and_ps(rescale, valid_mask); + let rescaled_sum = _mm256_mul_ps(sum_vec, rescale); + let sum = hsum_f32(rescaled_sum) + tail_sum; + + // Pass 2: Compute exp(x - max) / sum in a single write pass let v_max = _mm256_set1_ps(max_val); - let mut sum_vec = _mm256_setzero_ps(); + let v_inv_sum = _mm256_set1_ps(1.0 / sum); for c in 0..chunks { let offset = base + c * F32_LANES; let v = _mm256_loadu_ps(a.add(offset)); let diff = _mm256_sub_ps(v, v_max); - let exp_v = exp_f32(diff); - _mm256_storeu_ps(out.add(offset), exp_v); - sum_vec = _mm256_add_ps(sum_vec, exp_v); - } - - let mut sum = hsum_f32(sum_vec); - - // Scalar tail for exp and sum - for d in (chunks * F32_LANES)..dim_size { - let val = *a.add(base + d); - let exp_val = (val - max_val).exp(); - *out.add(base + d) = exp_val; - sum += exp_val; - } - - // Step 3: SIMD normalize - let v_inv_sum = _mm256_set1_ps(1.0 / sum); - - for c in 0..chunks { - let offset = base + c * F32_LANES; - let v = _mm256_loadu_ps(out.add(offset)); - let normalized = _mm256_mul_ps(v, v_inv_sum); + let normalized = _mm256_mul_ps(exp_f32(diff), v_inv_sum); _mm256_storeu_ps(out.add(offset), normalized); } - // Scalar tail for normalization + // Scalar tail let inv_sum = 1.0 / sum; for d in (chunks * F32_LANES)..dim_size { - *out.add(base + d) *= inv_sum; + let val = *a.add(base + d); + *out.add(base + d) = (val - max_val).exp() * inv_sum; } } } -/// AVX2 softmax for f64 +/// AVX2 softmax for f64 using online algorithm. #[target_feature(enable = "avx2", enable = "fma")] pub unsafe fn softmax_f64(a: *const f64, out: *mut f64, outer_size: usize, dim_size: usize) { let chunks = dim_size / F64_LANES; @@ -84,60 +112,69 @@ pub unsafe fn softmax_f64(a: *const f64, out: *mut f64, outer_size: usize, dim_s for o in 0..outer_size { let base = o * dim_size; - // Step 1: SIMD max-reduce + // Pass 1: Online max + sum let mut max_vec = _mm256_set1_pd(f64::NEG_INFINITY); + let mut sum_vec = _mm256_setzero_pd(); + for c in 0..chunks { let offset = base + c * F64_LANES; let v = _mm256_loadu_pd(a.add(offset)); + + let old_max = max_vec; max_vec = _mm256_max_pd(max_vec, v); + + let rescale = exp_f64(_mm256_sub_pd(old_max, max_vec)); + sum_vec = _mm256_mul_pd(sum_vec, rescale); + + let exp_v = exp_f64(_mm256_sub_pd(v, max_vec)); + sum_vec = _mm256_add_pd(sum_vec, exp_v); } - let mut max_val = hmax_f64(max_vec); - // Scalar tail for max + let max_val_simd = hmax_f64(max_vec); + let mut max_val = max_val_simd; + + // Scalar tail (online) + let mut tail_sum = 0.0f64; for d in (chunks * F64_LANES)..dim_size { let val = *a.add(base + d); if val > max_val { + let rescale = if max_val == f64::NEG_INFINITY { + 0.0 + } else { + (max_val - val).exp() + }; + tail_sum = tail_sum * rescale + 1.0; max_val = val; + } else if val == f64::NEG_INFINITY { + // skip + } else { + tail_sum += (val - max_val).exp(); } } - // Step 2: Compute exp(x - max) and accumulate sum + // Reconcile SIMD sum with global max + let v_max_vec = max_vec; + let v_global_max = _mm256_set1_pd(max_val); + let rescale = exp_f64(_mm256_sub_pd(v_max_vec, v_global_max)); + let rescaled_sum = _mm256_mul_pd(sum_vec, rescale); + let sum = hsum_f64(rescaled_sum) + tail_sum; + + // Pass 2: exp(x - max) / sum let v_max = _mm256_set1_pd(max_val); - let mut sum_vec = _mm256_setzero_pd(); + let v_inv_sum = _mm256_set1_pd(1.0 / sum); for c in 0..chunks { let offset = base + c * F64_LANES; let v = _mm256_loadu_pd(a.add(offset)); let diff = _mm256_sub_pd(v, v_max); - let exp_v = exp_f64(diff); - _mm256_storeu_pd(out.add(offset), exp_v); - sum_vec = _mm256_add_pd(sum_vec, exp_v); - } - - let mut sum = hsum_f64(sum_vec); - - // Scalar tail for exp and sum - for d in (chunks * F64_LANES)..dim_size { - let val = *a.add(base + d); - let exp_val = (val - max_val).exp(); - *out.add(base + d) = exp_val; - sum += exp_val; - } - - // Step 3: SIMD normalize - let v_inv_sum = _mm256_set1_pd(1.0 / sum); - - for c in 0..chunks { - let offset = base + c * F64_LANES; - let v = _mm256_loadu_pd(out.add(offset)); - let normalized = _mm256_mul_pd(v, v_inv_sum); + let normalized = _mm256_mul_pd(exp_f64(diff), v_inv_sum); _mm256_storeu_pd(out.add(offset), normalized); } - // Scalar tail for normalization let inv_sum = 1.0 / sum; for d in (chunks * F64_LANES)..dim_size { - *out.add(base + d) *= inv_sum; + let val = *a.add(base + d); + *out.add(base + d) = (val - max_val).exp() * inv_sum; } } } diff --git a/src/runtime/cpu/kernels/simd/softmax/avx512.rs b/src/runtime/cpu/kernels/simd/softmax/avx512.rs index 4d43ac73..e9f76b1b 100644 --- a/src/runtime/cpu/kernels/simd/softmax/avx512.rs +++ b/src/runtime/cpu/kernels/simd/softmax/avx512.rs @@ -1,6 +1,7 @@ -//! AVX-512 softmax kernels +//! AVX-512 softmax kernels using online algorithm (2-pass). //! -//! Uses SIMD for max-reduce, sum-reduce, and final normalization. +//! Pass 1: Online SIMD max + sum (single read of input) +//! Pass 2: Compute exp(x - max) / sum and write output (one read + one write) #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; @@ -10,7 +11,7 @@ use super::super::math::avx512::{exp_f32, exp_f64}; const F32_LANES: usize = 16; const F64_LANES: usize = 8; -/// AVX-512 softmax for f32 +/// AVX-512 softmax for f32 using online algorithm. #[target_feature(enable = "avx512f")] pub unsafe fn softmax_f32(a: *const f32, out: *mut f32, outer_size: usize, dim_size: usize) { let chunks = dim_size / F32_LANES; @@ -18,65 +19,84 @@ pub unsafe fn softmax_f32(a: *const f32, out: *mut f32, outer_size: usize, dim_s for o in 0..outer_size { let base = o * dim_size; - // Step 1: SIMD max-reduce + // Pass 1: Online max + sum let mut max_vec = _mm512_set1_ps(f32::NEG_INFINITY); + let mut sum_vec = _mm512_setzero_ps(); + for c in 0..chunks { let offset = base + c * F32_LANES; let v = _mm512_loadu_ps(a.add(offset)); + + let old_max = max_vec; max_vec = _mm512_max_ps(max_vec, v); + + // Rescale previous sum and add new contributions. + // Guard: when old_max == max_vec == -inf, exp(-inf-(-inf)) = NaN. + // Use mask to zero out -inf lanes (their sum contribution is 0). + let neg_inf = _mm512_set1_ps(f32::NEG_INFINITY); + let valid_old = _mm512_cmp_ps_mask(old_max, neg_inf, _CMP_GT_OQ); + let rescale = exp_f32(_mm512_sub_ps(old_max, max_vec)); + let rescale = _mm512_maskz_mov_ps(valid_old, rescale); + sum_vec = _mm512_mul_ps(sum_vec, rescale); + + let valid_new = _mm512_cmp_ps_mask(max_vec, neg_inf, _CMP_GT_OQ); + let exp_v = exp_f32(_mm512_sub_ps(v, max_vec)); + let exp_v = _mm512_maskz_mov_ps(valid_new, exp_v); + sum_vec = _mm512_add_ps(sum_vec, exp_v); } + let mut max_val = _mm512_reduce_max_ps(max_vec); - // Scalar tail for max + // Scalar tail (online) + let mut tail_sum = 0.0f32; for d in (chunks * F32_LANES)..dim_size { let val = *a.add(base + d); if val > max_val { + let rescale = if max_val == f32::NEG_INFINITY { + 0.0 + } else { + (max_val - val).exp() + }; + tail_sum = tail_sum * rescale + 1.0; max_val = val; + } else if val == f32::NEG_INFINITY { + // skip + } else { + tail_sum += (val - max_val).exp(); } } - // Step 2: Compute exp(x - max) and accumulate sum + // Reconcile SIMD sum with global max + // Guard -inf lanes to avoid NaN from exp(-inf - (-inf)) + let v_global_max = _mm512_set1_ps(max_val); + let neg_inf = _mm512_set1_ps(f32::NEG_INFINITY); + let valid_mask = _mm512_cmp_ps_mask(max_vec, neg_inf, _CMP_GT_OQ); + let rescale = exp_f32(_mm512_sub_ps(max_vec, v_global_max)); + let rescale = _mm512_maskz_mov_ps(valid_mask, rescale); + let rescaled_sum = _mm512_mul_ps(sum_vec, rescale); + let sum = _mm512_reduce_add_ps(rescaled_sum) + tail_sum; + + // Pass 2: exp(x - max) / sum let v_max = _mm512_set1_ps(max_val); - let mut sum_vec = _mm512_setzero_ps(); + let v_inv_sum = _mm512_set1_ps(1.0 / sum); for c in 0..chunks { let offset = base + c * F32_LANES; let v = _mm512_loadu_ps(a.add(offset)); let diff = _mm512_sub_ps(v, v_max); - let exp_v = exp_f32(diff); - _mm512_storeu_ps(out.add(offset), exp_v); - sum_vec = _mm512_add_ps(sum_vec, exp_v); - } - - let mut sum = _mm512_reduce_add_ps(sum_vec); - - // Scalar tail for exp and sum - for d in (chunks * F32_LANES)..dim_size { - let val = *a.add(base + d); - let exp_val = (val - max_val).exp(); - *out.add(base + d) = exp_val; - sum += exp_val; - } - - // Step 3: SIMD normalize by 1/sum - let v_inv_sum = _mm512_set1_ps(1.0 / sum); - - for c in 0..chunks { - let offset = base + c * F32_LANES; - let v = _mm512_loadu_ps(out.add(offset)); - let normalized = _mm512_mul_ps(v, v_inv_sum); + let normalized = _mm512_mul_ps(exp_f32(diff), v_inv_sum); _mm512_storeu_ps(out.add(offset), normalized); } - // Scalar tail for normalization let inv_sum = 1.0 / sum; for d in (chunks * F32_LANES)..dim_size { - *out.add(base + d) *= inv_sum; + let val = *a.add(base + d); + *out.add(base + d) = (val - max_val).exp() * inv_sum; } } } -/// AVX-512 softmax for f64 +/// AVX-512 softmax for f64 using online algorithm. #[target_feature(enable = "avx512f")] pub unsafe fn softmax_f64(a: *const f64, out: *mut f64, outer_size: usize, dim_size: usize) { let chunks = dim_size / F64_LANES; @@ -84,60 +104,78 @@ pub unsafe fn softmax_f64(a: *const f64, out: *mut f64, outer_size: usize, dim_s for o in 0..outer_size { let base = o * dim_size; - // Step 1: SIMD max-reduce + // Pass 1: Online max + sum let mut max_vec = _mm512_set1_pd(f64::NEG_INFINITY); + let mut sum_vec = _mm512_setzero_pd(); + for c in 0..chunks { let offset = base + c * F64_LANES; let v = _mm512_loadu_pd(a.add(offset)); + + let old_max = max_vec; max_vec = _mm512_max_pd(max_vec, v); + + // Guard: when old_max == max_vec == -inf, exp(-inf-(-inf)) = NaN. + // Use mask to zero out -inf lanes (their sum contribution is 0). + let neg_inf = _mm512_set1_pd(f64::NEG_INFINITY); + let valid_old = _mm512_cmp_pd_mask(old_max, neg_inf, _CMP_GT_OQ); + let rescale = exp_f64(_mm512_sub_pd(old_max, max_vec)); + let rescale = _mm512_maskz_mov_pd(valid_old, rescale); + sum_vec = _mm512_mul_pd(sum_vec, rescale); + + let valid_new = _mm512_cmp_pd_mask(max_vec, neg_inf, _CMP_GT_OQ); + let exp_v = exp_f64(_mm512_sub_pd(v, max_vec)); + let exp_v = _mm512_maskz_mov_pd(valid_new, exp_v); + sum_vec = _mm512_add_pd(sum_vec, exp_v); } + let mut max_val = _mm512_reduce_max_pd(max_vec); - // Scalar tail for max + // Scalar tail (online) + let mut tail_sum = 0.0f64; for d in (chunks * F64_LANES)..dim_size { let val = *a.add(base + d); if val > max_val { + let rescale = if max_val == f64::NEG_INFINITY { + 0.0 + } else { + (max_val - val).exp() + }; + tail_sum = tail_sum * rescale + 1.0; max_val = val; + } else if val == f64::NEG_INFINITY { + // skip + } else { + tail_sum += (val - max_val).exp(); } } - // Step 2: Compute exp(x - max) and accumulate sum + // Reconcile SIMD sum with global max + // Guard -inf lanes to avoid NaN from exp(-inf - (-inf)) + let v_global_max = _mm512_set1_pd(max_val); + let neg_inf = _mm512_set1_pd(f64::NEG_INFINITY); + let valid_mask = _mm512_cmp_pd_mask(max_vec, neg_inf, _CMP_GT_OQ); + let rescale = exp_f64(_mm512_sub_pd(max_vec, v_global_max)); + let rescale = _mm512_maskz_mov_pd(valid_mask, rescale); + let rescaled_sum = _mm512_mul_pd(sum_vec, rescale); + let sum = _mm512_reduce_add_pd(rescaled_sum) + tail_sum; + + // Pass 2: exp(x - max) / sum let v_max = _mm512_set1_pd(max_val); - let mut sum_vec = _mm512_setzero_pd(); + let v_inv_sum = _mm512_set1_pd(1.0 / sum); for c in 0..chunks { let offset = base + c * F64_LANES; let v = _mm512_loadu_pd(a.add(offset)); let diff = _mm512_sub_pd(v, v_max); - let exp_v = exp_f64(diff); - _mm512_storeu_pd(out.add(offset), exp_v); - sum_vec = _mm512_add_pd(sum_vec, exp_v); - } - - let mut sum = _mm512_reduce_add_pd(sum_vec); - - // Scalar tail for exp and sum - for d in (chunks * F64_LANES)..dim_size { - let val = *a.add(base + d); - let exp_val = (val - max_val).exp(); - *out.add(base + d) = exp_val; - sum += exp_val; - } - - // Step 3: SIMD normalize - let v_inv_sum = _mm512_set1_pd(1.0 / sum); - - for c in 0..chunks { - let offset = base + c * F64_LANES; - let v = _mm512_loadu_pd(out.add(offset)); - let normalized = _mm512_mul_pd(v, v_inv_sum); + let normalized = _mm512_mul_pd(exp_f64(diff), v_inv_sum); _mm512_storeu_pd(out.add(offset), normalized); } - // Scalar tail for normalization let inv_sum = 1.0 / sum; for d in (chunks * F64_LANES)..dim_size { - *out.add(base + d) *= inv_sum; + let val = *a.add(base + d); + *out.add(base + d) = (val - max_val).exp() * inv_sum; } } } diff --git a/src/runtime/cpu/kernels/simd/softmax/mod.rs b/src/runtime/cpu/kernels/simd/softmax/mod.rs index 9b76d1fd..8787e990 100644 --- a/src/runtime/cpu/kernels/simd/softmax/mod.rs +++ b/src/runtime/cpu/kernels/simd/softmax/mod.rs @@ -1,14 +1,19 @@ -//! SIMD-accelerated softmax operation +//! SIMD-accelerated softmax operation using the online softmax algorithm. //! //! Softmax is critical for attention mechanisms in transformers. //! softmax(x)[i] = exp(x[i] - max(x)) / sum(exp(x - max(x))) //! -//! # SIMD Optimizations +//! # Online Softmax Algorithm (2-pass) //! -//! - SIMD max-reduce for finding maximum -//! - SIMD exp computation (vectorized) -//! - SIMD sum-reduce for normalization -//! - SIMD multiply for final division +//! Instead of the traditional 3-pass approach (find max, compute exp+sum, normalize), +//! we use a 2-pass online algorithm: +//! +//! **Pass 1 (online max + sum):** For each element x[i], maintain running max `m` and +//! running sum `s`. When a new max is found, rescale the accumulated sum. +//! +//! **Pass 2 (normalize):** output[i] = exp(x[i] - m) / s +//! +//! This saves one full read+write pass over the output buffer compared to 3-pass. #[cfg(target_arch = "x86_64")] mod avx2; @@ -97,67 +102,128 @@ pub unsafe fn softmax_f64(a: *const f64, out: *mut f64, outer_size: usize, dim_s // Scalar fallbacks // ============================================================================ -/// Scalar softmax for f32 +/// Scalar softmax for f32 using online algorithm (2-pass). #[inline] pub unsafe fn softmax_scalar_f32(a: *const f32, out: *mut f32, outer_size: usize, dim_size: usize) { for o in 0..outer_size { let base = o * dim_size; - // Find max + // Pass 1: Online max + sum — single read of input let mut max_val = *a.add(base); + let mut sum = if max_val.is_finite() { 1.0f32 } else { 0.0f32 }; for d in 1..dim_size { let val = *a.add(base + d); if val > max_val { + // Guard: if max_val == -inf, rescale factor is 0 (not NaN) + let rescale = if max_val == f32::NEG_INFINITY { + 0.0 + } else { + (max_val - val).exp() + }; + sum = sum * rescale + 1.0; max_val = val; + } else if val == f32::NEG_INFINITY { + // exp(-inf - anything) = 0, skip to avoid NaN from -inf - (-inf) + } else { + sum += (val - max_val).exp(); } } - // Compute exp(x - max) and sum - let mut sum = 0.0f32; - for d in 0..dim_size { - let val = *a.add(base + d); - let exp_val = (val - max_val).exp(); - *out.add(base + d) = exp_val; - sum += exp_val; - } - - // Normalize + // Pass 2: Compute exp(x - max) / sum — one read of input, one write of output let inv_sum = 1.0 / sum; for d in 0..dim_size { - *out.add(base + d) *= inv_sum; + let val = *a.add(base + d); + *out.add(base + d) = if val == f32::NEG_INFINITY { + 0.0 + } else { + (val - max_val).exp() * inv_sum + }; } } } -/// Scalar softmax for f64 +/// Scalar softmax for f64 using online algorithm (2-pass). #[inline] pub unsafe fn softmax_scalar_f64(a: *const f64, out: *mut f64, outer_size: usize, dim_size: usize) { for o in 0..outer_size { let base = o * dim_size; - // Find max + // Pass 1: Online max + sum let mut max_val = *a.add(base); + let mut sum = if max_val.is_finite() { 1.0f64 } else { 0.0f64 }; for d in 1..dim_size { let val = *a.add(base + d); if val > max_val { + let rescale = if max_val == f64::NEG_INFINITY { + 0.0 + } else { + (max_val - val).exp() + }; + sum = sum * rescale + 1.0; max_val = val; + } else if val == f64::NEG_INFINITY { + // exp(-inf - anything) = 0, skip to avoid NaN from -inf - (-inf) + } else { + sum += (val - max_val).exp(); } } - // Compute exp(x - max) and sum - let mut sum = 0.0f64; + // Pass 2: Compute exp(x - max) / sum + let inv_sum = 1.0 / sum; for d in 0..dim_size { let val = *a.add(base + d); - let exp_val = (val - max_val).exp(); - *out.add(base + d) = exp_val; - sum += exp_val; + *out.add(base + d) = if val == f64::NEG_INFINITY { + 0.0 + } else { + (val - max_val).exp() * inv_sum + }; } + } +} - // Normalize - let inv_sum = 1.0 / sum; - for d in 0..dim_size { - *out.add(base + d) *= inv_sum; - } +#[cfg(feature = "f16")] +/// f16 wrapper for softmax: processes one row at a time via f32 conversion. +/// +/// # Safety +/// - `a` and `out` must point to `outer_size * dim_size` elements +pub unsafe fn softmax_f16( + a: *const half::f16, + out: *mut half::f16, + outer_size: usize, + dim_size: usize, +) { + use super::half_convert_utils::*; + let row_len = dim_size; + let mut a_buf = vec![0.0f32; row_len]; + let mut out_buf = vec![0.0f32; row_len]; + for i in 0..outer_size { + let offset = i * dim_size; + convert_f16_to_f32(a.add(offset) as *const u16, a_buf.as_mut_ptr(), row_len); + softmax_f32(a_buf.as_ptr(), out_buf.as_mut_ptr(), 1, dim_size); + convert_f32_to_f16(out_buf.as_ptr(), out.add(offset) as *mut u16, row_len); + } +} + +#[cfg(feature = "f16")] +/// bf16 wrapper for softmax: processes one row at a time via f32 conversion. +/// +/// # Safety +/// - `a` and `out` must point to `outer_size * dim_size` elements +pub unsafe fn softmax_bf16( + a: *const half::bf16, + out: *mut half::bf16, + outer_size: usize, + dim_size: usize, +) { + use super::half_convert_utils::*; + let row_len = dim_size; + let mut a_buf = vec![0.0f32; row_len]; + let mut out_buf = vec![0.0f32; row_len]; + for i in 0..outer_size { + let offset = i * dim_size; + convert_bf16_to_f32(a.add(offset) as *const u16, a_buf.as_mut_ptr(), row_len); + softmax_f32(a_buf.as_ptr(), out_buf.as_mut_ptr(), 1, dim_size); + convert_f32_to_bf16(out_buf.as_ptr(), out.add(offset) as *mut u16, row_len); } } diff --git a/src/runtime/cpu/kernels/simd/softmax_bwd/aarch64/mod.rs b/src/runtime/cpu/kernels/simd/softmax_bwd/aarch64/mod.rs new file mode 100644 index 00000000..ad60b5cd --- /dev/null +++ b/src/runtime/cpu/kernels/simd/softmax_bwd/aarch64/mod.rs @@ -0,0 +1,3 @@ +//! AArch64-specific softmax backward SIMD implementations + +pub mod neon; diff --git a/src/runtime/cpu/kernels/simd/softmax_bwd/aarch64/neon.rs b/src/runtime/cpu/kernels/simd/softmax_bwd/aarch64/neon.rs new file mode 100644 index 00000000..161cf7b5 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/softmax_bwd/aarch64/neon.rs @@ -0,0 +1,121 @@ +//! NEON softmax backward kernels for ARM64. +//! +//! Fused 2-pass: SIMD dot product, then SIMD elementwise output * (grad - dot). + +#[cfg(target_arch = "aarch64")] +use std::arch::aarch64::*; + +use super::super::super::math::aarch64::neon::{hsum_f32, hsum_f64}; + +const F32_LANES: usize = 4; +const F64_LANES: usize = 2; + +/// NEON softmax backward for f32. +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - All pointers must point to `outer_size * dim_size` valid f32 elements +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn softmax_bwd_f32( + grad: *const f32, + output: *const f32, + d_input: *mut f32, + outer_size: usize, + dim_size: usize, +) { + let chunks = dim_size / F32_LANES; + let remainder = dim_size % F32_LANES; + + for o in 0..outer_size { + let g_base = grad.add(o * dim_size); + let o_base = output.add(o * dim_size); + let d_base = d_input.add(o * dim_size); + + // Pass 1: SIMD dot product + let mut dot_acc = vdupq_n_f32(0.0); + for i in 0..chunks { + let offset = i * F32_LANES; + let g = vld1q_f32(g_base.add(offset)); + let out = vld1q_f32(o_base.add(offset)); + dot_acc = vfmaq_f32(dot_acc, g, out); + } + let mut dot = hsum_f32(dot_acc); + + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + dot += *g_base.add(offset) * *o_base.add(offset); + } + + // Pass 2: d_input = output * (grad - dot) + let v_dot = vdupq_n_f32(dot); + for i in 0..chunks { + let offset = i * F32_LANES; + let g = vld1q_f32(g_base.add(offset)); + let out = vld1q_f32(o_base.add(offset)); + let shifted = vsubq_f32(g, v_dot); + let result = vmulq_f32(out, shifted); + vst1q_f32(d_base.add(offset), result); + } + + for i in 0..remainder { + let offset = chunks * F32_LANES + i; + *d_base.add(offset) = *o_base.add(offset) * (*g_base.add(offset) - dot); + } + } +} + +/// NEON softmax backward for f64. +/// +/// # Safety +/// - CPU must support NEON (always true on AArch64) +/// - All pointers must point to `outer_size * dim_size` valid f64 elements +#[cfg(target_arch = "aarch64")] +#[target_feature(enable = "neon")] +pub unsafe fn softmax_bwd_f64( + grad: *const f64, + output: *const f64, + d_input: *mut f64, + outer_size: usize, + dim_size: usize, +) { + let chunks = dim_size / F64_LANES; + let remainder = dim_size % F64_LANES; + + for o in 0..outer_size { + let g_base = grad.add(o * dim_size); + let o_base = output.add(o * dim_size); + let d_base = d_input.add(o * dim_size); + + // Pass 1: SIMD dot product + let mut dot_acc = vdupq_n_f64(0.0); + for i in 0..chunks { + let offset = i * F64_LANES; + let g = vld1q_f64(g_base.add(offset)); + let out = vld1q_f64(o_base.add(offset)); + dot_acc = vfmaq_f64(dot_acc, g, out); + } + let mut dot = hsum_f64(dot_acc); + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + dot += *g_base.add(offset) * *o_base.add(offset); + } + + // Pass 2: d_input = output * (grad - dot) + let v_dot = vdupq_n_f64(dot); + for i in 0..chunks { + let offset = i * F64_LANES; + let g = vld1q_f64(g_base.add(offset)); + let out = vld1q_f64(o_base.add(offset)); + let shifted = vsubq_f64(g, v_dot); + let result = vmulq_f64(out, shifted); + vst1q_f64(d_base.add(offset), result); + } + + for i in 0..remainder { + let offset = chunks * F64_LANES + i; + *d_base.add(offset) = *o_base.add(offset) * (*g_base.add(offset) - dot); + } + } +} diff --git a/src/runtime/cpu/kernels/simd/softmax_bwd/avx2.rs b/src/runtime/cpu/kernels/simd/softmax_bwd/avx2.rs new file mode 100644 index 00000000..5af5d990 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/softmax_bwd/avx2.rs @@ -0,0 +1,105 @@ +//! AVX2 softmax backward kernels. +//! +//! Fused 2-pass: SIMD dot product, then SIMD elementwise output * (grad - dot). + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +use super::super::math::avx2::{hsum_f32, hsum_f64}; + +const F32_LANES: usize = 8; +const F64_LANES: usize = 4; + +/// AVX2 softmax backward for f32. +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn softmax_bwd_f32( + grad: *const f32, + output: *const f32, + d_input: *mut f32, + outer_size: usize, + dim_size: usize, +) { + let chunks = dim_size / F32_LANES; + + for o in 0..outer_size { + let base = o * dim_size; + + // Pass 1: SIMD dot product — dot = sum(grad * output) + let mut dot_vec = _mm256_setzero_ps(); + for c in 0..chunks { + let offset = base + c * F32_LANES; + let g = _mm256_loadu_ps(grad.add(offset)); + let out = _mm256_loadu_ps(output.add(offset)); + dot_vec = _mm256_fmadd_ps(g, out, dot_vec); + } + let mut dot = hsum_f32(dot_vec); + + // Scalar tail for dot + for d in (chunks * F32_LANES)..dim_size { + dot += *grad.add(base + d) * *output.add(base + d); + } + + // Pass 2: SIMD d_input = output * (grad - dot) + let v_dot = _mm256_set1_ps(dot); + for c in 0..chunks { + let offset = base + c * F32_LANES; + let g = _mm256_loadu_ps(grad.add(offset)); + let out = _mm256_loadu_ps(output.add(offset)); + let shifted = _mm256_sub_ps(g, v_dot); + let result = _mm256_mul_ps(out, shifted); + _mm256_storeu_ps(d_input.add(offset), result); + } + + // Scalar tail + for d in (chunks * F32_LANES)..dim_size { + let idx = base + d; + *d_input.add(idx) = *output.add(idx) * (*grad.add(idx) - dot); + } + } +} + +/// AVX2 softmax backward for f64. +#[target_feature(enable = "avx2", enable = "fma")] +pub unsafe fn softmax_bwd_f64( + grad: *const f64, + output: *const f64, + d_input: *mut f64, + outer_size: usize, + dim_size: usize, +) { + let chunks = dim_size / F64_LANES; + + for o in 0..outer_size { + let base = o * dim_size; + + // Pass 1: SIMD dot product + let mut dot_vec = _mm256_setzero_pd(); + for c in 0..chunks { + let offset = base + c * F64_LANES; + let g = _mm256_loadu_pd(grad.add(offset)); + let out = _mm256_loadu_pd(output.add(offset)); + dot_vec = _mm256_fmadd_pd(g, out, dot_vec); + } + let mut dot = hsum_f64(dot_vec); + + for d in (chunks * F64_LANES)..dim_size { + dot += *grad.add(base + d) * *output.add(base + d); + } + + // Pass 2: d_input = output * (grad - dot) + let v_dot = _mm256_set1_pd(dot); + for c in 0..chunks { + let offset = base + c * F64_LANES; + let g = _mm256_loadu_pd(grad.add(offset)); + let out = _mm256_loadu_pd(output.add(offset)); + let shifted = _mm256_sub_pd(g, v_dot); + let result = _mm256_mul_pd(out, shifted); + _mm256_storeu_pd(d_input.add(offset), result); + } + + for d in (chunks * F64_LANES)..dim_size { + let idx = base + d; + *d_input.add(idx) = *output.add(idx) * (*grad.add(idx) - dot); + } + } +} diff --git a/src/runtime/cpu/kernels/simd/softmax_bwd/avx512.rs b/src/runtime/cpu/kernels/simd/softmax_bwd/avx512.rs new file mode 100644 index 00000000..61eded71 --- /dev/null +++ b/src/runtime/cpu/kernels/simd/softmax_bwd/avx512.rs @@ -0,0 +1,101 @@ +//! AVX-512 softmax backward kernels. +//! +//! Fused 2-pass: SIMD dot product, then SIMD elementwise output * (grad - dot). + +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +const F32_LANES: usize = 16; +const F64_LANES: usize = 8; + +/// AVX-512 softmax backward for f32. +#[target_feature(enable = "avx512f")] +pub unsafe fn softmax_bwd_f32( + grad: *const f32, + output: *const f32, + d_input: *mut f32, + outer_size: usize, + dim_size: usize, +) { + let chunks = dim_size / F32_LANES; + + for o in 0..outer_size { + let base = o * dim_size; + + // Pass 1: SIMD dot product + let mut dot_vec = _mm512_setzero_ps(); + for c in 0..chunks { + let offset = base + c * F32_LANES; + let g = _mm512_loadu_ps(grad.add(offset)); + let out = _mm512_loadu_ps(output.add(offset)); + dot_vec = _mm512_fmadd_ps(g, out, dot_vec); + } + let mut dot = _mm512_reduce_add_ps(dot_vec); + + for d in (chunks * F32_LANES)..dim_size { + dot += *grad.add(base + d) * *output.add(base + d); + } + + // Pass 2: d_input = output * (grad - dot) + let v_dot = _mm512_set1_ps(dot); + for c in 0..chunks { + let offset = base + c * F32_LANES; + let g = _mm512_loadu_ps(grad.add(offset)); + let out = _mm512_loadu_ps(output.add(offset)); + let shifted = _mm512_sub_ps(g, v_dot); + let result = _mm512_mul_ps(out, shifted); + _mm512_storeu_ps(d_input.add(offset), result); + } + + for d in (chunks * F32_LANES)..dim_size { + let idx = base + d; + *d_input.add(idx) = *output.add(idx) * (*grad.add(idx) - dot); + } + } +} + +/// AVX-512 softmax backward for f64. +#[target_feature(enable = "avx512f")] +pub unsafe fn softmax_bwd_f64( + grad: *const f64, + output: *const f64, + d_input: *mut f64, + outer_size: usize, + dim_size: usize, +) { + let chunks = dim_size / F64_LANES; + + for o in 0..outer_size { + let base = o * dim_size; + + // Pass 1: SIMD dot product + let mut dot_vec = _mm512_setzero_pd(); + for c in 0..chunks { + let offset = base + c * F64_LANES; + let g = _mm512_loadu_pd(grad.add(offset)); + let out = _mm512_loadu_pd(output.add(offset)); + dot_vec = _mm512_fmadd_pd(g, out, dot_vec); + } + let mut dot = _mm512_reduce_add_pd(dot_vec); + + for d in (chunks * F64_LANES)..dim_size { + dot += *grad.add(base + d) * *output.add(base + d); + } + + // Pass 2: d_input = output * (grad - dot) + let v_dot = _mm512_set1_pd(dot); + for c in 0..chunks { + let offset = base + c * F64_LANES; + let g = _mm512_loadu_pd(grad.add(offset)); + let out = _mm512_loadu_pd(output.add(offset)); + let shifted = _mm512_sub_pd(g, v_dot); + let result = _mm512_mul_pd(out, shifted); + _mm512_storeu_pd(d_input.add(offset), result); + } + + for d in (chunks * F64_LANES)..dim_size { + let idx = base + d; + *d_input.add(idx) = *output.add(idx) * (*grad.add(idx) - dot); + } + } +} diff --git a/src/runtime/cpu/kernels/simd/softmax_bwd/mod.rs b/src/runtime/cpu/kernels/simd/softmax_bwd/mod.rs new file mode 100644 index 00000000..a777591e --- /dev/null +++ b/src/runtime/cpu/kernels/simd/softmax_bwd/mod.rs @@ -0,0 +1,326 @@ +//! SIMD-accelerated softmax backward operation. +//! +//! Computes: d_input[i] = output[i] * (grad[i] - dot) +//! where dot = sum(grad * output) along the softmax dimension. +//! +//! Fused 2-pass kernel: +//! - Pass 1: SIMD dot product (grad * output, reduced to scalar) +//! - Pass 2: SIMD elementwise output * (grad - dot) + +#[cfg(target_arch = "x86_64")] +mod avx2; +#[cfg(target_arch = "x86_64")] +mod avx512; + +#[cfg(target_arch = "aarch64")] +mod aarch64; + +use super::{SimdLevel, detect_simd}; + +/// Minimum dimension size to justify SIMD overhead +const SIMD_THRESHOLD: usize = 32; + +/// SIMD softmax backward for f32 +/// +/// # Safety +/// - `grad`, `output`, `d_input` must point to `outer_size * dim_size` elements +#[inline] +pub unsafe fn softmax_bwd_f32( + grad: *const f32, + output: *const f32, + d_input: *mut f32, + outer_size: usize, + dim_size: usize, +) { + let level = detect_simd(); + + if dim_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + softmax_bwd_scalar_f32(grad, output, d_input, outer_size, dim_size); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::softmax_bwd_f32(grad, output, d_input, outer_size, dim_size), + SimdLevel::Avx2Fma => avx2::softmax_bwd_f32(grad, output, d_input, outer_size, dim_size), + _ => softmax_bwd_scalar_f32(grad, output, d_input, outer_size, dim_size), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::softmax_bwd_f32(grad, output, d_input, outer_size, dim_size) + } + _ => softmax_bwd_scalar_f32(grad, output, d_input, outer_size, dim_size), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + softmax_bwd_scalar_f32(grad, output, d_input, outer_size, dim_size); +} + +/// SIMD softmax backward for f64 +/// +/// # Safety +/// - `grad`, `output`, `d_input` must point to `outer_size * dim_size` elements +#[inline] +pub unsafe fn softmax_bwd_f64( + grad: *const f64, + output: *const f64, + d_input: *mut f64, + outer_size: usize, + dim_size: usize, +) { + let level = detect_simd(); + + if dim_size < SIMD_THRESHOLD || level == SimdLevel::Scalar { + softmax_bwd_scalar_f64(grad, output, d_input, outer_size, dim_size); + return; + } + + #[cfg(target_arch = "x86_64")] + match level { + SimdLevel::Avx512 => avx512::softmax_bwd_f64(grad, output, d_input, outer_size, dim_size), + SimdLevel::Avx2Fma => avx2::softmax_bwd_f64(grad, output, d_input, outer_size, dim_size), + _ => softmax_bwd_scalar_f64(grad, output, d_input, outer_size, dim_size), + } + + #[cfg(target_arch = "aarch64")] + match level { + SimdLevel::Neon | SimdLevel::NeonFp16 => { + aarch64::neon::softmax_bwd_f64(grad, output, d_input, outer_size, dim_size) + } + _ => softmax_bwd_scalar_f64(grad, output, d_input, outer_size, dim_size), + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + softmax_bwd_scalar_f64(grad, output, d_input, outer_size, dim_size); +} + +// ============================================================================ +// Scalar fallbacks +// ============================================================================ + +/// Scalar softmax backward for f32 +#[inline] +pub unsafe fn softmax_bwd_scalar_f32( + grad: *const f32, + output: *const f32, + d_input: *mut f32, + outer_size: usize, + dim_size: usize, +) { + for o in 0..outer_size { + let base = o * dim_size; + + // Pass 1: dot = sum(grad * output) + let mut dot = 0.0f32; + for d in 0..dim_size { + dot += *grad.add(base + d) * *output.add(base + d); + } + + // Pass 2: d_input = output * (grad - dot) + for d in 0..dim_size { + let idx = base + d; + *d_input.add(idx) = *output.add(idx) * (*grad.add(idx) - dot); + } + } +} + +/// Scalar softmax backward for f64 +#[inline] +pub unsafe fn softmax_bwd_scalar_f64( + grad: *const f64, + output: *const f64, + d_input: *mut f64, + outer_size: usize, + dim_size: usize, +) { + for o in 0..outer_size { + let base = o * dim_size; + + // Pass 1: dot = sum(grad * output) + let mut dot = 0.0f64; + for d in 0..dim_size { + dot += *grad.add(base + d) * *output.add(base + d); + } + + // Pass 2: d_input = output * (grad - dot) + for d in 0..dim_size { + let idx = base + d; + *d_input.add(idx) = *output.add(idx) * (*grad.add(idx) - dot); + } + } +} + +#[cfg(feature = "f16")] +/// f16 wrapper for softmax backward: processes one row at a time via f32 conversion. +/// +/// # Safety +/// - All pointers must point to `outer_size * dim_size` elements +pub unsafe fn softmax_bwd_f16( + grad: *const half::f16, + output: *const half::f16, + d_input: *mut half::f16, + outer_size: usize, + dim_size: usize, +) { + use super::half_convert_utils::*; + let row_len = dim_size; + let mut grad_buf = vec![0.0f32; row_len]; + let mut out_buf = vec![0.0f32; row_len]; + let mut result_buf = vec![0.0f32; row_len]; + for i in 0..outer_size { + let offset = i * dim_size; + convert_f16_to_f32( + grad.add(offset) as *const u16, + grad_buf.as_mut_ptr(), + row_len, + ); + convert_f16_to_f32( + output.add(offset) as *const u16, + out_buf.as_mut_ptr(), + row_len, + ); + softmax_bwd_f32( + grad_buf.as_ptr(), + out_buf.as_ptr(), + result_buf.as_mut_ptr(), + 1, + dim_size, + ); + convert_f32_to_f16( + result_buf.as_ptr(), + d_input.add(offset) as *mut u16, + row_len, + ); + } +} + +#[cfg(feature = "f16")] +/// bf16 wrapper for softmax backward: processes one row at a time via f32 conversion. +/// +/// # Safety +/// - All pointers must point to `outer_size * dim_size` elements +pub unsafe fn softmax_bwd_bf16( + grad: *const half::bf16, + output: *const half::bf16, + d_input: *mut half::bf16, + outer_size: usize, + dim_size: usize, +) { + use super::half_convert_utils::*; + let row_len = dim_size; + let mut grad_buf = vec![0.0f32; row_len]; + let mut out_buf = vec![0.0f32; row_len]; + let mut result_buf = vec![0.0f32; row_len]; + for i in 0..outer_size { + let offset = i * dim_size; + convert_bf16_to_f32( + grad.add(offset) as *const u16, + grad_buf.as_mut_ptr(), + row_len, + ); + convert_bf16_to_f32( + output.add(offset) as *const u16, + out_buf.as_mut_ptr(), + row_len, + ); + softmax_bwd_f32( + grad_buf.as_ptr(), + out_buf.as_ptr(), + result_buf.as_mut_ptr(), + 1, + dim_size, + ); + convert_f32_to_bf16( + result_buf.as_ptr(), + d_input.add(offset) as *mut u16, + row_len, + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_softmax_bwd_f32() { + // softmax output that sums to 1 + let output = [0.09003057f32, 0.24472847, 0.66524096]; // softmax([1,2,3]) + let grad = [1.0f32, 0.0, 0.0]; // d_loss/d_softmax + let mut d_input = [0.0f32; 3]; + + unsafe { + softmax_bwd_f32(grad.as_ptr(), output.as_ptr(), d_input.as_mut_ptr(), 1, 3); + } + + // dot = 1.0 * 0.09003057 = 0.09003057 + // d_input[0] = 0.09003057 * (1.0 - 0.09003057) = 0.0819 + // d_input[1] = 0.24472847 * (0.0 - 0.09003057) = -0.02203 + // d_input[2] = 0.66524096 * (0.0 - 0.09003057) = -0.05989 + assert!((d_input[0] - 0.08192507).abs() < 1e-5); + assert!((d_input[1] - (-0.02203645)).abs() < 1e-5); + assert!((d_input[2] - (-0.05988862)).abs() < 1e-5); + + // Gradients should sum to 0 (softmax outputs sum to 1, so Jacobian rows sum to 0) + let sum: f32 = d_input.iter().sum(); + assert!(sum.abs() < 1e-6, "gradients should sum to 0, got {sum}"); + } + + #[test] + fn test_softmax_bwd_simd() { + let dim_size = 128; + let outer_size = 4; + + // Create valid softmax outputs (sum to 1 per row) + let mut output = vec![0.0f32; outer_size * dim_size]; + for o in 0..outer_size { + let base = o * dim_size; + let sum: f32 = (0..dim_size).map(|d| ((d as f32) * 0.1 - 5.0).exp()).sum(); + for d in 0..dim_size { + output[base + d] = ((d as f32) * 0.1 - 5.0).exp() / sum; + } + } + + let grad: Vec = (0..(outer_size * dim_size)) + .map(|x| (x as f32) / 100.0 - 2.5) + .collect(); + + let mut d_input_simd = vec![0.0f32; outer_size * dim_size]; + let mut d_input_ref = vec![0.0f32; outer_size * dim_size]; + + unsafe { + softmax_bwd_f32( + grad.as_ptr(), + output.as_ptr(), + d_input_simd.as_mut_ptr(), + outer_size, + dim_size, + ); + softmax_bwd_scalar_f32( + grad.as_ptr(), + output.as_ptr(), + d_input_ref.as_mut_ptr(), + outer_size, + dim_size, + ); + } + + for i in 0..(outer_size * dim_size) { + let rel_err = if d_input_ref[i].abs() > 1e-10 { + (d_input_simd[i] - d_input_ref[i]).abs() / d_input_ref[i].abs() + } else { + (d_input_simd[i] - d_input_ref[i]).abs() + }; + assert!( + rel_err < 1e-3, + "mismatch at {}: {} vs {} (rel_err: {})", + i, + d_input_simd[i], + d_input_ref[i], + rel_err + ); + } + } +} diff --git a/src/runtime/cpu/kernels/simd/special/aarch64/neon.rs b/src/runtime/cpu/kernels/simd/special/aarch64/neon.rs index 4bc37ccc..503c136d 100644 --- a/src/runtime/cpu/kernels/simd/special/aarch64/neon.rs +++ b/src/runtime/cpu/kernels/simd/special/aarch64/neon.rs @@ -18,7 +18,7 @@ use std::arch::aarch64::*; use crate::algorithm::special::scalar::{ - bessel_i0_scalar, bessel_i1_scalar, bessel_j0_scalar, bessel_j1_scalar, erf_scalar, erfc_scalar, + bessel_i0_scalar, bessel_i1_scalar, bessel_j0_scalar, bessel_j1_scalar, erf_scalar, }; // ============================================================================ diff --git a/src/runtime/cpu/kernels/simd/special/mod.rs b/src/runtime/cpu/kernels/simd/special/mod.rs index c331b369..ac860328 100644 --- a/src/runtime/cpu/kernels/simd/special/mod.rs +++ b/src/runtime/cpu/kernels/simd/special/mod.rs @@ -191,6 +191,17 @@ impl_scalar_only!(gamma); impl_scalar_only!(lgamma); impl_scalar_only!(digamma); +// F16/BF16 Wrappers via macros +half_unary!(erf, erf_f32); +half_unary!(erfc, erfc_f32); +half_unary!(bessel_j0, bessel_j0_f32); +half_unary!(bessel_j1, bessel_j1_f32); +half_unary!(bessel_i0, bessel_i0_f32); +half_unary!(bessel_i1, bessel_i1_f32); +half_unary!(gamma, gamma_f32); +half_unary!(lgamma, lgamma_f32); +half_unary!(digamma, digamma_f32); + // ============================================================================ // Tests // ============================================================================ diff --git a/src/runtime/cpu/kernels/simd/unary/aarch64/neon.rs b/src/runtime/cpu/kernels/simd/unary/aarch64/neon.rs index ad5889b2..22c30749 100644 --- a/src/runtime/cpu/kernels/simd/unary/aarch64/neon.rs +++ b/src/runtime/cpu/kernels/simd/unary/aarch64/neon.rs @@ -117,11 +117,6 @@ pub unsafe fn unary_f32(op: UnaryOp, a: *const f32, out: *mut f32, len: usize) { UnaryOp::Asinh => unary_transcendental_f32(a, out, chunks, math::asinh_f32), UnaryOp::Acosh => unary_transcendental_f32(a, out, chunks, math::acosh_f32), UnaryOp::Atanh => unary_transcendental_f32(a, out, chunks, math::atanh_f32), - _ => { - // Unsupported ops handled above - unary_scalar_f32(op, a, out, len); - return; - } } if remainder > 0 { @@ -181,11 +176,6 @@ pub unsafe fn unary_f64(op: UnaryOp, a: *const f64, out: *mut f64, len: usize) { UnaryOp::Asinh => unary_transcendental_f64(a, out, chunks, math::asinh_f64), UnaryOp::Acosh => unary_transcendental_f64(a, out, chunks, math::acosh_f64), UnaryOp::Atanh => unary_transcendental_f64(a, out, chunks, math::atanh_f64), - _ => { - // Unsupported ops handled above - unary_scalar_f64(op, a, out, len); - return; - } } if remainder > 0 { diff --git a/src/runtime/cpu/kernels/simd/unary/mod.rs b/src/runtime/cpu/kernels/simd/unary/mod.rs index 8ab5ca26..fd725870 100644 --- a/src/runtime/cpu/kernels/simd/unary/mod.rs +++ b/src/runtime/cpu/kernels/simd/unary/mod.rs @@ -194,6 +194,13 @@ pub unsafe fn relu_f64(a: *const f64, out: *mut f64, len: usize) { relu_scalar_f64(a, out, len); } +// --------------------------------------------------------------------------- +// f16/bf16 via f32 block-convert-compute +// --------------------------------------------------------------------------- + +half_unary_op!(unary, unary_f32, UnaryOp); +half_unary!(relu, relu_f32); + #[cfg(test)] mod tests { use super::*; diff --git a/src/runtime/cpu/kernels/simd/where_select/mod.rs b/src/runtime/cpu/kernels/simd/where_select/mod.rs index 60eaebeb..a2ce3026 100644 --- a/src/runtime/cpu/kernels/simd/where_select/mod.rs +++ b/src/runtime/cpu/kernels/simd/where_select/mod.rs @@ -122,6 +122,8 @@ pub unsafe fn where_scalar_f64( } } +half_where!(r#where, where_f32); + #[cfg(test)] mod tests { use super::*; diff --git a/src/runtime/cpu/kernels/sparse_24.rs b/src/runtime/cpu/kernels/sparse_24.rs new file mode 100644 index 00000000..695945a8 --- /dev/null +++ b/src/runtime/cpu/kernels/sparse_24.rs @@ -0,0 +1,218 @@ +//! CPU kernels for 2:4 structured sparsity +//! +//! Low-level kernels for pruning to 2:4 format, decompression, and sparse matmul. + +use crate::dtype::Element; + +/// Prune a dense [M, K] matrix to 2:4 structured sparsity. +/// +/// For each group of 4 elements along K, keeps the 2 with largest magnitude. +/// +/// # Arguments +/// * `dense` - Input data, row-major [M, K] +/// * `compressed` - Output compressed values, row-major [M, K/2] +/// * `metadata` - Output packed metadata, row-major [M, meta_cols] as u32 +/// * `m` - Number of rows +/// * `k` - Number of columns (must be divisible by 4) +/// +/// # Safety +/// Caller must ensure all pointers are valid and buffers are correctly sized. +pub unsafe fn prune_to_24_kernel( + dense: *const T, + compressed: *mut T, + metadata: *mut u32, + m: usize, + k: usize, +) { + let num_groups = k / 4; + let meta_cols = (num_groups + 7) / 8; + let half_k = k / 2; + + for row in 0..m { + let row_in = dense.add(row * k); + let row_out = compressed.add(row * half_k); + let row_meta = metadata.add(row * meta_cols); + + // Zero out metadata + for mc in 0..meta_cols { + *row_meta.add(mc) = 0; + } + + let mut out_idx = 0usize; + + for g in 0..num_groups { + let base = g * 4; + let vals = [ + *row_in.add(base), + *row_in.add(base + 1), + *row_in.add(base + 2), + *row_in.add(base + 3), + ]; + + // Compute magnitudes and find top-2 + let mags: [f64; 4] = [ + vals[0].to_f64().abs(), + vals[1].to_f64().abs(), + vals[2].to_f64().abs(), + vals[3].to_f64().abs(), + ]; + + // Find the 2 largest magnitudes (stable: prefer earlier indices on tie) + let (idx0, idx1) = top_2_indices(&mags); + + // Write compressed values (lower index first) + let (first, second) = if idx0 < idx1 { + (idx0, idx1) + } else { + (idx1, idx0) + }; + *row_out.add(out_idx) = vals[first]; + *row_out.add(out_idx + 1) = vals[second]; + out_idx += 2; + + // Build 4-bit bitmask: bit i set means position i is kept + let mask: u32 = (1 << first) | (1 << second); + + // Pack into metadata word + let word_idx = g / 8; + let nibble_idx = g % 8; + let word = row_meta.add(word_idx); + *word |= mask << (nibble_idx * 4); + } + } +} + +/// Find indices of the 2 largest values in a 4-element array. +/// On ties, prefers earlier indices. +#[inline] +fn top_2_indices(mags: &[f64; 4]) -> (usize, usize) { + // Simple approach: find max, then find second max + let mut indices = [0usize, 1, 2, 3]; + // Sort by magnitude descending, stable (preserves order on ties) + indices.sort_by(|&a, &b| { + mags[b] + .partial_cmp(&mags[a]) + .unwrap_or(std::cmp::Ordering::Equal) + }); + (indices[0], indices[1]) +} + +/// Decompress a 2:4 sparse tensor back to dense format. +/// +/// # Arguments +/// * `compressed` - Input compressed values, row-major [M, K/2] +/// * `metadata` - Input packed metadata, row-major [M, meta_cols] as u32 +/// * `dense` - Output dense values, row-major [M, K] +/// * `m` - Number of rows +/// * `k` - Number of columns +/// +/// # Safety +/// Caller must ensure all pointers are valid and buffers are correctly sized. +pub unsafe fn decompress_24_kernel( + compressed: *const T, + metadata: *const u32, + dense: *mut T, + m: usize, + k: usize, +) { + let num_groups = k / 4; + let meta_cols = (num_groups + 7) / 8; + let half_k = k / 2; + let zero = T::zeroed(); + + for row in 0..m { + let row_in = compressed.add(row * half_k); + let row_meta = metadata.add(row * meta_cols); + let row_out = dense.add(row * k); + + let mut in_idx = 0usize; + + for g in 0..num_groups { + let base = g * 4; + let word_idx = g / 8; + let nibble_idx = g % 8; + let word = *row_meta.add(word_idx); + let mask = (word >> (nibble_idx * 4)) & 0xF; + + // Write zeros first, then overwrite kept positions + for i in 0..4 { + *row_out.add(base + i) = zero; + } + + // Place the 2 compressed values at their original positions + for bit in 0..4u32 { + if mask & (1 << bit) != 0 { + *row_out.add(base + bit as usize) = *row_in.add(in_idx); + in_idx += 1; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_prune_roundtrip_f32() { + // Dense matrix: 2x8 + let dense: Vec = vec![ + 1.0, -3.0, 2.0, 0.5, // group 0: keep -3.0 (idx1), 2.0 (idx2) + 0.1, 0.2, 0.3, 0.4, // group 1: keep 0.3 (idx2), 0.4 (idx3) + 4.0, 1.0, -5.0, 3.0, // group 2: keep 4.0 (idx0), -5.0 (idx2) + 0.0, 0.0, 0.0, 0.0, // group 3: all zero, keep first 2 + ]; + let m = 2; + let k = 8; + let half_k = k / 2; + let meta_cols = 1; // 2 groups per row, fits in 1 u32 + + let mut compressed = vec![0.0f32; m * half_k]; + let mut metadata = vec![0u32; m * meta_cols]; + + unsafe { + prune_to_24_kernel( + dense.as_ptr(), + compressed.as_mut_ptr(), + metadata.as_mut_ptr(), + m, + k, + ); + } + + // Verify: group 0 (row 0): -3.0 (idx1) and 2.0 (idx2) are top-2 + // compressed[0] should be -3.0 (idx1), compressed[1] should be 2.0 (idx2) + // (sorted by index: idx1 < idx2) + assert_eq!(compressed[0], -3.0); + assert_eq!(compressed[1], 2.0); + + // Now decompress and verify roundtrip + let mut reconstructed = vec![0.0f32; m * k]; + unsafe { + decompress_24_kernel( + compressed.as_ptr(), + metadata.as_ptr(), + reconstructed.as_mut_ptr(), + m, + k, + ); + } + + // Row 0, group 0: positions 1,2 kept → [0, -3, 2, 0] + assert_eq!(reconstructed[0], 0.0); + assert_eq!(reconstructed[1], -3.0); + assert_eq!(reconstructed[2], 2.0); + assert_eq!(reconstructed[3], 0.0); + } + + #[test] + fn test_top_2_indices() { + // Basic case + assert_eq!(top_2_indices(&[1.0, 3.0, 2.0, 0.5]), (1, 2)); + // Ties: prefer earlier indices + assert_eq!(top_2_indices(&[1.0, 1.0, 1.0, 1.0]), (0, 1)); + // Negative magnitudes (should not happen since we pass abs, but test anyway) + assert_eq!(top_2_indices(&[0.0, 0.0, 0.0, 0.0]), (0, 1)); + } +} diff --git a/src/runtime/cpu/kernels/unary/activations.rs b/src/runtime/cpu/kernels/unary/activations.rs index 11126a2c..46f69e03 100644 --- a/src/runtime/cpu/kernels/unary/activations.rs +++ b/src/runtime/cpu/kernels/unary/activations.rs @@ -3,7 +3,7 @@ //! Provides element-wise activation functions with automatic SIMD dispatch. //! On x86-64, f32 and f64 operations use AVX-512 or AVX2 when available. -use crate::dtype::{DType, Element}; +use crate::dtype::Element; /// Sigmoid activation: 1 / (1 + exp(-x)) /// @@ -18,6 +18,7 @@ pub unsafe fn sigmoid_kernel(a: *const T, out: *mut T, len: usize) { #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] { use super::super::simd::activations; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -28,6 +29,16 @@ pub unsafe fn sigmoid_kernel(a: *const T, out: *mut T, len: usize) { activations::sigmoid_f64(a as *const f64, out as *mut f64, len); return; } + #[cfg(feature = "f16")] + DType::F16 => { + activations::sigmoid_f16(a as *const half::f16, out as *mut half::f16, len); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + activations::sigmoid_bf16(a as *const half::bf16, out as *mut half::bf16, len); + return; + } _ => {} } } @@ -59,6 +70,7 @@ pub unsafe fn silu_kernel(a: *const T, out: *mut T, len: usize) { #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] { use super::super::simd::activations; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -69,6 +81,16 @@ pub unsafe fn silu_kernel(a: *const T, out: *mut T, len: usize) { activations::silu_f64(a as *const f64, out as *mut f64, len); return; } + #[cfg(feature = "f16")] + DType::F16 => { + activations::silu_f16(a as *const half::f16, out as *mut half::f16, len); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + activations::silu_bf16(a as *const half::bf16, out as *mut half::bf16, len); + return; + } _ => {} } } @@ -102,6 +124,7 @@ pub unsafe fn gelu_kernel(a: *const T, out: *mut T, len: usize) { #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] { use super::super::simd::activations; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -112,6 +135,16 @@ pub unsafe fn gelu_kernel(a: *const T, out: *mut T, len: usize) { activations::gelu_f64(a as *const f64, out as *mut f64, len); return; } + #[cfg(feature = "f16")] + DType::F16 => { + activations::gelu_f16(a as *const half::f16, out as *mut half::f16, len); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + activations::gelu_bf16(a as *const half::bf16, out as *mut half::bf16, len); + return; + } _ => {} } } @@ -150,6 +183,7 @@ pub unsafe fn leaky_relu_kernel( #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] { use super::super::simd::activations; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -165,6 +199,26 @@ pub unsafe fn leaky_relu_kernel( activations::leaky_relu_f64(a as *const f64, out as *mut f64, len, negative_slope); return; } + #[cfg(feature = "f16")] + DType::F16 => { + activations::leaky_relu_f16( + a as *const half::f16, + out as *mut half::f16, + len, + negative_slope as f32, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + activations::leaky_relu_bf16( + a as *const half::bf16, + out as *mut half::bf16, + len, + negative_slope as f32, + ); + return; + } _ => {} } } @@ -198,6 +252,7 @@ pub unsafe fn elu_kernel(a: *const T, out: *mut T, len: usize, alpha #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] { use super::super::simd::activations; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -208,6 +263,26 @@ pub unsafe fn elu_kernel(a: *const T, out: *mut T, len: usize, alpha activations::elu_f64(a as *const f64, out as *mut f64, len, alpha); return; } + #[cfg(feature = "f16")] + DType::F16 => { + activations::elu_f16( + a as *const half::f16, + out as *mut half::f16, + len, + alpha as f32, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + activations::elu_bf16( + a as *const half::bf16, + out as *mut half::bf16, + len, + alpha as f32, + ); + return; + } _ => {} } } diff --git a/src/runtime/cpu/kernels/unary/fused_activations.rs b/src/runtime/cpu/kernels/unary/fused_activations.rs new file mode 100644 index 00000000..cb4ce9dd --- /dev/null +++ b/src/runtime/cpu/kernels/unary/fused_activations.rs @@ -0,0 +1,259 @@ +//! Fused activation-multiplication kernels +//! +//! Each function computes `activation(a) * b` element-wise with automatic SIMD dispatch. +//! Fusing saves one full memory pass compared to separate activation + multiply. + +use crate::dtype::Element; + +/// Fused SiLU-Mul: `silu(a) * b = (a / (1 + exp(-a))) * b` +/// +/// # Safety +/// - `a`, `b`, and `out` must be valid pointers to `len` elements +#[inline] +pub unsafe fn silu_mul_kernel(a: *const T, b: *const T, out: *mut T, len: usize) { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + { + use super::super::simd::fused_activation_mul; + use crate::dtype::DType; + + match T::DTYPE { + DType::F32 => { + fused_activation_mul::silu_mul_f32( + a as *const f32, + b as *const f32, + out as *mut f32, + len, + ); + return; + } + DType::F64 => { + fused_activation_mul::silu_mul_f64( + a as *const f64, + b as *const f64, + out as *mut f64, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + fused_activation_mul::silu_mul_f16( + a as *const half::f16, + b as *const half::f16, + out as *mut half::f16, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + fused_activation_mul::silu_mul_bf16( + a as *const half::bf16, + b as *const half::bf16, + out as *mut half::bf16, + len, + ); + return; + } + _ => {} + } + } + + fused_scalar(a, b, out, len, |x| x / (1.0 + (-x).exp())); +} + +/// Fused GELU-Mul: `gelu(a) * b` +/// +/// # Safety +/// - `a`, `b`, and `out` must be valid pointers to `len` elements +#[inline] +pub unsafe fn gelu_mul_kernel(a: *const T, b: *const T, out: *mut T, len: usize) { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + { + use super::super::simd::fused_activation_mul; + use crate::dtype::DType; + + match T::DTYPE { + DType::F32 => { + fused_activation_mul::gelu_mul_f32( + a as *const f32, + b as *const f32, + out as *mut f32, + len, + ); + return; + } + DType::F64 => { + fused_activation_mul::gelu_mul_f64( + a as *const f64, + b as *const f64, + out as *mut f64, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + fused_activation_mul::gelu_mul_f16( + a as *const half::f16, + b as *const half::f16, + out as *mut half::f16, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + fused_activation_mul::gelu_mul_bf16( + a as *const half::bf16, + b as *const half::bf16, + out as *mut half::bf16, + len, + ); + return; + } + _ => {} + } + } + + const SQRT_2_OVER_PI: f64 = 0.7978845608028654; + const TANH_COEF: f64 = 0.044715; + fused_scalar(a, b, out, len, |x| { + let inner = SQRT_2_OVER_PI * (x + TANH_COEF * x * x * x); + 0.5 * x * (1.0 + inner.tanh()) + }); +} + +/// Fused ReLU-Mul: `relu(a) * b = max(0, a) * b` +/// +/// # Safety +/// - `a`, `b`, and `out` must be valid pointers to `len` elements +#[inline] +pub unsafe fn relu_mul_kernel(a: *const T, b: *const T, out: *mut T, len: usize) { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + { + use super::super::simd::fused_activation_mul; + use crate::dtype::DType; + + match T::DTYPE { + DType::F32 => { + fused_activation_mul::relu_mul_f32( + a as *const f32, + b as *const f32, + out as *mut f32, + len, + ); + return; + } + DType::F64 => { + fused_activation_mul::relu_mul_f64( + a as *const f64, + b as *const f64, + out as *mut f64, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + fused_activation_mul::relu_mul_f16( + a as *const half::f16, + b as *const half::f16, + out as *mut half::f16, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + fused_activation_mul::relu_mul_bf16( + a as *const half::bf16, + b as *const half::bf16, + out as *mut half::bf16, + len, + ); + return; + } + _ => {} + } + } + + fused_scalar(a, b, out, len, |x| if x > 0.0 { x } else { 0.0 }); +} + +/// Fused Sigmoid-Mul: `sigmoid(a) * b = (1 / (1 + exp(-a))) * b` +/// +/// # Safety +/// - `a`, `b`, and `out` must be valid pointers to `len` elements +#[inline] +pub unsafe fn sigmoid_mul_kernel(a: *const T, b: *const T, out: *mut T, len: usize) { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + { + use super::super::simd::fused_activation_mul; + use crate::dtype::DType; + + match T::DTYPE { + DType::F32 => { + fused_activation_mul::sigmoid_mul_f32( + a as *const f32, + b as *const f32, + out as *mut f32, + len, + ); + return; + } + DType::F64 => { + fused_activation_mul::sigmoid_mul_f64( + a as *const f64, + b as *const f64, + out as *mut f64, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::F16 => { + fused_activation_mul::sigmoid_mul_f16( + a as *const half::f16, + b as *const half::f16, + out as *mut half::f16, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + fused_activation_mul::sigmoid_mul_bf16( + a as *const half::bf16, + b as *const half::bf16, + out as *mut half::bf16, + len, + ); + return; + } + _ => {} + } + } + + fused_scalar(a, b, out, len, |x| 1.0 / (1.0 + (-x).exp())); +} + +/// Generic scalar fallback for fused activation-mul: `activation(a[i]) * b[i]` +#[inline] +unsafe fn fused_scalar f64>( + a: *const T, + b: *const T, + out: *mut T, + len: usize, + activation: F, +) { + let a_slice = std::slice::from_raw_parts(a, len); + let b_slice = std::slice::from_raw_parts(b, len); + let out_slice = std::slice::from_raw_parts_mut(out, len); + + for i in 0..len { + let x = a_slice[i].to_f64(); + let y = b_slice[i].to_f64(); + out_slice[i] = T::from_f64(activation(x) * y); + } +} diff --git a/src/runtime/cpu/kernels/unary/mod.rs b/src/runtime/cpu/kernels/unary/mod.rs index 710beeb9..94ee0894 100644 --- a/src/runtime/cpu/kernels/unary/mod.rs +++ b/src/runtime/cpu/kernels/unary/mod.rs @@ -5,9 +5,13 @@ pub mod activations; mod complex; +pub mod fused_activations; pub mod scalar; pub use activations::{elu_kernel, gelu_kernel, leaky_relu_kernel, sigmoid_kernel, silu_kernel}; +pub use fused_activations::{ + gelu_mul_kernel, relu_mul_kernel, sigmoid_mul_kernel, silu_mul_kernel, +}; pub use scalar::{relu_scalar_f32, relu_scalar_f64, unary_scalar_f32, unary_scalar_f64}; use crate::dtype::{DType, Element}; @@ -50,6 +54,16 @@ pub unsafe fn unary_op_kernel(op: UnaryOp, a: *const T, out: *mut T, unary::unary_f64(op, a as *const f64, out as *mut f64, len); return; } + #[cfg(feature = "f16")] + DType::F16 => { + unary::unary_f16(op, a as *const half::f16, out as *mut half::f16, len); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + unary::unary_bf16(op, a as *const half::bf16, out as *mut half::bf16, len); + return; + } _ => {} } } @@ -279,6 +293,16 @@ pub unsafe fn relu_kernel(a: *const T, out: *mut T, len: usize) { unary::relu_f64(a as *const f64, out as *mut f64, len); return; } + #[cfg(feature = "f16")] + DType::F16 => { + unary::relu_f16(a as *const half::f16, out as *mut half::f16, len); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + unary::relu_bf16(a as *const half::bf16, out as *mut half::bf16, len); + return; + } _ => {} } } @@ -370,6 +394,28 @@ pub unsafe fn clamp_kernel( clamp::clamp_f64(a as *const f64, out as *mut f64, len, min_val, max_val); return; } + #[cfg(feature = "f16")] + DType::F16 => { + clamp::clamp_f16( + a as *const half::f16, + out as *mut half::f16, + len, + min_val as f32, + max_val as f32, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + clamp::clamp_bf16( + a as *const half::bf16, + out as *mut half::bf16, + len, + min_val as f32, + max_val as f32, + ); + return; + } _ => {} } } diff --git a/src/runtime/cpu/kernels/where_select.rs b/src/runtime/cpu/kernels/where_select.rs index fb053b54..02b99445 100644 --- a/src/runtime/cpu/kernels/where_select.rs +++ b/src/runtime/cpu/kernels/where_select.rs @@ -6,7 +6,7 @@ //! - `where_strided_kernel` - U8 condition with broadcasting //! - `where_strided_kernel_generic` - Generic condition with broadcasting -use crate::dtype::{DType, Element}; +use crate::dtype::Element; /// Where (conditional select): out[i] = cond[i] ? x[i] : y[i] /// @@ -31,6 +31,7 @@ pub unsafe fn where_kernel( #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] { use super::simd::where_select; + use crate::dtype::DType; match T::DTYPE { DType::F32 => { @@ -53,6 +54,28 @@ pub unsafe fn where_kernel( ); return; } + #[cfg(feature = "f16")] + DType::F16 => { + where_select::where_f16( + cond, + x as *const half::f16, + y as *const half::f16, + out as *mut half::f16, + len, + ); + return; + } + #[cfg(feature = "f16")] + DType::BF16 => { + where_select::where_bf16( + cond, + x as *const half::bf16, + y as *const half::bf16, + out as *mut half::bf16, + len, + ); + return; + } _ => {} // Fall through to scalar } } diff --git a/src/runtime/cpu/ops.rs b/src/runtime/cpu/ops.rs index 03006ada..ffb69247 100644 --- a/src/runtime/cpu/ops.rs +++ b/src/runtime/cpu/ops.rs @@ -92,3 +92,14 @@ mod semiring_matmul; #[path = "../../ops/cpu/einsum.rs"] mod einsum; + +#[path = "../../ops/cpu/gemm_epilogue.rs"] +mod gemm_epilogue; + +#[cfg(feature = "fp8")] +#[path = "../../ops/cpu/fp8_matmul.rs"] +mod fp8_matmul; + +#[cfg(feature = "sparse")] +#[path = "../../ops/cpu/sparse_24.rs"] +mod sparse_24; diff --git a/src/runtime/cpu/runtime.rs b/src/runtime/cpu/runtime.rs index 840249be..4244fe59 100644 --- a/src/runtime/cpu/runtime.rs +++ b/src/runtime/cpu/runtime.rs @@ -2,7 +2,7 @@ use super::client::{CpuAllocator, CpuClient}; use super::device::CpuDevice; -use crate::runtime::Runtime; +use crate::runtime::{NoOpGraph, Runtime}; use std::alloc::{Layout as AllocLayout, alloc, dealloc}; /// CPU compute runtime @@ -16,12 +16,23 @@ impl Runtime for CpuRuntime { type Device = CpuDevice; type Client = CpuClient; type Allocator = CpuAllocator; - type RawHandle = (); // CPU has no special handle needed + type Graph = NoOpGraph; + type RawHandle = (); + type DType = crate::dtype::DType; fn name() -> &'static str { "cpu" } + fn capture_graph(client: &Self::Client, f: F) -> crate::error::Result<(Self::Graph, T)> + where + F: FnOnce(&Self::Client) -> crate::error::Result, + { + // CPU: execute eagerly, return NoOpGraph + let result = f(client)?; + Ok((NoOpGraph, result)) + } + fn allocate(size_bytes: usize, _device: &Self::Device) -> crate::error::Result { if size_bytes == 0 { return Ok(0); @@ -165,3 +176,48 @@ impl Runtime for CpuRuntime { &() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::runtime::Graph; + + #[test] + fn test_cpu_supports_graph_capture() { + assert!(!CpuRuntime::supports_graph_capture()); + } + + #[test] + fn test_cpu_capture_graph_executes_eagerly() { + let device = CpuRuntime::default_device(); + let client = CpuRuntime::default_client(&device); + + let mut executed = false; + let (graph, result) = CpuRuntime::capture_graph(&client, |_c| { + executed = true; + Ok(42) + }) + .unwrap(); + + // Closure executed eagerly + assert!(executed); + assert_eq!(result, 42); + + // Graph is NoOp + assert!(!graph.is_replay_capable()); + assert!(graph.launch().is_ok()); + } + + #[test] + fn test_cpu_capture_graph_propagates_error() { + let device = CpuRuntime::default_device(); + let client = CpuRuntime::default_client(&device); + + let result: crate::error::Result<(NoOpGraph, ())> = + CpuRuntime::capture_graph(&client, |_c| { + Err(crate::error::Error::Internal("test error".into())) + }); + + assert!(result.is_err()); + } +} diff --git a/src/runtime/cpu/sort.rs b/src/runtime/cpu/sort.rs index e68f8993..9d63290e 100644 --- a/src/runtime/cpu/sort.rs +++ b/src/runtime/cpu/sort.rs @@ -28,8 +28,8 @@ pub fn sort_impl( let a_contig = ensure_contiguous(a); let out = Tensor::::empty(shape, dtype, &client.device); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -69,9 +69,9 @@ pub fn sort_with_indices_impl( let out_values = Tensor::::empty(shape, dtype, &client.device); let out_indices = Tensor::::empty(shape, DType::I64, &client.device); - let a_ptr = a_contig.storage().ptr(); - let values_ptr = out_values.storage().ptr(); - let indices_ptr = out_indices.storage().ptr(); + let a_ptr = a_contig.ptr(); + let values_ptr = out_values.ptr(); + let indices_ptr = out_indices.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -114,8 +114,8 @@ pub fn argsort_impl( let a_contig = ensure_contiguous(a); let out = Tensor::::empty(shape, DType::I64, &client.device); - let a_ptr = a_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let a_ptr = a_contig.ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -186,9 +186,9 @@ pub fn topk_impl( let out_values = Tensor::::empty(&out_shape, dtype, &client.device); let out_indices = Tensor::::empty(&out_shape, DType::I64, &client.device); - let a_ptr = a_contig.storage().ptr(); - let values_ptr = out_values.storage().ptr(); - let indices_ptr = out_indices.storage().ptr(); + let a_ptr = a_contig.ptr(); + let values_ptr = out_values.ptr(); + let indices_ptr = out_indices.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -227,7 +227,7 @@ pub fn unique_impl( // Sort first let sorted_tensor = sort_impl(client, &a_contig, 0, false)?; - let sorted_ptr = sorted_tensor.storage().ptr(); + let sorted_ptr = sorted_tensor.ptr(); // Count unique let unique_count = dispatch_dtype!(dtype, T => { @@ -236,7 +236,7 @@ pub fn unique_impl( // Extract unique let out = Tensor::::empty(&[unique_count], dtype, &client.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -275,7 +275,7 @@ pub fn unique_with_counts_impl( // Gather sorted data let sorted_tensor = client.gather(&a_contig, 0, &sort_indices)?; - let sorted_ptr = sorted_tensor.storage().ptr(); + let sorted_ptr = sorted_tensor.ptr(); // Count unique let unique_count = dispatch_dtype!(dtype, T => { @@ -287,11 +287,11 @@ pub fn unique_with_counts_impl( let out_inverse = Tensor::::empty(&[numel], DType::I64, &client.device); let out_counts = Tensor::::empty(&[unique_count], DType::I64, &client.device); - let a_ptr = a_contig.storage().ptr(); - let sort_indices_ptr = sort_indices.storage().ptr(); - let unique_ptr = out_unique.storage().ptr(); - let inverse_ptr = out_inverse.storage().ptr(); - let counts_ptr = out_counts.storage().ptr(); + let a_ptr = a_contig.ptr(); + let sort_indices_ptr = sort_indices.ptr(); + let unique_ptr = out_unique.ptr(); + let inverse_ptr = out_inverse.ptr(); + let counts_ptr = out_counts.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -327,7 +327,7 @@ pub fn nonzero_impl(client: &CpuClient, a: &Tensor) -> Result { @@ -352,7 +352,7 @@ pub fn nonzero_impl(client: &CpuClient, a: &Tensor) -> Result::empty(&[nnz], DType::I64, &client.device); - let flat_ptr = flat_indices.storage().ptr() as *mut i64; + let flat_ptr = flat_indices.ptr() as *mut i64; dispatch_dtype!(dtype, T => { unsafe { kernels::nonzero_flat_kernel::(a_ptr as *const T, flat_ptr, numel); } @@ -360,7 +360,7 @@ pub fn nonzero_impl(client: &CpuClient, a: &Tensor) -> Result::empty(&[nnz, ndim], DType::I64, &client.device); - let out_ptr = out.storage().ptr() as *mut i64; + let out_ptr = out.ptr() as *mut i64; unsafe { kernels::flat_to_multi_index_kernel(flat_ptr, out_ptr, nnz, shape); @@ -406,9 +406,9 @@ pub fn searchsorted_impl( let values_contig = ensure_contiguous(values); let out = Tensor::::empty(values.shape(), DType::I64, &client.device); - let seq_ptr = seq_contig.storage().ptr(); - let values_ptr = values_contig.storage().ptr(); - let out_ptr = out.storage().ptr() as *mut i64; + let seq_ptr = seq_contig.ptr(); + let values_ptr = values_contig.ptr(); + let out_ptr = out.ptr() as *mut i64; dispatch_dtype!(dtype, T => { unsafe { diff --git a/src/runtime/cpu/sparse/merge.rs b/src/runtime/cpu/sparse/merge.rs index cebd3b6e..3cc65ce7 100644 --- a/src/runtime/cpu/sparse/merge.rs +++ b/src/runtime/cpu/sparse/merge.rs @@ -10,7 +10,7 @@ use crate::tensor::Tensor; // Re-export zero_tolerance from shared utilities module // See runtime::sparse_utils::zero_tolerance for full documentation -pub(crate) use crate::runtime::sparse_utils::zero_tolerance; +pub(crate) use crate::runtime::common::sparse_utils::zero_tolerance; // ============================================================================= // Merge Strategy and Operation Semantics diff --git a/src/runtime/cpu/special/helpers/simd.rs b/src/runtime/cpu/special/helpers/simd.rs index df4fb5f0..1cecf9d7 100644 --- a/src/runtime/cpu/special/helpers/simd.rs +++ b/src/runtime/cpu/special/helpers/simd.rs @@ -25,7 +25,7 @@ use crate::runtime::cpu::kernels::simd::special as simd_special; /// 2. Dispatches to architecture-specific SIMD kernel if available /// 3. Falls back to scalar implementation otherwise macro_rules! impl_simd_special_fn { - ($fn_name:ident, $simd_f32:ident, $simd_f64:ident, $scalar_fn:path) => { + ($fn_name:ident, $simd_f32:ident, $simd_f64:ident, $simd_f16:ident, $simd_bf16:ident, $scalar_fn:path) => { pub fn $fn_name(x: &Tensor, device: &CpuDevice) -> Result> { // SIMD requires contiguous memory layout if !x.is_contiguous() { @@ -38,7 +38,7 @@ macro_rules! impl_simd_special_fn { { let len = x.numel(); let mut result = vec![0.0f32; len]; - let input_ptr = x.storage().ptr() as *const f32; + let input_ptr = x.ptr() as *const f32; unsafe { simd_special::$simd_f32(input_ptr, result.as_mut_ptr(), len); } @@ -53,7 +53,7 @@ macro_rules! impl_simd_special_fn { { let len = x.numel(); let mut result = vec![0.0f64; len]; - let input_ptr = x.storage().ptr() as *const f64; + let input_ptr = x.ptr() as *const f64; unsafe { simd_special::$simd_f64(input_ptr, result.as_mut_ptr(), len); } @@ -63,10 +63,42 @@ macro_rules! impl_simd_special_fn { #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] apply_unary(x, device, $scalar_fn) } - // F16/BF16/FP8: Convert to F32, compute, convert back - DType::F16 | DType::BF16 | DType::FP8E4M3 | DType::FP8E5M2 => { + #[cfg(feature = "f16")] + DType::F16 => { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + { + let len = x.numel(); + let mut result = vec![half::f16::ZERO; len]; + let input_ptr = x.ptr() as *const half::f16; + unsafe { + simd_special::$simd_f16(input_ptr, result.as_mut_ptr(), len); + } + return Ok(Tensor::from_slice(&result, x.shape(), device)); + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + apply_unary(x, device, $scalar_fn) + } + #[cfg(feature = "f16")] + DType::BF16 => { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] + { + let len = x.numel(); + let mut result = vec![half::bf16::ZERO; len]; + let input_ptr = x.ptr() as *const half::bf16; + unsafe { + simd_special::$simd_bf16(input_ptr, result.as_mut_ptr(), len); + } + return Ok(Tensor::from_slice(&result, x.shape(), device)); + } + + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] apply_unary(x, device, $scalar_fn) } + // FP8 and others: scalar fallback + #[cfg(not(feature = "f16"))] + DType::F16 | DType::BF16 => apply_unary(x, device, $scalar_fn), + DType::FP8E4M3 | DType::FP8E5M2 => apply_unary(x, device, $scalar_fn), _ => unreachable!("dtype validated by caller"), } } @@ -78,6 +110,8 @@ impl_simd_special_fn!( apply_erf, erf_f32, erf_f64, + erf_f16, + erf_bf16, crate::algorithm::special::scalar::erf_scalar ); @@ -85,6 +119,8 @@ impl_simd_special_fn!( apply_erfc, erfc_f32, erfc_f64, + erfc_f16, + erfc_bf16, crate::algorithm::special::scalar::erfc_scalar ); @@ -92,6 +128,8 @@ impl_simd_special_fn!( apply_bessel_j0, bessel_j0_f32, bessel_j0_f64, + bessel_j0_f16, + bessel_j0_bf16, crate::algorithm::special::scalar::bessel_j0_scalar ); @@ -99,6 +137,8 @@ impl_simd_special_fn!( apply_bessel_j1, bessel_j1_f32, bessel_j1_f64, + bessel_j1_f16, + bessel_j1_bf16, crate::algorithm::special::scalar::bessel_j1_scalar ); @@ -106,6 +146,8 @@ impl_simd_special_fn!( apply_bessel_i0, bessel_i0_f32, bessel_i0_f64, + bessel_i0_f16, + bessel_i0_bf16, crate::algorithm::special::scalar::bessel_i0_scalar ); @@ -113,6 +155,8 @@ impl_simd_special_fn!( apply_bessel_i1, bessel_i1_f32, bessel_i1_f64, + bessel_i1_f16, + bessel_i1_bf16, crate::algorithm::special::scalar::bessel_i1_scalar ); @@ -120,6 +164,8 @@ impl_simd_special_fn!( apply_gamma, gamma_f32, gamma_f64, + gamma_f16, + gamma_bf16, crate::algorithm::special::scalar::gamma_scalar ); @@ -127,6 +173,8 @@ impl_simd_special_fn!( apply_lgamma, lgamma_f32, lgamma_f64, + lgamma_f16, + lgamma_bf16, crate::algorithm::special::scalar::lgamma_scalar ); @@ -134,5 +182,7 @@ impl_simd_special_fn!( apply_digamma, digamma_f32, digamma_f64, + digamma_f16, + digamma_bf16, crate::algorithm::special::scalar::digamma_scalar ); diff --git a/src/runtime/cpu/statistics/histogram.rs b/src/runtime/cpu/statistics/histogram.rs index deb07469..25f4c4cc 100644 --- a/src/runtime/cpu/statistics/histogram.rs +++ b/src/runtime/cpu/statistics/histogram.rs @@ -51,7 +51,7 @@ pub fn histogram_impl( // Flatten input let flat = a.reshape(&[numel])?; let flat_contig = ensure_contiguous(&flat); - let flat_ptr = flat_contig.storage().ptr(); + let flat_ptr = flat_contig.ptr(); // Determine range let (min_val, max_val) = if let Some((min, max)) = range { @@ -78,7 +78,7 @@ pub fn histogram_impl( // Create histogram counts tensor let hist = Tensor::::zeros(&[bins], DType::I64, &client.device); - let hist_ptr = hist.storage().ptr() as *mut i64; + let hist_ptr = hist.ptr() as *mut i64; // Compute histogram using optimized kernel dispatch_dtype!(dtype, T => { diff --git a/src/runtime/cpu/statistics/mod.rs b/src/runtime/cpu/statistics/mod.rs index 016510ff..9f5f8754 100644 --- a/src/runtime/cpu/statistics/mod.rs +++ b/src/runtime/cpu/statistics/mod.rs @@ -32,11 +32,11 @@ use super::helpers::dispatch_dtype; use super::{CpuClient, CpuRuntime}; use crate::dtype::{DType, Element}; use crate::error::Result; -use crate::runtime::statistics_common::{self, compute_bin_edges_f64}; +use crate::runtime::common::statistics_common::{self, compute_bin_edges_f64}; use crate::tensor::Tensor; // Re-export Interpolation for submodules -pub(crate) use crate::runtime::statistics_common::Interpolation; +pub(crate) use crate::runtime::common::statistics_common::Interpolation; // ============================================================================ // Optimized CPU Kernels @@ -170,7 +170,7 @@ pub(crate) fn create_bin_edges( // Create tensor and copy data based on dtype let edges = Tensor::::empty(&[bins + 1], dtype, &client.device); - let edges_ptr = edges.storage().ptr(); + let edges_ptr = edges.ptr(); dispatch_dtype!(dtype, T => { unsafe { @@ -187,7 +187,7 @@ pub(crate) fn create_bin_edges( /// Extract scalar f64 value from tensor. pub(crate) fn tensor_to_f64(t: &Tensor) -> Result { let dtype = t.dtype(); - let ptr = t.storage().ptr(); + let ptr = t.ptr(); let val = dispatch_dtype!(dtype, T => { unsafe { (*(ptr as *const T)).to_f64() } diff --git a/src/runtime/cpu/statistics/mode.rs b/src/runtime/cpu/statistics/mode.rs index 4d2def52..7fb44e4a 100644 --- a/src/runtime/cpu/statistics/mode.rs +++ b/src/runtime/cpu/statistics/mode.rs @@ -6,8 +6,8 @@ use super::super::{CpuClient, CpuRuntime}; use crate::dtype::DType; use crate::error::Result; use crate::ops::{TypeConversionOps, compute_reduce_strides, reduce_dim_output_shape}; +use crate::runtime::common::statistics_common::compute_mode_strided; use crate::runtime::normalize_dim; -use crate::runtime::statistics_common::compute_mode_strided; use crate::tensor::Tensor; /// Compute mode (most frequent value) along a dimension. diff --git a/src/runtime/cpu/statistics/moments.rs b/src/runtime/cpu/statistics/moments.rs index 6e1bca3f..c8e3f278 100644 --- a/src/runtime/cpu/statistics/moments.rs +++ b/src/runtime/cpu/statistics/moments.rs @@ -5,7 +5,9 @@ use super::super::{CpuClient, CpuRuntime}; use crate::dtype::Element; use crate::error::Result; use crate::ops::{BinaryOps, ReduceOps, ScalarOps, StatisticalOps}; -use crate::runtime::statistics_common::{DIVISION_EPSILON, compute_kurtosis, compute_skewness}; +use crate::runtime::common::statistics_common::{ + DIVISION_EPSILON, compute_kurtosis, compute_skewness, +}; use crate::tensor::Tensor; /// Compute skewness (third standardized moment) along dimensions. @@ -44,7 +46,7 @@ pub fn skew_impl( if dims.is_empty() { let numel = a.numel(); let a_contig = ensure_contiguous(a); - let a_ptr = a_contig.storage().ptr(); + let a_ptr = a_contig.ptr(); let skewness = dispatch_dtype!(dtype, T => { unsafe { @@ -55,7 +57,7 @@ pub fn skew_impl( let out_shape = if keepdim { vec![1; ndim] } else { vec![] }; let out = Tensor::::empty(&out_shape, dtype, &client.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { *(out_ptr as *mut T) = T::from_f64(skewness); } @@ -125,7 +127,7 @@ pub fn kurtosis_impl( if dims.is_empty() { let numel = a.numel(); let a_contig = ensure_contiguous(a); - let a_ptr = a_contig.storage().ptr(); + let a_ptr = a_contig.ptr(); let kurtosis = dispatch_dtype!(dtype, T => { unsafe { @@ -136,7 +138,7 @@ pub fn kurtosis_impl( let out_shape = if keepdim { vec![1; ndim] } else { vec![] }; let out = Tensor::::empty(&out_shape, dtype, &client.device); - let out_ptr = out.storage().ptr(); + let out_ptr = out.ptr(); dispatch_dtype!(dtype, T => { unsafe { *(out_ptr as *mut T) = T::from_f64(kurtosis); } diff --git a/src/runtime/cpu/statistics/quantile.rs b/src/runtime/cpu/statistics/quantile.rs index 506b5f97..a815467e 100644 --- a/src/runtime/cpu/statistics/quantile.rs +++ b/src/runtime/cpu/statistics/quantile.rs @@ -93,8 +93,8 @@ pub fn quantile_impl( let (outer_size, reduce_size, inner_size) = compute_reduce_strides(shape, dim_idx); let sorted_contig = ensure_contiguous(&sorted); - let sorted_ptr = sorted_contig.storage().ptr(); - let out_ptr = out.storage().ptr(); + let sorted_ptr = sorted_contig.ptr(); + let out_ptr = out.ptr(); // Dispatch to typed kernel dispatch_dtype!(dtype, T => { diff --git a/src/runtime/cuda/cache.rs b/src/runtime/cuda/cache.rs index b42d9c4b..ecbe3e53 100644 --- a/src/runtime/cuda/cache.rs +++ b/src/runtime/cuda/cache.rs @@ -52,33 +52,14 @@ pub(super) fn get_or_create_client(device: &CudaDevice) -> CudaClient { client } -/// Reset the cached client for a device, creating a fresh one. +/// Try to get a cached client for a device. /// -/// This is used to recover from sticky CUDA stream errors (e.g., -/// CUDA_ERROR_MISALIGNED_ADDRESS) that permanently poison a stream. -/// Creates a new client with a fresh context, stream, and cuBLAS handle. -/// -/// Returns the new client, or None if client creation fails. -pub(super) fn reset_client(device: &CudaDevice) -> Option { - let cache = CLIENT_CACHE.get_or_init(|| Mutex::new(HashMap::new())); - let mut cache_guard = lock_client_cache(cache); - - // Remove old client and create a fresh one - cache_guard.remove(&device.index); - - // Also clear any cached modules since they're tied to the old context - if let Some(mod_cache) = super::kernels::loader::module_cache() { - let mut mod_guard = mod_cache.lock().unwrap_or_else(PoisonError::into_inner); - mod_guard.retain(|(dev_idx, _), _| *dev_idx != device.index); - } - - match CudaClient::new(device.clone()) { - Ok(client) => { - cache_guard.insert(device.index, client.clone()); - Some(client) - } - Err(_) => None, - } +/// Returns `None` if no client is cached or if the cache lock is unavailable. +#[inline] +pub(super) fn try_get_cached_client(device_index: usize) -> Option { + let cache = CLIENT_CACHE.get()?; + let guard = lock_client_cache(cache); + guard.get(&device_index).cloned() } /// Try to get the stream from a cached client for a device. diff --git a/src/runtime/cuda/client.rs b/src/runtime/cuda/client.rs index f87286b4..2f054679 100644 --- a/src/runtime/cuda/client.rs +++ b/src/runtime/cuda/client.rs @@ -34,16 +34,6 @@ unsafe fn is_cuda_context_valid() -> bool { result == cudarc::driver::sys::CUresult::CUDA_SUCCESS && !ctx.is_null() } -/// Log a CUDA memory operation failure. -#[cold] -#[inline(never)] -fn log_cuda_memory_error(operation: &str, ptr: u64, result: cudarc::driver::sys::CUresult) { - eprintln!( - "[numr::cuda] {} failed for ptr 0x{:x}: {:?}", - operation, ptr, result - ); -} - // ============================================================================ // CudaClient // ============================================================================ @@ -70,9 +60,12 @@ pub struct CudaClient { /// CUDA context for this device (owns GPU context) pub(crate) context: Arc, - /// Stream on which all kernels launch + /// Stream on which all kernels launch (compute stream) pub(crate) stream: Arc, + /// Dedicated stream for D2H copies (overlaps with compute stream) + pub(crate) copy_stream: Arc, + /// cuBLAS handle for GEMM operations pub(crate) cublas: Arc, @@ -95,34 +88,45 @@ impl std::fmt::Debug for CudaClient { // CudaAllocator // ============================================================================ -/// CUDA allocator that uses stream-ordered allocation. +/// CUDA caching allocator with Rust-side free lists. /// -/// This allocator uses `cuMemAllocAsync` and `cuMemFreeAsync` for efficient -/// stream-ordered memory management. Memory operations are synchronized with -/// kernel execution on the associated stream. +/// Maintains per-size free lists of GPU buffers. On deallocation, buffers are +/// returned to the free list instead of calling `cuMemFreeAsync`. On allocation, +/// the free list is checked first, bypassing the CUDA driver entirely for repeat +/// allocations of the same size. This is critical for inference decode loops where +/// the same buffer sizes are allocated every step. /// -/// # Panics +/// Falls through to `cuMemAllocAsync` for sizes not in the cache. /// -/// The `allocate` method panics if CUDA memory allocation fails, following -/// CUDA best practices where OOM is typically unrecoverable. #[derive(Clone)] pub struct CudaAllocator { stream: Arc, + /// Free list: size_bytes → Vec + cache: Arc>>>, + /// When frozen, bypass the cache entirely. Used during CUDA graph capture + /// so that `cuMemAllocAsync`/`cuMemFreeAsync` create proper graph nodes. + frozen: Arc, } impl Allocator for CudaAllocator { - /// Allocate GPU memory using stream-ordered allocation. - /// - /// If the first allocation attempt fails, synchronizes the stream to flush - /// pending async frees, then retries once. This handles the common case where - /// `cuMemFreeAsync` calls haven't completed yet. - /// - /// Returns `Err(OutOfMemory)` if CUDA memory allocation fails even after retry. fn allocate(&self, size_bytes: usize) -> crate::error::Result { if size_bytes == 0 { return Ok(0); } + // When frozen (graph capture), bypass cache — go straight to driver + // so cuMemAllocAsync creates a proper graph allocation node. + if !self.frozen.load(std::sync::atomic::Ordering::Relaxed) { + // Check free list first + let mut cache = self.cache.lock().unwrap(); + if let Some(ptrs) = cache.get_mut(&size_bytes) + && let Some(ptr) = ptrs.pop() + { + return Ok(ptr); + } + } + + // Allocate from CUDA driver (stream-ordered) unsafe { let mut ptr: u64 = 0; let result = @@ -132,8 +136,7 @@ impl Allocator for CudaAllocator { return Ok(ptr); } - // First attempt failed - synchronize stream to flush pending async frees, - // then retry. + // Sync stream to flush pending async frees, then retry let _ = self.stream.synchronize(); let result = @@ -147,40 +150,53 @@ impl Allocator for CudaAllocator { } } - fn deallocate(&self, ptr: u64, _size_bytes: usize) { + fn deallocate(&self, ptr: u64, size_bytes: usize) { if ptr == 0 { return; } - unsafe { - // Check if CUDA context is still valid before attempting free - if !is_cuda_context_valid() { - // Context is gone - memory will be reclaimed by driver - return; - } - - let result = cudarc::driver::sys::cuMemFreeAsync(ptr, self.stream.cu_stream()); - - // Log failures but don't panic - deallocation errors are typically benign - if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS - && result != cudarc::driver::sys::CUresult::CUDA_ERROR_ILLEGAL_ADDRESS - { - log_cuda_memory_error("cuMemFreeAsync", ptr, result); + // When frozen (graph capture), bypass cache — call cuMemFreeAsync + // so the driver creates a proper graph free node. + if self.frozen.load(std::sync::atomic::Ordering::Relaxed) { + unsafe { + let _ = cudarc::driver::sys::cuMemFreeAsync(ptr, self.stream.cu_stream()); } + return; } + + // Return to free list for reuse + let mut cache = self.cache.lock().unwrap(); + cache.entry(size_bytes).or_default().push(ptr); } fn is_frozen(&self) -> bool { - false // CUDA allocator doesn't support freeze + self.frozen.load(std::sync::atomic::Ordering::Relaxed) } fn freeze(&self) -> bool { - // No-op for CUDA - always succeeds + self.frozen + .store(true, std::sync::atomic::Ordering::Relaxed); true } fn unfreeze(&self) { - // No-op for CUDA + self.frozen + .store(false, std::sync::atomic::Ordering::Relaxed); + } + + fn reset(&self) -> crate::error::Result<()> { + // Flush all cached buffers back to CUDA + let mut cache = self.cache.lock().unwrap(); + for (_size, ptrs) in cache.drain() { + for ptr in ptrs { + unsafe { + if is_cuda_context_valid() { + let _ = cudarc::driver::sys::cuMemFreeAsync(ptr, self.stream.cu_stream()); + } + } + } + } + Ok(()) } } @@ -213,17 +229,41 @@ impl CudaClient { CudaError::ContextError(format!("Failed to bind CUDA context to thread: {:?}", e)) })?; - // Create a stream in this context + // Create compute stream let stream = context.new_stream().map_err(|e| { CudaError::ContextError(format!("Failed to create CUDA stream: {:?}", e)) })?; + // Create dedicated copy stream for overlapped D2H transfers + let copy_stream = context.new_stream().map_err(|e| { + CudaError::ContextError(format!("Failed to create CUDA copy stream: {:?}", e)) + })?; + // Initialize cuBLAS handle for GEMM operations let cublas = CudaBlas::new(stream.clone()) .map_err(|e| CudaError::CublasError(format!("Failed to initialize cuBLAS: {:?}", e)))?; + // Configure the default memory pool to cache freed allocations instead + // of returning them to the OS. This dramatically reduces allocation overhead + // for repetitive workloads (e.g. inference decode loops). + unsafe { + let mut pool: cudarc::driver::sys::CUmemoryPool = std::ptr::null_mut(); + let result = + cudarc::driver::sys::cuDeviceGetDefaultMemPool(&mut pool, device.index as i32); + if result == cudarc::driver::sys::CUresult::CUDA_SUCCESS && !pool.is_null() { + let threshold: u64 = u64::MAX; // Keep all freed memory cached + let _ = cudarc::driver::sys::cuMemPoolSetAttribute( + pool, + cudarc::driver::sys::CUmemPool_attribute::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD, + &threshold as *const u64 as *mut std::ffi::c_void, + ); + } + } + let allocator = CudaAllocator { stream: stream.clone(), + cache: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())), + frozen: Arc::new(std::sync::atomic::AtomicBool::new(false)), }; let raw_handle = CudaRawHandle { @@ -235,6 +275,7 @@ impl CudaClient { device, context, stream, + copy_stream, cublas: Arc::new(cublas), allocator, raw_handle, @@ -261,11 +302,85 @@ impl CudaClient { &self.context } + /// Get reference to the copy stream (for overlapped D2H transfers). + #[inline] + pub fn copy_stream(&self) -> &CudaStream { + &self.copy_stream + } + /// Get reference to the cuBLAS handle. #[inline] pub fn cublas(&self) -> &CudaBlas { &self.cublas } + + /// Record an event on the compute stream. + /// + /// Returns an event handle that can be passed to `copy_stream_wait_event`. + pub fn record_event_on_compute(&self) -> Result { + use cudarc::driver::sys::{CUevent_flags, cuEventCreate, cuEventRecord}; + unsafe { + let mut event = std::ptr::null_mut(); + let r = cuEventCreate(&mut event, CUevent_flags::CU_EVENT_DISABLE_TIMING as u32); + if r != cudarc::driver::sys::CUresult::CUDA_SUCCESS { + return Err(CudaError::ContextError(format!( + "cuEventCreate failed: {:?}", + r + ))); + } + let r = cuEventRecord(event, self.stream.cu_stream()); + if r != cudarc::driver::sys::CUresult::CUDA_SUCCESS { + cudarc::driver::sys::cuEventDestroy_v2(event); + return Err(CudaError::ContextError(format!( + "cuEventRecord failed: {:?}", + r + ))); + } + Ok(event as u64) + } + } + + /// Make the copy stream wait for an event recorded on the compute stream. + pub fn copy_stream_wait_event(&self, event: u64) -> Result<(), CudaError> { + use cudarc::driver::sys::cuStreamWaitEvent; + unsafe { + let r = cuStreamWaitEvent( + self.copy_stream.cu_stream(), + event as cudarc::driver::sys::CUevent, + 0, + ); + if r != cudarc::driver::sys::CUresult::CUDA_SUCCESS { + return Err(CudaError::ContextError(format!( + "cuStreamWaitEvent failed: {:?}", + r + ))); + } + } + Ok(()) + } + + /// Pre-load CUDA PTX modules to avoid JIT compilation latency on first use. + /// + /// Call this during warmup with the list of numr kernel module names + /// that will be used during inference. + pub fn preload_modules(&self, module_names: &[&'static str]) -> crate::error::Result<()> { + crate::runtime::cuda::kernels::preload_modules( + &self.context, + self.device.index, + module_names, + ) + } + + /// Destroy a CUDA event handle returned by `record_event_on_compute`. + /// + /// Must be called after the copy stream has finished using the event + /// (i.e., after `copy_stream.synchronize()`). Passing an already-destroyed + /// or invalid handle is safe (CUDA ignores it). + pub fn destroy_event(&self, event: u64) { + unsafe { + cudarc::driver::sys::cuEventDestroy_v2(event as cudarc::driver::sys::CUevent); + } + } } impl RuntimeClient for CudaClient { @@ -282,6 +397,10 @@ impl RuntimeClient for CudaClient { fn allocator(&self) -> &CudaAllocator { &self.allocator } + + fn compute_stream_handle(&self) -> Option { + Some(self.stream.cu_stream() as u64) + } } // ============================================================================ diff --git a/src/runtime/cuda/communicator.rs b/src/runtime/cuda/communicator.rs new file mode 100644 index 00000000..0838da0b --- /dev/null +++ b/src/runtime/cuda/communicator.rs @@ -0,0 +1,645 @@ +//! NCCL-backed collective communication for multi-GPU +//! +//! Wraps cudarc's `nccl::Comm` and implements numr's `Communicator` trait. +//! Uses raw `nccl::result` FFI to handle runtime `DType` dispatch (cudarc's +//! safe API requires compile-time `NcclType` generics). + +use std::sync::Arc; + +use cudarc::driver::CudaStream; +use cudarc::nccl::{self, result as nccl_result, sys as nccl_sys}; + +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::runtime::communicator::{Communicator, ReduceOp, StreamSyncOps}; + +/// NCCL communicator wrapping a single `cudarc::nccl::Comm` (one per rank). +pub struct NcclCommunicator { + comm: nccl::Comm, +} + +// SAFETY: NCCL comms are thread-safe for submission from the owning thread. +// The Comm internally holds an Arc which is Send+Sync. +unsafe impl Send for NcclCommunicator {} +unsafe impl Sync for NcclCommunicator {} + +impl NcclCommunicator { + /// Wrap an existing cudarc NCCL communicator. + pub fn new(comm: nccl::Comm) -> Self { + Self { comm } + } + + /// Create communicators for all given streams (single-process, multi-GPU). + /// + /// Returns one `NcclCommunicator` per stream, with ranks assigned in order. + pub fn from_streams(streams: Vec>) -> Result> { + let comms = nccl::Comm::from_devices(streams) + .map_err(|e| Error::Backend(format!("NCCL init failed: {e:?}")))?; + Ok(comms.into_iter().map(|c| Self { comm: c }).collect()) + } + + /// Access the underlying cudarc `Comm`. + pub fn inner(&self) -> &nccl::Comm { + &self.comm + } + + /// Get the raw NCCL comm handle for FFI calls. + fn raw_comm(&self) -> nccl_sys::ncclComm_t { + // Access the private field via the Comm's public API indirectly. + // We need the raw pointer. Comm stores it as `comm: sys::ncclComm_t`. + // Unfortunately cudarc doesn't expose this directly, so we use + // a transmute-based approach to read the first field. + // + // SAFETY: Comm's first field is `comm: sys::ncclComm_t` (a raw pointer). + // This is verified by cudarc 0.18's source code. + unsafe { std::ptr::read((&self.comm as *const nccl::Comm).cast::()) } + } + + /// Get the raw CUDA stream handle for FFI calls. + fn raw_stream(&self) -> nccl_sys::cudaStream_t { + self.comm.stream().cu_stream() as nccl_sys::cudaStream_t + } +} + +/// Map numr `DType` to NCCL data type. +fn dtype_to_nccl(dtype: DType) -> Result { + match dtype { + DType::F32 => Ok(nccl_sys::ncclDataType_t::ncclFloat32), + DType::F64 => Ok(nccl_sys::ncclDataType_t::ncclFloat64), + DType::F16 => Ok(nccl_sys::ncclDataType_t::ncclFloat16), + DType::BF16 => Ok(nccl_sys::ncclDataType_t::ncclBfloat16), + DType::FP8E4M3 => Ok(nccl_sys::ncclDataType_t::ncclFloat8e4m3), + DType::FP8E5M2 => Ok(nccl_sys::ncclDataType_t::ncclFloat8e5m2), + DType::I32 => Ok(nccl_sys::ncclDataType_t::ncclInt32), + DType::I64 => Ok(nccl_sys::ncclDataType_t::ncclInt64), + DType::I8 => Ok(nccl_sys::ncclDataType_t::ncclInt8), + DType::U32 => Ok(nccl_sys::ncclDataType_t::ncclUint32), + DType::U8 => Ok(nccl_sys::ncclDataType_t::ncclUint8), + _ => Err(Error::UnsupportedDType { + dtype, + op: "nccl_communication", + }), + } +} + +/// Map numr `ReduceOp` to NCCL reduction operation. +fn reduce_op_to_nccl(op: ReduceOp) -> nccl_sys::ncclRedOp_t { + match op { + ReduceOp::Sum => nccl_sys::ncclRedOp_t::ncclSum, + ReduceOp::Prod => nccl_sys::ncclRedOp_t::ncclProd, + ReduceOp::Min => nccl_sys::ncclRedOp_t::ncclMin, + ReduceOp::Max => nccl_sys::ncclRedOp_t::ncclMax, + } +} + +/// Convert NCCL error to numr error. +fn nccl_err(e: nccl_result::NcclError) -> Error { + Error::Backend(format!("NCCL error: {e:?}")) +} + +impl Communicator for NcclCommunicator { + fn world_size(&self) -> usize { + self.comm.world_size() + } + + fn rank(&self) -> usize { + self.comm.rank() + } + + unsafe fn all_reduce(&self, ptr: u64, count: usize, dtype: DType, op: ReduceOp) -> Result<()> { + let nccl_dtype = dtype_to_nccl(dtype)?; + let nccl_op = reduce_op_to_nccl(op); + // In-place: sendbuff == recvbuff + unsafe { + nccl_result::all_reduce( + ptr as *const std::ffi::c_void, + ptr as *mut std::ffi::c_void, + count, + nccl_dtype, + nccl_op, + self.raw_comm(), + self.raw_stream(), + ) + .map_err(nccl_err)?; + } + Ok(()) + } + + unsafe fn broadcast(&self, ptr: u64, count: usize, dtype: DType, root: usize) -> Result<()> { + let nccl_dtype = dtype_to_nccl(dtype)?; + // In-place: sendbuff == recvbuff + unsafe { + nccl_result::broadcast( + ptr as *const std::ffi::c_void, + ptr as *mut std::ffi::c_void, + count, + nccl_dtype, + root as i32, + self.raw_comm(), + self.raw_stream(), + ) + .map_err(nccl_err)?; + } + Ok(()) + } + + unsafe fn all_gather( + &self, + send_ptr: u64, + recv_ptr: u64, + count: usize, + dtype: DType, + ) -> Result<()> { + let nccl_dtype = dtype_to_nccl(dtype)?; + unsafe { + nccl_result::all_gather( + send_ptr as *const std::ffi::c_void, + recv_ptr as *mut std::ffi::c_void, + count, + nccl_dtype, + self.raw_comm(), + self.raw_stream(), + ) + .map_err(nccl_err)?; + } + Ok(()) + } + + unsafe fn reduce_scatter( + &self, + send_ptr: u64, + recv_ptr: u64, + count: usize, + dtype: DType, + op: ReduceOp, + ) -> Result<()> { + let nccl_dtype = dtype_to_nccl(dtype)?; + let nccl_op = reduce_op_to_nccl(op); + unsafe { + nccl_result::reduce_scatter( + send_ptr as *const std::ffi::c_void, + recv_ptr as *mut std::ffi::c_void, + count, + nccl_dtype, + nccl_op, + self.raw_comm(), + self.raw_stream(), + ) + .map_err(nccl_err)?; + } + Ok(()) + } + + unsafe fn send( + &self, + ptr: u64, + count: usize, + dtype: DType, + dest: usize, + _tag: u32, + ) -> Result<()> { + let nccl_dtype = dtype_to_nccl(dtype)?; + unsafe { + nccl_result::send( + ptr as *const std::ffi::c_void, + count, + nccl_dtype, + dest as i32, + self.raw_comm(), + self.raw_stream(), + ) + .map_err(nccl_err)?; + } + Ok(()) + } + + unsafe fn recv( + &self, + ptr: u64, + count: usize, + dtype: DType, + src: usize, + _tag: u32, + ) -> Result<()> { + let nccl_dtype = dtype_to_nccl(dtype)?; + unsafe { + nccl_result::recv( + ptr as *mut std::ffi::c_void, + count, + nccl_dtype, + src as i32, + self.raw_comm(), + self.raw_stream(), + ) + .map_err(nccl_err)?; + } + Ok(()) + } + + fn sync(&self) -> Result<()> { + self.comm + .stream() + .synchronize() + .map_err(|e| Error::Backend(format!("CUDA stream sync failed: {e}")))?; + Ok(()) + } + + fn as_stream_sync(&self) -> Option<&dyn StreamSyncOps> { + Some(self) + } + + fn barrier(&self) -> Result<()> { + // NCCL has no explicit barrier. Sync the stream first, then do a + // zero-byte all_reduce as a collective synchronization point. + self.sync()?; + unsafe { + nccl_result::all_reduce( + std::ptr::null(), + std::ptr::null_mut(), + 0, + nccl_sys::ncclDataType_t::ncclFloat32, + nccl_sys::ncclRedOp_t::ncclSum, + self.raw_comm(), + self.raw_stream(), + ) + .map_err(nccl_err)?; + } + self.sync() + } +} + +impl StreamSyncOps for NcclCommunicator { + fn create_event(&self) -> Result { + use cudarc::driver::sys::{CUevent_flags, cuEventCreate}; + let mut event = std::ptr::null_mut(); + let result = + unsafe { cuEventCreate(&mut event, CUevent_flags::CU_EVENT_DISABLE_TIMING as u32) }; + if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS { + return Err(Error::Backend(format!("cuEventCreate failed: {result:?}"))); + } + Ok(event as u64) + } + + fn destroy_event(&self, event: u64) -> Result<()> { + use cudarc::driver::sys::cuEventDestroy_v2; + let result = unsafe { cuEventDestroy_v2(event as cudarc::driver::sys::CUevent) }; + if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS { + return Err(Error::Backend(format!("cuEventDestroy failed: {result:?}"))); + } + Ok(()) + } + + fn record_on_comm_stream(&self, event: u64) -> Result<()> { + use cudarc::driver::sys::cuEventRecord; + let result = unsafe { + cuEventRecord( + event as cudarc::driver::sys::CUevent, + self.raw_stream() as cudarc::driver::sys::CUstream, + ) + }; + if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS { + return Err(Error::Backend(format!( + "cuEventRecord on comm stream failed: {result:?}" + ))); + } + Ok(()) + } + + fn record_on_stream(&self, event: u64, stream_handle: u64) -> Result<()> { + use cudarc::driver::sys::cuEventRecord; + let result = unsafe { + cuEventRecord( + event as cudarc::driver::sys::CUevent, + stream_handle as cudarc::driver::sys::CUstream, + ) + }; + if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS { + return Err(Error::Backend(format!( + "cuEventRecord on stream failed: {result:?}" + ))); + } + Ok(()) + } + + fn comm_stream_wait_event(&self, event: u64) -> Result<()> { + use cudarc::driver::sys::cuStreamWaitEvent; + let result = unsafe { + cuStreamWaitEvent( + self.raw_stream() as cudarc::driver::sys::CUstream, + event as cudarc::driver::sys::CUevent, + 0, + ) + }; + if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS { + return Err(Error::Backend(format!( + "cuStreamWaitEvent on comm stream failed: {result:?}" + ))); + } + Ok(()) + } + + fn stream_wait_event(&self, stream_handle: u64, event: u64) -> Result<()> { + use cudarc::driver::sys::cuStreamWaitEvent; + let result = unsafe { + cuStreamWaitEvent( + stream_handle as cudarc::driver::sys::CUstream, + event as cudarc::driver::sys::CUevent, + 0, + ) + }; + if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS { + return Err(Error::Backend(format!( + "cuStreamWaitEvent on stream failed: {result:?}" + ))); + } + Ok(()) + } + + fn sync_comm_stream(&self) -> Result<()> { + self.comm + .stream() + .synchronize() + .map_err(|e| Error::Backend(format!("CUDA comm stream sync failed: {e}")))?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_send_sync_bounds() { + fn assert_send_sync() {} + assert_send_sync::(); + } + + #[test] + fn test_dtype_to_nccl_mapping() { + assert!(dtype_to_nccl(DType::F32).is_ok()); + assert!(dtype_to_nccl(DType::F64).is_ok()); + assert!(dtype_to_nccl(DType::F16).is_ok()); + assert!(dtype_to_nccl(DType::BF16).is_ok()); + assert!(dtype_to_nccl(DType::I32).is_ok()); + assert!(dtype_to_nccl(DType::I64).is_ok()); + assert!(dtype_to_nccl(DType::U32).is_ok()); + assert!(dtype_to_nccl(DType::U8).is_ok()); + assert!(dtype_to_nccl(DType::Bool).is_err()); + } + + #[test] + fn test_reduce_op_mapping() { + assert_eq!( + reduce_op_to_nccl(ReduceOp::Sum), + nccl_sys::ncclRedOp_t::ncclSum + ); + assert_eq!( + reduce_op_to_nccl(ReduceOp::Prod), + nccl_sys::ncclRedOp_t::ncclProd + ); + assert_eq!( + reduce_op_to_nccl(ReduceOp::Min), + nccl_sys::ncclRedOp_t::ncclMin + ); + assert_eq!( + reduce_op_to_nccl(ReduceOp::Max), + nccl_sys::ncclRedOp_t::ncclMax + ); + } + + // Helper: get raw device pointer from a CudaSlice for test use + fn slice_ptr(slice: &cudarc::driver::CudaSlice, stream: &Arc) -> u64 { + use cudarc::driver::DevicePtr; + let (ptr, _guard) = slice.device_ptr(stream); + ptr as u64 + } + + // ── Multi-GPU tests (require 2+ GPUs) ────────────────────────────── + + #[test] + #[ignore] + fn test_nccl_metadata() { + let ctx0 = cudarc::driver::CudaContext::new(0).unwrap(); + let ctx1 = cudarc::driver::CudaContext::new(1).unwrap(); + let streams = vec![ctx0.default_stream(), ctx1.default_stream()]; + let comms = NcclCommunicator::from_streams(streams).unwrap(); + assert_eq!(comms.len(), 2); + assert_eq!(comms[0].world_size(), 2); + assert_eq!(comms[1].world_size(), 2); + assert_eq!(comms[0].rank(), 0); + assert_eq!(comms[1].rank(), 1); + } + + #[test] + #[ignore] + fn test_nccl_all_reduce_f32() { + use cudarc::driver::CudaContext; + use cudarc::nccl::result as nr; + + let n = 4; + let n_devices = CudaContext::device_count().unwrap().min(2) as usize; + if n_devices < 2 { + return; + } + + let streams: Vec<_> = (0..n_devices) + .map(|i| { + let ctx = CudaContext::new(i).unwrap(); + ctx.default_stream() + }) + .collect(); + let comms = NcclCommunicator::from_streams(streams.clone()).unwrap(); + + // Each rank has [rank+1, rank+1, rank+1, rank+1] + let mut slices = Vec::new(); + for i in 0..n_devices { + let data = vec![(i + 1) as f32; n]; + let slice = streams[i].clone_htod(&data).unwrap(); + slices.push(slice); + } + + nr::group_start().unwrap(); + for (i, comm) in comms.iter().enumerate() { + unsafe { + comm.all_reduce( + slice_ptr(&slices[i], &streams[i]), + n, + DType::F32, + ReduceOp::Sum, + ) + .unwrap(); + } + } + nr::group_end().unwrap(); + + for (i, comm) in comms.iter().enumerate() { + comm.sync().unwrap(); + let out = streams[i].clone_dtoh(&slices[i]).unwrap(); + let expected = (n_devices * (n_devices + 1)) as f32 / 2.0; + for v in &out { + assert!( + (*v - expected).abs() < 1e-5, + "rank {i}: expected {expected}, got {v}" + ); + } + } + } + + #[test] + #[ignore] + fn test_nccl_broadcast() { + use cudarc::driver::CudaContext; + use cudarc::nccl::result as nr; + + let n = 4; + let n_devices = CudaContext::device_count().unwrap().min(2) as usize; + if n_devices < 2 { + return; + } + + let streams: Vec<_> = (0..n_devices) + .map(|i| CudaContext::new(i).unwrap().default_stream()) + .collect(); + let comms = NcclCommunicator::from_streams(streams.clone()).unwrap(); + + let mut slices = Vec::new(); + for (i, stream) in streams.iter().enumerate() { + let data = if i == 0 { + vec![42.0f32; n] + } else { + vec![0.0f32; n] + }; + slices.push(stream.clone_htod(&data).unwrap()); + } + + nr::group_start().unwrap(); + for (i, comm) in comms.iter().enumerate() { + unsafe { + comm.broadcast(slice_ptr(&slices[i], &streams[i]), n, DType::F32, 0) + .unwrap(); + } + } + nr::group_end().unwrap(); + + for (i, comm) in comms.iter().enumerate() { + comm.sync().unwrap(); + let out = streams[i].clone_dtoh(&slices[i]).unwrap(); + assert_eq!(out, vec![42.0f32; n], "rank {i} broadcast mismatch"); + } + } + + #[test] + #[ignore] + fn test_nccl_all_gather() { + use cudarc::driver::CudaContext; + use cudarc::nccl::result as nr; + + let n = 2; // elements per rank + let n_devices = CudaContext::device_count().unwrap().min(2) as usize; + if n_devices < 2 { + return; + } + + let streams: Vec<_> = (0..n_devices) + .map(|i| CudaContext::new(i).unwrap().default_stream()) + .collect(); + let comms = NcclCommunicator::from_streams(streams.clone()).unwrap(); + + let mut send_slices = Vec::new(); + let mut recv_slices = Vec::new(); + for (i, stream) in streams.iter().enumerate() { + let data = vec![(i + 1) as f32; n]; + send_slices.push(stream.clone_htod(&data).unwrap()); + recv_slices.push(stream.alloc_zeros::(n * n_devices).unwrap()); + } + + nr::group_start().unwrap(); + for (i, comm) in comms.iter().enumerate() { + unsafe { + comm.all_gather( + slice_ptr(&send_slices[i], &streams[i]), + slice_ptr(&recv_slices[i], &streams[i]), + n, + DType::F32, + ) + .unwrap(); + } + } + nr::group_end().unwrap(); + + for (i, comm) in comms.iter().enumerate() { + comm.sync().unwrap(); + let out = streams[i].clone_dtoh(&recv_slices[i]).unwrap(); + // Expected: [1.0, 1.0, 2.0, 2.0] for 2 devices + let mut expected = Vec::new(); + for rank in 0..n_devices { + expected.extend(std::iter::repeat_n((rank + 1) as f32, n)); + } + assert_eq!(out, expected, "rank {i} all_gather mismatch"); + } + } + + #[test] + #[ignore] + fn test_nccl_send_recv() { + use cudarc::driver::CudaContext; + use cudarc::nccl::result as nr; + + let n = 4; + let n_devices = CudaContext::device_count().unwrap().min(2) as usize; + if n_devices < 2 { + return; + } + + let streams: Vec<_> = (0..n_devices) + .map(|i| CudaContext::new(i).unwrap().default_stream()) + .collect(); + let comms = NcclCommunicator::from_streams(streams.clone()).unwrap(); + + let send_data = vec![99.0f32; n]; + let send_slice = streams[0].clone_htod(&send_data).unwrap(); + let recv_slice = streams[1].alloc_zeros::(n).unwrap(); + + nr::group_start().unwrap(); + unsafe { + comms[0] + .send(slice_ptr(&send_slice, &streams[0]), n, DType::F32, 1, 0) + .unwrap(); + comms[1] + .recv(slice_ptr(&recv_slice, &streams[1]), n, DType::F32, 0, 0) + .unwrap(); + } + nr::group_end().unwrap(); + + comms[0].sync().unwrap(); + comms[1].sync().unwrap(); + let out = streams[1].clone_dtoh(&recv_slice).unwrap(); + assert_eq!(out, vec![99.0f32; n]); + } + + #[test] + #[ignore] + fn test_nccl_sync_barrier() { + use cudarc::driver::CudaContext; + + let n_devices = CudaContext::device_count().unwrap().min(2) as usize; + if n_devices < 2 { + return; + } + + let streams: Vec<_> = (0..n_devices) + .map(|i| CudaContext::new(i).unwrap().default_stream()) + .collect(); + let comms = NcclCommunicator::from_streams(streams).unwrap(); + + for comm in &comms { + comm.sync().unwrap(); + } + // barrier requires all ranks to participate + cudarc::nccl::result::group_start().unwrap(); + for comm in &comms { + comm.barrier().unwrap(); + } + cudarc::nccl::result::group_end().unwrap(); + } +} diff --git a/src/runtime/cuda/fft.rs b/src/runtime/cuda/fft.rs index 3b3c3dc0..54622bd1 100644 --- a/src/runtime/cuda/fft.rs +++ b/src/runtime/cuda/fft.rs @@ -61,7 +61,7 @@ impl FftAlgorithms for CudaClient { let output_guard = AllocGuard::new(self.allocator(), output_size)?; let output_ptr = output_guard.ptr(); - let input_ptr = input_contig.storage().ptr(); + let input_ptr = input_contig.ptr(); // Choose small FFT (shared memory) or large FFT (multi-stage) based on size if n <= kernels::MAX_SHARED_MEM_FFT_SIZE { @@ -238,7 +238,7 @@ impl FftAlgorithms for CudaClient { self.stream(), device.index, dtype, - input_contig.storage().ptr(), + input_contig.ptr(), complex_ptr, n, batch_size, @@ -389,7 +389,7 @@ impl FftAlgorithms for CudaClient { self.stream(), device.index, dtype, - input_contig.storage().ptr(), + input_contig.ptr(), full_complex_ptr, input_n, output_n, @@ -575,7 +575,7 @@ impl FftAlgorithms for CudaClient { self.stream(), device.index, dtype, - input_contig.storage().ptr(), + input_contig.ptr(), output_ptr, n, batch_size, @@ -622,7 +622,7 @@ impl FftAlgorithms for CudaClient { self.stream(), device.index, dtype, - input_contig.storage().ptr(), + input_contig.ptr(), output_ptr, n, batch_size, diff --git a/src/runtime/cuda/graph.rs b/src/runtime/cuda/graph.rs new file mode 100644 index 00000000..32105708 --- /dev/null +++ b/src/runtime/cuda/graph.rs @@ -0,0 +1,88 @@ +//! CUDA graph capture and replay +//! +//! Wraps cudarc's `CudaGraph` with `Send + Sync + Clone` for use with +//! numr's `Graph` trait. + +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; + +use cudarc::driver::safe::CudaGraph as CudarcGraph; + +/// Wrapper to make cudarc's CudaGraph safe to send across threads. +/// +/// # Safety +/// +/// cudarc's `CudaGraph` contains raw CUDA pointers (`CUgraph`, `CUgraphExec`) +/// which don't auto-implement `Send`. We wrap it in `Mutex` to serialize all +/// access. The only operation after instantiation is `launch()`, which: +/// 1. Binds the CUDA context to the current thread (`ctx.bind_to_thread()`) +/// 2. Calls `cuGraphLaunch` (a stream-ordered operation) +/// +/// No concurrent graph structure modification ever occurs. +struct CudaGraphInner(CudarcGraph); + +// SAFETY: Access is serialized via Mutex. After instantiation, only launch() +// is called, which binds CUDA context to the calling thread. +unsafe impl Send for CudaGraphInner {} + +/// CUDA graph — a captured computation sequence replayed via `cuGraphLaunch`. +/// +/// Created by `CudaRuntime::capture_graph()`. Thread-safe via internal `Mutex`. +/// `Clone` bumps the `Arc` refcount (no graph duplication). +pub struct CudaGraph { + inner: Arc>, + launch_count: Arc, +} + +impl Clone for CudaGraph { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + launch_count: self.launch_count.clone(), + } + } +} + +impl std::fmt::Debug for CudaGraph { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CudaGraph") + .field("launch_count", &self.launch_count.load(Ordering::Relaxed)) + .finish() + } +} + +impl CudaGraph { + /// Create a new CudaGraph wrapping cudarc's graph. + pub(crate) fn new(graph: CudarcGraph) -> Self { + Self { + inner: Arc::new(Mutex::new(CudaGraphInner(graph))), + launch_count: Arc::new(AtomicUsize::new(0)), + } + } + + /// How many times this graph has been launched. + pub fn launch_count(&self) -> usize { + self.launch_count.load(Ordering::Relaxed) + } +} + +impl crate::runtime::Graph for CudaGraph { + fn launch(&self) -> crate::error::Result<()> { + let guard = self.inner.lock().unwrap_or_else(|p| p.into_inner()); + guard + .0 + .launch() + .map_err(|e| crate::error::Error::Backend(format!("CUDA graph launch failed: {e}")))?; + self.launch_count.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + + fn is_replay_capable(&self) -> bool { + true + } +} + +// SAFETY: All interior access is serialized via Mutex. Arc provides shared ownership. +// The CudaGraph is only ever launched (no structural modification after instantiation). +unsafe impl Send for CudaGraph {} +unsafe impl Sync for CudaGraph {} diff --git a/src/runtime/cuda/kernels/activation.cu b/src/runtime/cuda/kernels/activation.cu index 80a2aad3..70521bf9 100644 --- a/src/runtime/cuda/kernels/activation.cu +++ b/src/runtime/cuda/kernels/activation.cu @@ -1,6 +1,8 @@ -// Activation CUDA kernels -// Supports: relu, sigmoid, softmax, silu, gelu +// Element-wise activation CUDA kernels +// Supports: relu, sigmoid, silu, gelu, leaky_relu, elu // Types: f32, f64, f16, bf16, fp8_e4m3, fp8_e5m2 +// +// Softmax kernels are in softmax.cu #include #include @@ -26,7 +28,6 @@ __global__ void sigmoid_f32(const float* a, float* out, unsigned int n) { } } -// SiLU (Swish): x * sigmoid(x) = x / (1 + exp(-x)) __global__ void silu_f32(const float* a, float* out, unsigned int n) { unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { @@ -35,117 +36,15 @@ __global__ void silu_f32(const float* a, float* out, unsigned int n) { } } -// GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) -// Using the tanh approximation for better performance __global__ void gelu_f32(const float* a, float* out, unsigned int n) { unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { float x = a[idx]; - // sqrt(2/pi) ≈ 0.7978845608 float cdf = 0.5f * (1.0f + tanhf(0.7978845608f * (x + 0.044715f * x * x * x))); out[idx] = x * cdf; } } -// Softmax over the last dimension -// outer_size = product of all dims except last -// dim_size = size of last dimension -__global__ void softmax_f32( - const float* input, float* output, - unsigned int outer_size, unsigned int dim_size -) { - unsigned int outer_idx = blockIdx.x; - if (outer_idx >= outer_size) return; - - extern __shared__ float shared[]; - float* max_val = shared; - float* sum_exp = shared + blockDim.x; - - const float* row_in = input + outer_idx * dim_size; - float* row_out = output + outer_idx * dim_size; - - // Phase 1: Find max value for numerical stability - float thread_max = -INFINITY; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - thread_max = fmaxf(thread_max, row_in[i]); - } - max_val[threadIdx.x] = thread_max; - __syncthreads(); - - // Reduce max across threads - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - max_val[threadIdx.x] = fmaxf(max_val[threadIdx.x], max_val[threadIdx.x + s]); - } - __syncthreads(); - } - float row_max = max_val[0]; - __syncthreads(); - - // Phase 2: Compute exp(x - max) and sum - float thread_sum = 0.0f; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - float val = expf(row_in[i] - row_max); - row_out[i] = val; // Temporarily store exp values - thread_sum += val; - } - sum_exp[threadIdx.x] = thread_sum; - __syncthreads(); - - // Reduce sum across threads - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - sum_exp[threadIdx.x] += sum_exp[threadIdx.x + s]; - } - __syncthreads(); - } - float row_sum = sum_exp[0]; - __syncthreads(); - - // Phase 3: Normalize - float inv_sum = 1.0f / row_sum; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - row_out[i] *= inv_sum; - } -} - -// Softmax over non-last dimension -// For shape [A, B, C] with softmax over dim=1: -// outer_size = A, dim_size = B, inner_size = C -__global__ void softmax_dim_f32( - const float* input, float* output, - unsigned int outer_size, unsigned int dim_size, unsigned int inner_size -) { - unsigned int outer_idx = blockIdx.x; - unsigned int inner_idx = blockIdx.y; - - if (outer_idx >= outer_size || inner_idx >= inner_size) return; - - // Base offset for this (outer, inner) position - unsigned int base = outer_idx * dim_size * inner_size + inner_idx; - unsigned int stride = inner_size; - - // Find max - float max_val = -INFINITY; - for (unsigned int i = 0; i < dim_size; i++) { - max_val = fmaxf(max_val, input[base + i * stride]); - } - - // Compute exp and sum - float sum = 0.0f; - for (unsigned int i = 0; i < dim_size; i++) { - float val = expf(input[base + i * stride] - max_val); - output[base + i * stride] = val; - sum += val; - } - - // Normalize - float inv_sum = 1.0f / sum; - for (unsigned int i = 0; i < dim_size; i++) { - output[base + i * stride] *= inv_sum; - } -} - // ============================================================================ // F64 Activation Operations // ============================================================================ @@ -181,96 +80,8 @@ __global__ void gelu_f64(const double* a, double* out, unsigned int n) { } } -__global__ void softmax_f64( - const double* input, double* output, - unsigned int outer_size, unsigned int dim_size -) { - unsigned int outer_idx = blockIdx.x; - if (outer_idx >= outer_size) return; - - extern __shared__ double shared_f64[]; - double* max_val = shared_f64; - double* sum_exp = shared_f64 + blockDim.x; - - const double* row_in = input + outer_idx * dim_size; - double* row_out = output + outer_idx * dim_size; - - // Phase 1: Find max - double thread_max = -INFINITY; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - thread_max = fmax(thread_max, row_in[i]); - } - max_val[threadIdx.x] = thread_max; - __syncthreads(); - - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - max_val[threadIdx.x] = fmax(max_val[threadIdx.x], max_val[threadIdx.x + s]); - } - __syncthreads(); - } - double row_max = max_val[0]; - __syncthreads(); - - // Phase 2: Compute exp and sum - double thread_sum = 0.0; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - double val = exp(row_in[i] - row_max); - row_out[i] = val; - thread_sum += val; - } - sum_exp[threadIdx.x] = thread_sum; - __syncthreads(); - - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - sum_exp[threadIdx.x] += sum_exp[threadIdx.x + s]; - } - __syncthreads(); - } - double row_sum = sum_exp[0]; - __syncthreads(); - - // Phase 3: Normalize - double inv_sum = 1.0 / row_sum; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - row_out[i] *= inv_sum; - } -} - -__global__ void softmax_dim_f64( - const double* input, double* output, - unsigned int outer_size, unsigned int dim_size, unsigned int inner_size -) { - unsigned int outer_idx = blockIdx.x; - unsigned int inner_idx = blockIdx.y; - - if (outer_idx >= outer_size || inner_idx >= inner_size) return; - - unsigned int base = outer_idx * dim_size * inner_size + inner_idx; - unsigned int stride = inner_size; - - double max_val = -INFINITY; - for (unsigned int i = 0; i < dim_size; i++) { - max_val = fmax(max_val, input[base + i * stride]); - } - - double sum = 0.0; - for (unsigned int i = 0; i < dim_size; i++) { - double val = exp(input[base + i * stride] - max_val); - output[base + i * stride] = val; - sum += val; - } - - double inv_sum = 1.0 / sum; - for (unsigned int i = 0; i < dim_size; i++) { - output[base + i * stride] *= inv_sum; - } -} - // ============================================================================ -// F16 Activation Operations -// Note: Uses FP32 internally for accuracy where needed +// F16 Activation Operations (FP32 internally for accuracy) // ============================================================================ __global__ void relu_f16(const __half* a, __half* out, unsigned int n) { @@ -284,7 +95,6 @@ __global__ void relu_f16(const __half* a, __half* out, unsigned int n) { __global__ void sigmoid_f16(const __half* a, __half* out, unsigned int n) { unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { - // Use FP32 for accuracy float x = __half2float(a[idx]); out[idx] = __float2half(1.0f / (1.0f + expf(-x))); } @@ -307,98 +117,8 @@ __global__ void gelu_f16(const __half* a, __half* out, unsigned int n) { } } -// F16 Softmax: Uses FP32 accumulation internally for numerical stability -__global__ void softmax_f16( - const __half* input, __half* output, - unsigned int outer_size, unsigned int dim_size -) { - unsigned int outer_idx = blockIdx.x; - if (outer_idx >= outer_size) return; - - extern __shared__ float shared[]; - float* max_val = shared; - float* sum_exp = shared + blockDim.x; - - const __half* row_in = input + outer_idx * dim_size; - __half* row_out = output + outer_idx * dim_size; - - // Phase 1: Find max value (using FP32) - float thread_max = -INFINITY; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - thread_max = fmaxf(thread_max, __half2float(row_in[i])); - } - max_val[threadIdx.x] = thread_max; - __syncthreads(); - - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - max_val[threadIdx.x] = fmaxf(max_val[threadIdx.x], max_val[threadIdx.x + s]); - } - __syncthreads(); - } - float row_max = max_val[0]; - __syncthreads(); - - // Phase 2: Compute exp(x - max) and sum (FP32 accumulation) - float thread_sum = 0.0f; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - float val = expf(__half2float(row_in[i]) - row_max); - row_out[i] = __float2half(val); - thread_sum += val; - } - sum_exp[threadIdx.x] = thread_sum; - __syncthreads(); - - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - sum_exp[threadIdx.x] += sum_exp[threadIdx.x + s]; - } - __syncthreads(); - } - float row_sum = sum_exp[0]; - __syncthreads(); - - // Phase 3: Normalize - float inv_sum = 1.0f / row_sum; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - row_out[i] = __float2half(__half2float(row_out[i]) * inv_sum); - } -} - -__global__ void softmax_dim_f16( - const __half* input, __half* output, - unsigned int outer_size, unsigned int dim_size, unsigned int inner_size -) { - unsigned int outer_idx = blockIdx.x; - unsigned int inner_idx = blockIdx.y; - - if (outer_idx >= outer_size || inner_idx >= inner_size) return; - - unsigned int base = outer_idx * dim_size * inner_size + inner_idx; - unsigned int stride = inner_size; - - // FP32 accumulation for stability - float max_val = -INFINITY; - for (unsigned int i = 0; i < dim_size; i++) { - max_val = fmaxf(max_val, __half2float(input[base + i * stride])); - } - - float sum = 0.0f; - for (unsigned int i = 0; i < dim_size; i++) { - float val = expf(__half2float(input[base + i * stride]) - max_val); - output[base + i * stride] = __float2half(val); - sum += val; - } - - float inv_sum = 1.0f / sum; - for (unsigned int i = 0; i < dim_size; i++) { - output[base + i * stride] = __float2half(__half2float(output[base + i * stride]) * inv_sum); - } -} - // ============================================================================ -// BF16 Activation Operations -// Note: Uses FP32 internally for accuracy where needed +// BF16 Activation Operations (FP32 internally for accuracy) // ============================================================================ __global__ void relu_bf16(const __nv_bfloat16* a, __nv_bfloat16* out, unsigned int n) { @@ -434,99 +154,8 @@ __global__ void gelu_bf16(const __nv_bfloat16* a, __nv_bfloat16* out, unsigned i } } -// BF16 Softmax: Uses FP32 accumulation internally for numerical stability -__global__ void softmax_bf16( - const __nv_bfloat16* input, __nv_bfloat16* output, - unsigned int outer_size, unsigned int dim_size -) { - unsigned int outer_idx = blockIdx.x; - if (outer_idx >= outer_size) return; - - extern __shared__ float shared[]; - float* max_val = shared; - float* sum_exp = shared + blockDim.x; - - const __nv_bfloat16* row_in = input + outer_idx * dim_size; - __nv_bfloat16* row_out = output + outer_idx * dim_size; - - // Phase 1: Find max value (using FP32) - float thread_max = -INFINITY; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - thread_max = fmaxf(thread_max, __bfloat162float(row_in[i])); - } - max_val[threadIdx.x] = thread_max; - __syncthreads(); - - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - max_val[threadIdx.x] = fmaxf(max_val[threadIdx.x], max_val[threadIdx.x + s]); - } - __syncthreads(); - } - float row_max = max_val[0]; - __syncthreads(); - - // Phase 2: Compute exp(x - max) and sum (FP32 accumulation) - float thread_sum = 0.0f; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - float val = expf(__bfloat162float(row_in[i]) - row_max); - row_out[i] = __float2bfloat16(val); - thread_sum += val; - } - sum_exp[threadIdx.x] = thread_sum; - __syncthreads(); - - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - sum_exp[threadIdx.x] += sum_exp[threadIdx.x + s]; - } - __syncthreads(); - } - float row_sum = sum_exp[0]; - __syncthreads(); - - // Phase 3: Normalize - float inv_sum = 1.0f / row_sum; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - row_out[i] = __float2bfloat16(__bfloat162float(row_out[i]) * inv_sum); - } -} - -__global__ void softmax_dim_bf16( - const __nv_bfloat16* input, __nv_bfloat16* output, - unsigned int outer_size, unsigned int dim_size, unsigned int inner_size -) { - unsigned int outer_idx = blockIdx.x; - unsigned int inner_idx = blockIdx.y; - - if (outer_idx >= outer_size || inner_idx >= inner_size) return; - - unsigned int base = outer_idx * dim_size * inner_size + inner_idx; - unsigned int stride = inner_size; - - // FP32 accumulation for stability - float max_val = -INFINITY; - for (unsigned int i = 0; i < dim_size; i++) { - max_val = fmaxf(max_val, __bfloat162float(input[base + i * stride])); - } - - float sum = 0.0f; - for (unsigned int i = 0; i < dim_size; i++) { - float val = expf(__bfloat162float(input[base + i * stride]) - max_val); - output[base + i * stride] = __float2bfloat16(val); - sum += val; - } - - float inv_sum = 1.0f / sum; - for (unsigned int i = 0; i < dim_size; i++) { - output[base + i * stride] = __float2bfloat16(__bfloat162float(output[base + i * stride]) * inv_sum); - } -} - // ============================================================================ -// FP8 E4M3 Activation Operations -// All computation done in F32, stored back as FP8 -// Uses Hopper PTX intrinsics on SM 8.9+, software emulation on SM 8.0+ +// FP8 E4M3 Activation Operations (FP32 internally) // ============================================================================ __global__ void relu_fp8_e4m3(const numr_fp8_e4m3* a, numr_fp8_e4m3* out, unsigned int n) { @@ -562,98 +191,8 @@ __global__ void gelu_fp8_e4m3(const numr_fp8_e4m3* a, numr_fp8_e4m3* out, unsign } } -// FP8 E4M3 Softmax: Uses FP32 accumulation internally for numerical stability -__global__ void softmax_fp8_e4m3( - const numr_fp8_e4m3* input, numr_fp8_e4m3* output, - unsigned int outer_size, unsigned int dim_size -) { - unsigned int outer_idx = blockIdx.x; - if (outer_idx >= outer_size) return; - - extern __shared__ float shared[]; - float* max_val = shared; - float* sum_exp = shared + blockDim.x; - - const numr_fp8_e4m3* row_in = input + outer_idx * dim_size; - numr_fp8_e4m3* row_out = output + outer_idx * dim_size; - - // Phase 1: Find max value (using FP32) - float thread_max = -INFINITY; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - thread_max = fmaxf(thread_max, fp8_e4m3_to_f32(row_in[i].data)); - } - max_val[threadIdx.x] = thread_max; - __syncthreads(); - - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - max_val[threadIdx.x] = fmaxf(max_val[threadIdx.x], max_val[threadIdx.x + s]); - } - __syncthreads(); - } - float row_max = max_val[0]; - __syncthreads(); - - // Phase 2: Compute exp(x - max) and sum (FP32 accumulation) - float thread_sum = 0.0f; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - float val = expf(fp8_e4m3_to_f32(row_in[i].data) - row_max); - row_out[i] = numr_fp8_e4m3(f32_to_fp8_e4m3(val)); - thread_sum += val; - } - sum_exp[threadIdx.x] = thread_sum; - __syncthreads(); - - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - sum_exp[threadIdx.x] += sum_exp[threadIdx.x + s]; - } - __syncthreads(); - } - float row_sum = sum_exp[0]; - __syncthreads(); - - // Phase 3: Normalize - float inv_sum = 1.0f / row_sum; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - row_out[i] = numr_fp8_e4m3(f32_to_fp8_e4m3(fp8_e4m3_to_f32(row_out[i].data) * inv_sum)); - } -} - -__global__ void softmax_dim_fp8_e4m3( - const numr_fp8_e4m3* input, numr_fp8_e4m3* output, - unsigned int outer_size, unsigned int dim_size, unsigned int inner_size -) { - unsigned int outer_idx = blockIdx.x; - unsigned int inner_idx = blockIdx.y; - - if (outer_idx >= outer_size || inner_idx >= inner_size) return; - - unsigned int base = outer_idx * dim_size * inner_size + inner_idx; - unsigned int stride = inner_size; - - // FP32 accumulation for stability - float max_val = -INFINITY; - for (unsigned int i = 0; i < dim_size; i++) { - max_val = fmaxf(max_val, fp8_e4m3_to_f32(input[base + i * stride].data)); - } - - float sum = 0.0f; - for (unsigned int i = 0; i < dim_size; i++) { - float val = expf(fp8_e4m3_to_f32(input[base + i * stride].data) - max_val); - output[base + i * stride] = numr_fp8_e4m3(f32_to_fp8_e4m3(val)); - sum += val; - } - - float inv_sum = 1.0f / sum; - for (unsigned int i = 0; i < dim_size; i++) { - output[base + i * stride] = numr_fp8_e4m3(f32_to_fp8_e4m3( - fp8_e4m3_to_f32(output[base + i * stride].data) * inv_sum)); - } -} - // ============================================================================ -// FP8 E5M2 Activation Operations +// FP8 E5M2 Activation Operations (FP32 internally) // ============================================================================ __global__ void relu_fp8_e5m2(const numr_fp8_e5m2* a, numr_fp8_e5m2* out, unsigned int n) { @@ -689,99 +228,8 @@ __global__ void gelu_fp8_e5m2(const numr_fp8_e5m2* a, numr_fp8_e5m2* out, unsign } } -// FP8 E5M2 Softmax -__global__ void softmax_fp8_e5m2( - const numr_fp8_e5m2* input, numr_fp8_e5m2* output, - unsigned int outer_size, unsigned int dim_size -) { - unsigned int outer_idx = blockIdx.x; - if (outer_idx >= outer_size) return; - - extern __shared__ float shared[]; - float* max_val = shared; - float* sum_exp = shared + blockDim.x; - - const numr_fp8_e5m2* row_in = input + outer_idx * dim_size; - numr_fp8_e5m2* row_out = output + outer_idx * dim_size; - - // Phase 1: Find max value (using FP32) - float thread_max = -INFINITY; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - thread_max = fmaxf(thread_max, fp8_e5m2_to_f32(row_in[i].data)); - } - max_val[threadIdx.x] = thread_max; - __syncthreads(); - - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - max_val[threadIdx.x] = fmaxf(max_val[threadIdx.x], max_val[threadIdx.x + s]); - } - __syncthreads(); - } - float row_max = max_val[0]; - __syncthreads(); - - // Phase 2: Compute exp(x - max) and sum - float thread_sum = 0.0f; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - float val = expf(fp8_e5m2_to_f32(row_in[i].data) - row_max); - row_out[i] = numr_fp8_e5m2(f32_to_fp8_e5m2(val)); - thread_sum += val; - } - sum_exp[threadIdx.x] = thread_sum; - __syncthreads(); - - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - sum_exp[threadIdx.x] += sum_exp[threadIdx.x + s]; - } - __syncthreads(); - } - float row_sum = sum_exp[0]; - __syncthreads(); - - // Phase 3: Normalize - float inv_sum = 1.0f / row_sum; - for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { - row_out[i] = numr_fp8_e5m2(f32_to_fp8_e5m2(fp8_e5m2_to_f32(row_out[i].data) * inv_sum)); - } -} - -__global__ void softmax_dim_fp8_e5m2( - const numr_fp8_e5m2* input, numr_fp8_e5m2* output, - unsigned int outer_size, unsigned int dim_size, unsigned int inner_size -) { - unsigned int outer_idx = blockIdx.x; - unsigned int inner_idx = blockIdx.y; - - if (outer_idx >= outer_size || inner_idx >= inner_size) return; - - unsigned int base = outer_idx * dim_size * inner_size + inner_idx; - unsigned int stride = inner_size; - - // FP32 accumulation for stability - float max_val = -INFINITY; - for (unsigned int i = 0; i < dim_size; i++) { - max_val = fmaxf(max_val, fp8_e5m2_to_f32(input[base + i * stride].data)); - } - - float sum = 0.0f; - for (unsigned int i = 0; i < dim_size; i++) { - float val = expf(fp8_e5m2_to_f32(input[base + i * stride].data) - max_val); - output[base + i * stride] = numr_fp8_e5m2(f32_to_fp8_e5m2(val)); - sum += val; - } - - float inv_sum = 1.0f / sum; - for (unsigned int i = 0; i < dim_size; i++) { - output[base + i * stride] = numr_fp8_e5m2(f32_to_fp8_e5m2( - fp8_e5m2_to_f32(output[base + i * stride].data) * inv_sum)); - } -} - // ============================================================================ -// Leaky ReLU Activation Operations -// leaky_relu(x) = max(negative_slope * x, x) +// Leaky ReLU: max(negative_slope * x, x) // ============================================================================ __global__ void leaky_relu_f32(const float* a, float* out, unsigned int n, float negative_slope) { @@ -834,8 +282,7 @@ __global__ void leaky_relu_fp8_e5m2(const numr_fp8_e5m2* a, numr_fp8_e5m2* out, } // ============================================================================ -// ELU (Exponential Linear Unit) Activation Operations -// elu(x) = x if x > 0, else alpha * (exp(x) - 1) +// ELU: x if x > 0, else alpha * (exp(x) - 1) // ============================================================================ __global__ void elu_f32(const float* a, float* out, unsigned int n, float alpha) { diff --git a/src/runtime/cuda/kernels/activation.rs b/src/runtime/cuda/kernels/activation/elementwise.rs similarity index 50% rename from src/runtime/cuda/kernels/activation.rs rename to src/runtime/cuda/kernels/activation/elementwise.rs index 4f3a38b1..b1d93700 100644 --- a/src/runtime/cuda/kernels/activation.rs +++ b/src/runtime/cuda/kernels/activation/elementwise.rs @@ -1,22 +1,17 @@ -//! Activation function CUDA kernel launchers +//! Element-wise activation CUDA kernel launchers //! -//! Provides launchers for activation functions (ReLU, sigmoid, softmax) -//! commonly used in neural networks. +//! Kernel source: activation.cu use cudarc::driver::PushKernelArg; use cudarc::driver::safe::{CudaContext, CudaStream}; use std::sync::Arc; -use super::loader::{ - BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name, - kernel_names, launch_config, launch_unary_kernel, softmax_launch_config, -}; use crate::dtype::DType; use crate::error::{Error, Result}; - -// ============================================================================ -// Element-wise Activations -// ============================================================================ +use crate::runtime::cuda::kernels::loader::{ + BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name, + kernel_names, launch_config, launch_unary_kernel, +}; /// Launch a ReLU (Rectified Linear Unit) kernel. /// @@ -52,9 +47,7 @@ pub unsafe fn launch_relu( /// Launch a SiLU (Swish) kernel. /// -/// Computes: `output[i] = input[i] * sigmoid(input[i]) = input[i] / (1 + exp(-input[i]))` -/// -/// SiLU (Sigmoid Linear Unit) is commonly used in modern architectures like LLaMA. +/// Computes: `output[i] = input[i] / (1 + exp(-input[i]))` /// /// # Safety /// @@ -86,10 +79,7 @@ pub unsafe fn launch_silu( /// Launch a GELU (Gaussian Error Linear Unit) kernel. /// -/// Computes the tanh approximation: -/// `output[i] = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))` -/// -/// GELU is used in models like BERT and GPT. +/// Computes: `output[i] = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))` /// /// # Safety /// @@ -155,8 +145,6 @@ pub unsafe fn launch_sigmoid( /// /// Computes: `output[i] = max(negative_slope * input[i], input[i])` /// -/// Allows small gradients for negative inputs, helping prevent "dying ReLU" problem. -/// /// # Safety /// /// - All pointers must be valid device memory @@ -199,8 +187,6 @@ pub unsafe fn launch_leaky_relu( /// /// Computes: `output[i] = input[i] if input[i] > 0, else alpha * (exp(input[i]) - 1)` /// -/// Smooth approximation to ReLU with negative values saturating to -alpha. -/// /// # Safety /// /// - All pointers must be valid device memory @@ -238,131 +224,3 @@ pub unsafe fn launch_elu( Ok(()) } } - -// ============================================================================ -// Softmax Activations -// ============================================================================ - -/// Launch softmax over the last dimension. -/// -/// For a tensor of shape `[..., D]`, computes softmax independently for each -/// of the `outer_size` vectors of length `dim_size`. -/// -/// The softmax is computed as: -/// ```text -/// softmax(x)[i] = exp(x[i] - max(x)) / sum(exp(x - max(x))) -/// ``` -/// -/// Uses shared memory for parallel reduction of max and sum values. -/// -/// # Safety -/// -/// - All pointers must be valid device memory -/// - `input_ptr` must have `outer_size * dim_size` elements -/// - `output_ptr` must have `outer_size * dim_size` elements -/// -/// # Arguments -/// -/// * `outer_size` - Number of independent softmax computations (product of all dims except last) -/// * `dim_size` - Size of the last dimension (the dimension over which softmax is computed) -pub unsafe fn launch_softmax( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - input_ptr: u64, - output_ptr: u64, - outer_size: usize, - dim_size: usize, -) -> Result<()> { - unsafe { - let module = get_or_load_module(context, device_index, kernel_names::ACTIVATION_MODULE)?; - let func_name = kernel_name("softmax", dtype); - let func = get_kernel_function(&module, &func_name)?; - - let (grid_size, block_size, shared_mem) = softmax_launch_config(outer_size, dim_size); - let outer = outer_size as u32; - let dim = dim_size as u32; - - // Adjust shared memory for f64 (double the size) - let shared_mem = if dtype == DType::F64 { - shared_mem * 2 - } else { - shared_mem - }; - - let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem); - let mut builder = stream.launch_builder(&func); - builder.arg(&input_ptr); - builder.arg(&output_ptr); - builder.arg(&outer); - builder.arg(&dim); - - builder - .launch(cfg) - .map_err(|e| Error::Internal(format!("CUDA softmax kernel launch failed: {:?}", e)))?; - - Ok(()) - } -} - -/// Launch softmax over a non-last dimension. -/// -/// For a tensor of shape `[A, B, C]` with softmax over dimension 1: -/// - `outer_size` = A -/// - `dim_size` = B -/// - `inner_size` = C -/// -/// Each thread handles one (outer, inner) position and sequentially computes -/// softmax across the `dim_size` elements. -/// -/// # Performance Note -/// -/// This kernel uses one thread per (outer, inner) position with sequential -/// processing over dim_size. For large dim_size values, consider using -/// `launch_softmax` by transposing the tensor to put the reduction dimension last. -/// -/// # Safety -/// -/// - All pointers must be valid device memory -/// - `input_ptr` must have `outer_size * dim_size * inner_size` elements -/// - `output_ptr` must have `outer_size * dim_size * inner_size` elements -pub unsafe fn launch_softmax_dim( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - input_ptr: u64, - output_ptr: u64, - outer_size: usize, - dim_size: usize, - inner_size: usize, -) -> Result<()> { - unsafe { - let module = get_or_load_module(context, device_index, kernel_names::ACTIVATION_MODULE)?; - let func_name = kernel_name("softmax_dim", dtype); - let func = get_kernel_function(&module, &func_name)?; - - // The kernel uses blockIdx.x for outer and blockIdx.y for inner, - // with each thread handling one (outer, inner) pair sequentially over dim_size. - // This is intentionally a 2D grid with 1 thread per block to match the kernel design. - let grid = (outer_size as u32, inner_size as u32, 1); - let outer = outer_size as u32; - let dim = dim_size as u32; - let inner = inner_size as u32; - - let cfg = launch_config(grid, (1, 1, 1), 0); - let mut builder = stream.launch_builder(&func); - builder.arg(&input_ptr); - builder.arg(&output_ptr); - builder.arg(&outer); - builder.arg(&dim); - builder.arg(&inner); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!("CUDA softmax_dim kernel launch failed: {:?}", e)) - })?; - - Ok(()) - } -} diff --git a/src/runtime/cuda/kernels/activation/mod.rs b/src/runtime/cuda/kernels/activation/mod.rs new file mode 100644 index 00000000..737ab27b --- /dev/null +++ b/src/runtime/cuda/kernels/activation/mod.rs @@ -0,0 +1,11 @@ +//! Activation CUDA kernel launchers +//! +//! Split into submodules: +//! - `elementwise` - relu, sigmoid, silu, gelu, leaky_relu, elu +//! - `softmax` - softmax forward + backward (last-dim and non-last-dim) + +mod elementwise; +mod softmax; + +pub use elementwise::*; +pub use softmax::*; diff --git a/src/runtime/cuda/kernels/activation/softmax.rs b/src/runtime/cuda/kernels/activation/softmax.rs new file mode 100644 index 00000000..2655db08 --- /dev/null +++ b/src/runtime/cuda/kernels/activation/softmax.rs @@ -0,0 +1,202 @@ +//! Softmax CUDA kernel launchers (forward + backward) +//! +//! Kernel source: softmax.cu + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::runtime::cuda::kernels::loader::{ + get_kernel_function, get_or_load_module, kernel_name, kernel_names, launch_config, + softmax_launch_config, +}; + +/// Launch softmax over the last dimension. +/// +/// Uses shared memory for parallel reduction of max and sum values. +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - `input_ptr` must have `outer_size * dim_size` elements +/// - `output_ptr` must have `outer_size * dim_size` elements +pub unsafe fn launch_softmax( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + output_ptr: u64, + outer_size: usize, + dim_size: usize, +) -> Result<()> { + unsafe { + let module = get_or_load_module(context, device_index, kernel_names::SOFTMAX_MODULE)?; + let func_name = kernel_name("softmax", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let (grid_size, block_size, shared_mem) = softmax_launch_config(outer_size, dim_size); + let outer = outer_size as u32; + let dim = dim_size as u32; + + let shared_mem = if dtype == DType::F64 { + shared_mem * 2 + } else { + shared_mem + }; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem); + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&output_ptr); + builder.arg(&outer); + builder.arg(&dim); + + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA softmax kernel launch failed: {:?}", e)))?; + + Ok(()) + } +} + +/// Launch softmax over a non-last dimension. +/// +/// For shape `[A, B, C]` with softmax over dim=1: outer=A, dim=B, inner=C. +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - Tensors must have `outer_size * dim_size * inner_size` elements +pub unsafe fn launch_softmax_dim( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + output_ptr: u64, + outer_size: usize, + dim_size: usize, + inner_size: usize, +) -> Result<()> { + unsafe { + let module = get_or_load_module(context, device_index, kernel_names::SOFTMAX_MODULE)?; + let func_name = kernel_name("softmax_dim", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = (outer_size as u32, inner_size as u32, 1); + let outer = outer_size as u32; + let dim = dim_size as u32; + let inner = inner_size as u32; + + let cfg = launch_config(grid, (1, 1, 1), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&output_ptr); + builder.arg(&outer); + builder.arg(&dim); + builder.arg(&inner); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA softmax_dim kernel launch failed: {:?}", e)) + })?; + + Ok(()) + } +} + +/// Launch softmax backward kernel (last dimension). +/// +/// Computes: d_input = output * (grad - sum(grad * output)) +/// +/// # Safety +/// - All pointers must be valid device memory of `outer_size * dim_size` elements +pub unsafe fn launch_softmax_bwd( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + grad_ptr: u64, + output_ptr: u64, + d_input_ptr: u64, + outer_size: usize, + dim_size: usize, +) -> Result<()> { + unsafe { + let module = get_or_load_module(context, device_index, kernel_names::SOFTMAX_MODULE)?; + let func_name = kernel_name("softmax_bwd", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let (grid_size, block_size, shared_mem) = softmax_launch_config(outer_size, dim_size); + let outer = outer_size as u32; + let dim = dim_size as u32; + + let shared_mem = if dtype == DType::F64 { + shared_mem * 2 + } else { + shared_mem + }; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem); + let mut builder = stream.launch_builder(&func); + builder.arg(&grad_ptr); + builder.arg(&output_ptr); + builder.arg(&d_input_ptr); + builder.arg(&outer); + builder.arg(&dim); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA softmax_bwd kernel launch failed: {:?}", e)) + })?; + + Ok(()) + } +} + +/// Launch softmax backward kernel (non-last dimension). +/// +/// # Safety +/// - All pointers must be valid device memory +pub unsafe fn launch_softmax_bwd_dim( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + grad_ptr: u64, + output_ptr: u64, + d_input_ptr: u64, + outer_size: usize, + dim_size: usize, + inner_size: usize, +) -> Result<()> { + unsafe { + let module = get_or_load_module(context, device_index, kernel_names::SOFTMAX_MODULE)?; + let func_name = kernel_name("softmax_bwd_dim", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = (outer_size as u32, inner_size as u32, 1); + let outer = outer_size as u32; + let dim = dim_size as u32; + let inner = inner_size as u32; + + let cfg = launch_config(grid, (1, 1, 1), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&grad_ptr); + builder.arg(&output_ptr); + builder.arg(&d_input_ptr); + builder.arg(&outer); + builder.arg(&dim); + builder.arg(&inner); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA softmax_bwd_dim kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} diff --git a/src/runtime/cuda/kernels/activation_deriv.cuh b/src/runtime/cuda/kernels/activation_deriv.cuh new file mode 100644 index 00000000..af3187a0 --- /dev/null +++ b/src/runtime/cuda/kernels/activation_deriv.cuh @@ -0,0 +1,155 @@ +// Shared activation derivative and forward helpers for CUDA backward kernels. +// +// Used by: gemm_epilogue_bwd.cu, fused_activation_mul_bwd.cu +// +// Activation type encoding (for switch-based dispatch): +// 0 = None (identity), 1 = ReLU, 2 = GELU, 3 = SiLU, 4 = Sigmoid, 5 = Tanh +// +// GELU tanh-approximation clamping ranges: +// f32: ±15.0 — tanhf(15) saturates to ±1.0f in float32 precision, and +// expf(30) < FLT_MAX so no overflow in tanh's internal exp(2x). +// f64: ±20.0 — tanh(20) saturates to ±1.0 in float64 precision, and +// exp(40) < DBL_MAX so no overflow. Tighter than ±15 would +// lose valid precision for f64. + +#pragma once + +// ============================================================================ +// Per-activation derivative helpers (scalar, __forceinline__) +// ============================================================================ + +__device__ __forceinline__ float relu_deriv_f32(float x) { + return x > 0.0f ? 1.0f : 0.0f; +} + +__device__ __forceinline__ float sigmoid_fwd_f32(float x) { + return 1.0f / (1.0f + expf(-x)); +} + +__device__ __forceinline__ float sigmoid_deriv_f32(float x) { + float sig = sigmoid_fwd_f32(x); + return sig * (1.0f - sig); +} + +__device__ __forceinline__ float tanh_deriv_f32(float x) { + float t = tanhf(x); + return 1.0f - t * t; +} + +__device__ __forceinline__ float silu_deriv_f32(float x) { + float sig = sigmoid_fwd_f32(x); + return sig + x * sig * (1.0f - sig); +} + +__device__ __forceinline__ float gelu_deriv_f32(float x) { + const float c = 0.7978845608f; // sqrt(2/pi) + const float k = 0.044715f; + float inner = c * (x + k * x * x * x); + // Clamp to prevent exp overflow in tanh (see header comment for range rationale) + inner = fminf(fmaxf(inner, -15.0f), 15.0f); + float t = tanhf(inner); + return 0.5f * (1.0f + t) + 0.5f * x * (1.0f - t * t) * c * (1.0f + 3.0f * k * x * x); +} + +// Switch-based dispatcher for f32 +__device__ __forceinline__ float activation_deriv_f32(float x, unsigned int act_type) { + switch (act_type) { + case 0: return 1.0f; + case 1: return relu_deriv_f32(x); + case 2: return gelu_deriv_f32(x); + case 3: return silu_deriv_f32(x); + case 4: return sigmoid_deriv_f32(x); + case 5: return tanh_deriv_f32(x); + default: return 1.0f; + } +} + +// ============================================================================ +// F64 variants +// ============================================================================ + +__device__ __forceinline__ double relu_deriv_f64(double x) { + return x > 0.0 ? 1.0 : 0.0; +} + +__device__ __forceinline__ double sigmoid_fwd_f64(double x) { + return 1.0 / (1.0 + exp(-x)); +} + +__device__ __forceinline__ double sigmoid_deriv_f64(double x) { + double sig = sigmoid_fwd_f64(x); + return sig * (1.0 - sig); +} + +__device__ __forceinline__ double tanh_deriv_f64(double x) { + double t = tanh(x); + return 1.0 - t * t; +} + +__device__ __forceinline__ double silu_deriv_f64(double x) { + double sig = sigmoid_fwd_f64(x); + return sig + x * sig * (1.0 - sig); +} + +__device__ __forceinline__ double gelu_deriv_f64(double x) { + const double c = 0.7978845608028654; // sqrt(2/pi) + const double k = 0.044715; + double inner = c * (x + k * x * x * x); + // Clamp to prevent exp overflow in tanh (see header comment for range rationale) + inner = fmin(fmax(inner, -20.0), 20.0); + double t = tanh(inner); + return 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * c * (1.0 + 3.0 * k * x * x); +} + +// Switch-based dispatcher for f64 +__device__ __forceinline__ double activation_deriv_f64(double x, unsigned int act_type) { + switch (act_type) { + case 0: return 1.0; + case 1: return relu_deriv_f64(x); + case 2: return gelu_deriv_f64(x); + case 3: return silu_deriv_f64(x); + case 4: return sigmoid_deriv_f64(x); + case 5: return tanh_deriv_f64(x); + default: return 1.0; + } +} + +// ============================================================================ +// Forward value helpers (used by fused activation-mul backward) +// ============================================================================ + +__device__ __forceinline__ float relu_fwd_f32(float x) { + return fmaxf(0.0f, x); +} + +__device__ __forceinline__ float silu_fwd_f32(float x) { + return x * sigmoid_fwd_f32(x); +} + +__device__ __forceinline__ float gelu_fwd_f32(float x) { + const float c = 0.7978845608f; + const float k = 0.044715f; + float inner = c * (x + k * x * x * x); + inner = fminf(fmaxf(inner, -15.0f), 15.0f); + return 0.5f * x * (1.0f + tanhf(inner)); +} + +__device__ __forceinline__ double relu_fwd_f64(double x) { + return fmax(0.0, x); +} + +__device__ __forceinline__ double sigmoid_fwd_f64_val(double x) { + return sigmoid_fwd_f64(x); +} + +__device__ __forceinline__ double silu_fwd_f64(double x) { + return x * sigmoid_fwd_f64(x); +} + +__device__ __forceinline__ double gelu_fwd_f64(double x) { + const double c = 0.7978845608028654; + const double k = 0.044715; + double inner = c * (x + k * x * x * x); + inner = fmin(fmax(inner, -20.0), 20.0); + return 0.5 * x * (1.0 + tanh(inner)); +} diff --git a/src/runtime/cuda/kernels/binary.rs b/src/runtime/cuda/kernels/binary.rs index 21f4a6a9..71407c12 100644 --- a/src/runtime/cuda/kernels/binary.rs +++ b/src/runtime/cuda/kernels/binary.rs @@ -344,10 +344,10 @@ pub unsafe fn launch_broadcast_binary_op( let shape_tensor = Tensor::::from_slice(&shape_u32, &[ndim], device); // Get device pointers - let a_strides_ptr = a_strides_tensor.storage().ptr(); - let b_strides_ptr = b_strides_tensor.storage().ptr(); - let out_strides_ptr = out_strides_tensor.storage().ptr(); - let shape_ptr = shape_tensor.storage().ptr(); + let a_strides_ptr = a_strides_tensor.ptr(); + let b_strides_ptr = b_strides_tensor.ptr(); + let out_strides_ptr = out_strides_tensor.ptr(); + let shape_ptr = shape_tensor.ptr(); // Get kernel function let module = get_or_load_module(context, device_index, kernel_names::BINARY_MODULE)?; @@ -386,10 +386,8 @@ pub unsafe fn launch_broadcast_binary_op( })?; } - // Synchronize to ensure the kernel completes before freeing temporary allocations - stream - .synchronize() - .map_err(|e| Error::Internal(format!("Stream sync failed: {:?}", e)))?; + // No sync needed: temporary GPU allocations (strides, shape tensors) are freed via + // cuMemFreeAsync which is stream-ordered — the free happens after the kernel completes. Ok(()) } diff --git a/src/runtime/cuda/kernels/compare.rs b/src/runtime/cuda/kernels/compare.rs index baa31a87..40e97cc4 100644 --- a/src/runtime/cuda/kernels/compare.rs +++ b/src/runtime/cuda/kernels/compare.rs @@ -146,9 +146,9 @@ pub unsafe fn launch_broadcast_compare_op( let shape_tensor = Tensor::::from_slice(&shape_u32, &[ndim], device); // Get device pointers - let a_strides_ptr = a_strides_tensor.storage().ptr(); - let b_strides_ptr = b_strides_tensor.storage().ptr(); - let shape_ptr = shape_tensor.storage().ptr(); + let a_strides_ptr = a_strides_tensor.ptr(); + let b_strides_ptr = b_strides_tensor.ptr(); + let shape_ptr = shape_tensor.ptr(); // Get kernel function let module = get_or_load_module(context, device_index, kernel_names::COMPARE_MODULE)?; @@ -186,10 +186,7 @@ pub unsafe fn launch_broadcast_compare_op( })?; } - // Synchronize to ensure the kernel completes before freeing temporary allocations - stream - .synchronize() - .map_err(|e| Error::Internal(format!("Stream sync failed: {:?}", e)))?; + // No sync needed: temporary GPU allocations freed via cuMemFreeAsync (stream-ordered). Ok(()) } diff --git a/src/runtime/cuda/kernels/distance.cu b/src/runtime/cuda/kernels/distance.cu index d4fa7a54..d499f911 100644 --- a/src/runtime/cuda/kernels/distance.cu +++ b/src/runtime/cuda/kernels/distance.cu @@ -2,184 +2,230 @@ // // Provides efficient pairwise distance computation for various metrics. // All kernels support F32, F64, F16, and BF16 data types. +// +// Precision: F32/F64 accumulate in native precision. +// F16/BF16 accumulate in F32 for accuracy. #include #include #include +#include "dtype_traits.cuh" // ============================================================================ -// Type Conversion Helpers +// Accumulation Type Traits // ============================================================================ -template -__device__ __forceinline__ float to_float(T val) { - return static_cast(val); +// AccT: the type used for accumulation and intermediate computation. +// F32 -> float, F64 -> double, F16/BF16 -> float (compute in F32 for accuracy) +template struct AccType { using type = T; }; +template<> struct AccType<__half> { using type = float; }; +template<> struct AccType<__nv_bfloat16> { using type = float; }; +template<> struct AccType { using type = float; }; +template<> struct AccType { using type = float; }; + +// ============================================================================ +// Type Conversion Helpers (to/from AccT) +// ============================================================================ + +template +__device__ __forceinline__ AccT to_acc(T val) { + return static_cast(val); } template<> -__device__ __forceinline__ float to_float<__half>(__half val) { +__device__ __forceinline__ float to_acc(__half val) { return __half2float(val); } template<> -__device__ __forceinline__ float to_float<__nv_bfloat16>(__nv_bfloat16 val) { +__device__ __forceinline__ float to_acc(__nv_bfloat16 val) { return __bfloat162float(val); } -template -__device__ __forceinline__ T from_float(float val) { +template +__device__ __forceinline__ T from_acc(AccT val) { return static_cast(val); } template<> -__device__ __forceinline__ __half from_float<__half>(float val) { +__device__ __forceinline__ __half from_acc<__half, float>(float val) { return __float2half(val); } template<> -__device__ __forceinline__ __nv_bfloat16 from_float<__nv_bfloat16>(float val) { +__device__ __forceinline__ __nv_bfloat16 from_acc<__nv_bfloat16, float>(float val) { return __float2bfloat16(val); } +template<> +__device__ __forceinline__ float to_acc(numr_fp8_e4m3 val) { + return fp8_e4m3_to_f32(val.data); +} + +template<> +__device__ __forceinline__ float to_acc(numr_fp8_e5m2 val) { + return fp8_e5m2_to_f32(val.data); +} + +template<> +__device__ __forceinline__ numr_fp8_e4m3 from_acc(float val) { + return numr_fp8_e4m3(f32_to_fp8_e4m3(val)); +} + +template<> +__device__ __forceinline__ numr_fp8_e5m2 from_acc(float val) { + return numr_fp8_e5m2(f32_to_fp8_e5m2(val)); +} + // ============================================================================ -// Distance Metric Implementations +// Math helpers — dispatch sqrt/fabs/pow to correct precision // ============================================================================ -// Squared Euclidean distance between two vectors -template -__device__ float sqeuclidean_dist(const T* a, const T* b, unsigned int d) { - float sum = 0.0f; +__device__ __forceinline__ float acc_sqrt(float x) { return sqrtf(x); } +__device__ __forceinline__ double acc_sqrt(double x) { return sqrt(x); } + +__device__ __forceinline__ float acc_fabs(float x) { return fabsf(x); } +__device__ __forceinline__ double acc_fabs(double x) { return fabs(x); } + +__device__ __forceinline__ float acc_pow(float x, float y) { return powf(x, y); } +__device__ __forceinline__ double acc_pow(double x, double y) { return pow(x, y); } + +// ============================================================================ +// Distance Metric Implementations (templated on T and AccT) +// ============================================================================ + +// Squared Euclidean distance +template +__device__ AccT sqeuclidean_dist(const T* a, const T* b, unsigned int d) { + AccT sum = AccT(0); for (unsigned int k = 0; k < d; k++) { - float diff = to_float(a[k]) - to_float(b[k]); + AccT diff = to_acc(a[k]) - to_acc(b[k]); sum += diff * diff; } return sum; } // Euclidean (L2) distance -template -__device__ float euclidean_dist(const T* a, const T* b, unsigned int d) { - return sqrtf(sqeuclidean_dist(a, b, d)); +template +__device__ AccT euclidean_dist(const T* a, const T* b, unsigned int d) { + return acc_sqrt(sqeuclidean_dist(a, b, d)); } // Manhattan (L1) distance -template -__device__ float manhattan_dist(const T* a, const T* b, unsigned int d) { - float sum = 0.0f; +template +__device__ AccT manhattan_dist(const T* a, const T* b, unsigned int d) { + AccT sum = AccT(0); for (unsigned int k = 0; k < d; k++) { - sum += fabsf(to_float(a[k]) - to_float(b[k])); + sum += acc_fabs(to_acc(a[k]) - to_acc(b[k])); } return sum; } // Chebyshev (L-infinity) distance -template -__device__ float chebyshev_dist(const T* a, const T* b, unsigned int d) { - float max_val = 0.0f; +template +__device__ AccT chebyshev_dist(const T* a, const T* b, unsigned int d) { + AccT max_val = AccT(0); for (unsigned int k = 0; k < d; k++) { - float abs_diff = fabsf(to_float(a[k]) - to_float(b[k])); + AccT abs_diff = acc_fabs(to_acc(a[k]) - to_acc(b[k])); if (abs_diff > max_val) max_val = abs_diff; } return max_val; } // Minkowski (Lp) distance -template -__device__ float minkowski_dist(const T* a, const T* b, unsigned int d, float p) { - float sum = 0.0f; +template +__device__ AccT minkowski_dist(const T* a, const T* b, unsigned int d, AccT p) { + AccT sum = AccT(0); for (unsigned int k = 0; k < d; k++) { - sum += powf(fabsf(to_float(a[k]) - to_float(b[k])), p); + sum += acc_pow(acc_fabs(to_acc(a[k]) - to_acc(b[k])), p); } - return powf(sum, 1.0f / p); + return acc_pow(sum, AccT(1) / p); } // Cosine distance: 1 - cos(theta) -template -__device__ float cosine_dist(const T* a, const T* b, unsigned int d) { - float dot = 0.0f; - float norm_a = 0.0f; - float norm_b = 0.0f; +template +__device__ AccT cosine_dist(const T* a, const T* b, unsigned int d) { + AccT dot = AccT(0); + AccT norm_a = AccT(0); + AccT norm_b = AccT(0); for (unsigned int k = 0; k < d; k++) { - float ak = to_float(a[k]); - float bk = to_float(b[k]); + AccT ak = to_acc(a[k]); + AccT bk = to_acc(b[k]); dot += ak * bk; norm_a += ak * ak; norm_b += bk * bk; } - float denom = sqrtf(norm_a * norm_b); - if (denom == 0.0f) return 0.0f; - return 1.0f - dot / denom; + AccT denom = acc_sqrt(norm_a * norm_b); + if (denom == AccT(0)) return AccT(0); + return AccT(1) - dot / denom; } // Correlation distance: 1 - Pearson r -template -__device__ float correlation_dist(const T* a, const T* b, unsigned int d) { - // Compute means - float sum_a = 0.0f; - float sum_b = 0.0f; +template +__device__ AccT correlation_dist(const T* a, const T* b, unsigned int d) { + AccT sum_a = AccT(0); + AccT sum_b = AccT(0); for (unsigned int k = 0; k < d; k++) { - sum_a += to_float(a[k]); - sum_b += to_float(b[k]); + sum_a += to_acc(a[k]); + sum_b += to_acc(b[k]); } - float mean_a = sum_a / d; - float mean_b = sum_b / d; + AccT mean_a = sum_a / AccT(d); + AccT mean_b = sum_b / AccT(d); - // Compute correlation - float cov = 0.0f; - float var_a = 0.0f; - float var_b = 0.0f; + AccT cov = AccT(0); + AccT var_a = AccT(0); + AccT var_b = AccT(0); for (unsigned int k = 0; k < d; k++) { - float da = to_float(a[k]) - mean_a; - float db = to_float(b[k]) - mean_b; + AccT da = to_acc(a[k]) - mean_a; + AccT db = to_acc(b[k]) - mean_b; cov += da * db; var_a += da * da; var_b += db * db; } - float denom = sqrtf(var_a * var_b); - if (denom == 0.0f) return 0.0f; - return 1.0f - cov / denom; + AccT denom = acc_sqrt(var_a * var_b); + if (denom == AccT(0)) return AccT(0); + return AccT(1) - cov / denom; } // Hamming distance: fraction of differing elements -template -__device__ float hamming_dist(const T* a, const T* b, unsigned int d) { - float count = 0.0f; +template +__device__ AccT hamming_dist(const T* a, const T* b, unsigned int d) { + AccT count = AccT(0); for (unsigned int k = 0; k < d; k++) { - if (to_float(a[k]) != to_float(b[k])) { - count += 1.0f; + if (to_acc(a[k]) != to_acc(b[k])) { + count += AccT(1); } } - return count / d; + return count / AccT(d); } // Jaccard distance: 1 - |intersection|/|union| for binary vectors -template -__device__ float jaccard_dist(const T* a, const T* b, unsigned int d) { - float intersection = 0.0f; - float union_count = 0.0f; +template +__device__ AccT jaccard_dist(const T* a, const T* b, unsigned int d) { + AccT intersection = AccT(0); + AccT union_count = AccT(0); for (unsigned int k = 0; k < d; k++) { - float ak = to_float(a[k]); - float bk = to_float(b[k]); - bool a_nonzero = (ak != 0.0f); - bool b_nonzero = (bk != 0.0f); + AccT ak = to_acc(a[k]); + AccT bk = to_acc(b[k]); + bool a_nonzero = (ak != AccT(0)); + bool b_nonzero = (bk != AccT(0)); - if (a_nonzero && b_nonzero) intersection += 1.0f; - if (a_nonzero || b_nonzero) union_count += 1.0f; + if (a_nonzero && b_nonzero) intersection += AccT(1); + if (a_nonzero || b_nonzero) union_count += AccT(1); } - if (union_count == 0.0f) return 0.0f; - return 1.0f - intersection / union_count; + if (union_count == AccT(0)) return AccT(0); + return AccT(1) - intersection / union_count; } // ============================================================================ // Metric Dispatch // ============================================================================ -// Distance metric enum values (must match Rust DistanceMetric) #define METRIC_EUCLIDEAN 0 #define METRIC_SQEUCLIDEAN 1 #define METRIC_MANHATTAN 2 @@ -190,28 +236,28 @@ __device__ float jaccard_dist(const T* a, const T* b, unsigned int d) { #define METRIC_HAMMING 7 #define METRIC_JACCARD 8 -template -__device__ float compute_distance(const T* a, const T* b, unsigned int d, - unsigned int metric, float p) { +template +__device__ AccT compute_distance(const T* a, const T* b, unsigned int d, + unsigned int metric, AccT p) { switch (metric) { - case METRIC_EUCLIDEAN: return euclidean_dist(a, b, d); - case METRIC_SQEUCLIDEAN: return sqeuclidean_dist(a, b, d); - case METRIC_MANHATTAN: return manhattan_dist(a, b, d); - case METRIC_CHEBYSHEV: return chebyshev_dist(a, b, d); - case METRIC_MINKOWSKI: return minkowski_dist(a, b, d, p); - case METRIC_COSINE: return cosine_dist(a, b, d); - case METRIC_CORRELATION: return correlation_dist(a, b, d); - case METRIC_HAMMING: return hamming_dist(a, b, d); - case METRIC_JACCARD: return jaccard_dist(a, b, d); - default: return 0.0f; + case METRIC_EUCLIDEAN: return euclidean_dist(a, b, d); + case METRIC_SQEUCLIDEAN: return sqeuclidean_dist(a, b, d); + case METRIC_MANHATTAN: return manhattan_dist(a, b, d); + case METRIC_CHEBYSHEV: return chebyshev_dist(a, b, d); + case METRIC_MINKOWSKI: return minkowski_dist(a, b, d, p); + case METRIC_COSINE: return cosine_dist(a, b, d); + case METRIC_CORRELATION: return correlation_dist(a, b, d); + case METRIC_HAMMING: return hamming_dist(a, b, d); + case METRIC_JACCARD: return jaccard_dist(a, b, d); + default: return AccT(0); } } // ============================================================================ -// CDIST Device Function - Pairwise distances between two sets +// CDIST Kernel - Pairwise distances between two sets // ============================================================================ -template +template __device__ void cdist_kernel_impl( const T* __restrict__ x, // (n, d) const T* __restrict__ y, // (m, d) @@ -220,48 +266,40 @@ __device__ void cdist_kernel_impl( unsigned int m, unsigned int d, unsigned int metric, - float p + AccT p ) { - // Each thread computes one distance unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; unsigned int total = n * m; if (idx < total) { - unsigned int i = idx / m; // Row in output (index into x) - unsigned int j = idx % m; // Col in output (index into y) + unsigned int i = idx / m; + unsigned int j = idx % m; const T* x_row = x + i * d; const T* y_row = y + j * d; - float dist = compute_distance(x_row, y_row, d, metric, p); - out[idx] = from_float(dist); + AccT dist = compute_distance(x_row, y_row, d, metric, p); + out[idx] = from_acc(dist); } } // ============================================================================ -// PDIST Device Function - Pairwise distances within one set (condensed) +// PDIST Kernel - Pairwise distances within one set (condensed) // ============================================================================ -template +template __device__ void pdist_kernel_impl( const T* __restrict__ x, // (n, d) T* __restrict__ out, // (n*(n-1)/2,) unsigned int n, unsigned int d, unsigned int metric, - float p + AccT p ) { - // Each thread computes one distance from condensed index unsigned int k = blockIdx.x * blockDim.x + threadIdx.x; unsigned int total = n * (n - 1) / 2; if (k < total) { - // Convert condensed index k to (i, j) where i < j - // Using formula: k = n*i - i*(i+1)/2 + j - i - 1 - // Inverse: i = n - 2 - floor(sqrt(-8k + 4n*(n-1) - 7) / 2 - 0.5) - // j = k + i + 1 - n*(n-1)/2 + (n-i)*((n-i)-1)/2 - - // Simpler approach: iterate to find i, j unsigned int i = 0; unsigned int j_start = 1; unsigned int count = 0; @@ -280,19 +318,19 @@ __device__ void pdist_kernel_impl( const T* x_i = x + i * d; const T* x_j = x + j * d; - float dist = compute_distance(x_i, x_j, d, metric, p); - out[k] = from_float(dist); + AccT dist = compute_distance(x_i, x_j, d, metric, p); + out[k] = from_acc(dist); } } // ============================================================================ -// Squareform Device Function - Condensed to square +// Squareform Kernel - Condensed to square // ============================================================================ -template +template __device__ void squareform_kernel_impl( - const T* __restrict__ condensed, // (n*(n-1)/2,) - T* __restrict__ square, // (n, n) + const T* __restrict__ condensed, + T* __restrict__ square, unsigned int n ) { unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -303,14 +341,11 @@ __device__ void squareform_kernel_impl( unsigned int j = idx % n; if (i == j) { - // Diagonal is zero - square[idx] = from_float(0.0f); + square[idx] = from_acc(AccT(0)); } else if (i < j) { - // Upper triangle: k = n*i - i*(i+1)/2 + j - i - 1 unsigned int k = n * i - i * (i + 1) / 2 + j - i - 1; square[idx] = condensed[k]; } else { - // Lower triangle: mirror from upper unsigned int k = n * j - j * (j + 1) / 2 + i - j - 1; square[idx] = condensed[k]; } @@ -318,20 +353,19 @@ __device__ void squareform_kernel_impl( } // ============================================================================ -// Squareform Inverse Device Function - Square to condensed +// Squareform Inverse Kernel - Square to condensed // ============================================================================ template __device__ void squareform_inverse_kernel_impl( - const T* __restrict__ square, // (n, n) - T* __restrict__ condensed, // (n*(n-1)/2,) + const T* __restrict__ square, + T* __restrict__ condensed, unsigned int n ) { unsigned int k = blockIdx.x * blockDim.x + threadIdx.x; unsigned int total = n * (n - 1) / 2; if (k < total) { - // Convert k to (i, j) where i < j unsigned int i = 0; unsigned int count = 0; @@ -352,29 +386,35 @@ __device__ void squareform_inverse_kernel_impl( // Kernel Instantiations // ============================================================================ -#define INSTANTIATE_DISTANCE_KERNELS(T, suffix) \ +// F32: accumulate in float +// F64: accumulate in double +// F16/BF16: accumulate in float + +#define INSTANTIATE_DISTANCE_KERNELS(T, AccT, suffix) \ extern "C" __global__ void cdist_##suffix( \ const T* x, const T* y, T* out, \ unsigned int n, unsigned int m, unsigned int d, \ - unsigned int metric, float p) { \ - cdist_kernel_impl(x, y, out, n, m, d, metric, p); \ + unsigned int metric, AccT p) { \ + cdist_kernel_impl(x, y, out, n, m, d, metric, p); \ } \ extern "C" __global__ void pdist_##suffix( \ const T* x, T* out, \ unsigned int n, unsigned int d, \ - unsigned int metric, float p) { \ - pdist_kernel_impl(x, out, n, d, metric, p); \ + unsigned int metric, AccT p) { \ + pdist_kernel_impl(x, out, n, d, metric, p); \ } \ extern "C" __global__ void squareform_##suffix( \ const T* condensed, T* square, unsigned int n) { \ - squareform_kernel_impl(condensed, square, n); \ + squareform_kernel_impl(condensed, square, n); \ } \ extern "C" __global__ void squareform_inverse_##suffix( \ const T* square, T* condensed, unsigned int n) { \ squareform_inverse_kernel_impl(square, condensed, n); \ } -INSTANTIATE_DISTANCE_KERNELS(float, f32) -INSTANTIATE_DISTANCE_KERNELS(double, f64) -INSTANTIATE_DISTANCE_KERNELS(__half, f16) -INSTANTIATE_DISTANCE_KERNELS(__nv_bfloat16, bf16) +INSTANTIATE_DISTANCE_KERNELS(float, float, f32) +INSTANTIATE_DISTANCE_KERNELS(double, double, f64) +INSTANTIATE_DISTANCE_KERNELS(__half, float, f16) +INSTANTIATE_DISTANCE_KERNELS(__nv_bfloat16, float, bf16) +INSTANTIATE_DISTANCE_KERNELS(numr_fp8_e4m3, float, fp8_e4m3) +INSTANTIATE_DISTANCE_KERNELS(numr_fp8_e5m2, float, fp8_e5m2) diff --git a/src/runtime/cuda/kernels/distance.rs b/src/runtime/cuda/kernels/distance.rs index f8ad03f5..ad4e444f 100644 --- a/src/runtime/cuda/kernels/distance.rs +++ b/src/runtime/cuda/kernels/distance.rs @@ -32,14 +32,22 @@ fn metric_to_index(metric: DistanceMetric) -> u32 { } } -/// Get Minkowski p value from metric -fn metric_p_value(metric: DistanceMetric) -> f32 { +/// Get Minkowski p value from metric as f32 (for F32/F16/BF16 kernels) +fn metric_p_value_f32(metric: DistanceMetric) -> f32 { match metric { DistanceMetric::Minkowski(p) => p as f32, _ => 2.0, // Default (not used for non-Minkowski) } } +/// Get Minkowski p value from metric as f64 (for F64 kernel) +fn metric_p_value_f64(metric: DistanceMetric) -> f64 { + match metric { + DistanceMetric::Minkowski(p) => p, + _ => 2.0, // Default (not used for non-Minkowski) + } +} + /// Launch cdist kernel - pairwise distances between two point sets. /// /// # Safety @@ -85,10 +93,11 @@ pub unsafe fn launch_cdist( let cfg = launch_config(grid, block, 0); let metric_idx = metric_to_index(metric); - let p_value = metric_p_value(metric); let n_u32 = n as u32; let m_u32 = m as u32; let d_u32 = d as u32; + let p_f32 = metric_p_value_f32(metric); + let p_f64 = metric_p_value_f64(metric); let mut builder = stream.launch_builder(&func); builder.arg(&x_ptr); @@ -98,7 +107,13 @@ pub unsafe fn launch_cdist( builder.arg(&m_u32); builder.arg(&d_u32); builder.arg(&metric_idx); - builder.arg(&p_value); + + // AccT is f64 for F64 dtype, f32 for all others + if dtype == DType::F64 { + builder.arg(&p_f64); + } else { + builder.arg(&p_f32); + } builder .launch(cfg) @@ -149,9 +164,10 @@ pub unsafe fn launch_pdist( let cfg = launch_config(grid, block, 0); let metric_idx = metric_to_index(metric); - let p_value = metric_p_value(metric); let n_u32 = n as u32; let d_u32 = d as u32; + let p_f32 = metric_p_value_f32(metric); + let p_f64 = metric_p_value_f64(metric); let mut builder = stream.launch_builder(&func); builder.arg(&x_ptr); @@ -159,7 +175,12 @@ pub unsafe fn launch_pdist( builder.arg(&n_u32); builder.arg(&d_u32); builder.arg(&metric_idx); - builder.arg(&p_value); + + if dtype == DType::F64 { + builder.arg(&p_f64); + } else { + builder.arg(&p_f32); + } builder .launch(cfg) diff --git a/src/runtime/cuda/kernels/fft.rs b/src/runtime/cuda/kernels/fft.rs index 3d831a89..e9f828eb 100644 --- a/src/runtime/cuda/kernels/fft.rs +++ b/src/runtime/cuda/kernels/fft.rs @@ -209,6 +209,11 @@ pub unsafe fn launch_stockham_fft_stage( } /// Launch scale kernel for complex data +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_scale_complex( context: &Arc, stream: &CudaStream, @@ -270,6 +275,11 @@ pub unsafe fn launch_scale_complex( } /// Launch rfft pack kernel (real -> complex with zero imaginary) +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_rfft_pack( context: &Arc, stream: &CudaStream, @@ -336,6 +346,11 @@ pub unsafe fn launch_rfft_pack( } /// Launch irfft unpack kernel (complex -> real, extracting real parts) +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_irfft_unpack( context: &Arc, stream: &CudaStream, @@ -406,6 +421,11 @@ pub unsafe fn launch_irfft_unpack( } /// Launch Hermitian extension kernel (N/2+1 complex -> N complex) +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_hermitian_extend( context: &Arc, stream: &CudaStream, @@ -483,6 +503,11 @@ pub unsafe fn launch_hermitian_extend( } /// Launch rfft truncation kernel (N complex -> N/2+1 complex) +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_rfft_truncate( context: &Arc, stream: &CudaStream, @@ -554,6 +579,11 @@ pub unsafe fn launch_rfft_truncate( } /// Launch fftshift kernel +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_fftshift( context: &Arc, stream: &CudaStream, @@ -620,6 +650,11 @@ pub unsafe fn launch_fftshift( } /// Launch ifftshift kernel +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_ifftshift( context: &Arc, stream: &CudaStream, @@ -686,6 +721,11 @@ pub unsafe fn launch_ifftshift( } /// Launch copy kernel for complex data +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. #[allow(dead_code)] pub unsafe fn launch_copy_complex( context: &Arc, diff --git a/src/runtime/cuda/kernels/fp8_matmul.cu b/src/runtime/cuda/kernels/fp8_matmul.cu new file mode 100644 index 00000000..ca8bdfb0 --- /dev/null +++ b/src/runtime/cuda/kernels/fp8_matmul.cu @@ -0,0 +1,539 @@ +// FP8 Matrix Multiplication CUDA Kernels +// +// Computes: C = scale_a * scale_b * (A_fp8 @ B_fp8) +// where A,B are FP8 tensors, accumulation is in FP32, output is F32/F16/BF16. +// +// Variants: +// - E4M3 x E4M3 -> F32/F16/BF16 (forward pass) +// - E5M2 x E4M3 -> F32/F16/BF16 (backward pass: gradients x weights) +// - Batched versions of both +// +// Algorithm: tiled GEMM with shared memory (F32 accumulation), FP8 loads via conversion. + +#include "dtype_traits.cuh" + +// Tile sizes for FP8 GEMM +// FP8 elements are 1 byte, so we can fit more in shared memory +#define FP8_TILE_M 64 +#define FP8_TILE_N 64 +#define FP8_TILE_K 32 +#define FP8_THREAD_M 4 +#define FP8_THREAD_N 4 + +// ============================================================================ +// Helper: store result with dtype conversion and scaling +// ============================================================================ + +__device__ __forceinline__ void store_f32(float* out, unsigned int idx, float val) { + out[idx] = val; +} + +__device__ __forceinline__ void store_f16(__half* out, unsigned int idx, float val) { + out[idx] = __float2half(val); +} + +__device__ __forceinline__ void store_bf16(__nv_bfloat16* out, unsigned int idx, float val) { + out[idx] = __float2bfloat16(val); +} + +// ============================================================================ +// FP8 E4M3 x E4M3 -> output dtype (tiled GEMM with F32 accumulation) +// ============================================================================ + +template +__device__ void fp8_matmul_e4m3_kernel( + const numr_fp8_e4m3* __restrict__ A, + const numr_fp8_e4m3* __restrict__ B, + OutT* __restrict__ C, + float scale_a, + float scale_b, + unsigned int M, + unsigned int N, + unsigned int K +) { + // Shared memory for tiles (store as f32 after conversion) + __shared__ float As[FP8_TILE_M][FP8_TILE_K]; + __shared__ float Bs[FP8_TILE_K][FP8_TILE_N]; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = FP8_TILE_N / FP8_THREAD_N; + const unsigned int threads_y = FP8_TILE_M / FP8_THREAD_M; + + const unsigned int block_row = blockIdx.y * FP8_TILE_M; + const unsigned int block_col = blockIdx.x * FP8_TILE_N; + const unsigned int thread_row = ty * FP8_THREAD_M; + const unsigned int thread_col = tx * FP8_THREAD_N; + + // Register accumulators (F32) + float reg_c[FP8_THREAD_M][FP8_THREAD_N]; + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) { + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) { + reg_c[i][j] = 0.0f; + } + } + + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = threads_x * threads_y; + const unsigned int num_k_tiles = (K + FP8_TILE_K - 1) / FP8_TILE_K; + const float combined_scale = scale_a * scale_b; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * FP8_TILE_K; + + // Cooperative load A tile, convert FP8 -> F32 + unsigned int a_elems = FP8_TILE_M * FP8_TILE_K; + for (unsigned int idx = thread_id; idx < a_elems; idx += num_threads) { + unsigned int r = idx / FP8_TILE_K; + unsigned int c = idx % FP8_TILE_K; + unsigned int gr = block_row + r; + unsigned int gc = k_offset + c; + if (gr < M && gc < K) { + As[r][c] = fp8_e4m3_to_f32(A[gr * K + gc].data); + } else { + As[r][c] = 0.0f; + } + } + + // Cooperative load B tile, convert FP8 -> F32 + unsigned int b_elems = FP8_TILE_K * FP8_TILE_N; + for (unsigned int idx = thread_id; idx < b_elems; idx += num_threads) { + unsigned int r = idx / FP8_TILE_N; + unsigned int c = idx % FP8_TILE_N; + unsigned int gr = k_offset + r; + unsigned int gc = block_col + c; + if (gr < K && gc < N) { + Bs[r][c] = fp8_e4m3_to_f32(B[gr * N + gc].data); + } else { + Bs[r][c] = 0.0f; + } + } + + __syncthreads(); + + // Compute partial products + #pragma unroll + for (unsigned int kk = 0; kk < FP8_TILE_K; kk++) { + float reg_a[FP8_THREAD_M]; + float reg_b[FP8_THREAD_N]; + + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) { + reg_a[i] = As[thread_row + i][kk]; + } + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) { + reg_b[j] = Bs[kk][thread_col + j]; + } + + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) { + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) { + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + } + } + + __syncthreads(); + } + + // Write output with scaling and dtype conversion + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) { + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) { + unsigned int gr = block_row + thread_row + i; + unsigned int gc = block_col + thread_col + j; + if (gr < M && gc < N) { + store_fn(C, gr * N + gc, reg_c[i][j] * combined_scale); + } + } + } +} + +// ============================================================================ +// FP8 E5M2 x E4M3 -> output dtype (backward pass) +// ============================================================================ + +template +__device__ void fp8_matmul_e5m2_kernel( + const numr_fp8_e5m2* __restrict__ A, + const numr_fp8_e4m3* __restrict__ B, + OutT* __restrict__ C, + float scale_a, + float scale_b, + unsigned int M, + unsigned int N, + unsigned int K +) { + __shared__ float As[FP8_TILE_M][FP8_TILE_K]; + __shared__ float Bs[FP8_TILE_K][FP8_TILE_N]; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = FP8_TILE_N / FP8_THREAD_N; + const unsigned int threads_y = FP8_TILE_M / FP8_THREAD_M; + + const unsigned int block_row = blockIdx.y * FP8_TILE_M; + const unsigned int block_col = blockIdx.x * FP8_TILE_N; + const unsigned int thread_row = ty * FP8_THREAD_M; + const unsigned int thread_col = tx * FP8_THREAD_N; + + float reg_c[FP8_THREAD_M][FP8_THREAD_N]; + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) { + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) { + reg_c[i][j] = 0.0f; + } + } + + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = threads_x * threads_y; + const unsigned int num_k_tiles = (K + FP8_TILE_K - 1) / FP8_TILE_K; + const float combined_scale = scale_a * scale_b; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * FP8_TILE_K; + + // Load A (E5M2) -> F32 + unsigned int a_elems = FP8_TILE_M * FP8_TILE_K; + for (unsigned int idx = thread_id; idx < a_elems; idx += num_threads) { + unsigned int r = idx / FP8_TILE_K; + unsigned int c = idx % FP8_TILE_K; + unsigned int gr = block_row + r; + unsigned int gc = k_offset + c; + if (gr < M && gc < K) { + As[r][c] = fp8_e5m2_to_f32(A[gr * K + gc].data); + } else { + As[r][c] = 0.0f; + } + } + + // Load B (E4M3) -> F32 + unsigned int b_elems = FP8_TILE_K * FP8_TILE_N; + for (unsigned int idx = thread_id; idx < b_elems; idx += num_threads) { + unsigned int r = idx / FP8_TILE_N; + unsigned int c = idx % FP8_TILE_N; + unsigned int gr = k_offset + r; + unsigned int gc = block_col + c; + if (gr < K && gc < N) { + Bs[r][c] = fp8_e4m3_to_f32(B[gr * N + gc].data); + } else { + Bs[r][c] = 0.0f; + } + } + + __syncthreads(); + + #pragma unroll + for (unsigned int kk = 0; kk < FP8_TILE_K; kk++) { + float reg_a[FP8_THREAD_M]; + float reg_b[FP8_THREAD_N]; + + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) { + reg_a[i] = As[thread_row + i][kk]; + } + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) { + reg_b[j] = Bs[kk][thread_col + j]; + } + + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) { + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) { + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + } + } + + __syncthreads(); + } + + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) { + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) { + unsigned int gr = block_row + thread_row + i; + unsigned int gc = block_col + thread_col + j; + if (gr < M && gc < N) { + store_fn(C, gr * N + gc, reg_c[i][j] * combined_scale); + } + } + } +} + +// ============================================================================ +// Batched variants +// ============================================================================ + +template +__device__ void fp8_matmul_e4m3_batched_kernel( + const numr_fp8_e4m3* __restrict__ A, + const numr_fp8_e4m3* __restrict__ B, + OutT* __restrict__ C, + float scale_a, + float scale_b, + unsigned int batch, + unsigned int M, + unsigned int N, + unsigned int K +) { + const unsigned int batch_idx = blockIdx.z; + if (batch_idx >= batch) return; + + const numr_fp8_e4m3* A_batch = A + batch_idx * M * K; + const numr_fp8_e4m3* B_batch = B + batch_idx * K * N; + OutT* C_batch = C + batch_idx * M * N; + + __shared__ float As[FP8_TILE_M][FP8_TILE_K]; + __shared__ float Bs[FP8_TILE_K][FP8_TILE_N]; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = FP8_TILE_N / FP8_THREAD_N; + const unsigned int threads_y = FP8_TILE_M / FP8_THREAD_M; + + const unsigned int block_row = blockIdx.y * FP8_TILE_M; + const unsigned int block_col = blockIdx.x * FP8_TILE_N; + const unsigned int thread_row = ty * FP8_THREAD_M; + const unsigned int thread_col = tx * FP8_THREAD_N; + + float reg_c[FP8_THREAD_M][FP8_THREAD_N]; + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) + reg_c[i][j] = 0.0f; + + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = threads_x * threads_y; + const unsigned int num_k_tiles = (K + FP8_TILE_K - 1) / FP8_TILE_K; + const float combined_scale = scale_a * scale_b; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * FP8_TILE_K; + + unsigned int a_elems = FP8_TILE_M * FP8_TILE_K; + for (unsigned int idx = thread_id; idx < a_elems; idx += num_threads) { + unsigned int r = idx / FP8_TILE_K; + unsigned int c = idx % FP8_TILE_K; + unsigned int gr = block_row + r; + unsigned int gc = k_offset + c; + As[r][c] = (gr < M && gc < K) ? fp8_e4m3_to_f32(A_batch[gr * K + gc].data) : 0.0f; + } + + unsigned int b_elems = FP8_TILE_K * FP8_TILE_N; + for (unsigned int idx = thread_id; idx < b_elems; idx += num_threads) { + unsigned int r = idx / FP8_TILE_N; + unsigned int c = idx % FP8_TILE_N; + unsigned int gr = k_offset + r; + unsigned int gc = block_col + c; + Bs[r][c] = (gr < K && gc < N) ? fp8_e4m3_to_f32(B_batch[gr * N + gc].data) : 0.0f; + } + + __syncthreads(); + + #pragma unroll + for (unsigned int kk = 0; kk < FP8_TILE_K; kk++) { + float ra[FP8_THREAD_M], rb[FP8_THREAD_N]; + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) ra[i] = As[thread_row + i][kk]; + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) rb[j] = Bs[kk][thread_col + j]; + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) + reg_c[i][j] += ra[i] * rb[j]; + } + + __syncthreads(); + } + + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) { + unsigned int gr = block_row + thread_row + i; + unsigned int gc = block_col + thread_col + j; + if (gr < M && gc < N) + store_fn(C_batch, gr * N + gc, reg_c[i][j] * combined_scale); + } +} + +template +__device__ void fp8_matmul_e5m2_batched_kernel( + const numr_fp8_e5m2* __restrict__ A, + const numr_fp8_e4m3* __restrict__ B, + OutT* __restrict__ C, + float scale_a, + float scale_b, + unsigned int batch, + unsigned int M, + unsigned int N, + unsigned int K +) { + const unsigned int batch_idx = blockIdx.z; + if (batch_idx >= batch) return; + + const numr_fp8_e5m2* A_batch = A + batch_idx * M * K; + const numr_fp8_e4m3* B_batch = B + batch_idx * K * N; + OutT* C_batch = C + batch_idx * M * N; + + __shared__ float As[FP8_TILE_M][FP8_TILE_K]; + __shared__ float Bs[FP8_TILE_K][FP8_TILE_N]; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = FP8_TILE_N / FP8_THREAD_N; + const unsigned int threads_y = FP8_TILE_M / FP8_THREAD_M; + + const unsigned int block_row = blockIdx.y * FP8_TILE_M; + const unsigned int block_col = blockIdx.x * FP8_TILE_N; + const unsigned int thread_row = ty * FP8_THREAD_M; + const unsigned int thread_col = tx * FP8_THREAD_N; + + float reg_c[FP8_THREAD_M][FP8_THREAD_N]; + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) + reg_c[i][j] = 0.0f; + + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = threads_x * threads_y; + const unsigned int num_k_tiles = (K + FP8_TILE_K - 1) / FP8_TILE_K; + const float combined_scale = scale_a * scale_b; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * FP8_TILE_K; + + unsigned int a_elems = FP8_TILE_M * FP8_TILE_K; + for (unsigned int idx = thread_id; idx < a_elems; idx += num_threads) { + unsigned int r = idx / FP8_TILE_K; + unsigned int c = idx % FP8_TILE_K; + unsigned int gr = block_row + r; + unsigned int gc = k_offset + c; + As[r][c] = (gr < M && gc < K) ? fp8_e5m2_to_f32(A_batch[gr * K + gc].data) : 0.0f; + } + + unsigned int b_elems = FP8_TILE_K * FP8_TILE_N; + for (unsigned int idx = thread_id; idx < b_elems; idx += num_threads) { + unsigned int r = idx / FP8_TILE_N; + unsigned int c = idx % FP8_TILE_N; + unsigned int gr = k_offset + r; + unsigned int gc = block_col + c; + Bs[r][c] = (gr < K && gc < N) ? fp8_e4m3_to_f32(B_batch[gr * N + gc].data) : 0.0f; + } + + __syncthreads(); + + #pragma unroll + for (unsigned int kk = 0; kk < FP8_TILE_K; kk++) { + float ra[FP8_THREAD_M], rb[FP8_THREAD_N]; + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) ra[i] = As[thread_row + i][kk]; + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) rb[j] = Bs[kk][thread_col + j]; + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) + reg_c[i][j] += ra[i] * rb[j]; + } + + __syncthreads(); + } + + #pragma unroll + for (int i = 0; i < FP8_THREAD_M; i++) + #pragma unroll + for (int j = 0; j < FP8_THREAD_N; j++) { + unsigned int gr = block_row + thread_row + i; + unsigned int gc = block_col + thread_col + j; + if (gr < M && gc < N) + store_fn(C_batch, gr * N + gc, reg_c[i][j] * combined_scale); + } +} + +// ============================================================================ +// Extern "C" entry points +// ============================================================================ + +extern "C" { + +// --- E4M3 x E4M3 -> F32 --- +__global__ void fp8_matmul_e4m3_f32( + const numr_fp8_e4m3* A, const numr_fp8_e4m3* B, float* C, + float scale_a, float scale_b, unsigned int M, unsigned int N, unsigned int K +) { fp8_matmul_e4m3_kernel(A, B, C, scale_a, scale_b, M, N, K); } + +// --- E4M3 x E4M3 -> F16 --- +__global__ void fp8_matmul_e4m3_f16( + const numr_fp8_e4m3* A, const numr_fp8_e4m3* B, __half* C, + float scale_a, float scale_b, unsigned int M, unsigned int N, unsigned int K +) { fp8_matmul_e4m3_kernel<__half, store_f16>(A, B, C, scale_a, scale_b, M, N, K); } + +// --- E4M3 x E4M3 -> BF16 --- +__global__ void fp8_matmul_e4m3_bf16( + const numr_fp8_e4m3* A, const numr_fp8_e4m3* B, __nv_bfloat16* C, + float scale_a, float scale_b, unsigned int M, unsigned int N, unsigned int K +) { fp8_matmul_e4m3_kernel<__nv_bfloat16, store_bf16>(A, B, C, scale_a, scale_b, M, N, K); } + +// --- E5M2 x E4M3 -> F32 --- +__global__ void fp8_matmul_e5m2_f32( + const numr_fp8_e5m2* A, const numr_fp8_e4m3* B, float* C, + float scale_a, float scale_b, unsigned int M, unsigned int N, unsigned int K +) { fp8_matmul_e5m2_kernel(A, B, C, scale_a, scale_b, M, N, K); } + +// --- E5M2 x E4M3 -> F16 --- +__global__ void fp8_matmul_e5m2_f16( + const numr_fp8_e5m2* A, const numr_fp8_e4m3* B, __half* C, + float scale_a, float scale_b, unsigned int M, unsigned int N, unsigned int K +) { fp8_matmul_e5m2_kernel<__half, store_f16>(A, B, C, scale_a, scale_b, M, N, K); } + +// --- E5M2 x E4M3 -> BF16 --- +__global__ void fp8_matmul_e5m2_bf16( + const numr_fp8_e5m2* A, const numr_fp8_e4m3* B, __nv_bfloat16* C, + float scale_a, float scale_b, unsigned int M, unsigned int N, unsigned int K +) { fp8_matmul_e5m2_kernel<__nv_bfloat16, store_bf16>(A, B, C, scale_a, scale_b, M, N, K); } + +// --- Batched E4M3 x E4M3 --- +__global__ void fp8_matmul_e4m3_batched_f32( + const numr_fp8_e4m3* A, const numr_fp8_e4m3* B, float* C, + float scale_a, float scale_b, unsigned int batch, unsigned int M, unsigned int N, unsigned int K +) { fp8_matmul_e4m3_batched_kernel(A, B, C, scale_a, scale_b, batch, M, N, K); } + +__global__ void fp8_matmul_e4m3_batched_f16( + const numr_fp8_e4m3* A, const numr_fp8_e4m3* B, __half* C, + float scale_a, float scale_b, unsigned int batch, unsigned int M, unsigned int N, unsigned int K +) { fp8_matmul_e4m3_batched_kernel<__half, store_f16>(A, B, C, scale_a, scale_b, batch, M, N, K); } + +__global__ void fp8_matmul_e4m3_batched_bf16( + const numr_fp8_e4m3* A, const numr_fp8_e4m3* B, __nv_bfloat16* C, + float scale_a, float scale_b, unsigned int batch, unsigned int M, unsigned int N, unsigned int K +) { fp8_matmul_e4m3_batched_kernel<__nv_bfloat16, store_bf16>(A, B, C, scale_a, scale_b, batch, M, N, K); } + +// --- Batched E5M2 x E4M3 --- +__global__ void fp8_matmul_e5m2_batched_f32( + const numr_fp8_e5m2* A, const numr_fp8_e4m3* B, float* C, + float scale_a, float scale_b, unsigned int batch, unsigned int M, unsigned int N, unsigned int K +) { fp8_matmul_e5m2_batched_kernel(A, B, C, scale_a, scale_b, batch, M, N, K); } + +__global__ void fp8_matmul_e5m2_batched_f16( + const numr_fp8_e5m2* A, const numr_fp8_e4m3* B, __half* C, + float scale_a, float scale_b, unsigned int batch, unsigned int M, unsigned int N, unsigned int K +) { fp8_matmul_e5m2_batched_kernel<__half, store_f16>(A, B, C, scale_a, scale_b, batch, M, N, K); } + +__global__ void fp8_matmul_e5m2_batched_bf16( + const numr_fp8_e5m2* A, const numr_fp8_e4m3* B, __nv_bfloat16* C, + float scale_a, float scale_b, unsigned int batch, unsigned int M, unsigned int N, unsigned int K +) { fp8_matmul_e5m2_batched_kernel<__nv_bfloat16, store_bf16>(A, B, C, scale_a, scale_b, batch, M, N, K); } + +} // extern "C" diff --git a/src/runtime/cuda/kernels/fp8_matmul.rs b/src/runtime/cuda/kernels/fp8_matmul.rs new file mode 100644 index 00000000..12f3029e --- /dev/null +++ b/src/runtime/cuda/kernels/fp8_matmul.rs @@ -0,0 +1,250 @@ +//! FP8 matmul CUDA kernel launchers +//! +//! Launches FP8 GEMM kernels with per-tensor scaling and F32 accumulation. +//! Output can be F32, F16, or BF16. + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use super::loader::{get_kernel_function, get_or_load_module, launch_config}; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +const FP8_MATMUL_MODULE: &str = "fp8_matmul"; + +// Tile config matching the .cu defines +const TILE_M: u32 = 64; +const TILE_N: u32 = 64; +const THREAD_M: u32 = 4; +const THREAD_N: u32 = 4; + +fn fp8_matmul_launch_cfg(m: usize, n: usize, batch: usize) -> super::loader::LaunchConfig { + let grid_x = ((n as u32) + TILE_N - 1) / TILE_N; + let grid_y = ((m as u32) + TILE_M - 1) / TILE_M; + let threads_x = TILE_N / THREAD_N; + let threads_y = TILE_M / THREAD_M; + launch_config( + (grid_x, grid_y, (batch as u32).max(1)), + (threads_x, threads_y, 1), + 0, + ) +} + +fn out_dtype_suffix(out_dtype: DType) -> Result<&'static str> { + match out_dtype { + DType::F32 => Ok("f32"), + DType::F16 => Ok("f16"), + DType::BF16 => Ok("bf16"), + _ => Err(Error::UnsupportedDType { + dtype: out_dtype, + op: "fp8_matmul output", + }), + } +} + +/// Launch FP8 E4M3 x E4M3 matmul kernel. +/// +/// # Safety +/// +/// All pointers must be valid device memory with correct sizes. +pub unsafe fn launch_fp8_matmul_e4m3( + context: &Arc, + stream: &CudaStream, + device_index: usize, + out_dtype: DType, + a_ptr: u64, + b_ptr: u64, + c_ptr: u64, + scale_a: f32, + scale_b: f32, + m: usize, + n: usize, + k: usize, +) -> Result<()> { + let module = get_or_load_module(context, device_index, FP8_MATMUL_MODULE)?; + let suffix = out_dtype_suffix(out_dtype)?; + let func_name = format!("fp8_matmul_e4m3_{}", suffix); + let func = get_kernel_function(&module, &func_name)?; + + let cfg = fp8_matmul_launch_cfg(m, n, 1); + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&c_ptr); + builder.arg(&scale_a); + builder.arg(&scale_b); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA fp8_matmul_e4m3 kernel launch failed: {:?}", + e + )) + })?; + } + + Ok(()) +} + +/// Launch FP8 E5M2 x E4M3 matmul kernel (backward pass). +/// +/// # Safety +/// +/// All pointers must be valid device memory with correct sizes. +pub unsafe fn launch_fp8_matmul_e5m2( + context: &Arc, + stream: &CudaStream, + device_index: usize, + out_dtype: DType, + a_ptr: u64, + b_ptr: u64, + c_ptr: u64, + scale_a: f32, + scale_b: f32, + m: usize, + n: usize, + k: usize, +) -> Result<()> { + let module = get_or_load_module(context, device_index, FP8_MATMUL_MODULE)?; + let suffix = out_dtype_suffix(out_dtype)?; + let func_name = format!("fp8_matmul_e5m2_{}", suffix); + let func = get_kernel_function(&module, &func_name)?; + + let cfg = fp8_matmul_launch_cfg(m, n, 1); + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&c_ptr); + builder.arg(&scale_a); + builder.arg(&scale_b); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA fp8_matmul_e5m2 kernel launch failed: {:?}", + e + )) + })?; + } + + Ok(()) +} + +/// Launch batched FP8 E4M3 x E4M3 matmul kernel. +/// +/// # Safety +/// +/// All pointers must be valid device memory with correct sizes. +pub unsafe fn launch_fp8_matmul_e4m3_batched( + context: &Arc, + stream: &CudaStream, + device_index: usize, + out_dtype: DType, + a_ptr: u64, + b_ptr: u64, + c_ptr: u64, + scale_a: f32, + scale_b: f32, + batch: usize, + m: usize, + n: usize, + k: usize, +) -> Result<()> { + let module = get_or_load_module(context, device_index, FP8_MATMUL_MODULE)?; + let suffix = out_dtype_suffix(out_dtype)?; + let func_name = format!("fp8_matmul_e4m3_batched_{}", suffix); + let func = get_kernel_function(&module, &func_name)?; + + let cfg = fp8_matmul_launch_cfg(m, n, batch); + let batch_u32 = batch as u32; + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&c_ptr); + builder.arg(&scale_a); + builder.arg(&scale_b); + builder.arg(&batch_u32); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA fp8_matmul_e4m3_batched kernel launch failed: {:?}", + e + )) + })?; + } + + Ok(()) +} + +/// Launch batched FP8 E5M2 x E4M3 matmul kernel (backward pass). +/// +/// # Safety +/// +/// All pointers must be valid device memory with correct sizes. +pub unsafe fn launch_fp8_matmul_e5m2_batched( + context: &Arc, + stream: &CudaStream, + device_index: usize, + out_dtype: DType, + a_ptr: u64, + b_ptr: u64, + c_ptr: u64, + scale_a: f32, + scale_b: f32, + batch: usize, + m: usize, + n: usize, + k: usize, +) -> Result<()> { + let module = get_or_load_module(context, device_index, FP8_MATMUL_MODULE)?; + let suffix = out_dtype_suffix(out_dtype)?; + let func_name = format!("fp8_matmul_e5m2_batched_{}", suffix); + let func = get_kernel_function(&module, &func_name)?; + + let cfg = fp8_matmul_launch_cfg(m, n, batch); + let batch_u32 = batch as u32; + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&c_ptr); + builder.arg(&scale_a); + builder.arg(&scale_b); + builder.arg(&batch_u32); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA fp8_matmul_e5m2_batched kernel launch failed: {:?}", + e + )) + })?; + } + + Ok(()) +} diff --git a/src/runtime/cuda/kernels/fused_activation_mul.cu b/src/runtime/cuda/kernels/fused_activation_mul.cu new file mode 100644 index 00000000..4b9a27a9 --- /dev/null +++ b/src/runtime/cuda/kernels/fused_activation_mul.cu @@ -0,0 +1,274 @@ +// Fused activation-mul CUDA kernels +// Forward: output = activation(a) * b +// Supports: silu_mul, gelu_mul, relu_mul, sigmoid_mul +// Types: f32, f64, f16, bf16, fp8_e4m3, fp8_e5m2 + +#include +#include +#include "dtype_traits.cuh" + +// ============================================================================ +// Helper device functions (shared across dtypes) +// ============================================================================ + +__device__ __forceinline__ float silu_f(float x) { + return x / (1.0f + expf(-x)); +} + +__device__ __forceinline__ float gelu_f(float x) { + float cdf = 0.5f * (1.0f + tanhf(0.7978845608f * (x + 0.044715f * x * x * x))); + return x * cdf; +} + +__device__ __forceinline__ float relu_f(float x) { + return fmaxf(0.0f, x); +} + +__device__ __forceinline__ float sigmoid_f(float x) { + return 1.0f / (1.0f + expf(-x)); +} + +__device__ __forceinline__ double silu_d(double x) { + return x / (1.0 + exp(-x)); +} + +__device__ __forceinline__ double gelu_d(double x) { + double cdf = 0.5 * (1.0 + tanh(0.7978845608028654 * (x + 0.044715 * x * x * x))); + return x * cdf; +} + +__device__ __forceinline__ double relu_d(double x) { + return fmax(0.0, x); +} + +__device__ __forceinline__ double sigmoid_d(double x) { + return 1.0 / (1.0 + exp(-x)); +} + +extern "C" { + +// ============================================================================ +// F32 Fused Activation-Mul Forward +// ============================================================================ + +__global__ void silu_mul_f32(const float* a, const float* b, float* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = silu_f(a[idx]) * b[idx]; + } +} + +__global__ void gelu_mul_f32(const float* a, const float* b, float* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = gelu_f(a[idx]) * b[idx]; + } +} + +__global__ void relu_mul_f32(const float* a, const float* b, float* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = relu_f(a[idx]) * b[idx]; + } +} + +__global__ void sigmoid_mul_f32(const float* a, const float* b, float* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = sigmoid_f(a[idx]) * b[idx]; + } +} + +// ============================================================================ +// F64 Fused Activation-Mul Forward +// ============================================================================ + +__global__ void silu_mul_f64(const double* a, const double* b, double* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = silu_d(a[idx]) * b[idx]; + } +} + +__global__ void gelu_mul_f64(const double* a, const double* b, double* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = gelu_d(a[idx]) * b[idx]; + } +} + +__global__ void relu_mul_f64(const double* a, const double* b, double* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = relu_d(a[idx]) * b[idx]; + } +} + +__global__ void sigmoid_mul_f64(const double* a, const double* b, double* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = sigmoid_d(a[idx]) * b[idx]; + } +} + +// ============================================================================ +// F16 Fused Activation-Mul Forward (compute in F32) +// ============================================================================ + +__global__ void silu_mul_f16(const __half* a, const __half* b, __half* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = __half2float(a[idx]); + float bx = __half2float(b[idx]); + out[idx] = __float2half(silu_f(ax) * bx); + } +} + +__global__ void gelu_mul_f16(const __half* a, const __half* b, __half* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = __half2float(a[idx]); + float bx = __half2float(b[idx]); + out[idx] = __float2half(gelu_f(ax) * bx); + } +} + +__global__ void relu_mul_f16(const __half* a, const __half* b, __half* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = __half2float(a[idx]); + float bx = __half2float(b[idx]); + out[idx] = __float2half(relu_f(ax) * bx); + } +} + +__global__ void sigmoid_mul_f16(const __half* a, const __half* b, __half* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = __half2float(a[idx]); + float bx = __half2float(b[idx]); + out[idx] = __float2half(sigmoid_f(ax) * bx); + } +} + +// ============================================================================ +// BF16 Fused Activation-Mul Forward (compute in F32) +// ============================================================================ + +__global__ void silu_mul_bf16(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = __bfloat162float(a[idx]); + float bx = __bfloat162float(b[idx]); + out[idx] = __float2bfloat16(silu_f(ax) * bx); + } +} + +__global__ void gelu_mul_bf16(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = __bfloat162float(a[idx]); + float bx = __bfloat162float(b[idx]); + out[idx] = __float2bfloat16(gelu_f(ax) * bx); + } +} + +__global__ void relu_mul_bf16(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = __bfloat162float(a[idx]); + float bx = __bfloat162float(b[idx]); + out[idx] = __float2bfloat16(relu_f(ax) * bx); + } +} + +__global__ void sigmoid_mul_bf16(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = __bfloat162float(a[idx]); + float bx = __bfloat162float(b[idx]); + out[idx] = __float2bfloat16(sigmoid_f(ax) * bx); + } +} + +// ============================================================================ +// FP8 E4M3 Fused Activation-Mul Forward (compute in F32) +// ============================================================================ + +__global__ void silu_mul_fp8_e4m3(const numr_fp8_e4m3* a, const numr_fp8_e4m3* b, numr_fp8_e4m3* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = fp8_e4m3_to_f32(a[idx].data); + float bx = fp8_e4m3_to_f32(b[idx].data); + out[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(silu_f(ax) * bx)); + } +} + +__global__ void gelu_mul_fp8_e4m3(const numr_fp8_e4m3* a, const numr_fp8_e4m3* b, numr_fp8_e4m3* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = fp8_e4m3_to_f32(a[idx].data); + float bx = fp8_e4m3_to_f32(b[idx].data); + out[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(gelu_f(ax) * bx)); + } +} + +__global__ void relu_mul_fp8_e4m3(const numr_fp8_e4m3* a, const numr_fp8_e4m3* b, numr_fp8_e4m3* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = fp8_e4m3_to_f32(a[idx].data); + float bx = fp8_e4m3_to_f32(b[idx].data); + out[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(relu_f(ax) * bx)); + } +} + +__global__ void sigmoid_mul_fp8_e4m3(const numr_fp8_e4m3* a, const numr_fp8_e4m3* b, numr_fp8_e4m3* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = fp8_e4m3_to_f32(a[idx].data); + float bx = fp8_e4m3_to_f32(b[idx].data); + out[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(sigmoid_f(ax) * bx)); + } +} + +// ============================================================================ +// FP8 E5M2 Fused Activation-Mul Forward (compute in F32) +// ============================================================================ + +__global__ void silu_mul_fp8_e5m2(const numr_fp8_e5m2* a, const numr_fp8_e5m2* b, numr_fp8_e5m2* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = fp8_e5m2_to_f32(a[idx].data); + float bx = fp8_e5m2_to_f32(b[idx].data); + out[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(silu_f(ax) * bx)); + } +} + +__global__ void gelu_mul_fp8_e5m2(const numr_fp8_e5m2* a, const numr_fp8_e5m2* b, numr_fp8_e5m2* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = fp8_e5m2_to_f32(a[idx].data); + float bx = fp8_e5m2_to_f32(b[idx].data); + out[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(gelu_f(ax) * bx)); + } +} + +__global__ void relu_mul_fp8_e5m2(const numr_fp8_e5m2* a, const numr_fp8_e5m2* b, numr_fp8_e5m2* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = fp8_e5m2_to_f32(a[idx].data); + float bx = fp8_e5m2_to_f32(b[idx].data); + out[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(relu_f(ax) * bx)); + } +} + +__global__ void sigmoid_mul_fp8_e5m2(const numr_fp8_e5m2* a, const numr_fp8_e5m2* b, numr_fp8_e5m2* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float ax = fp8_e5m2_to_f32(a[idx].data); + float bx = fp8_e5m2_to_f32(b[idx].data); + out[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(sigmoid_f(ax) * bx)); + } +} + +} // extern "C" diff --git a/src/runtime/cuda/kernels/fused_activation_mul.rs b/src/runtime/cuda/kernels/fused_activation_mul.rs new file mode 100644 index 00000000..132fe5e7 --- /dev/null +++ b/src/runtime/cuda/kernels/fused_activation_mul.rs @@ -0,0 +1,195 @@ +//! Fused activation-mul CUDA kernel launchers +//! +//! Forward: output = activation(a) * b +//! Backward: d_a = grad * b * activation'(a), d_b = grad * activation(a) + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use super::loader::{ + BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name, + launch_config, +}; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +const FUSED_ACTIVATION_MUL_MODULE: &str = "fused_activation_mul"; +const FUSED_ACTIVATION_MUL_BWD_MODULE: &str = "fused_activation_mul_bwd"; + +/// Launch a fused activation-mul forward kernel. +/// +/// Computes: `output[i] = activation(a[i]) * b[i]` +/// +/// # Safety +/// +/// All pointers must be valid device memory with at least `numel` elements. +unsafe fn launch_fused_activation_mul_fwd( + context: &Arc, + stream: &CudaStream, + device_index: usize, + op: &str, + dtype: DType, + a_ptr: u64, + b_ptr: u64, + output_ptr: u64, + numel: usize, +) -> Result<()> { + let module = get_or_load_module(context, device_index, FUSED_ACTIVATION_MUL_MODULE)?; + let func_name = kernel_name(op, dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(numel); + let block = (BLOCK_SIZE, 1, 1); + let n = numel as u32; + + let cfg = launch_config(grid, block, 0); + let mut builder = stream.launch_builder(&func); + unsafe { + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&output_ptr); + builder.arg(&n); + + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA {} kernel launch failed: {:?}", op, e)))?; + } + + Ok(()) +} + +/// Launch a fused activation-mul backward kernel. +/// +/// Computes: `d_b[i] = grad[i] * activation(a[i])`, `d_a[i] = grad[i] * b[i] * activation'(a[i])` +/// +/// # Safety +/// +/// All pointers must be valid device memory with at least `numel` elements. +unsafe fn launch_fused_activation_mul_bwd( + context: &Arc, + stream: &CudaStream, + device_index: usize, + op: &str, + dtype: DType, + grad_ptr: u64, + a_ptr: u64, + b_ptr: u64, + d_a_ptr: u64, + d_b_ptr: u64, + numel: usize, +) -> Result<()> { + let module = get_or_load_module(context, device_index, FUSED_ACTIVATION_MUL_BWD_MODULE)?; + let func_name = kernel_name(op, dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(numel); + let block = (BLOCK_SIZE, 1, 1); + let n = numel as u32; + + let cfg = launch_config(grid, block, 0); + let mut builder = stream.launch_builder(&func); + unsafe { + builder.arg(&grad_ptr); + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&d_a_ptr); + builder.arg(&d_b_ptr); + builder.arg(&n); + + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA {} kernel launch failed: {:?}", op, e)))?; + } + + Ok(()) +} + +// ============================================================================ +// Public forward launchers +// ============================================================================ + +macro_rules! fused_activation_mul_fwd { + ($($(#[doc = $doc:expr])* $name:ident => $op:expr),+ $(,)?) => { + $( + $(#[doc = $doc])* + /// + /// # Safety + /// + /// All pointers must be valid device memory with at least `numel` elements. + pub unsafe fn $name( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + a_ptr: u64, + b_ptr: u64, + output_ptr: u64, + numel: usize, + ) -> Result<()> { + unsafe { + launch_fused_activation_mul_fwd( + context, stream, device_index, $op, dtype, a_ptr, b_ptr, output_ptr, numel, + ) + } + } + )+ + }; +} + +fused_activation_mul_fwd! { + /// Launch fused silu_mul: output = silu(a) * b + launch_silu_mul => "silu_mul", + /// Launch fused gelu_mul: output = gelu(a) * b + launch_gelu_mul => "gelu_mul", + /// Launch fused relu_mul: output = relu(a) * b + launch_relu_mul => "relu_mul", + /// Launch fused sigmoid_mul: output = sigmoid(a) * b + launch_sigmoid_mul => "sigmoid_mul", +} + +// ============================================================================ +// Public backward launchers +// ============================================================================ + +macro_rules! fused_activation_mul_bwd { + ($($(#[doc = $doc:expr])* $name:ident => $op:expr),+ $(,)?) => { + $( + $(#[doc = $doc])* + /// + /// # Safety + /// + /// All pointers must be valid device memory with at least `numel` elements. + pub unsafe fn $name( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + grad_ptr: u64, + a_ptr: u64, + b_ptr: u64, + d_a_ptr: u64, + d_b_ptr: u64, + numel: usize, + ) -> Result<()> { + unsafe { + launch_fused_activation_mul_bwd( + context, stream, device_index, $op, dtype, grad_ptr, a_ptr, b_ptr, + d_a_ptr, d_b_ptr, numel, + ) + } + } + )+ + }; +} + +fused_activation_mul_bwd! { + /// Launch fused silu_mul backward + launch_silu_mul_bwd => "silu_mul_bwd", + /// Launch fused gelu_mul backward + launch_gelu_mul_bwd => "gelu_mul_bwd", + /// Launch fused relu_mul backward + launch_relu_mul_bwd => "relu_mul_bwd", + /// Launch fused sigmoid_mul backward + launch_sigmoid_mul_bwd => "sigmoid_mul_bwd", +} diff --git a/src/runtime/cuda/kernels/fused_activation_mul_bwd.cu b/src/runtime/cuda/kernels/fused_activation_mul_bwd.cu new file mode 100644 index 00000000..ddd0ca48 --- /dev/null +++ b/src/runtime/cuda/kernels/fused_activation_mul_bwd.cu @@ -0,0 +1,378 @@ +// Fused activation-mul backward CUDA kernels +// Given forward: output = activation(a) * b +// Backward: d_a = grad * b * activation'(a), d_b = grad * activation(a) +// Fused: computes activation(a), activation'(a), d_a, d_b in single pass +// Types: f32, f64, f16, bf16, fp8_e4m3, fp8_e5m2 + +#include +#include +#include "dtype_traits.cuh" +#include "activation_deriv.cuh" + +extern "C" { + +// ============================================================================ +// F32 Fused Activation-Mul Backward +// ============================================================================ + +// SiLU backward: silu'(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x))) +__global__ void silu_mul_bwd_f32( + const float* grad, const float* a, const float* b, + float* d_a, float* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = a[idx]; + float g = grad[idx]; + float bv = b[idx]; + d_b[idx] = g * silu_fwd_f32(x); + d_a[idx] = g * bv * silu_deriv_f32(x); + } +} + +// GELU backward: uses tanh approximation derivative +__global__ void gelu_mul_bwd_f32( + const float* grad, const float* a, const float* b, + float* d_a, float* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = a[idx]; + float g = grad[idx]; + float bv = b[idx]; + d_b[idx] = g * gelu_fwd_f32(x); + d_a[idx] = g * bv * gelu_deriv_f32(x); + } +} + +// ReLU backward: relu'(x) = 1 if x > 0 else 0 +__global__ void relu_mul_bwd_f32( + const float* grad, const float* a, const float* b, + float* d_a, float* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = a[idx]; + float g = grad[idx]; + float bv = b[idx]; + d_b[idx] = g * relu_fwd_f32(x); + d_a[idx] = g * bv * relu_deriv_f32(x); + } +} + +// Sigmoid backward: sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x)) +__global__ void sigmoid_mul_bwd_f32( + const float* grad, const float* a, const float* b, + float* d_a, float* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = a[idx]; + float g = grad[idx]; + float bv = b[idx]; + d_b[idx] = g * sigmoid_fwd_f32(x); + d_a[idx] = g * bv * sigmoid_deriv_f32(x); + } +} + +// ============================================================================ +// F64 Fused Activation-Mul Backward +// ============================================================================ + +__global__ void silu_mul_bwd_f64( + const double* grad, const double* a, const double* b, + double* d_a, double* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + double x = a[idx]; + double g = grad[idx]; + double bv = b[idx]; + d_b[idx] = g * silu_fwd_f64(x); + d_a[idx] = g * bv * silu_deriv_f64(x); + } +} + +__global__ void gelu_mul_bwd_f64( + const double* grad, const double* a, const double* b, + double* d_a, double* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + double x = a[idx]; + double g = grad[idx]; + double bv = b[idx]; + d_b[idx] = g * gelu_fwd_f64(x); + d_a[idx] = g * bv * gelu_deriv_f64(x); + } +} + +__global__ void relu_mul_bwd_f64( + const double* grad, const double* a, const double* b, + double* d_a, double* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + double x = a[idx]; + double g = grad[idx]; + double bv = b[idx]; + d_b[idx] = g * relu_fwd_f64(x); + d_a[idx] = g * bv * relu_deriv_f64(x); + } +} + +__global__ void sigmoid_mul_bwd_f64( + const double* grad, const double* a, const double* b, + double* d_a, double* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + double x = a[idx]; + double g = grad[idx]; + double bv = b[idx]; + d_b[idx] = g * sigmoid_fwd_f64(x); + d_a[idx] = g * bv * sigmoid_deriv_f64(x); + } +} + +// ============================================================================ +// F16 Fused Activation-Mul Backward (compute in F32) +// ============================================================================ + +__global__ void silu_mul_bwd_f16( + const __half* grad, const __half* a, const __half* b, + __half* d_a, __half* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __half2float(a[idx]); + float g = __half2float(grad[idx]); + float bv = __half2float(b[idx]); + d_b[idx] = __float2half(g * silu_fwd_f32(x)); + d_a[idx] = __float2half(g * bv * silu_deriv_f32(x)); + } +} + +__global__ void gelu_mul_bwd_f16( + const __half* grad, const __half* a, const __half* b, + __half* d_a, __half* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __half2float(a[idx]); + float g = __half2float(grad[idx]); + float bv = __half2float(b[idx]); + d_b[idx] = __float2half(g * gelu_fwd_f32(x)); + d_a[idx] = __float2half(g * bv * gelu_deriv_f32(x)); + } +} + +__global__ void relu_mul_bwd_f16( + const __half* grad, const __half* a, const __half* b, + __half* d_a, __half* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __half2float(a[idx]); + float g = __half2float(grad[idx]); + float bv = __half2float(b[idx]); + d_b[idx] = __float2half(g * relu_fwd_f32(x)); + d_a[idx] = __float2half(g * bv * relu_deriv_f32(x)); + } +} + +__global__ void sigmoid_mul_bwd_f16( + const __half* grad, const __half* a, const __half* b, + __half* d_a, __half* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __half2float(a[idx]); + float g = __half2float(grad[idx]); + float bv = __half2float(b[idx]); + d_b[idx] = __float2half(g * sigmoid_fwd_f32(x)); + d_a[idx] = __float2half(g * bv * sigmoid_deriv_f32(x)); + } +} + +// ============================================================================ +// BF16 Fused Activation-Mul Backward (compute in F32) +// ============================================================================ + +__global__ void silu_mul_bwd_bf16( + const __nv_bfloat16* grad, const __nv_bfloat16* a, const __nv_bfloat16* b, + __nv_bfloat16* d_a, __nv_bfloat16* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __bfloat162float(a[idx]); + float g = __bfloat162float(grad[idx]); + float bv = __bfloat162float(b[idx]); + d_b[idx] = __float2bfloat16(g * silu_fwd_f32(x)); + d_a[idx] = __float2bfloat16(g * bv * silu_deriv_f32(x)); + } +} + +__global__ void gelu_mul_bwd_bf16( + const __nv_bfloat16* grad, const __nv_bfloat16* a, const __nv_bfloat16* b, + __nv_bfloat16* d_a, __nv_bfloat16* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __bfloat162float(a[idx]); + float g = __bfloat162float(grad[idx]); + float bv = __bfloat162float(b[idx]); + d_b[idx] = __float2bfloat16(g * gelu_fwd_f32(x)); + d_a[idx] = __float2bfloat16(g * bv * gelu_deriv_f32(x)); + } +} + +__global__ void relu_mul_bwd_bf16( + const __nv_bfloat16* grad, const __nv_bfloat16* a, const __nv_bfloat16* b, + __nv_bfloat16* d_a, __nv_bfloat16* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __bfloat162float(a[idx]); + float g = __bfloat162float(grad[idx]); + float bv = __bfloat162float(b[idx]); + d_b[idx] = __float2bfloat16(g * relu_fwd_f32(x)); + d_a[idx] = __float2bfloat16(g * bv * relu_deriv_f32(x)); + } +} + +__global__ void sigmoid_mul_bwd_bf16( + const __nv_bfloat16* grad, const __nv_bfloat16* a, const __nv_bfloat16* b, + __nv_bfloat16* d_a, __nv_bfloat16* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = __bfloat162float(a[idx]); + float g = __bfloat162float(grad[idx]); + float bv = __bfloat162float(b[idx]); + d_b[idx] = __float2bfloat16(g * sigmoid_fwd_f32(x)); + d_a[idx] = __float2bfloat16(g * bv * sigmoid_deriv_f32(x)); + } +} + +// ============================================================================ +// FP8 E4M3 Fused Activation-Mul Backward (compute in F32) +// ============================================================================ + +__global__ void silu_mul_bwd_fp8_e4m3( + const numr_fp8_e4m3* grad, const numr_fp8_e4m3* a, const numr_fp8_e4m3* b, + numr_fp8_e4m3* d_a, numr_fp8_e4m3* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = fp8_e4m3_to_f32(a[idx].data); + float g = fp8_e4m3_to_f32(grad[idx].data); + float bv = fp8_e4m3_to_f32(b[idx].data); + d_b[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(g * silu_fwd_f32(x))); + d_a[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(g * bv * silu_deriv_f32(x))); + } +} + +__global__ void gelu_mul_bwd_fp8_e4m3( + const numr_fp8_e4m3* grad, const numr_fp8_e4m3* a, const numr_fp8_e4m3* b, + numr_fp8_e4m3* d_a, numr_fp8_e4m3* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = fp8_e4m3_to_f32(a[idx].data); + float g = fp8_e4m3_to_f32(grad[idx].data); + float bv = fp8_e4m3_to_f32(b[idx].data); + d_b[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(g * gelu_fwd_f32(x))); + d_a[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(g * bv * gelu_deriv_f32(x))); + } +} + +__global__ void relu_mul_bwd_fp8_e4m3( + const numr_fp8_e4m3* grad, const numr_fp8_e4m3* a, const numr_fp8_e4m3* b, + numr_fp8_e4m3* d_a, numr_fp8_e4m3* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = fp8_e4m3_to_f32(a[idx].data); + float g = fp8_e4m3_to_f32(grad[idx].data); + float bv = fp8_e4m3_to_f32(b[idx].data); + d_b[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(g * relu_fwd_f32(x))); + d_a[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(g * bv * relu_deriv_f32(x))); + } +} + +__global__ void sigmoid_mul_bwd_fp8_e4m3( + const numr_fp8_e4m3* grad, const numr_fp8_e4m3* a, const numr_fp8_e4m3* b, + numr_fp8_e4m3* d_a, numr_fp8_e4m3* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = fp8_e4m3_to_f32(a[idx].data); + float g = fp8_e4m3_to_f32(grad[idx].data); + float bv = fp8_e4m3_to_f32(b[idx].data); + d_b[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(g * sigmoid_fwd_f32(x))); + d_a[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(g * bv * sigmoid_deriv_f32(x))); + } +} + +// ============================================================================ +// FP8 E5M2 Fused Activation-Mul Backward (compute in F32) +// ============================================================================ + +__global__ void silu_mul_bwd_fp8_e5m2( + const numr_fp8_e5m2* grad, const numr_fp8_e5m2* a, const numr_fp8_e5m2* b, + numr_fp8_e5m2* d_a, numr_fp8_e5m2* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = fp8_e5m2_to_f32(a[idx].data); + float g = fp8_e5m2_to_f32(grad[idx].data); + float bv = fp8_e5m2_to_f32(b[idx].data); + d_b[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(g * silu_fwd_f32(x))); + d_a[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(g * bv * silu_deriv_f32(x))); + } +} + +__global__ void gelu_mul_bwd_fp8_e5m2( + const numr_fp8_e5m2* grad, const numr_fp8_e5m2* a, const numr_fp8_e5m2* b, + numr_fp8_e5m2* d_a, numr_fp8_e5m2* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = fp8_e5m2_to_f32(a[idx].data); + float g = fp8_e5m2_to_f32(grad[idx].data); + float bv = fp8_e5m2_to_f32(b[idx].data); + d_b[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(g * gelu_fwd_f32(x))); + d_a[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(g * bv * gelu_deriv_f32(x))); + } +} + +__global__ void relu_mul_bwd_fp8_e5m2( + const numr_fp8_e5m2* grad, const numr_fp8_e5m2* a, const numr_fp8_e5m2* b, + numr_fp8_e5m2* d_a, numr_fp8_e5m2* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = fp8_e5m2_to_f32(a[idx].data); + float g = fp8_e5m2_to_f32(grad[idx].data); + float bv = fp8_e5m2_to_f32(b[idx].data); + d_b[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(g * relu_fwd_f32(x))); + d_a[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(g * bv * relu_deriv_f32(x))); + } +} + +__global__ void sigmoid_mul_bwd_fp8_e5m2( + const numr_fp8_e5m2* grad, const numr_fp8_e5m2* a, const numr_fp8_e5m2* b, + numr_fp8_e5m2* d_a, numr_fp8_e5m2* d_b, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float x = fp8_e5m2_to_f32(a[idx].data); + float g = fp8_e5m2_to_f32(grad[idx].data); + float bv = fp8_e5m2_to_f32(b[idx].data); + d_b[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(g * sigmoid_fwd_f32(x))); + d_a[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(g * bv * sigmoid_deriv_f32(x))); + } +} + +} // extern "C" diff --git a/src/runtime/cuda/kernels/fused_add_norm.cu b/src/runtime/cuda/kernels/fused_add_norm.cu new file mode 100644 index 00000000..5bc0b3ce --- /dev/null +++ b/src/runtime/cuda/kernels/fused_add_norm.cu @@ -0,0 +1,1461 @@ +// Fused Add + Normalization CUDA kernels +// Supports: fused_add_rms_norm, fused_add_layer_norm (forward + backward) +// Types: f32, f64, f16, bf16 +// Note: All half-precision variants use FP32 accumulation for numerical stability + +#include +#include +#include "dtype_traits.cuh" + +extern "C" { + +// ============================================================================ +// Helper: atomicAdd for half-precision types via atomicCAS +// ============================================================================ + +__device__ void atomicAddHalf(__half* address, float val) { + unsigned short int* address_as_us = (unsigned short int*)address; + unsigned short int old = *address_as_us, assumed; + do { + assumed = old; + old = atomicCAS(address_as_us, assumed, + __half_as_ushort(__float2half(__half2float(__ushort_as_half(assumed)) + val))); + } while (assumed != old); +} + +__device__ void atomicAddBf16(__nv_bfloat16* address, float val) { + // Use atomicCAS with bit manipulation for BF16 + unsigned short int* address_as_us = (unsigned short int*)address; + unsigned short int old = *address_as_us, assumed; + do { + assumed = old; + // Extract as uint16, convert to bfloat16, then float, add, convert back + __nv_bfloat16 old_val; + unsigned short int* old_val_ptr = (unsigned short int*)&old_val; + *old_val_ptr = assumed; + float new_float = __bfloat162float(old_val) + val; + __nv_bfloat16 new_val = __float2bfloat16(new_float); + unsigned short int* new_val_ptr = (unsigned short int*)&new_val; + old = atomicCAS(address_as_us, assumed, *new_val_ptr); + } while (assumed != old); +} + +// ============================================================================ +// F32 Fused Add + RMSNorm Forward +// ============================================================================ + +__global__ void fused_add_rms_norm_f32( + const float* input, const float* residual, const float* weight, + float* output, float* pre_norm, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + + const float* row_in = input + row * hidden_size; + const float* row_res = residual + row * hidden_size; + float* row_pn = pre_norm + row * hidden_size; + float* row_out = output + row * hidden_size; + + // Phase 1: Add residual + compute sum of squares + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = row_in[i] + row_res[i]; + row_pn[i] = pn; + thread_sum += pn * pn; + } + shared[threadIdx.x] = thread_sum; + __syncthreads(); + + // Reduce + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) shared[threadIdx.x] += shared[threadIdx.x + s]; + __syncthreads(); + } + + float rms_inv = rsqrtf(shared[0] / hidden_size + eps); + __syncthreads(); + + // Phase 2: Normalize and apply weight + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + row_out[i] = row_pn[i] * rms_inv * weight[i]; + } +} + +// ============================================================================ +// F64 Fused Add + RMSNorm Forward +// ============================================================================ + +__global__ void fused_add_rms_norm_f64( + const double* input, const double* residual, const double* weight, + double* output, double* pre_norm, + unsigned int batch_size, unsigned int hidden_size, double eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ double shared_f64[]; + + const double* row_in = input + row * hidden_size; + const double* row_res = residual + row * hidden_size; + double* row_pn = pre_norm + row * hidden_size; + double* row_out = output + row * hidden_size; + + double thread_sum = 0.0; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + double pn = row_in[i] + row_res[i]; + row_pn[i] = pn; + thread_sum += pn * pn; + } + shared_f64[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) shared_f64[threadIdx.x] += shared_f64[threadIdx.x + s]; + __syncthreads(); + } + + double rms_inv = rsqrt(shared_f64[0] / hidden_size + eps); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + row_out[i] = row_pn[i] * rms_inv * weight[i]; + } +} + +// ============================================================================ +// F16 Fused Add + RMSNorm Forward (FP32 accumulation) +// ============================================================================ + +__global__ void fused_add_rms_norm_f16( + const __half* input, const __half* residual, const __half* weight, + __half* output, __half* pre_norm, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + + const __half* row_in = input + row * hidden_size; + const __half* row_res = residual + row * hidden_size; + __half* row_pn = pre_norm + row * hidden_size; + __half* row_out = output + row * hidden_size; + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = __half2float(row_in[i]) + __half2float(row_res[i]); + row_pn[i] = __float2half(pn); + thread_sum += pn * pn; + } + shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) shared[threadIdx.x] += shared[threadIdx.x + s]; + __syncthreads(); + } + + float rms_inv = rsqrtf(shared[0] / hidden_size + eps); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = __half2float(row_pn[i]); + float result = pn * rms_inv * __half2float(weight[i]); + row_out[i] = __float2half(result); + } +} + +// ============================================================================ +// BF16 Fused Add + RMSNorm Forward (FP32 accumulation) +// ============================================================================ + +__global__ void fused_add_rms_norm_bf16( + const __nv_bfloat16* input, const __nv_bfloat16* residual, const __nv_bfloat16* weight, + __nv_bfloat16* output, __nv_bfloat16* pre_norm, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + + const __nv_bfloat16* row_in = input + row * hidden_size; + const __nv_bfloat16* row_res = residual + row * hidden_size; + __nv_bfloat16* row_pn = pre_norm + row * hidden_size; + __nv_bfloat16* row_out = output + row * hidden_size; + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = __bfloat162float(row_in[i]) + __bfloat162float(row_res[i]); + row_pn[i] = __float2bfloat16(pn); + thread_sum += pn * pn; + } + shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) shared[threadIdx.x] += shared[threadIdx.x + s]; + __syncthreads(); + } + + float rms_inv = rsqrtf(shared[0] / hidden_size + eps); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = __bfloat162float(row_pn[i]); + float result = pn * rms_inv * __bfloat162float(weight[i]); + row_out[i] = __float2bfloat16(result); + } +} + +// ============================================================================ +// F32 Fused Add + RMSNorm Backward +// ============================================================================ + +__global__ void fused_add_rms_norm_bwd_f32( + const float* grad, const float* pre_norm, const float* weight, + float* d_input_residual, float* d_weight, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* sum_sq_shared = shared; + float* dot_shared = shared + blockDim.x; + + // Phase 1: Compute sum_sq and dot = sum(grad * weight * pre_norm) + float thread_sq = 0.0f, thread_dot = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = pre_norm[row * hidden_size + i]; + float g = grad[row * hidden_size + i]; + float w = weight[i]; + thread_sq += pn * pn; + thread_dot += g * w * pn; + } + sum_sq_shared[threadIdx.x] = thread_sq; + dot_shared[threadIdx.x] = thread_dot; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sum_sq_shared[threadIdx.x] += sum_sq_shared[threadIdx.x + s]; + dot_shared[threadIdx.x] += dot_shared[threadIdx.x + s]; + } + __syncthreads(); + } + + float mean_sq = sum_sq_shared[0] / hidden_size; + float inv_rms = rsqrtf(mean_sq + eps); + float dot = dot_shared[0]; + float coeff = dot * inv_rms / (hidden_size * (mean_sq + eps)); + __syncthreads(); + + // Phase 2: Compute d_input_residual and atomicAdd d_weight + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float g = grad[row * hidden_size + i]; + float w = weight[i]; + float pn = pre_norm[row * hidden_size + i]; + d_input_residual[row * hidden_size + i] = (g * w - pn * coeff) * inv_rms; + atomicAdd(&d_weight[i], g * pn * inv_rms); + } +} + +// ============================================================================ +// F64 Fused Add + RMSNorm Backward +// ============================================================================ + +__global__ void fused_add_rms_norm_bwd_f64( + const double* grad, const double* pre_norm, const double* weight, + double* d_input_residual, double* d_weight, + unsigned int batch_size, unsigned int hidden_size, double eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ double shared_f64[]; + double* sum_sq_shared = shared_f64; + double* dot_shared = shared_f64 + blockDim.x; + + double thread_sq = 0.0, thread_dot = 0.0; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + double pn = pre_norm[row * hidden_size + i]; + double g = grad[row * hidden_size + i]; + double w = weight[i]; + thread_sq += pn * pn; + thread_dot += g * w * pn; + } + sum_sq_shared[threadIdx.x] = thread_sq; + dot_shared[threadIdx.x] = thread_dot; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sum_sq_shared[threadIdx.x] += sum_sq_shared[threadIdx.x + s]; + dot_shared[threadIdx.x] += dot_shared[threadIdx.x + s]; + } + __syncthreads(); + } + + double mean_sq = sum_sq_shared[0] / hidden_size; + double inv_rms = rsqrt(mean_sq + eps); + double dot = dot_shared[0]; + double coeff = dot * inv_rms / (hidden_size * (mean_sq + eps)); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + double g = grad[row * hidden_size + i]; + double w = weight[i]; + double pn = pre_norm[row * hidden_size + i]; + d_input_residual[row * hidden_size + i] = (g * w - pn * coeff) * inv_rms; + atomicAdd(&d_weight[i], g * pn * inv_rms); + } +} + +// ============================================================================ +// F16 Fused Add + RMSNorm Backward (FP32 accumulation) +// ============================================================================ + +__global__ void fused_add_rms_norm_bwd_f16( + const __half* grad, const __half* pre_norm, const __half* weight, + __half* d_input_residual, __half* d_weight, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* sum_sq_shared = shared; + float* dot_shared = shared + blockDim.x; + + float thread_sq = 0.0f, thread_dot = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = __half2float(pre_norm[row * hidden_size + i]); + float g = __half2float(grad[row * hidden_size + i]); + float w = __half2float(weight[i]); + thread_sq += pn * pn; + thread_dot += g * w * pn; + } + sum_sq_shared[threadIdx.x] = thread_sq; + dot_shared[threadIdx.x] = thread_dot; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sum_sq_shared[threadIdx.x] += sum_sq_shared[threadIdx.x + s]; + dot_shared[threadIdx.x] += dot_shared[threadIdx.x + s]; + } + __syncthreads(); + } + + float mean_sq = sum_sq_shared[0] / hidden_size; + float inv_rms = rsqrtf(mean_sq + eps); + float dot = dot_shared[0]; + float coeff = dot * inv_rms / (hidden_size * (mean_sq + eps)); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float g = __half2float(grad[row * hidden_size + i]); + float w = __half2float(weight[i]); + float pn = __half2float(pre_norm[row * hidden_size + i]); + float dir = (g * w - pn * coeff) * inv_rms; + d_input_residual[row * hidden_size + i] = __float2half(dir); + atomicAddHalf(&d_weight[i], g * pn * inv_rms); + } +} + +// ============================================================================ +// BF16 Fused Add + RMSNorm Backward (FP32 accumulation) +// ============================================================================ + +__global__ void fused_add_rms_norm_bwd_bf16( + const __nv_bfloat16* grad, const __nv_bfloat16* pre_norm, const __nv_bfloat16* weight, + __nv_bfloat16* d_input_residual, __nv_bfloat16* d_weight, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* sum_sq_shared = shared; + float* dot_shared = shared + blockDim.x; + + float thread_sq = 0.0f, thread_dot = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = __bfloat162float(pre_norm[row * hidden_size + i]); + float g = __bfloat162float(grad[row * hidden_size + i]); + float w = __bfloat162float(weight[i]); + thread_sq += pn * pn; + thread_dot += g * w * pn; + } + sum_sq_shared[threadIdx.x] = thread_sq; + dot_shared[threadIdx.x] = thread_dot; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sum_sq_shared[threadIdx.x] += sum_sq_shared[threadIdx.x + s]; + dot_shared[threadIdx.x] += dot_shared[threadIdx.x + s]; + } + __syncthreads(); + } + + float mean_sq = sum_sq_shared[0] / hidden_size; + float inv_rms = rsqrtf(mean_sq + eps); + float dot = dot_shared[0]; + float coeff = dot * inv_rms / (hidden_size * (mean_sq + eps)); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float g = __bfloat162float(grad[row * hidden_size + i]); + float w = __bfloat162float(weight[i]); + float pn = __bfloat162float(pre_norm[row * hidden_size + i]); + float dir = (g * w - pn * coeff) * inv_rms; + d_input_residual[row * hidden_size + i] = __float2bfloat16(dir); + atomicAddBf16(&d_weight[i], g * pn * inv_rms); + } +} + +// ============================================================================ +// F32 Fused Add + LayerNorm Forward +// ============================================================================ + +__global__ void fused_add_layer_norm_f32( + const float* input, const float* residual, const float* weight, const float* bias, + float* output, float* pre_norm, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* mean_shared = shared; + float* var_shared = shared + blockDim.x; + + const float* row_in = input + row * hidden_size; + const float* row_res = residual + row * hidden_size; + float* row_pn = pre_norm + row * hidden_size; + float* row_out = output + row * hidden_size; + + // Phase 1: Add residual + compute mean + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = row_in[i] + row_res[i]; + row_pn[i] = pn; + thread_sum += pn; + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + __syncthreads(); + } + float mean = mean_shared[0] / hidden_size; + __syncthreads(); + + // Phase 2: Compute variance + float thread_var = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float diff = row_pn[i] - mean; + thread_var += diff * diff; + } + var_shared[threadIdx.x] = thread_var; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + __syncthreads(); + } + float inv_std = rsqrtf(var_shared[0] / hidden_size + eps); + __syncthreads(); + + // Phase 3: Normalize and apply affine + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float normalized = (row_pn[i] - mean) * inv_std; + row_out[i] = normalized * weight[i] + bias[i]; + } +} + +// ============================================================================ +// F64 Fused Add + LayerNorm Forward +// ============================================================================ + +__global__ void fused_add_layer_norm_f64( + const double* input, const double* residual, const double* weight, const double* bias, + double* output, double* pre_norm, + unsigned int batch_size, unsigned int hidden_size, double eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ double shared_f64[]; + double* mean_shared = shared_f64; + double* var_shared = shared_f64 + blockDim.x; + + const double* row_in = input + row * hidden_size; + const double* row_res = residual + row * hidden_size; + double* row_pn = pre_norm + row * hidden_size; + double* row_out = output + row * hidden_size; + + double thread_sum = 0.0; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + double pn = row_in[i] + row_res[i]; + row_pn[i] = pn; + thread_sum += pn; + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + __syncthreads(); + } + double mean = mean_shared[0] / hidden_size; + __syncthreads(); + + double thread_var = 0.0; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + double diff = row_pn[i] - mean; + thread_var += diff * diff; + } + var_shared[threadIdx.x] = thread_var; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + __syncthreads(); + } + double inv_std = rsqrt(var_shared[0] / hidden_size + eps); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + double normalized = (row_pn[i] - mean) * inv_std; + row_out[i] = normalized * weight[i] + bias[i]; + } +} + +// ============================================================================ +// F16 Fused Add + LayerNorm Forward (FP32 accumulation) +// ============================================================================ + +__global__ void fused_add_layer_norm_f16( + const __half* input, const __half* residual, const __half* weight, const __half* bias, + __half* output, __half* pre_norm, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* mean_shared = shared; + float* var_shared = shared + blockDim.x; + + const __half* row_in = input + row * hidden_size; + const __half* row_res = residual + row * hidden_size; + __half* row_pn = pre_norm + row * hidden_size; + __half* row_out = output + row * hidden_size; + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = __half2float(row_in[i]) + __half2float(row_res[i]); + row_pn[i] = __float2half(pn); + thread_sum += pn; + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + __syncthreads(); + } + float mean = mean_shared[0] / hidden_size; + __syncthreads(); + + float thread_var = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float diff = __half2float(row_pn[i]) - mean; + thread_var += diff * diff; + } + var_shared[threadIdx.x] = thread_var; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + __syncthreads(); + } + float inv_std = rsqrtf(var_shared[0] / hidden_size + eps); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float normalized = (__half2float(row_pn[i]) - mean) * inv_std; + float result = normalized * __half2float(weight[i]) + __half2float(bias[i]); + row_out[i] = __float2half(result); + } +} + +// ============================================================================ +// BF16 Fused Add + LayerNorm Forward (FP32 accumulation) +// ============================================================================ + +__global__ void fused_add_layer_norm_bf16( + const __nv_bfloat16* input, const __nv_bfloat16* residual, const __nv_bfloat16* weight, const __nv_bfloat16* bias, + __nv_bfloat16* output, __nv_bfloat16* pre_norm, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* mean_shared = shared; + float* var_shared = shared + blockDim.x; + + const __nv_bfloat16* row_in = input + row * hidden_size; + const __nv_bfloat16* row_res = residual + row * hidden_size; + __nv_bfloat16* row_pn = pre_norm + row * hidden_size; + __nv_bfloat16* row_out = output + row * hidden_size; + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = __bfloat162float(row_in[i]) + __bfloat162float(row_res[i]); + row_pn[i] = __float2bfloat16(pn); + thread_sum += pn; + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + __syncthreads(); + } + float mean = mean_shared[0] / hidden_size; + __syncthreads(); + + float thread_var = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float diff = __bfloat162float(row_pn[i]) - mean; + thread_var += diff * diff; + } + var_shared[threadIdx.x] = thread_var; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + __syncthreads(); + } + float inv_std = rsqrtf(var_shared[0] / hidden_size + eps); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float normalized = (__bfloat162float(row_pn[i]) - mean) * inv_std; + float result = normalized * __bfloat162float(weight[i]) + __bfloat162float(bias[i]); + row_out[i] = __float2bfloat16(result); + } +} + +// ============================================================================ +// F32 Fused Add + LayerNorm Backward +// ============================================================================ + +__global__ void fused_add_layer_norm_bwd_f32( + const float* grad, const float* pre_norm, const float* weight, + float* d_input_residual, float* d_weight, float* d_bias, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* mean_shared = shared; + float* var_shared = shared + blockDim.x; + float* gs_shared = shared + 2 * blockDim.x; + float* gsn_shared = shared + 3 * blockDim.x; + + // Phase 1: Compute mean + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + thread_sum += pre_norm[row * hidden_size + i]; + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + __syncthreads(); + } + float mean = mean_shared[0] / hidden_size; + __syncthreads(); + + // Phase 2: Compute variance + float thread_var = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float diff = pre_norm[row * hidden_size + i] - mean; + thread_var += diff * diff; + } + var_shared[threadIdx.x] = thread_var; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + __syncthreads(); + } + float var = var_shared[0] / hidden_size; + float inv_std = rsqrtf(var + eps); + __syncthreads(); + + // Phase 3: mean_gs and mean_gsn + float thread_gs = 0.0f, thread_gsn = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float g = grad[row * hidden_size + i]; + float w = weight[i]; + float normalized = (pre_norm[row * hidden_size + i] - mean) * inv_std; + thread_gs += g * w; + thread_gsn += g * w * normalized; + } + gs_shared[threadIdx.x] = thread_gs; + gsn_shared[threadIdx.x] = thread_gsn; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + gs_shared[threadIdx.x] += gs_shared[threadIdx.x + s]; + gsn_shared[threadIdx.x] += gsn_shared[threadIdx.x + s]; + } + __syncthreads(); + } + float mean_gs = gs_shared[0] / hidden_size; + float mean_gsn = gsn_shared[0] / hidden_size; + __syncthreads(); + + // Phase 4: Compute gradients + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float g = grad[row * hidden_size + i]; + float w = weight[i]; + float normalized = (pre_norm[row * hidden_size + i] - mean) * inv_std; + float d_ir = inv_std * (g * w - mean_gs - normalized * mean_gsn); + d_input_residual[row * hidden_size + i] = d_ir; + atomicAdd(&d_weight[i], g * normalized); + atomicAdd(&d_bias[i], g); + } +} + +// ============================================================================ +// F64 Fused Add + LayerNorm Backward +// ============================================================================ + +__global__ void fused_add_layer_norm_bwd_f64( + const double* grad, const double* pre_norm, const double* weight, + double* d_input_residual, double* d_weight, double* d_bias, + unsigned int batch_size, unsigned int hidden_size, double eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ double shared_f64[]; + double* mean_shared = shared_f64; + double* var_shared = shared_f64 + blockDim.x; + double* gs_shared = shared_f64 + 2 * blockDim.x; + double* gsn_shared = shared_f64 + 3 * blockDim.x; + + double thread_sum = 0.0; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + thread_sum += pre_norm[row * hidden_size + i]; + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + __syncthreads(); + } + double mean = mean_shared[0] / hidden_size; + __syncthreads(); + + double thread_var = 0.0; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + double diff = pre_norm[row * hidden_size + i] - mean; + thread_var += diff * diff; + } + var_shared[threadIdx.x] = thread_var; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + __syncthreads(); + } + double var = var_shared[0] / hidden_size; + double inv_std = rsqrt(var + eps); + __syncthreads(); + + double thread_gs = 0.0, thread_gsn = 0.0; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + double g = grad[row * hidden_size + i]; + double w = weight[i]; + double normalized = (pre_norm[row * hidden_size + i] - mean) * inv_std; + thread_gs += g * w; + thread_gsn += g * w * normalized; + } + gs_shared[threadIdx.x] = thread_gs; + gsn_shared[threadIdx.x] = thread_gsn; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + gs_shared[threadIdx.x] += gs_shared[threadIdx.x + s]; + gsn_shared[threadIdx.x] += gsn_shared[threadIdx.x + s]; + } + __syncthreads(); + } + double mean_gs = gs_shared[0] / hidden_size; + double mean_gsn = gsn_shared[0] / hidden_size; + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + double g = grad[row * hidden_size + i]; + double w = weight[i]; + double normalized = (pre_norm[row * hidden_size + i] - mean) * inv_std; + double d_ir = inv_std * (g * w - mean_gs - normalized * mean_gsn); + d_input_residual[row * hidden_size + i] = d_ir; + atomicAdd(&d_weight[i], g * normalized); + atomicAdd(&d_bias[i], g); + } +} + +// ============================================================================ +// F16 Fused Add + LayerNorm Backward (FP32 accumulation) +// ============================================================================ + +__global__ void fused_add_layer_norm_bwd_f16( + const __half* grad, const __half* pre_norm, const __half* weight, + __half* d_input_residual, __half* d_weight, __half* d_bias, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* mean_shared = shared; + float* var_shared = shared + blockDim.x; + float* gs_shared = shared + 2 * blockDim.x; + float* gsn_shared = shared + 3 * blockDim.x; + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + thread_sum += __half2float(pre_norm[row * hidden_size + i]); + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + __syncthreads(); + } + float mean = mean_shared[0] / hidden_size; + __syncthreads(); + + float thread_var = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float diff = __half2float(pre_norm[row * hidden_size + i]) - mean; + thread_var += diff * diff; + } + var_shared[threadIdx.x] = thread_var; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + __syncthreads(); + } + float var = var_shared[0] / hidden_size; + float inv_std = rsqrtf(var + eps); + __syncthreads(); + + float thread_gs = 0.0f, thread_gsn = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float g = __half2float(grad[row * hidden_size + i]); + float w = __half2float(weight[i]); + float normalized = (__half2float(pre_norm[row * hidden_size + i]) - mean) * inv_std; + thread_gs += g * w; + thread_gsn += g * w * normalized; + } + gs_shared[threadIdx.x] = thread_gs; + gsn_shared[threadIdx.x] = thread_gsn; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + gs_shared[threadIdx.x] += gs_shared[threadIdx.x + s]; + gsn_shared[threadIdx.x] += gsn_shared[threadIdx.x + s]; + } + __syncthreads(); + } + float mean_gs = gs_shared[0] / hidden_size; + float mean_gsn = gsn_shared[0] / hidden_size; + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float g = __half2float(grad[row * hidden_size + i]); + float w = __half2float(weight[i]); + float normalized = (__half2float(pre_norm[row * hidden_size + i]) - mean) * inv_std; + float d_ir = inv_std * (g * w - mean_gs - normalized * mean_gsn); + d_input_residual[row * hidden_size + i] = __float2half(d_ir); + atomicAddHalf(&d_weight[i], g * normalized); + atomicAddHalf(&d_bias[i], g); + } +} + +// ============================================================================ +// BF16 Fused Add + LayerNorm Backward (FP32 accumulation) +// ============================================================================ + +__global__ void fused_add_layer_norm_bwd_bf16( + const __nv_bfloat16* grad, const __nv_bfloat16* pre_norm, const __nv_bfloat16* weight, + __nv_bfloat16* d_input_residual, __nv_bfloat16* d_weight, __nv_bfloat16* d_bias, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* mean_shared = shared; + float* var_shared = shared + blockDim.x; + float* gs_shared = shared + 2 * blockDim.x; + float* gsn_shared = shared + 3 * blockDim.x; + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + thread_sum += __bfloat162float(pre_norm[row * hidden_size + i]); + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + __syncthreads(); + } + float mean = mean_shared[0] / hidden_size; + __syncthreads(); + + float thread_var = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float diff = __bfloat162float(pre_norm[row * hidden_size + i]) - mean; + thread_var += diff * diff; + } + var_shared[threadIdx.x] = thread_var; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + __syncthreads(); + } + float var = var_shared[0] / hidden_size; + float inv_std = rsqrtf(var + eps); + __syncthreads(); + + float thread_gs = 0.0f, thread_gsn = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float g = __bfloat162float(grad[row * hidden_size + i]); + float w = __bfloat162float(weight[i]); + float normalized = (__bfloat162float(pre_norm[row * hidden_size + i]) - mean) * inv_std; + thread_gs += g * w; + thread_gsn += g * w * normalized; + } + gs_shared[threadIdx.x] = thread_gs; + gsn_shared[threadIdx.x] = thread_gsn; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + gs_shared[threadIdx.x] += gs_shared[threadIdx.x + s]; + gsn_shared[threadIdx.x] += gsn_shared[threadIdx.x + s]; + } + __syncthreads(); + } + float mean_gs = gs_shared[0] / hidden_size; + float mean_gsn = gsn_shared[0] / hidden_size; + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float g = __bfloat162float(grad[row * hidden_size + i]); + float w = __bfloat162float(weight[i]); + float normalized = (__bfloat162float(pre_norm[row * hidden_size + i]) - mean) * inv_std; + float d_ir = inv_std * (g * w - mean_gs - normalized * mean_gsn); + d_input_residual[row * hidden_size + i] = __float2bfloat16(d_ir); + atomicAddBf16(&d_weight[i], g * normalized); + atomicAddBf16(&d_bias[i], g); + } +} + +// ============================================================================ +// Helper: atomicAdd for FP8 types via 32-bit atomicCAS +// ============================================================================ + +__device__ void atomicAddFp8E4M3(numr_fp8_e4m3* address, float val) { + // FP8 is 1 byte — use 32-bit atomicCAS on the containing 4-byte word + unsigned int* base = (unsigned int*)((size_t)address & ~3ULL); + unsigned int byte_offset = (unsigned int)((size_t)address & 3); + unsigned int shift = byte_offset * 8; + unsigned int old_word = *base, assumed; + do { + assumed = old_word; + uint8_t old_byte = (uint8_t)((assumed >> shift) & 0xFF); + float old_float = fp8_e4m3_to_f32(old_byte); + uint8_t new_byte = f32_to_fp8_e4m3(old_float + val); + unsigned int new_word = (assumed & ~(0xFFu << shift)) | ((unsigned int)new_byte << shift); + old_word = atomicCAS(base, assumed, new_word); + } while (assumed != old_word); +} + +__device__ void atomicAddFp8E5M2(numr_fp8_e5m2* address, float val) { + unsigned int* base = (unsigned int*)((size_t)address & ~3ULL); + unsigned int byte_offset = (unsigned int)((size_t)address & 3); + unsigned int shift = byte_offset * 8; + unsigned int old_word = *base, assumed; + do { + assumed = old_word; + uint8_t old_byte = (uint8_t)((assumed >> shift) & 0xFF); + float old_float = fp8_e5m2_to_f32(old_byte); + uint8_t new_byte = f32_to_fp8_e5m2(old_float + val); + unsigned int new_word = (assumed & ~(0xFFu << shift)) | ((unsigned int)new_byte << shift); + old_word = atomicCAS(base, assumed, new_word); + } while (assumed != old_word); +} + +// ============================================================================ +// FP8 E4M3 Fused Add + RMSNorm Forward +// ============================================================================ + +__global__ void fused_add_rms_norm_fp8_e4m3( + const numr_fp8_e4m3* input, const numr_fp8_e4m3* residual, const numr_fp8_e4m3* weight, + numr_fp8_e4m3* output, numr_fp8_e4m3* pre_norm, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e4m3_to_f32(input[row * hidden_size + i].data) + + fp8_e4m3_to_f32(residual[row * hidden_size + i].data); + pre_norm[row * hidden_size + i].data = f32_to_fp8_e4m3(pn); + thread_sum += pn * pn; + } + shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) shared[threadIdx.x] += shared[threadIdx.x + s]; + __syncthreads(); + } + + float rms_inv = rsqrtf(shared[0] / hidden_size + eps); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e4m3_to_f32(pre_norm[row * hidden_size + i].data); + float w = fp8_e4m3_to_f32(weight[i].data); + output[row * hidden_size + i].data = f32_to_fp8_e4m3(pn * rms_inv * w); + } +} + +__global__ void fused_add_rms_norm_fp8_e5m2( + const numr_fp8_e5m2* input, const numr_fp8_e5m2* residual, const numr_fp8_e5m2* weight, + numr_fp8_e5m2* output, numr_fp8_e5m2* pre_norm, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e5m2_to_f32(input[row * hidden_size + i].data) + + fp8_e5m2_to_f32(residual[row * hidden_size + i].data); + pre_norm[row * hidden_size + i].data = f32_to_fp8_e5m2(pn); + thread_sum += pn * pn; + } + shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) shared[threadIdx.x] += shared[threadIdx.x + s]; + __syncthreads(); + } + + float rms_inv = rsqrtf(shared[0] / hidden_size + eps); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e5m2_to_f32(pre_norm[row * hidden_size + i].data); + float w = fp8_e5m2_to_f32(weight[i].data); + output[row * hidden_size + i].data = f32_to_fp8_e5m2(pn * rms_inv * w); + } +} + +// ============================================================================ +// FP8 E4M3 Fused Add + RMSNorm Backward +// ============================================================================ + +__global__ void fused_add_rms_norm_bwd_fp8_e4m3( + const numr_fp8_e4m3* grad, const numr_fp8_e4m3* pre_norm, const numr_fp8_e4m3* weight, + numr_fp8_e4m3* d_input_residual, numr_fp8_e4m3* d_weight, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* sum_sq_shared = shared; + float* dot_shared = shared + blockDim.x; + + float thread_sq = 0.0f, thread_dot = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e4m3_to_f32(pre_norm[row * hidden_size + i].data); + float g = fp8_e4m3_to_f32(grad[row * hidden_size + i].data); + float w = fp8_e4m3_to_f32(weight[i].data); + thread_sq += pn * pn; + thread_dot += g * w * pn; + } + sum_sq_shared[threadIdx.x] = thread_sq; + dot_shared[threadIdx.x] = thread_dot; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sum_sq_shared[threadIdx.x] += sum_sq_shared[threadIdx.x + s]; + dot_shared[threadIdx.x] += dot_shared[threadIdx.x + s]; + } + __syncthreads(); + } + + float mean_sq = sum_sq_shared[0] / hidden_size; + float inv_rms = rsqrtf(mean_sq + eps); + float dot = dot_shared[0]; + float coeff = dot * inv_rms / (hidden_size * (mean_sq + eps)); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float g = fp8_e4m3_to_f32(grad[row * hidden_size + i].data); + float w = fp8_e4m3_to_f32(weight[i].data); + float pn = fp8_e4m3_to_f32(pre_norm[row * hidden_size + i].data); + float dir = (g * w - pn * coeff) * inv_rms; + d_input_residual[row * hidden_size + i].data = f32_to_fp8_e4m3(dir); + atomicAddFp8E4M3(&d_weight[i], g * pn * inv_rms); + } +} + +__global__ void fused_add_rms_norm_bwd_fp8_e5m2( + const numr_fp8_e5m2* grad, const numr_fp8_e5m2* pre_norm, const numr_fp8_e5m2* weight, + numr_fp8_e5m2* d_input_residual, numr_fp8_e5m2* d_weight, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* sum_sq_shared = shared; + float* dot_shared = shared + blockDim.x; + + float thread_sq = 0.0f, thread_dot = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e5m2_to_f32(pre_norm[row * hidden_size + i].data); + float g = fp8_e5m2_to_f32(grad[row * hidden_size + i].data); + float w = fp8_e5m2_to_f32(weight[i].data); + thread_sq += pn * pn; + thread_dot += g * w * pn; + } + sum_sq_shared[threadIdx.x] = thread_sq; + dot_shared[threadIdx.x] = thread_dot; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sum_sq_shared[threadIdx.x] += sum_sq_shared[threadIdx.x + s]; + dot_shared[threadIdx.x] += dot_shared[threadIdx.x + s]; + } + __syncthreads(); + } + + float mean_sq = sum_sq_shared[0] / hidden_size; + float inv_rms = rsqrtf(mean_sq + eps); + float dot = dot_shared[0]; + float coeff = dot * inv_rms / (hidden_size * (mean_sq + eps)); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float g = fp8_e5m2_to_f32(grad[row * hidden_size + i].data); + float w = fp8_e5m2_to_f32(weight[i].data); + float pn = fp8_e5m2_to_f32(pre_norm[row * hidden_size + i].data); + float dir = (g * w - pn * coeff) * inv_rms; + d_input_residual[row * hidden_size + i].data = f32_to_fp8_e5m2(dir); + atomicAddFp8E5M2(&d_weight[i], g * pn * inv_rms); + } +} + +// ============================================================================ +// FP8 E4M3 Fused Add + LayerNorm Forward +// ============================================================================ + +__global__ void fused_add_layer_norm_fp8_e4m3( + const numr_fp8_e4m3* input, const numr_fp8_e4m3* residual, + const numr_fp8_e4m3* weight, const numr_fp8_e4m3* bias, + numr_fp8_e4m3* output, numr_fp8_e4m3* pre_norm, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* mean_shared = shared; + float* var_shared = shared + blockDim.x; + + // Phase 1: Add residual + compute mean + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e4m3_to_f32(input[row * hidden_size + i].data) + + fp8_e4m3_to_f32(residual[row * hidden_size + i].data); + pre_norm[row * hidden_size + i].data = f32_to_fp8_e4m3(pn); + thread_sum += pn; + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + __syncthreads(); + } + float mean = mean_shared[0] / hidden_size; + __syncthreads(); + + // Phase 2: Compute variance + float thread_var = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e4m3_to_f32(pre_norm[row * hidden_size + i].data); + float diff = pn - mean; + thread_var += diff * diff; + } + var_shared[threadIdx.x] = thread_var; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + __syncthreads(); + } + float inv_std = rsqrtf(var_shared[0] / hidden_size + eps); + __syncthreads(); + + // Phase 3: Normalize + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e4m3_to_f32(pre_norm[row * hidden_size + i].data); + float w = fp8_e4m3_to_f32(weight[i].data); + float b = fp8_e4m3_to_f32(bias[i].data); + float normalized = (pn - mean) * inv_std; + output[row * hidden_size + i].data = f32_to_fp8_e4m3(normalized * w + b); + } +} + +__global__ void fused_add_layer_norm_fp8_e5m2( + const numr_fp8_e5m2* input, const numr_fp8_e5m2* residual, + const numr_fp8_e5m2* weight, const numr_fp8_e5m2* bias, + numr_fp8_e5m2* output, numr_fp8_e5m2* pre_norm, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* mean_shared = shared; + float* var_shared = shared + blockDim.x; + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e5m2_to_f32(input[row * hidden_size + i].data) + + fp8_e5m2_to_f32(residual[row * hidden_size + i].data); + pre_norm[row * hidden_size + i].data = f32_to_fp8_e5m2(pn); + thread_sum += pn; + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + __syncthreads(); + } + float mean = mean_shared[0] / hidden_size; + __syncthreads(); + + float thread_var = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e5m2_to_f32(pre_norm[row * hidden_size + i].data); + float diff = pn - mean; + thread_var += diff * diff; + } + var_shared[threadIdx.x] = thread_var; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + __syncthreads(); + } + float inv_std = rsqrtf(var_shared[0] / hidden_size + eps); + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e5m2_to_f32(pre_norm[row * hidden_size + i].data); + float w = fp8_e5m2_to_f32(weight[i].data); + float b = fp8_e5m2_to_f32(bias[i].data); + float normalized = (pn - mean) * inv_std; + output[row * hidden_size + i].data = f32_to_fp8_e5m2(normalized * w + b); + } +} + +// ============================================================================ +// FP8 E4M3 Fused Add + LayerNorm Backward +// ============================================================================ + +__global__ void fused_add_layer_norm_bwd_fp8_e4m3( + const numr_fp8_e4m3* grad, const numr_fp8_e4m3* pre_norm, + const numr_fp8_e4m3* weight, + numr_fp8_e4m3* d_input_residual, numr_fp8_e4m3* d_weight, numr_fp8_e4m3* d_bias, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* mean_shared = shared; + float* var_shared = shared + blockDim.x; + float* gs_shared = shared + 2 * blockDim.x; + float* gsn_shared = shared + 3 * blockDim.x; + + // Phase 1: Compute mean + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + thread_sum += fp8_e4m3_to_f32(pre_norm[row * hidden_size + i].data); + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + __syncthreads(); + } + float mean = mean_shared[0] / hidden_size; + __syncthreads(); + + // Phase 2: Compute variance + dot products + float thread_var = 0.0f, thread_gs = 0.0f, thread_gsn = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e4m3_to_f32(pre_norm[row * hidden_size + i].data); + float g = fp8_e4m3_to_f32(grad[row * hidden_size + i].data); + float w = fp8_e4m3_to_f32(weight[i].data); + float diff = pn - mean; + thread_var += diff * diff; + thread_gs += g * w; + thread_gsn += g * w * diff; + } + var_shared[threadIdx.x] = thread_var; + gs_shared[threadIdx.x] = thread_gs; + gsn_shared[threadIdx.x] = thread_gsn; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + gs_shared[threadIdx.x] += gs_shared[threadIdx.x + s]; + gsn_shared[threadIdx.x] += gsn_shared[threadIdx.x + s]; + } + __syncthreads(); + } + + float inv_std = rsqrtf(var_shared[0] / hidden_size + eps); + float mean_gs = gs_shared[0] / hidden_size; + float mean_gsn = gsn_shared[0] / hidden_size; + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float g = fp8_e4m3_to_f32(grad[row * hidden_size + i].data); + float w = fp8_e4m3_to_f32(weight[i].data); + float normalized = (fp8_e4m3_to_f32(pre_norm[row * hidden_size + i].data) - mean) * inv_std; + float d_ir = inv_std * (g * w - mean_gs - normalized * mean_gsn); + d_input_residual[row * hidden_size + i].data = f32_to_fp8_e4m3(d_ir); + atomicAddFp8E4M3(&d_weight[i], g * normalized); + atomicAddFp8E4M3(&d_bias[i], g); + } +} + +__global__ void fused_add_layer_norm_bwd_fp8_e5m2( + const numr_fp8_e5m2* grad, const numr_fp8_e5m2* pre_norm, + const numr_fp8_e5m2* weight, + numr_fp8_e5m2* d_input_residual, numr_fp8_e5m2* d_weight, numr_fp8_e5m2* d_bias, + unsigned int batch_size, unsigned int hidden_size, float eps +) { + unsigned int row = blockIdx.x; + if (row >= batch_size) return; + + extern __shared__ float shared[]; + float* mean_shared = shared; + float* var_shared = shared + blockDim.x; + float* gs_shared = shared + 2 * blockDim.x; + float* gsn_shared = shared + 3 * blockDim.x; + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + thread_sum += fp8_e5m2_to_f32(pre_norm[row * hidden_size + i].data); + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + __syncthreads(); + } + float mean = mean_shared[0] / hidden_size; + __syncthreads(); + + float thread_var = 0.0f, thread_gs = 0.0f, thread_gsn = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float pn = fp8_e5m2_to_f32(pre_norm[row * hidden_size + i].data); + float g = fp8_e5m2_to_f32(grad[row * hidden_size + i].data); + float w = fp8_e5m2_to_f32(weight[i].data); + float diff = pn - mean; + thread_var += diff * diff; + thread_gs += g * w; + thread_gsn += g * w * diff; + } + var_shared[threadIdx.x] = thread_var; + gs_shared[threadIdx.x] = thread_gs; + gsn_shared[threadIdx.x] = thread_gsn; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + gs_shared[threadIdx.x] += gs_shared[threadIdx.x + s]; + gsn_shared[threadIdx.x] += gsn_shared[threadIdx.x + s]; + } + __syncthreads(); + } + + float inv_std = rsqrtf(var_shared[0] / hidden_size + eps); + float mean_gs = gs_shared[0] / hidden_size; + float mean_gsn = gsn_shared[0] / hidden_size; + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float g = fp8_e5m2_to_f32(grad[row * hidden_size + i].data); + float w = fp8_e5m2_to_f32(weight[i].data); + float normalized = (fp8_e5m2_to_f32(pre_norm[row * hidden_size + i].data) - mean) * inv_std; + float d_ir = inv_std * (g * w - mean_gs - normalized * mean_gsn); + d_input_residual[row * hidden_size + i].data = f32_to_fp8_e5m2(d_ir); + atomicAddFp8E5M2(&d_weight[i], g * normalized); + atomicAddFp8E5M2(&d_bias[i], g); + } +} + +} // extern "C" diff --git a/src/runtime/cuda/kernels/fused_add_norm.rs b/src/runtime/cuda/kernels/fused_add_norm.rs new file mode 100644 index 00000000..c5fe468d --- /dev/null +++ b/src/runtime/cuda/kernels/fused_add_norm.rs @@ -0,0 +1,329 @@ +//! Fused Add + Normalization CUDA kernel launchers +//! +//! Provides launchers for fused operations combining residual addition with normalization. +//! These operations are common in transformer architectures for efficient computation. + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use super::loader::{ + BLOCK_SIZE, get_kernel_function, get_or_load_module, kernel_name, kernel_names, launch_config, +}; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +/// Calculate launch configuration for fused normalization kernels. +/// +/// One block per row (batch element), with threads cooperating to compute statistics. +/// Returns (grid_size, block_size, shared_memory_bytes). +#[inline] +fn fused_norm_launch_config( + batch_size: usize, + hidden_size: usize, + shared_arrays: usize, + dtype: DType, +) -> (u32, u32, u32) { + let block_size = BLOCK_SIZE.min(hidden_size as u32); + let grid_size = batch_size as u32; + let elem_size = match dtype { + DType::F64 => 8u32, + _ => 4u32, // f32, f16, bf16 all use f32 shared memory + }; + let shared_mem = (shared_arrays as u32) * block_size * elem_size; + (grid_size, block_size, shared_mem) +} + +/// Launch a fused_add_rms_norm forward kernel. +/// +/// Computes: `pre_norm = input + residual`, then `output = pre_norm * rsqrt(mean(pre_norm^2) + eps) * weight` +/// +/// # Arguments +/// +/// * `input_ptr` - Device pointer to input tensor of shape [batch_size, hidden_size] +/// * `residual_ptr` - Device pointer to residual tensor of shape [batch_size, hidden_size] +/// * `weight_ptr` - Device pointer to weight tensor of shape [hidden_size] +/// * `output_ptr` - Device pointer to output tensor of shape [batch_size, hidden_size] +/// * `pre_norm_ptr` - Device pointer to pre-normalization tensor of shape [batch_size, hidden_size] +/// * `batch_size` - Number of rows (batch dimension) +/// * `hidden_size` - Size of each row (hidden dimension) +/// * `eps` - Small constant for numerical stability +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - All tensors must have `batch_size * hidden_size` elements +pub unsafe fn launch_fused_add_rms_norm( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + residual_ptr: u64, + weight_ptr: u64, + output_ptr: u64, + pre_norm_ptr: u64, + batch_size: usize, + hidden_size: usize, + eps: f32, +) -> Result<()> { + unsafe { + let module = + get_or_load_module(context, device_index, kernel_names::FUSED_ADD_NORM_MODULE)?; + let func_name = kernel_name("fused_add_rms_norm", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let (grid_size, block_size, shared_mem) = + fused_norm_launch_config(batch_size, hidden_size, 1, dtype); + let batch = batch_size as u32; + let hidden = hidden_size as u32; + let eps_f64 = eps as f64; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem); + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&residual_ptr); + builder.arg(&weight_ptr); + builder.arg(&output_ptr); + builder.arg(&pre_norm_ptr); + builder.arg(&batch); + builder.arg(&hidden); + if dtype == DType::F64 { + builder.arg(&eps_f64); + } else { + builder.arg(&eps); + } + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA fused_add_rms_norm kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +/// Launch a fused_add_rms_norm backward kernel. +/// +/// Computes gradients for fused add + RMSNorm operation. +/// +/// # Arguments +/// +/// * `grad_ptr` - Device pointer to gradient tensor of shape [batch_size, hidden_size] +/// * `pre_norm_ptr` - Device pointer to pre-norm tensor from forward pass +/// * `weight_ptr` - Device pointer to weight tensor of shape [hidden_size] +/// * `d_input_residual_ptr` - Device pointer to output gradients for input and residual +/// * `d_weight_ptr` - Device pointer to weight gradients (pre-zeroed, accumulated via atomicAdd) +/// * `batch_size` - Number of rows +/// * `hidden_size` - Size of each row +/// * `eps` - Small constant for numerical stability +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - d_weight_ptr must be pre-zeroed with `hidden_size` elements +pub unsafe fn launch_fused_add_rms_norm_bwd( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + grad_ptr: u64, + pre_norm_ptr: u64, + weight_ptr: u64, + d_input_residual_ptr: u64, + d_weight_ptr: u64, + batch_size: usize, + hidden_size: usize, + eps: f32, +) -> Result<()> { + unsafe { + let module = + get_or_load_module(context, device_index, kernel_names::FUSED_ADD_NORM_MODULE)?; + let func_name = kernel_name("fused_add_rms_norm_bwd", dtype); + let func = get_kernel_function(&module, &func_name)?; + + // Backward needs 2 shared arrays: sum_sq and dot + let (grid_size, block_size, shared_mem) = + fused_norm_launch_config(batch_size, hidden_size, 2, dtype); + let batch = batch_size as u32; + let hidden = hidden_size as u32; + let eps_f64 = eps as f64; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem); + let mut builder = stream.launch_builder(&func); + builder.arg(&grad_ptr); + builder.arg(&pre_norm_ptr); + builder.arg(&weight_ptr); + builder.arg(&d_input_residual_ptr); + builder.arg(&d_weight_ptr); + builder.arg(&batch); + builder.arg(&hidden); + if dtype == DType::F64 { + builder.arg(&eps_f64); + } else { + builder.arg(&eps); + } + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA fused_add_rms_norm_bwd kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +/// Launch a fused_add_layer_norm forward kernel. +/// +/// Computes: `pre_norm = input + residual`, then +/// `output = (pre_norm - mean) / sqrt(var + eps) * weight + bias` +/// +/// # Arguments +/// +/// * `input_ptr` - Device pointer to input tensor of shape [batch_size, hidden_size] +/// * `residual_ptr` - Device pointer to residual tensor of shape [batch_size, hidden_size] +/// * `weight_ptr` - Device pointer to weight (gamma) tensor of shape [hidden_size] +/// * `bias_ptr` - Device pointer to bias (beta) tensor of shape [hidden_size] +/// * `output_ptr` - Device pointer to output tensor of shape [batch_size, hidden_size] +/// * `pre_norm_ptr` - Device pointer to pre-normalization tensor of shape [batch_size, hidden_size] +/// * `batch_size` - Number of rows (batch dimension) +/// * `hidden_size` - Size of each row (hidden dimension) +/// * `eps` - Small constant for numerical stability +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - All tensors must have `batch_size * hidden_size` elements +pub unsafe fn launch_fused_add_layer_norm( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + residual_ptr: u64, + weight_ptr: u64, + bias_ptr: u64, + output_ptr: u64, + pre_norm_ptr: u64, + batch_size: usize, + hidden_size: usize, + eps: f32, +) -> Result<()> { + unsafe { + let module = + get_or_load_module(context, device_index, kernel_names::FUSED_ADD_NORM_MODULE)?; + let func_name = kernel_name("fused_add_layer_norm", dtype); + let func = get_kernel_function(&module, &func_name)?; + + // Layer norm needs 2 shared arrays: mean and variance + let (grid_size, block_size, shared_mem) = + fused_norm_launch_config(batch_size, hidden_size, 2, dtype); + let batch = batch_size as u32; + let hidden = hidden_size as u32; + let eps_f64 = eps as f64; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem); + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&residual_ptr); + builder.arg(&weight_ptr); + builder.arg(&bias_ptr); + builder.arg(&output_ptr); + builder.arg(&pre_norm_ptr); + builder.arg(&batch); + builder.arg(&hidden); + if dtype == DType::F64 { + builder.arg(&eps_f64); + } else { + builder.arg(&eps); + } + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA fused_add_layer_norm kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +/// Launch a fused_add_layer_norm backward kernel. +/// +/// Computes gradients for fused add + LayerNorm operation. +/// +/// # Arguments +/// +/// * `grad_ptr` - Device pointer to gradient tensor of shape [batch_size, hidden_size] +/// * `pre_norm_ptr` - Device pointer to pre-norm tensor from forward pass +/// * `weight_ptr` - Device pointer to weight tensor of shape [hidden_size] +/// * `d_input_residual_ptr` - Device pointer to output gradients for input and residual +/// * `d_weight_ptr` - Device pointer to weight gradients (pre-zeroed, accumulated via atomicAdd) +/// * `d_bias_ptr` - Device pointer to bias gradients (pre-zeroed, accumulated via atomicAdd) +/// * `batch_size` - Number of rows +/// * `hidden_size` - Size of each row +/// * `eps` - Small constant for numerical stability +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - d_weight_ptr and d_bias_ptr must be pre-zeroed with `hidden_size` elements each +pub unsafe fn launch_fused_add_layer_norm_bwd( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + grad_ptr: u64, + pre_norm_ptr: u64, + weight_ptr: u64, + d_input_residual_ptr: u64, + d_weight_ptr: u64, + d_bias_ptr: u64, + batch_size: usize, + hidden_size: usize, + eps: f32, +) -> Result<()> { + unsafe { + let module = + get_or_load_module(context, device_index, kernel_names::FUSED_ADD_NORM_MODULE)?; + let func_name = kernel_name("fused_add_layer_norm_bwd", dtype); + let func = get_kernel_function(&module, &func_name)?; + + // Backward needs 4 shared arrays: mean, var, gs (mean_gs), gsn (mean_gsn) + let (grid_size, block_size, shared_mem) = + fused_norm_launch_config(batch_size, hidden_size, 4, dtype); + let batch = batch_size as u32; + let hidden = hidden_size as u32; + let eps_f64 = eps as f64; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem); + let mut builder = stream.launch_builder(&func); + builder.arg(&grad_ptr); + builder.arg(&pre_norm_ptr); + builder.arg(&weight_ptr); + builder.arg(&d_input_residual_ptr); + builder.arg(&d_weight_ptr); + builder.arg(&d_bias_ptr); + builder.arg(&batch); + builder.arg(&hidden); + if dtype == DType::F64 { + builder.arg(&eps_f64); + } else { + builder.arg(&eps); + } + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA fused_add_layer_norm_bwd kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} diff --git a/src/runtime/cuda/kernels/fused_elementwise.cu b/src/runtime/cuda/kernels/fused_elementwise.cu new file mode 100644 index 00000000..f06c4eb4 --- /dev/null +++ b/src/runtime/cuda/kernels/fused_elementwise.cu @@ -0,0 +1,191 @@ +// Fused elementwise CUDA kernels +// fused_mul_add: out = a * b + c (FMA) +// fused_add_mul: out = (a + b) * c +// fused_mul_add_scalar: out = a * scale + bias +// Types: f32, f64, f16, bf16 + +#include +#include +#include "dtype_traits.cuh" + +extern "C" { + +// ============================================================================ +// fused_mul_add: out = a * b + c +// ============================================================================ + +__global__ void fused_mul_add_f32(const float* a, const float* b, const float* c, float* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = fmaf(a[idx], b[idx], c[idx]); + } +} + +__global__ void fused_mul_add_f64(const double* a, const double* b, const double* c, double* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = fma(a[idx], b[idx], c[idx]); + } +} + +__global__ void fused_mul_add_f16(const __half* a, const __half* b, const __half* c, __half* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float va = __half2float(a[idx]); + float vb = __half2float(b[idx]); + float vc = __half2float(c[idx]); + out[idx] = __float2half(fmaf(va, vb, vc)); + } +} + +__global__ void fused_mul_add_bf16(const __nv_bfloat16* a, const __nv_bfloat16* b, const __nv_bfloat16* c, __nv_bfloat16* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float va = __bfloat162float(a[idx]); + float vb = __bfloat162float(b[idx]); + float vc = __bfloat162float(c[idx]); + out[idx] = __float2bfloat16(fmaf(va, vb, vc)); + } +} + +// ============================================================================ +// fused_add_mul: out = (a + b) * c +// ============================================================================ + +__global__ void fused_add_mul_f32(const float* a, const float* b, const float* c, float* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = (a[idx] + b[idx]) * c[idx]; + } +} + +__global__ void fused_add_mul_f64(const double* a, const double* b, const double* c, double* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = (a[idx] + b[idx]) * c[idx]; + } +} + +__global__ void fused_add_mul_f16(const __half* a, const __half* b, const __half* c, __half* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float va = __half2float(a[idx]); + float vb = __half2float(b[idx]); + float vc = __half2float(c[idx]); + out[idx] = __float2half((va + vb) * vc); + } +} + +__global__ void fused_add_mul_bf16(const __nv_bfloat16* a, const __nv_bfloat16* b, const __nv_bfloat16* c, __nv_bfloat16* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float va = __bfloat162float(a[idx]); + float vb = __bfloat162float(b[idx]); + float vc = __bfloat162float(c[idx]); + out[idx] = __float2bfloat16((va + vb) * vc); + } +} + +// ============================================================================ +// fused_mul_add_scalar: out = a * scale + bias +// ============================================================================ + +__global__ void fused_mul_add_scalar_f32(const float* a, float* out, unsigned int n, float scale, float bias) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = fmaf(a[idx], scale, bias); + } +} + +__global__ void fused_mul_add_scalar_f64(const double* a, double* out, unsigned int n, double scale, double bias) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = fma(a[idx], scale, bias); + } +} + +__global__ void fused_mul_add_scalar_f16(const __half* a, __half* out, unsigned int n, float scale, float bias) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float va = __half2float(a[idx]); + out[idx] = __float2half(fmaf(va, scale, bias)); + } +} + +__global__ void fused_mul_add_scalar_bf16(const __nv_bfloat16* a, __nv_bfloat16* out, unsigned int n, float scale, float bias) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float va = __bfloat162float(a[idx]); + out[idx] = __float2bfloat16(fmaf(va, scale, bias)); + } +} + +// ============================================================================ +// FP8 fused_mul_add: out = a * b + c +// ============================================================================ + +__global__ void fused_mul_add_fp8_e4m3(const numr_fp8_e4m3* a, const numr_fp8_e4m3* b, const numr_fp8_e4m3* c, numr_fp8_e4m3* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float va = fp8_e4m3_to_f32(a[idx].data); + float vb = fp8_e4m3_to_f32(b[idx].data); + float vc = fp8_e4m3_to_f32(c[idx].data); + out[idx].data = f32_to_fp8_e4m3(fmaf(va, vb, vc)); + } +} + +__global__ void fused_mul_add_fp8_e5m2(const numr_fp8_e5m2* a, const numr_fp8_e5m2* b, const numr_fp8_e5m2* c, numr_fp8_e5m2* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float va = fp8_e5m2_to_f32(a[idx].data); + float vb = fp8_e5m2_to_f32(b[idx].data); + float vc = fp8_e5m2_to_f32(c[idx].data); + out[idx].data = f32_to_fp8_e5m2(fmaf(va, vb, vc)); + } +} + +// ============================================================================ +// FP8 fused_add_mul: out = (a + b) * c +// ============================================================================ + +__global__ void fused_add_mul_fp8_e4m3(const numr_fp8_e4m3* a, const numr_fp8_e4m3* b, const numr_fp8_e4m3* c, numr_fp8_e4m3* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float va = fp8_e4m3_to_f32(a[idx].data); + float vb = fp8_e4m3_to_f32(b[idx].data); + float vc = fp8_e4m3_to_f32(c[idx].data); + out[idx].data = f32_to_fp8_e4m3((va + vb) * vc); + } +} + +__global__ void fused_add_mul_fp8_e5m2(const numr_fp8_e5m2* a, const numr_fp8_e5m2* b, const numr_fp8_e5m2* c, numr_fp8_e5m2* out, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float va = fp8_e5m2_to_f32(a[idx].data); + float vb = fp8_e5m2_to_f32(b[idx].data); + float vc = fp8_e5m2_to_f32(c[idx].data); + out[idx].data = f32_to_fp8_e5m2((va + vb) * vc); + } +} + +// ============================================================================ +// FP8 fused_mul_add_scalar: out = a * scale + bias +// ============================================================================ + +__global__ void fused_mul_add_scalar_fp8_e4m3(const numr_fp8_e4m3* a, numr_fp8_e4m3* out, unsigned int n, float scale, float bias) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float va = fp8_e4m3_to_f32(a[idx].data); + out[idx].data = f32_to_fp8_e4m3(fmaf(va, scale, bias)); + } +} + +__global__ void fused_mul_add_scalar_fp8_e5m2(const numr_fp8_e5m2* a, numr_fp8_e5m2* out, unsigned int n, float scale, float bias) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float va = fp8_e5m2_to_f32(a[idx].data); + out[idx].data = f32_to_fp8_e5m2(fmaf(va, scale, bias)); + } +} + +} // extern "C" diff --git a/src/runtime/cuda/kernels/fused_elementwise.rs b/src/runtime/cuda/kernels/fused_elementwise.rs new file mode 100644 index 00000000..2964c7f5 --- /dev/null +++ b/src/runtime/cuda/kernels/fused_elementwise.rs @@ -0,0 +1,173 @@ +//! Fused elementwise CUDA kernel launchers +//! +//! - fused_mul_add: out = a * b + c +//! - fused_add_mul: out = (a + b) * c +//! - fused_mul_add_scalar: out = a * scale + bias + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use super::loader::{ + BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name, + launch_config, +}; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +const MODULE: &str = "fused_elementwise"; + +/// Launch fused_mul_add: out = a * b + c +/// +/// # Safety +/// All pointers must be valid device memory with at least `numel` elements. +pub unsafe fn launch_fused_mul_add( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + a_ptr: u64, + b_ptr: u64, + c_ptr: u64, + output_ptr: u64, + numel: usize, +) -> Result<()> { + unsafe { + launch_ternary_kernel( + context, + stream, + device_index, + "fused_mul_add", + dtype, + a_ptr, + b_ptr, + c_ptr, + output_ptr, + numel, + ) + } +} + +/// Launch fused_add_mul: out = (a + b) * c +/// +/// # Safety +/// All pointers must be valid device memory with at least `numel` elements. +pub unsafe fn launch_fused_add_mul( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + a_ptr: u64, + b_ptr: u64, + c_ptr: u64, + output_ptr: u64, + numel: usize, +) -> Result<()> { + unsafe { + launch_ternary_kernel( + context, + stream, + device_index, + "fused_add_mul", + dtype, + a_ptr, + b_ptr, + c_ptr, + output_ptr, + numel, + ) + } +} + +/// Launch fused_mul_add_scalar: out = a * scale + bias +/// +/// # Safety +/// All pointers must be valid device memory with at least `numel` elements. +pub unsafe fn launch_fused_mul_add_scalar( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + a_ptr: u64, + output_ptr: u64, + numel: usize, + scale: f64, + bias: f64, +) -> Result<()> { + let module = get_or_load_module(context, device_index, MODULE)?; + let func_name = kernel_name("fused_mul_add_scalar", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(numel); + let block = (BLOCK_SIZE, 1, 1); + let n = numel as u32; + let cfg = launch_config(grid, block, 0); + + let scale_f32 = scale as f32; + let bias_f32 = bias as f32; + + let mut builder = stream.launch_builder(&func); + unsafe { + builder.arg(&a_ptr); + builder.arg(&output_ptr); + builder.arg(&n); + + match dtype { + DType::F64 => { + builder.arg(&scale); + builder.arg(&bias); + } + _ => { + builder.arg(&scale_f32); + builder.arg(&bias_f32); + } + } + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA fused_mul_add_scalar kernel launch failed: {:?}", + e + )) + })?; + } + + Ok(()) +} + +/// Internal helper for ternary kernels (a, b, c -> out) +unsafe fn launch_ternary_kernel( + context: &Arc, + stream: &CudaStream, + device_index: usize, + op: &str, + dtype: DType, + a_ptr: u64, + b_ptr: u64, + c_ptr: u64, + output_ptr: u64, + numel: usize, +) -> Result<()> { + let module = get_or_load_module(context, device_index, MODULE)?; + let func_name = kernel_name(op, dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(numel); + let block = (BLOCK_SIZE, 1, 1); + let n = numel as u32; + let cfg = launch_config(grid, block, 0); + + let mut builder = stream.launch_builder(&func); + unsafe { + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&c_ptr); + builder.arg(&output_ptr); + builder.arg(&n); + + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA {} kernel launch failed: {:?}", op, e)))?; + } + + Ok(()) +} diff --git a/src/runtime/cuda/kernels/gemm_epilogue.cu b/src/runtime/cuda/kernels/gemm_epilogue.cu new file mode 100644 index 00000000..6616d9c5 --- /dev/null +++ b/src/runtime/cuda/kernels/gemm_epilogue.cu @@ -0,0 +1,1396 @@ +// Fused GEMM epilogue kernels: +// - gemm_bias_act: C = activation(A @ B + bias) +// - gemm_bias_residual: C = A @ B + bias + residual +// +// activation_type: 0=None, 1=ReLU, 2=GELU, 3=SiLU, 4=Sigmoid, 5=Tanh + +#include +#include + +// ============================================================================ +// Activation helpers (device functions) +// ============================================================================ + +__device__ __forceinline__ float apply_activation_f32(float x, unsigned int act_type) { + switch (act_type) { + case 0: return x; // None + case 1: return fmaxf(x, 0.0f); // ReLU + case 2: { // GELU + const float sqrt_2_over_pi = 0.7978845608f; + const float coef = 0.044715f; + float inner = sqrt_2_over_pi * (x + coef * x * x * x); + return 0.5f * x * (1.0f + tanhf(inner)); + } + case 3: { // SiLU + return x / (1.0f + expf(-x)); + } + case 4: { // Sigmoid + return 1.0f / (1.0f + expf(-x)); + } + case 5: { // Tanh + return tanhf(x); + } + default: return x; + } +} + +__device__ __forceinline__ double apply_activation_f64(double x, unsigned int act_type) { + switch (act_type) { + case 0: return x; + case 1: return fmax(x, 0.0); + case 2: { + const double sqrt_2_over_pi = 0.7978845608028654; + const double coef = 0.044715; + double inner = sqrt_2_over_pi * (x + coef * x * x * x); + return 0.5 * x * (1.0 + tanh(inner)); + } + case 3: return x / (1.0 + exp(-x)); + case 4: return 1.0 / (1.0 + exp(-x)); + case 5: return tanh(x); + default: return x; + } +} + +// ============================================================================ +// GEMM + bias + activation: F32 +// ============================================================================ + +extern "C" __global__ void gemm_bias_act_f32( + const float* __restrict__ A, + const float* __restrict__ B, + const float* __restrict__ bias, + float* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int activation_type, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + extern __shared__ float shared_mem[]; + float* As = shared_mem; + float* Bs = shared_mem + block_m * block_k; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + float reg_c[8][8]; + #pragma unroll + for (int i = 0; i < 8; i++) + #pragma unroll + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0f; + + float reg_a[8]; + float reg_b[8]; + + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? A[gr * K + gc] : 0.0f; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? B[gr * N + gc] : 0.0f; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + // EPILOGUE: bias + activation + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + float val = reg_c[i][j] + bias[global_col]; + C[global_row * N + global_col] = apply_activation_f32(val, activation_type); + } + } + } + } +} + +extern "C" __global__ void gemm_bias_act_batched_f32( + const float* __restrict__ A, + const float* __restrict__ B, + const float* __restrict__ bias, + float* __restrict__ C, + unsigned int batch_count, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int activation_type, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + const unsigned int batch = blockIdx.z; + if (batch >= batch_count) return; + + extern __shared__ float shared_mem[]; + float* As = shared_mem; + float* Bs = shared_mem + block_m * block_k; + + const float* A_batch = A + batch * M * K; + const float* B_batch = B + batch * K * N; + float* C_batch = C + batch * M * N; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + float reg_c[8][8]; + #pragma unroll + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0f; + + float reg_a[8], reg_b[8]; + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? A_batch[gr * K + gc] : 0.0f; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? B_batch[gr * N + gc] : 0.0f; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + float val = reg_c[i][j] + bias[global_col]; + C_batch[global_row * N + global_col] = apply_activation_f32(val, activation_type); + } + } + } + } +} + +// ============================================================================ +// GEMM + bias + activation: F64 +// ============================================================================ + +extern "C" __global__ void gemm_bias_act_f64( + const double* __restrict__ A, + const double* __restrict__ B, + const double* __restrict__ bias, + double* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int activation_type, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + extern __shared__ double shared_mem_f64[]; + double* As = shared_mem_f64; + double* Bs = shared_mem_f64 + block_m * block_k; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + double reg_c[8][8]; + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0; + + double reg_a[8], reg_b[8]; + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? A[gr * K + gc] : 0.0; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? B[gr * N + gc] : 0.0; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + double val = reg_c[i][j] + bias[global_col]; + C[global_row * N + global_col] = apply_activation_f64(val, activation_type); + } + } + } + } +} + +extern "C" __global__ void gemm_bias_act_batched_f64( + const double* __restrict__ A, + const double* __restrict__ B, + const double* __restrict__ bias, + double* __restrict__ C, + unsigned int batch_count, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int activation_type, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + const unsigned int batch = blockIdx.z; + if (batch >= batch_count) return; + + extern __shared__ double shared_mem_f64[]; + double* As = shared_mem_f64; + double* Bs = shared_mem_f64 + block_m * block_k; + + const double* A_batch = A + batch * M * K; + const double* B_batch = B + batch * K * N; + double* C_batch = C + batch * M * N; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + double reg_c[8][8]; + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0; + + double reg_a[8], reg_b[8]; + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? A_batch[gr * K + gc] : 0.0; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? B_batch[gr * N + gc] : 0.0; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + double val = reg_c[i][j] + bias[global_col]; + C_batch[global_row * N + global_col] = apply_activation_f64(val, activation_type); + } + } + } + } +} + +// ============================================================================ +// GEMM + bias + residual: F32 +// ============================================================================ + +extern "C" __global__ void gemm_bias_residual_f32( + const float* __restrict__ A, + const float* __restrict__ B, + const float* __restrict__ bias, + const float* __restrict__ residual, + float* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + extern __shared__ float shared_mem[]; + float* As = shared_mem; + float* Bs = shared_mem + block_m * block_k; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + float reg_c[8][8]; + #pragma unroll + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0f; + + float reg_a[8], reg_b[8]; + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? A[gr * K + gc] : 0.0f; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? B[gr * N + gc] : 0.0f; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + // EPILOGUE: bias + residual + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + unsigned int idx = global_row * N + global_col; + C[idx] = reg_c[i][j] + bias[global_col] + residual[idx]; + } + } + } + } +} + +extern "C" __global__ void gemm_bias_residual_batched_f32( + const float* __restrict__ A, + const float* __restrict__ B, + const float* __restrict__ bias, + const float* __restrict__ residual, + float* __restrict__ C, + unsigned int batch_count, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + const unsigned int batch = blockIdx.z; + if (batch >= batch_count) return; + + extern __shared__ float shared_mem[]; + float* As = shared_mem; + float* Bs = shared_mem + block_m * block_k; + + const float* A_batch = A + batch * M * K; + const float* B_batch = B + batch * K * N; + const float* res_batch = residual + batch * M * N; + float* C_batch = C + batch * M * N; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + float reg_c[8][8]; + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0f; + + float reg_a[8], reg_b[8]; + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? A_batch[gr * K + gc] : 0.0f; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? B_batch[gr * N + gc] : 0.0f; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + unsigned int idx = global_row * N + global_col; + C_batch[idx] = reg_c[i][j] + bias[global_col] + res_batch[idx]; + } + } + } + } +} +// ============================================================================ +// GEMM + bias + residual: F64 +// ============================================================================ + +extern "C" __global__ void gemm_bias_residual_f64( + const double* __restrict__ A, + const double* __restrict__ B, + const double* __restrict__ bias, + const double* __restrict__ residual, + double* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + extern __shared__ double shared_mem_f64[]; + double* As = shared_mem_f64; + double* Bs = shared_mem_f64 + block_m * block_k; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + double reg_c[8][8]; + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0; + + double reg_a[8], reg_b[8]; + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? A[gr * K + gc] : 0.0; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? B[gr * N + gc] : 0.0; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + unsigned int idx = global_row * N + global_col; + C[idx] = reg_c[i][j] + bias[global_col] + residual[idx]; + } + } + } + } +} + +extern "C" __global__ void gemm_bias_residual_batched_f64( + const double* __restrict__ A, + const double* __restrict__ B, + const double* __restrict__ bias, + const double* __restrict__ residual, + double* __restrict__ C, + unsigned int batch_count, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + const unsigned int batch = blockIdx.z; + if (batch >= batch_count) return; + + extern __shared__ double shared_mem_f64[]; + double* As = shared_mem_f64; + double* Bs = shared_mem_f64 + block_m * block_k; + + const double* A_batch = A + batch * M * K; + const double* B_batch = B + batch * K * N; + const double* res_batch = residual + batch * M * N; + double* C_batch = C + batch * M * N; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + double reg_c[8][8]; + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0; + + double reg_a[8], reg_b[8]; + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? A_batch[gr * K + gc] : 0.0; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? B_batch[gr * N + gc] : 0.0; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + unsigned int idx = global_row * N + global_col; + C_batch[idx] = reg_c[i][j] + bias[global_col] + res_batch[idx]; + } + } + } + } +} + +// ============================================================================ +// GEMM + bias + activation: F16 +// ============================================================================ + +extern "C" __global__ void gemm_bias_act_f16( + const __half* __restrict__ A, + const __half* __restrict__ B, + const __half* __restrict__ bias, + __half* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int activation_type, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + extern __shared__ float shared_mem_f16[]; + float* As = shared_mem_f16; + float* Bs = shared_mem_f16 + block_m * block_k; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + float reg_c[8][8]; + #pragma unroll + for (int i = 0; i < 8; i++) + #pragma unroll + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0f; + + float reg_a[8]; + float reg_b[8]; + + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? __half2float(A[gr * K + gc]) : 0.0f; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? __half2float(B[gr * N + gc]) : 0.0f; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + // EPILOGUE: bias + activation + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + float val = reg_c[i][j] + __half2float(bias[global_col]); + C[global_row * N + global_col] = __float2half(apply_activation_f32(val, activation_type)); + } + } + } + } +} + +extern "C" __global__ void gemm_bias_act_batched_f16( + const __half* __restrict__ A, + const __half* __restrict__ B, + const __half* __restrict__ bias, + __half* __restrict__ C, + unsigned int batch_count, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int activation_type, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + const unsigned int batch = blockIdx.z; + if (batch >= batch_count) return; + + extern __shared__ float shared_mem_f16[]; + float* As = shared_mem_f16; + float* Bs = shared_mem_f16 + block_m * block_k; + + const __half* A_batch = A + batch * M * K; + const __half* B_batch = B + batch * K * N; + __half* C_batch = C + batch * M * N; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + float reg_c[8][8]; + #pragma unroll + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0f; + + float reg_a[8], reg_b[8]; + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? __half2float(A_batch[gr * K + gc]) : 0.0f; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? __half2float(B_batch[gr * N + gc]) : 0.0f; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + float val = reg_c[i][j] + __half2float(bias[global_col]); + C_batch[global_row * N + global_col] = __float2half(apply_activation_f32(val, activation_type)); + } + } + } + } +} + +// ============================================================================ +// GEMM + bias + residual: F16 +// ============================================================================ + +extern "C" __global__ void gemm_bias_residual_f16( + const __half* __restrict__ A, + const __half* __restrict__ B, + const __half* __restrict__ bias, + const __half* __restrict__ residual, + __half* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + extern __shared__ float shared_mem_f16[]; + float* As = shared_mem_f16; + float* Bs = shared_mem_f16 + block_m * block_k; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + float reg_c[8][8]; + #pragma unroll + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0f; + + float reg_a[8], reg_b[8]; + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? __half2float(A[gr * K + gc]) : 0.0f; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? __half2float(B[gr * N + gc]) : 0.0f; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + // EPILOGUE: bias + residual + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + unsigned int idx = global_row * N + global_col; + float val = reg_c[i][j] + __half2float(bias[global_col]) + __half2float(residual[idx]); + C[idx] = __float2half(val); + } + } + } + } +} + +extern "C" __global__ void gemm_bias_residual_batched_f16( + const __half* __restrict__ A, + const __half* __restrict__ B, + const __half* __restrict__ bias, + const __half* __restrict__ residual, + __half* __restrict__ C, + unsigned int batch_count, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + const unsigned int batch = blockIdx.z; + if (batch >= batch_count) return; + + extern __shared__ float shared_mem_f16[]; + float* As = shared_mem_f16; + float* Bs = shared_mem_f16 + block_m * block_k; + + const __half* A_batch = A + batch * M * K; + const __half* B_batch = B + batch * K * N; + const __half* res_batch = residual + batch * M * N; + __half* C_batch = C + batch * M * N; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + float reg_c[8][8]; + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0f; + + float reg_a[8], reg_b[8]; + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? __half2float(A_batch[gr * K + gc]) : 0.0f; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? __half2float(B_batch[gr * N + gc]) : 0.0f; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + unsigned int idx = global_row * N + global_col; + float val = reg_c[i][j] + __half2float(bias[global_col]) + __half2float(res_batch[idx]); + C_batch[idx] = __float2half(val); + } + } + } + } +} + +// ============================================================================ +// GEMM + bias + activation: BF16 +// ============================================================================ + +extern "C" __global__ void gemm_bias_act_bf16( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, + const __nv_bfloat16* __restrict__ bias, + __nv_bfloat16* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int activation_type, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + extern __shared__ float shared_mem_bf16[]; + float* As = shared_mem_bf16; + float* Bs = shared_mem_bf16 + block_m * block_k; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + float reg_c[8][8]; + #pragma unroll + for (int i = 0; i < 8; i++) + #pragma unroll + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0f; + + float reg_a[8]; + float reg_b[8]; + + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? __bfloat162float(A[gr * K + gc]) : 0.0f; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? __bfloat162float(B[gr * N + gc]) : 0.0f; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + // EPILOGUE: bias + activation + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + float val = reg_c[i][j] + __bfloat162float(bias[global_col]); + C[global_row * N + global_col] = __float2bfloat16(apply_activation_f32(val, activation_type)); + } + } + } + } +} + +extern "C" __global__ void gemm_bias_act_batched_bf16( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, + const __nv_bfloat16* __restrict__ bias, + __nv_bfloat16* __restrict__ C, + unsigned int batch_count, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int activation_type, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + const unsigned int batch = blockIdx.z; + if (batch >= batch_count) return; + + extern __shared__ float shared_mem_bf16[]; + float* As = shared_mem_bf16; + float* Bs = shared_mem_bf16 + block_m * block_k; + + const __nv_bfloat16* A_batch = A + batch * M * K; + const __nv_bfloat16* B_batch = B + batch * K * N; + __nv_bfloat16* C_batch = C + batch * M * N; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + float reg_c[8][8]; + #pragma unroll + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0f; + + float reg_a[8], reg_b[8]; + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? __bfloat162float(A_batch[gr * K + gc]) : 0.0f; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? __bfloat162float(B_batch[gr * N + gc]) : 0.0f; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + float val = reg_c[i][j] + __bfloat162float(bias[global_col]); + C_batch[global_row * N + global_col] = __float2bfloat16(apply_activation_f32(val, activation_type)); + } + } + } + } +} + +// ============================================================================ +// GEMM + bias + residual: BF16 +// ============================================================================ + +extern "C" __global__ void gemm_bias_residual_bf16( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, + const __nv_bfloat16* __restrict__ bias, + const __nv_bfloat16* __restrict__ residual, + __nv_bfloat16* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + extern __shared__ float shared_mem_bf16[]; + float* As = shared_mem_bf16; + float* Bs = shared_mem_bf16 + block_m * block_k; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + float reg_c[8][8]; + #pragma unroll + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0f; + + float reg_a[8], reg_b[8]; + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? __bfloat162float(A[gr * K + gc]) : 0.0f; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? __bfloat162float(B[gr * N + gc]) : 0.0f; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + // EPILOGUE: bias + residual + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + unsigned int idx = global_row * N + global_col; + float val = reg_c[i][j] + __bfloat162float(bias[global_col]) + __bfloat162float(residual[idx]); + C[idx] = __float2bfloat16(val); + } + } + } + } +} + +extern "C" __global__ void gemm_bias_residual_batched_bf16( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, + const __nv_bfloat16* __restrict__ bias, + const __nv_bfloat16* __restrict__ residual, + __nv_bfloat16* __restrict__ C, + unsigned int batch_count, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int block_m, + unsigned int block_n, + unsigned int block_k, + unsigned int thread_m, + unsigned int thread_n +) { + const unsigned int batch = blockIdx.z; + if (batch >= batch_count) return; + + extern __shared__ float shared_mem_bf16[]; + float* As = shared_mem_bf16; + float* Bs = shared_mem_bf16 + block_m * block_k; + + const __nv_bfloat16* A_batch = A + batch * M * K; + const __nv_bfloat16* B_batch = B + batch * K * N; + const __nv_bfloat16* res_batch = residual + batch * M * N; + __nv_bfloat16* C_batch = C + batch * M * N; + + const unsigned int tx = threadIdx.x; + const unsigned int ty = threadIdx.y; + const unsigned int threads_x = block_n / thread_n; + const unsigned int block_row = blockIdx.y * block_m; + const unsigned int block_col = blockIdx.x * block_n; + const unsigned int thread_row = ty * thread_m; + const unsigned int thread_col = tx * thread_n; + + float reg_c[8][8]; + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + reg_c[i][j] = 0.0f; + + float reg_a[8], reg_b[8]; + const unsigned int num_k_tiles = (K + block_k - 1) / block_k; + const unsigned int thread_id = ty * threads_x + tx; + const unsigned int num_threads = blockDim.x * blockDim.y; + + for (unsigned int bk = 0; bk < num_k_tiles; bk++) { + const unsigned int k_offset = bk * block_k; + + for (unsigned int load_idx = thread_id; load_idx < block_m * block_k; load_idx += num_threads) { + const unsigned int lr = load_idx / block_k, lc = load_idx % block_k; + const unsigned int gr = block_row + lr, gc = k_offset + lc; + As[lr * block_k + lc] = (gr < M && gc < K) ? __bfloat162float(A_batch[gr * K + gc]) : 0.0f; + } + for (unsigned int load_idx = thread_id; load_idx < block_k * block_n; load_idx += num_threads) { + const unsigned int lr = load_idx / block_n, lc = load_idx % block_n; + const unsigned int gr = k_offset + lr, gc = block_col + lc; + Bs[lr * block_n + lc] = (gr < K && gc < N) ? __bfloat162float(B_batch[gr * N + gc]) : 0.0f; + } + __syncthreads(); + + for (unsigned int k = 0; k < block_k; k++) { + for (unsigned int i = 0; i < thread_m; i++) reg_a[i] = As[(thread_row + i) * block_k + k]; + for (unsigned int j = 0; j < thread_n; j++) reg_b[j] = Bs[k * block_n + thread_col + j]; + for (unsigned int i = 0; i < thread_m; i++) + for (unsigned int j = 0; j < thread_n; j++) + reg_c[i][j] += reg_a[i] * reg_b[j]; + } + __syncthreads(); + } + + for (unsigned int i = 0; i < thread_m; i++) { + const unsigned int global_row = block_row + thread_row + i; + if (global_row < M) { + for (unsigned int j = 0; j < thread_n; j++) { + const unsigned int global_col = block_col + thread_col + j; + if (global_col < N) { + unsigned int idx = global_row * N + global_col; + float val = reg_c[i][j] + __bfloat162float(bias[global_col]) + __bfloat162float(res_batch[idx]); + C_batch[idx] = __float2bfloat16(val); + } + } + } + } +} diff --git a/src/runtime/cuda/kernels/gemm_epilogue/bwd_launcher.rs b/src/runtime/cuda/kernels/gemm_epilogue/bwd_launcher.rs new file mode 100644 index 00000000..d54784ad --- /dev/null +++ b/src/runtime/cuda/kernels/gemm_epilogue/bwd_launcher.rs @@ -0,0 +1,259 @@ +//! CUDA kernel launchers for GEMM epilogue backward operations. + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use super::super::loader::{get_kernel_function, get_or_load_module, kernel_name, launch_config}; +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::ops::GemmActivation; + +const GEMM_EPILOGUE_BWD_MODULE: &str = "gemm_epilogue_bwd"; +const BLOCK_SIZE: u32 = 256; + +fn activation_to_u32(activation: GemmActivation) -> u32 { + match activation { + GemmActivation::None => 0, + GemmActivation::ReLU => 1, + GemmActivation::GELU => 2, + GemmActivation::SiLU => 3, + GemmActivation::Sigmoid => 4, + GemmActivation::Tanh => 5, + } +} + +fn grid_1d(n: u32) -> (u32, u32, u32) { + ((n + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1) +} + +fn block_1d() -> (u32, u32, u32) { + (BLOCK_SIZE, 1, 1) +} + +/// Launch a single-batch GEMM backward pass (4 kernel launches). +/// +/// # Safety +/// All pointers must be valid device memory with correct sizes. +/// `grad_pre_ptr` must point to a temporary buffer of size `m * n * dtype.size_in_bytes()`. +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_gemm_bias_act_bwd_kernel( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + grad_ptr: u64, + a_ptr: u64, + b_ptr: u64, + bias_ptr: u64, + grad_pre_ptr: u64, + d_a_ptr: u64, + d_b_ptr: u64, + d_bias_ptr: u64, + m: usize, + n: usize, + k: usize, + activation: GemmActivation, +) -> Result<()> { + unsafe { + launch_gemm_bwd_kernels( + context, + stream, + device_index, + dtype, + grad_ptr, + a_ptr, + b_ptr, + bias_ptr, + grad_pre_ptr, + d_a_ptr, + d_b_ptr, + d_bias_ptr, + m, + n, + k, + activation, + false, // don't accumulate d_b/d_bias + ) + } +} + +/// Launch batched GEMM backward pass. +/// +/// Batch 0 writes d_b/d_bias, batches 1+ accumulate into d_b/d_bias. +/// d_a is written per-batch at offset. +/// +/// # Safety +/// All pointers must be valid device memory with correct sizes. +/// `grad_pre_ptr` must point to a temporary buffer of size `m * n * dtype.size_in_bytes()`. +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_gemm_bias_act_bwd_batched_kernel( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + grad_ptr: u64, + a_ptr: u64, + b_ptr: u64, + bias_ptr: u64, + grad_pre_ptr: u64, + d_a_ptr: u64, + d_b_ptr: u64, + d_bias_ptr: u64, + batch: usize, + m: usize, + n: usize, + k: usize, + activation: GemmActivation, +) -> Result<()> { + let elem_size = dtype.size_in_bytes() as u64; + let mn_bytes = (m * n) as u64 * elem_size; + let mk_bytes = (m * k) as u64 * elem_size; + let kn_bytes = (k * n) as u64 * elem_size; + + for batch_idx in 0..batch { + let grad_off = grad_ptr + batch_idx as u64 * mn_bytes; + let a_off = a_ptr + batch_idx as u64 * mk_bytes; + let b_off = b_ptr + batch_idx as u64 * kn_bytes; + let d_a_off = d_a_ptr + batch_idx as u64 * mk_bytes; + let accumulate = batch_idx > 0; + + unsafe { + launch_gemm_bwd_kernels( + context, + stream, + device_index, + dtype, + grad_off, + a_off, + b_off, + bias_ptr, + grad_pre_ptr, + d_a_off, + d_b_ptr, + d_bias_ptr, + m, + n, + k, + activation, + accumulate, + )?; + } + } + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +unsafe fn launch_gemm_bwd_kernels( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + grad_ptr: u64, + a_ptr: u64, + b_ptr: u64, + bias_ptr: u64, + grad_pre_ptr: u64, + d_a_ptr: u64, + d_b_ptr: u64, + d_bias_ptr: u64, + m: usize, + n: usize, + k: usize, + activation: GemmActivation, + accumulate: bool, +) -> Result<()> { + let module = get_or_load_module(context, device_index, GEMM_EPILOGUE_BWD_MODULE)?; + + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + let act_u32 = activation_to_u32(activation); + let mn = (m * n) as u32; + let mk = (m * k) as u32; + let kn = (k * n) as u32; + + unsafe { + // Kernel 1: grad_pre = grad * act'(A @ B + bias) + { + let func_name = kernel_name("gemm_bias_act_bwd_grad_pre", dtype); + let func = get_kernel_function(&module, &func_name)?; + let cfg = launch_config(grid_1d(mn), block_1d(), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&grad_ptr); + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&bias_ptr); + builder.arg(&grad_pre_ptr); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder.arg(&act_u32); + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA gemm_bwd_grad_pre launch failed: {:?}", e)) + })?; + } + + // Kernel 2: d_a = grad_pre @ B^T (always write, not accumulate) + { + let func_name = kernel_name("gemm_bwd_da", dtype); + let func = get_kernel_function(&module, &func_name)?; + let cfg = launch_config(grid_1d(mk), block_1d(), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&grad_pre_ptr); + builder.arg(&b_ptr); + builder.arg(&d_a_ptr); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA gemm_bwd_da launch failed: {:?}", e)))?; + } + + // Kernel 3: d_b = A^T @ grad_pre (or d_b += for accumulate) + { + let base = if accumulate { + "gemm_bwd_db_accum" + } else { + "gemm_bwd_db" + }; + let func_name = kernel_name(base, dtype); + let func = get_kernel_function(&module, &func_name)?; + let cfg = launch_config(grid_1d(kn), block_1d(), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&a_ptr); + builder.arg(&grad_pre_ptr); + builder.arg(&d_b_ptr); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA gemm_bwd_db launch failed: {:?}", e)))?; + } + + // Kernel 4: d_bias = sum(grad_pre, dim=0) (or += for accumulate) + { + let base = if accumulate { + "gemm_bwd_dbias_accum" + } else { + "gemm_bwd_dbias" + }; + let func_name = kernel_name(base, dtype); + let func = get_kernel_function(&module, &func_name)?; + let cfg = launch_config(grid_1d(n_u32), block_1d(), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&grad_pre_ptr); + builder.arg(&d_bias_ptr); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA gemm_bwd_dbias launch failed: {:?}", e)) + })?; + } + } + + Ok(()) +} diff --git a/src/runtime/cuda/kernels/gemm_epilogue/launcher.rs b/src/runtime/cuda/kernels/gemm_epilogue/launcher.rs new file mode 100644 index 00000000..43d99ee9 --- /dev/null +++ b/src/runtime/cuda/kernels/gemm_epilogue/launcher.rs @@ -0,0 +1,319 @@ +//! CUDA kernel launchers for GEMM epilogue operations. + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use super::super::loader::{ + get_kernel_function, get_or_load_module, kernel_name, matmul_batched_launch_config, + matmul_launch_config, +}; +use crate::algorithm::TileConfig; +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::ops::GemmActivation; + +const GEMM_EPILOGUE_MODULE: &str = "gemm_epilogue"; + +fn activation_to_u32(activation: GemmActivation) -> u32 { + match activation { + GemmActivation::None => 0, + GemmActivation::ReLU => 1, + GemmActivation::GELU => 2, + GemmActivation::SiLU => 3, + GemmActivation::Sigmoid => 4, + GemmActivation::Tanh => 5, + } +} + +fn default_tile_config(dtype: DType) -> TileConfig { + match dtype { + DType::F64 => TileConfig { + block_m: 32, + block_n: 32, + block_k: 8, + thread_m: 4, + thread_n: 4, + }, + _ => TileConfig { + block_m: 64, + block_n: 64, + block_k: 8, + thread_m: 8, + thread_n: 8, + }, + } +} + +/// Launch fused GEMM + bias + activation kernel. +/// +/// # Safety +/// All pointers must be valid device memory. +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_gemm_bias_act_kernel( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + a_ptr: u64, + b_ptr: u64, + bias_ptr: u64, + c_ptr: u64, + m: usize, + n: usize, + k: usize, + activation: GemmActivation, +) -> Result<()> { + let tile_cfg = default_tile_config(dtype); + let module = get_or_load_module(context, device_index, GEMM_EPILOGUE_MODULE)?; + let func_name = kernel_name("gemm_bias_act", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let elem_size = dtype.size_in_bytes(); + let shared_elem_size = match dtype { + DType::F16 | DType::BF16 => 4, + _ => elem_size, + }; + + let cfg = matmul_launch_config(m, n, &tile_cfg, shared_elem_size); + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + let act_u32 = activation_to_u32(activation); + let block_m = tile_cfg.block_m as u32; + let block_n = tile_cfg.block_n as u32; + let block_k = tile_cfg.block_k as u32; + let thread_m = tile_cfg.thread_m as u32; + let thread_n = tile_cfg.thread_n as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&bias_ptr); + builder.arg(&c_ptr); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder.arg(&act_u32); + builder.arg(&block_m); + builder.arg(&block_n); + builder.arg(&block_k); + builder.arg(&thread_m); + builder.arg(&thread_n); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA gemm_bias_act kernel launch failed: {:?}", e)) + })?; + } + + Ok(()) +} + +/// Launch batched fused GEMM + bias + activation kernel. +/// +/// # Safety +/// All pointers must be valid device memory. +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_gemm_bias_act_batched_kernel( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + a_ptr: u64, + b_ptr: u64, + bias_ptr: u64, + c_ptr: u64, + batch: usize, + m: usize, + n: usize, + k: usize, + activation: GemmActivation, +) -> Result<()> { + let tile_cfg = default_tile_config(dtype); + let module = get_or_load_module(context, device_index, GEMM_EPILOGUE_MODULE)?; + let func_name = kernel_name("gemm_bias_act_batched", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let elem_size = dtype.size_in_bytes(); + let shared_elem_size = match dtype { + DType::F16 | DType::BF16 => 4, + _ => elem_size, + }; + + let cfg = matmul_batched_launch_config(batch, m, n, &tile_cfg, shared_elem_size); + let batch_u32 = batch as u32; + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + let act_u32 = activation_to_u32(activation); + let block_m = tile_cfg.block_m as u32; + let block_n = tile_cfg.block_n as u32; + let block_k = tile_cfg.block_k as u32; + let thread_m = tile_cfg.thread_m as u32; + let thread_n = tile_cfg.thread_n as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&bias_ptr); + builder.arg(&c_ptr); + builder.arg(&batch_u32); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder.arg(&act_u32); + builder.arg(&block_m); + builder.arg(&block_n); + builder.arg(&block_k); + builder.arg(&thread_m); + builder.arg(&thread_n); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA gemm_bias_act_batched kernel launch failed: {:?}", + e + )) + })?; + } + + Ok(()) +} + +/// Launch fused GEMM + bias + residual kernel. +/// +/// # Safety +/// All pointers must be valid device memory. +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_gemm_bias_residual_kernel( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + a_ptr: u64, + b_ptr: u64, + bias_ptr: u64, + residual_ptr: u64, + c_ptr: u64, + m: usize, + n: usize, + k: usize, +) -> Result<()> { + let tile_cfg = default_tile_config(dtype); + let module = get_or_load_module(context, device_index, GEMM_EPILOGUE_MODULE)?; + let func_name = kernel_name("gemm_bias_residual", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let elem_size = dtype.size_in_bytes(); + let shared_elem_size = match dtype { + DType::F16 | DType::BF16 => 4, + _ => elem_size, + }; + + let cfg = matmul_launch_config(m, n, &tile_cfg, shared_elem_size); + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + let block_m = tile_cfg.block_m as u32; + let block_n = tile_cfg.block_n as u32; + let block_k = tile_cfg.block_k as u32; + let thread_m = tile_cfg.thread_m as u32; + let thread_n = tile_cfg.thread_n as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&bias_ptr); + builder.arg(&residual_ptr); + builder.arg(&c_ptr); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder.arg(&block_m); + builder.arg(&block_n); + builder.arg(&block_k); + builder.arg(&thread_m); + builder.arg(&thread_n); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA gemm_bias_residual kernel launch failed: {:?}", + e + )) + })?; + } + + Ok(()) +} + +/// Launch batched fused GEMM + bias + residual kernel. +/// +/// # Safety +/// All pointers must be valid device memory. +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_gemm_bias_residual_batched_kernel( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + a_ptr: u64, + b_ptr: u64, + bias_ptr: u64, + residual_ptr: u64, + c_ptr: u64, + batch: usize, + m: usize, + n: usize, + k: usize, +) -> Result<()> { + let tile_cfg = default_tile_config(dtype); + let module = get_or_load_module(context, device_index, GEMM_EPILOGUE_MODULE)?; + let func_name = kernel_name("gemm_bias_residual_batched", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let elem_size = dtype.size_in_bytes(); + let shared_elem_size = match dtype { + DType::F16 | DType::BF16 => 4, + _ => elem_size, + }; + + let cfg = matmul_batched_launch_config(batch, m, n, &tile_cfg, shared_elem_size); + let batch_u32 = batch as u32; + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + let block_m = tile_cfg.block_m as u32; + let block_n = tile_cfg.block_n as u32; + let block_k = tile_cfg.block_k as u32; + let thread_m = tile_cfg.thread_m as u32; + let thread_n = tile_cfg.thread_n as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&bias_ptr); + builder.arg(&residual_ptr); + builder.arg(&c_ptr); + builder.arg(&batch_u32); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder.arg(&block_m); + builder.arg(&block_n); + builder.arg(&block_k); + builder.arg(&thread_m); + builder.arg(&thread_n); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA gemm_bias_residual_batched kernel launch failed: {:?}", + e + )) + })?; + } + + Ok(()) +} diff --git a/src/runtime/cuda/kernels/gemm_epilogue/mod.rs b/src/runtime/cuda/kernels/gemm_epilogue/mod.rs new file mode 100644 index 00000000..2c365362 --- /dev/null +++ b/src/runtime/cuda/kernels/gemm_epilogue/mod.rs @@ -0,0 +1,10 @@ +//! CUDA GEMM epilogue kernels and launchers. + +mod bwd_launcher; +mod launcher; + +pub use bwd_launcher::{launch_gemm_bias_act_bwd_batched_kernel, launch_gemm_bias_act_bwd_kernel}; +pub use launcher::{ + launch_gemm_bias_act_batched_kernel, launch_gemm_bias_act_kernel, + launch_gemm_bias_residual_batched_kernel, launch_gemm_bias_residual_kernel, +}; diff --git a/src/runtime/cuda/kernels/gemm_epilogue_bwd.cu b/src/runtime/cuda/kernels/gemm_epilogue_bwd.cu new file mode 100644 index 00000000..80e4a8e5 --- /dev/null +++ b/src/runtime/cuda/kernels/gemm_epilogue_bwd.cu @@ -0,0 +1,640 @@ +// Backward kernels for fused GEMM epilogue: activation(A @ B + bias) +// +// Kernels per dtype: +// 1. gemm_bias_act_bwd_grad_pre: grad_pre = grad * act'(A @ B + bias) +// 2. gemm_bwd_da: d_a = grad_pre @ B^T +// 3. gemm_bwd_db: d_b = A^T @ grad_pre (write) +// 4. gemm_bwd_db_accum: d_b += A^T @ grad_pre (accumulate for batched) +// 5. gemm_bwd_dbias: d_bias = sum(grad_pre, dim=0) (write) +// 6. gemm_bwd_dbias_accum: d_bias += sum(grad_pre, dim=0) (accumulate for batched) +// +// activation_type: 0=None, 1=ReLU, 2=GELU, 3=SiLU, 4=Sigmoid, 5=Tanh + +#include +#include +#include "dtype_traits.cuh" +#include "activation_deriv.cuh" + +extern "C" { + +// ============================================================================ +// F32 Backward Kernels +// ============================================================================ + +__global__ void gemm_bias_act_bwd_grad_pre_f32( + const float* __restrict__ grad, + const float* __restrict__ A, + const float* __restrict__ B, + const float* __restrict__ bias, + float* __restrict__ grad_pre, + unsigned int M, unsigned int N, unsigned int K, + unsigned int activation_type +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= M * N) return; + unsigned int i = idx / N; + unsigned int j = idx % N; + double pre_act = (double)bias[j]; + for (unsigned int kk = 0; kk < K; kk++) { + pre_act += (double)A[i * K + kk] * (double)B[kk * N + j]; + } + grad_pre[idx] = grad[idx] * (float)activation_deriv_f64(pre_act, activation_type); +} + +__global__ void gemm_bwd_da_f32( + const float* __restrict__ grad_pre, + const float* __restrict__ B, + float* __restrict__ d_a, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= M * K) return; + unsigned int i = idx / K; + unsigned int k = idx % K; + double sum = 0.0; + for (unsigned int j = 0; j < N; j++) { + sum += (double)grad_pre[i * N + j] * (double)B[k * N + j]; + } + d_a[idx] = (float)sum; +} + +__global__ void gemm_bwd_db_f32( + const float* __restrict__ A, + const float* __restrict__ grad_pre, + float* __restrict__ d_b, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= K * N) return; + unsigned int k = idx / N; + unsigned int j = idx % N; + double sum = 0.0; + for (unsigned int i = 0; i < M; i++) { + sum += (double)A[i * K + k] * (double)grad_pre[i * N + j]; + } + d_b[idx] = (float)sum; +} + +__global__ void gemm_bwd_db_accum_f32( + const float* __restrict__ A, + const float* __restrict__ grad_pre, + float* __restrict__ d_b, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= K * N) return; + unsigned int k = idx / N; + unsigned int j = idx % N; + double sum = 0.0; + for (unsigned int i = 0; i < M; i++) { + sum += (double)A[i * K + k] * (double)grad_pre[i * N + j]; + } + d_b[idx] += (float)sum; +} + +__global__ void gemm_bwd_dbias_f32( + const float* __restrict__ grad_pre, + float* __restrict__ d_bias, + unsigned int M, unsigned int N +) { + unsigned int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j >= N) return; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += grad_pre[i * N + j]; + } + d_bias[j] = sum; +} + +__global__ void gemm_bwd_dbias_accum_f32( + const float* __restrict__ grad_pre, + float* __restrict__ d_bias, + unsigned int M, unsigned int N +) { + unsigned int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j >= N) return; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += grad_pre[i * N + j]; + } + d_bias[j] += sum; +} + +// ============================================================================ +// F64 Backward Kernels +// ============================================================================ + +__global__ void gemm_bias_act_bwd_grad_pre_f64( + const double* __restrict__ grad, + const double* __restrict__ A, + const double* __restrict__ B, + const double* __restrict__ bias, + double* __restrict__ grad_pre, + unsigned int M, unsigned int N, unsigned int K, + unsigned int activation_type +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= M * N) return; + unsigned int i = idx / N; + unsigned int j = idx % N; + double pre_act = bias[j]; + for (unsigned int kk = 0; kk < K; kk++) { + pre_act += A[i * K + kk] * B[kk * N + j]; + } + grad_pre[idx] = grad[idx] * activation_deriv_f64(pre_act, activation_type); +} + +__global__ void gemm_bwd_da_f64( + const double* __restrict__ grad_pre, + const double* __restrict__ B, + double* __restrict__ d_a, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= M * K) return; + unsigned int i = idx / K; + unsigned int k = idx % K; + double sum = 0.0; + for (unsigned int j = 0; j < N; j++) { + sum += grad_pre[i * N + j] * B[k * N + j]; + } + d_a[idx] = sum; +} + +__global__ void gemm_bwd_db_f64( + const double* __restrict__ A, + const double* __restrict__ grad_pre, + double* __restrict__ d_b, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= K * N) return; + unsigned int k = idx / N; + unsigned int j = idx % N; + double sum = 0.0; + for (unsigned int i = 0; i < M; i++) { + sum += A[i * K + k] * grad_pre[i * N + j]; + } + d_b[idx] = sum; +} + +__global__ void gemm_bwd_db_accum_f64( + const double* __restrict__ A, + const double* __restrict__ grad_pre, + double* __restrict__ d_b, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= K * N) return; + unsigned int k = idx / N; + unsigned int j = idx % N; + double sum = 0.0; + for (unsigned int i = 0; i < M; i++) { + sum += A[i * K + k] * grad_pre[i * N + j]; + } + d_b[idx] += sum; +} + +__global__ void gemm_bwd_dbias_f64( + const double* __restrict__ grad_pre, + double* __restrict__ d_bias, + unsigned int M, unsigned int N +) { + unsigned int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j >= N) return; + double sum = 0.0; + for (unsigned int i = 0; i < M; i++) { + sum += grad_pre[i * N + j]; + } + d_bias[j] = sum; +} + +__global__ void gemm_bwd_dbias_accum_f64( + const double* __restrict__ grad_pre, + double* __restrict__ d_bias, + unsigned int M, unsigned int N +) { + unsigned int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j >= N) return; + double sum = 0.0; + for (unsigned int i = 0; i < M; i++) { + sum += grad_pre[i * N + j]; + } + d_bias[j] += sum; +} + +// ============================================================================ +// F16 Backward Kernels (compute in F32) +// ============================================================================ + +__global__ void gemm_bias_act_bwd_grad_pre_f16( + const __half* __restrict__ grad, + const __half* __restrict__ A, + const __half* __restrict__ B, + const __half* __restrict__ bias, + __half* __restrict__ grad_pre, + unsigned int M, unsigned int N, unsigned int K, + unsigned int activation_type +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= M * N) return; + unsigned int i = idx / N; + unsigned int j = idx % N; + float pre_act = __half2float(bias[j]); + for (unsigned int kk = 0; kk < K; kk++) { + pre_act += __half2float(A[i * K + kk]) * __half2float(B[kk * N + j]); + } + grad_pre[idx] = __float2half(__half2float(grad[idx]) * activation_deriv_f32(pre_act, activation_type)); +} + +__global__ void gemm_bwd_da_f16( + const __half* __restrict__ grad_pre, + const __half* __restrict__ B, + __half* __restrict__ d_a, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= M * K) return; + unsigned int i = idx / K; + unsigned int k = idx % K; + float sum = 0.0f; + for (unsigned int j = 0; j < N; j++) { + sum += __half2float(grad_pre[i * N + j]) * __half2float(B[k * N + j]); + } + d_a[idx] = __float2half(sum); +} + +__global__ void gemm_bwd_db_f16( + const __half* __restrict__ A, + const __half* __restrict__ grad_pre, + __half* __restrict__ d_b, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= K * N) return; + unsigned int k = idx / N; + unsigned int j = idx % N; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += __half2float(A[i * K + k]) * __half2float(grad_pre[i * N + j]); + } + d_b[idx] = __float2half(sum); +} + +__global__ void gemm_bwd_db_accum_f16( + const __half* __restrict__ A, + const __half* __restrict__ grad_pre, + __half* __restrict__ d_b, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= K * N) return; + unsigned int k = idx / N; + unsigned int j = idx % N; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += __half2float(A[i * K + k]) * __half2float(grad_pre[i * N + j]); + } + d_b[idx] = __float2half(__half2float(d_b[idx]) + sum); +} + +__global__ void gemm_bwd_dbias_f16( + const __half* __restrict__ grad_pre, + __half* __restrict__ d_bias, + unsigned int M, unsigned int N +) { + unsigned int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j >= N) return; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += __half2float(grad_pre[i * N + j]); + } + d_bias[j] = __float2half(sum); +} + +__global__ void gemm_bwd_dbias_accum_f16( + const __half* __restrict__ grad_pre, + __half* __restrict__ d_bias, + unsigned int M, unsigned int N +) { + unsigned int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j >= N) return; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += __half2float(grad_pre[i * N + j]); + } + d_bias[j] = __float2half(__half2float(d_bias[j]) + sum); +} + +// ============================================================================ +// BF16 Backward Kernels (compute in F32) +// ============================================================================ + +__global__ void gemm_bias_act_bwd_grad_pre_bf16( + const __nv_bfloat16* __restrict__ grad, + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, + const __nv_bfloat16* __restrict__ bias, + __nv_bfloat16* __restrict__ grad_pre, + unsigned int M, unsigned int N, unsigned int K, + unsigned int activation_type +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= M * N) return; + unsigned int i = idx / N; + unsigned int j = idx % N; + float pre_act = __bfloat162float(bias[j]); + for (unsigned int kk = 0; kk < K; kk++) { + pre_act += __bfloat162float(A[i * K + kk]) * __bfloat162float(B[kk * N + j]); + } + grad_pre[idx] = __float2bfloat16(__bfloat162float(grad[idx]) * activation_deriv_f32(pre_act, activation_type)); +} + +__global__ void gemm_bwd_da_bf16( + const __nv_bfloat16* __restrict__ grad_pre, + const __nv_bfloat16* __restrict__ B, + __nv_bfloat16* __restrict__ d_a, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= M * K) return; + unsigned int i = idx / K; + unsigned int k = idx % K; + float sum = 0.0f; + for (unsigned int j = 0; j < N; j++) { + sum += __bfloat162float(grad_pre[i * N + j]) * __bfloat162float(B[k * N + j]); + } + d_a[idx] = __float2bfloat16(sum); +} + +__global__ void gemm_bwd_db_bf16( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ grad_pre, + __nv_bfloat16* __restrict__ d_b, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= K * N) return; + unsigned int k = idx / N; + unsigned int j = idx % N; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += __bfloat162float(A[i * K + k]) * __bfloat162float(grad_pre[i * N + j]); + } + d_b[idx] = __float2bfloat16(sum); +} + +__global__ void gemm_bwd_db_accum_bf16( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ grad_pre, + __nv_bfloat16* __restrict__ d_b, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= K * N) return; + unsigned int k = idx / N; + unsigned int j = idx % N; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += __bfloat162float(A[i * K + k]) * __bfloat162float(grad_pre[i * N + j]); + } + d_b[idx] = __float2bfloat16(__bfloat162float(d_b[idx]) + sum); +} + +__global__ void gemm_bwd_dbias_bf16( + const __nv_bfloat16* __restrict__ grad_pre, + __nv_bfloat16* __restrict__ d_bias, + unsigned int M, unsigned int N +) { + unsigned int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j >= N) return; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += __bfloat162float(grad_pre[i * N + j]); + } + d_bias[j] = __float2bfloat16(sum); +} + +__global__ void gemm_bwd_dbias_accum_bf16( + const __nv_bfloat16* __restrict__ grad_pre, + __nv_bfloat16* __restrict__ d_bias, + unsigned int M, unsigned int N +) { + unsigned int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j >= N) return; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += __bfloat162float(grad_pre[i * N + j]); + } + d_bias[j] = __float2bfloat16(__bfloat162float(d_bias[j]) + sum); +} + +// ============================================================================ +// FP8 E4M3 Backward Kernels (compute in F32) +// ============================================================================ + +__global__ void gemm_bias_act_bwd_grad_pre_fp8_e4m3( + const numr_fp8_e4m3* __restrict__ grad, + const numr_fp8_e4m3* __restrict__ A, + const numr_fp8_e4m3* __restrict__ B, + const numr_fp8_e4m3* __restrict__ bias, + numr_fp8_e4m3* __restrict__ grad_pre, + unsigned int M, unsigned int N, unsigned int K, + unsigned int activation_type +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= M * N) return; + unsigned int i = idx / N; + unsigned int j = idx % N; + float pre_act = fp8_e4m3_to_f32(bias[j].data); + for (unsigned int kk = 0; kk < K; kk++) { + pre_act += fp8_e4m3_to_f32(A[i * K + kk].data) * fp8_e4m3_to_f32(B[kk * N + j].data); + } + float g = fp8_e4m3_to_f32(grad[idx].data); + grad_pre[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(g * activation_deriv_f32(pre_act, activation_type))); +} + +__global__ void gemm_bwd_da_fp8_e4m3( + const numr_fp8_e4m3* __restrict__ grad_pre, + const numr_fp8_e4m3* __restrict__ B, + numr_fp8_e4m3* __restrict__ d_a, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= M * K) return; + unsigned int i = idx / K; + unsigned int k = idx % K; + float sum = 0.0f; + for (unsigned int j = 0; j < N; j++) { + sum += fp8_e4m3_to_f32(grad_pre[i * N + j].data) * fp8_e4m3_to_f32(B[k * N + j].data); + } + d_a[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(sum)); +} + +__global__ void gemm_bwd_db_fp8_e4m3( + const numr_fp8_e4m3* __restrict__ A, + const numr_fp8_e4m3* __restrict__ grad_pre, + numr_fp8_e4m3* __restrict__ d_b, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= K * N) return; + unsigned int k = idx / N; + unsigned int j = idx % N; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += fp8_e4m3_to_f32(A[i * K + k].data) * fp8_e4m3_to_f32(grad_pre[i * N + j].data); + } + d_b[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(sum)); +} + +__global__ void gemm_bwd_db_accum_fp8_e4m3( + const numr_fp8_e4m3* __restrict__ A, + const numr_fp8_e4m3* __restrict__ grad_pre, + numr_fp8_e4m3* __restrict__ d_b, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= K * N) return; + unsigned int k = idx / N; + unsigned int j = idx % N; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += fp8_e4m3_to_f32(A[i * K + k].data) * fp8_e4m3_to_f32(grad_pre[i * N + j].data); + } + d_b[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(fp8_e4m3_to_f32(d_b[idx].data) + sum)); +} + +__global__ void gemm_bwd_dbias_fp8_e4m3( + const numr_fp8_e4m3* __restrict__ grad_pre, + numr_fp8_e4m3* __restrict__ d_bias, + unsigned int M, unsigned int N +) { + unsigned int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j >= N) return; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += fp8_e4m3_to_f32(grad_pre[i * N + j].data); + } + d_bias[j] = numr_fp8_e4m3(f32_to_fp8_e4m3(sum)); +} + +__global__ void gemm_bwd_dbias_accum_fp8_e4m3( + const numr_fp8_e4m3* __restrict__ grad_pre, + numr_fp8_e4m3* __restrict__ d_bias, + unsigned int M, unsigned int N +) { + unsigned int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j >= N) return; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += fp8_e4m3_to_f32(grad_pre[i * N + j].data); + } + d_bias[j] = numr_fp8_e4m3(f32_to_fp8_e4m3(fp8_e4m3_to_f32(d_bias[j].data) + sum)); +} + +// ============================================================================ +// FP8 E5M2 Backward Kernels (compute in F32) +// ============================================================================ + +__global__ void gemm_bias_act_bwd_grad_pre_fp8_e5m2( + const numr_fp8_e5m2* __restrict__ grad, + const numr_fp8_e5m2* __restrict__ A, + const numr_fp8_e5m2* __restrict__ B, + const numr_fp8_e5m2* __restrict__ bias, + numr_fp8_e5m2* __restrict__ grad_pre, + unsigned int M, unsigned int N, unsigned int K, + unsigned int activation_type +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= M * N) return; + unsigned int i = idx / N; + unsigned int j = idx % N; + float pre_act = fp8_e5m2_to_f32(bias[j].data); + for (unsigned int kk = 0; kk < K; kk++) { + pre_act += fp8_e5m2_to_f32(A[i * K + kk].data) * fp8_e5m2_to_f32(B[kk * N + j].data); + } + float g = fp8_e5m2_to_f32(grad[idx].data); + grad_pre[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(g * activation_deriv_f32(pre_act, activation_type))); +} + +__global__ void gemm_bwd_da_fp8_e5m2( + const numr_fp8_e5m2* __restrict__ grad_pre, + const numr_fp8_e5m2* __restrict__ B, + numr_fp8_e5m2* __restrict__ d_a, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= M * K) return; + unsigned int i = idx / K; + unsigned int k = idx % K; + float sum = 0.0f; + for (unsigned int j = 0; j < N; j++) { + sum += fp8_e5m2_to_f32(grad_pre[i * N + j].data) * fp8_e5m2_to_f32(B[k * N + j].data); + } + d_a[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(sum)); +} + +__global__ void gemm_bwd_db_fp8_e5m2( + const numr_fp8_e5m2* __restrict__ A, + const numr_fp8_e5m2* __restrict__ grad_pre, + numr_fp8_e5m2* __restrict__ d_b, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= K * N) return; + unsigned int k = idx / N; + unsigned int j = idx % N; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += fp8_e5m2_to_f32(A[i * K + k].data) * fp8_e5m2_to_f32(grad_pre[i * N + j].data); + } + d_b[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(sum)); +} + +__global__ void gemm_bwd_db_accum_fp8_e5m2( + const numr_fp8_e5m2* __restrict__ A, + const numr_fp8_e5m2* __restrict__ grad_pre, + numr_fp8_e5m2* __restrict__ d_b, + unsigned int M, unsigned int N, unsigned int K +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= K * N) return; + unsigned int k = idx / N; + unsigned int j = idx % N; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += fp8_e5m2_to_f32(A[i * K + k].data) * fp8_e5m2_to_f32(grad_pre[i * N + j].data); + } + d_b[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(fp8_e5m2_to_f32(d_b[idx].data) + sum)); +} + +__global__ void gemm_bwd_dbias_fp8_e5m2( + const numr_fp8_e5m2* __restrict__ grad_pre, + numr_fp8_e5m2* __restrict__ d_bias, + unsigned int M, unsigned int N +) { + unsigned int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j >= N) return; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += fp8_e5m2_to_f32(grad_pre[i * N + j].data); + } + d_bias[j] = numr_fp8_e5m2(f32_to_fp8_e5m2(sum)); +} + +__global__ void gemm_bwd_dbias_accum_fp8_e5m2( + const numr_fp8_e5m2* __restrict__ grad_pre, + numr_fp8_e5m2* __restrict__ d_bias, + unsigned int M, unsigned int N +) { + unsigned int j = blockIdx.x * blockDim.x + threadIdx.x; + if (j >= N) return; + float sum = 0.0f; + for (unsigned int i = 0; i < M; i++) { + sum += fp8_e5m2_to_f32(grad_pre[i * N + j].data); + } + d_bias[j] = numr_fp8_e5m2(f32_to_fp8_e5m2(fp8_e5m2_to_f32(d_bias[j].data) + sum)); +} + +} // extern "C" diff --git a/src/runtime/cuda/kernels/gemv.cu b/src/runtime/cuda/kernels/gemv.cu new file mode 100644 index 00000000..e046f51d --- /dev/null +++ b/src/runtime/cuda/kernels/gemv.cu @@ -0,0 +1,574 @@ +// GEMV (General Matrix-Vector Multiply) CUDA Kernels +// C[M,N] = A[M,K] @ B[K,N] for small M (M <= 16, typically M=1 for LLM decode) +// +// Two kernel families: +// +// 1. gemv_* : B is [K,N] row-major (non-transposed) +// - One thread per output column, iterates K +// - Coalesced B reads: consecutive threads read B[k*N + col], B[k*N + col+1] +// - Grid: (ceil(N/256), M, batch), block: (256, 1, 1) +// +// 2. gemv_bt_* : B is [N,K] row-major (transposed weight, the common case for nn.Linear) +// - Warp-cooperative: each warp reduces one output column along K +// - Coalesced B reads: lanes read B[col*K + lane], B[col*K + lane+1] (stride-1) +// - Grid: (ceil(N/WARPS_PER_BLOCK), M, batch), block: (256, 1, 1) +// +// The bt (B-transposed) variant avoids a 500MB contiguous copy when Linear +// computes y = x @ W^T by passing the raw [N,K] weight pointer directly. + +#include +#include + +// ============================================================================ +// Non-transposed B: one thread per output, iterate K +// B layout: [K, N] row-major — B[k,n] = B_data[k*N + n] +// ============================================================================ + +extern "C" __global__ void gemv_bf16( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, + __nv_bfloat16* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int a_batch_count, + unsigned int b_batch_count +) { + const unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + const unsigned int m = blockIdx.y; + const unsigned int batch = blockIdx.z; + if (col >= N) return; + + const __nv_bfloat16* a_row = A + (batch % a_batch_count) * M * K + m * K; + const __nv_bfloat16* b_base = B + (batch % b_batch_count) * K * N; + + float acc = 0.0f; + for (unsigned int k = 0; k < K; k++) { + acc += __bfloat162float(a_row[k]) * __bfloat162float(b_base[k * N + col]); + } + + C[batch * M * N + m * N + col] = __float2bfloat16(acc); +} + +extern "C" __global__ void gemv_f32( + const float* __restrict__ A, + const float* __restrict__ B, + float* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int a_batch_count, + unsigned int b_batch_count +) { + const unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + const unsigned int m = blockIdx.y; + const unsigned int batch = blockIdx.z; + if (col >= N) return; + + const float* a_row = A + (batch % a_batch_count) * M * K + m * K; + const float* b_base = B + (batch % b_batch_count) * K * N; + + float acc = 0.0f; + for (unsigned int k = 0; k < K; k++) { + acc += a_row[k] * b_base[k * N + col]; + } + + C[batch * M * N + m * N + col] = acc; +} + +extern "C" __global__ void gemv_f16( + const half* __restrict__ A, + const half* __restrict__ B, + half* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int a_batch_count, + unsigned int b_batch_count +) { + const unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + const unsigned int m = blockIdx.y; + const unsigned int batch = blockIdx.z; + if (col >= N) return; + + const half* a_row = A + (batch % a_batch_count) * M * K + m * K; + const half* b_base = B + (batch % b_batch_count) * K * N; + + float acc = 0.0f; + for (unsigned int k = 0; k < K; k++) { + acc += __half2float(a_row[k]) * __half2float(b_base[k * N + col]); + } + + C[batch * M * N + m * N + col] = __float2half(acc); +} + +extern "C" __global__ void gemv_f64( + const double* __restrict__ A, + const double* __restrict__ B, + double* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int a_batch_count, + unsigned int b_batch_count +) { + const unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + const unsigned int m = blockIdx.y; + const unsigned int batch = blockIdx.z; + if (col >= N) return; + + const double* a_row = A + (batch % a_batch_count) * M * K + m * K; + const double* b_base = B + (batch % b_batch_count) * K * N; + + double acc = 0.0; + for (unsigned int k = 0; k < K; k++) { + acc += a_row[k] * b_base[k * N + col]; + } + + C[batch * M * N + m * N + col] = acc; +} + +// ============================================================================ +// Transposed B: warp-cooperative K-reduction +// B layout: [N, K] row-major (weight matrix) — B_logical[k,n] = B_data[n*K + k] +// +// Each warp handles one output column. Lanes cooperatively reduce along K. +// B_data[col*K + lane_id] reads are stride-1 (coalesced within each warp). +// ============================================================================ + +#define WARP_SIZE 32 +#define WARPS_PER_BLOCK 8 +#define BLOCK_SIZE (WARP_SIZE * WARPS_PER_BLOCK) + +extern "C" __global__ void gemv_bt_bf16( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, // stored [N, K] row-major + __nv_bfloat16* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int a_batch_count, + unsigned int b_batch_count +) { + const unsigned int warp_id = threadIdx.x / WARP_SIZE; + const unsigned int lane_id = threadIdx.x % WARP_SIZE; + const unsigned int col = blockIdx.x * WARPS_PER_BLOCK + warp_id; + const unsigned int m = blockIdx.y; + const unsigned int batch = blockIdx.z; + if (col >= N) return; + + const __nv_bfloat16* a_row = A + (batch % a_batch_count) * M * K + m * K; + const __nv_bfloat16* b_row = B + (batch % b_batch_count) * N * K + col * K; // B[col, 0..K] + + float acc = 0.0f; + for (unsigned int k = lane_id; k < K; k += WARP_SIZE) { + acc += __bfloat162float(a_row[k]) * __bfloat162float(b_row[k]); + } + + // Warp-level reduction + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { + acc += __shfl_down_sync(0xFFFFFFFF, acc, offset); + } + + if (lane_id == 0) { + C[batch * M * N + m * N + col] = __float2bfloat16(acc); + } +} + +extern "C" __global__ void gemv_bt_f32( + const float* __restrict__ A, + const float* __restrict__ B, + float* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int a_batch_count, + unsigned int b_batch_count +) { + const unsigned int warp_id = threadIdx.x / WARP_SIZE; + const unsigned int lane_id = threadIdx.x % WARP_SIZE; + const unsigned int col = blockIdx.x * WARPS_PER_BLOCK + warp_id; + const unsigned int m = blockIdx.y; + const unsigned int batch = blockIdx.z; + if (col >= N) return; + + const float* a_row = A + (batch % a_batch_count) * M * K + m * K; + const float* b_row = B + (batch % b_batch_count) * N * K + col * K; + + float acc = 0.0f; + for (unsigned int k = lane_id; k < K; k += WARP_SIZE) { + acc += a_row[k] * b_row[k]; + } + + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { + acc += __shfl_down_sync(0xFFFFFFFF, acc, offset); + } + + if (lane_id == 0) { + C[batch * M * N + m * N + col] = acc; + } +} + +extern "C" __global__ void gemv_bt_f16( + const half* __restrict__ A, + const half* __restrict__ B, + half* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int a_batch_count, + unsigned int b_batch_count +) { + const unsigned int warp_id = threadIdx.x / WARP_SIZE; + const unsigned int lane_id = threadIdx.x % WARP_SIZE; + const unsigned int col = blockIdx.x * WARPS_PER_BLOCK + warp_id; + const unsigned int m = blockIdx.y; + const unsigned int batch = blockIdx.z; + if (col >= N) return; + + const half* a_row = A + (batch % a_batch_count) * M * K + m * K; + const half* b_row = B + (batch % b_batch_count) * N * K + col * K; + + float acc = 0.0f; + for (unsigned int k = lane_id; k < K; k += WARP_SIZE) { + acc += __half2float(a_row[k]) * __half2float(b_row[k]); + } + + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { + acc += __shfl_down_sync(0xFFFFFFFF, acc, offset); + } + + if (lane_id == 0) { + C[batch * M * N + m * N + col] = __float2half(acc); + } +} + +// ============================================================================ +// Multi-Row Transposed B with Vectorized Loads +// +// Each warp computes ROWS_PER_WARP output columns. Activation vector loaded +// once, reused across rows. Vectorized loads (float4 = 16 bytes per load) +// saturate memory bus — 8x fewer transactions for bf16/f16, 4x for f32. +// +// Runtime alignment check: if K is divisible by VEC elements AND pointers are +// 16-byte aligned, use float4 loads. Otherwise fall back to scalar. +// ============================================================================ + +#define ROWS_PER_WARP 2 + +// Helper: check if a pointer is aligned to N bytes +#define IS_ALIGNED(ptr, n) (((unsigned long long)(ptr)) % (n) == 0) + +// --- BF16: float4 = 8 bf16 values per load --- + +extern "C" __global__ void gemv_bt_mr_bf16( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, + __nv_bfloat16* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int a_batch_count, + unsigned int b_batch_count +) { + const unsigned int warp_id = threadIdx.x / WARP_SIZE; + const unsigned int lane_id = threadIdx.x % WARP_SIZE; + const unsigned int col_base = (blockIdx.x * WARPS_PER_BLOCK + warp_id) * ROWS_PER_WARP; + const unsigned int m = blockIdx.y; + const unsigned int batch = blockIdx.z; + const unsigned int a_batch = batch % a_batch_count; + const unsigned int b_batch = batch % b_batch_count; + + const __nv_bfloat16* a_row = A + a_batch * M * K + m * K; + + float acc[ROWS_PER_WARP] = {0.0f, 0.0f}; + + // float4 = 16 bytes = 8 bf16. Use vectorized path if K is multiple of 8 + // and both A and B rows are 16-byte aligned. + const unsigned int VEC = 8; + const bool can_vec = (K % VEC == 0) && IS_ALIGNED(a_row, 16); + + if (can_vec) { + const unsigned int K_vec = K / VEC; + const float4* a_vec = reinterpret_cast(a_row); + + for (unsigned int vi = lane_id; vi < K_vec; vi += WARP_SIZE) { + float4 av = a_vec[vi]; + const __nv_bfloat16* a8 = reinterpret_cast(&av); + + #pragma unroll + for (int r = 0; r < ROWS_PER_WARP; r++) { + if (col_base + r < N) { + const float4* b_vec = reinterpret_cast( + B + b_batch * N * K + (col_base + r) * K); + float4 bv = b_vec[vi]; + const __nv_bfloat16* b8 = reinterpret_cast(&bv); + + #pragma unroll + for (int j = 0; j < 8; j++) { + acc[r] += __bfloat162float(a8[j]) * __bfloat162float(b8[j]); + } + } + } + } + } else { + for (unsigned int k = lane_id; k < K; k += WARP_SIZE) { + float a_val = __bfloat162float(a_row[k]); + #pragma unroll + for (int r = 0; r < ROWS_PER_WARP; r++) { + if (col_base + r < N) { + acc[r] += a_val * __bfloat162float( + B[b_batch * N * K + (col_base + r) * K + k]); + } + } + } + } + + #pragma unroll + for (int r = 0; r < ROWS_PER_WARP; r++) { + for (int off = WARP_SIZE / 2; off > 0; off >>= 1) + acc[r] += __shfl_down_sync(0xFFFFFFFF, acc[r], off); + if (lane_id == 0 && col_base + r < N) + C[batch * M * N + m * N + col_base + r] = __float2bfloat16(acc[r]); + } +} + +// --- F32: float4 = 4 f32 values per load --- + +extern "C" __global__ void gemv_bt_mr_f32( + const float* __restrict__ A, + const float* __restrict__ B, + float* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int a_batch_count, + unsigned int b_batch_count +) { + const unsigned int warp_id = threadIdx.x / WARP_SIZE; + const unsigned int lane_id = threadIdx.x % WARP_SIZE; + const unsigned int col_base = (blockIdx.x * WARPS_PER_BLOCK + warp_id) * ROWS_PER_WARP; + const unsigned int m = blockIdx.y; + const unsigned int batch = blockIdx.z; + const unsigned int a_batch = batch % a_batch_count; + const unsigned int b_batch = batch % b_batch_count; + + const float* a_row = A + a_batch * M * K + m * K; + + float acc[ROWS_PER_WARP] = {0.0f, 0.0f}; + + const unsigned int VEC = 4; + const bool can_vec = (K % VEC == 0) && IS_ALIGNED(a_row, 16); + + if (can_vec) { + const unsigned int K_vec = K / VEC; + const float4* a_vec = reinterpret_cast(a_row); + + for (unsigned int vi = lane_id; vi < K_vec; vi += WARP_SIZE) { + float4 av = a_vec[vi]; + + #pragma unroll + for (int r = 0; r < ROWS_PER_WARP; r++) { + if (col_base + r < N) { + const float4* b_vec = reinterpret_cast( + B + b_batch * N * K + (col_base + r) * K); + float4 bv = b_vec[vi]; + acc[r] += av.x * bv.x + av.y * bv.y + av.z * bv.z + av.w * bv.w; + } + } + } + } else { + for (unsigned int k = lane_id; k < K; k += WARP_SIZE) { + float a_val = a_row[k]; + #pragma unroll + for (int r = 0; r < ROWS_PER_WARP; r++) { + if (col_base + r < N) { + acc[r] += a_val * B[b_batch * N * K + (col_base + r) * K + k]; + } + } + } + } + + #pragma unroll + for (int r = 0; r < ROWS_PER_WARP; r++) { + for (int off = WARP_SIZE / 2; off > 0; off >>= 1) + acc[r] += __shfl_down_sync(0xFFFFFFFF, acc[r], off); + if (lane_id == 0 && col_base + r < N) + C[batch * M * N + m * N + col_base + r] = acc[r]; + } +} + +// --- F16: float4 = 8 half values per load --- + +extern "C" __global__ void gemv_bt_mr_f16( + const half* __restrict__ A, + const half* __restrict__ B, + half* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int a_batch_count, + unsigned int b_batch_count +) { + const unsigned int warp_id = threadIdx.x / WARP_SIZE; + const unsigned int lane_id = threadIdx.x % WARP_SIZE; + const unsigned int col_base = (blockIdx.x * WARPS_PER_BLOCK + warp_id) * ROWS_PER_WARP; + const unsigned int m = blockIdx.y; + const unsigned int batch = blockIdx.z; + const unsigned int a_batch = batch % a_batch_count; + const unsigned int b_batch = batch % b_batch_count; + + const half* a_row = A + a_batch * M * K + m * K; + + float acc[ROWS_PER_WARP] = {0.0f, 0.0f}; + + const unsigned int VEC = 8; + const bool can_vec = (K % VEC == 0) && IS_ALIGNED(a_row, 16); + + if (can_vec) { + const unsigned int K_vec = K / VEC; + const float4* a_vec = reinterpret_cast(a_row); + + for (unsigned int vi = lane_id; vi < K_vec; vi += WARP_SIZE) { + float4 av = a_vec[vi]; + const half* a8 = reinterpret_cast(&av); + + #pragma unroll + for (int r = 0; r < ROWS_PER_WARP; r++) { + if (col_base + r < N) { + const float4* b_vec = reinterpret_cast( + B + b_batch * N * K + (col_base + r) * K); + float4 bv = b_vec[vi]; + const half* b8 = reinterpret_cast(&bv); + + #pragma unroll + for (int j = 0; j < 8; j++) { + acc[r] += __half2float(a8[j]) * __half2float(b8[j]); + } + } + } + } + } else { + for (unsigned int k = lane_id; k < K; k += WARP_SIZE) { + float a_val = __half2float(a_row[k]); + #pragma unroll + for (int r = 0; r < ROWS_PER_WARP; r++) { + if (col_base + r < N) { + acc[r] += a_val * __half2float( + B[b_batch * N * K + (col_base + r) * K + k]); + } + } + } + } + + #pragma unroll + for (int r = 0; r < ROWS_PER_WARP; r++) { + for (int off = WARP_SIZE / 2; off > 0; off >>= 1) + acc[r] += __shfl_down_sync(0xFFFFFFFF, acc[r], off); + if (lane_id == 0 && col_base + r < N) + C[batch * M * N + m * N + col_base + r] = __float2half(acc[r]); + } +} + +// --- F64: double2 = 2 f64 values per load --- + +extern "C" __global__ void gemv_bt_mr_f64( + const double* __restrict__ A, + const double* __restrict__ B, + double* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int a_batch_count, + unsigned int b_batch_count +) { + const unsigned int warp_id = threadIdx.x / WARP_SIZE; + const unsigned int lane_id = threadIdx.x % WARP_SIZE; + const unsigned int col_base = (blockIdx.x * WARPS_PER_BLOCK + warp_id) * ROWS_PER_WARP; + const unsigned int m = blockIdx.y; + const unsigned int batch = blockIdx.z; + const unsigned int a_batch = batch % a_batch_count; + const unsigned int b_batch = batch % b_batch_count; + + const double* a_row = A + a_batch * M * K + m * K; + + double acc[ROWS_PER_WARP] = {0.0, 0.0}; + + const unsigned int VEC = 2; + const bool can_vec = (K % VEC == 0) && IS_ALIGNED(a_row, 16); + + if (can_vec) { + const unsigned int K_vec = K / VEC; + const double2* a_vec = reinterpret_cast(a_row); + + for (unsigned int vi = lane_id; vi < K_vec; vi += WARP_SIZE) { + double2 av = a_vec[vi]; + + #pragma unroll + for (int r = 0; r < ROWS_PER_WARP; r++) { + if (col_base + r < N) { + const double2* b_vec = reinterpret_cast( + B + b_batch * N * K + (col_base + r) * K); + double2 bv = b_vec[vi]; + acc[r] += av.x * bv.x + av.y * bv.y; + } + } + } + } else { + for (unsigned int k = lane_id; k < K; k += WARP_SIZE) { + double a_val = a_row[k]; + #pragma unroll + for (int r = 0; r < ROWS_PER_WARP; r++) { + if (col_base + r < N) { + acc[r] += a_val * B[b_batch * N * K + (col_base + r) * K + k]; + } + } + } + } + + #pragma unroll + for (int r = 0; r < ROWS_PER_WARP; r++) { + for (int off = WARP_SIZE / 2; off > 0; off >>= 1) + acc[r] += __shfl_down_sync(0xFFFFFFFF, acc[r], off); + if (lane_id == 0 && col_base + r < N) + C[batch * M * N + m * N + col_base + r] = acc[r]; + } +} + +extern "C" __global__ void gemv_bt_f64( + const double* __restrict__ A, + const double* __restrict__ B, + double* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int a_batch_count, + unsigned int b_batch_count +) { + const unsigned int warp_id = threadIdx.x / WARP_SIZE; + const unsigned int lane_id = threadIdx.x % WARP_SIZE; + const unsigned int col = blockIdx.x * WARPS_PER_BLOCK + warp_id; + const unsigned int m = blockIdx.y; + const unsigned int batch = blockIdx.z; + if (col >= N) return; + + const double* a_row = A + (batch % a_batch_count) * M * K + m * K; + const double* b_row = B + (batch % b_batch_count) * N * K + col * K; + + double acc = 0.0; + for (unsigned int k = lane_id; k < K; k += WARP_SIZE) { + acc += a_row[k] * b_row[k]; + } + + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { + acc += __shfl_down_sync(0xFFFFFFFF, acc, offset); + } + + if (lane_id == 0) { + C[batch * M * N + m * N + col] = acc; + } +} diff --git a/src/runtime/cuda/kernels/index.cu b/src/runtime/cuda/kernels/index.cu index 43c01273..8cb97fc6 100644 --- a/src/runtime/cuda/kernels/index.cu +++ b/src/runtime/cuda/kernels/index.cu @@ -1227,4 +1227,42 @@ __global__ void scatter_reduce_mean_div_##suffix( \ DEFINE_SCATTER_REDUCE_MEAN_DIV_KERNEL(f32, float) DEFINE_SCATTER_REDUCE_MEAN_DIV_KERNEL(f64, double) +// ============================================================================ +// Slice Assign - Copy src into a slice of dst along a dimension +// dst: full destination tensor (outer_size * dst_dim_size * inner_size) +// src: source tensor (outer_size * src_dim_size * inner_size) +// output: pre-copied dst, then src overwrites the slice region +// ============================================================================ + +#define DEFINE_SLICE_ASSIGN_KERNEL(suffix, dtype) \ +__global__ void slice_assign_##suffix( \ + const dtype* __restrict__ src, \ + dtype* __restrict__ output, \ + unsigned int outer_size, \ + unsigned int dst_dim_size, \ + unsigned int src_dim_size, \ + unsigned int inner_size, \ + unsigned int start \ +) { \ + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; \ + unsigned int total = outer_size * src_dim_size * inner_size; \ + if (idx >= total) return; \ + \ + unsigned int inner = idx % inner_size; \ + unsigned int s = (idx / inner_size) % src_dim_size; \ + unsigned int o = idx / (src_dim_size * inner_size); \ + \ + unsigned int dst_offset = o * dst_dim_size * inner_size + (start + s) * inner_size + inner; \ + output[dst_offset] = src[idx]; \ +} + +DEFINE_SLICE_ASSIGN_KERNEL(f32, float) +DEFINE_SLICE_ASSIGN_KERNEL(f64, double) +DEFINE_SLICE_ASSIGN_KERNEL(f16, __half) +DEFINE_SLICE_ASSIGN_KERNEL(bf16, __nv_bfloat16) +DEFINE_SLICE_ASSIGN_KERNEL(i32, int) +DEFINE_SLICE_ASSIGN_KERNEL(i64, long long) +DEFINE_SLICE_ASSIGN_KERNEL(fp8_e4m3, numr_fp8_e4m3) +DEFINE_SLICE_ASSIGN_KERNEL(fp8_e5m2, numr_fp8_e5m2) + } // extern "C" diff --git a/src/runtime/cuda/kernels/index.rs b/src/runtime/cuda/kernels/index.rs deleted file mode 100644 index 73f9b2e5..00000000 --- a/src/runtime/cuda/kernels/index.rs +++ /dev/null @@ -1,1469 +0,0 @@ -//! Indexing CUDA kernel launchers -//! -//! Provides launchers for indexing operations: gather, scatter, index_select, -//! masked_select, and masked_fill. - -use cudarc::driver::PushKernelArg; -use cudarc::driver::safe::{CudaContext, CudaStream}; -use std::sync::Arc; - -use super::loader::{ - BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name, - launch_config, -}; -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// Module name for indexing operations -pub const INDEX_MODULE: &str = "index"; - -// ============================================================================ -// Gather -// ============================================================================ - -/// Launch gather kernel. -/// -/// Gathers values from input along a dimension specified by indices. -/// `output[i][j][k] = input[i][indices[i][j][k]][k]` (when dim=1) -/// -/// # Safety -/// -/// - All pointers must be valid device memory -/// - Shape and stride arrays must be valid device memory with `ndim` u32 elements -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_gather( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - input_ptr: u64, - indices_ptr: u64, - output_ptr: u64, - ndim: usize, - dim: usize, - input_shape_ptr: u64, - input_strides_ptr: u64, - output_shape_ptr: u64, - output_strides_ptr: u64, - total_elements: usize, -) -> Result<()> { - if total_elements == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func_name = kernel_name("gather", dtype); - let func = get_kernel_function(&module, &func_name)?; - - let grid = elementwise_launch_config(total_elements); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let ndim_u32 = ndim as u32; - let dim_u32 = dim as u32; - let total_u32 = total_elements as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&input_ptr); - builder.arg(&indices_ptr); - builder.arg(&output_ptr); - builder.arg(&ndim_u32); - builder.arg(&dim_u32); - builder.arg(&input_shape_ptr); - builder.arg(&input_strides_ptr); - builder.arg(&output_shape_ptr); - builder.arg(&output_strides_ptr); - builder.arg(&total_u32); - - builder - .launch(cfg) - .map_err(|e| Error::Internal(format!("CUDA gather kernel launch failed: {:?}", e)))?; - - Ok(()) - } -} - -// ============================================================================ -// Scatter -// ============================================================================ - -/// Launch scatter kernel. -/// -/// Scatters values from src to output at positions specified by indices. -/// `output[i][indices[i][j][k]][k] = src[i][j][k]` (when dim=1) -/// -/// # Safety -/// -/// - All pointers must be valid device memory -/// - Output must be pre-initialized (typically a copy of input) -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_scatter( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - input_ptr: u64, - indices_ptr: u64, - src_ptr: u64, - output_ptr: u64, - ndim: usize, - dim: usize, - output_shape_ptr: u64, - output_strides_ptr: u64, - src_shape_ptr: u64, - src_strides_ptr: u64, - src_total: usize, -) -> Result<()> { - if src_total == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func_name = kernel_name("scatter", dtype); - let func = get_kernel_function(&module, &func_name)?; - - let grid = elementwise_launch_config(src_total); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let ndim_u32 = ndim as u32; - let dim_u32 = dim as u32; - let src_total_u32 = src_total as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&input_ptr); - builder.arg(&indices_ptr); - builder.arg(&src_ptr); - builder.arg(&output_ptr); - builder.arg(&ndim_u32); - builder.arg(&dim_u32); - builder.arg(&output_shape_ptr); - builder.arg(&output_strides_ptr); - builder.arg(&src_shape_ptr); - builder.arg(&src_strides_ptr); - builder.arg(&src_total_u32); - - builder - .launch(cfg) - .map_err(|e| Error::Internal(format!("CUDA scatter kernel launch failed: {:?}", e)))?; - - Ok(()) - } -} - -/// Launch copy kernel for scatter initialization. -/// -/// # Safety -/// -/// - All pointers must be valid device memory -/// - dst must have space for n elements -pub unsafe fn launch_copy( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - src_ptr: u64, - dst_ptr: u64, - n: usize, -) -> Result<()> { - if n == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func_name = kernel_name("copy", dtype); - let func = get_kernel_function(&module, &func_name)?; - - let grid = elementwise_launch_config(n); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let n_u32 = n as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&src_ptr); - builder.arg(&dst_ptr); - builder.arg(&n_u32); - - builder - .launch(cfg) - .map_err(|e| Error::Internal(format!("CUDA copy kernel launch failed: {:?}", e)))?; - - Ok(()) - } -} - -// ============================================================================ -// Index Select -// ============================================================================ - -/// Launch index_select kernel. -/// -/// Selects elements along a dimension using a 1D index tensor. -/// -/// # Safety -/// -/// - All pointers must be valid device memory -/// - indices must be a 1D tensor of i64 values -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_index_select( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - input_ptr: u64, - indices_ptr: u64, - output_ptr: u64, - outer_size: usize, - dim_size: usize, - inner_size: usize, - index_len: usize, -) -> Result<()> { - let total = outer_size * index_len * inner_size; - if total == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func_name = kernel_name("index_select", dtype); - let func = get_kernel_function(&module, &func_name)?; - - let grid = elementwise_launch_config(total); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let outer_u32 = outer_size as u32; - let dim_u32 = dim_size as u32; - let inner_u32 = inner_size as u32; - let index_len_u32 = index_len as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&input_ptr); - builder.arg(&indices_ptr); - builder.arg(&output_ptr); - builder.arg(&outer_u32); - builder.arg(&dim_u32); - builder.arg(&inner_u32); - builder.arg(&index_len_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!("CUDA index_select kernel launch failed: {:?}", e)) - })?; - - Ok(()) - } -} - -/// Puts values at specified indices along a dimension. -/// -/// # Safety -/// -/// - All pointers must be valid device memory -/// - indices must be a 1D tensor of i64 values -/// - output must already contain a copy of the input tensor -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_index_put( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - indices_ptr: u64, - src_ptr: u64, - output_ptr: u64, - outer_size: usize, - dim_size: usize, - inner_size: usize, - index_len: usize, -) -> Result<()> { - let total = outer_size * index_len * inner_size; - if total == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func_name = kernel_name("index_put", dtype); - let func = get_kernel_function(&module, &func_name)?; - - let grid = elementwise_launch_config(total); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let outer_u32 = outer_size as u32; - let dim_u32 = dim_size as u32; - let inner_u32 = inner_size as u32; - let index_len_u32 = index_len as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&indices_ptr); - builder.arg(&src_ptr); - builder.arg(&output_ptr); - builder.arg(&outer_u32); - builder.arg(&dim_u32); - builder.arg(&inner_u32); - builder.arg(&index_len_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!("CUDA index_put kernel launch failed: {:?}", e)) - })?; - - Ok(()) - } -} - -// ============================================================================ -// Index Bounds Validation -// ============================================================================ - -/// Launch index bounds validation kernel. -/// -/// Validates that all indices are within bounds [0, dim_size). -/// Returns the count of out-of-bounds indices in error_count buffer. -/// -/// # Safety -/// -/// - indices_ptr must be valid device memory with index_len i64 elements -/// - error_count_ptr must be valid device memory with 1 u32 element (initialized to 0) -pub unsafe fn launch_validate_indices( - context: &Arc, - stream: &CudaStream, - device_index: usize, - indices_ptr: u64, - error_count_ptr: u64, - index_len: usize, - dim_size: usize, -) -> Result<()> { - if index_len == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func = get_kernel_function(&module, "validate_indices_kernel")?; - - let grid = elementwise_launch_config(index_len); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let index_len_u32 = index_len as u32; - let dim_size_u32 = dim_size as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&indices_ptr); - builder.arg(&error_count_ptr); - builder.arg(&index_len_u32); - builder.arg(&dim_size_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA validate_indices kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -// ============================================================================ -// Masked Select -// ============================================================================ - -/// Launch masked_count kernel to count true elements in mask. -/// -/// # Safety -/// -/// - mask_ptr must be valid device memory with n u8 elements -/// - count_ptr must be valid device memory with 1 u32 element (initialized to 0) -pub unsafe fn launch_masked_count( - context: &Arc, - stream: &CudaStream, - device_index: usize, - mask_ptr: u64, - count_ptr: u64, - n: usize, -) -> Result<()> { - if n == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func = get_kernel_function(&module, "masked_count_kernel")?; - - let grid = elementwise_launch_config(n); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let n_u32 = n as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&mask_ptr); - builder.arg(&count_ptr); - builder.arg(&n_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!("CUDA masked_count kernel launch failed: {:?}", e)) - })?; - - Ok(()) - } -} - -/// Launch masked_prefix_sum kernel to compute prefix sum of mask. -/// -/// This is a simple sequential kernel for small tensors. For large tensors, -/// a parallel scan algorithm should be used instead. -/// -/// # Safety -/// -/// - mask_ptr must be valid device memory with n u8 elements -/// - prefix_sum_ptr must be valid device memory with n u32 elements -pub unsafe fn launch_masked_prefix_sum( - context: &Arc, - stream: &CudaStream, - device_index: usize, - mask_ptr: u64, - prefix_sum_ptr: u64, - n: usize, -) -> Result<()> { - if n == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func = get_kernel_function(&module, "masked_prefix_sum_kernel")?; - - // This kernel uses a single thread - let cfg = launch_config((1, 1, 1), (1, 1, 1), 0); - - let n_u32 = n as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&mask_ptr); - builder.arg(&prefix_sum_ptr); - builder.arg(&n_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA masked_prefix_sum kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -/// Launch masked_select kernel. -/// -/// Selects elements from input where mask is true, using precomputed prefix sum. -/// -/// # Safety -/// -/// - All pointers must be valid device memory -/// - prefix_sum must be precomputed via launch_masked_prefix_sum -/// - output must have space for at least count_true elements -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_masked_select( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - input_ptr: u64, - mask_ptr: u64, - output_ptr: u64, - prefix_sum_ptr: u64, - n: usize, -) -> Result<()> { - if n == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func_name = kernel_name("masked_select", dtype); - let func = get_kernel_function(&module, &func_name)?; - - let grid = elementwise_launch_config(n); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let n_u32 = n as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&input_ptr); - builder.arg(&mask_ptr); - builder.arg(&output_ptr); - builder.arg(&prefix_sum_ptr); - builder.arg(&n_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!("CUDA masked_select kernel launch failed: {:?}", e)) - })?; - - Ok(()) - } -} - -// ============================================================================ -// Masked Fill -// ============================================================================ - -/// Launch masked_fill kernel. -/// -/// Fills elements where mask is true with a scalar value. -/// Dispatches to the appropriate dtype-specific kernel. -/// -/// # Safety -/// -/// - All pointers must be valid device memory -/// - input and output must have n elements -pub unsafe fn launch_masked_fill( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - input_ptr: u64, - mask_ptr: u64, - output_ptr: u64, - fill_value: f64, - n: usize, -) -> Result<()> { - if n == 0 { - return Ok(()); - } - - let kernel_name = match dtype { - DType::F32 => "masked_fill_f32", - DType::F64 => "masked_fill_f64", - DType::I32 => "masked_fill_i32", - DType::I64 => "masked_fill_i64", - #[cfg(feature = "f16")] - DType::F16 => "masked_fill_f16", - #[cfg(feature = "f16")] - DType::BF16 => "masked_fill_bf16", - #[cfg(feature = "fp8")] - DType::FP8E4M3 => "masked_fill_fp8_e4m3", - #[cfg(feature = "fp8")] - DType::FP8E5M2 => "masked_fill_fp8_e5m2", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "masked_fill", - }); - } - }; - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func = get_kernel_function(&module, kernel_name)?; - - let grid = elementwise_launch_config(n); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let n_u32 = n as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&input_ptr); - builder.arg(&mask_ptr); - builder.arg(&output_ptr); - - // Pre-convert fill_value to all possible types to avoid lifetime issues - let fill_f32 = fill_value as f32; - let fill_f64 = fill_value; - let fill_i32 = fill_value as i32; - let fill_i64 = fill_value as i64; - #[cfg(feature = "f16")] - let fill_f16 = half::f16::from_f64(fill_value).to_bits(); - #[cfg(feature = "f16")] - let fill_bf16 = half::bf16::from_f64(fill_value).to_bits(); - #[cfg(feature = "fp8")] - let fill_fp8_e4m3 = crate::dtype::fp8::FP8E4M3::from_f64(fill_value).to_bits(); - #[cfg(feature = "fp8")] - let fill_fp8_e5m2 = crate::dtype::fp8::FP8E5M2::from_f64(fill_value).to_bits(); - - // Pass fill_value with appropriate type - match dtype { - DType::F32 => builder.arg(&fill_f32), - DType::F64 => builder.arg(&fill_f64), - DType::I32 => builder.arg(&fill_i32), - DType::I64 => builder.arg(&fill_i64), - #[cfg(feature = "f16")] - DType::F16 => builder.arg(&fill_f16), - #[cfg(feature = "f16")] - DType::BF16 => builder.arg(&fill_bf16), - #[cfg(feature = "fp8")] - DType::FP8E4M3 => builder.arg(&fill_fp8_e4m3), - #[cfg(feature = "fp8")] - DType::FP8E5M2 => builder.arg(&fill_fp8_e5m2), - _ => unreachable!(), // Already handled above - }; - - builder.arg(&n_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!("CUDA masked_fill kernel launch failed: {:?}", e)) - })?; - - Ok(()) - } -} - -// ============================================================================ -// Broadcast Masked Operations -// ============================================================================ - -/// Launch broadcast masked_count kernel. -/// -/// Counts true elements in mask when broadcast to output shape. -/// -/// # Safety -/// -/// - mask_ptr must be valid device memory -/// - count_ptr must be valid device memory with 1 u32 element (initialized to 0) -/// - mask_strides_ptr, out_shape_ptr must be valid device memory with ndim u32 elements -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_masked_count_broadcast( - context: &Arc, - stream: &CudaStream, - device_index: usize, - mask_ptr: u64, - count_ptr: u64, - mask_strides_ptr: u64, - out_shape_ptr: u64, - ndim: usize, - n: usize, -) -> Result<()> { - if n == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func = get_kernel_function(&module, "masked_count_broadcast_kernel")?; - - let grid = elementwise_launch_config(n); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let ndim_u32 = ndim as u32; - let n_u32 = n as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&mask_ptr); - builder.arg(&count_ptr); - builder.arg(&mask_strides_ptr); - builder.arg(&out_shape_ptr); - builder.arg(&ndim_u32); - builder.arg(&n_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA masked_count_broadcast kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -/// Launch broadcast masked_prefix_sum kernel. -/// -/// Computes prefix sum of mask values when broadcast to output shape. -/// -/// # Safety -/// -/// - mask_ptr must be valid device memory -/// - prefix_sum_ptr must be valid device memory with n u32 elements -/// - mask_strides_ptr, out_shape_ptr must be valid device memory with ndim u32 elements -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_masked_prefix_sum_broadcast( - context: &Arc, - stream: &CudaStream, - device_index: usize, - mask_ptr: u64, - prefix_sum_ptr: u64, - mask_strides_ptr: u64, - out_shape_ptr: u64, - ndim: usize, - n: usize, -) -> Result<()> { - if n == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func = get_kernel_function(&module, "masked_prefix_sum_broadcast_kernel")?; - - // This kernel uses a single thread - let cfg = launch_config((1, 1, 1), (1, 1, 1), 0); - - let ndim_u32 = ndim as u32; - let n_u32 = n as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&mask_ptr); - builder.arg(&prefix_sum_ptr); - builder.arg(&mask_strides_ptr); - builder.arg(&out_shape_ptr); - builder.arg(&ndim_u32); - builder.arg(&n_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA masked_prefix_sum_broadcast kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -/// Launch broadcast masked_select kernel. -/// -/// Selects elements from input where broadcast mask is true. -/// -/// # Safety -/// -/// - All pointers must be valid device memory -/// - prefix_sum must be precomputed via launch_masked_prefix_sum_broadcast -/// - output must have space for at least count_true elements -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_masked_select_broadcast( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - input_ptr: u64, - mask_ptr: u64, - output_ptr: u64, - prefix_sum_ptr: u64, - mask_strides_ptr: u64, - out_shape_ptr: u64, - ndim: usize, - n: usize, -) -> Result<()> { - if n == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func_name = format!("masked_select_broadcast_{}", dtype_suffix(dtype)?); - let func = get_kernel_function(&module, &func_name)?; - - let grid = elementwise_launch_config(n); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let ndim_u32 = ndim as u32; - let n_u32 = n as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&input_ptr); - builder.arg(&mask_ptr); - builder.arg(&output_ptr); - builder.arg(&prefix_sum_ptr); - builder.arg(&mask_strides_ptr); - builder.arg(&out_shape_ptr); - builder.arg(&ndim_u32); - builder.arg(&n_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA masked_select_broadcast kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -/// Launch broadcast masked_fill kernel. -/// -/// Fills elements where broadcast mask is true with a scalar value. -/// -/// # Safety -/// -/// - All pointers must be valid device memory -/// - input and output must have n elements -/// - mask_strides_ptr, out_shape_ptr must be valid device memory with ndim u32 elements -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_masked_fill_broadcast( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - input_ptr: u64, - mask_ptr: u64, - output_ptr: u64, - fill_value: f64, - mask_strides_ptr: u64, - out_shape_ptr: u64, - ndim: usize, - n: usize, -) -> Result<()> { - if n == 0 { - return Ok(()); - } - - let kernel_name = match dtype { - DType::F32 => "masked_fill_broadcast_f32", - DType::F64 => "masked_fill_broadcast_f64", - DType::I32 => "masked_fill_broadcast_i32", - DType::I64 => "masked_fill_broadcast_i64", - #[cfg(feature = "f16")] - DType::F16 => "masked_fill_broadcast_f16", - #[cfg(feature = "f16")] - DType::BF16 => "masked_fill_broadcast_bf16", - #[cfg(feature = "fp8")] - DType::FP8E4M3 => "masked_fill_broadcast_fp8_e4m3", - #[cfg(feature = "fp8")] - DType::FP8E5M2 => "masked_fill_broadcast_fp8_e5m2", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "masked_fill_broadcast", - }); - } - }; - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func = get_kernel_function(&module, kernel_name)?; - - let grid = elementwise_launch_config(n); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let ndim_u32 = ndim as u32; - let n_u32 = n as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&input_ptr); - builder.arg(&mask_ptr); - builder.arg(&output_ptr); - - // Pre-convert fill_value to all possible types to avoid lifetime issues - let fill_f32 = fill_value as f32; - let fill_f64 = fill_value; - let fill_i32 = fill_value as i32; - let fill_i64 = fill_value as i64; - #[cfg(feature = "f16")] - let fill_f16 = half::f16::from_f64(fill_value).to_bits(); - #[cfg(feature = "f16")] - let fill_bf16 = half::bf16::from_f64(fill_value).to_bits(); - #[cfg(feature = "fp8")] - let fill_fp8_e4m3 = crate::dtype::fp8::FP8E4M3::from_f64(fill_value).to_bits(); - #[cfg(feature = "fp8")] - let fill_fp8_e5m2 = crate::dtype::fp8::FP8E5M2::from_f64(fill_value).to_bits(); - - // Pass fill_value with appropriate type - match dtype { - DType::F32 => builder.arg(&fill_f32), - DType::F64 => builder.arg(&fill_f64), - DType::I32 => builder.arg(&fill_i32), - DType::I64 => builder.arg(&fill_i64), - #[cfg(feature = "f16")] - DType::F16 => builder.arg(&fill_f16), - #[cfg(feature = "f16")] - DType::BF16 => builder.arg(&fill_bf16), - #[cfg(feature = "fp8")] - DType::FP8E4M3 => builder.arg(&fill_fp8_e4m3), - #[cfg(feature = "fp8")] - DType::FP8E5M2 => builder.arg(&fill_fp8_e5m2), - _ => unreachable!(), // Already handled above - }; - - builder.arg(&mask_strides_ptr); - builder.arg(&out_shape_ptr); - builder.arg(&ndim_u32); - builder.arg(&n_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA masked_fill_broadcast kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -/// Helper to get dtype suffix for kernel name -fn dtype_suffix(dtype: DType) -> Result<&'static str> { - match dtype { - DType::F32 => Ok("f32"), - DType::F64 => Ok("f64"), - DType::I32 => Ok("i32"), - DType::I64 => Ok("i64"), - #[cfg(feature = "f16")] - DType::F16 => Ok("f16"), - #[cfg(feature = "f16")] - DType::BF16 => Ok("bf16"), - #[cfg(feature = "fp8")] - DType::FP8E4M3 => Ok("fp8_e4m3"), - #[cfg(feature = "fp8")] - DType::FP8E5M2 => Ok("fp8_e5m2"), - _ => Err(Error::UnsupportedDType { - dtype, - op: "masked_select_broadcast", - }), - } -} - -// ============================================================================ -// Embedding Lookup -// ============================================================================ - -/// Launch embedding_lookup kernel. -/// -/// Looks up embeddings from an embedding table using indices. -/// This is the industry-standard embedding lookup operation used in neural networks. -/// -/// # Algorithm -/// For each index i in [0, num_indices): -/// output[i, :] = embeddings[indices[i], :] -/// -/// Output shape: [num_indices, embedding_dim] -/// -/// # Safety -/// -/// - All pointers must be valid device memory -/// - embeddings must be 2D [vocab_size, embedding_dim] -/// - indices must contain values in [0, vocab_size) -/// - output must have space for num_indices * embedding_dim elements -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_embedding_lookup( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - embeddings_ptr: u64, - indices_ptr: u64, - output_ptr: u64, - num_indices: usize, - vocab_size: usize, - embedding_dim: usize, -) -> Result<()> { - if num_indices == 0 || embedding_dim == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func_name = kernel_name("embedding_lookup", dtype); - let func = get_kernel_function(&module, &func_name)?; - - // Each thread handles one embedding lookup (one index) - // More efficient than one thread per element because we copy contiguous rows - let grid = elementwise_launch_config(num_indices); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let num_indices_u32 = num_indices as u32; - let vocab_size_u32 = vocab_size as u32; - let embedding_dim_u32 = embedding_dim as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&embeddings_ptr); - builder.arg(&indices_ptr); - builder.arg(&output_ptr); - builder.arg(&num_indices_u32); - builder.arg(&vocab_size_u32); - builder.arg(&embedding_dim_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA embedding_lookup kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -// ============================================================================ -// Gather ND -// ============================================================================ - -/// Launch gather_nd kernel. -/// -/// Gathers slices from input at positions specified by indices tensor. -/// -/// # Arguments -/// -/// * `input_ptr` - Input tensor data -/// * `indices_ptr` - Indices tensor (num_slices, index_depth) -/// * `output_ptr` - Output tensor (num_slices, remaining_dims...) -/// * `input_shape_ptr` - Device pointer to input shape array -/// * `input_strides_ptr` - Device pointer to input strides array -/// -/// # Safety -/// -/// All pointers must be valid device memory with sufficient size. -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_gather_nd( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - input_ptr: u64, - indices_ptr: u64, - output_ptr: u64, - input_shape_ptr: u64, - input_strides_ptr: u64, - num_slices: usize, - slice_size: usize, - index_depth: usize, - ndim: usize, -) -> Result<()> { - let total = num_slices * slice_size; - if total == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func_name = kernel_name("gather_nd", dtype); - let func = get_kernel_function(&module, &func_name)?; - - let grid = elementwise_launch_config(total); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let num_slices_u32 = num_slices as u32; - let slice_size_u32 = slice_size as u32; - let index_depth_u32 = index_depth as u32; - let ndim_u32 = ndim as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&input_ptr); - builder.arg(&indices_ptr); - builder.arg(&output_ptr); - builder.arg(&input_shape_ptr); - builder.arg(&input_strides_ptr); - builder.arg(&num_slices_u32); - builder.arg(&slice_size_u32); - builder.arg(&index_depth_u32); - builder.arg(&ndim_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!("CUDA gather_nd kernel launch failed: {:?}", e)) - })?; - - Ok(()) - } -} - -// ============================================================================ -// Bincount -// ============================================================================ - -/// Launch bincount kernel. -/// -/// Counts occurrences of each value in an integer tensor, optionally with weights. -/// -/// # Arguments -/// -/// * `input_ptr` - Input tensor of non-negative integers (i32 or i64) -/// * `weights_ptr` - Optional weights tensor -/// * `output_ptr` - Output tensor (initialized to zeros) -/// * `n` - Number of elements in input -/// * `minlength` - Length of output tensor -/// -/// # Safety -/// -/// All pointers must be valid device memory. -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_bincount_weighted( - context: &Arc, - stream: &CudaStream, - device_index: usize, - input_dtype: DType, - weights_dtype: Option, - input_ptr: u64, - weights_ptr: Option, - output_ptr: u64, - n: usize, - minlength: usize, -) -> Result<()> { - if n == 0 || minlength == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - - let func_name = match (input_dtype, weights_ptr, weights_dtype) { - (DType::I32, None, _) => "bincount_i32", - (DType::I64, None, _) => "bincount_i64", - (DType::I32, Some(_), Some(DType::F32)) => "bincount_weighted_f32", - (DType::I32, Some(_), Some(DType::F64)) => "bincount_weighted_f64", - (DType::I64, Some(_), Some(DType::F32)) => "bincount_i64_weighted_f32", - _ => { - return Err(Error::InvalidArgument { - arg: "dtype", - reason: format!("bincount requires i32/i64 input, got {:?}", input_dtype), - }); - } - }; - - let func = get_kernel_function(&module, func_name)?; - - let grid = elementwise_launch_config(n); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let n_u32 = n as u32; - let minlength_u32 = minlength as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&input_ptr); - - // Store weights_ptr value outside the if block to extend its lifetime - let weights_ptr_val = weights_ptr.unwrap_or(0); - if weights_ptr.is_some() { - builder.arg(&weights_ptr_val); - } - - builder.arg(&output_ptr); - builder.arg(&n_u32); - builder.arg(&minlength_u32); - - builder - .launch(cfg) - .map_err(|e| Error::Internal(format!("CUDA bincount kernel launch failed: {:?}", e)))?; - - Ok(()) - } -} - -// ============================================================================ -// Scatter Reduce -// ============================================================================ - -/// Scatter reduce operation type. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ScatterReduceOpCuda { - Sum, - Max, - Min, - Prod, -} - -/// Launch scatter_reduce kernel. -/// -/// Scatters values from src to dst at positions specified by indices with a -/// reduction operation. -/// -/// # Arguments -/// -/// * `src_ptr` - Source tensor data -/// * `indices_ptr` - Indices tensor (1D) -/// * `dst_ptr` - Destination tensor (must be pre-initialized with appropriate values) -/// * `op` - Reduction operation (sum, max, min) -/// -/// # Safety -/// -/// All pointers must be valid device memory. -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_scatter_reduce( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - src_ptr: u64, - indices_ptr: u64, - dst_ptr: u64, - dim: usize, - outer_size: usize, - dim_size: usize, - inner_size: usize, - src_dim_size: usize, - op: ScatterReduceOpCuda, -) -> Result<()> { - let total = outer_size * src_dim_size * inner_size; - if total == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - - let func_name = match (dtype, op) { - (DType::F32, ScatterReduceOpCuda::Sum) => "scatter_reduce_sum_f32", - (DType::F32, ScatterReduceOpCuda::Max) => "scatter_reduce_max_f32", - (DType::F32, ScatterReduceOpCuda::Min) => "scatter_reduce_min_f32", - (DType::F32, ScatterReduceOpCuda::Prod) => "scatter_reduce_prod_f32", - (DType::F64, ScatterReduceOpCuda::Sum) => "scatter_reduce_sum_f64", - (DType::F64, ScatterReduceOpCuda::Max) => "scatter_reduce_max_f64", - (DType::F64, ScatterReduceOpCuda::Min) => "scatter_reduce_min_f64", - (DType::F64, ScatterReduceOpCuda::Prod) => "scatter_reduce_prod_f64", - (DType::I32, ScatterReduceOpCuda::Sum) => "scatter_reduce_sum_i32", - (DType::I32, ScatterReduceOpCuda::Max) => "scatter_reduce_max_i32", - (DType::I32, ScatterReduceOpCuda::Min) => "scatter_reduce_min_i32", - (DType::I32, ScatterReduceOpCuda::Prod) => "scatter_reduce_prod_i32", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "scatter_reduce", - }); - } - }; - - let func = get_kernel_function(&module, func_name)?; - - let grid = elementwise_launch_config(total); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let dim_u32 = dim as u32; - let outer_size_u32 = outer_size as u32; - let dim_size_u32 = dim_size as u32; - let inner_size_u32 = inner_size as u32; - let src_dim_size_u32 = src_dim_size as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&src_ptr); - builder.arg(&indices_ptr); - builder.arg(&dst_ptr); - builder.arg(&dim_u32); - builder.arg(&outer_size_u32); - builder.arg(&dim_size_u32); - builder.arg(&inner_size_u32); - builder.arg(&src_dim_size_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!("CUDA scatter_reduce kernel launch failed: {:?}", e)) - })?; - - Ok(()) - } -} - -// ============================================================================ -// Scatter Reduce Count (for mean) -// ============================================================================ - -/// Launch scatter_reduce_count kernel. -/// -/// Atomically increments count buffer at scattered positions. -/// Used as part of scatter_reduce mean: sum / count. -/// -/// # Safety -/// -/// All pointers must be valid device memory. -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_scatter_reduce_count( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - indices_ptr: u64, - count_ptr: u64, - dim: usize, - outer_size: usize, - dim_size: usize, - inner_size: usize, - src_dim_size: usize, -) -> Result<()> { - let total = outer_size * src_dim_size * inner_size; - if total == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - - let func_name = match dtype { - DType::F32 => "scatter_reduce_count_f32", - DType::F64 => "scatter_reduce_count_f64", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "scatter_reduce_count", - }); - } - }; - - let func = get_kernel_function(&module, func_name)?; - - let grid = elementwise_launch_config(total); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let dim_u32 = dim as u32; - let outer_size_u32 = outer_size as u32; - let dim_size_u32 = dim_size as u32; - let inner_size_u32 = inner_size as u32; - let src_dim_size_u32 = src_dim_size as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&indices_ptr); - builder.arg(&count_ptr); - builder.arg(&dim_u32); - builder.arg(&outer_size_u32); - builder.arg(&dim_size_u32); - builder.arg(&inner_size_u32); - builder.arg(&src_dim_size_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA scatter_reduce_count kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -// ============================================================================ -// Scatter Reduce Mean Divide -// ============================================================================ - -/// Launch scatter_reduce_mean_div kernel. -/// -/// Element-wise: output[i] = sum[i] / count[i]. -/// If count[i] == 0, output[i] = 0. -/// -/// # Safety -/// -/// All pointers must be valid device memory. -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_scatter_reduce_mean_div( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - sum_ptr: u64, - count_ptr: u64, - output_ptr: u64, - n: usize, -) -> Result<()> { - if n == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - - let func_name = match dtype { - DType::F32 => "scatter_reduce_mean_div_f32", - DType::F64 => "scatter_reduce_mean_div_f64", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "scatter_reduce_mean_div", - }); - } - }; - - let func = get_kernel_function(&module, func_name)?; - - let grid = elementwise_launch_config(n); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let n_u32 = n as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&sum_ptr); - builder.arg(&count_ptr); - builder.arg(&output_ptr); - builder.arg(&n_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA scatter_reduce_mean_div kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -// ============================================================================ -// Gather 2D -// ============================================================================ - -/// Launch gather_2d kernel. -/// -/// Gathers elements from a 2D matrix at specific (row, col) positions. -/// For each index i: output[i] = input[rows[i], cols[i]] -/// -/// # Arguments -/// -/// * `input_ptr` - 2D input tensor data (row-major) -/// * `rows_ptr` - 1D row indices tensor (i64) -/// * `cols_ptr` - 1D column indices tensor (i64) -/// * `output_ptr` - 1D output tensor -/// * `nrows` - Number of rows in input -/// * `ncols` - Number of columns in input -/// * `num_indices` - Number of (row, col) pairs to gather -/// -/// # Safety -/// -/// All pointers must be valid device memory. -#[allow(clippy::too_many_arguments)] -pub unsafe fn launch_gather_2d( - context: &Arc, - stream: &CudaStream, - device_index: usize, - dtype: DType, - input_ptr: u64, - rows_ptr: u64, - cols_ptr: u64, - output_ptr: u64, - nrows: usize, - ncols: usize, - num_indices: usize, -) -> Result<()> { - if num_indices == 0 { - return Ok(()); - } - - unsafe { - let module = get_or_load_module(context, device_index, INDEX_MODULE)?; - let func_name = kernel_name("gather_2d", dtype); - let func = get_kernel_function(&module, &func_name)?; - - let grid = elementwise_launch_config(num_indices); - let block = (BLOCK_SIZE, 1, 1); - let cfg = launch_config(grid, block, 0); - - let nrows_u32 = nrows as u32; - let ncols_u32 = ncols as u32; - let num_indices_u32 = num_indices as u32; - - let mut builder = stream.launch_builder(&func); - builder.arg(&input_ptr); - builder.arg(&rows_ptr); - builder.arg(&cols_ptr); - builder.arg(&output_ptr); - builder.arg(&nrows_u32); - builder.arg(&ncols_u32); - builder.arg(&num_indices_u32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!("CUDA gather_2d kernel launch failed: {:?}", e)) - })?; - - Ok(()) - } -} diff --git a/src/runtime/cuda/kernels/index/embedding.rs b/src/runtime/cuda/kernels/index/embedding.rs new file mode 100644 index 00000000..6d6ad53d --- /dev/null +++ b/src/runtime/cuda/kernels/index/embedding.rs @@ -0,0 +1,142 @@ +//! Embedding lookup and bincount kernel launchers + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use super::super::loader::{ + BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name, + launch_config, +}; +use super::gather::INDEX_MODULE; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +/// Launch embedding_lookup kernel. +/// +/// Looks up embeddings from an embedding table using indices. +/// For each index i: output[i, :] = embeddings[indices[i], :] +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - embeddings must be 2D [vocab_size, embedding_dim] +/// - indices must contain values in [0, vocab_size) +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_embedding_lookup( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + embeddings_ptr: u64, + indices_ptr: u64, + output_ptr: u64, + num_indices: usize, + vocab_size: usize, + embedding_dim: usize, +) -> Result<()> { + if num_indices == 0 || embedding_dim == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func_name = kernel_name("embedding_lookup", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(num_indices); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let num_indices_u32 = num_indices as u32; + let vocab_size_u32 = vocab_size as u32; + let embedding_dim_u32 = embedding_dim as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&embeddings_ptr); + builder.arg(&indices_ptr); + builder.arg(&output_ptr); + builder.arg(&num_indices_u32); + builder.arg(&vocab_size_u32); + builder.arg(&embedding_dim_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA embedding_lookup kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +/// Launch bincount kernel. +/// +/// Counts occurrences of each value in an integer tensor, optionally with weights. +/// +/// # Safety +/// +/// All pointers must be valid device memory. +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_bincount_weighted( + context: &Arc, + stream: &CudaStream, + device_index: usize, + input_dtype: DType, + weights_dtype: Option, + input_ptr: u64, + weights_ptr: Option, + output_ptr: u64, + n: usize, + minlength: usize, +) -> Result<()> { + if n == 0 || minlength == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + + let func_name = match (input_dtype, weights_ptr, weights_dtype) { + (DType::I32, None, _) => "bincount_i32", + (DType::I64, None, _) => "bincount_i64", + (DType::I32, Some(_), Some(DType::F32)) => "bincount_weighted_f32", + (DType::I32, Some(_), Some(DType::F64)) => "bincount_weighted_f64", + (DType::I64, Some(_), Some(DType::F32)) => "bincount_i64_weighted_f32", + _ => { + return Err(Error::InvalidArgument { + arg: "dtype", + reason: format!("bincount requires i32/i64 input, got {:?}", input_dtype), + }); + } + }; + + let func = get_kernel_function(&module, func_name)?; + + let grid = elementwise_launch_config(n); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let n_u32 = n as u32; + let minlength_u32 = minlength as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + + let weights_ptr_val = weights_ptr.unwrap_or(0); + if weights_ptr.is_some() { + builder.arg(&weights_ptr_val); + } + + builder.arg(&output_ptr); + builder.arg(&n_u32); + builder.arg(&minlength_u32); + + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA bincount kernel launch failed: {:?}", e)))?; + + Ok(()) + } +} diff --git a/src/runtime/cuda/kernels/index/gather.rs b/src/runtime/cuda/kernels/index/gather.rs new file mode 100644 index 00000000..5f1ee7c7 --- /dev/null +++ b/src/runtime/cuda/kernels/index/gather.rs @@ -0,0 +1,195 @@ +//! Gather kernel launchers (gather, gather_nd, gather_2d) + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use super::super::loader::{ + BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name, + launch_config, +}; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +/// Module name for indexing operations +pub const INDEX_MODULE: &str = "index"; + +/// Launch gather kernel. +/// +/// Gathers values from input along a dimension specified by indices. +/// `output[i][j][k] = input[i][indices[i][j][k]][k]` (when dim=1) +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - Shape and stride arrays must be valid device memory with `ndim` u32 elements +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_gather( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + indices_ptr: u64, + output_ptr: u64, + ndim: usize, + dim: usize, + input_shape_ptr: u64, + input_strides_ptr: u64, + output_shape_ptr: u64, + output_strides_ptr: u64, + total_elements: usize, +) -> Result<()> { + if total_elements == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func_name = kernel_name("gather", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(total_elements); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let ndim_u32 = ndim as u32; + let dim_u32 = dim as u32; + let total_u32 = total_elements as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&indices_ptr); + builder.arg(&output_ptr); + builder.arg(&ndim_u32); + builder.arg(&dim_u32); + builder.arg(&input_shape_ptr); + builder.arg(&input_strides_ptr); + builder.arg(&output_shape_ptr); + builder.arg(&output_strides_ptr); + builder.arg(&total_u32); + + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA gather kernel launch failed: {:?}", e)))?; + + Ok(()) + } +} + +/// Launch gather_nd kernel. +/// +/// Gathers slices from input at positions specified by indices tensor. +/// +/// # Safety +/// +/// All pointers must be valid device memory with sufficient size. +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_gather_nd( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + indices_ptr: u64, + output_ptr: u64, + input_shape_ptr: u64, + input_strides_ptr: u64, + num_slices: usize, + slice_size: usize, + index_depth: usize, + ndim: usize, +) -> Result<()> { + let total = num_slices * slice_size; + if total == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func_name = kernel_name("gather_nd", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(total); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let num_slices_u32 = num_slices as u32; + let slice_size_u32 = slice_size as u32; + let index_depth_u32 = index_depth as u32; + let ndim_u32 = ndim as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&indices_ptr); + builder.arg(&output_ptr); + builder.arg(&input_shape_ptr); + builder.arg(&input_strides_ptr); + builder.arg(&num_slices_u32); + builder.arg(&slice_size_u32); + builder.arg(&index_depth_u32); + builder.arg(&ndim_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA gather_nd kernel launch failed: {:?}", e)) + })?; + + Ok(()) + } +} + +/// Launch gather_2d kernel. +/// +/// Gathers elements from a 2D matrix at specific (row, col) positions. +/// For each index i: output[i] = input[rows[i], cols[i]] +/// +/// # Safety +/// +/// All pointers must be valid device memory. +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_gather_2d( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + rows_ptr: u64, + cols_ptr: u64, + output_ptr: u64, + nrows: usize, + ncols: usize, + num_indices: usize, +) -> Result<()> { + if num_indices == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func_name = kernel_name("gather_2d", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(num_indices); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let nrows_u32 = nrows as u32; + let ncols_u32 = ncols as u32; + let num_indices_u32 = num_indices as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&rows_ptr); + builder.arg(&cols_ptr); + builder.arg(&output_ptr); + builder.arg(&nrows_u32); + builder.arg(&ncols_u32); + builder.arg(&num_indices_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA gather_2d kernel launch failed: {:?}", e)) + })?; + + Ok(()) + } +} diff --git a/src/runtime/cuda/kernels/index/index_select.rs b/src/runtime/cuda/kernels/index/index_select.rs new file mode 100644 index 00000000..4628579d --- /dev/null +++ b/src/runtime/cuda/kernels/index/index_select.rs @@ -0,0 +1,178 @@ +//! Index select and index bounds validation kernel launchers + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use super::super::loader::{ + BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name, + launch_config, +}; +use super::gather::INDEX_MODULE; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +/// Launch index_select kernel. +/// +/// Selects elements along a dimension using a 1D index tensor. +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - indices must be a 1D tensor of i64 values +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_index_select( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + indices_ptr: u64, + output_ptr: u64, + outer_size: usize, + dim_size: usize, + inner_size: usize, + index_len: usize, +) -> Result<()> { + let total = outer_size * index_len * inner_size; + if total == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func_name = kernel_name("index_select", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(total); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let outer_u32 = outer_size as u32; + let dim_u32 = dim_size as u32; + let inner_u32 = inner_size as u32; + let index_len_u32 = index_len as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&indices_ptr); + builder.arg(&output_ptr); + builder.arg(&outer_u32); + builder.arg(&dim_u32); + builder.arg(&inner_u32); + builder.arg(&index_len_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA index_select kernel launch failed: {:?}", e)) + })?; + + Ok(()) + } +} + +/// Puts values at specified indices along a dimension. +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - indices must be a 1D tensor of i64 values +/// - output must already contain a copy of the input tensor +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_index_put( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + indices_ptr: u64, + src_ptr: u64, + output_ptr: u64, + outer_size: usize, + dim_size: usize, + inner_size: usize, + index_len: usize, +) -> Result<()> { + let total = outer_size * index_len * inner_size; + if total == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func_name = kernel_name("index_put", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(total); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let outer_u32 = outer_size as u32; + let dim_u32 = dim_size as u32; + let inner_u32 = inner_size as u32; + let index_len_u32 = index_len as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&indices_ptr); + builder.arg(&src_ptr); + builder.arg(&output_ptr); + builder.arg(&outer_u32); + builder.arg(&dim_u32); + builder.arg(&inner_u32); + builder.arg(&index_len_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA index_put kernel launch failed: {:?}", e)) + })?; + + Ok(()) + } +} + +/// Launch index bounds validation kernel. +/// +/// Validates that all indices are within bounds [0, dim_size). +/// Returns the count of out-of-bounds indices in error_count buffer. +/// +/// # Safety +/// +/// - indices_ptr must be valid device memory with index_len i64 elements +/// - error_count_ptr must be valid device memory with 1 u32 element (initialized to 0) +pub unsafe fn launch_validate_indices( + context: &Arc, + stream: &CudaStream, + device_index: usize, + indices_ptr: u64, + error_count_ptr: u64, + index_len: usize, + dim_size: usize, +) -> Result<()> { + if index_len == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func = get_kernel_function(&module, "validate_indices_kernel")?; + + let grid = elementwise_launch_config(index_len); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let index_len_u32 = index_len as u32; + let dim_size_u32 = dim_size as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&indices_ptr); + builder.arg(&error_count_ptr); + builder.arg(&index_len_u32); + builder.arg(&dim_size_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA validate_indices kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} diff --git a/src/runtime/cuda/kernels/index/masked.rs b/src/runtime/cuda/kernels/index/masked.rs new file mode 100644 index 00000000..66928cef --- /dev/null +++ b/src/runtime/cuda/kernels/index/masked.rs @@ -0,0 +1,548 @@ +//! Masked select, masked fill, and broadcast masked operation kernel launchers + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use super::super::loader::{ + BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name, + launch_config, +}; +use super::gather::INDEX_MODULE; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +// ============================================================================ +// Masked Select +// ============================================================================ + +/// Launch masked_count kernel to count true elements in mask. +/// +/// # Safety +/// +/// - mask_ptr must be valid device memory with n u8 elements +/// - count_ptr must be valid device memory with 1 u32 element (initialized to 0) +pub unsafe fn launch_masked_count( + context: &Arc, + stream: &CudaStream, + device_index: usize, + mask_ptr: u64, + count_ptr: u64, + n: usize, +) -> Result<()> { + if n == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func = get_kernel_function(&module, "masked_count_kernel")?; + + let grid = elementwise_launch_config(n); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let n_u32 = n as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&mask_ptr); + builder.arg(&count_ptr); + builder.arg(&n_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA masked_count kernel launch failed: {:?}", e)) + })?; + + Ok(()) + } +} + +/// Launch masked_prefix_sum kernel to compute prefix sum of mask. +/// +/// # Safety +/// +/// - mask_ptr must be valid device memory with n u8 elements +/// - prefix_sum_ptr must be valid device memory with n u32 elements +pub unsafe fn launch_masked_prefix_sum( + context: &Arc, + stream: &CudaStream, + device_index: usize, + mask_ptr: u64, + prefix_sum_ptr: u64, + n: usize, +) -> Result<()> { + if n == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func = get_kernel_function(&module, "masked_prefix_sum_kernel")?; + + let cfg = launch_config((1, 1, 1), (1, 1, 1), 0); + + let n_u32 = n as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&mask_ptr); + builder.arg(&prefix_sum_ptr); + builder.arg(&n_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA masked_prefix_sum kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +/// Launch masked_select kernel. +/// +/// Selects elements from input where mask is true, using precomputed prefix sum. +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - prefix_sum must be precomputed via launch_masked_prefix_sum +/// - output must have space for at least count_true elements +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_masked_select( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + mask_ptr: u64, + output_ptr: u64, + prefix_sum_ptr: u64, + n: usize, +) -> Result<()> { + if n == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func_name = kernel_name("masked_select", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(n); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let n_u32 = n as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&mask_ptr); + builder.arg(&output_ptr); + builder.arg(&prefix_sum_ptr); + builder.arg(&n_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA masked_select kernel launch failed: {:?}", e)) + })?; + + Ok(()) + } +} + +// ============================================================================ +// Masked Fill +// ============================================================================ + +/// Launch masked_fill kernel. +/// +/// Fills elements where mask is true with a scalar value. +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - input and output must have n elements +pub unsafe fn launch_masked_fill( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + mask_ptr: u64, + output_ptr: u64, + fill_value: f64, + n: usize, +) -> Result<()> { + if n == 0 { + return Ok(()); + } + + let kernel_name = match dtype { + DType::F32 => "masked_fill_f32", + DType::F64 => "masked_fill_f64", + DType::I32 => "masked_fill_i32", + DType::I64 => "masked_fill_i64", + #[cfg(feature = "f16")] + DType::F16 => "masked_fill_f16", + #[cfg(feature = "f16")] + DType::BF16 => "masked_fill_bf16", + #[cfg(feature = "fp8")] + DType::FP8E4M3 => "masked_fill_fp8_e4m3", + #[cfg(feature = "fp8")] + DType::FP8E5M2 => "masked_fill_fp8_e5m2", + _ => { + return Err(Error::UnsupportedDType { + dtype, + op: "masked_fill", + }); + } + }; + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func = get_kernel_function(&module, kernel_name)?; + + let grid = elementwise_launch_config(n); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let n_u32 = n as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&mask_ptr); + builder.arg(&output_ptr); + + let fill_f32 = fill_value as f32; + let fill_f64 = fill_value; + let fill_i32 = fill_value as i32; + let fill_i64 = fill_value as i64; + #[cfg(feature = "f16")] + let fill_f16 = half::f16::from_f64(fill_value).to_bits(); + #[cfg(feature = "f16")] + let fill_bf16 = half::bf16::from_f64(fill_value).to_bits(); + #[cfg(feature = "fp8")] + let fill_fp8_e4m3 = crate::dtype::fp8::FP8E4M3::from_f64(fill_value).to_bits(); + #[cfg(feature = "fp8")] + let fill_fp8_e5m2 = crate::dtype::fp8::FP8E5M2::from_f64(fill_value).to_bits(); + + match dtype { + DType::F32 => builder.arg(&fill_f32), + DType::F64 => builder.arg(&fill_f64), + DType::I32 => builder.arg(&fill_i32), + DType::I64 => builder.arg(&fill_i64), + #[cfg(feature = "f16")] + DType::F16 => builder.arg(&fill_f16), + #[cfg(feature = "f16")] + DType::BF16 => builder.arg(&fill_bf16), + #[cfg(feature = "fp8")] + DType::FP8E4M3 => builder.arg(&fill_fp8_e4m3), + #[cfg(feature = "fp8")] + DType::FP8E5M2 => builder.arg(&fill_fp8_e5m2), + _ => unreachable!(), + }; + + builder.arg(&n_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA masked_fill kernel launch failed: {:?}", e)) + })?; + + Ok(()) + } +} + +// ============================================================================ +// Broadcast Masked Operations +// ============================================================================ + +/// Launch broadcast masked_count kernel. +/// +/// # Safety +/// +/// - mask_ptr must be valid device memory +/// - count_ptr must be valid device memory with 1 u32 element (initialized to 0) +/// - mask_strides_ptr, out_shape_ptr must be valid device memory with ndim u32 elements +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_masked_count_broadcast( + context: &Arc, + stream: &CudaStream, + device_index: usize, + mask_ptr: u64, + count_ptr: u64, + mask_strides_ptr: u64, + out_shape_ptr: u64, + ndim: usize, + n: usize, +) -> Result<()> { + if n == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func = get_kernel_function(&module, "masked_count_broadcast_kernel")?; + + let grid = elementwise_launch_config(n); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let ndim_u32 = ndim as u32; + let n_u32 = n as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&mask_ptr); + builder.arg(&count_ptr); + builder.arg(&mask_strides_ptr); + builder.arg(&out_shape_ptr); + builder.arg(&ndim_u32); + builder.arg(&n_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA masked_count_broadcast kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +/// Launch broadcast masked_prefix_sum kernel. +/// +/// # Safety +/// +/// - mask_ptr must be valid device memory +/// - prefix_sum_ptr must be valid device memory with n u32 elements +/// - mask_strides_ptr, out_shape_ptr must be valid device memory with ndim u32 elements +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_masked_prefix_sum_broadcast( + context: &Arc, + stream: &CudaStream, + device_index: usize, + mask_ptr: u64, + prefix_sum_ptr: u64, + mask_strides_ptr: u64, + out_shape_ptr: u64, + ndim: usize, + n: usize, +) -> Result<()> { + if n == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func = get_kernel_function(&module, "masked_prefix_sum_broadcast_kernel")?; + + let cfg = launch_config((1, 1, 1), (1, 1, 1), 0); + + let ndim_u32 = ndim as u32; + let n_u32 = n as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&mask_ptr); + builder.arg(&prefix_sum_ptr); + builder.arg(&mask_strides_ptr); + builder.arg(&out_shape_ptr); + builder.arg(&ndim_u32); + builder.arg(&n_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA masked_prefix_sum_broadcast kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +/// Launch broadcast masked_select kernel. +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - prefix_sum must be precomputed via launch_masked_prefix_sum_broadcast +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_masked_select_broadcast( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + mask_ptr: u64, + output_ptr: u64, + prefix_sum_ptr: u64, + mask_strides_ptr: u64, + out_shape_ptr: u64, + ndim: usize, + n: usize, +) -> Result<()> { + if n == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func_name = format!("masked_select_broadcast_{}", dtype_suffix(dtype)?); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(n); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let ndim_u32 = ndim as u32; + let n_u32 = n as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&mask_ptr); + builder.arg(&output_ptr); + builder.arg(&prefix_sum_ptr); + builder.arg(&mask_strides_ptr); + builder.arg(&out_shape_ptr); + builder.arg(&ndim_u32); + builder.arg(&n_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA masked_select_broadcast kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +/// Launch broadcast masked_fill kernel. +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - mask_strides_ptr, out_shape_ptr must be valid device memory with ndim u32 elements +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_masked_fill_broadcast( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + mask_ptr: u64, + output_ptr: u64, + fill_value: f64, + mask_strides_ptr: u64, + out_shape_ptr: u64, + ndim: usize, + n: usize, +) -> Result<()> { + if n == 0 { + return Ok(()); + } + + let kernel_name = match dtype { + DType::F32 => "masked_fill_broadcast_f32", + DType::F64 => "masked_fill_broadcast_f64", + DType::I32 => "masked_fill_broadcast_i32", + DType::I64 => "masked_fill_broadcast_i64", + #[cfg(feature = "f16")] + DType::F16 => "masked_fill_broadcast_f16", + #[cfg(feature = "f16")] + DType::BF16 => "masked_fill_broadcast_bf16", + #[cfg(feature = "fp8")] + DType::FP8E4M3 => "masked_fill_broadcast_fp8_e4m3", + #[cfg(feature = "fp8")] + DType::FP8E5M2 => "masked_fill_broadcast_fp8_e5m2", + _ => { + return Err(Error::UnsupportedDType { + dtype, + op: "masked_fill_broadcast", + }); + } + }; + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func = get_kernel_function(&module, kernel_name)?; + + let grid = elementwise_launch_config(n); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let ndim_u32 = ndim as u32; + let n_u32 = n as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&mask_ptr); + builder.arg(&output_ptr); + + let fill_f32 = fill_value as f32; + let fill_f64 = fill_value; + let fill_i32 = fill_value as i32; + let fill_i64 = fill_value as i64; + #[cfg(feature = "f16")] + let fill_f16 = half::f16::from_f64(fill_value).to_bits(); + #[cfg(feature = "f16")] + let fill_bf16 = half::bf16::from_f64(fill_value).to_bits(); + #[cfg(feature = "fp8")] + let fill_fp8_e4m3 = crate::dtype::fp8::FP8E4M3::from_f64(fill_value).to_bits(); + #[cfg(feature = "fp8")] + let fill_fp8_e5m2 = crate::dtype::fp8::FP8E5M2::from_f64(fill_value).to_bits(); + + match dtype { + DType::F32 => builder.arg(&fill_f32), + DType::F64 => builder.arg(&fill_f64), + DType::I32 => builder.arg(&fill_i32), + DType::I64 => builder.arg(&fill_i64), + #[cfg(feature = "f16")] + DType::F16 => builder.arg(&fill_f16), + #[cfg(feature = "f16")] + DType::BF16 => builder.arg(&fill_bf16), + #[cfg(feature = "fp8")] + DType::FP8E4M3 => builder.arg(&fill_fp8_e4m3), + #[cfg(feature = "fp8")] + DType::FP8E5M2 => builder.arg(&fill_fp8_e5m2), + _ => unreachable!(), + }; + + builder.arg(&mask_strides_ptr); + builder.arg(&out_shape_ptr); + builder.arg(&ndim_u32); + builder.arg(&n_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA masked_fill_broadcast kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +/// Helper to get dtype suffix for kernel name +fn dtype_suffix(dtype: DType) -> Result<&'static str> { + match dtype { + DType::F32 => Ok("f32"), + DType::F64 => Ok("f64"), + DType::I32 => Ok("i32"), + DType::I64 => Ok("i64"), + #[cfg(feature = "f16")] + DType::F16 => Ok("f16"), + #[cfg(feature = "f16")] + DType::BF16 => Ok("bf16"), + #[cfg(feature = "fp8")] + DType::FP8E4M3 => Ok("fp8_e4m3"), + #[cfg(feature = "fp8")] + DType::FP8E5M2 => Ok("fp8_e5m2"), + _ => Err(Error::UnsupportedDType { + dtype, + op: "masked_select_broadcast", + }), + } +} diff --git a/src/runtime/cuda/kernels/index/mod.rs b/src/runtime/cuda/kernels/index/mod.rs new file mode 100644 index 00000000..4848f693 --- /dev/null +++ b/src/runtime/cuda/kernels/index/mod.rs @@ -0,0 +1,18 @@ +//! Indexing CUDA kernel launchers +//! +//! Provides launchers for indexing operations: gather, scatter, index_select, +//! masked_select, masked_fill, embedding, and slice_assign. + +mod embedding; +mod gather; +mod index_select; +mod masked; +mod scatter; +mod slice_assign; + +pub use embedding::*; +pub use gather::*; +pub use index_select::*; +pub use masked::*; +pub use scatter::*; +pub use slice_assign::*; diff --git a/src/runtime/cuda/kernels/index/scatter.rs b/src/runtime/cuda/kernels/index/scatter.rs new file mode 100644 index 00000000..3bf7992d --- /dev/null +++ b/src/runtime/cuda/kernels/index/scatter.rs @@ -0,0 +1,352 @@ +//! Scatter kernel launchers (scatter, copy, scatter_reduce) + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use super::super::loader::{ + BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name, + launch_config, +}; +use super::gather::INDEX_MODULE; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +/// Launch scatter kernel. +/// +/// Scatters values from src to output at positions specified by indices. +/// `output[i][indices[i][j][k]][k] = src[i][j][k]` (when dim=1) +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - Output must be pre-initialized (typically a copy of input) +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_scatter( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + indices_ptr: u64, + src_ptr: u64, + output_ptr: u64, + ndim: usize, + dim: usize, + output_shape_ptr: u64, + output_strides_ptr: u64, + src_shape_ptr: u64, + src_strides_ptr: u64, + src_total: usize, +) -> Result<()> { + if src_total == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func_name = kernel_name("scatter", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(src_total); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let ndim_u32 = ndim as u32; + let dim_u32 = dim as u32; + let src_total_u32 = src_total as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&indices_ptr); + builder.arg(&src_ptr); + builder.arg(&output_ptr); + builder.arg(&ndim_u32); + builder.arg(&dim_u32); + builder.arg(&output_shape_ptr); + builder.arg(&output_strides_ptr); + builder.arg(&src_shape_ptr); + builder.arg(&src_strides_ptr); + builder.arg(&src_total_u32); + + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA scatter kernel launch failed: {:?}", e)))?; + + Ok(()) + } +} + +/// Launch copy kernel for scatter initialization. +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - dst must have space for n elements +pub unsafe fn launch_copy( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + src_ptr: u64, + dst_ptr: u64, + n: usize, +) -> Result<()> { + if n == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func_name = kernel_name("copy", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(n); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let n_u32 = n as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&src_ptr); + builder.arg(&dst_ptr); + builder.arg(&n_u32); + + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA copy kernel launch failed: {:?}", e)))?; + + Ok(()) + } +} + +/// Scatter reduce operation type. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ScatterReduceOpCuda { + /// Sum reduction: accumulate values by addition. + Sum, + /// Max reduction: keep the maximum value. + Max, + /// Min reduction: keep the minimum value. + Min, + /// Product reduction: accumulate values by multiplication. + Prod, +} + +/// Launch scatter_reduce kernel. +/// +/// Scatters values from src to dst at positions specified by indices with a +/// reduction operation. +/// +/// # Safety +/// +/// All pointers must be valid device memory. +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_scatter_reduce( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + src_ptr: u64, + indices_ptr: u64, + dst_ptr: u64, + dim: usize, + outer_size: usize, + dim_size: usize, + inner_size: usize, + src_dim_size: usize, + op: ScatterReduceOpCuda, +) -> Result<()> { + let total = outer_size * src_dim_size * inner_size; + if total == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + + let func_name = match (dtype, op) { + (DType::F32, ScatterReduceOpCuda::Sum) => "scatter_reduce_sum_f32", + (DType::F32, ScatterReduceOpCuda::Max) => "scatter_reduce_max_f32", + (DType::F32, ScatterReduceOpCuda::Min) => "scatter_reduce_min_f32", + (DType::F32, ScatterReduceOpCuda::Prod) => "scatter_reduce_prod_f32", + (DType::F64, ScatterReduceOpCuda::Sum) => "scatter_reduce_sum_f64", + (DType::F64, ScatterReduceOpCuda::Max) => "scatter_reduce_max_f64", + (DType::F64, ScatterReduceOpCuda::Min) => "scatter_reduce_min_f64", + (DType::F64, ScatterReduceOpCuda::Prod) => "scatter_reduce_prod_f64", + (DType::I32, ScatterReduceOpCuda::Sum) => "scatter_reduce_sum_i32", + (DType::I32, ScatterReduceOpCuda::Max) => "scatter_reduce_max_i32", + (DType::I32, ScatterReduceOpCuda::Min) => "scatter_reduce_min_i32", + (DType::I32, ScatterReduceOpCuda::Prod) => "scatter_reduce_prod_i32", + _ => { + return Err(Error::UnsupportedDType { + dtype, + op: "scatter_reduce", + }); + } + }; + + let func = get_kernel_function(&module, func_name)?; + + let grid = elementwise_launch_config(total); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let dim_u32 = dim as u32; + let outer_size_u32 = outer_size as u32; + let dim_size_u32 = dim_size as u32; + let inner_size_u32 = inner_size as u32; + let src_dim_size_u32 = src_dim_size as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&src_ptr); + builder.arg(&indices_ptr); + builder.arg(&dst_ptr); + builder.arg(&dim_u32); + builder.arg(&outer_size_u32); + builder.arg(&dim_size_u32); + builder.arg(&inner_size_u32); + builder.arg(&src_dim_size_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA scatter_reduce kernel launch failed: {:?}", e)) + })?; + + Ok(()) + } +} + +/// Launch scatter_reduce_count kernel. +/// +/// Atomically increments count buffer at scattered positions. +/// Used as part of scatter_reduce mean: sum / count. +/// +/// # Safety +/// +/// All pointers must be valid device memory. +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_scatter_reduce_count( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + indices_ptr: u64, + count_ptr: u64, + dim: usize, + outer_size: usize, + dim_size: usize, + inner_size: usize, + src_dim_size: usize, +) -> Result<()> { + let total = outer_size * src_dim_size * inner_size; + if total == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + + let func_name = match dtype { + DType::F32 => "scatter_reduce_count_f32", + DType::F64 => "scatter_reduce_count_f64", + _ => { + return Err(Error::UnsupportedDType { + dtype, + op: "scatter_reduce_count", + }); + } + }; + + let func = get_kernel_function(&module, func_name)?; + + let grid = elementwise_launch_config(total); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let dim_u32 = dim as u32; + let outer_size_u32 = outer_size as u32; + let dim_size_u32 = dim_size as u32; + let inner_size_u32 = inner_size as u32; + let src_dim_size_u32 = src_dim_size as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&indices_ptr); + builder.arg(&count_ptr); + builder.arg(&dim_u32); + builder.arg(&outer_size_u32); + builder.arg(&dim_size_u32); + builder.arg(&inner_size_u32); + builder.arg(&src_dim_size_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA scatter_reduce_count kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +/// Launch scatter_reduce_mean_div kernel. +/// +/// Element-wise: output[i] = sum[i] / count[i]. +/// If count[i] == 0, output[i] = 0. +/// +/// # Safety +/// +/// All pointers must be valid device memory. +#[allow(clippy::too_many_arguments)] +pub unsafe fn launch_scatter_reduce_mean_div( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + sum_ptr: u64, + count_ptr: u64, + output_ptr: u64, + n: usize, +) -> Result<()> { + if n == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + + let func_name = match dtype { + DType::F32 => "scatter_reduce_mean_div_f32", + DType::F64 => "scatter_reduce_mean_div_f64", + _ => { + return Err(Error::UnsupportedDType { + dtype, + op: "scatter_reduce_mean_div", + }); + } + }; + + let func = get_kernel_function(&module, func_name)?; + + let grid = elementwise_launch_config(n); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let n_u32 = n as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&sum_ptr); + builder.arg(&count_ptr); + builder.arg(&output_ptr); + builder.arg(&n_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA scatter_reduce_mean_div kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} diff --git a/src/runtime/cuda/kernels/index/slice_assign.rs b/src/runtime/cuda/kernels/index/slice_assign.rs new file mode 100644 index 00000000..a0c5b8d7 --- /dev/null +++ b/src/runtime/cuda/kernels/index/slice_assign.rs @@ -0,0 +1,72 @@ +//! Slice assign kernel launcher + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use super::super::loader::{ + BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name, + launch_config, +}; +use super::gather::INDEX_MODULE; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +/// Launch slice_assign kernel: copies src into a region of output (pre-copied from dst). +/// +/// Output must already contain a copy of dst. This kernel overwrites the slice region +/// [start..start+src_dim_size] along the specified dimension with src data. +/// +/// # Safety +/// +/// - src_ptr: valid device memory with outer_size * src_dim_size * inner_size elements +/// - output_ptr: valid device memory with outer_size * dst_dim_size * inner_size elements +pub unsafe fn launch_slice_assign( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + src_ptr: u64, + output_ptr: u64, + outer_size: usize, + dst_dim_size: usize, + src_dim_size: usize, + inner_size: usize, + start: usize, +) -> Result<()> { + let total = outer_size * src_dim_size * inner_size; + if total == 0 { + return Ok(()); + } + + unsafe { + let module = get_or_load_module(context, device_index, INDEX_MODULE)?; + let func_name = kernel_name("slice_assign", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let grid = elementwise_launch_config(total); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let outer_u32 = outer_size as u32; + let dst_dim_u32 = dst_dim_size as u32; + let src_dim_u32 = src_dim_size as u32; + let inner_u32 = inner_size as u32; + let start_u32 = start as u32; + + let mut builder = stream.launch_builder(&func); + builder.arg(&src_ptr); + builder.arg(&output_ptr); + builder.arg(&outer_u32); + builder.arg(&dst_dim_u32); + builder.arg(&src_dim_u32); + builder.arg(&inner_u32); + builder.arg(&start_u32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA slice_assign kernel launch failed: {:?}", e)) + })?; + + Ok(()) + } +} diff --git a/src/runtime/cuda/kernels/loader.rs b/src/runtime/cuda/kernels/loader.rs index e5554f2c..6a4c826a 100644 --- a/src/runtime/cuda/kernels/loader.rs +++ b/src/runtime/cuda/kernels/loader.rs @@ -45,11 +45,6 @@ fn load_ptx(name: &str) -> Ptx { static MODULE_CACHE: OnceLock>>> = OnceLock::new(); -/// Get a reference to the module cache (for cache invalidation during recovery). -pub fn module_cache() -> Option<&'static Mutex>>> { - MODULE_CACHE.get() -} - /// Get or load a CUDA module from PTX. /// /// Modules are cached per-device to avoid repeated loading. This is thread-safe @@ -94,6 +89,21 @@ pub fn get_or_load_module( Ok(module) } +/// Pre-load a list of CUDA modules to avoid JIT compilation latency on first use. +/// +/// This is useful for inference warmup: call this once with all module names +/// that will be used during inference to front-load all PTX→SASS compilation. +pub fn preload_modules( + context: &Arc, + device_index: usize, + module_names: &[&'static str], +) -> Result<()> { + for name in module_names { + get_or_load_module(context, device_index, name)?; + } + Ok(()) +} + /// Get a kernel function from a loaded module. /// /// # Arguments @@ -160,7 +170,9 @@ pub fn reduce_dim_launch_config(outer: usize, inner: usize) -> ((u32, u32, u32), #[inline] pub fn softmax_launch_config(outer: usize, dim_size: usize) -> (u32, u32, u32) { // One block per row, threads handle the dimension - let block_size = BLOCK_SIZE.min(dim_size as u32); + // Block size must be a power of 2 for the shared-memory tree reduction to work correctly + let block_size = BLOCK_SIZE.min(dim_size as u32).next_power_of_two(); + let block_size = block_size.min(BLOCK_SIZE); let grid_size = outer as u32; // Shared memory: 2 arrays of block_size floats (for max and sum reduction) let shared_mem = 2 * block_size * 4; // f32 @@ -213,10 +225,14 @@ pub mod kernel_names { pub const REDUCE_MODULE: &str = "reduce"; /// Comparison operations (eq, ne, lt, le, gt, ge) pub const COMPARE_MODULE: &str = "compare"; - /// Activation functions (relu, sigmoid, softmax, silu, gelu) + /// Element-wise activation functions (relu, sigmoid, silu, gelu, leaky_relu, elu) pub const ACTIVATION_MODULE: &str = "activation"; + /// Softmax forward + backward kernels + pub const SOFTMAX_MODULE: &str = "softmax"; /// Normalization operations (rms_norm, layer_norm) pub const NORM_MODULE: &str = "norm"; + /// Fused add + normalization operations + pub const FUSED_ADD_NORM_MODULE: &str = "fused_add_norm"; /// Type casting operations (cast between dtypes) pub const CAST_MODULE: &str = "cast"; /// Utility operations (fill) @@ -265,6 +281,8 @@ pub mod kernel_names { pub const LINALG_MATRIX_FUNCS_MODULE: &str = "linalg_matrix_funcs"; /// Matrix multiplication operations (native tiled GEMM) pub const MATMUL_MODULE: &str = "matmul"; + /// GEMV operations (matrix-vector multiply for small M) + pub const GEMV_MODULE: &str = "gemv"; /// Cumulative operations (cumsum, cumprod, logsumexp) pub const CUMULATIVE_MODULE: &str = "cumulative"; /// Distribution sampling operations (bernoulli, beta, gamma, etc.) @@ -544,6 +562,27 @@ pub unsafe fn launch_matmul_kernel( n: usize, k: usize, ) -> Result<()> { + // Use GEMV kernel for small M (single-token decode in LLM inference) + // The tiled GEMM wastes 99%+ compute when M < block_m (typically 128) + if m <= 16 { + unsafe { + return launch_gemv_kernel( + context, + stream, + device_index, + dtype, + a_ptr, + b_ptr, + c_ptr, + 1, + m, + n, + k, + 1, + 1, + ); + } + } unsafe { launch_matmul_kernel_with_config( context, @@ -561,6 +600,198 @@ pub unsafe fn launch_matmul_kernel( } } +/// Launch GEMV kernel: C[batch,M,N] = A[batch,M,K] @ B[batch,K,N] for small M +/// +/// B is [K,N] row-major (non-transposed). One thread per output column, iterates K. +/// +/// # Safety +/// +/// All pointers must be valid device memory with correct sizes. +pub unsafe fn launch_gemv_kernel( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + a_ptr: u64, + b_ptr: u64, + c_ptr: u64, + batch: usize, + m: usize, + n: usize, + k: usize, + a_batch: usize, + b_batch: usize, +) -> Result<()> { + let module = get_or_load_module(context, device_index, kernel_names::GEMV_MODULE)?; + let func_name = kernel_name("gemv", dtype); + let func = get_kernel_function(&module, &func_name)?; + + // grid: (ceil(N/256), M, batch), block: (256, 1, 1) + // One thread per output column, each thread iterates over K. + let block_size: u32 = 256; + let grid_x = ((n as u32) + block_size - 1) / block_size; + let grid_y = m as u32; + let grid_z = batch as u32; + let cfg = LaunchConfig { + grid_dim: (grid_x, grid_y, grid_z), + block_dim: (block_size, 1, 1), + shared_mem_bytes: 0, + }; + + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + let a_batch_u32 = a_batch as u32; + let b_batch_u32 = b_batch as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&c_ptr); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder.arg(&a_batch_u32); + builder.arg(&b_batch_u32); + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA GEMV kernel launch failed: {:?}", e)))?; + } + + Ok(()) +} + +/// Launch GEMV kernel with transposed B: C[batch,M,N] = A[batch,M,K] @ B^T +/// +/// B is stored [N,K] row-major (transposed weight matrix, common for nn.Linear). +/// Warp-cooperative: each warp reduces one output column along K using shuffle. +/// +/// # Safety +/// +/// All pointers must be valid device memory with correct sizes. +/// `b_ptr` points to the raw [N,K] data (NOT the transposed [K,N] view). +pub unsafe fn launch_gemv_kernel_bt( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + a_ptr: u64, + b_ptr: u64, + c_ptr: u64, + batch: usize, + m: usize, + n: usize, + k: usize, + a_batch: usize, + b_batch: usize, +) -> Result<()> { + let module = get_or_load_module(context, device_index, kernel_names::GEMV_MODULE)?; + let func_name = kernel_name("gemv_bt", dtype); + let func = get_kernel_function(&module, &func_name)?; + + // grid: (ceil(N/WARPS_PER_BLOCK), M, batch), block: (256, 1, 1) + // 8 warps per block, each warp handles one output column. + let warps_per_block: u32 = 8; + let grid_x = ((n as u32) + warps_per_block - 1) / warps_per_block; + let grid_y = m as u32; + let grid_z = batch as u32; + let cfg = LaunchConfig { + grid_dim: (grid_x, grid_y, grid_z), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + let a_batch_u32 = a_batch as u32; + let b_batch_u32 = b_batch as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&c_ptr); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder.arg(&a_batch_u32); + builder.arg(&b_batch_u32); + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA GEMV-BT kernel launch failed: {:?}", e)))?; + } + + Ok(()) +} + +/// Launch multi-row GEMV kernel with transposed B: C[batch,M,N] = A[batch,M,K] @ B^T +/// +/// Each warp computes 2 output columns, sharing the activation vector load across rows. +/// This halves activation memory bandwidth compared to `launch_gemv_kernel_bt`. +/// +/// # Safety +/// +/// All pointers must be valid device memory with correct sizes. +/// `b_ptr` points to the raw [N,K] data (NOT the transposed [K,N] view). +pub unsafe fn launch_gemv_kernel_bt_mr( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + a_ptr: u64, + b_ptr: u64, + c_ptr: u64, + batch: usize, + m: usize, + n: usize, + k: usize, + a_batch: usize, + b_batch: usize, +) -> Result<()> { + let module = get_or_load_module(context, device_index, kernel_names::GEMV_MODULE)?; + let func_name = kernel_name("gemv_bt_mr", dtype); + let func = get_kernel_function(&module, &func_name)?; + + // grid: (ceil(N / (WARPS_PER_BLOCK * ROWS_PER_WARP)), M, batch), block: (256, 1, 1) + // 8 warps per block, each warp handles 2 output columns. + let warps_per_block: u32 = 8; + let rows_per_warp: u32 = 2; + let cols_per_block = warps_per_block * rows_per_warp; // 16 + let grid_x = ((n as u32) + cols_per_block - 1) / cols_per_block; + let grid_y = m as u32; + let grid_z = batch as u32; + let cfg = LaunchConfig { + grid_dim: (grid_x, grid_y, grid_z), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + + let m_u32 = m as u32; + let n_u32 = n as u32; + let k_u32 = k as u32; + let a_batch_u32 = a_batch as u32; + let b_batch_u32 = b_batch as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&a_ptr); + builder.arg(&b_ptr); + builder.arg(&c_ptr); + builder.arg(&m_u32); + builder.arg(&n_u32); + builder.arg(&k_u32); + builder.arg(&a_batch_u32); + builder.arg(&b_batch_u32); + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA GEMV-BT-MR kernel launch failed: {:?}", e)) + })?; + } + + Ok(()) +} + /// Launch native tiled matmul kernel with custom tile configuration. /// /// # Safety @@ -642,7 +873,29 @@ pub unsafe fn launch_matmul_batched_kernel( m: usize, n: usize, k: usize, + a_batch: usize, + b_batch: usize, ) -> Result<()> { + // Use GEMV kernel for small M (batched case) + if m <= 16 { + unsafe { + return launch_gemv_kernel( + context, + stream, + device_index, + dtype, + a_ptr, + b_ptr, + c_ptr, + batch, + m, + n, + k, + a_batch, + b_batch, + ); + } + } unsafe { launch_matmul_batched_kernel_with_config( context, @@ -657,6 +910,8 @@ pub unsafe fn launch_matmul_batched_kernel( n, k, &default_tile_config(dtype), + a_batch, + b_batch, ) } } @@ -679,6 +934,8 @@ pub unsafe fn launch_matmul_batched_kernel_with_config( n: usize, k: usize, tile_cfg: &TileConfig, + a_batch: usize, + b_batch: usize, ) -> Result<()> { let module = get_or_load_module(context, device_index, kernel_names::MATMUL_MODULE)?; let func_name = kernel_name("matmul_batched", dtype); @@ -700,6 +957,8 @@ pub unsafe fn launch_matmul_batched_kernel_with_config( let block_k = tile_cfg.block_k as u32; let thread_m = tile_cfg.thread_m as u32; let thread_n = tile_cfg.thread_n as u32; + let a_batch_u32 = a_batch as u32; + let b_batch_u32 = b_batch as u32; unsafe { let mut builder = stream.launch_builder(&func); @@ -715,6 +974,8 @@ pub unsafe fn launch_matmul_batched_kernel_with_config( builder.arg(&block_k); builder.arg(&thread_m); builder.arg(&thread_n); + builder.arg(&a_batch_u32); + builder.arg(&b_batch_u32); builder.launch(cfg).map_err(|e| { Error::Internal(format!("CUDA batched matmul kernel launch failed: {:?}", e)) @@ -857,6 +1118,8 @@ pub unsafe fn launch_matmul_bias_batched_kernel( m: usize, n: usize, k: usize, + a_batch: usize, + b_batch: usize, ) -> Result<()> { unsafe { launch_matmul_bias_batched_kernel_with_config( @@ -873,6 +1136,8 @@ pub unsafe fn launch_matmul_bias_batched_kernel( n, k, &default_tile_config(dtype), + a_batch, + b_batch, ) } } @@ -896,6 +1161,8 @@ pub unsafe fn launch_matmul_bias_batched_kernel_with_config( n: usize, k: usize, tile_cfg: &TileConfig, + a_batch: usize, + b_batch: usize, ) -> Result<()> { let module = get_or_load_module(context, device_index, kernel_names::MATMUL_MODULE)?; let func_name = kernel_name("matmul_bias_batched", dtype); @@ -917,6 +1184,8 @@ pub unsafe fn launch_matmul_bias_batched_kernel_with_config( let block_k = tile_cfg.block_k as u32; let thread_m = tile_cfg.thread_m as u32; let thread_n = tile_cfg.thread_n as u32; + let a_batch_u32 = a_batch as u32; + let b_batch_u32 = b_batch as u32; unsafe { let mut builder = stream.launch_builder(&func); @@ -933,6 +1202,8 @@ pub unsafe fn launch_matmul_bias_batched_kernel_with_config( builder.arg(&block_k); builder.arg(&thread_m); builder.arg(&thread_n); + builder.arg(&a_batch_u32); + builder.arg(&b_batch_u32); builder.launch(cfg).map_err(|e| { Error::Internal(format!( @@ -1028,6 +1299,8 @@ pub unsafe fn launch_semiring_matmul_batched_kernel( n: usize, k: usize, semiring_op: u32, + a_batch: usize, + b_batch: usize, ) -> Result<()> { let module = get_or_load_module(context, device_index, kernel_names::SEMIRING_MATMUL_MODULE)?; let func_name = kernel_name("semiring_matmul_batched", dtype); @@ -1049,6 +1322,8 @@ pub unsafe fn launch_semiring_matmul_batched_kernel( let n_u32 = n as u32; let k_u32 = k as u32; let batch_u32 = batch as u32; + let a_batch_u32 = a_batch as u32; + let b_batch_u32 = b_batch as u32; unsafe { let mut builder = stream.launch_builder(&func); @@ -1060,6 +1335,8 @@ pub unsafe fn launch_semiring_matmul_batched_kernel( builder.arg(&k_u32); builder.arg(&semiring_op); builder.arg(&batch_u32); + builder.arg(&a_batch_u32); + builder.arg(&b_batch_u32); builder.launch(cfg).map_err(|e| { Error::Internal(format!( diff --git a/src/runtime/cuda/kernels/matmul.cu b/src/runtime/cuda/kernels/matmul.cu index e54c9afc..ea636d95 100644 --- a/src/runtime/cuda/kernels/matmul.cu +++ b/src/runtime/cuda/kernels/matmul.cu @@ -160,7 +160,9 @@ extern "C" __global__ void matmul_batched_f32( unsigned int block_n, unsigned int block_k, unsigned int thread_m, - unsigned int thread_n + unsigned int thread_n, + unsigned int a_batch_count, + unsigned int b_batch_count ) { extern __shared__ float shared_mem[]; float* As = shared_mem; @@ -173,8 +175,8 @@ extern "C" __global__ void matmul_batched_f32( const unsigned int stride_b = K * N; const unsigned int stride_c = M * N; - const float* A_batch = A + b * stride_a; - const float* B_batch = B + b * stride_b; + const float* A_batch = A + (b % a_batch_count) * stride_a; + const float* B_batch = B + (b % b_batch_count) * stride_b; float* C_batch = C + b * stride_c; const unsigned int tx = threadIdx.x; @@ -378,7 +380,9 @@ extern "C" __global__ void matmul_batched_f64( unsigned int block_n, unsigned int block_k, unsigned int thread_m, - unsigned int thread_n + unsigned int thread_n, + unsigned int a_batch_count, + unsigned int b_batch_count ) { extern __shared__ double shared_mem_f64[]; double* As = shared_mem_f64; @@ -391,8 +395,8 @@ extern "C" __global__ void matmul_batched_f64( const unsigned int stride_b = K * N; const unsigned int stride_c = M * N; - const double* A_batch = A + b * stride_a; - const double* B_batch = B + b * stride_b; + const double* A_batch = A + (b % a_batch_count) * stride_a; + const double* B_batch = B + (b % b_batch_count) * stride_b; double* C_batch = C + b * stride_c; const unsigned int tx = threadIdx.x; @@ -597,7 +601,9 @@ extern "C" __global__ void matmul_batched_f16( unsigned int block_n, unsigned int block_k, unsigned int thread_m, - unsigned int thread_n + unsigned int thread_n, + unsigned int a_batch_count, + unsigned int b_batch_count ) { extern __shared__ float shared_mem[]; float* As = shared_mem; @@ -610,8 +616,8 @@ extern "C" __global__ void matmul_batched_f16( const unsigned int stride_b = K * N; const unsigned int stride_c = M * N; - const __half* A_batch = A + b * stride_a; - const __half* B_batch = B + b * stride_b; + const __half* A_batch = A + (b % a_batch_count) * stride_a; + const __half* B_batch = B + (b % b_batch_count) * stride_b; __half* C_batch = C + b * stride_c; const unsigned int tx = threadIdx.x; @@ -815,7 +821,9 @@ extern "C" __global__ void matmul_batched_bf16( unsigned int block_n, unsigned int block_k, unsigned int thread_m, - unsigned int thread_n + unsigned int thread_n, + unsigned int a_batch_count, + unsigned int b_batch_count ) { extern __shared__ float shared_mem[]; float* As = shared_mem; @@ -828,8 +836,8 @@ extern "C" __global__ void matmul_batched_bf16( const unsigned int stride_b = K * N; const unsigned int stride_c = M * N; - const __nv_bfloat16* A_batch = A + b * stride_a; - const __nv_bfloat16* B_batch = B + b * stride_b; + const __nv_bfloat16* A_batch = A + (b % a_batch_count) * stride_a; + const __nv_bfloat16* B_batch = B + (b % b_batch_count) * stride_b; __nv_bfloat16* C_batch = C + b * stride_c; const unsigned int tx = threadIdx.x; @@ -1042,7 +1050,9 @@ extern "C" __global__ void matmul_bias_batched_f32( unsigned int block_n, unsigned int block_k, unsigned int thread_m, - unsigned int thread_n + unsigned int thread_n, + unsigned int a_batch_count, + unsigned int b_batch_count ) { extern __shared__ float shared_mem[]; float* As = shared_mem; @@ -1055,8 +1065,8 @@ extern "C" __global__ void matmul_bias_batched_f32( const unsigned int stride_b = K * N; const unsigned int stride_c = M * N; - const float* A_batch = A + b * stride_a; - const float* B_batch = B + b * stride_b; + const float* A_batch = A + (b % a_batch_count) * stride_a; + const float* B_batch = B + (b % b_batch_count) * stride_b; float* C_batch = C + b * stride_c; const unsigned int tx = threadIdx.x; @@ -1264,7 +1274,9 @@ extern "C" __global__ void matmul_bias_batched_f64( unsigned int block_n, unsigned int block_k, unsigned int thread_m, - unsigned int thread_n + unsigned int thread_n, + unsigned int a_batch_count, + unsigned int b_batch_count ) { extern __shared__ double shared_mem_f64[]; double* As = shared_mem_f64; @@ -1277,8 +1289,8 @@ extern "C" __global__ void matmul_bias_batched_f64( const unsigned int stride_b = K * N; const unsigned int stride_c = M * N; - const double* A_batch = A + b * stride_a; - const double* B_batch = B + b * stride_b; + const double* A_batch = A + (b % a_batch_count) * stride_a; + const double* B_batch = B + (b % b_batch_count) * stride_b; double* C_batch = C + b * stride_c; const unsigned int tx = threadIdx.x; @@ -1487,7 +1499,9 @@ extern "C" __global__ void matmul_bias_batched_f16( unsigned int block_n, unsigned int block_k, unsigned int thread_m, - unsigned int thread_n + unsigned int thread_n, + unsigned int a_batch_count, + unsigned int b_batch_count ) { extern __shared__ float shared_mem[]; float* As = shared_mem; @@ -1500,8 +1514,8 @@ extern "C" __global__ void matmul_bias_batched_f16( const unsigned int stride_b = K * N; const unsigned int stride_c = M * N; - const __half* A_batch = A + b * stride_a; - const __half* B_batch = B + b * stride_b; + const __half* A_batch = A + (b % a_batch_count) * stride_a; + const __half* B_batch = B + (b % b_batch_count) * stride_b; __half* C_batch = C + b * stride_c; const unsigned int tx = threadIdx.x; @@ -1711,7 +1725,9 @@ extern "C" __global__ void matmul_bias_batched_bf16( unsigned int block_n, unsigned int block_k, unsigned int thread_m, - unsigned int thread_n + unsigned int thread_n, + unsigned int a_batch_count, + unsigned int b_batch_count ) { extern __shared__ float shared_mem[]; float* As = shared_mem; @@ -1724,8 +1740,8 @@ extern "C" __global__ void matmul_bias_batched_bf16( const unsigned int stride_b = K * N; const unsigned int stride_c = M * N; - const __nv_bfloat16* A_batch = A + b * stride_a; - const __nv_bfloat16* B_batch = B + b * stride_b; + const __nv_bfloat16* A_batch = A + (b % a_batch_count) * stride_a; + const __nv_bfloat16* B_batch = B + (b % b_batch_count) * stride_b; __nv_bfloat16* C_batch = C + b * stride_c; const unsigned int tx = threadIdx.x; diff --git a/src/runtime/cuda/kernels/mod.rs b/src/runtime/cuda/kernels/mod.rs index a922ad8f..c16a05a3 100644 --- a/src/runtime/cuda/kernels/mod.rs +++ b/src/runtime/cuda/kernels/mod.rs @@ -56,6 +56,12 @@ mod cumulative; mod distance; mod distributions; mod fft; +#[cfg(feature = "fp8")] +mod fp8_matmul; +mod fused_activation_mul; +mod fused_add_norm; +mod fused_elementwise; +mod gemm_epilogue; mod index; mod linalg; pub mod linalg_launchers; @@ -69,6 +75,8 @@ mod scan; mod shape; mod sort; #[cfg(feature = "sparse")] +mod sparse_24_launcher; +#[cfg(feature = "sparse")] mod sparse_convert; #[cfg(feature = "sparse")] mod sparse_coo; @@ -102,6 +110,12 @@ pub use cumulative::*; pub use distance::*; pub use distributions::*; pub use fft::*; +#[cfg(feature = "fp8")] +pub use fp8_matmul::*; +pub use fused_activation_mul::*; +pub use fused_add_norm::*; +pub use fused_elementwise::*; +pub use gemm_epilogue::*; pub use index::*; pub use linalg::*; pub use norm::*; @@ -114,6 +128,8 @@ pub use scan::*; pub use shape::*; pub use sort::*; #[cfg(feature = "sparse")] +pub use sparse_24_launcher::*; +#[cfg(feature = "sparse")] pub use sparse_convert::*; #[cfg(feature = "sparse")] pub use sparse_coo::*; @@ -142,7 +158,8 @@ pub use utility::*; // Re-export commonly used items from loader for advanced users #[allow(unused_imports)] pub use loader::{ - BLOCK_SIZE, LaunchConfig, kernel_names, launch_matmul_batched_kernel, - launch_matmul_bias_batched_kernel, launch_matmul_bias_kernel, launch_matmul_kernel, - launch_semiring_matmul_batched_kernel, launch_semiring_matmul_kernel, + BLOCK_SIZE, LaunchConfig, kernel_names, launch_gemv_kernel_bt, launch_gemv_kernel_bt_mr, + launch_matmul_batched_kernel, launch_matmul_bias_batched_kernel, launch_matmul_bias_kernel, + launch_matmul_kernel, launch_semiring_matmul_batched_kernel, launch_semiring_matmul_kernel, + preload_modules, }; diff --git a/src/runtime/cuda/kernels/norm.cu b/src/runtime/cuda/kernels/norm.cu index b483328a..def90cf4 100644 --- a/src/runtime/cuda/kernels/norm.cu +++ b/src/runtime/cuda/kernels/norm.cu @@ -1,11 +1,97 @@ // Normalization CUDA kernels -// Supports: rms_norm, layer_norm +// Supports: rms_norm, layer_norm, group_norm // Types: f32, f64, f16, bf16 // Note: All half-precision variants use FP32 accumulation for numerical stability +// +// LayerNorm uses single-pass Welford algorithm for numerically stable mean+variance +// computation with warp-level merge via __shfl_down_sync. +// +// Shared memory requirements: +// - rms_norm: blockDim.x * sizeof(T) (e.g., 256 * 4 = 1024 bytes for f32) +// - layer_norm: 3 * ceil(blockDim.x / 32) * sizeof(T) (e.g., 3 * 8 * 4 = 96 bytes for f32) +// - group_norm: 2 * blockDim.x * sizeof(T) (e.g., 2 * 256 * 4 = 2048 bytes for f32) +// +// The kernel launcher MUST allocate at least this much shared memory via the +// launch configuration's third <<< >>> parameter. #include #include +// ============================================================================ +// Welford merge helpers +// ============================================================================ + +// Welford's online algorithm for numerically stable mean+variance. +// Maintains three accumulators per partition: +// count: number of elements seen +// mean: running mean +// M2: sum of squared deviations from the running mean +// Merge formula (combining two partitions a, b): +// delta = mean_b - mean_a +// mean_ab = mean_a + delta * count_b / (count_a + count_b) +// M2_ab = M2_a + M2_b + delta^2 * count_a * count_b / (count_a + count_b) +// This is numerically stable even with extreme value ranges. +__device__ __forceinline__ void welford_merge( + float count_a, float mean_a, float M2_a, + float count_b, float mean_b, float M2_b, + float &count_out, float &mean_out, float &M2_out +) { + float count = count_a + count_b; + if (count == 0.0f) { + count_out = 0.0f; + mean_out = 0.0f; + M2_out = 0.0f; + return; + } + float delta = mean_b - mean_a; + mean_out = mean_a + delta * count_b / count; + M2_out = M2_a + M2_b + delta * delta * count_a * count_b / count; + count_out = count; +} + +__device__ __forceinline__ void welford_merge_f64( + double count_a, double mean_a, double M2_a, + double count_b, double mean_b, double M2_b, + double &count_out, double &mean_out, double &M2_out +) { + double count = count_a + count_b; + if (count == 0.0) { + count_out = 0.0; + mean_out = 0.0; + M2_out = 0.0; + return; + } + double delta = mean_b - mean_a; + mean_out = mean_a + delta * count_b / count; + M2_out = M2_a + M2_b + delta * delta * count_a * count_b / count; + count_out = count; +} + +// Warp-level Welford reduction: merges accumulators across 32 warp lanes +// using shuffle instructions (__shfl_down_sync) to avoid shared memory. +// After this function, lane 0 holds the merged result for the entire warp. +__device__ __forceinline__ void welford_warp_reduce( + float &count, float &mean, float &M2 +) { + for (int offset = 16; offset > 0; offset >>= 1) { + float o_count = __shfl_down_sync(0xffffffff, count, offset); + float o_mean = __shfl_down_sync(0xffffffff, mean, offset); + float o_M2 = __shfl_down_sync(0xffffffff, M2, offset); + welford_merge(count, mean, M2, o_count, o_mean, o_M2, count, mean, M2); + } +} + +__device__ __forceinline__ void welford_warp_reduce_f64( + double &count, double &mean, double &M2 +) { + for (int offset = 16; offset > 0; offset >>= 1) { + double o_count = __shfl_down_sync(0xffffffff, count, offset); + double o_mean = __shfl_down_sync(0xffffffff, mean, offset); + double o_M2 = __shfl_down_sync(0xffffffff, M2, offset); + welford_merge_f64(count, mean, M2, o_count, o_mean, o_M2, count, mean, M2); + } +} + extern "C" { // ============================================================================ @@ -54,6 +140,7 @@ __global__ void rms_norm_f32( } // LayerNorm: (x - mean) / sqrt(var + eps) * weight + bias +// Single-pass Welford algorithm with warp-level merge for numerical stability __global__ void layer_norm_f32( const float* input, const float* weight, const float* bias, float* output, unsigned int batch_size, unsigned int hidden_size, float eps @@ -61,51 +148,62 @@ __global__ void layer_norm_f32( unsigned int row = blockIdx.x; if (row >= batch_size) return; - extern __shared__ float shared[]; - float* mean_shared = shared; - float* var_shared = shared + blockDim.x; - const float* row_in = input + row * hidden_size; float* row_out = output + row * hidden_size; - // Phase 1: Compute mean - float thread_sum = 0.0f; + // Phase 1: Single-pass Welford accumulation + float count = 0.0f, mean = 0.0f, M2 = 0.0f; for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - thread_sum += row_in[i]; + float x = row_in[i]; + count += 1.0f; + float delta = x - mean; + mean += delta / count; + M2 += delta * (x - mean); } - mean_shared[threadIdx.x] = thread_sum; - __syncthreads(); - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; - } - __syncthreads(); - } - float mean = mean_shared[0] / hidden_size; - __syncthreads(); + // Warp-level Welford merge + welford_warp_reduce(count, mean, M2); - // Phase 2: Compute variance - float thread_var = 0.0f; - for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float diff = row_in[i] - mean; - thread_var += diff * diff; + // Block-level merge via shared memory (one entry per warp) + unsigned int warp_id = threadIdx.x / 32; + unsigned int lane_id = threadIdx.x % 32; + unsigned int num_warps = (blockDim.x + 31) / 32; + + extern __shared__ float shared[]; + // Layout: [count0..countN, mean0..meanN, M2_0..M2_N] where N = num_warps + float* s_count = shared; + float* s_mean = shared + num_warps; + float* s_M2 = shared + 2 * num_warps; + + if (lane_id == 0) { + s_count[warp_id] = count; + s_mean[warp_id] = mean; + s_M2[warp_id] = M2; } - var_shared[threadIdx.x] = thread_var; __syncthreads(); - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + // Final reduction in first warp + if (warp_id == 0) { + float r_count = (lane_id < num_warps) ? s_count[lane_id] : 0.0f; + float r_mean = (lane_id < num_warps) ? s_mean[lane_id] : 0.0f; + float r_M2 = (lane_id < num_warps) ? s_M2[lane_id] : 0.0f; + + welford_warp_reduce(r_count, r_mean, r_M2); + + if (lane_id == 0) { + s_mean[0] = r_mean; + s_M2[0] = r_M2; + s_count[0] = r_count; } - __syncthreads(); } - float inv_std = rsqrtf(var_shared[0] / hidden_size + eps); __syncthreads(); - // Phase 3: Normalize and apply affine transform + float final_mean = s_mean[0]; + float inv_std = rsqrtf(s_M2[0] / s_count[0] + eps); + + // Phase 2: Normalize and apply affine transform for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float normalized = (row_in[i] - mean) * inv_std; + float normalized = (row_in[i] - final_mean) * inv_std; row_out[i] = normalized * weight[i] + bias[i]; } } @@ -156,48 +254,58 @@ __global__ void layer_norm_f64( unsigned int row = blockIdx.x; if (row >= batch_size) return; - extern __shared__ double shared_f64[]; - double* mean_shared = shared_f64; - double* var_shared = shared_f64 + blockDim.x; - const double* row_in = input + row * hidden_size; double* row_out = output + row * hidden_size; - double thread_sum = 0.0; + // Single-pass Welford + double count = 0.0, mean = 0.0, M2 = 0.0; for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - thread_sum += row_in[i]; + double x = row_in[i]; + count += 1.0; + double delta = x - mean; + mean += delta / count; + M2 += delta * (x - mean); } - mean_shared[threadIdx.x] = thread_sum; - __syncthreads(); - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; - } - __syncthreads(); - } - double mean = mean_shared[0] / hidden_size; - __syncthreads(); + // Warp-level merge + welford_warp_reduce_f64(count, mean, M2); - double thread_var = 0.0; - for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - double diff = row_in[i] - mean; - thread_var += diff * diff; + unsigned int warp_id = threadIdx.x / 32; + unsigned int lane_id = threadIdx.x % 32; + unsigned int num_warps = (blockDim.x + 31) / 32; + + extern __shared__ double shared_f64[]; + double* s_count = shared_f64; + double* s_mean = shared_f64 + num_warps; + double* s_M2 = shared_f64 + 2 * num_warps; + + if (lane_id == 0) { + s_count[warp_id] = count; + s_mean[warp_id] = mean; + s_M2[warp_id] = M2; } - var_shared[threadIdx.x] = thread_var; __syncthreads(); - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + if (warp_id == 0) { + double r_count = (lane_id < num_warps) ? s_count[lane_id] : 0.0; + double r_mean = (lane_id < num_warps) ? s_mean[lane_id] : 0.0; + double r_M2 = (lane_id < num_warps) ? s_M2[lane_id] : 0.0; + + welford_warp_reduce_f64(r_count, r_mean, r_M2); + + if (lane_id == 0) { + s_mean[0] = r_mean; + s_M2[0] = r_M2; + s_count[0] = r_count; } - __syncthreads(); } - double inv_std = rsqrt(var_shared[0] / hidden_size + eps); __syncthreads(); + double final_mean = s_mean[0]; + double inv_std = rsqrt(s_M2[0] / s_count[0] + eps); + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - double normalized = (row_in[i] - mean) * inv_std; + double normalized = (row_in[i] - final_mean) * inv_std; row_out[i] = normalized * weight[i] + bias[i]; } } @@ -251,50 +359,57 @@ __global__ void layer_norm_f16( unsigned int row = blockIdx.x; if (row >= batch_size) return; - extern __shared__ float shared[]; - float* mean_shared = shared; - float* var_shared = shared + blockDim.x; - const __half* row_in = input + row * hidden_size; __half* row_out = output + row * hidden_size; - // FP32 accumulation for mean - float thread_sum = 0.0f; + // Single-pass Welford with FP32 accumulation + float count = 0.0f, mean = 0.0f, M2 = 0.0f; for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - thread_sum += __half2float(row_in[i]); + float x = __half2float(row_in[i]); + count += 1.0f; + float delta = x - mean; + mean += delta / count; + M2 += delta * (x - mean); } - mean_shared[threadIdx.x] = thread_sum; - __syncthreads(); - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; - } - __syncthreads(); - } - float mean = mean_shared[0] / hidden_size; - __syncthreads(); + welford_warp_reduce(count, mean, M2); - // FP32 accumulation for variance - float thread_var = 0.0f; - for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float diff = __half2float(row_in[i]) - mean; - thread_var += diff * diff; + unsigned int warp_id = threadIdx.x / 32; + unsigned int lane_id = threadIdx.x % 32; + unsigned int num_warps = (blockDim.x + 31) / 32; + + extern __shared__ float shared[]; + float* s_count = shared; + float* s_mean = shared + num_warps; + float* s_M2 = shared + 2 * num_warps; + + if (lane_id == 0) { + s_count[warp_id] = count; + s_mean[warp_id] = mean; + s_M2[warp_id] = M2; } - var_shared[threadIdx.x] = thread_var; __syncthreads(); - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + if (warp_id == 0) { + float r_count = (lane_id < num_warps) ? s_count[lane_id] : 0.0f; + float r_mean = (lane_id < num_warps) ? s_mean[lane_id] : 0.0f; + float r_M2 = (lane_id < num_warps) ? s_M2[lane_id] : 0.0f; + + welford_warp_reduce(r_count, r_mean, r_M2); + + if (lane_id == 0) { + s_mean[0] = r_mean; + s_M2[0] = r_M2; + s_count[0] = r_count; } - __syncthreads(); } - float inv_std = rsqrtf(var_shared[0] / hidden_size + eps); __syncthreads(); + float final_mean = s_mean[0]; + float inv_std = rsqrtf(s_M2[0] / s_count[0] + eps); + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float normalized = (__half2float(row_in[i]) - mean) * inv_std; + float normalized = (__half2float(row_in[i]) - final_mean) * inv_std; float result = normalized * __half2float(weight[i]) + __half2float(bias[i]); row_out[i] = __float2half(result); } @@ -349,52 +464,362 @@ __global__ void layer_norm_bf16( unsigned int row = blockIdx.x; if (row >= batch_size) return; + const __nv_bfloat16* row_in = input + row * hidden_size; + __nv_bfloat16* row_out = output + row * hidden_size; + + // Single-pass Welford with FP32 accumulation + float count = 0.0f, mean = 0.0f, M2 = 0.0f; + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = __bfloat162float(row_in[i]); + count += 1.0f; + float delta = x - mean; + mean += delta / count; + M2 += delta * (x - mean); + } + + welford_warp_reduce(count, mean, M2); + + unsigned int warp_id = threadIdx.x / 32; + unsigned int lane_id = threadIdx.x % 32; + unsigned int num_warps = (blockDim.x + 31) / 32; + + extern __shared__ float shared[]; + float* s_count = shared; + float* s_mean = shared + num_warps; + float* s_M2 = shared + 2 * num_warps; + + if (lane_id == 0) { + s_count[warp_id] = count; + s_mean[warp_id] = mean; + s_M2[warp_id] = M2; + } + __syncthreads(); + + if (warp_id == 0) { + float r_count = (lane_id < num_warps) ? s_count[lane_id] : 0.0f; + float r_mean = (lane_id < num_warps) ? s_mean[lane_id] : 0.0f; + float r_M2 = (lane_id < num_warps) ? s_M2[lane_id] : 0.0f; + + welford_warp_reduce(r_count, r_mean, r_M2); + + if (lane_id == 0) { + s_mean[0] = r_mean; + s_M2[0] = r_M2; + s_count[0] = r_count; + } + } + __syncthreads(); + + float final_mean = s_mean[0]; + float inv_std = rsqrtf(s_M2[0] / s_count[0] + eps); + + for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float normalized = (__bfloat162float(row_in[i]) - final_mean) * inv_std; + float result = normalized * __bfloat162float(weight[i]) + __bfloat162float(bias[i]); + row_out[i] = __float2bfloat16(result); + } +} + +// ============================================================================ +// F32 GroupNorm Operations +// ============================================================================ + +// GroupNorm: Divides channels into num_groups, normalizes each group separately +// Each block handles one (batch, group) pair +// Input shape: [batch, channels, spatial...] +__global__ void group_norm_f32( + const float* input, const float* weight, const float* bias, float* output, + unsigned int batch, unsigned int channels, unsigned int spatial, + unsigned int num_groups, unsigned int channels_per_group, float eps +) { + unsigned int b = blockIdx.x / num_groups; + unsigned int g = blockIdx.x % num_groups; + + if (b >= batch || g >= num_groups) return; + extern __shared__ float shared[]; float* mean_shared = shared; float* var_shared = shared + blockDim.x; - const __nv_bfloat16* row_in = input + row * hidden_size; - __nv_bfloat16* row_out = output + row * hidden_size; + unsigned int group_size = channels_per_group * spatial; + unsigned int c_start = g * channels_per_group; - // FP32 accumulation for mean + // Phase 1: Compute mean float thread_sum = 0.0f; - for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - thread_sum += __bfloat162float(row_in[i]); + for (unsigned int idx = threadIdx.x; idx < group_size; idx += blockDim.x) { + unsigned int c = c_start + (idx / spatial); + unsigned int s = idx % spatial; + unsigned int offset = (b * channels + c) * spatial + s; + thread_sum += input[offset]; } mean_shared[threadIdx.x] = thread_sum; __syncthreads(); + // Reduce within block for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { if (threadIdx.x < s) { mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; } __syncthreads(); } - float mean = mean_shared[0] / hidden_size; + float mean = mean_shared[0] / group_size; __syncthreads(); - // FP32 accumulation for variance + // Phase 2: Compute variance float thread_var = 0.0f; - for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float diff = __bfloat162float(row_in[i]) - mean; + for (unsigned int idx = threadIdx.x; idx < group_size; idx += blockDim.x) { + unsigned int c = c_start + (idx / spatial); + unsigned int s = idx % spatial; + unsigned int offset = (b * channels + c) * spatial + s; + float diff = input[offset] - mean; thread_var += diff * diff; } var_shared[threadIdx.x] = thread_var; __syncthreads(); + // Reduce within block for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { if (threadIdx.x < s) { var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; } __syncthreads(); } - float inv_std = rsqrtf(var_shared[0] / hidden_size + eps); + float inv_std = rsqrtf(var_shared[0] / group_size + eps); __syncthreads(); - for (unsigned int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float normalized = (__bfloat162float(row_in[i]) - mean) * inv_std; - float result = normalized * __bfloat162float(weight[i]) + __bfloat162float(bias[i]); - row_out[i] = __float2bfloat16(result); + // Phase 3: Normalize and apply affine transform + for (unsigned int idx = threadIdx.x; idx < group_size; idx += blockDim.x) { + unsigned int c = c_start + (idx / spatial); + unsigned int s = idx % spatial; + unsigned int offset = (b * channels + c) * spatial + s; + float normalized = (input[offset] - mean) * inv_std; + output[offset] = normalized * weight[c] + bias[c]; + } +} + +// ============================================================================ +// F64 GroupNorm Operations +// ============================================================================ + +__global__ void group_norm_f64( + const double* input, const double* weight, const double* bias, double* output, + unsigned int batch, unsigned int channels, unsigned int spatial, + unsigned int num_groups, unsigned int channels_per_group, double eps +) { + unsigned int b = blockIdx.x / num_groups; + unsigned int g = blockIdx.x % num_groups; + + if (b >= batch || g >= num_groups) return; + + extern __shared__ double shared_f64[]; + double* mean_shared = shared_f64; + double* var_shared = shared_f64 + blockDim.x; + + unsigned int group_size = channels_per_group * spatial; + unsigned int c_start = g * channels_per_group; + + // Phase 1: Compute mean + double thread_sum = 0.0; + for (unsigned int idx = threadIdx.x; idx < group_size; idx += blockDim.x) { + unsigned int c = c_start + (idx / spatial); + unsigned int s = idx % spatial; + unsigned int offset = (b * channels + c) * spatial + s; + thread_sum += input[offset]; + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + // Reduce within block + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + } + __syncthreads(); + } + double mean = mean_shared[0] / group_size; + __syncthreads(); + + // Phase 2: Compute variance + double thread_var = 0.0; + for (unsigned int idx = threadIdx.x; idx < group_size; idx += blockDim.x) { + unsigned int c = c_start + (idx / spatial); + unsigned int s = idx % spatial; + unsigned int offset = (b * channels + c) * spatial + s; + double diff = input[offset] - mean; + thread_var += diff * diff; + } + var_shared[threadIdx.x] = thread_var; + __syncthreads(); + + // Reduce within block + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + } + __syncthreads(); + } + double inv_std = rsqrt(var_shared[0] / group_size + eps); + __syncthreads(); + + // Phase 3: Normalize and apply affine transform + for (unsigned int idx = threadIdx.x; idx < group_size; idx += blockDim.x) { + unsigned int c = c_start + (idx / spatial); + unsigned int s = idx % spatial; + unsigned int offset = (b * channels + c) * spatial + s; + double normalized = (input[offset] - mean) * inv_std; + output[offset] = normalized * weight[c] + bias[c]; + } +} + +// ============================================================================ +// F16 GroupNorm Operations +// Note: Uses FP32 accumulation for numerical stability +// ============================================================================ + +__global__ void group_norm_f16( + const __half* input, const __half* weight, const __half* bias, __half* output, + unsigned int batch, unsigned int channels, unsigned int spatial, + unsigned int num_groups, unsigned int channels_per_group, float eps +) { + unsigned int b = blockIdx.x / num_groups; + unsigned int g = blockIdx.x % num_groups; + + if (b >= batch || g >= num_groups) return; + + extern __shared__ float shared[]; + float* mean_shared = shared; + float* var_shared = shared + blockDim.x; + + unsigned int group_size = channels_per_group * spatial; + unsigned int c_start = g * channels_per_group; + + // Phase 1: Compute mean (FP32 accumulation) + float thread_sum = 0.0f; + for (unsigned int idx = threadIdx.x; idx < group_size; idx += blockDim.x) { + unsigned int c = c_start + (idx / spatial); + unsigned int s = idx % spatial; + unsigned int offset = (b * channels + c) * spatial + s; + thread_sum += __half2float(input[offset]); + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + // Reduce within block + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + } + __syncthreads(); + } + float mean = mean_shared[0] / group_size; + __syncthreads(); + + // Phase 2: Compute variance (FP32 accumulation) + float thread_var = 0.0f; + for (unsigned int idx = threadIdx.x; idx < group_size; idx += blockDim.x) { + unsigned int c = c_start + (idx / spatial); + unsigned int s = idx % spatial; + unsigned int offset = (b * channels + c) * spatial + s; + float diff = __half2float(input[offset]) - mean; + thread_var += diff * diff; + } + var_shared[threadIdx.x] = thread_var; + __syncthreads(); + + // Reduce within block + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + } + __syncthreads(); + } + float inv_std = rsqrtf(var_shared[0] / group_size + eps); + __syncthreads(); + + // Phase 3: Normalize and apply affine transform + for (unsigned int idx = threadIdx.x; idx < group_size; idx += blockDim.x) { + unsigned int c = c_start + (idx / spatial); + unsigned int s = idx % spatial; + unsigned int offset = (b * channels + c) * spatial + s; + float normalized = (__half2float(input[offset]) - mean) * inv_std; + float result = normalized * __half2float(weight[c]) + __half2float(bias[c]); + output[offset] = __float2half(result); + } +} + +// ============================================================================ +// BF16 GroupNorm Operations +// Note: Uses FP32 accumulation for numerical stability +// ============================================================================ + +__global__ void group_norm_bf16( + const __nv_bfloat16* input, const __nv_bfloat16* weight, const __nv_bfloat16* bias, __nv_bfloat16* output, + unsigned int batch, unsigned int channels, unsigned int spatial, + unsigned int num_groups, unsigned int channels_per_group, float eps +) { + unsigned int b = blockIdx.x / num_groups; + unsigned int g = blockIdx.x % num_groups; + + if (b >= batch || g >= num_groups) return; + + extern __shared__ float shared[]; + float* mean_shared = shared; + float* var_shared = shared + blockDim.x; + + unsigned int group_size = channels_per_group * spatial; + unsigned int c_start = g * channels_per_group; + + // Phase 1: Compute mean (FP32 accumulation) + float thread_sum = 0.0f; + for (unsigned int idx = threadIdx.x; idx < group_size; idx += blockDim.x) { + unsigned int c = c_start + (idx / spatial); + unsigned int s = idx % spatial; + unsigned int offset = (b * channels + c) * spatial + s; + thread_sum += __bfloat162float(input[offset]); + } + mean_shared[threadIdx.x] = thread_sum; + __syncthreads(); + + // Reduce within block + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + mean_shared[threadIdx.x] += mean_shared[threadIdx.x + s]; + } + __syncthreads(); + } + float mean = mean_shared[0] / group_size; + __syncthreads(); + + // Phase 2: Compute variance (FP32 accumulation) + float thread_var = 0.0f; + for (unsigned int idx = threadIdx.x; idx < group_size; idx += blockDim.x) { + unsigned int c = c_start + (idx / spatial); + unsigned int s = idx % spatial; + unsigned int offset = (b * channels + c) * spatial + s; + float diff = __bfloat162float(input[offset]) - mean; + thread_var += diff * diff; + } + var_shared[threadIdx.x] = thread_var; + __syncthreads(); + + // Reduce within block + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + var_shared[threadIdx.x] += var_shared[threadIdx.x + s]; + } + __syncthreads(); + } + float inv_std = rsqrtf(var_shared[0] / group_size + eps); + __syncthreads(); + + // Phase 3: Normalize and apply affine transform + for (unsigned int idx = threadIdx.x; idx < group_size; idx += blockDim.x) { + unsigned int c = c_start + (idx / spatial); + unsigned int s = idx % spatial; + unsigned int offset = (b * channels + c) * spatial + s; + float normalized = (__bfloat162float(input[offset]) - mean) * inv_std; + float result = normalized * __bfloat162float(weight[c]) + __bfloat162float(bias[c]); + output[offset] = __float2bfloat16(result); } } diff --git a/src/runtime/cuda/kernels/norm.rs b/src/runtime/cuda/kernels/norm.rs index 25ec3eac..1d8542f9 100644 --- a/src/runtime/cuda/kernels/norm.rs +++ b/src/runtime/cuda/kernels/norm.rs @@ -148,3 +148,88 @@ pub unsafe fn launch_layer_norm( Ok(()) } } + +/// Launch a GroupNorm kernel. +/// +/// Computes: Group normalization across divided channel groups +/// Input shape: [batch, channels, spatial...] +/// Divides channels into num_groups, normalizes each group separately +/// +/// Computes for each (batch, group): +/// - mean and variance over channels_per_group * spatial elements +/// - Then: `output = (input - mean) / sqrt(variance + eps) * weight + bias` +/// +/// # Arguments +/// +/// * `input_ptr` - Device pointer to input tensor of shape [batch, channels, spatial...] +/// * `weight_ptr` - Device pointer to weight tensor of shape [channels] +/// * `bias_ptr` - Device pointer to bias tensor of shape [channels] +/// * `output_ptr` - Device pointer to output tensor of shape [batch, channels, spatial...] +/// * `batch` - Batch size +/// * `channels` - Number of channels +/// * `spatial` - Product of spatial dimensions (height * width for 4D tensors) +/// * `num_groups` - Number of groups to divide channels into +/// * `channels_per_group` - Channels per group (channels / num_groups) +/// * `eps` - Small constant for numerical stability (typically 1e-5) +/// +/// # Safety +/// +/// - All pointers must be valid device memory +/// - Input and output must have batch * channels * spatial elements +/// - Weight and bias must have channels elements +/// - channels must be divisible by num_groups +pub unsafe fn launch_group_norm( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + input_ptr: u64, + weight_ptr: u64, + bias_ptr: u64, + output_ptr: u64, + batch: usize, + channels: usize, + spatial: usize, + num_groups: usize, + channels_per_group: usize, + eps: f32, +) -> Result<()> { + unsafe { + let module = get_or_load_module(context, device_index, kernel_names::NORM_MODULE)?; + let func_name = kernel_name("group_norm", dtype); + let func = get_kernel_function(&module, &func_name)?; + + // One block per (batch, group) pair + let grid_size = (batch * num_groups) as u32; + let group_size = channels_per_group * spatial; + let block_size = BLOCK_SIZE.min(group_size as u32); + + // Shared memory: 2 * block_size floats for mean and variance reduction + let shared_mem = block_size * 2 * 4; // 2 floats per thread for f32 + + let batch_u32 = batch as u32; + let channels_u32 = channels as u32; + let spatial_u32 = spatial as u32; + let num_groups_u32 = num_groups as u32; + let channels_per_group_u32 = channels_per_group as u32; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem); + let mut builder = stream.launch_builder(&func); + builder.arg(&input_ptr); + builder.arg(&weight_ptr); + builder.arg(&bias_ptr); + builder.arg(&output_ptr); + builder.arg(&batch_u32); + builder.arg(&channels_u32); + builder.arg(&spatial_u32); + builder.arg(&num_groups_u32); + builder.arg(&channels_per_group_u32); + builder.arg(&eps); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA group_norm kernel launch failed: {:?}", e)) + })?; + + Ok(()) + } +} diff --git a/src/runtime/cuda/kernels/scan.rs b/src/runtime/cuda/kernels/scan.rs index 13501cc6..c0e05102 100644 --- a/src/runtime/cuda/kernels/scan.rs +++ b/src/runtime/cuda/kernels/scan.rs @@ -59,6 +59,14 @@ const MAX_SCAN_RECURSION_DEPTH: usize = 10; /// # Returns /// /// `(output_tensor, total_sum)` where output has size n+1 +/// +/// # Safety +/// +/// - `input` must be a valid `CudaRuntime` tensor of `DType::I32` on the device associated with +/// `context`. Passing a tensor with a different dtype returns an error. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +/// - A single scalar GPU-to-CPU transfer is performed at the end to read the total sum; this is +/// intentional and documented as acceptable for control-flow purposes. pub unsafe fn exclusive_scan_i32_gpu( context: &Arc, stream: &CudaStream, @@ -78,8 +86,8 @@ pub unsafe fn exclusive_scan_i32_gpu( // Allocate output tensor with size n+1 let output = Tensor::::zeros(&[n + 1], DType::I32, device); - let input_ptr = input.storage().ptr(); - let output_ptr = output.storage().ptr(); + let input_ptr = input.ptr(); + let output_ptr = output.ptr(); if n <= SCAN_BLOCK_SIZE as usize { // Small array: use single-block scan @@ -120,7 +128,7 @@ pub unsafe fn exclusive_scan_i32_gpu( unsafe { cudarc::driver::sys::cuMemcpyDtoH_v2( &mut total_i32 as *mut i32 as *mut std::ffi::c_void, - output.storage().ptr() + offset_bytes as u64, + output.ptr() + offset_bytes as u64, std::mem::size_of::(), ); } @@ -194,7 +202,7 @@ unsafe fn launch_scan_multi_block_i32( // Allocate temporary buffer for block sums let block_sums = Tensor::::zeros(&[num_blocks as usize], DType::I32, device); - let block_sums_ptr = block_sums.storage().ptr(); + let block_sums_ptr = block_sums.ptr(); // Step 1: Scan each block independently let func_step1 = get_kernel_function(&module, "scan_blocks_i32_step1")?; @@ -219,7 +227,7 @@ unsafe fn launch_scan_multi_block_i32( // Allocate buffer for scanned block sums (size num_blocks + 1) let scanned_block_sums = Tensor::::zeros(&[num_blocks as usize + 1], DType::I32, device); - let scanned_block_sums_ptr = scanned_block_sums.storage().ptr(); + let scanned_block_sums_ptr = scanned_block_sums.ptr(); if num_blocks <= SCAN_BLOCK_SIZE { // Block sums fit in single block - use simple scan @@ -316,6 +324,14 @@ unsafe fn launch_scan_multi_block_i32( /// # Returns /// /// `(output_tensor, total_sum)` where output has size n+1 +/// +/// # Safety +/// +/// - `input` must be a valid `CudaRuntime` tensor of `DType::I64` on the device associated with +/// `context`. Passing a tensor with a different dtype returns an error. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +/// - A single scalar GPU-to-CPU transfer is performed at the end to read the total sum; this is +/// intentional and documented as acceptable for control-flow purposes. pub unsafe fn exclusive_scan_i64_gpu( context: &Arc, stream: &CudaStream, @@ -335,8 +351,8 @@ pub unsafe fn exclusive_scan_i64_gpu( // Allocate output tensor with size n+1 let output = Tensor::::zeros(&[n + 1], DType::I64, device); - let input_ptr = input.storage().ptr(); - let output_ptr = output.storage().ptr(); + let input_ptr = input.ptr(); + let output_ptr = output.ptr(); if n <= SCAN_BLOCK_SIZE as usize { // Small array: use single-block scan @@ -377,7 +393,7 @@ pub unsafe fn exclusive_scan_i64_gpu( unsafe { cudarc::driver::sys::cuMemcpyDtoH_v2( &mut total_i64 as *mut i64 as *mut std::ffi::c_void, - output.storage().ptr() + offset_bytes as u64, + output.ptr() + offset_bytes as u64, std::mem::size_of::(), ); } @@ -451,7 +467,7 @@ unsafe fn launch_scan_multi_block_i64( // Allocate temporary buffer for block sums let block_sums = Tensor::::zeros(&[num_blocks as usize], DType::I64, device); - let block_sums_ptr = block_sums.storage().ptr(); + let block_sums_ptr = block_sums.ptr(); // Step 1: Scan each block independently let func_step1 = get_kernel_function(&module, "scan_blocks_i64_step1")?; @@ -483,7 +499,7 @@ unsafe fn launch_scan_multi_block_i64( // Allocate buffer for scanned block sums (size num_blocks + 1) let scanned_block_sums = Tensor::::zeros(&[num_blocks as usize + 1], DType::I64, device); - let scanned_block_sums_ptr = scanned_block_sums.storage().ptr(); + let scanned_block_sums_ptr = scanned_block_sums.ptr(); if num_blocks <= SCAN_BLOCK_SIZE { // Block sums fit in single block - use simple scan @@ -580,6 +596,9 @@ mod tests { #[test] #[cfg(feature = "cuda")] fn test_exclusive_scan_small() { + if !crate::runtime::cuda::is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -606,6 +625,9 @@ mod tests { #[test] #[cfg(feature = "cuda")] fn test_exclusive_scan_large() { + if !crate::runtime::cuda::is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -640,6 +662,9 @@ mod tests { #[test] #[cfg(feature = "cuda")] fn test_exclusive_scan_zeros() { + if !crate::runtime::cuda::is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -664,6 +689,9 @@ mod tests { #[test] #[cfg(feature = "cuda")] fn test_exclusive_scan_single_element() { + if !crate::runtime::cuda::is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -688,6 +716,9 @@ mod tests { #[test] #[cfg(feature = "cuda")] fn test_exclusive_scan_very_large() { + if !crate::runtime::cuda::is_cuda_available() { + return; + } // Test with 500,000 elements (requires recursive multi-level scan) // This exceeds 262,144 = 512^2 which was the previous limit let device = CudaDevice::new(0); @@ -724,6 +755,9 @@ mod tests { #[test] #[cfg(feature = "cuda")] fn test_exclusive_scan_boundary_size() { + if !crate::runtime::cuda::is_cuda_available() { + return; + } // Test at the boundary of single-level multi-block (512 * 512 = 262,144) let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -754,6 +788,9 @@ mod tests { #[test] #[cfg(feature = "cuda")] fn test_exclusive_scan_i64_very_large() { + if !crate::runtime::cuda::is_cuda_available() { + return; + } // Test i64 with large values that would overflow i32 let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); diff --git a/src/runtime/cuda/kernels/semiring_matmul.cu b/src/runtime/cuda/kernels/semiring_matmul.cu index aacf84a9..e3a5ed03 100644 --- a/src/runtime/cuda/kernels/semiring_matmul.cu +++ b/src/runtime/cuda/kernels/semiring_matmul.cu @@ -1,6 +1,10 @@ // Semiring Matrix Multiplication CUDA Kernels // C[i,j] = reduce_k( combine(A[i,k], B[k,j]) ) // +#include +#include +#include "dtype_traits.cuh" +// // Semiring operations (passed as op parameter): // 0 = MinPlus: reduce=min, combine=+ // 1 = MaxPlus: reduce=max, combine=+ @@ -110,7 +114,9 @@ extern "C" __global__ void semiring_matmul_batched_f32( unsigned int N, unsigned int K, unsigned int op, - unsigned int batch_size + unsigned int batch_size, + unsigned int a_batch_count, + unsigned int b_batch_count ) { unsigned int batch = blockIdx.z; if (batch >= batch_size) return; @@ -120,8 +126,8 @@ extern "C" __global__ void semiring_matmul_batched_f32( if (row >= M || col >= N) return; - unsigned int a_offset = batch * M * K; - unsigned int b_offset = batch * K * N; + unsigned int a_offset = (batch % a_batch_count) * M * K; + unsigned int b_offset = (batch % b_batch_count) * K * N; unsigned int c_offset = batch * M * N; float acc; @@ -214,7 +220,9 @@ extern "C" __global__ void semiring_matmul_batched_f64( unsigned int N, unsigned int K, unsigned int op, - unsigned int batch_size + unsigned int batch_size, + unsigned int a_batch_count, + unsigned int b_batch_count ) { unsigned int batch = blockIdx.z; if (batch >= batch_size) return; @@ -224,8 +232,8 @@ extern "C" __global__ void semiring_matmul_batched_f64( if (row >= M || col >= N) return; - unsigned int a_offset = batch * M * K; - unsigned int b_offset = batch * K * N; + unsigned int a_offset = (batch % a_batch_count) * M * K; + unsigned int b_offset = (batch % b_batch_count) * K * N; unsigned int c_offset = batch * M * N; double acc; @@ -318,7 +326,9 @@ extern "C" __global__ void semiring_matmul_batched_i32( unsigned int N, unsigned int K, unsigned int op, - unsigned int batch_size + unsigned int batch_size, + unsigned int a_batch_count, + unsigned int b_batch_count ) { unsigned int batch = blockIdx.z; if (batch >= batch_size) return; @@ -328,8 +338,8 @@ extern "C" __global__ void semiring_matmul_batched_i32( if (row >= M || col >= N) return; - unsigned int a_offset = batch * M * K; - unsigned int b_offset = batch * K * N; + unsigned int a_offset = (batch % a_batch_count) * M * K; + unsigned int b_offset = (batch % b_batch_count) * K * N; unsigned int c_offset = batch * M * N; int acc; @@ -363,3 +373,465 @@ extern "C" __global__ void semiring_matmul_batched_i32( C[c_offset + row * N + col] = acc; } + +// ============================================================================ +// U8 (Bool) Kernels — primarily for OrAnd semiring +// ============================================================================ + +extern "C" __global__ void semiring_matmul_u8( + const unsigned char* __restrict__ A, + const unsigned char* __restrict__ B, + unsigned char* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int op +) { + unsigned int row = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + + if (row >= M || col >= N) return; + + unsigned char acc; + switch (op) { + case 0: case 3: acc = 255; break; + case 1: case 2: acc = 0; break; + default: acc = 0; break; + } + + for (unsigned int kk = 0; kk < K; kk++) { + unsigned char a_val = A[row * K + kk]; + unsigned char b_val = B[kk * N + col]; + + unsigned char combined; + switch (op) { + case 0: case 1: combined = (unsigned char)(a_val + b_val); break; + case 2: combined = (a_val < b_val) ? a_val : b_val; break; + case 3: case 5: combined = (a_val > b_val) ? a_val : b_val; break; + case 4: combined = (a_val != 0 && b_val != 0) ? 1 : 0; break; + default: combined = (unsigned char)(a_val + b_val); break; + } + + switch (op) { + case 0: case 3: acc = (acc < combined) ? acc : combined; break; + case 1: case 2: acc = (acc > combined) ? acc : combined; break; + case 4: if (combined != 0) acc = 1; break; + case 5: acc = acc + combined; break; + default: acc = (acc < combined) ? acc : combined; break; + } + } + + C[row * N + col] = acc; +} + +extern "C" __global__ void semiring_matmul_batched_u8( + const unsigned char* __restrict__ A, + const unsigned char* __restrict__ B, + unsigned char* __restrict__ C, + unsigned int M, + unsigned int N, + unsigned int K, + unsigned int op, + unsigned int batch_size, + unsigned int a_batch_count, + unsigned int b_batch_count +) { + unsigned int batch = blockIdx.z; + if (batch >= batch_size) return; + + unsigned int row = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + + if (row >= M || col >= N) return; + + unsigned int a_offset = (batch % a_batch_count) * M * K; + unsigned int b_offset = (batch % b_batch_count) * K * N; + unsigned int c_offset = batch * M * N; + + unsigned char acc; + switch (op) { + case 0: case 3: acc = 255; break; + case 1: case 2: acc = 0; break; + default: acc = 0; break; + } + + for (unsigned int kk = 0; kk < K; kk++) { + unsigned char a_val = A[a_offset + row * K + kk]; + unsigned char b_val = B[b_offset + kk * N + col]; + + unsigned char combined; + switch (op) { + case 0: case 1: combined = (unsigned char)(a_val + b_val); break; + case 2: combined = (a_val < b_val) ? a_val : b_val; break; + case 3: case 5: combined = (a_val > b_val) ? a_val : b_val; break; + case 4: combined = (a_val != 0 && b_val != 0) ? 1 : 0; break; + default: combined = (unsigned char)(a_val + b_val); break; + } + + switch (op) { + case 0: case 3: acc = (acc < combined) ? acc : combined; break; + case 1: case 2: acc = (acc > combined) ? acc : combined; break; + case 4: if (combined != 0) acc = 1; break; + case 5: acc = acc + combined; break; + default: acc = (acc < combined) ? acc : combined; break; + } + } + + C[c_offset + row * N + col] = acc; +} + +// ============================================================================ +// F16 Kernels (compute in F32) +// ============================================================================ + +extern "C" __global__ void semiring_matmul_f16( + const __half* __restrict__ A, + const __half* __restrict__ B, + __half* __restrict__ C, + unsigned int M, unsigned int N, unsigned int K, unsigned int op +) { + unsigned int row = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= M || col >= N) return; + + float acc; + switch (op) { + case 0: case 3: acc = __int_as_float(0x7f800000); break; + case 1: case 2: acc = __int_as_float(0xff800000); break; + default: acc = 0.0f; break; + } + + for (unsigned int kk = 0; kk < K; kk++) { + float a_val = __half2float(A[row * K + kk]); + float b_val = __half2float(B[kk * N + col]); + float combined; + switch (op) { + case 0: case 1: combined = a_val + b_val; break; + case 2: combined = fminf(a_val, b_val); break; + case 3: case 5: combined = fmaxf(a_val, b_val); break; + case 4: combined = (a_val != 0.0f && b_val != 0.0f) ? 1.0f : 0.0f; break; + default: combined = a_val + b_val; break; + } + switch (op) { + case 0: case 3: acc = fminf(acc, combined); break; + case 1: case 2: acc = fmaxf(acc, combined); break; + case 4: if (combined != 0.0f) acc = 1.0f; break; + case 5: acc = acc + combined; break; + default: acc = fminf(acc, combined); break; + } + } + C[row * N + col] = __float2half(acc); +} + +extern "C" __global__ void semiring_matmul_batched_f16( + const __half* __restrict__ A, + const __half* __restrict__ B, + __half* __restrict__ C, + unsigned int M, unsigned int N, unsigned int K, unsigned int op, + unsigned int batch_size, unsigned int a_batch_count, unsigned int b_batch_count +) { + unsigned int batch = blockIdx.z; + if (batch >= batch_size) return; + unsigned int row = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= M || col >= N) return; + + unsigned int a_offset = (batch % a_batch_count) * M * K; + unsigned int b_offset = (batch % b_batch_count) * K * N; + unsigned int c_offset = batch * M * N; + + float acc; + switch (op) { + case 0: case 3: acc = __int_as_float(0x7f800000); break; + case 1: case 2: acc = __int_as_float(0xff800000); break; + default: acc = 0.0f; break; + } + + for (unsigned int kk = 0; kk < K; kk++) { + float a_val = __half2float(A[a_offset + row * K + kk]); + float b_val = __half2float(B[b_offset + kk * N + col]); + float combined; + switch (op) { + case 0: case 1: combined = a_val + b_val; break; + case 2: combined = fminf(a_val, b_val); break; + case 3: case 5: combined = fmaxf(a_val, b_val); break; + case 4: combined = (a_val != 0.0f && b_val != 0.0f) ? 1.0f : 0.0f; break; + default: combined = a_val + b_val; break; + } + switch (op) { + case 0: case 3: acc = fminf(acc, combined); break; + case 1: case 2: acc = fmaxf(acc, combined); break; + case 4: if (combined != 0.0f) acc = 1.0f; break; + case 5: acc = acc + combined; break; + default: acc = fminf(acc, combined); break; + } + } + C[c_offset + row * N + col] = __float2half(acc); +} + +// ============================================================================ +// BF16 Kernels (compute in F32) +// ============================================================================ + +extern "C" __global__ void semiring_matmul_bf16( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, + __nv_bfloat16* __restrict__ C, + unsigned int M, unsigned int N, unsigned int K, unsigned int op +) { + unsigned int row = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= M || col >= N) return; + + float acc; + switch (op) { + case 0: case 3: acc = __int_as_float(0x7f800000); break; + case 1: case 2: acc = __int_as_float(0xff800000); break; + default: acc = 0.0f; break; + } + + for (unsigned int kk = 0; kk < K; kk++) { + float a_val = __bfloat162float(A[row * K + kk]); + float b_val = __bfloat162float(B[kk * N + col]); + float combined; + switch (op) { + case 0: case 1: combined = a_val + b_val; break; + case 2: combined = fminf(a_val, b_val); break; + case 3: case 5: combined = fmaxf(a_val, b_val); break; + case 4: combined = (a_val != 0.0f && b_val != 0.0f) ? 1.0f : 0.0f; break; + default: combined = a_val + b_val; break; + } + switch (op) { + case 0: case 3: acc = fminf(acc, combined); break; + case 1: case 2: acc = fmaxf(acc, combined); break; + case 4: if (combined != 0.0f) acc = 1.0f; break; + case 5: acc = acc + combined; break; + default: acc = fminf(acc, combined); break; + } + } + C[row * N + col] = __float2bfloat16(acc); +} + +extern "C" __global__ void semiring_matmul_batched_bf16( + const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, + __nv_bfloat16* __restrict__ C, + unsigned int M, unsigned int N, unsigned int K, unsigned int op, + unsigned int batch_size, unsigned int a_batch_count, unsigned int b_batch_count +) { + unsigned int batch = blockIdx.z; + if (batch >= batch_size) return; + unsigned int row = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= M || col >= N) return; + + unsigned int a_offset = (batch % a_batch_count) * M * K; + unsigned int b_offset = (batch % b_batch_count) * K * N; + unsigned int c_offset = batch * M * N; + + float acc; + switch (op) { + case 0: case 3: acc = __int_as_float(0x7f800000); break; + case 1: case 2: acc = __int_as_float(0xff800000); break; + default: acc = 0.0f; break; + } + + for (unsigned int kk = 0; kk < K; kk++) { + float a_val = __bfloat162float(A[a_offset + row * K + kk]); + float b_val = __bfloat162float(B[b_offset + kk * N + col]); + float combined; + switch (op) { + case 0: case 1: combined = a_val + b_val; break; + case 2: combined = fminf(a_val, b_val); break; + case 3: case 5: combined = fmaxf(a_val, b_val); break; + case 4: combined = (a_val != 0.0f && b_val != 0.0f) ? 1.0f : 0.0f; break; + default: combined = a_val + b_val; break; + } + switch (op) { + case 0: case 3: acc = fminf(acc, combined); break; + case 1: case 2: acc = fmaxf(acc, combined); break; + case 4: if (combined != 0.0f) acc = 1.0f; break; + case 5: acc = acc + combined; break; + default: acc = fminf(acc, combined); break; + } + } + C[c_offset + row * N + col] = __float2bfloat16(acc); +} + +// ============================================================================ +// FP8 E4M3 Kernels (compute in F32) +// ============================================================================ + +extern "C" __global__ void semiring_matmul_fp8_e4m3( + const numr_fp8_e4m3* __restrict__ A, + const numr_fp8_e4m3* __restrict__ B, + numr_fp8_e4m3* __restrict__ C, + unsigned int M, unsigned int N, unsigned int K, unsigned int op +) { + unsigned int row = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= M || col >= N) return; + + float acc; + switch (op) { + case 0: case 3: acc = __int_as_float(0x7f800000); break; + case 1: case 2: acc = __int_as_float(0xff800000); break; + default: acc = 0.0f; break; + } + + for (unsigned int kk = 0; kk < K; kk++) { + float a_val = fp8_e4m3_to_f32(A[row * K + kk].data); + float b_val = fp8_e4m3_to_f32(B[kk * N + col].data); + float combined; + switch (op) { + case 0: case 1: combined = a_val + b_val; break; + case 2: combined = fminf(a_val, b_val); break; + case 3: case 5: combined = fmaxf(a_val, b_val); break; + case 4: combined = (a_val != 0.0f && b_val != 0.0f) ? 1.0f : 0.0f; break; + default: combined = a_val + b_val; break; + } + switch (op) { + case 0: case 3: acc = fminf(acc, combined); break; + case 1: case 2: acc = fmaxf(acc, combined); break; + case 4: if (combined != 0.0f) acc = 1.0f; break; + case 5: acc = acc + combined; break; + default: acc = fminf(acc, combined); break; + } + } + C[row * N + col].data = f32_to_fp8_e4m3(acc); +} + +extern "C" __global__ void semiring_matmul_batched_fp8_e4m3( + const numr_fp8_e4m3* __restrict__ A, + const numr_fp8_e4m3* __restrict__ B, + numr_fp8_e4m3* __restrict__ C, + unsigned int M, unsigned int N, unsigned int K, unsigned int op, + unsigned int batch_size, unsigned int a_batch_count, unsigned int b_batch_count +) { + unsigned int batch = blockIdx.z; + if (batch >= batch_size) return; + unsigned int row = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= M || col >= N) return; + + unsigned int a_offset = (batch % a_batch_count) * M * K; + unsigned int b_offset = (batch % b_batch_count) * K * N; + unsigned int c_offset = batch * M * N; + + float acc; + switch (op) { + case 0: case 3: acc = __int_as_float(0x7f800000); break; + case 1: case 2: acc = __int_as_float(0xff800000); break; + default: acc = 0.0f; break; + } + + for (unsigned int kk = 0; kk < K; kk++) { + float a_val = fp8_e4m3_to_f32(A[a_offset + row * K + kk].data); + float b_val = fp8_e4m3_to_f32(B[b_offset + kk * N + col].data); + float combined; + switch (op) { + case 0: case 1: combined = a_val + b_val; break; + case 2: combined = fminf(a_val, b_val); break; + case 3: case 5: combined = fmaxf(a_val, b_val); break; + case 4: combined = (a_val != 0.0f && b_val != 0.0f) ? 1.0f : 0.0f; break; + default: combined = a_val + b_val; break; + } + switch (op) { + case 0: case 3: acc = fminf(acc, combined); break; + case 1: case 2: acc = fmaxf(acc, combined); break; + case 4: if (combined != 0.0f) acc = 1.0f; break; + case 5: acc = acc + combined; break; + default: acc = fminf(acc, combined); break; + } + } + C[c_offset + row * N + col].data = f32_to_fp8_e4m3(acc); +} + +// ============================================================================ +// FP8 E5M2 Kernels (compute in F32) +// ============================================================================ + +extern "C" __global__ void semiring_matmul_fp8_e5m2( + const numr_fp8_e5m2* __restrict__ A, + const numr_fp8_e5m2* __restrict__ B, + numr_fp8_e5m2* __restrict__ C, + unsigned int M, unsigned int N, unsigned int K, unsigned int op +) { + unsigned int row = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= M || col >= N) return; + + float acc; + switch (op) { + case 0: case 3: acc = __int_as_float(0x7f800000); break; + case 1: case 2: acc = __int_as_float(0xff800000); break; + default: acc = 0.0f; break; + } + + for (unsigned int kk = 0; kk < K; kk++) { + float a_val = fp8_e5m2_to_f32(A[row * K + kk].data); + float b_val = fp8_e5m2_to_f32(B[kk * N + col].data); + float combined; + switch (op) { + case 0: case 1: combined = a_val + b_val; break; + case 2: combined = fminf(a_val, b_val); break; + case 3: case 5: combined = fmaxf(a_val, b_val); break; + case 4: combined = (a_val != 0.0f && b_val != 0.0f) ? 1.0f : 0.0f; break; + default: combined = a_val + b_val; break; + } + switch (op) { + case 0: case 3: acc = fminf(acc, combined); break; + case 1: case 2: acc = fmaxf(acc, combined); break; + case 4: if (combined != 0.0f) acc = 1.0f; break; + case 5: acc = acc + combined; break; + default: acc = fminf(acc, combined); break; + } + } + C[row * N + col].data = f32_to_fp8_e5m2(acc); +} + +extern "C" __global__ void semiring_matmul_batched_fp8_e5m2( + const numr_fp8_e5m2* __restrict__ A, + const numr_fp8_e5m2* __restrict__ B, + numr_fp8_e5m2* __restrict__ C, + unsigned int M, unsigned int N, unsigned int K, unsigned int op, + unsigned int batch_size, unsigned int a_batch_count, unsigned int b_batch_count +) { + unsigned int batch = blockIdx.z; + if (batch >= batch_size) return; + unsigned int row = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= M || col >= N) return; + + unsigned int a_offset = (batch % a_batch_count) * M * K; + unsigned int b_offset = (batch % b_batch_count) * K * N; + unsigned int c_offset = batch * M * N; + + float acc; + switch (op) { + case 0: case 3: acc = __int_as_float(0x7f800000); break; + case 1: case 2: acc = __int_as_float(0xff800000); break; + default: acc = 0.0f; break; + } + + for (unsigned int kk = 0; kk < K; kk++) { + float a_val = fp8_e5m2_to_f32(A[a_offset + row * K + kk].data); + float b_val = fp8_e5m2_to_f32(B[b_offset + kk * N + col].data); + float combined; + switch (op) { + case 0: case 1: combined = a_val + b_val; break; + case 2: combined = fminf(a_val, b_val); break; + case 3: case 5: combined = fmaxf(a_val, b_val); break; + case 4: combined = (a_val != 0.0f && b_val != 0.0f) ? 1.0f : 0.0f; break; + default: combined = a_val + b_val; break; + } + switch (op) { + case 0: case 3: acc = fminf(acc, combined); break; + case 1: case 2: acc = fmaxf(acc, combined); break; + case 4: if (combined != 0.0f) acc = 1.0f; break; + case 5: acc = acc + combined; break; + default: acc = fminf(acc, combined); break; + } + } + C[c_offset + row * N + col].data = f32_to_fp8_e5m2(acc); +} diff --git a/src/runtime/cuda/kernels/softmax.cu b/src/runtime/cuda/kernels/softmax.cu new file mode 100644 index 00000000..6046a9cb --- /dev/null +++ b/src/runtime/cuda/kernels/softmax.cu @@ -0,0 +1,863 @@ +// Softmax CUDA kernels (forward + backward) +// Supports: softmax (last-dim), softmax_dim (non-last-dim), softmax_bwd, softmax_bwd_dim +// Types: f32, f64, f16, bf16, fp8_e4m3, fp8_e5m2 + +#include +#include +#include "dtype_traits.cuh" + +extern "C" { + +// ============================================================================ +// Softmax Forward (Last Dimension) +// ============================================================================ + +__global__ void softmax_f32( + const float* input, float* output, + unsigned int outer_size, unsigned int dim_size +) { + unsigned int outer_idx = blockIdx.x; + if (outer_idx >= outer_size) return; + + extern __shared__ float shared[]; + float* max_val = shared; + float* sum_exp = shared + blockDim.x; + + const float* row_in = input + outer_idx * dim_size; + float* row_out = output + outer_idx * dim_size; + + // Phase 1: Find max value for numerical stability + float thread_max = -INFINITY; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + thread_max = fmaxf(thread_max, row_in[i]); + } + max_val[threadIdx.x] = thread_max; + __syncthreads(); + + // Reduce max across threads + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + max_val[threadIdx.x] = fmaxf(max_val[threadIdx.x], max_val[threadIdx.x + s]); + } + __syncthreads(); + } + float row_max = max_val[0]; + __syncthreads(); + + // Phase 2: Compute exp(x - max) and sum + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + float val = expf(row_in[i] - row_max); + row_out[i] = val; // Temporarily store exp values + thread_sum += val; + } + sum_exp[threadIdx.x] = thread_sum; + __syncthreads(); + + // Reduce sum across threads + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sum_exp[threadIdx.x] += sum_exp[threadIdx.x + s]; + } + __syncthreads(); + } + float row_sum = sum_exp[0]; + __syncthreads(); + + // Phase 3: Normalize + float inv_sum = 1.0f / row_sum; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + row_out[i] *= inv_sum; + } +} + +__global__ void softmax_f64( + const double* input, double* output, + unsigned int outer_size, unsigned int dim_size +) { + unsigned int outer_idx = blockIdx.x; + if (outer_idx >= outer_size) return; + + extern __shared__ double shared_f64[]; + double* max_val = shared_f64; + double* sum_exp = shared_f64 + blockDim.x; + + const double* row_in = input + outer_idx * dim_size; + double* row_out = output + outer_idx * dim_size; + + double thread_max = -INFINITY; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + thread_max = fmax(thread_max, row_in[i]); + } + max_val[threadIdx.x] = thread_max; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + max_val[threadIdx.x] = fmax(max_val[threadIdx.x], max_val[threadIdx.x + s]); + } + __syncthreads(); + } + double row_max = max_val[0]; + __syncthreads(); + + double thread_sum = 0.0; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + double val = exp(row_in[i] - row_max); + row_out[i] = val; + thread_sum += val; + } + sum_exp[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sum_exp[threadIdx.x] += sum_exp[threadIdx.x + s]; + } + __syncthreads(); + } + double row_sum = sum_exp[0]; + __syncthreads(); + + double inv_sum = 1.0 / row_sum; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + row_out[i] *= inv_sum; + } +} + +__global__ void softmax_f16( + const __half* input, __half* output, + unsigned int outer_size, unsigned int dim_size +) { + unsigned int outer_idx = blockIdx.x; + if (outer_idx >= outer_size) return; + + extern __shared__ float shared[]; + float* max_val = shared; + float* sum_exp = shared + blockDim.x; + + const __half* row_in = input + outer_idx * dim_size; + __half* row_out = output + outer_idx * dim_size; + + float thread_max = -INFINITY; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + thread_max = fmaxf(thread_max, __half2float(row_in[i])); + } + max_val[threadIdx.x] = thread_max; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + max_val[threadIdx.x] = fmaxf(max_val[threadIdx.x], max_val[threadIdx.x + s]); + } + __syncthreads(); + } + float row_max = max_val[0]; + __syncthreads(); + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + float val = expf(__half2float(row_in[i]) - row_max); + row_out[i] = __float2half(val); + thread_sum += val; + } + sum_exp[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sum_exp[threadIdx.x] += sum_exp[threadIdx.x + s]; + } + __syncthreads(); + } + float row_sum = sum_exp[0]; + __syncthreads(); + + float inv_sum = 1.0f / row_sum; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + row_out[i] = __float2half(__half2float(row_out[i]) * inv_sum); + } +} + +__global__ void softmax_bf16( + const __nv_bfloat16* input, __nv_bfloat16* output, + unsigned int outer_size, unsigned int dim_size +) { + unsigned int outer_idx = blockIdx.x; + if (outer_idx >= outer_size) return; + + extern __shared__ float shared[]; + float* max_val = shared; + float* sum_exp = shared + blockDim.x; + + const __nv_bfloat16* row_in = input + outer_idx * dim_size; + __nv_bfloat16* row_out = output + outer_idx * dim_size; + + float thread_max = -INFINITY; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + thread_max = fmaxf(thread_max, __bfloat162float(row_in[i])); + } + max_val[threadIdx.x] = thread_max; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + max_val[threadIdx.x] = fmaxf(max_val[threadIdx.x], max_val[threadIdx.x + s]); + } + __syncthreads(); + } + float row_max = max_val[0]; + __syncthreads(); + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + float val = expf(__bfloat162float(row_in[i]) - row_max); + row_out[i] = __float2bfloat16(val); + thread_sum += val; + } + sum_exp[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sum_exp[threadIdx.x] += sum_exp[threadIdx.x + s]; + } + __syncthreads(); + } + float row_sum = sum_exp[0]; + __syncthreads(); + + float inv_sum = 1.0f / row_sum; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + row_out[i] = __float2bfloat16(__bfloat162float(row_out[i]) * inv_sum); + } +} + +__global__ void softmax_fp8_e4m3( + const numr_fp8_e4m3* input, numr_fp8_e4m3* output, + unsigned int outer_size, unsigned int dim_size +) { + unsigned int outer_idx = blockIdx.x; + if (outer_idx >= outer_size) return; + + extern __shared__ float shared[]; + float* max_val = shared; + float* sum_exp = shared + blockDim.x; + + const numr_fp8_e4m3* row_in = input + outer_idx * dim_size; + numr_fp8_e4m3* row_out = output + outer_idx * dim_size; + + float thread_max = -INFINITY; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + thread_max = fmaxf(thread_max, fp8_e4m3_to_f32(row_in[i].data)); + } + max_val[threadIdx.x] = thread_max; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + max_val[threadIdx.x] = fmaxf(max_val[threadIdx.x], max_val[threadIdx.x + s]); + } + __syncthreads(); + } + float row_max = max_val[0]; + __syncthreads(); + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + float val = expf(fp8_e4m3_to_f32(row_in[i].data) - row_max); + row_out[i] = numr_fp8_e4m3(f32_to_fp8_e4m3(val)); + thread_sum += val; + } + sum_exp[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sum_exp[threadIdx.x] += sum_exp[threadIdx.x + s]; + } + __syncthreads(); + } + float row_sum = sum_exp[0]; + __syncthreads(); + + float inv_sum = 1.0f / row_sum; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + row_out[i] = numr_fp8_e4m3(f32_to_fp8_e4m3(fp8_e4m3_to_f32(row_out[i].data) * inv_sum)); + } +} + +__global__ void softmax_fp8_e5m2( + const numr_fp8_e5m2* input, numr_fp8_e5m2* output, + unsigned int outer_size, unsigned int dim_size +) { + unsigned int outer_idx = blockIdx.x; + if (outer_idx >= outer_size) return; + + extern __shared__ float shared[]; + float* max_val = shared; + float* sum_exp = shared + blockDim.x; + + const numr_fp8_e5m2* row_in = input + outer_idx * dim_size; + numr_fp8_e5m2* row_out = output + outer_idx * dim_size; + + float thread_max = -INFINITY; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + thread_max = fmaxf(thread_max, fp8_e5m2_to_f32(row_in[i].data)); + } + max_val[threadIdx.x] = thread_max; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + max_val[threadIdx.x] = fmaxf(max_val[threadIdx.x], max_val[threadIdx.x + s]); + } + __syncthreads(); + } + float row_max = max_val[0]; + __syncthreads(); + + float thread_sum = 0.0f; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + float val = expf(fp8_e5m2_to_f32(row_in[i].data) - row_max); + row_out[i] = numr_fp8_e5m2(f32_to_fp8_e5m2(val)); + thread_sum += val; + } + sum_exp[threadIdx.x] = thread_sum; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sum_exp[threadIdx.x] += sum_exp[threadIdx.x + s]; + } + __syncthreads(); + } + float row_sum = sum_exp[0]; + __syncthreads(); + + float inv_sum = 1.0f / row_sum; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + row_out[i] = numr_fp8_e5m2(f32_to_fp8_e5m2(fp8_e5m2_to_f32(row_out[i].data) * inv_sum)); + } +} + +// ============================================================================ +// Softmax Forward (Non-Last Dimension) +// For shape [A, B, C] with softmax over dim=1: +// outer_size = A, dim_size = B, inner_size = C +// ============================================================================ + +__global__ void softmax_dim_f32( + const float* input, float* output, + unsigned int outer_size, unsigned int dim_size, unsigned int inner_size +) { + unsigned int outer_idx = blockIdx.x; + unsigned int inner_idx = blockIdx.y; + if (outer_idx >= outer_size || inner_idx >= inner_size) return; + + unsigned int base = outer_idx * dim_size * inner_size + inner_idx; + unsigned int stride = inner_size; + + float max_val = input[base]; + float sum = 1.0f; + for (unsigned int i = 1; i < dim_size; i++) { + float val = input[base + i * stride]; + if (val > max_val) { + sum = sum * expf(max_val - val) + 1.0f; + max_val = val; + } else { + sum += expf(val - max_val); + } + } + + float inv_sum = 1.0f / sum; + for (unsigned int i = 0; i < dim_size; i++) { + output[base + i * stride] = expf(input[base + i * stride] - max_val) * inv_sum; + } +} + +__global__ void softmax_dim_f64( + const double* input, double* output, + unsigned int outer_size, unsigned int dim_size, unsigned int inner_size +) { + unsigned int outer_idx = blockIdx.x; + unsigned int inner_idx = blockIdx.y; + if (outer_idx >= outer_size || inner_idx >= inner_size) return; + + unsigned int base = outer_idx * dim_size * inner_size + inner_idx; + unsigned int stride = inner_size; + + double max_val = input[base]; + double sum = 1.0; + for (unsigned int i = 1; i < dim_size; i++) { + double val = input[base + i * stride]; + if (val > max_val) { + sum = sum * exp(max_val - val) + 1.0; + max_val = val; + } else { + sum += exp(val - max_val); + } + } + + double inv_sum = 1.0 / sum; + for (unsigned int i = 0; i < dim_size; i++) { + output[base + i * stride] = exp(input[base + i * stride] - max_val) * inv_sum; + } +} + +__global__ void softmax_dim_f16( + const __half* input, __half* output, + unsigned int outer_size, unsigned int dim_size, unsigned int inner_size +) { + unsigned int outer_idx = blockIdx.x; + unsigned int inner_idx = blockIdx.y; + if (outer_idx >= outer_size || inner_idx >= inner_size) return; + + unsigned int base = outer_idx * dim_size * inner_size + inner_idx; + unsigned int stride = inner_size; + + float max_val = __half2float(input[base]); + float sum = 1.0f; + for (unsigned int i = 1; i < dim_size; i++) { + float val = __half2float(input[base + i * stride]); + if (val > max_val) { + sum = sum * expf(max_val - val) + 1.0f; + max_val = val; + } else { + sum += expf(val - max_val); + } + } + + float inv_sum = 1.0f / sum; + for (unsigned int i = 0; i < dim_size; i++) { + float val = __half2float(input[base + i * stride]); + output[base + i * stride] = __float2half(expf(val - max_val) * inv_sum); + } +} + +__global__ void softmax_dim_bf16( + const __nv_bfloat16* input, __nv_bfloat16* output, + unsigned int outer_size, unsigned int dim_size, unsigned int inner_size +) { + unsigned int outer_idx = blockIdx.x; + unsigned int inner_idx = blockIdx.y; + if (outer_idx >= outer_size || inner_idx >= inner_size) return; + + unsigned int base = outer_idx * dim_size * inner_size + inner_idx; + unsigned int stride = inner_size; + + float max_val = __bfloat162float(input[base]); + float sum = 1.0f; + for (unsigned int i = 1; i < dim_size; i++) { + float val = __bfloat162float(input[base + i * stride]); + if (val > max_val) { + sum = sum * expf(max_val - val) + 1.0f; + max_val = val; + } else { + sum += expf(val - max_val); + } + } + + float inv_sum = 1.0f / sum; + for (unsigned int i = 0; i < dim_size; i++) { + float val = __bfloat162float(input[base + i * stride]); + output[base + i * stride] = __float2bfloat16(expf(val - max_val) * inv_sum); + } +} + +__global__ void softmax_dim_fp8_e4m3( + const numr_fp8_e4m3* input, numr_fp8_e4m3* output, + unsigned int outer_size, unsigned int dim_size, unsigned int inner_size +) { + unsigned int outer_idx = blockIdx.x; + unsigned int inner_idx = blockIdx.y; + if (outer_idx >= outer_size || inner_idx >= inner_size) return; + + unsigned int base = outer_idx * dim_size * inner_size + inner_idx; + unsigned int stride = inner_size; + + float max_val = fp8_e4m3_to_f32(input[base].data); + float sum = 1.0f; + for (unsigned int i = 1; i < dim_size; i++) { + float val = fp8_e4m3_to_f32(input[base + i * stride].data); + if (val > max_val) { + sum = sum * expf(max_val - val) + 1.0f; + max_val = val; + } else { + sum += expf(val - max_val); + } + } + + float inv_sum = 1.0f / sum; + for (unsigned int i = 0; i < dim_size; i++) { + float val = fp8_e4m3_to_f32(input[base + i * stride].data); + output[base + i * stride] = numr_fp8_e4m3(f32_to_fp8_e4m3(expf(val - max_val) * inv_sum)); + } +} + +__global__ void softmax_dim_fp8_e5m2( + const numr_fp8_e5m2* input, numr_fp8_e5m2* output, + unsigned int outer_size, unsigned int dim_size, unsigned int inner_size +) { + unsigned int outer_idx = blockIdx.x; + unsigned int inner_idx = blockIdx.y; + if (outer_idx >= outer_size || inner_idx >= inner_size) return; + + unsigned int base = outer_idx * dim_size * inner_size + inner_idx; + unsigned int stride = inner_size; + + float max_val = fp8_e5m2_to_f32(input[base].data); + float sum = 1.0f; + for (unsigned int i = 1; i < dim_size; i++) { + float val = fp8_e5m2_to_f32(input[base + i * stride].data); + if (val > max_val) { + sum = sum * expf(max_val - val) + 1.0f; + max_val = val; + } else { + sum += expf(val - max_val); + } + } + + float inv_sum = 1.0f / sum; + for (unsigned int i = 0; i < dim_size; i++) { + float val = fp8_e5m2_to_f32(input[base + i * stride].data); + output[base + i * stride] = numr_fp8_e5m2(f32_to_fp8_e5m2(expf(val - max_val) * inv_sum)); + } +} + +// ============================================================================ +// Softmax Backward (Last Dimension) +// d_input = output * (grad - dot), where dot = sum(grad * output) +// ============================================================================ + +__global__ void softmax_bwd_f32( + const float* grad, const float* output, float* d_input, + unsigned int outer_size, unsigned int dim_size +) { + unsigned int outer_idx = blockIdx.x; + if (outer_idx >= outer_size) return; + + extern __shared__ float shared[]; + + const float* g_row = grad + outer_idx * dim_size; + const float* o_row = output + outer_idx * dim_size; + float* d_row = d_input + outer_idx * dim_size; + + float thread_dot = 0.0f; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + thread_dot += g_row[i] * o_row[i]; + } + shared[threadIdx.x] = thread_dot; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) shared[threadIdx.x] += shared[threadIdx.x + s]; + __syncthreads(); + } + float dot = shared[0]; + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + d_row[i] = o_row[i] * (g_row[i] - dot); + } +} + +__global__ void softmax_bwd_f64( + const double* grad, const double* output, double* d_input, + unsigned int outer_size, unsigned int dim_size +) { + unsigned int outer_idx = blockIdx.x; + if (outer_idx >= outer_size) return; + + extern __shared__ double shared_d[]; + + const double* g_row = grad + outer_idx * dim_size; + const double* o_row = output + outer_idx * dim_size; + double* d_row = d_input + outer_idx * dim_size; + + double thread_dot = 0.0; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + thread_dot += g_row[i] * o_row[i]; + } + shared_d[threadIdx.x] = thread_dot; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) shared_d[threadIdx.x] += shared_d[threadIdx.x + s]; + __syncthreads(); + } + double dot = shared_d[0]; + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + d_row[i] = o_row[i] * (g_row[i] - dot); + } +} + +__global__ void softmax_bwd_f16( + const __half* grad, const __half* output, __half* d_input, + unsigned int outer_size, unsigned int dim_size +) { + unsigned int outer_idx = blockIdx.x; + if (outer_idx >= outer_size) return; + + extern __shared__ float shared_f16[]; + + const __half* g_row = grad + outer_idx * dim_size; + const __half* o_row = output + outer_idx * dim_size; + __half* d_row = d_input + outer_idx * dim_size; + + float thread_dot = 0.0f; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + thread_dot += __half2float(g_row[i]) * __half2float(o_row[i]); + } + shared_f16[threadIdx.x] = thread_dot; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) shared_f16[threadIdx.x] += shared_f16[threadIdx.x + s]; + __syncthreads(); + } + float dot = shared_f16[0]; + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + float g = __half2float(g_row[i]); + float o = __half2float(o_row[i]); + d_row[i] = __float2half(o * (g - dot)); + } +} + +__global__ void softmax_bwd_bf16( + const __nv_bfloat16* grad, const __nv_bfloat16* output, __nv_bfloat16* d_input, + unsigned int outer_size, unsigned int dim_size +) { + unsigned int outer_idx = blockIdx.x; + if (outer_idx >= outer_size) return; + + extern __shared__ float shared_bf16[]; + + const __nv_bfloat16* g_row = grad + outer_idx * dim_size; + const __nv_bfloat16* o_row = output + outer_idx * dim_size; + __nv_bfloat16* d_row = d_input + outer_idx * dim_size; + + float thread_dot = 0.0f; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + thread_dot += __bfloat162float(g_row[i]) * __bfloat162float(o_row[i]); + } + shared_bf16[threadIdx.x] = thread_dot; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) shared_bf16[threadIdx.x] += shared_bf16[threadIdx.x + s]; + __syncthreads(); + } + float dot = shared_bf16[0]; + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + float g = __bfloat162float(g_row[i]); + float o = __bfloat162float(o_row[i]); + d_row[i] = __float2bfloat16(o * (g - dot)); + } +} + +__global__ void softmax_bwd_fp8_e4m3( + const numr_fp8_e4m3* grad, const numr_fp8_e4m3* output, numr_fp8_e4m3* d_input, + unsigned int outer_size, unsigned int dim_size +) { + unsigned int outer_idx = blockIdx.x; + if (outer_idx >= outer_size) return; + + extern __shared__ float shared_fp8[]; + + const numr_fp8_e4m3* g_row = grad + outer_idx * dim_size; + const numr_fp8_e4m3* o_row = output + outer_idx * dim_size; + numr_fp8_e4m3* d_row = d_input + outer_idx * dim_size; + + float thread_dot = 0.0f; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + thread_dot += fp8_e4m3_to_f32(g_row[i].data) * fp8_e4m3_to_f32(o_row[i].data); + } + shared_fp8[threadIdx.x] = thread_dot; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) shared_fp8[threadIdx.x] += shared_fp8[threadIdx.x + s]; + __syncthreads(); + } + float dot = shared_fp8[0]; + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + float g = fp8_e4m3_to_f32(g_row[i].data); + float o = fp8_e4m3_to_f32(o_row[i].data); + d_row[i] = numr_fp8_e4m3(f32_to_fp8_e4m3(o * (g - dot))); + } +} + +__global__ void softmax_bwd_fp8_e5m2( + const numr_fp8_e5m2* grad, const numr_fp8_e5m2* output, numr_fp8_e5m2* d_input, + unsigned int outer_size, unsigned int dim_size +) { + unsigned int outer_idx = blockIdx.x; + if (outer_idx >= outer_size) return; + + extern __shared__ float shared_fp8e5[]; + + const numr_fp8_e5m2* g_row = grad + outer_idx * dim_size; + const numr_fp8_e5m2* o_row = output + outer_idx * dim_size; + numr_fp8_e5m2* d_row = d_input + outer_idx * dim_size; + + float thread_dot = 0.0f; + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + thread_dot += fp8_e5m2_to_f32(g_row[i].data) * fp8_e5m2_to_f32(o_row[i].data); + } + shared_fp8e5[threadIdx.x] = thread_dot; + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) shared_fp8e5[threadIdx.x] += shared_fp8e5[threadIdx.x + s]; + __syncthreads(); + } + float dot = shared_fp8e5[0]; + __syncthreads(); + + for (unsigned int i = threadIdx.x; i < dim_size; i += blockDim.x) { + float g = fp8_e5m2_to_f32(g_row[i].data); + float o = fp8_e5m2_to_f32(o_row[i].data); + d_row[i] = numr_fp8_e5m2(f32_to_fp8_e5m2(o * (g - dot))); + } +} + +// ============================================================================ +// Softmax Backward (Non-Last Dimension) +// ============================================================================ + +__global__ void softmax_bwd_dim_f32( + const float* grad, const float* output, float* d_input, + unsigned int outer_size, unsigned int dim_size, unsigned int inner_size +) { + unsigned int outer_idx = blockIdx.x; + unsigned int inner_idx = blockIdx.y; + if (outer_idx >= outer_size || inner_idx >= inner_size) return; + + unsigned int base = outer_idx * dim_size * inner_size + inner_idx; + unsigned int stride = inner_size; + + float dot = 0.0f; + for (unsigned int i = 0; i < dim_size; i++) { + dot += grad[base + i * stride] * output[base + i * stride]; + } + for (unsigned int i = 0; i < dim_size; i++) { + unsigned int idx = base + i * stride; + d_input[idx] = output[idx] * (grad[idx] - dot); + } +} + +__global__ void softmax_bwd_dim_f64( + const double* grad, const double* output, double* d_input, + unsigned int outer_size, unsigned int dim_size, unsigned int inner_size +) { + unsigned int outer_idx = blockIdx.x; + unsigned int inner_idx = blockIdx.y; + if (outer_idx >= outer_size || inner_idx >= inner_size) return; + + unsigned int base = outer_idx * dim_size * inner_size + inner_idx; + unsigned int stride = inner_size; + + double dot = 0.0; + for (unsigned int i = 0; i < dim_size; i++) { + dot += grad[base + i * stride] * output[base + i * stride]; + } + for (unsigned int i = 0; i < dim_size; i++) { + unsigned int idx = base + i * stride; + d_input[idx] = output[idx] * (grad[idx] - dot); + } +} + +__global__ void softmax_bwd_dim_f16( + const __half* grad, const __half* output, __half* d_input, + unsigned int outer_size, unsigned int dim_size, unsigned int inner_size +) { + unsigned int outer_idx = blockIdx.x; + unsigned int inner_idx = blockIdx.y; + if (outer_idx >= outer_size || inner_idx >= inner_size) return; + + unsigned int base = outer_idx * dim_size * inner_size + inner_idx; + unsigned int stride = inner_size; + + float dot = 0.0f; + for (unsigned int i = 0; i < dim_size; i++) { + dot += __half2float(grad[base + i * stride]) * __half2float(output[base + i * stride]); + } + for (unsigned int i = 0; i < dim_size; i++) { + unsigned int idx = base + i * stride; + d_input[idx] = __float2half(__half2float(output[idx]) * (__half2float(grad[idx]) - dot)); + } +} + +__global__ void softmax_bwd_dim_bf16( + const __nv_bfloat16* grad, const __nv_bfloat16* output, __nv_bfloat16* d_input, + unsigned int outer_size, unsigned int dim_size, unsigned int inner_size +) { + unsigned int outer_idx = blockIdx.x; + unsigned int inner_idx = blockIdx.y; + if (outer_idx >= outer_size || inner_idx >= inner_size) return; + + unsigned int base = outer_idx * dim_size * inner_size + inner_idx; + unsigned int stride = inner_size; + + float dot = 0.0f; + for (unsigned int i = 0; i < dim_size; i++) { + dot += __bfloat162float(grad[base + i * stride]) * __bfloat162float(output[base + i * stride]); + } + for (unsigned int i = 0; i < dim_size; i++) { + unsigned int idx = base + i * stride; + d_input[idx] = __float2bfloat16(__bfloat162float(output[idx]) * (__bfloat162float(grad[idx]) - dot)); + } +} + +__global__ void softmax_bwd_dim_fp8_e4m3( + const numr_fp8_e4m3* grad, const numr_fp8_e4m3* output, numr_fp8_e4m3* d_input, + unsigned int outer_size, unsigned int dim_size, unsigned int inner_size +) { + unsigned int outer_idx = blockIdx.x; + unsigned int inner_idx = blockIdx.y; + if (outer_idx >= outer_size || inner_idx >= inner_size) return; + + unsigned int base = outer_idx * dim_size * inner_size + inner_idx; + unsigned int stride = inner_size; + + float dot = 0.0f; + for (unsigned int i = 0; i < dim_size; i++) { + dot += fp8_e4m3_to_f32(grad[base + i * stride].data) * fp8_e4m3_to_f32(output[base + i * stride].data); + } + for (unsigned int i = 0; i < dim_size; i++) { + unsigned int idx = base + i * stride; + d_input[idx] = numr_fp8_e4m3(f32_to_fp8_e4m3(fp8_e4m3_to_f32(output[idx].data) * (fp8_e4m3_to_f32(grad[idx].data) - dot))); + } +} + +__global__ void softmax_bwd_dim_fp8_e5m2( + const numr_fp8_e5m2* grad, const numr_fp8_e5m2* output, numr_fp8_e5m2* d_input, + unsigned int outer_size, unsigned int dim_size, unsigned int inner_size +) { + unsigned int outer_idx = blockIdx.x; + unsigned int inner_idx = blockIdx.y; + if (outer_idx >= outer_size || inner_idx >= inner_size) return; + + unsigned int base = outer_idx * dim_size * inner_size + inner_idx; + unsigned int stride = inner_size; + + float dot = 0.0f; + for (unsigned int i = 0; i < dim_size; i++) { + dot += fp8_e5m2_to_f32(grad[base + i * stride].data) * fp8_e5m2_to_f32(output[base + i * stride].data); + } + for (unsigned int i = 0; i < dim_size; i++) { + unsigned int idx = base + i * stride; + d_input[idx] = numr_fp8_e5m2(f32_to_fp8_e5m2(fp8_e5m2_to_f32(output[idx].data) * (fp8_e5m2_to_f32(grad[idx].data) - dot))); + } +} + +} // extern "C" diff --git a/src/runtime/cuda/kernels/sort.rs b/src/runtime/cuda/kernels/sort.rs index ee450c00..21e8d162 100644 --- a/src/runtime/cuda/kernels/sort.rs +++ b/src/runtime/cuda/kernels/sort.rs @@ -26,6 +26,11 @@ fn sort_shared_mem_size(sort_size: usize, elem_size: usize) -> u32 { } /// Launch sort kernel with indices +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_sort( context: &Arc, stream: &CudaStream, @@ -77,6 +82,11 @@ pub unsafe fn launch_sort( } /// Launch sort kernel (values only, no indices) +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_sort_values_only( context: &Arc, stream: &CudaStream, @@ -128,6 +138,11 @@ pub unsafe fn launch_sort_values_only( } /// Launch argsort kernel (indices only, no values) +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_argsort( context: &Arc, stream: &CudaStream, @@ -176,6 +191,11 @@ pub unsafe fn launch_argsort( } /// Launch topk kernel +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_topk( context: &Arc, stream: &CudaStream, @@ -232,6 +252,11 @@ pub unsafe fn launch_topk( } /// Launch count_nonzero kernel +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_count_nonzero( context: &Arc, stream: &CudaStream, @@ -268,6 +293,11 @@ pub unsafe fn launch_count_nonzero( } /// Launch gather_nonzero kernel +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_gather_nonzero( context: &Arc, stream: &CudaStream, @@ -306,6 +336,11 @@ pub unsafe fn launch_gather_nonzero( } /// Launch flat_to_multi_index kernel +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_flat_to_multi_index( context: &Arc, stream: &CudaStream, @@ -346,6 +381,11 @@ pub unsafe fn launch_flat_to_multi_index( } /// Launch searchsorted kernel +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_searchsorted( context: &Arc, stream: &CudaStream, @@ -388,6 +428,11 @@ pub unsafe fn launch_searchsorted( } /// Launch count_unique kernel +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_count_unique( context: &Arc, stream: &CudaStream, @@ -422,6 +467,11 @@ pub unsafe fn launch_count_unique( } /// Launch extract_unique kernel +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_extract_unique( context: &Arc, stream: &CudaStream, @@ -458,6 +508,11 @@ pub unsafe fn launch_extract_unique( } /// Launch bincount kernel - counts occurrences of each index +/// +/// # Safety +/// +/// Caller must ensure all raw pointer arguments (`*_ptr`) point to valid GPU memory +/// allocated on `device_index` with sufficient size for the operation. pub unsafe fn launch_bincount( context: &Arc, stream: &CudaStream, diff --git a/src/runtime/cuda/kernels/sparse_24.cu b/src/runtime/cuda/kernels/sparse_24.cu new file mode 100644 index 00000000..6159ca62 --- /dev/null +++ b/src/runtime/cuda/kernels/sparse_24.cu @@ -0,0 +1,275 @@ +// 2:4 Structured Sparsity CUDA kernels +// Operations: prune to 2:4, decompress to dense, sparse matmul +// +// Metadata format: 4 bits per group of 4, bitmask with exactly 2 bits set. +// 8 groups packed per U32 (8 × 4 = 32 bits). + +#include +#include +#include "dtype_traits.cuh" + +// ============================================================================ +// Prune to 2:4: For each group of 4 elements, keep the 2 with largest magnitude +// ============================================================================ + +template +__device__ float to_abs_float(T val) { + return fabsf(static_cast(val)); +} + +__device__ float to_abs_float(__half val) { + return fabsf(__half2float(val)); +} + +__device__ float to_abs_float(__nv_bfloat16 val) { + return fabsf(__bfloat162float(val)); +} + +// One thread per group of 4 elements +template +__device__ void prune_to_24_impl( + const T* __restrict__ dense, // [M, K] + T* __restrict__ compressed, // [M, K/2] + unsigned int* __restrict__ metadata, // [M, meta_cols] + unsigned int M, + unsigned int K +) { + unsigned int num_groups_per_row = K / 4; + unsigned int meta_cols = (num_groups_per_row + 7) / 8; + unsigned int half_k = K / 2; + + unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int total_groups = M * num_groups_per_row; + if (tid >= total_groups) return; + + unsigned int row = tid / num_groups_per_row; + unsigned int g = tid % num_groups_per_row; + unsigned int base = row * K + g * 4; + + // Load 4 values + T vals[4]; + vals[0] = dense[base]; + vals[1] = dense[base + 1]; + vals[2] = dense[base + 2]; + vals[3] = dense[base + 3]; + + // Compute magnitudes + float mags[4]; + mags[0] = to_abs_float(vals[0]); + mags[1] = to_abs_float(vals[1]); + mags[2] = to_abs_float(vals[2]); + mags[3] = to_abs_float(vals[3]); + + // Find top-2 by magnitude (stable: prefer earlier indices on tie) + // Simple selection network for 4 elements + int idx0 = 0, idx1 = 1; + float m0 = mags[0], m1 = mags[1]; + + // Ensure m0 >= m1 + if (m1 > m0) { int t = idx0; idx0 = idx1; idx1 = t; float ft = m0; m0 = m1; m1 = ft; } + + // Compare with index 2 + if (mags[2] > m1) { + idx1 = 2; m1 = mags[2]; + if (m1 > m0) { int t = idx0; idx0 = idx1; idx1 = t; float ft = m0; m0 = m1; m1 = ft; } + } + + // Compare with index 3 + if (mags[3] > m1) { + idx1 = 3; m1 = mags[3]; + if (m1 > m0) { int t = idx0; idx0 = idx1; idx1 = t; } + } + + // Sort kept indices so lower index comes first + int first = min(idx0, idx1); + int second = max(idx0, idx1); + + // Write compressed values (2 per group) + unsigned int out_base = row * half_k + g * 2; + compressed[out_base] = vals[first]; + compressed[out_base + 1] = vals[second]; + + // Build 4-bit bitmask + unsigned int mask = (1u << first) | (1u << second); + + // Pack into metadata (atomic OR since multiple threads may write to same U32) + unsigned int word_idx = g / 8; + unsigned int nibble_idx = g % 8; + unsigned int meta_offset = row * meta_cols + word_idx; + atomicOr(&metadata[meta_offset], mask << (nibble_idx * 4)); +} + +// ============================================================================ +// Decompress: Reconstruct dense matrix from 2:4 compressed format +// ============================================================================ + +template +__device__ void decompress_24_impl( + const T* __restrict__ compressed, // [M, K/2] + const unsigned int* __restrict__ metadata, // [M, meta_cols] + T* __restrict__ dense, // [M, K] + unsigned int M, + unsigned int K +) { + unsigned int num_groups_per_row = K / 4; + unsigned int meta_cols = (num_groups_per_row + 7) / 8; + unsigned int half_k = K / 2; + + unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int total_groups = M * num_groups_per_row; + if (tid >= total_groups) return; + + unsigned int row = tid / num_groups_per_row; + unsigned int g = tid % num_groups_per_row; + + // Read metadata + unsigned int word_idx = g / 8; + unsigned int nibble_idx = g % 8; + unsigned int word = metadata[row * meta_cols + word_idx]; + unsigned int mask = (word >> (nibble_idx * 4)) & 0xF; + + // Read 2 compressed values + unsigned int in_base = row * half_k + g * 2; + T v0 = compressed[in_base]; + T v1 = compressed[in_base + 1]; + + // Write to dense (zero all 4 first, then fill kept positions) + unsigned int out_base = row * K + g * 4; + T zero = static_cast(0); + dense[out_base] = zero; + dense[out_base + 1] = zero; + dense[out_base + 2] = zero; + dense[out_base + 3] = zero; + + // Place values at their positions + int val_idx = 0; + for (int bit = 0; bit < 4; bit++) { + if (mask & (1u << bit)) { + dense[out_base + bit] = (val_idx == 0) ? v0 : v1; + val_idx++; + } + } +} + +// ============================================================================ +// Sparse 2:4 MatMul: C = A @ B^T where B is in 2:4 compressed format +// A: [N, K] dense, B: [M, K] compressed as [M, K/2] + metadata → C: [N, M] +// +// Each thread computes one element of C by decompressing B on the fly. +// Tiled with shared memory for better performance. +// ============================================================================ + +#define TILE_SIZE 16 + +template +__device__ void sparse_24_matmul_impl( + const T* __restrict__ A, // [N, K] dense input + const T* __restrict__ B_compressed, // [M, K/2] compressed weights + const unsigned int* __restrict__ B_metadata, // [M, meta_cols] + T* __restrict__ C, // [N, M] output + unsigned int N, + unsigned int M, + unsigned int K +) { + unsigned int num_groups = K / 4; + unsigned int meta_cols = (num_groups + 7) / 8; + unsigned int half_k = K / 2; + + unsigned int row = blockIdx.y * TILE_SIZE + threadIdx.y; // output row (N dim) + unsigned int col = blockIdx.x * TILE_SIZE + threadIdx.x; // output col (M dim) + + if (row >= N || col >= M) return; + + AccT sum = static_cast(0); + + // For each group of 4 in K dimension + for (unsigned int g = 0; g < num_groups; g++) { + // Read A values (dense, 4 consecutive) + unsigned int a_base = row * K + g * 4; + AccT a0 = static_cast(A[a_base]); + AccT a1 = static_cast(A[a_base + 1]); + AccT a2 = static_cast(A[a_base + 2]); + AccT a3 = static_cast(A[a_base + 3]); + + // Read B compressed values (2 per group) + unsigned int b_base = col * half_k + g * 2; + AccT b0 = static_cast(B_compressed[b_base]); + AccT b1 = static_cast(B_compressed[b_base + 1]); + + // Read B metadata + unsigned int word_idx = g / 8; + unsigned int nibble_idx = g % 8; + unsigned int word = B_metadata[col * meta_cols + word_idx]; + unsigned int mask = (word >> (nibble_idx * 4)) & 0xF; + + // Decompress and accumulate on the fly + AccT a_vals[4] = {a0, a1, a2, a3}; + int val_idx = 0; + for (int bit = 0; bit < 4; bit++) { + if (mask & (1u << bit)) { + AccT b_val = (val_idx == 0) ? b0 : b1; + sum += a_vals[bit] * b_val; + val_idx++; + } + } + } + + C[row * M + col] = static_cast(sum); +} + +// ============================================================================ +// F16/BF16 specialization: decompress kernel (same logic, no special accumulation needed) +// ============================================================================ + +// For F16 decompress, the template works directly since we just copy values. +// For F16 matmul, we accumulate in F32. + +// ============================================================================ +// Extern "C" instantiations +// ============================================================================ + +extern "C" { + +// --- Prune --- +__global__ void sparse_24_prune_f32(const float* d, float* c, unsigned int* m, unsigned int M, unsigned int K) { + prune_to_24_impl(d, c, m, M, K); +} +__global__ void sparse_24_prune_f64(const double* d, double* c, unsigned int* m, unsigned int M, unsigned int K) { + prune_to_24_impl(d, c, m, M, K); +} +__global__ void sparse_24_prune_f16(const __half* d, __half* c, unsigned int* m, unsigned int M, unsigned int K) { + prune_to_24_impl<__half>(d, c, m, M, K); +} +__global__ void sparse_24_prune_bf16(const __nv_bfloat16* d, __nv_bfloat16* c, unsigned int* m, unsigned int M, unsigned int K) { + prune_to_24_impl<__nv_bfloat16>(d, c, m, M, K); +} + +// --- Decompress --- +__global__ void sparse_24_decompress_f32(const float* c, const unsigned int* m, float* d, unsigned int M, unsigned int K) { + decompress_24_impl(c, m, d, M, K); +} +__global__ void sparse_24_decompress_f64(const double* c, const unsigned int* m, double* d, unsigned int M, unsigned int K) { + decompress_24_impl(c, m, d, M, K); +} +__global__ void sparse_24_decompress_f16(const __half* c, const unsigned int* m, __half* d, unsigned int M, unsigned int K) { + decompress_24_impl<__half>(c, m, d, M, K); +} +__global__ void sparse_24_decompress_bf16(const __nv_bfloat16* c, const unsigned int* m, __nv_bfloat16* d, unsigned int M, unsigned int K) { + decompress_24_impl<__nv_bfloat16>(c, m, d, M, K); +} + +// --- Matmul (accumulate in appropriate precision) --- +__global__ void sparse_24_matmul_f32(const float* A, const float* Bc, const unsigned int* Bm, float* C, unsigned int N, unsigned int M, unsigned int K) { + sparse_24_matmul_impl(A, Bc, Bm, C, N, M, K); +} +__global__ void sparse_24_matmul_f64(const double* A, const double* Bc, const unsigned int* Bm, double* C, unsigned int N, unsigned int M, unsigned int K) { + sparse_24_matmul_impl(A, Bc, Bm, C, N, M, K); +} +__global__ void sparse_24_matmul_f16(const __half* A, const __half* Bc, const unsigned int* Bm, __half* C, unsigned int N, unsigned int M, unsigned int K) { + sparse_24_matmul_impl<__half, float>(A, Bc, Bm, C, N, M, K); +} +__global__ void sparse_24_matmul_bf16(const __nv_bfloat16* A, const __nv_bfloat16* Bc, const unsigned int* Bm, __nv_bfloat16* C, unsigned int N, unsigned int M, unsigned int K) { + sparse_24_matmul_impl<__nv_bfloat16, float>(A, Bc, Bm, C, N, M, K); +} + +} // extern "C" diff --git a/src/runtime/cuda/kernels/sparse_24_launcher.rs b/src/runtime/cuda/kernels/sparse_24_launcher.rs new file mode 100644 index 00000000..95bc4a34 --- /dev/null +++ b/src/runtime/cuda/kernels/sparse_24_launcher.rs @@ -0,0 +1,149 @@ +//! CUDA kernel launchers for 2:4 structured sparsity +//! +//! Kernel source: sparse_24.cu + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::runtime::cuda::kernels::loader::{ + BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name, + launch_config, +}; + +const MODULE_NAME: &str = "sparse_24"; + +/// Launch prune-to-2:4 kernel. +/// +/// # Safety +/// All pointers must be valid device memory of correct size. +pub unsafe fn launch_sparse_24_prune( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + dense_ptr: u64, + compressed_ptr: u64, + metadata_ptr: u64, + m: usize, + k: usize, +) -> Result<()> { + let module = get_or_load_module(context, device_index, MODULE_NAME)?; + let func_name = kernel_name("sparse_24_prune", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let total_groups = (m * (k / 4)) as u32; + let grid = elementwise_launch_config(total_groups as usize); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let m_u32 = m as u32; + let k_u32 = k as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&dense_ptr); + builder.arg(&compressed_ptr); + builder.arg(&metadata_ptr); + builder.arg(&m_u32); + builder.arg(&k_u32); + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA sparse_24_prune launch failed: {e:?}")))?; + } + + Ok(()) +} + +/// Launch decompress-from-2:4 kernel. +/// +/// # Safety +/// All pointers must be valid device memory of correct size. +pub unsafe fn launch_sparse_24_decompress( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + compressed_ptr: u64, + metadata_ptr: u64, + dense_ptr: u64, + m: usize, + k: usize, +) -> Result<()> { + let module = get_or_load_module(context, device_index, MODULE_NAME)?; + let func_name = kernel_name("sparse_24_decompress", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let total_groups = (m * (k / 4)) as u32; + let grid = elementwise_launch_config(total_groups as usize); + let block = (BLOCK_SIZE, 1, 1); + let cfg = launch_config(grid, block, 0); + + let m_u32 = m as u32; + let k_u32 = k as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&compressed_ptr); + builder.arg(&metadata_ptr); + builder.arg(&dense_ptr); + builder.arg(&m_u32); + builder.arg(&k_u32); + builder.launch(cfg).map_err(|e| { + Error::Internal(format!("CUDA sparse_24_decompress launch failed: {e:?}")) + })?; + } + + Ok(()) +} + +/// Launch 2:4 sparse matmul kernel: C = A @ B^T where B is 2:4 compressed. +/// +/// # Safety +/// All pointers must be valid device memory of correct size. +pub unsafe fn launch_sparse_24_matmul( + context: &Arc, + stream: &CudaStream, + device_index: usize, + dtype: DType, + a_ptr: u64, // [N, K] + b_compressed_ptr: u64, // [M, K/2] + b_metadata_ptr: u64, // [M, meta_cols] + c_ptr: u64, // [N, M] + n: usize, + m: usize, + k: usize, +) -> Result<()> { + let module = get_or_load_module(context, device_index, MODULE_NAME)?; + let func_name = kernel_name("sparse_24_matmul", dtype); + let func = get_kernel_function(&module, &func_name)?; + + let tile_size = 16u32; + let grid_x = (m as u32 + tile_size - 1) / tile_size; + let grid_y = (n as u32 + tile_size - 1) / tile_size; + let grid = (grid_x, grid_y, 1); + let block = (tile_size, tile_size, 1); + let cfg = launch_config(grid, block, 0); + + let n_u32 = n as u32; + let m_u32 = m as u32; + let k_u32 = k as u32; + + unsafe { + let mut builder = stream.launch_builder(&func); + builder.arg(&a_ptr); + builder.arg(&b_compressed_ptr); + builder.arg(&b_metadata_ptr); + builder.arg(&c_ptr); + builder.arg(&n_u32); + builder.arg(&m_u32); + builder.arg(&k_u32); + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("CUDA sparse_24_matmul launch failed: {e:?}")))?; + } + + Ok(()) +} diff --git a/src/runtime/cuda/kernels/sparse_coo/kernels.rs b/src/runtime/cuda/kernels/sparse_coo/kernels.rs index 85197155..229d4d88 100644 --- a/src/runtime/cuda/kernels/sparse_coo/kernels.rs +++ b/src/runtime/cuda/kernels/sparse_coo/kernels.rs @@ -590,8 +590,15 @@ pub(crate) unsafe fn launch_coo_compact( // GPU Sort using Thrust // ============================================================================ -/// Sort (i64 keys, i32 indices) using Thrust stable_sort_by_key - FULLY ON GPU -/// Sorts IN-PLACE, so keys and indices are both input and output +/// Sort (i64 keys, i32 indices) using Thrust `stable_sort_by_key` - fully on GPU +/// +/// Sorts in-place: both `keys` and `indices` serve as input and output after sorting. +/// +/// # Safety +/// +/// - `keys` must be a valid device memory pointer with at least `n` i64 elements. +/// - `indices` must be a valid device memory pointer with at least `n` i32 elements. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_thrust_sort_pairs_i64_i32( context: &Arc, stream: &CudaStream, @@ -629,7 +636,12 @@ pub unsafe fn launch_thrust_sort_pairs_i64_i32( // Index and Gather Kernel Launchers // ============================================================================ -/// Initialize indices array [0, 1, 2, ..., n-1] +/// Initialize indices array `[0, 1, 2, ..., n-1]` on device +/// +/// # Safety +/// +/// - `indices` must be a valid device memory pointer with at least `n` i32 elements. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_coo_init_indices( context: &Arc, stream: &CudaStream, @@ -659,7 +671,14 @@ pub unsafe fn launch_coo_init_indices( Ok(()) } -/// Gather values using indices (permutation) +/// Gather values using a permutation index: `values_out[i] = values_in[indices[i]]` +/// +/// # Safety +/// +/// - `values_in`, `indices`, and `values_out` must be valid device memory pointers on the device +/// associated with `context`, each with at least `n` elements of their respective types. +/// - All values in `indices` must be valid indices into `values_in` (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_coo_gather( context: &Arc, stream: &CudaStream, @@ -736,7 +755,16 @@ pub(crate) unsafe fn launch_coo_gather_i32( Ok(()) } -/// Gather i64 values using indices (for row/col indices) +/// Gather i64 values using a permutation index: `values_out[i] = values_in[indices[i]]` +/// +/// Used for permuting row/col index arrays in COO format. +/// +/// # Safety +/// +/// - `values_in`, `indices`, and `values_out` must be valid device memory pointers on the device +/// associated with `context`, each with at least `n` elements of their respective types. +/// - All values in `indices` (i32) must be valid indices into `values_in` (i64 array). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_coo_gather_i64( context: &Arc, stream: &CudaStream, diff --git a/src/runtime/cuda/kernels/sparse_coo/merge.rs b/src/runtime/cuda/kernels/sparse_coo/merge.rs index c8b12391..2b742b69 100644 --- a/src/runtime/cuda/kernels/sparse_coo/merge.rs +++ b/src/runtime/cuda/kernels/sparse_coo/merge.rs @@ -16,7 +16,7 @@ use crate::runtime::Runtime; use crate::runtime::cuda::CudaRuntime; use crate::tensor::Tensor; -/// Perform COO add merge (A + B) on GPU +/// Perform COO add merge (A + B) on GPU (union semantics) /// /// Uses the following algorithm: /// 1. Compute composite keys for both matrices @@ -27,6 +27,13 @@ use crate::tensor::Tensor; /// 6. Merge duplicates with addition /// 7. Filter out zeros /// 8. Extract row/col indices from keys +/// +/// # Safety +/// +/// - All tensor arguments must be valid `CudaRuntime` tensors on the device associated with +/// `context`, with consistent COO structure (matching lengths of row/col index and value arrays). +/// - `shape` must match the logical matrix dimensions (`[nrows, ncols]`). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn coo_add_merge( context: &Arc, stream: &CudaStream, @@ -68,9 +75,9 @@ pub unsafe fn coo_add_merge( context, stream, device_index, - row_indices_a.storage().ptr(), - col_indices_a.storage().ptr(), - keys_a.storage().ptr(), + row_indices_a.ptr(), + col_indices_a.ptr(), + keys_a.ptr(), ncols as i64, nnz_a, )?; @@ -81,9 +88,9 @@ pub unsafe fn coo_add_merge( context, stream, device_index, - row_indices_b.storage().ptr(), - col_indices_b.storage().ptr(), - keys_b.storage().ptr(), + row_indices_b.ptr(), + col_indices_b.ptr(), + keys_b.ptr(), ncols as i64, nnz_b, )?; @@ -98,9 +105,9 @@ pub unsafe fn coo_add_merge( context, stream, device_index, - keys_a.storage().ptr(), - keys_b.storage().ptr(), - concat_keys.storage().ptr(), + keys_a.ptr(), + keys_b.ptr(), + concat_keys.ptr(), nnz_a, nnz_b, )?; @@ -109,23 +116,17 @@ pub unsafe fn coo_add_merge( context, stream, device_index, - values_a.storage().ptr(), - values_b.storage().ptr(), - concat_values.storage().ptr(), - concat_sources.storage().ptr(), + values_a.ptr(), + values_b.ptr(), + concat_values.ptr(), + concat_sources.ptr(), nnz_a, nnz_b, )?; // Step 3: Initialize indices array [0, 1, 2, ..., total-1] on GPU let indices = Tensor::::zeros(&[total], DType::I32, device); - launch_coo_init_indices( - context, - stream, - device_index, - indices.storage().ptr(), - total, - )?; + launch_coo_init_indices(context, stream, device_index, indices.ptr(), total)?; // Step 4: Sort (keys, indices) using Thrust stable_sort_by_key - FULLY ON GPU // Thrust sorts IN-PLACE, so we sort concat_keys and indices directly @@ -134,8 +135,8 @@ pub unsafe fn coo_add_merge( context, stream, device_index, - concat_keys.storage().ptr(), - indices.storage().ptr(), + concat_keys.ptr(), + indices.ptr(), total as u32, )?; } @@ -150,9 +151,9 @@ pub unsafe fn coo_add_merge( context, stream, device_index, - concat_values.storage().ptr(), - indices.storage().ptr(), // indices is now sorted - sorted_values.storage().ptr(), + concat_values.ptr(), + indices.ptr(), // indices is now sorted + sorted_values.ptr(), total, )?; @@ -160,9 +161,9 @@ pub unsafe fn coo_add_merge( context, stream, device_index, - concat_sources.storage().ptr(), - indices.storage().ptr(), // indices is now sorted - sorted_sources.storage().ptr(), + concat_sources.ptr(), + indices.ptr(), // indices is now sorted + sorted_sources.ptr(), total, )?; @@ -172,8 +173,8 @@ pub unsafe fn coo_add_merge( context, stream, device_index, - concat_keys.storage().ptr(), // concat_keys is now sorted - unique_flags.storage().ptr(), + concat_keys.ptr(), // concat_keys is now sorted + unique_flags.ptr(), total, )?; @@ -196,26 +197,26 @@ pub unsafe fn coo_add_merge( context, stream, device_index, - concat_keys.storage().ptr(), // concat_keys is sorted - sorted_values.storage().ptr(), - sorted_sources.storage().ptr(), - unique_flags.storage().ptr(), - output_positions.storage().ptr(), - merged_keys.storage().ptr(), - merged_values.storage().ptr(), + concat_keys.ptr(), // concat_keys is sorted + sorted_values.ptr(), + sorted_sources.ptr(), + unique_flags.ptr(), + output_positions.ptr(), + merged_keys.ptr(), + merged_values.ptr(), total, )?; // Step 9: Filter out zeros - ALL ON GPU (using CUB) - let threshold = crate::runtime::sparse_utils::zero_tolerance::(); + let threshold = crate::runtime::common::sparse_utils::zero_tolerance::(); let nonzero_flags = Tensor::::zeros(&[num_unique], DType::I32, device); launch_coo_mark_nonzero::( context, stream, device_index, - merged_values.storage().ptr(), - nonzero_flags.storage().ptr(), + merged_values.ptr(), + nonzero_flags.ptr(), threshold, num_unique, )?; @@ -239,12 +240,12 @@ pub unsafe fn coo_add_merge( context, stream, device_index, - merged_keys.storage().ptr(), - merged_values.storage().ptr(), - nonzero_flags.storage().ptr(), - compact_positions.storage().ptr(), - final_keys.storage().ptr(), - final_values.storage().ptr(), + merged_keys.ptr(), + merged_values.ptr(), + nonzero_flags.ptr(), + compact_positions.ptr(), + final_keys.ptr(), + final_values.ptr(), num_unique, )?; @@ -256,9 +257,9 @@ pub unsafe fn coo_add_merge( context, stream, device_index, - final_keys.storage().ptr(), - final_row_indices.storage().ptr(), - final_col_indices.storage().ptr(), + final_keys.ptr(), + final_row_indices.ptr(), + final_col_indices.ptr(), ncols as i64, nnz_out, )?; @@ -267,7 +268,7 @@ pub unsafe fn coo_add_merge( Ok((final_row_indices, final_col_indices, final_values)) } -/// Perform COO sub merge (A - B) on GPU +/// Perform COO sub merge (A - B) on GPU (union semantics) /// /// Uses the following algorithm: /// 1. Compute composite keys for both matrices @@ -278,6 +279,13 @@ pub unsafe fn coo_add_merge( /// 6. Merge duplicates with subtraction (union semantics) /// 7. Filter out zeros /// 8. Extract row/col indices from keys +/// +/// # Safety +/// +/// - All tensor arguments must be valid `CudaRuntime` tensors on the device associated with +/// `context`, with consistent COO structure (matching lengths of row/col index and value arrays). +/// - `shape` must match the logical matrix dimensions (`[nrows, ncols]`). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn coo_sub_merge( context: &Arc, stream: &CudaStream, @@ -319,9 +327,9 @@ pub unsafe fn coo_sub_merge( context, stream, device_index, - row_indices_a.storage().ptr(), - col_indices_a.storage().ptr(), - keys_a.storage().ptr(), + row_indices_a.ptr(), + col_indices_a.ptr(), + keys_a.ptr(), ncols as i64, nnz_a, )?; @@ -332,9 +340,9 @@ pub unsafe fn coo_sub_merge( context, stream, device_index, - row_indices_b.storage().ptr(), - col_indices_b.storage().ptr(), - keys_b.storage().ptr(), + row_indices_b.ptr(), + col_indices_b.ptr(), + keys_b.ptr(), ncols as i64, nnz_b, )?; @@ -349,9 +357,9 @@ pub unsafe fn coo_sub_merge( context, stream, device_index, - keys_a.storage().ptr(), - keys_b.storage().ptr(), - concat_keys.storage().ptr(), + keys_a.ptr(), + keys_b.ptr(), + concat_keys.ptr(), nnz_a, nnz_b, )?; @@ -360,23 +368,17 @@ pub unsafe fn coo_sub_merge( context, stream, device_index, - values_a.storage().ptr(), - values_b.storage().ptr(), - concat_values.storage().ptr(), - concat_sources.storage().ptr(), + values_a.ptr(), + values_b.ptr(), + concat_values.ptr(), + concat_sources.ptr(), nnz_a, nnz_b, )?; // Step 3: Initialize indices array [0, 1, 2, ..., total-1] on GPU let indices = Tensor::::zeros(&[total], DType::I32, device); - launch_coo_init_indices( - context, - stream, - device_index, - indices.storage().ptr(), - total, - )?; + launch_coo_init_indices(context, stream, device_index, indices.ptr(), total)?; // Step 4: Sort (keys, indices) using Thrust stable_sort_by_key - FULLY ON GPU unsafe { @@ -384,8 +386,8 @@ pub unsafe fn coo_sub_merge( context, stream, device_index, - concat_keys.storage().ptr(), - indices.storage().ptr(), + concat_keys.ptr(), + indices.ptr(), total as u32, )?; } @@ -398,9 +400,9 @@ pub unsafe fn coo_sub_merge( context, stream, device_index, - concat_values.storage().ptr(), - indices.storage().ptr(), - sorted_values.storage().ptr(), + concat_values.ptr(), + indices.ptr(), + sorted_values.ptr(), total, )?; @@ -408,9 +410,9 @@ pub unsafe fn coo_sub_merge( context, stream, device_index, - concat_sources.storage().ptr(), - indices.storage().ptr(), - sorted_sources.storage().ptr(), + concat_sources.ptr(), + indices.ptr(), + sorted_sources.ptr(), total, )?; @@ -420,8 +422,8 @@ pub unsafe fn coo_sub_merge( context, stream, device_index, - concat_keys.storage().ptr(), - unique_flags.storage().ptr(), + concat_keys.ptr(), + unique_flags.ptr(), total, )?; @@ -444,26 +446,26 @@ pub unsafe fn coo_sub_merge( context, stream, device_index, - concat_keys.storage().ptr(), - sorted_values.storage().ptr(), - sorted_sources.storage().ptr(), - output_positions.storage().ptr(), - merged_keys.storage().ptr(), - merged_values.storage().ptr(), + concat_keys.ptr(), + sorted_values.ptr(), + sorted_sources.ptr(), + output_positions.ptr(), + merged_keys.ptr(), + merged_values.ptr(), total, num_unique, )?; // Step 9: Filter out zeros - ALL ON GPU (using CUB) - let threshold = crate::runtime::sparse_utils::zero_tolerance::(); + let threshold = crate::runtime::common::sparse_utils::zero_tolerance::(); let nonzero_flags = Tensor::::zeros(&[num_unique], DType::I32, device); launch_coo_mark_nonzero::( context, stream, device_index, - merged_values.storage().ptr(), - nonzero_flags.storage().ptr(), + merged_values.ptr(), + nonzero_flags.ptr(), threshold, num_unique, )?; @@ -487,12 +489,12 @@ pub unsafe fn coo_sub_merge( context, stream, device_index, - merged_keys.storage().ptr(), - merged_values.storage().ptr(), - nonzero_flags.storage().ptr(), - compact_positions.storage().ptr(), - final_keys.storage().ptr(), - final_values.storage().ptr(), + merged_keys.ptr(), + merged_values.ptr(), + nonzero_flags.ptr(), + compact_positions.ptr(), + final_keys.ptr(), + final_values.ptr(), num_unique, )?; @@ -504,9 +506,9 @@ pub unsafe fn coo_sub_merge( context, stream, device_index, - final_keys.storage().ptr(), - final_row_indices.storage().ptr(), - final_col_indices.storage().ptr(), + final_keys.ptr(), + final_row_indices.ptr(), + final_col_indices.ptr(), ncols as i64, nnz_out, )?; @@ -526,6 +528,13 @@ pub unsafe fn coo_sub_merge( /// 6. Merge intersections with multiplication /// 7. Filter out zeros /// 8. Extract row/col indices from keys +/// +/// # Safety +/// +/// - All tensor arguments must be valid `CudaRuntime` tensors on the device associated with +/// `context`, with consistent COO structure (matching lengths of row/col index and value arrays). +/// - `shape` must match the logical matrix dimensions (`[nrows, ncols]`). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn coo_mul_merge( context: &Arc, stream: &CudaStream, @@ -566,9 +575,9 @@ pub unsafe fn coo_mul_merge( context, stream, device_index, - row_indices_a.storage().ptr(), - col_indices_a.storage().ptr(), - keys_a.storage().ptr(), + row_indices_a.ptr(), + col_indices_a.ptr(), + keys_a.ptr(), ncols as i64, nnz_a, )?; @@ -577,9 +586,9 @@ pub unsafe fn coo_mul_merge( context, stream, device_index, - row_indices_b.storage().ptr(), - col_indices_b.storage().ptr(), - keys_b.storage().ptr(), + row_indices_b.ptr(), + col_indices_b.ptr(), + keys_b.ptr(), ncols as i64, nnz_b, )?; @@ -593,9 +602,9 @@ pub unsafe fn coo_mul_merge( context, stream, device_index, - keys_a.storage().ptr(), - keys_b.storage().ptr(), - concat_keys.storage().ptr(), + keys_a.ptr(), + keys_b.ptr(), + concat_keys.ptr(), nnz_a, nnz_b, )?; @@ -604,23 +613,17 @@ pub unsafe fn coo_mul_merge( context, stream, device_index, - values_a.storage().ptr(), - values_b.storage().ptr(), - concat_values.storage().ptr(), - concat_sources.storage().ptr(), + values_a.ptr(), + values_b.ptr(), + concat_values.ptr(), + concat_sources.ptr(), nnz_a, nnz_b, )?; // Step 3: Initialize indices array [0, 1, 2, ..., total-1] on GPU let indices = Tensor::::zeros(&[total], DType::I32, device); - launch_coo_init_indices( - context, - stream, - device_index, - indices.storage().ptr(), - total, - )?; + launch_coo_init_indices(context, stream, device_index, indices.ptr(), total)?; // Step 4: Sort (keys, indices) using Thrust stable_sort_by_key - FULLY ON GPU unsafe { @@ -628,8 +631,8 @@ pub unsafe fn coo_mul_merge( context, stream, device_index, - concat_keys.storage().ptr(), - indices.storage().ptr(), + concat_keys.ptr(), + indices.ptr(), total as u32, )?; } @@ -642,9 +645,9 @@ pub unsafe fn coo_mul_merge( context, stream, device_index, - concat_values.storage().ptr(), - indices.storage().ptr(), - sorted_values.storage().ptr(), + concat_values.ptr(), + indices.ptr(), + sorted_values.ptr(), total, )?; @@ -652,9 +655,9 @@ pub unsafe fn coo_mul_merge( context, stream, device_index, - concat_sources.storage().ptr(), - indices.storage().ptr(), - sorted_sources.storage().ptr(), + concat_sources.ptr(), + indices.ptr(), + sorted_sources.ptr(), total, )?; @@ -664,9 +667,9 @@ pub unsafe fn coo_mul_merge( context, stream, device_index, - concat_keys.storage().ptr(), - sorted_sources.storage().ptr(), - intersection_flags.storage().ptr(), + concat_keys.ptr(), + sorted_sources.ptr(), + intersection_flags.ptr(), total, )?; @@ -689,26 +692,26 @@ pub unsafe fn coo_mul_merge( context, stream, device_index, - concat_keys.storage().ptr(), - sorted_values.storage().ptr(), - sorted_sources.storage().ptr(), - intersection_flags.storage().ptr(), - output_positions.storage().ptr(), - merged_keys.storage().ptr(), - merged_values.storage().ptr(), + concat_keys.ptr(), + sorted_values.ptr(), + sorted_sources.ptr(), + intersection_flags.ptr(), + output_positions.ptr(), + merged_keys.ptr(), + merged_values.ptr(), total, )?; // Step 9: Filter out zeros - ALL ON GPU (using CUB) - let threshold = crate::runtime::sparse_utils::zero_tolerance::(); + let threshold = crate::runtime::common::sparse_utils::zero_tolerance::(); let nonzero_flags = Tensor::::zeros(&[num_intersections], DType::I32, device); launch_coo_mark_nonzero::( context, stream, device_index, - merged_values.storage().ptr(), - nonzero_flags.storage().ptr(), + merged_values.ptr(), + nonzero_flags.ptr(), threshold, num_intersections, )?; @@ -732,12 +735,12 @@ pub unsafe fn coo_mul_merge( context, stream, device_index, - merged_keys.storage().ptr(), - merged_values.storage().ptr(), - nonzero_flags.storage().ptr(), - compact_positions.storage().ptr(), - final_keys.storage().ptr(), - final_values.storage().ptr(), + merged_keys.ptr(), + merged_values.ptr(), + nonzero_flags.ptr(), + compact_positions.ptr(), + final_keys.ptr(), + final_values.ptr(), num_intersections, )?; @@ -749,9 +752,9 @@ pub unsafe fn coo_mul_merge( context, stream, device_index, - final_keys.storage().ptr(), - final_row_indices.storage().ptr(), - final_col_indices.storage().ptr(), + final_keys.ptr(), + final_row_indices.ptr(), + final_col_indices.ptr(), ncols as i64, nnz_out, )?; @@ -771,6 +774,13 @@ pub unsafe fn coo_mul_merge( /// 6. Merge intersections with division /// 7. Filter out zeros and non-finite values /// 8. Extract row/col indices from keys +/// +/// # Safety +/// +/// - All tensor arguments must be valid `CudaRuntime` tensors on the device associated with +/// `context`, with consistent COO structure (matching lengths of row/col index and value arrays). +/// - `shape` must match the logical matrix dimensions (`[nrows, ncols]`). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn coo_div_merge( context: &Arc, stream: &CudaStream, @@ -811,9 +821,9 @@ pub unsafe fn coo_div_merge( context, stream, device_index, - row_indices_a.storage().ptr(), - col_indices_a.storage().ptr(), - keys_a.storage().ptr(), + row_indices_a.ptr(), + col_indices_a.ptr(), + keys_a.ptr(), ncols as i64, nnz_a, )?; @@ -822,9 +832,9 @@ pub unsafe fn coo_div_merge( context, stream, device_index, - row_indices_b.storage().ptr(), - col_indices_b.storage().ptr(), - keys_b.storage().ptr(), + row_indices_b.ptr(), + col_indices_b.ptr(), + keys_b.ptr(), ncols as i64, nnz_b, )?; @@ -838,9 +848,9 @@ pub unsafe fn coo_div_merge( context, stream, device_index, - keys_a.storage().ptr(), - keys_b.storage().ptr(), - concat_keys.storage().ptr(), + keys_a.ptr(), + keys_b.ptr(), + concat_keys.ptr(), nnz_a, nnz_b, )?; @@ -849,23 +859,17 @@ pub unsafe fn coo_div_merge( context, stream, device_index, - values_a.storage().ptr(), - values_b.storage().ptr(), - concat_values.storage().ptr(), - concat_sources.storage().ptr(), + values_a.ptr(), + values_b.ptr(), + concat_values.ptr(), + concat_sources.ptr(), nnz_a, nnz_b, )?; // Step 3: Initialize indices array [0, 1, 2, ..., total-1] on GPU let indices = Tensor::::zeros(&[total], DType::I32, device); - launch_coo_init_indices( - context, - stream, - device_index, - indices.storage().ptr(), - total, - )?; + launch_coo_init_indices(context, stream, device_index, indices.ptr(), total)?; // Step 4: Sort (keys, indices) using Thrust stable_sort_by_key - FULLY ON GPU unsafe { @@ -873,8 +877,8 @@ pub unsafe fn coo_div_merge( context, stream, device_index, - concat_keys.storage().ptr(), - indices.storage().ptr(), + concat_keys.ptr(), + indices.ptr(), total as u32, )?; } @@ -887,9 +891,9 @@ pub unsafe fn coo_div_merge( context, stream, device_index, - concat_values.storage().ptr(), - indices.storage().ptr(), - sorted_values.storage().ptr(), + concat_values.ptr(), + indices.ptr(), + sorted_values.ptr(), total, )?; @@ -897,9 +901,9 @@ pub unsafe fn coo_div_merge( context, stream, device_index, - concat_sources.storage().ptr(), - indices.storage().ptr(), - sorted_sources.storage().ptr(), + concat_sources.ptr(), + indices.ptr(), + sorted_sources.ptr(), total, )?; @@ -909,9 +913,9 @@ pub unsafe fn coo_div_merge( context, stream, device_index, - concat_keys.storage().ptr(), - sorted_sources.storage().ptr(), - intersection_flags.storage().ptr(), + concat_keys.ptr(), + sorted_sources.ptr(), + intersection_flags.ptr(), total, )?; @@ -934,26 +938,26 @@ pub unsafe fn coo_div_merge( context, stream, device_index, - concat_keys.storage().ptr(), - sorted_values.storage().ptr(), - sorted_sources.storage().ptr(), - intersection_flags.storage().ptr(), - output_positions.storage().ptr(), - merged_keys.storage().ptr(), - merged_values.storage().ptr(), + concat_keys.ptr(), + sorted_values.ptr(), + sorted_sources.ptr(), + intersection_flags.ptr(), + output_positions.ptr(), + merged_keys.ptr(), + merged_values.ptr(), total, )?; // Step 9: Filter out zeros and non-finite values - ALL ON GPU (using CUB) - let threshold = crate::runtime::sparse_utils::zero_tolerance::(); + let threshold = crate::runtime::common::sparse_utils::zero_tolerance::(); let nonzero_flags = Tensor::::zeros(&[num_intersections], DType::I32, device); launch_coo_mark_nonzero::( context, stream, device_index, - merged_values.storage().ptr(), - nonzero_flags.storage().ptr(), + merged_values.ptr(), + nonzero_flags.ptr(), threshold, num_intersections, )?; @@ -977,12 +981,12 @@ pub unsafe fn coo_div_merge( context, stream, device_index, - merged_keys.storage().ptr(), - merged_values.storage().ptr(), - nonzero_flags.storage().ptr(), - compact_positions.storage().ptr(), - final_keys.storage().ptr(), - final_values.storage().ptr(), + merged_keys.ptr(), + merged_values.ptr(), + nonzero_flags.ptr(), + compact_positions.ptr(), + final_keys.ptr(), + final_values.ptr(), num_intersections, )?; @@ -994,9 +998,9 @@ pub unsafe fn coo_div_merge( context, stream, device_index, - final_keys.storage().ptr(), - final_row_indices.storage().ptr(), - final_col_indices.storage().ptr(), + final_keys.ptr(), + final_row_indices.ptr(), + final_col_indices.ptr(), ncols as i64, nnz_out, )?; diff --git a/src/runtime/cuda/kernels/sparse_linalg.cu b/src/runtime/cuda/kernels/sparse_linalg.cu index 4b7decff..0ec381b2 100644 --- a/src/runtime/cuda/kernels/sparse_linalg.cu +++ b/src/runtime/cuda/kernels/sparse_linalg.cu @@ -1156,4 +1156,285 @@ __global__ void apply_row_perm_f64( y[i] = b[perm[i]]; } +// ============================================================================ +// Sparse QR Factorization Kernels +// ============================================================================ + +// Apply a dense Householder reflector to work vector (fused dot + axpy) +// work[v_start..v_start+v_len] -= tau * (v^T * work[v_start..v_start+v_len]) * v +// Single block launch, 256 threads, shared memory reduction for dot product +__global__ void sparse_qr_apply_reflector_f32( + const float* v, // Dense Householder vector, length v_len + int v_start, // Starting row index in work + int v_len, // Length of v + const float* tau_ptr, // Pointer to tau (single element on GPU) + float* work, // Dense work vector + int m // Length of work (unused but for safety) +) { + __shared__ float partial[256]; + + int tid = threadIdx.x; + float tau = *tau_ptr; + + if (tau == 0.0f) return; + + // Phase 1: dot = v^T * work[v_start..] + float my_sum = 0.0f; + for (int i = tid; i < v_len; i += blockDim.x) { + my_sum += v[i] * work[v_start + i]; + } + partial[tid] = my_sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) partial[tid] += partial[tid + s]; + __syncthreads(); + } + + float scale = tau * partial[0]; + + // Phase 2: work[v_start + i] -= scale * v[i] + for (int i = tid; i < v_len; i += blockDim.x) { + work[v_start + i] -= scale * v[i]; + } +} + +__global__ void sparse_qr_apply_reflector_f64( + const double* v, + int v_start, + int v_len, + const double* tau_ptr, + double* work, + int m +) { + __shared__ double partial[256]; + + int tid = threadIdx.x; + double tau = *tau_ptr; + + if (tau == 0.0) return; + + double my_sum = 0.0; + for (int i = tid; i < v_len; i += blockDim.x) { + my_sum += v[i] * work[v_start + i]; + } + partial[tid] = my_sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) partial[tid] += partial[tid + s]; + __syncthreads(); + } + + double scale = tau * partial[0]; + + for (int i = tid; i < v_len; i += blockDim.x) { + work[v_start + i] -= scale * v[i]; + } +} + +// Compute ||work[start..start+count]||^2 via parallel reduction +// Single block, result written to result[0] +__global__ void sparse_qr_norm_f32( + const float* work, + int start, + int count, + float* result +) { + __shared__ float partial[256]; + + int tid = threadIdx.x; + + float my_sum = 0.0f; + for (int i = tid; i < count; i += blockDim.x) { + float val = work[start + i]; + my_sum += val * val; + } + partial[tid] = my_sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) partial[tid] += partial[tid + s]; + __syncthreads(); + } + + if (tid == 0) result[0] = partial[0]; +} + +__global__ void sparse_qr_norm_f64( + const double* work, + int start, + int count, + double* result +) { + __shared__ double partial[256]; + + int tid = threadIdx.x; + + double my_sum = 0.0; + for (int i = tid; i < count; i += blockDim.x) { + double val = work[start + i]; + my_sum += val * val; + } + partial[tid] = my_sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) partial[tid] += partial[tid + s]; + __syncthreads(); + } + + if (tid == 0) result[0] = partial[0]; +} + +// Compute Householder vector from work[start..m] +// Reads norm_sq from norm_sq_ptr (computed by norm kernel) +// Writes dense v to out_v, tau to out_tau, R diagonal to out_diag +// Single block, thread 0 computes control values, all threads compute v entries +// +// Tolerance 1e-30: well below machine epsilon for both f32 (~1e-7) and f64 (~2e-16). +// Matches CPU implementation (algorithm.rs:226,238). This threshold detects truly zero +// columns without false positives from normal floating-point roundoff. +__global__ void sparse_qr_householder_f32( + const float* work, + int start, + int m, + const float* norm_sq_ptr, + float* out_v, + float* out_tau, + float* out_diag +) { + __shared__ float ctrl[4]; // [sigma, tau, diag, inv_v_start] + + int tid = threadIdx.x; + int v_len = m - start; + + if (tid == 0) { + float norm_sq = *norm_sq_ptr; + float norm = sqrtf(norm_sq); + + if (norm < 1e-30f) { + ctrl[0] = 0.0f; ctrl[1] = 0.0f; ctrl[2] = 0.0f; ctrl[3] = 0.0f; + } else { + float x0 = work[start]; + float sigma = (x0 >= 0.0f) ? -norm : norm; + float v_start_val = x0 - sigma; + + if (fabsf(v_start_val) < 1e-30f) { + ctrl[0] = sigma; ctrl[1] = 0.0f; ctrl[2] = sigma; ctrl[3] = 0.0f; + } else { + ctrl[0] = sigma; + ctrl[1] = -v_start_val / sigma; + ctrl[2] = sigma; + ctrl[3] = 1.0f / v_start_val; + } + } + } + __syncthreads(); + + float tau = ctrl[1]; + float inv_v_start = ctrl[3]; + + if (tid == 0) { + *out_tau = tau; + *out_diag = ctrl[2]; + } + + if (tau == 0.0f) { + for (int i = tid; i < v_len; i += blockDim.x) { + out_v[i] = (i == 0) ? 1.0f : 0.0f; + } + } else { + for (int i = tid; i < v_len; i += blockDim.x) { + out_v[i] = (i == 0) ? 1.0f : work[start + i] * inv_v_start; + } + } +} + +__global__ void sparse_qr_householder_f64( + const double* work, + int start, + int m, + const double* norm_sq_ptr, + double* out_v, + double* out_tau, + double* out_diag +) { + __shared__ double ctrl[4]; + + int tid = threadIdx.x; + int v_len = m - start; + + if (tid == 0) { + double norm_sq = *norm_sq_ptr; + double norm = sqrt(norm_sq); + + if (norm < 1e-30) { + ctrl[0] = 0.0; ctrl[1] = 0.0; ctrl[2] = 0.0; ctrl[3] = 0.0; + } else { + double x0 = work[start]; + double sigma = (x0 >= 0.0) ? -norm : norm; + double v_start_val = x0 - sigma; + + if (fabs(v_start_val) < 1e-30) { + ctrl[0] = sigma; ctrl[1] = 0.0; ctrl[2] = sigma; ctrl[3] = 0.0; + } else { + ctrl[0] = sigma; + ctrl[1] = -v_start_val / sigma; + ctrl[2] = sigma; + ctrl[3] = 1.0 / v_start_val; + } + } + } + __syncthreads(); + + double tau = ctrl[1]; + double inv_v_start = ctrl[3]; + + if (tid == 0) { + *out_tau = tau; + *out_diag = ctrl[2]; + } + + if (tau == 0.0) { + for (int i = tid; i < v_len; i += blockDim.x) { + out_v[i] = (i == 0) ? 1.0 : 0.0; + } + } else { + for (int i = tid; i < v_len; i += blockDim.x) { + out_v[i] = (i == 0) ? 1.0 : work[start + i] * inv_v_start; + } + } +} + +// Extract R off-diagonal: copy work[0..count] to output buffer +__global__ void sparse_qr_extract_r_f32( + const float* work, + int count, + float* output +) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < count) output[i] = work[i]; +} + +__global__ void sparse_qr_extract_r_f64( + const double* work, + int count, + double* output +) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < count) output[i] = work[i]; +} + +// Clear work vector: work[0..n] = 0 +__global__ void sparse_qr_clear_f32(float* work, int n) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) work[i] = 0.0f; +} + +__global__ void sparse_qr_clear_f64(double* work, int n) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) work[i] = 0.0; +} + } // extern "C" diff --git a/src/runtime/cuda/kernels/sparse_linalg/ilu_ic.rs b/src/runtime/cuda/kernels/sparse_linalg/ilu_ic.rs index 373a881a..13b6565b 100644 --- a/src/runtime/cuda/kernels/sparse_linalg/ilu_ic.rs +++ b/src/runtime/cuda/kernels/sparse_linalg/ilu_ic.rs @@ -14,7 +14,16 @@ use crate::error::Result; // ILU(0) Level Kernel Launchers // ============================================================================ -/// Launch ILU(0) level kernel - f32 +/// Launch ILU(0) factorization level kernel - f32 +/// +/// # Safety +/// +/// - `level_rows`, `row_ptrs`, `col_indices`, `values`, and `diag_indices` must be valid device +/// memory pointers on the device associated with `context`. +/// - `level_rows` must have at least `level_size` elements with valid row indices in `[0, n)`. +/// - `row_ptrs` must have at least `n + 1` elements; `col_indices`, `values`, and `diag_indices` +/// must each have at least `nnz` elements (as encoded in `row_ptrs`). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_ilu0_level_f32( context: &Arc, @@ -47,7 +56,16 @@ pub unsafe fn launch_ilu0_level_f32( Ok(()) } -/// Launch ILU(0) level kernel - f64 +/// Launch ILU(0) factorization level kernel - f64 +/// +/// # Safety +/// +/// - `level_rows`, `row_ptrs`, `col_indices`, `values`, and `diag_indices` must be valid device +/// memory pointers on the device associated with `context`. +/// - `level_rows` must have at least `level_size` elements with valid row indices in `[0, n)`. +/// - `row_ptrs` must have at least `n + 1` elements; `col_indices`, `values`, and `diag_indices` +/// must each have at least `nnz` elements (as encoded in `row_ptrs`). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_ilu0_level_f64( context: &Arc, @@ -84,7 +102,16 @@ pub unsafe fn launch_ilu0_level_f64( // IC(0) Level Kernel Launchers // ============================================================================ -/// Launch IC(0) level kernel - f32 +/// Launch IC(0) factorization level kernel - f32 +/// +/// # Safety +/// +/// - `level_rows`, `row_ptrs`, `col_indices`, `values`, and `diag_indices` must be valid device +/// memory pointers on the device associated with `context`. +/// - `level_rows` must have at least `level_size` elements with valid row indices in `[0, n)`. +/// - `row_ptrs` must have at least `n + 1` elements; `col_indices`, `values`, and `diag_indices` +/// must each have at least `nnz` elements (as encoded in `row_ptrs`). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_ic0_level_f32( context: &Arc, @@ -117,7 +144,16 @@ pub unsafe fn launch_ic0_level_f32( Ok(()) } -/// Launch IC(0) level kernel - f64 +/// Launch IC(0) factorization level kernel - f64 +/// +/// # Safety +/// +/// - `level_rows`, `row_ptrs`, `col_indices`, `values`, and `diag_indices` must be valid device +/// memory pointers on the device associated with `context`. +/// - `level_rows` must have at least `level_size` elements with valid row indices in `[0, n)`. +/// - `row_ptrs` must have at least `n + 1` elements; `col_indices`, `values`, and `diag_indices` +/// must each have at least `nnz` elements (as encoded in `row_ptrs`). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_ic0_level_f64( context: &Arc, diff --git a/src/runtime/cuda/kernels/sparse_linalg/levels.rs b/src/runtime/cuda/kernels/sparse_linalg/levels.rs index 9ee9d8f8..092cee65 100644 --- a/src/runtime/cuda/kernels/sparse_linalg/levels.rs +++ b/src/runtime/cuda/kernels/sparse_linalg/levels.rs @@ -24,6 +24,13 @@ use crate::error::Result; // ============================================================================ /// Cast i64 GPU tensor to i32 GPU tensor (no CPU transfer) +/// +/// # Safety +/// +/// - `input` and `output` must be valid device memory pointers on the device associated with +/// `context`, each with at least `n` elements of their respective types. +/// - Values in `input` that exceed `i32::MAX` or are below `i32::MIN` will be truncated. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_cast_i64_to_i32( context: &Arc, stream: &CudaStream, @@ -48,7 +55,16 @@ pub unsafe fn launch_cast_i64_to_i32( // Level Computation // ============================================================================ -/// Compute level schedule for lower triangular (iterative BFS on GPU) +/// Compute level schedule for lower triangular matrix via iterative BFS on GPU +/// +/// # Safety +/// +/// - `row_ptrs`, `col_indices`, `levels`, and `changed` must be valid device memory pointers on +/// the device associated with `context`. +/// - `row_ptrs` must have at least `n + 1` i32 elements; `col_indices` has `nnz` elements. +/// - `levels` must have at least `n` i32 elements (initialized by caller before first call). +/// - `changed` must point to a single i32 flag in device memory. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_compute_levels_lower_iter( context: &Arc, stream: &CudaStream, @@ -73,7 +89,16 @@ pub unsafe fn launch_compute_levels_lower_iter( Ok(()) } -/// Compute level schedule for upper triangular (iterative BFS on GPU) +/// Compute level schedule for upper triangular matrix via iterative BFS on GPU +/// +/// # Safety +/// +/// - `row_ptrs`, `col_indices`, `levels`, and `changed` must be valid device memory pointers on +/// the device associated with `context`. +/// - `row_ptrs` must have at least `n + 1` i32 elements; `col_indices` has `nnz` elements. +/// - `levels` must have at least `n` i32 elements (initialized by caller before first call). +/// - `changed` must point to a single i32 flag in device memory. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_compute_levels_upper_iter( context: &Arc, stream: &CudaStream, @@ -102,7 +127,13 @@ pub unsafe fn launch_compute_levels_upper_iter( // Reduction // ============================================================================ -/// Find maximum level value via reduction +/// Find maximum level value via single-block parallel reduction +/// +/// # Safety +/// +/// - `data` must be a valid device memory pointer with at least `n` i32 elements. +/// - `result` must point to a single i32 element in device memory where the result is written. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_reduce_max_i32( context: &Arc, stream: &CudaStream, @@ -128,7 +159,15 @@ pub unsafe fn launch_reduce_max_i32( // Histogram and Scatter // ============================================================================ -/// Count occurrences of each level +/// Count occurrences of each level via atomic histogram +/// +/// # Safety +/// +/// - `levels` must be a valid device memory pointer with at least `n` i32 elements. +/// - `counts` must be a valid device memory pointer pre-allocated to hold the histogram +/// (size must be at least `max_level + 1` as determined by the caller). +/// - All values in `levels` must be non-negative and within bounds of the `counts` array. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_histogram_levels( context: &Arc, stream: &CudaStream, @@ -149,7 +188,16 @@ pub unsafe fn launch_histogram_levels( Ok(()) } -/// Scatter rows by level into level_rows array +/// Scatter rows by level into the `level_rows` array using atomic counters +/// +/// # Safety +/// +/// - `levels`, `level_ptrs`, `level_rows`, and `level_counters` must be valid device memory +/// pointers on the device associated with `context`. +/// - `levels` and `level_counters` must have at least `n` elements. +/// - `level_ptrs` must have at least `num_levels + 1` elements (prefix sums of level sizes). +/// - `level_rows` must have at least `n` elements. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_scatter_by_level( context: &Arc, stream: &CudaStream, diff --git a/src/runtime/cuda/kernels/sparse_linalg/mod.rs b/src/runtime/cuda/kernels/sparse_linalg/mod.rs index 4938e304..e59144bf 100644 --- a/src/runtime/cuda/kernels/sparse_linalg/mod.rs +++ b/src/runtime/cuda/kernels/sparse_linalg/mod.rs @@ -14,12 +14,14 @@ mod ilu_ic; mod levels; mod primitives; +mod qr; mod trsv; mod utils; pub use ilu_ic::*; pub use levels::*; pub use primitives::*; +pub use qr::*; pub use trsv::*; pub use utils::*; diff --git a/src/runtime/cuda/kernels/sparse_linalg/primitives.rs b/src/runtime/cuda/kernels/sparse_linalg/primitives.rs index 34678198..5502aeb7 100644 --- a/src/runtime/cuda/kernels/sparse_linalg/primitives.rs +++ b/src/runtime/cuda/kernels/sparse_linalg/primitives.rs @@ -17,7 +17,15 @@ use crate::error::Result; // Scatter Operations // ============================================================================ -/// Scatters values into work vector: work[row_indices[i]] = values[i] - f32 +/// Scatters values into work vector: `work[row_indices[i]] = values[i]` - f32 +/// +/// # Safety +/// +/// - `values`, `row_indices`, and `work` must be valid device memory pointers on the device +/// associated with `context`. +/// - `values` and `row_indices` must each have at least `nnz` elements. +/// - All values in `row_indices` must be valid indices into `work` (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_sparse_scatter_f32( context: &Arc, stream: &CudaStream, @@ -41,7 +49,15 @@ pub unsafe fn launch_sparse_scatter_f32( Ok(()) } -/// Scatters values into work vector - f64 +/// Scatters values into work vector: `work[row_indices[i]] = values[i]` - f64 +/// +/// # Safety +/// +/// - `values`, `row_indices`, and `work` must be valid device memory pointers on the device +/// associated with `context`. +/// - `values` and `row_indices` must each have at least `nnz` elements. +/// - All values in `row_indices` must be valid indices into `work` (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_sparse_scatter_f64( context: &Arc, stream: &CudaStream, @@ -69,7 +85,15 @@ pub unsafe fn launch_sparse_scatter_f64( // AXPY Operations // ============================================================================ -/// Computes: work[row_indices[i]] -= scale * values[i] - f32 +/// Computes: `work[row_indices[i]] -= scale * values[i]` - f32 +/// +/// # Safety +/// +/// - `values`, `row_indices`, and `work` must be valid device memory pointers on the device +/// associated with `context`. +/// - `values` and `row_indices` must each have at least `nnz` elements. +/// - All values in `row_indices` must be valid indices into `work` (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_sparse_axpy_f32( context: &Arc, stream: &CudaStream, @@ -95,7 +119,15 @@ pub unsafe fn launch_sparse_axpy_f32( Ok(()) } -/// Computes: work[row_indices[i]] -= scale * values[i] - f64 +/// Computes: `work[row_indices[i]] -= scale * values[i]` - f64 +/// +/// # Safety +/// +/// - `values`, `row_indices`, and `work` must be valid device memory pointers on the device +/// associated with `context`. +/// - `values` and `row_indices` must each have at least `nnz` elements. +/// - All values in `row_indices` must be valid indices into `work` (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_sparse_axpy_f64( context: &Arc, stream: &CudaStream, @@ -125,7 +157,15 @@ pub unsafe fn launch_sparse_axpy_f64( // Gather and Clear Operations // ============================================================================ -/// Gathers: output[i] = work[row_indices[i]], then clears work[row_indices[i]] = 0 - f32 +/// Gathers and clears: `output[i] = work[row_indices[i]]`, then sets `work[row_indices[i]] = 0` - f32 +/// +/// # Safety +/// +/// - `work`, `row_indices`, and `output` must be valid device memory pointers on the device +/// associated with `context`. +/// - `row_indices` and `output` must each have at least `nnz` elements. +/// - All values in `row_indices` must be valid indices into `work` (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_sparse_gather_clear_f32( context: &Arc, stream: &CudaStream, @@ -149,7 +189,15 @@ pub unsafe fn launch_sparse_gather_clear_f32( Ok(()) } -/// Gathers and clears - f64 +/// Gathers and clears: `output[i] = work[row_indices[i]]`, then sets `work[row_indices[i]] = 0` - f64 +/// +/// # Safety +/// +/// - `work`, `row_indices`, and `output` must be valid device memory pointers on the device +/// associated with `context`. +/// - `row_indices` and `output` must each have at least `nnz` elements. +/// - All values in `row_indices` must be valid indices into `work` (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_sparse_gather_clear_f64( context: &Arc, stream: &CudaStream, @@ -177,7 +225,15 @@ pub unsafe fn launch_sparse_gather_clear_f64( // Divide by Pivot Operations // ============================================================================ -/// Computes: work[row_indices[i]] *= inv_pivot - f32 +/// Computes: `work[row_indices[i]] *= inv_pivot` - f32 +/// +/// # Safety +/// +/// - `work` and `row_indices` must be valid device memory pointers on the device associated +/// with `context`. +/// - `row_indices` must have at least `nnz` elements. +/// - All values in `row_indices` must be valid indices into `work` (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_sparse_divide_pivot_f32( context: &Arc, stream: &CudaStream, @@ -201,7 +257,15 @@ pub unsafe fn launch_sparse_divide_pivot_f32( Ok(()) } -/// Divide by pivot - f64 +/// Computes: `work[row_indices[i]] *= inv_pivot` - f64 +/// +/// # Safety +/// +/// - `work` and `row_indices` must be valid device memory pointers on the device associated +/// with `context`. +/// - `row_indices` must have at least `nnz` elements. +/// - All values in `row_indices` must be valid indices into `work` (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_sparse_divide_pivot_f64( context: &Arc, stream: &CudaStream, @@ -229,7 +293,15 @@ pub unsafe fn launch_sparse_divide_pivot_f64( // Row Permutation Operations // ============================================================================ -/// Applies row permutation: y[i] = b[perm[i]] - f32 +/// Applies row permutation: `y[i] = b[perm[i]]` - f32 +/// +/// # Safety +/// +/// - `b`, `perm`, and `y` must be valid device memory pointers on the device associated +/// with `context`. +/// - `b`, `perm`, and `y` must each have at least `n` elements. +/// - All values in `perm` must be valid indices into `b` (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_apply_row_perm_f32( context: &Arc, stream: &CudaStream, @@ -253,7 +325,15 @@ pub unsafe fn launch_apply_row_perm_f32( Ok(()) } -/// Applies row permutation - f64 +/// Applies row permutation: `y[i] = b[perm[i]]` - f64 +/// +/// # Safety +/// +/// - `b`, `perm`, and `y` must be valid device memory pointers on the device associated +/// with `context`. +/// - `b`, `perm`, and `y` must each have at least `n` elements. +/// - All values in `perm` must be valid indices into `b` (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_apply_row_perm_f64( context: &Arc, stream: &CudaStream, diff --git a/src/runtime/cuda/kernels/sparse_linalg/qr.rs b/src/runtime/cuda/kernels/sparse_linalg/qr.rs new file mode 100644 index 00000000..2a97ade5 --- /dev/null +++ b/src/runtime/cuda/kernels/sparse_linalg/qr.rs @@ -0,0 +1,355 @@ +//! CUDA kernel launchers for sparse QR factorization +//! +//! Implements Householder QR reduction for sparse matrices on NVIDIA GPUs. +//! Five primitive kernels composed into a column-wise left-looking algorithm: +//! +//! - `apply_reflector`: Fused dot+axpy Householder update (single block, shared mem reduction) +//! - `norm`: Parallel sum-of-squares reduction for ||work[start..start+count]||^2 +//! - `householder`: Householder vector generation with tau and R diagonal computation +//! - `extract_r`: Copy R off-diagonal entries from work vector +//! - `clear`: Zero-initialize work vector +//! +//! All single-block kernels use 256 threads with shared memory reductions. +//! Grid-based kernels (extract_r, clear) scale to arbitrary sizes. + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use std::sync::Arc; + +use super::{ + BLOCK_SIZE, SPARSE_LINALG_MODULE, get_kernel_function, get_or_load_module, grid_size, + launch_config, launch_error, +}; +use crate::error::Result; + +// ============================================================================ +// Apply Householder Reflector (single block, fused dot + axpy) +// ============================================================================ + +/// Applies dense Householder reflector to work vector - f32 +/// +/// Computes: `work[v_start..] -= tau * (v^T * work[v_start..]) * v` +/// Single block of 256 threads with shared memory reduction. +/// +/// # Safety +/// +/// - `v`, `tau_ptr`, and `work` must be valid device memory pointers on the device associated +/// with `context`. +/// - `v` must have at least `v_len` elements starting from index `v_start`. +/// - `work` must have at least `m` elements. +/// - `tau_ptr` must point to a single f32 scalar in device memory. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub unsafe fn launch_sparse_qr_apply_reflector_f32( + context: &Arc, + stream: &CudaStream, + device_index: usize, + v: u64, + v_start: i32, + v_len: i32, + tau_ptr: u64, + work: u64, + m: i32, +) -> Result<()> { + let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?; + let func = get_kernel_function(&module, "sparse_qr_apply_reflector_f32")?; + let cfg = launch_config((1, 1, 1), (BLOCK_SIZE, 1, 1), 0); + + let mut builder = stream.launch_builder(&func); + builder.arg(&v); + builder.arg(&v_start); + builder.arg(&v_len); + builder.arg(&tau_ptr); + builder.arg(&work); + builder.arg(&m); + unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_apply_reflector_f32", e))?; + Ok(()) +} + +/// Applies dense Householder reflector to work vector - f64 +/// +/// Computes: `work[v_start..] -= tau * (v^T * work[v_start..]) * v` +/// Single block of 256 threads with shared memory reduction. +/// +/// # Safety +/// +/// - `v`, `tau_ptr`, and `work` must be valid device memory pointers on the device associated +/// with `context`. +/// - `v` must have at least `v_len` elements starting from index `v_start`. +/// - `work` must have at least `m` elements. +/// - `tau_ptr` must point to a single f64 scalar in device memory. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub unsafe fn launch_sparse_qr_apply_reflector_f64( + context: &Arc, + stream: &CudaStream, + device_index: usize, + v: u64, + v_start: i32, + v_len: i32, + tau_ptr: u64, + work: u64, + m: i32, +) -> Result<()> { + let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?; + let func = get_kernel_function(&module, "sparse_qr_apply_reflector_f64")?; + let cfg = launch_config((1, 1, 1), (BLOCK_SIZE, 1, 1), 0); + + let mut builder = stream.launch_builder(&func); + builder.arg(&v); + builder.arg(&v_start); + builder.arg(&v_len); + builder.arg(&tau_ptr); + builder.arg(&work); + builder.arg(&m); + unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_apply_reflector_f64", e))?; + Ok(()) +} + +// ============================================================================ +// Norm (sum of squares reduction, single block) +// ============================================================================ + +/// Computes `||work[start..start+count]||^2` via parallel reduction - f32 +/// +/// # Safety +/// +/// - `work` must be a valid device memory pointer on the device associated with `context`, +/// with at least `start + count` f32 elements. +/// - `result` must point to a single f32 element in device memory where the result will be written. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub unsafe fn launch_sparse_qr_norm_f32( + context: &Arc, + stream: &CudaStream, + device_index: usize, + work: u64, + start: i32, + count: i32, + result: u64, +) -> Result<()> { + let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?; + let func = get_kernel_function(&module, "sparse_qr_norm_f32")?; + let cfg = launch_config((1, 1, 1), (BLOCK_SIZE, 1, 1), 0); + + let mut builder = stream.launch_builder(&func); + builder.arg(&work); + builder.arg(&start); + builder.arg(&count); + builder.arg(&result); + unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_norm_f32", e))?; + Ok(()) +} + +/// Computes `||work[start..start+count]||^2` via parallel reduction - f64 +/// +/// # Safety +/// +/// - `work` must be a valid device memory pointer on the device associated with `context`, +/// with at least `start + count` f64 elements. +/// - `result` must point to a single f64 element in device memory where the result will be written. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub unsafe fn launch_sparse_qr_norm_f64( + context: &Arc, + stream: &CudaStream, + device_index: usize, + work: u64, + start: i32, + count: i32, + result: u64, +) -> Result<()> { + let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?; + let func = get_kernel_function(&module, "sparse_qr_norm_f64")?; + let cfg = launch_config((1, 1, 1), (BLOCK_SIZE, 1, 1), 0); + + let mut builder = stream.launch_builder(&func); + builder.arg(&work); + builder.arg(&start); + builder.arg(&count); + builder.arg(&result); + unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_norm_f64", e))?; + Ok(()) +} + +// ============================================================================ +// Householder vector computation (single block) +// ============================================================================ + +/// Computes Householder vector from `work[start..m]` and stores results - f32 +/// +/// # Safety +/// +/// - `work` must be a valid device memory pointer with at least `m` f32 elements. +/// - `norm_sq_ptr` must point to a single f32 scalar in device memory (the precomputed norm²). +/// - `out_v`, `out_tau`, and `out_diag` must be valid device memory pointers with sufficient space. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub unsafe fn launch_sparse_qr_householder_f32( + context: &Arc, + stream: &CudaStream, + device_index: usize, + work: u64, + start: i32, + m: i32, + norm_sq_ptr: u64, + out_v: u64, + out_tau: u64, + out_diag: u64, +) -> Result<()> { + let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?; + let func = get_kernel_function(&module, "sparse_qr_householder_f32")?; + let cfg = launch_config((1, 1, 1), (BLOCK_SIZE, 1, 1), 0); + + let mut builder = stream.launch_builder(&func); + builder.arg(&work); + builder.arg(&start); + builder.arg(&m); + builder.arg(&norm_sq_ptr); + builder.arg(&out_v); + builder.arg(&out_tau); + builder.arg(&out_diag); + unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_householder_f32", e))?; + Ok(()) +} + +/// Computes Householder vector from `work[start..m]` and stores results - f64 +/// +/// # Safety +/// +/// - `work` must be a valid device memory pointer with at least `m` f64 elements. +/// - `norm_sq_ptr` must point to a single f64 scalar in device memory (the precomputed norm²). +/// - `out_v`, `out_tau`, and `out_diag` must be valid device memory pointers with sufficient space. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub unsafe fn launch_sparse_qr_householder_f64( + context: &Arc, + stream: &CudaStream, + device_index: usize, + work: u64, + start: i32, + m: i32, + norm_sq_ptr: u64, + out_v: u64, + out_tau: u64, + out_diag: u64, +) -> Result<()> { + let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?; + let func = get_kernel_function(&module, "sparse_qr_householder_f64")?; + let cfg = launch_config((1, 1, 1), (BLOCK_SIZE, 1, 1), 0); + + let mut builder = stream.launch_builder(&func); + builder.arg(&work); + builder.arg(&start); + builder.arg(&m); + builder.arg(&norm_sq_ptr); + builder.arg(&out_v); + builder.arg(&out_tau); + builder.arg(&out_diag); + unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_householder_f64", e))?; + Ok(()) +} + +// ============================================================================ +// Extract R off-diagonal entries +// ============================================================================ + +/// Copies `work[0..count]` to output buffer - f32 +/// +/// # Safety +/// +/// - `work` must be a valid device memory pointer with at least `count` f32 elements. +/// - `output` must be a valid device memory pointer with at least `count` f32 elements. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub unsafe fn launch_sparse_qr_extract_r_f32( + context: &Arc, + stream: &CudaStream, + device_index: usize, + work: u64, + count: i32, + output: u64, +) -> Result<()> { + let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?; + let func = get_kernel_function(&module, "sparse_qr_extract_r_f32")?; + let cfg = launch_config((grid_size(count as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0); + + let mut builder = stream.launch_builder(&func); + builder.arg(&work); + builder.arg(&count); + builder.arg(&output); + unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_extract_r_f32", e))?; + Ok(()) +} + +/// Copies `work[0..count]` to output buffer - f64 +/// +/// # Safety +/// +/// - `work` must be a valid device memory pointer with at least `count` f64 elements. +/// - `output` must be a valid device memory pointer with at least `count` f64 elements. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub unsafe fn launch_sparse_qr_extract_r_f64( + context: &Arc, + stream: &CudaStream, + device_index: usize, + work: u64, + count: i32, + output: u64, +) -> Result<()> { + let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?; + let func = get_kernel_function(&module, "sparse_qr_extract_r_f64")?; + let cfg = launch_config((grid_size(count as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0); + + let mut builder = stream.launch_builder(&func); + builder.arg(&work); + builder.arg(&count); + builder.arg(&output); + unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_extract_r_f64", e))?; + Ok(()) +} + +// ============================================================================ +// Clear work vector +// ============================================================================ + +/// Sets `work[0..n]` to zero - f32 +/// +/// # Safety +/// +/// - `work` must be a valid device memory pointer with at least `n` f32 elements. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub unsafe fn launch_sparse_qr_clear_f32( + context: &Arc, + stream: &CudaStream, + device_index: usize, + work: u64, + n: i32, +) -> Result<()> { + let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?; + let func = get_kernel_function(&module, "sparse_qr_clear_f32")?; + let cfg = launch_config((grid_size(n as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0); + + let mut builder = stream.launch_builder(&func); + builder.arg(&work); + builder.arg(&n); + unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_clear_f32", e))?; + Ok(()) +} + +/// Sets `work[0..n]` to zero - f64 +/// +/// # Safety +/// +/// - `work` must be a valid device memory pointer with at least `n` f64 elements. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub unsafe fn launch_sparse_qr_clear_f64( + context: &Arc, + stream: &CudaStream, + device_index: usize, + work: u64, + n: i32, +) -> Result<()> { + let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?; + let func = get_kernel_function(&module, "sparse_qr_clear_f64")?; + let cfg = launch_config((grid_size(n as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0); + + let mut builder = stream.launch_builder(&func); + builder.arg(&work); + builder.arg(&n); + unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_clear_f64", e))?; + Ok(()) +} diff --git a/src/runtime/cuda/kernels/sparse_linalg/trsv.rs b/src/runtime/cuda/kernels/sparse_linalg/trsv.rs index 2acf2067..01625c3c 100644 --- a/src/runtime/cuda/kernels/sparse_linalg/trsv.rs +++ b/src/runtime/cuda/kernels/sparse_linalg/trsv.rs @@ -54,7 +54,14 @@ pub unsafe fn launch_sparse_trsv_lower_level_f32( Ok(()) } -/// Launch level-scheduled lower triangular solve kernel - f64 +/// Launch level-scheduled lower triangular solve kernel (forward substitution) - f64 +/// +/// # Safety +/// +/// - All pointer arguments (`level_rows`, `row_ptrs`, `col_indices`, `values`, `b`, `x`) must be +/// valid device memory pointers allocated on the device associated with `context`. +/// - Buffer sizes must match the expected dimensions: `level_size` rows, matrix of size `n x n`. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_sparse_trsv_lower_level_f64( context: &Arc, @@ -91,6 +98,13 @@ pub unsafe fn launch_sparse_trsv_lower_level_f64( } /// Launch level-scheduled upper triangular solve kernel (backward substitution) - f32 +/// +/// # Safety +/// +/// - All pointer arguments (`level_rows`, `row_ptrs`, `col_indices`, `values`, `b`, `x`) must be +/// valid device memory pointers allocated on the device associated with `context`. +/// - Buffer sizes must match the expected dimensions: `level_size` rows, matrix of size `n x n`. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_sparse_trsv_upper_level_f32( context: &Arc, @@ -123,7 +137,14 @@ pub unsafe fn launch_sparse_trsv_upper_level_f32( Ok(()) } -/// Launch level-scheduled upper triangular solve kernel - f64 +/// Launch level-scheduled upper triangular solve kernel (backward substitution) - f64 +/// +/// # Safety +/// +/// - All pointer arguments (`level_rows`, `row_ptrs`, `col_indices`, `values`, `b`, `x`) must be +/// valid device memory pointers allocated on the device associated with `context`. +/// - Buffer sizes must match the expected dimensions: `level_size` rows, matrix of size `n x n`. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_sparse_trsv_upper_level_f64( context: &Arc, @@ -161,6 +182,14 @@ pub unsafe fn launch_sparse_trsv_upper_level_f64( // ============================================================================ /// Launch multi-RHS lower triangular solve kernel (forward substitution) - f32 +/// +/// # Safety +/// +/// - All pointer arguments (`level_rows`, `row_ptrs`, `col_indices`, `values`, `b`, `x`) must be +/// valid device memory pointers on the device associated with `context`. +/// - The `b` buffer must have at least `n * nrhs` elements; `x` must have at least `n * nrhs`. +/// - `level_size * nrhs` must not overflow `u32` when computing the grid size. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_sparse_trsv_lower_level_multi_rhs_f32( context: &Arc, @@ -200,7 +229,15 @@ pub unsafe fn launch_sparse_trsv_lower_level_multi_rhs_f32( Ok(()) } -/// Launch multi-RHS lower triangular solve kernel - f64 +/// Launch multi-RHS lower triangular solve kernel (forward substitution) - f64 +/// +/// # Safety +/// +/// - All pointer arguments (`level_rows`, `row_ptrs`, `col_indices`, `values`, `b`, `x`) must be +/// valid device memory pointers on the device associated with `context`. +/// - The `b` buffer must have at least `n * nrhs` elements; `x` must have at least `n * nrhs`. +/// - `level_size * nrhs` must not overflow `u32` when computing the grid size. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_sparse_trsv_lower_level_multi_rhs_f64( context: &Arc, @@ -241,6 +278,14 @@ pub unsafe fn launch_sparse_trsv_lower_level_multi_rhs_f64( } /// Launch multi-RHS upper triangular solve kernel (backward substitution) - f32 +/// +/// # Safety +/// +/// - All pointer arguments (`level_rows`, `row_ptrs`, `col_indices`, `values`, `b`, `x`) must be +/// valid device memory pointers on the device associated with `context`. +/// - The `b` buffer must have at least `n * nrhs` elements; `x` must have at least `n * nrhs`. +/// - `level_size * nrhs` must not overflow `u32` when computing the grid size. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_sparse_trsv_upper_level_multi_rhs_f32( context: &Arc, @@ -277,7 +322,15 @@ pub unsafe fn launch_sparse_trsv_upper_level_multi_rhs_f32( Ok(()) } -/// Launch multi-RHS upper triangular solve kernel - f64 +/// Launch multi-RHS upper triangular solve kernel (backward substitution) - f64 +/// +/// # Safety +/// +/// - All pointer arguments (`level_rows`, `row_ptrs`, `col_indices`, `values`, `b`, `x`) must be +/// valid device memory pointers on the device associated with `context`. +/// - The `b` buffer must have at least `n * nrhs` elements; `x` must have at least `n * nrhs`. +/// - `level_size * nrhs` must not overflow `u32` when computing the grid size. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_sparse_trsv_upper_level_multi_rhs_f64( context: &Arc, @@ -318,7 +371,15 @@ pub unsafe fn launch_sparse_trsv_upper_level_multi_rhs_f64( // CSC Format - Single RHS (for LU solve) // ============================================================================ -/// Launch CSC lower triangular solve kernel - f32 +/// Launch CSC lower triangular solve kernel (forward substitution) - f32 +/// +/// # Safety +/// +/// - All pointer arguments (`level_cols`, `col_ptrs`, `row_indices`, `values`, `diag_ptr`, `b`) +/// must be valid device memory pointers on the device associated with `context`. +/// - Buffer sizes must be consistent: `col_ptrs` has `n+1` entries, `row_indices` and `values` +/// have `nnz` entries, `b` has `n` elements, `diag_ptr` has `n` entries. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_sparse_trsv_csc_lower_level_f32( context: &Arc, @@ -355,7 +416,15 @@ pub unsafe fn launch_sparse_trsv_csc_lower_level_f32( Ok(()) } -/// Launch CSC lower triangular solve kernel - f64 +/// Launch CSC lower triangular solve kernel (forward substitution) - f64 +/// +/// # Safety +/// +/// - All pointer arguments (`level_cols`, `col_ptrs`, `row_indices`, `values`, `diag_ptr`, `b`) +/// must be valid device memory pointers on the device associated with `context`. +/// - Buffer sizes must be consistent: `col_ptrs` has `n+1` entries, `row_indices` and `values` +/// have `nnz` entries, `b` has `n` elements, `diag_ptr` has `n` entries. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_sparse_trsv_csc_lower_level_f64( context: &Arc, @@ -392,7 +461,15 @@ pub unsafe fn launch_sparse_trsv_csc_lower_level_f64( Ok(()) } -/// Launch CSC upper triangular solve kernel - f32 +/// Launch CSC upper triangular solve kernel (backward substitution) - f32 +/// +/// # Safety +/// +/// - All pointer arguments (`level_cols`, `col_ptrs`, `row_indices`, `values`, `diag_ptr`, `b`) +/// must be valid device memory pointers on the device associated with `context`. +/// - Buffer sizes must be consistent: `col_ptrs` has `n+1` entries, `row_indices` and `values` +/// have `nnz` entries, `b` has `n` elements, `diag_ptr` has `n` entries. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_sparse_trsv_csc_upper_level_f32( context: &Arc, @@ -426,7 +503,15 @@ pub unsafe fn launch_sparse_trsv_csc_upper_level_f32( Ok(()) } -/// Launch CSC upper triangular solve kernel - f64 +/// Launch CSC upper triangular solve kernel (backward substitution) - f64 +/// +/// # Safety +/// +/// - All pointer arguments (`level_cols`, `col_ptrs`, `row_indices`, `values`, `diag_ptr`, `b`) +/// must be valid device memory pointers on the device associated with `context`. +/// - Buffer sizes must be consistent: `col_ptrs` has `n+1` entries, `row_indices` and `values` +/// have `nnz` entries, `b` has `n` elements, `diag_ptr` has `n` entries. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(clippy::too_many_arguments)] pub unsafe fn launch_sparse_trsv_csc_upper_level_f64( context: &Arc, diff --git a/src/runtime/cuda/kernels/sparse_linalg/utils.rs b/src/runtime/cuda/kernels/sparse_linalg/utils.rs index f8d8b029..d69578d8 100644 --- a/src/runtime/cuda/kernels/sparse_linalg/utils.rs +++ b/src/runtime/cuda/kernels/sparse_linalg/utils.rs @@ -24,6 +24,14 @@ use crate::error::Result; /// /// For each row i, finds the index within that row's entries where col == i (diagonal). /// Stores -1 if no diagonal entry exists. +/// +/// # Safety +/// +/// - `row_ptrs`, `col_indices`, and `diag_indices` must be valid device memory pointers on the +/// device associated with `context`. +/// - `row_ptrs` must have at least `n + 1` elements; `diag_indices` must have at least `n`. +/// - `col_indices` must have at least `nnz` elements (as encoded in `row_ptrs`). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_find_diag_indices( context: &Arc, stream: &CudaStream, @@ -51,6 +59,14 @@ pub unsafe fn launch_find_diag_indices( /// /// For each column j, finds the index within that column's entries where row == j (diagonal). /// Stores -1 if no diagonal entry exists. +/// +/// # Safety +/// +/// - `col_ptrs`, `row_indices`, and `diag_ptr` must be valid device memory pointers on the +/// device associated with `context`. +/// - `col_ptrs` must have at least `n + 1` elements; `diag_ptr` must have at least `n`. +/// - `row_indices` must have at least `nnz` elements (as encoded in `col_ptrs`). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_find_diag_indices_csc( context: &Arc, stream: &CudaStream, @@ -78,7 +94,14 @@ pub unsafe fn launch_find_diag_indices_csc( // Copy Operations (may be unused but kept for potential future use) // ============================================================================ -/// Copy kernel - f32 +/// Copy `n` f32 elements from `src` to `dst` on device (GPU kernel) +/// +/// # Safety +/// +/// - `src` and `dst` must be valid device memory pointers on the device associated with `context`. +/// - Both buffers must have at least `n` f32 elements. +/// - `src` and `dst` must not alias. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(dead_code)] pub unsafe fn launch_copy_f32( context: &Arc, @@ -101,7 +124,14 @@ pub unsafe fn launch_copy_f32( Ok(()) } -/// Copy kernel - f64 +/// Copy `n` f64 elements from `src` to `dst` on device (GPU kernel) +/// +/// # Safety +/// +/// - `src` and `dst` must be valid device memory pointers on the device associated with `context`. +/// - Both buffers must have at least `n` f64 elements. +/// - `src` and `dst` must not alias. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. #[allow(dead_code)] pub unsafe fn launch_copy_f64( context: &Arc, @@ -137,6 +167,14 @@ pub unsafe fn launch_copy_f64( /// * `l_map` - Mapping: l_map[i] = destination index in l_values, or -1 if not in L /// * `u_map` - Mapping: u_map[i] = destination index in u_values, or -1 if not in U /// * `nnz` - Number of non-zero elements in source +/// +/// # Safety +/// +/// - `src_values`, `l_values`, `u_values`, `l_map`, and `u_map` must be valid device memory +/// pointers on the device associated with `context`, each with at least `nnz` elements. +/// - All mapped indices in `l_map` and `u_map` (excluding -1) must be valid indices into their +/// respective output arrays (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_split_lu_scatter_f32( context: &Arc, stream: &CudaStream, @@ -165,6 +203,14 @@ pub unsafe fn launch_split_lu_scatter_f32( } /// Scatter values from factored LU matrix to separate L and U arrays - f64 +/// +/// # Safety +/// +/// - `src_values`, `l_values`, `u_values`, `l_map`, and `u_map` must be valid device memory +/// pointers on the device associated with `context`, each with at least `nnz` elements. +/// - All mapped indices in `l_map` and `u_map` (excluding -1) must be valid indices into their +/// respective output arrays (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_split_lu_scatter_f64( context: &Arc, stream: &CudaStream, @@ -203,6 +249,14 @@ pub unsafe fn launch_split_lu_scatter_f64( /// * `dst_values` - Output values array (lower triangular) /// * `lower_map` - Mapping: lower_map[i] = destination index, or -1 if not in lower /// * `nnz` - Number of non-zero elements in source +/// +/// # Safety +/// +/// - `src_values`, `dst_values`, and `lower_map` must be valid device memory pointers on the +/// device associated with `context`, each with at least `nnz` elements. +/// - All mapped indices in `lower_map` (excluding -1) must be valid indices into `dst_values` +/// (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_extract_lower_scatter_f32( context: &Arc, stream: &CudaStream, @@ -227,6 +281,14 @@ pub unsafe fn launch_extract_lower_scatter_f32( } /// Scatter values from source to lower triangular output - f64 +/// +/// # Safety +/// +/// - `src_values`, `dst_values`, and `lower_map` must be valid device memory pointers on the +/// device associated with `context`, each with at least `nnz` elements. +/// - All mapped indices in `lower_map` (excluding -1) must be valid indices into `dst_values` +/// (no out-of-bounds access). +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn launch_extract_lower_scatter_f64( context: &Arc, stream: &CudaStream, diff --git a/src/runtime/cuda/kernels/sparse_merge.rs b/src/runtime/cuda/kernels/sparse_merge.rs deleted file mode 100644 index 25d4b558..00000000 --- a/src/runtime/cuda/kernels/sparse_merge.rs +++ /dev/null @@ -1,1280 +0,0 @@ -//! Sparse matrix element-wise merge kernel launchers -//! -//! Two-pass algorithm for CSR element-wise operations: -//! 1. Count output size per row -//! 2. Exclusive scan to get row_ptrs -//! 3. Compute merged output - -#![allow(dead_code)] -#![allow(unsafe_op_in_unsafe_fn)] - -use cudarc::driver::PushKernelArg; -use cudarc::driver::safe::{CudaContext, CudaStream}; -use cudarc::types::CudaTypeName; -use std::sync::Arc; - -use super::loader::{ - BLOCK_SIZE, get_kernel_function, get_or_load_module, kernel_names, launch_config, -}; -use crate::dtype::DType; -use crate::error::{Error, Result}; -use crate::runtime::Runtime; -use crate::runtime::cuda::CudaRuntime; -use crate::tensor::Tensor; - -// ============================================================================ -// Generic Kernel Launcher Helpers (DRY principle) -// ============================================================================ - -/// Get dtype-specific kernel name suffix -fn dtype_suffix() -> Result<&'static str> { - match T::NAME { - "f32" => Ok("f32"), - "f64" => Ok("f64"), - "__half" => Ok("f16"), - "__nv_bfloat16" => Ok("bf16"), - _ => Err(Error::Internal(format!( - "Unsupported dtype for sparse operation: {}", - T::NAME - ))), - } -} - -/// Generic launcher for kernels without dtype template (count kernels) -/// -/// Eliminates duplication across count kernel launchers -unsafe fn launch_count_kernel( - context: &Arc, - stream: &CudaStream, - device_index: usize, - kernel_name: &str, - row_ptrs_a: u64, - col_indices_a: u64, - row_ptrs_b: u64, - col_indices_b: u64, - row_counts: u64, - nrows: usize, - error_context: &str, -) -> Result<()> { - let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; - let func = get_kernel_function(&module, kernel_name)?; - - let block_size = BLOCK_SIZE; - let grid_size = (nrows as u32 + block_size - 1) / block_size; - let nrows_i32 = nrows as i32; - - let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let mut builder = stream.launch_builder(&func); - builder.arg(&row_ptrs_a); - builder.arg(&col_indices_a); - builder.arg(&row_ptrs_b); - builder.arg(&col_indices_b); - builder.arg(&row_counts); - builder.arg(&nrows_i32); - - builder - .launch(cfg) - .map_err(|e| Error::Internal(format!("{} kernel launch failed: {:?}", error_context, e)))?; - - Ok(()) -} - -/// Generic launcher for dtype-templated compute kernels (CSR format) -/// -/// Eliminates duplication across CSR add/sub/mul/div compute launchers -unsafe fn launch_csr_compute_kernel( - context: &Arc, - stream: &CudaStream, - device_index: usize, - kernel_base_name: &str, - row_ptrs_a: u64, - col_indices_a: u64, - values_a: u64, - row_ptrs_b: u64, - col_indices_b: u64, - values_b: u64, - out_row_ptrs: u64, - out_col_indices: u64, - out_values: u64, - nrows: usize, - error_context: &str, -) -> Result<()> { - let suffix = dtype_suffix::()?; - let kernel_name = format!("{}_{}", kernel_base_name, suffix); - - let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; - let func = get_kernel_function(&module, &kernel_name)?; - - let block_size = BLOCK_SIZE; - let grid_size = (nrows as u32 + block_size - 1) / block_size; - let nrows_i32 = nrows as i32; - - let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let mut builder = stream.launch_builder(&func); - builder.arg(&row_ptrs_a); - builder.arg(&col_indices_a); - builder.arg(&values_a); - builder.arg(&row_ptrs_b); - builder.arg(&col_indices_b); - builder.arg(&values_b); - builder.arg(&out_row_ptrs); - builder.arg(&out_col_indices); - builder.arg(&out_values); - builder.arg(&nrows_i32); - - builder - .launch(cfg) - .map_err(|e| Error::Internal(format!("{} kernel launch failed: {:?}", error_context, e)))?; - - Ok(()) -} - -/// Generic launcher for dtype-templated compute kernels (CSC format) -/// -/// Eliminates duplication across CSC add/sub/mul/div compute launchers -unsafe fn launch_csc_compute_kernel( - context: &Arc, - stream: &CudaStream, - device_index: usize, - kernel_base_name: &str, - col_ptrs_a: u64, - row_indices_a: u64, - values_a: u64, - col_ptrs_b: u64, - row_indices_b: u64, - values_b: u64, - out_col_ptrs: u64, - out_row_indices: u64, - out_values: u64, - ncols: usize, - error_context: &str, -) -> Result<()> { - let suffix = dtype_suffix::()?; - let kernel_name = format!("{}_{}", kernel_base_name, suffix); - - let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; - let func = get_kernel_function(&module, &kernel_name)?; - - let block_size = BLOCK_SIZE; - let grid_size = (ncols as u32 + block_size - 1) / block_size; - let ncols_i32 = ncols as i32; - - let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let mut builder = stream.launch_builder(&func); - builder.arg(&col_ptrs_a); - builder.arg(&row_indices_a); - builder.arg(&values_a); - builder.arg(&col_ptrs_b); - builder.arg(&row_indices_b); - builder.arg(&values_b); - builder.arg(&out_col_ptrs); - builder.arg(&out_row_indices); - builder.arg(&out_values); - builder.arg(&ncols_i32); - - builder - .launch(cfg) - .map_err(|e| Error::Internal(format!("{} kernel launch failed: {:?}", error_context, e)))?; - - Ok(()) -} - -// ============================================================================ -// Exclusive Scan (Prefix Sum) -// ============================================================================ - -/// Compute exclusive scan (prefix sum) on GPU tensor -/// -/// Input: [3, 1, 4, 2] -/// Output: [0, 3, 4, 8, 10] (n+1 elements, last is total sum) -/// -/// Uses GPU-native parallel scan (no CPU transfer) -fn exclusive_scan_i32( - context: &Arc, - stream: &CudaStream, - device_index: usize, - input: &Tensor, -) -> Result<(Tensor, usize)> { - let device = input.device(); - - // Use GPU scan (imported from scan module) - unsafe { super::scan::exclusive_scan_i32_gpu(context, stream, device_index, device, input) } -} - -// ============================================================================ -// Count Kernels -// ============================================================================ - -/// Launch CSR merge count kernel (for add/sub operations) -/// -/// Counts output size per row using union semantics -unsafe fn launch_csr_merge_count( - context: &Arc, - stream: &CudaStream, - device_index: usize, - row_ptrs_a: u64, - col_indices_a: u64, - row_ptrs_b: u64, - col_indices_b: u64, - row_counts: u64, - nrows: usize, -) -> Result<()> { - launch_count_kernel( - context, - stream, - device_index, - "csr_merge_count", - row_ptrs_a, - col_indices_a, - row_ptrs_b, - col_indices_b, - row_counts, - nrows, - "CUDA sparse merge count", - ) -} - -/// Launch CSR mul count kernel (intersection semantics) -unsafe fn launch_csr_mul_count( - context: &Arc, - stream: &CudaStream, - device_index: usize, - row_ptrs_a: u64, - col_indices_a: u64, - row_ptrs_b: u64, - col_indices_b: u64, - row_counts: u64, - nrows: usize, -) -> Result<()> { - launch_count_kernel( - context, - stream, - device_index, - "csr_mul_count", - row_ptrs_a, - col_indices_a, - row_ptrs_b, - col_indices_b, - row_counts, - nrows, - "CUDA sparse mul count", - ) -} - -// ============================================================================ -// Compute Kernels -// ============================================================================ - -/// Launch CSR add compute kernel -unsafe fn launch_csr_add_compute( - context: &Arc, - stream: &CudaStream, - device_index: usize, - row_ptrs_a: u64, - col_indices_a: u64, - values_a: u64, - row_ptrs_b: u64, - col_indices_b: u64, - values_b: u64, - out_row_ptrs: u64, - out_col_indices: u64, - out_values: u64, - nrows: usize, -) -> Result<()> { - launch_csr_compute_kernel::( - context, - stream, - device_index, - "csr_add_compute", - row_ptrs_a, - col_indices_a, - values_a, - row_ptrs_b, - col_indices_b, - values_b, - out_row_ptrs, - out_col_indices, - out_values, - nrows, - "CUDA sparse add compute", - ) -} - -/// Launch CSR sub compute kernel -unsafe fn launch_csr_sub_compute( - context: &Arc, - stream: &CudaStream, - device_index: usize, - row_ptrs_a: u64, - col_indices_a: u64, - values_a: u64, - row_ptrs_b: u64, - col_indices_b: u64, - values_b: u64, - out_row_ptrs: u64, - out_col_indices: u64, - out_values: u64, - nrows: usize, -) -> Result<()> { - launch_csr_compute_kernel::( - context, - stream, - device_index, - "csr_sub_compute", - row_ptrs_a, - col_indices_a, - values_a, - row_ptrs_b, - col_indices_b, - values_b, - out_row_ptrs, - out_col_indices, - out_values, - nrows, - "CUDA sparse sub compute", - ) -} - -/// Launch CSR mul compute kernel -unsafe fn launch_csr_mul_compute( - context: &Arc, - stream: &CudaStream, - device_index: usize, - row_ptrs_a: u64, - col_indices_a: u64, - values_a: u64, - row_ptrs_b: u64, - col_indices_b: u64, - values_b: u64, - out_row_ptrs: u64, - out_col_indices: u64, - out_values: u64, - nrows: usize, -) -> Result<()> { - launch_csr_compute_kernel::( - context, - stream, - device_index, - "csr_mul_compute", - row_ptrs_a, - col_indices_a, - values_a, - row_ptrs_b, - col_indices_b, - values_b, - out_row_ptrs, - out_col_indices, - out_values, - nrows, - "CUDA sparse mul compute", - ) -} - -/// Launch CSR div compute kernel -unsafe fn launch_csr_div_compute( - context: &Arc, - stream: &CudaStream, - device_index: usize, - row_ptrs_a: u64, - col_indices_a: u64, - values_a: u64, - row_ptrs_b: u64, - col_indices_b: u64, - values_b: u64, - out_row_ptrs: u64, - out_col_indices: u64, - out_values: u64, - nrows: usize, -) -> Result<()> { - launch_csr_compute_kernel::( - context, - stream, - device_index, - "csr_div_compute", - row_ptrs_a, - col_indices_a, - values_a, - row_ptrs_b, - col_indices_b, - values_b, - out_row_ptrs, - out_col_indices, - out_values, - nrows, - "CUDA sparse div compute", - ) -} - -/// Launch CSC intersect count kernel (for mul/div) -unsafe fn launch_csc_intersect_count( - context: &Arc, - stream: &CudaStream, - device_index: usize, - col_ptrs_a: u64, - row_indices_a: u64, - col_ptrs_b: u64, - row_indices_b: u64, - col_counts: u64, - ncols: usize, -) -> Result<()> { - unsafe { - let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; - let func = get_kernel_function(&module, "csc_intersect_count")?; - - let block_size = BLOCK_SIZE; - let grid_size = (ncols as u32 + block_size - 1) / block_size; - let ncols_i32 = ncols as i32; - - let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let mut builder = stream.launch_builder(&func); - builder.arg(&col_ptrs_a); - builder.arg(&row_indices_a); - builder.arg(&col_ptrs_b); - builder.arg(&row_indices_b); - builder.arg(&col_counts); - builder.arg(&ncols_i32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA CSC intersect count kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -/// Launch CSC add compute kernel -unsafe fn launch_csc_add_compute( - context: &Arc, - stream: &CudaStream, - device_index: usize, - col_ptrs_a: u64, - row_indices_a: u64, - values_a: u64, - col_ptrs_b: u64, - row_indices_b: u64, - values_b: u64, - out_col_ptrs: u64, - out_row_indices: u64, - out_values: u64, - ncols: usize, -) -> Result<()> { - let kernel_name = match T::NAME { - "f32" => "csc_add_compute_f32", - "f64" => "csc_add_compute_f64", - "__half" => "csc_add_compute_f16", - "__nv_bfloat16" => "csc_add_compute_bf16", - _ => { - return Err(Error::Internal(format!( - "Unsupported dtype for sparse CSC add: {}", - T::NAME - ))); - } - }; - - unsafe { - let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; - let func = get_kernel_function(&module, kernel_name)?; - - let block_size = BLOCK_SIZE; - let grid_size = (ncols as u32 + block_size - 1) / block_size; - let ncols_i32 = ncols as i32; - - let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let mut builder = stream.launch_builder(&func); - builder.arg(&col_ptrs_a); - builder.arg(&row_indices_a); - builder.arg(&values_a); - builder.arg(&col_ptrs_b); - builder.arg(&row_indices_b); - builder.arg(&values_b); - builder.arg(&out_col_ptrs); - builder.arg(&out_row_indices); - builder.arg(&out_values); - builder.arg(&ncols_i32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA CSC add compute kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -/// Launch CSC sub compute kernel -unsafe fn launch_csc_sub_compute( - context: &Arc, - stream: &CudaStream, - device_index: usize, - col_ptrs_a: u64, - row_indices_a: u64, - values_a: u64, - col_ptrs_b: u64, - row_indices_b: u64, - values_b: u64, - out_col_ptrs: u64, - out_row_indices: u64, - out_values: u64, - ncols: usize, -) -> Result<()> { - let kernel_name = match T::NAME { - "f32" => "csc_sub_compute_f32", - "f64" => "csc_sub_compute_f64", - "__half" => "csc_sub_compute_f16", - "__nv_bfloat16" => "csc_sub_compute_bf16", - _ => { - return Err(Error::Internal(format!( - "Unsupported dtype for sparse CSC sub: {}", - T::NAME - ))); - } - }; - - unsafe { - let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; - let func = get_kernel_function(&module, kernel_name)?; - - let block_size = BLOCK_SIZE; - let grid_size = (ncols as u32 + block_size - 1) / block_size; - let ncols_i32 = ncols as i32; - - let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let mut builder = stream.launch_builder(&func); - builder.arg(&col_ptrs_a); - builder.arg(&row_indices_a); - builder.arg(&values_a); - builder.arg(&col_ptrs_b); - builder.arg(&row_indices_b); - builder.arg(&values_b); - builder.arg(&out_col_ptrs); - builder.arg(&out_row_indices); - builder.arg(&out_values); - builder.arg(&ncols_i32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA CSC sub compute kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -/// Launch CSC mul compute kernel -unsafe fn launch_csc_mul_compute( - context: &Arc, - stream: &CudaStream, - device_index: usize, - col_ptrs_a: u64, - row_indices_a: u64, - values_a: u64, - col_ptrs_b: u64, - row_indices_b: u64, - values_b: u64, - out_col_ptrs: u64, - out_row_indices: u64, - out_values: u64, - ncols: usize, -) -> Result<()> { - let kernel_name = match T::NAME { - "f32" => "csc_mul_compute_f32", - "f64" => "csc_mul_compute_f64", - "__half" => "csc_mul_compute_f16", - "__nv_bfloat16" => "csc_mul_compute_bf16", - _ => { - return Err(Error::Internal(format!( - "Unsupported dtype for sparse CSC mul: {}", - T::NAME - ))); - } - }; - - unsafe { - let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; - let func = get_kernel_function(&module, kernel_name)?; - - let block_size = BLOCK_SIZE; - let grid_size = (ncols as u32 + block_size - 1) / block_size; - let ncols_i32 = ncols as i32; - - let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let mut builder = stream.launch_builder(&func); - builder.arg(&col_ptrs_a); - builder.arg(&row_indices_a); - builder.arg(&values_a); - builder.arg(&col_ptrs_b); - builder.arg(&row_indices_b); - builder.arg(&values_b); - builder.arg(&out_col_ptrs); - builder.arg(&out_row_indices); - builder.arg(&out_values); - builder.arg(&ncols_i32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA CSC mul compute kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -/// Launch CSC div compute kernel -unsafe fn launch_csc_div_compute( - context: &Arc, - stream: &CudaStream, - device_index: usize, - col_ptrs_a: u64, - row_indices_a: u64, - values_a: u64, - col_ptrs_b: u64, - row_indices_b: u64, - values_b: u64, - out_col_ptrs: u64, - out_row_indices: u64, - out_values: u64, - ncols: usize, -) -> Result<()> { - let kernel_name = match T::NAME { - "f32" => "csc_div_compute_f32", - "f64" => "csc_div_compute_f64", - "__half" => "csc_div_compute_f16", - "__nv_bfloat16" => "csc_div_compute_bf16", - _ => { - return Err(Error::Internal(format!( - "Unsupported dtype for sparse CSC div: {}", - T::NAME - ))); - } - }; - - unsafe { - let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; - let func = get_kernel_function(&module, kernel_name)?; - - let block_size = BLOCK_SIZE; - let grid_size = (ncols as u32 + block_size - 1) / block_size; - let ncols_i32 = ncols as i32; - - let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let mut builder = stream.launch_builder(&func); - builder.arg(&col_ptrs_a); - builder.arg(&row_indices_a); - builder.arg(&values_a); - builder.arg(&col_ptrs_b); - builder.arg(&row_indices_b); - builder.arg(&values_b); - builder.arg(&out_col_ptrs); - builder.arg(&out_row_indices); - builder.arg(&out_values); - builder.arg(&ncols_i32); - - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA CSC div compute kernel launch failed: {:?}", - e - )) - })?; - - Ok(()) - } -} - -// ============================================================================ -// High-level Merge Operations -// ============================================================================ - -/// Two-pass CSR addition: C = A + B (union semantics) -/// -/// Now uses generic_csr_merge with AddMerge strategy to eliminate duplication. -pub unsafe fn csr_add_merge( - context: &Arc, - stream: &CudaStream, - device_index: usize, - device: &::Device, - dtype: DType, - row_ptrs_a: &Tensor, - col_indices_a: &Tensor, - values_a: &Tensor, - row_ptrs_b: &Tensor, - col_indices_b: &Tensor, - values_b: &Tensor, - nrows: usize, -) -> Result<( - Tensor, - Tensor, - Tensor, -)> { - use super::sparse_strategy::AddMerge; - generic_csr_merge::( - context, - stream, - device_index, - device, - dtype, - row_ptrs_a, - col_indices_a, - values_a, - row_ptrs_b, - col_indices_b, - values_b, - nrows, - ) -} - -/// Two-pass CSR subtraction: C = A - B (union semantics) -/// -/// Now uses generic_csr_merge with SubMerge strategy to eliminate duplication. -pub unsafe fn csr_sub_merge( - context: &Arc, - stream: &CudaStream, - device_index: usize, - device: &::Device, - dtype: DType, - row_ptrs_a: &Tensor, - col_indices_a: &Tensor, - values_a: &Tensor, - row_ptrs_b: &Tensor, - col_indices_b: &Tensor, - values_b: &Tensor, - nrows: usize, -) -> Result<( - Tensor, - Tensor, - Tensor, -)> { - use super::sparse_strategy::SubMerge; - generic_csr_merge::( - context, - stream, - device_index, - device, - dtype, - row_ptrs_a, - col_indices_a, - values_a, - row_ptrs_b, - col_indices_b, - values_b, - nrows, - ) -} - -/// Two-pass CSR element-wise multiplication: C = A .* B (intersection semantics) -/// -/// Now uses generic_csr_merge with MulMerge strategy to eliminate duplication. -pub unsafe fn csr_mul_merge( - context: &Arc, - stream: &CudaStream, - device_index: usize, - device: &::Device, - dtype: DType, - row_ptrs_a: &Tensor, - col_indices_a: &Tensor, - values_a: &Tensor, - row_ptrs_b: &Tensor, - col_indices_b: &Tensor, - values_b: &Tensor, - nrows: usize, -) -> Result<( - Tensor, - Tensor, - Tensor, -)> { - use super::sparse_strategy::MulMerge; - generic_csr_merge::( - context, - stream, - device_index, - device, - dtype, - row_ptrs_a, - col_indices_a, - values_a, - row_ptrs_b, - col_indices_b, - values_b, - nrows, - ) -} - -/// Two-pass CSR element-wise division: C = A ./ B (intersection semantics) -pub unsafe fn csr_div_merge( - context: &Arc, - stream: &CudaStream, - device_index: usize, - device: &::Device, - dtype: DType, - row_ptrs_a: &Tensor, - col_indices_a: &Tensor, - values_a: &Tensor, - row_ptrs_b: &Tensor, - col_indices_b: &Tensor, - values_b: &Tensor, - nrows: usize, -) -> Result<( - Tensor, - Tensor, - Tensor, -)> { - use super::sparse_strategy::DivMerge; - generic_csr_merge::( - context, - stream, - device_index, - device, - dtype, - row_ptrs_a, - col_indices_a, - values_a, - row_ptrs_b, - col_indices_b, - values_b, - nrows, - ) -} - -// ============================================================================ -// High-level CSC Merge Operations -// ============================================================================ - -/// Two-pass CSC addition: C = A + B (union semantics) -pub unsafe fn csc_add_merge( - context: &Arc, - stream: &CudaStream, - device_index: usize, - device: &::Device, - dtype: DType, - col_ptrs_a: &Tensor, - row_indices_a: &Tensor, - values_a: &Tensor, - col_ptrs_b: &Tensor, - row_indices_b: &Tensor, - values_b: &Tensor, - ncols: usize, -) -> Result<( - Tensor, - Tensor, - Tensor, -)> { - use super::sparse_strategy::AddMerge; - generic_csc_merge::( - context, - stream, - device_index, - device, - dtype, - col_ptrs_a, - row_indices_a, - values_a, - col_ptrs_b, - row_indices_b, - values_b, - ncols, - ) -} - -/// Two-pass CSC subtraction: C = A - B (union semantics) -pub unsafe fn csc_sub_merge( - context: &Arc, - stream: &CudaStream, - device_index: usize, - device: &::Device, - dtype: DType, - col_ptrs_a: &Tensor, - row_indices_a: &Tensor, - values_a: &Tensor, - col_ptrs_b: &Tensor, - row_indices_b: &Tensor, - values_b: &Tensor, - ncols: usize, -) -> Result<( - Tensor, - Tensor, - Tensor, -)> { - use super::sparse_strategy::SubMerge; - generic_csc_merge::( - context, - stream, - device_index, - device, - dtype, - col_ptrs_a, - row_indices_a, - values_a, - col_ptrs_b, - row_indices_b, - values_b, - ncols, - ) -} - -/// Two-pass CSC element-wise multiplication: C = A .* B (intersection semantics) -pub unsafe fn csc_mul_merge( - context: &Arc, - stream: &CudaStream, - device_index: usize, - device: &::Device, - dtype: DType, - col_ptrs_a: &Tensor, - row_indices_a: &Tensor, - values_a: &Tensor, - col_ptrs_b: &Tensor, - row_indices_b: &Tensor, - values_b: &Tensor, - ncols: usize, -) -> Result<( - Tensor, - Tensor, - Tensor, -)> { - use super::sparse_strategy::MulMerge; - generic_csc_merge::( - context, - stream, - device_index, - device, - dtype, - col_ptrs_a, - row_indices_a, - values_a, - col_ptrs_b, - row_indices_b, - values_b, - ncols, - ) -} - -/// Two-pass CSC element-wise division: C = A ./ B (intersection semantics) -pub unsafe fn csc_div_merge( - context: &Arc, - stream: &CudaStream, - device_index: usize, - device: &::Device, - dtype: DType, - col_ptrs_a: &Tensor, - row_indices_a: &Tensor, - values_a: &Tensor, - col_ptrs_b: &Tensor, - row_indices_b: &Tensor, - values_b: &Tensor, - ncols: usize, -) -> Result<( - Tensor, - Tensor, - Tensor, -)> { - use super::sparse_strategy::DivMerge; - generic_csc_merge::( - context, - stream, - device_index, - device, - dtype, - col_ptrs_a, - row_indices_a, - values_a, - col_ptrs_b, - row_indices_b, - values_b, - ncols, - ) -} - -// ============================================================================ -// Generic Merge Implementation (Zero Duplication) -// ============================================================================ - -use super::sparse_strategy::{MergeStrategy, SparseFormat}; - -/// Generic two-pass CSR merge using strategy pattern -/// -/// Eliminates code duplication across add/sub/mul/div operations by abstracting -/// the merge semantics through the MergeStrategy trait. -/// -/// # Type Parameters -/// -/// * `T` - Element type (f32, f64, etc.) -/// * `S` - Merge strategy (AddMerge, SubMerge, MulMerge, DivMerge) -/// -/// # Algorithm -/// -/// 1. **Count**: Determine output size per row using strategy-specific semantics -/// 2. **Scan**: Compute row_ptrs via exclusive prefix sum -/// 3. **Compute**: Merge values using strategy-specific operation -pub unsafe fn generic_csr_merge( - context: &Arc, - stream: &CudaStream, - device_index: usize, - device: &::Device, - dtype: DType, - row_ptrs_a: &Tensor, - col_indices_a: &Tensor, - values_a: &Tensor, - row_ptrs_b: &Tensor, - col_indices_b: &Tensor, - values_b: &Tensor, - nrows: usize, -) -> Result<( - Tensor, - Tensor, - Tensor, -)> { - // Pass 1: Count output size per row - let row_counts = Tensor::::zeros(&[nrows], DType::I32, device); - - // Launch count kernel (union vs intersection semantics determined by strategy) - let count_kernel_name = S::count_kernel_name(SparseFormat::Csr); - let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; - let function = get_kernel_function(&module, count_kernel_name)?; - - let block_size = BLOCK_SIZE; - let grid_size = (nrows as u32 + block_size - 1) / block_size; - let nrows_i32 = nrows as i32; - - let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let mut builder = stream.launch_builder(&function); - - // Store pointers to avoid temporary value issues - let row_ptrs_a_ptr = row_ptrs_a.storage().ptr(); - let col_indices_a_ptr = col_indices_a.storage().ptr(); - let row_ptrs_b_ptr = row_ptrs_b.storage().ptr(); - let col_indices_b_ptr = col_indices_b.storage().ptr(); - let row_counts_ptr = row_counts.storage().ptr(); - - builder.arg(&row_ptrs_a_ptr); - builder.arg(&col_indices_a_ptr); - builder.arg(&row_ptrs_b_ptr); - builder.arg(&col_indices_b_ptr); - builder.arg(&row_counts_ptr); - builder.arg(&nrows_i32); - - // SAFETY: Kernel launch is unsafe because: - // 1. Raw pointers are passed to CUDA kernel - // 2. Kernel accesses GPU memory - // Safety requirements satisfied: - // - All pointers are valid GPU memory addresses from CudaRuntime tensors - // - Tensor lifetimes ensure memory is valid during kernel execution - // - nrows matches the actual tensor dimensions - // - Stream synchronization ensures no data races - unsafe { - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA {} kernel launch failed (nrows={}, strategy={:?}): {:?}", - count_kernel_name, - nrows, - S::OP, - e - )) - })?; - } - - // Synchronize to ensure counts are ready - stream - .synchronize() - .map_err(|e| Error::Internal(format!("Stream synchronize failed: {:?}", e)))?; - - // Exclusive scan to get row_ptrs and total_nnz - let (out_row_ptrs, total_nnz) = exclusive_scan_i32(context, stream, device_index, &row_counts)?; - - // Pass 2: Allocate output and compute merged result - let out_col_indices = Tensor::::zeros(&[total_nnz], DType::I32, device); - let out_values = Tensor::::zeros(&[total_nnz], dtype, device); - - // Launch compute kernel (operation-specific) - let compute_kernel_name = S::compute_kernel_name(SparseFormat::Csr, T::NAME); - let function = get_kernel_function(&module, &compute_kernel_name)?; - - let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let mut builder = stream.launch_builder(&function); - - // Store pointers to avoid temporary value issues - let row_ptrs_a_ptr = row_ptrs_a.storage().ptr(); - let col_indices_a_ptr = col_indices_a.storage().ptr(); - let values_a_ptr = values_a.storage().ptr(); - let row_ptrs_b_ptr = row_ptrs_b.storage().ptr(); - let col_indices_b_ptr = col_indices_b.storage().ptr(); - let values_b_ptr = values_b.storage().ptr(); - let out_row_ptrs_ptr = out_row_ptrs.storage().ptr(); - let out_col_indices_ptr = out_col_indices.storage().ptr(); - let out_values_ptr = out_values.storage().ptr(); - - builder.arg(&row_ptrs_a_ptr); - builder.arg(&col_indices_a_ptr); - builder.arg(&values_a_ptr); - builder.arg(&row_ptrs_b_ptr); - builder.arg(&col_indices_b_ptr); - builder.arg(&values_b_ptr); - builder.arg(&out_row_ptrs_ptr); - builder.arg(&out_col_indices_ptr); - builder.arg(&out_values_ptr); - builder.arg(&nrows_i32); - - // SAFETY: Kernel launch is unsafe because: - // 1. Raw pointers are passed to CUDA kernel - // 2. Kernel writes to output tensors - // Safety requirements satisfied: - // - All input pointers are valid GPU memory from input tensors - // - Output tensors allocated with correct size (total_nnz from scan) - // - Tensor ownership prevents concurrent modification - // - Stream ordering ensures count kernel completed before compute kernel - unsafe { - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA {} kernel launch failed (nrows={}, total_nnz={}, strategy={:?}): {:?}", - compute_kernel_name, - nrows, - total_nnz, - S::OP, - e - )) - })?; - } - - Ok((out_row_ptrs, out_col_indices, out_values)) -} - -/// Generic two-pass CSC merge using strategy pattern -/// -/// CSC variant of generic_csr_merge. See generic_csr_merge for details. -pub unsafe fn generic_csc_merge( - context: &Arc, - stream: &CudaStream, - device_index: usize, - device: &::Device, - dtype: DType, - col_ptrs_a: &Tensor, - row_indices_a: &Tensor, - values_a: &Tensor, - col_ptrs_b: &Tensor, - row_indices_b: &Tensor, - values_b: &Tensor, - ncols: usize, -) -> Result<( - Tensor, - Tensor, - Tensor, -)> { - // Pass 1: Count output size per column - let col_counts = Tensor::::zeros(&[ncols], DType::I32, device); - - // Launch count kernel - let count_kernel_name = S::count_kernel_name(SparseFormat::Csc); - let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; - let function = get_kernel_function(&module, count_kernel_name)?; - - let block_size = BLOCK_SIZE; - let grid_size = (ncols as u32 + block_size - 1) / block_size; - let ncols_i32 = ncols as i32; - - let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let mut builder = stream.launch_builder(&function); - - // Store pointers to avoid temporary value issues - let col_ptrs_a_ptr = col_ptrs_a.storage().ptr(); - let row_indices_a_ptr = row_indices_a.storage().ptr(); - let col_ptrs_b_ptr = col_ptrs_b.storage().ptr(); - let row_indices_b_ptr = row_indices_b.storage().ptr(); - let col_counts_ptr = col_counts.storage().ptr(); - - builder.arg(&col_ptrs_a_ptr); - builder.arg(&row_indices_a_ptr); - builder.arg(&col_ptrs_b_ptr); - builder.arg(&row_indices_b_ptr); - builder.arg(&col_counts_ptr); - builder.arg(&ncols_i32); - - // SAFETY: Kernel launch is unsafe because: - // 1. Raw pointers are passed to CUDA kernel - // 2. Kernel accesses GPU memory - // Safety requirements satisfied: - // - All pointers are valid GPU memory addresses from CudaRuntime tensors - // - Tensor lifetimes ensure memory is valid during kernel execution - // - ncols matches the actual tensor dimensions - // - Stream synchronization ensures no data races - unsafe { - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA {} kernel launch failed (ncols={}, strategy={:?}): {:?}", - count_kernel_name, - ncols, - S::OP, - e - )) - })?; - } - - // Synchronize to ensure counts are ready - stream - .synchronize() - .map_err(|e| Error::Internal(format!("Stream synchronize failed: {:?}", e)))?; - - // Exclusive scan to get col_ptrs and total_nnz - let (out_col_ptrs, total_nnz) = exclusive_scan_i32(context, stream, device_index, &col_counts)?; - - // Pass 2: Allocate output and compute merged result - let out_row_indices = Tensor::::zeros(&[total_nnz], DType::I32, device); - let out_values = Tensor::::zeros(&[total_nnz], dtype, device); - - // Launch compute kernel - let compute_kernel_name = S::compute_kernel_name(SparseFormat::Csc, T::NAME); - let function = get_kernel_function(&module, &compute_kernel_name)?; - - let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let mut builder = stream.launch_builder(&function); - - // Store pointers to avoid temporary value issues - let col_ptrs_a_ptr = col_ptrs_a.storage().ptr(); - let row_indices_a_ptr = row_indices_a.storage().ptr(); - let values_a_ptr = values_a.storage().ptr(); - let col_ptrs_b_ptr = col_ptrs_b.storage().ptr(); - let row_indices_b_ptr = row_indices_b.storage().ptr(); - let values_b_ptr = values_b.storage().ptr(); - let out_col_ptrs_ptr = out_col_ptrs.storage().ptr(); - let out_row_indices_ptr = out_row_indices.storage().ptr(); - let out_values_ptr = out_values.storage().ptr(); - - builder.arg(&col_ptrs_a_ptr); - builder.arg(&row_indices_a_ptr); - builder.arg(&values_a_ptr); - builder.arg(&col_ptrs_b_ptr); - builder.arg(&row_indices_b_ptr); - builder.arg(&values_b_ptr); - builder.arg(&out_col_ptrs_ptr); - builder.arg(&out_row_indices_ptr); - builder.arg(&out_values_ptr); - builder.arg(&ncols_i32); - - // SAFETY: Kernel launch is unsafe because: - // 1. Raw pointers are passed to CUDA kernel - // 2. Kernel writes to output tensors - // Safety requirements satisfied: - // - All input pointers are valid GPU memory from input tensors - // - Output tensors allocated with correct size (total_nnz from scan) - // - Tensor ownership prevents concurrent modification - // - Stream ordering ensures count kernel completed before compute kernel - unsafe { - builder.launch(cfg).map_err(|e| { - Error::Internal(format!( - "CUDA {} kernel launch failed (ncols={}, total_nnz={}, strategy={:?}): {:?}", - compute_kernel_name, - ncols, - total_nnz, - S::OP, - e - )) - })?; - } - - Ok((out_col_ptrs, out_row_indices, out_values)) -} diff --git a/src/runtime/cuda/kernels/sparse_merge/csc.rs b/src/runtime/cuda/kernels/sparse_merge/csc.rs new file mode 100644 index 00000000..ab648f9d --- /dev/null +++ b/src/runtime/cuda/kernels/sparse_merge/csc.rs @@ -0,0 +1,517 @@ +//! CSC (Compressed Sparse Column) merge kernel launchers +//! +//! Low-level count and compute launchers plus high-level public merge operations +//! for CSC format sparse matrices. + +#![allow(dead_code)] +#![allow(unsafe_op_in_unsafe_fn)] + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use cudarc::types::CudaTypeName; +use std::sync::Arc; + +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::runtime::Runtime; +use crate::runtime::cuda::CudaRuntime; +use crate::tensor::Tensor; + +use super::super::loader::{ + BLOCK_SIZE, get_kernel_function, get_or_load_module, kernel_names, launch_config, +}; + +// ============================================================================ +// Count Kernels +// ============================================================================ + +/// Launch CSC intersect count kernel (for mul/div) +/// +/// # Safety +/// +/// - `col_ptrs_a`, `row_indices_a`, `col_ptrs_b`, `row_indices_b`, and `col_counts` must be +/// valid device memory pointers on the device associated with `context`. +/// - `ncols` must match the number of columns in both input CSC matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_csc_intersect_count( + context: &Arc, + stream: &CudaStream, + device_index: usize, + col_ptrs_a: u64, + row_indices_a: u64, + col_ptrs_b: u64, + row_indices_b: u64, + col_counts: u64, + ncols: usize, +) -> Result<()> { + unsafe { + let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; + let func = get_kernel_function(&module, "csc_intersect_count")?; + + let block_size = BLOCK_SIZE; + let grid_size = (ncols as u32 + block_size - 1) / block_size; + let ncols_i32 = ncols as i32; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&col_ptrs_a); + builder.arg(&row_indices_a); + builder.arg(&col_ptrs_b); + builder.arg(&row_indices_b); + builder.arg(&col_counts); + builder.arg(&ncols_i32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA CSC intersect count kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +// ============================================================================ +// Compute Kernels +// ============================================================================ + +/// Launch CSC add compute kernel +/// +/// # Safety +/// +/// - All pointer arguments must be valid device memory pointers on the device associated +/// with `context`. Output buffers must be pre-allocated to the correct sizes. +/// - `ncols` must match the number of columns in both input CSC matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_csc_add_compute( + context: &Arc, + stream: &CudaStream, + device_index: usize, + col_ptrs_a: u64, + row_indices_a: u64, + values_a: u64, + col_ptrs_b: u64, + row_indices_b: u64, + values_b: u64, + out_col_ptrs: u64, + out_row_indices: u64, + out_values: u64, + ncols: usize, +) -> Result<()> { + let kernel_name = match T::NAME { + "f32" => "csc_add_compute_f32", + "f64" => "csc_add_compute_f64", + "__half" => "csc_add_compute_f16", + "__nv_bfloat16" => "csc_add_compute_bf16", + _ => { + return Err(Error::Internal(format!( + "Unsupported dtype for sparse CSC add: {}", + T::NAME + ))); + } + }; + + unsafe { + let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; + let func = get_kernel_function(&module, kernel_name)?; + + let block_size = BLOCK_SIZE; + let grid_size = (ncols as u32 + block_size - 1) / block_size; + let ncols_i32 = ncols as i32; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&col_ptrs_a); + builder.arg(&row_indices_a); + builder.arg(&values_a); + builder.arg(&col_ptrs_b); + builder.arg(&row_indices_b); + builder.arg(&values_b); + builder.arg(&out_col_ptrs); + builder.arg(&out_row_indices); + builder.arg(&out_values); + builder.arg(&ncols_i32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA CSC add compute kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +/// Launch CSC sub compute kernel +/// +/// # Safety +/// +/// - All pointer arguments must be valid device memory pointers on the device associated +/// with `context`. Output buffers must be pre-allocated to the correct sizes. +/// - `ncols` must match the number of columns in both input CSC matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_csc_sub_compute( + context: &Arc, + stream: &CudaStream, + device_index: usize, + col_ptrs_a: u64, + row_indices_a: u64, + values_a: u64, + col_ptrs_b: u64, + row_indices_b: u64, + values_b: u64, + out_col_ptrs: u64, + out_row_indices: u64, + out_values: u64, + ncols: usize, +) -> Result<()> { + let kernel_name = match T::NAME { + "f32" => "csc_sub_compute_f32", + "f64" => "csc_sub_compute_f64", + "__half" => "csc_sub_compute_f16", + "__nv_bfloat16" => "csc_sub_compute_bf16", + _ => { + return Err(Error::Internal(format!( + "Unsupported dtype for sparse CSC sub: {}", + T::NAME + ))); + } + }; + + unsafe { + let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; + let func = get_kernel_function(&module, kernel_name)?; + + let block_size = BLOCK_SIZE; + let grid_size = (ncols as u32 + block_size - 1) / block_size; + let ncols_i32 = ncols as i32; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&col_ptrs_a); + builder.arg(&row_indices_a); + builder.arg(&values_a); + builder.arg(&col_ptrs_b); + builder.arg(&row_indices_b); + builder.arg(&values_b); + builder.arg(&out_col_ptrs); + builder.arg(&out_row_indices); + builder.arg(&out_values); + builder.arg(&ncols_i32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA CSC sub compute kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +/// Launch CSC mul compute kernel +/// +/// # Safety +/// +/// - All pointer arguments must be valid device memory pointers on the device associated +/// with `context`. Output buffers must be pre-allocated to the correct sizes. +/// - `ncols` must match the number of columns in both input CSC matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_csc_mul_compute( + context: &Arc, + stream: &CudaStream, + device_index: usize, + col_ptrs_a: u64, + row_indices_a: u64, + values_a: u64, + col_ptrs_b: u64, + row_indices_b: u64, + values_b: u64, + out_col_ptrs: u64, + out_row_indices: u64, + out_values: u64, + ncols: usize, +) -> Result<()> { + let kernel_name = match T::NAME { + "f32" => "csc_mul_compute_f32", + "f64" => "csc_mul_compute_f64", + "__half" => "csc_mul_compute_f16", + "__nv_bfloat16" => "csc_mul_compute_bf16", + _ => { + return Err(Error::Internal(format!( + "Unsupported dtype for sparse CSC mul: {}", + T::NAME + ))); + } + }; + + unsafe { + let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; + let func = get_kernel_function(&module, kernel_name)?; + + let block_size = BLOCK_SIZE; + let grid_size = (ncols as u32 + block_size - 1) / block_size; + let ncols_i32 = ncols as i32; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&col_ptrs_a); + builder.arg(&row_indices_a); + builder.arg(&values_a); + builder.arg(&col_ptrs_b); + builder.arg(&row_indices_b); + builder.arg(&values_b); + builder.arg(&out_col_ptrs); + builder.arg(&out_row_indices); + builder.arg(&out_values); + builder.arg(&ncols_i32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA CSC mul compute kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +/// Launch CSC div compute kernel +/// +/// # Safety +/// +/// - All pointer arguments must be valid device memory pointers on the device associated +/// with `context`. Output buffers must be pre-allocated to the correct sizes. +/// - `ncols` must match the number of columns in both input CSC matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_csc_div_compute( + context: &Arc, + stream: &CudaStream, + device_index: usize, + col_ptrs_a: u64, + row_indices_a: u64, + values_a: u64, + col_ptrs_b: u64, + row_indices_b: u64, + values_b: u64, + out_col_ptrs: u64, + out_row_indices: u64, + out_values: u64, + ncols: usize, +) -> Result<()> { + let kernel_name = match T::NAME { + "f32" => "csc_div_compute_f32", + "f64" => "csc_div_compute_f64", + "__half" => "csc_div_compute_f16", + "__nv_bfloat16" => "csc_div_compute_bf16", + _ => { + return Err(Error::Internal(format!( + "Unsupported dtype for sparse CSC div: {}", + T::NAME + ))); + } + }; + + unsafe { + let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; + let func = get_kernel_function(&module, kernel_name)?; + + let block_size = BLOCK_SIZE; + let grid_size = (ncols as u32 + block_size - 1) / block_size; + let ncols_i32 = ncols as i32; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&col_ptrs_a); + builder.arg(&row_indices_a); + builder.arg(&values_a); + builder.arg(&col_ptrs_b); + builder.arg(&row_indices_b); + builder.arg(&values_b); + builder.arg(&out_col_ptrs); + builder.arg(&out_row_indices); + builder.arg(&out_values); + builder.arg(&ncols_i32); + + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA CSC div compute kernel launch failed: {:?}", + e + )) + })?; + + Ok(()) + } +} + +// ============================================================================ +// High-level CSC Merge Operations +// ============================================================================ + +/// Two-pass CSC addition: C = A + B (union semantics) +/// +/// # Safety +/// +/// All tensor arguments must contain valid CUDA device pointers with correct sizes +/// for the given sparse CSC format. `ncols` must match the sparse matrix dimensions. +pub unsafe fn csc_add_merge( + context: &Arc, + stream: &CudaStream, + device_index: usize, + device: &::Device, + dtype: DType, + col_ptrs_a: &Tensor, + row_indices_a: &Tensor, + values_a: &Tensor, + col_ptrs_b: &Tensor, + row_indices_b: &Tensor, + values_b: &Tensor, + ncols: usize, +) -> Result<( + Tensor, + Tensor, + Tensor, +)> { + use super::super::sparse_strategy::AddMerge; + super::generic::generic_csc_merge::( + context, + stream, + device_index, + device, + dtype, + col_ptrs_a, + row_indices_a, + values_a, + col_ptrs_b, + row_indices_b, + values_b, + ncols, + ) +} + +/// Two-pass CSC subtraction: C = A - B (union semantics) +/// +/// # Safety +/// +/// All tensor arguments must contain valid CUDA device pointers with correct sizes +/// for the given sparse CSC format. `ncols` must match the sparse matrix dimensions. +pub unsafe fn csc_sub_merge( + context: &Arc, + stream: &CudaStream, + device_index: usize, + device: &::Device, + dtype: DType, + col_ptrs_a: &Tensor, + row_indices_a: &Tensor, + values_a: &Tensor, + col_ptrs_b: &Tensor, + row_indices_b: &Tensor, + values_b: &Tensor, + ncols: usize, +) -> Result<( + Tensor, + Tensor, + Tensor, +)> { + use super::super::sparse_strategy::SubMerge; + super::generic::generic_csc_merge::( + context, + stream, + device_index, + device, + dtype, + col_ptrs_a, + row_indices_a, + values_a, + col_ptrs_b, + row_indices_b, + values_b, + ncols, + ) +} + +/// Two-pass CSC element-wise multiplication: C = A .* B (intersection semantics) +/// +/// # Safety +/// +/// All tensor arguments must contain valid CUDA device pointers with correct sizes +/// for the given sparse CSC format. `ncols` must match the sparse matrix dimensions. +pub unsafe fn csc_mul_merge( + context: &Arc, + stream: &CudaStream, + device_index: usize, + device: &::Device, + dtype: DType, + col_ptrs_a: &Tensor, + row_indices_a: &Tensor, + values_a: &Tensor, + col_ptrs_b: &Tensor, + row_indices_b: &Tensor, + values_b: &Tensor, + ncols: usize, +) -> Result<( + Tensor, + Tensor, + Tensor, +)> { + use super::super::sparse_strategy::MulMerge; + super::generic::generic_csc_merge::( + context, + stream, + device_index, + device, + dtype, + col_ptrs_a, + row_indices_a, + values_a, + col_ptrs_b, + row_indices_b, + values_b, + ncols, + ) +} + +/// Two-pass CSC element-wise division: C = A ./ B (intersection semantics) +/// +/// # Safety +/// +/// All tensor arguments must contain valid CUDA device pointers with correct sizes +/// for the given sparse CSC format. `ncols` must match the sparse matrix dimensions. +pub unsafe fn csc_div_merge( + context: &Arc, + stream: &CudaStream, + device_index: usize, + device: &::Device, + dtype: DType, + col_ptrs_a: &Tensor, + row_indices_a: &Tensor, + values_a: &Tensor, + col_ptrs_b: &Tensor, + row_indices_b: &Tensor, + values_b: &Tensor, + ncols: usize, +) -> Result<( + Tensor, + Tensor, + Tensor, +)> { + use super::super::sparse_strategy::DivMerge; + super::generic::generic_csc_merge::( + context, + stream, + device_index, + device, + dtype, + col_ptrs_a, + row_indices_a, + values_a, + col_ptrs_b, + row_indices_b, + values_b, + ncols, + ) +} diff --git a/src/runtime/cuda/kernels/sparse_merge/csr.rs b/src/runtime/cuda/kernels/sparse_merge/csr.rs new file mode 100644 index 00000000..654d789a --- /dev/null +++ b/src/runtime/cuda/kernels/sparse_merge/csr.rs @@ -0,0 +1,439 @@ +//! CSR (Compressed Sparse Row) merge kernel launchers +//! +//! Low-level count and compute launchers plus high-level public merge operations +//! for CSR format sparse matrices. + +#![allow(dead_code)] +#![allow(unsafe_op_in_unsafe_fn)] + +use cudarc::driver::safe::{CudaContext, CudaStream}; +use cudarc::types::CudaTypeName; +use std::sync::Arc; + +use crate::dtype::DType; +use crate::error::Result; +use crate::runtime::Runtime; +use crate::runtime::cuda::CudaRuntime; +use crate::tensor::Tensor; + +use super::helpers::{launch_count_kernel, launch_csr_compute_kernel}; + +// ============================================================================ +// Count Kernels +// ============================================================================ + +/// Launch CSR merge count kernel (for add/sub operations) +/// +/// Counts output size per row using union semantics +/// +/// # Safety +/// +/// - `row_ptrs_a`, `col_indices_a`, `row_ptrs_b`, `col_indices_b`, and `row_counts` must be +/// valid device memory pointers on the device associated with `context`. +/// - `nrows` must match the number of rows in both input CSR matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_csr_merge_count( + context: &Arc, + stream: &CudaStream, + device_index: usize, + row_ptrs_a: u64, + col_indices_a: u64, + row_ptrs_b: u64, + col_indices_b: u64, + row_counts: u64, + nrows: usize, +) -> Result<()> { + launch_count_kernel( + context, + stream, + device_index, + "csr_merge_count", + row_ptrs_a, + col_indices_a, + row_ptrs_b, + col_indices_b, + row_counts, + nrows, + "CUDA sparse merge count", + ) +} + +/// Launch CSR mul count kernel (intersection semantics) +/// +/// # Safety +/// +/// - `row_ptrs_a`, `col_indices_a`, `row_ptrs_b`, `col_indices_b`, and `row_counts` must be +/// valid device memory pointers on the device associated with `context`. +/// - `nrows` must match the number of rows in both input CSR matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_csr_mul_count( + context: &Arc, + stream: &CudaStream, + device_index: usize, + row_ptrs_a: u64, + col_indices_a: u64, + row_ptrs_b: u64, + col_indices_b: u64, + row_counts: u64, + nrows: usize, +) -> Result<()> { + launch_count_kernel( + context, + stream, + device_index, + "csr_mul_count", + row_ptrs_a, + col_indices_a, + row_ptrs_b, + col_indices_b, + row_counts, + nrows, + "CUDA sparse mul count", + ) +} + +// ============================================================================ +// Compute Kernels +// ============================================================================ + +/// Launch CSR add compute kernel +/// +/// # Safety +/// +/// - All pointer arguments must be valid device memory pointers on the device associated +/// with `context`. Output buffers must be pre-allocated to the correct sizes. +/// - `nrows` must match the number of rows in both input CSR matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_csr_add_compute( + context: &Arc, + stream: &CudaStream, + device_index: usize, + row_ptrs_a: u64, + col_indices_a: u64, + values_a: u64, + row_ptrs_b: u64, + col_indices_b: u64, + values_b: u64, + out_row_ptrs: u64, + out_col_indices: u64, + out_values: u64, + nrows: usize, +) -> Result<()> { + launch_csr_compute_kernel::( + context, + stream, + device_index, + "csr_add_compute", + row_ptrs_a, + col_indices_a, + values_a, + row_ptrs_b, + col_indices_b, + values_b, + out_row_ptrs, + out_col_indices, + out_values, + nrows, + "CUDA sparse add compute", + ) +} + +/// Launch CSR sub compute kernel +/// +/// # Safety +/// +/// - All pointer arguments must be valid device memory pointers on the device associated +/// with `context`. Output buffers must be pre-allocated to the correct sizes. +/// - `nrows` must match the number of rows in both input CSR matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_csr_sub_compute( + context: &Arc, + stream: &CudaStream, + device_index: usize, + row_ptrs_a: u64, + col_indices_a: u64, + values_a: u64, + row_ptrs_b: u64, + col_indices_b: u64, + values_b: u64, + out_row_ptrs: u64, + out_col_indices: u64, + out_values: u64, + nrows: usize, +) -> Result<()> { + launch_csr_compute_kernel::( + context, + stream, + device_index, + "csr_sub_compute", + row_ptrs_a, + col_indices_a, + values_a, + row_ptrs_b, + col_indices_b, + values_b, + out_row_ptrs, + out_col_indices, + out_values, + nrows, + "CUDA sparse sub compute", + ) +} + +/// Launch CSR mul compute kernel +/// +/// # Safety +/// +/// - All pointer arguments must be valid device memory pointers on the device associated +/// with `context`. Output buffers must be pre-allocated to the correct sizes. +/// - `nrows` must match the number of rows in both input CSR matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_csr_mul_compute( + context: &Arc, + stream: &CudaStream, + device_index: usize, + row_ptrs_a: u64, + col_indices_a: u64, + values_a: u64, + row_ptrs_b: u64, + col_indices_b: u64, + values_b: u64, + out_row_ptrs: u64, + out_col_indices: u64, + out_values: u64, + nrows: usize, +) -> Result<()> { + launch_csr_compute_kernel::( + context, + stream, + device_index, + "csr_mul_compute", + row_ptrs_a, + col_indices_a, + values_a, + row_ptrs_b, + col_indices_b, + values_b, + out_row_ptrs, + out_col_indices, + out_values, + nrows, + "CUDA sparse mul compute", + ) +} + +/// Launch CSR div compute kernel +/// +/// # Safety +/// +/// - All pointer arguments must be valid device memory pointers on the device associated +/// with `context`. Output buffers must be pre-allocated to the correct sizes. +/// - `nrows` must match the number of rows in both input CSR matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_csr_div_compute( + context: &Arc, + stream: &CudaStream, + device_index: usize, + row_ptrs_a: u64, + col_indices_a: u64, + values_a: u64, + row_ptrs_b: u64, + col_indices_b: u64, + values_b: u64, + out_row_ptrs: u64, + out_col_indices: u64, + out_values: u64, + nrows: usize, +) -> Result<()> { + launch_csr_compute_kernel::( + context, + stream, + device_index, + "csr_div_compute", + row_ptrs_a, + col_indices_a, + values_a, + row_ptrs_b, + col_indices_b, + values_b, + out_row_ptrs, + out_col_indices, + out_values, + nrows, + "CUDA sparse div compute", + ) +} + +// ============================================================================ +// High-level CSR Merge Operations +// ============================================================================ + +/// Two-pass CSR addition: C = A + B (union semantics) +/// +/// Now uses generic_csr_merge with AddMerge strategy to eliminate duplication. +/// +/// # Safety +/// +/// All tensor arguments must contain valid CUDA device pointers with correct sizes +/// for the given sparse CSR format. `nrows` must match the sparse matrix dimensions. +pub unsafe fn csr_add_merge( + context: &Arc, + stream: &CudaStream, + device_index: usize, + device: &::Device, + dtype: DType, + row_ptrs_a: &Tensor, + col_indices_a: &Tensor, + values_a: &Tensor, + row_ptrs_b: &Tensor, + col_indices_b: &Tensor, + values_b: &Tensor, + nrows: usize, +) -> Result<( + Tensor, + Tensor, + Tensor, +)> { + use super::super::sparse_strategy::AddMerge; + super::generic::generic_csr_merge::( + context, + stream, + device_index, + device, + dtype, + row_ptrs_a, + col_indices_a, + values_a, + row_ptrs_b, + col_indices_b, + values_b, + nrows, + ) +} + +/// Two-pass CSR subtraction: C = A - B (union semantics) +/// +/// Now uses generic_csr_merge with SubMerge strategy to eliminate duplication. +/// +/// # Safety +/// +/// All tensor arguments must contain valid CUDA device pointers with correct sizes +/// for the given sparse CSR format. `nrows` must match the sparse matrix dimensions. +pub unsafe fn csr_sub_merge( + context: &Arc, + stream: &CudaStream, + device_index: usize, + device: &::Device, + dtype: DType, + row_ptrs_a: &Tensor, + col_indices_a: &Tensor, + values_a: &Tensor, + row_ptrs_b: &Tensor, + col_indices_b: &Tensor, + values_b: &Tensor, + nrows: usize, +) -> Result<( + Tensor, + Tensor, + Tensor, +)> { + use super::super::sparse_strategy::SubMerge; + super::generic::generic_csr_merge::( + context, + stream, + device_index, + device, + dtype, + row_ptrs_a, + col_indices_a, + values_a, + row_ptrs_b, + col_indices_b, + values_b, + nrows, + ) +} + +/// Two-pass CSR element-wise multiplication: C = A .* B (intersection semantics) +/// +/// Now uses generic_csr_merge with MulMerge strategy to eliminate duplication. +/// +/// # Safety +/// +/// All tensor arguments must contain valid CUDA device pointers with correct sizes +/// for the given sparse CSR format. `nrows` must match the sparse matrix dimensions. +pub unsafe fn csr_mul_merge( + context: &Arc, + stream: &CudaStream, + device_index: usize, + device: &::Device, + dtype: DType, + row_ptrs_a: &Tensor, + col_indices_a: &Tensor, + values_a: &Tensor, + row_ptrs_b: &Tensor, + col_indices_b: &Tensor, + values_b: &Tensor, + nrows: usize, +) -> Result<( + Tensor, + Tensor, + Tensor, +)> { + use super::super::sparse_strategy::MulMerge; + super::generic::generic_csr_merge::( + context, + stream, + device_index, + device, + dtype, + row_ptrs_a, + col_indices_a, + values_a, + row_ptrs_b, + col_indices_b, + values_b, + nrows, + ) +} + +/// Two-pass CSR element-wise division: C = A ./ B (intersection semantics) +/// +/// # Safety +/// +/// All tensor arguments must contain valid CUDA device pointers with correct sizes +/// for the given sparse CSR format. `nrows` must match the sparse matrix dimensions. +pub unsafe fn csr_div_merge( + context: &Arc, + stream: &CudaStream, + device_index: usize, + device: &::Device, + dtype: DType, + row_ptrs_a: &Tensor, + col_indices_a: &Tensor, + values_a: &Tensor, + row_ptrs_b: &Tensor, + col_indices_b: &Tensor, + values_b: &Tensor, + nrows: usize, +) -> Result<( + Tensor, + Tensor, + Tensor, +)> { + use super::super::sparse_strategy::DivMerge; + super::generic::generic_csr_merge::( + context, + stream, + device_index, + device, + dtype, + row_ptrs_a, + col_indices_a, + values_a, + row_ptrs_b, + col_indices_b, + values_b, + nrows, + ) +} diff --git a/src/runtime/cuda/kernels/sparse_merge/generic.rs b/src/runtime/cuda/kernels/sparse_merge/generic.rs new file mode 100644 index 00000000..db022a2e --- /dev/null +++ b/src/runtime/cuda/kernels/sparse_merge/generic.rs @@ -0,0 +1,318 @@ +//! Generic two-pass merge implementations for sparse matrices +//! +//! Zero-duplication generic merge using the strategy pattern. +//! Both CSR and CSC formats are handled here. + +#![allow(dead_code)] +#![allow(unsafe_op_in_unsafe_fn)] + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use cudarc::types::CudaTypeName; +use std::sync::Arc; + +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::runtime::Runtime; +use crate::runtime::cuda::CudaRuntime; +use crate::tensor::Tensor; + +use super::super::loader::{ + BLOCK_SIZE, get_kernel_function, get_or_load_module, kernel_names, launch_config, +}; +use super::super::sparse_strategy::{MergeStrategy, SparseFormat}; +use super::helpers::exclusive_scan_i32; + +/// Generic two-pass CSR merge using strategy pattern +/// +/// Eliminates code duplication across add/sub/mul/div operations by abstracting +/// the merge semantics through the MergeStrategy trait. +/// +/// # Type Parameters +/// +/// * `T` - Element type (f32, f64, etc.) +/// * `S` - Merge strategy (AddMerge, SubMerge, MulMerge, DivMerge) +/// +/// # Algorithm +/// +/// 1. **Count**: Determine output size per row using strategy-specific semantics +/// 2. **Scan**: Compute row_ptrs via exclusive prefix sum +/// 3. **Compute**: Merge values using strategy-specific operation +/// +/// # Safety +/// +/// All tensor arguments must contain valid CUDA device pointers with correct sizes +/// for the given sparse CSR format. `nrows` must match the sparse matrix dimensions. +/// The CUDA stream and context must be valid and associated with the correct device. +pub unsafe fn generic_csr_merge( + context: &Arc, + stream: &CudaStream, + device_index: usize, + device: &::Device, + dtype: DType, + row_ptrs_a: &Tensor, + col_indices_a: &Tensor, + values_a: &Tensor, + row_ptrs_b: &Tensor, + col_indices_b: &Tensor, + values_b: &Tensor, + nrows: usize, +) -> Result<( + Tensor, + Tensor, + Tensor, +)> { + // Pass 1: Count output size per row + let row_counts = Tensor::::zeros(&[nrows], DType::I32, device); + + // Launch count kernel (union vs intersection semantics determined by strategy) + let count_kernel_name = S::count_kernel_name(SparseFormat::Csr); + let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; + let function = get_kernel_function(&module, count_kernel_name)?; + + let block_size = BLOCK_SIZE; + let grid_size = (nrows as u32 + block_size - 1) / block_size; + let nrows_i32 = nrows as i32; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); + let mut builder = stream.launch_builder(&function); + + // Store pointers to avoid temporary value issues + let row_ptrs_a_ptr = row_ptrs_a.ptr(); + let col_indices_a_ptr = col_indices_a.ptr(); + let row_ptrs_b_ptr = row_ptrs_b.ptr(); + let col_indices_b_ptr = col_indices_b.ptr(); + let row_counts_ptr = row_counts.ptr(); + + builder.arg(&row_ptrs_a_ptr); + builder.arg(&col_indices_a_ptr); + builder.arg(&row_ptrs_b_ptr); + builder.arg(&col_indices_b_ptr); + builder.arg(&row_counts_ptr); + builder.arg(&nrows_i32); + + // SAFETY: Kernel launch is unsafe because: + // 1. Raw pointers are passed to CUDA kernel + // 2. Kernel accesses GPU memory + // Safety requirements satisfied: + // - All pointers are valid GPU memory addresses from CudaRuntime tensors + // - Tensor lifetimes ensure memory is valid during kernel execution + // - nrows matches the actual tensor dimensions + // - Stream synchronization ensures no data races + unsafe { + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA {} kernel launch failed (nrows={}, strategy={:?}): {:?}", + count_kernel_name, + nrows, + S::OP, + e + )) + })?; + } + + // Synchronize to ensure counts are ready + stream + .synchronize() + .map_err(|e| Error::Internal(format!("Stream synchronize failed: {:?}", e)))?; + + // Exclusive scan to get row_ptrs and total_nnz + let (out_row_ptrs, total_nnz) = exclusive_scan_i32(context, stream, device_index, &row_counts)?; + + // Pass 2: Allocate output and compute merged result + let out_col_indices = Tensor::::zeros(&[total_nnz], DType::I32, device); + let out_values = Tensor::::zeros(&[total_nnz], dtype, device); + + // Launch compute kernel (operation-specific) + let compute_kernel_name = S::compute_kernel_name(SparseFormat::Csr, T::NAME); + let function = get_kernel_function(&module, &compute_kernel_name)?; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); + let mut builder = stream.launch_builder(&function); + + // Store pointers to avoid temporary value issues + let row_ptrs_a_ptr = row_ptrs_a.ptr(); + let col_indices_a_ptr = col_indices_a.ptr(); + let values_a_ptr = values_a.ptr(); + let row_ptrs_b_ptr = row_ptrs_b.ptr(); + let col_indices_b_ptr = col_indices_b.ptr(); + let values_b_ptr = values_b.ptr(); + let out_row_ptrs_ptr = out_row_ptrs.ptr(); + let out_col_indices_ptr = out_col_indices.ptr(); + let out_values_ptr = out_values.ptr(); + + builder.arg(&row_ptrs_a_ptr); + builder.arg(&col_indices_a_ptr); + builder.arg(&values_a_ptr); + builder.arg(&row_ptrs_b_ptr); + builder.arg(&col_indices_b_ptr); + builder.arg(&values_b_ptr); + builder.arg(&out_row_ptrs_ptr); + builder.arg(&out_col_indices_ptr); + builder.arg(&out_values_ptr); + builder.arg(&nrows_i32); + + // SAFETY: Kernel launch is unsafe because: + // 1. Raw pointers are passed to CUDA kernel + // 2. Kernel writes to output tensors + // Safety requirements satisfied: + // - All input pointers are valid GPU memory from input tensors + // - Output tensors allocated with correct size (total_nnz from scan) + // - Tensor ownership prevents concurrent modification + // - Stream ordering ensures count kernel completed before compute kernel + unsafe { + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA {} kernel launch failed (nrows={}, total_nnz={}, strategy={:?}): {:?}", + compute_kernel_name, + nrows, + total_nnz, + S::OP, + e + )) + })?; + } + + Ok((out_row_ptrs, out_col_indices, out_values)) +} + +/// Generic two-pass CSC merge using strategy pattern +/// +/// CSC variant of generic_csr_merge. See generic_csr_merge for details. +/// +/// # Safety +/// +/// All tensor arguments must contain valid CUDA device pointers with correct sizes +/// for the given sparse CSC format. `ncols` must match the sparse matrix dimensions. +/// The CUDA stream and context must be valid and associated with the correct device. +pub unsafe fn generic_csc_merge( + context: &Arc, + stream: &CudaStream, + device_index: usize, + device: &::Device, + dtype: DType, + col_ptrs_a: &Tensor, + row_indices_a: &Tensor, + values_a: &Tensor, + col_ptrs_b: &Tensor, + row_indices_b: &Tensor, + values_b: &Tensor, + ncols: usize, +) -> Result<( + Tensor, + Tensor, + Tensor, +)> { + // Pass 1: Count output size per column + let col_counts = Tensor::::zeros(&[ncols], DType::I32, device); + + // Launch count kernel + let count_kernel_name = S::count_kernel_name(SparseFormat::Csc); + let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; + let function = get_kernel_function(&module, count_kernel_name)?; + + let block_size = BLOCK_SIZE; + let grid_size = (ncols as u32 + block_size - 1) / block_size; + let ncols_i32 = ncols as i32; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); + let mut builder = stream.launch_builder(&function); + + // Store pointers to avoid temporary value issues + let col_ptrs_a_ptr = col_ptrs_a.ptr(); + let row_indices_a_ptr = row_indices_a.ptr(); + let col_ptrs_b_ptr = col_ptrs_b.ptr(); + let row_indices_b_ptr = row_indices_b.ptr(); + let col_counts_ptr = col_counts.ptr(); + + builder.arg(&col_ptrs_a_ptr); + builder.arg(&row_indices_a_ptr); + builder.arg(&col_ptrs_b_ptr); + builder.arg(&row_indices_b_ptr); + builder.arg(&col_counts_ptr); + builder.arg(&ncols_i32); + + // SAFETY: Kernel launch is unsafe because: + // 1. Raw pointers are passed to CUDA kernel + // 2. Kernel accesses GPU memory + // Safety requirements satisfied: + // - All pointers are valid GPU memory addresses from CudaRuntime tensors + // - Tensor lifetimes ensure memory is valid during kernel execution + // - ncols matches the actual tensor dimensions + // - Stream synchronization ensures no data races + unsafe { + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA {} kernel launch failed (ncols={}, strategy={:?}): {:?}", + count_kernel_name, + ncols, + S::OP, + e + )) + })?; + } + + // Synchronize to ensure counts are ready + stream + .synchronize() + .map_err(|e| Error::Internal(format!("Stream synchronize failed: {:?}", e)))?; + + // Exclusive scan to get col_ptrs and total_nnz + let (out_col_ptrs, total_nnz) = exclusive_scan_i32(context, stream, device_index, &col_counts)?; + + // Pass 2: Allocate output and compute merged result + let out_row_indices = Tensor::::zeros(&[total_nnz], DType::I32, device); + let out_values = Tensor::::zeros(&[total_nnz], dtype, device); + + // Launch compute kernel + let compute_kernel_name = S::compute_kernel_name(SparseFormat::Csc, T::NAME); + let function = get_kernel_function(&module, &compute_kernel_name)?; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); + let mut builder = stream.launch_builder(&function); + + // Store pointers to avoid temporary value issues + let col_ptrs_a_ptr = col_ptrs_a.ptr(); + let row_indices_a_ptr = row_indices_a.ptr(); + let values_a_ptr = values_a.ptr(); + let col_ptrs_b_ptr = col_ptrs_b.ptr(); + let row_indices_b_ptr = row_indices_b.ptr(); + let values_b_ptr = values_b.ptr(); + let out_col_ptrs_ptr = out_col_ptrs.ptr(); + let out_row_indices_ptr = out_row_indices.ptr(); + let out_values_ptr = out_values.ptr(); + + builder.arg(&col_ptrs_a_ptr); + builder.arg(&row_indices_a_ptr); + builder.arg(&values_a_ptr); + builder.arg(&col_ptrs_b_ptr); + builder.arg(&row_indices_b_ptr); + builder.arg(&values_b_ptr); + builder.arg(&out_col_ptrs_ptr); + builder.arg(&out_row_indices_ptr); + builder.arg(&out_values_ptr); + builder.arg(&ncols_i32); + + // SAFETY: Kernel launch is unsafe because: + // 1. Raw pointers are passed to CUDA kernel + // 2. Kernel writes to output tensors + // Safety requirements satisfied: + // - All input pointers are valid GPU memory from input tensors + // - Output tensors allocated with correct size (total_nnz from scan) + // - Tensor ownership prevents concurrent modification + // - Stream ordering ensures count kernel completed before compute kernel + unsafe { + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA {} kernel launch failed (ncols={}, total_nnz={}, strategy={:?}): {:?}", + compute_kernel_name, + ncols, + total_nnz, + S::OP, + e + )) + })?; + } + + Ok((out_col_ptrs, out_row_indices, out_values)) +} diff --git a/src/runtime/cuda/kernels/sparse_merge/helpers.rs b/src/runtime/cuda/kernels/sparse_merge/helpers.rs new file mode 100644 index 00000000..64018e8d --- /dev/null +++ b/src/runtime/cuda/kernels/sparse_merge/helpers.rs @@ -0,0 +1,233 @@ +//! Helper utilities for sparse merge kernel launchers +//! +//! Shared infrastructure used by CSR and CSC merge operations: +//! - dtype suffix resolution +//! - generic count kernel launcher +//! - generic CSR/CSC compute kernel launchers +//! - exclusive scan wrapper + +#![allow(dead_code)] +#![allow(unsafe_op_in_unsafe_fn)] + +use cudarc::driver::PushKernelArg; +use cudarc::driver::safe::{CudaContext, CudaStream}; +use cudarc::types::CudaTypeName; +use std::sync::Arc; + +use crate::error::{Error, Result}; +use crate::runtime::cuda::CudaRuntime; +use crate::tensor::Tensor; + +use super::super::loader::{ + BLOCK_SIZE, get_kernel_function, get_or_load_module, kernel_names, launch_config, +}; + +// ============================================================================ +// dtype suffix helper +// ============================================================================ + +/// Get dtype-specific kernel name suffix +pub(super) fn dtype_suffix() -> Result<&'static str> { + match T::NAME { + "f32" => Ok("f32"), + "f64" => Ok("f64"), + "__half" => Ok("f16"), + "__nv_bfloat16" => Ok("bf16"), + _ => Err(Error::Internal(format!( + "Unsupported dtype for sparse operation: {}", + T::NAME + ))), + } +} + +// ============================================================================ +// Generic Kernel Launcher Helpers (DRY principle) +// ============================================================================ + +/// Generic launcher for kernels without dtype template (count kernels) +/// +/// Eliminates duplication across count kernel launchers +/// +/// # Safety +/// +/// - `row_ptrs_a`, `col_indices_a`, `row_ptrs_b`, `col_indices_b`, and `row_counts` must be +/// valid device memory pointers on the device associated with `context`. +/// - `nrows` must match the number of rows in both sparse matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_count_kernel( + context: &Arc, + stream: &CudaStream, + device_index: usize, + kernel_name: &str, + row_ptrs_a: u64, + col_indices_a: u64, + row_ptrs_b: u64, + col_indices_b: u64, + row_counts: u64, + nrows: usize, + error_context: &str, +) -> Result<()> { + let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; + let func = get_kernel_function(&module, kernel_name)?; + + let block_size = BLOCK_SIZE; + let grid_size = (nrows as u32 + block_size - 1) / block_size; + let nrows_i32 = nrows as i32; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&row_ptrs_a); + builder.arg(&col_indices_a); + builder.arg(&row_ptrs_b); + builder.arg(&col_indices_b); + builder.arg(&row_counts); + builder.arg(&nrows_i32); + + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("{} kernel launch failed: {:?}", error_context, e)))?; + + Ok(()) +} + +/// Generic launcher for dtype-templated compute kernels (CSR format) +/// +/// Eliminates duplication across CSR add/sub/mul/div compute launchers +/// +/// # Safety +/// +/// - All pointer arguments (`row_ptrs_a`, `col_indices_a`, `values_a`, `row_ptrs_b`, +/// `col_indices_b`, `values_b`, `out_row_ptrs`, `out_col_indices`, `out_values`) must be +/// valid device memory pointers on the device associated with `context`. +/// - Output buffers must be pre-allocated to the correct sizes (determined by a prior count pass). +/// - `nrows` must match the number of rows in both input matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_csr_compute_kernel( + context: &Arc, + stream: &CudaStream, + device_index: usize, + kernel_base_name: &str, + row_ptrs_a: u64, + col_indices_a: u64, + values_a: u64, + row_ptrs_b: u64, + col_indices_b: u64, + values_b: u64, + out_row_ptrs: u64, + out_col_indices: u64, + out_values: u64, + nrows: usize, + error_context: &str, +) -> Result<()> { + let suffix = dtype_suffix::()?; + let kernel_name = format!("{}_{}", kernel_base_name, suffix); + + let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; + let func = get_kernel_function(&module, &kernel_name)?; + + let block_size = BLOCK_SIZE; + let grid_size = (nrows as u32 + block_size - 1) / block_size; + let nrows_i32 = nrows as i32; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&row_ptrs_a); + builder.arg(&col_indices_a); + builder.arg(&values_a); + builder.arg(&row_ptrs_b); + builder.arg(&col_indices_b); + builder.arg(&values_b); + builder.arg(&out_row_ptrs); + builder.arg(&out_col_indices); + builder.arg(&out_values); + builder.arg(&nrows_i32); + + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("{} kernel launch failed: {:?}", error_context, e)))?; + + Ok(()) +} + +/// Generic launcher for dtype-templated compute kernels (CSC format) +/// +/// Eliminates duplication across CSC add/sub/mul/div compute launchers +/// +/// # Safety +/// +/// - All pointer arguments (`col_ptrs_a`, `row_indices_a`, `values_a`, `col_ptrs_b`, +/// `row_indices_b`, `values_b`, `out_col_ptrs`, `out_row_indices`, `out_values`) must be +/// valid device memory pointers on the device associated with `context`. +/// - Output buffers must be pre-allocated to the correct sizes (determined by a prior count pass). +/// - `ncols` must match the number of columns in both input matrices. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. +pub(super) unsafe fn launch_csc_compute_kernel( + context: &Arc, + stream: &CudaStream, + device_index: usize, + kernel_base_name: &str, + col_ptrs_a: u64, + row_indices_a: u64, + values_a: u64, + col_ptrs_b: u64, + row_indices_b: u64, + values_b: u64, + out_col_ptrs: u64, + out_row_indices: u64, + out_values: u64, + ncols: usize, + error_context: &str, +) -> Result<()> { + let suffix = dtype_suffix::()?; + let kernel_name = format!("{}_{}", kernel_base_name, suffix); + + let module = get_or_load_module(context, device_index, kernel_names::SPARSE_MERGE_MODULE)?; + let func = get_kernel_function(&module, &kernel_name)?; + + let block_size = BLOCK_SIZE; + let grid_size = (ncols as u32 + block_size - 1) / block_size; + let ncols_i32 = ncols as i32; + + let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); + let mut builder = stream.launch_builder(&func); + builder.arg(&col_ptrs_a); + builder.arg(&row_indices_a); + builder.arg(&values_a); + builder.arg(&col_ptrs_b); + builder.arg(&row_indices_b); + builder.arg(&values_b); + builder.arg(&out_col_ptrs); + builder.arg(&out_row_indices); + builder.arg(&out_values); + builder.arg(&ncols_i32); + + builder + .launch(cfg) + .map_err(|e| Error::Internal(format!("{} kernel launch failed: {:?}", error_context, e)))?; + + Ok(()) +} + +// ============================================================================ +// Exclusive Scan (Prefix Sum) +// ============================================================================ + +/// Compute exclusive scan (prefix sum) on GPU tensor +/// +/// Input: [3, 1, 4, 2] +/// Output: [0, 3, 4, 8, 10] (n+1 elements, last is total sum) +/// +/// Uses GPU-native parallel scan (no CPU transfer) +pub(super) fn exclusive_scan_i32( + context: &Arc, + stream: &CudaStream, + device_index: usize, + input: &Tensor, +) -> Result<(Tensor, usize)> { + let device = input.device(); + + // Use GPU scan (imported from scan module) + unsafe { + super::super::scan::exclusive_scan_i32_gpu(context, stream, device_index, device, input) + } +} diff --git a/src/runtime/cuda/kernels/sparse_merge/mod.rs b/src/runtime/cuda/kernels/sparse_merge/mod.rs new file mode 100644 index 00000000..fa62e2d1 --- /dev/null +++ b/src/runtime/cuda/kernels/sparse_merge/mod.rs @@ -0,0 +1,15 @@ +//! Sparse matrix element-wise merge kernel launchers +//! +//! Two-pass algorithm for CSR element-wise operations: +//! 1. Count output size per row +//! 2. Exclusive scan to get row_ptrs +//! 3. Compute merged output + +mod csc; +mod csr; +mod generic; +mod helpers; + +pub use csc::*; +pub use csr::*; +pub use generic::*; diff --git a/src/runtime/cuda/kernels/sparse_strategy.rs b/src/runtime/cuda/kernels/sparse_strategy.rs index 8974f27e..1c00ff1f 100644 --- a/src/runtime/cuda/kernels/sparse_strategy.rs +++ b/src/runtime/cuda/kernels/sparse_strategy.rs @@ -26,9 +26,13 @@ /// Sparse element-wise operations #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SparseMergeOp { + /// Element-wise addition: C[i,j] = A[i,j] + B[i,j] (union semantics) Add, + /// Element-wise subtraction: C[i,j] = A[i,j] - B[i,j] (union semantics) Sub, + /// Element-wise multiplication: C[i,j] = A[i,j] * B[i,j] (intersection semantics) Mul, + /// Element-wise division: C[i,j] = A[i,j] / B[i,j] (intersection semantics) Div, } diff --git a/src/runtime/cuda/kernels/sparse_utils.rs b/src/runtime/cuda/kernels/sparse_utils.rs index d2cfe684..f6c0f49d 100644 --- a/src/runtime/cuda/kernels/sparse_utils.rs +++ b/src/runtime/cuda/kernels/sparse_utils.rs @@ -25,6 +25,7 @@ use crate::tensor::Tensor; // Module name // ============================================================================ +/// CUDA module name for sparse utility kernels (filtering, sums, NNZ counting, conversions). pub const SPARSE_UTILS_MODULE: &str = "sparse_utils"; // ============================================================================ @@ -49,6 +50,12 @@ fn dtype_suffix() -> Result<&'static str> { // ============================================================================ /// Cast I32 tensor to I64 (for row_ptrs after scan) +/// +/// # Safety +/// +/// - `input` must be a valid `CudaRuntime` tensor with `DType::I32` residing on the device +/// associated with `context`. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. unsafe fn cast_i32_to_i64_gpu( context: &Arc, stream: &CudaStream, @@ -62,8 +69,8 @@ unsafe fn cast_i32_to_i64_gpu( // Use cast kernel from cast.rs use super::cast::launch_cast; - let input_ptr = input.storage().ptr(); - let output_ptr = output.storage().ptr(); + let input_ptr = input.ptr(); + let output_ptr = output.ptr(); unsafe { launch_cast( @@ -86,6 +93,14 @@ unsafe fn cast_i32_to_i64_gpu( // ============================================================================ /// Pass 1: Count values above threshold per row +/// +/// # Safety +/// +/// - `row_ptrs`, `values`, and `row_counts` must be valid device memory pointers on the device +/// associated with `context`. +/// - `row_ptrs` must have at least `nrows + 1` elements; `row_counts` must have at least `nrows`. +/// - `values` must have at least as many elements as indicated by `row_ptrs[nrows]`. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. unsafe fn launch_filter_csr_count( context: &Arc, stream: &CudaStream, @@ -125,6 +140,14 @@ unsafe fn launch_filter_csr_count( context: &Arc, stream: &CudaStream, @@ -169,7 +192,17 @@ unsafe fn launch_filter_csr_compute( context: &Arc, stream: &CudaStream, @@ -201,9 +234,9 @@ pub unsafe fn filter_csr_values_gpu( context: &Arc, stream: &CudaStream, @@ -274,9 +314,9 @@ pub unsafe fn csr_sum_rows_gpu( let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let row_ptrs_ptr = row_ptrs.storage().ptr(); - let values_ptr = values.storage().ptr(); - let out_ptr = out.storage().ptr(); + let row_ptrs_ptr = row_ptrs.ptr(); + let values_ptr = values.ptr(); + let out_ptr = out.ptr(); let mut builder = stream.launch_builder(&func); builder.arg(&row_ptrs_ptr); @@ -291,7 +331,14 @@ pub unsafe fn csr_sum_rows_gpu( Ok(out) } -/// CSC column-wise sum (GPU kernel) +/// Compute column-wise sum of a CSC sparse matrix (GPU kernel) +/// +/// # Safety +/// +/// - `col_ptrs` and `values` must be valid `CudaRuntime` tensors on the device associated with +/// `context`, with a consistent CSC structure where `col_ptrs` has `ncols + 1` elements. +/// - `ncols` must match the actual number of columns. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn csc_sum_cols_gpu( context: &Arc, stream: &CudaStream, @@ -316,9 +363,9 @@ pub unsafe fn csc_sum_cols_gpu( let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let col_ptrs_ptr = col_ptrs.storage().ptr(); - let values_ptr = values.storage().ptr(); - let out_ptr = out.storage().ptr(); + let col_ptrs_ptr = col_ptrs.ptr(); + let values_ptr = values.ptr(); + let out_ptr = out.ptr(); let mut builder = stream.launch_builder(&func); builder.arg(&col_ptrs_ptr); @@ -337,7 +384,14 @@ pub unsafe fn csc_sum_cols_gpu( // NNZ Counting // ============================================================================ -/// Count non-zeros per row (pointer difference) +/// Count non-zeros per row of a CSR matrix using pointer differences (GPU kernel) +/// +/// # Safety +/// +/// - `row_ptrs` must be a valid `CudaRuntime` tensor on the device associated with `context`, +/// with at least `nrows + 1` elements of type I64. +/// - `nrows` must match the actual number of rows. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn csr_nnz_per_row_gpu( context: &Arc, stream: &CudaStream, @@ -357,8 +411,8 @@ pub unsafe fn csr_nnz_per_row_gpu( let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let row_ptrs_ptr = row_ptrs.storage().ptr(); - let out_ptr = out.storage().ptr(); + let row_ptrs_ptr = row_ptrs.ptr(); + let out_ptr = out.ptr(); let mut builder = stream.launch_builder(&func); builder.arg(&row_ptrs_ptr); @@ -375,7 +429,14 @@ pub unsafe fn csr_nnz_per_row_gpu( Ok(out) } -/// Count non-zeros per column (pointer difference) +/// Count non-zeros per column of a CSC matrix using pointer differences (GPU kernel) +/// +/// # Safety +/// +/// - `col_ptrs` must be a valid `CudaRuntime` tensor on the device associated with `context`, +/// with at least `ncols + 1` elements of type I64. +/// - `ncols` must match the actual number of columns. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn csc_nnz_per_col_gpu( context: &Arc, stream: &CudaStream, @@ -395,8 +456,8 @@ pub unsafe fn csc_nnz_per_col_gpu( let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let col_ptrs_ptr = col_ptrs.storage().ptr(); - let out_ptr = out.storage().ptr(); + let col_ptrs_ptr = col_ptrs.ptr(); + let out_ptr = out.ptr(); let mut builder = stream.launch_builder(&func); builder.arg(&col_ptrs_ptr); @@ -417,7 +478,15 @@ pub unsafe fn csc_nnz_per_col_gpu( // Sparse to Dense Conversion // ============================================================================ -/// Expand CSR to dense matrix (GPU kernel) +/// Expand CSR sparse matrix to a dense matrix (GPU kernel) +/// +/// # Safety +/// +/// - `row_ptrs`, `col_indices`, and `values` must be valid `CudaRuntime` tensors on the device +/// associated with `context` with a consistent CSR structure. +/// - `shape` must match the actual matrix dimensions: `row_ptrs` has `shape[0] + 1` elements, +/// `col_indices` and `values` have `nnz` elements, all column indices are in `[0, shape[1])`. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn csr_to_dense_gpu( context: &Arc, stream: &CudaStream, @@ -446,10 +515,10 @@ pub unsafe fn csr_to_dense_gpu( let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), 0); - let row_ptrs_ptr = row_ptrs.storage().ptr(); - let col_indices_ptr = col_indices.storage().ptr(); - let values_ptr = values.storage().ptr(); - let out_ptr = out.storage().ptr(); + let row_ptrs_ptr = row_ptrs.ptr(); + let col_indices_ptr = col_indices.ptr(); + let values_ptr = values.ptr(); + let out_ptr = out.ptr(); let mut builder = stream.launch_builder(&func); builder.arg(&row_ptrs_ptr); @@ -470,7 +539,14 @@ pub unsafe fn csr_to_dense_gpu( // Dense to COO Conversion (two-pass) // ============================================================================ -/// Pass 1: Count non-zeros per row +/// Pass 1: Count non-zeros per row for dense-to-COO conversion +/// +/// # Safety +/// +/// - `input` must be a valid device memory pointer for a 2D row-major array of at least +/// `nrows * ncols` elements of type `T`. +/// - `row_counts` must be a valid device memory pointer with at least `nrows` i32 elements. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. unsafe fn launch_dense_to_coo_count( context: &Arc, stream: &CudaStream, @@ -510,7 +586,15 @@ unsafe fn launch_dense_to_coo_count( context: &Arc, stream: &CudaStream, @@ -556,7 +640,16 @@ unsafe fn launch_dense_to_coo_extract( context: &Arc, stream: &CudaStream, @@ -572,9 +665,9 @@ pub unsafe fn dense_to_coo_gpu { let shape = input.shape(); if shape.len() != 2 { - return Err(Error::ShapeMismatch { - expected: vec![0, 0], // placeholder - got: shape.to_vec(), + return Err(Error::InvalidArgument { + arg: "input", + reason: format!("dense_to_coo requires a 2D tensor, got {}D", shape.len()), }); } @@ -589,8 +682,8 @@ pub unsafe fn dense_to_coo_gpu( context: &Arc, stream: &CudaStream, diff --git a/src/runtime/cuda/kernels/spgemm.rs b/src/runtime/cuda/kernels/spgemm.rs index 90a5a02c..4c392e77 100644 --- a/src/runtime/cuda/kernels/spgemm.rs +++ b/src/runtime/cuda/kernels/spgemm.rs @@ -15,9 +15,21 @@ use crate::runtime::Runtime; use crate::runtime::cuda::CudaRuntime; use crate::tensor::Tensor; +/// CUDA module name for sparse matrix-matrix multiplication (SpGEMM) kernels. pub const SPGEMM_MODULE: &str = "spgemm"; -/// Phase 1: Symbolic - Count NNZ per output row +/// Phase 1: Symbolic - Count NNZ per output row of C = A * B +/// +/// Uses a bitmap approach per thread to count unique column indices produced by each output row. +/// Allocates dynamic shared memory of `block_size * ceil(n / 8)` bytes. +/// +/// # Safety +/// +/// - All tensor arguments must be valid `CudaRuntime` tensors on the device associated with +/// `context` with consistent CSR structure. +/// - `m` must equal the number of rows in `a`; `n` must equal the number of columns in `b`. +/// - `m * ceil(n / 8)` bytes of shared memory must be available on the device. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn spgemm_symbolic_phase( context: &Arc, stream: &CudaStream, @@ -50,11 +62,11 @@ pub unsafe fn spgemm_symbolic_phase( let cfg = launch_config((grid_size, 1, 1), (block_size, 1, 1), shared_mem_bytes); - let a_row_ptrs_ptr = a_row_ptrs.storage().ptr(); - let a_col_indices_ptr = a_col_indices.storage().ptr(); - let b_row_ptrs_ptr = b_row_ptrs.storage().ptr(); - let b_col_indices_ptr = b_col_indices.storage().ptr(); - let row_nnz_ptr = row_nnz.storage().ptr(); + let a_row_ptrs_ptr = a_row_ptrs.ptr(); + let a_col_indices_ptr = a_col_indices.ptr(); + let b_row_ptrs_ptr = b_row_ptrs.ptr(); + let b_col_indices_ptr = b_col_indices.ptr(); + let row_nnz_ptr = row_nnz.ptr(); let mut builder = stream.launch_builder(&func); builder.arg(&a_row_ptrs_ptr); @@ -80,7 +92,19 @@ pub unsafe fn spgemm_symbolic_phase( Ok(row_nnz) } -/// Phase 2: Numeric - Compute values +/// Phase 2: Numeric - Compute values of C = A * B +/// +/// Fills the pre-allocated output CSR arrays (`c_row_ptrs`, `c_col_indices`, `c_values`) with +/// the computed product. Must be called after `spgemm_symbolic_phase` and exclusive scan. +/// +/// # Safety +/// +/// - All tensor arguments must be valid `CudaRuntime` tensors on the device associated with +/// `context` with consistent CSR structure. +/// - `c_row_ptrs` and `c_col_indices` must be pre-allocated (from the symbolic phase and scan). +/// - `c_values` must be pre-allocated to match the NNZ count from the symbolic phase. +/// - `m` must equal the number of rows in `a`; `n` must equal the number of columns in `b`. +/// - The stream must be from the same context and must not be destroyed while the kernel runs. pub unsafe fn spgemm_numeric_phase( context: &Arc, stream: &CudaStream, @@ -132,15 +156,15 @@ pub unsafe fn spgemm_numeric_phase= 0; d--) { - unsigned int dim_size = (unsigned int)shape[d]; - unsigned int idx = remaining % dim_size; - remaining = remaining / dim_size; - offset += (long long)idx * strides[d]; - } - - return offset; -} - // Generic strided copy kernel - copies element_size bytes per element -// This works for any dtype (f32=4, f64=8, f16=2, etc.) +// Shape and strides are passed by value as fixed-size arrays in kernel args. +// Thread 0 in each block loads them into shared memory to avoid per-thread +// register pressure from 16 scalar args. __global__ void strided_copy( const char* __restrict__ src, char* __restrict__ dst, - const unsigned long long* __restrict__ shape, - const long long* __restrict__ strides, + unsigned long long shape0, + unsigned long long shape1, + unsigned long long shape2, + unsigned long long shape3, + unsigned long long shape4, + unsigned long long shape5, + unsigned long long shape6, + unsigned long long shape7, + long long stride0, + long long stride1, + long long stride2, + long long stride3, + long long stride4, + long long stride5, + long long stride6, + long long stride7, unsigned int numel, unsigned int ndim, unsigned int elem_size, unsigned long long src_byte_offset ) { + // Shared memory: shape and strides loaded once per block by thread 0 + __shared__ unsigned long long s_shape[MAX_DIMS]; + __shared__ long long s_strides[MAX_DIMS]; + + if (threadIdx.x == 0) { + s_shape[0] = shape0; s_shape[1] = shape1; s_shape[2] = shape2; s_shape[3] = shape3; + s_shape[4] = shape4; s_shape[5] = shape5; s_shape[6] = shape6; s_shape[7] = shape7; + s_strides[0] = stride0; s_strides[1] = stride1; s_strides[2] = stride2; s_strides[3] = stride3; + s_strides[4] = stride4; s_strides[5] = stride5; s_strides[6] = stride6; s_strides[7] = stride7; + } + __syncthreads(); + unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x; if (gid >= numel) return; - // Calculate source element offset (in elements) - long long src_elem_offset = get_strided_offset(gid, ndim, shape, strides); + // Convert linear index to strided source offset (row-major) + long long offset = 0; + unsigned int remaining = gid; + for (int d = (int)ndim - 1; d >= 0; d--) { + unsigned int dim_size = (unsigned int)s_shape[d]; + unsigned int idx = remaining % dim_size; + remaining = remaining / dim_size; + offset += (long long)idx * s_strides[d]; + } // Calculate byte addresses - // src_byte_offset is the initial offset into source buffer - // src_elem_offset is the strided offset in elements - unsigned long long src_byte_addr = src_byte_offset + (unsigned long long)((long long)src_elem_offset * (long long)elem_size); + unsigned long long src_byte_addr = src_byte_offset + (unsigned long long)((long long)offset * (long long)elem_size); unsigned long long dst_byte_addr = (unsigned long long)gid * (unsigned long long)elem_size; - // Copy element bytes - // For common element sizes, use optimized paths + // Copy with size-specific optimization if (elem_size == 4) { - // 4-byte elements (f32, i32, u32) *((unsigned int*)(dst + dst_byte_addr)) = *((const unsigned int*)(src + src_byte_addr)); } else if (elem_size == 8) { - // 8-byte elements (f64, i64, u64) *((unsigned long long*)(dst + dst_byte_addr)) = *((const unsigned long long*)(src + src_byte_addr)); } else if (elem_size == 2) { - // 2-byte elements (f16, bf16, i16, u16) *((unsigned short*)(dst + dst_byte_addr)) = *((const unsigned short*)(src + src_byte_addr)); } else if (elem_size == 1) { - // 1-byte elements (i8, u8, bool) dst[dst_byte_addr] = src[src_byte_addr]; } else { - // Generic byte-by-byte copy for unusual element sizes for (unsigned int i = 0; i < elem_size; i++) { dst[dst_byte_addr + i] = src[src_byte_addr + i]; } diff --git a/src/runtime/cuda/kernels/strided_copy.rs b/src/runtime/cuda/kernels/strided_copy.rs index 83df8a5c..d989e83f 100644 --- a/src/runtime/cuda/kernels/strided_copy.rs +++ b/src/runtime/cuda/kernels/strided_copy.rs @@ -3,6 +3,10 @@ //! Provides GPU-accelerated strided-to-contiguous tensor copy operations. //! This replaces the inefficient per-element cuMemcpy approach with a //! parallel CUDA kernel. +//! +//! Shape and strides are passed as kernel arguments (by value), NOT as device +//! memory pointers. This is critical for CUDA graph capture compatibility: +//! device pointers to temporary host-allocated data become stale on graph replay. use cudarc::driver::PushKernelArg; use cudarc::driver::safe::{CudaContext, CudaStream}; @@ -24,12 +28,13 @@ pub const MAX_DIMS: usize = 8; /// Copies non-contiguous (strided) tensor data to a contiguous destination buffer /// using parallel GPU threads. Each thread handles one element. /// +/// Shape and strides are passed as individual kernel arguments (up to MAX_DIMS=8), +/// making this safe for CUDA graph capture/replay. +/// /// # Safety /// /// - `src_ptr` must be valid device memory /// - `dst_ptr` must be valid device memory with space for `numel * elem_size` bytes -/// - `shape_ptr` must point to device memory containing `ndim` u64 values -/// - `strides_ptr` must point to device memory containing `ndim` i64 values /// - All device memory must be allocated on the same device as the stream /// /// # Arguments @@ -39,8 +44,8 @@ pub const MAX_DIMS: usize = 8; /// * `device_index` - Device index for module caching /// * `src_ptr` - Source buffer device pointer /// * `dst_ptr` - Destination buffer device pointer (contiguous) -/// * `shape_ptr` - Device pointer to shape array (u64[ndim]) -/// * `strides_ptr` - Device pointer to strides array (i64[ndim]) +/// * `shape` - Shape array (up to MAX_DIMS elements) +/// * `strides` - Strides array (up to MAX_DIMS elements, in elements) /// * `numel` - Total number of elements /// * `ndim` - Number of dimensions /// * `elem_size` - Size of each element in bytes @@ -51,8 +56,8 @@ pub unsafe fn launch_strided_copy( device_index: usize, src_ptr: u64, dst_ptr: u64, - shape_ptr: u64, - strides_ptr: u64, + shape: &[usize], + strides: &[isize], numel: usize, ndim: usize, elem_size: usize, @@ -69,6 +74,14 @@ pub unsafe fn launch_strided_copy( ))); } + // Pad shape and strides to MAX_DIMS with zeros + let mut shape_args = [0u64; MAX_DIMS]; + let mut stride_args = [0i64; MAX_DIMS]; + for i in 0..ndim { + shape_args[i] = shape[i] as u64; + stride_args[i] = strides[i] as i64; + } + unsafe { let module = get_or_load_module(context, device_index, STRIDED_COPY_MODULE)?; let func = get_kernel_function(&module, "strided_copy")?; @@ -85,8 +98,14 @@ pub unsafe fn launch_strided_copy( let mut builder = stream.launch_builder(&func); builder.arg(&src_ptr); builder.arg(&dst_ptr); - builder.arg(&shape_ptr); - builder.arg(&strides_ptr); + // Pass shape as 8 individual u64 args + for i in 0..MAX_DIMS { + builder.arg(&shape_args[i]); + } + // Pass strides as 8 individual i64 args + for i in 0..MAX_DIMS { + builder.arg(&stride_args[i]); + } builder.arg(&numel_u32); builder.arg(&ndim_u32); builder.arg(&elem_size_u32); @@ -109,19 +128,6 @@ pub unsafe fn launch_strided_copy( /// # Safety /// /// Same requirements as [`launch_strided_copy`]. -/// -/// # Arguments -/// -/// * `context` - CUDA context -/// * `stream` - CUDA stream for async execution -/// * `device_index` - Device index for module caching -/// * `src_ptr` - Source buffer device pointer -/// * `dst_ptr` - Destination buffer device pointer (contiguous) -/// * `outer_size` - Size of outer dimension -/// * `inner_size` - Size of inner (contiguous) dimension -/// * `outer_stride` - Stride of outer dimension (in elements) -/// * `elem_size` - Size of each element in bytes -/// * `src_byte_offset` - Byte offset into source buffer #[allow(dead_code)] // Available for future optimization pub unsafe fn launch_strided_copy_2d( context: &Arc, diff --git a/src/runtime/cuda/kernels/ternary.cu b/src/runtime/cuda/kernels/ternary.cu index c994a644..fb793640 100644 --- a/src/runtime/cuda/kernels/ternary.cu +++ b/src/runtime/cuda/kernels/ternary.cu @@ -45,6 +45,26 @@ __device__ __forceinline__ bool is_nonzero(unsigned int val) { return val != 0; } +template<> +__device__ __forceinline__ bool is_nonzero<__half>(__half val) { + return __half2float(val) != 0.0f; +} + +template<> +__device__ __forceinline__ bool is_nonzero<__nv_bfloat16>(__nv_bfloat16 val) { + return __bfloat162float(val) != 0.0f; +} + +template<> +__device__ __forceinline__ bool is_nonzero(numr_fp8_e4m3 val) { + return fp8_e4m3_to_f32(val.data) != 0.0f; +} + +template<> +__device__ __forceinline__ bool is_nonzero(numr_fp8_e5m2 val) { + return fp8_e5m2_to_f32(val.data) != 0.0f; +} + // ============================================================================ // Where Template (must be outside extern "C") // ============================================================================ @@ -313,6 +333,98 @@ __global__ void where_cond_u32_f64( } } +// ============================================================================ +// F16 condition type +// ============================================================================ + +__global__ void where_cond_f16_f16( + const __half* cond, const __half* x, const __half* y, + __half* out, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = where_impl_generic<__half, __half>(cond[idx], x[idx], y[idx]); + } +} + +__global__ void where_cond_f16_f32( + const __half* cond, const float* x, const float* y, + float* out, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = where_impl_generic<__half, float>(cond[idx], x[idx], y[idx]); + } +} + +__global__ void where_cond_f16_f64( + const __half* cond, const double* x, const double* y, + double* out, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = where_impl_generic<__half, double>(cond[idx], x[idx], y[idx]); + } +} + +// ============================================================================ +// BF16 condition type +// ============================================================================ + +__global__ void where_cond_bf16_bf16( + const __nv_bfloat16* cond, const __nv_bfloat16* x, const __nv_bfloat16* y, + __nv_bfloat16* out, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = where_impl_generic<__nv_bfloat16, __nv_bfloat16>(cond[idx], x[idx], y[idx]); + } +} + +__global__ void where_cond_bf16_f32( + const __nv_bfloat16* cond, const float* x, const float* y, + float* out, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = where_impl_generic<__nv_bfloat16, float>(cond[idx], x[idx], y[idx]); + } +} + +__global__ void where_cond_bf16_f64( + const __nv_bfloat16* cond, const double* x, const double* y, + double* out, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = where_impl_generic<__nv_bfloat16, double>(cond[idx], x[idx], y[idx]); + } +} + +// ============================================================================ +// FP8 condition types +// ============================================================================ + +__global__ void where_cond_fp8_e4m3_fp8_e4m3( + const numr_fp8_e4m3* cond, const numr_fp8_e4m3* x, const numr_fp8_e4m3* y, + numr_fp8_e4m3* out, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = where_impl_generic(cond[idx], x[idx], y[idx]); + } +} + +__global__ void where_cond_fp8_e5m2_fp8_e5m2( + const numr_fp8_e5m2* cond, const numr_fp8_e5m2* x, const numr_fp8_e5m2* y, + numr_fp8_e5m2* out, unsigned int n +) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx] = where_impl_generic(cond[idx], x[idx], y[idx]); + } +} + // ============================================================================ // Where Broadcast Operations (different shapes with broadcasting) // ============================================================================ diff --git a/src/runtime/cuda/kernels/ternary.rs b/src/runtime/cuda/kernels/ternary.rs index 515f9b3f..d5e21f4d 100644 --- a/src/runtime/cuda/kernels/ternary.rs +++ b/src/runtime/cuda/kernels/ternary.rs @@ -139,10 +139,10 @@ pub unsafe fn launch_where_broadcast_op( let shape_tensor = Tensor::::from_slice(&shape_u32, &[ndim], device); // Get device pointers - let cond_strides_ptr = cond_strides_tensor.storage().ptr(); - let x_strides_ptr = x_strides_tensor.storage().ptr(); - let y_strides_ptr = y_strides_tensor.storage().ptr(); - let shape_ptr = shape_tensor.storage().ptr(); + let cond_strides_ptr = cond_strides_tensor.ptr(); + let x_strides_ptr = x_strides_tensor.ptr(); + let y_strides_ptr = y_strides_tensor.ptr(); + let shape_ptr = shape_tensor.ptr(); // Get kernel function let module = get_or_load_module(context, device_index, kernel_names::TERNARY_MODULE)?; @@ -178,10 +178,7 @@ pub unsafe fn launch_where_broadcast_op( })?; } - // Synchronize to ensure the kernel completes before freeing temporary allocations - stream - .synchronize() - .map_err(|e| Error::Internal(format!("Stream sync failed: {:?}", e)))?; + // No sync needed: temporary GPU allocations freed via cuMemFreeAsync (stream-ordered). Ok(()) } @@ -319,10 +316,10 @@ pub unsafe fn launch_where_broadcast_generic_op( let shape_tensor = Tensor::::from_slice(&shape_u32, &[ndim], device); // Get device pointers - let cond_strides_ptr = cond_strides_tensor.storage().ptr(); - let x_strides_ptr = x_strides_tensor.storage().ptr(); - let y_strides_ptr = y_strides_tensor.storage().ptr(); - let shape_ptr = shape_tensor.storage().ptr(); + let cond_strides_ptr = cond_strides_tensor.ptr(); + let x_strides_ptr = x_strides_tensor.ptr(); + let y_strides_ptr = y_strides_tensor.ptr(); + let shape_ptr = shape_tensor.ptr(); // Build kernel name: where_broadcast_cond_{cond_dtype}_{out_dtype} let cond_suffix = super::loader::dtype_suffix(cond_dtype); @@ -365,10 +362,7 @@ pub unsafe fn launch_where_broadcast_generic_op( })?; } - // Synchronize to ensure the kernel completes before freeing temporary allocations - stream - .synchronize() - .map_err(|e| Error::Internal(format!("Stream sync failed: {:?}", e)))?; + // No sync needed: temporary GPU allocations freed via cuMemFreeAsync (stream-ordered). Ok(()) } diff --git a/src/runtime/cuda/kernels/utility.cu b/src/runtime/cuda/kernels/utility.cu index 36c2beab..e9a65572 100644 --- a/src/runtime/cuda/kernels/utility.cu +++ b/src/runtime/cuda/kernels/utility.cu @@ -587,6 +587,68 @@ __global__ void eye_u64(unsigned long long* out, unsigned int n, unsigned int m) } } +// ============================================================================ +// FP8 Arange +// ============================================================================ + +__global__ void arange_fp8_e4m3(numr_fp8_e4m3* out, float start, float step, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx].data = f32_to_fp8_e4m3(start + step * (float)idx); + } +} + +__global__ void arange_fp8_e5m2(numr_fp8_e5m2* out, float start, float step, unsigned int n) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + out[idx].data = f32_to_fp8_e5m2(start + step * (float)idx); + } +} + +// ============================================================================ +// FP8 Linspace +// ============================================================================ + +__global__ void linspace_fp8_e4m3(numr_fp8_e4m3* out, float start, float stop, unsigned int steps) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < steps) { + float t = (float)idx / (float)(steps - 1); + out[idx].data = f32_to_fp8_e4m3(start + (stop - start) * t); + } +} + +__global__ void linspace_fp8_e5m2(numr_fp8_e5m2* out, float start, float stop, unsigned int steps) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < steps) { + float t = (float)idx / (float)(steps - 1); + out[idx].data = f32_to_fp8_e5m2(start + (stop - start) * t); + } +} + +// ============================================================================ +// FP8 Eye +// ============================================================================ + +__global__ void eye_fp8_e4m3(numr_fp8_e4m3* out, unsigned int n, unsigned int m) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int total = n * m; + if (idx < total) { + unsigned int row = idx / m; + unsigned int col = idx % m; + out[idx].data = (row == col) ? f32_to_fp8_e4m3(1.0f) : f32_to_fp8_e4m3(0.0f); + } +} + +__global__ void eye_fp8_e5m2(numr_fp8_e5m2* out, unsigned int n, unsigned int m) { + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int total = n * m; + if (idx < total) { + unsigned int row = idx / m; + unsigned int col = idx % m; + out[idx].data = (row == col) ? f32_to_fp8_e5m2(1.0f) : f32_to_fp8_e5m2(0.0f); + } +} + } // extern "C" - close before template functions // ============================================================================ diff --git a/src/runtime/cuda/kernels/utility.rs b/src/runtime/cuda/kernels/utility.rs index 38b2165d..fe6a0d6e 100644 --- a/src/runtime/cuda/kernels/utility.rs +++ b/src/runtime/cuda/kernels/utility.rs @@ -22,11 +22,26 @@ use crate::error::{Error, Result}; /// while maintaining type safety at the kernel boundary. #[derive(Debug, Clone, Copy)] pub enum FillValue { + /// 32-bit float fill value. F32(f32), + /// 64-bit float fill value. F64(f64), + /// 32-bit signed integer fill value. I32(i32), + /// 64-bit signed integer fill value. I64(i64), + /// 8-bit unsigned integer fill value (also used for Bool). U8(u8), + /// 16-bit float fill value (raw bits for __half). + #[cfg(feature = "f16")] + F16(u16), + /// 16-bit bfloat fill value (raw bits for __nv_bfloat16). + #[cfg(feature = "f16")] + BF16(u16), + /// FP8 E4M3 fill value (raw bits). + FP8E4M3(u8), + /// FP8 E5M2 fill value (raw bits). + FP8E5M2(u8), } impl FillValue { @@ -39,9 +54,16 @@ impl FillValue { DType::I64 => FillValue::I64(value as i64), DType::U8 | DType::Bool => FillValue::U8(value as u8), #[cfg(feature = "f16")] - DType::F16 | DType::BF16 => FillValue::F32(value as f32), // F16/BF16 kernels use f32 value - DType::FP8E4M3 | DType::FP8E5M2 => FillValue::F32(value as f32), // FP8 kernels use f32 value - _ => FillValue::F64(value), // Default fallback + DType::F16 => FillValue::F16(half::f16::from_f64(value).to_bits()), + #[cfg(feature = "f16")] + DType::BF16 => FillValue::BF16(half::bf16::from_f64(value).to_bits()), + DType::FP8E4M3 => { + FillValue::FP8E4M3(crate::dtype::fp8::FP8E4M3::from_f64(value).to_bits()) + } + DType::FP8E5M2 => { + FillValue::FP8E5M2(crate::dtype::fp8::FP8E5M2::from_f64(value).to_bits()) + } + _ => FillValue::F64(value), } } @@ -53,6 +75,12 @@ impl FillValue { FillValue::I32(_) => DType::I32, FillValue::I64(_) => DType::I64, FillValue::U8(_) => DType::U8, + #[cfg(feature = "f16")] + FillValue::F16(_) => DType::F16, + #[cfg(feature = "f16")] + FillValue::BF16(_) => DType::BF16, + FillValue::FP8E4M3(_) => DType::FP8E4M3, + FillValue::FP8E5M2(_) => DType::FP8E5M2, } } } @@ -146,6 +174,36 @@ pub unsafe fn launch_fill( builder.arg(&n); unsafe { builder.launch(cfg) } } + #[cfg(feature = "f16")] + FillValue::F16(v) => { + let mut builder = stream.launch_builder(&func); + builder.arg(&out_ptr); + builder.arg(&v); + builder.arg(&n); + unsafe { builder.launch(cfg) } + } + #[cfg(feature = "f16")] + FillValue::BF16(v) => { + let mut builder = stream.launch_builder(&func); + builder.arg(&out_ptr); + builder.arg(&v); + builder.arg(&n); + unsafe { builder.launch(cfg) } + } + FillValue::FP8E4M3(v) => { + let mut builder = stream.launch_builder(&func); + builder.arg(&out_ptr); + builder.arg(&v); + builder.arg(&n); + unsafe { builder.launch(cfg) } + } + FillValue::FP8E5M2(v) => { + let mut builder = stream.launch_builder(&func); + builder.arg(&out_ptr); + builder.arg(&v); + builder.arg(&n); + unsafe { builder.launch(cfg) } + } }; launch_result.map_err(|e| { @@ -510,6 +568,23 @@ pub unsafe fn launch_arange( )) })?; }, + #[cfg(feature = "fp8")] + DType::FP8E4M3 | DType::FP8E5M2 => unsafe { + // FP8 kernels take f32 parameters (compute in f32, store as fp8) + let start_f32 = start as f32; + let step_f32 = step as f32; + let mut builder = stream.launch_builder(&func); + builder.arg(&out_ptr); + builder.arg(&start_f32); + builder.arg(&step_f32); + builder.arg(&n); + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA arange kernel '{}' launch failed: {:?}", + func_name, e + )) + })?; + }, _ => { return Err(Error::UnsupportedDType { dtype, @@ -622,6 +697,22 @@ pub unsafe fn launch_linspace( )) })?; }, + #[cfg(feature = "fp8")] + DType::FP8E4M3 | DType::FP8E5M2 => unsafe { + let start_f32 = start as f32; + let stop_f32 = stop as f32; + let mut builder = stream.launch_builder(&func); + builder.arg(&out_ptr); + builder.arg(&start_f32); + builder.arg(&stop_f32); + builder.arg(&n); + builder.launch(cfg).map_err(|e| { + Error::Internal(format!( + "CUDA linspace kernel '{}' launch failed: {:?}", + func_name, e + )) + })?; + }, _ => { return Err(Error::UnsupportedDType { dtype, diff --git a/src/runtime/cuda/linalg/advanced_decompositions.rs b/src/runtime/cuda/linalg/advanced_decompositions.rs index 564108c3..424c4ddb 100644 --- a/src/runtime/cuda/linalg/advanced_decompositions.rs +++ b/src/runtime/cuda/linalg/advanced_decompositions.rs @@ -56,8 +56,8 @@ pub fn rsf2csf_impl( client.stream(), device.index, dtype, - schur.z.storage().ptr(), - schur.t.storage().ptr(), + schur.z.ptr(), + schur.t.ptr(), z_real_ptr, z_imag_ptr, t_real_ptr, @@ -135,8 +135,8 @@ pub fn qz_decompose_impl( let flag_ptr = flag_guard.ptr(); // Copy input matrices to S and T (will be modified in-place) - CudaRuntime::copy_within_device(a.storage().ptr(), s_ptr, matrix_size, device)?; - CudaRuntime::copy_within_device(b.storage().ptr(), t_ptr, matrix_size, device)?; + CudaRuntime::copy_within_device(a.ptr(), s_ptr, matrix_size, device)?; + CudaRuntime::copy_within_device(b.ptr(), t_ptr, matrix_size, device)?; // Initialize converged flag to 0 let zero_flag = [0i32]; diff --git a/src/runtime/cuda/linalg/banded.rs b/src/runtime/cuda/linalg/banded.rs index 3e079595..247c645e 100644 --- a/src/runtime/cuda/linalg/banded.rs +++ b/src/runtime/cuda/linalg/banded.rs @@ -116,8 +116,8 @@ pub fn solve_banded_impl( client.stream(), device.index, dtype, - ab_contig.storage().ptr(), - b_contig.storage().ptr(), + ab_contig.ptr(), + b_contig.ptr(), x_ptr, work_ptr, n, @@ -142,7 +142,7 @@ pub fn solve_banded_impl( client.stream(), device.index, dtype, - b_contig.storage().ptr(), + b_contig.ptr(), b_col_ptr, n, nrhs, @@ -158,7 +158,7 @@ pub fn solve_banded_impl( client.stream(), device.index, dtype, - ab_contig.storage().ptr(), + ab_contig.ptr(), b_col_ptr, x_col_ptr, work_ptr, diff --git a/src/runtime/cuda/linalg/decompositions.rs b/src/runtime/cuda/linalg/decompositions.rs index 96881b8c..c41cf463 100644 --- a/src/runtime/cuda/linalg/decompositions.rs +++ b/src/runtime/cuda/linalg/decompositions.rs @@ -40,7 +40,7 @@ pub fn lu_decompose_impl( let singular_flag_ptr = singular_flag_guard.ptr(); // Copy input to LU buffer - CudaRuntime::copy_within_device(a.storage().ptr(), lu_ptr, lu_size, device)?; + CudaRuntime::copy_within_device(a.ptr(), lu_ptr, lu_size, device)?; // Zero-initialize flags let zero_i32: [u8; 4] = [0; 4]; @@ -114,7 +114,7 @@ pub fn cholesky_decompose_impl( let not_pd_flag_ptr = not_pd_flag_guard.ptr(); // Copy input to L buffer - CudaRuntime::copy_within_device(a.storage().ptr(), l_ptr, l_size, device)?; + CudaRuntime::copy_within_device(a.ptr(), l_ptr, l_size, device)?; // Zero-initialize flag let zero_i32: [u8; 4] = [0; 4]; @@ -179,7 +179,7 @@ pub fn qr_decompose_internal( let workspace_ptr = workspace_guard.ptr(); // Copy A to R (will be modified in place) - CudaRuntime::copy_within_device(a.storage().ptr(), r_ptr, r_size, device)?; + CudaRuntime::copy_within_device(a.ptr(), r_ptr, r_size, device)?; let result = unsafe { kernels::launch_qr_decompose( diff --git a/src/runtime/cuda/linalg/eig_general.rs b/src/runtime/cuda/linalg/eig_general.rs index 709d9014..4ab09c38 100644 --- a/src/runtime/cuda/linalg/eig_general.rs +++ b/src/runtime/cuda/linalg/eig_general.rs @@ -52,7 +52,7 @@ pub fn eig_decompose_impl( let flag_ptr = flag_guard.ptr(); // Copy A to T (working buffer) - CudaRuntime::copy_within_device(a.storage().ptr(), t_ptr, matrix_size, device)?; + CudaRuntime::copy_within_device(a.ptr(), t_ptr, matrix_size, device)?; // Initialize converged flag to 0 let zero_flag = [0i32]; diff --git a/src/runtime/cuda/linalg/eig_symmetric.rs b/src/runtime/cuda/linalg/eig_symmetric.rs index 22601046..bd82c75d 100644 --- a/src/runtime/cuda/linalg/eig_symmetric.rs +++ b/src/runtime/cuda/linalg/eig_symmetric.rs @@ -43,12 +43,7 @@ pub fn eig_decompose_symmetric_impl( let eigenvectors_ptr = client.allocator().allocate(eigenvectors_size)?; // Copy the single element as eigenvalue - CudaRuntime::copy_within_device( - a.storage().ptr(), - eigenvalues_ptr, - eigenvalues_size, - device, - )?; + CudaRuntime::copy_within_device(a.ptr(), eigenvalues_ptr, eigenvalues_size, device)?; // Eigenvector is [1.0] match dtype { @@ -92,7 +87,7 @@ pub fn eig_decompose_symmetric_impl( let converged_flag_ptr = converged_flag_guard.ptr(); // Copy input to working buffer - CudaRuntime::copy_within_device(a.storage().ptr(), work_ptr, work_size, device)?; + CudaRuntime::copy_within_device(a.ptr(), work_ptr, work_size, device)?; // Zero-initialize converged flag let zero_i32: [u8; 4] = [0; 4]; diff --git a/src/runtime/cuda/linalg/matrix_functions.rs b/src/runtime/cuda/linalg/matrix_functions.rs index 1a93aa4b..40125459 100644 --- a/src/runtime/cuda/linalg/matrix_functions.rs +++ b/src/runtime/cuda/linalg/matrix_functions.rs @@ -36,7 +36,7 @@ use crate::tensor::Tensor; /// Get the GPU buffer pointer from a tensor. fn get_tensor_ptr(tensor: &Tensor) -> u64 { - tensor.storage().ptr() + tensor.ptr() } /// Read a single scalar f64 value from GPU tensor using cuMemcpyDtoH_v2. diff --git a/src/runtime/cuda/linalg/matrix_ops.rs b/src/runtime/cuda/linalg/matrix_ops.rs index ce7534b5..61958336 100644 --- a/src/runtime/cuda/linalg/matrix_ops.rs +++ b/src/runtime/cuda/linalg/matrix_ops.rs @@ -82,7 +82,7 @@ pub fn inverse_impl(client: &CudaClient, a: &Tensor) -> Result) -> Result) -> Result) -> Result) -> Result) -> Result) -> Result CudaClient { +fn create_client() -> Option { + if !is_cuda_available() { + return None; + } let device = CudaDevice::new(0); - CudaRuntime::default_client(&device) + Some(CudaRuntime::default_client(&device)) } #[test] fn test_trace() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // 2x2 matrix: [[1, 2], [3, 4]] @@ -30,7 +35,9 @@ fn test_trace() { #[test] fn test_diag() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // 2x3 matrix @@ -46,7 +53,9 @@ fn test_diag() { #[test] fn test_diagflat() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); let a = Tensor::::from_slice(&[1.0f32, 2.0, 3.0], &[3], device); @@ -64,7 +73,9 @@ fn test_diagflat() { #[test] fn test_lu_decomposition() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // 2x2 matrix: [[4, 3], [6, 3]] @@ -78,7 +89,9 @@ fn test_lu_decomposition() { #[test] fn test_cholesky() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // Symmetric positive definite: [[4, 2], [2, 5]] @@ -95,7 +108,9 @@ fn test_cholesky() { #[test] fn test_det() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // 2x2 matrix: [[1, 2], [3, 4]] @@ -110,7 +125,9 @@ fn test_det() { #[test] fn test_solve() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // Solve [[2, 1], [1, 2]] @ x = [3, 3] @@ -127,7 +144,9 @@ fn test_solve() { #[test] fn test_inverse() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // Test 2x2 matrix: [[4, 7], [2, 6]] @@ -147,7 +166,9 @@ fn test_inverse() { #[test] fn test_inverse_identity() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // A @ A^-1 should equal I @@ -166,7 +187,9 @@ fn test_inverse_identity() { #[test] fn test_matrix_rank_full() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // Full rank 2x2 matrix @@ -180,7 +203,9 @@ fn test_matrix_rank_full() { #[test] fn test_matrix_rank_deficient() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // Rank-deficient 2x2 matrix (rows are linearly dependent) @@ -194,7 +219,9 @@ fn test_matrix_rank_deficient() { #[test] fn test_qr_decomposition() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // Test QR: A = Q @ R @@ -220,7 +247,9 @@ fn test_qr_decomposition() { #[test] fn test_solve_multi_rhs() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // Solve A @ X = B where B has multiple columns @@ -259,7 +288,9 @@ fn test_solve_multi_rhs() { #[test] fn test_lstsq_overdetermined() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // Overdetermined system: A is 3x2, b is 3x1 @@ -283,7 +314,9 @@ fn test_lstsq_overdetermined() { #[test] fn test_lstsq_multi_rhs() { - let client = create_client(); + let Some(client) = create_client() else { + return; + }; let device = client.device(); // Overdetermined system with multiple RHS diff --git a/src/runtime/cuda/mod.rs b/src/runtime/cuda/mod.rs index 88b69ebc..f0f860d4 100644 --- a/src/runtime/cuda/mod.rs +++ b/src/runtime/cuda/mod.rs @@ -24,9 +24,12 @@ mod cache; mod client; +#[cfg(feature = "nccl")] +mod communicator; mod device; mod fft; -pub(crate) mod kernels; +mod graph; +pub mod kernels; mod linalg; mod ops; mod polynomial; @@ -37,5 +40,8 @@ mod special; pub use crate::tensor::Tensor; pub use client::{CudaAllocator, CudaClient, CudaRawHandle}; +#[cfg(feature = "nccl")] +pub use communicator::NcclCommunicator; pub use device::{CudaDevice, CudaError}; +pub use graph::CudaGraph; pub use runtime::{CudaRuntime, cuda_device, cuda_device_id, is_cuda_available}; diff --git a/src/runtime/cuda/ops/helpers.rs b/src/runtime/cuda/ops/helpers.rs index fa2a95ee..c9f13b9c 100644 --- a/src/runtime/cuda/ops/helpers.rs +++ b/src/runtime/cuda/ops/helpers.rs @@ -3,9 +3,9 @@ use super::super::kernels::launch_scalar_op_half; use super::super::kernels::{ AccumulationPrecision, launch_binary_op, launch_broadcast_binary_op, - launch_broadcast_compare_op, launch_compare_op, launch_matmul_batched_kernel, - launch_matmul_bias_batched_kernel, launch_matmul_bias_kernel, launch_matmul_kernel, - launch_reduce_dim_op, launch_scalar_op_f32, launch_scalar_op_f64, + launch_broadcast_compare_op, launch_compare_op, launch_gemv_kernel_bt_mr, + launch_matmul_batched_kernel, launch_matmul_bias_batched_kernel, launch_matmul_bias_kernel, + launch_matmul_kernel, launch_reduce_dim_op, launch_scalar_op_f32, launch_scalar_op_f64, launch_semiring_matmul_batched_kernel, launch_semiring_matmul_kernel, launch_unary_op, }; use super::super::kernels::{ @@ -26,6 +26,21 @@ use crate::tensor::Tensor; /// /// Uses shared memory tiling for cache efficiency. This is the default /// implementation that works without any vendor dependencies. +/// Detect if a 2D tensor is a simple transpose of a contiguous [N,K] matrix. +/// +/// A tensor with shape [K, N] and strides [1, K] is a transpose view of +/// contiguous [N, K] data. We can pass the raw pointer directly to gemv_bt +/// instead of materializing the transpose (which copies the entire matrix). +fn is_simple_transpose_2d(tensor: &Tensor) -> bool { + let shape = tensor.shape(); + let strides = tensor.strides(); + if shape.len() != 2 { + return false; + } + // shape=[K,N], strides=[1,K] means transpose of contiguous [N,K] + strides[0] == 1 && strides[1] == shape[0] as isize +} + pub(crate) fn matmul_native( client: &CudaClient, a: &Tensor, @@ -35,14 +50,41 @@ pub(crate) fn matmul_native( k: usize, n: usize, ) -> Result> { - let a_contig = ensure_contiguous(a); - let b_contig = ensure_contiguous(b); - let out_shape = matmul_output_shape(a.shape(), b.shape()).ok_or(Error::ShapeMismatch { expected: a.shape().to_vec(), got: b.shape().to_vec(), })?; + // Fast path: if B is a transposed view of contiguous [N,K] and M is small, + // use gemv_bt kernel directly — avoids copying the entire weight matrix. + if m <= 16 && is_simple_transpose_2d(b) { + let a_contig = ensure_contiguous(a); + let out = Tensor::::empty(&out_shape, dtype, &client.device); + + unsafe { + launch_gemv_kernel_bt_mr( + &client.context, + &client.stream, + client.device.index, + dtype, + a_contig.ptr(), + b.ptr(), // raw [N,K] pointer — no copy! + out.ptr(), + 1, // batch + m, + n, + k, + 1, // a_batch + 1, // b_batch + )?; + } + + return Ok(out); + } + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let out = Tensor::::empty(&out_shape, dtype, &client.device); unsafe { @@ -51,9 +93,9 @@ pub(crate) fn matmul_native( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - b_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), m, n, k, @@ -63,6 +105,37 @@ pub(crate) fn matmul_native( Ok(out) } +/// Detect if the last two dims of a 3D tensor are a simple transpose. +/// Shape [B, K, N] with strides [B_stride, 1, K] means each batch slice +/// is a transpose of contiguous [N, K]. +fn is_batched_transpose_last2(tensor: &Tensor) -> bool { + let shape = tensor.shape(); + let strides = tensor.strides(); + if shape.len() != 3 { + return false; + } + let k = shape[1]; + let n = shape[2]; + // strides: [n*k, 1, k] means transpose of contiguous [batch, N, K] + strides[1] == 1 && strides[2] == k as isize && strides[0] == (n * k) as isize +} + +/// Compute batch count for A and B from their shapes. +/// Returns (a_batch_count, b_batch_count) where each is the product of +/// the leading dimensions (all dims except the last two). +/// Returns 1 for 2D tensors (no batch dimension). +fn compute_batch_counts(a_shape: &[usize], b_shape: &[usize]) -> (usize, usize) { + let a_batch: usize = a_shape + .iter() + .take(a_shape.len().saturating_sub(2)) + .product(); + let b_batch: usize = b_shape + .iter() + .take(b_shape.len().saturating_sub(2)) + .product(); + (a_batch.max(1), b_batch.max(1)) +} + /// Native batched matrix multiplication using tiled CUDA kernel. pub(crate) fn matmul_batched_native( client: &CudaClient, @@ -74,14 +147,42 @@ pub(crate) fn matmul_batched_native( k: usize, n: usize, ) -> Result> { - let a_contig = ensure_contiguous(a); - let b_contig = ensure_contiguous(b); - let out_shape = matmul_output_shape(a.shape(), b.shape()).ok_or(Error::ShapeMismatch { expected: a.shape().to_vec(), got: b.shape().to_vec(), })?; + let (a_batch, b_batch) = compute_batch_counts(a.shape(), b.shape()); + + // Fast path: transposed B with small M → gemv_bt + if m <= 16 && is_batched_transpose_last2(b) { + let a_contig = ensure_contiguous(a); + let out = Tensor::::empty(&out_shape, dtype, &client.device); + + unsafe { + launch_gemv_kernel_bt_mr( + &client.context, + &client.stream, + client.device.index, + dtype, + a_contig.ptr(), + b.ptr(), + out.ptr(), + batch, + m, + n, + k, + a_batch, + b_batch, + )?; + } + + return Ok(out); + } + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let out = Tensor::::empty(&out_shape, dtype, &client.device); unsafe { @@ -90,13 +191,15 @@ pub(crate) fn matmul_batched_native( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - b_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), batch, m, n, k, + a_batch, + b_batch, )?; } @@ -140,10 +243,10 @@ pub(crate) fn matmul_bias_native( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - b_contig.storage().ptr(), - bias_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + b_contig.ptr(), + bias_contig.ptr(), + out.ptr(), m, n, k, @@ -177,6 +280,8 @@ pub(crate) fn matmul_bias_batched_native( }, )?; + let (a_batch, b_batch) = compute_batch_counts(a.shape(), b.shape()); + let out = Tensor::::empty(&out_shape, dtype, &client.device); unsafe { @@ -185,14 +290,16 @@ pub(crate) fn matmul_bias_batched_native( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - b_contig.storage().ptr(), - bias_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + b_contig.ptr(), + bias_contig.ptr(), + out.ptr(), batch, m, n, k, + a_batch, + b_batch, )?; } @@ -234,9 +341,9 @@ pub(crate) fn native_binary_op( client.device.index, op, dtype, - a_contig.storage().ptr(), - b_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), out.numel(), )?; } @@ -257,9 +364,9 @@ pub(crate) fn native_binary_op( &client.device, op, dtype, - a_contig.storage().ptr(), - b_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), a.shape(), b.shape(), &out_shape, @@ -292,8 +399,8 @@ pub(crate) fn native_unary_op( client.device.index, op, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), out.numel(), )?; } @@ -334,9 +441,9 @@ pub(crate) fn native_scalar_op( &client.stream, client.device.index, op, - a_contig.storage().ptr(), + a_contig.ptr(), scalar as f32, - out.storage().ptr(), + out.ptr(), out.numel(), )?, DType::F64 => launch_scalar_op_f64( @@ -344,9 +451,9 @@ pub(crate) fn native_scalar_op( &client.stream, client.device.index, op, - a_contig.storage().ptr(), + a_contig.ptr(), scalar, - out.storage().ptr(), + out.ptr(), out.numel(), )?, DType::I32 => launch_scalar_op_i32( @@ -354,9 +461,9 @@ pub(crate) fn native_scalar_op( &client.stream, client.device.index, op, - a_contig.storage().ptr(), + a_contig.ptr(), scalar as i32, - out.storage().ptr(), + out.ptr(), out.numel(), )?, DType::I64 => launch_scalar_op_i64( @@ -364,9 +471,9 @@ pub(crate) fn native_scalar_op( &client.stream, client.device.index, op, - a_contig.storage().ptr(), + a_contig.ptr(), scalar as i64, - out.storage().ptr(), + out.ptr(), out.numel(), )?, #[cfg(feature = "f16")] @@ -376,9 +483,9 @@ pub(crate) fn native_scalar_op( client.device.index, op, dtype, - a_contig.storage().ptr(), + a_contig.ptr(), scalar as f32, - out.storage().ptr(), + out.ptr(), out.numel(), )?, DType::FP8E4M3 | DType::FP8E5M2 => launch_scalar_op_half( @@ -387,9 +494,9 @@ pub(crate) fn native_scalar_op( client.device.index, op, dtype, - a_contig.storage().ptr(), + a_contig.ptr(), scalar as f32, - out.storage().ptr(), + out.ptr(), out.numel(), )?, DType::Complex64 => launch_scalar_op_c64( @@ -397,9 +504,9 @@ pub(crate) fn native_scalar_op( &client.stream, client.device.index, op, - a_contig.storage().ptr(), + a_contig.ptr(), scalar as f32, - out.storage().ptr(), + out.ptr(), out.numel(), )?, DType::Complex128 => launch_scalar_op_c128( @@ -407,9 +514,9 @@ pub(crate) fn native_scalar_op( &client.stream, client.device.index, op, - a_contig.storage().ptr(), + a_contig.ptr(), scalar, - out.storage().ptr(), + out.ptr(), out.numel(), )?, _ => { @@ -469,8 +576,8 @@ pub(crate) fn native_reduce_op( client.device.index, op, dtype, - a_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + out.ptr(), outer_size, reduce_size, inner_size, @@ -538,9 +645,9 @@ pub(crate) fn native_compare_op( client.device.index, op, dtype, - a_contig.storage().ptr(), - b_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), out.numel(), )?; } @@ -561,9 +668,9 @@ pub(crate) fn native_compare_op( &client.device, op, dtype, - a_contig.storage().ptr(), - b_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), a.shape(), b.shape(), &out_shape, @@ -604,9 +711,9 @@ pub(crate) fn semiring_matmul_native( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - b_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), m, n, k, @@ -637,6 +744,8 @@ pub(crate) fn semiring_matmul_batched_native( got: b.shape().to_vec(), })?; + let (a_batch, b_batch) = compute_batch_counts(a.shape(), b.shape()); + let out = Tensor::::empty(&out_shape, dtype, &client.device); unsafe { @@ -645,14 +754,16 @@ pub(crate) fn semiring_matmul_batched_native( &client.stream, client.device.index, dtype, - a_contig.storage().ptr(), - b_contig.storage().ptr(), - out.storage().ptr(), + a_contig.ptr(), + b_contig.ptr(), + out.ptr(), batch, m, n, k, semiring_op, + a_batch, + b_batch, )?; } diff --git a/src/runtime/cuda/ops/mod.rs b/src/runtime/cuda/ops/mod.rs index 4451f0a4..77853568 100644 --- a/src/runtime/cuda/ops/mod.rs +++ b/src/runtime/cuda/ops/mod.rs @@ -13,11 +13,14 @@ mod tests { ActivationOps, BinaryOps, IndexingOps, MatmulOps, NormalizationOps, ReduceOps, }; use crate::runtime::Runtime; - use crate::runtime::cuda::{CudaDevice, CudaRuntime}; + use crate::runtime::cuda::{CudaDevice, CudaRuntime, is_cuda_available}; use crate::tensor::Tensor; #[test] fn test_cuda_tensor_add() { + if !is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -33,6 +36,9 @@ mod tests { #[test] fn test_cuda_tensor_matmul_2x2() { + if !is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -48,6 +54,9 @@ mod tests { #[test] fn test_cuda_tensor_matmul_3x2_2x4() { + if !is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -73,6 +82,9 @@ mod tests { #[test] fn test_cuda_tensor_relu() { + if !is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -85,6 +97,9 @@ mod tests { #[test] fn test_cuda_tensor_sum() { + if !is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -100,6 +115,9 @@ mod tests { #[test] fn test_cuda_tensor_silu() { + if !is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -118,6 +136,9 @@ mod tests { #[test] fn test_cuda_tensor_gelu() { + if !is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -135,6 +156,9 @@ mod tests { #[test] fn test_cuda_tensor_rms_norm() { + if !is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -163,6 +187,9 @@ mod tests { #[test] fn test_cuda_tensor_layer_norm() { + if !is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -198,6 +225,9 @@ mod tests { #[test] fn test_cuda_tensor_argmax() { + if !is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); @@ -226,6 +256,9 @@ mod tests { #[test] fn test_cuda_tensor_argmin() { + if !is_cuda_available() { + return; + } let device = CudaDevice::new(0); let client = CudaRuntime::default_client(&device); diff --git a/src/runtime/cuda/ops/statistics/mod.rs b/src/runtime/cuda/ops/statistics/mod.rs index d4d06773..d19417e7 100644 --- a/src/runtime/cuda/ops/statistics/mod.rs +++ b/src/runtime/cuda/ops/statistics/mod.rs @@ -35,8 +35,8 @@ pub use quantile::{median_impl, percentile_impl, quantile_impl}; use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::TypeConversionOps; +use crate::runtime::common::statistics_common::compute_bin_edges_f64; use crate::runtime::cuda::{CudaClient, CudaRuntime}; -use crate::runtime::statistics_common::compute_bin_edges_f64; use crate::tensor::Tensor; /// Create bin edges tensor from computed f64 edges. @@ -93,7 +93,7 @@ pub(crate) fn read_scalar_f64(t: &Tensor) -> Result { }; // Get GPU buffer pointer - let ptr = tensor.storage().ptr(); + let ptr = tensor.ptr(); // Allocate host memory and copy from GPU based on dtype let result = match dtype { diff --git a/src/runtime/cuda/ops/statistics/mode.rs b/src/runtime/cuda/ops/statistics/mode.rs index 31262489..559091aa 100644 --- a/src/runtime/cuda/ops/statistics/mode.rs +++ b/src/runtime/cuda/ops/statistics/mode.rs @@ -94,9 +94,9 @@ pub fn mode_impl( &client.stream, client.device.index, dtype, - sorted_contig.storage().ptr(), - mode_values.storage().ptr(), - mode_counts.storage().ptr(), + sorted_contig.ptr(), + mode_values.ptr(), + mode_counts.ptr(), outer_size, reduce_size, inner_size, diff --git a/src/runtime/cuda/ops/statistics/moments.rs b/src/runtime/cuda/ops/statistics/moments.rs index c34c3338..798ab654 100644 --- a/src/runtime/cuda/ops/statistics/moments.rs +++ b/src/runtime/cuda/ops/statistics/moments.rs @@ -5,8 +5,8 @@ use crate::algorithm::linalg::helpers::{linalg_demote, linalg_promote}; use crate::error::Result; +use crate::runtime::common::statistics_common; use crate::runtime::cuda::{CudaClient, CudaRuntime}; -use crate::runtime::statistics_common; use crate::tensor::Tensor; /// Compute skewness (third standardized moment) using composition. diff --git a/src/runtime/cuda/ops/statistics/quantile.rs b/src/runtime/cuda/ops/statistics/quantile.rs index 9fdc1300..04943cad 100644 --- a/src/runtime/cuda/ops/statistics/quantile.rs +++ b/src/runtime/cuda/ops/statistics/quantile.rs @@ -3,9 +3,9 @@ use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::{BinaryOps, IndexingOps, ScalarOps, SortingOps, TypeConversionOps}; +use crate::runtime::common::statistics_common::Interpolation; use crate::runtime::cuda::{CudaClient, CudaRuntime}; use crate::runtime::normalize_dim; -use crate::runtime::statistics_common::Interpolation; use crate::tensor::Tensor; /// Compute quantile along a dimension entirely on GPU. @@ -91,7 +91,7 @@ pub fn quantile_impl( // Calculate quantile indices (small computation, OK on CPU) let (floor_idx, ceil_idx, frac) = - crate::runtime::statistics_common::compute_quantile_indices(q, dim_size); + crate::runtime::common::statistics_common::compute_quantile_indices(q, dim_size); // index_select requires at least 1D indices, so use [1] for scalar output let is_scalar_output = out_shape.is_empty(); diff --git a/src/runtime/cuda/ops/tensor.rs b/src/runtime/cuda/ops/tensor.rs index 733ae97f..988490aa 100644 --- a/src/runtime/cuda/ops/tensor.rs +++ b/src/runtime/cuda/ops/tensor.rs @@ -78,6 +78,9 @@ mod distance; #[path = "../../../ops/cuda/multivariate.rs"] mod multivariate; +#[path = "../../../ops/cuda/gemm_epilogue.rs"] +mod gemm_epilogue; + #[path = "../../../ops/cuda/semiring_matmul.rs"] mod semiring_matmul; @@ -92,3 +95,11 @@ mod logical; #[path = "../../../ops/cuda/einsum.rs"] mod einsum; + +#[cfg(feature = "fp8")] +#[path = "../../../ops/cuda/fp8_matmul.rs"] +mod fp8_matmul; + +#[cfg(feature = "sparse")] +#[path = "../../../ops/cuda/sparse_24.rs"] +mod sparse_24; diff --git a/src/runtime/cuda/runtime.rs b/src/runtime/cuda/runtime.rs index fc7f5023..2804449c 100644 --- a/src/runtime/cuda/runtime.rs +++ b/src/runtime/cuda/runtime.rs @@ -1,7 +1,7 @@ //! CUDA runtime implementation use super::cache::{ - get_or_create_client, is_cuda_context_valid, log_cuda_memory_error, reset_client, + get_or_create_client, is_cuda_context_valid, log_cuda_memory_error, try_get_cached_client, try_get_cached_stream, }; use super::client::CudaAllocator; @@ -9,6 +9,7 @@ use super::client::CudaClient; use super::device::CudaDevice; use super::kernels; use crate::runtime::Runtime; +use crate::runtime::common::Allocator; /// CUDA Runtime adapter /// @@ -21,7 +22,9 @@ impl Runtime for CudaRuntime { type Device = CudaDevice; type Client = CudaClient; type Allocator = CudaAllocator; + type Graph = super::CudaGraph; type RawHandle = super::CudaRawHandle; + type DType = crate::dtype::DType; fn name() -> &'static str { "cuda" @@ -31,82 +34,96 @@ impl Runtime for CudaRuntime { true // CUDA supports graph capture } + fn capture_graph(client: &Self::Client, f: F) -> crate::error::Result<(Self::Graph, T)> + where + F: FnOnce(&Self::Client) -> crate::error::Result, + { + use cudarc::driver::sys::CUstreamCaptureMode; + + // Freeze the caching allocator so all alloc/free calls go directly + // through cuMemAllocAsync/cuMemFreeAsync, creating proper graph nodes. + // Without this, the free-list cache intercepts deallocations (no graph + // free node) and satisfies allocations from cache (no graph alloc node), + // corrupting the graph's internal memory management on replay. + client.allocator.freeze(); + + // Begin stream capture — all ops on this stream are recorded, not executed + client + .stream + .begin_capture(CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_GLOBAL)?; + + // Execute the closure — ops are recorded into the graph + let result = f(client); + + // End capture — MUST happen even if the closure failed, otherwise the + // stream is left in capture mode and all subsequent operations fail. + // + // AUTO_FREE_ON_LAUNCH: graph-managed memory allocated during capture is + // freed on each launch. For graph capture in training (where we re-run + // the same graph), this is acceptable — each launch re-allocates. + // For inference with stable output pointers, the caller must copy the + // output tensor after each launch before the next launch frees it. + let flags = cudarc::driver::sys::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH; + let graph_result = client.stream.end_capture(flags); + + // Restore caching allocator for normal (non-capture) operations + client.allocator.unfreeze(); + + // Handle closure error: propagate after restoring stream + let closure_result = result?; + + // Handle capture error + let graph_opt = graph_result?; + + let cudarc_graph = graph_opt.ok_or_else(|| { + crate::error::Error::Backend( + "CUDA graph capture produced no operations — closure recorded nothing".into(), + ) + })?; + + Ok((super::CudaGraph::new(cudarc_graph), closure_result)) + } + /// Allocate GPU memory. /// - /// Returns `Err(OutOfMemory)` if CUDA memory allocation fails. + /// Routes through the client's caching allocator (free-list pool) to avoid + /// cuMemAllocAsync driver round-trips for repeated same-size allocations. fn allocate(size_bytes: usize, device: &Self::Device) -> crate::error::Result { if size_bytes == 0 { return Ok(0); } let client = get_or_create_client(device); - - unsafe { - let mut ptr: u64 = 0; - let result = cudarc::driver::sys::cuMemAllocAsync( - &mut ptr, - size_bytes, - client.stream.cu_stream(), - ); - - if result == cudarc::driver::sys::CUresult::CUDA_SUCCESS { - return Ok(ptr); - } - - // First attempt failed - try syncing the stream to flush pending frees - let _ = client.stream.synchronize(); - - let result = cudarc::driver::sys::cuMemAllocAsync( - &mut ptr, - size_bytes, - client.stream.cu_stream(), - ); - - if result == cudarc::driver::sys::CUresult::CUDA_SUCCESS { - return Ok(ptr); - } - - // Stream is likely in a sticky error state (e.g., CUDA_ERROR_MISALIGNED_ADDRESS - // from a previous kernel). Reset the client with a fresh context/stream. - drop(client); - if let Some(new_client) = reset_client(device) { - let result = cudarc::driver::sys::cuMemAllocAsync( - &mut ptr, - size_bytes, - new_client.stream.cu_stream(), - ); - - if result == cudarc::driver::sys::CUresult::CUDA_SUCCESS { - return Ok(ptr); - } - } - - Err(crate::error::Error::OutOfMemory { size: size_bytes }) - } + client.allocator.allocate(size_bytes) } - fn deallocate(ptr: u64, _size_bytes: usize, device: &Self::Device) { + /// Deallocate GPU memory. + /// + /// Routes through the client's caching allocator — buffers are returned to + /// the free-list for reuse instead of calling cuMemFreeAsync. + fn deallocate(ptr: u64, size_bytes: usize, device: &Self::Device) { if ptr == 0 { return; } + // Try to use the client's caching allocator (returns to free-list) + if let Some(client) = try_get_cached_client(device.index) { + client.allocator.deallocate(ptr, size_bytes); + return; + } + + // Client not available (shutdown) — free directly unsafe { - // Check if CUDA context is still valid before attempting free if !is_cuda_context_valid() { - // Context is gone - memory will be reclaimed by driver on context destruction return; } - // Try to use stream-ordered async free if client is available let result = if let Some(stream) = try_get_cached_stream(device.index) { cudarc::driver::sys::cuMemFreeAsync(ptr, stream) } else { - // Fallback to synchronous free cudarc::driver::sys::cuMemFree_v2(ptr) }; - // Log failures but don't panic - deallocation errors are typically benign - // (e.g., double-free attempts, already-freed memory) if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS && result != cudarc::driver::sys::CUresult::CUDA_ERROR_ILLEGAL_ADDRESS { @@ -141,8 +158,10 @@ impl Runtime for CudaRuntime { ))); } - // Synchronize to ensure data is available - let _ = client.stream.synchronize(); + // No explicit sync needed: with pageable (non-pinned) host memory, + // cuMemcpyHtoDAsync is synchronous w.r.t. the host buffer — the call + // returns only after the copy is complete. An explicit stream.synchronize() + // here would also drain ALL pending GPU work, destroying pipeline throughput. } Ok(()) } @@ -177,12 +196,68 @@ impl Runtime for CudaRuntime { ))); } - // Synchronize to ensure data is available on host + // With pageable host memory, cuMemcpyDtoHAsync blocks the host until + // the copy completes. However, we still need to synchronize the stream + // to ensure all prior GPU kernels have finished producing the data. let _ = client.stream.synchronize(); } Ok(()) } + /// Record an event on the compute stream. + fn record_compute_event(device: &Self::Device) -> crate::error::Result { + let client = get_or_create_client(device); + client + .record_event_on_compute() + .map_err(|e| crate::error::Error::Backend(format!("Event record failed: {}", e))) + } + + /// Pipelined D2H copy: copy stream waits on the provided event, copies, + /// and syncs only the copy stream. Compute stream continues concurrently. + fn copy_from_device_pipelined( + src: u64, + dst: &mut [u8], + device: &Self::Device, + event: u64, + ) -> crate::error::Result<()> { + if dst.is_empty() || src == 0 { + return Ok(()); + } + + let client = get_or_create_client(device); + + unsafe { + // 1. Copy stream waits for event (waits for argmax to finish) + client.copy_stream_wait_event(event).map_err(|e| { + client.destroy_event(event); + crate::error::Error::Backend(format!("Stream wait event failed: {}", e)) + })?; + + // 2. Launch D2H copy on copy stream + let result = cudarc::driver::sys::cuMemcpyDtoHAsync_v2( + dst.as_mut_ptr() as *mut std::ffi::c_void, + src, + dst.len(), + client.copy_stream.cu_stream(), + ); + + if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS { + client.destroy_event(event); + return Err(crate::error::Error::Backend(format!( + "[numr::cuda] Pipelined D2H copy failed: {} bytes ({:?})", + dst.len(), + result + ))); + } + + // 3. Sync ONLY the copy stream (compute stream keeps running) + let _ = client.copy_stream.synchronize(); + + client.destroy_event(event); + } + Ok(()) + } + /// Copy data within device memory. /// /// Returns an error if the CUDA copy operation fails. @@ -236,93 +311,26 @@ impl Runtime for CudaRuntime { let ndim = shape.len(); let client = get_or_create_client(device); - let cu_stream = client.stream.cu_stream(); - - // Convert shape and strides to device-compatible types - let shape_u64: Vec = shape.iter().map(|&s| s as u64).collect(); - let strides_i64: Vec = strides.iter().map(|&s| s as i64).collect(); - - // Allocate temporary device memory for shape and strides arrays - let shape_bytes = ndim * std::mem::size_of::(); - let strides_bytes = ndim * std::mem::size_of::(); + // Shape and strides are passed as kernel arguments (by value), not device + // memory pointers. This is critical for CUDA graph capture compatibility: + // H2D copies of temporary host data create graph memcpy nodes that re-read + // from stale host addresses on replay, causing CUDA_ERROR_ILLEGAL_ADDRESS. unsafe { - // Allocate device memory for shape array - let mut shape_ptr: u64 = 0; - let result = - cudarc::driver::sys::cuMemAllocAsync(&mut shape_ptr, shape_bytes, cu_stream); - if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS { - return Err(crate::error::Error::Backend(format!( - "[numr::cuda] Failed to allocate shape array for strided copy ({:?})", - result - ))); - } - - // Allocate device memory for strides array - let mut strides_ptr: u64 = 0; - let result = - cudarc::driver::sys::cuMemAllocAsync(&mut strides_ptr, strides_bytes, cu_stream); - if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS { - // Free shape_ptr before returning error - let _ = cudarc::driver::sys::cuMemFreeAsync(shape_ptr, cu_stream); - return Err(crate::error::Error::Backend(format!( - "[numr::cuda] Failed to allocate strides array for strided copy ({:?})", - result - ))); - } - - // Copy shape to device - let result = cudarc::driver::sys::cuMemcpyHtoDAsync_v2( - shape_ptr, - shape_u64.as_ptr() as *const std::ffi::c_void, - shape_bytes, - cu_stream, - ); - if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS { - let _ = cudarc::driver::sys::cuMemFreeAsync(shape_ptr, cu_stream); - let _ = cudarc::driver::sys::cuMemFreeAsync(strides_ptr, cu_stream); - return Err(crate::error::Error::Backend(format!( - "[numr::cuda] Failed to copy shape to device for strided copy ({:?})", - result - ))); - } - - // Copy strides to device - let result = cudarc::driver::sys::cuMemcpyHtoDAsync_v2( - strides_ptr, - strides_i64.as_ptr() as *const std::ffi::c_void, - strides_bytes, - cu_stream, - ); - if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS { - let _ = cudarc::driver::sys::cuMemFreeAsync(shape_ptr, cu_stream); - let _ = cudarc::driver::sys::cuMemFreeAsync(strides_ptr, cu_stream); - return Err(crate::error::Error::Backend(format!( - "[numr::cuda] Failed to copy strides to device for strided copy ({:?})", - result - ))); - } - - // Launch the strided copy kernel let kernel_result = kernels::launch_strided_copy( &client.context, &client.stream, device.index, src_handle, dst_handle, - shape_ptr, - strides_ptr, + shape, + strides, numel, ndim, elem_size, src_byte_offset, ); - // Free temporary device memory (async, will happen after kernel completes) - let _ = cudarc::driver::sys::cuMemFreeAsync(shape_ptr, cu_stream); - let _ = cudarc::driver::sys::cuMemFreeAsync(strides_ptr, cu_stream); - - // Check kernel launch result if let Err(e) = kernel_result { return Err(crate::error::Error::Backend(format!( "[numr::cuda] Strided copy kernel failed: {} bytes ({} elements × {} bytes/elem) from {} to {} on device {}: {:?}", @@ -335,9 +343,6 @@ impl Runtime for CudaRuntime { e ))); } - - // Synchronize to ensure copy is complete - let _ = client.stream.synchronize(); } Ok(()) } diff --git a/src/runtime/cuda/sparse/conversions.rs b/src/runtime/cuda/sparse/conversions.rs index 7b0859e8..14f48c36 100644 --- a/src/runtime/cuda/sparse/conversions.rs +++ b/src/runtime/cuda/sparse/conversions.rs @@ -49,7 +49,7 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - perm_indices.storage().ptr(), + perm_indices.ptr(), nnz, )?; } @@ -59,8 +59,8 @@ impl CudaClient { unsafe { // Copy row_indices to sorted_rows for in-place sorting CudaRuntime::copy_within_device( - row_indices.storage().ptr(), - sorted_rows.storage().ptr(), + row_indices.ptr(), + sorted_rows.ptr(), row_indices.storage().size_in_bytes(), device, )?; @@ -69,8 +69,8 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - sorted_rows.storage().ptr(), - perm_indices.storage().ptr(), + sorted_rows.ptr(), + perm_indices.ptr(), nnz_u32, )?; } @@ -82,18 +82,18 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - values.storage().ptr(), - perm_indices.storage().ptr(), - sorted_values.storage().ptr(), + values.ptr(), + perm_indices.ptr(), + sorted_values.ptr(), nnz, )?, DType::F64 => kernels::launch_coo_gather::( &self.context, &self.stream, self.device.index, - values.storage().ptr(), - perm_indices.storage().ptr(), - sorted_values.storage().ptr(), + values.ptr(), + perm_indices.ptr(), + sorted_values.ptr(), nnz, )?, #[cfg(feature = "f16")] @@ -101,9 +101,9 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - values.storage().ptr(), - perm_indices.storage().ptr(), - sorted_values.storage().ptr(), + values.ptr(), + perm_indices.ptr(), + sorted_values.ptr(), nnz, )?, #[cfg(feature = "f16")] @@ -111,9 +111,9 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - values.storage().ptr(), - perm_indices.storage().ptr(), - sorted_values.storage().ptr(), + values.ptr(), + perm_indices.ptr(), + sorted_values.ptr(), nnz, )?, _ => { @@ -128,9 +128,9 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - col_indices.storage().ptr(), - perm_indices.storage().ptr(), - sorted_cols.storage().ptr(), + col_indices.ptr(), + perm_indices.ptr(), + sorted_cols.ptr(), nnz, )?; } @@ -142,8 +142,8 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - sorted_rows.storage().ptr(), - row_ptrs.storage().ptr(), + sorted_rows.ptr(), + row_ptrs.ptr(), nnz, nrows, )?; @@ -196,7 +196,7 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - perm_indices.storage().ptr(), + perm_indices.ptr(), nnz, )?; } @@ -204,8 +204,8 @@ impl CudaClient { // Step 3: Sort by column indices using Thrust unsafe { CudaRuntime::copy_within_device( - col_indices.storage().ptr(), - sorted_cols.storage().ptr(), + col_indices.ptr(), + sorted_cols.ptr(), col_indices.storage().size_in_bytes(), device, )?; @@ -214,8 +214,8 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - sorted_cols.storage().ptr(), - perm_indices.storage().ptr(), + sorted_cols.ptr(), + perm_indices.ptr(), nnz_u32, )?; } @@ -227,18 +227,18 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - values.storage().ptr(), - perm_indices.storage().ptr(), - sorted_values.storage().ptr(), + values.ptr(), + perm_indices.ptr(), + sorted_values.ptr(), nnz, )?, DType::F64 => kernels::launch_coo_gather::( &self.context, &self.stream, self.device.index, - values.storage().ptr(), - perm_indices.storage().ptr(), - sorted_values.storage().ptr(), + values.ptr(), + perm_indices.ptr(), + sorted_values.ptr(), nnz, )?, #[cfg(feature = "f16")] @@ -246,9 +246,9 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - values.storage().ptr(), - perm_indices.storage().ptr(), - sorted_values.storage().ptr(), + values.ptr(), + perm_indices.ptr(), + sorted_values.ptr(), nnz, )?, #[cfg(feature = "f16")] @@ -256,9 +256,9 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - values.storage().ptr(), - perm_indices.storage().ptr(), - sorted_values.storage().ptr(), + values.ptr(), + perm_indices.ptr(), + sorted_values.ptr(), nnz, )?, _ => { @@ -273,9 +273,9 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - row_indices.storage().ptr(), - perm_indices.storage().ptr(), - sorted_rows.storage().ptr(), + row_indices.ptr(), + perm_indices.ptr(), + sorted_rows.ptr(), nnz, )?; } @@ -287,8 +287,8 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - sorted_cols.storage().ptr(), - col_ptrs.storage().ptr(), + sorted_cols.ptr(), + col_ptrs.ptr(), nnz, ncols, )?; @@ -323,8 +323,8 @@ impl CudaClient { let row_indices = Tensor::::zeros(&[nnz], crate::dtype::DType::I64, device); // Get device pointers (no data transfer!) - let row_ptrs_ptr = row_ptrs.storage().ptr(); - let row_indices_ptr = row_indices.storage().ptr(); + let row_ptrs_ptr = row_ptrs.ptr(); + let row_indices_ptr = row_indices.ptr(); // Launch pointer expansion kernel unsafe { @@ -369,8 +369,8 @@ impl CudaClient { let col_indices = Tensor::::zeros(&[nnz], crate::dtype::DType::I64, device); // Get device pointers (no data transfer!) - let col_ptrs_ptr = col_ptrs.storage().ptr(); - let col_indices_ptr = col_indices.storage().ptr(); + let col_ptrs_ptr = col_ptrs.ptr(); + let col_indices_ptr = col_indices.ptr(); // Launch pointer expansion kernel unsafe { @@ -419,9 +419,9 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - row_ptrs.storage().ptr(), - col_indices.storage().ptr(), - col_counts.storage().ptr(), + row_ptrs.ptr(), + col_indices.ptr(), + col_counts.ptr(), nrows, )?; } @@ -453,12 +453,12 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - row_ptrs.storage().ptr(), - col_indices.storage().ptr(), - values.storage().ptr(), - col_ptrs_working.storage().ptr(), - row_indices_out.storage().ptr(), - values_out.storage().ptr(), + row_ptrs.ptr(), + col_indices.ptr(), + values.ptr(), + col_ptrs_working.ptr(), + row_indices_out.ptr(), + values_out.ptr(), nrows, ncols, )?; @@ -468,12 +468,12 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - row_ptrs.storage().ptr(), - col_indices.storage().ptr(), - values.storage().ptr(), - col_ptrs_working.storage().ptr(), - row_indices_out.storage().ptr(), - values_out.storage().ptr(), + row_ptrs.ptr(), + col_indices.ptr(), + values.ptr(), + col_ptrs_working.ptr(), + row_indices_out.ptr(), + values_out.ptr(), nrows, ncols, )?; @@ -518,9 +518,9 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - col_ptrs.storage().ptr(), - row_indices.storage().ptr(), - row_counts.storage().ptr(), + col_ptrs.ptr(), + row_indices.ptr(), + row_counts.ptr(), ncols, )?; } @@ -552,12 +552,12 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - col_ptrs.storage().ptr(), - row_indices.storage().ptr(), - values.storage().ptr(), - row_ptrs_working.storage().ptr(), - col_indices_out.storage().ptr(), - values_out.storage().ptr(), + col_ptrs.ptr(), + row_indices.ptr(), + values.ptr(), + row_ptrs_working.ptr(), + col_indices_out.ptr(), + values_out.ptr(), nrows, ncols, )?; @@ -567,12 +567,12 @@ impl CudaClient { &self.context, &self.stream, self.device.index, - col_ptrs.storage().ptr(), - row_indices.storage().ptr(), - values.storage().ptr(), - row_ptrs_working.storage().ptr(), - col_indices_out.storage().ptr(), - values_out.storage().ptr(), + col_ptrs.ptr(), + row_indices.ptr(), + values.ptr(), + row_ptrs_working.ptr(), + col_indices_out.ptr(), + values_out.ptr(), nrows, ncols, )?; diff --git a/src/runtime/cuda/sparse/dsmm.rs b/src/runtime/cuda/sparse/dsmm.rs index 9530256c..1a481c60 100644 --- a/src/runtime/cuda/sparse/dsmm.rs +++ b/src/runtime/cuda/sparse/dsmm.rs @@ -44,11 +44,11 @@ pub(super) fn column_parallel_dsmm( let output = Tensor::::zeros(&[m, n], dtype, device); // Get raw pointers - let a_ptr = a_contig.storage().ptr(); - let col_ptrs_ptr = sparse_b_csc.col_ptrs.storage().ptr(); - let row_indices_ptr = sparse_b_csc.row_indices.storage().ptr(); - let values_ptr = sparse_b_csc.values.storage().ptr(); - let output_ptr = output.storage().ptr(); + let a_ptr = a_contig.ptr(); + let col_ptrs_ptr = sparse_b_csc.col_ptrs.ptr(); + let row_indices_ptr = sparse_b_csc.row_indices.ptr(); + let values_ptr = sparse_b_csc.values.ptr(); + let output_ptr = output.ptr(); // Launch CUDA kernel unsafe { diff --git a/src/runtime/cuda/sparse/esc_spgemm.rs b/src/runtime/cuda/sparse/esc_spgemm.rs index ab0648a5..f81d3473 100644 --- a/src/runtime/cuda/sparse/esc_spgemm.rs +++ b/src/runtime/cuda/sparse/esc_spgemm.rs @@ -64,7 +64,7 @@ impl CudaClient { a_shape: [usize; 2], b_shape: [usize; 2], ) -> Result> { - use crate::runtime::sparse_utils::zero_tolerance; + use crate::runtime::common::sparse_utils::zero_tolerance; let [m, _k] = a_shape; let [_, n] = b_shape; @@ -109,8 +109,8 @@ impl CudaClient { self.device.index, DType::I32, DType::I64, - c_row_ptrs_i32.storage().ptr(), - output.storage().ptr(), + c_row_ptrs_i32.ptr(), + output.ptr(), m + 1, )?; output diff --git a/src/runtime/cuda/sparse/high_level_ops.rs b/src/runtime/cuda/sparse/high_level_ops.rs index 651149a9..e9de93e3 100644 --- a/src/runtime/cuda/sparse/high_level_ops.rs +++ b/src/runtime/cuda/sparse/high_level_ops.rs @@ -452,10 +452,10 @@ impl SparseOps for CudaClient { let out = Tensor::::zeros(&[n], dtype, device); - let row_ptrs_ptr = row_ptrs.storage().ptr(); - let col_indices_ptr = col_indices.storage().ptr(); - let values_ptr = values.storage().ptr(); - let out_ptr = out.storage().ptr(); + let row_ptrs_ptr = row_ptrs.ptr(); + let col_indices_ptr = col_indices.ptr(); + let values_ptr = values.ptr(); + let out_ptr = out.ptr(); match dtype { DType::F32 => unsafe { diff --git a/src/runtime/cuda/sparse/linalg/common.rs b/src/runtime/cuda/sparse/linalg/common.rs index 247ccaac..fcfbeaa7 100644 --- a/src/runtime/cuda/sparse/linalg/common.rs +++ b/src/runtime/cuda/sparse/linalg/common.rs @@ -40,8 +40,8 @@ pub fn cast_i64_to_i32_gpu( &client.context, &client.stream, client.device.index, - tensor.storage().ptr(), - output.storage().ptr(), + tensor.ptr(), + output.ptr(), n, )?; } @@ -76,10 +76,10 @@ pub fn compute_levels_lower_gpu( &client.context, &client.stream, client.device.index, - row_ptrs_i32.storage().ptr(), - col_indices_i32.storage().ptr(), - levels_gpu.storage().ptr(), - changed_gpu.storage().ptr(), + row_ptrs_i32.ptr(), + col_indices_i32.ptr(), + levels_gpu.ptr(), + changed_gpu.ptr(), n as i32, )?; } @@ -100,8 +100,8 @@ pub fn compute_levels_lower_gpu( &client.context, &client.stream, client.device.index, - levels_gpu.storage().ptr(), - max_level_gpu.storage().ptr(), + levels_gpu.ptr(), + max_level_gpu.ptr(), n as i32, )?; } @@ -117,8 +117,8 @@ pub fn compute_levels_lower_gpu( &client.context, &client.stream, client.device.index, - levels_gpu.storage().ptr(), - histogram_gpu.storage().ptr(), + levels_gpu.ptr(), + histogram_gpu.ptr(), n as i32, )?; } @@ -148,10 +148,10 @@ pub fn compute_levels_lower_gpu( &client.context, &client.stream, client.device.index, - levels_gpu.storage().ptr(), - level_ptrs_gpu.storage().ptr(), - level_rows_gpu.storage().ptr(), - level_counters_gpu.storage().ptr(), + levels_gpu.ptr(), + level_ptrs_gpu.ptr(), + level_rows_gpu.ptr(), + level_counters_gpu.ptr(), n as i32, )?; } @@ -186,10 +186,10 @@ pub fn compute_levels_upper_gpu( &client.context, &client.stream, client.device.index, - row_ptrs_i32.storage().ptr(), - col_indices_i32.storage().ptr(), - levels_gpu.storage().ptr(), - changed_gpu.storage().ptr(), + row_ptrs_i32.ptr(), + col_indices_i32.ptr(), + levels_gpu.ptr(), + changed_gpu.ptr(), n as i32, )?; } @@ -210,8 +210,8 @@ pub fn compute_levels_upper_gpu( &client.context, &client.stream, client.device.index, - levels_gpu.storage().ptr(), - max_level_gpu.storage().ptr(), + levels_gpu.ptr(), + max_level_gpu.ptr(), n as i32, )?; } @@ -227,8 +227,8 @@ pub fn compute_levels_upper_gpu( &client.context, &client.stream, client.device.index, - levels_gpu.storage().ptr(), - histogram_gpu.storage().ptr(), + levels_gpu.ptr(), + histogram_gpu.ptr(), n as i32, )?; } @@ -255,10 +255,10 @@ pub fn compute_levels_upper_gpu( &client.context, &client.stream, client.device.index, - levels_gpu.storage().ptr(), - level_ptrs_gpu.storage().ptr(), - level_rows_gpu.storage().ptr(), - level_counters_gpu.storage().ptr(), + levels_gpu.ptr(), + level_ptrs_gpu.ptr(), + level_rows_gpu.ptr(), + level_counters_gpu.ptr(), n as i32, )?; } @@ -342,11 +342,11 @@ pub fn split_lu_cuda( &client.context, &client.stream, client.device.index, - values_gpu.storage().ptr(), - l_values_t.storage().ptr(), - u_values_t.storage().ptr(), - l_map_gpu.storage().ptr(), - u_map_gpu.storage().ptr(), + values_gpu.ptr(), + l_values_t.ptr(), + u_values_t.ptr(), + l_map_gpu.ptr(), + u_map_gpu.ptr(), nnz as i32, )?; } @@ -355,11 +355,11 @@ pub fn split_lu_cuda( &client.context, &client.stream, client.device.index, - values_gpu.storage().ptr(), - l_values_t.storage().ptr(), - u_values_t.storage().ptr(), - l_map_gpu.storage().ptr(), - u_map_gpu.storage().ptr(), + values_gpu.ptr(), + l_values_t.ptr(), + u_values_t.ptr(), + l_map_gpu.ptr(), + u_map_gpu.ptr(), nnz as i32, )?; } @@ -431,9 +431,9 @@ pub fn extract_lower_cuda( &client.context, &client.stream, client.device.index, - values_gpu.storage().ptr(), - l_values_t.storage().ptr(), - lower_map_gpu.storage().ptr(), + values_gpu.ptr(), + l_values_t.ptr(), + lower_map_gpu.ptr(), nnz as i32, )?; } @@ -442,9 +442,9 @@ pub fn extract_lower_cuda( &client.context, &client.stream, client.device.index, - values_gpu.storage().ptr(), - l_values_t.storage().ptr(), - lower_map_gpu.storage().ptr(), + values_gpu.ptr(), + l_values_t.ptr(), + lower_map_gpu.ptr(), nnz as i32, )?; } diff --git a/src/runtime/cuda/sparse/linalg/ic0.rs b/src/runtime/cuda/sparse/linalg/ic0.rs index 7c05c29d..5e19f16f 100644 --- a/src/runtime/cuda/sparse/linalg/ic0.rs +++ b/src/runtime/cuda/sparse/linalg/ic0.rs @@ -43,9 +43,9 @@ pub fn ic0_cuda( &client.context, &client.stream, client.device.index, - row_ptrs_gpu.storage().ptr(), - col_indices_gpu.storage().ptr(), - diag_indices_gpu.storage().ptr(), + row_ptrs_gpu.ptr(), + col_indices_gpu.ptr(), + diag_indices_gpu.ptr(), n as i32, )?; } @@ -61,7 +61,7 @@ pub fn ic0_cuda( } let level_rows_ptr = - level_rows_gpu.storage().ptr() + (level_start * std::mem::size_of::()) as u64; + level_rows_gpu.ptr() + (level_start * std::mem::size_of::()) as u64; match dtype { DType::F32 => unsafe { @@ -71,10 +71,10 @@ pub fn ic0_cuda( client.device.index, level_rows_ptr, level_size, - row_ptrs_gpu.storage().ptr(), - col_indices_gpu.storage().ptr(), - values_gpu.storage().ptr(), - diag_indices_gpu.storage().ptr(), + row_ptrs_gpu.ptr(), + col_indices_gpu.ptr(), + values_gpu.ptr(), + diag_indices_gpu.ptr(), n as i32, options.diagonal_shift as f32, )?; @@ -86,10 +86,10 @@ pub fn ic0_cuda( client.device.index, level_rows_ptr, level_size, - row_ptrs_gpu.storage().ptr(), - col_indices_gpu.storage().ptr(), - values_gpu.storage().ptr(), - diag_indices_gpu.storage().ptr(), + row_ptrs_gpu.ptr(), + col_indices_gpu.ptr(), + values_gpu.ptr(), + diag_indices_gpu.ptr(), n as i32, options.diagonal_shift, )?; diff --git a/src/runtime/cuda/sparse/linalg/ilu0.rs b/src/runtime/cuda/sparse/linalg/ilu0.rs index 998da858..39f53c29 100644 --- a/src/runtime/cuda/sparse/linalg/ilu0.rs +++ b/src/runtime/cuda/sparse/linalg/ilu0.rs @@ -50,9 +50,9 @@ pub fn ilu0_cuda( &client.context, &client.stream, client.device.index, - row_ptrs_gpu.storage().ptr(), - col_indices_gpu.storage().ptr(), - diag_indices_gpu.storage().ptr(), + row_ptrs_gpu.ptr(), + col_indices_gpu.ptr(), + diag_indices_gpu.ptr(), n as i32, )?; } @@ -69,7 +69,7 @@ pub fn ilu0_cuda( // Get pointer to this level's rows let level_rows_ptr = - level_rows_gpu.storage().ptr() + (level_start * std::mem::size_of::()) as u64; + level_rows_gpu.ptr() + (level_start * std::mem::size_of::()) as u64; match dtype { DType::F32 => unsafe { @@ -79,10 +79,10 @@ pub fn ilu0_cuda( client.device.index, level_rows_ptr, level_size, - row_ptrs_gpu.storage().ptr(), - col_indices_gpu.storage().ptr(), - values_gpu.storage().ptr(), - diag_indices_gpu.storage().ptr(), + row_ptrs_gpu.ptr(), + col_indices_gpu.ptr(), + values_gpu.ptr(), + diag_indices_gpu.ptr(), n as i32, options.diagonal_shift as f32, )?; @@ -94,10 +94,10 @@ pub fn ilu0_cuda( client.device.index, level_rows_ptr, level_size, - row_ptrs_gpu.storage().ptr(), - col_indices_gpu.storage().ptr(), - values_gpu.storage().ptr(), - diag_indices_gpu.storage().ptr(), + row_ptrs_gpu.ptr(), + col_indices_gpu.ptr(), + values_gpu.ptr(), + diag_indices_gpu.ptr(), n as i32, options.diagonal_shift, )?; @@ -181,9 +181,9 @@ pub fn ilu0_numeric_cuda( &client.context, &client.stream, client.device.index, - row_ptrs_gpu.storage().ptr(), - col_indices_gpu.storage().ptr(), - diag_indices_gpu.storage().ptr(), + row_ptrs_gpu.ptr(), + col_indices_gpu.ptr(), + diag_indices_gpu.ptr(), n as i32, )?; } @@ -199,7 +199,7 @@ pub fn ilu0_numeric_cuda( } let level_rows_ptr = - level_rows_gpu.storage().ptr() + (level_start * std::mem::size_of::()) as u64; + level_rows_gpu.ptr() + (level_start * std::mem::size_of::()) as u64; match dtype { DType::F32 => unsafe { @@ -209,10 +209,10 @@ pub fn ilu0_numeric_cuda( client.device.index, level_rows_ptr, level_size, - row_ptrs_gpu.storage().ptr(), - col_indices_gpu.storage().ptr(), - values_gpu.storage().ptr(), - diag_indices_gpu.storage().ptr(), + row_ptrs_gpu.ptr(), + col_indices_gpu.ptr(), + values_gpu.ptr(), + diag_indices_gpu.ptr(), n as i32, options.diagonal_shift as f32, )?; @@ -224,10 +224,10 @@ pub fn ilu0_numeric_cuda( client.device.index, level_rows_ptr, level_size, - row_ptrs_gpu.storage().ptr(), - col_indices_gpu.storage().ptr(), - values_gpu.storage().ptr(), - diag_indices_gpu.storage().ptr(), + row_ptrs_gpu.ptr(), + col_indices_gpu.ptr(), + values_gpu.ptr(), + diag_indices_gpu.ptr(), n as i32, options.diagonal_shift, )?; diff --git a/src/runtime/cuda/sparse/linalg/iluk.rs b/src/runtime/cuda/sparse/linalg/iluk.rs index ae23fb53..feeaa42a 100644 --- a/src/runtime/cuda/sparse/linalg/iluk.rs +++ b/src/runtime/cuda/sparse/linalg/iluk.rs @@ -89,9 +89,9 @@ pub fn iluk_numeric_cuda( &client.context, &client.stream, client.device.index, - row_ptrs_gpu.storage().ptr(), - col_indices_gpu.storage().ptr(), - diag_indices_gpu.storage().ptr(), + row_ptrs_gpu.ptr(), + col_indices_gpu.ptr(), + diag_indices_gpu.ptr(), n as i32, )?; } @@ -107,7 +107,7 @@ pub fn iluk_numeric_cuda( } let level_rows_ptr = - level_rows_gpu.storage().ptr() + (level_start * std::mem::size_of::()) as u64; + level_rows_gpu.ptr() + (level_start * std::mem::size_of::()) as u64; match dtype { DType::F32 => unsafe { @@ -117,10 +117,10 @@ pub fn iluk_numeric_cuda( client.device.index, level_rows_ptr, level_size, - row_ptrs_gpu.storage().ptr(), - col_indices_gpu.storage().ptr(), - values_gpu.storage().ptr(), - diag_indices_gpu.storage().ptr(), + row_ptrs_gpu.ptr(), + col_indices_gpu.ptr(), + values_gpu.ptr(), + diag_indices_gpu.ptr(), n as i32, opts.diagonal_shift as f32, )?; @@ -132,10 +132,10 @@ pub fn iluk_numeric_cuda( client.device.index, level_rows_ptr, level_size, - row_ptrs_gpu.storage().ptr(), - col_indices_gpu.storage().ptr(), - values_gpu.storage().ptr(), - diag_indices_gpu.storage().ptr(), + row_ptrs_gpu.ptr(), + col_indices_gpu.ptr(), + values_gpu.ptr(), + diag_indices_gpu.ptr(), n as i32, opts.diagonal_shift, )?; diff --git a/src/runtime/cuda/sparse/linalg/triangular_solve.rs b/src/runtime/cuda/sparse/linalg/triangular_solve.rs index ac992a49..41bd49b3 100644 --- a/src/runtime/cuda/sparse/linalg/triangular_solve.rs +++ b/src/runtime/cuda/sparse/linalg/triangular_solve.rs @@ -56,7 +56,7 @@ pub fn sparse_solve_triangular_cuda( } let level_rows_ptr = - level_rows_gpu.storage().ptr() + (level_start * std::mem::size_of::()) as u64; + level_rows_gpu.ptr() + (level_start * std::mem::size_of::()) as u64; if nrhs == 1 { // Use single RHS kernels for vectors @@ -154,11 +154,11 @@ fn launch_trsv_lower( client.device.index, level_rows_ptr, level_size, - row_ptrs.storage().ptr(), - col_indices.storage().ptr(), - values.storage().ptr(), - b.storage().ptr(), - x.storage().ptr(), + row_ptrs.ptr(), + col_indices.ptr(), + values.ptr(), + b.ptr(), + x.ptr(), n as i32, unit_diagonal, )?; @@ -170,11 +170,11 @@ fn launch_trsv_lower( client.device.index, level_rows_ptr, level_size, - row_ptrs.storage().ptr(), - col_indices.storage().ptr(), - values.storage().ptr(), - b.storage().ptr(), - x.storage().ptr(), + row_ptrs.ptr(), + col_indices.ptr(), + values.ptr(), + b.ptr(), + x.ptr(), n as i32, unit_diagonal, )?; @@ -206,11 +206,11 @@ fn launch_trsv_upper( client.device.index, level_rows_ptr, level_size, - row_ptrs.storage().ptr(), - col_indices.storage().ptr(), - values.storage().ptr(), - b.storage().ptr(), - x.storage().ptr(), + row_ptrs.ptr(), + col_indices.ptr(), + values.ptr(), + b.ptr(), + x.ptr(), n as i32, )?; }, @@ -221,11 +221,11 @@ fn launch_trsv_upper( client.device.index, level_rows_ptr, level_size, - row_ptrs.storage().ptr(), - col_indices.storage().ptr(), - values.storage().ptr(), - b.storage().ptr(), - x.storage().ptr(), + row_ptrs.ptr(), + col_indices.ptr(), + values.ptr(), + b.ptr(), + x.ptr(), n as i32, )?; }, @@ -259,11 +259,11 @@ fn launch_trsv_lower_multi_rhs( level_rows_ptr, level_size, nrhs as i32, - row_ptrs.storage().ptr(), - col_indices.storage().ptr(), - values.storage().ptr(), - b.storage().ptr(), - x.storage().ptr(), + row_ptrs.ptr(), + col_indices.ptr(), + values.ptr(), + b.ptr(), + x.ptr(), n as i32, unit_diagonal, )?; @@ -276,11 +276,11 @@ fn launch_trsv_lower_multi_rhs( level_rows_ptr, level_size, nrhs as i32, - row_ptrs.storage().ptr(), - col_indices.storage().ptr(), - values.storage().ptr(), - b.storage().ptr(), - x.storage().ptr(), + row_ptrs.ptr(), + col_indices.ptr(), + values.ptr(), + b.ptr(), + x.ptr(), n as i32, unit_diagonal, )?; @@ -314,11 +314,11 @@ fn launch_trsv_upper_multi_rhs( level_rows_ptr, level_size, nrhs as i32, - row_ptrs.storage().ptr(), - col_indices.storage().ptr(), - values.storage().ptr(), - b.storage().ptr(), - x.storage().ptr(), + row_ptrs.ptr(), + col_indices.ptr(), + values.ptr(), + b.ptr(), + x.ptr(), n as i32, )?; }, @@ -330,11 +330,11 @@ fn launch_trsv_upper_multi_rhs( level_rows_ptr, level_size, nrhs as i32, - row_ptrs.storage().ptr(), - col_indices.storage().ptr(), - values.storage().ptr(), - b.storage().ptr(), - x.storage().ptr(), + row_ptrs.ptr(), + col_indices.ptr(), + values.ptr(), + b.ptr(), + x.ptr(), n as i32, )?; }, diff --git a/src/runtime/cuda/sparse/spmv.rs b/src/runtime/cuda/sparse/spmv.rs index e661c4b6..64d76d08 100644 --- a/src/runtime/cuda/sparse/spmv.rs +++ b/src/runtime/cuda/sparse/spmv.rs @@ -33,11 +33,11 @@ impl CudaClient { let y = Tensor::::zeros(&[nrows], dtype, device); // Get device pointers (no data transfer!) - let row_ptrs_ptr = row_ptrs.storage().ptr(); - let col_indices_ptr = col_indices.storage().ptr(); - let values_ptr = values.storage().ptr(); - let x_ptr = x.storage().ptr(); - let y_ptr = y.storage().ptr(); + let row_ptrs_ptr = row_ptrs.ptr(); + let col_indices_ptr = col_indices.ptr(); + let values_ptr = values.ptr(); + let x_ptr = x.ptr(); + let y_ptr = y.ptr(); // Choose optimal kernel based on sparsity let nnz = values.numel(); @@ -206,11 +206,11 @@ impl CudaClient { let c = Tensor::::zeros(&[m, n], dtype, device); // Get device pointers (no data transfer!) - let row_ptrs_ptr = row_ptrs.storage().ptr(); - let col_indices_ptr = col_indices.storage().ptr(); - let values_ptr = values.storage().ptr(); - let b_ptr = b.storage().ptr(); - let c_ptr = c.storage().ptr(); + let row_ptrs_ptr = row_ptrs.ptr(); + let col_indices_ptr = col_indices.ptr(); + let values_ptr = values.ptr(); + let b_ptr = b.ptr(); + let c_ptr = c.ptr(); // Dispatch based on dtype (only F32/F64/F16/BF16 supported on CUDA) use crate::dtype::DType; diff --git a/src/runtime/cuda/special.rs b/src/runtime/cuda/special.rs index 25ec016e..a0c16fdc 100644 --- a/src/runtime/cuda/special.rs +++ b/src/runtime/cuda/special.rs @@ -26,8 +26,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -46,8 +46,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -66,8 +66,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -86,8 +86,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -106,8 +106,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -126,8 +126,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -163,9 +163,9 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, a.dtype(), - a.storage().ptr(), - b.storage().ptr(), - out.storage().ptr(), + a.ptr(), + b.ptr(), + out.ptr(), a.numel(), )?; } @@ -202,10 +202,10 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, a.dtype(), - a.storage().ptr(), - b.storage().ptr(), - x.storage().ptr(), - out.storage().ptr(), + a.ptr(), + b.ptr(), + x.ptr(), + out.ptr(), a.numel(), )?; } @@ -241,9 +241,9 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, a.dtype(), - a.storage().ptr(), - x.storage().ptr(), - out.storage().ptr(), + a.ptr(), + x.ptr(), + out.ptr(), a.numel(), )?; } @@ -279,9 +279,9 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, a.dtype(), - a.storage().ptr(), - x.storage().ptr(), - out.storage().ptr(), + a.ptr(), + x.ptr(), + out.ptr(), a.numel(), )?; } @@ -317,9 +317,9 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, a.dtype(), - a.storage().ptr(), - p.storage().ptr(), - out.storage().ptr(), + a.ptr(), + p.ptr(), + out.ptr(), a.numel(), )?; } @@ -356,10 +356,10 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, a.dtype(), - a.storage().ptr(), - b.storage().ptr(), - p.storage().ptr(), - out.storage().ptr(), + a.ptr(), + b.ptr(), + p.ptr(), + out.ptr(), a.numel(), )?; } @@ -378,8 +378,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -398,8 +398,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -418,8 +418,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -438,8 +438,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -458,8 +458,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -478,8 +478,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -498,8 +498,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -518,8 +518,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -542,8 +542,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, m.dtype(), - m.storage().ptr(), - out.storage().ptr(), + m.ptr(), + out.ptr(), m.numel(), )?; } @@ -562,8 +562,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, m.dtype(), - m.storage().ptr(), - out.storage().ptr(), + m.ptr(), + out.ptr(), m.numel(), )?; } @@ -591,8 +591,8 @@ impl SpecialFunctions for CudaClient { a, b, c, - z.storage().ptr(), - out.storage().ptr(), + z.ptr(), + out.ptr(), z.numel(), )?; } @@ -613,8 +613,8 @@ impl SpecialFunctions for CudaClient { z.dtype(), a, b, - z.storage().ptr(), - out.storage().ptr(), + z.ptr(), + out.ptr(), z.numel(), )?; } @@ -633,8 +633,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -653,8 +653,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -674,8 +674,8 @@ impl SpecialFunctions for CudaClient { device.index, x.dtype(), n, - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -701,8 +701,8 @@ impl SpecialFunctions for CudaClient { x.dtype(), n, m, - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -742,9 +742,9 @@ impl SpecialFunctions for CudaClient { theta.dtype(), n, m, - theta.storage().ptr(), - phi.storage().ptr(), - out.storage().ptr(), + theta.ptr(), + phi.ptr(), + out.ptr(), theta.numel(), )?; } @@ -763,8 +763,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } @@ -783,8 +783,8 @@ impl SpecialFunctions for CudaClient { self.stream(), device.index, x.dtype(), - x.storage().ptr(), - out.storage().ptr(), + x.ptr(), + out.ptr(), x.numel(), )?; } diff --git a/src/runtime/fallback.rs b/src/runtime/fallback.rs index ef022adf..78ac2dc1 100644 --- a/src/runtime/fallback.rs +++ b/src/runtime/fallback.rs @@ -150,7 +150,7 @@ impl CpuFallbackContext { /// /// This copies the tensor data from GPU memory to CPU memory. #[inline] - pub fn tensor_from_gpu( + pub fn tensor_from_gpu>( &self, tensor: &Tensor, ) -> Tensor { @@ -171,7 +171,10 @@ impl Default for CpuFallbackContext { /// Validate that two tensors have matching dtypes for binary operations. #[inline] -pub fn validate_binary_dtypes(a: &Tensor, b: &Tensor) -> Result { +pub fn validate_binary_dtypes>( + a: &Tensor, + b: &Tensor, +) -> Result { if a.dtype() != b.dtype() { return Err(Error::DTypeMismatch { lhs: a.dtype(), @@ -183,7 +186,10 @@ pub fn validate_binary_dtypes(a: &Tensor, b: &Tensor) -> Resul /// Compute broadcast shape for binary operations. #[inline] -pub fn compute_broadcast_shape(a: &Tensor, b: &Tensor) -> Result> { +pub fn compute_broadcast_shape>( + a: &Tensor, + b: &Tensor, +) -> Result> { broadcast_shape(a.shape(), b.shape()).ok_or_else(|| Error::BroadcastError { lhs: a.shape().to_vec(), rhs: b.shape().to_vec(), @@ -216,7 +222,7 @@ pub fn binary_op_fallback( op_name: &'static str, ) -> Result> where - R: Runtime, + R: Runtime, D: Device + Clone, { let dtype = validate_binary_dtypes(a, b)?; @@ -253,7 +259,7 @@ pub fn unary_op_fallback( op_name: &'static str, ) -> Result> where - R: Runtime, + R: Runtime, D: Device + Clone, { let dtype = a.dtype(); @@ -312,7 +318,7 @@ pub fn scalar_op_fallback( op_name: &'static str, ) -> Result> where - R: Runtime, + R: Runtime, D: Device + Clone, { let dtype = a.dtype(); @@ -347,7 +353,7 @@ pub fn reduce_op_fallback( op_name: &'static str, ) -> Result> where - R: Runtime, + R: Runtime, D: Device + Clone, { let dtype = a.dtype(); @@ -383,7 +389,7 @@ pub fn activation_fallback( op_fn: F, ) -> Result> where - R: Runtime, + R: Runtime, D: Device + Clone, F: Fn(&cpu::CpuClient, &Tensor) -> Result>, { @@ -408,7 +414,7 @@ pub fn softmax_fallback( op_name: &'static str, ) -> Result> where - R: Runtime, + R: Runtime, D: Device + Clone, { let dtype = a.dtype(); @@ -433,7 +439,7 @@ pub fn matmul_fallback( op_name: &'static str, ) -> Result> where - R: Runtime, + R: Runtime, D: Device + Clone, { let dtype = validate_binary_dtypes(a, b)?; @@ -461,7 +467,7 @@ pub fn compare_op_fallback( op_name: &'static str, ) -> Result> where - R: Runtime, + R: Runtime, D: Device + Clone, { let dtype = validate_binary_dtypes(a, b)?; @@ -492,7 +498,7 @@ where /// /// Returns the broadcasted shape of all three tensors. #[inline] -pub fn compute_ternary_broadcast_shape( +pub fn compute_ternary_broadcast_shape>( cond: &Tensor, x: &Tensor, y: &Tensor, @@ -529,7 +535,7 @@ pub fn where_cond_fallback( op_name: &'static str, ) -> Result> where - R: Runtime, + R: Runtime, D: Device + Clone, { // Validate dtypes (x and y must match, cond can be any dtype - non-zero = true) @@ -566,7 +572,7 @@ where #[cfg(feature = "sparse")] /// CSC element-wise operation fallback (GPU → CPU → GPU) #[allow(private_interfaces)] -pub fn csc_elementwise_fallback( +pub fn csc_elementwise_fallback, F, FA, FB>( a_col_ptrs: &Tensor, a_row_indices: &Tensor, a_values: &Tensor, @@ -631,7 +637,7 @@ where #[cfg(feature = "sparse")] /// COO element-wise operation fallback (GPU → CPU → GPU) #[allow(private_interfaces)] -pub fn coo_elementwise_fallback( +pub fn coo_elementwise_fallback, F, FA, FB>( a_row_indices: &Tensor, a_col_indices: &Tensor, a_values: &Tensor, diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index bb7928a6..83d3dc25 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -13,12 +13,8 @@ //! └── RawHandle (escape hatch for custom kernels) //! ``` -mod allocator; -pub(crate) mod helpers; -pub(crate) mod shape_ops; -#[cfg(feature = "sparse")] -pub(crate) mod sparse_utils; -pub(crate) mod statistics_common; +pub(crate) mod common; +mod communicator; pub mod traits; pub mod cpu; @@ -29,38 +25,32 @@ pub mod cuda; #[cfg(feature = "wgpu")] pub mod wgpu; -// CPU fallback utilities for GPU backends +// CPU fallback utilities for GPU backends (not common - GPU-specific) #[cfg(any(feature = "cuda", feature = "wgpu"))] pub(crate) mod fallback; +// Common re-exports #[cfg(any(feature = "cuda", feature = "wgpu"))] -pub(crate) use allocator::AllocGuard; -pub use allocator::Allocator; -pub(crate) use allocator::DefaultAllocator; -pub(crate) use helpers::{ +pub(crate) use common::AllocGuard; +pub(crate) use common::DefaultAllocator; +#[cfg(any(feature = "cuda", feature = "wgpu"))] +pub(crate) use common::compute_contiguous_strides; +pub use common::{AllocationStats, Allocator, Graph, NoOpGraph, TrackingAllocator}; +pub(crate) use common::{ compute_broadcast_shape, ensure_contiguous, normalize_dim, validate_arange, validate_binary_dtypes, validate_eye, }; -pub use traits::{Device, Runtime, RuntimeClient}; -// ============================================================================ -// Shared Helpers -// ============================================================================ +// Communicator re-exports +#[cfg(feature = "distributed-gpu")] +pub use communicator::HierarchicalCommunicator; +#[cfg(feature = "distributed")] +pub use communicator::NexarNetCommunicator; +pub use communicator::{ + Communicator, CommunicatorGroup, NoOpCommunicator, ParallelDim, ReduceOp, StreamSyncOps, +}; +#[cfg(feature = "nccl")] +pub use cuda::NcclCommunicator; -#[cfg(any(feature = "cuda", feature = "wgpu"))] -/// Compute contiguous (row-major) strides for a given shape. -/// -/// For a shape `[d0, d1, d2, ...]`, the strides are computed as: -/// - `strides[i] = product of dims[i+1..]` -/// - Last dimension always has stride 1 -#[inline] -pub(crate) fn compute_contiguous_strides(shape: &[usize]) -> Vec { - if shape.is_empty() { - return Vec::new(); - } - let mut strides = vec![1usize; shape.len()]; - for i in (0..shape.len().saturating_sub(1)).rev() { - strides[i] = strides[i + 1] * shape[i + 1]; - } - strides -} +// Trait re-exports +pub use traits::{Device, Runtime, RuntimeClient}; diff --git a/src/runtime/traits/client.rs b/src/runtime/traits/client.rs index c956f18d..654938fa 100644 --- a/src/runtime/traits/client.rs +++ b/src/runtime/traits/client.rs @@ -12,4 +12,12 @@ pub trait RuntimeClient: Clone + Send + Sync { /// Get the allocator for this client fn allocator(&self) -> &R::Allocator; + + /// Get the raw CUDA stream handle for compute-communication overlap. + /// + /// Returns `Some(handle)` on CUDA backends where the handle is the + /// `CUstream` pointer cast to `u64`. Returns `None` on CPU/WebGPU. + fn compute_stream_handle(&self) -> Option { + None + } } diff --git a/src/runtime/traits/runtime.rs b/src/runtime/traits/runtime.rs index 29e17d93..a9ebc79b 100644 --- a/src/runtime/traits/runtime.rs +++ b/src/runtime/traits/runtime.rs @@ -10,6 +10,7 @@ /// - `Device`: Identifies a specific compute unit (e.g., GPU 0, GPU 1) /// - `Client`: Handles operation dispatch and synchronization /// - `Allocator`: Memory management with optional freeze support +/// - `Graph`: Captured computation sequence for replay (CUDA Graphs, etc.) /// - `RawHandle`: Escape hatch for custom kernel launching /// /// # Example @@ -30,6 +31,12 @@ pub trait Runtime: Clone + Send + Sync + 'static { /// Memory allocator type type Allocator: crate::runtime::Allocator; + /// Captured computation graph for replay + /// + /// For CPU/WebGPU: `NoOpGraph` (operations execute eagerly, launch is no-op) + /// For CUDA: `CudaGraph` wrapping cudarc's graph types + type Graph: crate::runtime::Graph; + /// Raw handle for custom kernel launching (escape hatch) /// /// For CPU: `()` (no raw handle needed) @@ -37,14 +44,35 @@ pub trait Runtime: Clone + Send + Sync + 'static { /// For WGPU: Access to wgpu::Device/Queue type RawHandle: Send + Sync; + /// Data type enum for tensor elements. + /// + /// numr runtimes use `numr::DType`. Downstream runtimes (e.g. boostr) + /// can specify their own dtype enum with quantized variants. + type DType: crate::dtype::DataType; + /// Human-readable name of this runtime fn name() -> &'static str; /// Does this backend support graph capture (e.g., CUDA Graphs)? + /// + /// Check this BEFORE calling `capture_graph` to avoid unnecessary + /// eager execution on non-capture backends. fn supports_graph_capture() -> bool { false } + /// Capture a sequence of operations as a replayable graph. + /// + /// The closure receives the client so operations are issued on the correct + /// stream/queue. On capture-capable backends (CUDA), ops submitted inside + /// the closure are recorded into a graph. On non-capture backends (CPU, WebGPU), + /// the closure executes eagerly and returns `NoOpGraph`. + /// + /// Returns `(graph, closure_result)`. + fn capture_graph(client: &Self::Client, f: F) -> crate::error::Result<(Self::Graph, T)> + where + F: FnOnce(&Self::Client) -> crate::error::Result; + /// Allocate device memory /// /// Returns a device pointer (u64) that can be used for operations. @@ -102,6 +130,29 @@ pub trait Runtime: Clone + Send + Sync + 'static { device: &Self::Device, ) -> crate::error::Result<()>; + /// Record an event on the compute stream. Returns an opaque handle. + /// On non-CUDA backends, returns 0 (no-op). + fn record_compute_event(_device: &Self::Device) -> crate::error::Result { + Ok(0) + } + + /// Copy data from device to host using a dedicated copy stream, + /// synchronized via a previously recorded event. + /// + /// On CUDA: copy stream waits on the event, performs D2H, syncs only copy stream. + /// The compute stream continues running concurrently. + /// + /// Default: ignores event, falls back to `copy_from_device`. + fn copy_from_device_pipelined( + src: u64, + dst: &mut [u8], + device: &Self::Device, + event: u64, + ) -> crate::error::Result<()> { + let _ = event; + Self::copy_from_device(src, dst, device) + } + /// Get the default device fn default_device() -> Self::Device; diff --git a/src/runtime/wgpu/client.rs b/src/runtime/wgpu/client.rs index 6582d493..e3bc0f62 100644 --- a/src/runtime/wgpu/client.rs +++ b/src/runtime/wgpu/client.rs @@ -303,7 +303,7 @@ fn get_buffer_registry() -> &'static parking_lot::Mutex } /// Get a buffer by its ID. -pub(crate) fn get_buffer(id: u64) -> Option> { +pub fn get_buffer(id: u64) -> Option> { if id == 0 { return None; } diff --git a/src/runtime/wgpu/fft.rs b/src/runtime/wgpu/fft.rs index 4d516941..8f662970 100644 --- a/src/runtime/wgpu/fft.rs +++ b/src/runtime/wgpu/fft.rs @@ -14,7 +14,7 @@ use super::client::get_buffer; use super::shaders::fft as kernels; -use super::shaders::generator::MAX_WORKGROUP_FFT_SIZE; +const MAX_WORKGROUP_FFT_SIZE: usize = 256; use super::{WgpuClient, WgpuRuntime}; use crate::algorithm::fft::{ FftAlgorithms, FftDirection, FftNormalization, complex_dtype_for_real, real_dtype_for_complex, @@ -117,7 +117,7 @@ impl FftAlgorithms for WgpuClient { let output_ptr = output_guard.ptr(); let output_buffer = get_buffer_or_err!(output_ptr, "FFT output"); - let input_buffer = get_buffer_or_err!(input_contig.storage().ptr(), "FFT input"); + let input_buffer = get_buffer_or_err!(input_contig.ptr(), "FFT input"); // If FFT is on last dimension and data is contiguous, we can do batched FFT directly if dim_usize == ndim - 1 { @@ -155,12 +155,7 @@ impl FftAlgorithms for WgpuClient { let temp_buffer = get_buffer_or_err!(temp_ptr, "FFT temp"); // Copy input to temp buffer initially - WgpuRuntime::copy_within_device( - input_contig.storage().ptr(), - temp_ptr, - output_size, - device, - )?; + WgpuRuntime::copy_within_device(input_contig.ptr(), temp_ptr, output_size, device)?; // Run stages let mut use_temp_as_input = true; @@ -306,7 +301,7 @@ impl FftAlgorithms for WgpuClient { let complex_ptr = complex_guard.ptr(); let complex_buffer = get_buffer_or_err!(complex_ptr, "rfft complex"); - let input_buffer = get_buffer_or_err!(input_contig.storage().ptr(), "rfft input"); + let input_buffer = get_buffer_or_err!(input_contig.ptr(), "rfft input"); let pack_params: [u32; 4] = [n as u32, batch_size as u32, 0, 0]; let params_buffer = self.create_uniform_buffer("rfft_params", 16); @@ -342,7 +337,7 @@ impl FftAlgorithms for WgpuClient { let output_ptr = output_guard.ptr(); let output_buffer = get_buffer_or_err!(output_ptr, "rfft output"); - let fft_buffer = get_buffer_or_err!(fft_result.storage().ptr(), "rfft fft result"); + let fft_buffer = get_buffer_or_err!(fft_result.ptr(), "rfft fft result"); let truncate_params: [u32; 4] = [n as u32, out_n as u32, batch_size as u32, 0]; self.write_buffer(¶ms_buffer, &truncate_params); @@ -420,7 +415,7 @@ impl FftAlgorithms for WgpuClient { let extended_ptr = extended_guard.ptr(); let extended_buffer = get_buffer_or_err!(extended_ptr, "irfft extended"); - let input_buffer = get_buffer_or_err!(input_contig.storage().ptr(), "irfft input"); + let input_buffer = get_buffer_or_err!(input_contig.ptr(), "irfft input"); let extend_params: [u32; 4] = [full_n as u32, half_n as u32, batch_size as u32, 0]; let params_buffer = self.create_uniform_buffer("irfft_params", 16); @@ -458,7 +453,7 @@ impl FftAlgorithms for WgpuClient { let output_ptr = output_guard.ptr(); let output_buffer = get_buffer_or_err!(output_ptr, "irfft output"); - let ifft_buffer = get_buffer_or_err!(ifft_result.storage().ptr(), "irfft ifft result"); + let ifft_buffer = get_buffer_or_err!(ifft_result.ptr(), "irfft ifft result"); let unpack_params: [u32; 4] = [full_n as u32, batch_size as u32, 0, 0]; self.write_buffer(¶ms_buffer, &unpack_params); @@ -551,7 +546,7 @@ impl FftAlgorithms for WgpuClient { let output_ptr = output_guard.ptr(); let output_buffer = get_buffer_or_err!(output_ptr, "fftshift output"); - let input_buffer = get_buffer_or_err!(input_contig.storage().ptr(), "fftshift input"); + let input_buffer = get_buffer_or_err!(input_contig.ptr(), "fftshift input"); let params: [u32; 4] = [n as u32, batch_size as u32, 0, 0]; let params_buffer = self.create_uniform_buffer("fftshift_params", 16); @@ -604,7 +599,7 @@ impl FftAlgorithms for WgpuClient { let output_ptr = output_guard.ptr(); let output_buffer = get_buffer_or_err!(output_ptr, "ifftshift output"); - let input_buffer = get_buffer_or_err!(input_contig.storage().ptr(), "ifftshift input"); + let input_buffer = get_buffer_or_err!(input_contig.ptr(), "ifftshift input"); let params: [u32; 4] = [n as u32, batch_size as u32, 0, 0]; let params_buffer = self.create_uniform_buffer("ifftshift_params", 16); diff --git a/src/runtime/wgpu/linalg/advanced_decompositions.rs b/src/runtime/wgpu/linalg/advanced_decompositions.rs index 1b9ddec7..e10850e3 100644 --- a/src/runtime/wgpu/linalg/advanced_decompositions.rs +++ b/src/runtime/wgpu/linalg/advanced_decompositions.rs @@ -50,10 +50,10 @@ pub fn rsf2csf( let elem = dtype.size_in_bytes(); let t_real_guard = AllocGuard::new(client.allocator(), elem)?; let t_real_ptr = t_real_guard.ptr(); - WgpuRuntime::copy_within_device(schur.t.storage().ptr(), t_real_ptr, elem, device)?; + WgpuRuntime::copy_within_device(schur.t.ptr(), t_real_ptr, elem, device)?; let z_real_guard = AllocGuard::new(client.allocator(), elem)?; let z_real_ptr = z_real_guard.ptr(); - WgpuRuntime::copy_within_device(schur.z.storage().ptr(), z_real_ptr, elem, device)?; + WgpuRuntime::copy_within_device(schur.z.ptr(), z_real_ptr, elem, device)?; return Ok(ComplexSchurDecomposition { z_real: unsafe { WgpuClient::tensor_from_raw(z_real_guard.release(), &[1, 1], dtype, device) @@ -87,8 +87,8 @@ pub fn rsf2csf( let z_imag_buffer = get_buffer_or_err!(z_imag_ptr, "Z_imag"); // Copy input T and Z to real buffers - WgpuRuntime::copy_within_device(schur.t.storage().ptr(), t_real_ptr, matrix_size, device)?; - WgpuRuntime::copy_within_device(schur.z.storage().ptr(), z_real_ptr, matrix_size, device)?; + WgpuRuntime::copy_within_device(schur.t.ptr(), t_real_ptr, matrix_size, device)?; + WgpuRuntime::copy_within_device(schur.z.ptr(), z_real_ptr, matrix_size, device)?; // Zero-initialize imaginary buffers let zeros = vec![0.0f32; n * n]; @@ -179,10 +179,10 @@ pub fn qz_decompose( let elem = dtype.size_in_bytes(); let s_guard = AllocGuard::new(client.allocator(), elem)?; let s_ptr = s_guard.ptr(); - WgpuRuntime::copy_within_device(a.storage().ptr(), s_ptr, elem, device)?; + WgpuRuntime::copy_within_device(a.ptr(), s_ptr, elem, device)?; let t_guard = AllocGuard::new(client.allocator(), elem)?; let t_ptr = t_guard.ptr(); - WgpuRuntime::copy_within_device(b.storage().ptr(), t_ptr, elem, device)?; + WgpuRuntime::copy_within_device(b.ptr(), t_ptr, elem, device)?; let s_tensor = unsafe { WgpuClient::tensor_from_raw(s_guard.release(), &[1], dtype, device) }; let t_tensor = @@ -240,8 +240,8 @@ pub fn qz_decompose( let converged_flag_buffer = get_buffer_or_err!(converged_flag_ptr, "QZ convergence flag"); // Copy input matrices - WgpuRuntime::copy_within_device(a.storage().ptr(), s_ptr, matrix_size, device)?; - WgpuRuntime::copy_within_device(b.storage().ptr(), t_ptr, matrix_size, device)?; + WgpuRuntime::copy_within_device(a.ptr(), s_ptr, matrix_size, device)?; + WgpuRuntime::copy_within_device(b.ptr(), t_ptr, matrix_size, device)?; // Zero-initialize converged flag let zero_i32: [i32; 1] = [0]; diff --git a/src/runtime/wgpu/linalg/banded.rs b/src/runtime/wgpu/linalg/banded.rs index 3d3afff8..297b7482 100644 --- a/src/runtime/wgpu/linalg/banded.rs +++ b/src/runtime/wgpu/linalg/banded.rs @@ -96,9 +96,9 @@ pub fn solve_banded_impl( let ab_contig = ab.contiguous(); let b_contig = b.contiguous(); - let ab_buffer = get_buffer(ab_contig.storage().ptr()) + let ab_buffer = get_buffer(ab_contig.ptr()) .ok_or_else(|| Error::Internal("Failed to get ab buffer".to_string()))?; - let b_buffer = get_buffer(b_contig.storage().ptr()) + let b_buffer = get_buffer(b_contig.ptr()) .ok_or_else(|| Error::Internal("Failed to get b buffer".to_string()))?; // Allocate output buffer for all RHS columns stored contiguously diff --git a/src/runtime/wgpu/linalg/decompositions.rs b/src/runtime/wgpu/linalg/decompositions.rs index d8b9d519..169768b0 100644 --- a/src/runtime/wgpu/linalg/decompositions.rs +++ b/src/runtime/wgpu/linalg/decompositions.rs @@ -56,7 +56,7 @@ pub fn lu_decompose( .ok_or_else(|| Error::Internal("Failed to get singular_flag buffer".to_string()))?; // Copy input to LU buffer - WgpuRuntime::copy_within_device(a.storage().ptr(), lu_ptr, lu_size, device)?; + WgpuRuntime::copy_within_device(a.ptr(), lu_ptr, lu_size, device)?; // Create params buffer let params: [u32; 2] = [m as u32, n as u32]; @@ -156,7 +156,7 @@ pub fn cholesky_decompose( .ok_or_else(|| Error::Internal("Failed to get not_pd_flag buffer".to_string()))?; // Copy input to L buffer - WgpuRuntime::copy_within_device(a.storage().ptr(), l_ptr, l_size, device)?; + WgpuRuntime::copy_within_device(a.ptr(), l_ptr, l_size, device)?; // Create params buffer let params: [u32; 1] = [n as u32]; @@ -250,7 +250,7 @@ pub fn qr_decompose_internal( .ok_or_else(|| Error::Internal("Failed to get workspace buffer".to_string()))?; // Copy A to R (will be modified in place) - WgpuRuntime::copy_within_device(a.storage().ptr(), r_ptr, r_size, device)?; + WgpuRuntime::copy_within_device(a.ptr(), r_ptr, r_size, device)?; // Create params buffer let params: [u32; 3] = [m as u32, n as u32, if thin { 1 } else { 0 }]; diff --git a/src/runtime/wgpu/linalg/eig_general.rs b/src/runtime/wgpu/linalg/eig_general.rs index 5c1090f2..8db2c52f 100644 --- a/src/runtime/wgpu/linalg/eig_general.rs +++ b/src/runtime/wgpu/linalg/eig_general.rs @@ -43,7 +43,7 @@ pub fn eig_decompose( let elem = dtype.size_in_bytes(); let eval_guard = AllocGuard::new(client.allocator(), elem)?; let eval_ptr = eval_guard.ptr(); - WgpuRuntime::copy_within_device(a.storage().ptr(), eval_ptr, elem, device)?; + WgpuRuntime::copy_within_device(a.ptr(), eval_ptr, elem, device)?; let eigenvalues_real = unsafe { WgpuClient::tensor_from_raw(eval_guard.release(), &[1], dtype, device) }; return Ok(GeneralEigenDecomposition { @@ -89,7 +89,7 @@ pub fn eig_decompose( get_buffer_or_err!(converged_flag_ptr, "eig_general convergence flag"); // Copy input to T buffer - WgpuRuntime::copy_within_device(a.storage().ptr(), t_ptr, matrix_size, device)?; + WgpuRuntime::copy_within_device(a.ptr(), t_ptr, matrix_size, device)?; // Zero-initialize converged flag let zero_i32: [i32; 1] = [0]; diff --git a/src/runtime/wgpu/linalg/eig_symmetric.rs b/src/runtime/wgpu/linalg/eig_symmetric.rs index c87cfb35..e9cb3f8a 100644 --- a/src/runtime/wgpu/linalg/eig_symmetric.rs +++ b/src/runtime/wgpu/linalg/eig_symmetric.rs @@ -41,7 +41,7 @@ pub fn eig_decompose_symmetric( let elem = dtype.size_in_bytes(); let eval_guard = AllocGuard::new(client.allocator(), elem)?; let eval_ptr = eval_guard.ptr(); - WgpuRuntime::copy_within_device(a.storage().ptr(), eval_ptr, elem, device)?; + WgpuRuntime::copy_within_device(a.ptr(), eval_ptr, elem, device)?; let eigenvalues = unsafe { WgpuClient::tensor_from_raw(eval_guard.release(), &[1], dtype, device) }; let eigenvectors = Tensor::::from_slice(&[1.0f32], &[1, 1], device); @@ -74,7 +74,7 @@ pub fn eig_decompose_symmetric( get_buffer_or_err!(converged_flag_ptr, "eigendecomposition convergence flag"); // Copy input to work buffer - WgpuRuntime::copy_within_device(a.storage().ptr(), work_ptr, work_size, device)?; + WgpuRuntime::copy_within_device(a.ptr(), work_ptr, work_size, device)?; // Zero-initialize converged flag let zero_i32: [i32; 1] = [0]; diff --git a/src/runtime/wgpu/linalg/lstsq.rs b/src/runtime/wgpu/linalg/lstsq.rs index f445ade5..4058c75f 100644 --- a/src/runtime/wgpu/linalg/lstsq.rs +++ b/src/runtime/wgpu/linalg/lstsq.rs @@ -91,7 +91,7 @@ pub fn lstsq( // Q^T @ B gives [m, num_rhs] let qtb = client.matmul(&q_t, &b_mat)?; - let r_buffer = get_buffer(qr.r.storage().ptr()) + let r_buffer = get_buffer(qr.r.ptr()) .ok_or_else(|| Error::Internal("Failed to get R buffer".to_string()))?; // Allocate output X [n, num_rhs] or [n] for vector @@ -106,7 +106,7 @@ pub fn lstsq( // Get first n elements of Q^T @ b using GPU-side slicing let qtb_flat = qtb.reshape(&[m])?; let qtb_n = qtb_flat.narrow(0, 0, n)?.contiguous(); - let qtb_buffer = get_buffer(qtb_n.storage().ptr()) + let qtb_buffer = get_buffer(qtb_n.ptr()) .ok_or_else(|| Error::Internal("Failed to get qtb buffer".to_string()))?; let params: [u32; 1] = [n as u32]; @@ -125,7 +125,7 @@ pub fn lstsq( } else { // Multi-RHS: solve for each column let qtb_contig = qtb.contiguous(); - let qtb_buffer = get_buffer(qtb_contig.storage().ptr()) + let qtb_buffer = get_buffer(qtb_contig.ptr()) .ok_or_else(|| Error::Internal("Failed to get qtb buffer".to_string()))?; let col_size = n * dtype.size_in_bytes(); diff --git a/src/runtime/wgpu/linalg/matrix_functions.rs b/src/runtime/wgpu/linalg/matrix_functions.rs index 44c6e976..42fe2332 100644 --- a/src/runtime/wgpu/linalg/matrix_functions.rs +++ b/src/runtime/wgpu/linalg/matrix_functions.rs @@ -445,8 +445,7 @@ fn compute_norm(client: &WgpuClient, a: &Tensor) -> Result { fn get_tensor_buffer(t: &Tensor) -> Result> { use super::super::client::get_buffer; - get_buffer(t.storage().ptr()) - .ok_or_else(|| Error::Internal("Failed to get tensor buffer".to_string())) + get_buffer(t.ptr()).ok_or_else(|| Error::Internal("Failed to get tensor buffer".to_string())) } /// Compute exp(T) for quasi-triangular matrix T using GPU kernels. diff --git a/src/runtime/wgpu/linalg/matrix_ops.rs b/src/runtime/wgpu/linalg/matrix_ops.rs index 54586952..a1084e75 100644 --- a/src/runtime/wgpu/linalg/matrix_ops.rs +++ b/src/runtime/wgpu/linalg/matrix_ops.rs @@ -58,9 +58,9 @@ pub fn inverse(client: &WgpuClient, a: &Tensor) -> Result) -> Result) -> Result) -> Result) -> Result::from_slice(&[1.0f32], &[1, 1], device); return Ok(SchurDecomposition { z, t }); @@ -62,7 +62,7 @@ pub fn schur_decompose( let converged_flag_buffer = get_buffer_or_err!(converged_flag_ptr, "Schur convergence flag"); // Copy input to T buffer - WgpuRuntime::copy_within_device(a.storage().ptr(), t_ptr, matrix_size, device)?; + WgpuRuntime::copy_within_device(a.ptr(), t_ptr, matrix_size, device)?; // Zero-initialize converged flag let zero_i32: [i32; 1] = [0]; diff --git a/src/runtime/wgpu/linalg/solvers.rs b/src/runtime/wgpu/linalg/solvers.rs index 8f49f430..e66036f9 100644 --- a/src/runtime/wgpu/linalg/solvers.rs +++ b/src/runtime/wgpu/linalg/solvers.rs @@ -73,9 +73,9 @@ pub fn solve( let lu_result = lu_decompose(client, a)?; // Get LU and pivots buffers (both already on GPU, no transfers needed) - let lu_buffer = get_buffer(lu_result.lu.storage().ptr()) + let lu_buffer = get_buffer(lu_result.lu.ptr()) .ok_or_else(|| Error::Internal("Failed to get lu buffer".to_string()))?; - let pivots_buffer = get_buffer(lu_result.pivots.storage().ptr()) + let pivots_buffer = get_buffer(lu_result.pivots.ptr()) .ok_or_else(|| Error::Internal("Failed to get pivots buffer".to_string()))?; // Allocate temporary buffers for single column operations @@ -95,7 +95,7 @@ pub fn solve( // Get b buffer for GPU column extraction let b_contig = b.contiguous(); - let b_buffer = get_buffer(b_contig.storage().ptr()) + let b_buffer = get_buffer(b_contig.ptr()) .ok_or_else(|| Error::Internal("Failed to get b buffer".to_string()))?; // Allocate output buffer for all RHS (column-major: each solved column stored contiguously) diff --git a/src/runtime/wgpu/linalg/svd.rs b/src/runtime/wgpu/linalg/svd.rs index fc72b1e5..a8f84f42 100644 --- a/src/runtime/wgpu/linalg/svd.rs +++ b/src/runtime/wgpu/linalg/svd.rs @@ -19,7 +19,7 @@ use crate::tensor::Tensor; /// Helper to get buffer from tensor, with proper error handling. fn get_tensor_buffer(tensor: &Tensor) -> Result> { - let ptr = tensor.storage().ptr(); + let ptr = tensor.ptr(); get_buffer(ptr).ok_or_else(|| Error::Internal("Failed to get buffer from tensor".to_string())) } diff --git a/src/runtime/wgpu/linalg/triangular_solve.rs b/src/runtime/wgpu/linalg/triangular_solve.rs index b1f85569..d9d93418 100644 --- a/src/runtime/wgpu/linalg/triangular_solve.rs +++ b/src/runtime/wgpu/linalg/triangular_solve.rs @@ -67,10 +67,10 @@ pub fn solve_triangular_lower( ))); }; - let l_buffer = get_buffer(l.storage().ptr()) - .ok_or_else(|| Error::Internal("Failed to get L buffer".to_string()))?; + let l_buffer = + get_buffer(l.ptr()).ok_or_else(|| Error::Internal("Failed to get L buffer".to_string()))?; let b_contig = b.contiguous(); - let b_buffer = get_buffer(b_contig.storage().ptr()) + let b_buffer = get_buffer(b_contig.ptr()) .ok_or_else(|| Error::Internal("Failed to get b buffer".to_string()))?; // Allocate output @@ -227,10 +227,10 @@ pub fn solve_triangular_upper( ))); }; - let u_buffer = get_buffer(u.storage().ptr()) - .ok_or_else(|| Error::Internal("Failed to get U buffer".to_string()))?; + let u_buffer = + get_buffer(u.ptr()).ok_or_else(|| Error::Internal("Failed to get U buffer".to_string()))?; let b_contig = b.contiguous(); - let b_buffer = get_buffer(b_contig.storage().ptr()) + let b_buffer = get_buffer(b_contig.ptr()) .ok_or_else(|| Error::Internal("Failed to get b buffer".to_string()))?; // Allocate output diff --git a/src/runtime/wgpu/mod.rs b/src/runtime/wgpu/mod.rs index fdc479e3..7febed81 100644 --- a/src/runtime/wgpu/mod.rs +++ b/src/runtime/wgpu/mod.rs @@ -40,6 +40,6 @@ mod special; mod statistics; pub use crate::tensor::Tensor; -pub use client::{WgpuAllocator, WgpuClient, WgpuRawHandle}; +pub use client::{WgpuAllocator, WgpuClient, WgpuRawHandle, get_buffer}; pub use device::{WgpuDevice, WgpuError}; pub use runtime::{WgpuRuntime, is_wgpu_available, wgpu_device, wgpu_device_id}; diff --git a/src/runtime/wgpu/ops/helpers.rs b/src/runtime/wgpu/ops/helpers.rs index 3144a8db..0db27635 100644 --- a/src/runtime/wgpu/ops/helpers.rs +++ b/src/runtime/wgpu/ops/helpers.rs @@ -37,7 +37,7 @@ pub(super) fn create_params_buffer( pub(crate) fn get_tensor_buffer( tensor: &Tensor, ) -> Result> { - let ptr = tensor.storage().ptr(); + let ptr = tensor.ptr(); get_buffer(ptr).ok_or_else(|| Error::Internal("Buffer not found in registry".to_string())) } @@ -171,6 +171,19 @@ pub(super) struct LayerNormParams { pub(super) eps: f32, } +#[repr(C)] +#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +pub(super) struct GroupNormParams { + pub(super) batch_size: u32, + pub(super) channels: u32, + pub(super) spatial: u32, + pub(super) num_groups: u32, + pub(super) channels_per_group: u32, + pub(super) eps: f32, + pub(super) _pad0: u32, + pub(super) _pad1: u32, +} + #[repr(C)] #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] pub(super) struct CatShaderParams { @@ -721,6 +734,20 @@ pub(crate) struct Gather2dParams { pub(crate) _pad: u32, } +/// Params for slice_assign operations +#[repr(C)] +#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +pub(crate) struct SliceAssignParams { + pub(crate) outer_size: u32, + pub(crate) dst_dim_size: u32, + pub(crate) src_dim_size: u32, + pub(crate) inner_size: u32, + pub(crate) start: u32, + pub(crate) _pad0: u32, + pub(crate) _pad1: u32, + pub(crate) _pad2: u32, +} + /// Params for unique_with_counts operations #[repr(C)] #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] diff --git a/src/runtime/wgpu/ops/native/activation.rs b/src/runtime/wgpu/ops/native/activation.rs index d3af5f4b..da22aec0 100644 --- a/src/runtime/wgpu/ops/native/activation.rs +++ b/src/runtime/wgpu/ops/native/activation.rs @@ -3,7 +3,7 @@ use super::helpers::*; use crate::error::{Error, Result}; use crate::runtime::ensure_contiguous; -use crate::runtime::wgpu::shaders::activation_launcher; +use crate::runtime::wgpu::shaders::{activation_launcher, fused_activation_mul}; use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; use crate::tensor::Tensor; @@ -65,3 +65,172 @@ pub(crate) fn native_parametric_activation( Ok(out) } + +/// Native fused activation-mul forward: out = activation(a) * b. F32 only. +pub(crate) fn native_fused_activation_mul_fwd( + client: &WgpuClient, + op: &'static str, + a: &Tensor, + b: &Tensor, +) -> Result> { + let dtype = a.dtype(); + if b.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: b.dtype(), + }); + } + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let numel = a.numel(); + + let out = alloc_output(client, a.shape(), dtype); + + let a_buf = get_tensor_buffer(&a_contig)?; + let b_buf = get_tensor_buffer(&b_contig)?; + let out_buf = get_tensor_buffer(&out)?; + + let params = BinaryParams { + numel: numel as u32, + }; + let params_buf = create_params_buffer(client, ¶ms); + + match op { + "silu_mul" => fused_activation_mul::launch_silu_mul( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &b_buf, + &out_buf, + ¶ms_buf, + numel, + dtype, + )?, + "gelu_mul" => fused_activation_mul::launch_gelu_mul( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &b_buf, + &out_buf, + ¶ms_buf, + numel, + dtype, + )?, + "relu_mul" => fused_activation_mul::launch_relu_mul( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &b_buf, + &out_buf, + ¶ms_buf, + numel, + dtype, + )?, + "sigmoid_mul" => fused_activation_mul::launch_sigmoid_mul( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &b_buf, + &out_buf, + ¶ms_buf, + numel, + dtype, + )?, + _ => { + return Err(Error::Internal(format!( + "Unknown fused activation-mul op: {}", + op + ))); + } + } + + Ok(out) +} + +/// Native fused activation-mul backward: d_a = grad * b * act'(a), d_b = grad * act(a). F32 only. +pub(crate) fn native_fused_activation_mul_bwd( + client: &WgpuClient, + op: &'static str, + grad: &Tensor, + a: &Tensor, + b: &Tensor, +) -> Result<(Tensor, Tensor)> { + let dtype = a.dtype(); + let grad_contig = ensure_contiguous(grad); + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let numel = a.numel(); + + let d_a = alloc_output(client, a.shape(), dtype); + let d_b = alloc_output(client, b.shape(), dtype); + + let grad_buf = get_tensor_buffer(&grad_contig)?; + let a_buf = get_tensor_buffer(&a_contig)?; + let b_buf = get_tensor_buffer(&b_contig)?; + let d_a_buf = get_tensor_buffer(&d_a)?; + let d_b_buf = get_tensor_buffer(&d_b)?; + + let params = BinaryParams { + numel: numel as u32, + }; + let params_buf = create_params_buffer(client, ¶ms); + + match op { + "silu_mul_bwd" => fused_activation_mul::launch_silu_mul_bwd( + client.pipeline_cache(), + client.wgpu_queue(), + &grad_buf, + &a_buf, + &b_buf, + &d_a_buf, + &d_b_buf, + ¶ms_buf, + numel, + dtype, + )?, + "gelu_mul_bwd" => fused_activation_mul::launch_gelu_mul_bwd( + client.pipeline_cache(), + client.wgpu_queue(), + &grad_buf, + &a_buf, + &b_buf, + &d_a_buf, + &d_b_buf, + ¶ms_buf, + numel, + dtype, + )?, + "relu_mul_bwd" => fused_activation_mul::launch_relu_mul_bwd( + client.pipeline_cache(), + client.wgpu_queue(), + &grad_buf, + &a_buf, + &b_buf, + &d_a_buf, + &d_b_buf, + ¶ms_buf, + numel, + dtype, + )?, + "sigmoid_mul_bwd" => fused_activation_mul::launch_sigmoid_mul_bwd( + client.pipeline_cache(), + client.wgpu_queue(), + &grad_buf, + &a_buf, + &b_buf, + &d_a_buf, + &d_b_buf, + ¶ms_buf, + numel, + dtype, + )?, + _ => { + return Err(Error::Internal(format!( + "Unknown fused activation-mul bwd op: {}", + op + ))); + } + } + + Ok((d_a, d_b)) +} diff --git a/src/runtime/wgpu/ops/native/fused_elementwise.rs b/src/runtime/wgpu/ops/native/fused_elementwise.rs new file mode 100644 index 00000000..226198b6 --- /dev/null +++ b/src/runtime/wgpu/ops/native/fused_elementwise.rs @@ -0,0 +1,145 @@ +//! Fused elementwise native GPU operations for WebGPU. + +use super::helpers::*; +use crate::error::{Error, Result}; +use crate::runtime::ensure_contiguous; +use crate::runtime::wgpu::shaders::fused_elementwise; +use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; +use crate::tensor::Tensor; + +/// Native fused_mul_add: out = a * b + c. F32 only. +pub(crate) fn native_fused_mul_add( + client: &WgpuClient, + a: &Tensor, + b: &Tensor, + c: &Tensor, +) -> Result> { + let dtype = a.dtype(); + if b.dtype() != dtype || c.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if b.dtype() != dtype { + b.dtype() + } else { + c.dtype() + }, + }); + } + if a.shape() != b.shape() || a.shape() != c.shape() { + return Err(Error::ShapeMismatch { + expected: a.shape().to_vec(), + got: if a.shape() != b.shape() { + b.shape().to_vec() + } else { + c.shape().to_vec() + }, + }); + } + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let c_contig = ensure_contiguous(c); + let numel = a.numel(); + let out = alloc_output(client, a.shape(), dtype); + + let a_buf = get_tensor_buffer(&a_contig)?; + let b_buf = get_tensor_buffer(&b_contig)?; + let c_buf = get_tensor_buffer(&c_contig)?; + let out_buf = get_tensor_buffer(&out)?; + + fused_elementwise::launch_fused_mul_add( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &b_buf, + &c_buf, + &out_buf, + numel, + dtype, + )?; + + Ok(out) +} + +/// Native fused_add_mul: out = (a + b) * c. F32 only. +pub(crate) fn native_fused_add_mul( + client: &WgpuClient, + a: &Tensor, + b: &Tensor, + c: &Tensor, +) -> Result> { + let dtype = a.dtype(); + if b.dtype() != dtype || c.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: if b.dtype() != dtype { + b.dtype() + } else { + c.dtype() + }, + }); + } + if a.shape() != b.shape() || a.shape() != c.shape() { + return Err(Error::ShapeMismatch { + expected: a.shape().to_vec(), + got: if a.shape() != b.shape() { + b.shape().to_vec() + } else { + c.shape().to_vec() + }, + }); + } + + let a_contig = ensure_contiguous(a); + let b_contig = ensure_contiguous(b); + let c_contig = ensure_contiguous(c); + let numel = a.numel(); + let out = alloc_output(client, a.shape(), dtype); + + let a_buf = get_tensor_buffer(&a_contig)?; + let b_buf = get_tensor_buffer(&b_contig)?; + let c_buf = get_tensor_buffer(&c_contig)?; + let out_buf = get_tensor_buffer(&out)?; + + fused_elementwise::launch_fused_add_mul( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &b_buf, + &c_buf, + &out_buf, + numel, + dtype, + )?; + + Ok(out) +} + +/// Native fused_mul_add_scalar: out = a * scale + bias. F32 only. +pub(crate) fn native_fused_mul_add_scalar( + client: &WgpuClient, + a: &Tensor, + scale: f64, + bias: f64, +) -> Result> { + let dtype = a.dtype(); + let a_contig = ensure_contiguous(a); + let numel = a.numel(); + let out = alloc_output(client, a.shape(), dtype); + + let a_buf = get_tensor_buffer(&a_contig)?; + let out_buf = get_tensor_buffer(&out)?; + + fused_elementwise::launch_fused_mul_add_scalar( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &out_buf, + numel, + dtype, + scale as f32, + bias as f32, + )?; + + Ok(out) +} diff --git a/src/runtime/wgpu/ops/native/gemm_epilogue.rs b/src/runtime/wgpu/ops/native/gemm_epilogue.rs new file mode 100644 index 00000000..949261f0 --- /dev/null +++ b/src/runtime/wgpu/ops/native/gemm_epilogue.rs @@ -0,0 +1,255 @@ +//! Native WGPU GEMM epilogue operations. + +use super::helpers::*; +use crate::error::{Error, Result}; +use crate::ops::{GemmActivation, matmul_bias_output_shape, validate_matmul_bias_dtypes}; +use crate::runtime::ensure_contiguous; +use crate::runtime::wgpu::shaders::gemm_epilogue; +use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; +use crate::tensor::Tensor; + +pub(crate) fn native_gemm_bias_activation( + client: &WgpuClient, + a: &Tensor, + b: &Tensor, + bias: &Tensor, + activation: GemmActivation, +) -> Result> { + let dtype = validate_matmul_bias_dtypes(a.dtype(), b.dtype(), bias.dtype())?; + let out_shape = matmul_bias_output_shape(a.shape(), b.shape(), bias.shape()) + .ok_or_else(|| Error::shape_mismatch(a.shape(), b.shape()))?; + + let a_shape = a.shape(); + let b_shape = b.shape(); + + if a_shape.len() == 2 && b_shape.len() == 2 { + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + let a_c = ensure_contiguous(a); + let b_c = ensure_contiguous(b); + let bias_c = ensure_contiguous(bias); + let out = alloc_output(client, &out_shape, dtype); + + let a_buf = get_tensor_buffer(&a_c)?; + let b_buf = get_tensor_buffer(&b_c)?; + let bias_buf = get_tensor_buffer(&bias_c)?; + let out_buf = get_tensor_buffer(&out)?; + + let params_buf = gemm_epilogue::create_epilogue_params_buffer( + client.pipeline_cache(), + m as u32, + k as u32, + n as u32, + 1, + activation, + ); + + gemm_epilogue::launch_gemm_bias_act( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &b_buf, + &bias_buf, + &out_buf, + ¶ms_buf, + m, + n, + dtype, + )?; + + return Ok(out); + } + + if a_shape.len() == 3 && b_shape.len() == 3 { + let batch_size = a_shape[0]; + let m = a_shape[1]; + let k = a_shape[2]; + let n = b_shape[2]; + + if b_shape[0] != batch_size { + return Err(Error::ShapeMismatch { + expected: vec![batch_size, m, k], + got: b_shape.to_vec(), + }); + } + + let a_c = ensure_contiguous(a); + let b_c = ensure_contiguous(b); + let bias_c = ensure_contiguous(bias); + let out = alloc_output(client, &out_shape, dtype); + + let a_buf = get_tensor_buffer(&a_c)?; + let b_buf = get_tensor_buffer(&b_c)?; + let bias_buf = get_tensor_buffer(&bias_c)?; + let out_buf = get_tensor_buffer(&out)?; + + let params_buf = gemm_epilogue::create_epilogue_params_buffer( + client.pipeline_cache(), + m as u32, + k as u32, + n as u32, + batch_size as u32, + activation, + ); + + gemm_epilogue::launch_gemm_bias_act_batched( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &b_buf, + &bias_buf, + &out_buf, + ¶ms_buf, + m, + n, + batch_size, + dtype, + )?; + + return Ok(out); + } + + Err(Error::BackendLimitation { + backend: "WebGPU", + operation: "gemm_bias_activation", + reason: format!( + "only supports 2D and 3D tensors, got shapes {:?} and {:?}", + a.shape(), + b.shape() + ), + }) +} + +pub(crate) fn native_gemm_bias_residual( + client: &WgpuClient, + a: &Tensor, + b: &Tensor, + bias: &Tensor, + residual: &Tensor, +) -> Result> { + let dtype = validate_matmul_bias_dtypes(a.dtype(), b.dtype(), bias.dtype())?; + if residual.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: residual.dtype(), + }); + } + + let out_shape = matmul_bias_output_shape(a.shape(), b.shape(), bias.shape()) + .ok_or_else(|| Error::shape_mismatch(a.shape(), b.shape()))?; + + if residual.shape() != out_shape.as_slice() { + return Err(Error::ShapeMismatch { + expected: out_shape.clone(), + got: residual.shape().to_vec(), + }); + } + + let a_shape = a.shape(); + let b_shape = b.shape(); + + if a_shape.len() == 2 && b_shape.len() == 2 { + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + let a_c = ensure_contiguous(a); + let b_c = ensure_contiguous(b); + let bias_c = ensure_contiguous(bias); + let res_c = ensure_contiguous(residual); + let out = alloc_output(client, &out_shape, dtype); + + let a_buf = get_tensor_buffer(&a_c)?; + let b_buf = get_tensor_buffer(&b_c)?; + let bias_buf = get_tensor_buffer(&bias_c)?; + let res_buf = get_tensor_buffer(&res_c)?; + let out_buf = get_tensor_buffer(&out)?; + + let params_buf = gemm_epilogue::create_residual_params_buffer( + client.pipeline_cache(), + m as u32, + k as u32, + n as u32, + 1, + ); + + gemm_epilogue::launch_gemm_bias_residual( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &b_buf, + &bias_buf, + &res_buf, + &out_buf, + ¶ms_buf, + m, + n, + dtype, + )?; + + return Ok(out); + } + + if a_shape.len() == 3 && b_shape.len() == 3 { + let batch_size = a_shape[0]; + let m = a_shape[1]; + let k = a_shape[2]; + let n = b_shape[2]; + + if b_shape[0] != batch_size { + return Err(Error::ShapeMismatch { + expected: vec![batch_size, m, k], + got: b_shape.to_vec(), + }); + } + + let a_c = ensure_contiguous(a); + let b_c = ensure_contiguous(b); + let bias_c = ensure_contiguous(bias); + let res_c = ensure_contiguous(residual); + let out = alloc_output(client, &out_shape, dtype); + + let a_buf = get_tensor_buffer(&a_c)?; + let b_buf = get_tensor_buffer(&b_c)?; + let bias_buf = get_tensor_buffer(&bias_c)?; + let res_buf = get_tensor_buffer(&res_c)?; + let out_buf = get_tensor_buffer(&out)?; + + let params_buf = gemm_epilogue::create_residual_params_buffer( + client.pipeline_cache(), + m as u32, + k as u32, + n as u32, + batch_size as u32, + ); + + gemm_epilogue::launch_gemm_bias_residual_batched( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &b_buf, + &bias_buf, + &res_buf, + &out_buf, + ¶ms_buf, + m, + n, + batch_size, + dtype, + )?; + + return Ok(out); + } + + Err(Error::BackendLimitation { + backend: "WebGPU", + operation: "gemm_bias_residual", + reason: format!( + "only supports 2D and 3D tensors, got shapes {:?} and {:?}", + a.shape(), + b.shape() + ), + }) +} diff --git a/src/runtime/wgpu/ops/native/indexing.rs b/src/runtime/wgpu/ops/native/indexing.rs index 4d3a9d29..3b6c47b4 100644 --- a/src/runtime/wgpu/ops/native/indexing.rs +++ b/src/runtime/wgpu/ops/native/indexing.rs @@ -2,7 +2,7 @@ use super::helpers::*; use crate::error::{Error, Result}; -use crate::runtime::wgpu::shaders::index; +use crate::runtime::wgpu::shaders::{index, launch_slice_assign}; use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; use crate::runtime::{compute_contiguous_strides, ensure_contiguous}; use crate::tensor::Tensor; @@ -417,3 +417,109 @@ pub(crate) fn native_scatter( Ok(out) } + +pub(crate) fn native_slice_assign( + client: &WgpuClient, + dst: &Tensor, + src: &Tensor, + dim: usize, + start: usize, +) -> Result> { + let ndim = dst.ndim(); + if dim >= ndim { + return Err(Error::InvalidDimension { + dim: dim as isize, + ndim, + }); + } + + if src.ndim() != ndim { + return Err(Error::ShapeMismatch { + expected: dst.shape().to_vec(), + got: src.shape().to_vec(), + }); + } + for d in 0..ndim { + if d != dim && src.shape()[d] != dst.shape()[d] { + return Err(Error::ShapeMismatch { + expected: dst.shape().to_vec(), + got: src.shape().to_vec(), + }); + } + } + + let src_dim_size = src.shape()[dim]; + let dst_dim_size = dst.shape()[dim]; + if start + src_dim_size > dst_dim_size { + return Err(Error::InvalidArgument { + arg: "start", + reason: format!( + "start ({}) + src dim size ({}) exceeds dst dim size ({})", + start, src_dim_size, dst_dim_size + ), + }); + } + + let dtype = dst.dtype(); + if src.dtype() != dtype { + return Err(Error::DTypeMismatch { + lhs: dtype, + rhs: src.dtype(), + }); + } + + let outer_size: usize = dst.shape()[..dim].iter().product(); + let outer_size = outer_size.max(1); + let inner_size: usize = dst.shape()[dim + 1..].iter().product(); + let inner_size = inner_size.max(1); + let total_src = outer_size * src_dim_size * inner_size; + + let dst_contig = ensure_contiguous(dst); + let src_contig = ensure_contiguous(src); + + let out = alloc_output(client, dst.shape(), dtype); + + let dst_buf = get_tensor_buffer(&dst_contig)?; + let src_buf = get_tensor_buffer(&src_contig)?; + let out_buf = get_tensor_buffer(&out)?; + + // First copy dst → output + let copy_params = CopyParams { + numel: dst.numel() as u32, + }; + let copy_params_buf = create_params_buffer(client, ©_params); + index::launch_copy( + client.pipeline_cache(), + client.wgpu_queue(), + &dst_buf, + &out_buf, + ©_params_buf, + dst.numel(), + dtype, + )?; + + // Then overwrite the slice with src + let params = SliceAssignParams { + outer_size: outer_size as u32, + dst_dim_size: dst_dim_size as u32, + src_dim_size: src_dim_size as u32, + inner_size: inner_size as u32, + start: start as u32, + _pad0: 0, + _pad1: 0, + _pad2: 0, + }; + let params_buf = create_params_buffer(client, ¶ms); + + launch_slice_assign( + client.pipeline_cache(), + client.wgpu_queue(), + &src_buf, + &out_buf, + ¶ms_buf, + total_src.max(1), + dtype, + )?; + + Ok(out) +} diff --git a/src/runtime/wgpu/ops/native/matmul.rs b/src/runtime/wgpu/ops/native/matmul.rs index 311b9341..e22d1d4e 100644 --- a/src/runtime/wgpu/ops/native/matmul.rs +++ b/src/runtime/wgpu/ops/native/matmul.rs @@ -5,10 +5,35 @@ use crate::error::Error; use crate::error::Result; use crate::ops::{matmul_bias_output_shape, matmul_output_shape, validate_matmul_bias_dtypes}; use crate::runtime::ensure_contiguous; -use crate::runtime::wgpu::shaders::matmul; +use crate::runtime::wgpu::shaders::{gemv_bt, matmul}; use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; use crate::tensor::Tensor; +/// Detect if a 2D tensor is a simple transpose of a contiguous [N,K] matrix. +/// Shape [K, N] with strides [1, K] means it's a transpose view of contiguous [N, K]. +fn is_simple_transpose_2d(tensor: &Tensor) -> bool { + let shape = tensor.shape(); + let strides = tensor.strides(); + if shape.len() != 2 { + return false; + } + strides[0] == 1 && strides[1] == shape[0] as isize +} + +/// Detect if the last two dims of a 3D tensor are a simple transpose. +/// Shape [B, K, N] with strides [N*K, 1, K] means each batch slice +/// is a transpose of contiguous [N, K]. +fn is_batched_transpose_last2(tensor: &Tensor) -> bool { + let shape = tensor.shape(); + let strides = tensor.strides(); + if shape.len() != 3 { + return false; + } + let k = shape[1]; + let n = shape[2]; + strides[1] == 1 && strides[2] == k as isize && strides[0] == (n * k) as isize +} + pub(crate) fn native_matmul( client: &WgpuClient, a: &Tensor, @@ -28,6 +53,38 @@ pub(crate) fn native_matmul( let k = a_shape[1]; let n = b_shape[1]; + // GEMV-BT fast path: transposed B with small M + if m <= 16 && is_simple_transpose_2d(b) { + let a_contig = ensure_contiguous(a); + let out = alloc_output(client, &out_shape, dtype); + + let a_buf = get_tensor_buffer(&a_contig)?; + let b_buf = get_tensor_buffer(b)?; // Use original [N,K] buffer directly + let out_buf = get_tensor_buffer(&out)?; + + let params = MatmulParams { + m: m as u32, + k: k as u32, + n: n as u32, + batch_size: 1, + }; + let params_buf = create_params_buffer(client, ¶ms); + + gemv_bt::launch_gemv_bt( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &b_buf, + &out_buf, + ¶ms_buf, + m, + n, + dtype, + )?; + + return Ok(out); + } + let a_contig = ensure_contiguous(a); let b_contig = ensure_contiguous(b); @@ -90,6 +147,39 @@ pub(crate) fn native_matmul( }); } + // GEMV-BT fast path: transposed B with small M + if m <= 16 && is_batched_transpose_last2(b) { + let a_contig = ensure_contiguous(a); + let out = alloc_output(client, &out_shape, dtype); + + let a_buf = get_tensor_buffer(&a_contig)?; + let b_buf = get_tensor_buffer(b)?; + let out_buf = get_tensor_buffer(&out)?; + + let params = MatmulParams { + m: m as u32, + k: k as u32, + n: n as u32, + batch_size: batch_size as u32, + }; + let params_buf = create_params_buffer(client, ¶ms); + + gemv_bt::launch_batched_gemv_bt( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &b_buf, + &out_buf, + ¶ms_buf, + m, + n, + batch_size, + dtype, + )?; + + return Ok(out); + } + let a_contig = ensure_contiguous(a); let b_contig = ensure_contiguous(b); @@ -123,17 +213,90 @@ pub(crate) fn native_matmul( return Ok(out); } - // >3D tensors are not supported - return error instead of silent fallback - // (WebGPU shader dispatch is limited to 3D workgroups) - Err(Error::BackendLimitation { - backend: "WebGPU", - operation: "matmul", - reason: format!( - "only supports 2D and 3D tensors, got shapes {:?} and {:?}", - a.shape(), - b.shape() - ), - }) + // >3D: flatten leading dims into batch, run 3D batched matmul, reshape back. + // Same strategy as CUDA backend (which computes batch_size = product of leading dims). + let ndim_a = a_shape.len(); + let ndim_b = b_shape.len(); + + if ndim_a < 2 || ndim_b < 2 { + return Err(Error::BackendLimitation { + backend: "WebGPU", + operation: "matmul", + reason: format!( + "requires at least 2D tensors, got shapes {:?} and {:?}", + a_shape, b_shape + ), + }); + } + + let m = a_shape[ndim_a - 2]; + let k = a_shape[ndim_a - 1]; + let n = b_shape[ndim_b - 1]; + + let batch_a: usize = a_shape[..ndim_a - 2].iter().product(); + let batch_b: usize = b_shape[..ndim_b - 2].iter().product(); + let batch_size = batch_a.max(batch_b); + + // Flatten to 3D + let a_3d = ensure_contiguous(a) + .reshape(&[batch_a, m, k]) + .map_err(|_| Error::shape_mismatch(a_shape, b_shape))?; + let b_3d = ensure_contiguous(b) + .reshape(&[batch_b, k, n]) + .map_err(|_| Error::shape_mismatch(a_shape, b_shape))?; + + // Broadcast if batch dims differ (one must be 1) + let (a_batched, b_batched) = if batch_a == batch_b { + (a_3d, b_3d) + } else if batch_a == 1 { + ( + a_3d.broadcast_to(&[batch_size, m, k]) + .map_err(|_| Error::shape_mismatch(a_shape, b_shape))? + .contiguous(), + b_3d, + ) + } else if batch_b == 1 { + ( + a_3d, + b_3d.broadcast_to(&[batch_size, k, n]) + .map_err(|_| Error::shape_mismatch(a_shape, b_shape))? + .contiguous(), + ) + } else { + return Err(Error::shape_mismatch(a_shape, b_shape)); + }; + + let a_buf = get_tensor_buffer(&a_batched)?; + let b_buf = get_tensor_buffer(&b_batched)?; + let out_flat = alloc_output(client, &[batch_size, m, n], dtype); + let out_buf = get_tensor_buffer(&out_flat)?; + + let params = MatmulParams { + m: m as u32, + k: k as u32, + n: n as u32, + batch_size: batch_size as u32, + }; + let params_buf = create_params_buffer(client, ¶ms); + + matmul::launch_batched_matmul( + client.pipeline_cache(), + client.wgpu_queue(), + &a_buf, + &b_buf, + &out_buf, + ¶ms_buf, + m, + n, + batch_size, + dtype, + )?; + + // Reshape back to original leading dims + [m, n] + let result = out_flat + .reshape(&out_shape) + .map_err(|_| Error::shape_mismatch(a_shape, b_shape))?; + Ok(result) } /// Native WGPU implementation of fused matrix multiplication with bias. diff --git a/src/runtime/wgpu/ops/native/mod.rs b/src/runtime/wgpu/ops/native/mod.rs index 36233bf0..7db7e403 100644 --- a/src/runtime/wgpu/ops/native/mod.rs +++ b/src/runtime/wgpu/ops/native/mod.rs @@ -11,6 +11,8 @@ mod cast; mod compare; mod conditional; mod cumulative; +mod fused_elementwise; +mod gemm_epilogue; mod indexing; pub(crate) mod logical; mod masking; @@ -21,16 +23,29 @@ mod semiring_matmul; mod unary; // Re-export all native functions for use by ops/wgpu/ implementations -pub(crate) use activation::native_parametric_activation; +pub(crate) use activation::{ + native_fused_activation_mul_bwd, native_fused_activation_mul_fwd, native_parametric_activation, +}; pub(crate) use binary::{native_binary_op, native_scalar_op}; pub(crate) use cast::native_cast_op; pub(crate) use compare::native_compare_op; pub(crate) use conditional::{native_clamp, native_where_cond}; pub(crate) use cumulative::{native_cumprod, native_cumsum, native_logsumexp}; -pub(crate) use indexing::{native_gather, native_index_put, native_index_select, native_scatter}; +pub(crate) use fused_elementwise::{ + native_fused_add_mul, native_fused_mul_add, native_fused_mul_add_scalar, +}; +pub(crate) use gemm_epilogue::{native_gemm_bias_activation, native_gemm_bias_residual}; +pub(crate) use indexing::{ + native_gather, native_index_put, native_index_select, native_scatter, native_slice_assign, +}; pub(crate) use masking::{native_embedding_lookup, native_masked_fill, native_masked_select}; pub(crate) use matmul::{native_matmul, native_matmul_bias}; -pub(crate) use normalization::{native_layer_norm, native_rms_norm}; -pub(crate) use reduce::{native_argreduce_op, native_reduce_op, native_softmax}; +pub(crate) use normalization::{ + native_fused_add_layer_norm, native_fused_add_layer_norm_bwd, native_fused_add_rms_norm, + native_fused_add_rms_norm_bwd, native_group_norm, native_layer_norm, native_rms_norm, +}; +pub(crate) use reduce::{ + native_argreduce_op, native_reduce_op, native_softmax, native_softmax_bwd, +}; pub(crate) use semiring_matmul::native_semiring_matmul; pub(crate) use unary::native_unary_op; diff --git a/src/runtime/wgpu/ops/native/normalization.rs b/src/runtime/wgpu/ops/native/normalization.rs index 548d1b8a..0041985b 100644 --- a/src/runtime/wgpu/ops/native/normalization.rs +++ b/src/runtime/wgpu/ops/native/normalization.rs @@ -3,10 +3,17 @@ use super::helpers::*; use crate::error::{Error, Result}; use crate::runtime::ensure_contiguous; -use crate::runtime::wgpu::shaders::norm; +use crate::runtime::wgpu::shaders::{fused_add_norm, norm}; use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; use crate::tensor::Tensor; +#[repr(C)] +#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +struct ReduceSumParams { + batch_size: u32, + hidden_size: u32, +} + pub(crate) fn native_rms_norm( client: &WgpuClient, a: &Tensor, @@ -106,3 +113,387 @@ pub(crate) fn native_layer_norm( Ok(out) } + +pub(crate) fn native_group_norm( + client: &WgpuClient, + input: &Tensor, + weight: &Tensor, + bias: &Tensor, + num_groups: usize, + eps: f32, +) -> Result> { + let dtype = input.dtype(); + let shape = input.shape(); + + if shape.len() < 2 { + return Err(Error::InvalidArgument { + arg: "input", + reason: "group_norm requires at least 2D input [batch, channels, ...]".into(), + }); + } + + let batch = shape[0]; + let channels = shape[1]; + if !channels.is_multiple_of(num_groups) { + return Err(Error::InvalidArgument { + arg: "num_groups", + reason: format!("channels {channels} not divisible by num_groups {num_groups}"), + }); + } + let channels_per_group = channels / num_groups; + let spatial: usize = shape[2..].iter().product::().max(1); + + if weight.shape() != [channels] || bias.shape() != [channels] { + return Err(Error::ShapeMismatch { + expected: vec![channels], + got: if weight.shape() != [channels] { + weight.shape().to_vec() + } else { + bias.shape().to_vec() + }, + }); + } + + let input_contig = ensure_contiguous(input); + let weight_contig = ensure_contiguous(weight); + let bias_contig = ensure_contiguous(bias); + let out = alloc_output(client, shape, dtype); + + let input_buf = get_tensor_buffer(&input_contig)?; + let weight_buf = get_tensor_buffer(&weight_contig)?; + let bias_buf = get_tensor_buffer(&bias_contig)?; + let out_buf = get_tensor_buffer(&out)?; + + let params = GroupNormParams { + batch_size: batch as u32, + channels: channels as u32, + spatial: spatial as u32, + num_groups: num_groups as u32, + channels_per_group: channels_per_group as u32, + eps, + _pad0: 0, + _pad1: 0, + }; + let params_buf = create_params_buffer(client, ¶ms); + + norm::launch_group_norm( + client.pipeline_cache(), + client.wgpu_queue(), + &input_buf, + &weight_buf, + &bias_buf, + &out_buf, + ¶ms_buf, + batch, + num_groups, + dtype, + )?; + + Ok(out) +} + +// ============================================================================ +// Fused Add + Normalization Operations +// ============================================================================ + +pub(crate) fn native_fused_add_rms_norm( + client: &WgpuClient, + input: &Tensor, + residual: &Tensor, + weight: &Tensor, + eps: f32, +) -> Result<(Tensor, Tensor)> { + let dtype = input.dtype(); + let shape = input.shape(); + + if shape.len() < 1 { + return Err(Error::Internal( + "fused_add_rms_norm requires at least 1D input".to_string(), + )); + } + + if shape != residual.shape() { + return Err(Error::ShapeMismatch { + expected: shape.to_vec(), + got: residual.shape().to_vec(), + }); + } + + let hidden_size = shape[shape.len() - 1]; + let batch_size: usize = shape[..shape.len() - 1].iter().product(); + + let input_contig = ensure_contiguous(input); + let residual_contig = ensure_contiguous(residual); + let weight_contig = ensure_contiguous(weight); + + let output = alloc_output(client, shape, dtype); + let pre_norm = alloc_output(client, shape, dtype); + + let input_buf = get_tensor_buffer(&input_contig)?; + let residual_buf = get_tensor_buffer(&residual_contig)?; + let weight_buf = get_tensor_buffer(&weight_contig)?; + let output_buf = get_tensor_buffer(&output)?; + let pre_norm_buf = get_tensor_buffer(&pre_norm)?; + + let params = RmsNormParams { + batch_size: batch_size.max(1) as u32, + hidden_size: hidden_size as u32, + eps, + }; + let params_buf = create_params_buffer(client, ¶ms); + + fused_add_norm::launch_fused_add_rms_norm( + client.pipeline_cache(), + client.wgpu_queue(), + &input_buf, + &residual_buf, + &weight_buf, + &output_buf, + &pre_norm_buf, + ¶ms_buf, + batch_size.max(1), + dtype, + )?; + + Ok((output, pre_norm)) +} + +pub(crate) fn native_fused_add_layer_norm( + client: &WgpuClient, + input: &Tensor, + residual: &Tensor, + weight: &Tensor, + bias: &Tensor, + eps: f32, +) -> Result<(Tensor, Tensor)> { + let dtype = input.dtype(); + let shape = input.shape(); + + if shape.len() < 1 { + return Err(Error::Internal( + "fused_add_layer_norm requires at least 1D input".to_string(), + )); + } + + if shape != residual.shape() { + return Err(Error::ShapeMismatch { + expected: shape.to_vec(), + got: residual.shape().to_vec(), + }); + } + + let hidden_size = shape[shape.len() - 1]; + let batch_size: usize = shape[..shape.len() - 1].iter().product(); + + let input_contig = ensure_contiguous(input); + let residual_contig = ensure_contiguous(residual); + let weight_contig = ensure_contiguous(weight); + let bias_contig = ensure_contiguous(bias); + + let output = alloc_output(client, shape, dtype); + let pre_norm = alloc_output(client, shape, dtype); + + let input_buf = get_tensor_buffer(&input_contig)?; + let residual_buf = get_tensor_buffer(&residual_contig)?; + let weight_buf = get_tensor_buffer(&weight_contig)?; + let bias_buf = get_tensor_buffer(&bias_contig)?; + let output_buf = get_tensor_buffer(&output)?; + let pre_norm_buf = get_tensor_buffer(&pre_norm)?; + + let params = LayerNormParams { + batch_size: batch_size.max(1) as u32, + hidden_size: hidden_size as u32, + eps, + }; + let params_buf = create_params_buffer(client, ¶ms); + + fused_add_norm::launch_fused_add_layer_norm( + client.pipeline_cache(), + client.wgpu_queue(), + &input_buf, + &residual_buf, + &weight_buf, + &bias_buf, + &output_buf, + &pre_norm_buf, + ¶ms_buf, + batch_size.max(1), + dtype, + )?; + + Ok((output, pre_norm)) +} + +pub(crate) fn native_fused_add_rms_norm_bwd( + client: &WgpuClient, + grad: &Tensor, + pre_norm: &Tensor, + weight: &Tensor, + eps: f32, +) -> Result<(Tensor, Tensor)> { + let dtype = grad.dtype(); + let shape = grad.shape(); + + if shape.len() < 1 { + return Err(Error::Internal( + "fused_add_rms_norm_bwd requires at least 1D input".to_string(), + )); + } + + let hidden_size = shape[shape.len() - 1]; + let batch_size: usize = shape[..shape.len() - 1].iter().product(); + + let grad_contig = ensure_contiguous(grad); + let pn_contig = ensure_contiguous(pre_norm); + let weight_contig = ensure_contiguous(weight); + + let d_input_residual = alloc_output(client, shape, dtype); + let d_weight_scratch = alloc_output(client, &[batch_size, hidden_size], dtype); + let d_weight = alloc_output(client, &[hidden_size], dtype); + + let grad_buf = get_tensor_buffer(&grad_contig)?; + let pn_buf = get_tensor_buffer(&pn_contig)?; + let weight_buf = get_tensor_buffer(&weight_contig)?; + let d_ir_buf = get_tensor_buffer(&d_input_residual)?; + let dws_buf = get_tensor_buffer(&d_weight_scratch)?; + let dw_buf = get_tensor_buffer(&d_weight)?; + + let params = RmsNormParams { + batch_size: batch_size.max(1) as u32, + hidden_size: hidden_size as u32, + eps, + }; + let params_buf = create_params_buffer(client, ¶ms); + + fused_add_norm::launch_fused_add_rms_norm_bwd( + client.pipeline_cache(), + client.wgpu_queue(), + &grad_buf, + &pn_buf, + &weight_buf, + &d_ir_buf, + &dws_buf, + ¶ms_buf, + batch_size.max(1), + dtype, + )?; + + // Launch reduce_sum_rows to sum d_weight_scratch across batch + let reduce_params = ReduceSumParams { + batch_size: batch_size.max(1) as u32, + hidden_size: hidden_size as u32, + }; + let reduce_params_buf = create_params_buffer(client, &reduce_params); + + fused_add_norm::launch_reduce_sum_rows( + client.pipeline_cache(), + client.wgpu_queue(), + &dws_buf, + &dw_buf, + &reduce_params_buf, + hidden_size, + dtype, + )?; + + Ok((d_input_residual, d_weight)) +} + +pub(crate) fn native_fused_add_layer_norm_bwd( + client: &WgpuClient, + grad: &Tensor, + pre_norm: &Tensor, + weight: &Tensor, + bias: &Tensor, + eps: f32, +) -> Result<( + Tensor, + Tensor, + Tensor, +)> { + let dtype = grad.dtype(); + let shape = grad.shape(); + + if shape.len() < 1 { + return Err(Error::Internal( + "fused_add_layer_norm_bwd requires at least 1D input".to_string(), + )); + } + + let hidden_size = shape[shape.len() - 1]; + let batch_size: usize = shape[..shape.len() - 1].iter().product(); + + let grad_contig = ensure_contiguous(grad); + let pn_contig = ensure_contiguous(pre_norm); + let weight_contig = ensure_contiguous(weight); + let bias_contig = ensure_contiguous(bias); + + let d_input_residual = alloc_output(client, shape, dtype); + let d_weight_scratch = alloc_output(client, &[batch_size, hidden_size], dtype); + let d_bias_scratch = alloc_output(client, &[batch_size, hidden_size], dtype); + let d_weight = alloc_output(client, &[hidden_size], dtype); + let d_bias = alloc_output(client, &[hidden_size], dtype); + + let grad_buf = get_tensor_buffer(&grad_contig)?; + let pn_buf = get_tensor_buffer(&pn_contig)?; + let weight_buf = get_tensor_buffer(&weight_contig)?; + let bias_buf = get_tensor_buffer(&bias_contig)?; + let d_ir_buf = get_tensor_buffer(&d_input_residual)?; + let dws_buf = get_tensor_buffer(&d_weight_scratch)?; + let dbs_buf = get_tensor_buffer(&d_bias_scratch)?; + let dw_buf = get_tensor_buffer(&d_weight)?; + let db_buf = get_tensor_buffer(&d_bias)?; + + let params = LayerNormParams { + batch_size: batch_size.max(1) as u32, + hidden_size: hidden_size as u32, + eps, + }; + let params_buf = create_params_buffer(client, ¶ms); + + fused_add_norm::launch_fused_add_layer_norm_bwd( + client.pipeline_cache(), + client.wgpu_queue(), + &grad_buf, + &pn_buf, + &weight_buf, + &bias_buf, + &d_ir_buf, + &dws_buf, + &dbs_buf, + ¶ms_buf, + batch_size.max(1), + dtype, + )?; + + // Launch reduce_sum_rows for d_weight_scratch + let reduce_params = ReduceSumParams { + batch_size: batch_size.max(1) as u32, + hidden_size: hidden_size as u32, + }; + let reduce_params_buf = create_params_buffer(client, &reduce_params); + + fused_add_norm::launch_reduce_sum_rows( + client.pipeline_cache(), + client.wgpu_queue(), + &dws_buf, + &dw_buf, + &reduce_params_buf, + hidden_size, + dtype, + )?; + + // Launch reduce_sum_rows for d_bias_scratch + let reduce_params_buf = create_params_buffer(client, &reduce_params); + + fused_add_norm::launch_reduce_sum_rows( + client.pipeline_cache(), + client.wgpu_queue(), + &dbs_buf, + &db_buf, + &reduce_params_buf, + hidden_size, + dtype, + )?; + + Ok((d_input_residual, d_weight, d_bias)) +} diff --git a/src/runtime/wgpu/ops/native/reduce.rs b/src/runtime/wgpu/ops/native/reduce.rs index 43c38b38..d9b62e7c 100644 --- a/src/runtime/wgpu/ops/native/reduce.rs +++ b/src/runtime/wgpu/ops/native/reduce.rs @@ -308,6 +308,92 @@ fn native_softmax_last_dim( Ok(out) } +/// Softmax backward with dedicated GPU kernel. +/// +/// d_input = output * (grad - sum(grad * output)) +pub(crate) fn native_softmax_bwd( + client: &WgpuClient, + grad: &Tensor, + output: &Tensor, + dim: isize, +) -> Result> { + let shape = grad.shape(); + let ndim = shape.len(); + + let dim = if dim < 0 { + (ndim as isize + dim) as usize + } else { + dim as usize + }; + + if dim >= ndim { + return Err(Error::InvalidDimension { + dim: dim as isize, + ndim, + }); + } + + // For non-last dimension, permute to last, compute, permute back + if dim != ndim - 1 { + let mut perm: Vec = (0..ndim).collect(); + perm.remove(dim); + perm.push(dim); + + let grad_p = grad.permute(&perm)?.contiguous(); + let output_p = output.permute(&perm)?.contiguous(); + let result = native_softmax_bwd_last_dim(client, &grad_p, &output_p)?; + + let mut inv_perm = vec![0; ndim]; + for (i, &p) in perm.iter().enumerate() { + inv_perm[p] = i; + } + return result.permute(&inv_perm); + } + + native_softmax_bwd_last_dim(client, grad, output) +} + +fn native_softmax_bwd_last_dim( + client: &WgpuClient, + grad: &Tensor, + output: &Tensor, +) -> Result> { + let shape = grad.shape(); + let ndim = shape.len(); + let dtype = grad.dtype(); + + let grad_contig = ensure_contiguous(grad); + let output_contig = ensure_contiguous(output); + let dim = ndim - 1; + let batch_size: usize = shape[..dim].iter().product(); + let dim_size = shape[dim]; + + let d_input = alloc_output(client, shape, dtype); + + let grad_buf = get_tensor_buffer(&grad_contig)?; + let output_buf = get_tensor_buffer(&output_contig)?; + let d_input_buf = get_tensor_buffer(&d_input)?; + + let params = SoftmaxParams { + batch_size: batch_size.max(1) as u32, + dim_size: dim_size as u32, + }; + let params_buf = create_params_buffer(client, ¶ms); + + reduce::launch_softmax_bwd_op( + client.pipeline_cache(), + client.wgpu_queue(), + &grad_buf, + &output_buf, + &d_input_buf, + ¶ms_buf, + batch_size.max(1), + dtype, + )?; + + Ok(d_input) +} + pub(crate) fn native_argreduce_op( client: &WgpuClient, op: &'static str, diff --git a/src/runtime/wgpu/ops/tensor.rs b/src/runtime/wgpu/ops/tensor.rs index 42600c71..04e80098 100644 --- a/src/runtime/wgpu/ops/tensor.rs +++ b/src/runtime/wgpu/ops/tensor.rs @@ -78,6 +78,9 @@ mod distance; #[path = "../../../ops/wgpu/multivariate.rs"] mod multivariate; +#[path = "../../../ops/wgpu/gemm_epilogue.rs"] +mod gemm_epilogue; + #[path = "../../../ops/wgpu/semiring_matmul.rs"] mod semiring_matmul; @@ -92,3 +95,10 @@ mod scalar; #[path = "../../../ops/wgpu/einsum.rs"] mod einsum; + +#[path = "../../../ops/wgpu/fp8_matmul.rs"] +mod fp8_matmul; + +#[cfg(feature = "sparse")] +#[path = "../../../ops/wgpu/sparse_24.rs"] +mod sparse_24; diff --git a/src/runtime/wgpu/runtime.rs b/src/runtime/wgpu/runtime.rs index b348fb3a..861f4530 100644 --- a/src/runtime/wgpu/runtime.rs +++ b/src/runtime/wgpu/runtime.rs @@ -9,7 +9,7 @@ fn wgpu_err(e: super::device::WgpuError) -> crate::error::Error { use super::client::WgpuClient; use super::device::WgpuDevice; use super::shaders; -use crate::runtime::{Allocator, Runtime, RuntimeClient}; +use crate::runtime::{Allocator, NoOpGraph, Runtime, RuntimeClient}; use std::time::Duration; /// WebGPU Runtime adapter @@ -23,7 +23,9 @@ impl Runtime for WgpuRuntime { type Device = WgpuDevice; type Client = WgpuClient; type Allocator = super::WgpuAllocator; + type Graph = NoOpGraph; type RawHandle = super::WgpuRawHandle; + type DType = crate::dtype::DType; fn name() -> &'static str { "wgpu" @@ -33,6 +35,15 @@ impl Runtime for WgpuRuntime { false // WebGPU doesn't have CUDA-style graph capture } + fn capture_graph(client: &Self::Client, f: F) -> crate::error::Result<(Self::Graph, T)> + where + F: FnOnce(&Self::Client) -> crate::error::Result, + { + // WebGPU: execute eagerly, return NoOpGraph + let result = f(client)?; + Ok((NoOpGraph, result)) + } + /// Allocate GPU memory (storage buffer). /// /// Returns `Err(OutOfMemory)` if buffer creation fails. diff --git a/src/runtime/wgpu/shaders/activation.wgsl b/src/runtime/wgpu/shaders/activation.wgsl new file mode 100644 index 00000000..0dfc105a --- /dev/null +++ b/src/runtime/wgpu/shaders/activation.wgsl @@ -0,0 +1,22 @@ +// F32 clamp operation + +const WORKGROUP_SIZE: u32 = 256u; + +struct ClampParams { + numel: u32, + min_val: f32, + max_val: f32, + _pad0: u32, +} + +@group(0) @binding(0) var clamp_a: array; +@group(0) @binding(1) var clamp_out: array; +@group(0) @binding(2) var clamp_params: ClampParams; + +@compute @workgroup_size(256) +fn clamp_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < clamp_params.numel) { + clamp_out[idx] = clamp(clamp_a[idx], clamp_params.min_val, clamp_params.max_val); + } +} diff --git a/src/runtime/wgpu/shaders/activation_launcher.rs b/src/runtime/wgpu/shaders/activation_launcher.rs index 030fb831..45ec419e 100644 --- a/src/runtime/wgpu/shaders/activation_launcher.rs +++ b/src/runtime/wgpu/shaders/activation_launcher.rs @@ -1,32 +1,15 @@ -//! Activation and utility WGSL kernel launchers -//! -//! Provides launchers for specialized activation and utility operations: -//! - `launch_leaky_relu` - Leaky ReLU activation -//! - `launch_elu` - ELU (Exponential Linear Unit) activation -//! - `launch_clamp_op` - Value clamping -//! -//! All operations support F32 and F16 dtypes. +//! Activation and utility WGSL kernel launchers. F32 only. use wgpu::{Buffer, Queue}; -use super::generator::{ - dtype_suffix, generate_clamp_shader, generate_scalar_shader, is_wgsl_float, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; use crate::error::{Error, Result}; -// ============================================================================ -// Parametric Activation Operations -// ============================================================================ +const SCALAR_SHADER: &str = include_str!("scalar.wgsl"); +const ACTIVATION_SHADER: &str = include_str!("activation.wgsl"); -/// Launch Leaky ReLU activation kernel. -/// -/// Computes `out[i] = max(negative_slope * a[i], a[i])` for all elements. -/// -/// Helps prevent "dying ReLU" by allowing small gradients for negative inputs. -/// -/// Supports F32 and F16 dtypes. +/// Launch Leaky ReLU: `out[i] = max(slope * a[i], a[i])`. F32 only. pub fn launch_leaky_relu( cache: &PipelineCache, queue: &Queue, @@ -36,28 +19,20 @@ pub fn launch_leaky_relu( numel: usize, dtype: DType, ) -> Result<()> { - // leaky_relu is float-only - if !is_wgsl_float(dtype) { + if dtype != DType::F32 { return Err(Error::UnsupportedDType { dtype, op: "leaky_relu", }); } - let suffix = dtype_suffix(dtype)?; - let shader_key = format!("scalar_{}", suffix); - let entry_point = format!("leaky_relu_{}", suffix); - - let shader_source = generate_scalar_shader(dtype)?; - let module = cache.get_or_create_module_from_source(&shader_key, &shader_source); + let module = cache.get_or_create_module("scalar_f32", SCALAR_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline(&shader_key, &entry_point, &module, &layout); - + let pipeline = cache.get_or_create_pipeline("scalar_f32", "leaky_relu_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[a, out, params_buffer]); let mut encoder = cache @@ -65,7 +40,6 @@ pub fn launch_leaky_relu( .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("leaky_relu"), }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("leaky_relu"), @@ -75,18 +49,11 @@ pub fn launch_leaky_relu( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(numel), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } -/// Launch ELU (Exponential Linear Unit) activation kernel. -/// -/// Computes `out[i] = a[i] if a[i] > 0, else alpha * (exp(a[i]) - 1)` for all elements. -/// -/// Smooth approximation to ReLU with negative values saturating to -alpha. -/// -/// Supports F32 and F16 dtypes. +/// Launch ELU: `out[i] = x > 0 ? x : alpha * (exp(x) - 1)`. F32 only. pub fn launch_elu( cache: &PipelineCache, queue: &Queue, @@ -96,31 +63,22 @@ pub fn launch_elu( numel: usize, dtype: DType, ) -> Result<()> { - // elu is float-only - if !is_wgsl_float(dtype) { + if dtype != DType::F32 { return Err(Error::UnsupportedDType { dtype, op: "elu" }); } - let suffix = dtype_suffix(dtype)?; - let shader_key = format!("scalar_{}", suffix); - let entry_point = format!("elu_{}", suffix); - - let shader_source = generate_scalar_shader(dtype)?; - let module = cache.get_or_create_module_from_source(&shader_key, &shader_source); + let module = cache.get_or_create_module("scalar_f32", SCALAR_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline(&shader_key, &entry_point, &module, &layout); - + let pipeline = cache.get_or_create_pipeline("scalar_f32", "elu_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[a, out, params_buffer]); let mut encoder = cache .device() .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("elu") }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("elu"), @@ -130,20 +88,11 @@ pub fn launch_elu( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(numel), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } -// ============================================================================ -// Clamp Operation -// ============================================================================ - -/// Launch clamp operation kernel. -/// -/// Computes `out[i] = clamp(a[i], min_val, max_val)` for all elements. -/// -/// Supports F32 and F16 dtypes. +/// Launch clamp: `out[i] = clamp(a[i], min_val, max_val)`. F32 only. pub fn launch_clamp_op( cache: &PipelineCache, queue: &Queue, @@ -153,25 +102,17 @@ pub fn launch_clamp_op( numel: usize, dtype: DType, ) -> Result<()> { - // clamp is float-only - if !is_wgsl_float(dtype) { + if dtype != DType::F32 { return Err(Error::UnsupportedDType { dtype, op: "clamp" }); } - let suffix = dtype_suffix(dtype)?; - let shader_key = format!("clamp_{}", suffix); - let entry_point = format!("clamp_{}", suffix); - - let shader_source = generate_clamp_shader(dtype)?; - let module = cache.get_or_create_module_from_source(&shader_key, &shader_source); + let module = cache.get_or_create_module("activation_f32", ACTIVATION_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline(&shader_key, &entry_point, &module, &layout); - + let pipeline = cache.get_or_create_pipeline("activation_f32", "clamp_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[a, out, params_buffer]); let mut encoder = cache @@ -179,7 +120,6 @@ pub fn launch_clamp_op( .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("clamp"), }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("clamp"), @@ -189,7 +129,6 @@ pub fn launch_clamp_op( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(numel), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } diff --git a/src/runtime/wgpu/shaders/angle_complex64.wgsl b/src/runtime/wgpu/shaders/angle_complex64.wgsl new file mode 100644 index 00000000..6d28bdd9 --- /dev/null +++ b/src/runtime/wgpu/shaders/angle_complex64.wgsl @@ -0,0 +1,19 @@ +// Complex phase angle shader +// entry point: angle_complex64 + +struct Params { + numel: u32, +} + +@group(0) @binding(0) var input: array>; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: Params; + +@compute @workgroup_size(256) +fn angle_complex64(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < params.numel) { + let val = input[idx]; + output[idx] = atan2(val.y, val.x); // Phase angle in radians [-π, π] + } +} diff --git a/src/runtime/wgpu/shaders/angle_real_f32.wgsl b/src/runtime/wgpu/shaders/angle_real_f32.wgsl new file mode 100644 index 00000000..7d8fdf6d --- /dev/null +++ b/src/runtime/wgpu/shaders/angle_real_f32.wgsl @@ -0,0 +1,24 @@ +// Phase angle of real numbers shader +// entry point: angle_real_f32 +// angle(x) = 0 if x >= 0, π if x < 0 + +struct Params { + numel: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: Params; + +// PI constant (WGSL has no standard math library, so this is defined literally) +// Value matches std::f32::consts::PI exactly (f32 precision: ~7 significant digits) +const PI: f32 = 3.14159265f; + +@compute @workgroup_size(256) +fn angle_real_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < params.numel) { + let val = input[idx]; + output[idx] = select(0.0, PI, val < 0.0); // 0 if x >= 0, π if x < 0 + } +} diff --git a/src/runtime/wgpu/shaders/arange_f32.wgsl b/src/runtime/wgpu/shaders/arange_f32.wgsl new file mode 100644 index 00000000..51eca620 --- /dev/null +++ b/src/runtime/wgpu/shaders/arange_f32.wgsl @@ -0,0 +1,21 @@ +// Auto-generated arange operation for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ArangeParams { + numel: u32, + start: f32, + step: f32, +} + +@group(0) @binding(0) var arange_out: array; +@group(0) @binding(1) var arange_params: ArangeParams; + +@compute @workgroup_size(256) +fn arange_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < arange_params.numel) { + let value = arange_params.start + arange_params.step * f32(idx); + arange_out[idx] = f32(value); + } +} diff --git a/src/runtime/wgpu/shaders/arange_i32.wgsl b/src/runtime/wgpu/shaders/arange_i32.wgsl new file mode 100644 index 00000000..8abb3058 --- /dev/null +++ b/src/runtime/wgpu/shaders/arange_i32.wgsl @@ -0,0 +1,21 @@ +// Auto-generated arange operation for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ArangeParams { + numel: u32, + start: f32, + step: f32, +} + +@group(0) @binding(0) var arange_out: array; +@group(0) @binding(1) var arange_params: ArangeParams; + +@compute @workgroup_size(256) +fn arange_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < arange_params.numel) { + let value = arange_params.start + arange_params.step * f32(idx); + arange_out[idx] = i32(value); + } +} diff --git a/src/runtime/wgpu/shaders/arange_u32.wgsl b/src/runtime/wgpu/shaders/arange_u32.wgsl new file mode 100644 index 00000000..3cb3473a --- /dev/null +++ b/src/runtime/wgpu/shaders/arange_u32.wgsl @@ -0,0 +1,21 @@ +// Auto-generated arange operation for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ArangeParams { + numel: u32, + start: f32, + step: f32, +} + +@group(0) @binding(0) var arange_out: array; +@group(0) @binding(1) var arange_params: ArangeParams; + +@compute @workgroup_size(256) +fn arange_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < arange_params.numel) { + let value = arange_params.start + arange_params.step * f32(idx); + arange_out[idx] = u32(value); + } +} diff --git a/src/runtime/wgpu/shaders/bernoulli_f32.wgsl b/src/runtime/wgpu/shaders/bernoulli_f32.wgsl new file mode 100644 index 00000000..efdc4b2b --- /dev/null +++ b/src/runtime/wgpu/shaders/bernoulli_f32.wgsl @@ -0,0 +1,39 @@ +// Bernoulli distribution sampling for f32 + +// PCG hash function for random number generation +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct BernoulliParams { + numel: u32, + seed: u32, + p: f32, + _pad: u32, +} + +@group(0) @binding(0) var out: array; +@group(0) @binding(1) var params: BernoulliParams; + +@compute @workgroup_size(256) +fn bernoulli_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if idx < params.numel { + var state = pcg_init(params.seed, idx); + let u = pcg_uniform(&state); + out[idx] = select(f32(0.0), f32(1.0), u < params.p); + } +} diff --git a/src/runtime/wgpu/shaders/beta_dist_f32.wgsl b/src/runtime/wgpu/shaders/beta_dist_f32.wgsl new file mode 100644 index 00000000..06b834e9 --- /dev/null +++ b/src/runtime/wgpu/shaders/beta_dist_f32.wgsl @@ -0,0 +1,92 @@ +// Beta distribution sampling for f32 + +// PCG hash function for random number generation +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; +} + +// Box-Muller for normal distribution +fn sample_normal(state: ptr) -> f32 { + let u1 = max(pcg_uniform(state), 0.0000001); + let u2 = pcg_uniform(state); + return sqrt(-2.0 * log(u1)) * cos(6.28318530718 * u2); +} + +// Gamma via Marsaglia-Tsang method +fn sample_gamma_mt(state: ptr, shape: f32, scale: f32) -> f32 { + var alpha = shape; + var boost = 1.0; + + // Handle shape < 1 by boosting + if alpha < 1.0 { + boost = pow(pcg_uniform(state), 1.0 / alpha); + alpha = alpha + 1.0; + } + + let d = alpha - 1.0 / 3.0; + let c = 1.0 / sqrt(9.0 * d); + + // Rejection sampling + for (var i = 0u; i < 100u; i = i + 1u) { + var x: f32; + var v: f32; + + // Generate valid v + for (var j = 0u; j < 100u; j = j + 1u) { + x = sample_normal(state); + v = 1.0 + c * x; + if v > 0.0 { + break; + } + } + + v = v * v * v; + let u = pcg_uniform(state); + let x2 = x * x; + + // Accept/reject + if u < 1.0 - 0.0331 * x2 * x2 { + return d * v * boost * scale; + } + if log(u) < 0.5 * x2 + d * (1.0 - v + log(v)) { + return d * v * boost * scale; + } + } + + // Fallback (should rarely reach) + return d * boost * scale; +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct BetaParams { + numel: u32, + seed: u32, + alpha: f32, + beta: f32, +} + +@group(0) @binding(0) var out: array; +@group(0) @binding(1) var params: BetaParams; + +@compute @workgroup_size(256) +fn beta_dist_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if idx < params.numel { + var state = pcg_init(params.seed, idx); + let x = sample_gamma_mt(&state, params.alpha, 1.0); + let y = sample_gamma_mt(&state, params.beta, 1.0); + out[idx] = f32(x / (x + y)); + } +} diff --git a/src/runtime/wgpu/shaders/binary.wgsl b/src/runtime/wgpu/shaders/binary.wgsl new file mode 100644 index 00000000..f71c8d78 --- /dev/null +++ b/src/runtime/wgpu/shaders/binary.wgsl @@ -0,0 +1,76 @@ +// F32 binary operations + +const WORKGROUP_SIZE: u32 = 256u; + +struct BinaryParams { + numel: u32, +} + +@group(0) @binding(0) var binary_a: array; +@group(0) @binding(1) var binary_b: array; +@group(0) @binding(2) var binary_out: array; +@group(0) @binding(3) var binary_params: BinaryParams; + +@compute @workgroup_size(256) +fn add_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = binary_a[idx] + binary_b[idx]; + } +} + +@compute @workgroup_size(256) +fn sub_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = binary_a[idx] - binary_b[idx]; + } +} + +@compute @workgroup_size(256) +fn mul_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = binary_a[idx] * binary_b[idx]; + } +} + +@compute @workgroup_size(256) +fn div_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = binary_a[idx] / binary_b[idx]; + } +} + +@compute @workgroup_size(256) +fn max_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = max(binary_a[idx], binary_b[idx]); + } +} + +@compute @workgroup_size(256) +fn min_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = min(binary_a[idx], binary_b[idx]); + } +} + +@compute @workgroup_size(256) +fn pow_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = pow(binary_a[idx], binary_b[idx]); + } +} + +@compute @workgroup_size(256) +fn atan2_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = atan2(binary_a[idx], binary_b[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/binary_broadcast.wgsl b/src/runtime/wgpu/shaders/binary_broadcast.wgsl new file mode 100644 index 00000000..94b18716 --- /dev/null +++ b/src/runtime/wgpu/shaders/binary_broadcast.wgsl @@ -0,0 +1,177 @@ +// F32 broadcast binary operations + +const WORKGROUP_SIZE: u32 = 256u; + +struct BroadcastBinaryParams { + numel: u32, + ndim: u32, +} + +@group(0) @binding(0) var broadcast_a: array; +@group(0) @binding(1) var broadcast_b: array; +@group(0) @binding(2) var broadcast_out: array; +@group(0) @binding(3) var broadcast_a_strides: array; +@group(0) @binding(4) var broadcast_b_strides: array; +@group(0) @binding(5) var broadcast_out_strides: array; +@group(0) @binding(6) var broadcast_params: BroadcastBinaryParams; + +@compute @workgroup_size(256) +fn broadcast_add_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { + return; + } + + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + + broadcast_out[idx] = broadcast_a[a_offset] + broadcast_b[b_offset]; +} + +@compute @workgroup_size(256) +fn broadcast_sub_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { + return; + } + + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + + broadcast_out[idx] = broadcast_a[a_offset] - broadcast_b[b_offset]; +} + +@compute @workgroup_size(256) +fn broadcast_mul_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { + return; + } + + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + + broadcast_out[idx] = broadcast_a[a_offset] * broadcast_b[b_offset]; +} + +@compute @workgroup_size(256) +fn broadcast_div_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { + return; + } + + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + + broadcast_out[idx] = broadcast_a[a_offset] / broadcast_b[b_offset]; +} + +@compute @workgroup_size(256) +fn broadcast_max_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { + return; + } + + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + + broadcast_out[idx] = max(broadcast_a[a_offset], broadcast_b[b_offset]); +} + +@compute @workgroup_size(256) +fn broadcast_min_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { + return; + } + + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + + broadcast_out[idx] = min(broadcast_a[a_offset], broadcast_b[b_offset]); +} + +@compute @workgroup_size(256) +fn broadcast_pow_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { + return; + } + + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + + broadcast_out[idx] = pow(broadcast_a[a_offset], broadcast_b[b_offset]); +} diff --git a/src/runtime/wgpu/shaders/binary_broadcast_i32.wgsl b/src/runtime/wgpu/shaders/binary_broadcast_i32.wgsl new file mode 100644 index 00000000..3ded637f --- /dev/null +++ b/src/runtime/wgpu/shaders/binary_broadcast_i32.wgsl @@ -0,0 +1,116 @@ +// I32 broadcast binary operations + +struct BroadcastBinaryParams { + numel: u32, + ndim: u32, +} + +@group(0) @binding(0) var broadcast_a: array; +@group(0) @binding(1) var broadcast_b: array; +@group(0) @binding(2) var broadcast_out: array; +@group(0) @binding(3) var broadcast_a_strides: array; +@group(0) @binding(4) var broadcast_b_strides: array; +@group(0) @binding(5) var broadcast_out_strides: array; +@group(0) @binding(6) var broadcast_params: BroadcastBinaryParams; + +@compute @workgroup_size(256) +fn broadcast_add_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { return; } + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + broadcast_out[idx] = broadcast_a[a_offset] + broadcast_b[b_offset]; +} + +@compute @workgroup_size(256) +fn broadcast_sub_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { return; } + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + broadcast_out[idx] = broadcast_a[a_offset] - broadcast_b[b_offset]; +} + +@compute @workgroup_size(256) +fn broadcast_mul_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { return; } + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + broadcast_out[idx] = broadcast_a[a_offset] * broadcast_b[b_offset]; +} + +@compute @workgroup_size(256) +fn broadcast_div_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { return; } + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + broadcast_out[idx] = broadcast_a[a_offset] / broadcast_b[b_offset]; +} + +@compute @workgroup_size(256) +fn broadcast_max_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { return; } + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + broadcast_out[idx] = max(broadcast_a[a_offset], broadcast_b[b_offset]); +} + +@compute @workgroup_size(256) +fn broadcast_min_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { return; } + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + broadcast_out[idx] = min(broadcast_a[a_offset], broadcast_b[b_offset]); +} diff --git a/src/runtime/wgpu/shaders/binary_broadcast_u32.wgsl b/src/runtime/wgpu/shaders/binary_broadcast_u32.wgsl new file mode 100644 index 00000000..60136e9e --- /dev/null +++ b/src/runtime/wgpu/shaders/binary_broadcast_u32.wgsl @@ -0,0 +1,116 @@ +// U32 broadcast binary operations + +struct BroadcastBinaryParams { + numel: u32, + ndim: u32, +} + +@group(0) @binding(0) var broadcast_a: array; +@group(0) @binding(1) var broadcast_b: array; +@group(0) @binding(2) var broadcast_out: array; +@group(0) @binding(3) var broadcast_a_strides: array; +@group(0) @binding(4) var broadcast_b_strides: array; +@group(0) @binding(5) var broadcast_out_strides: array; +@group(0) @binding(6) var broadcast_params: BroadcastBinaryParams; + +@compute @workgroup_size(256) +fn broadcast_add_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { return; } + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + broadcast_out[idx] = broadcast_a[a_offset] + broadcast_b[b_offset]; +} + +@compute @workgroup_size(256) +fn broadcast_sub_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { return; } + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + broadcast_out[idx] = broadcast_a[a_offset] - broadcast_b[b_offset]; +} + +@compute @workgroup_size(256) +fn broadcast_mul_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { return; } + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + broadcast_out[idx] = broadcast_a[a_offset] * broadcast_b[b_offset]; +} + +@compute @workgroup_size(256) +fn broadcast_div_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { return; } + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + broadcast_out[idx] = broadcast_a[a_offset] / broadcast_b[b_offset]; +} + +@compute @workgroup_size(256) +fn broadcast_max_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { return; } + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + broadcast_out[idx] = max(broadcast_a[a_offset], broadcast_b[b_offset]); +} + +@compute @workgroup_size(256) +fn broadcast_min_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= broadcast_params.numel) { return; } + var remaining = idx; + var a_offset: u32 = 0u; + var b_offset: u32 = 0u; + for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) { + let stride = broadcast_out_strides[d]; + let coord = remaining / stride; + remaining = remaining % stride; + a_offset = a_offset + coord * broadcast_a_strides[d]; + b_offset = b_offset + coord * broadcast_b_strides[d]; + } + broadcast_out[idx] = min(broadcast_a[a_offset], broadcast_b[b_offset]); +} diff --git a/src/runtime/wgpu/shaders/binary_i32.wgsl b/src/runtime/wgpu/shaders/binary_i32.wgsl new file mode 100644 index 00000000..4f9e984f --- /dev/null +++ b/src/runtime/wgpu/shaders/binary_i32.wgsl @@ -0,0 +1,58 @@ +// I32 binary operations + +struct BinaryParams { + numel: u32, +} + +@group(0) @binding(0) var binary_a: array; +@group(0) @binding(1) var binary_b: array; +@group(0) @binding(2) var binary_out: array; +@group(0) @binding(3) var binary_params: BinaryParams; + +@compute @workgroup_size(256) +fn add_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = binary_a[idx] + binary_b[idx]; + } +} + +@compute @workgroup_size(256) +fn sub_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = binary_a[idx] - binary_b[idx]; + } +} + +@compute @workgroup_size(256) +fn mul_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = binary_a[idx] * binary_b[idx]; + } +} + +@compute @workgroup_size(256) +fn div_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = binary_a[idx] / binary_b[idx]; + } +} + +@compute @workgroup_size(256) +fn max_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = max(binary_a[idx], binary_b[idx]); + } +} + +@compute @workgroup_size(256) +fn min_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = min(binary_a[idx], binary_b[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/binary_u32.wgsl b/src/runtime/wgpu/shaders/binary_u32.wgsl new file mode 100644 index 00000000..01dd2adf --- /dev/null +++ b/src/runtime/wgpu/shaders/binary_u32.wgsl @@ -0,0 +1,58 @@ +// U32 binary operations + +struct BinaryParams { + numel: u32, +} + +@group(0) @binding(0) var binary_a: array; +@group(0) @binding(1) var binary_b: array; +@group(0) @binding(2) var binary_out: array; +@group(0) @binding(3) var binary_params: BinaryParams; + +@compute @workgroup_size(256) +fn add_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = binary_a[idx] + binary_b[idx]; + } +} + +@compute @workgroup_size(256) +fn sub_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = binary_a[idx] - binary_b[idx]; + } +} + +@compute @workgroup_size(256) +fn mul_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = binary_a[idx] * binary_b[idx]; + } +} + +@compute @workgroup_size(256) +fn div_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = binary_a[idx] / binary_b[idx]; + } +} + +@compute @workgroup_size(256) +fn max_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = max(binary_a[idx], binary_b[idx]); + } +} + +@compute @workgroup_size(256) +fn min_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < binary_params.numel) { + binary_out[idx] = min(binary_a[idx], binary_b[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/bincount_i32.wgsl b/src/runtime/wgpu/shaders/bincount_i32.wgsl new file mode 100644 index 00000000..8c99a06d --- /dev/null +++ b/src/runtime/wgpu/shaders/bincount_i32.wgsl @@ -0,0 +1,29 @@ +// Auto-generated unweighted bincount + +const WORKGROUP_SIZE: u32 = 256u; + +struct BincountParams { + n: u32, + minlength: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var bincount_input: array; +@group(0) @binding(1) var bincount_output: array>; +@group(0) @binding(2) var bincount_params: BincountParams; + +@compute @workgroup_size(256) +fn bincount_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= bincount_params.n) { + return; + } + + let value = bincount_input[idx]; + if (value < 0 || u32(value) >= bincount_params.minlength) { + return; + } + + atomicAdd(&bincount_output[u32(value)], 1u); +} diff --git a/src/runtime/wgpu/shaders/bincount_weighted_f32.wgsl b/src/runtime/wgpu/shaders/bincount_weighted_f32.wgsl new file mode 100644 index 00000000..a9c265f5 --- /dev/null +++ b/src/runtime/wgpu/shaders/bincount_weighted_f32.wgsl @@ -0,0 +1,34 @@ +// Auto-generated weighted bincount for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct BincountParams { + n: u32, + minlength: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var bincount_input: array; +@group(0) @binding(1) var bincount_weights: array; +@group(0) @binding(2) var bincount_output: array>; +@group(0) @binding(3) var bincount_params: BincountParams; + +@compute @workgroup_size(256) +fn bincount_weighted_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= bincount_params.n) { + return; + } + + let value = bincount_input[idx]; + if (value < 0 || u32(value) >= bincount_params.minlength) { + return; + } + + let weight = bincount_weights[idx]; + // For float weights, we need to use atomic operations + // WebGPU only supports atomic ops on u32/i32, so we use bitcast + let weight_bits = bitcast(weight); + atomicAdd(&bincount_output[u32(value)], weight_bits); +} diff --git a/src/runtime/wgpu/shaders/binomial_f32.wgsl b/src/runtime/wgpu/shaders/binomial_f32.wgsl new file mode 100644 index 00000000..4eab1365 --- /dev/null +++ b/src/runtime/wgpu/shaders/binomial_f32.wgsl @@ -0,0 +1,65 @@ +// Binomial distribution sampling for f32 + +// PCG hash function for random number generation +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; +} + +// Box-Muller for normal distribution +fn sample_normal(state: ptr) -> f32 { + let u1 = max(pcg_uniform(state), 0.0000001); + let u2 = pcg_uniform(state); + return sqrt(-2.0 * log(u1)) * cos(6.28318530718 * u2); +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct BinomialParams { + numel: u32, + seed: u32, + n_trials: u32, + p: f32, +} + +@group(0) @binding(0) var out: array; +@group(0) @binding(1) var params: BinomialParams; + +@compute @workgroup_size(256) +fn binomial_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if idx < params.numel { + var state = pcg_init(params.seed, idx); + + let n = params.n_trials; + let p = params.p; + + // Direct simulation for small n + if n <= 64u { + var successes = 0u; + for (var i = 0u; i < n; i = i + 1u) { + if pcg_uniform(&state) < p { + successes = successes + 1u; + } + } + out[idx] = f32(f32(successes)); + } else { + // Normal approximation for large n + let mean = f32(n) * p; + let std_dev = sqrt(mean * (1.0 - p)); + let z = sample_normal(&state); + let result = clamp(round(mean + std_dev * z), 0.0, f32(n)); + out[idx] = f32(result); + } + } +} diff --git a/src/runtime/wgpu/shaders/cast_f32_to_i32.wgsl b/src/runtime/wgpu/shaders/cast_f32_to_i32.wgsl new file mode 100644 index 00000000..bb81a50e --- /dev/null +++ b/src/runtime/wgpu/shaders/cast_f32_to_i32.wgsl @@ -0,0 +1,19 @@ +// F32 to I32 cast operation + +const WORKGROUP_SIZE: u32 = 256u; + +struct CastParams { + numel: u32, +} + +@group(0) @binding(0) var cast_input: array; +@group(0) @binding(1) var cast_output: array; +@group(0) @binding(2) var cast_params: CastParams; + +@compute @workgroup_size(256) +fn cast_f32_to_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < cast_params.numel) { + cast_output[idx] = i32(cast_input[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/cast_f32_to_u32.wgsl b/src/runtime/wgpu/shaders/cast_f32_to_u32.wgsl new file mode 100644 index 00000000..21efd791 --- /dev/null +++ b/src/runtime/wgpu/shaders/cast_f32_to_u32.wgsl @@ -0,0 +1,19 @@ +// F32 to U32 cast operation + +const WORKGROUP_SIZE: u32 = 256u; + +struct CastParams { + numel: u32, +} + +@group(0) @binding(0) var cast_input: array; +@group(0) @binding(1) var cast_output: array; +@group(0) @binding(2) var cast_params: CastParams; + +@compute @workgroup_size(256) +fn cast_f32_to_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < cast_params.numel) { + cast_output[idx] = u32(cast_input[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/cast_i32_to_f32.wgsl b/src/runtime/wgpu/shaders/cast_i32_to_f32.wgsl new file mode 100644 index 00000000..ca6f820a --- /dev/null +++ b/src/runtime/wgpu/shaders/cast_i32_to_f32.wgsl @@ -0,0 +1,19 @@ +// I32 to F32 cast operation + +const WORKGROUP_SIZE: u32 = 256u; + +struct CastParams { + numel: u32, +} + +@group(0) @binding(0) var cast_input: array; +@group(0) @binding(1) var cast_output: array; +@group(0) @binding(2) var cast_params: CastParams; + +@compute @workgroup_size(256) +fn cast_i32_to_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < cast_params.numel) { + cast_output[idx] = f32(cast_input[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/cast_i32_to_u32.wgsl b/src/runtime/wgpu/shaders/cast_i32_to_u32.wgsl new file mode 100644 index 00000000..348c7ac7 --- /dev/null +++ b/src/runtime/wgpu/shaders/cast_i32_to_u32.wgsl @@ -0,0 +1,19 @@ +// I32 to U32 cast operation + +const WORKGROUP_SIZE: u32 = 256u; + +struct CastParams { + numel: u32, +} + +@group(0) @binding(0) var cast_input: array; +@group(0) @binding(1) var cast_output: array; +@group(0) @binding(2) var cast_params: CastParams; + +@compute @workgroup_size(256) +fn cast_i32_to_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < cast_params.numel) { + cast_output[idx] = u32(cast_input[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/cast_u32_to_f32.wgsl b/src/runtime/wgpu/shaders/cast_u32_to_f32.wgsl new file mode 100644 index 00000000..aa097e4e --- /dev/null +++ b/src/runtime/wgpu/shaders/cast_u32_to_f32.wgsl @@ -0,0 +1,19 @@ +// U32 to F32 cast operation + +const WORKGROUP_SIZE: u32 = 256u; + +struct CastParams { + numel: u32, +} + +@group(0) @binding(0) var cast_input: array; +@group(0) @binding(1) var cast_output: array; +@group(0) @binding(2) var cast_params: CastParams; + +@compute @workgroup_size(256) +fn cast_u32_to_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < cast_params.numel) { + cast_output[idx] = f32(cast_input[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/cast_u32_to_i32.wgsl b/src/runtime/wgpu/shaders/cast_u32_to_i32.wgsl new file mode 100644 index 00000000..862bb08a --- /dev/null +++ b/src/runtime/wgpu/shaders/cast_u32_to_i32.wgsl @@ -0,0 +1,19 @@ +// U32 to I32 cast operation + +const WORKGROUP_SIZE: u32 = 256u; + +struct CastParams { + numel: u32, +} + +@group(0) @binding(0) var cast_input: array; +@group(0) @binding(1) var cast_output: array; +@group(0) @binding(2) var cast_params: CastParams; + +@compute @workgroup_size(256) +fn cast_u32_to_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < cast_params.numel) { + cast_output[idx] = i32(cast_input[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/cat_copy_f32.wgsl b/src/runtime/wgpu/shaders/cat_copy_f32.wgsl new file mode 100644 index 00000000..814f3d84 --- /dev/null +++ b/src/runtime/wgpu/shaders/cat_copy_f32.wgsl @@ -0,0 +1,37 @@ +// Auto-generated cat operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct CatParams { + outer_size: u32, + src_cat_size: u32, + dst_cat_size: u32, + cat_offset: u32, + inner_size: u32, + total_elements: u32, +} + +@group(0) @binding(0) var cat_src: array; +@group(0) @binding(1) var cat_dst: array; +@group(0) @binding(2) var cat_params: CatParams; + +@compute @workgroup_size(256) +fn cat_copy_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= cat_params.total_elements) { + return; + } + + // Decompose idx into (outer, cat_i, inner) for source tensor + let inner = idx % cat_params.inner_size; + let remaining = idx / cat_params.inner_size; + let cat_i = remaining % cat_params.src_cat_size; + let outer = remaining / cat_params.src_cat_size; + + // Compute destination index + let dst_idx = outer * cat_params.dst_cat_size * cat_params.inner_size + + (cat_params.cat_offset + cat_i) * cat_params.inner_size + + inner; + + cat_dst[dst_idx] = cat_src[idx]; +} diff --git a/src/runtime/wgpu/shaders/cat_copy_i32.wgsl b/src/runtime/wgpu/shaders/cat_copy_i32.wgsl new file mode 100644 index 00000000..2a6e114e --- /dev/null +++ b/src/runtime/wgpu/shaders/cat_copy_i32.wgsl @@ -0,0 +1,37 @@ +// Auto-generated cat operations for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct CatParams { + outer_size: u32, + src_cat_size: u32, + dst_cat_size: u32, + cat_offset: u32, + inner_size: u32, + total_elements: u32, +} + +@group(0) @binding(0) var cat_src: array; +@group(0) @binding(1) var cat_dst: array; +@group(0) @binding(2) var cat_params: CatParams; + +@compute @workgroup_size(256) +fn cat_copy_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= cat_params.total_elements) { + return; + } + + // Decompose idx into (outer, cat_i, inner) for source tensor + let inner = idx % cat_params.inner_size; + let remaining = idx / cat_params.inner_size; + let cat_i = remaining % cat_params.src_cat_size; + let outer = remaining / cat_params.src_cat_size; + + // Compute destination index + let dst_idx = outer * cat_params.dst_cat_size * cat_params.inner_size + + (cat_params.cat_offset + cat_i) * cat_params.inner_size + + inner; + + cat_dst[dst_idx] = cat_src[idx]; +} diff --git a/src/runtime/wgpu/shaders/cat_copy_u32.wgsl b/src/runtime/wgpu/shaders/cat_copy_u32.wgsl new file mode 100644 index 00000000..232065a7 --- /dev/null +++ b/src/runtime/wgpu/shaders/cat_copy_u32.wgsl @@ -0,0 +1,37 @@ +// Auto-generated cat operations for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct CatParams { + outer_size: u32, + src_cat_size: u32, + dst_cat_size: u32, + cat_offset: u32, + inner_size: u32, + total_elements: u32, +} + +@group(0) @binding(0) var cat_src: array; +@group(0) @binding(1) var cat_dst: array; +@group(0) @binding(2) var cat_params: CatParams; + +@compute @workgroup_size(256) +fn cat_copy_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= cat_params.total_elements) { + return; + } + + // Decompose idx into (outer, cat_i, inner) for source tensor + let inner = idx % cat_params.inner_size; + let remaining = idx / cat_params.inner_size; + let cat_i = remaining % cat_params.src_cat_size; + let outer = remaining / cat_params.src_cat_size; + + // Compute destination index + let dst_idx = outer * cat_params.dst_cat_size * cat_params.inner_size + + (cat_params.cat_offset + cat_i) * cat_params.inner_size + + inner; + + cat_dst[dst_idx] = cat_src[idx]; +} diff --git a/src/runtime/wgpu/shaders/chi_squared_f32.wgsl b/src/runtime/wgpu/shaders/chi_squared_f32.wgsl new file mode 100644 index 00000000..1d1f077e --- /dev/null +++ b/src/runtime/wgpu/shaders/chi_squared_f32.wgsl @@ -0,0 +1,91 @@ +// Chi-squared distribution sampling for f32 + +// PCG hash function for random number generation +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; +} + +// Box-Muller for normal distribution +fn sample_normal(state: ptr) -> f32 { + let u1 = max(pcg_uniform(state), 0.0000001); + let u2 = pcg_uniform(state); + return sqrt(-2.0 * log(u1)) * cos(6.28318530718 * u2); +} + +// Gamma via Marsaglia-Tsang method +fn sample_gamma_mt(state: ptr, shape: f32, scale: f32) -> f32 { + var alpha = shape; + var boost = 1.0; + + // Handle shape < 1 by boosting + if alpha < 1.0 { + boost = pow(pcg_uniform(state), 1.0 / alpha); + alpha = alpha + 1.0; + } + + let d = alpha - 1.0 / 3.0; + let c = 1.0 / sqrt(9.0 * d); + + // Rejection sampling + for (var i = 0u; i < 100u; i = i + 1u) { + var x: f32; + var v: f32; + + // Generate valid v + for (var j = 0u; j < 100u; j = j + 1u) { + x = sample_normal(state); + v = 1.0 + c * x; + if v > 0.0 { + break; + } + } + + v = v * v * v; + let u = pcg_uniform(state); + let x2 = x * x; + + // Accept/reject + if u < 1.0 - 0.0331 * x2 * x2 { + return d * v * boost * scale; + } + if log(u) < 0.5 * x2 + d * (1.0 - v + log(v)) { + return d * v * boost * scale; + } + } + + // Fallback (should rarely reach) + return d * boost * scale; +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct ChiSquaredParams { + numel: u32, + seed: u32, + df: f32, + _pad: u32, +} + +@group(0) @binding(0) var out: array; +@group(0) @binding(1) var params: ChiSquaredParams; + +@compute @workgroup_size(256) +fn chi_squared_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if idx < params.numel { + var state = pcg_init(params.seed, idx); + // Chi-squared(df) = Gamma(df/2, 2) + out[idx] = f32(sample_gamma_mt(&state, params.df / 2.0, 2.0)); + } +} diff --git a/src/runtime/wgpu/shaders/compare.wgsl b/src/runtime/wgpu/shaders/compare.wgsl new file mode 100644 index 00000000..993998a2 --- /dev/null +++ b/src/runtime/wgpu/shaders/compare.wgsl @@ -0,0 +1,60 @@ +// F32 comparison operations (input F32, output F32: 1.0=true, 0.0=false) + +const WORKGROUP_SIZE: u32 = 256u; + +struct CompareParams { + numel: u32, +} + +@group(0) @binding(0) var compare_a: array; +@group(0) @binding(1) var compare_b: array; +@group(0) @binding(2) var compare_out: array; +@group(0) @binding(3) var compare_params: CompareParams; + +@compute @workgroup_size(256) +fn eq_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] == compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn ne_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] != compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn lt_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] < compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn le_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] <= compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn gt_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] > compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn ge_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] >= compare_b[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/compare_i32.wgsl b/src/runtime/wgpu/shaders/compare_i32.wgsl new file mode 100644 index 00000000..960aa3bb --- /dev/null +++ b/src/runtime/wgpu/shaders/compare_i32.wgsl @@ -0,0 +1,60 @@ +// I32 comparison operations (input I32, output F32: 1.0=true, 0.0=false) + +const WORKGROUP_SIZE: u32 = 256u; + +struct CompareParams { + numel: u32, +} + +@group(0) @binding(0) var compare_a: array; +@group(0) @binding(1) var compare_b: array; +@group(0) @binding(2) var compare_out: array; +@group(0) @binding(3) var compare_params: CompareParams; + +@compute @workgroup_size(256) +fn eq_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] == compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn ne_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] != compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn lt_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] < compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn le_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] <= compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn gt_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] > compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn ge_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] >= compare_b[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/compare_u32.wgsl b/src/runtime/wgpu/shaders/compare_u32.wgsl new file mode 100644 index 00000000..57e10b15 --- /dev/null +++ b/src/runtime/wgpu/shaders/compare_u32.wgsl @@ -0,0 +1,60 @@ +// U32 comparison operations (input U32, output F32: 1.0=true, 0.0=false) + +const WORKGROUP_SIZE: u32 = 256u; + +struct CompareParams { + numel: u32, +} + +@group(0) @binding(0) var compare_a: array; +@group(0) @binding(1) var compare_b: array; +@group(0) @binding(2) var compare_out: array; +@group(0) @binding(3) var compare_params: CompareParams; + +@compute @workgroup_size(256) +fn eq_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] == compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn ne_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] != compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn lt_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] < compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn le_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] <= compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn gt_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] > compare_b[idx]); + } +} + +@compute @workgroup_size(256) +fn ge_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < compare_params.numel) { + compare_out[idx] = select(0.0, 1.0, compare_a[idx] >= compare_b[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/complex.rs b/src/runtime/wgpu/shaders/complex.rs index a4c2eda1..68d3fbc1 100644 --- a/src/runtime/wgpu/shaders/complex.rs +++ b/src/runtime/wgpu/shaders/complex.rs @@ -1,11 +1,34 @@ //! Complex number operation compute shader launchers for WebGPU -use super::generator::complex::get_complex_shader_generator; -use super::pipeline::PipelineCache; +use super::pipeline::{LayoutKey, PipelineCache}; use crate::dtype::DType; use crate::error::{Error, Result}; use wgpu::{Buffer, Queue}; +const CONJ_SHADER: &str = include_str!("conj_complex64.wgsl"); +// entry point: "conj_complex64" + +const REAL_SHADER: &str = include_str!("real_complex64.wgsl"); +// entry point: "real_complex64" + +const IMAG_SHADER: &str = include_str!("imag_complex64.wgsl"); +// entry point: "imag_complex64" + +const ANGLE_SHADER: &str = include_str!("angle_complex64.wgsl"); +// entry point: "angle_complex64" + +const ANGLE_REAL_SHADER: &str = include_str!("angle_real_f32.wgsl"); +// entry point: "angle_real_f32" + +const FROM_REAL_IMAG_SHADER: &str = include_str!("from_real_imag_f32.wgsl"); +// entry point: "from_real_imag_f32" + +const COMPLEX_MUL_REAL_SHADER: &str = include_str!("complex64_mul_real.wgsl"); +// entry point: "complex64_mul_real" + +const COMPLEX_DIV_REAL_SHADER: &str = include_str!("complex64_div_real.wgsl"); +// entry point: "complex64_div_real" + /// Launch a complex operation on the GPU. /// /// # Arguments @@ -43,27 +66,31 @@ pub fn launch_complex_op( }); } - // Get shader generator for this operation - let shader_gen = get_complex_shader_generator(op)?; - let shader_src = shader_gen()?; - - // Entry point name: "conj_complex64", "real_complex64", etc. - let entry_point = format!("{}_{}", op, "complex64"); + let (shader_src, module_name, entry_point): (&str, &'static str, &'static str) = match op { + "conj" => (CONJ_SHADER, "conj_complex64", "conj_complex64"), + "real" => (REAL_SHADER, "real_complex64", "real_complex64"), + "imag" => (IMAG_SHADER, "imag_complex64", "imag_complex64"), + "angle" => (ANGLE_SHADER, "angle_complex64", "angle_complex64"), + _ => { + return Err(Error::Internal(format!( + "Unknown complex operation: {}", + op + ))); + } + }; // Create shader module - let module_name = format!("complex_{}_{}", op, "complex64"); - let module = cache.get_or_create_module_from_source(&module_name, &shader_src); + let module = cache.get_or_create_module(module_name, shader_src); // Create bind group layout (3 buffers: input storage, output storage, params uniform) - let layout = cache.get_or_create_layout(super::pipeline::LayoutKey { + let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); // Get or create pipeline - let pipeline = - cache.get_or_create_dynamic_pipeline(&module_name, &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_name, entry_point, &module, &layout); // Create bind group let bind_group = cache.create_bind_group(&layout, &[input_buf, output_buf, params_buf]); @@ -118,19 +145,15 @@ pub fn launch_angle_real( params_buf: &Buffer, numel: usize, ) -> Result<()> { - let shader_src = super::generator::complex::generate_angle_real_shader()?; - let entry_point = "angle_real_f32"; - let module_name = "angle_real_f32"; - - let module = cache.get_or_create_module_from_source(&module_name, &shader_src); - let layout = cache.get_or_create_layout(super::pipeline::LayoutKey { + let module = cache.get_or_create_module("angle_real_f32", ANGLE_REAL_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); let pipeline = - cache.get_or_create_dynamic_pipeline(&module_name, &entry_point, &module, &layout); + cache.get_or_create_pipeline("angle_real_f32", "angle_real_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input_buf, output_buf, params_buf]); let mut encoder = cache @@ -181,19 +204,15 @@ pub fn launch_from_real_imag( params_buf: &Buffer, numel: usize, ) -> Result<()> { - let shader_src = super::generator::complex::generate_from_real_imag_shader()?; - let entry_point = "from_real_imag_f32"; - let module_name = "from_real_imag_f32"; - - let module = cache.get_or_create_module_from_source(&module_name, &shader_src); - let layout = cache.get_or_create_layout(super::pipeline::LayoutKey { + let module = cache.get_or_create_module("from_real_imag_f32", FROM_REAL_IMAG_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); let pipeline = - cache.get_or_create_dynamic_pipeline(&module_name, &entry_point, &module, &layout); + cache.get_or_create_pipeline("from_real_imag_f32", "from_real_imag_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[real_buf, imag_buf, output_buf, params_buf]); @@ -245,19 +264,15 @@ pub fn launch_complex_mul_real( params_buf: &Buffer, numel: usize, ) -> Result<()> { - let shader_src = super::generator::complex::generate_complex_mul_real_shader()?; - let entry_point = "complex64_mul_real"; - let module_name = "complex64_mul_real"; - - let module = cache.get_or_create_module_from_source(&module_name, &shader_src); - let layout = cache.get_or_create_layout(super::pipeline::LayoutKey { + let module = cache.get_or_create_module("complex64_mul_real", COMPLEX_MUL_REAL_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); let pipeline = - cache.get_or_create_dynamic_pipeline(&module_name, &entry_point, &module, &layout); + cache.get_or_create_pipeline("complex64_mul_real", "complex64_mul_real", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[complex_buf, real_buf, output_buf, params_buf]); @@ -309,19 +324,15 @@ pub fn launch_complex_div_real( params_buf: &Buffer, numel: usize, ) -> Result<()> { - let shader_src = super::generator::complex::generate_complex_div_real_shader()?; - let entry_point = "complex64_div_real"; - let module_name = "complex64_div_real"; - - let module = cache.get_or_create_module_from_source(&module_name, &shader_src); - let layout = cache.get_or_create_layout(super::pipeline::LayoutKey { + let module = cache.get_or_create_module("complex64_div_real", COMPLEX_DIV_REAL_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); let pipeline = - cache.get_or_create_dynamic_pipeline(&module_name, &entry_point, &module, &layout); + cache.get_or_create_pipeline("complex64_div_real", "complex64_div_real", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[complex_buf, real_buf, output_buf, params_buf]); diff --git a/src/runtime/wgpu/shaders/complex64_div_real.wgsl b/src/runtime/wgpu/shaders/complex64_div_real.wgsl new file mode 100644 index 00000000..bcb9c799 --- /dev/null +++ b/src/runtime/wgpu/shaders/complex64_div_real.wgsl @@ -0,0 +1,22 @@ +// Complex / real division shader +// entry point: complex64_div_real +// (a + bi) / r = (a/r) + (b/r)*i + +struct Params { + numel: u32, +} + +@group(0) @binding(0) var complex_input: array>; +@group(0) @binding(1) var real_input: array; +@group(0) @binding(2) var output: array>; +@group(0) @binding(3) var params: Params; + +@compute @workgroup_size(256) +fn complex64_div_real(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < params.numel) { + let c = complex_input[idx]; + let r = real_input[idx]; + output[idx] = vec2(c.x / r, c.y / r); + } +} diff --git a/src/runtime/wgpu/shaders/complex64_mul_real.wgsl b/src/runtime/wgpu/shaders/complex64_mul_real.wgsl new file mode 100644 index 00000000..49560397 --- /dev/null +++ b/src/runtime/wgpu/shaders/complex64_mul_real.wgsl @@ -0,0 +1,22 @@ +// Complex × real multiplication shader +// entry point: complex64_mul_real +// (a + bi) * r = ar + br*i + +struct Params { + numel: u32, +} + +@group(0) @binding(0) var complex_input: array>; +@group(0) @binding(1) var real_input: array; +@group(0) @binding(2) var output: array>; +@group(0) @binding(3) var params: Params; + +@compute @workgroup_size(256) +fn complex64_mul_real(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < params.numel) { + let c = complex_input[idx]; + let r = real_input[idx]; + output[idx] = vec2(c.x * r, c.y * r); + } +} diff --git a/src/runtime/wgpu/shaders/conj_complex64.wgsl b/src/runtime/wgpu/shaders/conj_complex64.wgsl new file mode 100644 index 00000000..4db05002 --- /dev/null +++ b/src/runtime/wgpu/shaders/conj_complex64.wgsl @@ -0,0 +1,19 @@ +// Complex conjugate shader +// entry point: conj_complex64 + +struct Params { + numel: u32, +} + +@group(0) @binding(0) var input: array>; +@group(0) @binding(1) var output: array>; +@group(0) @binding(2) var params: Params; + +@compute @workgroup_size(256) +fn conj_complex64(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < params.numel) { + let val = input[idx]; + output[idx] = vec2(val.x, -val.y); // Real stays same, imaginary flips sign + } +} diff --git a/src/runtime/wgpu/shaders/conv.rs b/src/runtime/wgpu/shaders/conv.rs index bb0fe977..1d23d565 100644 --- a/src/runtime/wgpu/shaders/conv.rs +++ b/src/runtime/wgpu/shaders/conv.rs @@ -1,4 +1,4 @@ -//! Convolution WGSL kernel launchers +//! Convolution WGSL kernel launchers (F32 only on WebGPU) //! //! Provides launchers for convolution operations: //! - 1D convolution (conv1d) @@ -9,37 +9,22 @@ use wgpu::{Buffer, Queue}; -use super::generator::{ - generate_conv1d_shader, generate_conv2d_shader, generate_depthwise_conv2d_shader, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; use crate::error::{Error, Result}; -// ============================================================================ -// Helper Macros -// ============================================================================ +const CONV1D_SHADER: &str = include_str!("conv1d_f32.wgsl"); +// entry point: "conv1d_f32" -macro_rules! check_dtype_float { - ($dtype:expr, $op:expr) => { - if $dtype != DType::F32 && $dtype != DType::F16 { - return Err(Error::UnsupportedDType { - dtype: $dtype, - op: $op, - }); - } - }; -} +const CONV2D_SHADER: &str = include_str!("conv2d_f32.wgsl"); +// entry point: "conv2d_f32" + +const DEPTHWISE_CONV2D_SHADER: &str = include_str!("depthwise_conv2d_f32.wgsl"); +// entry point: "depthwise_conv2d_f32" -/// Get static kernel name for convolution operations. -fn kernel_name(op: &'static str, dtype: DType) -> Result<&'static str> { - match (op, dtype) { - ("conv1d", DType::F32) => Ok("conv1d_f32"), - ("conv1d", DType::F16) => Ok("conv1d_f16"), - ("conv2d", DType::F32) => Ok("conv2d_f32"), - ("conv2d", DType::F16) => Ok("conv2d_f16"), - ("depthwise_conv2d", DType::F32) => Ok("depthwise_conv2d_f32"), - ("depthwise_conv2d", DType::F16) => Ok("depthwise_conv2d_f16"), +fn check_dtype_f32(dtype: DType, op: &'static str) -> Result<()> { + match dtype { + DType::F32 => Ok(()), _ => Err(Error::UnsupportedDType { dtype, op }), } } @@ -71,17 +56,15 @@ pub fn launch_conv1d( total_output: usize, dtype: DType, ) -> Result<()> { - check_dtype_float!(dtype, "conv1d"); + check_dtype_f32(dtype, "conv1d")?; - let name = kernel_name("conv1d", dtype)?; - let shader_source = generate_conv1d_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module("conv1d_f32", CONV1D_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 4, num_uniform_buffers: 1, num_readonly_storage: 3, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline("conv1d_f32", "conv1d_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, weight, bias, output, params_buffer]); @@ -133,17 +116,15 @@ pub fn launch_conv2d( total_output: usize, dtype: DType, ) -> Result<()> { - check_dtype_float!(dtype, "conv2d"); + check_dtype_f32(dtype, "conv2d")?; - let name = kernel_name("conv2d", dtype)?; - let shader_source = generate_conv2d_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module("conv2d_f32", CONV2D_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 4, num_uniform_buffers: 1, num_readonly_storage: 3, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline("conv2d_f32", "conv2d_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, weight, bias, output, params_buffer]); @@ -195,17 +176,20 @@ pub fn launch_depthwise_conv2d( total_output: usize, dtype: DType, ) -> Result<()> { - check_dtype_float!(dtype, "depthwise_conv2d"); + check_dtype_f32(dtype, "depthwise_conv2d")?; - let name = kernel_name("depthwise_conv2d", dtype)?; - let shader_source = generate_depthwise_conv2d_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module("depthwise_conv2d_f32", DEPTHWISE_CONV2D_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 4, num_uniform_buffers: 1, num_readonly_storage: 3, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline( + "depthwise_conv2d_f32", + "depthwise_conv2d_f32", + &module, + &layout, + ); let bind_group = cache.create_bind_group(&layout, &[input, weight, bias, output, params_buffer]); diff --git a/src/runtime/wgpu/shaders/conv1d_f32.wgsl b/src/runtime/wgpu/shaders/conv1d_f32.wgsl new file mode 100644 index 00000000..7f31b6a7 --- /dev/null +++ b/src/runtime/wgpu/shaders/conv1d_f32.wgsl @@ -0,0 +1,66 @@ +// Conv1d shader for f32 +// Input layout: (N, C_in, L) +// Weight layout: (C_out, C_in/groups, K) +// Output layout: (N, C_out, L_out) + +const WORKGROUP_SIZE: u32 = 256u; + +struct Conv1dParams { + batch: u32, + c_in: u32, + length: u32, + c_out: u32, + kernel_size: u32, + output_length: u32, + stride: u32, + padding: u32, + dilation: u32, + groups: u32, + has_bias: u32, + _pad: u32, +} + +@group(0) @binding(0) var conv1d_input: array; +@group(0) @binding(1) var conv1d_weight: array; +@group(0) @binding(2) var conv1d_bias: array; +@group(0) @binding(3) var conv1d_output: array; +@group(0) @binding(4) var conv1d_params: Conv1dParams; + +@compute @workgroup_size(256) +fn conv1d_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = conv1d_params.batch * conv1d_params.c_out * conv1d_params.output_length; + if (idx >= total) { return; } + + let ox = idx % conv1d_params.output_length; + let oc = (idx / conv1d_params.output_length) % conv1d_params.c_out; + let b = idx / (conv1d_params.c_out * conv1d_params.output_length); + + let c_in_per_group = conv1d_params.c_in / conv1d_params.groups; + let c_out_per_group = conv1d_params.c_out / conv1d_params.groups; + let g = oc / c_out_per_group; + let c_in_start = g * c_in_per_group; + + var sum: f32 = 0.0; + + for (var ic: u32 = 0u; ic < c_in_per_group; ic = ic + 1u) { + let c_in_idx = c_in_start + ic; + + for (var kx: u32 = 0u; kx < conv1d_params.kernel_size; kx = kx + 1u) { + let ix_signed = i32(ox * conv1d_params.stride + kx * conv1d_params.dilation) - i32(conv1d_params.padding); + + if (ix_signed >= 0 && u32(ix_signed) < conv1d_params.length) { + let ix = u32(ix_signed); + let input_idx = b * conv1d_params.c_in * conv1d_params.length + c_in_idx * conv1d_params.length + ix; + let weight_idx = oc * c_in_per_group * conv1d_params.kernel_size + ic * conv1d_params.kernel_size + kx; + sum = sum + conv1d_input[input_idx] * conv1d_weight[weight_idx]; + } + } + } + + if (conv1d_params.has_bias != 0u) { + sum = sum + conv1d_bias[oc]; + } + + conv1d_output[idx] = sum; +} diff --git a/src/runtime/wgpu/shaders/conv2d_f32.wgsl b/src/runtime/wgpu/shaders/conv2d_f32.wgsl new file mode 100644 index 00000000..d74aae1b --- /dev/null +++ b/src/runtime/wgpu/shaders/conv2d_f32.wgsl @@ -0,0 +1,83 @@ +// Conv2d shader for f32 +// Input layout: (N, C_in, H, W) +// Weight layout: (C_out, C_in/groups, K_h, K_w) +// Output layout: (N, C_out, H_out, W_out) + +const WORKGROUP_SIZE: u32 = 256u; + +struct Conv2dParams { + batch: u32, + c_in: u32, + height: u32, + width: u32, + c_out: u32, + kernel_h: u32, + kernel_w: u32, + output_h: u32, + output_w: u32, + stride_h: u32, + stride_w: u32, + pad_h: u32, + pad_w: u32, + dilation_h: u32, + dilation_w: u32, + groups: u32, + has_bias: u32, + _pad: u32, +} + +@group(0) @binding(0) var conv2d_input: array; +@group(0) @binding(1) var conv2d_weight: array; +@group(0) @binding(2) var conv2d_bias: array; +@group(0) @binding(3) var conv2d_output: array; +@group(0) @binding(4) var conv2d_params: Conv2dParams; + +@compute @workgroup_size(256) +fn conv2d_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = conv2d_params.batch * conv2d_params.c_out * conv2d_params.output_h * conv2d_params.output_w; + if (idx >= total) { return; } + + let ox = idx % conv2d_params.output_w; + let oy = (idx / conv2d_params.output_w) % conv2d_params.output_h; + let oc = (idx / (conv2d_params.output_w * conv2d_params.output_h)) % conv2d_params.c_out; + let b = idx / (conv2d_params.c_out * conv2d_params.output_h * conv2d_params.output_w); + + let c_in_per_group = conv2d_params.c_in / conv2d_params.groups; + let c_out_per_group = conv2d_params.c_out / conv2d_params.groups; + let g = oc / c_out_per_group; + let c_in_start = g * c_in_per_group; + + var sum: f32 = 0.0; + + for (var ic: u32 = 0u; ic < c_in_per_group; ic = ic + 1u) { + let c_in_idx = c_in_start + ic; + + for (var ky: u32 = 0u; ky < conv2d_params.kernel_h; ky = ky + 1u) { + for (var kx: u32 = 0u; kx < conv2d_params.kernel_w; kx = kx + 1u) { + let iy_signed = i32(oy * conv2d_params.stride_h + ky * conv2d_params.dilation_h) - i32(conv2d_params.pad_h); + let ix_signed = i32(ox * conv2d_params.stride_w + kx * conv2d_params.dilation_w) - i32(conv2d_params.pad_w); + + if (iy_signed >= 0 && u32(iy_signed) < conv2d_params.height && ix_signed >= 0 && u32(ix_signed) < conv2d_params.width) { + let iy = u32(iy_signed); + let ix = u32(ix_signed); + let input_idx = b * conv2d_params.c_in * conv2d_params.height * conv2d_params.width + + c_in_idx * conv2d_params.height * conv2d_params.width + + iy * conv2d_params.width + + ix; + let weight_idx = oc * c_in_per_group * conv2d_params.kernel_h * conv2d_params.kernel_w + + ic * conv2d_params.kernel_h * conv2d_params.kernel_w + + ky * conv2d_params.kernel_w + + kx; + sum = sum + conv2d_input[input_idx] * conv2d_weight[weight_idx]; + } + } + } + } + + if (conv2d_params.has_bias != 0u) { + sum = sum + conv2d_bias[oc]; + } + + conv2d_output[idx] = sum; +} diff --git a/src/runtime/wgpu/shaders/copy_complex.wgsl b/src/runtime/wgpu/shaders/copy_complex.wgsl new file mode 100644 index 00000000..75893aca --- /dev/null +++ b/src/runtime/wgpu/shaders/copy_complex.wgsl @@ -0,0 +1,26 @@ +// Copy complex array + +const WORKGROUP_SIZE: u32 = 256u; + +struct CopyParams { + n: u32, + _pad1: u32, + _pad2: u32, + _pad3: u32, +} + +@group(0) @binding(0) var copy_input: array>; +@group(0) @binding(1) var copy_output: array>; +@group(0) @binding(2) var copy_params: CopyParams; + +@compute @workgroup_size(WORKGROUP_SIZE) +fn copy_complex( + @builtin(global_invocation_id) gid: vec3 +) { + let idx = gid.x; + let n = copy_params.n; + + if (idx < n) { + copy_output[idx] = copy_input[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/count_nonzero_f32.wgsl b/src/runtime/wgpu/shaders/count_nonzero_f32.wgsl new file mode 100644 index 00000000..5cca3bf8 --- /dev/null +++ b/src/runtime/wgpu/shaders/count_nonzero_f32.wgsl @@ -0,0 +1,48 @@ +// Auto-generated count_nonzero operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +var shared_count: array; + +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var count_output: array>; +@group(0) @binding(2) var count_params: CountParams; + +@compute @workgroup_size(256) +fn count_nonzero_f32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let tid = local_id.x; + let numel = count_params.numel; + + // Each thread counts its elements + var local_count: u32 = 0u; + var idx = global_id.x; + while (idx < numel) { + if (input[idx] != 0.0) { + local_count = local_count + 1u; + } + idx = idx + WORKGROUP_SIZE * 256u; // stride by total threads + } + + shared_count[tid] = local_count; + workgroupBarrier(); + + // Tree reduction + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + shared_count[tid] = shared_count[tid] + shared_count[tid + s]; + } + workgroupBarrier(); + } + + // Thread 0 adds to global counter + if (tid == 0u) { + atomicAdd(&count_output[0], shared_count[0]); + } +} diff --git a/src/runtime/wgpu/shaders/count_nonzero_i32.wgsl b/src/runtime/wgpu/shaders/count_nonzero_i32.wgsl new file mode 100644 index 00000000..8dc10551 --- /dev/null +++ b/src/runtime/wgpu/shaders/count_nonzero_i32.wgsl @@ -0,0 +1,48 @@ +// Auto-generated count_nonzero operations for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +var shared_count: array; + +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var count_output: array>; +@group(0) @binding(2) var count_params: CountParams; + +@compute @workgroup_size(256) +fn count_nonzero_i32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let tid = local_id.x; + let numel = count_params.numel; + + // Each thread counts its elements + var local_count: u32 = 0u; + var idx = global_id.x; + while (idx < numel) { + if (input[idx] != 0) { + local_count = local_count + 1u; + } + idx = idx + WORKGROUP_SIZE * 256u; // stride by total threads + } + + shared_count[tid] = local_count; + workgroupBarrier(); + + // Tree reduction + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + shared_count[tid] = shared_count[tid] + shared_count[tid + s]; + } + workgroupBarrier(); + } + + // Thread 0 adds to global counter + if (tid == 0u) { + atomicAdd(&count_output[0], shared_count[0]); + } +} diff --git a/src/runtime/wgpu/shaders/count_nonzero_u32.wgsl b/src/runtime/wgpu/shaders/count_nonzero_u32.wgsl new file mode 100644 index 00000000..4174ec22 --- /dev/null +++ b/src/runtime/wgpu/shaders/count_nonzero_u32.wgsl @@ -0,0 +1,48 @@ +// Auto-generated count_nonzero operations for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +var shared_count: array; + +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var count_output: array>; +@group(0) @binding(2) var count_params: CountParams; + +@compute @workgroup_size(256) +fn count_nonzero_u32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let tid = local_id.x; + let numel = count_params.numel; + + // Each thread counts its elements + var local_count: u32 = 0u; + var idx = global_id.x; + while (idx < numel) { + if (input[idx] != 0u) { + local_count = local_count + 1u; + } + idx = idx + WORKGROUP_SIZE * 256u; // stride by total threads + } + + shared_count[tid] = local_count; + workgroupBarrier(); + + // Tree reduction + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + shared_count[tid] = shared_count[tid] + shared_count[tid + s]; + } + workgroupBarrier(); + } + + // Thread 0 adds to global counter + if (tid == 0u) { + atomicAdd(&count_output[0], shared_count[0]); + } +} diff --git a/src/runtime/wgpu/shaders/count_unique_f32.wgsl b/src/runtime/wgpu/shaders/count_unique_f32.wgsl new file mode 100644 index 00000000..372ad13a --- /dev/null +++ b/src/runtime/wgpu/shaders/count_unique_f32.wgsl @@ -0,0 +1,42 @@ +// Count unique elements in a sorted f32 array + +var shared_count: array; + +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var sorted_input: array; +@group(0) @binding(1) var count_output: array>; +@group(0) @binding(2) var params: CountParams; + +@compute @workgroup_size(256) +fn count_unique_f32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, +) { + let tid = local_id.x; + let numel = params.numel; + + var local_count: u32 = 0u; + let idx = global_id.x; + if (idx < numel) { + if (idx == 0u || sorted_input[idx] != sorted_input[idx - 1u]) { + local_count = 1u; + } + } + + shared_count[tid] = local_count; + workgroupBarrier(); + + for (var s: u32 = 128u; s > 0u; s = s >> 1u) { + if (tid < s) { + shared_count[tid] = shared_count[tid] + shared_count[tid + s]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + atomicAdd(&count_output[0], shared_count[0]); + } +} diff --git a/src/runtime/wgpu/shaders/count_unique_i32.wgsl b/src/runtime/wgpu/shaders/count_unique_i32.wgsl new file mode 100644 index 00000000..297df772 --- /dev/null +++ b/src/runtime/wgpu/shaders/count_unique_i32.wgsl @@ -0,0 +1,42 @@ +// Count unique elements in a sorted i32 array + +var shared_count: array; + +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var sorted_input: array; +@group(0) @binding(1) var count_output: array>; +@group(0) @binding(2) var params: CountParams; + +@compute @workgroup_size(256) +fn count_unique_i32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, +) { + let tid = local_id.x; + let numel = params.numel; + + var local_count: u32 = 0u; + let idx = global_id.x; + if (idx < numel) { + if (idx == 0u || sorted_input[idx] != sorted_input[idx - 1u]) { + local_count = 1u; + } + } + + shared_count[tid] = local_count; + workgroupBarrier(); + + for (var s: u32 = 128u; s > 0u; s = s >> 1u) { + if (tid < s) { + shared_count[tid] = shared_count[tid] + shared_count[tid + s]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + atomicAdd(&count_output[0], shared_count[0]); + } +} diff --git a/src/runtime/wgpu/shaders/count_unique_u32.wgsl b/src/runtime/wgpu/shaders/count_unique_u32.wgsl new file mode 100644 index 00000000..0b687eb6 --- /dev/null +++ b/src/runtime/wgpu/shaders/count_unique_u32.wgsl @@ -0,0 +1,42 @@ +// Count unique elements in a sorted u32 array + +var shared_count: array; + +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var sorted_input: array; +@group(0) @binding(1) var count_output: array>; +@group(0) @binding(2) var params: CountParams; + +@compute @workgroup_size(256) +fn count_unique_u32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, +) { + let tid = local_id.x; + let numel = params.numel; + + var local_count: u32 = 0u; + let idx = global_id.x; + if (idx < numel) { + if (idx == 0u || sorted_input[idx] != sorted_input[idx - 1u]) { + local_count = 1u; + } + } + + shared_count[tid] = local_count; + workgroupBarrier(); + + for (var s: u32 = 128u; s > 0u; s = s >> 1u) { + if (tid < s) { + shared_count[tid] = shared_count[tid] + shared_count[tid + s]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + atomicAdd(&count_output[0], shared_count[0]); + } +} diff --git a/src/runtime/wgpu/shaders/cumprod_f32.wgsl b/src/runtime/wgpu/shaders/cumprod_f32.wgsl new file mode 100644 index 00000000..2af298d8 --- /dev/null +++ b/src/runtime/wgpu/shaders/cumprod_f32.wgsl @@ -0,0 +1,25 @@ +// Cumulative product shader for f32 + +struct CumprodParams { + scan_size: u32, + outer_size: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: CumprodParams; + +@compute @workgroup_size(256) +fn cumprod_f32(@builtin(global_invocation_id) global_id: vec3) { + let outer_idx = global_id.x; + if (outer_idx >= params.outer_size) { + return; + } + + let base = outer_idx * params.scan_size; + var acc: f32 = 1.0; + for (var i: u32 = 0u; i < params.scan_size; i = i + 1u) { + acc = acc * input[base + i]; + output[base + i] = acc; + } +} diff --git a/src/runtime/wgpu/shaders/cumprod_i32.wgsl b/src/runtime/wgpu/shaders/cumprod_i32.wgsl new file mode 100644 index 00000000..b5be9df2 --- /dev/null +++ b/src/runtime/wgpu/shaders/cumprod_i32.wgsl @@ -0,0 +1,25 @@ +// Cumulative product shader for i32 + +struct CumprodParams { + scan_size: u32, + outer_size: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: CumprodParams; + +@compute @workgroup_size(256) +fn cumprod_i32(@builtin(global_invocation_id) global_id: vec3) { + let outer_idx = global_id.x; + if (outer_idx >= params.outer_size) { + return; + } + + let base = outer_idx * params.scan_size; + var acc: i32 = 1; + for (var i: u32 = 0u; i < params.scan_size; i = i + 1u) { + acc = acc * input[base + i]; + output[base + i] = acc; + } +} diff --git a/src/runtime/wgpu/shaders/cumprod_strided_f32.wgsl b/src/runtime/wgpu/shaders/cumprod_strided_f32.wgsl new file mode 100644 index 00000000..869d770d --- /dev/null +++ b/src/runtime/wgpu/shaders/cumprod_strided_f32.wgsl @@ -0,0 +1,30 @@ +// Strided cumulative product shader for f32 + +struct CumprodStridedParams { + scan_size: u32, + outer_size: u32, + inner_size: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: CumprodStridedParams; + +@compute @workgroup_size(256) +fn cumprod_strided_f32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + let total_inner = params.outer_size * params.inner_size; + if (idx >= total_inner) { + return; + } + + let outer_idx = idx / params.inner_size; + let inner_idx = idx % params.inner_size; + + var acc: f32 = 1.0; + for (var s: u32 = 0u; s < params.scan_size; s = s + 1u) { + let offset = outer_idx * params.scan_size * params.inner_size + s * params.inner_size + inner_idx; + acc = acc * input[offset]; + output[offset] = acc; + } +} diff --git a/src/runtime/wgpu/shaders/cumprod_strided_i32.wgsl b/src/runtime/wgpu/shaders/cumprod_strided_i32.wgsl new file mode 100644 index 00000000..5fb006ba --- /dev/null +++ b/src/runtime/wgpu/shaders/cumprod_strided_i32.wgsl @@ -0,0 +1,30 @@ +// Strided cumulative product shader for i32 + +struct CumprodStridedParams { + scan_size: u32, + outer_size: u32, + inner_size: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: CumprodStridedParams; + +@compute @workgroup_size(256) +fn cumprod_strided_i32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + let total_inner = params.outer_size * params.inner_size; + if (idx >= total_inner) { + return; + } + + let outer_idx = idx / params.inner_size; + let inner_idx = idx % params.inner_size; + + var acc: i32 = 1; + for (var s: u32 = 0u; s < params.scan_size; s = s + 1u) { + let offset = outer_idx * params.scan_size * params.inner_size + s * params.inner_size + inner_idx; + acc = acc * input[offset]; + output[offset] = acc; + } +} diff --git a/src/runtime/wgpu/shaders/cumprod_strided_u32.wgsl b/src/runtime/wgpu/shaders/cumprod_strided_u32.wgsl new file mode 100644 index 00000000..42e59dd9 --- /dev/null +++ b/src/runtime/wgpu/shaders/cumprod_strided_u32.wgsl @@ -0,0 +1,30 @@ +// Strided cumulative product shader for u32 + +struct CumprodStridedParams { + scan_size: u32, + outer_size: u32, + inner_size: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: CumprodStridedParams; + +@compute @workgroup_size(256) +fn cumprod_strided_u32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + let total_inner = params.outer_size * params.inner_size; + if (idx >= total_inner) { + return; + } + + let outer_idx = idx / params.inner_size; + let inner_idx = idx % params.inner_size; + + var acc: u32 = 1u; + for (var s: u32 = 0u; s < params.scan_size; s = s + 1u) { + let offset = outer_idx * params.scan_size * params.inner_size + s * params.inner_size + inner_idx; + acc = acc * input[offset]; + output[offset] = acc; + } +} diff --git a/src/runtime/wgpu/shaders/cumprod_u32.wgsl b/src/runtime/wgpu/shaders/cumprod_u32.wgsl new file mode 100644 index 00000000..834f1e6d --- /dev/null +++ b/src/runtime/wgpu/shaders/cumprod_u32.wgsl @@ -0,0 +1,25 @@ +// Cumulative product shader for u32 + +struct CumprodParams { + scan_size: u32, + outer_size: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: CumprodParams; + +@compute @workgroup_size(256) +fn cumprod_u32(@builtin(global_invocation_id) global_id: vec3) { + let outer_idx = global_id.x; + if (outer_idx >= params.outer_size) { + return; + } + + let base = outer_idx * params.scan_size; + var acc: u32 = 1u; + for (var i: u32 = 0u; i < params.scan_size; i = i + 1u) { + acc = acc * input[base + i]; + output[base + i] = acc; + } +} diff --git a/src/runtime/wgpu/shaders/cumsum_f32.wgsl b/src/runtime/wgpu/shaders/cumsum_f32.wgsl new file mode 100644 index 00000000..5a317399 --- /dev/null +++ b/src/runtime/wgpu/shaders/cumsum_f32.wgsl @@ -0,0 +1,25 @@ +// Cumulative sum shader for f32 + +struct CumsumParams { + scan_size: u32, + outer_size: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: CumsumParams; + +@compute @workgroup_size(256) +fn cumsum_f32(@builtin(global_invocation_id) global_id: vec3) { + let outer_idx = global_id.x; + if (outer_idx >= params.outer_size) { + return; + } + + let base = outer_idx * params.scan_size; + var acc: f32 = 0.0; + for (var i: u32 = 0u; i < params.scan_size; i = i + 1u) { + acc = acc + input[base + i]; + output[base + i] = acc; + } +} diff --git a/src/runtime/wgpu/shaders/cumsum_i32.wgsl b/src/runtime/wgpu/shaders/cumsum_i32.wgsl new file mode 100644 index 00000000..35bc7dcd --- /dev/null +++ b/src/runtime/wgpu/shaders/cumsum_i32.wgsl @@ -0,0 +1,25 @@ +// Cumulative sum shader for i32 + +struct CumsumParams { + scan_size: u32, + outer_size: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: CumsumParams; + +@compute @workgroup_size(256) +fn cumsum_i32(@builtin(global_invocation_id) global_id: vec3) { + let outer_idx = global_id.x; + if (outer_idx >= params.outer_size) { + return; + } + + let base = outer_idx * params.scan_size; + var acc: i32 = 0; + for (var i: u32 = 0u; i < params.scan_size; i = i + 1u) { + acc = acc + input[base + i]; + output[base + i] = acc; + } +} diff --git a/src/runtime/wgpu/shaders/cumsum_strided_f32.wgsl b/src/runtime/wgpu/shaders/cumsum_strided_f32.wgsl new file mode 100644 index 00000000..a42a44a3 --- /dev/null +++ b/src/runtime/wgpu/shaders/cumsum_strided_f32.wgsl @@ -0,0 +1,30 @@ +// Strided cumulative sum shader for f32 + +struct CumsumStridedParams { + scan_size: u32, + outer_size: u32, + inner_size: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: CumsumStridedParams; + +@compute @workgroup_size(256) +fn cumsum_strided_f32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + let total_inner = params.outer_size * params.inner_size; + if (idx >= total_inner) { + return; + } + + let outer_idx = idx / params.inner_size; + let inner_idx = idx % params.inner_size; + + var acc: f32 = 0.0; + for (var s: u32 = 0u; s < params.scan_size; s = s + 1u) { + let offset = outer_idx * params.scan_size * params.inner_size + s * params.inner_size + inner_idx; + acc = acc + input[offset]; + output[offset] = acc; + } +} diff --git a/src/runtime/wgpu/shaders/cumsum_strided_i32.wgsl b/src/runtime/wgpu/shaders/cumsum_strided_i32.wgsl new file mode 100644 index 00000000..8a896e5a --- /dev/null +++ b/src/runtime/wgpu/shaders/cumsum_strided_i32.wgsl @@ -0,0 +1,30 @@ +// Strided cumulative sum shader for i32 + +struct CumsumStridedParams { + scan_size: u32, + outer_size: u32, + inner_size: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: CumsumStridedParams; + +@compute @workgroup_size(256) +fn cumsum_strided_i32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + let total_inner = params.outer_size * params.inner_size; + if (idx >= total_inner) { + return; + } + + let outer_idx = idx / params.inner_size; + let inner_idx = idx % params.inner_size; + + var acc: i32 = 0; + for (var s: u32 = 0u; s < params.scan_size; s = s + 1u) { + let offset = outer_idx * params.scan_size * params.inner_size + s * params.inner_size + inner_idx; + acc = acc + input[offset]; + output[offset] = acc; + } +} diff --git a/src/runtime/wgpu/shaders/cumulative.rs b/src/runtime/wgpu/shaders/cumulative.rs index 67edf488..d5b034f1 100644 --- a/src/runtime/wgpu/shaders/cumulative.rs +++ b/src/runtime/wgpu/shaders/cumulative.rs @@ -1,29 +1,58 @@ //! Cumulative operation WGSL kernel launchers //! -//! Provides launchers for cumulative operations: -//! - `cumsum` - Cumulative sum along a dimension -//! - `cumprod` - Cumulative product along a dimension -//! - `logsumexp` - Numerically stable log-sum-exp reduction +//! - `cumsum` - F32 and I32 +//! - `cumprod` - F32, I32, U32 +//! - `logsumexp` - F32 only use wgpu::{Buffer, Queue}; -use super::generator::{ - dtype_suffix, generate_cumprod_shader, generate_cumprod_strided_shader, generate_cumsum_shader, - generate_cumsum_strided_shader, generate_logsumexp_shader, generate_logsumexp_strided_shader, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; -use crate::error::Result; +use crate::error::{Error, Result}; + +const CUMSUM_F32_SHADER: &str = include_str!("cumsum_f32.wgsl"); +const CUMSUM_I32_SHADER: &str = include_str!("cumsum_i32.wgsl"); + +const CUMSUM_STRIDED_F32_SHADER: &str = include_str!("cumsum_strided_f32.wgsl"); +const CUMSUM_STRIDED_I32_SHADER: &str = include_str!("cumsum_strided_i32.wgsl"); + +const CUMPROD_F32_SHADER: &str = include_str!("cumprod_f32.wgsl"); +const CUMPROD_I32_SHADER: &str = include_str!("cumprod_i32.wgsl"); +const CUMPROD_U32_SHADER: &str = include_str!("cumprod_u32.wgsl"); + +const CUMPROD_STRIDED_F32_SHADER: &str = include_str!("cumprod_strided_f32.wgsl"); +const CUMPROD_STRIDED_I32_SHADER: &str = include_str!("cumprod_strided_i32.wgsl"); +const CUMPROD_STRIDED_U32_SHADER: &str = include_str!("cumprod_strided_u32.wgsl"); + +const LOGSUMEXP_SHADER: &str = include_str!("logsumexp_f32.wgsl"); +const LOGSUMEXP_STRIDED_SHADER: &str = include_str!("logsumexp_strided_f32.wgsl"); + +fn check_f32(dtype: DType, op: &'static str) -> Result<()> { + match dtype { + DType::F32 => Ok(()), + _ => Err(Error::UnsupportedDType { dtype, op }), + } +} + +fn check_f32_i32(dtype: DType, op: &'static str) -> Result<()> { + match dtype { + DType::F32 | DType::I32 => Ok(()), + _ => Err(Error::UnsupportedDType { dtype, op }), + } +} + +fn check_f32_i32_u32(dtype: DType, op: &'static str) -> Result<()> { + match dtype { + DType::F32 | DType::I32 | DType::U32 => Ok(()), + _ => Err(Error::UnsupportedDType { dtype, op }), + } +} // ============================================================================ // Cumulative Sum // ============================================================================ -/// Launch cumsum operation kernel (contiguous data). -/// -/// Parameters: -/// - scan_size: Size of the dimension being scanned -/// - outer_size: Number of independent scans +/// Launch cumsum operation kernel (contiguous data). Supports F32 and I32. pub fn launch_cumsum( cache: &PipelineCache, queue: &Queue, @@ -33,22 +62,21 @@ pub fn launch_cumsum( outer_size: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point_name = format!("cumsum_{}", suffix); + check_f32_i32(dtype, "cumsum")?; - // Generate shader on-demand - let shader_source = generate_cumsum_shader(dtype)?; + let (module_key, shader, entry_point) = match dtype { + DType::F32 => ("cumsum_f32", CUMSUM_F32_SHADER, "cumsum_f32"), + DType::I32 => ("cumsum_i32", CUMSUM_I32_SHADER, "cumsum_i32"), + _ => unreachable!(), + }; - let module_name = format!("cumsum_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("cumsum", &entry_point_name, &module, &layout); - + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]); let mut encoder = cache @@ -56,7 +84,6 @@ pub fn launch_cumsum( .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("cumsum"), }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("cumsum"), @@ -66,12 +93,11 @@ pub fn launch_cumsum( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(outer_size), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } -/// Launch strided cumsum operation kernel. +/// Launch strided cumsum operation kernel. Supports F32 and I32. pub fn launch_cumsum_strided( cache: &PipelineCache, queue: &Queue, @@ -81,21 +107,29 @@ pub fn launch_cumsum_strided( total_inner: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point_name = format!("cumsum_strided_{}", suffix); - - let shader_source = generate_cumsum_strided_shader(dtype)?; - - let module = cache - .get_or_create_module_from_source(&format!("cumsum_strided_{}", suffix), &shader_source); + check_f32_i32(dtype, "cumsum_strided")?; + + let (module_key, shader, entry_point) = match dtype { + DType::F32 => ( + "cumsum_strided_f32", + CUMSUM_STRIDED_F32_SHADER, + "cumsum_strided_f32", + ), + DType::I32 => ( + "cumsum_strided_i32", + CUMSUM_STRIDED_I32_SHADER, + "cumsum_strided_i32", + ), + _ => unreachable!(), + }; + + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("cumsum_strided", &entry_point_name, &module, &layout); - + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]); let mut encoder = cache @@ -103,7 +137,6 @@ pub fn launch_cumsum_strided( .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("cumsum_strided"), }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("cumsum_strided"), @@ -113,7 +146,6 @@ pub fn launch_cumsum_strided( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(total_inner), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } @@ -122,7 +154,7 @@ pub fn launch_cumsum_strided( // Cumulative Product // ============================================================================ -/// Launch cumprod operation kernel (contiguous data). +/// Launch cumprod operation kernel (contiguous data). Supports F32, I32, U32. pub fn launch_cumprod( cache: &PipelineCache, queue: &Queue, @@ -132,21 +164,22 @@ pub fn launch_cumprod( outer_size: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point_name = format!("cumprod_{}", suffix); + check_f32_i32_u32(dtype, "cumprod")?; - let shader_source = generate_cumprod_shader(dtype)?; + let (module_key, shader, entry_point) = match dtype { + DType::F32 => ("cumprod_f32", CUMPROD_F32_SHADER, "cumprod_f32"), + DType::I32 => ("cumprod_i32", CUMPROD_I32_SHADER, "cumprod_i32"), + DType::U32 => ("cumprod_u32", CUMPROD_U32_SHADER, "cumprod_u32"), + _ => unreachable!(), + }; - let module = - cache.get_or_create_module_from_source(&format!("cumprod_{}", suffix), &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("cumprod", &entry_point_name, &module, &layout); - + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]); let mut encoder = cache @@ -154,7 +187,6 @@ pub fn launch_cumprod( .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("cumprod"), }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("cumprod"), @@ -164,12 +196,11 @@ pub fn launch_cumprod( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(outer_size), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } -/// Launch strided cumprod operation kernel. +/// Launch strided cumprod operation kernel. Supports F32, I32, U32. pub fn launch_cumprod_strided( cache: &PipelineCache, queue: &Queue, @@ -179,25 +210,34 @@ pub fn launch_cumprod_strided( total_inner: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point_name = format!("cumprod_strided_{}", suffix); - - let shader_source = generate_cumprod_strided_shader(dtype)?; - - let module = cache - .get_or_create_module_from_source(&format!("cumprod_strided_{}", suffix), &shader_source); + check_f32_i32_u32(dtype, "cumprod_strided")?; + + let (module_key, shader, entry_point) = match dtype { + DType::F32 => ( + "cumprod_strided_f32", + CUMPROD_STRIDED_F32_SHADER, + "cumprod_strided_f32", + ), + DType::I32 => ( + "cumprod_strided_i32", + CUMPROD_STRIDED_I32_SHADER, + "cumprod_strided_i32", + ), + DType::U32 => ( + "cumprod_strided_u32", + CUMPROD_STRIDED_U32_SHADER, + "cumprod_strided_u32", + ), + _ => unreachable!(), + }; + + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline( - "cumprod_strided", - &entry_point_name, - &module, - &layout, - ); - + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]); let mut encoder = cache @@ -205,7 +245,6 @@ pub fn launch_cumprod_strided( .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("cumprod_strided"), }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("cumprod_strided"), @@ -215,7 +254,6 @@ pub fn launch_cumprod_strided( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(total_inner), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } @@ -234,20 +272,15 @@ pub fn launch_logsumexp( outer_size: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point_name = format!("logsumexp_{}", suffix); + check_f32(dtype, "logsumexp")?; - let shader_source = generate_logsumexp_shader(dtype)?; - - let module = - cache.get_or_create_module_from_source(&format!("logsumexp_{}", suffix), &shader_source); + let module = cache.get_or_create_module("logsumexp_f32", LOGSUMEXP_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("logsumexp", &entry_point_name, &module, &layout); + let pipeline = cache.get_or_create_pipeline("logsumexp_f32", "logsumexp_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]); @@ -281,21 +314,17 @@ pub fn launch_logsumexp_strided( total_inner: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point_name = format!("logsumexp_strided_{}", suffix); - - let shader_source = generate_logsumexp_strided_shader(dtype)?; + check_f32(dtype, "logsumexp_strided")?; - let module = cache - .get_or_create_module_from_source(&format!("logsumexp_strided_{}", suffix), &shader_source); + let module = cache.get_or_create_module("logsumexp_strided_f32", LOGSUMEXP_STRIDED_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline( - "logsumexp_strided", - &entry_point_name, + let pipeline = cache.get_or_create_pipeline( + "logsumexp_strided_f32", + "logsumexp_strided_f32", &module, &layout, ); diff --git a/src/runtime/wgpu/shaders/depthwise_conv2d_f32.wgsl b/src/runtime/wgpu/shaders/depthwise_conv2d_f32.wgsl new file mode 100644 index 00000000..84359764 --- /dev/null +++ b/src/runtime/wgpu/shaders/depthwise_conv2d_f32.wgsl @@ -0,0 +1,69 @@ +// Depthwise conv2d shader for f32 +// Input layout: (N, C, H, W) +// Weight layout: (C, 1, K_h, K_w) +// Output layout: (N, C, H_out, W_out) + +const WORKGROUP_SIZE: u32 = 256u; + +struct DepthwiseConv2dParams { + batch: u32, + channels: u32, + height: u32, + width: u32, + kernel_h: u32, + kernel_w: u32, + output_h: u32, + output_w: u32, + stride_h: u32, + stride_w: u32, + pad_h: u32, + pad_w: u32, + dilation_h: u32, + dilation_w: u32, + has_bias: u32, + _pad: u32, +} + +@group(0) @binding(0) var depthwise_input: array; +@group(0) @binding(1) var depthwise_weight: array; +@group(0) @binding(2) var depthwise_bias: array; +@group(0) @binding(3) var depthwise_output: array; +@group(0) @binding(4) var depthwise_params: DepthwiseConv2dParams; + +@compute @workgroup_size(256) +fn depthwise_conv2d_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = depthwise_params.batch * depthwise_params.channels * depthwise_params.output_h * depthwise_params.output_w; + if (idx >= total) { return; } + + let ox = idx % depthwise_params.output_w; + let oy = (idx / depthwise_params.output_w) % depthwise_params.output_h; + let c = (idx / (depthwise_params.output_w * depthwise_params.output_h)) % depthwise_params.channels; + let b = idx / (depthwise_params.channels * depthwise_params.output_h * depthwise_params.output_w); + + var sum: f32 = 0.0; + + for (var ky: u32 = 0u; ky < depthwise_params.kernel_h; ky = ky + 1u) { + for (var kx: u32 = 0u; kx < depthwise_params.kernel_w; kx = kx + 1u) { + let iy_signed = i32(oy * depthwise_params.stride_h + ky * depthwise_params.dilation_h) - i32(depthwise_params.pad_h); + let ix_signed = i32(ox * depthwise_params.stride_w + kx * depthwise_params.dilation_w) - i32(depthwise_params.pad_w); + + if (iy_signed >= 0 && u32(iy_signed) < depthwise_params.height && ix_signed >= 0 && u32(ix_signed) < depthwise_params.width) { + let iy = u32(iy_signed); + let ix = u32(ix_signed); + let input_idx = b * depthwise_params.channels * depthwise_params.height * depthwise_params.width + + c * depthwise_params.height * depthwise_params.width + + iy * depthwise_params.width + + ix; + let weight_idx = c * depthwise_params.kernel_h * depthwise_params.kernel_w + ky * depthwise_params.kernel_w + kx; + sum = sum + depthwise_input[input_idx] * depthwise_weight[weight_idx]; + } + } + } + + if (depthwise_params.has_bias != 0u) { + sum = sum + depthwise_bias[c]; + } + + depthwise_output[idx] = sum; +} diff --git a/src/runtime/wgpu/shaders/diagonal_exp_f32.wgsl b/src/runtime/wgpu/shaders/diagonal_exp_f32.wgsl new file mode 100644 index 00000000..30ac59d2 --- /dev/null +++ b/src/runtime/wgpu/shaders/diagonal_exp_f32.wgsl @@ -0,0 +1,102 @@ +// Diagonal block function application for f32 - exp + +const WORKGROUP_SIZE: u32 = 256u; + +struct Params { + n: u32, + eps: f32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var input_t: array; +@group(0) @binding(1) var output_f: array; +@group(0) @binding(2) var params: Params; + +// Apply exp to 2x2 block +fn apply_2x2_block(a: f32, b: f32, c: f32, d: f32, + f11: ptr, f12: ptr, + f21: ptr, f22: ptr) { + // For 2x2 block with complex eigenvalues a ± bi: + // exp(a ± bi) = exp(a) * (cos(b) ± i*sin(b)) + // Result is [[exp(a)*cos(b), -exp(a)*sin(b)], [exp(a)*sin(b), exp(a)*cos(b)]] + // after similarity transform + let trace = a + d; + let det = a * d - b * c; + let disc = trace * trace - 4.0 * det; + + if disc >= 0.0 { + // Real eigenvalues - diagonalize and apply exp + let sqrt_disc = sqrt(disc); + let lambda1 = (trace + sqrt_disc) / 2.0; + let lambda2 = (trace - sqrt_disc) / 2.0; + let exp1 = exp(lambda1); + let exp2 = exp(lambda2); + + // Simple case: return diagonal exp values + // This is approximate but handles most cases + *f11 = (exp1 + exp2) / 2.0; + *f22 = (exp1 + exp2) / 2.0; + *f12 = (exp1 - exp2) / 2.0 * sign(b); + *f21 = (exp1 - exp2) / 2.0 * sign(c); + } else { + // Complex eigenvalues + let real_part = trace / 2.0; + let imag_part = sqrt(-disc) / 2.0; + let exp_real = exp(real_part); + let cos_imag = cos(imag_part); + let sin_imag = sin(imag_part); + + *f11 = exp_real * cos_imag; + *f22 = exp_real * cos_imag; + // Off-diagonal scaling based on original block structure + let scale = exp_real * sin_imag / imag_part; + *f12 = scale * b; + *f21 = scale * c; + } +} + +@compute @workgroup_size(1) +fn diagonal_exp_f32(@builtin(global_invocation_id) gid: vec3) { + let n = params.n; + let eps = f32(params.eps); + + // Initialize output to zero + for (var idx: u32 = 0u; idx < n * n; idx = idx + 1u) { + output_f[idx] = 0.0; + } + + var i: u32 = 0u; + while i < n { + // Check if this is a 2x2 block + if i + 1u < n { + let sub_diag = abs(input_t[(i + 1u) * n + i]); + if sub_diag > eps { + // 2x2 block + let a = input_t[i * n + i]; + let b = input_t[i * n + (i + 1u)]; + let c = input_t[(i + 1u) * n + i]; + let d = input_t[(i + 1u) * n + (i + 1u)]; + + var f11: f32; + var f12: f32; + var f21: f32; + var f22: f32; + apply_2x2_block(a, b, c, d, &f11, &f12, &f21, &f22); + + output_f[i * n + i] = f11; + output_f[i * n + (i + 1u)] = f12; + output_f[(i + 1u) * n + i] = f21; + output_f[(i + 1u) * n + (i + 1u)] = f22; + + i = i + 2u; + continue; + } + } + + // 1x1 block + let x = input_t[i * n + i]; + output_f[i * n + i] = exp(x); + i = i + 1u; + } +} diff --git a/src/runtime/wgpu/shaders/diagonal_log_f32.wgsl b/src/runtime/wgpu/shaders/diagonal_log_f32.wgsl new file mode 100644 index 00000000..5a83f472 --- /dev/null +++ b/src/runtime/wgpu/shaders/diagonal_log_f32.wgsl @@ -0,0 +1,94 @@ +// Diagonal block function application for f32 - log + +const WORKGROUP_SIZE: u32 = 256u; + +struct Params { + n: u32, + eps: f32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var input_t: array; +@group(0) @binding(1) var output_f: array; +@group(0) @binding(2) var params: Params; + +// Apply log to 2x2 block +fn apply_2x2_block(a: f32, b: f32, c: f32, d: f32, + f11: ptr, f12: ptr, + f21: ptr, f22: ptr) { + let trace = a + d; + let det = a * d - b * c; + let disc = trace * trace - 4.0 * det; + + if disc >= 0.0 { + // Real eigenvalues + let sqrt_disc = sqrt(disc); + let lambda1 = (trace + sqrt_disc) / 2.0; + let lambda2 = (trace - sqrt_disc) / 2.0; + let log1 = log(lambda1); + let log2 = log(lambda2); + + *f11 = (log1 + log2) / 2.0; + *f22 = (log1 + log2) / 2.0; + *f12 = (log1 - log2) / (lambda1 - lambda2) * b; + *f21 = (log1 - log2) / (lambda1 - lambda2) * c; + } else { + // Complex eigenvalues: log(r * e^(i*theta)) = log(r) + i*theta + let real_part = trace / 2.0; + let imag_part = sqrt(-disc) / 2.0; + let r = sqrt(det); // |lambda| = sqrt(det) for conjugate pair + let theta = atan2(imag_part, real_part); + + *f11 = log(r); + *f22 = log(r); + let scale = theta / imag_part; + *f12 = scale * b; + *f21 = scale * c; + } +} + +@compute @workgroup_size(1) +fn diagonal_log_f32(@builtin(global_invocation_id) gid: vec3) { + let n = params.n; + let eps = f32(params.eps); + + // Initialize output to zero + for (var idx: u32 = 0u; idx < n * n; idx = idx + 1u) { + output_f[idx] = 0.0; + } + + var i: u32 = 0u; + while i < n { + // Check if this is a 2x2 block + if i + 1u < n { + let sub_diag = abs(input_t[(i + 1u) * n + i]); + if sub_diag > eps { + // 2x2 block + let a = input_t[i * n + i]; + let b = input_t[i * n + (i + 1u)]; + let c = input_t[(i + 1u) * n + i]; + let d = input_t[(i + 1u) * n + (i + 1u)]; + + var f11: f32; + var f12: f32; + var f21: f32; + var f22: f32; + apply_2x2_block(a, b, c, d, &f11, &f12, &f21, &f22); + + output_f[i * n + i] = f11; + output_f[i * n + (i + 1u)] = f12; + output_f[(i + 1u) * n + i] = f21; + output_f[(i + 1u) * n + (i + 1u)] = f22; + + i = i + 2u; + continue; + } + } + + // 1x1 block + let x = input_t[i * n + i]; + output_f[i * n + i] = log(x); + i = i + 1u; + } +} diff --git a/src/runtime/wgpu/shaders/diagonal_sqrt_f32.wgsl b/src/runtime/wgpu/shaders/diagonal_sqrt_f32.wgsl new file mode 100644 index 00000000..41a88782 --- /dev/null +++ b/src/runtime/wgpu/shaders/diagonal_sqrt_f32.wgsl @@ -0,0 +1,101 @@ +// Diagonal block function application for f32 - sqrt + +const WORKGROUP_SIZE: u32 = 256u; + +struct Params { + n: u32, + eps: f32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var input_t: array; +@group(0) @binding(1) var output_f: array; +@group(0) @binding(2) var params: Params; + +// Apply sqrt to 2x2 block +fn apply_2x2_block(a: f32, b: f32, c: f32, d: f32, + f11: ptr, f12: ptr, + f21: ptr, f22: ptr) { + let trace = a + d; + let det = a * d - b * c; + let disc = trace * trace - 4.0 * det; + + if disc >= 0.0 { + // Real eigenvalues + let sqrt_disc = sqrt(disc); + let lambda1 = (trace + sqrt_disc) / 2.0; + let lambda2 = (trace - sqrt_disc) / 2.0; + let sqrt1 = sqrt(lambda1); + let sqrt2 = sqrt(lambda2); + + *f11 = (sqrt1 + sqrt2) / 2.0; + *f22 = (sqrt1 + sqrt2) / 2.0; + let denom = sqrt1 + sqrt2; + if abs(denom) > 1e-10 { + *f12 = b / denom; + *f21 = c / denom; + } else { + *f12 = 0.0; + *f21 = 0.0; + } + } else { + // Complex eigenvalues + let r = sqrt(det); + let theta = atan2(sqrt(-disc) / 2.0, trace / 2.0); + let sqrt_r = sqrt(r); + let half_theta = theta / 2.0; + + *f11 = sqrt_r * cos(half_theta); + *f22 = sqrt_r * cos(half_theta); + let imag_part = sqrt(-disc) / 2.0; + let scale = sqrt_r * sin(half_theta) / imag_part; + *f12 = scale * b; + *f21 = scale * c; + } +} + +@compute @workgroup_size(1) +fn diagonal_sqrt_f32(@builtin(global_invocation_id) gid: vec3) { + let n = params.n; + let eps = f32(params.eps); + + // Initialize output to zero + for (var idx: u32 = 0u; idx < n * n; idx = idx + 1u) { + output_f[idx] = 0.0; + } + + var i: u32 = 0u; + while i < n { + // Check if this is a 2x2 block + if i + 1u < n { + let sub_diag = abs(input_t[(i + 1u) * n + i]); + if sub_diag > eps { + // 2x2 block + let a = input_t[i * n + i]; + let b = input_t[i * n + (i + 1u)]; + let c = input_t[(i + 1u) * n + i]; + let d = input_t[(i + 1u) * n + (i + 1u)]; + + var f11: f32; + var f12: f32; + var f21: f32; + var f22: f32; + apply_2x2_block(a, b, c, d, &f11, &f12, &f21, &f22); + + output_f[i * n + i] = f11; + output_f[i * n + (i + 1u)] = f12; + output_f[(i + 1u) * n + i] = f21; + output_f[(i + 1u) * n + (i + 1u)] = f22; + + i = i + 2u; + continue; + } + } + + // 1x1 block + let x = input_t[i * n + i]; + output_f[i * n + i] = sqrt(x); + i = i + 1u; + } +} diff --git a/src/runtime/wgpu/shaders/distance.rs b/src/runtime/wgpu/shaders/distance.rs index ee3f1eee..039d93ae 100644 --- a/src/runtime/wgpu/shaders/distance.rs +++ b/src/runtime/wgpu/shaders/distance.rs @@ -4,11 +4,17 @@ use wgpu::{Buffer, Queue}; -use super::pipeline::{LayoutKey, PipelineCache, WORKGROUP_SIZE, workgroup_count}; +use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::DistanceMetric; +// Static WGSL shader code +const CDIST_F32: &str = include_str!("distance_cdist_f32.wgsl"); +const PDIST_F32: &str = include_str!("distance_pdist_f32.wgsl"); +const SQUAREFORM_F32: &str = include_str!("distance_squareform_f32.wgsl"); +const SQUAREFORM_INVERSE_F32: &str = include_str!("distance_squareform_inverse_f32.wgsl"); + fn check_float_dtype(dtype: DType, op: &'static str) -> Result<()> { match dtype { DType::F32 => Ok(()), @@ -39,507 +45,6 @@ pub fn metric_p_value(metric: DistanceMetric) -> f32 { } } -/// Generate WGSL shader for cdist operation -fn generate_cdist_shader() -> String { - format!( - r#" -const WORKGROUP_SIZE: u32 = {workgroup_size}u; - -// Distance metric constants -const METRIC_EUCLIDEAN: u32 = 0u; -const METRIC_SQEUCLIDEAN: u32 = 1u; -const METRIC_MANHATTAN: u32 = 2u; -const METRIC_CHEBYSHEV: u32 = 3u; -const METRIC_MINKOWSKI: u32 = 4u; -const METRIC_COSINE: u32 = 5u; -const METRIC_CORRELATION: u32 = 6u; -const METRIC_HAMMING: u32 = 7u; -const METRIC_JACCARD: u32 = 8u; - -struct Params {{ - n: u32, - m: u32, - d: u32, - metric: u32, - p: f32, -}} - -@group(0) @binding(0) var x: array; -@group(0) @binding(1) var y: array; -@group(0) @binding(2) var out: array; -@group(0) @binding(3) var params: Params; - -fn sqeuclidean_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 {{ - var sum: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - let diff = x[x_offset + k] - y[y_offset + k]; - sum += diff * diff; - }} - return sum; -}} - -fn manhattan_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 {{ - var sum: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - sum += abs(x[x_offset + k] - y[y_offset + k]); - }} - return sum; -}} - -fn chebyshev_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 {{ - var max_val: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - let abs_diff = abs(x[x_offset + k] - y[y_offset + k]); - if (abs_diff > max_val) {{ - max_val = abs_diff; - }} - }} - return max_val; -}} - -fn minkowski_dist(x_offset: u32, y_offset: u32, d: u32, p: f32) -> f32 {{ - var sum: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - sum += pow(abs(x[x_offset + k] - y[y_offset + k]), p); - }} - return pow(sum, 1.0 / p); -}} - -fn cosine_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 {{ - var dot: f32 = 0.0; - var norm_a: f32 = 0.0; - var norm_b: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - let ak = x[x_offset + k]; - let bk = y[y_offset + k]; - dot += ak * bk; - norm_a += ak * ak; - norm_b += bk * bk; - }} - let denom = sqrt(norm_a * norm_b); - if (denom == 0.0) {{ - return 0.0; - }} - return 1.0 - dot / denom; -}} - -fn correlation_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 {{ - var sum_a: f32 = 0.0; - var sum_b: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - sum_a += x[x_offset + k]; - sum_b += y[y_offset + k]; - }} - let mean_a = sum_a / f32(d); - let mean_b = sum_b / f32(d); - - var cov: f32 = 0.0; - var var_a: f32 = 0.0; - var var_b: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - let da = x[x_offset + k] - mean_a; - let db = y[y_offset + k] - mean_b; - cov += da * db; - var_a += da * da; - var_b += db * db; - }} - let denom = sqrt(var_a * var_b); - if (denom == 0.0) {{ - return 0.0; - }} - return 1.0 - cov / denom; -}} - -fn hamming_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 {{ - var count: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - if (x[x_offset + k] != y[y_offset + k]) {{ - count += 1.0; - }} - }} - return count / f32(d); -}} - -fn jaccard_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 {{ - var intersection: f32 = 0.0; - var union_count: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - let a_nonzero = x[x_offset + k] != 0.0; - let b_nonzero = y[y_offset + k] != 0.0; - if (a_nonzero && b_nonzero) {{ - intersection += 1.0; - }} - if (a_nonzero || b_nonzero) {{ - union_count += 1.0; - }} - }} - if (union_count == 0.0) {{ - return 0.0; - }} - return 1.0 - intersection / union_count; -}} - -fn compute_distance(x_offset: u32, y_offset: u32, d: u32, metric: u32, p: f32) -> f32 {{ - switch (metric) {{ - case METRIC_EUCLIDEAN: {{ - return sqrt(sqeuclidean_dist(x_offset, y_offset, d)); - }} - case METRIC_SQEUCLIDEAN: {{ - return sqeuclidean_dist(x_offset, y_offset, d); - }} - case METRIC_MANHATTAN: {{ - return manhattan_dist(x_offset, y_offset, d); - }} - case METRIC_CHEBYSHEV: {{ - return chebyshev_dist(x_offset, y_offset, d); - }} - case METRIC_MINKOWSKI: {{ - return minkowski_dist(x_offset, y_offset, d, p); - }} - case METRIC_COSINE: {{ - return cosine_dist(x_offset, y_offset, d); - }} - case METRIC_CORRELATION: {{ - return correlation_dist(x_offset, y_offset, d); - }} - case METRIC_HAMMING: {{ - return hamming_dist(x_offset, y_offset, d); - }} - case METRIC_JACCARD: {{ - return jaccard_dist(x_offset, y_offset, d); - }} - default: {{ - return 0.0; - }} - }} -}} - -@compute @workgroup_size(WORKGROUP_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = params.n * params.m; - if (idx >= total) {{ - return; - }} - - let i = idx / params.m; - let j = idx % params.m; - - let x_offset = i * params.d; - let y_offset = j * params.d; - - let dist = compute_distance(x_offset, y_offset, params.d, params.metric, params.p); - out[idx] = dist; -}} -"#, - workgroup_size = WORKGROUP_SIZE - ) -} - -/// Generate WGSL shader for pdist operation -fn generate_pdist_shader() -> String { - format!( - r#" -const WORKGROUP_SIZE: u32 = {workgroup_size}u; - -// Distance metric constants (same as cdist) -const METRIC_EUCLIDEAN: u32 = 0u; -const METRIC_SQEUCLIDEAN: u32 = 1u; -const METRIC_MANHATTAN: u32 = 2u; -const METRIC_CHEBYSHEV: u32 = 3u; -const METRIC_MINKOWSKI: u32 = 4u; -const METRIC_COSINE: u32 = 5u; -const METRIC_CORRELATION: u32 = 6u; -const METRIC_HAMMING: u32 = 7u; -const METRIC_JACCARD: u32 = 8u; - -struct Params {{ - n: u32, - d: u32, - metric: u32, - p: f32, -}} - -@group(0) @binding(0) var x: array; -@group(0) @binding(1) var out: array; -@group(0) @binding(2) var params: Params; - -fn sqeuclidean_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 {{ - var sum: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - let diff = x[i_offset + k] - x[j_offset + k]; - sum += diff * diff; - }} - return sum; -}} - -fn manhattan_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 {{ - var sum: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - sum += abs(x[i_offset + k] - x[j_offset + k]); - }} - return sum; -}} - -fn chebyshev_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 {{ - var max_val: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - let abs_diff = abs(x[i_offset + k] - x[j_offset + k]); - if (abs_diff > max_val) {{ - max_val = abs_diff; - }} - }} - return max_val; -}} - -fn minkowski_dist(i_offset: u32, j_offset: u32, d: u32, p: f32) -> f32 {{ - var sum: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - sum += pow(abs(x[i_offset + k] - x[j_offset + k]), p); - }} - return pow(sum, 1.0 / p); -}} - -fn cosine_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 {{ - var dot: f32 = 0.0; - var norm_a: f32 = 0.0; - var norm_b: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - let ak = x[i_offset + k]; - let bk = x[j_offset + k]; - dot += ak * bk; - norm_a += ak * ak; - norm_b += bk * bk; - }} - let denom = sqrt(norm_a * norm_b); - if (denom == 0.0) {{ - return 0.0; - }} - return 1.0 - dot / denom; -}} - -fn correlation_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 {{ - var sum_a: f32 = 0.0; - var sum_b: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - sum_a += x[i_offset + k]; - sum_b += x[j_offset + k]; - }} - let mean_a = sum_a / f32(d); - let mean_b = sum_b / f32(d); - - var cov: f32 = 0.0; - var var_a: f32 = 0.0; - var var_b: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - let da = x[i_offset + k] - mean_a; - let db = x[j_offset + k] - mean_b; - cov += da * db; - var_a += da * da; - var_b += db * db; - }} - let denom = sqrt(var_a * var_b); - if (denom == 0.0) {{ - return 0.0; - }} - return 1.0 - cov / denom; -}} - -fn hamming_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 {{ - var count: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - if (x[i_offset + k] != x[j_offset + k]) {{ - count += 1.0; - }} - }} - return count / f32(d); -}} - -fn jaccard_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 {{ - var intersection: f32 = 0.0; - var union_count: f32 = 0.0; - for (var k: u32 = 0u; k < d; k++) {{ - let a_nonzero = x[i_offset + k] != 0.0; - let b_nonzero = x[j_offset + k] != 0.0; - if (a_nonzero && b_nonzero) {{ - intersection += 1.0; - }} - if (a_nonzero || b_nonzero) {{ - union_count += 1.0; - }} - }} - if (union_count == 0.0) {{ - return 0.0; - }} - return 1.0 - intersection / union_count; -}} - -fn compute_distance(i_offset: u32, j_offset: u32, d: u32, metric: u32, p: f32) -> f32 {{ - switch (metric) {{ - case METRIC_EUCLIDEAN: {{ - return sqrt(sqeuclidean_dist(i_offset, j_offset, d)); - }} - case METRIC_SQEUCLIDEAN: {{ - return sqeuclidean_dist(i_offset, j_offset, d); - }} - case METRIC_MANHATTAN: {{ - return manhattan_dist(i_offset, j_offset, d); - }} - case METRIC_CHEBYSHEV: {{ - return chebyshev_dist(i_offset, j_offset, d); - }} - case METRIC_MINKOWSKI: {{ - return minkowski_dist(i_offset, j_offset, d, p); - }} - case METRIC_COSINE: {{ - return cosine_dist(i_offset, j_offset, d); - }} - case METRIC_CORRELATION: {{ - return correlation_dist(i_offset, j_offset, d); - }} - case METRIC_HAMMING: {{ - return hamming_dist(i_offset, j_offset, d); - }} - case METRIC_JACCARD: {{ - return jaccard_dist(i_offset, j_offset, d); - }} - default: {{ - return 0.0; - }} - }} -}} - -// Convert condensed index k to (i, j) where i < j -fn condensed_to_ij(k: u32, n: u32) -> vec2 {{ - var i: u32 = 0u; - var count: u32 = 0u; - loop {{ - let row_count = n - 1u - i; - if (count + row_count > k) {{ - let j = k - count + i + 1u; - return vec2(i, j); - }} - count += row_count; - i++; - }} - return vec2(0u, 0u); // Should never reach -}} - -@compute @workgroup_size(WORKGROUP_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3) {{ - let k = gid.x; - let total = params.n * (params.n - 1u) / 2u; - if (k >= total) {{ - return; - }} - - let ij = condensed_to_ij(k, params.n); - let i = ij.x; - let j = ij.y; - - let i_offset = i * params.d; - let j_offset = j * params.d; - - let dist = compute_distance(i_offset, j_offset, params.d, params.metric, params.p); - out[k] = dist; -}} -"#, - workgroup_size = WORKGROUP_SIZE - ) -} - -/// Generate WGSL shader for squareform operation -fn generate_squareform_shader() -> String { - format!( - r#" -const WORKGROUP_SIZE: u32 = {workgroup_size}u; - -struct Params {{ - n: u32, -}} - -@group(0) @binding(0) var condensed: array; -@group(0) @binding(1) var square: array; -@group(0) @binding(2) var params: Params; - -@compute @workgroup_size(WORKGROUP_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = params.n * params.n; - if (idx >= total) {{ - return; - }} - - let i = idx / params.n; - let j = idx % params.n; - - if (i == j) {{ - // Diagonal is zero - square[idx] = 0.0; - }} else if (i < j) {{ - // Upper triangle: k = n*i - i*(i+1)/2 + j - i - 1 - let k = params.n * i - i * (i + 1u) / 2u + j - i - 1u; - square[idx] = condensed[k]; - }} else {{ - // Lower triangle: mirror from upper - let k = params.n * j - j * (j + 1u) / 2u + i - j - 1u; - square[idx] = condensed[k]; - }} -}} -"#, - workgroup_size = WORKGROUP_SIZE - ) -} - -/// Generate WGSL shader for squareform_inverse operation -fn generate_squareform_inverse_shader() -> String { - format!( - r#" -const WORKGROUP_SIZE: u32 = {workgroup_size}u; - -struct Params {{ - n: u32, -}} - -@group(0) @binding(0) var square: array; -@group(0) @binding(1) var condensed: array; -@group(0) @binding(2) var params: Params; - -// Convert condensed index k to (i, j) where i < j -fn condensed_to_ij(k: u32, n: u32) -> vec2 {{ - var i: u32 = 0u; - var count: u32 = 0u; - loop {{ - let row_count = n - 1u - i; - if (count + row_count > k) {{ - let j = k - count + i + 1u; - return vec2(i, j); - }} - count += row_count; - i++; - }} - return vec2(0u, 0u); -}} - -@compute @workgroup_size(WORKGROUP_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3) {{ - let k = gid.x; - let total = params.n * (params.n - 1u) / 2u; - if (k >= total) {{ - return; - }} - - let ij = condensed_to_ij(k, params.n); - let i = ij.x; - let j = ij.y; - - condensed[k] = square[i * params.n + j]; -}} -"#, - workgroup_size = WORKGROUP_SIZE - ) -} - /// Launch cdist kernel - pairwise distances between two point sets. pub fn launch_cdist( cache: &PipelineCache, @@ -557,8 +62,7 @@ pub fn launch_cdist( check_float_dtype(dtype, "cdist")?; let name = "cdist_f32"; - let shader = generate_cdist_shader(); - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module(name, CDIST_F32); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, @@ -601,8 +105,7 @@ pub fn launch_pdist( check_float_dtype(dtype, "pdist")?; let name = "pdist_f32"; - let shader = generate_pdist_shader(); - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module(name, PDIST_F32); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, @@ -645,8 +148,7 @@ pub fn launch_squareform( check_float_dtype(dtype, "squareform")?; let name = "squareform_f32"; - let shader = generate_squareform_shader(); - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module(name, SQUAREFORM_F32); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, @@ -689,8 +191,7 @@ pub fn launch_squareform_inverse( check_float_dtype(dtype, "squareform_inverse")?; let name = "squareform_inverse_f32"; - let shader = generate_squareform_inverse_shader(); - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module(name, SQUAREFORM_INVERSE_F32); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, diff --git a/src/runtime/wgpu/shaders/distance_cdist_f32.wgsl b/src/runtime/wgpu/shaders/distance_cdist_f32.wgsl new file mode 100644 index 00000000..4bd89880 --- /dev/null +++ b/src/runtime/wgpu/shaders/distance_cdist_f32.wgsl @@ -0,0 +1,188 @@ +const WORKGROUP_SIZE: u32 = 256u; + +// Distance metric constants +const METRIC_EUCLIDEAN: u32 = 0u; +const METRIC_SQEUCLIDEAN: u32 = 1u; +const METRIC_MANHATTAN: u32 = 2u; +const METRIC_CHEBYSHEV: u32 = 3u; +const METRIC_MINKOWSKI: u32 = 4u; +const METRIC_COSINE: u32 = 5u; +const METRIC_CORRELATION: u32 = 6u; +const METRIC_HAMMING: u32 = 7u; +const METRIC_JACCARD: u32 = 8u; + +struct Params { + n: u32, + m: u32, + d: u32, + metric: u32, + p: f32, +} + +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var y: array; +@group(0) @binding(2) var out: array; +@group(0) @binding(3) var params: Params; + +fn sqeuclidean_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let diff = x[x_offset + k] - y[y_offset + k]; + sum += diff * diff; + } + return sum; +} + +fn manhattan_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + sum += abs(x[x_offset + k] - y[y_offset + k]); + } + return sum; +} + +fn chebyshev_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var max_val: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let abs_diff = abs(x[x_offset + k] - y[y_offset + k]); + if (abs_diff > max_val) { + max_val = abs_diff; + } + } + return max_val; +} + +fn minkowski_dist(x_offset: u32, y_offset: u32, d: u32, p: f32) -> f32 { + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + sum += pow(abs(x[x_offset + k] - y[y_offset + k]), p); + } + return pow(sum, 1.0 / p); +} + +fn cosine_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var dot: f32 = 0.0; + var norm_a: f32 = 0.0; + var norm_b: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let ak = x[x_offset + k]; + let bk = y[y_offset + k]; + dot += ak * bk; + norm_a += ak * ak; + norm_b += bk * bk; + } + let denom = sqrt(norm_a * norm_b); + if (denom == 0.0) { + return 0.0; + } + return 1.0 - dot / denom; +} + +fn correlation_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var sum_a: f32 = 0.0; + var sum_b: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + sum_a += x[x_offset + k]; + sum_b += y[y_offset + k]; + } + let mean_a = sum_a / f32(d); + let mean_b = sum_b / f32(d); + + var cov: f32 = 0.0; + var var_a: f32 = 0.0; + var var_b: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let da = x[x_offset + k] - mean_a; + let db = y[y_offset + k] - mean_b; + cov += da * db; + var_a += da * da; + var_b += db * db; + } + let denom = sqrt(var_a * var_b); + if (denom == 0.0) { + return 0.0; + } + return 1.0 - cov / denom; +} + +fn hamming_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var count: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + if (x[x_offset + k] != y[y_offset + k]) { + count += 1.0; + } + } + return count / f32(d); +} + +fn jaccard_dist(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var intersection: f32 = 0.0; + var union_count: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let a_nonzero = x[x_offset + k] != 0.0; + let b_nonzero = y[y_offset + k] != 0.0; + if (a_nonzero && b_nonzero) { + intersection += 1.0; + } + if (a_nonzero || b_nonzero) { + union_count += 1.0; + } + } + if (union_count == 0.0) { + return 0.0; + } + return 1.0 - intersection / union_count; +} + +fn compute_distance(x_offset: u32, y_offset: u32, d: u32, metric: u32, p: f32) -> f32 { + switch (metric) { + case METRIC_EUCLIDEAN: { + return sqrt(sqeuclidean_dist(x_offset, y_offset, d)); + } + case METRIC_SQEUCLIDEAN: { + return sqeuclidean_dist(x_offset, y_offset, d); + } + case METRIC_MANHATTAN: { + return manhattan_dist(x_offset, y_offset, d); + } + case METRIC_CHEBYSHEV: { + return chebyshev_dist(x_offset, y_offset, d); + } + case METRIC_MINKOWSKI: { + return minkowski_dist(x_offset, y_offset, d, p); + } + case METRIC_COSINE: { + return cosine_dist(x_offset, y_offset, d); + } + case METRIC_CORRELATION: { + return correlation_dist(x_offset, y_offset, d); + } + case METRIC_HAMMING: { + return hamming_dist(x_offset, y_offset, d); + } + case METRIC_JACCARD: { + return jaccard_dist(x_offset, y_offset, d); + } + default: { + return 0.0; + } + } +} + +@compute @workgroup_size(WORKGROUP_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = params.n * params.m; + if (idx >= total) { + return; + } + + let i = idx / params.m; + let j = idx % params.m; + + let x_offset = i * params.d; + let y_offset = j * params.d; + + let dist = compute_distance(x_offset, y_offset, params.d, params.metric, params.p); + out[idx] = dist; +} diff --git a/src/runtime/wgpu/shaders/distance_f32.wgsl b/src/runtime/wgpu/shaders/distance_f32.wgsl new file mode 100644 index 00000000..5a339af6 --- /dev/null +++ b/src/runtime/wgpu/shaders/distance_f32.wgsl @@ -0,0 +1,473 @@ +// Distance computation shaders - F32 +// +// cdist_f32: Pairwise distances between two point sets +// pdist_f32: Pairwise distances within one point set (condensed) +// squareform_f32: Condensed to square distance matrix +// squareform_inverse_f32: Square to condensed distance matrix + +const WORKGROUP_SIZE: u32 = 256u; + +// Distance metric constants +const METRIC_EUCLIDEAN: u32 = 0u; +const METRIC_SQEUCLIDEAN: u32 = 1u; +const METRIC_MANHATTAN: u32 = 2u; +const METRIC_CHEBYSHEV: u32 = 3u; +const METRIC_MINKOWSKI: u32 = 4u; +const METRIC_COSINE: u32 = 5u; +const METRIC_CORRELATION: u32 = 6u; +const METRIC_HAMMING: u32 = 7u; +const METRIC_JACCARD: u32 = 8u; + +// ============================================================================ +// cdist_f32 +// ============================================================================ + +struct CdistParams { + n: u32, + m: u32, + d: u32, + metric: u32, + p: f32, +} + +@group(0) @binding(0) var cdist_x: array; +@group(0) @binding(1) var cdist_y: array; +@group(0) @binding(2) var cdist_out: array; +@group(0) @binding(3) var cdist_params: CdistParams; + +fn cdist_sqeuclidean(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let diff = cdist_x[x_offset + k] - cdist_y[y_offset + k]; + sum += diff * diff; + } + return sum; +} + +fn cdist_manhattan(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + sum += abs(cdist_x[x_offset + k] - cdist_y[y_offset + k]); + } + return sum; +} + +fn cdist_chebyshev(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var max_val: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let abs_diff = abs(cdist_x[x_offset + k] - cdist_y[y_offset + k]); + if (abs_diff > max_val) { + max_val = abs_diff; + } + } + return max_val; +} + +fn cdist_minkowski(x_offset: u32, y_offset: u32, d: u32, p: f32) -> f32 { + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + sum += pow(abs(cdist_x[x_offset + k] - cdist_y[y_offset + k]), p); + } + return pow(sum, 1.0 / p); +} + +fn cdist_cosine(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var dot: f32 = 0.0; + var norm_a: f32 = 0.0; + var norm_b: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let ak = cdist_x[x_offset + k]; + let bk = cdist_y[y_offset + k]; + dot += ak * bk; + norm_a += ak * ak; + norm_b += bk * bk; + } + let denom = sqrt(norm_a * norm_b); + if (denom == 0.0) { + return 0.0; + } + return 1.0 - dot / denom; +} + +fn cdist_correlation(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var sum_a: f32 = 0.0; + var sum_b: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + sum_a += cdist_x[x_offset + k]; + sum_b += cdist_y[y_offset + k]; + } + let mean_a = sum_a / f32(d); + let mean_b = sum_b / f32(d); + + var cov: f32 = 0.0; + var var_a: f32 = 0.0; + var var_b: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let da = cdist_x[x_offset + k] - mean_a; + let db = cdist_y[y_offset + k] - mean_b; + cov += da * db; + var_a += da * da; + var_b += db * db; + } + let denom = sqrt(var_a * var_b); + if (denom == 0.0) { + return 0.0; + } + return 1.0 - cov / denom; +} + +fn cdist_hamming(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var count: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + if (cdist_x[x_offset + k] != cdist_y[y_offset + k]) { + count += 1.0; + } + } + return count / f32(d); +} + +fn cdist_jaccard(x_offset: u32, y_offset: u32, d: u32) -> f32 { + var intersection: f32 = 0.0; + var union_count: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let a_nonzero = cdist_x[x_offset + k] != 0.0; + let b_nonzero = cdist_y[y_offset + k] != 0.0; + if (a_nonzero && b_nonzero) { + intersection += 1.0; + } + if (a_nonzero || b_nonzero) { + union_count += 1.0; + } + } + if (union_count == 0.0) { + return 0.0; + } + return 1.0 - intersection / union_count; +} + +fn cdist_compute_distance(x_offset: u32, y_offset: u32, d: u32, metric: u32, p: f32) -> f32 { + switch (metric) { + case METRIC_EUCLIDEAN: { + return sqrt(cdist_sqeuclidean(x_offset, y_offset, d)); + } + case METRIC_SQEUCLIDEAN: { + return cdist_sqeuclidean(x_offset, y_offset, d); + } + case METRIC_MANHATTAN: { + return cdist_manhattan(x_offset, y_offset, d); + } + case METRIC_CHEBYSHEV: { + return cdist_chebyshev(x_offset, y_offset, d); + } + case METRIC_MINKOWSKI: { + return cdist_minkowski(x_offset, y_offset, d, p); + } + case METRIC_COSINE: { + return cdist_cosine(x_offset, y_offset, d); + } + case METRIC_CORRELATION: { + return cdist_correlation(x_offset, y_offset, d); + } + case METRIC_HAMMING: { + return cdist_hamming(x_offset, y_offset, d); + } + case METRIC_JACCARD: { + return cdist_jaccard(x_offset, y_offset, d); + } + default: { + return 0.0; + } + } +} + +@compute @workgroup_size(WORKGROUP_SIZE) +fn cdist_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = cdist_params.n * cdist_params.m; + if (idx >= total) { + return; + } + + let i = idx / cdist_params.m; + let j = idx % cdist_params.m; + + let x_offset = i * cdist_params.d; + let y_offset = j * cdist_params.d; + + let dist = cdist_compute_distance(x_offset, y_offset, cdist_params.d, cdist_params.metric, cdist_params.p); + cdist_out[idx] = dist; +} + +// ============================================================================ +// pdist_f32 +// ============================================================================ + +struct PdistParams { + n: u32, + d: u32, + metric: u32, + p: f32, +} + +@group(0) @binding(0) var pdist_x: array; +@group(0) @binding(1) var pdist_out: array; +@group(0) @binding(2) var pdist_params: PdistParams; + +fn pdist_sqeuclidean(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let diff = pdist_x[i_offset + k] - pdist_x[j_offset + k]; + sum += diff * diff; + } + return sum; +} + +fn pdist_manhattan(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + sum += abs(pdist_x[i_offset + k] - pdist_x[j_offset + k]); + } + return sum; +} + +fn pdist_chebyshev(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var max_val: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let abs_diff = abs(pdist_x[i_offset + k] - pdist_x[j_offset + k]); + if (abs_diff > max_val) { + max_val = abs_diff; + } + } + return max_val; +} + +fn pdist_minkowski(i_offset: u32, j_offset: u32, d: u32, p: f32) -> f32 { + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + sum += pow(abs(pdist_x[i_offset + k] - pdist_x[j_offset + k]), p); + } + return pow(sum, 1.0 / p); +} + +fn pdist_cosine(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var dot: f32 = 0.0; + var norm_a: f32 = 0.0; + var norm_b: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let ak = pdist_x[i_offset + k]; + let bk = pdist_x[j_offset + k]; + dot += ak * bk; + norm_a += ak * ak; + norm_b += bk * bk; + } + let denom = sqrt(norm_a * norm_b); + if (denom == 0.0) { + return 0.0; + } + return 1.0 - dot / denom; +} + +fn pdist_correlation(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var sum_a: f32 = 0.0; + var sum_b: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + sum_a += pdist_x[i_offset + k]; + sum_b += pdist_x[j_offset + k]; + } + let mean_a = sum_a / f32(d); + let mean_b = sum_b / f32(d); + + var cov: f32 = 0.0; + var var_a: f32 = 0.0; + var var_b: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let da = pdist_x[i_offset + k] - mean_a; + let db = pdist_x[j_offset + k] - mean_b; + cov += da * db; + var_a += da * da; + var_b += db * db; + } + let denom = sqrt(var_a * var_b); + if (denom == 0.0) { + return 0.0; + } + return 1.0 - cov / denom; +} + +fn pdist_hamming(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var count: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + if (pdist_x[i_offset + k] != pdist_x[j_offset + k]) { + count += 1.0; + } + } + return count / f32(d); +} + +fn pdist_jaccard(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var intersection: f32 = 0.0; + var union_count: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let a_nonzero = pdist_x[i_offset + k] != 0.0; + let b_nonzero = pdist_x[j_offset + k] != 0.0; + if (a_nonzero && b_nonzero) { + intersection += 1.0; + } + if (a_nonzero || b_nonzero) { + union_count += 1.0; + } + } + if (union_count == 0.0) { + return 0.0; + } + return 1.0 - intersection / union_count; +} + +fn pdist_compute_distance(i_offset: u32, j_offset: u32, d: u32, metric: u32, p: f32) -> f32 { + switch (metric) { + case METRIC_EUCLIDEAN: { + return sqrt(pdist_sqeuclidean(i_offset, j_offset, d)); + } + case METRIC_SQEUCLIDEAN: { + return pdist_sqeuclidean(i_offset, j_offset, d); + } + case METRIC_MANHATTAN: { + return pdist_manhattan(i_offset, j_offset, d); + } + case METRIC_CHEBYSHEV: { + return pdist_chebyshev(i_offset, j_offset, d); + } + case METRIC_MINKOWSKI: { + return pdist_minkowski(i_offset, j_offset, d, p); + } + case METRIC_COSINE: { + return pdist_cosine(i_offset, j_offset, d); + } + case METRIC_CORRELATION: { + return pdist_correlation(i_offset, j_offset, d); + } + case METRIC_HAMMING: { + return pdist_hamming(i_offset, j_offset, d); + } + case METRIC_JACCARD: { + return pdist_jaccard(i_offset, j_offset, d); + } + default: { + return 0.0; + } + } +} + +// Convert condensed index k to (i, j) where i < j +fn pdist_condensed_to_ij(k: u32, n: u32) -> vec2 { + var i: u32 = 0u; + var count: u32 = 0u; + loop { + let row_count = n - 1u - i; + if (count + row_count > k) { + let j = k - count + i + 1u; + return vec2(i, j); + } + count += row_count; + i++; + } + return vec2(0u, 0u); // Should never reach +} + +@compute @workgroup_size(WORKGROUP_SIZE) +fn pdist_f32(@builtin(global_invocation_id) gid: vec3) { + let k = gid.x; + let total = pdist_params.n * (pdist_params.n - 1u) / 2u; + if (k >= total) { + return; + } + + let ij = pdist_condensed_to_ij(k, pdist_params.n); + let i = ij.x; + let j = ij.y; + + let i_offset = i * pdist_params.d; + let j_offset = j * pdist_params.d; + + let dist = pdist_compute_distance(i_offset, j_offset, pdist_params.d, pdist_params.metric, pdist_params.p); + pdist_out[k] = dist; +} + +// ============================================================================ +// squareform_f32 +// ============================================================================ + +struct SquareformParams { + n: u32, +} + +@group(0) @binding(0) var sqf_condensed: array; +@group(0) @binding(1) var sqf_square: array; +@group(0) @binding(2) var sqf_params: SquareformParams; + +@compute @workgroup_size(WORKGROUP_SIZE) +fn squareform_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = sqf_params.n * sqf_params.n; + if (idx >= total) { + return; + } + + let i = idx / sqf_params.n; + let j = idx % sqf_params.n; + + if (i == j) { + // Diagonal is zero + sqf_square[idx] = 0.0; + } else if (i < j) { + // Upper triangle: k = n*i - i*(i+1)/2 + j - i - 1 + let k = sqf_params.n * i - i * (i + 1u) / 2u + j - i - 1u; + sqf_square[idx] = sqf_condensed[k]; + } else { + // Lower triangle: mirror from upper + let k = sqf_params.n * j - j * (j + 1u) / 2u + i - j - 1u; + sqf_square[idx] = sqf_condensed[k]; + } +} + +// ============================================================================ +// squareform_inverse_f32 +// ============================================================================ + +struct SquareformInverseParams { + n: u32, +} + +@group(0) @binding(0) var sqfi_square: array; +@group(0) @binding(1) var sqfi_condensed: array; +@group(0) @binding(2) var sqfi_params: SquareformInverseParams; + +fn sqfi_condensed_to_ij(k: u32, n: u32) -> vec2 { + var i: u32 = 0u; + var count: u32 = 0u; + loop { + let row_count = n - 1u - i; + if (count + row_count > k) { + let j = k - count + i + 1u; + return vec2(i, j); + } + count += row_count; + i++; + } + return vec2(0u, 0u); +} + +@compute @workgroup_size(WORKGROUP_SIZE) +fn squareform_inverse_f32(@builtin(global_invocation_id) gid: vec3) { + let k = gid.x; + let total = sqfi_params.n * (sqfi_params.n - 1u) / 2u; + if (k >= total) { + return; + } + + let ij = sqfi_condensed_to_ij(k, sqfi_params.n); + let i = ij.x; + let j = ij.y; + + sqfi_condensed[k] = sqfi_square[i * sqfi_params.n + j]; +} diff --git a/src/runtime/wgpu/shaders/distance_pdist_f32.wgsl b/src/runtime/wgpu/shaders/distance_pdist_f32.wgsl new file mode 100644 index 00000000..3ff19b93 --- /dev/null +++ b/src/runtime/wgpu/shaders/distance_pdist_f32.wgsl @@ -0,0 +1,203 @@ +const WORKGROUP_SIZE: u32 = 256u; + +// Distance metric constants (same as cdist) +const METRIC_EUCLIDEAN: u32 = 0u; +const METRIC_SQEUCLIDEAN: u32 = 1u; +const METRIC_MANHATTAN: u32 = 2u; +const METRIC_CHEBYSHEV: u32 = 3u; +const METRIC_MINKOWSKI: u32 = 4u; +const METRIC_COSINE: u32 = 5u; +const METRIC_CORRELATION: u32 = 6u; +const METRIC_HAMMING: u32 = 7u; +const METRIC_JACCARD: u32 = 8u; + +struct Params { + n: u32, + d: u32, + metric: u32, + p: f32, +} + +@group(0) @binding(0) var x: array; +@group(0) @binding(1) var out: array; +@group(0) @binding(2) var params: Params; + +fn sqeuclidean_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let diff = x[i_offset + k] - x[j_offset + k]; + sum += diff * diff; + } + return sum; +} + +fn manhattan_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + sum += abs(x[i_offset + k] - x[j_offset + k]); + } + return sum; +} + +fn chebyshev_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var max_val: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let abs_diff = abs(x[i_offset + k] - x[j_offset + k]); + if (abs_diff > max_val) { + max_val = abs_diff; + } + } + return max_val; +} + +fn minkowski_dist(i_offset: u32, j_offset: u32, d: u32, p: f32) -> f32 { + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + sum += pow(abs(x[i_offset + k] - x[j_offset + k]), p); + } + return pow(sum, 1.0 / p); +} + +fn cosine_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var dot: f32 = 0.0; + var norm_a: f32 = 0.0; + var norm_b: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let ak = x[i_offset + k]; + let bk = x[j_offset + k]; + dot += ak * bk; + norm_a += ak * ak; + norm_b += bk * bk; + } + let denom = sqrt(norm_a * norm_b); + if (denom == 0.0) { + return 0.0; + } + return 1.0 - dot / denom; +} + +fn correlation_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var sum_a: f32 = 0.0; + var sum_b: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + sum_a += x[i_offset + k]; + sum_b += x[j_offset + k]; + } + let mean_a = sum_a / f32(d); + let mean_b = sum_b / f32(d); + + var cov: f32 = 0.0; + var var_a: f32 = 0.0; + var var_b: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let da = x[i_offset + k] - mean_a; + let db = x[j_offset + k] - mean_b; + cov += da * db; + var_a += da * da; + var_b += db * db; + } + let denom = sqrt(var_a * var_b); + if (denom == 0.0) { + return 0.0; + } + return 1.0 - cov / denom; +} + +fn hamming_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var count: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + if (x[i_offset + k] != x[j_offset + k]) { + count += 1.0; + } + } + return count / f32(d); +} + +fn jaccard_dist(i_offset: u32, j_offset: u32, d: u32) -> f32 { + var intersection: f32 = 0.0; + var union_count: f32 = 0.0; + for (var k: u32 = 0u; k < d; k++) { + let a_nonzero = x[i_offset + k] != 0.0; + let b_nonzero = x[j_offset + k] != 0.0; + if (a_nonzero && b_nonzero) { + intersection += 1.0; + } + if (a_nonzero || b_nonzero) { + union_count += 1.0; + } + } + if (union_count == 0.0) { + return 0.0; + } + return 1.0 - intersection / union_count; +} + +fn compute_distance(i_offset: u32, j_offset: u32, d: u32, metric: u32, p: f32) -> f32 { + switch (metric) { + case METRIC_EUCLIDEAN: { + return sqrt(sqeuclidean_dist(i_offset, j_offset, d)); + } + case METRIC_SQEUCLIDEAN: { + return sqeuclidean_dist(i_offset, j_offset, d); + } + case METRIC_MANHATTAN: { + return manhattan_dist(i_offset, j_offset, d); + } + case METRIC_CHEBYSHEV: { + return chebyshev_dist(i_offset, j_offset, d); + } + case METRIC_MINKOWSKI: { + return minkowski_dist(i_offset, j_offset, d, p); + } + case METRIC_COSINE: { + return cosine_dist(i_offset, j_offset, d); + } + case METRIC_CORRELATION: { + return correlation_dist(i_offset, j_offset, d); + } + case METRIC_HAMMING: { + return hamming_dist(i_offset, j_offset, d); + } + case METRIC_JACCARD: { + return jaccard_dist(i_offset, j_offset, d); + } + default: { + return 0.0; + } + } +} + +// Convert condensed index k to (i, j) where i < j +fn condensed_to_ij(k: u32, n: u32) -> vec2 { + var i: u32 = 0u; + var count: u32 = 0u; + loop { + let row_count = n - 1u - i; + if (count + row_count > k) { + let j = k - count + i + 1u; + return vec2(i, j); + } + count += row_count; + i++; + } + return vec2(0u, 0u); // Should never reach +} + +@compute @workgroup_size(WORKGROUP_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + let k = gid.x; + let total = params.n * (params.n - 1u) / 2u; + if (k >= total) { + return; + } + + let ij = condensed_to_ij(k, params.n); + let i = ij.x; + let j = ij.y; + + let i_offset = i * params.d; + let j_offset = j * params.d; + + let dist = compute_distance(i_offset, j_offset, params.d, params.metric, params.p); + out[k] = dist; +} diff --git a/src/runtime/wgpu/shaders/distance_squareform_f32.wgsl b/src/runtime/wgpu/shaders/distance_squareform_f32.wgsl new file mode 100644 index 00000000..3fef8fa6 --- /dev/null +++ b/src/runtime/wgpu/shaders/distance_squareform_f32.wgsl @@ -0,0 +1,34 @@ +const WORKGROUP_SIZE: u32 = 256u; + +struct Params { + n: u32, +} + +@group(0) @binding(0) var condensed: array; +@group(0) @binding(1) var square: array; +@group(0) @binding(2) var params: Params; + +@compute @workgroup_size(WORKGROUP_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = params.n * params.n; + if (idx >= total) { + return; + } + + let i = idx / params.n; + let j = idx % params.n; + + if (i == j) { + // Diagonal is zero + square[idx] = 0.0; + } else if (i < j) { + // Upper triangle: k = n*i - i*(i+1)/2 + j - i - 1 + let k = params.n * i - i * (i + 1u) / 2u + j - i - 1u; + square[idx] = condensed[k]; + } else { + // Lower triangle: mirror from upper + let k = params.n * j - j * (j + 1u) / 2u + i - j - 1u; + square[idx] = condensed[k]; + } +} diff --git a/src/runtime/wgpu/shaders/distance_squareform_inverse_f32.wgsl b/src/runtime/wgpu/shaders/distance_squareform_inverse_f32.wgsl new file mode 100644 index 00000000..d374cde0 --- /dev/null +++ b/src/runtime/wgpu/shaders/distance_squareform_inverse_f32.wgsl @@ -0,0 +1,40 @@ +const WORKGROUP_SIZE: u32 = 256u; + +struct Params { + n: u32, +} + +@group(0) @binding(0) var square: array; +@group(0) @binding(1) var condensed: array; +@group(0) @binding(2) var params: Params; + +// Convert condensed index k to (i, j) where i < j +fn condensed_to_ij(k: u32, n: u32) -> vec2 { + var i: u32 = 0u; + var count: u32 = 0u; + loop { + let row_count = n - 1u - i; + if (count + row_count > k) { + let j = k - count + i + 1u; + return vec2(i, j); + } + count += row_count; + i++; + } + return vec2(0u, 0u); +} + +@compute @workgroup_size(WORKGROUP_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + let k = gid.x; + let total = params.n * (params.n - 1u) / 2u; + if (k >= total) { + return; + } + + let ij = condensed_to_ij(k, params.n); + let i = ij.x; + let j = ij.y; + + condensed[k] = square[i * params.n + j]; +} diff --git a/src/runtime/wgpu/shaders/distributions.rs b/src/runtime/wgpu/shaders/distributions.rs index d7144a44..3844eaf2 100644 --- a/src/runtime/wgpu/shaders/distributions.rs +++ b/src/runtime/wgpu/shaders/distributions.rs @@ -1,4 +1,4 @@ -//! Distribution sampling WGSL kernel launchers +//! Distribution sampling WGSL kernel launchers (F32 only on WebGPU) //! //! Provides launchers for probability distribution sampling: //! - Bernoulli, Beta, Gamma, Exponential, Poisson @@ -6,16 +6,43 @@ use wgpu::{Buffer, Queue}; -use super::generator::{ - generate_bernoulli_shader, generate_beta_dist_shader, generate_binomial_shader, - generate_chi_squared_shader, generate_exponential_shader, generate_f_distribution_shader, - generate_gamma_dist_shader, generate_laplace_shader, generate_multinomial_count_shader, - generate_poisson_shader, generate_student_t_shader, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; use crate::error::{Error, Result}; +const BERNOULLI_SHADER: &str = include_str!("bernoulli_f32.wgsl"); +// entry point: "bernoulli_f32" + +const BETA_DIST_SHADER: &str = include_str!("beta_dist_f32.wgsl"); +// entry point: "beta_dist_f32" + +const GAMMA_DIST_SHADER: &str = include_str!("gamma_dist_f32.wgsl"); +// entry point: "gamma_dist_f32" + +const EXPONENTIAL_SHADER: &str = include_str!("exponential_f32.wgsl"); +// entry point: "exponential_f32" + +const POISSON_SHADER: &str = include_str!("poisson_f32.wgsl"); +// entry point: "poisson_f32" + +const BINOMIAL_SHADER: &str = include_str!("binomial_f32.wgsl"); +// entry point: "binomial_f32" + +const LAPLACE_SHADER: &str = include_str!("laplace_f32.wgsl"); +// entry point: "laplace_f32" + +const CHI_SQUARED_SHADER: &str = include_str!("chi_squared_f32.wgsl"); +// entry point: "chi_squared_f32" + +const STUDENT_T_SHADER: &str = include_str!("student_t_f32.wgsl"); +// entry point: "student_t_f32" + +const F_DISTRIBUTION_SHADER: &str = include_str!("f_distribution_f32.wgsl"); +// entry point: "f_distribution_f32" + +const MULTINOMIAL_COUNT_SHADER: &str = include_str!("multinomial_count_f32.wgsl"); +// entry point: "multinomial_count_f32" + fn check_float_dtype(dtype: DType, op: &'static str) -> Result<()> { match dtype { DType::F32 => Ok(()), @@ -37,15 +64,13 @@ pub fn launch_bernoulli( } check_float_dtype(dtype, "bernoulli")?; - let name = "bernoulli_f32"; - let shader = generate_bernoulli_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module("bernoulli_f32", BERNOULLI_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline("bernoulli_f32", "bernoulli_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params]); let mut encoder = cache @@ -80,15 +105,13 @@ pub fn launch_beta_dist( } check_float_dtype(dtype, "beta")?; - let name = "beta_dist_f32"; - let shader = generate_beta_dist_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module("beta_dist_f32", BETA_DIST_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline("beta_dist_f32", "beta_dist_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params]); let mut encoder = cache @@ -123,15 +146,14 @@ pub fn launch_gamma_dist( } check_float_dtype(dtype, "gamma")?; - let name = "gamma_dist_f32"; - let shader = generate_gamma_dist_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module("gamma_dist_f32", GAMMA_DIST_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = + cache.get_or_create_pipeline("gamma_dist_f32", "gamma_dist_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params]); let mut encoder = cache @@ -166,15 +188,14 @@ pub fn launch_exponential( } check_float_dtype(dtype, "exponential")?; - let name = "exponential_f32"; - let shader = generate_exponential_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module("exponential_f32", EXPONENTIAL_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = + cache.get_or_create_pipeline("exponential_f32", "exponential_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params]); let mut encoder = cache @@ -209,15 +230,13 @@ pub fn launch_poisson( } check_float_dtype(dtype, "poisson")?; - let name = "poisson_f32"; - let shader = generate_poisson_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module("poisson_f32", POISSON_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline("poisson_f32", "poisson_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params]); let mut encoder = cache @@ -252,15 +271,13 @@ pub fn launch_binomial( } check_float_dtype(dtype, "binomial")?; - let name = "binomial_f32"; - let shader = generate_binomial_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module("binomial_f32", BINOMIAL_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline("binomial_f32", "binomial_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params]); let mut encoder = cache @@ -295,15 +312,13 @@ pub fn launch_laplace( } check_float_dtype(dtype, "laplace")?; - let name = "laplace_f32"; - let shader = generate_laplace_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module("laplace_f32", LAPLACE_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline("laplace_f32", "laplace_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params]); let mut encoder = cache @@ -338,15 +353,14 @@ pub fn launch_chi_squared( } check_float_dtype(dtype, "chi_squared")?; - let name = "chi_squared_f32"; - let shader = generate_chi_squared_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module("chi_squared_f32", CHI_SQUARED_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = + cache.get_or_create_pipeline("chi_squared_f32", "chi_squared_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params]); let mut encoder = cache @@ -381,15 +395,13 @@ pub fn launch_student_t( } check_float_dtype(dtype, "student_t")?; - let name = "student_t_f32"; - let shader = generate_student_t_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module("student_t_f32", STUDENT_T_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline("student_t_f32", "student_t_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params]); let mut encoder = cache @@ -424,15 +436,14 @@ pub fn launch_f_distribution( } check_float_dtype(dtype, "f_distribution")?; - let name = "f_distribution_f32"; - let shader = generate_f_distribution_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module("f_distribution_f32", F_DISTRIBUTION_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = + cache.get_or_create_pipeline("f_distribution_f32", "f_distribution_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params]); let mut encoder = cache @@ -494,15 +505,18 @@ pub fn launch_multinomial_count( } check_float_dtype(dtype, "multinomial_count")?; - let name = "multinomial_count_f32"; - let shader = generate_multinomial_count_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader); + let module = cache.get_or_create_module("multinomial_count_f32", MULTINOMIAL_COUNT_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline( + "multinomial_count_f32", + "multinomial_count_f32", + &module, + &layout, + ); let bind_group = cache.create_bind_group(&layout, &[cdf, uniforms, counts, params]); let mut encoder = cache diff --git a/src/runtime/wgpu/shaders/dtype_support.rs b/src/runtime/wgpu/shaders/dtype_support.rs index 424c8992..39c3f009 100644 --- a/src/runtime/wgpu/shaders/dtype_support.rs +++ b/src/runtime/wgpu/shaders/dtype_support.rs @@ -1,112 +1,44 @@ -//! DType support validation for WebGPU operations +//! DType support for WebGPU operations. //! -//! This module defines which operations support which dtypes and provides -//! validation functions to ensure operations are called with supported types. +//! WebGPU is a 32-bit compute backend. All element-wise, scalar, comparison, +//! and activation operations are F32 only. Cast supports F32 ↔ I32 ↔ U32 +//! because type conversions are necessary for indexing interop. use crate::dtype::DType; use crate::error::{Error, Result}; -// ============================================================================ -// Unary Operations Support -// ============================================================================ - -/// Operations that work for all dtypes (F32, I32, U32) -const UNIVERSAL_UNARY_OPS: &[&str] = &["abs", "square", "sign"]; - -/// Operations that work for signed types only (F32, I32) -const SIGNED_UNARY_OPS: &[&str] = &["neg"]; - -/// Operations that require floating point (F32 only) -const FLOAT_ONLY_UNARY_OPS: &[&str] = &[ - "sqrt", "exp", "log", "sin", "cos", "tan", "tanh", "recip", "floor", "ceil", "round", "relu", - "sigmoid", "silu", "gelu", "isnan", "isinf", -]; - -/// Check if a unary operation supports the given dtype -pub fn is_unary_op_supported(op: &str, dtype: DType) -> bool { - // Universal ops work for all types - if UNIVERSAL_UNARY_OPS.contains(&op) { - return matches!(dtype, DType::F32 | DType::I32 | DType::U32); - } - - // Signed ops don't work for U32 - if SIGNED_UNARY_OPS.contains(&op) { - return matches!(dtype, DType::F32 | DType::I32); - } - - // Float-only ops - if FLOAT_ONLY_UNARY_OPS.contains(&op) { - return dtype == DType::F32; - } - - // Default: assume F32 only for unknown ops +/// Returns true only for F32 (all WebGPU compute ops are F32-only). +pub fn is_wgpu_compute_supported(dtype: DType) -> bool { dtype == DType::F32 } -/// Validate that a unary operation supports the given dtype +/// Validate F32 for unary operations. pub fn check_unary_dtype_support(op: &'static str, dtype: DType) -> Result<()> { - if !is_unary_op_supported(op, dtype) { + if dtype != DType::F32 { return Err(Error::UnsupportedDType { dtype, op }); } Ok(()) } -// ============================================================================ -// Binary Operations Support -// ============================================================================ - -/// All binary operations support F32, I32, U32 -const BINARY_OPS: &[&str] = &["add", "sub", "mul", "div", "max", "min"]; - -/// Pow operation (requires special handling for integers) -const POW_OP: &str = "pow"; - -/// Check if a binary operation supports the given dtype -pub fn is_binary_op_supported(op: &str, dtype: DType) -> bool { - if BINARY_OPS.contains(&op) || op == POW_OP { - return matches!(dtype, DType::F32 | DType::I32 | DType::U32); - } - // Default: assume F32 only - dtype == DType::F32 -} - -/// Validate that a binary operation supports the given dtype +/// Validate F32 for binary operations. pub fn check_binary_dtype_support(op: &'static str, dtype: DType) -> Result<()> { - if !is_binary_op_supported(op, dtype) { + if dtype != DType::F32 { return Err(Error::UnsupportedDType { dtype, op }); } Ok(()) } -// ============================================================================ -// Scalar Operations Support -// ============================================================================ - -/// All scalar operations support F32, I32, U32 -pub fn is_scalar_op_supported(_op: &str, dtype: DType) -> bool { - matches!(dtype, DType::F32 | DType::I32 | DType::U32) -} - -/// Validate that a scalar operation supports the given dtype +/// Validate F32 for scalar operations. pub fn check_scalar_dtype_support(op: &'static str, dtype: DType) -> Result<()> { - if !is_scalar_op_supported(op, dtype) { + if dtype != DType::F32 { return Err(Error::UnsupportedDType { dtype, op }); } Ok(()) } -// ============================================================================ -// Comparison Operations Support -// ============================================================================ - -/// All comparison operations support F32, I32, U32 -pub fn is_compare_op_supported(dtype: DType) -> bool { - matches!(dtype, DType::F32 | DType::I32 | DType::U32) -} - -/// Validate that comparison operations support the given dtype +/// Validate F32 for comparison operations. pub fn check_compare_dtype_support(op: &'static str, dtype: DType) -> Result<()> { - if !is_compare_op_supported(dtype) { + if dtype != DType::F32 { return Err(Error::UnsupportedDType { dtype, op }); } Ok(()) @@ -117,59 +49,18 @@ mod tests { use super::*; #[test] - fn test_universal_unary_ops() { - // abs works for all types - assert!(is_unary_op_supported("abs", DType::F32)); - assert!(is_unary_op_supported("abs", DType::I32)); - assert!(is_unary_op_supported("abs", DType::U32)); - - // square works for all types - assert!(is_unary_op_supported("square", DType::F32)); - assert!(is_unary_op_supported("square", DType::I32)); - assert!(is_unary_op_supported("square", DType::U32)); - } - - #[test] - fn test_signed_unary_ops() { - // neg works for F32 and I32, not U32 - assert!(is_unary_op_supported("neg", DType::F32)); - assert!(is_unary_op_supported("neg", DType::I32)); - assert!(!is_unary_op_supported("neg", DType::U32)); - } - - #[test] - fn test_float_only_unary_ops() { - // sqrt is F32 only - assert!(is_unary_op_supported("sqrt", DType::F32)); - assert!(!is_unary_op_supported("sqrt", DType::I32)); - assert!(!is_unary_op_supported("sqrt", DType::U32)); - - // relu is F32 only - assert!(is_unary_op_supported("relu", DType::F32)); - assert!(!is_unary_op_supported("relu", DType::I32)); - assert!(!is_unary_op_supported("relu", DType::U32)); - } - - #[test] - fn test_binary_ops_all_dtypes() { - for &op in &["add", "sub", "mul", "div", "max", "min", "pow"] { - assert!(is_binary_op_supported(op, DType::F32)); - assert!(is_binary_op_supported(op, DType::I32)); - assert!(is_binary_op_supported(op, DType::U32)); - } - } - - #[test] - fn test_scalar_ops_all_dtypes() { - assert!(is_scalar_op_supported("add_scalar", DType::F32)); - assert!(is_scalar_op_supported("add_scalar", DType::I32)); - assert!(is_scalar_op_supported("add_scalar", DType::U32)); + fn test_f32_supported() { + assert!(check_unary_dtype_support("neg", DType::F32).is_ok()); + assert!(check_binary_dtype_support("add", DType::F32).is_ok()); + assert!(check_scalar_dtype_support("add_scalar", DType::F32).is_ok()); + assert!(check_compare_dtype_support("eq", DType::F32).is_ok()); } #[test] - fn test_compare_ops_all_dtypes() { - assert!(is_compare_op_supported(DType::F32)); - assert!(is_compare_op_supported(DType::I32)); - assert!(is_compare_op_supported(DType::U32)); + fn test_non_f32_rejected() { + assert!(check_unary_dtype_support("abs", DType::I32).is_err()); + assert!(check_binary_dtype_support("add", DType::U32).is_err()); + assert!(check_scalar_dtype_support("mul_scalar", DType::I32).is_err()); + assert!(check_compare_dtype_support("lt", DType::U32).is_err()); } } diff --git a/src/runtime/wgpu/shaders/elementwise.rs b/src/runtime/wgpu/shaders/elementwise.rs index e4655f36..7d9faf02 100644 --- a/src/runtime/wgpu/shaders/elementwise.rs +++ b/src/runtime/wgpu/shaders/elementwise.rs @@ -1,159 +1,48 @@ //! Element-wise WGSL kernel launchers //! -//! Provides launchers for element-wise operations including: -//! - Binary operations (add, sub, mul, div, pow, max, min) -//! - Unary operations (neg, abs, sqrt, exp, log, sin, cos, tan, tanh, etc.) -//! - Scalar operations (add_scalar, sub_scalar, mul_scalar, div_scalar, pow_scalar) -//! - Comparison operations (eq, ne, lt, le, gt, ge) -//! - Activation functions (relu, sigmoid, silu, gelu) -//! - Utility operations (clamp, isnan, isinf, where) -//! -//! Multi-dtype support: F32, I32, U32 (F16 requires shader-f16 extension) -//! All operations run entirely on GPU with no CPU fallback. - -use std::collections::HashMap; -use std::sync::{OnceLock, RwLock, RwLockReadGuard, RwLockWriteGuard}; - -// ============================================================================ -// Lock Helpers (Handle Poisoned Locks Gracefully) -// ============================================================================ - -/// Acquire read lock, recovering from poison if necessary. -/// Cache data remains valid even after a panic in another thread. -fn read_lock(lock: &RwLock) -> RwLockReadGuard<'_, T> { - lock.read().unwrap_or_else(|poisoned| poisoned.into_inner()) -} - -/// Acquire write lock, recovering from poison if necessary. -/// Cache data remains valid even after a panic in another thread. -fn write_lock(lock: &RwLock) -> RwLockWriteGuard<'_, T> { - lock.write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} +//! Binary and broadcast-binary ops support F32, I32, U32. +//! Unary ops: most are F32 only; neg/abs support I32, abs supports U32. +//! Scalar ops: F32, I32, U32 (no pow for integers). +//! Compare ops: F32, I32, U32. use wgpu::{Buffer, Queue}; -use super::dtype_support; -use super::generator::{ - dtype_suffix, generate_binary_shader, generate_cast_shader, generate_compare_shader, - generate_scalar_shader, generate_unary_shader, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; use crate::error::{Error, Result}; // ============================================================================ -// Shader Module Cache +// Static Shader Sources // ============================================================================ -/// Cache for leaked shader references (leaked once per dtype+op_type combination) -/// Key: (DType, operation_type), Value: &'static str to leaked shader source -static SHADER_CACHE: OnceLock>> = - OnceLock::new(); - -/// Cache for leaked module key references -static MODULE_KEY_CACHE: OnceLock>> = - OnceLock::new(); - -/// Get or generate shader for a specific dtype and operation type. -/// Generates shader once, leaks it once, caches the leaked reference. -/// Subsequent calls return the cached &'static str without leaking. -fn get_or_leak_shader(dtype: DType, op_type: &'static str) -> Result<&'static str> { - let cache = SHADER_CACHE.get_or_init(|| RwLock::new(HashMap::new())); - - // Check if already cached - { - let read_guard = read_lock(cache); - if let Some(&shader_ref) = read_guard.get(&(dtype, op_type)) { - return Ok(shader_ref); - } - } - - // Generate shader based on operation type - let shader = match op_type { - "binary" => generate_binary_shader(dtype)?, - "unary" => generate_unary_shader(dtype)?, - "scalar" => generate_scalar_shader(dtype)?, - "compare" => generate_compare_shader(dtype)?, - _ => return Err(Error::Internal(format!("Unknown op type: {}", op_type))), - }; - - // Leak ONCE and cache the reference - let leaked: &'static str = Box::leak(shader.into_boxed_str()); - - let mut write_guard = write_lock(cache); - write_guard.insert((dtype, op_type), leaked); - - Ok(leaked) -} - -/// Get the module key for a dtype and operation type. -/// Generates key once, leaks it once, caches the leaked reference. -fn get_or_leak_module_key(dtype: DType, op_type: &'static str) -> Result<&'static str> { - let cache = MODULE_KEY_CACHE.get_or_init(|| RwLock::new(HashMap::new())); - - // Check if already cached - { - let read_guard = read_lock(cache); - if let Some(&key_ref) = read_guard.get(&(dtype, op_type)) { - return Ok(key_ref); - } - } - - // Generate module key - let suffix = dtype_suffix(dtype)?; - let key = format!("{}_{}", op_type, suffix); - - // Leak ONCE and cache the reference - let leaked: &'static str = Box::leak(key.into_boxed_str()); - - let mut write_guard = write_lock(cache); - write_guard.insert((dtype, op_type), leaked); - - Ok(leaked) -} - -/// Cache for leaked entry point references -static ENTRY_POINT_CACHE: OnceLock>> = - OnceLock::new(); - -/// Get entry point name for an operation. -/// Generates once per (op, dtype), leaks once, caches the leaked reference. -fn get_or_leak_entry_point(op: &str, dtype: DType) -> Result<&'static str> { - let cache = ENTRY_POINT_CACHE.get_or_init(|| RwLock::new(HashMap::new())); - - let key = (op.to_string(), dtype); - - // Check if already cached - { - let read_guard = read_lock(cache); - if let Some(&entry_ref) = read_guard.get(&key) { - return Ok(entry_ref); - } - } - - // Generate entry point - let suffix = dtype_suffix(dtype)?; - let entry = format!("{}_{}", op, suffix); - - // Leak ONCE and cache the reference - let leaked: &'static str = Box::leak(entry.into_boxed_str()); - - let mut write_guard = write_lock(cache); - write_guard.insert(key, leaked); - - Ok(leaked) -} +const BINARY_F32_SHADER: &str = include_str!("binary.wgsl"); +const BINARY_I32_SHADER: &str = include_str!("binary_i32.wgsl"); +const BINARY_U32_SHADER: &str = include_str!("binary_u32.wgsl"); +const BINARY_BROADCAST_F32_SHADER: &str = include_str!("binary_broadcast.wgsl"); +const BINARY_BROADCAST_I32_SHADER: &str = include_str!("binary_broadcast_i32.wgsl"); +const BINARY_BROADCAST_U32_SHADER: &str = include_str!("binary_broadcast_u32.wgsl"); +const UNARY_SHADER: &str = include_str!("unary.wgsl"); +const UNARY_I32_SHADER: &str = include_str!("unary_i32.wgsl"); +const UNARY_U32_SHADER: &str = include_str!("unary_u32.wgsl"); +const SCALAR_SHADER: &str = include_str!("scalar.wgsl"); +const SCALAR_I32_SHADER: &str = include_str!("scalar_i32.wgsl"); +const SCALAR_U32_SHADER: &str = include_str!("scalar_u32.wgsl"); +const COMPARE_SHADER: &str = include_str!("compare.wgsl"); +const COMPARE_I32_SHADER: &str = include_str!("compare_i32.wgsl"); +const COMPARE_U32_SHADER: &str = include_str!("compare_u32.wgsl"); + +const CAST_F32_TO_I32_SHADER: &str = include_str!("cast_f32_to_i32.wgsl"); +const CAST_F32_TO_U32_SHADER: &str = include_str!("cast_f32_to_u32.wgsl"); +const CAST_I32_TO_F32_SHADER: &str = include_str!("cast_i32_to_f32.wgsl"); +const CAST_I32_TO_U32_SHADER: &str = include_str!("cast_i32_to_u32.wgsl"); +const CAST_U32_TO_F32_SHADER: &str = include_str!("cast_u32_to_f32.wgsl"); +const CAST_U32_TO_I32_SHADER: &str = include_str!("cast_u32_to_i32.wgsl"); // ============================================================================ // Binary Operations // ============================================================================ -/// Launch a binary element-wise operation kernel. -/// -/// Computes `out[i] = a[i] op b[i]` for all elements. -/// -/// Supports F32, I32, U32 dtypes. +/// Launch a binary element-wise operation: `out[i] = a[i] op b[i]`. F32, I32, U32. pub fn launch_binary_op( cache: &PipelineCache, queue: &Queue, @@ -165,38 +54,37 @@ pub fn launch_binary_op( numel: usize, dtype: DType, ) -> Result<()> { - // Validate dtype support for this operation - dtype_support::check_binary_dtype_support(op, dtype)?; - - // Normalize operation name let op_name = match op { "maximum" => "max", "minimum" => "min", _ => op, }; - // Get entry point name based on dtype (cached, leaked once per op+dtype) - let entry_point = get_or_leak_entry_point(op_name, dtype)?; + let (module_key, shader, suffix) = match dtype { + DType::F32 => ("binary_f32", BINARY_F32_SHADER, "f32"), + DType::I32 => ("binary_i32", BINARY_I32_SHADER, "i32"), + DType::U32 => ("binary_u32", BINARY_U32_SHADER, "u32"), + _ => return Err(Error::UnsupportedDType { dtype, op }), + }; - // Use generated shader for all dtypes to keep op coverage consistent. - let shader = get_or_leak_shader(dtype, "binary")?; - let module_key = get_or_leak_module_key(dtype, "binary")?; - let (module_name, shader_source): (&str, &str) = (module_key, shader); + // pow and atan2 are float-only + if matches!(op_name, "pow" | "atan2") && dtype != DType::F32 { + return Err(Error::UnsupportedDType { dtype, op }); + } - let module = cache.get_or_create_module(module_name, shader_source); + let entry_point: String = format!("{}_{}", op_name, suffix); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(module_name, entry_point, &module, &layout); - + let pipeline = cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[a, b, out, params_buffer]); let mut encoder = cache .device() .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(op) }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some(op), @@ -206,17 +94,11 @@ pub fn launch_binary_op( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(numel), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } -/// Launch a broadcast binary element-wise operation kernel. -/// -/// Computes `out[i] = a[broadcast_idx_a] op b[broadcast_idx_b]` for all elements, -/// where broadcast indices are computed from strides (0 for broadcast dimensions). -/// -/// Supports F32, I32, U32 dtypes. +/// Launch a broadcast binary operation. F32, I32, U32. #[allow(clippy::too_many_arguments)] pub fn launch_broadcast_binary_op( cache: &PipelineCache, @@ -232,54 +114,32 @@ pub fn launch_broadcast_binary_op( numel: usize, dtype: DType, ) -> Result<()> { - // Validate dtype support for this operation - dtype_support::check_binary_dtype_support(op, dtype)?; - - // Normalize operation name let op_name = match op { "maximum" => "max", "minimum" => "min", _ => op, }; - // Generate entry point name - let suffix = super::generator::dtype_suffix(dtype)?; - let entry_point_str = format!("broadcast_{}_{}", op_name, suffix); - let entry_point: &'static str = Box::leak(entry_point_str.into_boxed_str()); - - // Generate broadcast shader (cached per dtype) - let shader = { - use super::generator::generate_broadcast_binary_shader; - let shader_cache = - SHADER_CACHE.get_or_init(|| std::sync::RwLock::new(std::collections::HashMap::new())); - - let cache_key = (dtype, "broadcast_binary"); - { - let read_guard = read_lock(shader_cache); - if let Some(&cached) = read_guard.get(&cache_key) { - cached - } else { - drop(read_guard); - let generated = generate_broadcast_binary_shader(dtype)?; - let leaked: &'static str = Box::leak(generated.into_boxed_str()); - let mut write_guard = write_lock(shader_cache); - write_guard.insert(cache_key, leaked); - leaked - } - } + let (module_key, shader, suffix) = match dtype { + DType::F32 => ("binary_broadcast_f32", BINARY_BROADCAST_F32_SHADER, "f32"), + DType::I32 => ("binary_broadcast_i32", BINARY_BROADCAST_I32_SHADER, "i32"), + DType::U32 => ("binary_broadcast_u32", BINARY_BROADCAST_U32_SHADER, "u32"), + _ => return Err(Error::UnsupportedDType { dtype, op }), }; - let module_key = format!("broadcast_binary_{}", suffix); - let module_key: &'static str = Box::leak(module_key.into_boxed_str()); + // pow is float-only + if op_name == "pow" && dtype != DType::F32 { + return Err(Error::UnsupportedDType { dtype, op }); + } + let entry_point: String = format!("broadcast_{}_{}", op_name, suffix); let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 6, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); - + let pipeline = cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); let bind_group = cache.create_bind_group( &layout, &[a, b, out, a_strides, b_strides, out_strides, params_buffer], @@ -290,7 +150,6 @@ pub fn launch_broadcast_binary_op( .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("broadcast_binary"), }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("broadcast_binary"), @@ -300,7 +159,6 @@ pub fn launch_broadcast_binary_op( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(numel), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } @@ -309,11 +167,8 @@ pub fn launch_broadcast_binary_op( // Unary Operations // ============================================================================ -/// Launch a unary element-wise operation kernel. -/// -/// Computes `out[i] = op(a[i])` for all elements. -/// -/// Supports F32, I32, U32 dtypes (operation-dependent). +/// Launch a unary operation: `out[i] = op(a[i])`. +/// Most ops are F32 only. neg/abs support I32, abs supports U32. pub fn launch_unary_op( cache: &PipelineCache, queue: &Queue, @@ -324,31 +179,83 @@ pub fn launch_unary_op( numel: usize, dtype: DType, ) -> Result<()> { - // Validate dtype support for this operation - dtype_support::check_unary_dtype_support(op, dtype)?; - - // Get entry point name based on dtype (cached, leaked once per op+dtype) - let entry_point = get_or_leak_entry_point(op, dtype)?; + // For I32/U32, only neg and abs are supported + match dtype { + DType::F32 => {} + DType::I32 => { + if !matches!(op, "neg" | "abs") { + return Err(Error::UnsupportedDType { dtype, op }); + } + } + DType::U32 => { + if op != "abs" { + return Err(Error::UnsupportedDType { dtype, op }); + } + } + _ => return Err(Error::UnsupportedDType { dtype, op }), + } - // Use generated shader for all dtypes to keep op coverage consistent. - let shader = get_or_leak_shader(dtype, "unary")?; - let module_key = get_or_leak_module_key(dtype, "unary")?; - let (module_name, shader_source): (&str, &str) = (module_key, shader); + let (module_key, shader, entry_point): (&str, &str, String) = match dtype { + DType::I32 => ("unary_i32", UNARY_I32_SHADER, format!("{}_i32", op)), + DType::U32 => ("unary_u32", UNARY_U32_SHADER, format!("{}_u32", op)), + DType::F32 => { + let ep: &'static str = match op { + "neg" => "neg_f32", + "abs" => "abs_f32", + "sqrt" => "sqrt_f32", + "exp" => "exp_f32", + "log" => "log_f32", + "sin" => "sin_f32", + "cos" => "cos_f32", + "tan" => "tan_f32", + "atan" => "atan_f32", + "tanh" => "tanh_f32", + "recip" => "recip_f32", + "floor" => "floor_f32", + "ceil" => "ceil_f32", + "round" => "round_f32", + "trunc" => "trunc_f32", + "rsqrt" => "rsqrt_f32", + "cbrt" => "cbrt_f32", + "exp2" => "exp2_f32", + "expm1" => "expm1_f32", + "log2" => "log2_f32", + "log10" => "log10_f32", + "log1p" => "log1p_f32", + "asin" => "asin_f32", + "acos" => "acos_f32", + "sinh" => "sinh_f32", + "cosh" => "cosh_f32", + "asinh" => "asinh_f32", + "acosh" => "acosh_f32", + "atanh" => "atanh_f32", + "square" => "square_f32", + "sign" => "sign_f32", + "relu" => "relu_f32", + "sigmoid" => "sigmoid_f32", + "silu" => "silu_f32", + "gelu" => "gelu_f32", + "isnan" => "isnan_f32", + "isinf" => "isinf_f32", + _ => return Err(Error::Internal(format!("Unknown unary op: {}", op))), + }; + ("unary_f32", UNARY_SHADER, ep.to_string()) + } + _ => unreachable!(), + }; - let module = cache.get_or_create_module(module_name, shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(module_name, entry_point, &module, &layout); - + let pipeline = cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[a, out, params_buffer]); let mut encoder = cache .device() .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(op) }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some(op), @@ -358,7 +265,6 @@ pub fn launch_unary_op( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(numel), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } @@ -367,11 +273,8 @@ pub fn launch_unary_op( // Scalar Operations // ============================================================================ -/// Launch a scalar element-wise operation kernel. -/// -/// Computes `out[i] = a[i] op scalar` for all elements. -/// -/// Supports F32, I32, U32 dtypes. +/// Launch a scalar operation: `out[i] = a[i] op scalar`. F32, I32, U32. +/// pow_scalar, leaky_relu, elu are F32-only. pub fn launch_scalar_op( cache: &PipelineCache, queue: &Queue, @@ -382,31 +285,57 @@ pub fn launch_scalar_op( numel: usize, dtype: DType, ) -> Result<()> { - // Validate dtype support for this operation - dtype_support::check_scalar_dtype_support(op, dtype)?; + // pow_scalar, leaky_relu, elu are F32-only + if matches!(op, "pow_scalar" | "leaky_relu" | "elu") && dtype != DType::F32 { + return Err(Error::UnsupportedDType { dtype, op }); + } - // Get entry point name based on dtype (cached, leaked once per op+dtype) - let entry_point = get_or_leak_entry_point(op, dtype)?; + let (module_key, shader, suffix) = match dtype { + DType::F32 => ("scalar_f32", SCALAR_SHADER, "f32"), + DType::I32 => ("scalar_i32", SCALAR_I32_SHADER, "i32"), + DType::U32 => ("scalar_u32", SCALAR_U32_SHADER, "u32"), + _ => return Err(Error::UnsupportedDType { dtype, op }), + }; - // Use generated shader for all dtypes to keep op coverage consistent. - let shader = get_or_leak_shader(dtype, "scalar")?; - let module_key = get_or_leak_module_key(dtype, "scalar")?; - let (module_name, shader_source): (&str, &str) = (module_key, shader); + let entry_point: String = match dtype { + DType::F32 => { + // F32 uses static entry points + let ep: &'static str = match op { + "add_scalar" => "add_scalar_f32", + "sub_scalar" => "sub_scalar_f32", + "rsub_scalar" => "rsub_scalar_f32", + "mul_scalar" => "mul_scalar_f32", + "div_scalar" => "div_scalar_f32", + "pow_scalar" => "pow_scalar_f32", + "leaky_relu" => "leaky_relu_f32", + "elu" => "elu_f32", + _ => return Err(Error::Internal(format!("Unknown scalar op: {}", op))), + }; + ep.to_string() + } + _ => { + // I32/U32: format entry point + match op { + "add_scalar" | "sub_scalar" | "rsub_scalar" | "mul_scalar" | "div_scalar" => { + format!("{}_{}", op, suffix) + } + _ => return Err(Error::Internal(format!("Unknown scalar op: {}", op))), + } + } + }; - let module = cache.get_or_create_module(module_name, shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(module_name, entry_point, &module, &layout); - + let pipeline = cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[a, out, params_buffer]); let mut encoder = cache .device() .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(op) }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some(op), @@ -416,7 +345,6 @@ pub fn launch_scalar_op( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(numel), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } @@ -425,11 +353,8 @@ pub fn launch_scalar_op( // Comparison Operations // ============================================================================ -/// Launch a comparison element-wise operation kernel. -/// -/// Computes `out[i] = (a[i] op b[i]) ? 1.0 : 0.0` for all elements. -/// -/// Supports F32, I32, U32 dtypes. Output is always F32. +/// Launch a comparison operation: `out[i] = (a[i] op b[i]) ? 1.0 : 0.0`. F32, I32, U32. +/// Output is always F32. pub fn launch_compare_op( cache: &PipelineCache, queue: &Queue, @@ -441,31 +366,30 @@ pub fn launch_compare_op( numel: usize, dtype: DType, ) -> Result<()> { - // Validate dtype support for this operation - dtype_support::check_compare_dtype_support(op, dtype)?; - - // Get entry point name based on dtype (cached, leaked once per op+dtype) - let entry_point = get_or_leak_entry_point(op, dtype)?; + let (module_key, shader, suffix) = match dtype { + DType::F32 => ("compare_f32", COMPARE_SHADER, "f32"), + DType::I32 => ("compare_i32", COMPARE_I32_SHADER, "i32"), + DType::U32 => ("compare_u32", COMPARE_U32_SHADER, "u32"), + _ => return Err(Error::UnsupportedDType { dtype, op }), + }; - // Use generated shader for all dtypes to keep op coverage consistent. - let shader = get_or_leak_shader(dtype, "compare")?; - let module_key = get_or_leak_module_key(dtype, "compare")?; - let (module_name, shader_source): (&str, &str) = (module_key, shader); + let entry_point: String = match op { + "eq" | "ne" | "lt" | "le" | "gt" | "ge" => format!("{}_{}", op, suffix), + _ => return Err(Error::Internal(format!("Unknown compare op: {}", op))), + }; - let module = cache.get_or_create_module(module_name, shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(module_name, entry_point, &module, &layout); - + let pipeline = cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[a, b, out, params_buffer]); let mut encoder = cache .device() .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(op) }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some(op), @@ -475,37 +399,15 @@ pub fn launch_compare_op( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(numel), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } // ============================================================================ -// Cast Operation (uses generator for DRY) +// Cast Operations // ============================================================================ -/// Get static module name and entry point for a cast operation. -/// -/// Returns (module_name, entry_point) for caching purposes. -/// The shader source is generated dynamically via `generate_cast_shader()`. -fn cast_info(src: DType, dst: DType) -> Option<(&'static str, &'static str)> { - match (src, dst) { - (DType::F32, DType::I32) => Some(("cast_f32_i32", "cast_f32_to_i32")), - (DType::F32, DType::U32) => Some(("cast_f32_u32", "cast_f32_to_u32")), - (DType::I32, DType::F32) => Some(("cast_i32_f32", "cast_i32_to_f32")), - (DType::I32, DType::U32) => Some(("cast_i32_u32", "cast_i32_to_u32")), - (DType::U32, DType::F32) => Some(("cast_u32_f32", "cast_u32_to_f32")), - (DType::U32, DType::I32) => Some(("cast_u32_i32", "cast_u32_to_i32")), - _ => None, - } -} - -/// Launch cast operation kernel. -/// -/// Converts `out[i] = dst_dtype(a[i])` for all elements. -/// Supports F32 ↔ I32 ↔ U32 conversions. -/// -/// Uses `generate_cast_shader()` from the generator module for DRY shader generation. +/// Launch a cast operation: `out[i] = DstType(a[i])`. Supports F32 ↔ I32 ↔ U32. pub fn launch_cast_op( cache: &PipelineCache, queue: &Queue, @@ -516,29 +418,33 @@ pub fn launch_cast_op( src_dtype: DType, dst_dtype: DType, ) -> Result<()> { - // Same-type cast is a no-op (should be caught earlier, but handle here too) if src_dtype == dst_dtype { return Ok(()); } - // Get static names for caching - let (module_name, entry_point) = - cast_info(src_dtype, dst_dtype).ok_or_else(|| Error::UnsupportedDType { - dtype: src_dtype, - op: "cast (unsupported dtype combination)", - })?; - - // Generate shader source dynamically (DRY - single source of truth in generator.rs) - let shader_source = generate_cast_shader(src_dtype, dst_dtype)?; + let (module_name, entry_point, shader_source): (&'static str, &'static str, &'static str) = + match (src_dtype, dst_dtype) { + (DType::F32, DType::I32) => ("cast_f32_i32", "cast_f32_to_i32", CAST_F32_TO_I32_SHADER), + (DType::F32, DType::U32) => ("cast_f32_u32", "cast_f32_to_u32", CAST_F32_TO_U32_SHADER), + (DType::I32, DType::F32) => ("cast_i32_f32", "cast_i32_to_f32", CAST_I32_TO_F32_SHADER), + (DType::I32, DType::U32) => ("cast_i32_u32", "cast_i32_to_u32", CAST_I32_TO_U32_SHADER), + (DType::U32, DType::F32) => ("cast_u32_f32", "cast_u32_to_f32", CAST_U32_TO_F32_SHADER), + (DType::U32, DType::I32) => ("cast_u32_i32", "cast_u32_to_i32", CAST_U32_TO_I32_SHADER), + _ => { + return Err(Error::UnsupportedDType { + dtype: src_dtype, + op: "cast", + }); + } + }; - let module = cache.get_or_create_module(module_name, &shader_source); + let module = cache.get_or_create_module(module_name, shader_source); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); let pipeline = cache.get_or_create_pipeline(module_name, entry_point, &module, &layout); - let bind_group = cache.create_bind_group(&layout, &[a, out, params_buffer]); let mut encoder = cache @@ -546,7 +452,6 @@ pub fn launch_cast_op( .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("cast"), }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("cast"), @@ -556,7 +461,6 @@ pub fn launch_cast_op( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(numel), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } diff --git a/src/runtime/wgpu/shaders/embedding_lookup_f32.wgsl b/src/runtime/wgpu/shaders/embedding_lookup_f32.wgsl new file mode 100644 index 00000000..88f8ca0e --- /dev/null +++ b/src/runtime/wgpu/shaders/embedding_lookup_f32.wgsl @@ -0,0 +1,44 @@ +// Auto-generated embedding_lookup operation for f32 +// Industry-standard embedding table lookup used in neural networks. +// Each thread handles one index lookup and copies the full embedding row. + +const WORKGROUP_SIZE: u32 = 256u; + +struct EmbeddingLookupParams { + num_indices: u32, + vocab_size: u32, + embedding_dim: u32, + _pad0: u32, +} + +@group(0) @binding(0) var embeddings: array; +@group(0) @binding(1) var indices: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: EmbeddingLookupParams; + +@compute @workgroup_size(256) +fn embedding_lookup_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.num_indices) { + return; + } + + let index_val = indices[idx]; + + // Check bounds + if (index_val < 0 || u32(index_val) >= params.vocab_size) { + // Out of bounds - fill with zeros + let out_start = idx * params.embedding_dim; + for (var i: u32 = 0u; i < params.embedding_dim; i = i + 1u) { + output[out_start + i] = 0.0; + } + return; + } + + // Copy the entire embedding row to output + let emb_start = u32(index_val) * params.embedding_dim; + let out_start = idx * params.embedding_dim; + for (var i: u32 = 0u; i < params.embedding_dim; i = i + 1u) { + output[out_start + i] = embeddings[emb_start + i]; + } +} diff --git a/src/runtime/wgpu/shaders/embedding_lookup_i32.wgsl b/src/runtime/wgpu/shaders/embedding_lookup_i32.wgsl new file mode 100644 index 00000000..0a7ae9cc --- /dev/null +++ b/src/runtime/wgpu/shaders/embedding_lookup_i32.wgsl @@ -0,0 +1,44 @@ +// Auto-generated embedding_lookup operation for i32 +// Industry-standard embedding table lookup used in neural networks. +// Each thread handles one index lookup and copies the full embedding row. + +const WORKGROUP_SIZE: u32 = 256u; + +struct EmbeddingLookupParams { + num_indices: u32, + vocab_size: u32, + embedding_dim: u32, + _pad0: u32, +} + +@group(0) @binding(0) var embeddings: array; +@group(0) @binding(1) var indices: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: EmbeddingLookupParams; + +@compute @workgroup_size(256) +fn embedding_lookup_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.num_indices) { + return; + } + + let index_val = indices[idx]; + + // Check bounds + if (index_val < 0 || u32(index_val) >= params.vocab_size) { + // Out of bounds - fill with zeros + let out_start = idx * params.embedding_dim; + for (var i: u32 = 0u; i < params.embedding_dim; i = i + 1u) { + output[out_start + i] = 0; + } + return; + } + + // Copy the entire embedding row to output + let emb_start = u32(index_val) * params.embedding_dim; + let out_start = idx * params.embedding_dim; + for (var i: u32 = 0u; i < params.embedding_dim; i = i + 1u) { + output[out_start + i] = embeddings[emb_start + i]; + } +} diff --git a/src/runtime/wgpu/shaders/embedding_lookup_u32.wgsl b/src/runtime/wgpu/shaders/embedding_lookup_u32.wgsl new file mode 100644 index 00000000..fcf4486a --- /dev/null +++ b/src/runtime/wgpu/shaders/embedding_lookup_u32.wgsl @@ -0,0 +1,44 @@ +// Auto-generated embedding_lookup operation for u32 +// Industry-standard embedding table lookup used in neural networks. +// Each thread handles one index lookup and copies the full embedding row. + +const WORKGROUP_SIZE: u32 = 256u; + +struct EmbeddingLookupParams { + num_indices: u32, + vocab_size: u32, + embedding_dim: u32, + _pad0: u32, +} + +@group(0) @binding(0) var embeddings: array; +@group(0) @binding(1) var indices: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: EmbeddingLookupParams; + +@compute @workgroup_size(256) +fn embedding_lookup_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.num_indices) { + return; + } + + let index_val = indices[idx]; + + // Check bounds + if (index_val < 0 || u32(index_val) >= params.vocab_size) { + // Out of bounds - fill with zeros + let out_start = idx * params.embedding_dim; + for (var i: u32 = 0u; i < params.embedding_dim; i = i + 1u) { + output[out_start + i] = 0u; + } + return; + } + + // Copy the entire embedding row to output + let emb_start = u32(index_val) * params.embedding_dim; + let out_start = idx * params.embedding_dim; + for (var i: u32 = 0u; i < params.embedding_dim; i = i + 1u) { + output[out_start + i] = embeddings[emb_start + i]; + } +} diff --git a/src/runtime/wgpu/shaders/exponential_f32.wgsl b/src/runtime/wgpu/shaders/exponential_f32.wgsl new file mode 100644 index 00000000..bd13602d --- /dev/null +++ b/src/runtime/wgpu/shaders/exponential_f32.wgsl @@ -0,0 +1,39 @@ +// Exponential distribution sampling for f32 + +// PCG hash function for random number generation +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct ExponentialParams { + numel: u32, + seed: u32, + rate: f32, + _pad: u32, +} + +@group(0) @binding(0) var out: array; +@group(0) @binding(1) var params: ExponentialParams; + +@compute @workgroup_size(256) +fn exponential_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if idx < params.numel { + var state = pcg_init(params.seed, idx); + let u = max(pcg_uniform(&state), 0.0000001); + out[idx] = f32(-log(u) / params.rate); + } +} diff --git a/src/runtime/wgpu/shaders/extract_unique_f32.wgsl b/src/runtime/wgpu/shaders/extract_unique_f32.wgsl new file mode 100644 index 00000000..06cd67c9 --- /dev/null +++ b/src/runtime/wgpu/shaders/extract_unique_f32.wgsl @@ -0,0 +1,22 @@ +// Extract unique elements from a sorted f32 array using atomic counter + +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var sorted_input: array; +@group(0) @binding(1) var unique_output: array; +@group(0) @binding(2) var counter: array>; +@group(0) @binding(3) var params: CountParams; + +@compute @workgroup_size(256) +fn extract_unique_f32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + if (idx >= params.numel) { + return; + } + if (idx == 0u || sorted_input[idx] != sorted_input[idx - 1u]) { + let out_idx = atomicAdd(&counter[0], 1u); + unique_output[out_idx] = sorted_input[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/extract_unique_i32.wgsl b/src/runtime/wgpu/shaders/extract_unique_i32.wgsl new file mode 100644 index 00000000..8970d06a --- /dev/null +++ b/src/runtime/wgpu/shaders/extract_unique_i32.wgsl @@ -0,0 +1,22 @@ +// Extract unique elements from a sorted i32 array using atomic counter + +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var sorted_input: array; +@group(0) @binding(1) var unique_output: array; +@group(0) @binding(2) var counter: array>; +@group(0) @binding(3) var params: CountParams; + +@compute @workgroup_size(256) +fn extract_unique_i32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + if (idx >= params.numel) { + return; + } + if (idx == 0u || sorted_input[idx] != sorted_input[idx - 1u]) { + let out_idx = atomicAdd(&counter[0], 1u); + unique_output[out_idx] = sorted_input[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/extract_unique_u32.wgsl b/src/runtime/wgpu/shaders/extract_unique_u32.wgsl new file mode 100644 index 00000000..97fbda53 --- /dev/null +++ b/src/runtime/wgpu/shaders/extract_unique_u32.wgsl @@ -0,0 +1,22 @@ +// Extract unique elements from a sorted u32 array using atomic counter + +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var sorted_input: array; +@group(0) @binding(1) var unique_output: array; +@group(0) @binding(2) var counter: array>; +@group(0) @binding(3) var params: CountParams; + +@compute @workgroup_size(256) +fn extract_unique_u32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + if (idx >= params.numel) { + return; + } + if (idx == 0u || sorted_input[idx] != sorted_input[idx - 1u]) { + let out_idx = atomicAdd(&counter[0], 1u); + unique_output[out_idx] = sorted_input[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/eye_f32.wgsl b/src/runtime/wgpu/shaders/eye_f32.wgsl new file mode 100644 index 00000000..73ba2ca0 --- /dev/null +++ b/src/runtime/wgpu/shaders/eye_f32.wgsl @@ -0,0 +1,26 @@ +// Auto-generated eye (identity matrix) operation for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct EyeParams { + n: u32, // rows + m: u32, // cols + numel: u32, // n * m +} + +@group(0) @binding(0) var eye_out: array; +@group(0) @binding(1) var eye_params: EyeParams; + +@compute @workgroup_size(256) +fn eye_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < eye_params.numel) { + let row = idx / eye_params.m; + let col = idx % eye_params.m; + if (row == col) { + eye_out[idx] = f32(1.0); + } else { + eye_out[idx] = f32(0.0); + } + } +} diff --git a/src/runtime/wgpu/shaders/eye_i32.wgsl b/src/runtime/wgpu/shaders/eye_i32.wgsl new file mode 100644 index 00000000..b9ce696b --- /dev/null +++ b/src/runtime/wgpu/shaders/eye_i32.wgsl @@ -0,0 +1,26 @@ +// Auto-generated eye (identity matrix) operation for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct EyeParams { + n: u32, // rows + m: u32, // cols + numel: u32, // n * m +} + +@group(0) @binding(0) var eye_out: array; +@group(0) @binding(1) var eye_params: EyeParams; + +@compute @workgroup_size(256) +fn eye_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < eye_params.numel) { + let row = idx / eye_params.m; + let col = idx % eye_params.m; + if (row == col) { + eye_out[idx] = i32(1); + } else { + eye_out[idx] = i32(0); + } + } +} diff --git a/src/runtime/wgpu/shaders/eye_u32.wgsl b/src/runtime/wgpu/shaders/eye_u32.wgsl new file mode 100644 index 00000000..89c25468 --- /dev/null +++ b/src/runtime/wgpu/shaders/eye_u32.wgsl @@ -0,0 +1,26 @@ +// Auto-generated eye (identity matrix) operation for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct EyeParams { + n: u32, // rows + m: u32, // cols + numel: u32, // n * m +} + +@group(0) @binding(0) var eye_out: array; +@group(0) @binding(1) var eye_params: EyeParams; + +@compute @workgroup_size(256) +fn eye_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < eye_params.numel) { + let row = idx / eye_params.m; + let col = idx % eye_params.m; + if (row == col) { + eye_out[idx] = u32(1); + } else { + eye_out[idx] = u32(0); + } + } +} diff --git a/src/runtime/wgpu/shaders/f_distribution_f32.wgsl b/src/runtime/wgpu/shaders/f_distribution_f32.wgsl new file mode 100644 index 00000000..8e6d2ca1 --- /dev/null +++ b/src/runtime/wgpu/shaders/f_distribution_f32.wgsl @@ -0,0 +1,92 @@ +// F distribution sampling for f32 + +// PCG hash function for random number generation +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; +} + +// Box-Muller for normal distribution +fn sample_normal(state: ptr) -> f32 { + let u1 = max(pcg_uniform(state), 0.0000001); + let u2 = pcg_uniform(state); + return sqrt(-2.0 * log(u1)) * cos(6.28318530718 * u2); +} + +// Gamma via Marsaglia-Tsang method +fn sample_gamma_mt(state: ptr, shape: f32, scale: f32) -> f32 { + var alpha = shape; + var boost = 1.0; + + // Handle shape < 1 by boosting + if alpha < 1.0 { + boost = pow(pcg_uniform(state), 1.0 / alpha); + alpha = alpha + 1.0; + } + + let d = alpha - 1.0 / 3.0; + let c = 1.0 / sqrt(9.0 * d); + + // Rejection sampling + for (var i = 0u; i < 100u; i = i + 1u) { + var x: f32; + var v: f32; + + // Generate valid v + for (var j = 0u; j < 100u; j = j + 1u) { + x = sample_normal(state); + v = 1.0 + c * x; + if v > 0.0 { + break; + } + } + + v = v * v * v; + let u = pcg_uniform(state); + let x2 = x * x; + + // Accept/reject + if u < 1.0 - 0.0331 * x2 * x2 { + return d * v * boost * scale; + } + if log(u) < 0.5 * x2 + d * (1.0 - v + log(v)) { + return d * v * boost * scale; + } + } + + // Fallback (should rarely reach) + return d * boost * scale; +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct FDistributionParams { + numel: u32, + seed: u32, + df1: f32, + df2: f32, +} + +@group(0) @binding(0) var out: array; +@group(0) @binding(1) var params: FDistributionParams; + +@compute @workgroup_size(256) +fn f_distribution_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if idx < params.numel { + var state = pcg_init(params.seed, idx); + let chi2_1 = sample_gamma_mt(&state, params.df1 / 2.0, 2.0); + let chi2_2 = sample_gamma_mt(&state, params.df2 / 2.0, 2.0); + out[idx] = f32((chi2_1 / params.df1) / (chi2_2 / params.df2)); + } +} diff --git a/src/runtime/wgpu/shaders/fft.rs b/src/runtime/wgpu/shaders/fft.rs index 35612d94..8e192b29 100644 --- a/src/runtime/wgpu/shaders/fft.rs +++ b/src/runtime/wgpu/shaders/fft.rs @@ -1,16 +1,36 @@ //! FFT kernel launchers for WebGPU //! -//! Provides dispatch functions for FFT compute shaders. +//! Provides dispatch functions for FFT compute shaders (F32 only on WebGPU). -use super::generator::{ - MAX_WORKGROUP_FFT_SIZE, generate_fftshift_shader, generate_hermitian_extend_shader, - generate_irfft_unpack_shader, generate_rfft_pack_shader, generate_rfft_truncate_shader, - generate_stockham_fft_shader, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::error::Result; use wgpu::{Buffer, Queue}; +/// Maximum FFT size for shared memory (workgroup) implementation. +/// Matches the shared memory array size in stockham_fft.wgsl. +pub const MAX_WORKGROUP_FFT_SIZE: usize = 256; + +const STOCKHAM_FFT_SHADER: &str = include_str!("stockham_fft.wgsl"); +// entry points: "stockham_fft_small", "stockham_fft_stage", "scale_complex" + +const FFTSHIFT_SHADER: &str = include_str!("fftshift.wgsl"); +// entry points: "fftshift", "ifftshift" + +const RFFT_PACK_SHADER: &str = include_str!("rfft_pack.wgsl"); +// entry point: "rfft_pack" + +const IRFFT_UNPACK_SHADER: &str = include_str!("irfft_unpack.wgsl"); +// entry point: "irfft_unpack" + +const HERMITIAN_EXTEND_SHADER: &str = include_str!("hermitian_extend.wgsl"); +// entry point: "hermitian_extend" + +const RFFT_TRUNCATE_SHADER: &str = include_str!("rfft_truncate.wgsl"); +// entry point: "rfft_truncate" + +const COPY_COMPLEX_SHADER: &str = include_str!("copy_complex.wgsl"); +// entry point: "copy_complex" + /// Launch batched Stockham FFT for small transforms (N <= MAX_WORKGROUP_FFT_SIZE) /// /// Each workgroup processes one FFT using shared memory. @@ -30,8 +50,7 @@ pub fn launch_stockham_fft_batched( ))); } - let shader = generate_stockham_fft_shader()?; - let module = pipeline_cache.get_or_create_module_from_source("stockham_fft", &shader); + let module = pipeline_cache.get_or_create_module("stockham_fft", STOCKHAM_FFT_SHADER); let layout = pipeline_cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, @@ -39,7 +58,7 @@ pub fn launch_stockham_fft_batched( num_readonly_storage: 0, }); - let pipeline = pipeline_cache.get_or_create_dynamic_pipeline( + let pipeline = pipeline_cache.get_or_create_pipeline( "stockham_fft", "stockham_fft_small", &module, @@ -80,8 +99,7 @@ pub fn launch_stockham_fft_stage( n: usize, batch_size: usize, ) -> Result<()> { - let shader = generate_stockham_fft_shader()?; - let module = pipeline_cache.get_or_create_module_from_source("stockham_fft", &shader); + let module = pipeline_cache.get_or_create_module("stockham_fft", STOCKHAM_FFT_SHADER); let layout = pipeline_cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, @@ -89,7 +107,7 @@ pub fn launch_stockham_fft_stage( num_readonly_storage: 0, }); - let pipeline = pipeline_cache.get_or_create_dynamic_pipeline( + let pipeline = pipeline_cache.get_or_create_pipeline( "stockham_fft", "stockham_fft_stage", &module, @@ -130,8 +148,7 @@ pub fn launch_scale_complex( params: &Buffer, n: usize, ) -> Result<()> { - let shader = generate_stockham_fft_shader()?; - let module = pipeline_cache.get_or_create_module_from_source("stockham_fft", &shader); + let module = pipeline_cache.get_or_create_module("stockham_fft", STOCKHAM_FFT_SHADER); let layout = pipeline_cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, @@ -139,12 +156,8 @@ pub fn launch_scale_complex( num_readonly_storage: 0, }); - let pipeline = pipeline_cache.get_or_create_dynamic_pipeline( - "stockham_fft", - "scale_complex", - &module, - &layout, - ); + let pipeline = + pipeline_cache.get_or_create_pipeline("stockham_fft", "scale_complex", &module, &layout); let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]); @@ -179,8 +192,7 @@ pub fn launch_fftshift( n: usize, batch_size: usize, ) -> Result<()> { - let shader = generate_fftshift_shader()?; - let module = pipeline_cache.get_or_create_module_from_source("fftshift", &shader); + let module = pipeline_cache.get_or_create_module("fftshift", FFTSHIFT_SHADER); let layout = pipeline_cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, @@ -188,8 +200,7 @@ pub fn launch_fftshift( num_readonly_storage: 0, }); - let pipeline = - pipeline_cache.get_or_create_dynamic_pipeline("fftshift", "fftshift", &module, &layout); + let pipeline = pipeline_cache.get_or_create_pipeline("fftshift", "fftshift", &module, &layout); let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]); @@ -224,8 +235,7 @@ pub fn launch_ifftshift( n: usize, batch_size: usize, ) -> Result<()> { - let shader = generate_fftshift_shader()?; - let module = pipeline_cache.get_or_create_module_from_source("fftshift", &shader); + let module = pipeline_cache.get_or_create_module("fftshift", FFTSHIFT_SHADER); let layout = pipeline_cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, @@ -233,8 +243,7 @@ pub fn launch_ifftshift( num_readonly_storage: 0, }); - let pipeline = - pipeline_cache.get_or_create_dynamic_pipeline("fftshift", "ifftshift", &module, &layout); + let pipeline = pipeline_cache.get_or_create_pipeline("fftshift", "ifftshift", &module, &layout); let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]); @@ -269,8 +278,7 @@ pub fn launch_rfft_pack( n: usize, batch_size: usize, ) -> Result<()> { - let shader = generate_rfft_pack_shader()?; - let module = pipeline_cache.get_or_create_module_from_source("rfft_pack", &shader); + let module = pipeline_cache.get_or_create_module("rfft_pack", RFFT_PACK_SHADER); let layout = pipeline_cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, @@ -279,7 +287,7 @@ pub fn launch_rfft_pack( }); let pipeline = - pipeline_cache.get_or_create_dynamic_pipeline("rfft_pack", "rfft_pack", &module, &layout); + pipeline_cache.get_or_create_pipeline("rfft_pack", "rfft_pack", &module, &layout); let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]); @@ -314,8 +322,7 @@ pub fn launch_irfft_unpack( n: usize, batch_size: usize, ) -> Result<()> { - let shader = generate_irfft_unpack_shader()?; - let module = pipeline_cache.get_or_create_module_from_source("irfft_unpack", &shader); + let module = pipeline_cache.get_or_create_module("irfft_unpack", IRFFT_UNPACK_SHADER); let layout = pipeline_cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, @@ -323,12 +330,8 @@ pub fn launch_irfft_unpack( num_readonly_storage: 0, }); - let pipeline = pipeline_cache.get_or_create_dynamic_pipeline( - "irfft_unpack", - "irfft_unpack", - &module, - &layout, - ); + let pipeline = + pipeline_cache.get_or_create_pipeline("irfft_unpack", "irfft_unpack", &module, &layout); let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]); @@ -363,8 +366,7 @@ pub fn launch_hermitian_extend( n: usize, batch_size: usize, ) -> Result<()> { - let shader = generate_hermitian_extend_shader()?; - let module = pipeline_cache.get_or_create_module_from_source("hermitian_extend", &shader); + let module = pipeline_cache.get_or_create_module("hermitian_extend", HERMITIAN_EXTEND_SHADER); let layout = pipeline_cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, @@ -372,7 +374,7 @@ pub fn launch_hermitian_extend( num_readonly_storage: 0, }); - let pipeline = pipeline_cache.get_or_create_dynamic_pipeline( + let pipeline = pipeline_cache.get_or_create_pipeline( "hermitian_extend", "hermitian_extend", &module, @@ -412,8 +414,7 @@ pub fn launch_rfft_truncate( half_n: usize, batch_size: usize, ) -> Result<()> { - let shader = generate_rfft_truncate_shader()?; - let module = pipeline_cache.get_or_create_module_from_source("rfft_truncate", &shader); + let module = pipeline_cache.get_or_create_module("rfft_truncate", RFFT_TRUNCATE_SHADER); let layout = pipeline_cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, @@ -421,12 +422,8 @@ pub fn launch_rfft_truncate( num_readonly_storage: 0, }); - let pipeline = pipeline_cache.get_or_create_dynamic_pipeline( - "rfft_truncate", - "rfft_truncate", - &module, - &layout, - ); + let pipeline = + pipeline_cache.get_or_create_pipeline("rfft_truncate", "rfft_truncate", &module, &layout); let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]); @@ -450,3 +447,46 @@ pub fn launch_rfft_truncate( queue.submit(std::iter::once(encoder.finish())); Ok(()) } + +/// Launch copy_complex shader +pub fn launch_copy_complex( + pipeline_cache: &PipelineCache, + queue: &Queue, + input: &Buffer, + output: &Buffer, + params: &Buffer, + n: usize, +) -> Result<()> { + let module = pipeline_cache.get_or_create_module("copy_complex", COPY_COMPLEX_SHADER); + + let layout = pipeline_cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 2, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + + let pipeline = + pipeline_cache.get_or_create_pipeline("copy_complex", "copy_complex", &module, &layout); + + let bind_group = pipeline_cache.create_bind_group(&layout, &[input, output, params]); + + let mut encoder = + pipeline_cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("copy_complex_encoder"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("copy_complex_pass"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + pass.dispatch_workgroups(workgroup_count(n), 1, 1); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} diff --git a/src/runtime/wgpu/shaders/fftshift.wgsl b/src/runtime/wgpu/shaders/fftshift.wgsl new file mode 100644 index 00000000..ac5e1b47 --- /dev/null +++ b/src/runtime/wgpu/shaders/fftshift.wgsl @@ -0,0 +1,92 @@ +// FFT shift shader - shifts zero-frequency to center + +const WORKGROUP_SIZE: u32 = 256u; + +struct ShiftParams { + n: u32, + batch_size: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var shift_input: array>; +@group(0) @binding(1) var shift_output: array>; +@group(0) @binding(2) var shift_params: ShiftParams; + +// Complex number helpers (vec2: x=real, y=imag) +fn cmul(a: vec2, b: vec2) -> vec2 { + return vec2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); +} + +fn cadd(a: vec2, b: vec2) -> vec2 { + return a + b; +} + +fn csub(a: vec2, b: vec2) -> vec2 { + return a - b; +} + +fn cscale(a: vec2, s: f32) -> vec2 { + return vec2(a.x * s, a.y * s); +} + +fn cconj(a: vec2) -> vec2 { + return vec2(a.x, -a.y); +} + +// Compute e^(i*theta) = cos(theta) + i*sin(theta) +fn cexp_i(theta: f32) -> vec2 { + return vec2(cos(theta), sin(theta)); +} + +@compute @workgroup_size(WORKGROUP_SIZE) +fn fftshift( + @builtin(global_invocation_id) gid: vec3 +) { + let idx = gid.x; + let batch_idx = gid.y; + let n = shift_params.n; + + if (idx >= n) { + return; + } + + let base_offset = batch_idx * n; + let half_n = n / 2u; + + // Swap first half with second half + var src_idx: u32; + if (idx < half_n) { + src_idx = idx + half_n; + } else { + src_idx = idx - half_n; + } + + shift_output[base_offset + idx] = shift_input[base_offset + src_idx]; +} + +@compute @workgroup_size(WORKGROUP_SIZE) +fn ifftshift( + @builtin(global_invocation_id) gid: vec3 +) { + let idx = gid.x; + let batch_idx = gid.y; + let n = shift_params.n; + + if (idx >= n) { + return; + } + + let base_offset = batch_idx * n; + let half_n = (n + 1u) / 2u; // Ceiling division for odd n + + // Inverse shift + var src_idx: u32; + if (idx < n - half_n) { + src_idx = idx + half_n; + } else { + src_idx = idx - (n - half_n); + } + + shift_output[base_offset + idx] = shift_input[base_offset + src_idx]; +} diff --git a/src/runtime/wgpu/shaders/fill.wgsl b/src/runtime/wgpu/shaders/fill.wgsl new file mode 100644 index 00000000..f993a232 --- /dev/null +++ b/src/runtime/wgpu/shaders/fill.wgsl @@ -0,0 +1,19 @@ +// F32 fill operation + +const WORKGROUP_SIZE: u32 = 256u; + +struct FillParams { + numel: u32, + value: f32, +} + +@group(0) @binding(0) var fill_out: array; +@group(0) @binding(1) var fill_params: FillParams; + +@compute @workgroup_size(256) +fn fill_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < fill_params.numel) { + fill_out[idx] = fill_params.value; + } +} diff --git a/src/runtime/wgpu/shaders/flat_to_multi_index.wgsl b/src/runtime/wgpu/shaders/flat_to_multi_index.wgsl new file mode 100644 index 00000000..107050a0 --- /dev/null +++ b/src/runtime/wgpu/shaders/flat_to_multi_index.wgsl @@ -0,0 +1,44 @@ +// Convert flat indices to multi-dimensional indices + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_DIMS: u32 = 8u; + +struct FlatToMultiParams { + nnz: u32, + ndim: u32, + _pad0: u32, + _pad1: u32, + shape: array, 2>, +} + +@group(0) @binding(0) var flat_indices: array; +@group(0) @binding(1) var multi_indices: array; +@group(0) @binding(2) var params: FlatToMultiParams; + +fn get_shape_dim(d: u32) -> u32 { + return params.shape[d / 4u][d % 4u]; +} + +@compute @workgroup_size(256) +fn flat_to_multi_index(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + + if (idx >= params.nnz) { + return; + } + + var flat_idx = u32(flat_indices[idx]); + let ndim = params.ndim; + + // Compute strides on the fly (row-major) + // and convert flat index to multi-index + for (var d: u32 = ndim; d > 0u; d = d - 1u) { + let dim = d - 1u; + let dim_size = get_shape_dim(dim); + let coord = flat_idx % dim_size; + flat_idx = flat_idx / dim_size; + + // Store: multi_indices[idx * ndim + dim] = coord + multi_indices[idx * ndim + dim] = i32(coord); + } +} diff --git a/src/runtime/wgpu/shaders/from_real_imag_f32.wgsl b/src/runtime/wgpu/shaders/from_real_imag_f32.wgsl new file mode 100644 index 00000000..5a0da839 --- /dev/null +++ b/src/runtime/wgpu/shaders/from_real_imag_f32.wgsl @@ -0,0 +1,19 @@ +// Construct Complex64 from real and imaginary parts +// entry point: from_real_imag_f32 + +struct Params { + numel: u32, +} + +@group(0) @binding(0) var real_input: array; +@group(0) @binding(1) var imag_input: array; +@group(0) @binding(2) var output: array>; +@group(0) @binding(3) var params: Params; + +@compute @workgroup_size(256) +fn from_real_imag_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < params.numel) { + output[idx] = vec2(real_input[idx], imag_input[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/fused_activation_mul.rs b/src/runtime/wgpu/shaders/fused_activation_mul.rs new file mode 100644 index 00000000..4986c7e9 --- /dev/null +++ b/src/runtime/wgpu/shaders/fused_activation_mul.rs @@ -0,0 +1,325 @@ +//! Fused activation-mul WGSL kernel launchers. F32 only. + +use wgpu::{Buffer, Queue}; + +use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +const FUSED_ACTIVATION_MUL_SHADER: &str = include_str!("fused_activation_mul.wgsl"); + +// ============================================================================ +// Forward launchers: (a, b) -> out +// ============================================================================ + +fn launch_fused_fwd( + cache: &PipelineCache, + queue: &Queue, + entry_point: &'static str, + op_name: &'static str, + a: &Buffer, + b: &Buffer, + out: &Buffer, + params_buffer: &Buffer, + numel: usize, + dtype: DType, +) -> Result<()> { + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { dtype, op: op_name }); + } + + let module = + cache.get_or_create_module("fused_activation_mul_f32", FUSED_ACTIVATION_MUL_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 3, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = + cache.get_or_create_pipeline("fused_activation_mul_f32", entry_point, &module, &layout); + let bind_group = cache.create_bind_group(&layout, &[a, b, out, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some(op_name), + }); + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some(op_name), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + pass.dispatch_workgroups(workgroup_count(numel), 1, 1); + } + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + +/// Launch fused SiLU-mul forward: `out = silu(a) * b`. F32 only. +pub fn launch_silu_mul( + cache: &PipelineCache, + queue: &Queue, + a: &Buffer, + b: &Buffer, + out: &Buffer, + params_buffer: &Buffer, + numel: usize, + dtype: DType, +) -> Result<()> { + launch_fused_fwd( + cache, + queue, + "silu_mul_f32", + "silu_mul", + a, + b, + out, + params_buffer, + numel, + dtype, + ) +} + +/// Launch fused GELU-mul forward: `out = gelu(a) * b`. F32 only. +pub fn launch_gelu_mul( + cache: &PipelineCache, + queue: &Queue, + a: &Buffer, + b: &Buffer, + out: &Buffer, + params_buffer: &Buffer, + numel: usize, + dtype: DType, +) -> Result<()> { + launch_fused_fwd( + cache, + queue, + "gelu_mul_f32", + "gelu_mul", + a, + b, + out, + params_buffer, + numel, + dtype, + ) +} + +/// Launch fused ReLU-mul forward: `out = relu(a) * b`. F32 only. +pub fn launch_relu_mul( + cache: &PipelineCache, + queue: &Queue, + a: &Buffer, + b: &Buffer, + out: &Buffer, + params_buffer: &Buffer, + numel: usize, + dtype: DType, +) -> Result<()> { + launch_fused_fwd( + cache, + queue, + "relu_mul_f32", + "relu_mul", + a, + b, + out, + params_buffer, + numel, + dtype, + ) +} + +/// Launch fused sigmoid-mul forward: `out = sigmoid(a) * b`. F32 only. +pub fn launch_sigmoid_mul( + cache: &PipelineCache, + queue: &Queue, + a: &Buffer, + b: &Buffer, + out: &Buffer, + params_buffer: &Buffer, + numel: usize, + dtype: DType, +) -> Result<()> { + launch_fused_fwd( + cache, + queue, + "sigmoid_mul_f32", + "sigmoid_mul", + a, + b, + out, + params_buffer, + numel, + dtype, + ) +} + +// ============================================================================ +// Backward launchers: (grad, a, b) -> (d_a, d_b) +// ============================================================================ + +fn launch_fused_bwd( + cache: &PipelineCache, + queue: &Queue, + entry_point: &'static str, + op_name: &'static str, + grad: &Buffer, + a: &Buffer, + b: &Buffer, + d_a: &Buffer, + d_b: &Buffer, + params_buffer: &Buffer, + numel: usize, + dtype: DType, +) -> Result<()> { + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { dtype, op: op_name }); + } + + let module = + cache.get_or_create_module("fused_activation_mul_f32", FUSED_ACTIVATION_MUL_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 5, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = + cache.get_or_create_pipeline("fused_activation_mul_f32", entry_point, &module, &layout); + let bind_group = cache.create_bind_group(&layout, &[grad, a, b, d_a, d_b, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some(op_name), + }); + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some(op_name), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + pass.dispatch_workgroups(workgroup_count(numel), 1, 1); + } + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + +/// Launch fused SiLU-mul backward. F32 only. +pub fn launch_silu_mul_bwd( + cache: &PipelineCache, + queue: &Queue, + grad: &Buffer, + a: &Buffer, + b: &Buffer, + d_a: &Buffer, + d_b: &Buffer, + params_buffer: &Buffer, + numel: usize, + dtype: DType, +) -> Result<()> { + launch_fused_bwd( + cache, + queue, + "silu_mul_bwd_f32", + "silu_mul_bwd", + grad, + a, + b, + d_a, + d_b, + params_buffer, + numel, + dtype, + ) +} + +/// Launch fused GELU-mul backward. F32 only. +pub fn launch_gelu_mul_bwd( + cache: &PipelineCache, + queue: &Queue, + grad: &Buffer, + a: &Buffer, + b: &Buffer, + d_a: &Buffer, + d_b: &Buffer, + params_buffer: &Buffer, + numel: usize, + dtype: DType, +) -> Result<()> { + launch_fused_bwd( + cache, + queue, + "gelu_mul_bwd_f32", + "gelu_mul_bwd", + grad, + a, + b, + d_a, + d_b, + params_buffer, + numel, + dtype, + ) +} + +/// Launch fused ReLU-mul backward. F32 only. +pub fn launch_relu_mul_bwd( + cache: &PipelineCache, + queue: &Queue, + grad: &Buffer, + a: &Buffer, + b: &Buffer, + d_a: &Buffer, + d_b: &Buffer, + params_buffer: &Buffer, + numel: usize, + dtype: DType, +) -> Result<()> { + launch_fused_bwd( + cache, + queue, + "relu_mul_bwd_f32", + "relu_mul_bwd", + grad, + a, + b, + d_a, + d_b, + params_buffer, + numel, + dtype, + ) +} + +/// Launch fused sigmoid-mul backward. F32 only. +pub fn launch_sigmoid_mul_bwd( + cache: &PipelineCache, + queue: &Queue, + grad: &Buffer, + a: &Buffer, + b: &Buffer, + d_a: &Buffer, + d_b: &Buffer, + params_buffer: &Buffer, + numel: usize, + dtype: DType, +) -> Result<()> { + launch_fused_bwd( + cache, + queue, + "sigmoid_mul_bwd_f32", + "sigmoid_mul_bwd", + grad, + a, + b, + d_a, + d_b, + params_buffer, + numel, + dtype, + ) +} diff --git a/src/runtime/wgpu/shaders/fused_activation_mul.wgsl b/src/runtime/wgpu/shaders/fused_activation_mul.wgsl new file mode 100644 index 00000000..a8ca4b6a --- /dev/null +++ b/src/runtime/wgpu/shaders/fused_activation_mul.wgsl @@ -0,0 +1,136 @@ +// Fused activation-mul WGSL shaders (F32 only) +// Forward: out = activation(a) * b +// Backward: d_a = grad * b * activation'(a), d_b = grad * activation(a) + +// ============================================================================ +// Forward kernels: 2 inputs (a, b), 1 output, uniform params +// ============================================================================ + +struct FusedFwdParams { + numel: u32, +} + +@group(0) @binding(0) var fwd_a: array; +@group(0) @binding(1) var fwd_b: array; +@group(0) @binding(2) var fwd_out: array; +@group(0) @binding(3) var fwd_params: FusedFwdParams; + +@compute @workgroup_size(256) +fn silu_mul_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < fwd_params.numel) { + let x = fwd_a[idx]; + let sig = 1.0 / (1.0 + exp(-x)); + fwd_out[idx] = x * sig * fwd_b[idx]; + } +} + +@compute @workgroup_size(256) +fn gelu_mul_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < fwd_params.numel) { + let x = fwd_a[idx]; + let c = 0.7978845608; + let k = 0.044715; + let inner = c * (x + k * x * x * x); + let t = tanh(inner); + fwd_out[idx] = 0.5 * x * (1.0 + t) * fwd_b[idx]; + } +} + +@compute @workgroup_size(256) +fn relu_mul_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < fwd_params.numel) { + fwd_out[idx] = max(0.0, fwd_a[idx]) * fwd_b[idx]; + } +} + +@compute @workgroup_size(256) +fn sigmoid_mul_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < fwd_params.numel) { + let sig = 1.0 / (1.0 + exp(-fwd_a[idx])); + fwd_out[idx] = sig * fwd_b[idx]; + } +} + +// ============================================================================ +// Backward kernels: 3 inputs (grad, a, b), 2 outputs (d_a, d_b), uniform params +// ============================================================================ + +struct FusedBwdParams { + numel: u32, +} + +@group(0) @binding(0) var bwd_grad: array; +@group(0) @binding(1) var bwd_a: array; +@group(0) @binding(2) var bwd_b: array; +@group(0) @binding(3) var bwd_d_a: array; +@group(0) @binding(4) var bwd_d_b: array; +@group(0) @binding(5) var bwd_params: FusedBwdParams; + +// silu'(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x))) +@compute @workgroup_size(256) +fn silu_mul_bwd_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < bwd_params.numel) { + let x = bwd_a[idx]; + let g = bwd_grad[idx]; + let bv = bwd_b[idx]; + let sig = 1.0 / (1.0 + exp(-x)); + let silu_val = x * sig; + let silu_deriv = sig * (1.0 + x * (1.0 - sig)); + bwd_d_b[idx] = g * silu_val; + bwd_d_a[idx] = g * bv * silu_deriv; + } +} + +// gelu'(x) = 0.5 * (1 + t) + 0.5 * x * (1 - t*t) * c * (1 + 3*k*x*x) +@compute @workgroup_size(256) +fn gelu_mul_bwd_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < bwd_params.numel) { + let x = bwd_a[idx]; + let g = bwd_grad[idx]; + let bv = bwd_b[idx]; + let c = 0.7978845608; + let k = 0.044715; + let inner = c * (x + k * x * x * x); + let t = tanh(inner); + let gelu_val = 0.5 * x * (1.0 + t); + let gelu_deriv = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * c * (1.0 + 3.0 * k * x * x); + bwd_d_b[idx] = g * gelu_val; + bwd_d_a[idx] = g * bv * gelu_deriv; + } +} + +// relu'(x) = 1 if x > 0, else 0 +@compute @workgroup_size(256) +fn relu_mul_bwd_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < bwd_params.numel) { + let x = bwd_a[idx]; + let g = bwd_grad[idx]; + let bv = bwd_b[idx]; + let relu_val = max(0.0, x); + let relu_deriv = select(0.0, 1.0, x > 0.0); + bwd_d_b[idx] = g * relu_val; + bwd_d_a[idx] = g * bv * relu_deriv; + } +} + +// sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x)) +@compute @workgroup_size(256) +fn sigmoid_mul_bwd_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < bwd_params.numel) { + let x = bwd_a[idx]; + let g = bwd_grad[idx]; + let bv = bwd_b[idx]; + let sig = 1.0 / (1.0 + exp(-x)); + let sig_deriv = sig * (1.0 - sig); + bwd_d_b[idx] = g * sig; + bwd_d_a[idx] = g * bv * sig_deriv; + } +} diff --git a/src/runtime/wgpu/shaders/fused_add_norm.rs b/src/runtime/wgpu/shaders/fused_add_norm.rs new file mode 100644 index 00000000..3fc2bc09 --- /dev/null +++ b/src/runtime/wgpu/shaders/fused_add_norm.rs @@ -0,0 +1,356 @@ +//! Fused add + normalization WGSL kernel launchers +//! +//! Provides launchers for fused add+norm operations: +//! - Fused add + RMS normalization (forward and backward) +//! - Fused add + Layer normalization (forward and backward) +//! - Helper reduction kernel for backward passes +//! +//! All operations run entirely on GPU with no CPU fallback. + +use wgpu::{Buffer, Queue}; + +use super::pipeline::{LayoutKey, PipelineCache}; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +const FUSED_ADD_NORM_SHADER: &str = include_str!("fused_add_norm.wgsl"); + +// ============================================================================ +// Helper Macros +// ============================================================================ + +macro_rules! check_dtype_f32 { + ($dtype:expr, $op:expr) => { + if $dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype: $dtype, + op: $op, + }); + } + }; +} + +// ============================================================================ +// Fused Add + RMS Normalization (Forward) +// ============================================================================ + +/// Launch fused add + RMS normalization kernel. +/// +/// Computes: pre_norm = input + residual +/// output = pre_norm / sqrt(mean(pre_norm^2) + eps) * weight +pub fn launch_fused_add_rms_norm( + cache: &PipelineCache, + queue: &Queue, + input: &Buffer, + residual: &Buffer, + weight: &Buffer, + output: &Buffer, + pre_norm: &Buffer, + params_buffer: &Buffer, + batch_size: usize, + dtype: DType, +) -> Result<()> { + check_dtype_f32!(dtype, "fused_add_rms_norm"); + + let module = cache.get_or_create_module("fused_add_norm", FUSED_ADD_NORM_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 5, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = + cache.get_or_create_pipeline("fused_add_norm", "fused_add_rms_norm_f32", &module, &layout); + + let bind_group = cache.create_bind_group( + &layout, + &[input, residual, weight, output, pre_norm, params_buffer], + ); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("fused_add_rms_norm"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("fused_add_rms_norm"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + // One workgroup per batch element + pass.dispatch_workgroups(batch_size as u32, 1, 1); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + +// ============================================================================ +// Fused Add + Layer Normalization (Forward) +// ============================================================================ + +/// Launch fused add + layer normalization kernel. +/// +/// Computes: pre_norm = input + residual +/// output = (pre_norm - mean) / sqrt(var + eps) * weight + bias +pub fn launch_fused_add_layer_norm( + cache: &PipelineCache, + queue: &Queue, + input: &Buffer, + residual: &Buffer, + weight: &Buffer, + bias: &Buffer, + output: &Buffer, + pre_norm: &Buffer, + params_buffer: &Buffer, + batch_size: usize, + dtype: DType, +) -> Result<()> { + check_dtype_f32!(dtype, "fused_add_layer_norm"); + + let module = cache.get_or_create_module("fused_add_norm", FUSED_ADD_NORM_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 6, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = cache.get_or_create_pipeline( + "fused_add_norm", + "fused_add_layer_norm_f32", + &module, + &layout, + ); + + let bind_group = cache.create_bind_group( + &layout, + &[ + input, + residual, + weight, + bias, + output, + pre_norm, + params_buffer, + ], + ); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("fused_add_layer_norm"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("fused_add_layer_norm"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + // One workgroup per batch element + pass.dispatch_workgroups(batch_size as u32, 1, 1); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + +// ============================================================================ +// Fused Add + RMS Normalization (Backward) +// ============================================================================ + +/// Launch fused add + RMS normalization backward kernel. +/// +/// Computes: +/// d_input_residual = (grad * weight - pre_norm * coeff) * inv_rms +/// d_weight_scratch[batch_idx * hidden + i] = grad[batch_idx * hidden + i] * pre_norm[...] / rms +/// +/// Caller must launch reduce_sum_rows to sum d_weight_scratch across batch dimension. +pub fn launch_fused_add_rms_norm_bwd( + cache: &PipelineCache, + queue: &Queue, + grad: &Buffer, + pre_norm: &Buffer, + weight: &Buffer, + d_input_residual: &Buffer, + d_weight_scratch: &Buffer, + params_buffer: &Buffer, + batch_size: usize, + dtype: DType, +) -> Result<()> { + check_dtype_f32!(dtype, "fused_add_rms_norm_bwd"); + + let module = cache.get_or_create_module("fused_add_norm", FUSED_ADD_NORM_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 5, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = cache.get_or_create_pipeline( + "fused_add_norm", + "fused_add_rms_norm_bwd_f32", + &module, + &layout, + ); + + let bind_group = cache.create_bind_group( + &layout, + &[ + grad, + pre_norm, + weight, + d_input_residual, + d_weight_scratch, + params_buffer, + ], + ); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("fused_add_rms_norm_bwd"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("fused_add_rms_norm_bwd"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + // One workgroup per batch element + pass.dispatch_workgroups(batch_size as u32, 1, 1); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + +// ============================================================================ +// Fused Add + Layer Normalization (Backward) +// ============================================================================ + +/// Launch fused add + layer normalization backward kernel. +/// +/// Computes: +/// d_input_residual = inv_std * (grad - mean_grad - normalized * mean_grad_normalized) +/// d_weight_scratch[batch_idx * hidden + i] = grad[...] * normalized +/// d_bias_scratch[batch_idx * hidden + i] = grad[...] +/// +/// Caller must launch reduce_sum_rows twice to sum d_weight_scratch and d_bias_scratch. +pub fn launch_fused_add_layer_norm_bwd( + cache: &PipelineCache, + queue: &Queue, + grad: &Buffer, + pre_norm: &Buffer, + weight: &Buffer, + bias: &Buffer, + d_input_residual: &Buffer, + d_weight_scratch: &Buffer, + d_bias_scratch: &Buffer, + params_buffer: &Buffer, + batch_size: usize, + dtype: DType, +) -> Result<()> { + check_dtype_f32!(dtype, "fused_add_layer_norm_bwd"); + + let module = cache.get_or_create_module("fused_add_norm", FUSED_ADD_NORM_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 7, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = cache.get_or_create_pipeline( + "fused_add_norm", + "fused_add_layer_norm_bwd_f32", + &module, + &layout, + ); + + let bind_group = cache.create_bind_group( + &layout, + &[ + grad, + pre_norm, + weight, + bias, + d_input_residual, + d_weight_scratch, + d_bias_scratch, + params_buffer, + ], + ); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("fused_add_layer_norm_bwd"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("fused_add_layer_norm_bwd"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + // One workgroup per batch element + pass.dispatch_workgroups(batch_size as u32, 1, 1); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + +// ============================================================================ +// Reduce Sum Rows (Helper for backward) +// ============================================================================ + +/// Launch reduce sum rows kernel to sum a [batch_size, hidden_size] array across batch dimension. +/// +/// Reduces input [batch_size, hidden_size] to output [hidden_size] by summing across batch. +pub fn launch_reduce_sum_rows( + cache: &PipelineCache, + queue: &Queue, + input: &Buffer, + output: &Buffer, + params_buffer: &Buffer, + hidden_size: usize, + dtype: DType, +) -> Result<()> { + check_dtype_f32!(dtype, "reduce_sum_rows"); + + let module = cache.get_or_create_module("fused_add_norm", FUSED_ADD_NORM_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 2, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = + cache.get_or_create_pipeline("fused_add_norm", "reduce_sum_rows_f32", &module, &layout); + + let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("reduce_sum_rows"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("reduce_sum_rows"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + // Dispatch enough workgroups to cover hidden_size elements + let num_workgroups = (hidden_size as u32 + 255) / 256; + pass.dispatch_workgroups(num_workgroups, 1, 1); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} diff --git a/src/runtime/wgpu/shaders/fused_add_norm.wgsl b/src/runtime/wgpu/shaders/fused_add_norm.wgsl new file mode 100644 index 00000000..f922565b --- /dev/null +++ b/src/runtime/wgpu/shaders/fused_add_norm.wgsl @@ -0,0 +1,402 @@ +// Fused add + normalization operations. F32 only. +// Entry points: +// - fused_add_rms_norm_f32: Add residual, then RMS normalize +// - fused_add_layer_norm_f32: Add residual, then layer normalize +// - fused_add_rms_norm_bwd_f32: Backward pass for fused add RMS norm +// - fused_add_layer_norm_bwd_f32: Backward pass for fused add layer norm +// - reduce_sum_rows_f32: Reduce d_weight/d_bias scratch buffers across batch dimension + +// ============================================================================ +// Workgroup Configuration +// ============================================================================ + +const WORKGROUP_SIZE: u32 = 256u; + +// ============================================================================ +// RMS Normalization Structs +// ============================================================================ + +struct RmsNormParams { + batch_size: u32, + hidden_size: u32, + eps: f32, +} + +struct LayerNormParams { + batch_size: u32, + hidden_size: u32, + eps: f32, +} + +struct ReduceSumParams { + batch_size: u32, + hidden_size: u32, +} + +// ============================================================================ +// Fused Add + RMS Norm (Forward) +// ============================================================================ + +@group(0) @binding(0) var farn_input: array; +@group(0) @binding(1) var farn_residual: array; +@group(0) @binding(2) var farn_weight: array; +@group(0) @binding(3) var farn_output: array; +@group(0) @binding(4) var farn_pre_norm: array; +@group(0) @binding(5) var farn_params: RmsNormParams; + +var farn_shared: array; + +@compute @workgroup_size(256) +fn fused_add_rms_norm_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let batch_idx = group_id.x; + + if (batch_idx >= farn_params.batch_size) { + return; + } + + let hidden_size = farn_params.hidden_size; + let eps = farn_params.eps; + let base_offset = batch_idx * hidden_size; + + // Step 1: Add input + residual -> pre_norm, compute sum of squares + var sum_sq: f32 = 0.0; + var i: u32 = tid; + while (i < hidden_size) { + let pre_val = farn_input[base_offset + i] + farn_residual[base_offset + i]; + farn_pre_norm[base_offset + i] = pre_val; + sum_sq = sum_sq + pre_val * pre_val; + i = i + WORKGROUP_SIZE; + } + + farn_shared[tid] = sum_sq; + workgroupBarrier(); + + // Reduce to get total sum of squares + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + farn_shared[tid] = farn_shared[tid] + farn_shared[tid + s]; + } + workgroupBarrier(); + } + + // Compute RMS: sqrt(mean(x^2) + eps) + let rms = sqrt(farn_shared[0] / f32(hidden_size) + eps); + workgroupBarrier(); + + // Step 2: Normalize and apply weight + i = tid; + while (i < hidden_size) { + farn_output[base_offset + i] = farn_pre_norm[base_offset + i] / rms * farn_weight[i]; + i = i + WORKGROUP_SIZE; + } +} + +// ============================================================================ +// Fused Add + Layer Norm (Forward) +// ============================================================================ + +@group(0) @binding(0) var faln_input: array; +@group(0) @binding(1) var faln_residual: array; +@group(0) @binding(2) var faln_weight: array; +@group(0) @binding(3) var faln_bias: array; +@group(0) @binding(4) var faln_output: array; +@group(0) @binding(5) var faln_pre_norm: array; +@group(0) @binding(6) var faln_params: LayerNormParams; + +var faln_shared_mean: array; +var faln_shared_var: array; + +@compute @workgroup_size(256) +fn fused_add_layer_norm_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let batch_idx = group_id.x; + + if (batch_idx >= faln_params.batch_size) { + return; + } + + let hidden_size = faln_params.hidden_size; + let eps = faln_params.eps; + let base_offset = batch_idx * hidden_size; + + // Step 1: Add input + residual -> pre_norm, compute sum for mean + var sum: f32 = 0.0; + var i: u32 = tid; + while (i < hidden_size) { + let pre_val = faln_input[base_offset + i] + faln_residual[base_offset + i]; + faln_pre_norm[base_offset + i] = pre_val; + sum = sum + pre_val; + i = i + WORKGROUP_SIZE; + } + + faln_shared_mean[tid] = sum; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + faln_shared_mean[tid] = faln_shared_mean[tid] + faln_shared_mean[tid + s]; + } + workgroupBarrier(); + } + + let mean = faln_shared_mean[0] / f32(hidden_size); + workgroupBarrier(); + + // Step 2: Compute variance + var var_sum: f32 = 0.0; + i = tid; + while (i < hidden_size) { + let diff = faln_pre_norm[base_offset + i] - mean; + var_sum = var_sum + diff * diff; + i = i + WORKGROUP_SIZE; + } + + faln_shared_var[tid] = var_sum; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + faln_shared_var[tid] = faln_shared_var[tid] + faln_shared_var[tid + s]; + } + workgroupBarrier(); + } + + let variance = faln_shared_var[0] / f32(hidden_size); + let inv_std = 1.0 / sqrt(variance + eps); + workgroupBarrier(); + + // Step 3: Normalize and apply affine transformation + i = tid; + while (i < hidden_size) { + let normalized = (faln_pre_norm[base_offset + i] - mean) * inv_std; + faln_output[base_offset + i] = normalized * faln_weight[i] + faln_bias[i]; + i = i + WORKGROUP_SIZE; + } +} + +// ============================================================================ +// Fused Add + RMS Norm (Backward) +// ============================================================================ + +@group(0) @binding(0) var farnb_grad: array; +@group(0) @binding(1) var farnb_pre_norm: array; +@group(0) @binding(2) var farnb_weight: array; +@group(0) @binding(3) var farnb_d_input_residual: array; +@group(0) @binding(4) var farnb_d_weight_scratch: array; +@group(0) @binding(5) var farnb_params: RmsNormParams; + +var farnb_shared_sum_sq: array; +var farnb_shared_dot: array; + +@compute @workgroup_size(256) +fn fused_add_rms_norm_bwd_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let batch_idx = group_id.x; + + if (batch_idx >= farnb_params.batch_size) { + return; + } + + let hidden_size = farnb_params.hidden_size; + let eps = farnb_params.eps; + let base_offset = batch_idx * hidden_size; + + // Phase 1: Compute sum_sq and dot(grad, weight, pre_norm) + var sum_sq: f32 = 0.0; + var dot: f32 = 0.0; + var i: u32 = tid; + while (i < hidden_size) { + let pre_val = farnb_pre_norm[base_offset + i]; + sum_sq = sum_sq + pre_val * pre_val; + dot = dot + farnb_grad[base_offset + i] * farnb_weight[i] * pre_val; + i = i + WORKGROUP_SIZE; + } + + farnb_shared_sum_sq[tid] = sum_sq; + farnb_shared_dot[tid] = dot; + workgroupBarrier(); + + // Reduce both sums + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + farnb_shared_sum_sq[tid] = farnb_shared_sum_sq[tid] + farnb_shared_sum_sq[tid + s]; + farnb_shared_dot[tid] = farnb_shared_dot[tid] + farnb_shared_dot[tid + s]; + } + workgroupBarrier(); + } + + let total_sum_sq = farnb_shared_sum_sq[0]; + let total_dot = farnb_shared_dot[0]; + let rms = sqrt(total_sum_sq / f32(hidden_size) + eps); + let inv_rms = 1.0 / rms; + let inv_rms_cubed = inv_rms * inv_rms * inv_rms; + let coeff = total_dot * inv_rms_cubed / f32(hidden_size); + workgroupBarrier(); + + // Phase 2: Compute d_input_residual and accumulate d_weight + i = tid; + while (i < hidden_size) { + // d_input_residual = (grad * weight - pre_norm * coeff) * inv_rms + farnb_d_input_residual[base_offset + i] = + (farnb_grad[base_offset + i] * farnb_weight[i] - farnb_pre_norm[base_offset + i] * coeff) * inv_rms; + + // d_weight contribution: sum(grad * pre_norm / rms) per element + // Each workgroup writes its per-row contribution to scratch + farnb_d_weight_scratch[base_offset + i] = farnb_grad[base_offset + i] * farnb_pre_norm[base_offset + i] * inv_rms; + + i = i + WORKGROUP_SIZE; + } +} + +// ============================================================================ +// Fused Add + Layer Norm (Backward) +// ============================================================================ + +@group(0) @binding(0) var falnb_grad: array; +@group(0) @binding(1) var falnb_pre_norm: array; +@group(0) @binding(2) var falnb_weight: array; +@group(0) @binding(3) var falnb_bias: array; +@group(0) @binding(4) var falnb_d_input_residual: array; +@group(0) @binding(5) var falnb_d_weight_scratch: array; +@group(0) @binding(6) var falnb_d_bias_scratch: array; +@group(0) @binding(7) var falnb_params: LayerNormParams; + +var falnb_shared_mean: array; +var falnb_shared_var: array; + +@compute @workgroup_size(256) +fn fused_add_layer_norm_bwd_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let batch_idx = group_id.x; + + if (batch_idx >= falnb_params.batch_size) { + return; + } + + let hidden_size = falnb_params.hidden_size; + let eps = falnb_params.eps; + let base_offset = batch_idx * hidden_size; + + // Phase 1: Compute mean of pre_norm + var sum: f32 = 0.0; + var i: u32 = tid; + while (i < hidden_size) { + sum = sum + falnb_pre_norm[base_offset + i]; + i = i + WORKGROUP_SIZE; + } + + falnb_shared_mean[tid] = sum; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + falnb_shared_mean[tid] = falnb_shared_mean[tid] + falnb_shared_mean[tid + s]; + } + workgroupBarrier(); + } + + let mean = falnb_shared_mean[0] / f32(hidden_size); + workgroupBarrier(); + + // Phase 2: Compute variance + var var_sum: f32 = 0.0; + i = tid; + while (i < hidden_size) { + let diff = falnb_pre_norm[base_offset + i] - mean; + var_sum = var_sum + diff * diff; + i = i + WORKGROUP_SIZE; + } + + falnb_shared_var[tid] = var_sum; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + falnb_shared_var[tid] = falnb_shared_var[tid] + falnb_shared_var[tid + s]; + } + workgroupBarrier(); + } + + let variance = falnb_shared_var[0] / f32(hidden_size); + let inv_std = 1.0 / sqrt(variance + eps); + + // Compute grad_scaled = grad * weight sums + var sum_gs: f32 = 0.0; + var sum_gs_n: f32 = 0.0; + i = tid; + while (i < hidden_size) { + let normalized = (falnb_pre_norm[base_offset + i] - mean) * inv_std; + let gs = falnb_grad[base_offset + i] * falnb_weight[i]; + sum_gs = sum_gs + gs; + sum_gs_n = sum_gs_n + gs * normalized; + i = i + WORKGROUP_SIZE; + } + + falnb_shared_mean[tid] = sum_gs; + falnb_shared_var[tid] = sum_gs_n; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + falnb_shared_mean[tid] = falnb_shared_mean[tid] + falnb_shared_mean[tid + s]; + falnb_shared_var[tid] = falnb_shared_var[tid] + falnb_shared_var[tid + s]; + } + workgroupBarrier(); + } + + let total_sum_gs = falnb_shared_mean[0]; + let total_sum_gs_n = falnb_shared_var[0]; + workgroupBarrier(); + + // Phase 3: Compute d_input_residual, d_weight_scratch, d_bias_scratch + i = tid; + while (i < hidden_size) { + let normalized = (falnb_pre_norm[base_offset + i] - mean) * inv_std; + + // d_input_residual = inv_std * (grad*weight - mean_gs - normalized * mean_gs_n) + let mean_gs_val = total_sum_gs / f32(hidden_size); + let mean_gs_n_val = total_sum_gs_n / f32(hidden_size); + let gs = falnb_grad[base_offset + i] * falnb_weight[i]; + falnb_d_input_residual[base_offset + i] = inv_std * + (gs - mean_gs_val - normalized * mean_gs_n_val); + + // d_weight: sum(grad * normalized) per element + falnb_d_weight_scratch[base_offset + i] = falnb_grad[base_offset + i] * normalized; + + // d_bias: sum(grad) per element + falnb_d_bias_scratch[base_offset + i] = falnb_grad[base_offset + i]; + + i = i + WORKGROUP_SIZE; + } +} + +// ============================================================================ +// Reduce Sum Rows (helper for backward) +// ============================================================================ + +@group(0) @binding(0) var rsr_input: array; +@group(0) @binding(1) var rsr_output: array; +@group(0) @binding(2) var rsr_params: ReduceSumParams; + +@compute @workgroup_size(256) +fn reduce_sum_rows_f32(@builtin(global_invocation_id) gid: vec3) { + let i = gid.x; + if (i >= rsr_params.hidden_size) { + return; + } + + var sum: f32 = 0.0; + for (var b: u32 = 0u; b < rsr_params.batch_size; b = b + 1u) { + sum = sum + rsr_input[b * rsr_params.hidden_size + i]; + } + rsr_output[i] = sum; +} diff --git a/src/runtime/wgpu/shaders/fused_elementwise.rs b/src/runtime/wgpu/shaders/fused_elementwise.rs new file mode 100644 index 00000000..e983e87f --- /dev/null +++ b/src/runtime/wgpu/shaders/fused_elementwise.rs @@ -0,0 +1,196 @@ +//! Fused elementwise WGSL kernel launchers. F32 only. + +use wgpu::{Buffer, Queue}; + +use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +const TERNARY_SHADER: &str = include_str!("fused_elementwise.wgsl"); +const SCALAR_SHADER: &str = include_str!("fused_elementwise_scalar.wgsl"); + +/// Params for ternary ops (matches TernaryParams in WGSL) +#[repr(C)] +#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +struct TernaryParams { + numel: u32, +} + +/// Params for scalar FMA (matches ScalarFmaParams in WGSL) +#[repr(C)] +#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +struct ScalarFmaParams { + numel: u32, + scale: f32, + bias: f32, + _pad: u32, +} + +fn launch_ternary( + cache: &PipelineCache, + queue: &Queue, + entry_point: &'static str, + op_name: &'static str, + a: &Buffer, + b: &Buffer, + c: &Buffer, + out: &Buffer, + numel: usize, + dtype: DType, +) -> Result<()> { + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { dtype, op: op_name }); + } + + let params = TernaryParams { + numel: numel as u32, + }; + let params_buf = cache.device().create_buffer(&wgpu::BufferDescriptor { + label: Some("fused_elem_params"), + size: std::mem::size_of::() as u64, + usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + queue.write_buffer(¶ms_buf, 0, bytemuck::bytes_of(¶ms)); + + let module = cache.get_or_create_module("fused_elementwise_f32", TERNARY_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 4, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = + cache.get_or_create_pipeline("fused_elementwise_f32", entry_point, &module, &layout); + let bind_group = cache.create_bind_group(&layout, &[a, b, c, out, ¶ms_buf]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some(op_name), + }); + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some(op_name), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + pass.dispatch_workgroups(workgroup_count(numel), 1, 1); + } + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + +/// Launch fused_mul_add: out = a * b + c. F32 only. +pub fn launch_fused_mul_add( + cache: &PipelineCache, + queue: &Queue, + a: &Buffer, + b: &Buffer, + c: &Buffer, + out: &Buffer, + numel: usize, + dtype: DType, +) -> Result<()> { + launch_ternary( + cache, + queue, + "fused_mul_add_f32", + "fused_mul_add", + a, + b, + c, + out, + numel, + dtype, + ) +} + +/// Launch fused_add_mul: out = (a + b) * c. F32 only. +pub fn launch_fused_add_mul( + cache: &PipelineCache, + queue: &Queue, + a: &Buffer, + b: &Buffer, + c: &Buffer, + out: &Buffer, + numel: usize, + dtype: DType, +) -> Result<()> { + launch_ternary( + cache, + queue, + "fused_add_mul_f32", + "fused_add_mul", + a, + b, + c, + out, + numel, + dtype, + ) +} + +/// Launch fused_mul_add_scalar: out = a * scale + bias. F32 only. +pub fn launch_fused_mul_add_scalar( + cache: &PipelineCache, + queue: &Queue, + a: &Buffer, + out: &Buffer, + numel: usize, + dtype: DType, + scale: f32, + bias: f32, +) -> Result<()> { + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "fused_mul_add_scalar", + }); + } + + let params = ScalarFmaParams { + numel: numel as u32, + scale, + bias, + _pad: 0, + }; + let params_buf = cache.device().create_buffer(&wgpu::BufferDescriptor { + label: Some("fused_elem_scalar_params"), + size: std::mem::size_of::() as u64, + usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + queue.write_buffer(¶ms_buf, 0, bytemuck::bytes_of(¶ms)); + + let module = cache.get_or_create_module("fused_elementwise_scalar_f32", SCALAR_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 2, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = cache.get_or_create_pipeline( + "fused_elementwise_scalar_f32", + "fused_mul_add_scalar_f32", + &module, + &layout, + ); + let bind_group = cache.create_bind_group(&layout, &[a, out, ¶ms_buf]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("fused_mul_add_scalar"), + }); + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("fused_mul_add_scalar"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + pass.dispatch_workgroups(workgroup_count(numel), 1, 1); + } + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} diff --git a/src/runtime/wgpu/shaders/fused_elementwise.wgsl b/src/runtime/wgpu/shaders/fused_elementwise.wgsl new file mode 100644 index 00000000..d08d739d --- /dev/null +++ b/src/runtime/wgpu/shaders/fused_elementwise.wgsl @@ -0,0 +1,41 @@ +// Fused elementwise WGSL shaders (F32 only) +// fused_mul_add: out = a * b + c +// fused_add_mul: out = (a + b) * c +// fused_mul_add_scalar: out = a * scale + bias + +struct TernaryParams { + numel: u32, +} + +struct ScalarFmaParams { + numel: u32, + scale: f32, + bias: f32, + _pad: u32, +} + +// ============================================================================ +// Ternary ops: 3 inputs (a, b, c), 1 output +// ============================================================================ + +@group(0) @binding(0) var tern_a: array; +@group(0) @binding(1) var tern_b: array; +@group(0) @binding(2) var tern_c: array; +@group(0) @binding(3) var tern_out: array; +@group(0) @binding(4) var tern_params: TernaryParams; + +@compute @workgroup_size(256) +fn fused_mul_add_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < tern_params.numel) { + tern_out[idx] = fma(tern_a[idx], tern_b[idx], tern_c[idx]); + } +} + +@compute @workgroup_size(256) +fn fused_add_mul_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < tern_params.numel) { + tern_out[idx] = (tern_a[idx] + tern_b[idx]) * tern_c[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/fused_elementwise_scalar.wgsl b/src/runtime/wgpu/shaders/fused_elementwise_scalar.wgsl new file mode 100644 index 00000000..ac33f0fb --- /dev/null +++ b/src/runtime/wgpu/shaders/fused_elementwise_scalar.wgsl @@ -0,0 +1,21 @@ +// Fused elementwise scalar WGSL shader (F32 only) +// fused_mul_add_scalar: out = a * scale + bias + +struct ScalarFmaParams { + numel: u32, + scale: f32, + bias: f32, + _pad: u32, +} + +@group(0) @binding(0) var sfma_a: array; +@group(0) @binding(1) var sfma_out: array; +@group(0) @binding(2) var sfma_params: ScalarFmaParams; + +@compute @workgroup_size(256) +fn fused_mul_add_scalar_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < sfma_params.numel) { + sfma_out[idx] = fma(sfma_a[idx], sfma_params.scale, sfma_params.bias); + } +} diff --git a/src/runtime/wgpu/shaders/gamma_dist_f32.wgsl b/src/runtime/wgpu/shaders/gamma_dist_f32.wgsl new file mode 100644 index 00000000..c72f36a4 --- /dev/null +++ b/src/runtime/wgpu/shaders/gamma_dist_f32.wgsl @@ -0,0 +1,90 @@ +// Gamma distribution sampling for f32 + +// PCG hash function for random number generation +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; +} + +// Box-Muller for normal distribution +fn sample_normal(state: ptr) -> f32 { + let u1 = max(pcg_uniform(state), 0.0000001); + let u2 = pcg_uniform(state); + return sqrt(-2.0 * log(u1)) * cos(6.28318530718 * u2); +} + +// Gamma via Marsaglia-Tsang method +fn sample_gamma_mt(state: ptr, shape: f32, scale: f32) -> f32 { + var alpha = shape; + var boost = 1.0; + + // Handle shape < 1 by boosting + if alpha < 1.0 { + boost = pow(pcg_uniform(state), 1.0 / alpha); + alpha = alpha + 1.0; + } + + let d = alpha - 1.0 / 3.0; + let c = 1.0 / sqrt(9.0 * d); + + // Rejection sampling + for (var i = 0u; i < 100u; i = i + 1u) { + var x: f32; + var v: f32; + + // Generate valid v + for (var j = 0u; j < 100u; j = j + 1u) { + x = sample_normal(state); + v = 1.0 + c * x; + if v > 0.0 { + break; + } + } + + v = v * v * v; + let u = pcg_uniform(state); + let x2 = x * x; + + // Accept/reject + if u < 1.0 - 0.0331 * x2 * x2 { + return d * v * boost * scale; + } + if log(u) < 0.5 * x2 + d * (1.0 - v + log(v)) { + return d * v * boost * scale; + } + } + + // Fallback (should rarely reach) + return d * boost * scale; +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct GammaParams { + numel: u32, + seed: u32, + shape: f32, + scale: f32, +} + +@group(0) @binding(0) var out: array; +@group(0) @binding(1) var params: GammaParams; + +@compute @workgroup_size(256) +fn gamma_dist_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if idx < params.numel { + var state = pcg_init(params.seed, idx); + out[idx] = f32(sample_gamma_mt(&state, params.shape, params.scale)); + } +} diff --git a/src/runtime/wgpu/shaders/gather_2d_f32.wgsl b/src/runtime/wgpu/shaders/gather_2d_f32.wgsl new file mode 100644 index 00000000..43ec5288 --- /dev/null +++ b/src/runtime/wgpu/shaders/gather_2d_f32.wgsl @@ -0,0 +1,38 @@ +// Auto-generated gather_2d operation for f32 +// Gathers elements from a 2D matrix at (row, col) positions. + +const WORKGROUP_SIZE: u32 = 256u; + +struct Gather2dParams { + nrows: u32, + ncols: u32, + num_indices: u32, + _pad: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var rows: array; +@group(0) @binding(2) var cols: array; +@group(0) @binding(3) var output: array; +@group(0) @binding(4) var params: Gather2dParams; + +@compute @workgroup_size(256) +fn gather_2d_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.num_indices) { + return; + } + + let r = rows[idx]; + let c = cols[idx]; + + // Bounds checking + if (r < 0 || u32(r) >= params.nrows || c < 0 || u32(c) >= params.ncols) { + output[idx] = 0.0; + return; + } + + // Row-major indexing: input[r, c] = input[r * ncols + c] + let input_idx = u32(r) * params.ncols + u32(c); + output[idx] = input[input_idx]; +} diff --git a/src/runtime/wgpu/shaders/gather_2d_i32.wgsl b/src/runtime/wgpu/shaders/gather_2d_i32.wgsl new file mode 100644 index 00000000..c7b8b837 --- /dev/null +++ b/src/runtime/wgpu/shaders/gather_2d_i32.wgsl @@ -0,0 +1,38 @@ +// Auto-generated gather_2d operation for i32 +// Gathers elements from a 2D matrix at (row, col) positions. + +const WORKGROUP_SIZE: u32 = 256u; + +struct Gather2dParams { + nrows: u32, + ncols: u32, + num_indices: u32, + _pad: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var rows: array; +@group(0) @binding(2) var cols: array; +@group(0) @binding(3) var output: array; +@group(0) @binding(4) var params: Gather2dParams; + +@compute @workgroup_size(256) +fn gather_2d_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.num_indices) { + return; + } + + let r = rows[idx]; + let c = cols[idx]; + + // Bounds checking + if (r < 0 || u32(r) >= params.nrows || c < 0 || u32(c) >= params.ncols) { + output[idx] = 0; + return; + } + + // Row-major indexing: input[r, c] = input[r * ncols + c] + let input_idx = u32(r) * params.ncols + u32(c); + output[idx] = input[input_idx]; +} diff --git a/src/runtime/wgpu/shaders/gather_2d_u32.wgsl b/src/runtime/wgpu/shaders/gather_2d_u32.wgsl new file mode 100644 index 00000000..43210456 --- /dev/null +++ b/src/runtime/wgpu/shaders/gather_2d_u32.wgsl @@ -0,0 +1,38 @@ +// Auto-generated gather_2d operation for u32 +// Gathers elements from a 2D matrix at (row, col) positions. + +const WORKGROUP_SIZE: u32 = 256u; + +struct Gather2dParams { + nrows: u32, + ncols: u32, + num_indices: u32, + _pad: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var rows: array; +@group(0) @binding(2) var cols: array; +@group(0) @binding(3) var output: array; +@group(0) @binding(4) var params: Gather2dParams; + +@compute @workgroup_size(256) +fn gather_2d_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.num_indices) { + return; + } + + let r = rows[idx]; + let c = cols[idx]; + + // Bounds checking + if (r < 0 || u32(r) >= params.nrows || c < 0 || u32(c) >= params.ncols) { + output[idx] = 0u; + return; + } + + // Row-major indexing: input[r, c] = input[r * ncols + c] + let input_idx = u32(r) * params.ncols + u32(c); + output[idx] = input[input_idx]; +} diff --git a/src/runtime/wgpu/shaders/gather_f32.wgsl b/src/runtime/wgpu/shaders/gather_f32.wgsl new file mode 100644 index 00000000..3a9cbb97 --- /dev/null +++ b/src/runtime/wgpu/shaders/gather_f32.wgsl @@ -0,0 +1,59 @@ +// Auto-generated gather operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_DIMS: u32 = 4u; + +struct GatherParams { + ndim: u32, + dim: u32, + total_elements: u32, + _padding: u32, + // Shape and strides packed: [input_shape[0..4], input_strides[0..4], output_shape[0..4], output_strides[0..4]] + input_shape: vec4, + input_strides: vec4, + output_shape: vec4, + output_strides: vec4, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var indices: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: GatherParams; + +fn get_shape(arr: vec4, d: u32) -> u32 { + if (d == 0u) { return arr.x; } + else if (d == 1u) { return arr.y; } + else if (d == 2u) { return arr.z; } + else { return arr.w; } +} + +@compute @workgroup_size(256) +fn gather_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.total_elements) { + return; + } + + var remaining = idx; + var src_offset: u32 = 0u; + + for (var d: u32 = 0u; d < params.ndim; d = d + 1u) { + let out_stride = get_shape(params.output_strides, d); + let coord = remaining / out_stride; + remaining = remaining % out_stride; + + if (d == params.dim) { + let index_val = indices[idx]; + let dim_size = get_shape(params.input_shape, d); + if (index_val < 0 || u32(index_val) >= dim_size) { + output[idx] = 0.0; + return; + } + src_offset = src_offset + u32(index_val) * get_shape(params.input_strides, d); + } else { + src_offset = src_offset + coord * get_shape(params.input_strides, d); + } + } + + output[idx] = input[src_offset]; +} diff --git a/src/runtime/wgpu/shaders/gather_i32.wgsl b/src/runtime/wgpu/shaders/gather_i32.wgsl new file mode 100644 index 00000000..6b7a167b --- /dev/null +++ b/src/runtime/wgpu/shaders/gather_i32.wgsl @@ -0,0 +1,59 @@ +// Auto-generated gather operations for i32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_DIMS: u32 = 4u; + +struct GatherParams { + ndim: u32, + dim: u32, + total_elements: u32, + _padding: u32, + // Shape and strides packed: [input_shape[0..4], input_strides[0..4], output_shape[0..4], output_strides[0..4]] + input_shape: vec4, + input_strides: vec4, + output_shape: vec4, + output_strides: vec4, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var indices: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: GatherParams; + +fn get_shape(arr: vec4, d: u32) -> u32 { + if (d == 0u) { return arr.x; } + else if (d == 1u) { return arr.y; } + else if (d == 2u) { return arr.z; } + else { return arr.w; } +} + +@compute @workgroup_size(256) +fn gather_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.total_elements) { + return; + } + + var remaining = idx; + var src_offset: u32 = 0u; + + for (var d: u32 = 0u; d < params.ndim; d = d + 1u) { + let out_stride = get_shape(params.output_strides, d); + let coord = remaining / out_stride; + remaining = remaining % out_stride; + + if (d == params.dim) { + let index_val = indices[idx]; + let dim_size = get_shape(params.input_shape, d); + if (index_val < 0 || u32(index_val) >= dim_size) { + output[idx] = 0; + return; + } + src_offset = src_offset + u32(index_val) * get_shape(params.input_strides, d); + } else { + src_offset = src_offset + coord * get_shape(params.input_strides, d); + } + } + + output[idx] = input[src_offset]; +} diff --git a/src/runtime/wgpu/shaders/gather_nd_f32.wgsl b/src/runtime/wgpu/shaders/gather_nd_f32.wgsl new file mode 100644 index 00000000..aa0bb412 --- /dev/null +++ b/src/runtime/wgpu/shaders/gather_nd_f32.wgsl @@ -0,0 +1,56 @@ +// Auto-generated gather_nd operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_DIMS: u32 = 8u; + +struct GatherNdParams { + num_slices: u32, + slice_size: u32, + index_depth: u32, + ndim: u32, + input_shape: array, + input_strides: array, +} + +@group(0) @binding(0) var gather_nd_input: array; +@group(0) @binding(1) var gather_nd_indices: array; +@group(0) @binding(2) var gather_nd_output: array; +@group(0) @binding(3) var gather_nd_params: GatherNdParams; + +@compute @workgroup_size(256) +fn gather_nd_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = gather_nd_params.num_slices * gather_nd_params.slice_size; + if (idx >= total) { + return; + } + + let slice_idx = idx / gather_nd_params.slice_size; + let element_in_slice = idx % gather_nd_params.slice_size; + + // Compute input offset from indices + var input_offset: u32 = 0u; + let indices_offset = slice_idx * gather_nd_params.index_depth; + + for (var d: u32 = 0u; d < gather_nd_params.index_depth; d = d + 1u) { + let coord = gather_nd_indices[indices_offset + d]; + if (coord < 0 || u32(coord) >= gather_nd_params.input_shape[d]) { + gather_nd_output[idx] = 0.0; + return; + } + input_offset = input_offset + u32(coord) * gather_nd_params.input_strides[d]; + } + + // Add offset for element within slice + if (gather_nd_params.slice_size > 1u) { + var remaining = element_in_slice; + for (var d: u32 = gather_nd_params.index_depth; d < gather_nd_params.ndim; d = d + 1u) { + let dim_size = gather_nd_params.input_shape[d]; + let coord = remaining / gather_nd_params.input_strides[d]; + remaining = remaining % gather_nd_params.input_strides[d]; + input_offset = input_offset + coord * gather_nd_params.input_strides[d]; + } + } + + gather_nd_output[idx] = gather_nd_input[input_offset + element_in_slice]; +} diff --git a/src/runtime/wgpu/shaders/gather_nd_i32.wgsl b/src/runtime/wgpu/shaders/gather_nd_i32.wgsl new file mode 100644 index 00000000..6e236513 --- /dev/null +++ b/src/runtime/wgpu/shaders/gather_nd_i32.wgsl @@ -0,0 +1,56 @@ +// Auto-generated gather_nd operations for i32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_DIMS: u32 = 8u; + +struct GatherNdParams { + num_slices: u32, + slice_size: u32, + index_depth: u32, + ndim: u32, + input_shape: array, + input_strides: array, +} + +@group(0) @binding(0) var gather_nd_input: array; +@group(0) @binding(1) var gather_nd_indices: array; +@group(0) @binding(2) var gather_nd_output: array; +@group(0) @binding(3) var gather_nd_params: GatherNdParams; + +@compute @workgroup_size(256) +fn gather_nd_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = gather_nd_params.num_slices * gather_nd_params.slice_size; + if (idx >= total) { + return; + } + + let slice_idx = idx / gather_nd_params.slice_size; + let element_in_slice = idx % gather_nd_params.slice_size; + + // Compute input offset from indices + var input_offset: u32 = 0u; + let indices_offset = slice_idx * gather_nd_params.index_depth; + + for (var d: u32 = 0u; d < gather_nd_params.index_depth; d = d + 1u) { + let coord = gather_nd_indices[indices_offset + d]; + if (coord < 0 || u32(coord) >= gather_nd_params.input_shape[d]) { + gather_nd_output[idx] = 0; + return; + } + input_offset = input_offset + u32(coord) * gather_nd_params.input_strides[d]; + } + + // Add offset for element within slice + if (gather_nd_params.slice_size > 1u) { + var remaining = element_in_slice; + for (var d: u32 = gather_nd_params.index_depth; d < gather_nd_params.ndim; d = d + 1u) { + let dim_size = gather_nd_params.input_shape[d]; + let coord = remaining / gather_nd_params.input_strides[d]; + remaining = remaining % gather_nd_params.input_strides[d]; + input_offset = input_offset + coord * gather_nd_params.input_strides[d]; + } + } + + gather_nd_output[idx] = gather_nd_input[input_offset + element_in_slice]; +} diff --git a/src/runtime/wgpu/shaders/gather_nd_u32.wgsl b/src/runtime/wgpu/shaders/gather_nd_u32.wgsl new file mode 100644 index 00000000..d3405a69 --- /dev/null +++ b/src/runtime/wgpu/shaders/gather_nd_u32.wgsl @@ -0,0 +1,56 @@ +// Auto-generated gather_nd operations for u32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_DIMS: u32 = 8u; + +struct GatherNdParams { + num_slices: u32, + slice_size: u32, + index_depth: u32, + ndim: u32, + input_shape: array, + input_strides: array, +} + +@group(0) @binding(0) var gather_nd_input: array; +@group(0) @binding(1) var gather_nd_indices: array; +@group(0) @binding(2) var gather_nd_output: array; +@group(0) @binding(3) var gather_nd_params: GatherNdParams; + +@compute @workgroup_size(256) +fn gather_nd_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = gather_nd_params.num_slices * gather_nd_params.slice_size; + if (idx >= total) { + return; + } + + let slice_idx = idx / gather_nd_params.slice_size; + let element_in_slice = idx % gather_nd_params.slice_size; + + // Compute input offset from indices + var input_offset: u32 = 0u; + let indices_offset = slice_idx * gather_nd_params.index_depth; + + for (var d: u32 = 0u; d < gather_nd_params.index_depth; d = d + 1u) { + let coord = gather_nd_indices[indices_offset + d]; + if (coord < 0 || u32(coord) >= gather_nd_params.input_shape[d]) { + gather_nd_output[idx] = 0u; + return; + } + input_offset = input_offset + u32(coord) * gather_nd_params.input_strides[d]; + } + + // Add offset for element within slice + if (gather_nd_params.slice_size > 1u) { + var remaining = element_in_slice; + for (var d: u32 = gather_nd_params.index_depth; d < gather_nd_params.ndim; d = d + 1u) { + let dim_size = gather_nd_params.input_shape[d]; + let coord = remaining / gather_nd_params.input_strides[d]; + remaining = remaining % gather_nd_params.input_strides[d]; + input_offset = input_offset + coord * gather_nd_params.input_strides[d]; + } + } + + gather_nd_output[idx] = gather_nd_input[input_offset + element_in_slice]; +} diff --git a/src/runtime/wgpu/shaders/gather_nonzero_f32.wgsl b/src/runtime/wgpu/shaders/gather_nonzero_f32.wgsl new file mode 100644 index 00000000..a07fc222 --- /dev/null +++ b/src/runtime/wgpu/shaders/gather_nonzero_f32.wgsl @@ -0,0 +1,26 @@ +// Auto-generated gather_nonzero operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var indices_output: array; +@group(0) @binding(2) var counter: array>; +@group(0) @binding(3) var count_params: CountParams; + +@compute @workgroup_size(256) +fn gather_nonzero_f32(@builtin(global_invocation_id) global_id: vec3) { + let numel = count_params.numel; + var idx = global_id.x; + + while (idx < numel) { + if (input[idx] != 0.0) { + let out_idx = atomicAdd(&counter[0], 1u); + indices_output[out_idx] = i32(idx); + } + idx = idx + WORKGROUP_SIZE * 256u; + } +} diff --git a/src/runtime/wgpu/shaders/gather_nonzero_i32.wgsl b/src/runtime/wgpu/shaders/gather_nonzero_i32.wgsl new file mode 100644 index 00000000..d28dbaca --- /dev/null +++ b/src/runtime/wgpu/shaders/gather_nonzero_i32.wgsl @@ -0,0 +1,26 @@ +// Auto-generated gather_nonzero operations for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var indices_output: array; +@group(0) @binding(2) var counter: array>; +@group(0) @binding(3) var count_params: CountParams; + +@compute @workgroup_size(256) +fn gather_nonzero_i32(@builtin(global_invocation_id) global_id: vec3) { + let numel = count_params.numel; + var idx = global_id.x; + + while (idx < numel) { + if (input[idx] != 0) { + let out_idx = atomicAdd(&counter[0], 1u); + indices_output[out_idx] = i32(idx); + } + idx = idx + WORKGROUP_SIZE * 256u; + } +} diff --git a/src/runtime/wgpu/shaders/gather_nonzero_u32.wgsl b/src/runtime/wgpu/shaders/gather_nonzero_u32.wgsl new file mode 100644 index 00000000..890cee20 --- /dev/null +++ b/src/runtime/wgpu/shaders/gather_nonzero_u32.wgsl @@ -0,0 +1,26 @@ +// Auto-generated gather_nonzero operations for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var indices_output: array; +@group(0) @binding(2) var counter: array>; +@group(0) @binding(3) var count_params: CountParams; + +@compute @workgroup_size(256) +fn gather_nonzero_u32(@builtin(global_invocation_id) global_id: vec3) { + let numel = count_params.numel; + var idx = global_id.x; + + while (idx < numel) { + if (input[idx] != 0u) { + let out_idx = atomicAdd(&counter[0], 1u); + indices_output[out_idx] = i32(idx); + } + idx = idx + WORKGROUP_SIZE * 256u; + } +} diff --git a/src/runtime/wgpu/shaders/gather_u32.wgsl b/src/runtime/wgpu/shaders/gather_u32.wgsl new file mode 100644 index 00000000..ce65415f --- /dev/null +++ b/src/runtime/wgpu/shaders/gather_u32.wgsl @@ -0,0 +1,59 @@ +// Auto-generated gather operations for u32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_DIMS: u32 = 4u; + +struct GatherParams { + ndim: u32, + dim: u32, + total_elements: u32, + _padding: u32, + // Shape and strides packed: [input_shape[0..4], input_strides[0..4], output_shape[0..4], output_strides[0..4]] + input_shape: vec4, + input_strides: vec4, + output_shape: vec4, + output_strides: vec4, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var indices: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: GatherParams; + +fn get_shape(arr: vec4, d: u32) -> u32 { + if (d == 0u) { return arr.x; } + else if (d == 1u) { return arr.y; } + else if (d == 2u) { return arr.z; } + else { return arr.w; } +} + +@compute @workgroup_size(256) +fn gather_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.total_elements) { + return; + } + + var remaining = idx; + var src_offset: u32 = 0u; + + for (var d: u32 = 0u; d < params.ndim; d = d + 1u) { + let out_stride = get_shape(params.output_strides, d); + let coord = remaining / out_stride; + remaining = remaining % out_stride; + + if (d == params.dim) { + let index_val = indices[idx]; + let dim_size = get_shape(params.input_shape, d); + if (index_val < 0 || u32(index_val) >= dim_size) { + output[idx] = 0u; + return; + } + src_offset = src_offset + u32(index_val) * get_shape(params.input_strides, d); + } else { + src_offset = src_offset + coord * get_shape(params.input_strides, d); + } + } + + output[idx] = input[src_offset]; +} diff --git a/src/runtime/wgpu/shaders/gemm_epilogue.rs b/src/runtime/wgpu/shaders/gemm_epilogue.rs new file mode 100644 index 00000000..7d36f5a1 --- /dev/null +++ b/src/runtime/wgpu/shaders/gemm_epilogue.rs @@ -0,0 +1,334 @@ +//! WGSL kernel launchers for GEMM epilogue operations. F32 only. + +use wgpu::{Buffer, Queue}; + +use super::pipeline::{LayoutKey, PipelineCache}; +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::ops::GemmActivation; + +const GEMM_EPILOGUE_SHADER: &str = include_str!("gemm_epilogue_f32.wgsl"); +const GEMM_EPILOGUE_RESIDUAL_SHADER: &str = include_str!("gemm_epilogue_residual_f32.wgsl"); + +const TILE_SIZE: u32 = 16; + +fn activation_to_u32(act: GemmActivation) -> u32 { + match act { + GemmActivation::None => 0, + GemmActivation::ReLU => 1, + GemmActivation::GELU => 2, + GemmActivation::SiLU => 3, + GemmActivation::Sigmoid => 4, + GemmActivation::Tanh => 5, + } +} + +/// Params struct for the activation shader (8 u32s for alignment). +#[repr(C)] +#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +pub struct GemmEpilogueParams { + /// Number of rows of A / output. + pub m: u32, + /// Inner dimension (columns of A, rows of B). + pub k: u32, + /// Number of columns of B / output. + pub n: u32, + /// Number of batches (1 for non-batched). + pub batch_size: u32, + /// Activation function index (0=None, 1=ReLU, 2=GELU, 3=SiLU, 4=Sigmoid, 5=Tanh). + pub activation_type: u32, + /// Padding for 32-byte alignment. + pub _pad0: u32, + /// Padding for 32-byte alignment. + pub _pad1: u32, + /// Padding for 32-byte alignment. + pub _pad2: u32, +} + +/// Params struct for the residual shader. +#[repr(C)] +#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +pub struct GemmResidualParams { + /// Number of rows of A / output. + pub m: u32, + /// Inner dimension (columns of A, rows of B). + pub k: u32, + /// Number of columns of B / output. + pub n: u32, + /// Number of batches (1 for non-batched). + pub batch_size: u32, +} + +fn check_f32(dtype: DType, op: &'static str) -> Result<()> { + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { dtype, op }); + } + Ok(()) +} + +/// Launch fused GEMM + bias + activation (2D). +#[allow(clippy::too_many_arguments)] +pub fn launch_gemm_bias_act( + cache: &PipelineCache, + queue: &Queue, + a: &Buffer, + b: &Buffer, + bias: &Buffer, + c: &Buffer, + params_buffer: &Buffer, + m: usize, + n: usize, + dtype: DType, +) -> Result<()> { + check_f32(dtype, "gemm_bias_act")?; + + let module = cache.get_or_create_module("gemm_epilogue_f32", GEMM_EPILOGUE_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 4, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = + cache.get_or_create_pipeline("gemm_bias_act_f32", "gemm_bias_act_f32", &module, &layout); + + let bind_group = cache.create_bind_group(&layout, &[a, b, bias, c, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("gemm_bias_act"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("gemm_bias_act"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + let gx = (n as u32 + TILE_SIZE - 1) / TILE_SIZE; + let gy = (m as u32 + TILE_SIZE - 1) / TILE_SIZE; + pass.dispatch_workgroups(gx, gy, 1); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + +/// Launch batched fused GEMM + bias + activation (3D). +#[allow(clippy::too_many_arguments)] +pub fn launch_gemm_bias_act_batched( + cache: &PipelineCache, + queue: &Queue, + a: &Buffer, + b: &Buffer, + bias: &Buffer, + c: &Buffer, + params_buffer: &Buffer, + m: usize, + n: usize, + batch_size: usize, + dtype: DType, +) -> Result<()> { + check_f32(dtype, "gemm_bias_act_batched")?; + + let module = cache.get_or_create_module("gemm_epilogue_f32", GEMM_EPILOGUE_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 4, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = cache.get_or_create_pipeline( + "gemm_bias_act_batched_f32", + "gemm_bias_act_batched_f32", + &module, + &layout, + ); + + let bind_group = cache.create_bind_group(&layout, &[a, b, bias, c, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("gemm_bias_act_batched"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("gemm_bias_act_batched"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + let gx = (n as u32 + TILE_SIZE - 1) / TILE_SIZE; + let gy = (m as u32 + TILE_SIZE - 1) / TILE_SIZE; + pass.dispatch_workgroups(gx, gy, batch_size as u32); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + +/// Launch fused GEMM + bias + residual (2D). +#[allow(clippy::too_many_arguments)] +pub fn launch_gemm_bias_residual( + cache: &PipelineCache, + queue: &Queue, + a: &Buffer, + b: &Buffer, + bias: &Buffer, + residual: &Buffer, + c: &Buffer, + params_buffer: &Buffer, + m: usize, + n: usize, + dtype: DType, +) -> Result<()> { + check_f32(dtype, "gemm_bias_residual")?; + + let module = + cache.get_or_create_module("gemm_epilogue_residual_f32", GEMM_EPILOGUE_RESIDUAL_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 5, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = cache.get_or_create_pipeline( + "gemm_bias_residual_f32", + "gemm_bias_residual_f32", + &module, + &layout, + ); + + let bind_group = cache.create_bind_group(&layout, &[a, b, bias, residual, c, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("gemm_bias_residual"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("gemm_bias_residual"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + let gx = (n as u32 + TILE_SIZE - 1) / TILE_SIZE; + let gy = (m as u32 + TILE_SIZE - 1) / TILE_SIZE; + pass.dispatch_workgroups(gx, gy, 1); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + +/// Launch batched fused GEMM + bias + residual (3D). +#[allow(clippy::too_many_arguments)] +pub fn launch_gemm_bias_residual_batched( + cache: &PipelineCache, + queue: &Queue, + a: &Buffer, + b: &Buffer, + bias: &Buffer, + residual: &Buffer, + c: &Buffer, + params_buffer: &Buffer, + m: usize, + n: usize, + batch_size: usize, + dtype: DType, +) -> Result<()> { + check_f32(dtype, "gemm_bias_residual_batched")?; + + let module = + cache.get_or_create_module("gemm_epilogue_residual_f32", GEMM_EPILOGUE_RESIDUAL_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 5, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = cache.get_or_create_pipeline( + "gemm_bias_residual_batched_f32", + "gemm_bias_residual_batched_f32", + &module, + &layout, + ); + + let bind_group = cache.create_bind_group(&layout, &[a, b, bias, residual, c, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("gemm_bias_residual_batched"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("gemm_bias_residual_batched"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + let gx = (n as u32 + TILE_SIZE - 1) / TILE_SIZE; + let gy = (m as u32 + TILE_SIZE - 1) / TILE_SIZE; + pass.dispatch_workgroups(gx, gy, batch_size as u32); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + +/// Create a uniform buffer for the activation params. +pub fn create_epilogue_params_buffer( + cache: &PipelineCache, + m: u32, + k: u32, + n: u32, + batch_size: u32, + activation: GemmActivation, +) -> Buffer { + let params = GemmEpilogueParams { + m, + k, + n, + batch_size, + activation_type: activation_to_u32(activation), + _pad0: 0, + _pad1: 0, + _pad2: 0, + }; + use wgpu::util::DeviceExt; + cache + .device() + .create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("gemm_epilogue_params"), + contents: bytemuck::bytes_of(¶ms), + usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, + }) +} + +/// Create a uniform buffer for the residual params. +pub fn create_residual_params_buffer( + cache: &PipelineCache, + m: u32, + k: u32, + n: u32, + batch_size: u32, +) -> Buffer { + let params = GemmResidualParams { + m, + k, + n, + batch_size, + }; + use wgpu::util::DeviceExt; + cache + .device() + .create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("gemm_residual_params"), + contents: bytemuck::bytes_of(¶ms), + usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST, + }) +} diff --git a/src/runtime/wgpu/shaders/gemm_epilogue_f32.wgsl b/src/runtime/wgpu/shaders/gemm_epilogue_f32.wgsl new file mode 100644 index 00000000..b2923ee5 --- /dev/null +++ b/src/runtime/wgpu/shaders/gemm_epilogue_f32.wgsl @@ -0,0 +1,131 @@ +// Fused GEMM + bias + activation. F32 only. +// C = activation(A @ B + bias) +// activation_type in params: 0=None, 1=ReLU, 2=GELU, 3=SiLU, 4=Sigmoid, 5=Tanh + +const TILE_SIZE: u32 = 16u; + +var tile_a: array, 16>; +var tile_b: array, 16>; + +struct GemmEpilogueParams { + M: u32, + K: u32, + N: u32, + batch_size: u32, + activation_type: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var a: array; +@group(0) @binding(1) var b: array; +@group(0) @binding(2) var bias: array; +@group(0) @binding(3) var c: array; +@group(0) @binding(4) var params: GemmEpilogueParams; + +fn apply_activation(x: f32, act_type: u32) -> f32 { + switch act_type { + case 1u: { + return max(x, 0.0); + } + case 2u: { + let s = 0.7978845608; + let co = 0.044715; + let inner = s * (x + co * x * x * x); + return 0.5 * x * (1.0 + tanh(inner)); + } + case 3u: { + return x / (1.0 + exp(-x)); + } + case 4u: { + return 1.0 / (1.0 + exp(-x)); + } + case 5u: { + return tanh(x); + } + default: { + return x; + } + } +} + +@compute @workgroup_size(16, 16, 1) +fn gemm_bias_act_f32(@builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let M = params.M; + let K = params.K; + let N = params.N; + let row = group_id.y * TILE_SIZE + local_id.y; + let col = group_id.x * TILE_SIZE + local_id.x; + + var sum: f32 = 0.0; + let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE; + + for (var t: u32 = 0u; t < num_tiles; t = t + 1u) { + let a_col = t * TILE_SIZE + local_id.x; + if (row < M && a_col < K) { + tile_a[local_id.y][local_id.x] = a[row * K + a_col]; + } else { + tile_a[local_id.y][local_id.x] = 0.0; + } + let b_row = t * TILE_SIZE + local_id.y; + if (b_row < K && col < N) { + tile_b[local_id.y][local_id.x] = b[b_row * N + col]; + } else { + tile_b[local_id.y][local_id.x] = 0.0; + } + workgroupBarrier(); + for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) { + sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x]; + } + workgroupBarrier(); + } + + if (row < M && col < N) { + c[row * N + col] = apply_activation(sum + bias[col], params.activation_type); + } +} + +@compute @workgroup_size(16, 16, 1) +fn gemm_bias_act_batched_f32(@builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let M = params.M; + let K = params.K; + let N = params.N; + let batch = group_id.z; + if (batch >= params.batch_size) { return; } + + let row = group_id.y * TILE_SIZE + local_id.y; + let col = group_id.x * TILE_SIZE + local_id.x; + let a_off = batch * M * K; + let b_off = batch * K * N; + let c_off = batch * M * N; + + var sum: f32 = 0.0; + let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE; + + for (var t: u32 = 0u; t < num_tiles; t = t + 1u) { + let a_col = t * TILE_SIZE + local_id.x; + if (row < M && a_col < K) { + tile_a[local_id.y][local_id.x] = a[a_off + row * K + a_col]; + } else { + tile_a[local_id.y][local_id.x] = 0.0; + } + let b_row = t * TILE_SIZE + local_id.y; + if (b_row < K && col < N) { + tile_b[local_id.y][local_id.x] = b[b_off + b_row * N + col]; + } else { + tile_b[local_id.y][local_id.x] = 0.0; + } + workgroupBarrier(); + for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) { + sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x]; + } + workgroupBarrier(); + } + + if (row < M && col < N) { + c[c_off + row * N + col] = apply_activation(sum + bias[col], params.activation_type); + } +} diff --git a/src/runtime/wgpu/shaders/gemm_epilogue_residual_f32.wgsl b/src/runtime/wgpu/shaders/gemm_epilogue_residual_f32.wgsl new file mode 100644 index 00000000..a39c4e34 --- /dev/null +++ b/src/runtime/wgpu/shaders/gemm_epilogue_residual_f32.wgsl @@ -0,0 +1,103 @@ +// Fused GEMM + bias + residual. F32 only. +// C = A @ B + bias + residual + +const TILE_SIZE: u32 = 16u; + +var tile_a: array, 16>; +var tile_b: array, 16>; + +struct GemmResidualParams { + M: u32, + K: u32, + N: u32, + batch_size: u32, +} + +@group(0) @binding(0) var a: array; +@group(0) @binding(1) var b: array; +@group(0) @binding(2) var bias: array; +@group(0) @binding(3) var residual: array; +@group(0) @binding(4) var c: array; +@group(0) @binding(5) var params: GemmResidualParams; + +@compute @workgroup_size(16, 16, 1) +fn gemm_bias_residual_f32(@builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let M = params.M; + let K = params.K; + let N = params.N; + let row = group_id.y * TILE_SIZE + local_id.y; + let col = group_id.x * TILE_SIZE + local_id.x; + + var sum: f32 = 0.0; + let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE; + + for (var t: u32 = 0u; t < num_tiles; t = t + 1u) { + let a_col = t * TILE_SIZE + local_id.x; + if (row < M && a_col < K) { + tile_a[local_id.y][local_id.x] = a[row * K + a_col]; + } else { + tile_a[local_id.y][local_id.x] = 0.0; + } + let b_row = t * TILE_SIZE + local_id.y; + if (b_row < K && col < N) { + tile_b[local_id.y][local_id.x] = b[b_row * N + col]; + } else { + tile_b[local_id.y][local_id.x] = 0.0; + } + workgroupBarrier(); + for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) { + sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x]; + } + workgroupBarrier(); + } + + if (row < M && col < N) { + let idx = row * N + col; + c[idx] = sum + bias[col] + residual[idx]; + } +} + +@compute @workgroup_size(16, 16, 1) +fn gemm_bias_residual_batched_f32(@builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let M = params.M; + let K = params.K; + let N = params.N; + let batch = group_id.z; + if (batch >= params.batch_size) { return; } + + let row = group_id.y * TILE_SIZE + local_id.y; + let col = group_id.x * TILE_SIZE + local_id.x; + let a_off = batch * M * K; + let b_off = batch * K * N; + let c_off = batch * M * N; + + var sum: f32 = 0.0; + let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE; + + for (var t: u32 = 0u; t < num_tiles; t = t + 1u) { + let a_col = t * TILE_SIZE + local_id.x; + if (row < M && a_col < K) { + tile_a[local_id.y][local_id.x] = a[a_off + row * K + a_col]; + } else { + tile_a[local_id.y][local_id.x] = 0.0; + } + let b_row = t * TILE_SIZE + local_id.y; + if (b_row < K && col < N) { + tile_b[local_id.y][local_id.x] = b[b_off + b_row * N + col]; + } else { + tile_b[local_id.y][local_id.x] = 0.0; + } + workgroupBarrier(); + for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) { + sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x]; + } + workgroupBarrier(); + } + + if (row < M && col < N) { + let idx = c_off + row * N + col; + c[idx] = sum + bias[col] + residual[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/gemv_bt.rs b/src/runtime/wgpu/shaders/gemv_bt.rs new file mode 100644 index 00000000..6e630a44 --- /dev/null +++ b/src/runtime/wgpu/shaders/gemv_bt.rs @@ -0,0 +1,117 @@ +//! GEMV-BT WGSL kernel launchers: C[M,N] = A[M,K] @ B^T where B is [N,K]. +//! +//! Avoids the GPU-side contiguous copy of transposed weight matrices by +//! reading B in its native [N,K] layout. Each output element is a dot product +//! of contiguous A and B row vectors, computed via parallel reduction. + +use wgpu::{Buffer, Queue}; + +use super::pipeline::{LayoutKey, PipelineCache}; +use crate::dtype::DType; +use crate::error::{Error, Result}; + +const GEMV_BT_SHADER: &str = include_str!("gemv_bt.wgsl"); + +/// Launch 2D GEMV-BT kernel. +/// +/// Computes C[M,N] = A[M,K] @ B^T where B is stored as [N,K] row-major. +/// Dispatch: (N, M, 1) workgroups, each with 256 threads for K-reduction. +pub fn launch_gemv_bt( + cache: &PipelineCache, + queue: &Queue, + a: &Buffer, + b_nk: &Buffer, + c: &Buffer, + params_buffer: &Buffer, + m: usize, + n: usize, + dtype: DType, +) -> Result<()> { + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "gemv_bt", + }); + } + + let module = cache.get_or_create_module("gemv_bt", GEMV_BT_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 3, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = cache.get_or_create_pipeline("gemv_bt", "gemv_bt_f32", &module, &layout); + + let bind_group = cache.create_bind_group(&layout, &[a, b_nk, c, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("gemv_bt"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("gemv_bt"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + pass.dispatch_workgroups(n as u32, m as u32, 1); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + +/// Launch batched GEMV-BT kernel. +/// +/// Computes C[b,M,N] = A[b,M,K] @ B[b]^T where each B[b] is stored [N,K]. +pub fn launch_batched_gemv_bt( + cache: &PipelineCache, + queue: &Queue, + a: &Buffer, + b_nk: &Buffer, + c: &Buffer, + params_buffer: &Buffer, + m: usize, + n: usize, + batch_size: usize, + dtype: DType, +) -> Result<()> { + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "batched_gemv_bt", + }); + } + + let module = cache.get_or_create_module("gemv_bt", GEMV_BT_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 3, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = cache.get_or_create_pipeline("gemv_bt", "batched_gemv_bt_f32", &module, &layout); + + let bind_group = cache.create_bind_group(&layout, &[a, b_nk, c, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("batched_gemv_bt"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("batched_gemv_bt"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + pass.dispatch_workgroups(n as u32, m as u32, batch_size as u32); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} diff --git a/src/runtime/wgpu/shaders/gemv_bt.wgsl b/src/runtime/wgpu/shaders/gemv_bt.wgsl new file mode 100644 index 00000000..97549d41 --- /dev/null +++ b/src/runtime/wgpu/shaders/gemv_bt.wgsl @@ -0,0 +1,107 @@ +// GEMV-BT: C[M,N] = A[M,K] @ B^T where B is stored as [N,K] row-major. +// +// Each output C[m,n] = dot(A[m,:], B[n,:]) where both vectors are contiguous. +// This avoids copying transposed weight matrices to make them contiguous. +// +// Dispatch: workgroups(N, M, batch_size) with workgroup_size(256, 1, 1) +// Each workgroup computes one output element using parallel reduction. + +struct GemvBtParams { + M: u32, + K: u32, + N: u32, + batch_size: u32, +} + +@group(0) @binding(0) var gemv_a: array; +@group(0) @binding(1) var gemv_b: array; +@group(0) @binding(2) var gemv_c: array; +@group(0) @binding(3) var gemv_params: GemvBtParams; + +var gemv_shared: array; + +// 2D GEMV-BT: one workgroup per output element +// workgroup_id.x = output column (n), workgroup_id.y = output row (m) +@compute @workgroup_size(256, 1, 1) +fn gemv_bt_f32(@builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let M = gemv_params.M; + let K = gemv_params.K; + let N = gemv_params.N; + let tid = local_id.x; + let m = group_id.y; + let n = group_id.x; + + if (m >= M || n >= N) { + return; + } + + // A is [M, K] row-major, B is [N, K] row-major + let a_offset = m * K; + let b_offset = n * K; + + // Each thread computes partial dot product + var sum: f32 = 0.0; + var i: u32 = tid; + while (i < K) { + sum = sum + gemv_a[a_offset + i] * gemv_b[b_offset + i]; + i = i + 256u; + } + + gemv_shared[tid] = sum; + workgroupBarrier(); + + // Parallel reduction + for (var s: u32 = 128u; s > 0u; s = s >> 1u) { + if (tid < s) { + gemv_shared[tid] = gemv_shared[tid] + gemv_shared[tid + s]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + gemv_c[m * N + n] = gemv_shared[0]; + } +} + +// Batched GEMV-BT: workgroup_id.z = batch index +@compute @workgroup_size(256, 1, 1) +fn batched_gemv_bt_f32(@builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let M = gemv_params.M; + let K = gemv_params.K; + let N = gemv_params.N; + let batch_size = gemv_params.batch_size; + let tid = local_id.x; + let m = group_id.y; + let n = group_id.x; + let batch = group_id.z; + + if (m >= M || n >= N || batch >= batch_size) { + return; + } + + let a_offset = batch * M * K + m * K; + let b_offset = batch * N * K + n * K; + + var sum: f32 = 0.0; + var i: u32 = tid; + while (i < K) { + sum = sum + gemv_a[a_offset + i] * gemv_b[b_offset + i]; + i = i + 256u; + } + + gemv_shared[tid] = sum; + workgroupBarrier(); + + for (var s: u32 = 128u; s > 0u; s = s >> 1u) { + if (tid < s) { + gemv_shared[tid] = gemv_shared[tid] + gemv_shared[tid + s]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + gemv_c[batch * M * N + m * N + n] = gemv_shared[0]; + } +} diff --git a/src/runtime/wgpu/shaders/generator/activation.rs b/src/runtime/wgpu/shaders/generator/activation.rs deleted file mode 100644 index c856842e..00000000 --- a/src/runtime/wgpu/shaders/generator/activation.rs +++ /dev/null @@ -1,49 +0,0 @@ -//! WGSL shader generation for parameterized activation operations -//! -//! Handles activation functions that require more than one parameter, -//! like clamp (min, max). - -use super::common::{dtype_suffix, is_wgsl_float, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -/// Generate WGSL shader for clamp operation -/// -/// Clamp requires two parameters (min, max) so uses a dedicated params struct. -pub fn generate_clamp_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - // Only float types support clamp with float bounds - if !is_wgsl_float(dtype) { - return Ok(String::new()); - } - - Ok(format!( - r#"// Auto-generated clamp operation for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct ClampParams {{ - numel: u32, - min_val: f32, - max_val: f32, - _pad0: u32, -}} - -@group(0) @binding(0) var clamp_a: array<{t}>; -@group(0) @binding(1) var clamp_out: array<{t}>; -@group(0) @binding(2) var clamp_params: ClampParams; - -@compute @workgroup_size(256) -fn clamp_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < clamp_params.numel) {{ - clamp_out[idx] = clamp(clamp_a[idx], {t}(clamp_params.min_val), {t}(clamp_params.max_val)); - }} -}} -"#, - t = t, - suffix = suffix - )) -} diff --git a/src/runtime/wgpu/shaders/generator/binary.rs b/src/runtime/wgpu/shaders/generator/binary.rs deleted file mode 100644 index cf41e5b6..00000000 --- a/src/runtime/wgpu/shaders/generator/binary.rs +++ /dev/null @@ -1,280 +0,0 @@ -//! WGSL shader generation for binary element-wise operations - -use super::common::{dtype_suffix, is_wgsl_float, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -/// Generate WGSL shader for binary element-wise operations -pub fn generate_binary_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - let float_ops = if is_wgsl_float(dtype) { - format!( - r#" -@compute @workgroup_size(256) -fn pow_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < binary_params.numel) {{ - binary_out[idx] = pow(binary_a[idx], binary_b[idx]); - }} -}} - -@compute @workgroup_size(256) -fn atan2_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < binary_params.numel) {{ - binary_out[idx] = atan2(binary_a[idx], binary_b[idx]); - }} -}} -"#, - suffix = suffix - ) - } else { - // Integer pow requires loop implementation - format!( - r#" -@compute @workgroup_size(256) -fn pow_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < binary_params.numel) {{ - var base = binary_a[idx]; - var exp = binary_b[idx]; - var result: {t} = 1; - // Simple integer power loop - for (var i: {t} = 0; i < exp; i = i + 1) {{ - result = result * base; - }} - binary_out[idx] = result; - }} -}} -"#, - suffix = suffix, - t = t - ) - }; - - Ok(format!( - r#"// Auto-generated binary operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct BinaryParams {{ - numel: u32, -}} - -@group(0) @binding(0) var binary_a: array<{t}>; -@group(0) @binding(1) var binary_b: array<{t}>; -@group(0) @binding(2) var binary_out: array<{t}>; -@group(0) @binding(3) var binary_params: BinaryParams; - -@compute @workgroup_size(256) -fn add_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < binary_params.numel) {{ - binary_out[idx] = binary_a[idx] + binary_b[idx]; - }} -}} - -@compute @workgroup_size(256) -fn sub_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < binary_params.numel) {{ - binary_out[idx] = binary_a[idx] - binary_b[idx]; - }} -}} - -@compute @workgroup_size(256) -fn mul_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < binary_params.numel) {{ - binary_out[idx] = binary_a[idx] * binary_b[idx]; - }} -}} - -@compute @workgroup_size(256) -fn div_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < binary_params.numel) {{ - binary_out[idx] = binary_a[idx] / binary_b[idx]; - }} -}} - -@compute @workgroup_size(256) -fn max_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < binary_params.numel) {{ - binary_out[idx] = max(binary_a[idx], binary_b[idx]); - }} -}} - -@compute @workgroup_size(256) -fn min_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < binary_params.numel) {{ - binary_out[idx] = min(binary_a[idx], binary_b[idx]); - }} -}} - -{float_ops} -"#, - t = t, - suffix = suffix, - float_ops = float_ops - )) -} - -/// Generate WGSL shader for broadcast binary element-wise operations. -/// -/// This shader handles tensors with different shapes that need broadcasting. -/// Strides are passed as storage buffers with 0 for broadcast dimensions. -pub fn generate_broadcast_binary_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - let float_ops = if is_wgsl_float(dtype) { - format!( - r#" -@compute @workgroup_size(256) -fn broadcast_pow_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= broadcast_params.numel) {{ - return; - }} - - var remaining = idx; - var a_offset: u32 = 0u; - var b_offset: u32 = 0u; - - for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) {{ - let stride = broadcast_out_strides[d]; - let coord = remaining / stride; - remaining = remaining % stride; - - a_offset = a_offset + coord * broadcast_a_strides[d]; - b_offset = b_offset + coord * broadcast_b_strides[d]; - }} - - broadcast_out[idx] = pow(broadcast_a[a_offset], broadcast_b[b_offset]); -}} -"#, - suffix = suffix - ) - } else { - String::new() // Integer pow not commonly needed for broadcast - }; - - // Define all broadcast binary operations - let ops = [("add", "+"), ("sub", "-"), ("mul", "*"), ("div", "/")]; - - let mut op_shaders = String::new(); - for (op_name, op_sym) in ops.iter() { - op_shaders.push_str(&format!( - r#" -@compute @workgroup_size(256) -fn broadcast_{op_name}_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= broadcast_params.numel) {{ - return; - }} - - var remaining = idx; - var a_offset: u32 = 0u; - var b_offset: u32 = 0u; - - for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) {{ - let stride = broadcast_out_strides[d]; - let coord = remaining / stride; - remaining = remaining % stride; - - a_offset = a_offset + coord * broadcast_a_strides[d]; - b_offset = b_offset + coord * broadcast_b_strides[d]; - }} - - broadcast_out[idx] = broadcast_a[a_offset] {op_sym} broadcast_b[b_offset]; -}} -"#, - op_name = op_name, - suffix = suffix, - op_sym = op_sym, - )); - } - - // max/min use built-in functions - op_shaders.push_str(&format!( - r#" -@compute @workgroup_size(256) -fn broadcast_max_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= broadcast_params.numel) {{ - return; - }} - - var remaining = idx; - var a_offset: u32 = 0u; - var b_offset: u32 = 0u; - - for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) {{ - let stride = broadcast_out_strides[d]; - let coord = remaining / stride; - remaining = remaining % stride; - - a_offset = a_offset + coord * broadcast_a_strides[d]; - b_offset = b_offset + coord * broadcast_b_strides[d]; - }} - - broadcast_out[idx] = max(broadcast_a[a_offset], broadcast_b[b_offset]); -}} - -@compute @workgroup_size(256) -fn broadcast_min_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= broadcast_params.numel) {{ - return; - }} - - var remaining = idx; - var a_offset: u32 = 0u; - var b_offset: u32 = 0u; - - for (var d: u32 = 0u; d < broadcast_params.ndim; d = d + 1u) {{ - let stride = broadcast_out_strides[d]; - let coord = remaining / stride; - remaining = remaining % stride; - - a_offset = a_offset + coord * broadcast_a_strides[d]; - b_offset = b_offset + coord * broadcast_b_strides[d]; - }} - - broadcast_out[idx] = min(broadcast_a[a_offset], broadcast_b[b_offset]); -}} -"#, - suffix = suffix - )); - - Ok(format!( - r#"// Auto-generated broadcast binary operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct BroadcastBinaryParams {{ - numel: u32, - ndim: u32, -}} - -@group(0) @binding(0) var broadcast_a: array<{t}>; -@group(0) @binding(1) var broadcast_b: array<{t}>; -@group(0) @binding(2) var broadcast_out: array<{t}>; -@group(0) @binding(3) var broadcast_a_strides: array; -@group(0) @binding(4) var broadcast_b_strides: array; -@group(0) @binding(5) var broadcast_out_strides: array; -@group(0) @binding(6) var broadcast_params: BroadcastBinaryParams; - -{op_shaders} -{float_ops} -"#, - t = t, - op_shaders = op_shaders, - float_ops = float_ops - )) -} diff --git a/src/runtime/wgpu/shaders/generator/cast.rs b/src/runtime/wgpu/shaders/generator/cast.rs deleted file mode 100644 index 0b759d07..00000000 --- a/src/runtime/wgpu/shaders/generator/cast.rs +++ /dev/null @@ -1,111 +0,0 @@ -//! WGSL shader generation for dtype cast operations - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// Generate WGSL shader for dtype cast operations -/// -/// WebGPU-supported casts: -/// - F32 ↔ I32 ↔ U32 -/// -/// Each cast direction requires a separate entry point since WGSL -/// doesn't support templates. -pub fn generate_cast_shader(src_dtype: DType, dst_dtype: DType) -> Result { - let src_t = wgsl_type(src_dtype)?; - let dst_t = wgsl_type(dst_dtype)?; - let src_suffix = dtype_suffix(src_dtype)?; - let dst_suffix = dtype_suffix(dst_dtype)?; - - // For same-type cast, just return a no-op shader (shouldn't be called) - if src_dtype == dst_dtype { - return Ok(format!( - r#"// No-op cast shader for {src_t} -> {dst_t} -// This should be optimized away at dispatch time -"# - )); - } - - Ok(format!( - r#"// Auto-generated cast operation: {src_t} -> {dst_t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct CastParams {{ - numel: u32, -}} - -@group(0) @binding(0) var cast_input: array<{src_t}>; -@group(0) @binding(1) var cast_output: array<{dst_t}>; -@group(0) @binding(2) var cast_params: CastParams; - -@compute @workgroup_size(256) -fn cast_{src_suffix}_to_{dst_suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < cast_params.numel) {{ - cast_output[idx] = {dst_t}(cast_input[idx]); - }} -}} -"#, - src_t = src_t, - dst_t = dst_t, - src_suffix = src_suffix, - dst_suffix = dst_suffix - )) -} - -/// Generate all cast shaders for a given source dtype -/// -/// Returns a combined shader with all casts from the source type. -pub fn generate_all_casts_from(src_dtype: DType) -> Result { - let src_t = wgsl_type(src_dtype)?; - let src_suffix = dtype_suffix(src_dtype)?; - - let targets: &[DType] = match src_dtype { - DType::F32 => &[DType::I32, DType::U32], - DType::I32 => &[DType::F32, DType::U32], - DType::U32 => &[DType::F32, DType::I32], - _ => { - return Err(Error::UnsupportedDType { - dtype: src_dtype, - op: "cast", - }); - } - }; - - let mut shader = format!( - r#"// Auto-generated cast operations from {src_t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct CastParams {{ - numel: u32, -}} - -@group(0) @binding(0) var cast_input: array<{src_t}>; -"# - ); - - for &dst_dtype in targets { - let dst_t = wgsl_type(dst_dtype)?; - let dst_suffix = dtype_suffix(dst_dtype)?; - - shader.push_str(&format!( - r#" -// Cast {src_t} -> {dst_t} -@group(0) @binding(1) var cast_output_{dst_suffix}: array<{dst_t}>; -@group(0) @binding(2) var cast_params_{dst_suffix}: CastParams; - -@compute @workgroup_size(256) -fn cast_{src_suffix}_to_{dst_suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < cast_params_{dst_suffix}.numel) {{ - cast_output_{dst_suffix}[idx] = {dst_t}(cast_input[idx]); - }} -}} -"# - )); - } - - Ok(shader) -} diff --git a/src/runtime/wgpu/shaders/generator/cat.rs b/src/runtime/wgpu/shaders/generator/cat.rs deleted file mode 100644 index 5913f661..00000000 --- a/src/runtime/wgpu/shaders/generator/cat.rs +++ /dev/null @@ -1,281 +0,0 @@ -//! WGSL shader generation for shape operations (cat, repeat, pad, roll) - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -/// WGSL helper function to access packed `` `array, 2>` `` by index. -/// -/// WGSL uniform buffers require 16-byte alignment for array elements. We pack 8 u32 values -/// into `` `2 vec4` `` to meet this requirement. This helper extracts individual values. -const WGSL_GET_PACKED_VALUE_HELPER: &str = r#"// Helper to access packed array, 2> by index -fn get_packed_value(arr: array, 2>, d: i32) -> u32 { - let vec_idx = u32(d) / 4u; - let comp_idx = u32(d) % 4u; - if (vec_idx == 0u) { - if (comp_idx == 0u) { return arr[0].x; } - else if (comp_idx == 1u) { return arr[0].y; } - else if (comp_idx == 2u) { return arr[0].z; } - else { return arr[0].w; } - } else { - if (comp_idx == 0u) { return arr[1].x; } - else if (comp_idx == 1u) { return arr[1].y; } - else if (comp_idx == 2u) { return arr[1].z; } - else { return arr[1].w; } - } -} -"#; - -/// Generate WGSL shader for cat_copy operation (one tensor at a time) -/// -/// This kernel copies data from a source tensor to the appropriate position -/// in the concatenated output tensor. It's called once per input tensor. -pub fn generate_cat_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated cat operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct CatParams {{ - outer_size: u32, - src_cat_size: u32, - dst_cat_size: u32, - cat_offset: u32, - inner_size: u32, - total_elements: u32, -}} - -@group(0) @binding(0) var cat_src: array<{t}>; -@group(0) @binding(1) var cat_dst: array<{t}>; -@group(0) @binding(2) var cat_params: CatParams; - -@compute @workgroup_size(256) -fn cat_copy_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= cat_params.total_elements) {{ - return; - }} - - // Decompose idx into (outer, cat_i, inner) for source tensor - let inner = idx % cat_params.inner_size; - let remaining = idx / cat_params.inner_size; - let cat_i = remaining % cat_params.src_cat_size; - let outer = remaining / cat_params.src_cat_size; - - // Compute destination index - let dst_idx = outer * cat_params.dst_cat_size * cat_params.inner_size - + (cat_params.cat_offset + cat_i) * cat_params.inner_size - + inner; - - cat_dst[dst_idx] = cat_src[idx]; -}} -"#, - t = t, - suffix = suffix - )) -} - -/// Generate WGSL shader for repeat operation (tile tensor along all dimensions) -/// -/// This kernel tiles the source tensor by the given repeat factors. -pub fn generate_repeat_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated repeat operation for {t} - -const WORKGROUP_SIZE: u32 = 256u; -const MAX_DIMS: u32 = 8u; - -// Use vec4 for 16-byte alignment in uniform buffer -struct RepeatParams {{ - ndim: u32, - total_elements: u32, - _pad0: u32, - _pad1: u32, - src_shape: array, 2>, // 8 u32 values packed into 2 vec4 - out_shape: array, 2>, -}} - -{helper} - -@group(0) @binding(0) var repeat_src: array<{t}>; -@group(0) @binding(1) var repeat_dst: array<{t}>; -@group(0) @binding(2) var repeat_params: RepeatParams; - -@compute @workgroup_size(256) -fn repeat_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= repeat_params.total_elements) {{ - return; - }} - - // Decompose idx into multi-dimensional output coordinates - var remaining = idx; - var src_idx = 0u; - - // Compute source strides first (row-major) - var src_strides: array; - var stride = 1u; - for (var d = i32(repeat_params.ndim) - 1; d >= 0; d = d - 1) {{ - src_strides[d] = stride; - stride = stride * get_packed_value(repeat_params.src_shape, d); - }} - - // Process dimensions from last to first - for (var d = i32(repeat_params.ndim) - 1; d >= 0; d = d - 1) {{ - let out_dim = get_packed_value(repeat_params.out_shape, d); - let coord = remaining % out_dim; - remaining = remaining / out_dim; - - // Map to source coordinate using modulo - let src_shape_d = get_packed_value(repeat_params.src_shape, d); - let src_coord = coord % src_shape_d; - src_idx = src_idx + src_coord * src_strides[d]; - }} - - repeat_dst[idx] = repeat_src[src_idx]; -}} -"#, - t = t, - suffix = suffix, - helper = WGSL_GET_PACKED_VALUE_HELPER - )) -} - -/// Generate WGSL shader for pad operation (add padding around tensor) -/// -/// This kernel adds padding to a tensor with a fill value. -pub fn generate_pad_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated pad operation for {t} - -const WORKGROUP_SIZE: u32 = 256u; -const MAX_DIMS: u32 = 8u; - -// Use vec4 for 16-byte alignment in uniform buffer -struct PadParams {{ - ndim: u32, - total_elements: u32, - fill_value: {t}, - _pad0: u32, - src_shape: array, 2>, // 8 u32 values packed into 2 vec4 - out_shape: array, 2>, - pad_before: array, 2>, -}} - -{helper} - -@group(0) @binding(0) var pad_src: array<{t}>; -@group(0) @binding(1) var pad_dst: array<{t}>; -@group(0) @binding(2) var pad_params: PadParams; - -@compute @workgroup_size(256) -fn pad_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= pad_params.total_elements) {{ - return; - }} - - // Decompose idx into multi-dimensional output coordinates - var remaining = idx; - var coords: array; - var in_bounds = true; - - // Process dimensions from last to first - for (var d = i32(pad_params.ndim) - 1; d >= 0; d = d - 1) {{ - let out_dim = get_packed_value(pad_params.out_shape, d); - coords[d] = remaining % out_dim; - remaining = remaining / out_dim; - - // Check if coordinate is in original tensor region - let pb = get_packed_value(pad_params.pad_before, d); - let ss = get_packed_value(pad_params.src_shape, d); - if (coords[d] < pb || coords[d] >= pb + ss) {{ - in_bounds = false; - }} - }} - - if (in_bounds) {{ - // Compute source index - var src_idx = 0u; - var src_stride = 1u; - for (var d = i32(pad_params.ndim) - 1; d >= 0; d = d - 1) {{ - let src_coord = coords[d] - get_packed_value(pad_params.pad_before, d); - src_idx = src_idx + src_coord * src_stride; - src_stride = src_stride * get_packed_value(pad_params.src_shape, d); - }} - pad_dst[idx] = pad_src[src_idx]; - }} else {{ - pad_dst[idx] = pad_params.fill_value; - }} -}} -"#, - t = t, - suffix = suffix, - helper = WGSL_GET_PACKED_VALUE_HELPER - )) -} - -/// Generate WGSL shader for roll operation (circular shift along dimension) -/// -/// This kernel shifts elements along a dimension with wrapping. -pub fn generate_roll_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated roll operation for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct RollParams {{ - outer_size: u32, - dim_size: u32, - inner_size: u32, - shift: u32, - total_elements: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var roll_src: array<{t}>; -@group(0) @binding(1) var roll_dst: array<{t}>; -@group(0) @binding(2) var roll_params: RollParams; - -@compute @workgroup_size(256) -fn roll_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= roll_params.total_elements) {{ - return; - }} - - // Decompose idx into (outer, dim_coord, inner) - let inner = idx % roll_params.inner_size; - let remaining = idx / roll_params.inner_size; - let dim_coord = remaining % roll_params.dim_size; - let outer = remaining / roll_params.dim_size; - - // Compute source coordinate with roll (shift goes right, so source is shift positions left) - let src_dim_coord = (dim_coord + roll_params.dim_size - roll_params.shift) % roll_params.dim_size; - - // Compute source linear index - let src_idx = outer * roll_params.dim_size * roll_params.inner_size - + src_dim_coord * roll_params.inner_size - + inner; - - roll_dst[idx] = roll_src[src_idx]; -}} -"#, - t = t, - suffix = suffix - )) -} diff --git a/src/runtime/wgpu/shaders/generator/common.rs b/src/runtime/wgpu/shaders/generator/common.rs deleted file mode 100644 index 2cd89d44..00000000 --- a/src/runtime/wgpu/shaders/generator/common.rs +++ /dev/null @@ -1,47 +0,0 @@ -//! Common helper functions for WGSL shader generation - -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// WGSL type name for a given DType -pub fn wgsl_type(dtype: DType) -> Result<&'static str> { - match dtype { - DType::F32 => Ok("f32"), - DType::I32 => Ok("i32"), - DType::U32 => Ok("u32"), - DType::F16 => Ok("f16"), // Requires extension - _ => Err(Error::UnsupportedDType { - dtype, - op: "wgpu_shader", - }), - } -} - -/// Short suffix for entry point names (e.g., "add_f32", "add_i32") -pub fn dtype_suffix(dtype: DType) -> Result<&'static str> { - match dtype { - DType::F32 => Ok("f32"), - DType::I32 => Ok("i32"), - DType::U32 => Ok("u32"), - DType::F16 => Ok("f16"), - _ => Err(Error::UnsupportedDType { - dtype, - op: "wgpu_shader", - }), - } -} - -/// Check if dtype is supported by WebGPU -pub fn is_wgpu_supported(dtype: DType) -> bool { - matches!(dtype, DType::F32 | DType::I32 | DType::U32 | DType::F16) -} - -/// Check if dtype is a float type in WGSL -pub fn is_wgsl_float(dtype: DType) -> bool { - matches!(dtype, DType::F32 | DType::F16) -} - -/// Check if dtype is an integer type in WGSL -pub fn is_wgsl_int(dtype: DType) -> bool { - matches!(dtype, DType::I32 | DType::U32) -} diff --git a/src/runtime/wgpu/shaders/generator/compare.rs b/src/runtime/wgpu/shaders/generator/compare.rs deleted file mode 100644 index cd944daf..00000000 --- a/src/runtime/wgpu/shaders/generator/compare.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! WGSL shader generation for comparison operations - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -/// Generate WGSL shader for comparison operations -pub fn generate_compare_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - // Output is always f32 for consistency (1.0 = true, 0.0 = false) - Ok(format!( - r#"// Auto-generated compare operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct CompareParams {{ - numel: u32, -}} - -@group(0) @binding(0) var compare_a: array<{t}>; -@group(0) @binding(1) var compare_b: array<{t}>; -@group(0) @binding(2) var compare_out: array; -@group(0) @binding(3) var compare_params: CompareParams; - -@compute @workgroup_size(256) -fn eq_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < compare_params.numel) {{ - compare_out[idx] = select(0.0, 1.0, compare_a[idx] == compare_b[idx]); - }} -}} - -@compute @workgroup_size(256) -fn ne_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < compare_params.numel) {{ - compare_out[idx] = select(0.0, 1.0, compare_a[idx] != compare_b[idx]); - }} -}} - -@compute @workgroup_size(256) -fn lt_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < compare_params.numel) {{ - compare_out[idx] = select(0.0, 1.0, compare_a[idx] < compare_b[idx]); - }} -}} - -@compute @workgroup_size(256) -fn le_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < compare_params.numel) {{ - compare_out[idx] = select(0.0, 1.0, compare_a[idx] <= compare_b[idx]); - }} -}} - -@compute @workgroup_size(256) -fn gt_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < compare_params.numel) {{ - compare_out[idx] = select(0.0, 1.0, compare_a[idx] > compare_b[idx]); - }} -}} - -@compute @workgroup_size(256) -fn ge_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < compare_params.numel) {{ - compare_out[idx] = select(0.0, 1.0, compare_a[idx] >= compare_b[idx]); - }} -}} -"#, - t = t, - suffix = suffix - )) -} diff --git a/src/runtime/wgpu/shaders/generator/complex.rs b/src/runtime/wgpu/shaders/generator/complex.rs deleted file mode 100644 index 6179cb5f..00000000 --- a/src/runtime/wgpu/shaders/generator/complex.rs +++ /dev/null @@ -1,285 +0,0 @@ -//! WGSL shader generation for complex number operations -//! -//! Complex64 is represented as `vec2` where: -//! - .x = real part -//! - .y = imaginary part - -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// Generate WGSL shader for complex conjugate operation. -/// -/// Input: Complex64 (`vec2`) -/// Output: Complex64 (`vec2`) -/// Operation: conj(a + bi) = a - bi -pub fn generate_conj_shader() -> Result { - Ok(r#" -struct Params { - numel: u32, -} - -@group(0) @binding(0) var input: array>; -@group(0) @binding(1) var output: array>; -@group(0) @binding(2) var params: Params; - -@compute @workgroup_size(256) -fn conj_complex64(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx < params.numel) { - let val = input[idx]; - output[idx] = vec2(val.x, -val.y); // Real stays same, imaginary flips sign - } -} -"# - .to_string()) -} - -/// Generate WGSL shader for extracting real part. -/// -/// Input: Complex64 (`vec2`) -/// Output: F32 (`f32`) -/// Operation: real(a + bi) = a -pub fn generate_real_shader() -> Result { - Ok(r#" -struct Params { - numel: u32, -} - -@group(0) @binding(0) var input: array>; -@group(0) @binding(1) var output: array; -@group(0) @binding(2) var params: Params; - -@compute @workgroup_size(256) -fn real_complex64(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx < params.numel) { - output[idx] = input[idx].x; // Extract real component - } -} -"# - .to_string()) -} - -/// Generate WGSL shader for extracting imaginary part. -/// -/// Input: Complex64 (`vec2`) -/// Output: F32 (`f32`) -/// Operation: imag(a + bi) = b -pub fn generate_imag_shader() -> Result { - Ok(r#" -struct Params { - numel: u32, -} - -@group(0) @binding(0) var input: array>; -@group(0) @binding(1) var output: array; -@group(0) @binding(2) var params: Params; - -@compute @workgroup_size(256) -fn imag_complex64(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx < params.numel) { - output[idx] = input[idx].y; // Extract imaginary component - } -} -"# - .to_string()) -} - -/// Generate WGSL shader for computing phase angle. -/// -/// Input: Complex64 (`vec2`) -/// Output: F32 (`f32`) -/// Operation: angle(a + bi) = atan2(b, a) -pub fn generate_angle_shader() -> Result { - Ok(r#" -struct Params { - numel: u32, -} - -@group(0) @binding(0) var input: array>; -@group(0) @binding(1) var output: array; -@group(0) @binding(2) var params: Params; - -@compute @workgroup_size(256) -fn angle_complex64(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx < params.numel) { - let val = input[idx]; - output[idx] = atan2(val.y, val.x); // Phase angle in radians [-π, π] - } -} -"# - .to_string()) -} - -/// Generate WGSL shader for computing phase angle of real numbers. -/// -/// Input: F32 (real numbers) -/// Output: F32 -/// Operation: angle(x) = 0 if x >= 0, π if x < 0 -/// -/// Note: WGSL does not have a standard library with mathematical constants, -/// so PI must be defined as a literal constant in the shader source. -/// This matches Rust's std::f32::consts::PI value. -pub fn generate_angle_real_shader() -> Result { - Ok(r#" -struct Params { - numel: u32, -} - -@group(0) @binding(0) var input: array; -@group(0) @binding(1) var output: array; -@group(0) @binding(2) var params: Params; - -// PI constant (WGSL has no standard math library, so this is defined literally) -// Value matches std::f32::consts::PI exactly (f32 precision: ~7 significant digits) -const PI: f32 = 3.14159265f; - -@compute @workgroup_size(256) -fn angle_real_f32(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx < params.numel) { - let val = input[idx]; - output[idx] = select(0.0, PI, val < 0.0); // 0 if x >= 0, π if x < 0 - } -} -"# - .to_string()) -} - -/// Get the shader generator for a complex operation. -pub fn get_complex_shader_generator(op: &str) -> Result Result> { - match op { - "conj" => Ok(generate_conj_shader), - "real" => Ok(generate_real_shader), - "imag" => Ok(generate_imag_shader), - "angle" => Ok(generate_angle_shader), - _ => Err(Error::Internal(format!( - "Unknown complex operation: {}", - op - ))), - } -} - -/// Validate dtype for complex operations. -pub fn validate_complex_dtype(dtype: DType, op: &str) -> Result<()> { - // WebGPU only supports Complex64 (no F64 support) - if dtype != DType::Complex64 { - let op_static: &'static str = match op { - "conj" => "conj", - "real" => "real", - "imag" => "imag", - "angle" => "angle", - _ => "complex_op", - }; - return Err(Error::UnsupportedDType { - dtype, - op: op_static, - }); - } - Ok(()) -} - -/// Get output dtype for complex operation. -pub fn complex_output_dtype(input_dtype: DType, op: &str) -> Result { - validate_complex_dtype(input_dtype, op)?; - - match op { - "conj" => Ok(DType::Complex64), // Same as input - "real" | "imag" | "angle" => Ok(DType::F32), // Extract float component - _ => Err(Error::Internal(format!( - "Unknown complex operation: {}", - op - ))), - } -} - -/// Generate WGSL shader for constructing complex from real and imaginary parts. -/// -/// Input: F32 arrays for real and imaginary parts -/// Output: Complex64 (`` `vec2` ``) -/// Operation: `from_real_imag(real, imag)[i] = vec2(real[i], imag[i])` -pub fn generate_from_real_imag_shader() -> Result { - // Note: All storage bindings use read_write to match the pipeline layout - // (PipelineCache creates all storage buffers as read_write) - Ok(r#" -struct Params { - numel: u32, -} - -@group(0) @binding(0) var real_input: array; -@group(0) @binding(1) var imag_input: array; -@group(0) @binding(2) var output: array>; -@group(0) @binding(3) var params: Params; - -@compute @workgroup_size(256) -fn from_real_imag_f32(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx < params.numel) { - output[idx] = vec2(real_input[idx], imag_input[idx]); - } -} -"# - .to_string()) -} - -/// Generate WGSL shader for complex × real multiplication. -/// -/// Input: Complex64 (`vec2`) and F32 (real coefficient) -/// Output: Complex64 (`vec2`) -/// Operation: (a+bi) * r = ar + br*i -pub fn generate_complex_mul_real_shader() -> Result { - // Note: All storage bindings use read_write to match the pipeline layout - Ok(r#" -struct Params { - numel: u32, -} - -@group(0) @binding(0) var complex_input: array>; -@group(0) @binding(1) var real_input: array; -@group(0) @binding(2) var output: array>; -@group(0) @binding(3) var params: Params; - -@compute @workgroup_size(256) -fn complex64_mul_real(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx < params.numel) { - let c = complex_input[idx]; - let r = real_input[idx]; - output[idx] = vec2(c.x * r, c.y * r); - } -} -"# - .to_string()) -} - -/// Generate WGSL shader for complex / real division. -/// -/// Input: Complex64 (`vec2`) and F32 (real divisor) -/// Output: Complex64 (`vec2`) -/// Operation: (a+bi) / r = (a/r) + (b/r)*i -pub fn generate_complex_div_real_shader() -> Result { - // Note: All storage bindings use read_write to match the pipeline layout - Ok(r#" -struct Params { - numel: u32, -} - -@group(0) @binding(0) var complex_input: array>; -@group(0) @binding(1) var real_input: array; -@group(0) @binding(2) var output: array>; -@group(0) @binding(3) var params: Params; - -@compute @workgroup_size(256) -fn complex64_div_real(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx < params.numel) { - let c = complex_input[idx]; - let r = real_input[idx]; - output[idx] = vec2(c.x / r, c.y / r); - } -} -"# - .to_string()) -} diff --git a/src/runtime/wgpu/shaders/generator/conv.rs b/src/runtime/wgpu/shaders/generator/conv.rs deleted file mode 100644 index 37df0be3..00000000 --- a/src/runtime/wgpu/shaders/generator/conv.rs +++ /dev/null @@ -1,343 +0,0 @@ -//! WGSL shader generation for convolution operations - -use super::common::{dtype_suffix, is_wgsl_float, wgsl_type}; -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// Generate WGSL shader for conv1d operation. -/// -/// Input layout: (N, C_in, L) -/// Weight layout: (C_out, C_in/groups, K) -/// Output layout: (N, C_out, L_out) -pub fn generate_conv1d_shader(dtype: DType) -> Result { - if !is_wgsl_float(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "conv1d", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - let zero = if dtype == DType::F16 { "0.0h" } else { "0.0" }; - - Ok(format!( - r#"// Auto-generated conv1d shader for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct Conv1dParams {{ - batch: u32, - c_in: u32, - length: u32, - c_out: u32, - kernel_size: u32, - output_length: u32, - stride: u32, - padding: u32, - dilation: u32, - groups: u32, - has_bias: u32, - _pad: u32, -}} - -@group(0) @binding(0) var conv1d_input: array<{t}>; -@group(0) @binding(1) var conv1d_weight: array<{t}>; -@group(0) @binding(2) var conv1d_bias: array<{t}>; -@group(0) @binding(3) var conv1d_output: array<{t}>; -@group(0) @binding(4) var conv1d_params: Conv1dParams; - -@compute @workgroup_size(256) -fn conv1d_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = conv1d_params.batch * conv1d_params.c_out * conv1d_params.output_length; - if (idx >= total) {{ return; }} - - let ox = idx % conv1d_params.output_length; - let oc = (idx / conv1d_params.output_length) % conv1d_params.c_out; - let b = idx / (conv1d_params.c_out * conv1d_params.output_length); - - let c_in_per_group = conv1d_params.c_in / conv1d_params.groups; - let c_out_per_group = conv1d_params.c_out / conv1d_params.groups; - let g = oc / c_out_per_group; - let c_in_start = g * c_in_per_group; - - var sum: {t} = {zero}; - - for (var ic: u32 = 0u; ic < c_in_per_group; ic = ic + 1u) {{ - let c_in_idx = c_in_start + ic; - - for (var kx: u32 = 0u; kx < conv1d_params.kernel_size; kx = kx + 1u) {{ - let ix_signed = i32(ox * conv1d_params.stride + kx * conv1d_params.dilation) - i32(conv1d_params.padding); - - if (ix_signed >= 0 && u32(ix_signed) < conv1d_params.length) {{ - let ix = u32(ix_signed); - let input_idx = b * conv1d_params.c_in * conv1d_params.length + c_in_idx * conv1d_params.length + ix; - let weight_idx = oc * c_in_per_group * conv1d_params.kernel_size + ic * conv1d_params.kernel_size + kx; - sum = sum + conv1d_input[input_idx] * conv1d_weight[weight_idx]; - }} - }} - }} - - if (conv1d_params.has_bias != 0u) {{ - sum = sum + conv1d_bias[oc]; - }} - - conv1d_output[idx] = sum; -}} -"#, - t = t, - suffix = suffix, - zero = zero, - )) -} - -/// Generate WGSL shader for conv2d operation. -/// -/// Input layout: (N, C_in, H, W) -/// Weight layout: (C_out, C_in/groups, K_h, K_w) -/// Output layout: (N, C_out, H_out, W_out) -pub fn generate_conv2d_shader(dtype: DType) -> Result { - if !is_wgsl_float(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "conv2d", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - let zero = if dtype == DType::F16 { "0.0h" } else { "0.0" }; - - Ok(format!( - r#"// Auto-generated conv2d shader for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct Conv2dParams {{ - batch: u32, - c_in: u32, - height: u32, - width: u32, - c_out: u32, - kernel_h: u32, - kernel_w: u32, - output_h: u32, - output_w: u32, - stride_h: u32, - stride_w: u32, - pad_h: u32, - pad_w: u32, - dilation_h: u32, - dilation_w: u32, - groups: u32, - has_bias: u32, - _pad: u32, -}} - -@group(0) @binding(0) var conv2d_input: array<{t}>; -@group(0) @binding(1) var conv2d_weight: array<{t}>; -@group(0) @binding(2) var conv2d_bias: array<{t}>; -@group(0) @binding(3) var conv2d_output: array<{t}>; -@group(0) @binding(4) var conv2d_params: Conv2dParams; - -@compute @workgroup_size(256) -fn conv2d_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = conv2d_params.batch * conv2d_params.c_out * conv2d_params.output_h * conv2d_params.output_w; - if (idx >= total) {{ return; }} - - let ox = idx % conv2d_params.output_w; - let oy = (idx / conv2d_params.output_w) % conv2d_params.output_h; - let oc = (idx / (conv2d_params.output_w * conv2d_params.output_h)) % conv2d_params.c_out; - let b = idx / (conv2d_params.c_out * conv2d_params.output_h * conv2d_params.output_w); - - let c_in_per_group = conv2d_params.c_in / conv2d_params.groups; - let c_out_per_group = conv2d_params.c_out / conv2d_params.groups; - let g = oc / c_out_per_group; - let c_in_start = g * c_in_per_group; - - var sum: {t} = {zero}; - - for (var ic: u32 = 0u; ic < c_in_per_group; ic = ic + 1u) {{ - let c_in_idx = c_in_start + ic; - - for (var ky: u32 = 0u; ky < conv2d_params.kernel_h; ky = ky + 1u) {{ - for (var kx: u32 = 0u; kx < conv2d_params.kernel_w; kx = kx + 1u) {{ - let iy_signed = i32(oy * conv2d_params.stride_h + ky * conv2d_params.dilation_h) - i32(conv2d_params.pad_h); - let ix_signed = i32(ox * conv2d_params.stride_w + kx * conv2d_params.dilation_w) - i32(conv2d_params.pad_w); - - if (iy_signed >= 0 && u32(iy_signed) < conv2d_params.height && ix_signed >= 0 && u32(ix_signed) < conv2d_params.width) {{ - let iy = u32(iy_signed); - let ix = u32(ix_signed); - let input_idx = b * conv2d_params.c_in * conv2d_params.height * conv2d_params.width - + c_in_idx * conv2d_params.height * conv2d_params.width - + iy * conv2d_params.width - + ix; - let weight_idx = oc * c_in_per_group * conv2d_params.kernel_h * conv2d_params.kernel_w - + ic * conv2d_params.kernel_h * conv2d_params.kernel_w - + ky * conv2d_params.kernel_w - + kx; - sum = sum + conv2d_input[input_idx] * conv2d_weight[weight_idx]; - }} - }} - }} - }} - - if (conv2d_params.has_bias != 0u) {{ - sum = sum + conv2d_bias[oc]; - }} - - conv2d_output[idx] = sum; -}} -"#, - t = t, - suffix = suffix, - zero = zero, - )) -} - -/// Generate WGSL shader for depthwise conv2d operation. -/// -/// Input layout: (N, C, H, W) -/// Weight layout: (C, 1, K_h, K_w) -/// Output layout: (N, C, H_out, W_out) -pub fn generate_depthwise_conv2d_shader(dtype: DType) -> Result { - if !is_wgsl_float(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "depthwise_conv2d", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - let zero = if dtype == DType::F16 { "0.0h" } else { "0.0" }; - - Ok(format!( - r#"// Auto-generated depthwise conv2d shader for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct DepthwiseConv2dParams {{ - batch: u32, - channels: u32, - height: u32, - width: u32, - kernel_h: u32, - kernel_w: u32, - output_h: u32, - output_w: u32, - stride_h: u32, - stride_w: u32, - pad_h: u32, - pad_w: u32, - dilation_h: u32, - dilation_w: u32, - has_bias: u32, - _pad: u32, -}} - -@group(0) @binding(0) var depthwise_input: array<{t}>; -@group(0) @binding(1) var depthwise_weight: array<{t}>; -@group(0) @binding(2) var depthwise_bias: array<{t}>; -@group(0) @binding(3) var depthwise_output: array<{t}>; -@group(0) @binding(4) var depthwise_params: DepthwiseConv2dParams; - -@compute @workgroup_size(256) -fn depthwise_conv2d_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = depthwise_params.batch * depthwise_params.channels * depthwise_params.output_h * depthwise_params.output_w; - if (idx >= total) {{ return; }} - - let ox = idx % depthwise_params.output_w; - let oy = (idx / depthwise_params.output_w) % depthwise_params.output_h; - let c = (idx / (depthwise_params.output_w * depthwise_params.output_h)) % depthwise_params.channels; - let b = idx / (depthwise_params.channels * depthwise_params.output_h * depthwise_params.output_w); - - var sum: {t} = {zero}; - - for (var ky: u32 = 0u; ky < depthwise_params.kernel_h; ky = ky + 1u) {{ - for (var kx: u32 = 0u; kx < depthwise_params.kernel_w; kx = kx + 1u) {{ - let iy_signed = i32(oy * depthwise_params.stride_h + ky * depthwise_params.dilation_h) - i32(depthwise_params.pad_h); - let ix_signed = i32(ox * depthwise_params.stride_w + kx * depthwise_params.dilation_w) - i32(depthwise_params.pad_w); - - if (iy_signed >= 0 && u32(iy_signed) < depthwise_params.height && ix_signed >= 0 && u32(ix_signed) < depthwise_params.width) {{ - let iy = u32(iy_signed); - let ix = u32(ix_signed); - let input_idx = b * depthwise_params.channels * depthwise_params.height * depthwise_params.width - + c * depthwise_params.height * depthwise_params.width - + iy * depthwise_params.width - + ix; - let weight_idx = c * depthwise_params.kernel_h * depthwise_params.kernel_w + ky * depthwise_params.kernel_w + kx; - sum = sum + depthwise_input[input_idx] * depthwise_weight[weight_idx]; - }} - }} - }} - - if (depthwise_params.has_bias != 0u) {{ - sum = sum + depthwise_bias[c]; - }} - - depthwise_output[idx] = sum; -}} -"#, - t = t, - suffix = suffix, - zero = zero, - )) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_conv1d_shader_syntax() { - let shader = generate_conv1d_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for conv1d shader:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_conv2d_shader_syntax() { - let shader = generate_conv2d_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for conv2d shader:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_depthwise_conv2d_shader_syntax() { - let shader = generate_depthwise_conv2d_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for depthwise_conv2d shader:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_conv_shaders_int_fails() { - assert!(generate_conv1d_shader(DType::I32).is_err()); - assert!(generate_conv2d_shader(DType::I32).is_err()); - assert!(generate_depthwise_conv2d_shader(DType::I32).is_err()); - } -} diff --git a/src/runtime/wgpu/shaders/generator/cumulative.rs b/src/runtime/wgpu/shaders/generator/cumulative.rs deleted file mode 100644 index 994dc4e6..00000000 --- a/src/runtime/wgpu/shaders/generator/cumulative.rs +++ /dev/null @@ -1,348 +0,0 @@ -//! WGSL shader generation for cumulative operations -//! -//! Generates shaders for: -//! - cumsum: cumulative sum along a dimension -//! - cumprod: cumulative product along a dimension -//! - logsumexp: numerically stable log-sum-exp reduction - -use super::common::{dtype_suffix, is_wgsl_float, wgsl_type}; -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// Generate WGSL shader for cumulative sum operation (simple/contiguous) -pub fn generate_cumsum_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - let zero = match dtype { - DType::F32 | DType::F16 => "0.0", - DType::I32 => "0", - DType::U32 => "0u", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "cumsum", - }); - } - }; - - Ok(format!( - r#"// Auto-generated cumsum shader for {t} - -struct CumsumParams {{ - scan_size: u32, - outer_size: u32, -}} - -@group(0) @binding(0) var input: array<{t}>; -@group(0) @binding(1) var output: array<{t}>; -@group(0) @binding(2) var params: CumsumParams; - -@compute @workgroup_size(256) -fn cumsum_{suffix}(@builtin(global_invocation_id) global_id: vec3) {{ - let outer_idx = global_id.x; - if (outer_idx >= params.outer_size) {{ - return; - }} - - let base = outer_idx * params.scan_size; - var acc: {t} = {zero}; - for (var i: u32 = 0u; i < params.scan_size; i = i + 1u) {{ - acc = acc + input[base + i]; - output[base + i] = acc; - }} -}} -"#, - t = t, - suffix = suffix, - zero = zero, - )) -} - -/// Generate WGSL shader for strided cumulative sum -pub fn generate_cumsum_strided_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - let zero = match dtype { - DType::F32 | DType::F16 => "0.0", - DType::I32 => "0", - DType::U32 => "0u", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "cumsum_strided", - }); - } - }; - - Ok(format!( - r#"// Auto-generated strided cumsum shader for {t} - -struct CumsumStridedParams {{ - scan_size: u32, - outer_size: u32, - inner_size: u32, -}} - -@group(0) @binding(0) var input: array<{t}>; -@group(0) @binding(1) var output: array<{t}>; -@group(0) @binding(2) var params: CumsumStridedParams; - -@compute @workgroup_size(256) -fn cumsum_strided_{suffix}(@builtin(global_invocation_id) global_id: vec3) {{ - let idx = global_id.x; - let total_inner = params.outer_size * params.inner_size; - if (idx >= total_inner) {{ - return; - }} - - let outer_idx = idx / params.inner_size; - let inner_idx = idx % params.inner_size; - - var acc: {t} = {zero}; - for (var s: u32 = 0u; s < params.scan_size; s = s + 1u) {{ - let offset = outer_idx * params.scan_size * params.inner_size + s * params.inner_size + inner_idx; - acc = acc + input[offset]; - output[offset] = acc; - }} -}} -"#, - t = t, - suffix = suffix, - zero = zero, - )) -} - -/// Generate WGSL shader for cumulative product operation (simple/contiguous) -pub fn generate_cumprod_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - let one = match dtype { - DType::F32 | DType::F16 => "1.0", - DType::I32 => "1", - DType::U32 => "1u", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "cumprod", - }); - } - }; - - Ok(format!( - r#"// Auto-generated cumprod shader for {t} - -struct CumprodParams {{ - scan_size: u32, - outer_size: u32, -}} - -@group(0) @binding(0) var input: array<{t}>; -@group(0) @binding(1) var output: array<{t}>; -@group(0) @binding(2) var params: CumprodParams; - -@compute @workgroup_size(256) -fn cumprod_{suffix}(@builtin(global_invocation_id) global_id: vec3) {{ - let outer_idx = global_id.x; - if (outer_idx >= params.outer_size) {{ - return; - }} - - let base = outer_idx * params.scan_size; - var acc: {t} = {one}; - for (var i: u32 = 0u; i < params.scan_size; i = i + 1u) {{ - acc = acc * input[base + i]; - output[base + i] = acc; - }} -}} -"#, - t = t, - suffix = suffix, - one = one, - )) -} - -/// Generate WGSL shader for strided cumulative product -pub fn generate_cumprod_strided_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - let one = match dtype { - DType::F32 | DType::F16 => "1.0", - DType::I32 => "1", - DType::U32 => "1u", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "cumprod_strided", - }); - } - }; - - Ok(format!( - r#"// Auto-generated strided cumprod shader for {t} - -struct CumprodStridedParams {{ - scan_size: u32, - outer_size: u32, - inner_size: u32, -}} - -@group(0) @binding(0) var input: array<{t}>; -@group(0) @binding(1) var output: array<{t}>; -@group(0) @binding(2) var params: CumprodStridedParams; - -@compute @workgroup_size(256) -fn cumprod_strided_{suffix}(@builtin(global_invocation_id) global_id: vec3) {{ - let idx = global_id.x; - let total_inner = params.outer_size * params.inner_size; - if (idx >= total_inner) {{ - return; - }} - - let outer_idx = idx / params.inner_size; - let inner_idx = idx % params.inner_size; - - var acc: {t} = {one}; - for (var s: u32 = 0u; s < params.scan_size; s = s + 1u) {{ - let offset = outer_idx * params.scan_size * params.inner_size + s * params.inner_size + inner_idx; - acc = acc * input[offset]; - output[offset] = acc; - }} -}} -"#, - t = t, - suffix = suffix, - one = one, - )) -} - -/// Generate WGSL shader for log-sum-exp reduction (simple/contiguous) -/// -/// Computes log(sum(exp(x))) in a numerically stable way: -/// logsumexp(x) = max(x) + log(sum(exp(x - max(x)))) -pub fn generate_logsumexp_shader(dtype: DType) -> Result { - // logsumexp only supported for float types - if !is_wgsl_float(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "logsumexp", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - let min_val = match dtype { - DType::F32 => "-3.402823e+38", - DType::F16 => "-65504.0", - _ => "-3.402823e+38", - }; - - Ok(format!( - r#"// Auto-generated logsumexp shader for {t} - -struct LogsumexpParams {{ - reduce_size: u32, - outer_size: u32, -}} - -@group(0) @binding(0) var input: array<{t}>; -@group(0) @binding(1) var output: array<{t}>; -@group(0) @binding(2) var params: LogsumexpParams; - -@compute @workgroup_size(256) -fn logsumexp_{suffix}(@builtin(global_invocation_id) global_id: vec3) {{ - let outer_idx = global_id.x; - if (outer_idx >= params.outer_size) {{ - return; - }} - - let base = outer_idx * params.reduce_size; - - // Step 1: Find max value - var max_val: {t} = {min_val}; - for (var i: u32 = 0u; i < params.reduce_size; i = i + 1u) {{ - let val = input[base + i]; - max_val = max(max_val, val); - }} - - // Step 2: Compute sum(exp(x - max)) - var sum_exp: {t} = 0.0; - for (var i: u32 = 0u; i < params.reduce_size; i = i + 1u) {{ - sum_exp = sum_exp + exp(input[base + i] - max_val); - }} - - // Step 3: Result = max + log(sum) - output[outer_idx] = max_val + log(sum_exp); -}} -"#, - t = t, - suffix = suffix, - min_val = min_val, - )) -} - -/// Generate WGSL shader for strided log-sum-exp reduction -pub fn generate_logsumexp_strided_shader(dtype: DType) -> Result { - // logsumexp only supported for float types - if !is_wgsl_float(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "logsumexp_strided", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - let min_val = match dtype { - DType::F32 => "-3.402823e+38", - DType::F16 => "-65504.0", - _ => "-3.402823e+38", - }; - - Ok(format!( - r#"// Auto-generated strided logsumexp shader for {t} - -struct LogsumexpStridedParams {{ - reduce_size: u32, - outer_size: u32, - inner_size: u32, -}} - -@group(0) @binding(0) var input: array<{t}>; -@group(0) @binding(1) var output: array<{t}>; -@group(0) @binding(2) var params: LogsumexpStridedParams; - -@compute @workgroup_size(256) -fn logsumexp_strided_{suffix}(@builtin(global_invocation_id) global_id: vec3) {{ - let idx = global_id.x; - let total_inner = params.outer_size * params.inner_size; - if (idx >= total_inner) {{ - return; - }} - - let outer_idx = idx / params.inner_size; - let inner_idx = idx % params.inner_size; - - // Step 1: Find max value along reduce dimension - let first_offset = outer_idx * params.reduce_size * params.inner_size + inner_idx; - var max_val: {t} = {min_val}; - for (var r: u32 = 0u; r < params.reduce_size; r = r + 1u) {{ - let offset = outer_idx * params.reduce_size * params.inner_size + r * params.inner_size + inner_idx; - max_val = max(max_val, input[offset]); - }} - - // Step 2: Compute sum(exp(x - max)) - var sum_exp: {t} = 0.0; - for (var r: u32 = 0u; r < params.reduce_size; r = r + 1u) {{ - let offset = outer_idx * params.reduce_size * params.inner_size + r * params.inner_size + inner_idx; - sum_exp = sum_exp + exp(input[offset] - max_val); - }} - - // Step 3: Write result - output[outer_idx * params.inner_size + inner_idx] = max_val + log(sum_exp); -}} -"#, - t = t, - suffix = suffix, - min_val = min_val, - )) -} diff --git a/src/runtime/wgpu/shaders/generator/distributions.rs b/src/runtime/wgpu/shaders/generator/distributions.rs deleted file mode 100644 index 44118c7f..00000000 --- a/src/runtime/wgpu/shaders/generator/distributions.rs +++ /dev/null @@ -1,578 +0,0 @@ -//! WGSL shader generation for probability distribution sampling operations -//! -//! Provides shaders for: -//! - Bernoulli: Binary outcomes with probability p -//! - Beta: Continuous on [0, 1] with shape parameters -//! - Gamma: Continuous on [0, inf) with shape/scale -//! - Exponential: Continuous on [0, inf) with rate -//! - Poisson: Discrete counts with rate lambda -//! - Binomial: Discrete successes in n trials -//! - Laplace: Double exponential distribution -//! - Chi-squared: Sum of squared normals -//! - Student's t: Heavy-tailed distribution -//! - F: Ratio of chi-squared variates - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// PCG random number generator for WGSL with distribution helpers -const DISTRIBUTION_RNG_WGSL: &str = r#" -// PCG hash function for random number generation -fn pcg_hash(input: u32) -> u32 { - var state = input * 747796405u + 2891336453u; - var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; - return (word >> 22u) ^ word; -} - -fn pcg_init(seed: u32, idx: u32) -> u32 { - return pcg_hash(seed ^ pcg_hash(idx)); -} - -fn pcg_uniform(state: ptr) -> f32 { - *state = pcg_hash(*state); - return f32(*state) / 4294967296.0; -} - -// Box-Muller for normal distribution -fn sample_normal(state: ptr) -> f32 { - let u1 = max(pcg_uniform(state), 0.0000001); - let u2 = pcg_uniform(state); - return sqrt(-2.0 * log(u1)) * cos(6.28318530718 * u2); -} - -// Gamma via Marsaglia-Tsang method -fn sample_gamma_mt(state: ptr, shape: f32, scale: f32) -> f32 { - var alpha = shape; - var boost = 1.0; - - // Handle shape < 1 by boosting - if alpha < 1.0 { - boost = pow(pcg_uniform(state), 1.0 / alpha); - alpha = alpha + 1.0; - } - - let d = alpha - 1.0 / 3.0; - let c = 1.0 / sqrt(9.0 * d); - - // Rejection sampling - for (var i = 0u; i < 100u; i = i + 1u) { - var x: f32; - var v: f32; - - // Generate valid v - for (var j = 0u; j < 100u; j = j + 1u) { - x = sample_normal(state); - v = 1.0 + c * x; - if v > 0.0 { - break; - } - } - - v = v * v * v; - let u = pcg_uniform(state); - let x2 = x * x; - - // Accept/reject - if u < 1.0 - 0.0331 * x2 * x2 { - return d * v * boost * scale; - } - if log(u) < 0.5 * x2 + d * (1.0 - v + log(v)) { - return d * v * boost * scale; - } - } - - // Fallback (should rarely reach) - return d * boost * scale; -} -"#; - -fn check_float_dtype(dtype: DType, op: &'static str) -> Result<()> { - match dtype { - DType::F32 => Ok(()), - _ => Err(Error::UnsupportedDType { dtype, op }), - } -} - -/// Generate WGSL shader for Bernoulli distribution sampling -pub fn generate_bernoulli_shader(dtype: DType) -> Result { - check_float_dtype(dtype, "bernoulli")?; - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Bernoulli distribution sampling for {t} -{rng} -const WORKGROUP_SIZE: u32 = 256u; - -struct BernoulliParams {{ - numel: u32, - seed: u32, - p: f32, - _pad: u32, -}} - -@group(0) @binding(0) var out: array<{t}>; -@group(0) @binding(1) var params: BernoulliParams; - -@compute @workgroup_size(256) -fn bernoulli_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if idx < params.numel {{ - var state = pcg_init(params.seed, idx); - let u = pcg_uniform(&state); - out[idx] = select({t}(0.0), {t}(1.0), u < params.p); - }} -}} -"#, - t = t, - suffix = suffix, - rng = DISTRIBUTION_RNG_WGSL - )) -} - -/// Generate WGSL shader for Beta distribution sampling -pub fn generate_beta_dist_shader(dtype: DType) -> Result { - check_float_dtype(dtype, "beta")?; - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Beta distribution sampling for {t} -{rng} -const WORKGROUP_SIZE: u32 = 256u; - -struct BetaParams {{ - numel: u32, - seed: u32, - alpha: f32, - beta: f32, -}} - -@group(0) @binding(0) var out: array<{t}>; -@group(0) @binding(1) var params: BetaParams; - -@compute @workgroup_size(256) -fn beta_dist_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if idx < params.numel {{ - var state = pcg_init(params.seed, idx); - let x = sample_gamma_mt(&state, params.alpha, 1.0); - let y = sample_gamma_mt(&state, params.beta, 1.0); - out[idx] = {t}(x / (x + y)); - }} -}} -"#, - t = t, - suffix = suffix, - rng = DISTRIBUTION_RNG_WGSL - )) -} - -/// Generate WGSL shader for Gamma distribution sampling -pub fn generate_gamma_dist_shader(dtype: DType) -> Result { - check_float_dtype(dtype, "gamma")?; - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Gamma distribution sampling for {t} -{rng} -const WORKGROUP_SIZE: u32 = 256u; - -struct GammaParams {{ - numel: u32, - seed: u32, - shape: f32, - scale: f32, -}} - -@group(0) @binding(0) var out: array<{t}>; -@group(0) @binding(1) var params: GammaParams; - -@compute @workgroup_size(256) -fn gamma_dist_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if idx < params.numel {{ - var state = pcg_init(params.seed, idx); - out[idx] = {t}(sample_gamma_mt(&state, params.shape, params.scale)); - }} -}} -"#, - t = t, - suffix = suffix, - rng = DISTRIBUTION_RNG_WGSL - )) -} - -/// Generate WGSL shader for Exponential distribution sampling -pub fn generate_exponential_shader(dtype: DType) -> Result { - check_float_dtype(dtype, "exponential")?; - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Exponential distribution sampling for {t} -{rng} -const WORKGROUP_SIZE: u32 = 256u; - -struct ExponentialParams {{ - numel: u32, - seed: u32, - rate: f32, - _pad: u32, -}} - -@group(0) @binding(0) var out: array<{t}>; -@group(0) @binding(1) var params: ExponentialParams; - -@compute @workgroup_size(256) -fn exponential_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if idx < params.numel {{ - var state = pcg_init(params.seed, idx); - let u = max(pcg_uniform(&state), 0.0000001); - out[idx] = {t}(-log(u) / params.rate); - }} -}} -"#, - t = t, - suffix = suffix, - rng = DISTRIBUTION_RNG_WGSL - )) -} - -/// Generate WGSL shader for Poisson distribution sampling -pub fn generate_poisson_shader(dtype: DType) -> Result { - check_float_dtype(dtype, "poisson")?; - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Poisson distribution sampling for {t} -{rng} -const WORKGROUP_SIZE: u32 = 256u; - -struct PoissonParams {{ - numel: u32, - seed: u32, - lambda: f32, - _pad: u32, -}} - -@group(0) @binding(0) var out: array<{t}>; -@group(0) @binding(1) var params: PoissonParams; - -@compute @workgroup_size(256) -fn poisson_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if idx < params.numel {{ - var state = pcg_init(params.seed, idx); - - // Knuth's algorithm for small lambda - if params.lambda < 30.0 {{ - let L = exp(-params.lambda); - var k = 0u; - var p = 1.0; - - for (var i = 0u; i < 1000u; i = i + 1u) {{ - p = p * pcg_uniform(&state); - if p <= L {{ - break; - }} - k = k + 1u; - }} - out[idx] = {t}(f32(k)); - }} else {{ - // Normal approximation for large lambda - let z = sample_normal(&state); - let result = max(0.0, round(params.lambda + sqrt(params.lambda) * z)); - out[idx] = {t}(result); - }} - }} -}} -"#, - t = t, - suffix = suffix, - rng = DISTRIBUTION_RNG_WGSL - )) -} - -/// Generate WGSL shader for Binomial distribution sampling -pub fn generate_binomial_shader(dtype: DType) -> Result { - check_float_dtype(dtype, "binomial")?; - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Binomial distribution sampling for {t} -{rng} -const WORKGROUP_SIZE: u32 = 256u; - -struct BinomialParams {{ - numel: u32, - seed: u32, - n_trials: u32, - p: f32, -}} - -@group(0) @binding(0) var out: array<{t}>; -@group(0) @binding(1) var params: BinomialParams; - -@compute @workgroup_size(256) -fn binomial_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if idx < params.numel {{ - var state = pcg_init(params.seed, idx); - - let n = params.n_trials; - let p = params.p; - - // Direct simulation for small n - if n <= 64u {{ - var successes = 0u; - for (var i = 0u; i < n; i = i + 1u) {{ - if pcg_uniform(&state) < p {{ - successes = successes + 1u; - }} - }} - out[idx] = {t}(f32(successes)); - }} else {{ - // Normal approximation for large n - let mean = f32(n) * p; - let std_dev = sqrt(mean * (1.0 - p)); - let z = sample_normal(&state); - let result = clamp(round(mean + std_dev * z), 0.0, f32(n)); - out[idx] = {t}(result); - }} - }} -}} -"#, - t = t, - suffix = suffix, - rng = DISTRIBUTION_RNG_WGSL - )) -} - -/// Generate WGSL shader for Laplace distribution sampling -pub fn generate_laplace_shader(dtype: DType) -> Result { - check_float_dtype(dtype, "laplace")?; - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Laplace distribution sampling for {t} -{rng} -const WORKGROUP_SIZE: u32 = 256u; - -struct LaplaceParams {{ - numel: u32, - seed: u32, - loc: f32, - scale: f32, -}} - -@group(0) @binding(0) var out: array<{t}>; -@group(0) @binding(1) var params: LaplaceParams; - -@compute @workgroup_size(256) -fn laplace_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if idx < params.numel {{ - var state = pcg_init(params.seed, idx); - let u = pcg_uniform(&state) - 0.5; - let result = params.loc - params.scale * sign(u) * log(1.0 - 2.0 * abs(u)); - out[idx] = {t}(result); - }} -}} -"#, - t = t, - suffix = suffix, - rng = DISTRIBUTION_RNG_WGSL - )) -} - -/// Generate WGSL shader for Chi-squared distribution sampling -pub fn generate_chi_squared_shader(dtype: DType) -> Result { - check_float_dtype(dtype, "chi_squared")?; - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Chi-squared distribution sampling for {t} -{rng} -const WORKGROUP_SIZE: u32 = 256u; - -struct ChiSquaredParams {{ - numel: u32, - seed: u32, - df: f32, - _pad: u32, -}} - -@group(0) @binding(0) var out: array<{t}>; -@group(0) @binding(1) var params: ChiSquaredParams; - -@compute @workgroup_size(256) -fn chi_squared_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if idx < params.numel {{ - var state = pcg_init(params.seed, idx); - // Chi-squared(df) = Gamma(df/2, 2) - out[idx] = {t}(sample_gamma_mt(&state, params.df / 2.0, 2.0)); - }} -}} -"#, - t = t, - suffix = suffix, - rng = DISTRIBUTION_RNG_WGSL - )) -} - -/// Generate WGSL shader for Student's t distribution sampling -pub fn generate_student_t_shader(dtype: DType) -> Result { - check_float_dtype(dtype, "student_t")?; - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Student's t distribution sampling for {t} -{rng} -const WORKGROUP_SIZE: u32 = 256u; - -struct StudentTParams {{ - numel: u32, - seed: u32, - df: f32, - _pad: u32, -}} - -@group(0) @binding(0) var out: array<{t}>; -@group(0) @binding(1) var params: StudentTParams; - -@compute @workgroup_size(256) -fn student_t_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if idx < params.numel {{ - var state = pcg_init(params.seed, idx); - let z = sample_normal(&state); - let chi2 = sample_gamma_mt(&state, params.df / 2.0, 2.0); - out[idx] = {t}(z / sqrt(chi2 / params.df)); - }} -}} -"#, - t = t, - suffix = suffix, - rng = DISTRIBUTION_RNG_WGSL - )) -} - -/// Generate WGSL shader for F distribution sampling -pub fn generate_f_distribution_shader(dtype: DType) -> Result { - check_float_dtype(dtype, "f_distribution")?; - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// F distribution sampling for {t} -{rng} -const WORKGROUP_SIZE: u32 = 256u; - -struct FDistributionParams {{ - numel: u32, - seed: u32, - df1: f32, - df2: f32, -}} - -@group(0) @binding(0) var out: array<{t}>; -@group(0) @binding(1) var params: FDistributionParams; - -@compute @workgroup_size(256) -fn f_distribution_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if idx < params.numel {{ - var state = pcg_init(params.seed, idx); - let chi2_1 = sample_gamma_mt(&state, params.df1 / 2.0, 2.0); - let chi2_2 = sample_gamma_mt(&state, params.df2 / 2.0, 2.0); - out[idx] = {t}((chi2_1 / params.df1) / (chi2_2 / params.df2)); - }} -}} -"#, - t = t, - suffix = suffix, - rng = DISTRIBUTION_RNG_WGSL - )) -} - -/// Generate WGSL shader for multinomial count operation -/// -/// Performs CDF lookup for uniform samples and counts occurrences per category. -/// Used for multinomial sampling: given uniform samples and a CDF, counts how -/// many samples fall into each category. -pub fn generate_multinomial_count_shader(dtype: DType) -> Result { - check_float_dtype(dtype, "multinomial_count")?; - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Multinomial count shader for {t} -// Performs CDF lookup for uniform samples and counts occurrences per category - -const WORKGROUP_SIZE: u32 = 256u; - -struct MultinomialCountParams {{ - k: u32, // Number of categories - n_trials: u32, // Number of trials per sample - n_samples: u32, // Number of samples - _pad: u32, -}} - -@group(0) @binding(0) var cdf: array<{t}>; -@group(0) @binding(1) var uniforms: array<{t}>; -@group(0) @binding(2) var counts: array<{t}>; -@group(0) @binding(3) var params: MultinomialCountParams; - -// Binary search to find category for uniform sample -fn find_category(u: {t}, k: u32) -> u32 {{ - var lo: u32 = 0u; - var hi: u32 = k; - while (lo < hi) {{ - let mid = lo + (hi - lo) / 2u; - if (cdf[mid] <= u) {{ - lo = mid + 1u; - }} else {{ - hi = mid; - }} - }} - return min(lo, k - 1u); -}} - -@compute @workgroup_size(256) -fn multinomial_count_{suffix}(@builtin(global_invocation_id) global_id: vec3) {{ - let sample_idx = global_id.x; - let k = params.k; - let n_trials = params.n_trials; - let n_samples = params.n_samples; - - if (sample_idx >= n_samples) {{ - return; - }} - - // Initialize counts for this sample to zero - for (var c: u32 = 0u; c < k; c++) {{ - counts[sample_idx * k + c] = {t}(0.0); - }} - - // Process each trial - for (var t_idx: u32 = 0u; t_idx < n_trials; t_idx++) {{ - let u = uniforms[sample_idx * n_trials + t_idx]; - let category = find_category(u, k); - counts[sample_idx * k + category] += {t}(1.0); - }} -}} -"#, - t = t, - suffix = suffix, - )) -} diff --git a/src/runtime/wgpu/shaders/generator/fft.rs b/src/runtime/wgpu/shaders/generator/fft.rs deleted file mode 100644 index 46be33c9..00000000 --- a/src/runtime/wgpu/shaders/generator/fft.rs +++ /dev/null @@ -1,485 +0,0 @@ -//! WGSL shader generation for FFT operations -//! -//! Generates Stockham FFT shaders using `vec2` for complex numbers. -//! WGSL doesn't have native complex type, so we use vec2 (re, im). - -use crate::error::Result; - -/// Maximum FFT size for shared memory implementation -pub const MAX_WORKGROUP_FFT_SIZE: usize = 256; - -/// Generate complex arithmetic helper functions -fn complex_helpers() -> &'static str { - r#" -// Complex number helpers (vec2: x=real, y=imag) -fn cmul(a: vec2, b: vec2) -> vec2 { - return vec2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); -} - -fn cadd(a: vec2, b: vec2) -> vec2 { - return a + b; -} - -fn csub(a: vec2, b: vec2) -> vec2 { - return a - b; -} - -fn cscale(a: vec2, s: f32) -> vec2 { - return vec2(a.x * s, a.y * s); -} - -fn cconj(a: vec2) -> vec2 { - return vec2(a.x, -a.y); -} - -// Compute e^(i*theta) = cos(theta) + i*sin(theta) -fn cexp_i(theta: f32) -> vec2 { - return vec2(cos(theta), sin(theta)); -} -"# -} - -/// Generate batched Stockham FFT shader for small transforms -/// -/// Each workgroup processes one FFT. Uses workgroup shared memory for ping-pong. -pub fn generate_stockham_fft_shader() -> Result { - Ok(format!( - r#"// Stockham FFT shader for WebGPU -// Complex numbers as vec2 (re, im) - -const PI: f32 = 3.14159265358979323846; -const WORKGROUP_SIZE: u32 = 256u; - -struct FftParams {{ - n: u32, - log_n: u32, - inverse: i32, - scale: f32, - batch_size: u32, - _pad1: u32, - _pad2: u32, - _pad3: u32, -}} - -@group(0) @binding(0) var fft_input: array>; -@group(0) @binding(1) var fft_output: array>; -@group(0) @binding(2) var fft_params: FftParams; - -// Workgroup shared memory for ping-pong -var smem_a: array, {max_size}>; -var smem_b: array, {max_size}>; -{complex_helpers} - -@compute @workgroup_size(WORKGROUP_SIZE) -fn stockham_fft_small( - @builtin(workgroup_id) wg_id: vec3, - @builtin(local_invocation_id) local_id: vec3 -) {{ - let batch_idx = wg_id.x; - let tid = local_id.x; - let n = fft_params.n; - let log_n = fft_params.log_n; - let inverse = fft_params.inverse; - let scale_factor = fft_params.scale; - - // Sign for twiddle factor - let sign = select(-1.0, 1.0, inverse != 0); - - // Load input to shared memory - let base_offset = batch_idx * n; - for (var i = tid; i < n; i = i + WORKGROUP_SIZE) {{ - smem_a[i] = fft_input[base_offset + i]; - }} - workgroupBarrier(); - - // Perform Stockham FFT stages - var use_a = true; - for (var stage: u32 = 0u; stage < log_n; stage = stage + 1u) {{ - let m = 1u << (stage + 1u); - let half_m = 1u << stage; - - for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) {{ - let group = i / half_m; - let pair = i % half_m; - - let even_idx = group * half_m + pair; - let odd_idx = even_idx + n / 2u; - - let out_even_idx = group * m + pair; - let out_odd_idx = out_even_idx + half_m; - - // Twiddle factor - let theta = sign * 2.0 * PI * f32(pair) / f32(m); - let twiddle = cexp_i(theta); - - var even_val: vec2; - var odd_val: vec2; - - if (use_a) {{ - even_val = smem_a[even_idx]; - odd_val = cmul(smem_a[odd_idx], twiddle); - }} else {{ - even_val = smem_b[even_idx]; - odd_val = cmul(smem_b[odd_idx], twiddle); - }} - - let sum = cadd(even_val, odd_val); - let diff = csub(even_val, odd_val); - - if (use_a) {{ - smem_b[out_even_idx] = sum; - smem_b[out_odd_idx] = diff; - }} else {{ - smem_a[out_even_idx] = sum; - smem_a[out_odd_idx] = diff; - }} - }} - - workgroupBarrier(); - use_a = !use_a; - }} - - // Write output with scaling - for (var i = tid; i < n; i = i + WORKGROUP_SIZE) {{ - var result: vec2; - if (use_a) {{ - result = smem_a[i]; - }} else {{ - result = smem_b[i]; - }} - fft_output[base_offset + i] = cscale(result, scale_factor); - }} -}} - -// Single stage kernel for large FFTs (N > workgroup FFT size) -@compute @workgroup_size(WORKGROUP_SIZE) -fn stockham_fft_stage( - @builtin(global_invocation_id) gid: vec3 -) {{ - let n = fft_params.n; - let stage = fft_params.log_n; // Reuse log_n as current stage - let inverse = fft_params.inverse; - let batch_idx = gid.y; - - let sign = select(-1.0, 1.0, inverse != 0); - - let m = 1u << (stage + 1u); - let half_m = 1u << stage; - - let i = gid.x; - if (i >= n / 2u) {{ - return; - }} - - let group = i / half_m; - let pair = i % half_m; - - let base_offset = batch_idx * n; - let even_idx = base_offset + group * half_m + pair; - let odd_idx = even_idx + n / 2u; - - let out_even_idx = base_offset + group * m + pair; - let out_odd_idx = out_even_idx + half_m; - - // Twiddle factor - let theta = sign * 2.0 * PI * f32(pair) / f32(m); - let twiddle = cexp_i(theta); - - let even_val = fft_input[even_idx]; - let odd_val = cmul(fft_input[odd_idx], twiddle); - - fft_output[out_even_idx] = cadd(even_val, odd_val); - fft_output[out_odd_idx] = csub(even_val, odd_val); -}} - -// Scale complex array -@compute @workgroup_size(WORKGROUP_SIZE) -fn scale_complex( - @builtin(global_invocation_id) gid: vec3 -) {{ - let idx = gid.x; - let n = fft_params.n; - let scale_factor = fft_params.scale; - - if (idx < n) {{ - fft_output[idx] = cscale(fft_input[idx], scale_factor); - }} -}} -"#, - max_size = MAX_WORKGROUP_FFT_SIZE, - complex_helpers = complex_helpers() - )) -} - -/// Generate FFT shift shader -pub fn generate_fftshift_shader() -> Result { - Ok(format!( - r#"// FFT shift shader - shifts zero-frequency to center - -const WORKGROUP_SIZE: u32 = 256u; - -struct ShiftParams {{ - n: u32, - batch_size: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var shift_input: array>; -@group(0) @binding(1) var shift_output: array>; -@group(0) @binding(2) var shift_params: ShiftParams; -{complex_helpers} - -@compute @workgroup_size(WORKGROUP_SIZE) -fn fftshift( - @builtin(global_invocation_id) gid: vec3 -) {{ - let idx = gid.x; - let batch_idx = gid.y; - let n = shift_params.n; - - if (idx >= n) {{ - return; - }} - - let base_offset = batch_idx * n; - let half_n = n / 2u; - - // Swap first half with second half - var src_idx: u32; - if (idx < half_n) {{ - src_idx = idx + half_n; - }} else {{ - src_idx = idx - half_n; - }} - - shift_output[base_offset + idx] = shift_input[base_offset + src_idx]; -}} - -@compute @workgroup_size(WORKGROUP_SIZE) -fn ifftshift( - @builtin(global_invocation_id) gid: vec3 -) {{ - let idx = gid.x; - let batch_idx = gid.y; - let n = shift_params.n; - - if (idx >= n) {{ - return; - }} - - let base_offset = batch_idx * n; - let half_n = (n + 1u) / 2u; // Ceiling division for odd n - - // Inverse shift - var src_idx: u32; - if (idx < n - half_n) {{ - src_idx = idx + half_n; - }} else {{ - src_idx = idx - (n - half_n); - }} - - shift_output[base_offset + idx] = shift_input[base_offset + src_idx]; -}} -"#, - complex_helpers = complex_helpers() - )) -} - -/// Generate rfft pack shader (real to complex) -pub fn generate_rfft_pack_shader() -> Result { - Ok(r#"// rfft pack shader - converts real input to complex - -const WORKGROUP_SIZE: u32 = 256u; - -struct PackParams { - n: u32, - batch_size: u32, - _pad1: u32, - _pad2: u32, -} - -@group(0) @binding(0) var pack_input: array; -@group(0) @binding(1) var pack_output: array>; -@group(0) @binding(2) var pack_params: PackParams; - -@compute @workgroup_size(WORKGROUP_SIZE) -fn rfft_pack( - @builtin(global_invocation_id) gid: vec3 -) { - let idx = gid.x; - let batch_idx = gid.y; - let n = pack_params.n; - - if (idx >= n) { - return; - } - - let in_offset = batch_idx * n; - let out_offset = batch_idx * n; - - pack_output[out_offset + idx] = vec2(pack_input[in_offset + idx], 0.0); -} -"# - .to_string()) -} - -/// Generate irfft unpack shader (complex to real) -pub fn generate_irfft_unpack_shader() -> Result { - Ok(r#"// irfft unpack shader - extracts real part from complex - -const WORKGROUP_SIZE: u32 = 256u; - -struct UnpackParams { - n: u32, - batch_size: u32, - _pad1: u32, - _pad2: u32, -} - -@group(0) @binding(0) var unpack_input: array>; -@group(0) @binding(1) var unpack_output: array; -@group(0) @binding(2) var unpack_params: UnpackParams; - -@compute @workgroup_size(WORKGROUP_SIZE) -fn irfft_unpack( - @builtin(global_invocation_id) gid: vec3 -) { - let idx = gid.x; - let batch_idx = gid.y; - let n = unpack_params.n; - - if (idx >= n) { - return; - } - - let in_offset = batch_idx * n; - let out_offset = batch_idx * n; - - unpack_output[out_offset + idx] = unpack_input[in_offset + idx].x; -} -"# - .to_string()) -} - -/// Generate Hermitian extend shader for rfft -pub fn generate_hermitian_extend_shader() -> Result { - Ok( - r#"// Hermitian extend shader - extends N/2+1 complex to N complex using symmetry - -const WORKGROUP_SIZE: u32 = 256u; - -struct ExtendParams { - n: u32, // Full FFT size - half_n: u32, // N/2 + 1 (input size) - batch_size: u32, - _pad: u32, -} - -@group(0) @binding(0) var extend_input: array>; -@group(0) @binding(1) var extend_output: array>; -@group(0) @binding(2) var extend_params: ExtendParams; - -@compute @workgroup_size(WORKGROUP_SIZE) -fn hermitian_extend( - @builtin(global_invocation_id) gid: vec3 -) { - let idx = gid.x; - let batch_idx = gid.y; - let n = extend_params.n; - let half_n = extend_params.half_n; - - if (idx >= n) { - return; - } - - let in_offset = batch_idx * half_n; - let out_offset = batch_idx * n; - - if (idx < half_n) { - // Direct copy for first half - extend_output[out_offset + idx] = extend_input[in_offset + idx]; - } else { - // Conjugate symmetry for second half: X[N-k] = conj(X[k]) - let k = n - idx; - let val = extend_input[in_offset + k]; - extend_output[out_offset + idx] = vec2(val.x, -val.y); - } -} -"# - .to_string(), - ) -} - -/// Generate rfft truncate shader -pub fn generate_rfft_truncate_shader() -> Result { - Ok( - r#"// rfft truncate shader - keeps only N/2+1 complex values from full FFT - -const WORKGROUP_SIZE: u32 = 256u; - -struct TruncateParams { - n: u32, // Full FFT size (input) - half_n: u32, // N/2 + 1 (output size) - batch_size: u32, - _pad: u32, -} - -@group(0) @binding(0) var truncate_input: array>; -@group(0) @binding(1) var truncate_output: array>; -@group(0) @binding(2) var truncate_params: TruncateParams; - -@compute @workgroup_size(WORKGROUP_SIZE) -fn rfft_truncate( - @builtin(global_invocation_id) gid: vec3 -) { - let idx = gid.x; - let batch_idx = gid.y; - let n = truncate_params.n; - let half_n = truncate_params.half_n; - - if (idx >= half_n) { - return; - } - - let in_offset = batch_idx * n; - let out_offset = batch_idx * half_n; - - truncate_output[out_offset + idx] = truncate_input[in_offset + idx]; -} -"# - .to_string(), - ) -} - -/// Generate copy complex shader -pub fn generate_copy_complex_shader() -> Result { - Ok(r#"// Copy complex array - -const WORKGROUP_SIZE: u32 = 256u; - -struct CopyParams { - n: u32, - _pad1: u32, - _pad2: u32, - _pad3: u32, -} - -@group(0) @binding(0) var copy_input: array>; -@group(0) @binding(1) var copy_output: array>; -@group(0) @binding(2) var copy_params: CopyParams; - -@compute @workgroup_size(WORKGROUP_SIZE) -fn copy_complex( - @builtin(global_invocation_id) gid: vec3 -) { - let idx = gid.x; - let n = copy_params.n; - - if (idx < n) { - copy_output[idx] = copy_input[idx]; - } -} -"# - .to_string()) -} diff --git a/src/runtime/wgpu/shaders/generator/index.rs b/src/runtime/wgpu/shaders/generator/index.rs deleted file mode 100644 index 9236c6c1..00000000 --- a/src/runtime/wgpu/shaders/generator/index.rs +++ /dev/null @@ -1,1033 +0,0 @@ -//! WGSL shader generation for index, gather, and scatter operations - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -/// Generate WGSL shader for index_select operation -pub fn generate_index_select_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated index_select operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct IndexSelectParams {{ - outer_size: u32, - dim_size: u32, - inner_size: u32, - index_len: u32, -}} - -@group(0) @binding(0) var input: array<{t}>; -@group(0) @binding(1) var indices: array; -@group(0) @binding(2) var output: array<{t}>; -@group(0) @binding(3) var params: IndexSelectParams; - -@compute @workgroup_size(256) -fn index_select_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = params.outer_size * params.index_len * params.inner_size; - if (idx >= total) {{ - return; - }} - - let inner = idx % params.inner_size; - let sel_idx = (idx / params.inner_size) % params.index_len; - let outer = idx / (params.index_len * params.inner_size); - - let index_val = indices[sel_idx]; - if (index_val < 0 || u32(index_val) >= params.dim_size) {{ - output[idx] = {zero}; - return; - }} - - let src_offset = outer * params.dim_size * params.inner_size + u32(index_val) * params.inner_size + inner; - output[idx] = input[src_offset]; -}} -"#, - t = t, - suffix = suffix, - zero = match dtype { - DType::F32 | DType::F16 => "0.0", - _ => "0", - }, - )) -} - -/// Generate WGSL shader for gather operation -pub fn generate_gather_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - // For simplicity, we implement gather with max 4 dimensions - // This is sufficient for most use cases - Ok(format!( - r#"// Auto-generated gather operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; -const MAX_DIMS: u32 = 4u; - -struct GatherParams {{ - ndim: u32, - dim: u32, - total_elements: u32, - _padding: u32, - // Shape and strides packed: [input_shape[0..4], input_strides[0..4], output_shape[0..4], output_strides[0..4]] - input_shape: vec4, - input_strides: vec4, - output_shape: vec4, - output_strides: vec4, -}} - -@group(0) @binding(0) var input: array<{t}>; -@group(0) @binding(1) var indices: array; -@group(0) @binding(2) var output: array<{t}>; -@group(0) @binding(3) var params: GatherParams; - -fn get_shape(arr: vec4, d: u32) -> u32 {{ - if (d == 0u) {{ return arr.x; }} - else if (d == 1u) {{ return arr.y; }} - else if (d == 2u) {{ return arr.z; }} - else {{ return arr.w; }} -}} - -@compute @workgroup_size(256) -fn gather_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= params.total_elements) {{ - return; - }} - - var remaining = idx; - var src_offset: u32 = 0u; - - for (var d: u32 = 0u; d < params.ndim; d = d + 1u) {{ - let out_stride = get_shape(params.output_strides, d); - let coord = remaining / out_stride; - remaining = remaining % out_stride; - - if (d == params.dim) {{ - let index_val = indices[idx]; - let dim_size = get_shape(params.input_shape, d); - if (index_val < 0 || u32(index_val) >= dim_size) {{ - output[idx] = {zero}; - return; - }} - src_offset = src_offset + u32(index_val) * get_shape(params.input_strides, d); - }} else {{ - src_offset = src_offset + coord * get_shape(params.input_strides, d); - }} - }} - - output[idx] = input[src_offset]; -}} -"#, - t = t, - suffix = suffix, - zero = match dtype { - DType::F32 | DType::F16 => "0.0", - _ => "0", - }, - )) -} - -/// Generate WGSL shader for scatter operation -pub fn generate_scatter_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated scatter operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct ScatterParams {{ - ndim: u32, - dim: u32, - src_total: u32, - _padding: u32, - output_shape: vec4, - output_strides: vec4, - src_shape: vec4, - src_strides: vec4, -}} - -@group(0) @binding(0) var src: array<{t}>; -@group(0) @binding(1) var indices: array; -@group(0) @binding(2) var output: array<{t}>; -@group(0) @binding(3) var params: ScatterParams; - -fn get_shape(arr: vec4, d: u32) -> u32 {{ - if (d == 0u) {{ return arr.x; }} - else if (d == 1u) {{ return arr.y; }} - else if (d == 2u) {{ return arr.z; }} - else {{ return arr.w; }} -}} - -@compute @workgroup_size(256) -fn scatter_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= params.src_total) {{ - return; - }} - - var remaining = idx; - var dst_offset: u32 = 0u; - - for (var d: u32 = 0u; d < params.ndim; d = d + 1u) {{ - let src_stride = get_shape(params.src_strides, d); - let coord = remaining / src_stride; - remaining = remaining % src_stride; - - if (d == params.dim) {{ - let index_val = indices[idx]; - let dim_size = get_shape(params.output_shape, d); - if (index_val < 0 || u32(index_val) >= dim_size) {{ - return; - }} - dst_offset = dst_offset + u32(index_val) * get_shape(params.output_strides, d); - }} else {{ - dst_offset = dst_offset + coord * get_shape(params.output_strides, d); - }} - }} - - output[dst_offset] = src[idx]; -}} - -// Copy kernel for initializing output from input -@group(0) @binding(0) var copy_src: array<{t}>; -@group(0) @binding(1) var copy_dst: array<{t}>; - -struct CopyParams {{ - numel: u32, -}} - -@group(0) @binding(2) var copy_params: CopyParams; - -@compute @workgroup_size(256) -fn copy_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < copy_params.numel) {{ - copy_dst[idx] = copy_src[idx]; - }} -}} -"#, - t = t, - suffix = suffix, - )) -} - -/// Generate WGSL shader for index_put operation -/// -/// This is the inverse of index_select: puts values from src at positions -/// specified by indices along a dimension. Output should be pre-initialized -/// with a copy of the input tensor. -pub fn generate_index_put_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated index_put operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct IndexPutParams {{ - outer_size: u32, - dim_size: u32, - inner_size: u32, - index_len: u32, -}} - -@group(0) @binding(0) var indices: array; -@group(0) @binding(1) var src: array<{t}>; -@group(0) @binding(2) var output: array<{t}>; -@group(0) @binding(3) var params: IndexPutParams; - -@compute @workgroup_size(256) -fn index_put_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = params.outer_size * params.index_len * params.inner_size; - if (idx >= total) {{ - return; - }} - - let inner = idx % params.inner_size; - let sel_idx = (idx / params.inner_size) % params.index_len; - let outer = idx / (params.index_len * params.inner_size); - - let index_val = indices[sel_idx]; - if (index_val < 0 || u32(index_val) >= params.dim_size) {{ - return; // Out of bounds - skip - }} - - let dst_offset = outer * params.dim_size * params.inner_size + u32(index_val) * params.inner_size + inner; - output[dst_offset] = src[idx]; -}} -"#, - t = t, - suffix = suffix, - )) -} - -/// Generate WGSL shader for embedding_lookup operation -/// -/// This is the industry-standard embedding lookup operation used in neural networks -/// for word embeddings, entity embeddings, etc. -/// -/// Input: embeddings `[vocab_size, embedding_dim]`, indices `[num_indices]` -/// Output: output `[num_indices, embedding_dim]` -pub fn generate_embedding_lookup_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated embedding_lookup operation for {t} -// Industry-standard embedding table lookup used in neural networks. -// Each thread handles one index lookup and copies the full embedding row. - -const WORKGROUP_SIZE: u32 = 256u; - -struct EmbeddingLookupParams {{ - num_indices: u32, - vocab_size: u32, - embedding_dim: u32, - _pad0: u32, -}} - -@group(0) @binding(0) var embeddings: array<{t}>; -@group(0) @binding(1) var indices: array; -@group(0) @binding(2) var output: array<{t}>; -@group(0) @binding(3) var params: EmbeddingLookupParams; - -@compute @workgroup_size(256) -fn embedding_lookup_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= params.num_indices) {{ - return; - }} - - let index_val = indices[idx]; - - // Check bounds - if (index_val < 0 || u32(index_val) >= params.vocab_size) {{ - // Out of bounds - fill with zeros - let out_start = idx * params.embedding_dim; - for (var i: u32 = 0u; i < params.embedding_dim; i = i + 1u) {{ - output[out_start + i] = {zero}; - }} - return; - }} - - // Copy the entire embedding row to output - let emb_start = u32(index_val) * params.embedding_dim; - let out_start = idx * params.embedding_dim; - for (var i: u32 = 0u; i < params.embedding_dim; i = i + 1u) {{ - output[out_start + i] = embeddings[emb_start + i]; - }} -}} -"#, - t = t, - suffix = suffix, - zero = match dtype { - DType::F32 | DType::F16 => "0.0", - _ => "0", - }, - )) -} - -/// Generate WGSL shader for gather_nd operation. -/// -/// Gathers slices from input using N-dimensional indices. -pub fn generate_gather_nd_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated gather_nd operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; -const MAX_DIMS: u32 = 8u; - -struct GatherNdParams {{ - num_slices: u32, - slice_size: u32, - index_depth: u32, - ndim: u32, - input_shape: array, - input_strides: array, -}} - -@group(0) @binding(0) var gather_nd_input: array<{t}>; -@group(0) @binding(1) var gather_nd_indices: array; -@group(0) @binding(2) var gather_nd_output: array<{t}>; -@group(0) @binding(3) var gather_nd_params: GatherNdParams; - -@compute @workgroup_size(256) -fn gather_nd_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = gather_nd_params.num_slices * gather_nd_params.slice_size; - if (idx >= total) {{ - return; - }} - - let slice_idx = idx / gather_nd_params.slice_size; - let element_in_slice = idx % gather_nd_params.slice_size; - - // Compute input offset from indices - var input_offset: u32 = 0u; - let indices_offset = slice_idx * gather_nd_params.index_depth; - - for (var d: u32 = 0u; d < gather_nd_params.index_depth; d = d + 1u) {{ - let coord = gather_nd_indices[indices_offset + d]; - if (coord < 0 || u32(coord) >= gather_nd_params.input_shape[d]) {{ - gather_nd_output[idx] = {zero}; - return; - }} - input_offset = input_offset + u32(coord) * gather_nd_params.input_strides[d]; - }} - - // Add offset for element within slice - if (gather_nd_params.slice_size > 1u) {{ - var remaining = element_in_slice; - for (var d: u32 = gather_nd_params.index_depth; d < gather_nd_params.ndim; d = d + 1u) {{ - let dim_size = gather_nd_params.input_shape[d]; - let coord = remaining / gather_nd_params.input_strides[d]; - remaining = remaining % gather_nd_params.input_strides[d]; - input_offset = input_offset + coord * gather_nd_params.input_strides[d]; - }} - }} - - gather_nd_output[idx] = gather_nd_input[input_offset + element_in_slice]; -}} -"#, - t = t, - suffix = suffix, - zero = match dtype { - DType::F32 | DType::F16 => "0.0", - _ => "0", - }, - )) -} - -/// Generate WGSL shader for bincount operation. -/// -/// Counts occurrences of each value in an integer tensor, optionally with weights. -/// Note: Uses atomic operations for accumulation. -pub fn generate_bincount_shader(weights_dtype: Option) -> Result { - if let Some(dtype) = weights_dtype { - // Weighted bincount - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated weighted bincount for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct BincountParams {{ - n: u32, - minlength: u32, - _pad0: u32, - _pad1: u32, -}} - -@group(0) @binding(0) var bincount_input: array; -@group(0) @binding(1) var bincount_weights: array<{t}>; -@group(0) @binding(2) var bincount_output: array>; -@group(0) @binding(3) var bincount_params: BincountParams; - -@compute @workgroup_size(256) -fn bincount_weighted_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= bincount_params.n) {{ - return; - }} - - let value = bincount_input[idx]; - if (value < 0 || u32(value) >= bincount_params.minlength) {{ - return; - }} - - let weight = bincount_weights[idx]; - // For float weights, we need to use atomic operations - // WebGPU only supports atomic ops on u32/i32, so we use bitcast - let weight_bits = bitcast(weight); - atomicAdd(&bincount_output[u32(value)], weight_bits); -}} -"#, - t = t, - suffix = suffix, - )) - } else { - // Unweighted bincount - Ok(r#"// Auto-generated unweighted bincount - -const WORKGROUP_SIZE: u32 = 256u; - -struct BincountParams { - n: u32, - minlength: u32, - _pad0: u32, - _pad1: u32, -} - -@group(0) @binding(0) var bincount_input: array; -@group(0) @binding(1) var bincount_output: array>; -@group(0) @binding(2) var bincount_params: BincountParams; - -@compute @workgroup_size(256) -fn bincount_i32(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx >= bincount_params.n) { - return; - } - - let value = bincount_input[idx]; - if (value < 0 || u32(value) >= bincount_params.minlength) { - return; - } - - atomicAdd(&bincount_output[u32(value)], 1u); -} -"# - .to_string()) - } -} - -/// Generate WGSL shader for scatter_reduce operation. -/// -/// Scatters values with reduction (sum, max, min). -/// Note: Uses atomic operations. -pub fn generate_scatter_reduce_shader(dtype: DType, op: &str) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - let atomic_op = match op { - "sum" => "atomicAdd", - "max" => "atomicMax", - "min" => "atomicMin", - _ => { - return Err(crate::error::Error::InvalidArgument { - arg: "op", - reason: format!("scatter_reduce op must be sum, max, or min, got {}", op), - }); - } - }; - - // For f32, we need CAS loops since atomicMax/Min only work on integers - let is_float = matches!(dtype, DType::F32 | DType::F16); - - if is_float && op != "sum" { - // Float max/min requires CAS loop - Ok(format!( - r#"// Auto-generated scatter_reduce_{op} for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct ScatterReduceParams {{ - dim: u32, - outer_size: u32, - dim_size: u32, - inner_size: u32, - src_dim_size: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -// Note: All storage buffers use read_write to match the pipeline cache layout. -// The actual access pattern is: src (read), indices (read), dst (read_write). -@group(0) @binding(0) var scatter_src: array<{t}>; -@group(0) @binding(1) var scatter_indices: array; -@group(0) @binding(2) var scatter_dst: array>; -@group(0) @binding(3) var scatter_params: ScatterReduceParams; - -@compute @workgroup_size(256) -fn scatter_reduce_{op}_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; - if (idx >= total) {{ - return; - }} - - let inner = idx % scatter_params.inner_size; - let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; - let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); - - let index_val = scatter_indices[src_dim_idx]; - if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) {{ - return; - }} - - let src_val = scatter_src[idx]; - let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; - - // CAS loop for {op} - var old_bits: u32; - var new_bits: u32; - loop {{ - old_bits = atomicLoad(&scatter_dst[dst_idx]); - let old_val = bitcast(old_bits); - let new_val = {cmp_expr}; - new_bits = bitcast(new_val); - let result = atomicCompareExchangeWeak(&scatter_dst[dst_idx], old_bits, new_bits); - if (result.exchanged) {{ - break; - }} - }} -}} -"#, - t = t, - suffix = suffix, - op = op, - cmp_expr = if op == "max" { - "max(old_val, src_val)" - } else { - "min(old_val, src_val)" - }, - )) - } else if is_float { - // Float sum uses atomicAdd with bitcast - Ok(format!( - r#"// Auto-generated scatter_reduce_sum for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct ScatterReduceParams {{ - dim: u32, - outer_size: u32, - dim_size: u32, - inner_size: u32, - src_dim_size: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -// Note: All storage buffers use read_write to match the pipeline cache layout. -// The actual access pattern is: src (read), indices (read), dst (read_write). -@group(0) @binding(0) var scatter_src: array<{t}>; -@group(0) @binding(1) var scatter_indices: array; -@group(0) @binding(2) var scatter_dst: array>; -@group(0) @binding(3) var scatter_params: ScatterReduceParams; - -@compute @workgroup_size(256) -fn scatter_reduce_sum_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; - if (idx >= total) {{ - return; - }} - - let inner = idx % scatter_params.inner_size; - let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; - let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); - - let index_val = scatter_indices[src_dim_idx]; - if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) {{ - return; - }} - - let src_val = scatter_src[idx]; - let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; - - // CAS loop for atomic float add - var old_bits: u32; - var new_bits: u32; - loop {{ - old_bits = atomicLoad(&scatter_dst[dst_idx]); - let old_val = bitcast(old_bits); - let new_val = old_val + src_val; - new_bits = bitcast(new_val); - let result = atomicCompareExchangeWeak(&scatter_dst[dst_idx], old_bits, new_bits); - if (result.exchanged) {{ - break; - }} - }} -}} -"#, - t = t, - suffix = suffix, - )) - } else { - // Integer types can use native atomic ops - Ok(format!( - r#"// Auto-generated scatter_reduce_{op} for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct ScatterReduceParams {{ - dim: u32, - outer_size: u32, - dim_size: u32, - inner_size: u32, - src_dim_size: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var scatter_src: array<{t}>; -@group(0) @binding(1) var scatter_indices: array; -@group(0) @binding(2) var scatter_dst: array>; -@group(0) @binding(3) var scatter_params: ScatterReduceParams; - -@compute @workgroup_size(256) -fn scatter_reduce_{op}_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; - if (idx >= total) {{ - return; - }} - - let inner = idx % scatter_params.inner_size; - let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; - let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); - - let index_val = scatter_indices[src_dim_idx]; - if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) {{ - return; - }} - - let src_val = scatter_src[idx]; - let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; - - {atomic_op}(&scatter_dst[dst_idx], src_val); -}} -"#, - t = t, - suffix = suffix, - op = op, - atomic_t = if dtype == DType::I32 { "i32" } else { "u32" }, - atomic_op = atomic_op, - )) - } -} - -/// Generate WGSL shader for scatter_reduce prod operation. -/// -/// Uses CAS loop for atomic multiply (no native atomicMul in WGSL). -/// Only supports F32 (uses bitcast to u32 for atomics). -pub fn generate_scatter_reduce_prod_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - let is_float = matches!(dtype, DType::F32); - - if is_float { - Ok(format!( - r#"// Auto-generated scatter_reduce_prod for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct ScatterReduceParams {{ - dim: u32, - outer_size: u32, - dim_size: u32, - inner_size: u32, - src_dim_size: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var scatter_src: array<{t}>; -@group(0) @binding(1) var scatter_indices: array; -@group(0) @binding(2) var scatter_dst: array>; -@group(0) @binding(3) var scatter_params: ScatterReduceParams; - -@compute @workgroup_size(256) -fn scatter_reduce_prod_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; - if (idx >= total) {{ - return; - }} - - let inner = idx % scatter_params.inner_size; - let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; - let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); - - let index_val = scatter_indices[src_dim_idx]; - if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) {{ - return; - }} - - let src_val = scatter_src[idx]; - let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; - - // CAS loop for atomic multiply - var old_bits: u32; - var new_bits: u32; - loop {{ - old_bits = atomicLoad(&scatter_dst[dst_idx]); - let old_val = bitcast(old_bits); - let new_val = old_val * src_val; - new_bits = bitcast(new_val); - let result = atomicCompareExchangeWeak(&scatter_dst[dst_idx], old_bits, new_bits); - if (result.exchanged) {{ - break; - }} - }} -}} -"#, - t = t, - suffix = suffix, - )) - } else { - // Integer prod using CAS loop - let atomic_t = if dtype == DType::I32 { "i32" } else { "u32" }; - Ok(format!( - r#"// Auto-generated scatter_reduce_prod for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct ScatterReduceParams {{ - dim: u32, - outer_size: u32, - dim_size: u32, - inner_size: u32, - src_dim_size: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var scatter_src: array<{t}>; -@group(0) @binding(1) var scatter_indices: array; -@group(0) @binding(2) var scatter_dst: array>; -@group(0) @binding(3) var scatter_params: ScatterReduceParams; - -@compute @workgroup_size(256) -fn scatter_reduce_prod_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; - if (idx >= total) {{ - return; - }} - - let inner = idx % scatter_params.inner_size; - let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; - let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); - - let index_val = scatter_indices[src_dim_idx]; - if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) {{ - return; - }} - - let src_val = scatter_src[idx]; - let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; - - // CAS loop for atomic multiply - loop {{ - let old_val = atomicLoad(&scatter_dst[dst_idx]); - let new_val = old_val * src_val; - let result = atomicCompareExchangeWeak(&scatter_dst[dst_idx], old_val, new_val); - if (result.exchanged) {{ - break; - }} - }} -}} -"#, - t = t, - suffix = suffix, - atomic_t = atomic_t, - )) - } -} - -/// Generate WGSL shader for scatter_reduce count (for mean computation). -/// -/// Atomically increments count buffer at scattered positions. -pub fn generate_scatter_reduce_count_shader(dtype: DType) -> Result { - let suffix = dtype_suffix(dtype)?; - - // Count buffer is always u32 (atomic) - Ok(format!( - r#"// Auto-generated scatter_reduce_count for mean computation - -const WORKGROUP_SIZE: u32 = 256u; - -struct ScatterReduceParams {{ - dim: u32, - outer_size: u32, - dim_size: u32, - inner_size: u32, - src_dim_size: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var scatter_indices: array; -@group(0) @binding(1) var scatter_count: array>; -@group(0) @binding(2) var scatter_params: ScatterReduceParams; - -@compute @workgroup_size(256) -fn scatter_reduce_count_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; - if (idx >= total) {{ - return; - }} - - let inner = idx % scatter_params.inner_size; - let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; - let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); - - let index_val = scatter_indices[src_dim_idx]; - if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) {{ - return; - }} - - let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; - - atomicAdd(&scatter_count[dst_idx], 1u); -}} -"#, - suffix = suffix, - )) -} - -/// Generate WGSL shader for scatter_reduce mean divide. -/// -/// Element-wise: output[i] = sum[i] / f32(count[i]). If count == 0, output = 0. -pub fn generate_scatter_reduce_mean_div_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated scatter_reduce_mean_div for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct MeanDivParams {{ - n: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var mean_sum: array<{t}>; -@group(0) @binding(1) var mean_count: array; -@group(0) @binding(2) var mean_output: array<{t}>; -@group(0) @binding(3) var mean_params: MeanDivParams; - -@compute @workgroup_size(256) -fn scatter_reduce_mean_div_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= mean_params.n) {{ - return; - }} - - let c = mean_count[idx]; - if (c > 0u) {{ - mean_output[idx] = mean_sum[idx] / {t}(c); - }} else {{ - mean_output[idx] = {t}(0); - }} -}} -"#, - t = t, - suffix = suffix, - )) -} - -/// Generate WGSL shader for index bounds validation. -/// -/// Validates that all indices are within bounds `[0, dim_size)`. -/// Atomically counts the number of out-of-bounds indices. -/// Returns count in `error_count[0]`. If count > 0, some indices are invalid. -pub fn generate_validate_indices_shader() -> String { - r#"// Auto-generated index bounds validation kernel - -const WORKGROUP_SIZE: u32 = 256u; - -struct ValidateIndicesParams { - index_len: u32, - dim_size: u32, - _pad0: u32, - _pad1: u32, -} - -@group(0) @binding(0) var indices: array; -@group(0) @binding(1) var error_count: atomic; -@group(0) @binding(2) var params: ValidateIndicesParams; - -@compute @workgroup_size(256) -fn validate_indices(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx >= params.index_len) { - return; - } - - let index_val = indices[idx]; - if (index_val < 0 || u32(index_val) >= params.dim_size) { - atomicAdd(&error_count, 1u); - } -} -"# - .to_string() -} - -/// Generate WGSL shader for gather_2d operation. -/// -/// Gathers elements from a 2D matrix at specific (row, col) positions. -/// Input: input `[nrows, ncols]`, rows `[num_indices]`, cols `[num_indices]` -/// Output: output `[num_indices]` -/// -/// For each index i: `output[i] = input[rows[i], cols[i]]` -pub fn generate_gather_2d_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated gather_2d operation for {t} -// Gathers elements from a 2D matrix at (row, col) positions. - -const WORKGROUP_SIZE: u32 = 256u; - -struct Gather2dParams {{ - nrows: u32, - ncols: u32, - num_indices: u32, - _pad: u32, -}} - -@group(0) @binding(0) var input: array<{t}>; -@group(0) @binding(1) var rows: array; -@group(0) @binding(2) var cols: array; -@group(0) @binding(3) var output: array<{t}>; -@group(0) @binding(4) var params: Gather2dParams; - -@compute @workgroup_size(256) -fn gather_2d_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= params.num_indices) {{ - return; - }} - - let r = rows[idx]; - let c = cols[idx]; - - // Bounds checking - if (r < 0 || u32(r) >= params.nrows || c < 0 || u32(c) >= params.ncols) {{ - output[idx] = {zero}; - return; - }} - - // Row-major indexing: input[r, c] = input[r * ncols + c] - let input_idx = u32(r) * params.ncols + u32(c); - output[idx] = input[input_idx]; -}} -"#, - t = t, - suffix = suffix, - zero = match dtype { - DType::F32 | DType::F16 => "0.0", - _ => "0", - }, - )) -} diff --git a/src/runtime/wgpu/shaders/generator/masked.rs b/src/runtime/wgpu/shaders/generator/masked.rs deleted file mode 100644 index 0b112bbf..00000000 --- a/src/runtime/wgpu/shaders/generator/masked.rs +++ /dev/null @@ -1,147 +0,0 @@ -//! WGSL shader generation for masked operations (masked_fill and masked_select) - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -/// Generate WGSL shader for masked_fill operation -pub fn generate_masked_fill_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated masked_fill operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct MaskedFillParams {{ - numel: u32, - fill_value: f32, -}} - -@group(0) @binding(0) var input: array<{t}>; -@group(0) @binding(1) var mask: array; -@group(0) @binding(2) var output: array<{t}>; -@group(0) @binding(3) var params: MaskedFillParams; - -@compute @workgroup_size(256) -fn masked_fill_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= params.numel) {{ - return; - }} - - if (mask[idx] != 0u) {{ - output[idx] = {t}(params.fill_value); - }} else {{ - output[idx] = input[idx]; - }} -}} -"#, - t = t, - suffix = suffix, - )) -} - -/// Generate WGSL shader for masked_select operation -/// This is a two-phase operation: -/// 1. Count phase: count how many elements are selected (uses atomic) -/// 2. Prefix sum phase: compute exclusive prefix sum of mask -/// 3. Gather phase: copy selected elements to output -pub fn generate_masked_select_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated masked_select operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -// Phase 1: Count masked elements -struct CountParams {{ - numel: u32, -}} - -@group(0) @binding(0) var count_mask: array; -@group(0) @binding(1) var count_result: atomic; -@group(0) @binding(2) var count_params: CountParams; - -var shared_count: atomic; - -@compute @workgroup_size(256) -fn masked_count(@builtin(global_invocation_id) gid: vec3, - @builtin(local_invocation_id) lid: vec3) {{ - if (lid.x == 0u) {{ - atomicStore(&shared_count, 0u); - }} - workgroupBarrier(); - - var local_count: u32 = 0u; - var i = gid.x; - while (i < count_params.numel) {{ - if (count_mask[i] != 0u) {{ - local_count = local_count + 1u; - }} - i = i + 256u * 256u; // Grid stride - }} - - atomicAdd(&shared_count, local_count); - workgroupBarrier(); - - if (lid.x == 0u) {{ - atomicAdd(&count_result, atomicLoad(&shared_count)); - }} -}} - -// Phase 2: Compute prefix sum (sequential - for small arrays) -struct PrefixSumParams {{ - numel: u32, -}} - -@group(0) @binding(0) var prefix_mask: array; -@group(0) @binding(1) var prefix_sum: array; -@group(0) @binding(2) var prefix_params: PrefixSumParams; - -@compute @workgroup_size(1) -fn masked_prefix_sum(@builtin(global_invocation_id) gid: vec3) {{ - if (gid.x != 0u) {{ - return; - }} - - var sum: u32 = 0u; - for (var i: u32 = 0u; i < prefix_params.numel; i = i + 1u) {{ - prefix_sum[i] = sum; - if (prefix_mask[i] != 0u) {{ - sum = sum + 1u; - }} - }} -}} - -// Phase 3: Gather selected elements -struct SelectParams {{ - numel: u32, -}} - -@group(0) @binding(0) var select_input: array<{t}>; -@group(0) @binding(1) var select_mask: array; -@group(0) @binding(2) var select_prefix: array; -@group(0) @binding(3) var select_output: array<{t}>; -@group(0) @binding(4) var select_params: SelectParams; - -@compute @workgroup_size(256) -fn masked_select_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= select_params.numel) {{ - return; - }} - - if (select_mask[idx] != 0u) {{ - let out_idx = select_prefix[idx]; - select_output[out_idx] = select_input[idx]; - }} -}} -"#, - t = t, - suffix = suffix, - )) -} diff --git a/src/runtime/wgpu/shaders/generator/matmul.rs b/src/runtime/wgpu/shaders/generator/matmul.rs deleted file mode 100644 index 0a641465..00000000 --- a/src/runtime/wgpu/shaders/generator/matmul.rs +++ /dev/null @@ -1,282 +0,0 @@ -//! WGSL shader generation for matrix multiplication operations - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -/// Generate WGSL shader for matrix multiplication -pub fn generate_matmul_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated matmul operations for {t} - -const TILE_SIZE: u32 = 16u; - -var tile_a: array, 16>; -var tile_b: array, 16>; - -struct MatmulParams {{ - M: u32, - K: u32, - N: u32, - batch_size: u32, -}} - -@group(0) @binding(0) var matmul_a: array<{t}>; -@group(0) @binding(1) var matmul_b: array<{t}>; -@group(0) @binding(2) var matmul_c: array<{t}>; -@group(0) @binding(3) var matmul_params: MatmulParams; - -@compute @workgroup_size(16, 16, 1) -fn matmul_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let M = matmul_params.M; - let K = matmul_params.K; - let N = matmul_params.N; - - let row = group_id.y * TILE_SIZE + local_id.y; - let col = group_id.x * TILE_SIZE + local_id.x; - - var sum: {t} = {zero}; - - let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE; - - for (var t_idx: u32 = 0u; t_idx < num_tiles; t_idx = t_idx + 1u) {{ - let a_col = t_idx * TILE_SIZE + local_id.x; - if (row < M && a_col < K) {{ - tile_a[local_id.y][local_id.x] = matmul_a[row * K + a_col]; - }} else {{ - tile_a[local_id.y][local_id.x] = {zero}; - }} - - let b_row = t_idx * TILE_SIZE + local_id.y; - if (b_row < K && col < N) {{ - tile_b[local_id.y][local_id.x] = matmul_b[b_row * N + col]; - }} else {{ - tile_b[local_id.y][local_id.x] = {zero}; - }} - - workgroupBarrier(); - - for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) {{ - sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x]; - }} - - workgroupBarrier(); - }} - - if (row < M && col < N) {{ - matmul_c[row * N + col] = sum; - }} -}} - -@compute @workgroup_size(16, 16, 1) -fn batched_matmul_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let M = matmul_params.M; - let K = matmul_params.K; - let N = matmul_params.N; - let batch_size = matmul_params.batch_size; - - let batch = group_id.z; - if (batch >= batch_size) {{ - return; - }} - - let row = group_id.y * TILE_SIZE + local_id.y; - let col = group_id.x * TILE_SIZE + local_id.x; - - let a_batch_offset = batch * M * K; - let b_batch_offset = batch * K * N; - let c_batch_offset = batch * M * N; - - var sum: {t} = {zero}; - - let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE; - - for (var t_idx: u32 = 0u; t_idx < num_tiles; t_idx = t_idx + 1u) {{ - let a_col = t_idx * TILE_SIZE + local_id.x; - if (row < M && a_col < K) {{ - tile_a[local_id.y][local_id.x] = matmul_a[a_batch_offset + row * K + a_col]; - }} else {{ - tile_a[local_id.y][local_id.x] = {zero}; - }} - - let b_row = t_idx * TILE_SIZE + local_id.y; - if (b_row < K && col < N) {{ - tile_b[local_id.y][local_id.x] = matmul_b[b_batch_offset + b_row * N + col]; - }} else {{ - tile_b[local_id.y][local_id.x] = {zero}; - }} - - workgroupBarrier(); - - for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) {{ - sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x]; - }} - - workgroupBarrier(); - }} - - if (row < M && col < N) {{ - matmul_c[c_batch_offset + row * N + col] = sum; - }} -}} -"#, - t = t, - suffix = suffix, - zero = match dtype { - DType::F32 | DType::F16 => "0.0", - _ => "0", - }, - )) -} - -/// Generate WGSL shader for fused matrix multiplication with bias addition -/// -/// This implements C = A @ B + bias where: -/// - A has shape `[M, K]` or `[batch, M, K]` -/// - B has shape `[K, N]` or `[batch, K, N]` -/// - bias has shape `[N]` (1D, broadcast across all rows and batches) -/// - C has shape `[M, N]` or `[batch, M, N]` -/// -/// The bias addition is fused into the GEMM epilogue for efficiency, -/// avoiding an extra memory round-trip. -pub fn generate_matmul_bias_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated matmul_bias operations for {t} -// C = A @ B + bias (fused epilogue) - -const TILE_SIZE: u32 = 16u; - -var tile_a: array, 16>; -var tile_b: array, 16>; - -struct MatmulBiasParams {{ - M: u32, - K: u32, - N: u32, - batch_size: u32, -}} - -@group(0) @binding(0) var matmul_a: array<{t}>; -@group(0) @binding(1) var matmul_b: array<{t}>; -@group(0) @binding(2) var matmul_bias: array<{t}>; -@group(0) @binding(3) var matmul_c: array<{t}>; -@group(0) @binding(4) var matmul_params: MatmulBiasParams; - -@compute @workgroup_size(16, 16, 1) -fn matmul_bias_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let M = matmul_params.M; - let K = matmul_params.K; - let N = matmul_params.N; - - let row = group_id.y * TILE_SIZE + local_id.y; - let col = group_id.x * TILE_SIZE + local_id.x; - - var sum: {t} = {zero}; - - let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE; - - for (var t_idx: u32 = 0u; t_idx < num_tiles; t_idx = t_idx + 1u) {{ - let a_col = t_idx * TILE_SIZE + local_id.x; - if (row < M && a_col < K) {{ - tile_a[local_id.y][local_id.x] = matmul_a[row * K + a_col]; - }} else {{ - tile_a[local_id.y][local_id.x] = {zero}; - }} - - let b_row = t_idx * TILE_SIZE + local_id.y; - if (b_row < K && col < N) {{ - tile_b[local_id.y][local_id.x] = matmul_b[b_row * N + col]; - }} else {{ - tile_b[local_id.y][local_id.x] = {zero}; - }} - - workgroupBarrier(); - - for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) {{ - sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x]; - }} - - workgroupBarrier(); - }} - - // Fused epilogue: add bias and write result - if (row < M && col < N) {{ - matmul_c[row * N + col] = sum + matmul_bias[col]; - }} -}} - -@compute @workgroup_size(16, 16, 1) -fn batched_matmul_bias_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let M = matmul_params.M; - let K = matmul_params.K; - let N = matmul_params.N; - let batch_size = matmul_params.batch_size; - - let batch = group_id.z; - if (batch >= batch_size) {{ - return; - }} - - let row = group_id.y * TILE_SIZE + local_id.y; - let col = group_id.x * TILE_SIZE + local_id.x; - - let a_batch_offset = batch * M * K; - let b_batch_offset = batch * K * N; - let c_batch_offset = batch * M * N; - - var sum: {t} = {zero}; - - let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE; - - for (var t_idx: u32 = 0u; t_idx < num_tiles; t_idx = t_idx + 1u) {{ - let a_col = t_idx * TILE_SIZE + local_id.x; - if (row < M && a_col < K) {{ - tile_a[local_id.y][local_id.x] = matmul_a[a_batch_offset + row * K + a_col]; - }} else {{ - tile_a[local_id.y][local_id.x] = {zero}; - }} - - let b_row = t_idx * TILE_SIZE + local_id.y; - if (b_row < K && col < N) {{ - tile_b[local_id.y][local_id.x] = matmul_b[b_batch_offset + b_row * N + col]; - }} else {{ - tile_b[local_id.y][local_id.x] = {zero}; - }} - - workgroupBarrier(); - - for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) {{ - sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x]; - }} - - workgroupBarrier(); - }} - - // Fused epilogue: add bias (same bias for all batches) and write result - if (row < M && col < N) {{ - matmul_c[c_batch_offset + row * N + col] = sum + matmul_bias[col]; - }} -}} -"#, - t = t, - suffix = suffix, - zero = match dtype { - DType::F32 | DType::F16 => "0.0", - _ => "0", - }, - )) -} diff --git a/src/runtime/wgpu/shaders/generator/matrix_funcs.rs b/src/runtime/wgpu/shaders/generator/matrix_funcs.rs deleted file mode 100644 index ba84f767..00000000 --- a/src/runtime/wgpu/shaders/generator/matrix_funcs.rs +++ /dev/null @@ -1,397 +0,0 @@ -//! WGSL shader generation for matrix function operations on quasi-triangular matrices. -//! -//! These shaders operate on the Schur form T of a matrix A, where A = Z @ T @ Z^T. -//! The quasi-triangular form has 1x1 blocks (real eigenvalues) and 2x2 blocks -//! (complex conjugate pairs) on the diagonal. - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -/// Generate shader for validating Schur eigenvalues (checking for non-positive real eigenvalues). -/// -/// Returns a tensor with validation results: -/// - `output[0]` = 1.0 if any non-positive real eigenvalue found, 0.0 otherwise -/// - `output[1]` = the first problematic eigenvalue value (if any) -pub fn generate_validate_eigenvalues_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Schur eigenvalue validation for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct Params {{ - n: u32, - eps: f32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var matrix_t: array<{t}>; -@group(0) @binding(1) var result: array<{t}>; // [has_error, error_value] -@group(0) @binding(2) var params: Params; - -// Check if a real eigenvalue is non-positive -fn check_real_eigenvalue(val: {t}, eps: {t}) -> bool {{ - return val <= eps; -}} - -// Check if a 2x2 block represents non-positive real eigenvalues -// For 2x2 block [[a, b], [c, d]], eigenvalues are (a+d)/2 ± sqrt((a-d)²/4 + bc) -// If discriminant < 0, eigenvalues are complex (ok) -// If discriminant >= 0, check if real part is non-positive -fn check_2x2_block(a: {t}, b: {t}, c: {t}, d: {t}, eps: {t}) -> bool {{ - let trace = a + d; - let det = a * d - b * c; - let disc = trace * trace - 4.0 * det; - - if disc < 0.0 {{ - // Complex eigenvalues - check real part - let real_part = trace / 2.0; - return real_part <= eps; - }} else {{ - // Real eigenvalues - let sqrt_disc = sqrt(disc); - let lambda1 = (trace + sqrt_disc) / 2.0; - let lambda2 = (trace - sqrt_disc) / 2.0; - return lambda1 <= eps || lambda2 <= eps; - }} -}} - -@compute @workgroup_size(1) -fn validate_eigenvalues_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let n = params.n; - let eps = {t}(params.eps); - - // Initialize result to "no error" - result[0] = 0.0; - result[1] = 0.0; - - var i: u32 = 0u; - while i < n {{ - let diag_idx = i * n + i; - - // Check if this is a 2x2 block (non-zero sub-diagonal) - if i + 1u < n {{ - let sub_diag = abs(matrix_t[(i + 1u) * n + i]); - if sub_diag > eps {{ - // 2x2 block - let a = matrix_t[i * n + i]; - let b = matrix_t[i * n + (i + 1u)]; - let c = matrix_t[(i + 1u) * n + i]; - let d = matrix_t[(i + 1u) * n + (i + 1u)]; - - if check_2x2_block(a, b, c, d, eps) {{ - result[0] = 1.0; - result[1] = (a + d) / 2.0; // Report real part - return; - }} - i = i + 2u; - continue; - }} - }} - - // 1x1 block (real eigenvalue) - let eigenvalue = matrix_t[diag_idx]; - if check_real_eigenvalue(eigenvalue, eps) {{ - result[0] = 1.0; - result[1] = eigenvalue; - return; - }} - i = i + 1u; - }} -}} -"#, - t = t, - suffix = suffix - )) -} - -/// Generate shader for applying a scalar function to diagonal blocks of quasi-triangular matrix. -/// -/// This handles both 1x1 blocks (real eigenvalues) and 2x2 blocks (complex pairs). -/// The function is specified by `func_type`: "exp", "log", "sqrt". -pub fn generate_diagonal_func_shader(dtype: DType, func_type: &str) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - // Generate the scalar function application - let scalar_func = match func_type { - "exp" => "exp(x)", - "log" => "log(x)", - "sqrt" => "sqrt(x)", - _ => { - return Err(crate::error::Error::InvalidArgument { - arg: "func_type", - reason: format!("Unknown function type: {}", func_type), - }); - } - }; - - // For 2x2 blocks with complex eigenvalues, we need special handling - let block_2x2_func = match func_type { - "exp" => { - r#" - // For 2x2 block with complex eigenvalues a ± bi: - // exp(a ± bi) = exp(a) * (cos(b) ± i*sin(b)) - // Result is [[exp(a)*cos(b), -exp(a)*sin(b)], [exp(a)*sin(b), exp(a)*cos(b)]] - // after similarity transform - let trace = a + d; - let det = a * d - b * c; - let disc = trace * trace - 4.0 * det; - - if disc >= 0.0 { - // Real eigenvalues - diagonalize and apply exp - let sqrt_disc = sqrt(disc); - let lambda1 = (trace + sqrt_disc) / 2.0; - let lambda2 = (trace - sqrt_disc) / 2.0; - let exp1 = exp(lambda1); - let exp2 = exp(lambda2); - - // Simple case: return diagonal exp values - // This is approximate but handles most cases - *f11 = (exp1 + exp2) / 2.0; - *f22 = (exp1 + exp2) / 2.0; - *f12 = (exp1 - exp2) / 2.0 * sign(b); - *f21 = (exp1 - exp2) / 2.0 * sign(c); - } else { - // Complex eigenvalues - let real_part = trace / 2.0; - let imag_part = sqrt(-disc) / 2.0; - let exp_real = exp(real_part); - let cos_imag = cos(imag_part); - let sin_imag = sin(imag_part); - - *f11 = exp_real * cos_imag; - *f22 = exp_real * cos_imag; - // Off-diagonal scaling based on original block structure - let scale = exp_real * sin_imag / imag_part; - *f12 = scale * b; - *f21 = scale * c; - } -"# - } - "log" => { - r#" - let trace = a + d; - let det = a * d - b * c; - let disc = trace * trace - 4.0 * det; - - if disc >= 0.0 { - // Real eigenvalues - let sqrt_disc = sqrt(disc); - let lambda1 = (trace + sqrt_disc) / 2.0; - let lambda2 = (trace - sqrt_disc) / 2.0; - let log1 = log(lambda1); - let log2 = log(lambda2); - - *f11 = (log1 + log2) / 2.0; - *f22 = (log1 + log2) / 2.0; - *f12 = (log1 - log2) / (lambda1 - lambda2) * b; - *f21 = (log1 - log2) / (lambda1 - lambda2) * c; - } else { - // Complex eigenvalues: log(r * e^(i*theta)) = log(r) + i*theta - let real_part = trace / 2.0; - let imag_part = sqrt(-disc) / 2.0; - let r = sqrt(det); // |lambda| = sqrt(det) for conjugate pair - let theta = atan2(imag_part, real_part); - - *f11 = log(r); - *f22 = log(r); - let scale = theta / imag_part; - *f12 = scale * b; - *f21 = scale * c; - } -"# - } - "sqrt" => { - r#" - let trace = a + d; - let det = a * d - b * c; - let disc = trace * trace - 4.0 * det; - - if disc >= 0.0 { - // Real eigenvalues - let sqrt_disc = sqrt(disc); - let lambda1 = (trace + sqrt_disc) / 2.0; - let lambda2 = (trace - sqrt_disc) / 2.0; - let sqrt1 = sqrt(lambda1); - let sqrt2 = sqrt(lambda2); - - *f11 = (sqrt1 + sqrt2) / 2.0; - *f22 = (sqrt1 + sqrt2) / 2.0; - let denom = sqrt1 + sqrt2; - if abs(denom) > 1e-10 { - *f12 = b / denom; - *f21 = c / denom; - } else { - *f12 = 0.0; - *f21 = 0.0; - } - } else { - // Complex eigenvalues - let r = sqrt(det); - let theta = atan2(sqrt(-disc) / 2.0, trace / 2.0); - let sqrt_r = sqrt(r); - let half_theta = theta / 2.0; - - *f11 = sqrt_r * cos(half_theta); - *f22 = sqrt_r * cos(half_theta); - let imag_part = sqrt(-disc) / 2.0; - let scale = sqrt_r * sin(half_theta) / imag_part; - *f12 = scale * b; - *f21 = scale * c; - } -"# - } - _ => unreachable!(), - }; - - Ok(format!( - r#"// Diagonal block function application for {t} - {func_type} - -const WORKGROUP_SIZE: u32 = 256u; - -struct Params {{ - n: u32, - eps: f32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var input_t: array<{t}>; -@group(0) @binding(1) var output_f: array<{t}>; -@group(0) @binding(2) var params: Params; - -// Apply function to 2x2 block -fn apply_2x2_block(a: {t}, b: {t}, c: {t}, d: {t}, - f11: ptr, f12: ptr, - f21: ptr, f22: ptr) {{ -{block_2x2_func} -}} - -@compute @workgroup_size(1) -fn diagonal_{func_type}_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let n = params.n; - let eps = {t}(params.eps); - - // Initialize output to zero - for (var idx: u32 = 0u; idx < n * n; idx = idx + 1u) {{ - output_f[idx] = 0.0; - }} - - var i: u32 = 0u; - while i < n {{ - // Check if this is a 2x2 block - if i + 1u < n {{ - let sub_diag = abs(input_t[(i + 1u) * n + i]); - if sub_diag > eps {{ - // 2x2 block - let a = input_t[i * n + i]; - let b = input_t[i * n + (i + 1u)]; - let c = input_t[(i + 1u) * n + i]; - let d = input_t[(i + 1u) * n + (i + 1u)]; - - var f11: {t}; - var f12: {t}; - var f21: {t}; - var f22: {t}; - apply_2x2_block(a, b, c, d, &f11, &f12, &f21, &f22); - - output_f[i * n + i] = f11; - output_f[i * n + (i + 1u)] = f12; - output_f[(i + 1u) * n + i] = f21; - output_f[(i + 1u) * n + (i + 1u)] = f22; - - i = i + 2u; - continue; - }} - }} - - // 1x1 block - let x = input_t[i * n + i]; - output_f[i * n + i] = {scalar_func}; - i = i + 1u; - }} -}} -"#, - t = t, - suffix = suffix, - func_type = func_type, - block_2x2_func = block_2x2_func, - scalar_func = scalar_func, - )) -} - -/// Generate shader for computing off-diagonal elements using Parlett's recurrence. -/// -/// For column j, processes rows i < j: -/// `F[i,j] = (T[i,i] - T[j,j])^(-1) * (F[i,j] * T[i,j] - sum_{k=i+1}^{j-1} F[i,k]*T[k,j] + T[i,k]*F[k,j])` -/// -/// This kernel processes one column at a time (called n times). -pub fn generate_parlett_column_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Parlett recurrence for off-diagonal elements - {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct Params {{ - n: u32, - col: u32, // Current column being processed - eps: f32, - _pad: u32, -}} - -@group(0) @binding(0) var input_t: array<{t}>; -@group(0) @binding(1) var output_f: array<{t}>; -@group(0) @binding(2) var params: Params; - -@compute @workgroup_size(WORKGROUP_SIZE) -fn parlett_column_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let n = params.n; - let j = params.col; - let eps = {t}(params.eps); - - // Each thread handles one row i < j - let i = gid.x; - if i >= j {{ - return; - }} - - let t_ii = input_t[i * n + i]; - let t_jj = input_t[j * n + j]; - let t_ij = input_t[i * n + j]; - - let denom = t_ii - t_jj; - - // Compute the sum term - var sum: {t} = 0.0; - for (var k: u32 = i + 1u; k < j; k = k + 1u) {{ - let f_ik = output_f[i * n + k]; - let t_kj = input_t[k * n + j]; - let t_ik = input_t[i * n + k]; - let f_kj = output_f[k * n + j]; - sum = sum + f_ik * t_kj - t_ik * f_kj; - }} - - let f_ii = output_f[i * n + i]; - let f_jj = output_f[j * n + j]; - - // F[i,j] = (T[i,j] * (F[i,i] - F[j,j]) + sum) / (T[i,i] - T[j,j]) - if abs(denom) > eps {{ - output_f[i * n + j] = (t_ij * (f_ii - f_jj) + sum) / denom; - }} else {{ - // Eigenvalues too close - use limit formula - output_f[i * n + j] = t_ij * f_ii; // Simplified fallback - }} -}} -"#, - t = t, - suffix = suffix, - )) -} diff --git a/src/runtime/wgpu/shaders/generator/mod.rs b/src/runtime/wgpu/shaders/generator/mod.rs deleted file mode 100644 index 36b43740..00000000 --- a/src/runtime/wgpu/shaders/generator/mod.rs +++ /dev/null @@ -1,706 +0,0 @@ -//! WGSL shader generation for multi-dtype support -//! -//! WebGPU's WGSL does not support templates like CUDA/C++. -//! This module generates WGSL shader source code for each dtype. -//! -//! # Supported DTypes -//! -//! | DType | WGSL Type | Notes | -//! |-------|-----------|-------| -//! | F32 | f32 | Always available | -//! | I32 | i32 | Always available | -//! | U32 | u32 | Always available | -//! | F16 | f16 | Requires WebGPU f16 extension | -//! -//! # Architecture -//! -//! ```text -//! generate_binary_shader(DType::F32, "add") → WGSL source with f32 types -//! generate_binary_shader(DType::I32, "add") → WGSL source with i32 types -//! generate_binary_shader(DType::U32, "add") → WGSL source with u32 types -//! ``` -//! -//! Shaders are cached by `(dtype, operation)` key in the pipeline cache. - -pub mod activation; -pub mod binary; -pub mod cast; -pub mod cat; -pub mod common; -pub mod compare; -pub mod complex; -pub mod conv; -pub mod cumulative; -pub mod distributions; -pub mod fft; -pub mod index; -pub mod masked; -pub mod matmul; -pub mod matrix_funcs; -pub mod norm; -pub mod reduce; -pub mod scalar; -pub mod semiring_matmul; -pub mod sort; -#[cfg(feature = "sparse")] -pub mod sparse_algorithms; -#[cfg(feature = "sparse")] -pub mod sparse_conversions; -#[cfg(feature = "sparse")] -pub mod sparse_factorize; -#[cfg(feature = "sparse")] -pub mod sparse_linalg; -#[cfg(feature = "sparse")] -pub mod sparse_merge; -#[cfg(feature = "sparse")] -pub mod sparse_split; -#[cfg(feature = "sparse")] -pub mod sparse_trsv; -#[cfg(feature = "sparse")] -pub mod sparse_utils; -pub mod special; -#[cfg(feature = "sparse")] -pub mod spmv; -pub mod unary; -pub mod utility; -pub mod where_cond; - -pub use activation::generate_clamp_shader; -pub use binary::{generate_binary_shader, generate_broadcast_binary_shader}; -pub use cast::{generate_all_casts_from, generate_cast_shader}; -pub use cat::{ - generate_cat_shader, generate_pad_shader, generate_repeat_shader, generate_roll_shader, -}; -pub use common::{dtype_suffix, is_wgpu_supported, is_wgsl_float, is_wgsl_int, wgsl_type}; -pub use compare::generate_compare_shader; -pub use complex::{ - complex_output_dtype, generate_angle_shader, generate_conj_shader, generate_imag_shader, - generate_real_shader, get_complex_shader_generator, validate_complex_dtype, -}; -pub use conv::{generate_conv1d_shader, generate_conv2d_shader, generate_depthwise_conv2d_shader}; -pub use cumulative::{ - generate_cumprod_shader, generate_cumprod_strided_shader, generate_cumsum_shader, - generate_cumsum_strided_shader, generate_logsumexp_shader, generate_logsumexp_strided_shader, -}; -pub use distributions::{ - generate_bernoulli_shader, generate_beta_dist_shader, generate_binomial_shader, - generate_chi_squared_shader, generate_exponential_shader, generate_f_distribution_shader, - generate_gamma_dist_shader, generate_laplace_shader, generate_multinomial_count_shader, - generate_poisson_shader, generate_student_t_shader, -}; -pub use fft::{ - MAX_WORKGROUP_FFT_SIZE, generate_copy_complex_shader, generate_fftshift_shader, - generate_hermitian_extend_shader, generate_irfft_unpack_shader, generate_rfft_pack_shader, - generate_rfft_truncate_shader, generate_stockham_fft_shader, -}; -pub use index::{ - generate_bincount_shader, generate_embedding_lookup_shader, generate_gather_2d_shader, - generate_gather_nd_shader, generate_gather_shader, generate_index_put_shader, - generate_index_select_shader, generate_scatter_reduce_count_shader, - generate_scatter_reduce_mean_div_shader, generate_scatter_reduce_prod_shader, - generate_scatter_reduce_shader, generate_scatter_shader, generate_validate_indices_shader, -}; -pub use masked::{generate_masked_fill_shader, generate_masked_select_shader}; -pub use matmul::{generate_matmul_bias_shader, generate_matmul_shader}; -pub use matrix_funcs::{ - generate_diagonal_func_shader, generate_parlett_column_shader, - generate_validate_eigenvalues_shader, -}; -pub use norm::generate_norm_shader; -pub use reduce::generate_reduce_shader; -pub use scalar::{generate_fill_shader, generate_scalar_shader}; -pub use sort::{ - MAX_SHARED_SORT_SIZE, generate_count_nonzero_shader, generate_flat_to_multi_index_shader, - generate_gather_nonzero_shader, generate_searchsorted_shader, generate_sort_shader, - generate_topk_shader, generate_unique_shader, generate_unique_with_counts_shader, -}; -// Sparse linear algebra exports from split modules -#[cfg(feature = "sparse")] -pub use sparse_algorithms::{ - generate_dsmm_csc_shader, generate_spgemm_accumulate_shader, generate_spgemm_scatter_shader, - generate_spgemm_symbolic_shader, -}; -#[cfg(feature = "sparse")] -pub use sparse_conversions::{ - generate_coo_to_csc_scatter_shader, generate_coo_to_csr_scatter_shader, - generate_copy_ptrs_shader, generate_count_nonzeros_shader, generate_csc_to_csr_scatter_shader, - generate_csr_to_csc_scatter_shader, generate_csr_to_dense_shader, - generate_dense_to_coo_scatter_shader, generate_expand_col_ptrs_shader, - generate_expand_row_ptrs_shader, generate_histogram_shader, -}; -#[cfg(feature = "sparse")] -pub use sparse_factorize::{generate_ic0_level_shader, generate_ilu0_level_shader}; -#[cfg(feature = "sparse")] -pub use sparse_merge::{ - generate_csc_add_compute_shader, generate_csc_div_compute_shader, - generate_csc_merge_count_shader, generate_csc_mul_compute_shader, - generate_csc_mul_count_shader, generate_csc_sub_compute_shader, - generate_csr_add_compute_shader, generate_csr_div_compute_shader, - generate_csr_merge_count_shader, generate_csr_mul_compute_shader, - generate_csr_mul_count_shader, generate_csr_sub_compute_shader, generate_exclusive_scan_shader, -}; -#[cfg(feature = "sparse")] -pub use sparse_split::{ - generate_extract_lower_count_shader, generate_extract_lower_scatter_shader, - generate_split_lu_count_shader, generate_split_lu_scatter_l_shader, - generate_split_lu_scatter_shader, generate_split_lu_scatter_u_shader, -}; -#[cfg(feature = "sparse")] -pub use sparse_trsv::{generate_sparse_trsv_lower_shader, generate_sparse_trsv_upper_shader}; -#[cfg(feature = "sparse")] -pub use sparse_utils::{generate_copy_shader, generate_find_diag_indices_shader}; -pub use special::{ - generate_special_binary_shader, generate_special_ternary_shader, generate_special_unary_shader, -}; -#[cfg(feature = "sparse")] -pub use spmv::{ - generate_csr_extract_diagonal_shader, generate_csr_spmm_shader, generate_csr_spmv_shader, -}; -pub use unary::generate_unary_shader; -pub use utility::{ - generate_arange_shader, generate_eye_shader, generate_linspace_shader, - generate_multinomial_with_replacement_shader, generate_multinomial_without_replacement_shader, - generate_rand_shader, generate_randint_shader, generate_randn_shader, -}; -pub use where_cond::generate_where_cond_shader; - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_wgsl_type() { - assert_eq!(wgsl_type(crate::dtype::DType::F32).unwrap(), "f32"); - assert_eq!(wgsl_type(crate::dtype::DType::I32).unwrap(), "i32"); - assert_eq!(wgsl_type(crate::dtype::DType::U32).unwrap(), "u32"); - assert!(wgsl_type(crate::dtype::DType::F64).is_err()); // Not supported - } - - #[test] - fn test_generate_binary_shader() { - let shader = generate_binary_shader(crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn add_f32")); - assert!(shader.contains("fn sub_f32")); - assert!(shader.contains("fn mul_f32")); - assert!(shader.contains("array")); - } - - #[test] - fn test_generate_binary_shader_i32() { - let shader = generate_binary_shader(crate::dtype::DType::I32).unwrap(); - assert!(shader.contains("fn add_i32")); - assert!(shader.contains("array")); - } - - #[test] - fn test_generate_unary_shader_float() { - let shader = generate_unary_shader(crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn sqrt_f32")); - assert!(shader.contains("fn exp_f32")); - assert!(shader.contains("fn relu_f32")); - } - - #[test] - fn test_generate_unary_shader_int() { - let shader = generate_unary_shader(crate::dtype::DType::I32).unwrap(); - assert!(shader.contains("fn neg_i32")); - assert!(shader.contains("fn abs_i32")); - // Float ops should not be present - assert!(!shader.contains("fn sqrt_i32")); - assert!(!shader.contains("fn exp_i32")); - } - - #[test] - fn test_generate_reduce_shader() { - let shader = generate_reduce_shader(crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn reduce_sum_f32")); - assert!(shader.contains("fn reduce_max_f32")); - assert!(shader.contains("fn reduce_min_f32")); - } - - #[test] - fn test_generate_matmul_shader() { - let shader = generate_matmul_shader(crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn matmul_f32")); - assert!(shader.contains("fn batched_matmul_f32")); - assert!(shader.contains("tile_a")); - assert!(shader.contains("tile_b")); - } - - #[test] - fn test_generate_matmul_bias_shader() { - let shader = generate_matmul_bias_shader(crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn matmul_bias_f32")); - assert!(shader.contains("fn batched_matmul_bias_f32")); - assert!(shader.contains("matmul_bias")); // bias buffer binding - assert!(shader.contains("tile_a")); - assert!(shader.contains("tile_b")); - // Verify fused epilogue pattern - assert!(shader.contains("sum + matmul_bias[col]")); - } - - #[test] - fn test_generate_norm_shader() { - let shader = generate_norm_shader(crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn rms_norm_f32")); - assert!(shader.contains("fn layer_norm_f32")); - } - - #[test] - fn test_generate_norm_shader_int_fails() { - // Normalization is only for float types - assert!(generate_norm_shader(crate::dtype::DType::I32).is_err()); - } - - #[test] - fn test_generate_compare_shader() { - let shader = generate_compare_shader(crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn eq_f32")); - assert!(shader.contains("fn lt_f32")); - assert!(shader.contains("array")); // Output is f32 - } - - // ======================================================================== - // Multi-DType WGSL Syntax Validation Tests - // - // These tests validate that generated shaders are syntactically correct - // WGSL by parsing them with naga. This catches issues like: - // - Float literals in integer contexts (0.0 vs 0) - // - Invalid type casts - // - Missing/incorrect array types - // ======================================================================== - - /// Helper to validate WGSL shader syntax using naga parser (re-exported by wgpu) - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - /// All dtypes that WebGPU supports - const WGPU_DTYPES: &[crate::dtype::DType] = &[ - crate::dtype::DType::F32, - crate::dtype::DType::I32, - crate::dtype::DType::U32, - ]; - - #[test] - fn test_binary_shader_syntax_all_dtypes() { - for &dtype in WGPU_DTYPES { - let shader = generate_binary_shader(dtype) - .unwrap_or_else(|_| panic!("Failed to generate binary shader for {:?}", dtype)); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for binary shader {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_unary_shader_syntax_all_dtypes() { - for &dtype in WGPU_DTYPES { - let shader = generate_unary_shader(dtype) - .unwrap_or_else(|_| panic!("Failed to generate unary shader for {:?}", dtype)); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for unary shader {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_scalar_shader_syntax_all_dtypes() { - for &dtype in WGPU_DTYPES { - let shader = generate_scalar_shader(dtype) - .unwrap_or_else(|_| panic!("Failed to generate scalar shader for {:?}", dtype)); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for scalar shader {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_reduce_shader_syntax_all_dtypes() { - for &dtype in WGPU_DTYPES { - let shader = generate_reduce_shader(dtype) - .unwrap_or_else(|_| panic!("Failed to generate reduce shader for {:?}", dtype)); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for reduce shader {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_compare_shader_syntax_all_dtypes() { - for &dtype in WGPU_DTYPES { - let shader = generate_compare_shader(dtype) - .unwrap_or_else(|_| panic!("Failed to generate compare shader for {:?}", dtype)); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for compare shader {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_matmul_shader_syntax_all_dtypes() { - for &dtype in WGPU_DTYPES { - let shader = generate_matmul_shader(dtype) - .unwrap_or_else(|_| panic!("Failed to generate matmul shader for {:?}", dtype)); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for matmul shader {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_matmul_bias_shader_syntax_all_dtypes() { - for &dtype in WGPU_DTYPES { - let shader = generate_matmul_bias_shader(dtype).unwrap_or_else(|_| { - panic!("Failed to generate matmul_bias shader for {:?}", dtype) - }); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for matmul_bias shader {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_norm_shader_syntax_float_only() { - // Norm operations only support float types - let shader = generate_norm_shader(crate::dtype::DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for norm shader F32:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_fill_shader_syntax_all_dtypes() { - for &dtype in WGPU_DTYPES { - let shader = generate_fill_shader(dtype) - .unwrap_or_else(|_| panic!("Failed to generate fill shader for {:?}", dtype)); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for fill shader {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_integer_shaders_no_float_literals() { - // Verify integer shaders don't contain float literals that would cause type errors - for dtype in [crate::dtype::DType::I32, crate::dtype::DType::U32] { - let unary = generate_unary_shader(dtype).unwrap(); - // Integer shaders should not contain standalone float operations - // The float ops (sqrt, exp, etc.) should be excluded for integers - assert!( - !unary.contains("fn sqrt_"), - "Integer unary shader should not contain sqrt for {:?}", - dtype - ); - assert!( - !unary.contains("fn exp_"), - "Integer unary shader should not contain exp for {:?}", - dtype - ); - } - } - - #[test] - fn test_generate_cast_shader() { - // F32 -> I32 - let shader = - generate_cast_shader(crate::dtype::DType::F32, crate::dtype::DType::I32).unwrap(); - assert!(shader.contains("fn cast_f32_to_i32")); - assert!(shader.contains("array")); - assert!(shader.contains("array")); - - // I32 -> F32 - let shader = - generate_cast_shader(crate::dtype::DType::I32, crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn cast_i32_to_f32")); - - // U32 -> F32 - let shader = - generate_cast_shader(crate::dtype::DType::U32, crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn cast_u32_to_f32")); - } - - #[test] - fn test_cast_shader_syntax_all_combinations() { - let dtypes = [ - crate::dtype::DType::F32, - crate::dtype::DType::I32, - crate::dtype::DType::U32, - ]; - - for &src in &dtypes { - for &dst in &dtypes { - if src == dst { - continue; - } - - let shader = generate_cast_shader(src, dst).unwrap_or_else(|_| { - panic!("Failed to generate cast shader for {:?} -> {:?}", src, dst) - }); - - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for cast {:?} -> {:?}:\n{}\n\nShader:\n{}", - src, dst, e, shader - ) - }); - } - } - } - - #[test] - fn test_cast_shader_same_type_is_noop() { - let shader = - generate_cast_shader(crate::dtype::DType::F32, crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("No-op")); - assert!(!shader.contains("@compute")); - } - - // ======================================================================== - // Utility Operation Shader Tests (arange, linspace, eye) - // ======================================================================== - - #[test] - fn test_generate_arange_shader_f32() { - let shader = generate_arange_shader(crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn arange_f32")); - assert!(shader.contains("array")); - assert!(shader.contains("arange_params")); - } - - #[test] - fn test_arange_shader_syntax_all_dtypes() { - for &dtype in WGPU_DTYPES { - let shader = generate_arange_shader(dtype) - .unwrap_or_else(|_| panic!("Failed to generate arange shader for {:?}", dtype)); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for arange shader {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_generate_linspace_shader_f32() { - let shader = generate_linspace_shader(crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn linspace_f32")); - assert!(shader.contains("array")); - assert!(shader.contains("linspace_params")); - } - - #[test] - fn test_linspace_shader_syntax() { - // linspace only supports float types - let shader = generate_linspace_shader(crate::dtype::DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for linspace shader F32:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_linspace_shader_int_fails() { - // linspace should fail for integer types - assert!(generate_linspace_shader(crate::dtype::DType::I32).is_err()); - assert!(generate_linspace_shader(crate::dtype::DType::U32).is_err()); - } - - #[test] - fn test_generate_eye_shader_f32() { - let shader = generate_eye_shader(crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn eye_f32")); - assert!(shader.contains("array")); - assert!(shader.contains("eye_params")); - } - - #[test] - fn test_eye_shader_syntax_all_dtypes() { - for &dtype in WGPU_DTYPES { - let shader = generate_eye_shader(dtype) - .unwrap_or_else(|_| panic!("Failed to generate eye shader for {:?}", dtype)); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for eye shader {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - // ======================================================================== - // Random Operation Shader Tests (rand, randn, randint) - // ======================================================================== - - #[test] - fn test_generate_rand_shader_f32() { - let shader = generate_rand_shader(crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn rand_f32")); - assert!(shader.contains("array")); - assert!(shader.contains("rand_params")); - assert!(shader.contains("pcg_hash")); - } - - #[test] - fn test_rand_shader_syntax() { - // rand only supports F32 on WebGPU - let shader = generate_rand_shader(crate::dtype::DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for rand shader F32:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_rand_shader_int_fails() { - // rand should fail for integer types - assert!(generate_rand_shader(crate::dtype::DType::I32).is_err()); - assert!(generate_rand_shader(crate::dtype::DType::U32).is_err()); - } - - #[test] - fn test_generate_randn_shader_f32() { - let shader = generate_randn_shader(crate::dtype::DType::F32).unwrap(); - assert!(shader.contains("fn randn_f32")); - assert!(shader.contains("array")); - assert!(shader.contains("randn_params")); - assert!(shader.contains("box_muller")); - } - - #[test] - fn test_randn_shader_syntax() { - // randn only supports F32 on WebGPU - let shader = generate_randn_shader(crate::dtype::DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for randn shader F32:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_randn_shader_int_fails() { - // randn should fail for integer types - assert!(generate_randn_shader(crate::dtype::DType::I32).is_err()); - assert!(generate_randn_shader(crate::dtype::DType::U32).is_err()); - } - - #[test] - fn test_generate_randint_shader_i32() { - let shader = generate_randint_shader(crate::dtype::DType::I32).unwrap(); - assert!(shader.contains("fn randint_i32")); - assert!(shader.contains("array")); - assert!(shader.contains("randint_params")); - } - - #[test] - fn test_generate_randint_shader_u32() { - let shader = generate_randint_shader(crate::dtype::DType::U32).unwrap(); - assert!(shader.contains("fn randint_u32")); - assert!(shader.contains("array")); - } - - #[test] - fn test_randint_shader_syntax_int_dtypes() { - // randint supports I32 and U32 - for dtype in [crate::dtype::DType::I32, crate::dtype::DType::U32] { - let shader = generate_randint_shader(dtype) - .unwrap_or_else(|_| panic!("Failed to generate randint shader for {:?}", dtype)); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for randint shader {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_randint_shader_float_fails() { - // randint should fail for float types - assert!(generate_randint_shader(crate::dtype::DType::F32).is_err()); - } - - // ======================================================================== - // Special Function Shader Tests - // ======================================================================== - - #[test] - fn test_special_unary_shader_syntax() { - let shader = generate_special_unary_shader(crate::dtype::DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for special unary shader F32:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_special_binary_shader_syntax() { - let shader = generate_special_binary_shader(crate::dtype::DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for special binary shader F32:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_special_ternary_shader_syntax() { - let shader = generate_special_ternary_shader(crate::dtype::DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for special ternary shader F32:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_special_shaders_f64_fails() { - // Special functions only support F32 on WebGPU (no F64) - assert!(generate_special_unary_shader(crate::dtype::DType::F64).is_err()); - assert!(generate_special_binary_shader(crate::dtype::DType::F64).is_err()); - assert!(generate_special_ternary_shader(crate::dtype::DType::F64).is_err()); - } -} diff --git a/src/runtime/wgpu/shaders/generator/norm.rs b/src/runtime/wgpu/shaders/generator/norm.rs deleted file mode 100644 index 137985ea..00000000 --- a/src/runtime/wgpu/shaders/generator/norm.rs +++ /dev/null @@ -1,167 +0,0 @@ -//! WGSL shader generation for normalization operations - -use super::common::{dtype_suffix, is_wgsl_float, wgsl_type}; -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// Generate WGSL shader for normalization operations (float types only) -pub fn generate_norm_shader(dtype: DType) -> Result { - if !is_wgsl_float(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "normalization (requires float type)", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated normalization operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -var norm_shared: array<{t}, 256>; -var ln_shared_mean: array<{t}, 256>; -var ln_shared_var: array<{t}, 256>; - -struct RmsNormParams {{ - batch_size: u32, - hidden_size: u32, - eps: f32, -}} - -@group(0) @binding(0) var rms_input: array<{t}>; -@group(0) @binding(1) var rms_weight: array<{t}>; -@group(0) @binding(2) var rms_output: array<{t}>; -@group(0) @binding(3) var rms_params: RmsNormParams; - -@compute @workgroup_size(256) -fn rms_norm_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let tid = local_id.x; - let batch_idx = group_id.x; - - if (batch_idx >= rms_params.batch_size) {{ - return; - }} - - let hidden_size = rms_params.hidden_size; - let eps = {t}(rms_params.eps); - let base_offset = batch_idx * hidden_size; - - // Compute sum of squares - var sum_sq: {t} = 0.0; - var i: u32 = tid; - while (i < hidden_size) {{ - let val = rms_input[base_offset + i]; - sum_sq = sum_sq + val * val; - i = i + WORKGROUP_SIZE; - }} - - norm_shared[tid] = sum_sq; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - norm_shared[tid] = norm_shared[tid] + norm_shared[tid + s]; - }} - workgroupBarrier(); - }} - - let rms = sqrt(norm_shared[0] / {t}(hidden_size) + eps); - workgroupBarrier(); - - // Normalize and apply weight - i = tid; - while (i < hidden_size) {{ - rms_output[base_offset + i] = rms_input[base_offset + i] / rms * rms_weight[i]; - i = i + WORKGROUP_SIZE; - }} -}} - -struct LayerNormParams {{ - batch_size: u32, - hidden_size: u32, - eps: f32, -}} - -@group(0) @binding(0) var ln_input: array<{t}>; -@group(0) @binding(1) var ln_weight: array<{t}>; -@group(0) @binding(2) var ln_bias: array<{t}>; -@group(0) @binding(3) var ln_output: array<{t}>; -@group(0) @binding(4) var ln_params: LayerNormParams; - -@compute @workgroup_size(256) -fn layer_norm_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let tid = local_id.x; - let batch_idx = group_id.x; - - if (batch_idx >= ln_params.batch_size) {{ - return; - }} - - let hidden_size = ln_params.hidden_size; - let eps = {t}(ln_params.eps); - let base_offset = batch_idx * hidden_size; - - // Compute mean - var sum: {t} = 0.0; - var i: u32 = tid; - while (i < hidden_size) {{ - sum = sum + ln_input[base_offset + i]; - i = i + WORKGROUP_SIZE; - }} - - ln_shared_mean[tid] = sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - ln_shared_mean[tid] = ln_shared_mean[tid] + ln_shared_mean[tid + s]; - }} - workgroupBarrier(); - }} - - let mean_val = ln_shared_mean[0] / {t}(hidden_size); - workgroupBarrier(); - - // Compute variance - var var_sum: {t} = 0.0; - i = tid; - while (i < hidden_size) {{ - let diff = ln_input[base_offset + i] - mean_val; - var_sum = var_sum + diff * diff; - i = i + WORKGROUP_SIZE; - }} - - ln_shared_var[tid] = var_sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - ln_shared_var[tid] = ln_shared_var[tid] + ln_shared_var[tid + s]; - }} - workgroupBarrier(); - }} - - let variance = ln_shared_var[0] / {t}(hidden_size); - let inv_std = 1.0 / sqrt(variance + eps); - workgroupBarrier(); - - // Normalize and apply affine - i = tid; - while (i < hidden_size) {{ - let normalized = (ln_input[base_offset + i] - mean_val) * inv_std; - ln_output[base_offset + i] = normalized * ln_weight[i] + ln_bias[i]; - i = i + WORKGROUP_SIZE; - }} -}} -"#, - t = t, - suffix = suffix - )) -} diff --git a/src/runtime/wgpu/shaders/generator/reduce.rs b/src/runtime/wgpu/shaders/generator/reduce.rs deleted file mode 100644 index d57d3a40..00000000 --- a/src/runtime/wgpu/shaders/generator/reduce.rs +++ /dev/null @@ -1,162 +0,0 @@ -//! WGSL shader generation for reduction operations - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -/// Generate WGSL shader for reduction operations -pub fn generate_reduce_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - // Workgroup shared memory for reductions - Ok(format!( - r#"// Auto-generated reduce operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -var reduce_shared: array<{t}, 256>; - -struct ReduceParams {{ - reduce_size: u32, - outer_size: u32, - inner_size: u32, - numel_out: u32, -}} - -@group(0) @binding(0) var reduce_input: array<{t}>; -@group(0) @binding(1) var reduce_output: array<{t}>; -@group(0) @binding(2) var reduce_params: ReduceParams; - -@compute @workgroup_size(256) -fn reduce_sum_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let tid = local_id.x; - let outer_idx = group_id.x; - - if (outer_idx >= reduce_params.outer_size) {{ - return; - }} - - let reduce_size = reduce_params.reduce_size; - let base_offset = outer_idx * reduce_size; - - // Each thread accumulates multiple elements - var sum: {t} = {zero}; - var i: u32 = tid; - while (i < reduce_size) {{ - sum = sum + reduce_input[base_offset + i]; - i = i + WORKGROUP_SIZE; - }} - - reduce_shared[tid] = sum; - workgroupBarrier(); - - // Tree reduction in shared memory - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - reduce_output[outer_idx] = reduce_shared[0]; - }} -}} - -@compute @workgroup_size(256) -fn reduce_max_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let tid = local_id.x; - let outer_idx = group_id.x; - - if (outer_idx >= reduce_params.outer_size) {{ - return; - }} - - let reduce_size = reduce_params.reduce_size; - let base_offset = outer_idx * reduce_size; - - var max_val: {t} = {min_val}; - var i: u32 = tid; - while (i < reduce_size) {{ - max_val = max(max_val, reduce_input[base_offset + i]); - i = i + WORKGROUP_SIZE; - }} - - reduce_shared[tid] = max_val; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - reduce_output[outer_idx] = reduce_shared[0]; - }} -}} - -@compute @workgroup_size(256) -fn reduce_min_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let tid = local_id.x; - let outer_idx = group_id.x; - - if (outer_idx >= reduce_params.outer_size) {{ - return; - }} - - let reduce_size = reduce_params.reduce_size; - let base_offset = outer_idx * reduce_size; - - var min_val: {t} = {max_val}; - var i: u32 = tid; - while (i < reduce_size) {{ - min_val = min(min_val, reduce_input[base_offset + i]); - i = i + WORKGROUP_SIZE; - }} - - reduce_shared[tid] = min_val; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - reduce_output[outer_idx] = reduce_shared[0]; - }} -}} -"#, - t = t, - suffix = suffix, - zero = match dtype { - DType::F32 | DType::F16 => "0.0", - _ => "0", - }, - min_val = match dtype { - DType::F32 => "-3.402823e+38", // -FLT_MAX - DType::F16 => "-65504.0", - DType::I32 => "-2147483648", - DType::U32 => "0u", - _ => "0", - }, - max_val = match dtype { - DType::F32 => "3.402823e+38", // FLT_MAX - DType::F16 => "65504.0", - DType::I32 => "2147483647", - DType::U32 => "4294967295u", - _ => "0", - }, - )) -} diff --git a/src/runtime/wgpu/shaders/generator/scalar.rs b/src/runtime/wgpu/shaders/generator/scalar.rs deleted file mode 100644 index f234fe7d..00000000 --- a/src/runtime/wgpu/shaders/generator/scalar.rs +++ /dev/null @@ -1,162 +0,0 @@ -//! WGSL shader generation for scalar element-wise operations and fill operations - -use super::common::{dtype_suffix, is_wgsl_float, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -/// Generate WGSL shader for scalar element-wise operations -pub fn generate_scalar_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - let float_ops = if is_wgsl_float(dtype) { - format!( - r#" -@compute @workgroup_size(256) -fn pow_scalar_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < scalar_params.numel) {{ - scalar_out[idx] = pow(scalar_a[idx], {t}(scalar_params.scalar)); - }} -}} - -// Leaky ReLU: max(negative_slope * x, x) -@compute @workgroup_size(256) -fn leaky_relu_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < scalar_params.numel) {{ - let x = scalar_a[idx]; - let slope = {t}(scalar_params.scalar); - scalar_out[idx] = max(slope * x, x); - }} -}} - -// ELU: x if x > 0, else alpha * (exp(x) - 1) -@compute @workgroup_size(256) -fn elu_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < scalar_params.numel) {{ - let x = scalar_a[idx]; - let alpha = {t}(scalar_params.scalar); - scalar_out[idx] = select(alpha * (exp(x) - 1.0), x, x > 0.0); - }} -}} -"#, - suffix = suffix, - t = t - ) - } else { - // Integer pow_scalar - format!( - r#" -@compute @workgroup_size(256) -fn pow_scalar_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < scalar_params.numel) {{ - var base = scalar_a[idx]; - var exp = {t}(scalar_params.scalar); - var result: {t} = 1; - for (var i: {t} = 0; i < exp; i = i + 1) {{ - result = result * base; - }} - scalar_out[idx] = result; - }} -}} -"#, - suffix = suffix, - t = t - ) - }; - - Ok(format!( - r#"// Auto-generated scalar operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct ScalarParams {{ - numel: u32, - scalar: f32, // Always f32 for uniform, cast in shader -}} - -@group(0) @binding(0) var scalar_a: array<{t}>; -@group(0) @binding(1) var scalar_out: array<{t}>; -@group(0) @binding(2) var scalar_params: ScalarParams; - -@compute @workgroup_size(256) -fn add_scalar_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < scalar_params.numel) {{ - scalar_out[idx] = scalar_a[idx] + {t}(scalar_params.scalar); - }} -}} - -@compute @workgroup_size(256) -fn sub_scalar_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < scalar_params.numel) {{ - scalar_out[idx] = scalar_a[idx] - {t}(scalar_params.scalar); - }} -}} - -@compute @workgroup_size(256) -fn rsub_scalar_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < scalar_params.numel) {{ - scalar_out[idx] = {t}(scalar_params.scalar) - scalar_a[idx]; - }} -}} - -@compute @workgroup_size(256) -fn mul_scalar_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < scalar_params.numel) {{ - scalar_out[idx] = scalar_a[idx] * {t}(scalar_params.scalar); - }} -}} - -@compute @workgroup_size(256) -fn div_scalar_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < scalar_params.numel) {{ - scalar_out[idx] = scalar_a[idx] / {t}(scalar_params.scalar); - }} -}} - -{float_ops} -"#, - t = t, - suffix = suffix, - float_ops = float_ops - )) -} - -/// Generate WGSL shader for fill operation (set all elements to a constant value) -pub fn generate_fill_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated fill operation for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct FillParams {{ - numel: u32, - value: f32, // Always f32 for uniform, cast in shader -}} - -@group(0) @binding(0) var fill_out: array<{t}>; -@group(0) @binding(1) var fill_params: FillParams; - -@compute @workgroup_size(256) -fn fill_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < fill_params.numel) {{ - fill_out[idx] = {t}(fill_params.value); - }} -}} -"#, - t = t, - suffix = suffix - )) -} diff --git a/src/runtime/wgpu/shaders/generator/semiring_matmul.rs b/src/runtime/wgpu/shaders/generator/semiring_matmul.rs deleted file mode 100644 index 835c4a96..00000000 --- a/src/runtime/wgpu/shaders/generator/semiring_matmul.rs +++ /dev/null @@ -1,197 +0,0 @@ -//! WGSL shader generation for semiring matrix multiplication - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; -use crate::ops::semiring::SemiringOp; - -/// Generate WGSL shader for semiring matrix multiplication. -/// -/// Unlike standard matmul which uses (+, ×), semiring matmul uses -/// a custom (reduce, combine) pair. The shader is generated per (dtype, op) -/// combination with the operations baked in as WGSL functions. -/// -/// Uses a simple one-thread-per-output-element approach (no shared-memory -/// tiling) because semiring operations don't distribute like (+, ×). -pub fn generate_semiring_matmul_shader(dtype: DType, op: SemiringOp) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - let op_name = semiring_op_name(op); - - let is_float = matches!(dtype, DType::F32 | DType::F16); - - let (identity, combine_expr, reduce_expr) = semiring_wgsl_ops(op, is_float); - - Ok(format!( - r#"// Auto-generated semiring matmul: {op_name} for {t} -// C[i,j] = reduce_k( combine(A[i,k], B[k,j]) ) - -struct SemiringMatmulParams {{ - M: u32, - K: u32, - N: u32, - batch_size: u32, -}} - -@group(0) @binding(0) var sr_a: array<{t}>; -@group(0) @binding(1) var sr_b: array<{t}>; -@group(0) @binding(2) var sr_c: array<{t}>; -@group(0) @binding(3) var sr_params: SemiringMatmulParams; - -fn sr_combine(a: {t}, b: {t}) -> {t} {{ - {combine_expr} -}} - -fn sr_reduce(acc: {t}, val: {t}) -> {t} {{ - {reduce_expr} -}} - -@compute @workgroup_size(16, 16, 1) -fn semiring_matmul_{op_name}_{suffix}( - @builtin(global_invocation_id) global_id: vec3 -) {{ - let M = sr_params.M; - let K = sr_params.K; - let N = sr_params.N; - - let row = global_id.y; - let col = global_id.x; - - if (row >= M || col >= N) {{ - return; - }} - - var acc: {t} = {identity}; - - for (var kk: u32 = 0u; kk < K; kk = kk + 1u) {{ - let a_val = sr_a[row * K + kk]; - let b_val = sr_b[kk * N + col]; - acc = sr_reduce(acc, sr_combine(a_val, b_val)); - }} - - sr_c[row * N + col] = acc; -}} - -@compute @workgroup_size(16, 16, 1) -fn batched_semiring_matmul_{op_name}_{suffix}( - @builtin(global_invocation_id) global_id: vec3 -) {{ - let M = sr_params.M; - let K = sr_params.K; - let N = sr_params.N; - let batch_size = sr_params.batch_size; - - let batch = global_id.z; - if (batch >= batch_size) {{ - return; - }} - - let row = global_id.y; - let col = global_id.x; - - if (row >= M || col >= N) {{ - return; - }} - - let a_offset = batch * M * K; - let b_offset = batch * K * N; - let c_offset = batch * M * N; - - var acc: {t} = {identity}; - - for (var kk: u32 = 0u; kk < K; kk = kk + 1u) {{ - let a_val = sr_a[a_offset + row * K + kk]; - let b_val = sr_b[b_offset + kk * N + col]; - acc = sr_reduce(acc, sr_combine(a_val, b_val)); - }} - - sr_c[c_offset + row * N + col] = acc; -}} -"#, - t = t, - suffix = suffix, - op_name = op_name, - identity = identity, - combine_expr = combine_expr, - reduce_expr = reduce_expr, - )) -} - -fn semiring_op_name(op: SemiringOp) -> &'static str { - match op { - SemiringOp::MinPlus => "min_plus", - SemiringOp::MaxPlus => "max_plus", - SemiringOp::MaxMin => "max_min", - SemiringOp::MinMax => "min_max", - SemiringOp::OrAnd => "or_and", - SemiringOp::PlusMax => "plus_max", - } -} - -/// Returns (identity, combine_expr, reduce_expr) as WGSL code strings. -fn semiring_wgsl_ops(op: SemiringOp, is_float: bool) -> (&'static str, &'static str, &'static str) { - match op { - // KEEP IN SYNC: ops/semiring.rs reduce_identity_f64(), cuda/kernels/semiring_matmul.cu - SemiringOp::MinPlus => { - // reduce=min, identity=+inf - let identity = if is_float { - "bitcast(0x7f800000u)" - } else { - "2147483647" - }; - (identity, "return a + b;", "return min(acc, val);") - } - SemiringOp::MaxPlus => { - // reduce=max, identity=-inf - let identity = if is_float { - "bitcast(0xff800000u)" - } else { - "-2147483647" - }; - (identity, "return a + b;", "return max(acc, val);") - } - SemiringOp::MaxMin => { - // reduce=max, identity=-inf - let identity = if is_float { - "bitcast(0xff800000u)" - } else { - "-2147483647" - }; - (identity, "return min(a, b);", "return max(acc, val);") - } - SemiringOp::MinMax => { - // reduce=min, identity=+inf - let identity = if is_float { - "bitcast(0x7f800000u)" - } else { - "2147483647" - }; - (identity, "return max(a, b);", "return min(acc, val);") - } - SemiringOp::OrAnd => { - let zero = if is_float { "0.0" } else { "0" }; - // OrAnd: combine=AND, reduce=OR - // We inline the logic since we need conditional expressions - // combine: (a != 0 && b != 0) ? 1 : 0 - // reduce: (acc != 0 || val != 0) ? 1 : 0 - // But WGSL doesn't have ternary, so we use select() - ( - zero, - if is_float { - "return select(0.0, 1.0, a != 0.0 && b != 0.0);" - } else { - "return select(0, 1, a != 0 && b != 0);" - }, - if is_float { - "return select(0.0, 1.0, acc != 0.0 || val != 0.0);" - } else { - "return select(0, 1, acc != 0 || val != 0);" - }, - ) - } - SemiringOp::PlusMax => { - let zero = if is_float { "0.0" } else { "0" }; - (zero, "return max(a, b);", "return acc + val;") - } - } -} diff --git a/src/runtime/wgpu/shaders/generator/sort.rs b/src/runtime/wgpu/shaders/generator/sort.rs deleted file mode 100644 index 79b94a93..00000000 --- a/src/runtime/wgpu/shaders/generator/sort.rs +++ /dev/null @@ -1,864 +0,0 @@ -//! WGSL shader generation for sorting operations -//! -//! Provides bitonic sort implementation for GPU-accelerated sorting. -//! Supports sort, argsort, topk, unique, nonzero, and searchsorted operations. - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// Maximum sort size for shared memory (power of 2) -pub const MAX_SHARED_SORT_SIZE: usize = 512; - -/// Generate WGSL shader for sort operations -pub fn generate_sort_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - let (min_val, max_val) = match dtype { - DType::F32 => ("-3.402823e+38", "3.402823e+38"), - DType::I32 => ("-2147483648", "2147483647"), - DType::U32 => ("0u", "4294967295u"), - _ => return Err(Error::UnsupportedDType { dtype, op: "sort" }), - }; - - // Comparison function depends on type - let cmp_less = match dtype { - DType::F32 => "a < b", - DType::I32 => "a < b", - DType::U32 => "a < b", - _ => "a < b", - }; - - Ok(format!( - r#"// Auto-generated sort operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; -const MAX_SORT_SIZE: u32 = 512u; - -var shared_vals: array<{t}, 512>; -var shared_idxs: array; - -struct SortParams {{ - outer_size: u32, - sort_size: u32, - inner_size: u32, - descending: u32, -}} - -struct TopkParams {{ - outer_size: u32, - sort_size: u32, - inner_size: u32, - k: u32, - largest: u32, - sorted: u32, -}} - -struct SearchsortedParams {{ - seq_len: u32, - num_values: u32, - right: u32, - _pad: u32, -}} - -struct CountParams {{ - numel: u32, -}} - -@group(0) @binding(0) var sort_input: array<{t}>; -@group(0) @binding(1) var sort_output: array<{t}>; -@group(0) @binding(2) var sort_indices: array; -@group(0) @binding(3) var sort_params: SortParams; - -// Comparison helper -fn compare_less_{suffix}(a: {t}, b: {t}) -> bool {{ - return {cmp_less}; -}} - -// Bitonic compare and swap for sort with indices -fn bitonic_cas_{suffix}(i: u32, j: u32, dir: bool) {{ - let vi = shared_vals[i]; - let vj = shared_vals[j]; - let swap = select(compare_less_{suffix}(vi, vj), compare_less_{suffix}(vj, vi), dir); - if (swap) {{ - shared_vals[i] = vj; - shared_vals[j] = vi; - let ti = shared_idxs[i]; - shared_idxs[i] = shared_idxs[j]; - shared_idxs[j] = ti; - }} -}} - -// Bitonic compare and swap for sort values only -fn bitonic_cas_values_{suffix}(i: u32, j: u32, dir: bool) {{ - let vi = shared_vals[i]; - let vj = shared_vals[j]; - let swap = select(compare_less_{suffix}(vi, vj), compare_less_{suffix}(vj, vi), dir); - if (swap) {{ - shared_vals[i] = vj; - shared_vals[j] = vi; - }} -}} - -// Sort with indices - returns both sorted values and original indices -@compute @workgroup_size(256) -fn sort_{suffix}( - @builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3 -) {{ - let outer_idx = group_id.x; - let inner_idx = group_id.y; - let tid = local_id.x; - - let outer_size = sort_params.outer_size; - let sort_size = sort_params.sort_size; - let inner_size = sort_params.inner_size; - let descending = sort_params.descending != 0u; - - if (outer_idx >= outer_size || inner_idx >= inner_size) {{ - return; - }} - - // Pad to next power of 2 - var n = sort_size; - var p: u32 = 1u; - while (p < n) {{ - p = p << 1u; - }} - n = min(p, MAX_SORT_SIZE); - - // Load data into shared memory - let base_offset = outer_idx * sort_size * inner_size + inner_idx; - for (var i = tid; i < n; i = i + WORKGROUP_SIZE) {{ - if (i < sort_size) {{ - let idx = base_offset + i * inner_size; - shared_vals[i] = sort_input[idx]; - shared_idxs[i] = i32(i); - }} else {{ - // Pad with max/min based on sort direction - shared_vals[i] = select({t}({max_val}), {t}({min_val}), descending); - shared_idxs[i] = i32(i); - }} - }} - workgroupBarrier(); - - // Bitonic sort - for (var k: u32 = 2u; k <= n; k = k << 1u) {{ - for (var j: u32 = k >> 1u; j > 0u; j = j >> 1u) {{ - for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) {{ - // Calculate bitonic network indices - let ij = (i / j) * 2u * j + (i % j); - let ij_pair = ij + j; - - // Direction depends on which half of the network we're in - let ascending_local = ((ij / k) % 2u == 0u) != descending; - - if (ij_pair < n) {{ - bitonic_cas_{suffix}(ij, ij_pair, ascending_local); - }} - }} - workgroupBarrier(); - }} - }} - - // Write sorted values and indices - for (var i = tid; i < sort_size; i = i + WORKGROUP_SIZE) {{ - let out_idx = base_offset + i * inner_size; - sort_output[out_idx] = shared_vals[i]; - sort_indices[out_idx] = shared_idxs[i]; - }} -}} - -// Sort values only (no indices) -@compute @workgroup_size(256) -fn sort_values_only_{suffix}( - @builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3 -) {{ - let outer_idx = group_id.x; - let inner_idx = group_id.y; - let tid = local_id.x; - - let outer_size = sort_params.outer_size; - let sort_size = sort_params.sort_size; - let inner_size = sort_params.inner_size; - let descending = sort_params.descending != 0u; - - if (outer_idx >= outer_size || inner_idx >= inner_size) {{ - return; - }} - - var n = sort_size; - var p: u32 = 1u; - while (p < n) {{ - p = p << 1u; - }} - n = min(p, MAX_SORT_SIZE); - - let base_offset = outer_idx * sort_size * inner_size + inner_idx; - for (var i = tid; i < n; i = i + WORKGROUP_SIZE) {{ - if (i < sort_size) {{ - let idx = base_offset + i * inner_size; - shared_vals[i] = sort_input[idx]; - }} else {{ - shared_vals[i] = select({t}({max_val}), {t}({min_val}), descending); - }} - }} - workgroupBarrier(); - - // Bitonic sort - for (var k: u32 = 2u; k <= n; k = k << 1u) {{ - for (var j: u32 = k >> 1u; j > 0u; j = j >> 1u) {{ - for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) {{ - // Calculate bitonic network indices - let ij = (i / j) * 2u * j + (i % j); - let ij_pair = ij + j; - - // Direction depends on which half of the network we're in - let ascending_local = ((ij / k) % 2u == 0u) != descending; - - if (ij_pair < n) {{ - bitonic_cas_values_{suffix}(ij, ij_pair, ascending_local); - }} - }} - workgroupBarrier(); - }} - }} - - for (var i = tid; i < sort_size; i = i + WORKGROUP_SIZE) {{ - let out_idx = base_offset + i * inner_size; - sort_output[out_idx] = shared_vals[i]; - }} -}} - -// Argsort - returns indices only -@compute @workgroup_size(256) -fn argsort_{suffix}( - @builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3 -) {{ - let outer_idx = group_id.x; - let inner_idx = group_id.y; - let tid = local_id.x; - - let outer_size = sort_params.outer_size; - let sort_size = sort_params.sort_size; - let inner_size = sort_params.inner_size; - let descending = sort_params.descending != 0u; - - if (outer_idx >= outer_size || inner_idx >= inner_size) {{ - return; - }} - - var n = sort_size; - var p: u32 = 1u; - while (p < n) {{ - p = p << 1u; - }} - n = min(p, MAX_SORT_SIZE); - - let base_offset = outer_idx * sort_size * inner_size + inner_idx; - for (var i = tid; i < n; i = i + WORKGROUP_SIZE) {{ - if (i < sort_size) {{ - let idx = base_offset + i * inner_size; - shared_vals[i] = sort_input[idx]; - shared_idxs[i] = i32(i); - }} else {{ - shared_vals[i] = select({t}({max_val}), {t}({min_val}), descending); - shared_idxs[i] = i32(i); - }} - }} - workgroupBarrier(); - - // Bitonic sort - for (var k: u32 = 2u; k <= n; k = k << 1u) {{ - for (var j: u32 = k >> 1u; j > 0u; j = j >> 1u) {{ - for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) {{ - // Calculate bitonic network indices - let ij = (i / j) * 2u * j + (i % j); - let ij_pair = ij + j; - - // Direction depends on which half of the network we're in - let ascending_local = ((ij / k) % 2u == 0u) != descending; - - if (ij_pair < n) {{ - bitonic_cas_{suffix}(ij, ij_pair, ascending_local); - }} - }} - workgroupBarrier(); - }} - }} - - // Write indices only - for (var i = tid; i < sort_size; i = i + WORKGROUP_SIZE) {{ - let out_idx = base_offset + i * inner_size; - sort_indices[out_idx] = shared_idxs[i]; - }} -}} -"#, - t = t, - suffix = suffix, - min_val = min_val, - max_val = max_val, - cmp_less = cmp_less, - )) -} - -/// Generate WGSL shader for topk operation -pub fn generate_topk_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - let (min_val, max_val) = match dtype { - DType::F32 => ("-3.402823e+38", "3.402823e+38"), - DType::I32 => ("-2147483648", "2147483647"), - DType::U32 => ("0u", "4294967295u"), - _ => return Err(Error::UnsupportedDType { dtype, op: "topk" }), - }; - - let cmp_less = match dtype { - DType::F32 => "a < b", - DType::I32 => "a < b", - DType::U32 => "a < b", - _ => "a < b", - }; - - Ok(format!( - r#"// Auto-generated topk operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; -const MAX_SORT_SIZE: u32 = 512u; - -var shared_vals: array<{t}, 512>; -var shared_idxs: array; - -struct TopkParams {{ - outer_size: u32, - sort_size: u32, - inner_size: u32, - k: u32, - largest: u32, - sorted: u32, -}} - -@group(0) @binding(0) var topk_input: array<{t}>; -@group(0) @binding(1) var topk_values: array<{t}>; -@group(0) @binding(2) var topk_indices: array; -@group(0) @binding(3) var topk_params: TopkParams; - -fn compare_less_{suffix}(a: {t}, b: {t}) -> bool {{ - return {cmp_less}; -}} - -fn bitonic_cas_{suffix}(i: u32, j: u32, dir: bool) {{ - let vi = shared_vals[i]; - let vj = shared_vals[j]; - let swap = select(compare_less_{suffix}(vi, vj), compare_less_{suffix}(vj, vi), dir); - if (swap) {{ - shared_vals[i] = vj; - shared_vals[j] = vi; - let ti = shared_idxs[i]; - shared_idxs[i] = shared_idxs[j]; - shared_idxs[j] = ti; - }} -}} - -@compute @workgroup_size(256) -fn topk_{suffix}( - @builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3 -) {{ - let outer_idx = group_id.x; - let inner_idx = group_id.y; - let tid = local_id.x; - - let outer_size = topk_params.outer_size; - let sort_size = topk_params.sort_size; - let inner_size = topk_params.inner_size; - let k = topk_params.k; - let largest = topk_params.largest != 0u; - - if (outer_idx >= outer_size || inner_idx >= inner_size) {{ - return; - }} - - var n = sort_size; - var p: u32 = 1u; - while (p < n) {{ - p = p << 1u; - }} - n = min(p, MAX_SORT_SIZE); - - let base_offset = outer_idx * sort_size * inner_size + inner_idx; - for (var i = tid; i < n; i = i + WORKGROUP_SIZE) {{ - if (i < sort_size) {{ - let idx = base_offset + i * inner_size; - shared_vals[i] = topk_input[idx]; - shared_idxs[i] = i32(i); - }} else {{ - shared_vals[i] = select({t}({max_val}), {t}({min_val}), largest); - shared_idxs[i] = i32(i); - }} - }} - workgroupBarrier(); - - // Bitonic sort (descending if largest, ascending if smallest) - for (var k_: u32 = 2u; k_ <= n; k_ = k_ << 1u) {{ - for (var j: u32 = k_ >> 1u; j > 0u; j = j >> 1u) {{ - for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) {{ - // Calculate bitonic network indices - let ij = (i / j) * 2u * j + (i % j); - let ij_pair = ij + j; - - // Direction depends on which half of the network we're in - // For largest: descending (true), for smallest: ascending (false) - let ascending_local = ((ij / k_) % 2u == 0u) != largest; - - if (ij_pair < n) {{ - bitonic_cas_{suffix}(ij, ij_pair, ascending_local); - }} - }} - workgroupBarrier(); - }} - }} - - // Write top-k values and indices - let out_base = outer_idx * k * inner_size + inner_idx; - for (var i = tid; i < k; i = i + WORKGROUP_SIZE) {{ - let out_idx = out_base + i * inner_size; - topk_values[out_idx] = shared_vals[i]; - topk_indices[out_idx] = shared_idxs[i]; - }} -}} -"#, - t = t, - suffix = suffix, - min_val = min_val, - max_val = max_val, - cmp_less = cmp_less, - )) -} - -/// Generate WGSL shader for searchsorted operation -pub fn generate_searchsorted_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated searchsorted operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct SearchsortedParams {{ - seq_len: u32, - num_values: u32, - right: u32, - _pad: u32, -}} - -@group(0) @binding(0) var ss_seq: array<{t}>; -@group(0) @binding(1) var ss_values: array<{t}>; -@group(0) @binding(2) var ss_output: array; -@group(0) @binding(3) var ss_params: SearchsortedParams; - -@compute @workgroup_size(256) -fn searchsorted_{suffix}(@builtin(global_invocation_id) global_id: vec3) {{ - let idx = global_id.x; - - if (idx >= ss_params.num_values) {{ - return; - }} - - let value = ss_values[idx]; - let seq_len = ss_params.seq_len; - let right = ss_params.right != 0u; - - // Binary search - var lo: u32 = 0u; - var hi: u32 = seq_len; - - while (lo < hi) {{ - let mid = lo + (hi - lo) / 2u; - let seq_val = ss_seq[mid]; - - var go_right: bool; - if (right) {{ - go_right = seq_val <= value; - }} else {{ - go_right = seq_val < value; - }} - - if (go_right) {{ - lo = mid + 1u; - }} else {{ - hi = mid; - }} - }} - - ss_output[idx] = i32(lo); -}} -"#, - t = t, - suffix = suffix, - )) -} - -/// Generate WGSL shader for nonzero counting (phase 1) -pub fn generate_count_nonzero_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - let zero_check = match dtype { - DType::F32 => "input[idx] != 0.0", - DType::I32 => "input[idx] != 0", - DType::U32 => "input[idx] != 0u", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "count_nonzero", - }); - } - }; - - Ok(format!( - r#"// Auto-generated count_nonzero operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -var shared_count: array; - -struct CountParams {{ - numel: u32, -}} - -@group(0) @binding(0) var input: array<{t}>; -@group(0) @binding(1) var count_output: array>; -@group(0) @binding(2) var count_params: CountParams; - -@compute @workgroup_size(256) -fn count_nonzero_{suffix}( - @builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3 -) {{ - let tid = local_id.x; - let numel = count_params.numel; - - // Each thread counts its elements - var local_count: u32 = 0u; - var idx = global_id.x; - while (idx < numel) {{ - if ({zero_check}) {{ - local_count = local_count + 1u; - }} - idx = idx + WORKGROUP_SIZE * 256u; // stride by total threads - }} - - shared_count[tid] = local_count; - workgroupBarrier(); - - // Tree reduction - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - shared_count[tid] = shared_count[tid] + shared_count[tid + s]; - }} - workgroupBarrier(); - }} - - // Thread 0 adds to global counter - if (tid == 0u) {{ - atomicAdd(&count_output[0], shared_count[0]); - }} -}} -"#, - t = t, - suffix = suffix, - zero_check = zero_check, - )) -} - -/// Generate WGSL shader for gathering nonzero indices (phase 2) -pub fn generate_gather_nonzero_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - let zero_check = match dtype { - DType::F32 => "input[idx] != 0.0", - DType::I32 => "input[idx] != 0", - DType::U32 => "input[idx] != 0u", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "gather_nonzero", - }); - } - }; - - Ok(format!( - r#"// Auto-generated gather_nonzero operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct CountParams {{ - numel: u32, -}} - -@group(0) @binding(0) var input: array<{t}>; -@group(0) @binding(1) var indices_output: array; -@group(0) @binding(2) var counter: array>; -@group(0) @binding(3) var count_params: CountParams; - -@compute @workgroup_size(256) -fn gather_nonzero_{suffix}(@builtin(global_invocation_id) global_id: vec3) {{ - let numel = count_params.numel; - var idx = global_id.x; - - while (idx < numel) {{ - if ({zero_check}) {{ - let out_idx = atomicAdd(&counter[0], 1u); - indices_output[out_idx] = i32(idx); - }} - idx = idx + WORKGROUP_SIZE * 256u; - }} -}} -"#, - t = t, - suffix = suffix, - zero_check = zero_check, - )) -} - -/// Generate WGSL shader for flat_to_multi_index -pub fn generate_flat_to_multi_index_shader() -> Result { - Ok(r#"// Convert flat indices to multi-dimensional indices - -const WORKGROUP_SIZE: u32 = 256u; -const MAX_DIMS: u32 = 8u; - -struct FlatToMultiParams { - nnz: u32, - ndim: u32, - _pad0: u32, - _pad1: u32, - shape: array, 2>, -} - -@group(0) @binding(0) var flat_indices: array; -@group(0) @binding(1) var multi_indices: array; -@group(0) @binding(2) var params: FlatToMultiParams; - -fn get_shape_dim(d: u32) -> u32 { - return params.shape[d / 4u][d % 4u]; -} - -@compute @workgroup_size(256) -fn flat_to_multi_index(@builtin(global_invocation_id) global_id: vec3) { - let idx = global_id.x; - - if (idx >= params.nnz) { - return; - } - - var flat_idx = u32(flat_indices[idx]); - let ndim = params.ndim; - - // Compute strides on the fly (row-major) - // and convert flat index to multi-index - for (var d: u32 = ndim; d > 0u; d = d - 1u) { - let dim = d - 1u; - let dim_size = get_shape_dim(dim); - let coord = flat_idx % dim_size; - flat_idx = flat_idx / dim_size; - - // Store: multi_indices[idx * ndim + dim] = coord - multi_indices[idx * ndim + dim] = i32(coord); - } -} -"# - .to_string()) -} - -/// Generate WGSL shader for unique operations -pub fn generate_unique_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated unique operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -var shared_count: array; - -struct UniqueParams {{ - numel: u32, -}} - -@group(0) @binding(0) var sorted_input: array<{t}>; -@group(0) @binding(1) var unique_output: array<{t}>; -@group(0) @binding(2) var unique_counter: array>; -@group(0) @binding(3) var unique_params: UniqueParams; - -// Count unique elements (on sorted input) -@compute @workgroup_size(256) -fn count_unique_{suffix}( - @builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3 -) {{ - let tid = local_id.x; - let numel = unique_params.numel; - - var local_count: u32 = 0u; - var idx = global_id.x; - - while (idx < numel) {{ - // Count if first element or different from previous - if (idx == 0u || sorted_input[idx] != sorted_input[idx - 1u]) {{ - local_count = local_count + 1u; - }} - idx = idx + WORKGROUP_SIZE * 256u; - }} - - shared_count[tid] = local_count; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - shared_count[tid] = shared_count[tid] + shared_count[tid + s]; - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - atomicAdd(&unique_counter[0], shared_count[0]); - }} -}} - -// Extract unique elements -@compute @workgroup_size(256) -fn extract_unique_{suffix}(@builtin(global_invocation_id) global_id: vec3) {{ - let numel = unique_params.numel; - var idx = global_id.x; - - while (idx < numel) {{ - // Write if first element or different from previous - if (idx == 0u || sorted_input[idx] != sorted_input[idx - 1u]) {{ - let out_idx = atomicAdd(&unique_counter[0], 1u); - unique_output[out_idx] = sorted_input[idx]; - }} - idx = idx + WORKGROUP_SIZE * 256u; - }} -}} -"#, - t = t, - suffix = suffix, - )) -} - -/// Generate WGSL shader for unique_with_counts operations -pub fn generate_unique_with_counts_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated unique_with_counts operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct UniqueCountsParams {{ - numel: u32, - num_unique: u32, - _pad0: u32, - _pad1: u32, -}} - -// Mark boundaries in sorted array (where value changes) -// Output: flags[i] = 1 if sorted[i] != sorted[i-1] (or i=0), else 0 -@group(0) @binding(0) var sorted_input: array<{t}>; -@group(0) @binding(1) var boundary_flags: array; -@group(0) @binding(2) var params: UniqueCountsParams; - -@compute @workgroup_size(256) -fn mark_boundaries_{suffix}(@builtin(global_invocation_id) global_id: vec3) {{ - let idx = global_id.x; - let numel = params.numel; - - if (idx >= numel) {{ - return; - }} - - // Mark boundary: first element or different from previous - if (idx == 0u || sorted_input[idx] != sorted_input[idx - 1u]) {{ - boundary_flags[idx] = 1u; - }} else {{ - boundary_flags[idx] = 0u; - }} -}} - -// Scatter unique values and compute counts using prefix sum indices -// prefix_sum[i] contains the output index for element at position i (if it's a boundary) -// We write: unique_values[prefix_sum[i]-1] = sorted[i] when flags[i] == 1 -// counts[prefix_sum[i]-1] = (next boundary position - i) computed from adjacent prefix sums -@group(0) @binding(0) var scatter_sorted: array<{t}>; -@group(0) @binding(1) var prefix_sum: array; -@group(0) @binding(2) var unique_values: array<{t}>; -@group(0) @binding(3) var inverse_indices: array; -@group(0) @binding(4) var counts: array; -@group(0) @binding(5) var scatter_params: UniqueCountsParams; - -@compute @workgroup_size(256) -fn scatter_unique_with_counts_{suffix}(@builtin(global_invocation_id) global_id: vec3) {{ - let idx = global_id.x; - let numel = scatter_params.numel; - let num_unique = scatter_params.num_unique; - - if (idx >= numel) {{ - return; - }} - - // The prefix sum gives us 1-based output indices - let out_idx_plus1 = prefix_sum[idx]; - - // Check if this is a boundary by comparing with previous prefix sum - let is_boundary = (idx == 0u) || (prefix_sum[idx] != prefix_sum[idx - 1u]); - - // Write inverse index: which unique element does this sorted element map to - inverse_indices[idx] = i32(out_idx_plus1 - 1u); - - if (is_boundary) {{ - let out_idx = out_idx_plus1 - 1u; - unique_values[out_idx] = scatter_sorted[idx]; - - // Compute count: find next boundary position - // The count is (next_boundary_position - idx) - // If we're the last unique, count to numel - if (out_idx + 1u >= num_unique) {{ - // Last unique element - counts[out_idx] = i32(numel - idx); - }} else {{ - // Find next boundary: it's where prefix_sum increases next - // We need to find the smallest j > idx where prefix_sum[j] > out_idx_plus1 - // Actually, we can compute this differently: - // The run length is the distance to the next boundary - // For efficiency, we'll use a second pass or a different approach - - // For now, scan forward (not ideal but correct) - var run_len: u32 = 1u; - var j = idx + 1u; - while (j < numel && prefix_sum[j] == out_idx_plus1) {{ - run_len = run_len + 1u; - j = j + 1u; - }} - counts[out_idx] = i32(run_len); - }} - }} -}} -"#, - t = t, - suffix = suffix, - )) -} diff --git a/src/runtime/wgpu/shaders/generator/sparse_algorithms.rs b/src/runtime/wgpu/shaders/generator/sparse_algorithms.rs deleted file mode 100644 index fa278842..00000000 --- a/src/runtime/wgpu/shaders/generator/sparse_algorithms.rs +++ /dev/null @@ -1,353 +0,0 @@ -//! WGSL shader generation for sparse matrix algorithms. -//! -//! Implements: -//! - Column-Parallel DSMM: Dense × Sparse Matrix Multiplication -//! - Row-Parallel SpGEMM: Sparse × Sparse Matrix Multiplication (simplified GPU version) - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -/// Generate WGSL shader for column-parallel DSMM: C = A * B -/// -/// Dense A [M, K] × Sparse B CSC [K, N] → Dense C [M, N] -/// -/// Algorithm: -/// For each column j in B: -/// For each non-zero B[k, j]: -/// C[:, j] += A[:, k] * B[k, j] -/// -/// GPU parallelization: -/// - Each thread computes one element C[row, col] -/// - Thread reads A[row, :] and accumulates with sparse column of B -pub fn generate_dsmm_csc_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Column-Parallel Dense × Sparse Matrix Multiplication: C = A * B -// Dense A [M, K] × Sparse B CSC [K, N] → Dense C [M, N] -// Each thread computes one element C[row, col] - -const WORKGROUP_SIZE: u32 = 256u; - -struct DsmmParams {{ - m: u32, // Number of rows in A (and C) - k: u32, // Number of columns in A (and rows in B) - n: u32, // Number of columns in B (and C) - _pad: u32, -}} - -// Dense matrix A (m x k, row-major) -@group(0) @binding(0) var a: array<{t}>; -// CSC format for B -@group(0) @binding(1) var col_ptrs: array; -@group(0) @binding(2) var row_indices: array; -@group(0) @binding(3) var b_values: array<{t}>; -// Output matrix C (m x n, row-major) -@group(0) @binding(4) var c: array<{t}>; -// Parameters -@group(0) @binding(5) var params: DsmmParams; - -@compute @workgroup_size(256) -fn dsmm_csc_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = params.m * params.n; - if (idx >= total) {{ - return; - }} - - let row = idx / params.n; - let col = idx % params.n; - - // Accumulate C[row, col] = sum over non-zeros in column 'col' of B - // For each B[k, col], add A[row, k] * B[k, col] - let col_start = col_ptrs[col]; - let col_end = col_ptrs[col + 1u]; - - var sum: {t} = {zero}; - for (var j: i32 = col_start; j < col_end; j = j + 1) {{ - let k = row_indices[j]; // row index in B = column index in A - let b_val = b_values[j]; - // A is row-major: A[row, k] = a[row * k_dim + k] - let a_idx = row * params.k + u32(k); - sum = sum + a[a_idx] * b_val; - }} - - // C is row-major: C[row, col] = c[row * n + col] - c[idx] = sum; -}} -"#, - t = t, - suffix = suffix, - zero = zero_literal(dtype), - )) -} - -/// Generate WGSL shader for SpGEMM symbolic phase: count NNZ per output row. -/// -/// CSR A `[M, K]` × CSR B `[K, N]` → `row_nnz[M]` -/// -/// For small N (< 4096), uses a bitmap to track unique columns. -/// Each workgroup processes one row of the output. -pub fn generate_spgemm_symbolic_shader(dtype: DType) -> Result { - let suffix = dtype_suffix(dtype)?; - let _ = wgsl_type(dtype)?; // validate dtype - - Ok(format!( - r#"// SpGEMM Symbolic Phase: Count NNZ per output row -// CSR A [M, K] × CSR B [K, N] → row_nnz[M] -// Uses bitmap in workgroup memory for small N - -const WORKGROUP_SIZE: u32 = 256u; -const MAX_BITMAP_SIZE: u32 = 4096u; // Max columns we can handle with bitmap - -struct SymbolicParams {{ - m: u32, // Number of rows in A (and output) - n: u32, // Number of columns in B (and output) - _pad0: u32, - _pad1: u32, -}} - -// CSR format for A -@group(0) @binding(0) var a_row_ptrs: array; -@group(0) @binding(1) var a_col_indices: array; -// CSR format for B -@group(0) @binding(2) var b_row_ptrs: array; -@group(0) @binding(3) var b_col_indices: array; -// Output: NNZ per row -@group(0) @binding(4) var row_nnz: array; -// Global bitmap storage (one bitmap per row, M * ((N+31)/32) u32 words) -@group(0) @binding(5) var bitmap: array>; -// Parameters (uniforms are placed after storage buffers in LayoutKey layouts) -@group(0) @binding(6) var params: SymbolicParams; - -@compute @workgroup_size(256) -fn spgemm_symbolic_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let row = gid.x; - if (row >= params.m) {{ - return; - }} - - // Calculate bitmap offset for this row - let words_per_row = (params.n + 31u) / 32u; - let bitmap_offset = row * words_per_row; - - // Clear this row's bitmap - for (var w: u32 = 0u; w < words_per_row; w = w + 1u) {{ - atomicStore(&bitmap[bitmap_offset + w], 0u); - }} - - // For each non-zero in row 'row' of A - let a_start = a_row_ptrs[row]; - let a_end = a_row_ptrs[row + 1u]; - - for (var ai: i32 = a_start; ai < a_end; ai = ai + 1) {{ - let k = a_col_indices[ai]; // column in A = row in B - - // For each non-zero in row k of B - let b_start = b_row_ptrs[k]; - let b_end = b_row_ptrs[k + 1]; - - for (var bi: i32 = b_start; bi < b_end; bi = bi + 1) {{ - let j = b_col_indices[bi]; // column in B = column in output - - // Set bit j in bitmap - let word_idx = bitmap_offset + u32(j) / 32u; - let bit_idx = u32(j) % 32u; - atomicOr(&bitmap[word_idx], 1u << bit_idx); - }} - }} - - // Count set bits (popcount) - var count: i32 = 0; - for (var w: u32 = 0u; w < words_per_row; w = w + 1u) {{ - let word = atomicLoad(&bitmap[bitmap_offset + w]); - count = count + i32(countOneBits(word)); - }} - - row_nnz[row] = count; -}} -"#, - suffix = suffix, - )) -} - -/// Generate WGSL shader for SpGEMM accumulate phase. -/// -/// Each thread handles one output row, clears accum/flags for that row, and accumulates -/// contributions from A(row,:) * B(:,col). -pub fn generate_spgemm_accumulate_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// SpGEMM Accumulate Phase -// CSR A [M, K] × CSR B [K, N] -> dense row accumulators -// Uses dense accumulator array per row - -const WORKGROUP_SIZE: u32 = 256u; - -struct SpgemmParams {{ - m: u32, - n: u32, - _pad0: u32, - _pad1: u32, -}} - -// CSR format for A -@group(0) @binding(0) var a_row_ptrs: array; -@group(0) @binding(1) var a_col_indices: array; -@group(0) @binding(2) var a_values: array<{t}>; -// CSR format for B -@group(0) @binding(3) var b_row_ptrs: array; -@group(0) @binding(4) var b_col_indices: array; -@group(0) @binding(5) var b_values: array<{t}>; -// Dense accumulator (M * N elements, used as temporary per-row storage) -@group(0) @binding(6) var accum: array<{t}>; -// Flag array to track which columns have values (M * N elements) -@group(0) @binding(7) var flags: array; -// Parameters (uniforms are placed after storage buffers in LayoutKey layouts) -@group(0) @binding(8) var params: SpgemmParams; - -@compute @workgroup_size(256) -fn spgemm_accumulate_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let row = gid.x; - if (row >= params.m) {{ - return; - }} - - let accum_offset = row * params.n; - - // Clear accumulator and flags for this row - for (var col: u32 = 0u; col < params.n; col = col + 1u) {{ - accum[accum_offset + col] = {zero}; - flags[accum_offset + col] = 0u; - }} - - // Accumulate: C[row, :] = sum over k of A[row, k] * B[k, :] - let a_start = a_row_ptrs[row]; - let a_end = a_row_ptrs[row + 1u]; - - for (var ai: i32 = a_start; ai < a_end; ai = ai + 1) {{ - let k = a_col_indices[ai]; - let a_val = a_values[ai]; - - let b_start = b_row_ptrs[k]; - let b_end = b_row_ptrs[k + 1]; - - for (var bi: i32 = b_start; bi < b_end; bi = bi + 1) {{ - let j = b_col_indices[bi]; - let b_val = b_values[bi]; - let idx = accum_offset + u32(j); - accum[idx] = accum[idx] + a_val * b_val; - flags[idx] = 1u; // Mark column as having a value - }} - }} -}} -"#, - t = t, - suffix = suffix, - zero = zero_literal(dtype) - )) -} - -/// Generate WGSL shader for SpGEMM scatter phase. -/// -/// Compacts per-row `accum/flags` into CSR `col_indices/values` using row_ptrs. -pub fn generate_spgemm_scatter_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// SpGEMM Scatter Phase -// Compacts dense row accumulators into CSR output arrays. - -const WORKGROUP_SIZE: u32 = 256u; - -struct SpgemmParams {{ - m: u32, - n: u32, - _pad0: u32, - _pad1: u32, -}} - -@group(0) @binding(0) var c_row_ptrs: array; -@group(0) @binding(1) var accum: array<{t}>; -@group(0) @binding(2) var flags: array; -@group(0) @binding(3) var c_col_indices: array; -@group(0) @binding(4) var c_values: array<{t}>; -@group(0) @binding(5) var params: SpgemmParams; - -@compute @workgroup_size(256) -fn spgemm_scatter_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let row = gid.x; - if (row >= params.m) {{ - return; - }} - - let accum_offset = row * params.n; - var write_idx: i32 = c_row_ptrs[row]; - - for (var col: u32 = 0u; col < params.n; col = col + 1u) {{ - let idx = accum_offset + col; - if (flags[idx] != 0u) {{ - c_col_indices[write_idx] = i32(col); - c_values[write_idx] = accum[idx]; - write_idx = write_idx + 1; - }} - }} -}} -"#, - t = t, - suffix = suffix - )) -} - -/// Get zero literal for dtype -fn zero_literal(dtype: DType) -> &'static str { - match dtype { - DType::F32 | DType::F16 => "0.0", - DType::I32 => "0", - DType::U32 => "0u", - _ => "0.0", - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_dsmm_csc_shader_syntax_f32() { - let shader = generate_dsmm_csc_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("DSMM shader should be valid WGSL"); - } - - #[test] - fn test_spgemm_symbolic_shader_syntax_f32() { - let shader = generate_spgemm_symbolic_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("SpGEMM symbolic shader should be valid WGSL"); - } - - #[test] - fn test_spgemm_accumulate_shader_syntax_f32() { - let shader = generate_spgemm_accumulate_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("SpGEMM accumulate shader should be valid WGSL"); - } - - #[test] - fn test_spgemm_scatter_shader_syntax_f32() { - let shader = generate_spgemm_scatter_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("SpGEMM scatter shader should be valid WGSL"); - } -} diff --git a/src/runtime/wgpu/shaders/generator/sparse_conversions.rs b/src/runtime/wgpu/shaders/generator/sparse_conversions.rs deleted file mode 100644 index 0fbcd3b3..00000000 --- a/src/runtime/wgpu/shaders/generator/sparse_conversions.rs +++ /dev/null @@ -1,644 +0,0 @@ -//! WGSL shader generators for sparse format conversions. -//! -//! Generates shaders for converting between COO, CSR, and CSC formats. -//! Algorithms: -//! - CSR/CSC → COO: Expand pointers to explicit indices -//! - COO → CSR/CSC: Histogram + scan + scatter (counting sort) -//! - CSR ↔ CSC: Direct transpose via histogram + scan + scatter - -use crate::dtype::DType; -use crate::error::Result; - -use super::common::wgsl_type; - -/// Generate shader for expanding CSR row pointers to explicit row indices (CSR → COO). -/// -/// Input: `row_ptrs[nrows+1]`, nnz elements total -/// Output: `row_indices[nnz]` where each element i gets the row index it belongs to -pub fn generate_expand_row_ptrs_shader() -> Result { - Ok(r#" -// Expand CSR row pointers to explicit row indices -// One thread per row - -struct ExpandParams { - nrows: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -} - -@group(0) @binding(0) var row_ptrs: array; -@group(0) @binding(1) var row_indices: array; -@group(0) @binding(2) var params: ExpandParams; - -@compute @workgroup_size(256) -fn expand_row_ptrs(@builtin(global_invocation_id) gid: vec3) { - let row = gid.x; - if (row >= params.nrows) { - return; - } - - let start = row_ptrs[row]; - let end = row_ptrs[row + 1u]; - - // Fill all indices in this row with the row number - for (var i = start; i < end; i = i + 1) { - row_indices[i] = i32(row); - } -} -"# - .to_string()) -} - -/// Generate shader for expanding CSC column pointers to explicit column indices (CSC → COO). -pub fn generate_expand_col_ptrs_shader() -> Result { - Ok(r#" -// Expand CSC column pointers to explicit column indices -// One thread per column - -struct ExpandParams { - ncols: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -} - -@group(0) @binding(0) var col_ptrs: array; -@group(0) @binding(1) var col_indices: array; -@group(0) @binding(2) var params: ExpandParams; - -@compute @workgroup_size(256) -fn expand_col_ptrs(@builtin(global_invocation_id) gid: vec3) { - let col = gid.x; - if (col >= params.ncols) { - return; - } - - let start = col_ptrs[col]; - let end = col_ptrs[col + 1u]; - - // Fill all indices in this column with the column number - for (var i = start; i < end; i = i + 1) { - col_indices[i] = i32(col); - } -} -"# - .to_string()) -} - -/// Generate histogram shader for counting elements per row/column. -/// -/// Used by COO→CSR/CSC and CSR↔CSC conversions. -pub fn generate_histogram_shader() -> Result { - Ok(r#" -// Count elements per bucket (row or column) -// One thread per element - -struct HistogramParams { - nnz: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -} - -@group(0) @binding(0) var indices: array; -@group(0) @binding(1) var counts: array>; -@group(0) @binding(2) var params: HistogramParams; - -@compute @workgroup_size(256) -fn histogram(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx >= params.nnz) { - return; - } - - let bucket = indices[idx]; - atomicAdd(&counts[bucket], 1); -} -"# - .to_string()) -} - -/// Generate shader for COO→CSR scatter operation. -/// -/// Given sorted row indices and their scatter positions, place elements -/// at their correct positions in the CSR output. -pub fn generate_coo_to_csr_scatter_shader(dtype: DType) -> Result { - let wgsl_t = wgsl_type(dtype)?; - - Ok(format!( - r#" -// Scatter COO elements to CSR format using atomic position tracking -// One thread per element - -struct ScatterParams {{ - nnz: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var in_row_indices: array; -@group(0) @binding(1) var in_col_indices: array; -@group(0) @binding(2) var in_values: array<{wgsl_t}>; -@group(0) @binding(3) var row_ptrs_atomic: array>; -@group(0) @binding(4) var out_col_indices: array; -@group(0) @binding(5) var out_values: array<{wgsl_t}>; -@group(0) @binding(6) var params: ScatterParams; - -@compute @workgroup_size(256) -fn coo_to_csr_scatter(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= params.nnz) {{ - return; - }} - - let row = in_row_indices[idx]; - let col = in_col_indices[idx]; - let val = in_values[idx]; - - // Atomically get position within this row's segment - let pos = atomicAdd(&row_ptrs_atomic[row], 1); - - out_col_indices[pos] = col; - out_values[pos] = val; -}} -"#, - wgsl_t = wgsl_t - )) -} - -/// Generate shader for COO→CSC scatter operation. -pub fn generate_coo_to_csc_scatter_shader(dtype: DType) -> Result { - let wgsl_t = wgsl_type(dtype)?; - - Ok(format!( - r#" -// Scatter COO elements to CSC format using atomic position tracking -// One thread per element - -struct ScatterParams {{ - nnz: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var in_row_indices: array; -@group(0) @binding(1) var in_col_indices: array; -@group(0) @binding(2) var in_values: array<{wgsl_t}>; -@group(0) @binding(3) var col_ptrs_atomic: array>; -@group(0) @binding(4) var out_row_indices: array; -@group(0) @binding(5) var out_values: array<{wgsl_t}>; -@group(0) @binding(6) var params: ScatterParams; - -@compute @workgroup_size(256) -fn coo_to_csc_scatter(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= params.nnz) {{ - return; - }} - - let row = in_row_indices[idx]; - let col = in_col_indices[idx]; - let val = in_values[idx]; - - // Atomically get position within this column's segment - let pos = atomicAdd(&col_ptrs_atomic[col], 1); - - out_row_indices[pos] = row; - out_values[pos] = val; -}} -"#, - wgsl_t = wgsl_t - )) -} - -/// Generate shader for CSR→CSC transpose scatter operation. -/// -/// Directly converts CSR to CSC without going through COO. -pub fn generate_csr_to_csc_scatter_shader(dtype: DType) -> Result { - let wgsl_t = wgsl_type(dtype)?; - - Ok(format!( - r#" -// Scatter CSR elements to CSC format (transpose) -// One thread per row, iterates over row's elements - -struct TransposeParams {{ - nrows: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var in_row_ptrs: array; -@group(0) @binding(1) var in_col_indices: array; -@group(0) @binding(2) var in_values: array<{wgsl_t}>; -@group(0) @binding(3) var col_ptrs_atomic: array>; -@group(0) @binding(4) var out_row_indices: array; -@group(0) @binding(5) var out_values: array<{wgsl_t}>; -@group(0) @binding(6) var params: TransposeParams; - -@compute @workgroup_size(256) -fn csr_to_csc_scatter(@builtin(global_invocation_id) gid: vec3) {{ - let row = gid.x; - if (row >= params.nrows) {{ - return; - }} - - let start = in_row_ptrs[row]; - let end = in_row_ptrs[row + 1u]; - - for (var i = start; i < end; i = i + 1) {{ - let col = in_col_indices[i]; - let val = in_values[i]; - - // Atomically get position within this column's segment - let pos = atomicAdd(&col_ptrs_atomic[col], 1); - - out_row_indices[pos] = i32(row); - out_values[pos] = val; - }} -}} -"#, - wgsl_t = wgsl_t - )) -} - -/// Generate shader for CSC→CSR transpose scatter operation. -pub fn generate_csc_to_csr_scatter_shader(dtype: DType) -> Result { - let wgsl_t = wgsl_type(dtype)?; - - Ok(format!( - r#" -// Scatter CSC elements to CSR format (transpose) -// One thread per column, iterates over column's elements - -struct TransposeParams {{ - ncols: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var in_col_ptrs: array; -@group(0) @binding(1) var in_row_indices: array; -@group(0) @binding(2) var in_values: array<{wgsl_t}>; -@group(0) @binding(3) var row_ptrs_atomic: array>; -@group(0) @binding(4) var out_col_indices: array; -@group(0) @binding(5) var out_values: array<{wgsl_t}>; -@group(0) @binding(6) var params: TransposeParams; - -@compute @workgroup_size(256) -fn csc_to_csr_scatter(@builtin(global_invocation_id) gid: vec3) {{ - let col = gid.x; - if (col >= params.ncols) {{ - return; - }} - - let start = in_col_ptrs[col]; - let end = in_col_ptrs[col + 1u]; - - for (var i = start; i < end; i = i + 1) {{ - let row = in_row_indices[i]; - let val = in_values[i]; - - // Atomically get position within this row's segment - let pos = atomicAdd(&row_ptrs_atomic[row], 1); - - out_col_indices[pos] = i32(col); - out_values[pos] = val; - }} -}} -"#, - wgsl_t = wgsl_t - )) -} - -/// Generate shader to copy row_ptrs before scatter (since scatter modifies them atomically). -pub fn generate_copy_ptrs_shader() -> Result { - Ok(r#" -// Copy pointers array (preserves original before scatter) - -struct CopyParams { - n: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -} - -@group(0) @binding(0) var src: array; -@group(0) @binding(1) var dst: array; -@group(0) @binding(2) var params: CopyParams; - -@compute @workgroup_size(256) -fn copy_ptrs(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; - if (idx >= params.n) { - return; - } - dst[idx] = src[idx]; -} -"# - .to_string()) -} - -/// Generate shader for CSR to dense conversion. -/// -/// Each thread handles one row, scattering values into the dense output. -pub fn generate_csr_to_dense_shader(dtype: DType) -> Result { - let wgsl_t = wgsl_type(dtype)?; - - Ok(format!( - r#" -// Convert CSR sparse matrix to dense format -// One thread per row - -struct CsrToDenseParams {{ - nrows: u32, - ncols: u32, - _pad0: u32, - _pad1: u32, -}} - -@group(0) @binding(0) var row_ptrs: array; -@group(0) @binding(1) var col_indices: array; -@group(0) @binding(2) var values: array<{wgsl_t}>; -@group(0) @binding(3) var dense: array<{wgsl_t}>; -@group(0) @binding(4) var params: CsrToDenseParams; - -@compute @workgroup_size(256) -fn csr_to_dense(@builtin(global_invocation_id) gid: vec3) {{ - let row = gid.x; - if (row >= params.nrows) {{ - return; - }} - - let start = row_ptrs[row]; - let end = row_ptrs[row + 1u]; - let ncols = params.ncols; - - // Scatter this row's values into the dense matrix - for (var i = start; i < end; i = i + 1) {{ - let col = u32(col_indices[i]); - let val = values[i]; - // Dense matrix is row-major: index = row * ncols + col - dense[row * ncols + col] = val; - }} -}} -"#, - wgsl_t = wgsl_t - )) -} - -/// Generate shader to count non-zero elements in dense matrix. -/// -/// Each thread counts non-zeros in a chunk, atomically adds to global counter. -pub fn generate_count_nonzeros_shader(dtype: DType) -> Result { - let wgsl_t = wgsl_type(dtype)?; - let zero_check = match dtype { - DType::F32 | DType::F64 => "abs(val) >= threshold", - _ => "val != zero_val", - }; - - Ok(format!( - r#" -// Count non-zero elements in dense matrix -// Returns total count via atomic counter - -struct CountParams {{ - total_elems: u32, - threshold_bits: u32, - _pad0: u32, - _pad1: u32, -}} - -@group(0) @binding(0) var dense: array<{wgsl_t}>; -@group(0) @binding(1) var count: atomic; -@group(0) @binding(2) var params: CountParams; - -@compute @workgroup_size(256) -fn count_nonzeros(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= params.total_elems) {{ - return; - }} - - let val = dense[idx]; - let threshold = bitcast<{wgsl_t}>(params.threshold_bits); - let zero_val = {wgsl_t}(0); - - if ({zero_check}) {{ - atomicAdd(&count, 1u); - }} -}} -"#, - wgsl_t = wgsl_t, - zero_check = zero_check - )) -} - -/// Generate shader for dense to COO conversion (scatter pass). -/// -/// Each thread checks one element, if non-zero, atomically gets position and writes to COO. -pub fn generate_dense_to_coo_scatter_shader(dtype: DType) -> Result { - let wgsl_t = wgsl_type(dtype)?; - let zero_check = match dtype { - DType::F32 | DType::F64 => "abs(val) >= threshold", - _ => "val != zero_val", - }; - - Ok(format!( - r#" -// Scatter non-zero elements from dense matrix to COO format -// One thread per element, atomic position tracking - -struct DenseToCooParams {{ - nrows: u32, - ncols: u32, - threshold_bits: u32, - _pad0: u32, -}} - -@group(0) @binding(0) var dense: array<{wgsl_t}>; -@group(0) @binding(1) var row_indices: array; -@group(0) @binding(2) var col_indices: array; -@group(0) @binding(3) var values: array<{wgsl_t}>; -@group(0) @binding(4) var write_pos: atomic; -@group(0) @binding(5) var params: DenseToCooParams; - -@compute @workgroup_size(256) -fn dense_to_coo_scatter(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = params.nrows * params.ncols; - if (idx >= total) {{ - return; - }} - - let val = dense[idx]; - let threshold = bitcast<{wgsl_t}>(params.threshold_bits); - let zero_val = {wgsl_t}(0); - - if ({zero_check}) {{ - // Compute row and column from linear index - let row = idx / params.ncols; - let col = idx % params.ncols; - - // Atomically get write position - let pos = atomicAdd(&write_pos, 1u); - - // Write COO entry - row_indices[pos] = i32(row); - col_indices[pos] = i32(col); - values[pos] = val; - }} -}} -"#, - wgsl_t = wgsl_t, - zero_check = zero_check - )) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_expand_row_ptrs_shader_syntax() { - let shader = generate_expand_row_ptrs_shader().unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for expand_row_ptrs:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_expand_col_ptrs_shader_syntax() { - let shader = generate_expand_col_ptrs_shader().unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for expand_col_ptrs:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_histogram_shader_syntax() { - let shader = generate_histogram_shader().unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!("Invalid WGSL for histogram:\n{}\n\nShader:\n{}", e, shader) - }); - } - - #[test] - fn test_coo_to_csr_scatter_shader_syntax() { - for dtype in [DType::F32, DType::I32, DType::U32] { - let shader = generate_coo_to_csr_scatter_shader(dtype).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for coo_to_csr_scatter {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_coo_to_csc_scatter_shader_syntax() { - for dtype in [DType::F32, DType::I32, DType::U32] { - let shader = generate_coo_to_csc_scatter_shader(dtype).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for coo_to_csc_scatter {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_csr_to_csc_scatter_shader_syntax() { - for dtype in [DType::F32, DType::I32, DType::U32] { - let shader = generate_csr_to_csc_scatter_shader(dtype).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for csr_to_csc_scatter {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_csc_to_csr_scatter_shader_syntax() { - for dtype in [DType::F32, DType::I32, DType::U32] { - let shader = generate_csc_to_csr_scatter_shader(dtype).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for csc_to_csr_scatter {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_copy_ptrs_shader_syntax() { - let shader = generate_copy_ptrs_shader().unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!("Invalid WGSL for copy_ptrs:\n{}\n\nShader:\n{}", e, shader) - }); - } - - #[test] - fn test_csr_to_dense_shader_syntax() { - for dtype in [DType::F32, DType::I32, DType::U32] { - let shader = generate_csr_to_dense_shader(dtype).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for csr_to_dense {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_count_nonzeros_shader_syntax() { - for dtype in [DType::F32, DType::I32, DType::U32] { - let shader = generate_count_nonzeros_shader(dtype).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for count_nonzeros {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } - - #[test] - fn test_dense_to_coo_scatter_shader_syntax() { - for dtype in [DType::F32, DType::I32, DType::U32] { - let shader = generate_dense_to_coo_scatter_shader(dtype).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for dense_to_coo_scatter {:?}:\n{}\n\nShader:\n{}", - dtype, e, shader - ) - }); - } - } -} diff --git a/src/runtime/wgpu/shaders/generator/sparse_factorize.rs b/src/runtime/wgpu/shaders/generator/sparse_factorize.rs deleted file mode 100644 index 6e0b639e..00000000 --- a/src/runtime/wgpu/shaders/generator/sparse_factorize.rs +++ /dev/null @@ -1,252 +0,0 @@ -//! WGSL shader generation for sparse factorization operations. -//! -//! Level-scheduled ILU(0) and IC(0) incomplete factorization. - -use crate::dtype::DType; -use crate::error::{Error, Result}; - -use super::common::{is_wgpu_supported, wgsl_type}; - -/// Generate WGSL shader for ILU(0) level kernel -pub fn generate_ilu0_level_shader(dtype: DType) -> Result { - if !is_wgpu_supported(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "ilu0_level", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = match dtype { - DType::F32 => "f32", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "ilu0_level", - }); - } - }; - - Ok(format!( - r#"// Level-scheduled ILU(0) factorization kernel - -struct Ilu0Params {{ - level_size: u32, - n: u32, - diagonal_shift: {t}, - level_start: u32, -}} - -@group(0) @binding(0) var level_rows: array; -@group(0) @binding(1) var row_ptrs: array; -@group(0) @binding(2) var col_indices: array; -@group(0) @binding(3) var values: array<{t}>; -@group(0) @binding(4) var diag_indices: array; -@group(0) @binding(5) var params: Ilu0Params; - -@compute @workgroup_size(256) -fn ilu0_level_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let tid = gid.x; - if (tid >= params.level_size) {{ - return; - }} - - let i = level_rows[params.level_start + tid]; - let row_start = row_ptrs[i]; - let row_end = row_ptrs[i + 1]; - - // Process columns k < i (for L factor) - for (var idx_ik = row_start; idx_ik < row_end; idx_ik = idx_ik + 1) {{ - let k = col_indices[idx_ik]; - if (k >= i) {{ - break; - }} - - // Get diagonal U[k,k] - let diag_k = diag_indices[k]; - var diag_val = values[diag_k]; - - // Handle zero pivot - if (abs(diag_val) < 1e-15) {{ - if (params.diagonal_shift > 0.0) {{ - values[diag_k] = params.diagonal_shift; - diag_val = params.diagonal_shift; - }} - }} - - // L[i,k] = A[i,k] / U[k,k] - let l_ik = values[idx_ik] / diag_val; - values[idx_ik] = l_ik; - - // Update row i for columns j > k - let k_start = row_ptrs[k]; - let k_end = row_ptrs[k + 1]; - - for (var idx_kj = k_start; idx_kj < k_end; idx_kj = idx_kj + 1) {{ - let j = col_indices[idx_kj]; - if (j <= k) {{ - continue; - }} - - // Find A[i,j] if it exists (zero fill-in constraint) - for (var idx_ij = row_start; idx_ij < row_end; idx_ij = idx_ij + 1) {{ - if (col_indices[idx_ij] == j) {{ - values[idx_ij] = values[idx_ij] - l_ik * values[idx_kj]; - break; - }} - if (col_indices[idx_ij] > j) {{ - break; - }} - }} - }} - }} -}} -"#, - t = t, - suffix = suffix - )) -} - -/// Generate WGSL shader for IC(0) level kernel -pub fn generate_ic0_level_shader(dtype: DType) -> Result { - if !is_wgpu_supported(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "ic0_level", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = match dtype { - DType::F32 => "f32", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "ic0_level", - }); - } - }; - - Ok(format!( - r#"// Level-scheduled IC(0) factorization kernel - -struct Ic0Params {{ - level_size: u32, - n: u32, - diagonal_shift: {t}, - level_start: u32, -}} - -@group(0) @binding(0) var level_rows: array; -@group(0) @binding(1) var row_ptrs: array; -@group(0) @binding(2) var col_indices: array; -@group(0) @binding(3) var values: array<{t}>; -@group(0) @binding(4) var diag_indices: array; -@group(0) @binding(5) var params: Ic0Params; - -@compute @workgroup_size(256) -fn ic0_level_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let tid = gid.x; - if (tid >= params.level_size) {{ - return; - }} - - let i = level_rows[params.level_start + tid]; - let i_start = row_ptrs[i]; - let i_end = row_ptrs[i + 1]; - - // Process off-diagonal entries in row i (columns k < i) - for (var idx_ik = i_start; idx_ik < i_end; idx_ik = idx_ik + 1) {{ - let k = col_indices[idx_ik]; - if (k >= i) {{ - break; - }} - - let k_start = row_ptrs[k]; - let k_end = row_ptrs[k + 1]; - - // Compute inner product contribution - var sum = values[idx_ik]; - - for (var idx_kj = k_start; idx_kj < k_end; idx_kj = idx_kj + 1) {{ - let j = col_indices[idx_kj]; - if (j >= k) {{ - break; - }} - - // Check if L[i,j] exists - for (var idx_ij = i_start; idx_ij < i_end; idx_ij = idx_ij + 1) {{ - if (col_indices[idx_ij] == j) {{ - sum = sum - values[idx_ij] * values[idx_kj]; - break; - }} - if (col_indices[idx_ij] > j) {{ - break; - }} - }} - }} - - // Divide by L[k,k] - let diag_k = diag_indices[k]; - values[idx_ik] = sum / values[diag_k]; - }} - - // Compute diagonal L[i,i] - let diag_i = diag_indices[i]; - var diag_sum = values[diag_i] + params.diagonal_shift; - - for (var idx_ij = i_start; idx_ij < i_end; idx_ij = idx_ij + 1) {{ - let j = col_indices[idx_ij]; - if (j >= i) {{ - break; - }} - diag_sum = diag_sum - values[idx_ij] * values[idx_ij]; - }} - - if (diag_sum <= 0.0) {{ - diag_sum = select(1e-10, params.diagonal_shift, params.diagonal_shift > 0.0); - }} - - values[diag_i] = sqrt(diag_sum); -}} -"#, - t = t, - suffix = suffix - )) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_ilu0_level_shader_syntax() { - let shader = generate_ilu0_level_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!("Invalid WGSL for ilu0_level:\n{}\n\nShader:\n{}", e, shader) - }); - } - - #[test] - fn test_ic0_level_shader_syntax() { - let shader = generate_ic0_level_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!("Invalid WGSL for ic0_level:\n{}\n\nShader:\n{}", e, shader) - }); - } - - #[test] - fn test_f64_not_supported() { - assert!(generate_ilu0_level_shader(DType::F64).is_err()); - assert!(generate_ic0_level_shader(DType::F64).is_err()); - } -} diff --git a/src/runtime/wgpu/shaders/generator/sparse_linalg.rs b/src/runtime/wgpu/shaders/generator/sparse_linalg.rs deleted file mode 100644 index 1a146c4d..00000000 --- a/src/runtime/wgpu/shaders/generator/sparse_linalg.rs +++ /dev/null @@ -1,21 +0,0 @@ -//! WGSL shader generation for sparse linear algebra operations. -//! -//! This module re-exports from the split submodules for backward compatibility. -//! The actual implementations are in: -//! - `sparse_trsv.rs` - Sparse triangular solve shaders -//! - `sparse_factorize.rs` - ILU(0) and IC(0) factorization shaders -//! - `sparse_utils.rs` - Utility shaders (find_diag, copy) -//! - `sparse_split.rs` - Split LU and extract lower triangle shaders - -// Re-export from split modules for backward compatibility -pub use super::sparse_factorize::{generate_ic0_level_shader, generate_ilu0_level_shader}; -pub use super::sparse_split::{ - generate_extract_lower_count_shader, generate_extract_lower_scatter_shader, - generate_split_lu_count_shader, generate_split_lu_scatter_l_shader, - generate_split_lu_scatter_shader, generate_split_lu_scatter_u_shader, -}; -pub use super::sparse_trsv::{ - generate_sparse_trsv_lower_multi_rhs_shader, generate_sparse_trsv_lower_shader, - generate_sparse_trsv_upper_multi_rhs_shader, generate_sparse_trsv_upper_shader, -}; -pub use super::sparse_utils::{generate_copy_shader, generate_find_diag_indices_shader}; diff --git a/src/runtime/wgpu/shaders/generator/sparse_merge.rs b/src/runtime/wgpu/shaders/generator/sparse_merge.rs deleted file mode 100644 index f1782ec1..00000000 --- a/src/runtime/wgpu/shaders/generator/sparse_merge.rs +++ /dev/null @@ -1,765 +0,0 @@ -//! WGSL shader generation for sparse matrix element-wise merge operations -//! -//! Implements two-pass algorithms for CSR/CSC/COO element-wise operations: -//! - add, sub: union semantics (output has nonzeros from both A and B) -//! - mul, div: intersection semantics (output only where both A and B have nonzeros) -//! -//! Each format requires: -//! 1. Count kernel: count output elements per row/column/entry -//! 2. Compute kernel: perform merge and operation - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -// ============================================================================ -// CSR Format Shaders -// ============================================================================ - -/// Generate WGSL shader for CSR merge count (add/sub - union semantics) -/// -/// Counts output nonzeros per row for operations that produce union of sparsity patterns. -pub fn generate_csr_merge_count_shader() -> String { - r#"// CSR merge count kernel (union semantics for add/sub) - -const WORKGROUP_SIZE: u32 = 256u; - -struct CountParams { - nrows: u32, -} - -@group(0) @binding(0) var a_row_ptrs: array; -@group(0) @binding(1) var a_col_indices: array; -@group(0) @binding(2) var b_row_ptrs: array; -@group(0) @binding(3) var b_col_indices: array; -@group(0) @binding(4) var row_counts: array; -@group(0) @binding(5) var params: CountParams; - -@compute @workgroup_size(256) -fn csr_merge_count(@builtin(global_invocation_id) gid: vec3) { - let row = gid.x; - if (row >= params.nrows) { - return; - } - - let a_start = a_row_ptrs[row]; - let a_end = a_row_ptrs[row + 1u]; - let b_start = b_row_ptrs[row]; - let b_end = b_row_ptrs[row + 1u]; - - var count: i32 = 0; - var i: i32 = a_start; - var j: i32 = b_start; - - // Merge sorted column indices, count unique columns - while (i < a_end && j < b_end) { - let a_col = a_col_indices[i]; - let b_col = b_col_indices[j]; - - count = count + 1; - if (a_col < b_col) { - i = i + 1; - } else if (a_col > b_col) { - j = j + 1; - } else { - i = i + 1; - j = j + 1; - } - } - - // Add remaining elements from A - count = count + (a_end - i); - // Add remaining elements from B - count = count + (b_end - j); - - row_counts[row] = count; -} -"# - .to_string() -} - -/// Generate WGSL shader for CSR mul count (intersection semantics) -/// -/// Counts output nonzeros per row for operations that produce intersection of sparsity patterns. -pub fn generate_csr_mul_count_shader() -> String { - r#"// CSR mul count kernel (intersection semantics for mul/div) - -const WORKGROUP_SIZE: u32 = 256u; - -struct CountParams { - nrows: u32, -} - -@group(0) @binding(0) var a_row_ptrs: array; -@group(0) @binding(1) var a_col_indices: array; -@group(0) @binding(2) var b_row_ptrs: array; -@group(0) @binding(3) var b_col_indices: array; -@group(0) @binding(4) var row_counts: array; -@group(0) @binding(5) var params: CountParams; - -@compute @workgroup_size(256) -fn csr_mul_count(@builtin(global_invocation_id) gid: vec3) { - let row = gid.x; - if (row >= params.nrows) { - return; - } - - let a_start = a_row_ptrs[row]; - let a_end = a_row_ptrs[row + 1u]; - let b_start = b_row_ptrs[row]; - let b_end = b_row_ptrs[row + 1u]; - - var count: i32 = 0; - var i: i32 = a_start; - var j: i32 = b_start; - - // Count matching column indices only (intersection) - while (i < a_end && j < b_end) { - let a_col = a_col_indices[i]; - let b_col = b_col_indices[j]; - - if (a_col < b_col) { - i = i + 1; - } else if (a_col > b_col) { - j = j + 1; - } else { - count = count + 1; - i = i + 1; - j = j + 1; - } - } - - row_counts[row] = count; -} -"# - .to_string() -} - -/// Generate WGSL shader for CSR add compute -pub fn generate_csr_add_compute_shader(dtype: DType) -> Result { - generate_csr_binary_compute_shader(dtype, "add", "a_val + b_val", "a_val", "b_val") -} - -/// Generate WGSL shader for CSR sub compute -pub fn generate_csr_sub_compute_shader(dtype: DType) -> Result { - generate_csr_binary_compute_shader(dtype, "sub", "a_val - b_val", "a_val", "-b_val") -} - -/// Generate WGSL shader for CSR mul compute -pub fn generate_csr_mul_compute_shader(dtype: DType) -> Result { - generate_csr_intersection_compute_shader(dtype, "mul", "a_val * b_val") -} - -/// Generate WGSL shader for CSR div compute -pub fn generate_csr_div_compute_shader(dtype: DType) -> Result { - generate_csr_intersection_compute_shader(dtype, "div", "a_val / b_val") -} - -/// Internal helper for CSR add/sub compute (union semantics) -fn generate_csr_binary_compute_shader( - dtype: DType, - op_name: &str, - both_expr: &str, - a_only_expr: &str, - b_only_expr: &str, -) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// CSR {op_name} compute kernel (union semantics) - -const WORKGROUP_SIZE: u32 = 256u; - -struct ComputeParams {{ - nrows: u32, -}} - -@group(0) @binding(0) var a_row_ptrs: array; -@group(0) @binding(1) var a_col_indices: array; -@group(0) @binding(2) var a_values: array<{t}>; -@group(0) @binding(3) var b_row_ptrs: array; -@group(0) @binding(4) var b_col_indices: array; -@group(0) @binding(5) var b_values: array<{t}>; -@group(0) @binding(6) var out_row_ptrs: array; -@group(0) @binding(7) var out_col_indices: array; -@group(0) @binding(8) var out_values: array<{t}>; -@group(0) @binding(9) var params: ComputeParams; - -@compute @workgroup_size(256) -fn csr_{op_name}_compute_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let row = gid.x; - if (row >= params.nrows) {{ - return; - }} - - let a_start = a_row_ptrs[row]; - let a_end = a_row_ptrs[row + 1u]; - let b_start = b_row_ptrs[row]; - let b_end = b_row_ptrs[row + 1u]; - - var out_idx = out_row_ptrs[row]; - var i: i32 = a_start; - var j: i32 = b_start; - - // Merge sorted column indices - while (i < a_end && j < b_end) {{ - let a_col = a_col_indices[i]; - let b_col = b_col_indices[j]; - let a_val = a_values[i]; - let b_val = b_values[j]; - - if (a_col < b_col) {{ - out_col_indices[out_idx] = a_col; - out_values[out_idx] = {a_only_expr}; - out_idx = out_idx + 1; - i = i + 1; - }} else if (a_col > b_col) {{ - out_col_indices[out_idx] = b_col; - out_values[out_idx] = {b_only_expr}; - out_idx = out_idx + 1; - j = j + 1; - }} else {{ - out_col_indices[out_idx] = a_col; - out_values[out_idx] = {both_expr}; - out_idx = out_idx + 1; - i = i + 1; - j = j + 1; - }} - }} - - // Copy remaining from A - while (i < a_end) {{ - out_col_indices[out_idx] = a_col_indices[i]; - out_values[out_idx] = a_values[i]; - out_idx = out_idx + 1; - i = i + 1; - }} - - // Copy remaining from B - while (j < b_end) {{ - out_col_indices[out_idx] = b_col_indices[j]; - out_values[out_idx] = {b_only_expr_for_b}; - out_idx = out_idx + 1; - j = j + 1; - }} -}} -"#, - t = t, - op_name = op_name, - suffix = suffix, - both_expr = both_expr, - a_only_expr = a_only_expr, - b_only_expr = b_only_expr, - b_only_expr_for_b = if op_name == "sub" { - "-b_values[j]" - } else { - "b_values[j]" - }, - )) -} - -/// Internal helper for CSR mul/div compute (intersection semantics) -fn generate_csr_intersection_compute_shader( - dtype: DType, - op_name: &str, - expr: &str, -) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// CSR {op_name} compute kernel (intersection semantics) - -const WORKGROUP_SIZE: u32 = 256u; - -struct ComputeParams {{ - nrows: u32, -}} - -@group(0) @binding(0) var a_row_ptrs: array; -@group(0) @binding(1) var a_col_indices: array; -@group(0) @binding(2) var a_values: array<{t}>; -@group(0) @binding(3) var b_row_ptrs: array; -@group(0) @binding(4) var b_col_indices: array; -@group(0) @binding(5) var b_values: array<{t}>; -@group(0) @binding(6) var out_row_ptrs: array; -@group(0) @binding(7) var out_col_indices: array; -@group(0) @binding(8) var out_values: array<{t}>; -@group(0) @binding(9) var params: ComputeParams; - -@compute @workgroup_size(256) -fn csr_{op_name}_compute_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let row = gid.x; - if (row >= params.nrows) {{ - return; - }} - - let a_start = a_row_ptrs[row]; - let a_end = a_row_ptrs[row + 1u]; - let b_start = b_row_ptrs[row]; - let b_end = b_row_ptrs[row + 1u]; - - var out_idx = out_row_ptrs[row]; - var i: i32 = a_start; - var j: i32 = b_start; - - // Only output where both A and B have nonzeros (intersection) - while (i < a_end && j < b_end) {{ - let a_col = a_col_indices[i]; - let b_col = b_col_indices[j]; - - if (a_col < b_col) {{ - i = i + 1; - }} else if (a_col > b_col) {{ - j = j + 1; - }} else {{ - let a_val = a_values[i]; - let b_val = b_values[j]; - out_col_indices[out_idx] = a_col; - out_values[out_idx] = {expr}; - out_idx = out_idx + 1; - i = i + 1; - j = j + 1; - }} - }} -}} -"#, - t = t, - op_name = op_name, - suffix = suffix, - expr = expr, - )) -} - -// ============================================================================ -// CSC Format Shaders (analogous to CSR but operates on columns) -// ============================================================================ - -/// Generate WGSL shader for CSC merge count (union semantics) -pub fn generate_csc_merge_count_shader() -> String { - r#"// CSC merge count kernel (union semantics for add/sub) - -const WORKGROUP_SIZE: u32 = 256u; - -struct CountParams { - ncols: u32, -} - -@group(0) @binding(0) var a_col_ptrs: array; -@group(0) @binding(1) var a_row_indices: array; -@group(0) @binding(2) var b_col_ptrs: array; -@group(0) @binding(3) var b_row_indices: array; -@group(0) @binding(4) var col_counts: array; -@group(0) @binding(5) var params: CountParams; - -@compute @workgroup_size(256) -fn csc_merge_count(@builtin(global_invocation_id) gid: vec3) { - let col = gid.x; - if (col >= params.ncols) { - return; - } - - let a_start = a_col_ptrs[col]; - let a_end = a_col_ptrs[col + 1u]; - let b_start = b_col_ptrs[col]; - let b_end = b_col_ptrs[col + 1u]; - - var count: i32 = 0; - var i: i32 = a_start; - var j: i32 = b_start; - - while (i < a_end && j < b_end) { - let a_row = a_row_indices[i]; - let b_row = b_row_indices[j]; - - count = count + 1; - if (a_row < b_row) { - i = i + 1; - } else if (a_row > b_row) { - j = j + 1; - } else { - i = i + 1; - j = j + 1; - } - } - - count = count + (a_end - i); - count = count + (b_end - j); - - col_counts[col] = count; -} -"# - .to_string() -} - -/// Generate WGSL shader for CSC mul count (intersection semantics) -pub fn generate_csc_mul_count_shader() -> String { - r#"// CSC mul count kernel (intersection semantics for mul/div) - -const WORKGROUP_SIZE: u32 = 256u; - -struct CountParams { - ncols: u32, -} - -@group(0) @binding(0) var a_col_ptrs: array; -@group(0) @binding(1) var a_row_indices: array; -@group(0) @binding(2) var b_col_ptrs: array; -@group(0) @binding(3) var b_row_indices: array; -@group(0) @binding(4) var col_counts: array; -@group(0) @binding(5) var params: CountParams; - -@compute @workgroup_size(256) -fn csc_mul_count(@builtin(global_invocation_id) gid: vec3) { - let col = gid.x; - if (col >= params.ncols) { - return; - } - - let a_start = a_col_ptrs[col]; - let a_end = a_col_ptrs[col + 1u]; - let b_start = b_col_ptrs[col]; - let b_end = b_col_ptrs[col + 1u]; - - var count: i32 = 0; - var i: i32 = a_start; - var j: i32 = b_start; - - while (i < a_end && j < b_end) { - let a_row = a_row_indices[i]; - let b_row = b_row_indices[j]; - - if (a_row < b_row) { - i = i + 1; - } else if (a_row > b_row) { - j = j + 1; - } else { - count = count + 1; - i = i + 1; - j = j + 1; - } - } - - col_counts[col] = count; -} -"# - .to_string() -} - -/// Generate WGSL shader for CSC add compute -pub fn generate_csc_add_compute_shader(dtype: DType) -> Result { - generate_csc_binary_compute_shader(dtype, "add", "a_val + b_val", "a_val", "b_val") -} - -/// Generate WGSL shader for CSC sub compute -pub fn generate_csc_sub_compute_shader(dtype: DType) -> Result { - generate_csc_binary_compute_shader(dtype, "sub", "a_val - b_val", "a_val", "-b_val") -} - -/// Generate WGSL shader for CSC mul compute -pub fn generate_csc_mul_compute_shader(dtype: DType) -> Result { - generate_csc_intersection_compute_shader(dtype, "mul", "a_val * b_val") -} - -/// Generate WGSL shader for CSC div compute -pub fn generate_csc_div_compute_shader(dtype: DType) -> Result { - generate_csc_intersection_compute_shader(dtype, "div", "a_val / b_val") -} - -/// Internal helper for CSC add/sub compute (union semantics) -fn generate_csc_binary_compute_shader( - dtype: DType, - op_name: &str, - both_expr: &str, - a_only_expr: &str, - b_only_expr: &str, -) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// CSC {op_name} compute kernel (union semantics) - -const WORKGROUP_SIZE: u32 = 256u; - -struct ComputeParams {{ - ncols: u32, -}} - -@group(0) @binding(0) var a_col_ptrs: array; -@group(0) @binding(1) var a_row_indices: array; -@group(0) @binding(2) var a_values: array<{t}>; -@group(0) @binding(3) var b_col_ptrs: array; -@group(0) @binding(4) var b_row_indices: array; -@group(0) @binding(5) var b_values: array<{t}>; -@group(0) @binding(6) var out_col_ptrs: array; -@group(0) @binding(7) var out_row_indices: array; -@group(0) @binding(8) var out_values: array<{t}>; -@group(0) @binding(9) var params: ComputeParams; - -@compute @workgroup_size(256) -fn csc_{op_name}_compute_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let col = gid.x; - if (col >= params.ncols) {{ - return; - }} - - let a_start = a_col_ptrs[col]; - let a_end = a_col_ptrs[col + 1u]; - let b_start = b_col_ptrs[col]; - let b_end = b_col_ptrs[col + 1u]; - - var out_idx = out_col_ptrs[col]; - var i: i32 = a_start; - var j: i32 = b_start; - - while (i < a_end && j < b_end) {{ - let a_row = a_row_indices[i]; - let b_row = b_row_indices[j]; - let a_val = a_values[i]; - let b_val = b_values[j]; - - if (a_row < b_row) {{ - out_row_indices[out_idx] = a_row; - out_values[out_idx] = {a_only_expr}; - out_idx = out_idx + 1; - i = i + 1; - }} else if (a_row > b_row) {{ - out_row_indices[out_idx] = b_row; - out_values[out_idx] = {b_only_expr}; - out_idx = out_idx + 1; - j = j + 1; - }} else {{ - out_row_indices[out_idx] = a_row; - out_values[out_idx] = {both_expr}; - out_idx = out_idx + 1; - i = i + 1; - j = j + 1; - }} - }} - - while (i < a_end) {{ - out_row_indices[out_idx] = a_row_indices[i]; - out_values[out_idx] = a_values[i]; - out_idx = out_idx + 1; - i = i + 1; - }} - - while (j < b_end) {{ - out_row_indices[out_idx] = b_row_indices[j]; - out_values[out_idx] = {b_only_expr_for_b}; - out_idx = out_idx + 1; - j = j + 1; - }} -}} -"#, - t = t, - op_name = op_name, - suffix = suffix, - both_expr = both_expr, - a_only_expr = a_only_expr, - b_only_expr = b_only_expr, - b_only_expr_for_b = if op_name == "sub" { - "-b_values[j]" - } else { - "b_values[j]" - }, - )) -} - -/// Internal helper for CSC mul/div compute (intersection semantics) -fn generate_csc_intersection_compute_shader( - dtype: DType, - op_name: &str, - expr: &str, -) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// CSC {op_name} compute kernel (intersection semantics) - -const WORKGROUP_SIZE: u32 = 256u; - -struct ComputeParams {{ - ncols: u32, -}} - -@group(0) @binding(0) var a_col_ptrs: array; -@group(0) @binding(1) var a_row_indices: array; -@group(0) @binding(2) var a_values: array<{t}>; -@group(0) @binding(3) var b_col_ptrs: array; -@group(0) @binding(4) var b_row_indices: array; -@group(0) @binding(5) var b_values: array<{t}>; -@group(0) @binding(6) var out_col_ptrs: array; -@group(0) @binding(7) var out_row_indices: array; -@group(0) @binding(8) var out_values: array<{t}>; -@group(0) @binding(9) var params: ComputeParams; - -@compute @workgroup_size(256) -fn csc_{op_name}_compute_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let col = gid.x; - if (col >= params.ncols) {{ - return; - }} - - let a_start = a_col_ptrs[col]; - let a_end = a_col_ptrs[col + 1u]; - let b_start = b_col_ptrs[col]; - let b_end = b_col_ptrs[col + 1u]; - - var out_idx = out_col_ptrs[col]; - var i: i32 = a_start; - var j: i32 = b_start; - - while (i < a_end && j < b_end) {{ - let a_row = a_row_indices[i]; - let b_row = b_row_indices[j]; - - if (a_row < b_row) {{ - i = i + 1; - }} else if (a_row > b_row) {{ - j = j + 1; - }} else {{ - let a_val = a_values[i]; - let b_val = b_values[j]; - out_row_indices[out_idx] = a_row; - out_values[out_idx] = {expr}; - out_idx = out_idx + 1; - i = i + 1; - j = j + 1; - }} - }} -}} -"#, - t = t, - op_name = op_name, - suffix = suffix, - expr = expr, - )) -} - -// ============================================================================ -// COO Format Shaders -// ============================================================================ - -// COO merge is more complex since entries aren't sorted by row/col. -// For simplicity, we convert COO to CSR, perform the merge, then optionally convert back. -// This is the standard approach since COO doesn't have efficient merge algorithms. - -// ============================================================================ -// Exclusive Scan (Prefix Sum) Shader -// ============================================================================ - -/// Generate WGSL shader for sequential exclusive scan (for small arrays) -/// -/// This is a simple sequential scan that works for the row_ptrs/col_ptrs arrays -/// which are typically small (O(nrows) or O(ncols)). -pub fn generate_exclusive_scan_shader() -> String { - r#"// Sequential exclusive scan for small arrays - -const WORKGROUP_SIZE: u32 = 256u; - -struct ScanParams { - n: u32, -} - -@group(0) @binding(0) var input: array; -@group(0) @binding(1) var output: array; -@group(0) @binding(2) var params: ScanParams; - -// Sequential exclusive scan - only first thread does work -// For parallel scan on larger arrays, use work-efficient parallel scan -@compute @workgroup_size(1) -fn exclusive_scan_i32(@builtin(global_invocation_id) gid: vec3) { - if (gid.x != 0u) { - return; - } - - var sum: i32 = 0; - for (var i: u32 = 0u; i < params.n; i = i + 1u) { - let val = input[i]; - output[i] = sum; - sum = sum + val; - } - // Final element is total sum - output[params.n] = sum; -} -"# - .to_string() -} - -#[cfg(test)] -mod tests { - use super::*; - - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_csr_merge_count_shader_syntax() { - let shader = generate_csr_merge_count_shader(); - validate_wgsl_syntax(&shader).expect("CSR merge count shader should be valid WGSL"); - } - - #[test] - fn test_csr_mul_count_shader_syntax() { - let shader = generate_csr_mul_count_shader(); - validate_wgsl_syntax(&shader).expect("CSR mul count shader should be valid WGSL"); - } - - #[test] - fn test_csr_add_compute_shader_syntax_f32() { - let shader = generate_csr_add_compute_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("CSR add compute shader should be valid WGSL"); - } - - #[test] - fn test_csr_sub_compute_shader_syntax_f32() { - let shader = generate_csr_sub_compute_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("CSR sub compute shader should be valid WGSL"); - } - - #[test] - fn test_csr_mul_compute_shader_syntax_f32() { - let shader = generate_csr_mul_compute_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("CSR mul compute shader should be valid WGSL"); - } - - #[test] - fn test_csr_div_compute_shader_syntax_f32() { - let shader = generate_csr_div_compute_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("CSR div compute shader should be valid WGSL"); - } - - #[test] - fn test_csc_merge_count_shader_syntax() { - let shader = generate_csc_merge_count_shader(); - validate_wgsl_syntax(&shader).expect("CSC merge count shader should be valid WGSL"); - } - - #[test] - fn test_csc_mul_count_shader_syntax() { - let shader = generate_csc_mul_count_shader(); - validate_wgsl_syntax(&shader).expect("CSC mul count shader should be valid WGSL"); - } - - #[test] - fn test_csc_add_compute_shader_syntax_f32() { - let shader = generate_csc_add_compute_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("CSC add compute shader should be valid WGSL"); - } - - #[test] - fn test_exclusive_scan_shader_syntax() { - let shader = generate_exclusive_scan_shader(); - validate_wgsl_syntax(&shader).expect("Exclusive scan shader should be valid WGSL"); - } -} diff --git a/src/runtime/wgpu/shaders/generator/sparse_split.rs b/src/runtime/wgpu/shaders/generator/sparse_split.rs deleted file mode 100644 index d014a0c8..00000000 --- a/src/runtime/wgpu/shaders/generator/sparse_split.rs +++ /dev/null @@ -1,459 +0,0 @@ -//! WGSL shader generation for sparse matrix splitting operations. -//! -//! Split LU and extract lower triangle operations. - -use crate::dtype::DType; -use crate::error::{Error, Result}; - -use super::common::{is_wgpu_supported, wgsl_type}; - -/// Generate WGSL shader for counting L and U non-zeros per row (split_lu step 1) -pub fn generate_split_lu_count_shader() -> String { - r#"// Count L and U non-zeros per row for split_lu - -struct SplitLuCountParams { - n: u32, - _padding0: u32, - _padding1: u32, - _padding2: u32, -} - -// Note: All buffers use read_write due to LayoutKey-based pipeline layout -@group(0) @binding(0) var row_ptrs: array; -@group(0) @binding(1) var col_indices: array; -@group(0) @binding(2) var l_counts: array; -@group(0) @binding(3) var u_counts: array; -@group(0) @binding(4) var params: SplitLuCountParams; - -@compute @workgroup_size(256) -fn split_lu_count(@builtin(global_invocation_id) gid: vec3) { - let row = i32(gid.x); - if (u32(row) >= params.n) { - return; - } - - let start = row_ptrs[row]; - let end = row_ptrs[row + 1]; - - var l_count = 0i; - var u_count = 0i; - - for (var idx = start; idx < end; idx = idx + 1) { - let col = col_indices[idx]; - if (col < row) { - l_count = l_count + 1; - } else { - u_count = u_count + 1; - } - } - - l_counts[row] = l_count; - u_counts[row] = u_count; -} -"# - .to_string() -} - -/// Generate WGSL shader for scattering values into L and U (split_lu step 2) -pub fn generate_split_lu_scatter_shader(dtype: DType) -> Result { - if !is_wgpu_supported(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "split_lu_scatter", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = match dtype { - DType::F32 => "f32", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "split_lu_scatter", - }); - } - }; - - Ok(format!( - r#"// Scatter values into L and U matrices - -struct SplitLuScatterParams {{ - n: u32, - _padding0: u32, - _padding1: u32, - _padding2: u32, -}} - -// Note: All buffers use read_write due to LayoutKey-based pipeline layout -@group(0) @binding(0) var row_ptrs: array; -@group(0) @binding(1) var col_indices: array; -@group(0) @binding(2) var values: array<{t}>; -@group(0) @binding(3) var l_row_ptrs: array; -@group(0) @binding(4) var l_col_indices: array; -@group(0) @binding(5) var l_values: array<{t}>; -@group(0) @binding(6) var u_row_ptrs: array; -@group(0) @binding(7) var u_col_indices: array; -@group(0) @binding(8) var u_values: array<{t}>; -@group(0) @binding(9) var params: SplitLuScatterParams; - -@compute @workgroup_size(256) -fn split_lu_scatter_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let row = i32(gid.x); - if (u32(row) >= params.n) {{ - return; - }} - - let src_start = row_ptrs[row]; - let src_end = row_ptrs[row + 1]; - - var l_write_pos = l_row_ptrs[row]; - var u_write_pos = u_row_ptrs[row]; - - for (var idx = src_start; idx < src_end; idx = idx + 1) {{ - let col = col_indices[idx]; - let val = values[idx]; - - if (col < row) {{ - // Lower triangle - l_col_indices[l_write_pos] = col; - l_values[l_write_pos] = val; - l_write_pos = l_write_pos + 1; - }} else {{ - // Upper triangle (includes diagonal) - u_col_indices[u_write_pos] = col; - u_values[u_write_pos] = val; - u_write_pos = u_write_pos + 1; - }} - }} -}} -"#, - t = t, - suffix = suffix - )) -} - -/// Generate WGSL shader for scattering values into L matrix only (split_lu part 1) -pub fn generate_split_lu_scatter_l_shader(dtype: DType) -> Result { - if !is_wgpu_supported(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "split_lu_scatter_l", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = match dtype { - DType::F32 => "f32", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "split_lu_scatter_l", - }); - } - }; - - Ok(format!( - r#"// Scatter values into L matrix (lower triangle) - -struct SplitLuScatterParams {{ - n: u32, - _padding0: u32, - _padding1: u32, - _padding2: u32, -}} - -// Note: All buffers use read_write due to LayoutKey-based pipeline layout -@group(0) @binding(0) var row_ptrs: array; -@group(0) @binding(1) var col_indices: array; -@group(0) @binding(2) var values: array<{t}>; -@group(0) @binding(3) var l_row_ptrs: array; -@group(0) @binding(4) var l_col_indices: array; -@group(0) @binding(5) var l_values: array<{t}>; -@group(0) @binding(6) var params: SplitLuScatterParams; - -@compute @workgroup_size(256) -fn split_lu_scatter_l_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let row = i32(gid.x); - if (u32(row) >= params.n) {{ - return; - }} - - let src_start = row_ptrs[row]; - let src_end = row_ptrs[row + 1]; - var l_write_pos = l_row_ptrs[row]; - - for (var idx = src_start; idx < src_end; idx = idx + 1) {{ - let col = col_indices[idx]; - if (col < row) {{ - l_col_indices[l_write_pos] = col; - l_values[l_write_pos] = values[idx]; - l_write_pos = l_write_pos + 1; - }} - }} -}} -"#, - t = t, - suffix = suffix - )) -} - -/// Generate WGSL shader for scattering values into U matrix only (split_lu part 2) -pub fn generate_split_lu_scatter_u_shader(dtype: DType) -> Result { - if !is_wgpu_supported(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "split_lu_scatter_u", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = match dtype { - DType::F32 => "f32", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "split_lu_scatter_u", - }); - } - }; - - Ok(format!( - r#"// Scatter values into U matrix (upper triangle + diagonal) - -struct SplitLuScatterParams {{ - n: u32, - _padding0: u32, - _padding1: u32, - _padding2: u32, -}} - -// Note: All buffers use read_write due to LayoutKey-based pipeline layout -@group(0) @binding(0) var row_ptrs: array; -@group(0) @binding(1) var col_indices: array; -@group(0) @binding(2) var values: array<{t}>; -@group(0) @binding(3) var u_row_ptrs: array; -@group(0) @binding(4) var u_col_indices: array; -@group(0) @binding(5) var u_values: array<{t}>; -@group(0) @binding(6) var params: SplitLuScatterParams; - -@compute @workgroup_size(256) -fn split_lu_scatter_u_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let row = i32(gid.x); - if (u32(row) >= params.n) {{ - return; - }} - - let src_start = row_ptrs[row]; - let src_end = row_ptrs[row + 1]; - var u_write_pos = u_row_ptrs[row]; - - for (var idx = src_start; idx < src_end; idx = idx + 1) {{ - let col = col_indices[idx]; - if (col >= row) {{ - u_col_indices[u_write_pos] = col; - u_values[u_write_pos] = values[idx]; - u_write_pos = u_write_pos + 1; - }} - }} -}} -"#, - t = t, - suffix = suffix - )) -} - -/// Generate WGSL shader for counting lower triangle non-zeros per row -pub fn generate_extract_lower_count_shader() -> String { - r#"// Count lower triangle non-zeros per row - -struct ExtractLowerCountParams { - n: u32, - _padding0: u32, - _padding1: u32, - _padding2: u32, -} - -// Note: All buffers use read_write due to LayoutKey-based pipeline layout -@group(0) @binding(0) var row_ptrs: array; -@group(0) @binding(1) var col_indices: array; -@group(0) @binding(2) var l_counts: array; -@group(0) @binding(3) var params: ExtractLowerCountParams; - -@compute @workgroup_size(256) -fn extract_lower_count(@builtin(global_invocation_id) gid: vec3) { - let row = i32(gid.x); - if (u32(row) >= params.n) { - return; - } - - let start = row_ptrs[row]; - let end = row_ptrs[row + 1]; - - var count = 0i; - - for (var idx = start; idx < end; idx = idx + 1) { - let col = col_indices[idx]; - if (col <= row) { - count = count + 1; - } - } - - l_counts[row] = count; -} -"# - .to_string() -} - -/// Generate WGSL shader for scattering lower triangle values -pub fn generate_extract_lower_scatter_shader(dtype: DType) -> Result { - if !is_wgpu_supported(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "extract_lower_scatter", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = match dtype { - DType::F32 => "f32", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "extract_lower_scatter", - }); - } - }; - - Ok(format!( - r#"// Scatter lower triangle values - -struct ExtractLowerScatterParams {{ - n: u32, - _padding0: u32, - _padding1: u32, - _padding2: u32, -}} - -// Note: All buffers use read_write due to LayoutKey-based pipeline layout -@group(0) @binding(0) var row_ptrs: array; -@group(0) @binding(1) var col_indices: array; -@group(0) @binding(2) var values: array<{t}>; -@group(0) @binding(3) var l_row_ptrs: array; -@group(0) @binding(4) var l_col_indices: array; -@group(0) @binding(5) var l_values: array<{t}>; -@group(0) @binding(6) var params: ExtractLowerScatterParams; - -@compute @workgroup_size(256) -fn extract_lower_scatter_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let row = i32(gid.x); - if (u32(row) >= params.n) {{ - return; - }} - - let src_start = row_ptrs[row]; - let src_end = row_ptrs[row + 1]; - - var write_pos = l_row_ptrs[row]; - - for (var idx = src_start; idx < src_end; idx = idx + 1) {{ - let col = col_indices[idx]; - if (col <= row) {{ - l_col_indices[write_pos] = col; - l_values[write_pos] = values[idx]; - write_pos = write_pos + 1; - }} - }} -}} -"#, - t = t, - suffix = suffix - )) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_split_lu_count_shader_syntax() { - let shader = generate_split_lu_count_shader(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for split_lu_count:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_split_lu_scatter_shader_syntax() { - let shader = generate_split_lu_scatter_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for split_lu_scatter:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_split_lu_scatter_l_shader_syntax() { - let shader = generate_split_lu_scatter_l_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for split_lu_scatter_l:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_split_lu_scatter_u_shader_syntax() { - let shader = generate_split_lu_scatter_u_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for split_lu_scatter_u:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_extract_lower_count_shader_syntax() { - let shader = generate_extract_lower_count_shader(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for extract_lower_count:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_extract_lower_scatter_shader_syntax() { - let shader = generate_extract_lower_scatter_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for extract_lower_scatter:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_f64_not_supported() { - assert!(generate_split_lu_scatter_shader(DType::F64).is_err()); - assert!(generate_split_lu_scatter_l_shader(DType::F64).is_err()); - assert!(generate_split_lu_scatter_u_shader(DType::F64).is_err()); - assert!(generate_extract_lower_scatter_shader(DType::F64).is_err()); - } -} diff --git a/src/runtime/wgpu/shaders/generator/sparse_trsv.rs b/src/runtime/wgpu/shaders/generator/sparse_trsv.rs deleted file mode 100644 index 1e223e36..00000000 --- a/src/runtime/wgpu/shaders/generator/sparse_trsv.rs +++ /dev/null @@ -1,353 +0,0 @@ -//! WGSL shader generation for sparse triangular solve operations. -//! -//! Level-scheduled sparse triangular solve (forward and backward substitution). - -use crate::dtype::DType; -use crate::error::{Error, Result}; - -use super::common::{is_wgpu_supported, wgsl_type}; - -/// Generate WGSL shader for level-scheduled sparse lower triangular solve -pub fn generate_sparse_trsv_lower_shader(dtype: DType) -> Result { - if !is_wgpu_supported(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "sparse_trsv_lower", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = match dtype { - DType::F32 => "f32", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "sparse_trsv_lower", - }); - } - }; - - Ok(format!( - r#"// Level-scheduled sparse lower triangular solve (forward substitution) -// Processes all rows in a single level in parallel - -struct TrsvParams {{ - level_size: u32, - n: u32, - unit_diagonal: u32, - level_start: u32, -}} - -@group(0) @binding(0) var level_rows: array; -@group(0) @binding(1) var row_ptrs: array; -@group(0) @binding(2) var col_indices: array; -@group(0) @binding(3) var values: array<{t}>; -@group(0) @binding(4) var b: array<{t}>; -@group(0) @binding(5) var x: array<{t}>; -@group(0) @binding(6) var params: TrsvParams; - -@compute @workgroup_size(256) -fn sparse_trsv_lower_level_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let tid = gid.x; - if (tid >= params.level_size) {{ - return; - }} - - let row = level_rows[params.level_start + tid]; - let start = row_ptrs[row]; - let end = row_ptrs[row + 1]; - - var sum = b[row]; - var diag = {t}(1.0); - - for (var idx = start; idx < end; idx = idx + 1) {{ - let col = col_indices[idx]; - if (col < row) {{ - sum = sum - values[idx] * x[col]; - }} else if (col == row && params.unit_diagonal == 0u) {{ - diag = values[idx]; - }} - }} - - if (params.unit_diagonal == 0u) {{ - sum = sum / diag; - }} - - x[row] = sum; -}} -"#, - t = t, - suffix = suffix - )) -} - -/// Generate WGSL shader for level-scheduled sparse upper triangular solve -pub fn generate_sparse_trsv_upper_shader(dtype: DType) -> Result { - if !is_wgpu_supported(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "sparse_trsv_upper", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = match dtype { - DType::F32 => "f32", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "sparse_trsv_upper", - }); - } - }; - - Ok(format!( - r#"// Level-scheduled sparse upper triangular solve (backward substitution) - -struct TrsvParams {{ - level_size: u32, - n: u32, - _pad0: u32, - level_start: u32, -}} - -@group(0) @binding(0) var level_rows: array; -@group(0) @binding(1) var row_ptrs: array; -@group(0) @binding(2) var col_indices: array; -@group(0) @binding(3) var values: array<{t}>; -@group(0) @binding(4) var b: array<{t}>; -@group(0) @binding(5) var x: array<{t}>; -@group(0) @binding(6) var params: TrsvParams; - -@compute @workgroup_size(256) -fn sparse_trsv_upper_level_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let tid = gid.x; - if (tid >= params.level_size) {{ - return; - }} - - let row = level_rows[params.level_start + tid]; - let start = row_ptrs[row]; - let end = row_ptrs[row + 1]; - - var sum = b[row]; - var diag = {t}(1.0); - - for (var idx = start; idx < end; idx = idx + 1) {{ - let col = col_indices[idx]; - if (col > row) {{ - sum = sum - values[idx] * x[col]; - }} else if (col == row) {{ - diag = values[idx]; - }} - }} - - x[row] = sum / diag; -}} -"#, - t = t, - suffix = suffix - )) -} - -/// Generate WGSL shader for multi-RHS level-scheduled sparse lower triangular solve -/// Handles b and x with shape [n, nrhs] in row-major order -pub fn generate_sparse_trsv_lower_multi_rhs_shader(dtype: DType) -> Result { - if !is_wgpu_supported(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "sparse_trsv_lower_multi_rhs", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = match dtype { - DType::F32 => "f32", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "sparse_trsv_lower_multi_rhs", - }); - } - }; - - Ok(format!( - r#"// Multi-RHS level-scheduled sparse lower triangular solve (forward substitution) -// Processes all (row, rhs_column) pairs in a single level in parallel - -struct TrsvMultiRhsParams {{ - level_size: u32, - nrhs: u32, - n: u32, - unit_diagonal: u32, - level_start: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var level_rows: array; -@group(0) @binding(1) var row_ptrs: array; -@group(0) @binding(2) var col_indices: array; -@group(0) @binding(3) var values: array<{t}>; -@group(0) @binding(4) var b: array<{t}>; -@group(0) @binding(5) var x: array<{t}>; -@group(0) @binding(6) var params: TrsvMultiRhsParams; - -@compute @workgroup_size(256) -fn sparse_trsv_lower_level_multi_rhs_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let tid = gid.x; - let total_work = params.level_size * params.nrhs; - if (tid >= total_work) {{ - return; - }} - - let row_idx = tid / params.nrhs; - let rhs_col = tid % params.nrhs; - let row = level_rows[params.level_start + row_idx]; - - let start = row_ptrs[row]; - let end = row_ptrs[row + 1]; - - var sum = b[u32(row) * params.nrhs + rhs_col]; - var diag = {t}(1.0); - - for (var idx = start; idx < end; idx = idx + 1) {{ - let col = col_indices[idx]; - if (col < row) {{ - sum = sum - values[idx] * x[u32(col) * params.nrhs + rhs_col]; - }} else if (col == row && params.unit_diagonal == 0u) {{ - diag = values[idx]; - }} - }} - - if (params.unit_diagonal == 0u) {{ - sum = sum / diag; - }} - - x[u32(row) * params.nrhs + rhs_col] = sum; -}} -"#, - t = t, - suffix = suffix - )) -} - -/// Generate WGSL shader for multi-RHS level-scheduled sparse upper triangular solve -pub fn generate_sparse_trsv_upper_multi_rhs_shader(dtype: DType) -> Result { - if !is_wgpu_supported(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "sparse_trsv_upper_multi_rhs", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = match dtype { - DType::F32 => "f32", - _ => { - return Err(Error::UnsupportedDType { - dtype, - op: "sparse_trsv_upper_multi_rhs", - }); - } - }; - - Ok(format!( - r#"// Multi-RHS level-scheduled sparse upper triangular solve (backward substitution) - -struct TrsvMultiRhsParams {{ - level_size: u32, - nrhs: u32, - n: u32, - _pad0: u32, - level_start: u32, - _pad1: u32, - _pad2: u32, - _pad3: u32, -}} - -@group(0) @binding(0) var level_rows: array; -@group(0) @binding(1) var row_ptrs: array; -@group(0) @binding(2) var col_indices: array; -@group(0) @binding(3) var values: array<{t}>; -@group(0) @binding(4) var b: array<{t}>; -@group(0) @binding(5) var x: array<{t}>; -@group(0) @binding(6) var params: TrsvMultiRhsParams; - -@compute @workgroup_size(256) -fn sparse_trsv_upper_level_multi_rhs_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let tid = gid.x; - let total_work = params.level_size * params.nrhs; - if (tid >= total_work) {{ - return; - }} - - let row_idx = tid / params.nrhs; - let rhs_col = tid % params.nrhs; - let row = level_rows[params.level_start + row_idx]; - - let start = row_ptrs[row]; - let end = row_ptrs[row + 1]; - - var sum = b[u32(row) * params.nrhs + rhs_col]; - var diag = {t}(1.0); - - for (var idx = start; idx < end; idx = idx + 1) {{ - let col = col_indices[idx]; - if (col > row) {{ - sum = sum - values[idx] * x[u32(col) * params.nrhs + rhs_col]; - }} else if (col == row) {{ - diag = values[idx]; - }} - }} - - x[u32(row) * params.nrhs + rhs_col] = sum / diag; -}} -"#, - t = t, - suffix = suffix - )) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_sparse_trsv_lower_shader_syntax() { - let shader = generate_sparse_trsv_lower_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for sparse_trsv_lower:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_sparse_trsv_upper_shader_syntax() { - let shader = generate_sparse_trsv_upper_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for sparse_trsv_upper:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_f64_not_supported() { - assert!(generate_sparse_trsv_lower_shader(DType::F64).is_err()); - assert!(generate_sparse_trsv_upper_shader(DType::F64).is_err()); - } -} diff --git a/src/runtime/wgpu/shaders/generator/sparse_utils.rs b/src/runtime/wgpu/shaders/generator/sparse_utils.rs deleted file mode 100644 index 417e9bc4..00000000 --- a/src/runtime/wgpu/shaders/generator/sparse_utils.rs +++ /dev/null @@ -1,124 +0,0 @@ -//! WGSL shader generation for sparse utility operations. -//! -//! Finding diagonal indices and copying vectors. - -use crate::dtype::DType; -use crate::error::{Error, Result}; - -use super::common::{is_wgpu_supported, wgsl_type}; - -/// Generate WGSL shader for finding diagonal indices -pub fn generate_find_diag_indices_shader() -> String { - r#"// Find diagonal indices in CSR matrix - -struct DiagParams { - n: u32, - _padding0: u32, - _padding1: u32, - _padding2: u32, -} - -@group(0) @binding(0) var row_ptrs: array; -@group(0) @binding(1) var col_indices: array; -@group(0) @binding(2) var diag_indices: array; -@group(0) @binding(3) var params: DiagParams; - -@compute @workgroup_size(256) -fn find_diag_indices(@builtin(global_invocation_id) gid: vec3) { - let row = i32(gid.x); - if (u32(row) >= params.n) { - return; - } - - let start = row_ptrs[row]; - let end = row_ptrs[row + 1]; - - diag_indices[row] = -1; // Default: no diagonal found - - for (var idx = start; idx < end; idx = idx + 1) { - if (col_indices[idx] == row) { - diag_indices[row] = idx; - break; - } - } -} -"# - .to_string() -} - -/// Generate WGSL shader for copying vectors -pub fn generate_copy_shader(dtype: DType) -> Result { - if !is_wgpu_supported(dtype) { - return Err(Error::UnsupportedDType { dtype, op: "copy" }); - } - - let t = wgsl_type(dtype)?; - let suffix = match dtype { - DType::F32 => "f32", - _ => return Err(Error::UnsupportedDType { dtype, op: "copy" }), - }; - - Ok(format!( - r#"// Copy vector - -struct CopyParams {{ - n: u32, - _padding0: u32, - _padding1: u32, - _padding2: u32, -}} - -// Note: All buffers use read_write for compatibility with LayoutKey-based layouts -@group(0) @binding(0) var src: array<{t}>; -@group(0) @binding(1) var dst: array<{t}>; -@group(0) @binding(2) var params: CopyParams; - -@compute @workgroup_size(256) -fn copy_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < params.n) {{ - dst[idx] = src[idx]; - }} -}} -"#, - t = t, - suffix = suffix - )) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_find_diag_indices_shader_syntax() { - let shader = generate_find_diag_indices_shader(); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for find_diag_indices:\n{}\n\nShader:\n{}", - e, shader - ) - }); - } - - #[test] - fn test_copy_shader_syntax() { - let shader = generate_copy_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader) - .unwrap_or_else(|e| panic!("Invalid WGSL for copy:\n{}\n\nShader:\n{}", e, shader)); - } - - #[test] - fn test_f64_not_supported() { - assert!(generate_copy_shader(DType::F64).is_err()); - } -} diff --git a/src/runtime/wgpu/shaders/generator/special/binary.rs b/src/runtime/wgpu/shaders/generator/special/binary.rs deleted file mode 100644 index d864edd8..00000000 --- a/src/runtime/wgpu/shaders/generator/special/binary.rs +++ /dev/null @@ -1,158 +0,0 @@ -//! WGSL shader generation for special binary functions -//! -//! Generates shaders for: beta, gammainc, gammaincc - -use super::super::common::{dtype_suffix, wgsl_type}; -use super::{common_constants, lgamma_helpers}; -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// Generate WGSL shader for special binary functions (beta, gammainc, gammaincc) -pub fn generate_special_binary_shader(dtype: DType) -> Result { - if dtype != DType::F32 { - return Err(Error::UnsupportedDType { - dtype, - op: "special functions (WebGPU requires F32)", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated special binary functions for {t} - -{constants} - -struct SpecialBinaryParams {{ - numel: u32, -}} - -@group(0) @binding(0) var special_a: array<{t}>; -@group(0) @binding(1) var special_b: array<{t}>; -@group(0) @binding(2) var special_out: array<{t}>; -@group(0) @binding(3) var special_params: SpecialBinaryParams; - -// ============================================================================ -// Helper Functions (shared lgamma) -// ============================================================================ -{lgamma_helpers} - -// Lower incomplete gamma series -fn gammainc_series(a: f32, x: f32) -> f32 {{ - if (x == 0.0) {{ - return 0.0; - }} - - var term = 1.0 / a; - var sum = term; - - for (var n = 1; n < MAX_ITER; n = n + 1) {{ - term = term * x / (a + f32(n)); - sum = sum + term; - if (abs(term) < abs(sum) * EPSILON) {{ - break; - }} - }} - - return exp(-x + a * log(x) - lgamma_impl(a)) * sum; -}} - -// Upper incomplete gamma continued fraction -fn gammaincc_cf(a: f32, x: f32) -> f32 {{ - var f = 1e30; - var c = 1e30; - var d = 0.0; - - for (var n = 1; n < MAX_ITER; n = n + 1) {{ - var an: f32; - if (n % 2 == 1) {{ - an = f32((n + 1) / 2); - }} else {{ - an = a - f32(n / 2); - }} - let bn = x + f32(n) - a; - - d = bn + an * d; - if (abs(d) < TINY) {{ - d = TINY; - }} - c = bn + an / c; - if (abs(c) < TINY) {{ - c = TINY; - }} - - d = 1.0 / d; - let delta = c * d; - f = f * delta; - - if (abs(delta - 1.0) < EPSILON) {{ - break; - }} - }} - - return exp(-x + a * log(x) - lgamma_impl(a)) / f; -}} - -fn gammainc_impl(a: f32, x: f32) -> f32 {{ - if (x < 0.0 || a <= 0.0) {{ - return bitcast(0x7FC00000u); // NaN - }} - if (x == 0.0) {{ - return 0.0; - }} - if (x < a + 1.0) {{ - return gammainc_series(a, x); - }} - return 1.0 - gammaincc_cf(a, x); -}} - -fn gammaincc_impl(a: f32, x: f32) -> f32 {{ - if (x < 0.0 || a <= 0.0) {{ - return bitcast(0x7FC00000u); // NaN - }} - if (x == 0.0) {{ - return 1.0; - }} - if (x < a + 1.0) {{ - return 1.0 - gammainc_series(a, x); - }} - return gammaincc_cf(a, x); -}} - -// ============================================================================ -// Compute Kernels -// ============================================================================ - -@compute @workgroup_size(256) -fn beta_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < special_params.numel) {{ - let a = special_a[idx]; - let b = special_b[idx]; - special_out[idx] = exp(lgamma_impl(a) + lgamma_impl(b) - lgamma_impl(a + b)); - }} -}} - -@compute @workgroup_size(256) -fn gammainc_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < special_params.numel) {{ - special_out[idx] = gammainc_impl(special_a[idx], special_b[idx]); - }} -}} - -@compute @workgroup_size(256) -fn gammaincc_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < special_params.numel) {{ - special_out[idx] = gammaincc_impl(special_a[idx], special_b[idx]); - }} -}} -"#, - t = t, - suffix = suffix, - constants = common_constants(), - lgamma_helpers = lgamma_helpers() - )) -} diff --git a/src/runtime/wgpu/shaders/generator/special/mod.rs b/src/runtime/wgpu/shaders/generator/special/mod.rs deleted file mode 100644 index 36c39a68..00000000 --- a/src/runtime/wgpu/shaders/generator/special/mod.rs +++ /dev/null @@ -1,90 +0,0 @@ -//! WGSL shader generation for special mathematical functions -//! -//! Implements erf, erfc, erfinv, gamma, lgamma, digamma, beta, -//! betainc, gammainc, gammaincc using numerical algorithms in WGSL. -//! -//! # Module Structure -//! -//! - `common` - Shared constants and helper functions -//! - `unary` - Unary function shaders (erf, erfc, erfinv, gamma, lgamma, digamma) -//! - `binary` - Binary function shaders (beta, gammainc, gammaincc) -//! - `ternary` - Ternary function shaders (betainc) - -mod binary; -mod ternary; -mod unary; - -pub use binary::generate_special_binary_shader; -pub use ternary::generate_special_ternary_shader; -pub use unary::generate_special_unary_shader; - -// ============================================================================ -// Shared Constants and Helpers -// ============================================================================ - -/// Generate WGSL constants used by all special function shaders. -pub(super) fn common_constants() -> &'static str { - r#"const WORKGROUP_SIZE: u32 = 256u; -const PI: f32 = 3.14159265358979323846; -const SQRT_PI: f32 = 1.7724538509055159; -const EULER_GAMMA: f32 = 0.5772156649015329; -const LN_SQRT_2PI: f32 = 0.9189385332046727; -const LANCZOS_G: f32 = 7.0; -const MAX_ITER: i32 = 100; -const EPSILON: f32 = 1e-6; -const TINY: f32 = 1e-30;"# -} - -/// Generate the common lgamma helper functions used by multiple shaders. -/// -/// These functions are shared between unary, binary, and ternary shaders -/// to avoid code duplication (~50 lines saved per shader). -pub(super) fn lgamma_helpers() -> &'static str { - r#" -// Lanczos computation for positive x only (no recursion) -fn lgamma_positive(x: f32) -> f32 { - // Lanczos coefficients (g=7, n=9) - let c0 = 0.99999999999980993; - let c1 = 676.5203681218851; - let c2 = -1259.1392167224028; - let c3 = 771.32342877765313; - let c4 = -176.61502916214059; - let c5 = 12.507343278686905; - let c6 = -0.13857109526572012; - let c7 = 9.9843695780195716e-6; - let c8 = 1.5056327351493116e-7; - - let z = x - 1.0; - var ag = c0; - ag = ag + c1 / (z + 1.0); - ag = ag + c2 / (z + 2.0); - ag = ag + c3 / (z + 3.0); - ag = ag + c4 / (z + 4.0); - ag = ag + c5 / (z + 5.0); - ag = ag + c6 / (z + 6.0); - ag = ag + c7 / (z + 7.0); - ag = ag + c8 / (z + 8.0); - - let t = z + LANCZOS_G + 0.5; - return LN_SQRT_2PI + (z + 0.5) * log(t) - t + log(ag); -} - -// Log-gamma using Lanczos approximation (non-recursive) -fn lgamma_impl(x: f32) -> f32 { - if (x <= 0.0) { - // Use reflection formula for negative values - if (x == floor(x)) { - return 1e30; // Pole at non-positive integers - } - // lgamma(x) = log(pi / sin(pi*x)) - lgamma(1-x) - // Since 1-x > 0 for x <= 0, we call lgamma_positive directly - let sinpix = sin(PI * x); - if (sinpix == 0.0) { - return 1e30; - } - return log(PI / abs(sinpix)) - lgamma_positive(1.0 - x); - } - - return lgamma_positive(x); -}"# -} diff --git a/src/runtime/wgpu/shaders/generator/special/ternary.rs b/src/runtime/wgpu/shaders/generator/special/ternary.rs deleted file mode 100644 index ef5d03f4..00000000 --- a/src/runtime/wgpu/shaders/generator/special/ternary.rs +++ /dev/null @@ -1,127 +0,0 @@ -//! WGSL shader generation for special ternary functions -//! -//! Generates shaders for: betainc - -use super::super::common::{dtype_suffix, wgsl_type}; -use super::{common_constants, lgamma_helpers}; -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// Generate WGSL shader for betainc (ternary: a, b, x) -pub fn generate_special_ternary_shader(dtype: DType) -> Result { - if dtype != DType::F32 { - return Err(Error::UnsupportedDType { - dtype, - op: "special functions (WebGPU requires F32)", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated special ternary functions for {t} - -{constants} - -struct SpecialTernaryParams {{ - numel: u32, -}} - -@group(0) @binding(0) var special_a: array<{t}>; -@group(0) @binding(1) var special_b: array<{t}>; -@group(0) @binding(2) var special_x: array<{t}>; -@group(0) @binding(3) var special_out: array<{t}>; -@group(0) @binding(4) var special_params: SpecialTernaryParams; - -// ============================================================================ -// Helper Functions (shared lgamma) -// ============================================================================ -{lgamma_helpers} - -// Regularized incomplete beta using continued fraction -fn betainc_cf(a: f32, b: f32, x: f32) -> f32 {{ - let qab = a + b; - let qap = a + 1.0; - let qam = a - 1.0; - - var c = 1.0; - var d = 1.0 - qab * x / qap; - if (abs(d) < TINY) {{ - d = TINY; - }} - d = 1.0 / d; - var h = d; - - for (var m = 1; m < MAX_ITER; m = m + 1) {{ - let m2 = 2 * m; - - var aa = f32(m) * (b - f32(m)) * x / ((qam + f32(m2)) * (a + f32(m2))); - d = 1.0 + aa * d; - if (abs(d) < TINY) {{ - d = TINY; - }} - c = 1.0 + aa / c; - if (abs(c) < TINY) {{ - c = TINY; - }} - d = 1.0 / d; - h = h * d * c; - - aa = -(a + f32(m)) * (qab + f32(m)) * x / ((a + f32(m2)) * (qap + f32(m2))); - d = 1.0 + aa * d; - if (abs(d) < TINY) {{ - d = TINY; - }} - c = 1.0 + aa / c; - if (abs(c) < TINY) {{ - c = TINY; - }} - d = 1.0 / d; - let delta = d * c; - h = h * delta; - - if (abs(delta - 1.0) < EPSILON) {{ - break; - }} - }} - - let lnbeta = lgamma_impl(a) + lgamma_impl(b) - lgamma_impl(a + b); - return exp(a * log(x) + b * log(1.0 - x) - lnbeta) * h / a; -}} - -fn betainc_impl(a: f32, b: f32, x: f32) -> f32 {{ - if (x <= 0.0) {{ - return 0.0; - }} - if (x >= 1.0) {{ - return 1.0; - }} - - // Use symmetry for better convergence (non-recursive version) - if (x > (a + 1.0) / (a + b + 2.0)) {{ - // Compute directly without recursion using symmetry - return 1.0 - betainc_cf(b, a, 1.0 - x); - }} - - return betainc_cf(a, b, x); -}} - -// ============================================================================ -// Compute Kernels -// ============================================================================ - -@compute @workgroup_size(256) -fn betainc_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < special_params.numel) {{ - special_out[idx] = betainc_impl(special_a[idx], special_b[idx], special_x[idx]); - }} -}} -"#, - t = t, - suffix = suffix, - constants = common_constants(), - lgamma_helpers = lgamma_helpers() - )) -} diff --git a/src/runtime/wgpu/shaders/generator/spmv.rs b/src/runtime/wgpu/shaders/generator/spmv.rs deleted file mode 100644 index 1facb379..00000000 --- a/src/runtime/wgpu/shaders/generator/spmv.rs +++ /dev/null @@ -1,218 +0,0 @@ -//! WGSL shader generation for sparse matrix-vector and matrix-matrix multiplication. -//! -//! SpMV (y = A * x) and SpMM (C = A * B) for CSR format matrices. -//! Row-parallel implementation that doesn't require atomics. - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -/// Generate WGSL shader for CSR SpMV: y = A * x -/// -/// Each workgroup thread processes one row of the sparse matrix. -pub fn generate_csr_spmv_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// CSR Sparse Matrix-Vector Multiplication: y = A * x -// Row-parallel implementation: one thread per row - -const WORKGROUP_SIZE: u32 = 256u; - -struct SpmvParams {{ - nrows: u32, - ncols: u32, - _pad0: u32, - _pad1: u32, -}} - -// CSR format -@group(0) @binding(0) var row_ptrs: array; -@group(0) @binding(1) var col_indices: array; -@group(0) @binding(2) var values: array<{t}>; -// Dense vector x -@group(0) @binding(3) var x: array<{t}>; -// Output vector y -@group(0) @binding(4) var y: array<{t}>; -// Parameters -@group(0) @binding(5) var params: SpmvParams; - -@compute @workgroup_size(256) -fn csr_spmv_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let row = gid.x; - if (row >= params.nrows) {{ - return; - }} - - let row_start = row_ptrs[row]; - let row_end = row_ptrs[row + 1u]; - - var sum: {t} = {zero}; - for (var j: i32 = row_start; j < row_end; j = j + 1) {{ - let col = col_indices[j]; - sum = sum + values[j] * x[col]; - }} - - y[row] = sum; -}} -"#, - t = t, - suffix = suffix, - zero = zero_literal(dtype), - )) -} - -/// Generate WGSL shader for CSR SpMM: C = A * B -/// -/// Row-parallel implementation where each thread computes one element of C. -/// Thread (row, col) computes C[row, col] = sum(A[row, :] * B[:, col]) -pub fn generate_csr_spmm_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// CSR Sparse Matrix-Dense Matrix Multiplication: C = A * B -// Each thread computes one output element C[row, col] - -const WORKGROUP_SIZE: u32 = 256u; - -struct SpmmParams {{ - m: u32, // Number of rows in A (and C) - k: u32, // Number of columns in A (and rows in B) - n: u32, // Number of columns in B (and C) - _pad: u32, -}} - -// CSR format for A -@group(0) @binding(0) var row_ptrs: array; -@group(0) @binding(1) var col_indices: array; -@group(0) @binding(2) var a_values: array<{t}>; -// Dense matrix B (k x n, row-major) -@group(0) @binding(3) var b: array<{t}>; -// Output matrix C (m x n, row-major) -@group(0) @binding(4) var c: array<{t}>; -// Parameters -@group(0) @binding(5) var params: SpmmParams; - -@compute @workgroup_size(256) -fn csr_spmm_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = params.m * params.n; - if (idx >= total) {{ - return; - }} - - let row = idx / params.n; - let col = idx % params.n; - - let row_start = row_ptrs[row]; - let row_end = row_ptrs[row + 1u]; - - var sum: {t} = {zero}; - for (var j: i32 = row_start; j < row_end; j = j + 1) {{ - let a_col = col_indices[j]; - let a_val = a_values[j]; - // B is row-major: B[a_col, col] = b[a_col * n + col] - let b_idx = u32(a_col) * params.n + col; - sum = sum + a_val * b[b_idx]; - }} - - // C is row-major: C[row, col] = c[row * n + col] - c[idx] = sum; -}} -"#, - t = t, - suffix = suffix, - zero = zero_literal(dtype), - )) -} - -/// Generate WGSL shader for CSR diagonal extraction: `diag[i] = A[i,i]` -/// -/// Thread-per-row: each thread scans its row for the diagonal entry. -pub fn generate_csr_extract_diagonal_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// CSR Extract Diagonal: diag[i] = A[i,i] -// Thread-per-row: each thread scans one row for col_index == row_index - -const WORKGROUP_SIZE: u32 = 256u; - -struct DiagParams {{ - n: u32, - _pad0: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var row_ptrs: array; -@group(0) @binding(1) var col_indices: array; -@group(0) @binding(2) var values: array<{t}>; -@group(0) @binding(3) var diag: array<{t}>; -@group(0) @binding(4) var params: DiagParams; - -@compute @workgroup_size(256) -fn csr_extract_diagonal_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let row = gid.x; - if (row >= params.n) {{ - return; - }} - - let row_start = row_ptrs[row]; - let row_end = row_ptrs[row + 1u]; - - var val: {t} = {zero}; - for (var j: i32 = row_start; j < row_end; j = j + 1) {{ - if (col_indices[j] == i32(row)) {{ - val = values[j]; - break; - }} - }} - - diag[row] = val; -}} -"#, - t = t, - suffix = suffix, - zero = zero_literal(dtype), - )) -} - -/// Get zero literal for dtype -fn zero_literal(dtype: DType) -> &'static str { - match dtype { - DType::F32 | DType::F16 => "0.0", - DType::I32 => "0", - DType::U32 => "0u", - _ => "0.0", - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_csr_spmv_shader_syntax_f32() { - let shader = generate_csr_spmv_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("SpMV shader should be valid WGSL"); - } - - #[test] - fn test_csr_spmm_shader_syntax_f32() { - let shader = generate_csr_spmm_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("SpMM shader should be valid WGSL"); - } -} diff --git a/src/runtime/wgpu/shaders/generator/unary.rs b/src/runtime/wgpu/shaders/generator/unary.rs deleted file mode 100644 index ed9db45d..00000000 --- a/src/runtime/wgpu/shaders/generator/unary.rs +++ /dev/null @@ -1,374 +0,0 @@ -//! WGSL shader generation for unary element-wise operations - -use super::common::{dtype_suffix, is_wgsl_float, wgsl_type}; -use crate::dtype::DType; -use crate::error::Result; - -/// Generate WGSL shader for unary element-wise operations -pub fn generate_unary_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - // Signed-only operations (F32, I32 - not U32) - let signed_ops = if dtype != DType::U32 { - format!( - r#" -@compute @workgroup_size(256) -fn neg_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = -unary_a[idx]; - }} -}} -"#, - suffix = suffix - ) - } else { - // U32 doesn't support negation - String::new() - }; - - // Float-only operations - let float_ops = if is_wgsl_float(dtype) { - format!( - r#" -@compute @workgroup_size(256) -fn sqrt_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = sqrt(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn exp_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = exp(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn log_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = log(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn sin_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = sin(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn cos_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = cos(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn tan_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = tan(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn atan_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = atan(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn tanh_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = tanh(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn recip_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = 1.0 / unary_a[idx]; - }} -}} - -@compute @workgroup_size(256) -fn floor_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = floor(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn ceil_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = ceil(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn round_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - // Match CPU/CUDA behavior: ties round away from zero. - let x = unary_a[idx]; - unary_out[idx] = select(ceil(x - 0.5), floor(x + 0.5), x >= 0.0); - }} -}} - -@compute @workgroup_size(256) -fn trunc_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = trunc(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn rsqrt_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = inverseSqrt(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn cbrt_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - let x = unary_a[idx]; - // cbrt(x) = sign(x) * pow(abs(x), 1/3) - unary_out[idx] = sign(x) * pow(abs(x), 1.0 / 3.0); - }} -}} - -@compute @workgroup_size(256) -fn exp2_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = exp2(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn expm1_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = exp(unary_a[idx]) - 1.0; - }} -}} - -@compute @workgroup_size(256) -fn log2_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = log2(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn log10_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - // log10(x) = log(x) / log(10) = log(x) * 0.4342944819032518 - unary_out[idx] = log(unary_a[idx]) * 0.4342944819032518; - }} -}} - -@compute @workgroup_size(256) -fn log1p_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = log(1.0 + unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn asin_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - let x = unary_a[idx]; - let y = sqrt(max(0.0, 1.0 - x * x)); - unary_out[idx] = atan2(x, y); - }} -}} - -@compute @workgroup_size(256) -fn acos_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - let x = unary_a[idx]; - let y = sqrt(max(0.0, 1.0 - x * x)); - unary_out[idx] = atan2(y, x); - }} -}} - -@compute @workgroup_size(256) -fn sinh_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = sinh(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn cosh_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = cosh(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn asinh_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = asinh(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn acosh_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = acosh(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn atanh_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = atanh(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn relu_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = max(unary_a[idx], 0.0); - }} -}} - -@compute @workgroup_size(256) -fn sigmoid_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = 1.0 / (1.0 + exp(-unary_a[idx])); - }} -}} - -@compute @workgroup_size(256) -fn silu_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - let x = unary_a[idx]; - unary_out[idx] = x / (1.0 + exp(-x)); - }} -}} - -@compute @workgroup_size(256) -fn gelu_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - let x = unary_a[idx]; - let c = 0.7978845608028654; // sqrt(2/pi) - unary_out[idx] = 0.5 * x * (1.0 + tanh(c * (x + 0.044715 * x * x * x))); - }} -}} - -@compute @workgroup_size(256) -fn isnan_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - let x = unary_a[idx]; - let bits = bitcast(f32(x)); - let exp = bits & 0x7f800000u; - let mant = bits & 0x007fffffu; - let is_nan = (exp == 0x7f800000u) && (mant != 0u); - unary_out[idx] = select(0.0, 1.0, is_nan); - }} -}} - -@compute @workgroup_size(256) -fn isinf_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - let x = unary_a[idx]; - let bits = bitcast(f32(x)); - let exp = bits & 0x7f800000u; - let mant = bits & 0x007fffffu; - let is_inf = (exp == 0x7f800000u) && (mant == 0u); - unary_out[idx] = select(0.0, 1.0, is_inf); - }} -}} -"#, - suffix = suffix - ) - } else { - // Integer types don't have these operations - String::new() - }; - - Ok(format!( - r#"// Auto-generated unary operations for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct UnaryParams {{ - numel: u32, -}} - -@group(0) @binding(0) var unary_a: array<{t}>; -@group(0) @binding(1) var unary_out: array<{t}>; -@group(0) @binding(2) var unary_params: UnaryParams; - -{signed_ops} -@compute @workgroup_size(256) -fn abs_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = abs(unary_a[idx]); - }} -}} - -@compute @workgroup_size(256) -fn square_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - let x = unary_a[idx]; - unary_out[idx] = x * x; - }} -}} - -@compute @workgroup_size(256) -fn sign_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < unary_params.numel) {{ - unary_out[idx] = sign(unary_a[idx]); - }} -}} - -{float_ops} -"#, - t = t, - suffix = suffix, - signed_ops = signed_ops, - float_ops = float_ops - )) -} diff --git a/src/runtime/wgpu/shaders/generator/utility.rs b/src/runtime/wgpu/shaders/generator/utility.rs deleted file mode 100644 index bf7f1bb8..00000000 --- a/src/runtime/wgpu/shaders/generator/utility.rs +++ /dev/null @@ -1,497 +0,0 @@ -//! WGSL shader generation for utility operations: arange, linspace, eye, rand, randn, randint - -use super::common::{dtype_suffix, is_wgsl_float, is_wgsl_int, wgsl_type}; -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// Generate WGSL shader for arange operation -pub fn generate_arange_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated arange operation for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct ArangeParams {{ - numel: u32, - start: f32, - step: f32, -}} - -@group(0) @binding(0) var arange_out: array<{t}>; -@group(0) @binding(1) var arange_params: ArangeParams; - -@compute @workgroup_size(256) -fn arange_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < arange_params.numel) {{ - let value = arange_params.start + arange_params.step * f32(idx); - arange_out[idx] = {t}(value); - }} -}} -"#, - t = t, - suffix = suffix - )) -} - -/// Generate WGSL shader for linspace operation -pub fn generate_linspace_shader(dtype: DType) -> Result { - // linspace only makes sense for float types - if !is_wgsl_float(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "linspace", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated linspace operation for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct LinspaceParams {{ - steps: u32, - start: f32, - stop: f32, -}} - -@group(0) @binding(0) var linspace_out: array<{t}>; -@group(0) @binding(1) var linspace_params: LinspaceParams; - -@compute @workgroup_size(256) -fn linspace_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < linspace_params.steps) {{ - let t_val = f32(idx) / f32(linspace_params.steps - 1u); - let value = linspace_params.start + (linspace_params.stop - linspace_params.start) * t_val; - linspace_out[idx] = {t}(value); - }} -}} -"#, - t = t, - suffix = suffix - )) -} - -/// Generate WGSL shader for eye operation (identity matrix) -pub fn generate_eye_shader(dtype: DType) -> Result { - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - // Determine the correct "one" and "zero" values based on type - let (one_val, zero_val) = if is_wgsl_float(dtype) { - ("1.0", "0.0") - } else { - ("1", "0") - }; - - Ok(format!( - r#"// Auto-generated eye (identity matrix) operation for {t} - -const WORKGROUP_SIZE: u32 = 256u; - -struct EyeParams {{ - n: u32, // rows - m: u32, // cols - numel: u32, // n * m -}} - -@group(0) @binding(0) var eye_out: array<{t}>; -@group(0) @binding(1) var eye_params: EyeParams; - -@compute @workgroup_size(256) -fn eye_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < eye_params.numel) {{ - let row = idx / eye_params.m; - let col = idx % eye_params.m; - if (row == col) {{ - eye_out[idx] = {t}({one_val}); - }} else {{ - eye_out[idx] = {t}({zero_val}); - }} - }} -}} -"#, - t = t, - suffix = suffix, - one_val = one_val, - zero_val = zero_val - )) -} - -// ============================================================================ -// Random Number Generation Shaders -// ============================================================================ - -/// WGSL implementation of PCG hash for random number generation -/// This produces high-quality random numbers suitable for most applications. -const PCG_HASH_WGSL: &str = r#" -// PCG hash function for random number generation -// Based on PCG Random Number Generation by Melissa O'Neill -fn pcg_hash(input: u32) -> u32 { - var state = input * 747796405u + 2891336453u; - var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; - return (word >> 22u) ^ word; -} - -// Initialize PCG state from seed and index -fn pcg_init(seed: u32, idx: u32) -> u32 { - return pcg_hash(seed ^ pcg_hash(idx)); -} - -// Generate uniform float in [0, 1) -fn pcg_uniform(state: ptr) -> f32 { - *state = pcg_hash(*state); - return f32(*state) / 4294967296.0; // Divide by 2^32 -} - -// Box-Muller transform for normal distribution -// Generates one normal value, requires two uniform values -fn box_muller(u1: f32, u2: f32) -> f32 { - let u1_safe = max(u1, 0.0000001); // Avoid log(0) - let r = sqrt(-2.0 * log(u1_safe)); - let theta = 6.28318530718 * u2; // 2 * PI - return r * cos(theta); -} -"#; - -/// Generate WGSL shader for rand operation (uniform [0, 1)) -pub fn generate_rand_shader(dtype: DType) -> Result { - // rand only makes sense for float types - if !is_wgsl_float(dtype) { - return Err(Error::UnsupportedDType { dtype, op: "rand" }); - } - - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated rand operation for {t} -{pcg_hash} -const WORKGROUP_SIZE: u32 = 256u; - -struct RandParams {{ - numel: u32, - seed: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var rand_out: array<{t}>; -@group(0) @binding(1) var rand_params: RandParams; - -@compute @workgroup_size(256) -fn rand_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < rand_params.numel) {{ - var state = pcg_init(rand_params.seed, idx); - let value = pcg_uniform(&state); - rand_out[idx] = {t}(value); - }} -}} -"#, - t = t, - suffix = suffix, - pcg_hash = PCG_HASH_WGSL - )) -} - -/// Generate WGSL shader for randn operation (standard normal N(0, 1)) -pub fn generate_randn_shader(dtype: DType) -> Result { - // randn only makes sense for float types - if !is_wgsl_float(dtype) { - return Err(Error::UnsupportedDType { dtype, op: "randn" }); - } - - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated randn operation for {t} -{pcg_hash} -const WORKGROUP_SIZE: u32 = 256u; - -struct RandnParams {{ - numel: u32, - seed: u32, - _pad1: u32, - _pad2: u32, -}} - -@group(0) @binding(0) var randn_out: array<{t}>; -@group(0) @binding(1) var randn_params: RandnParams; - -@compute @workgroup_size(256) -fn randn_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < randn_params.numel) {{ - // Use two uniform random values for Box-Muller - var state = pcg_init(randn_params.seed, idx); - let u1 = pcg_uniform(&state); - let u2 = pcg_uniform(&state); - let value = box_muller(u1, u2); - randn_out[idx] = {t}(value); - }} -}} -"#, - t = t, - suffix = suffix, - pcg_hash = PCG_HASH_WGSL - )) -} - -/// Generate WGSL shader for randint operation (uniform integers in [low, high)) -/// -/// For signed integers (I32): low is stored as i32, arithmetic done in i32 -/// For unsigned integers (U32): low is stored as u32, arithmetic done in u32 -/// -/// This ensures correct handling of negative bounds for signed types and -/// avoids overflow issues with large unsigned ranges. -pub fn generate_randint_shader(dtype: DType) -> Result { - // randint only makes sense for integer types - if !is_wgsl_int(dtype) { - return Err(Error::UnsupportedDType { - dtype, - op: "randint", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - // Generate completely separate shaders for signed vs unsigned - // This avoids type casting issues and overflow problems - let is_signed = matches!(dtype, DType::I32); - - if is_signed { - // Signed integer version: low stored as i32, arithmetic in i32 - Ok(format!( - r#"// Auto-generated randint operation for {t} (signed) -{pcg_hash} -const WORKGROUP_SIZE: u32 = 256u; - -struct RandintParams {{ - numel: u32, - low: i32, // Low bound as signed integer - range: u32, // high - low (always positive, fits in u32) - seed: u32, -}} - -@group(0) @binding(0) var randint_out: array<{t}>; -@group(0) @binding(1) var randint_params: RandintParams; - -@compute @workgroup_size(256) -fn randint_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < randint_params.numel) {{ - var state = pcg_init(randint_params.seed, idx); - let r = pcg_hash(state); - // Compute offset in unsigned space, then add to signed low - let offset = r % randint_params.range; - // Safe: offset < range, so low + offset won't overflow if inputs are valid - randint_out[idx] = randint_params.low + i32(offset); - }} -}} -"#, - t = t, - suffix = suffix, - pcg_hash = PCG_HASH_WGSL - )) - } else { - // Unsigned integer version: all arithmetic in u32 - Ok(format!( - r#"// Auto-generated randint operation for {t} (unsigned) -{pcg_hash} -const WORKGROUP_SIZE: u32 = 256u; - -struct RandintParams {{ - numel: u32, - low: u32, // Low bound as unsigned integer - range: u32, // high - low - seed: u32, -}} - -@group(0) @binding(0) var randint_out: array<{t}>; -@group(0) @binding(1) var randint_params: RandintParams; - -@compute @workgroup_size(256) -fn randint_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < randint_params.numel) {{ - var state = pcg_init(randint_params.seed, idx); - let r = pcg_hash(state); - // Pure unsigned arithmetic - no overflow for valid inputs - let offset = r % randint_params.range; - randint_out[idx] = randint_params.low + offset; - }} -}} -"#, - t = t, - suffix = suffix, - pcg_hash = PCG_HASH_WGSL - )) - } -} - -/// Generate WGSL shader for multinomial sampling with replacement -/// -/// Uses inverse transform sampling (CDF method): -/// 1. Compute cumulative sum of normalized probabilities -/// 2. For each sample, draw uniform random u ∈ `[0, 1)` -/// 3. Find smallest index i where `CDF[i]` ≥ u (linear search) -pub fn generate_multinomial_with_replacement_shader() -> Result { - Ok(format!( - r#"// Auto-generated multinomial_with_replacement operation for f32 -{pcg_hash} -const WORKGROUP_SIZE: u32 = 256u; - -struct MultinomialParams {{ - num_distributions: u32, - num_categories: u32, - num_samples: u32, - seed: u32, -}} - -@group(0) @binding(0) var probs: array; -@group(0) @binding(1) var multinomial_out: array; -@group(0) @binding(2) var multinomial_params: MultinomialParams; - -@compute @workgroup_size(256) -fn multinomial_with_replacement_f32(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - let total = multinomial_params.num_distributions * multinomial_params.num_samples; - if (idx >= total) {{ - return; - }} - - let dist = idx / multinomial_params.num_samples; - let sample = idx % multinomial_params.num_samples; - - // Initialize RNG for this thread - var state = pcg_init(multinomial_params.seed, idx); - - // Get pointer to this distribution's probabilities - let prob_offset = dist * multinomial_params.num_categories; - - // Compute sum of probabilities for normalization - var sum: f32 = 0.0; - for (var i: u32 = 0u; i < multinomial_params.num_categories; i = i + 1u) {{ - sum = sum + probs[prob_offset + i]; - }} - - // Generate uniform random value - let u = pcg_uniform(&state); - - // Linear search using CDF (on-the-fly computation) - // Find smallest index where cumsum/sum >= u - var cumsum: f32 = 0.0; - var result: u32 = multinomial_params.num_categories - 1u; // Default to last category - for (var i: u32 = 0u; i < multinomial_params.num_categories; i = i + 1u) {{ - cumsum = cumsum + probs[prob_offset + i]; - if (cumsum / sum >= u) {{ - result = i; - break; - }} - }} - - multinomial_out[dist * multinomial_params.num_samples + sample] = i32(result); -}} -"#, - pcg_hash = PCG_HASH_WGSL - )) -} - -/// Generate WGSL shader for multinomial sampling without replacement -/// -/// Uses sequential sampling within each distribution. Each workgroup handles -/// one distribution. Selected categories are zeroed out in shared memory to -/// prevent resampling. -/// -/// Note: This kernel is less parallelizable than with-replacement because -/// samples within a distribution must be sequential to ensure uniqueness. -pub fn generate_multinomial_without_replacement_shader() -> Result { - Ok(format!( - r#"// Auto-generated multinomial_without_replacement operation for f32 -{pcg_hash} -const WORKGROUP_SIZE: u32 = 256u; -const MAX_CATEGORIES: u32 = 1024u; // Maximum supported categories - -struct MultinomialParams {{ - num_distributions: u32, - num_categories: u32, - num_samples: u32, - seed: u32, -}} - -@group(0) @binding(0) var probs: array; -@group(0) @binding(1) var multinomial_out: array; -@group(0) @binding(2) var multinomial_params: MultinomialParams; - -var shared_probs: array; - -@compute @workgroup_size(256) -fn multinomial_without_replacement_f32(@builtin(global_invocation_id) gid: vec3, @builtin(local_invocation_id) lid: vec3) {{ - let dist = gid.x / WORKGROUP_SIZE; - if (dist >= multinomial_params.num_distributions) {{ - return; - }} - - // Copy probabilities to shared memory (each thread copies some elements) - let prob_offset = dist * multinomial_params.num_categories; - let elements_per_thread = (multinomial_params.num_categories + WORKGROUP_SIZE - 1u) / WORKGROUP_SIZE; - for (var i: u32 = 0u; i < elements_per_thread; i = i + 1u) {{ - let idx = lid.x * elements_per_thread + i; - if (idx < multinomial_params.num_categories) {{ - shared_probs[idx] = probs[prob_offset + idx]; - }} - }} - - workgroupBarrier(); - - // Only thread 0 does the sequential sampling - if (lid.x != 0u) {{ - return; - }} - - // Initialize RNG - var state = pcg_init(multinomial_params.seed, dist); - - // Sample without replacement - for (var s: u32 = 0u; s < multinomial_params.num_samples; s = s + 1u) {{ - // Compute sum of remaining probabilities - var sum: f32 = 0.0; - for (var i: u32 = 0u; i < multinomial_params.num_categories; i = i + 1u) {{ - sum = sum + shared_probs[i]; - }} - - // Generate uniform random value - let u = pcg_uniform(&state); - - // Linear search using CDF - var cumsum: f32 = 0.0; - var result: u32 = multinomial_params.num_categories - 1u; - for (var i: u32 = 0u; i < multinomial_params.num_categories; i = i + 1u) {{ - cumsum = cumsum + shared_probs[i]; - if (cumsum / sum >= u) {{ - result = i; - break; - }} - }} - - multinomial_out[dist * multinomial_params.num_samples + s] = i32(result); - - // Zero out selected category - shared_probs[result] = 0.0; - }} -}} -"#, - pcg_hash = PCG_HASH_WGSL - )) -} diff --git a/src/runtime/wgpu/shaders/generator/where_cond.rs b/src/runtime/wgpu/shaders/generator/where_cond.rs deleted file mode 100644 index b4235a9f..00000000 --- a/src/runtime/wgpu/shaders/generator/where_cond.rs +++ /dev/null @@ -1,206 +0,0 @@ -//! WGSL shader generation for where_cond (ternary conditional select) -//! -//! Generates shaders for: where_cond(condition, x, y) → output -//! where `output[i] = condition[i] != 0 ? x[i] : y[i]` -//! -//! Supports multiple condition dtypes (F32, I32, U32) and multiple output dtypes. - -use super::common::{dtype_suffix, wgsl_type}; -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// Generate WGSL shader for where_cond operation. -/// -/// Creates kernels for both element-wise and broadcast where operations. -/// The condition is tested for non-zero: any non-zero value is treated as true. -/// -/// # Arguments -/// -/// * `cond_dtype` - Data type of condition tensor (F32, I32, U32) -/// * `out_dtype` - Data type of x, y, and output tensors -/// -/// # Entry Points -/// -/// * `where_cond_{cond_suffix}_{out_suffix}` - Element-wise where -/// * `where_broadcast_cond_{cond_suffix}_{out_suffix}` - Broadcast where -pub fn generate_where_cond_shader(cond_dtype: DType, out_dtype: DType) -> Result { - let cond_t = wgsl_type(cond_dtype)?; - let out_t = wgsl_type(out_dtype)?; - let cond_suffix = dtype_suffix(cond_dtype)?; - let out_suffix = dtype_suffix(out_dtype)?; - - // Generate zero literal for comparison - let zero_cmp = match cond_dtype { - DType::F32 | DType::F16 => "0.0", - DType::I32 | DType::U32 => "0", - _ => { - return Err(Error::UnsupportedDType { - dtype: cond_dtype, - op: "where_cond (condition dtype)", - }); - } - }; - - Ok(format!( - r#"// Auto-generated where_cond shader for condition={cond_t}, output={out_t} - -const WORKGROUP_SIZE: u32 = 256u; -const MAX_DIMS: u32 = 8u; - -// ============================================================================ -// Element-wise where_cond -// ============================================================================ - -struct WhereParams {{ - numel: u32, -}} - -@group(0) @binding(0) var where_cond_arr: array<{cond_t}>; -@group(0) @binding(1) var where_x: array<{out_t}>; -@group(0) @binding(2) var where_y: array<{out_t}>; -@group(0) @binding(3) var where_out: array<{out_t}>; -@group(0) @binding(4) var where_params: WhereParams; - -@compute @workgroup_size(256) -fn where_cond_{cond_suffix}_{out_suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx < where_params.numel) {{ - // Condition is true if non-zero - let cond_val = where_cond_arr[idx] != {zero_cmp}; - where_out[idx] = select(where_y[idx], where_x[idx], cond_val); - }} -}} - -// ============================================================================ -// Broadcast where_cond -// ============================================================================ - -struct WhereBroadcastParams {{ - numel: u32, - ndim: u32, - _pad0: u32, - _pad1: u32, -}} - -@group(0) @binding(0) var bc_cond: array<{cond_t}>; -@group(0) @binding(1) var bc_x: array<{out_t}>; -@group(0) @binding(2) var bc_y: array<{out_t}>; -@group(0) @binding(3) var bc_out: array<{out_t}>; -@group(0) @binding(4) var cond_strides: array; -@group(0) @binding(5) var x_strides: array; -@group(0) @binding(6) var y_strides: array; -@group(0) @binding(7) var out_shape: array; -@group(0) @binding(8) var bc_params: WhereBroadcastParams; - -@compute @workgroup_size(256) -fn where_broadcast_cond_{cond_suffix}_{out_suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let idx = gid.x; - if (idx >= bc_params.numel) {{ - return; - }} - - // Convert linear index to multi-dimensional coords and compute offsets - var remaining = idx; - var cond_offset: u32 = 0u; - var x_offset: u32 = 0u; - var y_offset: u32 = 0u; - - for (var d: u32 = 0u; d < bc_params.ndim; d = d + 1u) {{ - let dim_size = out_shape[d]; - let coord = remaining / compute_out_stride(d, bc_params.ndim); - remaining = remaining % compute_out_stride(d, bc_params.ndim); - - cond_offset = cond_offset + coord * cond_strides[d]; - x_offset = x_offset + coord * x_strides[d]; - y_offset = y_offset + coord * y_strides[d]; - }} - - // Apply condition - let cond_val = bc_cond[cond_offset] != {zero_cmp}; - bc_out[idx] = select(bc_y[y_offset], bc_x[x_offset], cond_val); -}} - -// Helper function to compute output stride at dimension d -fn compute_out_stride(d: u32, ndim: u32) -> u32 {{ - var stride: u32 = 1u; - for (var i: u32 = d + 1u; i < ndim; i = i + 1u) {{ - stride = stride * out_shape[i]; - }} - return stride; -}} -"#, - cond_t = cond_t, - out_t = out_t, - cond_suffix = cond_suffix, - out_suffix = out_suffix, - zero_cmp = zero_cmp, - )) -} - -#[cfg(test)] -mod tests { - use super::*; - - /// Helper to validate WGSL shader syntax using naga parser - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_where_cond_shader_f32_f32() { - let shader = generate_where_cond_shader(DType::F32, DType::F32).unwrap(); - assert!(shader.contains("fn where_cond_f32_f32")); - assert!(shader.contains("fn where_broadcast_cond_f32_f32")); - assert!(shader.contains("array")); - validate_wgsl_syntax(&shader).unwrap(); - } - - #[test] - fn test_where_cond_shader_i32_f32() { - let shader = generate_where_cond_shader(DType::I32, DType::F32).unwrap(); - assert!(shader.contains("fn where_cond_i32_f32")); - assert!(shader.contains("fn where_broadcast_cond_i32_f32")); - validate_wgsl_syntax(&shader).unwrap(); - } - - #[test] - fn test_where_cond_shader_u32_f32() { - let shader = generate_where_cond_shader(DType::U32, DType::F32).unwrap(); - assert!(shader.contains("fn where_cond_u32_f32")); - validate_wgsl_syntax(&shader).unwrap(); - } - - #[test] - fn test_where_cond_shader_f32_i32() { - let shader = generate_where_cond_shader(DType::F32, DType::I32).unwrap(); - assert!(shader.contains("fn where_cond_f32_i32")); - validate_wgsl_syntax(&shader).unwrap(); - } - - #[test] - fn test_where_cond_shader_all_combinations() { - let dtypes = [DType::F32, DType::I32, DType::U32]; - for cond_dtype in &dtypes { - for out_dtype in &dtypes { - let shader = - generate_where_cond_shader(*cond_dtype, *out_dtype).unwrap_or_else(|e| { - panic!( - "Failed to generate where_cond shader for {:?}/{:?}: {}", - cond_dtype, out_dtype, e - ) - }); - validate_wgsl_syntax(&shader).unwrap_or_else(|e| { - panic!( - "Invalid WGSL for where_cond {:?}/{:?}:\n{}\n\nShader:\n{}", - cond_dtype, out_dtype, e, shader - ) - }); - } - } - } -} diff --git a/src/runtime/wgpu/shaders/hermitian_extend.wgsl b/src/runtime/wgpu/shaders/hermitian_extend.wgsl new file mode 100644 index 00000000..99827f82 --- /dev/null +++ b/src/runtime/wgpu/shaders/hermitian_extend.wgsl @@ -0,0 +1,41 @@ +// Hermitian extend shader - extends N/2+1 complex to N complex using symmetry + +const WORKGROUP_SIZE: u32 = 256u; + +struct ExtendParams { + n: u32, // Full FFT size + half_n: u32, // N/2 + 1 (input size) + batch_size: u32, + _pad: u32, +} + +@group(0) @binding(0) var extend_input: array>; +@group(0) @binding(1) var extend_output: array>; +@group(0) @binding(2) var extend_params: ExtendParams; + +@compute @workgroup_size(WORKGROUP_SIZE) +fn hermitian_extend( + @builtin(global_invocation_id) gid: vec3 +) { + let idx = gid.x; + let batch_idx = gid.y; + let n = extend_params.n; + let half_n = extend_params.half_n; + + if (idx >= n) { + return; + } + + let in_offset = batch_idx * half_n; + let out_offset = batch_idx * n; + + if (idx < half_n) { + // Direct copy for first half + extend_output[out_offset + idx] = extend_input[in_offset + idx]; + } else { + // Conjugate symmetry for second half: X[N-k] = conj(X[k]) + let k = n - idx; + let val = extend_input[in_offset + k]; + extend_output[out_offset + idx] = vec2(val.x, -val.y); + } +} diff --git a/src/runtime/wgpu/shaders/imag_complex64.wgsl b/src/runtime/wgpu/shaders/imag_complex64.wgsl new file mode 100644 index 00000000..a045af16 --- /dev/null +++ b/src/runtime/wgpu/shaders/imag_complex64.wgsl @@ -0,0 +1,18 @@ +// Complex imaginary-part extraction shader +// entry point: imag_complex64 + +struct Params { + numel: u32, +} + +@group(0) @binding(0) var input: array>; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: Params; + +@compute @workgroup_size(256) +fn imag_complex64(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < params.numel) { + output[idx] = input[idx].y; // Extract imaginary component + } +} diff --git a/src/runtime/wgpu/shaders/index.rs b/src/runtime/wgpu/shaders/index.rs index 66c9f9b5..bbb5bce7 100644 --- a/src/runtime/wgpu/shaders/index.rs +++ b/src/runtime/wgpu/shaders/index.rs @@ -11,83 +11,287 @@ use wgpu::{Buffer, Queue}; -use super::generator::{ - generate_embedding_lookup_shader, generate_gather_shader, generate_index_put_shader, - generate_index_select_shader, generate_masked_fill_shader, generate_masked_select_shader, - generate_scatter_shader, generate_validate_indices_shader, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; use crate::error::{Error, Result}; // ============================================================================ -// Helper Functions +// Static shaders — data-movement ops (F32 / I32 / U32) // ============================================================================ -/// Check if dtype is supported for index operations on WebGPU. -fn check_dtype_supported(dtype: DType, op: &'static str) -> Result<()> { - match dtype { - DType::F32 | DType::I32 | DType::U32 => Ok(()), - _ => Err(Error::UnsupportedDType { dtype, op }), - } -} +const INDEX_SELECT_SHADER_F32: &str = include_str!("index_select_f32.wgsl"); +const INDEX_SELECT_SHADER_I32: &str = include_str!("index_select_i32.wgsl"); +const INDEX_SELECT_SHADER_U32: &str = include_str!("index_select_u32.wgsl"); -/// Get the static module/entry point name for an index operation. -/// -/// Returns the kernel name in format `{op}_{dtype_suffix}`. -/// For WebGPU index operations, module name and entry point are identical. -fn kernel_name(op: &'static str, dtype: DType) -> Result<&'static str> { - match (op, dtype) { - ("index_select", DType::F32) => Ok("index_select_f32"), - ("index_select", DType::I32) => Ok("index_select_i32"), - ("index_select", DType::U32) => Ok("index_select_u32"), - ("index_put", DType::F32) => Ok("index_put_f32"), - ("index_put", DType::I32) => Ok("index_put_i32"), - ("index_put", DType::U32) => Ok("index_put_u32"), - ("gather", DType::F32) => Ok("gather_f32"), - ("gather", DType::I32) => Ok("gather_i32"), - ("gather", DType::U32) => Ok("gather_u32"), - ("scatter", DType::F32) => Ok("scatter_f32"), - ("scatter", DType::I32) => Ok("scatter_i32"), - ("scatter", DType::U32) => Ok("scatter_u32"), - ("copy", DType::F32) => Ok("copy_f32"), - ("copy", DType::I32) => Ok("copy_i32"), - ("copy", DType::U32) => Ok("copy_u32"), - ("masked_fill", DType::F32) => Ok("masked_fill_f32"), - ("masked_fill", DType::I32) => Ok("masked_fill_i32"), - ("masked_fill", DType::U32) => Ok("masked_fill_u32"), - ("masked_select", DType::F32) => Ok("masked_select_f32"), - ("masked_select", DType::I32) => Ok("masked_select_i32"), - ("masked_select", DType::U32) => Ok("masked_select_u32"), - ("embedding_lookup", DType::F32) => Ok("embedding_lookup_f32"), - ("embedding_lookup", DType::I32) => Ok("embedding_lookup_i32"), - ("embedding_lookup", DType::U32) => Ok("embedding_lookup_u32"), - ("gather_nd", DType::F32) => Ok("gather_nd_f32"), - ("gather_nd", DType::I32) => Ok("gather_nd_i32"), - ("gather_nd", DType::U32) => Ok("gather_nd_u32"), - ("bincount", DType::F32) => Ok("bincount_weighted_f32"), - ("bincount", DType::I32) => Ok("bincount_weighted_i32"), - ("bincount", DType::U32) => Ok("bincount_weighted_u32"), - ("bincount_unweighted", _) => Ok("bincount_i32"), - ("scatter_reduce_sum", DType::F32) => Ok("scatter_reduce_sum_f32"), - ("scatter_reduce_sum", DType::I32) => Ok("scatter_reduce_sum_i32"), - ("scatter_reduce_sum", DType::U32) => Ok("scatter_reduce_sum_u32"), - ("scatter_reduce_max", DType::F32) => Ok("scatter_reduce_max_f32"), - ("scatter_reduce_max", DType::I32) => Ok("scatter_reduce_max_i32"), - ("scatter_reduce_max", DType::U32) => Ok("scatter_reduce_max_u32"), - ("scatter_reduce_min", DType::F32) => Ok("scatter_reduce_min_f32"), - ("scatter_reduce_min", DType::I32) => Ok("scatter_reduce_min_i32"), - ("scatter_reduce_min", DType::U32) => Ok("scatter_reduce_min_u32"), - ("scatter_reduce_prod", DType::F32) => Ok("scatter_reduce_prod_f32"), - ("scatter_reduce_prod", DType::I32) => Ok("scatter_reduce_prod_i32"), - ("scatter_reduce_prod", DType::U32) => Ok("scatter_reduce_prod_u32"), - ("scatter_reduce_count", DType::F32) => Ok("scatter_reduce_count_f32"), - ("scatter_reduce_mean_div", DType::F32) => Ok("scatter_reduce_mean_div_f32"), - ("gather_2d", DType::F32) => Ok("gather_2d_f32"), - ("gather_2d", DType::I32) => Ok("gather_2d_i32"), - ("gather_2d", DType::U32) => Ok("gather_2d_u32"), - _ => Err(Error::UnsupportedDType { dtype, op }), - } +const INDEX_PUT_SHADER_F32: &str = include_str!("index_put_f32.wgsl"); +const INDEX_PUT_SHADER_I32: &str = include_str!("index_put_i32.wgsl"); +const INDEX_PUT_SHADER_U32: &str = include_str!("index_put_u32.wgsl"); + +const GATHER_SHADER_F32: &str = include_str!("gather_f32.wgsl"); +const GATHER_SHADER_I32: &str = include_str!("gather_i32.wgsl"); +const GATHER_SHADER_U32: &str = include_str!("gather_u32.wgsl"); + +const SCATTER_SHADER_F32: &str = include_str!("scatter_f32.wgsl"); +const SCATTER_SHADER_I32: &str = include_str!("scatter_i32.wgsl"); +const SCATTER_SHADER_U32: &str = include_str!("scatter_u32.wgsl"); + +const MASKED_FILL_SHADER_F32: &str = include_str!("masked_fill_f32.wgsl"); +const MASKED_FILL_SHADER_I32: &str = include_str!("masked_fill_i32.wgsl"); +const MASKED_FILL_SHADER_U32: &str = include_str!("masked_fill_u32.wgsl"); + +const MASKED_SELECT_SHADER_F32: &str = include_str!("masked_select_f32.wgsl"); +const MASKED_SELECT_SHADER_I32: &str = include_str!("masked_select_i32.wgsl"); +const MASKED_SELECT_SHADER_U32: &str = include_str!("masked_select_u32.wgsl"); + +const EMBEDDING_LOOKUP_SHADER_F32: &str = include_str!("embedding_lookup_f32.wgsl"); +const EMBEDDING_LOOKUP_SHADER_I32: &str = include_str!("embedding_lookup_i32.wgsl"); +const EMBEDDING_LOOKUP_SHADER_U32: &str = include_str!("embedding_lookup_u32.wgsl"); + +const GATHER_ND_SHADER_F32: &str = include_str!("gather_nd_f32.wgsl"); +const GATHER_ND_SHADER_I32: &str = include_str!("gather_nd_i32.wgsl"); +const GATHER_ND_SHADER_U32: &str = include_str!("gather_nd_u32.wgsl"); + +const SCATTER_REDUCE_SUM_SHADER_F32: &str = include_str!("scatter_reduce_sum_f32.wgsl"); +const SCATTER_REDUCE_SUM_SHADER_I32: &str = include_str!("scatter_reduce_sum_i32.wgsl"); +const SCATTER_REDUCE_SUM_SHADER_U32: &str = include_str!("scatter_reduce_sum_u32.wgsl"); + +const SCATTER_REDUCE_MAX_SHADER_F32: &str = include_str!("scatter_reduce_max_f32.wgsl"); +const SCATTER_REDUCE_MAX_SHADER_I32: &str = include_str!("scatter_reduce_max_i32.wgsl"); +const SCATTER_REDUCE_MAX_SHADER_U32: &str = include_str!("scatter_reduce_max_u32.wgsl"); + +const SCATTER_REDUCE_MIN_SHADER_F32: &str = include_str!("scatter_reduce_min_f32.wgsl"); +const SCATTER_REDUCE_MIN_SHADER_I32: &str = include_str!("scatter_reduce_min_i32.wgsl"); +const SCATTER_REDUCE_MIN_SHADER_U32: &str = include_str!("scatter_reduce_min_u32.wgsl"); + +const SCATTER_REDUCE_PROD_SHADER_F32: &str = include_str!("scatter_reduce_prod_f32.wgsl"); +const SCATTER_REDUCE_PROD_SHADER_I32: &str = include_str!("scatter_reduce_prod_i32.wgsl"); +const SCATTER_REDUCE_PROD_SHADER_U32: &str = include_str!("scatter_reduce_prod_u32.wgsl"); + +const SCATTER_REDUCE_COUNT_SHADER_F32: &str = include_str!("scatter_reduce_count_f32.wgsl"); +const SCATTER_REDUCE_MEAN_DIV_SHADER_F32: &str = include_str!("scatter_reduce_mean_div_f32.wgsl"); + +const SLICE_ASSIGN_SHADER_F32: &str = include_str!("slice_assign_f32.wgsl"); +const SLICE_ASSIGN_SHADER_I32: &str = include_str!("slice_assign_i32.wgsl"); +const SLICE_ASSIGN_SHADER_U32: &str = include_str!("slice_assign_u32.wgsl"); + +const GATHER_2D_SHADER_F32: &str = include_str!("gather_2d_f32.wgsl"); +const GATHER_2D_SHADER_I32: &str = include_str!("gather_2d_i32.wgsl"); +const GATHER_2D_SHADER_U32: &str = include_str!("gather_2d_u32.wgsl"); + +// ============================================================================ +// Static shaders — dtype-agnostic ops +// ============================================================================ + +const VALIDATE_INDICES_SHADER: &str = include_str!("validate_indices.wgsl"); +const BINCOUNT_UNWEIGHTED_SHADER: &str = include_str!("bincount_i32.wgsl"); + +// ============================================================================ +// Static shaders — F32-only ops +// ============================================================================ + +const BINCOUNT_WEIGHTED_SHADER_F32: &str = include_str!("bincount_weighted_f32.wgsl"); + +// ============================================================================ +// Helpers +// ============================================================================ + +/// Returns (shader, module_key, entry_point) for standard index/scatter/gather ops. +fn shader_info( + op: &'static str, + dtype: DType, +) -> Result<(&'static str, &'static str, &'static str)> { + Ok(match (op, dtype) { + ("index_select", DType::F32) => ( + INDEX_SELECT_SHADER_F32, + "index_select_f32", + "index_select_f32", + ), + ("index_select", DType::I32) => ( + INDEX_SELECT_SHADER_I32, + "index_select_i32", + "index_select_i32", + ), + ("index_select", DType::U32) => ( + INDEX_SELECT_SHADER_U32, + "index_select_u32", + "index_select_u32", + ), + ("index_put", DType::F32) => (INDEX_PUT_SHADER_F32, "index_put_f32", "index_put_f32"), + ("index_put", DType::I32) => (INDEX_PUT_SHADER_I32, "index_put_i32", "index_put_i32"), + ("index_put", DType::U32) => (INDEX_PUT_SHADER_U32, "index_put_u32", "index_put_u32"), + ("gather", DType::F32) => (GATHER_SHADER_F32, "gather_f32", "gather_f32"), + ("gather", DType::I32) => (GATHER_SHADER_I32, "gather_i32", "gather_i32"), + ("gather", DType::U32) => (GATHER_SHADER_U32, "gather_u32", "gather_u32"), + ("scatter", DType::F32) => (SCATTER_SHADER_F32, "scatter_f32", "scatter_f32"), + ("scatter", DType::I32) => (SCATTER_SHADER_I32, "scatter_i32", "scatter_i32"), + ("scatter", DType::U32) => (SCATTER_SHADER_U32, "scatter_u32", "scatter_u32"), + // copy shares the scatter shader module but uses a different entry point + ("copy", DType::F32) => (SCATTER_SHADER_F32, "scatter_f32", "copy_f32"), + ("copy", DType::I32) => (SCATTER_SHADER_I32, "scatter_i32", "copy_i32"), + ("copy", DType::U32) => (SCATTER_SHADER_U32, "scatter_u32", "copy_u32"), + ("masked_fill", DType::F32) => { + (MASKED_FILL_SHADER_F32, "masked_fill_f32", "masked_fill_f32") + } + ("masked_fill", DType::I32) => { + (MASKED_FILL_SHADER_I32, "masked_fill_i32", "masked_fill_i32") + } + ("masked_fill", DType::U32) => { + (MASKED_FILL_SHADER_U32, "masked_fill_u32", "masked_fill_u32") + } + ("masked_select", DType::F32) => ( + MASKED_SELECT_SHADER_F32, + "masked_select_f32", + "masked_select_f32", + ), + ("masked_select", DType::I32) => ( + MASKED_SELECT_SHADER_I32, + "masked_select_i32", + "masked_select_i32", + ), + ("masked_select", DType::U32) => ( + MASKED_SELECT_SHADER_U32, + "masked_select_u32", + "masked_select_u32", + ), + // masked_count and masked_prefix_sum share the masked_select shader module + ("masked_count", DType::F32) => ( + MASKED_SELECT_SHADER_F32, + "masked_select_f32", + "masked_count", + ), + ("masked_count", DType::I32) => ( + MASKED_SELECT_SHADER_I32, + "masked_select_i32", + "masked_count", + ), + ("masked_count", DType::U32) => ( + MASKED_SELECT_SHADER_U32, + "masked_select_u32", + "masked_count", + ), + ("masked_prefix_sum", DType::F32) => ( + MASKED_SELECT_SHADER_F32, + "masked_select_f32", + "masked_prefix_sum", + ), + ("masked_prefix_sum", DType::I32) => ( + MASKED_SELECT_SHADER_I32, + "masked_select_i32", + "masked_prefix_sum", + ), + ("masked_prefix_sum", DType::U32) => ( + MASKED_SELECT_SHADER_U32, + "masked_select_u32", + "masked_prefix_sum", + ), + ("embedding_lookup", DType::F32) => ( + EMBEDDING_LOOKUP_SHADER_F32, + "embedding_lookup_f32", + "embedding_lookup_f32", + ), + ("embedding_lookup", DType::I32) => ( + EMBEDDING_LOOKUP_SHADER_I32, + "embedding_lookup_i32", + "embedding_lookup_i32", + ), + ("embedding_lookup", DType::U32) => ( + EMBEDDING_LOOKUP_SHADER_U32, + "embedding_lookup_u32", + "embedding_lookup_u32", + ), + ("gather_nd", DType::F32) => (GATHER_ND_SHADER_F32, "gather_nd_f32", "gather_nd_f32"), + ("gather_nd", DType::I32) => (GATHER_ND_SHADER_I32, "gather_nd_i32", "gather_nd_i32"), + ("gather_nd", DType::U32) => (GATHER_ND_SHADER_U32, "gather_nd_u32", "gather_nd_u32"), + ("scatter_reduce_sum", DType::F32) => ( + SCATTER_REDUCE_SUM_SHADER_F32, + "scatter_reduce_sum_f32", + "scatter_reduce_sum_f32", + ), + ("scatter_reduce_sum", DType::I32) => ( + SCATTER_REDUCE_SUM_SHADER_I32, + "scatter_reduce_sum_i32", + "scatter_reduce_sum_i32", + ), + ("scatter_reduce_sum", DType::U32) => ( + SCATTER_REDUCE_SUM_SHADER_U32, + "scatter_reduce_sum_u32", + "scatter_reduce_sum_u32", + ), + ("scatter_reduce_max", DType::F32) => ( + SCATTER_REDUCE_MAX_SHADER_F32, + "scatter_reduce_max_f32", + "scatter_reduce_max_f32", + ), + ("scatter_reduce_max", DType::I32) => ( + SCATTER_REDUCE_MAX_SHADER_I32, + "scatter_reduce_max_i32", + "scatter_reduce_max_i32", + ), + ("scatter_reduce_max", DType::U32) => ( + SCATTER_REDUCE_MAX_SHADER_U32, + "scatter_reduce_max_u32", + "scatter_reduce_max_u32", + ), + ("scatter_reduce_min", DType::F32) => ( + SCATTER_REDUCE_MIN_SHADER_F32, + "scatter_reduce_min_f32", + "scatter_reduce_min_f32", + ), + ("scatter_reduce_min", DType::I32) => ( + SCATTER_REDUCE_MIN_SHADER_I32, + "scatter_reduce_min_i32", + "scatter_reduce_min_i32", + ), + ("scatter_reduce_min", DType::U32) => ( + SCATTER_REDUCE_MIN_SHADER_U32, + "scatter_reduce_min_u32", + "scatter_reduce_min_u32", + ), + ("scatter_reduce_prod", DType::F32) => ( + SCATTER_REDUCE_PROD_SHADER_F32, + "scatter_reduce_prod_f32", + "scatter_reduce_prod_f32", + ), + ("scatter_reduce_prod", DType::I32) => ( + SCATTER_REDUCE_PROD_SHADER_I32, + "scatter_reduce_prod_i32", + "scatter_reduce_prod_i32", + ), + ("scatter_reduce_prod", DType::U32) => ( + SCATTER_REDUCE_PROD_SHADER_U32, + "scatter_reduce_prod_u32", + "scatter_reduce_prod_u32", + ), + ("scatter_reduce_count", DType::F32) => ( + SCATTER_REDUCE_COUNT_SHADER_F32, + "scatter_reduce_count_f32", + "scatter_reduce_count_f32", + ), + ("scatter_reduce_mean_div", DType::F32) => ( + SCATTER_REDUCE_MEAN_DIV_SHADER_F32, + "scatter_reduce_mean_div_f32", + "scatter_reduce_mean_div_f32", + ), + ("slice_assign", DType::F32) => ( + SLICE_ASSIGN_SHADER_F32, + "slice_assign_f32", + "slice_assign_f32", + ), + ("slice_assign", DType::I32) => ( + SLICE_ASSIGN_SHADER_I32, + "slice_assign_i32", + "slice_assign_i32", + ), + ("slice_assign", DType::U32) => ( + SLICE_ASSIGN_SHADER_U32, + "slice_assign_u32", + "slice_assign_u32", + ), + ("gather_2d", DType::F32) => (GATHER_2D_SHADER_F32, "gather_2d_f32", "gather_2d_f32"), + ("gather_2d", DType::I32) => (GATHER_2D_SHADER_I32, "gather_2d_i32", "gather_2d_i32"), + ("gather_2d", DType::U32) => (GATHER_2D_SHADER_U32, "gather_2d_u32", "gather_2d_u32"), + _ => return Err(Error::UnsupportedDType { dtype, op }), + }) } // ============================================================================ @@ -108,17 +312,15 @@ pub fn launch_index_select( total_output: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "index_select")?; + let (shader, module_key, entry_point) = shader_info("index_select", dtype)?; - let name = kernel_name("index_select", dtype)?; - let shader_source = generate_index_select_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, indices, output, params_buffer]); @@ -160,17 +362,15 @@ pub fn launch_index_put( total_src: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "index_put")?; + let (shader, module_key, entry_point) = shader_info("index_put", dtype)?; - let name = kernel_name("index_put", dtype)?; - let shader_source = generate_index_put_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[indices, src, output, params_buffer]); @@ -215,15 +415,14 @@ pub fn launch_validate_indices( return Ok(()); } - let name = "validate_indices"; - let shader_source = generate_validate_indices_shader(); - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module("validate_indices", VALIDATE_INDICES_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = + cache.get_or_create_pipeline("validate_indices", "validate_indices", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[indices, error_count, params_buffer]); @@ -264,17 +463,15 @@ pub fn launch_gather( total_elements: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "gather")?; + let (shader, module_key, entry_point) = shader_info("gather", dtype)?; - let name = kernel_name("gather", dtype)?; - let shader_source = generate_gather_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, indices, output, params_buffer]); @@ -312,20 +509,15 @@ pub fn launch_copy( numel: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "copy")?; - - // Copy kernel is defined in the scatter shader module - let mod_name = kernel_name("scatter", dtype)?; - let entry_point = kernel_name("copy", dtype)?; + let (shader, module_key, entry_point) = shader_info("copy", dtype)?; - let shader_source = generate_scatter_shader(dtype)?; - let module = cache.get_or_create_module(mod_name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(mod_name, entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[src, dst, params_buffer]); @@ -362,17 +554,15 @@ pub fn launch_scatter( src_total: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "scatter")?; + let (shader, module_key, entry_point) = shader_info("scatter", dtype)?; - let name = kernel_name("scatter", dtype)?; - let shader_source = generate_scatter_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[src, indices, output, params_buffer]); @@ -413,17 +603,15 @@ pub fn launch_masked_fill( numel: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "masked_fill")?; + let (shader, module_key, entry_point) = shader_info("masked_fill", dtype)?; - let name = kernel_name("masked_fill", dtype)?; - let shader_source = generate_masked_fill_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, mask, output, params_buffer]); @@ -463,11 +651,9 @@ pub fn launch_masked_count( numel: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "masked_count")?; + let (shader, module_key, entry_point) = shader_info("masked_count", dtype)?; - let mod_name = kernel_name("masked_select", dtype)?; - let shader_source = generate_masked_select_shader(dtype)?; - let module = cache.get_or_create_module(mod_name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); // For count: mask (read), count_result (atomic), params let layout = cache.get_or_create_layout(LayoutKey { @@ -475,7 +661,7 @@ pub fn launch_masked_count( num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(mod_name, "masked_count", &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[mask, count_result, params_buffer]); @@ -511,18 +697,16 @@ pub fn launch_masked_prefix_sum( _numel: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "masked_prefix_sum")?; + let (shader, module_key, entry_point) = shader_info("masked_prefix_sum", dtype)?; - let mod_name = kernel_name("masked_select", dtype)?; - let shader_source = generate_masked_select_shader(dtype)?; - let module = cache.get_or_create_module(mod_name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(mod_name, "masked_prefix_sum", &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[mask, prefix_sum, params_buffer]); @@ -561,20 +745,16 @@ pub fn launch_masked_select( numel: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "masked_select")?; - - let mod_name = kernel_name("masked_select", dtype)?; - let entry_point = kernel_name("masked_select", dtype)?; + let (shader, module_key, entry_point) = shader_info("masked_select", dtype)?; - let shader_source = generate_masked_select_shader(dtype)?; - let module = cache.get_or_create_module(mod_name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 4, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(mod_name, entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, mask, prefix_sum, output, params_buffer]); @@ -618,17 +798,15 @@ pub fn launch_gather_nd( total_output: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "gather_nd")?; + let (shader, module_key, entry_point) = shader_info("gather_nd", dtype)?; - let name = kernel_name("gather_nd", dtype)?; - let shader_source = super::generator::generate_gather_nd_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 4, num_uniform_buffers: 0, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, indices, output, params_buffer]); @@ -671,17 +849,20 @@ pub fn launch_bincount( n: usize, weights_dtype: Option, ) -> Result<()> { - let (name, shader_source) = if let Some(dtype) = weights_dtype { - let name = kernel_name("bincount", dtype)?; - let source = super::generator::generate_bincount_shader(Some(dtype))?; - (name, source) + let (name, shader) = if let Some(dtype) = weights_dtype { + // bincount_weighted is F32 only (uses float atomics) + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "bincount_weighted", + }); + } + ("bincount_weighted_f32", BINCOUNT_WEIGHTED_SHADER_F32) } else { - let name = kernel_name("bincount_unweighted", DType::I32)?; - let source = super::generator::generate_bincount_shader(None)?; - (name, source) + ("bincount_i32", BINCOUNT_UNWEIGHTED_SHADER) }; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module(name, shader); let (layout, bind_group) = if let Some(weights_buf) = weights { let layout = cache.get_or_create_layout(LayoutKey { @@ -743,8 +924,6 @@ pub fn launch_scatter_reduce( dtype: DType, op: &str, ) -> Result<()> { - check_dtype_supported(dtype, "scatter_reduce")?; - // Get static kernel name based on op type let op_name: &'static str = match op { "sum" => "scatter_reduce_sum", @@ -758,15 +937,15 @@ pub fn launch_scatter_reduce( } }; - let name = kernel_name(op_name, dtype)?; - let shader_source = super::generator::generate_scatter_reduce_shader(dtype, op)?; - let module = cache.get_or_create_module(name, &shader_source); + let (shader, module_key, entry_point) = shader_info(op_name, dtype)?; + + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[src, indices, dst, params_buffer]); @@ -807,17 +986,15 @@ pub fn launch_scatter_reduce_prod( total_src: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "scatter_reduce_prod")?; + let (shader, module_key, entry_point) = shader_info("scatter_reduce_prod", dtype)?; - let name = kernel_name("scatter_reduce_prod", dtype)?; - let shader_source = super::generator::generate_scatter_reduce_prod_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[src, indices, dst, params_buffer]); @@ -857,15 +1034,15 @@ pub fn launch_scatter_reduce_count( total_src: usize, dtype: DType, ) -> Result<()> { - let name = kernel_name("scatter_reduce_count", dtype)?; - let shader_source = super::generator::generate_scatter_reduce_count_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let (shader, module_key, entry_point) = shader_info("scatter_reduce_count", dtype)?; + + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[indices, count, params_buffer]); @@ -904,17 +1081,15 @@ pub fn launch_scatter_reduce_mean_div( n: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "scatter_reduce_mean_div")?; + let (shader, module_key, entry_point) = shader_info("scatter_reduce_mean_div", dtype)?; - let name = kernel_name("scatter_reduce_mean_div", dtype)?; - let shader_source = super::generator::generate_scatter_reduce_mean_div_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[sum_buf, count_buf, output, params_buffer]); @@ -960,17 +1135,15 @@ pub fn launch_embedding_lookup( num_indices: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "embedding_lookup")?; + let (shader, module_key, entry_point) = shader_info("embedding_lookup", dtype)?; - let name = kernel_name("embedding_lookup", dtype)?; - let shader_source = generate_embedding_lookup_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[embeddings, indices, output, params_buffer]); @@ -995,6 +1168,55 @@ pub fn launch_embedding_lookup( Ok(()) } +// ============================================================================ +// Slice Assign Operation +// ============================================================================ + +/// Launch a slice_assign operation kernel. +/// +/// Overwrites a slice of the output tensor with src values along a dimension. +/// Output should already contain a copy of dst data. +pub fn launch_slice_assign( + cache: &PipelineCache, + queue: &Queue, + src: &Buffer, + output: &Buffer, + params_buffer: &Buffer, + total_src: usize, + dtype: DType, +) -> Result<()> { + let (shader, module_key, entry_point) = shader_info("slice_assign", dtype)?; + + let module = cache.get_or_create_module(module_key, shader); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 2, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); + + let bind_group = cache.create_bind_group(&layout, &[src, output, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("slice_assign"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("slice_assign"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + pass.dispatch_workgroups(workgroup_count(total_src), 1, 1); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + // ============================================================================ // Gather 2D Operation // ============================================================================ @@ -1018,17 +1240,15 @@ pub fn launch_gather_2d( num_indices: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "gather_2d")?; + let (shader, module_key, entry_point) = shader_info("gather_2d", dtype)?; - let name = kernel_name("gather_2d", dtype)?; - let shader_source = super::generator::generate_gather_2d_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 4, num_uniform_buffers: 1, num_readonly_storage: 3, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, rows, cols, output, params_buffer]); diff --git a/src/runtime/wgpu/shaders/index_put_f32.wgsl b/src/runtime/wgpu/shaders/index_put_f32.wgsl new file mode 100644 index 00000000..5489374f --- /dev/null +++ b/src/runtime/wgpu/shaders/index_put_f32.wgsl @@ -0,0 +1,36 @@ +// Auto-generated index_put operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct IndexPutParams { + outer_size: u32, + dim_size: u32, + inner_size: u32, + index_len: u32, +} + +@group(0) @binding(0) var indices: array; +@group(0) @binding(1) var src: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: IndexPutParams; + +@compute @workgroup_size(256) +fn index_put_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = params.outer_size * params.index_len * params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % params.inner_size; + let sel_idx = (idx / params.inner_size) % params.index_len; + let outer = idx / (params.index_len * params.inner_size); + + let index_val = indices[sel_idx]; + if (index_val < 0 || u32(index_val) >= params.dim_size) { + return; // Out of bounds - skip + } + + let dst_offset = outer * params.dim_size * params.inner_size + u32(index_val) * params.inner_size + inner; + output[dst_offset] = src[idx]; +} diff --git a/src/runtime/wgpu/shaders/index_put_i32.wgsl b/src/runtime/wgpu/shaders/index_put_i32.wgsl new file mode 100644 index 00000000..ad4c4931 --- /dev/null +++ b/src/runtime/wgpu/shaders/index_put_i32.wgsl @@ -0,0 +1,36 @@ +// Auto-generated index_put operations for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct IndexPutParams { + outer_size: u32, + dim_size: u32, + inner_size: u32, + index_len: u32, +} + +@group(0) @binding(0) var indices: array; +@group(0) @binding(1) var src: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: IndexPutParams; + +@compute @workgroup_size(256) +fn index_put_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = params.outer_size * params.index_len * params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % params.inner_size; + let sel_idx = (idx / params.inner_size) % params.index_len; + let outer = idx / (params.index_len * params.inner_size); + + let index_val = indices[sel_idx]; + if (index_val < 0 || u32(index_val) >= params.dim_size) { + return; // Out of bounds - skip + } + + let dst_offset = outer * params.dim_size * params.inner_size + u32(index_val) * params.inner_size + inner; + output[dst_offset] = src[idx]; +} diff --git a/src/runtime/wgpu/shaders/index_put_u32.wgsl b/src/runtime/wgpu/shaders/index_put_u32.wgsl new file mode 100644 index 00000000..8dae1b7b --- /dev/null +++ b/src/runtime/wgpu/shaders/index_put_u32.wgsl @@ -0,0 +1,36 @@ +// Auto-generated index_put operations for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct IndexPutParams { + outer_size: u32, + dim_size: u32, + inner_size: u32, + index_len: u32, +} + +@group(0) @binding(0) var indices: array; +@group(0) @binding(1) var src: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: IndexPutParams; + +@compute @workgroup_size(256) +fn index_put_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = params.outer_size * params.index_len * params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % params.inner_size; + let sel_idx = (idx / params.inner_size) % params.index_len; + let outer = idx / (params.index_len * params.inner_size); + + let index_val = indices[sel_idx]; + if (index_val < 0 || u32(index_val) >= params.dim_size) { + return; // Out of bounds - skip + } + + let dst_offset = outer * params.dim_size * params.inner_size + u32(index_val) * params.inner_size + inner; + output[dst_offset] = src[idx]; +} diff --git a/src/runtime/wgpu/shaders/index_select_f32.wgsl b/src/runtime/wgpu/shaders/index_select_f32.wgsl new file mode 100644 index 00000000..13add251 --- /dev/null +++ b/src/runtime/wgpu/shaders/index_select_f32.wgsl @@ -0,0 +1,37 @@ +// Auto-generated index_select operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct IndexSelectParams { + outer_size: u32, + dim_size: u32, + inner_size: u32, + index_len: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var indices: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: IndexSelectParams; + +@compute @workgroup_size(256) +fn index_select_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = params.outer_size * params.index_len * params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % params.inner_size; + let sel_idx = (idx / params.inner_size) % params.index_len; + let outer = idx / (params.index_len * params.inner_size); + + let index_val = indices[sel_idx]; + if (index_val < 0 || u32(index_val) >= params.dim_size) { + output[idx] = 0.0; + return; + } + + let src_offset = outer * params.dim_size * params.inner_size + u32(index_val) * params.inner_size + inner; + output[idx] = input[src_offset]; +} diff --git a/src/runtime/wgpu/shaders/index_select_i32.wgsl b/src/runtime/wgpu/shaders/index_select_i32.wgsl new file mode 100644 index 00000000..c677544d --- /dev/null +++ b/src/runtime/wgpu/shaders/index_select_i32.wgsl @@ -0,0 +1,37 @@ +// Auto-generated index_select operations for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct IndexSelectParams { + outer_size: u32, + dim_size: u32, + inner_size: u32, + index_len: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var indices: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: IndexSelectParams; + +@compute @workgroup_size(256) +fn index_select_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = params.outer_size * params.index_len * params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % params.inner_size; + let sel_idx = (idx / params.inner_size) % params.index_len; + let outer = idx / (params.index_len * params.inner_size); + + let index_val = indices[sel_idx]; + if (index_val < 0 || u32(index_val) >= params.dim_size) { + output[idx] = 0; + return; + } + + let src_offset = outer * params.dim_size * params.inner_size + u32(index_val) * params.inner_size + inner; + output[idx] = input[src_offset]; +} diff --git a/src/runtime/wgpu/shaders/index_select_u32.wgsl b/src/runtime/wgpu/shaders/index_select_u32.wgsl new file mode 100644 index 00000000..1b8dcde1 --- /dev/null +++ b/src/runtime/wgpu/shaders/index_select_u32.wgsl @@ -0,0 +1,37 @@ +// Auto-generated index_select operations for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct IndexSelectParams { + outer_size: u32, + dim_size: u32, + inner_size: u32, + index_len: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var indices: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: IndexSelectParams; + +@compute @workgroup_size(256) +fn index_select_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = params.outer_size * params.index_len * params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % params.inner_size; + let sel_idx = (idx / params.inner_size) % params.index_len; + let outer = idx / (params.index_len * params.inner_size); + + let index_val = indices[sel_idx]; + if (index_val < 0 || u32(index_val) >= params.dim_size) { + output[idx] = 0u; + return; + } + + let src_offset = outer * params.dim_size * params.inner_size + u32(index_val) * params.inner_size + inner; + output[idx] = input[src_offset]; +} diff --git a/src/runtime/wgpu/shaders/irfft_unpack.wgsl b/src/runtime/wgpu/shaders/irfft_unpack.wgsl new file mode 100644 index 00000000..55787538 --- /dev/null +++ b/src/runtime/wgpu/shaders/irfft_unpack.wgsl @@ -0,0 +1,32 @@ +// irfft unpack shader - extracts real part from complex + +const WORKGROUP_SIZE: u32 = 256u; + +struct UnpackParams { + n: u32, + batch_size: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var unpack_input: array>; +@group(0) @binding(1) var unpack_output: array; +@group(0) @binding(2) var unpack_params: UnpackParams; + +@compute @workgroup_size(WORKGROUP_SIZE) +fn irfft_unpack( + @builtin(global_invocation_id) gid: vec3 +) { + let idx = gid.x; + let batch_idx = gid.y; + let n = unpack_params.n; + + if (idx >= n) { + return; + } + + let in_offset = batch_idx * n; + let out_offset = batch_idx * n; + + unpack_output[out_offset + idx] = unpack_input[in_offset + idx].x; +} diff --git a/src/runtime/wgpu/shaders/laplace_f32.wgsl b/src/runtime/wgpu/shaders/laplace_f32.wgsl new file mode 100644 index 00000000..42a52813 --- /dev/null +++ b/src/runtime/wgpu/shaders/laplace_f32.wgsl @@ -0,0 +1,40 @@ +// Laplace distribution sampling for f32 + +// PCG hash function for random number generation +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct LaplaceParams { + numel: u32, + seed: u32, + loc: f32, + scale: f32, +} + +@group(0) @binding(0) var out: array; +@group(0) @binding(1) var params: LaplaceParams; + +@compute @workgroup_size(256) +fn laplace_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if idx < params.numel { + var state = pcg_init(params.seed, idx); + let u = pcg_uniform(&state) - 0.5; + let result = params.loc - params.scale * sign(u) * log(1.0 - 2.0 * abs(u)); + out[idx] = f32(result); + } +} diff --git a/src/runtime/wgpu/shaders/linalg_wgsl.rs b/src/runtime/wgpu/shaders/linalg_wgsl.rs deleted file mode 100644 index afab7477..00000000 --- a/src/runtime/wgpu/shaders/linalg_wgsl.rs +++ /dev/null @@ -1,26 +0,0 @@ -//! WGSL shader source code for linear algebra operations -//! -//! This module provides the combined linear algebra shader used by all linalg operations. -//! The shader source is maintained in `linalg_combined.wgsl` which contains all operations: -//! -//! - Basic ops: trace, diagonal, identity -//! - Solvers: forward/backward substitution -//! - Decompositions: LU, Cholesky, QR -//! - Utilities: determinant, permutation, column operations -//! - SVD: Singular value decomposition (Jacobi) -//! - Eigendecomposition: Symmetric and general cases -//! - Schur decomposition -//! - Matrix functions: expm, sqrtm, logm -//! -//! # Future Work -//! -//! Individual shader modules exist in `linalg_shaders/` for potential fine-grained -//! compilation, but are not currently used. This could reduce shader compilation time -//! for specialized applications that only need specific operations. - -/// Combined linear algebra shader containing all operations. -/// -/// This shader is used by all linear algebra launchers and includes all operations -/// from basic matrix ops to advanced decompositions. -#[allow(dead_code)] -pub const LINALG_SHADER: &str = include_str!("linalg_combined.wgsl"); diff --git a/src/runtime/wgpu/shaders/linspace_f32.wgsl b/src/runtime/wgpu/shaders/linspace_f32.wgsl new file mode 100644 index 00000000..d8abb948 --- /dev/null +++ b/src/runtime/wgpu/shaders/linspace_f32.wgsl @@ -0,0 +1,22 @@ +// Auto-generated linspace operation for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct LinspaceParams { + steps: u32, + start: f32, + stop: f32, +} + +@group(0) @binding(0) var linspace_out: array; +@group(0) @binding(1) var linspace_params: LinspaceParams; + +@compute @workgroup_size(256) +fn linspace_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < linspace_params.steps) { + let t_val = f32(idx) / f32(linspace_params.steps - 1u); + let value = linspace_params.start + (linspace_params.stop - linspace_params.start) * t_val; + linspace_out[idx] = f32(value); + } +} diff --git a/src/runtime/wgpu/shaders/logsumexp_f32.wgsl b/src/runtime/wgpu/shaders/logsumexp_f32.wgsl new file mode 100644 index 00000000..4e21e8e4 --- /dev/null +++ b/src/runtime/wgpu/shaders/logsumexp_f32.wgsl @@ -0,0 +1,39 @@ +// Log-sum-exp shader for f32 +// +// Computes log(sum(exp(x))) in a numerically stable way: +// logsumexp(x) = max(x) + log(sum(exp(x - max(x)))) + +struct LogsumexpParams { + reduce_size: u32, + outer_size: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: LogsumexpParams; + +@compute @workgroup_size(256) +fn logsumexp_f32(@builtin(global_invocation_id) global_id: vec3) { + let outer_idx = global_id.x; + if (outer_idx >= params.outer_size) { + return; + } + + let base = outer_idx * params.reduce_size; + + // Step 1: Find max value + var max_val: f32 = -3.402823e+38; + for (var i: u32 = 0u; i < params.reduce_size; i = i + 1u) { + let val = input[base + i]; + max_val = max(max_val, val); + } + + // Step 2: Compute sum(exp(x - max)) + var sum_exp: f32 = 0.0; + for (var i: u32 = 0u; i < params.reduce_size; i = i + 1u) { + sum_exp = sum_exp + exp(input[base + i] - max_val); + } + + // Step 3: Result = max + log(sum) + output[outer_idx] = max_val + log(sum_exp); +} diff --git a/src/runtime/wgpu/shaders/logsumexp_strided_f32.wgsl b/src/runtime/wgpu/shaders/logsumexp_strided_f32.wgsl new file mode 100644 index 00000000..4c5c2d82 --- /dev/null +++ b/src/runtime/wgpu/shaders/logsumexp_strided_f32.wgsl @@ -0,0 +1,40 @@ +// Strided log-sum-exp shader for f32 + +struct LogsumexpStridedParams { + reduce_size: u32, + outer_size: u32, + inner_size: u32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: LogsumexpStridedParams; + +@compute @workgroup_size(256) +fn logsumexp_strided_f32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + let total_inner = params.outer_size * params.inner_size; + if (idx >= total_inner) { + return; + } + + let outer_idx = idx / params.inner_size; + let inner_idx = idx % params.inner_size; + + // Step 1: Find max value along reduce dimension + var max_val: f32 = -3.402823e+38; + for (var r: u32 = 0u; r < params.reduce_size; r = r + 1u) { + let offset = outer_idx * params.reduce_size * params.inner_size + r * params.inner_size + inner_idx; + max_val = max(max_val, input[offset]); + } + + // Step 2: Compute sum(exp(x - max)) + var sum_exp: f32 = 0.0; + for (var r: u32 = 0u; r < params.reduce_size; r = r + 1u) { + let offset = outer_idx * params.reduce_size * params.inner_size + r * params.inner_size + inner_idx; + sum_exp = sum_exp + exp(input[offset] - max_val); + } + + // Step 3: Write result + output[outer_idx * params.inner_size + inner_idx] = max_val + log(sum_exp); +} diff --git a/src/runtime/wgpu/shaders/masked_fill_f32.wgsl b/src/runtime/wgpu/shaders/masked_fill_f32.wgsl new file mode 100644 index 00000000..41a07bde --- /dev/null +++ b/src/runtime/wgpu/shaders/masked_fill_f32.wgsl @@ -0,0 +1,27 @@ +// Auto-generated masked_fill operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct MaskedFillParams { + numel: u32, + fill_value: f32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var mask: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: MaskedFillParams; + +@compute @workgroup_size(256) +fn masked_fill_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.numel) { + return; + } + + if (mask[idx] != 0u) { + output[idx] = f32(params.fill_value); + } else { + output[idx] = input[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/masked_fill_i32.wgsl b/src/runtime/wgpu/shaders/masked_fill_i32.wgsl new file mode 100644 index 00000000..5daa0fb4 --- /dev/null +++ b/src/runtime/wgpu/shaders/masked_fill_i32.wgsl @@ -0,0 +1,27 @@ +// Auto-generated masked_fill operations for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct MaskedFillParams { + numel: u32, + fill_value: f32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var mask: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: MaskedFillParams; + +@compute @workgroup_size(256) +fn masked_fill_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.numel) { + return; + } + + if (mask[idx] != 0u) { + output[idx] = i32(params.fill_value); + } else { + output[idx] = input[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/masked_fill_u32.wgsl b/src/runtime/wgpu/shaders/masked_fill_u32.wgsl new file mode 100644 index 00000000..d5d791fc --- /dev/null +++ b/src/runtime/wgpu/shaders/masked_fill_u32.wgsl @@ -0,0 +1,27 @@ +// Auto-generated masked_fill operations for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct MaskedFillParams { + numel: u32, + fill_value: f32, +} + +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var mask: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: MaskedFillParams; + +@compute @workgroup_size(256) +fn masked_fill_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.numel) { + return; + } + + if (mask[idx] != 0u) { + output[idx] = u32(params.fill_value); + } else { + output[idx] = input[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/masked_select_f32.wgsl b/src/runtime/wgpu/shaders/masked_select_f32.wgsl new file mode 100644 index 00000000..b73e7f56 --- /dev/null +++ b/src/runtime/wgpu/shaders/masked_select_f32.wgsl @@ -0,0 +1,87 @@ +// Auto-generated masked_select operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +// Phase 1: Count masked elements +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var count_mask: array; +@group(0) @binding(1) var count_result: atomic; +@group(0) @binding(2) var count_params: CountParams; + +var shared_count: atomic; + +@compute @workgroup_size(256) +fn masked_count(@builtin(global_invocation_id) gid: vec3, + @builtin(local_invocation_id) lid: vec3) { + if (lid.x == 0u) { + atomicStore(&shared_count, 0u); + } + workgroupBarrier(); + + var local_count: u32 = 0u; + var i = gid.x; + while (i < count_params.numel) { + if (count_mask[i] != 0u) { + local_count = local_count + 1u; + } + i = i + 256u * 256u; // Grid stride + } + + atomicAdd(&shared_count, local_count); + workgroupBarrier(); + + if (lid.x == 0u) { + atomicAdd(&count_result, atomicLoad(&shared_count)); + } +} + +// Phase 2: Compute prefix sum (sequential - for small arrays) +struct PrefixSumParams { + numel: u32, +} + +@group(0) @binding(0) var prefix_mask: array; +@group(0) @binding(1) var prefix_sum: array; +@group(0) @binding(2) var prefix_params: PrefixSumParams; + +@compute @workgroup_size(1) +fn masked_prefix_sum(@builtin(global_invocation_id) gid: vec3) { + if (gid.x != 0u) { + return; + } + + var sum: u32 = 0u; + for (var i: u32 = 0u; i < prefix_params.numel; i = i + 1u) { + prefix_sum[i] = sum; + if (prefix_mask[i] != 0u) { + sum = sum + 1u; + } + } +} + +// Phase 3: Gather selected elements +struct SelectParams { + numel: u32, +} + +@group(0) @binding(0) var select_input: array; +@group(0) @binding(1) var select_mask: array; +@group(0) @binding(2) var select_prefix: array; +@group(0) @binding(3) var select_output: array; +@group(0) @binding(4) var select_params: SelectParams; + +@compute @workgroup_size(256) +fn masked_select_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= select_params.numel) { + return; + } + + if (select_mask[idx] != 0u) { + let out_idx = select_prefix[idx]; + select_output[out_idx] = select_input[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/masked_select_i32.wgsl b/src/runtime/wgpu/shaders/masked_select_i32.wgsl new file mode 100644 index 00000000..d6618e8a --- /dev/null +++ b/src/runtime/wgpu/shaders/masked_select_i32.wgsl @@ -0,0 +1,87 @@ +// Auto-generated masked_select operations for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +// Phase 1: Count masked elements +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var count_mask: array; +@group(0) @binding(1) var count_result: atomic; +@group(0) @binding(2) var count_params: CountParams; + +var shared_count: atomic; + +@compute @workgroup_size(256) +fn masked_count(@builtin(global_invocation_id) gid: vec3, + @builtin(local_invocation_id) lid: vec3) { + if (lid.x == 0u) { + atomicStore(&shared_count, 0u); + } + workgroupBarrier(); + + var local_count: u32 = 0u; + var i = gid.x; + while (i < count_params.numel) { + if (count_mask[i] != 0u) { + local_count = local_count + 1u; + } + i = i + 256u * 256u; // Grid stride + } + + atomicAdd(&shared_count, local_count); + workgroupBarrier(); + + if (lid.x == 0u) { + atomicAdd(&count_result, atomicLoad(&shared_count)); + } +} + +// Phase 2: Compute prefix sum (sequential - for small arrays) +struct PrefixSumParams { + numel: u32, +} + +@group(0) @binding(0) var prefix_mask: array; +@group(0) @binding(1) var prefix_sum: array; +@group(0) @binding(2) var prefix_params: PrefixSumParams; + +@compute @workgroup_size(1) +fn masked_prefix_sum(@builtin(global_invocation_id) gid: vec3) { + if (gid.x != 0u) { + return; + } + + var sum: u32 = 0u; + for (var i: u32 = 0u; i < prefix_params.numel; i = i + 1u) { + prefix_sum[i] = sum; + if (prefix_mask[i] != 0u) { + sum = sum + 1u; + } + } +} + +// Phase 3: Gather selected elements +struct SelectParams { + numel: u32, +} + +@group(0) @binding(0) var select_input: array; +@group(0) @binding(1) var select_mask: array; +@group(0) @binding(2) var select_prefix: array; +@group(0) @binding(3) var select_output: array; +@group(0) @binding(4) var select_params: SelectParams; + +@compute @workgroup_size(256) +fn masked_select_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= select_params.numel) { + return; + } + + if (select_mask[idx] != 0u) { + let out_idx = select_prefix[idx]; + select_output[out_idx] = select_input[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/masked_select_u32.wgsl b/src/runtime/wgpu/shaders/masked_select_u32.wgsl new file mode 100644 index 00000000..7d6eaeb9 --- /dev/null +++ b/src/runtime/wgpu/shaders/masked_select_u32.wgsl @@ -0,0 +1,87 @@ +// Auto-generated masked_select operations for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +// Phase 1: Count masked elements +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var count_mask: array; +@group(0) @binding(1) var count_result: atomic; +@group(0) @binding(2) var count_params: CountParams; + +var shared_count: atomic; + +@compute @workgroup_size(256) +fn masked_count(@builtin(global_invocation_id) gid: vec3, + @builtin(local_invocation_id) lid: vec3) { + if (lid.x == 0u) { + atomicStore(&shared_count, 0u); + } + workgroupBarrier(); + + var local_count: u32 = 0u; + var i = gid.x; + while (i < count_params.numel) { + if (count_mask[i] != 0u) { + local_count = local_count + 1u; + } + i = i + 256u * 256u; // Grid stride + } + + atomicAdd(&shared_count, local_count); + workgroupBarrier(); + + if (lid.x == 0u) { + atomicAdd(&count_result, atomicLoad(&shared_count)); + } +} + +// Phase 2: Compute prefix sum (sequential - for small arrays) +struct PrefixSumParams { + numel: u32, +} + +@group(0) @binding(0) var prefix_mask: array; +@group(0) @binding(1) var prefix_sum: array; +@group(0) @binding(2) var prefix_params: PrefixSumParams; + +@compute @workgroup_size(1) +fn masked_prefix_sum(@builtin(global_invocation_id) gid: vec3) { + if (gid.x != 0u) { + return; + } + + var sum: u32 = 0u; + for (var i: u32 = 0u; i < prefix_params.numel; i = i + 1u) { + prefix_sum[i] = sum; + if (prefix_mask[i] != 0u) { + sum = sum + 1u; + } + } +} + +// Phase 3: Gather selected elements +struct SelectParams { + numel: u32, +} + +@group(0) @binding(0) var select_input: array; +@group(0) @binding(1) var select_mask: array; +@group(0) @binding(2) var select_prefix: array; +@group(0) @binding(3) var select_output: array; +@group(0) @binding(4) var select_params: SelectParams; + +@compute @workgroup_size(256) +fn masked_select_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= select_params.numel) { + return; + } + + if (select_mask[idx] != 0u) { + let out_idx = select_prefix[idx]; + select_output[out_idx] = select_input[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/matmul.rs b/src/runtime/wgpu/shaders/matmul.rs index 2898f6fb..a7d35c0d 100644 --- a/src/runtime/wgpu/shaders/matmul.rs +++ b/src/runtime/wgpu/shaders/matmul.rs @@ -1,21 +1,17 @@ -//! Matrix multiplication WGSL kernel launchers -//! -//! Provides launchers for matrix multiplication operations: -//! - 2D matrix multiplication (C = A @ B) -//! - Batched matrix multiplication -//! - Matrix-vector multiplication -//! - Fused matmul with bias (C = A @ B + bias) -//! -//! All operations run entirely on GPU with no CPU fallback. +//! Matrix multiplication WGSL kernel launchers. F32 only. use wgpu::{Buffer, Queue}; -use super::generator::generate_matmul_bias_shader; -use super::matmul_wgsl::MATMUL_SHADER; use super::pipeline::{LayoutKey, PipelineCache}; use crate::dtype::DType; use crate::error::{Error, Result}; +const MATMUL_SHADER: &str = include_str!("matmul.wgsl"); +const MATMUL_BIAS_SHADER: &str = include_str!("matmul_bias_f32.wgsl"); + +/// Tile size for tiled matrix multiplication (must match shader constant) +const TILE_SIZE: u32 = 16; + // ============================================================================ // Helper Macros // ============================================================================ @@ -31,9 +27,6 @@ macro_rules! check_dtype_f32 { }; } -/// Tile size for tiled matrix multiplication (must match shader constant) -const TILE_SIZE: u32 = 16; - // ============================================================================ // 2D Matrix Multiplication // ============================================================================ @@ -77,7 +70,6 @@ pub fn launch_matmul( }); pass.set_pipeline(&pipeline); pass.set_bind_group(0, Some(&bind_group), &[]); - // Number of workgroups in x (columns) and y (rows) dimensions let num_groups_x = (n as u32 + TILE_SIZE - 1) / TILE_SIZE; let num_groups_y = (m as u32 + TILE_SIZE - 1) / TILE_SIZE; pass.dispatch_workgroups(num_groups_x, num_groups_y, 1); @@ -126,7 +118,6 @@ pub fn launch_matmul_simple( }); pass.set_pipeline(&pipeline); pass.set_bind_group(0, Some(&bind_group), &[]); - // One thread per output element let total = m * n; let num_groups = (total as u32 + 255) / 256; pass.dispatch_workgroups(num_groups, 1, 1); @@ -231,7 +222,6 @@ pub fn launch_matvec( }); pass.set_pipeline(&pipeline); pass.set_bind_group(0, Some(&bind_group), &[]); - // One workgroup per output row pass.dispatch_workgroups(m as u32, 1, 1); } @@ -243,45 +233,9 @@ pub fn launch_matvec( // Fused Matrix Multiplication with Bias // ============================================================================ -/// Helper to get static module key and entry point for matmul_bias -fn matmul_bias_keys(dtype: DType) -> Result<(&'static str, &'static str, &'static str)> { - match dtype { - DType::F32 => Ok(( - "matmul_bias_f32", - "matmul_bias_f32", - "batched_matmul_bias_f32", - )), - DType::I32 => Ok(( - "matmul_bias_i32", - "matmul_bias_i32", - "batched_matmul_bias_i32", - )), - DType::U32 => Ok(( - "matmul_bias_u32", - "matmul_bias_u32", - "batched_matmul_bias_u32", - )), - DType::F16 => Ok(( - "matmul_bias_f16", - "matmul_bias_f16", - "batched_matmul_bias_f16", - )), - _ => Err(Error::UnsupportedDType { - dtype, - op: "matmul_bias", - }), - } -} - /// Launch tiled matrix multiplication with fused bias addition. /// -/// Computes C = A @ B + bias where: -/// - A is `[M, K]` -/// - B is `[K, N]` -/// - bias is `[N]` (broadcast across rows) -/// - C is `[M, N]` -/// -/// The bias addition is fused into the GEMM epilogue for efficiency. +/// Computes C = A @ B + bias where bias is `[N]` (broadcast across rows). pub fn launch_matmul_bias( cache: &PipelineCache, queue: &Queue, @@ -294,19 +248,17 @@ pub fn launch_matmul_bias( n: usize, dtype: DType, ) -> Result<()> { - // Get static keys and generate shader - let (module_key, entry_point, _) = matmul_bias_keys(dtype)?; - let shader_source = generate_matmul_bias_shader(dtype)?; + check_dtype_f32!(dtype, "matmul_bias"); - let module = cache.get_or_create_module(module_key, &shader_source); + let module = cache.get_or_create_module("matmul_bias_f32", MATMUL_BIAS_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 4, // a, b, bias, c num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); + let pipeline = + cache.get_or_create_pipeline("matmul_bias_f32", "matmul_bias_f32", &module, &layout); - // Bind buffers: a, b, bias, c, params let bind_group = cache.create_bind_group(&layout, &[a, b, bias, c, params_buffer]); let mut encoder = cache @@ -322,7 +274,6 @@ pub fn launch_matmul_bias( }); pass.set_pipeline(&pipeline); pass.set_bind_group(0, Some(&bind_group), &[]); - // Number of workgroups in x (columns) and y (rows) dimensions let num_groups_x = (n as u32 + TILE_SIZE - 1) / TILE_SIZE; let num_groups_y = (m as u32 + TILE_SIZE - 1) / TILE_SIZE; pass.dispatch_workgroups(num_groups_x, num_groups_y, 1); @@ -335,7 +286,6 @@ pub fn launch_matmul_bias( /// Launch batched matrix multiplication with fused bias addition. /// /// Computes `C[b] = A[b] @ B[b] + bias` for each batch b. -/// The same bias vector is used for all batches. pub fn launch_batched_matmul_bias( cache: &PipelineCache, queue: &Queue, @@ -349,19 +299,21 @@ pub fn launch_batched_matmul_bias( batch_size: usize, dtype: DType, ) -> Result<()> { - // Get static keys and generate shader - let (module_key, _, batched_entry_point) = matmul_bias_keys(dtype)?; - let shader_source = generate_matmul_bias_shader(dtype)?; + check_dtype_f32!(dtype, "batched_matmul_bias"); - let module = cache.get_or_create_module(module_key, &shader_source); + let module = cache.get_or_create_module("matmul_bias_f32", MATMUL_BIAS_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 4, // a, b, bias, c num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(module_key, batched_entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline( + "matmul_bias_f32", + "batched_matmul_bias_f32", + &module, + &layout, + ); - // Bind buffers: a, b, bias, c, params let bind_group = cache.create_bind_group(&layout, &[a, b, bias, c, params_buffer]); let mut encoder = cache diff --git a/src/runtime/wgpu/shaders/matmul_wgsl.rs b/src/runtime/wgpu/shaders/matmul.wgsl similarity index 96% rename from src/runtime/wgpu/shaders/matmul_wgsl.rs rename to src/runtime/wgpu/shaders/matmul.wgsl index 8a74afcd..393de23c 100644 --- a/src/runtime/wgpu/shaders/matmul_wgsl.rs +++ b/src/runtime/wgpu/shaders/matmul.wgsl @@ -1,10 +1,6 @@ -//! WGSL shader source code for matrix multiplication -//! -//! Implements tiled matrix multiplication for better memory access patterns. -//! Supports 2D and batched matrix multiplication. +// Matrix multiplication operations. F32 only. +// Entry points: matmul_f32, batched_matmul_f32, matmul_simple_f32, matvec_f32 -/// Matrix multiplication shader module source (F32 only) -pub const MATMUL_SHADER: &str = r#" // ============================================================================ // Workgroup Configuration // ============================================================================ @@ -233,4 +229,3 @@ fn matvec_f32(@builtin(global_invocation_id) global_id: vec3, matvec_y[row] = matvec_shared[0]; } } -"#; diff --git a/src/runtime/wgpu/shaders/matmul_bias_f32.wgsl b/src/runtime/wgpu/shaders/matmul_bias_f32.wgsl new file mode 100644 index 00000000..4d6b7b5d --- /dev/null +++ b/src/runtime/wgpu/shaders/matmul_bias_f32.wgsl @@ -0,0 +1,121 @@ +// Fused matmul+bias operations. F32 only. +// C = A @ B + bias (fused epilogue) +// Entry points: matmul_bias_f32, batched_matmul_bias_f32 + +const TILE_SIZE: u32 = 16u; + +var tile_a: array, 16>; +var tile_b: array, 16>; + +struct MatmulBiasParams { + M: u32, + K: u32, + N: u32, + batch_size: u32, +} + +@group(0) @binding(0) var matmul_a: array; +@group(0) @binding(1) var matmul_b: array; +@group(0) @binding(2) var matmul_bias: array; +@group(0) @binding(3) var matmul_c: array; +@group(0) @binding(4) var matmul_params: MatmulBiasParams; + +@compute @workgroup_size(16, 16, 1) +fn matmul_bias_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let M = matmul_params.M; + let K = matmul_params.K; + let N = matmul_params.N; + + let row = group_id.y * TILE_SIZE + local_id.y; + let col = group_id.x * TILE_SIZE + local_id.x; + + var sum: f32 = 0.0; + + let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE; + + for (var t: u32 = 0u; t < num_tiles; t = t + 1u) { + let a_col = t * TILE_SIZE + local_id.x; + if (row < M && a_col < K) { + tile_a[local_id.y][local_id.x] = matmul_a[row * K + a_col]; + } else { + tile_a[local_id.y][local_id.x] = 0.0; + } + + let b_row = t * TILE_SIZE + local_id.y; + if (b_row < K && col < N) { + tile_b[local_id.y][local_id.x] = matmul_b[b_row * N + col]; + } else { + tile_b[local_id.y][local_id.x] = 0.0; + } + + workgroupBarrier(); + + for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) { + sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x]; + } + + workgroupBarrier(); + } + + // Fused epilogue: add bias and write result + if (row < M && col < N) { + matmul_c[row * N + col] = sum + matmul_bias[col]; + } +} + +@compute @workgroup_size(16, 16, 1) +fn batched_matmul_bias_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let M = matmul_params.M; + let K = matmul_params.K; + let N = matmul_params.N; + let batch_size = matmul_params.batch_size; + + let batch = group_id.z; + if (batch >= batch_size) { + return; + } + + let row = group_id.y * TILE_SIZE + local_id.y; + let col = group_id.x * TILE_SIZE + local_id.x; + + let a_batch_offset = batch * M * K; + let b_batch_offset = batch * K * N; + let c_batch_offset = batch * M * N; + + var sum: f32 = 0.0; + + let num_tiles = (K + TILE_SIZE - 1u) / TILE_SIZE; + + for (var t: u32 = 0u; t < num_tiles; t = t + 1u) { + let a_col = t * TILE_SIZE + local_id.x; + if (row < M && a_col < K) { + tile_a[local_id.y][local_id.x] = matmul_a[a_batch_offset + row * K + a_col]; + } else { + tile_a[local_id.y][local_id.x] = 0.0; + } + + let b_row = t * TILE_SIZE + local_id.y; + if (b_row < K && col < N) { + tile_b[local_id.y][local_id.x] = matmul_b[b_batch_offset + b_row * N + col]; + } else { + tile_b[local_id.y][local_id.x] = 0.0; + } + + workgroupBarrier(); + + for (var k: u32 = 0u; k < TILE_SIZE; k = k + 1u) { + sum = sum + tile_a[local_id.y][k] * tile_b[k][local_id.x]; + } + + workgroupBarrier(); + } + + // Fused epilogue: add bias (same bias for all batches) and write result + if (row < M && col < N) { + matmul_c[c_batch_offset + row * N + col] = sum + matmul_bias[col]; + } +} diff --git a/src/runtime/wgpu/shaders/matrix_funcs_launcher.rs b/src/runtime/wgpu/shaders/matrix_funcs_launcher.rs index 3009c511..5b0e3acc 100644 --- a/src/runtime/wgpu/shaders/matrix_funcs_launcher.rs +++ b/src/runtime/wgpu/shaders/matrix_funcs_launcher.rs @@ -2,13 +2,31 @@ use wgpu::{Buffer, Queue}; -use super::generator::{ - dtype_suffix, generate_diagonal_func_shader, generate_parlett_column_shader, - generate_validate_eigenvalues_shader, -}; use super::pipeline::{LayoutKey, PipelineCache}; use crate::dtype::DType; -use crate::error::Result; +use crate::error::{Error, Result}; + +const VALIDATE_EIGENVALUES_SHADER: &str = include_str!("validate_eigenvalues_f32.wgsl"); +// entry point: "validate_eigenvalues_f32" + +const DIAGONAL_EXP_SHADER: &str = include_str!("diagonal_exp_f32.wgsl"); +// entry point: "diagonal_exp_f32" + +const DIAGONAL_LOG_SHADER: &str = include_str!("diagonal_log_f32.wgsl"); +// entry point: "diagonal_log_f32" + +const DIAGONAL_SQRT_SHADER: &str = include_str!("diagonal_sqrt_f32.wgsl"); +// entry point: "diagonal_sqrt_f32" + +const PARLETT_COLUMN_SHADER: &str = include_str!("parlett_column_f32.wgsl"); +// entry point: "parlett_column_f32" + +fn check_dtype_f32(dtype: DType, op: &'static str) -> Result<()> { + match dtype { + DType::F32 => Ok(()), + _ => Err(Error::UnsupportedDType { dtype, op }), + } +} /// Launch eigenvalue validation on Schur form. /// @@ -24,19 +42,21 @@ pub fn launch_validate_eigenvalues( eps: f32, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let shader_key = format!("validate_eigenvalues_{}", suffix); - let entry_point = format!("validate_eigenvalues_{}", suffix); + check_dtype_f32(dtype, "validate_eigenvalues")?; - let shader_source = generate_validate_eigenvalues_shader(dtype)?; - let module = cache.get_or_create_module_from_source(&shader_key, &shader_source); + let module = + cache.get_or_create_module("validate_eigenvalues_f32", VALIDATE_EIGENVALUES_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline(&shader_key, &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline( + "validate_eigenvalues_f32", + "validate_eigenvalues_f32", + &module, + &layout, + ); // Create params buffer let params: [u32; 4] = [n as u32, eps.to_bits(), 0, 0]; @@ -83,19 +103,32 @@ pub fn launch_diagonal_func( func_type: &str, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let shader_key = format!("diagonal_{}_{}", func_type, suffix); - let entry_point = format!("diagonal_{}_{}", func_type, suffix); + check_dtype_f32(dtype, "diagonal_func")?; - let shader_source = generate_diagonal_func_shader(dtype, func_type)?; - let module = cache.get_or_create_module_from_source(&shader_key, &shader_source); + let (shader_src, module_name, entry_point): (&str, &'static str, &'static str) = match func_type + { + "exp" => (DIAGONAL_EXP_SHADER, "diagonal_exp_f32", "diagonal_exp_f32"), + "log" => (DIAGONAL_LOG_SHADER, "diagonal_log_f32", "diagonal_log_f32"), + "sqrt" => ( + DIAGONAL_SQRT_SHADER, + "diagonal_sqrt_f32", + "diagonal_sqrt_f32", + ), + _ => { + return Err(Error::Internal(format!( + "Unknown diagonal func type: {}", + func_type + ))); + } + }; + + let module = cache.get_or_create_module(module_name, shader_src); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline(&shader_key, &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_name, entry_point, &module, &layout); // Create params buffer let params: [u32; 4] = [n as u32, eps.to_bits(), 0, 0]; @@ -142,19 +175,16 @@ pub fn launch_parlett_column( eps: f32, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let shader_key = format!("parlett_column_{}", suffix); - let entry_point = format!("parlett_column_{}", suffix); + check_dtype_f32(dtype, "parlett_column")?; - let shader_source = generate_parlett_column_shader(dtype)?; - let module = cache.get_or_create_module_from_source(&shader_key, &shader_source); + let module = cache.get_or_create_module("parlett_column_f32", PARLETT_COLUMN_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); let pipeline = - cache.get_or_create_dynamic_pipeline(&shader_key, &entry_point, &module, &layout); + cache.get_or_create_pipeline("parlett_column_f32", "parlett_column_f32", &module, &layout); // Create params buffer let params: [u32; 4] = [n as u32, col as u32, eps.to_bits(), 0]; diff --git a/src/runtime/wgpu/shaders/mod.rs b/src/runtime/wgpu/shaders/mod.rs index f96ea7f4..aa38a2d4 100644 --- a/src/runtime/wgpu/shaders/mod.rs +++ b/src/runtime/wgpu/shaders/mod.rs @@ -2,23 +2,8 @@ //! //! This module provides native WGSL compute shaders for tensor operations. //! All operations run entirely on the GPU without CPU fallback. -//! -//! # Multi-DType Support -//! -//! Shaders are generated per-dtype using the `generator` module: -//! - F32, I32, U32 are always supported -//! - F16 requires WebGPU f16 extension -//! -//! # Module Structure -//! -//! - `generator` - WGSL shader source generation per dtype -//! - `pipeline` - Pipeline caching and dispatch utilities -//! - `elementwise` - Element-wise operation launchers -//! - `reduce` - Reduction operation launchers -//! - `matmul` - Matrix multiplication launchers -//! - `norm` - Normalization operation launchers -//! - `linalg` - Linear algebra kernel launchers -//! - `copy` - Copy operation shaders (strided to contiguous) +//! Shaders are static `.wgsl` files embedded at compile time via `include_str!()`. +//! WebGPU supports F32, I32, U32 only (no F64/F16/BF16). pub mod advanced_random; pub mod complex; @@ -29,7 +14,6 @@ pub mod distance; pub mod distributions; pub mod dtype_support; pub mod fft; -pub mod generator; pub mod index; pub mod linalg; pub mod logical; @@ -42,12 +26,17 @@ pub mod statistics; // Operation launchers pub mod activation_launcher; pub mod elementwise; +pub mod fused_add_norm; +pub mod gemm_epilogue; +pub mod gemv_bt; pub mod matmul; pub mod matrix_funcs_launcher; pub mod norm; pub mod reduce; pub mod semiring_matmul; #[cfg(feature = "sparse")] +pub mod sparse_24; +#[cfg(feature = "sparse")] pub mod sparse_algorithms_launcher; #[cfg(feature = "sparse")] pub mod sparse_conversions_launcher; @@ -63,11 +52,7 @@ pub mod where_launcher; mod linalg_launchers; mod linalg_shaders; -mod linalg_wgsl; -mod matmul_wgsl; -mod norm_wgsl; mod pipeline; -mod reduce_wgsl; #[cfg(feature = "sparse")] /// GPU-native level computation kernels for sparse factorization @@ -79,6 +64,8 @@ pub mod sparse_level_compute { } pub use activation_launcher::{launch_clamp_op, launch_elu, launch_leaky_relu}; +pub mod fused_activation_mul; +pub mod fused_elementwise; pub use advanced_random::{ launch_pcg64_randn, launch_pcg64_uniform, launch_philox_randn, launch_philox_uniform, launch_threefry_randn, launch_threefry_uniform, launch_xoshiro256_randn, @@ -102,24 +89,21 @@ pub use distributions::{ launch_chi_squared, launch_exponential, launch_f_distribution, launch_gamma_dist, launch_laplace, launch_multinomial_count, launch_poisson, launch_student_t, }; -pub use generator::{ - dtype_suffix, generate_all_casts_from, generate_arange_shader, generate_binary_shader, - generate_bincount_shader, generate_cast_shader, generate_cat_shader, generate_compare_shader, - generate_conv1d_shader, generate_conv2d_shader, generate_cumprod_shader, - generate_cumprod_strided_shader, generate_cumsum_shader, generate_cumsum_strided_shader, - generate_depthwise_conv2d_shader, generate_eye_shader, generate_fill_shader, - generate_gather_nd_shader, generate_gather_shader, generate_index_select_shader, - generate_linspace_shader, generate_logsumexp_shader, generate_logsumexp_strided_shader, - generate_masked_fill_shader, generate_masked_select_shader, generate_matmul_shader, - generate_norm_shader, generate_reduce_shader, generate_scalar_shader, - generate_scatter_reduce_shader, generate_scatter_shader, generate_unary_shader, - is_wgpu_supported, is_wgsl_float, is_wgsl_int, wgsl_type, +pub use fused_activation_mul::{ + launch_gelu_mul, launch_gelu_mul_bwd, launch_relu_mul, launch_relu_mul_bwd, launch_sigmoid_mul, + launch_sigmoid_mul_bwd, launch_silu_mul, launch_silu_mul_bwd, +}; +pub use fused_add_norm::{ + launch_fused_add_layer_norm, launch_fused_add_layer_norm_bwd, launch_fused_add_rms_norm, + launch_fused_add_rms_norm_bwd, launch_reduce_sum_rows, +}; +pub use fused_elementwise::{ + launch_fused_add_mul, launch_fused_mul_add, launch_fused_mul_add_scalar, }; -#[cfg(feature = "sparse")] -pub use generator::{generate_csr_spmm_shader, generate_csr_spmv_shader}; pub use index::{ launch_bincount, launch_gather_2d, launch_gather_nd, launch_scatter_reduce, launch_scatter_reduce_count, launch_scatter_reduce_mean_div, launch_scatter_reduce_prod, + launch_slice_assign, }; pub use logical::{launch_logical_and, launch_logical_not, launch_logical_or, launch_logical_xor}; pub use matrix_funcs_launcher::{ @@ -129,6 +113,8 @@ pub use matrix_funcs_launcher::{ pub use pipeline::{LayoutKey, PipelineCache, WORKGROUP_SIZE, workgroup_count}; pub use quasirandom::{launch_halton, launch_latin_hypercube, launch_sobol}; #[cfg(feature = "sparse")] +pub use sparse_24::{Sparse24Params, launch_sparse_24_decompress, launch_sparse_24_prune}; +#[cfg(feature = "sparse")] pub use sparse_algorithms_launcher::{ launch_dsmm_csc, launch_spgemm_accumulate, launch_spgemm_scatter, launch_spgemm_symbolic, }; diff --git a/src/runtime/wgpu/shaders/multinomial_count_f32.wgsl b/src/runtime/wgpu/shaders/multinomial_count_f32.wgsl new file mode 100644 index 00000000..51beffad --- /dev/null +++ b/src/runtime/wgpu/shaders/multinomial_count_f32.wgsl @@ -0,0 +1,55 @@ +// Multinomial count shader for f32 +// Performs CDF lookup for uniform samples and counts occurrences per category + +const WORKGROUP_SIZE: u32 = 256u; + +struct MultinomialCountParams { + k: u32, // Number of categories + n_trials: u32, // Number of trials per sample + n_samples: u32, // Number of samples + _pad: u32, +} + +@group(0) @binding(0) var cdf: array; +@group(0) @binding(1) var uniforms: array; +@group(0) @binding(2) var counts: array; +@group(0) @binding(3) var params: MultinomialCountParams; + +// Binary search to find category for uniform sample +fn find_category(u: f32, k: u32) -> u32 { + var lo: u32 = 0u; + var hi: u32 = k; + while (lo < hi) { + let mid = lo + (hi - lo) / 2u; + if (cdf[mid] <= u) { + lo = mid + 1u; + } else { + hi = mid; + } + } + return min(lo, k - 1u); +} + +@compute @workgroup_size(256) +fn multinomial_count_f32(@builtin(global_invocation_id) global_id: vec3) { + let sample_idx = global_id.x; + let k = params.k; + let n_trials = params.n_trials; + let n_samples = params.n_samples; + + if (sample_idx >= n_samples) { + return; + } + + // Initialize counts for this sample to zero + for (var c: u32 = 0u; c < k; c++) { + counts[sample_idx * k + c] = f32(0.0); + } + + // Process each trial + for (var t_idx: u32 = 0u; t_idx < n_trials; t_idx++) { + let u = uniforms[sample_idx * n_trials + t_idx]; + let category = find_category(u, k); + counts[sample_idx * k + category] += f32(1.0); + } +} diff --git a/src/runtime/wgpu/shaders/multinomial_with_replacement_f32.wgsl b/src/runtime/wgpu/shaders/multinomial_with_replacement_f32.wgsl new file mode 100644 index 00000000..d00e01ab --- /dev/null +++ b/src/runtime/wgpu/shaders/multinomial_with_replacement_f32.wgsl @@ -0,0 +1,83 @@ +// Auto-generated multinomial_with_replacement operation for f32 + +// PCG hash function for random number generation +// Based on PCG Random Number Generation by Melissa O'Neill +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +// Initialize PCG state from seed and index +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +// Generate uniform float in [0, 1) +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; // Divide by 2^32 +} + +// Box-Muller transform for normal distribution +// Generates one normal value, requires two uniform values +fn box_muller(u1: f32, u2: f32) -> f32 { + let u1_safe = max(u1, 0.0000001); // Avoid log(0) + let r = sqrt(-2.0 * log(u1_safe)); + let theta = 6.28318530718 * u2; // 2 * PI + return r * cos(theta); +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct MultinomialParams { + num_distributions: u32, + num_categories: u32, + num_samples: u32, + seed: u32, +} + +@group(0) @binding(0) var probs: array; +@group(0) @binding(1) var multinomial_out: array; +@group(0) @binding(2) var multinomial_params: MultinomialParams; + +@compute @workgroup_size(256) +fn multinomial_with_replacement_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = multinomial_params.num_distributions * multinomial_params.num_samples; + if (idx >= total) { + return; + } + + let dist = idx / multinomial_params.num_samples; + let sample = idx % multinomial_params.num_samples; + + // Initialize RNG for this thread + var state = pcg_init(multinomial_params.seed, idx); + + // Get pointer to this distribution's probabilities + let prob_offset = dist * multinomial_params.num_categories; + + // Compute sum of probabilities for normalization + var sum: f32 = 0.0; + for (var i: u32 = 0u; i < multinomial_params.num_categories; i = i + 1u) { + sum = sum + probs[prob_offset + i]; + } + + // Generate uniform random value + let u = pcg_uniform(&state); + + // Linear search using CDF (on-the-fly computation) + // Find smallest index where cumsum/sum >= u + var cumsum: f32 = 0.0; + var result: u32 = multinomial_params.num_categories - 1u; // Default to last category + for (var i: u32 = 0u; i < multinomial_params.num_categories; i = i + 1u) { + cumsum = cumsum + probs[prob_offset + i]; + if (cumsum / sum >= u) { + result = i; + break; + } + } + + multinomial_out[dist * multinomial_params.num_samples + sample] = i32(result); +} diff --git a/src/runtime/wgpu/shaders/multinomial_without_replacement_f32.wgsl b/src/runtime/wgpu/shaders/multinomial_without_replacement_f32.wgsl new file mode 100644 index 00000000..a7b562ea --- /dev/null +++ b/src/runtime/wgpu/shaders/multinomial_without_replacement_f32.wgsl @@ -0,0 +1,101 @@ +// Auto-generated multinomial_without_replacement operation for f32 + +// PCG hash function for random number generation +// Based on PCG Random Number Generation by Melissa O'Neill +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +// Initialize PCG state from seed and index +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +// Generate uniform float in [0, 1) +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; // Divide by 2^32 +} + +// Box-Muller transform for normal distribution +// Generates one normal value, requires two uniform values +fn box_muller(u1: f32, u2: f32) -> f32 { + let u1_safe = max(u1, 0.0000001); // Avoid log(0) + let r = sqrt(-2.0 * log(u1_safe)); + let theta = 6.28318530718 * u2; // 2 * PI + return r * cos(theta); +} + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_CATEGORIES: u32 = 1024u; // Maximum supported categories + +struct MultinomialParams { + num_distributions: u32, + num_categories: u32, + num_samples: u32, + seed: u32, +} + +@group(0) @binding(0) var probs: array; +@group(0) @binding(1) var multinomial_out: array; +@group(0) @binding(2) var multinomial_params: MultinomialParams; + +var shared_probs: array; + +@compute @workgroup_size(256) +fn multinomial_without_replacement_f32(@builtin(global_invocation_id) gid: vec3, @builtin(local_invocation_id) lid: vec3) { + let dist = gid.x / WORKGROUP_SIZE; + if (dist >= multinomial_params.num_distributions) { + return; + } + + // Copy probabilities to shared memory (each thread copies some elements) + let prob_offset = dist * multinomial_params.num_categories; + let elements_per_thread = (multinomial_params.num_categories + WORKGROUP_SIZE - 1u) / WORKGROUP_SIZE; + for (var i: u32 = 0u; i < elements_per_thread; i = i + 1u) { + let idx = lid.x * elements_per_thread + i; + if (idx < multinomial_params.num_categories) { + shared_probs[idx] = probs[prob_offset + idx]; + } + } + + workgroupBarrier(); + + // Only thread 0 does the sequential sampling + if (lid.x != 0u) { + return; + } + + // Initialize RNG + var state = pcg_init(multinomial_params.seed, dist); + + // Sample without replacement + for (var s: u32 = 0u; s < multinomial_params.num_samples; s = s + 1u) { + // Compute sum of remaining probabilities + var sum: f32 = 0.0; + for (var i: u32 = 0u; i < multinomial_params.num_categories; i = i + 1u) { + sum = sum + shared_probs[i]; + } + + // Generate uniform random value + let u = pcg_uniform(&state); + + // Linear search using CDF + var cumsum: f32 = 0.0; + var result: u32 = multinomial_params.num_categories - 1u; + for (var i: u32 = 0u; i < multinomial_params.num_categories; i = i + 1u) { + cumsum = cumsum + shared_probs[i]; + if (cumsum / sum >= u) { + result = i; + break; + } + } + + multinomial_out[dist * multinomial_params.num_samples + s] = i32(result); + + // Zero out selected category + shared_probs[result] = 0.0; + } +} diff --git a/src/runtime/wgpu/shaders/norm.rs b/src/runtime/wgpu/shaders/norm.rs index 39922c49..c6b927fe 100644 --- a/src/runtime/wgpu/shaders/norm.rs +++ b/src/runtime/wgpu/shaders/norm.rs @@ -8,11 +8,12 @@ use wgpu::{Buffer, Queue}; -use super::norm_wgsl::NORM_SHADER; use super::pipeline::{LayoutKey, PipelineCache}; use crate::dtype::DType; use crate::error::{Error, Result}; +const NORM_SHADER: &str = include_str!("norm.wgsl"); + // ============================================================================ // Helper Macros // ============================================================================ @@ -175,3 +176,57 @@ pub fn launch_layer_norm_no_bias( queue.submit(std::iter::once(encoder.finish())); Ok(()) } + +// ============================================================================ +// Group Normalization +// ============================================================================ + +/// Launch group normalization kernel. +/// +/// Computes: output = (input - mean) / sqrt(var + eps) * weight + bias +/// Normalizes over groups of channels. +pub fn launch_group_norm( + cache: &PipelineCache, + queue: &Queue, + input: &Buffer, + weight: &Buffer, + bias: &Buffer, + output: &Buffer, + params_buffer: &Buffer, + batch_size: usize, + num_groups: usize, + dtype: DType, +) -> Result<()> { + check_dtype_f32!(dtype, "group_norm"); + + let module = cache.get_or_create_module("norm", NORM_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 4, + num_uniform_buffers: 1, + num_readonly_storage: 0, + }); + let pipeline = cache.get_or_create_pipeline("norm", "group_norm_f32", &module, &layout); + + let bind_group = + cache.create_bind_group(&layout, &[input, weight, bias, output, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("group_norm"), + }); + + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("group_norm"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + // One workgroup per (batch, group) pair + pass.dispatch_workgroups((batch_size * num_groups) as u32, 1, 1); + } + + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} diff --git a/src/runtime/wgpu/shaders/norm.wgsl b/src/runtime/wgpu/shaders/norm.wgsl new file mode 100644 index 00000000..cb7589b4 --- /dev/null +++ b/src/runtime/wgpu/shaders/norm.wgsl @@ -0,0 +1,371 @@ +// Normalization operations. F32 only. +// Entry points: rms_norm_f32, layer_norm_f32, layer_norm_no_bias_f32, group_norm_f32 +// +// Welford's online algorithm is used for LayerNorm and GroupNorm to compute +// mean and variance in a single pass with numerical stability. Each thread +// accumulates its own (count, mean, M2) triple, then a tree reduction merges +// accumulators across the workgroup using the parallel Welford merge formula: +// delta = mean_b - mean_a +// mean_ab = mean_a + delta * count_b / (count_a + count_b) +// M2_ab = M2_a + M2_b + delta^2 * count_a * count_b / (count_a + count_b) +// +// Shared memory is sized to WORKGROUP_SIZE (256). All workgroup_size attributes +// and shared memory array sizes MUST be kept in sync with this constant. + +// ============================================================================ +// Workgroup Configuration +// ============================================================================ + +const WORKGROUP_SIZE: u32 = 256u; + +var norm_shared: array; + +// ============================================================================ +// RMS Normalization +// ============================================================================ +// rms_norm(x, weight, eps) = x / sqrt(mean(x^2) + eps) * weight +// Applied to last dimension + +struct RmsNormParams { + batch_size: u32, // Product of all dims except the last + hidden_size: u32, // Size of the last dimension + eps: f32, +} + +@group(0) @binding(0) var rms_input: array; +@group(0) @binding(1) var rms_weight: array; +@group(0) @binding(2) var rms_output: array; +@group(0) @binding(3) var rms_params: RmsNormParams; + +@compute @workgroup_size(256) +fn rms_norm_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let batch_idx = group_id.x; + + if (batch_idx >= rms_params.batch_size) { + return; + } + + let hidden_size = rms_params.hidden_size; + let eps = rms_params.eps; + let base_offset = batch_idx * hidden_size; + + // Step 1: Compute sum of squares + var sum_sq: f32 = 0.0; + var i: u32 = tid; + while (i < hidden_size) { + let val = rms_input[base_offset + i]; + sum_sq = sum_sq + val * val; + i = i + WORKGROUP_SIZE; + } + + norm_shared[tid] = sum_sq; + workgroupBarrier(); + + // Reduce to get total sum of squares + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + norm_shared[tid] = norm_shared[tid] + norm_shared[tid + s]; + } + workgroupBarrier(); + } + + // Compute RMS: sqrt(mean(x^2) + eps) + let rms = sqrt(norm_shared[0] / f32(hidden_size) + eps); + workgroupBarrier(); + + // Step 2: Normalize and apply weight + i = tid; + while (i < hidden_size) { + rms_output[base_offset + i] = rms_input[base_offset + i] / rms * rms_weight[i]; + i = i + WORKGROUP_SIZE; + } +} + +// ============================================================================ +// Layer Normalization +// ============================================================================ +// layer_norm(x, weight, bias, eps) = (x - mean(x)) / sqrt(var(x) + eps) * weight + bias +// Applied to last dimension + +struct LayerNormParams { + batch_size: u32, + hidden_size: u32, + eps: f32, +} + +@group(0) @binding(0) var ln_input: array; +@group(0) @binding(1) var ln_weight: array; +@group(0) @binding(2) var ln_bias: array; +@group(0) @binding(3) var ln_output: array; +@group(0) @binding(4) var ln_params: LayerNormParams; + +// Welford shared memory: count, mean, M2 per thread +var ln_shared_count: array; +var ln_shared_mean: array; +var ln_shared_m2: array; + +@compute @workgroup_size(256) +fn layer_norm_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let batch_idx = group_id.x; + + if (batch_idx >= ln_params.batch_size) { + return; + } + + let hidden_size = ln_params.hidden_size; + let eps = ln_params.eps; + let base_offset = batch_idx * hidden_size; + + // Step 1: Per-thread Welford accumulation (single pass over input) + var count: f32 = 0.0; + var mean: f32 = 0.0; + var m2: f32 = 0.0; + var i: u32 = tid; + while (i < hidden_size) { + let x = ln_input[base_offset + i]; + count = count + 1.0; + let delta = x - mean; + mean = mean + delta / count; + m2 = m2 + delta * (x - mean); + i = i + WORKGROUP_SIZE; + } + + ln_shared_count[tid] = count; + ln_shared_mean[tid] = mean; + ln_shared_m2[tid] = m2; + workgroupBarrier(); + + // Step 2: Tree reduction with Welford merge + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + let count_a = ln_shared_count[tid]; + let mean_a = ln_shared_mean[tid]; + let m2_a = ln_shared_m2[tid]; + let count_b = ln_shared_count[tid + s]; + let mean_b = ln_shared_mean[tid + s]; + let m2_b = ln_shared_m2[tid + s]; + + let merged_count = count_a + count_b; + if (merged_count > 0.0) { + let delta = mean_b - mean_a; + let merged_mean = mean_a + delta * count_b / merged_count; + let merged_m2 = m2_a + m2_b + delta * delta * count_a * count_b / merged_count; + ln_shared_count[tid] = merged_count; + ln_shared_mean[tid] = merged_mean; + ln_shared_m2[tid] = merged_m2; + } + } + workgroupBarrier(); + } + + let final_mean = ln_shared_mean[0]; + let variance = ln_shared_m2[0] / f32(hidden_size); + let inv_std = 1.0 / sqrt(variance + eps); + workgroupBarrier(); + + // Step 3: Normalize and apply affine transformation (second pass over input) + i = tid; + while (i < hidden_size) { + let normalized = (ln_input[base_offset + i] - final_mean) * inv_std; + ln_output[base_offset + i] = normalized * ln_weight[i] + ln_bias[i]; + i = i + WORKGROUP_SIZE; + } +} + +// ============================================================================ +// Layer Normalization without bias +// ============================================================================ + +@group(0) @binding(0) var ln_nb_input: array; +@group(0) @binding(1) var ln_nb_weight: array; +@group(0) @binding(2) var ln_nb_output: array; +@group(0) @binding(3) var ln_nb_params: LayerNormParams; + +@compute @workgroup_size(256) +fn layer_norm_no_bias_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let batch_idx = group_id.x; + + if (batch_idx >= ln_nb_params.batch_size) { + return; + } + + let hidden_size = ln_nb_params.hidden_size; + let eps = ln_nb_params.eps; + let base_offset = batch_idx * hidden_size; + + // Step 1: Per-thread Welford accumulation (single pass) + var count: f32 = 0.0; + var mean: f32 = 0.0; + var m2: f32 = 0.0; + var i: u32 = tid; + while (i < hidden_size) { + let x = ln_nb_input[base_offset + i]; + count = count + 1.0; + let delta = x - mean; + mean = mean + delta / count; + m2 = m2 + delta * (x - mean); + i = i + WORKGROUP_SIZE; + } + + // Reuse layer_norm shared memory for reduction + ln_shared_count[tid] = count; + ln_shared_mean[tid] = mean; + ln_shared_m2[tid] = m2; + workgroupBarrier(); + + // Step 2: Tree reduction with Welford merge + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + let count_a = ln_shared_count[tid]; + let mean_a = ln_shared_mean[tid]; + let m2_a = ln_shared_m2[tid]; + let count_b = ln_shared_count[tid + s]; + let mean_b = ln_shared_mean[tid + s]; + let m2_b = ln_shared_m2[tid + s]; + + let merged_count = count_a + count_b; + if (merged_count > 0.0) { + let delta = mean_b - mean_a; + ln_shared_count[tid] = merged_count; + ln_shared_mean[tid] = mean_a + delta * count_b / merged_count; + ln_shared_m2[tid] = m2_a + m2_b + delta * delta * count_a * count_b / merged_count; + } + } + workgroupBarrier(); + } + + let final_mean = ln_shared_mean[0]; + let variance = ln_shared_m2[0] / f32(hidden_size); + let inv_std = 1.0 / sqrt(variance + eps); + workgroupBarrier(); + + // Step 3: Normalize and apply weight only (second pass) + i = tid; + while (i < hidden_size) { + let normalized = (ln_nb_input[base_offset + i] - final_mean) * inv_std; + ln_nb_output[base_offset + i] = normalized * ln_nb_weight[i]; + i = i + WORKGROUP_SIZE; + } +} + +// ============================================================================ +// Group Normalization +// ============================================================================ +// group_norm(x, weight, bias, num_groups) normalizes over groups of channels + +struct GroupNormParams { + batch_size: u32, + channels: u32, + spatial: u32, + num_groups: u32, + channels_per_group: u32, + eps: f32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var gn_input: array; +@group(0) @binding(1) var gn_weight: array; +@group(0) @binding(2) var gn_bias: array; +@group(0) @binding(3) var gn_output: array; +@group(0) @binding(4) var gn_params: GroupNormParams; + +var gn_shared_count: array; +var gn_shared_mean: array; +var gn_shared_m2: array; + +@compute @workgroup_size(256) +fn group_norm_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let bg_id = group_id.x; // batch_id * num_groups + group_id + + let batch_size = gn_params.batch_size; + let channels = gn_params.channels; + let spatial = gn_params.spatial; + let num_groups = gn_params.num_groups; + let channels_per_group = gn_params.channels_per_group; + let eps = gn_params.eps; + + if (bg_id >= batch_size * num_groups) { + return; + } + + let batch_id = bg_id / num_groups; + let group_id_val = bg_id % num_groups; + let c_start = group_id_val * channels_per_group; + let group_size = channels_per_group * spatial; + + let batch_offset = batch_id * channels * spatial; + let group_offset = batch_offset + c_start * spatial; + + // Step 1: Per-thread Welford accumulation (single pass) + var count: f32 = 0.0; + var mean: f32 = 0.0; + var m2: f32 = 0.0; + var i: u32 = tid; + while (i < group_size) { + let c_offset = i / spatial; + let s_offset = i % spatial; + let idx = group_offset + c_offset * spatial + s_offset; + let x = gn_input[idx]; + count = count + 1.0; + let delta = x - mean; + mean = mean + delta / count; + m2 = m2 + delta * (x - mean); + i = i + WORKGROUP_SIZE; + } + + gn_shared_count[tid] = count; + gn_shared_mean[tid] = mean; + gn_shared_m2[tid] = m2; + workgroupBarrier(); + + // Step 2: Tree reduction with Welford merge + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + let count_a = gn_shared_count[tid]; + let mean_a = gn_shared_mean[tid]; + let m2_a = gn_shared_m2[tid]; + let count_b = gn_shared_count[tid + s]; + let mean_b = gn_shared_mean[tid + s]; + let m2_b = gn_shared_m2[tid + s]; + + let merged_count = count_a + count_b; + if (merged_count > 0.0) { + let delta = mean_b - mean_a; + gn_shared_count[tid] = merged_count; + gn_shared_mean[tid] = mean_a + delta * count_b / merged_count; + gn_shared_m2[tid] = m2_a + m2_b + delta * delta * count_a * count_b / merged_count; + } + } + workgroupBarrier(); + } + + let final_mean = gn_shared_mean[0]; + let variance = gn_shared_m2[0] / f32(group_size); + let inv_std = 1.0 / sqrt(variance + eps); + workgroupBarrier(); + + // Step 3: Normalize and apply per-channel weight and bias (second pass) + i = tid; + while (i < group_size) { + let c_offset = i / spatial; + let s_offset = i % spatial; + let idx = group_offset + c_offset * spatial + s_offset; + let channel = c_start + c_offset; + let normalized = (gn_input[idx] - final_mean) * inv_std; + gn_output[idx] = normalized * gn_weight[channel] + gn_bias[channel]; + i = i + WORKGROUP_SIZE; + } +} diff --git a/src/runtime/wgpu/shaders/norm_wgsl.rs b/src/runtime/wgpu/shaders/norm_wgsl.rs deleted file mode 100644 index 8815a2f8..00000000 --- a/src/runtime/wgpu/shaders/norm_wgsl.rs +++ /dev/null @@ -1,245 +0,0 @@ -//! WGSL shader source code for normalization operations -//! -//! Includes RMS normalization and Layer normalization. -//! Both use workgroup-level parallel reductions for efficiency. - -/// Normalization shader module source (F32 only) -pub const NORM_SHADER: &str = r#" -// ============================================================================ -// Workgroup Configuration -// ============================================================================ - -const WORKGROUP_SIZE: u32 = 256u; - -var norm_shared: array; - -// ============================================================================ -// RMS Normalization -// ============================================================================ -// rms_norm(x, weight, eps) = x / sqrt(mean(x^2) + eps) * weight -// Applied to last dimension - -struct RmsNormParams { - batch_size: u32, // Product of all dims except the last - hidden_size: u32, // Size of the last dimension - eps: f32, -} - -@group(0) @binding(0) var rms_input: array; -@group(0) @binding(1) var rms_weight: array; -@group(0) @binding(2) var rms_output: array; -@group(0) @binding(3) var rms_params: RmsNormParams; - -@compute @workgroup_size(256) -fn rms_norm_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let batch_idx = group_id.x; - - if (batch_idx >= rms_params.batch_size) { - return; - } - - let hidden_size = rms_params.hidden_size; - let eps = rms_params.eps; - let base_offset = batch_idx * hidden_size; - - // Step 1: Compute sum of squares - var sum_sq: f32 = 0.0; - var i: u32 = tid; - while (i < hidden_size) { - let val = rms_input[base_offset + i]; - sum_sq = sum_sq + val * val; - i = i + WORKGROUP_SIZE; - } - - norm_shared[tid] = sum_sq; - workgroupBarrier(); - - // Reduce to get total sum of squares - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - norm_shared[tid] = norm_shared[tid] + norm_shared[tid + s]; - } - workgroupBarrier(); - } - - // Compute RMS: sqrt(mean(x^2) + eps) - let rms = sqrt(norm_shared[0] / f32(hidden_size) + eps); - workgroupBarrier(); - - // Step 2: Normalize and apply weight - i = tid; - while (i < hidden_size) { - rms_output[base_offset + i] = rms_input[base_offset + i] / rms * rms_weight[i]; - i = i + WORKGROUP_SIZE; - } -} - -// ============================================================================ -// Layer Normalization -// ============================================================================ -// layer_norm(x, weight, bias, eps) = (x - mean(x)) / sqrt(var(x) + eps) * weight + bias -// Applied to last dimension - -struct LayerNormParams { - batch_size: u32, - hidden_size: u32, - eps: f32, -} - -@group(0) @binding(0) var ln_input: array; -@group(0) @binding(1) var ln_weight: array; -@group(0) @binding(2) var ln_bias: array; -@group(0) @binding(3) var ln_output: array; -@group(0) @binding(4) var ln_params: LayerNormParams; - -var ln_shared_mean: array; -var ln_shared_var: array; - -@compute @workgroup_size(256) -fn layer_norm_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let batch_idx = group_id.x; - - if (batch_idx >= ln_params.batch_size) { - return; - } - - let hidden_size = ln_params.hidden_size; - let eps = ln_params.eps; - let base_offset = batch_idx * hidden_size; - - // Step 1: Compute mean - var sum: f32 = 0.0; - var i: u32 = tid; - while (i < hidden_size) { - sum = sum + ln_input[base_offset + i]; - i = i + WORKGROUP_SIZE; - } - - ln_shared_mean[tid] = sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - ln_shared_mean[tid] = ln_shared_mean[tid] + ln_shared_mean[tid + s]; - } - workgroupBarrier(); - } - - let mean = ln_shared_mean[0] / f32(hidden_size); - workgroupBarrier(); - - // Step 2: Compute variance - var var_sum: f32 = 0.0; - i = tid; - while (i < hidden_size) { - let diff = ln_input[base_offset + i] - mean; - var_sum = var_sum + diff * diff; - i = i + WORKGROUP_SIZE; - } - - ln_shared_var[tid] = var_sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - ln_shared_var[tid] = ln_shared_var[tid] + ln_shared_var[tid + s]; - } - workgroupBarrier(); - } - - let variance = ln_shared_var[0] / f32(hidden_size); - let inv_std = 1.0 / sqrt(variance + eps); - workgroupBarrier(); - - // Step 3: Normalize and apply affine transformation - i = tid; - while (i < hidden_size) { - let normalized = (ln_input[base_offset + i] - mean) * inv_std; - ln_output[base_offset + i] = normalized * ln_weight[i] + ln_bias[i]; - i = i + WORKGROUP_SIZE; - } -} - -// ============================================================================ -// Layer Normalization without bias -// ============================================================================ - -@group(0) @binding(0) var ln_nb_input: array; -@group(0) @binding(1) var ln_nb_weight: array; -@group(0) @binding(2) var ln_nb_output: array; -@group(0) @binding(3) var ln_nb_params: LayerNormParams; - -@compute @workgroup_size(256) -fn layer_norm_no_bias_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let batch_idx = group_id.x; - - if (batch_idx >= ln_nb_params.batch_size) { - return; - } - - let hidden_size = ln_nb_params.hidden_size; - let eps = ln_nb_params.eps; - let base_offset = batch_idx * hidden_size; - - // Step 1: Compute mean - var sum: f32 = 0.0; - var i: u32 = tid; - while (i < hidden_size) { - sum = sum + ln_nb_input[base_offset + i]; - i = i + WORKGROUP_SIZE; - } - - ln_shared_mean[tid] = sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - ln_shared_mean[tid] = ln_shared_mean[tid] + ln_shared_mean[tid + s]; - } - workgroupBarrier(); - } - - let mean = ln_shared_mean[0] / f32(hidden_size); - workgroupBarrier(); - - // Step 2: Compute variance - var var_sum: f32 = 0.0; - i = tid; - while (i < hidden_size) { - let diff = ln_nb_input[base_offset + i] - mean; - var_sum = var_sum + diff * diff; - i = i + WORKGROUP_SIZE; - } - - ln_shared_var[tid] = var_sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - ln_shared_var[tid] = ln_shared_var[tid] + ln_shared_var[tid + s]; - } - workgroupBarrier(); - } - - let variance = ln_shared_var[0] / f32(hidden_size); - let inv_std = 1.0 / sqrt(variance + eps); - workgroupBarrier(); - - // Step 3: Normalize and apply weight only - i = tid; - while (i < hidden_size) { - let normalized = (ln_nb_input[base_offset + i] - mean) * inv_std; - ln_nb_output[base_offset + i] = normalized * ln_nb_weight[i]; - i = i + WORKGROUP_SIZE; - } -} -"#; diff --git a/src/runtime/wgpu/shaders/pad_f32.wgsl b/src/runtime/wgpu/shaders/pad_f32.wgsl new file mode 100644 index 00000000..ec5bf2a9 --- /dev/null +++ b/src/runtime/wgpu/shaders/pad_f32.wgsl @@ -0,0 +1,77 @@ +// Auto-generated pad operation for f32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_DIMS: u32 = 8u; + +// Use vec4 for 16-byte alignment in uniform buffer +struct PadParams { + ndim: u32, + total_elements: u32, + fill_value: f32, + _pad0: u32, + src_shape: array, 2>, // 8 u32 values packed into 2 vec4 + out_shape: array, 2>, + pad_before: array, 2>, +} + +// Helper to access packed array, 2> by index +fn get_packed_value(arr: array, 2>, d: i32) -> u32 { + let vec_idx = u32(d) / 4u; + let comp_idx = u32(d) % 4u; + if (vec_idx == 0u) { + if (comp_idx == 0u) { return arr[0].x; } + else if (comp_idx == 1u) { return arr[0].y; } + else if (comp_idx == 2u) { return arr[0].z; } + else { return arr[0].w; } + } else { + if (comp_idx == 0u) { return arr[1].x; } + else if (comp_idx == 1u) { return arr[1].y; } + else if (comp_idx == 2u) { return arr[1].z; } + else { return arr[1].w; } + } +} + +@group(0) @binding(0) var pad_src: array; +@group(0) @binding(1) var pad_dst: array; +@group(0) @binding(2) var pad_params: PadParams; + +@compute @workgroup_size(256) +fn pad_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= pad_params.total_elements) { + return; + } + + // Decompose idx into multi-dimensional output coordinates + var remaining = idx; + var coords: array; + var in_bounds = true; + + // Process dimensions from last to first + for (var d = i32(pad_params.ndim) - 1; d >= 0; d = d - 1) { + let out_dim = get_packed_value(pad_params.out_shape, d); + coords[d] = remaining % out_dim; + remaining = remaining / out_dim; + + // Check if coordinate is in original tensor region + let pb = get_packed_value(pad_params.pad_before, d); + let ss = get_packed_value(pad_params.src_shape, d); + if (coords[d] < pb || coords[d] >= pb + ss) { + in_bounds = false; + } + } + + if (in_bounds) { + // Compute source index + var src_idx = 0u; + var src_stride = 1u; + for (var d = i32(pad_params.ndim) - 1; d >= 0; d = d - 1) { + let src_coord = coords[d] - get_packed_value(pad_params.pad_before, d); + src_idx = src_idx + src_coord * src_stride; + src_stride = src_stride * get_packed_value(pad_params.src_shape, d); + } + pad_dst[idx] = pad_src[src_idx]; + } else { + pad_dst[idx] = pad_params.fill_value; + } +} diff --git a/src/runtime/wgpu/shaders/pad_i32.wgsl b/src/runtime/wgpu/shaders/pad_i32.wgsl new file mode 100644 index 00000000..386428f3 --- /dev/null +++ b/src/runtime/wgpu/shaders/pad_i32.wgsl @@ -0,0 +1,77 @@ +// Auto-generated pad operation for i32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_DIMS: u32 = 8u; + +// Use vec4 for 16-byte alignment in uniform buffer +struct PadParams { + ndim: u32, + total_elements: u32, + fill_value: i32, + _pad0: u32, + src_shape: array, 2>, // 8 u32 values packed into 2 vec4 + out_shape: array, 2>, + pad_before: array, 2>, +} + +// Helper to access packed array, 2> by index +fn get_packed_value(arr: array, 2>, d: i32) -> u32 { + let vec_idx = u32(d) / 4u; + let comp_idx = u32(d) % 4u; + if (vec_idx == 0u) { + if (comp_idx == 0u) { return arr[0].x; } + else if (comp_idx == 1u) { return arr[0].y; } + else if (comp_idx == 2u) { return arr[0].z; } + else { return arr[0].w; } + } else { + if (comp_idx == 0u) { return arr[1].x; } + else if (comp_idx == 1u) { return arr[1].y; } + else if (comp_idx == 2u) { return arr[1].z; } + else { return arr[1].w; } + } +} + +@group(0) @binding(0) var pad_src: array; +@group(0) @binding(1) var pad_dst: array; +@group(0) @binding(2) var pad_params: PadParams; + +@compute @workgroup_size(256) +fn pad_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= pad_params.total_elements) { + return; + } + + // Decompose idx into multi-dimensional output coordinates + var remaining = idx; + var coords: array; + var in_bounds = true; + + // Process dimensions from last to first + for (var d = i32(pad_params.ndim) - 1; d >= 0; d = d - 1) { + let out_dim = get_packed_value(pad_params.out_shape, d); + coords[d] = remaining % out_dim; + remaining = remaining / out_dim; + + // Check if coordinate is in original tensor region + let pb = get_packed_value(pad_params.pad_before, d); + let ss = get_packed_value(pad_params.src_shape, d); + if (coords[d] < pb || coords[d] >= pb + ss) { + in_bounds = false; + } + } + + if (in_bounds) { + // Compute source index + var src_idx = 0u; + var src_stride = 1u; + for (var d = i32(pad_params.ndim) - 1; d >= 0; d = d - 1) { + let src_coord = coords[d] - get_packed_value(pad_params.pad_before, d); + src_idx = src_idx + src_coord * src_stride; + src_stride = src_stride * get_packed_value(pad_params.src_shape, d); + } + pad_dst[idx] = pad_src[src_idx]; + } else { + pad_dst[idx] = pad_params.fill_value; + } +} diff --git a/src/runtime/wgpu/shaders/pad_u32.wgsl b/src/runtime/wgpu/shaders/pad_u32.wgsl new file mode 100644 index 00000000..a9f34f80 --- /dev/null +++ b/src/runtime/wgpu/shaders/pad_u32.wgsl @@ -0,0 +1,77 @@ +// Auto-generated pad operation for u32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_DIMS: u32 = 8u; + +// Use vec4 for 16-byte alignment in uniform buffer +struct PadParams { + ndim: u32, + total_elements: u32, + fill_value: u32, + _pad0: u32, + src_shape: array, 2>, // 8 u32 values packed into 2 vec4 + out_shape: array, 2>, + pad_before: array, 2>, +} + +// Helper to access packed array, 2> by index +fn get_packed_value(arr: array, 2>, d: i32) -> u32 { + let vec_idx = u32(d) / 4u; + let comp_idx = u32(d) % 4u; + if (vec_idx == 0u) { + if (comp_idx == 0u) { return arr[0].x; } + else if (comp_idx == 1u) { return arr[0].y; } + else if (comp_idx == 2u) { return arr[0].z; } + else { return arr[0].w; } + } else { + if (comp_idx == 0u) { return arr[1].x; } + else if (comp_idx == 1u) { return arr[1].y; } + else if (comp_idx == 2u) { return arr[1].z; } + else { return arr[1].w; } + } +} + +@group(0) @binding(0) var pad_src: array; +@group(0) @binding(1) var pad_dst: array; +@group(0) @binding(2) var pad_params: PadParams; + +@compute @workgroup_size(256) +fn pad_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= pad_params.total_elements) { + return; + } + + // Decompose idx into multi-dimensional output coordinates + var remaining = idx; + var coords: array; + var in_bounds = true; + + // Process dimensions from last to first + for (var d = i32(pad_params.ndim) - 1; d >= 0; d = d - 1) { + let out_dim = get_packed_value(pad_params.out_shape, d); + coords[d] = remaining % out_dim; + remaining = remaining / out_dim; + + // Check if coordinate is in original tensor region + let pb = get_packed_value(pad_params.pad_before, d); + let ss = get_packed_value(pad_params.src_shape, d); + if (coords[d] < pb || coords[d] >= pb + ss) { + in_bounds = false; + } + } + + if (in_bounds) { + // Compute source index + var src_idx = 0u; + var src_stride = 1u; + for (var d = i32(pad_params.ndim) - 1; d >= 0; d = d - 1) { + let src_coord = coords[d] - get_packed_value(pad_params.pad_before, d); + src_idx = src_idx + src_coord * src_stride; + src_stride = src_stride * get_packed_value(pad_params.src_shape, d); + } + pad_dst[idx] = pad_src[src_idx]; + } else { + pad_dst[idx] = pad_params.fill_value; + } +} diff --git a/src/runtime/wgpu/shaders/parlett_column_f32.wgsl b/src/runtime/wgpu/shaders/parlett_column_f32.wgsl new file mode 100644 index 00000000..ef77f6f3 --- /dev/null +++ b/src/runtime/wgpu/shaders/parlett_column_f32.wgsl @@ -0,0 +1,54 @@ +// Parlett recurrence for off-diagonal elements - f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct Params { + n: u32, + col: u32, // Current column being processed + eps: f32, + _pad: u32, +} + +@group(0) @binding(0) var input_t: array; +@group(0) @binding(1) var output_f: array; +@group(0) @binding(2) var params: Params; + +@compute @workgroup_size(WORKGROUP_SIZE) +fn parlett_column_f32(@builtin(global_invocation_id) gid: vec3) { + let n = params.n; + let j = params.col; + let eps = f32(params.eps); + + // Each thread handles one row i < j + let i = gid.x; + if i >= j { + return; + } + + let t_ii = input_t[i * n + i]; + let t_jj = input_t[j * n + j]; + let t_ij = input_t[i * n + j]; + + let denom = t_ii - t_jj; + + // Compute the sum term + var sum: f32 = 0.0; + for (var k: u32 = i + 1u; k < j; k = k + 1u) { + let f_ik = output_f[i * n + k]; + let t_kj = input_t[k * n + j]; + let t_ik = input_t[i * n + k]; + let f_kj = output_f[k * n + j]; + sum = sum + f_ik * t_kj - t_ik * f_kj; + } + + let f_ii = output_f[i * n + i]; + let f_jj = output_f[j * n + j]; + + // F[i,j] = (T[i,j] * (F[i,i] - F[j,j]) + sum) / (T[i,i] - T[j,j]) + if abs(denom) > eps { + output_f[i * n + j] = (t_ij * (f_ii - f_jj) + sum) / denom; + } else { + // Eigenvalues too close - use limit formula + output_f[i * n + j] = t_ij * f_ii; // Simplified fallback + } +} diff --git a/src/runtime/wgpu/shaders/poisson_f32.wgsl b/src/runtime/wgpu/shaders/poisson_f32.wgsl new file mode 100644 index 00000000..0f670f5c --- /dev/null +++ b/src/runtime/wgpu/shaders/poisson_f32.wgsl @@ -0,0 +1,65 @@ +// Poisson distribution sampling for f32 + +// PCG hash function for random number generation +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; +} + +// Box-Muller for normal distribution +fn sample_normal(state: ptr) -> f32 { + let u1 = max(pcg_uniform(state), 0.0000001); + let u2 = pcg_uniform(state); + return sqrt(-2.0 * log(u1)) * cos(6.28318530718 * u2); +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct PoissonParams { + numel: u32, + seed: u32, + lambda: f32, + _pad: u32, +} + +@group(0) @binding(0) var out: array; +@group(0) @binding(1) var params: PoissonParams; + +@compute @workgroup_size(256) +fn poisson_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if idx < params.numel { + var state = pcg_init(params.seed, idx); + + // Knuth's algorithm for small lambda + if params.lambda < 30.0 { + let L = exp(-params.lambda); + var k = 0u; + var p = 1.0; + + for (var i = 0u; i < 1000u; i = i + 1u) { + p = p * pcg_uniform(&state); + if p <= L { + break; + } + k = k + 1u; + } + out[idx] = f32(f32(k)); + } else { + // Normal approximation for large lambda + let z = sample_normal(&state); + let result = max(0.0, round(params.lambda + sqrt(params.lambda) * z)); + out[idx] = f32(result); + } + } +} diff --git a/src/runtime/wgpu/shaders/rand_f32.wgsl b/src/runtime/wgpu/shaders/rand_f32.wgsl new file mode 100644 index 00000000..f096cc8f --- /dev/null +++ b/src/runtime/wgpu/shaders/rand_f32.wgsl @@ -0,0 +1,51 @@ +// Auto-generated rand operation for f32 + +// PCG hash function for random number generation +// Based on PCG Random Number Generation by Melissa O'Neill +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +// Initialize PCG state from seed and index +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +// Generate uniform float in [0, 1) +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; // Divide by 2^32 +} + +// Box-Muller transform for normal distribution +// Generates one normal value, requires two uniform values +fn box_muller(u1: f32, u2: f32) -> f32 { + let u1_safe = max(u1, 0.0000001); // Avoid log(0) + let r = sqrt(-2.0 * log(u1_safe)); + let theta = 6.28318530718 * u2; // 2 * PI + return r * cos(theta); +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct RandParams { + numel: u32, + seed: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var rand_out: array; +@group(0) @binding(1) var rand_params: RandParams; + +@compute @workgroup_size(256) +fn rand_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < rand_params.numel) { + var state = pcg_init(rand_params.seed, idx); + let value = pcg_uniform(&state); + rand_out[idx] = f32(value); + } +} diff --git a/src/runtime/wgpu/shaders/randint_i32.wgsl b/src/runtime/wgpu/shaders/randint_i32.wgsl new file mode 100644 index 00000000..4028687d --- /dev/null +++ b/src/runtime/wgpu/shaders/randint_i32.wgsl @@ -0,0 +1,54 @@ +// Auto-generated randint operation for i32 (signed) + +// PCG hash function for random number generation +// Based on PCG Random Number Generation by Melissa O'Neill +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +// Initialize PCG state from seed and index +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +// Generate uniform float in [0, 1) +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; // Divide by 2^32 +} + +// Box-Muller transform for normal distribution +// Generates one normal value, requires two uniform values +fn box_muller(u1: f32, u2: f32) -> f32 { + let u1_safe = max(u1, 0.0000001); // Avoid log(0) + let r = sqrt(-2.0 * log(u1_safe)); + let theta = 6.28318530718 * u2; // 2 * PI + return r * cos(theta); +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct RandintParams { + numel: u32, + low: i32, // Low bound as signed integer + range: u32, // high - low (always positive, fits in u32) + seed: u32, +} + +@group(0) @binding(0) var randint_out: array; +@group(0) @binding(1) var randint_params: RandintParams; + +@compute @workgroup_size(256) +fn randint_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < randint_params.numel) { + var state = pcg_init(randint_params.seed, idx); + let r = pcg_hash(state); + // Compute offset in unsigned space, then add to signed low + let offset = r % randint_params.range; + // Safe: offset < range, so low + offset won't overflow if inputs are valid + randint_out[idx] = randint_params.low + i32(offset); + } +} diff --git a/src/runtime/wgpu/shaders/randint_u32.wgsl b/src/runtime/wgpu/shaders/randint_u32.wgsl new file mode 100644 index 00000000..f75e9e65 --- /dev/null +++ b/src/runtime/wgpu/shaders/randint_u32.wgsl @@ -0,0 +1,53 @@ +// Auto-generated randint operation for u32 (unsigned) + +// PCG hash function for random number generation +// Based on PCG Random Number Generation by Melissa O'Neill +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +// Initialize PCG state from seed and index +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +// Generate uniform float in [0, 1) +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; // Divide by 2^32 +} + +// Box-Muller transform for normal distribution +// Generates one normal value, requires two uniform values +fn box_muller(u1: f32, u2: f32) -> f32 { + let u1_safe = max(u1, 0.0000001); // Avoid log(0) + let r = sqrt(-2.0 * log(u1_safe)); + let theta = 6.28318530718 * u2; // 2 * PI + return r * cos(theta); +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct RandintParams { + numel: u32, + low: u32, // Low bound as unsigned integer + range: u32, // high - low + seed: u32, +} + +@group(0) @binding(0) var randint_out: array; +@group(0) @binding(1) var randint_params: RandintParams; + +@compute @workgroup_size(256) +fn randint_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < randint_params.numel) { + var state = pcg_init(randint_params.seed, idx); + let r = pcg_hash(state); + // Pure unsigned arithmetic - no overflow for valid inputs + let offset = r % randint_params.range; + randint_out[idx] = randint_params.low + offset; + } +} diff --git a/src/runtime/wgpu/shaders/randn_f32.wgsl b/src/runtime/wgpu/shaders/randn_f32.wgsl new file mode 100644 index 00000000..d6c54af6 --- /dev/null +++ b/src/runtime/wgpu/shaders/randn_f32.wgsl @@ -0,0 +1,54 @@ +// Auto-generated randn operation for f32 + +// PCG hash function for random number generation +// Based on PCG Random Number Generation by Melissa O'Neill +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +// Initialize PCG state from seed and index +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +// Generate uniform float in [0, 1) +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; // Divide by 2^32 +} + +// Box-Muller transform for normal distribution +// Generates one normal value, requires two uniform values +fn box_muller(u1: f32, u2: f32) -> f32 { + let u1_safe = max(u1, 0.0000001); // Avoid log(0) + let r = sqrt(-2.0 * log(u1_safe)); + let theta = 6.28318530718 * u2; // 2 * PI + return r * cos(theta); +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct RandnParams { + numel: u32, + seed: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var randn_out: array; +@group(0) @binding(1) var randn_params: RandnParams; + +@compute @workgroup_size(256) +fn randn_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < randn_params.numel) { + // Use two uniform random values for Box-Muller + var state = pcg_init(randn_params.seed, idx); + let u1 = pcg_uniform(&state); + let u2 = pcg_uniform(&state); + let value = box_muller(u1, u2); + randn_out[idx] = f32(value); + } +} diff --git a/src/runtime/wgpu/shaders/real_complex64.wgsl b/src/runtime/wgpu/shaders/real_complex64.wgsl new file mode 100644 index 00000000..33763e65 --- /dev/null +++ b/src/runtime/wgpu/shaders/real_complex64.wgsl @@ -0,0 +1,18 @@ +// Complex real-part extraction shader +// entry point: real_complex64 + +struct Params { + numel: u32, +} + +@group(0) @binding(0) var input: array>; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: Params; + +@compute @workgroup_size(256) +fn real_complex64(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < params.numel) { + output[idx] = input[idx].x; // Extract real component + } +} diff --git a/src/runtime/wgpu/shaders/reduce.rs b/src/runtime/wgpu/shaders/reduce.rs index b9858fc6..976cd020 100644 --- a/src/runtime/wgpu/shaders/reduce.rs +++ b/src/runtime/wgpu/shaders/reduce.rs @@ -1,141 +1,22 @@ -//! Reduction WGSL kernel launchers -//! -//! Provides launchers for reduction operations including: -//! - Sum, Mean, Max, Min, Prod, Any, All reductions along specified dimensions -//! - Argmax, Argmin (returns indices) -//! - Softmax (numerically stable) -//! -//! Multi-dtype support: F32, I32, U32 (F16 requires shader-f16 extension) -//! All operations run entirely on GPU with no CPU fallback. - -use std::collections::HashMap; -use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; - -// ============================================================================ -// Lock Helpers (Handle Poisoned Locks Gracefully) -// ============================================================================ - -/// Acquire read lock, recovering from poison if necessary. -fn read_lock(lock: &RwLock) -> RwLockReadGuard<'_, T> { - lock.read().unwrap_or_else(|poisoned| poisoned.into_inner()) -} - -/// Acquire write lock, recovering from poison if necessary. -fn write_lock(lock: &RwLock) -> RwLockWriteGuard<'_, T> { - lock.write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} +//! Reduction WGSL kernel launchers. F32, I32, U32. use wgpu::{Buffer, Queue}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; -use super::reduce_wgsl::{ - REDUCE_SHADER, generate_reduce_shader, is_float_only_op, is_supported_dtype, -}; use crate::dtype::DType; use crate::error::{Error, Result}; -// ============================================================================ -// Shader Module Cache -// ============================================================================ - -/// Cache for dtype-specific shader modules -/// Key: (dtype suffix), Value: generated shader source -static SHADER_CACHE: RwLock>> = RwLock::new(None); - -/// Get or generate shader for a specific dtype -fn get_shader_for_dtype(dtype: DType) -> String { - // Check cache first - { - let cache = read_lock(&SHADER_CACHE); - if let Some(ref map) = *cache - && let Some(shader) = map.get(&dtype) - { - return shader.clone(); - } - } - - // Generate and cache - let shader = generate_reduce_shader(dtype); - { - let mut cache = write_lock(&SHADER_CACHE); - let map = cache.get_or_insert_with(HashMap::new); - map.insert(dtype, shader.clone()); - } - shader -} - -/// Get the module key for a dtype -fn module_key(dtype: DType) -> String { - match dtype { - DType::F32 => "reduce_f32".to_string(), - DType::I32 => "reduce_i32".to_string(), - DType::U32 => "reduce_u32".to_string(), - _ => "reduce_f32".to_string(), // Fallback - } -} - -// ============================================================================ -// Helper Functions -// ============================================================================ - -/// Check if dtype is supported, returning appropriate error if not -fn check_dtype_supported(dtype: DType, op: &'static str) -> Result<()> { - if !is_supported_dtype(dtype) { - return Err(Error::UnsupportedDType { dtype, op }); - } - // Float-only operations (mean, softmax) require F32 - if is_float_only_op(op) && dtype != DType::F32 { - return Err(Error::UnsupportedDType { dtype, op }); - } - Ok(()) -} - -/// Get entry point name for reduce operation -fn reduce_entry_point(op: &str, dtype: DType) -> String { - let suffix = match dtype { - DType::F32 => "f32", - DType::I32 => "i32", - DType::U32 => "u32", - _ => "f32", - }; - format!("reduce_{}_{}", op, suffix) -} - -/// Get entry point name for full reduce operation -fn full_reduce_entry_point(op: &str, dtype: DType) -> String { - let suffix = match dtype { - DType::F32 => "f32", - DType::I32 => "i32", - DType::U32 => "u32", - _ => "f32", - }; - format!("full_reduce_{}_{}", op, suffix) -} - -/// Get entry point name for argreduce operation -fn argreduce_entry_point(op: &str, dtype: DType) -> String { - let suffix = match dtype { - DType::F32 => "f32", - DType::I32 => "i32", - DType::U32 => "u32", - _ => "f32", - }; - format!("{}_{}", op, suffix) -} +const REDUCE_F32_SHADER: &str = include_str!("reduce.wgsl"); +const REDUCE_I32_SHADER: &str = include_str!("reduce_i32.wgsl"); +const REDUCE_U32_SHADER: &str = include_str!("reduce_u32.wgsl"); // ============================================================================ // Single-Dimension Reduction // ============================================================================ -/// Launch a reduction operation kernel along a single dimension. -/// -/// Supports F32, I32, U32 dtypes. Mean is F32-only. +/// Launch a reduction operation along a single dimension. F32, I32, U32. /// -/// Parameters: -/// - reduce_size: Size of the dimension being reduced -/// - outer_size: Product of dimensions before the reduce dimension -/// - inner_size: Product of dimensions after the reduce dimension +/// Supported ops: "sum", "mean" (F32 only), "max", "min", "prod", "any", "all" pub fn launch_reduce_op( cache: &PipelineCache, queue: &Queue, @@ -146,41 +27,37 @@ pub fn launch_reduce_op( numel_out: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, op)?; + // mean is F32-only + if op == "mean" && dtype != DType::F32 { + return Err(Error::UnsupportedDType { dtype, op }); + } - let entry_point = reduce_entry_point(op, dtype); - // Leak entry_point to get static reference (cached, so leak is acceptable) - let static_entry_point: &'static str = Box::leak(entry_point.into_boxed_str()); + let (module_key, shader, suffix) = match dtype { + DType::F32 => ("reduce_f32", REDUCE_F32_SHADER, "f32"), + DType::I32 => ("reduce_i32", REDUCE_I32_SHADER, "i32"), + DType::U32 => ("reduce_u32", REDUCE_U32_SHADER, "u32"), + _ => return Err(Error::UnsupportedDType { dtype, op }), + }; - // Use F32 shader for backward compatibility, or dtype-specific for I32/U32 - let (module_name, shader_source): (&str, &str) = if dtype == DType::F32 { - ("reduce", REDUCE_SHADER) - } else { - // For I32/U32, we need to use the generated shader - // But since we can't easily pass owned String to get_or_create_module, - // we'll use a static approach with leaked strings (acceptable for caching) - let shader = get_shader_for_dtype(dtype); - let key = module_key(dtype); - // Leak the strings to get static references (these are cached, so leak is acceptable) - let static_key: &'static str = Box::leak(key.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - (static_key, static_shader) + let entry_point: String = match op { + "sum" | "mean" | "max" | "min" | "prod" | "any" | "all" => { + format!("reduce_{}_{}", op, suffix) + } + _ => return Err(Error::Internal(format!("Unknown reduce op: {}", op))), }; - let module = cache.get_or_create_module(module_name, shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(module_name, static_entry_point, &module, &layout); - + let pipeline = cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]); let mut encoder = cache .device() .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(op) }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some(op), @@ -188,10 +65,8 @@ pub fn launch_reduce_op( }); pass.set_pipeline(&pipeline); pass.set_bind_group(0, Some(&bind_group), &[]); - // One workgroup per output element pass.dispatch_workgroups(numel_out as u32, 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } @@ -200,10 +75,9 @@ pub fn launch_reduce_op( // Full Reduction (all elements to single value) // ============================================================================ -/// Launch a full reduction operation kernel. +/// Launch a full reduction kernel (reduce all elements). F32, I32, U32. /// -/// Supports F32, I32, U32 dtypes. -/// Reduces all elements to a single value using two-pass reduction. +/// Supported ops: "sum", "max", "min", "prod" pub fn launch_full_reduce_op( cache: &PipelineCache, queue: &Queue, @@ -214,36 +88,30 @@ pub fn launch_full_reduce_op( numel: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, op)?; - - let entry_point = full_reduce_entry_point(op, dtype); - // Leak entry_point to get static reference (cached, so leak is acceptable) - let static_entry_point: &'static str = Box::leak(entry_point.into_boxed_str()); + let (module_key, shader, suffix) = match dtype { + DType::F32 => ("reduce_f32", REDUCE_F32_SHADER, "f32"), + DType::I32 => ("reduce_i32", REDUCE_I32_SHADER, "i32"), + DType::U32 => ("reduce_u32", REDUCE_U32_SHADER, "u32"), + _ => return Err(Error::UnsupportedDType { dtype, op }), + }; - let (module_name, shader_source): (&str, &str) = if dtype == DType::F32 { - ("reduce", REDUCE_SHADER) - } else { - let shader = get_shader_for_dtype(dtype); - let key = module_key(dtype); - let static_key: &'static str = Box::leak(key.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - (static_key, static_shader) + let entry_point: String = match op { + "sum" | "max" | "min" | "prod" => format!("full_reduce_{}_{}", op, suffix), + _ => return Err(Error::Internal(format!("Unknown full reduce op: {}", op))), }; - let module = cache.get_or_create_module(module_name, shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(module_name, static_entry_point, &module, &layout); - + let pipeline = cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]); let mut encoder = cache .device() .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(op) }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some(op), @@ -251,10 +119,8 @@ pub fn launch_full_reduce_op( }); pass.set_pipeline(&pipeline); pass.set_bind_group(0, Some(&bind_group), &[]); - // Use enough workgroups to cover all elements pass.dispatch_workgroups(workgroup_count(numel), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } @@ -263,10 +129,9 @@ pub fn launch_full_reduce_op( // Argmax / Argmin // ============================================================================ -/// Launch argmax/argmin operation kernel. +/// Launch argmax/argmin kernel. F32, I32, U32. /// -/// Supports F32, I32, U32 dtypes. -/// Returns indices of max/min values along specified dimension. +/// Supported ops: "argmax", "argmin" pub fn launch_argreduce_op( cache: &PipelineCache, queue: &Queue, @@ -277,36 +142,30 @@ pub fn launch_argreduce_op( numel_out: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, op)?; - - let entry_point = argreduce_entry_point(op, dtype); - // Leak entry_point to get static reference (cached, so leak is acceptable) - let static_entry_point: &'static str = Box::leak(entry_point.into_boxed_str()); + let (module_key, shader, suffix) = match dtype { + DType::F32 => ("reduce_f32", REDUCE_F32_SHADER, "f32"), + DType::I32 => ("reduce_i32", REDUCE_I32_SHADER, "i32"), + DType::U32 => ("reduce_u32", REDUCE_U32_SHADER, "u32"), + _ => return Err(Error::UnsupportedDType { dtype, op }), + }; - let (module_name, shader_source): (&str, &str) = if dtype == DType::F32 { - ("reduce", REDUCE_SHADER) - } else { - let shader = get_shader_for_dtype(dtype); - let key = module_key(dtype); - let static_key: &'static str = Box::leak(key.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - (static_key, static_shader) + let entry_point: String = match op { + "argmax" | "argmin" => format!("{}_{}", op, suffix), + _ => return Err(Error::Internal(format!("Unknown argreduce op: {}", op))), }; - let module = cache.get_or_create_module(module_name, shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(module_name, static_entry_point, &module, &layout); - + let pipeline = cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]); let mut encoder = cache .device() .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(op) }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some(op), @@ -314,10 +173,8 @@ pub fn launch_argreduce_op( }); pass.set_pipeline(&pipeline); pass.set_bind_group(0, Some(&bind_group), &[]); - // One workgroup per output element pass.dispatch_workgroups(numel_out as u32, 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } @@ -326,10 +183,7 @@ pub fn launch_argreduce_op( // Softmax // ============================================================================ -/// Launch softmax operation kernel. -/// -/// F32 only - softmax is a floating-point operation. -/// Computes numerically stable softmax over the last dimension. +/// Launch softmax kernel. F32 only. pub fn launch_softmax_op( cache: &PipelineCache, queue: &Queue, @@ -339,16 +193,20 @@ pub fn launch_softmax_op( batch_size: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "softmax")?; + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "softmax", + }); + } - let module = cache.get_or_create_module("reduce", REDUCE_SHADER); + let module = cache.get_or_create_module("reduce_f32", REDUCE_F32_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline("reduce", "softmax_f32", &module, &layout); - + let pipeline = cache.get_or_create_pipeline("reduce_f32", "softmax_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]); let mut encoder = cache @@ -356,7 +214,6 @@ pub fn launch_softmax_op( .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("softmax"), }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("softmax"), @@ -364,10 +221,56 @@ pub fn launch_softmax_op( }); pass.set_pipeline(&pipeline); pass.set_bind_group(0, Some(&bind_group), &[]); - // One workgroup per batch element pass.dispatch_workgroups(batch_size as u32, 1, 1); } + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} +/// Launch softmax backward kernel. F32 only. +/// +/// d_input = output * (grad - sum(grad * output)) +pub fn launch_softmax_bwd_op( + cache: &PipelineCache, + queue: &Queue, + grad: &Buffer, + output: &Buffer, + d_input: &Buffer, + params_buffer: &Buffer, + batch_size: usize, + dtype: DType, +) -> Result<()> { + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "softmax_bwd", + }); + } + + let module = cache.get_or_create_module("reduce_f32", REDUCE_F32_SHADER); + // 2 read-only storage (grad, output) + 1 read-write (d_input) + 1 uniform + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 3, + num_uniform_buffers: 1, + num_readonly_storage: 2, + }); + let pipeline = cache.get_or_create_pipeline("reduce_f32", "softmax_bwd_f32", &module, &layout); + let bind_group = cache.create_bind_group(&layout, &[grad, output, d_input, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("softmax_bwd"), + }); + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("softmax_bwd"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + pass.dispatch_workgroups(batch_size as u32, 1, 1); + } queue.submit(std::iter::once(encoder.finish())); Ok(()) } diff --git a/src/runtime/wgpu/shaders/reduce.wgsl b/src/runtime/wgpu/shaders/reduce.wgsl new file mode 100644 index 00000000..c17ac618 --- /dev/null +++ b/src/runtime/wgpu/shaders/reduce.wgsl @@ -0,0 +1,749 @@ +// Reduction operations. F32 only. +// Entry points: reduce_sum_f32, reduce_mean_f32, reduce_max_f32, reduce_min_f32, +// reduce_prod_f32, reduce_any_f32, reduce_all_f32, +// full_reduce_sum_f32, full_reduce_max_f32, full_reduce_min_f32, full_reduce_prod_f32, +// argmax_f32, argmin_f32, softmax_f32 + +// ============================================================================ +// Workgroup Configuration +// ============================================================================ + +const WORKGROUP_SIZE: u32 = 256u; + +// Shared memory for parallel reduction +var reduce_shared: array; + +// ============================================================================ +// Reduction Parameters +// ============================================================================ + +struct ReduceParams { + reduce_size: u32, + outer_size: u32, + inner_size: u32, + numel_out: u32, +} + +@group(0) @binding(0) var reduce_input: array; +@group(0) @binding(1) var reduce_output: array; +@group(0) @binding(2) var reduce_params: ReduceParams; + +// ============================================================================ +// Sum Reduction +// ============================================================================ + +@compute @workgroup_size(256) +fn reduce_sum_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + + if (output_idx >= reduce_params.numel_out) { + return; + } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var sum: f32 = 0.0; + var i: u32 = tid; + while (i < reduce_size) { + let input_idx = base_offset + i * inner_size; + sum = sum + reduce_input[input_idx]; + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = sum; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + reduce_output[output_idx] = reduce_shared[0]; + } +} + +// ============================================================================ +// Mean Reduction +// ============================================================================ + +@compute @workgroup_size(256) +fn reduce_mean_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + + if (output_idx >= reduce_params.numel_out) { + return; + } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var sum: f32 = 0.0; + var i: u32 = tid; + while (i < reduce_size) { + let input_idx = base_offset + i * inner_size; + sum = sum + reduce_input[input_idx]; + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = sum; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + reduce_output[output_idx] = reduce_shared[0] / f32(reduce_size); + } +} + +// ============================================================================ +// Max Reduction +// ============================================================================ + +@compute @workgroup_size(256) +fn reduce_max_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + + if (output_idx >= reduce_params.numel_out) { + return; + } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var max_val: f32 = -3.40282346638528859812e+38; + var i: u32 = tid; + while (i < reduce_size) { + let input_idx = base_offset + i * inner_size; + max_val = max(max_val, reduce_input[input_idx]); + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = max_val; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); + } + workgroupBarrier(); + } + + if (tid == 0u) { + reduce_output[output_idx] = reduce_shared[0]; + } +} + +// ============================================================================ +// Min Reduction +// ============================================================================ + +@compute @workgroup_size(256) +fn reduce_min_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + + if (output_idx >= reduce_params.numel_out) { + return; + } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var min_val: f32 = 3.40282346638528859812e+38; + var i: u32 = tid; + while (i < reduce_size) { + let input_idx = base_offset + i * inner_size; + min_val = min(min_val, reduce_input[input_idx]); + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = min_val; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); + } + workgroupBarrier(); + } + + if (tid == 0u) { + reduce_output[output_idx] = reduce_shared[0]; + } +} + +// ============================================================================ +// Product Reduction +// ============================================================================ + +@compute @workgroup_size(256) +fn reduce_prod_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + + if (output_idx >= reduce_params.numel_out) { + return; + } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var prod: f32 = 1.0; + var i: u32 = tid; + while (i < reduce_size) { + let input_idx = base_offset + i * inner_size; + prod = prod * reduce_input[input_idx]; + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = prod; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + reduce_shared[tid] = reduce_shared[tid] * reduce_shared[tid + s]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + reduce_output[output_idx] = reduce_shared[0]; + } +} + +// ============================================================================ +// Any Reduction (returns 1.0 if any element is non-zero, 0.0 otherwise) +// ============================================================================ + +@compute @workgroup_size(256) +fn reduce_any_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + + if (output_idx >= reduce_params.numel_out) { + return; + } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var found_nonzero: f32 = 0.0; + var i: u32 = tid; + while (i < reduce_size) { + let input_idx = base_offset + i * inner_size; + if (reduce_input[input_idx] != 0.0) { + found_nonzero = 1.0; + } + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = found_nonzero; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); + } + workgroupBarrier(); + } + + if (tid == 0u) { + reduce_output[output_idx] = reduce_shared[0]; + } +} + +// ============================================================================ +// All Reduction (returns 1.0 if all elements are non-zero, 0.0 otherwise) +// ============================================================================ + +@compute @workgroup_size(256) +fn reduce_all_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + + if (output_idx >= reduce_params.numel_out) { + return; + } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var all_nonzero: f32 = 1.0; + var i: u32 = tid; + while (i < reduce_size) { + let input_idx = base_offset + i * inner_size; + if (reduce_input[input_idx] == 0.0) { + all_nonzero = 0.0; + } + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = all_nonzero; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); + } + workgroupBarrier(); + } + + if (tid == 0u) { + reduce_output[output_idx] = reduce_shared[0]; + } +} + +// ============================================================================ +// Full Reduction (reduce all elements to single value) +// ============================================================================ + +struct FullReduceParams { + numel: u32, +} + +@group(0) @binding(0) var full_reduce_input: array; +@group(0) @binding(1) var full_reduce_output: array; +@group(0) @binding(2) var full_reduce_params: FullReduceParams; + +@compute @workgroup_size(256) +fn full_reduce_sum_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3, + @builtin(num_workgroups) num_groups: vec3) { + let tid = local_id.x; + let wid = group_id.x; + let numel = full_reduce_params.numel; + + var sum: f32 = 0.0; + var i: u32 = wid * WORKGROUP_SIZE + tid; + let stride = num_groups.x * WORKGROUP_SIZE; + + while (i < numel) { + sum = sum + full_reduce_input[i]; + i = i + stride; + } + + reduce_shared[tid] = sum; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + full_reduce_output[wid] = reduce_shared[0]; + } +} + +@compute @workgroup_size(256) +fn full_reduce_max_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3, + @builtin(num_workgroups) num_groups: vec3) { + let tid = local_id.x; + let wid = group_id.x; + let numel = full_reduce_params.numel; + + var max_val: f32 = -3.40282346638528859812e+38; + var i: u32 = wid * WORKGROUP_SIZE + tid; + let stride = num_groups.x * WORKGROUP_SIZE; + + while (i < numel) { + max_val = max(max_val, full_reduce_input[i]); + i = i + stride; + } + + reduce_shared[tid] = max_val; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); + } + workgroupBarrier(); + } + + if (tid == 0u) { + full_reduce_output[wid] = reduce_shared[0]; + } +} + +@compute @workgroup_size(256) +fn full_reduce_min_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3, + @builtin(num_workgroups) num_groups: vec3) { + let tid = local_id.x; + let wid = group_id.x; + let numel = full_reduce_params.numel; + + var min_val: f32 = 3.40282346638528859812e+38; + var i: u32 = wid * WORKGROUP_SIZE + tid; + let stride = num_groups.x * WORKGROUP_SIZE; + + while (i < numel) { + min_val = min(min_val, full_reduce_input[i]); + i = i + stride; + } + + reduce_shared[tid] = min_val; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); + } + workgroupBarrier(); + } + + if (tid == 0u) { + full_reduce_output[wid] = reduce_shared[0]; + } +} + +@compute @workgroup_size(256) +fn full_reduce_prod_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3, + @builtin(num_workgroups) num_groups: vec3) { + let tid = local_id.x; + let wid = group_id.x; + let numel = full_reduce_params.numel; + + var prod: f32 = 1.0; + var i: u32 = wid * WORKGROUP_SIZE + tid; + let stride = num_groups.x * WORKGROUP_SIZE; + + while (i < numel) { + prod = prod * full_reduce_input[i]; + i = i + stride; + } + + reduce_shared[tid] = prod; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + reduce_shared[tid] = reduce_shared[tid] * reduce_shared[tid + s]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + full_reduce_output[wid] = reduce_shared[0]; + } +} + +// ============================================================================ +// Argmax / Argmin (returns index of max/min value) +// ============================================================================ + +var argmax_shared_val: array; +var argmax_shared_idx: array; + +struct ArgReduceParams { + reduce_size: u32, + outer_size: u32, + inner_size: u32, + numel_out: u32, +} + +@group(0) @binding(0) var argreduce_input: array; +@group(0) @binding(1) var argreduce_output: array; +@group(0) @binding(2) var argreduce_params: ArgReduceParams; + +@compute @workgroup_size(256) +fn argmax_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + + if (output_idx >= argreduce_params.numel_out) { + return; + } + + let reduce_size = argreduce_params.reduce_size; + let inner_size = argreduce_params.inner_size; + + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var max_val: f32 = -3.40282346638528859812e+38; + var max_idx: u32 = 0u; + var i: u32 = tid; + + while (i < reduce_size) { + let input_idx = base_offset + i * inner_size; + let val = argreduce_input[input_idx]; + if (val > max_val) { + max_val = val; + max_idx = i; + } + i = i + WORKGROUP_SIZE; + } + + argmax_shared_val[tid] = max_val; + argmax_shared_idx[tid] = max_idx; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + if (argmax_shared_val[tid + s] > argmax_shared_val[tid]) { + argmax_shared_val[tid] = argmax_shared_val[tid + s]; + argmax_shared_idx[tid] = argmax_shared_idx[tid + s]; + } + } + workgroupBarrier(); + } + + if (tid == 0u) { + argreduce_output[output_idx] = argmax_shared_idx[0]; + } +} + +@compute @workgroup_size(256) +fn argmin_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + + if (output_idx >= argreduce_params.numel_out) { + return; + } + + let reduce_size = argreduce_params.reduce_size; + let inner_size = argreduce_params.inner_size; + + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var min_val: f32 = 3.40282346638528859812e+38; + var min_idx: u32 = 0u; + var i: u32 = tid; + + while (i < reduce_size) { + let input_idx = base_offset + i * inner_size; + let val = argreduce_input[input_idx]; + if (val < min_val) { + min_val = val; + min_idx = i; + } + i = i + WORKGROUP_SIZE; + } + + argmax_shared_val[tid] = min_val; + argmax_shared_idx[tid] = min_idx; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + if (argmax_shared_val[tid + s] < argmax_shared_val[tid]) { + argmax_shared_val[tid] = argmax_shared_val[tid + s]; + argmax_shared_idx[tid] = argmax_shared_idx[tid + s]; + } + } + workgroupBarrier(); + } + + if (tid == 0u) { + argreduce_output[output_idx] = argmax_shared_idx[0]; + } +} + +// ============================================================================ +// Softmax (numerically stable) +// ============================================================================ + +struct SoftmaxParams { + batch_size: u32, + dim_size: u32, +} + +@group(0) @binding(0) var softmax_input: array; +@group(0) @binding(1) var softmax_output: array; +@group(0) @binding(2) var softmax_params: SoftmaxParams; + +var softmax_shared: array; + +@compute @workgroup_size(256) +fn softmax_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let batch_idx = group_id.x; + + if (batch_idx >= softmax_params.batch_size) { + return; + } + + let dim_size = softmax_params.dim_size; + let base_offset = batch_idx * dim_size; + + // Step 1: Find max for numerical stability + var max_val: f32 = -3.40282346638528859812e+38; + var i: u32 = tid; + while (i < dim_size) { + max_val = max(max_val, softmax_input[base_offset + i]); + i = i + WORKGROUP_SIZE; + } + + softmax_shared[tid] = max_val; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + softmax_shared[tid] = max(softmax_shared[tid], softmax_shared[tid + s]); + } + workgroupBarrier(); + } + + let global_max = softmax_shared[0]; + workgroupBarrier(); + + // Step 2: Compute sum of exp(x - max) + var sum: f32 = 0.0; + i = tid; + while (i < dim_size) { + sum = sum + exp(softmax_input[base_offset + i] - global_max); + i = i + WORKGROUP_SIZE; + } + + softmax_shared[tid] = sum; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + softmax_shared[tid] = softmax_shared[tid] + softmax_shared[tid + s]; + } + workgroupBarrier(); + } + + let global_sum = softmax_shared[0]; + workgroupBarrier(); + + // Step 3: Compute output = exp(x - max) / sum + i = tid; + while (i < dim_size) { + softmax_output[base_offset + i] = exp(softmax_input[base_offset + i] - global_max) / global_sum; + i = i + WORKGROUP_SIZE; + } +} + +// ============================================================================ +// Softmax Backward +// d_input = output * (grad - dot), where dot = sum(grad * output) +// Uses same SoftmaxParams (batch_size, dim_size) +// Bindings: 0=grad(read), 1=output(read), 2=d_input(write), 3=params +// ============================================================================ + +@group(0) @binding(0) var sbwd_grad: array; +@group(0) @binding(1) var sbwd_output: array; +@group(0) @binding(2) var sbwd_d_input: array; +@group(0) @binding(3) var sbwd_params: SoftmaxParams; + +var sbwd_shared: array; + +@compute @workgroup_size(256) +fn softmax_bwd_f32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let batch_idx = group_id.x; + + if (batch_idx >= sbwd_params.batch_size) { + return; + } + + let dim_size = sbwd_params.dim_size; + let base_offset = batch_idx * dim_size; + + // Pass 1: dot = sum(grad * output) + var dot: f32 = 0.0; + var i: u32 = tid; + while (i < dim_size) { + dot = dot + sbwd_grad[base_offset + i] * sbwd_output[base_offset + i]; + i = i + WORKGROUP_SIZE; + } + + sbwd_shared[tid] = dot; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + sbwd_shared[tid] = sbwd_shared[tid] + sbwd_shared[tid + s]; + } + workgroupBarrier(); + } + + let global_dot = sbwd_shared[0]; + workgroupBarrier(); + + // Pass 2: d_input = output * (grad - dot) + i = tid; + while (i < dim_size) { + let idx = base_offset + i; + sbwd_d_input[idx] = sbwd_output[idx] * (sbwd_grad[idx] - global_dot); + i = i + WORKGROUP_SIZE; + } +} diff --git a/src/runtime/wgpu/shaders/reduce_i32.wgsl b/src/runtime/wgpu/shaders/reduce_i32.wgsl new file mode 100644 index 00000000..4f2c62fc --- /dev/null +++ b/src/runtime/wgpu/shaders/reduce_i32.wgsl @@ -0,0 +1,414 @@ +// Reduction operations for I32. +// Entry points: reduce_sum_i32, reduce_max_i32, reduce_min_i32, +// reduce_prod_i32, reduce_any_i32, reduce_all_i32, +// full_reduce_sum_i32, full_reduce_max_i32, full_reduce_min_i32, full_reduce_prod_i32, +// argmax_i32, argmin_i32 + +const WORKGROUP_SIZE: u32 = 256u; + +var reduce_shared: array; + +struct ReduceParams { + reduce_size: u32, + outer_size: u32, + inner_size: u32, + numel_out: u32, +} + +@group(0) @binding(0) var reduce_input: array; +@group(0) @binding(1) var reduce_output: array; +@group(0) @binding(2) var reduce_params: ReduceParams; + +@compute @workgroup_size(256) +fn reduce_sum_i32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= reduce_params.numel_out) { return; } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var sum: i32 = 0; + var i: u32 = tid; + while (i < reduce_size) { + sum = sum + reduce_input[base_offset + i * inner_size]; + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = sum; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; } + workgroupBarrier(); + } + + if (tid == 0u) { reduce_output[output_idx] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn reduce_max_i32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= reduce_params.numel_out) { return; } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var max_val: i32 = (-2147483647i - 1i); + var i: u32 = tid; + while (i < reduce_size) { + max_val = max(max_val, reduce_input[base_offset + i * inner_size]); + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = max_val; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); } + workgroupBarrier(); + } + + if (tid == 0u) { reduce_output[output_idx] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn reduce_min_i32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= reduce_params.numel_out) { return; } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var min_val: i32 = 2147483647i; + var i: u32 = tid; + while (i < reduce_size) { + min_val = min(min_val, reduce_input[base_offset + i * inner_size]); + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = min_val; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); } + workgroupBarrier(); + } + + if (tid == 0u) { reduce_output[output_idx] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn reduce_prod_i32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= reduce_params.numel_out) { return; } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var prod: i32 = 1; + var i: u32 = tid; + while (i < reduce_size) { + prod = prod * reduce_input[base_offset + i * inner_size]; + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = prod; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = reduce_shared[tid] * reduce_shared[tid + s]; } + workgroupBarrier(); + } + + if (tid == 0u) { reduce_output[output_idx] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn reduce_any_i32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= reduce_params.numel_out) { return; } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var found_nonzero: i32 = 0; + var i: u32 = tid; + while (i < reduce_size) { + if (reduce_input[base_offset + i * inner_size] != 0) { found_nonzero = 1; } + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = found_nonzero; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); } + workgroupBarrier(); + } + + if (tid == 0u) { reduce_output[output_idx] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn reduce_all_i32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= reduce_params.numel_out) { return; } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var all_nonzero: i32 = 1; + var i: u32 = tid; + while (i < reduce_size) { + if (reduce_input[base_offset + i * inner_size] == 0) { all_nonzero = 0; } + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = all_nonzero; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); } + workgroupBarrier(); + } + + if (tid == 0u) { reduce_output[output_idx] = reduce_shared[0]; } +} + +// ============================================================================ +// Full Reduction +// ============================================================================ + +struct FullReduceParams { + numel: u32, +} + +@group(0) @binding(0) var full_reduce_input: array; +@group(0) @binding(1) var full_reduce_output: array; +@group(0) @binding(2) var full_reduce_params: FullReduceParams; + +@compute @workgroup_size(256) +fn full_reduce_sum_i32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3, + @builtin(num_workgroups) num_groups: vec3) { + let tid = local_id.x; + let wid = group_id.x; + let numel = full_reduce_params.numel; + + var sum: i32 = 0; + var i: u32 = wid * WORKGROUP_SIZE + tid; + let stride = num_groups.x * WORKGROUP_SIZE; + while (i < numel) { sum = sum + full_reduce_input[i]; i = i + stride; } + + reduce_shared[tid] = sum; + workgroupBarrier(); + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; } + workgroupBarrier(); + } + if (tid == 0u) { full_reduce_output[wid] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn full_reduce_max_i32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3, + @builtin(num_workgroups) num_groups: vec3) { + let tid = local_id.x; + let wid = group_id.x; + let numel = full_reduce_params.numel; + + var max_val: i32 = (-2147483647i - 1i); + var i: u32 = wid * WORKGROUP_SIZE + tid; + let stride = num_groups.x * WORKGROUP_SIZE; + while (i < numel) { max_val = max(max_val, full_reduce_input[i]); i = i + stride; } + + reduce_shared[tid] = max_val; + workgroupBarrier(); + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); } + workgroupBarrier(); + } + if (tid == 0u) { full_reduce_output[wid] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn full_reduce_min_i32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3, + @builtin(num_workgroups) num_groups: vec3) { + let tid = local_id.x; + let wid = group_id.x; + let numel = full_reduce_params.numel; + + var min_val: i32 = 2147483647i; + var i: u32 = wid * WORKGROUP_SIZE + tid; + let stride = num_groups.x * WORKGROUP_SIZE; + while (i < numel) { min_val = min(min_val, full_reduce_input[i]); i = i + stride; } + + reduce_shared[tid] = min_val; + workgroupBarrier(); + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); } + workgroupBarrier(); + } + if (tid == 0u) { full_reduce_output[wid] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn full_reduce_prod_i32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3, + @builtin(num_workgroups) num_groups: vec3) { + let tid = local_id.x; + let wid = group_id.x; + let numel = full_reduce_params.numel; + + var prod: i32 = 1; + var i: u32 = wid * WORKGROUP_SIZE + tid; + let stride = num_groups.x * WORKGROUP_SIZE; + while (i < numel) { prod = prod * full_reduce_input[i]; i = i + stride; } + + reduce_shared[tid] = prod; + workgroupBarrier(); + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = reduce_shared[tid] * reduce_shared[tid + s]; } + workgroupBarrier(); + } + if (tid == 0u) { full_reduce_output[wid] = reduce_shared[0]; } +} + +// ============================================================================ +// Argmax / Argmin +// ============================================================================ + +var argmax_shared_val: array; +var argmax_shared_idx: array; + +struct ArgReduceParams { + reduce_size: u32, + outer_size: u32, + inner_size: u32, + numel_out: u32, +} + +@group(0) @binding(0) var argreduce_input: array; +@group(0) @binding(1) var argreduce_output: array; +@group(0) @binding(2) var argreduce_params: ArgReduceParams; + +@compute @workgroup_size(256) +fn argmax_i32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= argreduce_params.numel_out) { return; } + + let reduce_size = argreduce_params.reduce_size; + let inner_size = argreduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var max_val: i32 = (-2147483647i - 1i); + var max_idx: u32 = 0u; + var i: u32 = tid; + while (i < reduce_size) { + let val = argreduce_input[base_offset + i * inner_size]; + if (val > max_val) { max_val = val; max_idx = i; } + i = i + WORKGROUP_SIZE; + } + + argmax_shared_val[tid] = max_val; + argmax_shared_idx[tid] = max_idx; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + if (argmax_shared_val[tid + s] > argmax_shared_val[tid]) { + argmax_shared_val[tid] = argmax_shared_val[tid + s]; + argmax_shared_idx[tid] = argmax_shared_idx[tid + s]; + } + } + workgroupBarrier(); + } + + if (tid == 0u) { argreduce_output[output_idx] = argmax_shared_idx[0]; } +} + +@compute @workgroup_size(256) +fn argmin_i32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= argreduce_params.numel_out) { return; } + + let reduce_size = argreduce_params.reduce_size; + let inner_size = argreduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var min_val: i32 = 2147483647i; + var min_idx: u32 = 0u; + var i: u32 = tid; + while (i < reduce_size) { + let val = argreduce_input[base_offset + i * inner_size]; + if (val < min_val) { min_val = val; min_idx = i; } + i = i + WORKGROUP_SIZE; + } + + argmax_shared_val[tid] = min_val; + argmax_shared_idx[tid] = min_idx; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + if (argmax_shared_val[tid + s] < argmax_shared_val[tid]) { + argmax_shared_val[tid] = argmax_shared_val[tid + s]; + argmax_shared_idx[tid] = argmax_shared_idx[tid + s]; + } + } + workgroupBarrier(); + } + + if (tid == 0u) { argreduce_output[output_idx] = argmax_shared_idx[0]; } +} diff --git a/src/runtime/wgpu/shaders/reduce_u32.wgsl b/src/runtime/wgpu/shaders/reduce_u32.wgsl new file mode 100644 index 00000000..a312eb51 --- /dev/null +++ b/src/runtime/wgpu/shaders/reduce_u32.wgsl @@ -0,0 +1,414 @@ +// Reduction operations for U32. +// Entry points: reduce_sum_u32, reduce_max_u32, reduce_min_u32, +// reduce_prod_u32, reduce_any_u32, reduce_all_u32, +// full_reduce_sum_u32, full_reduce_max_u32, full_reduce_min_u32, full_reduce_prod_u32, +// argmax_u32, argmin_u32 + +const WORKGROUP_SIZE: u32 = 256u; + +var reduce_shared: array; + +struct ReduceParams { + reduce_size: u32, + outer_size: u32, + inner_size: u32, + numel_out: u32, +} + +@group(0) @binding(0) var reduce_input: array; +@group(0) @binding(1) var reduce_output: array; +@group(0) @binding(2) var reduce_params: ReduceParams; + +@compute @workgroup_size(256) +fn reduce_sum_u32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= reduce_params.numel_out) { return; } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var sum: u32 = 0u; + var i: u32 = tid; + while (i < reduce_size) { + sum = sum + reduce_input[base_offset + i * inner_size]; + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = sum; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; } + workgroupBarrier(); + } + + if (tid == 0u) { reduce_output[output_idx] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn reduce_max_u32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= reduce_params.numel_out) { return; } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var max_val: u32 = 0u; + var i: u32 = tid; + while (i < reduce_size) { + max_val = max(max_val, reduce_input[base_offset + i * inner_size]); + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = max_val; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); } + workgroupBarrier(); + } + + if (tid == 0u) { reduce_output[output_idx] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn reduce_min_u32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= reduce_params.numel_out) { return; } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var min_val: u32 = 4294967295u; + var i: u32 = tid; + while (i < reduce_size) { + min_val = min(min_val, reduce_input[base_offset + i * inner_size]); + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = min_val; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); } + workgroupBarrier(); + } + + if (tid == 0u) { reduce_output[output_idx] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn reduce_prod_u32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= reduce_params.numel_out) { return; } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var prod: u32 = 1u; + var i: u32 = tid; + while (i < reduce_size) { + prod = prod * reduce_input[base_offset + i * inner_size]; + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = prod; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = reduce_shared[tid] * reduce_shared[tid + s]; } + workgroupBarrier(); + } + + if (tid == 0u) { reduce_output[output_idx] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn reduce_any_u32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= reduce_params.numel_out) { return; } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var found_nonzero: u32 = 0u; + var i: u32 = tid; + while (i < reduce_size) { + if (reduce_input[base_offset + i * inner_size] != 0u) { found_nonzero = 1u; } + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = found_nonzero; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); } + workgroupBarrier(); + } + + if (tid == 0u) { reduce_output[output_idx] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn reduce_all_u32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= reduce_params.numel_out) { return; } + + let reduce_size = reduce_params.reduce_size; + let inner_size = reduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var all_nonzero: u32 = 1u; + var i: u32 = tid; + while (i < reduce_size) { + if (reduce_input[base_offset + i * inner_size] == 0u) { all_nonzero = 0u; } + i = i + WORKGROUP_SIZE; + } + + reduce_shared[tid] = all_nonzero; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); } + workgroupBarrier(); + } + + if (tid == 0u) { reduce_output[output_idx] = reduce_shared[0]; } +} + +// ============================================================================ +// Full Reduction +// ============================================================================ + +struct FullReduceParams { + numel: u32, +} + +@group(0) @binding(0) var full_reduce_input: array; +@group(0) @binding(1) var full_reduce_output: array; +@group(0) @binding(2) var full_reduce_params: FullReduceParams; + +@compute @workgroup_size(256) +fn full_reduce_sum_u32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3, + @builtin(num_workgroups) num_groups: vec3) { + let tid = local_id.x; + let wid = group_id.x; + let numel = full_reduce_params.numel; + + var sum: u32 = 0u; + var i: u32 = wid * WORKGROUP_SIZE + tid; + let stride = num_groups.x * WORKGROUP_SIZE; + while (i < numel) { sum = sum + full_reduce_input[i]; i = i + stride; } + + reduce_shared[tid] = sum; + workgroupBarrier(); + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; } + workgroupBarrier(); + } + if (tid == 0u) { full_reduce_output[wid] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn full_reduce_max_u32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3, + @builtin(num_workgroups) num_groups: vec3) { + let tid = local_id.x; + let wid = group_id.x; + let numel = full_reduce_params.numel; + + var max_val: u32 = 0u; + var i: u32 = wid * WORKGROUP_SIZE + tid; + let stride = num_groups.x * WORKGROUP_SIZE; + while (i < numel) { max_val = max(max_val, full_reduce_input[i]); i = i + stride; } + + reduce_shared[tid] = max_val; + workgroupBarrier(); + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); } + workgroupBarrier(); + } + if (tid == 0u) { full_reduce_output[wid] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn full_reduce_min_u32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3, + @builtin(num_workgroups) num_groups: vec3) { + let tid = local_id.x; + let wid = group_id.x; + let numel = full_reduce_params.numel; + + var min_val: u32 = 4294967295u; + var i: u32 = wid * WORKGROUP_SIZE + tid; + let stride = num_groups.x * WORKGROUP_SIZE; + while (i < numel) { min_val = min(min_val, full_reduce_input[i]); i = i + stride; } + + reduce_shared[tid] = min_val; + workgroupBarrier(); + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); } + workgroupBarrier(); + } + if (tid == 0u) { full_reduce_output[wid] = reduce_shared[0]; } +} + +@compute @workgroup_size(256) +fn full_reduce_prod_u32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3, + @builtin(num_workgroups) num_groups: vec3) { + let tid = local_id.x; + let wid = group_id.x; + let numel = full_reduce_params.numel; + + var prod: u32 = 1u; + var i: u32 = wid * WORKGROUP_SIZE + tid; + let stride = num_groups.x * WORKGROUP_SIZE; + while (i < numel) { prod = prod * full_reduce_input[i]; i = i + stride; } + + reduce_shared[tid] = prod; + workgroupBarrier(); + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { reduce_shared[tid] = reduce_shared[tid] * reduce_shared[tid + s]; } + workgroupBarrier(); + } + if (tid == 0u) { full_reduce_output[wid] = reduce_shared[0]; } +} + +// ============================================================================ +// Argmax / Argmin +// ============================================================================ + +var argmax_shared_val: array; +var argmax_shared_idx: array; + +struct ArgReduceParams { + reduce_size: u32, + outer_size: u32, + inner_size: u32, + numel_out: u32, +} + +@group(0) @binding(0) var argreduce_input: array; +@group(0) @binding(1) var argreduce_output: array; +@group(0) @binding(2) var argreduce_params: ArgReduceParams; + +@compute @workgroup_size(256) +fn argmax_u32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= argreduce_params.numel_out) { return; } + + let reduce_size = argreduce_params.reduce_size; + let inner_size = argreduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var max_val: u32 = 0u; + var max_idx: u32 = 0u; + var i: u32 = tid; + while (i < reduce_size) { + let val = argreduce_input[base_offset + i * inner_size]; + if (val > max_val) { max_val = val; max_idx = i; } + i = i + WORKGROUP_SIZE; + } + + argmax_shared_val[tid] = max_val; + argmax_shared_idx[tid] = max_idx; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + if (argmax_shared_val[tid + s] > argmax_shared_val[tid]) { + argmax_shared_val[tid] = argmax_shared_val[tid + s]; + argmax_shared_idx[tid] = argmax_shared_idx[tid + s]; + } + } + workgroupBarrier(); + } + + if (tid == 0u) { argreduce_output[output_idx] = argmax_shared_idx[0]; } +} + +@compute @workgroup_size(256) +fn argmin_u32(@builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3) { + let tid = local_id.x; + let output_idx = group_id.x; + if (output_idx >= argreduce_params.numel_out) { return; } + + let reduce_size = argreduce_params.reduce_size; + let inner_size = argreduce_params.inner_size; + let outer = output_idx / inner_size; + let inner = output_idx % inner_size; + let base_offset = outer * reduce_size * inner_size + inner; + + var min_val: u32 = 4294967295u; + var min_idx: u32 = 0u; + var i: u32 = tid; + while (i < reduce_size) { + let val = argreduce_input[base_offset + i * inner_size]; + if (val < min_val) { min_val = val; min_idx = i; } + i = i + WORKGROUP_SIZE; + } + + argmax_shared_val[tid] = min_val; + argmax_shared_idx[tid] = min_idx; + workgroupBarrier(); + + for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { + if (tid < s) { + if (argmax_shared_val[tid + s] < argmax_shared_val[tid]) { + argmax_shared_val[tid] = argmax_shared_val[tid + s]; + argmax_shared_idx[tid] = argmax_shared_idx[tid + s]; + } + } + workgroupBarrier(); + } + + if (tid == 0u) { argreduce_output[output_idx] = argmax_shared_idx[0]; } +} diff --git a/src/runtime/wgpu/shaders/reduce_wgsl.rs b/src/runtime/wgpu/shaders/reduce_wgsl.rs deleted file mode 100644 index f3c9581a..00000000 --- a/src/runtime/wgpu/shaders/reduce_wgsl.rs +++ /dev/null @@ -1,1525 +0,0 @@ -//! WGSL shader source code for reduction operations -//! -//! Includes sum, mean, max, min, prod, any, all reductions along specified dimensions. -//! Uses workgroup-level parallel reduction for efficiency. -//! -//! Multi-dtype support: F32, I32, U32 (F16 requires shader-f16 extension) - -use crate::dtype::DType; - -/// Get WGSL type name for a dtype -fn wgsl_type(dtype: DType) -> &'static str { - match dtype { - DType::F32 => "f32", - DType::I32 => "i32", - DType::U32 => "u32", - // F16 requires extension, so we fallback to f32 accumulation - _ => "f32", - } -} - -/// Get dtype suffix for kernel naming -fn dtype_suffix(dtype: DType) -> &'static str { - match dtype { - DType::F32 => "f32", - DType::I32 => "i32", - DType::U32 => "u32", - _ => "f32", - } -} - -/// Get the identity value for sum (zero) -fn zero_value(dtype: DType) -> &'static str { - match dtype { - DType::F32 => "0.0", - DType::I32 => "0", - DType::U32 => "0u", - _ => "0.0", - } -} - -/// Get the identity value for prod (one) -fn one_value(dtype: DType) -> &'static str { - match dtype { - DType::F32 => "1.0", - DType::I32 => "1", - DType::U32 => "1u", - _ => "1.0", - } -} - -/// Get the minimum value for max reduction initialization -fn neg_inf_value(dtype: DType) -> &'static str { - match dtype { - DType::F32 => "-3.40282346638528859812e+38", - DType::I32 => "-2147483648", // i32::MIN - DType::U32 => "0u", // u32 has no negative, use 0 - _ => "-3.40282346638528859812e+38", - } -} - -/// Get the maximum value for min reduction initialization -fn pos_inf_value(dtype: DType) -> &'static str { - match dtype { - DType::F32 => "3.40282346638528859812e+38", - DType::I32 => "2147483647", // i32::MAX - DType::U32 => "4294967295u", // u32::MAX - _ => "3.40282346638528859812e+38", - } -} - -/// Generate the reduce shader for a specific dtype -pub fn generate_reduce_shader(dtype: DType) -> String { - let wgsl_t = wgsl_type(dtype); - let suffix = dtype_suffix(dtype); - let zero = zero_value(dtype); - let one = one_value(dtype); - let neg_inf = neg_inf_value(dtype); - let pos_inf = pos_inf_value(dtype); - - // Use f32 for reduction accumulation for better precision (integers use native) - let acc_type = match dtype { - DType::F32 => "f32", - DType::I32 => "i32", - DType::U32 => "u32", - _ => "f32", - }; - - format!( - r#" -// ============================================================================ -// Workgroup Configuration -// ============================================================================ - -const WORKGROUP_SIZE: u32 = 256u; - -// Shared memory for parallel reduction -var reduce_shared: array<{acc_type}, 256>; - -// ============================================================================ -// Reduction Parameters -// ============================================================================ - -struct ReduceParams {{ - reduce_size: u32, - outer_size: u32, - inner_size: u32, - numel_out: u32, -}} - -@group(0) @binding(0) var reduce_input: array<{wgsl_t}>; -@group(0) @binding(1) var reduce_output: array<{wgsl_t}>; -@group(0) @binding(2) var reduce_params: ReduceParams; - -// ============================================================================ -// Sum Reduction -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_sum_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) {{ - return; - }} - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var sum: {acc_type} = {zero}; - var i: u32 = tid; - while (i < reduce_size) {{ - let input_idx = base_offset + i * inner_size; - sum = sum + {acc_type}(reduce_input[input_idx]); - i = i + WORKGROUP_SIZE; - }} - - reduce_shared[tid] = sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - reduce_output[output_idx] = {wgsl_t}(reduce_shared[0]); - }} -}} - -// ============================================================================ -// Max Reduction -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_max_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) {{ - return; - }} - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var max_val: {acc_type} = {neg_inf}; - var i: u32 = tid; - while (i < reduce_size) {{ - let input_idx = base_offset + i * inner_size; - max_val = max(max_val, {acc_type}(reduce_input[input_idx])); - i = i + WORKGROUP_SIZE; - }} - - reduce_shared[tid] = max_val; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - reduce_output[output_idx] = {wgsl_t}(reduce_shared[0]); - }} -}} - -// ============================================================================ -// Min Reduction -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_min_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) {{ - return; - }} - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var min_val: {acc_type} = {pos_inf}; - var i: u32 = tid; - while (i < reduce_size) {{ - let input_idx = base_offset + i * inner_size; - min_val = min(min_val, {acc_type}(reduce_input[input_idx])); - i = i + WORKGROUP_SIZE; - }} - - reduce_shared[tid] = min_val; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - reduce_output[output_idx] = {wgsl_t}(reduce_shared[0]); - }} -}} - -// ============================================================================ -// Product Reduction -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_prod_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) {{ - return; - }} - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var prod: {acc_type} = {one}; - var i: u32 = tid; - while (i < reduce_size) {{ - let input_idx = base_offset + i * inner_size; - prod = prod * {acc_type}(reduce_input[input_idx]); - i = i + WORKGROUP_SIZE; - }} - - reduce_shared[tid] = prod; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - reduce_shared[tid] = reduce_shared[tid] * reduce_shared[tid + s]; - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - reduce_output[output_idx] = {wgsl_t}(reduce_shared[0]); - }} -}} - -// ============================================================================ -// Any Reduction (returns 1 if any element is non-zero, 0 otherwise) -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_any_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) {{ - return; - }} - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var found_nonzero: {acc_type} = {zero}; - var i: u32 = tid; - while (i < reduce_size) {{ - let input_idx = base_offset + i * inner_size; - if (reduce_input[input_idx] != {zero}) {{ - found_nonzero = {one}; - }} - i = i + WORKGROUP_SIZE; - }} - - reduce_shared[tid] = found_nonzero; - workgroupBarrier(); - - // OR logic via max (0 or 1) - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - reduce_output[output_idx] = {wgsl_t}(reduce_shared[0]); - }} -}} - -// ============================================================================ -// All Reduction (returns 1 if all elements are non-zero, 0 otherwise) -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_all_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) {{ - return; - }} - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var all_nonzero: {acc_type} = {one}; - var i: u32 = tid; - while (i < reduce_size) {{ - let input_idx = base_offset + i * inner_size; - if (reduce_input[input_idx] == {zero}) {{ - all_nonzero = {zero}; - }} - i = i + WORKGROUP_SIZE; - }} - - reduce_shared[tid] = all_nonzero; - workgroupBarrier(); - - // AND logic via min (0 or 1) - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - reduce_output[output_idx] = {wgsl_t}(reduce_shared[0]); - }} -}} - -// ============================================================================ -// Full Reduction (reduce all elements to single value) -// ============================================================================ - -struct FullReduceParams {{ - numel: u32, -}} - -@group(0) @binding(0) var full_reduce_input: array<{wgsl_t}>; -@group(0) @binding(1) var full_reduce_output: array<{wgsl_t}>; -@group(0) @binding(2) var full_reduce_params: FullReduceParams; - -@compute @workgroup_size(256) -fn full_reduce_sum_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3, - @builtin(num_workgroups) num_groups: vec3) {{ - let tid = local_id.x; - let wid = group_id.x; - let numel = full_reduce_params.numel; - - var sum: {acc_type} = {zero}; - var i: u32 = wid * WORKGROUP_SIZE + tid; - let stride = num_groups.x * WORKGROUP_SIZE; - - while (i < numel) {{ - sum = sum + {acc_type}(full_reduce_input[i]); - i = i + stride; - }} - - reduce_shared[tid] = sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - full_reduce_output[wid] = {wgsl_t}(reduce_shared[0]); - }} -}} - -@compute @workgroup_size(256) -fn full_reduce_max_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3, - @builtin(num_workgroups) num_groups: vec3) {{ - let tid = local_id.x; - let wid = group_id.x; - let numel = full_reduce_params.numel; - - var max_val: {acc_type} = {neg_inf}; - var i: u32 = wid * WORKGROUP_SIZE + tid; - let stride = num_groups.x * WORKGROUP_SIZE; - - while (i < numel) {{ - max_val = max(max_val, {acc_type}(full_reduce_input[i])); - i = i + stride; - }} - - reduce_shared[tid] = max_val; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - full_reduce_output[wid] = {wgsl_t}(reduce_shared[0]); - }} -}} - -@compute @workgroup_size(256) -fn full_reduce_min_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3, - @builtin(num_workgroups) num_groups: vec3) {{ - let tid = local_id.x; - let wid = group_id.x; - let numel = full_reduce_params.numel; - - var min_val: {acc_type} = {pos_inf}; - var i: u32 = wid * WORKGROUP_SIZE + tid; - let stride = num_groups.x * WORKGROUP_SIZE; - - while (i < numel) {{ - min_val = min(min_val, {acc_type}(full_reduce_input[i])); - i = i + stride; - }} - - reduce_shared[tid] = min_val; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - full_reduce_output[wid] = {wgsl_t}(reduce_shared[0]); - }} -}} - -@compute @workgroup_size(256) -fn full_reduce_prod_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3, - @builtin(num_workgroups) num_groups: vec3) {{ - let tid = local_id.x; - let wid = group_id.x; - let numel = full_reduce_params.numel; - - var prod: {acc_type} = {one}; - var i: u32 = wid * WORKGROUP_SIZE + tid; - let stride = num_groups.x * WORKGROUP_SIZE; - - while (i < numel) {{ - prod = prod * {acc_type}(full_reduce_input[i]); - i = i + stride; - }} - - reduce_shared[tid] = prod; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - reduce_shared[tid] = reduce_shared[tid] * reduce_shared[tid + s]; - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - full_reduce_output[wid] = {wgsl_t}(reduce_shared[0]); - }} -}} - -// ============================================================================ -// Argmax / Argmin (returns index of max/min value) -// ============================================================================ - -var argmax_shared_val: array<{acc_type}, 256>; -var argmax_shared_idx: array; - -struct ArgReduceParams {{ - reduce_size: u32, - outer_size: u32, - inner_size: u32, - numel_out: u32, -}} - -@group(0) @binding(0) var argreduce_input: array<{wgsl_t}>; -@group(0) @binding(1) var argreduce_output: array; -@group(0) @binding(2) var argreduce_params: ArgReduceParams; - -@compute @workgroup_size(256) -fn argmax_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= argreduce_params.numel_out) {{ - return; - }} - - let reduce_size = argreduce_params.reduce_size; - let inner_size = argreduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var max_val: {acc_type} = {neg_inf}; - var max_idx: u32 = 0u; - var i: u32 = tid; - - while (i < reduce_size) {{ - let input_idx = base_offset + i * inner_size; - let val = {acc_type}(argreduce_input[input_idx]); - if (val > max_val) {{ - max_val = val; - max_idx = i; - }} - i = i + WORKGROUP_SIZE; - }} - - argmax_shared_val[tid] = max_val; - argmax_shared_idx[tid] = max_idx; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - if (argmax_shared_val[tid + s] > argmax_shared_val[tid]) {{ - argmax_shared_val[tid] = argmax_shared_val[tid + s]; - argmax_shared_idx[tid] = argmax_shared_idx[tid + s]; - }} - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - argreduce_output[output_idx] = argmax_shared_idx[0]; - }} -}} - -@compute @workgroup_size(256) -fn argmin_{suffix}(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) {{ - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= argreduce_params.numel_out) {{ - return; - }} - - let reduce_size = argreduce_params.reduce_size; - let inner_size = argreduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var min_val: {acc_type} = {pos_inf}; - var min_idx: u32 = 0u; - var i: u32 = tid; - - while (i < reduce_size) {{ - let input_idx = base_offset + i * inner_size; - let val = {acc_type}(argreduce_input[input_idx]); - if (val < min_val) {{ - min_val = val; - min_idx = i; - }} - i = i + WORKGROUP_SIZE; - }} - - argmax_shared_val[tid] = min_val; - argmax_shared_idx[tid] = min_idx; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) {{ - if (tid < s) {{ - if (argmax_shared_val[tid + s] < argmax_shared_val[tid]) {{ - argmax_shared_val[tid] = argmax_shared_val[tid + s]; - argmax_shared_idx[tid] = argmax_shared_idx[tid + s]; - }} - }} - workgroupBarrier(); - }} - - if (tid == 0u) {{ - argreduce_output[output_idx] = argmax_shared_idx[0]; - }} -}} -"#, - wgsl_t = wgsl_t, - suffix = suffix, - acc_type = acc_type, - zero = zero, - one = one, - neg_inf = neg_inf, - pos_inf = pos_inf - ) -} - -/// Generate F32-only mean and softmax shader (float-specific operations) -#[allow(dead_code)] -pub fn generate_float_reduce_shader() -> &'static str { - r#" -// ============================================================================ -// Float-only operations (mean, softmax) -// These operations only make sense for floating-point types -// ============================================================================ - -const WORKGROUP_SIZE: u32 = 256u; - -var reduce_shared: array; - -struct ReduceParams { - reduce_size: u32, - outer_size: u32, - inner_size: u32, - numel_out: u32, -} - -@group(0) @binding(0) var reduce_input: array; -@group(0) @binding(1) var reduce_output: array; -@group(0) @binding(2) var reduce_params: ReduceParams; - -// ============================================================================ -// Mean Reduction (F32 only) -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_mean_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) { - return; - } - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var sum: f32 = 0.0; - var i: u32 = tid; - while (i < reduce_size) { - let input_idx = base_offset + i * inner_size; - sum = sum + reduce_input[input_idx]; - i = i + WORKGROUP_SIZE; - } - - reduce_shared[tid] = sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; - } - workgroupBarrier(); - } - - if (tid == 0u) { - reduce_output[output_idx] = reduce_shared[0] / f32(reduce_size); - } -} - -// ============================================================================ -// Softmax (F32 only - numerically stable) -// ============================================================================ - -struct SoftmaxParams { - batch_size: u32, - dim_size: u32, -} - -@group(0) @binding(0) var softmax_input: array; -@group(0) @binding(1) var softmax_output: array; -@group(0) @binding(2) var softmax_params: SoftmaxParams; - -var softmax_shared: array; - -@compute @workgroup_size(256) -fn softmax_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let batch_idx = group_id.x; - - if (batch_idx >= softmax_params.batch_size) { - return; - } - - let dim_size = softmax_params.dim_size; - let base_offset = batch_idx * dim_size; - - // Step 1: Find max for numerical stability - var max_val: f32 = -3.40282346638528859812e+38; - var i: u32 = tid; - while (i < dim_size) { - max_val = max(max_val, softmax_input[base_offset + i]); - i = i + WORKGROUP_SIZE; - } - - softmax_shared[tid] = max_val; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - softmax_shared[tid] = max(softmax_shared[tid], softmax_shared[tid + s]); - } - workgroupBarrier(); - } - - let global_max = softmax_shared[0]; - workgroupBarrier(); - - // Step 2: Compute sum of exp(x - max) - var sum: f32 = 0.0; - i = tid; - while (i < dim_size) { - sum = sum + exp(softmax_input[base_offset + i] - global_max); - i = i + WORKGROUP_SIZE; - } - - softmax_shared[tid] = sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - softmax_shared[tid] = softmax_shared[tid] + softmax_shared[tid + s]; - } - workgroupBarrier(); - } - - let global_sum = softmax_shared[0]; - workgroupBarrier(); - - // Step 3: Compute output = exp(x - max) / sum - i = tid; - while (i < dim_size) { - softmax_output[base_offset + i] = exp(softmax_input[base_offset + i] - global_max) / global_sum; - i = i + WORKGROUP_SIZE; - } -} -"# -} - -/// Get the entry point name for a reduce operation and dtype -#[allow(dead_code)] -pub fn get_entry_point(op: &str, dtype: DType) -> String { - let suffix = dtype_suffix(dtype); - format!("{}_{}", op, suffix) -} - -/// Get the full reduce entry point name -#[allow(dead_code)] -pub fn get_full_reduce_entry_point(op: &str, dtype: DType) -> String { - let suffix = dtype_suffix(dtype); - format!("full_reduce_{}_{}", op, suffix) -} - -/// Check if dtype is supported for WebGPU reduce operations -pub fn is_supported_dtype(dtype: DType) -> bool { - matches!(dtype, DType::F32 | DType::I32 | DType::U32) -} - -/// Check if the operation is float-only -pub fn is_float_only_op(op: &str) -> bool { - matches!(op, "mean" | "softmax") -} - -// Keep the old constant for backward compatibility during migration -pub const REDUCE_SHADER: &str = r#" -// ============================================================================ -// Workgroup Configuration -// ============================================================================ - -const WORKGROUP_SIZE: u32 = 256u; - -// Shared memory for parallel reduction -var reduce_shared: array; - -// ============================================================================ -// Reduction Parameters -// ============================================================================ - -struct ReduceParams { - reduce_size: u32, - outer_size: u32, - inner_size: u32, - numel_out: u32, -} - -@group(0) @binding(0) var reduce_input: array; -@group(0) @binding(1) var reduce_output: array; -@group(0) @binding(2) var reduce_params: ReduceParams; - -// ============================================================================ -// Sum Reduction -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_sum_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) { - return; - } - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var sum: f32 = 0.0; - var i: u32 = tid; - while (i < reduce_size) { - let input_idx = base_offset + i * inner_size; - sum = sum + reduce_input[input_idx]; - i = i + WORKGROUP_SIZE; - } - - reduce_shared[tid] = sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; - } - workgroupBarrier(); - } - - if (tid == 0u) { - reduce_output[output_idx] = reduce_shared[0]; - } -} - -// ============================================================================ -// Mean Reduction -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_mean_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) { - return; - } - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var sum: f32 = 0.0; - var i: u32 = tid; - while (i < reduce_size) { - let input_idx = base_offset + i * inner_size; - sum = sum + reduce_input[input_idx]; - i = i + WORKGROUP_SIZE; - } - - reduce_shared[tid] = sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; - } - workgroupBarrier(); - } - - if (tid == 0u) { - reduce_output[output_idx] = reduce_shared[0] / f32(reduce_size); - } -} - -// ============================================================================ -// Max Reduction -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_max_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) { - return; - } - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var max_val: f32 = -3.40282346638528859812e+38; - var i: u32 = tid; - while (i < reduce_size) { - let input_idx = base_offset + i * inner_size; - max_val = max(max_val, reduce_input[input_idx]); - i = i + WORKGROUP_SIZE; - } - - reduce_shared[tid] = max_val; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); - } - workgroupBarrier(); - } - - if (tid == 0u) { - reduce_output[output_idx] = reduce_shared[0]; - } -} - -// ============================================================================ -// Min Reduction -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_min_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) { - return; - } - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var min_val: f32 = 3.40282346638528859812e+38; - var i: u32 = tid; - while (i < reduce_size) { - let input_idx = base_offset + i * inner_size; - min_val = min(min_val, reduce_input[input_idx]); - i = i + WORKGROUP_SIZE; - } - - reduce_shared[tid] = min_val; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); - } - workgroupBarrier(); - } - - if (tid == 0u) { - reduce_output[output_idx] = reduce_shared[0]; - } -} - -// ============================================================================ -// Product Reduction -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_prod_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) { - return; - } - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var prod: f32 = 1.0; - var i: u32 = tid; - while (i < reduce_size) { - let input_idx = base_offset + i * inner_size; - prod = prod * reduce_input[input_idx]; - i = i + WORKGROUP_SIZE; - } - - reduce_shared[tid] = prod; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - reduce_shared[tid] = reduce_shared[tid] * reduce_shared[tid + s]; - } - workgroupBarrier(); - } - - if (tid == 0u) { - reduce_output[output_idx] = reduce_shared[0]; - } -} - -// ============================================================================ -// Any Reduction (returns 1.0 if any element is non-zero, 0.0 otherwise) -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_any_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) { - return; - } - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var found_nonzero: f32 = 0.0; - var i: u32 = tid; - while (i < reduce_size) { - let input_idx = base_offset + i * inner_size; - if (reduce_input[input_idx] != 0.0) { - found_nonzero = 1.0; - } - i = i + WORKGROUP_SIZE; - } - - reduce_shared[tid] = found_nonzero; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); - } - workgroupBarrier(); - } - - if (tid == 0u) { - reduce_output[output_idx] = reduce_shared[0]; - } -} - -// ============================================================================ -// All Reduction (returns 1.0 if all elements are non-zero, 0.0 otherwise) -// ============================================================================ - -@compute @workgroup_size(256) -fn reduce_all_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= reduce_params.numel_out) { - return; - } - - let reduce_size = reduce_params.reduce_size; - let inner_size = reduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var all_nonzero: f32 = 1.0; - var i: u32 = tid; - while (i < reduce_size) { - let input_idx = base_offset + i * inner_size; - if (reduce_input[input_idx] == 0.0) { - all_nonzero = 0.0; - } - i = i + WORKGROUP_SIZE; - } - - reduce_shared[tid] = all_nonzero; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); - } - workgroupBarrier(); - } - - if (tid == 0u) { - reduce_output[output_idx] = reduce_shared[0]; - } -} - -// ============================================================================ -// Full Reduction (reduce all elements to single value) -// ============================================================================ - -struct FullReduceParams { - numel: u32, -} - -@group(0) @binding(0) var full_reduce_input: array; -@group(0) @binding(1) var full_reduce_output: array; -@group(0) @binding(2) var full_reduce_params: FullReduceParams; - -@compute @workgroup_size(256) -fn full_reduce_sum_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3, - @builtin(num_workgroups) num_groups: vec3) { - let tid = local_id.x; - let wid = group_id.x; - let numel = full_reduce_params.numel; - - var sum: f32 = 0.0; - var i: u32 = wid * WORKGROUP_SIZE + tid; - let stride = num_groups.x * WORKGROUP_SIZE; - - while (i < numel) { - sum = sum + full_reduce_input[i]; - i = i + stride; - } - - reduce_shared[tid] = sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - reduce_shared[tid] = reduce_shared[tid] + reduce_shared[tid + s]; - } - workgroupBarrier(); - } - - if (tid == 0u) { - full_reduce_output[wid] = reduce_shared[0]; - } -} - -@compute @workgroup_size(256) -fn full_reduce_max_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3, - @builtin(num_workgroups) num_groups: vec3) { - let tid = local_id.x; - let wid = group_id.x; - let numel = full_reduce_params.numel; - - var max_val: f32 = -3.40282346638528859812e+38; - var i: u32 = wid * WORKGROUP_SIZE + tid; - let stride = num_groups.x * WORKGROUP_SIZE; - - while (i < numel) { - max_val = max(max_val, full_reduce_input[i]); - i = i + stride; - } - - reduce_shared[tid] = max_val; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - reduce_shared[tid] = max(reduce_shared[tid], reduce_shared[tid + s]); - } - workgroupBarrier(); - } - - if (tid == 0u) { - full_reduce_output[wid] = reduce_shared[0]; - } -} - -@compute @workgroup_size(256) -fn full_reduce_min_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3, - @builtin(num_workgroups) num_groups: vec3) { - let tid = local_id.x; - let wid = group_id.x; - let numel = full_reduce_params.numel; - - var min_val: f32 = 3.40282346638528859812e+38; - var i: u32 = wid * WORKGROUP_SIZE + tid; - let stride = num_groups.x * WORKGROUP_SIZE; - - while (i < numel) { - min_val = min(min_val, full_reduce_input[i]); - i = i + stride; - } - - reduce_shared[tid] = min_val; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - reduce_shared[tid] = min(reduce_shared[tid], reduce_shared[tid + s]); - } - workgroupBarrier(); - } - - if (tid == 0u) { - full_reduce_output[wid] = reduce_shared[0]; - } -} - -@compute @workgroup_size(256) -fn full_reduce_prod_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3, - @builtin(num_workgroups) num_groups: vec3) { - let tid = local_id.x; - let wid = group_id.x; - let numel = full_reduce_params.numel; - - var prod: f32 = 1.0; - var i: u32 = wid * WORKGROUP_SIZE + tid; - let stride = num_groups.x * WORKGROUP_SIZE; - - while (i < numel) { - prod = prod * full_reduce_input[i]; - i = i + stride; - } - - reduce_shared[tid] = prod; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - reduce_shared[tid] = reduce_shared[tid] * reduce_shared[tid + s]; - } - workgroupBarrier(); - } - - if (tid == 0u) { - full_reduce_output[wid] = reduce_shared[0]; - } -} - -// ============================================================================ -// Argmax / Argmin (returns index of max/min value) -// ============================================================================ - -var argmax_shared_val: array; -var argmax_shared_idx: array; - -struct ArgReduceParams { - reduce_size: u32, - outer_size: u32, - inner_size: u32, - numel_out: u32, -} - -@group(0) @binding(0) var argreduce_input: array; -@group(0) @binding(1) var argreduce_output: array; -@group(0) @binding(2) var argreduce_params: ArgReduceParams; - -@compute @workgroup_size(256) -fn argmax_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= argreduce_params.numel_out) { - return; - } - - let reduce_size = argreduce_params.reduce_size; - let inner_size = argreduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var max_val: f32 = -3.40282346638528859812e+38; - var max_idx: u32 = 0u; - var i: u32 = tid; - - while (i < reduce_size) { - let input_idx = base_offset + i * inner_size; - let val = argreduce_input[input_idx]; - if (val > max_val) { - max_val = val; - max_idx = i; - } - i = i + WORKGROUP_SIZE; - } - - argmax_shared_val[tid] = max_val; - argmax_shared_idx[tid] = max_idx; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - if (argmax_shared_val[tid + s] > argmax_shared_val[tid]) { - argmax_shared_val[tid] = argmax_shared_val[tid + s]; - argmax_shared_idx[tid] = argmax_shared_idx[tid + s]; - } - } - workgroupBarrier(); - } - - if (tid == 0u) { - argreduce_output[output_idx] = argmax_shared_idx[0]; - } -} - -@compute @workgroup_size(256) -fn argmin_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let output_idx = group_id.x; - - if (output_idx >= argreduce_params.numel_out) { - return; - } - - let reduce_size = argreduce_params.reduce_size; - let inner_size = argreduce_params.inner_size; - - let outer = output_idx / inner_size; - let inner = output_idx % inner_size; - let base_offset = outer * reduce_size * inner_size + inner; - - var min_val: f32 = 3.40282346638528859812e+38; - var min_idx: u32 = 0u; - var i: u32 = tid; - - while (i < reduce_size) { - let input_idx = base_offset + i * inner_size; - let val = argreduce_input[input_idx]; - if (val < min_val) { - min_val = val; - min_idx = i; - } - i = i + WORKGROUP_SIZE; - } - - argmax_shared_val[tid] = min_val; - argmax_shared_idx[tid] = min_idx; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - if (argmax_shared_val[tid + s] < argmax_shared_val[tid]) { - argmax_shared_val[tid] = argmax_shared_val[tid + s]; - argmax_shared_idx[tid] = argmax_shared_idx[tid + s]; - } - } - workgroupBarrier(); - } - - if (tid == 0u) { - argreduce_output[output_idx] = argmax_shared_idx[0]; - } -} - -// ============================================================================ -// Softmax (numerically stable) -// ============================================================================ - -struct SoftmaxParams { - batch_size: u32, - dim_size: u32, -} - -@group(0) @binding(0) var softmax_input: array; -@group(0) @binding(1) var softmax_output: array; -@group(0) @binding(2) var softmax_params: SoftmaxParams; - -var softmax_shared: array; - -@compute @workgroup_size(256) -fn softmax_f32(@builtin(global_invocation_id) global_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) group_id: vec3) { - let tid = local_id.x; - let batch_idx = group_id.x; - - if (batch_idx >= softmax_params.batch_size) { - return; - } - - let dim_size = softmax_params.dim_size; - let base_offset = batch_idx * dim_size; - - // Step 1: Find max for numerical stability - var max_val: f32 = -3.40282346638528859812e+38; - var i: u32 = tid; - while (i < dim_size) { - max_val = max(max_val, softmax_input[base_offset + i]); - i = i + WORKGROUP_SIZE; - } - - softmax_shared[tid] = max_val; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - softmax_shared[tid] = max(softmax_shared[tid], softmax_shared[tid + s]); - } - workgroupBarrier(); - } - - let global_max = softmax_shared[0]; - workgroupBarrier(); - - // Step 2: Compute sum of exp(x - max) - var sum: f32 = 0.0; - i = tid; - while (i < dim_size) { - sum = sum + exp(softmax_input[base_offset + i] - global_max); - i = i + WORKGROUP_SIZE; - } - - softmax_shared[tid] = sum; - workgroupBarrier(); - - for (var s: u32 = WORKGROUP_SIZE / 2u; s > 0u; s = s >> 1u) { - if (tid < s) { - softmax_shared[tid] = softmax_shared[tid] + softmax_shared[tid + s]; - } - workgroupBarrier(); - } - - let global_sum = softmax_shared[0]; - workgroupBarrier(); - - // Step 3: Compute output = exp(x - max) / sum - i = tid; - while (i < dim_size) { - softmax_output[base_offset + i] = exp(softmax_input[base_offset + i] - global_max) / global_sum; - i = i + WORKGROUP_SIZE; - } -} -"#; diff --git a/src/runtime/wgpu/shaders/repeat_f32.wgsl b/src/runtime/wgpu/shaders/repeat_f32.wgsl new file mode 100644 index 00000000..f7401294 --- /dev/null +++ b/src/runtime/wgpu/shaders/repeat_f32.wgsl @@ -0,0 +1,69 @@ +// Auto-generated repeat operation for f32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_DIMS: u32 = 8u; + +// Use vec4 for 16-byte alignment in uniform buffer +struct RepeatParams { + ndim: u32, + total_elements: u32, + _pad0: u32, + _pad1: u32, + src_shape: array, 2>, // 8 u32 values packed into 2 vec4 + out_shape: array, 2>, +} + +// Helper to access packed array, 2> by index +fn get_packed_value(arr: array, 2>, d: i32) -> u32 { + let vec_idx = u32(d) / 4u; + let comp_idx = u32(d) % 4u; + if (vec_idx == 0u) { + if (comp_idx == 0u) { return arr[0].x; } + else if (comp_idx == 1u) { return arr[0].y; } + else if (comp_idx == 2u) { return arr[0].z; } + else { return arr[0].w; } + } else { + if (comp_idx == 0u) { return arr[1].x; } + else if (comp_idx == 1u) { return arr[1].y; } + else if (comp_idx == 2u) { return arr[1].z; } + else { return arr[1].w; } + } +} + +@group(0) @binding(0) var repeat_src: array; +@group(0) @binding(1) var repeat_dst: array; +@group(0) @binding(2) var repeat_params: RepeatParams; + +@compute @workgroup_size(256) +fn repeat_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= repeat_params.total_elements) { + return; + } + + // Decompose idx into multi-dimensional output coordinates + var remaining = idx; + var src_idx = 0u; + + // Compute source strides first (row-major) + var src_strides: array; + var stride = 1u; + for (var d = i32(repeat_params.ndim) - 1; d >= 0; d = d - 1) { + src_strides[d] = stride; + stride = stride * get_packed_value(repeat_params.src_shape, d); + } + + // Process dimensions from last to first + for (var d = i32(repeat_params.ndim) - 1; d >= 0; d = d - 1) { + let out_dim = get_packed_value(repeat_params.out_shape, d); + let coord = remaining % out_dim; + remaining = remaining / out_dim; + + // Map to source coordinate using modulo + let src_shape_d = get_packed_value(repeat_params.src_shape, d); + let src_coord = coord % src_shape_d; + src_idx = src_idx + src_coord * src_strides[d]; + } + + repeat_dst[idx] = repeat_src[src_idx]; +} diff --git a/src/runtime/wgpu/shaders/repeat_i32.wgsl b/src/runtime/wgpu/shaders/repeat_i32.wgsl new file mode 100644 index 00000000..fa240b76 --- /dev/null +++ b/src/runtime/wgpu/shaders/repeat_i32.wgsl @@ -0,0 +1,69 @@ +// Auto-generated repeat operation for i32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_DIMS: u32 = 8u; + +// Use vec4 for 16-byte alignment in uniform buffer +struct RepeatParams { + ndim: u32, + total_elements: u32, + _pad0: u32, + _pad1: u32, + src_shape: array, 2>, // 8 u32 values packed into 2 vec4 + out_shape: array, 2>, +} + +// Helper to access packed array, 2> by index +fn get_packed_value(arr: array, 2>, d: i32) -> u32 { + let vec_idx = u32(d) / 4u; + let comp_idx = u32(d) % 4u; + if (vec_idx == 0u) { + if (comp_idx == 0u) { return arr[0].x; } + else if (comp_idx == 1u) { return arr[0].y; } + else if (comp_idx == 2u) { return arr[0].z; } + else { return arr[0].w; } + } else { + if (comp_idx == 0u) { return arr[1].x; } + else if (comp_idx == 1u) { return arr[1].y; } + else if (comp_idx == 2u) { return arr[1].z; } + else { return arr[1].w; } + } +} + +@group(0) @binding(0) var repeat_src: array; +@group(0) @binding(1) var repeat_dst: array; +@group(0) @binding(2) var repeat_params: RepeatParams; + +@compute @workgroup_size(256) +fn repeat_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= repeat_params.total_elements) { + return; + } + + // Decompose idx into multi-dimensional output coordinates + var remaining = idx; + var src_idx = 0u; + + // Compute source strides first (row-major) + var src_strides: array; + var stride = 1u; + for (var d = i32(repeat_params.ndim) - 1; d >= 0; d = d - 1) { + src_strides[d] = stride; + stride = stride * get_packed_value(repeat_params.src_shape, d); + } + + // Process dimensions from last to first + for (var d = i32(repeat_params.ndim) - 1; d >= 0; d = d - 1) { + let out_dim = get_packed_value(repeat_params.out_shape, d); + let coord = remaining % out_dim; + remaining = remaining / out_dim; + + // Map to source coordinate using modulo + let src_shape_d = get_packed_value(repeat_params.src_shape, d); + let src_coord = coord % src_shape_d; + src_idx = src_idx + src_coord * src_strides[d]; + } + + repeat_dst[idx] = repeat_src[src_idx]; +} diff --git a/src/runtime/wgpu/shaders/repeat_u32.wgsl b/src/runtime/wgpu/shaders/repeat_u32.wgsl new file mode 100644 index 00000000..c4acebf9 --- /dev/null +++ b/src/runtime/wgpu/shaders/repeat_u32.wgsl @@ -0,0 +1,69 @@ +// Auto-generated repeat operation for u32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_DIMS: u32 = 8u; + +// Use vec4 for 16-byte alignment in uniform buffer +struct RepeatParams { + ndim: u32, + total_elements: u32, + _pad0: u32, + _pad1: u32, + src_shape: array, 2>, // 8 u32 values packed into 2 vec4 + out_shape: array, 2>, +} + +// Helper to access packed array, 2> by index +fn get_packed_value(arr: array, 2>, d: i32) -> u32 { + let vec_idx = u32(d) / 4u; + let comp_idx = u32(d) % 4u; + if (vec_idx == 0u) { + if (comp_idx == 0u) { return arr[0].x; } + else if (comp_idx == 1u) { return arr[0].y; } + else if (comp_idx == 2u) { return arr[0].z; } + else { return arr[0].w; } + } else { + if (comp_idx == 0u) { return arr[1].x; } + else if (comp_idx == 1u) { return arr[1].y; } + else if (comp_idx == 2u) { return arr[1].z; } + else { return arr[1].w; } + } +} + +@group(0) @binding(0) var repeat_src: array; +@group(0) @binding(1) var repeat_dst: array; +@group(0) @binding(2) var repeat_params: RepeatParams; + +@compute @workgroup_size(256) +fn repeat_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= repeat_params.total_elements) { + return; + } + + // Decompose idx into multi-dimensional output coordinates + var remaining = idx; + var src_idx = 0u; + + // Compute source strides first (row-major) + var src_strides: array; + var stride = 1u; + for (var d = i32(repeat_params.ndim) - 1; d >= 0; d = d - 1) { + src_strides[d] = stride; + stride = stride * get_packed_value(repeat_params.src_shape, d); + } + + // Process dimensions from last to first + for (var d = i32(repeat_params.ndim) - 1; d >= 0; d = d - 1) { + let out_dim = get_packed_value(repeat_params.out_shape, d); + let coord = remaining % out_dim; + remaining = remaining / out_dim; + + // Map to source coordinate using modulo + let src_shape_d = get_packed_value(repeat_params.src_shape, d); + let src_coord = coord % src_shape_d; + src_idx = src_idx + src_coord * src_strides[d]; + } + + repeat_dst[idx] = repeat_src[src_idx]; +} diff --git a/src/runtime/wgpu/shaders/rfft_pack.wgsl b/src/runtime/wgpu/shaders/rfft_pack.wgsl new file mode 100644 index 00000000..9510c0cc --- /dev/null +++ b/src/runtime/wgpu/shaders/rfft_pack.wgsl @@ -0,0 +1,32 @@ +// rfft pack shader - converts real input to complex + +const WORKGROUP_SIZE: u32 = 256u; + +struct PackParams { + n: u32, + batch_size: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var pack_input: array; +@group(0) @binding(1) var pack_output: array>; +@group(0) @binding(2) var pack_params: PackParams; + +@compute @workgroup_size(WORKGROUP_SIZE) +fn rfft_pack( + @builtin(global_invocation_id) gid: vec3 +) { + let idx = gid.x; + let batch_idx = gid.y; + let n = pack_params.n; + + if (idx >= n) { + return; + } + + let in_offset = batch_idx * n; + let out_offset = batch_idx * n; + + pack_output[out_offset + idx] = vec2(pack_input[in_offset + idx], 0.0); +} diff --git a/src/runtime/wgpu/shaders/rfft_truncate.wgsl b/src/runtime/wgpu/shaders/rfft_truncate.wgsl new file mode 100644 index 00000000..ef865b89 --- /dev/null +++ b/src/runtime/wgpu/shaders/rfft_truncate.wgsl @@ -0,0 +1,33 @@ +// rfft truncate shader - keeps only N/2+1 complex values from full FFT + +const WORKGROUP_SIZE: u32 = 256u; + +struct TruncateParams { + n: u32, // Full FFT size (input) + half_n: u32, // N/2 + 1 (output size) + batch_size: u32, + _pad: u32, +} + +@group(0) @binding(0) var truncate_input: array>; +@group(0) @binding(1) var truncate_output: array>; +@group(0) @binding(2) var truncate_params: TruncateParams; + +@compute @workgroup_size(WORKGROUP_SIZE) +fn rfft_truncate( + @builtin(global_invocation_id) gid: vec3 +) { + let idx = gid.x; + let batch_idx = gid.y; + let n = truncate_params.n; + let half_n = truncate_params.half_n; + + if (idx >= half_n) { + return; + } + + let in_offset = batch_idx * n; + let out_offset = batch_idx * half_n; + + truncate_output[out_offset + idx] = truncate_input[in_offset + idx]; +} diff --git a/src/runtime/wgpu/shaders/roll_f32.wgsl b/src/runtime/wgpu/shaders/roll_f32.wgsl new file mode 100644 index 00000000..4596a5f9 --- /dev/null +++ b/src/runtime/wgpu/shaders/roll_f32.wgsl @@ -0,0 +1,42 @@ +// Auto-generated roll operation for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct RollParams { + outer_size: u32, + dim_size: u32, + inner_size: u32, + shift: u32, + total_elements: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var roll_src: array; +@group(0) @binding(1) var roll_dst: array; +@group(0) @binding(2) var roll_params: RollParams; + +@compute @workgroup_size(256) +fn roll_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= roll_params.total_elements) { + return; + } + + // Decompose idx into (outer, dim_coord, inner) + let inner = idx % roll_params.inner_size; + let remaining = idx / roll_params.inner_size; + let dim_coord = remaining % roll_params.dim_size; + let outer = remaining / roll_params.dim_size; + + // Compute source coordinate with roll (shift goes right, so source is shift positions left) + let src_dim_coord = (dim_coord + roll_params.dim_size - roll_params.shift) % roll_params.dim_size; + + // Compute source linear index + let src_idx = outer * roll_params.dim_size * roll_params.inner_size + + src_dim_coord * roll_params.inner_size + + inner; + + roll_dst[idx] = roll_src[src_idx]; +} diff --git a/src/runtime/wgpu/shaders/roll_i32.wgsl b/src/runtime/wgpu/shaders/roll_i32.wgsl new file mode 100644 index 00000000..2c9dba98 --- /dev/null +++ b/src/runtime/wgpu/shaders/roll_i32.wgsl @@ -0,0 +1,42 @@ +// Auto-generated roll operation for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct RollParams { + outer_size: u32, + dim_size: u32, + inner_size: u32, + shift: u32, + total_elements: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var roll_src: array; +@group(0) @binding(1) var roll_dst: array; +@group(0) @binding(2) var roll_params: RollParams; + +@compute @workgroup_size(256) +fn roll_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= roll_params.total_elements) { + return; + } + + // Decompose idx into (outer, dim_coord, inner) + let inner = idx % roll_params.inner_size; + let remaining = idx / roll_params.inner_size; + let dim_coord = remaining % roll_params.dim_size; + let outer = remaining / roll_params.dim_size; + + // Compute source coordinate with roll (shift goes right, so source is shift positions left) + let src_dim_coord = (dim_coord + roll_params.dim_size - roll_params.shift) % roll_params.dim_size; + + // Compute source linear index + let src_idx = outer * roll_params.dim_size * roll_params.inner_size + + src_dim_coord * roll_params.inner_size + + inner; + + roll_dst[idx] = roll_src[src_idx]; +} diff --git a/src/runtime/wgpu/shaders/roll_u32.wgsl b/src/runtime/wgpu/shaders/roll_u32.wgsl new file mode 100644 index 00000000..5c59f16b --- /dev/null +++ b/src/runtime/wgpu/shaders/roll_u32.wgsl @@ -0,0 +1,42 @@ +// Auto-generated roll operation for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct RollParams { + outer_size: u32, + dim_size: u32, + inner_size: u32, + shift: u32, + total_elements: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var roll_src: array; +@group(0) @binding(1) var roll_dst: array; +@group(0) @binding(2) var roll_params: RollParams; + +@compute @workgroup_size(256) +fn roll_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= roll_params.total_elements) { + return; + } + + // Decompose idx into (outer, dim_coord, inner) + let inner = idx % roll_params.inner_size; + let remaining = idx / roll_params.inner_size; + let dim_coord = remaining % roll_params.dim_size; + let outer = remaining / roll_params.dim_size; + + // Compute source coordinate with roll (shift goes right, so source is shift positions left) + let src_dim_coord = (dim_coord + roll_params.dim_size - roll_params.shift) % roll_params.dim_size; + + // Compute source linear index + let src_idx = outer * roll_params.dim_size * roll_params.inner_size + + src_dim_coord * roll_params.inner_size + + inner; + + roll_dst[idx] = roll_src[src_idx]; +} diff --git a/src/runtime/wgpu/shaders/scalar.wgsl b/src/runtime/wgpu/shaders/scalar.wgsl new file mode 100644 index 00000000..a82ac86b --- /dev/null +++ b/src/runtime/wgpu/shaders/scalar.wgsl @@ -0,0 +1,80 @@ +// F32 scalar operations + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScalarParams { + numel: u32, + scalar: f32, +} + +@group(0) @binding(0) var scalar_a: array; +@group(0) @binding(1) var scalar_out: array; +@group(0) @binding(2) var scalar_params: ScalarParams; + +@compute @workgroup_size(256) +fn add_scalar_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_a[idx] + scalar_params.scalar; + } +} + +@compute @workgroup_size(256) +fn sub_scalar_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_a[idx] - scalar_params.scalar; + } +} + +@compute @workgroup_size(256) +fn rsub_scalar_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_params.scalar - scalar_a[idx]; + } +} + +@compute @workgroup_size(256) +fn mul_scalar_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_a[idx] * scalar_params.scalar; + } +} + +@compute @workgroup_size(256) +fn div_scalar_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_a[idx] / scalar_params.scalar; + } +} + +@compute @workgroup_size(256) +fn pow_scalar_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = pow(scalar_a[idx], scalar_params.scalar); + } +} + +@compute @workgroup_size(256) +fn leaky_relu_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + let x = scalar_a[idx]; + let slope = scalar_params.scalar; + scalar_out[idx] = max(slope * x, x); + } +} + +@compute @workgroup_size(256) +fn elu_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + let x = scalar_a[idx]; + let alpha = scalar_params.scalar; + scalar_out[idx] = select(alpha * (exp(x) - 1.0), x, x > 0.0); + } +} diff --git a/src/runtime/wgpu/shaders/scalar_i32.wgsl b/src/runtime/wgpu/shaders/scalar_i32.wgsl new file mode 100644 index 00000000..bbde6a2a --- /dev/null +++ b/src/runtime/wgpu/shaders/scalar_i32.wgsl @@ -0,0 +1,52 @@ +// I32 scalar operations + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScalarParams { + numel: u32, + scalar: i32, +} + +@group(0) @binding(0) var scalar_a: array; +@group(0) @binding(1) var scalar_out: array; +@group(0) @binding(2) var scalar_params: ScalarParams; + +@compute @workgroup_size(256) +fn add_scalar_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_a[idx] + scalar_params.scalar; + } +} + +@compute @workgroup_size(256) +fn sub_scalar_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_a[idx] - scalar_params.scalar; + } +} + +@compute @workgroup_size(256) +fn rsub_scalar_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_params.scalar - scalar_a[idx]; + } +} + +@compute @workgroup_size(256) +fn mul_scalar_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_a[idx] * scalar_params.scalar; + } +} + +@compute @workgroup_size(256) +fn div_scalar_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_a[idx] / scalar_params.scalar; + } +} diff --git a/src/runtime/wgpu/shaders/scalar_u32.wgsl b/src/runtime/wgpu/shaders/scalar_u32.wgsl new file mode 100644 index 00000000..fe84e80d --- /dev/null +++ b/src/runtime/wgpu/shaders/scalar_u32.wgsl @@ -0,0 +1,52 @@ +// U32 scalar operations + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScalarParams { + numel: u32, + scalar: u32, +} + +@group(0) @binding(0) var scalar_a: array; +@group(0) @binding(1) var scalar_out: array; +@group(0) @binding(2) var scalar_params: ScalarParams; + +@compute @workgroup_size(256) +fn add_scalar_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_a[idx] + scalar_params.scalar; + } +} + +@compute @workgroup_size(256) +fn sub_scalar_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_a[idx] - scalar_params.scalar; + } +} + +@compute @workgroup_size(256) +fn rsub_scalar_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_params.scalar - scalar_a[idx]; + } +} + +@compute @workgroup_size(256) +fn mul_scalar_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_a[idx] * scalar_params.scalar; + } +} + +@compute @workgroup_size(256) +fn div_scalar_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < scalar_params.numel) { + scalar_out[idx] = scalar_a[idx] / scalar_params.scalar; + } +} diff --git a/src/runtime/wgpu/shaders/scatter_f32.wgsl b/src/runtime/wgpu/shaders/scatter_f32.wgsl new file mode 100644 index 00000000..99b4306e --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_f32.wgsl @@ -0,0 +1,74 @@ +// Auto-generated scatter operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterParams { + ndim: u32, + dim: u32, + src_total: u32, + _padding: u32, + output_shape: vec4, + output_strides: vec4, + src_shape: vec4, + src_strides: vec4, +} + +@group(0) @binding(0) var src: array; +@group(0) @binding(1) var indices: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: ScatterParams; + +fn get_shape(arr: vec4, d: u32) -> u32 { + if (d == 0u) { return arr.x; } + else if (d == 1u) { return arr.y; } + else if (d == 2u) { return arr.z; } + else { return arr.w; } +} + +@compute @workgroup_size(256) +fn scatter_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.src_total) { + return; + } + + var remaining = idx; + var dst_offset: u32 = 0u; + + for (var d: u32 = 0u; d < params.ndim; d = d + 1u) { + let src_stride = get_shape(params.src_strides, d); + let coord = remaining / src_stride; + remaining = remaining % src_stride; + + if (d == params.dim) { + let index_val = indices[idx]; + let dim_size = get_shape(params.output_shape, d); + if (index_val < 0 || u32(index_val) >= dim_size) { + return; + } + dst_offset = dst_offset + u32(index_val) * get_shape(params.output_strides, d); + } else { + dst_offset = dst_offset + coord * get_shape(params.output_strides, d); + } + } + + output[dst_offset] = src[idx]; +} + +// Copy kernel for initializing output from input +@group(0) @binding(0) var copy_src: array; +@group(0) @binding(1) var copy_dst: array; + +struct CopyParams { + numel: u32, +} + +@group(0) @binding(2) var copy_params: CopyParams; + +@compute @workgroup_size(256) +fn copy_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < copy_params.numel) { + copy_dst[idx] = copy_src[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/scatter_i32.wgsl b/src/runtime/wgpu/shaders/scatter_i32.wgsl new file mode 100644 index 00000000..29e68baf --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_i32.wgsl @@ -0,0 +1,74 @@ +// Auto-generated scatter operations for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterParams { + ndim: u32, + dim: u32, + src_total: u32, + _padding: u32, + output_shape: vec4, + output_strides: vec4, + src_shape: vec4, + src_strides: vec4, +} + +@group(0) @binding(0) var src: array; +@group(0) @binding(1) var indices: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: ScatterParams; + +fn get_shape(arr: vec4, d: u32) -> u32 { + if (d == 0u) { return arr.x; } + else if (d == 1u) { return arr.y; } + else if (d == 2u) { return arr.z; } + else { return arr.w; } +} + +@compute @workgroup_size(256) +fn scatter_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.src_total) { + return; + } + + var remaining = idx; + var dst_offset: u32 = 0u; + + for (var d: u32 = 0u; d < params.ndim; d = d + 1u) { + let src_stride = get_shape(params.src_strides, d); + let coord = remaining / src_stride; + remaining = remaining % src_stride; + + if (d == params.dim) { + let index_val = indices[idx]; + let dim_size = get_shape(params.output_shape, d); + if (index_val < 0 || u32(index_val) >= dim_size) { + return; + } + dst_offset = dst_offset + u32(index_val) * get_shape(params.output_strides, d); + } else { + dst_offset = dst_offset + coord * get_shape(params.output_strides, d); + } + } + + output[dst_offset] = src[idx]; +} + +// Copy kernel for initializing output from input +@group(0) @binding(0) var copy_src: array; +@group(0) @binding(1) var copy_dst: array; + +struct CopyParams { + numel: u32, +} + +@group(0) @binding(2) var copy_params: CopyParams; + +@compute @workgroup_size(256) +fn copy_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < copy_params.numel) { + copy_dst[idx] = copy_src[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_count_f32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_count_f32.wgsl new file mode 100644 index 00000000..77306d72 --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_count_f32.wgsl @@ -0,0 +1,40 @@ +// Auto-generated scatter_reduce_count for mean computation + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterReduceParams { + dim: u32, + outer_size: u32, + dim_size: u32, + inner_size: u32, + src_dim_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var scatter_indices: array; +@group(0) @binding(1) var scatter_count: array>; +@group(0) @binding(2) var scatter_params: ScatterReduceParams; + +@compute @workgroup_size(256) +fn scatter_reduce_count_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % scatter_params.inner_size; + let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; + let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); + + let index_val = scatter_indices[src_dim_idx]; + if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) { + return; + } + + let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; + + atomicAdd(&scatter_count[dst_idx], 1u); +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_max_f32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_max_f32.wgsl new file mode 100644 index 00000000..75e5eed1 --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_max_f32.wgsl @@ -0,0 +1,56 @@ +// Auto-generated scatter_reduce_max for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterReduceParams { + dim: u32, + outer_size: u32, + dim_size: u32, + inner_size: u32, + src_dim_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +// Note: All storage buffers use read_write to match the pipeline cache layout. +// The actual access pattern is: src (read), indices (read), dst (read_write). +@group(0) @binding(0) var scatter_src: array; +@group(0) @binding(1) var scatter_indices: array; +@group(0) @binding(2) var scatter_dst: array>; +@group(0) @binding(3) var scatter_params: ScatterReduceParams; + +@compute @workgroup_size(256) +fn scatter_reduce_max_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % scatter_params.inner_size; + let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; + let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); + + let index_val = scatter_indices[src_dim_idx]; + if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) { + return; + } + + let src_val = scatter_src[idx]; + let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; + + // CAS loop for max + var old_bits: u32; + var new_bits: u32; + loop { + old_bits = atomicLoad(&scatter_dst[dst_idx]); + let old_val = bitcast(old_bits); + let new_val = max(old_val, src_val); + new_bits = bitcast(new_val); + let result = atomicCompareExchangeWeak(&scatter_dst[dst_idx], old_bits, new_bits); + if (result.exchanged) { + break; + } + } +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_max_i32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_max_i32.wgsl new file mode 100644 index 00000000..2ddeb0e2 --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_max_i32.wgsl @@ -0,0 +1,42 @@ +// Auto-generated scatter_reduce_max for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterReduceParams { + dim: u32, + outer_size: u32, + dim_size: u32, + inner_size: u32, + src_dim_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var scatter_src: array; +@group(0) @binding(1) var scatter_indices: array; +@group(0) @binding(2) var scatter_dst: array>; +@group(0) @binding(3) var scatter_params: ScatterReduceParams; + +@compute @workgroup_size(256) +fn scatter_reduce_max_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % scatter_params.inner_size; + let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; + let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); + + let index_val = scatter_indices[src_dim_idx]; + if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) { + return; + } + + let src_val = scatter_src[idx]; + let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; + + atomicMax(&scatter_dst[dst_idx], src_val); +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_max_u32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_max_u32.wgsl new file mode 100644 index 00000000..d1fb5ddd --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_max_u32.wgsl @@ -0,0 +1,42 @@ +// Auto-generated scatter_reduce_max for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterReduceParams { + dim: u32, + outer_size: u32, + dim_size: u32, + inner_size: u32, + src_dim_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var scatter_src: array; +@group(0) @binding(1) var scatter_indices: array; +@group(0) @binding(2) var scatter_dst: array>; +@group(0) @binding(3) var scatter_params: ScatterReduceParams; + +@compute @workgroup_size(256) +fn scatter_reduce_max_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % scatter_params.inner_size; + let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; + let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); + + let index_val = scatter_indices[src_dim_idx]; + if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) { + return; + } + + let src_val = scatter_src[idx]; + let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; + + atomicMax(&scatter_dst[dst_idx], src_val); +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_mean_div_f32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_mean_div_f32.wgsl new file mode 100644 index 00000000..24134d33 --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_mean_div_f32.wgsl @@ -0,0 +1,30 @@ +// Auto-generated scatter_reduce_mean_div for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct MeanDivParams { + n: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var mean_sum: array; +@group(0) @binding(1) var mean_count: array; +@group(0) @binding(2) var mean_output: array; +@group(0) @binding(3) var mean_params: MeanDivParams; + +@compute @workgroup_size(256) +fn scatter_reduce_mean_div_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= mean_params.n) { + return; + } + + let c = mean_count[idx]; + if (c > 0u) { + mean_output[idx] = mean_sum[idx] / f32(c); + } else { + mean_output[idx] = f32(0); + } +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_min_f32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_min_f32.wgsl new file mode 100644 index 00000000..ad3dc19e --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_min_f32.wgsl @@ -0,0 +1,56 @@ +// Auto-generated scatter_reduce_min for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterReduceParams { + dim: u32, + outer_size: u32, + dim_size: u32, + inner_size: u32, + src_dim_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +// Note: All storage buffers use read_write to match the pipeline cache layout. +// The actual access pattern is: src (read), indices (read), dst (read_write). +@group(0) @binding(0) var scatter_src: array; +@group(0) @binding(1) var scatter_indices: array; +@group(0) @binding(2) var scatter_dst: array>; +@group(0) @binding(3) var scatter_params: ScatterReduceParams; + +@compute @workgroup_size(256) +fn scatter_reduce_min_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % scatter_params.inner_size; + let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; + let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); + + let index_val = scatter_indices[src_dim_idx]; + if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) { + return; + } + + let src_val = scatter_src[idx]; + let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; + + // CAS loop for min + var old_bits: u32; + var new_bits: u32; + loop { + old_bits = atomicLoad(&scatter_dst[dst_idx]); + let old_val = bitcast(old_bits); + let new_val = min(old_val, src_val); + new_bits = bitcast(new_val); + let result = atomicCompareExchangeWeak(&scatter_dst[dst_idx], old_bits, new_bits); + if (result.exchanged) { + break; + } + } +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_min_i32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_min_i32.wgsl new file mode 100644 index 00000000..eedb9431 --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_min_i32.wgsl @@ -0,0 +1,42 @@ +// Auto-generated scatter_reduce_min for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterReduceParams { + dim: u32, + outer_size: u32, + dim_size: u32, + inner_size: u32, + src_dim_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var scatter_src: array; +@group(0) @binding(1) var scatter_indices: array; +@group(0) @binding(2) var scatter_dst: array>; +@group(0) @binding(3) var scatter_params: ScatterReduceParams; + +@compute @workgroup_size(256) +fn scatter_reduce_min_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % scatter_params.inner_size; + let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; + let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); + + let index_val = scatter_indices[src_dim_idx]; + if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) { + return; + } + + let src_val = scatter_src[idx]; + let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; + + atomicMin(&scatter_dst[dst_idx], src_val); +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_min_u32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_min_u32.wgsl new file mode 100644 index 00000000..15d19cc6 --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_min_u32.wgsl @@ -0,0 +1,42 @@ +// Auto-generated scatter_reduce_min for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterReduceParams { + dim: u32, + outer_size: u32, + dim_size: u32, + inner_size: u32, + src_dim_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var scatter_src: array; +@group(0) @binding(1) var scatter_indices: array; +@group(0) @binding(2) var scatter_dst: array>; +@group(0) @binding(3) var scatter_params: ScatterReduceParams; + +@compute @workgroup_size(256) +fn scatter_reduce_min_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % scatter_params.inner_size; + let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; + let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); + + let index_val = scatter_indices[src_dim_idx]; + if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) { + return; + } + + let src_val = scatter_src[idx]; + let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; + + atomicMin(&scatter_dst[dst_idx], src_val); +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_prod_f32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_prod_f32.wgsl new file mode 100644 index 00000000..edcef918 --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_prod_f32.wgsl @@ -0,0 +1,54 @@ +// Auto-generated scatter_reduce_prod for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterReduceParams { + dim: u32, + outer_size: u32, + dim_size: u32, + inner_size: u32, + src_dim_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var scatter_src: array; +@group(0) @binding(1) var scatter_indices: array; +@group(0) @binding(2) var scatter_dst: array>; +@group(0) @binding(3) var scatter_params: ScatterReduceParams; + +@compute @workgroup_size(256) +fn scatter_reduce_prod_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % scatter_params.inner_size; + let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; + let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); + + let index_val = scatter_indices[src_dim_idx]; + if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) { + return; + } + + let src_val = scatter_src[idx]; + let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; + + // CAS loop for atomic multiply + var old_bits: u32; + var new_bits: u32; + loop { + old_bits = atomicLoad(&scatter_dst[dst_idx]); + let old_val = bitcast(old_bits); + let new_val = old_val * src_val; + new_bits = bitcast(new_val); + let result = atomicCompareExchangeWeak(&scatter_dst[dst_idx], old_bits, new_bits); + if (result.exchanged) { + break; + } + } +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_prod_i32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_prod_i32.wgsl new file mode 100644 index 00000000..abaf343a --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_prod_i32.wgsl @@ -0,0 +1,50 @@ +// Auto-generated scatter_reduce_prod for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterReduceParams { + dim: u32, + outer_size: u32, + dim_size: u32, + inner_size: u32, + src_dim_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var scatter_src: array; +@group(0) @binding(1) var scatter_indices: array; +@group(0) @binding(2) var scatter_dst: array>; +@group(0) @binding(3) var scatter_params: ScatterReduceParams; + +@compute @workgroup_size(256) +fn scatter_reduce_prod_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % scatter_params.inner_size; + let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; + let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); + + let index_val = scatter_indices[src_dim_idx]; + if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) { + return; + } + + let src_val = scatter_src[idx]; + let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; + + // CAS loop for atomic multiply + loop { + let old_val = atomicLoad(&scatter_dst[dst_idx]); + let new_val = old_val * src_val; + let result = atomicCompareExchangeWeak(&scatter_dst[dst_idx], old_val, new_val); + if (result.exchanged) { + break; + } + } +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_prod_u32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_prod_u32.wgsl new file mode 100644 index 00000000..c17e62bc --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_prod_u32.wgsl @@ -0,0 +1,50 @@ +// Auto-generated scatter_reduce_prod for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterReduceParams { + dim: u32, + outer_size: u32, + dim_size: u32, + inner_size: u32, + src_dim_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var scatter_src: array; +@group(0) @binding(1) var scatter_indices: array; +@group(0) @binding(2) var scatter_dst: array>; +@group(0) @binding(3) var scatter_params: ScatterReduceParams; + +@compute @workgroup_size(256) +fn scatter_reduce_prod_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % scatter_params.inner_size; + let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; + let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); + + let index_val = scatter_indices[src_dim_idx]; + if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) { + return; + } + + let src_val = scatter_src[idx]; + let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; + + // CAS loop for atomic multiply + loop { + let old_val = atomicLoad(&scatter_dst[dst_idx]); + let new_val = old_val * src_val; + let result = atomicCompareExchangeWeak(&scatter_dst[dst_idx], old_val, new_val); + if (result.exchanged) { + break; + } + } +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_sum_f32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_sum_f32.wgsl new file mode 100644 index 00000000..3e922f04 --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_sum_f32.wgsl @@ -0,0 +1,56 @@ +// Auto-generated scatter_reduce_sum for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterReduceParams { + dim: u32, + outer_size: u32, + dim_size: u32, + inner_size: u32, + src_dim_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +// Note: All storage buffers use read_write to match the pipeline cache layout. +// The actual access pattern is: src (read), indices (read), dst (read_write). +@group(0) @binding(0) var scatter_src: array; +@group(0) @binding(1) var scatter_indices: array; +@group(0) @binding(2) var scatter_dst: array>; +@group(0) @binding(3) var scatter_params: ScatterReduceParams; + +@compute @workgroup_size(256) +fn scatter_reduce_sum_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % scatter_params.inner_size; + let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; + let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); + + let index_val = scatter_indices[src_dim_idx]; + if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) { + return; + } + + let src_val = scatter_src[idx]; + let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; + + // CAS loop for atomic float add + var old_bits: u32; + var new_bits: u32; + loop { + old_bits = atomicLoad(&scatter_dst[dst_idx]); + let old_val = bitcast(old_bits); + let new_val = old_val + src_val; + new_bits = bitcast(new_val); + let result = atomicCompareExchangeWeak(&scatter_dst[dst_idx], old_bits, new_bits); + if (result.exchanged) { + break; + } + } +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_sum_i32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_sum_i32.wgsl new file mode 100644 index 00000000..93a169a5 --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_sum_i32.wgsl @@ -0,0 +1,42 @@ +// Auto-generated scatter_reduce_sum for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterReduceParams { + dim: u32, + outer_size: u32, + dim_size: u32, + inner_size: u32, + src_dim_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var scatter_src: array; +@group(0) @binding(1) var scatter_indices: array; +@group(0) @binding(2) var scatter_dst: array>; +@group(0) @binding(3) var scatter_params: ScatterReduceParams; + +@compute @workgroup_size(256) +fn scatter_reduce_sum_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % scatter_params.inner_size; + let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; + let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); + + let index_val = scatter_indices[src_dim_idx]; + if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) { + return; + } + + let src_val = scatter_src[idx]; + let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; + + atomicAdd(&scatter_dst[dst_idx], src_val); +} diff --git a/src/runtime/wgpu/shaders/scatter_reduce_sum_u32.wgsl b/src/runtime/wgpu/shaders/scatter_reduce_sum_u32.wgsl new file mode 100644 index 00000000..05b8cc35 --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_reduce_sum_u32.wgsl @@ -0,0 +1,42 @@ +// Auto-generated scatter_reduce_sum for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterReduceParams { + dim: u32, + outer_size: u32, + dim_size: u32, + inner_size: u32, + src_dim_size: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var scatter_src: array; +@group(0) @binding(1) var scatter_indices: array; +@group(0) @binding(2) var scatter_dst: array>; +@group(0) @binding(3) var scatter_params: ScatterReduceParams; + +@compute @workgroup_size(256) +fn scatter_reduce_sum_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = scatter_params.outer_size * scatter_params.src_dim_size * scatter_params.inner_size; + if (idx >= total) { + return; + } + + let inner = idx % scatter_params.inner_size; + let src_dim_idx = (idx / scatter_params.inner_size) % scatter_params.src_dim_size; + let outer = idx / (scatter_params.src_dim_size * scatter_params.inner_size); + + let index_val = scatter_indices[src_dim_idx]; + if (index_val < 0 || u32(index_val) >= scatter_params.dim_size) { + return; + } + + let src_val = scatter_src[idx]; + let dst_idx = outer * scatter_params.dim_size * scatter_params.inner_size + u32(index_val) * scatter_params.inner_size + inner; + + atomicAdd(&scatter_dst[dst_idx], src_val); +} diff --git a/src/runtime/wgpu/shaders/scatter_u32.wgsl b/src/runtime/wgpu/shaders/scatter_u32.wgsl new file mode 100644 index 00000000..12634bd5 --- /dev/null +++ b/src/runtime/wgpu/shaders/scatter_u32.wgsl @@ -0,0 +1,74 @@ +// Auto-generated scatter operations for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct ScatterParams { + ndim: u32, + dim: u32, + src_total: u32, + _padding: u32, + output_shape: vec4, + output_strides: vec4, + src_shape: vec4, + src_strides: vec4, +} + +@group(0) @binding(0) var src: array; +@group(0) @binding(1) var indices: array; +@group(0) @binding(2) var output: array; +@group(0) @binding(3) var params: ScatterParams; + +fn get_shape(arr: vec4, d: u32) -> u32 { + if (d == 0u) { return arr.x; } + else if (d == 1u) { return arr.y; } + else if (d == 2u) { return arr.z; } + else { return arr.w; } +} + +@compute @workgroup_size(256) +fn scatter_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.src_total) { + return; + } + + var remaining = idx; + var dst_offset: u32 = 0u; + + for (var d: u32 = 0u; d < params.ndim; d = d + 1u) { + let src_stride = get_shape(params.src_strides, d); + let coord = remaining / src_stride; + remaining = remaining % src_stride; + + if (d == params.dim) { + let index_val = indices[idx]; + let dim_size = get_shape(params.output_shape, d); + if (index_val < 0 || u32(index_val) >= dim_size) { + return; + } + dst_offset = dst_offset + u32(index_val) * get_shape(params.output_strides, d); + } else { + dst_offset = dst_offset + coord * get_shape(params.output_strides, d); + } + } + + output[dst_offset] = src[idx]; +} + +// Copy kernel for initializing output from input +@group(0) @binding(0) var copy_src: array; +@group(0) @binding(1) var copy_dst: array; + +struct CopyParams { + numel: u32, +} + +@group(0) @binding(2) var copy_params: CopyParams; + +@compute @workgroup_size(256) +fn copy_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < copy_params.numel) { + copy_dst[idx] = copy_src[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/searchsorted_f32.wgsl b/src/runtime/wgpu/shaders/searchsorted_f32.wgsl new file mode 100644 index 00000000..4243f212 --- /dev/null +++ b/src/runtime/wgpu/shaders/searchsorted_f32.wgsl @@ -0,0 +1,52 @@ +// Auto-generated searchsorted operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct SearchsortedParams { + seq_len: u32, + num_values: u32, + right: u32, + _pad: u32, +} + +@group(0) @binding(0) var ss_seq: array; +@group(0) @binding(1) var ss_values: array; +@group(0) @binding(2) var ss_output: array; +@group(0) @binding(3) var ss_params: SearchsortedParams; + +@compute @workgroup_size(256) +fn searchsorted_f32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + + if (idx >= ss_params.num_values) { + return; + } + + let value = ss_values[idx]; + let seq_len = ss_params.seq_len; + let right = ss_params.right != 0u; + + // Binary search + var lo: u32 = 0u; + var hi: u32 = seq_len; + + while (lo < hi) { + let mid = lo + (hi - lo) / 2u; + let seq_val = ss_seq[mid]; + + var go_right: bool; + if (right) { + go_right = seq_val <= value; + } else { + go_right = seq_val < value; + } + + if (go_right) { + lo = mid + 1u; + } else { + hi = mid; + } + } + + ss_output[idx] = i32(lo); +} diff --git a/src/runtime/wgpu/shaders/semiring_matmul.rs b/src/runtime/wgpu/shaders/semiring_matmul.rs index b833fc84..8173e477 100644 --- a/src/runtime/wgpu/shaders/semiring_matmul.rs +++ b/src/runtime/wgpu/shaders/semiring_matmul.rs @@ -1,122 +1,69 @@ -//! Semiring matrix multiplication WGSL kernel launchers +//! Semiring matrix multiplication WGSL kernel launchers. F32 only. use wgpu::{Buffer, Queue}; -use super::generator::semiring_matmul::generate_semiring_matmul_shader; use super::pipeline::{LayoutKey, PipelineCache}; use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::semiring::SemiringOp; +const SR_MIN_PLUS_SHADER: &str = include_str!("semiring_matmul_min_plus_f32.wgsl"); +const SR_MAX_PLUS_SHADER: &str = include_str!("semiring_matmul_max_plus_f32.wgsl"); +const SR_MAX_MIN_SHADER: &str = include_str!("semiring_matmul_max_min_f32.wgsl"); +const SR_MIN_MAX_SHADER: &str = include_str!("semiring_matmul_min_max_f32.wgsl"); +const SR_OR_AND_SHADER: &str = include_str!("semiring_matmul_or_and_f32.wgsl"); +const SR_PLUS_MAX_SHADER: &str = include_str!("semiring_matmul_plus_max_f32.wgsl"); + const TILE_SIZE: u32 = 16; -/// Returns (module_key, entry_point, batched_entry_point) as &'static str. -/// The pipeline cache requires 'static lifetimes for keys. -fn semiring_keys( +fn semiring_shader_info( op: SemiringOp, dtype: DType, -) -> Result<(&'static str, &'static str, &'static str)> { - use DType::*; - use SemiringOp::*; - match (op, dtype) { - (MinPlus, F32) => Ok(( +) -> Result<(&'static str, &'static str, &'static str, &'static str)> { + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "semiring_matmul (WebGPU)", + }); + } + Ok(match op { + SemiringOp::MinPlus => ( + SR_MIN_PLUS_SHADER, "sr_min_plus_f32", "semiring_matmul_min_plus_f32", "batched_semiring_matmul_min_plus_f32", - )), - (MaxPlus, F32) => Ok(( + ), + SemiringOp::MaxPlus => ( + SR_MAX_PLUS_SHADER, "sr_max_plus_f32", "semiring_matmul_max_plus_f32", "batched_semiring_matmul_max_plus_f32", - )), - (MaxMin, F32) => Ok(( + ), + SemiringOp::MaxMin => ( + SR_MAX_MIN_SHADER, "sr_max_min_f32", "semiring_matmul_max_min_f32", "batched_semiring_matmul_max_min_f32", - )), - (MinMax, F32) => Ok(( + ), + SemiringOp::MinMax => ( + SR_MIN_MAX_SHADER, "sr_min_max_f32", "semiring_matmul_min_max_f32", "batched_semiring_matmul_min_max_f32", - )), - (OrAnd, F32) => Ok(( + ), + SemiringOp::OrAnd => ( + SR_OR_AND_SHADER, "sr_or_and_f32", "semiring_matmul_or_and_f32", "batched_semiring_matmul_or_and_f32", - )), - (PlusMax, F32) => Ok(( + ), + SemiringOp::PlusMax => ( + SR_PLUS_MAX_SHADER, "sr_plus_max_f32", "semiring_matmul_plus_max_f32", "batched_semiring_matmul_plus_max_f32", - )), - - (MinPlus, I32) => Ok(( - "sr_min_plus_i32", - "semiring_matmul_min_plus_i32", - "batched_semiring_matmul_min_plus_i32", - )), - (MaxPlus, I32) => Ok(( - "sr_max_plus_i32", - "semiring_matmul_max_plus_i32", - "batched_semiring_matmul_max_plus_i32", - )), - (MaxMin, I32) => Ok(( - "sr_max_min_i32", - "semiring_matmul_max_min_i32", - "batched_semiring_matmul_max_min_i32", - )), - (MinMax, I32) => Ok(( - "sr_min_max_i32", - "semiring_matmul_min_max_i32", - "batched_semiring_matmul_min_max_i32", - )), - (OrAnd, I32) => Ok(( - "sr_or_and_i32", - "semiring_matmul_or_and_i32", - "batched_semiring_matmul_or_and_i32", - )), - (PlusMax, I32) => Ok(( - "sr_plus_max_i32", - "semiring_matmul_plus_max_i32", - "batched_semiring_matmul_plus_max_i32", - )), - - (MinPlus, U32) => Ok(( - "sr_min_plus_u32", - "semiring_matmul_min_plus_u32", - "batched_semiring_matmul_min_plus_u32", - )), - (MaxPlus, U32) => Ok(( - "sr_max_plus_u32", - "semiring_matmul_max_plus_u32", - "batched_semiring_matmul_max_plus_u32", - )), - (MaxMin, U32) => Ok(( - "sr_max_min_u32", - "semiring_matmul_max_min_u32", - "batched_semiring_matmul_max_min_u32", - )), - (MinMax, U32) => Ok(( - "sr_min_max_u32", - "semiring_matmul_min_max_u32", - "batched_semiring_matmul_min_max_u32", - )), - (OrAnd, U32) => Ok(( - "sr_or_and_u32", - "semiring_matmul_or_and_u32", - "batched_semiring_matmul_or_and_u32", - )), - (PlusMax, U32) => Ok(( - "sr_plus_max_u32", - "semiring_matmul_plus_max_u32", - "batched_semiring_matmul_plus_max_u32", - )), - - _ => Err(Error::UnsupportedDType { - dtype, - op: "semiring_matmul (WebGPU)", - }), - } + ), + }) } /// Launch semiring matrix multiplication kernel. @@ -132,10 +79,9 @@ pub fn launch_semiring_matmul( op: SemiringOp, dtype: DType, ) -> Result<()> { - let (module_key, entry_point, _) = semiring_keys(op, dtype)?; - let shader_source = generate_semiring_matmul_shader(dtype, op)?; + let (shader, module_key, entry_point, _) = semiring_shader_info(op, dtype)?; - let module = cache.get_or_create_module(module_key, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, @@ -181,10 +127,9 @@ pub fn launch_batched_semiring_matmul( op: SemiringOp, dtype: DType, ) -> Result<()> { - let (module_key, _, batched_entry_point) = semiring_keys(op, dtype)?; - let shader_source = generate_semiring_matmul_shader(dtype, op)?; + let (shader, module_key, _, batched_entry_point) = semiring_shader_info(op, dtype)?; - let module = cache.get_or_create_module(module_key, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, diff --git a/src/runtime/wgpu/shaders/semiring_matmul_max_min_f32.wgsl b/src/runtime/wgpu/shaders/semiring_matmul_max_min_f32.wgsl new file mode 100644 index 00000000..95714b46 --- /dev/null +++ b/src/runtime/wgpu/shaders/semiring_matmul_max_min_f32.wgsl @@ -0,0 +1,85 @@ +// Semiring matmul: max_min for f32 +// C[i,j] = max_k( min(A[i,k], B[k,j]) ) +// Entry points: semiring_matmul_max_min_f32, batched_semiring_matmul_max_min_f32 + +struct SemiringMatmulParams { + M: u32, + K: u32, + N: u32, + batch_size: u32, +} + +@group(0) @binding(0) var sr_a: array; +@group(0) @binding(1) var sr_b: array; +@group(0) @binding(2) var sr_c: array; +@group(0) @binding(3) var sr_params: SemiringMatmulParams; + +fn sr_combine(a: f32, b: f32) -> f32 { + return min(a, b); +} + +fn sr_reduce(acc: f32, val: f32) -> f32 { + return max(acc, val); +} + +@compute @workgroup_size(16, 16, 1) +fn semiring_matmul_max_min_f32( + @builtin(global_invocation_id) global_id: vec3 +) { + let M = sr_params.M; + let K = sr_params.K; + let N = sr_params.N; + + let row = global_id.y; + let col = global_id.x; + + if (row >= M || col >= N) { + return; + } + + var acc: f32 = bitcast(0xff800000u); + + for (var kk: u32 = 0u; kk < K; kk = kk + 1u) { + let a_val = sr_a[row * K + kk]; + let b_val = sr_b[kk * N + col]; + acc = sr_reduce(acc, sr_combine(a_val, b_val)); + } + + sr_c[row * N + col] = acc; +} + +@compute @workgroup_size(16, 16, 1) +fn batched_semiring_matmul_max_min_f32( + @builtin(global_invocation_id) global_id: vec3 +) { + let M = sr_params.M; + let K = sr_params.K; + let N = sr_params.N; + let batch_size = sr_params.batch_size; + + let batch = global_id.z; + if (batch >= batch_size) { + return; + } + + let row = global_id.y; + let col = global_id.x; + + if (row >= M || col >= N) { + return; + } + + let a_offset = batch * M * K; + let b_offset = batch * K * N; + let c_offset = batch * M * N; + + var acc: f32 = bitcast(0xff800000u); + + for (var kk: u32 = 0u; kk < K; kk = kk + 1u) { + let a_val = sr_a[a_offset + row * K + kk]; + let b_val = sr_b[b_offset + kk * N + col]; + acc = sr_reduce(acc, sr_combine(a_val, b_val)); + } + + sr_c[c_offset + row * N + col] = acc; +} diff --git a/src/runtime/wgpu/shaders/semiring_matmul_max_plus_f32.wgsl b/src/runtime/wgpu/shaders/semiring_matmul_max_plus_f32.wgsl new file mode 100644 index 00000000..f2c7d682 --- /dev/null +++ b/src/runtime/wgpu/shaders/semiring_matmul_max_plus_f32.wgsl @@ -0,0 +1,85 @@ +// Semiring matmul: max_plus for f32 +// C[i,j] = max_k( A[i,k] + B[k,j] ) +// Entry points: semiring_matmul_max_plus_f32, batched_semiring_matmul_max_plus_f32 + +struct SemiringMatmulParams { + M: u32, + K: u32, + N: u32, + batch_size: u32, +} + +@group(0) @binding(0) var sr_a: array; +@group(0) @binding(1) var sr_b: array; +@group(0) @binding(2) var sr_c: array; +@group(0) @binding(3) var sr_params: SemiringMatmulParams; + +fn sr_combine(a: f32, b: f32) -> f32 { + return a + b; +} + +fn sr_reduce(acc: f32, val: f32) -> f32 { + return max(acc, val); +} + +@compute @workgroup_size(16, 16, 1) +fn semiring_matmul_max_plus_f32( + @builtin(global_invocation_id) global_id: vec3 +) { + let M = sr_params.M; + let K = sr_params.K; + let N = sr_params.N; + + let row = global_id.y; + let col = global_id.x; + + if (row >= M || col >= N) { + return; + } + + var acc: f32 = bitcast(0xff800000u); + + for (var kk: u32 = 0u; kk < K; kk = kk + 1u) { + let a_val = sr_a[row * K + kk]; + let b_val = sr_b[kk * N + col]; + acc = sr_reduce(acc, sr_combine(a_val, b_val)); + } + + sr_c[row * N + col] = acc; +} + +@compute @workgroup_size(16, 16, 1) +fn batched_semiring_matmul_max_plus_f32( + @builtin(global_invocation_id) global_id: vec3 +) { + let M = sr_params.M; + let K = sr_params.K; + let N = sr_params.N; + let batch_size = sr_params.batch_size; + + let batch = global_id.z; + if (batch >= batch_size) { + return; + } + + let row = global_id.y; + let col = global_id.x; + + if (row >= M || col >= N) { + return; + } + + let a_offset = batch * M * K; + let b_offset = batch * K * N; + let c_offset = batch * M * N; + + var acc: f32 = bitcast(0xff800000u); + + for (var kk: u32 = 0u; kk < K; kk = kk + 1u) { + let a_val = sr_a[a_offset + row * K + kk]; + let b_val = sr_b[b_offset + kk * N + col]; + acc = sr_reduce(acc, sr_combine(a_val, b_val)); + } + + sr_c[c_offset + row * N + col] = acc; +} diff --git a/src/runtime/wgpu/shaders/semiring_matmul_min_max_f32.wgsl b/src/runtime/wgpu/shaders/semiring_matmul_min_max_f32.wgsl new file mode 100644 index 00000000..81dd52f3 --- /dev/null +++ b/src/runtime/wgpu/shaders/semiring_matmul_min_max_f32.wgsl @@ -0,0 +1,85 @@ +// Semiring matmul: min_max for f32 +// C[i,j] = min_k( max(A[i,k], B[k,j]) ) +// Entry points: semiring_matmul_min_max_f32, batched_semiring_matmul_min_max_f32 + +struct SemiringMatmulParams { + M: u32, + K: u32, + N: u32, + batch_size: u32, +} + +@group(0) @binding(0) var sr_a: array; +@group(0) @binding(1) var sr_b: array; +@group(0) @binding(2) var sr_c: array; +@group(0) @binding(3) var sr_params: SemiringMatmulParams; + +fn sr_combine(a: f32, b: f32) -> f32 { + return max(a, b); +} + +fn sr_reduce(acc: f32, val: f32) -> f32 { + return min(acc, val); +} + +@compute @workgroup_size(16, 16, 1) +fn semiring_matmul_min_max_f32( + @builtin(global_invocation_id) global_id: vec3 +) { + let M = sr_params.M; + let K = sr_params.K; + let N = sr_params.N; + + let row = global_id.y; + let col = global_id.x; + + if (row >= M || col >= N) { + return; + } + + var acc: f32 = bitcast(0x7f800000u); + + for (var kk: u32 = 0u; kk < K; kk = kk + 1u) { + let a_val = sr_a[row * K + kk]; + let b_val = sr_b[kk * N + col]; + acc = sr_reduce(acc, sr_combine(a_val, b_val)); + } + + sr_c[row * N + col] = acc; +} + +@compute @workgroup_size(16, 16, 1) +fn batched_semiring_matmul_min_max_f32( + @builtin(global_invocation_id) global_id: vec3 +) { + let M = sr_params.M; + let K = sr_params.K; + let N = sr_params.N; + let batch_size = sr_params.batch_size; + + let batch = global_id.z; + if (batch >= batch_size) { + return; + } + + let row = global_id.y; + let col = global_id.x; + + if (row >= M || col >= N) { + return; + } + + let a_offset = batch * M * K; + let b_offset = batch * K * N; + let c_offset = batch * M * N; + + var acc: f32 = bitcast(0x7f800000u); + + for (var kk: u32 = 0u; kk < K; kk = kk + 1u) { + let a_val = sr_a[a_offset + row * K + kk]; + let b_val = sr_b[b_offset + kk * N + col]; + acc = sr_reduce(acc, sr_combine(a_val, b_val)); + } + + sr_c[c_offset + row * N + col] = acc; +} diff --git a/src/runtime/wgpu/shaders/semiring_matmul_min_plus_f32.wgsl b/src/runtime/wgpu/shaders/semiring_matmul_min_plus_f32.wgsl new file mode 100644 index 00000000..446a078a --- /dev/null +++ b/src/runtime/wgpu/shaders/semiring_matmul_min_plus_f32.wgsl @@ -0,0 +1,85 @@ +// Semiring matmul: min_plus for f32 +// C[i,j] = min_k( A[i,k] + B[k,j] ) +// Entry points: semiring_matmul_min_plus_f32, batched_semiring_matmul_min_plus_f32 + +struct SemiringMatmulParams { + M: u32, + K: u32, + N: u32, + batch_size: u32, +} + +@group(0) @binding(0) var sr_a: array; +@group(0) @binding(1) var sr_b: array; +@group(0) @binding(2) var sr_c: array; +@group(0) @binding(3) var sr_params: SemiringMatmulParams; + +fn sr_combine(a: f32, b: f32) -> f32 { + return a + b; +} + +fn sr_reduce(acc: f32, val: f32) -> f32 { + return min(acc, val); +} + +@compute @workgroup_size(16, 16, 1) +fn semiring_matmul_min_plus_f32( + @builtin(global_invocation_id) global_id: vec3 +) { + let M = sr_params.M; + let K = sr_params.K; + let N = sr_params.N; + + let row = global_id.y; + let col = global_id.x; + + if (row >= M || col >= N) { + return; + } + + var acc: f32 = bitcast(0x7f800000u); + + for (var kk: u32 = 0u; kk < K; kk = kk + 1u) { + let a_val = sr_a[row * K + kk]; + let b_val = sr_b[kk * N + col]; + acc = sr_reduce(acc, sr_combine(a_val, b_val)); + } + + sr_c[row * N + col] = acc; +} + +@compute @workgroup_size(16, 16, 1) +fn batched_semiring_matmul_min_plus_f32( + @builtin(global_invocation_id) global_id: vec3 +) { + let M = sr_params.M; + let K = sr_params.K; + let N = sr_params.N; + let batch_size = sr_params.batch_size; + + let batch = global_id.z; + if (batch >= batch_size) { + return; + } + + let row = global_id.y; + let col = global_id.x; + + if (row >= M || col >= N) { + return; + } + + let a_offset = batch * M * K; + let b_offset = batch * K * N; + let c_offset = batch * M * N; + + var acc: f32 = bitcast(0x7f800000u); + + for (var kk: u32 = 0u; kk < K; kk = kk + 1u) { + let a_val = sr_a[a_offset + row * K + kk]; + let b_val = sr_b[b_offset + kk * N + col]; + acc = sr_reduce(acc, sr_combine(a_val, b_val)); + } + + sr_c[c_offset + row * N + col] = acc; +} diff --git a/src/runtime/wgpu/shaders/semiring_matmul_or_and_f32.wgsl b/src/runtime/wgpu/shaders/semiring_matmul_or_and_f32.wgsl new file mode 100644 index 00000000..bd021d2b --- /dev/null +++ b/src/runtime/wgpu/shaders/semiring_matmul_or_and_f32.wgsl @@ -0,0 +1,85 @@ +// Semiring matmul: or_and for f32 +// C[i,j] = OR_k( A[i,k] AND B[k,j] ) (logical, mapped to float 0.0/1.0) +// Entry points: semiring_matmul_or_and_f32, batched_semiring_matmul_or_and_f32 + +struct SemiringMatmulParams { + M: u32, + K: u32, + N: u32, + batch_size: u32, +} + +@group(0) @binding(0) var sr_a: array; +@group(0) @binding(1) var sr_b: array; +@group(0) @binding(2) var sr_c: array; +@group(0) @binding(3) var sr_params: SemiringMatmulParams; + +fn sr_combine(a: f32, b: f32) -> f32 { + return select(0.0, 1.0, a != 0.0 && b != 0.0); +} + +fn sr_reduce(acc: f32, val: f32) -> f32 { + return select(0.0, 1.0, acc != 0.0 || val != 0.0); +} + +@compute @workgroup_size(16, 16, 1) +fn semiring_matmul_or_and_f32( + @builtin(global_invocation_id) global_id: vec3 +) { + let M = sr_params.M; + let K = sr_params.K; + let N = sr_params.N; + + let row = global_id.y; + let col = global_id.x; + + if (row >= M || col >= N) { + return; + } + + var acc: f32 = 0.0; + + for (var kk: u32 = 0u; kk < K; kk = kk + 1u) { + let a_val = sr_a[row * K + kk]; + let b_val = sr_b[kk * N + col]; + acc = sr_reduce(acc, sr_combine(a_val, b_val)); + } + + sr_c[row * N + col] = acc; +} + +@compute @workgroup_size(16, 16, 1) +fn batched_semiring_matmul_or_and_f32( + @builtin(global_invocation_id) global_id: vec3 +) { + let M = sr_params.M; + let K = sr_params.K; + let N = sr_params.N; + let batch_size = sr_params.batch_size; + + let batch = global_id.z; + if (batch >= batch_size) { + return; + } + + let row = global_id.y; + let col = global_id.x; + + if (row >= M || col >= N) { + return; + } + + let a_offset = batch * M * K; + let b_offset = batch * K * N; + let c_offset = batch * M * N; + + var acc: f32 = 0.0; + + for (var kk: u32 = 0u; kk < K; kk = kk + 1u) { + let a_val = sr_a[a_offset + row * K + kk]; + let b_val = sr_b[b_offset + kk * N + col]; + acc = sr_reduce(acc, sr_combine(a_val, b_val)); + } + + sr_c[c_offset + row * N + col] = acc; +} diff --git a/src/runtime/wgpu/shaders/semiring_matmul_plus_max_f32.wgsl b/src/runtime/wgpu/shaders/semiring_matmul_plus_max_f32.wgsl new file mode 100644 index 00000000..00f6c5c7 --- /dev/null +++ b/src/runtime/wgpu/shaders/semiring_matmul_plus_max_f32.wgsl @@ -0,0 +1,85 @@ +// Semiring matmul: plus_max for f32 +// C[i,j] = sum_k( max(A[i,k], B[k,j]) ) +// Entry points: semiring_matmul_plus_max_f32, batched_semiring_matmul_plus_max_f32 + +struct SemiringMatmulParams { + M: u32, + K: u32, + N: u32, + batch_size: u32, +} + +@group(0) @binding(0) var sr_a: array; +@group(0) @binding(1) var sr_b: array; +@group(0) @binding(2) var sr_c: array; +@group(0) @binding(3) var sr_params: SemiringMatmulParams; + +fn sr_combine(a: f32, b: f32) -> f32 { + return max(a, b); +} + +fn sr_reduce(acc: f32, val: f32) -> f32 { + return acc + val; +} + +@compute @workgroup_size(16, 16, 1) +fn semiring_matmul_plus_max_f32( + @builtin(global_invocation_id) global_id: vec3 +) { + let M = sr_params.M; + let K = sr_params.K; + let N = sr_params.N; + + let row = global_id.y; + let col = global_id.x; + + if (row >= M || col >= N) { + return; + } + + var acc: f32 = 0.0; + + for (var kk: u32 = 0u; kk < K; kk = kk + 1u) { + let a_val = sr_a[row * K + kk]; + let b_val = sr_b[kk * N + col]; + acc = sr_reduce(acc, sr_combine(a_val, b_val)); + } + + sr_c[row * N + col] = acc; +} + +@compute @workgroup_size(16, 16, 1) +fn batched_semiring_matmul_plus_max_f32( + @builtin(global_invocation_id) global_id: vec3 +) { + let M = sr_params.M; + let K = sr_params.K; + let N = sr_params.N; + let batch_size = sr_params.batch_size; + + let batch = global_id.z; + if (batch >= batch_size) { + return; + } + + let row = global_id.y; + let col = global_id.x; + + if (row >= M || col >= N) { + return; + } + + let a_offset = batch * M * K; + let b_offset = batch * K * N; + let c_offset = batch * M * N; + + var acc: f32 = 0.0; + + for (var kk: u32 = 0u; kk < K; kk = kk + 1u) { + let a_val = sr_a[a_offset + row * K + kk]; + let b_val = sr_b[b_offset + kk * N + col]; + acc = sr_reduce(acc, sr_combine(a_val, b_val)); + } + + sr_c[c_offset + row * N + col] = acc; +} diff --git a/src/runtime/wgpu/shaders/shape.rs b/src/runtime/wgpu/shaders/shape.rs index 49674c94..e0a935eb 100644 --- a/src/runtime/wgpu/shaders/shape.rs +++ b/src/runtime/wgpu/shaders/shape.rs @@ -9,37 +9,140 @@ //! - split/chunk: Zero-copy views using narrow (no kernel needed) //! //! All copy operations run entirely on GPU with no CPU fallback. +//! +//! dtype policy (Option C): +//! - cat, repeat, pad, roll → DATA-MOVEMENT → support F32, I32, U32 +//! - arange, eye → can produce F32 / I32 / U32 +//! - linspace → F32 only (interpolation math) +//! - rand, randn → F32 only (math) +//! - randint → I32 / U32 only +//! - multinomial → F32 only (math) use wgpu::{Buffer, Queue}; -use super::generator::{ - generate_arange_shader, generate_cat_shader, generate_eye_shader, generate_linspace_shader, - generate_multinomial_with_replacement_shader, generate_multinomial_without_replacement_shader, - generate_pad_shader, generate_rand_shader, generate_randint_shader, generate_randn_shader, - generate_repeat_shader, generate_roll_shader, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; use crate::error::{Error, Result}; // ============================================================================ -// Helper Functions +// Static shaders — cat (data-movement: F32 / I32 / U32) // ============================================================================ -/// Check if dtype is supported for shape operations on WebGPU. -fn check_dtype_supported(dtype: DType, op: &'static str) -> Result<()> { - match dtype { - DType::F32 | DType::I32 | DType::U32 => Ok(()), - _ => Err(Error::UnsupportedDType { dtype, op }), - } -} +const CAT_COPY_SHADER_F32: &str = include_str!("cat_copy_f32.wgsl"); +const CAT_COPY_SHADER_I32: &str = include_str!("cat_copy_i32.wgsl"); +const CAT_COPY_SHADER_U32: &str = include_str!("cat_copy_u32.wgsl"); + +// ============================================================================ +// Static shaders — repeat (data-movement: F32 / I32 / U32) +// ============================================================================ + +const REPEAT_SHADER_F32: &str = include_str!("repeat_f32.wgsl"); +const REPEAT_SHADER_I32: &str = include_str!("repeat_i32.wgsl"); +const REPEAT_SHADER_U32: &str = include_str!("repeat_u32.wgsl"); + +// ============================================================================ +// Static shaders — pad (data-movement: F32 / I32 / U32) +// ============================================================================ + +const PAD_SHADER_F32: &str = include_str!("pad_f32.wgsl"); +const PAD_SHADER_I32: &str = include_str!("pad_i32.wgsl"); +const PAD_SHADER_U32: &str = include_str!("pad_u32.wgsl"); + +// ============================================================================ +// Static shaders — roll (data-movement: F32 / I32 / U32) +// ============================================================================ + +const ROLL_SHADER_F32: &str = include_str!("roll_f32.wgsl"); +const ROLL_SHADER_I32: &str = include_str!("roll_i32.wgsl"); +const ROLL_SHADER_U32: &str = include_str!("roll_u32.wgsl"); + +// ============================================================================ +// Static shaders — arange (F32 / I32 / U32) +// ============================================================================ + +const ARANGE_SHADER_F32: &str = include_str!("arange_f32.wgsl"); +const ARANGE_SHADER_I32: &str = include_str!("arange_i32.wgsl"); +const ARANGE_SHADER_U32: &str = include_str!("arange_u32.wgsl"); + +// ============================================================================ +// Static shaders — linspace (F32 only) +// ============================================================================ + +const LINSPACE_SHADER_F32: &str = include_str!("linspace_f32.wgsl"); + +// ============================================================================ +// Static shaders — eye (F32 / I32 / U32) +// ============================================================================ + +const EYE_SHADER_F32: &str = include_str!("eye_f32.wgsl"); +const EYE_SHADER_I32: &str = include_str!("eye_i32.wgsl"); +const EYE_SHADER_U32: &str = include_str!("eye_u32.wgsl"); + +// ============================================================================ +// Static shaders — rand / randn (F32 only) +// ============================================================================ + +const RAND_SHADER_F32: &str = include_str!("rand_f32.wgsl"); +const RANDN_SHADER_F32: &str = include_str!("randn_f32.wgsl"); + +// ============================================================================ +// Static shaders — randint (I32 / U32 only) +// ============================================================================ + +const RANDINT_SHADER_I32: &str = include_str!("randint_i32.wgsl"); +const RANDINT_SHADER_U32: &str = include_str!("randint_u32.wgsl"); + +// ============================================================================ +// Static shaders — multinomial (F32 only) +// ============================================================================ + +const MULTINOMIAL_WITH_REPLACEMENT_SHADER_F32: &str = + include_str!("multinomial_with_replacement_f32.wgsl"); +const MULTINOMIAL_WITHOUT_REPLACEMENT_SHADER_F32: &str = + include_str!("multinomial_without_replacement_f32.wgsl"); + +// ============================================================================ +// Helper: shader_info returns (shader_source, module_key, entry_point) +// ============================================================================ -/// Get the static module/entry point name for a shape operation. -fn kernel_name(op: &'static str, dtype: DType) -> Result<&'static str> { +fn shader_info( + op: &'static str, + dtype: DType, +) -> Result<(&'static str, &'static str, &'static str)> { match (op, dtype) { - ("cat_copy", DType::F32) => Ok("cat_copy_f32"), - ("cat_copy", DType::I32) => Ok("cat_copy_i32"), - ("cat_copy", DType::U32) => Ok("cat_copy_u32"), + // cat_copy + ("cat_copy", DType::F32) => Ok((CAT_COPY_SHADER_F32, "cat_copy_f32", "cat_copy_f32")), + ("cat_copy", DType::I32) => Ok((CAT_COPY_SHADER_I32, "cat_copy_i32", "cat_copy_i32")), + ("cat_copy", DType::U32) => Ok((CAT_COPY_SHADER_U32, "cat_copy_u32", "cat_copy_u32")), + // repeat + ("repeat", DType::F32) => Ok((REPEAT_SHADER_F32, "repeat_f32", "repeat_f32")), + ("repeat", DType::I32) => Ok((REPEAT_SHADER_I32, "repeat_i32", "repeat_i32")), + ("repeat", DType::U32) => Ok((REPEAT_SHADER_U32, "repeat_u32", "repeat_u32")), + // pad + ("pad", DType::F32) => Ok((PAD_SHADER_F32, "pad_f32", "pad_f32")), + ("pad", DType::I32) => Ok((PAD_SHADER_I32, "pad_i32", "pad_i32")), + ("pad", DType::U32) => Ok((PAD_SHADER_U32, "pad_u32", "pad_u32")), + // roll + ("roll", DType::F32) => Ok((ROLL_SHADER_F32, "roll_f32", "roll_f32")), + ("roll", DType::I32) => Ok((ROLL_SHADER_I32, "roll_i32", "roll_i32")), + ("roll", DType::U32) => Ok((ROLL_SHADER_U32, "roll_u32", "roll_u32")), + // arange + ("arange", DType::F32) => Ok((ARANGE_SHADER_F32, "arange_f32", "arange_f32")), + ("arange", DType::I32) => Ok((ARANGE_SHADER_I32, "arange_i32", "arange_i32")), + ("arange", DType::U32) => Ok((ARANGE_SHADER_U32, "arange_u32", "arange_u32")), + // linspace + ("linspace", DType::F32) => Ok((LINSPACE_SHADER_F32, "linspace_f32", "linspace_f32")), + // eye + ("eye", DType::F32) => Ok((EYE_SHADER_F32, "eye_f32", "eye_f32")), + ("eye", DType::I32) => Ok((EYE_SHADER_I32, "eye_i32", "eye_i32")), + ("eye", DType::U32) => Ok((EYE_SHADER_U32, "eye_u32", "eye_u32")), + // rand + ("rand", DType::F32) => Ok((RAND_SHADER_F32, "rand_f32", "rand_f32")), + // randn + ("randn", DType::F32) => Ok((RANDN_SHADER_F32, "randn_f32", "randn_f32")), + // randint + ("randint", DType::I32) => Ok((RANDINT_SHADER_I32, "randint_i32", "randint_i32")), + ("randint", DType::U32) => Ok((RANDINT_SHADER_U32, "randint_u32", "randint_u32")), _ => Err(Error::UnsupportedDType { dtype, op }), } } @@ -76,17 +179,14 @@ pub fn launch_cat_copy( return Ok(()); } - check_dtype_supported(dtype, "cat_copy")?; - - let name = kernel_name("cat_copy", dtype)?; - let shader_source = generate_cat_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let (shader, module_key, entry_point) = shader_info("cat_copy", dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[src, dst, params_buffer]); @@ -114,19 +214,6 @@ pub fn launch_cat_copy( // Arange Operation // ============================================================================ -/// Get the kernel name for arange operation. -fn arange_kernel_name(dtype: DType) -> Result<&'static str> { - match dtype { - DType::F32 => Ok("arange_f32"), - DType::I32 => Ok("arange_i32"), - DType::U32 => Ok("arange_u32"), - _ => Err(Error::UnsupportedDType { - dtype, - op: "arange", - }), - } -} - /// Launch an arange operation kernel. /// /// # Arguments @@ -149,15 +236,14 @@ pub fn launch_arange( return Ok(()); } - let name = arange_kernel_name(dtype)?; - let shader_source = generate_arange_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let (shader, module_key, entry_point) = shader_info("arange", dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params_buffer]); @@ -185,17 +271,6 @@ pub fn launch_arange( // Linspace Operation // ============================================================================ -/// Get the kernel name for linspace operation. -fn linspace_kernel_name(dtype: DType) -> Result<&'static str> { - match dtype { - DType::F32 => Ok("linspace_f32"), - _ => Err(Error::UnsupportedDType { - dtype, - op: "linspace", - }), - } -} - /// Launch a linspace operation kernel. /// /// # Arguments @@ -218,15 +293,14 @@ pub fn launch_linspace( return Ok(()); } - let name = linspace_kernel_name(dtype)?; - let shader_source = generate_linspace_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let (shader, module_key, entry_point) = shader_info("linspace", dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params_buffer]); @@ -254,16 +328,6 @@ pub fn launch_linspace( // Eye Operation // ============================================================================ -/// Get the kernel name for eye operation. -fn eye_kernel_name(dtype: DType) -> Result<&'static str> { - match dtype { - DType::F32 => Ok("eye_f32"), - DType::I32 => Ok("eye_i32"), - DType::U32 => Ok("eye_u32"), - _ => Err(Error::UnsupportedDType { dtype, op: "eye" }), - } -} - /// Launch an eye (identity matrix) operation kernel. /// /// # Arguments @@ -286,15 +350,14 @@ pub fn launch_eye( return Ok(()); } - let name = eye_kernel_name(dtype)?; - let shader_source = generate_eye_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let (shader, module_key, entry_point) = shader_info("eye", dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params_buffer]); @@ -320,14 +383,6 @@ pub fn launch_eye( // Random Operations // ============================================================================ -/// Get the kernel name for rand operation. -fn rand_kernel_name(dtype: DType) -> Result<&'static str> { - match dtype { - DType::F32 => Ok("rand_f32"), - _ => Err(Error::UnsupportedDType { dtype, op: "rand" }), - } -} - /// Launch a rand operation kernel (uniform [0, 1)). /// /// # Arguments @@ -350,15 +405,14 @@ pub fn launch_rand( return Ok(()); } - let name = rand_kernel_name(dtype)?; - let shader_source = generate_rand_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let (shader, module_key, entry_point) = shader_info("rand", dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params_buffer]); @@ -382,14 +436,6 @@ pub fn launch_rand( Ok(()) } -/// Get the kernel name for randn operation. -fn randn_kernel_name(dtype: DType) -> Result<&'static str> { - match dtype { - DType::F32 => Ok("randn_f32"), - _ => Err(Error::UnsupportedDType { dtype, op: "randn" }), - } -} - /// Launch a randn operation kernel (standard normal N(0, 1)). /// /// # Arguments @@ -412,15 +458,14 @@ pub fn launch_randn( return Ok(()); } - let name = randn_kernel_name(dtype)?; - let shader_source = generate_randn_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let (shader, module_key, entry_point) = shader_info("randn", dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params_buffer]); @@ -444,18 +489,6 @@ pub fn launch_randn( Ok(()) } -/// Get the kernel name for randint operation. -fn randint_kernel_name(dtype: DType) -> Result<&'static str> { - match dtype { - DType::I32 => Ok("randint_i32"), - DType::U32 => Ok("randint_u32"), - _ => Err(Error::UnsupportedDType { - dtype, - op: "randint", - }), - } -} - /// Launch a randint operation kernel (uniform integers in [low, high)). /// /// # Arguments @@ -478,15 +511,14 @@ pub fn launch_randint( return Ok(()); } - let name = randint_kernel_name(dtype)?; - let shader_source = generate_randint_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let (shader, module_key, entry_point) = shader_info("randint", dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 1, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[out, params_buffer]); @@ -514,19 +546,6 @@ pub fn launch_randint( // Repeat Operation // ============================================================================ -/// Get the kernel name for repeat operation. -fn repeat_kernel_name(dtype: DType) -> Result<&'static str> { - match dtype { - DType::F32 => Ok("repeat_f32"), - DType::I32 => Ok("repeat_i32"), - DType::U32 => Ok("repeat_u32"), - _ => Err(Error::UnsupportedDType { - dtype, - op: "repeat", - }), - } -} - /// Launch a repeat operation kernel. /// /// # Arguments @@ -551,15 +570,14 @@ pub fn launch_repeat( return Ok(()); } - let name = repeat_kernel_name(dtype)?; - let shader_source = generate_repeat_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let (shader, module_key, entry_point) = shader_info("repeat", dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[src, dst, params_buffer]); @@ -587,16 +605,6 @@ pub fn launch_repeat( // Pad Operation // ============================================================================ -/// Get the kernel name for pad operation. -fn pad_kernel_name(dtype: DType) -> Result<&'static str> { - match dtype { - DType::F32 => Ok("pad_f32"), - DType::I32 => Ok("pad_i32"), - DType::U32 => Ok("pad_u32"), - _ => Err(Error::UnsupportedDType { dtype, op: "pad" }), - } -} - /// Launch a pad operation kernel. /// /// # Arguments @@ -621,15 +629,14 @@ pub fn launch_pad( return Ok(()); } - let name = pad_kernel_name(dtype)?; - let shader_source = generate_pad_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let (shader, module_key, entry_point) = shader_info("pad", dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[src, dst, params_buffer]); @@ -655,16 +662,6 @@ pub fn launch_pad( // Roll Operation // ============================================================================ -/// Get the kernel name for roll operation. -fn roll_kernel_name(dtype: DType) -> Result<&'static str> { - match dtype { - DType::F32 => Ok("roll_f32"), - DType::I32 => Ok("roll_i32"), - DType::U32 => Ok("roll_u32"), - _ => Err(Error::UnsupportedDType { dtype, op: "roll" }), - } -} - /// Launch a roll operation kernel. /// /// # Arguments @@ -689,15 +686,14 @@ pub fn launch_roll( return Ok(()); } - let name = roll_kernel_name(dtype)?; - let shader_source = generate_roll_shader(dtype)?; - let module = cache.get_or_create_module(name, &shader_source); + let (shader, module_key, entry_point) = shader_info("roll", dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[src, dst, params_buffer]); @@ -762,8 +758,7 @@ pub fn launch_multinomial_with_replacement( } let name = "multinomial_with_replacement_f32"; - let shader_source = generate_multinomial_with_replacement_shader()?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module(name, MULTINOMIAL_WITH_REPLACEMENT_SHADER_F32); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, @@ -831,8 +826,7 @@ pub fn launch_multinomial_without_replacement( } let name = "multinomial_without_replacement_f32"; - let shader_source = generate_multinomial_without_replacement_shader()?; - let module = cache.get_or_create_module(name, &shader_source); + let module = cache.get_or_create_module(name, MULTINOMIAL_WITHOUT_REPLACEMENT_SHADER_F32); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, diff --git a/src/runtime/wgpu/shaders/slice_assign_f32.wgsl b/src/runtime/wgpu/shaders/slice_assign_f32.wgsl new file mode 100644 index 00000000..74884ead --- /dev/null +++ b/src/runtime/wgpu/shaders/slice_assign_f32.wgsl @@ -0,0 +1,34 @@ +// Auto-generated slice_assign operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct SliceAssignParams { + outer_size: u32, + dst_dim_size: u32, + src_dim_size: u32, + inner_size: u32, + start: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var src: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: SliceAssignParams; + +@compute @workgroup_size(256) +fn slice_assign_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = params.outer_size * params.src_dim_size * params.inner_size; + if (idx >= total) { + return; + } + + let inner_idx = idx % params.inner_size; + let src_dim_idx = (idx / params.inner_size) % params.src_dim_size; + let outer = idx / (params.src_dim_size * params.inner_size); + + let dst_offset = outer * params.dst_dim_size * params.inner_size + (params.start + src_dim_idx) * params.inner_size + inner_idx; + output[dst_offset] = src[idx]; +} diff --git a/src/runtime/wgpu/shaders/slice_assign_i32.wgsl b/src/runtime/wgpu/shaders/slice_assign_i32.wgsl new file mode 100644 index 00000000..cf7b1a92 --- /dev/null +++ b/src/runtime/wgpu/shaders/slice_assign_i32.wgsl @@ -0,0 +1,34 @@ +// Auto-generated slice_assign operations for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct SliceAssignParams { + outer_size: u32, + dst_dim_size: u32, + src_dim_size: u32, + inner_size: u32, + start: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var src: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: SliceAssignParams; + +@compute @workgroup_size(256) +fn slice_assign_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = params.outer_size * params.src_dim_size * params.inner_size; + if (idx >= total) { + return; + } + + let inner_idx = idx % params.inner_size; + let src_dim_idx = (idx / params.inner_size) % params.src_dim_size; + let outer = idx / (params.src_dim_size * params.inner_size); + + let dst_offset = outer * params.dst_dim_size * params.inner_size + (params.start + src_dim_idx) * params.inner_size + inner_idx; + output[dst_offset] = src[idx]; +} diff --git a/src/runtime/wgpu/shaders/slice_assign_u32.wgsl b/src/runtime/wgpu/shaders/slice_assign_u32.wgsl new file mode 100644 index 00000000..6172fe37 --- /dev/null +++ b/src/runtime/wgpu/shaders/slice_assign_u32.wgsl @@ -0,0 +1,34 @@ +// Auto-generated slice_assign operations for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct SliceAssignParams { + outer_size: u32, + dst_dim_size: u32, + src_dim_size: u32, + inner_size: u32, + start: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var src: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var params: SliceAssignParams; + +@compute @workgroup_size(256) +fn slice_assign_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = params.outer_size * params.src_dim_size * params.inner_size; + if (idx >= total) { + return; + } + + let inner_idx = idx % params.inner_size; + let src_dim_idx = (idx / params.inner_size) % params.src_dim_size; + let outer = idx / (params.src_dim_size * params.inner_size); + + let dst_offset = outer * params.dst_dim_size * params.inner_size + (params.start + src_dim_idx) * params.inner_size + inner_idx; + output[dst_offset] = src[idx]; +} diff --git a/src/runtime/wgpu/shaders/sort.rs b/src/runtime/wgpu/shaders/sort.rs index 53cc4497..671232bd 100644 --- a/src/runtime/wgpu/shaders/sort.rs +++ b/src/runtime/wgpu/shaders/sort.rs @@ -1,108 +1,181 @@ -//! Sort operation WGSL kernel launchers +//! Sort operation WGSL kernel launchers. //! -//! Provides launchers for sorting operations including: -//! - Sort, argsort (bitonic sort) -//! - Topk (top-k values and indices) -//! - Searchsorted (binary search) -//! - Nonzero (two-phase: count + gather) -//! - Unique (two-phase: count + extract on sorted input) -//! -//! Multi-dtype support: F32, I32, U32 +//! dtype policy: +//! - sort, sort_values_only, argsort: F32 / I32 / U32 +//! - topk, searchsorted: F32 only +//! - unique, unique_with_counts: F32 / I32 / U32 +//! - nonzero, flat_to_multi_index: F32 / I32 / U32 -use std::collections::HashMap; -use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; +use wgpu::{Buffer, Queue}; + +use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; +use crate::dtype::DType; +use crate::error::{Error, Result}; // ============================================================================ -// Lock Helpers (Handle Poisoned Locks Gracefully) +// Static shaders — sort ops (F32 / I32 / U32) // ============================================================================ -/// Acquire read lock, recovering from poison if necessary. -fn read_lock(lock: &RwLock) -> RwLockReadGuard<'_, T> { - lock.read().unwrap_or_else(|poisoned| poisoned.into_inner()) -} +const SORT_SHADER_F32: &str = include_str!("sort_f32.wgsl"); +const SORT_SHADER_I32: &str = include_str!("sort_i32.wgsl"); +const SORT_SHADER_U32: &str = include_str!("sort_u32.wgsl"); -/// Acquire write lock, recovering from poison if necessary. -fn write_lock(lock: &RwLock) -> RwLockWriteGuard<'_, T> { - lock.write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} +// ============================================================================ +// Static shaders — topk/searchsorted (F32 only) +// ============================================================================ -use wgpu::{Buffer, Queue}; - -use super::generator::{ - generate_count_nonzero_shader, generate_flat_to_multi_index_shader, - generate_gather_nonzero_shader, generate_searchsorted_shader, generate_sort_shader, - generate_topk_shader, generate_unique_shader, generate_unique_with_counts_shader, -}; -use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; -use crate::dtype::DType; -use crate::error::{Error, Result}; +const TOPK_SHADER_F32: &str = include_str!("topk_f32.wgsl"); +const SEARCHSORTED_SHADER_F32: &str = include_str!("searchsorted_f32.wgsl"); // ============================================================================ -// Shader Module Cache +// Static shaders — data-movement ops (F32 / I32 / U32) // ============================================================================ -static SORT_SHADER_CACHE: RwLock>> = - RwLock::new(None); +const COUNT_NONZERO_SHADER_F32: &str = include_str!("count_nonzero_f32.wgsl"); +const COUNT_NONZERO_SHADER_I32: &str = include_str!("count_nonzero_i32.wgsl"); +const COUNT_NONZERO_SHADER_U32: &str = include_str!("count_nonzero_u32.wgsl"); -fn get_shader(dtype: DType, op: &'static str) -> Result { - // Check cache - { - let cache = read_lock(&SORT_SHADER_CACHE); - if let Some(ref map) = *cache - && let Some(shader) = map.get(&(dtype, op)) - { - return Ok(shader.clone()); - } - } +const GATHER_NONZERO_SHADER_F32: &str = include_str!("gather_nonzero_f32.wgsl"); +const GATHER_NONZERO_SHADER_I32: &str = include_str!("gather_nonzero_i32.wgsl"); +const GATHER_NONZERO_SHADER_U32: &str = include_str!("gather_nonzero_u32.wgsl"); - // Generate shader - let shader = match op { - "sort" => generate_sort_shader(dtype)?, - "topk" => generate_topk_shader(dtype)?, - "searchsorted" => generate_searchsorted_shader(dtype)?, - "count_nonzero" => generate_count_nonzero_shader(dtype)?, - "gather_nonzero" => generate_gather_nonzero_shader(dtype)?, - "unique" => generate_unique_shader(dtype)?, - "flat_to_multi_index" => generate_flat_to_multi_index_shader()?, - _ => { - return Err(Error::InvalidArgument { - arg: "op", - reason: format!("Unknown sort operation: {}", op), - }); - } - }; +const FLAT_TO_MULTI_INDEX_SHADER: &str = include_str!("flat_to_multi_index.wgsl"); - // Cache and return - { - let mut cache = write_lock(&SORT_SHADER_CACHE); - let map = cache.get_or_insert_with(HashMap::new); - map.insert((dtype, op), shader.clone()); - } - Ok(shader) -} +const UNIQUE_WITH_COUNTS_SHADER_F32: &str = include_str!("unique_with_counts_f32.wgsl"); +const UNIQUE_WITH_COUNTS_SHADER_I32: &str = include_str!("unique_with_counts_i32.wgsl"); +const UNIQUE_WITH_COUNTS_SHADER_U32: &str = include_str!("unique_with_counts_u32.wgsl"); -fn module_key(dtype: DType, op: &'static str) -> String { - let suffix = match dtype { - DType::F32 => "f32", - DType::I32 => "i32", - DType::U32 => "u32", - _ => "f32", - }; - format!("{}_{}", op, suffix) +const COUNT_UNIQUE_SHADER_F32: &str = include_str!("count_unique_f32.wgsl"); +const COUNT_UNIQUE_SHADER_I32: &str = include_str!("count_unique_i32.wgsl"); +const COUNT_UNIQUE_SHADER_U32: &str = include_str!("count_unique_u32.wgsl"); + +const EXTRACT_UNIQUE_SHADER_F32: &str = include_str!("extract_unique_f32.wgsl"); +const EXTRACT_UNIQUE_SHADER_I32: &str = include_str!("extract_unique_i32.wgsl"); +const EXTRACT_UNIQUE_SHADER_U32: &str = include_str!("extract_unique_u32.wgsl"); + +// ============================================================================ +// Helpers +// ============================================================================ + +/// Returns (shader, module_key, entry_point) for sort ops. +/// Supports F32/I32/U32 for sort/sort_values_only/argsort, F32 only for topk/searchsorted. +fn sort_math_info( + op: &'static str, + dtype: DType, +) -> Result<(&'static str, &'static str, &'static str)> { + match op { + "sort" | "sort_values_only" | "argsort" => { + let (shader, module_key, _suffix) = match dtype { + DType::F32 => (SORT_SHADER_F32, "sort_f32", "f32"), + DType::I32 => (SORT_SHADER_I32, "sort_i32", "i32"), + DType::U32 => (SORT_SHADER_U32, "sort_u32", "u32"), + _ => return Err(Error::UnsupportedDType { dtype, op }), + }; + let entry_point: &'static str = match (op, dtype) { + ("sort", DType::F32) => "sort_f32", + ("sort", DType::I32) => "sort_i32", + ("sort", DType::U32) => "sort_u32", + ("sort_values_only", DType::F32) => "sort_values_only_f32", + ("sort_values_only", DType::I32) => "sort_values_only_i32", + ("sort_values_only", DType::U32) => "sort_values_only_u32", + ("argsort", DType::F32) => "argsort_f32", + ("argsort", DType::I32) => "argsort_i32", + ("argsort", DType::U32) => "argsort_u32", + _ => unreachable!(), + }; + Ok((shader, module_key, entry_point)) + } + "topk" => { + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { dtype, op }); + } + Ok((TOPK_SHADER_F32, "topk_f32", "topk_f32")) + } + "searchsorted" => { + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { dtype, op }); + } + Ok(( + SEARCHSORTED_SHADER_F32, + "searchsorted_f32", + "searchsorted_f32", + )) + } + _ => Err(Error::UnsupportedDType { dtype, op }), + } } -fn entry_point(op: &str, dtype: DType) -> String { - let suffix = match dtype { - DType::F32 => "f32", - DType::I32 => "i32", - DType::U32 => "u32", - _ => "f32", - }; - format!("{}_{}", op, suffix) +/// Returns (shader, module_key, entry_point) for data-movement ops. F32/I32/U32. +fn sort_data_info( + op: &'static str, + dtype: DType, +) -> Result<(&'static str, &'static str, &'static str)> { + Ok(match (op, dtype) { + ("count_nonzero", DType::F32) => ( + COUNT_NONZERO_SHADER_F32, + "count_nonzero_f32", + "count_nonzero_f32", + ), + ("count_nonzero", DType::I32) => ( + COUNT_NONZERO_SHADER_I32, + "count_nonzero_i32", + "count_nonzero_i32", + ), + ("count_nonzero", DType::U32) => ( + COUNT_NONZERO_SHADER_U32, + "count_nonzero_u32", + "count_nonzero_u32", + ), + ("gather_nonzero", DType::F32) => ( + GATHER_NONZERO_SHADER_F32, + "gather_nonzero_f32", + "gather_nonzero_f32", + ), + ("gather_nonzero", DType::I32) => ( + GATHER_NONZERO_SHADER_I32, + "gather_nonzero_i32", + "gather_nonzero_i32", + ), + ("gather_nonzero", DType::U32) => ( + GATHER_NONZERO_SHADER_U32, + "gather_nonzero_u32", + "gather_nonzero_u32", + ), + ("unique_with_counts", DType::F32) => ( + UNIQUE_WITH_COUNTS_SHADER_F32, + "unique_with_counts_f32", + "mark_boundaries_f32", + ), + ("unique_with_counts", DType::I32) => ( + UNIQUE_WITH_COUNTS_SHADER_I32, + "unique_with_counts_i32", + "mark_boundaries_i32", + ), + ("unique_with_counts", DType::U32) => ( + UNIQUE_WITH_COUNTS_SHADER_U32, + "unique_with_counts_u32", + "mark_boundaries_u32", + ), + ("scatter_unique_with_counts", DType::F32) => ( + UNIQUE_WITH_COUNTS_SHADER_F32, + "unique_with_counts_f32", + "scatter_unique_with_counts_f32", + ), + ("scatter_unique_with_counts", DType::I32) => ( + UNIQUE_WITH_COUNTS_SHADER_I32, + "unique_with_counts_i32", + "scatter_unique_with_counts_i32", + ), + ("scatter_unique_with_counts", DType::U32) => ( + UNIQUE_WITH_COUNTS_SHADER_U32, + "unique_with_counts_u32", + "scatter_unique_with_counts_u32", + ), + _ => return Err(Error::UnsupportedDType { dtype, op }), + }) } -fn check_dtype_supported(dtype: DType, op: &'static str) -> Result<()> { +fn check_data_dtype(dtype: DType, op: &'static str) -> Result<()> { if !matches!(dtype, DType::F32 | DType::I32 | DType::U32) { return Err(Error::UnsupportedDType { dtype, op }); } @@ -125,23 +198,15 @@ pub fn launch_sort( inner_size: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "sort")?; + let (shader, module_key, entry_point) = sort_math_info("sort", dtype)?; - let shader = get_shader(dtype, "sort")?; - let module_name = module_key(dtype, "sort"); - let ep = entry_point("sort", dtype); - - let static_module: &'static str = Box::leak(module_name.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - let static_ep: &'static str = Box::leak(ep.into_boxed_str()); - - let module = cache.get_or_create_module(static_module, static_shader); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(static_module, static_ep, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -179,17 +244,9 @@ pub fn launch_sort_values_only( inner_size: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "sort")?; - - let shader = get_shader(dtype, "sort")?; - let module_name = module_key(dtype, "sort"); - let ep = entry_point("sort_values_only", dtype); + let (shader, module_key, entry_point) = sort_math_info("sort_values_only", dtype)?; - let static_module: &'static str = Box::leak(module_name.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - let static_ep: &'static str = Box::leak(ep.into_boxed_str()); - - let module = cache.get_or_create_module(static_module, static_shader); + let module = cache.get_or_create_module(module_key, shader); // Need a 4-buffer layout but only use 3 (input, output, dummy_indices, params) // Actually for values_only we need different layout let layout = cache.get_or_create_layout(LayoutKey { @@ -197,7 +254,7 @@ pub fn launch_sort_values_only( num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(static_module, static_ep, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); // Create dummy indices buffer for the binding let dummy_buf = cache.device().create_buffer(&wgpu::BufferDescriptor { @@ -240,23 +297,15 @@ pub fn launch_argsort( inner_size: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "argsort")?; - - let shader = get_shader(dtype, "sort")?; - let module_name = module_key(dtype, "sort"); - let ep = entry_point("argsort", dtype); - - let static_module: &'static str = Box::leak(module_name.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - let static_ep: &'static str = Box::leak(ep.into_boxed_str()); + let (shader, module_key, entry_point) = sort_math_info("argsort", dtype)?; - let module = cache.get_or_create_module(static_module, static_shader); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(static_module, static_ep, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); // Create dummy values buffer for the binding let dummy_buf = cache.device().create_buffer(&wgpu::BufferDescriptor { @@ -305,23 +354,22 @@ pub fn launch_topk( inner_size: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "topk")?; - - let shader = get_shader(dtype, "topk")?; - let module_name = module_key(dtype, "topk"); - let ep = entry_point("topk", dtype); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "topk (WebGPU)", + }); + } - let static_module: &'static str = Box::leak(module_name.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - let static_ep: &'static str = Box::leak(ep.into_boxed_str()); + let (shader, module_key, entry_point) = sort_math_info("topk", dtype)?; - let module = cache.get_or_create_module(static_module, static_shader); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(static_module, static_ep, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -363,23 +411,22 @@ pub fn launch_searchsorted( num_values: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "searchsorted")?; - - let shader = get_shader(dtype, "searchsorted")?; - let module_name = module_key(dtype, "searchsorted"); - let ep = entry_point("searchsorted", dtype); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "searchsorted (WebGPU)", + }); + } - let static_module: &'static str = Box::leak(module_name.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - let static_ep: &'static str = Box::leak(ep.into_boxed_str()); + let (shader, module_key, entry_point) = sort_math_info("searchsorted", dtype)?; - let module = cache.get_or_create_module(static_module, static_shader); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(static_module, static_ep, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[sorted_seq, values, output, params_buffer]); @@ -417,23 +464,17 @@ pub fn launch_count_nonzero( numel: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "count_nonzero")?; + check_data_dtype(dtype, "count_nonzero")?; - let shader = get_shader(dtype, "count_nonzero")?; - let module_name = module_key(dtype, "count_nonzero"); - let ep = entry_point("count_nonzero", dtype); + let (shader, module_key, entry_point) = sort_data_info("count_nonzero", dtype)?; - let static_module: &'static str = Box::leak(module_name.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - let static_ep: &'static str = Box::leak(ep.into_boxed_str()); - - let module = cache.get_or_create_module(static_module, static_shader); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(static_module, static_ep, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, count_output, params_buffer]); @@ -468,23 +509,17 @@ pub fn launch_gather_nonzero( numel: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "gather_nonzero")?; - - let shader = get_shader(dtype, "gather_nonzero")?; - let module_name = module_key(dtype, "gather_nonzero"); - let ep = entry_point("gather_nonzero", dtype); + check_data_dtype(dtype, "gather_nonzero")?; - let static_module: &'static str = Box::leak(module_name.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - let static_ep: &'static str = Box::leak(ep.into_boxed_str()); + let (shader, module_key, entry_point) = sort_data_info("gather_nonzero", dtype)?; - let module = cache.get_or_create_module(static_module, static_shader); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(static_module, static_ep, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, indices_output, counter, params_buffer]); @@ -518,19 +553,18 @@ pub fn launch_flat_to_multi_index( params_buffer: &Buffer, nnz: usize, ) -> Result<()> { - let shader = get_shader(DType::I32, "flat_to_multi_index")?; - - let static_module: &'static str = "flat_to_multi_index"; - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - - let module = cache.get_or_create_module(static_module, static_shader); + let module = cache.get_or_create_module("flat_to_multi_index", FLAT_TO_MULTI_INDEX_SHADER); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_pipeline(static_module, "flat_to_multi_index", &module, &layout); + let pipeline = cache.get_or_create_pipeline( + "flat_to_multi_index", + "flat_to_multi_index", + &module, + &layout, + ); let bind_group = cache.create_bind_group(&layout, &[flat_indices, multi_indices, params_buffer]); @@ -569,43 +603,44 @@ pub fn launch_count_unique( numel: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "unique")?; - - let shader = get_shader(dtype, "unique")?; - let module_name = module_key(dtype, "unique"); - let ep = entry_point("count_unique", dtype); - - let static_module: &'static str = Box::leak(module_name.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - let static_ep: &'static str = Box::leak(ep.into_boxed_str()); + let (module_key, shader, entry_point) = match dtype { + DType::F32 => ( + "count_unique_f32", + COUNT_UNIQUE_SHADER_F32, + "count_unique_f32", + ), + DType::I32 => ( + "count_unique_i32", + COUNT_UNIQUE_SHADER_I32, + "count_unique_i32", + ), + DType::U32 => ( + "count_unique_u32", + COUNT_UNIQUE_SHADER_U32, + "count_unique_u32", + ), + _ => { + return Err(Error::UnsupportedDType { + dtype, + op: "count_unique", + }); + } + }; - let module = cache.get_or_create_module(static_module, static_shader); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { - num_storage_buffers: 3, + num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(static_module, static_ep, &module, &layout); - - // Create dummy output buffer for the binding - let dummy_buf = cache.device().create_buffer(&wgpu::BufferDescriptor { - label: Some("dummy_unique_output"), - size: 4, - usage: wgpu::BufferUsages::STORAGE, - mapped_at_creation: false, - }); - - let bind_group = cache.create_bind_group( - &layout, - &[sorted_input, &dummy_buf, count_output, params_buffer], - ); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); + let bind_group = cache.create_bind_group(&layout, &[sorted_input, count_output, params_buffer]); let mut encoder = cache .device() .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("count_unique"), }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("count_unique"), @@ -615,7 +650,6 @@ pub fn launch_count_unique( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(numel), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } @@ -631,24 +665,37 @@ pub fn launch_extract_unique( numel: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "unique")?; - - let shader = get_shader(dtype, "unique")?; - let module_name = module_key(dtype, "unique"); - let ep = entry_point("extract_unique", dtype); - - let static_module: &'static str = Box::leak(module_name.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - let static_ep: &'static str = Box::leak(ep.into_boxed_str()); + let (module_key, shader, entry_point) = match dtype { + DType::F32 => ( + "extract_unique_f32", + EXTRACT_UNIQUE_SHADER_F32, + "extract_unique_f32", + ), + DType::I32 => ( + "extract_unique_i32", + EXTRACT_UNIQUE_SHADER_I32, + "extract_unique_i32", + ), + DType::U32 => ( + "extract_unique_u32", + EXTRACT_UNIQUE_SHADER_U32, + "extract_unique_u32", + ), + _ => { + return Err(Error::UnsupportedDType { + dtype, + op: "extract_unique", + }); + } + }; - let module = cache.get_or_create_module(static_module, static_shader); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(static_module, static_ep, &module, &layout); - + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group( &layout, &[sorted_input, unique_output, counter, params_buffer], @@ -659,7 +706,6 @@ pub fn launch_extract_unique( .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("extract_unique"), }); - { let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("extract_unique"), @@ -669,7 +715,6 @@ pub fn launch_extract_unique( pass.set_bind_group(0, Some(&bind_group), &[]); pass.dispatch_workgroups(workgroup_count(numel), 1, 1); } - queue.submit(std::iter::once(encoder.finish())); Ok(()) } @@ -688,23 +733,17 @@ pub fn launch_mark_boundaries( numel: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "unique_with_counts")?; + check_data_dtype(dtype, "unique_with_counts")?; - let shader = get_shader_unique_with_counts(dtype)?; - let module_name = module_key_unique_with_counts(dtype); - let ep = entry_point("mark_boundaries", dtype); + let (shader, module_key, entry_point) = sort_data_info("unique_with_counts", dtype)?; - let static_module: &'static str = Box::leak(module_name.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - let static_ep: &'static str = Box::leak(ep.into_boxed_str()); - - let module = cache.get_or_create_module(static_module, static_shader); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(static_module, static_ep, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[sorted_input, boundary_flags, params_buffer]); @@ -742,23 +781,17 @@ pub fn launch_scatter_unique_with_counts( numel: usize, dtype: DType, ) -> Result<()> { - check_dtype_supported(dtype, "unique_with_counts")?; - - let shader = get_shader_unique_with_counts(dtype)?; - let module_name = module_key_unique_with_counts(dtype); - let ep = entry_point("scatter_unique_with_counts", dtype); + check_data_dtype(dtype, "unique_with_counts")?; - let static_module: &'static str = Box::leak(module_name.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - let static_ep: &'static str = Box::leak(ep.into_boxed_str()); + let (shader, module_key, entry_point) = sort_data_info("scatter_unique_with_counts", dtype)?; - let module = cache.get_or_create_module(static_module, static_shader); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 5, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline(static_module, static_ep, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -791,39 +824,3 @@ pub fn launch_scatter_unique_with_counts( queue.submit(std::iter::once(encoder.finish())); Ok(()) } - -// Cache for unique_with_counts shaders -static UNIQUE_COUNTS_SHADER_CACHE: RwLock>> = RwLock::new(None); - -fn get_shader_unique_with_counts(dtype: DType) -> Result { - // Check cache - { - let cache = read_lock(&UNIQUE_COUNTS_SHADER_CACHE); - if let Some(ref map) = *cache - && let Some(shader) = map.get(&dtype) - { - return Ok(shader.clone()); - } - } - - // Generate shader - let shader = generate_unique_with_counts_shader(dtype)?; - - // Cache and return - { - let mut cache = write_lock(&UNIQUE_COUNTS_SHADER_CACHE); - let map = cache.get_or_insert_with(HashMap::new); - map.insert(dtype, shader.clone()); - } - Ok(shader) -} - -fn module_key_unique_with_counts(dtype: DType) -> String { - let suffix = match dtype { - DType::F32 => "f32", - DType::I32 => "i32", - DType::U32 => "u32", - _ => "f32", - }; - format!("unique_with_counts_{}", suffix) -} diff --git a/src/runtime/wgpu/shaders/sort_f32.wgsl b/src/runtime/wgpu/shaders/sort_f32.wgsl new file mode 100644 index 00000000..6111247b --- /dev/null +++ b/src/runtime/wgpu/shaders/sort_f32.wgsl @@ -0,0 +1,277 @@ +// Auto-generated sort operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_SORT_SIZE: u32 = 512u; + +var shared_vals: array; +var shared_idxs: array; + +struct SortParams { + outer_size: u32, + sort_size: u32, + inner_size: u32, + descending: u32, +} + +struct TopkParams { + outer_size: u32, + sort_size: u32, + inner_size: u32, + k: u32, + largest: u32, + sorted: u32, +} + +struct SearchsortedParams { + seq_len: u32, + num_values: u32, + right: u32, + _pad: u32, +} + +struct CountParams { + numel: u32, +} + +@group(0) @binding(0) var sort_input: array; +@group(0) @binding(1) var sort_output: array; +@group(0) @binding(2) var sort_indices: array; +@group(0) @binding(3) var sort_params: SortParams; + +// Comparison helper +fn compare_less_f32(a: f32, b: f32) -> bool { + return a < b; +} + +// Stable comparison: use original index as tiebreaker for equal values +fn compare_less_stable_f32(a: f32, b: f32, idx_a: i32, idx_b: i32) -> bool { + if (a == b) { + return idx_a < idx_b; + } + return a < b; +} + +// Bitonic compare and swap for sort with indices (stable) +fn bitonic_cas_f32(i: u32, j: u32, dir: bool) { + let vi = shared_vals[i]; + let vj = shared_vals[j]; + let ii = shared_idxs[i]; + let ij = shared_idxs[j]; + let swap = select(compare_less_stable_f32(vi, vj, ii, ij), compare_less_stable_f32(vj, vi, ij, ii), dir); + if (swap) { + shared_vals[i] = vj; + shared_vals[j] = vi; + shared_idxs[i] = ij; + shared_idxs[j] = ii; + } +} + +// Bitonic compare and swap for sort values only +fn bitonic_cas_values_f32(i: u32, j: u32, dir: bool) { + let vi = shared_vals[i]; + let vj = shared_vals[j]; + let swap = select(compare_less_f32(vi, vj), compare_less_f32(vj, vi), dir); + if (swap) { + shared_vals[i] = vj; + shared_vals[j] = vi; + } +} + +// Sort with indices - returns both sorted values and original indices +@compute @workgroup_size(256) +fn sort_f32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3 +) { + let outer_idx = group_id.x; + let inner_idx = group_id.y; + let tid = local_id.x; + + let outer_size = sort_params.outer_size; + let sort_size = sort_params.sort_size; + let inner_size = sort_params.inner_size; + let descending = sort_params.descending != 0u; + + if (outer_idx >= outer_size || inner_idx >= inner_size) { + return; + } + + // Pad to next power of 2 + var n = sort_size; + var p: u32 = 1u; + while (p < n) { + p = p << 1u; + } + n = min(p, MAX_SORT_SIZE); + + // Load data into shared memory + let base_offset = outer_idx * sort_size * inner_size + inner_idx; + for (var i = tid; i < n; i = i + WORKGROUP_SIZE) { + if (i < sort_size) { + let idx = base_offset + i * inner_size; + shared_vals[i] = sort_input[idx]; + shared_idxs[i] = i32(i); + } else { + // Pad with max/min based on sort direction + shared_vals[i] = select(f32(3.402823e+38), f32(-3.402823e+38), descending); + shared_idxs[i] = i32(i); + } + } + workgroupBarrier(); + + // Bitonic sort + for (var k: u32 = 2u; k <= n; k = k << 1u) { + for (var j: u32 = k >> 1u; j > 0u; j = j >> 1u) { + for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) { + // Calculate bitonic network indices + let ij = (i / j) * 2u * j + (i % j); + let ij_pair = ij + j; + + // Direction depends on which half of the network we're in + let ascending_local = ((ij / k) % 2u == 0u) != descending; + + if (ij_pair < n) { + bitonic_cas_f32(ij, ij_pair, ascending_local); + } + } + workgroupBarrier(); + } + } + + // Write sorted values and indices + for (var i = tid; i < sort_size; i = i + WORKGROUP_SIZE) { + let out_idx = base_offset + i * inner_size; + sort_output[out_idx] = shared_vals[i]; + sort_indices[out_idx] = shared_idxs[i]; + } +} + +// Sort values only (no indices) +@compute @workgroup_size(256) +fn sort_values_only_f32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3 +) { + let outer_idx = group_id.x; + let inner_idx = group_id.y; + let tid = local_id.x; + + let outer_size = sort_params.outer_size; + let sort_size = sort_params.sort_size; + let inner_size = sort_params.inner_size; + let descending = sort_params.descending != 0u; + + if (outer_idx >= outer_size || inner_idx >= inner_size) { + return; + } + + var n = sort_size; + var p: u32 = 1u; + while (p < n) { + p = p << 1u; + } + n = min(p, MAX_SORT_SIZE); + + let base_offset = outer_idx * sort_size * inner_size + inner_idx; + for (var i = tid; i < n; i = i + WORKGROUP_SIZE) { + if (i < sort_size) { + let idx = base_offset + i * inner_size; + shared_vals[i] = sort_input[idx]; + } else { + shared_vals[i] = select(f32(3.402823e+38), f32(-3.402823e+38), descending); + } + } + workgroupBarrier(); + + // Bitonic sort + for (var k: u32 = 2u; k <= n; k = k << 1u) { + for (var j: u32 = k >> 1u; j > 0u; j = j >> 1u) { + for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) { + // Calculate bitonic network indices + let ij = (i / j) * 2u * j + (i % j); + let ij_pair = ij + j; + + // Direction depends on which half of the network we're in + let ascending_local = ((ij / k) % 2u == 0u) != descending; + + if (ij_pair < n) { + bitonic_cas_values_f32(ij, ij_pair, ascending_local); + } + } + workgroupBarrier(); + } + } + + for (var i = tid; i < sort_size; i = i + WORKGROUP_SIZE) { + let out_idx = base_offset + i * inner_size; + sort_output[out_idx] = shared_vals[i]; + } +} + +// Argsort - returns indices only +@compute @workgroup_size(256) +fn argsort_f32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3 +) { + let outer_idx = group_id.x; + let inner_idx = group_id.y; + let tid = local_id.x; + + let outer_size = sort_params.outer_size; + let sort_size = sort_params.sort_size; + let inner_size = sort_params.inner_size; + let descending = sort_params.descending != 0u; + + if (outer_idx >= outer_size || inner_idx >= inner_size) { + return; + } + + var n = sort_size; + var p: u32 = 1u; + while (p < n) { + p = p << 1u; + } + n = min(p, MAX_SORT_SIZE); + + let base_offset = outer_idx * sort_size * inner_size + inner_idx; + for (var i = tid; i < n; i = i + WORKGROUP_SIZE) { + if (i < sort_size) { + let idx = base_offset + i * inner_size; + shared_vals[i] = sort_input[idx]; + shared_idxs[i] = i32(i); + } else { + shared_vals[i] = select(f32(3.402823e+38), f32(-3.402823e+38), descending); + shared_idxs[i] = i32(i); + } + } + workgroupBarrier(); + + // Bitonic sort + for (var k: u32 = 2u; k <= n; k = k << 1u) { + for (var j: u32 = k >> 1u; j > 0u; j = j >> 1u) { + for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) { + // Calculate bitonic network indices + let ij = (i / j) * 2u * j + (i % j); + let ij_pair = ij + j; + + // Direction depends on which half of the network we're in + let ascending_local = ((ij / k) % 2u == 0u) != descending; + + if (ij_pair < n) { + bitonic_cas_f32(ij, ij_pair, ascending_local); + } + } + workgroupBarrier(); + } + } + + // Write indices only + for (var i = tid; i < sort_size; i = i + WORKGROUP_SIZE) { + let out_idx = base_offset + i * inner_size; + sort_indices[out_idx] = shared_idxs[i]; + } +} diff --git a/src/runtime/wgpu/shaders/sort_i32.wgsl b/src/runtime/wgpu/shaders/sort_i32.wgsl new file mode 100644 index 00000000..8276a560 --- /dev/null +++ b/src/runtime/wgpu/shaders/sort_i32.wgsl @@ -0,0 +1,257 @@ +// Auto-generated sort operations for i32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_SORT_SIZE: u32 = 512u; + +var shared_vals: array; +var shared_idxs: array; + +struct SortParams { + outer_size: u32, + sort_size: u32, + inner_size: u32, + descending: u32, +} + +@group(0) @binding(0) var sort_input: array; +@group(0) @binding(1) var sort_output: array; +@group(0) @binding(2) var sort_indices: array; +@group(0) @binding(3) var sort_params: SortParams; + +// Comparison helper +fn compare_less_i32(a: i32, b: i32) -> bool { + return a < b; +} + +// Stable comparison: use original index as tiebreaker for equal values +fn compare_less_stable_i32(a: i32, b: i32, idx_a: i32, idx_b: i32) -> bool { + if (a == b) { + return idx_a < idx_b; + } + return a < b; +} + +// Bitonic compare and swap for sort with indices (stable) +fn bitonic_cas_i32(i: u32, j: u32, dir: bool) { + let vi = shared_vals[i]; + let vj = shared_vals[j]; + let ii = shared_idxs[i]; + let ij = shared_idxs[j]; + let swap = select(compare_less_stable_i32(vi, vj, ii, ij), compare_less_stable_i32(vj, vi, ij, ii), dir); + if (swap) { + shared_vals[i] = vj; + shared_vals[j] = vi; + shared_idxs[i] = ij; + shared_idxs[j] = ii; + } +} + +// Bitonic compare and swap for sort values only +fn bitonic_cas_values_i32(i: u32, j: u32, dir: bool) { + let vi = shared_vals[i]; + let vj = shared_vals[j]; + let swap = select(compare_less_i32(vi, vj), compare_less_i32(vj, vi), dir); + if (swap) { + shared_vals[i] = vj; + shared_vals[j] = vi; + } +} + +// Sort with indices - returns both sorted values and original indices +@compute @workgroup_size(256) +fn sort_i32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3 +) { + let outer_idx = group_id.x; + let inner_idx = group_id.y; + let tid = local_id.x; + + let outer_size = sort_params.outer_size; + let sort_size = sort_params.sort_size; + let inner_size = sort_params.inner_size; + let descending = sort_params.descending != 0u; + + if (outer_idx >= outer_size || inner_idx >= inner_size) { + return; + } + + // Pad to next power of 2 + var n = sort_size; + var p: u32 = 1u; + while (p < n) { + p = p << 1u; + } + n = min(p, MAX_SORT_SIZE); + + // Load data into shared memory + let base_offset = outer_idx * sort_size * inner_size + inner_idx; + for (var i = tid; i < n; i = i + WORKGROUP_SIZE) { + if (i < sort_size) { + let idx = base_offset + i * inner_size; + shared_vals[i] = sort_input[idx]; + shared_idxs[i] = i32(i); + } else { + // Pad with max/min based on sort direction + shared_vals[i] = select(2147483647i, (-2147483647i - 1i), descending); + shared_idxs[i] = i32(i); + } + } + workgroupBarrier(); + + // Bitonic sort + for (var k: u32 = 2u; k <= n; k = k << 1u) { + for (var j: u32 = k >> 1u; j > 0u; j = j >> 1u) { + for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) { + // Calculate bitonic network indices + let ij = (i / j) * 2u * j + (i % j); + let ij_pair = ij + j; + + // Direction depends on which half of the network we're in + let ascending_local = ((ij / k) % 2u == 0u) != descending; + + if (ij_pair < n) { + bitonic_cas_i32(ij, ij_pair, ascending_local); + } + } + workgroupBarrier(); + } + } + + // Write sorted values and indices + for (var i = tid; i < sort_size; i = i + WORKGROUP_SIZE) { + let out_idx = base_offset + i * inner_size; + sort_output[out_idx] = shared_vals[i]; + sort_indices[out_idx] = shared_idxs[i]; + } +} + +// Sort values only (no indices) +@compute @workgroup_size(256) +fn sort_values_only_i32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3 +) { + let outer_idx = group_id.x; + let inner_idx = group_id.y; + let tid = local_id.x; + + let outer_size = sort_params.outer_size; + let sort_size = sort_params.sort_size; + let inner_size = sort_params.inner_size; + let descending = sort_params.descending != 0u; + + if (outer_idx >= outer_size || inner_idx >= inner_size) { + return; + } + + var n = sort_size; + var p: u32 = 1u; + while (p < n) { + p = p << 1u; + } + n = min(p, MAX_SORT_SIZE); + + let base_offset = outer_idx * sort_size * inner_size + inner_idx; + for (var i = tid; i < n; i = i + WORKGROUP_SIZE) { + if (i < sort_size) { + let idx = base_offset + i * inner_size; + shared_vals[i] = sort_input[idx]; + } else { + shared_vals[i] = select(2147483647i, (-2147483647i - 1i), descending); + } + } + workgroupBarrier(); + + // Bitonic sort + for (var k: u32 = 2u; k <= n; k = k << 1u) { + for (var j: u32 = k >> 1u; j > 0u; j = j >> 1u) { + for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) { + // Calculate bitonic network indices + let ij = (i / j) * 2u * j + (i % j); + let ij_pair = ij + j; + + // Direction depends on which half of the network we're in + let ascending_local = ((ij / k) % 2u == 0u) != descending; + + if (ij_pair < n) { + bitonic_cas_values_i32(ij, ij_pair, ascending_local); + } + } + workgroupBarrier(); + } + } + + for (var i = tid; i < sort_size; i = i + WORKGROUP_SIZE) { + let out_idx = base_offset + i * inner_size; + sort_output[out_idx] = shared_vals[i]; + } +} + +// Argsort - returns indices only +@compute @workgroup_size(256) +fn argsort_i32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3 +) { + let outer_idx = group_id.x; + let inner_idx = group_id.y; + let tid = local_id.x; + + let outer_size = sort_params.outer_size; + let sort_size = sort_params.sort_size; + let inner_size = sort_params.inner_size; + let descending = sort_params.descending != 0u; + + if (outer_idx >= outer_size || inner_idx >= inner_size) { + return; + } + + var n = sort_size; + var p: u32 = 1u; + while (p < n) { + p = p << 1u; + } + n = min(p, MAX_SORT_SIZE); + + let base_offset = outer_idx * sort_size * inner_size + inner_idx; + for (var i = tid; i < n; i = i + WORKGROUP_SIZE) { + if (i < sort_size) { + let idx = base_offset + i * inner_size; + shared_vals[i] = sort_input[idx]; + shared_idxs[i] = i32(i); + } else { + shared_vals[i] = select(2147483647i, (-2147483647i - 1i), descending); + shared_idxs[i] = i32(i); + } + } + workgroupBarrier(); + + // Bitonic sort + for (var k: u32 = 2u; k <= n; k = k << 1u) { + for (var j: u32 = k >> 1u; j > 0u; j = j >> 1u) { + for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) { + // Calculate bitonic network indices + let ij = (i / j) * 2u * j + (i % j); + let ij_pair = ij + j; + + // Direction depends on which half of the network we're in + let ascending_local = ((ij / k) % 2u == 0u) != descending; + + if (ij_pair < n) { + bitonic_cas_i32(ij, ij_pair, ascending_local); + } + } + workgroupBarrier(); + } + } + + // Write indices only + for (var i = tid; i < sort_size; i = i + WORKGROUP_SIZE) { + let out_idx = base_offset + i * inner_size; + sort_indices[out_idx] = shared_idxs[i]; + } +} diff --git a/src/runtime/wgpu/shaders/sort_u32.wgsl b/src/runtime/wgpu/shaders/sort_u32.wgsl new file mode 100644 index 00000000..35b18d99 --- /dev/null +++ b/src/runtime/wgpu/shaders/sort_u32.wgsl @@ -0,0 +1,257 @@ +// Auto-generated sort operations for u32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_SORT_SIZE: u32 = 512u; + +var shared_vals: array; +var shared_idxs: array; + +struct SortParams { + outer_size: u32, + sort_size: u32, + inner_size: u32, + descending: u32, +} + +@group(0) @binding(0) var sort_input: array; +@group(0) @binding(1) var sort_output: array; +@group(0) @binding(2) var sort_indices: array; +@group(0) @binding(3) var sort_params: SortParams; + +// Comparison helper +fn compare_less_u32(a: u32, b: u32) -> bool { + return a < b; +} + +// Stable comparison: use original index as tiebreaker for equal values +fn compare_less_stable_u32(a: u32, b: u32, idx_a: i32, idx_b: i32) -> bool { + if (a == b) { + return idx_a < idx_b; + } + return a < b; +} + +// Bitonic compare and swap for sort with indices (stable) +fn bitonic_cas_u32(i: u32, j: u32, dir: bool) { + let vi = shared_vals[i]; + let vj = shared_vals[j]; + let ii = shared_idxs[i]; + let ij = shared_idxs[j]; + let swap = select(compare_less_stable_u32(vi, vj, ii, ij), compare_less_stable_u32(vj, vi, ij, ii), dir); + if (swap) { + shared_vals[i] = vj; + shared_vals[j] = vi; + shared_idxs[i] = ij; + shared_idxs[j] = ii; + } +} + +// Bitonic compare and swap for sort values only +fn bitonic_cas_values_u32(i: u32, j: u32, dir: bool) { + let vi = shared_vals[i]; + let vj = shared_vals[j]; + let swap = select(compare_less_u32(vi, vj), compare_less_u32(vj, vi), dir); + if (swap) { + shared_vals[i] = vj; + shared_vals[j] = vi; + } +} + +// Sort with indices - returns both sorted values and original indices +@compute @workgroup_size(256) +fn sort_u32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3 +) { + let outer_idx = group_id.x; + let inner_idx = group_id.y; + let tid = local_id.x; + + let outer_size = sort_params.outer_size; + let sort_size = sort_params.sort_size; + let inner_size = sort_params.inner_size; + let descending = sort_params.descending != 0u; + + if (outer_idx >= outer_size || inner_idx >= inner_size) { + return; + } + + // Pad to next power of 2 + var n = sort_size; + var p: u32 = 1u; + while (p < n) { + p = p << 1u; + } + n = min(p, MAX_SORT_SIZE); + + // Load data into shared memory + let base_offset = outer_idx * sort_size * inner_size + inner_idx; + for (var i = tid; i < n; i = i + WORKGROUP_SIZE) { + if (i < sort_size) { + let idx = base_offset + i * inner_size; + shared_vals[i] = sort_input[idx]; + shared_idxs[i] = i32(i); + } else { + // Pad with max/min based on sort direction + shared_vals[i] = select(4294967295u, 0u, descending); + shared_idxs[i] = i32(i); + } + } + workgroupBarrier(); + + // Bitonic sort + for (var k: u32 = 2u; k <= n; k = k << 1u) { + for (var j: u32 = k >> 1u; j > 0u; j = j >> 1u) { + for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) { + // Calculate bitonic network indices + let ij = (i / j) * 2u * j + (i % j); + let ij_pair = ij + j; + + // Direction depends on which half of the network we're in + let ascending_local = ((ij / k) % 2u == 0u) != descending; + + if (ij_pair < n) { + bitonic_cas_u32(ij, ij_pair, ascending_local); + } + } + workgroupBarrier(); + } + } + + // Write sorted values and indices + for (var i = tid; i < sort_size; i = i + WORKGROUP_SIZE) { + let out_idx = base_offset + i * inner_size; + sort_output[out_idx] = shared_vals[i]; + sort_indices[out_idx] = shared_idxs[i]; + } +} + +// Sort values only (no indices) +@compute @workgroup_size(256) +fn sort_values_only_u32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3 +) { + let outer_idx = group_id.x; + let inner_idx = group_id.y; + let tid = local_id.x; + + let outer_size = sort_params.outer_size; + let sort_size = sort_params.sort_size; + let inner_size = sort_params.inner_size; + let descending = sort_params.descending != 0u; + + if (outer_idx >= outer_size || inner_idx >= inner_size) { + return; + } + + var n = sort_size; + var p: u32 = 1u; + while (p < n) { + p = p << 1u; + } + n = min(p, MAX_SORT_SIZE); + + let base_offset = outer_idx * sort_size * inner_size + inner_idx; + for (var i = tid; i < n; i = i + WORKGROUP_SIZE) { + if (i < sort_size) { + let idx = base_offset + i * inner_size; + shared_vals[i] = sort_input[idx]; + } else { + shared_vals[i] = select(4294967295u, 0u, descending); + } + } + workgroupBarrier(); + + // Bitonic sort + for (var k: u32 = 2u; k <= n; k = k << 1u) { + for (var j: u32 = k >> 1u; j > 0u; j = j >> 1u) { + for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) { + // Calculate bitonic network indices + let ij = (i / j) * 2u * j + (i % j); + let ij_pair = ij + j; + + // Direction depends on which half of the network we're in + let ascending_local = ((ij / k) % 2u == 0u) != descending; + + if (ij_pair < n) { + bitonic_cas_values_u32(ij, ij_pair, ascending_local); + } + } + workgroupBarrier(); + } + } + + for (var i = tid; i < sort_size; i = i + WORKGROUP_SIZE) { + let out_idx = base_offset + i * inner_size; + sort_output[out_idx] = shared_vals[i]; + } +} + +// Argsort - returns indices only +@compute @workgroup_size(256) +fn argsort_u32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3 +) { + let outer_idx = group_id.x; + let inner_idx = group_id.y; + let tid = local_id.x; + + let outer_size = sort_params.outer_size; + let sort_size = sort_params.sort_size; + let inner_size = sort_params.inner_size; + let descending = sort_params.descending != 0u; + + if (outer_idx >= outer_size || inner_idx >= inner_size) { + return; + } + + var n = sort_size; + var p: u32 = 1u; + while (p < n) { + p = p << 1u; + } + n = min(p, MAX_SORT_SIZE); + + let base_offset = outer_idx * sort_size * inner_size + inner_idx; + for (var i = tid; i < n; i = i + WORKGROUP_SIZE) { + if (i < sort_size) { + let idx = base_offset + i * inner_size; + shared_vals[i] = sort_input[idx]; + shared_idxs[i] = i32(i); + } else { + shared_vals[i] = select(4294967295u, 0u, descending); + shared_idxs[i] = i32(i); + } + } + workgroupBarrier(); + + // Bitonic sort + for (var k: u32 = 2u; k <= n; k = k << 1u) { + for (var j: u32 = k >> 1u; j > 0u; j = j >> 1u) { + for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) { + // Calculate bitonic network indices + let ij = (i / j) * 2u * j + (i % j); + let ij_pair = ij + j; + + // Direction depends on which half of the network we're in + let ascending_local = ((ij / k) % 2u == 0u) != descending; + + if (ij_pair < n) { + bitonic_cas_u32(ij, ij_pair, ascending_local); + } + } + workgroupBarrier(); + } + } + + // Write indices only + for (var i = tid; i < sort_size; i = i + WORKGROUP_SIZE) { + let out_idx = base_offset + i * inner_size; + sort_indices[out_idx] = shared_idxs[i]; + } +} diff --git a/src/runtime/wgpu/shaders/sparse_24.rs b/src/runtime/wgpu/shaders/sparse_24.rs new file mode 100644 index 00000000..63b6f0a4 --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_24.rs @@ -0,0 +1,117 @@ +//! WGSL shader launchers for 2:4 structured sparsity operations + +use wgpu::{Buffer, Queue}; + +use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; +use crate::error::Result; + +const PRUNE_SHADER: &str = include_str!("sparse_24_prune.wgsl"); +const DECOMPRESS_SHADER: &str = include_str!("sparse_24_decompress.wgsl"); + +/// Parameters for 2:4 sparse operations (matches WGSL Params struct) +#[repr(C)] +#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +pub struct Sparse24Params { + /// Total number of 2:4 groups across all rows (m * num_groups_per_row). + pub total_groups: u32, + /// Number of groups per row (k / 4). + pub num_groups_per_row: u32, + /// Number of metadata columns per row. + pub meta_cols: u32, + /// Half of the K dimension (k / 2), i.e. number of non-zero values per row. + pub half_k: u32, + /// Full K dimension of the dense matrix. + pub k: u32, + /// Padding to satisfy WGSL 16-byte uniform alignment. + pub _pad0: u32, + /// Padding to satisfy WGSL 16-byte uniform alignment. + pub _pad1: u32, + /// Padding to satisfy WGSL 16-byte uniform alignment. + pub _pad2: u32, +} + +/// Launch prune-to-2:4 shader. +pub fn launch_sparse_24_prune( + cache: &PipelineCache, + queue: &Queue, + dense: &Buffer, + compressed: &Buffer, + metadata: &Buffer, + params_buffer: &Buffer, + total_groups: usize, +) -> Result<()> { + let module = cache.get_or_create_module("sparse_24_prune", PRUNE_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 3, + num_uniform_buffers: 1, + num_readonly_storage: 1, + }); + let pipeline = cache.get_or_create_dynamic_pipeline( + "sparse_24_prune", + "sparse_24_prune_f32", + &module, + &layout, + ); + let bind_group = + cache.create_bind_group(&layout, &[dense, compressed, metadata, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("sparse_24_prune"), + }); + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("sparse_24_prune"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + pass.dispatch_workgroups(workgroup_count(total_groups), 1, 1); + } + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} + +/// Launch decompress-from-2:4 shader. +pub fn launch_sparse_24_decompress( + cache: &PipelineCache, + queue: &Queue, + compressed: &Buffer, + metadata: &Buffer, + dense: &Buffer, + params_buffer: &Buffer, + total_groups: usize, +) -> Result<()> { + let module = cache.get_or_create_module("sparse_24_decompress", DECOMPRESS_SHADER); + let layout = cache.get_or_create_layout(LayoutKey { + num_storage_buffers: 3, + num_uniform_buffers: 1, + num_readonly_storage: 2, + }); + let pipeline = cache.get_or_create_dynamic_pipeline( + "sparse_24_decompress", + "sparse_24_decompress_f32", + &module, + &layout, + ); + let bind_group = + cache.create_bind_group(&layout, &[compressed, metadata, dense, params_buffer]); + + let mut encoder = cache + .device() + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("sparse_24_decompress"), + }); + { + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("sparse_24_decompress"), + timestamp_writes: None, + }); + pass.set_pipeline(&pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + pass.dispatch_workgroups(workgroup_count(total_groups), 1, 1); + } + queue.submit(std::iter::once(encoder.finish())); + Ok(()) +} diff --git a/src/runtime/wgpu/shaders/sparse_24_decompress.wgsl b/src/runtime/wgpu/shaders/sparse_24_decompress.wgsl new file mode 100644 index 00000000..7115d3e8 --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_24_decompress.wgsl @@ -0,0 +1,61 @@ +// 2:4 Structured Sparsity: Decompress to dense format (F32 only) +// +// Reconstructs dense matrix from compressed 2:4 format. +// One workgroup thread per group of 4 output elements. + +struct Params { + total_groups: u32, + num_groups_per_row: u32, + meta_cols: u32, + half_k: u32, + k: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var compressed: array; +@group(0) @binding(1) var metadata: array; +@group(0) @binding(2) var dense: array; +@group(0) @binding(3) var params: Params; + +@compute @workgroup_size(256) +fn sparse_24_decompress_f32(@builtin(global_invocation_id) gid: vec3) { + let tid = gid.x; + if (tid >= params.total_groups) { + return; + } + + let row = tid / params.num_groups_per_row; + let g = tid % params.num_groups_per_row; + + // Read metadata + let word_idx = g / 8u; + let nibble_idx = g % 8u; + let word = metadata[row * params.meta_cols + word_idx]; + let mask = (word >> (nibble_idx * 4u)) & 0xFu; + + // Read 2 compressed values + let in_base = row * params.half_k + g * 2u; + let v0 = compressed[in_base]; + let v1 = compressed[in_base + 1u]; + + // Write to dense + let out_base = row * params.k + g * 4u; + dense[out_base] = 0.0; + dense[out_base + 1u] = 0.0; + dense[out_base + 2u] = 0.0; + dense[out_base + 3u] = 0.0; + + var val_idx: u32 = 0u; + for (var bit: u32 = 0u; bit < 4u; bit = bit + 1u) { + if ((mask & (1u << bit)) != 0u) { + if (val_idx == 0u) { + dense[out_base + bit] = v0; + } else { + dense[out_base + bit] = v1; + } + val_idx = val_idx + 1u; + } + } +} diff --git a/src/runtime/wgpu/shaders/sparse_24_prune.wgsl b/src/runtime/wgpu/shaders/sparse_24_prune.wgsl new file mode 100644 index 00000000..d5718de7 --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_24_prune.wgsl @@ -0,0 +1,86 @@ +// 2:4 Structured Sparsity: Prune to 2:4 format (F32 only) +// +// For each group of 4 consecutive elements along K, keeps the 2 with largest magnitude. +// One workgroup thread per group. + +struct Params { + total_groups: u32, + num_groups_per_row: u32, + meta_cols: u32, + half_k: u32, + k: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var dense: array; +@group(0) @binding(1) var compressed: array; +@group(0) @binding(2) var metadata: array>; +@group(0) @binding(3) var params: Params; + +@compute @workgroup_size(256) +fn sparse_24_prune_f32(@builtin(global_invocation_id) gid: vec3) { + let tid = gid.x; + if (tid >= params.total_groups) { + return; + } + + let row = tid / params.num_groups_per_row; + let g = tid % params.num_groups_per_row; + let base = row * params.k + g * 4u; + + // Load 4 values + let v0 = dense[base]; + let v1 = dense[base + 1u]; + let v2 = dense[base + 2u]; + let v3 = dense[base + 3u]; + + // Compute magnitudes + let m0 = abs(v0); + let m1 = abs(v1); + let m2 = abs(v2); + let m3 = abs(v3); + + // Find top-2 by magnitude using selection network + var idx0: u32 = 0u; + var idx1: u32 = 1u; + var mag0 = m0; + var mag1 = m1; + + if (mag1 > mag0) { + let ti = idx0; idx0 = idx1; idx1 = ti; + let tf = mag0; mag0 = mag1; mag1 = tf; + } + + if (m2 > mag1) { + idx1 = 2u; mag1 = m2; + if (mag1 > mag0) { + let ti = idx0; idx0 = idx1; idx1 = ti; + let tf = mag0; mag0 = mag1; mag1 = tf; + } + } + + if (m3 > mag1) { + idx1 = 3u; mag1 = m3; + if (mag1 > mag0) { + let ti = idx0; idx0 = idx1; idx1 = ti; + } + } + + let first = min(idx0, idx1); + let second = max(idx0, idx1); + + // Write compressed values + let out_base = row * params.half_k + g * 2u; + let vals = array(v0, v1, v2, v3); + compressed[out_base] = vals[first]; + compressed[out_base + 1u] = vals[second]; + + // Build 4-bit bitmask and atomically OR into metadata + let mask = (1u << first) | (1u << second); + let word_idx = g / 8u; + let nibble_idx = g % 8u; + let meta_offset = row * params.meta_cols + word_idx; + atomicOr(&metadata[meta_offset], mask << (nibble_idx * 4u)); +} diff --git a/src/runtime/wgpu/shaders/sparse_algorithms_f32.wgsl b/src/runtime/wgpu/shaders/sparse_algorithms_f32.wgsl new file mode 100644 index 00000000..aed60f0c --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_algorithms_f32.wgsl @@ -0,0 +1,197 @@ +// Sparse Algorithm Shaders - F32 +// +// Column-Parallel Dense x Sparse Matrix Multiplication (DSMM) +// Sparse x Sparse Matrix Multiplication (SpGEMM) - symbolic, accumulate, scatter phases + +// ============================================================================ +// DSMM: C = A * B (Dense A [M,K] x Sparse B CSC [K,N] -> Dense C [M,N]) +// Each thread computes one element C[row, col] +// ============================================================================ + +struct DsmmParams { + m: u32, + k: u32, + n: u32, + _pad: u32, +} + +@group(0) @binding(0) var dsmm_a: array; +@group(0) @binding(1) var dsmm_col_ptrs: array; +@group(0) @binding(2) var dsmm_row_indices: array; +@group(0) @binding(3) var dsmm_b_values: array; +@group(0) @binding(4) var dsmm_c: array; +@group(0) @binding(5) var dsmm_params: DsmmParams; + +@compute @workgroup_size(256) +fn dsmm_csc_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = dsmm_params.m * dsmm_params.n; + if (idx >= total) { + return; + } + + let row = idx / dsmm_params.n; + let col = idx % dsmm_params.n; + + let col_start = dsmm_col_ptrs[col]; + let col_end = dsmm_col_ptrs[col + 1u]; + + var sum: f32 = 0.0; + for (var j: i32 = col_start; j < col_end; j = j + 1) { + let k = dsmm_row_indices[j]; + let b_val = dsmm_b_values[j]; + let a_idx = row * dsmm_params.k + u32(k); + sum = sum + dsmm_a[a_idx] * b_val; + } + + dsmm_c[idx] = sum; +} + +// ============================================================================ +// SpGEMM Symbolic Phase: count NNZ per output row +// CSR A [M,K] x CSR B [K,N] -> row_nnz[M] +// Uses bitmap for small N +// ============================================================================ + +struct SymbolicParams { + m: u32, + n: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var sym_a_row_ptrs: array; +@group(0) @binding(1) var sym_a_col_indices: array; +@group(0) @binding(2) var sym_b_row_ptrs: array; +@group(0) @binding(3) var sym_b_col_indices: array; +@group(0) @binding(4) var sym_row_nnz: array; +@group(0) @binding(5) var sym_bitmap: array>; +@group(0) @binding(6) var sym_params: SymbolicParams; + +@compute @workgroup_size(256) +fn spgemm_symbolic_f32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= sym_params.m) { + return; + } + + let words_per_row = (sym_params.n + 31u) / 32u; + let bitmap_offset = row * words_per_row; + + for (var w: u32 = 0u; w < words_per_row; w = w + 1u) { + atomicStore(&sym_bitmap[bitmap_offset + w], 0u); + } + + let a_start = sym_a_row_ptrs[row]; + let a_end = sym_a_row_ptrs[row + 1u]; + + for (var ai: i32 = a_start; ai < a_end; ai = ai + 1) { + let k = sym_a_col_indices[ai]; + + let b_start = sym_b_row_ptrs[k]; + let b_end = sym_b_row_ptrs[k + 1]; + + for (var bi: i32 = b_start; bi < b_end; bi = bi + 1) { + let j = sym_b_col_indices[bi]; + let word_idx = bitmap_offset + u32(j) / 32u; + let bit_idx = u32(j) % 32u; + atomicOr(&sym_bitmap[word_idx], 1u << bit_idx); + } + } + + var count: i32 = 0; + for (var w: u32 = 0u; w < words_per_row; w = w + 1u) { + let word = atomicLoad(&sym_bitmap[bitmap_offset + w]); + count = count + i32(countOneBits(word)); + } + + sym_row_nnz[row] = count; +} + +// ============================================================================ +// SpGEMM Accumulate Phase +// CSR A [M,K] x CSR B [K,N] -> dense row accumulators +// ============================================================================ + +struct SpgemmParams { + m: u32, + n: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var accum_a_row_ptrs: array; +@group(0) @binding(1) var accum_a_col_indices: array; +@group(0) @binding(2) var accum_a_values: array; +@group(0) @binding(3) var accum_b_row_ptrs: array; +@group(0) @binding(4) var accum_b_col_indices: array; +@group(0) @binding(5) var accum_b_values: array; +@group(0) @binding(6) var accum_dense: array; +@group(0) @binding(7) var accum_flags: array; +@group(0) @binding(8) var accum_params: SpgemmParams; + +@compute @workgroup_size(256) +fn spgemm_accumulate_f32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= accum_params.m) { + return; + } + + let accum_offset = row * accum_params.n; + + for (var col: u32 = 0u; col < accum_params.n; col = col + 1u) { + accum_dense[accum_offset + col] = 0.0; + accum_flags[accum_offset + col] = 0u; + } + + let a_start = accum_a_row_ptrs[row]; + let a_end = accum_a_row_ptrs[row + 1u]; + + for (var ai: i32 = a_start; ai < a_end; ai = ai + 1) { + let k = accum_a_col_indices[ai]; + let a_val = accum_a_values[ai]; + + let b_start = accum_b_row_ptrs[k]; + let b_end = accum_b_row_ptrs[k + 1]; + + for (var bi: i32 = b_start; bi < b_end; bi = bi + 1) { + let j = accum_b_col_indices[bi]; + let b_val = accum_b_values[bi]; + let idx = accum_offset + u32(j); + accum_dense[idx] = accum_dense[idx] + a_val * b_val; + accum_flags[idx] = 1u; + } + } +} + +// ============================================================================ +// SpGEMM Scatter Phase +// Compact dense row accumulators into CSR output arrays +// ============================================================================ + +@group(0) @binding(0) var scatter_c_row_ptrs: array; +@group(0) @binding(1) var scatter_accum: array; +@group(0) @binding(2) var scatter_flags: array; +@group(0) @binding(3) var scatter_c_col_indices: array; +@group(0) @binding(4) var scatter_c_values: array; +@group(0) @binding(5) var scatter_params: SpgemmParams; + +@compute @workgroup_size(256) +fn spgemm_scatter_f32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= scatter_params.m) { + return; + } + + let accum_offset = row * scatter_params.n; + var write_idx: i32 = scatter_c_row_ptrs[row]; + + for (var col: u32 = 0u; col < scatter_params.n; col = col + 1u) { + let idx = accum_offset + col; + if (scatter_flags[idx] != 0u) { + scatter_c_col_indices[write_idx] = i32(col); + scatter_c_values[write_idx] = scatter_accum[idx]; + write_idx = write_idx + 1; + } + } +} diff --git a/src/runtime/wgpu/shaders/sparse_algorithms_launcher.rs b/src/runtime/wgpu/shaders/sparse_algorithms_launcher.rs index 758985b3..fcefefc3 100644 --- a/src/runtime/wgpu/shaders/sparse_algorithms_launcher.rs +++ b/src/runtime/wgpu/shaders/sparse_algorithms_launcher.rs @@ -6,14 +6,21 @@ use wgpu::{Buffer, Queue}; -use super::generator::dtype_suffix; -use super::generator::sparse_algorithms::{ - generate_dsmm_csc_shader, generate_spgemm_accumulate_shader, generate_spgemm_scatter_shader, - generate_spgemm_symbolic_shader, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; -use crate::error::Result; +use crate::error::{Error, Result}; + +const SPARSE_ALGORITHMS_F32: &str = include_str!("sparse_algorithms_f32.wgsl"); + +fn algorithms_shader_info(dtype: DType) -> Result<(&'static str, &'static str)> { + match dtype { + DType::F32 => Ok((SPARSE_ALGORITHMS_F32, "sparse_algorithms_f32")), + _ => Err(Error::UnsupportedDType { + dtype, + op: "sparse_algorithms (WebGPU)", + }), + } +} /// Launch DSMM (Dense × Sparse) kernel: C = A * B /// @@ -40,12 +47,9 @@ pub fn launch_dsmm_csc( n: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("dsmm_csc_{}", suffix); + let (shader, module_name) = algorithms_shader_info(dtype)?; - let shader_source = generate_dsmm_csc_shader(dtype)?; - let module_name = format!("dsmm_csc_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_name, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 5, // a, col_ptrs, row_indices, b_values, c @@ -53,7 +57,7 @@ pub fn launch_dsmm_csc( num_readonly_storage: 4, // a, col_ptrs, row_indices, b_values }); - let pipeline = cache.get_or_create_dynamic_pipeline("dsmm_csc", &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_name, "dsmm_csc_f32", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -106,12 +110,9 @@ pub fn launch_spgemm_symbolic( m: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("spgemm_symbolic_{}", suffix); + let (shader, module_name) = algorithms_shader_info(dtype)?; - let shader_source = generate_spgemm_symbolic_shader(dtype)?; - let module_name = format!("spgemm_symbolic_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_name, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 6, // a_row_ptrs, a_col_indices, b_row_ptrs, b_col_indices, row_nnz, bitmap @@ -120,7 +121,7 @@ pub fn launch_spgemm_symbolic( }); let pipeline = - cache.get_or_create_dynamic_pipeline("spgemm_symbolic", &entry_point, &module, &layout); + cache.get_or_create_pipeline(module_name, "spgemm_symbolic_f32", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -183,12 +184,9 @@ pub fn launch_spgemm_accumulate( m: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("spgemm_accumulate_{}", suffix); + let (shader, module_name) = algorithms_shader_info(dtype)?; - let shader_source = generate_spgemm_accumulate_shader(dtype)?; - let module_name = format!("spgemm_accumulate_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_name, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 8, // a_row_ptrs, a_col_indices, a_values, b_row_ptrs, b_col_indices, b_values, accum, flags @@ -197,7 +195,7 @@ pub fn launch_spgemm_accumulate( }); let pipeline = - cache.get_or_create_dynamic_pipeline("spgemm_accumulate", &entry_point, &module, &layout); + cache.get_or_create_pipeline(module_name, "spgemm_accumulate_f32", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -247,12 +245,9 @@ pub fn launch_spgemm_scatter( m: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("spgemm_scatter_{}", suffix); + let (shader, module_name) = algorithms_shader_info(dtype)?; - let shader_source = generate_spgemm_scatter_shader(dtype)?; - let module_name = format!("spgemm_scatter_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_name, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 5, // c_row_ptrs, accum, flags, c_col_indices, c_values @@ -261,7 +256,7 @@ pub fn launch_spgemm_scatter( }); let pipeline = - cache.get_or_create_dynamic_pipeline("spgemm_scatter", &entry_point, &module, &layout); + cache.get_or_create_pipeline(module_name, "spgemm_scatter_f32", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -294,45 +289,3 @@ pub fn launch_spgemm_scatter( queue.submit(std::iter::once(encoder.finish())); Ok(()) } - -#[cfg(test)] -mod tests { - use super::super::generator::sparse_algorithms::{ - generate_dsmm_csc_shader, generate_spgemm_accumulate_shader, - generate_spgemm_scatter_shader, generate_spgemm_symbolic_shader, - }; - use super::*; - - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_dsmm_csc_shader_syntax_f32() { - let shader = generate_dsmm_csc_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("DSMM shader should be valid WGSL"); - } - - #[test] - fn test_spgemm_symbolic_shader_syntax_f32() { - let shader = generate_spgemm_symbolic_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("SpGEMM symbolic shader should be valid WGSL"); - } - - #[test] - fn test_spgemm_accumulate_shader_syntax_f32() { - let shader = generate_spgemm_accumulate_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("SpGEMM accumulate shader should be valid WGSL"); - } - - #[test] - fn test_spgemm_scatter_shader_syntax_f32() { - let shader = generate_spgemm_scatter_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("SpGEMM scatter shader should be valid WGSL"); - } -} diff --git a/src/runtime/wgpu/shaders/sparse_conversions_f32.wgsl b/src/runtime/wgpu/shaders/sparse_conversions_f32.wgsl new file mode 100644 index 00000000..95809f85 --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_conversions_f32.wgsl @@ -0,0 +1,252 @@ +// Sparse format conversion shaders - F32 typed operations + +// ============================================================================ +// coo_to_csr_scatter +// ============================================================================ + +struct ScatterParams { + nnz: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var c2r_in_row_indices: array; +@group(0) @binding(1) var c2r_in_col_indices: array; +@group(0) @binding(2) var c2r_in_values: array; +@group(0) @binding(3) var c2r_row_ptrs_atomic: array>; +@group(0) @binding(4) var c2r_out_col_indices: array; +@group(0) @binding(5) var c2r_out_values: array; +@group(0) @binding(6) var c2r_params: ScatterParams; + +@compute @workgroup_size(256) +fn coo_to_csr_scatter(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= c2r_params.nnz) { + return; + } + + let row = c2r_in_row_indices[idx]; + let col = c2r_in_col_indices[idx]; + let val = c2r_in_values[idx]; + + let pos = atomicAdd(&c2r_row_ptrs_atomic[row], 1); + + c2r_out_col_indices[pos] = col; + c2r_out_values[pos] = val; +} + +// ============================================================================ +// coo_to_csc_scatter +// ============================================================================ + +@group(0) @binding(0) var c2c_in_row_indices: array; +@group(0) @binding(1) var c2c_in_col_indices: array; +@group(0) @binding(2) var c2c_in_values: array; +@group(0) @binding(3) var c2c_col_ptrs_atomic: array>; +@group(0) @binding(4) var c2c_out_row_indices: array; +@group(0) @binding(5) var c2c_out_values: array; +@group(0) @binding(6) var c2c_params: ScatterParams; + +@compute @workgroup_size(256) +fn coo_to_csc_scatter(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= c2c_params.nnz) { + return; + } + + let row = c2c_in_row_indices[idx]; + let col = c2c_in_col_indices[idx]; + let val = c2c_in_values[idx]; + + let pos = atomicAdd(&c2c_col_ptrs_atomic[col], 1); + + c2c_out_row_indices[pos] = row; + c2c_out_values[pos] = val; +} + +// ============================================================================ +// csr_to_csc_scatter (transpose) +// ============================================================================ + +struct TransposeRowParams { + nrows: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var r2c_in_row_ptrs: array; +@group(0) @binding(1) var r2c_in_col_indices: array; +@group(0) @binding(2) var r2c_in_values: array; +@group(0) @binding(3) var r2c_col_ptrs_atomic: array>; +@group(0) @binding(4) var r2c_out_row_indices: array; +@group(0) @binding(5) var r2c_out_values: array; +@group(0) @binding(6) var r2c_params: TransposeRowParams; + +@compute @workgroup_size(256) +fn csr_to_csc_scatter(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= r2c_params.nrows) { + return; + } + + let start = r2c_in_row_ptrs[row]; + let end = r2c_in_row_ptrs[row + 1u]; + + for (var i = start; i < end; i = i + 1) { + let col = r2c_in_col_indices[i]; + let val = r2c_in_values[i]; + + let pos = atomicAdd(&r2c_col_ptrs_atomic[col], 1); + + r2c_out_row_indices[pos] = i32(row); + r2c_out_values[pos] = val; + } +} + +// ============================================================================ +// csc_to_csr_scatter (transpose) +// ============================================================================ + +struct TransposeColParams { + ncols: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var c2r2_in_col_ptrs: array; +@group(0) @binding(1) var c2r2_in_row_indices: array; +@group(0) @binding(2) var c2r2_in_values: array; +@group(0) @binding(3) var c2r2_row_ptrs_atomic: array>; +@group(0) @binding(4) var c2r2_out_col_indices: array; +@group(0) @binding(5) var c2r2_out_values: array; +@group(0) @binding(6) var c2r2_params: TransposeColParams; + +@compute @workgroup_size(256) +fn csc_to_csr_scatter(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= c2r2_params.ncols) { + return; + } + + let start = c2r2_in_col_ptrs[col]; + let end = c2r2_in_col_ptrs[col + 1u]; + + for (var i = start; i < end; i = i + 1) { + let row = c2r2_in_row_indices[i]; + let val = c2r2_in_values[i]; + + let pos = atomicAdd(&c2r2_row_ptrs_atomic[row], 1); + + c2r2_out_col_indices[pos] = i32(col); + c2r2_out_values[pos] = val; + } +} + +// ============================================================================ +// csr_to_dense +// ============================================================================ + +struct CsrToDenseParams { + nrows: u32, + ncols: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var ctd_row_ptrs: array; +@group(0) @binding(1) var ctd_col_indices: array; +@group(0) @binding(2) var ctd_values: array; +@group(0) @binding(3) var ctd_dense: array; +@group(0) @binding(4) var ctd_params: CsrToDenseParams; + +@compute @workgroup_size(256) +fn csr_to_dense(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= ctd_params.nrows) { + return; + } + + let start = ctd_row_ptrs[row]; + let end = ctd_row_ptrs[row + 1u]; + let ncols = ctd_params.ncols; + + for (var i = start; i < end; i = i + 1) { + let col = u32(ctd_col_indices[i]); + ctd_dense[row * ncols + col] = ctd_values[i]; + } +} + +// ============================================================================ +// count_nonzeros +// ============================================================================ + +struct CountNzParams { + total_elems: u32, + threshold_bits: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var cnz_dense: array; +@group(0) @binding(1) var cnz_count: atomic; +@group(0) @binding(2) var cnz_params: CountNzParams; + +@compute @workgroup_size(256) +fn count_nonzeros(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= cnz_params.total_elems) { + return; + } + + let val = cnz_dense[idx]; + let threshold = bitcast(cnz_params.threshold_bits); + let zero_val = f32(0); + + if (abs(val) >= threshold) { + atomicAdd(&cnz_count, 1u); + } +} + +// ============================================================================ +// dense_to_coo_scatter +// ============================================================================ + +struct DenseToCooParams { + nrows: u32, + ncols: u32, + threshold_bits: u32, + _pad0: u32, +} + +@group(0) @binding(0) var dtc_dense: array; +@group(0) @binding(1) var dtc_row_indices: array; +@group(0) @binding(2) var dtc_col_indices: array; +@group(0) @binding(3) var dtc_values: array; +@group(0) @binding(4) var dtc_write_pos: atomic; +@group(0) @binding(5) var dtc_params: DenseToCooParams; + +@compute @workgroup_size(256) +fn dense_to_coo_scatter(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = dtc_params.nrows * dtc_params.ncols; + if (idx >= total) { + return; + } + + let val = dtc_dense[idx]; + let threshold = bitcast(dtc_params.threshold_bits); + + if (abs(val) >= threshold) { + let row = idx / dtc_params.ncols; + let col = idx % dtc_params.ncols; + + let pos = atomicAdd(&dtc_write_pos, 1u); + + dtc_row_indices[pos] = i32(row); + dtc_col_indices[pos] = i32(col); + dtc_values[pos] = val; + } +} diff --git a/src/runtime/wgpu/shaders/sparse_conversions_i32.wgsl b/src/runtime/wgpu/shaders/sparse_conversions_i32.wgsl new file mode 100644 index 00000000..283251ff --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_conversions_i32.wgsl @@ -0,0 +1,251 @@ +// Sparse format conversion shaders - I32 typed operations + +struct ScatterParams { + nnz: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +// ============================================================================ +// coo_to_csr_scatter +// ============================================================================ + +@group(0) @binding(0) var c2r_in_row_indices: array; +@group(0) @binding(1) var c2r_in_col_indices: array; +@group(0) @binding(2) var c2r_in_values: array; +@group(0) @binding(3) var c2r_row_ptrs_atomic: array>; +@group(0) @binding(4) var c2r_out_col_indices: array; +@group(0) @binding(5) var c2r_out_values: array; +@group(0) @binding(6) var c2r_params: ScatterParams; + +@compute @workgroup_size(256) +fn coo_to_csr_scatter(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= c2r_params.nnz) { + return; + } + + let row = c2r_in_row_indices[idx]; + let col = c2r_in_col_indices[idx]; + let val = c2r_in_values[idx]; + + let pos = atomicAdd(&c2r_row_ptrs_atomic[row], 1); + + c2r_out_col_indices[pos] = col; + c2r_out_values[pos] = val; +} + +// ============================================================================ +// coo_to_csc_scatter +// ============================================================================ + +@group(0) @binding(0) var c2c_in_row_indices: array; +@group(0) @binding(1) var c2c_in_col_indices: array; +@group(0) @binding(2) var c2c_in_values: array; +@group(0) @binding(3) var c2c_col_ptrs_atomic: array>; +@group(0) @binding(4) var c2c_out_row_indices: array; +@group(0) @binding(5) var c2c_out_values: array; +@group(0) @binding(6) var c2c_params: ScatterParams; + +@compute @workgroup_size(256) +fn coo_to_csc_scatter(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= c2c_params.nnz) { + return; + } + + let row = c2c_in_row_indices[idx]; + let col = c2c_in_col_indices[idx]; + let val = c2c_in_values[idx]; + + let pos = atomicAdd(&c2c_col_ptrs_atomic[col], 1); + + c2c_out_row_indices[pos] = row; + c2c_out_values[pos] = val; +} + +// ============================================================================ +// csr_to_csc_scatter (transpose) +// ============================================================================ + +struct TransposeRowParams { + nrows: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var r2c_in_row_ptrs: array; +@group(0) @binding(1) var r2c_in_col_indices: array; +@group(0) @binding(2) var r2c_in_values: array; +@group(0) @binding(3) var r2c_col_ptrs_atomic: array>; +@group(0) @binding(4) var r2c_out_row_indices: array; +@group(0) @binding(5) var r2c_out_values: array; +@group(0) @binding(6) var r2c_params: TransposeRowParams; + +@compute @workgroup_size(256) +fn csr_to_csc_scatter(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= r2c_params.nrows) { + return; + } + + let start = r2c_in_row_ptrs[row]; + let end = r2c_in_row_ptrs[row + 1u]; + + for (var i = start; i < end; i = i + 1) { + let col = r2c_in_col_indices[i]; + let val = r2c_in_values[i]; + + let pos = atomicAdd(&r2c_col_ptrs_atomic[col], 1); + + r2c_out_row_indices[pos] = i32(row); + r2c_out_values[pos] = val; + } +} + +// ============================================================================ +// csc_to_csr_scatter (transpose) +// ============================================================================ + +struct TransposeColParams { + ncols: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var c2r2_in_col_ptrs: array; +@group(0) @binding(1) var c2r2_in_row_indices: array; +@group(0) @binding(2) var c2r2_in_values: array; +@group(0) @binding(3) var c2r2_row_ptrs_atomic: array>; +@group(0) @binding(4) var c2r2_out_col_indices: array; +@group(0) @binding(5) var c2r2_out_values: array; +@group(0) @binding(6) var c2r2_params: TransposeColParams; + +@compute @workgroup_size(256) +fn csc_to_csr_scatter(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= c2r2_params.ncols) { + return; + } + + let start = c2r2_in_col_ptrs[col]; + let end = c2r2_in_col_ptrs[col + 1u]; + + for (var i = start; i < end; i = i + 1) { + let row = c2r2_in_row_indices[i]; + let val = c2r2_in_values[i]; + + let pos = atomicAdd(&c2r2_row_ptrs_atomic[row], 1); + + c2r2_out_col_indices[pos] = i32(col); + c2r2_out_values[pos] = val; + } +} + +// ============================================================================ +// csr_to_dense +// ============================================================================ + +struct CsrToDenseParams { + nrows: u32, + ncols: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var ctd_row_ptrs: array; +@group(0) @binding(1) var ctd_col_indices: array; +@group(0) @binding(2) var ctd_values: array; +@group(0) @binding(3) var ctd_dense: array; +@group(0) @binding(4) var ctd_params: CsrToDenseParams; + +@compute @workgroup_size(256) +fn csr_to_dense(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= ctd_params.nrows) { + return; + } + + let start = ctd_row_ptrs[row]; + let end = ctd_row_ptrs[row + 1u]; + let ncols = ctd_params.ncols; + + for (var i = start; i < end; i = i + 1) { + let col = u32(ctd_col_indices[i]); + ctd_dense[row * ncols + col] = ctd_values[i]; + } +} + +// ============================================================================ +// count_nonzeros +// ============================================================================ + +struct CountNzParams { + total_elems: u32, + threshold_bits: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var cnz_dense: array; +@group(0) @binding(1) var cnz_count: atomic; +@group(0) @binding(2) var cnz_params: CountNzParams; + +@compute @workgroup_size(256) +fn count_nonzeros(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= cnz_params.total_elems) { + return; + } + + let val = cnz_dense[idx]; + let zero_val = i32(0); + + if (val != zero_val) { + atomicAdd(&cnz_count, 1u); + } +} + +// ============================================================================ +// dense_to_coo_scatter +// ============================================================================ + +struct DenseToCooParams { + nrows: u32, + ncols: u32, + threshold_bits: u32, + _pad0: u32, +} + +@group(0) @binding(0) var dtc_dense: array; +@group(0) @binding(1) var dtc_row_indices: array; +@group(0) @binding(2) var dtc_col_indices: array; +@group(0) @binding(3) var dtc_values: array; +@group(0) @binding(4) var dtc_write_pos: atomic; +@group(0) @binding(5) var dtc_params: DenseToCooParams; + +@compute @workgroup_size(256) +fn dense_to_coo_scatter(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = dtc_params.nrows * dtc_params.ncols; + if (idx >= total) { + return; + } + + let val = dtc_dense[idx]; + let zero_val = i32(0); + + if (val != zero_val) { + let row = idx / dtc_params.ncols; + let col = idx % dtc_params.ncols; + + let pos = atomicAdd(&dtc_write_pos, 1u); + + dtc_row_indices[pos] = i32(row); + dtc_col_indices[pos] = i32(col); + dtc_values[pos] = val; + } +} diff --git a/src/runtime/wgpu/shaders/sparse_conversions_indices.wgsl b/src/runtime/wgpu/shaders/sparse_conversions_indices.wgsl new file mode 100644 index 00000000..40250c63 --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_conversions_indices.wgsl @@ -0,0 +1,116 @@ +// Sparse format conversion shaders - index-only (type-independent) +// +// expand_row_ptrs: CSR row pointers -> explicit row indices +// expand_col_ptrs: CSC col pointers -> explicit col indices +// histogram: count elements per bucket +// copy_ptrs: copy a pointer array + +// ============================================================================ +// expand_row_ptrs +// ============================================================================ + +struct ExpandRowParams { + nrows: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var erp_row_ptrs: array; +@group(0) @binding(1) var erp_row_indices: array; +@group(0) @binding(2) var erp_params: ExpandRowParams; + +@compute @workgroup_size(256) +fn expand_row_ptrs(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= erp_params.nrows) { + return; + } + + let start = erp_row_ptrs[row]; + let end = erp_row_ptrs[row + 1u]; + + for (var i = start; i < end; i = i + 1) { + erp_row_indices[i] = i32(row); + } +} + +// ============================================================================ +// expand_col_ptrs +// ============================================================================ + +struct ExpandColParams { + ncols: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var ecp_col_ptrs: array; +@group(0) @binding(1) var ecp_col_indices: array; +@group(0) @binding(2) var ecp_params: ExpandColParams; + +@compute @workgroup_size(256) +fn expand_col_ptrs(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= ecp_params.ncols) { + return; + } + + let start = ecp_col_ptrs[col]; + let end = ecp_col_ptrs[col + 1u]; + + for (var i = start; i < end; i = i + 1) { + ecp_col_indices[i] = i32(col); + } +} + +// ============================================================================ +// histogram +// ============================================================================ + +struct HistogramParams { + nnz: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var hist_indices: array; +@group(0) @binding(1) var hist_counts: array>; +@group(0) @binding(2) var hist_params: HistogramParams; + +@compute @workgroup_size(256) +fn histogram(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= hist_params.nnz) { + return; + } + + let bucket = hist_indices[idx]; + atomicAdd(&hist_counts[bucket], 1); +} + +// ============================================================================ +// copy_ptrs +// ============================================================================ + +struct CopyPtrsParams { + n: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var cp_src: array; +@group(0) @binding(1) var cp_dst: array; +@group(0) @binding(2) var cp_params: CopyPtrsParams; + +@compute @workgroup_size(256) +fn copy_ptrs(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= cp_params.n) { + return; + } + cp_dst[idx] = cp_src[idx]; +} diff --git a/src/runtime/wgpu/shaders/sparse_conversions_launcher.rs b/src/runtime/wgpu/shaders/sparse_conversions_launcher.rs index ea916954..4b88a87a 100644 --- a/src/runtime/wgpu/shaders/sparse_conversions_launcher.rs +++ b/src/runtime/wgpu/shaders/sparse_conversions_launcher.rs @@ -7,16 +7,28 @@ use wgpu::{Buffer, Queue}; -use super::generator::dtype_suffix; -use super::generator::sparse_conversions::{ - generate_coo_to_csc_scatter_shader, generate_coo_to_csr_scatter_shader, - generate_copy_ptrs_shader, generate_csc_to_csr_scatter_shader, - generate_csr_to_csc_scatter_shader, generate_expand_col_ptrs_shader, - generate_expand_row_ptrs_shader, generate_histogram_shader, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; -use crate::error::Result; +use crate::error::{Error, Result}; + +// Static WGSL shader sources +const SPARSE_CONVERSIONS_INDICES: &str = include_str!("sparse_conversions_indices.wgsl"); +const SPARSE_CONVERSIONS_F32: &str = include_str!("sparse_conversions_f32.wgsl"); +const SPARSE_CONVERSIONS_I32: &str = include_str!("sparse_conversions_i32.wgsl"); +const SPARSE_CONVERSIONS_U32: &str = include_str!("sparse_conversions_u32.wgsl"); + +/// Return (module_key, shader_source) for a dtype-specific conversions shader. +fn typed_shader(dtype: DType) -> Result<(&'static str, &'static str)> { + match dtype { + DType::F32 => Ok(("sparse_conversions_f32", SPARSE_CONVERSIONS_F32)), + DType::I32 => Ok(("sparse_conversions_i32", SPARSE_CONVERSIONS_I32)), + DType::U32 => Ok(("sparse_conversions_u32", SPARSE_CONVERSIONS_U32)), + _ => Err(Error::UnsupportedDType { + dtype, + op: "sparse_conversions (WebGPU)", + }), + } +} /// Launch kernel to expand CSR row_ptrs to explicit row_indices. pub fn launch_expand_row_ptrs( @@ -27,8 +39,8 @@ pub fn launch_expand_row_ptrs( params: &Buffer, nrows: usize, ) -> Result<()> { - let source = generate_expand_row_ptrs_shader()?; - let module = cache.get_or_create_module_from_source("expand_row_ptrs", &source); + let module = + cache.get_or_create_module("sparse_conversions_indices", SPARSE_CONVERSIONS_INDICES); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, // row_ptrs, row_indices @@ -36,8 +48,8 @@ pub fn launch_expand_row_ptrs( num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline( - "expand_row_ptrs", + let pipeline = cache.get_or_create_pipeline( + "sparse_conversions_indices", "expand_row_ptrs", &module, &layout, @@ -74,8 +86,8 @@ pub fn launch_expand_col_ptrs( params: &Buffer, ncols: usize, ) -> Result<()> { - let source = generate_expand_col_ptrs_shader()?; - let module = cache.get_or_create_module_from_source("expand_col_ptrs", &source); + let module = + cache.get_or_create_module("sparse_conversions_indices", SPARSE_CONVERSIONS_INDICES); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, // col_ptrs, col_indices @@ -83,8 +95,8 @@ pub fn launch_expand_col_ptrs( num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline( - "expand_col_ptrs", + let pipeline = cache.get_or_create_pipeline( + "sparse_conversions_indices", "expand_col_ptrs", &module, &layout, @@ -121,8 +133,8 @@ pub fn launch_histogram( params: &Buffer, nnz: usize, ) -> Result<()> { - let source = generate_histogram_shader()?; - let module = cache.get_or_create_module_from_source("histogram", &source); + let module = + cache.get_or_create_module("sparse_conversions_indices", SPARSE_CONVERSIONS_INDICES); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, // indices, counts @@ -130,7 +142,8 @@ pub fn launch_histogram( num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline("histogram", "histogram", &module, &layout); + let pipeline = + cache.get_or_create_pipeline("sparse_conversions_indices", "histogram", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[indices, counts, params]); @@ -168,9 +181,8 @@ pub fn launch_coo_to_csr_scatter( nnz: usize, dtype: DType, ) -> Result<()> { - let source = generate_coo_to_csr_scatter_shader(dtype)?; - let key = format!("coo_to_csr_scatter_{}", dtype_suffix(dtype)?); - let module = cache.get_or_create_module_from_source(&key, &source); + let (module_key, shader) = typed_shader(dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 6, // in_row, in_col, in_val, row_ptrs_atomic, out_col, out_val @@ -178,8 +190,7 @@ pub fn launch_coo_to_csr_scatter( num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline(&key, "coo_to_csr_scatter", &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, "coo_to_csr_scatter", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -228,9 +239,8 @@ pub fn launch_coo_to_csc_scatter( nnz: usize, dtype: DType, ) -> Result<()> { - let source = generate_coo_to_csc_scatter_shader(dtype)?; - let key = format!("coo_to_csc_scatter_{}", dtype_suffix(dtype)?); - let module = cache.get_or_create_module_from_source(&key, &source); + let (module_key, shader) = typed_shader(dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 6, @@ -238,8 +248,7 @@ pub fn launch_coo_to_csc_scatter( num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline(&key, "coo_to_csc_scatter", &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, "coo_to_csc_scatter", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -288,9 +297,8 @@ pub fn launch_csr_to_csc_scatter( nrows: usize, dtype: DType, ) -> Result<()> { - let source = generate_csr_to_csc_scatter_shader(dtype)?; - let key = format!("csr_to_csc_scatter_{}", dtype_suffix(dtype)?); - let module = cache.get_or_create_module_from_source(&key, &source); + let (module_key, shader) = typed_shader(dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 6, @@ -298,8 +306,7 @@ pub fn launch_csr_to_csc_scatter( num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline(&key, "csr_to_csc_scatter", &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, "csr_to_csc_scatter", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -348,9 +355,8 @@ pub fn launch_csc_to_csr_scatter( ncols: usize, dtype: DType, ) -> Result<()> { - let source = generate_csc_to_csr_scatter_shader(dtype)?; - let key = format!("csc_to_csr_scatter_{}", dtype_suffix(dtype)?); - let module = cache.get_or_create_module_from_source(&key, &source); + let (module_key, shader) = typed_shader(dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 6, @@ -358,8 +364,7 @@ pub fn launch_csc_to_csr_scatter( num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline(&key, "csc_to_csr_scatter", &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, "csc_to_csr_scatter", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -403,8 +408,8 @@ pub fn launch_copy_ptrs( params: &Buffer, n: usize, ) -> Result<()> { - let source = generate_copy_ptrs_shader()?; - let module = cache.get_or_create_module_from_source("copy_ptrs", &source); + let module = + cache.get_or_create_module("sparse_conversions_indices", SPARSE_CONVERSIONS_INDICES); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, // src, dst @@ -412,7 +417,8 @@ pub fn launch_copy_ptrs( num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline("copy_ptrs", "copy_ptrs", &module, &layout); + let pipeline = + cache.get_or_create_pipeline("sparse_conversions_indices", "copy_ptrs", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[src, dst, params]); @@ -448,9 +454,8 @@ pub fn launch_csr_to_dense( nrows: usize, dtype: DType, ) -> Result<()> { - let source = super::generator::generate_csr_to_dense_shader(dtype)?; - let key = format!("csr_to_dense_{}", dtype_suffix(dtype)?); - let module = cache.get_or_create_module_from_source(&key, &source); + let (module_key, shader) = typed_shader(dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 4, // row_ptrs, col_indices, values, dense @@ -458,7 +463,7 @@ pub fn launch_csr_to_dense( num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline(&key, "csr_to_dense", &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, "csr_to_dense", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[row_ptrs, col_indices, values, dense, params]); @@ -493,9 +498,8 @@ pub fn launch_count_nonzeros( total_elems: usize, dtype: DType, ) -> Result<()> { - let source = super::generator::generate_count_nonzeros_shader(dtype)?; - let key = format!("count_nonzeros_{}", dtype_suffix(dtype)?); - let module = cache.get_or_create_module_from_source(&key, &source); + let (module_key, shader) = typed_shader(dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, // dense, count @@ -503,7 +507,7 @@ pub fn launch_count_nonzeros( num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline(&key, "count_nonzeros", &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, "count_nonzeros", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[dense, count, params]); @@ -540,9 +544,8 @@ pub fn launch_dense_to_coo_scatter( total_elems: usize, dtype: DType, ) -> Result<()> { - let source = super::generator::generate_dense_to_coo_scatter_shader(dtype)?; - let key = format!("dense_to_coo_scatter_{}", dtype_suffix(dtype)?); - let module = cache.get_or_create_module_from_source(&key, &source); + let (module_key, shader) = typed_shader(dtype)?; + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 5, // dense, row_indices, col_indices, values, write_pos @@ -551,7 +554,7 @@ pub fn launch_dense_to_coo_scatter( }); let pipeline = - cache.get_or_create_dynamic_pipeline(&key, "dense_to_coo_scatter", &module, &layout); + cache.get_or_create_pipeline(module_key, "dense_to_coo_scatter", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -577,33 +580,3 @@ pub fn launch_dense_to_coo_scatter( queue.submit(std::iter::once(encoder.finish())); Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_all_conversion_shaders_valid() { - // Validate all generated shaders are syntactically correct - validate_wgsl_syntax(&generate_expand_row_ptrs_shader().unwrap()).unwrap(); - validate_wgsl_syntax(&generate_expand_col_ptrs_shader().unwrap()).unwrap(); - validate_wgsl_syntax(&generate_histogram_shader().unwrap()).unwrap(); - validate_wgsl_syntax(&generate_copy_ptrs_shader().unwrap()).unwrap(); - - for dtype in [DType::F32, DType::I32, DType::U32] { - validate_wgsl_syntax(&generate_coo_to_csr_scatter_shader(dtype).unwrap()).unwrap(); - validate_wgsl_syntax(&generate_coo_to_csc_scatter_shader(dtype).unwrap()).unwrap(); - validate_wgsl_syntax(&generate_csr_to_csc_scatter_shader(dtype).unwrap()).unwrap(); - validate_wgsl_syntax(&generate_csc_to_csr_scatter_shader(dtype).unwrap()).unwrap(); - } - } -} diff --git a/src/runtime/wgpu/shaders/sparse_conversions_u32.wgsl b/src/runtime/wgpu/shaders/sparse_conversions_u32.wgsl new file mode 100644 index 00000000..b6ba7e3f --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_conversions_u32.wgsl @@ -0,0 +1,251 @@ +// Sparse format conversion shaders - U32 typed operations + +struct ScatterParams { + nnz: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +// ============================================================================ +// coo_to_csr_scatter +// ============================================================================ + +@group(0) @binding(0) var c2r_in_row_indices: array; +@group(0) @binding(1) var c2r_in_col_indices: array; +@group(0) @binding(2) var c2r_in_values: array; +@group(0) @binding(3) var c2r_row_ptrs_atomic: array>; +@group(0) @binding(4) var c2r_out_col_indices: array; +@group(0) @binding(5) var c2r_out_values: array; +@group(0) @binding(6) var c2r_params: ScatterParams; + +@compute @workgroup_size(256) +fn coo_to_csr_scatter(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= c2r_params.nnz) { + return; + } + + let row = c2r_in_row_indices[idx]; + let col = c2r_in_col_indices[idx]; + let val = c2r_in_values[idx]; + + let pos = atomicAdd(&c2r_row_ptrs_atomic[row], 1); + + c2r_out_col_indices[pos] = col; + c2r_out_values[pos] = val; +} + +// ============================================================================ +// coo_to_csc_scatter +// ============================================================================ + +@group(0) @binding(0) var c2c_in_row_indices: array; +@group(0) @binding(1) var c2c_in_col_indices: array; +@group(0) @binding(2) var c2c_in_values: array; +@group(0) @binding(3) var c2c_col_ptrs_atomic: array>; +@group(0) @binding(4) var c2c_out_row_indices: array; +@group(0) @binding(5) var c2c_out_values: array; +@group(0) @binding(6) var c2c_params: ScatterParams; + +@compute @workgroup_size(256) +fn coo_to_csc_scatter(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= c2c_params.nnz) { + return; + } + + let row = c2c_in_row_indices[idx]; + let col = c2c_in_col_indices[idx]; + let val = c2c_in_values[idx]; + + let pos = atomicAdd(&c2c_col_ptrs_atomic[col], 1); + + c2c_out_row_indices[pos] = row; + c2c_out_values[pos] = val; +} + +// ============================================================================ +// csr_to_csc_scatter (transpose) +// ============================================================================ + +struct TransposeRowParams { + nrows: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var r2c_in_row_ptrs: array; +@group(0) @binding(1) var r2c_in_col_indices: array; +@group(0) @binding(2) var r2c_in_values: array; +@group(0) @binding(3) var r2c_col_ptrs_atomic: array>; +@group(0) @binding(4) var r2c_out_row_indices: array; +@group(0) @binding(5) var r2c_out_values: array; +@group(0) @binding(6) var r2c_params: TransposeRowParams; + +@compute @workgroup_size(256) +fn csr_to_csc_scatter(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= r2c_params.nrows) { + return; + } + + let start = r2c_in_row_ptrs[row]; + let end = r2c_in_row_ptrs[row + 1u]; + + for (var i = start; i < end; i = i + 1) { + let col = r2c_in_col_indices[i]; + let val = r2c_in_values[i]; + + let pos = atomicAdd(&r2c_col_ptrs_atomic[col], 1); + + r2c_out_row_indices[pos] = i32(row); + r2c_out_values[pos] = val; + } +} + +// ============================================================================ +// csc_to_csr_scatter (transpose) +// ============================================================================ + +struct TransposeColParams { + ncols: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var c2r2_in_col_ptrs: array; +@group(0) @binding(1) var c2r2_in_row_indices: array; +@group(0) @binding(2) var c2r2_in_values: array; +@group(0) @binding(3) var c2r2_row_ptrs_atomic: array>; +@group(0) @binding(4) var c2r2_out_col_indices: array; +@group(0) @binding(5) var c2r2_out_values: array; +@group(0) @binding(6) var c2r2_params: TransposeColParams; + +@compute @workgroup_size(256) +fn csc_to_csr_scatter(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= c2r2_params.ncols) { + return; + } + + let start = c2r2_in_col_ptrs[col]; + let end = c2r2_in_col_ptrs[col + 1u]; + + for (var i = start; i < end; i = i + 1) { + let row = c2r2_in_row_indices[i]; + let val = c2r2_in_values[i]; + + let pos = atomicAdd(&c2r2_row_ptrs_atomic[row], 1); + + c2r2_out_col_indices[pos] = i32(col); + c2r2_out_values[pos] = val; + } +} + +// ============================================================================ +// csr_to_dense +// ============================================================================ + +struct CsrToDenseParams { + nrows: u32, + ncols: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var ctd_row_ptrs: array; +@group(0) @binding(1) var ctd_col_indices: array; +@group(0) @binding(2) var ctd_values: array; +@group(0) @binding(3) var ctd_dense: array; +@group(0) @binding(4) var ctd_params: CsrToDenseParams; + +@compute @workgroup_size(256) +fn csr_to_dense(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= ctd_params.nrows) { + return; + } + + let start = ctd_row_ptrs[row]; + let end = ctd_row_ptrs[row + 1u]; + let ncols = ctd_params.ncols; + + for (var i = start; i < end; i = i + 1) { + let col = u32(ctd_col_indices[i]); + ctd_dense[row * ncols + col] = ctd_values[i]; + } +} + +// ============================================================================ +// count_nonzeros +// ============================================================================ + +struct CountNzParams { + total_elems: u32, + threshold_bits: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var cnz_dense: array; +@group(0) @binding(1) var cnz_count: atomic; +@group(0) @binding(2) var cnz_params: CountNzParams; + +@compute @workgroup_size(256) +fn count_nonzeros(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= cnz_params.total_elems) { + return; + } + + let val = cnz_dense[idx]; + let zero_val = u32(0); + + if (val != zero_val) { + atomicAdd(&cnz_count, 1u); + } +} + +// ============================================================================ +// dense_to_coo_scatter +// ============================================================================ + +struct DenseToCooParams { + nrows: u32, + ncols: u32, + threshold_bits: u32, + _pad0: u32, +} + +@group(0) @binding(0) var dtc_dense: array; +@group(0) @binding(1) var dtc_row_indices: array; +@group(0) @binding(2) var dtc_col_indices: array; +@group(0) @binding(3) var dtc_values: array; +@group(0) @binding(4) var dtc_write_pos: atomic; +@group(0) @binding(5) var dtc_params: DenseToCooParams; + +@compute @workgroup_size(256) +fn dense_to_coo_scatter(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = dtc_params.nrows * dtc_params.ncols; + if (idx >= total) { + return; + } + + let val = dtc_dense[idx]; + let zero_val = u32(0); + + if (val != zero_val) { + let row = idx / dtc_params.ncols; + let col = idx % dtc_params.ncols; + + let pos = atomicAdd(&dtc_write_pos, 1u); + + dtc_row_indices[pos] = i32(row); + dtc_col_indices[pos] = i32(col); + dtc_values[pos] = val; + } +} diff --git a/src/runtime/wgpu/shaders/sparse_find_diag_indices.wgsl b/src/runtime/wgpu/shaders/sparse_find_diag_indices.wgsl new file mode 100644 index 00000000..86884a45 --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_find_diag_indices.wgsl @@ -0,0 +1,33 @@ +// Find diagonal indices in CSR matrix + +struct DiagParams { + n: u32, + _padding0: u32, + _padding1: u32, + _padding2: u32, +} + +@group(0) @binding(0) var row_ptrs: array; +@group(0) @binding(1) var col_indices: array; +@group(0) @binding(2) var diag_indices: array; +@group(0) @binding(3) var params: DiagParams; + +@compute @workgroup_size(256) +fn find_diag_indices(@builtin(global_invocation_id) gid: vec3) { + let row = i32(gid.x); + if (u32(row) >= params.n) { + return; + } + + let start = row_ptrs[row]; + let end = row_ptrs[row + 1]; + + diag_indices[row] = -1; // Default: no diagonal found + + for (var idx = start; idx < end; idx = idx + 1) { + if (col_indices[idx] == row) { + diag_indices[row] = idx; + break; + } + } +} diff --git a/src/runtime/wgpu/shaders/sparse_ic0_level_f32.wgsl b/src/runtime/wgpu/shaders/sparse_ic0_level_f32.wgsl new file mode 100644 index 00000000..c43c4f86 --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_ic0_level_f32.wgsl @@ -0,0 +1,81 @@ +// Level-scheduled IC(0) factorization kernel + +struct Ic0Params { + level_size: u32, + n: u32, + diagonal_shift: f32, + level_start: u32, +} + +@group(0) @binding(0) var level_rows: array; +@group(0) @binding(1) var row_ptrs: array; +@group(0) @binding(2) var col_indices: array; +@group(0) @binding(3) var values: array; +@group(0) @binding(4) var diag_indices: array; +@group(0) @binding(5) var params: Ic0Params; + +@compute @workgroup_size(256) +fn ic0_level_f32(@builtin(global_invocation_id) gid: vec3) { + let tid = gid.x; + if (tid >= params.level_size) { + return; + } + + let i = level_rows[params.level_start + tid]; + let i_start = row_ptrs[i]; + let i_end = row_ptrs[i + 1]; + + // Process off-diagonal entries in row i (columns k < i) + for (var idx_ik = i_start; idx_ik < i_end; idx_ik = idx_ik + 1) { + let k = col_indices[idx_ik]; + if (k >= i) { + break; + } + + let k_start = row_ptrs[k]; + let k_end = row_ptrs[k + 1]; + + // Compute inner product contribution + var sum = values[idx_ik]; + + for (var idx_kj = k_start; idx_kj < k_end; idx_kj = idx_kj + 1) { + let j = col_indices[idx_kj]; + if (j >= k) { + break; + } + + // Check if L[i,j] exists + for (var idx_ij = i_start; idx_ij < i_end; idx_ij = idx_ij + 1) { + if (col_indices[idx_ij] == j) { + sum = sum - values[idx_ij] * values[idx_kj]; + break; + } + if (col_indices[idx_ij] > j) { + break; + } + } + } + + // Divide by L[k,k] + let diag_k = diag_indices[k]; + values[idx_ik] = sum / values[diag_k]; + } + + // Compute diagonal L[i,i] + let diag_i = diag_indices[i]; + var diag_sum = values[diag_i] + params.diagonal_shift; + + for (var idx_ij = i_start; idx_ij < i_end; idx_ij = idx_ij + 1) { + let j = col_indices[idx_ij]; + if (j >= i) { + break; + } + diag_sum = diag_sum - values[idx_ij] * values[idx_ij]; + } + + if (diag_sum <= 0.0) { + diag_sum = select(1e-10, params.diagonal_shift, params.diagonal_shift > 0.0); + } + + values[diag_i] = sqrt(diag_sum); +} diff --git a/src/runtime/wgpu/shaders/sparse_ilu0_level_f32.wgsl b/src/runtime/wgpu/shaders/sparse_ilu0_level_f32.wgsl new file mode 100644 index 00000000..f1674758 --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_ilu0_level_f32.wgsl @@ -0,0 +1,73 @@ +// Level-scheduled ILU(0) factorization kernel + +struct Ilu0Params { + level_size: u32, + n: u32, + diagonal_shift: f32, + level_start: u32, +} + +@group(0) @binding(0) var level_rows: array; +@group(0) @binding(1) var row_ptrs: array; +@group(0) @binding(2) var col_indices: array; +@group(0) @binding(3) var values: array; +@group(0) @binding(4) var diag_indices: array; +@group(0) @binding(5) var params: Ilu0Params; + +@compute @workgroup_size(256) +fn ilu0_level_f32(@builtin(global_invocation_id) gid: vec3) { + let tid = gid.x; + if (tid >= params.level_size) { + return; + } + + let i = level_rows[params.level_start + tid]; + let row_start = row_ptrs[i]; + let row_end = row_ptrs[i + 1]; + + // Process columns k < i (for L factor) + for (var idx_ik = row_start; idx_ik < row_end; idx_ik = idx_ik + 1) { + let k = col_indices[idx_ik]; + if (k >= i) { + break; + } + + // Get diagonal U[k,k] + let diag_k = diag_indices[k]; + var diag_val = values[diag_k]; + + // Handle zero pivot + if (abs(diag_val) < 1e-15) { + if (params.diagonal_shift > 0.0) { + values[diag_k] = params.diagonal_shift; + diag_val = params.diagonal_shift; + } + } + + // L[i,k] = A[i,k] / U[k,k] + let l_ik = values[idx_ik] / diag_val; + values[idx_ik] = l_ik; + + // Update row i for columns j > k + let k_start = row_ptrs[k]; + let k_end = row_ptrs[k + 1]; + + for (var idx_kj = k_start; idx_kj < k_end; idx_kj = idx_kj + 1) { + let j = col_indices[idx_kj]; + if (j <= k) { + continue; + } + + // Find A[i,j] if it exists (zero fill-in constraint) + for (var idx_ij = row_start; idx_ij < row_end; idx_ij = idx_ij + 1) { + if (col_indices[idx_ij] == j) { + values[idx_ij] = values[idx_ij] - l_ik * values[idx_kj]; + break; + } + if (col_indices[idx_ij] > j) { + break; + } + } + } + } +} diff --git a/src/runtime/wgpu/shaders/sparse_linalg.wgsl b/src/runtime/wgpu/shaders/sparse_linalg.wgsl index 0ec9b9ba..b92ee5cc 100644 --- a/src/runtime/wgpu/shaders/sparse_linalg.wgsl +++ b/src/runtime/wgpu/shaders/sparse_linalg.wgsl @@ -19,10 +19,10 @@ struct ScatterParams { count: u32, } -@group(0) @binding(0) var scatter_params: ScatterParams; -@group(0) @binding(1) var scatter_values_f32: array; -@group(0) @binding(2) var scatter_row_indices: array; -@group(0) @binding(3) var scatter_work_f32: array; +@group(0) @binding(0) var scatter_values_f32: array; +@group(0) @binding(1) var scatter_row_indices: array; +@group(0) @binding(2) var scatter_work_f32: array; +@group(0) @binding(3) var scatter_params: ScatterParams; @compute @workgroup_size(256) fn sparse_scatter_offset_f32(@builtin(global_invocation_id) gid: vec3) { @@ -315,13 +315,13 @@ struct TrsvCscUpperParams { _pad: u32, } -@group(0) @binding(0) var trsv_upper_params: TrsvCscUpperParams; -@group(0) @binding(1) var trsv_upper_level_cols: array; -@group(0) @binding(2) var trsv_upper_col_ptrs: array; -@group(0) @binding(3) var trsv_upper_row_indices: array; -@group(0) @binding(4) var trsv_upper_values: array; -@group(0) @binding(5) var trsv_upper_diag_ptr: array; -@group(0) @binding(6) var trsv_upper_b: array; +@group(0) @binding(0) var trsv_upper_level_cols: array; +@group(0) @binding(1) var trsv_upper_col_ptrs: array; +@group(0) @binding(2) var trsv_upper_row_indices: array; +@group(0) @binding(3) var trsv_upper_values: array; +@group(0) @binding(4) var trsv_upper_diag_ptr: array; +@group(0) @binding(5) var trsv_upper_b: array; +@group(0) @binding(6) var trsv_upper_params: TrsvCscUpperParams; @compute @workgroup_size(256) fn sparse_trsv_csc_upper_level_f32(@builtin(global_invocation_id) gid: vec3) { @@ -367,10 +367,10 @@ struct FindDiagCscParams { _pad3: u32, } -@group(0) @binding(0) var find_diag_csc_params: FindDiagCscParams; -@group(0) @binding(1) var find_diag_csc_col_ptrs: array; -@group(0) @binding(2) var find_diag_csc_row_indices: array; -@group(0) @binding(3) var find_diag_csc_diag_ptr: array; +@group(0) @binding(0) var find_diag_csc_col_ptrs: array; +@group(0) @binding(1) var find_diag_csc_row_indices: array; +@group(0) @binding(2) var find_diag_csc_diag_ptr: array; +@group(0) @binding(3) var find_diag_csc_params: FindDiagCscParams; @compute @workgroup_size(256) fn find_diag_indices_csc_f32(@builtin(global_invocation_id) gid: vec3) { @@ -400,10 +400,10 @@ struct ApplyPermParams { _pad3: u32, } -@group(0) @binding(0) var apply_perm_params: ApplyPermParams; -@group(0) @binding(1) var apply_perm_b: array; -@group(0) @binding(2) var apply_perm_perm: array; -@group(0) @binding(3) var apply_perm_y: array; +@group(0) @binding(0) var apply_perm_b: array; +@group(0) @binding(1) var apply_perm_perm: array; +@group(0) @binding(2) var apply_perm_y: array; +@group(0) @binding(3) var apply_perm_params: ApplyPermParams; @compute @workgroup_size(256) fn apply_row_perm_f32(@builtin(global_invocation_id) gid: vec3) { @@ -497,3 +497,216 @@ fn sparse_swap_rows(@builtin(global_invocation_id) gid: vec3) { swap_perm[swap_params.row_b] = tmp_perm; } } + +// ============================================================================ +// Sparse QR Factorization Kernels (F32 only) +// ============================================================================ + +// Apply Householder reflector: fused dot + axpy +// work[v_start..v_start+v_len] -= tau * (v^T * work[v_start..]) * v +// Single workgroup, shared memory reduction for dot product +struct QrReflectorParams { + v_start: u32, + v_len: u32, +} + +@group(0) @binding(0) var qr_reflector_v: array; +@group(0) @binding(1) var qr_reflector_tau: array; +@group(0) @binding(2) var qr_reflector_work: array; +@group(0) @binding(3) var qr_reflector_params: QrReflectorParams; + +var qr_dot_partial: array; + +@compute @workgroup_size(256) +fn sparse_qr_apply_reflector_f32(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let v_start = qr_reflector_params.v_start; + let v_len = qr_reflector_params.v_len; + let tau = qr_reflector_tau[0]; + + if (tau == 0.0) { return; } + + // Phase 1: dot product + var my_sum: f32 = 0.0; + var i = tid; + loop { + if (i >= v_len) { break; } + my_sum += qr_reflector_v[i] * qr_reflector_work[v_start + i]; + i += 256u; + } + qr_dot_partial[tid] = my_sum; + workgroupBarrier(); + + // Reduction + var s = 128u; + loop { + if (s == 0u) { break; } + if (tid < s) { + qr_dot_partial[tid] += qr_dot_partial[tid + s]; + } + workgroupBarrier(); + s = s >> 1u; + } + + let scale = tau * qr_dot_partial[0]; + + // Phase 2: axpy + i = tid; + loop { + if (i >= v_len) { break; } + qr_reflector_work[v_start + i] -= scale * qr_reflector_v[i]; + i += 256u; + } +} + +// Norm: compute ||work[start..start+count]||^2 +struct QrNormParams { + start: u32, + count: u32, +} + +@group(0) @binding(0) var qr_norm_work: array; +@group(0) @binding(1) var qr_norm_result: array; +@group(0) @binding(2) var qr_norm_params: QrNormParams; + +var qr_norm_partial: array; + +@compute @workgroup_size(256) +fn sparse_qr_norm_f32(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let start = qr_norm_params.start; + let count = qr_norm_params.count; + + var my_sum: f32 = 0.0; + var i = tid; + loop { + if (i >= count) { break; } + let val = qr_norm_work[start + i]; + my_sum += val * val; + i += 256u; + } + qr_norm_partial[tid] = my_sum; + workgroupBarrier(); + + var s = 128u; + loop { + if (s == 0u) { break; } + if (tid < s) { + qr_norm_partial[tid] += qr_norm_partial[tid + s]; + } + workgroupBarrier(); + s = s >> 1u; + } + + if (tid == 0u) { + qr_norm_result[0] = qr_norm_partial[0]; + } +} + +// Householder: compute Householder vector from work[start..m] +// +// Tolerance 1e-30: well below f32 machine epsilon (~1e-7). Matches CPU +// implementation (algorithm.rs:226,238). Detects truly zero columns without +// false positives from normal floating-point roundoff. +struct QrHouseholderParams { + start: u32, + m: u32, +} + +@group(0) @binding(0) var qr_hh_work: array; +@group(0) @binding(1) var qr_hh_norm_sq: array; +@group(0) @binding(2) var qr_hh_out_v: array; +@group(0) @binding(3) var qr_hh_out_tau: array; +@group(0) @binding(4) var qr_hh_out_diag: array; +@group(0) @binding(5) var qr_hh_params: QrHouseholderParams; + +var qr_hh_ctrl: array; // [sigma, tau, diag, inv_v_start] + +@compute @workgroup_size(256) +fn sparse_qr_householder_f32(@builtin(local_invocation_id) lid: vec3) { + let tid = lid.x; + let start = qr_hh_params.start; + let m = qr_hh_params.m; + let v_len = m - start; + + if (tid == 0u) { + let norm_sq = qr_hh_norm_sq[0]; + let norm = sqrt(norm_sq); + + if (norm < 1e-30) { + qr_hh_ctrl[0] = 0.0; qr_hh_ctrl[1] = 0.0; + qr_hh_ctrl[2] = 0.0; qr_hh_ctrl[3] = 0.0; + } else { + let x0 = qr_hh_work[start]; + var sigma: f32; + if (x0 >= 0.0) { sigma = -norm; } else { sigma = norm; } + let v_start_val = x0 - sigma; + + if (abs(v_start_val) < 1e-30) { + qr_hh_ctrl[0] = sigma; qr_hh_ctrl[1] = 0.0; + qr_hh_ctrl[2] = sigma; qr_hh_ctrl[3] = 0.0; + } else { + qr_hh_ctrl[0] = sigma; + qr_hh_ctrl[1] = -v_start_val / sigma; + qr_hh_ctrl[2] = sigma; + qr_hh_ctrl[3] = 1.0 / v_start_val; + } + } + } + workgroupBarrier(); + + let tau = qr_hh_ctrl[1]; + let inv_v_start = qr_hh_ctrl[3]; + + if (tid == 0u) { + qr_hh_out_tau[0] = tau; + qr_hh_out_diag[0] = qr_hh_ctrl[2]; + } + workgroupBarrier(); // Ensure scalar writes complete before output loop + + var i = tid; + loop { + if (i >= v_len) { break; } + if (tau == 0.0) { + if (i == 0u) { qr_hh_out_v[i] = 1.0; } else { qr_hh_out_v[i] = 0.0; } + } else { + if (i == 0u) { qr_hh_out_v[i] = 1.0; } else { qr_hh_out_v[i] = qr_hh_work[start + i] * inv_v_start; } + } + i += 256u; + } +} + +// Extract R off-diagonal: copy work[0..count] to output +struct QrExtractRParams { + count: u32, + _alignment: u32, // WGSL uniform buffer 8-byte minimum alignment +} + +@group(0) @binding(0) var qr_extract_work: array; +@group(0) @binding(1) var qr_extract_output: array; +@group(0) @binding(2) var qr_extract_params: QrExtractRParams; + +@compute @workgroup_size(256) +fn sparse_qr_extract_r_f32(@builtin(global_invocation_id) gid: vec3) { + let i = gid.x; + if (i < qr_extract_params.count) { + qr_extract_output[i] = qr_extract_work[i]; + } +} + +// Clear work vector: work[0..n] = 0 +struct QrClearParams { + n: u32, + _alignment: u32, // WGSL uniform buffer 8-byte minimum alignment +} + +@group(0) @binding(0) var qr_clear_work: array; +@group(0) @binding(1) var qr_clear_params: QrClearParams; + +@compute @workgroup_size(256) +fn sparse_qr_clear_f32(@builtin(global_invocation_id) gid: vec3) { + let i = gid.x; + if (i < qr_clear_params.n) { + qr_clear_work[i] = 0.0; + } +} diff --git a/src/runtime/wgpu/shaders/sparse_linalg_launcher.rs b/src/runtime/wgpu/shaders/sparse_linalg_launcher.rs index eb0b30a9..ade434a3 100644 --- a/src/runtime/wgpu/shaders/sparse_linalg_launcher.rs +++ b/src/runtime/wgpu/shaders/sparse_linalg_launcher.rs @@ -8,15 +8,13 @@ use wgpu::{Buffer, Queue}; -use super::generator::dtype_suffix; -use super::generator::sparse_linalg::{ - generate_extract_lower_count_shader, generate_extract_lower_scatter_shader, - generate_split_lu_count_shader, generate_split_lu_scatter_l_shader, - generate_split_lu_scatter_u_shader, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; -use crate::error::Result; +use crate::error::{Error, Result}; + +// Static WGSL shader sources +const SPARSE_LINALG: &str = include_str!("sparse_linalg.wgsl"); +const SPARSE_LINALG_SPLIT_F32: &str = include_str!("sparse_linalg_split_f32.wgsl"); // ============================================================================ // Split LU Operations @@ -40,15 +38,18 @@ pub fn launch_split_lu_count( params_buffer: &Buffer, n: usize, ) -> Result<()> { - let shader_source = generate_split_lu_count_shader(); - let module = cache.get_or_create_module_from_source("split_lu_count", &shader_source); + let module = cache.get_or_create_module("sparse_linalg_split_f32", SPARSE_LINALG_SPLIT_F32); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 4, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_pipeline("split_lu_count", "split_lu_count", &module, &layout); + let pipeline = cache.get_or_create_pipeline( + "sparse_linalg_split_f32", + "split_lu_count", + &module, + &layout, + ); let bind_group = cache.create_bind_group( &layout, @@ -98,19 +99,25 @@ pub fn launch_split_lu_scatter_l( n: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("split_lu_scatter_l_{}", suffix); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "split_lu_scatter_l (WebGPU)", + }); + } - let shader_source = generate_split_lu_scatter_l_shader(dtype)?; - let module_name = format!("split_lu_scatter_l_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module("sparse_linalg_split_f32", SPARSE_LINALG_SPLIT_F32); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 6, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("split_lu_scatter_l", &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline( + "sparse_linalg_split_f32", + "split_lu_scatter_l_f32", + &module, + &layout, + ); let bind_group = cache.create_bind_group( &layout, @@ -168,19 +175,25 @@ pub fn launch_split_lu_scatter_u( n: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("split_lu_scatter_u_{}", suffix); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "split_lu_scatter_u (WebGPU)", + }); + } - let shader_source = generate_split_lu_scatter_u_shader(dtype)?; - let module_name = format!("split_lu_scatter_u_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module("sparse_linalg_split_f32", SPARSE_LINALG_SPLIT_F32); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 6, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("split_lu_scatter_u", &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline( + "sparse_linalg_split_f32", + "split_lu_scatter_u_f32", + &module, + &layout, + ); let bind_group = cache.create_bind_group( &layout, @@ -235,15 +248,14 @@ pub fn launch_extract_lower_count( params_buffer: &Buffer, n: usize, ) -> Result<()> { - let shader_source = generate_extract_lower_count_shader(); - let module = cache.get_or_create_module_from_source("extract_lower_count", &shader_source); + let module = cache.get_or_create_module("sparse_linalg_split_f32", SPARSE_LINALG_SPLIT_F32); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); let pipeline = cache.get_or_create_pipeline( - "extract_lower_count", + "sparse_linalg_split_f32", "extract_lower_count", &module, &layout, @@ -295,20 +307,22 @@ pub fn launch_extract_lower_scatter( n: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("extract_lower_scatter_{}", suffix); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "extract_lower_scatter (WebGPU)", + }); + } - let shader_source = generate_extract_lower_scatter_shader(dtype)?; - let module_name = format!("extract_lower_scatter_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module("sparse_linalg_split_f32", SPARSE_LINALG_SPLIT_F32); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 6, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline( - "extract_lower_scatter", - &entry_point, + let pipeline = cache.get_or_create_pipeline( + "sparse_linalg_split_f32", + "extract_lower_scatter_f32", &module, &layout, ); @@ -371,15 +385,14 @@ pub fn launch_sparse_scatter_f32( work: &Buffer, nnz: usize, ) -> Result<()> { - let shader_source = include_str!("sparse_linalg.wgsl"); - let module = cache.get_or_create_module_from_source("sparse_scatter_f32", shader_source); + let module = cache.get_or_create_module("sparse_linalg", SPARSE_LINALG); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 0, num_readonly_storage: 0, }); let pipeline = - cache.get_or_create_pipeline("sparse_scatter_f32", "sparse_scatter_f32", &module, &layout); + cache.get_or_create_pipeline("sparse_linalg", "sparse_scatter_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[values, row_indices, work]); @@ -415,15 +428,14 @@ pub fn launch_sparse_axpy_f32( work: &Buffer, nnz: usize, ) -> Result<()> { - let shader_source = include_str!("sparse_linalg.wgsl"); - let module = cache.get_or_create_module_from_source("sparse_axpy_f32", shader_source); + let module = cache.get_or_create_module("sparse_linalg", SPARSE_LINALG); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); let pipeline = - cache.get_or_create_pipeline("sparse_axpy_f32", "sparse_axpy_f32", &module, &layout); + cache.get_or_create_pipeline("sparse_linalg", "sparse_axpy_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[params_buffer, values, row_indices, work]); @@ -458,19 +470,14 @@ pub fn launch_sparse_gather_clear_f32( output: &Buffer, nnz: usize, ) -> Result<()> { - let shader_source = include_str!("sparse_linalg.wgsl"); - let module = cache.get_or_create_module_from_source("sparse_gather_clear_f32", shader_source); + let module = cache.get_or_create_module("sparse_linalg", SPARSE_LINALG); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 0, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline( - "sparse_gather_clear_f32", - "sparse_gather_clear_f32", - &module, - &layout, - ); + let pipeline = + cache.get_or_create_pipeline("sparse_linalg", "sparse_gather_clear_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[work, row_indices, output]); @@ -515,19 +522,14 @@ pub fn launch_sparse_divide_pivot_f32( row_indices: &Buffer, nnz: usize, ) -> Result<()> { - let shader_source = include_str!("sparse_linalg.wgsl"); - let module = cache.get_or_create_module_from_source("sparse_divide_pivot_f32", shader_source); + let module = cache.get_or_create_module("sparse_linalg", SPARSE_LINALG); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_pipeline( - "sparse_divide_pivot_f32", - "sparse_divide_pivot_f32", - &module, - &layout, - ); + let pipeline = + cache.get_or_create_pipeline("sparse_linalg", "sparse_divide_pivot_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[params_buffer, work, row_indices]); @@ -561,15 +563,14 @@ pub fn launch_sparse_clear_f32( row_indices: &Buffer, nnz: usize, ) -> Result<()> { - let shader_source = include_str!("sparse_linalg.wgsl"); - let module = cache.get_or_create_module_from_source("sparse_clear_f32", shader_source); + let module = cache.get_or_create_module("sparse_linalg", SPARSE_LINALG); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, num_uniform_buffers: 0, num_readonly_storage: 0, }); let pipeline = - cache.get_or_create_pipeline("sparse_clear_f32", "sparse_clear_f32", &module, &layout); + cache.get_or_create_pipeline("sparse_linalg", "sparse_clear_f32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[work, row_indices]); diff --git a/src/runtime/wgpu/shaders/sparse_linalg_split_f32.wgsl b/src/runtime/wgpu/shaders/sparse_linalg_split_f32.wgsl new file mode 100644 index 00000000..e9b3a0c0 --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_linalg_split_f32.wgsl @@ -0,0 +1,214 @@ +// Sparse LU split and lower triangle extraction shaders - F32 +// +// split_lu_count: Count L and U non-zeros per row +// split_lu_scatter_l_f32: Scatter values into L matrix (lower triangle) +// split_lu_scatter_u_f32: Scatter values into U matrix (upper triangle + diagonal) +// extract_lower_count: Count lower triangle non-zeros per row +// extract_lower_scatter_f32: Scatter lower triangle values + +// ============================================================================ +// split_lu_count +// ============================================================================ + +struct SplitLuCountParams { + n: u32, + _padding0: u32, + _padding1: u32, + _padding2: u32, +} + +// Note: All buffers use read_write due to LayoutKey-based pipeline layout +@group(0) @binding(0) var slc_row_ptrs: array; +@group(0) @binding(1) var slc_col_indices: array; +@group(0) @binding(2) var slc_l_counts: array; +@group(0) @binding(3) var slc_u_counts: array; +@group(0) @binding(4) var slc_params: SplitLuCountParams; + +@compute @workgroup_size(256) +fn split_lu_count(@builtin(global_invocation_id) gid: vec3) { + let row = i32(gid.x); + if (u32(row) >= slc_params.n) { + return; + } + + let start = slc_row_ptrs[row]; + let end = slc_row_ptrs[row + 1]; + + var l_count = 0i; + var u_count = 0i; + + for (var idx = start; idx < end; idx = idx + 1) { + let col = slc_col_indices[idx]; + if (col < row) { + l_count = l_count + 1; + } else { + u_count = u_count + 1; + } + } + + slc_l_counts[row] = l_count; + slc_u_counts[row] = u_count; +} + +// ============================================================================ +// split_lu_scatter_l_f32 +// ============================================================================ + +struct SplitLuScatterLParams { + n: u32, + _padding0: u32, + _padding1: u32, + _padding2: u32, +} + +// Note: All buffers use read_write due to LayoutKey-based pipeline layout +@group(0) @binding(0) var sll_row_ptrs: array; +@group(0) @binding(1) var sll_col_indices: array; +@group(0) @binding(2) var sll_values: array; +@group(0) @binding(3) var sll_l_row_ptrs: array; +@group(0) @binding(4) var sll_l_col_indices: array; +@group(0) @binding(5) var sll_l_values: array; +@group(0) @binding(6) var sll_params: SplitLuScatterLParams; + +@compute @workgroup_size(256) +fn split_lu_scatter_l_f32(@builtin(global_invocation_id) gid: vec3) { + let row = i32(gid.x); + if (u32(row) >= sll_params.n) { + return; + } + + let src_start = sll_row_ptrs[row]; + let src_end = sll_row_ptrs[row + 1]; + var l_write_pos = sll_l_row_ptrs[row]; + + for (var idx = src_start; idx < src_end; idx = idx + 1) { + let col = sll_col_indices[idx]; + if (col < row) { + sll_l_col_indices[l_write_pos] = col; + sll_l_values[l_write_pos] = sll_values[idx]; + l_write_pos = l_write_pos + 1; + } + } +} + +// ============================================================================ +// split_lu_scatter_u_f32 +// ============================================================================ + +struct SplitLuScatterUParams { + n: u32, + _padding0: u32, + _padding1: u32, + _padding2: u32, +} + +// Note: All buffers use read_write due to LayoutKey-based pipeline layout +@group(0) @binding(0) var slu_row_ptrs: array; +@group(0) @binding(1) var slu_col_indices: array; +@group(0) @binding(2) var slu_values: array; +@group(0) @binding(3) var slu_u_row_ptrs: array; +@group(0) @binding(4) var slu_u_col_indices: array; +@group(0) @binding(5) var slu_u_values: array; +@group(0) @binding(6) var slu_params: SplitLuScatterUParams; + +@compute @workgroup_size(256) +fn split_lu_scatter_u_f32(@builtin(global_invocation_id) gid: vec3) { + let row = i32(gid.x); + if (u32(row) >= slu_params.n) { + return; + } + + let src_start = slu_row_ptrs[row]; + let src_end = slu_row_ptrs[row + 1]; + var u_write_pos = slu_u_row_ptrs[row]; + + for (var idx = src_start; idx < src_end; idx = idx + 1) { + let col = slu_col_indices[idx]; + if (col >= row) { + slu_u_col_indices[u_write_pos] = col; + slu_u_values[u_write_pos] = slu_values[idx]; + u_write_pos = u_write_pos + 1; + } + } +} + +// ============================================================================ +// extract_lower_count +// ============================================================================ + +struct ExtractLowerCountParams { + n: u32, + _padding0: u32, + _padding1: u32, + _padding2: u32, +} + +// Note: All buffers use read_write due to LayoutKey-based pipeline layout +@group(0) @binding(0) var elc_row_ptrs: array; +@group(0) @binding(1) var elc_col_indices: array; +@group(0) @binding(2) var elc_l_counts: array; +@group(0) @binding(3) var elc_params: ExtractLowerCountParams; + +@compute @workgroup_size(256) +fn extract_lower_count(@builtin(global_invocation_id) gid: vec3) { + let row = i32(gid.x); + if (u32(row) >= elc_params.n) { + return; + } + + let start = elc_row_ptrs[row]; + let end = elc_row_ptrs[row + 1]; + + var count = 0i; + + for (var idx = start; idx < end; idx = idx + 1) { + let col = elc_col_indices[idx]; + if (col <= row) { + count = count + 1; + } + } + + elc_l_counts[row] = count; +} + +// ============================================================================ +// extract_lower_scatter_f32 +// ============================================================================ + +struct ExtractLowerScatterParams { + n: u32, + _padding0: u32, + _padding1: u32, + _padding2: u32, +} + +// Note: All buffers use read_write due to LayoutKey-based pipeline layout +@group(0) @binding(0) var els_row_ptrs: array; +@group(0) @binding(1) var els_col_indices: array; +@group(0) @binding(2) var els_values: array; +@group(0) @binding(3) var els_l_row_ptrs: array; +@group(0) @binding(4) var els_l_col_indices: array; +@group(0) @binding(5) var els_l_values: array; +@group(0) @binding(6) var els_params: ExtractLowerScatterParams; + +@compute @workgroup_size(256) +fn extract_lower_scatter_f32(@builtin(global_invocation_id) gid: vec3) { + let row = i32(gid.x); + if (u32(row) >= els_params.n) { + return; + } + + let src_start = els_row_ptrs[row]; + let src_end = els_row_ptrs[row + 1]; + + var write_pos = els_l_row_ptrs[row]; + + for (var idx = src_start; idx < src_end; idx = idx + 1) { + let col = els_col_indices[idx]; + if (col <= row) { + els_l_col_indices[write_pos] = col; + els_l_values[write_pos] = els_values[idx]; + write_pos = write_pos + 1; + } + } +} diff --git a/src/runtime/wgpu/shaders/sparse_merge_count.wgsl b/src/runtime/wgpu/shaders/sparse_merge_count.wgsl new file mode 100644 index 00000000..7505ade0 --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_merge_count.wgsl @@ -0,0 +1,244 @@ +// Sparse merge count shaders - type-independent +// +// csr_merge_count: Count output NNZ per row for CSR add/sub (union semantics) +// csr_mul_count: Count output NNZ per row for CSR mul/div (intersection semantics) +// csc_merge_count: Count output NNZ per col for CSC add/sub (union semantics) +// csc_mul_count: Count output NNZ per col for CSC mul/div (intersection semantics) +// exclusive_scan_i32: Sequential exclusive prefix sum + +const WORKGROUP_SIZE: u32 = 256u; + +// ============================================================================ +// csr_merge_count +// ============================================================================ + +struct CsrMergeCountParams { + nrows: u32, +} + +@group(0) @binding(0) var cmc_a_row_ptrs: array; +@group(0) @binding(1) var cmc_a_col_indices: array; +@group(0) @binding(2) var cmc_b_row_ptrs: array; +@group(0) @binding(3) var cmc_b_col_indices: array; +@group(0) @binding(4) var cmc_row_counts: array; +@group(0) @binding(5) var cmc_params: CsrMergeCountParams; + +@compute @workgroup_size(256) +fn csr_merge_count(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= cmc_params.nrows) { + return; + } + + let a_start = cmc_a_row_ptrs[row]; + let a_end = cmc_a_row_ptrs[row + 1u]; + let b_start = cmc_b_row_ptrs[row]; + let b_end = cmc_b_row_ptrs[row + 1u]; + + var count: i32 = 0; + var i: i32 = a_start; + var j: i32 = b_start; + + // Merge sorted column indices, count unique columns + while (i < a_end && j < b_end) { + let a_col = cmc_a_col_indices[i]; + let b_col = cmc_b_col_indices[j]; + + count = count + 1; + if (a_col < b_col) { + i = i + 1; + } else if (a_col > b_col) { + j = j + 1; + } else { + i = i + 1; + j = j + 1; + } + } + + // Add remaining elements from A + count = count + (a_end - i); + // Add remaining elements from B + count = count + (b_end - j); + + cmc_row_counts[row] = count; +} + +// ============================================================================ +// csr_mul_count +// ============================================================================ + +struct CsrMulCountParams { + nrows: u32, +} + +@group(0) @binding(0) var cmmc_a_row_ptrs: array; +@group(0) @binding(1) var cmmc_a_col_indices: array; +@group(0) @binding(2) var cmmc_b_row_ptrs: array; +@group(0) @binding(3) var cmmc_b_col_indices: array; +@group(0) @binding(4) var cmmc_row_counts: array; +@group(0) @binding(5) var cmmc_params: CsrMulCountParams; + +@compute @workgroup_size(256) +fn csr_mul_count(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= cmmc_params.nrows) { + return; + } + + let a_start = cmmc_a_row_ptrs[row]; + let a_end = cmmc_a_row_ptrs[row + 1u]; + let b_start = cmmc_b_row_ptrs[row]; + let b_end = cmmc_b_row_ptrs[row + 1u]; + + var count: i32 = 0; + var i: i32 = a_start; + var j: i32 = b_start; + + // Count matching column indices only (intersection) + while (i < a_end && j < b_end) { + let a_col = cmmc_a_col_indices[i]; + let b_col = cmmc_b_col_indices[j]; + + if (a_col < b_col) { + i = i + 1; + } else if (a_col > b_col) { + j = j + 1; + } else { + count = count + 1; + i = i + 1; + j = j + 1; + } + } + + cmmc_row_counts[row] = count; +} + +// ============================================================================ +// csc_merge_count +// ============================================================================ + +struct CscMergeCountParams { + ncols: u32, +} + +@group(0) @binding(0) var csmc_a_col_ptrs: array; +@group(0) @binding(1) var csmc_a_row_indices: array; +@group(0) @binding(2) var csmc_b_col_ptrs: array; +@group(0) @binding(3) var csmc_b_row_indices: array; +@group(0) @binding(4) var csmc_col_counts: array; +@group(0) @binding(5) var csmc_params: CscMergeCountParams; + +@compute @workgroup_size(256) +fn csc_merge_count(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csmc_params.ncols) { + return; + } + + let a_start = csmc_a_col_ptrs[col]; + let a_end = csmc_a_col_ptrs[col + 1u]; + let b_start = csmc_b_col_ptrs[col]; + let b_end = csmc_b_col_ptrs[col + 1u]; + + var count: i32 = 0; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csmc_a_row_indices[i]; + let b_row = csmc_b_row_indices[j]; + + count = count + 1; + if (a_row < b_row) { + i = i + 1; + } else if (a_row > b_row) { + j = j + 1; + } else { + i = i + 1; + j = j + 1; + } + } + + count = count + (a_end - i); + count = count + (b_end - j); + + csmc_col_counts[col] = count; +} + +// ============================================================================ +// csc_mul_count +// ============================================================================ + +struct CscMulCountParams { + ncols: u32, +} + +@group(0) @binding(0) var csmmc_a_col_ptrs: array; +@group(0) @binding(1) var csmmc_a_row_indices: array; +@group(0) @binding(2) var csmmc_b_col_ptrs: array; +@group(0) @binding(3) var csmmc_b_row_indices: array; +@group(0) @binding(4) var csmmc_col_counts: array; +@group(0) @binding(5) var csmmc_params: CscMulCountParams; + +@compute @workgroup_size(256) +fn csc_mul_count(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csmmc_params.ncols) { + return; + } + + let a_start = csmmc_a_col_ptrs[col]; + let a_end = csmmc_a_col_ptrs[col + 1u]; + let b_start = csmmc_b_col_ptrs[col]; + let b_end = csmmc_b_col_ptrs[col + 1u]; + + var count: i32 = 0; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csmmc_a_row_indices[i]; + let b_row = csmmc_b_row_indices[j]; + + if (a_row < b_row) { + i = i + 1; + } else if (a_row > b_row) { + j = j + 1; + } else { + count = count + 1; + i = i + 1; + j = j + 1; + } + } + + csmmc_col_counts[col] = count; +} + +// ============================================================================ +// exclusive_scan_i32 +// ============================================================================ + +struct ScanParams { + n: u32, +} + +@group(0) @binding(0) var scan_input: array; +@group(0) @binding(1) var scan_output: array; +@group(0) @binding(2) var scan_params: ScanParams; + +// Sequential exclusive scan - only first thread does work +@compute @workgroup_size(1) +fn exclusive_scan_i32(@builtin(global_invocation_id) gid: vec3) { + if (gid.x != 0u) { + return; + } + + var sum: i32 = 0; + for (var i: u32 = 0u; i < scan_params.n; i = i + 1u) { + let val = scan_input[i]; + scan_output[i] = sum; + sum = sum + val; + } + // Final element is total sum + scan_output[scan_params.n] = sum; +} diff --git a/src/runtime/wgpu/shaders/sparse_merge_f32.wgsl b/src/runtime/wgpu/shaders/sparse_merge_f32.wgsl new file mode 100644 index 00000000..9182a36a --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_merge_f32.wgsl @@ -0,0 +1,524 @@ +// Sparse merge compute shaders - F32 +// +// CSR: csr_add_compute_f32, csr_sub_compute_f32, csr_mul_compute_f32, csr_div_compute_f32 +// CSC: csc_add_compute_f32, csc_sub_compute_f32, csc_mul_compute_f32, csc_div_compute_f32 + +// ============================================================================ +// csr_add_compute_f32 (union semantics) +// ============================================================================ + +struct CsrAddParams { + nrows: u32, +} + +@group(0) @binding(0) var csr_add_a_row_ptrs: array; +@group(0) @binding(1) var csr_add_a_col_indices: array; +@group(0) @binding(2) var csr_add_a_values: array; +@group(0) @binding(3) var csr_add_b_row_ptrs: array; +@group(0) @binding(4) var csr_add_b_col_indices: array; +@group(0) @binding(5) var csr_add_b_values: array; +@group(0) @binding(6) var csr_add_out_row_ptrs: array; +@group(0) @binding(7) var csr_add_out_col_indices: array; +@group(0) @binding(8) var csr_add_out_values: array; +@group(0) @binding(9) var csr_add_params: CsrAddParams; + +@compute @workgroup_size(256) +fn csr_add_compute_f32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= csr_add_params.nrows) { + return; + } + + let a_start = csr_add_a_row_ptrs[row]; + let a_end = csr_add_a_row_ptrs[row + 1u]; + let b_start = csr_add_b_row_ptrs[row]; + let b_end = csr_add_b_row_ptrs[row + 1u]; + + var out_idx = csr_add_out_row_ptrs[row]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_col = csr_add_a_col_indices[i]; + let b_col = csr_add_b_col_indices[j]; + let a_val = csr_add_a_values[i]; + let b_val = csr_add_b_values[j]; + + if (a_col < b_col) { + csr_add_out_col_indices[out_idx] = a_col; + csr_add_out_values[out_idx] = a_val; + out_idx = out_idx + 1; + i = i + 1; + } else if (a_col > b_col) { + csr_add_out_col_indices[out_idx] = b_col; + csr_add_out_values[out_idx] = b_val; + out_idx = out_idx + 1; + j = j + 1; + } else { + csr_add_out_col_indices[out_idx] = a_col; + csr_add_out_values[out_idx] = a_val + b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } + + while (i < a_end) { + csr_add_out_col_indices[out_idx] = csr_add_a_col_indices[i]; + csr_add_out_values[out_idx] = csr_add_a_values[i]; + out_idx = out_idx + 1; + i = i + 1; + } + + while (j < b_end) { + csr_add_out_col_indices[out_idx] = csr_add_b_col_indices[j]; + csr_add_out_values[out_idx] = csr_add_b_values[j]; + out_idx = out_idx + 1; + j = j + 1; + } +} + +// ============================================================================ +// csr_sub_compute_f32 (union semantics) +// ============================================================================ + +struct CsrSubParams { + nrows: u32, +} + +@group(0) @binding(0) var csr_sub_a_row_ptrs: array; +@group(0) @binding(1) var csr_sub_a_col_indices: array; +@group(0) @binding(2) var csr_sub_a_values: array; +@group(0) @binding(3) var csr_sub_b_row_ptrs: array; +@group(0) @binding(4) var csr_sub_b_col_indices: array; +@group(0) @binding(5) var csr_sub_b_values: array; +@group(0) @binding(6) var csr_sub_out_row_ptrs: array; +@group(0) @binding(7) var csr_sub_out_col_indices: array; +@group(0) @binding(8) var csr_sub_out_values: array; +@group(0) @binding(9) var csr_sub_params: CsrSubParams; + +@compute @workgroup_size(256) +fn csr_sub_compute_f32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= csr_sub_params.nrows) { + return; + } + + let a_start = csr_sub_a_row_ptrs[row]; + let a_end = csr_sub_a_row_ptrs[row + 1u]; + let b_start = csr_sub_b_row_ptrs[row]; + let b_end = csr_sub_b_row_ptrs[row + 1u]; + + var out_idx = csr_sub_out_row_ptrs[row]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_col = csr_sub_a_col_indices[i]; + let b_col = csr_sub_b_col_indices[j]; + let a_val = csr_sub_a_values[i]; + let b_val = csr_sub_b_values[j]; + + if (a_col < b_col) { + csr_sub_out_col_indices[out_idx] = a_col; + csr_sub_out_values[out_idx] = a_val; + out_idx = out_idx + 1; + i = i + 1; + } else if (a_col > b_col) { + csr_sub_out_col_indices[out_idx] = b_col; + csr_sub_out_values[out_idx] = -b_val; + out_idx = out_idx + 1; + j = j + 1; + } else { + csr_sub_out_col_indices[out_idx] = a_col; + csr_sub_out_values[out_idx] = a_val - b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } + + while (i < a_end) { + csr_sub_out_col_indices[out_idx] = csr_sub_a_col_indices[i]; + csr_sub_out_values[out_idx] = csr_sub_a_values[i]; + out_idx = out_idx + 1; + i = i + 1; + } + + while (j < b_end) { + csr_sub_out_col_indices[out_idx] = csr_sub_b_col_indices[j]; + csr_sub_out_values[out_idx] = -csr_sub_b_values[j]; + out_idx = out_idx + 1; + j = j + 1; + } +} + +// ============================================================================ +// csr_mul_compute_f32 (intersection semantics) +// ============================================================================ + +struct CsrMulParams { + nrows: u32, +} + +@group(0) @binding(0) var csr_mul_a_row_ptrs: array; +@group(0) @binding(1) var csr_mul_a_col_indices: array; +@group(0) @binding(2) var csr_mul_a_values: array; +@group(0) @binding(3) var csr_mul_b_row_ptrs: array; +@group(0) @binding(4) var csr_mul_b_col_indices: array; +@group(0) @binding(5) var csr_mul_b_values: array; +@group(0) @binding(6) var csr_mul_out_row_ptrs: array; +@group(0) @binding(7) var csr_mul_out_col_indices: array; +@group(0) @binding(8) var csr_mul_out_values: array; +@group(0) @binding(9) var csr_mul_params: CsrMulParams; + +@compute @workgroup_size(256) +fn csr_mul_compute_f32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= csr_mul_params.nrows) { + return; + } + + let a_start = csr_mul_a_row_ptrs[row]; + let a_end = csr_mul_a_row_ptrs[row + 1u]; + let b_start = csr_mul_b_row_ptrs[row]; + let b_end = csr_mul_b_row_ptrs[row + 1u]; + + var out_idx = csr_mul_out_row_ptrs[row]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_col = csr_mul_a_col_indices[i]; + let b_col = csr_mul_b_col_indices[j]; + + if (a_col < b_col) { + i = i + 1; + } else if (a_col > b_col) { + j = j + 1; + } else { + let a_val = csr_mul_a_values[i]; + let b_val = csr_mul_b_values[j]; + csr_mul_out_col_indices[out_idx] = a_col; + csr_mul_out_values[out_idx] = a_val * b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } +} + +// ============================================================================ +// csr_div_compute_f32 (intersection semantics) +// ============================================================================ + +struct CsrDivParams { + nrows: u32, +} + +@group(0) @binding(0) var csr_div_a_row_ptrs: array; +@group(0) @binding(1) var csr_div_a_col_indices: array; +@group(0) @binding(2) var csr_div_a_values: array; +@group(0) @binding(3) var csr_div_b_row_ptrs: array; +@group(0) @binding(4) var csr_div_b_col_indices: array; +@group(0) @binding(5) var csr_div_b_values: array; +@group(0) @binding(6) var csr_div_out_row_ptrs: array; +@group(0) @binding(7) var csr_div_out_col_indices: array; +@group(0) @binding(8) var csr_div_out_values: array; +@group(0) @binding(9) var csr_div_params: CsrDivParams; + +@compute @workgroup_size(256) +fn csr_div_compute_f32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= csr_div_params.nrows) { + return; + } + + let a_start = csr_div_a_row_ptrs[row]; + let a_end = csr_div_a_row_ptrs[row + 1u]; + let b_start = csr_div_b_row_ptrs[row]; + let b_end = csr_div_b_row_ptrs[row + 1u]; + + var out_idx = csr_div_out_row_ptrs[row]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_col = csr_div_a_col_indices[i]; + let b_col = csr_div_b_col_indices[j]; + + if (a_col < b_col) { + i = i + 1; + } else if (a_col > b_col) { + j = j + 1; + } else { + let a_val = csr_div_a_values[i]; + let b_val = csr_div_b_values[j]; + csr_div_out_col_indices[out_idx] = a_col; + csr_div_out_values[out_idx] = a_val / b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } +} + +// ============================================================================ +// csc_add_compute_f32 (union semantics) +// ============================================================================ + +struct CscAddParams { + ncols: u32, +} + +@group(0) @binding(0) var csc_add_a_col_ptrs: array; +@group(0) @binding(1) var csc_add_a_row_indices: array; +@group(0) @binding(2) var csc_add_a_values: array; +@group(0) @binding(3) var csc_add_b_col_ptrs: array; +@group(0) @binding(4) var csc_add_b_row_indices: array; +@group(0) @binding(5) var csc_add_b_values: array; +@group(0) @binding(6) var csc_add_out_col_ptrs: array; +@group(0) @binding(7) var csc_add_out_row_indices: array; +@group(0) @binding(8) var csc_add_out_values: array; +@group(0) @binding(9) var csc_add_params: CscAddParams; + +@compute @workgroup_size(256) +fn csc_add_compute_f32(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csc_add_params.ncols) { + return; + } + + let a_start = csc_add_a_col_ptrs[col]; + let a_end = csc_add_a_col_ptrs[col + 1u]; + let b_start = csc_add_b_col_ptrs[col]; + let b_end = csc_add_b_col_ptrs[col + 1u]; + + var out_idx = csc_add_out_col_ptrs[col]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csc_add_a_row_indices[i]; + let b_row = csc_add_b_row_indices[j]; + let a_val = csc_add_a_values[i]; + let b_val = csc_add_b_values[j]; + + if (a_row < b_row) { + csc_add_out_row_indices[out_idx] = a_row; + csc_add_out_values[out_idx] = a_val; + out_idx = out_idx + 1; + i = i + 1; + } else if (a_row > b_row) { + csc_add_out_row_indices[out_idx] = b_row; + csc_add_out_values[out_idx] = b_val; + out_idx = out_idx + 1; + j = j + 1; + } else { + csc_add_out_row_indices[out_idx] = a_row; + csc_add_out_values[out_idx] = a_val + b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } + + while (i < a_end) { + csc_add_out_row_indices[out_idx] = csc_add_a_row_indices[i]; + csc_add_out_values[out_idx] = csc_add_a_values[i]; + out_idx = out_idx + 1; + i = i + 1; + } + + while (j < b_end) { + csc_add_out_row_indices[out_idx] = csc_add_b_row_indices[j]; + csc_add_out_values[out_idx] = csc_add_b_values[j]; + out_idx = out_idx + 1; + j = j + 1; + } +} + +// ============================================================================ +// csc_sub_compute_f32 (union semantics) +// ============================================================================ + +struct CscSubParams { + ncols: u32, +} + +@group(0) @binding(0) var csc_sub_a_col_ptrs: array; +@group(0) @binding(1) var csc_sub_a_row_indices: array; +@group(0) @binding(2) var csc_sub_a_values: array; +@group(0) @binding(3) var csc_sub_b_col_ptrs: array; +@group(0) @binding(4) var csc_sub_b_row_indices: array; +@group(0) @binding(5) var csc_sub_b_values: array; +@group(0) @binding(6) var csc_sub_out_col_ptrs: array; +@group(0) @binding(7) var csc_sub_out_row_indices: array; +@group(0) @binding(8) var csc_sub_out_values: array; +@group(0) @binding(9) var csc_sub_params: CscSubParams; + +@compute @workgroup_size(256) +fn csc_sub_compute_f32(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csc_sub_params.ncols) { + return; + } + + let a_start = csc_sub_a_col_ptrs[col]; + let a_end = csc_sub_a_col_ptrs[col + 1u]; + let b_start = csc_sub_b_col_ptrs[col]; + let b_end = csc_sub_b_col_ptrs[col + 1u]; + + var out_idx = csc_sub_out_col_ptrs[col]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csc_sub_a_row_indices[i]; + let b_row = csc_sub_b_row_indices[j]; + let a_val = csc_sub_a_values[i]; + let b_val = csc_sub_b_values[j]; + + if (a_row < b_row) { + csc_sub_out_row_indices[out_idx] = a_row; + csc_sub_out_values[out_idx] = a_val; + out_idx = out_idx + 1; + i = i + 1; + } else if (a_row > b_row) { + csc_sub_out_row_indices[out_idx] = b_row; + csc_sub_out_values[out_idx] = -b_val; + out_idx = out_idx + 1; + j = j + 1; + } else { + csc_sub_out_row_indices[out_idx] = a_row; + csc_sub_out_values[out_idx] = a_val - b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } + + while (i < a_end) { + csc_sub_out_row_indices[out_idx] = csc_sub_a_row_indices[i]; + csc_sub_out_values[out_idx] = csc_sub_a_values[i]; + out_idx = out_idx + 1; + i = i + 1; + } + + while (j < b_end) { + csc_sub_out_row_indices[out_idx] = csc_sub_b_row_indices[j]; + csc_sub_out_values[out_idx] = -csc_sub_b_values[j]; + out_idx = out_idx + 1; + j = j + 1; + } +} + +// ============================================================================ +// csc_mul_compute_f32 (intersection semantics) +// ============================================================================ + +struct CscMulParams { + ncols: u32, +} + +@group(0) @binding(0) var csc_mul_a_col_ptrs: array; +@group(0) @binding(1) var csc_mul_a_row_indices: array; +@group(0) @binding(2) var csc_mul_a_values: array; +@group(0) @binding(3) var csc_mul_b_col_ptrs: array; +@group(0) @binding(4) var csc_mul_b_row_indices: array; +@group(0) @binding(5) var csc_mul_b_values: array; +@group(0) @binding(6) var csc_mul_out_col_ptrs: array; +@group(0) @binding(7) var csc_mul_out_row_indices: array; +@group(0) @binding(8) var csc_mul_out_values: array; +@group(0) @binding(9) var csc_mul_params: CscMulParams; + +@compute @workgroup_size(256) +fn csc_mul_compute_f32(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csc_mul_params.ncols) { + return; + } + + let a_start = csc_mul_a_col_ptrs[col]; + let a_end = csc_mul_a_col_ptrs[col + 1u]; + let b_start = csc_mul_b_col_ptrs[col]; + let b_end = csc_mul_b_col_ptrs[col + 1u]; + + var out_idx = csc_mul_out_col_ptrs[col]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csc_mul_a_row_indices[i]; + let b_row = csc_mul_b_row_indices[j]; + + if (a_row < b_row) { + i = i + 1; + } else if (a_row > b_row) { + j = j + 1; + } else { + let a_val = csc_mul_a_values[i]; + let b_val = csc_mul_b_values[j]; + csc_mul_out_row_indices[out_idx] = a_row; + csc_mul_out_values[out_idx] = a_val * b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } +} + +// ============================================================================ +// csc_div_compute_f32 (intersection semantics) +// ============================================================================ + +struct CscDivParams { + ncols: u32, +} + +@group(0) @binding(0) var csc_div_a_col_ptrs: array; +@group(0) @binding(1) var csc_div_a_row_indices: array; +@group(0) @binding(2) var csc_div_a_values: array; +@group(0) @binding(3) var csc_div_b_col_ptrs: array; +@group(0) @binding(4) var csc_div_b_row_indices: array; +@group(0) @binding(5) var csc_div_b_values: array; +@group(0) @binding(6) var csc_div_out_col_ptrs: array; +@group(0) @binding(7) var csc_div_out_row_indices: array; +@group(0) @binding(8) var csc_div_out_values: array; +@group(0) @binding(9) var csc_div_params: CscDivParams; + +@compute @workgroup_size(256) +fn csc_div_compute_f32(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csc_div_params.ncols) { + return; + } + + let a_start = csc_div_a_col_ptrs[col]; + let a_end = csc_div_a_col_ptrs[col + 1u]; + let b_start = csc_div_b_col_ptrs[col]; + let b_end = csc_div_b_col_ptrs[col + 1u]; + + var out_idx = csc_div_out_col_ptrs[col]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csc_div_a_row_indices[i]; + let b_row = csc_div_b_row_indices[j]; + + if (a_row < b_row) { + i = i + 1; + } else if (a_row > b_row) { + j = j + 1; + } else { + let a_val = csc_div_a_values[i]; + let b_val = csc_div_b_values[j]; + csc_div_out_row_indices[out_idx] = a_row; + csc_div_out_values[out_idx] = a_val / b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } +} diff --git a/src/runtime/wgpu/shaders/sparse_merge_i32.wgsl b/src/runtime/wgpu/shaders/sparse_merge_i32.wgsl new file mode 100644 index 00000000..9eae9c4e --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_merge_i32.wgsl @@ -0,0 +1,524 @@ +// Sparse merge compute shaders - I32 +// +// CSR: csr_add_compute_i32, csr_sub_compute_i32, csr_mul_compute_i32, csr_div_compute_i32 +// CSC: csc_add_compute_i32, csc_sub_compute_i32, csc_mul_compute_i32, csc_div_compute_i32 + +// ============================================================================ +// csr_add_compute_i32 (union semantics) +// ============================================================================ + +struct CsrAddI32Params { + nrows: u32, +} + +@group(0) @binding(0) var csr_add_i32_a_row_ptrs: array; +@group(0) @binding(1) var csr_add_i32_a_col_indices: array; +@group(0) @binding(2) var csr_add_i32_a_values: array; +@group(0) @binding(3) var csr_add_i32_b_row_ptrs: array; +@group(0) @binding(4) var csr_add_i32_b_col_indices: array; +@group(0) @binding(5) var csr_add_i32_b_values: array; +@group(0) @binding(6) var csr_add_i32_out_row_ptrs: array; +@group(0) @binding(7) var csr_add_i32_out_col_indices: array; +@group(0) @binding(8) var csr_add_i32_out_values: array; +@group(0) @binding(9) var csr_add_i32_params: CsrAddI32Params; + +@compute @workgroup_size(256) +fn csr_add_compute_i32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= csr_add_i32_params.nrows) { + return; + } + + let a_start = csr_add_i32_a_row_ptrs[row]; + let a_end = csr_add_i32_a_row_ptrs[row + 1u]; + let b_start = csr_add_i32_b_row_ptrs[row]; + let b_end = csr_add_i32_b_row_ptrs[row + 1u]; + + var out_idx = csr_add_i32_out_row_ptrs[row]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_col = csr_add_i32_a_col_indices[i]; + let b_col = csr_add_i32_b_col_indices[j]; + let a_val = csr_add_i32_a_values[i]; + let b_val = csr_add_i32_b_values[j]; + + if (a_col < b_col) { + csr_add_i32_out_col_indices[out_idx] = a_col; + csr_add_i32_out_values[out_idx] = a_val; + out_idx = out_idx + 1; + i = i + 1; + } else if (a_col > b_col) { + csr_add_i32_out_col_indices[out_idx] = b_col; + csr_add_i32_out_values[out_idx] = b_val; + out_idx = out_idx + 1; + j = j + 1; + } else { + csr_add_i32_out_col_indices[out_idx] = a_col; + csr_add_i32_out_values[out_idx] = a_val + b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } + + while (i < a_end) { + csr_add_i32_out_col_indices[out_idx] = csr_add_i32_a_col_indices[i]; + csr_add_i32_out_values[out_idx] = csr_add_i32_a_values[i]; + out_idx = out_idx + 1; + i = i + 1; + } + + while (j < b_end) { + csr_add_i32_out_col_indices[out_idx] = csr_add_i32_b_col_indices[j]; + csr_add_i32_out_values[out_idx] = csr_add_i32_b_values[j]; + out_idx = out_idx + 1; + j = j + 1; + } +} + +// ============================================================================ +// csr_sub_compute_i32 (union semantics) +// ============================================================================ + +struct CsrSubI32Params { + nrows: u32, +} + +@group(0) @binding(0) var csr_sub_i32_a_row_ptrs: array; +@group(0) @binding(1) var csr_sub_i32_a_col_indices: array; +@group(0) @binding(2) var csr_sub_i32_a_values: array; +@group(0) @binding(3) var csr_sub_i32_b_row_ptrs: array; +@group(0) @binding(4) var csr_sub_i32_b_col_indices: array; +@group(0) @binding(5) var csr_sub_i32_b_values: array; +@group(0) @binding(6) var csr_sub_i32_out_row_ptrs: array; +@group(0) @binding(7) var csr_sub_i32_out_col_indices: array; +@group(0) @binding(8) var csr_sub_i32_out_values: array; +@group(0) @binding(9) var csr_sub_i32_params: CsrSubI32Params; + +@compute @workgroup_size(256) +fn csr_sub_compute_i32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= csr_sub_i32_params.nrows) { + return; + } + + let a_start = csr_sub_i32_a_row_ptrs[row]; + let a_end = csr_sub_i32_a_row_ptrs[row + 1u]; + let b_start = csr_sub_i32_b_row_ptrs[row]; + let b_end = csr_sub_i32_b_row_ptrs[row + 1u]; + + var out_idx = csr_sub_i32_out_row_ptrs[row]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_col = csr_sub_i32_a_col_indices[i]; + let b_col = csr_sub_i32_b_col_indices[j]; + let a_val = csr_sub_i32_a_values[i]; + let b_val = csr_sub_i32_b_values[j]; + + if (a_col < b_col) { + csr_sub_i32_out_col_indices[out_idx] = a_col; + csr_sub_i32_out_values[out_idx] = a_val; + out_idx = out_idx + 1; + i = i + 1; + } else if (a_col > b_col) { + csr_sub_i32_out_col_indices[out_idx] = b_col; + csr_sub_i32_out_values[out_idx] = -b_val; + out_idx = out_idx + 1; + j = j + 1; + } else { + csr_sub_i32_out_col_indices[out_idx] = a_col; + csr_sub_i32_out_values[out_idx] = a_val - b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } + + while (i < a_end) { + csr_sub_i32_out_col_indices[out_idx] = csr_sub_i32_a_col_indices[i]; + csr_sub_i32_out_values[out_idx] = csr_sub_i32_a_values[i]; + out_idx = out_idx + 1; + i = i + 1; + } + + while (j < b_end) { + csr_sub_i32_out_col_indices[out_idx] = csr_sub_i32_b_col_indices[j]; + csr_sub_i32_out_values[out_idx] = -csr_sub_i32_b_values[j]; + out_idx = out_idx + 1; + j = j + 1; + } +} + +// ============================================================================ +// csr_mul_compute_i32 (intersection semantics) +// ============================================================================ + +struct CsrMulI32Params { + nrows: u32, +} + +@group(0) @binding(0) var csr_mul_i32_a_row_ptrs: array; +@group(0) @binding(1) var csr_mul_i32_a_col_indices: array; +@group(0) @binding(2) var csr_mul_i32_a_values: array; +@group(0) @binding(3) var csr_mul_i32_b_row_ptrs: array; +@group(0) @binding(4) var csr_mul_i32_b_col_indices: array; +@group(0) @binding(5) var csr_mul_i32_b_values: array; +@group(0) @binding(6) var csr_mul_i32_out_row_ptrs: array; +@group(0) @binding(7) var csr_mul_i32_out_col_indices: array; +@group(0) @binding(8) var csr_mul_i32_out_values: array; +@group(0) @binding(9) var csr_mul_i32_params: CsrMulI32Params; + +@compute @workgroup_size(256) +fn csr_mul_compute_i32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= csr_mul_i32_params.nrows) { + return; + } + + let a_start = csr_mul_i32_a_row_ptrs[row]; + let a_end = csr_mul_i32_a_row_ptrs[row + 1u]; + let b_start = csr_mul_i32_b_row_ptrs[row]; + let b_end = csr_mul_i32_b_row_ptrs[row + 1u]; + + var out_idx = csr_mul_i32_out_row_ptrs[row]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_col = csr_mul_i32_a_col_indices[i]; + let b_col = csr_mul_i32_b_col_indices[j]; + + if (a_col < b_col) { + i = i + 1; + } else if (a_col > b_col) { + j = j + 1; + } else { + let a_val = csr_mul_i32_a_values[i]; + let b_val = csr_mul_i32_b_values[j]; + csr_mul_i32_out_col_indices[out_idx] = a_col; + csr_mul_i32_out_values[out_idx] = a_val * b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } +} + +// ============================================================================ +// csr_div_compute_i32 (intersection semantics) +// ============================================================================ + +struct CsrDivI32Params { + nrows: u32, +} + +@group(0) @binding(0) var csr_div_i32_a_row_ptrs: array; +@group(0) @binding(1) var csr_div_i32_a_col_indices: array; +@group(0) @binding(2) var csr_div_i32_a_values: array; +@group(0) @binding(3) var csr_div_i32_b_row_ptrs: array; +@group(0) @binding(4) var csr_div_i32_b_col_indices: array; +@group(0) @binding(5) var csr_div_i32_b_values: array; +@group(0) @binding(6) var csr_div_i32_out_row_ptrs: array; +@group(0) @binding(7) var csr_div_i32_out_col_indices: array; +@group(0) @binding(8) var csr_div_i32_out_values: array; +@group(0) @binding(9) var csr_div_i32_params: CsrDivI32Params; + +@compute @workgroup_size(256) +fn csr_div_compute_i32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= csr_div_i32_params.nrows) { + return; + } + + let a_start = csr_div_i32_a_row_ptrs[row]; + let a_end = csr_div_i32_a_row_ptrs[row + 1u]; + let b_start = csr_div_i32_b_row_ptrs[row]; + let b_end = csr_div_i32_b_row_ptrs[row + 1u]; + + var out_idx = csr_div_i32_out_row_ptrs[row]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_col = csr_div_i32_a_col_indices[i]; + let b_col = csr_div_i32_b_col_indices[j]; + + if (a_col < b_col) { + i = i + 1; + } else if (a_col > b_col) { + j = j + 1; + } else { + let a_val = csr_div_i32_a_values[i]; + let b_val = csr_div_i32_b_values[j]; + csr_div_i32_out_col_indices[out_idx] = a_col; + csr_div_i32_out_values[out_idx] = a_val / b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } +} + +// ============================================================================ +// csc_add_compute_i32 (union semantics) +// ============================================================================ + +struct CscAddI32Params { + ncols: u32, +} + +@group(0) @binding(0) var csc_add_i32_a_col_ptrs: array; +@group(0) @binding(1) var csc_add_i32_a_row_indices: array; +@group(0) @binding(2) var csc_add_i32_a_values: array; +@group(0) @binding(3) var csc_add_i32_b_col_ptrs: array; +@group(0) @binding(4) var csc_add_i32_b_row_indices: array; +@group(0) @binding(5) var csc_add_i32_b_values: array; +@group(0) @binding(6) var csc_add_i32_out_col_ptrs: array; +@group(0) @binding(7) var csc_add_i32_out_row_indices: array; +@group(0) @binding(8) var csc_add_i32_out_values: array; +@group(0) @binding(9) var csc_add_i32_params: CscAddI32Params; + +@compute @workgroup_size(256) +fn csc_add_compute_i32(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csc_add_i32_params.ncols) { + return; + } + + let a_start = csc_add_i32_a_col_ptrs[col]; + let a_end = csc_add_i32_a_col_ptrs[col + 1u]; + let b_start = csc_add_i32_b_col_ptrs[col]; + let b_end = csc_add_i32_b_col_ptrs[col + 1u]; + + var out_idx = csc_add_i32_out_col_ptrs[col]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csc_add_i32_a_row_indices[i]; + let b_row = csc_add_i32_b_row_indices[j]; + let a_val = csc_add_i32_a_values[i]; + let b_val = csc_add_i32_b_values[j]; + + if (a_row < b_row) { + csc_add_i32_out_row_indices[out_idx] = a_row; + csc_add_i32_out_values[out_idx] = a_val; + out_idx = out_idx + 1; + i = i + 1; + } else if (a_row > b_row) { + csc_add_i32_out_row_indices[out_idx] = b_row; + csc_add_i32_out_values[out_idx] = b_val; + out_idx = out_idx + 1; + j = j + 1; + } else { + csc_add_i32_out_row_indices[out_idx] = a_row; + csc_add_i32_out_values[out_idx] = a_val + b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } + + while (i < a_end) { + csc_add_i32_out_row_indices[out_idx] = csc_add_i32_a_row_indices[i]; + csc_add_i32_out_values[out_idx] = csc_add_i32_a_values[i]; + out_idx = out_idx + 1; + i = i + 1; + } + + while (j < b_end) { + csc_add_i32_out_row_indices[out_idx] = csc_add_i32_b_row_indices[j]; + csc_add_i32_out_values[out_idx] = csc_add_i32_b_values[j]; + out_idx = out_idx + 1; + j = j + 1; + } +} + +// ============================================================================ +// csc_sub_compute_i32 (union semantics) +// ============================================================================ + +struct CscSubI32Params { + ncols: u32, +} + +@group(0) @binding(0) var csc_sub_i32_a_col_ptrs: array; +@group(0) @binding(1) var csc_sub_i32_a_row_indices: array; +@group(0) @binding(2) var csc_sub_i32_a_values: array; +@group(0) @binding(3) var csc_sub_i32_b_col_ptrs: array; +@group(0) @binding(4) var csc_sub_i32_b_row_indices: array; +@group(0) @binding(5) var csc_sub_i32_b_values: array; +@group(0) @binding(6) var csc_sub_i32_out_col_ptrs: array; +@group(0) @binding(7) var csc_sub_i32_out_row_indices: array; +@group(0) @binding(8) var csc_sub_i32_out_values: array; +@group(0) @binding(9) var csc_sub_i32_params: CscSubI32Params; + +@compute @workgroup_size(256) +fn csc_sub_compute_i32(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csc_sub_i32_params.ncols) { + return; + } + + let a_start = csc_sub_i32_a_col_ptrs[col]; + let a_end = csc_sub_i32_a_col_ptrs[col + 1u]; + let b_start = csc_sub_i32_b_col_ptrs[col]; + let b_end = csc_sub_i32_b_col_ptrs[col + 1u]; + + var out_idx = csc_sub_i32_out_col_ptrs[col]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csc_sub_i32_a_row_indices[i]; + let b_row = csc_sub_i32_b_row_indices[j]; + let a_val = csc_sub_i32_a_values[i]; + let b_val = csc_sub_i32_b_values[j]; + + if (a_row < b_row) { + csc_sub_i32_out_row_indices[out_idx] = a_row; + csc_sub_i32_out_values[out_idx] = a_val; + out_idx = out_idx + 1; + i = i + 1; + } else if (a_row > b_row) { + csc_sub_i32_out_row_indices[out_idx] = b_row; + csc_sub_i32_out_values[out_idx] = -b_val; + out_idx = out_idx + 1; + j = j + 1; + } else { + csc_sub_i32_out_row_indices[out_idx] = a_row; + csc_sub_i32_out_values[out_idx] = a_val - b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } + + while (i < a_end) { + csc_sub_i32_out_row_indices[out_idx] = csc_sub_i32_a_row_indices[i]; + csc_sub_i32_out_values[out_idx] = csc_sub_i32_a_values[i]; + out_idx = out_idx + 1; + i = i + 1; + } + + while (j < b_end) { + csc_sub_i32_out_row_indices[out_idx] = csc_sub_i32_b_row_indices[j]; + csc_sub_i32_out_values[out_idx] = -csc_sub_i32_b_values[j]; + out_idx = out_idx + 1; + j = j + 1; + } +} + +// ============================================================================ +// csc_mul_compute_i32 (intersection semantics) +// ============================================================================ + +struct CscMulI32Params { + ncols: u32, +} + +@group(0) @binding(0) var csc_mul_i32_a_col_ptrs: array; +@group(0) @binding(1) var csc_mul_i32_a_row_indices: array; +@group(0) @binding(2) var csc_mul_i32_a_values: array; +@group(0) @binding(3) var csc_mul_i32_b_col_ptrs: array; +@group(0) @binding(4) var csc_mul_i32_b_row_indices: array; +@group(0) @binding(5) var csc_mul_i32_b_values: array; +@group(0) @binding(6) var csc_mul_i32_out_col_ptrs: array; +@group(0) @binding(7) var csc_mul_i32_out_row_indices: array; +@group(0) @binding(8) var csc_mul_i32_out_values: array; +@group(0) @binding(9) var csc_mul_i32_params: CscMulI32Params; + +@compute @workgroup_size(256) +fn csc_mul_compute_i32(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csc_mul_i32_params.ncols) { + return; + } + + let a_start = csc_mul_i32_a_col_ptrs[col]; + let a_end = csc_mul_i32_a_col_ptrs[col + 1u]; + let b_start = csc_mul_i32_b_col_ptrs[col]; + let b_end = csc_mul_i32_b_col_ptrs[col + 1u]; + + var out_idx = csc_mul_i32_out_col_ptrs[col]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csc_mul_i32_a_row_indices[i]; + let b_row = csc_mul_i32_b_row_indices[j]; + + if (a_row < b_row) { + i = i + 1; + } else if (a_row > b_row) { + j = j + 1; + } else { + let a_val = csc_mul_i32_a_values[i]; + let b_val = csc_mul_i32_b_values[j]; + csc_mul_i32_out_row_indices[out_idx] = a_row; + csc_mul_i32_out_values[out_idx] = a_val * b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } +} + +// ============================================================================ +// csc_div_compute_i32 (intersection semantics) +// ============================================================================ + +struct CscDivI32Params { + ncols: u32, +} + +@group(0) @binding(0) var csc_div_i32_a_col_ptrs: array; +@group(0) @binding(1) var csc_div_i32_a_row_indices: array; +@group(0) @binding(2) var csc_div_i32_a_values: array; +@group(0) @binding(3) var csc_div_i32_b_col_ptrs: array; +@group(0) @binding(4) var csc_div_i32_b_row_indices: array; +@group(0) @binding(5) var csc_div_i32_b_values: array; +@group(0) @binding(6) var csc_div_i32_out_col_ptrs: array; +@group(0) @binding(7) var csc_div_i32_out_row_indices: array; +@group(0) @binding(8) var csc_div_i32_out_values: array; +@group(0) @binding(9) var csc_div_i32_params: CscDivI32Params; + +@compute @workgroup_size(256) +fn csc_div_compute_i32(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csc_div_i32_params.ncols) { + return; + } + + let a_start = csc_div_i32_a_col_ptrs[col]; + let a_end = csc_div_i32_a_col_ptrs[col + 1u]; + let b_start = csc_div_i32_b_col_ptrs[col]; + let b_end = csc_div_i32_b_col_ptrs[col + 1u]; + + var out_idx = csc_div_i32_out_col_ptrs[col]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csc_div_i32_a_row_indices[i]; + let b_row = csc_div_i32_b_row_indices[j]; + + if (a_row < b_row) { + i = i + 1; + } else if (a_row > b_row) { + j = j + 1; + } else { + let a_val = csc_div_i32_a_values[i]; + let b_val = csc_div_i32_b_values[j]; + csc_div_i32_out_row_indices[out_idx] = a_row; + csc_div_i32_out_values[out_idx] = a_val / b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } +} diff --git a/src/runtime/wgpu/shaders/sparse_merge_launcher.rs b/src/runtime/wgpu/shaders/sparse_merge_launcher.rs index c940ecac..8198d675 100644 --- a/src/runtime/wgpu/shaders/sparse_merge_launcher.rs +++ b/src/runtime/wgpu/shaders/sparse_merge_launcher.rs @@ -7,18 +7,41 @@ use wgpu::{Buffer, Queue}; -use super::generator::dtype_suffix; -use super::generator::sparse_merge::{ - generate_csc_add_compute_shader, generate_csc_div_compute_shader, - generate_csc_merge_count_shader, generate_csc_mul_compute_shader, - generate_csc_mul_count_shader, generate_csc_sub_compute_shader, - generate_csr_add_compute_shader, generate_csr_div_compute_shader, - generate_csr_merge_count_shader, generate_csr_mul_compute_shader, - generate_csr_mul_count_shader, generate_csr_sub_compute_shader, generate_exclusive_scan_shader, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; -use crate::error::Result; +use crate::error::{Error, Result}; + +// Static WGSL shader sources +const SPARSE_MERGE_COUNT: &str = include_str!("sparse_merge_count.wgsl"); +const SPARSE_MERGE_F32: &str = include_str!("sparse_merge_f32.wgsl"); +const SPARSE_MERGE_I32: &str = include_str!("sparse_merge_i32.wgsl"); +const SPARSE_MERGE_U32: &str = include_str!("sparse_merge_u32.wgsl"); + +/// Return (module_key, shader_source) for a dtype-specific merge shader. +fn typed_merge_shader(dtype: DType) -> Result<(&'static str, &'static str)> { + match dtype { + DType::F32 => Ok(("sparse_merge_f32", SPARSE_MERGE_F32)), + DType::I32 => Ok(("sparse_merge_i32", SPARSE_MERGE_I32)), + DType::U32 => Ok(("sparse_merge_u32", SPARSE_MERGE_U32)), + _ => Err(Error::UnsupportedDType { + dtype, + op: "sparse_merge (WebGPU)", + }), + } +} + +/// Return the dtype suffix string for entry point names. +fn dtype_suffix(dtype: DType) -> Result<&'static str> { + match dtype { + DType::F32 => Ok("f32"), + DType::I32 => Ok("i32"), + DType::U32 => Ok("u32"), + _ => Err(Error::UnsupportedDType { + dtype, + op: "sparse_merge (WebGPU)", + }), + } +} // ============================================================================ // CSR Count Kernels @@ -36,8 +59,7 @@ pub fn launch_csr_merge_count( params_buffer: &Buffer, nrows: usize, ) -> Result<()> { - let shader_source = generate_csr_merge_count_shader(); - let module = cache.get_or_create_module_from_source("csr_merge_count", &shader_source); + let module = cache.get_or_create_module("sparse_merge_count", SPARSE_MERGE_COUNT); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 5, // a_row_ptrs, a_col_indices, b_row_ptrs, b_col_indices, row_counts @@ -45,12 +67,8 @@ pub fn launch_csr_merge_count( num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline( - "csr_merge_count", - "csr_merge_count", - &module, - &layout, - ); + let pipeline = + cache.get_or_create_pipeline("sparse_merge_count", "csr_merge_count", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -96,8 +114,7 @@ pub fn launch_csr_mul_count( params_buffer: &Buffer, nrows: usize, ) -> Result<()> { - let shader_source = generate_csr_mul_count_shader(); - let module = cache.get_or_create_module_from_source("csr_mul_count", &shader_source); + let module = cache.get_or_create_module("sparse_merge_count", SPARSE_MERGE_COUNT); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 5, @@ -106,7 +123,7 @@ pub fn launch_csr_mul_count( }); let pipeline = - cache.get_or_create_dynamic_pipeline("csr_mul_count", "csr_mul_count", &module, &layout); + cache.get_or_create_pipeline("sparse_merge_count", "csr_mul_count", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -161,12 +178,16 @@ pub fn launch_csr_add_compute( nrows: usize, dtype: DType, ) -> Result<()> { + let (module_key, shader) = typed_merge_shader(dtype)?; let suffix = dtype_suffix(dtype)?; - let entry_point = format!("csr_add_compute_{}", suffix); + let entry_point: &'static str = match suffix { + "f32" => "csr_add_compute_f32", + "i32" => "csr_add_compute_i32", + "u32" => "csr_add_compute_u32", + _ => unreachable!(), + }; - let shader_source = generate_csr_add_compute_shader(dtype)?; - let module_name = format!("csr_add_compute_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 9, // 3 for A, 3 for B, 3 for output @@ -174,8 +195,7 @@ pub fn launch_csr_add_compute( num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("csr_add_compute", &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -230,12 +250,16 @@ pub fn launch_csr_sub_compute( nrows: usize, dtype: DType, ) -> Result<()> { + let (module_key, shader) = typed_merge_shader(dtype)?; let suffix = dtype_suffix(dtype)?; - let entry_point = format!("csr_sub_compute_{}", suffix); + let entry_point: &'static str = match suffix { + "f32" => "csr_sub_compute_f32", + "i32" => "csr_sub_compute_i32", + "u32" => "csr_sub_compute_u32", + _ => unreachable!(), + }; - let shader_source = generate_csr_sub_compute_shader(dtype)?; - let module_name = format!("csr_sub_compute_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 9, @@ -243,8 +267,7 @@ pub fn launch_csr_sub_compute( num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("csr_sub_compute", &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -299,12 +322,16 @@ pub fn launch_csr_mul_compute( nrows: usize, dtype: DType, ) -> Result<()> { + let (module_key, shader) = typed_merge_shader(dtype)?; let suffix = dtype_suffix(dtype)?; - let entry_point = format!("csr_mul_compute_{}", suffix); + let entry_point: &'static str = match suffix { + "f32" => "csr_mul_compute_f32", + "i32" => "csr_mul_compute_i32", + "u32" => "csr_mul_compute_u32", + _ => unreachable!(), + }; - let shader_source = generate_csr_mul_compute_shader(dtype)?; - let module_name = format!("csr_mul_compute_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 9, @@ -312,8 +339,7 @@ pub fn launch_csr_mul_compute( num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("csr_mul_compute", &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -368,12 +394,16 @@ pub fn launch_csr_div_compute( nrows: usize, dtype: DType, ) -> Result<()> { + let (module_key, shader) = typed_merge_shader(dtype)?; let suffix = dtype_suffix(dtype)?; - let entry_point = format!("csr_div_compute_{}", suffix); + let entry_point: &'static str = match suffix { + "f32" => "csr_div_compute_f32", + "i32" => "csr_div_compute_i32", + "u32" => "csr_div_compute_u32", + _ => unreachable!(), + }; - let shader_source = generate_csr_div_compute_shader(dtype)?; - let module_name = format!("csr_div_compute_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 9, @@ -381,8 +411,7 @@ pub fn launch_csr_div_compute( num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("csr_div_compute", &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -436,8 +465,7 @@ pub fn launch_csc_merge_count( params_buffer: &Buffer, ncols: usize, ) -> Result<()> { - let shader_source = generate_csc_merge_count_shader(); - let module = cache.get_or_create_module_from_source("csc_merge_count", &shader_source); + let module = cache.get_or_create_module("sparse_merge_count", SPARSE_MERGE_COUNT); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 5, @@ -445,12 +473,8 @@ pub fn launch_csc_merge_count( num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline( - "csc_merge_count", - "csc_merge_count", - &module, - &layout, - ); + let pipeline = + cache.get_or_create_pipeline("sparse_merge_count", "csc_merge_count", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -496,8 +520,7 @@ pub fn launch_csc_mul_count( params_buffer: &Buffer, ncols: usize, ) -> Result<()> { - let shader_source = generate_csc_mul_count_shader(); - let module = cache.get_or_create_module_from_source("csc_mul_count", &shader_source); + let module = cache.get_or_create_module("sparse_merge_count", SPARSE_MERGE_COUNT); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 5, @@ -506,7 +529,7 @@ pub fn launch_csc_mul_count( }); let pipeline = - cache.get_or_create_dynamic_pipeline("csc_mul_count", "csc_mul_count", &module, &layout); + cache.get_or_create_pipeline("sparse_merge_count", "csc_mul_count", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -561,12 +584,16 @@ pub fn launch_csc_add_compute( ncols: usize, dtype: DType, ) -> Result<()> { + let (module_key, shader) = typed_merge_shader(dtype)?; let suffix = dtype_suffix(dtype)?; - let entry_point = format!("csc_add_compute_{}", suffix); + let entry_point: &'static str = match suffix { + "f32" => "csc_add_compute_f32", + "i32" => "csc_add_compute_i32", + "u32" => "csc_add_compute_u32", + _ => unreachable!(), + }; - let shader_source = generate_csc_add_compute_shader(dtype)?; - let module_name = format!("csc_add_compute_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 9, @@ -574,8 +601,7 @@ pub fn launch_csc_add_compute( num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("csc_add_compute", &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -630,12 +656,16 @@ pub fn launch_csc_sub_compute( ncols: usize, dtype: DType, ) -> Result<()> { + let (module_key, shader) = typed_merge_shader(dtype)?; let suffix = dtype_suffix(dtype)?; - let entry_point = format!("csc_sub_compute_{}", suffix); + let entry_point: &'static str = match suffix { + "f32" => "csc_sub_compute_f32", + "i32" => "csc_sub_compute_i32", + "u32" => "csc_sub_compute_u32", + _ => unreachable!(), + }; - let shader_source = generate_csc_sub_compute_shader(dtype)?; - let module_name = format!("csc_sub_compute_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 9, @@ -643,8 +673,7 @@ pub fn launch_csc_sub_compute( num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("csc_sub_compute", &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -699,12 +728,16 @@ pub fn launch_csc_mul_compute( ncols: usize, dtype: DType, ) -> Result<()> { + let (module_key, shader) = typed_merge_shader(dtype)?; let suffix = dtype_suffix(dtype)?; - let entry_point = format!("csc_mul_compute_{}", suffix); + let entry_point: &'static str = match suffix { + "f32" => "csc_mul_compute_f32", + "i32" => "csc_mul_compute_i32", + "u32" => "csc_mul_compute_u32", + _ => unreachable!(), + }; - let shader_source = generate_csc_mul_compute_shader(dtype)?; - let module_name = format!("csc_mul_compute_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 9, @@ -712,8 +745,7 @@ pub fn launch_csc_mul_compute( num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("csc_mul_compute", &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -768,12 +800,16 @@ pub fn launch_csc_div_compute( ncols: usize, dtype: DType, ) -> Result<()> { + let (module_key, shader) = typed_merge_shader(dtype)?; let suffix = dtype_suffix(dtype)?; - let entry_point = format!("csc_div_compute_{}", suffix); + let entry_point: &'static str = match suffix { + "f32" => "csc_div_compute_f32", + "i32" => "csc_div_compute_i32", + "u32" => "csc_div_compute_u32", + _ => unreachable!(), + }; - let shader_source = generate_csc_div_compute_shader(dtype)?; - let module_name = format!("csc_div_compute_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 9, @@ -781,8 +817,7 @@ pub fn launch_csc_div_compute( num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_dynamic_pipeline("csc_div_compute", &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -835,8 +870,7 @@ pub fn launch_exclusive_scan_i32( output: &Buffer, params_buffer: &Buffer, ) -> Result<()> { - let shader_source = generate_exclusive_scan_shader(); - let module = cache.get_or_create_module_from_source("exclusive_scan_i32", &shader_source); + let module = cache.get_or_create_module("sparse_merge_count", SPARSE_MERGE_COUNT); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, @@ -844,12 +878,8 @@ pub fn launch_exclusive_scan_i32( num_readonly_storage: 1, }); - let pipeline = cache.get_or_create_dynamic_pipeline( - "exclusive_scan_i32", - "exclusive_scan_i32", - &module, - &layout, - ); + let pipeline = + cache.get_or_create_pipeline("sparse_merge_count", "exclusive_scan_i32", &module, &layout); let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]); @@ -873,44 +903,3 @@ pub fn launch_exclusive_scan_i32( queue.submit(std::iter::once(encoder.finish())); Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_generated_shaders_are_valid() { - // Test all generated shaders have valid syntax - validate_wgsl_syntax(&generate_csr_merge_count_shader()) - .expect("CSR merge count should be valid"); - validate_wgsl_syntax(&generate_csr_mul_count_shader()) - .expect("CSR mul count should be valid"); - validate_wgsl_syntax(&generate_csc_merge_count_shader()) - .expect("CSC merge count should be valid"); - validate_wgsl_syntax(&generate_csc_mul_count_shader()) - .expect("CSC mul count should be valid"); - validate_wgsl_syntax(&generate_exclusive_scan_shader()) - .expect("Exclusive scan should be valid"); - - // Test compute shaders for F32 - validate_wgsl_syntax(&generate_csr_add_compute_shader(DType::F32).unwrap()) - .expect("CSR add compute should be valid"); - validate_wgsl_syntax(&generate_csr_sub_compute_shader(DType::F32).unwrap()) - .expect("CSR sub compute should be valid"); - validate_wgsl_syntax(&generate_csr_mul_compute_shader(DType::F32).unwrap()) - .expect("CSR mul compute should be valid"); - validate_wgsl_syntax(&generate_csr_div_compute_shader(DType::F32).unwrap()) - .expect("CSR div compute should be valid"); - validate_wgsl_syntax(&generate_csc_add_compute_shader(DType::F32).unwrap()) - .expect("CSC add compute should be valid"); - } -} diff --git a/src/runtime/wgpu/shaders/sparse_merge_u32.wgsl b/src/runtime/wgpu/shaders/sparse_merge_u32.wgsl new file mode 100644 index 00000000..a9551c19 --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_merge_u32.wgsl @@ -0,0 +1,526 @@ +// Sparse merge compute shaders - U32 +// +// CSR: csr_add_compute_u32, csr_sub_compute_u32, csr_mul_compute_u32, csr_div_compute_u32 +// CSC: csc_add_compute_u32, csc_sub_compute_u32, csc_mul_compute_u32, csc_div_compute_u32 +// +// Note: U32 subtraction uses wrapping arithmetic. Sub b-only case emits 0u - b_val. + +// ============================================================================ +// csr_add_compute_u32 (union semantics) +// ============================================================================ + +struct CsrAddU32Params { + nrows: u32, +} + +@group(0) @binding(0) var csr_add_u32_a_row_ptrs: array; +@group(0) @binding(1) var csr_add_u32_a_col_indices: array; +@group(0) @binding(2) var csr_add_u32_a_values: array; +@group(0) @binding(3) var csr_add_u32_b_row_ptrs: array; +@group(0) @binding(4) var csr_add_u32_b_col_indices: array; +@group(0) @binding(5) var csr_add_u32_b_values: array; +@group(0) @binding(6) var csr_add_u32_out_row_ptrs: array; +@group(0) @binding(7) var csr_add_u32_out_col_indices: array; +@group(0) @binding(8) var csr_add_u32_out_values: array; +@group(0) @binding(9) var csr_add_u32_params: CsrAddU32Params; + +@compute @workgroup_size(256) +fn csr_add_compute_u32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= csr_add_u32_params.nrows) { + return; + } + + let a_start = csr_add_u32_a_row_ptrs[row]; + let a_end = csr_add_u32_a_row_ptrs[row + 1u]; + let b_start = csr_add_u32_b_row_ptrs[row]; + let b_end = csr_add_u32_b_row_ptrs[row + 1u]; + + var out_idx = csr_add_u32_out_row_ptrs[row]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_col = csr_add_u32_a_col_indices[i]; + let b_col = csr_add_u32_b_col_indices[j]; + let a_val = csr_add_u32_a_values[i]; + let b_val = csr_add_u32_b_values[j]; + + if (a_col < b_col) { + csr_add_u32_out_col_indices[out_idx] = a_col; + csr_add_u32_out_values[out_idx] = a_val; + out_idx = out_idx + 1; + i = i + 1; + } else if (a_col > b_col) { + csr_add_u32_out_col_indices[out_idx] = b_col; + csr_add_u32_out_values[out_idx] = b_val; + out_idx = out_idx + 1; + j = j + 1; + } else { + csr_add_u32_out_col_indices[out_idx] = a_col; + csr_add_u32_out_values[out_idx] = a_val + b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } + + while (i < a_end) { + csr_add_u32_out_col_indices[out_idx] = csr_add_u32_a_col_indices[i]; + csr_add_u32_out_values[out_idx] = csr_add_u32_a_values[i]; + out_idx = out_idx + 1; + i = i + 1; + } + + while (j < b_end) { + csr_add_u32_out_col_indices[out_idx] = csr_add_u32_b_col_indices[j]; + csr_add_u32_out_values[out_idx] = csr_add_u32_b_values[j]; + out_idx = out_idx + 1; + j = j + 1; + } +} + +// ============================================================================ +// csr_sub_compute_u32 (union semantics, wrapping subtraction) +// ============================================================================ + +struct CsrSubU32Params { + nrows: u32, +} + +@group(0) @binding(0) var csr_sub_u32_a_row_ptrs: array; +@group(0) @binding(1) var csr_sub_u32_a_col_indices: array; +@group(0) @binding(2) var csr_sub_u32_a_values: array; +@group(0) @binding(3) var csr_sub_u32_b_row_ptrs: array; +@group(0) @binding(4) var csr_sub_u32_b_col_indices: array; +@group(0) @binding(5) var csr_sub_u32_b_values: array; +@group(0) @binding(6) var csr_sub_u32_out_row_ptrs: array; +@group(0) @binding(7) var csr_sub_u32_out_col_indices: array; +@group(0) @binding(8) var csr_sub_u32_out_values: array; +@group(0) @binding(9) var csr_sub_u32_params: CsrSubU32Params; + +@compute @workgroup_size(256) +fn csr_sub_compute_u32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= csr_sub_u32_params.nrows) { + return; + } + + let a_start = csr_sub_u32_a_row_ptrs[row]; + let a_end = csr_sub_u32_a_row_ptrs[row + 1u]; + let b_start = csr_sub_u32_b_row_ptrs[row]; + let b_end = csr_sub_u32_b_row_ptrs[row + 1u]; + + var out_idx = csr_sub_u32_out_row_ptrs[row]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_col = csr_sub_u32_a_col_indices[i]; + let b_col = csr_sub_u32_b_col_indices[j]; + let a_val = csr_sub_u32_a_values[i]; + let b_val = csr_sub_u32_b_values[j]; + + if (a_col < b_col) { + csr_sub_u32_out_col_indices[out_idx] = a_col; + csr_sub_u32_out_values[out_idx] = a_val; + out_idx = out_idx + 1; + i = i + 1; + } else if (a_col > b_col) { + csr_sub_u32_out_col_indices[out_idx] = b_col; + csr_sub_u32_out_values[out_idx] = 0u - b_val; + out_idx = out_idx + 1; + j = j + 1; + } else { + csr_sub_u32_out_col_indices[out_idx] = a_col; + csr_sub_u32_out_values[out_idx] = a_val - b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } + + while (i < a_end) { + csr_sub_u32_out_col_indices[out_idx] = csr_sub_u32_a_col_indices[i]; + csr_sub_u32_out_values[out_idx] = csr_sub_u32_a_values[i]; + out_idx = out_idx + 1; + i = i + 1; + } + + while (j < b_end) { + csr_sub_u32_out_col_indices[out_idx] = csr_sub_u32_b_col_indices[j]; + csr_sub_u32_out_values[out_idx] = 0u - csr_sub_u32_b_values[j]; + out_idx = out_idx + 1; + j = j + 1; + } +} + +// ============================================================================ +// csr_mul_compute_u32 (intersection semantics) +// ============================================================================ + +struct CsrMulU32Params { + nrows: u32, +} + +@group(0) @binding(0) var csr_mul_u32_a_row_ptrs: array; +@group(0) @binding(1) var csr_mul_u32_a_col_indices: array; +@group(0) @binding(2) var csr_mul_u32_a_values: array; +@group(0) @binding(3) var csr_mul_u32_b_row_ptrs: array; +@group(0) @binding(4) var csr_mul_u32_b_col_indices: array; +@group(0) @binding(5) var csr_mul_u32_b_values: array; +@group(0) @binding(6) var csr_mul_u32_out_row_ptrs: array; +@group(0) @binding(7) var csr_mul_u32_out_col_indices: array; +@group(0) @binding(8) var csr_mul_u32_out_values: array; +@group(0) @binding(9) var csr_mul_u32_params: CsrMulU32Params; + +@compute @workgroup_size(256) +fn csr_mul_compute_u32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= csr_mul_u32_params.nrows) { + return; + } + + let a_start = csr_mul_u32_a_row_ptrs[row]; + let a_end = csr_mul_u32_a_row_ptrs[row + 1u]; + let b_start = csr_mul_u32_b_row_ptrs[row]; + let b_end = csr_mul_u32_b_row_ptrs[row + 1u]; + + var out_idx = csr_mul_u32_out_row_ptrs[row]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_col = csr_mul_u32_a_col_indices[i]; + let b_col = csr_mul_u32_b_col_indices[j]; + + if (a_col < b_col) { + i = i + 1; + } else if (a_col > b_col) { + j = j + 1; + } else { + let a_val = csr_mul_u32_a_values[i]; + let b_val = csr_mul_u32_b_values[j]; + csr_mul_u32_out_col_indices[out_idx] = a_col; + csr_mul_u32_out_values[out_idx] = a_val * b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } +} + +// ============================================================================ +// csr_div_compute_u32 (intersection semantics) +// ============================================================================ + +struct CsrDivU32Params { + nrows: u32, +} + +@group(0) @binding(0) var csr_div_u32_a_row_ptrs: array; +@group(0) @binding(1) var csr_div_u32_a_col_indices: array; +@group(0) @binding(2) var csr_div_u32_a_values: array; +@group(0) @binding(3) var csr_div_u32_b_row_ptrs: array; +@group(0) @binding(4) var csr_div_u32_b_col_indices: array; +@group(0) @binding(5) var csr_div_u32_b_values: array; +@group(0) @binding(6) var csr_div_u32_out_row_ptrs: array; +@group(0) @binding(7) var csr_div_u32_out_col_indices: array; +@group(0) @binding(8) var csr_div_u32_out_values: array; +@group(0) @binding(9) var csr_div_u32_params: CsrDivU32Params; + +@compute @workgroup_size(256) +fn csr_div_compute_u32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= csr_div_u32_params.nrows) { + return; + } + + let a_start = csr_div_u32_a_row_ptrs[row]; + let a_end = csr_div_u32_a_row_ptrs[row + 1u]; + let b_start = csr_div_u32_b_row_ptrs[row]; + let b_end = csr_div_u32_b_row_ptrs[row + 1u]; + + var out_idx = csr_div_u32_out_row_ptrs[row]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_col = csr_div_u32_a_col_indices[i]; + let b_col = csr_div_u32_b_col_indices[j]; + + if (a_col < b_col) { + i = i + 1; + } else if (a_col > b_col) { + j = j + 1; + } else { + let a_val = csr_div_u32_a_values[i]; + let b_val = csr_div_u32_b_values[j]; + csr_div_u32_out_col_indices[out_idx] = a_col; + csr_div_u32_out_values[out_idx] = a_val / b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } +} + +// ============================================================================ +// csc_add_compute_u32 (union semantics) +// ============================================================================ + +struct CscAddU32Params { + ncols: u32, +} + +@group(0) @binding(0) var csc_add_u32_a_col_ptrs: array; +@group(0) @binding(1) var csc_add_u32_a_row_indices: array; +@group(0) @binding(2) var csc_add_u32_a_values: array; +@group(0) @binding(3) var csc_add_u32_b_col_ptrs: array; +@group(0) @binding(4) var csc_add_u32_b_row_indices: array; +@group(0) @binding(5) var csc_add_u32_b_values: array; +@group(0) @binding(6) var csc_add_u32_out_col_ptrs: array; +@group(0) @binding(7) var csc_add_u32_out_row_indices: array; +@group(0) @binding(8) var csc_add_u32_out_values: array; +@group(0) @binding(9) var csc_add_u32_params: CscAddU32Params; + +@compute @workgroup_size(256) +fn csc_add_compute_u32(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csc_add_u32_params.ncols) { + return; + } + + let a_start = csc_add_u32_a_col_ptrs[col]; + let a_end = csc_add_u32_a_col_ptrs[col + 1u]; + let b_start = csc_add_u32_b_col_ptrs[col]; + let b_end = csc_add_u32_b_col_ptrs[col + 1u]; + + var out_idx = csc_add_u32_out_col_ptrs[col]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csc_add_u32_a_row_indices[i]; + let b_row = csc_add_u32_b_row_indices[j]; + let a_val = csc_add_u32_a_values[i]; + let b_val = csc_add_u32_b_values[j]; + + if (a_row < b_row) { + csc_add_u32_out_row_indices[out_idx] = a_row; + csc_add_u32_out_values[out_idx] = a_val; + out_idx = out_idx + 1; + i = i + 1; + } else if (a_row > b_row) { + csc_add_u32_out_row_indices[out_idx] = b_row; + csc_add_u32_out_values[out_idx] = b_val; + out_idx = out_idx + 1; + j = j + 1; + } else { + csc_add_u32_out_row_indices[out_idx] = a_row; + csc_add_u32_out_values[out_idx] = a_val + b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } + + while (i < a_end) { + csc_add_u32_out_row_indices[out_idx] = csc_add_u32_a_row_indices[i]; + csc_add_u32_out_values[out_idx] = csc_add_u32_a_values[i]; + out_idx = out_idx + 1; + i = i + 1; + } + + while (j < b_end) { + csc_add_u32_out_row_indices[out_idx] = csc_add_u32_b_row_indices[j]; + csc_add_u32_out_values[out_idx] = csc_add_u32_b_values[j]; + out_idx = out_idx + 1; + j = j + 1; + } +} + +// ============================================================================ +// csc_sub_compute_u32 (union semantics, wrapping subtraction) +// ============================================================================ + +struct CscSubU32Params { + ncols: u32, +} + +@group(0) @binding(0) var csc_sub_u32_a_col_ptrs: array; +@group(0) @binding(1) var csc_sub_u32_a_row_indices: array; +@group(0) @binding(2) var csc_sub_u32_a_values: array; +@group(0) @binding(3) var csc_sub_u32_b_col_ptrs: array; +@group(0) @binding(4) var csc_sub_u32_b_row_indices: array; +@group(0) @binding(5) var csc_sub_u32_b_values: array; +@group(0) @binding(6) var csc_sub_u32_out_col_ptrs: array; +@group(0) @binding(7) var csc_sub_u32_out_row_indices: array; +@group(0) @binding(8) var csc_sub_u32_out_values: array; +@group(0) @binding(9) var csc_sub_u32_params: CscSubU32Params; + +@compute @workgroup_size(256) +fn csc_sub_compute_u32(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csc_sub_u32_params.ncols) { + return; + } + + let a_start = csc_sub_u32_a_col_ptrs[col]; + let a_end = csc_sub_u32_a_col_ptrs[col + 1u]; + let b_start = csc_sub_u32_b_col_ptrs[col]; + let b_end = csc_sub_u32_b_col_ptrs[col + 1u]; + + var out_idx = csc_sub_u32_out_col_ptrs[col]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csc_sub_u32_a_row_indices[i]; + let b_row = csc_sub_u32_b_row_indices[j]; + let a_val = csc_sub_u32_a_values[i]; + let b_val = csc_sub_u32_b_values[j]; + + if (a_row < b_row) { + csc_sub_u32_out_row_indices[out_idx] = a_row; + csc_sub_u32_out_values[out_idx] = a_val; + out_idx = out_idx + 1; + i = i + 1; + } else if (a_row > b_row) { + csc_sub_u32_out_row_indices[out_idx] = b_row; + csc_sub_u32_out_values[out_idx] = 0u - b_val; + out_idx = out_idx + 1; + j = j + 1; + } else { + csc_sub_u32_out_row_indices[out_idx] = a_row; + csc_sub_u32_out_values[out_idx] = a_val - b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } + + while (i < a_end) { + csc_sub_u32_out_row_indices[out_idx] = csc_sub_u32_a_row_indices[i]; + csc_sub_u32_out_values[out_idx] = csc_sub_u32_a_values[i]; + out_idx = out_idx + 1; + i = i + 1; + } + + while (j < b_end) { + csc_sub_u32_out_row_indices[out_idx] = csc_sub_u32_b_row_indices[j]; + csc_sub_u32_out_values[out_idx] = 0u - csc_sub_u32_b_values[j]; + out_idx = out_idx + 1; + j = j + 1; + } +} + +// ============================================================================ +// csc_mul_compute_u32 (intersection semantics) +// ============================================================================ + +struct CscMulU32Params { + ncols: u32, +} + +@group(0) @binding(0) var csc_mul_u32_a_col_ptrs: array; +@group(0) @binding(1) var csc_mul_u32_a_row_indices: array; +@group(0) @binding(2) var csc_mul_u32_a_values: array; +@group(0) @binding(3) var csc_mul_u32_b_col_ptrs: array; +@group(0) @binding(4) var csc_mul_u32_b_row_indices: array; +@group(0) @binding(5) var csc_mul_u32_b_values: array; +@group(0) @binding(6) var csc_mul_u32_out_col_ptrs: array; +@group(0) @binding(7) var csc_mul_u32_out_row_indices: array; +@group(0) @binding(8) var csc_mul_u32_out_values: array; +@group(0) @binding(9) var csc_mul_u32_params: CscMulU32Params; + +@compute @workgroup_size(256) +fn csc_mul_compute_u32(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csc_mul_u32_params.ncols) { + return; + } + + let a_start = csc_mul_u32_a_col_ptrs[col]; + let a_end = csc_mul_u32_a_col_ptrs[col + 1u]; + let b_start = csc_mul_u32_b_col_ptrs[col]; + let b_end = csc_mul_u32_b_col_ptrs[col + 1u]; + + var out_idx = csc_mul_u32_out_col_ptrs[col]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csc_mul_u32_a_row_indices[i]; + let b_row = csc_mul_u32_b_row_indices[j]; + + if (a_row < b_row) { + i = i + 1; + } else if (a_row > b_row) { + j = j + 1; + } else { + let a_val = csc_mul_u32_a_values[i]; + let b_val = csc_mul_u32_b_values[j]; + csc_mul_u32_out_row_indices[out_idx] = a_row; + csc_mul_u32_out_values[out_idx] = a_val * b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } +} + +// ============================================================================ +// csc_div_compute_u32 (intersection semantics) +// ============================================================================ + +struct CscDivU32Params { + ncols: u32, +} + +@group(0) @binding(0) var csc_div_u32_a_col_ptrs: array; +@group(0) @binding(1) var csc_div_u32_a_row_indices: array; +@group(0) @binding(2) var csc_div_u32_a_values: array; +@group(0) @binding(3) var csc_div_u32_b_col_ptrs: array; +@group(0) @binding(4) var csc_div_u32_b_row_indices: array; +@group(0) @binding(5) var csc_div_u32_b_values: array; +@group(0) @binding(6) var csc_div_u32_out_col_ptrs: array; +@group(0) @binding(7) var csc_div_u32_out_row_indices: array; +@group(0) @binding(8) var csc_div_u32_out_values: array; +@group(0) @binding(9) var csc_div_u32_params: CscDivU32Params; + +@compute @workgroup_size(256) +fn csc_div_compute_u32(@builtin(global_invocation_id) gid: vec3) { + let col = gid.x; + if (col >= csc_div_u32_params.ncols) { + return; + } + + let a_start = csc_div_u32_a_col_ptrs[col]; + let a_end = csc_div_u32_a_col_ptrs[col + 1u]; + let b_start = csc_div_u32_b_col_ptrs[col]; + let b_end = csc_div_u32_b_col_ptrs[col + 1u]; + + var out_idx = csc_div_u32_out_col_ptrs[col]; + var i: i32 = a_start; + var j: i32 = b_start; + + while (i < a_end && j < b_end) { + let a_row = csc_div_u32_a_row_indices[i]; + let b_row = csc_div_u32_b_row_indices[j]; + + if (a_row < b_row) { + i = i + 1; + } else if (a_row > b_row) { + j = j + 1; + } else { + let a_val = csc_div_u32_a_values[i]; + let b_val = csc_div_u32_b_values[j]; + csc_div_u32_out_row_indices[out_idx] = a_row; + csc_div_u32_out_values[out_idx] = a_val / b_val; + out_idx = out_idx + 1; + i = i + 1; + j = j + 1; + } + } +} diff --git a/src/runtime/wgpu/shaders/sparse_spmv_f32.wgsl b/src/runtime/wgpu/shaders/sparse_spmv_f32.wgsl new file mode 100644 index 00000000..a01b0d8f --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_spmv_f32.wgsl @@ -0,0 +1,124 @@ +// CSR Sparse Matrix-Vector Multiplication: y = A * x +// Row-parallel implementation: one thread per row + +const WORKGROUP_SIZE: u32 = 256u; + +struct SpmvParams { + nrows: u32, + ncols: u32, + _pad0: u32, + _pad1: u32, +} + +// CSR format +@group(0) @binding(0) var spmv_row_ptrs: array; +@group(0) @binding(1) var spmv_col_indices: array; +@group(0) @binding(2) var spmv_values: array; +// Dense vector x +@group(0) @binding(3) var spmv_x: array; +// Output vector y +@group(0) @binding(4) var spmv_y: array; +// Parameters +@group(0) @binding(5) var spmv_params: SpmvParams; + +@compute @workgroup_size(256) +fn csr_spmv_f32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= spmv_params.nrows) { + return; + } + + let row_start = spmv_row_ptrs[row]; + let row_end = spmv_row_ptrs[row + 1u]; + + var sum: f32 = 0.0; + for (var j: i32 = row_start; j < row_end; j = j + 1) { + let col = spmv_col_indices[j]; + sum = sum + spmv_values[j] * spmv_x[col]; + } + + spmv_y[row] = sum; +} + +// CSR Sparse Matrix-Dense Matrix Multiplication: C = A * B +// Each thread computes one output element C[row, col] + +struct SpmmParams { + m: u32, + k: u32, + n: u32, + _pad: u32, +} + +// CSR format for A +@group(0) @binding(0) var spmm_row_ptrs: array; +@group(0) @binding(1) var spmm_col_indices: array; +@group(0) @binding(2) var spmm_a_values: array; +// Dense matrix B (k x n, row-major) +@group(0) @binding(3) var spmm_b: array; +// Output matrix C (m x n, row-major) +@group(0) @binding(4) var spmm_c: array; +// Parameters +@group(0) @binding(5) var spmm_params: SpmmParams; + +@compute @workgroup_size(256) +fn csr_spmm_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + let total = spmm_params.m * spmm_params.n; + if (idx >= total) { + return; + } + + let row = idx / spmm_params.n; + let col = idx % spmm_params.n; + + let row_start = spmm_row_ptrs[row]; + let row_end = spmm_row_ptrs[row + 1u]; + + var sum: f32 = 0.0; + for (var j: i32 = row_start; j < row_end; j = j + 1) { + let a_col = spmm_col_indices[j]; + let a_val = spmm_a_values[j]; + let b_idx = u32(a_col) * spmm_params.n + col; + sum = sum + a_val * spmm_b[b_idx]; + } + + spmm_c[idx] = sum; +} + +// CSR Extract Diagonal: diag[i] = A[i,i] +// Thread-per-row: each thread scans one row for col_index == row_index + +struct DiagParams { + n: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var diag_row_ptrs: array; +@group(0) @binding(1) var diag_col_indices: array; +@group(0) @binding(2) var diag_values: array; +@group(0) @binding(3) var diag_out: array; +@group(0) @binding(4) var diag_params: DiagParams; + +@compute @workgroup_size(256) +fn csr_extract_diagonal_f32(@builtin(global_invocation_id) gid: vec3) { + let row = gid.x; + if (row >= diag_params.n) { + return; + } + + let row_start = diag_row_ptrs[row]; + let row_end = diag_row_ptrs[row + 1u]; + + var val: f32 = 0.0; + for (var j: i32 = row_start; j < row_end; j = j + 1) { + if (diag_col_indices[j] == i32(row)) { + val = diag_values[j]; + break; + } + } + + diag_out[row] = val; +} diff --git a/src/runtime/wgpu/shaders/sparse_spmv_launcher.rs b/src/runtime/wgpu/shaders/sparse_spmv_launcher.rs index d69d340d..3fe8f697 100644 --- a/src/runtime/wgpu/shaders/sparse_spmv_launcher.rs +++ b/src/runtime/wgpu/shaders/sparse_spmv_launcher.rs @@ -3,16 +3,25 @@ //! Provides launchers for CSR format SpMV and SpMM operations: //! - `launch_csr_spmv` - Sparse matrix-vector multiplication: y = A * x //! - `launch_csr_spmm` - Sparse matrix-dense matrix multiplication: C = A * B +//! - `launch_csr_extract_diagonal` - Extract diagonal: diag[i] = A[i,i] use wgpu::{Buffer, Queue}; -use super::generator::dtype_suffix; -use super::generator::spmv::{ - generate_csr_extract_diagonal_shader, generate_csr_spmm_shader, generate_csr_spmv_shader, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; -use crate::error::Result; +use crate::error::{Error, Result}; + +const SPARSE_SPMV_F32: &str = include_str!("sparse_spmv_f32.wgsl"); + +fn spmv_shader_info(dtype: DType) -> Result<(&'static str, &'static str)> { + match dtype { + DType::F32 => Ok((SPARSE_SPMV_F32, "sparse_spmv_f32")), + _ => Err(Error::UnsupportedDType { + dtype, + op: "csr_spmv (WebGPU)", + }), + } +} /// Launch CSR SpMV kernel: y = A * x /// @@ -38,12 +47,9 @@ pub fn launch_csr_spmv( nrows: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("csr_spmv_{}", suffix); + let (shader, module_name) = spmv_shader_info(dtype)?; - let shader_source = generate_csr_spmv_shader(dtype)?; - let module_name = format!("csr_spmv_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_name, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 5, // row_ptrs, col_indices, values, x, y @@ -51,7 +57,7 @@ pub fn launch_csr_spmv( num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline("csr_spmv", &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_name, "csr_spmv_f32", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -103,12 +109,9 @@ pub fn launch_csr_spmm( n: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("csr_spmm_{}", suffix); + let (shader, module_name) = spmv_shader_info(dtype)?; - let shader_source = generate_csr_spmm_shader(dtype)?; - let module_name = format!("csr_spmm_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_name, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 5, // row_ptrs, col_indices, a_values, b, c @@ -116,7 +119,7 @@ pub fn launch_csr_spmm( num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline("csr_spmm", &entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_name, "csr_spmm_f32", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -165,12 +168,9 @@ pub fn launch_csr_extract_diagonal( n: usize, dtype: DType, ) -> Result<()> { - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("csr_extract_diagonal_{}", suffix); + let (shader, module_name) = spmv_shader_info(dtype)?; - let shader_source = generate_csr_extract_diagonal_shader(dtype)?; - let module_name = format!("csr_extract_diagonal_{}", suffix); - let module = cache.get_or_create_module_from_source(&module_name, &shader_source); + let module = cache.get_or_create_module(module_name, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 4, // row_ptrs, col_indices, values, diag @@ -178,12 +178,8 @@ pub fn launch_csr_extract_diagonal( num_readonly_storage: 0, }); - let pipeline = cache.get_or_create_dynamic_pipeline( - "csr_extract_diagonal", - &entry_point, - &module, - &layout, - ); + let pipeline = + cache.get_or_create_pipeline(module_name, "csr_extract_diagonal_f32", &module, &layout); let bind_group = cache.create_bind_group( &layout, @@ -209,29 +205,3 @@ pub fn launch_csr_extract_diagonal( queue.submit(std::iter::once(encoder.finish())); Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - - fn validate_wgsl_syntax(source: &str) -> std::result::Result<(), String> { - use wgpu::naga::front::wgsl; - let mut frontend = wgsl::Frontend::new(); - frontend - .parse(source) - .map(|_| ()) - .map_err(|e| format!("WGSL parse error: {e}")) - } - - #[test] - fn test_csr_spmv_shader_syntax_f32() { - let shader = generate_csr_spmv_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("SpMV shader should be valid WGSL"); - } - - #[test] - fn test_csr_spmm_shader_syntax_f32() { - let shader = generate_csr_spmm_shader(DType::F32).unwrap(); - validate_wgsl_syntax(&shader).expect("SpMM shader should be valid WGSL"); - } -} diff --git a/src/runtime/wgpu/shaders/sparse_trsv_lower_f32.wgsl b/src/runtime/wgpu/shaders/sparse_trsv_lower_f32.wgsl new file mode 100644 index 00000000..d13f6cd4 --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_trsv_lower_f32.wgsl @@ -0,0 +1,47 @@ +// Level-scheduled sparse lower triangular solve (forward substitution) +// Processes all rows in a single level in parallel + +struct TrsvParams { + level_size: u32, + n: u32, + unit_diagonal: u32, + level_start: u32, +} + +@group(0) @binding(0) var level_rows: array; +@group(0) @binding(1) var row_ptrs: array; +@group(0) @binding(2) var col_indices: array; +@group(0) @binding(3) var values: array; +@group(0) @binding(4) var b: array; +@group(0) @binding(5) var x: array; +@group(0) @binding(6) var params: TrsvParams; + +@compute @workgroup_size(256) +fn sparse_trsv_lower_level_f32(@builtin(global_invocation_id) gid: vec3) { + let tid = gid.x; + if (tid >= params.level_size) { + return; + } + + let row = level_rows[params.level_start + tid]; + let start = row_ptrs[row]; + let end = row_ptrs[row + 1]; + + var sum = b[row]; + var diag = f32(1.0); + + for (var idx = start; idx < end; idx = idx + 1) { + let col = col_indices[idx]; + if (col < row) { + sum = sum - values[idx] * x[col]; + } else if (col == row && params.unit_diagonal == 0u) { + diag = values[idx]; + } + } + + if (params.unit_diagonal == 0u) { + sum = sum / diag; + } + + x[row] = sum; +} diff --git a/src/runtime/wgpu/shaders/sparse_trsv_lower_multi_rhs_f32.wgsl b/src/runtime/wgpu/shaders/sparse_trsv_lower_multi_rhs_f32.wgsl new file mode 100644 index 00000000..c2cf4c48 --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_trsv_lower_multi_rhs_f32.wgsl @@ -0,0 +1,55 @@ +// Multi-RHS level-scheduled sparse lower triangular solve (forward substitution) +// Processes all (row, rhs_column) pairs in a single level in parallel + +struct TrsvMultiRhsParams { + level_size: u32, + nrhs: u32, + n: u32, + unit_diagonal: u32, + level_start: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var level_rows: array; +@group(0) @binding(1) var row_ptrs: array; +@group(0) @binding(2) var col_indices: array; +@group(0) @binding(3) var values: array; +@group(0) @binding(4) var b: array; +@group(0) @binding(5) var x: array; +@group(0) @binding(6) var params: TrsvMultiRhsParams; + +@compute @workgroup_size(256) +fn sparse_trsv_lower_level_multi_rhs_f32(@builtin(global_invocation_id) gid: vec3) { + let tid = gid.x; + let total_work = params.level_size * params.nrhs; + if (tid >= total_work) { + return; + } + + let row_idx = tid / params.nrhs; + let rhs_col = tid % params.nrhs; + let row = level_rows[params.level_start + row_idx]; + + let start = row_ptrs[row]; + let end = row_ptrs[row + 1]; + + var sum = b[u32(row) * params.nrhs + rhs_col]; + var diag = f32(1.0); + + for (var idx = start; idx < end; idx = idx + 1) { + let col = col_indices[idx]; + if (col < row) { + sum = sum - values[idx] * x[u32(col) * params.nrhs + rhs_col]; + } else if (col == row && params.unit_diagonal == 0u) { + diag = values[idx]; + } + } + + if (params.unit_diagonal == 0u) { + sum = sum / diag; + } + + x[u32(row) * params.nrhs + rhs_col] = sum; +} diff --git a/src/runtime/wgpu/shaders/sparse_trsv_upper_f32.wgsl b/src/runtime/wgpu/shaders/sparse_trsv_upper_f32.wgsl new file mode 100644 index 00000000..bef5d65f --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_trsv_upper_f32.wgsl @@ -0,0 +1,42 @@ +// Level-scheduled sparse upper triangular solve (backward substitution) + +struct TrsvParams { + level_size: u32, + n: u32, + _pad0: u32, + level_start: u32, +} + +@group(0) @binding(0) var level_rows: array; +@group(0) @binding(1) var row_ptrs: array; +@group(0) @binding(2) var col_indices: array; +@group(0) @binding(3) var values: array; +@group(0) @binding(4) var b: array; +@group(0) @binding(5) var x: array; +@group(0) @binding(6) var params: TrsvParams; + +@compute @workgroup_size(256) +fn sparse_trsv_upper_level_f32(@builtin(global_invocation_id) gid: vec3) { + let tid = gid.x; + if (tid >= params.level_size) { + return; + } + + let row = level_rows[params.level_start + tid]; + let start = row_ptrs[row]; + let end = row_ptrs[row + 1]; + + var sum = b[row]; + var diag = f32(1.0); + + for (var idx = start; idx < end; idx = idx + 1) { + let col = col_indices[idx]; + if (col > row) { + sum = sum - values[idx] * x[col]; + } else if (col == row) { + diag = values[idx]; + } + } + + x[row] = sum / diag; +} diff --git a/src/runtime/wgpu/shaders/sparse_trsv_upper_multi_rhs_f32.wgsl b/src/runtime/wgpu/shaders/sparse_trsv_upper_multi_rhs_f32.wgsl new file mode 100644 index 00000000..18c9a7fc --- /dev/null +++ b/src/runtime/wgpu/shaders/sparse_trsv_upper_multi_rhs_f32.wgsl @@ -0,0 +1,50 @@ +// Multi-RHS level-scheduled sparse upper triangular solve (backward substitution) + +struct TrsvMultiRhsParams { + level_size: u32, + nrhs: u32, + n: u32, + _pad0: u32, + level_start: u32, + _pad1: u32, + _pad2: u32, + _pad3: u32, +} + +@group(0) @binding(0) var level_rows: array; +@group(0) @binding(1) var row_ptrs: array; +@group(0) @binding(2) var col_indices: array; +@group(0) @binding(3) var values: array; +@group(0) @binding(4) var b: array; +@group(0) @binding(5) var x: array; +@group(0) @binding(6) var params: TrsvMultiRhsParams; + +@compute @workgroup_size(256) +fn sparse_trsv_upper_level_multi_rhs_f32(@builtin(global_invocation_id) gid: vec3) { + let tid = gid.x; + let total_work = params.level_size * params.nrhs; + if (tid >= total_work) { + return; + } + + let row_idx = tid / params.nrhs; + let rhs_col = tid % params.nrhs; + let row = level_rows[params.level_start + row_idx]; + + let start = row_ptrs[row]; + let end = row_ptrs[row + 1]; + + var sum = b[u32(row) * params.nrhs + rhs_col]; + var diag = f32(1.0); + + for (var idx = start; idx < end; idx = idx + 1) { + let col = col_indices[idx]; + if (col > row) { + sum = sum - values[idx] * x[u32(col) * params.nrhs + rhs_col]; + } else if (col == row) { + diag = values[idx]; + } + } + + x[u32(row) * params.nrhs + rhs_col] = sum / diag; +} diff --git a/src/runtime/wgpu/shaders/special.rs b/src/runtime/wgpu/shaders/special.rs index 937b38b1..31826570 100644 --- a/src/runtime/wgpu/shaders/special.rs +++ b/src/runtime/wgpu/shaders/special.rs @@ -3,99 +3,20 @@ //! Provides native GPU implementations for erf, erfc, erfinv, gamma, //! lgamma, digamma, beta, betainc, gammainc, gammaincc. -use std::collections::HashMap; -use std::sync::{OnceLock, RwLock, RwLockReadGuard, RwLockWriteGuard}; - -// ============================================================================ -// Lock Helpers (Handle Poisoned Locks Gracefully) -// ============================================================================ - -/// Acquire read lock, recovering from poison if necessary. -fn read_lock(lock: &RwLock) -> RwLockReadGuard<'_, T> { - lock.read().unwrap_or_else(|poisoned| poisoned.into_inner()) -} - -/// Acquire write lock, recovering from poison if necessary. -fn write_lock(lock: &RwLock) -> RwLockWriteGuard<'_, T> { - lock.write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} - use wgpu::util::DeviceExt; use wgpu::{Buffer, Queue}; -use super::generator::{ - dtype_suffix, generate_special_binary_shader, generate_special_ternary_shader, - generate_special_unary_shader, -}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; -use crate::error::Result; +use crate::error::{Error, Result}; // ============================================================================ -// Shader Module Cache +// Static WGSL Shader Sources // ============================================================================ -static SPECIAL_UNARY_CACHE: OnceLock>> = OnceLock::new(); -static SPECIAL_BINARY_CACHE: OnceLock>> = OnceLock::new(); -static SPECIAL_TERNARY_CACHE: OnceLock>> = OnceLock::new(); - -fn get_or_leak_special_unary_shader(dtype: DType) -> Result<&'static str> { - let cache = SPECIAL_UNARY_CACHE.get_or_init(|| RwLock::new(HashMap::new())); - - { - let read_guard = read_lock(cache); - if let Some(&shader_ref) = read_guard.get(&dtype) { - return Ok(shader_ref); - } - } - - let shader = generate_special_unary_shader(dtype)?; - let leaked: &'static str = Box::leak(shader.into_boxed_str()); - - let mut write_guard = write_lock(cache); - write_guard.insert(dtype, leaked); - - Ok(leaked) -} - -fn get_or_leak_special_binary_shader(dtype: DType) -> Result<&'static str> { - let cache = SPECIAL_BINARY_CACHE.get_or_init(|| RwLock::new(HashMap::new())); - - { - let read_guard = read_lock(cache); - if let Some(&shader_ref) = read_guard.get(&dtype) { - return Ok(shader_ref); - } - } - - let shader = generate_special_binary_shader(dtype)?; - let leaked: &'static str = Box::leak(shader.into_boxed_str()); - - let mut write_guard = write_lock(cache); - write_guard.insert(dtype, leaked); - - Ok(leaked) -} - -fn get_or_leak_special_ternary_shader(dtype: DType) -> Result<&'static str> { - let cache = SPECIAL_TERNARY_CACHE.get_or_init(|| RwLock::new(HashMap::new())); - - { - let read_guard = read_lock(cache); - if let Some(&shader_ref) = read_guard.get(&dtype) { - return Ok(shader_ref); - } - } - - let shader = generate_special_ternary_shader(dtype)?; - let leaked: &'static str = Box::leak(shader.into_boxed_str()); - - let mut write_guard = write_lock(cache); - write_guard.insert(dtype, leaked); - - Ok(leaked) -} +const SPECIAL_UNARY_F32: &str = include_str!("special_unary_f32.wgsl"); +const SPECIAL_BINARY_F32: &str = include_str!("special_binary_f32.wgsl"); +const SPECIAL_TERNARY_F32: &str = include_str!("special_ternary_f32.wgsl"); // ============================================================================ // Unary Special Functions (erf, erfc, erfinv, gamma, lgamma, digamma) @@ -111,12 +32,16 @@ pub fn launch_special_unary( numel: u32, dtype: DType, ) -> Result<()> { - let shader = get_or_leak_special_unary_shader(dtype)?; - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("{}_{}", op, suffix); - let module_key = format!("special_unary_{}", suffix); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "special_unary", + }); + } + let entry_point = format!("{}_f32", op); + let module_key = "special_unary_f32"; - let module = pipeline_cache.get_or_create_module_from_source(&module_key, shader); + let module = pipeline_cache.get_or_create_module(module_key, SPECIAL_UNARY_F32); // Layout: 2 storage buffers (input, output) + 1 uniform (params) let layout = pipeline_cache.get_or_create_layout(LayoutKey { @@ -126,7 +51,7 @@ pub fn launch_special_unary( }); let pipeline = - pipeline_cache.get_or_create_dynamic_pipeline(&module_key, &entry_point, &module, &layout); + pipeline_cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); // Create params buffer let params_data = [numel]; @@ -180,12 +105,16 @@ pub fn launch_special_binary( numel: u32, dtype: DType, ) -> Result<()> { - let shader = get_or_leak_special_binary_shader(dtype)?; - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("{}_{}", op, suffix); - let module_key = format!("special_binary_{}", suffix); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "special_binary", + }); + } + let entry_point = format!("{}_f32", op); + let module_key = "special_binary_f32"; - let module = pipeline_cache.get_or_create_module_from_source(&module_key, shader); + let module = pipeline_cache.get_or_create_module(module_key, SPECIAL_BINARY_F32); // Layout: 3 storage buffers (input_a, input_b, output) + 1 uniform (params) let layout = pipeline_cache.get_or_create_layout(LayoutKey { @@ -195,7 +124,7 @@ pub fn launch_special_binary( }); let pipeline = - pipeline_cache.get_or_create_dynamic_pipeline(&module_key, &entry_point, &module, &layout); + pipeline_cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); // Create params buffer let params_data = [numel]; @@ -251,12 +180,16 @@ pub fn launch_special_ternary( numel: u32, dtype: DType, ) -> Result<()> { - let shader = get_or_leak_special_ternary_shader(dtype)?; - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("{}_{}", op, suffix); - let module_key = format!("special_ternary_{}", suffix); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "special_ternary", + }); + } + let entry_point = format!("{}_f32", op); + let module_key = "special_ternary_f32"; - let module = pipeline_cache.get_or_create_module_from_source(&module_key, shader); + let module = pipeline_cache.get_or_create_module(module_key, SPECIAL_TERNARY_F32); // Layout: 4 storage buffers (input_a, input_b, input_x, output) + 1 uniform (params) let layout = pipeline_cache.get_or_create_layout(LayoutKey { @@ -266,7 +199,7 @@ pub fn launch_special_ternary( }); let pipeline = - pipeline_cache.get_or_create_dynamic_pipeline(&module_key, &entry_point, &module, &layout); + pipeline_cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); // Create params buffer let params_data = [numel]; @@ -323,12 +256,16 @@ pub fn launch_special_unary_with_int( n: i32, dtype: DType, ) -> Result<()> { - let shader = get_or_leak_special_unary_shader(dtype)?; - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("{}_{}", op, suffix); - let module_key = format!("special_unary_{}", suffix); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "special_unary_with_int", + }); + } + let entry_point = format!("{}_f32", op); + let module_key = "special_unary_f32"; - let module = pipeline_cache.get_or_create_module_from_source(&module_key, shader); + let module = pipeline_cache.get_or_create_module(module_key, SPECIAL_UNARY_F32); // Layout: 2 storage buffers + 1 uniform (params with numel and n) let layout = pipeline_cache.get_or_create_layout(LayoutKey { @@ -338,7 +275,7 @@ pub fn launch_special_unary_with_int( }); let pipeline = - pipeline_cache.get_or_create_dynamic_pipeline(&module_key, &entry_point, &module, &layout); + pipeline_cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); // Create params buffer with numel and n let params_data = [numel, n as u32]; @@ -387,12 +324,16 @@ pub fn launch_special_unary_with_two_ints( m: i32, dtype: DType, ) -> Result<()> { - let shader = get_or_leak_special_unary_shader(dtype)?; - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("{}_{}", op, suffix); - let module_key = format!("special_unary_{}", suffix); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "special_unary_with_two_ints", + }); + } + let entry_point = format!("{}_f32", op); + let module_key = "special_unary_f32"; - let module = pipeline_cache.get_or_create_module_from_source(&module_key, shader); + let module = pipeline_cache.get_or_create_module(module_key, SPECIAL_UNARY_F32); let layout = pipeline_cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, @@ -401,7 +342,7 @@ pub fn launch_special_unary_with_two_ints( }); let pipeline = - pipeline_cache.get_or_create_dynamic_pipeline(&module_key, &entry_point, &module, &layout); + pipeline_cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); // Create params buffer with numel, n, m let params_data = [numel, n as u32, m as u32, 0u32]; // Pad to 16 bytes @@ -451,12 +392,16 @@ pub fn launch_special_binary_with_two_ints( m: i32, dtype: DType, ) -> Result<()> { - let shader = get_or_leak_special_binary_shader(dtype)?; - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("{}_{}", op, suffix); - let module_key = format!("special_binary_{}", suffix); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "special_binary_with_two_ints", + }); + } + let entry_point = format!("{}_f32", op); + let module_key = "special_binary_f32"; - let module = pipeline_cache.get_or_create_module_from_source(&module_key, shader); + let module = pipeline_cache.get_or_create_module(module_key, SPECIAL_BINARY_F32); let layout = pipeline_cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, @@ -465,7 +410,7 @@ pub fn launch_special_binary_with_two_ints( }); let pipeline = - pipeline_cache.get_or_create_dynamic_pipeline(&module_key, &entry_point, &module, &layout); + pipeline_cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); // Create params buffer with numel, n, m let params_data = [numel, n as u32, m as u32, 0u32]; // Pad to 16 bytes @@ -515,12 +460,16 @@ pub fn launch_special_unary_with_2f32( b: f32, dtype: DType, ) -> Result<()> { - let shader = get_or_leak_special_unary_shader(dtype)?; - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("{}_{}", op, suffix); - let module_key = format!("special_unary_{}", suffix); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "special_unary_with_2f32", + }); + } + let entry_point = format!("{}_f32", op); + let module_key = "special_unary_f32"; - let module = pipeline_cache.get_or_create_module_from_source(&module_key, shader); + let module = pipeline_cache.get_or_create_module(module_key, SPECIAL_UNARY_F32); let layout = pipeline_cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, @@ -529,7 +478,7 @@ pub fn launch_special_unary_with_2f32( }); let pipeline = - pipeline_cache.get_or_create_dynamic_pipeline(&module_key, &entry_point, &module, &layout); + pipeline_cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); // Create params buffer with numel, a, b (use u32 + 2 f32s) let numel_bits = numel; @@ -580,12 +529,16 @@ pub fn launch_special_unary_with_3f32( c: f32, dtype: DType, ) -> Result<()> { - let shader = get_or_leak_special_unary_shader(dtype)?; - let suffix = dtype_suffix(dtype)?; - let entry_point = format!("{}_{}", op, suffix); - let module_key = format!("special_unary_{}", suffix); + if dtype != DType::F32 { + return Err(Error::UnsupportedDType { + dtype, + op: "special_unary_with_3f32", + }); + } + let entry_point = format!("{}_f32", op); + let module_key = "special_unary_f32"; - let module = pipeline_cache.get_or_create_module_from_source(&module_key, shader); + let module = pipeline_cache.get_or_create_module(module_key, SPECIAL_UNARY_F32); let layout = pipeline_cache.get_or_create_layout(LayoutKey { num_storage_buffers: 2, @@ -594,7 +547,7 @@ pub fn launch_special_unary_with_3f32( }); let pipeline = - pipeline_cache.get_or_create_dynamic_pipeline(&module_key, &entry_point, &module, &layout); + pipeline_cache.get_or_create_dynamic_pipeline(module_key, &entry_point, &module, &layout); // Create params buffer with numel, a, b, c let params_data: [u32; 6] = [numel, 0, a.to_bits(), b.to_bits(), c.to_bits(), 0]; diff --git a/src/runtime/wgpu/shaders/special_binary_f32.wgsl b/src/runtime/wgpu/shaders/special_binary_f32.wgsl new file mode 100644 index 00000000..b0770b54 --- /dev/null +++ b/src/runtime/wgpu/shaders/special_binary_f32.wgsl @@ -0,0 +1,183 @@ +// Auto-generated special binary functions for f32 + +const WORKGROUP_SIZE: u32 = 256u; +const PI: f32 = 3.14159265358979323846; +const SQRT_PI: f32 = 1.7724538509055159; +const EULER_GAMMA: f32 = 0.5772156649015329; +const LN_SQRT_2PI: f32 = 0.9189385332046727; +const LANCZOS_G: f32 = 7.0; +const MAX_ITER: i32 = 100; +const EPSILON: f32 = 1e-6; +const TINY: f32 = 1e-30; + +struct SpecialBinaryParams { + numel: u32, +} + +@group(0) @binding(0) var special_a: array; +@group(0) @binding(1) var special_b: array; +@group(0) @binding(2) var special_out: array; +@group(0) @binding(3) var special_params: SpecialBinaryParams; + +// ============================================================================ +// Helper Functions (shared lgamma) +// ============================================================================ + +// Lanczos computation for positive x only (no recursion) +fn lgamma_positive(x: f32) -> f32 { + // Lanczos coefficients (g=7, n=9) + let c0 = 0.99999999999980993; + let c1 = 676.5203681218851; + let c2 = -1259.1392167224028; + let c3 = 771.32342877765313; + let c4 = -176.61502916214059; + let c5 = 12.507343278686905; + let c6 = -0.13857109526572012; + let c7 = 9.9843695780195716e-6; + let c8 = 1.5056327351493116e-7; + + let z = x - 1.0; + var ag = c0; + ag = ag + c1 / (z + 1.0); + ag = ag + c2 / (z + 2.0); + ag = ag + c3 / (z + 3.0); + ag = ag + c4 / (z + 4.0); + ag = ag + c5 / (z + 5.0); + ag = ag + c6 / (z + 6.0); + ag = ag + c7 / (z + 7.0); + ag = ag + c8 / (z + 8.0); + + let t = z + LANCZOS_G + 0.5; + return LN_SQRT_2PI + (z + 0.5) * log(t) - t + log(ag); +} + +// Log-gamma using Lanczos approximation (non-recursive) +fn lgamma_impl(x: f32) -> f32 { + if (x <= 0.0) { + // Use reflection formula for negative values + if (x == floor(x)) { + return 1e30; // Pole at non-positive integers + } + // lgamma(x) = log(pi / sin(pi*x)) - lgamma(1-x) + // Since 1-x > 0 for x <= 0, we call lgamma_positive directly + let sinpix = sin(PI * x); + if (sinpix == 0.0) { + return 1e30; + } + return log(PI / abs(sinpix)) - lgamma_positive(1.0 - x); + } + + return lgamma_positive(x); +} + +// Lower incomplete gamma series +fn gammainc_series(a: f32, x: f32) -> f32 { + if (x == 0.0) { + return 0.0; + } + + var term = 1.0 / a; + var sum = term; + + for (var n = 1; n < MAX_ITER; n = n + 1) { + term = term * x / (a + f32(n)); + sum = sum + term; + if (abs(term) < abs(sum) * EPSILON) { + break; + } + } + + return exp(-x + a * log(x) - lgamma_impl(a)) * sum; +} + +// Upper incomplete gamma continued fraction +fn gammaincc_cf(a: f32, x: f32) -> f32 { + var f = 1e30; + var c = 1e30; + var d = 0.0; + + for (var n = 1; n < MAX_ITER; n = n + 1) { + var an: f32; + if (n % 2 == 1) { + an = f32((n + 1) / 2); + } else { + an = a - f32(n / 2); + } + let bn = x + f32(n) - a; + + d = bn + an * d; + if (abs(d) < TINY) { + d = TINY; + } + c = bn + an / c; + if (abs(c) < TINY) { + c = TINY; + } + + d = 1.0 / d; + let delta = c * d; + f = f * delta; + + if (abs(delta - 1.0) < EPSILON) { + break; + } + } + + return exp(-x + a * log(x) - lgamma_impl(a)) / f; +} + +fn gammainc_impl(a: f32, x: f32) -> f32 { + if (x < 0.0 || a <= 0.0) { + return bitcast(0x7FC00000u); // NaN + } + if (x == 0.0) { + return 0.0; + } + if (x < a + 1.0) { + return gammainc_series(a, x); + } + return 1.0 - gammaincc_cf(a, x); +} + +fn gammaincc_impl(a: f32, x: f32) -> f32 { + if (x < 0.0 || a <= 0.0) { + return bitcast(0x7FC00000u); // NaN + } + if (x == 0.0) { + return 1.0; + } + if (x < a + 1.0) { + return 1.0 - gammainc_series(a, x); + } + return gammaincc_cf(a, x); +} + +// ============================================================================ +// Compute Kernels +// ============================================================================ + +@compute @workgroup_size(256) +fn beta_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < special_params.numel) { + let a = special_a[idx]; + let b = special_b[idx]; + special_out[idx] = exp(lgamma_impl(a) + lgamma_impl(b) - lgamma_impl(a + b)); + } +} + +@compute @workgroup_size(256) +fn gammainc_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < special_params.numel) { + special_out[idx] = gammainc_impl(special_a[idx], special_b[idx]); + } +} + +@compute @workgroup_size(256) +fn gammaincc_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < special_params.numel) { + special_out[idx] = gammaincc_impl(special_a[idx], special_b[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/special_ternary_f32.wgsl b/src/runtime/wgpu/shaders/special_ternary_f32.wgsl new file mode 100644 index 00000000..0d536d02 --- /dev/null +++ b/src/runtime/wgpu/shaders/special_ternary_f32.wgsl @@ -0,0 +1,152 @@ +// Auto-generated special ternary functions for f32 + +const WORKGROUP_SIZE: u32 = 256u; +const PI: f32 = 3.14159265358979323846; +const SQRT_PI: f32 = 1.7724538509055159; +const EULER_GAMMA: f32 = 0.5772156649015329; +const LN_SQRT_2PI: f32 = 0.9189385332046727; +const LANCZOS_G: f32 = 7.0; +const MAX_ITER: i32 = 100; +const EPSILON: f32 = 1e-6; +const TINY: f32 = 1e-30; + +struct SpecialTernaryParams { + numel: u32, +} + +@group(0) @binding(0) var special_a: array; +@group(0) @binding(1) var special_b: array; +@group(0) @binding(2) var special_x: array; +@group(0) @binding(3) var special_out: array; +@group(0) @binding(4) var special_params: SpecialTernaryParams; + +// ============================================================================ +// Helper Functions (shared lgamma) +// ============================================================================ + +// Lanczos computation for positive x only (no recursion) +fn lgamma_positive(x: f32) -> f32 { + // Lanczos coefficients (g=7, n=9) + let c0 = 0.99999999999980993; + let c1 = 676.5203681218851; + let c2 = -1259.1392167224028; + let c3 = 771.32342877765313; + let c4 = -176.61502916214059; + let c5 = 12.507343278686905; + let c6 = -0.13857109526572012; + let c7 = 9.9843695780195716e-6; + let c8 = 1.5056327351493116e-7; + + let z = x - 1.0; + var ag = c0; + ag = ag + c1 / (z + 1.0); + ag = ag + c2 / (z + 2.0); + ag = ag + c3 / (z + 3.0); + ag = ag + c4 / (z + 4.0); + ag = ag + c5 / (z + 5.0); + ag = ag + c6 / (z + 6.0); + ag = ag + c7 / (z + 7.0); + ag = ag + c8 / (z + 8.0); + + let t = z + LANCZOS_G + 0.5; + return LN_SQRT_2PI + (z + 0.5) * log(t) - t + log(ag); +} + +// Log-gamma using Lanczos approximation (non-recursive) +fn lgamma_impl(x: f32) -> f32 { + if (x <= 0.0) { + // Use reflection formula for negative values + if (x == floor(x)) { + return 1e30; // Pole at non-positive integers + } + // lgamma(x) = log(pi / sin(pi*x)) - lgamma(1-x) + // Since 1-x > 0 for x <= 0, we call lgamma_positive directly + let sinpix = sin(PI * x); + if (sinpix == 0.0) { + return 1e30; + } + return log(PI / abs(sinpix)) - lgamma_positive(1.0 - x); + } + + return lgamma_positive(x); +} + +// Regularized incomplete beta using continued fraction +fn betainc_cf(a: f32, b: f32, x: f32) -> f32 { + let qab = a + b; + let qap = a + 1.0; + let qam = a - 1.0; + + var c = 1.0; + var d = 1.0 - qab * x / qap; + if (abs(d) < TINY) { + d = TINY; + } + d = 1.0 / d; + var h = d; + + for (var m = 1; m < MAX_ITER; m = m + 1) { + let m2 = 2 * m; + + var aa = f32(m) * (b - f32(m)) * x / ((qam + f32(m2)) * (a + f32(m2))); + d = 1.0 + aa * d; + if (abs(d) < TINY) { + d = TINY; + } + c = 1.0 + aa / c; + if (abs(c) < TINY) { + c = TINY; + } + d = 1.0 / d; + h = h * d * c; + + aa = -(a + f32(m)) * (qab + f32(m)) * x / ((a + f32(m2)) * (qap + f32(m2))); + d = 1.0 + aa * d; + if (abs(d) < TINY) { + d = TINY; + } + c = 1.0 + aa / c; + if (abs(c) < TINY) { + c = TINY; + } + d = 1.0 / d; + let delta = d * c; + h = h * delta; + + if (abs(delta - 1.0) < EPSILON) { + break; + } + } + + let lnbeta = lgamma_impl(a) + lgamma_impl(b) - lgamma_impl(a + b); + return exp(a * log(x) + b * log(1.0 - x) - lnbeta) * h / a; +} + +fn betainc_impl(a: f32, b: f32, x: f32) -> f32 { + if (x <= 0.0) { + return 0.0; + } + if (x >= 1.0) { + return 1.0; + } + + // Use symmetry for better convergence (non-recursive version) + if (x > (a + 1.0) / (a + b + 2.0)) { + // Compute directly without recursion using symmetry + return 1.0 - betainc_cf(b, a, 1.0 - x); + } + + return betainc_cf(a, b, x); +} + +// ============================================================================ +// Compute Kernels +// ============================================================================ + +@compute @workgroup_size(256) +fn betainc_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < special_params.numel) { + special_out[idx] = betainc_impl(special_a[idx], special_b[idx], special_x[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/generator/special/unary.rs b/src/runtime/wgpu/shaders/special_unary_f32.wgsl similarity index 74% rename from src/runtime/wgpu/shaders/generator/special/unary.rs rename to src/runtime/wgpu/shaders/special_unary_f32.wgsl index 6d358898..6f7730de 100644 --- a/src/runtime/wgpu/shaders/generator/special/unary.rs +++ b/src/runtime/wgpu/shaders/special_unary_f32.wgsl @@ -1,36 +1,22 @@ -//! WGSL shader generation for special unary functions -//! -//! Generates shaders for: erf, erfc, erfinv, gamma, lgamma, digamma - -use super::super::common::{dtype_suffix, wgsl_type}; -use super::{common_constants, lgamma_helpers}; -use crate::dtype::DType; -use crate::error::{Error, Result}; - -/// Generate WGSL shader for special unary functions (erf, erfc, erfinv, gamma, lgamma, digamma) -pub fn generate_special_unary_shader(dtype: DType) -> Result { - if dtype != DType::F32 { - return Err(Error::UnsupportedDType { - dtype, - op: "special functions (WebGPU requires F32)", - }); - } - - let t = wgsl_type(dtype)?; - let suffix = dtype_suffix(dtype)?; - - Ok(format!( - r#"// Auto-generated special functions for {t} +// Auto-generated special functions for f32 // Algorithms: A&S for erf, Lanczos for gamma, asymptotic for digamma -{constants} - -struct SpecialParams {{ +const WORKGROUP_SIZE: u32 = 256u; +const PI: f32 = 3.14159265358979323846; +const SQRT_PI: f32 = 1.7724538509055159; +const EULER_GAMMA: f32 = 0.5772156649015329; +const LN_SQRT_2PI: f32 = 0.9189385332046727; +const LANCZOS_G: f32 = 7.0; +const MAX_ITER: i32 = 100; +const EPSILON: f32 = 1e-6; +const TINY: f32 = 1e-30; + +struct SpecialParams { numel: u32, -}} +} -@group(0) @binding(0) var special_a: array<{t}>; -@group(0) @binding(1) var special_out: array<{t}>; +@group(0) @binding(0) var special_a: array; +@group(0) @binding(1) var special_out: array; @group(0) @binding(2) var special_params: SpecialParams; // ============================================================================ @@ -38,10 +24,10 @@ struct SpecialParams {{ // ============================================================================ // Error function using Abramowitz & Stegun approximation 7.1.26 -fn erf_impl(x: f32) -> f32 {{ - if (x == 0.0) {{ +fn erf_impl(x: f32) -> f32 { + if (x == 0.0) { return 0.0; - }} + } let sgn = select(-1.0, 1.0, x >= 0.0); let ax = abs(x); @@ -63,62 +49,108 @@ fn erf_impl(x: f32) -> f32 {{ let y = 1.0 - (a1 * t + a2 * t2 + a3 * t3 + a4 * t4 + a5 * t5) * exp(-ax * ax); return sgn * y; -}} +} // Complementary error function -fn erfc_impl(x: f32) -> f32 {{ +fn erfc_impl(x: f32) -> f32 { return 1.0 - erf_impl(x); -}} +} // Inverse error function using rational approximation -fn erfinv_impl(x: f32) -> f32 {{ - if (x <= -1.0) {{ +fn erfinv_impl(x: f32) -> f32 { + if (x <= -1.0) { return -1e30; // -inf approximation - }} - if (x >= 1.0) {{ + } + if (x >= 1.0) { return 1e30; // +inf approximation - }} - if (x == 0.0) {{ + } + if (x == 0.0) { return 0.0; - }} + } let sgn = select(-1.0, 1.0, x >= 0.0); let ax = abs(x); // Rational approximation for central region - if (ax <= 0.7) {{ + if (ax <= 0.7) { let x2 = ax * ax; let r = ax * ((((-0.140543331 * x2 + 0.914624893) * x2 - 1.645349621) * x2 + 0.886226899) / ((((0.012229801 * x2 - 0.329097515) * x2 + 1.442710462) * x2 - 2.118377725) * x2 + 1.0)); return sgn * r; - }} + } // Tail approximation let z = sqrt(-log((1.0 - ax) / 2.0)); let r = (((1.641345311 * z + 3.429567803) * z - 1.624906493) * z - 1.970840454) / ((1.637067800 * z + 3.543889200) * z + 1.0); return sgn * r; -}} -{lgamma_helpers} +} + +// Lanczos computation for positive x only (no recursion) +fn lgamma_positive(x: f32) -> f32 { + // Lanczos coefficients (g=7, n=9) + let c0 = 0.99999999999980993; + let c1 = 676.5203681218851; + let c2 = -1259.1392167224028; + let c3 = 771.32342877765313; + let c4 = -176.61502916214059; + let c5 = 12.507343278686905; + let c6 = -0.13857109526572012; + let c7 = 9.9843695780195716e-6; + let c8 = 1.5056327351493116e-7; + + let z = x - 1.0; + var ag = c0; + ag = ag + c1 / (z + 1.0); + ag = ag + c2 / (z + 2.0); + ag = ag + c3 / (z + 3.0); + ag = ag + c4 / (z + 4.0); + ag = ag + c5 / (z + 5.0); + ag = ag + c6 / (z + 6.0); + ag = ag + c7 / (z + 7.0); + ag = ag + c8 / (z + 8.0); + + let t = z + LANCZOS_G + 0.5; + return LN_SQRT_2PI + (z + 0.5) * log(t) - t + log(ag); +} + +// Log-gamma using Lanczos approximation (non-recursive) +fn lgamma_impl(x: f32) -> f32 { + if (x <= 0.0) { + // Use reflection formula for negative values + if (x == floor(x)) { + return 1e30; // Pole at non-positive integers + } + // lgamma(x) = log(pi / sin(pi*x)) - lgamma(1-x) + // Since 1-x > 0 for x <= 0, we call lgamma_positive directly + let sinpix = sin(PI * x); + if (sinpix == 0.0) { + return 1e30; + } + return log(PI / abs(sinpix)) - lgamma_positive(1.0 - x); + } + + return lgamma_positive(x); +} // Gamma function -fn gamma_impl(x: f32) -> f32 {{ - if (x <= 0.0 && x == floor(x)) {{ +fn gamma_impl(x: f32) -> f32 { + if (x <= 0.0 && x == floor(x)) { return 1e30; // Pole - }} + } return exp(lgamma_impl(x)); -}} +} // Digamma for positive x using asymptotic expansion (no recursion) -fn digamma_positive(x: f32) -> f32 {{ +fn digamma_positive(x: f32) -> f32 { var result = 0.0; var xx = x; // Recurrence to shift to large x where asymptotic works - while (xx < 6.0) {{ + while (xx < 6.0) { result = result - 1.0 / xx; xx = xx + 1.0; - }} + } // Asymptotic expansion let x2 = 1.0 / (xx * xx); @@ -126,84 +158,84 @@ fn digamma_positive(x: f32) -> f32 {{ result = result - x2 * (1.0/12.0 - x2 * (1.0/120.0 - x2 * (1.0/252.0))); return result; -}} +} // Digamma function (non-recursive) -fn digamma_impl(x: f32) -> f32 {{ - if (x <= 0.0 && x == floor(x)) {{ +fn digamma_impl(x: f32) -> f32 { + if (x <= 0.0 && x == floor(x)) { return 1e30; // Pole at non-positive integers - }} + } // Reflection formula for negative x (non-recursive) - if (x < 0.0) {{ + if (x < 0.0) { // For negative x, 1-x > 0, so we can call digamma_positive directly return digamma_positive(1.0 - x) - PI / tan(PI * x); - }} + } return digamma_positive(x); -}} +} // ============================================================================ // Compute Kernels // ============================================================================ @compute @workgroup_size(256) -fn erf_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn erf_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = erf_impl(special_a[idx]); - }} -}} + } +} @compute @workgroup_size(256) -fn erfc_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn erfc_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = erfc_impl(special_a[idx]); - }} -}} + } +} @compute @workgroup_size(256) -fn erfinv_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn erfinv_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = erfinv_impl(special_a[idx]); - }} -}} + } +} @compute @workgroup_size(256) -fn gamma_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn gamma_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = gamma_impl(special_a[idx]); - }} -}} + } +} @compute @workgroup_size(256) -fn lgamma_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn lgamma_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = lgamma_impl(special_a[idx]); - }} -}} + } +} @compute @workgroup_size(256) -fn digamma_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn digamma_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = digamma_impl(special_a[idx]); - }} -}} + } +} // ============================================================================ // Bessel Functions // ============================================================================ // J0: Bessel function of the first kind, order 0 (Numerical Recipes style) -fn bessel_j0_impl(x: f32) -> f32 {{ +fn bessel_j0_impl(x: f32) -> f32 { let ax = abs(x); - if (ax < 8.0) {{ + if (ax < 8.0) { let y = x * x; // Numerator polynomial @@ -226,7 +258,7 @@ fn bessel_j0_impl(x: f32) -> f32 {{ let den = q1 + y * (q2 + y * (q3 + y * (q4 + y * (q5 + y * q6)))); return num / den; - }} else {{ + } else { // Asymptotic expansion let z = 8.0 / ax; let y = z * z; @@ -248,15 +280,15 @@ fn bessel_j0_impl(x: f32) -> f32 {{ let q0 = z * (q1 + y * (q2 + y * (q3 + y * (q4 + y * q5)))); return sqrt(0.636619772 / ax) * (cos(xx) * p0 - sin(xx) * q0); - }} -}} + } +} // J1: Bessel function of the first kind, order 1 -fn bessel_j1_impl(x: f32) -> f32 {{ +fn bessel_j1_impl(x: f32) -> f32 { let ax = abs(x); var result: f32; - if (ax < 8.0) {{ + if (ax < 8.0) { let y = x * x; // Numerator polynomial @@ -279,7 +311,7 @@ fn bessel_j1_impl(x: f32) -> f32 {{ let den = q1 + y * (q2 + y * (q3 + y * (q4 + y * (q5 + y * q6)))); result = num / den; - }} else {{ + } else { let z = 8.0 / ax; let y = z * z; let xx = ax - 2.356194490; // ax - 3π/4 @@ -301,18 +333,18 @@ fn bessel_j1_impl(x: f32) -> f32 {{ let sign = select(-1.0, 1.0, x >= 0.0); result = sign * sqrt(0.636619772 / ax) * (cos(xx) * p0 - sin(xx) * q0); - }} + } return result; -}} +} // Y0: Bessel function of the second kind, order 0 (Numerical Recipes style) -fn bessel_y0_impl(x: f32) -> f32 {{ - if (x <= 0.0) {{ +fn bessel_y0_impl(x: f32) -> f32 { + if (x <= 0.0) { return 1e30; // NaN approximation for WGSL - }} + } - if (x < 8.0) {{ + if (x < 8.0) { let y = x * x; // Numerator polynomial @@ -335,7 +367,7 @@ fn bessel_y0_impl(x: f32) -> f32 {{ let den = q1 + y * (q2 + y * (q3 + y * (q4 + y * (q5 + y * q6)))); return num / den + 0.636619772 * bessel_j0_impl(x) * log(x); - }} else {{ + } else { // Asymptotic expansion for x >= 8 let z = 8.0 / x; let y = z * z; @@ -359,16 +391,16 @@ fn bessel_y0_impl(x: f32) -> f32 {{ let q0 = z * (q1 + y * (q2 + y * (q3 + y * (q4 + y * q5)))); return sqrt(0.636619772 / x) * (sin(xx) * p0 + cos(xx) * q0); - }} -}} + } +} // Y1: Bessel function of the second kind, order 1 (Numerical Recipes style) -fn bessel_y1_impl(x: f32) -> f32 {{ - if (x <= 0.0) {{ +fn bessel_y1_impl(x: f32) -> f32 { + if (x <= 0.0) { return 1e30; // NaN approximation - }} + } - if (x < 8.0) {{ + if (x < 8.0) { let y = x * x; // Numerator polynomial (Numerical Recipes coefficients) @@ -392,7 +424,7 @@ fn bessel_y1_impl(x: f32) -> f32 {{ let den = q1 + y * (q2 + y * (q3 + y * (q4 + y * (q5 + y * (q6 + y * q7))))); return num / den + 0.636619772 * (bessel_j1_impl(x) * log(x) - 1.0 / x); - }} else {{ + } else { // Asymptotic expansion for x >= 8 let z = 8.0 / x; let y = z * z; @@ -416,30 +448,30 @@ fn bessel_y1_impl(x: f32) -> f32 {{ let q0 = z * (q1 + y * (q2 + y * (q3 + y * (q4 + y * q5)))); return sqrt(0.636619772 / x) * (sin(xx) * p0 + cos(xx) * q0); - }} -}} + } +} // I0: Modified Bessel function of the first kind, order 0 -fn bessel_i0_impl(x: f32) -> f32 {{ +fn bessel_i0_impl(x: f32) -> f32 { let ax = abs(x); - if (ax <= 15.0) {{ + if (ax <= 15.0) { // Power series let z = ax * ax; var sum = 1.0; var term = 1.0; - for (var k = 1; k < 25; k++) {{ + for (var k = 1; k < 25; k++) { let kf = f32(k); term = term * z / (4.0 * kf * kf); sum = sum + term; - if (abs(term) < abs(sum) * 1e-7) {{ + if (abs(term) < abs(sum) * 1e-7) { break; - }} - }} + } + } return sum; - }} else {{ + } else { // Asymptotic expansion let z = 1.0 / ax; @@ -453,31 +485,31 @@ fn bessel_i0_impl(x: f32) -> f32 {{ let poly = ((((p5 * z + p4) * z + p3) * z + p2) * z + p1) * z + p0; return exp(ax) / sqrt(2.0 * PI * ax) * poly; - }} -}} + } +} // I1: Modified Bessel function of the first kind, order 1 -fn bessel_i1_impl(x: f32) -> f32 {{ +fn bessel_i1_impl(x: f32) -> f32 { let ax = abs(x); var result: f32; - if (ax <= 15.0) {{ + if (ax <= 15.0) { // Power series let z = ax * ax; var sum = 0.5; var term = 0.5; - for (var k = 1; k < 25; k++) {{ + for (var k = 1; k < 25; k++) { let kf = f32(k); term = term * z / (4.0 * kf * (kf + 1.0)); sum = sum + term; - if (abs(term) < abs(sum) * 1e-7) {{ + if (abs(term) < abs(sum) * 1e-7) { break; - }} - }} + } + } result = ax * sum; - }} else {{ + } else { // Asymptotic expansion let z = 1.0 / ax; @@ -491,19 +523,19 @@ fn bessel_i1_impl(x: f32) -> f32 {{ let poly = ((((q5 * z + q4) * z + q3) * z + q2) * z + q1) * z + q0; result = exp(ax) / sqrt(2.0 * PI * ax) * poly; - }} + } // I1 is an odd function return select(-result, result, x >= 0.0); -}} +} // K0: Modified Bessel function of the second kind, order 0 -fn bessel_k0_impl(x: f32) -> f32 {{ - if (x <= 0.0) {{ +fn bessel_k0_impl(x: f32) -> f32 { + if (x <= 0.0) { return 1e30; // NaN approximation - }} + } - if (x <= 2.0) {{ + if (x <= 2.0) { let z = x * x / 4.0; let i0 = bessel_i0_impl(x); @@ -518,7 +550,7 @@ fn bessel_k0_impl(x: f32) -> f32 {{ let poly = (((((p6 * z + p5) * z + p4) * z + p3) * z + p2) * z + p1) * z + p0; return -log(x / 2.0) * i0 + poly; - }} else {{ + } else { let z = 2.0 / x; let p0 = 1.25331414; @@ -532,16 +564,16 @@ fn bessel_k0_impl(x: f32) -> f32 {{ let poly = (((((p6 * z + p5) * z + p4) * z + p3) * z + p2) * z + p1) * z + p0; return exp(-x) / sqrt(x) * poly; - }} -}} + } +} // K1: Modified Bessel function of the second kind, order 1 -fn bessel_k1_impl(x: f32) -> f32 {{ - if (x <= 0.0) {{ +fn bessel_k1_impl(x: f32) -> f32 { + if (x <= 0.0) { return 1e30; // NaN approximation - }} + } - if (x <= 2.0) {{ + if (x <= 2.0) { let z = x * x / 4.0; let i1 = bessel_i1_impl(x); @@ -556,7 +588,7 @@ fn bessel_k1_impl(x: f32) -> f32 {{ let poly = x * (((((p6 * z + p5) * z + p4) * z + p3) * z + p2) * z + p1) * z + p0; return log(x / 2.0) * i1 + poly / x; - }} else {{ + } else { let z = 2.0 / x; let q0 = 1.25331414; @@ -570,76 +602,69 @@ fn bessel_k1_impl(x: f32) -> f32 {{ let poly = (((((q6 * z + q5) * z + q4) * z + q3) * z + q2) * z + q1) * z + q0; return exp(-x) / sqrt(x) * poly; - }} -}} + } +} @compute @workgroup_size(256) -fn bessel_j0_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn bessel_j0_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = bessel_j0_impl(special_a[idx]); - }} -}} + } +} @compute @workgroup_size(256) -fn bessel_j1_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn bessel_j1_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = bessel_j1_impl(special_a[idx]); - }} -}} + } +} @compute @workgroup_size(256) -fn bessel_y0_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn bessel_y0_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = bessel_y0_impl(special_a[idx]); - }} -}} + } +} @compute @workgroup_size(256) -fn bessel_y1_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn bessel_y1_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = bessel_y1_impl(special_a[idx]); - }} -}} + } +} @compute @workgroup_size(256) -fn bessel_i0_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn bessel_i0_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = bessel_i0_impl(special_a[idx]); - }} -}} + } +} @compute @workgroup_size(256) -fn bessel_i1_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn bessel_i1_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = bessel_i1_impl(special_a[idx]); - }} -}} + } +} @compute @workgroup_size(256) -fn bessel_k0_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn bessel_k0_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = bessel_k0_impl(special_a[idx]); - }} -}} + } +} @compute @workgroup_size(256) -fn bessel_k1_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ +fn bessel_k1_f32(@builtin(global_invocation_id) gid: vec3) { let idx = gid.x; - if (idx < special_params.numel) {{ + if (idx < special_params.numel) { special_out[idx] = bessel_k1_impl(special_a[idx]); - }} -}} -"#, - t = t, - suffix = suffix, - constants = common_constants(), - lgamma_helpers = lgamma_helpers() - )) + } } diff --git a/src/runtime/wgpu/shaders/statistics.rs b/src/runtime/wgpu/shaders/statistics.rs index 149f96e3..23c425a4 100644 --- a/src/runtime/wgpu/shaders/statistics.rs +++ b/src/runtime/wgpu/shaders/statistics.rs @@ -7,115 +7,34 @@ use wgpu::{Buffer, Queue}; -use super::generator::is_wgpu_supported; use super::pipeline::{LayoutKey, PipelineCache}; use crate::dtype::DType; use crate::error::{Error, Result}; // ============================================================================ -// Mode Shader Generation +// Static shaders // ============================================================================ -/// Get WGSL type string for dtype -fn wgsl_type_str(dtype: DType) -> &'static str { - match dtype { - DType::F32 => "f32", - DType::I32 => "i32", - DType::U32 => "u32", - _ => "f32", // Fallback, should be validated before calling - } -} - -/// Get suffix for kernel names -fn dtype_suffix_str(dtype: DType) -> &'static str { - match dtype { - DType::F32 => "f32", - DType::I32 => "i32", - DType::U32 => "u32", - _ => "f32", // Fallback, should be validated before calling - } -} +const MODE_F32: &str = include_str!("statistics_f32.wgsl"); +const MODE_I32: &str = include_str!("statistics_i32.wgsl"); +const MODE_U32: &str = include_str!("statistics_u32.wgsl"); -/// Generate WGSL shader for mode operation -fn generate_mode_shader(dtype: DType) -> String { - let wgsl_t = wgsl_type_str(dtype); - let suffix = dtype_suffix_str(dtype); - - format!( - r#" -// Mode shader for {wgsl_t} -// Finds most frequent value in sorted data along reduce dimension - -struct ModeParams {{ - outer_size: u32, - reduce_size: u32, - inner_size: u32, - _pad: u32, -}} - -@group(0) @binding(0) var sorted: array<{wgsl_t}>; -@group(0) @binding(1) var mode_values: array<{wgsl_t}>; -@group(0) @binding(2) var mode_counts: array; -@group(0) @binding(3) var params: ModeParams; - -@compute @workgroup_size(1) -fn mode_dim_{suffix}(@builtin(global_invocation_id) gid: vec3) {{ - let out_idx = gid.x; - let total_outputs = params.outer_size * params.inner_size; - - if (out_idx >= total_outputs) {{ - return; - }} - - let outer = out_idx / params.inner_size; - let inner = out_idx % params.inner_size; - let base = outer * params.reduce_size * params.inner_size + inner; - - if (params.reduce_size == 0u) {{ - return; - }} - - // Initialize with first element - var best_val = sorted[base]; - var best_count: i32 = 1; - var curr_val = best_val; - var curr_count: i32 = 1; - - // Scan through sorted slice - for (var r: u32 = 1u; r < params.reduce_size; r = r + 1u) {{ - let idx = base + r * params.inner_size; - let val = sorted[idx]; - - if (val == curr_val) {{ - curr_count = curr_count + 1; - }} else {{ - if (curr_count > best_count) {{ - best_val = curr_val; - best_count = curr_count; - }} - curr_val = val; - curr_count = 1; - }} - }} - - // Check final run - if (curr_count > best_count) {{ - best_val = curr_val; - best_count = curr_count; - }} - - mode_values[out_idx] = best_val; - mode_counts[out_idx] = best_count; -}} -"#, - wgsl_t = wgsl_t, - suffix = suffix - ) -} +// ============================================================================ +// Shader dispatch helper +// ============================================================================ -/// Get module key for caching -fn mode_module_key(dtype: DType) -> String { - format!("mode_{}", dtype_suffix_str(dtype)) +fn mode_shader_info(dtype: DType) -> Result<(&'static str, &'static str, &'static str)> { + Ok(match dtype { + DType::F32 => (MODE_F32, "statistics_f32", "mode_dim_f32"), + DType::I32 => (MODE_I32, "statistics_i32", "mode_dim_i32"), + DType::U32 => (MODE_U32, "statistics_u32", "mode_dim_u32"), + _ => { + return Err(Error::UnsupportedDType { + dtype, + op: "mode (WebGPU)", + }); + } + }) } // ============================================================================ @@ -143,40 +62,18 @@ pub fn launch_mode_dim( num_outputs: usize, dtype: DType, ) -> Result<()> { - if !is_wgpu_supported(dtype) { - return Err(Error::UnsupportedDType { dtype, op: "mode" }); - } + let (shader, module_key, entry_point) = mode_shader_info(dtype)?; - let suffix = dtype_suffix_str(dtype); - let entry_point = format!("mode_dim_{}", suffix); - // Leak entry_point to get static reference (cached, so leak is acceptable) - let static_entry_point: &'static str = Box::leak(entry_point.into_boxed_str()); - - // Generate shader and module key - let shader = generate_mode_shader(dtype); - let module_key = mode_module_key(dtype); - let static_module_key: &'static str = Box::leak(module_key.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); - - // Get or create shader module - let module = cache.get_or_create_module(static_module_key, static_shader); - - // Layout: 3 storage buffers + 1 uniform buffer + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - - // Get or create pipeline - let pipeline = - cache.get_or_create_pipeline(static_module_key, static_entry_point, &module, &layout); - - // Create bind group + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[sorted, mode_values, mode_counts, params_buffer]); - // Create command encoder and dispatch let mut encoder = cache .device() .create_command_encoder(&wgpu::CommandEncoderDescriptor { @@ -190,7 +87,6 @@ pub fn launch_mode_dim( }); pass.set_pipeline(&pipeline); pass.set_bind_group(0, Some(&bind_group), &[]); - // One workgroup per output element pass.dispatch_workgroups(num_outputs as u32, 1, 1); } @@ -199,7 +95,7 @@ pub fn launch_mode_dim( } /// Launch full mode operation (reduce entire tensor to single value). -#[allow(dead_code)] // May be used in future for full tensor mode +#[allow(dead_code)] pub fn launch_mode_full( cache: &PipelineCache, queue: &Queue, @@ -209,27 +105,15 @@ pub fn launch_mode_full( numel_buffer: &Buffer, dtype: DType, ) -> Result<()> { - if !is_wgpu_supported(dtype) { - return Err(Error::UnsupportedDType { dtype, op: "mode" }); - } - - let suffix = dtype_suffix_str(dtype); - let entry_point = format!("mode_full_{}", suffix); - let static_entry_point: &'static str = Box::leak(entry_point.into_boxed_str()); - - let shader = generate_mode_shader(dtype); - let module_key = format!("mode_full_{}", suffix); - let static_module_key: &'static str = Box::leak(module_key.into_boxed_str()); - let static_shader: &'static str = Box::leak(shader.into_boxed_str()); + let (shader, module_key, entry_point) = mode_shader_info(dtype)?; - let module = cache.get_or_create_module(static_module_key, static_shader); + let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { num_storage_buffers: 3, num_uniform_buffers: 1, num_readonly_storage: 0, }); - let pipeline = - cache.get_or_create_pipeline(static_module_key, static_entry_point, &module, &layout); + let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout); let bind_group = cache.create_bind_group(&layout, &[sorted, mode_value, mode_count, numel_buffer]); diff --git a/src/runtime/wgpu/shaders/statistics_f32.wgsl b/src/runtime/wgpu/shaders/statistics_f32.wgsl new file mode 100644 index 00000000..f7f7fea3 --- /dev/null +++ b/src/runtime/wgpu/shaders/statistics_f32.wgsl @@ -0,0 +1,64 @@ +// Statistics shaders - F32 +// mode_dim_f32: Find most frequent value in sorted data along reduce dimension + +struct ModeParams { + outer_size: u32, + reduce_size: u32, + inner_size: u32, + _pad: u32, +} + +@group(0) @binding(0) var sorted: array; +@group(0) @binding(1) var mode_values: array; +@group(0) @binding(2) var mode_counts: array; +@group(0) @binding(3) var params: ModeParams; + +@compute @workgroup_size(1) +fn mode_dim_f32(@builtin(global_invocation_id) gid: vec3) { + let out_idx = gid.x; + let total_outputs = params.outer_size * params.inner_size; + + if (out_idx >= total_outputs) { + return; + } + + let outer = out_idx / params.inner_size; + let inner = out_idx % params.inner_size; + let base = outer * params.reduce_size * params.inner_size + inner; + + if (params.reduce_size == 0u) { + return; + } + + // Initialize with first element + var best_val = sorted[base]; + var best_count: i32 = 1; + var curr_val = best_val; + var curr_count: i32 = 1; + + // Scan through sorted slice + for (var r: u32 = 1u; r < params.reduce_size; r = r + 1u) { + let idx = base + r * params.inner_size; + let val = sorted[idx]; + + if (val == curr_val) { + curr_count = curr_count + 1; + } else { + if (curr_count > best_count) { + best_val = curr_val; + best_count = curr_count; + } + curr_val = val; + curr_count = 1; + } + } + + // Check final run + if (curr_count > best_count) { + best_val = curr_val; + best_count = curr_count; + } + + mode_values[out_idx] = best_val; + mode_counts[out_idx] = best_count; +} diff --git a/src/runtime/wgpu/shaders/statistics_i32.wgsl b/src/runtime/wgpu/shaders/statistics_i32.wgsl new file mode 100644 index 00000000..165ec25c --- /dev/null +++ b/src/runtime/wgpu/shaders/statistics_i32.wgsl @@ -0,0 +1,64 @@ +// Statistics shaders - I32 +// mode_dim_i32: Find most frequent value in sorted data along reduce dimension + +struct ModeParams { + outer_size: u32, + reduce_size: u32, + inner_size: u32, + _pad: u32, +} + +@group(0) @binding(0) var sorted: array; +@group(0) @binding(1) var mode_values: array; +@group(0) @binding(2) var mode_counts: array; +@group(0) @binding(3) var params: ModeParams; + +@compute @workgroup_size(1) +fn mode_dim_i32(@builtin(global_invocation_id) gid: vec3) { + let out_idx = gid.x; + let total_outputs = params.outer_size * params.inner_size; + + if (out_idx >= total_outputs) { + return; + } + + let outer = out_idx / params.inner_size; + let inner = out_idx % params.inner_size; + let base = outer * params.reduce_size * params.inner_size + inner; + + if (params.reduce_size == 0u) { + return; + } + + // Initialize with first element + var best_val = sorted[base]; + var best_count: i32 = 1; + var curr_val = best_val; + var curr_count: i32 = 1; + + // Scan through sorted slice + for (var r: u32 = 1u; r < params.reduce_size; r = r + 1u) { + let idx = base + r * params.inner_size; + let val = sorted[idx]; + + if (val == curr_val) { + curr_count = curr_count + 1; + } else { + if (curr_count > best_count) { + best_val = curr_val; + best_count = curr_count; + } + curr_val = val; + curr_count = 1; + } + } + + // Check final run + if (curr_count > best_count) { + best_val = curr_val; + best_count = curr_count; + } + + mode_values[out_idx] = best_val; + mode_counts[out_idx] = best_count; +} diff --git a/src/runtime/wgpu/shaders/statistics_u32.wgsl b/src/runtime/wgpu/shaders/statistics_u32.wgsl new file mode 100644 index 00000000..eef39f66 --- /dev/null +++ b/src/runtime/wgpu/shaders/statistics_u32.wgsl @@ -0,0 +1,64 @@ +// Statistics shaders - U32 +// mode_dim_u32: Find most frequent value in sorted data along reduce dimension + +struct ModeParams { + outer_size: u32, + reduce_size: u32, + inner_size: u32, + _pad: u32, +} + +@group(0) @binding(0) var sorted: array; +@group(0) @binding(1) var mode_values: array; +@group(0) @binding(2) var mode_counts: array; +@group(0) @binding(3) var params: ModeParams; + +@compute @workgroup_size(1) +fn mode_dim_u32(@builtin(global_invocation_id) gid: vec3) { + let out_idx = gid.x; + let total_outputs = params.outer_size * params.inner_size; + + if (out_idx >= total_outputs) { + return; + } + + let outer = out_idx / params.inner_size; + let inner = out_idx % params.inner_size; + let base = outer * params.reduce_size * params.inner_size + inner; + + if (params.reduce_size == 0u) { + return; + } + + // Initialize with first element + var best_val = sorted[base]; + var best_count: i32 = 1; + var curr_val = best_val; + var curr_count: i32 = 1; + + // Scan through sorted slice + for (var r: u32 = 1u; r < params.reduce_size; r = r + 1u) { + let idx = base + r * params.inner_size; + let val = sorted[idx]; + + if (val == curr_val) { + curr_count = curr_count + 1; + } else { + if (curr_count > best_count) { + best_val = curr_val; + best_count = curr_count; + } + curr_val = val; + curr_count = 1; + } + } + + // Check final run + if (curr_count > best_count) { + best_val = curr_val; + best_count = curr_count; + } + + mode_values[out_idx] = best_val; + mode_counts[out_idx] = best_count; +} diff --git a/src/runtime/wgpu/shaders/stockham_fft.wgsl b/src/runtime/wgpu/shaders/stockham_fft.wgsl new file mode 100644 index 00000000..896a8573 --- /dev/null +++ b/src/runtime/wgpu/shaders/stockham_fft.wgsl @@ -0,0 +1,186 @@ +// Stockham FFT shader for WebGPU +// Complex numbers as vec2 (re, im) + +const PI: f32 = 3.14159265358979323846; +const WORKGROUP_SIZE: u32 = 256u; + +struct FftParams { + n: u32, + log_n: u32, + inverse: i32, + scale: f32, + batch_size: u32, + _pad1: u32, + _pad2: u32, + _pad3: u32, +} + +@group(0) @binding(0) var fft_input: array>; +@group(0) @binding(1) var fft_output: array>; +@group(0) @binding(2) var fft_params: FftParams; + +// Workgroup shared memory for ping-pong +var smem_a: array, 256>; +var smem_b: array, 256>; + +// Complex number helpers (vec2: x=real, y=imag) +fn cmul(a: vec2, b: vec2) -> vec2 { + return vec2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); +} + +fn cadd(a: vec2, b: vec2) -> vec2 { + return a + b; +} + +fn csub(a: vec2, b: vec2) -> vec2 { + return a - b; +} + +fn cscale(a: vec2, s: f32) -> vec2 { + return vec2(a.x * s, a.y * s); +} + +fn cconj(a: vec2) -> vec2 { + return vec2(a.x, -a.y); +} + +// Compute e^(i*theta) = cos(theta) + i*sin(theta) +fn cexp_i(theta: f32) -> vec2 { + return vec2(cos(theta), sin(theta)); +} + +@compute @workgroup_size(WORKGROUP_SIZE) +fn stockham_fft_small( + @builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let batch_idx = wg_id.x; + let tid = local_id.x; + let n = fft_params.n; + let log_n = fft_params.log_n; + let inverse = fft_params.inverse; + let scale_factor = fft_params.scale; + + // Sign for twiddle factor + let sign = select(-1.0, 1.0, inverse != 0); + + // Load input to shared memory + let base_offset = batch_idx * n; + for (var i = tid; i < n; i = i + WORKGROUP_SIZE) { + smem_a[i] = fft_input[base_offset + i]; + } + workgroupBarrier(); + + // Perform Stockham FFT stages + var use_a = true; + for (var stage: u32 = 0u; stage < log_n; stage = stage + 1u) { + let m = 1u << (stage + 1u); + let half_m = 1u << stage; + + for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) { + let group = i / half_m; + let pair = i % half_m; + + let even_idx = group * half_m + pair; + let odd_idx = even_idx + n / 2u; + + let out_even_idx = group * m + pair; + let out_odd_idx = out_even_idx + half_m; + + // Twiddle factor + let theta = sign * 2.0 * PI * f32(pair) / f32(m); + let twiddle = cexp_i(theta); + + var even_val: vec2; + var odd_val: vec2; + + if (use_a) { + even_val = smem_a[even_idx]; + odd_val = cmul(smem_a[odd_idx], twiddle); + } else { + even_val = smem_b[even_idx]; + odd_val = cmul(smem_b[odd_idx], twiddle); + } + + let sum = cadd(even_val, odd_val); + let diff = csub(even_val, odd_val); + + if (use_a) { + smem_b[out_even_idx] = sum; + smem_b[out_odd_idx] = diff; + } else { + smem_a[out_even_idx] = sum; + smem_a[out_odd_idx] = diff; + } + } + + workgroupBarrier(); + use_a = !use_a; + } + + // Write output with scaling + for (var i = tid; i < n; i = i + WORKGROUP_SIZE) { + var result: vec2; + if (use_a) { + result = smem_a[i]; + } else { + result = smem_b[i]; + } + fft_output[base_offset + i] = cscale(result, scale_factor); + } +} + +// Single stage kernel for large FFTs (N > workgroup FFT size) +@compute @workgroup_size(WORKGROUP_SIZE) +fn stockham_fft_stage( + @builtin(global_invocation_id) gid: vec3 +) { + let n = fft_params.n; + let stage = fft_params.log_n; // Reuse log_n as current stage + let inverse = fft_params.inverse; + let batch_idx = gid.y; + + let sign = select(-1.0, 1.0, inverse != 0); + + let m = 1u << (stage + 1u); + let half_m = 1u << stage; + + let i = gid.x; + if (i >= n / 2u) { + return; + } + + let group = i / half_m; + let pair = i % half_m; + + let base_offset = batch_idx * n; + let even_idx = base_offset + group * half_m + pair; + let odd_idx = even_idx + n / 2u; + + let out_even_idx = base_offset + group * m + pair; + let out_odd_idx = out_even_idx + half_m; + + // Twiddle factor + let theta = sign * 2.0 * PI * f32(pair) / f32(m); + let twiddle = cexp_i(theta); + + let even_val = fft_input[even_idx]; + let odd_val = cmul(fft_input[odd_idx], twiddle); + + fft_output[out_even_idx] = cadd(even_val, odd_val); + fft_output[out_odd_idx] = csub(even_val, odd_val); +} + +// Scale complex array +@compute @workgroup_size(WORKGROUP_SIZE) +fn scale_complex( + @builtin(global_invocation_id) gid: vec3 +) { + let idx = gid.x; + let n = fft_params.n; + let scale_factor = fft_params.scale; + + if (idx < n) { + fft_output[idx] = cscale(fft_input[idx], scale_factor); + } +} diff --git a/src/runtime/wgpu/shaders/student_t_f32.wgsl b/src/runtime/wgpu/shaders/student_t_f32.wgsl new file mode 100644 index 00000000..1c7ca35a --- /dev/null +++ b/src/runtime/wgpu/shaders/student_t_f32.wgsl @@ -0,0 +1,92 @@ +// Student's t distribution sampling for f32 + +// PCG hash function for random number generation +fn pcg_hash(input: u32) -> u32 { + var state = input * 747796405u + 2891336453u; + var word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +fn pcg_init(seed: u32, idx: u32) -> u32 { + return pcg_hash(seed ^ pcg_hash(idx)); +} + +fn pcg_uniform(state: ptr) -> f32 { + *state = pcg_hash(*state); + return f32(*state) / 4294967296.0; +} + +// Box-Muller for normal distribution +fn sample_normal(state: ptr) -> f32 { + let u1 = max(pcg_uniform(state), 0.0000001); + let u2 = pcg_uniform(state); + return sqrt(-2.0 * log(u1)) * cos(6.28318530718 * u2); +} + +// Gamma via Marsaglia-Tsang method +fn sample_gamma_mt(state: ptr, shape: f32, scale: f32) -> f32 { + var alpha = shape; + var boost = 1.0; + + // Handle shape < 1 by boosting + if alpha < 1.0 { + boost = pow(pcg_uniform(state), 1.0 / alpha); + alpha = alpha + 1.0; + } + + let d = alpha - 1.0 / 3.0; + let c = 1.0 / sqrt(9.0 * d); + + // Rejection sampling + for (var i = 0u; i < 100u; i = i + 1u) { + var x: f32; + var v: f32; + + // Generate valid v + for (var j = 0u; j < 100u; j = j + 1u) { + x = sample_normal(state); + v = 1.0 + c * x; + if v > 0.0 { + break; + } + } + + v = v * v * v; + let u = pcg_uniform(state); + let x2 = x * x; + + // Accept/reject + if u < 1.0 - 0.0331 * x2 * x2 { + return d * v * boost * scale; + } + if log(u) < 0.5 * x2 + d * (1.0 - v + log(v)) { + return d * v * boost * scale; + } + } + + // Fallback (should rarely reach) + return d * boost * scale; +} + +const WORKGROUP_SIZE: u32 = 256u; + +struct StudentTParams { + numel: u32, + seed: u32, + df: f32, + _pad: u32, +} + +@group(0) @binding(0) var out: array; +@group(0) @binding(1) var params: StudentTParams; + +@compute @workgroup_size(256) +fn student_t_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if idx < params.numel { + var state = pcg_init(params.seed, idx); + let z = sample_normal(&state); + let chi2 = sample_gamma_mt(&state, params.df / 2.0, 2.0); + out[idx] = f32(z / sqrt(chi2 / params.df)); + } +} diff --git a/src/runtime/wgpu/shaders/topk_f32.wgsl b/src/runtime/wgpu/shaders/topk_f32.wgsl new file mode 100644 index 00000000..9df2e0a7 --- /dev/null +++ b/src/runtime/wgpu/shaders/topk_f32.wgsl @@ -0,0 +1,107 @@ +// Auto-generated topk operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; +const MAX_SORT_SIZE: u32 = 512u; + +var shared_vals: array; +var shared_idxs: array; + +struct TopkParams { + outer_size: u32, + sort_size: u32, + inner_size: u32, + k: u32, + largest: u32, + sorted: u32, +} + +@group(0) @binding(0) var topk_input: array; +@group(0) @binding(1) var topk_values: array; +@group(0) @binding(2) var topk_indices: array; +@group(0) @binding(3) var topk_params: TopkParams; + +fn compare_less_f32(a: f32, b: f32) -> bool { + return a < b; +} + +fn bitonic_cas_f32(i: u32, j: u32, dir: bool) { + let vi = shared_vals[i]; + let vj = shared_vals[j]; + let swap = select(compare_less_f32(vi, vj), compare_less_f32(vj, vi), dir); + if (swap) { + shared_vals[i] = vj; + shared_vals[j] = vi; + let ti = shared_idxs[i]; + shared_idxs[i] = shared_idxs[j]; + shared_idxs[j] = ti; + } +} + +@compute @workgroup_size(256) +fn topk_f32( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) group_id: vec3 +) { + let outer_idx = group_id.x; + let inner_idx = group_id.y; + let tid = local_id.x; + + let outer_size = topk_params.outer_size; + let sort_size = topk_params.sort_size; + let inner_size = topk_params.inner_size; + let k = topk_params.k; + let largest = topk_params.largest != 0u; + + if (outer_idx >= outer_size || inner_idx >= inner_size) { + return; + } + + var n = sort_size; + var p: u32 = 1u; + while (p < n) { + p = p << 1u; + } + n = min(p, MAX_SORT_SIZE); + + let base_offset = outer_idx * sort_size * inner_size + inner_idx; + for (var i = tid; i < n; i = i + WORKGROUP_SIZE) { + if (i < sort_size) { + let idx = base_offset + i * inner_size; + shared_vals[i] = topk_input[idx]; + shared_idxs[i] = i32(i); + } else { + shared_vals[i] = select(f32(3.402823e+38), f32(-3.402823e+38), largest); + shared_idxs[i] = i32(i); + } + } + workgroupBarrier(); + + // Bitonic sort (descending if largest, ascending if smallest) + for (var k_: u32 = 2u; k_ <= n; k_ = k_ << 1u) { + for (var j: u32 = k_ >> 1u; j > 0u; j = j >> 1u) { + for (var i = tid; i < n / 2u; i = i + WORKGROUP_SIZE) { + // Calculate bitonic network indices + let ij = (i / j) * 2u * j + (i % j); + let ij_pair = ij + j; + + // Direction depends on which half of the network we're in + // For largest: descending (true), for smallest: ascending (false) + let ascending_local = ((ij / k_) % 2u == 0u) != largest; + + if (ij_pair < n) { + bitonic_cas_f32(ij, ij_pair, ascending_local); + } + } + workgroupBarrier(); + } + } + + // Write top-k values and indices + let out_base = outer_idx * k * inner_size + inner_idx; + for (var i = tid; i < k; i = i + WORKGROUP_SIZE) { + let out_idx = out_base + i * inner_size; + topk_values[out_idx] = shared_vals[i]; + topk_indices[out_idx] = shared_idxs[i]; + } +} diff --git a/src/runtime/wgpu/shaders/unary.wgsl b/src/runtime/wgpu/shaders/unary.wgsl new file mode 100644 index 00000000..84a58358 --- /dev/null +++ b/src/runtime/wgpu/shaders/unary.wgsl @@ -0,0 +1,327 @@ +// F32 unary operations + +const WORKGROUP_SIZE: u32 = 256u; + +struct UnaryParams { + numel: u32, +} + +@group(0) @binding(0) var unary_a: array; +@group(0) @binding(1) var unary_out: array; +@group(0) @binding(2) var unary_params: UnaryParams; + +@compute @workgroup_size(256) +fn neg_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = -unary_a[idx]; + } +} + +@compute @workgroup_size(256) +fn abs_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = abs(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn sqrt_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = sqrt(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn exp_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = exp(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn log_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = log(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn sin_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = sin(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn cos_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = cos(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn tan_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = tan(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn atan_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = atan(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn tanh_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = tanh(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn recip_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = 1.0 / unary_a[idx]; + } +} + +@compute @workgroup_size(256) +fn floor_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = floor(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn ceil_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = ceil(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn round_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + let x = unary_a[idx]; + unary_out[idx] = select(ceil(x - 0.5), floor(x + 0.5), x >= 0.0); + } +} + +@compute @workgroup_size(256) +fn trunc_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = trunc(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn rsqrt_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = inverseSqrt(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn cbrt_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + let x = unary_a[idx]; + unary_out[idx] = sign(x) * pow(abs(x), 1.0 / 3.0); + } +} + +@compute @workgroup_size(256) +fn exp2_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = exp2(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn expm1_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = exp(unary_a[idx]) - 1.0; + } +} + +@compute @workgroup_size(256) +fn log2_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = log2(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn log10_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = log(unary_a[idx]) * 0.4342944819032518; + } +} + +@compute @workgroup_size(256) +fn log1p_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = log(1.0 + unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn asin_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + let x = unary_a[idx]; + let y = sqrt(max(0.0, 1.0 - x * x)); + unary_out[idx] = atan2(x, y); + } +} + +@compute @workgroup_size(256) +fn acos_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + let x = unary_a[idx]; + let y = sqrt(max(0.0, 1.0 - x * x)); + unary_out[idx] = atan2(y, x); + } +} + +@compute @workgroup_size(256) +fn sinh_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = sinh(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn cosh_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = cosh(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn asinh_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = asinh(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn acosh_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = acosh(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn atanh_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = atanh(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn square_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + let x = unary_a[idx]; + unary_out[idx] = x * x; + } +} + +@compute @workgroup_size(256) +fn sign_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = sign(unary_a[idx]); + } +} + +@compute @workgroup_size(256) +fn relu_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = max(unary_a[idx], 0.0); + } +} + +@compute @workgroup_size(256) +fn sigmoid_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = 1.0 / (1.0 + exp(-unary_a[idx])); + } +} + +@compute @workgroup_size(256) +fn silu_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + let x = unary_a[idx]; + unary_out[idx] = x / (1.0 + exp(-x)); + } +} + +@compute @workgroup_size(256) +fn gelu_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + let x = unary_a[idx]; + let c = 0.7978845608028654; + unary_out[idx] = 0.5 * x * (1.0 + tanh(c * (x + 0.044715 * x * x * x))); + } +} + +@compute @workgroup_size(256) +fn isnan_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + let x = unary_a[idx]; + let bits = bitcast(f32(x)); + let exp = bits & 0x7f800000u; + let mant = bits & 0x007fffffu; + let is_nan = (exp == 0x7f800000u) && (mant != 0u); + unary_out[idx] = select(0.0, 1.0, is_nan); + } +} + +@compute @workgroup_size(256) +fn isinf_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + let x = unary_a[idx]; + let bits = bitcast(f32(x)); + let exp = bits & 0x7f800000u; + let mant = bits & 0x007fffffu; + let is_inf = (exp == 0x7f800000u) && (mant == 0u); + unary_out[idx] = select(0.0, 1.0, is_inf); + } +} diff --git a/src/runtime/wgpu/shaders/unary_i32.wgsl b/src/runtime/wgpu/shaders/unary_i32.wgsl new file mode 100644 index 00000000..6cbbcaed --- /dev/null +++ b/src/runtime/wgpu/shaders/unary_i32.wgsl @@ -0,0 +1,27 @@ +// I32 unary operations + +const WORKGROUP_SIZE: u32 = 256u; + +struct UnaryParams { + numel: u32, +} + +@group(0) @binding(0) var unary_a: array; +@group(0) @binding(1) var unary_out: array; +@group(0) @binding(2) var unary_params: UnaryParams; + +@compute @workgroup_size(256) +fn neg_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = -unary_a[idx]; + } +} + +@compute @workgroup_size(256) +fn abs_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = abs(unary_a[idx]); + } +} diff --git a/src/runtime/wgpu/shaders/unary_u32.wgsl b/src/runtime/wgpu/shaders/unary_u32.wgsl new file mode 100644 index 00000000..240d0aa8 --- /dev/null +++ b/src/runtime/wgpu/shaders/unary_u32.wgsl @@ -0,0 +1,19 @@ +// U32 unary operations + +const WORKGROUP_SIZE: u32 = 256u; + +struct UnaryParams { + numel: u32, +} + +@group(0) @binding(0) var unary_a: array; +@group(0) @binding(1) var unary_out: array; +@group(0) @binding(2) var unary_params: UnaryParams; + +@compute @workgroup_size(256) +fn abs_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < unary_params.numel) { + unary_out[idx] = unary_a[idx]; + } +} diff --git a/src/runtime/wgpu/shaders/unique_with_counts_f32.wgsl b/src/runtime/wgpu/shaders/unique_with_counts_f32.wgsl new file mode 100644 index 00000000..72022d9a --- /dev/null +++ b/src/runtime/wgpu/shaders/unique_with_counts_f32.wgsl @@ -0,0 +1,92 @@ +// Auto-generated unique_with_counts operations for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct UniqueCountsParams { + numel: u32, + num_unique: u32, + _pad0: u32, + _pad1: u32, +} + +// Mark boundaries in sorted array (where value changes) +// Output: flags[i] = 1 if sorted[i] != sorted[i-1] (or i=0), else 0 +@group(0) @binding(0) var sorted_input: array; +@group(0) @binding(1) var boundary_flags: array; +@group(0) @binding(2) var params: UniqueCountsParams; + +@compute @workgroup_size(256) +fn mark_boundaries_f32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + let numel = params.numel; + + if (idx >= numel) { + return; + } + + // Mark boundary: first element or different from previous + if (idx == 0u || sorted_input[idx] != sorted_input[idx - 1u]) { + boundary_flags[idx] = 1u; + } else { + boundary_flags[idx] = 0u; + } +} + +// Scatter unique values and compute counts using prefix sum indices +// prefix_sum[i] contains the output index for element at position i (if it's a boundary) +// We write: unique_values[prefix_sum[i]-1] = sorted[i] when flags[i] == 1 +// counts[prefix_sum[i]-1] = (next boundary position - i) computed from adjacent prefix sums +@group(0) @binding(0) var scatter_sorted: array; +@group(0) @binding(1) var prefix_sum: array; +@group(0) @binding(2) var unique_values: array; +@group(0) @binding(3) var inverse_indices: array; +@group(0) @binding(4) var counts: array; +@group(0) @binding(5) var scatter_params: UniqueCountsParams; + +@compute @workgroup_size(256) +fn scatter_unique_with_counts_f32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + let numel = scatter_params.numel; + let num_unique = scatter_params.num_unique; + + if (idx >= numel) { + return; + } + + // The prefix sum gives us 1-based output indices + let out_idx_plus1 = prefix_sum[idx]; + + // Check if this is a boundary by comparing with previous prefix sum + let is_boundary = (idx == 0u) || (prefix_sum[idx] != prefix_sum[idx - 1u]); + + // Write inverse index: which unique element does this sorted element map to + inverse_indices[idx] = i32(out_idx_plus1 - 1u); + + if (is_boundary) { + let out_idx = out_idx_plus1 - 1u; + unique_values[out_idx] = scatter_sorted[idx]; + + // Compute count: find next boundary position + // The count is (next_boundary_position - idx) + // If we're the last unique, count to numel + if (out_idx + 1u >= num_unique) { + // Last unique element + counts[out_idx] = i32(numel - idx); + } else { + // Find next boundary: it's where prefix_sum increases next + // We need to find the smallest j > idx where prefix_sum[j] > out_idx_plus1 + // Actually, we can compute this differently: + // The run length is the distance to the next boundary + // For efficiency, we'll use a second pass or a different approach + + // For now, scan forward (not ideal but correct) + var run_len: u32 = 1u; + var j = idx + 1u; + while (j < numel && prefix_sum[j] == out_idx_plus1) { + run_len = run_len + 1u; + j = j + 1u; + } + counts[out_idx] = i32(run_len); + } + } +} diff --git a/src/runtime/wgpu/shaders/unique_with_counts_i32.wgsl b/src/runtime/wgpu/shaders/unique_with_counts_i32.wgsl new file mode 100644 index 00000000..765d1e21 --- /dev/null +++ b/src/runtime/wgpu/shaders/unique_with_counts_i32.wgsl @@ -0,0 +1,92 @@ +// Auto-generated unique_with_counts operations for i32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct UniqueCountsParams { + numel: u32, + num_unique: u32, + _pad0: u32, + _pad1: u32, +} + +// Mark boundaries in sorted array (where value changes) +// Output: flags[i] = 1 if sorted[i] != sorted[i-1] (or i=0), else 0 +@group(0) @binding(0) var sorted_input: array; +@group(0) @binding(1) var boundary_flags: array; +@group(0) @binding(2) var params: UniqueCountsParams; + +@compute @workgroup_size(256) +fn mark_boundaries_i32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + let numel = params.numel; + + if (idx >= numel) { + return; + } + + // Mark boundary: first element or different from previous + if (idx == 0u || sorted_input[idx] != sorted_input[idx - 1u]) { + boundary_flags[idx] = 1u; + } else { + boundary_flags[idx] = 0u; + } +} + +// Scatter unique values and compute counts using prefix sum indices +// prefix_sum[i] contains the output index for element at position i (if it's a boundary) +// We write: unique_values[prefix_sum[i]-1] = sorted[i] when flags[i] == 1 +// counts[prefix_sum[i]-1] = (next boundary position - i) computed from adjacent prefix sums +@group(0) @binding(0) var scatter_sorted: array; +@group(0) @binding(1) var prefix_sum: array; +@group(0) @binding(2) var unique_values: array; +@group(0) @binding(3) var inverse_indices: array; +@group(0) @binding(4) var counts: array; +@group(0) @binding(5) var scatter_params: UniqueCountsParams; + +@compute @workgroup_size(256) +fn scatter_unique_with_counts_i32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + let numel = scatter_params.numel; + let num_unique = scatter_params.num_unique; + + if (idx >= numel) { + return; + } + + // The prefix sum gives us 1-based output indices + let out_idx_plus1 = prefix_sum[idx]; + + // Check if this is a boundary by comparing with previous prefix sum + let is_boundary = (idx == 0u) || (prefix_sum[idx] != prefix_sum[idx - 1u]); + + // Write inverse index: which unique element does this sorted element map to + inverse_indices[idx] = i32(out_idx_plus1 - 1u); + + if (is_boundary) { + let out_idx = out_idx_plus1 - 1u; + unique_values[out_idx] = scatter_sorted[idx]; + + // Compute count: find next boundary position + // The count is (next_boundary_position - idx) + // If we're the last unique, count to numel + if (out_idx + 1u >= num_unique) { + // Last unique element + counts[out_idx] = i32(numel - idx); + } else { + // Find next boundary: it's where prefix_sum increases next + // We need to find the smallest j > idx where prefix_sum[j] > out_idx_plus1 + // Actually, we can compute this differently: + // The run length is the distance to the next boundary + // For efficiency, we'll use a second pass or a different approach + + // For now, scan forward (not ideal but correct) + var run_len: u32 = 1u; + var j = idx + 1u; + while (j < numel && prefix_sum[j] == out_idx_plus1) { + run_len = run_len + 1u; + j = j + 1u; + } + counts[out_idx] = i32(run_len); + } + } +} diff --git a/src/runtime/wgpu/shaders/unique_with_counts_u32.wgsl b/src/runtime/wgpu/shaders/unique_with_counts_u32.wgsl new file mode 100644 index 00000000..f1c57395 --- /dev/null +++ b/src/runtime/wgpu/shaders/unique_with_counts_u32.wgsl @@ -0,0 +1,92 @@ +// Auto-generated unique_with_counts operations for u32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct UniqueCountsParams { + numel: u32, + num_unique: u32, + _pad0: u32, + _pad1: u32, +} + +// Mark boundaries in sorted array (where value changes) +// Output: flags[i] = 1 if sorted[i] != sorted[i-1] (or i=0), else 0 +@group(0) @binding(0) var sorted_input: array; +@group(0) @binding(1) var boundary_flags: array; +@group(0) @binding(2) var params: UniqueCountsParams; + +@compute @workgroup_size(256) +fn mark_boundaries_u32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + let numel = params.numel; + + if (idx >= numel) { + return; + } + + // Mark boundary: first element or different from previous + if (idx == 0u || sorted_input[idx] != sorted_input[idx - 1u]) { + boundary_flags[idx] = 1u; + } else { + boundary_flags[idx] = 0u; + } +} + +// Scatter unique values and compute counts using prefix sum indices +// prefix_sum[i] contains the output index for element at position i (if it's a boundary) +// We write: unique_values[prefix_sum[i]-1] = sorted[i] when flags[i] == 1 +// counts[prefix_sum[i]-1] = (next boundary position - i) computed from adjacent prefix sums +@group(0) @binding(0) var scatter_sorted: array; +@group(0) @binding(1) var prefix_sum: array; +@group(0) @binding(2) var unique_values: array; +@group(0) @binding(3) var inverse_indices: array; +@group(0) @binding(4) var counts: array; +@group(0) @binding(5) var scatter_params: UniqueCountsParams; + +@compute @workgroup_size(256) +fn scatter_unique_with_counts_u32(@builtin(global_invocation_id) global_id: vec3) { + let idx = global_id.x; + let numel = scatter_params.numel; + let num_unique = scatter_params.num_unique; + + if (idx >= numel) { + return; + } + + // The prefix sum gives us 1-based output indices + let out_idx_plus1 = prefix_sum[idx]; + + // Check if this is a boundary by comparing with previous prefix sum + let is_boundary = (idx == 0u) || (prefix_sum[idx] != prefix_sum[idx - 1u]); + + // Write inverse index: which unique element does this sorted element map to + inverse_indices[idx] = i32(out_idx_plus1 - 1u); + + if (is_boundary) { + let out_idx = out_idx_plus1 - 1u; + unique_values[out_idx] = scatter_sorted[idx]; + + // Compute count: find next boundary position + // The count is (next_boundary_position - idx) + // If we're the last unique, count to numel + if (out_idx + 1u >= num_unique) { + // Last unique element + counts[out_idx] = i32(numel - idx); + } else { + // Find next boundary: it's where prefix_sum increases next + // We need to find the smallest j > idx where prefix_sum[j] > out_idx_plus1 + // Actually, we can compute this differently: + // The run length is the distance to the next boundary + // For efficiency, we'll use a second pass or a different approach + + // For now, scan forward (not ideal but correct) + var run_len: u32 = 1u; + var j = idx + 1u; + while (j < numel && prefix_sum[j] == out_idx_plus1) { + run_len = run_len + 1u; + j = j + 1u; + } + counts[out_idx] = i32(run_len); + } + } +} diff --git a/src/runtime/wgpu/shaders/validate_eigenvalues_f32.wgsl b/src/runtime/wgpu/shaders/validate_eigenvalues_f32.wgsl new file mode 100644 index 00000000..1ae7906d --- /dev/null +++ b/src/runtime/wgpu/shaders/validate_eigenvalues_f32.wgsl @@ -0,0 +1,85 @@ +// Schur eigenvalue validation for f32 + +const WORKGROUP_SIZE: u32 = 256u; + +struct Params { + n: u32, + eps: f32, + _pad1: u32, + _pad2: u32, +} + +@group(0) @binding(0) var matrix_t: array; +@group(0) @binding(1) var result: array; // [has_error, error_value] +@group(0) @binding(2) var params: Params; + +// Check if a real eigenvalue is non-positive +fn check_real_eigenvalue(val: f32, eps: f32) -> bool { + return val <= eps; +} + +// Check if a 2x2 block represents non-positive real eigenvalues +// For 2x2 block [[a, b], [c, d]], eigenvalues are (a+d)/2 ± sqrt((a-d)²/4 + bc) +// If discriminant < 0, eigenvalues are complex (ok) +// If discriminant >= 0, check if real part is non-positive +fn check_2x2_block(a: f32, b: f32, c: f32, d: f32, eps: f32) -> bool { + let trace = a + d; + let det = a * d - b * c; + let disc = trace * trace - 4.0 * det; + + if disc < 0.0 { + // Complex eigenvalues - check real part + let real_part = trace / 2.0; + return real_part <= eps; + } else { + // Real eigenvalues + let sqrt_disc = sqrt(disc); + let lambda1 = (trace + sqrt_disc) / 2.0; + let lambda2 = (trace - sqrt_disc) / 2.0; + return lambda1 <= eps || lambda2 <= eps; + } +} + +@compute @workgroup_size(1) +fn validate_eigenvalues_f32(@builtin(global_invocation_id) gid: vec3) { + let n = params.n; + let eps = f32(params.eps); + + // Initialize result to "no error" + result[0] = 0.0; + result[1] = 0.0; + + var i: u32 = 0u; + while i < n { + let diag_idx = i * n + i; + + // Check if this is a 2x2 block (non-zero sub-diagonal) + if i + 1u < n { + let sub_diag = abs(matrix_t[(i + 1u) * n + i]); + if sub_diag > eps { + // 2x2 block + let a = matrix_t[i * n + i]; + let b = matrix_t[i * n + (i + 1u)]; + let c = matrix_t[(i + 1u) * n + i]; + let d = matrix_t[(i + 1u) * n + (i + 1u)]; + + if check_2x2_block(a, b, c, d, eps) { + result[0] = 1.0; + result[1] = (a + d) / 2.0; // Report real part + return; + } + i = i + 2u; + continue; + } + } + + // 1x1 block (real eigenvalue) + let eigenvalue = matrix_t[diag_idx]; + if check_real_eigenvalue(eigenvalue, eps) { + result[0] = 1.0; + result[1] = eigenvalue; + return; + } + i = i + 1u; + } +} diff --git a/src/runtime/wgpu/shaders/validate_indices.wgsl b/src/runtime/wgpu/shaders/validate_indices.wgsl new file mode 100644 index 00000000..49da5ae4 --- /dev/null +++ b/src/runtime/wgpu/shaders/validate_indices.wgsl @@ -0,0 +1,27 @@ +// Auto-generated index bounds validation kernel + +const WORKGROUP_SIZE: u32 = 256u; + +struct ValidateIndicesParams { + index_len: u32, + dim_size: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var indices: array; +@group(0) @binding(1) var error_count: atomic; +@group(0) @binding(2) var params: ValidateIndicesParams; + +@compute @workgroup_size(256) +fn validate_indices(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= params.index_len) { + return; + } + + let index_val = indices[idx]; + if (index_val < 0 || u32(index_val) >= params.dim_size) { + atomicAdd(&error_count, 1u); + } +} diff --git a/src/runtime/wgpu/shaders/where_broadcast_cond_f32_f32.wgsl b/src/runtime/wgpu/shaders/where_broadcast_cond_f32_f32.wgsl new file mode 100644 index 00000000..64951f66 --- /dev/null +++ b/src/runtime/wgpu/shaders/where_broadcast_cond_f32_f32.wgsl @@ -0,0 +1,52 @@ +// where_broadcast_cond: condition=f32, output=f32 +// out[i] = cond[cond_offset] != 0.0 ? x[x_offset] : y[y_offset] (with broadcasting) + +struct WhereBroadcastParams { + numel: u32, + ndim: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var bc_cond: array; +@group(0) @binding(1) var bc_x: array; +@group(0) @binding(2) var bc_y: array; +@group(0) @binding(3) var bc_out: array; +@group(0) @binding(4) var cond_strides: array; +@group(0) @binding(5) var x_strides: array; +@group(0) @binding(6) var y_strides: array; +@group(0) @binding(7) var out_shape: array; +@group(0) @binding(8) var bc_params: WhereBroadcastParams; + +fn compute_out_stride(d: u32, ndim: u32) -> u32 { + var stride: u32 = 1u; + for (var i: u32 = d + 1u; i < ndim; i = i + 1u) { + stride = stride * out_shape[i]; + } + return stride; +} + +@compute @workgroup_size(256) +fn where_broadcast_cond_f32_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= bc_params.numel) { + return; + } + + var remaining = idx; + var cond_offset: u32 = 0u; + var x_offset: u32 = 0u; + var y_offset: u32 = 0u; + + for (var d: u32 = 0u; d < bc_params.ndim; d = d + 1u) { + let s = compute_out_stride(d, bc_params.ndim); + let coord = remaining / s; + remaining = remaining % s; + cond_offset = cond_offset + coord * cond_strides[d]; + x_offset = x_offset + coord * x_strides[d]; + y_offset = y_offset + coord * y_strides[d]; + } + + let cond_val = bc_cond[cond_offset] != 0.0; + bc_out[idx] = select(bc_y[y_offset], bc_x[x_offset], cond_val); +} diff --git a/src/runtime/wgpu/shaders/where_broadcast_cond_f32_i32.wgsl b/src/runtime/wgpu/shaders/where_broadcast_cond_f32_i32.wgsl new file mode 100644 index 00000000..114593da --- /dev/null +++ b/src/runtime/wgpu/shaders/where_broadcast_cond_f32_i32.wgsl @@ -0,0 +1,52 @@ +// where_broadcast_cond: condition=f32, output=i32 +// out[i] = cond[cond_offset] != 0.0 ? x[x_offset] : y[y_offset] (with broadcasting) + +struct WhereBroadcastParams { + numel: u32, + ndim: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var bc_cond: array; +@group(0) @binding(1) var bc_x: array; +@group(0) @binding(2) var bc_y: array; +@group(0) @binding(3) var bc_out: array; +@group(0) @binding(4) var cond_strides: array; +@group(0) @binding(5) var x_strides: array; +@group(0) @binding(6) var y_strides: array; +@group(0) @binding(7) var out_shape: array; +@group(0) @binding(8) var bc_params: WhereBroadcastParams; + +fn compute_out_stride(d: u32, ndim: u32) -> u32 { + var stride: u32 = 1u; + for (var i: u32 = d + 1u; i < ndim; i = i + 1u) { + stride = stride * out_shape[i]; + } + return stride; +} + +@compute @workgroup_size(256) +fn where_broadcast_cond_f32_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= bc_params.numel) { + return; + } + + var remaining = idx; + var cond_offset: u32 = 0u; + var x_offset: u32 = 0u; + var y_offset: u32 = 0u; + + for (var d: u32 = 0u; d < bc_params.ndim; d = d + 1u) { + let s = compute_out_stride(d, bc_params.ndim); + let coord = remaining / s; + remaining = remaining % s; + cond_offset = cond_offset + coord * cond_strides[d]; + x_offset = x_offset + coord * x_strides[d]; + y_offset = y_offset + coord * y_strides[d]; + } + + let cond_val = bc_cond[cond_offset] != 0.0; + bc_out[idx] = select(bc_y[y_offset], bc_x[x_offset], cond_val); +} diff --git a/src/runtime/wgpu/shaders/where_broadcast_cond_f32_u32.wgsl b/src/runtime/wgpu/shaders/where_broadcast_cond_f32_u32.wgsl new file mode 100644 index 00000000..1b58b0c6 --- /dev/null +++ b/src/runtime/wgpu/shaders/where_broadcast_cond_f32_u32.wgsl @@ -0,0 +1,52 @@ +// where_broadcast_cond: condition=f32, output=u32 +// out[i] = cond[cond_offset] != 0.0 ? x[x_offset] : y[y_offset] (with broadcasting) + +struct WhereBroadcastParams { + numel: u32, + ndim: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var bc_cond: array; +@group(0) @binding(1) var bc_x: array; +@group(0) @binding(2) var bc_y: array; +@group(0) @binding(3) var bc_out: array; +@group(0) @binding(4) var cond_strides: array; +@group(0) @binding(5) var x_strides: array; +@group(0) @binding(6) var y_strides: array; +@group(0) @binding(7) var out_shape: array; +@group(0) @binding(8) var bc_params: WhereBroadcastParams; + +fn compute_out_stride(d: u32, ndim: u32) -> u32 { + var stride: u32 = 1u; + for (var i: u32 = d + 1u; i < ndim; i = i + 1u) { + stride = stride * out_shape[i]; + } + return stride; +} + +@compute @workgroup_size(256) +fn where_broadcast_cond_f32_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= bc_params.numel) { + return; + } + + var remaining = idx; + var cond_offset: u32 = 0u; + var x_offset: u32 = 0u; + var y_offset: u32 = 0u; + + for (var d: u32 = 0u; d < bc_params.ndim; d = d + 1u) { + let s = compute_out_stride(d, bc_params.ndim); + let coord = remaining / s; + remaining = remaining % s; + cond_offset = cond_offset + coord * cond_strides[d]; + x_offset = x_offset + coord * x_strides[d]; + y_offset = y_offset + coord * y_strides[d]; + } + + let cond_val = bc_cond[cond_offset] != 0.0; + bc_out[idx] = select(bc_y[y_offset], bc_x[x_offset], cond_val); +} diff --git a/src/runtime/wgpu/shaders/where_broadcast_cond_i32_f32.wgsl b/src/runtime/wgpu/shaders/where_broadcast_cond_i32_f32.wgsl new file mode 100644 index 00000000..8d13a0d1 --- /dev/null +++ b/src/runtime/wgpu/shaders/where_broadcast_cond_i32_f32.wgsl @@ -0,0 +1,52 @@ +// where_broadcast_cond: condition=i32, output=f32 +// out[i] = cond[cond_offset] != 0 ? x[x_offset] : y[y_offset] (with broadcasting) + +struct WhereBroadcastParams { + numel: u32, + ndim: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var bc_cond: array; +@group(0) @binding(1) var bc_x: array; +@group(0) @binding(2) var bc_y: array; +@group(0) @binding(3) var bc_out: array; +@group(0) @binding(4) var cond_strides: array; +@group(0) @binding(5) var x_strides: array; +@group(0) @binding(6) var y_strides: array; +@group(0) @binding(7) var out_shape: array; +@group(0) @binding(8) var bc_params: WhereBroadcastParams; + +fn compute_out_stride(d: u32, ndim: u32) -> u32 { + var stride: u32 = 1u; + for (var i: u32 = d + 1u; i < ndim; i = i + 1u) { + stride = stride * out_shape[i]; + } + return stride; +} + +@compute @workgroup_size(256) +fn where_broadcast_cond_i32_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= bc_params.numel) { + return; + } + + var remaining = idx; + var cond_offset: u32 = 0u; + var x_offset: u32 = 0u; + var y_offset: u32 = 0u; + + for (var d: u32 = 0u; d < bc_params.ndim; d = d + 1u) { + let s = compute_out_stride(d, bc_params.ndim); + let coord = remaining / s; + remaining = remaining % s; + cond_offset = cond_offset + coord * cond_strides[d]; + x_offset = x_offset + coord * x_strides[d]; + y_offset = y_offset + coord * y_strides[d]; + } + + let cond_val = bc_cond[cond_offset] != 0; + bc_out[idx] = select(bc_y[y_offset], bc_x[x_offset], cond_val); +} diff --git a/src/runtime/wgpu/shaders/where_broadcast_cond_i32_i32.wgsl b/src/runtime/wgpu/shaders/where_broadcast_cond_i32_i32.wgsl new file mode 100644 index 00000000..166f4b93 --- /dev/null +++ b/src/runtime/wgpu/shaders/where_broadcast_cond_i32_i32.wgsl @@ -0,0 +1,52 @@ +// where_broadcast_cond: condition=i32, output=i32 +// out[i] = cond[cond_offset] != 0 ? x[x_offset] : y[y_offset] (with broadcasting) + +struct WhereBroadcastParams { + numel: u32, + ndim: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var bc_cond: array; +@group(0) @binding(1) var bc_x: array; +@group(0) @binding(2) var bc_y: array; +@group(0) @binding(3) var bc_out: array; +@group(0) @binding(4) var cond_strides: array; +@group(0) @binding(5) var x_strides: array; +@group(0) @binding(6) var y_strides: array; +@group(0) @binding(7) var out_shape: array; +@group(0) @binding(8) var bc_params: WhereBroadcastParams; + +fn compute_out_stride(d: u32, ndim: u32) -> u32 { + var stride: u32 = 1u; + for (var i: u32 = d + 1u; i < ndim; i = i + 1u) { + stride = stride * out_shape[i]; + } + return stride; +} + +@compute @workgroup_size(256) +fn where_broadcast_cond_i32_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= bc_params.numel) { + return; + } + + var remaining = idx; + var cond_offset: u32 = 0u; + var x_offset: u32 = 0u; + var y_offset: u32 = 0u; + + for (var d: u32 = 0u; d < bc_params.ndim; d = d + 1u) { + let s = compute_out_stride(d, bc_params.ndim); + let coord = remaining / s; + remaining = remaining % s; + cond_offset = cond_offset + coord * cond_strides[d]; + x_offset = x_offset + coord * x_strides[d]; + y_offset = y_offset + coord * y_strides[d]; + } + + let cond_val = bc_cond[cond_offset] != 0; + bc_out[idx] = select(bc_y[y_offset], bc_x[x_offset], cond_val); +} diff --git a/src/runtime/wgpu/shaders/where_broadcast_cond_i32_u32.wgsl b/src/runtime/wgpu/shaders/where_broadcast_cond_i32_u32.wgsl new file mode 100644 index 00000000..0a75178e --- /dev/null +++ b/src/runtime/wgpu/shaders/where_broadcast_cond_i32_u32.wgsl @@ -0,0 +1,52 @@ +// where_broadcast_cond: condition=i32, output=u32 +// out[i] = cond[cond_offset] != 0 ? x[x_offset] : y[y_offset] (with broadcasting) + +struct WhereBroadcastParams { + numel: u32, + ndim: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var bc_cond: array; +@group(0) @binding(1) var bc_x: array; +@group(0) @binding(2) var bc_y: array; +@group(0) @binding(3) var bc_out: array; +@group(0) @binding(4) var cond_strides: array; +@group(0) @binding(5) var x_strides: array; +@group(0) @binding(6) var y_strides: array; +@group(0) @binding(7) var out_shape: array; +@group(0) @binding(8) var bc_params: WhereBroadcastParams; + +fn compute_out_stride(d: u32, ndim: u32) -> u32 { + var stride: u32 = 1u; + for (var i: u32 = d + 1u; i < ndim; i = i + 1u) { + stride = stride * out_shape[i]; + } + return stride; +} + +@compute @workgroup_size(256) +fn where_broadcast_cond_i32_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= bc_params.numel) { + return; + } + + var remaining = idx; + var cond_offset: u32 = 0u; + var x_offset: u32 = 0u; + var y_offset: u32 = 0u; + + for (var d: u32 = 0u; d < bc_params.ndim; d = d + 1u) { + let s = compute_out_stride(d, bc_params.ndim); + let coord = remaining / s; + remaining = remaining % s; + cond_offset = cond_offset + coord * cond_strides[d]; + x_offset = x_offset + coord * x_strides[d]; + y_offset = y_offset + coord * y_strides[d]; + } + + let cond_val = bc_cond[cond_offset] != 0; + bc_out[idx] = select(bc_y[y_offset], bc_x[x_offset], cond_val); +} diff --git a/src/runtime/wgpu/shaders/where_broadcast_cond_u32_f32.wgsl b/src/runtime/wgpu/shaders/where_broadcast_cond_u32_f32.wgsl new file mode 100644 index 00000000..1fcf6f5b --- /dev/null +++ b/src/runtime/wgpu/shaders/where_broadcast_cond_u32_f32.wgsl @@ -0,0 +1,52 @@ +// where_broadcast_cond: condition=u32, output=f32 +// out[i] = cond[cond_offset] != 0 ? x[x_offset] : y[y_offset] (with broadcasting) + +struct WhereBroadcastParams { + numel: u32, + ndim: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var bc_cond: array; +@group(0) @binding(1) var bc_x: array; +@group(0) @binding(2) var bc_y: array; +@group(0) @binding(3) var bc_out: array; +@group(0) @binding(4) var cond_strides: array; +@group(0) @binding(5) var x_strides: array; +@group(0) @binding(6) var y_strides: array; +@group(0) @binding(7) var out_shape: array; +@group(0) @binding(8) var bc_params: WhereBroadcastParams; + +fn compute_out_stride(d: u32, ndim: u32) -> u32 { + var stride: u32 = 1u; + for (var i: u32 = d + 1u; i < ndim; i = i + 1u) { + stride = stride * out_shape[i]; + } + return stride; +} + +@compute @workgroup_size(256) +fn where_broadcast_cond_u32_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= bc_params.numel) { + return; + } + + var remaining = idx; + var cond_offset: u32 = 0u; + var x_offset: u32 = 0u; + var y_offset: u32 = 0u; + + for (var d: u32 = 0u; d < bc_params.ndim; d = d + 1u) { + let s = compute_out_stride(d, bc_params.ndim); + let coord = remaining / s; + remaining = remaining % s; + cond_offset = cond_offset + coord * cond_strides[d]; + x_offset = x_offset + coord * x_strides[d]; + y_offset = y_offset + coord * y_strides[d]; + } + + let cond_val = bc_cond[cond_offset] != 0u; + bc_out[idx] = select(bc_y[y_offset], bc_x[x_offset], cond_val); +} diff --git a/src/runtime/wgpu/shaders/where_broadcast_cond_u32_i32.wgsl b/src/runtime/wgpu/shaders/where_broadcast_cond_u32_i32.wgsl new file mode 100644 index 00000000..2de4db24 --- /dev/null +++ b/src/runtime/wgpu/shaders/where_broadcast_cond_u32_i32.wgsl @@ -0,0 +1,52 @@ +// where_broadcast_cond: condition=u32, output=i32 +// out[i] = cond[cond_offset] != 0 ? x[x_offset] : y[y_offset] (with broadcasting) + +struct WhereBroadcastParams { + numel: u32, + ndim: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var bc_cond: array; +@group(0) @binding(1) var bc_x: array; +@group(0) @binding(2) var bc_y: array; +@group(0) @binding(3) var bc_out: array; +@group(0) @binding(4) var cond_strides: array; +@group(0) @binding(5) var x_strides: array; +@group(0) @binding(6) var y_strides: array; +@group(0) @binding(7) var out_shape: array; +@group(0) @binding(8) var bc_params: WhereBroadcastParams; + +fn compute_out_stride(d: u32, ndim: u32) -> u32 { + var stride: u32 = 1u; + for (var i: u32 = d + 1u; i < ndim; i = i + 1u) { + stride = stride * out_shape[i]; + } + return stride; +} + +@compute @workgroup_size(256) +fn where_broadcast_cond_u32_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= bc_params.numel) { + return; + } + + var remaining = idx; + var cond_offset: u32 = 0u; + var x_offset: u32 = 0u; + var y_offset: u32 = 0u; + + for (var d: u32 = 0u; d < bc_params.ndim; d = d + 1u) { + let s = compute_out_stride(d, bc_params.ndim); + let coord = remaining / s; + remaining = remaining % s; + cond_offset = cond_offset + coord * cond_strides[d]; + x_offset = x_offset + coord * x_strides[d]; + y_offset = y_offset + coord * y_strides[d]; + } + + let cond_val = bc_cond[cond_offset] != 0u; + bc_out[idx] = select(bc_y[y_offset], bc_x[x_offset], cond_val); +} diff --git a/src/runtime/wgpu/shaders/where_broadcast_cond_u32_u32.wgsl b/src/runtime/wgpu/shaders/where_broadcast_cond_u32_u32.wgsl new file mode 100644 index 00000000..736f6371 --- /dev/null +++ b/src/runtime/wgpu/shaders/where_broadcast_cond_u32_u32.wgsl @@ -0,0 +1,52 @@ +// where_broadcast_cond: condition=u32, output=u32 +// out[i] = cond[cond_offset] != 0 ? x[x_offset] : y[y_offset] (with broadcasting) + +struct WhereBroadcastParams { + numel: u32, + ndim: u32, + _pad0: u32, + _pad1: u32, +} + +@group(0) @binding(0) var bc_cond: array; +@group(0) @binding(1) var bc_x: array; +@group(0) @binding(2) var bc_y: array; +@group(0) @binding(3) var bc_out: array; +@group(0) @binding(4) var cond_strides: array; +@group(0) @binding(5) var x_strides: array; +@group(0) @binding(6) var y_strides: array; +@group(0) @binding(7) var out_shape: array; +@group(0) @binding(8) var bc_params: WhereBroadcastParams; + +fn compute_out_stride(d: u32, ndim: u32) -> u32 { + var stride: u32 = 1u; + for (var i: u32 = d + 1u; i < ndim; i = i + 1u) { + stride = stride * out_shape[i]; + } + return stride; +} + +@compute @workgroup_size(256) +fn where_broadcast_cond_u32_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx >= bc_params.numel) { + return; + } + + var remaining = idx; + var cond_offset: u32 = 0u; + var x_offset: u32 = 0u; + var y_offset: u32 = 0u; + + for (var d: u32 = 0u; d < bc_params.ndim; d = d + 1u) { + let s = compute_out_stride(d, bc_params.ndim); + let coord = remaining / s; + remaining = remaining % s; + cond_offset = cond_offset + coord * cond_strides[d]; + x_offset = x_offset + coord * x_strides[d]; + y_offset = y_offset + coord * y_strides[d]; + } + + let cond_val = bc_cond[cond_offset] != 0u; + bc_out[idx] = select(bc_y[y_offset], bc_x[x_offset], cond_val); +} diff --git a/src/runtime/wgpu/shaders/where_cond_f32_f32.wgsl b/src/runtime/wgpu/shaders/where_cond_f32_f32.wgsl new file mode 100644 index 00000000..1867addc --- /dev/null +++ b/src/runtime/wgpu/shaders/where_cond_f32_f32.wgsl @@ -0,0 +1,21 @@ +// where_cond: condition=f32, output=f32 +// out[i] = cond[i] != 0.0 ? x[i] : y[i] + +struct WhereParams { + numel: u32, +} + +@group(0) @binding(0) var where_cond_arr: array; +@group(0) @binding(1) var where_x: array; +@group(0) @binding(2) var where_y: array; +@group(0) @binding(3) var where_out: array; +@group(0) @binding(4) var where_params: WhereParams; + +@compute @workgroup_size(256) +fn where_cond_f32_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < where_params.numel) { + let cond_val = where_cond_arr[idx] != 0.0; + where_out[idx] = select(where_y[idx], where_x[idx], cond_val); + } +} diff --git a/src/runtime/wgpu/shaders/where_cond_f32_i32.wgsl b/src/runtime/wgpu/shaders/where_cond_f32_i32.wgsl new file mode 100644 index 00000000..0dcd1930 --- /dev/null +++ b/src/runtime/wgpu/shaders/where_cond_f32_i32.wgsl @@ -0,0 +1,21 @@ +// where_cond: condition=f32, output=i32 +// out[i] = cond[i] != 0.0 ? x[i] : y[i] + +struct WhereParams { + numel: u32, +} + +@group(0) @binding(0) var where_cond_arr: array; +@group(0) @binding(1) var where_x: array; +@group(0) @binding(2) var where_y: array; +@group(0) @binding(3) var where_out: array; +@group(0) @binding(4) var where_params: WhereParams; + +@compute @workgroup_size(256) +fn where_cond_f32_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < where_params.numel) { + let cond_val = where_cond_arr[idx] != 0.0; + where_out[idx] = select(where_y[idx], where_x[idx], cond_val); + } +} diff --git a/src/runtime/wgpu/shaders/where_cond_f32_u32.wgsl b/src/runtime/wgpu/shaders/where_cond_f32_u32.wgsl new file mode 100644 index 00000000..ba0e94da --- /dev/null +++ b/src/runtime/wgpu/shaders/where_cond_f32_u32.wgsl @@ -0,0 +1,21 @@ +// where_cond: condition=f32, output=u32 +// out[i] = cond[i] != 0.0 ? x[i] : y[i] + +struct WhereParams { + numel: u32, +} + +@group(0) @binding(0) var where_cond_arr: array; +@group(0) @binding(1) var where_x: array; +@group(0) @binding(2) var where_y: array; +@group(0) @binding(3) var where_out: array; +@group(0) @binding(4) var where_params: WhereParams; + +@compute @workgroup_size(256) +fn where_cond_f32_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < where_params.numel) { + let cond_val = where_cond_arr[idx] != 0.0; + where_out[idx] = select(where_y[idx], where_x[idx], cond_val); + } +} diff --git a/src/runtime/wgpu/shaders/where_cond_i32_f32.wgsl b/src/runtime/wgpu/shaders/where_cond_i32_f32.wgsl new file mode 100644 index 00000000..70a23214 --- /dev/null +++ b/src/runtime/wgpu/shaders/where_cond_i32_f32.wgsl @@ -0,0 +1,21 @@ +// where_cond: condition=i32, output=f32 +// out[i] = cond[i] != 0 ? x[i] : y[i] + +struct WhereParams { + numel: u32, +} + +@group(0) @binding(0) var where_cond_arr: array; +@group(0) @binding(1) var where_x: array; +@group(0) @binding(2) var where_y: array; +@group(0) @binding(3) var where_out: array; +@group(0) @binding(4) var where_params: WhereParams; + +@compute @workgroup_size(256) +fn where_cond_i32_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < where_params.numel) { + let cond_val = where_cond_arr[idx] != 0; + where_out[idx] = select(where_y[idx], where_x[idx], cond_val); + } +} diff --git a/src/runtime/wgpu/shaders/where_cond_i32_i32.wgsl b/src/runtime/wgpu/shaders/where_cond_i32_i32.wgsl new file mode 100644 index 00000000..15633cc7 --- /dev/null +++ b/src/runtime/wgpu/shaders/where_cond_i32_i32.wgsl @@ -0,0 +1,21 @@ +// where_cond: condition=i32, output=i32 +// out[i] = cond[i] != 0 ? x[i] : y[i] + +struct WhereParams { + numel: u32, +} + +@group(0) @binding(0) var where_cond_arr: array; +@group(0) @binding(1) var where_x: array; +@group(0) @binding(2) var where_y: array; +@group(0) @binding(3) var where_out: array; +@group(0) @binding(4) var where_params: WhereParams; + +@compute @workgroup_size(256) +fn where_cond_i32_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < where_params.numel) { + let cond_val = where_cond_arr[idx] != 0; + where_out[idx] = select(where_y[idx], where_x[idx], cond_val); + } +} diff --git a/src/runtime/wgpu/shaders/where_cond_i32_u32.wgsl b/src/runtime/wgpu/shaders/where_cond_i32_u32.wgsl new file mode 100644 index 00000000..5be675e3 --- /dev/null +++ b/src/runtime/wgpu/shaders/where_cond_i32_u32.wgsl @@ -0,0 +1,21 @@ +// where_cond: condition=i32, output=u32 +// out[i] = cond[i] != 0 ? x[i] : y[i] + +struct WhereParams { + numel: u32, +} + +@group(0) @binding(0) var where_cond_arr: array; +@group(0) @binding(1) var where_x: array; +@group(0) @binding(2) var where_y: array; +@group(0) @binding(3) var where_out: array; +@group(0) @binding(4) var where_params: WhereParams; + +@compute @workgroup_size(256) +fn where_cond_i32_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < where_params.numel) { + let cond_val = where_cond_arr[idx] != 0; + where_out[idx] = select(where_y[idx], where_x[idx], cond_val); + } +} diff --git a/src/runtime/wgpu/shaders/where_cond_u32_f32.wgsl b/src/runtime/wgpu/shaders/where_cond_u32_f32.wgsl new file mode 100644 index 00000000..ee9c7adf --- /dev/null +++ b/src/runtime/wgpu/shaders/where_cond_u32_f32.wgsl @@ -0,0 +1,21 @@ +// where_cond: condition=u32, output=f32 +// out[i] = cond[i] != 0 ? x[i] : y[i] + +struct WhereParams { + numel: u32, +} + +@group(0) @binding(0) var where_cond_arr: array; +@group(0) @binding(1) var where_x: array; +@group(0) @binding(2) var where_y: array; +@group(0) @binding(3) var where_out: array; +@group(0) @binding(4) var where_params: WhereParams; + +@compute @workgroup_size(256) +fn where_cond_u32_f32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < where_params.numel) { + let cond_val = where_cond_arr[idx] != 0u; + where_out[idx] = select(where_y[idx], where_x[idx], cond_val); + } +} diff --git a/src/runtime/wgpu/shaders/where_cond_u32_i32.wgsl b/src/runtime/wgpu/shaders/where_cond_u32_i32.wgsl new file mode 100644 index 00000000..c9d5d330 --- /dev/null +++ b/src/runtime/wgpu/shaders/where_cond_u32_i32.wgsl @@ -0,0 +1,21 @@ +// where_cond: condition=u32, output=i32 +// out[i] = cond[i] != 0 ? x[i] : y[i] + +struct WhereParams { + numel: u32, +} + +@group(0) @binding(0) var where_cond_arr: array; +@group(0) @binding(1) var where_x: array; +@group(0) @binding(2) var where_y: array; +@group(0) @binding(3) var where_out: array; +@group(0) @binding(4) var where_params: WhereParams; + +@compute @workgroup_size(256) +fn where_cond_u32_i32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < where_params.numel) { + let cond_val = where_cond_arr[idx] != 0u; + where_out[idx] = select(where_y[idx], where_x[idx], cond_val); + } +} diff --git a/src/runtime/wgpu/shaders/where_cond_u32_u32.wgsl b/src/runtime/wgpu/shaders/where_cond_u32_u32.wgsl new file mode 100644 index 00000000..0563c632 --- /dev/null +++ b/src/runtime/wgpu/shaders/where_cond_u32_u32.wgsl @@ -0,0 +1,21 @@ +// where_cond: condition=u32, output=u32 +// out[i] = cond[i] != 0 ? x[i] : y[i] + +struct WhereParams { + numel: u32, +} + +@group(0) @binding(0) var where_cond_arr: array; +@group(0) @binding(1) var where_x: array; +@group(0) @binding(2) var where_y: array; +@group(0) @binding(3) var where_out: array; +@group(0) @binding(4) var where_params: WhereParams; + +@compute @workgroup_size(256) +fn where_cond_u32_u32(@builtin(global_invocation_id) gid: vec3) { + let idx = gid.x; + if (idx < where_params.numel) { + let cond_val = where_cond_arr[idx] != 0u; + where_out[idx] = select(where_y[idx], where_x[idx], cond_val); + } +} diff --git a/src/runtime/wgpu/shaders/where_launcher.rs b/src/runtime/wgpu/shaders/where_launcher.rs index 65bc4ede..6b35d7ab 100644 --- a/src/runtime/wgpu/shaders/where_launcher.rs +++ b/src/runtime/wgpu/shaders/where_launcher.rs @@ -1,132 +1,171 @@ -//! Where (conditional select) WGSL kernel launchers -//! -//! Provides launchers for where_cond operations with multi-dtype support: -//! - `launch_where_op` - Legacy F32-only version for backward compatibility -//! - `launch_where_generic_op` - Generic condition dtype support (F32, I32, U32) -//! - `launch_where_broadcast_op` - Broadcast support with generic condition dtype - -use std::collections::HashMap; -use std::sync::{OnceLock, RwLock, RwLockReadGuard, RwLockWriteGuard}; - -// ============================================================================ -// Lock Helpers (Handle Poisoned Locks Gracefully) -// ============================================================================ - -/// Acquire read lock, recovering from poison if necessary. -fn read_lock(lock: &RwLock) -> RwLockReadGuard<'_, T> { - lock.read().unwrap_or_else(|poisoned| poisoned.into_inner()) -} - -/// Acquire write lock, recovering from poison if necessary. -fn write_lock(lock: &RwLock) -> RwLockWriteGuard<'_, T> { - lock.write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} +//! Where (conditional select) WGSL kernel launchers. F32/I32/U32 supported. use wgpu::{Buffer, Queue}; -use super::generator::{dtype_suffix, generate_where_cond_shader}; use super::pipeline::{LayoutKey, PipelineCache, workgroup_count}; use crate::dtype::DType; -use crate::error::Result; +use crate::error::{Error, Result}; // ============================================================================ -// Shader Caching +// Static shaders — element-wise (4 storage + 1 uniform) // ============================================================================ -/// Cache for where_cond shader references (leaked once per cond_dtype+out_dtype combination) -static WHERE_SHADER_CACHE: OnceLock>> = - OnceLock::new(); - -/// Cache for where_cond module key references -static WHERE_MODULE_KEY_CACHE: OnceLock>> = - OnceLock::new(); +const WHERE_COND_F32_F32: &str = include_str!("where_cond_f32_f32.wgsl"); +const WHERE_COND_F32_I32: &str = include_str!("where_cond_f32_i32.wgsl"); +const WHERE_COND_F32_U32: &str = include_str!("where_cond_f32_u32.wgsl"); +const WHERE_COND_I32_F32: &str = include_str!("where_cond_i32_f32.wgsl"); +const WHERE_COND_I32_I32: &str = include_str!("where_cond_i32_i32.wgsl"); +const WHERE_COND_I32_U32: &str = include_str!("where_cond_i32_u32.wgsl"); +const WHERE_COND_U32_F32: &str = include_str!("where_cond_u32_f32.wgsl"); +const WHERE_COND_U32_I32: &str = include_str!("where_cond_u32_i32.wgsl"); +const WHERE_COND_U32_U32: &str = include_str!("where_cond_u32_u32.wgsl"); -/// Cache for where_cond entry point references -static WHERE_ENTRY_CACHE: OnceLock>> = - OnceLock::new(); - -/// Get or generate where_cond shader for specific cond_dtype and out_dtype. -fn get_or_leak_where_shader(cond_dtype: DType, out_dtype: DType) -> Result<&'static str> { - let cache = WHERE_SHADER_CACHE.get_or_init(|| RwLock::new(HashMap::new())); - - { - let read_guard = read_lock(cache); - if let Some(&shader_ref) = read_guard.get(&(cond_dtype, out_dtype)) { - return Ok(shader_ref); - } - } - - let shader = generate_where_cond_shader(cond_dtype, out_dtype)?; - let leaked: &'static str = Box::leak(shader.into_boxed_str()); +// ============================================================================ +// Static shaders — broadcast (8 storage + 1 uniform) +// ============================================================================ - let mut write_guard = write_lock(cache); - write_guard.insert((cond_dtype, out_dtype), leaked); +const WHERE_BC_F32_F32: &str = include_str!("where_broadcast_cond_f32_f32.wgsl"); +const WHERE_BC_F32_I32: &str = include_str!("where_broadcast_cond_f32_i32.wgsl"); +const WHERE_BC_F32_U32: &str = include_str!("where_broadcast_cond_f32_u32.wgsl"); +const WHERE_BC_I32_F32: &str = include_str!("where_broadcast_cond_i32_f32.wgsl"); +const WHERE_BC_I32_I32: &str = include_str!("where_broadcast_cond_i32_i32.wgsl"); +const WHERE_BC_I32_U32: &str = include_str!("where_broadcast_cond_i32_u32.wgsl"); +const WHERE_BC_U32_F32: &str = include_str!("where_broadcast_cond_u32_f32.wgsl"); +const WHERE_BC_U32_I32: &str = include_str!("where_broadcast_cond_u32_i32.wgsl"); +const WHERE_BC_U32_U32: &str = include_str!("where_broadcast_cond_u32_u32.wgsl"); - Ok(leaked) -} - -/// Get module key for where_cond shader. -fn get_or_leak_where_module_key(cond_dtype: DType, out_dtype: DType) -> Result<&'static str> { - let cache = WHERE_MODULE_KEY_CACHE.get_or_init(|| RwLock::new(HashMap::new())); +// ============================================================================ +// Shader dispatch helpers +// ============================================================================ - { - let read_guard = read_lock(cache); - if let Some(&key_ref) = read_guard.get(&(cond_dtype, out_dtype)) { - return Ok(key_ref); +/// Returns (shader, module_key, entry_point) for element-wise where_cond. +fn where_shader_info( + cond_dtype: DType, + out_dtype: DType, +) -> Result<(&'static str, &'static str, &'static str)> { + Ok(match (cond_dtype, out_dtype) { + (DType::F32, DType::F32) => ( + WHERE_COND_F32_F32, + "where_cond_f32_f32", + "where_cond_f32_f32", + ), + (DType::F32, DType::I32) => ( + WHERE_COND_F32_I32, + "where_cond_f32_i32", + "where_cond_f32_i32", + ), + (DType::F32, DType::U32) => ( + WHERE_COND_F32_U32, + "where_cond_f32_u32", + "where_cond_f32_u32", + ), + (DType::I32, DType::F32) => ( + WHERE_COND_I32_F32, + "where_cond_i32_f32", + "where_cond_i32_f32", + ), + (DType::I32, DType::I32) => ( + WHERE_COND_I32_I32, + "where_cond_i32_i32", + "where_cond_i32_i32", + ), + (DType::I32, DType::U32) => ( + WHERE_COND_I32_U32, + "where_cond_i32_u32", + "where_cond_i32_u32", + ), + (DType::U32, DType::F32) => ( + WHERE_COND_U32_F32, + "where_cond_u32_f32", + "where_cond_u32_f32", + ), + (DType::U32, DType::I32) => ( + WHERE_COND_U32_I32, + "where_cond_u32_i32", + "where_cond_u32_i32", + ), + (DType::U32, DType::U32) => ( + WHERE_COND_U32_U32, + "where_cond_u32_u32", + "where_cond_u32_u32", + ), + _ => { + return Err(Error::UnsupportedDType { + dtype: cond_dtype, + op: "where_cond (WebGPU)", + }); } - } - - let cond_suffix = dtype_suffix(cond_dtype)?; - let out_suffix = dtype_suffix(out_dtype)?; - let key = format!("where_cond_{}_{}", cond_suffix, out_suffix); - let leaked: &'static str = Box::leak(key.into_boxed_str()); - - let mut write_guard = write_lock(cache); - write_guard.insert((cond_dtype, out_dtype), leaked); - - Ok(leaked) + }) } -/// Get entry point name for where_cond operation. -fn get_or_leak_where_entry( +/// Returns (shader, module_key, entry_point) for broadcast where_cond. +fn where_broadcast_shader_info( cond_dtype: DType, out_dtype: DType, - broadcast: bool, -) -> Result<&'static str> { - let cache = WHERE_ENTRY_CACHE.get_or_init(|| RwLock::new(HashMap::new())); - - { - let read_guard = read_lock(cache); - if let Some(&entry_ref) = read_guard.get(&(cond_dtype, out_dtype, broadcast)) { - return Ok(entry_ref); +) -> Result<(&'static str, &'static str, &'static str)> { + Ok(match (cond_dtype, out_dtype) { + (DType::F32, DType::F32) => ( + WHERE_BC_F32_F32, + "where_broadcast_cond_f32_f32", + "where_broadcast_cond_f32_f32", + ), + (DType::F32, DType::I32) => ( + WHERE_BC_F32_I32, + "where_broadcast_cond_f32_i32", + "where_broadcast_cond_f32_i32", + ), + (DType::F32, DType::U32) => ( + WHERE_BC_F32_U32, + "where_broadcast_cond_f32_u32", + "where_broadcast_cond_f32_u32", + ), + (DType::I32, DType::F32) => ( + WHERE_BC_I32_F32, + "where_broadcast_cond_i32_f32", + "where_broadcast_cond_i32_f32", + ), + (DType::I32, DType::I32) => ( + WHERE_BC_I32_I32, + "where_broadcast_cond_i32_i32", + "where_broadcast_cond_i32_i32", + ), + (DType::I32, DType::U32) => ( + WHERE_BC_I32_U32, + "where_broadcast_cond_i32_u32", + "where_broadcast_cond_i32_u32", + ), + (DType::U32, DType::F32) => ( + WHERE_BC_U32_F32, + "where_broadcast_cond_u32_f32", + "where_broadcast_cond_u32_f32", + ), + (DType::U32, DType::I32) => ( + WHERE_BC_U32_I32, + "where_broadcast_cond_u32_i32", + "where_broadcast_cond_u32_i32", + ), + (DType::U32, DType::U32) => ( + WHERE_BC_U32_U32, + "where_broadcast_cond_u32_u32", + "where_broadcast_cond_u32_u32", + ), + _ => { + return Err(Error::UnsupportedDType { + dtype: cond_dtype, + op: "where_broadcast_cond (WebGPU)", + }); } - } - - let cond_suffix = dtype_suffix(cond_dtype)?; - let out_suffix = dtype_suffix(out_dtype)?; - let prefix = if broadcast { - "where_broadcast_cond" - } else { - "where_cond" - }; - let entry = format!("{}_{}_{}", prefix, cond_suffix, out_suffix); - let leaked: &'static str = Box::leak(entry.into_boxed_str()); - - let mut write_guard = write_lock(cache); - write_guard.insert((cond_dtype, out_dtype, broadcast), leaked); - - Ok(leaked) + }) } // ============================================================================ // Kernel Launchers // ============================================================================ -/// Launch where conditional operation kernel. +/// Launch where conditional operation kernel (F32-only legacy wrapper). /// -/// Computes `out[i] = cond[i] ? x[i] : y[i]` for all elements. -/// This is the legacy F32-only version for backward compatibility. +/// Computes `out[i] = cond[i] != 0 ? x[i] : y[i]` for all elements. +/// Delegates to `launch_where_generic_op` with F32 condition dtype. #[allow(clippy::too_many_arguments)] pub fn launch_where_op( cache: &PipelineCache, @@ -139,7 +178,6 @@ pub fn launch_where_op( numel: usize, dtype: DType, ) -> Result<()> { - // Delegate to generic version with F32 condition launch_where_generic_op( cache, queue, @@ -171,9 +209,7 @@ pub fn launch_where_generic_op( cond_dtype: DType, out_dtype: DType, ) -> Result<()> { - let shader = get_or_leak_where_shader(cond_dtype, out_dtype)?; - let module_key = get_or_leak_where_module_key(cond_dtype, out_dtype)?; - let entry_point = get_or_leak_where_entry(cond_dtype, out_dtype, false)?; + let (shader, module_key, entry_point) = where_shader_info(cond_dtype, out_dtype)?; let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { @@ -208,7 +244,7 @@ pub fn launch_where_generic_op( /// Launch broadcast where conditional operation kernel. /// /// Computes `out[i] = cond[cond_offset] != 0 ? x[x_offset] : y[y_offset]` -/// with broadcasting support. +/// with broadcasting support via per-dimension stride buffers. #[allow(clippy::too_many_arguments)] pub fn launch_where_broadcast_op( cache: &PipelineCache, @@ -226,9 +262,7 @@ pub fn launch_where_broadcast_op( cond_dtype: DType, out_dtype: DType, ) -> Result<()> { - let shader = get_or_leak_where_shader(cond_dtype, out_dtype)?; - let module_key = get_or_leak_where_module_key(cond_dtype, out_dtype)?; - let entry_point = get_or_leak_where_entry(cond_dtype, out_dtype, true)?; + let (shader, module_key, entry_point) = where_broadcast_shader_info(cond_dtype, out_dtype)?; let module = cache.get_or_create_module(module_key, shader); let layout = cache.get_or_create_layout(LayoutKey { diff --git a/src/runtime/wgpu/sparse/ic0.rs b/src/runtime/wgpu/sparse/ic0.rs index fb08cc42..40299b34 100644 --- a/src/runtime/wgpu/sparse/ic0.rs +++ b/src/runtime/wgpu/sparse/ic0.rs @@ -3,7 +3,6 @@ use wgpu::{BindGroupDescriptor, BindGroupEntry, BufferUsages}; use super::super::ops::helpers::get_tensor_buffer; -use super::super::shaders::generator::sparse_linalg::generate_ic0_level_shader; use super::super::{WgpuClient, WgpuRuntime}; use super::common::{ WORKGROUP_SIZE, cast_i64_to_i32_gpu, create_ilu_ic_layout, extract_lower_wgpu, @@ -17,6 +16,8 @@ use crate::error::Result; use crate::sparse::CsrData; use crate::tensor::Tensor; +const IC0_LEVEL_F32: &str = include_str!("../shaders/sparse_ic0_level_f32.wgsl"); + /// IC(0) factorization for WebGPU. pub fn ic0_wgpu( client: &WgpuClient, @@ -111,14 +112,13 @@ fn launch_ic0_level( n: usize, diagonal_shift: f32, ) -> Result<()> { - let shader_source = generate_ic0_level_shader(DType::F32)?; let module = client .pipeline_cache - .get_or_create_module_from_source("ic0_level_f32", &shader_source); + .get_or_create_module("ic0_level_f32", IC0_LEVEL_F32); let layout = create_ilu_ic_layout(&client.wgpu_device); - let pipeline = client.pipeline_cache.get_or_create_dynamic_pipeline( + let pipeline = client.pipeline_cache.get_or_create_pipeline( "ic0_level_f32", "ic0_level_f32", &module, diff --git a/src/runtime/wgpu/sparse/ilu0.rs b/src/runtime/wgpu/sparse/ilu0.rs index f9f76047..e016a125 100644 --- a/src/runtime/wgpu/sparse/ilu0.rs +++ b/src/runtime/wgpu/sparse/ilu0.rs @@ -3,9 +3,6 @@ use wgpu::{BindGroupDescriptor, BindGroupEntry, BufferUsages}; use super::super::ops::helpers::get_tensor_buffer; -use super::super::shaders::generator::sparse_linalg::{ - generate_find_diag_indices_shader, generate_ilu0_level_shader, -}; use super::super::{WgpuClient, WgpuRuntime}; use super::common::{ WORKGROUP_SIZE, cast_i64_to_i32_gpu, create_ilu_ic_layout, split_lu_wgpu, validate_wgpu_dtype, @@ -19,6 +16,9 @@ use crate::error::{Error, Result}; use crate::sparse::CsrData; use crate::tensor::Tensor; +const FIND_DIAG_INDICES: &str = include_str!("../shaders/sparse_find_diag_indices.wgsl"); +const ILU0_LEVEL_F32: &str = include_str!("../shaders/sparse_ilu0_level_f32.wgsl"); + /// ILU(0) factorization for WebGPU. pub fn ilu0_wgpu( client: &WgpuClient, @@ -224,10 +224,9 @@ pub(super) fn launch_find_diag_indices( diag_indices: &Tensor, n: usize, ) -> Result<()> { - let shader_source = generate_find_diag_indices_shader(); let module = client .pipeline_cache - .get_or_create_module_from_source("find_diag_indices", &shader_source); + .get_or_create_module("find_diag_indices", FIND_DIAG_INDICES); // Create bind group layout let layout = client @@ -281,7 +280,7 @@ pub(super) fn launch_find_diag_indices( ], }); - let pipeline = client.pipeline_cache.get_or_create_dynamic_pipeline( + let pipeline = client.pipeline_cache.get_or_create_pipeline( "find_diag_indices", "find_diag_indices", &module, @@ -363,14 +362,13 @@ pub(super) fn launch_ilu0_level( n: usize, diagonal_shift: f32, ) -> Result<()> { - let shader_source = generate_ilu0_level_shader(DType::F32)?; let module = client .pipeline_cache - .get_or_create_module_from_source("ilu0_level_f32", &shader_source); + .get_or_create_module("ilu0_level_f32", ILU0_LEVEL_F32); let layout = create_ilu_ic_layout(&client.wgpu_device); - let pipeline = client.pipeline_cache.get_or_create_dynamic_pipeline( + let pipeline = client.pipeline_cache.get_or_create_pipeline( "ilu0_level_f32", "ilu0_level_f32", &module, diff --git a/src/runtime/wgpu/sparse/triangular_solve.rs b/src/runtime/wgpu/sparse/triangular_solve.rs index 0653cb16..10966142 100644 --- a/src/runtime/wgpu/sparse/triangular_solve.rs +++ b/src/runtime/wgpu/sparse/triangular_solve.rs @@ -3,20 +3,22 @@ use wgpu::{BindGroupDescriptor, BindGroupEntry, BufferUsages}; use super::super::ops::helpers::get_tensor_buffer; -use super::super::shaders::generator::sparse_linalg::{ - generate_sparse_trsv_lower_multi_rhs_shader, generate_sparse_trsv_lower_shader, - generate_sparse_trsv_upper_multi_rhs_shader, generate_sparse_trsv_upper_shader, -}; use super::super::{WgpuClient, WgpuRuntime}; use super::common::{WORKGROUP_SIZE, cast_i64_to_i32_gpu, create_trsv_layout, validate_wgpu_dtype}; use crate::algorithm::sparse_linalg::validate_triangular_solve_dims; use crate::algorithm::sparse_linalg::{compute_levels_lower, compute_levels_upper, flatten_levels}; -use crate::dtype::DType; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::CsrData; use crate::tensor::Tensor; +const TRSV_LOWER_F32: &str = include_str!("../shaders/sparse_trsv_lower_f32.wgsl"); +const TRSV_UPPER_F32: &str = include_str!("../shaders/sparse_trsv_upper_f32.wgsl"); +const TRSV_LOWER_MULTI_RHS_F32: &str = + include_str!("../shaders/sparse_trsv_lower_multi_rhs_f32.wgsl"); +const TRSV_UPPER_MULTI_RHS_F32: &str = + include_str!("../shaders/sparse_trsv_upper_multi_rhs_f32.wgsl"); + /// Sparse triangular solve for WebGPU. /// Supports both single RHS (b is 1D vector) and multi-RHS (b is 2D matrix [n, nrhs]). pub fn sparse_solve_triangular_wgpu( @@ -72,12 +74,7 @@ pub fn sparse_solve_triangular_wgpu( // Allocate output and copy b into it on GPU (must be separate buffer) let x = Tensor::::zeros(b.shape(), dtype, &client.device_id); let copy_size = b.numel() * dtype.size_in_bytes(); - WgpuRuntime::copy_within_device( - b.storage().ptr(), - x.storage().ptr(), - copy_size, - &client.device_id, - )?; + WgpuRuntime::copy_within_device(b.ptr(), x.ptr(), copy_size, &client.device_id)?; // Process each level for level in 0..schedule.num_levels { @@ -174,14 +171,13 @@ fn launch_sparse_trsv_lower( n: usize, unit_diagonal: bool, ) -> Result<()> { - let shader_source = generate_sparse_trsv_lower_shader(DType::F32)?; let module = client .pipeline_cache - .get_or_create_module_from_source("sparse_trsv_lower_f32", &shader_source); + .get_or_create_module("sparse_trsv_lower_f32", TRSV_LOWER_F32); let layout = create_trsv_layout(&client.wgpu_device); - let pipeline = client.pipeline_cache.get_or_create_dynamic_pipeline( + let pipeline = client.pipeline_cache.get_or_create_pipeline( "sparse_trsv_lower_f32", "sparse_trsv_lower_level_f32", &module, @@ -279,14 +275,13 @@ fn launch_sparse_trsv_upper( x: &Tensor, n: usize, ) -> Result<()> { - let shader_source = generate_sparse_trsv_upper_shader(DType::F32)?; let module = client .pipeline_cache - .get_or_create_module_from_source("sparse_trsv_upper_f32", &shader_source); + .get_or_create_module("sparse_trsv_upper_f32", TRSV_UPPER_F32); let layout = create_trsv_layout(&client.wgpu_device); - let pipeline = client.pipeline_cache.get_or_create_dynamic_pipeline( + let pipeline = client.pipeline_cache.get_or_create_pipeline( "sparse_trsv_upper_f32", "sparse_trsv_upper_level_f32", &module, @@ -381,14 +376,13 @@ fn launch_sparse_trsv_lower_multi_rhs( n: usize, unit_diagonal: bool, ) -> Result<()> { - let shader_source = generate_sparse_trsv_lower_multi_rhs_shader(DType::F32)?; let module = client .pipeline_cache - .get_or_create_module_from_source("sparse_trsv_lower_multi_rhs_f32", &shader_source); + .get_or_create_module("sparse_trsv_lower_multi_rhs_f32", TRSV_LOWER_MULTI_RHS_F32); let layout = create_trsv_layout(&client.wgpu_device); - let pipeline = client.pipeline_cache.get_or_create_dynamic_pipeline( + let pipeline = client.pipeline_cache.get_or_create_pipeline( "sparse_trsv_lower_multi_rhs_f32", "sparse_trsv_lower_level_multi_rhs_f32", &module, @@ -493,14 +487,13 @@ fn launch_sparse_trsv_upper_multi_rhs( x: &Tensor, n: usize, ) -> Result<()> { - let shader_source = generate_sparse_trsv_upper_multi_rhs_shader(DType::F32)?; let module = client .pipeline_cache - .get_or_create_module_from_source("sparse_trsv_upper_multi_rhs_f32", &shader_source); + .get_or_create_module("sparse_trsv_upper_multi_rhs_f32", TRSV_UPPER_MULTI_RHS_F32); let layout = create_trsv_layout(&client.wgpu_device); - let pipeline = client.pipeline_cache.get_or_create_dynamic_pipeline( + let pipeline = client.pipeline_cache.get_or_create_pipeline( "sparse_trsv_upper_multi_rhs_f32", "sparse_trsv_upper_level_multi_rhs_f32", &module, diff --git a/src/runtime/wgpu/statistics/mod.rs b/src/runtime/wgpu/statistics/mod.rs index e05038b6..02ba823a 100644 --- a/src/runtime/wgpu/statistics/mod.rs +++ b/src/runtime/wgpu/statistics/mod.rs @@ -36,7 +36,7 @@ use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::TypeConversionOps; use crate::runtime::RuntimeClient; -use crate::runtime::statistics_common::compute_bin_edges_f64; +use crate::runtime::common::statistics_common::compute_bin_edges_f64; use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; use crate::tensor::Tensor; @@ -94,7 +94,7 @@ pub(crate) fn tensor_to_f64(client: &WgpuClient, t: &Tensor) -> Res } // Get buffer from tensor - let src_buffer = get_buffer(t.storage().ptr()) + let src_buffer = get_buffer(t.ptr()) .ok_or_else(|| Error::Internal("Failed to get tensor buffer".to_string()))?; // Create staging buffer and copy diff --git a/src/runtime/wgpu/statistics/mode.rs b/src/runtime/wgpu/statistics/mode.rs index 0cb1a0e1..4269a6bc 100644 --- a/src/runtime/wgpu/statistics/mode.rs +++ b/src/runtime/wgpu/statistics/mode.rs @@ -4,7 +4,6 @@ use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::{SortingOps, TypeConversionOps, compute_reduce_strides, reduce_dim_output_shape}; use crate::runtime::wgpu::client::get_buffer; -use crate::runtime::wgpu::shaders::generator::is_wgpu_supported; use crate::runtime::wgpu::shaders::launch_mode_dim; use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; use crate::runtime::{RuntimeClient, ensure_contiguous, normalize_dim}; @@ -27,7 +26,7 @@ pub fn mode_impl( let dtype = a.dtype(); // Validate dtype is supported by native shader - let native_supported = is_wgpu_supported(dtype); + let native_supported = matches!(dtype, DType::F32 | DType::I32 | DType::U32); if !native_supported { // For unsupported dtypes (F64, F16, BF16, I64, etc.), cast to F32, compute, cast back @@ -88,11 +87,11 @@ pub fn mode_impl( let mode_counts = Tensor::::empty(&out_shape, DType::I32, client.device()); // Get wgpu buffers - let sorted_buf = get_buffer(sorted_contig.storage().ptr()) + let sorted_buf = get_buffer(sorted_contig.ptr()) .ok_or_else(|| Error::Internal("Failed to get sorted buffer".to_string()))?; - let values_buf = get_buffer(mode_values.storage().ptr()) + let values_buf = get_buffer(mode_values.ptr()) .ok_or_else(|| Error::Internal("Failed to get mode_values buffer".to_string()))?; - let counts_buf = get_buffer(mode_counts.storage().ptr()) + let counts_buf = get_buffer(mode_counts.ptr()) .ok_or_else(|| Error::Internal("Failed to get mode_counts buffer".to_string()))?; // Create params buffer: [outer_size, reduce_size, inner_size, pad] diff --git a/src/runtime/wgpu/statistics/moments.rs b/src/runtime/wgpu/statistics/moments.rs index ff9fab9f..12f29203 100644 --- a/src/runtime/wgpu/statistics/moments.rs +++ b/src/runtime/wgpu/statistics/moments.rs @@ -1,7 +1,7 @@ //! Higher-order moment statistics for WebGPU runtime (skewness, kurtosis) use crate::error::Result; -use crate::runtime::statistics_common; +use crate::runtime::common::statistics_common; use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; use crate::tensor::Tensor; diff --git a/src/runtime/wgpu/statistics/quantile.rs b/src/runtime/wgpu/statistics/quantile.rs index cad10c72..edcee325 100644 --- a/src/runtime/wgpu/statistics/quantile.rs +++ b/src/runtime/wgpu/statistics/quantile.rs @@ -5,7 +5,7 @@ use crate::error::{Error, Result}; use crate::ops::{ BinaryOps, IndexingOps, ScalarOps, SortingOps, TypeConversionOps, reduce_dim_output_shape, }; -use crate::runtime::statistics_common::Interpolation; +use crate::runtime::common::statistics_common::Interpolation; use crate::runtime::wgpu::{WgpuClient, WgpuRuntime}; use crate::runtime::{RuntimeClient, normalize_dim}; use crate::tensor::Tensor; @@ -94,7 +94,7 @@ pub fn quantile_impl( // Calculate quantile indices using shared logic let (floor_idx, ceil_idx, frac) = - crate::runtime::statistics_common::compute_quantile_indices(q, dim_size); + crate::runtime::common::statistics_common::compute_quantile_indices(q, dim_size); // Check for empty output let out_numel = out_shape.iter().product::(); diff --git a/src/sparse/coo/conversion.rs b/src/sparse/coo/conversion.rs index 1438b2b5..8fca2502 100644 --- a/src/sparse/coo/conversion.rs +++ b/src/sparse/coo/conversion.rs @@ -1,12 +1,13 @@ //! COO format conversion: to_csr, to_csc use super::CooData; +use crate::dtype::DType; use crate::error::Result; use crate::runtime::Runtime; use crate::sparse::{CscData, CsrData, SparseStorage}; use crate::tensor::Tensor; -impl CooData { +impl> CooData { /// Convert to CSR format /// /// This is an efficient conversion that: diff --git a/src/sparse/coo/core.rs b/src/sparse/coo/core.rs index 607bfe87..c2832046 100644 --- a/src/sparse/coo/core.rs +++ b/src/sparse/coo/core.rs @@ -17,7 +17,7 @@ pub struct CooData { pub(crate) sorted: bool, } -impl CooData { +impl> CooData { /// Create a new COO matrix from components /// /// # Arguments @@ -122,7 +122,7 @@ impl CooData { self.sorted = sorted; } } -impl SparseStorage for CooData { +impl> SparseStorage for CooData { fn format(&self) -> SparseFormat { SparseFormat::Coo } @@ -148,7 +148,7 @@ impl SparseStorage for CooData { } /// Create COO data from host arrays (CPU) -impl CooData { +impl> CooData { /// Create COO matrix from host slices /// /// # Arguments diff --git a/src/sparse/coo/elementwise/add.rs b/src/sparse/coo/elementwise/add.rs index 3495d56c..3a4c4921 100644 --- a/src/sparse/coo/elementwise/add.rs +++ b/src/sparse/coo/elementwise/add.rs @@ -1,11 +1,12 @@ //! Element-wise addition for COO matrices use super::super::CooData; +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::{SparseOps, SparseStorage}; -impl CooData { +impl> CooData { /// Element-wise addition: C = A + B /// /// Computes the sum of two sparse matrices with the same shape. diff --git a/src/sparse/coo/elementwise/div.rs b/src/sparse/coo/elementwise/div.rs index 94e158c0..2f76d760 100644 --- a/src/sparse/coo/elementwise/div.rs +++ b/src/sparse/coo/elementwise/div.rs @@ -1,13 +1,13 @@ //! Element-wise division for COO matrices use super::super::CooData; -use crate::dtype::Element; +use crate::dtype::{DType, Element}; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::SparseStorage; use crate::tensor::Tensor; -impl CooData { +impl> CooData { /// Element-wise division: C = A ./ B /// /// Computes the element-wise quotient of two sparse matrices with the same shape. diff --git a/src/sparse/coo/elementwise/mul.rs b/src/sparse/coo/elementwise/mul.rs index 1381b20c..75694920 100644 --- a/src/sparse/coo/elementwise/mul.rs +++ b/src/sparse/coo/elementwise/mul.rs @@ -1,11 +1,12 @@ //! Element-wise multiplication (Hadamard product) for COO matrices use super::super::CooData; +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::{SparseOps, SparseStorage}; -impl CooData { +impl> CooData { /// Element-wise multiplication (Hadamard product): C = A .* B /// /// Computes the element-wise product of two sparse matrices with the same shape. diff --git a/src/sparse/coo/elementwise/sub.rs b/src/sparse/coo/elementwise/sub.rs index ffe00f64..29346f19 100644 --- a/src/sparse/coo/elementwise/sub.rs +++ b/src/sparse/coo/elementwise/sub.rs @@ -1,11 +1,12 @@ //! Element-wise subtraction for COO matrices use super::super::CooData; +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::{SparseOps, SparseStorage}; -impl CooData { +impl> CooData { /// Element-wise subtraction: C = A - B /// /// Computes the difference of two sparse matrices with the same shape. diff --git a/src/sparse/coo/matmul.rs b/src/sparse/coo/matmul.rs index 1d4d1f2f..793d9d38 100644 --- a/src/sparse/coo/matmul.rs +++ b/src/sparse/coo/matmul.rs @@ -1,11 +1,12 @@ //! COO matrix multiplication: spmv, spmm, transpose use super::CooData; +use crate::dtype::DType; use crate::error::Result; use crate::runtime::Runtime; use crate::tensor::Tensor; -impl CooData { +impl> CooData { /// Sparse matrix-vector multiplication: y = A * x /// /// Converts to CSR format (optimal for SpMV) and performs the multiplication. diff --git a/src/sparse/csc/conversion.rs b/src/sparse/csc/conversion.rs index 4e0a6512..b148af76 100644 --- a/src/sparse/csc/conversion.rs +++ b/src/sparse/csc/conversion.rs @@ -1,12 +1,13 @@ //! CSC format conversion: to_coo, to_csr use super::CscData; +use crate::dtype::DType; use crate::error::Result; use crate::runtime::Runtime; use crate::sparse::{CooData, CsrData, SparseStorage}; use crate::tensor::Tensor; -impl CscData { +impl> CscData { /// Convert to COO format /// /// Expands the compressed column pointers into explicit column indices. diff --git a/src/sparse/csc/core.rs b/src/sparse/csc/core.rs index f1851ad1..4cf4debb 100644 --- a/src/sparse/csc/core.rs +++ b/src/sparse/csc/core.rs @@ -17,7 +17,7 @@ pub struct CscData { pub(crate) shape: [usize; 2], } -impl CscData { +impl> CscData { /// Create a new CSC matrix from components pub fn new( col_ptrs: Tensor, @@ -225,7 +225,7 @@ impl CscData { } } -impl SparseStorage for CscData { +impl> SparseStorage for CscData { fn format(&self) -> SparseFormat { SparseFormat::Csc } @@ -250,7 +250,7 @@ impl SparseStorage for CscData { } } -impl CscData { +impl> CscData { /// Create CSC matrix from host slices pub fn from_slices( col_ptrs: &[i64], @@ -307,7 +307,7 @@ impl CscData { // SparseScaling Implementation for CscData // ============================================================================ -impl SparseScaling for CscData { +impl> SparseScaling for CscData { fn row_norms(&self, norm: NormType) -> Result> { let [nrows, ncols] = self.shape; let device = self.values.device(); diff --git a/src/sparse/csc/elementwise/add.rs b/src/sparse/csc/elementwise/add.rs index 8ef43e94..f950f8da 100644 --- a/src/sparse/csc/elementwise/add.rs +++ b/src/sparse/csc/elementwise/add.rs @@ -1,11 +1,12 @@ //! Element-wise addition for CSC matrices use super::super::CscData; +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::{SparseOps, SparseStorage}; -impl CscData { +impl> CscData { /// Element-wise addition: C = A + B /// /// Computes the sum of two sparse matrices with the same shape. diff --git a/src/sparse/csc/elementwise/div.rs b/src/sparse/csc/elementwise/div.rs index 3a8f14ce..51b3c749 100644 --- a/src/sparse/csc/elementwise/div.rs +++ b/src/sparse/csc/elementwise/div.rs @@ -1,13 +1,13 @@ //! Element-wise division for CSC matrices use super::super::CscData; -use crate::dtype::Element; +use crate::dtype::{DType, Element}; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::SparseStorage; use crate::tensor::Tensor; -impl CscData { +impl> CscData { /// Element-wise division: C = A ./ B /// /// Computes the element-wise quotient of two sparse matrices with the same shape. diff --git a/src/sparse/csc/elementwise/mul.rs b/src/sparse/csc/elementwise/mul.rs index 474a49b2..125067c2 100644 --- a/src/sparse/csc/elementwise/mul.rs +++ b/src/sparse/csc/elementwise/mul.rs @@ -1,11 +1,12 @@ //! Element-wise multiplication (Hadamard product) for CSC matrices use super::super::CscData; +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::{SparseOps, SparseStorage}; -impl CscData { +impl> CscData { /// Element-wise multiplication (Hadamard product): C = A .* B /// /// Computes the element-wise product of two sparse matrices with the same shape. diff --git a/src/sparse/csc/elementwise/sub.rs b/src/sparse/csc/elementwise/sub.rs index f9ebe1ad..7cfff308 100644 --- a/src/sparse/csc/elementwise/sub.rs +++ b/src/sparse/csc/elementwise/sub.rs @@ -1,11 +1,12 @@ //! Element-wise subtraction for CSC matrices use super::super::CscData; +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::{SparseOps, SparseStorage}; -impl CscData { +impl> CscData { /// Element-wise subtraction: C = A - B /// /// Computes the difference of two sparse matrices with the same shape. diff --git a/src/sparse/csc/matmul.rs b/src/sparse/csc/matmul.rs index b22f9f3d..75ab76ec 100644 --- a/src/sparse/csc/matmul.rs +++ b/src/sparse/csc/matmul.rs @@ -1,12 +1,13 @@ //! CSC matrix multiplication: spmv, spmm use super::CscData; +use crate::dtype::DType; use crate::error::Result; use crate::runtime::Runtime; use crate::sparse::CsrData; use crate::tensor::Tensor; -impl CscData { +impl> CscData { /// Sparse matrix-vector multiplication: y = A * x /// /// Converts to CSR format (optimal for SpMV) and performs the multiplication. diff --git a/src/sparse/csr/conversion.rs b/src/sparse/csr/conversion.rs index 9a99a19d..2b9864ba 100644 --- a/src/sparse/csr/conversion.rs +++ b/src/sparse/csr/conversion.rs @@ -1,12 +1,13 @@ //! CSR format conversion: to_coo, to_csc use super::CsrData; +use crate::dtype::DType; use crate::error::Result; use crate::runtime::Runtime; use crate::sparse::{CooData, CscData, SparseStorage}; use crate::tensor::Tensor; -impl CsrData { +impl> CsrData { /// Convert to COO format /// /// Expands the compressed row pointers into explicit row indices. diff --git a/src/sparse/csr/core.rs b/src/sparse/csr/core.rs index 4b91bdd0..1dfd650f 100644 --- a/src/sparse/csr/core.rs +++ b/src/sparse/csr/core.rs @@ -16,7 +16,7 @@ pub struct CsrData { pub(crate) shape: [usize; 2], } -impl CsrData { +impl> CsrData { /// Create a new CSR matrix from components /// /// # Arguments @@ -288,7 +288,7 @@ impl CsrData { } } -impl SparseStorage for CsrData { +impl> SparseStorage for CsrData { fn format(&self) -> SparseFormat { SparseFormat::Csr } @@ -315,7 +315,7 @@ impl SparseStorage for CsrData { } /// Create CSR data from host arrays -impl CsrData { +impl> CsrData { /// Create CSR matrix from host slices /// /// # Arguments diff --git a/src/sparse/csr/elementwise.rs b/src/sparse/csr/elementwise.rs index 2d9b6969..4103e7ce 100644 --- a/src/sparse/csr/elementwise.rs +++ b/src/sparse/csr/elementwise.rs @@ -4,13 +4,13 @@ //! via the SparseOps trait, enabling GPU acceleration when available. use super::CsrData; -use crate::dtype::Element; +use crate::dtype::{DType, Element}; use crate::error::{Error, Result}; use crate::ops::ScalarOps; use crate::runtime::Runtime; use crate::sparse::{SparseOps, SparseStorage}; -impl CsrData { +impl> CsrData { /// Element-wise addition: C = A + B /// /// Computes the sum of two sparse matrices with the same shape. diff --git a/src/sparse/csr/matmul.rs b/src/sparse/csr/matmul.rs index fb29802a..951c99c0 100644 --- a/src/sparse/csr/matmul.rs +++ b/src/sparse/csr/matmul.rs @@ -1,13 +1,13 @@ //! CSR matrix multiplication: spmv, spmm use super::CsrData; -use crate::dtype::Element; +use crate::dtype::{DType, Element}; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::{CscData, SparseStorage}; use crate::tensor::Tensor; -impl CsrData { +impl> CsrData { /// Sparse matrix-vector multiplication: y = A * x /// /// Computes the product of this sparse matrix with a dense vector. diff --git a/src/sparse/mod.rs b/src/sparse/mod.rs index fef2c685..b1261045 100644 --- a/src/sparse/mod.rs +++ b/src/sparse/mod.rs @@ -62,11 +62,14 @@ mod csc; mod csr; mod format; mod ops; +pub mod structured; mod tensor; +pub use crate::ops::traits::Sparse24Ops; pub use coo::CooData; pub use csc::CscData; pub use csr::CsrData; pub use format::{SparseFormat, SparseStorage}; pub use ops::{NormType, SparseOps, SparseScaling}; +pub use structured::Sparse24Tensor; pub use tensor::SparseTensor; diff --git a/src/sparse/ops.rs b/src/sparse/ops.rs index ec2cb5a7..ec13c770 100644 --- a/src/sparse/ops.rs +++ b/src/sparse/ops.rs @@ -2,6 +2,7 @@ //! //! Defines the interface for sparse tensor operations that backends implement. +use crate::dtype::DType; use crate::error::Result; use crate::runtime::Runtime; use crate::tensor::Tensor; @@ -58,7 +59,7 @@ use super::{CscData, CsrData, SparseTensor}; /// # } /// # Ok::<(), numr::error::Error>(()) /// ``` -pub trait SparseOps: Sized { +pub trait SparseOps>: Sized { // ========================================================================= // Low-Level Format-Specific Operations (Backend Implementation Required) // ========================================================================= @@ -888,7 +889,7 @@ mod tests { #[test] fn test_sparse_ops_trait_exists() { // Trait compiles correctly - fn _accepts_sparse_ops>(_: &T) {} + fn _accepts_sparse_ops, T: SparseOps>(_: &T) {} } #[test] diff --git a/src/sparse/structured.rs b/src/sparse/structured.rs new file mode 100644 index 00000000..46abadbc --- /dev/null +++ b/src/sparse/structured.rs @@ -0,0 +1,231 @@ +//! 2:4 Structured sparsity format +//! +//! NVIDIA Ampere+ format where exactly 2 of every 4 consecutive elements are zero, +//! enabling 2x GEMM throughput via sparse tensor cores. +//! +//! The compressed representation stores only the 2 non-zero values per group of 4, +//! plus 2-bit metadata indicating which positions were kept. + +use crate::dtype::DType; +use crate::error::{Error, Result}; +use crate::runtime::Runtime; +use crate::tensor::Tensor; + +/// 2:4 structured sparse tensor +/// +/// Stores a matrix in compressed 2:4 format where exactly 2 out of every 4 +/// consecutive elements along the K dimension are non-zero. +/// +/// # Layout +/// +/// For an `[M, K]` dense matrix: +/// - `compressed_values`: `[M, K/2]` — the 2 kept values per group of 4 +/// - `metadata`: `[M, K/16]` as U32 — 2-bit indices packed into 32-bit words +/// (each U32 holds metadata for 16 groups of 4 = 64 elements) +/// +/// # Metadata encoding +/// +/// For each group of 4 elements, 2 bits encode which 2 of 4 positions are kept. +/// There are C(4,2) = 6 valid patterns, encoded as: +/// - 0b00: positions 0,1 +/// - 0b01: positions 0,2 +/// - 0b10: positions 0,3 +/// - 0b11: positions 1,2 +/// - 0b100: positions 1,3 (but we only use 2 bits, so we need a different encoding) +/// +/// Actually, NVIDIA uses a different encoding: each group stores a 4-bit mask where +/// exactly 2 bits are set, indicating which positions are kept. We pack 8 such masks +/// per U32 (8 × 4 bits = 32 bits), so metadata shape is `[M, ceil(K/4/8)]` = `[M, K/32]`. +/// +/// Revised: We use 4 bits per group (bitmask with exactly 2 bits set). +/// 8 groups per U32 → metadata shape `[M, K/32]` (since K/4 groups, 8 groups per U32). +/// If K is not divisible by 32, the last U32 is partially used. +#[derive(Debug, Clone)] +pub struct Sparse24Tensor { + /// Compressed non-zero values, shape [M, K/2] + pub(crate) compressed_values: Tensor, + /// Packed metadata bitmasks, shape [M, ceil(K/4 / 8)] as U32 + /// Each U32 contains 8 groups × 4 bits = 32 bits + pub(crate) metadata: Tensor, + /// Original dense shape [M, K] + pub(crate) original_shape: [usize; 2], + /// Data type of the compressed values + pub(crate) dtype: DType, +} + +impl> Sparse24Tensor { + /// Create a Sparse24Tensor from pre-built components + /// + /// # Arguments + /// * `compressed_values` - Shape [M, K/2], the non-zero values + /// * `metadata` - Shape [M, meta_cols] as U32, packed bitmasks + /// * `original_shape` - The original dense shape [M, K] + pub fn new( + compressed_values: Tensor, + metadata: Tensor, + original_shape: [usize; 2], + ) -> Result { + let [m, k] = original_shape; + + // K must be divisible by 4 + if k % 4 != 0 { + return Err(Error::InvalidArgument { + arg: "original_shape", + reason: format!("K dimension ({k}) must be divisible by 4 for 2:4 sparsity"), + }); + } + + // Validate compressed_values shape + let expected_val_shape = [m, k / 2]; + if compressed_values.shape() != expected_val_shape { + return Err(Error::ShapeMismatch { + expected: expected_val_shape.to_vec(), + got: compressed_values.shape().to_vec(), + }); + } + + // Validate metadata shape + let num_groups = k / 4; + let meta_cols = (num_groups + 7) / 8; // ceil(num_groups / 8) + let expected_meta_shape = [m, meta_cols]; + if metadata.shape() != expected_meta_shape { + return Err(Error::ShapeMismatch { + expected: expected_meta_shape.to_vec(), + got: metadata.shape().to_vec(), + }); + } + + // Metadata must be U32 + if metadata.dtype() != DType::U32 { + return Err(Error::DTypeMismatch { + lhs: DType::U32, + rhs: metadata.dtype(), + }); + } + + let dtype = compressed_values.dtype(); + + Ok(Self { + compressed_values, + metadata, + original_shape, + dtype, + }) + } + + /// Returns the original dense shape [M, K] + #[inline] + pub fn shape(&self) -> [usize; 2] { + self.original_shape + } + + /// Returns M (number of rows) + #[inline] + pub fn nrows(&self) -> usize { + self.original_shape[0] + } + + /// Returns K (original number of columns) + #[inline] + pub fn ncols(&self) -> usize { + self.original_shape[1] + } + + /// Returns the data type + #[inline] + pub fn dtype(&self) -> DType { + self.dtype + } + + /// Returns a reference to the compressed values tensor [M, K/2] + #[inline] + pub fn compressed_values(&self) -> &Tensor { + &self.compressed_values + } + + /// Returns a reference to the metadata tensor [M, meta_cols] as U32 + #[inline] + pub fn metadata(&self) -> &Tensor { + &self.metadata + } + + /// Returns the number of non-zero elements (always M * K/2) + #[inline] + pub fn nnz(&self) -> usize { + self.original_shape[0] * (self.original_shape[1] / 2) + } + + /// Returns the compression ratio (always 2.0 for 2:4) + #[inline] + pub fn compression_ratio(&self) -> f64 { + 2.0 + } + + /// Number of groups of 4 per row + #[inline] + pub fn groups_per_row(&self) -> usize { + self.original_shape[1] / 4 + } + + /// Number of U32 metadata words per row + #[inline] + pub fn meta_cols(&self) -> usize { + (self.groups_per_row() + 7) / 8 + } + + /// Validate that the 2:4 structure is correct: + /// each metadata group has exactly 2 bits set in its 4-bit nibble + pub fn is_valid(&self) -> bool + where + R: Runtime, + { + let meta_data: Vec = self.metadata.to_vec(); + let num_groups = self.groups_per_row(); + + for row in 0..self.nrows() { + for g in 0..num_groups { + let word_idx = g / 8; + let nibble_idx = g % 8; + let word = meta_data[row * self.meta_cols() + word_idx]; + let nibble = (word >> (nibble_idx * 4)) & 0xF; + if nibble.count_ones() != 2 { + return false; + } + } + } + true + } +} + +/// Compute the metadata column count for a given K dimension +#[inline] +pub fn meta_cols_for_k(k: usize) -> usize { + let num_groups = k / 4; + (num_groups + 7) / 8 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_meta_cols_for_k() { + assert_eq!(meta_cols_for_k(4), 1); // 1 group, 1 word + assert_eq!(meta_cols_for_k(8), 1); // 2 groups, 1 word + assert_eq!(meta_cols_for_k(32), 1); // 8 groups, 1 word + assert_eq!(meta_cols_for_k(36), 2); // 9 groups, 2 words + assert_eq!(meta_cols_for_k(64), 2); // 16 groups, 2 words + } + + #[test] + fn test_k_must_be_divisible_by_4() { + use crate::runtime::cpu::{CpuDevice, CpuRuntime}; + let device = CpuDevice::new(); + + // K=5 should fail + let vals = Tensor::::from_slice(&[1.0f32, 2.0], &[1, 2], &device); + let meta = Tensor::::from_slice(&[0u32], &[1, 1], &device); + let result = Sparse24Tensor::new(vals, meta, [1, 5]); + assert!(result.is_err()); + } +} diff --git a/src/sparse/tensor/conversion.rs b/src/sparse/tensor/conversion.rs index 563adc68..dc60cf74 100644 --- a/src/sparse/tensor/conversion.rs +++ b/src/sparse/tensor/conversion.rs @@ -1,13 +1,13 @@ //! SparseTensor format conversion: to_coo, to_csr, to_csc use super::SparseTensor; -use crate::dtype::Element; +use crate::dtype::{DType, Element}; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::SparseFormat; use crate::tensor::Tensor; -impl SparseTensor { +impl> SparseTensor { // ========================================================================= // Format Conversion // ========================================================================= diff --git a/src/sparse/tensor/core.rs b/src/sparse/tensor/core.rs index c26742bb..7b49d90e 100644 --- a/src/sparse/tensor/core.rs +++ b/src/sparse/tensor/core.rs @@ -65,7 +65,7 @@ pub enum SparseTensor { Csc(CscData), } -impl SparseTensor { +impl> SparseTensor { // ========================================================================= // Constructors // ========================================================================= @@ -302,7 +302,7 @@ impl SparseTensor { } } -impl std::fmt::Display for SparseTensor { +impl> std::fmt::Display for SparseTensor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, diff --git a/src/sparse/tensor/elementwise/add.rs b/src/sparse/tensor/elementwise/add.rs index 3ff12520..e93f58ae 100644 --- a/src/sparse/tensor/elementwise/add.rs +++ b/src/sparse/tensor/elementwise/add.rs @@ -1,10 +1,11 @@ //! Element-wise addition operation for sparse tensors +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::{SparseOps, SparseTensor}; -impl SparseTensor { +impl> SparseTensor { /// Element-wise addition: C = A + B /// /// Computes the sum of two sparse tensors with the same shape. diff --git a/src/sparse/tensor/elementwise/div.rs b/src/sparse/tensor/elementwise/div.rs index 0e19d03a..15318bd3 100644 --- a/src/sparse/tensor/elementwise/div.rs +++ b/src/sparse/tensor/elementwise/div.rs @@ -1,10 +1,11 @@ //! Element-wise division operation for sparse tensors +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::{SparseOps, SparseTensor}; -impl SparseTensor { +impl> SparseTensor { /// Element-wise division: C = A ./ B /// /// Computes the element-wise quotient of two sparse tensors with the same shape. diff --git a/src/sparse/tensor/elementwise/mul.rs b/src/sparse/tensor/elementwise/mul.rs index 6d0d3fa2..f845647c 100644 --- a/src/sparse/tensor/elementwise/mul.rs +++ b/src/sparse/tensor/elementwise/mul.rs @@ -1,10 +1,11 @@ //! Element-wise multiplication operation for sparse tensors +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::runtime::Runtime; use crate::sparse::{SparseOps, SparseTensor}; -impl SparseTensor { +impl> SparseTensor { /// Element-wise multiplication (Hadamard product): C = A .* B /// /// Computes the element-wise product of two sparse tensors with the same shape. diff --git a/src/sparse/tensor/elementwise/scalar.rs b/src/sparse/tensor/elementwise/scalar.rs index a6bd61f4..4df49cc4 100644 --- a/src/sparse/tensor/elementwise/scalar.rs +++ b/src/sparse/tensor/elementwise/scalar.rs @@ -1,11 +1,12 @@ //! Scalar operations for sparse tensors +use crate::dtype::DType; use crate::error::Result; use crate::ops::ScalarOps; use crate::runtime::Runtime; use crate::sparse::SparseTensor; -impl SparseTensor { +impl> SparseTensor { /// Scalar multiplication: C = A * scalar /// /// Multiplies all non-zero values by a scalar constant. diff --git a/src/sparse/tensor/elementwise/sub.rs b/src/sparse/tensor/elementwise/sub.rs index 33fbd748..896afccf 100644 --- a/src/sparse/tensor/elementwise/sub.rs +++ b/src/sparse/tensor/elementwise/sub.rs @@ -1,11 +1,12 @@ //! Element-wise subtraction operation for sparse tensors +use crate::dtype::DType; use crate::error::{Error, Result}; use crate::ops::ScalarOps; use crate::runtime::Runtime; use crate::sparse::{SparseOps, SparseTensor}; -impl SparseTensor { +impl> SparseTensor { /// Element-wise subtraction: C = A - B /// /// Computes the difference of two sparse tensors with the same shape. diff --git a/src/sparse/tensor/matmul.rs b/src/sparse/tensor/matmul.rs index 7a33bbf2..55af299f 100644 --- a/src/sparse/tensor/matmul.rs +++ b/src/sparse/tensor/matmul.rs @@ -1,11 +1,12 @@ //! SparseTensor matrix multiplication: spmv, spmm use super::SparseTensor; +use crate::dtype::DType; use crate::error::Result; use crate::runtime::Runtime; use crate::tensor::Tensor; -impl SparseTensor { +impl> SparseTensor { /// Sparse matrix-vector multiplication: y = A * x /// /// Computes the product of this sparse matrix with a dense vector. diff --git a/src/tensor/core.rs b/src/tensor/core.rs index 8add1c1a..9af23014 100644 --- a/src/tensor/core.rs +++ b/src/tensor/core.rs @@ -1,7 +1,7 @@ //! Core Tensor type use super::{Layout, Storage, TensorId}; -use crate::dtype::{DType, Element}; +use crate::dtype::{DType, DataType, Element}; use crate::error::{Error, Result}; use crate::runtime::Runtime; use std::fmt; @@ -38,6 +38,10 @@ pub struct Tensor { layout: Layout, } +// ============================================================================ +// Generic methods — work with ANY R::DType via DataType trait +// ============================================================================ + impl Tensor { /// Create a tensor from storage and layout pub fn from_parts(storage: Storage, layout: Layout) -> Self { @@ -48,63 +52,6 @@ impl Tensor { } } - /// Create a tensor from a slice of data - /// - /// # Panics - /// - /// Panics if `data.len()` does not equal the product of the `shape` dimensions. - /// For a fallible alternative, use [`Self::try_from_slice`]. - /// - /// # Example - /// - /// ``` - /// # use numr::prelude::*; - /// # let device = CpuDevice::new(); - /// let tensor = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device); - /// # Ok::<(), numr::error::Error>(()) - /// ``` - #[track_caller] - pub fn from_slice(data: &[T], shape: &[usize], device: &R::Device) -> Self { - Self::try_from_slice(data, shape, device) - .unwrap_or_else(|e| panic!("Tensor::from_slice failed: {e}")) - } - - /// Create a tensor from a slice of data (fallible version) - /// - /// Returns an error if `data.len()` does not equal the product of the `shape` dimensions, - /// or if memory allocation fails. - /// - /// # Example - /// - /// ``` - /// # use numr::prelude::*; - /// # let device = CpuDevice::new(); - /// let tensor = Tensor::::try_from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device)?; - /// # Ok::<(), numr::error::Error>(()) - /// ``` - pub fn try_from_slice( - data: &[T], - shape: &[usize], - device: &R::Device, - ) -> Result { - let expected_len: usize = shape.iter().product(); - if data.len() != expected_len { - return Err(Error::ShapeMismatch { - expected: shape.to_vec(), - got: vec![data.len()], - }); - } - - let storage = Storage::from_slice(data, device)?; - let layout = Layout::contiguous(shape); - - Ok(Self { - id: TensorId::new(), - storage, - layout, - }) - } - /// Create an uninitialized tensor /// /// # Safety @@ -113,13 +60,13 @@ impl Tensor { /// # Panics /// Panics if allocation fails. Use [`Self::try_empty`] in fallible contexts. #[track_caller] - pub fn empty(shape: &[usize], dtype: DType, device: &R::Device) -> Self { + pub fn empty(shape: &[usize], dtype: R::DType, device: &R::Device) -> Self { Self::try_empty(shape, dtype, device) .unwrap_or_else(|e| panic!("Tensor::empty failed: {e}")) } /// Create an uninitialized tensor (fallible version) - pub fn try_empty(shape: &[usize], dtype: DType, device: &R::Device) -> Result { + pub fn try_empty(shape: &[usize], dtype: R::DType, device: &R::Device) -> Result { let len: usize = shape.iter().product(); let storage = Storage::new(len, dtype, device)?; let layout = Layout::contiguous(shape); @@ -131,128 +78,11 @@ impl Tensor { }) } - /// Create a tensor filled with zeros - /// - /// This properly initializes memory to zero on all backends (CPU and GPU). - #[track_caller] - pub fn zeros(shape: &[usize], dtype: DType, device: &R::Device) -> Self { - Self::try_zeros(shape, dtype, device) - .unwrap_or_else(|e| panic!("Tensor::zeros failed: {e}")) - } - - /// Create a tensor filled with zeros (fallible version) - pub fn try_zeros(shape: &[usize], dtype: DType, device: &R::Device) -> Result { - Self::try_full_scalar(shape, dtype, 0.0, device) - } - - /// Create a tensor filled with ones - #[track_caller] - pub fn ones(shape: &[usize], dtype: DType, device: &R::Device) -> Self { - Self::try_ones(shape, dtype, device).unwrap_or_else(|e| panic!("Tensor::ones failed: {e}")) - } - - /// Create a tensor filled with ones (fallible version) - pub fn try_ones(shape: &[usize], dtype: DType, device: &R::Device) -> Result { - Self::try_full_scalar(shape, dtype, 1.0, device) - } - - /// Create a tensor filled with a scalar value - /// - /// The scalar is converted to the target dtype. - #[track_caller] - pub fn full_scalar(shape: &[usize], dtype: DType, value: f64, device: &R::Device) -> Self { - Self::try_full_scalar(shape, dtype, value, device) - .unwrap_or_else(|e| panic!("Tensor::full_scalar failed: {e}")) - } - - /// Create a tensor filled with a scalar value (fallible version) - pub fn try_full_scalar( - shape: &[usize], - dtype: DType, - value: f64, - device: &R::Device, - ) -> Result { - // Helper to convert a typed Vec to bytes safely. - // Allocates with correct alignment for T, then copies to u8 vec. - #[inline] - fn typed_to_bytes(v: Vec) -> Vec { - bytemuck::cast_slice::(&v).to_vec() - } - - let len: usize = shape.iter().product(); - if len == 0 { - return Self::try_empty(shape, dtype, device); - } - - // Allocate with correct type alignment, then convert to bytes. - // This avoids alignment violations that would occur if we allocated - // a Vec and cast to stricter-aligned types like f64/i64. - let bytes: Vec = match dtype { - DType::F64 => typed_to_bytes(vec![value; len]), - DType::F32 => typed_to_bytes(vec![value as f32; len]), - DType::F16 => { - #[cfg(feature = "f16")] - { - use half::f16; - typed_to_bytes(vec![f16::from_f64(value); len]) - } - #[cfg(not(feature = "f16"))] - { - let half_bits = half_from_f32(value as f32, dtype); - typed_to_bytes(vec![half_bits; len]) - } - } - DType::BF16 => { - #[cfg(feature = "f16")] - { - use half::bf16; - typed_to_bytes(vec![bf16::from_f64(value); len]) - } - #[cfg(not(feature = "f16"))] - { - let half_bits = half_from_f32(value as f32, dtype); - typed_to_bytes(vec![half_bits; len]) - } - } - DType::FP8E4M3 => { - vec![crate::dtype::FP8E4M3::from_f32(value as f32).to_bits(); len] - } - DType::FP8E5M2 => { - vec![crate::dtype::FP8E5M2::from_f32(value as f32).to_bits(); len] - } - DType::I64 => typed_to_bytes(vec![value as i64; len]), - DType::I32 => typed_to_bytes(vec![value as i32; len]), - DType::I16 => typed_to_bytes(vec![value as i16; len]), - DType::I8 => typed_to_bytes(vec![value as i8; len]), - DType::U64 => typed_to_bytes(vec![value as u64; len]), - DType::U32 => typed_to_bytes(vec![value as u32; len]), - DType::U16 => typed_to_bytes(vec![value as u16; len]), - DType::U8 => vec![value as u8; len], - DType::Bool => vec![if value != 0.0 { 1u8 } else { 0u8 }; len], - DType::Complex64 => { - typed_to_bytes(vec![crate::dtype::Complex64::new(value as f32, 0.0); len]) - } - DType::Complex128 => { - typed_to_bytes(vec![crate::dtype::Complex128::new(value, 0.0); len]) - } - }; - - // Allocate and copy to device - let storage = Storage::from_bytes(&bytes, dtype, device)?; - let layout = Layout::contiguous(shape); - - Ok(Self { - id: TensorId::new(), - storage, - layout, - }) - } - // ===== Accessors ===== /// Get the internal tensor ID for autograd graph tracking. #[inline] - pub(crate) fn id(&self) -> TensorId { + pub fn id(&self) -> TensorId { self.id } @@ -294,7 +124,7 @@ impl Tensor { /// Get the element type #[inline] - pub fn dtype(&self) -> DType { + pub fn dtype(&self) -> R::DType { self.storage.dtype() } @@ -321,6 +151,153 @@ impl Tensor { self.layout.dim(dim) } + /// Get size along a dimension, returning error on invalid index + pub fn dim(&self, index: isize) -> Result { + self.layout.dim(index).ok_or(Error::InvalidDimension { + dim: index, + ndim: self.ndim(), + }) + } + + // ===== Aliases (common across tensor libraries) ===== + + /// Number of dimensions (alias for `ndim`) + #[inline] + pub fn rank(&self) -> usize { + self.layout.ndim() + } + + /// Total number of elements (alias for `numel`) + #[inline] + pub fn elem_count(&self) -> usize { + self.layout.elem_count() + } + + /// Shape as slice (alias for `shape`) + #[inline] + pub fn dims(&self) -> &[usize] { + self.layout.shape() + } + + /// Total number of elements (alias for `numel`) + #[inline] + pub fn len(&self) -> usize { + self.layout.elem_count() + } + + /// Whether the tensor has zero elements + #[inline] + pub fn is_empty(&self) -> bool { + self.layout.elem_count() == 0 + } + + /// Layout offset into storage (in elements) + #[inline] + pub fn offset(&self) -> usize { + self.layout.offset() + } + + /// Data pointer adjusted for layout offset. + /// This is the pointer to the first element of this tensor's view. + #[inline] + pub fn ptr(&self) -> u64 { + self.storage.ptr() + (self.layout.offset() * self.dtype().size_in_bytes()) as u64 + } + + /// Whether the underlying storage is owned (will deallocate on drop) + #[inline] + pub fn owns_memory(&self) -> bool { + self.storage.is_owned() + } + + /// Check if two tensors share the same storage + pub fn shares_storage_with(&self, other: &Tensor) -> bool { + self.storage.ptr() == other.storage.ptr() + } + + /// Storage reference count + pub fn ref_count(&self) -> usize { + self.storage.ref_count() + } + + // ===== Dimension Unpacking ===== + + /// Unpack shape of a 1D tensor + pub fn dims1(&self) -> Result { + let s = self.shape(); + if s.len() == 1 { + Ok(s[0]) + } else { + Err(Error::ShapeMismatch { + expected: vec![0], + got: s.to_vec(), + }) + } + } + + /// Unpack shape of a 2D tensor + pub fn dims2(&self) -> Result<(usize, usize)> { + let s = self.shape(); + if s.len() == 2 { + Ok((s[0], s[1])) + } else { + Err(Error::ShapeMismatch { + expected: vec![0, 0], + got: s.to_vec(), + }) + } + } + + /// Unpack shape of a 3D tensor + pub fn dims3(&self) -> Result<(usize, usize, usize)> { + let s = self.shape(); + if s.len() == 3 { + Ok((s[0], s[1], s[2])) + } else { + Err(Error::ShapeMismatch { + expected: vec![0, 0, 0], + got: s.to_vec(), + }) + } + } + + /// Unpack shape of a 4D tensor + pub fn dims4(&self) -> Result<(usize, usize, usize, usize)> { + let s = self.shape(); + if s.len() == 4 { + Ok((s[0], s[1], s[2], s[3])) + } else { + Err(Error::ShapeMismatch { + expected: vec![0, 0, 0, 0], + got: s.to_vec(), + }) + } + } + + /// Unpack shape of a 5D tensor + pub fn dims5(&self) -> Result<(usize, usize, usize, usize, usize)> { + let s = self.shape(); + if s.len() == 5 { + Ok((s[0], s[1], s[2], s[3], s[4])) + } else { + Err(Error::ShapeMismatch { + expected: vec![0, 0, 0, 0, 0], + got: s.to_vec(), + }) + } + } + + // ===== Construction Helpers ===== + + /// Create tensor from storage and contiguous layout + pub fn from_storage_contiguous(storage: Storage, shape: &[usize]) -> Self { + Self { + id: TensorId::new(), + storage, + layout: Layout::contiguous(shape), + } + } + // ===== View Operations (Zero-Copy) ===== /// Transpose two dimensions (zero-copy) @@ -561,7 +538,7 @@ impl Tensor { /// - CPU/CUDA: Uses pointer arithmetic (handles can be offset directly) /// - WGPU: Uses compute shader (buffer IDs don't support arithmetic) pub fn contiguous(&self) -> Self { - if self.is_contiguous() { + if self.is_contiguous() && self.layout.offset() == 0 { self.clone() } else { // Need to copy data to a new contiguous storage @@ -636,6 +613,39 @@ impl Tensor { result } + /// Record an event on the compute stream for this tensor's device. + /// + /// Call this BEFORE launching additional compute work, then pass the event + /// to `to_vec_pipelined` AFTER launching the compute work. This allows the + /// copy to proceed as soon as the event fires, while compute continues. + pub fn record_event(&self) -> crate::error::Result { + R::record_compute_event(self.storage.device()) + } + + /// Copy tensor data to a Vec using the pipelined copy stream, synchronized + /// via a previously recorded event. + /// + /// On CUDA, syncs only the copy stream — compute stream keeps running. + pub fn to_vec_pipelined(&self, event: u64) -> crate::error::Result> { + if !self.is_contiguous() { + return Err(crate::error::Error::ShapeMismatch { + expected: vec![self.numel()], + got: self.shape().to_vec(), + }); + } + + let numel = self.numel(); + let offset = self.layout.offset(); + let elem_size = std::mem::size_of::(); + let byte_offset = offset * elem_size; + + let mut result = vec![T::zeroed(); numel]; + let bytes: &mut [u8] = bytemuck::cast_slice_mut(&mut result); + let src_ptr = self.storage.ptr() as usize + byte_offset; + R::copy_from_device_pipelined(src_ptr as u64, bytes, self.storage.device(), event)?; + Ok(result) + } + /// Extract the scalar value from a single-element tensor /// /// This is the idiomatic way to get a scalar value from a tensor for use @@ -687,6 +697,278 @@ impl Tensor { } } +// ============================================================================ +// Generic constructors (work with ANY R::DType via DataType trait) +// ============================================================================ + +impl Tensor { + /// Create a tensor filled with zeros (generic, works with any DType) + pub fn try_zeros_generic(shape: &[usize], dtype: R::DType, device: &R::Device) -> Result { + Self::try_full_scalar_generic(shape, dtype, 0.0, device) + } + + /// Create a tensor filled with ones (generic, works with any DType) + pub fn try_ones_generic(shape: &[usize], dtype: R::DType, device: &R::Device) -> Result { + Self::try_full_scalar_generic(shape, dtype, 1.0, device) + } + + /// Create a tensor filled with a scalar value (generic, works with any DType) + /// + /// Uses `DataType::fill_bytes` to generate the fill pattern, so it works + /// with any DType that implements the trait (including boostr's quantized types). + pub fn try_full_scalar_generic( + shape: &[usize], + dtype: R::DType, + value: f64, + device: &R::Device, + ) -> Result { + let len: usize = shape.iter().product(); + if len == 0 { + return Self::try_empty(shape, dtype, device); + } + + let bytes = dtype.fill_bytes(value, len).ok_or_else(|| { + Error::Msg(format!( + "fill not supported for dtype {}", + dtype.short_name() + )) + })?; + + let storage = Storage::from_bytes(&bytes, dtype, device)?; + let layout = Layout::contiguous(shape); + + Ok(Self { + id: TensorId::new(), + storage, + layout, + }) + } + + /// Create a tensor from raw bytes with specified dtype (generic) + pub fn try_from_bytes( + bytes: &[u8], + shape: &[usize], + dtype: R::DType, + device: &R::Device, + ) -> Result { + let storage = Storage::from_bytes(bytes, dtype, device)?; + let layout = Layout::contiguous(shape); + Ok(Self { + id: TensorId::new(), + storage, + layout, + }) + } +} + +// ============================================================================ +// Constructors that require numr's standard DType (for variant matching) +// ============================================================================ + +impl> Tensor { + /// Create a tensor from a slice of data + /// + /// # Panics + /// + /// Panics if `data.len()` does not equal the product of the `shape` dimensions. + /// For a fallible alternative, use [`Self::try_from_slice`]. + /// + /// # Example + /// + /// ``` + /// # use numr::prelude::*; + /// # let device = CpuDevice::new(); + /// let tensor = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device); + /// # Ok::<(), numr::error::Error>(()) + /// ``` + #[track_caller] + pub fn from_slice(data: &[T], shape: &[usize], device: &R::Device) -> Self { + Self::try_from_slice(data, shape, device) + .unwrap_or_else(|e| panic!("Tensor::from_slice failed: {e}")) + } + + /// Create a tensor from a slice of data (fallible version) + /// + /// Returns an error if `data.len()` does not equal the product of the `shape` dimensions, + /// or if memory allocation fails. + /// + /// # Example + /// + /// ``` + /// # use numr::prelude::*; + /// # let device = CpuDevice::new(); + /// let tensor = Tensor::::try_from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device)?; + /// # Ok::<(), numr::error::Error>(()) + /// ``` + pub fn try_from_slice( + data: &[T], + shape: &[usize], + device: &R::Device, + ) -> Result { + let expected_len: usize = shape.iter().product(); + if data.len() != expected_len { + return Err(Error::ShapeMismatch { + expected: shape.to_vec(), + got: vec![data.len()], + }); + } + + let storage = Storage::from_slice(data, device)?; + let layout = Layout::contiguous(shape); + + Ok(Self { + id: TensorId::new(), + storage, + layout, + }) + } + + /// Create a tensor filled with zeros + /// + /// This properly initializes memory to zero on all backends (CPU and GPU). + #[track_caller] + pub fn zeros(shape: &[usize], dtype: DType, device: &R::Device) -> Self { + Self::try_zeros(shape, dtype, device) + .unwrap_or_else(|e| panic!("Tensor::zeros failed: {e}")) + } + + /// Create a tensor filled with zeros (fallible version) + pub fn try_zeros(shape: &[usize], dtype: DType, device: &R::Device) -> Result { + Self::try_full_scalar(shape, dtype, 0.0, device) + } + + /// Create a tensor filled with ones + #[track_caller] + pub fn ones(shape: &[usize], dtype: DType, device: &R::Device) -> Self { + Self::try_ones(shape, dtype, device).unwrap_or_else(|e| panic!("Tensor::ones failed: {e}")) + } + + /// Create a tensor filled with ones (fallible version) + pub fn try_ones(shape: &[usize], dtype: DType, device: &R::Device) -> Result { + Self::try_full_scalar(shape, dtype, 1.0, device) + } + + /// Create a tensor filled with a scalar value + /// + /// The scalar is converted to the target dtype. + #[track_caller] + pub fn full_scalar(shape: &[usize], dtype: DType, value: f64, device: &R::Device) -> Self { + Self::try_full_scalar(shape, dtype, value, device) + .unwrap_or_else(|e| panic!("Tensor::full_scalar failed: {e}")) + } + + /// Create a tensor filled with a scalar value (fallible version) + pub fn try_full_scalar( + shape: &[usize], + dtype: DType, + value: f64, + device: &R::Device, + ) -> Result { + // Helper to convert a typed Vec to bytes safely. + // Allocates with correct alignment for T, then copies to u8 vec. + #[inline] + fn typed_to_bytes(v: Vec) -> Vec { + bytemuck::cast_slice::(&v).to_vec() + } + + let len: usize = shape.iter().product(); + if len == 0 { + return Self::try_empty(shape, dtype, device); + } + + // Allocate with correct type alignment, then convert to bytes. + // This avoids alignment violations that would occur if we allocated + // a Vec and cast to stricter-aligned types like f64/i64. + let bytes: Vec = match dtype { + DType::F64 => typed_to_bytes(vec![value; len]), + DType::F32 => typed_to_bytes(vec![value as f32; len]), + DType::F16 => { + #[cfg(feature = "f16")] + { + use half::f16; + typed_to_bytes(vec![f16::from_f64(value); len]) + } + #[cfg(not(feature = "f16"))] + { + let half_bits = half_from_f32(value as f32, dtype); + typed_to_bytes(vec![half_bits; len]) + } + } + DType::BF16 => { + #[cfg(feature = "f16")] + { + use half::bf16; + typed_to_bytes(vec![bf16::from_f64(value); len]) + } + #[cfg(not(feature = "f16"))] + { + let half_bits = half_from_f32(value as f32, dtype); + typed_to_bytes(vec![half_bits; len]) + } + } + DType::FP8E4M3 => { + vec![crate::dtype::FP8E4M3::from_f32(value as f32).to_bits(); len] + } + DType::FP8E5M2 => { + vec![crate::dtype::FP8E5M2::from_f32(value as f32).to_bits(); len] + } + DType::I64 => typed_to_bytes(vec![value as i64; len]), + DType::I32 => typed_to_bytes(vec![value as i32; len]), + DType::I16 => typed_to_bytes(vec![value as i16; len]), + DType::I8 => typed_to_bytes(vec![value as i8; len]), + DType::U64 => typed_to_bytes(vec![value as u64; len]), + DType::U32 => typed_to_bytes(vec![value as u32; len]), + DType::U16 => typed_to_bytes(vec![value as u16; len]), + DType::U8 => vec![value as u8; len], + DType::Bool => vec![if value != 0.0 { 1u8 } else { 0u8 }; len], + DType::Complex64 => { + typed_to_bytes(vec![crate::dtype::Complex64::new(value as f32, 0.0); len]) + } + DType::Complex128 => { + typed_to_bytes(vec![crate::dtype::Complex128::new(value, 0.0); len]) + } + }; + + // Allocate and copy to device + let storage = Storage::from_bytes(&bytes, dtype, device)?; + let layout = Layout::contiguous(shape); + + Ok(Self { + id: TensorId::new(), + storage, + layout, + }) + } +} + +// ============================================================================ +// Foundational ops (generic — work with any R::DType) +// ============================================================================ + +impl Tensor { + /// Serialize tensor data to raw bytes + /// + /// Makes tensor contiguous first if needed, then copies raw bytes from device. + pub fn to_bytes(&self) -> Result> { + let tensor = if self.is_contiguous() && self.offset() == 0 { + std::borrow::Cow::Borrowed(self) + } else { + std::borrow::Cow::Owned(self.contiguous()) + }; + let size = tensor.numel() * tensor.dtype().size_in_bytes(); + let mut data = vec![0u8; size]; + R::copy_from_device(tensor.storage().ptr(), &mut data, tensor.storage().device()) + .map_err(|e| Error::Msg(format!("to_bytes copy failed: {}", e)))?; + Ok(data) + } + + /// Clone tensor with new storage (deep copy) + pub fn clone_deep(&self) -> Result { + let bytes = self.to_bytes()?; + Self::try_from_bytes(&bytes, self.shape(), self.dtype(), self.device()) + } +} + impl Clone for Tensor { /// Clone creates a new tensor sharing the same storage (zero-copy) fn clone(&self) -> Self { @@ -711,7 +993,12 @@ impl fmt::Debug for Tensor { impl fmt::Display for Tensor { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Tensor({:?}, dtype={})", self.shape(), self.dtype()) + write!( + f, + "Tensor({:?}, dtype={})", + self.shape(), + self.dtype().short_name() + ) } } diff --git a/src/tensor/id.rs b/src/tensor/id.rs index f0129c83..d386ce3a 100644 --- a/src/tensor/id.rs +++ b/src/tensor/id.rs @@ -25,6 +25,12 @@ impl TensorId { self.0 } + /// Get the raw ID value as u64 (alias for raw) + #[inline] + pub fn as_u64(self) -> u64 { + self.0 + } + /// Create from raw value (for testing/serialization only) #[inline] pub const fn from_raw(id: u64) -> Self { diff --git a/src/tensor/layout.rs b/src/tensor/layout.rs index 02bfd1b3..3b0ed723 100644 --- a/src/tensor/layout.rs +++ b/src/tensor/layout.rs @@ -13,7 +13,7 @@ use std::fmt; /// /// Address of element at indices `[i0, i1, ..., in]`: /// offset + i0 * `strides[0]` + i1 * `strides[1]` + ... + in * `strides[n]` -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct Layout { /// Shape: size along each dimension shape: Shape, @@ -120,13 +120,27 @@ impl Layout { } /// Check if memory is contiguous (row-major order) + /// + /// A layout is contiguous if its strides match row-major order. + /// Size-1 dimensions are ignored since their stride doesn't affect + /// memory layout (only one element along that axis). + /// The offset does not affect contiguity (a narrowed view can still + /// be contiguous in its stride pattern). pub fn is_contiguous(&self) -> bool { if self.is_scalar() { return true; } let expected = Self::compute_contiguous_strides(&self.shape); - self.strides == expected && self.offset == 0 + if self.strides == expected { + return true; + } + + // Lenient check: strides for size-1 dims don't matter + self.shape + .iter() + .zip(self.strides.iter().zip(expected.iter())) + .all(|(&s, (&actual, &expect))| s == 1 || actual == expect) } /// Get size along a specific dimension @@ -206,7 +220,14 @@ impl Layout { return None; } - Some(Self::contiguous(new_shape)) + // Preserve offset for views (e.g., from narrow) + let shape: Shape = new_shape.iter().copied().collect(); + let strides = Self::compute_contiguous_strides(&shape); + Some(Self { + shape, + strides, + offset: self.offset, + }) } /// Create a squeezed layout (remove dimensions of size 1) @@ -415,6 +436,253 @@ impl Layout { Some(result) } + /// Create layout from usize strides (convenience for existing code) + /// + /// Converts usize strides to isize. All values must fit in isize. + #[inline] + pub fn new_unsigned(shape: &[usize], strides: &[usize], offset: usize) -> Self { + let strides_isize: Strides = strides.iter().map(|&s| s as isize).collect(); + Self { + shape: shape.into(), + strides: strides_isize, + offset, + } + } + + /// Get the rank (number of dimensions) - alias for `ndim()` + #[inline] + pub fn rank(&self) -> usize { + self.shape.len() + } + + /// Returns true if the tensor has zero elements + #[inline] + pub fn is_empty(&self) -> bool { + self.elem_count() == 0 + } + + /// Transpose last two dimensions (for matrix operations) + /// + /// Common operation for matmul: transpose(-2, -1) + #[inline] + pub fn t(&self) -> Option { + if self.ndim() < 2 { + return None; + } + let n = self.ndim(); + self.transpose_axes(n - 2, n - 1) + } + + /// Transpose two dimensions by axis index (usize version) + /// + /// Unlike `transpose()` which takes isize for negative indexing support, + /// this takes usize indices directly. + pub fn transpose_axes(&self, dim0: usize, dim1: usize) -> Option { + if dim0 >= self.ndim() || dim1 >= self.ndim() { + return None; + } + + let mut new_shape = self.shape.clone(); + let mut new_strides = self.strides.clone(); + new_shape.swap(dim0, dim1); + new_strides.swap(dim0, dim1); + + Some(Self { + shape: new_shape, + strides: new_strides, + offset: self.offset, + }) + } + + /// Squeeze a specific dimension (remove if size is 1) + /// + /// Returns None if dim is out of bounds or dimension size is not 1. + pub fn squeeze_dim(&self, dim: usize) -> Option { + if dim >= self.ndim() || self.shape[dim] != 1 { + return None; + } + + let mut new_shape = self.shape.clone(); + let mut new_strides = self.strides.clone(); + new_shape.remove(dim); + new_strides.remove(dim); + + Some(Self::new(new_shape, new_strides, self.offset)) + } + + /// Squeeze all dimensions of size 1 + pub fn squeeze_all(&self) -> Self { + self.squeeze(None) + } + + /// Unsqueeze (add dimension of size 1) at a usize index + pub fn unsqueeze_at(&self, dim: usize) -> Option { + if dim > self.ndim() { + return None; + } + + let mut new_shape = self.shape.clone(); + let mut new_strides = self.strides.clone(); + + let new_stride = if dim < self.ndim() { + new_strides[dim] * new_shape[dim] as isize + } else { + 1 + }; + + new_shape.insert(dim, 1); + new_strides.insert(dim, new_stride); + + Some(Self::new(new_shape, new_strides, self.offset)) + } + + /// Permute dimensions according to the given order + /// + /// Alias provided for API compatibility. See `permute()`. + #[inline] + pub fn permute_dims(&self, dims: &[usize]) -> Option { + self.permute(dims) + } + + /// Flatten dimensions [start_dim, end_dim] into a single dimension + pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Option { + if start_dim > end_dim || end_dim >= self.ndim() { + return None; + } + + // Must be contiguous in the flattened range + for i in start_dim..end_dim { + if self.strides[i] != self.strides[i + 1] * self.shape[i + 1] as isize { + return None; + } + } + + let flat_size: usize = self.shape[start_dim..=end_dim].iter().product(); + let mut new_shape = Shape::new(); + let mut new_strides = Strides::new(); + + for i in 0..start_dim { + new_shape.push(self.shape[i]); + new_strides.push(self.strides[i]); + } + + new_shape.push(flat_size); + new_strides.push(self.strides[end_dim]); + + for i in (end_dim + 1)..self.ndim() { + new_shape.push(self.shape[i]); + new_strides.push(self.strides[i]); + } + + Some(Self::new(new_shape, new_strides, self.offset)) + } + + /// Create a strided view with arbitrary shape, strides, and offset + /// + /// Low-level operation for advanced indexing. The offset is relative + /// to the current layout's offset. + pub fn as_strided(&self, shape: &[usize], strides: &[isize], offset: usize) -> Self { + Self { + shape: shape.into(), + strides: strides.into(), + offset: self.offset + offset, + } + } + + /// Compute the minimum storage size required for this layout (in elements) + /// + /// For contiguous layouts: elem_count() + offset + /// For strided layouts: max reachable offset + 1 + pub fn storage_size(&self) -> usize { + if self.shape.is_empty() { + return if self.offset > 0 { self.offset + 1 } else { 1 }; + } + + let mut max_offset = self.offset as isize; + for (&dim, &stride) in self.shape.iter().zip(self.strides.iter()) { + if dim > 0 && stride > 0 { + max_offset += (dim as isize - 1) * stride; + } + } + debug_assert!( + max_offset >= 0, + "storage_size: negative max_offset {}", + max_offset + ); + (max_offset as usize) + 1 + } + + /// Compute linear offset for a multi-dimensional index + /// + /// Alias for `index()` for API compatibility. + #[inline] + pub fn index_to_offset(&self, indices: &[usize]) -> Option { + self.index(indices) + } + + /// Compute linear offset without bounds checking + /// + /// # Safety + /// Caller must ensure index is within bounds. + #[inline] + pub unsafe fn index_to_offset_unchecked(&self, index: &[usize]) -> usize { + let mut offset = self.offset as isize; + for (&i, &stride) in index.iter().zip(self.strides.iter()) { + offset += i as isize * stride; + } + offset as usize + } + + /// Convert linear offset back to multi-dimensional index + /// + /// Only works correctly for contiguous layouts. + pub fn offset_to_index(&self, mut offset: usize) -> Option> { + if !self.is_contiguous() || offset >= self.elem_count() { + return None; + } + + let mut index = Vec::with_capacity(self.ndim()); + for &stride in self.strides.iter() { + if stride > 0 { + let s = stride as usize; + index.push(offset / s); + offset %= s; + } else { + index.push(0); + } + } + + Some(index) + } + + /// Compute broadcast shape between this layout and another + pub fn broadcast_shape(&self, other: &Layout) -> Option> { + Self::broadcast_shapes(self.shape(), other.shape()) + } + + /// Compute broadcast shape between two shapes + pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> Option> { + let max_rank = a.len().max(b.len()); + let mut result = vec![0usize; max_rank]; + + for i in 0..max_rank { + let dim_a = if i < a.len() { a[a.len() - 1 - i] } else { 1 }; + let dim_b = if i < b.len() { b[b.len() - 1 - i] } else { 1 }; + + if dim_a == dim_b { + result[max_rank - 1 - i] = dim_a; + } else if dim_a == 1 { + result[max_rank - 1 - i] = dim_b; + } else if dim_b == 1 { + result[max_rank - 1 - i] = dim_a; + } else { + return None; + } + } + + Some(result) + } + /// Create a broadcast layout to a target shape /// /// Returns None if shapes are not broadcastable @@ -471,6 +739,67 @@ impl fmt::Display for Layout { } } +// Convenient From implementations +impl From> for Layout { + fn from(dims: Vec) -> Self { + Layout::contiguous(&dims) + } +} + +impl From<&[usize]> for Layout { + fn from(dims: &[usize]) -> Self { + Layout::contiguous(dims) + } +} + +impl From<[usize; N]> for Layout { + fn from(dims: [usize; N]) -> Self { + Layout::contiguous(&dims) + } +} + +impl From for Layout { + fn from(dim: usize) -> Self { + Layout::contiguous(&[dim]) + } +} + +impl From<(usize,)> for Layout { + fn from((d,): (usize,)) -> Self { + Layout::contiguous(&[d]) + } +} + +impl From<(usize, usize)> for Layout { + fn from((d1, d2): (usize, usize)) -> Self { + Layout::contiguous(&[d1, d2]) + } +} + +impl From<(usize, usize, usize)> for Layout { + fn from((d1, d2, d3): (usize, usize, usize)) -> Self { + Layout::contiguous(&[d1, d2, d3]) + } +} + +impl From<(usize, usize, usize, usize)> for Layout { + fn from((d1, d2, d3, d4): (usize, usize, usize, usize)) -> Self { + Layout::contiguous(&[d1, d2, d3, d4]) + } +} + +impl From<(usize, usize, usize, usize, usize)> for Layout { + fn from((d1, d2, d3, d4, d5): (usize, usize, usize, usize, usize)) -> Self { + Layout::contiguous(&[d1, d2, d3, d4, d5]) + } +} + +impl From<(usize, usize, usize, usize, usize, usize)> for Layout { + fn from((d1, d2, d3, d4, d5, d6): (usize, usize, usize, usize, usize, usize)) -> Self { + Layout::contiguous(&[d1, d2, d3, d4, d5, d6]) + } +} + // Note: broadcast_shape is implemented in crate::ops::arithmetic and is the canonical version. // Use crate::ops::broadcast_shape for broadcasting logic. diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 662d5d86..4bc5258e 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -4,13 +4,14 @@ //! array stored on a compute device (CPU, GPU, etc.). mod core; -pub(crate) mod id; +pub mod id; mod layout; +mod ops; pub(crate) mod shape; mod storage; mod strides; pub use core::Tensor; -pub(crate) use id::TensorId; +pub use id::TensorId; pub use layout::{Layout, Shape, Strides}; pub use storage::Storage; diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs new file mode 100644 index 00000000..b2d23c97 --- /dev/null +++ b/src/tensor/ops.rs @@ -0,0 +1,488 @@ +//! Convenience methods on Tensor that delegate to Client ops +//! +//! These methods provide ergonomic `tensor.add(&other)` style calls +//! that internally get the client and delegate to the appropriate trait. + +use crate::dtype::DType; +use crate::error::Result; +use crate::ops::traits::{ + ActivationOps, BinaryOps, CompareOps, ConvOps, CumulativeOps, IndexingOps, MatmulOps, + NormalizationOps, PaddingMode, ReduceOps, ScalarOps, ShapeOps, TypeConversionOps, UnaryOps, + UtilityOps, +}; +use crate::runtime::Runtime; +use crate::tensor::Tensor; + +// ============================================================================ +// Binary arithmetic +// ============================================================================ + +impl Tensor +where + R::Client: BinaryOps, +{ + /// Element-wise addition: self + other + pub fn add(&self, other: &Tensor) -> Result> { + let client = R::default_client(self.device()); + client.add(self, other) + } + + /// Element-wise subtraction: self - other + pub fn sub(&self, other: &Tensor) -> Result> { + let client = R::default_client(self.device()); + client.sub(self, other) + } + + /// Element-wise multiplication: self * other + pub fn mul(&self, other: &Tensor) -> Result> { + let client = R::default_client(self.device()); + client.mul(self, other) + } + + /// Element-wise division: self / other + pub fn div(&self, other: &Tensor) -> Result> { + let client = R::default_client(self.device()); + client.div(self, other) + } + + /// Element-wise power: self ^ other + pub fn pow(&self, other: &Tensor) -> Result> { + let client = R::default_client(self.device()); + client.pow(self, other) + } + + /// Element-wise maximum: max(self, other) + pub fn maximum(&self, other: &Tensor) -> Result> { + let client = R::default_client(self.device()); + client.maximum(self, other) + } + + /// Element-wise minimum: min(self, other) + pub fn minimum(&self, other: &Tensor) -> Result> { + let client = R::default_client(self.device()); + client.minimum(self, other) + } +} + +// ============================================================================ +// Unary operations +// ============================================================================ + +impl Tensor +where + R::Client: UnaryOps, +{ + /// Element-wise negation + pub fn neg(&self) -> Result> { + let client = R::default_client(self.device()); + client.neg(self) + } + + /// Element-wise absolute value + pub fn abs(&self) -> Result> { + let client = R::default_client(self.device()); + client.abs(self) + } + + /// Element-wise square root + pub fn sqrt(&self) -> Result> { + let client = R::default_client(self.device()); + client.sqrt(self) + } + + /// Element-wise exponential + pub fn exp(&self) -> Result> { + let client = R::default_client(self.device()); + client.exp(self) + } + + /// Element-wise natural log + pub fn log(&self) -> Result> { + let client = R::default_client(self.device()); + client.log(self) + } + + /// Element-wise sine + pub fn sin(&self) -> Result> { + let client = R::default_client(self.device()); + client.sin(self) + } + + /// Element-wise cosine + pub fn cos(&self) -> Result> { + let client = R::default_client(self.device()); + client.cos(self) + } + + /// Element-wise tangent + pub fn tan(&self) -> Result> { + let client = R::default_client(self.device()); + client.tan(self) + } + + /// Element-wise hyperbolic tangent + pub fn tanh(&self) -> Result> { + let client = R::default_client(self.device()); + client.tanh(self) + } + + /// Element-wise reciprocal (1/x) + pub fn recip(&self) -> Result> { + let client = R::default_client(self.device()); + client.recip(self) + } + + /// Element-wise floor + pub fn floor(&self) -> Result> { + let client = R::default_client(self.device()); + client.floor(self) + } + + /// Element-wise ceil + pub fn ceil(&self) -> Result> { + let client = R::default_client(self.device()); + client.ceil(self) + } + + /// Element-wise round + pub fn round(&self) -> Result> { + let client = R::default_client(self.device()); + client.round(self) + } +} + +// ============================================================================ +// Scalar operations +// ============================================================================ + +impl Tensor +where + R::Client: ScalarOps, +{ + /// Add scalar: self + scalar + pub fn add_scalar(&self, scalar: f64) -> Result> { + let client = R::default_client(self.device()); + client.add_scalar(self, scalar) + } + + /// Multiply by scalar: self * scalar + pub fn mul_scalar(&self, scalar: f64) -> Result> { + let client = R::default_client(self.device()); + client.mul_scalar(self, scalar) + } + + /// Scale alias for mul_scalar + pub fn scale(&self, scalar: f64) -> Result> { + self.mul_scalar(scalar) + } +} + +// ============================================================================ +// Activation functions +// ============================================================================ + +impl Tensor +where + R::Client: ActivationOps, +{ + /// ReLU activation: max(0, x) + pub fn relu(&self) -> Result> { + let client = R::default_client(self.device()); + client.relu(self) + } + + /// Sigmoid activation: 1 / (1 + exp(-x)) + pub fn sigmoid(&self) -> Result> { + let client = R::default_client(self.device()); + client.sigmoid(self) + } + + /// GELU activation + pub fn gelu(&self) -> Result> { + let client = R::default_client(self.device()); + client.gelu(self) + } + + /// SiLU/Swish activation: x * sigmoid(x) + pub fn silu(&self) -> Result> { + let client = R::default_client(self.device()); + client.silu(self) + } + + /// Softmax along dimension + pub fn softmax(&self, dim: isize) -> Result> { + let client = R::default_client(self.device()); + client.softmax(self, dim) + } + + /// Log-softmax along dimension: log(softmax(x, dim)) + pub fn log_softmax(&self, dim: isize) -> Result> { + let client = R::default_client(self.device()); + client.log_softmax(self, dim) + } + + /// Dropout: randomly zero elements with probability `p` during training + pub fn dropout(&self, p: f64, training: bool) -> Result> { + let client = R::default_client(self.device()); + client.dropout(self, p, training) + } +} + +// ============================================================================ +// Reduction operations +// ============================================================================ + +impl Tensor +where + R::Client: ReduceOps, +{ + /// Sum along dimensions + pub fn sum(&self, dims: &[usize], keepdim: bool) -> Result> { + let client = R::default_client(self.device()); + client.sum(self, dims, keepdim) + } + + /// Mean along dimensions + pub fn mean(&self, dims: &[usize], keepdim: bool) -> Result> { + let client = R::default_client(self.device()); + client.mean(self, dims, keepdim) + } + + /// Max along dimensions + pub fn max(&self, dims: &[usize], keepdim: bool) -> Result> { + let client = R::default_client(self.device()); + client.max(self, dims, keepdim) + } + + /// Min along dimensions + pub fn min(&self, dims: &[usize], keepdim: bool) -> Result> { + let client = R::default_client(self.device()); + client.min(self, dims, keepdim) + } +} + +// ============================================================================ +// Matrix operations +// ============================================================================ + +impl Tensor +where + R::Client: MatmulOps, +{ + /// Matrix multiplication: self @ other + pub fn matmul(&self, other: &Tensor) -> Result> { + let client = R::default_client(self.device()); + client.matmul(self, other) + } +} + +// ============================================================================ +// Normalization +// ============================================================================ + +impl Tensor +where + R::Client: NormalizationOps, +{ + /// RMS normalization: x / RMS(x) * weight + pub fn rms_norm(&self, weight: &Tensor, eps: f32) -> Result> { + let client = R::default_client(self.device()); + client.rms_norm(self, weight, eps) + } + + /// Layer normalization: (x - mean) / sqrt(var + eps) * weight + bias + pub fn layer_norm(&self, weight: &Tensor, bias: &Tensor, eps: f32) -> Result> { + let client = R::default_client(self.device()); + client.layer_norm(self, weight, bias, eps) + } +} + +// ============================================================================ +// Comparison operations +// ============================================================================ + +impl Tensor +where + R::Client: CompareOps, +{ + /// Element-wise equality + pub fn eq(&self, other: &Tensor) -> Result> { + let client = R::default_client(self.device()); + client.eq(self, other) + } + + /// Element-wise greater than + pub fn gt(&self, other: &Tensor) -> Result> { + let client = R::default_client(self.device()); + client.gt(self, other) + } + + /// Element-wise less than + pub fn lt(&self, other: &Tensor) -> Result> { + let client = R::default_client(self.device()); + client.lt(self, other) + } +} + +// ============================================================================ +// Indexing operations +// ============================================================================ + +impl Tensor +where + R::Client: IndexingOps, +{ + /// Select elements along a dimension using indices + pub fn index_select(&self, dim: usize, indices: &Tensor) -> Result> { + let client = R::default_client(self.device()); + client.index_select(self, dim, indices) + } + + /// Argmax along a dimension + pub fn argmax(&self, dim: usize, keepdim: bool) -> Result> { + let client = R::default_client(self.device()); + client.argmax(self, dim, keepdim) + } + + /// Argmin along a dimension + pub fn argmin(&self, dim: usize, keepdim: bool) -> Result> { + let client = R::default_client(self.device()); + client.argmin(self, dim, keepdim) + } + + /// Fill tensor with value where mask is true + pub fn masked_fill(&self, mask: &Tensor, value: f64) -> Result> { + let client = R::default_client(self.device()); + client.masked_fill(self, mask, value) + } + + /// Assign `src` into a slice of `self` along `dim` starting at `start`. + /// + /// Returns a new tensor with the slice region replaced by `src`. + pub fn slice_assign(&self, src: &Tensor, dim: usize, start: usize) -> Result> { + let client = R::default_client(self.device()); + client.slice_assign(self, src, dim, start) + } +} + +// ============================================================================ +// Shape operations +// ============================================================================ + +impl Tensor +where + R::Client: ShapeOps, +{ + /// Concatenate tensors along a dimension + pub fn cat(tensors: &[&Tensor], dim: isize) -> Result> { + if tensors.is_empty() { + return Err(crate::error::Error::InvalidArgument { + arg: "tensors", + reason: "cannot concatenate empty list".into(), + }); + } + let client = R::default_client(tensors[0].device()); + client.cat(tensors, dim) + } + + /// Stack tensors along a new dimension + pub fn stack(tensors: &[&Tensor], dim: isize) -> Result> { + if tensors.is_empty() { + return Err(crate::error::Error::InvalidArgument { + arg: "tensors", + reason: "cannot stack empty list".into(), + }); + } + let client = R::default_client(tensors[0].device()); + client.stack(tensors, dim) + } +} + +// ============================================================================ +// Cumulative operations +// ============================================================================ + +impl Tensor +where + R::Client: CumulativeOps, +{ + /// Cumulative sum along a dimension + pub fn cumsum(&self, dim: isize) -> Result> { + let client = R::default_client(self.device()); + client.cumsum(self, dim) + } + + /// Cumulative product along a dimension + pub fn cumprod(&self, dim: isize) -> Result> { + let client = R::default_client(self.device()); + client.cumprod(self, dim) + } + + /// Log-sum-exp along specified dimensions (numerically stable) + /// + /// Computes `log(sum(exp(x)))` in a numerically stable way: + /// `logsumexp(x) = max(x) + log(sum(exp(x - max(x))))` + pub fn logsumexp(&self, dims: &[usize], keepdim: bool) -> Result> { + let client = R::default_client(self.device()); + client.logsumexp(self, dims, keepdim) + } +} + +// ============================================================================ +// Type conversion +// ============================================================================ + +impl Tensor +where + R::Client: TypeConversionOps, +{ + /// Convert tensor to a different dtype + pub fn to_dtype(&self, dtype: DType) -> Result> { + let client = R::default_client(self.device()); + client.cast(self, dtype) + } +} + +// ============================================================================ +// Utility operations +// ============================================================================ + +impl Tensor +where + R::Client: UtilityOps, +{ + /// Clamp values to [min, max] + pub fn clamp(&self, min: f64, max: f64) -> Result> { + let client = R::default_client(self.device()); + client.clamp(self, min, max) + } + + /// One-hot encode indices + pub fn one_hot(&self, num_classes: usize) -> Result> { + let client = R::default_client(self.device()); + client.one_hot(self, num_classes) + } +} + +// ============================================================================ +// Convolution operations +// ============================================================================ + +impl Tensor +where + R::Client: ConvOps, +{ + /// 1D convolution + pub fn conv1d( + &self, + weight: &Tensor, + bias: Option<&Tensor>, + stride: usize, + padding: PaddingMode, + dilation: usize, + groups: usize, + ) -> Result> { + let client = R::default_client(self.device()); + client.conv1d(self, weight, bias, stride, padding, dilation, groups) + } +} diff --git a/src/tensor/shape.rs b/src/tensor/shape.rs index a678968c..ce7afa75 100644 --- a/src/tensor/shape.rs +++ b/src/tensor/shape.rs @@ -10,7 +10,7 @@ use std::ops::{Deref, DerefMut}; pub(crate) const STACK_DIMS: usize = 4; /// Shape type: dimensions of a tensor -#[derive(Clone, PartialEq, Eq, Default)] +#[derive(Clone, PartialEq, Eq, Default, Hash)] pub struct Shape(SmallVec<[usize; STACK_DIMS]>); impl Shape { diff --git a/src/tensor/storage.rs b/src/tensor/storage.rs index 7aaca994..d57e9574 100644 --- a/src/tensor/storage.rs +++ b/src/tensor/storage.rs @@ -1,6 +1,6 @@ //! Storage: device memory management with Arc-based sharing -use crate::dtype::{DType, Element}; +use crate::dtype::{DType, DataType, Element}; use crate::error::Result; use crate::runtime::Runtime; use std::sync::Arc; @@ -21,7 +21,7 @@ struct StorageInner { /// Number of elements (not bytes) len: usize, /// Element type - dtype: DType, + dtype: R::DType, /// Device where memory is allocated device: R::Device, /// If true, we own this memory and should deallocate on drop @@ -32,8 +32,8 @@ impl Storage { /// Create new storage with allocated memory /// /// Allocates `len` elements of type `dtype` on the specified device. - pub fn new(len: usize, dtype: DType, device: &R::Device) -> Result { - let size_bytes = len * dtype.size_in_bytes(); + pub fn new(len: usize, dtype: R::DType, device: &R::Device) -> Result { + let size_bytes = dtype.storage_bytes(len); let ptr = R::allocate(size_bytes, device)?; Ok(Self { @@ -47,19 +47,14 @@ impl Storage { }) } - /// Create storage from existing data with inferred dtype + /// Create storage from raw bytes with explicit dtype /// - /// Copies `data` to the device. The dtype is inferred from the Element type. - pub fn from_slice(data: &[T], device: &R::Device) -> Result { - let dtype = T::DTYPE; - let len = data.len(); - - // Copy data to device - let bytes = bytemuck::cast_slice(data); - let size_bytes = bytes.len(); - let ptr = R::allocate(size_bytes, device)?; + /// Use this when you have raw bytes and know the dtype. + pub fn from_bytes(data: &[u8], dtype: R::DType, device: &R::Device) -> Result { + let len = data.len() / dtype.size_in_bytes(); + let ptr = R::allocate(data.len(), device)?; - R::copy_to_device(bytes, ptr, device)?; + R::copy_to_device(data, ptr, device)?; Ok(Self { inner: Arc::new(StorageInner { @@ -72,16 +67,37 @@ impl Storage { }) } - /// Create storage from raw bytes with explicit dtype + /// Wrap existing device memory without taking ownership /// - /// Use this when you have raw bytes and know the dtype. - pub fn from_bytes(data: &[u8], dtype: DType, device: &R::Device) -> Result { - let len = data.len() / dtype.size_in_bytes(); - let ptr = R::allocate(data.len(), device)?; - - R::copy_to_device(data, ptr, device)?; + /// # Safety + /// - `ptr` must point to valid device memory + /// - The memory must remain valid for the lifetime of this Storage + /// - Caller is responsible for eventual deallocation + pub unsafe fn from_ptr(ptr: u64, len: usize, dtype: R::DType, device: &R::Device) -> Self { + Self { + inner: Arc::new(StorageInner { + ptr, + len, + dtype, + device: device.clone(), + owned: false, + }), + } + } - Ok(Self { + /// Wrap existing device memory and take ownership (will deallocate on drop) + /// + /// # Safety + /// - `ptr` must point to valid device memory allocated by this runtime + /// - `len` must match the actual allocation size (in elements) + /// - No other code will deallocate this memory + pub unsafe fn from_ptr_owned( + ptr: u64, + len: usize, + dtype: R::DType, + device: &R::Device, + ) -> Self { + Self { inner: Arc::new(StorageInner { ptr, len, @@ -89,27 +105,62 @@ impl Storage { device: device.clone(), owned: true, }), - }) + } } - /// Wrap existing device memory without taking ownership + /// Wrap existing device memory with explicit ownership flag /// /// # Safety /// - `ptr` must point to valid device memory - /// - The memory must remain valid for the lifetime of this Storage - /// - Caller is responsible for eventual deallocation - pub unsafe fn from_ptr(ptr: u64, len: usize, dtype: DType, device: &R::Device) -> Self { + /// - If `owned` is true, the memory must have been allocated by this runtime + /// - If `owned` is false, the memory must remain valid for the Storage's lifetime + pub unsafe fn from_raw( + ptr: u64, + len: usize, + dtype: R::DType, + device: &R::Device, + owned: bool, + ) -> Self { Self { inner: Arc::new(StorageInner { ptr, len, dtype, device: device.clone(), - owned: false, + owned, }), } } + /// Create storage from existing data with inferred dtype + /// + /// Copies `data` to the device. The dtype is inferred from the Element type. + /// Only available when the runtime uses numr's standard `DType`. + pub fn from_slice(data: &[T], device: &R::Device) -> Result + where + R: Runtime, + { + let dtype = T::DTYPE; + let len = data.len(); + + // Copy data to device + let bytes = bytemuck::cast_slice(data); + let size_bytes = bytes.len(); + let ptr = R::allocate(size_bytes, device)?; + + R::copy_to_device(bytes, ptr, device)?; + + Ok(Self { + inner: Arc::new(StorageInner { + ptr, + len, + dtype, + device: device.clone(), + owned: true, + }), + }) + } + /// Get the raw device pointer #[inline] pub fn ptr(&self) -> u64 { @@ -130,7 +181,7 @@ impl Storage { /// Get the element type #[inline] - pub fn dtype(&self) -> DType { + pub fn dtype(&self) -> R::DType { self.inner.dtype } @@ -143,7 +194,7 @@ impl Storage { /// Get size in bytes #[inline] pub fn size_in_bytes(&self) -> usize { - self.inner.len * self.inner.dtype.size_in_bytes() + self.inner.dtype.storage_bytes(self.inner.len) } /// Get the reference count @@ -158,9 +209,20 @@ impl Storage { Arc::strong_count(&self.inner) == 1 } - /// Get as raw buffer for passing to operations + /// Check if this storage owns its memory (will deallocate on drop) #[inline] - pub fn as_raw(&self) -> RawBuffer { + pub fn is_owned(&self) -> bool { + self.inner.owned + } + + /// Get as raw buffer for passing to operations. + /// + /// Only available when the runtime uses numr's standard `DType`. + #[inline] + pub fn as_raw(&self) -> RawBuffer + where + R: Runtime, + { RawBuffer { ptr: self.inner.ptr, len: self.inner.len, @@ -168,6 +230,39 @@ impl Storage { } } + /// View storage as a host slice without copying. + /// + /// # Safety + /// + /// The caller must ensure: + /// - The storage pointer is a valid host (CPU) pointer + /// - This is only safe for CPU-backed storage; calling on GPU storage is UB + /// - The returned slice borrows the storage, preventing deallocation + #[inline] + pub unsafe fn as_host_slice(&self) -> &[T] { + if self.inner.len == 0 { + return &[]; + } + let ptr = self.inner.ptr as *const T; + unsafe { std::slice::from_raw_parts(ptr, self.inner.len) } + } + + /// View storage as a mutable host slice without copying. + /// + /// # Safety + /// + /// Same as [`as_host_slice`], plus: + /// - The storage must be uniquely owned (no other references) + /// - The caller must ensure no aliasing + #[inline] + pub unsafe fn as_host_slice_mut(&mut self) -> &mut [T] { + if self.inner.len == 0 { + return &mut []; + } + let ptr = self.inner.ptr as *mut T; + unsafe { std::slice::from_raw_parts_mut(ptr, self.inner.len) } + } + /// Copy data from device to host pub fn to_vec(&self) -> Vec { // Allocate with correct alignment for T, then cast to bytes for copy. @@ -193,11 +288,7 @@ impl Clone for Storage { impl Drop for StorageInner { fn drop(&mut self) { if self.owned && self.ptr != 0 { - R::deallocate( - self.ptr, - self.len * self.dtype.size_in_bytes(), - &self.device, - ); + R::deallocate(self.ptr, self.dtype.storage_bytes(self.len), &self.device); } } } diff --git a/src/tensor/strides.rs b/src/tensor/strides.rs index 9e68c5e7..e3ed2220 100644 --- a/src/tensor/strides.rs +++ b/src/tensor/strides.rs @@ -9,7 +9,7 @@ use std::ops::{Deref, DerefMut}; /// Strides type: element offsets between consecutive elements along each dimension /// Signed to support negative strides (e.g., for flip operations) /// NOTE: Strides are in ELEMENTS, not bytes -#[derive(Clone, PartialEq, Eq, Default)] +#[derive(Clone, PartialEq, Eq, Default, Hash)] pub struct Strides(SmallVec<[isize; STACK_DIMS]>); impl Strides { diff --git a/tests/backend_parity/activation.rs b/tests/backend_parity/activation.rs new file mode 100644 index 00000000..5f7cb87f --- /dev/null +++ b/tests/backend_parity/activation.rs @@ -0,0 +1,512 @@ +// Backend parity tests for fused activation-mul operations (ActivationOps trait) +// +// Tests: silu_mul, gelu_mul, relu_mul, sigmoid_mul (forward) +// silu_mul_bwd, gelu_mul_bwd, relu_mul_bwd, sigmoid_mul_bwd (backward) +// +// Dtype-parameterized: each test runs for all supported dtypes across all backends. + +use numr::dtype::DType; +use numr::ops::ActivationOps; +use numr::runtime::Runtime; +use numr::tensor::Tensor; + +use crate::backend_parity::dtype_helpers::tensor_from_f64; +#[cfg(feature = "cuda")] +use crate::backend_parity::helpers::with_cuda_backend; +#[cfg(feature = "wgpu")] +use crate::backend_parity::helpers::with_wgpu_backend; +use crate::common::{ + assert_tensor_allclose, create_cpu_client, is_dtype_supported, supported_dtypes, +}; + +// ============================================================================ +// Test Utilities +// ============================================================================ + +#[derive(Clone)] +struct FusedTestCase { + a: Vec, + b: Vec, + shape: Vec, +} + +impl FusedTestCase { + fn new(a: Vec, b: Vec, shape: Vec) -> Self { + Self { a, b, shape } + } +} + +#[derive(Clone, Copy, Debug)] +#[allow(clippy::enum_variant_names)] +enum FusedActivationOp { + SiluMul, + GeluMul, + ReluMul, + SigmoidMul, +} + +fn apply_fused_fwd( + client: &impl ActivationOps, + op: FusedActivationOp, + a: &Tensor, + b: &Tensor, +) -> numr::error::Result> { + match op { + FusedActivationOp::SiluMul => client.silu_mul(a, b), + FusedActivationOp::GeluMul => client.gelu_mul(a, b), + FusedActivationOp::ReluMul => client.relu_mul(a, b), + FusedActivationOp::SigmoidMul => client.sigmoid_mul(a, b), + } +} + +fn apply_fused_bwd( + client: &impl ActivationOps, + op: FusedActivationOp, + grad: &Tensor, + a: &Tensor, + b: &Tensor, +) -> numr::error::Result<(Tensor, Tensor)> { + match op { + FusedActivationOp::SiluMul => client.silu_mul_bwd(grad, a, b), + FusedActivationOp::GeluMul => client.gelu_mul_bwd(grad, a, b), + FusedActivationOp::ReluMul => client.relu_mul_bwd(grad, a, b), + FusedActivationOp::SigmoidMul => client.sigmoid_mul_bwd(grad, a, b), + } +} + +// ============================================================================ +// Forward parity tests +// ============================================================================ + +fn test_fused_fwd_parity(op: FusedActivationOp, test_cases: &[FusedTestCase], dtype: DType) { + let (cpu_client, cpu_device) = create_cpu_client(); + + let cpu_results: Vec> = test_cases + .iter() + .map(|tc| { + let a = tensor_from_f64(&tc.a, &tc.shape, dtype, &cpu_device, &cpu_client) + .unwrap_or_else(|e| panic!("CPU tensor_from_f64 failed for {dtype:?}: {e}")); + let b = tensor_from_f64(&tc.b, &tc.shape, dtype, &cpu_device, &cpu_client) + .unwrap_or_else(|e| panic!("CPU tensor_from_f64 failed for {dtype:?}: {e}")); + apply_fused_fwd(&cpu_client, op, &a, &b) + .unwrap_or_else(|e| panic!("CPU {op:?} failed for {dtype:?}: {e}")) + }) + .collect(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + for (idx, tc) in test_cases.iter().enumerate() { + let a = tensor_from_f64(&tc.a, &tc.shape, dtype, &cuda_device, &cuda_client) + .unwrap_or_else(|e| panic!("CUDA tensor_from_f64 failed for {dtype:?}: {e}")); + let b = tensor_from_f64(&tc.b, &tc.shape, dtype, &cuda_device, &cuda_client) + .unwrap_or_else(|e| panic!("CUDA tensor_from_f64 failed for {dtype:?}: {e}")); + let result = apply_fused_fwd(&cuda_client, op, &a, &b) + .unwrap_or_else(|e| panic!("CUDA {op:?} failed for {dtype:?}: {e}")); + assert_tensor_allclose( + &result, + &cpu_results[idx], + dtype, + &format!("{op:?} CUDA vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + for (idx, tc) in test_cases.iter().enumerate() { + let a = tensor_from_f64(&tc.a, &tc.shape, dtype, &wgpu_device, &wgpu_client) + .unwrap_or_else(|e| panic!("WebGPU tensor_from_f64 failed for {dtype:?}: {e}")); + let b = tensor_from_f64(&tc.b, &tc.shape, dtype, &wgpu_device, &wgpu_client) + .unwrap_or_else(|e| panic!("WebGPU tensor_from_f64 failed for {dtype:?}: {e}")); + let result = apply_fused_fwd(&wgpu_client, op, &a, &b) + .unwrap_or_else(|e| panic!("WebGPU {op:?} failed for {dtype:?}: {e}")); + assert_tensor_allclose( + &result, + &cpu_results[idx], + dtype, + &format!("{op:?} WebGPU vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } +} + +// ============================================================================ +// Backward parity tests +// ============================================================================ + +fn test_fused_bwd_parity(op: FusedActivationOp, test_cases: &[FusedTestCase], dtype: DType) { + let (cpu_client, cpu_device) = create_cpu_client(); + + // Use the same data for grad as a simple ones-like pattern + let cpu_results: Vec<( + Tensor, + Tensor, + )> = test_cases + .iter() + .map(|tc| { + let a = tensor_from_f64(&tc.a, &tc.shape, dtype, &cpu_device, &cpu_client) + .unwrap_or_else(|e| panic!("CPU tensor_from_f64 failed for {dtype:?}: {e}")); + let b = tensor_from_f64(&tc.b, &tc.shape, dtype, &cpu_device, &cpu_client) + .unwrap_or_else(|e| panic!("CPU tensor_from_f64 failed for {dtype:?}: {e}")); + // Use ones as grad for simplicity + let grad_data: Vec = vec![1.0; tc.a.len()]; + let grad = tensor_from_f64(&grad_data, &tc.shape, dtype, &cpu_device, &cpu_client) + .unwrap_or_else(|e| panic!("CPU tensor_from_f64 failed for {dtype:?}: {e}")); + apply_fused_bwd(&cpu_client, op, &grad, &a, &b) + .unwrap_or_else(|e| panic!("CPU {op:?}_bwd failed for {dtype:?}: {e}")) + }) + .collect(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + for (idx, tc) in test_cases.iter().enumerate() { + let a = tensor_from_f64(&tc.a, &tc.shape, dtype, &cuda_device, &cuda_client) + .unwrap_or_else(|e| panic!("CUDA tensor_from_f64 failed for {dtype:?}: {e}")); + let b = tensor_from_f64(&tc.b, &tc.shape, dtype, &cuda_device, &cuda_client) + .unwrap_or_else(|e| panic!("CUDA tensor_from_f64 failed for {dtype:?}: {e}")); + let grad_data: Vec = vec![1.0; tc.a.len()]; + let grad = + tensor_from_f64(&grad_data, &tc.shape, dtype, &cuda_device, &cuda_client) + .unwrap_or_else(|e| { + panic!("CUDA tensor_from_f64 failed for {dtype:?}: {e}") + }); + let (d_a, d_b) = apply_fused_bwd(&cuda_client, op, &grad, &a, &b) + .unwrap_or_else(|e| panic!("CUDA {op:?}_bwd failed for {dtype:?}: {e}")); + assert_tensor_allclose( + &d_a, + &cpu_results[idx].0, + dtype, + &format!("{op:?}_bwd d_a CUDA vs CPU [{dtype:?}] case {idx}"), + ); + assert_tensor_allclose( + &d_b, + &cpu_results[idx].1, + dtype, + &format!("{op:?}_bwd d_b CUDA vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + for (idx, tc) in test_cases.iter().enumerate() { + let a = tensor_from_f64(&tc.a, &tc.shape, dtype, &wgpu_device, &wgpu_client) + .unwrap_or_else(|e| panic!("WebGPU tensor_from_f64 failed for {dtype:?}: {e}")); + let b = tensor_from_f64(&tc.b, &tc.shape, dtype, &wgpu_device, &wgpu_client) + .unwrap_or_else(|e| panic!("WebGPU tensor_from_f64 failed for {dtype:?}: {e}")); + let grad_data: Vec = vec![1.0; tc.a.len()]; + let grad = + tensor_from_f64(&grad_data, &tc.shape, dtype, &wgpu_device, &wgpu_client) + .unwrap_or_else(|e| { + panic!("WebGPU tensor_from_f64 failed for {dtype:?}: {e}") + }); + let (d_a, d_b) = apply_fused_bwd(&wgpu_client, op, &grad, &a, &b) + .unwrap_or_else(|e| panic!("WebGPU {op:?}_bwd failed for {dtype:?}: {e}")); + assert_tensor_allclose( + &d_a, + &cpu_results[idx].0, + dtype, + &format!("{op:?}_bwd d_a WebGPU vs CPU [{dtype:?}] case {idx}"), + ); + assert_tensor_allclose( + &d_b, + &cpu_results[idx].1, + dtype, + &format!("{op:?}_bwd d_b WebGPU vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } +} + +// ============================================================================ +// Test data +// ============================================================================ + +fn standard_test_cases() -> Vec { + vec![ + // Small 1D + FusedTestCase::new( + vec![-2.0, -1.0, 0.0, 1.0, 2.0], + vec![0.5, 1.0, 1.5, 2.0, 0.3], + vec![5], + ), + // 2D matrix + FusedTestCase::new( + vec![-1.0, 0.5, 1.5, -0.5, 2.0, -2.0], + vec![1.0, 2.0, 0.5, 1.5, 0.3, 1.0], + vec![2, 3], + ), + // Values near zero (important for derivative accuracy) + FusedTestCase::new( + vec![0.01, -0.01, 0.1, -0.1], + vec![1.0, 1.0, 1.0, 1.0], + vec![4], + ), + // Larger values (tests saturation behavior) + FusedTestCase::new( + vec![5.0, -5.0, 10.0, -10.0], + vec![0.1, 0.2, 0.3, 0.4], + vec![4], + ), + // Single element (edge case) + FusedTestCase::new(vec![1.5], vec![2.0], vec![1]), + // All zeros + FusedTestCase::new(vec![0.0, 0.0, 0.0, 0.0], vec![1.0, 1.0, 1.0, 1.0], vec![4]), + // Very large values (overflow risk for exp) + FusedTestCase::new( + vec![80.0, -80.0, 50.0, -50.0], + vec![1.0, 1.0, 1.0, 1.0], + vec![4], + ), + // Very small values (subnormal territory) + FusedTestCase::new( + vec![1e-7, -1e-7, 1e-6, -1e-6], + vec![1.0, 1.0, 1.0, 1.0], + vec![4], + ), + // Mixed signs in both operands + FusedTestCase::new( + vec![-3.0, 2.0, -1.0, 4.0], + vec![-1.0, -0.5, 2.0, -2.0], + vec![2, 2], + ), + ] +} + +// ============================================================================ +// Forward tests +// ============================================================================ + +#[test] +fn test_silu_mul_parity() { + let cases = standard_test_cases(); + for dtype in supported_dtypes("cpu") { + test_fused_fwd_parity(FusedActivationOp::SiluMul, &cases, dtype); + } +} + +#[test] +fn test_gelu_mul_parity() { + let cases = standard_test_cases(); + for dtype in supported_dtypes("cpu") { + test_fused_fwd_parity(FusedActivationOp::GeluMul, &cases, dtype); + } +} + +#[test] +fn test_relu_mul_parity() { + let cases = standard_test_cases(); + for dtype in supported_dtypes("cpu") { + test_fused_fwd_parity(FusedActivationOp::ReluMul, &cases, dtype); + } +} + +#[test] +fn test_sigmoid_mul_parity() { + let cases = standard_test_cases(); + for dtype in supported_dtypes("cpu") { + test_fused_fwd_parity(FusedActivationOp::SigmoidMul, &cases, dtype); + } +} + +// ============================================================================ +// Backward tests +// ============================================================================ + +#[test] +fn test_silu_mul_bwd_parity() { + let cases = standard_test_cases(); + for dtype in supported_dtypes("cpu") { + test_fused_bwd_parity(FusedActivationOp::SiluMul, &cases, dtype); + } +} + +#[test] +fn test_gelu_mul_bwd_parity() { + let cases = standard_test_cases(); + for dtype in supported_dtypes("cpu") { + test_fused_bwd_parity(FusedActivationOp::GeluMul, &cases, dtype); + } +} + +#[test] +fn test_relu_mul_bwd_parity() { + let cases = standard_test_cases(); + for dtype in supported_dtypes("cpu") { + test_fused_bwd_parity(FusedActivationOp::ReluMul, &cases, dtype); + } +} + +#[test] +fn test_sigmoid_mul_bwd_parity() { + let cases = standard_test_cases(); + for dtype in supported_dtypes("cpu") { + test_fused_bwd_parity(FusedActivationOp::SigmoidMul, &cases, dtype); + } +} + +// ============================================================================ +// Softmax parity tests +// ============================================================================ + +fn softmax_test_shapes() -> Vec<(Vec, Vec, isize)> { + vec![ + // (data, shape, dim) + (vec![1.0, 2.0, 3.0], vec![3], -1), + (vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3], -1), + (vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3], 0), + ( + (0..24).map(|i| i as f64 * 0.1 - 1.0).collect(), + vec![2, 3, 4], + 1, + ), + ( + (0..24).map(|i| i as f64 * 0.1 - 1.0).collect(), + vec![2, 3, 4], + -1, + ), + // Single element (should produce 1.0) + (vec![5.0], vec![1], -1), + // All identical values (uniform distribution) + (vec![1.0, 1.0, 1.0, 1.0], vec![4], -1), + // Very large values (overflow risk without max subtraction) + (vec![100.0, 200.0, 300.0], vec![3], -1), + // Very negative values + (vec![-100.0, -200.0, -50.0], vec![3], -1), + // Mixed extreme values (tests numerical stability) + (vec![-80.0, 0.0, 80.0], vec![3], -1), + // All zeros + (vec![0.0, 0.0, 0.0], vec![3], -1), + // 2D with dim=0 single row + (vec![1.0, 2.0, 3.0], vec![1, 3], 0), + ] +} + +fn test_softmax_parity_for_dtype(dtype: DType) { + if !is_dtype_supported("cpu", dtype) { + return; + } + + let (cpu_client, cpu_device) = create_cpu_client(); + + for (data, shape, dim) in softmax_test_shapes() { + let input_cpu = tensor_from_f64(&data, &shape, dtype, &cpu_device, &cpu_client).unwrap(); + let result_cpu = cpu_client.softmax(&input_cpu, dim).unwrap().contiguous(); + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + let input_wgpu = + tensor_from_f64(&data, &shape, dtype, &wgpu_device, &wgpu_client).unwrap(); + let result_wgpu = wgpu_client.softmax(&input_wgpu, dim).unwrap().contiguous(); + assert_tensor_allclose(&result_wgpu, &result_cpu, dtype, "softmax wgpu vs cpu"); + }); + } + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + let input_cuda = + tensor_from_f64(&data, &shape, dtype, &cuda_device, &cuda_client).unwrap(); + let result_cuda = cuda_client.softmax(&input_cuda, dim).unwrap().contiguous(); + assert_tensor_allclose(&result_cuda, &result_cpu, dtype, "softmax cuda vs cpu"); + }); + } + } +} + +#[test] +fn test_softmax_parity() { + for dtype in &[DType::F32, DType::F64] { + test_softmax_parity_for_dtype(*dtype); + } +} + +fn test_softmax_bwd_parity_for_dtype(dtype: DType) { + if !is_dtype_supported("cpu", dtype) { + return; + } + + let (cpu_client, cpu_device) = create_cpu_client(); + + for (data, shape, dim) in softmax_test_shapes() { + let input_cpu = tensor_from_f64(&data, &shape, dtype, &cpu_device, &cpu_client).unwrap(); + let output_cpu = cpu_client.softmax(&input_cpu, dim).unwrap().contiguous(); + + let grad_data: Vec = (0..data.len()).map(|i| (i as f64) * 0.1 - 0.5).collect(); + let grad_cpu = + tensor_from_f64(&grad_data, &shape, dtype, &cpu_device, &cpu_client).unwrap(); + let d_input_cpu = cpu_client + .softmax_bwd(&grad_cpu, &output_cpu, dim) + .unwrap() + .contiguous(); + + // Get CPU output as f64 for creating GPU tensors + let output_f64: Vec = if dtype == DType::F64 { + output_cpu.to_vec::() + } else { + output_cpu + .to_vec::() + .iter() + .map(|&x| x as f64) + .collect() + }; + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + let output_wgpu = + tensor_from_f64(&output_f64, &shape, dtype, &wgpu_device, &wgpu_client) + .unwrap(); + let grad_wgpu = + tensor_from_f64(&grad_data, &shape, dtype, &wgpu_device, &wgpu_client).unwrap(); + let d_input_wgpu = wgpu_client + .softmax_bwd(&grad_wgpu, &output_wgpu, dim) + .unwrap() + .contiguous(); + assert_tensor_allclose( + &d_input_wgpu, + &d_input_cpu, + dtype, + "softmax_bwd wgpu vs cpu", + ); + }); + } + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + let output_cuda = + tensor_from_f64(&output_f64, &shape, dtype, &cuda_device, &cuda_client) + .unwrap(); + let grad_cuda = + tensor_from_f64(&grad_data, &shape, dtype, &cuda_device, &cuda_client).unwrap(); + let d_input_cuda = cuda_client + .softmax_bwd(&grad_cuda, &output_cuda, dim) + .unwrap() + .contiguous(); + assert_tensor_allclose( + &d_input_cuda, + &d_input_cpu, + dtype, + "softmax_bwd cuda vs cpu", + ); + }); + } + } +} + +#[test] +fn test_softmax_bwd_parity() { + for dtype in &[DType::F32, DType::F64] { + test_softmax_bwd_parity_for_dtype(*dtype); + } +} diff --git a/tests/backend_parity/compare.rs b/tests/backend_parity/compare.rs index de9d9b14..8a05aacf 100644 --- a/tests/backend_parity/compare.rs +++ b/tests/backend_parity/compare.rs @@ -60,7 +60,7 @@ fn apply_compare_op( /// Read back a compare result as Vec regardless of backend output dtype. /// Some backends return Bool (u8), some U32, some keep the input dtype /// where nonzero = true, zero = false. -fn readback_as_u32(tensor: &Tensor) -> Vec { +fn readback_as_u32>(tensor: &Tensor) -> Vec { use crate::common::ToF64; macro_rules! via_f64 { diff --git a/tests/backend_parity/conditional.rs b/tests/backend_parity/conditional.rs new file mode 100644 index 00000000..737ff5a0 --- /dev/null +++ b/tests/backend_parity/conditional.rs @@ -0,0 +1,259 @@ +// Backend parity tests for ConditionalOps trait (where_cond) +// +// Dtype-parameterized: each test runs for all supported dtypes. +// CPU is the reference implementation; CUDA and WebGPU must match. + +use numr::dtype::DType; +use numr::ops::{CompareOps, ConditionalOps}; +use numr::tensor::Tensor; + +use crate::backend_parity::dtype_helpers::tensor_from_f64; +#[cfg(feature = "cuda")] +use crate::backend_parity::helpers::with_cuda_backend; +#[cfg(feature = "wgpu")] +use crate::backend_parity::helpers::with_wgpu_backend; +use crate::common::{ + assert_tensor_allclose, create_cpu_client, is_dtype_supported, supported_dtypes, +}; + +struct WhereTestCase { + cond: Vec, + cond_shape: Vec, + x: Vec, + x_shape: Vec, + y: Vec, + y_shape: Vec, +} + +impl WhereTestCase { + fn new( + cond: Vec, + cond_shape: Vec, + x: Vec, + x_shape: Vec, + y: Vec, + y_shape: Vec, + ) -> Self { + Self { + cond, + cond_shape, + x, + x_shape, + y, + y_shape, + } + } +} + +fn test_where_cond_parity(test_cases: &[WhereTestCase], dtype: DType) { + let (cpu_client, cpu_device) = create_cpu_client(); + + let cpu_results: Vec> = test_cases + .iter() + .map(|tc| { + let cond = tensor_from_f64(&tc.cond, &tc.cond_shape, dtype, &cpu_device, &cpu_client) + .unwrap_or_else(|e| panic!("CPU cond tensor failed for {dtype:?}: {e}")); + let x = tensor_from_f64(&tc.x, &tc.x_shape, dtype, &cpu_device, &cpu_client) + .unwrap_or_else(|e| panic!("CPU x tensor failed for {dtype:?}: {e}")); + let y = tensor_from_f64(&tc.y, &tc.y_shape, dtype, &cpu_device, &cpu_client) + .unwrap_or_else(|e| panic!("CPU y tensor failed for {dtype:?}: {e}")); + + cpu_client + .where_cond(&cond, &x, &y) + .unwrap_or_else(|e| panic!("CPU where_cond failed for {dtype:?}: {e}")) + }) + .collect(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + for (idx, tc) in test_cases.iter().enumerate() { + let cond = + tensor_from_f64(&tc.cond, &tc.cond_shape, dtype, &cuda_device, &cuda_client) + .unwrap_or_else(|e| panic!("CUDA cond tensor failed for {dtype:?}: {e}")); + let x = tensor_from_f64(&tc.x, &tc.x_shape, dtype, &cuda_device, &cuda_client) + .unwrap_or_else(|e| panic!("CUDA x tensor failed for {dtype:?}: {e}")); + let y = tensor_from_f64(&tc.y, &tc.y_shape, dtype, &cuda_device, &cuda_client) + .unwrap_or_else(|e| panic!("CUDA y tensor failed for {dtype:?}: {e}")); + + let result = cuda_client + .where_cond(&cond, &x, &y) + .unwrap_or_else(|e| panic!("CUDA where_cond failed for {dtype:?}: {e}")); + + assert_tensor_allclose( + &result, + &cpu_results[idx], + dtype, + &format!("where_cond CUDA vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + for (idx, tc) in test_cases.iter().enumerate() { + let cond = + tensor_from_f64(&tc.cond, &tc.cond_shape, dtype, &wgpu_device, &wgpu_client) + .unwrap_or_else(|e| panic!("WebGPU cond tensor failed for {dtype:?}: {e}")); + let x = tensor_from_f64(&tc.x, &tc.x_shape, dtype, &wgpu_device, &wgpu_client) + .unwrap_or_else(|e| panic!("WebGPU x tensor failed for {dtype:?}: {e}")); + let y = tensor_from_f64(&tc.y, &tc.y_shape, dtype, &wgpu_device, &wgpu_client) + .unwrap_or_else(|e| panic!("WebGPU y tensor failed for {dtype:?}: {e}")); + + let result = wgpu_client + .where_cond(&cond, &x, &y) + .unwrap_or_else(|e| panic!("WebGPU where_cond failed for {dtype:?}: {e}")); + + assert_tensor_allclose( + &result, + &cpu_results[idx], + dtype, + &format!("where_cond WebGPU vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } +} + +fn where_test_cases() -> Vec { + vec![ + // 1D: simple mask + WhereTestCase::new( + vec![1.0, 0.0, 1.0, 0.0], + vec![4], + vec![10.0, 20.0, 30.0, 40.0], + vec![4], + vec![100.0, 200.0, 300.0, 400.0], + vec![4], + ), + // 2D: all true + WhereTestCase::new( + vec![1.0, 1.0, 1.0, 1.0], + vec![2, 2], + vec![1.0, 2.0, 3.0, 4.0], + vec![2, 2], + vec![5.0, 6.0, 7.0, 8.0], + vec![2, 2], + ), + // 2D: all false + WhereTestCase::new( + vec![0.0, 0.0, 0.0, 0.0], + vec![2, 2], + vec![1.0, 2.0, 3.0, 4.0], + vec![2, 2], + vec![5.0, 6.0, 7.0, 8.0], + vec![2, 2], + ), + // 1D: alternating + WhereTestCase::new( + vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0], + vec![6], + vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0], + vec![6], + vec![100.0, 200.0, 300.0, 400.0, 500.0, 600.0], + vec![6], + ), + // 3D tensor + WhereTestCase::new( + vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0], + vec![2, 2, 2], + vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + vec![2, 2, 2], + vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0], + vec![2, 2, 2], + ), + ] +} + +#[test] +fn test_where_cond_parity_all_dtypes() { + let cases = where_test_cases(); + for dtype in supported_dtypes("cpu") { + test_where_cond_parity(&cases, dtype); + } +} + +// Test where_cond with condition from comparison ops +#[test] +fn test_where_cond_from_compare_parity() { + let (cpu_client, cpu_device) = create_cpu_client(); + let dtype = DType::F32; + + let a = tensor_from_f64(&[1.0, 5.0, 3.0, 7.0], &[4], dtype, &cpu_device, &cpu_client) + .expect("tensor creation failed"); + let threshold = tensor_from_f64(&[3.0, 3.0, 3.0, 3.0], &[4], dtype, &cpu_device, &cpu_client) + .expect("tensor creation failed"); + let x = tensor_from_f64( + &[10.0, 20.0, 30.0, 40.0], + &[4], + dtype, + &cpu_device, + &cpu_client, + ) + .expect("tensor creation failed"); + let y = tensor_from_f64( + &[100.0, 200.0, 300.0, 400.0], + &[4], + dtype, + &cpu_device, + &cpu_client, + ) + .expect("tensor creation failed"); + + let mask = cpu_client.gt(&a, &threshold).expect("gt failed"); + let _cpu_result = cpu_client + .where_cond(&mask, &x, &y) + .expect("where_cond failed"); + + #[cfg(feature = "wgpu")] + { + with_wgpu_backend(|wgpu_client, wgpu_device| { + let a_w = tensor_from_f64( + &[1.0, 5.0, 3.0, 7.0], + &[4], + dtype, + &wgpu_device, + &wgpu_client, + ) + .expect("tensor creation failed"); + let t_w = tensor_from_f64( + &[3.0, 3.0, 3.0, 3.0], + &[4], + dtype, + &wgpu_device, + &wgpu_client, + ) + .expect("tensor creation failed"); + let x_w = tensor_from_f64( + &[10.0, 20.0, 30.0, 40.0], + &[4], + dtype, + &wgpu_device, + &wgpu_client, + ) + .expect("tensor creation failed"); + let y_w = tensor_from_f64( + &[100.0, 200.0, 300.0, 400.0], + &[4], + dtype, + &wgpu_device, + &wgpu_client, + ) + .expect("tensor creation failed"); + + let mask_w = wgpu_client.gt(&a_w, &t_w).expect("gt failed"); + let result = wgpu_client + .where_cond(&mask_w, &x_w, &y_w) + .expect("where_cond failed"); + + assert_tensor_allclose( + &result, + &_cpu_result, + dtype, + "where_cond(gt mask) WebGPU vs CPU", + ); + }); + } +} diff --git a/tests/backend_parity/conv.rs b/tests/backend_parity/conv.rs index f658f894..ccae8f11 100644 --- a/tests/backend_parity/conv.rs +++ b/tests/backend_parity/conv.rs @@ -3,10 +3,7 @@ // Dtype-parameterized: each test runs for all supported dtypes across all backends. // Comparison reads back in native dtype via assert_tensor_allclose. -use numr::dtype::DType; use numr::ops::{ConvOps, PaddingMode}; -use numr::runtime::cpu::CpuRuntime; -use numr::tensor::Tensor; use crate::backend_parity::dtype_helpers::tensor_from_f64; #[cfg(feature = "cuda")] diff --git a/tests/backend_parity/distance.rs b/tests/backend_parity/distance.rs new file mode 100644 index 00000000..7e98d015 --- /dev/null +++ b/tests/backend_parity/distance.rs @@ -0,0 +1,306 @@ +// Backend parity tests for DistanceOps trait +// +// Tests: cdist, pdist, squareform, squareform_inverse +// CPU is the reference implementation; CUDA and WebGPU must match. + +use numr::dtype::DType; +use numr::ops::{DistanceMetric, DistanceOps}; + +use crate::backend_parity::dtype_helpers::tensor_from_f64; +#[cfg(feature = "cuda")] +use crate::backend_parity::helpers::with_cuda_backend; +#[cfg(feature = "wgpu")] +use crate::backend_parity::helpers::with_wgpu_backend; +use crate::common::{ + assert_tensor_allclose, create_cpu_client, is_dtype_supported, supported_dtypes, +}; + +// ============================================================================ +// cdist +// ============================================================================ + +struct CdistCase { + x: Vec, + x_shape: Vec, + y: Vec, + y_shape: Vec, + metric: DistanceMetric, +} + +impl CdistCase { + fn new( + x: Vec, + x_shape: Vec, + y: Vec, + y_shape: Vec, + metric: DistanceMetric, + ) -> Self { + Self { + x, + x_shape, + y, + y_shape, + metric, + } + } +} + +fn cdist_test_cases() -> Vec { + // Points in 2D + let x = vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]; // 3 points in 2D + let y = vec![1.0, 1.0, 2.0, 0.0]; // 2 points in 2D + + vec![ + CdistCase::new( + x.clone(), + vec![3, 2], + y.clone(), + vec![2, 2], + DistanceMetric::Euclidean, + ), + CdistCase::new( + x.clone(), + vec![3, 2], + y.clone(), + vec![2, 2], + DistanceMetric::SquaredEuclidean, + ), + CdistCase::new( + x.clone(), + vec![3, 2], + y.clone(), + vec![2, 2], + DistanceMetric::Manhattan, + ), + CdistCase::new( + x.clone(), + vec![3, 2], + y.clone(), + vec![2, 2], + DistanceMetric::Chebyshev, + ), + // 3D points + CdistCase::new( + vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + vec![2, 3], + vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0], + vec![3, 3], + DistanceMetric::Euclidean, + ), + ] +} + +fn test_cdist_parity(dtype: DType) { + let cases = cdist_test_cases(); + let (cpu_client, cpu_device) = create_cpu_client(); + + for (idx, tc) in cases.iter().enumerate() { + let cpu_x = tensor_from_f64(&tc.x, &tc.x_shape, dtype, &cpu_device, &cpu_client) + .expect("CPU x tensor failed"); + let cpu_y = tensor_from_f64(&tc.y, &tc.y_shape, dtype, &cpu_device, &cpu_client) + .expect("CPU y tensor failed"); + let cpu_result = cpu_client + .cdist(&cpu_x, &cpu_y, tc.metric) + .unwrap_or_else(|e| panic!("CPU cdist {:?} failed for {dtype:?}: {e}", tc.metric)); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + let x = tensor_from_f64(&tc.x, &tc.x_shape, dtype, &cuda_device, &cuda_client) + .expect("CUDA x tensor failed"); + let y = tensor_from_f64(&tc.y, &tc.y_shape, dtype, &cuda_device, &cuda_client) + .expect("CUDA y tensor failed"); + let result = cuda_client + .cdist(&x, &y, tc.metric) + .unwrap_or_else(|e| panic!("CUDA cdist failed: {e}")); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("cdist {:?} CUDA vs CPU [{dtype:?}] case {idx}", tc.metric), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + let x = tensor_from_f64(&tc.x, &tc.x_shape, dtype, &wgpu_device, &wgpu_client) + .expect("WebGPU x tensor failed"); + let y = tensor_from_f64(&tc.y, &tc.y_shape, dtype, &wgpu_device, &wgpu_client) + .expect("WebGPU y tensor failed"); + let result = wgpu_client + .cdist(&x, &y, tc.metric) + .unwrap_or_else(|e| panic!("WebGPU cdist failed: {e}")); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("cdist {:?} WebGPU vs CPU [{dtype:?}] case {idx}", tc.metric), + ); + }); + } + } +} + +#[test] +fn test_cdist_parity_all_dtypes() { + for dtype in supported_dtypes("cpu") { + test_cdist_parity(dtype); + } +} + +// ============================================================================ +// pdist +// ============================================================================ + +fn test_pdist_parity(dtype: DType) { + let (cpu_client, cpu_device) = create_cpu_client(); + + // 4 points in 2D + let data = vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]; + let shape = vec![4, 2]; + + let metrics = vec![ + DistanceMetric::Euclidean, + DistanceMetric::SquaredEuclidean, + DistanceMetric::Manhattan, + DistanceMetric::Chebyshev, + ]; + + for metric in &metrics { + let cpu_x = tensor_from_f64(&data, &shape, dtype, &cpu_device, &cpu_client) + .expect("CPU tensor failed"); + let cpu_result = cpu_client + .pdist(&cpu_x, *metric) + .unwrap_or_else(|e| panic!("CPU pdist {metric:?} failed: {e}")); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + let x = tensor_from_f64(&data, &shape, dtype, &cuda_device, &cuda_client) + .expect("CUDA tensor failed"); + let result = cuda_client + .pdist(&x, *metric) + .unwrap_or_else(|e| panic!("CUDA pdist failed: {e}")); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("pdist {metric:?} CUDA vs CPU [{dtype:?}]"), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + let x = tensor_from_f64(&data, &shape, dtype, &wgpu_device, &wgpu_client) + .expect("WebGPU tensor failed"); + let result = wgpu_client + .pdist(&x, *metric) + .unwrap_or_else(|e| panic!("WebGPU pdist failed: {e}")); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("pdist {metric:?} WebGPU vs CPU [{dtype:?}]"), + ); + }); + } + } +} + +#[test] +fn test_pdist_parity_all_dtypes() { + for dtype in supported_dtypes("cpu") { + test_pdist_parity(dtype); + } +} + +// ============================================================================ +// squareform roundtrip +// ============================================================================ + +#[test] +fn test_squareform_roundtrip_parity() { + let dtype = DType::F32; + let (cpu_client, cpu_device) = create_cpu_client(); + + // 4 points in 2D + let data = vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]; + let shape = vec![4, 2]; + let n = 4usize; + + let cpu_x = + tensor_from_f64(&data, &shape, dtype, &cpu_device, &cpu_client).expect("tensor failed"); + let cpu_condensed = cpu_client + .pdist(&cpu_x, DistanceMetric::Euclidean) + .expect("pdist failed"); + let cpu_square = cpu_client + .squareform(&cpu_condensed, n) + .expect("squareform failed"); + let cpu_back = cpu_client + .squareform_inverse(&cpu_square) + .expect("squareform_inverse failed"); + + // Verify roundtrip: condensed -> square -> condensed + assert_tensor_allclose(&cpu_back, &cpu_condensed, dtype, "squareform roundtrip CPU"); + + #[cfg(feature = "wgpu")] + with_wgpu_backend(|wgpu_client, wgpu_device| { + let x = tensor_from_f64(&data, &shape, dtype, &wgpu_device, &wgpu_client) + .expect("tensor failed"); + let condensed = wgpu_client + .pdist(&x, DistanceMetric::Euclidean) + .expect("pdist failed"); + let square = wgpu_client + .squareform(&condensed, n) + .expect("squareform failed"); + + assert_tensor_allclose(&square, &cpu_square, dtype, "squareform WebGPU vs CPU"); + + let back = wgpu_client + .squareform_inverse(&square) + .expect("squareform_inverse failed"); + assert_tensor_allclose( + &back, + &cpu_condensed, + dtype, + "squareform_inverse WebGPU vs CPU", + ); + }); +} + +// ============================================================================ +// cosine distance +// ============================================================================ + +#[test] +fn test_cdist_cosine_parity() { + let dtype = DType::F32; + let (cpu_client, cpu_device) = create_cpu_client(); + + let x = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0]; // 3 points in 2D + let y = vec![1.0, 0.0, 0.0, 1.0]; // 2 points in 2D + + let cpu_x = + tensor_from_f64(&x, &[3, 2], dtype, &cpu_device, &cpu_client).expect("tensor failed"); + let cpu_y = + tensor_from_f64(&y, &[2, 2], dtype, &cpu_device, &cpu_client).expect("tensor failed"); + let _cpu_result = cpu_client + .cdist(&cpu_x, &cpu_y, DistanceMetric::Cosine) + .expect("CPU cosine cdist failed"); + + #[cfg(feature = "wgpu")] + with_wgpu_backend(|wgpu_client, wgpu_device| { + let wx = + tensor_from_f64(&x, &[3, 2], dtype, &wgpu_device, &wgpu_client).expect("tensor failed"); + let wy = + tensor_from_f64(&y, &[2, 2], dtype, &wgpu_device, &wgpu_client).expect("tensor failed"); + let result = wgpu_client + .cdist(&wx, &wy, DistanceMetric::Cosine) + .expect("WebGPU cosine cdist failed"); + assert_tensor_allclose(&result, &_cpu_result, dtype, "cdist Cosine WebGPU vs CPU"); + }); +} diff --git a/tests/backend_parity/dtype_helpers.rs b/tests/backend_parity/dtype_helpers.rs index 592940e6..4e44ee6e 100644 --- a/tests/backend_parity/dtype_helpers.rs +++ b/tests/backend_parity/dtype_helpers.rs @@ -50,7 +50,7 @@ use numr::tensor::Tensor; /// let tensor = tensor_from_f64(&data, &[2, 2], DType::F32, &device, &client)?; /// assert_eq!(tensor.dtype(), DType::F32); /// ``` -pub fn tensor_from_f64( +pub fn tensor_from_f64>( data: &[f64], shape: &[usize], dtype: DType, @@ -89,7 +89,7 @@ pub fn tensor_from_f64( /// let tensor = tensor_from_f32(&[1.0, 2.0], &[2], DType::F16, &device, &client)?; /// assert_eq!(tensor.dtype(), DType::F16); /// ``` -pub fn tensor_from_f32( +pub fn tensor_from_f32>( data: &[f32], shape: &[usize], dtype: DType, @@ -116,7 +116,7 @@ pub fn tensor_from_f32( /// let tensor = tensor_from_i32(&[1, 2, 3], &[3], DType::U32, &device, &client)?; /// assert_eq!(tensor.dtype(), DType::U32); /// ``` -pub fn tensor_from_i32( +pub fn tensor_from_i32>( data: &[i32], shape: &[usize], dtype: DType, diff --git a/tests/backend_parity/fp8_matmul.rs b/tests/backend_parity/fp8_matmul.rs new file mode 100644 index 00000000..35b0c0ba --- /dev/null +++ b/tests/backend_parity/fp8_matmul.rs @@ -0,0 +1,360 @@ +//! Backend parity tests for FP8 matrix multiplication operations. +//! +//! Tests verify that CUDA FP8 matmul produces results matching CPU reference +//! (cast FP8→F32, matmul, scale, cast to output dtype) within FP tolerance. + +use crate::common::create_cpu_client; +use numr::dtype::DType; +use numr::ops::{Fp8MatmulOps, TypeConversionOps}; +use numr::runtime::cpu::CpuRuntime; +use numr::tensor::Tensor; + +/// Create FP8E4M3 tensor from f32 data on the given backend. +fn create_fp8e4m3_tensor>( + data: &[f32], + shape: &[usize], + device: &R::Device, + client: &impl TypeConversionOps, +) -> numr::error::Result> { + let f32_tensor = Tensor::from_slice(data, shape, device); + client.cast(&f32_tensor, DType::FP8E4M3) +} + +/// Create FP8E5M2 tensor from f32 data on the given backend. +fn create_fp8e5m2_tensor>( + data: &[f32], + shape: &[usize], + device: &R::Device, + client: &impl TypeConversionOps, +) -> numr::error::Result> { + let f32_tensor = Tensor::from_slice(data, shape, device); + client.cast(&f32_tensor, DType::FP8E5M2) +} + +/// Compare f32 results with relaxed tolerance for FP8 (limited precision). +fn assert_fp8_parity(cpu: &[f32], other: &[f32], op: &str) { + let rtol = 0.1f32; // FP8 has very low precision, ~10% relative tolerance + let atol = 0.5f32; // Absolute tolerance for small values + assert_eq!( + cpu.len(), + other.len(), + "fp8_parity[{}]: length mismatch: {} vs {}", + op, + cpu.len(), + other.len() + ); + for (i, (c, o)) in cpu.iter().zip(other.iter()).enumerate() { + let diff = (c - o).abs(); + let tol = atol + rtol * c.abs(); + if diff > tol { + panic!( + "fp8_parity[{}] at index {}: cpu={} vs other={} (diff={}, tol={})", + op, i, c, o, diff, tol + ); + } + } +} + +// ============================================================================ +// CPU Tests (baseline) +// ============================================================================ + +#[test] +fn test_fp8_matmul_e4m3_cpu_f32_output() { + let (client, device) = create_cpu_client(); + // Small values to stay within FP8E4M3 range + let a_data: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let b_data: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + + let a = create_fp8e4m3_tensor::(&a_data, &[2, 3], &device, &client).unwrap(); + let b = create_fp8e4m3_tensor::(&b_data, &[3, 2], &device, &client).unwrap(); + + let result = client.fp8_matmul(&a, &b, 1.0, 1.0, DType::F32).unwrap(); + assert_eq!(result.dtype(), DType::F32); + assert_eq!(result.shape(), &[2, 2]); + + let vals = result.to_vec::(); + // Expected: [1*1+2*3+3*5, 1*2+2*4+3*6, 4*1+5*3+6*5, 4*2+5*4+6*6] + // = [22, 28, 49, 64] + assert_fp8_parity(&[22.0, 28.0, 49.0, 64.0], &vals, "fp8_e4m3_cpu_f32"); +} + +#[test] +fn test_fp8_matmul_e4m3_cpu_with_scaling() { + let (client, device) = create_cpu_client(); + let a_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let b_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + + let a = create_fp8e4m3_tensor::(&a_data, &[2, 2], &device, &client).unwrap(); + let b = create_fp8e4m3_tensor::(&b_data, &[2, 2], &device, &client).unwrap(); + + let result = client.fp8_matmul(&a, &b, 2.0, 0.5, DType::F32).unwrap(); + let vals = result.to_vec::(); + // scale_a * scale_b = 1.0, so same as unscaled + // [1*1+2*3, 1*2+2*4, 3*1+4*3, 3*2+4*4] = [7, 10, 15, 22] + assert_fp8_parity(&[7.0, 10.0, 15.0, 22.0], &vals, "fp8_e4m3_cpu_scaled"); +} + +#[test] +fn test_fp8_matmul_e5m2_cpu() { + let (client, device) = create_cpu_client(); + let a_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let b_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + + let a = create_fp8e5m2_tensor::(&a_data, &[2, 2], &device, &client).unwrap(); + let b = create_fp8e4m3_tensor::(&b_data, &[2, 2], &device, &client).unwrap(); + + let result = client + .fp8_matmul_e5m2(&a, &b, 1.0, 1.0, DType::F32) + .unwrap(); + assert_eq!(result.dtype(), DType::F32); + assert_eq!(result.shape(), &[2, 2]); +} + +#[test] +fn test_fp8_matmul_e4m3_cpu_f16_output() { + let (client, device) = create_cpu_client(); + let a_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let b_data: Vec = vec![1.0, 0.0, 0.0, 1.0]; + + let a = create_fp8e4m3_tensor::(&a_data, &[2, 2], &device, &client).unwrap(); + let b = create_fp8e4m3_tensor::(&b_data, &[2, 2], &device, &client).unwrap(); + + let result = client.fp8_matmul(&a, &b, 1.0, 1.0, DType::F16).unwrap(); + assert_eq!(result.dtype(), DType::F16); + assert_eq!(result.shape(), &[2, 2]); +} + +#[test] +fn test_fp8_matmul_e4m3_cpu_bf16_output() { + let (client, device) = create_cpu_client(); + let a_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let b_data: Vec = vec![1.0, 0.0, 0.0, 1.0]; + + let a = create_fp8e4m3_tensor::(&a_data, &[2, 2], &device, &client).unwrap(); + let b = create_fp8e4m3_tensor::(&b_data, &[2, 2], &device, &client).unwrap(); + + let result = client.fp8_matmul(&a, &b, 1.0, 1.0, DType::BF16).unwrap(); + assert_eq!(result.dtype(), DType::BF16); + assert_eq!(result.shape(), &[2, 2]); +} + +#[test] +fn test_fp8_matmul_dtype_validation() { + let (client, device) = create_cpu_client(); + let a = Tensor::::from_slice(&[1.0f32, 2.0], &[1, 2], &device); + let b_data: Vec = vec![1.0, 2.0]; + let b = create_fp8e4m3_tensor::(&b_data, &[2, 1], &device, &client).unwrap(); + + // a is F32, not FP8E4M3 — should fail + let result = client.fp8_matmul(&a, &b, 1.0, 1.0, DType::F32); + assert!(result.is_err()); +} + +#[test] +fn test_fp8_matmul_invalid_output_dtype() { + let (client, device) = create_cpu_client(); + let a_data: Vec = vec![1.0, 2.0]; + let b_data: Vec = vec![1.0, 2.0]; + + let a = create_fp8e4m3_tensor::(&a_data, &[1, 2], &device, &client).unwrap(); + let b = create_fp8e4m3_tensor::(&b_data, &[2, 1], &device, &client).unwrap(); + + // I32 is not a valid output dtype for FP8 matmul + let result = client.fp8_matmul(&a, &b, 1.0, 1.0, DType::I32); + assert!(result.is_err()); +} + +// ============================================================================ +// CUDA Parity Tests +// ============================================================================ + +#[cfg(feature = "cuda")] +mod cuda_parity { + use super::*; + use crate::backend_parity::helpers::with_cuda_backend; + use numr::ops::TypeConversionOps; + use numr::runtime::cuda::CudaRuntime; + + #[test] + fn test_fp8_matmul_e4m3_cuda_parity_f32() { + let (cpu_client, cpu_device) = create_cpu_client(); + with_cuda_backend(|cuda_client, cuda_device| { + let a_data: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let b_data: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + + let a_cpu = + create_fp8e4m3_tensor::(&a_data, &[2, 3], &cpu_device, &cpu_client) + .unwrap(); + let b_cpu = + create_fp8e4m3_tensor::(&b_data, &[3, 2], &cpu_device, &cpu_client) + .unwrap(); + let cpu_result = cpu_client + .fp8_matmul(&a_cpu, &b_cpu, 1.0, 1.0, DType::F32) + .unwrap(); + + let a_cuda = + create_fp8e4m3_tensor::(&a_data, &[2, 3], &cuda_device, &cuda_client) + .unwrap(); + let b_cuda = + create_fp8e4m3_tensor::(&b_data, &[3, 2], &cuda_device, &cuda_client) + .unwrap(); + let cuda_result = cuda_client + .fp8_matmul(&a_cuda, &b_cuda, 1.0, 1.0, DType::F32) + .unwrap(); + + let cpu_vals = cpu_result.to_vec::(); + let cuda_f32 = cuda_client.cast(&cuda_result, DType::F32).unwrap(); + let cuda_vals = cuda_f32.to_vec::(); + assert_fp8_parity(&cpu_vals, &cuda_vals, "fp8_e4m3_cuda_f32"); + }); + } + + #[test] + fn test_fp8_matmul_e4m3_cuda_parity_with_scaling() { + let (cpu_client, cpu_device) = create_cpu_client(); + with_cuda_backend(|cuda_client, cuda_device| { + let a_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let b_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let scale_a = 2.0f32; + let scale_b = 0.5f32; + + let a_cpu = + create_fp8e4m3_tensor::(&a_data, &[2, 2], &cpu_device, &cpu_client) + .unwrap(); + let b_cpu = + create_fp8e4m3_tensor::(&b_data, &[2, 2], &cpu_device, &cpu_client) + .unwrap(); + let cpu_result = cpu_client + .fp8_matmul(&a_cpu, &b_cpu, scale_a, scale_b, DType::F32) + .unwrap(); + + let a_cuda = + create_fp8e4m3_tensor::(&a_data, &[2, 2], &cuda_device, &cuda_client) + .unwrap(); + let b_cuda = + create_fp8e4m3_tensor::(&b_data, &[2, 2], &cuda_device, &cuda_client) + .unwrap(); + let cuda_result = cuda_client + .fp8_matmul(&a_cuda, &b_cuda, scale_a, scale_b, DType::F32) + .unwrap(); + + let cpu_vals = cpu_result.to_vec::(); + let cuda_f32 = cuda_client.cast(&cuda_result, DType::F32).unwrap(); + let cuda_vals = cuda_f32.to_vec::(); + assert_fp8_parity(&cpu_vals, &cuda_vals, "fp8_e4m3_cuda_scaled"); + }); + } + + #[test] + fn test_fp8_matmul_e5m2_cuda_parity() { + let (cpu_client, cpu_device) = create_cpu_client(); + with_cuda_backend(|cuda_client, cuda_device| { + let a_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let b_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + + let a_cpu = + create_fp8e5m2_tensor::(&a_data, &[2, 2], &cpu_device, &cpu_client) + .unwrap(); + let b_cpu = + create_fp8e4m3_tensor::(&b_data, &[2, 2], &cpu_device, &cpu_client) + .unwrap(); + let cpu_result = cpu_client + .fp8_matmul_e5m2(&a_cpu, &b_cpu, 1.0, 1.0, DType::F32) + .unwrap(); + + let a_cuda = + create_fp8e5m2_tensor::(&a_data, &[2, 2], &cuda_device, &cuda_client) + .unwrap(); + let b_cuda = + create_fp8e4m3_tensor::(&b_data, &[2, 2], &cuda_device, &cuda_client) + .unwrap(); + let cuda_result = cuda_client + .fp8_matmul_e5m2(&a_cuda, &b_cuda, 1.0, 1.0, DType::F32) + .unwrap(); + + let cpu_vals = cpu_result.to_vec::(); + let cuda_f32 = cuda_client.cast(&cuda_result, DType::F32).unwrap(); + let cuda_vals = cuda_f32.to_vec::(); + assert_fp8_parity(&cpu_vals, &cuda_vals, "fp8_e5m2_cuda"); + }); + } + + #[test] + fn test_fp8_matmul_e4m3_cuda_parity_f16_output() { + let (cpu_client, cpu_device) = create_cpu_client(); + with_cuda_backend(|cuda_client, cuda_device| { + let a_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let b_data: Vec = vec![1.0, 0.0, 0.0, 1.0]; + + let a_cpu = + create_fp8e4m3_tensor::(&a_data, &[2, 2], &cpu_device, &cpu_client) + .unwrap(); + let b_cpu = + create_fp8e4m3_tensor::(&b_data, &[2, 2], &cpu_device, &cpu_client) + .unwrap(); + let cpu_result = cpu_client + .fp8_matmul(&a_cpu, &b_cpu, 1.0, 1.0, DType::F16) + .unwrap(); + let cpu_f32 = cpu_client.cast(&cpu_result, DType::F32).unwrap(); + + let a_cuda = + create_fp8e4m3_tensor::(&a_data, &[2, 2], &cuda_device, &cuda_client) + .unwrap(); + let b_cuda = + create_fp8e4m3_tensor::(&b_data, &[2, 2], &cuda_device, &cuda_client) + .unwrap(); + let cuda_result = cuda_client + .fp8_matmul(&a_cuda, &b_cuda, 1.0, 1.0, DType::F16) + .unwrap(); + let cuda_f32 = cuda_client.cast(&cuda_result, DType::F32).unwrap(); + + let cpu_vals = cpu_f32.to_vec::(); + let cuda_vals = cuda_f32.to_vec::(); + assert_fp8_parity(&cpu_vals, &cuda_vals, "fp8_e4m3_cuda_f16"); + }); + } + + #[test] + fn test_fp8_matmul_e4m3_cuda_batched_parity() { + let (cpu_client, cpu_device) = create_cpu_client(); + with_cuda_backend(|cuda_client, cuda_device| { + // [2, 2, 2] x [2, 2, 2] batched matmul + let a_data: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let b_data: Vec = vec![1.0, 0.0, 0.0, 1.0, 1.0, 2.0, 3.0, 4.0]; + + let a_cpu = + create_fp8e4m3_tensor::(&a_data, &[2, 2, 2], &cpu_device, &cpu_client) + .unwrap(); + let b_cpu = + create_fp8e4m3_tensor::(&b_data, &[2, 2, 2], &cpu_device, &cpu_client) + .unwrap(); + let cpu_result = cpu_client + .fp8_matmul(&a_cpu, &b_cpu, 1.0, 1.0, DType::F32) + .unwrap(); + + let a_cuda = create_fp8e4m3_tensor::( + &a_data, + &[2, 2, 2], + &cuda_device, + &cuda_client, + ) + .unwrap(); + let b_cuda = create_fp8e4m3_tensor::( + &b_data, + &[2, 2, 2], + &cuda_device, + &cuda_client, + ) + .unwrap(); + let cuda_result = cuda_client + .fp8_matmul(&a_cuda, &b_cuda, 1.0, 1.0, DType::F32) + .unwrap(); + + let cpu_vals = cpu_result.to_vec::(); + let cuda_f32 = cuda_client.cast(&cuda_result, DType::F32).unwrap(); + let cuda_vals = cuda_f32.to_vec::(); + assert_fp8_parity(&cpu_vals, &cuda_vals, "fp8_e4m3_cuda_batched"); + }); + } +} diff --git a/tests/backend_parity/fused_elementwise.rs b/tests/backend_parity/fused_elementwise.rs new file mode 100644 index 00000000..f1385caf --- /dev/null +++ b/tests/backend_parity/fused_elementwise.rs @@ -0,0 +1,285 @@ +// Backend parity tests for fused elementwise operations +// +// Tests: fused_mul_add, fused_add_mul (BinaryOps), fused_mul_add_scalar (ScalarOps) +// Dtype-parameterized: runs for all supported dtypes across all backends. + +use numr::dtype::DType; +use numr::ops::{BinaryOps, ScalarOps}; +use numr::tensor::Tensor; + +use crate::backend_parity::dtype_helpers::tensor_from_f64; +#[cfg(feature = "cuda")] +use crate::backend_parity::helpers::with_cuda_backend; +#[cfg(feature = "wgpu")] +use crate::backend_parity::helpers::with_wgpu_backend; +use crate::common::{ + assert_tensor_allclose, create_cpu_client, is_dtype_supported, supported_dtypes, +}; + +// ============================================================================ +// Ternary test cases (a, b, c) +// ============================================================================ + +#[derive(Clone)] +struct TernaryCase { + a: Vec, + b: Vec, + c: Vec, + shape: Vec, +} + +impl TernaryCase { + fn new(a: Vec, b: Vec, c: Vec, shape: Vec) -> Self { + Self { a, b, c, shape } + } +} + +fn ternary_cases() -> Vec { + vec![ + TernaryCase::new( + vec![1.0, 2.0, 3.0, 4.0], + vec![2.0, 3.0, 4.0, 5.0], + vec![0.5, 1.0, 1.5, 2.0], + vec![4], + ), + TernaryCase::new( + vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + vec![0.01, 0.02, 0.03, 0.04, 0.05, 0.06], + vec![2, 3], + ), + TernaryCase::new( + vec![-1.0, 0.0, 1.0, 2.0], + vec![3.0, 3.0, 3.0, 3.0], + vec![10.0, 20.0, 30.0, 40.0], + vec![2, 2], + ), + ] +} + +// ============================================================================ +// fused_mul_add: out = a * b + c +// ============================================================================ + +fn test_fused_mul_add_parity(dtype: DType) { + let (cpu_client, cpu_device) = create_cpu_client(); + let cases = ternary_cases(); + + let cpu_results: Vec> = cases + .iter() + .map(|tc| { + let a = tensor_from_f64(&tc.a, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + let b = tensor_from_f64(&tc.b, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + let c = tensor_from_f64(&tc.c, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + cpu_client.fused_mul_add(&a, &b, &c).unwrap() + }) + .collect(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + for (idx, tc) in cases.iter().enumerate() { + let a = + tensor_from_f64(&tc.a, &tc.shape, dtype, &cuda_device, &cuda_client).unwrap(); + let b = + tensor_from_f64(&tc.b, &tc.shape, dtype, &cuda_device, &cuda_client).unwrap(); + let c = + tensor_from_f64(&tc.c, &tc.shape, dtype, &cuda_device, &cuda_client).unwrap(); + let result = cuda_client.fused_mul_add(&a, &b, &c).unwrap(); + assert_tensor_allclose( + &result, + &cpu_results[idx], + dtype, + &format!("fused_mul_add CUDA vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + for (idx, tc) in cases.iter().enumerate() { + let a = + tensor_from_f64(&tc.a, &tc.shape, dtype, &wgpu_device, &wgpu_client).unwrap(); + let b = + tensor_from_f64(&tc.b, &tc.shape, dtype, &wgpu_device, &wgpu_client).unwrap(); + let c = + tensor_from_f64(&tc.c, &tc.shape, dtype, &wgpu_device, &wgpu_client).unwrap(); + let result = wgpu_client.fused_mul_add(&a, &b, &c).unwrap(); + assert_tensor_allclose( + &result, + &cpu_results[idx], + dtype, + &format!("fused_mul_add WebGPU vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } +} + +#[test] +fn test_fused_mul_add_all_dtypes() { + for dtype in supported_dtypes("cpu") { + test_fused_mul_add_parity(dtype); + } +} + +// ============================================================================ +// fused_add_mul: out = (a + b) * c +// ============================================================================ + +fn test_fused_add_mul_parity(dtype: DType) { + let (cpu_client, cpu_device) = create_cpu_client(); + let cases = ternary_cases(); + + let cpu_results: Vec> = cases + .iter() + .map(|tc| { + let a = tensor_from_f64(&tc.a, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + let b = tensor_from_f64(&tc.b, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + let c = tensor_from_f64(&tc.c, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + cpu_client.fused_add_mul(&a, &b, &c).unwrap() + }) + .collect(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + for (idx, tc) in cases.iter().enumerate() { + let a = + tensor_from_f64(&tc.a, &tc.shape, dtype, &cuda_device, &cuda_client).unwrap(); + let b = + tensor_from_f64(&tc.b, &tc.shape, dtype, &cuda_device, &cuda_client).unwrap(); + let c = + tensor_from_f64(&tc.c, &tc.shape, dtype, &cuda_device, &cuda_client).unwrap(); + let result = cuda_client.fused_add_mul(&a, &b, &c).unwrap(); + assert_tensor_allclose( + &result, + &cpu_results[idx], + dtype, + &format!("fused_add_mul CUDA vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + for (idx, tc) in cases.iter().enumerate() { + let a = + tensor_from_f64(&tc.a, &tc.shape, dtype, &wgpu_device, &wgpu_client).unwrap(); + let b = + tensor_from_f64(&tc.b, &tc.shape, dtype, &wgpu_device, &wgpu_client).unwrap(); + let c = + tensor_from_f64(&tc.c, &tc.shape, dtype, &wgpu_device, &wgpu_client).unwrap(); + let result = wgpu_client.fused_add_mul(&a, &b, &c).unwrap(); + assert_tensor_allclose( + &result, + &cpu_results[idx], + dtype, + &format!("fused_add_mul WebGPU vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } +} + +#[test] +fn test_fused_add_mul_all_dtypes() { + for dtype in supported_dtypes("cpu") { + test_fused_add_mul_parity(dtype); + } +} + +// ============================================================================ +// fused_mul_add_scalar: out = a * scale + bias +// ============================================================================ + +#[derive(Clone)] +struct ScalarFmaCase { + data: Vec, + shape: Vec, + scale: f64, + bias: f64, +} + +impl ScalarFmaCase { + fn new(data: Vec, shape: Vec, scale: f64, bias: f64) -> Self { + Self { + data, + shape, + scale, + bias, + } + } +} + +fn scalar_fma_cases() -> Vec { + vec![ + ScalarFmaCase::new(vec![1.0, 2.0, 3.0, 4.0], vec![4], 2.5, -1.0), + ScalarFmaCase::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], vec![2, 3], 10.0, 0.5), + ScalarFmaCase::new(vec![-2.0, -1.0, 0.0, 1.0], vec![2, 2], 0.5, 3.0), + ] +} + +fn test_fused_mul_add_scalar_parity(dtype: DType) { + let (cpu_client, cpu_device) = create_cpu_client(); + let cases = scalar_fma_cases(); + + let cpu_results: Vec> = cases + .iter() + .map(|tc| { + let a = tensor_from_f64(&tc.data, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + cpu_client + .fused_mul_add_scalar(&a, tc.scale, tc.bias) + .unwrap() + }) + .collect(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + for (idx, tc) in cases.iter().enumerate() { + let a = tensor_from_f64(&tc.data, &tc.shape, dtype, &cuda_device, &cuda_client) + .unwrap(); + let result = cuda_client + .fused_mul_add_scalar(&a, tc.scale, tc.bias) + .unwrap(); + assert_tensor_allclose( + &result, + &cpu_results[idx], + dtype, + &format!("fused_mul_add_scalar CUDA vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + for (idx, tc) in cases.iter().enumerate() { + let a = tensor_from_f64(&tc.data, &tc.shape, dtype, &wgpu_device, &wgpu_client) + .unwrap(); + let result = wgpu_client + .fused_mul_add_scalar(&a, tc.scale, tc.bias) + .unwrap(); + assert_tensor_allclose( + &result, + &cpu_results[idx], + dtype, + &format!("fused_mul_add_scalar WebGPU vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } +} + +#[test] +fn test_fused_mul_add_scalar_all_dtypes() { + for dtype in supported_dtypes("cpu") { + test_fused_mul_add_scalar_parity(dtype); + } +} diff --git a/tests/backend_parity/gemm_epilogue.rs b/tests/backend_parity/gemm_epilogue.rs new file mode 100644 index 00000000..ce310fec --- /dev/null +++ b/tests/backend_parity/gemm_epilogue.rs @@ -0,0 +1,668 @@ +// Backend parity tests for GemmEpilogueOps +// +// This module tests matmul_bias_activation, matmul_bias_residual, and +// matmul_bias_activation_bwd across all supported dtypes and backends, +// ensuring numerical consistency across CPU, CUDA, and WebGPU. + +use numr::ops::{ActivationOps, BinaryOps, GemmActivation, GemmEpilogueOps, MatmulOps}; + +use crate::backend_parity::dtype_helpers::tensor_from_f64; +#[cfg(feature = "cuda")] +use crate::backend_parity::helpers::with_cuda_backend; +#[cfg(feature = "wgpu")] +use crate::backend_parity::helpers::with_wgpu_backend; +use crate::common::{ + assert_tensor_allclose, create_cpu_client, is_dtype_supported, supported_dtypes, +}; + +// ============================================================================ +// matmul_bias_activation: 2D parity across activations, dtypes, backends +// ============================================================================ + +#[test] +fn test_gemm_bias_activation_none_2d_parity() { + gemm_bias_activation_2d_parity(GemmActivation::None, "gemm_bias_act_none_2d"); +} + +#[test] +fn test_gemm_bias_activation_relu_2d_parity() { + gemm_bias_activation_2d_parity(GemmActivation::ReLU, "gemm_bias_act_relu_2d"); +} + +#[test] +fn test_gemm_bias_activation_gelu_2d_parity() { + gemm_bias_activation_2d_parity(GemmActivation::GELU, "gemm_bias_act_gelu_2d"); +} + +#[test] +fn test_gemm_bias_activation_silu_2d_parity() { + gemm_bias_activation_2d_parity(GemmActivation::SiLU, "gemm_bias_act_silu_2d"); +} + +#[test] +fn test_gemm_bias_activation_sigmoid_2d_parity() { + gemm_bias_activation_2d_parity(GemmActivation::Sigmoid, "gemm_bias_act_sigmoid_2d"); +} + +#[test] +fn test_gemm_bias_activation_tanh_2d_parity() { + gemm_bias_activation_2d_parity(GemmActivation::Tanh, "gemm_bias_act_tanh_2d"); +} + +fn gemm_bias_activation_2d_parity(activation: GemmActivation, label: &str) { + // [2, 3] @ [3, 2] + [2] -> [2, 2] + let a = vec![1.0f64, 2.0, -1.0, 3.0, -2.0, 4.0]; + let b = vec![0.5f64, -0.3, 0.1, 0.7, -0.2, 0.4]; + let bias = vec![-0.1f64, 0.2]; + + for dtype in supported_dtypes("cpu") { + let (cpu_client, cpu_device) = create_cpu_client(); + let a_t = tensor_from_f64(&a, &[2, 3], dtype, &cpu_device, &cpu_client).unwrap(); + let b_t = tensor_from_f64(&b, &[3, 2], dtype, &cpu_device, &cpu_client).unwrap(); + let bias_t = tensor_from_f64(&bias, &[2], dtype, &cpu_device, &cpu_client).unwrap(); + let cpu_result = cpu_client + .matmul_bias_activation(&a_t, &b_t, &bias_t, activation) + .unwrap(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + let a_t = tensor_from_f64(&a, &[2, 3], dtype, &cuda_device, &cuda_client).unwrap(); + let b_t = tensor_from_f64(&b, &[3, 2], dtype, &cuda_device, &cuda_client).unwrap(); + let bias_t = + tensor_from_f64(&bias, &[2], dtype, &cuda_device, &cuda_client).unwrap(); + let result = cuda_client + .matmul_bias_activation(&a_t, &b_t, &bias_t, activation) + .unwrap(); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("{label} CUDA vs CPU [{dtype:?}]"), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + let a_t = tensor_from_f64(&a, &[2, 3], dtype, &wgpu_device, &wgpu_client).unwrap(); + let b_t = tensor_from_f64(&b, &[3, 2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let bias_t = + tensor_from_f64(&bias, &[2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let result = wgpu_client + .matmul_bias_activation(&a_t, &b_t, &bias_t, activation) + .unwrap(); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("{label} WebGPU vs CPU [{dtype:?}]"), + ); + }); + } + } +} + +// ============================================================================ +// matmul_bias_activation: batched 3D parity +// ============================================================================ + +#[test] +fn test_gemm_bias_activation_batched_3d_parity() { + // [2, 2, 3] @ [2, 3, 2] + [2] -> [2, 2, 2] + let a = vec![ + 1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ]; + let b = vec![ + 0.1f64, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, + ]; + let bias = vec![0.01f64, 0.02]; + + for dtype in supported_dtypes("cpu") { + let (cpu_client, cpu_device) = create_cpu_client(); + let a_t = tensor_from_f64(&a, &[2, 2, 3], dtype, &cpu_device, &cpu_client).unwrap(); + let b_t = tensor_from_f64(&b, &[2, 3, 2], dtype, &cpu_device, &cpu_client).unwrap(); + let bias_t = tensor_from_f64(&bias, &[2], dtype, &cpu_device, &cpu_client).unwrap(); + let cpu_result = cpu_client + .matmul_bias_activation(&a_t, &b_t, &bias_t, GemmActivation::ReLU) + .unwrap(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + let a_t = + tensor_from_f64(&a, &[2, 2, 3], dtype, &cuda_device, &cuda_client).unwrap(); + let b_t = + tensor_from_f64(&b, &[2, 3, 2], dtype, &cuda_device, &cuda_client).unwrap(); + let bias_t = + tensor_from_f64(&bias, &[2], dtype, &cuda_device, &cuda_client).unwrap(); + let result = cuda_client + .matmul_bias_activation(&a_t, &b_t, &bias_t, GemmActivation::ReLU) + .unwrap(); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("gemm_bias_act_batched CUDA vs CPU [{dtype:?}]"), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + let a_t = + tensor_from_f64(&a, &[2, 2, 3], dtype, &wgpu_device, &wgpu_client).unwrap(); + let b_t = + tensor_from_f64(&b, &[2, 3, 2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let bias_t = + tensor_from_f64(&bias, &[2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let result = wgpu_client + .matmul_bias_activation(&a_t, &b_t, &bias_t, GemmActivation::ReLU) + .unwrap(); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("gemm_bias_act_batched WebGPU vs CPU [{dtype:?}]"), + ); + }); + } + } +} + +// ============================================================================ +// matmul_bias_residual: 2D parity across dtypes and backends +// ============================================================================ + +#[test] +fn test_gemm_bias_residual_2d_parity() { + let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]; + let b = vec![0.5f64, -0.3, 0.1, 0.7, -0.2, 0.4]; + let bias = vec![-0.1f64, 0.2]; + let residual = vec![1.0f64, 2.0, 3.0, 4.0]; + + for dtype in supported_dtypes("cpu") { + let (cpu_client, cpu_device) = create_cpu_client(); + let a_t = tensor_from_f64(&a, &[2, 3], dtype, &cpu_device, &cpu_client).unwrap(); + let b_t = tensor_from_f64(&b, &[3, 2], dtype, &cpu_device, &cpu_client).unwrap(); + let bias_t = tensor_from_f64(&bias, &[2], dtype, &cpu_device, &cpu_client).unwrap(); + let res_t = tensor_from_f64(&residual, &[2, 2], dtype, &cpu_device, &cpu_client).unwrap(); + let cpu_result = cpu_client + .matmul_bias_residual(&a_t, &b_t, &bias_t, &res_t) + .unwrap(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + let a_t = tensor_from_f64(&a, &[2, 3], dtype, &cuda_device, &cuda_client).unwrap(); + let b_t = tensor_from_f64(&b, &[3, 2], dtype, &cuda_device, &cuda_client).unwrap(); + let bias_t = + tensor_from_f64(&bias, &[2], dtype, &cuda_device, &cuda_client).unwrap(); + let res_t = + tensor_from_f64(&residual, &[2, 2], dtype, &cuda_device, &cuda_client).unwrap(); + let result = cuda_client + .matmul_bias_residual(&a_t, &b_t, &bias_t, &res_t) + .unwrap(); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("gemm_bias_residual_2d CUDA vs CPU [{dtype:?}]"), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + let a_t = tensor_from_f64(&a, &[2, 3], dtype, &wgpu_device, &wgpu_client).unwrap(); + let b_t = tensor_from_f64(&b, &[3, 2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let bias_t = + tensor_from_f64(&bias, &[2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let res_t = + tensor_from_f64(&residual, &[2, 2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let result = wgpu_client + .matmul_bias_residual(&a_t, &b_t, &bias_t, &res_t) + .unwrap(); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("gemm_bias_residual_2d WebGPU vs CPU [{dtype:?}]"), + ); + }); + } + } +} + +// ============================================================================ +// matmul_bias_activation_bwd: parity across dtypes and backends +// ============================================================================ + +#[test] +fn test_gemm_bias_activation_bwd_none_parity() { + gemm_bias_activation_bwd_parity(GemmActivation::None, "gemm_bias_act_bwd_none"); +} + +#[test] +fn test_gemm_bias_activation_bwd_relu_parity() { + gemm_bias_activation_bwd_parity(GemmActivation::ReLU, "gemm_bias_act_bwd_relu"); +} + +#[test] +fn test_gemm_bias_activation_bwd_sigmoid_parity() { + gemm_bias_activation_bwd_parity(GemmActivation::Sigmoid, "gemm_bias_act_bwd_sigmoid"); +} + +#[test] +fn test_gemm_bias_activation_bwd_tanh_parity() { + gemm_bias_activation_bwd_parity(GemmActivation::Tanh, "gemm_bias_act_bwd_tanh"); +} + +#[test] +fn test_gemm_bias_activation_bwd_silu_parity() { + gemm_bias_activation_bwd_parity(GemmActivation::SiLU, "gemm_bias_act_bwd_silu"); +} + +#[test] +fn test_gemm_bias_activation_bwd_gelu_parity() { + gemm_bias_activation_bwd_parity(GemmActivation::GELU, "gemm_bias_act_bwd_gelu"); +} + +fn gemm_bias_activation_bwd_parity(activation: GemmActivation, label: &str) { + let a = vec![1.0f64, 2.0, 3.0, 4.0]; + let b = vec![0.5f64, 0.3, -0.1, 0.7]; + let bias = vec![0.0f64, 0.0]; + let grad = vec![1.0f64, 1.0, 1.0, 1.0]; + + for dtype in supported_dtypes("cpu") { + let (cpu_client, cpu_device) = create_cpu_client(); + let a_t = tensor_from_f64(&a, &[2, 2], dtype, &cpu_device, &cpu_client).unwrap(); + let b_t = tensor_from_f64(&b, &[2, 2], dtype, &cpu_device, &cpu_client).unwrap(); + let bias_t = tensor_from_f64(&bias, &[2], dtype, &cpu_device, &cpu_client).unwrap(); + let grad_t = tensor_from_f64(&grad, &[2, 2], dtype, &cpu_device, &cpu_client).unwrap(); + let (cpu_da, cpu_db, cpu_dbias) = cpu_client + .matmul_bias_activation_bwd(&grad_t, &a_t, &b_t, &bias_t, activation) + .unwrap(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + let a_t = tensor_from_f64(&a, &[2, 2], dtype, &cuda_device, &cuda_client).unwrap(); + let b_t = tensor_from_f64(&b, &[2, 2], dtype, &cuda_device, &cuda_client).unwrap(); + let bias_t = + tensor_from_f64(&bias, &[2], dtype, &cuda_device, &cuda_client).unwrap(); + let grad_t = + tensor_from_f64(&grad, &[2, 2], dtype, &cuda_device, &cuda_client).unwrap(); + let (da, db, dbias) = cuda_client + .matmul_bias_activation_bwd(&grad_t, &a_t, &b_t, &bias_t, activation) + .unwrap(); + assert_tensor_allclose( + &da, + &cpu_da, + dtype, + &format!("{label} d_a CUDA vs CPU [{dtype:?}]"), + ); + assert_tensor_allclose( + &db, + &cpu_db, + dtype, + &format!("{label} d_b CUDA vs CPU [{dtype:?}]"), + ); + assert_tensor_allclose( + &dbias, + &cpu_dbias, + dtype, + &format!("{label} d_bias CUDA vs CPU [{dtype:?}]"), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + let a_t = tensor_from_f64(&a, &[2, 2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let b_t = tensor_from_f64(&b, &[2, 2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let bias_t = + tensor_from_f64(&bias, &[2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let grad_t = + tensor_from_f64(&grad, &[2, 2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let (da, db, dbias) = wgpu_client + .matmul_bias_activation_bwd(&grad_t, &a_t, &b_t, &bias_t, activation) + .unwrap(); + assert_tensor_allclose( + &da, + &cpu_da, + dtype, + &format!("{label} d_a WebGPU vs CPU [{dtype:?}]"), + ); + assert_tensor_allclose( + &db, + &cpu_db, + dtype, + &format!("{label} d_b WebGPU vs CPU [{dtype:?}]"), + ); + assert_tensor_allclose( + &dbias, + &cpu_dbias, + dtype, + &format!("{label} d_bias WebGPU vs CPU [{dtype:?}]"), + ); + }); + } + } +} + +// ============================================================================ +// matmul_bias_activation_bwd: batched 3D parity +// ============================================================================ + +#[test] +fn test_gemm_bias_activation_bwd_batched_3d_parity() { + let a = vec![ + 1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ]; + let b = vec![ + 0.1f64, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, + ]; + let bias = vec![0.01f64, 0.02]; + let grad = vec![1.0f64; 8]; + + for activation in [ + GemmActivation::None, + GemmActivation::ReLU, + GemmActivation::SiLU, + ] { + for dtype in supported_dtypes("cpu") { + let (cpu_client, cpu_device) = create_cpu_client(); + let a_t = tensor_from_f64(&a, &[2, 2, 3], dtype, &cpu_device, &cpu_client).unwrap(); + let b_t = tensor_from_f64(&b, &[2, 3, 2], dtype, &cpu_device, &cpu_client).unwrap(); + let bias_t = tensor_from_f64(&bias, &[2], dtype, &cpu_device, &cpu_client).unwrap(); + let grad_t = + tensor_from_f64(&grad, &[2, 2, 2], dtype, &cpu_device, &cpu_client).unwrap(); + let (cpu_da, cpu_db, cpu_dbias) = cpu_client + .matmul_bias_activation_bwd(&grad_t, &a_t, &b_t, &bias_t, activation) + .unwrap(); + + assert_eq!(cpu_da.shape(), &[2, 2, 3]); + assert_eq!(cpu_db.shape(), &[2, 3, 2]); + assert_eq!(cpu_dbias.shape(), &[2]); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + let a_t = + tensor_from_f64(&a, &[2, 2, 3], dtype, &cuda_device, &cuda_client).unwrap(); + let b_t = + tensor_from_f64(&b, &[2, 3, 2], dtype, &cuda_device, &cuda_client).unwrap(); + let bias_t = + tensor_from_f64(&bias, &[2], dtype, &cuda_device, &cuda_client).unwrap(); + let grad_t = + tensor_from_f64(&grad, &[2, 2, 2], dtype, &cuda_device, &cuda_client) + .unwrap(); + let label = format!("bwd_batched_{activation:?}"); + let (da, db, dbias) = cuda_client + .matmul_bias_activation_bwd(&grad_t, &a_t, &b_t, &bias_t, activation) + .unwrap(); + assert_tensor_allclose( + &da, + &cpu_da, + dtype, + &format!("{label} d_a CUDA vs CPU [{dtype:?}]"), + ); + assert_tensor_allclose( + &db, + &cpu_db, + dtype, + &format!("{label} d_b CUDA vs CPU [{dtype:?}]"), + ); + assert_tensor_allclose( + &dbias, + &cpu_dbias, + dtype, + &format!("{label} d_bias CUDA vs CPU [{dtype:?}]"), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + let a_t = + tensor_from_f64(&a, &[2, 2, 3], dtype, &wgpu_device, &wgpu_client).unwrap(); + let b_t = + tensor_from_f64(&b, &[2, 3, 2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let bias_t = + tensor_from_f64(&bias, &[2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let grad_t = + tensor_from_f64(&grad, &[2, 2, 2], dtype, &wgpu_device, &wgpu_client) + .unwrap(); + let label = format!("bwd_batched_{activation:?}"); + let (da, db, dbias) = wgpu_client + .matmul_bias_activation_bwd(&grad_t, &a_t, &b_t, &bias_t, activation) + .unwrap(); + assert_tensor_allclose( + &da, + &cpu_da, + dtype, + &format!("{label} d_a WebGPU vs CPU [{dtype:?}]"), + ); + assert_tensor_allclose( + &db, + &cpu_db, + dtype, + &format!("{label} d_b WebGPU vs CPU [{dtype:?}]"), + ); + assert_tensor_allclose( + &dbias, + &cpu_dbias, + dtype, + &format!("{label} d_bias WebGPU vs CPU [{dtype:?}]"), + ); + }); + } + } + } +} + +// ============================================================================ +// matmul_bias_activation_bwd: negative values / edge cases +// ============================================================================ + +#[test] +fn test_gemm_bias_activation_bwd_negative_values_parity() { + let a = vec![-1.0f64, 2.0, 3.0, -4.0]; + let b = vec![-1.0f64, 0.5, 0.5, -1.0]; + let bias = vec![-0.5f64, 0.5]; + let grad = vec![1.0f64, 1.0, 1.0, 1.0]; + + for activation in [ + GemmActivation::None, + GemmActivation::ReLU, + GemmActivation::Sigmoid, + GemmActivation::Tanh, + GemmActivation::SiLU, + GemmActivation::GELU, + ] { + for dtype in supported_dtypes("cpu") { + let (cpu_client, cpu_device) = create_cpu_client(); + let a_t = tensor_from_f64(&a, &[2, 2], dtype, &cpu_device, &cpu_client).unwrap(); + let b_t = tensor_from_f64(&b, &[2, 2], dtype, &cpu_device, &cpu_client).unwrap(); + let bias_t = tensor_from_f64(&bias, &[2], dtype, &cpu_device, &cpu_client).unwrap(); + let grad_t = tensor_from_f64(&grad, &[2, 2], dtype, &cpu_device, &cpu_client).unwrap(); + let (cpu_da, cpu_db, cpu_dbias) = cpu_client + .matmul_bias_activation_bwd(&grad_t, &a_t, &b_t, &bias_t, activation) + .unwrap(); + + // Verify finiteness on CPU reference + for val in cpu_da.to_vec::().iter() { + assert!( + val.is_finite(), + "non-finite d_a for {activation:?} [{dtype:?}]" + ); + } + for val in cpu_db.to_vec::().iter() { + assert!( + val.is_finite(), + "non-finite d_b for {activation:?} [{dtype:?}]" + ); + } + for val in cpu_dbias.to_vec::().iter() { + assert!( + val.is_finite(), + "non-finite d_bias for {activation:?} [{dtype:?}]" + ); + } + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + let a_t = + tensor_from_f64(&a, &[2, 2], dtype, &cuda_device, &cuda_client).unwrap(); + let b_t = + tensor_from_f64(&b, &[2, 2], dtype, &cuda_device, &cuda_client).unwrap(); + let bias_t = + tensor_from_f64(&bias, &[2], dtype, &cuda_device, &cuda_client).unwrap(); + let grad_t = + tensor_from_f64(&grad, &[2, 2], dtype, &cuda_device, &cuda_client).unwrap(); + let label = format!("bwd_neg_{activation:?}"); + let (da, db, dbias) = cuda_client + .matmul_bias_activation_bwd(&grad_t, &a_t, &b_t, &bias_t, activation) + .unwrap(); + assert_tensor_allclose( + &da, + &cpu_da, + dtype, + &format!("{label} d_a CUDA vs CPU [{dtype:?}]"), + ); + assert_tensor_allclose( + &db, + &cpu_db, + dtype, + &format!("{label} d_b CUDA vs CPU [{dtype:?}]"), + ); + assert_tensor_allclose( + &dbias, + &cpu_dbias, + dtype, + &format!("{label} d_bias CUDA vs CPU [{dtype:?}]"), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + let a_t = + tensor_from_f64(&a, &[2, 2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let b_t = + tensor_from_f64(&b, &[2, 2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let bias_t = + tensor_from_f64(&bias, &[2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let grad_t = + tensor_from_f64(&grad, &[2, 2], dtype, &wgpu_device, &wgpu_client).unwrap(); + let label = format!("bwd_neg_{activation:?}"); + let (da, db, dbias) = wgpu_client + .matmul_bias_activation_bwd(&grad_t, &a_t, &b_t, &bias_t, activation) + .unwrap(); + assert_tensor_allclose( + &da, + &cpu_da, + dtype, + &format!("{label} d_a WebGPU vs CPU [{dtype:?}]"), + ); + assert_tensor_allclose( + &db, + &cpu_db, + dtype, + &format!("{label} d_b WebGPU vs CPU [{dtype:?}]"), + ); + assert_tensor_allclose( + &dbias, + &cpu_dbias, + dtype, + &format!("{label} d_bias WebGPU vs CPU [{dtype:?}]"), + ); + }); + } + } + } +} + +// ============================================================================ +// CPU-only reference tests: fused == unfused +// ============================================================================ + +#[test] +fn test_gemm_bias_activation_none_matches_matmul_bias() { + use numr::runtime::cpu::CpuRuntime; + use numr::tensor::Tensor; + + let (client, dev) = create_cpu_client(); + let a = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], &dev); + let b = Tensor::::from_slice(&[0.5f32, -0.3, 0.1, 0.7, -0.2, 0.4], &[3, 2], &dev); + let bias = Tensor::::from_slice(&[-0.1f32, 0.2], &[2], &dev); + + let fused: Vec = client + .matmul_bias_activation(&a, &b, &bias, GemmActivation::None) + .unwrap() + .to_vec(); + let reference: Vec = client.matmul_bias(&a, &b, &bias).unwrap().to_vec(); + + crate::backend_parity::helpers::assert_parity_f32( + &fused, + &reference, + "gemm_bias_act_none_matches_matmul_bias", + ); +} + +#[test] +fn test_gemm_bias_activation_relu_matches_unfused() { + use numr::runtime::cpu::CpuRuntime; + use numr::tensor::Tensor; + + let (client, dev) = create_cpu_client(); + let a = Tensor::::from_slice(&[1.0f32, 2.0, -1.0, 3.0, -2.0, 4.0], &[2, 3], &dev); + let b = Tensor::::from_slice(&[0.5f32, -0.3, 0.1, 0.7, -0.2, 0.4], &[3, 2], &dev); + let bias = Tensor::::from_slice(&[-0.5f32, 0.3], &[2], &dev); + + let fused: Vec = client + .matmul_bias_activation(&a, &b, &bias, GemmActivation::ReLU) + .unwrap() + .to_vec(); + let pre = client.matmul_bias(&a, &b, &bias).unwrap(); + let unfused: Vec = client.relu(&pre).unwrap().to_vec(); + + crate::backend_parity::helpers::assert_parity_f32( + &fused, + &unfused, + "gemm_bias_act_relu_matches_unfused", + ); +} + +#[test] +fn test_gemm_bias_residual_matches_unfused() { + use numr::runtime::cpu::CpuRuntime; + use numr::tensor::Tensor; + + let (client, dev) = create_cpu_client(); + let a = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], &dev); + let b = Tensor::::from_slice(&[0.5f32, -0.3, 0.1, 0.7, -0.2, 0.4], &[3, 2], &dev); + let bias = Tensor::::from_slice(&[-0.1f32, 0.2], &[2], &dev); + let residual = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &dev); + + let fused: Vec = client + .matmul_bias_residual(&a, &b, &bias, &residual) + .unwrap() + .to_vec(); + let pre = client.matmul_bias(&a, &b, &bias).unwrap(); + let unfused: Vec = client.add(&pre, &residual).unwrap().to_vec(); + + crate::backend_parity::helpers::assert_parity_f32( + &fused, + &unfused, + "gemm_bias_residual_matches_unfused", + ); +} diff --git a/tests/backend_parity/helpers.rs b/tests/backend_parity/helpers.rs index 34c7fd34..546215b8 100644 --- a/tests/backend_parity/helpers.rs +++ b/tests/backend_parity/helpers.rs @@ -130,12 +130,15 @@ pub fn with_cuda_backend(mut f: F) where F: FnMut(numr::runtime::cuda::CudaClient, numr::runtime::cuda::CudaDevice), { + use numr::runtime::RuntimeClient; let _guard = CUDA_BACKEND_LOCK .get_or_init(|| Mutex::new(())) .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()); let (client, device) = create_cuda_client_checked() .expect("CUDA feature is enabled but CUDA runtime is unavailable"); + // Sync before test to clear any pending errors from a prior panicked test + client.synchronize(); f(client, device); } @@ -152,21 +155,3 @@ where .expect("WGPU feature is enabled but WGPU runtime is unavailable"); f(client, device); } - -pub fn assert_case_parity_f32( - cpu_results: &[Vec], - idx: usize, - backend_result: &[f32], - op: &str, - backend: &str, -) { - assert_parity_f32( - &cpu_results[idx], - backend_result, - &format!("{op}_{backend}_case_{idx}"), - ); -} - -pub fn assert_single_parity_f32(cpu: &[f32], backend_result: &[f32], op: &str, backend: &str) { - assert_parity_f32(cpu, backend_result, &format!("{op}_{backend}")); -} diff --git a/tests/backend_parity/indexing.rs b/tests/backend_parity/indexing.rs index 407b40b0..39aa566b 100644 --- a/tests/backend_parity/indexing.rs +++ b/tests/backend_parity/indexing.rs @@ -3,10 +3,8 @@ // Dtype-parameterized: each test runs for all supported dtypes across all backends. // Index tensors remain as I32/I64 (not parameterized), only data tensors vary by dtype. -use numr::dtype::DType; use numr::error::Error; use numr::ops::IndexingOps; -use numr::runtime::Runtime; use numr::tensor::Tensor; use crate::backend_parity::dtype_helpers::tensor_from_f64; diff --git a/tests/backend_parity/indexing_advanced.rs b/tests/backend_parity/indexing_advanced.rs index 0c19cbde..92d0eb57 100644 --- a/tests/backend_parity/indexing_advanced.rs +++ b/tests/backend_parity/indexing_advanced.rs @@ -3,10 +3,8 @@ // Dtype-parameterized: each test runs for all supported dtypes across all backends. // Index tensors remain as I32 (not parameterized), only data tensors are dtype-parameterized. -use numr::dtype::DType; use numr::ops::IndexingOps; use numr::ops::ScatterReduceOp; -use numr::runtime::cpu::CpuRuntime; use numr::tensor::Tensor; use crate::backend_parity::dtype_helpers::tensor_from_f64; diff --git a/tests/backend_parity/linalg.rs b/tests/backend_parity/linalg.rs index 99f75244..4f169380 100644 --- a/tests/backend_parity/linalg.rs +++ b/tests/backend_parity/linalg.rs @@ -4,10 +4,6 @@ // Comparison reads back in native dtype via assert_tensor_allclose. use numr::algorithm::linalg::LinearAlgebraAlgorithms; -use numr::dtype::DType; -use numr::runtime::Runtime; -use numr::runtime::cpu::CpuRuntime; -use numr::tensor::Tensor; use crate::backend_parity::dtype_helpers::tensor_from_f64; #[cfg(feature = "cuda")] diff --git a/tests/backend_parity/logical.rs b/tests/backend_parity/logical.rs new file mode 100644 index 00000000..a1514707 --- /dev/null +++ b/tests/backend_parity/logical.rs @@ -0,0 +1,200 @@ +// Backend parity tests for LogicalOps trait +// +// Logical ops work on U8 tensors (0 = false, non-zero = true). +// CPU is the reference implementation; CUDA and WebGPU must match. + +use numr::ops::LogicalOps; +use numr::runtime::Runtime; +use numr::tensor::Tensor; + +#[cfg(feature = "cuda")] +use crate::backend_parity::helpers::with_cuda_backend; +#[cfg(feature = "wgpu")] +use crate::backend_parity::helpers::with_wgpu_backend; +use crate::common::{create_cpu_client, readback_as_bool}; + +#[derive(Clone, Copy, Debug)] +enum LogicalOp { + And, + Or, + Xor, +} + +fn apply_logical_op( + client: &impl LogicalOps, + op: LogicalOp, + a: &Tensor, + b: &Tensor, +) -> numr::error::Result> { + match op { + LogicalOp::And => client.logical_and(a, b), + LogicalOp::Or => client.logical_or(a, b), + LogicalOp::Xor => client.logical_xor(a, b), + } +} + +struct BinaryLogicalCase { + a: Vec, + b: Vec, + shape: Vec, +} + +impl BinaryLogicalCase { + fn new(a: Vec, b: Vec, shape: Vec) -> Self { + Self { a, b, shape } + } +} + +fn binary_logical_cases() -> Vec { + vec![ + // Basic 1D + BinaryLogicalCase::new(vec![1, 0, 1, 0], vec![1, 1, 0, 0], vec![4]), + // All true + BinaryLogicalCase::new(vec![1, 1, 1, 1], vec![1, 1, 1, 1], vec![4]), + // All false + BinaryLogicalCase::new(vec![0, 0, 0, 0], vec![0, 0, 0, 0], vec![4]), + // 2D + BinaryLogicalCase::new(vec![1, 0, 0, 1, 1, 0], vec![0, 1, 1, 0, 1, 1], vec![2, 3]), + // Non-zero values treated as true + BinaryLogicalCase::new(vec![5, 0, 255, 0], vec![0, 3, 0, 1], vec![4]), + ] +} + +fn test_binary_logical_parity(op: LogicalOp) { + let cases = binary_logical_cases(); + let (cpu_client, cpu_device) = create_cpu_client(); + + let cpu_results: Vec> = cases + .iter() + .map(|tc| { + let a = + Tensor::::from_slice(&tc.a, &tc.shape, &cpu_device); + let b = + Tensor::::from_slice(&tc.b, &tc.shape, &cpu_device); + let result = apply_logical_op(&cpu_client, op, &a, &b) + .unwrap_or_else(|e| panic!("CPU {op:?} failed: {e}")); + readback_as_bool(&result) + }) + .collect(); + + #[cfg(feature = "cuda")] + with_cuda_backend(|cuda_client, cuda_device| { + for (idx, tc) in cases.iter().enumerate() { + let a = Tensor::::from_slice( + &tc.a, + &tc.shape, + &cuda_device, + ); + let b = Tensor::::from_slice( + &tc.b, + &tc.shape, + &cuda_device, + ); + let result = apply_logical_op(&cuda_client, op, &a, &b) + .unwrap_or_else(|e| panic!("CUDA {op:?} failed: {e}")); + let cuda_bools = readback_as_bool(&result); + assert_eq!( + cuda_bools, cpu_results[idx], + "{op:?} CUDA vs CPU case {idx}" + ); + } + }); + + #[cfg(feature = "wgpu")] + with_wgpu_backend(|wgpu_client, wgpu_device| { + for (idx, tc) in cases.iter().enumerate() { + // WebGPU uses U32 for bool-like tensors + let a_u32: Vec = tc.a.iter().map(|&v| v as u32).collect(); + let b_u32: Vec = tc.b.iter().map(|&v| v as u32).collect(); + let a = Tensor::::from_slice( + &a_u32, + &tc.shape, + &wgpu_device, + ); + let b = Tensor::::from_slice( + &b_u32, + &tc.shape, + &wgpu_device, + ); + let result = apply_logical_op(&wgpu_client, op, &a, &b) + .unwrap_or_else(|e| panic!("WebGPU {op:?} failed: {e}")); + let wgpu_bools = readback_as_bool(&result); + assert_eq!( + wgpu_bools, cpu_results[idx], + "{op:?} WebGPU vs CPU case {idx}" + ); + } + }); +} + +fn test_not_parity() { + let cases: Vec<(Vec, Vec)> = vec![ + (vec![1, 0, 1, 0], vec![4]), + (vec![0, 0, 0, 0], vec![4]), + (vec![1, 1, 1, 1], vec![4]), + (vec![5, 0, 255, 0, 1, 0], vec![2, 3]), + ]; + + let (cpu_client, cpu_device) = create_cpu_client(); + + let cpu_results: Vec> = cases + .iter() + .map(|(data, shape)| { + let a = Tensor::::from_slice(data, shape, &cpu_device); + let result = cpu_client + .logical_not(&a) + .unwrap_or_else(|e| panic!("CPU NOT failed: {e}")); + readback_as_bool(&result) + }) + .collect(); + + #[cfg(feature = "cuda")] + with_cuda_backend(|cuda_client, cuda_device| { + for (idx, (data, shape)) in cases.iter().enumerate() { + let a = + Tensor::::from_slice(data, shape, &cuda_device); + let result = cuda_client + .logical_not(&a) + .unwrap_or_else(|e| panic!("CUDA NOT failed: {e}")); + let cuda_bools = readback_as_bool(&result); + assert_eq!(cuda_bools, cpu_results[idx], "NOT CUDA vs CPU case {idx}"); + } + }); + + #[cfg(feature = "wgpu")] + with_wgpu_backend(|wgpu_client, wgpu_device| { + for (idx, (data, shape)) in cases.iter().enumerate() { + let data_u32: Vec = data.iter().map(|&v| v as u32).collect(); + let a = Tensor::::from_slice( + &data_u32, + shape, + &wgpu_device, + ); + let result = wgpu_client + .logical_not(&a) + .unwrap_or_else(|e| panic!("WebGPU NOT failed: {e}")); + let wgpu_bools = readback_as_bool(&result); + assert_eq!(wgpu_bools, cpu_results[idx], "NOT WebGPU vs CPU case {idx}"); + } + }); +} + +#[test] +fn test_logical_and_parity() { + test_binary_logical_parity(LogicalOp::And); +} + +#[test] +fn test_logical_or_parity() { + test_binary_logical_parity(LogicalOp::Or); +} + +#[test] +fn test_logical_xor_parity() { + test_binary_logical_parity(LogicalOp::Xor); +} + +#[test] +fn test_logical_not_parity() { + test_not_parity(); +} diff --git a/tests/backend_parity/mod.rs b/tests/backend_parity/mod.rs index 22536aea..f925a114 100644 --- a/tests/backend_parity/mod.rs +++ b/tests/backend_parity/mod.rs @@ -1,16 +1,23 @@ pub mod dtype_helpers; pub mod helpers; +pub mod activation; pub mod advanced_random; pub mod binary; pub mod cast; pub mod compare; pub mod complex; +pub mod conditional; pub mod conv; pub mod cumulative; +pub mod distance; pub mod eigen; pub mod einsum; pub mod fft; +#[cfg(feature = "fp8")] +pub mod fp8_matmul; +pub mod fused_elementwise; +pub mod gemm_epilogue; pub mod indexing; pub mod indexing_advanced; #[cfg(feature = "sparse")] @@ -20,23 +27,30 @@ pub mod iterative_solvers; #[cfg(feature = "sparse")] pub mod iterative_solvers_advanced; pub mod linalg; +pub mod logical; pub mod matmul; pub mod matmul_bias; pub mod matrix_functions_expm; pub mod matrix_functions_logm; pub mod matrix_functions_other; pub mod matrix_functions_sqrtm; +pub mod multivariate; +pub mod normalization; pub mod polynomial; pub mod random; pub mod reduce; pub mod scalar; +pub mod semiring_matmul; pub mod shape; pub mod sort; #[cfg(feature = "sparse")] pub mod sparse; #[cfg(feature = "sparse")] +pub mod sparse_24; +#[cfg(feature = "sparse")] pub mod sparse_ops; pub mod special; pub mod statistics; pub mod svd; pub mod unary; +pub mod utility; diff --git a/tests/backend_parity/multivariate.rs b/tests/backend_parity/multivariate.rs new file mode 100644 index 00000000..eef4e0b3 --- /dev/null +++ b/tests/backend_parity/multivariate.rs @@ -0,0 +1,483 @@ +// Backend parity tests for MultivariateRandomOps trait +// +// Multivariate distributions produce stochastic samples - we validate: +// - Shape correctness +// - Dtype correctness +// - Statistical properties (mean, variance, sum constraints) +// - Consistency with the mathematical definition + +use numr::dtype::DType; +use numr::ops::MultivariateRandomOps; +use numr::runtime::cpu::CpuRuntime; +use numr::tensor::Tensor; + +#[cfg(feature = "cuda")] +use crate::backend_parity::helpers::with_cuda_backend; +#[cfg(feature = "wgpu")] +use crate::backend_parity::helpers::with_wgpu_backend; +use crate::common::{create_cpu_client, is_dtype_supported}; + +// ============================================================================ +// Helpers +// ============================================================================ + +/// Check that all values in a slice are finite (no NaN/Inf) +fn assert_all_finite_f32(vals: &[f32], name: &str) { + for (i, &v) in vals.iter().enumerate() { + assert!( + v.is_finite(), + "{name} value at index {i} is not finite: {v}" + ); + } +} + +/// Check that the rows of a 2D slice (n_samples × k) each sum to approximately `expected_sum` +fn assert_rows_sum_to_f32(vals: &[f32], k: usize, expected_sum: f32, tol: f32, name: &str) { + let n = vals.len() / k; + for i in 0..n { + let row_sum: f32 = vals[i * k..(i + 1) * k].iter().sum(); + assert!( + (row_sum - expected_sum).abs() < tol, + "{name} row {i} sum = {row_sum}, expected {expected_sum} ± {tol}" + ); + } +} + +/// Check that all values are non-negative +fn assert_all_non_negative_f32(vals: &[f32], name: &str) { + for (i, &v) in vals.iter().enumerate() { + assert!(v >= 0.0, "{name} value at index {i} is negative: {v}"); + } +} + +/// Check approximate mean across columns of a 2D matrix (n_samples × k) +fn check_column_mean_f32(vals: &[f32], k: usize, expected_means: &[f32], tol: f32, name: &str) { + let n = (vals.len() / k) as f32; + for (j, &expected) in expected_means.iter().enumerate().take(k) { + let col_mean: f32 = vals.iter().skip(j).step_by(k).sum::() / n; + assert!( + (col_mean - expected).abs() < tol, + "{name} column {j} mean = {col_mean}, expected {expected} ± {tol}" + ); + } +} + +// ============================================================================ +// multivariate_normal tests +// ============================================================================ + +/// Test multivariate_normal produces correct shape, dtype, and finite values on all backends +#[test] +fn test_multivariate_normal_shape_and_dtype() { + let (cpu_client, cpu_device) = create_cpu_client(); + let mean = Tensor::::from_slice(&[0.0f32, 0.0], &[2], &cpu_device); + let cov = Tensor::::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &cpu_device); + let n_samples = 100usize; + + let result = cpu_client + .multivariate_normal(&mean, &cov, n_samples) + .unwrap_or_else(|e| panic!("CPU multivariate_normal failed: {e}")); + + assert_eq!( + result.shape(), + &[100, 2], + "multivariate_normal shape mismatch" + ); + assert_eq!( + result.dtype(), + DType::F32, + "multivariate_normal dtype mismatch" + ); + let vals: Vec = result.to_vec(); + assert_all_finite_f32(&vals, "multivariate_normal CPU"); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", DType::F32) { + with_cuda_backend(|cuda_client, cuda_device| { + use numr::runtime::cuda::CudaRuntime; + let mean_cuda = Tensor::::from_slice(&[0.0f32, 0.0], &[2], &cuda_device); + let cov_cuda = + Tensor::::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &cuda_device); + let result = cuda_client + .multivariate_normal(&mean_cuda, &cov_cuda, n_samples) + .unwrap_or_else(|e| panic!("CUDA multivariate_normal failed: {e}")); + assert_eq!(result.shape(), &[100, 2]); + assert_eq!(result.dtype(), DType::F32); + let vals: Vec = result.to_vec(); + assert_all_finite_f32(&vals, "multivariate_normal CUDA"); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", DType::F32) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + use numr::runtime::wgpu::WgpuRuntime; + let mean_wgpu = Tensor::::from_slice(&[0.0f32, 0.0], &[2], &wgpu_device); + let cov_wgpu = + Tensor::::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &wgpu_device); + let result = wgpu_client + .multivariate_normal(&mean_wgpu, &cov_wgpu, n_samples) + .unwrap_or_else(|e| panic!("WebGPU multivariate_normal failed: {e}")); + assert_eq!(result.shape(), &[100, 2]); + assert_eq!(result.dtype(), DType::F32); + let vals: Vec = result.to_vec(); + assert_all_finite_f32(&vals, "multivariate_normal WebGPU"); + }); + } +} + +/// Test multivariate_normal statistical properties: sample mean converges to true mean +#[test] +fn test_multivariate_normal_statistical_properties() { + let true_mean = [2.0f32, -1.0f32]; + // With 5000 samples and identity cov, sample mean should be within ~0.1 of true mean + + let (cpu_client, cpu_device) = create_cpu_client(); + let mean = Tensor::::from_slice(&true_mean, &[2], &cpu_device); + let cov = Tensor::::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &cpu_device); + let result = cpu_client + .multivariate_normal(&mean, &cov, 5000) + .unwrap_or_else(|e| panic!("CPU multivariate_normal statistical test failed: {e}")); + let vals: Vec = result.to_vec(); + check_column_mean_f32(&vals, 2, &true_mean, 0.1, "multivariate_normal CPU"); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", DType::F32) { + with_cuda_backend(|cuda_client, cuda_device| { + use numr::runtime::cuda::CudaRuntime; + let mean_cuda = Tensor::::from_slice(&true_mean, &[2], &cuda_device); + let cov_cuda = + Tensor::::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &cuda_device); + let result = cuda_client + .multivariate_normal(&mean_cuda, &cov_cuda, 5000) + .unwrap_or_else(|e| { + panic!("CUDA multivariate_normal statistical test failed: {e}") + }); + let vals: Vec = result.to_vec(); + check_column_mean_f32(&vals, 2, &true_mean, 0.1, "multivariate_normal CUDA"); + }); + } +} + +/// Test multivariate_normal with F64 dtype +#[test] +fn test_multivariate_normal_f64() { + let (cpu_client, cpu_device) = create_cpu_client(); + let mean = Tensor::::from_slice(&[0.0f64, 0.0], &[2], &cpu_device); + let cov = Tensor::::from_slice(&[1.0f64, 0.0, 0.0, 1.0], &[2, 2], &cpu_device); + + let result = cpu_client + .multivariate_normal(&mean, &cov, 100) + .unwrap_or_else(|e| panic!("CPU multivariate_normal F64 failed: {e}")); + assert_eq!(result.shape(), &[100, 2]); + assert_eq!(result.dtype(), DType::F64); + let vals: Vec = result.to_vec(); + for (i, &v) in vals.iter().enumerate() { + assert!(v.is_finite(), "f64 value at index {i} is not finite: {v}"); + } + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", DType::F64) { + with_cuda_backend(|cuda_client, cuda_device| { + use numr::runtime::cuda::CudaRuntime; + let mean_cuda = Tensor::::from_slice(&[0.0f64, 0.0], &[2], &cuda_device); + let cov_cuda = + Tensor::::from_slice(&[1.0f64, 0.0, 0.0, 1.0], &[2, 2], &cuda_device); + let result = cuda_client + .multivariate_normal(&mean_cuda, &cov_cuda, 100) + .unwrap_or_else(|e| panic!("CUDA multivariate_normal F64 failed: {e}")); + assert_eq!(result.shape(), &[100, 2]); + assert_eq!(result.dtype(), DType::F64); + let vals: Vec = result.to_vec(); + for (i, &v) in vals.iter().enumerate() { + assert!( + v.is_finite(), + "CUDA f64 value at index {i} is not finite: {v}" + ); + } + }); + } +} + +// ============================================================================ +// dirichlet tests +// ============================================================================ + +/// Test dirichlet produces correct shape, dtype, non-negativity, and row sums on all backends +#[test] +fn test_dirichlet_shape_and_constraints() { + let n_samples = 200usize; + + let (cpu_client, cpu_device) = create_cpu_client(); + let alpha = Tensor::::from_slice(&[1.0f32, 1.0, 1.0], &[3], &cpu_device); + + let result = cpu_client + .dirichlet(&alpha, n_samples) + .unwrap_or_else(|e| panic!("CPU dirichlet failed: {e}")); + + assert_eq!(result.shape(), &[200, 3], "dirichlet shape mismatch"); + assert_eq!(result.dtype(), DType::F32); + let vals: Vec = result.to_vec(); + assert_all_finite_f32(&vals, "dirichlet CPU"); + assert_all_non_negative_f32(&vals, "dirichlet CPU"); + assert_rows_sum_to_f32(&vals, 3, 1.0, 1e-5, "dirichlet CPU"); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", DType::F32) { + with_cuda_backend(|cuda_client, cuda_device| { + use numr::runtime::cuda::CudaRuntime; + let alpha_cuda = + Tensor::::from_slice(&[1.0f32, 1.0, 1.0], &[3], &cuda_device); + let result = cuda_client + .dirichlet(&alpha_cuda, n_samples) + .unwrap_or_else(|e| panic!("CUDA dirichlet failed: {e}")); + assert_eq!(result.shape(), &[200, 3]); + assert_eq!(result.dtype(), DType::F32); + let vals: Vec = result.to_vec(); + assert_all_finite_f32(&vals, "dirichlet CUDA"); + assert_all_non_negative_f32(&vals, "dirichlet CUDA"); + assert_rows_sum_to_f32(&vals, 3, 1.0, 1e-5, "dirichlet CUDA"); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", DType::F32) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + use numr::runtime::wgpu::WgpuRuntime; + let alpha_wgpu = + Tensor::::from_slice(&[1.0f32, 1.0, 1.0], &[3], &wgpu_device); + let result = wgpu_client + .dirichlet(&alpha_wgpu, n_samples) + .unwrap_or_else(|e| panic!("WebGPU dirichlet failed: {e}")); + assert_eq!(result.shape(), &[200, 3]); + assert_eq!(result.dtype(), DType::F32); + let vals: Vec = result.to_vec(); + assert_all_finite_f32(&vals, "dirichlet WebGPU"); + assert_all_non_negative_f32(&vals, "dirichlet WebGPU"); + assert_rows_sum_to_f32(&vals, 3, 1.0, 1e-5, "dirichlet WebGPU"); + }); + } +} + +/// Test dirichlet statistical properties: sample mean converges to alpha_i / sum(alpha) +#[test] +fn test_dirichlet_concentrated_mean() { + // alpha = [10, 10, 10] -> symmetric, expected mean [1/3, 1/3, 1/3] + let expected_means = [1.0f32 / 3.0, 1.0 / 3.0, 1.0 / 3.0]; + + let (cpu_client, cpu_device) = create_cpu_client(); + let alpha = Tensor::::from_slice(&[10.0f32, 10.0, 10.0], &[3], &cpu_device); + let result = cpu_client + .dirichlet(&alpha, 2000) + .unwrap_or_else(|e| panic!("CPU dirichlet concentrated mean test failed: {e}")); + let vals: Vec = result.to_vec(); + check_column_mean_f32( + &vals, + 3, + &expected_means, + 0.05, + "dirichlet CPU concentrated", + ); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", DType::F32) { + with_cuda_backend(|cuda_client, cuda_device| { + use numr::runtime::cuda::CudaRuntime; + let alpha_cuda = + Tensor::::from_slice(&[10.0f32, 10.0, 10.0], &[3], &cuda_device); + let result = cuda_client + .dirichlet(&alpha_cuda, 2000) + .unwrap_or_else(|e| panic!("CUDA dirichlet concentrated mean test failed: {e}")); + let vals: Vec = result.to_vec(); + check_column_mean_f32( + &vals, + 3, + &expected_means, + 0.05, + "dirichlet CUDA concentrated", + ); + }); + } +} + +// ============================================================================ +// multinomial_samples tests +// ============================================================================ + +/// Test multinomial_samples produces correct shape, dtype, non-negativity, and row sums on all backends +#[test] +fn test_multinomial_samples_shape_and_constraints() { + let n_trials = 50usize; + let n_samples = 100usize; + + let (cpu_client, cpu_device) = create_cpu_client(); + let probs = Tensor::::from_slice(&[0.5f32, 0.3, 0.2], &[3], &cpu_device); + + let result = cpu_client + .multinomial_samples(&probs, n_trials, n_samples) + .unwrap_or_else(|e| panic!("CPU multinomial_samples failed: {e}")); + + assert_eq!(result.shape(), &[100, 3], "multinomial shape mismatch"); + assert_eq!(result.dtype(), DType::F32); + let vals: Vec = result.to_vec(); + assert_all_finite_f32(&vals, "multinomial CPU"); + assert_all_non_negative_f32(&vals, "multinomial CPU"); + assert_rows_sum_to_f32(&vals, 3, n_trials as f32, 1e-4, "multinomial CPU"); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", DType::F32) { + with_cuda_backend(|cuda_client, cuda_device| { + use numr::runtime::cuda::CudaRuntime; + let probs_cuda = + Tensor::::from_slice(&[0.5f32, 0.3, 0.2], &[3], &cuda_device); + let result = cuda_client + .multinomial_samples(&probs_cuda, n_trials, n_samples) + .unwrap_or_else(|e| panic!("CUDA multinomial_samples failed: {e}")); + assert_eq!(result.shape(), &[100, 3]); + assert_eq!(result.dtype(), DType::F32); + let vals: Vec = result.to_vec(); + assert_all_finite_f32(&vals, "multinomial CUDA"); + assert_all_non_negative_f32(&vals, "multinomial CUDA"); + assert_rows_sum_to_f32(&vals, 3, n_trials as f32, 1e-4, "multinomial CUDA"); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", DType::F32) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + use numr::runtime::wgpu::WgpuRuntime; + let probs_wgpu = + Tensor::::from_slice(&[0.5f32, 0.3, 0.2], &[3], &wgpu_device); + let result = wgpu_client + .multinomial_samples(&probs_wgpu, n_trials, n_samples) + .unwrap_or_else(|e| panic!("WebGPU multinomial_samples failed: {e}")); + assert_eq!(result.shape(), &[100, 3]); + assert_eq!(result.dtype(), DType::F32); + let vals: Vec = result.to_vec(); + assert_all_finite_f32(&vals, "multinomial WebGPU"); + assert_all_non_negative_f32(&vals, "multinomial WebGPU"); + assert_rows_sum_to_f32(&vals, 3, n_trials as f32, 1e-4, "multinomial WebGPU"); + }); + } +} + +/// Test multinomial_samples statistical properties: mean counts proportional to probs +#[test] +fn test_multinomial_mean_proportional_to_probs() { + // Expected mean for each category = n_trials * p_i + let n_trials = 100usize; + let expected_means = [50.0f32, 30.0, 20.0]; + + let (cpu_client, cpu_device) = create_cpu_client(); + let probs = Tensor::::from_slice(&[0.5f32, 0.3, 0.2], &[3], &cpu_device); + let result = cpu_client + .multinomial_samples(&probs, n_trials, 2000) + .unwrap_or_else(|e| panic!("CPU multinomial mean test failed: {e}")); + let vals: Vec = result.to_vec(); + check_column_mean_f32(&vals, 3, &expected_means, 2.0, "multinomial CPU mean"); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", DType::F32) { + with_cuda_backend(|cuda_client, cuda_device| { + use numr::runtime::cuda::CudaRuntime; + let probs_cuda = + Tensor::::from_slice(&[0.5f32, 0.3, 0.2], &[3], &cuda_device); + let result = cuda_client + .multinomial_samples(&probs_cuda, n_trials, 2000) + .unwrap_or_else(|e| panic!("CUDA multinomial mean test failed: {e}")); + let vals: Vec = result.to_vec(); + check_column_mean_f32(&vals, 3, &expected_means, 2.0, "multinomial CUDA mean"); + }); + } +} + +// ============================================================================ +// wishart tests +// ============================================================================ + +/// Test wishart produces correct shape, dtype, and positive diagonal elements on all backends +#[test] +fn test_wishart_shape_and_positivity() { + let df = 5usize; + let n_samples = 50usize; + + let (cpu_client, cpu_device) = create_cpu_client(); + let scale = Tensor::::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &cpu_device); + + let result = cpu_client + .wishart(&scale, df, n_samples) + .unwrap_or_else(|e| panic!("CPU wishart failed: {e}")); + + assert_eq!(result.shape(), &[50, 2, 2], "wishart shape mismatch"); + assert_eq!(result.dtype(), DType::F32); + let vals: Vec = result.to_vec(); + assert_all_finite_f32(&vals, "wishart CPU"); + // Diagonal elements (variances) must be positive + for i in 0..n_samples { + let base = i * 4; // 2x2 matrix + assert!( + vals[base] > 0.0, + "wishart CPU sample {i}: [0,0] diagonal not positive: {}", + vals[base] + ); + assert!( + vals[base + 3] > 0.0, + "wishart CPU sample {i}: [1,1] diagonal not positive: {}", + vals[base + 3] + ); + } + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", DType::F32) { + with_cuda_backend(|cuda_client, cuda_device| { + use numr::runtime::cuda::CudaRuntime; + let scale_cuda = + Tensor::::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &cuda_device); + let result = cuda_client + .wishart(&scale_cuda, df, n_samples) + .unwrap_or_else(|e| panic!("CUDA wishart failed: {e}")); + assert_eq!(result.shape(), &[50, 2, 2]); + assert_eq!(result.dtype(), DType::F32); + let vals: Vec = result.to_vec(); + assert_all_finite_f32(&vals, "wishart CUDA"); + for i in 0..n_samples { + let base = i * 4; + assert!( + vals[base] > 0.0, + "wishart CUDA sample {i}: [0,0] not positive" + ); + assert!( + vals[base + 3] > 0.0, + "wishart CUDA sample {i}: [1,1] not positive" + ); + } + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", DType::F32) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + use numr::runtime::wgpu::WgpuRuntime; + let scale_wgpu = + Tensor::::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &wgpu_device); + let result = wgpu_client + .wishart(&scale_wgpu, df, n_samples) + .unwrap_or_else(|e| panic!("WebGPU wishart failed: {e}")); + assert_eq!(result.shape(), &[50, 2, 2]); + assert_eq!(result.dtype(), DType::F32); + let vals: Vec = result.to_vec(); + assert_all_finite_f32(&vals, "wishart WebGPU"); + for i in 0..n_samples { + let base = i * 4; + assert!( + vals[base] > 0.0, + "wishart WebGPU sample {i}: [0,0] not positive" + ); + assert!( + vals[base + 3] > 0.0, + "wishart WebGPU sample {i}: [1,1] not positive" + ); + } + }); + } +} diff --git a/tests/backend_parity/normalization.rs b/tests/backend_parity/normalization.rs new file mode 100644 index 00000000..f7fce60f --- /dev/null +++ b/tests/backend_parity/normalization.rs @@ -0,0 +1,618 @@ +// Backend parity tests for fused add+normalization operations (NormalizationOps trait) +// +// Tests: fused_add_rms_norm, fused_add_layer_norm (forward) +// fused_add_rms_norm_bwd, fused_add_layer_norm_bwd (backward) +// +// Dtype-parameterized: each test runs for all supported dtypes across all backends. + +use numr::dtype::DType; +use numr::ops::NormalizationOps; +use numr::tensor::Tensor; + +use crate::backend_parity::dtype_helpers::tensor_from_f64; +#[cfg(feature = "cuda")] +use crate::backend_parity::helpers::with_cuda_backend; +#[cfg(feature = "wgpu")] +use crate::backend_parity::helpers::with_wgpu_backend; +use crate::common::{ + assert_tensor_allclose, create_cpu_client, is_dtype_supported, supported_dtypes, +}; + +// ============================================================================ +// Test Data +// ============================================================================ + +struct FusedNormTestCase { + x: Vec, + residual: Vec, + weight: Vec, + bias: Vec, + shape: Vec, + hidden_size: usize, +} + +fn test_cases() -> Vec { + vec![ + // [4, 8] - simple 2D + FusedNormTestCase { + x: (0..32).map(|i| (i as f64) * 0.1 - 1.6).collect(), + residual: (0..32).map(|i| (i as f64) * 0.05 + 0.1).collect(), + weight: vec![1.0, 0.5, 2.0, 1.5, 0.8, 1.2, 0.7, 1.1], + bias: vec![0.1, -0.1, 0.2, 0.0, -0.2, 0.3, 0.0, 0.1], + shape: vec![4, 8], + hidden_size: 8, + }, + // [2, 3, 16] - 3D batched + FusedNormTestCase { + x: (0..96).map(|i| ((i as f64) * 0.07 - 3.0).sin()).collect(), + residual: (0..96).map(|i| ((i as f64) * 0.13 + 1.0).cos()).collect(), + weight: (0..16).map(|i| 0.5 + (i as f64) * 0.1).collect(), + bias: (0..16).map(|i| -0.5 + (i as f64) * 0.05).collect(), + shape: vec![2, 3, 16], + hidden_size: 16, + }, + // [1, 64] - single batch, larger hidden + FusedNormTestCase { + x: (0..64).map(|i| (i as f64) * 0.03 - 1.0).collect(), + residual: (0..64).map(|i| (i as f64) * 0.02 + 0.5).collect(), + weight: vec![1.0; 64], + bias: vec![0.0; 64], + shape: vec![1, 64], + hidden_size: 64, + }, + ] +} + +// ============================================================================ +// Fused Add + RMS Norm Forward +// ============================================================================ + +fn test_fused_add_rms_norm_parity_impl(dtype: DType) { + let (cpu_client, cpu_device) = create_cpu_client(); + let cases = test_cases(); + let eps = 1e-5f32; + + let cpu_results: Vec<( + Tensor, + Tensor, + )> = cases + .iter() + .map(|tc| { + let x = tensor_from_f64(&tc.x, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + let res = + tensor_from_f64(&tc.residual, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + let w = tensor_from_f64( + &tc.weight, + &[tc.hidden_size], + dtype, + &cpu_device, + &cpu_client, + ) + .unwrap(); + cpu_client.fused_add_rms_norm(&x, &res, &w, eps).unwrap() + }) + .collect(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + for (idx, tc) in cases.iter().enumerate() { + let x = + tensor_from_f64(&tc.x, &tc.shape, dtype, &cuda_device, &cuda_client).unwrap(); + let res = + tensor_from_f64(&tc.residual, &tc.shape, dtype, &cuda_device, &cuda_client) + .unwrap(); + let w = tensor_from_f64( + &tc.weight, + &[tc.hidden_size], + dtype, + &cuda_device, + &cuda_client, + ) + .unwrap(); + let (out, pre_norm) = cuda_client.fused_add_rms_norm(&x, &res, &w, eps).unwrap(); + assert_tensor_allclose( + &out, + &cpu_results[idx].0, + dtype, + &format!("fused_add_rms_norm output CUDA vs CPU [{dtype:?}] case {idx}"), + ); + assert_tensor_allclose( + &pre_norm, + &cpu_results[idx].1, + dtype, + &format!("fused_add_rms_norm pre_norm CUDA vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + for (idx, tc) in cases.iter().enumerate() { + let x = + tensor_from_f64(&tc.x, &tc.shape, dtype, &wgpu_device, &wgpu_client).unwrap(); + let res = + tensor_from_f64(&tc.residual, &tc.shape, dtype, &wgpu_device, &wgpu_client) + .unwrap(); + let w = tensor_from_f64( + &tc.weight, + &[tc.hidden_size], + dtype, + &wgpu_device, + &wgpu_client, + ) + .unwrap(); + let (out, pre_norm) = wgpu_client.fused_add_rms_norm(&x, &res, &w, eps).unwrap(); + assert_tensor_allclose( + &out, + &cpu_results[idx].0, + dtype, + &format!("fused_add_rms_norm output WebGPU vs CPU [{dtype:?}] case {idx}"), + ); + assert_tensor_allclose( + &pre_norm, + &cpu_results[idx].1, + dtype, + &format!("fused_add_rms_norm pre_norm WebGPU vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } +} + +#[test] +fn test_fused_add_rms_norm_parity() { + for dtype in supported_dtypes("cpu") { + test_fused_add_rms_norm_parity_impl(dtype); + } +} + +// ============================================================================ +// Fused Add + Layer Norm Forward +// ============================================================================ + +fn test_fused_add_layer_norm_parity_impl(dtype: DType) { + let (cpu_client, cpu_device) = create_cpu_client(); + let cases = test_cases(); + let eps = 1e-5f32; + + let cpu_results: Vec<( + Tensor, + Tensor, + )> = cases + .iter() + .map(|tc| { + let x = tensor_from_f64(&tc.x, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + let res = + tensor_from_f64(&tc.residual, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + let w = tensor_from_f64( + &tc.weight, + &[tc.hidden_size], + dtype, + &cpu_device, + &cpu_client, + ) + .unwrap(); + let b = tensor_from_f64(&tc.bias, &[tc.hidden_size], dtype, &cpu_device, &cpu_client) + .unwrap(); + cpu_client + .fused_add_layer_norm(&x, &res, &w, &b, eps) + .unwrap() + }) + .collect(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + for (idx, tc) in cases.iter().enumerate() { + let x = + tensor_from_f64(&tc.x, &tc.shape, dtype, &cuda_device, &cuda_client).unwrap(); + let res = + tensor_from_f64(&tc.residual, &tc.shape, dtype, &cuda_device, &cuda_client) + .unwrap(); + let w = tensor_from_f64( + &tc.weight, + &[tc.hidden_size], + dtype, + &cuda_device, + &cuda_client, + ) + .unwrap(); + let b = tensor_from_f64( + &tc.bias, + &[tc.hidden_size], + dtype, + &cuda_device, + &cuda_client, + ) + .unwrap(); + let (out, pre_norm) = cuda_client + .fused_add_layer_norm(&x, &res, &w, &b, eps) + .unwrap(); + assert_tensor_allclose( + &out, + &cpu_results[idx].0, + dtype, + &format!("fused_add_layer_norm output CUDA vs CPU [{dtype:?}] case {idx}"), + ); + assert_tensor_allclose( + &pre_norm, + &cpu_results[idx].1, + dtype, + &format!("fused_add_layer_norm pre_norm CUDA vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + for (idx, tc) in cases.iter().enumerate() { + let x = + tensor_from_f64(&tc.x, &tc.shape, dtype, &wgpu_device, &wgpu_client).unwrap(); + let res = + tensor_from_f64(&tc.residual, &tc.shape, dtype, &wgpu_device, &wgpu_client) + .unwrap(); + let w = tensor_from_f64( + &tc.weight, + &[tc.hidden_size], + dtype, + &wgpu_device, + &wgpu_client, + ) + .unwrap(); + let b = tensor_from_f64( + &tc.bias, + &[tc.hidden_size], + dtype, + &wgpu_device, + &wgpu_client, + ) + .unwrap(); + let (out, pre_norm) = wgpu_client + .fused_add_layer_norm(&x, &res, &w, &b, eps) + .unwrap(); + assert_tensor_allclose( + &out, + &cpu_results[idx].0, + dtype, + &format!("fused_add_layer_norm output WebGPU vs CPU [{dtype:?}] case {idx}"), + ); + assert_tensor_allclose( + &pre_norm, + &cpu_results[idx].1, + dtype, + &format!("fused_add_layer_norm pre_norm WebGPU vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } +} + +#[test] +fn test_fused_add_layer_norm_parity() { + for dtype in supported_dtypes("cpu") { + test_fused_add_layer_norm_parity_impl(dtype); + } +} + +// ============================================================================ +// Fused Add + RMS Norm Backward +// ============================================================================ + +fn test_fused_add_rms_norm_bwd_parity_impl(dtype: DType) { + let (cpu_client, cpu_device) = create_cpu_client(); + let cases = test_cases(); + let eps = 1e-5f32; + + // First compute pre_norm via forward, then test backward + let cpu_results: Vec<( + Tensor, + Tensor, + )> = cases + .iter() + .map(|tc| { + let x = tensor_from_f64(&tc.x, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + let res = + tensor_from_f64(&tc.residual, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + let w = tensor_from_f64( + &tc.weight, + &[tc.hidden_size], + dtype, + &cpu_device, + &cpu_client, + ) + .unwrap(); + let (_out, pre_norm) = cpu_client.fused_add_rms_norm(&x, &res, &w, eps).unwrap(); + let grad_data: Vec = (0..tc.x.len()) + .map(|i| ((i as f64) * 0.1).sin() + 0.5) + .collect(); + let grad = + tensor_from_f64(&grad_data, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + cpu_client + .fused_add_rms_norm_bwd(&grad, &pre_norm, &w, eps) + .unwrap() + }) + .collect(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + for (idx, tc) in cases.iter().enumerate() { + let x = + tensor_from_f64(&tc.x, &tc.shape, dtype, &cuda_device, &cuda_client).unwrap(); + let res = + tensor_from_f64(&tc.residual, &tc.shape, dtype, &cuda_device, &cuda_client) + .unwrap(); + let w = tensor_from_f64( + &tc.weight, + &[tc.hidden_size], + dtype, + &cuda_device, + &cuda_client, + ) + .unwrap(); + let (_out, pre_norm) = cuda_client.fused_add_rms_norm(&x, &res, &w, eps).unwrap(); + let grad_data: Vec = (0..tc.x.len()) + .map(|i| ((i as f64) * 0.1).sin() + 0.5) + .collect(); + let grad = + tensor_from_f64(&grad_data, &tc.shape, dtype, &cuda_device, &cuda_client) + .unwrap(); + let (d_input_res, d_weight) = cuda_client + .fused_add_rms_norm_bwd(&grad, &pre_norm, &w, eps) + .unwrap(); + assert_tensor_allclose( + &d_input_res, + &cpu_results[idx].0, + dtype, + &format!( + "fused_add_rms_norm_bwd d_input_residual CUDA vs CPU [{dtype:?}] case {idx}" + ), + ); + assert_tensor_allclose( + &d_weight, + &cpu_results[idx].1, + dtype, + &format!("fused_add_rms_norm_bwd d_weight CUDA vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + for (idx, tc) in cases.iter().enumerate() { + let x = + tensor_from_f64(&tc.x, &tc.shape, dtype, &wgpu_device, &wgpu_client).unwrap(); + let res = + tensor_from_f64(&tc.residual, &tc.shape, dtype, &wgpu_device, &wgpu_client) + .unwrap(); + let w = tensor_from_f64( + &tc.weight, + &[tc.hidden_size], + dtype, + &wgpu_device, + &wgpu_client, + ) + .unwrap(); + let (_out, pre_norm) = wgpu_client.fused_add_rms_norm(&x, &res, &w, eps).unwrap(); + let grad_data: Vec = (0..tc.x.len()) + .map(|i| ((i as f64) * 0.1).sin() + 0.5) + .collect(); + let grad = + tensor_from_f64(&grad_data, &tc.shape, dtype, &wgpu_device, &wgpu_client) + .unwrap(); + let (d_input_res, d_weight) = wgpu_client + .fused_add_rms_norm_bwd(&grad, &pre_norm, &w, eps) + .unwrap(); + assert_tensor_allclose( + &d_input_res, + &cpu_results[idx].0, + dtype, + &format!( + "fused_add_rms_norm_bwd d_input_residual WebGPU vs CPU [{dtype:?}] case {idx}" + ), + ); + assert_tensor_allclose( + &d_weight, + &cpu_results[idx].1, + dtype, + &format!( + "fused_add_rms_norm_bwd d_weight WebGPU vs CPU [{dtype:?}] case {idx}" + ), + ); + } + }); + } +} + +#[test] +fn test_fused_add_rms_norm_bwd_parity() { + for dtype in supported_dtypes("cpu") { + test_fused_add_rms_norm_bwd_parity_impl(dtype); + } +} + +// ============================================================================ +// Fused Add + Layer Norm Backward +// ============================================================================ + +fn test_fused_add_layer_norm_bwd_parity_impl(dtype: DType) { + let (cpu_client, cpu_device) = create_cpu_client(); + let cases = test_cases(); + let eps = 1e-5f32; + + let cpu_results: Vec<( + Tensor, + Tensor, + Tensor, + )> = cases + .iter() + .map(|tc| { + let x = tensor_from_f64(&tc.x, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + let res = + tensor_from_f64(&tc.residual, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + let w = tensor_from_f64( + &tc.weight, + &[tc.hidden_size], + dtype, + &cpu_device, + &cpu_client, + ) + .unwrap(); + let b = tensor_from_f64(&tc.bias, &[tc.hidden_size], dtype, &cpu_device, &cpu_client) + .unwrap(); + let (_out, pre_norm) = cpu_client + .fused_add_layer_norm(&x, &res, &w, &b, eps) + .unwrap(); + let grad_data: Vec = (0..tc.x.len()) + .map(|i| ((i as f64) * 0.1).sin() + 0.5) + .collect(); + let grad = + tensor_from_f64(&grad_data, &tc.shape, dtype, &cpu_device, &cpu_client).unwrap(); + cpu_client + .fused_add_layer_norm_bwd(&grad, &pre_norm, &w, &b, eps) + .unwrap() + }) + .collect(); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + for (idx, tc) in cases.iter().enumerate() { + let x = + tensor_from_f64(&tc.x, &tc.shape, dtype, &cuda_device, &cuda_client).unwrap(); + let res = + tensor_from_f64(&tc.residual, &tc.shape, dtype, &cuda_device, &cuda_client) + .unwrap(); + let w = tensor_from_f64( + &tc.weight, + &[tc.hidden_size], + dtype, + &cuda_device, + &cuda_client, + ) + .unwrap(); + let b = tensor_from_f64( + &tc.bias, + &[tc.hidden_size], + dtype, + &cuda_device, + &cuda_client, + ) + .unwrap(); + let (_out, pre_norm) = cuda_client + .fused_add_layer_norm(&x, &res, &w, &b, eps) + .unwrap(); + let grad_data: Vec = (0..tc.x.len()) + .map(|i| ((i as f64) * 0.1).sin() + 0.5) + .collect(); + let grad = + tensor_from_f64(&grad_data, &tc.shape, dtype, &cuda_device, &cuda_client) + .unwrap(); + let (d_input_res, d_weight, d_bias) = cuda_client + .fused_add_layer_norm_bwd(&grad, &pre_norm, &w, &b, eps) + .unwrap(); + assert_tensor_allclose( + &d_input_res, + &cpu_results[idx].0, + dtype, + &format!( + "fused_add_layer_norm_bwd d_input_residual CUDA vs CPU [{dtype:?}] case {idx}" + ), + ); + assert_tensor_allclose( + &d_weight, + &cpu_results[idx].1, + dtype, + &format!( + "fused_add_layer_norm_bwd d_weight CUDA vs CPU [{dtype:?}] case {idx}" + ), + ); + assert_tensor_allclose( + &d_bias, + &cpu_results[idx].2, + dtype, + &format!("fused_add_layer_norm_bwd d_bias CUDA vs CPU [{dtype:?}] case {idx}"), + ); + } + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + for (idx, tc) in cases.iter().enumerate() { + let x = + tensor_from_f64(&tc.x, &tc.shape, dtype, &wgpu_device, &wgpu_client).unwrap(); + let res = + tensor_from_f64(&tc.residual, &tc.shape, dtype, &wgpu_device, &wgpu_client) + .unwrap(); + let w = tensor_from_f64( + &tc.weight, + &[tc.hidden_size], + dtype, + &wgpu_device, + &wgpu_client, + ) + .unwrap(); + let b = tensor_from_f64( + &tc.bias, + &[tc.hidden_size], + dtype, + &wgpu_device, + &wgpu_client, + ) + .unwrap(); + let (_out, pre_norm) = wgpu_client + .fused_add_layer_norm(&x, &res, &w, &b, eps) + .unwrap(); + let grad_data: Vec = (0..tc.x.len()) + .map(|i| ((i as f64) * 0.1).sin() + 0.5) + .collect(); + let grad = + tensor_from_f64(&grad_data, &tc.shape, dtype, &wgpu_device, &wgpu_client) + .unwrap(); + let (d_input_res, d_weight, d_bias) = wgpu_client + .fused_add_layer_norm_bwd(&grad, &pre_norm, &w, &b, eps) + .unwrap(); + assert_tensor_allclose( + &d_input_res, + &cpu_results[idx].0, + dtype, + &format!( + "fused_add_layer_norm_bwd d_input_residual WebGPU vs CPU [{dtype:?}] case {idx}" + ), + ); + assert_tensor_allclose( + &d_weight, + &cpu_results[idx].1, + dtype, + &format!( + "fused_add_layer_norm_bwd d_weight WebGPU vs CPU [{dtype:?}] case {idx}" + ), + ); + assert_tensor_allclose( + &d_bias, + &cpu_results[idx].2, + dtype, + &format!( + "fused_add_layer_norm_bwd d_bias WebGPU vs CPU [{dtype:?}] case {idx}" + ), + ); + } + }); + } +} + +#[test] +fn test_fused_add_layer_norm_bwd_parity() { + for dtype in supported_dtypes("cpu") { + test_fused_add_layer_norm_bwd_parity_impl(dtype); + } +} diff --git a/tests/backend_parity/polynomial.rs b/tests/backend_parity/polynomial.rs index 7fd2a978..4db9f363 100644 --- a/tests/backend_parity/polynomial.rs +++ b/tests/backend_parity/polynomial.rs @@ -5,7 +5,6 @@ use numr::algorithm::polynomial::PolynomialAlgorithms; use numr::dtype::DType; -use numr::runtime::Runtime; use numr::runtime::cpu::CpuRuntime; use numr::tensor::Tensor; diff --git a/tests/backend_parity/random.rs b/tests/backend_parity/random.rs index 71a2c4fc..7fab3fe3 100644 --- a/tests/backend_parity/random.rs +++ b/tests/backend_parity/random.rs @@ -145,9 +145,8 @@ fn test_rand_invariants_all_backends() { }}; } - match dtype { - DType::F32 => check_wgpu!(f32), // WebGPU: F32 only - _ => {} + if dtype == DType::F32 { + check_wgpu!(f32); // WebGPU: F32 only } }); } @@ -244,9 +243,8 @@ fn test_randn_invariants_all_backends() { }}; } - match dtype { - DType::F32 => check_wgpu!(f32), // WebGPU: F32 only - _ => {} + if dtype == DType::F32 { + check_wgpu!(f32); // WebGPU: F32 only } }); } @@ -341,3 +339,75 @@ fn test_rand_shape_dtype_all_backends() { } } } + +// ============================================================ +// rand_seeded reproducibility tests +// ============================================================ + +#[test] +fn test_rand_seeded_reproducibility_cpu() { + let (client, _device) = create_cpu_client(); + + // Same seed → same output + let a = client.rand_seeded(&[100], DType::F32, 42).unwrap(); + let b = client.rand_seeded(&[100], DType::F32, 42).unwrap(); + let a_vec: Vec = a.to_vec(); + let b_vec: Vec = b.to_vec(); + assert_eq!(a_vec, b_vec, "same seed must produce same output"); + + // Different seed → different output + let c = client.rand_seeded(&[100], DType::F32, 99).unwrap(); + let c_vec: Vec = c.to_vec(); + assert_ne!( + a_vec, c_vec, + "different seeds must produce different output" + ); + + // Values in [0, 1) + for &v in &a_vec { + assert!((0.0..1.0).contains(&v), "value out of range: {v}"); + } +} + +#[cfg(feature = "cuda")] +#[test] +fn test_rand_seeded_reproducibility_cuda() { + with_cuda_backend(|client, _device| { + let a = client.rand_seeded(&[100], DType::F32, 42).unwrap(); + let b = client.rand_seeded(&[100], DType::F32, 42).unwrap(); + let a_vec: Vec = a.to_vec(); + let b_vec: Vec = b.to_vec(); + assert_eq!(a_vec, b_vec, "same seed must produce same output on CUDA"); + + let c = client.rand_seeded(&[100], DType::F32, 99).unwrap(); + let c_vec: Vec = c.to_vec(); + assert_ne!( + a_vec, c_vec, + "different seeds must produce different output on CUDA" + ); + }); +} + +#[cfg(feature = "wgpu")] +#[test] +fn test_rand_seeded_reproducibility_wgpu() { + with_wgpu_backend(|client, _device| { + let a = client.rand_seeded(&[100], DType::F32, 42).unwrap(); + let b = client.rand_seeded(&[100], DType::F32, 42).unwrap(); + let a_vec: Vec = a.to_vec(); + let b_vec: Vec = b.to_vec(); + assert_eq!(a_vec, b_vec, "same seed must produce same output on WebGPU"); + + let c = client.rand_seeded(&[100], DType::F32, 99).unwrap(); + let c_vec: Vec = c.to_vec(); + assert_ne!( + a_vec, c_vec, + "different seeds must produce different output on WebGPU" + ); + + // Values in [0, 1) + for &v in &a_vec { + assert!((0.0..1.0).contains(&v), "value out of range: {v}"); + } + }); +} diff --git a/tests/backend_parity/semiring_matmul.rs b/tests/backend_parity/semiring_matmul.rs new file mode 100644 index 00000000..97d8b929 --- /dev/null +++ b/tests/backend_parity/semiring_matmul.rs @@ -0,0 +1,191 @@ +// Backend parity tests for SemiringMatmulOps trait +// +// Tests: semiring_matmul with MinPlus, MaxPlus, MaxMin, MinMax, OrAnd +// CPU is the reference implementation; CUDA and WebGPU must match. + +use crate::backend_parity::dtype_helpers::tensor_from_f64; +#[cfg(feature = "cuda")] +use crate::backend_parity::helpers::with_cuda_backend; +#[cfg(feature = "wgpu")] +use crate::backend_parity::helpers::with_wgpu_backend; +use crate::common::{ + assert_tensor_allclose, create_cpu_client, is_dtype_supported, supported_dtypes, +}; +use numr::dtype::DType; +use numr::ops::{SemiringMatmulOps, SemiringOp}; + +struct SemiringCase { + a: Vec, + a_shape: Vec, + b: Vec, + b_shape: Vec, + op: SemiringOp, +} + +impl SemiringCase { + fn new( + a: Vec, + a_shape: Vec, + b: Vec, + b_shape: Vec, + op: SemiringOp, + ) -> Self { + Self { + a, + a_shape, + b, + b_shape, + op, + } + } +} + +fn semiring_test_cases() -> Vec { + // 2x3 @ 3x2 matrices + let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]; + + vec![ + // MinPlus: shortest path semantics + SemiringCase::new( + a.clone(), + vec![2, 3], + b.clone(), + vec![3, 2], + SemiringOp::MinPlus, + ), + // MaxPlus: longest path semantics + SemiringCase::new( + a.clone(), + vec![2, 3], + b.clone(), + vec![3, 2], + SemiringOp::MaxPlus, + ), + // MaxMin: bottleneck path + SemiringCase::new( + a.clone(), + vec![2, 3], + b.clone(), + vec![3, 2], + SemiringOp::MaxMin, + ), + // MinMax: fuzzy relations + SemiringCase::new( + a.clone(), + vec![2, 3], + b.clone(), + vec![3, 2], + SemiringOp::MinMax, + ), + // Smaller matrices + SemiringCase::new( + vec![1.0, 2.0, 3.0, 4.0], + vec![2, 2], + vec![5.0, 6.0, 7.0, 8.0], + vec![2, 2], + SemiringOp::MinPlus, + ), + // 1x4 @ 4x1 (vector inner product) + SemiringCase::new( + vec![1.0, 2.0, 3.0, 4.0], + vec![1, 4], + vec![5.0, 6.0, 7.0, 8.0], + vec![4, 1], + SemiringOp::MaxPlus, + ), + ] +} + +fn test_semiring_parity(dtype: DType) { + let cases = semiring_test_cases(); + let (cpu_client, cpu_device) = create_cpu_client(); + + for (idx, tc) in cases.iter().enumerate() { + let cpu_a = tensor_from_f64(&tc.a, &tc.a_shape, dtype, &cpu_device, &cpu_client) + .expect("CPU a tensor failed"); + let cpu_b = tensor_from_f64(&tc.b, &tc.b_shape, dtype, &cpu_device, &cpu_client) + .expect("CPU b tensor failed"); + let cpu_result = cpu_client + .semiring_matmul(&cpu_a, &cpu_b, tc.op) + .unwrap_or_else(|e| panic!("CPU semiring {:?} failed for {dtype:?}: {e}", tc.op)); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + let a = tensor_from_f64(&tc.a, &tc.a_shape, dtype, &cuda_device, &cuda_client) + .expect("CUDA a tensor failed"); + let b = tensor_from_f64(&tc.b, &tc.b_shape, dtype, &cuda_device, &cuda_client) + .expect("CUDA b tensor failed"); + let result = cuda_client + .semiring_matmul(&a, &b, tc.op) + .unwrap_or_else(|e| panic!("CUDA semiring failed: {e}")); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("semiring {:?} CUDA vs CPU [{dtype:?}] case {idx}", tc.op), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + let a = tensor_from_f64(&tc.a, &tc.a_shape, dtype, &wgpu_device, &wgpu_client) + .expect("WebGPU a tensor failed"); + let b = tensor_from_f64(&tc.b, &tc.b_shape, dtype, &wgpu_device, &wgpu_client) + .expect("WebGPU b tensor failed"); + let result = wgpu_client + .semiring_matmul(&a, &b, tc.op) + .unwrap_or_else(|e| panic!("WebGPU semiring failed: {e}")); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("semiring {:?} WebGPU vs CPU [{dtype:?}] case {idx}", tc.op), + ); + }); + } + } +} + +#[test] +fn test_semiring_matmul_parity_all_dtypes() { + for dtype in supported_dtypes("cpu") { + test_semiring_parity(dtype); + } +} + +// OrAnd operates on Bool tensors (u8: 0/1 values) +#[test] +fn test_semiring_or_and_parity() { + use numr::tensor::Tensor; + + let (cpu_client, cpu_device) = create_cpu_client(); + + // Boolean adjacency matrices + let a: Vec = vec![1, 0, 1, 0, 1, 1, 0, 0, 1]; + let b: Vec = vec![0, 1, 0, 1, 0, 1, 1, 1, 0]; + + let cpu_a = Tensor::::from_slice(&a, &[3, 3], &cpu_device); + let cpu_b = Tensor::::from_slice(&b, &[3, 3], &cpu_device); + #[allow(unused_variables)] + let cpu_result = cpu_client + .semiring_matmul(&cpu_a, &cpu_b, SemiringOp::OrAnd) + .expect("CPU OrAnd failed"); + + // WebGPU skipped: OrAnd requires Bool dtype, WebGPU is 32-bit only + + #[cfg(feature = "cuda")] + with_cuda_backend(|cuda_client, cuda_device| { + let cpu_vals = cpu_result.to_vec::(); + let ca = Tensor::::from_slice(&a, &[3, 3], &cuda_device); + let cb = Tensor::::from_slice(&b, &[3, 3], &cuda_device); + let result = cuda_client + .semiring_matmul(&ca, &cb, SemiringOp::OrAnd) + .expect("CUDA OrAnd failed"); + let cuda_vals = result.to_vec::(); + assert_eq!(cpu_vals, cuda_vals, "OrAnd CUDA vs CPU"); + }); +} diff --git a/tests/backend_parity/sort.rs b/tests/backend_parity/sort.rs index 6bbad29a..cf0c63cc 100644 --- a/tests/backend_parity/sort.rs +++ b/tests/backend_parity/sort.rs @@ -3,11 +3,7 @@ // Dtype-parameterized: each test runs for all supported dtypes across all backends. // Comparison reads back in native dtype via assert_tensor_allclose. -use numr::dtype::DType; use numr::ops::SortingOps; -use numr::runtime::Runtime; -use numr::runtime::cpu::{CpuDevice, CpuRuntime}; -use numr::tensor::Tensor; use crate::backend_parity::dtype_helpers::tensor_from_f64; #[cfg(feature = "cuda")] diff --git a/tests/backend_parity/sparse.rs b/tests/backend_parity/sparse.rs index 31ddcb04..1799bc6e 100644 --- a/tests/backend_parity/sparse.rs +++ b/tests/backend_parity/sparse.rs @@ -17,7 +17,7 @@ use numr::sparse::{CsrData, SparseOps, SparseStorage}; use numr::tensor::Tensor; /// Helper to assert sparse matrices are close within tolerance -fn assert_sparse_allclose( +fn assert_sparse_allclose, B: Runtime>( a: &CsrData, b: &CsrData, _rtol: f64, @@ -68,7 +68,10 @@ fn assert_sparse_allclose( } /// Helper to create a simple test sparse matrix in CSR format -fn create_test_csr_3x3(device: &R::Device, dtype: DType) -> Result> { +fn create_test_csr_3x3>( + device: &R::Device, + dtype: DType, +) -> Result> { // Matrix: // [1.0, 0.0, 2.0] // [0.0, 3.0, 0.0] diff --git a/tests/backend_parity/sparse_24.rs b/tests/backend_parity/sparse_24.rs new file mode 100644 index 00000000..315ec092 --- /dev/null +++ b/tests/backend_parity/sparse_24.rs @@ -0,0 +1,296 @@ +//! Backend parity tests for 2:4 structured sparsity operations. +//! +//! Tests verify that CPU, CUDA, and WebGPU backends produce identical results +//! for prune, decompress, and sparse matmul operations. + +use crate::backend_parity::helpers::assert_parity_f32; +use crate::common::create_cpu_client; +use numr::runtime::cpu::CpuRuntime; +use numr::sparse::Sparse24Ops; +use numr::tensor::Tensor; + +// ============================================================================ +// CPU-only correctness tests +// ============================================================================ + +#[test] +fn test_prune_to_24_correctness() { + let (client, device) = create_cpu_client(); + + // Matrix: 2x8, each row has 2 groups of 4 + let data: Vec = vec![ + 1.0, -3.0, 2.0, 0.5, // group 0: top-2 = -3.0 (1), 2.0 (2) + 0.1, 0.2, 0.3, 0.4, // group 1: top-2 = 0.3 (2), 0.4 (3) + 4.0, 1.0, -5.0, 3.0, // group 2: top-2 = 4.0 (0), -5.0 (2) + 0.0, 0.0, 0.0, 0.0, // group 3: all zero, keeps (0), (1) + ]; + let dense = Tensor::::from_slice(&data, &[2, 8], &device); + let sparse = client.prune_to_24(&dense).unwrap(); + + assert_eq!(sparse.shape(), [2, 8]); + assert_eq!(sparse.nnz(), 2 * 4); // 2 rows * 4 non-zeros per row + assert!(sparse.is_valid()); + + // Verify compressed values + let vals: Vec = sparse.compressed_values().to_vec(); + // Row 0, group 0: -3.0 (idx 1), 2.0 (idx 2) → sorted by index + assert_eq!(vals[0], -3.0); + assert_eq!(vals[1], 2.0); + // Row 0, group 1: 0.3 (idx 2), 0.4 (idx 3) → sorted by index + assert_eq!(vals[2], 0.3); + assert_eq!(vals[3], 0.4); +} + +#[test] +fn test_sparse_24_roundtrip() { + let (client, device) = create_cpu_client(); + + let data: Vec = vec![ + 1.0, -3.0, 2.0, 0.5, 0.1, 0.2, 0.3, 0.4, 4.0, 1.0, -5.0, 3.0, 0.0, 0.0, 0.0, 0.0, + ]; + let dense = Tensor::::from_slice(&data, &[2, 8], &device); + let sparse = client.prune_to_24(&dense).unwrap(); + let reconstructed = client.sparse_24_to_dense(&sparse).unwrap(); + + let recon_data: Vec = reconstructed.to_vec(); + + // After pruning and reconstruction, only top-2 per group survive + // Row 0, group 0: kept idx 1,2 → [0, -3, 2, 0] + assert_eq!(recon_data[0], 0.0); + assert_eq!(recon_data[1], -3.0); + assert_eq!(recon_data[2], 2.0); + assert_eq!(recon_data[3], 0.0); + + // Row 0, group 1: kept idx 2,3 → [0, 0, 0.3, 0.4] + assert_eq!(recon_data[4], 0.0); + assert_eq!(recon_data[5], 0.0); + assert_eq!(recon_data[6], 0.3); + assert_eq!(recon_data[7], 0.4); +} + +#[test] +fn test_sparse_24_matmul_matches_dense() { + use numr::prelude::MatmulOps; + + let (client, device) = create_cpu_client(); + + // Weight: [4, 8], Input: [2, 8] + let weight_data: Vec = vec![ + 1.0, -3.0, 2.0, 0.5, 0.1, 0.2, 0.3, 0.4, 4.0, 1.0, -5.0, 3.0, 0.5, 0.5, 0.5, 0.5, 2.0, 0.0, + 1.0, 0.0, 0.0, 3.0, 0.0, 1.0, 0.5, 1.5, 0.5, 1.5, 2.0, 0.0, 2.0, 0.0, + ]; + let weight = Tensor::::from_slice(&weight_data, &[4, 8], &device); + + let input_data: Vec = vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, + ]; + let input = Tensor::::from_slice(&input_data, &[2, 8], &device); + + // Prune weight + let sparse_weight = client.prune_to_24(&weight).unwrap(); + + // Sparse matmul + let sparse_result = client.sparse_24_matmul(&input, &sparse_weight).unwrap(); + + // Dense matmul with pruned weight + let dense_pruned = client.sparse_24_to_dense(&sparse_weight).unwrap(); + let dense_pruned_t = dense_pruned.t().unwrap(); + let dense_result = client.matmul(&input, &dense_pruned_t).unwrap(); + + let sparse_out: Vec = sparse_result.to_vec(); + let dense_out: Vec = dense_result.to_vec(); + + assert_parity_f32(&sparse_out, &dense_out, "sparse_24_matmul vs dense"); +} + +#[test] +fn test_sparse_24_matmul_larger() { + use numr::prelude::MatmulOps; + + let (client, device) = create_cpu_client(); + + // Larger: weight [16, 32], input [8, 32] + let weight_data: Vec = (0..16 * 32).map(|i| (i as f32 * 0.1).sin() * 3.0).collect(); + let weight = Tensor::::from_slice(&weight_data, &[16, 32], &device); + + let input_data: Vec = (0..8 * 32).map(|i| (i as f32 * 0.07).cos() * 2.0).collect(); + let input = Tensor::::from_slice(&input_data, &[8, 32], &device); + + let sparse_weight = client.prune_to_24(&weight).unwrap(); + let sparse_result = client.sparse_24_matmul(&input, &sparse_weight).unwrap(); + + let dense_pruned = client.sparse_24_to_dense(&sparse_weight).unwrap(); + let dense_pruned_t = dense_pruned.t().unwrap(); + let dense_result = client.matmul(&input, &dense_pruned_t).unwrap(); + + let sparse_out: Vec = sparse_result.to_vec(); + let dense_out: Vec = dense_result.to_vec(); + + assert_parity_f32(&sparse_out, &dense_out, "sparse_24_matmul_larger"); +} + +// ============================================================================ +// CUDA backend parity tests +// ============================================================================ + +#[cfg(feature = "cuda")] +mod cuda_parity { + use super::*; + use crate::backend_parity::helpers::{assert_parity_f32, with_cuda_backend}; + use numr::runtime::cuda::CudaRuntime; + use numr::sparse::Sparse24Ops; + + #[test] + fn test_prune_to_24_parity_cuda() { + let (cpu_client, cpu_device) = create_cpu_client(); + + let data: Vec = (0..4 * 16).map(|i| (i as f32 * 0.13).sin() * 5.0).collect(); + let cpu_dense = Tensor::::from_slice(&data, &[4, 16], &cpu_device); + let cpu_sparse = cpu_client.prune_to_24(&cpu_dense).unwrap(); + let cpu_vals: Vec = cpu_sparse.compressed_values().to_vec(); + let cpu_meta: Vec = cpu_sparse.metadata().to_vec(); + + with_cuda_backend(|cuda_client, cuda_device| { + let cuda_dense = Tensor::::from_slice(&data, &[4, 16], &cuda_device); + let cuda_sparse = cuda_client.prune_to_24(&cuda_dense).unwrap(); + let cuda_vals: Vec = cuda_sparse.compressed_values().to_vec(); + let cuda_meta: Vec = cuda_sparse.metadata().to_vec(); + + assert_parity_f32(&cuda_vals, &cpu_vals, "prune_to_24 values CUDA vs CPU"); + assert_eq!(cuda_meta, cpu_meta, "prune_to_24 metadata CUDA vs CPU"); + }); + } + + #[test] + fn test_sparse_24_roundtrip_parity_cuda() { + let (cpu_client, cpu_device) = create_cpu_client(); + + let data: Vec = (0..4 * 16).map(|i| (i as f32 * 0.13).sin() * 5.0).collect(); + let cpu_dense = Tensor::::from_slice(&data, &[4, 16], &cpu_device); + let cpu_sparse = cpu_client.prune_to_24(&cpu_dense).unwrap(); + let cpu_recon: Vec = cpu_client.sparse_24_to_dense(&cpu_sparse).unwrap().to_vec(); + + with_cuda_backend(|cuda_client, cuda_device| { + let cuda_dense = Tensor::::from_slice(&data, &[4, 16], &cuda_device); + let cuda_sparse = cuda_client.prune_to_24(&cuda_dense).unwrap(); + let cuda_recon: Vec = cuda_client + .sparse_24_to_dense(&cuda_sparse) + .unwrap() + .to_vec(); + + assert_parity_f32(&cuda_recon, &cpu_recon, "roundtrip CUDA vs CPU"); + }); + } + + #[test] + fn test_sparse_24_matmul_parity_cuda() { + let (cpu_client, cpu_device) = create_cpu_client(); + + let weight_data: Vec = (0..8 * 16).map(|i| (i as f32 * 0.1).sin() * 3.0).collect(); + let input_data: Vec = (0..4 * 16).map(|i| (i as f32 * 0.07).cos() * 2.0).collect(); + + let cpu_weight = Tensor::::from_slice(&weight_data, &[8, 16], &cpu_device); + let cpu_input = Tensor::::from_slice(&input_data, &[4, 16], &cpu_device); + let cpu_sparse = cpu_client.prune_to_24(&cpu_weight).unwrap(); + let cpu_result: Vec = cpu_client + .sparse_24_matmul(&cpu_input, &cpu_sparse) + .unwrap() + .to_vec(); + + with_cuda_backend(|cuda_client, cuda_device| { + let cuda_weight = + Tensor::::from_slice(&weight_data, &[8, 16], &cuda_device); + let cuda_input = Tensor::::from_slice(&input_data, &[4, 16], &cuda_device); + let cuda_sparse = cuda_client.prune_to_24(&cuda_weight).unwrap(); + let cuda_result: Vec = cuda_client + .sparse_24_matmul(&cuda_input, &cuda_sparse) + .unwrap() + .to_vec(); + + assert_parity_f32(&cuda_result, &cpu_result, "sparse_24_matmul CUDA vs CPU"); + }); + } +} + +// ============================================================================ +// WebGPU backend parity tests +// ============================================================================ + +#[cfg(feature = "wgpu")] +mod wgpu_parity { + use super::*; + use crate::backend_parity::helpers::{assert_parity_f32, with_wgpu_backend}; + use numr::runtime::wgpu::WgpuRuntime; + use numr::sparse::Sparse24Ops; + + #[test] + fn test_prune_to_24_parity_wgpu() { + let (cpu_client, cpu_device) = create_cpu_client(); + + let data: Vec = (0..4 * 16).map(|i| (i as f32 * 0.13).sin() * 5.0).collect(); + let cpu_dense = Tensor::::from_slice(&data, &[4, 16], &cpu_device); + let cpu_sparse = cpu_client.prune_to_24(&cpu_dense).unwrap(); + let cpu_vals: Vec = cpu_sparse.compressed_values().to_vec(); + let cpu_meta: Vec = cpu_sparse.metadata().to_vec(); + + with_wgpu_backend(|wgpu_client, wgpu_device| { + let wgpu_dense = Tensor::::from_slice(&data, &[4, 16], &wgpu_device); + let wgpu_sparse = wgpu_client.prune_to_24(&wgpu_dense).unwrap(); + let wgpu_vals: Vec = wgpu_sparse.compressed_values().to_vec(); + let wgpu_meta: Vec = wgpu_sparse.metadata().to_vec(); + + assert_parity_f32(&wgpu_vals, &cpu_vals, "prune_to_24 values WGPU vs CPU"); + assert_eq!(wgpu_meta, cpu_meta, "prune_to_24 metadata WGPU vs CPU"); + }); + } + + #[test] + fn test_sparse_24_roundtrip_parity_wgpu() { + let (cpu_client, cpu_device) = create_cpu_client(); + + let data: Vec = (0..4 * 16).map(|i| (i as f32 * 0.13).sin() * 5.0).collect(); + let cpu_dense = Tensor::::from_slice(&data, &[4, 16], &cpu_device); + let cpu_sparse = cpu_client.prune_to_24(&cpu_dense).unwrap(); + let cpu_recon: Vec = cpu_client.sparse_24_to_dense(&cpu_sparse).unwrap().to_vec(); + + with_wgpu_backend(|wgpu_client, wgpu_device| { + let wgpu_dense = Tensor::::from_slice(&data, &[4, 16], &wgpu_device); + let wgpu_sparse = wgpu_client.prune_to_24(&wgpu_dense).unwrap(); + let wgpu_recon: Vec = wgpu_client + .sparse_24_to_dense(&wgpu_sparse) + .unwrap() + .to_vec(); + + assert_parity_f32(&wgpu_recon, &cpu_recon, "roundtrip WGPU vs CPU"); + }); + } + + #[test] + fn test_sparse_24_matmul_parity_wgpu() { + let (cpu_client, cpu_device) = create_cpu_client(); + + let weight_data: Vec = (0..8 * 16).map(|i| (i as f32 * 0.1).sin() * 3.0).collect(); + let input_data: Vec = (0..4 * 16).map(|i| (i as f32 * 0.07).cos() * 2.0).collect(); + + let cpu_weight = Tensor::::from_slice(&weight_data, &[8, 16], &cpu_device); + let cpu_input = Tensor::::from_slice(&input_data, &[4, 16], &cpu_device); + let cpu_sparse = cpu_client.prune_to_24(&cpu_weight).unwrap(); + let cpu_result: Vec = cpu_client + .sparse_24_matmul(&cpu_input, &cpu_sparse) + .unwrap() + .to_vec(); + + with_wgpu_backend(|wgpu_client, wgpu_device| { + let wgpu_weight = + Tensor::::from_slice(&weight_data, &[8, 16], &wgpu_device); + let wgpu_input = Tensor::::from_slice(&input_data, &[4, 16], &wgpu_device); + let wgpu_sparse = wgpu_client.prune_to_24(&wgpu_weight).unwrap(); + let wgpu_result: Vec = wgpu_client + .sparse_24_matmul(&wgpu_input, &wgpu_sparse) + .unwrap() + .to_vec(); + + assert_parity_f32(&wgpu_result, &cpu_result, "sparse_24_matmul WGPU vs CPU"); + }); + } +} diff --git a/tests/backend_parity/special.rs b/tests/backend_parity/special.rs index eca7e142..7846db0a 100644 --- a/tests/backend_parity/special.rs +++ b/tests/backend_parity/special.rs @@ -6,7 +6,6 @@ use numr::dtype::DType; use numr::ops::SpecialFunctions; use numr::runtime::Runtime; -use numr::runtime::cpu::CpuRuntime; use numr::tensor::Tensor; use crate::backend_parity::dtype_helpers::tensor_from_f64; diff --git a/tests/backend_parity/statistics.rs b/tests/backend_parity/statistics.rs index 7655c41f..7454d538 100644 --- a/tests/backend_parity/statistics.rs +++ b/tests/backend_parity/statistics.rs @@ -5,7 +5,6 @@ use numr::dtype::DType; use numr::ops::StatisticalOps; -use numr::runtime::Runtime; use numr::tensor::Tensor; use crate::backend_parity::dtype_helpers::tensor_from_f64; diff --git a/tests/backend_parity/utility.rs b/tests/backend_parity/utility.rs new file mode 100644 index 00000000..27e078a7 --- /dev/null +++ b/tests/backend_parity/utility.rs @@ -0,0 +1,356 @@ +// Backend parity tests for UtilityOps trait +// +// Tests: clamp, fill, arange, linspace, eye, one_hot +// CPU is the reference implementation; CUDA and WebGPU must match. + +use crate::backend_parity::dtype_helpers::tensor_from_f64; +#[cfg(feature = "cuda")] +use crate::backend_parity::helpers::with_cuda_backend; +#[cfg(feature = "wgpu")] +use crate::backend_parity::helpers::with_wgpu_backend; +use crate::common::{ + assert_tensor_allclose, create_cpu_client, is_dtype_supported, supported_dtypes, +}; +use numr::dtype::DType; +use numr::ops::UtilityOps; +use numr::tensor::Tensor; + +// ============================================================================ +// clamp +// ============================================================================ + +fn test_clamp_parity(dtype: DType) { + let (cpu_client, cpu_device) = create_cpu_client(); + + let data = vec![-2.0, -1.0, 0.0, 0.5, 1.0, 2.0, 3.0, 5.0]; + let shape = vec![8]; + let min_val = 0.0; + let max_val = 3.0; + + let cpu_input = tensor_from_f64(&data, &shape, dtype, &cpu_device, &cpu_client) + .expect("CPU tensor creation failed"); + let cpu_result = cpu_client + .clamp(&cpu_input, min_val, max_val) + .expect("CPU clamp failed"); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, cuda_device| { + let input = tensor_from_f64(&data, &shape, dtype, &cuda_device, &cuda_client) + .expect("CUDA tensor creation failed"); + let result = cuda_client + .clamp(&input, min_val, max_val) + .expect("CUDA clamp failed"); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("clamp CUDA vs CPU [{dtype:?}]"), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, wgpu_device| { + let input = tensor_from_f64(&data, &shape, dtype, &wgpu_device, &wgpu_client) + .expect("WebGPU tensor creation failed"); + let result = wgpu_client + .clamp(&input, min_val, max_val) + .expect("WebGPU clamp failed"); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("clamp WebGPU vs CPU [{dtype:?}]"), + ); + }); + } +} + +#[test] +fn test_clamp_parity_all_dtypes() { + for dtype in supported_dtypes("cpu") { + test_clamp_parity(dtype); + } +} + +// ============================================================================ +// fill +// ============================================================================ + +fn test_fill_parity(dtype: DType) { + let (cpu_client, _cpu_device) = create_cpu_client(); + + let shapes: Vec> = vec![vec![4], vec![2, 3], vec![2, 2, 2]]; + let values = vec![0.0, 1.0, 42.0, -3.5]; + + for shape in &shapes { + for &value in &values { + let cpu_result = cpu_client + .fill(shape, value, dtype) + .expect("CPU fill failed"); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, _cuda_device| { + let result = cuda_client + .fill(shape, value, dtype) + .expect("CUDA fill failed"); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("fill({value}) CUDA vs CPU [{dtype:?}] shape {shape:?}"), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, _wgpu_device| { + let result = wgpu_client + .fill(shape, value, dtype) + .expect("WebGPU fill failed"); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("fill({value}) WebGPU vs CPU [{dtype:?}] shape {shape:?}"), + ); + }); + } + } + } +} + +#[test] +fn test_fill_parity_all_dtypes() { + for dtype in supported_dtypes("cpu") { + test_fill_parity(dtype); + } +} + +// ============================================================================ +// arange +// ============================================================================ + +fn test_arange_parity(dtype: DType) { + let (cpu_client, _cpu_device) = create_cpu_client(); + + let cases: Vec<(f64, f64, f64)> = vec![ + (0.0, 5.0, 1.0), // [0, 1, 2, 3, 4] + (0.0, 6.0, 2.0), // [0, 2, 4] + (1.0, 10.0, 3.0), // [1, 4, 7] + (5.0, 0.0, -1.0), // [5, 4, 3, 2, 1] + (0.0, 1.0, 0.25), // [0, 0.25, 0.5, 0.75] + ]; + + for (start, stop, step) in &cases { + let cpu_result = cpu_client + .arange(*start, *stop, *step, dtype) + .expect("CPU arange failed"); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, _cuda_device| { + let result = cuda_client + .arange(*start, *stop, *step, dtype) + .expect("CUDA arange failed"); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("arange({start},{stop},{step}) CUDA vs CPU [{dtype:?}]"), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, _wgpu_device| { + let result = wgpu_client + .arange(*start, *stop, *step, dtype) + .expect("WebGPU arange failed"); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("arange({start},{stop},{step}) WebGPU vs CPU [{dtype:?}]"), + ); + }); + } + } +} + +#[test] +fn test_arange_parity_all_dtypes() { + for dtype in supported_dtypes("cpu") { + test_arange_parity(dtype); + } +} + +// ============================================================================ +// linspace +// ============================================================================ + +fn test_linspace_parity(dtype: DType) { + let (cpu_client, _cpu_device) = create_cpu_client(); + + let cases: Vec<(f64, f64, usize)> = vec![ + (0.0, 10.0, 5), // [0, 2.5, 5, 7.5, 10] + (0.0, 1.0, 3), // [0, 0.5, 1] + (-1.0, 1.0, 5), // [-1, -0.5, 0, 0.5, 1] + (0.0, 100.0, 11), // [0, 10, 20, ..., 100] + (5.0, 5.0, 3), // [5, 5, 5] + ]; + + for (start, stop, steps) in &cases { + let cpu_result = cpu_client + .linspace(*start, *stop, *steps, dtype) + .expect("CPU linspace failed"); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, _cuda_device| { + let result = cuda_client + .linspace(*start, *stop, *steps, dtype) + .expect("CUDA linspace failed"); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("linspace({start},{stop},{steps}) CUDA vs CPU [{dtype:?}]"), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, _wgpu_device| { + let result = wgpu_client + .linspace(*start, *stop, *steps, dtype) + .expect("WebGPU linspace failed"); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("linspace({start},{stop},{steps}) WebGPU vs CPU [{dtype:?}]"), + ); + }); + } + } +} + +#[test] +fn test_linspace_parity_all_dtypes() { + for dtype in supported_dtypes("cpu") { + test_linspace_parity(dtype); + } +} + +// ============================================================================ +// eye +// ============================================================================ + +fn test_eye_parity(dtype: DType) { + let (cpu_client, _cpu_device) = create_cpu_client(); + + let cases: Vec<(usize, Option)> = vec![ + (3, None), // 3x3 identity + (4, None), // 4x4 identity + (2, Some(4)), // 2x4 rectangular + (4, Some(2)), // 4x2 rectangular + (1, None), // 1x1 identity + ]; + + for (n, m) in &cases { + let cpu_result = cpu_client.eye(*n, *m, dtype).expect("CPU eye failed"); + + #[cfg(feature = "cuda")] + if is_dtype_supported("cuda", dtype) { + with_cuda_backend(|cuda_client, _cuda_device| { + let result = cuda_client.eye(*n, *m, dtype).expect("CUDA eye failed"); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("eye({n},{m:?}) CUDA vs CPU [{dtype:?}]"), + ); + }); + } + + #[cfg(feature = "wgpu")] + if is_dtype_supported("wgpu", dtype) { + with_wgpu_backend(|wgpu_client, _wgpu_device| { + let result = wgpu_client.eye(*n, *m, dtype).expect("WebGPU eye failed"); + assert_tensor_allclose( + &result, + &cpu_result, + dtype, + &format!("eye({n},{m:?}) WebGPU vs CPU [{dtype:?}]"), + ); + }); + } + } +} + +#[test] +fn test_eye_parity_all_dtypes() { + for dtype in supported_dtypes("cpu") { + test_eye_parity(dtype); + } +} + +// ============================================================================ +// one_hot +// ============================================================================ + +#[test] +fn test_one_hot_parity() { + let (cpu_client, cpu_device) = create_cpu_client(); + + let cases: Vec<(Vec, Vec, usize)> = vec![ + (vec![0, 1, 2], vec![3], 3), // Simple 1D + (vec![0, 2, 1, 3], vec![4], 5), // With num_classes > max index + (vec![0, 1, 2, 3], vec![2, 2], 4), // 2D indices + ]; + + for (data, shape, num_classes) in &cases { + let cpu_indices = + Tensor::::from_slice(data, shape, &cpu_device); + let cpu_result = cpu_client + .one_hot(&cpu_indices, *num_classes) + .expect("CPU one_hot failed"); + + #[cfg(feature = "cuda")] + with_cuda_backend(|cuda_client, cuda_device| { + let indices = + Tensor::::from_slice(data, shape, &cuda_device); + let result = cuda_client + .one_hot(&indices, *num_classes) + .expect("CUDA one_hot failed"); + assert_tensor_allclose( + &result, + &cpu_result, + DType::F32, + &format!("one_hot CUDA vs CPU shape {shape:?} classes {num_classes}"), + ); + }); + + #[cfg(feature = "wgpu")] + with_wgpu_backend(|wgpu_client, wgpu_device| { + let indices = + Tensor::::from_slice(data, shape, &wgpu_device); + let result = wgpu_client + .one_hot(&indices, *num_classes) + .expect("WebGPU one_hot failed"); + assert_tensor_allclose( + &result, + &cpu_result, + DType::F32, + &format!("one_hot WebGPU vs CPU shape {shape:?} classes {num_classes}"), + ); + }); + } +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index d144c5ea..7f611126 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -161,7 +161,7 @@ pub fn tolerance_for_dtype(dtype: DType) -> (f64, f64) { DType::F64 => (1e-12, 1e-14), // Machine epsilon-level tolerance DType::F16 => (0.01, 0.1), // 1% relative tolerance for half-precision DType::BF16 => (0.01, 0.1), // 1% relative tolerance for BF16 - DType::FP8E4M3 => (0.1, 1.0), // 10% relative — 4-bit mantissa; atol=1.0 because floor/trunc can differ by 1 ULP + DType::FP8E4M3 => (0.3, 2.5), // 30% relative — 4-bit mantissa; atol=2.5 for compound ops (norm bwd, gemm) DType::FP8E5M2 => (1.0, 2.5), // Very coarse — 2-bit mantissa; atol=2.5 because scatter_reduce/cov accumulate rounding error _ => (1e-5, 1e-6), // Default tolerance } @@ -333,7 +333,7 @@ impl ToF64 for numr::dtype::FP8E5M2 { /// This function normalizes all of them to Vec for uniform comparison. /// /// Nonzero = true, zero = false. -pub fn readback_as_bool(tensor: &numr::tensor::Tensor) -> Vec { +pub fn readback_as_bool>(tensor: &numr::tensor::Tensor) -> Vec { macro_rules! nonzero { ($T:ty) => { tensor @@ -345,7 +345,7 @@ pub fn readback_as_bool(tensor: &numr::tensor::Tensor) -> Vec tensor.to_vec::().iter().map(|&x| x != 0).collect(), + DType::Bool | DType::U8 => tensor.to_vec::().iter().map(|&x| x != 0).collect(), DType::U32 => tensor.to_vec::().iter().map(|&x| x != 0).collect(), DType::I32 => tensor.to_vec::().iter().map(|&x| x != 0).collect(), DType::F32 => nonzero!(f32), diff --git a/tests/external_backend_api.rs b/tests/external_backend_api.rs index c30ccba3..467d6eb5 100644 --- a/tests/external_backend_api.rs +++ b/tests/external_backend_api.rs @@ -41,12 +41,22 @@ impl Runtime for MockRuntime { type Device = MockDevice; type Client = MockClient; type Allocator = MockAllocator; + type Graph = numr::runtime::NoOpGraph; type RawHandle = (); + type DType = numr::dtype::DType; fn name() -> &'static str { "mock" } + fn capture_graph(client: &Self::Client, f: F) -> error::Result<(Self::Graph, T)> + where + F: FnOnce(&Self::Client) -> error::Result, + { + let result = f(client)?; + Ok((numr::runtime::NoOpGraph, result)) + } + fn allocate(_size_bytes: usize, _device: &Self::Device) -> error::Result { Ok(0) } diff --git a/tests/index_ops.rs b/tests/index_ops.rs deleted file mode 100644 index cb950155..00000000 --- a/tests/index_ops.rs +++ /dev/null @@ -1,19 +0,0 @@ -//! Integration tests for index operations (embedding_lookup, gather, scatter, index_select) -//! -//! Tests verify correctness across: -//! - Different dtypes (f32, f64, i32) -//! - Various embedding dimensions -//! - Boundary conditions -//! - Edge cases (single element, out of bounds handling) - -#[path = "index_ops/advanced.rs"] -mod advanced; - -#[path = "index_ops/embedding.rs"] -mod embedding; - -#[path = "index_ops/gather_scatter.rs"] -mod gather_scatter; - -#[path = "index_ops/masked.rs"] -mod masked; diff --git a/tests/index_ops/advanced.rs b/tests/index_ops/advanced.rs deleted file mode 100644 index ae351714..00000000 --- a/tests/index_ops/advanced.rs +++ /dev/null @@ -1,2 +0,0 @@ -//! Advanced indexing integration tests have moved to `tests/backend_parity/indexing_advanced.rs`. -//! Keep this file as a migration marker for old test paths. diff --git a/tests/index_ops/embedding.rs b/tests/index_ops/embedding.rs deleted file mode 100644 index df942b24..00000000 --- a/tests/index_ops/embedding.rs +++ /dev/null @@ -1,2 +0,0 @@ -//! Embedding integration tests have moved to `tests/backend_parity/indexing_advanced.rs`. -//! Keep this file as a migration marker for old test paths. diff --git a/tests/index_ops/gather_scatter.rs b/tests/index_ops/gather_scatter.rs deleted file mode 100644 index 00712486..00000000 --- a/tests/index_ops/gather_scatter.rs +++ /dev/null @@ -1,2 +0,0 @@ -//! Gather/scatter integration tests have moved to `tests/backend_parity/indexing_advanced.rs`. -//! Keep this file as a migration marker for old test paths. diff --git a/tests/index_ops/masked.rs b/tests/index_ops/masked.rs deleted file mode 100644 index 78a6d1cf..00000000 --- a/tests/index_ops/masked.rs +++ /dev/null @@ -1,2 +0,0 @@ -//! Masked indexing integration tests have moved to `tests/backend_parity/indexing.rs`. -//! Keep this file as a migration marker for old test paths. diff --git a/tests/ml_dtype_audit.rs b/tests/ml_dtype_audit.rs new file mode 100644 index 00000000..42d3a6b4 --- /dev/null +++ b/tests/ml_dtype_audit.rs @@ -0,0 +1,230 @@ +//! DType Audit for ML Workloads +//! +//! Tests F16, BF16, FP8E4M3, FP8E5M2 support across ML-critical operations. +//! All helpers are feature-gated so they only compile when the relevant dtype +//! features are enabled. + +#[cfg(any(feature = "f16", feature = "fp8"))] +mod common; + +#[cfg(any(feature = "f16", feature = "fp8"))] +use common::create_cpu_client; +#[cfg(any(feature = "f16", feature = "fp8"))] +use numr::dtype::DType; +#[cfg(any(feature = "f16", feature = "fp8"))] +use numr::error::Result; +#[cfg(any(feature = "f16", feature = "fp8"))] +use numr::ops::*; +#[cfg(any(feature = "f16", feature = "fp8"))] +use numr::runtime::cpu::CpuRuntime; +#[cfg(any(feature = "f16", feature = "fp8"))] +use numr::tensor::Tensor; + +#[cfg(any(feature = "f16", feature = "fp8"))] +fn make_tensor( + data: &[f32], + shape: &[usize], + dtype: DType, + device: &::Device, + client: &impl TypeConversionOps, +) -> Result> { + let t = Tensor::from_slice(data, shape, device); + if dtype == DType::F32 { + Ok(t) + } else { + client.cast(&t, dtype) + } +} + +#[cfg(any(feature = "f16", feature = "fp8"))] +macro_rules! audit_op { + ($name:expr, $body:expr) => {{ + let result: Result<()> = (|| { + $body; + Ok(()) + })(); + match &result { + Ok(()) => println!(" PASS: {}", $name), + Err(e) => println!(" FAIL: {} - {}", $name, e), + } + result.is_ok() + }}; +} + +#[cfg(any(feature = "f16", feature = "fp8"))] +fn audit_dtype(dtype: DType) { + println!("\n=== Auditing {:?} ===", dtype); + let (client, device) = create_cpu_client(); + let mut pass = 0u32; + let mut fail = 0u32; + + macro_rules! tally { + ($ok:expr) => { + if $ok { + pass += 1; + } else { + fail += 1; + } + }; + } + + // Cast F32 -> target + let cast_ok = audit_op!("cast F32 -> target", { + let t = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4], &device); + let _ = client.cast(&t, dtype)?; + }); + tally!(cast_ok); + + if !cast_ok { + println!(" SKIP remaining (cast failed)"); + println!("\n Summary for {:?}: {} pass, {} fail", dtype, pass, fail); + return; + } + + let t1 = |d: &[f32], s: &[usize]| make_tensor(d, s, dtype, &device, &client); + + // Binary ops + tally!(audit_op!("add", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[4])?; + let b = t1(&[5.0, 6.0, 7.0, 8.0], &[4])?; + let _ = client.add(&a, &b)?; + })); + tally!(audit_op!("sub", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[4])?; + let b = t1(&[5.0, 6.0, 7.0, 8.0], &[4])?; + let _ = client.sub(&a, &b)?; + })); + tally!(audit_op!("mul", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[4])?; + let b = t1(&[5.0, 6.0, 7.0, 8.0], &[4])?; + let _ = client.mul(&a, &b)?; + })); + tally!(audit_op!("div", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[4])?; + let b = t1(&[5.0, 6.0, 7.0, 8.0], &[4])?; + let _ = client.div(&a, &b)?; + })); + + // Scalar ops + tally!(audit_op!("mul_scalar", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[4])?; + let _ = client.mul_scalar(&a, 2.0)?; + })); + tally!(audit_op!("add_scalar", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[4])?; + let _ = client.add_scalar(&a, 1.0)?; + })); + + // Unary ops + tally!(audit_op!("exp", { + let a = t1(&[0.0, 0.5, 1.0, 1.5], &[4])?; + let _ = client.exp(&a)?; + })); + tally!(audit_op!("log", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[4])?; + let _ = client.log(&a)?; + })); + tally!(audit_op!("sqrt", { + let a = t1(&[1.0, 4.0, 9.0, 16.0], &[4])?; + let _ = client.sqrt(&a)?; + })); + tally!(audit_op!("tanh", { + let a = t1(&[0.0, 0.5, 1.0, -1.0], &[4])?; + let _ = client.tanh(&a)?; + })); + tally!(audit_op!("neg", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[4])?; + let _ = client.neg(&a)?; + })); + + // Reduce ops (dims are usize, use last dim = 1 for [2,2]) + tally!(audit_op!("sum", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[2, 2])?; + let _ = client.sum(&a, &[1], false)?; + })); + tally!(audit_op!("max", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[2, 2])?; + let _ = client.max(&a, &[1], false)?; + })); + tally!(audit_op!("mean", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[2, 2])?; + let _ = client.mean(&a, &[1], false)?; + })); + tally!(audit_op!("argmax", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[2, 2])?; + let _ = client.argmax(&a, 1, false)?; + })); + + // Matmul (disambiguate) + tally!(audit_op!("matmul", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[2, 2])?; + let b = t1(&[5.0, 6.0, 7.0, 8.0], &[2, 2])?; + let _ = MatmulOps::matmul(&client, &a, &b)?; + })); + + // Activation ops + tally!(audit_op!("softmax", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[2, 2])?; + let _ = client.softmax(&a, -1)?; + })); + tally!(audit_op!("relu", { + let a = t1(&[-1.0, 0.0, 1.0, 2.0], &[4])?; + let _ = client.relu(&a)?; + })); + tally!(audit_op!("gelu", { + let a = t1(&[-1.0, 0.0, 1.0, 2.0], &[4])?; + let _ = client.gelu(&a)?; + })); + tally!(audit_op!("silu", { + let a = t1(&[-1.0, 0.0, 1.0, 2.0], &[4])?; + let _ = client.silu(&a)?; + })); + + // Normalization ops (require weight/bias tensors) + tally!(audit_op!("rms_norm", { + let a = t1(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])?; + let w = t1(&[1.0, 1.0, 1.0], &[3])?; + let _ = client.rms_norm(&a, &w, 1e-5)?; + })); + tally!(audit_op!("layer_norm", { + let a = t1(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])?; + let w = t1(&[1.0, 1.0, 1.0], &[3])?; + let b = t1(&[0.0, 0.0, 0.0], &[3])?; + let _ = client.layer_norm(&a, &w, &b, 1e-5)?; + })); + + // Cast back + tally!(audit_op!("cast target -> F32", { + let a = t1(&[1.0, 2.0, 3.0, 4.0], &[4])?; + let _ = client.cast(&a, DType::F32)?; + })); + + println!("\n Summary for {:?}: {} pass, {} fail", dtype, pass, fail); + if fail > 0 { + panic!("{:?} has {} failures", dtype, fail); + } +} + +#[test] +#[cfg(feature = "f16")] +fn audit_f16() { + audit_dtype(DType::F16); +} + +#[test] +#[cfg(feature = "f16")] +fn audit_bf16() { + audit_dtype(DType::BF16); +} + +#[test] +#[cfg(feature = "fp8")] +fn audit_fp8e4m3() { + audit_dtype(DType::FP8E4M3); +} + +#[test] +#[cfg(feature = "fp8")] +fn audit_fp8e5m2() { + audit_dtype(DType::FP8E5M2); +} diff --git a/tests/wgpu_integer_ops.rs b/tests/wgpu_integer_ops.rs index 3d71e599..87f064a7 100644 --- a/tests/wgpu_integer_ops.rs +++ b/tests/wgpu_integer_ops.rs @@ -119,7 +119,7 @@ fn test_u32_mul() { // ============================================================================ #[test] -fn test_i32_neg() { +fn test_f32_neg() { if !numr::runtime::wgpu::is_wgpu_available() { println!("WebGPU not available, skipping"); return; @@ -128,16 +128,16 @@ fn test_i32_neg() { let device = WgpuDevice::new(0); let client = WgpuRuntime::default_client(&device); - let a = Tensor::::from_slice(&[1i32, -2, 3, -4], &[4], &device); + let a = Tensor::::from_slice(&[1.0f32, -2.0, 3.0, -4.0], &[4], &device); let result = client.neg(&a).unwrap(); - let data: Vec = result.to_vec(); - assert_eq!(data, vec![-1, 2, -3, 4]); + let data: Vec = result.to_vec(); + assert_eq!(data, vec![-1.0, 2.0, -3.0, 4.0]); } #[test] -fn test_i32_abs() { +fn test_f32_abs() { if !numr::runtime::wgpu::is_wgpu_available() { println!("WebGPU not available, skipping"); return; @@ -146,34 +146,12 @@ fn test_i32_abs() { let device = WgpuDevice::new(0); let client = WgpuRuntime::default_client(&device); - let a = Tensor::::from_slice(&[1i32, -2, 3, -4], &[4], &device); + let a = Tensor::::from_slice(&[1.0f32, -2.0, 3.0, -4.0], &[4], &device); let result = client.abs(&a).unwrap(); - let data: Vec = result.to_vec(); - assert_eq!(data, vec![1, 2, 3, 4]); -} - -// ============================================================================ -// Unary Operations (U32) -// ============================================================================ - -#[test] -fn test_u32_abs() { - if !numr::runtime::wgpu::is_wgpu_available() { - println!("WebGPU not available, skipping"); - return; - } - - let device = WgpuDevice::new(0); - let client = WgpuRuntime::default_client(&device); - - let a = Tensor::::from_slice(&[1u32, 2, 3, 4], &[4], &device); - - let result = client.abs(&a).unwrap(); - - let data: Vec = result.to_vec(); - assert_eq!(data, vec![1, 2, 3, 4]); // abs of unsigned is identity + let data: Vec = result.to_vec(); + assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]); } // ============================================================================ @@ -237,7 +215,7 @@ fn test_i32_exp_should_fail() { // ============================================================================ #[test] -fn test_i32_eq() { +fn test_f32_eq() { if !numr::runtime::wgpu::is_wgpu_available() { println!("WebGPU not available, skipping"); return; @@ -246,13 +224,11 @@ fn test_i32_eq() { let device = WgpuDevice::new(0); let client = WgpuRuntime::default_client(&device); - let a = Tensor::::from_slice(&[1i32, 2, 3, 4], &[4], &device); - let b = Tensor::::from_slice(&[1i32, 0, 3, 0], &[4], &device); + let a = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4], &device); + let b = Tensor::::from_slice(&[1.0f32, 0.0, 3.0, 0.0], &[4], &device); let result = client.eq(&a, &b).unwrap(); - // Note: WebGPU compare ops currently output F32 (0.0 or 1.0) - assert_eq!(result.dtype(), DType::F32); let data: Vec = result.to_vec(); assert_eq!(data, vec![1.0, 0.0, 1.0, 0.0]); } @@ -262,7 +238,7 @@ fn test_i32_eq() { // ============================================================================ #[test] -fn test_i32_sum() { +fn test_f32_sum() { if !numr::runtime::wgpu::is_wgpu_available() { println!("WebGPU not available, skipping"); return; @@ -271,16 +247,16 @@ fn test_i32_sum() { let device = WgpuDevice::new(0); let client = WgpuRuntime::default_client(&device); - let a = Tensor::::from_slice(&[1i32, 2, 3, 4], &[4], &device); + let a = Tensor::::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4], &device); let result = client.sum(&a, &[], false).unwrap(); - let data: Vec = result.to_vec(); - assert_eq!(data, vec![10]); + let data: Vec = result.to_vec(); + assert_eq!(data, vec![10.0]); } #[test] -fn test_i32_max() { +fn test_f32_max() { if !numr::runtime::wgpu::is_wgpu_available() { println!("WebGPU not available, skipping"); return; @@ -289,16 +265,16 @@ fn test_i32_max() { let device = WgpuDevice::new(0); let client = WgpuRuntime::default_client(&device); - let a = Tensor::::from_slice(&[1i32, 20, 3, 40, 5], &[5], &device); + let a = Tensor::::from_slice(&[1.0f32, 20.0, 3.0, 40.0, 5.0], &[5], &device); let result = client.max(&a, &[], false).unwrap(); - let data: Vec = result.to_vec(); - assert_eq!(data, vec![40]); + let data: Vec = result.to_vec(); + assert_eq!(data, vec![40.0]); } #[test] -fn test_i32_min() { +fn test_f32_min() { if !numr::runtime::wgpu::is_wgpu_available() { println!("WebGPU not available, skipping"); return; @@ -307,12 +283,12 @@ fn test_i32_min() { let device = WgpuDevice::new(0); let client = WgpuRuntime::default_client(&device); - let a = Tensor::::from_slice(&[10i32, 2, 30, 4, 50], &[5], &device); + let a = Tensor::::from_slice(&[10.0f32, 2.0, 30.0, 4.0, 50.0], &[5], &device); let result = client.min(&a, &[], false).unwrap(); - let data: Vec = result.to_vec(); - assert_eq!(data, vec![2]); + let data: Vec = result.to_vec(); + assert_eq!(data, vec![2.0]); } // ============================================================================