diff --git a/CHANGELOG.md b/CHANGELOG.md index c7fbaa7..26796d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,14 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0. ## [Unreleased] +### Changed + - Drop the explicit TensorBoard dependency ([#12](https://github.com/microsoft/retrochimera/pull/12)) ([@kmaziarz]) +### Added + +- Expose setting `num_processes` for template-based models ([#13](https://github.com/microsoft/retrochimera/pull/13)) ([@kmaziarz]) + ## [1.1.0] - 2026-03-12 ### Changed diff --git a/retrochimera/chem/rules.py b/retrochimera/chem/rules.py index ccfe82c..2b821f3 100644 --- a/retrochimera/chem/rules.py +++ b/retrochimera/chem/rules.py @@ -182,6 +182,7 @@ def __init__( max_cumulative_prob: float = 1.0, apply_rules_timeout: Optional[float] = None, include_all_metadata: bool = False, + num_processes: Optional[int] = None, **kwargs, ) -> None: """Initialize a rule-based model. @@ -199,6 +200,8 @@ def __init__( not set, we fallback to the default setting in the server. include_all_metadata: If set, model outputs will include detailed metadata (may require substantial disk space for saving the results). + num_processes: Number of parallel processes for the rule application server. If not set, + defaults to the number of available CPUs. """ super().__init__(**kwargs) # In case this is one of several base classes in the MRO. @@ -208,6 +211,7 @@ def __init__( self._max_cumulative_prob = max_cumulative_prob self._apply_rules_timeout = apply_rules_timeout self._include_all_metadata = include_all_metadata + self._num_processes = num_processes def start_server(self, rulebase_dir: Union[str, Path]) -> None: """Instantiate the rule application server lazily.""" @@ -215,9 +219,15 @@ def start_server(self, rulebase_dir: Union[str, Path]) -> None: # Local to avoid circular import. from retrochimera.chem.rule_application_server import RuleApplicationServer - self._server = RuleApplicationServer( - rulebase_dir=rulebase_dir, rule_application_kwargs=self._rule_application_server_kwargs - ) + server_kwargs: dict[str, Any] = { + "rulebase_dir": rulebase_dir, + "rule_application_kwargs": self._rule_application_server_kwargs, + } + + if self._num_processes is not None: + server_kwargs["num_processes"] = self._num_processes + + self._server = RuleApplicationServer(**server_kwargs) def _get_server(self): if self._server is None: