diff --git a/kernels/src/kernels/utils.py b/kernels/src/kernels/utils.py index eea224e9..8dd7bf2e 100644 --- a/kernels/src/kernels/utils.py +++ b/kernels/src/kernels/utils.py @@ -201,11 +201,11 @@ def install_kernel( package_name = package_name_from_repo_id(repo_id) variants = get_variants(api, repo_id=repo_id, revision=revision) - variant = resolve_variant(variants, backend) - - if variant is None: + try: + variant = resolve_variant(variants, backend) + except FileNotFoundError as e: raise FileNotFoundError( - f"Cannot find a build variant for this system in {repo_id} (revision: {revision}). Available variants: {', '.join([variant.variant_str for variant in variants])}" + f"Cannot find a build variant: {e.filename} for this system in {repo_id} (revision: {revision}). Available variants: {', '.join([variant.variant_str for variant in variants])}" ) allow_patterns = [f"build/{variant.variant_str}/*"] @@ -478,9 +478,11 @@ def load_kernel( variants = get_variants(api, repo_id=repo_id, revision=locked_sha) variant = resolve_variant(variants, backend) - if variant is None: + try: + variant = resolve_variant(variants, backend) + except FileNotFoundError as e: raise FileNotFoundError( - f"Cannot find a build variant for this system in {repo_id} (revision: {locked_sha}). Available variants: {', '.join([variant.variant_str for variant in variants])}" + f"Cannot find a build variant: {e.filename} for this system in {repo_id} (revision: {locked_sha}). Available variants: {', '.join([variant.variant_str for variant in variants])}" ) allow_patterns = [f"build/{variant.variant_str}/*"] diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index 9434b4c2..e9b454c9 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -344,7 +344,7 @@ def resolve_variants(variants: list[Variant], backend: str | None = None) -> lis tvm_ffi_version = parse(tvm_ffi.__version__) tvm_ffi_version = Version(f"{tvm_ffi_version.major}.{tvm_ffi_version.minor}") - return _resolve_variant_for_system( + variants = _resolve_variant_for_system( variants=variants, selected_backend=selected_backend, cpu=cpu, @@ -353,6 +353,11 @@ def resolve_variants(variants: list[Variant], backend: str | None = None) -> lis torch_cxx11_abi=torch_cxx11_abi, tvm_ffi_version=tvm_ffi_version, ) + if not variants: + missing_variant = FileNotFoundError( "Variant not found.") + missing_variant.filename = f"torch{torch_version.major}{torch_version.minor}-{'cxx11' if torch_cxx11_abi else 'cxx98'}-cu{selected_backend.version.major}{selected_backend.version.minor}-{cpu}-{os}" + raise missing_variant + def _resolve_variant_for_system(