diff --git a/src/cwl_loader/__init__.py b/src/cwl_loader/__init__.py index 865d83c..bed0832 100644 --- a/src/cwl_loader/__init__.py +++ b/src/cwl_loader/__init__.py @@ -24,8 +24,9 @@ from loguru import logger from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap -from typing import Any, List, Mapping, TextIO +from typing import Any, List, Mapping, Optional, TextIO, Tuple from urllib.parse import urlparse, urldefrag +import copy import requests import os @@ -37,6 +38,190 @@ _yaml = YAML() _global_loader = default_loader() +# Module-level caches storing context from the most recent load. +# Cleared at the start of each top-level load (depth == 0) to prevent leaking +# state between successive loads. +_custom_requirements_cache: dict = {} +_original_namespaces: dict = {} +_load_depth: int = 0 + + +def _extract_custom_reqs_from_item(item: dict, item_id: str, req_cache: dict) -> None: + """ + Remove custom namespaced requirements from ``item['requirements']`` (and + ``item['hints']`` as fallback) in-place, storing them in *req_cache* keyed + by *item_id*. + + Custom requirements found in ``hints`` are also extracted so that they can + be re-injected into ``requirements`` by ``_inject_custom_reqs_into_item`` + (Calrissian's ``make_job_runner`` uses ``get_requirement()`` which searches + hints, but ``KubernetesDaskPodBuilder`` only reads ``requirements``). + + Handles both dict-form and list-form requirements/hints. + """ + collected: list = [] + + # --- process requirements --- + reqs = item.get("requirements") + if isinstance(reqs, dict): + custom_reqs: dict = {} + standard_reqs: dict = {} + for req_name, req_value in reqs.items(): + if ":" in str(req_name): + logger.debug(f"Storing custom requirement for {item_id}: {req_name}") + custom_reqs[req_name] = req_value + else: + standard_reqs[req_name] = req_value + if custom_reqs: + collected.append(("dict", custom_reqs)) + item["requirements"] = standard_reqs + elif isinstance(reqs, list): + custom_reqs_list: list = [] + standard_reqs_list: list = [] + for req in reqs: + if isinstance(req, dict): + req_class = req.get("class", "") + if ":" in str(req_class): + logger.debug( + f"Storing custom requirement for {item_id}: {req_class}" + ) + custom_reqs_list.append(req) + else: + standard_reqs_list.append(req) + else: + standard_reqs_list.append(req) + if custom_reqs_list: + collected.append(("list", custom_reqs_list)) + item["requirements"] = standard_reqs_list + + # --- process hints (fallback: custom reqs may have landed here) --- + hints = item.get("hints") + if isinstance(hints, list): + custom_hints: list = [] + standard_hints: list = [] + for hint in hints: + if isinstance(hint, dict): + hint_class = hint.get("class", "") + if ":" in str(hint_class): + logger.debug( + f"Storing custom hint as requirement for {item_id}: {hint_class}" + ) + custom_hints.append(hint) + else: + standard_hints.append(hint) + else: + standard_hints.append(hint) + if custom_hints: + collected.append(("list", custom_hints)) + item["hints"] = standard_hints + + # Merge all collected custom reqs into a single list for this item + if collected: + merged: list = [] + for form, data in collected: + if form == "list": + merged.extend(data) + else: # dict form → convert to list form for uniform injection + for req_name, req_value in data.items(): + entry: dict = {"class": req_name} + if isinstance(req_value, dict): + entry.update(req_value) + merged.append(entry) + req_cache[item_id] = merged + + +def _clean_custom_namespaces( + raw_process: Mapping[str, Any], +) -> Tuple[Mapping[str, Any], dict, dict]: + """ + Extract custom namespaced requirements and record ``$namespaces`` for later + restoration. + + Custom requirements - those whose dict key (dict-form) or ``class`` value + (list-form) contains a colon - are removed so that the standard CWL parser + does not reject them. Both ``$graph`` documents and single top-level process + documents are handled. + + The function never mutates *raw_process* or any of its nested objects. + + Args: + raw_process: The raw CWL document as a plain dict or CommentedMap. + + Returns: + A 3-tuple ``(cleaned_doc, req_cache, ns_store)`` where: + + * ``cleaned_doc`` – a (deep-)copy of *raw_process* with custom namespaced + requirements removed from every process item. + * ``req_cache`` – ``{item_id: custom_reqs}`` mapping; *custom_reqs* is a + dict (dict-form source) or list (list-form source). + * ``ns_store`` – ``{'__root__': {…}}`` if ``$namespaces`` was present, + empty dict otherwise. + """ + # Shallow-copy the top level so we do not mutate the caller's mapping. + cleaned: Any = ( + raw_process.copy() + if isinstance(raw_process, dict) + else CommentedMap(raw_process) + ) + req_cache: dict = {} + ns_store: dict = {} + + if "$namespaces" in cleaned: + ns_store["__root__"] = dict(cleaned["$namespaces"]) + logger.debug(f"Saved original $namespaces: {ns_store['__root__']}") + + if "$graph" in cleaned and isinstance(cleaned["$graph"], list): + # Rebuild the $graph list using deep copies of each item so that we can + # mutate requirements without touching the caller's original objects. + new_graph = [] + for item in cleaned["$graph"]: + if isinstance(item, dict): + item = copy.deepcopy(item) + item_id = item.get("id", "unknown") + _extract_custom_reqs_from_item(item, item_id, req_cache) + new_graph.append(item) + cleaned["$graph"] = new_graph + elif "requirements" in cleaned: + # Single top-level process (CommandLineTool / Workflow / …). + # Deep-copy the entire cleaned document before mutating it. + cleaned = copy.deepcopy(cleaned) + item_id = cleaned.get("id", "__top__") + _extract_custom_reqs_from_item(cleaned, item_id, req_cache) + + return cleaned, req_cache, ns_store + + +def _lookup_in_cache(item_id: Optional[str], cache: Mapping[str, Any]) -> Optional[Any]: + """ + Find *item_id* in *cache*, trying the full string first then progressively + shorter forms (fragment after ``#``, last path segment after ``/``). + + Returns the cached value or ``None`` if not found. + """ + if item_id is None: + return None + if item_id in cache: + return cache[item_id] + for sep in ("#", "/"): + if sep in str(item_id): + short = str(item_id).split(sep)[-1] + if short in cache: + return cache[short] + return None + + +def get_custom_requirements(item_id: str) -> List[Any] | Mapping[str, Any]: + """ + Retrieve custom requirements for a given item ID from the global cache. + + Args: + item_id: The ID of the CWL item (CommandLineTool, Workflow, etc.) + + Returns: + Custom requirements (list or dict) or empty list if none found + """ + return _custom_requirements_cache.get(item_id, []) + def _is_url(path_or_url: str) -> bool: try: @@ -97,71 +282,98 @@ def load_cwl_from_yaml( `raw_process` (`dict`): The dictionary representing the CWL document `uri` (`Optional[str]`): The CWL document URI. Default to `io://` `cwl_version` (`Optional[str]`): The CWL document version. Default to `v1.2` + `sort` (`Optional[bool]`): Sort processes by dependencies. Default to `True` Returns: `Processes`: The parsed CWL Process or Processes (if the CWL document is a `$graph`). """ - updated_process = raw_process + global _load_depth - if cwl_version != raw_process[__CWL_VERSION__]: - logger.debug( - f"Updating the model from version '{raw_process[__CWL_VERSION__]}' to version '{cwl_version}'..." - ) + # At the top-level load (not a recursive call from _dereference_steps) clear + # the caches so that state from a previous load does not bleed into the + # current one. + if _load_depth == 0: + _custom_requirements_cache.clear() + _original_namespaces.clear() - updated_process = update( - doc=raw_process - if isinstance(raw_process, CommentedMap) - else CommentedMap(OrderedDict(raw_process)), - loader=_global_loader, - baseuri=uri, - enable_dev=False, - metadata=CommentedMap(OrderedDict({"cwlVersion": cwl_version})), - update_to=cwl_version, + _load_depth += 1 + try: + # Clean custom namespaces and requirements before processing. + # _clean_custom_namespaces never mutates raw_process and returns local dicts. + cleaned_process, local_req_cache, local_ns = _clean_custom_namespaces( + raw_process ) - logger.debug(f"Raw CWL document successfully updated to {cwl_version}!") - else: - logger.debug( - f"No needs to update the Raw CWL document since it targets already the {cwl_version}" - ) + # Merge per-document caches into the module globals so they are accessible + # via get_custom_requirements / extract_dask_config without an explicit arg. + _custom_requirements_cache.update(local_req_cache) + _original_namespaces.update(local_ns) + + updated_process = cleaned_process + + if cwl_version != cleaned_process[__CWL_VERSION__]: + logger.debug( + f"Updating the model from version '{cleaned_process[__CWL_VERSION__]}' to version '{cwl_version}'..." + ) + + updated_process = update( + doc=cleaned_process + if isinstance(cleaned_process, CommentedMap) + else CommentedMap(OrderedDict(cleaned_process)), + loader=_global_loader, + baseuri=uri, + enable_dev=False, + metadata=CommentedMap(OrderedDict({"cwlVersion": cwl_version})), + update_to=cwl_version, + ) + + logger.debug(f"Raw CWL document successfully updated to {cwl_version}!") + else: + logger.debug( + f"No needs to update the Raw CWL document since it targets already the {cwl_version}" + ) - logger.debug("Parsing the raw CWL document to the CWL Utils DOM...") + logger.debug("Parsing the raw CWL document to the CWL Utils DOM...") - clean_uri, fragment = urldefrag(uri) + clean_uri, fragment = urldefrag(uri) - if fragment: - logger.debug(f"Ignoring fragment #{fragment} from URI {clean_uri}") + if fragment: + logger.debug(f"Ignoring fragment #{fragment} from URI {clean_uri}") - process = load_document_by_yaml(yaml=updated_process, uri=clean_uri, load_all=True) + process = load_document_by_yaml( + yaml=updated_process, uri=clean_uri, load_all=True + ) - logger.debug("Raw CWL document successfully parsed to the CWL Utils DOM!") + logger.debug("Raw CWL document successfully parsed to the CWL Utils DOM!") - logger.debug("Dereferencing the steps[].run...") + logger.debug("Dereferencing the steps[].run...") - dereferenced_process = _dereference_steps(process=process, uri=uri) + dereferenced_process = _dereference_steps(process=process, uri=uri) - logger.debug("steps[].run successfully dereferenced! Dereferencing the FQNs...") + logger.debug("steps[].run successfully dereferenced! Dereferencing the FQNs...") - remove_refs(dereferenced_process) + remove_refs(dereferenced_process) - logger.debug( - "CWL document successfully dereferenced! Now verifying steps[].run integrity..." - ) + logger.debug( + "CWL document successfully dereferenced! Now verifying steps[].run integrity..." + ) - assert_connected_graph(dereferenced_process) + assert_connected_graph(dereferenced_process) - logger.debug("All steps[].run link are resolvable! ") + logger.debug("All steps[].run link are resolvable! ") - if sort: - logger.debug("Sorting Process instances by dependencies....") - dereferenced_process = order_graph_by_dependencies(dereferenced_process) - logger.debug("Sorting process is over.") + if sort: + logger.debug("Sorting Process instances by dependencies....") + dereferenced_process = order_graph_by_dependencies(dereferenced_process) + logger.debug("Sorting process is over.") - return ( - dereferenced_process - if len(dereferenced_process) > 1 - else dereferenced_process[0] - ) + return ( + dereferenced_process + if len(dereferenced_process) > 1 + else dereferenced_process[0] + ) + finally: + _load_depth -= 1 def load_cwl_from_stream( @@ -282,3 +494,134 @@ def dump_cwl(process: Process | List[Process], stream: TextIO): ) _yaml.dump(data=data, stream=stream) + + +def _inject_custom_reqs_into_item(item: dict, custom_reqs: Any) -> None: + """ + Reinject *custom_reqs* (list or dict form) into ``item['requirements']``. + + All custom requirements (including calrissian:DaskGatewayRequirement) are + injected into ``requirements`` so that Calrissian can find them - it reads + DaskGatewayRequirement from ``requirements``, not ``hints``. + """ + if "requirements" not in item or not isinstance(item["requirements"], list): + item["requirements"] = [] + + if isinstance(custom_reqs, list): + for custom_req in custom_reqs: + item["requirements"].append(custom_req) + elif isinstance(custom_reqs, dict): + for req_name, req_value in custom_reqs.items(): + custom_req_entry: dict = {"class": req_name} + if isinstance(req_value, dict): + custom_req_entry.update(req_value) + elif req_value is not None: + logger.warning( + f"Custom requirement '{req_name}' has a non-mapping value " + f"{req_value!r}; only the 'class' key will be emitted in the " + "serialised output." + ) + item["requirements"].append(custom_req_entry) + + +def dump_cwl_with_custom_requirements( + process: Process | List[Process], + stream: TextIO, + custom_requirements_cache: Optional[Mapping[str, Any]] = None, + original_namespaces: Optional[Mapping[str, Any]] = None, +): + """ + Serializes a CWL document with custom requirements properly reinjected into the requirements section. + + This function ensures that custom namespaced requirements (like calrissian:DaskGatewayRequirement) + are placed in the correct location within the 'requirements' section. + + Args: + `process` (`Processes`): The CWL Process or Processes (if the CWL document is a `$graph`) + `stream` (`Stream`): The stream where serializing the CWL document + `custom_requirements_cache` (`Mapping[str, Any]`, optional): Cache of custom requirements. + If None, uses the module-level cache. + `original_namespaces` (`Mapping[str, Any]`, optional): Saved $namespaces mapping. + If None, uses the module-level cache. + + Returns: + `None`: none. + """ + if custom_requirements_cache is None: + custom_requirements_cache = _custom_requirements_cache + if original_namespaces is None: + original_namespaces = _original_namespaces + + data = save( + val=process, # type: ignore + relative_uris=False, + ) + + if "__root__" in original_namespaces: + data["$namespaces"] = original_namespaces["__root__"] + logger.debug(f"Restored original $namespaces: {data['$namespaces']}") + + if "$graph" in data and isinstance(data["$graph"], list): + for item in data["$graph"]: + if isinstance(item, dict): + item_id = item.get("id") + + if "cwlVersion" in item: + del item["cwlVersion"] + if "$namespaces" in item: + del item["$namespaces"] + + custom_reqs = _lookup_in_cache(item_id, custom_requirements_cache) + if custom_reqs is not None: + _inject_custom_reqs_into_item(item, custom_reqs) + else: + # Single top-level process (no $graph wrapper). + item_id = data.get("id") if isinstance(data, dict) else None + custom_reqs = _lookup_in_cache(item_id, custom_requirements_cache) + if custom_reqs is None: + # Fallback: top-level processes without an id were cached under '__top__'. + custom_reqs = custom_requirements_cache.get("__top__") + if custom_reqs is not None and isinstance(data, dict): + _inject_custom_reqs_into_item(data, custom_reqs) + + _yaml.dump(data=data, stream=stream) + + +def extract_dask_config( + custom_requirements_cache: Optional[Mapping[str, Any]] = None, +) -> Mapping[str, Any]: + """ + Extracts Dask Gateway configuration from custom requirements cache. + + This utility function searches for DaskGatewayRequirement in the custom requirements + and returns a dictionary with the Dask configuration parameters. + + Args: + `custom_requirements_cache` (`Mapping[str, Any]`, optional): Cache of custom requirements. + If None, uses the module-level cache. + + Returns: + `Mapping[str, Any]`: Dictionary containing all fields found in the + DaskGatewayRequirement (except the `class` key when + the requirement is represented as a list item). + Returns empty dict if no DaskGatewayRequirement found. + """ + if custom_requirements_cache is None: + custom_requirements_cache = _custom_requirements_cache + + for item_id, reqs in custom_requirements_cache.items(): + if isinstance(reqs, dict): + for req_name, req_value in reqs.items(): + if "DaskGatewayRequirement" in req_name: + logger.debug(f"Found DaskGatewayRequirement in {item_id}") + return dict(req_value) if isinstance(req_value, dict) else {} + elif isinstance(reqs, list): + for req in reqs: + if isinstance(req, dict): + req_class = req.get("class", "") + if "DaskGatewayRequirement" in req_class: + logger.debug(f"Found DaskGatewayRequirement in {item_id}") + return {k: v for k, v in req.items() if k != "class"} + + logger.debug("No DaskGatewayRequirement found in custom requirements cache") + return {} diff --git a/src/cwl_loader/utils.py b/src/cwl_loader/utils.py index 4a460b8..3ca30d9 100644 --- a/src/cwl_loader/utils.py +++ b/src/cwl_loader/utils.py @@ -96,7 +96,8 @@ def remove_refs(process: Process | List[Process]): step.run = step.run[step.run.rfind("#") :] if getattr(step, "scatter", None): - step.scatter = _clean_values(step.scatter, f"#{process.id}/") + cleaned_scatter = _clean_values(step.scatter, f"#{process.id}/") + step.scatter = _clean_values(cleaned_scatter, f"{step.id}/") if process.extension_fields and ORIGINAL_CWLVERSION in process.extension_fields: process.extension_fields.pop(ORIGINAL_CWLVERSION) diff --git a/tests/test_custom_requirements.py b/tests/test_custom_requirements.py new file mode 100644 index 0000000..a30ecff --- /dev/null +++ b/tests/test_custom_requirements.py @@ -0,0 +1,206 @@ +# Copyright 2026 Terradue +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import StringIO +from unittest import TestCase + +from cwl_utils.parser import Process +from ruamel.yaml import YAML + +from cwl_loader import ( + dump_cwl_with_custom_requirements, + extract_dask_config, + load_cwl_from_string_content, + _custom_requirements_cache, + _original_namespaces, +) + +_yaml = YAML() + +# --------------------------------------------------------------------------- +# Minimal valid CWL fixtures +# --------------------------------------------------------------------------- + +_SIMPLE_CWL = """\ +cwlVersion: v1.2 +class: CommandLineTool +id: simple-tool +baseCommand: echo +inputs: + msg: + type: string + inputBinding: + position: 1 +outputs: + out: + type: stdout +""" + +_GRAPH_CWL_WITH_CUSTOM_REQ = """\ +cwlVersion: v1.2 +$namespaces: + calrissian: "http://calrissian.example.com/" +$graph: +- class: CommandLineTool + id: echo-tool + baseCommand: echo + requirements: + DockerRequirement: + dockerPull: "alpine:latest" + calrissian:DaskGatewayRequirement: + gateway_url: "http://gateway.example.com" + worker_cores: 2 + worker_memory: "4G" + inputs: + message: + type: string + inputBinding: + position: 1 + outputs: + out: + type: stdout +- class: Workflow + id: main + inputs: + message: + type: string + outputs: + result: + type: File + outputSource: step1/out + steps: + step1: + run: "#echo-tool" + in: + message: message + out: [out] +""" + +_SINGLE_CWL_WITH_CUSTOM_REQ = """\ +cwlVersion: v1.2 +$namespaces: + calrissian: "http://calrissian.example.com/" +class: CommandLineTool +id: echo-tool +baseCommand: echo +requirements: + DockerRequirement: + dockerPull: "alpine:latest" + calrissian:DaskGatewayRequirement: + gateway_url: "http://gateway.example.com" + worker_cores: 2 + worker_memory: "4G" +inputs: + message: + type: string + inputBinding: + position: 1 +outputs: + out: + type: stdout +""" + + +class TestCustomRequirements(TestCase): + @classmethod + def setUpClass(cls): + process = load_cwl_from_string_content(_GRAPH_CWL_WITH_CUSTOM_REQ) + out = StringIO() + dump_cwl_with_custom_requirements(process, out) + cls._graph_data = _yaml.load(out.getvalue()) + + process = load_cwl_from_string_content(_SINGLE_CWL_WITH_CUSTOM_REQ) + out = StringIO() + dump_cwl_with_custom_requirements(process, out) + cls._single_data = _yaml.load(out.getvalue()) + + def test_graph_with_custom_req_parses_successfully(self): + result = load_cwl_from_string_content(_GRAPH_CWL_WITH_CUSTOM_REQ) + self.assertIsNotNone(result) + self.assertIsInstance(result, list) + self.assertEqual(2, len(result)) + + def test_single_process_with_custom_req_parses_successfully(self): + result = load_cwl_from_string_content(_SINGLE_CWL_WITH_CUSTOM_REQ) + self.assertIsNotNone(result) + self.assertIsInstance(result, Process) + + def test_graph_dump_roundtrip_restores_namespaces(self): + self.assertIn("$namespaces", self._graph_data) + self.assertIn("calrissian", self._graph_data["$namespaces"]) + + def test_graph_dump_roundtrip_reinjects_custom_req(self): + tool_item = next( + ( + item + for item in self._graph_data["$graph"] + if item.get("class") == "CommandLineTool" + ), + None, + ) + self.assertIsNotNone(tool_item) + reqs = tool_item.get("requirements", []) + dask_req = [r for r in reqs if "DaskGatewayRequirement" in r.get("class", "")] + self.assertEqual(1, len(dask_req)) + + def test_single_process_dump_roundtrip_reinjects_custom_req(self): + reqs = self._single_data.get("requirements", []) + dask_req = [r for r in reqs if "DaskGatewayRequirement" in r.get("class", "")] + self.assertEqual(1, len(dask_req)) + + def test_extract_dask_config_from_graph(self): + load_cwl_from_string_content(_GRAPH_CWL_WITH_CUSTOM_REQ) + config = extract_dask_config() + self.assertEqual("http://gateway.example.com", config.get("gateway_url")) + self.assertEqual(2, config.get("worker_cores")) + self.assertEqual("4G", config.get("worker_memory")) + + def test_extract_dask_config_from_single_process(self): + load_cwl_from_string_content(_SINGLE_CWL_WITH_CUSTOM_REQ) + config = extract_dask_config() + self.assertEqual("http://gateway.example.com", config.get("gateway_url")) + + def test_extract_dask_config_with_explicit_cache(self): + explicit_cache = { + "my-tool": { + "calrissian:DaskGatewayRequirement": { + "gateway_url": "http://explicit.example.com" + } + } + } + config = extract_dask_config(custom_requirements_cache=explicit_cache) + self.assertEqual("http://explicit.example.com", config.get("gateway_url")) + + def test_extract_dask_config_returns_empty_when_absent(self): + config = extract_dask_config(custom_requirements_cache={}) + self.assertEqual({}, config) + + def test_successive_loads_do_not_leak_cache(self): + load_cwl_from_string_content(_SINGLE_CWL_WITH_CUSTOM_REQ) + self.assertTrue( + len(_custom_requirements_cache) > 0, + "Cache should be populated after first load", + ) + + load_cwl_from_string_content(_SIMPLE_CWL) + self.assertEqual( + {}, + dict(_custom_requirements_cache), + "Cache must be empty after loading a doc without custom reqs", + ) + self.assertEqual( + {}, + dict(_original_namespaces), + "Namespace store must be empty after loading a doc without $namespaces", + )