diff --git a/src/op_system/specs.py b/src/op_system/specs.py index 6eac7c8..60e662d 100644 --- a/src/op_system/specs.py +++ b/src/op_system/specs.py @@ -53,7 +53,8 @@ re.DOTALL, ) _SUM_OVER_RE = re.compile( - r"sum_over\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*=\s*([A-Za-z_][A-Za-z0-9_]*)\s*,\s*(.*?)\)", + r"sum_over\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*=\s*([A-Za-z_][A-Za-z0-9_]*)" + r"(?:\s+IN\s+\[([^\[\]]*)\])?\s*,\s*(.*?)\)", re.DOTALL, ) _PLACEHOLDER_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*\[(.*?)\]") @@ -1413,6 +1414,49 @@ def _template_replacer( return out +def _apply_coord_filter( + filter_str: str | None, + *, + axis_name: str, + all_coords: list[str], +) -> list[str]: + """Return the coord subset to iterate over, or all coords if no filter. + + Args: + filter_str: Raw ``IN [...]`` content (e.g. ``"v, w"``), or ``None``. + axis_name: Axis name — used in error messages only. + all_coords: All valid coords for the axis. + + Returns: + Filtered coord list (preserves order of the filter, not the axis). + """ + if filter_str is None: + return all_coords + requested = [c.strip() for c in filter_str.split(",") if c.strip()] + if not requested: + _raise_invalid_rhs_spec( + detail=f"sum_over IN filter for axis {axis_name!r} is empty", + ) + seen: set[str] = set() + for coord in requested: + if coord in seen: + _raise_invalid_rhs_spec( + detail=( + f"sum_over IN filter for axis {axis_name!r} " + f"contains duplicate coord {coord!r}" + ), + ) + seen.add(coord) + if coord not in all_coords: + _raise_invalid_rhs_spec( + detail=( + f"sum_over IN filter references unknown coord {coord!r} " + f"for axis {axis_name!r} (valid: {all_coords})" + ), + ) + return requested + + def _expand_sum_over(expr: str, *, axes: list[dict[str, Any]]) -> str: """Expand sum_over(axis=var, inner_expr) for categorical axes. @@ -1447,8 +1491,12 @@ def _axis_coords(ax_name: str) -> list[str]: break axis_name = m.group(1) var_name = m.group(2) - inner = m.group(3) - coords = _axis_coords(axis_name) + inner = m.group(4) + coords = _apply_coord_filter( + m.group(3), + axis_name=axis_name, + all_coords=_axis_coords(axis_name), + ) terms: list[str] = [] for coord in coords: # Replace var_name occurrences with coord (as identifier-safe string) diff --git a/tests/op_system/test_op_system_compile.py b/tests/op_system/test_op_system_compile.py index a278ea1..5b53a52 100644 --- a/tests/op_system/test_op_system_compile.py +++ b/tests/op_system/test_op_system_compile.py @@ -370,3 +370,31 @@ def test_compile_spec_preserves_meta() -> None: assert "axes" in compiled.meta assert compiled.meta["axes"][0]["name"] == "space" + + +def test_sum_over_in_filter_evaluates_correctly() -> None: + """sum_over IN filter compiles and evaluates correctly end-to-end.""" + spec = { + "kind": "expr", + "axes": [{"name": "vax", "coords": ["u", "v", "w"]}], + "state": ["S[vax]"], + "equations": { + # Constant-rate decay so eval_fn output is predictable + "S[vax]": "-S[vax]", + }, + "aliases": { + "covered": "sum_over(vax=j IN [v, w], S[vax=j])", + }, + } + rhs = normalize_rhs(spec) + + # NormalizedRhs.aliases holds the expanded string + assert "S__vax_v" in rhs.aliases["covered"] + assert "S__vax_w" in rhs.aliases["covered"] + assert "S__vax_u" not in rhs.aliases["covered"] + + # Compile and run eval_fn: state = [u=10, v=3, w=7], dS/dt = -S + compiled = compile_rhs(rhs) + state = np.array([10.0, 3.0, 7.0], dtype=np.float64) + derivs = compiled.eval_fn(np.float64(0.0), state) + assert np.allclose(derivs, -state) diff --git a/tests/op_system/test_op_system_specs.py b/tests/op_system/test_op_system_specs.py index 2cc148a..85ab7f7 100644 --- a/tests/op_system/test_op_system_specs.py +++ b/tests/op_system/test_op_system_specs.py @@ -1109,3 +1109,99 @@ def test_coord_shift_rejects_bad_arrow_syntax() -> None: } with pytest.raises(ValueError, match="from_coord -> to_coord"): normalize_transitions_rhs(spec) + + +# sum_over IN filter tests +# --------------------------------------------------------------------------- + + +def test_sum_over_in_filter_subsets_coords() -> None: + """sum_over with IN filter sums only the listed coords.""" + spec = { + "kind": "expr", + "axes": [{"name": "vax", "coords": ["u", "v", "w"]}], + "state": ["S[vax]"], + "equations": { + "S[vax]": "-S[vax]", + }, + "aliases": { + "covered": "sum_over(vax=j IN [v, w], S[vax=j])", + }, + } + out = normalize_expr_rhs(spec) + covered_eq = out.aliases["covered"] + # Only v and w should appear, not u + assert "S__vax_v" in covered_eq + assert "S__vax_w" in covered_eq + assert "S__vax_u" not in covered_eq + + +def test_sum_over_no_filter_unchanged() -> None: + """sum_over without IN filter still expands over all coords.""" + spec = { + "kind": "expr", + "axes": [{"name": "vax", "coords": ["u", "v", "w"]}], + "state": ["S[vax]"], + "equations": {"S[vax]": "-S[vax]"}, + "aliases": {"N": "sum_over(vax=j, S[vax=j])"}, + } + out = normalize_expr_rhs(spec) + n_eq = out.aliases["N"] + assert "S__vax_u" in n_eq + assert "S__vax_v" in n_eq + assert "S__vax_w" in n_eq + + +def test_sum_over_in_filter_single_coord() -> None: + """sum_over IN filter with a single coord produces a single term.""" + spec = { + "kind": "expr", + "axes": [{"name": "vax", "coords": ["u", "v", "w"]}], + "state": ["S[vax]"], + "equations": {"S[vax]": "-S[vax]"}, + "aliases": {"just_u": "sum_over(vax=j IN [u], S[vax=j])"}, + } + out = normalize_expr_rhs(spec) + eq = out.aliases["just_u"] + assert "S__vax_u" in eq + assert "S__vax_v" not in eq + assert "S__vax_w" not in eq + + +def test_sum_over_in_filter_rejects_unknown_coord() -> None: + """sum_over IN filter with an unknown coord raises.""" + spec = { + "kind": "expr", + "axes": [{"name": "vax", "coords": ["u", "v", "w"]}], + "state": ["S[vax]"], + "equations": {"S[vax]": "-S[vax]"}, + "aliases": {"bad": "sum_over(vax=j IN [v, z], S[vax=j])"}, + } + with pytest.raises(ValueError, match=r"unknown coord.*z.*vax"): + normalize_expr_rhs(spec) + + +def test_sum_over_in_filter_rejects_empty_filter() -> None: + """sum_over IN [] with no coords raises.""" + spec = { + "kind": "expr", + "axes": [{"name": "vax", "coords": ["u", "v", "w"]}], + "state": ["S[vax]"], + "equations": {"S[vax]": "-S[vax]"}, + "aliases": {"bad": "sum_over(vax=j IN [], S[vax=j])"}, + } + with pytest.raises(ValueError, match=r"IN filter.*empty"): + normalize_expr_rhs(spec) + + +def test_sum_over_in_filter_rejects_duplicate_coord() -> None: + """sum_over IN filter with a repeated coord raises.""" + spec = { + "kind": "expr", + "axes": [{"name": "vax", "coords": ["u", "v", "w"]}], + "state": ["S[vax]"], + "equations": {"S[vax]": "-S[vax]"}, + "aliases": {"bad": "sum_over(vax=j IN [v, v], S[vax=j])"}, + } + with pytest.raises(ValueError, match=r"duplicate coord"): + normalize_expr_rhs(spec)