Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions parse_errors/source_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,23 @@ def detect_format(path: Path) -> str | None:
suffix = path.suffix.lower()
return {
".toml": "toml",
".yaml": "yaml",
".yml": "yaml",
}.get(suffix)


def build_source_map(source: str | bytes, fmt: str) -> TSourceMap:
"""Build a source map for the given source in the given format."""
if fmt == "toml":
from .toml_source_map import calculate
from . import toml_source_map

return calculate(source)
return toml_source_map.calculate(source)
elif fmt in ("yaml", "yml"):
from . import yaml_source_map

return yaml_source_map.calculate(
source.decode("utf-8") if isinstance(source, bytes) else source
)
else:
raise ValueError(f"Unknown format: {fmt!r}")

Expand Down
62 changes: 62 additions & 0 deletions parse_errors/yaml_source_map/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Calculate the source map for a YAML document."""

from __future__ import annotations

import yaml
from ..source_map import Entry, Location, TSourceMap


def calculate(source: str) -> TSourceMap:
"""Calculate the source map for a YAML document.

Args:
source: The YAML document as a string.

Returns:
A dict mapping JSON Pointer paths to Entry objects with location info.
"""
loader = yaml.SafeLoader(source)
node = loader.get_single_node()
if node is None:
return {}
result: TSourceMap = {}
_walk(node, "", result, loader)
return result


def _location(mark: yaml.Mark) -> Location:
return Location(line=mark.line, column=mark.column, position=mark.index)


def _walk(
node: yaml.Node, path: str, result: TSourceMap, loader: yaml.SafeLoader
) -> None:
value_start = _location(node.start_mark)
value_end = _location(node.end_mark)

if isinstance(node, yaml.MappingNode):
result[path] = Entry(value_start=value_start, value_end=value_end)
for key_node, value_node in node.value:
key = loader.construct_scalar(key_node)
child_path = f"{path}/{_escape(str(key))}"
key_start = _location(key_node.start_mark)
key_end = _location(key_node.end_mark)
_walk(value_node, child_path, result, loader)
existing = result[child_path]
result[child_path] = Entry(
value_start=existing.value_start,
value_end=existing.value_end,
key_start=key_start,
key_end=key_end,
)
elif isinstance(node, yaml.SequenceNode):
result[path] = Entry(value_start=value_start, value_end=value_end)
for i, item_node in enumerate(node.value):
_walk(item_node, f"{path}/{i}", result, loader)
else:
result[path] = Entry(value_start=value_start, value_end=value_end)


def _escape(key: str) -> str:
"""Escape a key for use in a JSON Pointer (RFC 6901)."""
return key.replace("~", "~0").replace("/", "~1")
7 changes: 7 additions & 0 deletions parse_errors/yaml_source_map/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
if __name__ == "__main__": # pragma: no cover
import sys
from . import calculate

source = open(sys.argv[1]).read()
for pointer, entry in calculate(source).items():
print(f"{pointer!r:40s} {entry}")
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ setup_requires =
setuptools >= 65
include_package_data = true
install_requires =
pyyaml
tree-sitter
tree-sitter-toml

Expand All @@ -26,9 +27,11 @@ dev =
ruff == 0.15.6
tox == 4.50.0
tox-uv == 1.33.4
types-pyyaml
test =
coverage >= 6
pytest >= 8
msgspec

[options.entry_points]
# console_scripts =
Expand Down
48 changes: 48 additions & 0 deletions tests/test_parse_context_yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
import msgspec

from parse_errors import ParseContext, ParseError

from ._types import Config, Nested

YAML_SOURCE = """\
host: localhost
port: not-an-int
"""

YAML_NESTED_SOURCE = """\
server:
host: localhost
port: not-an-int
"""


def test_yaml_raises_parse_error():
with pytest.raises(ParseError) as exc_info:
with ParseContext("config.yaml", data=YAML_SOURCE):
msgspec.yaml.decode(YAML_SOURCE.encode(), type=Config)

err = exc_info.value
assert err.filename == "config.yaml"
assert err.line == 2
assert str(err) == "config.yaml:2:7: Expected `int`, got `str` - at `$.port`"


def test_yaml_bytes_data():
with pytest.raises(ParseError) as exc_info:
with ParseContext("config.yaml", data=YAML_SOURCE.encode()):
msgspec.yaml.decode(YAML_SOURCE.encode(), type=Config)

assert (
str(exc_info.value)
== "config.yaml:2:7: Expected `int`, got `str` - at `$.port`"
)


def test_yaml_nested_raises_parse_error():
with pytest.raises(ParseError) as exc_info:
with ParseContext("config.yaml", data=YAML_NESTED_SOURCE):
msgspec.yaml.decode(YAML_NESTED_SOURCE.encode(), type=Nested)

err = exc_info.value
assert str(err) == "config.yaml:3:9: Expected `int`, got `str` - at `$.server.port`"
Loading