Skip to content
This repository was archived by the owner on Jun 30, 2025. It is now read-only.
Merged
102 changes: 68 additions & 34 deletions medcat/utils/envsnapshot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pkg_resources
import platform
import logging
import importlib.metadata
Expand Down Expand Up @@ -37,30 +36,6 @@ def get_direct_dependencies(include_extras: bool) -> list[str]:
return reqs


def _update_installed_dependencies_recursive(
gathered: dict[str, str],
package: pkg_resources.Distribution) -> dict[str, str]:
if package.project_name.lower() in gathered:
logger.debug("Trying to update already found transitive dependency "
"'%'", package.egg_name)
return gathered
for req in package.requires():
if req.project_name.lower() in gathered:
logger.debug("Trying to look up already found transitive "
"dependency '%'", req.project_name)
continue # don't look for it again
try:
dep = pkg_resources.get_distribution(req.project_name)
except pkg_resources.DistributionNotFound as e:
logger.warning("Unable to locate requirement '%s':",
req.project_name, exc_info=e)
continue
_update_installed_dependencies_recursive(gathered, dep)
# do this after so its dependencies get explored
gathered[dep.project_name.lower()] = dep.version
return gathered


def get_transitive_deps(direct_deps: list[str]) -> dict[str, str]:
"""Get the transitive dependencies of the direct dependencies.

Expand All @@ -70,12 +45,45 @@ def get_transitive_deps(direct_deps: list[str]) -> dict[str, str]:
Returns:
dict[str, str]: The dependency names and their corresponding versions.
"""
# map from name to version so as to avoid multiples of the same package
all_transitive_deps: dict[str, str] = {}
for dep in direct_deps:
package = pkg_resources.get_distribution(dep)
_update_installed_dependencies_recursive(all_transitive_deps, package)
return all_transitive_deps
all_deps: dict[str, str] = {}
to_process = set(direct_deps)
processed = set()
# list installed packages for ease of use
installed_packages = {
dist.metadata['name'].lower()
for dist in importlib.metadata.distributions()}

while to_process:
package = to_process.pop()
if package in processed:
continue

processed.add(package)

try:
dist = importlib.metadata.distribution(package)
except importlib.metadata.PackageNotFoundError:
# NOTE: if not installed, we won't bother
# after all, if we can save the model, clearly
# everything is working
continue
requires = dist.requires or []

for req in requires:
match = DEP_NAME_PATTERN.match(req)
if match is None:
raise ValueError(f"Malformed dependency: {req}")
dep_name = match.group(0).lower()
if (dep_name and dep_name not in processed and
dep_name in installed_packages):
all_deps[dep_name] = importlib.metadata.distribution(
dep_name).version
to_process.add(dep_name)

for direct in direct_deps:
# remove direct dependencies if they were added
all_deps.pop(direct, None)
return all_deps


def get_installed_dependencies(include_extras: bool) -> dict[str, str]:
Expand All @@ -89,13 +97,39 @@ def get_installed_dependencies(include_extras: bool) -> dict[str, str]:
"""
direct_deps = get_direct_dependencies(include_extras)
installed_packages: dict[str, str] = {}
for package in pkg_resources.working_set:
if package.project_name.lower() not in direct_deps:
for package in importlib.metadata.distributions():
req_name = package.metadata["name"].lower()
# NOTE: we're checking against the '-' typed package name not
# the import name (which will have _ instead)
req_name_dashes = req_name.replace("_", "-")
if all(cn not in direct_deps for cn in
[req_name, req_name_dashes]):
continue
installed_packages[package.project_name.lower()] = package.version
installed_packages[req_name] = package.version
return installed_packages


def is_dependency_installed(dependency: str) -> bool:
"""Checks whether a dependency is installed.

This takes into account changes such as '-' vs '_'.
For example, `typing-extensions` is a direct dependency,
but its module path will be `typing_extension` and that's
how we can find it as an installed dependency.

Args:
dependency (str): The dependency in question.

Returns:
bool: Whether the depedency has been installed.
"""
installed_deps = get_installed_dependencies(True)
dep_name = dependency.lower()
dep_name_underscores = dependency.replace("-", "_")
options = [dep_name, dep_name_underscores]
return any(option in installed_deps for option in options)


class Environment(BaseModel, AbstractSerialisable):
dependencies: dict[str, str]
transitive_deps: dict[str, str]
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_envsnapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_dir_deps_have_no_version(self):
def test_all_dir_deps_have_been_installed(self):
for dep in self.dir_deps:
with self.subTest(dep):
self.assertIn(dep, self.installed_deps)
self.assertTrue(envsnapshot.is_dependency_installed(dep))

def test_all_deps_add_to_correct(self):
# NOTE: didn't account for test/dev deps
Expand Down