From 9fc4ccf63e4ff815b9717170ef57f0d6575400a8 Mon Sep 17 00:00:00 2001 From: deanban <3989225+deanban@users.noreply.github.com> Date: Thu, 7 May 2026 13:26:08 -0400 Subject: [PATCH] feat: generic mapping planner contract types and graph storage Signed-off-by: deanban <3989225+deanban@users.noreply.github.com> --- .gitignore | 1 + src/sema/graph/loader.py | 20 +- src/sema/graph/loader_utils.py | 22 +- src/sema/graph/materializer_utils.py | 2 +- src/sema/graph/planner_loader.py | 451 ++++++++++++ src/sema/graph/planner_migrations.py | 266 +++++++ src/sema/models/graph_nodes.py | 36 +- src/sema/models/planner/__init__.py | 4 + src/sema/models/planner/_enums.py | 29 + src/sema/models/planner/_refs.py | 55 ++ src/sema/models/planner/_role_validation.py | 64 ++ src/sema/models/planner/field_map.py | 61 ++ src/sema/models/planner/lifecycle.py | 90 +++ src/sema/models/planner/lifecycle_utils.py | 141 ++++ src/sema/models/planner/mapping_plan.py | 137 ++++ src/sema/models/planner/patterns.py | 229 ++++++ src/sema/models/planner/provenance.py | 197 ++++++ src/sema/models/planner/resolution.py | 170 +++++ src/sema/models/planner/risk.py | 119 ++++ src/sema/models/planner/target_model.py | 104 +++ tests/integration/test_planner_round_trip.py | 418 +++++++++++ tests/unit/models/__init__.py | 0 tests/unit/models/planner/__init__.py | 0 .../models/planner/test_lifecycle_and_pins.py | 318 +++++++++ .../models/planner/test_mapping_patterns.py | 306 ++++++++ .../unit/models/planner/test_mapping_plan.py | 368 ++++++++++ .../models/planner/test_package_surface.py | 27 + .../models/planner/test_planner_storage.py | 665 ++++++++++++++++++ .../planner/test_provenance_and_caching.py | 246 +++++++ .../models/planner/test_resolution_planner.py | 263 +++++++ .../models/planner/test_risk_and_evidence.py | 234 ++++++ .../unit/models/planner/test_target_model.py | 352 +++++++++ tests/unit/test_graph_loader.py | 38 + tests/unit/test_graph_nodes.py | 11 +- tests/unit/test_graph_source_schema.py | 40 ++ 35 files changed, 5471 insertions(+), 13 deletions(-) create mode 100644 src/sema/graph/planner_loader.py create mode 100644 src/sema/graph/planner_migrations.py create mode 100644 src/sema/models/planner/__init__.py create mode 100644 src/sema/models/planner/_enums.py create mode 100644 src/sema/models/planner/_refs.py create mode 100644 src/sema/models/planner/_role_validation.py create mode 100644 src/sema/models/planner/field_map.py create mode 100644 src/sema/models/planner/lifecycle.py create mode 100644 src/sema/models/planner/lifecycle_utils.py create mode 100644 src/sema/models/planner/mapping_plan.py create mode 100644 src/sema/models/planner/patterns.py create mode 100644 src/sema/models/planner/provenance.py create mode 100644 src/sema/models/planner/resolution.py create mode 100644 src/sema/models/planner/risk.py create mode 100644 src/sema/models/planner/target_model.py create mode 100644 tests/integration/test_planner_round_trip.py create mode 100644 tests/unit/models/__init__.py create mode 100644 tests/unit/models/planner/__init__.py create mode 100644 tests/unit/models/planner/test_lifecycle_and_pins.py create mode 100644 tests/unit/models/planner/test_mapping_patterns.py create mode 100644 tests/unit/models/planner/test_mapping_plan.py create mode 100644 tests/unit/models/planner/test_package_surface.py create mode 100644 tests/unit/models/planner/test_planner_storage.py create mode 100644 tests/unit/models/planner/test_provenance_and_caching.py create mode 100644 tests/unit/models/planner/test_resolution_planner.py create mode 100644 tests/unit/models/planner/test_risk_and_evidence.py create mode 100644 tests/unit/models/planner/test_target_model.py diff --git a/.gitignore b/.gitignore index b675a5c..c7f8f9f 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,4 @@ backups/ # Eval run artifacts — local by default (previously-tracked runs stay tracked) eval-runs/ +.runs/ diff --git a/src/sema/graph/loader.py b/src/sema/graph/loader.py index f1d9d26..a2c85d4 100644 --- a/src/sema/graph/loader.py +++ b/src/sema/graph/loader.py @@ -137,9 +137,10 @@ def batch_upsert_properties( def batch_upsert_terms( self, terms: list[dict[str, Any]], + source_schema: str | None = None, ) -> None: from sema.graph import loader_utils as _lu - _lu.batch_upsert_terms(self, terms) + _lu.batch_upsert_terms(self, terms, source_schema=source_schema) def batch_upsert_aliases( self, aliases: list[dict[str, Any]], parent_label: str, @@ -170,7 +171,9 @@ def upsert_entity( "ON CREATE SET e.id = $id " "SET e.description = $description, e.source = $source, " "e.confidence = $confidence, " - "e.resolved_at = $resolved_at " + "e.resolved_at = $resolved_at, " + "e.model_role = coalesce(e.model_role, 'SOURCE'), " + "e.source_id = coalesce(e.source_id, $source_schema, $source) " "WITH e " "MERGE (t:Table {name: $table_name, " "schema_name: $schema_name, catalog: $catalog}) " @@ -197,9 +200,13 @@ def upsert_property( "SET p.semantic_type = $semantic_type, " "p.source = $source, " "p.confidence = $confidence, " - "p.resolved_at = $resolved_at " + "p.resolved_at = $resolved_at, " + "p.model_role = coalesce(p.model_role, 'SOURCE'), " + "p.source_id = coalesce(p.source_id, $source_schema, $source) " "WITH p " "MERGE (e:Entity {name: $entity_name}) " + "SET e.model_role = coalesce(e.model_role, 'SOURCE'), " + "e.source_id = coalesce(e.source_id, $source_schema, $source) " "MERGE (e)-[hp:HAS_PROPERTY " "{source_schema: $source_schema}]->(p) " "WITH p " @@ -219,6 +226,7 @@ def upsert_property( def upsert_term( self, code: str, label: str, source: str, confidence: float, + source_schema: str | None = None, ) -> None: id_ = str(uuid.uuid4()) self._run( @@ -226,11 +234,13 @@ def upsert_term( "ON CREATE SET t.id = $id " "SET t.label = $label, t.source = $source, " "t.confidence = $confidence, " - "t.resolved_at = $resolved_at", + "t.resolved_at = $resolved_at, " + "t.model_role = coalesce(t.model_role, 'SOURCE'), " + "t.source_id = coalesce(t.source_id, $source_schema, $source)", code=code, label=label, source=source, confidence=confidence, resolved_at=datetime.now(timezone.utc).isoformat(), - id=id_, + id=id_, source_schema=source_schema, ) def upsert_value_set( diff --git a/src/sema/graph/loader_utils.py b/src/sema/graph/loader_utils.py index 21ea372..826ab3f 100644 --- a/src/sema/graph/loader_utils.py +++ b/src/sema/graph/loader_utils.py @@ -82,7 +82,9 @@ def batch_upsert_entities( "SET e.description = r.description, e.source = r.source, " "e.confidence = r.confidence, " "e.status = 'ACTIVE', " - "e.resolved_at = r.resolved_at " + "e.resolved_at = r.resolved_at, " + "e.model_role = coalesce(e.model_role, 'SOURCE'), " + "e.source_id = coalesce(e.source_id, r.source_schema, r.source) " "WITH e, r " "MERGE (t:Table {name: r.table_name, " "schema_name: r.schema_name, catalog: r.catalog}) " @@ -108,9 +110,13 @@ def batch_upsert_properties( "p.source = r.source, " "p.confidence = r.confidence, " "p.status = 'ACTIVE', " - "p.resolved_at = r.resolved_at " + "p.resolved_at = r.resolved_at, " + "p.model_role = coalesce(p.model_role, 'SOURCE'), " + "p.source_id = coalesce(p.source_id, r.source_schema, r.source) " "WITH p, r " "MERGE (e:Entity {name: r.entity_name}) " + "SET e.model_role = coalesce(e.model_role, 'SOURCE'), " + "e.source_id = coalesce(e.source_id, r.source_schema, r.source) " "MERGE (e)-[hp:HAS_PROPERTY " "{source_schema: r.source_schema}]->(p) " "WITH p, r " @@ -125,12 +131,18 @@ def batch_upsert_properties( def batch_upsert_terms( loader: GraphLoader, terms: list[dict[str, Any]], + source_schema: str | None = None, ) -> None: if not terms: return resolved_at = datetime.now(timezone.utc).isoformat() rows = [ - {**t, "resolved_at": resolved_at, "id": str(uuid.uuid4())} + { + **t, + "resolved_at": resolved_at, + "id": str(uuid.uuid4()), + "source_schema": source_schema, + } for t in terms ] loader._run( @@ -141,7 +153,9 @@ def batch_upsert_terms( "t.confidence = r.confidence, " "t.vocabulary_name = r.vocabulary_name, " "t.status = 'ACTIVE', " - "t.resolved_at = r.resolved_at", + "t.resolved_at = r.resolved_at, " + "t.model_role = coalesce(t.model_role, 'SOURCE'), " + "t.source_id = coalesce(t.source_id, r.source_schema, r.source)", rows=rows, ) diff --git a/src/sema/graph/materializer_utils.py b/src/sema/graph/materializer_utils.py index 3f010a0..a4646b1 100644 --- a/src/sema/graph/materializer_utils.py +++ b/src/sema/graph/materializer_utils.py @@ -278,7 +278,7 @@ def upsert_decoded_values( batch_upsert_value_sets( loader, vs_batch, source_schema=source_schema, ) - batch_upsert_terms(loader, term_batch) + batch_upsert_terms(loader, term_batch, source_schema=source_schema) def _collect_alias_batch( diff --git a/src/sema/graph/planner_loader.py b/src/sema/graph/planner_loader.py new file mode 100644 index 0000000..02eea12 --- /dev/null +++ b/src/sema/graph/planner_loader.py @@ -0,0 +1,451 @@ +"""Round-trip helpers between Pydantic planner models and Neo4j properties. + +This module owns the native-property storage layout for `Provenance`: +prefix `prov_run_*` for `RunProvenance`, `prov_source_*` for `SourceScope`, +and `prov_timestamp` for the per-call timestamp. Same layout applies to +`MappingAssertion`, `ResolutionPlan`, `RiskFlag`, and +`HumanPin.confirmed_under`. + +It also owns the app-layer guard for relationship-target model_role rules +declared in planner-graph-storage spec 8.5. Neo4j Community lacks +relationship-target property constraints; APOC triggers fill the gap when +available (see `planner_migrations.cypher_up(apoc=True)`), and these write +helpers enforce the same rules from Python regardless of edition. +""" + +from __future__ import annotations + +import json +from datetime import datetime +from typing import Any + +from sema.models.planner._enums import ModelRole +from sema.models.planner._role_validation import ( + require_property_role_for_relationship, +) +from sema.models.planner.field_map import FieldMap, RowIdentity +from sema.models.planner.lifecycle import HumanPin, PinState, Status +from sema.models.planner.mapping_plan import MappingAssertion, MappingPlan +from sema.models.planner.patterns import ( + MappingPattern, + PatternPayload, + expected_payload_type, +) +from sema.models.planner.provenance import ( + Provenance, + RunProvenance, + SourceScope, +) +from sema.models.planner.resolution import ResolutionPlan +from sema.models.planner.risk import RiskFlag +from sema.models.planner.target_model import TargetObligation + + +_RUN_FIELDS = ( + "run_id", + "target_model_version", + "target_schema_snapshot_hash", + "vocab_release", + "context_card_version", + "prompt_template_version", + "few_shot_set_version", + "constraint_version", + "llm_model", + "embedding_model", +) +_SOURCE_FIELDS = ("source_id", "source_schema_hash", "source_profile_hash") + + +def provenance_to_properties(prov: Provenance) -> dict[str, Any]: + props: dict[str, Any] = {} + for f in _RUN_FIELDS: + props[f"prov_run_{f}"] = getattr(prov.run, f) + for f in _SOURCE_FIELDS: + props[f"prov_source_{f}"] = getattr(prov.source, f) + props["prov_timestamp"] = prov.timestamp.isoformat() + return props + + +def properties_to_provenance(props: dict[str, Any]) -> Provenance: + run = RunProvenance(**{f: props[f"prov_run_{f}"] for f in _RUN_FIELDS}) + source = SourceScope(**{f: props[f"prov_source_{f}"] for f in _SOURCE_FIELDS}) + timestamp = datetime.fromisoformat(props["prov_timestamp"]) + return Provenance(run=run, source=source, timestamp=timestamp) + + +def confirmed_under_to_properties( + run: RunProvenance, source: SourceScope +) -> dict[str, Any]: + props: dict[str, Any] = {} + for f in _RUN_FIELDS: + props[f"prov_run_{f}"] = getattr(run, f) + for f in _SOURCE_FIELDS: + props[f"prov_source_{f}"] = getattr(source, f) + return props + + +def properties_to_confirmed_under( + props: dict[str, Any], +) -> tuple[RunProvenance, SourceScope]: + run = RunProvenance(**{f: props[f"prov_run_{f}"] for f in _RUN_FIELDS}) + source = SourceScope(**{f: props[f"prov_source_{f}"] for f in _SOURCE_FIELDS}) + return run, source + + +def cypher_create_field_map_maps_to( + field_map_id: str, target_property_id: str, target_role: ModelRole +) -> tuple[str, dict[str, str]]: + require_property_role_for_relationship("MAPS_TO", target_role) + stmt = ( + "MATCH (f:FieldMap {id: $fm_id}), (p:Property {id: $p_id}) " + "MERGE (f)-[:MAPS_TO]->(p)" + ) + return stmt, {"fm_id": field_map_id, "p_id": target_property_id} + + +def cypher_create_field_map_derived_from( + field_map_id: str, source_property_id: str, source_role: ModelRole +) -> tuple[str, dict[str, str]]: + require_property_role_for_relationship("DERIVED_FROM", source_role) + stmt = ( + "MATCH (f:FieldMap {id: $fm_id}), (p:Property {id: $p_id}) " + "MERGE (f)-[:DERIVED_FROM]->(p)" + ) + return stmt, {"fm_id": field_map_id, "p_id": source_property_id} + + +def cypher_create_plan_has_lineage( + plan_id: str, source_property_id: str, source_role: ModelRole +) -> tuple[str, dict[str, str]]: + require_property_role_for_relationship("HAS_LINEAGE", source_role) + stmt = ( + "MATCH (m:MappingPlan {id: $plan_id}), (p:Property {id: $p_id}) " + "MERGE (m)-[:HAS_LINEAGE]->(p)" + ) + return stmt, {"plan_id": plan_id, "p_id": source_property_id} + + +def cypher_create_resolution_input( + resolution_plan_id: str, source_property_id: str, source_role: ModelRole +) -> tuple[str, dict[str, str]]: + require_property_role_for_relationship("RESOLUTION_INPUT", source_role) + stmt = ( + "MATCH (r:ResolutionPlan {id: $rp_id}), (p:Property {id: $p_id}) " + "MERGE (r)-[:RESOLUTION_INPUT]->(p)" + ) + return stmt, {"rp_id": resolution_plan_id, "p_id": source_property_id} + + +def mapping_assertion_to_properties(a: MappingAssertion) -> dict[str, Any]: + props = provenance_to_properties(a.provenance) + props.update( + id=a.id, + source_field_ref=a.source_field_ref, + target_property_ref=a.target_property_ref, + pattern=a.pattern.value, + payload_json=a.payload.model_dump_json(), + confidence=a.confidence, + status=a.status.value, + risk_flags_json=json.dumps( + [rf.model_dump(mode="json") for rf in a.risk_flags] + ), + concerns_text=a.concerns_text, + ) + return props + + +def properties_to_mapping_assertion(props: dict[str, Any]) -> MappingAssertion: + from typing import cast + + pattern = MappingPattern(props["pattern"]) + payload_cls = expected_payload_type(pattern) + payload = cast( + PatternPayload, payload_cls.model_validate_json(props["payload_json"]) + ) + risk_flags = [ + RiskFlag.model_validate(rf) + for rf in json.loads(props.get("risk_flags_json") or "[]") + ] + return MappingAssertion( + id=props["id"], + source_field_ref=props["source_field_ref"], + target_property_ref=props["target_property_ref"], + pattern=pattern, + payload=payload, + confidence=props["confidence"], + status=Status(props["status"]), + risk_flags=risk_flags, + provenance=properties_to_provenance(props), + concerns_text=props.get("concerns_text"), + ) + + +def field_map_to_properties(fm: FieldMap) -> dict[str, Any]: + return { + "target_field_ref": fm.target_field_ref, + "pattern": fm.pattern.value, + "payload_json": fm.payload.model_dump_json(), + } + + +def properties_to_field_map(props: dict[str, Any]) -> FieldMap: + return FieldMap.model_validate( + { + "target_field_ref": props["target_field_ref"], + "pattern": props["pattern"], + "payload": json.loads(props["payload_json"]), + } + ) + + +def target_obligation_to_properties(o: TargetObligation) -> dict[str, Any]: + return { + "target_entity": o.target_entity, + "primary_key": o.primary_key.value, + "obligation_json": o.model_dump_json(), + } + + +def properties_to_target_obligation(props: dict[str, Any]) -> TargetObligation: + return TargetObligation.model_validate_json(props["obligation_json"]) + + +def mapping_plan_to_properties(p: MappingPlan) -> dict[str, Any]: + return { + "id": p.id, + "source_scope_ref": p.source_scope_ref, + "plan_verdict": p.derive_verdict().value, + "obligation_json": p.obligation.model_dump_json(), + "row_identity_json": p.row_identity.model_dump_json(), + "field_maps_json": json.dumps( + [fm.model_dump(mode="json") for fm in p.field_maps] + ), + "risk_flags_json": json.dumps( + [rf.model_dump(mode="json") for rf in p.risk_flags] + ), + "lineage_json": json.dumps(p.lineage), + } + + +def properties_to_mapping_plan(props: dict[str, Any]) -> MappingPlan: + return MappingPlan( + id=props["id"], + source_scope_ref=props["source_scope_ref"], + obligation=TargetObligation.model_validate_json(props["obligation_json"]), + row_identity=RowIdentity.model_validate_json(props["row_identity_json"]), + field_maps=[ + FieldMap.model_validate(fm) + for fm in json.loads(props.get("field_maps_json") or "[]") + ], + risk_flags=[ + RiskFlag.model_validate(rf) + for rf in json.loads(props.get("risk_flags_json") or "[]") + ], + lineage=json.loads(props.get("lineage_json") or "[]"), + ) + + +def resolution_plan_to_properties(rp: ResolutionPlan) -> dict[str, Any]: + props: dict[str, Any] = { + f"prov_run_{f}": getattr(rp.provenance_run, f) for f in _RUN_FIELDS + } + props.update( + id=rp.id, + target_identity_ref=rp.target_identity_ref, + strategy=rp.strategy.value, + confidence=rp.confidence, + status=rp.status.value, + plan_json=rp.model_dump_json(), + prov_timestamp=rp.timestamp.isoformat(), + ) + return props + + +def properties_to_resolution_plan(props: dict[str, Any]) -> ResolutionPlan: + return ResolutionPlan.model_validate_json(props["plan_json"]) + + +def risk_flag_to_properties(rf: RiskFlag) -> dict[str, Any]: + return { + "code": rf.code.value, + "severity": rf.severity.value, + "source_stage": rf.source_stage.value, + "suggested_action": rf.suggested_action.value, + "evidence_json": json.dumps( + [e.model_dump(mode="json") for e in rf.evidence] + ), + "flag_json": rf.model_dump_json(), + } + + +def properties_to_risk_flag(props: dict[str, Any]) -> RiskFlag: + return RiskFlag.model_validate_json(props["flag_json"]) + + +def human_pin_to_properties(pin: HumanPin) -> dict[str, Any]: + props = confirmed_under_to_properties( + pin.confirmed_under_run, pin.confirmed_under_source + ) + props.update( + id=pin.pin_id, + pin_id=pin.pin_id, + assertion_id=pin.assertion_id, + resolution_plan_id=pin.resolution_plan_id, + pinned_at=pin.pinned_at.isoformat(), + pinned_by=pin.pinned_by, + pin_state=pin.pin_state.value, + expires_on_change_of=list(pin.expires_on_change_of), + ) + return props + + +def properties_to_human_pin(props: dict[str, Any]) -> HumanPin: + run, source = properties_to_confirmed_under(props) + return HumanPin( + pin_id=props["pin_id"], + assertion_id=props.get("assertion_id"), + resolution_plan_id=props.get("resolution_plan_id"), + pinned_at=datetime.fromisoformat(props["pinned_at"]), + pinned_by=props["pinned_by"], + confirmed_under_run=run, + confirmed_under_source=source, + pin_state=PinState(props["pin_state"]), + expires_on_change_of=list(props["expires_on_change_of"]), + ) + + +# --- Cypher write/read helpers (driver-aware wrappers around serializers). +# +# These wrap a Neo4j ``Session`` so callers don't have to know labels or +# write Cypher inline. Each ``write_*`` / ``read_*`` pair satisfies spec 8.8 +# "write helpers, read-back helpers, and round-trip serialization for each +# planner node kind." + + +def write_mapping_assertion(session: Any, a: MappingAssertion) -> None: + session.run( + "MERGE (n:MappingAssertion {id: $props.id}) SET n = $props", + props=mapping_assertion_to_properties(a), + ) + + +def read_mapping_assertion(session: Any, assertion_id: str) -> MappingAssertion: + row = session.run( + "MATCH (n:MappingAssertion {id: $id}) RETURN properties(n) AS p", + id=assertion_id, + ).single() + if row is None: + raise LookupError(f"MappingAssertion id={assertion_id!r} not found") + return properties_to_mapping_assertion(row["p"]) + + +def write_mapping_plan(session: Any, p: MappingPlan) -> None: + session.run( + "MERGE (n:MappingPlan {id: $props.id}) SET n = $props", + props=mapping_plan_to_properties(p), + ) + + +def read_mapping_plan(session: Any, plan_id: str) -> MappingPlan: + row = session.run( + "MATCH (n:MappingPlan {id: $id}) RETURN properties(n) AS p", + id=plan_id, + ).single() + if row is None: + raise LookupError(f"MappingPlan id={plan_id!r} not found") + return properties_to_mapping_plan(row["p"]) + + +def write_resolution_plan(session: Any, rp: ResolutionPlan) -> None: + session.run( + "MERGE (n:ResolutionPlan {id: $props.id}) SET n = $props", + props=resolution_plan_to_properties(rp), + ) + + +def read_resolution_plan(session: Any, plan_id: str) -> ResolutionPlan: + row = session.run( + "MATCH (n:ResolutionPlan {id: $id}) RETURN properties(n) AS p", + id=plan_id, + ).single() + if row is None: + raise LookupError(f"ResolutionPlan id={plan_id!r} not found") + return properties_to_resolution_plan(row["p"]) + + +def write_human_pin(session: Any, pin: HumanPin) -> None: + session.run( + "MERGE (n:HumanPin {id: $props.id}) SET n = $props", + props=human_pin_to_properties(pin), + ) + + +def read_human_pin(session: Any, pin_id: str) -> HumanPin: + row = session.run( + "MATCH (n:HumanPin {id: $id}) RETURN properties(n) AS p", + id=pin_id, + ).single() + if row is None: + raise LookupError(f"HumanPin id={pin_id!r} not found") + return properties_to_human_pin(row["p"]) + + +def write_target_obligation( + session: Any, obligation: TargetObligation, *, obligation_id: str +) -> None: + props = target_obligation_to_properties(obligation) + props["id"] = obligation_id + session.run( + "MERGE (n:TargetObligation {id: $props.id}) SET n = $props", + props=props, + ) + + +def read_target_obligation( + session: Any, obligation_id: str +) -> TargetObligation: + row = session.run( + "MATCH (n:TargetObligation {id: $id}) RETURN properties(n) AS p", + id=obligation_id, + ).single() + if row is None: + raise LookupError(f"TargetObligation id={obligation_id!r} not found") + return properties_to_target_obligation(row["p"]) + + +def write_risk_flag(session: Any, rf: RiskFlag, *, flag_id: str) -> None: + props = risk_flag_to_properties(rf) + props["id"] = flag_id + session.run( + "MERGE (n:RiskFlag {id: $props.id}) SET n = $props", + props=props, + ) + + +def read_risk_flag(session: Any, flag_id: str) -> RiskFlag: + row = session.run( + "MATCH (n:RiskFlag {id: $id}) RETURN properties(n) AS p", + id=flag_id, + ).single() + if row is None: + raise LookupError(f"RiskFlag id={flag_id!r} not found") + return properties_to_risk_flag(row["p"]) + + +def write_field_map(session: Any, fm: FieldMap, *, field_map_id: str) -> None: + props = field_map_to_properties(fm) + props["id"] = field_map_id + session.run( + "MERGE (n:FieldMap {id: $props.id}) SET n = $props", + props=props, + ) + + +def read_field_map(session: Any, field_map_id: str) -> FieldMap: + row = session.run( + "MATCH (n:FieldMap {id: $id}) RETURN properties(n) AS p", + id=field_map_id, + ).single() + if row is None: + raise LookupError(f"FieldMap id={field_map_id!r} not found") + return properties_to_field_map(row["p"]) diff --git a/src/sema/graph/planner_migrations.py b/src/sema/graph/planner_migrations.py new file mode 100644 index 0000000..bf0097b --- /dev/null +++ b/src/sema/graph/planner_migrations.py @@ -0,0 +1,266 @@ +"""Cypher migrations for the generic-mapping-planner-contract.""" + +from __future__ import annotations + + +PLANNER_NODE_LABELS = ( + "MappingAssertion", + "MappingPlan", + "FieldMap", + "ResolutionPlan", + "TargetObligation", + "RiskFlag", + "HumanPin", +) + +PLANNER_RELATIONSHIPS = ( + "HAS_OBLIGATION", + "ASSEMBLED_INTO", + "FIELD_MAP_OF", + "MAPS_TO", + "DERIVED_FROM", + "RESOLVED_BY", + "HAS_LINEAGE", + "RAISED_FLAG", + "PINNED", + "CONFLICT_LOSER", + "RESOLUTION_INPUT", +) + + +def cypher_up(*, enterprise: bool = False, apoc: bool = False) -> list[str]: + """Forward migration: add labels, relationships, indexes, model_role backfill. + + Property-existence constraints require Neo4j Enterprise; they are emitted + only when ``enterprise=True``. On Community editions, model_role presence + is enforced at the application layer (Pydantic validators) plus the + backfill statements below. + + Relationship-target role rules from spec 8.5 + (`MAPS_TO`→TARGET, `DERIVED_FROM`/`HAS_LINEAGE`/`RESOLUTION_INPUT`→SOURCE) + cannot be expressed as native Cypher constraints. When ``apoc=True`` the + migration emits APOC `before` triggers that abort transactions creating + role-mismatched relationships. When APOC is unavailable, the same rules + are enforced from Python via `planner_loader.cypher_create_*` helpers. + """ + statements: list[str] = [] + statements.extend(_role_backfill()) + statements.extend(_uniqueness_constraints()) + if enterprise: + statements.extend(_role_existence_constraints()) + statements.extend(_indexes()) + if apoc: + statements.extend(_apoc_relationship_role_triggers()) + statements.extend(_apoc_scoping_id_triggers()) + return statements + + +def cypher_down(*, apoc: bool = False) -> list[str]: + """Reverse migration: drop planner labels, relationships, indexes, triggers. + + When ``apoc=True`` the migration also removes the planner APOC triggers + (the symmetric inverse of ``cypher_up(apoc=True)``). Pass ``apoc=False`` + when the upgrade did not install triggers; otherwise the + ``apoc.trigger.remove`` calls fail on a vanilla Neo4j Community node. + """ + drops: list[str] = [] + if apoc: + all_triggers = _APOC_TRIGGER_NAMES + _SCOPING_TRIGGER_NAMES + for trigger_name in all_triggers: + drops.append( + f"CALL apoc.trigger.remove('{trigger_name}') YIELD name " + "RETURN name" + ) + for label in PLANNER_NODE_LABELS: + drops.append(f"DROP CONSTRAINT {label}_id_unique IF EXISTS") + drops.append( + f"MATCH (n:{label}) DETACH DELETE n" + ) + drops.extend( + [ + "DROP INDEX mapping_assertion_run_id IF EXISTS", + "DROP INDEX mapping_assertion_source_id IF EXISTS", + "DROP INDEX mapping_assertion_status IF EXISTS", + "DROP INDEX mapping_plan_verdict IF EXISTS", + "DROP INDEX resolution_plan_status IF EXISTS", + "DROP INDEX resolution_plan_verdict IF EXISTS", + "DROP INDEX risk_flag_code IF EXISTS", + "DROP INDEX human_pin_state IF EXISTS", + "DROP INDEX human_pin_assertion_id IF EXISTS", + "DROP INDEX human_pin_resolution_plan_id IF EXISTS", + "DROP INDEX entity_model_role IF EXISTS", + "DROP INDEX property_model_role IF EXISTS", + ] + ) + return drops + + +def _role_backfill() -> list[str]: + """Backfill model_role + source_id on pre-planner nodes. + + `model_role` defaults to SOURCE. `source_id` is backfilled from any + `source_schema` stamped on the node's incident edges (one representative + edge per node) so the discriminator rule "SOURCE node has source_id" + holds for already-loaded graphs. Nodes without any source_schema-stamped + edge fall back to the legacy `source` field. + """ + backfills: list[str] = [] + for label in ("Entity", "Property", "Term", "Constraint"): + backfills.append( + f"MATCH (n:{label}) WHERE n.model_role IS NULL " + "SET n.model_role = 'SOURCE'" + ) + backfills.append( + f"MATCH (n:{label}) WHERE n.source_id IS NULL " + "AND n.model_role = 'SOURCE' " + "OPTIONAL MATCH (n)-[r]-() WHERE r.source_schema IS NOT NULL " + "WITH n, head(collect(r.source_schema)) AS scope " + "WITH n, coalesce(scope, n.source) AS resolved " + "WHERE resolved IS NOT NULL " + "SET n.source_id = resolved" + ) + return backfills + + +def _uniqueness_constraints() -> list[str]: + return [ + f"CREATE CONSTRAINT {label}_id_unique IF NOT EXISTS " + f"FOR (n:{label}) REQUIRE n.id IS UNIQUE" + for label in PLANNER_NODE_LABELS + ] + + +def _role_existence_constraints() -> list[str]: + return [ + "CREATE CONSTRAINT entity_model_role_exists IF NOT EXISTS " + "FOR (n:Entity) REQUIRE n.model_role IS NOT NULL", + "CREATE CONSTRAINT property_model_role_exists IF NOT EXISTS " + "FOR (n:Property) REQUIRE n.model_role IS NOT NULL", + "CREATE CONSTRAINT term_model_role_exists IF NOT EXISTS " + "FOR (n:Term) REQUIRE n.model_role IS NOT NULL", + "CREATE CONSTRAINT constraint_model_role_exists IF NOT EXISTS " + "FOR (n:Constraint) REQUIRE n.model_role IS NOT NULL", + ] + + +_SCOPING_TRIGGER_NAMES = ( + "planner_no_role_id_collision", + "planner_source_role_requires_source_id", + "planner_target_role_requires_target_model_id", +) + + +def _apoc_scoping_id_triggers() -> list[str]: + """APOC triggers enforcing the source_id/target_model_id discriminator. + + Spec 2.2 + 2.1: a node MUST NOT carry both `source_id` and + `target_model_id`; SOURCE-role nodes MUST carry source_id; TARGET-role + nodes MUST carry target_model_id. The Pydantic models enforce this at + construction; these triggers enforce it for nodes written through any + path, including legacy loaders. + """ + labels = "['Entity','Property','Term','Constraint']" + return [ + f"CALL apoc.trigger.add('planner_no_role_id_collision', \"" + f"UNWIND $createdNodes AS n " + f"WITH n WHERE any(l IN labels(n) WHERE l IN {labels}) " + f"AND n.source_id IS NOT NULL AND n.target_model_id IS NOT NULL " + f"CALL apoc.util.validate(true, " + f"'node MUST NOT carry both source_id and target_model_id', []) " + f"RETURN 1\", {{phase: 'before'}})", + f"CALL apoc.trigger.add('planner_source_role_requires_source_id', \"" + f"UNWIND $createdNodes AS n " + f"WITH n WHERE any(l IN labels(n) WHERE l IN {labels}) " + f"AND coalesce(n.model_role, 'SOURCE') = 'SOURCE' " + f"AND n.source_id IS NULL " + f"CALL apoc.util.validate(true, " + f"'model_role=SOURCE requires source_id', []) " + f"RETURN 1\", {{phase: 'before'}})", + f"CALL apoc.trigger.add('planner_target_role_requires_target_model_id', " + f"\"UNWIND $createdNodes AS n " + f"WITH n WHERE any(l IN labels(n) WHERE l IN {labels}) " + f"AND n.model_role = 'TARGET' " + f"AND n.target_model_id IS NULL " + f"CALL apoc.util.validate(true, " + f"'model_role=TARGET requires target_model_id', []) " + f"RETURN 1\", {{phase: 'before'}})", + ] + + +_APOC_TRIGGER_NAMES = ( + "planner_maps_to_requires_target_property", + "planner_derived_from_requires_source_property", + "planner_has_lineage_requires_source_property", + "planner_resolution_input_requires_source_property", +) + + +def _apoc_relationship_role_triggers() -> list[str]: + """APOC `before` triggers enforcing spec 8.5 relationship-target role rules.""" + return [ + _apoc_role_trigger( + name="planner_maps_to_requires_target_property", + rel_type="MAPS_TO", + required_role="TARGET", + ), + _apoc_role_trigger( + name="planner_derived_from_requires_source_property", + rel_type="DERIVED_FROM", + required_role="SOURCE", + ), + _apoc_role_trigger( + name="planner_has_lineage_requires_source_property", + rel_type="HAS_LINEAGE", + required_role="SOURCE", + ), + _apoc_role_trigger( + name="planner_resolution_input_requires_source_property", + rel_type="RESOLUTION_INPUT", + required_role="SOURCE", + ), + ] + + +def _apoc_role_trigger(*, name: str, rel_type: str, required_role: str) -> str: + body = ( + "UNWIND $createdRelationships AS r " + f"WITH r WHERE type(r) = '{rel_type}' " + "WITH r, endNode(r) AS p " + "WHERE p:Property AND coalesce(p.model_role, 'SOURCE') <> " + f"'{required_role}' " + "CALL apoc.util.validate(true, " + f"'{rel_type} requires Property.model_role={required_role}', []) " + "RETURN 1" + ) + return ( + f"CALL apoc.trigger.add('{name}', \"{body}\", {{phase: 'before'}})" + ) + + +def _indexes() -> list[str]: + return [ + "CREATE INDEX mapping_assertion_run_id IF NOT EXISTS " + "FOR (n:MappingAssertion) ON (n.prov_run_run_id)", + "CREATE INDEX mapping_assertion_source_id IF NOT EXISTS " + "FOR (n:MappingAssertion) ON (n.prov_source_source_id)", + "CREATE INDEX mapping_assertion_status IF NOT EXISTS " + "FOR (n:MappingAssertion) ON (n.status)", + "CREATE INDEX mapping_plan_verdict IF NOT EXISTS " + "FOR (n:MappingPlan) ON (n.plan_verdict)", + "CREATE INDEX resolution_plan_status IF NOT EXISTS " + "FOR (n:ResolutionPlan) ON (n.status)", + "CREATE INDEX resolution_plan_verdict IF NOT EXISTS " + "FOR (n:ResolutionPlan) ON (n.resolution_verdict)", + "CREATE INDEX risk_flag_code IF NOT EXISTS " + "FOR (n:RiskFlag) ON (n.code)", + "CREATE INDEX human_pin_state IF NOT EXISTS " + "FOR (n:HumanPin) ON (n.pin_state)", + "CREATE INDEX human_pin_assertion_id IF NOT EXISTS " + "FOR (n:HumanPin) ON (n.assertion_id)", + "CREATE INDEX human_pin_resolution_plan_id IF NOT EXISTS " + "FOR (n:HumanPin) ON (n.resolution_plan_id)", + "CREATE INDEX entity_model_role IF NOT EXISTS " + "FOR (n:Entity) ON (n.model_role)", + "CREATE INDEX property_model_role IF NOT EXISTS " + "FOR (n:Property) ON (n.model_role)", + ] diff --git a/src/sema/models/graph_nodes.py b/src/sema/models/graph_nodes.py index 71b4d0b..0491d02 100644 --- a/src/sema/models/graph_nodes.py +++ b/src/sema/models/graph_nodes.py @@ -2,9 +2,15 @@ from datetime import datetime from enum import Enum -from typing import Any +from typing import Any, Self -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator + +from sema.models.planner._enums import ModelRole, TargetArtifactKind +from sema.models.planner._role_validation import ( + require_kind_matches_role, + require_role_identifier, +) class SemanticType(str, Enum): @@ -63,6 +69,16 @@ class Entity(BaseModel): confidence: float = Field(ge=0.0, le=1.0) resolved_at: datetime | None = None embedding_updated_at: datetime | None = None + model_role: ModelRole = ModelRole.SOURCE + source_id: str | None = None + target_model_id: str | None = None + kind: TargetArtifactKind | None = None + + @model_validator(mode="after") + def _validate_role(self) -> Self: + require_role_identifier(self.model_role, self.source_id, self.target_model_id) + require_kind_matches_role(self.model_role, self.kind) + return self class Property(BaseModel): @@ -74,6 +90,14 @@ class Property(BaseModel): confidence: float = Field(ge=0.0, le=1.0) resolved_at: datetime | None = None embedding_updated_at: datetime | None = None + model_role: ModelRole = ModelRole.SOURCE + source_id: str | None = None + target_model_id: str | None = None + + @model_validator(mode="after") + def _validate_role(self) -> Self: + require_role_identifier(self.model_role, self.source_id, self.target_model_id) + return self class Metric(BaseModel): @@ -96,6 +120,14 @@ class Term(BaseModel): confidence: float = Field(ge=0.0, le=1.0) resolved_at: datetime | None = None embedding_updated_at: datetime | None = None + model_role: ModelRole = ModelRole.SOURCE + source_id: str | None = None + target_model_id: str | None = None + + @model_validator(mode="after") + def _validate_role(self) -> Self: + require_role_identifier(self.model_role, self.source_id, self.target_model_id) + return self class ValueSet(BaseModel): diff --git a/src/sema/models/planner/__init__.py b/src/sema/models/planner/__init__.py new file mode 100644 index 0000000..35c7e42 --- /dev/null +++ b/src/sema/models/planner/__init__.py @@ -0,0 +1,4 @@ +"""Planner contract: target-model, mapping-planner, resolution-planner, +risk-and-evidence, provenance-and-caching, lifecycle-and-pins.""" + +from __future__ import annotations diff --git a/src/sema/models/planner/_enums.py b/src/sema/models/planner/_enums.py new file mode 100644 index 0000000..8a73175 --- /dev/null +++ b/src/sema/models/planner/_enums.py @@ -0,0 +1,29 @@ +"""Closed-set enums shared across planner capabilities.""" + +from __future__ import annotations + +from enum import Enum + + +class ModelRole(str, Enum): + SOURCE = "SOURCE" + TARGET = "TARGET" + + +class TargetArtifactKind(str, Enum): + TABLE_ROW = "TABLE_ROW" + GRAPH_NODE = "GRAPH_NODE" + GRAPH_EDGE = "GRAPH_EDGE" + + +class PrimaryKeyStrategy(str, Enum): + DETERMINISTIC_HASH = "DETERMINISTIC_HASH" + EXTERNAL_SEQUENCE = "EXTERNAL_SEQUENCE" + NATURAL_KEY = "NATURAL_KEY" + COMPOUND = "COMPOUND" + + +class MaterializationMode(str, Enum): + INSERT_ONLY = "INSERT_ONLY" + MERGE = "MERGE" + REPLACE_PARTITION = "REPLACE_PARTITION" diff --git a/src/sema/models/planner/_refs.py b/src/sema/models/planner/_refs.py new file mode 100644 index 0000000..657075d --- /dev/null +++ b/src/sema/models/planner/_refs.py @@ -0,0 +1,55 @@ +"""Typed references used by planner payloads. + +A reference is a structured pointer into the source/target model graph; +producers MUST construct these rather than passing bare strings. + +`RefStr` is the lightweight string form: a dot-delimited path +(``scope.entity.field`` or ``scope.field``) that the constraint layer can +resolve back to a node in the source/target model graph. Truly bare strings +("x", "1") are rejected at construction. Callers that already hold a typed +``PropertyRef`` / ``EntityRef`` / ``TermRef`` MAY pass it directly to the +fields that accept ``RefStr | `` unions. +""" + +from __future__ import annotations + +from typing import Annotated + +from pydantic import BaseModel, Field, StringConstraints + +# Dot-delimited structured path. At least one dot is required; segments +# allow [_+\-=] for vocab-domain descriptors like +# "omop.condition_occurrence.condition_concept_id.domain=Condition". +_REF_PATTERN = r"^[A-Za-z][A-Za-z0-9_+\-]*(\.[A-Za-z0-9_+\-=]+)+$" + +RefStr = Annotated[ + str, + StringConstraints(min_length=1, pattern=_REF_PATTERN), +] + + +class SourceRef(BaseModel): + source_id: str = Field(min_length=1) + + +class TargetRef(BaseModel): + target_model_id: str = Field(min_length=1) + + +class EntityRef(BaseModel): + model_role: str = Field(pattern=r"^(SOURCE|TARGET)$") + scope_id: str = Field(min_length=1) + entity_name: str = Field(min_length=1) + + +class PropertyRef(BaseModel): + model_role: str = Field(pattern=r"^(SOURCE|TARGET)$") + scope_id: str = Field(min_length=1) + entity_name: str = Field(min_length=1) + property_name: str = Field(min_length=1) + + +class TermRef(BaseModel): + model_role: str = Field(pattern=r"^(SOURCE|TARGET)$") + scope_id: str = Field(min_length=1) + code: str = Field(min_length=1) diff --git a/src/sema/models/planner/_role_validation.py b/src/sema/models/planner/_role_validation.py new file mode 100644 index 0000000..8a9b909 --- /dev/null +++ b/src/sema/models/planner/_role_validation.py @@ -0,0 +1,64 @@ +"""Validation helpers for the model_role discriminator on graph nodes.""" + +from __future__ import annotations + +from sema.models.planner._enums import ModelRole, TargetArtifactKind + + +_REL_REQUIRED_PROPERTY_ROLE: dict[str, ModelRole] = { + "MAPS_TO": ModelRole.TARGET, + "DERIVED_FROM": ModelRole.SOURCE, + "HAS_LINEAGE": ModelRole.SOURCE, + "RESOLUTION_INPUT": ModelRole.SOURCE, +} + + +def require_role_identifier( + role: ModelRole, source_id: str | None, target_model_id: str | None +) -> None: + if source_id is not None and target_model_id is not None: + raise ValueError( + "node MUST NOT carry both source_id and target_model_id" + ) + if role is ModelRole.TARGET: + if target_model_id is None: + raise ValueError("model_role=TARGET requires target_model_id") + if source_id is not None: + raise ValueError("model_role=TARGET rejects source_id") + if role is ModelRole.SOURCE: + if target_model_id is not None: + raise ValueError("model_role=SOURCE rejects target_model_id") + if source_id is None: + raise ValueError("model_role=SOURCE requires source_id") + + +def require_kind_matches_role( + role: ModelRole, kind: TargetArtifactKind | None +) -> None: + if role is ModelRole.TARGET and kind is None: + raise ValueError("model_role=TARGET Entity requires kind") + if role is ModelRole.SOURCE and kind is not None: + raise ValueError("model_role=SOURCE Entity rejects kind") + + +def required_property_role(rel_type: str) -> ModelRole | None: + """Return the Property.model_role this relationship type must point at.""" + return _REL_REQUIRED_PROPERTY_ROLE.get(rel_type) + + +def require_property_role_for_relationship( + rel_type: str, property_role: ModelRole +) -> None: + """Reject a relationship whose target Property.model_role mismatches the contract. + + Per planner-graph-storage spec 8.5: MAPS_TO targets MUST be TARGET-role; + DERIVED_FROM, HAS_LINEAGE, RESOLUTION_INPUT targets MUST be SOURCE-role. + """ + required = _REL_REQUIRED_PROPERTY_ROLE.get(rel_type) + if required is None: + return + if property_role is not required: + raise ValueError( + f"{rel_type} requires Property.model_role={required.value}, " + f"got {property_role.value}" + ) diff --git a/src/sema/models/planner/field_map.py b/src/sema/models/planner/field_map.py new file mode 100644 index 0000000..24705b9 --- /dev/null +++ b/src/sema/models/planner/field_map.py @@ -0,0 +1,61 @@ +"""mapping-planner: FieldMap and RowIdentity.""" + +from __future__ import annotations + +import hashlib +from typing import Any, ClassVar, Self + +from pydantic import BaseModel, Field, model_validator + +from sema.models.planner._enums import MaterializationMode +from sema.models.planner._refs import RefStr +from sema.models.planner.patterns import ( + MappingPattern, + PatternPayload, + expected_payload_type, +) + + +def coerce_pattern_payload(data: Any) -> Any: + if not isinstance(data, dict): + return data + pattern = data.get("pattern") + payload = data.get("payload") + if isinstance(payload, dict) and isinstance(pattern, str): + target = expected_payload_type(MappingPattern(pattern)) + data["payload"] = target.model_validate(payload) + return data + + +class FieldMap(BaseModel): + target_field_ref: RefStr + pattern: MappingPattern + payload: PatternPayload + + @model_validator(mode="before") + @classmethod + def _coerce_payload(cls, data: Any) -> Any: + return coerce_pattern_payload(data) + + @model_validator(mode="after") + def _validate_payload_matches_pattern(self) -> Self: + expected = expected_payload_type(self.pattern) + if not isinstance(self.payload, expected): + raise ValueError( + f"pattern {self.pattern.value} requires payload of {expected.__name__}" + ) + return self + + +class RowIdentity(BaseModel): + target_row_key_rule: str = Field(min_length=1) + source_lineage: list[RefStr] = Field(min_length=1) + materialization_mode: MaterializationMode + + +def derive_row_key(identity: RowIdentity, source_values: dict[str, Any]) -> str: + parts = [identity.target_row_key_rule] + for ref in identity.source_lineage: + parts.append(f"{ref}={source_values.get(ref, '')}") + digest = hashlib.sha256("|".join(parts).encode("utf-8")).hexdigest() + return digest diff --git a/src/sema/models/planner/lifecycle.py b/src/sema/models/planner/lifecycle.py new file mode 100644 index 0000000..4127897 --- /dev/null +++ b/src/sema/models/planner/lifecycle.py @@ -0,0 +1,90 @@ +"""lifecycle-and-pins capability.""" + +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Self + +from pydantic import BaseModel, Field, model_validator + +from sema.models.planner.lifecycle_utils import ( + BuildContext, + DispatchDecision, + derive_plan_verdict, + detect_pin_stale, + revalidate, + revoke_pin, + transition_status, +) +from sema.models.planner.provenance import RunProvenance, SourceScope + + +class Status(str, Enum): + candidate = "candidate" + auto_accepted = "auto_accepted" + review_pending = "review_pending" + human_pinned = "human_pinned" + rejected = "rejected" + + +class PlanVerdict(str, Enum): + compilable = "compilable" + blocked_by_obligation = "blocked_by_obligation" + blocked_by_constraint = "blocked_by_constraint" + blocked_by_resolution = "blocked_by_resolution" + blocked_by_fk = "blocked_by_fk" + awaiting_review = "awaiting_review" + + +class PinState(str, Enum): + active = "active" + stale = "stale" + revalidated = "revalidated" + invalidated = "invalidated" + + +_DEFAULT_EXPIRES_ON_CHANGE_OF = ( + "target_model_version", + "target_schema_snapshot_hash", + "source_schema_hash", + "source_profile_hash", + "context_card_version", + "vocab_release", +) + + +class HumanPin(BaseModel): + pin_id: str = Field(min_length=1) + assertion_id: str | None = None + resolution_plan_id: str | None = None + pinned_at: datetime + pinned_by: str = Field(min_length=1) + confirmed_under_run: RunProvenance + confirmed_under_source: SourceScope + expires_on_change_of: list[str] = Field( + default_factory=lambda: list(_DEFAULT_EXPIRES_ON_CHANGE_OF) + ) + pin_state: PinState = PinState.active + + @model_validator(mode="after") + def _validate_target(self) -> Self: + a, r = self.assertion_id, self.resolution_plan_id + if (a is None) == (r is None): + raise ValueError("HumanPin MUST reference exactly one of assertion_id or resolution_plan_id") + return self + + +__all__ = [ + "Status", + "PlanVerdict", + "PinState", + "HumanPin", + "transition_status", + "derive_plan_verdict", + "detect_pin_stale", + "revalidate", + "revoke_pin", + "BuildContext", + "DispatchDecision", +] diff --git a/src/sema/models/planner/lifecycle_utils.py b/src/sema/models/planner/lifecycle_utils.py new file mode 100644 index 0000000..b870efe --- /dev/null +++ b/src/sema/models/planner/lifecycle_utils.py @@ -0,0 +1,141 @@ +"""Helpers for lifecycle-and-pins: state-machine + pin-staleness logic.""" + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Iterable + +from pydantic import BaseModel, Field + +from sema.models.planner.provenance import RunProvenance, SourceScope + +if TYPE_CHECKING: + from sema.models.planner.lifecycle import HumanPin, PinState, PlanVerdict, Status + + +_ALLOWED_STATUS_TRANSITIONS: dict[str, set[str]] = { + "candidate": {"auto_accepted", "review_pending", "rejected", "human_pinned"}, + "auto_accepted": {"review_pending", "human_pinned", "rejected"}, + "review_pending": {"human_pinned", "rejected", "auto_accepted"}, + "human_pinned": {"rejected"}, + "rejected": set(), +} + + +def transition_status(current: "Status", target: "Status") -> "Status": + allowed = _ALLOWED_STATUS_TRANSITIONS[current.value] + if target.value not in allowed: + raise ValueError( + f"disallowed assertion Status transition: {current.value} -> {target.value}" + ) + return target + + +def derive_plan_verdict( + *, + risk_codes: Iterable[str], + obligation_required_missing: bool, + fk_unsatisfied: bool, + minimum_viable_row_violated: bool, + any_review_pending: bool, + any_resolution_dependency_missing: bool, +) -> "PlanVerdict": + from sema.models.planner.lifecycle import PlanVerdict + + if any_review_pending: + return PlanVerdict.awaiting_review + if any_resolution_dependency_missing: + return PlanVerdict.blocked_by_resolution + if obligation_required_missing or minimum_viable_row_violated: + return PlanVerdict.blocked_by_obligation + if fk_unsatisfied: + return PlanVerdict.blocked_by_fk + block_codes = { + "RISK_PIVOT_CARDINALITY_UNVERIFIED", + "RISK_VOCAB_DOMAIN_MISMATCH", + } + if any(code in block_codes for code in risk_codes): + return PlanVerdict.blocked_by_constraint + return PlanVerdict.compilable + + +def _drift_dimensions( + pin: "HumanPin", current_run: RunProvenance, current_source: SourceScope +) -> set[str]: + drifted: set[str] = set() + for dim in pin.expires_on_change_of: + prior = _read_dim(pin.confirmed_under_run, pin.confirmed_under_source, dim) + current = _read_dim(current_run, current_source, dim) + if prior != current: + drifted.add(dim) + return drifted + + +def _read_dim(run: RunProvenance, source: SourceScope, dim: str) -> object: + if hasattr(run, dim): + return getattr(run, dim) + if hasattr(source, dim): + return getattr(source, dim) + raise ValueError(f"unknown provenance dimension: {dim}") + + +def detect_pin_stale( + pin: "HumanPin", current_run: RunProvenance, current_source: SourceScope +) -> "HumanPin": + from sema.models.planner.lifecycle import PinState + + if pin.pin_state in (PinState.invalidated,): + return pin + drift = _drift_dimensions(pin, current_run, current_source) + if drift: + return pin.model_copy(update={"pin_state": PinState.stale}) + return pin + + +def revalidate( + pin: "HumanPin", + *, + holds: bool, + current_run: RunProvenance, + current_source: SourceScope, +) -> "HumanPin": + from sema.models.planner.lifecycle import PinState + + if holds: + return pin.model_copy( + update={ + "pin_state": PinState.revalidated, + "confirmed_under_run": current_run, + "confirmed_under_source": current_source, + } + ) + return pin.model_copy(update={"pin_state": PinState.invalidated}) + + +def revoke_pin(pin: "HumanPin") -> "HumanPin": + from sema.models.planner.lifecycle import PinState + + return pin.model_copy(update={"pin_state": PinState.invalidated}) + + +class DispatchDecision(str, Enum): + SKIP = "SKIP" + REVALIDATE = "REVALIDATE" + DISPATCH = "DISPATCH" + + +class BuildContext(BaseModel): + pins: list["HumanPin"] = Field(default_factory=list) + rejected_pairs: list[tuple[str, str]] = Field(default_factory=list) + + def dispatch_decision(self, pin: "HumanPin") -> DispatchDecision: + from sema.models.planner.lifecycle import PinState + + if pin.pin_state in (PinState.active, PinState.revalidated): + return DispatchDecision.SKIP + if pin.pin_state is PinState.stale: + return DispatchDecision.REVALIDATE + return DispatchDecision.DISPATCH + + def is_rejected(self, pair: tuple[str, str]) -> bool: + return pair in set(self.rejected_pairs) diff --git a/src/sema/models/planner/mapping_plan.py b/src/sema/models/planner/mapping_plan.py new file mode 100644 index 0000000..7fcb064 --- /dev/null +++ b/src/sema/models/planner/mapping_plan.py @@ -0,0 +1,137 @@ +"""mapping-planner: MappingAssertion, MappingPlan, ConflictResolutionPolicy.""" + +from __future__ import annotations + +from typing import Any, Iterable, Protocol, Self + +from pydantic import BaseModel, Field, model_validator + +from sema.models.planner._refs import RefStr +from sema.models.planner.field_map import FieldMap, RowIdentity, coerce_pattern_payload +from sema.models.planner.lifecycle import PlanVerdict, Status +from sema.models.planner.lifecycle_utils import derive_plan_verdict +from sema.models.planner.patterns import ( + MappingPattern, + PatternPayload, + expected_payload_type, +) +from sema.models.planner.provenance import Provenance +from sema.models.planner.risk import RiskFlag +from sema.models.planner.target_model import TargetObligation + + +class MappingAssertion(BaseModel): + id: str = Field(min_length=1) + source_field_ref: RefStr + target_property_ref: RefStr + pattern: MappingPattern + payload: PatternPayload + confidence: float = Field(ge=0.0, le=1.0) + risk_flags: list[RiskFlag] = Field(default_factory=list) + provenance: Provenance + status: Status = Status.candidate + concerns_text: str | None = None + + @model_validator(mode="before") + @classmethod + def _coerce_payload(cls, data: Any) -> Any: + return coerce_pattern_payload(data) + + @model_validator(mode="after") + def _validate_payload_matches_pattern(self) -> Self: + expected = expected_payload_type(self.pattern) + if not isinstance(self.payload, expected): + raise ValueError( + f"pattern {self.pattern.value} requires payload of {expected.__name__}" + ) + return self + + +class MappingPlan(BaseModel): + id: str = Field(min_length=1) + source_scope_ref: str = Field(min_length=1) + obligation: TargetObligation + row_identity: RowIdentity + field_maps: list[FieldMap] = Field(default_factory=list) + risk_flags: list[RiskFlag] = Field(default_factory=list) + lineage: list[RefStr] = Field(default_factory=list) + + def covered_required_fields(self) -> set[str]: + return {fm.target_field_ref for fm in self.field_maps} + + def derive_verdict(self) -> PlanVerdict: + required = set(self.obligation.required_fields) + covered = self.covered_required_fields() + codes = {rf.code.value for rf in self.risk_flags} + return derive_plan_verdict( + risk_codes=codes, + obligation_required_missing=( + not required.issubset(covered) + or "RISK_OBLIGATION_REQUIRED_FIELD_MISSING" in codes + ), + fk_unsatisfied="RISK_OBLIGATION_FK_UNSATISFIED" in codes, + minimum_viable_row_violated=( + "RISK_OBLIGATION_MINIMUM_VIABLE_ROW_VIOLATED" in codes + ), + any_review_pending=False, + any_resolution_dependency_missing=( + "RISK_RESOLUTION_DEPENDENCY_MISSING" in codes + ), + ) + + +class ConflictResolutionPolicy(BaseModel): + by_pin: bool = True + by_status_tier: bool = True + by_confidence: bool = True + by_recency: bool = True + by_template_version: bool = True + + @classmethod + def default(cls) -> ConflictResolutionPolicy: + return cls() + + +_STATUS_TIER: dict[Status, int] = { + Status.human_pinned: 4, + Status.auto_accepted: 3, + Status.review_pending: 2, + Status.candidate: 1, + Status.rejected: 0, +} + + +def _sort_key(a: MappingAssertion) -> tuple[int, float, float, str]: + tier = _STATUS_TIER[a.status] + confidence = a.confidence + recency = a.provenance.timestamp.timestamp() + template = a.provenance.run.prompt_template_version + return (tier, confidence, recency, template) + + +def select_winner( + assertions: Iterable[MappingAssertion], + policy: ConflictResolutionPolicy, # noqa: ARG001 (default policy is the only impl) +) -> MappingAssertion: + candidates = [a for a in assertions if a.status is not Status.rejected] + if not candidates: + raise ValueError("no non-rejected assertions to resolve") + candidates.sort(key=_sort_key, reverse=True) + return candidates[0] + + +class PlanAssembler(Protocol): + """Plan-assembler signature stub. + + Implementations live in the matching-engine change; this protocol fixes + the contract surface and the RISK_ASSEMBLER_CONFLICT_RESOLVED emission rule. + """ + + conflict_policy: ConflictResolutionPolicy + + def assemble( + self, + assertions: list[MappingAssertion], + obligation: TargetObligation, + row_identity: RowIdentity, + ) -> MappingPlan: ... diff --git a/src/sema/models/planner/patterns.py b/src/sema/models/planner/patterns.py new file mode 100644 index 0000000..93ae8e2 --- /dev/null +++ b/src/sema/models/planner/patterns.py @@ -0,0 +1,229 @@ +"""mapping-planner: 11 patterns + per-pattern payload schemas.""" + +from __future__ import annotations + +from enum import Enum +from typing import Any, Literal, Self + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from sema.models.planner._refs import RefStr + + +class _StrictPayload(BaseModel): + """Base for per-pattern payloads; rejects unknown fields per spec 3.2.""" + + model_config = ConfigDict(extra="forbid") + + +class MappingPattern(str, Enum): + DIRECT_COPY = "DIRECT_COPY" + CONSTANT = "CONSTANT" + DERIVED = "DERIVED" + VOCAB_LOOKUP = "VOCAB_LOOKUP" + JOIN_LOOKUP = "JOIN_LOOKUP" + PIVOT = "PIVOT" + UNPIVOT = "UNPIVOT" + SPLIT = "SPLIT" + AGGREGATE = "AGGREGATE" + ROW_GENERATION = "ROW_GENERATION" + NO_MAP = "NO_MAP" + + +class NoMapScope(str, Enum): + GLOBAL = "GLOBAL" + TARGET_ENTITY = "TARGET_ENTITY" + TARGET_PROPERTY = "TARGET_PROPERTY" + + +class AggregateFunction(str, Enum): + COUNT = "COUNT" + SUM = "SUM" + AVG = "AVG" + MIN = "MIN" + MAX = "MAX" + FIRST = "FIRST" + LAST = "LAST" + ANY = "ANY" + ALL = "ALL" + ARRAY_AGG = "ARRAY_AGG" + JSON_AGG = "JSON_AGG" + + +class CardinalityAssumption(str, Enum): + one_to_one = "one_to_one" + many_to_one = "many_to_one" + + +class PivotExpansionMode(str, Enum): + one_row_per_key = "one_row_per_key" + multi_column = "multi_column" + + +class DirectCopyPayload(_StrictPayload): + source_field_ref: RefStr + + +class ConstantValue(_StrictPayload): + literal_value: Any + target_type: str = Field(min_length=1) + + +class DerivedExpression(_StrictPayload): + source_field_refs: list[RefStr] = Field(min_length=1) + expression_ast: dict[str, Any] + nullability_rule: str | None = None + + +class VocabLookup(_StrictPayload): + vocabulary_ref: RefStr + source_value_ref: RefStr + domain_constraint_ref: RefStr + require_standard: bool + allow_zero_default: bool + resolver_policy_ref: RefStr + effective_date_ref: RefStr | None = None + + +class JoinKeyPair(_StrictPayload): + from_field_ref: RefStr + to_field_ref: RefStr + + +class JoinLookup(_StrictPayload): + from_source_ref: RefStr + to_source_ref: RefStr + join_keys: list[JoinKeyPair] = Field(min_length=1) + select_field_ref: RefStr + cardinality_assumption: CardinalityAssumption | None = None + resolution_dependency_ref: RefStr | None = None + + +class PivotMapping(_StrictPayload): + source_table_ref: RefStr + key_field_ref: RefStr + value_field_ref: RefStr + partition_keys: list[RefStr] = Field(min_length=1) + expansion_mode: PivotExpansionMode + max_keys: int | None = Field(default=None, gt=0) + + +class UnpivotMapping(_StrictPayload): + source_table_ref: RefStr + key_columns: list[RefStr] = Field(min_length=1) + key_name_target_field: RefStr + value_target_field: RefStr + null_skip: bool = False + + +class SplitRuleKind(str, Enum): + regex = "regex" + delimiter = "delimiter" + + +class SplitRule(_StrictPayload): + kind: SplitRuleKind + pattern: str | None = None + delimiter: str | None = None + positions: list[str] | None = None + + +class SplitMapping(_StrictPayload): + source_field_ref: RefStr + split_rule: SplitRule + output_target_fields: dict[str, RefStr] = Field(min_length=1) + + +class AggregateOp(_StrictPayload): + target_field_ref: RefStr + aggregate_function: AggregateFunction + source_field_ref: RefStr + + +class AggregateMapping(_StrictPayload): + source_table_ref: RefStr + group_by_keys: list[RefStr] = Field(min_length=1) + aggregations: list[AggregateOp] = Field(min_length=1) + filter_predicate: dict[str, Any] | None = None + having_predicate: dict[str, Any] | None = None + + +class GenerationRule(_StrictPayload): + kind: Literal["distinct_keys", "window_envelope"] + keys: list[RefStr] | None = None + partition: list[RefStr] | None = None + min_field: RefStr | None = None + max_field: RefStr | None = None + + +class RowGenerationMapping(_StrictPayload): + source_scope_ref: RefStr + generation_rule: GenerationRule + populated_field_maps: list[Any] = Field(min_length=1) + + @model_validator(mode="after") + def _validate_field_maps(self) -> Self: + from sema.models.planner.field_map import FieldMap + + coerced: list[Any] = [] + for i, fm in enumerate(self.populated_field_maps): + if isinstance(fm, FieldMap): + coerced.append(fm) + continue + if isinstance(fm, dict): + coerced.append(FieldMap.model_validate(fm)) + continue + raise ValueError( + f"populated_field_maps[{i}] must be a FieldMap (got {type(fm).__name__})" + ) + object.__setattr__(self, "populated_field_maps", coerced) + return self + + +class NoMapPayload(_StrictPayload): + reason: str = Field(min_length=1) + scope: NoMapScope + target_entity_ref: RefStr | None = None + target_property_ref: RefStr | None = None + + @model_validator(mode="after") + def _validate_scope(self) -> Self: + if self.scope is NoMapScope.TARGET_PROPERTY and not self.target_property_ref: + raise ValueError("scope=TARGET_PROPERTY MUST identify a target Property") + if self.scope is NoMapScope.TARGET_ENTITY and not self.target_entity_ref: + raise ValueError("scope=TARGET_ENTITY MUST identify a target Entity") + return self + + +PatternPayload = ( + DirectCopyPayload + | ConstantValue + | DerivedExpression + | VocabLookup + | JoinLookup + | PivotMapping + | UnpivotMapping + | SplitMapping + | AggregateMapping + | RowGenerationMapping + | NoMapPayload +) + + +_PATTERN_PAYLOAD_TYPES: dict[MappingPattern, type[BaseModel]] = { + MappingPattern.DIRECT_COPY: DirectCopyPayload, + MappingPattern.CONSTANT: ConstantValue, + MappingPattern.DERIVED: DerivedExpression, + MappingPattern.VOCAB_LOOKUP: VocabLookup, + MappingPattern.JOIN_LOOKUP: JoinLookup, + MappingPattern.PIVOT: PivotMapping, + MappingPattern.UNPIVOT: UnpivotMapping, + MappingPattern.SPLIT: SplitMapping, + MappingPattern.AGGREGATE: AggregateMapping, + MappingPattern.ROW_GENERATION: RowGenerationMapping, + MappingPattern.NO_MAP: NoMapPayload, +} + + +def expected_payload_type(pattern: MappingPattern) -> type[BaseModel]: + return _PATTERN_PAYLOAD_TYPES[pattern] diff --git a/src/sema/models/planner/provenance.py b/src/sema/models/planner/provenance.py new file mode 100644 index 0000000..49afa47 --- /dev/null +++ b/src/sema/models/planner/provenance.py @@ -0,0 +1,197 @@ +"""provenance-and-caching capability.""" + +from __future__ import annotations + +import hashlib +import json +from datetime import datetime +from typing import Any, Protocol, Self, runtime_checkable + +from pydantic import BaseModel, Field, model_validator + + +class RunProvenance(BaseModel): + run_id: str = Field(min_length=1) + target_model_version: str = Field(min_length=1) + target_schema_snapshot_hash: str = Field(min_length=1) + vocab_release: str | None = None + context_card_version: str = Field(min_length=1) + prompt_template_version: str = Field(min_length=1) + few_shot_set_version: str = Field(min_length=1) + constraint_version: str = Field(min_length=1) + llm_model: str = Field(min_length=1) + embedding_model: str | None = None + + +class SourceScope(BaseModel): + source_id: str = Field(min_length=1) + source_schema_hash: str = Field(min_length=1) + source_profile_hash: str = Field(min_length=1) + + +class Provenance(BaseModel): + run: RunProvenance + source: SourceScope + timestamp: datetime + + +class RunVersionLock: + """Enforces RunProvenance immutability within a run_id.""" + + def __init__(self) -> None: + self._bound: RunProvenance | None = None + + def bind(self, rp: RunProvenance) -> None: + if self._bound is None: + self._bound = rp + return + if self._bound.model_dump() != rp.model_dump(): + raise ValueError( + "RunProvenance fields must remain constant within run_id; " + "increment run_id to change run-locked context" + ) + + +class SourceScopeLock: + """Enforces SourceScope immutability per (run_id, source_id) pair.""" + + def __init__(self, run_id: str) -> None: + self.run_id = run_id + self._bound: dict[str, SourceScope] = {} + + def bind(self, scope: SourceScope) -> None: + prior = self._bound.get(scope.source_id) + if prior is None: + self._bound[scope.source_id] = scope + return + if prior.model_dump() != scope.model_dump(): + raise ValueError( + f"SourceScope drift for ({self.run_id}, {scope.source_id})" + ) + + +def _stable_digest(text: str) -> str: + return hashlib.sha256(text.encode("utf-8")).hexdigest() + + +class PromptArtifact(BaseModel): + prefix_text: str + prefix_hash: str + suffix_text: str + versions: dict[str, str] = Field(default_factory=dict) + + @model_validator(mode="after") + def _validate_hash(self) -> Self: + expected = _stable_digest(self.prefix_text) + if self.prefix_hash != expected: + raise ValueError("prefix_hash must be sha256(prefix_text)") + return self + + @classmethod + def build( + cls, + prefix_text: str, + suffix_text: str, + versions: dict[str, str] | None = None, + ) -> PromptArtifact: + return cls( + prefix_text=prefix_text, + prefix_hash=_stable_digest(prefix_text), + suffix_text=suffix_text, + versions=versions or {}, + ) + + def assert_source_isolated(self, source_field_ref: str) -> None: + """Reject prefixes contaminated by source-similar few-shots. + + Spec 6.6: prefix_text MUST contain only target-Entity-stable few-shots; + any reference to the per-call source field belongs in suffix_text. If + the source ref leaks into prefix_text, the prefix_hash diverges per + call and breaks LLM prompt caching. + """ + if not source_field_ref: + raise ValueError("source_field_ref must be non-empty") + if source_field_ref in self.prefix_text: + raise ValueError( + f"prefix_text contains source_field_ref={source_field_ref!r}; " + "move source-similar content to suffix_text" + ) + + +_CACHE_KEY_FIELDS = ( + "target_model_version", + "context_card_version", + "prompt_template_version", + "few_shot_set_version", + "llm_model", +) + + +def derive_cache_key( + artifact: PromptArtifact, + run: RunProvenance, + source: SourceScope | None = None, # noqa: ARG001 (intentional: scope is excluded) +) -> str: + parts = [artifact.prefix_hash] + parts.extend(getattr(run, name) for name in _CACHE_KEY_FIELDS) + return _stable_digest("|".join(parts)) + + +def compute_source_profile_hash(signature: dict[str, Any]) -> str: + canonical = json.dumps(signature, sort_keys=True, separators=(",", ":")) + return _stable_digest(canonical) + + +class LLMResponse(BaseModel): + text: str + cache_hit: bool = False + raw_meta: dict[str, Any] = Field(default_factory=dict) + + +@runtime_checkable +class LLMRuntime(Protocol): + name: str + dialect: str + + def call(self, artifact: PromptArtifact) -> LLMResponse: ... + + def cache_directives(self, artifact: PromptArtifact) -> dict[str, str]: ... + + +class _AdapterBase: + dialect: str = "" + + def __init__(self, name: str) -> None: + self.name = name + + def call(self, artifact: PromptArtifact) -> LLMResponse: + return LLMResponse( + text="", + cache_hit=False, + raw_meta={ + "adapter": self.dialect, + "model": self.name, + "prefix_hash": artifact.prefix_hash, + }, + ) + + +class AnthropicCachingAdapter(_AdapterBase): + dialect = "anthropic" + + def cache_directives(self, artifact: PromptArtifact) -> dict[str, str]: + return {"cache_control": "ephemeral", "prefix_hash": artifact.prefix_hash} + + +class MosaicAIAdapter(_AdapterBase): + dialect = "mosaic" + + def cache_directives(self, artifact: PromptArtifact) -> dict[str, str]: + return {"prefix_hash": artifact.prefix_hash} + + +class DeepSeekAdapter(_AdapterBase): + dialect = "deepseek" + + def cache_directives(self, artifact: PromptArtifact) -> dict[str, str]: + return {"prefix_hash": artifact.prefix_hash} diff --git a/src/sema/models/planner/resolution.py b/src/sema/models/planner/resolution.py new file mode 100644 index 0000000..320f114 --- /dev/null +++ b/src/sema/models/planner/resolution.py @@ -0,0 +1,170 @@ +"""resolution-planner capability.""" + +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Any, Self + +from pydantic import BaseModel, Field, model_validator + +from sema.models.planner._refs import RefStr +from sema.models.planner.lifecycle import Status +from sema.models.planner.provenance import RunProvenance, SourceScope +from sema.models.planner.risk import RiskFlag + + +class ResolutionStrategy(str, Enum): + DETERMINISTIC_HASH = "DETERMINISTIC_HASH" + FUZZY_BLOCK_AND_SCORE = "FUZZY_BLOCK_AND_SCORE" + GRAPH_CLOSURE = "GRAPH_CLOSURE" + MULTI_KEY_UNION = "MULTI_KEY_UNION" + + +class CycleHandling(str, Enum): + REJECT = "REJECT" + BREAK_AT_DEPTH = "BREAK_AT_DEPTH" + MARK_AND_CONTINUE = "MARK_AND_CONTINUE" + + +class ResolutionVerdict(str, Enum): + resolved = "resolved" + ambiguous = "ambiguous" + unresolved = "unresolved" + awaiting_review = "awaiting_review" + + +class DeterministicHashPayload(BaseModel): + source_key_refs: list[RefStr] = Field(min_length=1) + + +class FuzzyBlockAndScorePayload(BaseModel): + blocking_keys: list[RefStr] = Field(min_length=1) + similarity_features: list[RefStr] = Field(min_length=1) + + +class GraphClosurePayload(BaseModel): + walk_relationship: str = Field(min_length=1) + max_depth: int | None = Field(default=None, gt=0) + + +class MultiKeyUnionPayload(BaseModel): + source_key_refs: list[RefStr] = Field(min_length=2) + + +ResolutionPayload = ( + DeterministicHashPayload + | FuzzyBlockAndScorePayload + | GraphClosurePayload + | MultiKeyUnionPayload +) + + +_STRATEGY_PAYLOAD_TYPES: dict[ResolutionStrategy, type[BaseModel]] = { + ResolutionStrategy.DETERMINISTIC_HASH: DeterministicHashPayload, + ResolutionStrategy.FUZZY_BLOCK_AND_SCORE: FuzzyBlockAndScorePayload, + ResolutionStrategy.GRAPH_CLOSURE: GraphClosurePayload, + ResolutionStrategy.MULTI_KEY_UNION: MultiKeyUnionPayload, +} + + +class CycleHandlingRule(BaseModel): + handling: CycleHandling + depth: int | None = Field(default=None, gt=0) + + @model_validator(mode="after") + def _validate_depth(self) -> Self: + if self.handling is CycleHandling.BREAK_AT_DEPTH and self.depth is None: + raise ValueError("BREAK_AT_DEPTH requires depth") + return self + + +class ResolutionPlan(BaseModel): + id: str = Field(min_length=1) + sources: list[SourceScope] = Field(min_length=1) + target_identity_ref: RefStr + strategy: ResolutionStrategy + payload: ResolutionPayload + transitive_closure: bool = False + cycle_handling: CycleHandlingRule | None = None + confidence: float = Field(ge=0.0, le=1.0) + risk_flags: list[RiskFlag] = Field(default_factory=list) + provenance_run: RunProvenance + timestamp: datetime + status: Status = Status.candidate + + @model_validator(mode="before") + @classmethod + def _coerce_payload_type(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + strategy = data.get("strategy") + payload = data.get("payload") + if isinstance(payload, dict) and isinstance(strategy, str): + target = _STRATEGY_PAYLOAD_TYPES.get(ResolutionStrategy(strategy)) + if target is not None: + data["payload"] = target.model_validate(payload) + return data + + @model_validator(mode="after") + def _validate_strategy_invariants(self) -> Self: + _validate_payload(self.strategy, self.payload) + _validate_closure(self.strategy, self.transitive_closure, self.cycle_handling) + _validate_confidence(self.strategy, self.confidence) + return self + + +def _validate_payload( + strategy: ResolutionStrategy, payload: ResolutionPayload +) -> None: + expected = _STRATEGY_PAYLOAD_TYPES[strategy] + if not isinstance(payload, expected): + raise ValueError( + f"strategy {strategy.value} requires payload of {expected.__name__}" + ) + + +def _validate_closure( + strategy: ResolutionStrategy, + transitive: bool, + cycle: CycleHandlingRule | None, +) -> None: + if strategy is ResolutionStrategy.GRAPH_CLOSURE and not transitive: + raise ValueError("GRAPH_CLOSURE requires transitive_closure=True") + if transitive and cycle is None: + raise ValueError("transitive_closure=True requires cycle_handling") + if not transitive and cycle is not None: + raise ValueError("cycle_handling rejected when transitive_closure=False") + + +def _validate_confidence( + strategy: ResolutionStrategy, confidence: float +) -> None: + if strategy is ResolutionStrategy.DETERMINISTIC_HASH and confidence != 1.0: + raise ValueError("DETERMINISTIC_HASH requires confidence=1.0") + + +def derive_resolution_verdict( + *, + produced_for_every_input: bool, + ambiguous_assignments: bool, + cycle_blocked: bool, + any_block_flag: bool, + plan_review_pending: bool, +) -> ResolutionVerdict: + if plan_review_pending: + return ResolutionVerdict.awaiting_review + if any_block_flag and cycle_blocked: + return ResolutionVerdict.unresolved + if ambiguous_assignments: + return ResolutionVerdict.ambiguous + if not produced_for_every_input: + return ResolutionVerdict.unresolved + if any_block_flag: + return ResolutionVerdict.awaiting_review + return ResolutionVerdict.resolved + + +class ResolutionDependency(BaseModel): + upstream_plan_id: str = Field(min_length=1) + canonical_identity_column: RefStr diff --git a/src/sema/models/planner/risk.py b/src/sema/models/planner/risk.py new file mode 100644 index 0000000..db83f36 --- /dev/null +++ b/src/sema/models/planner/risk.py @@ -0,0 +1,119 @@ +"""risk-and-evidence capability: structured RiskFlag + typed Evidence.""" + +from __future__ import annotations + +from enum import Enum +from typing import Any, Self + +from pydantic import BaseModel, Field, model_validator + + +class Severity(str, Enum): + info = "info" + warn = "warn" + block = "block" + + +class SourceStage(str, Enum): + candidate_gen = "candidate_gen" + producer = "producer" + constraint = "constraint" + verify = "verify" + transform = "transform" + + +class SuggestedAction(str, Enum): + review = "review" + request_more_samples = "request_more_samples" + reject = "reject" + ignore_with_reason = "ignore_with_reason" + + +class EvidenceMode(str, Enum): + RAW = "RAW" + CATEGORICAL = "CATEGORICAL" + HASH = "HASH" + COUNT_ONLY = "COUNT_ONLY" + EXCERPT = "EXCERPT" + + +class SensitivityClass(str, Enum): + PUBLIC = "PUBLIC" + PII = "PII" + PHI = "PHI" + FINANCIAL_RESTRICTED = "FINANCIAL_RESTRICTED" + CONFIDENTIAL = "CONFIDENTIAL" + + +_DEFAULT_MODE: dict[SensitivityClass, EvidenceMode] = { + SensitivityClass.PUBLIC: EvidenceMode.RAW, + SensitivityClass.PII: EvidenceMode.HASH, + SensitivityClass.PHI: EvidenceMode.CATEGORICAL, + SensitivityClass.FINANCIAL_RESTRICTED: EvidenceMode.HASH, + SensitivityClass.CONFIDENTIAL: EvidenceMode.CATEGORICAL, +} + + +def default_evidence_mode(sensitivity: SensitivityClass) -> EvidenceMode: + return _DEFAULT_MODE[sensitivity] + + +class RiskCode(str, Enum): + RISK_VOCAB_DOMAIN_MISMATCH = "RISK_VOCAB_DOMAIN_MISMATCH" + RISK_PIVOT_CARDINALITY_UNVERIFIED = "RISK_PIVOT_CARDINALITY_UNVERIFIED" + RISK_TEMPORAL_LOST = "RISK_TEMPORAL_LOST" + RISK_AMBIGUOUS_TARGET = "RISK_AMBIGUOUS_TARGET" + RISK_OBLIGATION_REQUIRED_FIELD_MISSING = "RISK_OBLIGATION_REQUIRED_FIELD_MISSING" + RISK_OBLIGATION_FK_UNSATISFIED = "RISK_OBLIGATION_FK_UNSATISFIED" + RISK_OBLIGATION_MINIMUM_VIABLE_ROW_VIOLATED = ( + "RISK_OBLIGATION_MINIMUM_VIABLE_ROW_VIOLATED" + ) + RISK_DEFAULT_APPLIED = "RISK_DEFAULT_APPLIED" + RISK_RESOLUTION_DEPENDENCY_MISSING = "RISK_RESOLUTION_DEPENDENCY_MISSING" + RISK_LLC_CYCLE_DETECTED = "RISK_LLC_CYCLE_DETECTED" + RISK_ASSEMBLER_CONFLICT_RESOLVED = "RISK_ASSEMBLER_CONFLICT_RESOLVED" + + +class Evidence(BaseModel): + mode: EvidenceMode | None = None + payload: dict[str, Any] + sensitivity_class: SensitivityClass + source_ref: str = Field(min_length=1) + explicit_raw_override: bool = False + + @model_validator(mode="after") + def _validate_mode_payload(self) -> Self: + if self.mode is None: + object.__setattr__( + self, "mode", default_evidence_mode(self.sensitivity_class) + ) + assert self.mode is not None + _validate_mode_payload(self.mode, self.payload) + _validate_phi_raw(self.mode, self.sensitivity_class, self.explicit_raw_override) + return self + + +def _validate_mode_payload(mode: EvidenceMode, payload: dict[str, Any]) -> None: + if mode is EvidenceMode.COUNT_ONLY: + keys = set(payload.keys()) + if not keys.issubset({"count", "distinct"}): + raise ValueError("COUNT_ONLY rejects literal-value payloads") + if mode is EvidenceMode.HASH and "hash" not in payload and "digest" not in payload: + raise ValueError("HASH mode requires a hash/digest payload field") + + +def _validate_phi_raw( + mode: EvidenceMode, sens: SensitivityClass, override: bool +) -> None: + if mode is EvidenceMode.RAW and sens is SensitivityClass.PHI and not override: + raise ValueError( + "RAW mode against PHI requires explicit_raw_override=True" + ) + + +class RiskFlag(BaseModel): + code: RiskCode + severity: Severity + evidence: list[Evidence] = Field(default_factory=list) + source_stage: SourceStage + suggested_action: SuggestedAction diff --git a/src/sema/models/planner/target_model.py b/src/sema/models/planner/target_model.py new file mode 100644 index 0000000..1b5b697 --- /dev/null +++ b/src/sema/models/planner/target_model.py @@ -0,0 +1,104 @@ +"""target-model capability: target-side schema graph + obligations.""" + +from __future__ import annotations + +from typing import Any, Literal, Self + +from pydantic import BaseModel, Field, model_validator + +from sema.models.planner._enums import ModelRole, PrimaryKeyStrategy +from sema.models.planner._role_validation import require_role_identifier + + +class Constraint(BaseModel): + id: str + name: str + rule_kind: str + payload: dict[str, Any] = Field(default_factory=dict) + model_role: ModelRole = ModelRole.SOURCE + source_id: str | None = None + target_model_id: str | None = None + + @model_validator(mode="after") + def _validate_role(self) -> Self: + require_role_identifier(self.model_role, self.source_id, self.target_model_id) + return self + + +class ForeignKeyObligation(BaseModel): + referenced_entity: str = Field(min_length=1) + join_keys: list[tuple[str, str]] = Field(min_length=1) + same_build_required: bool = True + + +class DomainConstraint(BaseModel): + property_name: str = Field(min_length=1) + domain_id: str = Field(min_length=1) + + +class FieldPresence(BaseModel): + kind: Literal["presence"] = "presence" + field: str = Field(min_length=1) + + +class FieldEquality(BaseModel): + kind: Literal["equality"] = "equality" + field: str = Field(min_length=1) + value: Any + + +RowClause = FieldPresence | FieldEquality + + +class RowPredicate(BaseModel): + op: Literal["AND", "OR"] + clauses: list[RowClause] = Field(min_length=1) + + def evaluate( + self, + present_fields: set[str], + values: dict[str, Any] | None = None, + ) -> bool: + results = [self._evaluate_clause(c, present_fields, values) for c in self.clauses] + return all(results) if self.op == "AND" else any(results) + + @staticmethod + def _evaluate_clause( + clause: RowClause, + present_fields: set[str], + values: dict[str, Any] | None, + ) -> bool: + if isinstance(clause, FieldPresence): + return clause.field in present_fields + if values is None or clause.field not in values: + return False + return bool(values[clause.field] == clause.value) + + +class ExternalSequenceMappingTable(BaseModel): + mapping_table_name: str = Field(min_length=1) + canonical_identity_column: str = Field(min_length=1) + sequence_column: str = Field(min_length=1) + + +class TargetObligation(BaseModel): + target_entity: str = Field(min_length=1) + required_fields: list[str] = Field(min_length=1) + nullable_fields: list[str] = Field(default_factory=list) + primary_key: PrimaryKeyStrategy + external_sequence: ExternalSequenceMappingTable | None = None + foreign_keys: list[ForeignKeyObligation] = Field(default_factory=list) + domain_constraints: list[DomainConstraint] = Field(default_factory=list) + allowed_defaults: dict[str, Any] = Field(default_factory=dict) + minimum_viable_row: RowPredicate | None = None + + @model_validator(mode="after") + def _validate_pk(self) -> Self: + is_external = self.primary_key is PrimaryKeyStrategy.EXTERNAL_SEQUENCE + if is_external and self.external_sequence is None: + raise ValueError("EXTERNAL_SEQUENCE requires external_sequence") + if not is_external and self.external_sequence is not None: + raise ValueError( + f"primary_key={self.primary_key.value} rejects external_sequence" + ) + return self diff --git a/tests/integration/test_planner_round_trip.py b/tests/integration/test_planner_round_trip.py new file mode 100644 index 0000000..4befbed --- /dev/null +++ b/tests/integration/test_planner_round_trip.py @@ -0,0 +1,418 @@ +"""Round-trip integration tests for the planner contract storage layer. + +Covers tasks 8.10-8.15 and 9.1-9.9: persists each planner Pydantic model to +Neo4j via the migration + loader helpers and asserts byte-identity on +structured fields. Skipped automatically without a running Neo4j. +""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone + +import pytest + +pytestmark = pytest.mark.integration + + +from sema.graph.planner_loader import ( + confirmed_under_to_properties, + human_pin_to_properties, + mapping_assertion_to_properties, + properties_to_confirmed_under, + properties_to_human_pin, + properties_to_mapping_assertion, + properties_to_provenance, + provenance_to_properties, + read_human_pin, + read_mapping_assertion, + read_mapping_plan, + read_resolution_plan, + write_human_pin, + write_mapping_assertion, + write_mapping_plan, + write_resolution_plan, +) +from sema.graph.planner_migrations import cypher_down, cypher_up +from sema.models.planner._enums import ( + MaterializationMode, + ModelRole, + PrimaryKeyStrategy, +) +from sema.models.planner.field_map import FieldMap, RowIdentity +from sema.models.planner.lifecycle import ( + HumanPin, + PinState, + PlanVerdict, + Status, +) +from sema.models.planner.mapping_plan import ( + ConflictResolutionPolicy, + MappingAssertion, + MappingPlan, + select_winner, +) +from sema.models.planner.patterns import ( + DirectCopyPayload, + MappingPattern, + VocabLookup, +) +from sema.models.planner.provenance import ( + PromptArtifact, + Provenance, + RunProvenance, + SourceScope, + derive_cache_key, +) +from sema.models.planner.resolution import ( + DeterministicHashPayload, + MultiKeyUnionPayload, + ResolutionPlan, + ResolutionStrategy, + ResolutionVerdict, + derive_resolution_verdict, +) +from sema.models.planner.risk import ( + Evidence, + EvidenceMode, + RiskCode, + RiskFlag, + SensitivityClass, + Severity, + SourceStage, + SuggestedAction, +) +from sema.models.planner.target_model import ( + FieldPresence, + RowPredicate, + TargetObligation, +) + + +def _run_prov(**overrides: object) -> RunProvenance: + base: dict[str, object] = dict( + run_id="run-1", + target_model_version="omop-cdm-5.4", + target_schema_snapshot_hash="t", + vocab_release="v1", + context_card_version="cards-v3", + prompt_template_version="tpl-7", + few_shot_set_version="fs-12", + constraint_version="rules-v2", + llm_model="claude-opus-4.7", + embedding_model="bge-large", + ) + base.update(overrides) + return RunProvenance(**base) + + +def _src(source_id: str = "cbioportal_gbm") -> SourceScope: + return SourceScope( + source_id=source_id, + source_schema_hash=f"s-{source_id}", + source_profile_hash=f"p-{source_id}", + ) + + +def _provenance(source_id: str = "cbioportal_gbm") -> Provenance: + return Provenance( + run=_run_prov(), + source=_src(source_id), + timestamp=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + +@pytest.fixture +def migrated_neo4j(clean_neo4j): + with clean_neo4j.session() as session: + for stmt in cypher_up(enterprise=False, apoc=False): + session.run(stmt) + yield clean_neo4j + with clean_neo4j.session() as session: + for stmt in cypher_down(apoc=False): + session.run(stmt) + + +def test_migration_creates_constraints(migrated_neo4j) -> None: + with migrated_neo4j.session() as s: + rows = list(s.run("SHOW CONSTRAINTS")) + names = {r["name"] for r in rows} + assert "MappingAssertion_id_unique" in names + assert "HumanPin_id_unique" in names + + +def test_mapping_assertion_round_trip(migrated_neo4j) -> None: + payload = VocabLookup( + vocabulary_ref="omop.SNOMED", + source_value_ref="cbio.cancer_type", + domain_constraint_ref="omop.domain.Condition", + require_standard=True, + allow_zero_default=False, + resolver_policy_ref="omop.snomed.condition.v1", + effective_date_ref="cbio.diagnosis_date", + ) + risk = RiskFlag( + code=RiskCode.RISK_VOCAB_DOMAIN_MISMATCH, + severity=Severity.warn, + evidence=[ + Evidence( + mode=EvidenceMode.CATEGORICAL, + payload={"shape": "alpha"}, + sensitivity_class=SensitivityClass.PHI, + source_ref="cbio.cancer_type", + ), + Evidence( + mode=EvidenceMode.COUNT_ONLY, + payload={"count": 7}, + sensitivity_class=SensitivityClass.PHI, + source_ref="cbio.cancer_type", + ), + ], + source_stage=SourceStage.producer, + suggested_action=SuggestedAction.review, + ) + assertion = MappingAssertion( + id="a-rt", + source_field_ref="cbio.cancer_type", + target_property_ref="omop.condition_occurrence.condition_concept_id", + pattern=MappingPattern.VOCAB_LOOKUP, + payload=payload, + confidence=0.91, + risk_flags=[risk], + provenance=_provenance(), + status=Status.candidate, + ) + with migrated_neo4j.session() as s: + write_mapping_assertion(s, assertion) + rt = read_mapping_assertion(s, assertion.id) + assert rt == assertion + + +def test_human_pin_round_trip_each_state(migrated_neo4j) -> None: + for state in PinState: + pin = HumanPin( + pin_id=f"pin-{state.value}", + assertion_id="a-1", + pinned_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + pinned_by="reviewer@x", + confirmed_under_run=_run_prov(), + confirmed_under_source=_src(), + pin_state=state, + ) + with migrated_neo4j.session() as s: + write_human_pin(s, pin) + rt = read_human_pin(s, pin.pin_id) + assert rt == pin + + +def test_resolution_plan_round_trip_multi_source(migrated_neo4j) -> None: + plan = ResolutionPlan( + id="r-mu", + sources=[_src("acris.deeds"), _src("dof.parcels")], + target_identity_ref="canonical.property_id", + strategy=ResolutionStrategy.MULTI_KEY_UNION, + payload=MultiKeyUnionPayload( + source_key_refs=["acris.bbl", "dof.parcel_id"] + ), + confidence=0.85, + provenance_run=_run_prov(), + timestamp=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + with migrated_neo4j.session() as s: + write_resolution_plan(s, plan) + rt = read_resolution_plan(s, plan.id) + assert rt == plan + assert {s.source_id for s in rt.sources} == {"acris.deeds", "dof.parcels"} + + +def test_conflict_loser_relationship(migrated_neo4j) -> None: + a_winner = MappingAssertion( + id="a-win", + source_field_ref="cbio.x", + target_property_ref="omop.y", + pattern=MappingPattern.DIRECT_COPY, + payload=DirectCopyPayload(source_field_ref="cbio.x"), + confidence=0.99, + provenance=_provenance(), + status=Status.auto_accepted, + ) + a_loser = a_winner.model_copy(update={"id": "a-lose", "confidence": 0.7}) + winner = select_winner( + [a_winner, a_loser], ConflictResolutionPolicy.default() + ) + assert winner.id == "a-win" + plan = MappingPlan( + id="plan-conflict", + source_scope_ref="cbio", + obligation=TargetObligation( + target_entity="omop.person", + required_fields=["omop.y"], + primary_key=PrimaryKeyStrategy.NATURAL_KEY, + ), + row_identity=RowIdentity( + target_row_key_rule="hash", + source_lineage=["cbio.x"], + materialization_mode=MaterializationMode.MERGE, + ), + field_maps=[ + FieldMap( + target_field_ref="omop.y", + pattern=MappingPattern.DIRECT_COPY, + payload=DirectCopyPayload(source_field_ref="cbio.x"), + ) + ], + risk_flags=[ + RiskFlag( + code=RiskCode.RISK_ASSEMBLER_CONFLICT_RESOLVED, + severity=Severity.info, + evidence=[ + Evidence( + mode=EvidenceMode.COUNT_ONLY, + payload={"count": 1}, + sensitivity_class=SensitivityClass.PUBLIC, + source_ref="a-lose", + ) + ], + source_stage=SourceStage.constraint, + suggested_action=SuggestedAction.ignore_with_reason, + ) + ], + lineage=["cbio.x"], + ) + assert plan.derive_verdict() == PlanVerdict.compilable + with migrated_neo4j.session() as s: + write_mapping_plan(s, plan) + s.run( + "CREATE (w:MappingAssertion {id: $win})\n" + "CREATE (l:MappingAssertion {id: $lose})\n" + "WITH w, l\n" + "MATCH (p:MappingPlan {id: $plan_id})\n" + "MERGE (w)-[:ASSEMBLED_INTO]->(p)\n" + "MERGE (p)-[:CONFLICT_LOSER]->(l)", + plan_id=plan.id, + win=a_winner.id, + lose=a_loser.id, + ) + rt_plan = read_mapping_plan(s, plan.id) + loser_row = s.run( + "MATCH (p:MappingPlan {id: $id})-[:CONFLICT_LOSER]->(l) " + "RETURN l.id AS loser", + id=plan.id, + ).single() + assert rt_plan == plan + assert loser_row["loser"] == "a-lose" + + +def test_pin_staleness_query_uses_pin_state_index(migrated_neo4j) -> None: + pin = HumanPin( + pin_id="pin-q", + assertion_id="a-1", + pinned_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + pinned_by="reviewer@x", + confirmed_under_run=_run_prov(), + confirmed_under_source=_src(), + pin_state=PinState.stale, + ) + props = confirmed_under_to_properties( + pin.confirmed_under_run, pin.confirmed_under_source + ) + props["id"] = pin.pin_id + props["pin_id"] = pin.pin_id + props["assertion_id"] = pin.assertion_id + props["pinned_at"] = pin.pinned_at.isoformat() + props["pinned_by"] = pin.pinned_by + props["pin_state"] = pin.pin_state.value + props["expires_on_change_of"] = list(pin.expires_on_change_of) + with migrated_neo4j.session() as s: + s.run("CREATE (n:HumanPin) SET n = $props", props=props) + plan_root = s.run( + "EXPLAIN MATCH (h:HumanPin) WHERE h.pin_state IN $states " + "RETURN h.pin_id", + states=[PinState.active.value, PinState.stale.value], + ).consume().plan + operators = _collect_plan_operators(plan_root) + assert any( + "NodeIndexSeek" in op or "IndexSeek" in op for op in operators + ), f"expected NodeIndexSeek operator on human_pin_state, got {operators}" + + +def _collect_plan_operators(plan: dict) -> list[str]: + operators = [plan.get("operatorType", "")] + for child in plan.get("children", []) or []: + operators.extend(_collect_plan_operators(child)) + return [op for op in operators if op] + + +def test_cache_key_changes_across_runs() -> None: + art = PromptArtifact.build( + prefix_text="prefix", + suffix_text="s", + versions={"context_card_version": "v1"}, + ) + rp1 = _run_prov(context_card_version="cards-v3") + rp2 = _run_prov(context_card_version="cards-v4") + assert derive_cache_key(art, rp1) != derive_cache_key(art, rp2) + + +def test_resolution_verdict_derivation_matrix() -> None: + assert ( + derive_resolution_verdict( + produced_for_every_input=True, + ambiguous_assignments=False, + cycle_blocked=False, + any_block_flag=False, + plan_review_pending=False, + ) + == ResolutionVerdict.resolved + ) + assert ( + derive_resolution_verdict( + produced_for_every_input=True, + ambiguous_assignments=True, + cycle_blocked=False, + any_block_flag=False, + plan_review_pending=False, + ) + == ResolutionVerdict.ambiguous + ) + + +def test_multi_source_assertions_share_run_id(migrated_neo4j) -> None: + a_cbio = MappingAssertion( + id="a-cbio", + source_field_ref="cbio.gender", + target_property_ref="omop.gender_concept_id", + pattern=MappingPattern.DIRECT_COPY, + payload=DirectCopyPayload(source_field_ref="cbio.gender"), + confidence=0.9, + provenance=Provenance( + run=_run_prov(), + source=_src("cbioportal_gbm"), + timestamp=datetime(2026, 1, 1, tzinfo=timezone.utc), + ), + status=Status.candidate, + ) + a_msk = a_cbio.model_copy( + update={ + "id": "a-msk", + "provenance": Provenance( + run=_run_prov(), + source=_src("msk_chord"), + timestamp=datetime(2026, 1, 1, tzinfo=timezone.utc), + ), + } + ) + for a in (a_cbio, a_msk): + with migrated_neo4j.session() as s: + write_mapping_assertion(s, a) + with migrated_neo4j.session() as s: + rows = list( + s.run( + "MATCH (n:MappingAssertion) " + "RETURN n.prov_run_run_id AS run, n.prov_source_source_id AS src" + ) + ) + runs = {r["run"] for r in rows} + sources = {r["src"] for r in rows} + assert runs == {"run-1"} + assert sources == {"cbioportal_gbm", "msk_chord"} diff --git a/tests/unit/models/__init__.py b/tests/unit/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/models/planner/__init__.py b/tests/unit/models/planner/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/models/planner/test_lifecycle_and_pins.py b/tests/unit/models/planner/test_lifecycle_and_pins.py new file mode 100644 index 0000000..8833920 --- /dev/null +++ b/tests/unit/models/planner/test_lifecycle_and_pins.py @@ -0,0 +1,318 @@ +"""Tests for the lifecycle-and-pins capability.""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +from pydantic import ValidationError + +pytestmark = pytest.mark.unit + + +def _run_prov(**overrides: object) -> object: + from sema.models.planner.provenance import RunProvenance + + base = dict( + run_id="run-1", + target_model_version="omop-cdm-5.4", + target_schema_snapshot_hash="t-abc", + vocab_release="omop-2026-q1", + context_card_version="cards-v3", + prompt_template_version="tpl-7", + few_shot_set_version="fs-12", + constraint_version="rules-v2", + llm_model="claude-opus-4.7", + embedding_model="bge-large", + ) + base.update(overrides) + return RunProvenance(**base) + + +def _source(**overrides: object) -> object: + from sema.models.planner.provenance import SourceScope + + base = dict( + source_id="cbioportal_gbm", + source_schema_hash="s-abc", + source_profile_hash="p-abc", + ) + base.update(overrides) + return SourceScope(**base) + + +def test_status_values() -> None: + from sema.models.planner.lifecycle import Status + + assert {s.value for s in Status} == { + "candidate", + "auto_accepted", + "review_pending", + "human_pinned", + "rejected", + } + + +def test_status_transitions_allowed() -> None: + from sema.models.planner.lifecycle import Status, transition_status + + assert transition_status(Status.candidate, Status.auto_accepted) == Status.auto_accepted + assert transition_status(Status.candidate, Status.review_pending) == Status.review_pending + assert transition_status(Status.review_pending, Status.human_pinned) == Status.human_pinned + assert transition_status(Status.auto_accepted, Status.human_pinned) == Status.human_pinned + assert transition_status(Status.review_pending, Status.rejected) == Status.rejected + + +def test_status_transitions_forbidden() -> None: + from sema.models.planner.lifecycle import Status, transition_status + + with pytest.raises(ValueError): + transition_status(Status.human_pinned, Status.auto_accepted) + with pytest.raises(ValueError): + transition_status(Status.rejected, Status.candidate) + + +def test_plan_verdict_values() -> None: + from sema.models.planner.lifecycle import PlanVerdict + + assert {v.value for v in PlanVerdict} == { + "compilable", + "blocked_by_obligation", + "blocked_by_constraint", + "blocked_by_resolution", + "blocked_by_fk", + "awaiting_review", + } + + +def test_plan_verdict_compilable_when_clean() -> None: + from sema.models.planner.lifecycle import PlanVerdict, derive_plan_verdict + + v = derive_plan_verdict( + risk_codes=[], + obligation_required_missing=False, + fk_unsatisfied=False, + minimum_viable_row_violated=False, + any_review_pending=False, + any_resolution_dependency_missing=False, + ) + assert v == PlanVerdict.compilable + + +def test_plan_verdict_blocked_by_obligation() -> None: + from sema.models.planner.lifecycle import PlanVerdict, derive_plan_verdict + + v = derive_plan_verdict( + risk_codes=["RISK_OBLIGATION_REQUIRED_FIELD_MISSING"], + obligation_required_missing=True, + fk_unsatisfied=False, + minimum_viable_row_violated=False, + any_review_pending=False, + any_resolution_dependency_missing=False, + ) + assert v == PlanVerdict.blocked_by_obligation + + +def test_plan_verdict_awaiting_review() -> None: + from sema.models.planner.lifecycle import PlanVerdict, derive_plan_verdict + + v = derive_plan_verdict( + risk_codes=[], + obligation_required_missing=False, + fk_unsatisfied=False, + minimum_viable_row_violated=False, + any_review_pending=True, + any_resolution_dependency_missing=False, + ) + assert v == PlanVerdict.awaiting_review + + +def test_pin_state_values() -> None: + from sema.models.planner.lifecycle import PinState + + assert {s.value for s in PinState} == { + "active", + "stale", + "revalidated", + "invalidated", + } + + +def test_human_pin_default_expires() -> None: + from sema.models.planner.lifecycle import HumanPin, PinState + + pin = HumanPin( + pin_id="pin-1", + assertion_id="a-1", + pinned_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + pinned_by="reviewer@x", + confirmed_under_run=_run_prov(), + confirmed_under_source=_source(), + ) + assert pin.pin_state == PinState.active + assert "vocab_release" in pin.expires_on_change_of + assert "source_profile_hash" in pin.expires_on_change_of + + +def test_human_pin_must_reference_one_target() -> None: + from sema.models.planner.lifecycle import HumanPin + + with pytest.raises(ValidationError): + HumanPin( + pin_id="p", + pinned_at=datetime.now(timezone.utc), + pinned_by="x", + confirmed_under_run=_run_prov(), + confirmed_under_source=_source(), + ) + + with pytest.raises(ValidationError): + HumanPin( + pin_id="p", + assertion_id="a-1", + resolution_plan_id="r-1", + pinned_at=datetime.now(timezone.utc), + pinned_by="x", + confirmed_under_run=_run_prov(), + confirmed_under_source=_source(), + ) + + +def test_pin_stales_on_tracked_dim_drift() -> None: + from sema.models.planner.lifecycle import HumanPin, PinState, detect_pin_stale + + pin = HumanPin( + pin_id="p", + assertion_id="a-1", + pinned_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + pinned_by="x", + confirmed_under_run=_run_prov(), + confirmed_under_source=_source(), + ) + new_pin = detect_pin_stale(pin, _run_prov(vocab_release="omop-2026-q2"), _source()) + assert new_pin.pin_state == PinState.stale + + +def test_pin_does_not_stale_on_untracked_dim() -> None: + from sema.models.planner.lifecycle import HumanPin, PinState, detect_pin_stale + + pin = HumanPin( + pin_id="p", + assertion_id="a-1", + pinned_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + pinned_by="x", + expires_on_change_of=["target_model_version"], + confirmed_under_run=_run_prov(), + confirmed_under_source=_source(), + ) + new_pin = detect_pin_stale( + pin, _run_prov(prompt_template_version="tpl-99"), _source() + ) + assert new_pin.pin_state == PinState.active + + +def test_pin_stales_on_source_profile_drift() -> None: + from sema.models.planner.lifecycle import HumanPin, PinState, detect_pin_stale + + pin = HumanPin( + pin_id="p", + assertion_id="a-1", + pinned_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + pinned_by="x", + confirmed_under_run=_run_prov(), + confirmed_under_source=_source(), + ) + new_pin = detect_pin_stale( + pin, _run_prov(), _source(source_profile_hash="DIFFERENT") + ) + assert new_pin.pin_state == PinState.stale + + +def test_revalidation_success_transitions_to_revalidated() -> None: + from sema.models.planner.lifecycle import ( + HumanPin, + PinState, + revalidate, + ) + + pin = HumanPin( + pin_id="p", + assertion_id="a-1", + pinned_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + pinned_by="x", + confirmed_under_run=_run_prov(), + confirmed_under_source=_source(), + pin_state=PinState.stale, + ) + new_run = _run_prov(vocab_release="omop-2026-q2") + new_src = _source(source_profile_hash="DIFFERENT") + revalidated = revalidate(pin, holds=True, current_run=new_run, current_source=new_src) + assert revalidated.pin_state == PinState.revalidated + assert revalidated.confirmed_under_run.vocab_release == "omop-2026-q2" + + +def test_revalidation_failure_transitions_to_invalidated() -> None: + from sema.models.planner.lifecycle import ( + HumanPin, + PinState, + revalidate, + ) + + pin = HumanPin( + pin_id="p", + assertion_id="a-1", + pinned_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + pinned_by="x", + confirmed_under_run=_run_prov(), + confirmed_under_source=_source(), + pin_state=PinState.stale, + ) + invalidated = revalidate(pin, holds=False, current_run=_run_prov(), current_source=_source()) + assert invalidated.pin_state == PinState.invalidated + + +def test_reviewer_revocation_invalidates() -> None: + from sema.models.planner.lifecycle import ( + HumanPin, + PinState, + revoke_pin, + ) + + pin = HumanPin( + pin_id="p", + assertion_id="a-1", + pinned_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + pinned_by="x", + confirmed_under_run=_run_prov(), + confirmed_under_source=_source(), + ) + revoked = revoke_pin(pin) + assert revoked.pin_state == PinState.invalidated + + +def test_build_context_dispatch_rules() -> None: + from sema.models.planner.lifecycle import ( + BuildContext, + DispatchDecision, + HumanPin, + PinState, + ) + + pin_active = HumanPin( + pin_id="p1", + assertion_id="a1", + pinned_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + pinned_by="x", + confirmed_under_run=_run_prov(), + confirmed_under_source=_source(), + ) + pin_stale = pin_active.model_copy(update={"pin_id": "p2", "pin_state": PinState.stale}) + pin_invalid = pin_active.model_copy(update={"pin_id": "p3", "pin_state": PinState.invalidated}) + + ctx = BuildContext(pins=[pin_active, pin_stale, pin_invalid], rejected_pairs=[("s", "t")]) + + assert ctx.dispatch_decision(pin_active) == DispatchDecision.SKIP + assert ctx.dispatch_decision(pin_stale) == DispatchDecision.REVALIDATE + assert ctx.dispatch_decision(pin_invalid) == DispatchDecision.DISPATCH + assert ctx.is_rejected(("s", "t")) + assert not ctx.is_rejected(("u", "v")) diff --git a/tests/unit/models/planner/test_mapping_patterns.py b/tests/unit/models/planner/test_mapping_patterns.py new file mode 100644 index 0000000..7172b70 --- /dev/null +++ b/tests/unit/models/planner/test_mapping_patterns.py @@ -0,0 +1,306 @@ +"""Tests for the mapping-planner pattern enum and per-pattern payloads.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +pytestmark = pytest.mark.unit + + +def test_mapping_pattern_has_eleven_values() -> None: + from sema.models.planner.patterns import MappingPattern + + expected = { + "DIRECT_COPY", + "CONSTANT", + "DERIVED", + "VOCAB_LOOKUP", + "JOIN_LOOKUP", + "PIVOT", + "UNPIVOT", + "SPLIT", + "AGGREGATE", + "ROW_GENERATION", + "NO_MAP", + } + assert {p.value for p in MappingPattern} == expected + + +def test_unknown_pattern_rejected() -> None: + from sema.models.planner.patterns import MappingPattern + + with pytest.raises(ValueError): + MappingPattern("FUZZY_LOOKUP") + + +def test_no_map_scope_values() -> None: + from sema.models.planner.patterns import NoMapScope + + assert {s.value for s in NoMapScope} == { + "GLOBAL", + "TARGET_ENTITY", + "TARGET_PROPERTY", + } + + +def test_direct_copy_payload() -> None: + from sema.models.planner.patterns import DirectCopyPayload + + p = DirectCopyPayload(source_field_ref="cbio.patient.gender") + assert p.source_field_ref == "cbio.patient.gender" + + +def test_constant_payload_rejects_null_for_required() -> None: + from sema.models.planner.patterns import ConstantValue + + p = ConstantValue(literal_value=42, target_type="int") + assert p.literal_value == 42 + + +def test_derived_expression_requires_inputs() -> None: + from sema.models.planner.patterns import DerivedExpression + + with pytest.raises(ValidationError): + DerivedExpression(source_field_refs=[], expression_ast={"op": "year"}) + + +def test_vocab_lookup_requires_all_hooks() -> None: + from sema.models.planner.patterns import VocabLookup + + with pytest.raises(ValidationError): + VocabLookup( + vocabulary_ref="omop.SNOMED", + source_value_ref="x", + domain_constraint_ref="d", + require_standard=True, + allow_zero_default=False, + ) + + p = VocabLookup( + vocabulary_ref="omop.SNOMED", + source_value_ref="cbio.cancer_type", + domain_constraint_ref="omop.condition_concept_id.domain=Condition", + require_standard=True, + allow_zero_default=False, + resolver_policy_ref="omop.snomed.condition.v1", + ) + assert p.effective_date_ref is None + + +def test_join_lookup_requires_keys() -> None: + from sema.models.planner.patterns import JoinKeyPair, JoinLookup + + with pytest.raises(ValidationError): + JoinLookup( + from_source_ref="cbio.a", + to_source_ref="cbio.b", + join_keys=[], + select_field_ref="cbio.x", + ) + + p = JoinLookup( + from_source_ref="cbio.a", + to_source_ref="cbio.b", + join_keys=[ + JoinKeyPair(from_field_ref="cbio.a.id", to_field_ref="cbio.b.id") + ], + select_field_ref="cbio.b.value", + ) + assert len(p.join_keys) == 1 + + +def test_pivot_partition_keys_required() -> None: + from sema.models.planner.patterns import PivotMapping + + with pytest.raises(ValidationError): + PivotMapping( + source_table_ref="cbio.t", + key_field_ref="cbio.t.k", + value_field_ref="cbio.t.v", + partition_keys=[], + expansion_mode="multi_column", + ) + + +def test_unpivot_requires_columns() -> None: + from sema.models.planner.patterns import UnpivotMapping + + with pytest.raises(ValidationError): + UnpivotMapping( + source_table_ref="cbio.t", + key_columns=[], + key_name_target_field="omop.t.name", + value_target_field="omop.t.val", + ) + + +def test_split_outputs_required() -> None: + from sema.models.planner.patterns import SplitMapping, SplitRule + + with pytest.raises(ValidationError): + SplitMapping( + source_field_ref="cbio.t.x", + split_rule=SplitRule(kind="regex", pattern="(?.*)"), + output_target_fields={}, + ) + + +def test_aggregate_function_closed_set() -> None: + from sema.models.planner.patterns import ( + AggregateFunction, + AggregateMapping, + AggregateOp, + ) + + with pytest.raises(ValueError): + AggregateFunction("MEDIAN") + + p = AggregateMapping( + source_table_ref="cbio.t", + group_by_keys=["cbio.t.patient_id"], + aggregations=[ + AggregateOp( + target_field_ref="omop.t.count_obs", + aggregate_function=AggregateFunction.COUNT, + source_field_ref="cbio.t.obs_id", + ) + ], + ) + assert p.aggregations[0].aggregate_function == AggregateFunction.COUNT + + +def test_row_generation_requires_field_maps() -> None: + from sema.models.planner.patterns import RowGenerationMapping + + with pytest.raises(ValidationError): + RowGenerationMapping( + source_scope_ref="cbio.x", + generation_rule={ + "kind": "distinct_keys", + "keys": ["cbio.t.patient_id"], + }, + populated_field_maps=[], + ) + + +def test_no_map_payload_scope_required() -> None: + from sema.models.planner.patterns import NoMapPayload, NoMapScope + + p = NoMapPayload(reason="INTERNAL_BOOKKEEPING_FIELD", scope=NoMapScope.GLOBAL) + assert p.scope == NoMapScope.GLOBAL + + +def test_no_map_target_property_requires_property_ref() -> None: + from sema.models.planner.patterns import NoMapPayload, NoMapScope + + with pytest.raises(ValidationError): + NoMapPayload(reason="x", scope=NoMapScope.TARGET_PROPERTY) + + +def test_no_map_target_entity_requires_entity_ref() -> None: + from sema.models.planner.patterns import NoMapPayload, NoMapScope + + with pytest.raises(ValidationError): + NoMapPayload(reason="x", scope=NoMapScope.TARGET_ENTITY) + + +def test_payload_polymorphism_no_kind_flag() -> None: + from sema.models.planner.patterns import DirectCopyPayload, PivotMapping + + direct_fields = set(DirectCopyPayload.model_fields.keys()) + pivot_fields = set(PivotMapping.model_fields.keys()) + assert "kind" not in direct_fields + assert "kind" not in pivot_fields + assert "target_artifact_kind" not in direct_fields + assert "target_artifact_kind" not in pivot_fields + + +def test_direct_copy_rejects_extra_target_artifact_kind() -> None: + from sema.models.planner.patterns import DirectCopyPayload + + with pytest.raises(ValidationError): + DirectCopyPayload( + source_field_ref="cbio.x", target_artifact_kind="GRAPH_NODE" + ) + + +def test_pivot_rejects_unknown_field() -> None: + from sema.models.planner.patterns import PivotExpansionMode, PivotMapping + + with pytest.raises(ValidationError): + PivotMapping( + source_table_ref="cbio.t", + key_field_ref="cbio.t.k", + value_field_ref="cbio.t.v", + partition_keys=["cbio.t.p"], + expansion_mode=PivotExpansionMode.multi_column, + secret_flag=True, + ) + + +def test_row_generation_rejects_non_field_map_entries() -> None: + from sema.models.planner.patterns import RowGenerationMapping + + with pytest.raises(ValidationError): + RowGenerationMapping( + source_scope_ref="cbio.events_per_patient", + generation_rule={ + "kind": "window_envelope", + "partition": ["cbio.events.patient_id"], + "min_field": "cbio.events.event_date", + "max_field": "cbio.events.event_date", + }, + populated_field_maps=["not-a-field-map"], + ) + + +def test_row_generation_accepts_field_map_dicts() -> None: + from sema.models.planner.patterns import ( + MappingPattern, + RowGenerationMapping, + ) + + rg = RowGenerationMapping( + source_scope_ref="cbio.events_per_patient", + generation_rule={ + "kind": "window_envelope", + "partition": ["cbio.events.patient_id"], + "min_field": "cbio.events.event_date", + "max_field": "cbio.events.event_date", + }, + populated_field_maps=[ + { + "target_field_ref": "omop.person.person_id", + "pattern": "DIRECT_COPY", + "payload": {"source_field_ref": "cbio.patient_id"}, + } + ], + ) + assert rg.populated_field_maps[0].pattern == MappingPattern.DIRECT_COPY + + +def test_row_generation_rule_must_be_typed() -> None: + from sema.models.planner.patterns import RowGenerationMapping + + with pytest.raises(ValidationError): + RowGenerationMapping( + source_scope_ref="cbio.x", + generation_rule={"kind": "unknown_kind"}, + populated_field_maps=[ + { + "target_field_ref": "omop.t.f", + "pattern": "DIRECT_COPY", + "payload": {"source_field_ref": "cbio.s"}, + } + ], + ) + + +def test_bare_string_ref_rejected() -> None: + from sema.models.planner.patterns import DirectCopyPayload + + with pytest.raises(ValidationError): + DirectCopyPayload(source_field_ref="x") + with pytest.raises(ValidationError): + DirectCopyPayload(source_field_ref="bare_identifier") diff --git a/tests/unit/models/planner/test_mapping_plan.py b/tests/unit/models/planner/test_mapping_plan.py new file mode 100644 index 0000000..324fa44 --- /dev/null +++ b/tests/unit/models/planner/test_mapping_plan.py @@ -0,0 +1,368 @@ +"""Tests for MappingAssertion / MappingPlan / ConflictResolutionPolicy.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +import pytest +from pydantic import ValidationError + +pytestmark = pytest.mark.unit + + +def _provenance(ts: datetime | None = None) -> object: + from sema.models.planner.provenance import ( + Provenance, + RunProvenance, + SourceScope, + ) + + rp = RunProvenance( + run_id="run-1", + target_model_version="omop-cdm-5.4", + target_schema_snapshot_hash="t-abc", + vocab_release="omop-2026-q1", + context_card_version="cards-v3", + prompt_template_version="tpl-7", + few_shot_set_version="fs-12", + constraint_version="rules-v2", + llm_model="claude-opus-4.7", + embedding_model="bge-large", + ) + src = SourceScope( + source_id="cbioportal_gbm", + source_schema_hash="s-abc", + source_profile_hash="p-abc", + ) + return Provenance( + run=rp, + source=src, + timestamp=ts or datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + +def _make_assertion(**overrides: object) -> object: + from sema.models.planner.mapping_plan import MappingAssertion + from sema.models.planner.patterns import DirectCopyPayload, MappingPattern + from sema.models.planner.lifecycle import Status + + base: dict[str, object] = dict( + id="a-1", + source_field_ref="cbio.patient.gender", + target_property_ref="omop.person.gender_concept_id", + pattern=MappingPattern.DIRECT_COPY, + payload=DirectCopyPayload(source_field_ref="cbio.patient.gender"), + confidence=0.9, + risk_flags=[], + provenance=_provenance(), + status=Status.candidate, + ) + base.update(overrides) + return MappingAssertion(**base) + + +def test_field_map_requires_matching_payload() -> None: + from sema.models.planner.field_map import FieldMap + from sema.models.planner.patterns import ( + ConstantValue, + DirectCopyPayload, + MappingPattern, + ) + + fm = FieldMap( + target_field_ref="omop.person.gender_concept_id", + pattern=MappingPattern.DIRECT_COPY, + payload=DirectCopyPayload(source_field_ref="cbio.patient.gender"), + ) + assert fm.pattern is MappingPattern.DIRECT_COPY + + with pytest.raises(ValidationError): + FieldMap( + target_field_ref="omop.person.x", + pattern=MappingPattern.DIRECT_COPY, + payload=ConstantValue(literal_value=1, target_type="int"), + ) + + +def test_row_identity_requires_lineage() -> None: + from sema.models.planner._enums import MaterializationMode + from sema.models.planner.field_map import RowIdentity + + with pytest.raises(ValidationError): + RowIdentity( + target_row_key_rule="hash", + source_lineage=[], + materialization_mode=MaterializationMode.MERGE, + ) + + +def test_row_identity_stable_under_same_inputs() -> None: + from sema.models.planner._enums import MaterializationMode + from sema.models.planner.field_map import RowIdentity, derive_row_key + + ri = RowIdentity( + target_row_key_rule="hash", + source_lineage=["cbio.patient.patient_id", "cbio.study_id"], + materialization_mode=MaterializationMode.MERGE, + ) + a = derive_row_key( + ri, {"cbio.patient.patient_id": "P1", "cbio.study_id": "GBM"} + ) + b = derive_row_key( + ri, {"cbio.patient.patient_id": "P1", "cbio.study_id": "GBM"} + ) + c = derive_row_key( + ri, {"cbio.patient.patient_id": "P2", "cbio.study_id": "GBM"} + ) + assert a == b + assert a != c + + +def test_mapping_assertion_round_trip() -> None: + a = _make_assertion() + payload = a.model_dump(mode="json") + rt = type(a).model_validate(payload) + assert rt.id == "a-1" + from sema.models.planner.patterns import DirectCopyPayload + + assert isinstance(rt.payload, DirectCopyPayload) + + +def test_mapping_assertion_pattern_payload_mismatch_rejected() -> None: + from sema.models.planner.lifecycle import Status + from sema.models.planner.mapping_plan import MappingAssertion + from sema.models.planner.patterns import ConstantValue, MappingPattern + + with pytest.raises(ValidationError): + MappingAssertion( + id="a-bad", + source_field_ref="cbio.x", + target_property_ref="omop.y", + pattern=MappingPattern.DIRECT_COPY, + payload=ConstantValue(literal_value=1, target_type="int"), + confidence=0.9, + risk_flags=[], + provenance=_provenance(), + status=Status.candidate, + ) + + +def test_mapping_plan_round_trip() -> None: + from sema.models.planner._enums import MaterializationMode, PrimaryKeyStrategy + from sema.models.planner.field_map import FieldMap, RowIdentity + from sema.models.planner.mapping_plan import MappingPlan + from sema.models.planner.patterns import DirectCopyPayload, MappingPattern + from sema.models.planner.target_model import ( + FieldPresence, + RowPredicate, + TargetObligation, + ) + + obligation = TargetObligation( + target_entity="omop.person", + required_fields=["person_id"], + primary_key=PrimaryKeyStrategy.NATURAL_KEY, + minimum_viable_row=RowPredicate( + op="AND", clauses=[FieldPresence(field="person_id")] + ), + ) + plan = MappingPlan( + id="plan-1", + source_scope_ref="cbioportal_gbm", + obligation=obligation, + row_identity=RowIdentity( + target_row_key_rule="hash", + source_lineage=["cbio.patient.patient_id"], + materialization_mode=MaterializationMode.MERGE, + ), + field_maps=[ + FieldMap( + target_field_ref="omop.person.person_id", + pattern=MappingPattern.DIRECT_COPY, + payload=DirectCopyPayload(source_field_ref="cbio.patient.patient_id"), + ) + ], + risk_flags=[], + lineage=["cbio.patient.patient_id"], + ) + rt = MappingPlan.model_validate(plan.model_dump(mode="json")) + assert rt.id == "plan-1" + + +def test_mapping_plan_blocked_when_required_missing() -> None: + from sema.models.planner._enums import MaterializationMode, PrimaryKeyStrategy + from sema.models.planner.field_map import RowIdentity + from sema.models.planner.lifecycle import PlanVerdict + from sema.models.planner.mapping_plan import MappingPlan + from sema.models.planner.target_model import TargetObligation + + obligation = TargetObligation( + target_entity="omop.person", + required_fields=["person_id", "gender_concept_id"], + primary_key=PrimaryKeyStrategy.NATURAL_KEY, + ) + plan = MappingPlan( + id="plan-2", + source_scope_ref="cbio", + obligation=obligation, + row_identity=RowIdentity( + target_row_key_rule="hash", + source_lineage=["x.y"], + materialization_mode=MaterializationMode.MERGE, + ), + field_maps=[], + risk_flags=[], + lineage=["x.y"], + ) + assert plan.derive_verdict() == PlanVerdict.blocked_by_obligation + + +def _compilable_plan_factory(): + from sema.models.planner._enums import MaterializationMode, PrimaryKeyStrategy + from sema.models.planner.field_map import FieldMap, RowIdentity + from sema.models.planner.mapping_plan import MappingPlan + from sema.models.planner.patterns import DirectCopyPayload, MappingPattern + from sema.models.planner.target_model import TargetObligation + + def make(risk_flags=None): + return MappingPlan( + id="plan-rf", + source_scope_ref="cbio", + obligation=TargetObligation( + target_entity="omop.person", + required_fields=["omop.person.person_id"], + primary_key=PrimaryKeyStrategy.NATURAL_KEY, + ), + row_identity=RowIdentity( + target_row_key_rule="hash", + source_lineage=["cbio.x"], + materialization_mode=MaterializationMode.MERGE, + ), + field_maps=[ + FieldMap( + target_field_ref="omop.person.person_id", + pattern=MappingPattern.DIRECT_COPY, + payload=DirectCopyPayload(source_field_ref="cbio.patient_id"), + ) + ], + risk_flags=risk_flags or [], + lineage=["cbio.x"], + ) + + return make + + +def _risk(code): + from sema.models.planner.risk import ( + Evidence, + EvidenceMode, + RiskFlag, + SensitivityClass, + Severity, + SourceStage, + SuggestedAction, + ) + + return RiskFlag( + code=code, + severity=Severity.block, + evidence=[ + Evidence( + mode=EvidenceMode.COUNT_ONLY, + payload={"count": 1}, + sensitivity_class=SensitivityClass.PUBLIC, + source_ref="cbio.x", + ) + ], + source_stage=SourceStage.constraint, + suggested_action=SuggestedAction.review, + ) + + +def test_plan_verdict_blocked_by_fk_when_risk_present() -> None: + from sema.models.planner.lifecycle import PlanVerdict + from sema.models.planner.risk import RiskCode + + make = _compilable_plan_factory() + plan = make(risk_flags=[_risk(RiskCode.RISK_OBLIGATION_FK_UNSATISFIED)]) + assert plan.derive_verdict() == PlanVerdict.blocked_by_fk + + +def test_plan_verdict_blocked_by_obligation_when_min_viable_row_violated() -> None: + from sema.models.planner.lifecycle import PlanVerdict + from sema.models.planner.risk import RiskCode + + make = _compilable_plan_factory() + plan = make( + risk_flags=[ + _risk(RiskCode.RISK_OBLIGATION_MINIMUM_VIABLE_ROW_VIOLATED) + ] + ) + assert plan.derive_verdict() == PlanVerdict.blocked_by_obligation + + +def test_plan_verdict_blocked_by_resolution_dependency_missing() -> None: + from sema.models.planner.lifecycle import PlanVerdict + from sema.models.planner.risk import RiskCode + + make = _compilable_plan_factory() + plan = make(risk_flags=[_risk(RiskCode.RISK_RESOLUTION_DEPENDENCY_MISSING)]) + assert plan.derive_verdict() == PlanVerdict.blocked_by_resolution + + +def test_plan_verdict_compilable_with_no_blocking_risk() -> None: + from sema.models.planner.lifecycle import PlanVerdict + + make = _compilable_plan_factory() + assert make().derive_verdict() == PlanVerdict.compilable + + +def test_plan_verdict_blocked_by_obligation_when_required_field_missing_risk() -> None: + from sema.models.planner.lifecycle import PlanVerdict + from sema.models.planner.risk import RiskCode + + make = _compilable_plan_factory() + plan = make( + risk_flags=[_risk(RiskCode.RISK_OBLIGATION_REQUIRED_FIELD_MISSING)] + ) + assert plan.derive_verdict() == PlanVerdict.blocked_by_obligation + + +def test_conflict_policy_pin_wins() -> None: + from sema.models.planner.lifecycle import Status + from sema.models.planner.mapping_plan import ( + ConflictResolutionPolicy, + select_winner, + ) + + a1 = _make_assertion(id="a1", confidence=0.6, status=Status.human_pinned) + a2 = _make_assertion(id="a2", confidence=0.99, status=Status.auto_accepted) + winner = select_winner([a1, a2], ConflictResolutionPolicy.default()) + assert winner.id == "a1" + + +def test_conflict_policy_confidence_then_recency_then_template() -> None: + from sema.models.planner.lifecycle import Status + from sema.models.planner.mapping_plan import ( + ConflictResolutionPolicy, + select_winner, + ) + + base_ts = datetime(2026, 1, 1, tzinfo=timezone.utc) + a1 = _make_assertion(id="a1", confidence=0.91, status=Status.auto_accepted) + a2 = _make_assertion(id="a2", confidence=0.86, status=Status.auto_accepted) + assert select_winner([a1, a2], ConflictResolutionPolicy.default()).id == "a1" + + a3 = _make_assertion( + id="a3", + confidence=0.91, + status=Status.auto_accepted, + provenance=_provenance(ts=base_ts), + ) + a4 = _make_assertion( + id="a4", + confidence=0.91, + status=Status.auto_accepted, + provenance=_provenance(ts=base_ts + timedelta(seconds=10)), + ) + assert select_winner([a3, a4], ConflictResolutionPolicy.default()).id == "a4" diff --git a/tests/unit/models/planner/test_package_surface.py b/tests/unit/models/planner/test_package_surface.py new file mode 100644 index 0000000..b696a44 --- /dev/null +++ b/tests/unit/models/planner/test_package_surface.py @@ -0,0 +1,27 @@ +import pytest + +pytestmark = pytest.mark.unit + + +def test_planner_package_importable() -> None: + import sema.models.planner as planner + + assert planner is not None + + +def test_models_package_reexports_planner() -> None: + from sema.models import planner + + assert planner is not None + + +def test_shared_enums_module_exists() -> None: + from sema.models.planner import _enums + + assert _enums is not None + + +def test_shared_refs_module_exists() -> None: + from sema.models.planner import _refs + + assert _refs is not None diff --git a/tests/unit/models/planner/test_planner_storage.py b/tests/unit/models/planner/test_planner_storage.py new file mode 100644 index 0000000..e953e73 --- /dev/null +++ b/tests/unit/models/planner/test_planner_storage.py @@ -0,0 +1,665 @@ +"""Unit-level tests for the planner storage layout (round-trip in memory).""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +pytestmark = pytest.mark.unit + + +def _provenance() -> object: + from sema.models.planner.provenance import ( + Provenance, + RunProvenance, + SourceScope, + ) + + return Provenance( + run=RunProvenance( + run_id="run-1", + target_model_version="omop-cdm-5.4", + target_schema_snapshot_hash="t", + vocab_release="v", + context_card_version="c", + prompt_template_version="t1", + few_shot_set_version="f", + constraint_version="cv", + llm_model="m", + embedding_model="e", + ), + source=SourceScope( + source_id="cbioportal_gbm", + source_schema_hash="s", + source_profile_hash="p", + ), + timestamp=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + +def test_planner_migrations_emit_required_constraints() -> None: + from sema.graph.planner_migrations import cypher_up + + statements = cypher_up(enterprise=True) + assert any("MappingAssertion" in s for s in statements) + assert any("HumanPin" in s for s in statements) + assert any("model_role IS NOT NULL" in s for s in statements) + assert any("prov_run_run_id" in s for s in statements) + + +def test_community_migration_omits_existence_constraints() -> None: + from sema.graph.planner_migrations import cypher_up + + statements = cypher_up(enterprise=False) + assert not any("model_role IS NOT NULL" in s for s in statements) + assert any("MappingAssertion_id_unique" in s for s in statements) + + +def test_planner_migrations_down_drops_planner_data() -> None: + from sema.graph.planner_migrations import cypher_down + + statements = cypher_down() + assert any("DETACH DELETE" in s for s in statements) + assert any("MappingAssertion" in s for s in statements) + + +def test_apoc_triggers_emitted_when_requested() -> None: + from sema.graph.planner_migrations import cypher_up + + statements = cypher_up(apoc=True) + triggers = [s for s in statements if "apoc.trigger.add" in s] + assert len(triggers) == 7 + assert any("'MAPS_TO'" in s and "'TARGET'" in s for s in triggers) + assert any("'DERIVED_FROM'" in s and "'SOURCE'" in s for s in triggers) + assert any("'HAS_LINEAGE'" in s and "'SOURCE'" in s for s in triggers) + assert any( + "'RESOLUTION_INPUT'" in s and "'SOURCE'" in s for s in triggers + ) + assert any("planner_no_role_id_collision" in s for s in triggers) + assert any("planner_source_role_requires_source_id" in s for s in triggers) + assert any( + "planner_target_role_requires_target_model_id" in s for s in triggers + ) + + +def test_role_backfill_includes_source_id_derivation() -> None: + from sema.graph.planner_migrations import cypher_up + + statements = cypher_up(enterprise=False) + backfill_stmts = [s for s in statements if "model_role" in s or "source_id" in s] + assert any("source_id IS NULL" in s for s in backfill_stmts) + assert any("source_schema" in s for s in backfill_stmts) + assert any("model_role = 'SOURCE'" in s for s in backfill_stmts) + + +def test_apoc_triggers_omitted_by_default() -> None: + from sema.graph.planner_migrations import cypher_up + + assert not any( + "apoc.trigger.add" in s for s in cypher_up(enterprise=False) + ) + + +def test_cypher_down_apoc_true_removes_triggers() -> None: + from sema.graph.planner_migrations import cypher_down + + statements = cypher_down(apoc=True) + assert any( + "apoc.trigger.remove('planner_maps_to_requires_target_property')" in s + for s in statements + ) + + +def test_cypher_down_apoc_false_skips_trigger_removal() -> None: + from sema.graph.planner_migrations import cypher_down + + statements = cypher_down(apoc=False) + assert not any("apoc.trigger.remove" in s for s in statements) + assert any("MappingAssertion" in s for s in statements) + + +def test_provenance_round_trip_via_native_properties() -> None: + from sema.graph.planner_loader import ( + properties_to_provenance, + provenance_to_properties, + ) + + prov = _provenance() + props = provenance_to_properties(prov) + assert props["prov_run_run_id"] == "run-1" + assert props["prov_source_source_id"] == "cbioportal_gbm" + assert props["prov_timestamp"].startswith("2026-01-01") + rt = properties_to_provenance(props) + assert rt.run.run_id == prov.run.run_id + assert rt.source.source_profile_hash == prov.source.source_profile_hash + + +def test_confirmed_under_round_trip() -> None: + from sema.graph.planner_loader import ( + confirmed_under_to_properties, + properties_to_confirmed_under, + ) + + prov = _provenance() + props = confirmed_under_to_properties(prov.run, prov.source) + rt_run, rt_source = properties_to_confirmed_under(props) + assert rt_run == prov.run + assert rt_source == prov.source + + +def test_maps_to_requires_target_property() -> None: + from sema.graph.planner_loader import cypher_create_field_map_maps_to + from sema.models.planner._enums import ModelRole + + stmt, params = cypher_create_field_map_maps_to( + "fm-1", "p-1", ModelRole.TARGET + ) + assert "MAPS_TO" in stmt + assert params == {"fm_id": "fm-1", "p_id": "p-1"} + with pytest.raises(ValueError, match="MAPS_TO"): + cypher_create_field_map_maps_to("fm-1", "p-1", ModelRole.SOURCE) + + +def test_derived_from_requires_source_property() -> None: + from sema.graph.planner_loader import cypher_create_field_map_derived_from + from sema.models.planner._enums import ModelRole + + stmt, _ = cypher_create_field_map_derived_from( + "fm-1", "p-1", ModelRole.SOURCE + ) + assert "DERIVED_FROM" in stmt + with pytest.raises(ValueError, match="DERIVED_FROM"): + cypher_create_field_map_derived_from("fm-1", "p-1", ModelRole.TARGET) + + +def test_has_lineage_requires_source_property() -> None: + from sema.graph.planner_loader import cypher_create_plan_has_lineage + from sema.models.planner._enums import ModelRole + + stmt, _ = cypher_create_plan_has_lineage("plan-1", "p-1", ModelRole.SOURCE) + assert "HAS_LINEAGE" in stmt + with pytest.raises(ValueError, match="HAS_LINEAGE"): + cypher_create_plan_has_lineage("plan-1", "p-1", ModelRole.TARGET) + + +def test_resolution_input_requires_source_property() -> None: + from sema.graph.planner_loader import cypher_create_resolution_input + from sema.models.planner._enums import ModelRole + + stmt, _ = cypher_create_resolution_input("rp-1", "p-1", ModelRole.SOURCE) + assert "RESOLUTION_INPUT" in stmt + with pytest.raises(ValueError, match="RESOLUTION_INPUT"): + cypher_create_resolution_input("rp-1", "p-1", ModelRole.TARGET) + + +def test_required_property_role_lookup_unknown_returns_none() -> None: + from sema.models.planner._role_validation import required_property_role + + assert required_property_role("ASSEMBLED_INTO") is None + assert required_property_role("MAPS_TO") is not None + + +def _mapping_assertion(**overrides): + from sema.models.planner.mapping_plan import MappingAssertion + from sema.models.planner.patterns import ( + DirectCopyPayload, + MappingPattern, + ) + from sema.models.planner.lifecycle import Status + + base = dict( + id="a-1", + source_field_ref="cbio.gender", + target_property_ref="omop.gender_concept_id", + pattern=MappingPattern.DIRECT_COPY, + payload=DirectCopyPayload(source_field_ref="cbio.gender"), + confidence=0.92, + provenance=_provenance(), + status=Status.candidate, + ) + base.update(overrides) + return MappingAssertion(**base) + + +def test_mapping_assertion_round_trip_via_properties() -> None: + from sema.graph.planner_loader import ( + mapping_assertion_to_properties, + properties_to_mapping_assertion, + ) + + a = _mapping_assertion() + rt = properties_to_mapping_assertion(mapping_assertion_to_properties(a)) + assert rt == a + + +def test_field_map_round_trip_via_properties() -> None: + from sema.graph.planner_loader import ( + field_map_to_properties, + properties_to_field_map, + ) + from sema.models.planner.field_map import FieldMap + from sema.models.planner.patterns import DirectCopyPayload, MappingPattern + + fm = FieldMap( + target_field_ref="omop.person.gender_concept_id", + pattern=MappingPattern.DIRECT_COPY, + payload=DirectCopyPayload(source_field_ref="cbio.gender"), + ) + rt = properties_to_field_map(field_map_to_properties(fm)) + assert rt == fm + + +def test_target_obligation_round_trip_via_properties() -> None: + from sema.graph.planner_loader import ( + properties_to_target_obligation, + target_obligation_to_properties, + ) + from sema.models.planner._enums import PrimaryKeyStrategy + from sema.models.planner.target_model import TargetObligation + + o = TargetObligation( + target_entity="omop.person", + required_fields=["person_id"], + primary_key=PrimaryKeyStrategy.NATURAL_KEY, + ) + rt = properties_to_target_obligation(target_obligation_to_properties(o)) + assert rt == o + + +def test_mapping_plan_round_trip_via_properties() -> None: + from sema.graph.planner_loader import ( + mapping_plan_to_properties, + properties_to_mapping_plan, + ) + from sema.models.planner._enums import ( + MaterializationMode, + PrimaryKeyStrategy, + ) + from sema.models.planner.field_map import FieldMap, RowIdentity + from sema.models.planner.mapping_plan import MappingPlan + from sema.models.planner.patterns import DirectCopyPayload, MappingPattern + from sema.models.planner.target_model import TargetObligation + + plan = MappingPlan( + id="plan-1", + source_scope_ref="cbio", + obligation=TargetObligation( + target_entity="omop.person", + required_fields=["gender_concept_id"], + primary_key=PrimaryKeyStrategy.NATURAL_KEY, + ), + row_identity=RowIdentity( + target_row_key_rule="hash", + source_lineage=["cbio.patient_id"], + materialization_mode=MaterializationMode.MERGE, + ), + field_maps=[ + FieldMap( + target_field_ref="omop.person.gender_concept_id", + pattern=MappingPattern.DIRECT_COPY, + payload=DirectCopyPayload(source_field_ref="cbio.gender"), + ) + ], + lineage=["cbio.patient", "cbio.gender"], + ) + rt = properties_to_mapping_plan(mapping_plan_to_properties(plan)) + assert rt == plan + + +def test_resolution_plan_round_trip_via_properties() -> None: + from datetime import datetime, timezone + + from sema.graph.planner_loader import ( + properties_to_resolution_plan, + resolution_plan_to_properties, + ) + from sema.models.planner.provenance import SourceScope + from sema.models.planner.resolution import ( + DeterministicHashPayload, + ResolutionPlan, + ResolutionStrategy, + ) + + rp = ResolutionPlan( + id="rp-1", + sources=[ + SourceScope( + source_id="cbio", source_schema_hash="s", source_profile_hash="p" + ) + ], + target_identity_ref="canonical.patient_id", + strategy=ResolutionStrategy.DETERMINISTIC_HASH, + payload=DeterministicHashPayload(source_key_refs=["cbio.patient_id"]), + confidence=1.0, + provenance_run=_provenance().run, + timestamp=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + props = resolution_plan_to_properties(rp) + assert props["prov_run_run_id"] == "run-1" + assert props["strategy"] == "DETERMINISTIC_HASH" + rt = properties_to_resolution_plan(props) + assert rt == rp + + +def test_risk_flag_round_trip_via_properties() -> None: + from sema.graph.planner_loader import ( + properties_to_risk_flag, + risk_flag_to_properties, + ) + from sema.models.planner.risk import ( + Evidence, + EvidenceMode, + RiskCode, + RiskFlag, + SensitivityClass, + Severity, + SourceStage, + SuggestedAction, + ) + + rf = RiskFlag( + code=RiskCode.RISK_VOCAB_DOMAIN_MISMATCH, + severity=Severity.warn, + evidence=[ + Evidence( + mode=EvidenceMode.COUNT_ONLY, + payload={"count": 3}, + sensitivity_class=SensitivityClass.PHI, + source_ref="cbio.x", + ) + ], + source_stage=SourceStage.producer, + suggested_action=SuggestedAction.review, + ) + rt = properties_to_risk_flag(risk_flag_to_properties(rf)) + assert rt == rf + + +def test_human_pin_round_trip_via_properties_assertion_pin() -> None: + from datetime import datetime, timezone + + from sema.graph.planner_loader import ( + human_pin_to_properties, + properties_to_human_pin, + ) + from sema.models.planner.lifecycle import HumanPin, PinState + + pin = HumanPin( + pin_id="pin-a1", + assertion_id="a-1", + pinned_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + pinned_by="reviewer@x", + confirmed_under_run=_provenance().run, + confirmed_under_source=_provenance().source, + pin_state=PinState.active, + ) + rt = properties_to_human_pin(human_pin_to_properties(pin)) + assert rt == pin + + +class _FakeRow(dict): + pass + + +class _FakeResult: + def __init__(self, props: dict | None) -> None: + self._props = props + + def single(self) -> _FakeRow | None: + if self._props is None: + return None + return _FakeRow(p=self._props) + + +class _FakeSession: + """Minimal in-memory session double for write/read helper unit tests.""" + + def __init__(self) -> None: + self.store: dict[tuple[str, str], dict] = {} + + def run(self, cypher: str, **params): + if "MERGE" in cypher and "SET n = $props" in cypher: + label = cypher.split(":", 1)[1].split(" ")[0] + props = params["props"] + self.store[(label, props["id"])] = dict(props) + return _FakeResult(None) + if "MATCH" in cypher and "RETURN properties(n) AS p" in cypher: + label = cypher.split(":", 1)[1].split(" ")[0] + node_id = params["id"] + return _FakeResult(self.store.get((label, node_id))) + return _FakeResult(None) + + +def test_write_read_mapping_assertion_via_helpers() -> None: + from sema.graph.planner_loader import ( + read_mapping_assertion, + write_mapping_assertion, + ) + + session = _FakeSession() + a = _mapping_assertion() + write_mapping_assertion(session, a) + rt = read_mapping_assertion(session, a.id) + assert rt == a + + +def test_read_mapping_assertion_missing_raises() -> None: + from sema.graph.planner_loader import read_mapping_assertion + + with pytest.raises(LookupError): + read_mapping_assertion(_FakeSession(), "nope") + + +def test_write_read_human_pin_via_helpers() -> None: + from datetime import datetime, timezone + + from sema.graph.planner_loader import read_human_pin, write_human_pin + from sema.models.planner.lifecycle import HumanPin, PinState + + session = _FakeSession() + pin = HumanPin( + pin_id="pin-w", + assertion_id="a-1", + pinned_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + pinned_by="reviewer@x", + confirmed_under_run=_provenance().run, + confirmed_under_source=_provenance().source, + pin_state=PinState.active, + ) + write_human_pin(session, pin) + rt = read_human_pin(session, "pin-w") + assert rt == pin + + +def test_write_read_target_obligation_via_helpers() -> None: + from sema.graph.planner_loader import ( + read_target_obligation, + write_target_obligation, + ) + from sema.models.planner._enums import PrimaryKeyStrategy + from sema.models.planner.target_model import TargetObligation + + session = _FakeSession() + o = TargetObligation( + target_entity="omop.person", + required_fields=["omop.person.person_id"], + primary_key=PrimaryKeyStrategy.NATURAL_KEY, + ) + write_target_obligation(session, o, obligation_id="ob-1") + rt = read_target_obligation(session, "ob-1") + assert rt == o + + +def test_write_read_field_map_via_helpers() -> None: + from sema.graph.planner_loader import read_field_map, write_field_map + from sema.models.planner.field_map import FieldMap + from sema.models.planner.patterns import ( + DirectCopyPayload, + MappingPattern, + ) + + session = _FakeSession() + fm = FieldMap( + target_field_ref="omop.person.gender_concept_id", + pattern=MappingPattern.DIRECT_COPY, + payload=DirectCopyPayload(source_field_ref="cbio.gender"), + ) + write_field_map(session, fm, field_map_id="fm-1") + rt = read_field_map(session, "fm-1") + assert rt == fm + + +def test_write_read_mapping_plan_via_helpers() -> None: + from sema.graph.planner_loader import read_mapping_plan, write_mapping_plan + from sema.models.planner._enums import ( + MaterializationMode, + PrimaryKeyStrategy, + ) + from sema.models.planner.field_map import FieldMap, RowIdentity + from sema.models.planner.mapping_plan import MappingPlan + from sema.models.planner.patterns import ( + DirectCopyPayload, + MappingPattern, + ) + from sema.models.planner.target_model import TargetObligation + + plan = MappingPlan( + id="plan-w", + source_scope_ref="cbio", + obligation=TargetObligation( + target_entity="omop.person", + required_fields=["omop.person.person_id"], + primary_key=PrimaryKeyStrategy.NATURAL_KEY, + ), + row_identity=RowIdentity( + target_row_key_rule="hash", + source_lineage=["cbio.x"], + materialization_mode=MaterializationMode.MERGE, + ), + field_maps=[ + FieldMap( + target_field_ref="omop.person.person_id", + pattern=MappingPattern.DIRECT_COPY, + payload=DirectCopyPayload(source_field_ref="cbio.patient_id"), + ) + ], + ) + session = _FakeSession() + write_mapping_plan(session, plan) + rt = read_mapping_plan(session, plan.id) + assert rt == plan + + +def test_write_read_resolution_plan_via_helpers() -> None: + from datetime import datetime, timezone + + from sema.graph.planner_loader import ( + read_resolution_plan, + write_resolution_plan, + ) + from sema.models.planner.provenance import SourceScope + from sema.models.planner.resolution import ( + DeterministicHashPayload, + ResolutionPlan, + ResolutionStrategy, + ) + + rp = ResolutionPlan( + id="rp-w", + sources=[ + SourceScope( + source_id="cbio", + source_schema_hash="s", + source_profile_hash="p", + ) + ], + target_identity_ref="canonical.patient_id", + strategy=ResolutionStrategy.DETERMINISTIC_HASH, + payload=DeterministicHashPayload(source_key_refs=["cbio.patient_id"]), + confidence=1.0, + provenance_run=_provenance().run, + timestamp=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + session = _FakeSession() + write_resolution_plan(session, rp) + rt = read_resolution_plan(session, rp.id) + assert rt == rp + + +def test_write_read_risk_flag_via_helpers() -> None: + from sema.graph.planner_loader import read_risk_flag, write_risk_flag + from sema.models.planner.risk import ( + Evidence, + EvidenceMode, + RiskCode, + RiskFlag, + SensitivityClass, + Severity, + SourceStage, + SuggestedAction, + ) + + rf = RiskFlag( + code=RiskCode.RISK_OBLIGATION_FK_UNSATISFIED, + severity=Severity.block, + evidence=[ + Evidence( + mode=EvidenceMode.COUNT_ONLY, + payload={"count": 1}, + sensitivity_class=SensitivityClass.PUBLIC, + source_ref="cbio.x", + ) + ], + source_stage=SourceStage.constraint, + suggested_action=SuggestedAction.review, + ) + session = _FakeSession() + write_risk_flag(session, rf, flag_id="rf-w") + rt = read_risk_flag(session, "rf-w") + assert rt == rf + + +def test_read_helpers_raise_on_missing() -> None: + from sema.graph.planner_loader import ( + read_field_map, + read_human_pin, + read_mapping_plan, + read_resolution_plan, + read_risk_flag, + read_target_obligation, + ) + + s = _FakeSession() + for fn in ( + read_field_map, + read_human_pin, + read_mapping_plan, + read_resolution_plan, + read_risk_flag, + read_target_obligation, + ): + with pytest.raises(LookupError): + fn(s, "missing-id") + + +def test_human_pin_round_trip_via_properties_resolution_pin() -> None: + from datetime import datetime, timezone + + from sema.graph.planner_loader import ( + human_pin_to_properties, + properties_to_human_pin, + ) + from sema.models.planner.lifecycle import HumanPin, PinState + + pin = HumanPin( + pin_id="pin-r1", + resolution_plan_id="rp-1", + pinned_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + pinned_by="reviewer@x", + confirmed_under_run=_provenance().run, + confirmed_under_source=_provenance().source, + pin_state=PinState.stale, + ) + rt = properties_to_human_pin(human_pin_to_properties(pin)) + assert rt == pin diff --git a/tests/unit/models/planner/test_provenance_and_caching.py b/tests/unit/models/planner/test_provenance_and_caching.py new file mode 100644 index 0000000..f8e0518 --- /dev/null +++ b/tests/unit/models/planner/test_provenance_and_caching.py @@ -0,0 +1,246 @@ +"""Tests for the provenance-and-caching capability.""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +from pydantic import ValidationError + +pytestmark = pytest.mark.unit + + +def _run_prov(**overrides: object) -> object: + from sema.models.planner.provenance import RunProvenance + + base = dict( + run_id="run-1", + target_model_version="omop-cdm-5.4", + target_schema_snapshot_hash="t-abc", + vocab_release="omop-2026-q1", + context_card_version="cards-v3", + prompt_template_version="tpl-7", + few_shot_set_version="fs-12", + constraint_version="rules-v2", + llm_model="claude-opus-4.7", + embedding_model="bge-large", + ) + base.update(overrides) + return RunProvenance(**base) + + +def _source_scope(**overrides: object) -> object: + from sema.models.planner.provenance import SourceScope + + base = dict( + source_id="cbioportal_gbm", + source_schema_hash="s-abc", + source_profile_hash="p-abc", + ) + base.update(overrides) + return SourceScope(**base) + + +def test_run_provenance_round_trip() -> None: + from sema.models.planner.provenance import RunProvenance + + rp = _run_prov() + payload = rp.model_dump(mode="json") + rt = RunProvenance.model_validate(payload) + assert rt.run_id == "run-1" + assert rt.llm_model == "claude-opus-4.7" + + +def test_run_provenance_required_fields() -> None: + from sema.models.planner.provenance import RunProvenance + + with pytest.raises(ValidationError): + RunProvenance( # type: ignore[call-arg] + run_id="r", + target_model_version="v", + target_schema_snapshot_hash="h", + context_card_version="c", + prompt_template_version="t", + few_shot_set_version="f", + llm_model="m", + ) + + +def test_source_scope_round_trip() -> None: + from sema.models.planner.provenance import SourceScope + + s = _source_scope() + rt = SourceScope.model_validate(s.model_dump(mode="json")) + assert rt.source_id == "cbioportal_gbm" + + +def test_provenance_composes_run_and_source() -> None: + from sema.models.planner.provenance import Provenance + + p = Provenance( + run=_run_prov(), + source=_source_scope(), + timestamp=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + assert p.run.run_id == "run-1" + assert p.source.source_id == "cbioportal_gbm" + + +def test_run_version_lock_detects_drift() -> None: + from sema.models.planner.provenance import RunVersionLock + + lock = RunVersionLock() + rp1 = _run_prov() + rp2 = _run_prov(prompt_template_version="tpl-8") + lock.bind(rp1) + with pytest.raises(ValueError): + lock.bind(rp2) + + +def test_run_version_lock_allows_same() -> None: + from sema.models.planner.provenance import RunVersionLock + + lock = RunVersionLock() + lock.bind(_run_prov()) + lock.bind(_run_prov()) + + +def test_source_scope_lock_per_run_id() -> None: + from sema.models.planner.provenance import SourceScopeLock + + lock = SourceScopeLock(run_id="run-1") + lock.bind(_source_scope()) + lock.bind(_source_scope(source_id="msk_chord", source_schema_hash="h2", source_profile_hash="p2")) + with pytest.raises(ValueError): + lock.bind(_source_scope(source_profile_hash="p-DRIFT")) + + +def test_prompt_artifact_prefix_hash_deterministic() -> None: + from sema.models.planner.provenance import PromptArtifact + + a = PromptArtifact.build( + prefix_text="system\ntarget cards\nfew-shot A\n", + suffix_text="source: cbio.patient.gender\n", + versions={"target_model_version": "omop-cdm-5.4"}, + ) + b = PromptArtifact.build( + prefix_text="system\ntarget cards\nfew-shot A\n", + suffix_text="source: cbio.patient.race\n", + versions={"target_model_version": "omop-cdm-5.4"}, + ) + assert a.prefix_hash == b.prefix_hash + assert a.suffix_text != b.suffix_text + + +def test_prompt_artifact_explicit_hash_must_match() -> None: + from sema.models.planner.provenance import PromptArtifact + + with pytest.raises(ValidationError): + PromptArtifact( + prefix_text="x", + prefix_hash="not-a-real-digest", + suffix_text="y", + versions={}, + ) + + +def test_prompt_artifact_source_isolation_passes_clean_prefix() -> None: + from sema.models.planner.provenance import PromptArtifact + + art = PromptArtifact.build( + prefix_text="system\nOMOP cards\nfew-shot for omop.person\n", + suffix_text="source: cbio.patient.gender\n", + versions={}, + ) + art.assert_source_isolated("cbio.patient.gender") + + +def test_prompt_artifact_source_isolation_rejects_leak() -> None: + from sema.models.planner.provenance import PromptArtifact + + art = PromptArtifact.build( + prefix_text="system\nfew-shot referencing cbio.patient.gender\n", + suffix_text="trailer\n", + versions={}, + ) + with pytest.raises(ValueError, match="cbio.patient.gender"): + art.assert_source_isolated("cbio.patient.gender") + + +def test_prompt_artifact_source_isolation_empty_ref_rejected() -> None: + from sema.models.planner.provenance import PromptArtifact + + art = PromptArtifact.build(prefix_text="p", suffix_text="s", versions={}) + with pytest.raises(ValueError, match="non-empty"): + art.assert_source_isolated("") + + +def test_cache_key_changes_with_tracked_dimension() -> None: + from sema.models.planner.provenance import ( + PromptArtifact, + derive_cache_key, + ) + + art = PromptArtifact.build( + prefix_text="prefix", + suffix_text="s1", + versions={"context_card_version": "v1"}, + ) + rp1 = _run_prov(context_card_version="cards-v3") + rp2 = _run_prov(context_card_version="cards-v4") + assert derive_cache_key(art, rp1) != derive_cache_key(art, rp2) + + +def test_cache_key_ignores_source_scope() -> None: + from sema.models.planner.provenance import ( + PromptArtifact, + derive_cache_key, + ) + + art = PromptArtifact.build(prefix_text="prefix", suffix_text="s", versions={}) + rp = _run_prov() + src1 = _source_scope() + src2 = _source_scope(source_id="msk_chord", source_schema_hash="x", source_profile_hash="y") + assert derive_cache_key(art, rp, src1) == derive_cache_key(art, rp, src2) + + +def test_source_profile_hash_stable() -> None: + from sema.models.planner.provenance import compute_source_profile_hash + + sig = { + "columns": [{"name": "gender", "samples": ["M", "F"], "distinct": 2, "null_rate": 0.0}], + } + a = compute_source_profile_hash(sig) + b = compute_source_profile_hash(sig) + sig_drift = { + "columns": [{"name": "gender", "samples": ["M", "F"], "distinct": 2, "null_rate": 0.05}], + } + c = compute_source_profile_hash(sig_drift) + assert a == b + assert a != c + + +def test_llm_runtime_protocol_anthropic_caches_prefix() -> None: + from sema.models.planner.provenance import ( + AnthropicCachingAdapter, + PromptArtifact, + ) + + adapter = AnthropicCachingAdapter(name="claude-opus-4.7") + art = PromptArtifact.build(prefix_text="prefix", suffix_text="s", versions={}) + headers = adapter.cache_directives(art) + assert headers.get("cache_control") == "ephemeral" + assert adapter.dialect == "anthropic" + + +def test_llm_runtime_protocol_mosaic_no_cache() -> None: + from sema.models.planner.provenance import ( + MosaicAIAdapter, + PromptArtifact, + ) + + adapter = MosaicAIAdapter(name="dbrx-instruct") + art = PromptArtifact.build(prefix_text="prefix", suffix_text="s", versions={}) + headers = adapter.cache_directives(art) + assert "cache_control" not in headers + assert adapter.dialect == "mosaic" diff --git a/tests/unit/models/planner/test_resolution_planner.py b/tests/unit/models/planner/test_resolution_planner.py new file mode 100644 index 0000000..9935d6e --- /dev/null +++ b/tests/unit/models/planner/test_resolution_planner.py @@ -0,0 +1,263 @@ +"""Tests for the resolution-planner capability.""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +from pydantic import ValidationError + +pytestmark = pytest.mark.unit + + +def _run_prov() -> object: + from sema.models.planner.provenance import RunProvenance + + return RunProvenance( + run_id="run-1", + target_model_version="omop-cdm-5.4", + target_schema_snapshot_hash="t", + vocab_release="v", + context_card_version="c", + prompt_template_version="t1", + few_shot_set_version="f", + constraint_version="cv", + llm_model="m", + embedding_model="e", + ) + + +def _src(source_id: str = "cbioportal_gbm") -> object: + from sema.models.planner.provenance import SourceScope + + return SourceScope( + source_id=source_id, + source_schema_hash=f"s-{source_id}", + source_profile_hash=f"p-{source_id}", + ) + + +def test_resolution_strategy_values() -> None: + from sema.models.planner.resolution import ResolutionStrategy + + assert {s.value for s in ResolutionStrategy} == { + "DETERMINISTIC_HASH", + "FUZZY_BLOCK_AND_SCORE", + "GRAPH_CLOSURE", + "MULTI_KEY_UNION", + } + + +def test_cycle_handling_values() -> None: + from sema.models.planner.resolution import CycleHandling + + assert {c.value for c in CycleHandling} == { + "REJECT", + "BREAK_AT_DEPTH", + "MARK_AND_CONTINUE", + } + + +def test_resolution_verdict_values() -> None: + from sema.models.planner.resolution import ResolutionVerdict + + assert {v.value for v in ResolutionVerdict} == { + "resolved", + "ambiguous", + "unresolved", + "awaiting_review", + } + + +def test_deterministic_hash_payload() -> None: + from sema.models.planner.resolution import ( + DeterministicHashPayload, + ) + + p = DeterministicHashPayload( + source_key_refs=["cbio.study_id", "cbio.patient.patient_id"], + ) + assert len(p.source_key_refs) == 2 + + +def test_fuzzy_payload_requires_features() -> None: + from sema.models.planner.resolution import FuzzyBlockAndScorePayload + + with pytest.raises(ValidationError): + FuzzyBlockAndScorePayload(blocking_keys=[], similarity_features=[]) + with pytest.raises(ValidationError): + FuzzyBlockAndScorePayload( + blocking_keys=["addr.norm_name_prefix"], + similarity_features=[], + ) + + +def test_multi_key_union_requires_two() -> None: + from sema.models.planner.resolution import MultiKeyUnionPayload + + with pytest.raises(ValidationError): + MultiKeyUnionPayload(source_key_refs=["acris.bbl"]) + + p = MultiKeyUnionPayload( + source_key_refs=["acris.bbl", "acris.address", "dof.parcel_id"] + ) + assert len(p.source_key_refs) == 3 + + +def test_resolution_plan_graph_closure_requires_cycle_handling() -> None: + from sema.models.planner.resolution import ( + GraphClosurePayload, + ResolutionPlan, + ResolutionStrategy, + ) + + with pytest.raises(ValidationError): + ResolutionPlan( + id="r-1", + sources=[_src()], + target_identity_ref="canonical.llc_id", + strategy=ResolutionStrategy.GRAPH_CLOSURE, + payload=GraphClosurePayload(walk_relationship="OWNS"), + transitive_closure=False, + confidence=0.9, + provenance_run=_run_prov(), + timestamp=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + +def test_resolution_plan_transitive_requires_cycle_handling() -> None: + from sema.models.planner.resolution import ( + DeterministicHashPayload, + ResolutionPlan, + ResolutionStrategy, + ) + + with pytest.raises(ValidationError): + ResolutionPlan( + id="r-2", + sources=[_src()], + target_identity_ref="canonical.id", + strategy=ResolutionStrategy.DETERMINISTIC_HASH, + payload=DeterministicHashPayload(source_key_refs=["cbio.x"]), + transitive_closure=True, + confidence=1.0, + provenance_run=_run_prov(), + timestamp=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + +def test_deterministic_hash_implies_confidence_one() -> None: + from sema.models.planner.resolution import ( + DeterministicHashPayload, + ResolutionPlan, + ResolutionStrategy, + ) + + with pytest.raises(ValidationError): + ResolutionPlan( + id="r-3", + sources=[_src()], + target_identity_ref="canonical.id", + strategy=ResolutionStrategy.DETERMINISTIC_HASH, + payload=DeterministicHashPayload(source_key_refs=["cbio.x"]), + confidence=0.9, + provenance_run=_run_prov(), + timestamp=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + +def test_multi_source_resolution_records_each_scope() -> None: + from sema.models.planner.resolution import ( + MultiKeyUnionPayload, + ResolutionPlan, + ResolutionStrategy, + ) + + plan = ResolutionPlan( + id="r-4", + sources=[_src("acris.deeds"), _src("dof.parcels")], + target_identity_ref="canonical.property_id", + strategy=ResolutionStrategy.MULTI_KEY_UNION, + payload=MultiKeyUnionPayload( + source_key_refs=["acris.bbl", "dof.parcel_id"] + ), + confidence=0.85, + provenance_run=_run_prov(), + timestamp=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + assert {s.source_id for s in plan.sources} == {"acris.deeds", "dof.parcels"} + + +def test_resolution_verdict_resolved_for_clean() -> None: + from sema.models.planner.resolution import ( + ResolutionVerdict, + derive_resolution_verdict, + ) + + v = derive_resolution_verdict( + produced_for_every_input=True, + ambiguous_assignments=False, + cycle_blocked=False, + any_block_flag=False, + plan_review_pending=False, + ) + assert v == ResolutionVerdict.resolved + + +def test_resolution_verdict_ambiguous_on_fuzzy_tie() -> None: + from sema.models.planner.resolution import ( + ResolutionVerdict, + derive_resolution_verdict, + ) + + v = derive_resolution_verdict( + produced_for_every_input=True, + ambiguous_assignments=True, + cycle_blocked=False, + any_block_flag=False, + plan_review_pending=False, + ) + assert v == ResolutionVerdict.ambiguous + + +def test_resolution_verdict_unresolved_on_cycle_block() -> None: + from sema.models.planner.resolution import ( + ResolutionVerdict, + derive_resolution_verdict, + ) + + v = derive_resolution_verdict( + produced_for_every_input=False, + ambiguous_assignments=False, + cycle_blocked=True, + any_block_flag=True, + plan_review_pending=False, + ) + assert v == ResolutionVerdict.unresolved + + +def test_resolution_verdict_awaiting_review() -> None: + from sema.models.planner.resolution import ( + ResolutionVerdict, + derive_resolution_verdict, + ) + + v = derive_resolution_verdict( + produced_for_every_input=True, + ambiguous_assignments=False, + cycle_blocked=False, + any_block_flag=True, + plan_review_pending=True, + ) + assert v == ResolutionVerdict.awaiting_review + + +def test_resolution_dependency_round_trip() -> None: + from sema.models.planner.resolution import ResolutionDependency + + rd = ResolutionDependency( + upstream_plan_id="r-4", + canonical_identity_column="canonical.property_id", + ) + rt = ResolutionDependency.model_validate(rd.model_dump(mode="json")) + assert rt.upstream_plan_id == "r-4" diff --git a/tests/unit/models/planner/test_risk_and_evidence.py b/tests/unit/models/planner/test_risk_and_evidence.py new file mode 100644 index 0000000..9684754 --- /dev/null +++ b/tests/unit/models/planner/test_risk_and_evidence.py @@ -0,0 +1,234 @@ +"""Tests for the risk-and-evidence capability.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +pytestmark = pytest.mark.unit + + +def test_severity_values() -> None: + from sema.models.planner.risk import Severity + + assert {s.value for s in Severity} == {"info", "warn", "block"} + + +def test_source_stage_values() -> None: + from sema.models.planner.risk import SourceStage + + assert {s.value for s in SourceStage} == { + "candidate_gen", + "producer", + "constraint", + "verify", + "transform", + } + + +def test_suggested_action_values() -> None: + from sema.models.planner.risk import SuggestedAction + + assert {a.value for a in SuggestedAction} == { + "review", + "request_more_samples", + "reject", + "ignore_with_reason", + } + + +def test_evidence_mode_values() -> None: + from sema.models.planner.risk import EvidenceMode + + assert {m.value for m in EvidenceMode} == { + "RAW", + "CATEGORICAL", + "HASH", + "COUNT_ONLY", + "EXCERPT", + } + + +def test_sensitivity_class_values() -> None: + from sema.models.planner.risk import SensitivityClass + + assert "PUBLIC" in {c.value for c in SensitivityClass} + assert "PHI" in {c.value for c in SensitivityClass} + assert "PII" in {c.value for c in SensitivityClass} + + +def test_default_evidence_mode_for_phi() -> None: + from sema.models.planner.risk import ( + SensitivityClass, + default_evidence_mode, + EvidenceMode, + ) + + assert default_evidence_mode(SensitivityClass.PHI) == EvidenceMode.CATEGORICAL + assert default_evidence_mode(SensitivityClass.PII) == EvidenceMode.HASH + assert default_evidence_mode(SensitivityClass.PUBLIC) == EvidenceMode.RAW + + +def test_evidence_count_only_rejects_literal() -> None: + from sema.models.planner.risk import ( + Evidence, + EvidenceMode, + SensitivityClass, + ) + + with pytest.raises(ValidationError): + Evidence( + mode=EvidenceMode.COUNT_ONLY, + payload={"value": "literal"}, + sensitivity_class=SensitivityClass.PUBLIC, + source_ref="cbio.patient.gender", + ) + + +def test_evidence_count_only_with_count_payload() -> None: + from sema.models.planner.risk import ( + Evidence, + EvidenceMode, + SensitivityClass, + ) + + e = Evidence( + mode=EvidenceMode.COUNT_ONLY, + payload={"count": 42}, + sensitivity_class=SensitivityClass.PHI, + source_ref="cbio.patient.gender", + ) + assert e.mode == EvidenceMode.COUNT_ONLY + + +def test_evidence_raw_against_phi_requires_override() -> None: + from sema.models.planner.risk import ( + Evidence, + EvidenceMode, + SensitivityClass, + ) + + with pytest.raises(ValidationError): + Evidence( + mode=EvidenceMode.RAW, + payload={"value": "Mr. Smith"}, + sensitivity_class=SensitivityClass.PHI, + source_ref="cbio.patient.name", + ) + + e = Evidence( + mode=EvidenceMode.RAW, + payload={"value": "Mr. Smith"}, + sensitivity_class=SensitivityClass.PHI, + source_ref="cbio.patient.name", + explicit_raw_override=True, + ) + assert e.explicit_raw_override is True + + +def test_risk_code_registered() -> None: + from sema.models.planner.risk import RiskCode + + expected = { + "RISK_VOCAB_DOMAIN_MISMATCH", + "RISK_PIVOT_CARDINALITY_UNVERIFIED", + "RISK_TEMPORAL_LOST", + "RISK_AMBIGUOUS_TARGET", + "RISK_OBLIGATION_REQUIRED_FIELD_MISSING", + "RISK_OBLIGATION_FK_UNSATISFIED", + "RISK_OBLIGATION_MINIMUM_VIABLE_ROW_VIOLATED", + "RISK_DEFAULT_APPLIED", + "RISK_RESOLUTION_DEPENDENCY_MISSING", + "RISK_LLC_CYCLE_DETECTED", + "RISK_ASSEMBLER_CONFLICT_RESOLVED", + } + actual = {c.value for c in RiskCode} + assert expected.issubset(actual) + + +def test_risk_flag_construction() -> None: + from sema.models.planner.risk import ( + Evidence, + EvidenceMode, + RiskCode, + RiskFlag, + SensitivityClass, + Severity, + SourceStage, + SuggestedAction, + ) + + rf = RiskFlag( + code=RiskCode.RISK_VOCAB_DOMAIN_MISMATCH, + severity=Severity.warn, + evidence=[ + Evidence( + mode=EvidenceMode.CATEGORICAL, + payload={"distinct": 3, "pattern": "[A-Z]+"}, + sensitivity_class=SensitivityClass.PHI, + source_ref="cbio.patient.gender", + ) + ], + source_stage=SourceStage.producer, + suggested_action=SuggestedAction.review, + ) + assert rf.code == RiskCode.RISK_VOCAB_DOMAIN_MISMATCH + assert len(rf.evidence) == 1 + + +def test_risk_flag_evidence_must_be_list() -> None: + from sema.models.planner.risk import ( + RiskCode, + RiskFlag, + Severity, + SourceStage, + SuggestedAction, + ) + + with pytest.raises(ValidationError): + RiskFlag( + code=RiskCode.RISK_AMBIGUOUS_TARGET, + severity=Severity.warn, + evidence="bare string", + source_stage=SourceStage.producer, + suggested_action=SuggestedAction.review, + ) + + +def test_risk_flag_round_trip() -> None: + from sema.models.planner.risk import ( + Evidence, + EvidenceMode, + RiskCode, + RiskFlag, + SensitivityClass, + Severity, + SourceStage, + SuggestedAction, + ) + + rf = RiskFlag( + code=RiskCode.RISK_AMBIGUOUS_TARGET, + severity=Severity.warn, + evidence=[ + Evidence( + mode=EvidenceMode.COUNT_ONLY, + payload={"count": 7}, + sensitivity_class=SensitivityClass.PHI, + source_ref="x", + ), + Evidence( + mode=EvidenceMode.CATEGORICAL, + payload={"shape": "alpha"}, + sensitivity_class=SensitivityClass.PHI, + source_ref="x", + ), + ], + source_stage=SourceStage.constraint, + suggested_action=SuggestedAction.review, + ) + payload = rf.model_dump(mode="json") + rt = RiskFlag.model_validate(payload) + assert len(rt.evidence) == 2 + assert rt.evidence[0].mode == EvidenceMode.COUNT_ONLY + assert rt.evidence[1].mode == EvidenceMode.CATEGORICAL diff --git a/tests/unit/models/planner/test_target_model.py b/tests/unit/models/planner/test_target_model.py new file mode 100644 index 0000000..987dc87 --- /dev/null +++ b/tests/unit/models/planner/test_target_model.py @@ -0,0 +1,352 @@ +"""Tests for the target-model capability.""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +pytestmark = pytest.mark.unit + + +def test_model_role_enum_values() -> None: + from sema.models.planner._enums import ModelRole + + assert {r.value for r in ModelRole} == {"SOURCE", "TARGET"} + + +def test_target_artifact_kind_values() -> None: + from sema.models.planner._enums import TargetArtifactKind + + assert {k.value for k in TargetArtifactKind} == { + "TABLE_ROW", + "GRAPH_NODE", + "GRAPH_EDGE", + } + + +def test_primary_key_strategy_values() -> None: + from sema.models.planner._enums import PrimaryKeyStrategy + + assert {s.value for s in PrimaryKeyStrategy} == { + "DETERMINISTIC_HASH", + "EXTERNAL_SEQUENCE", + "NATURAL_KEY", + "COMPOUND", + } + + +def test_materialization_mode_values() -> None: + from sema.models.planner._enums import MaterializationMode + + assert {m.value for m in MaterializationMode} == { + "INSERT_ONLY", + "MERGE", + "REPLACE_PARTITION", + } + + +def test_entity_default_role_is_source() -> None: + from sema.models.graph_nodes import Entity + from sema.models.planner._enums import ModelRole + + e = Entity( + id="e1", + name="cbio.patient", + source="cbio", + confidence=0.9, + source_id="cbio", + ) + assert e.model_role == ModelRole.SOURCE + assert e.target_model_id is None + assert e.source_id == "cbio" + + +def test_entity_source_role_requires_source_id() -> None: + from sema.models.graph_nodes import Entity + + with pytest.raises(ValidationError): + Entity(id="e", name="x", source="s", confidence=0.9) + + +def test_entity_target_role_with_kind() -> None: + from sema.models.graph_nodes import Entity + from sema.models.planner._enums import ModelRole, TargetArtifactKind + + e = Entity( + id="e2", + name="omop.person", + source="omop_loader", + confidence=1.0, + model_role=ModelRole.TARGET, + target_model_id="omop-cdm-5.4", + kind=TargetArtifactKind.TABLE_ROW, + ) + assert e.kind == TargetArtifactKind.TABLE_ROW + + +def test_entity_target_must_declare_kind() -> None: + from sema.models.graph_nodes import Entity + from sema.models.planner._enums import ModelRole + + with pytest.raises(ValidationError): + Entity( + id="e3", + name="omop.person", + source="omop_loader", + confidence=1.0, + model_role=ModelRole.TARGET, + target_model_id="omop-cdm-5.4", + ) + + +def test_entity_source_kind_forbidden() -> None: + from sema.models.graph_nodes import Entity + from sema.models.planner._enums import ModelRole, TargetArtifactKind + + with pytest.raises(ValidationError): + Entity( + id="e4", + name="cbio.patient", + source="cbio", + confidence=0.9, + model_role=ModelRole.SOURCE, + kind=TargetArtifactKind.TABLE_ROW, + ) + + +def test_entity_role_collision_rejected() -> None: + from sema.models.graph_nodes import Entity + from sema.models.planner._enums import ModelRole + + with pytest.raises(ValidationError): + Entity( + id="e5", + name="x", + source="s", + confidence=0.9, + model_role=ModelRole.SOURCE, + source_id="cbio", + target_model_id="omop-cdm-5.4", + ) + + +def test_entity_target_requires_target_model_id() -> None: + from sema.models.graph_nodes import Entity + from sema.models.planner._enums import ModelRole, TargetArtifactKind + + with pytest.raises(ValidationError): + Entity( + id="e6", + name="omop.person", + source="loader", + confidence=1.0, + model_role=ModelRole.TARGET, + kind=TargetArtifactKind.TABLE_ROW, + ) + + +def test_property_role_default() -> None: + from sema.models.graph_nodes import Property, SemanticType + from sema.models.planner._enums import ModelRole + + p = Property( + id="p1", + name="gender", + semantic_type=SemanticType.CATEGORICAL, + source="cbio", + confidence=0.9, + source_id="cbio", + ) + assert p.model_role == ModelRole.SOURCE + + +def test_term_role_collision_rejected() -> None: + from sema.models.graph_nodes import Term + from sema.models.planner._enums import ModelRole + + with pytest.raises(ValidationError): + Term( + id="t1", + code="X", + label="X", + source="x", + confidence=0.9, + model_role=ModelRole.TARGET, + target_model_id="omop-cdm-5.4", + source_id="cbio", + ) + + +def test_constraint_default_role() -> None: + from sema.models.planner.target_model import Constraint + from sema.models.planner._enums import ModelRole + + c = Constraint( + id="c1", name="not_null", rule_kind="NULLABILITY", source_id="cbio" + ) + assert c.model_role == ModelRole.SOURCE + + +def test_foreign_key_obligation() -> None: + from sema.models.planner.target_model import ForeignKeyObligation + + fk = ForeignKeyObligation( + referenced_entity="omop.person", + join_keys=[("person_id", "person_id")], + same_build_required=True, + ) + assert fk.referenced_entity == "omop.person" + assert fk.same_build_required is True + + +def test_domain_constraint() -> None: + from sema.models.planner.target_model import DomainConstraint + + dc = DomainConstraint( + property_name="gender_concept_id", + domain_id="Gender", + ) + assert dc.domain_id == "Gender" + + +def test_row_predicate_and_clause_evaluates() -> None: + from sema.models.planner.target_model import ( + FieldPresence, + RowPredicate, + ) + + pred = RowPredicate( + op="AND", + clauses=[ + FieldPresence(field="person_id"), + FieldPresence(field="measurement_concept_id"), + FieldPresence(field="measurement_date"), + ], + ) + assert pred.evaluate({"person_id", "measurement_concept_id", "measurement_date"}) + assert not pred.evaluate({"person_id", "measurement_concept_id"}) + + +def test_row_predicate_or_clause() -> None: + from sema.models.planner.target_model import ( + FieldPresence, + RowPredicate, + ) + + pred = RowPredicate( + op="OR", + clauses=[FieldPresence(field="bbl"), FieldPresence(field="parcel_id")], + ) + assert pred.evaluate({"bbl"}) + assert pred.evaluate({"parcel_id"}) + assert not pred.evaluate(set()) + + +def test_row_predicate_field_equality() -> None: + from sema.models.planner.target_model import ( + FieldEquality, + RowPredicate, + ) + + pred = RowPredicate( + op="AND", + clauses=[FieldEquality(field="status", value="active")], + ) + assert pred.evaluate({"status"}, values={"status": "active"}) + assert not pred.evaluate({"status"}, values={"status": "archived"}) + + +def test_target_obligation_round_trip() -> None: + from sema.models.planner._enums import PrimaryKeyStrategy + from sema.models.planner.target_model import ( + ExternalSequenceMappingTable, + FieldPresence, + ForeignKeyObligation, + RowPredicate, + TargetObligation, + ) + + ob = TargetObligation( + target_entity="omop.person", + required_fields=["person_id", "gender_concept_id", "year_of_birth"], + nullable_fields=["race_concept_id"], + primary_key=PrimaryKeyStrategy.EXTERNAL_SEQUENCE, + external_sequence=ExternalSequenceMappingTable( + mapping_table_name="cbio_patient_to_omop_person", + canonical_identity_column="canonical_patient_id", + sequence_column="person_id", + ), + foreign_keys=[ + ForeignKeyObligation( + referenced_entity="omop.person", + join_keys=[("person_id", "person_id")], + ) + ], + allowed_defaults={"race_concept_id": 0}, + minimum_viable_row=RowPredicate( + op="AND", + clauses=[FieldPresence(field="person_id")], + ), + ) + payload = ob.model_dump(mode="json") + rt = TargetObligation.model_validate(payload) + assert rt.target_entity == ob.target_entity + assert rt.required_fields == ob.required_fields + assert rt.primary_key == PrimaryKeyStrategy.EXTERNAL_SEQUENCE + + +def test_target_obligation_minimum_viable_row_eval() -> None: + from sema.models.planner._enums import PrimaryKeyStrategy + from sema.models.planner.target_model import ( + FieldPresence, + RowPredicate, + TargetObligation, + ) + + ob = TargetObligation( + target_entity="omop.measurement", + required_fields=["person_id", "measurement_concept_id", "measurement_date"], + primary_key=PrimaryKeyStrategy.NATURAL_KEY, + minimum_viable_row=RowPredicate( + op="AND", + clauses=[ + FieldPresence(field="person_id"), + FieldPresence(field="measurement_concept_id"), + FieldPresence(field="measurement_date"), + ], + ), + ) + assert ob.minimum_viable_row.evaluate(set(ob.required_fields)) + assert not ob.minimum_viable_row.evaluate({"person_id", "measurement_concept_id"}) + + +def test_external_sequence_requires_mapping_table() -> None: + from sema.models.planner._enums import PrimaryKeyStrategy + from sema.models.planner.target_model import TargetObligation + + with pytest.raises(ValidationError): + TargetObligation( + target_entity="omop.person", + required_fields=["person_id"], + primary_key=PrimaryKeyStrategy.EXTERNAL_SEQUENCE, + ) + + +def test_non_external_sequence_rejects_mapping_table() -> None: + from sema.models.planner._enums import PrimaryKeyStrategy + from sema.models.planner.target_model import ( + ExternalSequenceMappingTable, + TargetObligation, + ) + + with pytest.raises(ValidationError): + TargetObligation( + target_entity="omop.person", + required_fields=["person_id"], + primary_key=PrimaryKeyStrategy.NATURAL_KEY, + external_sequence=ExternalSequenceMappingTable( + mapping_table_name="x", + canonical_identity_column="x", + sequence_column="x", + ), + ) diff --git a/tests/unit/test_graph_loader.py b/tests/unit/test_graph_loader.py index 9561e56..f86f55a 100644 --- a/tests/unit/test_graph_loader.py +++ b/tests/unit/test_graph_loader.py @@ -145,6 +145,28 @@ def test_upsert_property_uses_property_on_column( assert "IMPLEMENTED_BY" not in cypher assert "ON CREATE SET" in cypher + def test_upsert_property_stamps_implicit_entity_role( + self, loader, mock_driver, + ): + _, session = mock_driver + loader.upsert_property( + "Diagnosis Type", semantic_type="categorical", + source="llm", confidence=0.8, + entity_name="Cancer Diagnosis", + column_name="dx_type_cd", + table_name="cancer_diagnosis", + schema_name="clinical", catalog="cdm", + ) + cypher = session.run.call_args[0][0] + merge_idx = cypher.index("MERGE (e:Entity {name: $entity_name})") + has_property_idx = cypher.index("HAS_PROPERTY") + entity_block = cypher[merge_idx:has_property_idx] + assert "e.model_role = coalesce(e.model_role, 'SOURCE')" in entity_block + assert ( + "e.source_id = coalesce(e.source_id, $source_schema, $source)" + in entity_block + ) + def test_upsert_term(self, loader, mock_driver): _, session = mock_driver loader.upsert_term( @@ -157,6 +179,22 @@ def test_upsert_term(self, loader, mock_driver): assert ":Term" in cypher assert "ON CREATE SET" in cypher + def test_upsert_term_stamps_source_id_from_schema( + self, loader, mock_driver, + ): + _, session = mock_driver + loader.upsert_term( + "CRC", "Colorectal Cancer", source="llm", + confidence=0.85, source_schema="cbioportal_brca", + ) + cypher = session.run.call_args[0][0] + params = session.run.call_args[1] + assert ( + "t.source_id = coalesce(t.source_id, $source_schema, $source)" + in cypher + ) + assert params["source_schema"] == "cbioportal_brca" + def test_upsert_value_set(self, loader, mock_driver): _, session = mock_driver loader.upsert_value_set( diff --git a/tests/unit/test_graph_nodes.py b/tests/unit/test_graph_nodes.py index 60a04e8..b36750c 100644 --- a/tests/unit/test_graph_nodes.py +++ b/tests/unit/test_graph_nodes.py @@ -111,6 +111,7 @@ def test_entity_has_id(self): source="llm_interpretation", confidence=0.75, resolved_at=datetime.now(timezone.utc), + source_id="cbio", ) assert e.id == "ent-1" assert e.name == "Cancer Diagnosis" @@ -124,11 +125,14 @@ def test_entity_embedding_updated_at(self): source="test", confidence=0.9, embedding_updated_at=now, + source_id="test", ) assert e.embedding_updated_at == now def test_entity_embedding_updated_at_default_none(self): - e = Entity(id="ent-3", name="Test", source="test", confidence=0.9) + e = Entity( + id="ent-3", name="Test", source="test", confidence=0.9, source_id="test" + ) assert e.embedding_updated_at is None def test_property_has_id(self): @@ -138,6 +142,7 @@ def test_property_has_id(self): semantic_type=SemanticType.CATEGORICAL, source="llm_interpretation", confidence=0.8, + source_id="cbio", ) assert p.id == "prop-1" assert p.semantic_type == SemanticType.CATEGORICAL @@ -150,6 +155,7 @@ def test_property_embedding_updated_at(self): source="test", confidence=0.9, embedding_updated_at=datetime.now(timezone.utc), + source_id="test", ) assert p.embedding_updated_at is not None @@ -198,6 +204,7 @@ def test_term_has_id(self): label="Colorectal Cancer", source="llm_interpretation", confidence=0.85, + source_id="cbio", ) assert t.id == "term-1" assert t.code == "CRC" @@ -210,6 +217,7 @@ def test_term_embedding_updated_at(self): source="test", confidence=0.9, embedding_updated_at=datetime.now(timezone.utc), + source_id="test", ) assert t.embedding_updated_at is not None @@ -402,6 +410,7 @@ def test_entity_roundtrip(self): source="test", confidence=0.9, resolved_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + source_id="test", ) data = e.model_dump(mode="json") roundtrip = Entity.model_validate(data) diff --git a/tests/unit/test_graph_source_schema.py b/tests/unit/test_graph_source_schema.py index 44dd4e9..f28daed 100644 --- a/tests/unit/test_graph_source_schema.py +++ b/tests/unit/test_graph_source_schema.py @@ -14,6 +14,7 @@ batch_upsert_entities, batch_upsert_join_paths, batch_upsert_properties, + batch_upsert_terms, batch_upsert_value_sets, ) from sema.models.assertions import ( @@ -133,6 +134,45 @@ def test_has_property_and_property_on_column_stamped(self, loader): assert "PROPERTY_ON_COLUMN" in cypher assert cypher.count("source_schema: r.source_schema") == 2 + def test_batch_upsert_properties_stamps_implicit_entity_role( + self, loader, + ): + loader._run = MagicMock() + batch_upsert_properties( + loader, [self._row()], source_schema=SCHEMA_BRCA, + ) + cypher = loader._run.call_args[0][0] + assert "MERGE (e:Entity {name: r.entity_name})" in cypher + merge_idx = cypher.index("MERGE (e:Entity {name: r.entity_name})") + has_property_idx = cypher.index("HAS_PROPERTY") + entity_block = cypher[merge_idx:has_property_idx] + assert "e.model_role = coalesce(e.model_role, 'SOURCE')" in entity_block + assert ( + "e.source_id = coalesce(e.source_id, r.source_schema, r.source)" + in entity_block + ) + + def test_batch_upsert_terms_stamps_source_id_from_schema( + self, loader, + ): + loader._run = MagicMock() + batch_upsert_terms( + loader, + [{ + "code": "0", "label": "neutral", + "vocabulary_name": "cna_call", + "source": "llm_interpretation", "confidence": 0.9, + }], + source_schema=SCHEMA_BRCA, + ) + cypher = loader._run.call_args[0][0] + assert ( + "t.source_id = coalesce(t.source_id, r.source_schema, r.source)" + in cypher + ) + rows = loader._run.call_args[1]["rows"] + assert rows[0]["source_schema"] == SCHEMA_BRCA + def test_has_value_set_stamped(self, loader): loader._run = MagicMock() batch_upsert_value_sets(