diff --git a/flash_mla/__init__.py b/flash_mla/__init__.py index e37cd195..352d51dd 100644 --- a/flash_mla/__init__.py +++ b/flash_mla/__init__.py @@ -1,7 +1,25 @@ # Adapted from deepseek-ai/FlashMLA(https://github.com/deepseek-ai/FlashMLA) __version__ = "1.0.1" -from flash_mla.flash_mla_interface import( - get_mla_metadata, - flash_mla_with_kvcache -) +__all__ = ["__version__", "get_mla_metadata", "flash_mla_with_kvcache"] + + +def _load_interface(): + try: + from flash_mla.flash_mla_interface import flash_mla_with_kvcache, get_mla_metadata + except ImportError as exc: + raise ImportError( + "flash_mla_cuda is not available. Build and install FlashMLA from source " + "in a configured MACA environment before calling FlashMLA kernels." + ) from exc + return get_mla_metadata, flash_mla_with_kvcache + + +def get_mla_metadata(*args, **kwargs): + metadata_func, _ = _load_interface() + return metadata_func(*args, **kwargs) + + +def flash_mla_with_kvcache(*args, **kwargs): + _, flash_func = _load_interface() + return flash_func(*args, **kwargs) diff --git a/tests/test_lazy_import.py b/tests/test_lazy_import.py new file mode 100644 index 00000000..41530435 --- /dev/null +++ b/tests/test_lazy_import.py @@ -0,0 +1,14 @@ +import importlib +import unittest + + +class LazyImportTest(unittest.TestCase): + def test_import_package_without_compiled_extension(self): + module = importlib.import_module("flash_mla") + + self.assertEqual(module.__version__, "1.0.1") + self.assertTrue(callable(module.get_mla_metadata)) + + +if __name__ == "__main__": + unittest.main()