Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions src/op_system/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
re.DOTALL,
)
_SUM_OVER_RE = re.compile(
r"sum_over\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*=\s*([A-Za-z_][A-Za-z0-9_]*)\s*,\s*(.*?)\)",
r"sum_over\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*=\s*([A-Za-z_][A-Za-z0-9_]*)"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note for future work: desire to avoid needing regular expressions. I think the ideal here is figuring out how to rely on the equation / symbol parsing libs.

r"(?:\s+IN\s+\[([^\[\]]*)\])?\s*,\s*(.*?)\)",
re.DOTALL,
)
_PLACEHOLDER_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*\[(.*?)\]")
Expand Down Expand Up @@ -1413,6 +1414,49 @@ def _template_replacer(
return out


def _apply_coord_filter(
filter_str: str | None,
*,
axis_name: str,
all_coords: list[str],
) -> list[str]:
"""Return the coord subset to iterate over, or all coords if no filter.

Args:
filter_str: Raw ``IN [...]`` content (e.g. ``"v, w"``), or ``None``.
axis_name: Axis name — used in error messages only.
all_coords: All valid coords for the axis.

Returns:
Filtered coord list (preserves order of the filter, not the axis).
"""
if filter_str is None:
return all_coords
requested = [c.strip() for c in filter_str.split(",") if c.strip()]
if not requested:
_raise_invalid_rhs_spec(
detail=f"sum_over IN filter for axis {axis_name!r} is empty",
)
seen: set[str] = set()
for coord in requested:
if coord in seen:
_raise_invalid_rhs_spec(
detail=(
f"sum_over IN filter for axis {axis_name!r} "
f"contains duplicate coord {coord!r}"
),
)
seen.add(coord)
if coord not in all_coords:
_raise_invalid_rhs_spec(
detail=(
f"sum_over IN filter references unknown coord {coord!r} "
f"for axis {axis_name!r} (valid: {all_coords})"
),
)
return requested


def _expand_sum_over(expr: str, *, axes: list[dict[str, Any]]) -> str:
"""Expand sum_over(axis=var, inner_expr) for categorical axes.

Expand Down Expand Up @@ -1447,8 +1491,12 @@ def _axis_coords(ax_name: str) -> list[str]:
break
axis_name = m.group(1)
var_name = m.group(2)
inner = m.group(3)
coords = _axis_coords(axis_name)
inner = m.group(4)
coords = _apply_coord_filter(
m.group(3),
axis_name=axis_name,
all_coords=_axis_coords(axis_name),
)
terms: list[str] = []
for coord in coords:
# Replace var_name occurrences with coord (as identifier-safe string)
Expand Down
28 changes: 28 additions & 0 deletions tests/op_system/test_op_system_compile.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's error like for, say, a missing close ]? we should figure out (perhaps not in this PR) how to handle that sort of problem gracefully and with useful feedback.

Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,31 @@ def test_compile_spec_preserves_meta() -> None:

assert "axes" in compiled.meta
assert compiled.meta["axes"][0]["name"] == "space"


def test_sum_over_in_filter_evaluates_correctly() -> None:
"""sum_over IN filter compiles and evaluates correctly end-to-end."""
spec = {
"kind": "expr",
"axes": [{"name": "vax", "coords": ["u", "v", "w"]}],
"state": ["S[vax]"],
"equations": {
# Constant-rate decay so eval_fn output is predictable
"S[vax]": "-S[vax]",
},
"aliases": {
"covered": "sum_over(vax=j IN [v, w], S[vax=j])",
},
}
rhs = normalize_rhs(spec)

# NormalizedRhs.aliases holds the expanded string
assert "S__vax_v" in rhs.aliases["covered"]
assert "S__vax_w" in rhs.aliases["covered"]
assert "S__vax_u" not in rhs.aliases["covered"]

# Compile and run eval_fn: state = [u=10, v=3, w=7], dS/dt = -S
compiled = compile_rhs(rhs)
state = np.array([10.0, 3.0, 7.0], dtype=np.float64)
derivs = compiled.eval_fn(np.float64(0.0), state)
assert np.allclose(derivs, -state)
96 changes: 96 additions & 0 deletions tests/op_system/test_op_system_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,3 +1109,99 @@ def test_coord_shift_rejects_bad_arrow_syntax() -> None:
}
with pytest.raises(ValueError, match="from_coord -> to_coord"):
normalize_transitions_rhs(spec)


# sum_over IN filter tests
# ---------------------------------------------------------------------------


def test_sum_over_in_filter_subsets_coords() -> None:
"""sum_over with IN filter sums only the listed coords."""
spec = {
"kind": "expr",
"axes": [{"name": "vax", "coords": ["u", "v", "w"]}],
"state": ["S[vax]"],
"equations": {
"S[vax]": "-S[vax]",
},
"aliases": {
"covered": "sum_over(vax=j IN [v, w], S[vax=j])",
},
}
out = normalize_expr_rhs(spec)
covered_eq = out.aliases["covered"]
# Only v and w should appear, not u
assert "S__vax_v" in covered_eq
assert "S__vax_w" in covered_eq
assert "S__vax_u" not in covered_eq


def test_sum_over_no_filter_unchanged() -> None:
"""sum_over without IN filter still expands over all coords."""
spec = {
"kind": "expr",
"axes": [{"name": "vax", "coords": ["u", "v", "w"]}],
"state": ["S[vax]"],
"equations": {"S[vax]": "-S[vax]"},
"aliases": {"N": "sum_over(vax=j, S[vax=j])"},
}
out = normalize_expr_rhs(spec)
n_eq = out.aliases["N"]
assert "S__vax_u" in n_eq
assert "S__vax_v" in n_eq
assert "S__vax_w" in n_eq


def test_sum_over_in_filter_single_coord() -> None:
"""sum_over IN filter with a single coord produces a single term."""
spec = {
"kind": "expr",
"axes": [{"name": "vax", "coords": ["u", "v", "w"]}],
"state": ["S[vax]"],
"equations": {"S[vax]": "-S[vax]"},
"aliases": {"just_u": "sum_over(vax=j IN [u], S[vax=j])"},
}
out = normalize_expr_rhs(spec)
eq = out.aliases["just_u"]
assert "S__vax_u" in eq
assert "S__vax_v" not in eq
assert "S__vax_w" not in eq


def test_sum_over_in_filter_rejects_unknown_coord() -> None:
"""sum_over IN filter with an unknown coord raises."""
spec = {
"kind": "expr",
"axes": [{"name": "vax", "coords": ["u", "v", "w"]}],
"state": ["S[vax]"],
"equations": {"S[vax]": "-S[vax]"},
"aliases": {"bad": "sum_over(vax=j IN [v, z], S[vax=j])"},
}
with pytest.raises(ValueError, match=r"unknown coord.*z.*vax"):
normalize_expr_rhs(spec)


def test_sum_over_in_filter_rejects_empty_filter() -> None:
"""sum_over IN [] with no coords raises."""
spec = {
"kind": "expr",
"axes": [{"name": "vax", "coords": ["u", "v", "w"]}],
"state": ["S[vax]"],
"equations": {"S[vax]": "-S[vax]"},
"aliases": {"bad": "sum_over(vax=j IN [], S[vax=j])"},
}
with pytest.raises(ValueError, match=r"IN filter.*empty"):
normalize_expr_rhs(spec)


def test_sum_over_in_filter_rejects_duplicate_coord() -> None:
"""sum_over IN filter with a repeated coord raises."""
spec = {
"kind": "expr",
"axes": [{"name": "vax", "coords": ["u", "v", "w"]}],
"state": ["S[vax]"],
"equations": {"S[vax]": "-S[vax]"},
"aliases": {"bad": "sum_over(vax=j IN [v, v], S[vax=j])"},
}
with pytest.raises(ValueError, match=r"duplicate coord"):
normalize_expr_rhs(spec)
Loading