diff --git a/medcat/utils/envsnapshot.py b/medcat/utils/envsnapshot.py index f3418f5..c345ad5 100644 --- a/medcat/utils/envsnapshot.py +++ b/medcat/utils/envsnapshot.py @@ -1,4 +1,3 @@ -import pkg_resources import platform import logging import importlib.metadata @@ -37,30 +36,6 @@ def get_direct_dependencies(include_extras: bool) -> list[str]: return reqs -def _update_installed_dependencies_recursive( - gathered: dict[str, str], - package: pkg_resources.Distribution) -> dict[str, str]: - if package.project_name.lower() in gathered: - logger.debug("Trying to update already found transitive dependency " - "'%'", package.egg_name) - return gathered - for req in package.requires(): - if req.project_name.lower() in gathered: - logger.debug("Trying to look up already found transitive " - "dependency '%'", req.project_name) - continue # don't look for it again - try: - dep = pkg_resources.get_distribution(req.project_name) - except pkg_resources.DistributionNotFound as e: - logger.warning("Unable to locate requirement '%s':", - req.project_name, exc_info=e) - continue - _update_installed_dependencies_recursive(gathered, dep) - # do this after so its dependencies get explored - gathered[dep.project_name.lower()] = dep.version - return gathered - - def get_transitive_deps(direct_deps: list[str]) -> dict[str, str]: """Get the transitive dependencies of the direct dependencies. @@ -70,12 +45,45 @@ def get_transitive_deps(direct_deps: list[str]) -> dict[str, str]: Returns: dict[str, str]: The dependency names and their corresponding versions. """ - # map from name to version so as to avoid multiples of the same package - all_transitive_deps: dict[str, str] = {} - for dep in direct_deps: - package = pkg_resources.get_distribution(dep) - _update_installed_dependencies_recursive(all_transitive_deps, package) - return all_transitive_deps + all_deps: dict[str, str] = {} + to_process = set(direct_deps) + processed = set() + # list installed packages for ease of use + installed_packages = { + dist.metadata['name'].lower() + for dist in importlib.metadata.distributions()} + + while to_process: + package = to_process.pop() + if package in processed: + continue + + processed.add(package) + + try: + dist = importlib.metadata.distribution(package) + except importlib.metadata.PackageNotFoundError: + # NOTE: if not installed, we won't bother + # after all, if we can save the model, clearly + # everything is working + continue + requires = dist.requires or [] + + for req in requires: + match = DEP_NAME_PATTERN.match(req) + if match is None: + raise ValueError(f"Malformed dependency: {req}") + dep_name = match.group(0).lower() + if (dep_name and dep_name not in processed and + dep_name in installed_packages): + all_deps[dep_name] = importlib.metadata.distribution( + dep_name).version + to_process.add(dep_name) + + for direct in direct_deps: + # remove direct dependencies if they were added + all_deps.pop(direct, None) + return all_deps def get_installed_dependencies(include_extras: bool) -> dict[str, str]: @@ -89,13 +97,39 @@ def get_installed_dependencies(include_extras: bool) -> dict[str, str]: """ direct_deps = get_direct_dependencies(include_extras) installed_packages: dict[str, str] = {} - for package in pkg_resources.working_set: - if package.project_name.lower() not in direct_deps: + for package in importlib.metadata.distributions(): + req_name = package.metadata["name"].lower() + # NOTE: we're checking against the '-' typed package name not + # the import name (which will have _ instead) + req_name_dashes = req_name.replace("_", "-") + if all(cn not in direct_deps for cn in + [req_name, req_name_dashes]): continue - installed_packages[package.project_name.lower()] = package.version + installed_packages[req_name] = package.version return installed_packages +def is_dependency_installed(dependency: str) -> bool: + """Checks whether a dependency is installed. + + This takes into account changes such as '-' vs '_'. + For example, `typing-extensions` is a direct dependency, + but its module path will be `typing_extension` and that's + how we can find it as an installed dependency. + + Args: + dependency (str): The dependency in question. + + Returns: + bool: Whether the depedency has been installed. + """ + installed_deps = get_installed_dependencies(True) + dep_name = dependency.lower() + dep_name_underscores = dependency.replace("-", "_") + options = [dep_name, dep_name_underscores] + return any(option in installed_deps for option in options) + + class Environment(BaseModel, AbstractSerialisable): dependencies: dict[str, str] transitive_deps: dict[str, str] diff --git a/tests/utils/test_envsnapshot.py b/tests/utils/test_envsnapshot.py index fac7825..393c0b0 100644 --- a/tests/utils/test_envsnapshot.py +++ b/tests/utils/test_envsnapshot.py @@ -30,7 +30,7 @@ def test_dir_deps_have_no_version(self): def test_all_dir_deps_have_been_installed(self): for dep in self.dir_deps: with self.subTest(dep): - self.assertIn(dep, self.installed_deps) + self.assertTrue(envsnapshot.is_dependency_installed(dep)) def test_all_deps_add_to_correct(self): # NOTE: didn't account for test/dev deps