Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions flash_mla/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
# Adapted from deepseek-ai/FlashMLA(https://github.com/deepseek-ai/FlashMLA)
__version__ = "1.0.1"

from flash_mla.flash_mla_interface import(
get_mla_metadata,
flash_mla_with_kvcache
)
__all__ = ["__version__", "get_mla_metadata", "flash_mla_with_kvcache"]


def _load_interface():
try:
from flash_mla.flash_mla_interface import flash_mla_with_kvcache, get_mla_metadata
except ImportError as exc:
raise ImportError(
"flash_mla_cuda is not available. Build and install FlashMLA from source "
"in a configured MACA environment before calling FlashMLA kernels."
) from exc
return get_mla_metadata, flash_mla_with_kvcache


def get_mla_metadata(*args, **kwargs):
metadata_func, _ = _load_interface()
return metadata_func(*args, **kwargs)


def flash_mla_with_kvcache(*args, **kwargs):
_, flash_func = _load_interface()
return flash_func(*args, **kwargs)
Comment on lines +7 to +25

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

问题分析

  1. 性能开销(Efficiency):在当前的设计中,每次调用 get_mla_metadataflash_mla_with_kvcache 都会执行 _load_interface()。虽然 Python 会缓存已导入的模块,但在高频调用的 CUDA 算子核心路径(hot path)中,频繁执行 try-exceptimport 语句仍会带来不必要的 CPU 额外开销。
  2. 易用性与类型提示(Maintainability):使用 *args, **kwargs 包装函数会导致 IDE(如 VS Code、PyCharm)和静态类型检查工具(如 mypy)无法获取原始函数的参数签名和 Docstring,降低了开发体验。

解决方案

建议使用模块级缓存来避免重复导入的开销。同时,为了保留参数签名和文档,可以直接在包装函数中显式声明参数。

_funcs_cache = None


def _load_interface():
    global _funcs_cache
    if _funcs_cache is None:
        try:
            from flash_mla.flash_mla_interface import flash_mla_with_kvcache, get_mla_metadata
            _funcs_cache = (get_mla_metadata, flash_mla_with_kvcache)
        except ImportError as exc:
            raise ImportError(
                "flash_mla_cuda is not available. Build and install FlashMLA from source "
                "in a configured MACA environment before calling FlashMLA kernels."
            ) from exc
    return _funcs_cache


def get_mla_metadata(cache_seqlens, num_heads_per_head_k: int, num_heads_k: int):
    metadata_func, _ = _load_interface()
    return metadata_func(cache_seqlens, num_heads_per_head_k, num_heads_k)


def flash_mla_with_kvcache(
    q,
    k_cache,
    block_table,
    cache_seqlens,
    head_dim_v: int,
    tile_scheduler_metadata,
    num_splits,
    softmax_scale=None,
    causal: bool = False,
):
    _, flash_func = _load_interface()
    return flash_func(
        q,
        k_cache,
        block_table,
        cache_seqlens,
        head_dim_v,
        tile_scheduler_metadata,
        num_splits,
        softmax_scale,
        causal,
    )

14 changes: 14 additions & 0 deletions tests/test_lazy_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import importlib
import unittest


class LazyImportTest(unittest.TestCase):
def test_import_package_without_compiled_extension(self):
module = importlib.import_module("flash_mla")

self.assertEqual(module.__version__, "1.0.1")
self.assertTrue(callable(module.get_mla_metadata))


if __name__ == "__main__":
unittest.main()