From d9302c3395fa144b7e97a4ce0cd74b1d339816a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A9rald=20Fenoy?= Date: Tue, 10 Feb 2026 17:32:02 +0100 Subject: [PATCH 1/4] Preserve namespaces and custom requirements Add support for calrissian:DaskGatewayRequirements --- src/cwl_loader/__init__.py | 200 ++++++++++++++++++++++++++++++++++++- src/cwl_loader/utils.py | 3 +- 2 files changed, 198 insertions(+), 5 deletions(-) diff --git a/src/cwl_loader/__init__.py b/src/cwl_loader/__init__.py index 4317fb4..7b9270a 100644 --- a/src/cwl_loader/__init__.py +++ b/src/cwl_loader/__init__.py @@ -59,6 +59,87 @@ _yaml = YAML() _global_loader = default_loader() +_custom_requirements_cache = {} +_original_namespaces = {} + +def _clean_custom_namespaces(raw_process: Mapping[str, Any] | CommentedMap) -> Mapping[str, Any] | CommentedMap: + """ + Remove $namespaces, schemas and custom requirements with namespace prefixes. + Store custom requirements in a cache for later retrieval. + + Note: $namespaces and schemas declarations don't cause validation errors by themselves. + Only the use of custom requirements (like calrissian:DaskGatewayRequirement) causes errors + in cwl_utils parser because it doesn't recognize them. + + Args: + raw_process: The raw CWL document as dict or CommentedMap + + Returns: + Cleaned CWL document without custom namespaced requirements + """ + cleaned = raw_process.copy() if isinstance(raw_process, dict) else CommentedMap(raw_process) + + if '$namespaces' in cleaned: + _original_namespaces['__root__'] = dict(cleaned['$namespaces']) + logger.debug(f"Saved original $namespaces: {_original_namespaces['__root__']}") + + # The custom requirements in $graph items + # We need to remove them for validation, but preserve them for Calrissian/custom runners + if '$graph' in cleaned and isinstance(cleaned['$graph'], list): + for item in cleaned['$graph']: + if isinstance(item, dict): + item_id = item.get('id', 'unknown') + + if 'requirements' in item and isinstance(item['requirements'], dict): + custom_reqs = {} + standard_reqs = {} + + for req_name, req_value in item['requirements'].items(): + if ':' in str(req_name): + logger.debug(f"Storing custom requirement for {item_id}: {req_name}") + custom_reqs[req_name] = req_value + else: + standard_reqs[req_name] = req_value + + if custom_reqs: + _custom_requirements_cache[item_id] = custom_reqs + + item['requirements'] = standard_reqs + + elif 'requirements' in item and isinstance(item['requirements'], list): + custom_reqs = [] + standard_reqs = [] + + for req in item['requirements']: + if isinstance(req, dict): + req_class = req.get('class', '') + if ':' in str(req_class): + logger.debug(f"Storing custom requirement for {item_id}: {req_class}") + custom_reqs.append(req) + else: + standard_reqs.append(req) + else: + standard_reqs.append(req) + + if custom_reqs: + _custom_requirements_cache[item_id] = custom_reqs + + item['requirements'] = standard_reqs + + return cleaned + +def get_custom_requirements(item_id: str) -> List[Any] | Mapping[str, Any]: + """ + Retrieve custom requirements for a given item ID. + + Args: + item_id: The ID of the CWL item (CommandLineTool, Workflow, etc.) + + Returns: + Custom requirements (list or dict) or empty list if none found + """ + return _custom_requirements_cache.get(item_id, []) + def _is_url(path_or_url: str) -> bool: try: result = urlparse(path_or_url) @@ -120,17 +201,21 @@ def load_cwl_from_yaml( `raw_process` (`dict`): The dictionary representing the CWL document `uri` (`Optional[str]`): The CWL document URI. Default to `io://` `cwl_version` (`Optional[str]`): The CWL document version. Default to `v1.2` + `sort` (`Optional[bool]`): Sort processes by dependencies. Default to `True` Returns: `Processes`: The parsed CWL Process or Processes (if the CWL document is a `$graph`). ''' - updated_process = raw_process + # Clean custom namespaces and requirements before processing + cleaned_process = _clean_custom_namespaces(raw_process) + + updated_process = cleaned_process - if cwl_version != raw_process[__CWL_VERSION__]: - logger.debug(f"Updating the model from version '{raw_process[__CWL_VERSION__]}' to version '{cwl_version}'...") + if cwl_version != cleaned_process[__CWL_VERSION__]: + logger.debug(f"Updating the model from version '{cleaned_process[__CWL_VERSION__]}' to version '{cwl_version}'...") updated_process = update( - doc=raw_process if isinstance(raw_process, CommentedMap) else CommentedMap(OrderedDict(raw_process)), + doc=cleaned_process if isinstance(cleaned_process, CommentedMap) else CommentedMap(OrderedDict(cleaned_process)), loader=_global_loader, baseuri=uri, enable_dev=False, @@ -306,3 +391,110 @@ def dump_cwl( ) _yaml.dump(data=data, stream=stream) + +def dump_cwl_with_custom_requirements( + process: Process | List[Process], + stream: TextIO, + custom_requirements_cache: Mapping[str, Any] | None = None +): + ''' + Serializes a CWL document with custom requirements properly reinjected into the requirements section. + + This function ensures that custom namespaced requirements (like calrissian:DaskGatewayRequirement) + are placed in the correct location within the 'requirements' section. + + Args: + `process` (`Processes`): The CWL Process or Processes (if the CWL document is a `$graph`) + `stream` (`Stream`): The stream where serializing the CWL document + `custom_requirements_cache` (`Mapping[str, Any]`, optional): Cache of custom requirements. + If None, uses the global _custom_requirements_cache. + + Returns: + `None`: none. + ''' + if custom_requirements_cache is None: + custom_requirements_cache = _custom_requirements_cache + + data = save( + val=process, # type: ignore + relative_uris=False + ) + + if '__root__' in _original_namespaces: + data['$namespaces'] = _original_namespaces['__root__'] + logger.debug(f"Restored original $namespaces: {data['$namespaces']}") + + if '$graph' in data and isinstance(data['$graph'], list): + for item in data['$graph']: + if isinstance(item, dict): + item_id = item.get('id') + + if 'cwlVersion' in item: + del item['cwlVersion'] + if '$namespaces' in item: + del item['$namespaces'] + + if item_id and item_id in custom_requirements_cache: + custom_reqs = custom_requirements_cache[item_id] + + if 'requirements' not in item: + item['requirements'] = [] + + if not isinstance(item['requirements'], list): + item['requirements'] = [] + + # Add custom requirements in the same format as standard requirements + if isinstance(custom_reqs, list): + for custom_req in custom_reqs: + item['requirements'].append(custom_req) + elif isinstance(custom_reqs, dict): + for req_name, req_value in custom_reqs.items(): + # Create a requirement dict in CWL format with 'class' field + custom_req_entry = {'class': req_name} + + # Add all properties from req_value + if isinstance(req_value, dict): + custom_req_entry.update(req_value) + + item['requirements'].append(custom_req_entry) + + _yaml.dump(data=data, stream=stream) + +def extract_dask_config( + custom_requirements_cache: Mapping[str, Any] | None = None +) -> Mapping[str, Any]: + ''' + Extracts Dask Gateway configuration from custom requirements cache. + + This utility function searches for DaskGatewayRequirement in the custom requirements + and returns a dictionary with the Dask configuration parameters. + + Args: + `custom_requirements_cache` (`Mapping[str, Any]`, optional): Cache of custom requirements. + If None, uses the global _custom_requirements_cache. + + Returns: + `Mapping[str, Any]`: Dictionary containing all fields found in the + DaskGatewayRequirement (except the `class` key when + the requirement is represented as a list item). + Returns empty dict if no DaskGatewayRequirement found. + ''' + if custom_requirements_cache is None: + custom_requirements_cache = _custom_requirements_cache + + for item_id, reqs in custom_requirements_cache.items(): + if isinstance(reqs, dict): + for req_name, req_value in reqs.items(): + if 'DaskGatewayRequirement' in req_name: + logger.debug(f"Found DaskGatewayRequirement in {item_id}") + return dict(req_value) if isinstance(req_value, dict) else {} + elif isinstance(reqs, list): + for req in reqs: + if isinstance(req, dict): + req_class = req.get('class', '') + if 'DaskGatewayRequirement' in req_class: + logger.debug(f"Found DaskGatewayRequirement in {item_id}") + return {k: v for k, v in req.items() if k != 'class'} + + logger.debug("No DaskGatewayRequirement found in custom requirements cache") + return {} diff --git a/src/cwl_loader/utils.py b/src/cwl_loader/utils.py index b160abe..c90d81e 100644 --- a/src/cwl_loader/utils.py +++ b/src/cwl_loader/utils.py @@ -118,7 +118,8 @@ def remove_refs( step.run = step.run[step.run.rfind('#'):] if getattr(step, 'scatter', None): - step.scatter = _clean_values(step.scatter, f"#{process.id}/") + cleaned_scatter = _clean_values(step.scatter, f"#{process.id}/") + step.scatter = _clean_values(cleaned_scatter, f"{step.id}/") if process.extension_fields and ORIGINAL_CWLVERSION in process.extension_fields: process.extension_fields.pop(ORIGINAL_CWLVERSION) From 4b05d8aa325d91c6461099bd9629b66798eb3adc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A9rald=20Fenoy?= Date: Fri, 3 Apr 2026 10:02:15 +0200 Subject: [PATCH 2/4] Address Copilot feedbacks --- src/cwl_loader/__init__.py | 390 ++++++++++++++++++++---------- tests/test_custom_requirements.py | 198 +++++++++++++++ 2 files changed, 465 insertions(+), 123 deletions(-) create mode 100644 tests/test_custom_requirements.py diff --git a/src/cwl_loader/__init__.py b/src/cwl_loader/__init__.py index 7b9270a..4513bfe 100644 --- a/src/cwl_loader/__init__.py +++ b/src/cwl_loader/__init__.py @@ -42,12 +42,15 @@ Any, List, Mapping, - TextIO + Optional, + TextIO, + Tuple ) from urllib.parse import ( urlparse, urldefrag ) +import copy import requests import os @@ -59,78 +62,177 @@ _yaml = YAML() _global_loader = default_loader() -_custom_requirements_cache = {} -_original_namespaces = {} +# Module-level caches storing context from the most recent load. +# Cleared at the start of each top-level load (depth == 0) to prevent leaking +# state between successive loads. +_custom_requirements_cache: dict = {} +_original_namespaces: dict = {} +_load_depth: int = 0 -def _clean_custom_namespaces(raw_process: Mapping[str, Any] | CommentedMap) -> Mapping[str, Any] | CommentedMap: + +def _extract_custom_reqs_from_item( + item: dict, + item_id: str, + req_cache: dict +) -> None: + """ + Remove custom namespaced requirements from ``item['requirements']`` (and + ``item['hints']`` as fallback) in-place, storing them in *req_cache* keyed + by *item_id*. + + Custom requirements found in ``hints`` are also extracted so that they can + be re-injected into ``requirements`` by ``_inject_custom_reqs_into_item`` + (Calrissian's ``make_job_runner`` uses ``get_requirement()`` which searches + hints, but ``KubernetesDaskPodBuilder`` only reads ``requirements``). + + Handles both dict-form and list-form requirements/hints. + """ + collected: list = [] + + # --- process requirements --- + reqs = item.get('requirements') + if isinstance(reqs, dict): + custom_reqs: dict = {} + standard_reqs: dict = {} + for req_name, req_value in reqs.items(): + if ':' in str(req_name): + logger.debug(f"Storing custom requirement for {item_id}: {req_name}") + custom_reqs[req_name] = req_value + else: + standard_reqs[req_name] = req_value + if custom_reqs: + collected.append(('dict', custom_reqs)) + item['requirements'] = standard_reqs + elif isinstance(reqs, list): + custom_reqs_list: list = [] + standard_reqs_list: list = [] + for req in reqs: + if isinstance(req, dict): + req_class = req.get('class', '') + if ':' in str(req_class): + logger.debug(f"Storing custom requirement for {item_id}: {req_class}") + custom_reqs_list.append(req) + else: + standard_reqs_list.append(req) + else: + standard_reqs_list.append(req) + if custom_reqs_list: + collected.append(('list', custom_reqs_list)) + item['requirements'] = standard_reqs_list + + # --- process hints (fallback: custom reqs may have landed here) --- + hints = item.get('hints') + if isinstance(hints, list): + custom_hints: list = [] + standard_hints: list = [] + for hint in hints: + if isinstance(hint, dict): + hint_class = hint.get('class', '') + if ':' in str(hint_class): + logger.debug(f"Storing custom hint as requirement for {item_id}: {hint_class}") + custom_hints.append(hint) + else: + standard_hints.append(hint) + else: + standard_hints.append(hint) + if custom_hints: + collected.append(('list', custom_hints)) + item['hints'] = standard_hints + + # Merge all collected custom reqs into a single list for this item + if collected: + merged: list = [] + for form, data in collected: + if form == 'list': + merged.extend(data) + else: # dict form → convert to list form for uniform injection + for req_name, req_value in data.items(): + entry: dict = {'class': req_name} + if isinstance(req_value, dict): + entry.update(req_value) + merged.append(entry) + req_cache[item_id] = merged + + +def _clean_custom_namespaces( + raw_process: Mapping[str, Any] +) -> Tuple[Mapping[str, Any], dict, dict]: """ - Remove $namespaces, schemas and custom requirements with namespace prefixes. - Store custom requirements in a cache for later retrieval. + Extract custom namespaced requirements and record ``$namespaces`` for later + restoration. + + Custom requirements — those whose dict key (dict-form) or ``class`` value + (list-form) contains a colon — are removed so that the standard CWL parser + does not reject them. Both ``$graph`` documents and single top-level process + documents are handled. - Note: $namespaces and schemas declarations don't cause validation errors by themselves. - Only the use of custom requirements (like calrissian:DaskGatewayRequirement) causes errors - in cwl_utils parser because it doesn't recognize them. + The function never mutates *raw_process* or any of its nested objects. Args: - raw_process: The raw CWL document as dict or CommentedMap + raw_process: The raw CWL document as a plain dict or CommentedMap. Returns: - Cleaned CWL document without custom namespaced requirements + A 3-tuple ``(cleaned_doc, req_cache, ns_store)`` where: + + * ``cleaned_doc`` – a (deep-)copy of *raw_process* with custom namespaced + requirements removed from every process item. + * ``req_cache`` – ``{item_id: custom_reqs}`` mapping; *custom_reqs* is a + dict (dict-form source) or list (list-form source). + * ``ns_store`` – ``{'__root__': {…}}`` if ``$namespaces`` was present, + empty dict otherwise. """ - cleaned = raw_process.copy() if isinstance(raw_process, dict) else CommentedMap(raw_process) + # Shallow-copy the top level so we do not mutate the caller's mapping. + cleaned: Any = raw_process.copy() if isinstance(raw_process, dict) else CommentedMap(raw_process) + req_cache: dict = {} + ns_store: dict = {} if '$namespaces' in cleaned: - _original_namespaces['__root__'] = dict(cleaned['$namespaces']) - logger.debug(f"Saved original $namespaces: {_original_namespaces['__root__']}") + ns_store['__root__'] = dict(cleaned['$namespaces']) + logger.debug(f"Saved original $namespaces: {ns_store['__root__']}") - # The custom requirements in $graph items - # We need to remove them for validation, but preserve them for Calrissian/custom runners if '$graph' in cleaned and isinstance(cleaned['$graph'], list): + # Rebuild the $graph list using deep copies of each item so that we can + # mutate requirements without touching the caller's original objects. + new_graph = [] for item in cleaned['$graph']: if isinstance(item, dict): + item = copy.deepcopy(item) item_id = item.get('id', 'unknown') + _extract_custom_reqs_from_item(item, item_id, req_cache) + new_graph.append(item) + cleaned['$graph'] = new_graph + elif 'requirements' in cleaned: + # Single top-level process (CommandLineTool / Workflow / …). + # Deep-copy the entire cleaned document before mutating it. + cleaned = copy.deepcopy(cleaned) + item_id = cleaned.get('id', '__top__') + _extract_custom_reqs_from_item(cleaned, item_id, req_cache) - if 'requirements' in item and isinstance(item['requirements'], dict): - custom_reqs = {} - standard_reqs = {} - - for req_name, req_value in item['requirements'].items(): - if ':' in str(req_name): - logger.debug(f"Storing custom requirement for {item_id}: {req_name}") - custom_reqs[req_name] = req_value - else: - standard_reqs[req_name] = req_value - - if custom_reqs: - _custom_requirements_cache[item_id] = custom_reqs + return cleaned, req_cache, ns_store - item['requirements'] = standard_reqs - elif 'requirements' in item and isinstance(item['requirements'], list): - custom_reqs = [] - standard_reqs = [] - - for req in item['requirements']: - if isinstance(req, dict): - req_class = req.get('class', '') - if ':' in str(req_class): - logger.debug(f"Storing custom requirement for {item_id}: {req_class}") - custom_reqs.append(req) - else: - standard_reqs.append(req) - else: - standard_reqs.append(req) - - if custom_reqs: - _custom_requirements_cache[item_id] = custom_reqs +def _lookup_in_cache(item_id: Optional[str], cache: Mapping[str, Any]) -> Optional[Any]: + """ + Find *item_id* in *cache*, trying the full string first then progressively + shorter forms (fragment after ``#``, last path segment after ``/``). - item['requirements'] = standard_reqs + Returns the cached value or ``None`` if not found. + """ + if item_id is None: + return None + if item_id in cache: + return cache[item_id] + for sep in ('#', '/'): + if sep in str(item_id): + short = str(item_id).split(sep)[-1] + if short in cache: + return cache[short] + return None - return cleaned def get_custom_requirements(item_id: str) -> List[Any] | Mapping[str, Any]: """ - Retrieve custom requirements for a given item ID. + Retrieve custom requirements for a given item ID from the global cache. Args: item_id: The ID of the CWL item (CommandLineTool, Workflow, etc.) @@ -140,6 +242,7 @@ def get_custom_requirements(item_id: str) -> List[Any] | Mapping[str, Any]: """ return _custom_requirements_cache.get(item_id, []) + def _is_url(path_or_url: str) -> bool: try: result = urlparse(path_or_url) @@ -164,7 +267,7 @@ def _on_process( if run_url and not uri == run_url: referenced = load_cwl_from_location(run_url) - + if isinstance(referenced, list): accumulator += referenced @@ -206,65 +309,84 @@ def load_cwl_from_yaml( Returns: `Processes`: The parsed CWL Process or Processes (if the CWL document is a `$graph`). ''' - # Clean custom namespaces and requirements before processing - cleaned_process = _clean_custom_namespaces(raw_process) - - updated_process = cleaned_process + global _load_depth - if cwl_version != cleaned_process[__CWL_VERSION__]: - logger.debug(f"Updating the model from version '{cleaned_process[__CWL_VERSION__]}' to version '{cwl_version}'...") - - updated_process = update( - doc=cleaned_process if isinstance(cleaned_process, CommentedMap) else CommentedMap(OrderedDict(cleaned_process)), - loader=_global_loader, - baseuri=uri, - enable_dev=False, - metadata=CommentedMap(OrderedDict({'cwlVersion': cwl_version})), - update_to=cwl_version - ) + # At the top-level load (not a recursive call from _dereference_steps) clear + # the caches so that state from a previous load does not bleed into the + # current one. + if _load_depth == 0: + _custom_requirements_cache.clear() + _original_namespaces.clear() - logger.debug(f"Raw CWL document successfully updated to {cwl_version}!") - else: - logger.debug(f"No needs to update the Raw CWL document since it targets already the {cwl_version}") + _load_depth += 1 + try: + # Clean custom namespaces and requirements before processing. + # _clean_custom_namespaces never mutates raw_process and returns local dicts. + cleaned_process, local_req_cache, local_ns = _clean_custom_namespaces(raw_process) + + # Merge per-document caches into the module globals so they are accessible + # via get_custom_requirements / extract_dask_config without an explicit arg. + _custom_requirements_cache.update(local_req_cache) + _original_namespaces.update(local_ns) + + updated_process = cleaned_process + + if cwl_version != cleaned_process[__CWL_VERSION__]: + logger.debug(f"Updating the model from version '{cleaned_process[__CWL_VERSION__]}' to version '{cwl_version}'...") + + updated_process = update( + doc=cleaned_process if isinstance(cleaned_process, CommentedMap) else CommentedMap(OrderedDict(cleaned_process)), + loader=_global_loader, + baseuri=uri, + enable_dev=False, + metadata=CommentedMap(OrderedDict({'cwlVersion': cwl_version})), + update_to=cwl_version + ) + + logger.debug(f"Raw CWL document successfully updated to {cwl_version}!") + else: + logger.debug(f"No needs to update the Raw CWL document since it targets already the {cwl_version}") - logger.debug('Parsing the raw CWL document to the CWL Utils DOM...') + logger.debug('Parsing the raw CWL document to the CWL Utils DOM...') - clean_uri, fragment = urldefrag(uri) + clean_uri, fragment = urldefrag(uri) - if fragment: - logger.debug(f"Ignoring fragment #{fragment} from URI {clean_uri}") + if fragment: + logger.debug(f"Ignoring fragment #{fragment} from URI {clean_uri}") - process = load_document_by_yaml( - yaml=updated_process, - uri=clean_uri, - load_all=True - ) + process = load_document_by_yaml( + yaml=updated_process, + uri=clean_uri, + load_all=True + ) - logger.debug('Raw CWL document successfully parsed to the CWL Utils DOM!') + logger.debug('Raw CWL document successfully parsed to the CWL Utils DOM!') - logger.debug('Dereferencing the steps[].run...') + logger.debug('Dereferencing the steps[].run...') - dereferenced_process = _dereference_steps( - process=process, - uri=uri - ) + dereferenced_process = _dereference_steps( + process=process, + uri=uri + ) - logger.debug('steps[].run successfully dereferenced! Dereferencing the FQNs...') + logger.debug('steps[].run successfully dereferenced! Dereferencing the FQNs...') - remove_refs(dereferenced_process) + remove_refs(dereferenced_process) - logger.debug('CWL document successfully dereferenced! Now verifying steps[].run integrity...') + logger.debug('CWL document successfully dereferenced! Now verifying steps[].run integrity...') - assert_connected_graph(dereferenced_process) + assert_connected_graph(dereferenced_process) - logger.debug('All steps[].run link are resolvable! ') + logger.debug('All steps[].run link are resolvable! ') - if sort: - logger.debug('Sorting Process instances by dependencies....') - dereferenced_process = order_graph_by_dependencies(dereferenced_process) - logger.debug('Sorting process is over.') + if sort: + logger.debug('Sorting Process instances by dependencies....') + dereferenced_process = order_graph_by_dependencies(dereferenced_process) + logger.debug('Sorting process is over.') - return dereferenced_process if len(dereferenced_process) > 1 else dereferenced_process[0] + return dereferenced_process if len(dereferenced_process) > 1 else dereferenced_process[0] + finally: + _load_depth -= 1 def load_cwl_from_stream( content: TextIO, @@ -392,10 +514,39 @@ def dump_cwl( _yaml.dump(data=data, stream=stream) +def _inject_custom_reqs_into_item(item: dict, custom_reqs: Any) -> None: + """ + Reinject *custom_reqs* (list or dict form) into ``item['requirements']``. + + All custom requirements (including calrissian:DaskGatewayRequirement) are + injected into ``requirements`` so that Calrissian can find them — it reads + DaskGatewayRequirement from ``requirements``, not ``hints``. + """ + if 'requirements' not in item or not isinstance(item['requirements'], list): + item['requirements'] = [] + + if isinstance(custom_reqs, list): + for custom_req in custom_reqs: + item['requirements'].append(custom_req) + elif isinstance(custom_reqs, dict): + for req_name, req_value in custom_reqs.items(): + custom_req_entry: dict = {'class': req_name} + if isinstance(req_value, dict): + custom_req_entry.update(req_value) + elif req_value is not None: + logger.warning( + f"Custom requirement '{req_name}' has a non-mapping value " + f"{req_value!r}; only the 'class' key will be emitted in the " + "serialised output." + ) + item['requirements'].append(custom_req_entry) + + def dump_cwl_with_custom_requirements( process: Process | List[Process], stream: TextIO, - custom_requirements_cache: Mapping[str, Any] | None = None + custom_requirements_cache: Optional[Mapping[str, Any]] = None, + original_namespaces: Optional[Mapping[str, Any]] = None ): ''' Serializes a CWL document with custom requirements properly reinjected into the requirements section. @@ -407,21 +558,25 @@ def dump_cwl_with_custom_requirements( `process` (`Processes`): The CWL Process or Processes (if the CWL document is a `$graph`) `stream` (`Stream`): The stream where serializing the CWL document `custom_requirements_cache` (`Mapping[str, Any]`, optional): Cache of custom requirements. - If None, uses the global _custom_requirements_cache. + If None, uses the module-level cache. + `original_namespaces` (`Mapping[str, Any]`, optional): Saved $namespaces mapping. + If None, uses the module-level cache. Returns: `None`: none. ''' if custom_requirements_cache is None: custom_requirements_cache = _custom_requirements_cache + if original_namespaces is None: + original_namespaces = _original_namespaces data = save( val=process, # type: ignore relative_uris=False ) - if '__root__' in _original_namespaces: - data['$namespaces'] = _original_namespaces['__root__'] + if '__root__' in original_namespaces: + data['$namespaces'] = original_namespaces['__root__'] logger.debug(f"Restored original $namespaces: {data['$namespaces']}") if '$graph' in data and isinstance(data['$graph'], list): @@ -434,34 +589,23 @@ def dump_cwl_with_custom_requirements( if '$namespaces' in item: del item['$namespaces'] - if item_id and item_id in custom_requirements_cache: - custom_reqs = custom_requirements_cache[item_id] - - if 'requirements' not in item: - item['requirements'] = [] - - if not isinstance(item['requirements'], list): - item['requirements'] = [] - - # Add custom requirements in the same format as standard requirements - if isinstance(custom_reqs, list): - for custom_req in custom_reqs: - item['requirements'].append(custom_req) - elif isinstance(custom_reqs, dict): - for req_name, req_value in custom_reqs.items(): - # Create a requirement dict in CWL format with 'class' field - custom_req_entry = {'class': req_name} - - # Add all properties from req_value - if isinstance(req_value, dict): - custom_req_entry.update(req_value) - - item['requirements'].append(custom_req_entry) + custom_reqs = _lookup_in_cache(item_id, custom_requirements_cache) + if custom_reqs is not None: + _inject_custom_reqs_into_item(item, custom_reqs) + else: + # Single top-level process (no $graph wrapper). + item_id = data.get('id') if isinstance(data, dict) else None + custom_reqs = _lookup_in_cache(item_id, custom_requirements_cache) + if custom_reqs is None: + # Fallback: top-level processes without an id were cached under '__top__'. + custom_reqs = custom_requirements_cache.get('__top__') + if custom_reqs is not None and isinstance(data, dict): + _inject_custom_reqs_into_item(data, custom_reqs) _yaml.dump(data=data, stream=stream) def extract_dask_config( - custom_requirements_cache: Mapping[str, Any] | None = None + custom_requirements_cache: Optional[Mapping[str, Any]] = None ) -> Mapping[str, Any]: ''' Extracts Dask Gateway configuration from custom requirements cache. @@ -471,7 +615,7 @@ def extract_dask_config( Args: `custom_requirements_cache` (`Mapping[str, Any]`, optional): Cache of custom requirements. - If None, uses the global _custom_requirements_cache. + If None, uses the module-level cache. Returns: `Mapping[str, Any]`: Dictionary containing all fields found in the diff --git a/tests/test_custom_requirements.py b/tests/test_custom_requirements.py new file mode 100644 index 0000000..307feb7 --- /dev/null +++ b/tests/test_custom_requirements.py @@ -0,0 +1,198 @@ +# Copyright 2026 Terradue +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import StringIO +from unittest import TestCase + +from cwl_utils.parser import Process +from ruamel.yaml import YAML + +from cwl_loader import ( + dump_cwl_with_custom_requirements, + extract_dask_config, + load_cwl_from_string_content, + _custom_requirements_cache, + _original_namespaces, +) + +_yaml = YAML() + +# --------------------------------------------------------------------------- +# Minimal valid CWL fixtures +# --------------------------------------------------------------------------- + +_SIMPLE_CWL = """\ +cwlVersion: v1.2 +class: CommandLineTool +id: simple-tool +baseCommand: echo +inputs: + msg: + type: string + inputBinding: + position: 1 +outputs: + out: + type: stdout +""" + +_GRAPH_CWL_WITH_CUSTOM_REQ = """\ +cwlVersion: v1.2 +$namespaces: + calrissian: "http://calrissian.example.com/" +$graph: +- class: CommandLineTool + id: echo-tool + baseCommand: echo + requirements: + DockerRequirement: + dockerPull: "alpine:latest" + calrissian:DaskGatewayRequirement: + gateway_url: "http://gateway.example.com" + worker_cores: 2 + worker_memory: "4G" + inputs: + message: + type: string + inputBinding: + position: 1 + outputs: + out: + type: stdout +- class: Workflow + id: main + inputs: + message: + type: string + outputs: + result: + type: File + outputSource: step1/out + steps: + step1: + run: "#echo-tool" + in: + message: message + out: [out] +""" + +_SINGLE_CWL_WITH_CUSTOM_REQ = """\ +cwlVersion: v1.2 +$namespaces: + calrissian: "http://calrissian.example.com/" +class: CommandLineTool +id: echo-tool +baseCommand: echo +requirements: + DockerRequirement: + dockerPull: "alpine:latest" + calrissian:DaskGatewayRequirement: + gateway_url: "http://gateway.example.com" + worker_cores: 2 + worker_memory: "4G" +inputs: + message: + type: string + inputBinding: + position: 1 +outputs: + out: + type: stdout +""" + + +class TestCustomRequirements(TestCase): + @classmethod + def setUpClass(cls): + process = load_cwl_from_string_content(_GRAPH_CWL_WITH_CUSTOM_REQ) + out = StringIO() + dump_cwl_with_custom_requirements(process, out) + cls._graph_data = _yaml.load(out.getvalue()) + + process = load_cwl_from_string_content(_SINGLE_CWL_WITH_CUSTOM_REQ) + out = StringIO() + dump_cwl_with_custom_requirements(process, out) + cls._single_data = _yaml.load(out.getvalue()) + + def test_graph_with_custom_req_parses_successfully(self): + result = load_cwl_from_string_content(_GRAPH_CWL_WITH_CUSTOM_REQ) + self.assertIsNotNone(result) + self.assertIsInstance(result, list) + self.assertEqual(2, len(result)) + + def test_single_process_with_custom_req_parses_successfully(self): + result = load_cwl_from_string_content(_SINGLE_CWL_WITH_CUSTOM_REQ) + self.assertIsNotNone(result) + self.assertIsInstance(result, Process) + + def test_graph_dump_roundtrip_restores_namespaces(self): + self.assertIn("$namespaces", self._graph_data) + self.assertIn("calrissian", self._graph_data["$namespaces"]) + + def test_graph_dump_roundtrip_reinjects_custom_req(self): + tool_item = next( + (item for item in self._graph_data["$graph"] if item.get("class") == "CommandLineTool"), + None, + ) + self.assertIsNotNone(tool_item) + reqs = tool_item.get("requirements", []) + dask_req = [r for r in reqs if "DaskGatewayRequirement" in r.get("class", "")] + self.assertEqual(1, len(dask_req)) + + def test_single_process_dump_roundtrip_reinjects_custom_req(self): + reqs = self._single_data.get("requirements", []) + dask_req = [r for r in reqs if "DaskGatewayRequirement" in r.get("class", "")] + self.assertEqual(1, len(dask_req)) + + def test_extract_dask_config_from_graph(self): + load_cwl_from_string_content(_GRAPH_CWL_WITH_CUSTOM_REQ) + config = extract_dask_config() + self.assertEqual("http://gateway.example.com", config.get("gateway_url")) + self.assertEqual(2, config.get("worker_cores")) + self.assertEqual("4G", config.get("worker_memory")) + + def test_extract_dask_config_from_single_process(self): + load_cwl_from_string_content(_SINGLE_CWL_WITH_CUSTOM_REQ) + config = extract_dask_config() + self.assertEqual("http://gateway.example.com", config.get("gateway_url")) + + def test_extract_dask_config_with_explicit_cache(self): + explicit_cache = { + "my-tool": {"calrissian:DaskGatewayRequirement": {"gateway_url": "http://explicit.example.com"}} + } + config = extract_dask_config(custom_requirements_cache=explicit_cache) + self.assertEqual("http://explicit.example.com", config.get("gateway_url")) + + def test_extract_dask_config_returns_empty_when_absent(self): + config = extract_dask_config(custom_requirements_cache={}) + self.assertEqual({}, config) + + def test_successive_loads_do_not_leak_cache(self): + load_cwl_from_string_content(_SINGLE_CWL_WITH_CUSTOM_REQ) + self.assertTrue( + len(_custom_requirements_cache) > 0, + "Cache should be populated after first load", + ) + + load_cwl_from_string_content(_SIMPLE_CWL) + self.assertEqual( + {}, + dict(_custom_requirements_cache), + "Cache must be empty after loading a doc without custom reqs", + ) + self.assertEqual( + {}, + dict(_original_namespaces), + "Namespace store must be empty after loading a doc without $namespaces", + ) From 5ed8bcc571a3e82fe6ef8a2a06cb833b5d74e4ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A9rald=20Fenoy?= Date: Fri, 3 Apr 2026 10:20:05 +0200 Subject: [PATCH 3/4] Pass ruff test --- src/cwl_loader/__init__.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/cwl_loader/__init__.py b/src/cwl_loader/__init__.py index d4d4001..e509ab3 100644 --- a/src/cwl_loader/__init__.py +++ b/src/cwl_loader/__init__.py @@ -24,10 +24,6 @@ from loguru import logger from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap -from ruamel.yaml.scalarstring import ScalarString -from ruamel.yaml.scalarfloat import ScalarFloat -from ruamel.yaml.scalarint import ScalarInt -from pathlib import Path from typing import ( Any, List, @@ -151,8 +147,8 @@ def _clean_custom_namespaces( Extract custom namespaced requirements and record ``$namespaces`` for later restoration. - Custom requirements — those whose dict key (dict-form) or ``class`` value - (list-form) contains a colon — are removed so that the standard CWL parser + Custom requirements - those whose dict key (dict-form) or ``class`` value + (list-form) contains a colon - are removed so that the standard CWL parser does not reject them. Both ``$graph`` documents and single top-level process documents are handled. @@ -296,7 +292,7 @@ def load_cwl_from_yaml( Returns: `Processes`: The parsed CWL Process or Processes (if the CWL document is a `$graph`). - ''' + """ global _load_depth # At the top-level load (not a recursive call from _dereference_steps) clear @@ -500,7 +496,7 @@ def _inject_custom_reqs_into_item(item: dict, custom_reqs: Any) -> None: Reinject *custom_reqs* (list or dict form) into ``item['requirements']``. All custom requirements (including calrissian:DaskGatewayRequirement) are - injected into ``requirements`` so that Calrissian can find them — it reads + injected into ``requirements`` so that Calrissian can find them - it reads DaskGatewayRequirement from ``requirements``, not ``hints``. """ if 'requirements' not in item or not isinstance(item['requirements'], list): @@ -529,7 +525,7 @@ def dump_cwl_with_custom_requirements( custom_requirements_cache: Optional[Mapping[str, Any]] = None, original_namespaces: Optional[Mapping[str, Any]] = None ): - ''' + """ Serializes a CWL document with custom requirements properly reinjected into the requirements section. This function ensures that custom namespaced requirements (like calrissian:DaskGatewayRequirement) @@ -545,7 +541,7 @@ def dump_cwl_with_custom_requirements( Returns: `None`: none. - ''' + """ if custom_requirements_cache is None: custom_requirements_cache = _custom_requirements_cache if original_namespaces is None: @@ -588,7 +584,7 @@ def dump_cwl_with_custom_requirements( def extract_dask_config( custom_requirements_cache: Optional[Mapping[str, Any]] = None ) -> Mapping[str, Any]: - ''' + """ Extracts Dask Gateway configuration from custom requirements cache. This utility function searches for DaskGatewayRequirement in the custom requirements @@ -603,7 +599,7 @@ def extract_dask_config( DaskGatewayRequirement (except the `class` key when the requirement is represented as a list item). Returns empty dict if no DaskGatewayRequirement found. - ''' + """ if custom_requirements_cache is None: custom_requirements_cache = _custom_requirements_cache From d318023ab106f1275cd78e8513e893b6fa798f33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A9rald=20Fenoy?= Date: Fri, 3 Apr 2026 10:24:36 +0200 Subject: [PATCH 4/4] Use ruff lint before pushing --- src/cwl_loader/__init__.py | 188 +++++++++++++++--------------- src/cwl_loader/utils.py | 4 +- tests/test_custom_requirements.py | 12 +- 3 files changed, 109 insertions(+), 95 deletions(-) diff --git a/src/cwl_loader/__init__.py b/src/cwl_loader/__init__.py index e509ab3..bed0832 100644 --- a/src/cwl_loader/__init__.py +++ b/src/cwl_loader/__init__.py @@ -24,18 +24,8 @@ from loguru import logger from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap -from typing import ( - Any, - List, - Mapping, - Optional, - TextIO, - Tuple -) -from urllib.parse import ( - urlparse, - urldefrag -) +from typing import Any, List, Mapping, Optional, TextIO, Tuple +from urllib.parse import urlparse, urldefrag import copy import requests import os @@ -56,11 +46,7 @@ _load_depth: int = 0 -def _extract_custom_reqs_from_item( - item: dict, - item_id: str, - req_cache: dict -) -> None: +def _extract_custom_reqs_from_item(item: dict, item_id: str, req_cache: dict) -> None: """ Remove custom namespaced requirements from ``item['requirements']`` (and ``item['hints']`` as fallback) in-place, storing them in *req_cache* keyed @@ -76,64 +62,68 @@ def _extract_custom_reqs_from_item( collected: list = [] # --- process requirements --- - reqs = item.get('requirements') + reqs = item.get("requirements") if isinstance(reqs, dict): custom_reqs: dict = {} standard_reqs: dict = {} for req_name, req_value in reqs.items(): - if ':' in str(req_name): + if ":" in str(req_name): logger.debug(f"Storing custom requirement for {item_id}: {req_name}") custom_reqs[req_name] = req_value else: standard_reqs[req_name] = req_value if custom_reqs: - collected.append(('dict', custom_reqs)) - item['requirements'] = standard_reqs + collected.append(("dict", custom_reqs)) + item["requirements"] = standard_reqs elif isinstance(reqs, list): custom_reqs_list: list = [] standard_reqs_list: list = [] for req in reqs: if isinstance(req, dict): - req_class = req.get('class', '') - if ':' in str(req_class): - logger.debug(f"Storing custom requirement for {item_id}: {req_class}") + req_class = req.get("class", "") + if ":" in str(req_class): + logger.debug( + f"Storing custom requirement for {item_id}: {req_class}" + ) custom_reqs_list.append(req) else: standard_reqs_list.append(req) else: standard_reqs_list.append(req) if custom_reqs_list: - collected.append(('list', custom_reqs_list)) - item['requirements'] = standard_reqs_list + collected.append(("list", custom_reqs_list)) + item["requirements"] = standard_reqs_list # --- process hints (fallback: custom reqs may have landed here) --- - hints = item.get('hints') + hints = item.get("hints") if isinstance(hints, list): custom_hints: list = [] standard_hints: list = [] for hint in hints: if isinstance(hint, dict): - hint_class = hint.get('class', '') - if ':' in str(hint_class): - logger.debug(f"Storing custom hint as requirement for {item_id}: {hint_class}") + hint_class = hint.get("class", "") + if ":" in str(hint_class): + logger.debug( + f"Storing custom hint as requirement for {item_id}: {hint_class}" + ) custom_hints.append(hint) else: standard_hints.append(hint) else: standard_hints.append(hint) if custom_hints: - collected.append(('list', custom_hints)) - item['hints'] = standard_hints + collected.append(("list", custom_hints)) + item["hints"] = standard_hints # Merge all collected custom reqs into a single list for this item if collected: merged: list = [] for form, data in collected: - if form == 'list': + if form == "list": merged.extend(data) else: # dict form → convert to list form for uniform injection for req_name, req_value in data.items(): - entry: dict = {'class': req_name} + entry: dict = {"class": req_name} if isinstance(req_value, dict): entry.update(req_value) merged.append(entry) @@ -141,7 +131,7 @@ def _extract_custom_reqs_from_item( def _clean_custom_namespaces( - raw_process: Mapping[str, Any] + raw_process: Mapping[str, Any], ) -> Tuple[Mapping[str, Any], dict, dict]: """ Extract custom namespaced requirements and record ``$namespaces`` for later @@ -168,30 +158,34 @@ def _clean_custom_namespaces( empty dict otherwise. """ # Shallow-copy the top level so we do not mutate the caller's mapping. - cleaned: Any = raw_process.copy() if isinstance(raw_process, dict) else CommentedMap(raw_process) + cleaned: Any = ( + raw_process.copy() + if isinstance(raw_process, dict) + else CommentedMap(raw_process) + ) req_cache: dict = {} ns_store: dict = {} - if '$namespaces' in cleaned: - ns_store['__root__'] = dict(cleaned['$namespaces']) + if "$namespaces" in cleaned: + ns_store["__root__"] = dict(cleaned["$namespaces"]) logger.debug(f"Saved original $namespaces: {ns_store['__root__']}") - if '$graph' in cleaned and isinstance(cleaned['$graph'], list): + if "$graph" in cleaned and isinstance(cleaned["$graph"], list): # Rebuild the $graph list using deep copies of each item so that we can # mutate requirements without touching the caller's original objects. new_graph = [] - for item in cleaned['$graph']: + for item in cleaned["$graph"]: if isinstance(item, dict): item = copy.deepcopy(item) - item_id = item.get('id', 'unknown') + item_id = item.get("id", "unknown") _extract_custom_reqs_from_item(item, item_id, req_cache) new_graph.append(item) - cleaned['$graph'] = new_graph - elif 'requirements' in cleaned: + cleaned["$graph"] = new_graph + elif "requirements" in cleaned: # Single top-level process (CommandLineTool / Workflow / …). # Deep-copy the entire cleaned document before mutating it. cleaned = copy.deepcopy(cleaned) - item_id = cleaned.get('id', '__top__') + item_id = cleaned.get("id", "__top__") _extract_custom_reqs_from_item(cleaned, item_id, req_cache) return cleaned, req_cache, ns_store @@ -208,7 +202,7 @@ def _lookup_in_cache(item_id: Optional[str], cache: Mapping[str, Any]) -> Option return None if item_id in cache: return cache[item_id] - for sep in ('#', '/'): + for sep in ("#", "/"): if sep in str(item_id): short = str(item_id).split(sep)[-1] if short in cache: @@ -306,7 +300,9 @@ def load_cwl_from_yaml( try: # Clean custom namespaces and requirements before processing. # _clean_custom_namespaces never mutates raw_process and returns local dicts. - cleaned_process, local_req_cache, local_ns = _clean_custom_namespaces(raw_process) + cleaned_process, local_req_cache, local_ns = _clean_custom_namespaces( + raw_process + ) # Merge per-document caches into the module globals so they are accessible # via get_custom_requirements / extract_dask_config without an explicit arg. @@ -316,22 +312,28 @@ def load_cwl_from_yaml( updated_process = cleaned_process if cwl_version != cleaned_process[__CWL_VERSION__]: - logger.debug(f"Updating the model from version '{cleaned_process[__CWL_VERSION__]}' to version '{cwl_version}'...") + logger.debug( + f"Updating the model from version '{cleaned_process[__CWL_VERSION__]}' to version '{cwl_version}'..." + ) updated_process = update( - doc=cleaned_process if isinstance(cleaned_process, CommentedMap) else CommentedMap(OrderedDict(cleaned_process)), + doc=cleaned_process + if isinstance(cleaned_process, CommentedMap) + else CommentedMap(OrderedDict(cleaned_process)), loader=_global_loader, baseuri=uri, enable_dev=False, - metadata=CommentedMap(OrderedDict({'cwlVersion': cwl_version})), - update_to=cwl_version + metadata=CommentedMap(OrderedDict({"cwlVersion": cwl_version})), + update_to=cwl_version, ) logger.debug(f"Raw CWL document successfully updated to {cwl_version}!") else: - logger.debug(f"No needs to update the Raw CWL document since it targets already the {cwl_version}") + logger.debug( + f"No needs to update the Raw CWL document since it targets already the {cwl_version}" + ) - logger.debug('Parsing the raw CWL document to the CWL Utils DOM...') + logger.debug("Parsing the raw CWL document to the CWL Utils DOM...") clean_uri, fragment = urldefrag(uri) @@ -339,39 +341,41 @@ def load_cwl_from_yaml( logger.debug(f"Ignoring fragment #{fragment} from URI {clean_uri}") process = load_document_by_yaml( - yaml=updated_process, - uri=clean_uri, - load_all=True + yaml=updated_process, uri=clean_uri, load_all=True ) - logger.debug('Raw CWL document successfully parsed to the CWL Utils DOM!') + logger.debug("Raw CWL document successfully parsed to the CWL Utils DOM!") - logger.debug('Dereferencing the steps[].run...') + logger.debug("Dereferencing the steps[].run...") - dereferenced_process = _dereference_steps( - process=process, - uri=uri - ) + dereferenced_process = _dereference_steps(process=process, uri=uri) - logger.debug('steps[].run successfully dereferenced! Dereferencing the FQNs...') + logger.debug("steps[].run successfully dereferenced! Dereferencing the FQNs...") remove_refs(dereferenced_process) - logger.debug('CWL document successfully dereferenced! Now verifying steps[].run integrity...') + logger.debug( + "CWL document successfully dereferenced! Now verifying steps[].run integrity..." + ) assert_connected_graph(dereferenced_process) - logger.debug('All steps[].run link are resolvable! ') + logger.debug("All steps[].run link are resolvable! ") if sort: - logger.debug('Sorting Process instances by dependencies....') + logger.debug("Sorting Process instances by dependencies....") dereferenced_process = order_graph_by_dependencies(dereferenced_process) - logger.debug('Sorting process is over.') + logger.debug("Sorting process is over.") - return dereferenced_process if len(dereferenced_process) > 1 else dereferenced_process[0] + return ( + dereferenced_process + if len(dereferenced_process) > 1 + else dereferenced_process[0] + ) finally: _load_depth -= 1 + def load_cwl_from_stream( content: TextIO, uri: str = __DEFAULT_BASE_URI__, @@ -491,6 +495,7 @@ def dump_cwl(process: Process | List[Process], stream: TextIO): _yaml.dump(data=data, stream=stream) + def _inject_custom_reqs_into_item(item: dict, custom_reqs: Any) -> None: """ Reinject *custom_reqs* (list or dict form) into ``item['requirements']``. @@ -499,15 +504,15 @@ def _inject_custom_reqs_into_item(item: dict, custom_reqs: Any) -> None: injected into ``requirements`` so that Calrissian can find them - it reads DaskGatewayRequirement from ``requirements``, not ``hints``. """ - if 'requirements' not in item or not isinstance(item['requirements'], list): - item['requirements'] = [] + if "requirements" not in item or not isinstance(item["requirements"], list): + item["requirements"] = [] if isinstance(custom_reqs, list): for custom_req in custom_reqs: - item['requirements'].append(custom_req) + item["requirements"].append(custom_req) elif isinstance(custom_reqs, dict): for req_name, req_value in custom_reqs.items(): - custom_req_entry: dict = {'class': req_name} + custom_req_entry: dict = {"class": req_name} if isinstance(req_value, dict): custom_req_entry.update(req_value) elif req_value is not None: @@ -516,14 +521,14 @@ def _inject_custom_reqs_into_item(item: dict, custom_reqs: Any) -> None: f"{req_value!r}; only the 'class' key will be emitted in the " "serialised output." ) - item['requirements'].append(custom_req_entry) + item["requirements"].append(custom_req_entry) def dump_cwl_with_custom_requirements( process: Process | List[Process], stream: TextIO, custom_requirements_cache: Optional[Mapping[str, Any]] = None, - original_namespaces: Optional[Mapping[str, Any]] = None + original_namespaces: Optional[Mapping[str, Any]] = None, ): """ Serializes a CWL document with custom requirements properly reinjected into the requirements section. @@ -548,41 +553,42 @@ def dump_cwl_with_custom_requirements( original_namespaces = _original_namespaces data = save( - val=process, # type: ignore - relative_uris=False + val=process, # type: ignore + relative_uris=False, ) - if '__root__' in original_namespaces: - data['$namespaces'] = original_namespaces['__root__'] + if "__root__" in original_namespaces: + data["$namespaces"] = original_namespaces["__root__"] logger.debug(f"Restored original $namespaces: {data['$namespaces']}") - if '$graph' in data and isinstance(data['$graph'], list): - for item in data['$graph']: + if "$graph" in data and isinstance(data["$graph"], list): + for item in data["$graph"]: if isinstance(item, dict): - item_id = item.get('id') + item_id = item.get("id") - if 'cwlVersion' in item: - del item['cwlVersion'] - if '$namespaces' in item: - del item['$namespaces'] + if "cwlVersion" in item: + del item["cwlVersion"] + if "$namespaces" in item: + del item["$namespaces"] custom_reqs = _lookup_in_cache(item_id, custom_requirements_cache) if custom_reqs is not None: _inject_custom_reqs_into_item(item, custom_reqs) else: # Single top-level process (no $graph wrapper). - item_id = data.get('id') if isinstance(data, dict) else None + item_id = data.get("id") if isinstance(data, dict) else None custom_reqs = _lookup_in_cache(item_id, custom_requirements_cache) if custom_reqs is None: # Fallback: top-level processes without an id were cached under '__top__'. - custom_reqs = custom_requirements_cache.get('__top__') + custom_reqs = custom_requirements_cache.get("__top__") if custom_reqs is not None and isinstance(data, dict): _inject_custom_reqs_into_item(data, custom_reqs) _yaml.dump(data=data, stream=stream) + def extract_dask_config( - custom_requirements_cache: Optional[Mapping[str, Any]] = None + custom_requirements_cache: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, Any]: """ Extracts Dask Gateway configuration from custom requirements cache. @@ -606,16 +612,16 @@ def extract_dask_config( for item_id, reqs in custom_requirements_cache.items(): if isinstance(reqs, dict): for req_name, req_value in reqs.items(): - if 'DaskGatewayRequirement' in req_name: + if "DaskGatewayRequirement" in req_name: logger.debug(f"Found DaskGatewayRequirement in {item_id}") return dict(req_value) if isinstance(req_value, dict) else {} elif isinstance(reqs, list): for req in reqs: if isinstance(req, dict): - req_class = req.get('class', '') - if 'DaskGatewayRequirement' in req_class: + req_class = req.get("class", "") + if "DaskGatewayRequirement" in req_class: logger.debug(f"Found DaskGatewayRequirement in {item_id}") - return {k: v for k, v in req.items() if k != 'class'} + return {k: v for k, v in req.items() if k != "class"} logger.debug("No DaskGatewayRequirement found in custom requirements cache") return {} diff --git a/src/cwl_loader/utils.py b/src/cwl_loader/utils.py index 8b59f75..3ca30d9 100644 --- a/src/cwl_loader/utils.py +++ b/src/cwl_loader/utils.py @@ -95,10 +95,10 @@ def remove_refs(process: Process | List[Process]): if getattr(step, "run", None): step.run = step.run[step.run.rfind("#") :] - if getattr(step, 'scatter', None): + if getattr(step, "scatter", None): cleaned_scatter = _clean_values(step.scatter, f"#{process.id}/") step.scatter = _clean_values(cleaned_scatter, f"{step.id}/") - + if process.extension_fields and ORIGINAL_CWLVERSION in process.extension_fields: process.extension_fields.pop(ORIGINAL_CWLVERSION) diff --git a/tests/test_custom_requirements.py b/tests/test_custom_requirements.py index 307feb7..a30ecff 100644 --- a/tests/test_custom_requirements.py +++ b/tests/test_custom_requirements.py @@ -142,7 +142,11 @@ def test_graph_dump_roundtrip_restores_namespaces(self): def test_graph_dump_roundtrip_reinjects_custom_req(self): tool_item = next( - (item for item in self._graph_data["$graph"] if item.get("class") == "CommandLineTool"), + ( + item + for item in self._graph_data["$graph"] + if item.get("class") == "CommandLineTool" + ), None, ) self.assertIsNotNone(tool_item) @@ -169,7 +173,11 @@ def test_extract_dask_config_from_single_process(self): def test_extract_dask_config_with_explicit_cache(self): explicit_cache = { - "my-tool": {"calrissian:DaskGatewayRequirement": {"gateway_url": "http://explicit.example.com"}} + "my-tool": { + "calrissian:DaskGatewayRequirement": { + "gateway_url": "http://explicit.example.com" + } + } } config = extract_dask_config(custom_requirements_cache=explicit_cache) self.assertEqual("http://explicit.example.com", config.get("gateway_url"))