From 27d14a5b787c6eb9340097e6fce71707e645f45d Mon Sep 17 00:00:00 2001 From: Alex Santisteban Date: Wed, 26 Feb 2025 11:13:09 +0100 Subject: [PATCH 01/10] feat: add line numbers to validation errors --- craft_application/application.py | 12 ++++--- craft_application/errors.py | 5 +-- craft_application/models/base.py | 5 +-- craft_application/util/__init__.py | 3 +- craft_application/util/error_formatting.py | 22 ++++++++++--- craft_application/util/yaml.py | 38 ++++++++++++++++++++-- 6 files changed, 68 insertions(+), 17 deletions(-) diff --git a/craft_application/application.py b/craft_application/application.py index 71fa7ff2b..73367a5ad 100644 --- a/craft_application/application.py +++ b/craft_application/application.py @@ -29,6 +29,7 @@ from functools import cached_property from importlib import metadata from typing import TYPE_CHECKING, Any, cast, final +from pydantic.v1.utils import deep_update import craft_cli import craft_parts @@ -380,7 +381,8 @@ def get_project( craft_cli.emit.debug(f"Loading project file '{project_path!s}'") with project_path.open() as file: - yaml_data = util.safe_yaml_load(file) + yaml_data = util.safe_yaml_load(file, include_line_nums=True) + flattened_yaml_data = util.flatten_yaml_data(yaml_data) host_arch = util.get_host_architecture() build_planner = self.app.BuildPlannerClass.from_yaml_data( @@ -390,7 +392,7 @@ def get_project( self._build_plan = filter_plan( self._full_build_plan, platform, build_for, host_arch ) - + if not build_for: # get the build-for arch from the platform if platform: @@ -405,13 +407,13 @@ def get_project( build_for = self._build_plan[0].build_for # validate project grammar - GrammarAwareProject.validate_grammar(yaml_data) + GrammarAwareProject.validate_grammar(flattened_yaml_data) build_on = host_arch # Setup partitions, some projects require the yaml data, most will not - self._partitions = self._setup_partitions(yaml_data) - yaml_data = self._transform_project_yaml(yaml_data, build_on, build_for) + self._partitions = self._setup_partitions(flattened_yaml_data) + yaml_data = deep_update(yaml_data, self._transform_project_yaml(flattened_yaml_data, build_on, build_for)) self.__project = self.app.ProjectClass.from_yaml_data(yaml_data, project_path) # check if mandatory adoptable fields exist if adopt-info not used diff --git a/craft_application/errors.py b/craft_application/errors.py index a0635730e..1438bf314 100644 --- a/craft_application/errors.py +++ b/craft_application/errors.py @@ -21,7 +21,7 @@ import os from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import Any, TYPE_CHECKING import yaml from craft_cli import CraftError @@ -71,6 +71,7 @@ def from_pydantic( error: pydantic.ValidationError, *, file_name: str = "yaml file", + validated_object: dict[str, Any] | None = None, **kwargs: str | bool | int | None, ) -> Self: """Convert this error from a pydantic ValidationError. @@ -80,7 +81,7 @@ def from_pydantic( :param doc_slug: The optional slug to this error's docs. :param kwargs: additional keyword arguments get passed to CraftError """ - message = format_pydantic_errors(error.errors(), file_name=file_name) + message = format_pydantic_errors(error.errors(), file_name=file_name, validated_object=validated_object) return cls(message, **kwargs) # type: ignore[arg-type] diff --git a/craft_application/models/base.py b/craft_application/models/base.py index dcb01bf51..b2f777d30 100644 --- a/craft_application/models/base.py +++ b/craft_application/models/base.py @@ -75,7 +75,7 @@ def from_yaml_data(cls, data: dict[str, Any], filepath: pathlib.Path) -> Self: :param filepath: The filepath corresponding to ``data``, for error reporting. """ try: - return cls.unmarshal(data) + return cls.unmarshal(util.flatten_yaml_data(data)) except pydantic.ValidationError as err: cls.transform_pydantic_error(err) raise errors.CraftValidationError.from_pydantic( @@ -83,6 +83,7 @@ def from_yaml_data(cls, data: dict[str, Any], filepath: pathlib.Path) -> Self: file_name=filepath.name, doc_slug=cls.model_reference_slug(), logpath_report=False, + validated_object=data ) from None def to_yaml_file(self, path: pathlib.Path) -> None: @@ -108,4 +109,4 @@ def transform_pydantic_error(cls, error: pydantic.ValidationError) -> None: @classmethod def model_reference_slug(cls) -> str | None: """Get the slug to this model class' reference docs.""" - return None + return None \ No newline at end of file diff --git a/craft_application/util/__init__.py b/craft_application/util/__init__.py index c95cb71f0..c02be4cd3 100644 --- a/craft_application/util/__init__.py +++ b/craft_application/util/__init__.py @@ -35,7 +35,7 @@ ) from craft_application.util.string import humanize_list, strtobool from craft_application.util.system import get_parallel_build_count -from craft_application.util.yaml import dump_yaml, safe_yaml_load +from craft_application.util.yaml import dump_yaml, safe_yaml_load, flatten_yaml_data from craft_application.util.cli import format_timestamp __all__ = [ @@ -55,6 +55,7 @@ "get_host_base", "dump_yaml", "safe_yaml_load", + "flatten_yaml_data", "retry", "get_parallel_build_count", "get_hostname", diff --git a/craft_application/util/error_formatting.py b/craft_application/util/error_formatting.py index c6fdbf5b1..c952d212c 100644 --- a/craft_application/util/error_formatting.py +++ b/craft_application/util/error_formatting.py @@ -17,7 +17,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import NamedTuple +from typing import Any, NamedTuple from pydantic import error_wrappers @@ -44,7 +44,7 @@ def from_str(cls, loc_str: str) -> FieldLocationTuple: return cls(field, location) -def format_pydantic_error(loc: Iterable[str | int], message: str) -> str: +def format_pydantic_error(loc: Iterable[str | int], message: str, validated_object : dict[str, Any] | None = None) -> str: """Format a single pydantic ErrorDict as a string. :param loc: An iterable of strings and integers determining the error location. @@ -53,11 +53,23 @@ def format_pydantic_error(loc: Iterable[str | int], message: str) -> str: Can be pulled from the "msg" field of a pydantic ErrorDict. :returns: A formatted error. """ + line_num = None + if validated_object is not None: + for i,l in enumerate(loc): + if i == len(loc) - 1 and f"__line__{l}" in validated_object: + line_num = validated_object[f"__line__{l}"] + elif type(validated_object) == dict and l in validated_object: + validated_object = validated_object.get(l) + elif type(validated_object) == list and type(l) == int: + validated_object = validated_object[l] + field_path = _format_pydantic_error_location(loc) message = _format_pydantic_error_message(message) field_name, location = FieldLocationTuple.from_str(field_path) if location != "top-level": location = repr(location) + if line_num is not None: + location += f" - line {line_num}" if message == "field required": return f"- field {field_name!r} required in {location} configuration" @@ -67,11 +79,11 @@ def format_pydantic_error(loc: Iterable[str | int], message: str) -> str: return f"- duplicate {field_name!r} entry not permitted in {location} configuration" if field_path in ("__root__", ""): return f"- {message}" - return f"- {message} (in field {field_path!r})" + return f"- {message} (in field {field_path!r}" + (f" - line {line_num})" if line_num else ")") def format_pydantic_errors( - errors: Iterable[error_wrappers.ErrorDict], *, file_name: str = "yaml file" + errors: Iterable[error_wrappers.ErrorDict], *, file_name: str = "yaml file", validated_object: dict[str, Any] | None ) -> str: """Format errors. @@ -86,7 +98,7 @@ def format_pydantic_errors( - field: reason: . """ - messages = (format_pydantic_error(error["loc"], error["msg"]) for error in errors) + messages = (format_pydantic_error(error["loc"], error["msg"], validated_object) for error in errors) return "\n".join((f"Bad {file_name} content:", *messages)) diff --git a/craft_application/util/yaml.py b/craft_application/util/yaml.py index 1111da5a1..5f04cb4db 100644 --- a/craft_application/util/yaml.py +++ b/craft_application/util/yaml.py @@ -21,6 +21,10 @@ from typing import TYPE_CHECKING, Any, TextIO, cast, overload import yaml +from yaml.composer import Composer +from yaml.constructor import Constructor +from yaml.nodes import ScalarNode +from yaml.resolver import BaseResolver from craft_application import errors @@ -95,8 +99,28 @@ def __init__(self, stream: TextIO) -> None: yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _dict_constructor ) + def compose_node(self, parent, index): + # the line number where the previous token has ended (plus empty lines) + line = self.line + node = Composer.compose_node(self, parent, index) + node.__line__ = line + 1 + return node -def safe_yaml_load(stream: TextIO) -> Any: # noqa: ANN401 - The YAML could be anything + def construct_mapping(self, node, deep=False): + node_pair_lst = node.value + node_pair_lst_for_appending = [] + + for key_node, _ in node_pair_lst: + shadow_key_node = ScalarNode(tag=BaseResolver.DEFAULT_SCALAR_TAG, value='__line__' + key_node.value) + shadow_value_node = ScalarNode(tag=BaseResolver.DEFAULT_SCALAR_TAG, value=key_node.__line__) + node_pair_lst_for_appending.append((shadow_key_node, shadow_value_node)) + + node.value = node_pair_lst + node_pair_lst_for_appending + mapping = Constructor.construct_mapping(self, node, deep=deep) + return mapping + + +def safe_yaml_load(stream: TextIO, include_line_nums = False) -> Any: # noqa: ANN401 - The YAML could be anything """Equivalent to pyyaml's safe_load function, but constraining duplicate keys. :param stream: Any text-like IO object. @@ -105,7 +129,8 @@ def safe_yaml_load(stream: TextIO) -> Any: # noqa: ANN401 - The YAML could be a try: # Silencing S506 ("probable use of unsafe loader") because we override it by # using our own safe loader. - return yaml.load(stream, Loader=_SafeYamlLoader) # noqa: S506 + result = yaml.load(stream, Loader=_SafeYamlLoader) # noqa: S506 + return result if include_line_nums else flatten_yaml_data(result) except yaml.YAMLError as error: filename = pathlib.Path(stream.name).name raise errors.YamlError.from_yaml_error(filename, error) from error @@ -143,3 +168,12 @@ def dump_yaml(data: Any, stream: TextIO | None = None, **kwargs: Any) -> str | N return cast( # This cast is needed for pyright but not mypy str | None, yaml.dump(data, stream, Dumper=yaml.SafeDumper, **kwargs) ) + + +def flatten_yaml_data(data : dict[str, Any]) -> dict[str, Any]: + """ + Recursively flattens a nested dictionary by removing the '__line__' fields. + """ + if type(data) is not dict: + return data + return { k:flatten_yaml_data(v) for k,v in data.items() if "__line__" not in k} \ No newline at end of file From 8ce42b6d99541eaf468cb6d7c4edd935fd96f22b Mon Sep 17 00:00:00 2001 From: Alex Santisteban Date: Wed, 26 Feb 2025 11:13:32 +0100 Subject: [PATCH 02/10] feat: add line nos to grammar validation errors --- craft_application/application.py | 18 ++++++++++++++---- craft_application/models/base.py | 9 +++++++-- craft_application/models/grammar.py | 2 ++ craft_application/util/error_formatting.py | 2 +- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/craft_application/application.py b/craft_application/application.py index 73367a5ad..054207632 100644 --- a/craft_application/application.py +++ b/craft_application/application.py @@ -21,6 +21,7 @@ import importlib import os import pathlib +import pydantic import signal import subprocess import sys @@ -382,7 +383,6 @@ def get_project( with project_path.open() as file: yaml_data = util.safe_yaml_load(file, include_line_nums=True) - flattened_yaml_data = util.flatten_yaml_data(yaml_data) host_arch = util.get_host_architecture() build_planner = self.app.BuildPlannerClass.from_yaml_data( @@ -407,13 +407,23 @@ def get_project( build_for = self._build_plan[0].build_for # validate project grammar - GrammarAwareProject.validate_grammar(flattened_yaml_data) + try: + GrammarAwareProject.validate_grammar(yaml_data) + except pydantic.ValidationError as err: + raise errors.CraftValidationError.from_pydantic( + err, + file_name=project_path.name, + doc_slug="common/craft-parts/reference/part_properties", + logpath_report=False, + validated_object=yaml_data + ) from None + build_on = host_arch # Setup partitions, some projects require the yaml data, most will not - self._partitions = self._setup_partitions(flattened_yaml_data) - yaml_data = deep_update(yaml_data, self._transform_project_yaml(flattened_yaml_data, build_on, build_for)) + self._partitions = self._setup_partitions(yaml_data) + yaml_data = deep_update(yaml_data, self._transform_project_yaml(yaml_data, build_on, build_for)) self.__project = self.app.ProjectClass.from_yaml_data(yaml_data, project_path) # check if mandatory adoptable fields exist if adopt-info not used diff --git a/craft_application/models/base.py b/craft_application/models/base.py index b2f777d30..cc6e61347 100644 --- a/craft_application/models/base.py +++ b/craft_application/models/base.py @@ -17,9 +17,10 @@ from __future__ import annotations import pathlib -from typing import Any +from typing import Any, Dict import pydantic +from pydantic import model_validator from typing_extensions import Self from craft_application import errors, util @@ -41,6 +42,10 @@ class CraftBaseModel(pydantic.BaseModel): coerce_numbers_to_str=True, ) + @model_validator(mode="before") + def flatten(cls, values: Dict[str, Any]) -> Dict[str, Any]: + return util.flatten_yaml_data(values) + def marshal(self) -> dict[str, str | list[str] | dict[str, Any]]: """Convert to a dictionary.""" return self.model_dump(mode="json", by_alias=True, exclude_unset=True) @@ -75,7 +80,7 @@ def from_yaml_data(cls, data: dict[str, Any], filepath: pathlib.Path) -> Self: :param filepath: The filepath corresponding to ``data``, for error reporting. """ try: - return cls.unmarshal(util.flatten_yaml_data(data)) + return cls.unmarshal(data) except pydantic.ValidationError as err: cls.transform_pydantic_error(err) raise errors.CraftValidationError.from_pydantic( diff --git a/craft_application/models/grammar.py b/craft_application/models/grammar.py index 794d0f8fa..be323d3b8 100644 --- a/craft_application/models/grammar.py +++ b/craft_application/models/grammar.py @@ -18,6 +18,7 @@ from typing import Any import pydantic +from craft_application import util from craft_grammar.models import Grammar # type: ignore[import-untyped] from pydantic import ConfigDict @@ -91,6 +92,7 @@ def _ensure_parts(cls, data: dict[str, Any]) -> dict[str, Any]: item defined, set it to an empty dictionary. This is distinct from having `parts` be invalid, which is not coerced here. """ + data = util.flatten_yaml_data(data) data.setdefault("parts", {}) return data diff --git a/craft_application/util/error_formatting.py b/craft_application/util/error_formatting.py index c952d212c..29d734ea1 100644 --- a/craft_application/util/error_formatting.py +++ b/craft_application/util/error_formatting.py @@ -83,7 +83,7 @@ def format_pydantic_error(loc: Iterable[str | int], message: str, validated_obje def format_pydantic_errors( - errors: Iterable[error_wrappers.ErrorDict], *, file_name: str = "yaml file", validated_object: dict[str, Any] | None + errors: Iterable[error_wrappers.ErrorDict], *, file_name: str = "yaml file", validated_object: dict[str, Any] | None = None ) -> str: """Format errors. From b44e78efd90db9b8da7b323604376439a0189afe Mon Sep 17 00:00:00 2001 From: Alex Santisteban Date: Wed, 26 Feb 2025 12:34:23 +0100 Subject: [PATCH 03/10] style: run formatters --- craft_application/application.py | 9 +++++---- craft_application/errors.py | 4 +++- craft_application/models/base.py | 6 +++--- craft_application/util/error_formatting.py | 22 +++++++++++++++++----- craft_application/util/yaml.py | 16 +++++++++++----- 5 files changed, 39 insertions(+), 18 deletions(-) diff --git a/craft_application/application.py b/craft_application/application.py index 054207632..e9df90013 100644 --- a/craft_application/application.py +++ b/craft_application/application.py @@ -392,7 +392,7 @@ def get_project( self._build_plan = filter_plan( self._full_build_plan, platform, build_for, host_arch ) - + if not build_for: # get the build-for arch from the platform if platform: @@ -415,15 +415,16 @@ def get_project( file_name=project_path.name, doc_slug="common/craft-parts/reference/part_properties", logpath_report=False, - validated_object=yaml_data + validated_object=yaml_data, ) from None - build_on = host_arch # Setup partitions, some projects require the yaml data, most will not self._partitions = self._setup_partitions(yaml_data) - yaml_data = deep_update(yaml_data, self._transform_project_yaml(yaml_data, build_on, build_for)) + yaml_data = deep_update( + yaml_data, self._transform_project_yaml(yaml_data, build_on, build_for) + ) self.__project = self.app.ProjectClass.from_yaml_data(yaml_data, project_path) # check if mandatory adoptable fields exist if adopt-info not used diff --git a/craft_application/errors.py b/craft_application/errors.py index 1438bf314..fcb40bf78 100644 --- a/craft_application/errors.py +++ b/craft_application/errors.py @@ -81,7 +81,9 @@ def from_pydantic( :param doc_slug: The optional slug to this error's docs. :param kwargs: additional keyword arguments get passed to CraftError """ - message = format_pydantic_errors(error.errors(), file_name=file_name, validated_object=validated_object) + message = format_pydantic_errors( + error.errors(), file_name=file_name, validated_object=validated_object + ) return cls(message, **kwargs) # type: ignore[arg-type] diff --git a/craft_application/models/base.py b/craft_application/models/base.py index cc6e61347..01d67ecac 100644 --- a/craft_application/models/base.py +++ b/craft_application/models/base.py @@ -45,7 +45,7 @@ class CraftBaseModel(pydantic.BaseModel): @model_validator(mode="before") def flatten(cls, values: Dict[str, Any]) -> Dict[str, Any]: return util.flatten_yaml_data(values) - + def marshal(self) -> dict[str, str | list[str] | dict[str, Any]]: """Convert to a dictionary.""" return self.model_dump(mode="json", by_alias=True, exclude_unset=True) @@ -88,7 +88,7 @@ def from_yaml_data(cls, data: dict[str, Any], filepath: pathlib.Path) -> Self: file_name=filepath.name, doc_slug=cls.model_reference_slug(), logpath_report=False, - validated_object=data + validated_object=data, ) from None def to_yaml_file(self, path: pathlib.Path) -> None: @@ -114,4 +114,4 @@ def transform_pydantic_error(cls, error: pydantic.ValidationError) -> None: @classmethod def model_reference_slug(cls) -> str | None: """Get the slug to this model class' reference docs.""" - return None \ No newline at end of file + return None diff --git a/craft_application/util/error_formatting.py b/craft_application/util/error_formatting.py index 29d734ea1..92ef5f72e 100644 --- a/craft_application/util/error_formatting.py +++ b/craft_application/util/error_formatting.py @@ -44,7 +44,11 @@ def from_str(cls, loc_str: str) -> FieldLocationTuple: return cls(field, location) -def format_pydantic_error(loc: Iterable[str | int], message: str, validated_object : dict[str, Any] | None = None) -> str: +def format_pydantic_error( + loc: Iterable[str | int], + message: str, + validated_object: dict[str, Any] | None = None, +) -> str: """Format a single pydantic ErrorDict as a string. :param loc: An iterable of strings and integers determining the error location. @@ -55,7 +59,7 @@ def format_pydantic_error(loc: Iterable[str | int], message: str, validated_obje """ line_num = None if validated_object is not None: - for i,l in enumerate(loc): + for i, l in enumerate(loc): if i == len(loc) - 1 and f"__line__{l}" in validated_object: line_num = validated_object[f"__line__{l}"] elif type(validated_object) == dict and l in validated_object: @@ -79,11 +83,16 @@ def format_pydantic_error(loc: Iterable[str | int], message: str, validated_obje return f"- duplicate {field_name!r} entry not permitted in {location} configuration" if field_path in ("__root__", ""): return f"- {message}" - return f"- {message} (in field {field_path!r}" + (f" - line {line_num})" if line_num else ")") + return f"- {message} (in field {field_path!r}" + ( + f" - line {line_num})" if line_num else ")" + ) def format_pydantic_errors( - errors: Iterable[error_wrappers.ErrorDict], *, file_name: str = "yaml file", validated_object: dict[str, Any] | None = None + errors: Iterable[error_wrappers.ErrorDict], + *, + file_name: str = "yaml file", + validated_object: dict[str, Any] | None = None, ) -> str: """Format errors. @@ -98,7 +107,10 @@ def format_pydantic_errors( - field: reason: . """ - messages = (format_pydantic_error(error["loc"], error["msg"], validated_object) for error in errors) + messages = ( + format_pydantic_error(error["loc"], error["msg"], validated_object) + for error in errors + ) return "\n".join((f"Bad {file_name} content:", *messages)) diff --git a/craft_application/util/yaml.py b/craft_application/util/yaml.py index 5f04cb4db..4b286c059 100644 --- a/craft_application/util/yaml.py +++ b/craft_application/util/yaml.py @@ -111,8 +111,12 @@ def construct_mapping(self, node, deep=False): node_pair_lst_for_appending = [] for key_node, _ in node_pair_lst: - shadow_key_node = ScalarNode(tag=BaseResolver.DEFAULT_SCALAR_TAG, value='__line__' + key_node.value) - shadow_value_node = ScalarNode(tag=BaseResolver.DEFAULT_SCALAR_TAG, value=key_node.__line__) + shadow_key_node = ScalarNode( + tag=BaseResolver.DEFAULT_SCALAR_TAG, value="__line__" + key_node.value + ) + shadow_value_node = ScalarNode( + tag=BaseResolver.DEFAULT_SCALAR_TAG, value=key_node.__line__ + ) node_pair_lst_for_appending.append((shadow_key_node, shadow_value_node)) node.value = node_pair_lst + node_pair_lst_for_appending @@ -120,7 +124,9 @@ def construct_mapping(self, node, deep=False): return mapping -def safe_yaml_load(stream: TextIO, include_line_nums = False) -> Any: # noqa: ANN401 - The YAML could be anything +def safe_yaml_load( + stream: TextIO, include_line_nums=False +) -> Any: # noqa: ANN401 - The YAML could be anything """Equivalent to pyyaml's safe_load function, but constraining duplicate keys. :param stream: Any text-like IO object. @@ -170,10 +176,10 @@ def dump_yaml(data: Any, stream: TextIO | None = None, **kwargs: Any) -> str | N ) -def flatten_yaml_data(data : dict[str, Any]) -> dict[str, Any]: +def flatten_yaml_data(data: dict[str, Any]) -> dict[str, Any]: """ Recursively flattens a nested dictionary by removing the '__line__' fields. """ if type(data) is not dict: return data - return { k:flatten_yaml_data(v) for k,v in data.items() if "__line__" not in k} \ No newline at end of file + return {k: flatten_yaml_data(v) for k, v in data.items() if "__line__" not in k} From 5be6cf2663eee387150fdac9018a12dee542df1d Mon Sep 17 00:00:00 2001 From: Alex Santisteban Date: Wed, 26 Feb 2025 13:03:07 +0100 Subject: [PATCH 04/10] fix: prevent crash due to line numbers in grammar check --- craft_application/grammar.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/craft_application/grammar.py b/craft_application/grammar.py index 3af035b81..ffa86d29b 100644 --- a/craft_application/grammar.py +++ b/craft_application/grammar.py @@ -121,6 +121,9 @@ def self_check(value: Any) -> bool: # noqa: ANN401 processor = GrammarProcessor(arch=arch, target_arch=target_arch, checker=self_check) for part_name, part_data in parts_yaml_data.items(): + # Ignore line numbers coming from yaml reader + if part_name.startswith("__line__"): + continue parts_yaml_data[part_name] = process_part( part_yaml_data=part_data, processor=processor ) From d12536086a80b5deb25317b68b7d7cf6d5274524 Mon Sep 17 00:00:00 2001 From: Alex Santisteban Date: Wed, 26 Feb 2025 13:31:23 +0100 Subject: [PATCH 05/10] style: run ruff linter --- craft_application/application.py | 9 ++-- craft_application/errors.py | 2 +- craft_application/models/base.py | 7 ++-- craft_application/models/grammar.py | 2 +- craft_application/util/__init__.py | 8 +++- craft_application/util/error_formatting.py | 14 +++---- craft_application/util/yaml.py | 48 ++++++++++++++++------ 7 files changed, 58 insertions(+), 32 deletions(-) diff --git a/craft_application/application.py b/craft_application/application.py index b642161c0..6d9c41fd2 100644 --- a/craft_application/application.py +++ b/craft_application/application.py @@ -21,7 +21,6 @@ import importlib import os import pathlib -import pydantic import signal import subprocess import sys @@ -31,11 +30,11 @@ from functools import cached_property from importlib import metadata from typing import TYPE_CHECKING, Any, cast, final -from pydantic.v1.utils import deep_update import craft_cli import craft_parts import craft_providers +import pydantic from craft_parts.plugins.plugins import PluginType from platformdirs import user_cache_path @@ -381,7 +380,7 @@ def get_project( craft_cli.emit.debug(f"Loading project file '{project_path!s}'") with project_path.open() as file: - yaml_data = util.safe_yaml_load(file, include_line_nums=True) + yaml_data = util.safe_yaml_load_with_lines(file) host_arch = util.get_host_architecture() build_planner = self.app.BuildPlannerClass.from_yaml_data( @@ -421,9 +420,7 @@ def get_project( # Setup partitions, some projects require the yaml data, most will not self._partitions = self._setup_partitions(yaml_data) - yaml_data = deep_update( - yaml_data, self._transform_project_yaml(yaml_data, build_on, build_for) - ) + yaml_data = self._transform_project_yaml(yaml_data, build_on, build_for) self.__project = self.app.ProjectClass.from_yaml_data(yaml_data, project_path) # check if mandatory adoptable fields exist if adopt-info not used diff --git a/craft_application/errors.py b/craft_application/errors.py index 4933ce30d..4b7e10eb2 100644 --- a/craft_application/errors.py +++ b/craft_application/errors.py @@ -22,7 +22,7 @@ import os from collections.abc import Sequence -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any import yaml from craft_cli import CraftError diff --git a/craft_application/models/base.py b/craft_application/models/base.py index 435e6d0db..57901d88d 100644 --- a/craft_application/models/base.py +++ b/craft_application/models/base.py @@ -18,7 +18,7 @@ from __future__ import annotations import pathlib -from typing import Any, Dict +from typing import Any import pydantic from pydantic import model_validator @@ -42,9 +42,10 @@ class CraftBaseModel(pydantic.BaseModel): alias_generator=alias_generator, coerce_numbers_to_str=True, ) - + @model_validator(mode="before") - def flatten(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @classmethod + def _flatten(cls, values: dict[str, Any]) -> dict[str, Any]: return util.flatten_yaml_data(values) def marshal(self) -> dict[str, str | list[str] | dict[str, Any]]: diff --git a/craft_application/models/grammar.py b/craft_application/models/grammar.py index 0710756bc..495c9f91e 100644 --- a/craft_application/models/grammar.py +++ b/craft_application/models/grammar.py @@ -19,10 +19,10 @@ from typing import Any import pydantic -from craft_application import util from craft_grammar.models import Grammar # type: ignore[import-untyped] from pydantic import ConfigDict +from craft_application import util from craft_application.models.base import alias_generator from craft_application.models.constraints import SingleEntryDict diff --git a/craft_application/util/__init__.py b/craft_application/util/__init__.py index c02be4cd3..b790ac36b 100644 --- a/craft_application/util/__init__.py +++ b/craft_application/util/__init__.py @@ -35,7 +35,12 @@ ) from craft_application.util.string import humanize_list, strtobool from craft_application.util.system import get_parallel_build_count -from craft_application.util.yaml import dump_yaml, safe_yaml_load, flatten_yaml_data +from craft_application.util.yaml import ( + dump_yaml, + safe_yaml_load, + safe_yaml_load_with_lines, + flatten_yaml_data, +) from craft_application.util.cli import format_timestamp __all__ = [ @@ -55,6 +60,7 @@ "get_host_base", "dump_yaml", "safe_yaml_load", + "safe_yaml_load_with_lines", "flatten_yaml_data", "retry", "get_parallel_build_count", diff --git a/craft_application/util/error_formatting.py b/craft_application/util/error_formatting.py index 0dfbeff04..93790a0f8 100644 --- a/craft_application/util/error_formatting.py +++ b/craft_application/util/error_formatting.py @@ -60,13 +60,13 @@ def format_pydantic_error( """ line_num = None if validated_object is not None: - for i, l in enumerate(loc): - if i == len(loc) - 1 and f"__line__{l}" in validated_object: - line_num = validated_object[f"__line__{l}"] - elif type(validated_object) == dict and l in validated_object: - validated_object = validated_object.get(l) - elif type(validated_object) == list and type(l) == int: - validated_object = validated_object[l] + for i, location in enumerate(loc): + if i == len(loc) - 1 and f"__line__{location}" in validated_object: + line_num = validated_object[f"__line__{location}"] + elif type(validated_object) is dict and location in validated_object: + validated_object = validated_object.get(location) + elif type(validated_object) is list and type(location) is int: + validated_object = validated_object[location] field_path = _format_pydantic_error_location(loc) message = _format_pydantic_error_message(message) diff --git a/craft_application/util/yaml.py b/craft_application/util/yaml.py index ddc0d0749..45bc57d40 100644 --- a/craft_application/util/yaml.py +++ b/craft_application/util/yaml.py @@ -24,7 +24,7 @@ import yaml from yaml.composer import Composer from yaml.constructor import Constructor -from yaml.nodes import ScalarNode +from yaml.nodes import MappingNode, Node, ScalarNode from yaml.resolver import BaseResolver from craft_application import errors @@ -100,14 +100,27 @@ def __init__(self, stream: TextIO) -> None: yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _dict_constructor ) - def compose_node(self, parent, index): + +class _SafeLineNoLoader(_SafeYamlLoader): + def __init__(self, stream: TextIO) -> None: + super().__init__(stream) + + self.add_constructor( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _dict_constructor + ) + + def compose_node(self, parent: Node | None, index: int) -> Node: # the line number where the previous token has ended (plus empty lines) line = self.line node = Composer.compose_node(self, parent, index) node.__line__ = line + 1 return node - def construct_mapping(self, node, deep=False): + def construct_mapping( + self, + node: MappingNode, + deep: bool = False, # noqa: FBT001, FBT002 - used internally by yaml.SafeLoader + ) -> dict[Hashable, Any]: node_pair_lst = node.value node_pair_lst_for_appending = [] @@ -121,13 +134,10 @@ def construct_mapping(self, node, deep=False): node_pair_lst_for_appending.append((shadow_key_node, shadow_value_node)) node.value = node_pair_lst + node_pair_lst_for_appending - mapping = Constructor.construct_mapping(self, node, deep=deep) - return mapping + return Constructor.construct_mapping(self, node, deep=deep) -def safe_yaml_load( - stream: TextIO, include_line_nums=False -) -> Any: # noqa: ANN401 - The YAML could be anything +def safe_yaml_load(stream: TextIO) -> Any: # noqa: ANN401 - The YAML could be anything """Equivalent to pyyaml's safe_load function, but constraining duplicate keys. :param stream: Any text-like IO object. @@ -136,8 +146,22 @@ def safe_yaml_load( try: # Silencing S506 ("probable use of unsafe loader") because we override it by # using our own safe loader. - result = yaml.load(stream, Loader=_SafeYamlLoader) # noqa: S506 - return result if include_line_nums else flatten_yaml_data(result) + return yaml.load(stream, Loader=_SafeYamlLoader) # noqa: S506 + except yaml.YAMLError as error: + filename = pathlib.Path(stream.name).name + raise errors.YamlError.from_yaml_error(filename, error) from error + + +def safe_yaml_load_with_lines(stream: TextIO) -> Any: # noqa: ANN401 - The YAML could be anything + """Equivalent to pyyaml's safe_load function, but constraining duplicate keys and including line numbers. + + :param stream: Any text-like IO object. + :returns: A dict object mapping the yaml. + """ + try: + # Silencing S506 ("probable use of unsafe loader") because we override it by + # using our own safe loader. + return yaml.load(stream, Loader=_SafeLineNoLoader) # noqa: S506 except yaml.YAMLError as error: filename = pathlib.Path(stream.name).name raise errors.YamlError.from_yaml_error(filename, error) from error @@ -180,9 +204,7 @@ def dump_yaml(data: Any, stream: TextIO | None = None, **kwargs: Any) -> str | N def flatten_yaml_data(data: dict[str, Any]) -> dict[str, Any]: - """ - Recursively flattens a nested dictionary by removing the '__line__' fields. - """ + """Recursively flattens a nested dictionary by removing the '__line__' fields.""" if type(data) is not dict: return data return {k: flatten_yaml_data(v) for k, v in data.items() if "__line__" not in k} From 8c270b5b172250e61e22368b997eb1ea4b79ca49 Mon Sep 17 00:00:00 2001 From: Alex Santisteban Date: Wed, 26 Feb 2025 13:34:11 +0100 Subject: [PATCH 06/10] style: rename flatten_yaml_data function to remove_yaml_lines --- craft_application/models/base.py | 4 ++-- craft_application/models/grammar.py | 2 +- craft_application/util/__init__.py | 4 ++-- craft_application/util/yaml.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/craft_application/models/base.py b/craft_application/models/base.py index 57901d88d..caa614481 100644 --- a/craft_application/models/base.py +++ b/craft_application/models/base.py @@ -42,11 +42,11 @@ class CraftBaseModel(pydantic.BaseModel): alias_generator=alias_generator, coerce_numbers_to_str=True, ) - + @model_validator(mode="before") @classmethod def _flatten(cls, values: dict[str, Any]) -> dict[str, Any]: - return util.flatten_yaml_data(values) + return util.remove_yaml_lines(values) def marshal(self) -> dict[str, str | list[str] | dict[str, Any]]: """Convert to a dictionary.""" diff --git a/craft_application/models/grammar.py b/craft_application/models/grammar.py index 495c9f91e..3b3342ef0 100644 --- a/craft_application/models/grammar.py +++ b/craft_application/models/grammar.py @@ -93,7 +93,7 @@ def _ensure_parts(cls, data: dict[str, Any]) -> dict[str, Any]: item defined, set it to an empty dictionary. This is distinct from having `parts` be invalid, which is not coerced here. """ - data = util.flatten_yaml_data(data) + data = util.remove_yaml_lines(data) data.setdefault("parts", {}) return data diff --git a/craft_application/util/__init__.py b/craft_application/util/__init__.py index b790ac36b..87ab102be 100644 --- a/craft_application/util/__init__.py +++ b/craft_application/util/__init__.py @@ -39,7 +39,7 @@ dump_yaml, safe_yaml_load, safe_yaml_load_with_lines, - flatten_yaml_data, + remove_yaml_lines, ) from craft_application.util.cli import format_timestamp @@ -61,7 +61,7 @@ "dump_yaml", "safe_yaml_load", "safe_yaml_load_with_lines", - "flatten_yaml_data", + "remove_yaml_lines", "retry", "get_parallel_build_count", "get_hostname", diff --git a/craft_application/util/yaml.py b/craft_application/util/yaml.py index 45bc57d40..aa3454f4a 100644 --- a/craft_application/util/yaml.py +++ b/craft_application/util/yaml.py @@ -203,8 +203,8 @@ def dump_yaml(data: Any, stream: TextIO | None = None, **kwargs: Any) -> str | N ) -def flatten_yaml_data(data: dict[str, Any]) -> dict[str, Any]: +def remove_yaml_lines(data: dict[str, Any]) -> dict[str, Any]: """Recursively flattens a nested dictionary by removing the '__line__' fields.""" if type(data) is not dict: return data - return {k: flatten_yaml_data(v) for k, v in data.items() if "__line__" not in k} + return {k: remove_yaml_lines(v) for k, v in data.items() if "__line__" not in k} From 6e476fb4c5d576fd6dde410dda316b6caeb2ef7e Mon Sep 17 00:00:00 2001 From: Alex Santisteban Date: Thu, 27 Feb 2025 11:20:58 +0100 Subject: [PATCH 07/10] style: fix typings to comply with linter rules --- craft_application/util/error_formatting.py | 36 ++++++++++++++-------- craft_application/util/yaml.py | 6 ++-- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/craft_application/util/error_formatting.py b/craft_application/util/error_formatting.py index 93790a0f8..61e3c854d 100644 --- a/craft_application/util/error_formatting.py +++ b/craft_application/util/error_formatting.py @@ -17,7 +17,7 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from typing import Any, NamedTuple from pydantic import error_wrappers @@ -46,7 +46,7 @@ def from_str(cls, loc_str: str) -> FieldLocationTuple: def format_pydantic_error( - loc: Iterable[str | int], + loc: Sequence[str | int], message: str, validated_object: dict[str, Any] | None = None, ) -> str: @@ -58,18 +58,9 @@ def format_pydantic_error( Can be pulled from the "msg" field of a pydantic ErrorDict. :returns: A formatted error. """ - line_num = None - if validated_object is not None: - for i, location in enumerate(loc): - if i == len(loc) - 1 and f"__line__{location}" in validated_object: - line_num = validated_object[f"__line__{location}"] - elif type(validated_object) is dict and location in validated_object: - validated_object = validated_object.get(location) - elif type(validated_object) is list and type(location) is int: - validated_object = validated_object[location] - field_path = _format_pydantic_error_location(loc) message = _format_pydantic_error_message(message) + line_num = _get_line_number(loc, validated_object) field_name, location = FieldLocationTuple.from_str(field_path) if location != "top-level": location = repr(location) @@ -139,3 +130,24 @@ def _format_pydantic_error_message(msg: str) -> str: if msg: msg = msg[0].lower() + msg[1:] return msg + + +def _get_line_number( + loc: Sequence[str | int], validated_object: dict[str, Any] | None +) -> int | None: + """Return the line number of a key based on its location.""" + if validated_object is None: + return None + + object_value: dict[str, Any] | Sequence[Any] = validated_object + line_number: int | None = None + + for i, location in enumerate(loc): + if isinstance(location, int) and isinstance(object_value, Sequence): + object_value = object_value[location] # type: ignore[arg-type] + elif isinstance(location, str) and isinstance(object_value, dict): + if i == len(loc) - 1 and f"__line__{location}" in object_value: + line_number = object_value[f"__line__{location}"] + elif location in object_value: + object_value = object_value[location] + return line_number diff --git a/craft_application/util/yaml.py b/craft_application/util/yaml.py index aa3454f4a..e35852bc8 100644 --- a/craft_application/util/yaml.py +++ b/craft_application/util/yaml.py @@ -109,11 +109,11 @@ def __init__(self, stream: TextIO) -> None: yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _dict_constructor ) - def compose_node(self, parent: Node | None, index: int) -> Node: + def compose_node(self, parent: Node | None, index: int) -> Node | None: # the line number where the previous token has ended (plus empty lines) line = self.line node = Composer.compose_node(self, parent, index) - node.__line__ = line + 1 + setattr(node, "__line__", line + 1) # noqa: B010 - used internally, prevent mypy error return node def construct_mapping( @@ -134,7 +134,7 @@ def construct_mapping( node_pair_lst_for_appending.append((shadow_key_node, shadow_value_node)) node.value = node_pair_lst + node_pair_lst_for_appending - return Constructor.construct_mapping(self, node, deep=deep) + return Constructor.construct_mapping(self, node, deep=deep) # type: ignore[arg-type] def safe_yaml_load(stream: TextIO) -> Any: # noqa: ANN401 - The YAML could be anything From a412f74432a8bfbc24ba4e8d7731b58bc0d78efc Mon Sep 17 00:00:00 2001 From: Alex Santisteban Date: Thu, 27 Feb 2025 14:03:54 +0100 Subject: [PATCH 08/10] fix: apply transformations to base projectv yaml --- craft_application/application.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/craft_application/application.py b/craft_application/application.py index 6d9c41fd2..bc2ba528c 100644 --- a/craft_application/application.py +++ b/craft_application/application.py @@ -37,6 +37,7 @@ import pydantic from craft_parts.plugins.plugins import PluginType from platformdirs import user_cache_path +from pydantic.v1.utils import deep_update from craft_application import _config, commands, errors, grammar, models, secrets, util from craft_application.errors import PathInvalidError @@ -420,7 +421,12 @@ def get_project( # Setup partitions, some projects require the yaml data, most will not self._partitions = self._setup_partitions(yaml_data) - yaml_data = self._transform_project_yaml(yaml_data, build_on, build_for) + + # Apply transformations to base yaml, then update to preserve line numbers + yaml_base = util.remove_yaml_lines(yaml_data) + yaml_update = self._transform_project_yaml(yaml_base, build_on, build_for) + yaml_data = deep_update(yaml_data, yaml_update) + self.__project = self.app.ProjectClass.from_yaml_data(yaml_data, project_path) # check if mandatory adoptable fields exist if adopt-info not used From 91c93881683da9531da2fea36708313c9972eb0e Mon Sep 17 00:00:00 2001 From: Alex Santisteban Date: Thu, 27 Feb 2025 14:04:30 +0100 Subject: [PATCH 09/10] fix: remove lines from lists in remove_yaml_lines function --- craft_application/util/yaml.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/craft_application/util/yaml.py b/craft_application/util/yaml.py index e35852bc8..646f17505 100644 --- a/craft_application/util/yaml.py +++ b/craft_application/util/yaml.py @@ -203,8 +203,15 @@ def dump_yaml(data: Any, stream: TextIO | None = None, **kwargs: Any) -> str | N ) -def remove_yaml_lines(data: dict[str, Any]) -> dict[str, Any]: +def remove_yaml_lines(data: dict[str, Any] | list[Any]) -> dict[str, Any]: """Recursively flattens a nested dictionary by removing the '__line__' fields.""" + if type(data) is list: + return [remove_yaml_lines(v) for v in data] # type: ignore[return-value] if type(data) is not dict: - return data - return {k: remove_yaml_lines(v) for k, v in data.items() if "__line__" not in k} + return data # type: ignore[return-value] + # k is only None in one test case + return { + k: remove_yaml_lines(v) + for k, v in data.items() + if k is None or "__line__" not in k # type: ignore[reportUnnecessaryComparison] + } From 5ed293e215f7ea12788b0ae8f94f0cff77ce3908 Mon Sep 17 00:00:00 2001 From: Alex Santisteban Date: Thu, 27 Feb 2025 14:04:51 +0100 Subject: [PATCH 10/10] test: modify new error messages in tests --- tests/unit/test_application.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_application.py b/tests/unit/test_application.py index 898f1f75f..f78291383 100644 --- a/tests/unit/test_application.py +++ b/tests/unit/test_application.py @@ -2141,8 +2141,8 @@ def test_build_planner_errors(tmp_path, monkeypatch, fake_services): expected = ( "Bad testcraft.yaml content:\n" - "- bad value1: 10 (in field 'value1')\n" - "- bad value2: banana (in field 'value2')" + "- bad value1: 10 (in field 'value1' - line 3)\n" + "- bad value2: banana (in field 'value2' - line 4)" ) assert str(err.value) == expected