diff --git a/src/extensions/CMakeLists.txt b/src/extensions/CMakeLists.txt index fccadd61..2f74f71d 100644 --- a/src/extensions/CMakeLists.txt +++ b/src/extensions/CMakeLists.txt @@ -36,6 +36,7 @@ target_sources( tiny_llm_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/src/axpby.cpp + ${CMAKE_CURRENT_LIST_DIR}/src/flash_attention.cpp ${CMAKE_CURRENT_LIST_DIR}/src/utils.cpp ) @@ -58,6 +59,7 @@ if(MLX_BUILD_METAL) tiny_llm_ext SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/axpby.metal + ${CMAKE_CURRENT_LIST_DIR}/src/flash_attention.metal INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS} diff --git a/src/extensions/bindings.cpp b/src/extensions/bindings.cpp index 24437b03..918c32e6 100644 --- a/src/extensions/bindings.cpp +++ b/src/extensions/bindings.cpp @@ -31,4 +31,20 @@ NB_MODULE(_ext, m) { Returns: array: ``alpha * x + beta * y`` )"); + + m.def("flash_attention", &tiny_llm_ext::flash_attention, "query"_a, "key"_a, "value"_a, "mask"_a, "scale"_a = 1.0, + "is_causal"_a = false, "num_kv_heads"_a, "num_heads"_a, "stream"_a = nb::none(), R"( + Flash attention layer (student implementation) + + Args: + query (array): Query array. + key (array): Key array. + value (array): Value array. + mask (array): Mask array. + scale (float): Scaling factor. + is_causal (bool): Enable causal-mask fast path. + + Returns: + array: ``softmax(query @ key.T * scale) @ value`` + )"); } diff --git a/src/extensions/src/flash_attention.cpp b/src/extensions/src/flash_attention.cpp new file mode 100644 index 00000000..4c1e1f66 --- /dev/null +++ b/src/extensions/src/flash_attention.cpp @@ -0,0 +1,51 @@ +// Copyright © 2023-2025 Apple Inc. + +#include + +#include "tiny_llm_ext.h" + +namespace tiny_llm_ext { + +mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const mx::array &mask, + const float scale, const bool is_causal, const int num_kv_heads, const int num_heads, + mx::StreamOrDevice s /* = {} */) { + // TODO(student): implement flash attention. + (void)q; + (void)k; + (void)v; + (void)mask; + (void)scale; + (void)is_causal; + (void)num_kv_heads; + (void)num_heads; + (void)s; + throw std::runtime_error("flash_attention is not implemented."); +} + +void FlashAttention::eval_cpu(const std::vector &inputs, std::vector &outputs) { + // TODO(student): implement CPU kernel. + (void)inputs; + (void)outputs; + throw std::runtime_error("FlashAttention::eval_cpu is not implemented."); +} + +#ifdef _METAL_ + +void FlashAttention::eval_gpu(const std::vector &inputs, std::vector &outputs) { + // TODO(student): implement Metal kernel dispatch. + (void)inputs; + (void)outputs; + throw std::runtime_error("FlashAttention::eval_gpu is not implemented."); +} + +#else + +void FlashAttention::eval_gpu(const std::vector &inputs, std::vector &outputs) { + (void)inputs; + (void)outputs; + throw std::runtime_error("FlashAttention has no GPU implementation."); +} + +#endif + +} // namespace tiny_llm_ext diff --git a/src/extensions/src/flash_attention.metal b/src/extensions/src/flash_attention.metal new file mode 100644 index 00000000..d6c56949 --- /dev/null +++ b/src/extensions/src/flash_attention.metal @@ -0,0 +1,52 @@ +#include +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +[[kernel]] void flash_attention_f32_e128( + device const float *q [[buffer(0)]], + device const float *k [[buffer(1)]], + device const float *v [[buffer(2)]], + device const float *mask [[buffer(3)]], + device float *out [[buffer(4)]], + constant const int *mask_shape [[buffer(5)]], + constant const int64_t *mask_strides [[buffer(6)]], + device const int &is_causal [[buffer(7)]], + device const int &N [[buffer(8)]], + device const int &L [[buffer(9)]], + device const int &S [[buffer(10)]], + device const int &E [[buffer(11)]], + device const int &num_kv_heads [[buffer(12)]], + device const int &num_heads [[buffer(13)]], + device const float &scale [[buffer(14)]], + device const int &Br [[buffer(15)]], + device const int &Bc [[buffer(16)]], + device const int &Tr [[buffer(17)]], + device const int &Tc [[buffer(18)]], + uint2 group_id [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // TODO(student): implement flash attention kernel. + (void)q; + (void)k; + (void)v; + (void)mask; + (void)out; + (void)mask_shape; + (void)mask_strides; + (void)is_causal; + (void)N; + (void)L; + (void)S; + (void)E; + (void)num_kv_heads; + (void)num_heads; + (void)scale; + (void)Br; + (void)Bc; + (void)Tr; + (void)Tc; + (void)group_id; + (void)simd_gid; + (void)simd_lid; +} diff --git a/src/extensions/src/tiny_llm_ext.h b/src/extensions/src/tiny_llm_ext.h index bbca1936..2117baf2 100644 --- a/src/extensions/src/tiny_llm_ext.h +++ b/src/extensions/src/tiny_llm_ext.h @@ -1,7 +1,11 @@ #pragma once +#include +#include + #include "mlx/ops.h" #include "mlx/primitives.h" +#include "mlx/utils.h" namespace mx = mlx::core; @@ -9,4 +13,35 @@ namespace tiny_llm_ext { void load_library(mx::Device d, const char *path); +/////////////////////////////////////////////////////////////////////////////// +// Flash Attention (student implementation) +/////////////////////////////////////////////////////////////////////////////// + +mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const mx::array &mask, + const float scale, const bool is_causal, const int num_kv_heads, const int num_heads, + mx::StreamOrDevice s = {}); + +class FlashAttention : public mx::Primitive { +public: + explicit FlashAttention(mx::Stream stream, const float scale, const bool is_causal, const int num_kv_heads, + const int num_heads) + : mx::Primitive(stream), scale_(scale), is_causal_(is_causal), num_kv_heads_(num_kv_heads), num_heads_(num_heads) {}; + + void eval_cpu(const std::vector &inputs, std::vector &outputs) override; + void eval_gpu(const std::vector &inputs, std::vector &outputs) override; + + std::pair, std::vector> vmap(const std::vector &inputs, + const std::vector &axes) override { + throw std::runtime_error("FlashAttention has no vmap implementation."); + } + + const char *name() const override { return "FlashAttention"; } + +private: + float scale_; + bool is_causal_; + int num_kv_heads_; + int num_heads_; +}; + } // namespace tiny_llm_ext