diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1e14644..d15deca 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,6 +30,13 @@ jobs: --cov-fail-under=85 \ -q + - name: Verify holdout slice contamination + run: | + uv run python scripts/check_slice_contamination.py \ + --holdout showcase/cbioportal_to_omop/slices/msk_chord_holdout.yaml \ + --against showcase/cbioportal_to_omop/slices/contamination_map.yaml \ + --against showcase/cbioportal_to_omop/slices/msk_chord_dev.yaml + - name: Upload coverage artifact if: always() uses: actions/upload-artifact@v4 diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/check_slice_contamination.py b/scripts/check_slice_contamination.py new file mode 100644 index 0000000..35ef315 --- /dev/null +++ b/scripts/check_slice_contamination.py @@ -0,0 +1,89 @@ +"""CLI to verify a holdout slice has zero overlap with few-shot sources +and any other slices passed as references. + +Usage: + uv run python scripts/check_slice_contamination.py \ + --holdout showcase/cbioportal_to_omop/slices/msk_chord_holdout.yaml \ + --against showcase/cbioportal_to_omop/slices/contamination_map.yaml \ + --against showcase/cbioportal_to_omop/slices/msk_chord_dev.yaml + +Exits non-zero if overlap is detected. +""" +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import Iterable + +import yaml + + +class ContaminationError(RuntimeError): + pass + + +def load_contamination_map(path: Path) -> set[str]: + with path.open("r", encoding="utf-8") as fh: + data = yaml.safe_load(fh) or {} + contaminated = data.get("contaminated_tables") or [] + return {str(t) for t in contaminated} + + +def load_slice_tables(path: Path) -> set[str]: + with path.open("r", encoding="utf-8") as fh: + data = yaml.safe_load(fh) or {} + tables = data.get("tables") or [] + names: set[str] = set() + for entry in tables: + if isinstance(entry, dict) and entry.get("table_name"): + names.add(str(entry["table_name"])) + return names + + +def _load_reference(path: Path) -> set[str]: + contam = load_contamination_map(path) + if contam: + return contam + return load_slice_tables(path) + + +def check_contamination(holdout: Path, references: Iterable[Path]) -> None: + holdout_tables = load_slice_tables(holdout) + reference_tables: set[str] = set() + for ref in references: + reference_tables.update(_load_reference(ref)) + overlap = holdout_tables & reference_tables + if overlap: + raise ContaminationError( + f"Holdout slice {holdout} overlaps with reference tables: " + f"{sorted(overlap)}" + ) + + +def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--holdout", type=Path, required=True, + help="Path to the holdout slice YAML", + ) + parser.add_argument( + "--against", type=Path, action="append", default=[], + help="Path to a contamination_map or dev slice YAML; repeatable", + ) + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + args = _parse_args(argv) + try: + check_contamination(args.holdout, args.against) + except ContaminationError as e: + print(f"FAIL: {e}", file=sys.stderr) + return 1 + print(f"OK: {args.holdout} is disjoint from {len(args.against)} references") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/migrate_cbioportal_to_namespaced.py b/scripts/migrate_cbioportal_to_namespaced.py new file mode 100644 index 0000000..5a85ac0 --- /dev/null +++ b/scripts/migrate_cbioportal_to_namespaced.py @@ -0,0 +1,218 @@ +"""DuckDB migration: rename legacy `cbioportal` schema to namespaced equivalent. + +Renames the local DuckDB staging schema and registers the renamed schema in +`_sema_study_registry` so push discovery picks it up by default. + +Idempotent: safe to re-run. No-ops if the legacy schema is already absent +or the namespaced schema already exists. + +Usage: + uv run python scripts/migrate_cbioportal_to_namespaced.py \\ + --duckdb-path ~/.sema/staging.duckdb \\ + --study-id gbm_tcga_pan_can_atlas_2018 +""" +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Callable + +import click + +from sema.ingest.comment_recovery import ParsedTableComments +from sema.ingest.duckdb_staging import Staging +from sema.ingest.duckdb_staging_utils import ( + build_column_comment_sql, + build_table_comment_sql, +) +from sema.ingest.naming import sanitize_schema_name +from sema.ingest.study_registry import StudyCollisionError, StudyRegistry +from sema.log import logger +from sema.models.config import IngestConfig + +LEGACY_SCHEMA = "cbioportal" +PREFIX = "cbioportal" + +CommentSource = Callable[[str], ParsedTableComments] + + +def _has_schema(staging: Staging, name: str) -> bool: + row = staging.execute( + "SELECT count(*) FROM information_schema.schemata WHERE schema_name = ?", + [name], + ).fetchone() + return bool(row and row[0]) + + +def _rename_schema( + staging: Staging, + src: str, + dst: str, + *, + comment_source: CommentSource | None = None, +) -> None: + """DuckDB has no ALTER SCHEMA RENAME — emulate via copy + drop. + + `comment_source` is an optional callable that returns + `ParsedTableComments` for a given table name. When provided, parser + comments are re-applied after each `CREATE TABLE ... AS SELECT *` + so the rename round-trip preserves comments. + """ + staging.execute(f'CREATE SCHEMA IF NOT EXISTS "{dst}"') + rows = staging.execute( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema = ?", + [src], + ).fetchall() + for (table,) in rows: + staging.execute( + f'CREATE TABLE "{dst}"."{table}" AS SELECT * FROM "{src}"."{table}"' + ) + staging.execute(f'DROP TABLE "{src}"."{table}"') + _reapply_comments(staging, dst, table, comment_source) + staging.execute(f'DROP SCHEMA "{src}"') + + +def _reapply_comments( + staging: Staging, schema: str, table: str, + comment_source: CommentSource | None, +) -> None: + if comment_source is None: + return + try: + parsed = comment_source(table) + except Exception as err: # noqa: BLE001 + logger.warning( + "Comment source unavailable for {}.{}: {}; " + "run `sema ingest recover-comments` to restore later.", + schema, table, err, + ) + return + for column, comment in parsed.column_comments.items(): + if comment: + staging.execute( + build_column_comment_sql(schema, table, column, comment) + ) + if parsed.table_comment: + staging.execute( + build_table_comment_sql(schema, table, parsed.table_comment) + ) + + +def _default_comment_source( + cache_dir: Path, study_id: str, +) -> CommentSource | None: + study_dir = cache_dir / study_id + if not study_dir.exists(): + logger.warning( + "cBioPortal source cache not found at {}; migration will run " + "without comments. Run `sema ingest recover-comments --study {}` " + "later to restore.", + study_dir, study_id, + ) + return None + try: + from showcase.cbioportal_to_omop.comment_extract import ( + extract_study_comments, + ) + except ImportError as err: + logger.warning( + "showcase parser not importable ({}); skipping comment recovery.", + err, + ) + return None + parsed = extract_study_comments(study_dir) + return lambda table: parsed.get( + table, ParsedTableComments(table_comment=None, column_comments={}) + ) + + +def _resolve_target(study_id: str) -> str: + return sanitize_schema_name(PREFIX, study_id) + + +def _backfill_registry(staging: Staging, schema: str, study_id: str) -> None: + registry = StudyRegistry(staging) + try: + registry.register( + schema_name=schema, + original_study_id=study_id, + source_type="cbioportal", + ) + except StudyCollisionError as err: + raise click.ClickException(str(err)) from err + + +@click.command() +@click.option("--duckdb-path", "duckdb_path", default=None, help="DuckDB staging file path.") +@click.option( + "--study-id", + "study_id", + required=True, + help="Original cBioPortal study_id whose data lives in the legacy `cbioportal` schema.", +) +@click.option( + "--dry-run", + is_flag=True, + default=False, + help="Report what would change without executing the rename.", +) +def main(duckdb_path: str | None, study_id: str, dry_run: bool) -> None: + """Rename `cbioportal` -> namespaced schema in DuckDB and backfill registry.""" + config = IngestConfig() + if duckdb_path: + config.duckdb_path = duckdb_path + target_schema = _resolve_target(study_id) + logger.info( + "Migration target: schema {} -> {} in {}", + LEGACY_SCHEMA, target_schema, config.duckdb_path, + ) + + staging = Staging(config.duckdb_path) + try: + legacy_present = _has_schema(staging, LEGACY_SCHEMA) + target_present = _has_schema(staging, target_schema) + + if not legacy_present and target_present: + logger.info( + "Already migrated: {} absent, {} present. Backfilling registry only.", + LEGACY_SCHEMA, target_schema, + ) + if not dry_run: + _backfill_registry(staging, target_schema, study_id) + return + + if not legacy_present: + logger.info("Nothing to migrate: schema {} not present.", LEGACY_SCHEMA) + if not dry_run: + _backfill_registry(staging, target_schema, study_id) + return + + if target_present: + raise click.ClickException( + f"Both {LEGACY_SCHEMA} and {target_schema} schemas exist; " + "manual reconciliation required (the migration cannot merge schemas)." + ) + + if dry_run: + logger.info( + "DRY RUN: would ALTER SCHEMA {} RENAME TO {} and register {}", + LEGACY_SCHEMA, target_schema, study_id, + ) + return + + comment_source = _default_comment_source( + Path(config.cache_dir).expanduser(), study_id, + ) + _rename_schema( + staging, LEGACY_SCHEMA, target_schema, + comment_source=comment_source, + ) + _backfill_registry(staging, target_schema, study_id) + logger.info("Migration complete: {} -> {}", LEGACY_SCHEMA, target_schema) + finally: + staging.close() + + +if __name__ == "__main__": + sys.exit(main(standalone_mode=True)) diff --git a/scripts/verify_cross_study_collapse.py b/scripts/verify_cross_study_collapse.py new file mode 100644 index 0000000..8b58f6f --- /dev/null +++ b/scripts/verify_cross_study_collapse.py @@ -0,0 +1,101 @@ +"""Verify shared concepts collapse across studies in the Neo4j graph. + +Queries Neo4j for `:Term` nodes whose downstream paths reach columns in more +than one `source_schema`. Emits a JSON report listing each shared term, the +contributing schemas, and the connecting edge sample. Exits non-zero if a +specifically-required term (passed via `--require-code`) is not multi-study. + +Usage: + uv run python scripts/verify_cross_study_collapse.py \\ + --output eval-runs/msk-chord-full-C/cross-study-collapse.json \\ + --require-code HGNC:TP53 +""" +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import Any + +import click + +from sema.cli_factories import _get_neo4j_driver +from sema.log import logger +from sema.models.config import Neo4jConfig + +SHARED_TERM_QUERY = """ +MATCH (t:Term) +OPTIONAL MATCH (t)-[m:MEMBER_OF]->(:ValueSet)<-[hvs:HAS_VALUE_SET]-(c:Column) +WITH t, + collect(DISTINCT m.source_schema) + + collect(DISTINCT hvs.source_schema) AS schemas +WITH t, [s IN schemas WHERE s IS NOT NULL] AS schemas +WHERE size(schemas) > 1 +RETURN t.code AS code, + t.label AS label, + schemas +ORDER BY t.code +""" + + +def _query_shared_terms(driver: Any) -> list[dict[str, Any]]: + with driver.session() as session: + result = session.run(SHARED_TERM_QUERY) + return [ + { + "code": rec["code"], + "label": rec["label"], + "source_schemas": sorted(set(rec["schemas"])), + } + for rec in result + ] + + +def _enforce_required( + shared: list[dict[str, Any]], required_codes: tuple[str, ...] +) -> list[str]: + found = {entry["code"] for entry in shared} + return [code for code in required_codes if code not in found] + + +@click.command() +@click.option( + "--output", + "output_path", + required=True, + type=click.Path(path_type=Path), + help="JSON report destination path.", +) +@click.option( + "--require-code", + "required_codes", + multiple=True, + default=(), + help="Term code that must appear with multiple source_schemas (repeatable).", +) +def main(output_path: Path, required_codes: tuple[str, ...]) -> None: + """Run shared-term verification and write JSON report.""" + driver = _get_neo4j_driver(Neo4jConfig()) + try: + shared = _query_shared_terms(driver) + finally: + driver.close() + + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text( + json.dumps( + {"shared_terms": shared, "count": len(shared)}, + indent=2, + sort_keys=True, + ) + ) + logger.info("Wrote {} shared terms to {}", len(shared), output_path) + + missing = _enforce_required(shared, required_codes) + if missing: + logger.error("Required terms not multi-study: {}", missing) + sys.exit(1) + + +if __name__ == "__main__": + sys.exit(main(standalone_mode=True)) diff --git a/showcase/cbioportal_to_omop/README.md b/showcase/cbioportal_to_omop/README.md index c4e0328..517cd7e 100644 --- a/showcase/cbioportal_to_omop/README.md +++ b/showcase/cbioportal_to_omop/README.md @@ -11,32 +11,101 @@ End-to-end demo of the sema pipeline: - `parsers.py` — cBioPortal source parsers (clinical, MAF, SV, CNA, gene panel matrix, resources, timelines) - `cbioportal_utils.py` — download/type/IO helpers for `parsers.py` +- `comment_extract.py` — pure metadata-only walker that returns per-table `ParsedTableComments` for the cBioPortal cache; used by both the `sema ingest recover-comments` command and the DuckDB rename migration - `slices/dev_slice.yaml` — 13-table dev slice for prompt tuning - `slices/dev_slice_poc.yaml` — 12-table subset matching the current Databricks POC ingest - `slices/holdout.yaml` — 10-table held-out slice for bias checks -## Run from a source checkout +## Multi-study workflow (post-namespacing, 2026-04-24) + +Each cBioPortal `study_id` produces its own DuckDB schema and own Databricks +schema, named `cbioportal_`. The ingest path sanitizes, +records the schema in `_sema_study_registry` (DuckDB), and is idempotent on +re-ingest. Push (default) discovers schemas from the registry plus the +known shared schemas (`ontology_omop`, `vocabulary_omop`); scratch schemas +in DuckDB are NOT published unless `--discover-all-schemas` is set. ```bash -# Step 1 — stage cBioPortal study into DuckDB -uv run sema ingest cbioportal gbm_tcga_pan_can_atlas_2018 \ - --cache-dir ~/.cache/sema/cbioportal \ - --duckdb-path ./poc.duckdb +# Step 1a — stage two cBioPortal studies into DuckDB (each lands in its own schema) +PYTHONPATH=. uv run sema ingest cbioportal gbm_tcga_pan_can_atlas_2018 \ + --cache-dir ~/.sema/cache/cbioportal \ + --duckdb-path ~/.sema/poc.duckdb +PYTHONPATH=. uv run sema ingest cbioportal msk_chord_2024 \ + --cache-dir ~/.sema/cache/cbioportal \ + --duckdb-path ~/.sema/poc.duckdb + +# Step 1b — OMOP CDM + vocabulary +uv run sema ingest omop --vocab-path ~/data/omop/athena_2026_04 -# Step 2 — push to Databricks (requires DATABRICKS_* env) -uv run sema push --target databricks --duckdb-path ./poc.duckdb +# Step 2 — push to Databricks. Default mode discovers from +# `_sema_study_registry` ∪ {ontology_omop, vocabulary_omop}; both study +# schemas land alongside the shared ontology/vocab schemas. +uv run sema push --target databricks --duckdb-path ~/.sema/poc.duckdb -# Step 3 — run the staged L2 pipeline against the catalog +# Step 3 — run the staged L2 pipeline against ONE study's catalog uv run sema build \ - --catalog workspace --schemas cbioportal_omop \ - --domain healthcare \ - --table-workers 1 --skip-embeddings --verbose + --catalog workspace --schemas cbioportal_msk_chord_2024 \ + --domain healthcare --table-workers 1 --skip-embeddings --verbose -# Step 4 — evaluate against a slice +# Step 4 — evaluate against a slice (slice file references the namespaced schema) uv run sema eval run \ - --slice showcase/cbioportal_to_omop/slices/dev_slice_poc.yaml \ - --label post-showcase-refactor \ - --output-dir eval-runs/post-showcase-refactor + --slice showcase/cbioportal_to_omop/slices/msk_chord_dev.yaml \ + --label baseline-A \ + --output-dir eval-runs/msk-chord-baseline-A +``` + +### Scoped re-build per study + +Each `sema build --schemas X` run begins with a scoped-delete that removes +every relationship stamped with `source_schema = X` plus `:Assertion` and +`:JoinPath` nodes whose `source_schema = X`. Other studies' assertions and +provenance edges are untouched; shared concept and physical nodes are +never deleted. Re-running BRCA(GBM) does not affect MSK CHORD's slice of +the graph, and vice versa. + +### Legacy schema deprecation + +The flat `workspace.cbioportal` schema (pre-2026-04-24) is **deprecated** +and tagged with a `COMMENT ON SCHEMA` describing the migration. Its +contents (GBM TCGA Pan-Can Atlas 2018) live at +`workspace.cbioportal_gbm_tcga_pan_can_atlas_2018` post-migration. The +flat schema will be dropped in a follow-up change after a one-milestone +deprecation window. To migrate a local DuckDB staging file, run: + +```bash +uv run python scripts/migrate_cbioportal_to_namespaced.py \ + --duckdb-path ~/.sema/poc.duckdb \ + --study-id gbm_tcga_pan_can_atlas_2018 +``` + +The DuckDB rename emulator (CREATE TABLE … AS SELECT \*) drops column and +table comments. The migration script now re-applies parser-extracted +comments after each table rebuild when the local source cache +(`IngestConfig.cache_dir / `) is present. When the cache is +absent, the migration logs a WARN and the operator can restore comments +later via `sema ingest recover-comments`. + +### Recovering lost column / table comments + +`sema ingest recover-comments` re-parses cBioPortal source files for a +study and applies `ALTER TABLE … ALTER COLUMN … COMMENT '…'` / +`COMMENT ON TABLE … IS '…'` to the corresponding namespaced Databricks +schema. It does NOT re-stage data, is idempotent (skips columns that +already have a comment), and supports `--dry-run` and `--json`. + +```bash +# Dry-run to inspect the SQL plan +uv run sema ingest recover-comments \ + --study gbm_tcga_pan_can_atlas_2018 --dry-run --json + +# Apply +uv run sema ingest recover-comments --study gbm_tcga_pan_can_atlas_2018 + +# Bypass the registry entirely +uv run sema ingest recover-comments \ + --source-cache /path/to/cache \ + --target-catalog workspace \ + --target-schema cbioportal_x ``` ## Packaging note diff --git a/showcase/cbioportal_to_omop/cbioportal_fetch_utils.py b/showcase/cbioportal_to_omop/cbioportal_fetch_utils.py new file mode 100644 index 0000000..d173e9d --- /dev/null +++ b/showcase/cbioportal_to_omop/cbioportal_fetch_utils.py @@ -0,0 +1,72 @@ +"""Network fetch helpers for the cBioPortal datahub. + +Pulled out of parsers.py to keep that module under 400 lines. +""" +from __future__ import annotations + +import json +from pathlib import Path +from urllib.request import Request, urlopen + +from showcase.cbioportal_to_omop.cbioportal_utils import ( + GITHUB_API_TEMPLATE, + MEDIA_URL_TEMPLATE, + RAW_URL_TEMPLATE, +) +from sema.log import logger + + +def fetch_study_files(study_id: str, cache_dir: Path) -> Path: + study_cache = cache_dir / study_id + done_marker = study_cache / ".done" + if done_marker.exists(): + logger.info("Using cached cBioPortal study files at {}", study_cache) + return study_cache + study_cache.mkdir(parents=True, exist_ok=True) + entries = list_study_entries(study_id) + downloaded = 0 + from showcase.cbioportal_to_omop.parsers import _should_download + for entry in entries: + if entry.get("type") != "file": + continue + name = entry["name"] + if not _should_download(name): + continue + fetch_lfs_or_raw(study_id, name, study_cache / name) + downloaded += 1 + done_marker.touch() + logger.info("Fetched {} cBioPortal files for {}", downloaded, study_id) + return study_cache + + +def fetch_lfs_or_raw(study_id: str, filename: str, target: Path) -> None: + media_url = MEDIA_URL_TEMPLATE.format(study_id=study_id, filename=filename) + try: + download_url_to(media_url, target) + return + except Exception as media_err: + logger.debug("media URL failed for {}: {}; falling back to raw", filename, media_err) + raw_url = RAW_URL_TEMPLATE.format(study_id=study_id, filename=filename) + download_url_to(raw_url, target) + + +def list_study_entries(study_id: str) -> list[dict[str, str]]: + url = GITHUB_API_TEMPLATE.format(study_id=study_id) + req = Request(url, headers={"Accept": "application/vnd.github+json"}) + with urlopen(req) as resp: + data: bytes = resp.read() + parsed = json.loads(data.decode("utf-8")) + if not isinstance(parsed, list): + raise RuntimeError(f"Unexpected GitHub API response for {study_id}: {parsed!r}") + return parsed + + +def download_url_to(url: str, target: Path) -> None: + logger.info("Downloading {} -> {}", url, target) + with urlopen(url) as resp: + with target.open("wb") as out: + while True: + chunk = resp.read(1024 * 1024) + if not chunk: + break + out.write(chunk) diff --git a/showcase/cbioportal_to_omop/cbioportal_utils.py b/showcase/cbioportal_to_omop/cbioportal_utils.py index 3286575..e6270c1 100644 --- a/showcase/cbioportal_to_omop/cbioportal_utils.py +++ b/showcase/cbioportal_to_omop/cbioportal_utils.py @@ -46,6 +46,28 @@ "data_gene_panel_matrix.txt", }) +DOWNLOAD_EXTENSIONS: tuple[str, ...] = (".seg",) + +LAB_TIMELINE_REQUIRED_COLUMNS: frozenset[str] = frozenset({"TEST", "VALUE", "UNITS"}) + +SEG_COLUMN_RENAMES: dict[str, str] = { + "ID": "sample_id", + "chrom": "chrom", + "loc.start": "loc_start", + "loc.end": "loc_end", + "num.mark": "num_mark", + "seg.mean": "seg_mean", +} + +SEG_COLUMN_TYPES: dict[str, str] = { + "sample_id": "VARCHAR", + "chrom": "VARCHAR", + "loc_start": "BIGINT", + "loc_end": "BIGINT", + "num_mark": "BIGINT", + "seg_mean": "DOUBLE", +} + DOWNLOAD_PREFIXES: tuple[str, ...] = ( "data_clinical_", "data_timeline_", @@ -72,7 +94,7 @@ re.compile(r"^data_rppa.*\.txt$"), ) -TIMELINE_PATTERN = re.compile(r"^data_timeline_(?P[a-zA-Z0-9_]+)\.txt$") +TIMELINE_PATTERN = re.compile(r"^data_timeline_(?P[a-zA-Z0-9_-]+)\.txt$") @dataclass @@ -124,6 +146,47 @@ def sv_column_type(name: str) -> str: return "VARCHAR" +def normalize_seg_header(header: list[str]) -> list[str]: + return [SEG_COLUMN_RENAMES.get(h.strip(), h.strip().replace(".", "_")) for h in header] + + +def gene_panel_long_rows( + header: list[str], data_rows: list[list[str]], +) -> tuple[list[str], list[list[str | None]]]: + """Pivot gene panel matrix from wide (sample × assay) to long. + + Wide: SAMPLE_ID | mutations | cna | sv ... + Long: sample_id, panel_id, assay (one row per non-blank cell) + """ + if not header: + raise ValueError("gene_panel_matrix header is empty") + sample_col = _find_sample_column_index(header) + out_header = ["sample_id", "panel_id", "assay"] + out_rows: list[list[str | None]] = [] + for row in data_rows: + sample = row[sample_col] if sample_col < len(row) else "" + for idx, col_name in enumerate(header): + if idx == sample_col: + continue + value = row[idx].strip() if idx < len(row) and row[idx] else "" + if not value: + continue + out_rows.append([sample, value, col_name]) + return out_header, out_rows + + +def _find_sample_column_index(header: list[str]) -> int: + for i, name in enumerate(header): + if name.strip().lower() in {"sample_id", "sample id"}: + return i + return 0 + + +def is_lab_timeline_header(column_names: list[str]) -> bool: + upper = {c.upper() for c in column_names} + return LAB_TIMELINE_REQUIRED_COLUMNS.issubset(upper) + + def cna_long_format_rows( header: list[str], data_rows: list[list[str]], ) -> tuple[list[str], list[list[str | None]]]: @@ -171,6 +234,27 @@ def open_text_defensive(path: Path) -> IO[str]: return path.open("r", encoding="utf-8", errors="replace") +def dedupe_header_case_insensitive(header: list[str]) -> list[str]: + """Suffix duplicate column names so DuckDB's case-insensitive identifier + rule does not reject the CREATE TABLE. + + cBioPortal MAF files in some studies (e.g. MSK CHORD 2024) ship with + case-variant duplicates such as `Comments` / `comments`. The first + occurrence keeps its original casing; subsequent collisions append + `_2`, `_3`, etc., based on lower-cased identity. + """ + seen: dict[str, int] = {} + deduped: list[str] = [] + for name in header: + key = name.lower() + seen[key] = seen.get(key, 0) + 1 + if seen[key] == 1: + deduped.append(name) + else: + deduped.append(f"{name}_{seen[key]}") + return deduped + + def read_tsv_rows( path: Path, skip_comment_prefix: bool = True ) -> tuple[list[str], list[list[str]], int]: diff --git a/showcase/cbioportal_to_omop/comment_extract.py b/showcase/cbioportal_to_omop/comment_extract.py new file mode 100644 index 0000000..0b238a6 --- /dev/null +++ b/showcase/cbioportal_to_omop/comment_extract.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from pathlib import Path + +from sema.ingest.comment_recovery import ParsedTableComments + +from showcase.cbioportal_to_omop.parsers import ( + iter_timeline_files, + parse_clinical_file, + timeline_table_name, +) + + +_FIXED_TABLE_DESCRIPTORS: tuple[tuple[str, str, str], ...] = ( + ("data_sv.txt", "structural_variant", + "cBioPortal structural variants from data_sv.txt"), + ("data_cna.txt", "cna", + "cBioPortal copy-number alterations from data_cna.txt " + "(pivoted to long format: one row per sample×gene)"), + ("data_gene_panel_matrix.txt", "gene_panel_matrix", + "cBioPortal gene panel matrix from data_gene_panel_matrix.txt"), +) + +_MAF_FILENAMES: tuple[str, ...] = ("data_mutations.txt", "data_mutations_extended.txt") + + +def extract_study_comments(study_dir: Path) -> dict[str, ParsedTableComments]: + out: dict[str, ParsedTableComments] = {} + _extract_clinical(study_dir, out, "data_clinical_patient.txt", "patient") + _extract_clinical(study_dir, out, "data_clinical_sample.txt", "sample") + _extract_supp_clinical(study_dir, out) + _extract_maf(study_dir, out) + _extract_fixed_tables(study_dir, out) + _extract_segmented_cna(study_dir, out) + _extract_resource_files(study_dir, out) + _extract_timelines(study_dir, out) + return out + + +def _extract_clinical( + study_dir: Path, + out: dict[str, ParsedTableComments], + filename: str, + table: str, +) -> None: + path = study_dir / filename + if not path.exists(): + return + _, _, column_comments = parse_clinical_file(path) + out[table] = ParsedTableComments( + table_comment=f"cBioPortal clinical {table} from {filename}", + column_comments=column_comments, + ) + + +def _extract_supp_clinical( + study_dir: Path, out: dict[str, ParsedTableComments], +) -> None: + for entry in sorted(study_dir.iterdir()): + if not (entry.is_file() and entry.name.startswith("data_clinical_supp_")): + continue + if not entry.name.endswith(".txt"): + continue + table = entry.stem.removeprefix("data_") + _, _, column_comments = parse_clinical_file(entry) + out[table] = ParsedTableComments( + table_comment=f"cBioPortal {table} from {entry.name}", + column_comments=column_comments, + ) + + +def _extract_maf( + study_dir: Path, out: dict[str, ParsedTableComments], +) -> None: + for candidate in _MAF_FILENAMES: + path = study_dir / candidate + if path.exists(): + out["mutation"] = ParsedTableComments( + table_comment=f"cBioPortal MAF mutations from {candidate}", + column_comments={}, + ) + return + + +def _extract_fixed_tables( + study_dir: Path, out: dict[str, ParsedTableComments], +) -> None: + for filename, table, comment in _FIXED_TABLE_DESCRIPTORS: + if not (study_dir / filename).exists(): + continue + out[table] = ParsedTableComments( + table_comment=comment, column_comments={}, + ) + + +def _extract_segmented_cna( + study_dir: Path, out: dict[str, ParsedTableComments], +) -> None: + seg_files = sorted(study_dir.glob("*.seg")) + if not seg_files: + return + seg = seg_files[0] + out["cna_segmented"] = ParsedTableComments( + table_comment=f"cBioPortal segmented CNA from {seg.name}", + column_comments={}, + ) + + +def _extract_resource_files( + study_dir: Path, out: dict[str, ParsedTableComments], +) -> None: + for entry in sorted(study_dir.iterdir()): + if not (entry.is_file() and entry.name.startswith("data_resource_")): + continue + if not entry.name.endswith(".txt"): + continue + table = entry.stem.removeprefix("data_") + out[table] = ParsedTableComments( + table_comment=f"cBioPortal {table} from {entry.name}", + column_comments={}, + ) + + +def _extract_timelines( + study_dir: Path, out: dict[str, ParsedTableComments], +) -> None: + for kind, path in iter_timeline_files(study_dir): + table = timeline_table_name(kind) + out[table] = ParsedTableComments( + table_comment=f"cBioPortal {table} from {path.name}", + column_comments={}, + ) diff --git a/showcase/cbioportal_to_omop/parsers.py b/showcase/cbioportal_to_omop/parsers.py index b906efb..089fd99 100644 --- a/showcase/cbioportal_to_omop/parsers.py +++ b/showcase/cbioportal_to_omop/parsers.py @@ -1,25 +1,30 @@ from __future__ import annotations -import json from pathlib import Path from typing import Any, Iterator -from urllib.request import Request, urlopen import pyarrow as pa +from showcase.cbioportal_to_omop.cbioportal_fetch_utils import ( + fetch_study_files, +) + from showcase.cbioportal_to_omop.cbioportal_utils import ( DOWNLOAD_EXACT_FILENAMES, + DOWNLOAD_EXTENSIONS, DOWNLOAD_PREFIXES, EXCLUDED_DOWNLOAD_PREFIXES, - GITHUB_API_TEMPLATE, - MEDIA_URL_TEMPLATE, - RAW_URL_TEMPLATE, + SEG_COLUMN_TYPES, SKIP_FILENAME_PATTERNS, TIMELINE_PATTERN, ClinicalHeader, cbioportal_type_to_duckdb, cna_long_format_rows, + dedupe_header_case_insensitive, + gene_panel_long_rows, + is_lab_timeline_header, maf_column_type, + normalize_seg_header, parse_clinical_header, read_clinical_data_rows, read_header_block, @@ -28,8 +33,12 @@ sv_column_type, ) from sema.ingest.duckdb_staging import Staging +from sema.ingest.naming import sanitize_schema_name +from sema.ingest.study_registry import StudyRegistry from sema.log import logger +CBIOPORTAL_SCHEMA_PREFIX = "cbioportal" + __all__ = [ "ClinicalHeader", "fetch_study_files", @@ -39,77 +48,27 @@ "parse_clinical_header", "parse_cna_file", "parse_gene_panel_matrix", + "parse_lab_timeline", "parse_maf", "parse_resource_file", + "parse_segmented_cna", "parse_sv_file", "parse_timeline_file", + "timeline_table_name", ] -def fetch_study_files(study_id: str, cache_dir: Path) -> Path: - study_cache = cache_dir / study_id - done_marker = study_cache / ".done" - if done_marker.exists(): - logger.info("Using cached cBioPortal study files at {}", study_cache) - return study_cache - study_cache.mkdir(parents=True, exist_ok=True) - entries = _list_study_entries(study_id) - downloaded = 0 - for entry in entries: - if entry.get("type") != "file": - continue - name = entry["name"] - if not _should_download(name): - continue - _fetch_lfs_or_raw(study_id, name, study_cache / name) - downloaded += 1 - done_marker.touch() - logger.info("Fetched {} cBioPortal files for {}", downloaded, study_id) - return study_cache - - -def _fetch_lfs_or_raw(study_id: str, filename: str, target: Path) -> None: - media_url = MEDIA_URL_TEMPLATE.format(study_id=study_id, filename=filename) - try: - _download_url_to(media_url, target) - return - except Exception as media_err: - logger.debug("media URL failed for {}: {}; falling back to raw", filename, media_err) - raw_url = RAW_URL_TEMPLATE.format(study_id=study_id, filename=filename) - _download_url_to(raw_url, target) - - -def _list_study_entries(study_id: str) -> list[dict[str, str]]: - url = GITHUB_API_TEMPLATE.format(study_id=study_id) - req = Request(url, headers={"Accept": "application/vnd.github+json"}) - with urlopen(req) as resp: - data: bytes = resp.read() - parsed = json.loads(data.decode("utf-8")) - if not isinstance(parsed, list): - raise RuntimeError(f"Unexpected GitHub API response for {study_id}: {parsed!r}") - return parsed - - def _should_download(filename: str) -> bool: lowered = filename.lower() if filename in DOWNLOAD_EXACT_FILENAMES: return True + if any(lowered.endswith(ext) for ext in DOWNLOAD_EXTENSIONS): + return True if any(lowered.startswith(p) for p in EXCLUDED_DOWNLOAD_PREFIXES): return False return any(filename.startswith(p) for p in DOWNLOAD_PREFIXES) -def _download_url_to(url: str, target: Path) -> None: - logger.info("Downloading {} -> {}", url, target) - with urlopen(url) as resp: - with target.open("wb") as out: - while True: - chunk = resp.read(1024 * 1024) - if not chunk: - break - out.write(chunk) - - def parse_clinical_file( path: Path, ) -> tuple[pa.Table, dict[str, str], dict[str, str]]: @@ -143,6 +102,7 @@ def _clinical_column_comments(header: ClinicalHeader) -> dict[str, str]: def parse_maf(path: Path) -> tuple[pa.Table, dict[str, str], dict[str, str]]: column_names, data_rows, _ = read_tsv_rows(path, skip_comment_prefix=True) + column_names = dedupe_header_case_insensitive(column_names) column_types = {name: maf_column_type(name) for name in column_names} return rows_to_arrow(column_names, data_rows, column_types), column_types, {} @@ -192,15 +152,46 @@ def parse_cna_file( def parse_gene_panel_matrix( path: Path, ) -> tuple[pa.Table, dict[str, str], dict[str, str]]: - column_names, data_rows, _ = read_tsv_rows(path, skip_comment_prefix=True) - column_types = {name: "VARCHAR" for name in column_names} + header, data_rows, _ = read_tsv_rows(path, skip_comment_prefix=True) + long_header, long_rows = gene_panel_long_rows(header, data_rows) + column_types = {name: "VARCHAR" for name in long_header} + rows_as_str: list[list[str]] = [ + ["" if v is None else str(v) for v in row] + for row in long_rows + ] return ( - rows_to_arrow(column_names, data_rows, column_types), + rows_to_arrow(long_header, rows_as_str, column_types), + column_types, + {}, + ) + + +def parse_segmented_cna( + path: Path, +) -> tuple[pa.Table, dict[str, str], dict[str, str]]: + header, data_rows, _ = read_tsv_rows(path, skip_comment_prefix=True) + normalized = normalize_seg_header(header) + column_types = { + name: SEG_COLUMN_TYPES.get(name, "VARCHAR") for name in normalized + } + return ( + rows_to_arrow(normalized, data_rows, column_types), column_types, {}, ) +def parse_lab_timeline( + path: Path, +) -> tuple[pa.Table, dict[str, str], dict[str, str]]: + column_names, data_rows, _ = read_tsv_rows(path, skip_comment_prefix=True) + column_types: dict[str, str] = {name: "VARCHAR" for name in column_names} + for name in column_names: + if name.upper() == "VALUE": + column_types[name] = "DOUBLE" + return rows_to_arrow(column_names, data_rows, column_types), column_types, {} + + def parse_resource_file( path: Path, ) -> tuple[pa.Table, dict[str, str], dict[str, str]]: @@ -222,6 +213,10 @@ def iter_timeline_files(directory: Path) -> Iterator[tuple[str, Path]]: yield match.group("kind"), entry +def timeline_table_name(kind: str) -> str: + return f"timeline_{kind.replace('-', '_')}" + + def _list_skipped_files(directory: Path) -> list[Path]: skipped: list[Path] = [] for entry in sorted(directory.iterdir()): @@ -242,33 +237,73 @@ def ingest_study( staging: Staging, cache_dir: Path, ) -> None: + schema_name = sanitize_schema_name(CBIOPORTAL_SCHEMA_PREFIX, study_id) + StudyRegistry(staging).register( + schema_name=schema_name, + original_study_id=study_id, + source_type="cbioportal", + ) + staging.execute(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"') study_dir = fetch_study_files(study_id, cache_dir) - _ingest_study_dir(study_id, study_dir, staging) + _ingest_study_dir(study_id, study_dir, staging, schema_name=schema_name) -def _ingest_study_dir(study_id: str, study_dir: Path, staging: Staging) -> None: +def _ingest_study_dir( + study_id: str, + study_dir: Path, + staging: Staging, + schema_name: str, +) -> None: for skipped in _list_skipped_files(study_dir): logger.info("Skipping unsupported cBioPortal file: {}", skipped.name) - _try_ingest_clinical(study_dir, staging, "data_clinical_patient.txt", "patient") - _try_ingest_clinical(study_dir, staging, "data_clinical_sample.txt", "sample") - _try_ingest_maf(study_dir, staging) - _try_ingest_fixed_files(study_dir, staging) + _try_ingest_clinical( + study_dir, staging, "data_clinical_patient.txt", "patient", + schema_name=schema_name, + ) + _try_ingest_clinical( + study_dir, staging, "data_clinical_sample.txt", "sample", + schema_name=schema_name, + ) + _try_ingest_maf(study_dir, staging, schema_name=schema_name) + _try_ingest_fixed_files(study_dir, staging, schema_name=schema_name) + _try_ingest_seg_files(study_dir, staging, schema_name=schema_name) _ingest_prefix_matched_files( study_dir, staging, prefix="data_resource_", parser=parse_resource_file, + schema_name=schema_name, ) _ingest_prefix_matched_files( study_dir, staging, prefix="data_clinical_supp_", parser=parse_clinical_file, uses_clinical_comments=True, + schema_name=schema_name, ) - _ingest_timelines(study_dir, staging) + _ingest_timelines(study_dir, staging, schema_name=schema_name) logger.info("Finished ingesting cBioPortal study {}", study_id) +def _try_ingest_seg_files( + study_dir: Path, staging: Staging, schema_name: str, +) -> None: + seg_files = sorted(study_dir.glob("*.seg")) + if not seg_files: + return + for seg in seg_files: + rows, column_types, _ = parse_segmented_cna(seg) + staging.write_table( + schema=schema_name, + table="cna_segmented", + rows=rows, + column_types=column_types, + column_comments={}, + table_comment=f"cBioPortal segmented CNA from {seg.name}", + ) + + def _try_ingest_clinical( - study_dir: Path, staging: Staging, filename: str, table: str + study_dir: Path, staging: Staging, filename: str, table: str, + schema_name: str, ) -> None: path = study_dir / filename if not path.exists(): @@ -276,7 +311,7 @@ def _try_ingest_clinical( return rows, column_types, column_comments = parse_clinical_file(path) staging.write_table( - schema="cbioportal", + schema=schema_name, table=table, rows=rows, column_types=column_types, @@ -302,7 +337,9 @@ def _try_ingest_clinical( } -def _try_ingest_fixed_files(study_dir: Path, staging: Staging) -> None: +def _try_ingest_fixed_files( + study_dir: Path, staging: Staging, schema_name: str, +) -> None: for filename, table_name, comment in _SIMPLE_FIXED_FILES: path = study_dir / filename if not path.exists(): @@ -310,7 +347,7 @@ def _try_ingest_fixed_files(study_dir: Path, staging: Staging) -> None: parser = _PARSERS_BY_FILENAME[filename] rows, column_types, _ = parser(path) staging.write_table( - schema="cbioportal", + schema=schema_name, table=table_name, rows=rows, column_types=column_types, @@ -326,6 +363,7 @@ def _ingest_prefix_matched_files( prefix: str, parser: Any, uses_clinical_comments: bool = False, + schema_name: str, ) -> None: for entry in sorted(study_dir.iterdir()): if not (entry.is_file() and entry.name.startswith(prefix)): @@ -337,7 +375,7 @@ def _ingest_prefix_matched_files( rows, column_types = result[0], result[1] comments = result[2] if uses_clinical_comments else {} staging.write_table( - schema="cbioportal", + schema=schema_name, table=table_name, rows=rows, column_types=column_types, @@ -346,13 +384,15 @@ def _ingest_prefix_matched_files( ) -def _try_ingest_maf(study_dir: Path, staging: Staging) -> None: +def _try_ingest_maf( + study_dir: Path, staging: Staging, schema_name: str, +) -> None: for candidate in ("data_mutations.txt", "data_mutations_extended.txt"): path = study_dir / candidate if path.exists(): rows, column_types, _ = parse_maf(path) staging.write_table( - schema="cbioportal", + schema=schema_name, table="mutation", rows=rows, column_types=column_types, @@ -363,20 +403,36 @@ def _try_ingest_maf(study_dir: Path, staging: Staging) -> None: logger.info("cBioPortal MAF missing in {}, skipping", study_dir.name) -def _ingest_timelines(study_dir: Path, staging: Staging) -> None: +def _ingest_timelines( + study_dir: Path, staging: Staging, schema_name: str, +) -> None: timelines = list(iter_timeline_files(study_dir)) if not timelines: logger.info("No cBioPortal timeline files present in {}", study_dir.name) return for kind, path in timelines: - rows, column_types, _ = parse_timeline_file(path) + parser = _select_timeline_parser(path) + rows, column_types, _ = parser(path) + table = timeline_table_name(kind) staging.write_table( - schema="cbioportal", - table=f"timeline_{kind}", + schema=schema_name, + table=table, rows=rows, column_types=column_types, column_comments={}, - table_comment=f"cBioPortal timeline_{kind} from {path.name}", + table_comment=f"cBioPortal {table} from {path.name}", ) +def _select_timeline_parser(path: Path) -> Any: + header_lines = read_header_block(path, max_lines=4) + for line in header_lines: + if line.startswith("#"): + continue + column_names = line.rstrip("\n").split("\t") + if is_lab_timeline_header(column_names): + return parse_lab_timeline + return parse_timeline_file + return parse_timeline_file + + diff --git a/showcase/cbioportal_to_omop/slices/README.md b/showcase/cbioportal_to_omop/slices/README.md new file mode 100644 index 0000000..96a558a --- /dev/null +++ b/showcase/cbioportal_to_omop/slices/README.md @@ -0,0 +1,45 @@ +# cBioPortal Eval Slices + +YAML slice configs that select tables from a study's namespaced staging +schema for the eval runner. + +## Slice files + +| File | Schema | Role | +|---|---|---| +| `dev_slice.yaml` | `cbioportal_gbm_tcga_pan_can_atlas_2018` (post-Phase 2) | Original BRCA dev slice | +| `dev_slice_poc.yaml` | `cbioportal_gbm_tcga_pan_can_atlas_2018` (post-Phase 2) | BRCA POC regression-guard slice — small, cheap, run on every change | +| `holdout.yaml` | `cbioportal_gbm_tcga_pan_can_atlas_2018` (post-Phase 2) | Original BRCA holdout | +| `msk_chord_dev.yaml` | `cbioportal_msk_chord_2024` | MSK CHORD 2024 dev slice (12 tables) — overlaps with few-shots intentionally | +| `msk_chord_holdout.yaml` | `cbioportal_msk_chord_2024` | MSK CHORD 2024 holdout (9 tables) — disjoint from few-shot sources + dev | +| `contamination_map.yaml` | n/a | Lists every table referenced by a few-shot example. Holdouts MUST NOT include any of these. | + +## Contamination policy + +A table that appears in any few-shot prompt example becomes +"contaminated" — the LLM has been shown its expected output, so its +performance on that table is not unbiased. Holdout slices are designed to +measure pipeline generalization and therefore MUST exclude every +contaminated table. + +`scripts/check_slice_contamination.py` enforces this: + +```bash +uv run python scripts/check_slice_contamination.py \ + --holdout showcase/cbioportal_to_omop/slices/msk_chord_holdout.yaml \ + --against showcase/cbioportal_to_omop/slices/contamination_map.yaml \ + --against showcase/cbioportal_to_omop/slices/msk_chord_dev.yaml +``` + +The check runs in CI; PRs that contaminate a holdout fail. + +When adding a new few-shot example to +`src/sema/engine/few_shot_healthcare_stage_{a,b,c}.py`, add the example's +`table_name` to `contamination_map.yaml` if it isn't already there. + +## BRCA POC regression guard + +`dev_slice_poc.yaml` is intentionally retained as a small, cheap slice +that catches regressions on the patterns the original +`source-semantic-hardening` change established. Run it before/after any +change touching prompts, materializer, or graph loader. diff --git a/showcase/cbioportal_to_omop/slices/contamination_map.yaml b/showcase/cbioportal_to_omop/slices/contamination_map.yaml new file mode 100644 index 0000000..23998e6 --- /dev/null +++ b/showcase/cbioportal_to_omop/slices/contamination_map.yaml @@ -0,0 +1,26 @@ +# Contamination map: tables used as few-shot example sources. +# Holdout slices MUST NOT include any of these tables. +# Updated whenever new tables are referenced from +# src/sema/engine/few_shot_healthcare_stage_{a,b,c}.py. +# +# Sources are referenced by table_name only — schema-agnostic, so adding +# the same table from a different study still trips contamination. + +version: 1 +last_updated: "2026-04-24" + +contaminated_tables: + # Stage A few-shot sources + - patient + - sample + - mutation + - treatment + - structural_variant + - timeline_labtest + - timeline_surgery + - timeline_performance_status + - cna_segmented + + # Stage B / Stage C additional sources (not also in Stage A) + - biomarker + - progression diff --git a/showcase/cbioportal_to_omop/slices/dev_slice.yaml b/showcase/cbioportal_to_omop/slices/dev_slice.yaml index 80024de..32ed3c9 100644 --- a/showcase/cbioportal_to_omop/slices/dev_slice.yaml +++ b/showcase/cbioportal_to_omop/slices/dev_slice.yaml @@ -9,11 +9,13 @@ # # Version log: # v1 (2026-04-13): Initial selection based on known cBioPortal GENIE corpus. +# v2 (2026-04-24): namespaced schema migration — schema renamed from +# `cbioportal` to `cbioportal_gbm_tcga_pan_can_atlas_2018`. -version: 1 -created: "2026-04-13" +version: 2 +created: "2026-04-24" catalog: unity -schema: cbioportal +schema: cbioportal_gbm_tcga_pan_can_atlas_2018 tables: - table_name: patient diff --git a/showcase/cbioportal_to_omop/slices/dev_slice_poc.yaml b/showcase/cbioportal_to_omop/slices/dev_slice_poc.yaml index 2442153..643a5a2 100644 --- a/showcase/cbioportal_to_omop/slices/dev_slice_poc.yaml +++ b/showcase/cbioportal_to_omop/slices/dev_slice_poc.yaml @@ -1,5 +1,6 @@ # Dev Slice POC — matches current Databricks ingest from -# gbm_tcga_pan_can_atlas_2018 (12 tables in workspace.cbioportal). +# gbm_tcga_pan_can_atlas_2018 (12 tables in +# workspace.cbioportal_gbm_tcga_pan_can_atlas_2018, post-namespacing). # # Covers 9 of the 13 tables in the original eval/dev_slice.yaml spec # plus 3 extras (timeline_*, clinical_supp_hypoxia, resource_patient @@ -10,11 +11,13 @@ # (patient/sample/mutation/3 timelines) to gbm_tcga_pan_can_atlas_2018 # (adds structural_variant, cna, gene_panel_matrix, # resource_definition, resource_patient, clinical_supp_hypoxia). +# v3 (2026-04-24): namespaced schema migration — schema renamed from +# `cbioportal` to `cbioportal_gbm_tcga_pan_can_atlas_2018`. -version: 2 -created: "2026-04-19" +version: 3 +created: "2026-04-24" catalog: workspace -schema: cbioportal +schema: cbioportal_gbm_tcga_pan_can_atlas_2018 tables: - table_name: patient diff --git a/showcase/cbioportal_to_omop/slices/holdout.yaml b/showcase/cbioportal_to_omop/slices/holdout.yaml index 53cb779..822d417 100644 --- a/showcase/cbioportal_to_omop/slices/holdout.yaml +++ b/showcase/cbioportal_to_omop/slices/holdout.yaml @@ -10,12 +10,14 @@ # # Version log: # v1 (2026-04-13): Initial selection. Frozen before Step 3 (domain-aware prompts). +# v2 (2026-04-24): namespaced schema migration — schema renamed from +# `cbioportal` to `cbioportal_gbm_tcga_pan_can_atlas_2018`. Table list unchanged. -version: 1 -created: "2026-04-13" +version: 2 +created: "2026-04-24" frozen_before_step: 3 catalog: unity -schema: cbioportal +schema: cbioportal_gbm_tcga_pan_can_atlas_2018 tables: - table_name: resource_patient diff --git a/showcase/cbioportal_to_omop/slices/msk_chord_dev.yaml b/showcase/cbioportal_to_omop/slices/msk_chord_dev.yaml new file mode 100644 index 0000000..b876167 --- /dev/null +++ b/showcase/cbioportal_to_omop/slices/msk_chord_dev.yaml @@ -0,0 +1,75 @@ +# MSK CHORD 2024 dev slice — 12 tables for iteration / regression. +# +# Schema reference: cbioportal_msk_chord_2024 (the namespaced staging schema +# created by `sema ingest cbioportal msk_chord_2024`). +# +# This slice INCLUDES tables that appear in healthcare few-shot sources; +# that is intentional — dev runs verify the pipeline reproduces the +# patterns it has been taught. Holdout slice (msk_chord_holdout.yaml) is +# disjoint from few-shot sources for unbiased eval. + +version: 1 +created: "2026-04-24" +catalog: workspace +schema: cbioportal_msk_chord_2024 + +tables: + - table_name: patient + reason: "MSK CHORD demographics; baseline cross-study collapse target" + failure_mode: encoded_categorical + tier: sanity + + - table_name: sample + reason: "Specimen-level biomarkers, MSI / TMB / cancer types" + failure_mode: mixed_semantic_types + tier: sanity + + - table_name: mutation + reason: "Wide MAF table; multi-batch Stage B; large row count" + failure_mode: wide_table_batching + tier: stress + + - table_name: structural_variant + reason: "SV/fusion calls; site1/site2 column structure" + failure_mode: abbreviation_ambiguity + tier: standard + + - table_name: cna_segmented + reason: "Genomic-segment-level CNA from .seg file (new format)" + failure_mode: mixed_semantic_types + tier: standard + + - table_name: gene_panel_matrix + reason: "Long-format sample×assay panel assignments (new shape)" + failure_mode: identifier_only + tier: standard + + - table_name: timeline_cea_labs + reason: "CEA lab measurements with TEST/VALUE/UNITS — new few-shot family" + failure_mode: mixed_semantic_types + tier: standard + + - table_name: timeline_surgery + reason: "Procedure events with PROCEDURE_CODE column" + failure_mode: encoded_categorical + tier: standard + + - table_name: timeline_performance_status + reason: "ECOG / Karnofsky scores; ordinal decoding" + failure_mode: encoded_categorical + tier: standard + + - table_name: timeline_treatment + reason: "Treatment events with agent / regimen vocabulary" + failure_mode: encoded_categorical + tier: standard + + - table_name: timeline_specimen + reason: "Specimen collection timeline — encoded dates, sample-of-collection" + failure_mode: encoded_categorical + tier: standard + + - table_name: timeline_progression + reason: "Disease progression events — date encoding, status flags" + failure_mode: encoded_categorical + tier: standard diff --git a/showcase/cbioportal_to_omop/slices/msk_chord_holdout.yaml b/showcase/cbioportal_to_omop/slices/msk_chord_holdout.yaml new file mode 100644 index 0000000..219a5ec --- /dev/null +++ b/showcase/cbioportal_to_omop/slices/msk_chord_holdout.yaml @@ -0,0 +1,56 @@ +# MSK CHORD 2024 holdout slice — 9 tables disjoint from dev + few-shot sources. +# +# Holdout reserves tables NOT used as few-shot sources (per +# contamination_map.yaml) AND NOT in the dev slice. Used for unbiased +# evaluation of the pipeline's ability to generalize. +# +# Adjust this list once MSK CHORD has actually been ingested and the real +# table inventory is known. The names below are speculative based on +# typical cBioPortal study layouts — names that don't materialize in the +# DuckDB schema after ingest will simply produce no rows in the slice run. + +version: 1 +created: "2026-04-24" +catalog: workspace +schema: cbioportal_msk_chord_2024 + +tables: + - table_name: timeline_radiation_therapy + reason: "Radiation timeline; not used as few-shot source" + failure_mode: encoded_categorical + tier: standard + + - table_name: timeline_specimen_collection + reason: "Specimen sampling timeline; not seeded by few-shots" + failure_mode: encoded_categorical + tier: standard + + - table_name: timeline_sequencing + reason: "Sequencing event timeline; new shape for the eval" + failure_mode: encoded_categorical + tier: standard + + - table_name: timeline_diagnosis + reason: "Diagnosis timeline events" + failure_mode: mixed_semantic_types + tier: standard + + - table_name: timeline_brain_imaging + reason: "Sub-modality imaging timeline; modality-specific" + failure_mode: encoded_categorical + tier: standard + + - table_name: resource_definition + reason: "Resource metadata; non-clinical entity" + failure_mode: non_clinical_entity + tier: edge + + - table_name: resource_patient + reason: "Resource-to-patient bridge" + failure_mode: identifier_only + tier: edge + + - table_name: clinical_supp_panel + reason: "Supplemental clinical attributes (panel-specific)" + failure_mode: mixed_semantic_types + tier: standard diff --git a/src/sema/cli.py b/src/sema/cli.py index f415ee6..56175ff 100644 --- a/src/sema/cli.py +++ b/src/sema/cli.py @@ -21,6 +21,7 @@ ) from sema.cli_ingest import ingest as _ingest_group, push_cmd as _push_cmd from sema.cli_eval import eval_group as _eval_group +from sema.cli_utils import build_config_from_args as _build_config_from_args @click.group() @@ -43,74 +44,6 @@ def emit(self, record: logging.LogRecord) -> None: logging.getLogger("databricks.sql").setLevel(logging.WARNING) -def _build_config_from_args( - *, - source: str | None, - catalog: str | None, - schemas: str | None, - table_pattern: str | None, - domain: str | None, - table_workers: int | None, - neo4j_uri: str | None, - neo4j_user: str | None, - neo4j_password: str | None, - llm_provider: str | None, - llm_model: str | None, - llm_timeout: int | None, - config_file: str | None, - skip_embeddings: bool, - resume: bool, - verbose: bool, -) -> BuildConfig: - """Assemble a :class:`BuildConfig` from CLI arguments.""" - overrides: dict[str, Any] = {} - if source is not None: - overrides["source"] = source - if catalog is not None: - overrides["catalog"] = catalog - if schemas is not None: - overrides["schemas"] = [s.strip() for s in schemas.split(",")] - if table_pattern is not None: - overrides["table_pattern"] = table_pattern - if domain is not None: - overrides["domain"] = domain - overrides["domain_from_cli"] = True - if table_workers is not None: - overrides["table_workers"] = table_workers - if skip_embeddings: - overrides["skip_embeddings"] = True - if resume: - overrides["resume"] = True - if verbose: - overrides["verbose"] = True - - if config_file: - build_config = BuildConfig.from_file(config_file, overrides=overrides) - else: - build_config = BuildConfig(**overrides) - - if neo4j_uri is not None: - build_config.neo4j = Neo4jConfig( - uri=neo4j_uri, - user=neo4j_user or build_config.neo4j.user, - password=neo4j_password or build_config.neo4j.password, # type: ignore[arg-type] - ) - - llm_overrides: dict[str, Any] = {} - if llm_provider is not None: - llm_overrides["provider"] = llm_provider - if llm_model is not None: - llm_overrides["model"] = llm_model - if llm_timeout is not None: - llm_overrides["request_timeout"] = llm_timeout - if llm_overrides: - build_config.llm = build_config.llm.model_copy( - update=llm_overrides, - ) - - return build_config - - @cli.command() @click.option("--source", default="databricks", help="Data source connector type") @click.option("--catalog", default=None, help="Catalog name to extract from") @@ -127,6 +60,15 @@ def _build_config_from_args( @click.option("--config", "config_file", default=None, help="Path to config YAML file") @click.option("--skip-embeddings", is_flag=True, default=False, help="Create indexes only, skip embedding computation") @click.option("--resume", is_flag=True, default=False, help="Skip tables that already have assertions in the graph") +@click.option( + "--enable-fk-detection/--no-enable-fk-detection", + "enable_fk_detection", default=True, + help="Run the FK detector after extraction (default: on)", +) +@click.option( + "--materialize-structural-fk", is_flag=True, default=False, + help="Lower materialization threshold to 0.70 so Tier-3 (name+type only) FK candidates promote to JoinPath nodes", +) @click.option("--verbose", is_flag=True, default=False, help="Enable verbose output") def build( source: str | None, @@ -144,6 +86,8 @@ def build( config_file: str | None, skip_embeddings: bool, resume: bool, + enable_fk_detection: bool, + materialize_structural_fk: bool, verbose: bool, ) -> None: """Build the knowledge graph from a data source.""" @@ -155,7 +99,10 @@ def build( neo4j_password=neo4j_password, llm_provider=llm_provider, llm_model=llm_model, llm_timeout=llm_timeout, config_file=config_file, - skip_embeddings=skip_embeddings, resume=resume, verbose=verbose, + skip_embeddings=skip_embeddings, resume=resume, + enable_fk_detection=enable_fk_detection, + materialize_structural_fk=materialize_structural_fk, + verbose=verbose, ) try: report = run_build(build_config) diff --git a/src/sema/cli_ingest.py b/src/sema/cli_ingest.py index ec1076b..a5c3ad3 100644 --- a/src/sema/cli_ingest.py +++ b/src/sema/cli_ingest.py @@ -1,12 +1,27 @@ from __future__ import annotations +import json from pathlib import Path +from typing import Any, Callable, Sequence import click +from sema.ingest.comment_recovery import ( + LiveTableComments, + ParsedTableComments, + PartialOverrideError, + RecoveryReport, + StudyCacheMissingError, + StudyNotRegisteredError, + build_recovery_plan, + execute_recovery_plan, + read_databricks_comments, + resolve_recovery_context, +) from sema.ingest.databricks_push import Bridge, PushError from sema.ingest.duckdb_staging import Staging from sema.ingest.omop import ingest_cdm_schema, ingest_vocabulary +from sema.ingest.study_registry import StudyRegistry from sema.log import logger from sema.models.config import IngestConfig @@ -99,6 +114,178 @@ def ingest_omop_cmd( staging.close() +def _extract_study_comments_lazy( + study_dir: Path, +) -> dict[str, ParsedTableComments]: + try: + from showcase.cbioportal_to_omop.comment_extract import ( + extract_study_comments, + ) + except ImportError as err: + raise click.ClickException( + "The cBioPortal showcase is not importable. Run from a source " + "checkout where the 'showcase/' directory is on sys.path." + ) from err + return extract_study_comments(study_dir) + + +def _open_recovery_executor( + config: IngestConfig, +) -> tuple[Callable[[str], None], Callable[[str, list[str]], Sequence[Sequence[Any]]]]: + creds = config.databricks_creds + from databricks import sql as databricks_sql + conn = databricks_sql.connect( + server_hostname=creds.host.replace("https://", ""), + http_path=creds.http_path, + access_token=creds.token.get_secret_value(), + ) + + def execute(sql: str) -> None: + cursor = conn.cursor() + try: + cursor.execute(sql) + finally: + cursor.close() + + def query(sql: str, params: list[str]) -> Sequence[Sequence[Any]]: + cursor = conn.cursor() + try: + cursor.execute(sql, params) + return cursor.fetchall() + finally: + cursor.close() + + return execute, query + + +def _emit_summary( + study_id: str | None, + target_catalog: str, + target_schema: str, + report: RecoveryReport, + as_json: bool, +) -> None: + if as_json: + payload = { + "study_id": study_id, + "target_catalog": target_catalog, + "target_schema": target_schema, + "tables_visited": _count_tables_visited(report), + "columns_updated": report.columns_updated, + "columns_skipped": report.columns_skipped, + "columns_failed": report.columns_failed, + "table_comments_updated": report.table_comments_updated, + } + click.echo(json.dumps(payload)) + return + click.echo("\nRecovery Report") + click.echo("=" * 40) + click.echo(f" Study: {study_id or ''}") + click.echo(f" Target: {target_catalog}.{target_schema}") + click.echo(f" Columns updated: {report.columns_updated}") + click.echo(f" Table comments updated: {report.table_comments_updated}") + click.echo(f" Columns skipped: {report.columns_skipped}") + click.echo(f" Columns failed: {report.columns_failed}") + if report.failed: + click.echo(" Failures:") + for f in report.failed: + click.echo(f" {f.table}.{f.column}: {f.error}") + + +def _count_tables_visited(report: RecoveryReport) -> int: + tables: set[str] = set() + for s in report.skipped: + tables.add(s.table) + for f in report.failed: + tables.add(f.table) + return len(tables) + + +@ingest.command("recover-comments") +@click.option("--study", "study_id", default=None, help="cBioPortal study_id.") +@click.option("--source-cache", "source_cache", default=None, + help="Override path to the local cBioPortal cache.") +@click.option("--target-catalog", "target_catalog", default=None, + help="Override Databricks catalog.") +@click.option("--target-schema", "target_schema", default=None, + help="Override Databricks schema.") +@click.option("--cache-dir", "cache_dir", default=None, + help="Override IngestConfig.cache_dir for this run.") +@click.option("--duckdb-path", "duckdb_path", default=None, + help="Override DuckDB staging file path.") +@click.option("--dry-run", is_flag=True, default=False, + help="Print SQL without executing.") +@click.option("--force", is_flag=True, default=False, + help="Overwrite existing comments.") +@click.option("--json", "as_json", is_flag=True, default=False, + help="Emit a JSON summary on stdout.") +def recover_comments_cmd( + study_id: str | None, + source_cache: str | None, + target_catalog: str | None, + target_schema: str | None, + cache_dir: str | None, + duckdb_path: str | None, + dry_run: bool, + force: bool, + as_json: bool, +) -> None: + """Re-apply parser-extracted column and table comments to Databricks.""" + config = _load_ingest_config(duckdb_path) + if cache_dir: + config.cache_dir = cache_dir + staging = Staging(config.duckdb_path) + try: + registry = StudyRegistry(staging) + try: + ctx = resolve_recovery_context( + study_id=study_id, registry=registry, ingest_config=config, + source_cache_override=Path(source_cache) if source_cache else None, + target_catalog_override=target_catalog, + target_schema_override=target_schema, + ) + except StudyNotRegisteredError as err: + raise click.ClickException(str(err)) from err + except StudyCacheMissingError as err: + raise click.ClickException(str(err)) from err + except PartialOverrideError as err: + raise click.UsageError(str(err)) from err + + parsed = _extract_study_comments_lazy(ctx.source_cache) + if dry_run: + executor: Callable[[str], None] = lambda _sql: None + query_fn: Callable[..., Sequence[Sequence[Any]]] = ( + lambda *_a, **_k: [] + ) + try: + executor, query_fn = _open_recovery_executor(config) + except Exception as exc: # noqa: BLE001 + logger.warning( + "Skipping live read in --dry-run (no Databricks): {}", exc + ) + live: dict[str, LiveTableComments] = {} + else: + live = read_databricks_comments( + ctx.target_catalog, ctx.target_schema, query_fn, + ) + else: + executor, query_fn = _open_recovery_executor(config) + live = read_databricks_comments( + ctx.target_catalog, ctx.target_schema, query_fn, + ) + plan = build_recovery_plan(ctx, parsed, live, force=force) + report = execute_recovery_plan(plan, executor, dry_run=dry_run) + _emit_summary( + ctx.study_id, ctx.target_catalog, ctx.target_schema, report, as_json, + ) + if report.columns_failed > 0: + raise click.ClickException( + f"{report.columns_failed} column(s) failed during recovery." + ) + finally: + staging.close() + + @click.command("push") @click.option("--target", default="databricks", help="Push target (databricks only).") @click.option( diff --git a/src/sema/cli_utils.py b/src/sema/cli_utils.py new file mode 100644 index 0000000..db01f02 --- /dev/null +++ b/src/sema/cli_utils.py @@ -0,0 +1,122 @@ +"""Helpers for CLI command implementations. + +Extracted from ``cli.py`` to keep that module under the size budget. +""" +from __future__ import annotations + +from typing import Any + +from sema.models.config import BuildConfig, Neo4jConfig + + +def build_config_from_args( + *, + source: str | None, + catalog: str | None, + schemas: str | None, + table_pattern: str | None, + domain: str | None, + table_workers: int | None, + neo4j_uri: str | None, + neo4j_user: str | None, + neo4j_password: str | None, + llm_provider: str | None, + llm_model: str | None, + llm_timeout: int | None, + config_file: str | None, + skip_embeddings: bool, + resume: bool, + verbose: bool, + enable_fk_detection: bool = True, + materialize_structural_fk: bool = False, +) -> BuildConfig: + """Assemble a :class:`BuildConfig` from CLI arguments.""" + overrides = _gather_overrides( + source=source, catalog=catalog, schemas=schemas, + table_pattern=table_pattern, domain=domain, + table_workers=table_workers, + skip_embeddings=skip_embeddings, resume=resume, verbose=verbose, + enable_fk_detection=enable_fk_detection, + materialize_structural_fk=materialize_structural_fk, + ) + + if config_file: + build_config = BuildConfig.from_file( + config_file, overrides=overrides, + ) + else: + build_config = BuildConfig(**overrides) + + if neo4j_uri is not None: + build_config.neo4j = Neo4jConfig( + uri=neo4j_uri, + user=neo4j_user or build_config.neo4j.user, + password=neo4j_password or build_config.neo4j.password, # type: ignore[arg-type] + ) + + _apply_llm_overrides( + build_config, + llm_provider=llm_provider, + llm_model=llm_model, + llm_timeout=llm_timeout, + ) + return build_config + + +def _gather_overrides( + *, + source: str | None, + catalog: str | None, + schemas: str | None, + table_pattern: str | None, + domain: str | None, + table_workers: int | None, + skip_embeddings: bool, + resume: bool, + verbose: bool, + enable_fk_detection: bool, + materialize_structural_fk: bool, +) -> dict[str, Any]: + overrides: dict[str, Any] = {} + if source is not None: + overrides["source"] = source + if catalog is not None: + overrides["catalog"] = catalog + if schemas is not None: + overrides["schemas"] = [s.strip() for s in schemas.split(",")] + if table_pattern is not None: + overrides["table_pattern"] = table_pattern + if domain is not None: + overrides["domain"] = domain + overrides["domain_from_cli"] = True + if table_workers is not None: + overrides["table_workers"] = table_workers + if skip_embeddings: + overrides["skip_embeddings"] = True + if resume: + overrides["resume"] = True + if verbose: + overrides["verbose"] = True + overrides["enable_fk_detection"] = enable_fk_detection + overrides["materialize_structural_fk"] = materialize_structural_fk + return overrides + + +def _apply_llm_overrides( + build_config: BuildConfig, + *, + llm_provider: str | None, + llm_model: str | None, + llm_timeout: int | None, +) -> None: + llm_overrides: dict[str, Any] = {} + if llm_provider is not None: + llm_overrides["provider"] = llm_provider + if llm_model is not None: + llm_overrides["model"] = llm_model + if llm_timeout is not None: + llm_overrides["request_timeout"] = llm_timeout + if llm_overrides: + build_config.llm = build_config.llm.model_copy( + update=llm_overrides, + ) diff --git a/src/sema/engine/few_shot_healthcare.py b/src/sema/engine/few_shot_healthcare.py index 994a74f..9419209 100644 --- a/src/sema/engine/few_shot_healthcare.py +++ b/src/sema/engine/few_shot_healthcare.py @@ -1,476 +1,19 @@ """Healthcare few-shot examples for staged L2 prompts. -Examples sourced from timbr-poc cBioPortal analyst questions; cover the +Examples sourced from cBioPortal analyst questions; cover the clinical-genomics shapes (patient, sample, mutation, treatment, structural -variant) and the value encodings common to TCGA-style warehouses. +variant) plus MSK CHORD-introduced shapes (lab timelines, procedures, +performance status, biomarker states, segmented CNA). Stage arrays live in +sibling modules to keep each file under the per-file line cap. """ from __future__ import annotations -from typing import Any +from sema.engine.few_shot_healthcare_stage_a import HEALTHCARE_STAGE_A +from sema.engine.few_shot_healthcare_stage_b import HEALTHCARE_STAGE_B +from sema.engine.few_shot_healthcare_stage_c import HEALTHCARE_STAGE_C -HEALTHCARE_STAGE_A: list[dict[str, Any]] = [ - { - "input": { - "table_name": "patient", - "columns": "patient_id (STRING), gender (STRING), " - "current_age (INT), os_status (STRING), os_months (DOUBLE), " - "dfs_status (STRING), dfs_months (DOUBLE), " - "smoking_status (STRING), stage_highest (STRING)", - }, - "output": { - "primary_entity": "Patient", - "grain_hypothesis": "one row per patient", - "secondary_entity_hints": [ - "cancer diagnosis", "survival outcome", - ], - "ambiguity_flags": [], - "confidence": 0.95, - }, - }, - { - "input": { - "table_name": "sample", - "columns": "sample_id (STRING), patient_id (STRING), " - "cancer_type (STRING), cancer_type_detailed (STRING), " - "sample_type (STRING), tmb (DOUBLE), msi_type (STRING), " - "oncotree_code (STRING), sample_class (STRING)", - }, - "output": { - "primary_entity": "Biospecimen/Sample", - "grain_hypothesis": "one row per tumor sample " - "(multiple samples per patient)", - "secondary_entity_hints": ["tumor characterization"], - "ambiguity_flags": [], - "confidence": 0.9, - }, - }, - { - "input": { - "table_name": "mutation", - "columns": "sample_id (STRING), hugo_symbol (STRING), " - "variant_classification (STRING), hgvsp_short (STRING), " - "chromosome (STRING), start_position (INT), " - "end_position (INT), reference_allele (STRING), " - "tumor_seq_allele2 (STRING), mutation_status (STRING)", - }, - "output": { - "primary_entity": "Somatic Mutation", - "grain_hypothesis": "one row per variant call per sample", - "secondary_entity_hints": [ - "gene", "protein change", - ], - "ambiguity_flags": [], - "confidence": 0.9, - }, - }, - { - "input": { - "table_name": "treatment", - "columns": "patient_id (STRING), treatment_subtype (STRING), " - "agent (STRING), start_date (INT), stop_date (INT)", - }, - "output": { - "primary_entity": "Treatment Event", - "grain_hypothesis": "one row per treatment event " - "(multiple events per patient)", - "secondary_entity_hints": ["drug/agent", "regimen"], - "ambiguity_flags": [], - "confidence": 0.85, - }, - }, - { - "input": { - "table_name": "structural_variant", - "columns": "sample_id (STRING), site1_gene (STRING), " - "site2_gene (STRING), sv_class (STRING), " - "event_info (STRING), annotation (STRING)", - }, - "output": { - "primary_entity": "Structural Variant", - "grain_hypothesis": "one row per structural variant " - "call per sample", - "secondary_entity_hints": ["fusion partner genes"], - "ambiguity_flags": [], - "confidence": 0.85, - }, - }, -] - -HEALTHCARE_STAGE_B: list[dict[str, Any]] = [ - { - "input": { - "table_name": "patient", - "column": "patient_id", - "data_type": "STRING", - "entity_context": "Patient", - }, - "output": { - "canonical_property_label": "patient identifier", - "semantic_type": "patient identifier", - "synonyms": ["subject id", "case id", "participant id"], - "candidate_vocab_families": [], - "entity_role": "primary_key", - "needs_stage_c": False, - }, - }, - { - "input": { - "table_name": "mutation", - "column": "sample_id", - "data_type": "STRING", - "entity_context": "Somatic Mutation", - }, - "output": { - "canonical_property_label": "sample identifier", - "semantic_type": "specimen/sample identifier", - "synonyms": ["specimen id", "biospecimen id", "tumor sample id"], - "candidate_vocab_families": [], - "entity_role": "foreign_key", - "needs_stage_c": False, - }, - }, - { - "input": { - "table_name": "patient", - "column": "gender", - "data_type": "STRING", - "top_values": "Male, Female, Other", - "entity_context": "Patient", - }, - "output": { - "canonical_property_label": "biological sex", - "semantic_type": "demographic", - "synonyms": ["sex", "biological sex"], - "candidate_vocab_families": [], - "entity_role": "attribute", - "needs_stage_c": True, - }, - }, - { - "input": { - "table_name": "treatment", - "column": "start_date", - "data_type": "INT", - "entity_context": "Treatment Event", - }, - "output": { - "canonical_property_label": "treatment start date", - "semantic_type": "temporal field", - "candidate_vocab_families": [ - "days-from-epoch encoding", - ], - "entity_role": "attribute", - "needs_stage_c": True, - }, - }, - { - "input": { - "table_name": "sample", - "column": "cancer_type", - "data_type": "STRING", - "top_values": "Non-Small Cell Lung Cancer, " - "Colorectal Cancer, Breast Cancer", - "entity_context": "Biospecimen/Sample", - }, - "output": { - "canonical_property_label": "cancer type", - "semantic_type": "diagnosis/condition", - "candidate_vocab_families": [ - "cancer classification system", - ], - "entity_role": "attribute", - "needs_stage_c": False, - }, - }, - { - "input": { - "table_name": "sample", - "column": "cancer_type_detailed", - "data_type": "STRING", - "entity_context": "Biospecimen/Sample", - }, - "output": { - "canonical_property_label": "cancer subtype", - "semantic_type": "diagnosis/condition", - "candidate_vocab_families": [ - "cancer subtype classification", - ], - "entity_role": "attribute", - "needs_stage_c": False, - }, - }, - { - "input": { - "table_name": "sample", - "column": "tmb", - "data_type": "DOUBLE", - "entity_context": "Biospecimen/Sample", - }, - "output": { - "canonical_property_label": "tumor mutational burden", - "semantic_type": "biomarker/gene/variant", - "synonyms": ["tmb", "mutations per megabase", "mutation burden"], - "candidate_vocab_families": [], - "entity_role": "attribute", - "needs_stage_c": False, - }, - }, - { - "input": { - "table_name": "sample", - "column": "msi_type", - "data_type": "STRING", - "top_values": "Instable, Stable", - "entity_context": "Biospecimen/Sample", - }, - "output": { - "canonical_property_label": "microsatellite instability", - "semantic_type": "biomarker/gene/variant", - "synonyms": ["MSI status", "MSI type", "MSI"], - "candidate_vocab_families": [], - "entity_role": "attribute", - "needs_stage_c": True, - }, - }, - { - "input": { - "table_name": "mutation", - "column": "hugo_symbol", - "data_type": "STRING", - "top_values": "TP53, KRAS, EGFR, PIK3CA", - "entity_context": "Somatic Mutation", - }, - "output": { - "canonical_property_label": "gene symbol", - "semantic_type": "biomarker/gene/variant", - "synonyms": ["gene name", "HGNC symbol", "gene"], - "candidate_vocab_families": [ - "gene symbol namespace", - ], - "entity_role": "attribute", - "needs_stage_c": False, - }, - }, - { - "input": { - "table_name": "mutation", - "column": "variant_classification", - "data_type": "STRING", - "top_values": "Missense_Mutation, Silent, " - "Frame_Shift_Del, Nonsense_Mutation", - "entity_context": "Somatic Mutation", - }, - "output": { - "canonical_property_label": "variant effect", - "semantic_type": "biomarker/gene/variant", - "synonyms": ["mutation type", "variant type", "mutation effect"], - "candidate_vocab_families": [ - "variant effect classification", - ], - "entity_role": "attribute", - "needs_stage_c": False, - }, - }, - { - "input": { - "table_name": "treatment", - "column": "agent", - "data_type": "STRING", - "top_values": "PACLITAXEL, CAPECITABINE, LETROZOLE", - "entity_context": "Treatment Event", - }, - "output": { - "canonical_property_label": "drug/agent name", - "semantic_type": "therapy/drug/regimen", - "synonyms": ["drug", "therapeutic agent", "medication"], - "candidate_vocab_families": [ - "drug naming system", - ], - "entity_role": "attribute", - "needs_stage_c": False, - }, - }, - { - "input": { - "table_name": "patient", - "column": "stage_highest", - "data_type": "STRING", - "top_values": "I, II, III, IV, IA, IIB", - "entity_context": "Patient", - }, - "output": { - "canonical_property_label": "highest cancer stage", - "semantic_type": "diagnosis/condition", - "synonyms": ["overall stage", "pathologic stage", "tumor stage"], - "candidate_vocab_families": [ - "cancer staging system", - ], - "entity_role": "attribute", - "needs_stage_c": True, - }, - }, -] - -HEALTHCARE_STAGE_C: list[dict[str, Any]] = [ - { - "input": { - "table_name": "sample", - "column": "msi_type", - "values": ["Instable (45%)", "Stable (55%)"], - }, - "output": { - "decoded_categories": [ - {"raw": "Instable", - "label": "microsatellite instability high (MSI-H)"}, - {"raw": "Stable", - "label": "microsatellite stable (MSS)"}, - ], - "uncertainty": 0.1, - "codebook_lookup_needed": False, - }, - }, - { - "input": { - "table_name": "treatment", - "column": "treatment_subtype", - "values": [ - "Immuno (15%)", "Chemo (40%)", "Hormone (20%)", - "Targeted (15%)", "Other (10%)", - ], - }, - "output": { - "decoded_categories": [ - {"raw": "Immuno", "label": "immunotherapy"}, - {"raw": "Chemo", "label": "chemotherapy"}, - {"raw": "Hormone", "label": "hormonal therapy"}, - {"raw": "Targeted", "label": "targeted therapy"}, - {"raw": "Other", "label": "other therapy type"}, - ], - "uncertainty": 0.05, - "codebook_lookup_needed": False, - }, - }, - { - "input": { - "table_name": "patient", - "column": "os_status", - "values": ["1:DECEASED (40%)", "0:LIVING (60%)"], - }, - "output": { - "decoded_categories": [ - {"raw": "1:DECEASED", "label": "patient died"}, - {"raw": "0:LIVING", "label": "patient alive"}, - ], - "uncertainty": 0.0, - "codebook_lookup_needed": False, - }, - }, - { - "input": { - "table_name": "sample", - "column": "sample_type", - "values": [ - "Primary (60%)", "Metastasis (30%)", - "Normal (5%)", "Unknown (5%)", - ], - }, - "output": { - "decoded_categories": [ - {"raw": "Primary", - "label": "primary tumor site"}, - {"raw": "Metastasis", - "label": "metastatic site"}, - {"raw": "Normal", - "label": "normal tissue"}, - {"raw": "Unknown", - "label": "unknown sample origin"}, - ], - "uncertainty": 0.05, - "codebook_lookup_needed": False, - }, - }, - { - "input": { - "table_name": "patient", - "column": "gender", - "values": ["Male (55%)", "Female (43%)", "Other (2%)"], - }, - "output": { - "decoded_categories": [ - {"raw": "Male", "label": "male biological sex"}, - {"raw": "Female", "label": "female biological sex"}, - {"raw": "Other", - "label": "other/unspecified biological sex"}, - ], - "uncertainty": 0.1, - "codebook_lookup_needed": False, - }, - }, - { - "input": { - "table_name": "patient", - "column": "stage_highest", - "values": [ - "IV (25%)", "III (20%)", "II (20%)", - "I (15%)", "IA (8%)", "IIB (7%)", "IIIA (5%)", - ], - }, - "output": { - "decoded_categories": [ - {"raw": "I", "label": "AJCC stage I"}, - {"raw": "IA", "label": "AJCC stage IA"}, - {"raw": "II", "label": "AJCC stage II"}, - {"raw": "IIB", "label": "AJCC stage IIB"}, - {"raw": "III", "label": "AJCC stage III"}, - {"raw": "IIIA", "label": "AJCC stage IIIA"}, - {"raw": "IV", "label": "AJCC stage IV"}, - ], - "uncertainty": 0.15, - "codebook_lookup_needed": False, - }, - }, - { - "input": { - "table_name": "mutation", - "column": "variant_classification", - "values": [ - "Silent (30%)", "Missense_Mutation (45%)", - "Nonsense_Mutation (10%)", "Frame_Shift_Del (8%)", - "Splice_Site (7%)", - ], - }, - "output": { - "decoded_categories": [ - {"raw": "Silent", - "label": "synonymous, no protein change"}, - {"raw": "Missense_Mutation", - "label": "single amino acid change"}, - {"raw": "Nonsense_Mutation", - "label": "premature stop codon"}, - {"raw": "Frame_Shift_Del", - "label": "frameshift deletion"}, - {"raw": "Splice_Site", - "label": "splice site disruption"}, - ], - "uncertainty": 0.05, - "codebook_lookup_needed": False, - }, - }, - { - "input": { - "table_name": "progression", - "column": "progression", - "values": ["Y (35%)", "N (65%)"], - }, - "output": { - "decoded_categories": [ - {"raw": "Y", - "label": "disease progressed"}, - {"raw": "N", - "label": "no disease progression"}, - ], - "uncertainty": 0.0, - "codebook_lookup_needed": False, - }, - }, +__all__ = [ + "HEALTHCARE_STAGE_A", + "HEALTHCARE_STAGE_B", + "HEALTHCARE_STAGE_C", ] diff --git a/src/sema/engine/few_shot_healthcare_stage_a.py b/src/sema/engine/few_shot_healthcare_stage_a.py new file mode 100644 index 0000000..328cab8 --- /dev/null +++ b/src/sema/engine/few_shot_healthcare_stage_a.py @@ -0,0 +1,165 @@ +"""Stage A healthcare few-shot examples (table-level entity classification).""" +from __future__ import annotations + +from typing import Any + +HEALTHCARE_STAGE_A: list[dict[str, Any]] = [ + { + "input": { + "table_name": "patient", + "columns": "patient_id (STRING), gender (STRING), " + "current_age (INT), os_status (STRING), os_months (DOUBLE), " + "dfs_status (STRING), dfs_months (DOUBLE), " + "smoking_status (STRING), stage_highest (STRING)", + }, + "output": { + "primary_entity": "Patient", + "grain_hypothesis": "one row per patient", + "secondary_entity_hints": [ + "cancer diagnosis", "survival outcome", + ], + "ambiguity_flags": [], + "confidence": 0.95, + }, + }, + { + "input": { + "table_name": "sample", + "columns": "sample_id (STRING), patient_id (STRING), " + "cancer_type (STRING), cancer_type_detailed (STRING), " + "sample_type (STRING), tmb (DOUBLE), msi_type (STRING), " + "oncotree_code (STRING), sample_class (STRING)", + }, + "output": { + "primary_entity": "Biospecimen/Sample", + "grain_hypothesis": "one row per tumor sample " + "(multiple samples per patient)", + "secondary_entity_hints": ["tumor characterization"], + "ambiguity_flags": [], + "confidence": 0.9, + }, + }, + { + "input": { + "table_name": "mutation", + "columns": "sample_id (STRING), hugo_symbol (STRING), " + "variant_classification (STRING), hgvsp_short (STRING), " + "chromosome (STRING), start_position (INT), " + "end_position (INT), reference_allele (STRING), " + "tumor_seq_allele2 (STRING), mutation_status (STRING)", + }, + "output": { + "primary_entity": "Somatic Mutation", + "grain_hypothesis": "one row per variant call per sample", + "secondary_entity_hints": [ + "gene", "protein change", + ], + "ambiguity_flags": [], + "confidence": 0.9, + }, + }, + { + "input": { + "table_name": "treatment", + "columns": "patient_id (STRING), treatment_subtype (STRING), " + "agent (STRING), start_date (INT), stop_date (INT)", + }, + "output": { + "primary_entity": "Treatment Event", + "grain_hypothesis": "one row per treatment event " + "(multiple events per patient)", + "secondary_entity_hints": ["drug/agent", "regimen"], + "ambiguity_flags": [], + "confidence": 0.85, + }, + }, + { + "input": { + "table_name": "structural_variant", + "columns": "sample_id (STRING), site1_gene (STRING), " + "site2_gene (STRING), sv_class (STRING), " + "event_info (STRING), annotation (STRING)", + }, + "output": { + "primary_entity": "Structural Variant", + "grain_hypothesis": "one row per structural variant " + "call per sample", + "secondary_entity_hints": ["fusion partner genes"], + "ambiguity_flags": [], + "confidence": 0.85, + }, + }, + { + "input": { + "table_name": "timeline_labtest", + "columns": "PATIENT_ID (STRING), START_DATE (INT), " + "STOP_DATE (INT), EVENT_TYPE (STRING), TEST (STRING), " + "VALUE (DOUBLE), UNITS (STRING)", + }, + "output": { + "primary_entity": "Lab Measurement", + "grain_hypothesis": "one row per lab measurement event " + "(many per patient over time)", + "secondary_entity_hints": [ + "analyte", "unit of measure", "longitudinal observation", + ], + "ambiguity_flags": [], + "confidence": 0.9, + }, + }, + { + "input": { + "table_name": "timeline_surgery", + "columns": "PATIENT_ID (STRING), START_DATE (INT), " + "STOP_DATE (INT), EVENT_TYPE (STRING), " + "PROCEDURE (STRING), PROCEDURE_CODE (STRING), " + "SURGERY_DETAILS (STRING)", + }, + "output": { + "primary_entity": "Procedure Event", + "grain_hypothesis": "one row per procedure (surgery, " + "radiation, specimen collection) per patient over time", + "secondary_entity_hints": [ + "procedure code", "clinical event timeline", + ], + "ambiguity_flags": [], + "confidence": 0.85, + }, + }, + { + "input": { + "table_name": "timeline_performance_status", + "columns": "PATIENT_ID (STRING), START_DATE (INT), " + "EVENT_TYPE (STRING), ECOG_SCORE (STRING), " + "KARNOFSKY_SCORE (STRING)", + }, + "output": { + "primary_entity": "Performance Status Assessment", + "grain_hypothesis": "one row per performance status " + "assessment per patient over time", + "secondary_entity_hints": [ + "functional status score", "longitudinal observation", + ], + "ambiguity_flags": [], + "confidence": 0.85, + }, + }, + { + "input": { + "table_name": "cna_segmented", + "columns": "sample_id (STRING), chrom (STRING), " + "loc_start (BIGINT), loc_end (BIGINT), " + "num_mark (BIGINT), seg_mean (DOUBLE)", + }, + "output": { + "primary_entity": "Copy Number Segment", + "grain_hypothesis": "one row per genomic segment per sample " + "(many segments per sample, segment-level CNA calls)", + "secondary_entity_hints": [ + "genomic interval", "copy number signal", + ], + "ambiguity_flags": [], + "confidence": 0.85, + }, + }, +] diff --git a/src/sema/engine/few_shot_healthcare_stage_b.py b/src/sema/engine/few_shot_healthcare_stage_b.py new file mode 100644 index 0000000..fa8d7f6 --- /dev/null +++ b/src/sema/engine/few_shot_healthcare_stage_b.py @@ -0,0 +1,372 @@ +"""Stage B healthcare few-shot examples (column-level semantic typing).""" +from __future__ import annotations + +from typing import Any + +HEALTHCARE_STAGE_B: list[dict[str, Any]] = [ + { + "input": { + "table_name": "patient", + "column": "patient_id", + "data_type": "STRING", + "entity_context": "Patient", + }, + "output": { + "canonical_property_label": "patient identifier", + "semantic_type": "patient identifier", + "synonyms": ["subject id", "case id", "participant id"], + "candidate_vocab_families": [], + "entity_role": "primary_key", + "needs_stage_c": False, + }, + }, + { + "input": { + "table_name": "mutation", + "column": "sample_id", + "data_type": "STRING", + "entity_context": "Somatic Mutation", + }, + "output": { + "canonical_property_label": "sample identifier", + "semantic_type": "specimen/sample identifier", + "synonyms": ["specimen id", "biospecimen id", "tumor sample id"], + "candidate_vocab_families": [], + "entity_role": "foreign_key", + "needs_stage_c": False, + }, + }, + { + "input": { + "table_name": "patient", + "column": "gender", + "data_type": "STRING", + "top_values": "Male, Female, Other", + "entity_context": "Patient", + }, + "output": { + "canonical_property_label": "biological sex", + "semantic_type": "demographic", + "synonyms": ["sex", "biological sex"], + "candidate_vocab_families": [], + "entity_role": "attribute", + "needs_stage_c": True, + }, + }, + { + "input": { + "table_name": "treatment", + "column": "start_date", + "data_type": "INT", + "entity_context": "Treatment Event", + }, + "output": { + "canonical_property_label": "treatment start date", + "semantic_type": "temporal field", + "candidate_vocab_families": [ + "days-from-epoch encoding", + ], + "entity_role": "attribute", + "needs_stage_c": True, + }, + }, + { + "input": { + "table_name": "sample", + "column": "cancer_type", + "data_type": "STRING", + "top_values": "Non-Small Cell Lung Cancer, " + "Colorectal Cancer, Breast Cancer", + "entity_context": "Biospecimen/Sample", + }, + "output": { + "canonical_property_label": "cancer type", + "semantic_type": "diagnosis/condition", + "candidate_vocab_families": [ + "cancer classification system", + ], + "entity_role": "attribute", + "needs_stage_c": False, + }, + }, + { + "input": { + "table_name": "sample", + "column": "cancer_type_detailed", + "data_type": "STRING", + "entity_context": "Biospecimen/Sample", + }, + "output": { + "canonical_property_label": "cancer subtype", + "semantic_type": "diagnosis/condition", + "candidate_vocab_families": [ + "cancer subtype classification", + ], + "entity_role": "attribute", + "needs_stage_c": False, + }, + }, + { + "input": { + "table_name": "sample", + "column": "tmb", + "data_type": "DOUBLE", + "entity_context": "Biospecimen/Sample", + }, + "output": { + "canonical_property_label": "tumor mutational burden", + "semantic_type": "biomarker/gene/variant", + "synonyms": ["tmb", "mutations per megabase", "mutation burden"], + "candidate_vocab_families": [], + "entity_role": "attribute", + "needs_stage_c": False, + }, + }, + { + "input": { + "table_name": "sample", + "column": "msi_type", + "data_type": "STRING", + "top_values": "Instable, Stable", + "entity_context": "Biospecimen/Sample", + }, + "output": { + "canonical_property_label": "microsatellite instability", + "semantic_type": "biomarker/gene/variant", + "synonyms": ["MSI status", "MSI type", "MSI"], + "candidate_vocab_families": [], + "entity_role": "attribute", + "needs_stage_c": True, + }, + }, + { + "input": { + "table_name": "mutation", + "column": "hugo_symbol", + "data_type": "STRING", + "top_values": "TP53, KRAS, EGFR, PIK3CA", + "entity_context": "Somatic Mutation", + }, + "output": { + "canonical_property_label": "gene symbol", + "semantic_type": "biomarker/gene/variant", + "synonyms": ["gene name", "HGNC symbol", "gene"], + "candidate_vocab_families": [ + "gene symbol namespace", + ], + "entity_role": "attribute", + "needs_stage_c": False, + }, + }, + { + "input": { + "table_name": "mutation", + "column": "variant_classification", + "data_type": "STRING", + "top_values": "Missense_Mutation, Silent, " + "Frame_Shift_Del, Nonsense_Mutation", + "entity_context": "Somatic Mutation", + }, + "output": { + "canonical_property_label": "variant effect", + "semantic_type": "biomarker/gene/variant", + "synonyms": ["mutation type", "variant type", "mutation effect"], + "candidate_vocab_families": [ + "variant effect classification", + ], + "entity_role": "attribute", + "needs_stage_c": False, + }, + }, + { + "input": { + "table_name": "treatment", + "column": "agent", + "data_type": "STRING", + "top_values": "PACLITAXEL, CAPECITABINE, LETROZOLE", + "entity_context": "Treatment Event", + }, + "output": { + "canonical_property_label": "drug/agent name", + "semantic_type": "therapy/drug/regimen", + "synonyms": ["drug", "therapeutic agent", "medication"], + "candidate_vocab_families": [ + "drug naming system", + ], + "entity_role": "attribute", + "needs_stage_c": False, + }, + }, + { + "input": { + "table_name": "patient", + "column": "stage_highest", + "data_type": "STRING", + "top_values": "I, II, III, IV, IA, IIB", + "entity_context": "Patient", + }, + "output": { + "canonical_property_label": "highest cancer stage", + "semantic_type": "diagnosis/condition", + "synonyms": ["overall stage", "pathologic stage", "tumor stage"], + "candidate_vocab_families": [ + "cancer staging system", + ], + "entity_role": "attribute", + "needs_stage_c": True, + }, + }, + { + "input": { + "table_name": "timeline_labtest", + "column": "VALUE", + "data_type": "DOUBLE", + "top_values": "13.5, 8.2, 1.1, 110.0", + "entity_context": "Lab Measurement", + }, + "output": { + "canonical_property_label": "lab value", + "semantic_type": "measurement", + "synonyms": ["result value", "test result", "measurement value"], + "candidate_vocab_families": [], + "entity_role": "attribute", + "needs_stage_c": False, + }, + }, + { + "input": { + "table_name": "timeline_labtest", + "column": "UNITS", + "data_type": "STRING", + "top_values": "g/dL, mg/dL, mmol/L, %", + "entity_context": "Lab Measurement", + }, + "output": { + "canonical_property_label": "unit of measure", + "semantic_type": "unit of measure", + "synonyms": ["units", "uom", "measurement unit"], + "candidate_vocab_families": ["unit of measure system"], + "entity_role": "attribute", + "needs_stage_c": False, + }, + }, + { + "input": { + "table_name": "timeline_labtest", + "column": "TEST", + "data_type": "STRING", + "top_values": "Hemoglobin, Creatinine, A1c, Sodium", + "entity_context": "Lab Measurement", + }, + "output": { + "canonical_property_label": "lab test name", + "semantic_type": "biomarker/gene/variant", + "synonyms": ["analyte", "lab test", "assay name"], + "candidate_vocab_families": ["clinical observation naming system"], + "entity_role": "attribute", + "needs_stage_c": False, + }, + }, + { + "input": { + "table_name": "biomarker", + "column": "pd_l1_status", + "data_type": "STRING", + "top_values": "POSITIVE, NEGATIVE, EQUIVOCAL", + "entity_context": "Biomarker State", + }, + "output": { + "canonical_property_label": "PD-L1 expression class", + "semantic_type": "biomarker/gene/variant", + "synonyms": ["pd-l1 expression", "pdl1 status", "pd-l1 class"], + "candidate_vocab_families": [], + "entity_role": "attribute", + "needs_stage_c": True, + }, + }, + { + "input": { + "table_name": "biomarker", + "column": "mmr_status", + "data_type": "STRING", + "top_values": "Proficient, Deficient", + "entity_context": "Biomarker State", + }, + "output": { + "canonical_property_label": "mismatch repair status", + "semantic_type": "biomarker/gene/variant", + "synonyms": ["mmr proficiency", "mismatch repair", "mmr"], + "candidate_vocab_families": [], + "entity_role": "attribute", + "needs_stage_c": True, + }, + }, + { + "input": { + "table_name": "biomarker", + "column": "gleason_score", + "data_type": "STRING", + "top_values": "6, 7, 8, 9, 10", + "entity_context": "Biomarker State", + }, + "output": { + "canonical_property_label": "gleason grade", + "semantic_type": "biomarker/gene/variant", + "synonyms": ["gleason sum", "gleason grade", "prostate grade"], + "candidate_vocab_families": [], + "entity_role": "attribute", + "needs_stage_c": True, + }, + }, + { + "input": { + "table_name": "timeline_performance_status", + "column": "ECOG_SCORE", + "data_type": "STRING", + "top_values": "0, 1, 2, 3, 4", + "entity_context": "Performance Status Assessment", + }, + "output": { + "canonical_property_label": "ECOG performance score", + "semantic_type": "clinical assessment score", + "synonyms": ["ECOG", "ECOG PS", "performance status"], + "candidate_vocab_families": [], + "entity_role": "attribute", + "needs_stage_c": True, + }, + }, + { + "input": { + "table_name": "timeline_performance_status", + "column": "KARNOFSKY_SCORE", + "data_type": "STRING", + "top_values": "100, 90, 80, 70, 60", + "entity_context": "Performance Status Assessment", + }, + "output": { + "canonical_property_label": "Karnofsky performance score", + "semantic_type": "clinical assessment score", + "synonyms": ["KPS", "karnofsky index", "performance index"], + "candidate_vocab_families": [], + "entity_role": "attribute", + "needs_stage_c": True, + }, + }, + { + "input": { + "table_name": "timeline_surgery", + "column": "PROCEDURE_CODE", + "data_type": "STRING", + "top_values": "0DTJ0ZZ, 0DBJ0ZZ, 0FTG0ZZ", + "entity_context": "Procedure Event", + }, + "output": { + "canonical_property_label": "procedure code", + "semantic_type": "procedure code", + "synonyms": ["surgery code", "intervention code", "procedure code"], + "candidate_vocab_families": ["procedure code system"], + "entity_role": "attribute", + "needs_stage_c": False, + }, + }, +] diff --git a/src/sema/engine/few_shot_healthcare_stage_c.py b/src/sema/engine/few_shot_healthcare_stage_c.py new file mode 100644 index 0000000..9da7d3b --- /dev/null +++ b/src/sema/engine/few_shot_healthcare_stage_c.py @@ -0,0 +1,268 @@ +"""Stage C healthcare few-shot examples (value decoding).""" +from __future__ import annotations + +from typing import Any + +HEALTHCARE_STAGE_C: list[dict[str, Any]] = [ + { + "input": { + "table_name": "sample", + "column": "msi_type", + "values": ["Instable (45%)", "Stable (55%)"], + }, + "output": { + "decoded_categories": [ + {"raw": "Instable", + "label": "microsatellite instability high (MSI-H)"}, + {"raw": "Stable", + "label": "microsatellite stable (MSS)"}, + ], + "uncertainty": 0.1, + "codebook_lookup_needed": False, + }, + }, + { + "input": { + "table_name": "treatment", + "column": "treatment_subtype", + "values": [ + "Immuno (15%)", "Chemo (40%)", "Hormone (20%)", + "Targeted (15%)", "Other (10%)", + ], + }, + "output": { + "decoded_categories": [ + {"raw": "Immuno", "label": "immunotherapy"}, + {"raw": "Chemo", "label": "chemotherapy"}, + {"raw": "Hormone", "label": "hormonal therapy"}, + {"raw": "Targeted", "label": "targeted therapy"}, + {"raw": "Other", "label": "other therapy type"}, + ], + "uncertainty": 0.05, + "codebook_lookup_needed": False, + }, + }, + { + "input": { + "table_name": "patient", + "column": "os_status", + "values": ["1:DECEASED (40%)", "0:LIVING (60%)"], + }, + "output": { + "decoded_categories": [ + {"raw": "1:DECEASED", "label": "patient died"}, + {"raw": "0:LIVING", "label": "patient alive"}, + ], + "uncertainty": 0.0, + "codebook_lookup_needed": False, + }, + }, + { + "input": { + "table_name": "sample", + "column": "sample_type", + "values": [ + "Primary (60%)", "Metastasis (30%)", + "Normal (5%)", "Unknown (5%)", + ], + }, + "output": { + "decoded_categories": [ + {"raw": "Primary", + "label": "primary tumor site"}, + {"raw": "Metastasis", + "label": "metastatic site"}, + {"raw": "Normal", + "label": "normal tissue"}, + {"raw": "Unknown", + "label": "unknown sample origin"}, + ], + "uncertainty": 0.05, + "codebook_lookup_needed": False, + }, + }, + { + "input": { + "table_name": "patient", + "column": "gender", + "values": ["Male (55%)", "Female (43%)", "Other (2%)"], + }, + "output": { + "decoded_categories": [ + {"raw": "Male", "label": "male biological sex"}, + {"raw": "Female", "label": "female biological sex"}, + {"raw": "Other", + "label": "other/unspecified biological sex"}, + ], + "uncertainty": 0.1, + "codebook_lookup_needed": False, + }, + }, + { + "input": { + "table_name": "patient", + "column": "stage_highest", + "values": [ + "IV (25%)", "III (20%)", "II (20%)", + "I (15%)", "IA (8%)", "IIB (7%)", "IIIA (5%)", + ], + }, + "output": { + "decoded_categories": [ + {"raw": "I", "label": "AJCC stage I"}, + {"raw": "IA", "label": "AJCC stage IA"}, + {"raw": "II", "label": "AJCC stage II"}, + {"raw": "IIB", "label": "AJCC stage IIB"}, + {"raw": "III", "label": "AJCC stage III"}, + {"raw": "IIIA", "label": "AJCC stage IIIA"}, + {"raw": "IV", "label": "AJCC stage IV"}, + ], + "uncertainty": 0.15, + "codebook_lookup_needed": False, + }, + }, + { + "input": { + "table_name": "mutation", + "column": "variant_classification", + "values": [ + "Silent (30%)", "Missense_Mutation (45%)", + "Nonsense_Mutation (10%)", "Frame_Shift_Del (8%)", + "Splice_Site (7%)", + ], + }, + "output": { + "decoded_categories": [ + {"raw": "Silent", + "label": "synonymous, no protein change"}, + {"raw": "Missense_Mutation", + "label": "single amino acid change"}, + {"raw": "Nonsense_Mutation", + "label": "premature stop codon"}, + {"raw": "Frame_Shift_Del", + "label": "frameshift deletion"}, + {"raw": "Splice_Site", + "label": "splice site disruption"}, + ], + "uncertainty": 0.05, + "codebook_lookup_needed": False, + }, + }, + { + "input": { + "table_name": "progression", + "column": "progression", + "values": ["Y (35%)", "N (65%)"], + }, + "output": { + "decoded_categories": [ + {"raw": "Y", + "label": "disease progressed"}, + {"raw": "N", + "label": "no disease progression"}, + ], + "uncertainty": 0.0, + "codebook_lookup_needed": False, + }, + }, + { + "input": { + "table_name": "biomarker", + "column": "pd_l1_status", + "values": ["POSITIVE (45%)", "NEGATIVE (50%)", "EQUIVOCAL (5%)"], + }, + "output": { + "decoded_categories": [ + {"raw": "POSITIVE", + "label": "PD-L1 positive expression (TPS or CPS above threshold)"}, + {"raw": "NEGATIVE", + "label": "PD-L1 negative (below assay threshold)"}, + {"raw": "EQUIVOCAL", + "label": "PD-L1 equivocal/indeterminate result"}, + ], + "uncertainty": 0.1, + "codebook_lookup_needed": False, + }, + }, + { + "input": { + "table_name": "biomarker", + "column": "mmr_status", + "values": ["Proficient (80%)", "Deficient (20%)"], + }, + "output": { + "decoded_categories": [ + {"raw": "Proficient", + "label": "mismatch repair proficient (pMMR)"}, + {"raw": "Deficient", + "label": "mismatch repair deficient (dMMR / MSI-H surrogate)"}, + ], + "uncertainty": 0.05, + "codebook_lookup_needed": False, + }, + }, + { + "input": { + "table_name": "biomarker", + "column": "gleason_score", + "values": ["6 (10%)", "7 (40%)", "8 (25%)", "9 (15%)", "10 (10%)"], + }, + "output": { + "decoded_categories": [ + {"raw": "6", "label": "Gleason 6 (Grade Group 1, low-grade)"}, + {"raw": "7", "label": "Gleason 7 (Grade Group 2 or 3, intermediate)"}, + {"raw": "8", "label": "Gleason 8 (Grade Group 4, high-grade)"}, + {"raw": "9", "label": "Gleason 9 (Grade Group 5, very high-grade)"}, + {"raw": "10", "label": "Gleason 10 (Grade Group 5, very high-grade)"}, + ], + "uncertainty": 0.1, + "codebook_lookup_needed": False, + }, + }, + { + "input": { + "table_name": "timeline_performance_status", + "column": "ECOG_SCORE", + "values": ["0 (40%)", "1 (35%)", "2 (15%)", "3 (8%)", "4 (2%)"], + }, + "output": { + "decoded_categories": [ + {"raw": "0", "label": "ECOG 0 — fully active, no restriction"}, + {"raw": "1", "label": "ECOG 1 — restricted in strenuous activity"}, + {"raw": "2", "label": "ECOG 2 — ambulatory, self-care, <50% in bed"}, + {"raw": "3", "label": "ECOG 3 — limited self-care, >50% in bed"}, + {"raw": "4", "label": "ECOG 4 — completely disabled, no self-care"}, + {"raw": "5", "label": "ECOG 5 — dead"}, + ], + "uncertainty": 0.05, + "codebook_lookup_needed": False, + }, + }, + { + "input": { + "table_name": "timeline_performance_status", + "column": "KARNOFSKY_SCORE", + "values": [ + "100 (20%)", "90 (25%)", "80 (20%)", "70 (15%)", + "60 (10%)", "50 (5%)", "40 (3%)", "30 (1%)", "20 (1%)", + ], + }, + "output": { + "decoded_categories": [ + {"raw": "100", "label": "Karnofsky 100 — normal, no complaints"}, + {"raw": "90", "label": "Karnofsky 90 — normal activity, minor symptoms"}, + {"raw": "80", "label": "Karnofsky 80 — normal activity with effort"}, + {"raw": "70", "label": "Karnofsky 70 — cares for self, can't do normal activity"}, + {"raw": "60", "label": "Karnofsky 60 — needs occasional assistance"}, + {"raw": "50", "label": "Karnofsky 50 — needs considerable assistance"}, + {"raw": "40", "label": "Karnofsky 40 — disabled, requires special care"}, + {"raw": "30", "label": "Karnofsky 30 — severely disabled, hospitalization indicated"}, + {"raw": "20", "label": "Karnofsky 20 — very sick, active supportive treatment"}, + {"raw": "10", "label": "Karnofsky 10 — moribund"}, + {"raw": "0", "label": "Karnofsky 0 — dead"}, + ], + "uncertainty": 0.1, + "codebook_lookup_needed": False, + }, + }, +] diff --git a/src/sema/engine/join_detector.py b/src/sema/engine/join_detector.py new file mode 100644 index 0000000..abc373e --- /dev/null +++ b/src/sema/engine/join_detector.py @@ -0,0 +1,201 @@ +"""Domain-agnostic FK / join detector. + +Enumerates intra-schema FK candidates from column metadata, verifies +each via tiered evidence, and emits `FK_TO` assertions tagged with +`source_schema`. Verification uses bounded distinct-value samples +when available; the detector NEVER issues unbounded +`COUNT(... NOT IN ...)` referential-integrity scans — when sample-based +verification is inconclusive, the detector downgrades confidence +rather than escalating cost. +""" +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Callable + +from sema.engine.join_detector_utils import ( + FKCandidate, + coverage_ratio, + enumerate_candidates_from_metadata, + verify_cardinality, +) +from sema.models.assertions import ( + Assertion, + AssertionPredicate, + AssertionStatus, +) +from sema.models.extraction import ExtractedColumn + +ColumnKey = tuple[str, str, str] +SamplerFn = Callable[[ColumnKey], set[str] | None] +DEFAULT_SAMPLE_CAP = 500 +DEFAULT_MATERIALIZE_THRESHOLD = 0.80 +TIER_1 = 0.95 +TIER_2 = 0.80 +TIER_3 = 0.70 + + +@dataclass(frozen=True) +class FKAssertion: + """Detector output: candidate + tier + confidence + provenance.""" + candidate: FKCandidate + confidence: float + tier: int + source_schema: str + + +def _column_key(table: str, column: str, schema: str) -> ColumnKey: + return (schema, table, column) + + +@dataclass +class JoinDetector: + """FK / join candidate detector. + + `sample_cap` bounds detector-owned distinct-value samples. When FK + distinct cardinality exceeds this cap the detector downgrades to + Tier 2 if cardinality metadata supports the FK relation, else + Tier 3. The detector NEVER scales up to unbounded RI queries. + """ + sample_cap: int = DEFAULT_SAMPLE_CAP + materialization_threshold: float = DEFAULT_MATERIALIZE_THRESHOLD + + def detect( + self, + *, + columns: list[ExtractedColumn], + source_schema: str, + profiles: dict[ColumnKey, tuple[int, int]] | None = None, + samples: dict[ColumnKey, set[str]] | None = None, + sampler: SamplerFn | None = None, + ) -> list[FKAssertion]: + candidates = enumerate_candidates_from_metadata(columns) + return [ + self._classify(c, source_schema, profiles, samples, sampler) + for c in candidates + ] + + def should_materialize(self, fk: FKAssertion) -> bool: + return fk.confidence >= self.materialization_threshold + + def _classify( + self, + candidate: FKCandidate, + source_schema: str, + profiles: dict[ColumnKey, tuple[int, int]] | None, + samples: dict[ColumnKey, set[str]] | None, + sampler: SamplerFn | None, + ) -> FKAssertion: + pk_key = _column_key( + candidate.pk_table, candidate.pk_column, candidate.schema_name, + ) + fk_key = _column_key( + candidate.fk_table, candidate.fk_column, candidate.schema_name, + ) + + pk_sample, fk_sample, sample_origin = self._collect_samples( + pk_key, fk_key, samples, sampler, + ) + if pk_sample is not None and fk_sample is not None: + tier1 = self._try_tier_1(pk_sample, fk_sample) + if tier1 is not None: + return FKAssertion( + candidate, tier1, 1, source_schema, + ) + + tier2 = self._try_tier_2(pk_key, fk_key, profiles) + if tier2 is not None: + return FKAssertion(candidate, tier2, 2, source_schema) + + return FKAssertion(candidate, TIER_3, 3, source_schema) + + def _collect_samples( + self, + pk_key: ColumnKey, + fk_key: ColumnKey, + samples: dict[ColumnKey, set[str]] | None, + sampler: SamplerFn | None, + ) -> tuple[set[str] | None, set[str] | None, str]: + if samples and pk_key in samples and fk_key in samples: + return samples[pk_key], samples[fk_key], "profiler" + if sampler is None: + return None, None, "missing" + try: + pk = sampler(pk_key) + fk = sampler(fk_key) + except Exception: + return None, None, "warehouse_error" + if pk is None or fk is None: + return None, None, "missing" + return pk, fk, "detector" + + def _try_tier_1( + self, pk_sample: set[str], fk_sample: set[str], + ) -> float | None: + """Tier 1 sample subset test. + + When FK distinct cardinality has hit the sample cap the subset + relation is unprovable — return None so the caller falls + through to Tier 2 / Tier 3 without escalating to RI scans. + """ + if len(fk_sample) >= self.sample_cap: + return None + if not fk_sample: + return None + if coverage_ratio(fk_sample, pk_sample) < 0.80: + return None + return TIER_1 + + def _try_tier_2( + self, + pk_key: ColumnKey, + fk_key: ColumnKey, + profiles: dict[ColumnKey, tuple[int, int]] | None, + ) -> float | None: + if not profiles: + return None + pk_stats = profiles.get(pk_key) + fk_stats = profiles.get(fk_key) + if pk_stats is None or fk_stats is None: + return None + pk_distinct, pk_rows = pk_stats + fk_distinct, _ = fk_stats + if not verify_cardinality(pk_distinct, pk_rows, fk_distinct): + return None + return TIER_2 + + +def to_fk_assertion( + fk: FKAssertion, run_id: str, +) -> Assertion: + cat = fk.candidate.catalog or "" + schema = fk.candidate.schema_name + fk_ref = ( + f"databricks://workspace/{cat}/{schema}/" + f"{fk.candidate.fk_table}/{fk.candidate.fk_column}" + ) + pk_ref = ( + f"databricks://workspace/{cat}/{schema}/" + f"{fk.candidate.pk_table}/{fk.candidate.pk_column}" + ) + return Assertion( + id=str(uuid.uuid4()), + subject_ref=fk_ref, + predicate=AssertionPredicate.FK_TO, + payload={ + "pk_table": fk.candidate.pk_table, + "pk_column": fk.candidate.pk_column, + "fk_table": fk.candidate.fk_table, + "fk_column": fk.candidate.fk_column, + "tier": fk.tier, + }, + object_ref=pk_ref, + source="fk_detector", + confidence=fk.confidence, + status=AssertionStatus.AUTO, + run_id=run_id, + observed_at=datetime.now(timezone.utc), + source_schema=fk.source_schema, + ) diff --git a/src/sema/engine/join_detector_utils.py b/src/sema/engine/join_detector_utils.py new file mode 100644 index 0000000..288a695 --- /dev/null +++ b/src/sema/engine/join_detector_utils.py @@ -0,0 +1,158 @@ +"""Helpers for the FK / join detector. + +Pure-functional name-pattern matching, type compatibility, and +sample-set verification. The detector itself orchestrates these +helpers; the helpers never touch I/O. +""" +from __future__ import annotations + +import re +from dataclasses import dataclass + +from sema.models.extraction import ExtractedColumn + +_INTEGER_TYPES = frozenset({ + "int", "integer", "bigint", "smallint", "tinyint", "long", +}) +_TEXT_TYPES = frozenset({ + "string", "varchar", "char", "text", "uuid", +}) +_FK_NAME_SUFFIX_RE = re.compile(r"^(?P.+?)_(id|key|code)$") + + +@dataclass(frozen=True) +class FKCandidate: + """A potential foreign-key relationship between two columns. + + PK / FK semantics: the FK column REFERENCES the PK column. + Both columns are constrained to the same `schema_name`. + """ + pk_table: str + pk_column: str + fk_table: str + fk_column: str + pk_type: str + fk_type: str + schema_name: str + catalog: str = "" + + +def normalize_type(data_type: str) -> str: + """Lowercase, strip parameter list (e.g., `VARCHAR(64)` → `varchar`).""" + base = data_type.split("(", 1)[0].strip().lower() + return base + + +def types_compatible(left: str, right: str) -> bool: + a = normalize_type(left) + b = normalize_type(right) + if a == b: + return True + if a in _INTEGER_TYPES and b in _INTEGER_TYPES: + return True + if a in _TEXT_TYPES and b in _TEXT_TYPES: + return True + return False + + +def fk_name_root(column_name: str) -> str | None: + """Return the entity root from a FK-style column name. + + `patient_id` → `patient`, `sample_key` → `sample`, `gene_code` → `gene`. + Returns None when the name does not match the FK suffix pattern. + """ + match = _FK_NAME_SUFFIX_RE.match(column_name.lower()) + return match.group("entity") if match else None + + +def is_pk_match(pk_table: str, pk_column: str, fk_root: str) -> bool: + """A PK candidate matches the FK's entity root. + + A column qualifies as a PK candidate when (a) its table name equals + the FK root, or (b) the column name itself equals `_id` / + `_key` / `_code` — the same FK suffix pattern. + """ + if pk_table.lower() == fk_root: + return True + pk_root = fk_name_root(pk_column) + return pk_root == fk_root + + +def enumerate_candidates_from_metadata( + columns: list[ExtractedColumn], +) -> list[FKCandidate]: + """Enumerate intra-schema FK candidates by name pattern + type.""" + by_schema: dict[str, list[ExtractedColumn]] = {} + for col in columns: + by_schema.setdefault(col.schema, []).append(col) + + candidates: list[FKCandidate] = [] + for schema, cols in by_schema.items(): + candidates.extend(_candidates_within_schema(cols)) + return candidates + + +def _candidates_within_schema( + cols: list[ExtractedColumn], +) -> list[FKCandidate]: + out: list[FKCandidate] = [] + for fk_col in cols: + root = fk_name_root(fk_col.name) + if not root: + continue + if fk_col.table_name.lower() == root: + continue + for pk_col in cols: + if pk_col is fk_col: + continue + if pk_col.table_name == fk_col.table_name: + continue + if pk_col.table_name.lower() != root: + continue + if not types_compatible(pk_col.data_type, fk_col.data_type): + continue + out.append(FKCandidate( + pk_table=pk_col.table_name, + pk_column=pk_col.name, + fk_table=fk_col.table_name, + fk_column=fk_col.name, + pk_type=pk_col.data_type, + fk_type=fk_col.data_type, + schema_name=fk_col.schema, + catalog=fk_col.catalog, + )) + return out + + +def coverage_ratio( + fk_values: set[str], pk_values: set[str], +) -> float: + """Fraction of FK distinct values that appear in PK distinct values.""" + if not fk_values: + return 0.0 + return len(fk_values & pk_values) / len(fk_values) + + +def verify_data_subset( + fk_values: set[str], + pk_values: set[str], + *, + coverage_threshold: float = 0.80, +) -> bool: + """FK ⊆ PK with at least `coverage_threshold` of FK values matched.""" + if not fk_values: + return False + return coverage_ratio(fk_values, pk_values) >= coverage_threshold + + +def verify_cardinality( + pk_distinct: int, + pk_rows: int, + fk_distinct: int, +) -> bool: + """PK uniquely valued AND FK distinct count ≤ PK distinct count.""" + if pk_distinct <= 0 or pk_rows <= 0: + return False + if pk_distinct != pk_rows: + return False + return fk_distinct <= pk_distinct diff --git a/src/sema/engine/warehouse_lookup.py b/src/sema/engine/warehouse_lookup.py new file mode 100644 index 0000000..13037f5 --- /dev/null +++ b/src/sema/engine/warehouse_lookup.py @@ -0,0 +1,98 @@ +"""Lazy warehouse-backed lookups for `JoinDetector`. + +`WarehouseSampler` returns bounded distinct-value samples for a column; +`WarehouseProfileLookup` returns `(approx_distinct, row_count)`. Both +cache per column (and the profile lookup also caches `row_count` per +table to avoid redundant `COUNT(*)` scans). Both fail closed: a +warehouse error returns `None` and is cached, so the detector +downgrades confidence rather than retrying failing queries. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable + +ColumnKey = tuple[str, str, str] +TableKey = tuple[str, str] +QueryFn = Callable[[str], list[tuple[Any, ...]]] + + +@dataclass +class WarehouseSampler: + query_fn: QueryFn + catalog: str + sample_cap: int = 500 + _cache: dict[ColumnKey, set[str] | None] = field( + default_factory=dict, + ) + + def __call__(self, key: ColumnKey) -> set[str] | None: + if key in self._cache: + return self._cache[key] + result = self._fetch(key) + self._cache[key] = result + return result + + def _fetch(self, key: ColumnKey) -> set[str] | None: + schema, table, column = key + sql = ( + f"SELECT DISTINCT `{column}` FROM " + f"`{self.catalog}`.`{schema}`.`{table}` " + f"LIMIT {self.sample_cap}" + ) + try: + rows = self.query_fn(sql) + except Exception: + return None + return {str(row[0]) for row in rows if row[0] is not None} + + +@dataclass +class WarehouseProfileLookup: + query_fn: QueryFn + catalog: str + _cache: dict[ColumnKey, tuple[int, int] | None] = field( + default_factory=dict, + ) + _row_cache: dict[TableKey, int | None] = field( + default_factory=dict, + ) + + def __call__(self, key: ColumnKey) -> tuple[int, int] | None: + if key in self._cache: + return self._cache[key] + result = self._fetch(key) + self._cache[key] = result + return result + + def _fetch(self, key: ColumnKey) -> tuple[int, int] | None: + schema, table, column = key + rows = self._row_count(schema, table) + if rows is None: + return None + try: + res = self.query_fn( + f"SELECT COUNT(DISTINCT `{column}`) " + f"FROM `{self.catalog}`.`{schema}`.`{table}`" + ) + except Exception: + return None + if not res or not res[0]: + return None + return (int(res[0][0]), rows) + + def _row_count(self, schema: str, table: str) -> int | None: + tk = (schema, table) + if tk in self._row_cache: + return self._row_cache[tk] + try: + res = self.query_fn( + f"SELECT COUNT(*) FROM " + f"`{self.catalog}`.`{schema}`.`{table}`" + ) + except Exception: + self._row_cache[tk] = None + return None + rows = int(res[0][0]) if res and res[0] else None + self._row_cache[tk] = rows + return rows diff --git a/src/sema/graph/join_materializer.py b/src/sema/graph/join_materializer.py index 904afc6..f237dd6 100644 --- a/src/sema/graph/join_materializer.py +++ b/src/sema/graph/join_materializer.py @@ -1,6 +1,6 @@ """Join path materialization helpers. -Normalizes JOINS_TO + HAS_JOIN_EVIDENCE into JoinPath nodes. +Normalizes JOINS_TO + HAS_JOIN_EVIDENCE + FK_TO into JoinPath nodes. Extracted from materializer_utils.py to keep files under 400 lines. """ @@ -8,12 +8,19 @@ from typing import TYPE_CHECKING, Any -from sema.models.assertions import Assertion, AssertionPredicate +from sema.graph.join_materializer_utils import normalize_fk_to_assertion from sema.graph.loader_utils import batch_upsert_join_paths +from sema.models.assertions import Assertion, AssertionPredicate if TYPE_CHECKING: from sema.graph.loader import GraphLoader +_JOIN_PREDICATES = ( + AssertionPredicate.HAS_JOIN_EVIDENCE.value, + AssertionPredicate.JOINS_TO.value, + AssertionPredicate.FK_TO.value, +) + def _derive_join_path_name( join_predicates: list[dict[str, str]], @@ -57,48 +64,64 @@ def _build_join_path_records( def _wire_join_path_edges( loader: GraphLoader, records: list[dict[str, Any]], + source_schema: str | None, ) -> None: + if source_schema is None: + return for rec in records: name = rec["name"] for jp in rec["join_predicates"]: if jp.get("left_table"): - loader.add_join_path_uses(name, jp["left_table"]) + loader.add_join_path_uses( + name, jp["left_table"], + source_schema=source_schema, + ) if jp.get("left_column") and jp.get("left_table"): loader.add_join_path_uses( name, jp["left_table"], jp["left_column"], + source_schema=source_schema, ) if jp.get("right_table"): - loader.add_join_path_uses(name, jp["right_table"]) + loader.add_join_path_uses( + name, jp["right_table"], + source_schema=source_schema, + ) if jp.get("right_column") and jp.get("right_table"): loader.add_join_path_uses( name, jp["right_table"], jp["right_column"], + source_schema=source_schema, ) if rec.get("from_table") or rec.get("to_table"): loader.add_join_path_entity_links( name, rec.get("from_table", ""), rec.get("to_table", ""), + source_schema=source_schema, ) def materialize_join_paths( loader: GraphLoader, groups: dict[tuple[str, str], list[Assertion]], + source_schema: str | None = None, ) -> None: - """Normalize JOINS_TO + HAS_JOIN_EVIDENCE into JoinPath nodes.""" + """Normalize JOINS_TO + HAS_JOIN_EVIDENCE + FK_TO into JoinPath nodes.""" join_groups: dict[str, list[Assertion]] = {} for (subj, pred), group in groups.items(): - if pred in ( - AssertionPredicate.HAS_JOIN_EVIDENCE.value, - AssertionPredicate.JOINS_TO.value, - ): - if subj in join_groups: - join_groups[subj].extend(group) - else: - join_groups[subj] = list(group) + if pred not in _JOIN_PREDICATES: + continue + normalized = ( + [normalize_fk_to_assertion(a) for a in group] + if pred == AssertionPredicate.FK_TO.value + else list(group) + ) + if subj in join_groups: + join_groups[subj].extend(normalized) + else: + join_groups[subj] = normalized records = _build_join_path_records(join_groups) batch = [ {k: v for k, v in r.items() if k not in ("from_table", "to_table")} for r in records ] - batch_upsert_join_paths(loader, batch) - _wire_join_path_edges(loader, records) + batch_upsert_join_paths(loader, batch, source_schema=source_schema) + _wire_join_path_edges(loader, records, source_schema) diff --git a/src/sema/graph/join_materializer_utils.py b/src/sema/graph/join_materializer_utils.py new file mode 100644 index 0000000..93f5476 --- /dev/null +++ b/src/sema/graph/join_materializer_utils.py @@ -0,0 +1,38 @@ +"""Helpers for `join_materializer.py` — kept thin per the engine-file rule.""" +from __future__ import annotations + +from sema.models.assertions import Assertion, AssertionPredicate + + +def _strip_trailing_segment(ref: str) -> str: + return ref.rsplit("/", 1)[0] if ref else "" + + +def normalize_fk_to_assertion(a: Assertion) -> Assertion: + """Translate an `FK_TO` assertion into the legacy join-evidence shape. + + The legacy materializer pipeline reads `join_predicates` / + `from_table` / `to_table` from the winner's payload. FK_TO emits + `pk_table` / `pk_column` / `fk_table` / `fk_column` with subject_ref + pointing at the FK column and object_ref at the PK column. This + helper translates one shape into the other so a single materializer + path serves both predicate families. + """ + p = a.payload + pk_table, pk_column = p["pk_table"], p["pk_column"] + fk_table, fk_column = p["fk_table"], p["fk_column"] + new_payload = { + "join_predicates": [{ + "left_table": pk_table, "left_column": pk_column, + "right_table": fk_table, "right_column": fk_column, + "operator": "=", + }], + "hop_count": 1, + "from_table": _strip_trailing_segment(a.object_ref or ""), + "to_table": _strip_trailing_segment(a.subject_ref), + "tier": p.get("tier"), + } + return a.model_copy(update={ + "payload": new_payload, + "predicate": AssertionPredicate.HAS_JOIN_EVIDENCE, + }) diff --git a/src/sema/graph/loader.py b/src/sema/graph/loader.py index 598e345..f1d9d26 100644 --- a/src/sema/graph/loader.py +++ b/src/sema/graph/loader.py @@ -162,7 +162,7 @@ def batch_upsert_join_paths( def upsert_entity( self, name: str, description: str | None, source: str, confidence: float, table_name: str, schema_name: str, - catalog: str, + catalog: str, source_schema: str | None = None, ) -> None: id_ = str(uuid.uuid4()) self._run( @@ -174,23 +174,25 @@ def upsert_entity( "WITH e " "MERGE (t:Table {name: $table_name, " "schema_name: $schema_name, catalog: $catalog}) " - "MERGE (e)-[:ENTITY_ON_TABLE]->(t)", + "MERGE (e)-[link:ENTITY_ON_TABLE " + "{source_schema: $source_schema}]->(t)", name=name, description=description, source=source, confidence=confidence, resolved_at=datetime.now(timezone.utc).isoformat(), table_name=table_name, schema_name=schema_name, - catalog=catalog, id=id_, + catalog=catalog, id=id_, source_schema=source_schema, ) def upsert_property( self, name: str, semantic_type: str, source: str, confidence: float, entity_name: str, column_name: str, table_name: str, schema_name: str, catalog: str, + source_schema: str | None = None, ) -> None: id_ = str(uuid.uuid4()) self._run( - "MERGE (p:Property {name: $name, " - "entity_name: $entity_name}) " + "MERGE (p:Property {entity_name: $entity_name, " + "name: $name}) " "ON CREATE SET p.id = $id " "SET p.semantic_type = $semantic_type, " "p.source = $source, " @@ -198,18 +200,20 @@ def upsert_property( "p.resolved_at = $resolved_at " "WITH p " "MERGE (e:Entity {name: $entity_name}) " - "MERGE (e)-[:HAS_PROPERTY]->(p) " + "MERGE (e)-[hp:HAS_PROPERTY " + "{source_schema: $source_schema}]->(p) " "WITH p " "MERGE (c:Column {name: $column_name, " "table_name: $table_name, " "schema_name: $schema_name, catalog: $catalog}) " - "MERGE (p)-[:PROPERTY_ON_COLUMN]->(c)", + "MERGE (p)-[poc:PROPERTY_ON_COLUMN " + "{source_schema: $source_schema}]->(c)", name=name, semantic_type=semantic_type, source=source, confidence=confidence, resolved_at=datetime.now(timezone.utc).isoformat(), entity_name=entity_name, column_name=column_name, table_name=table_name, schema_name=schema_name, - catalog=catalog, id=id_, + catalog=catalog, id=id_, source_schema=source_schema, ) def upsert_term( @@ -232,62 +236,89 @@ def upsert_term( def upsert_value_set( self, name: str, column_name: str, table_name: str, schema_name: str, catalog: str, + source_schema: str | None = None, + column_ref: str | None = None, ) -> None: id_ = str(uuid.uuid4()) + ref = column_ref or ( + f"{catalog}.{schema_name}.{table_name}.{column_name}" + ) self._run( - "MERGE (vs:ValueSet {name: $name}) " + "MERGE (vs:ValueSet {column_ref: $column_ref}) " "ON CREATE SET vs.id = $id " + "SET vs.name = $name " "WITH vs " "MERGE (c:Column {name: $column_name, " "table_name: $table_name, " "schema_name: $schema_name, catalog: $catalog}) " - "MERGE (c)-[:HAS_VALUE_SET]->(vs)", + "MERGE (c)-[hvs:HAS_VALUE_SET " + "{source_schema: $source_schema}]->(vs)", name=name, column_name=column_name, table_name=table_name, schema_name=schema_name, - catalog=catalog, id=id_, + catalog=catalog, id=id_, column_ref=ref, + source_schema=source_schema, ) def add_term_to_value_set( self, term_code: str, value_set_name: str, + source_schema: str | None = None, ) -> None: self._run( "MERGE (t:Term {code: $term_code}) " "MERGE (vs:ValueSet {name: $value_set_name}) " - "MERGE (t)-[:MEMBER_OF]->(vs)", + "MERGE (t)-[m:MEMBER_OF " + "{source_schema: $source_schema}]->(vs)", term_code=term_code, value_set_name=value_set_name, + source_schema=source_schema, ) def add_term_hierarchy( self, parent_code: str, child_code: str, + source_schema: str | None = None, ) -> None: self._run( "MERGE (p:Term {code: $parent_code}) " "MERGE (c:Term {code: $child_code}) " - "MERGE (p)-[:PARENT_OF]->(c)", + "MERGE (p)-[po:PARENT_OF " + "{source_schema: $source_schema}]->(c)", parent_code=parent_code, child_code=child_code, + source_schema=source_schema, ) def upsert_alias( self, text: str, parent_label: str, parent_name: str, source: str, confidence: float, is_preferred: bool = False, description: str | None = None, + source_schema: str | None = None, + parent_entity_name: str | None = None, ) -> None: id_ = str(uuid.uuid4()) + if parent_label == ":Property": + parent_match = ( + "MERGE (p:Property {entity_name: $parent_entity_name, " + "name: $parent_name})" + ) + else: + parent_match = ( + f"MERGE (p{parent_label} {{name: $parent_name}})" + ) self._run( - f"MERGE (a:Alias {{text: $text}}) " - f"ON CREATE SET a.id = $id " - f"SET a.source = $source, a.confidence = $confidence, " - f"a.resolved_at = $resolved_at, " - f"a.is_preferred = $is_preferred, " - f"a.description = $description " - f"WITH a " - f"MERGE (p{parent_label} {{name: $parent_name}}) " - f"MERGE (a)-[:REFERS_TO]->(p)", + "MERGE (a:Alias {text: $text}) " + "ON CREATE SET a.id = $id " + "SET a.source = $source, a.confidence = $confidence, " + "a.resolved_at = $resolved_at, " + "a.is_preferred = $is_preferred, " + "a.description = $description " + "WITH a " + f"{parent_match} " + "MERGE (a)-[ref:REFERS_TO " + "{source_schema: $source_schema}]->(p)", text=text, parent_name=parent_name, source=source, confidence=confidence, resolved_at=datetime.now(timezone.utc).isoformat(), is_preferred=is_preferred, description=description, - id=id_, + id=id_, source_schema=source_schema, + parent_entity_name=parent_entity_name, ) def upsert_join_path( @@ -295,10 +326,12 @@ def upsert_join_path( hop_count: int, source: str, confidence: float, sql_snippet: str | None = None, cardinality_hint: str | None = None, + source_schema: str | None = None, ) -> None: id_ = str(uuid.uuid4()) self._run( - "MERGE (jp:JoinPath {name: $name}) " + "MERGE (jp:JoinPath {name: $name, " + "source_schema: $source_schema}) " "ON CREATE SET jp.id = $id " "SET jp.join_predicates = $join_predicates, " "jp.hop_count = $hop_count, jp.source = $source, " @@ -312,53 +345,75 @@ def upsert_join_path( confidence=confidence, sql_snippet=sql_snippet, cardinality_hint=cardinality_hint, resolved_at=datetime.now(timezone.utc).isoformat(), - id=id_, + id=id_, source_schema=source_schema, ) def add_join_path_uses( self, join_path_name: str, table_ref: str, column_name: str | None = None, + source_schema: str | None = None, ) -> None: + if source_schema is None: + raise ValueError( + "add_join_path_uses requires source_schema to scope " + "the JoinPath match by {name, source_schema}" + ) if column_name: self._run( - "MATCH (jp:JoinPath {name: $jp_name}) " + "MATCH (jp:JoinPath {name: $jp_name, " + "source_schema: $source_schema}) " "MATCH (c:Column {ref: $ref}) " - "MERGE (jp)-[:USES]->(c)", + "MERGE (jp)-[u:USES " + "{source_schema: $source_schema}]->(c)", jp_name=join_path_name, ref=table_ref, + source_schema=source_schema, ) else: self._run( - "MATCH (jp:JoinPath {name: $jp_name}) " + "MATCH (jp:JoinPath {name: $jp_name, " + "source_schema: $source_schema}) " "MATCH (t:Table {ref: $ref}) " - "MERGE (jp)-[:USES]->(t)", + "MERGE (jp)-[u:USES " + "{source_schema: $source_schema}]->(t)", jp_name=join_path_name, ref=table_ref, + source_schema=source_schema, ) def add_join_path_entity_links( self, join_path_name: str, from_table_ref: str, to_table_ref: str, + source_schema: str | None = None, ) -> None: + if source_schema is None: + raise ValueError( + "add_join_path_entity_links requires source_schema " + "to scope the JoinPath match by {name, source_schema}" + ) self._run( - "MATCH (jp:JoinPath {name: $jp_name}) " + "MATCH (jp:JoinPath {name: $jp_name, " + "source_schema: $source_schema}) " "OPTIONAL MATCH (fe:Entity)-[:ENTITY_ON_TABLE]->" "(:Table {ref: $from_ref}) " "OPTIONAL MATCH (te:Entity)-[:ENTITY_ON_TABLE]->" "(:Table {ref: $to_ref}) " "FOREACH (_ IN CASE WHEN fe IS NOT NULL " "THEN [1] ELSE [] END | " - "MERGE (jp)-[:FROM_ENTITY]->(fe)) " + "MERGE (jp)-[fr:FROM_ENTITY " + "{source_schema: $source_schema}]->(fe)) " "FOREACH (_ IN CASE WHEN te IS NOT NULL " "THEN [1] ELSE [] END | " - "MERGE (jp)-[:TO_ENTITY]->(te))", + "MERGE (jp)-[to_:TO_ENTITY " + "{source_schema: $source_schema}]->(te))", jp_name=join_path_name, from_ref=from_table_ref, to_ref=to_table_ref, + source_schema=source_schema, ) - def store_assertion(self, assertion: Assertion) -> None: - # NOTE: No longer mutating prior assertions to 'superseded'. - # Status transitions are now handled via StatusEvent log. - # Prior assertions remain as immutable history. + def store_assertion( + self, assertion: Assertion, source_schema: str | None = None, + ) -> None: + schema = source_schema or assertion.source_schema self._run( "CREATE (a:Assertion {" " id: $id, subject_ref: $subject_ref," @@ -368,7 +423,8 @@ def store_assertion(self, assertion: Assertion) -> None: " object_id: $object_id," " source: $source, confidence: $confidence," " status: $status, run_id: $run_id," - " observed_at: $observed_at" + " observed_at: $observed_at," + " source_schema: $source_schema" "})", id=assertion.id, subject_ref=assertion.subject_ref, @@ -382,16 +438,19 @@ def store_assertion(self, assertion: Assertion) -> None: status="auto", run_id=assertion.run_id, observed_at=assertion.observed_at.isoformat(), + source_schema=schema, ) def batch_store_assertions( self, assertions: list[Assertion], + source_schema: str | None = None, ) -> None: for assertion in assertions: - self.store_assertion(assertion) + self.store_assertion(assertion, source_schema=source_schema) def _build_assertion_dicts( self, assertions: list[Assertion], + source_schema: str | None = None, ) -> list[dict[str, Any]]: return [ { @@ -407,20 +466,20 @@ def _build_assertion_dicts( "status": "auto", "run_id": a.run_id, "observed_at": a.observed_at.isoformat(), + "source_schema": source_schema or a.source_schema, } for a in assertions ] def commit_table_assertions( self, assertions: list[Assertion], + source_schema: str | None = None, ) -> None: with self._driver.session() as session: tx = session.begin_transaction() try: - # NOTE: No longer mutating prior assertions to 'superseded'. - # Status transitions are now handled via StatusEvent log. assertion_dicts = self._build_assertion_dicts( - assertions + assertions, source_schema=source_schema, ) tx.run( "UNWIND $assertions AS a " @@ -436,7 +495,8 @@ def commit_table_assertions( " confidence: a.confidence," " status: a.status," " run_id: a.run_id," - " observed_at: a.observed_at" + " observed_at: a.observed_at," + " source_schema: a.source_schema" "})", assertions=assertion_dicts, ) @@ -512,6 +572,31 @@ def set_embedding( updated_at=datetime.now(timezone.utc).isoformat(), ) + def delete_study_scoped(self, schema_name: str) -> None: + """Remove every graph element stamped with this study's schema. + + Edge sweep is type-agnostic (matches by `source_schema` property). + Assertion / JoinPath nodes are detach-deleted by `source_schema`, + which transitively removes provenance edges (`:SUBJECT` / + `:OBJECT`). Shared concept nodes (`:Entity`, `:Term`, + `:ValueSet`, `:Property`, `:SemanticType`) and physical nodes + (`:Table`, `:Column`, `:Schema`) are never touched. + """ + self._run( + "MATCH ()-[r {source_schema: $schema}]-() DELETE r", + schema=schema_name, + ) + self._run( + "MATCH (a:Assertion {source_schema: $schema}) " + "DETACH DELETE a", + schema=schema_name, + ) + self._run( + "MATCH (jp:JoinPath {source_schema: $schema}) " + "DETACH DELETE jp", + schema=schema_name, + ) + def has_assertions(self, table_ref: str) -> bool: table_ref_slash = table_ref + "/" with self._driver.session() as session: diff --git a/src/sema/graph/loader_utils.py b/src/sema/graph/loader_utils.py index d265597..21ea372 100644 --- a/src/sema/graph/loader_utils.py +++ b/src/sema/graph/loader_utils.py @@ -2,6 +2,12 @@ Each function takes a ``loader`` (GraphLoader) as its first argument and delegates to ``loader._run`` for Cypher execution. + +Study-derived edges (`:ENTITY_ON_TABLE`, `:HAS_PROPERTY`, +`:PROPERTY_ON_COLUMN`, `:HAS_VALUE_SET`, `:REFERS_TO`) include +`source_schema` IN THE MERGE MATCH KEY so two studies emitting the same +logical edge produce two distinct relationships, each independently +scoped-deletable. """ from __future__ import annotations @@ -10,69 +16,109 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any +from sema.models.extraction import ExtractedColumn + if TYPE_CHECKING: from sema.graph.loader import GraphLoader +_FETCH_COLUMNS_QUERY = ( + "MATCH (c:Column) WHERE c.schema_name = $schema_name " + "RETURN c.name AS name, c.table_name AS table_name, " + "c.catalog AS catalog, c.schema_name AS schema_name, " + "c.data_type AS data_type, c.nullable AS nullable, " + "c.comment AS comment" +) + + +def fetch_columns_by_schema( + loader: GraphLoader, schema_name: str, +) -> list[ExtractedColumn]: + rows = loader._run_read( + _FETCH_COLUMNS_QUERY, schema_name=schema_name, + ) + return [ + ExtractedColumn( + name=row["name"], + table_name=row["table_name"], + catalog=row.get("catalog", ""), + schema=row["schema_name"], + data_type=row.get("data_type", "UNKNOWN"), + nullable=bool(row.get("nullable", True)), + comment=row.get("comment"), + ) + for row in rows + ] + + +def _annotate_rows( + items: list[dict[str, Any]], + source_schema: str | None, +) -> list[dict[str, Any]]: + resolved_at = datetime.now(timezone.utc).isoformat() + return [ + { + **item, + "resolved_at": resolved_at, + "id": str(uuid.uuid4()), + "source_schema": source_schema, + } + for item in items + ] + + def batch_upsert_entities( - loader: GraphLoader, entities: list[dict[str, Any]], + loader: GraphLoader, + entities: list[dict[str, Any]], + source_schema: str | None = None, ) -> None: if not entities: return - resolved_at = datetime.now(timezone.utc).isoformat() - rows = [ - {**e, "resolved_at": resolved_at, "id": str(uuid.uuid4())} - for e in entities - ] + rows = _annotate_rows(entities, source_schema) loader._run( "UNWIND $rows AS r " - "MERGE (e:Entity {datasource_id: r.datasource_id, " - "table_key: r.table_key}) " + "MERGE (e:Entity {name: r.name}) " "ON CREATE SET e.id = r.id " - "SET e.name = r.name, " - "e.description = r.description, e.source = r.source, " + "SET e.description = r.description, e.source = r.source, " "e.confidence = r.confidence, " "e.status = 'ACTIVE', " "e.resolved_at = r.resolved_at " "WITH e, r " "MERGE (t:Table {name: r.table_name, " "schema_name: r.schema_name, catalog: r.catalog}) " - "MERGE (e)-[:ENTITY_ON_TABLE]->(t)", + "MERGE (e)-[link:ENTITY_ON_TABLE " + "{source_schema: r.source_schema}]->(t)", rows=rows, ) def batch_upsert_properties( - loader: GraphLoader, properties: list[dict[str, Any]], + loader: GraphLoader, + properties: list[dict[str, Any]], + source_schema: str | None = None, ) -> None: if not properties: return - resolved_at = datetime.now(timezone.utc).isoformat() - rows = [ - {**p, "resolved_at": resolved_at, "id": str(uuid.uuid4())} - for p in properties - ] + rows = _annotate_rows(properties, source_schema) loader._run( "UNWIND $rows AS r " - "MERGE (p:Property {datasource_id: r.datasource_id, " - "column_key: r.column_key}) " + "MERGE (p:Property {entity_name: r.entity_name, name: r.name}) " "ON CREATE SET p.id = r.id " - "SET p.name = r.name, " - "p.entity_name = r.entity_name, " - "p.semantic_type = r.semantic_type, " + "SET p.semantic_type = r.semantic_type, " "p.source = r.source, " "p.confidence = r.confidence, " "p.status = 'ACTIVE', " "p.resolved_at = r.resolved_at " "WITH p, r " - "MERGE (e:Entity {datasource_id: r.datasource_id, " - "table_key: r.table_key}) " - "MERGE (e)-[:HAS_PROPERTY]->(p) " + "MERGE (e:Entity {name: r.entity_name}) " + "MERGE (e)-[hp:HAS_PROPERTY " + "{source_schema: r.source_schema}]->(p) " "WITH p, r " "MERGE (c:Column {name: r.column_name, " "table_name: r.table_name, " "schema_name: r.schema_name, catalog: r.catalog}) " - "MERGE (p)-[:PROPERTY_ON_COLUMN]->(c)", + "MERGE (p)-[poc:PROPERTY_ON_COLUMN " + "{source_schema: r.source_schema}]->(c)", rows=rows, ) @@ -89,11 +135,11 @@ def batch_upsert_terms( ] loader._run( "UNWIND $rows AS r " - "MERGE (t:Term {vocabulary_name: r.vocabulary_name, " - "code: r.code}) " + "MERGE (t:Term {code: r.code}) " "ON CREATE SET t.id = r.id " "SET t.label = r.label, t.source = r.source, " "t.confidence = r.confidence, " + "t.vocabulary_name = r.vocabulary_name, " "t.status = 'ACTIVE', " "t.resolved_at = r.resolved_at", rows=rows, @@ -104,42 +150,46 @@ def batch_upsert_aliases( loader: GraphLoader, aliases: list[dict[str, Any]], parent_label: str, + source_schema: str | None = None, ) -> None: if not aliases: return - resolved_at = datetime.now(timezone.utc).isoformat() - rows = [ - {**a, "resolved_at": resolved_at, "id": str(uuid.uuid4())} - for a in aliases - ] + rows = _annotate_rows(aliases, source_schema) + if parent_label == ":Property": + parent_match = ( + "MERGE (p:Property {entity_name: r.parent_entity_name, " + "name: r.parent_name})" + ) + else: + parent_match = f"MERGE (p{parent_label} {{name: r.parent_name}})" loader._run( - f"UNWIND $rows AS r " - f"MERGE (a:Alias {{target_key: r.target_key, text: r.text}}) " - f"ON CREATE SET a.id = r.id " - f"SET a.source = r.source, a.confidence = r.confidence, " - f"a.resolved_at = r.resolved_at, " - f"a.status = 'ACTIVE', " - f"a.is_preferred = r.is_preferred, " - f"a.description = r.description " - f"WITH a, r " - f"MERGE (p{parent_label} {{name: r.parent_name}}) " - f"MERGE (a)-[:REFERS_TO]->(p)", + "UNWIND $rows AS r " + "MERGE (a:Alias {target_key: r.target_key, text: r.text}) " + "ON CREATE SET a.id = r.id " + "SET a.source = r.source, a.confidence = r.confidence, " + "a.resolved_at = r.resolved_at, " + "a.status = 'ACTIVE', " + "a.is_preferred = r.is_preferred, " + "a.description = r.description " + "WITH a, r " + f"{parent_match} " + "MERGE (a)-[ref:REFERS_TO " + "{source_schema: r.source_schema}]->(p)", rows=rows, ) def batch_upsert_value_sets( - loader: GraphLoader, value_sets: list[dict[str, Any]], + loader: GraphLoader, + value_sets: list[dict[str, Any]], + source_schema: str | None = None, ) -> None: if not value_sets: return - rows = [ - {**vs, "id": str(uuid.uuid4())} for vs in value_sets - ] + rows = _annotate_rows(value_sets, source_schema) loader._run( "UNWIND $rows AS r " - "MERGE (vs:ValueSet {datasource_id: r.datasource_id, " - "column_key: r.column_key}) " + "MERGE (vs:ValueSet {column_ref: r.column_ref}) " "ON CREATE SET vs.id = r.id " "SET vs.name = r.name, " "vs.status = 'ACTIVE' " @@ -147,13 +197,16 @@ def batch_upsert_value_sets( "MERGE (c:Column {name: r.column_name, " "table_name: r.table_name, " "schema_name: r.schema_name, catalog: r.catalog}) " - "MERGE (c)-[:HAS_VALUE_SET]->(vs)", + "MERGE (c)-[hvs:HAS_VALUE_SET " + "{source_schema: r.source_schema}]->(vs)", rows=rows, ) def batch_upsert_join_paths( - loader: GraphLoader, join_paths: list[dict[str, Any]], + loader: GraphLoader, + join_paths: list[dict[str, Any]], + source_schema: str | None = None, ) -> None: if not join_paths: return @@ -163,6 +216,7 @@ def batch_upsert_join_paths( **jp, "resolved_at": resolved_at, "id": str(uuid.uuid4()), + "source_schema": source_schema, "join_predicates_json": json.dumps( jp["join_predicates"] ), @@ -171,12 +225,10 @@ def batch_upsert_join_paths( ] loader._run( "UNWIND $rows AS r " - "MERGE (jp:JoinPath {datasource_id: r.datasource_id, " - "from_table: r.from_table, to_table: r.to_table, " - "join_columns_hash: r.join_columns_hash}) " + "MERGE (jp:JoinPath {name: r.name, " + "source_schema: r.source_schema}) " "ON CREATE SET jp.id = r.id " - "SET jp.name = r.name, " - "jp.join_predicates = r.join_predicates_json, " + "SET jp.join_predicates = r.join_predicates_json, " "jp.hop_count = r.hop_count, " "jp.source = r.source, " "jp.confidence = r.confidence, " @@ -218,8 +270,7 @@ def batch_create_classified_as( return loader._run( "UNWIND $rows AS r " - "MATCH (p:Property {datasource_id: r.datasource_id, " - "column_key: r.column_key}) " + "MATCH (p:Property {entity_name: r.entity_name, name: r.name}) " "MERGE (v:Vocabulary {name: r.vocabulary_name}) " "MERGE (p)-[:CLASSIFIED_AS]->(v)", rows=edges, @@ -235,8 +286,7 @@ def batch_create_in_vocabulary( return loader._run( "UNWIND $rows AS r " - "MATCH (t:Term {vocabulary_name: r.vocabulary_name, " - "code: r.code}) " + "MATCH (t:Term {code: r.code}) " "MERGE (v:Vocabulary {name: r.vocabulary_name}) " "MERGE (t)-[:IN_VOCABULARY]->(v)", rows=edges, diff --git a/src/sema/graph/materializer.py b/src/sema/graph/materializer.py index 4fccde0..ffeb350 100644 --- a/src/sema/graph/materializer.py +++ b/src/sema/graph/materializer.py @@ -35,6 +35,7 @@ def materialize_unified( loader: GraphLoader, assertions: list[Assertion], + source_schema: str | None = None, ) -> None: """Unified materializer: single assertion-to-graph path. @@ -52,21 +53,16 @@ def materialize_unified( by_subject[a.subject_ref].append(a) groups[(a.subject_ref, a.predicate.value)].append(a) - # Phase 1: Physical nodes upsert_physical_nodes(loader, by_subject) upsert_column_nodes(loader, by_subject) - - # Phase 2: Semantic nodes - upsert_semantic_nodes(loader, by_subject, groups) - - # Phase 3: Bridge edges - apply_resolution_edges(loader, groups) + upsert_semantic_nodes( + loader, by_subject, groups, source_schema=source_schema, + ) + apply_resolution_edges( + loader, groups, source_schema=source_schema, + ) materialize_vocabulary_edges(loader, groups) - - # Phase 4: Provenance loader.materialize_provenance_edges(assertions) - - # Phase 5: Lifecycle run_lifecycle_phase(loader, assertions) logger.info( diff --git a/src/sema/graph/materializer_utils.py b/src/sema/graph/materializer_utils.py index 1698d7d..3f010a0 100644 --- a/src/sema/graph/materializer_utils.py +++ b/src/sema/graph/materializer_utils.py @@ -133,6 +133,7 @@ def upsert_column_nodes( def upsert_entities( loader: GraphLoader, groups: dict[tuple[str, str], list[Assertion]], + source_schema: str | None = None, ) -> None: entity_groups = { subj: group @@ -150,13 +151,11 @@ def upsert_entities( "description": winner.payload.get("description"), "source": winner.source, "confidence": winner.confidence, - "datasource_id": pk.datasource_id, - "table_key": pk.table_key, "table_name": pk.table, "schema_name": pk.schema or "", "catalog": pk.catalog_or_db, }) - batch_upsert_entities(loader, batch) + batch_upsert_entities(loader, batch, source_schema=source_schema) def _resolve_property_details( @@ -200,9 +199,6 @@ def _resolve_property_details( "source": winner.source, "confidence": winner.confidence, "entity_name": entity_name, - "datasource_id": pk.datasource_id, - "table_key": pk.table_key, - "column_key": pk.column_key, "column_name": pk.column, "table_name": pk.table, "schema_name": pk.schema or "", @@ -213,6 +209,7 @@ def _resolve_property_details( def upsert_properties( loader: GraphLoader, groups: dict[tuple[str, str], list[Assertion]], + source_schema: str | None = None, ) -> None: prop_groups = { subj: group @@ -224,12 +221,20 @@ def upsert_properties( details = _resolve_property_details(col_ref, group, groups) if details: batch.append(details) - batch_upsert_properties(loader, batch) + batch_upsert_properties(loader, batch, source_schema=source_schema) + + +def _column_ref(pk: Any) -> str: + return ( + f"{pk.catalog_or_db}.{pk.schema or ''}" + f".{pk.table}.{pk.column}" + ) def upsert_decoded_values( loader: GraphLoader, groups: dict[tuple[str, str], list[Assertion]], + source_schema: str | None = None, ) -> None: decoded_groups: dict[str, list[Assertion]] = defaultdict(list) for (subj, pred), group in groups.items(): @@ -247,10 +252,10 @@ def upsert_decoded_values( if not pk.column: continue vs_name = f"{pk.table}_{pk.column}_values" + ref = _column_ref(pk) vs_batch.append({ "name": vs_name, - "datasource_id": pk.datasource_id, - "column_key": pk.column_key, + "column_ref": ref, "column_name": pk.column, "table_name": pk.table, "schema_name": pk.schema or "", "catalog": pk.catalog_or_db, }) @@ -266,9 +271,13 @@ def upsert_decoded_values( "vocabulary_name": vs_name, "source": a.source, "confidence": a.confidence, }) - loader.add_term_to_value_set(raw, vs_name) + loader.add_term_to_value_set( + raw, vs_name, source_schema=source_schema, + ) - batch_upsert_value_sets(loader, vs_batch) + batch_upsert_value_sets( + loader, vs_batch, source_schema=source_schema, + ) batch_upsert_terms(loader, term_batch) @@ -284,6 +293,7 @@ def _collect_alias_batch( return [], ":Entity" batch: list[dict[str, Any]] = [] + parent_entity_name: str | None = None if pk.column: prop_group = groups.get( (subject_ref, AssertionPredicate.HAS_PROPERTY_NAME.value), [], @@ -294,7 +304,16 @@ def _collect_alias_batch( if prop_winner else pk.column ) parent_label = ":Property" - target_key = pk.column_key or subject_ref + target_key = subject_ref + table_ref = subject_ref.rsplit("/", 1)[0] + ent_group = groups.get( + (table_ref, AssertionPredicate.HAS_ENTITY_NAME.value), [], + ) + ent_winner = pick_winner(ent_group) + parent_entity_name = ( + ent_winner.payload.get("value", pk.table) + if ent_winner else pk.table + ) else: entity_group = groups.get( (subject_ref, AssertionPredicate.HAS_ENTITY_NAME.value), [], @@ -305,7 +324,7 @@ def _collect_alias_batch( if entity_winner else pk.table ) parent_label = ":Entity" - target_key = pk.table_key + target_key = subject_ref for a in group: if a.status in ( @@ -316,6 +335,7 @@ def _collect_alias_batch( "text": a.payload.get("value", ""), "target_key": target_key, "parent_name": parent_name, + "parent_entity_name": parent_entity_name, "source": a.source, "confidence": a.confidence, "is_preferred": a.payload.get("is_preferred", False), "description": a.payload.get("description"), @@ -327,6 +347,7 @@ def _collect_alias_batch( def upsert_aliases( loader: GraphLoader, groups: dict[tuple[str, str], list[Assertion]], + source_schema: str | None = None, ) -> None: alias_groups = { subj: group @@ -348,26 +369,34 @@ def upsert_aliases( else: property_aliases.extend(batch) - batch_upsert_aliases(loader, entity_aliases, ":Entity") - batch_upsert_aliases(loader, property_aliases, ":Property") + batch_upsert_aliases( + loader, entity_aliases, ":Entity", + source_schema=source_schema, + ) + batch_upsert_aliases( + loader, property_aliases, ":Property", + source_schema=source_schema, + ) def upsert_semantic_nodes( loader: GraphLoader, by_subject: dict[str, list[Assertion]], groups: dict[tuple[str, str], list[Assertion]], + source_schema: str | None = None, ) -> None: - upsert_entities(loader, groups) - upsert_properties(loader, groups) - upsert_decoded_values(loader, groups) - upsert_aliases(loader, groups) + upsert_entities(loader, groups, source_schema=source_schema) + upsert_properties(loader, groups, source_schema=source_schema) + upsert_decoded_values(loader, groups, source_schema=source_schema) + upsert_aliases(loader, groups, source_schema=source_schema) from sema.graph.join_materializer import materialize_join_paths - materialize_join_paths(loader, groups) + materialize_join_paths(loader, groups, source_schema=source_schema) def apply_resolution_edges( loader: GraphLoader, groups: dict[tuple[str, str], list[Assertion]], + source_schema: str | None = None, ) -> None: for (subj, pred), group in groups.items(): if pred == AssertionPredicate.PARENT_OF.value: @@ -379,6 +408,7 @@ def apply_resolution_edges( loader.add_term_hierarchy( parent_code=a.payload.get("parent", ""), child_code=a.payload.get("child", ""), + source_schema=source_schema, ) diff --git a/src/sema/graph/vocabulary_materializer.py b/src/sema/graph/vocabulary_materializer.py index 0b6dfea..3ffd055 100644 --- a/src/sema/graph/vocabulary_materializer.py +++ b/src/sema/graph/vocabulary_materializer.py @@ -53,16 +53,33 @@ def materialize_vocabulary_edges( try: pk = CanonicalRef.parse(subj) - col_key = pk.column_key - ds_id = pk.datasource_id except ValueError: continue - if not col_key: + if not pk.column: continue + prop_group = groups.get( + (subj, AssertionPredicate.HAS_PROPERTY_NAME.value), [], + ) + prop_winner = pick_winner(prop_group) + prop_name = ( + prop_winner.payload.get("value", pk.column) + if prop_winner else pk.column + ) + + table_ref = subj.rsplit("/", 1)[0] + ent_group = groups.get( + (table_ref, AssertionPredicate.HAS_ENTITY_NAME.value), [], + ) + ent_winner = pick_winner(ent_group) + entity_name = ( + ent_winner.payload.get("value", pk.table) + if ent_winner else pk.table + ) + classified_edges.append({ - "datasource_id": ds_id, - "column_key": col_key, + "entity_name": entity_name, + "name": prop_name, "vocabulary_name": vocab_name, }) diff --git a/src/sema/ingest/comment_recovery.py b/src/sema/ingest/comment_recovery.py new file mode 100644 index 0000000..78ee59d --- /dev/null +++ b/src/sema/ingest/comment_recovery.py @@ -0,0 +1,314 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, Sequence + +from sema.ingest.databricks_push_utils import ( + build_alter_column_comment_sql, + build_alter_table_comment_sql, +) +from sema.ingest.study_registry import StudyRegistry +from sema.log import logger +from sema.models.config import IngestConfig + +QueryFn = Callable[[str, list[str]], Sequence[Sequence[object]]] + + +class StudyNotRegisteredError(LookupError): + """The given study_id is not present in the _sema_study_registry.""" + + +class StudyCacheMissingError(LookupError): + """The study is registered but its cache directory does not exist.""" + + +class PartialOverrideError(ValueError): + """Partial override flags supplied; bypass mode requires the full set.""" + + +@dataclass(frozen=True) +class ParsedTableComments: + table_comment: str | None + column_comments: dict[str, str] + + +@dataclass(frozen=True) +class LiveTableComments: + table_comment: str | None + column_comments: dict[str, str | None] + + +@dataclass(frozen=True) +class RecoveryContext: + study_id: str | None + source_cache: Path + target_catalog: str + target_schema: str + + +@dataclass(frozen=True) +class ColumnUpdate: + table: str + column: str + new_comment: str + + +@dataclass(frozen=True) +class TableUpdate: + table: str + new_comment: str + + +@dataclass(frozen=True) +class SkippedColumn: + table: str + column: str + reason: str + + +@dataclass(frozen=True) +class FailedColumn: + table: str + column: str + error: str + + +@dataclass(frozen=True) +class RecoveryPlan: + catalog: str + schema: str + column_updates: list[ColumnUpdate] + table_updates: list[TableUpdate] + skipped_columns: list[SkippedColumn] + + +@dataclass(frozen=True) +class RecoveryReport: + columns_updated: int + columns_skipped: int + columns_failed: int + table_comments_updated: int + failed: list[FailedColumn] = field(default_factory=list) + skipped: list[SkippedColumn] = field(default_factory=list) + + +def build_recovery_plan( + ctx: RecoveryContext, + parsed: dict[str, ParsedTableComments], + live: dict[str, LiveTableComments], + *, + force: bool = False, +) -> RecoveryPlan: + column_updates: list[ColumnUpdate] = [] + table_updates: list[TableUpdate] = [] + skipped: list[SkippedColumn] = [] + for table in sorted(parsed): + parsed_t = parsed[table] + live_t = live.get(table) + if live_t is None: + for col in parsed_t.column_comments: + skipped.append(SkippedColumn(table, col, "table_not_found")) + continue + _plan_column_updates( + table, parsed_t, live_t, force, column_updates, skipped, + ) + _plan_table_update(table, parsed_t, live_t, force, table_updates) + return RecoveryPlan( + catalog=ctx.target_catalog, + schema=ctx.target_schema, + column_updates=column_updates, + table_updates=table_updates, + skipped_columns=skipped, + ) + + +def _plan_column_updates( + table: str, + parsed_t: ParsedTableComments, + live_t: LiveTableComments, + force: bool, + out_updates: list[ColumnUpdate], + out_skipped: list[SkippedColumn], +) -> None: + for column, new_text in parsed_t.column_comments.items(): + if not new_text: + continue + if column not in live_t.column_comments: + out_skipped.append(SkippedColumn(table, column, "column_not_found")) + continue + existing = live_t.column_comments.get(column) or "" + if existing and not force: + out_skipped.append( + SkippedColumn(table, column, "already_commented") + ) + continue + out_updates.append( + ColumnUpdate(table=table, column=column, new_comment=new_text) + ) + + +def _plan_table_update( + table: str, + parsed_t: ParsedTableComments, + live_t: LiveTableComments, + force: bool, + out_updates: list[TableUpdate], +) -> None: + new_text = parsed_t.table_comment + if not new_text: + return + existing = live_t.table_comment or "" + if existing and not force: + return + out_updates.append(TableUpdate(table=table, new_comment=new_text)) + + +def execute_recovery_plan( + plan: RecoveryPlan, + executor_fn: Callable[[str], None], + *, + dry_run: bool = False, +) -> RecoveryReport: + failed: list[FailedColumn] = [] + columns_updated = 0 + table_comments_updated = 0 + for upd in plan.column_updates: + sql = build_alter_column_comment_sql( + plan.catalog, plan.schema, upd.table, upd.column, upd.new_comment, + ) + if dry_run: + print(sql) + columns_updated += 1 + continue + try: + executor_fn(sql) + columns_updated += 1 + except Exception as err: # noqa: BLE001 + logger.warning( + "ALTER COLUMN failed for {}.{}: {}", upd.table, upd.column, err, + ) + failed.append( + FailedColumn(table=upd.table, column=upd.column, error=str(err)) + ) + for tupd in plan.table_updates: + sql = build_alter_table_comment_sql( + plan.catalog, plan.schema, tupd.table, tupd.new_comment, + ) + if dry_run: + print(sql) + table_comments_updated += 1 + continue + try: + executor_fn(sql) + table_comments_updated += 1 + except Exception as err: # noqa: BLE001 + logger.warning("COMMENT ON TABLE failed for {}: {}", tupd.table, err) + failed.append( + FailedColumn(table=tupd.table, column="", error=str(err)) + ) + return RecoveryReport( + columns_updated=columns_updated, + columns_skipped=len(plan.skipped_columns), + columns_failed=len(failed), + table_comments_updated=table_comments_updated, + failed=failed, + skipped=list(plan.skipped_columns), + ) + + +def read_databricks_comments( + catalog: str, schema: str, query_fn: QueryFn, +) -> dict[str, LiveTableComments]: + column_rows = query_fn( + "SELECT table_name, column_name, comment " + "FROM system.information_schema.columns " + "WHERE table_catalog = ? AND table_schema = ? " + "ORDER BY table_name, ordinal_position", + [catalog, schema], + ) + table_rows = query_fn( + "SELECT table_name, comment " + "FROM system.information_schema.tables " + "WHERE table_catalog = ? AND table_schema = ?", + [catalog, schema], + ) + columns_by_table: dict[str, dict[str, str | None]] = {} + for row in column_rows: + table, column, comment = row[0], row[1], row[2] + columns_by_table.setdefault(str(table), {})[str(column)] = ( + None if comment is None else str(comment) + ) + table_comments: dict[str, str | None] = { + str(row[0]): (None if row[1] is None else str(row[1])) + for row in table_rows + } + out: dict[str, LiveTableComments] = {} + table_names = set(columns_by_table) | set(table_comments) + for name in table_names: + out[name] = LiveTableComments( + table_comment=table_comments.get(name), + column_comments=columns_by_table.get(name, {}), + ) + return out + + +def resolve_recovery_context( + *, + study_id: str | None, + registry: StudyRegistry, + ingest_config: IngestConfig, + source_cache_override: Path | None, + target_catalog_override: str | None, + target_schema_override: str | None, +) -> RecoveryContext: + if ( + source_cache_override is not None + and target_catalog_override is not None + and target_schema_override is not None + ): + return RecoveryContext( + study_id=study_id, + source_cache=source_cache_override, + target_catalog=target_catalog_override, + target_schema=target_schema_override, + ) + if study_id is None: + raise PartialOverrideError( + "Recovery requires either --study or the full override set " + "(--source-cache, --target-catalog, --target-schema)." + ) + schema = ( + target_schema_override + if target_schema_override is not None + else registry.find_schema_for_study(study_id) + ) + if schema is None: + raise StudyNotRegisteredError( + f"Study {study_id!r} not in `_sema_study_registry`. " + f"Run `sema ingest cbioportal --study {study_id}` first, or pass " + "`--source-cache`, `--target-catalog`, and `--target-schema` " + "explicitly." + ) + cache_path = ( + source_cache_override + if source_cache_override is not None + else Path(ingest_config.cache_dir).expanduser() / study_id + ) + if not cache_path.exists(): + raise StudyCacheMissingError( + f"Source cache for study {study_id!r} not found at {cache_path}. " + "Re-fetch the study via `sema ingest cbioportal` or pass " + "`--source-cache `." + ) + catalog = ( + target_catalog_override + if target_catalog_override is not None + else ingest_config.databricks.catalog + ) + return RecoveryContext( + study_id=study_id, + source_cache=cache_path, + target_catalog=catalog, + target_schema=schema, + ) diff --git a/src/sema/ingest/databricks_push.py b/src/sema/ingest/databricks_push.py index dc9a433..a609b16 100644 --- a/src/sema/ingest/databricks_push.py +++ b/src/sema/ingest/databricks_push.py @@ -1,11 +1,13 @@ from __future__ import annotations +import tempfile from dataclasses import dataclass from pathlib import Path from typing import Any, Iterator from urllib.parse import urlparse from databricks import sql as databricks_sql +from databricks.sdk import WorkspaceClient from sema.ingest.databricks_push_utils import ( build_copy_into_sql, @@ -17,6 +19,7 @@ copy_into_staging_path, duckdb_to_databricks_type, format_sql_value, + is_uc_volume_path, should_route_via_copy_into, ) from sema.ingest.duckdb_staging import Staging @@ -65,13 +68,20 @@ def _open_connection(self) -> Any: except Exception as exc: raise ConnectionError(f"Failed to connect to Databricks: {exc}") from exc - def ensure_schemas(self) -> None: - for schema in self._schemas: + def ensure_schemas(self, schemas: list[str] | None = None) -> None: + targets = schemas if schemas is not None else self._resolve_targets(None, False) + for schema in targets: self._execute(build_create_schema_sql(self._catalog, schema)) - def push_schemas(self, schemas: list[str] | None = None) -> list[PushResult]: - self.ensure_schemas() - targets = schemas or self._schemas + def push_schemas( + self, + schemas: list[str] | None = None, + discover_all: bool = False, + ) -> list[PushResult]: + targets = self._resolve_targets(schemas, discover_all) + if not targets: + return [] + self.ensure_schemas(schemas=targets) results: list[PushResult] = [] failures: list[tuple[str, str, str]] = [] for schema in targets: @@ -80,6 +90,24 @@ def push_schemas(self, schemas: list[str] | None = None) -> list[PushResult]: raise PushError(failures) return results + def _resolve_targets( + self, explicit: list[str] | None, discover_all: bool + ) -> list[str]: + if explicit: + return list(dict.fromkeys(explicit)) + if discover_all: + discovered = self._staging.list_all_schemas() + logger.warning( + "--discover-all-schemas: pushing every non-system DuckDB schema: {}", + ", ".join(discovered), + ) + else: + discovered = self._staging.list_registered_schemas() + if self._schemas: + allowlist = set(self._schemas) + return [s for s in discovered if s in allowlist] + return discovered + def _push_schema_collect( self, schema: str, failures: list[tuple[str, str, str]] ) -> list[PushResult]: @@ -112,7 +140,9 @@ def push_table(self, schema: str, table: str) -> PushResult: return result def _dispatch_push(self, schema: str, table: str) -> tuple[str, int]: - if should_route_via_copy_into(schema, table) and self._cloud_uri: + row_count = self._count_source_rows(schema, table) + wants_copy = should_route_via_copy_into(schema, table, row_count) + if wants_copy and self._cloud_uri: try: return "copy_into", self._push_via_copy_into(schema, table) except Exception as exc: @@ -120,13 +150,20 @@ def _dispatch_push(self, schema: str, table: str) -> tuple[str, int]: "COPY INTO failed for {}.{}: {}; falling back to INSERT", schema, table, exc, ) - if should_route_via_copy_into(schema, table) and not self._cloud_uri: + if wants_copy and not self._cloud_uri: logger.warning( - "No cloud_staging_uri configured; {}.{} will be loaded via INSERT (slow).", - schema, table, + "No cloud_staging_uri configured; {}.{} ({} rows) will be loaded " + "via INSERT (slow).", + schema, table, row_count, ) return "insert", self._push_via_insert(schema, table) + def _count_source_rows(self, schema: str, table: str) -> int: + row = self._staging.execute( + f'SELECT COUNT(*) FROM "{schema}"."{table}"' + ).fetchone() + return int(row[0]) if row else 0 + def _recreate_target_table(self, schema: str, table: str) -> None: info = self._staging.describe(schema, table) self._execute(build_drop_table_sql(self._catalog, schema, table)) @@ -184,13 +221,15 @@ def _push_via_copy_into(self, schema: str, table: str) -> int: def _export_to_parquet(self, schema: str, table: str, staging_uri: str) -> int: target_dir = copy_into_staging_path(staging_uri, schema, table) + source = f'"{schema}"."{table}"' + if is_uc_volume_path(staging_uri): + return self._export_via_volume_upload(source, target_dir) local_dir = _local_path_for_uri(target_dir) if local_dir is not None: local_dir.mkdir(parents=True, exist_ok=True) duckdb_target = str(local_dir / "data.parquet") else: duckdb_target = target_dir.rstrip("/") + "/data.parquet" - source = f'"{schema}"."{table}"' escaped_target = duckdb_target.replace("'", "''") self._staging.execute( f"COPY (SELECT * FROM {source}) TO '{escaped_target}' (FORMAT 'parquet')" @@ -198,6 +237,29 @@ def _export_to_parquet(self, schema: str, table: str, staging_uri: str) -> int: row = self._staging.execute(f"SELECT COUNT(*) FROM {source}").fetchone() return int(row[0]) if row else 0 + def _export_via_volume_upload(self, source: str, target_dir: str) -> int: + with tempfile.TemporaryDirectory() as tmp: + local_parquet = Path(tmp) / "data.parquet" + escaped_local = str(local_parquet).replace("'", "''") + self._staging.execute( + f"COPY (SELECT * FROM {source}) TO '{escaped_local}' (FORMAT 'parquet')" + ) + row = self._staging.execute(f"SELECT COUNT(*) FROM {source}").fetchone() + row_count = int(row[0]) if row else 0 + volume_path = target_dir.rstrip("/") + "/data.parquet" + with open(local_parquet, "rb") as fh: + self._workspace_client().files.upload( + file_path=volume_path, + contents=fh, + overwrite=True, + ) + return row_count + + def _workspace_client(self) -> WorkspaceClient: + creds = self._config.databricks_creds + host = creds.host if creds.host.startswith("http") else f"https://{creds.host}" + return WorkspaceClient(host=host, token=creds.token.get_secret_value()) + def _count_target(self, schema: str, table: str) -> int: cursor = self._cursor() try: diff --git a/src/sema/ingest/databricks_push_utils.py b/src/sema/ingest/databricks_push_utils.py index 67136ba..eab0139 100644 --- a/src/sema/ingest/databricks_push_utils.py +++ b/src/sema/ingest/databricks_push_utils.py @@ -10,6 +10,8 @@ } ) +COPY_INTO_ROW_THRESHOLD = 100_000 + DUCKDB_TO_DATABRICKS_TYPE: dict[str, str] = { "INTEGER": "INT", "BIGINT": "BIGINT", @@ -25,6 +27,9 @@ } +_FORBIDDEN_IDENT_CHARS = (";", "`", "\n", "\r", "\x00") + + def back_quote(name: str) -> str: return "`" + name.replace("`", "``") + "`" @@ -37,14 +42,33 @@ def escape_sql_literal(value: str) -> str: return value.replace("'", "''") +def validate_identifier(name: str, kind: str) -> None: + if not name: + raise ValueError(f"{kind} identifier must be non-empty") + for ch in _FORBIDDEN_IDENT_CHARS: + if ch in name: + raise ValueError( + f"{kind} identifier contains forbidden character " + f"{ch!r}: {name!r}" + ) + + def duckdb_to_databricks_type(duckdb_type: str) -> str: upper = duckdb_type.strip().upper() base = re.sub(r"\s*\([^)]+\)", "", upper) return DUCKDB_TO_DATABRICKS_TYPE.get(base, "STRING") -def should_route_via_copy_into(schema: str, table: str) -> bool: - return (schema.lower(), table.lower()) in COPY_INTO_TABLES +def should_route_via_copy_into( + schema: str, table: str, row_count: int = 0, +) -> bool: + if (schema.lower(), table.lower()) in COPY_INTO_TABLES: + return True + return row_count >= COPY_INTO_ROW_THRESHOLD + + +def is_uc_volume_path(uri: str) -> bool: + return uri.startswith("/Volumes/") def build_create_schema_sql(catalog: str, schema: str) -> str: @@ -110,3 +134,29 @@ def build_copy_into_sql(catalog: str, schema: str, table: str, staging_uri: str) def build_count_sql(catalog: str, schema: str, table: str) -> str: return f"SELECT COUNT(*) FROM {qualified(catalog, schema, table)}" + + +def build_alter_column_comment_sql( + catalog: str, schema: str, table: str, column: str, comment: str, +) -> str: + validate_identifier(catalog, "catalog") + validate_identifier(schema, "schema") + validate_identifier(table, "table") + validate_identifier(column, "column") + return ( + f"ALTER TABLE {qualified(catalog, schema, table)} " + f"ALTER COLUMN {back_quote(column)} " + f"COMMENT '{escape_sql_literal(comment)}'" + ) + + +def build_alter_table_comment_sql( + catalog: str, schema: str, table: str, comment: str, +) -> str: + validate_identifier(catalog, "catalog") + validate_identifier(schema, "schema") + validate_identifier(table, "table") + return ( + f"COMMENT ON TABLE {qualified(catalog, schema, table)} " + f"IS '{escape_sql_literal(comment)}'" + ) diff --git a/src/sema/ingest/duckdb_staging.py b/src/sema/ingest/duckdb_staging.py index de06bb3..527a0da 100644 --- a/src/sema/ingest/duckdb_staging.py +++ b/src/sema/ingest/duckdb_staging.py @@ -47,6 +47,30 @@ def list_schemas(self) -> list[str]: ).fetchall() return [r[0] for r in rows] + def list_registered_schemas(self) -> list[str]: + from sema.ingest.duckdb_staging_utils import KNOWN_SHARED_SCHEMAS + + names: set[str] = set(KNOWN_SHARED_SCHEMAS) + registry_exists = self._conn.execute( + "SELECT count(*) FROM duckdb_tables() " + "WHERE schema_name = '_sema' AND table_name = '_sema_study_registry'" + ).fetchone() + if registry_exists and registry_exists[0]: + rows = self._conn.execute( + 'SELECT schema_name FROM "_sema"."_sema_study_registry"' + ).fetchall() + for row in rows: + names.add(row[0]) + return sorted(names) + + def list_all_schemas(self) -> list[str]: + from sema.ingest.duckdb_staging_utils import SYSTEM_SCHEMAS + + rows = self._conn.execute( + "SELECT schema_name FROM information_schema.schemata" + ).fetchall() + return sorted(r[0] for r in rows if r[0] not in SYSTEM_SCHEMAS) + def drop_table(self, schema: str, table: str) -> None: self._conn.execute(f"DROP TABLE IF EXISTS {qualified(schema, table)}") diff --git a/src/sema/ingest/duckdb_staging_utils.py b/src/sema/ingest/duckdb_staging_utils.py index 5e6122f..4e61721 100644 --- a/src/sema/ingest/duckdb_staging_utils.py +++ b/src/sema/ingest/duckdb_staging_utils.py @@ -2,7 +2,13 @@ from pathlib import Path -DEFAULT_SCHEMAS: tuple[str, ...] = ("cbioportal", "ontology_omop", "vocabulary_omop") +DEFAULT_SCHEMAS: tuple[str, ...] = ("ontology_omop", "vocabulary_omop") + +KNOWN_SHARED_SCHEMAS: tuple[str, ...] = ("ontology_omop", "vocabulary_omop") + +SYSTEM_SCHEMAS: frozenset[str] = frozenset( + {"main", "information_schema", "pg_catalog", "_sema"} +) def resolve_db_path(raw: str) -> Path: diff --git a/src/sema/ingest/naming.py b/src/sema/ingest/naming.py new file mode 100644 index 0000000..a88d71f --- /dev/null +++ b/src/sema/ingest/naming.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import hashlib +import re + +_MAX_IDENT_LEN = 63 +_HASH_LEN = 10 +_NON_IDENT_CHAR = re.compile(r"[^a-z0-9_]") +_RUN_OF_UNDERSCORES = re.compile(r"_+") + + +def sanitize_schema_name(prefix: str, study_id: str) -> str: + if not study_id: + raise ValueError("study_id must be non-empty") + + sanitized = _sanitize_identifier(study_id) + if not sanitized: + raise ValueError(f"study_id {study_id!r} produces empty identifier after sanitization") + + full = f"{prefix}_{sanitized}" + if len(full) <= _MAX_IDENT_LEN: + return full + + return _truncate_with_hash(prefix, sanitized, study_id) + + +def _sanitize_identifier(raw: str) -> str: + lowered = raw.lower() + replaced = _NON_IDENT_CHAR.sub("_", lowered) + collapsed = _RUN_OF_UNDERSCORES.sub("_", replaced) + return collapsed.strip("_") + + +def _truncate_with_hash(prefix: str, sanitized: str, original: str) -> str: + digest = hashlib.sha256(original.encode("utf-8")).hexdigest()[:_HASH_LEN] + suffix = f"_{digest}" + available = _MAX_IDENT_LEN - len(prefix) - 1 - len(suffix) + if available <= 0: + raise ValueError( + f"prefix {prefix!r} leaves no room for sanitized study_id within {_MAX_IDENT_LEN} chars" + ) + truncated = sanitized[:available].rstrip("_") + if not truncated: + raise ValueError( + f"sanitized study_id collapses to empty after truncation for prefix {prefix!r}" + ) + return f"{prefix}_{truncated}{suffix}" diff --git a/src/sema/ingest/study_registry.py b/src/sema/ingest/study_registry.py new file mode 100644 index 0000000..dae9440 --- /dev/null +++ b/src/sema/ingest/study_registry.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from sema.ingest.duckdb_staging import Staging + +_REGISTRY_SCHEMA = "_sema" +_REGISTRY_TABLE = "_sema_study_registry" +_QUALIFIED = f'"{_REGISTRY_SCHEMA}"."{_REGISTRY_TABLE}"' + + +class StudyCollisionError(RuntimeError): + """Raised when two original study IDs sanitize to the same schema name.""" + + +class StudyRegistry: + def __init__(self, staging: Staging) -> None: + self._staging = staging + self._ensure_table() + + def _ensure_table(self) -> None: + self._staging.execute(f'CREATE SCHEMA IF NOT EXISTS "{_REGISTRY_SCHEMA}"') + self._staging.execute( + f""" + CREATE TABLE IF NOT EXISTS {_QUALIFIED} ( + schema_name TEXT PRIMARY KEY, + original_study_id TEXT NOT NULL, + source_type TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + def register(self, schema_name: str, original_study_id: str, source_type: str) -> None: + existing = self._lookup(schema_name) + if existing is not None: + if existing == original_study_id: + return + raise StudyCollisionError( + f"Schema name {schema_name!r} is already registered to study " + f"{existing!r}; cannot also register study {original_study_id!r}. " + f"Rename one of the studies or choose disambiguating IDs." + ) + self._staging.execute( + f"INSERT INTO {_QUALIFIED} (schema_name, original_study_id, source_type) " + "VALUES (?, ?, ?)", + [schema_name, original_study_id, source_type], + ) + + def _lookup(self, schema_name: str) -> str | None: + row = self._staging.execute( + f"SELECT original_study_id FROM {_QUALIFIED} WHERE schema_name = ?", + [schema_name], + ).fetchone() + return row[0] if row else None + + def list_schemas(self) -> list[str]: + rows = self._staging.execute( + f"SELECT schema_name FROM {_QUALIFIED} ORDER BY schema_name" + ).fetchall() + return [r[0] for r in rows] + + def find_schema_for_study(self, study_id: str) -> str | None: + row = self._staging.execute( + f"SELECT schema_name FROM {_QUALIFIED} WHERE original_study_id = ?", + [study_id], + ).fetchone() + return row[0] if row else None diff --git a/src/sema/llm_client.py b/src/sema/llm_client.py index ddc1d33..bb6b29d 100644 --- a/src/sema/llm_client.py +++ b/src/sema/llm_client.py @@ -89,6 +89,9 @@ def __init__( TRANSIENT_STATUS_CODES: Final[frozenset[int]] = frozenset({429, 500, 502, 503, 504}) NON_RETRYABLE_STATUS_CODES: Final[frozenset[int]] = frozenset({401, 403}) +RATE_LIMIT_BASE_DELAY: Final[float] = 10.0 +RATE_LIMIT_MULTIPLIER: Final[float] = 3.0 + def _is_transient_error(exc: Exception) -> bool: status = getattr(exc, "status_code", None) or getattr(exc, "http_status", None) @@ -102,6 +105,23 @@ def _is_transient_error(exc: Exception) -> bool: return False +def _is_rate_limit_error(exc: Exception) -> bool: + status = getattr(exc, "status_code", None) or getattr(exc, "http_status", None) + if status is not None and int(status) == 429: + return True + msg = str(exc).lower() + return any( + kw in msg + for kw in ("rate limit", "429", "too many requests", "request_limit_exceeded") + ) + + +def _all_step_errors_rate_limited(stage_err: "LLMStageError") -> bool: + if not stage_err.step_errors: + return False + return all(_is_rate_limit_error(err) for _, err in stage_err.step_errors) + + def parse_llm_response(raw: str, schema: type[T]) -> T: """Universal fallback parser for raw LLM text responses.""" text = raw.strip() @@ -158,6 +178,8 @@ def __init__( retry_base_delay: float = 2.0, retry_multiplier: float = 2.0, retry_jitter: float = 0.5, + rate_limit_base_delay: float = RATE_LIMIT_BASE_DELAY, + rate_limit_multiplier: float = RATE_LIMIT_MULTIPLIER, use_structured_output: str = "auto", circuit_breaker: Any | None = None, ): @@ -166,6 +188,8 @@ def __init__( self._retry_base_delay = retry_base_delay self._retry_multiplier = retry_multiplier self._retry_jitter = retry_jitter + self._rate_limit_base_delay = rate_limit_base_delay + self._rate_limit_multiplier = rate_limit_multiplier self._circuit_breaker = circuit_breaker self._supports_structured = resolve_structured_support( @@ -200,8 +224,10 @@ def invoke( result = self._invoke_fallback_chain( prompt, schema, table_ref, stage_name, simplified_prompt, ) - except LLMStageError: - if self._circuit_breaker is not None: + except LLMStageError as stage_err: + if self._circuit_breaker is not None and not _all_step_errors_rate_limited( + stage_err, + ): self._circuit_breaker.record_failure() self._record_stats(start, prompt, "") raise @@ -330,7 +356,13 @@ def _invoke_with_retry(self, fn: Callable[..., Any], step_name: str = "") -> Any last_error = e if not _is_transient_error(e) or attempt == self._retry_max_attempts - 1: raise - delay = self._retry_base_delay * (self._retry_multiplier ** attempt) + if _is_rate_limit_error(e): + base = self._rate_limit_base_delay + mult = self._rate_limit_multiplier + else: + base = self._retry_base_delay + mult = self._retry_multiplier + delay = base * (mult ** attempt) jitter = random.uniform(-self._retry_jitter, self._retry_jitter) sleep_time = max(0, delay + jitter) logger.debug( diff --git a/src/sema/models/assertions.py b/src/sema/models/assertions.py index 4df0b8b..332a511 100644 --- a/src/sema/models/assertions.py +++ b/src/sema/models/assertions.py @@ -37,6 +37,7 @@ class AssertionPredicate(str, Enum): HAS_JOIN_EVIDENCE = "has_join_evidence" ENTITY_ON_TABLE = "entity_on_table" PROPERTY_ON_COLUMN = "property_on_column" + FK_TO = "fk_to" class AssertionStatus(str, Enum): @@ -64,3 +65,4 @@ class Assertion(BaseModel): status: AssertionStatus = AssertionStatus.AUTO run_id: str observed_at: datetime + source_schema: str | None = None diff --git a/src/sema/models/config.py b/src/sema/models/config.py index 720be7b..7d0f060 100644 --- a/src/sema/models/config.py +++ b/src/sema/models/config.py @@ -101,6 +101,9 @@ class BuildConfig(BaseSettings): enable_few_shot: bool = True enable_stage_c: bool = True + enable_fk_detection: bool = True + materialize_structural_fk: bool = False + eval_dump_dir: str | None = None eval_config_label: str = "run" slice_tables: list[str] = [] @@ -120,6 +123,10 @@ class BuildConfig(BaseSettings): embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig) profiling: ProfilingConfig = Field(default_factory=ProfilingConfig) + @property + def fk_materialization_threshold(self) -> float: + return 0.70 if self.materialize_structural_fk else 0.80 + @classmethod def from_file(cls, path: str, overrides: dict[str, Any] | None = None) -> BuildConfig: with open(path) as f: @@ -133,9 +140,7 @@ class IngestDatabricksTargetConfig(BaseSettings): model_config = SettingsConfigDict(env_prefix="INGEST_DATABRICKS_") catalog: str = "workspace" - schemas: list[str] = Field( - default_factory=lambda: ["cbioportal", "ontology_omop", "vocabulary_omop"] - ) + schemas: list[str] = Field(default_factory=list) class IngestOmopConfig(BaseSettings): diff --git a/src/sema/pipeline/build_utils.py b/src/sema/pipeline/build_utils.py index 84f64b3..b46f834 100644 --- a/src/sema/pipeline/build_utils.py +++ b/src/sema/pipeline/build_utils.py @@ -392,12 +392,16 @@ def _commit_and_materialize( f"[{work_item.table_name}] Committing " f"{len(all_assertions)} assertions..." ) - loader.commit_table_assertions(all_assertions) + loader.commit_table_assertions( + all_assertions, source_schema=work_item.schema, + ) logger.info( f"[{work_item.table_name}] Materializing graph..." ) from sema.graph.materializer import materialize_unified - materialize_unified(loader, all_assertions) + materialize_unified( + loader, all_assertions, source_schema=work_item.schema, + ) logger.info( f"[{work_item.table_name}] Done" ) diff --git a/src/sema/pipeline/orchestrate.py b/src/sema/pipeline/orchestrate.py index 4b737ca..48a9418 100644 --- a/src/sema/pipeline/orchestrate.py +++ b/src/sema/pipeline/orchestrate.py @@ -25,6 +25,7 @@ _retrieve_context, _spawn_workers, _spawn_workers_parallel, + run_fk_detection, ) @@ -61,6 +62,9 @@ def run_build(config: BuildConfig) -> dict[str, Any]: driver.close() return aggregate_report([]) + for schema in sorted({wi.schema for wi in work_items}): + loader.delete_study_scoped(schema) + circuit_breaker = CircuitBreaker( failure_threshold=config.circuit_breaker_threshold, recovery_timeout=float(config.circuit_breaker_timeout), @@ -108,6 +112,11 @@ def run_build(config: BuildConfig) -> dict[str, Any]: report = _collect_results(results) + schemas = sorted({wi.schema for wi in work_items}) + run_fk_detection( + loader, discovery_connector, config, schemas, run_id=run_id, + ) + _compute_embeddings(config, loader, skip_embeddings=config.skip_embeddings) driver.close() diff --git a/src/sema/pipeline/orchestrate_utils.py b/src/sema/pipeline/orchestrate_utils.py index c612dab..aa2e1ae 100644 --- a/src/sema/pipeline/orchestrate_utils.py +++ b/src/sema/pipeline/orchestrate_utils.py @@ -15,6 +15,16 @@ _get_embedder, _get_neo4j_driver, ) +from sema.engine.join_detector import JoinDetector, to_fk_assertion +from sema.engine.join_detector_utils import ( + enumerate_candidates_from_metadata, +) +from sema.engine.warehouse_lookup import ( + WarehouseProfileLookup, + WarehouseSampler, +) +from sema.graph.join_materializer import materialize_join_paths +from sema.graph.loader_utils import fetch_columns_by_schema from sema.models.config import ( BuildConfig, QueryConfig, @@ -297,6 +307,74 @@ def _embed_label_nodes( ) +def run_fk_detection( + loader: Any, + connector: Any, + config: BuildConfig, + schemas: list[str], + run_id: str, +) -> None: + """Detect FK candidates per schema and materialize JoinPaths. + + Skips entirely when `enable_fk_detection=False`. Threshold derives + from `materialize_structural_fk` (0.70 if opt-in, else 0.80) so + Tier-3 structural-only matches stay gated by an explicit user flag. + The detector receives a lazy `sampler` (Tier 1 sample-set + verification) and a pre-built `profiles` dict for candidate columns + (Tier 2 cardinality verification). + """ + if not config.enable_fk_detection: + return + detector = JoinDetector( + materialization_threshold=config.fk_materialization_threshold, + ) + sampler = WarehouseSampler( + query_fn=connector._execute, catalog=config.catalog, + ) + profile_lookup = WarehouseProfileLookup( + query_fn=connector._execute, catalog=config.catalog, + ) + for schema in schemas: + columns = fetch_columns_by_schema(loader, schema) + if not columns: + continue + profiles = _prebuild_profiles_for_candidates( + columns, profile_lookup, + ) + fks = detector.detect( + columns=columns, source_schema=schema, + profiles=profiles, sampler=sampler, + ) + keep = [fk for fk in fks if detector.should_materialize(fk)] + if not keep: + continue + assertions = [to_fk_assertion(fk, run_id) for fk in keep] + groups = { + (a.subject_ref, a.predicate.value): [a] + for a in assertions + } + materialize_join_paths( + loader, groups, source_schema=schema, + ) + + +def _prebuild_profiles_for_candidates( + columns: list[Any], + profile_lookup: Any, +) -> dict[tuple[str, str, str], tuple[int, int]]: + candidates = enumerate_candidates_from_metadata(columns) + keys: set[tuple[str, str, str]] = set() + for c in candidates: + keys.add((c.schema_name, c.pk_table, c.pk_column)) + keys.add((c.schema_name, c.fk_table, c.fk_column)) + profiles: dict[tuple[str, str, str], tuple[int, int]] = {} + for key in keys: + stats = profile_lookup(key) + if stats is not None: + profiles[key] = stats + return profiles + + def _retrieve_context(config: QueryConfig) -> Any: driver = _get_neo4j_driver(config.neo4j) diff --git a/tests/integration/test_join_path_source_schema.py b/tests/integration/test_join_path_source_schema.py new file mode 100644 index 0000000..a386619 --- /dev/null +++ b/tests/integration/test_join_path_source_schema.py @@ -0,0 +1,179 @@ +"""Integration test 12.15: patient↔sample JoinPath wired by {name, source_schema}. + +Asserts that materialize_join_paths produces a `:JoinPath` whose `:USES`, +`:FROM_ENTITY`, and `:TO_ENTITY` edges all match the JoinPath by the full +{name, source_schema} match key — guaranteeing two studies emitting the +same logical join produce two distinct paths. +""" +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +from sema.graph.join_materializer import materialize_join_paths +from sema.graph.loader import GraphLoader +from sema.models.assertions import ( + Assertion, + AssertionPredicate, + AssertionStatus, +) + +pytestmark = pytest.mark.integration + +CATALOG = "workspace" +SCHEMA_A = "cbioportal_brca_tcga_pan_can_atlas_2018" +SCHEMA_B = "cbioportal_msk_chord_2024" + + +@pytest.fixture +def loader(clean_neo4j): + return GraphLoader(clean_neo4j) + + +def _ref(catalog: str, schema: str, table: str, column: str | None = None) -> str: + base = f"databricks://workspace/{catalog}/{schema}/{table}" + return f"{base}/{column}" if column else base + + +def _seed_physical(loader: GraphLoader, schema: str) -> None: + loader.upsert_table("patient", schema, CATALOG) + loader.upsert_table("sample", schema, CATALOG) + loader.upsert_column("patient_id", "patient", schema, CATALOG, "STRING") + loader.upsert_column("patient_id", "sample", schema, CATALOG, "STRING") + loader._run( + "MATCH (t:Table {name: 'patient', schema_name: $s, catalog: $c}) " + "SET t.ref = $ref", + s=schema, c=CATALOG, ref=_ref(CATALOG, schema, "patient"), + ) + loader._run( + "MATCH (t:Table {name: 'sample', schema_name: $s, catalog: $c}) " + "SET t.ref = $ref", + s=schema, c=CATALOG, ref=_ref(CATALOG, schema, "sample"), + ) + loader._run( + "MATCH (c:Column {name: 'patient_id', table_name: $t, " + "schema_name: $s, catalog: $cat}) SET c.ref = $ref", + t="patient", s=schema, cat=CATALOG, + ref=_ref(CATALOG, schema, "patient", "patient_id"), + ) + loader._run( + "MATCH (c:Column {name: 'patient_id', table_name: $t, " + "schema_name: $s, catalog: $cat}) SET c.ref = $ref", + t="sample", s=schema, cat=CATALOG, + ref=_ref(CATALOG, schema, "sample", "patient_id"), + ) + loader.upsert_entity( + name="Patient", description=None, source="test", confidence=0.9, + table_name="patient", schema_name=schema, catalog=CATALOG, + source_schema=schema, + ) + loader.upsert_entity( + name="Sample", description=None, source="test", confidence=0.9, + table_name="sample", schema_name=schema, catalog=CATALOG, + source_schema=schema, + ) + + +def _join_assertion(schema: str) -> Assertion: + return Assertion( + id=f"j-{schema}", + subject_ref=f"join_evidence://{schema}/patient_sample", + predicate=AssertionPredicate.HAS_JOIN_EVIDENCE, + payload={ + "join_predicates": [{ + "left_table": _ref(CATALOG, schema, "sample"), + "left_column": _ref(CATALOG, schema, "sample", "patient_id"), + "right_table": _ref(CATALOG, schema, "patient"), + "right_column": _ref(CATALOG, schema, "patient", "patient_id"), + }], + "hop_count": 1, + "from_table": _ref(CATALOG, schema, "sample"), + "to_table": _ref(CATALOG, schema, "patient"), + }, + source="fk_detector", + confidence=0.95, + status=AssertionStatus.AUTO, + run_id="run-1", + observed_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + +def _materialize(loader: GraphLoader, assertion: Assertion, schema: str) -> None: + groups = {(assertion.subject_ref, assertion.predicate.value): [assertion]} + materialize_join_paths(loader, groups, source_schema=schema) + + +class TestJoinPathSourceSchemaScoping: + def test_two_studies_produce_two_distinct_join_paths( + self, loader, clean_neo4j, + ): + _seed_physical(loader, SCHEMA_A) + _seed_physical(loader, SCHEMA_B) + _materialize(loader, _join_assertion(SCHEMA_A), SCHEMA_A) + _materialize(loader, _join_assertion(SCHEMA_B), SCHEMA_B) + + with clean_neo4j.session() as s: + jp_count = s.run( + "MATCH (jp:JoinPath) RETURN count(jp) AS c" + ).single()["c"] + schemas = sorted(r["s"] for r in s.run( + "MATCH (jp:JoinPath) RETURN jp.source_schema AS s" + )) + assert jp_count == 2 + assert schemas == [SCHEMA_A, SCHEMA_B] + + def test_uses_and_entity_edges_scoped_by_source_schema( + self, loader, clean_neo4j, + ): + _seed_physical(loader, SCHEMA_A) + _materialize(loader, _join_assertion(SCHEMA_A), SCHEMA_A) + + with clean_neo4j.session() as s: + uses = s.run( + "MATCH (jp:JoinPath {source_schema: $sc})" + "-[u:USES]->(t:Table) RETURN count(u) AS c, " + "collect(DISTINCT u.source_schema) AS ss", + sc=SCHEMA_A, + ).single() + from_entity = s.run( + "MATCH (jp:JoinPath {source_schema: $sc})" + "-[r:FROM_ENTITY]->(:Entity {name: 'Sample'}) " + "RETURN count(r) AS c, " + "collect(DISTINCT r.source_schema) AS ss", + sc=SCHEMA_A, + ).single() + to_entity = s.run( + "MATCH (jp:JoinPath {source_schema: $sc})" + "-[r:TO_ENTITY]->(:Entity {name: 'Patient'}) " + "RETURN count(r) AS c, " + "collect(DISTINCT r.source_schema) AS ss", + sc=SCHEMA_A, + ).single() + assert uses["c"] >= 2 + assert uses["ss"] == [SCHEMA_A] + assert from_entity["c"] == 1 + assert from_entity["ss"] == [SCHEMA_A] + assert to_entity["c"] == 1 + assert to_entity["ss"] == [SCHEMA_A] + + def test_scoped_delete_removes_only_one_studys_join_path( + self, loader, clean_neo4j, + ): + _seed_physical(loader, SCHEMA_A) + _seed_physical(loader, SCHEMA_B) + _materialize(loader, _join_assertion(SCHEMA_A), SCHEMA_A) + _materialize(loader, _join_assertion(SCHEMA_B), SCHEMA_B) + + loader.delete_study_scoped(SCHEMA_A) + + with clean_neo4j.session() as s: + remaining = sorted(r["s"] for r in s.run( + "MATCH (jp:JoinPath) RETURN jp.source_schema AS s" + )) + uses_left = s.run( + "MATCH (:JoinPath)-[u:USES]->() RETURN count(u) AS c, " + "collect(DISTINCT u.source_schema) AS ss" + ).single() + assert remaining == [SCHEMA_B] + assert uses_left["ss"] == [SCHEMA_B] diff --git a/tests/integration/test_multi_study_lifecycle.py b/tests/integration/test_multi_study_lifecycle.py new file mode 100644 index 0000000..19bef2d --- /dev/null +++ b/tests/integration/test_multi_study_lifecycle.py @@ -0,0 +1,268 @@ +"""Integration tests for multi-study graph lifecycle. + +Covers tasks 2.4c, 3.8, 4.6, 4.7, 4.8 of expand-healthcare-eval-coverage: +- ontology-preloaded edges survive scoped-delete (2.4c) +- no :Entity carries `datasource_id` property after a build (3.8) +- A→B→A reload yields unchanged A counts and untouched B (4.6) +- shared `:Term {code: HGNC:TP53}` survives a single-study rebuild (4.7) +- prior edges not re-emitted by the rebuild are removed (4.8) +""" +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +from sema.graph.loader import GraphLoader +from sema.models.assertions import ( + Assertion, + AssertionPredicate, + AssertionStatus, +) + +pytestmark = pytest.mark.integration + +SCHEMA_A = "cbioportal_brca_tcga_pan_can_atlas_2018" +SCHEMA_B = "cbioportal_msk_chord_2024" + + +@pytest.fixture +def loader(clean_neo4j): + return GraphLoader(clean_neo4j) + + +def _count(driver, cypher: str, **params) -> int: + with driver.session() as s: + rec = s.run(cypher, **params).single() + return int(rec["c"]) if rec else 0 + + +def _make_assertion(subject_ref: str, run_id: str) -> Assertion: + return Assertion( + id=f"a-{abs(hash(subject_ref + run_id)) % 100000}", + subject_ref=subject_ref, + predicate=AssertionPredicate.HAS_LABEL, + payload={"value": "v"}, + source="test", + confidence=0.9, + status=AssertionStatus.AUTO, + run_id=run_id, + observed_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + +def _seed_study( + loader: GraphLoader, + schema: str, + table: str, + *, + entity: str, + property_name: str, + term_code: str, + run_id: str, +) -> None: + catalog = "workspace" + loader.upsert_table(table, schema, catalog) + loader.upsert_column(property_name, table, schema, catalog, data_type="STRING") + loader.upsert_entity( + name=entity, description=None, source="test", confidence=0.9, + table_name=table, schema_name=schema, catalog=catalog, + source_schema=schema, + ) + loader.upsert_property( + name=property_name, semantic_type="genomic_id", + source="test", confidence=0.9, + entity_name=entity, column_name=property_name, + table_name=table, schema_name=schema, catalog=catalog, + source_schema=schema, + ) + loader.upsert_term( + code=term_code, label=term_code, source="test", confidence=0.95, + ) + loader.upsert_value_set( + name=property_name, column_name=property_name, table_name=table, + schema_name=schema, catalog=catalog, source_schema=schema, + ) + loader.add_term_to_value_set( + term_code=term_code, value_set_name=property_name, + source_schema=schema, + ) + a = _make_assertion( + subject_ref=f"databricks://{catalog}/{schema}/{table}", run_id=run_id, + ) + loader.store_assertion(a, source_schema=schema) + + +def _add_ontology_parent_of(driver) -> None: + """Pre-load a hierarchy edge with no source_schema (ontology-global).""" + with driver.session() as s: + s.run( + "MERGE (p:Term {code: 'HGNC:TP53'}) " + "MERGE (c:Term {code: 'HGNC:TP53_VARIANT'}) " + "MERGE (p)-[:PARENT_OF]->(c)" + ) + + +class TestOntologyEdgeSurvivesScopedDelete: + def test_preloaded_parent_of_not_swept(self, loader, clean_neo4j): + _seed_study( + loader, SCHEMA_A, "patient", + entity="Patient", property_name="hugo_symbol", + term_code="HGNC:TP53", run_id="run-A", + ) + _add_ontology_parent_of(clean_neo4j) + + before = _count( + clean_neo4j, + "MATCH ()-[r:PARENT_OF]->() WHERE r.source_schema IS NULL " + "RETURN count(r) AS c", + ) + assert before == 1 + + loader.delete_study_scoped(SCHEMA_A) + + after = _count( + clean_neo4j, + "MATCH ()-[r:PARENT_OF]->() WHERE r.source_schema IS NULL " + "RETURN count(r) AS c", + ) + assert after == 1, "ontology-preloaded PARENT_OF must survive sweep" + + +class TestEntityHasNoDatasourceId: + def test_no_entity_carries_datasource_id_after_build(self, loader, clean_neo4j): + _seed_study( + loader, SCHEMA_A, "patient", + entity="Patient", property_name="hugo_symbol", + term_code="HGNC:TP53", run_id="run-A", + ) + _seed_study( + loader, SCHEMA_B, "patient", + entity="Patient", property_name="hugo_symbol", + term_code="HGNC:TP53", run_id="run-B", + ) + with_ds = _count( + clean_neo4j, + "MATCH (e:Entity) WHERE e.datasource_id IS NOT NULL " + "RETURN count(e) AS c", + ) + assert with_ds == 0 + + +class TestReloadIdempotence: + def test_reload_a_leaves_a_count_unchanged_and_b_untouched( + self, loader, clean_neo4j, + ): + _seed_study( + loader, SCHEMA_A, "patient", + entity="Patient", property_name="hugo_symbol", + term_code="HGNC:TP53", run_id="run-A1", + ) + _seed_study( + loader, SCHEMA_B, "patient", + entity="Patient", property_name="hugo_symbol", + term_code="HGNC:TP53", run_id="run-B1", + ) + a_first = _count( + clean_neo4j, + "MATCH (a:Assertion {source_schema: $s}) RETURN count(a) AS c", + s=SCHEMA_A, + ) + b_first = _count( + clean_neo4j, + "MATCH (a:Assertion {source_schema: $s}) RETURN count(a) AS c", + s=SCHEMA_B, + ) + assert a_first > 0 and b_first > 0 + + loader.delete_study_scoped(SCHEMA_A) + _seed_study( + loader, SCHEMA_A, "patient", + entity="Patient", property_name="hugo_symbol", + term_code="HGNC:TP53", run_id="run-A2", + ) + + a_third = _count( + clean_neo4j, + "MATCH (a:Assertion {source_schema: $s}) RETURN count(a) AS c", + s=SCHEMA_A, + ) + b_third = _count( + clean_neo4j, + "MATCH (a:Assertion {source_schema: $s}) RETURN count(a) AS c", + s=SCHEMA_B, + ) + assert a_third == a_first + assert b_third == b_first + + +class TestSharedTermSurvivesRebuild: + def test_tp53_term_remains_after_one_study_rebuild( + self, loader, clean_neo4j, + ): + _seed_study( + loader, SCHEMA_A, "patient", + entity="Patient", property_name="hugo_symbol", + term_code="HGNC:TP53", run_id="run-A", + ) + _seed_study( + loader, SCHEMA_B, "patient", + entity="Patient", property_name="hugo_symbol", + term_code="HGNC:TP53", run_id="run-B", + ) + before_a = _count( + clean_neo4j, + "MATCH (t:Term {code: 'HGNC:TP53'})-[m:MEMBER_OF " + "{source_schema: $s}]->(:ValueSet) RETURN count(m) AS c", + s=SCHEMA_A, + ) + assert before_a > 0 + + loader.delete_study_scoped(SCHEMA_A) + + term_count = _count( + clean_neo4j, + "MATCH (t:Term {code: 'HGNC:TP53'}) RETURN count(t) AS c", + ) + assert term_count == 1 + + a_after = _count( + clean_neo4j, + "MATCH (t:Term {code: 'HGNC:TP53'})-[m:MEMBER_OF " + "{source_schema: $s}]->(:ValueSet) RETURN count(m) AS c", + s=SCHEMA_A, + ) + assert a_after == 0 + + b_after = _count( + clean_neo4j, + "MATCH (t:Term {code: 'HGNC:TP53'})-[m:MEMBER_OF " + "{source_schema: $s}]->(:ValueSet) RETURN count(m) AS c", + s=SCHEMA_B, + ) + assert b_after >= 1 + + +class TestPriorEdgesRemovedOnRebuild: + def test_prior_edges_not_emitted_in_rebuild_are_removed( + self, loader, clean_neo4j, + ): + _seed_study( + loader, SCHEMA_A, "patient", + entity="Patient", property_name="hugo_symbol", + term_code="HGNC:TP53", run_id="run-A1", + ) + first = _count( + clean_neo4j, + "MATCH ()-[r {source_schema: $s}]-() RETURN count(r) AS c", + s=SCHEMA_A, + ) + assert first > 0 + + loader.delete_study_scoped(SCHEMA_A) + after_delete = _count( + clean_neo4j, + "MATCH ()-[r {source_schema: $s}]-() RETURN count(r) AS c", + s=SCHEMA_A, + ) + assert after_delete == 0 diff --git a/tests/scripts/__init__.py b/tests/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/scripts/test_migrate_cbioportal.py b/tests/scripts/test_migrate_cbioportal.py new file mode 100644 index 0000000..3b11619 --- /dev/null +++ b/tests/scripts/test_migrate_cbioportal.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +import pytest +from click.testing import CliRunner + +from scripts.migrate_cbioportal_to_namespaced import _rename_schema, main +from sema.ingest.comment_recovery import ParsedTableComments +from sema.ingest.duckdb_staging import Staging + +pytestmark = pytest.mark.unit + + +def _staging_with_legacy(tmp_path: Path) -> Staging: + s = Staging(str(tmp_path / "stg.duckdb")) + s.execute('CREATE SCHEMA IF NOT EXISTS "cbioportal"') + s.execute( + 'CREATE TABLE "cbioportal"."patient" ' + '(PATIENT_ID VARCHAR, AGE INTEGER)' + ) + s.execute("INSERT INTO \"cbioportal\".\"patient\" VALUES ('P-001', 42)") + return s + + +def test_rename_schema_reapplies_column_and_table_comments(tmp_path: Path) -> None: + staging = _staging_with_legacy(tmp_path) + + def comment_source(table: str) -> ParsedTableComments: + if table == "patient": + return ParsedTableComments( + table_comment="patient table comment", + column_comments={ + "PATIENT_ID": "Patient identifier.", + "AGE": "Age at diagnosis.", + }, + ) + return ParsedTableComments(table_comment=None, column_comments={}) + + _rename_schema( + staging, "cbioportal", "cbioportal_x", comment_source=comment_source, + ) + info = staging.describe("cbioportal_x", "patient") + assert info.columns["PATIENT_ID"].comment == "Patient identifier." + assert info.columns["AGE"].comment == "Age at diagnosis." + assert info.table_comment == "patient table comment" + staging.close() + + +def test_rename_schema_without_comment_source_completes(tmp_path: Path) -> None: + staging = _staging_with_legacy(tmp_path) + _rename_schema(staging, "cbioportal", "cbioportal_x") + info = staging.describe("cbioportal_x", "patient") + assert "PATIENT_ID" in info.columns + staging.close() + + +def test_rename_schema_warns_when_comment_source_raises( + tmp_path: Path, caplog: pytest.LogCaptureFixture, +) -> None: + staging = _staging_with_legacy(tmp_path) + + def failing_source(table: str) -> ParsedTableComments: + raise FileNotFoundError(f"no cache for {table}") + + _rename_schema( + staging, "cbioportal", "cbioportal_x", + comment_source=failing_source, + ) + info = staging.describe("cbioportal_x", "patient") + assert "PATIENT_ID" in info.columns + assert info.columns["PATIENT_ID"].comment in (None, "") + staging.close() + + +def test_migration_idempotent_when_target_already_exists(tmp_path: Path) -> None: + duckdb_path = tmp_path / "stg.duckdb" + staging = Staging(str(duckdb_path)) + staging.execute('CREATE SCHEMA IF NOT EXISTS "cbioportal_x"') + staging.execute( + 'CREATE TABLE "cbioportal_x"."patient" (PATIENT_ID VARCHAR)' + ) + staging.close() + + runner = CliRunner() + cache_root = tmp_path / "cache" + (cache_root / "study_x").mkdir(parents=True) + with patch.dict("os.environ", { + "INGEST_DUCKDB_PATH": str(duckdb_path), + "INGEST_CACHE_DIR": str(cache_root), + }): + result = runner.invoke( + main, + ["--duckdb-path", str(duckdb_path), "--study-id", "study_x"], + ) + assert result.exit_code == 0, result.output + + +def test_migration_with_full_flow_preserves_comments(tmp_path: Path) -> None: + duckdb_path = tmp_path / "stg.duckdb" + staging = Staging(str(duckdb_path)) + staging.execute('CREATE SCHEMA IF NOT EXISTS "cbioportal"') + staging.execute( + 'CREATE TABLE "cbioportal"."patient" (PATIENT_ID VARCHAR)' + ) + staging.close() + + cache_root = tmp_path / "cache" + (cache_root / "study_x").mkdir(parents=True) + (cache_root / "study_x" / "data_clinical_patient.txt").write_text( + "#Patient Identifier\n" + "#Identifier to uniquely specify a patient.\n" + "#STRING\n" + "#1\n" + "PATIENT_ID\n" + "P-001\n", + encoding="utf-8", + ) + + runner = CliRunner() + with patch.dict("os.environ", { + "INGEST_DUCKDB_PATH": str(duckdb_path), + "INGEST_CACHE_DIR": str(cache_root), + }): + result = runner.invoke( + main, + ["--duckdb-path", str(duckdb_path), "--study-id", "study_x"], + ) + assert result.exit_code == 0, result.output + + staging = Staging(str(duckdb_path)) + info = staging.describe("cbioportal_study_x", "patient") + assert ( + info.columns["PATIENT_ID"].comment + == "Identifier to uniquely specify a patient." + ) + staging.close() diff --git a/tests/showcase/cbioportal_to_omop/test_comment_extract.py b/tests/showcase/cbioportal_to_omop/test_comment_extract.py new file mode 100644 index 0000000..bccf5b7 --- /dev/null +++ b/tests/showcase/cbioportal_to_omop/test_comment_extract.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from showcase.cbioportal_to_omop.comment_extract import extract_study_comments + +pytestmark = pytest.mark.unit + + +def _write(path: Path, content: str) -> None: + path.write_text(content, encoding="utf-8") + + +def test_extract_clinical_patient_threads_column_and_table_comments( + tmp_path: Path, +) -> None: + _write( + tmp_path / "data_clinical_patient.txt", + "#Patient Identifier\tAge at Diagnosis\n" + "#Identifier to uniquely specify a patient.\tAge at first diagnosis.\n" + "#STRING\tNUMBER\n" + "#1\t1\n" + "PATIENT_ID\tAGE\n" + "P-001\t42\n", + ) + out = extract_study_comments(tmp_path) + assert "patient" in out + patient = out["patient"] + assert ( + patient.column_comments["PATIENT_ID"] + == "Identifier to uniquely specify a patient." + ) + assert patient.column_comments["AGE"] == "Age at first diagnosis." + assert patient.table_comment == ( + "cBioPortal clinical patient from data_clinical_patient.txt" + ) + + +def test_extract_clinical_supp_files(tmp_path: Path) -> None: + _write( + tmp_path / "data_clinical_supp_hypoxia.txt", + "#Hypoxia Score\n" + "#Buffa hypoxia score.\n" + "#NUMBER\n" + "#1\n" + "BUFFA_HYPOXIA_SCORE\n" + "0.5\n", + ) + out = extract_study_comments(tmp_path) + assert "clinical_supp_hypoxia" in out + assert ( + out["clinical_supp_hypoxia"].column_comments["BUFFA_HYPOXIA_SCORE"] + == "Buffa hypoxia score." + ) + + +def test_extract_non_clinical_table_has_empty_column_comments( + tmp_path: Path, +) -> None: + _write( + tmp_path / "data_mutations.txt", + "Hugo_Symbol\tStart_Position\n" "TP53\t7571720\n", + ) + out = extract_study_comments(tmp_path) + assert "mutation" in out + assert out["mutation"].column_comments == {} + assert out["mutation"].table_comment is not None + assert "MAF" in out["mutation"].table_comment + + +def test_extract_returns_empty_dict_for_empty_dir(tmp_path: Path) -> None: + assert extract_study_comments(tmp_path) == {} + + +def test_extract_timeline_table(tmp_path: Path) -> None: + _write( + tmp_path / "data_timeline_treatment.txt", + "PATIENT_ID\tSTART_DATE\n" "P-001\t0\n", + ) + out = extract_study_comments(tmp_path) + assert "timeline_treatment" in out + assert out["timeline_treatment"].column_comments == {} diff --git a/tests/showcase/cbioportal_to_omop/test_extended_parsers.py b/tests/showcase/cbioportal_to_omop/test_extended_parsers.py index 067ae4c..50b6155 100644 --- a/tests/showcase/cbioportal_to_omop/test_extended_parsers.py +++ b/tests/showcase/cbioportal_to_omop/test_extended_parsers.py @@ -110,9 +110,9 @@ def test_parses_sample_to_panel_assignments(self, tmp_path: Path) -> None: "SAMPLE-2\tIMPACT410\tIMPACT410\tIMPACT410\n", ) rows, types, _ = parse_gene_panel_matrix(path) - assert rows.num_rows == 2 - assert "SAMPLE_ID" in rows.column_names - assert types["mutations"] == "VARCHAR" + assert set(rows.column_names) == {"sample_id", "panel_id", "assay"} + assert rows.num_rows == 6 # 2 samples × 3 assays + assert types["assay"] == "VARCHAR" class TestParseResourceFile: @@ -192,7 +192,7 @@ def test_ingests_sv_cna_panel_matrix_and_resources( db_path=str(tmp_path / "db.duckdb"), schemas=("cbioportal",), ) - _ingest_study_dir("test_study", study_dir, staging) + _ingest_study_dir("test_study", study_dir, staging, schema_name="cbioportal") for tbl in ( "structural_variant", "cna", "gene_panel_matrix", "resource_definition", "resource_patient", @@ -200,3 +200,35 @@ def test_ingests_sv_cna_panel_matrix_and_resources( info = staging.describe("cbioportal", tbl) assert info.columns, f"{tbl} should have columns" staging.close() + + +class TestParseMafCaseInsensitiveDedupe: + def test_duplicate_columns_are_suffixed(self, tmp_path: Path) -> None: + from showcase.cbioportal_to_omop.parsers import parse_maf + + path = _write( + tmp_path / "data_mutations.txt", + "Hugo_Symbol\tComments\tcomments\tCOMMENTS\tcDNA_Change\tcdna_change\n" + "TP53\ta\tb\tc\tCGA\tcga\n", + ) + rows, types, _ = parse_maf(path) + cols = list(types.keys()) + assert cols.count("Comments") == 1 + assert "comments_2" in cols + assert "COMMENTS_3" in cols + assert "cDNA_Change" in cols + assert "cdna_change_2" in cols + assert rows.num_rows == 1 + + def test_first_occurrence_keeps_original_casing(self, tmp_path: Path) -> None: + from showcase.cbioportal_to_omop.parsers import parse_maf + + path = _write( + tmp_path / "data_mutations.txt", + "Hugo_Symbol\ttranscript\tTranscript\n" + "TP53\tNM_1\tNM_2\n", + ) + _, types, _ = parse_maf(path) + cols = list(types.keys()) + assert "transcript" in cols + assert "Transcript_2" in cols diff --git a/tests/showcase/cbioportal_to_omop/test_msk_chord_parsers.py b/tests/showcase/cbioportal_to_omop/test_msk_chord_parsers.py new file mode 100644 index 0000000..f970d43 --- /dev/null +++ b/tests/showcase/cbioportal_to_omop/test_msk_chord_parsers.py @@ -0,0 +1,228 @@ +"""Tests for MSK CHORD-specific parsers: segmented CNA, gene panel matrix (long), lab timelines.""" +from __future__ import annotations + +from pathlib import Path + +import pytest + +from showcase.cbioportal_to_omop.parsers import ( + parse_gene_panel_matrix, + parse_lab_timeline, + parse_segmented_cna, + parse_timeline_file, +) + +pytestmark = pytest.mark.unit + + +def _write(path: Path, content: str) -> Path: + path.write_text(content, encoding="utf-8") + return path + + +class TestParseSegmentedCna: + def test_parses_six_column_seg_file(self, tmp_path: Path) -> None: + path = _write( + tmp_path / "data_cna_hg19.seg", + "ID\tchrom\tloc.start\tloc.end\tnum.mark\tseg.mean\n" + "P-001\t1\t10000\t250000\t150\t0.4\n" + "P-001\t2\t300000\t500000\t80\t-0.2\n" + "P-002\t1\t10000\t250000\t150\t0.0\n", + ) + rows, types, _ = parse_segmented_cna(path) + assert rows.num_rows == 3 + col_set = set(rows.column_names) + assert col_set == {"sample_id", "chrom", "loc_start", "loc_end", "num_mark", "seg_mean"} + + def test_numeric_columns_typed_correctly(self, tmp_path: Path) -> None: + path = _write( + tmp_path / "x.seg", + "ID\tchrom\tloc.start\tloc.end\tnum.mark\tseg.mean\n" + "S-1\tX\t1\t1000\t10\t0.3\n", + ) + _, types, _ = parse_segmented_cna(path) + assert types["loc_start"] == "BIGINT" + assert types["loc_end"] == "BIGINT" + assert types["num_mark"] == "BIGINT" + assert types["seg_mean"] == "DOUBLE" + assert types["sample_id"] == "VARCHAR" + assert types["chrom"] == "VARCHAR" + + def test_handles_blank_seg_mean_as_null(self, tmp_path: Path) -> None: + path = _write( + tmp_path / "x.seg", + "ID\tchrom\tloc.start\tloc.end\tnum.mark\tseg.mean\n" + "S-1\t1\t1\t1000\t10\t\n", + ) + rows, _, _ = parse_segmented_cna(path) + assert rows.column("seg_mean").to_pylist() == [None] + + +class TestParseGenePanelMatrixLong: + def test_pivots_wide_to_long(self, tmp_path: Path) -> None: + path = _write( + tmp_path / "data_gene_panel_matrix.txt", + "SAMPLE_ID\tmutations\tcna\tstructural_variants\n" + "S-1\tIMPACT341\tIMPACT341\tIMPACT341\n" + "S-2\tIMPACT410\tIMPACT410\tIMPACT410\n", + ) + rows, types, _ = parse_gene_panel_matrix(path) + col_set = set(rows.column_names) + assert col_set == {"sample_id", "panel_id", "assay"} + # 2 samples x 3 assays = 6 rows + assert rows.num_rows == 6 + assert types["sample_id"] == "VARCHAR" + assert types["panel_id"] == "VARCHAR" + assert types["assay"] == "VARCHAR" + + def test_skips_blank_panel_assignments(self, tmp_path: Path) -> None: + path = _write( + tmp_path / "data_gene_panel_matrix.txt", + "SAMPLE_ID\tmutations\tcna\n" + "S-1\tIMPACT341\t\n" + "S-2\t\tIMPACT410\n", + ) + rows, _, _ = parse_gene_panel_matrix(path) + # Blank cells skipped: 2 non-blank assignments only + assert rows.num_rows == 2 + pairs = list(zip( + rows.column("sample_id").to_pylist(), + rows.column("assay").to_pylist(), + )) + assert ("S-1", "mutations") in pairs + assert ("S-2", "cna") in pairs + + def test_emits_one_row_per_sample_assay_pair(self, tmp_path: Path) -> None: + path = _write( + tmp_path / "data_gene_panel_matrix.txt", + "SAMPLE_ID\tmutations\n" + "S-1\tIMPACT341\n", + ) + rows, _, _ = parse_gene_panel_matrix(path) + assert rows.num_rows == 1 + assert rows.column("panel_id").to_pylist() == ["IMPACT341"] + assert rows.column("assay").to_pylist() == ["mutations"] + + +class TestParseLabTimeline: + def test_types_value_as_double_when_test_value_units_present( + self, tmp_path: Path + ) -> None: + path = _write( + tmp_path / "data_timeline_labtest.txt", + "PATIENT_ID\tSTART_DATE\tSTOP_DATE\tEVENT_TYPE\tTEST\tVALUE\tUNITS\n" + "P-1\t10\t10\tLAB_TEST\tHemoglobin\t13.5\tg/dL\n" + "P-2\t20\t20\tLAB_TEST\tCreatinine\t1.1\tmg/dL\n", + ) + rows, types, _ = parse_lab_timeline(path) + assert types["VALUE"] == "DOUBLE" + assert types["TEST"] == "VARCHAR" + assert types["UNITS"] == "VARCHAR" + assert rows.column("VALUE").to_pylist() == [13.5, 1.1] + + def test_falls_back_to_varchar_when_value_unparseable( + self, tmp_path: Path + ) -> None: + path = _write( + tmp_path / "data_timeline_labtest.txt", + "PATIENT_ID\tSTART_DATE\tEVENT_TYPE\tTEST\tVALUE\tUNITS\n" + "P-1\t10\tLAB_TEST\tA1c\t<6.0\t%\n", + ) + rows, types, _ = parse_lab_timeline(path) + assert types["VALUE"] == "DOUBLE" + assert rows.column("VALUE").to_pylist() == [None] + + +class TestParseTimelineFileGeneric: + def test_non_lab_timeline_keeps_varchar(self, tmp_path: Path) -> None: + path = _write( + tmp_path / "data_timeline_treatment.txt", + "PATIENT_ID\tSTART_DATE\tSTOP_DATE\tEVENT_TYPE\tTREATMENT_TYPE\tAGENT\n" + "P-1\t0\t30\tTREATMENT\tChemotherapy\tCisplatin\n", + ) + _, types, _ = parse_timeline_file(path) + assert types["AGENT"] == "VARCHAR" + + +class TestIngestStudyDirDispatchesSeg: + def test_seg_file_creates_cna_segmented_table(self, tmp_path: Path) -> None: + from sema.ingest.duckdb_staging import Staging + from showcase.cbioportal_to_omop.parsers import _ingest_study_dir + + study_dir = tmp_path / "msk_chord" + study_dir.mkdir() + _write( + study_dir / "msk_chord_2024_data_cna_hg19.seg", + "ID\tchrom\tloc.start\tloc.end\tnum.mark\tseg.mean\n" + "S-1\t1\t1\t1000\t10\t0.3\n", + ) + + staging = Staging( + db_path=str(tmp_path / "db.duckdb"), + schemas=("cbioportal",), + ) + _ingest_study_dir("msk_chord", study_dir, staging, schema_name="cbioportal") + info = staging.describe("cbioportal", "cna_segmented") + assert "sample_id" in info.columns + assert "seg_mean" in info.columns + staging.close() + + def test_lab_timeline_picked_up_by_dispatch(self, tmp_path: Path) -> None: + from sema.ingest.duckdb_staging import Staging + from showcase.cbioportal_to_omop.parsers import _ingest_study_dir + + study_dir = tmp_path / "msk_chord" + study_dir.mkdir() + _write( + study_dir / "data_timeline_labtest.txt", + "PATIENT_ID\tSTART_DATE\tEVENT_TYPE\tTEST\tVALUE\tUNITS\n" + "P-1\t10\tLAB_TEST\tHb\t13.5\tg/dL\n", + ) + + staging = Staging( + db_path=str(tmp_path / "db.duckdb"), + schemas=("cbioportal",), + ) + _ingest_study_dir("msk_chord", study_dir, staging, schema_name="cbioportal") + info = staging.describe("cbioportal", "timeline_labtest") + assert info.columns["VALUE"].type.upper().startswith("DOUB") + staging.close() + + +class TestRoundTripMixedFiles: + def test_ingest_mixed_seg_and_panel_matrix_and_lab( + self, tmp_path: Path + ) -> None: + from sema.ingest.duckdb_staging import Staging + from showcase.cbioportal_to_omop.parsers import _ingest_study_dir + + study_dir = tmp_path / "study" + study_dir.mkdir() + _write( + study_dir / "msk_data_cna_hg19.seg", + "ID\tchrom\tloc.start\tloc.end\tnum.mark\tseg.mean\nS-1\t1\t1\t100\t5\t0.1\n", + ) + _write( + study_dir / "data_gene_panel_matrix.txt", + "SAMPLE_ID\tmutations\nS-1\tIMPACT341\n", + ) + _write( + study_dir / "data_timeline_labtest.txt", + "PATIENT_ID\tSTART_DATE\tEVENT_TYPE\tTEST\tVALUE\tUNITS\n" + "P-1\t10\tLAB_TEST\tHb\t13.5\tg/dL\n", + ) + + staging = Staging( + db_path=str(tmp_path / "rt.duckdb"), + schemas=("cbioportal",), + ) + _ingest_study_dir("msk", study_dir, staging, schema_name="cbioportal") + + for tbl in ("cna_segmented", "gene_panel_matrix", "timeline_labtest"): + info = staging.describe("cbioportal", tbl) + assert info.columns, f"{tbl} should be ingested" + # gene_panel_matrix should be long format + gpm = staging.describe("cbioportal", "gene_panel_matrix") + assert "panel_id" in gpm.columns + assert "assay" in gpm.columns + staging.close() diff --git a/tests/showcase/cbioportal_to_omop/test_parsers.py b/tests/showcase/cbioportal_to_omop/test_parsers.py index f79f1ad..0940ccd 100644 --- a/tests/showcase/cbioportal_to_omop/test_parsers.py +++ b/tests/showcase/cbioportal_to_omop/test_parsers.py @@ -163,9 +163,23 @@ def test_iter_timeline_files_emits_one_kind_per_file(self, tmp_path: Path) -> No assert set(kinds.keys()) == {"treatment", "status"} assert "timeline_treatment" not in kinds + def test_iter_timeline_files_accepts_hyphen_in_kind(self, tmp_path: Path) -> None: + _write(tmp_path / "data_timeline_ca_15-3_labs.txt", "PATIENT_ID\n") + _write(tmp_path / "data_timeline_ca_19-9_labs.txt", "PATIENT_ID\n") + + kinds = {kind for kind, _ in iter_timeline_files(tmp_path)} + assert kinds == {"ca_15-3_labs", "ca_19-9_labs"} + def test_iter_timeline_files_returns_empty_when_no_files(self, tmp_path: Path) -> None: assert list(iter_timeline_files(tmp_path)) == [] + def test_timeline_table_name_sanitizes_hyphen(self) -> None: + from showcase.cbioportal_to_omop.parsers import timeline_table_name + + assert timeline_table_name("treatment") == "timeline_treatment" + assert timeline_table_name("ca_15-3_labs") == "timeline_ca_15_3_labs" + assert timeline_table_name("ca_19-9_labs") == "timeline_ca_19_9_labs" + @pytest.mark.unit class TestIngestStudySkipsMatrixFiles: @@ -198,7 +212,7 @@ def test_reuses_cache_when_done_marker_present(self, tmp_path: Path) -> None: study_dir.mkdir(parents=True) (study_dir / ".done").touch() - with patch("showcase.cbioportal_to_omop.parsers.urlopen") as mock_urlopen: + with patch("showcase.cbioportal_to_omop.cbioportal_fetch_utils.urlopen") as mock_urlopen: result = fetch_study_files("brca_tcga", cache_dir=cache) mock_urlopen.assert_not_called() assert result == study_dir @@ -230,7 +244,7 @@ def test_lists_and_downloads_expected_files(self, tmp_path: Path) -> None: download_responses.append(dl) urlopen_mock = MagicMock(side_effect=[api_resp, *download_responses]) - with patch("showcase.cbioportal_to_omop.parsers.urlopen", urlopen_mock): + with patch("showcase.cbioportal_to_omop.cbioportal_fetch_utils.urlopen", urlopen_mock): result = fetch_study_files("brca_tcga", cache_dir=cache) downloaded = {p.name for p in result.iterdir() if p.is_file() and p.name != ".done"} diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 4fd3f50..5ac05fe 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -99,6 +99,55 @@ def test_build_flags_override_config_file(self, runner, tmp_path): call_config = mock_build.call_args[0][0] assert call_config.catalog == "override" + def test_build_fk_detection_default_on(self, runner): + with patch("sema.cli.run_build") as mock_build: + mock_build.return_value = { + "tables_processed": 0, "entities_created": 0, + "properties_created": 0, "value_sets_created": 0, + "terms_created": 0, "joins_inferred": 0, + "confidence_distribution": {}, + } + result = runner.invoke(cli, [ + "build", "--source", "databricks", "--catalog", "test", + ]) + assert result.exit_code == 0 + call_config = mock_build.call_args[0][0] + assert call_config.enable_fk_detection is True + assert call_config.materialize_structural_fk is False + + def test_build_no_enable_fk_detection_flag(self, runner): + with patch("sema.cli.run_build") as mock_build: + mock_build.return_value = { + "tables_processed": 0, "entities_created": 0, + "properties_created": 0, "value_sets_created": 0, + "terms_created": 0, "joins_inferred": 0, + "confidence_distribution": {}, + } + result = runner.invoke(cli, [ + "build", "--source", "databricks", "--catalog", "test", + "--no-enable-fk-detection", + ]) + assert result.exit_code == 0 + call_config = mock_build.call_args[0][0] + assert call_config.enable_fk_detection is False + + def test_build_materialize_structural_fk_flag(self, runner): + with patch("sema.cli.run_build") as mock_build: + mock_build.return_value = { + "tables_processed": 0, "entities_created": 0, + "properties_created": 0, "value_sets_created": 0, + "terms_created": 0, "joins_inferred": 0, + "confidence_distribution": {}, + } + result = runner.invoke(cli, [ + "build", "--source", "databricks", "--catalog", "test", + "--materialize-structural-fk", + ]) + assert result.exit_code == 0 + call_config = mock_build.call_args[0][0] + assert call_config.materialize_structural_fk is True + assert call_config.fk_materialization_threshold == 0.70 + class TestContextCommand: def test_context_produces_sco_json(self, runner): diff --git a/tests/unit/test_cli_recover_comments.py b/tests/unit/test_cli_recover_comments.py new file mode 100644 index 0000000..5a870db --- /dev/null +++ b/tests/unit/test_cli_recover_comments.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from click.testing import CliRunner + +from sema.cli import cli +from sema.ingest.comment_recovery import ( + LiveTableComments, + ParsedTableComments, +) + +pytestmark = pytest.mark.unit + + +def _stub_dependencies( + *, + registered_study: str | None = "study_x", + target_schema: str = "cbioportal_x", + parsed: dict[str, ParsedTableComments] | None = None, + live: dict[str, LiveTableComments] | None = None, + cache_exists: bool = True, +) -> dict[str, MagicMock]: + parsed = parsed if parsed is not None else { + "patient": ParsedTableComments( + table_comment="cBioPortal clinical patient", + column_comments={"PATIENT_ID": "Identifier."}, + ), + } + live = live if live is not None else { + "patient": LiveTableComments( + table_comment=None, column_comments={"PATIENT_ID": ""}, + ), + } + extract_mock = MagicMock(return_value=parsed) + read_mock = MagicMock(return_value=live) + executor_mock = MagicMock() + bridge_mock = MagicMock() + bridge_mock._execute = executor_mock # not directly used; see patches below + return { + "extract": extract_mock, "read": read_mock, + "executor": executor_mock, "bridge": bridge_mock, + "registered_study": registered_study, + "target_schema": target_schema, + "cache_exists": cache_exists, + } + + +def _registry_mock(stubs: dict[str, MagicMock]) -> MagicMock: + r = MagicMock() + if stubs["registered_study"] is not None: + r.find_schema_for_study.return_value = stubs["target_schema"] + else: + r.find_schema_for_study.return_value = None + return r + + +def _invoke( + args: list[str], stubs: dict[str, MagicMock], tmp_path: Path, +) -> object: + cache_root = tmp_path / "cache" + cache_root.mkdir() + if stubs["cache_exists"] and stubs["registered_study"] is not None: + (cache_root / stubs["registered_study"]).mkdir() + runner = CliRunner() + with patch("sema.cli_ingest.Staging") as staging_cls, patch( + "sema.cli_ingest.StudyRegistry", return_value=_registry_mock(stubs), + ), patch( + "sema.cli_ingest._extract_study_comments_lazy", + stubs["extract"], + ), patch( + "sema.cli_ingest.read_databricks_comments", stubs["read"], + ), patch( + "sema.cli_ingest._open_recovery_executor", + return_value=(stubs["executor"], MagicMock()), + ): + staging_cls.return_value = MagicMock() + return runner.invoke( + cli, + [ + "ingest", "recover-comments", + "--duckdb-path", str(tmp_path / "stg.duckdb"), + "--cache-dir", str(cache_root), + *args, + ], + ) + + +def test_registered_study_runs_with_stubbed_executor(tmp_path: Path) -> None: + stubs = _stub_dependencies() + result = _invoke(["--study", "study_x"], stubs, tmp_path) + assert result.exit_code == 0, result.output + stubs["executor"].assert_called() + assert "Columns updated: 1" in result.output + assert "Table comments updated: 1" in result.output + + +def test_unregistered_study_exits_nonzero_with_message(tmp_path: Path) -> None: + stubs = _stub_dependencies(registered_study=None) + result = _invoke(["--study", "ghost"], stubs, tmp_path) + assert result.exit_code != 0 + assert "_sema_study_registry" in result.output + + +def test_dry_run_does_not_execute(tmp_path: Path) -> None: + stubs = _stub_dependencies() + result = _invoke(["--study", "study_x", "--dry-run"], stubs, tmp_path) + assert result.exit_code == 0, result.output + stubs["executor"].assert_not_called() + + +def test_force_flag_is_threaded_through(tmp_path: Path) -> None: + stubs = _stub_dependencies( + live={ + "patient": LiveTableComments( + table_comment="existing", + column_comments={"PATIENT_ID": "existing"}, + ), + }, + ) + result = _invoke(["--study", "study_x", "--force"], stubs, tmp_path) + assert result.exit_code == 0, result.output + stubs["executor"].assert_called() + + +def test_json_output_is_parseable(tmp_path: Path) -> None: + stubs = _stub_dependencies() + result = _invoke(["--study", "study_x", "--json"], stubs, tmp_path) + assert result.exit_code == 0, result.output + payload = json.loads(result.output.strip().splitlines()[-1]) + assert payload["study_id"] == "study_x" + assert payload["target_schema"] == "cbioportal_x" + assert payload["columns_updated"] >= 1 + + +def test_explicit_overrides_bypass_registry(tmp_path: Path) -> None: + stubs = _stub_dependencies(registered_study=None) + cache = tmp_path / "alt_cache" + cache.mkdir() + runner = CliRunner() + with patch("sema.cli_ingest.Staging") as staging_cls, patch( + "sema.cli_ingest.StudyRegistry", return_value=_registry_mock(stubs), + ), patch( + "sema.cli_ingest._extract_study_comments_lazy", stubs["extract"], + ), patch( + "sema.cli_ingest.read_databricks_comments", stubs["read"], + ), patch( + "sema.cli_ingest._open_recovery_executor", + return_value=(stubs["executor"], MagicMock()), + ): + staging_cls.return_value = MagicMock() + result = runner.invoke( + cli, + [ + "ingest", "recover-comments", + "--source-cache", str(cache), + "--target-catalog", "workspace", + "--target-schema", "cbioportal_x", + "--duckdb-path", str(tmp_path / "stg.duckdb"), + ], + ) + assert result.exit_code == 0, result.output + stubs["executor"].assert_called() diff --git a/tests/unit/test_comment_recovery.py b/tests/unit/test_comment_recovery.py new file mode 100644 index 0000000..f973468 --- /dev/null +++ b/tests/unit/test_comment_recovery.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from sema.ingest.comment_recovery import ( + ColumnUpdate, + LiveTableComments, + ParsedTableComments, + PartialOverrideError, + RecoveryContext, + StudyCacheMissingError, + StudyNotRegisteredError, + TableUpdate, + build_recovery_plan, + execute_recovery_plan, + read_databricks_comments, + resolve_recovery_context, +) +from sema.ingest.duckdb_staging import Staging +from sema.ingest.study_registry import StudyRegistry +from sema.models.config import ( + IngestConfig, + IngestDatabricksTargetConfig, +) + +pytestmark = pytest.mark.unit + + +def _ctx() -> RecoveryContext: + return RecoveryContext( + study_id="study_x", + source_cache=Path("/tmp/cache/study_x"), + target_catalog="workspace", + target_schema="cbioportal_x", + ) + + +def _parsed_clinical() -> dict[str, ParsedTableComments]: + return { + "patient": ParsedTableComments( + table_comment="cBioPortal clinical patient", + column_comments={ + "PATIENT_ID": "Identifier to uniquely specify a patient.", + "OS_STATUS": "Overall survival status.", + }, + ), + "clinical_supp_hypoxia": ParsedTableComments( + table_comment="cBioPortal clinical_supp_hypoxia", + column_comments={"BUFFA_HYPOXIA_SCORE": "Buffa hypoxia score."}, + ), + } + + +def _live_blank() -> dict[str, LiveTableComments]: + return { + "patient": LiveTableComments( + table_comment=None, + column_comments={"PATIENT_ID": "", "OS_STATUS": None}, # type: ignore[dict-item] + ), + "clinical_supp_hypoxia": LiveTableComments( + table_comment="", + column_comments={"BUFFA_HYPOXIA_SCORE": ""}, + ), + } + + +def test_plan_emits_one_alter_per_parsed_column_and_table_comment() -> None: + plan = build_recovery_plan(_ctx(), _parsed_clinical(), _live_blank()) + assert len(plan.column_updates) == 3 + assert len(plan.table_updates) == 2 + assert plan.catalog == "workspace" + assert plan.schema == "cbioportal_x" + pairs = {(u.table, u.column) for u in plan.column_updates} + assert pairs == { + ("patient", "PATIENT_ID"), + ("patient", "OS_STATUS"), + ("clinical_supp_hypoxia", "BUFFA_HYPOXIA_SCORE"), + } + + +def test_plan_idempotent_when_live_already_commented() -> None: + parsed = _parsed_clinical() + live = { + "patient": LiveTableComments( + table_comment="cBioPortal clinical patient", + column_comments={ + "PATIENT_ID": "Identifier to uniquely specify a patient.", + "OS_STATUS": "Overall survival status.", + }, + ), + "clinical_supp_hypoxia": LiveTableComments( + table_comment="cBioPortal clinical_supp_hypoxia", + column_comments={"BUFFA_HYPOXIA_SCORE": "Buffa hypoxia score."}, + ), + } + plan = build_recovery_plan(_ctx(), parsed, live) + assert plan.column_updates == [] + assert plan.table_updates == [] + + +def test_plan_force_overrides_existing_column_and_table_comments() -> None: + parsed = _parsed_clinical() + live = { + "patient": LiveTableComments( + table_comment="manual operator edit", + column_comments={ + "PATIENT_ID": "manual edit", + "OS_STATUS": "manual edit", + }, + ), + "clinical_supp_hypoxia": LiveTableComments( + table_comment="manual edit", + column_comments={"BUFFA_HYPOXIA_SCORE": "manual edit"}, + ), + } + plan = build_recovery_plan(_ctx(), parsed, live, force=True) + assert len(plan.column_updates) == 3 + assert len(plan.table_updates) == 2 + + +def test_plan_marks_missing_column_as_skipped_with_reason() -> None: + parsed = { + "patient": ParsedTableComments( + table_comment=None, + column_comments={ + "PATIENT_ID": "ok", + "LEGACY_COL": "no longer in databricks", + }, + ), + } + live = { + "patient": LiveTableComments( + table_comment=None, + column_comments={"PATIENT_ID": ""}, + ), + } + plan = build_recovery_plan(_ctx(), parsed, live) + assert ColumnUpdate( + table="patient", column="PATIENT_ID", new_comment="ok", + ) in plan.column_updates + assert any( + s.column == "LEGACY_COL" and s.reason == "column_not_found" + for s in plan.skipped_columns + ) + + +def test_plan_skips_table_unknown_to_live() -> None: + parsed = { + "ghost": ParsedTableComments( + table_comment="t", column_comments={"X": "x"}, + ), + } + plan = build_recovery_plan(_ctx(), parsed, {}) + assert plan.column_updates == [] + assert plan.table_updates == [] + assert any( + s.table == "ghost" and s.reason == "table_not_found" + for s in plan.skipped_columns + ) + + +def test_plan_preserves_existing_table_comment_unless_force() -> None: + parsed = { + "patient": ParsedTableComments( + table_comment="parser value", + column_comments={}, + ), + } + live = { + "patient": LiveTableComments( + table_comment="existing", + column_comments={}, + ), + } + plan = build_recovery_plan(_ctx(), parsed, live) + assert plan.table_updates == [] + forced = build_recovery_plan(_ctx(), parsed, live, force=True) + assert forced.table_updates == [ + TableUpdate(table="patient", new_comment="parser value"), + ] + + +def test_plan_skips_empty_parser_comment() -> None: + parsed = { + "patient": ParsedTableComments( + table_comment=None, + column_comments={"BLANK": ""}, + ), + } + live = { + "patient": LiveTableComments( + table_comment=None, column_comments={"BLANK": ""}, + ), + } + plan = build_recovery_plan(_ctx(), parsed, live) + assert plan.column_updates == [] + + +def test_execute_runs_executor_per_statement() -> None: + plan = build_recovery_plan(_ctx(), _parsed_clinical(), _live_blank()) + executed: list[str] = [] + + def executor(sql: str) -> None: + executed.append(sql) + + report = execute_recovery_plan(plan, executor) + assert report.columns_updated == 3 + assert report.table_comments_updated == 2 + assert report.columns_failed == 0 + assert len(executed) == 5 + assert all("ALTER TABLE" in s or "COMMENT ON TABLE" in s for s in executed) + + +def test_execute_dry_run_skips_executor() -> None: + plan = build_recovery_plan(_ctx(), _parsed_clinical(), _live_blank()) + executed: list[str] = [] + + def executor(sql: str) -> None: + executed.append(sql) + + report = execute_recovery_plan(plan, executor, dry_run=True) + assert executed == [] + assert report.columns_updated == 3 + assert report.table_comments_updated == 2 + + +def test_execute_continues_after_per_statement_failure() -> None: + plan = build_recovery_plan(_ctx(), _parsed_clinical(), _live_blank()) + calls: list[str] = [] + + def executor(sql: str) -> None: + calls.append(sql) + if "OS_STATUS" in sql: + raise RuntimeError("simulated transient") + + report = execute_recovery_plan(plan, executor) + assert report.columns_failed == 1 + assert report.columns_updated == 2 + assert len(calls) == 5 + assert any(f.column == "OS_STATUS" for f in report.failed) + + +def test_execute_records_skipped_from_plan() -> None: + parsed = { + "patient": ParsedTableComments( + table_comment=None, + column_comments={"PATIENT_ID": "ok", "LEGACY_COL": "x"}, + ), + } + live = { + "patient": LiveTableComments( + table_comment=None, column_comments={"PATIENT_ID": ""}, + ), + } + plan = build_recovery_plan(_ctx(), parsed, live) + report = execute_recovery_plan(plan, lambda sql: None) + assert report.columns_skipped == 1 + assert report.columns_updated == 1 + + +def test_read_databricks_comments_combines_columns_and_tables() -> None: + column_rows = [ + ("patient", "PATIENT_ID", "Existing comment"), + ("patient", "AGE", None), + ("sample", "SAMPLE_ID", ""), + ] + table_rows = [("patient", "old patient comment"), ("sample", None)] + + def query_fn(sql: str, params: list[str]) -> list[tuple[str, ...]]: + if "information_schema.columns" in sql: + return column_rows # type: ignore[return-value] + if "information_schema.tables" in sql: + return table_rows # type: ignore[return-value] + raise AssertionError(f"unexpected query: {sql}") + + result = read_databricks_comments("workspace", "cbioportal_x", query_fn) + assert result["patient"].table_comment == "old patient comment" + assert result["patient"].column_comments == { + "PATIENT_ID": "Existing comment", + "AGE": None, + } + assert result["sample"].column_comments == {"SAMPLE_ID": ""} + assert result["sample"].table_comment is None + + +def _ingest_config(tmp_path: Path) -> IngestConfig: + return IngestConfig( + cache_dir=str(tmp_path / "cache"), + databricks=IngestDatabricksTargetConfig(catalog="workspace"), + ) + + +def _registry_with(study_to_schema: dict[str, str], staging: Staging) -> StudyRegistry: + registry = StudyRegistry(staging) + for study, schema in study_to_schema.items(): + registry.register( + schema_name=schema, + original_study_id=study, + source_type="cbioportal", + ) + return registry + + +@pytest.fixture +def staging_db(tmp_path: Path) -> Staging: + return Staging(str(tmp_path / "stg.duckdb")) + + +def test_resolve_context_registered_study_with_cache( + tmp_path: Path, staging_db: Staging, +) -> None: + registry = _registry_with({"study_x": "cbioportal_x"}, staging_db) + config = _ingest_config(tmp_path) + cache = Path(config.cache_dir).expanduser() / "study_x" + cache.mkdir(parents=True) + + ctx = resolve_recovery_context( + study_id="study_x", registry=registry, ingest_config=config, + source_cache_override=None, target_catalog_override=None, + target_schema_override=None, + ) + assert ctx.study_id == "study_x" + assert ctx.source_cache == cache + assert ctx.target_catalog == "workspace" + assert ctx.target_schema == "cbioportal_x" + + +def test_resolve_context_unregistered_study_raises( + tmp_path: Path, staging_db: Staging, +) -> None: + registry = _registry_with({}, staging_db) + config = _ingest_config(tmp_path) + with pytest.raises(StudyNotRegisteredError): + resolve_recovery_context( + study_id="ghost", registry=registry, ingest_config=config, + source_cache_override=None, target_catalog_override=None, + target_schema_override=None, + ) + + +def test_resolve_context_missing_cache_raises_distinct_error( + tmp_path: Path, staging_db: Staging, +) -> None: + registry = _registry_with({"study_x": "cbioportal_x"}, staging_db) + config = _ingest_config(tmp_path) + with pytest.raises(StudyCacheMissingError): + resolve_recovery_context( + study_id="study_x", registry=registry, ingest_config=config, + source_cache_override=None, target_catalog_override=None, + target_schema_override=None, + ) + + +def test_resolve_context_full_overrides_bypass_registry( + tmp_path: Path, staging_db: Staging, +) -> None: + registry = _registry_with({}, staging_db) + config = _ingest_config(tmp_path) + cache = tmp_path / "alt" + cache.mkdir() + + ctx = resolve_recovery_context( + study_id=None, registry=registry, ingest_config=config, + source_cache_override=cache, + target_catalog_override="custom_catalog", + target_schema_override="cbioportal_alt", + ) + assert ctx.source_cache == cache + assert ctx.target_catalog == "custom_catalog" + assert ctx.target_schema == "cbioportal_alt" + assert ctx.study_id is None + + +def test_resolve_context_partial_overrides_without_study_raise( + tmp_path: Path, staging_db: Staging, +) -> None: + registry = _registry_with({}, staging_db) + config = _ingest_config(tmp_path) + with pytest.raises(PartialOverrideError): + resolve_recovery_context( + study_id=None, registry=registry, ingest_config=config, + source_cache_override=tmp_path, + target_catalog_override=None, + target_schema_override="cbioportal_alt", + ) + + +def test_resolve_context_no_study_no_overrides_raises( + tmp_path: Path, staging_db: Staging, +) -> None: + registry = _registry_with({}, staging_db) + config = _ingest_config(tmp_path) + with pytest.raises(PartialOverrideError): + resolve_recovery_context( + study_id=None, registry=registry, ingest_config=config, + source_cache_override=None, target_catalog_override=None, + target_schema_override=None, + ) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 8d65fca..c555f23 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -187,6 +187,21 @@ def test_cli_flags_override_file(self, tmp_path, monkeypatch): config = BuildConfig.from_file(str(config_file), overrides={"catalog": "override_catalog"}) assert config.catalog == "override_catalog" + def test_fk_detection_defaults_on(self): + config = BuildConfig() + assert config.enable_fk_detection is True + assert config.materialize_structural_fk is False + + def test_fk_materialization_threshold_default_080(self): + config = BuildConfig() + assert config.fk_materialization_threshold == 0.80 + + def test_fk_materialization_threshold_drops_to_070_when_structural_opt_in( + self, + ): + config = BuildConfig(materialize_structural_fk=True) + assert config.fk_materialization_threshold == 0.70 + class TestQueryConfig: def test_defaults(self): diff --git a/tests/unit/test_databricks_bridge.py b/tests/unit/test_databricks_bridge.py index bc4689a..3e48689 100644 --- a/tests/unit/test_databricks_bridge.py +++ b/tests/unit/test_databricks_bridge.py @@ -8,6 +8,7 @@ from sema.ingest.databricks_push import Bridge, PushError, PushResult from sema.ingest.duckdb_staging import Staging +from sema.ingest.study_registry import StudyRegistry from sema.models.config import ( DatabricksConfig, IngestConfig, @@ -46,6 +47,7 @@ def _mock_connection(cursor: MagicMock) -> MagicMock: @pytest.fixture def staging(tmp_path: Path) -> Staging: s = Staging(str(tmp_path / "bridge.duckdb")) + s.execute('CREATE SCHEMA IF NOT EXISTS "cbioportal"') rows = pa.table({"patient_id": ["P-1", "P-2"], "age": [40, 50]}) s.write_table( schema="cbioportal", @@ -60,7 +62,11 @@ def staging(tmp_path: Path) -> Staging: @pytest.mark.unit class TestBridgeProvisioning: - def test_ensure_schemas_issues_create_if_not_exists(self, staging: Staging) -> None: + def test_ensure_schemas_issues_create_for_registered_and_shared( + self, staging: Staging + ) -> None: + registry = StudyRegistry(staging) + registry.register("cbioportal_msk_chord_2024", "msk_chord_2024", "cbioportal") cursor = _mock_cursor() conn = _mock_connection(cursor) with patch("sema.ingest.databricks_push.sql_connect", return_value=conn): @@ -68,7 +74,10 @@ def test_ensure_schemas_issues_create_if_not_exists(self, staging: Staging) -> N bridge.ensure_schemas() executed = [call.args[0] for call in cursor.execute.call_args_list] - assert any("CREATE SCHEMA IF NOT EXISTS `workspace`.`cbioportal`" in sql for sql in executed) + assert any( + "CREATE SCHEMA IF NOT EXISTS `workspace`.`cbioportal_msk_chord_2024`" in sql + for sql in executed + ) assert any("CREATE SCHEMA IF NOT EXISTS `workspace`.`ontology_omop`" in sql for sql in executed) assert any("CREATE SCHEMA IF NOT EXISTS `workspace`.`vocabulary_omop`" in sql for sql in executed) @@ -167,6 +176,7 @@ def test_falls_back_to_insert_when_no_staging_uri(self, tmp_path: Path) -> None: class TestPushSchemasErrorHandling: def test_one_table_failure_continues_with_others(self, tmp_path: Path) -> None: staging = Staging(str(tmp_path / "errors.duckdb")) + staging.execute('CREATE SCHEMA IF NOT EXISTS "cbioportal"') for name in ["patient", "sample"]: staging.write_table( schema="cbioportal", @@ -210,3 +220,82 @@ def test_logs_warning_on_count_mismatch(self, staging: Staging, caplog: pytest.L assert result.rows_pushed == 2 assert result.target_count == 99 assert result.count_mismatch is True + + +@pytest.mark.unit +class TestUcVolumeDetection: + def test_is_uc_volume_path_true_for_volumes_prefix(self) -> None: + from sema.ingest.databricks_push_utils import is_uc_volume_path + + assert is_uc_volume_path("/Volumes/workspace/default/sema_staging") is True + assert is_uc_volume_path("/Volumes/workspace/default/sema_staging/") is True + + def test_is_uc_volume_path_false_for_other_uris(self) -> None: + from sema.ingest.databricks_push_utils import is_uc_volume_path + + assert is_uc_volume_path("file:///tmp/foo") is False + assert is_uc_volume_path("s3://bucket/key") is False + assert is_uc_volume_path("/tmp/foo") is False + assert is_uc_volume_path("dbfs:/tmp/foo") is False + + +@pytest.mark.unit +class TestSizeBasedRouting: + def test_route_via_copy_into_when_above_threshold(self) -> None: + from sema.ingest.databricks_push_utils import ( + COPY_INTO_ROW_THRESHOLD, + should_route_via_copy_into, + ) + + assert should_route_via_copy_into( + "cbioportal_msk_chord_2024", "cna", + row_count=COPY_INTO_ROW_THRESHOLD, + ) is True + + def test_no_copy_into_when_below_threshold_and_not_in_allowlist(self) -> None: + from sema.ingest.databricks_push_utils import should_route_via_copy_into + + assert should_route_via_copy_into( + "cbioportal_msk_chord_2024", "patient", row_count=24_950, + ) is False + + def test_allowlist_entries_route_regardless_of_row_count(self) -> None: + from sema.ingest.databricks_push_utils import should_route_via_copy_into + + assert should_route_via_copy_into( + "vocabulary_omop", "concept_ancestor", row_count=0, + ) is True + + +@pytest.mark.unit +class TestUcVolumeUpload: + def test_uc_volume_uri_uploads_via_workspace_client(self, tmp_path: Path) -> None: + staging = Staging(str(tmp_path / "uc.duckdb")) + staging.execute('CREATE SCHEMA IF NOT EXISTS "cbioportal_msk_chord_2024"') + rows = pa.table({"id": list(range(150_000))}) + staging.write_table( + schema="cbioportal_msk_chord_2024", + table="cna", + rows=rows, + column_types={"id": "INTEGER"}, + column_comments={}, + table_comment=None, + ) + cursor = _mock_cursor() + cursor.fetchone.return_value = (150_000,) + conn = _mock_connection(cursor) + + ws_client = MagicMock() + with patch("sema.ingest.databricks_push.sql_connect", return_value=conn), \ + patch("sema.ingest.databricks_push.WorkspaceClient", return_value=ws_client): + cfg = _config(cloud_uri="/Volumes/workspace/default/sema_staging") + bridge = Bridge(cfg, staging=staging) + result = bridge.push_table("cbioportal_msk_chord_2024", "cna") + + assert result.mechanism == "copy_into" + assert ws_client.files.upload.call_count == 1 + upload_kwargs = ws_client.files.upload.call_args.kwargs + assert upload_kwargs["file_path"].startswith( + "/Volumes/workspace/default/sema_staging/cbioportal_msk_chord_2024/cna/" + ) + assert upload_kwargs["overwrite"] is True diff --git a/tests/unit/test_databricks_push_utils.py b/tests/unit/test_databricks_push_utils.py new file mode 100644 index 0000000..da68c83 --- /dev/null +++ b/tests/unit/test_databricks_push_utils.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import pytest + +from sema.ingest.databricks_push_utils import ( + build_alter_column_comment_sql, + build_alter_table_comment_sql, +) + +pytestmark = pytest.mark.unit + + +def test_build_alter_column_comment_sql_quotes_identifiers() -> None: + sql = build_alter_column_comment_sql( + "workspace", "cbioportal_x", "patient", "PATIENT_ID", + "Identifier to uniquely specify a patient.", + ) + assert sql == ( + "ALTER TABLE `workspace`.`cbioportal_x`.`patient` " + "ALTER COLUMN `PATIENT_ID` " + "COMMENT 'Identifier to uniquely specify a patient.'" + ) + + +def test_build_alter_column_comment_sql_escapes_single_quotes() -> None: + sql = build_alter_column_comment_sql( + "workspace", "cbioportal_x", "patient", "PATIENT_ID", + "Patient's identifier", + ) + assert "'Patient''s identifier'" in sql + + +def test_build_alter_column_comment_sql_empty_comment_clears() -> None: + sql = build_alter_column_comment_sql( + "workspace", "cbioportal_x", "patient", "PATIENT_ID", "", + ) + assert sql.endswith("COMMENT ''") + + +def test_build_alter_column_comment_sql_rejects_semicolon_in_identifier() -> None: + with pytest.raises(ValueError): + build_alter_column_comment_sql( + "workspace", "x; DROP TABLE y", "patient", "PATIENT_ID", "x", + ) + + +def test_build_alter_column_comment_sql_rejects_backtick_in_identifier() -> None: + with pytest.raises(ValueError): + build_alter_column_comment_sql( + "workspace", "cbioportal", "patient", "col`name", "x", + ) + + +def test_build_alter_column_comment_sql_rejects_empty_identifier() -> None: + with pytest.raises(ValueError): + build_alter_column_comment_sql( + "workspace", "cbioportal_x", "patient", "", "x", + ) + + +def test_build_alter_table_comment_sql_quotes_identifiers() -> None: + sql = build_alter_table_comment_sql( + "workspace", "cbioportal_x", "patient", + "cBioPortal clinical patient from data_clinical_patient.txt", + ) + assert sql == ( + "COMMENT ON TABLE `workspace`.`cbioportal_x`.`patient` " + "IS 'cBioPortal clinical patient from data_clinical_patient.txt'" + ) + + +def test_build_alter_table_comment_sql_escapes_single_quotes() -> None: + sql = build_alter_table_comment_sql( + "workspace", "cbioportal_x", "patient", "Patient's table", + ) + assert "'Patient''s table'" in sql + + +def test_build_alter_table_comment_sql_empty_comment_clears() -> None: + sql = build_alter_table_comment_sql( + "workspace", "cbioportal_x", "patient", "", + ) + assert sql.endswith("IS ''") + + +def test_build_alter_table_comment_sql_rejects_semicolon_in_identifier() -> None: + with pytest.raises(ValueError): + build_alter_table_comment_sql( + "workspace", "cbioportal_x", "pa;tient", "x", + ) diff --git a/tests/unit/test_duckdb_staging.py b/tests/unit/test_duckdb_staging.py index 7bfe1fd..2b012e3 100644 --- a/tests/unit/test_duckdb_staging.py +++ b/tests/unit/test_duckdb_staging.py @@ -11,7 +11,9 @@ @pytest.fixture def staging(tmp_path: Path) -> Staging: db_path = tmp_path / "test.duckdb" - return Staging(str(db_path)) + s = Staging(str(db_path)) + s.execute('CREATE SCHEMA IF NOT EXISTS "cbioportal"') + return s @pytest.mark.unit @@ -24,9 +26,9 @@ def test_creates_file_and_schemas_on_init(self, tmp_path: Path) -> None: assert db_path.exists() schemas = staging.list_schemas() - assert "cbioportal" in schemas assert "ontology_omop" in schemas assert "vocabulary_omop" in schemas + assert "cbioportal" not in schemas def test_expands_home_in_path(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("HOME", str(tmp_path)) @@ -38,7 +40,7 @@ def test_reopen_is_idempotent(self, tmp_path: Path) -> None: db_path = tmp_path / "reopen.duckdb" Staging(str(db_path)).close() reopened = Staging(str(db_path)) - assert "cbioportal" in reopened.list_schemas() + assert "ontology_omop" in reopened.list_schemas() @pytest.mark.unit diff --git a/tests/unit/test_e2e_pipeline.py b/tests/unit/test_e2e_pipeline.py index d57f54a..5969f3c 100644 --- a/tests/unit/test_e2e_pipeline.py +++ b/tests/unit/test_e2e_pipeline.py @@ -142,7 +142,7 @@ def process_one(t): ] loader = MagicMock() - def capture(assertions): + def capture(assertions, source_schema=None): committed_counts.append(len(assertions)) loader.commit_table_assertions.side_effect = capture diff --git a/tests/unit/test_few_shot.py b/tests/unit/test_few_shot.py index f270207..a71d44c 100644 --- a/tests/unit/test_few_shot.py +++ b/tests/unit/test_few_shot.py @@ -11,22 +11,22 @@ def test_lookup_healthcare_stage_a(self) -> None: from sema.engine.few_shot import get_examples examples = get_examples(domain="healthcare", stage="A") - assert len(examples) >= 3 - assert len(examples) <= 5 + assert len(examples) >= 5 + assert len(examples) <= 12 def test_lookup_healthcare_stage_b(self) -> None: from sema.engine.few_shot import get_examples examples = get_examples(domain="healthcare", stage="B") - assert len(examples) >= 8 - assert len(examples) <= 12 + assert len(examples) >= 12 + assert len(examples) <= 25 def test_lookup_healthcare_stage_c(self) -> None: from sema.engine.few_shot import get_examples examples = get_examples(domain="healthcare", stage="C") - assert len(examples) >= 6 - assert len(examples) <= 10 + assert len(examples) >= 8 + assert len(examples) <= 18 def test_unknown_domain_returns_empty(self) -> None: from sema.engine.few_shot import get_examples diff --git a/tests/unit/test_few_shot_quality.py b/tests/unit/test_few_shot_quality.py index d473343..787ab3b 100644 --- a/tests/unit/test_few_shot_quality.py +++ b/tests/unit/test_few_shot_quality.py @@ -49,15 +49,15 @@ def test_uses_compact_json_without_indent(self) -> None: ) def test_block_stays_under_token_budget(self) -> None: - """Stage B composed (generic + healthcare) must stay under 2100 tokens. + """Stage B composed (generic + healthcare) must stay under 3500 tokens. - Budget raised from the 1200 healthcare-only ceiling after adding the - 8-example generic base layer. Target ~90 tokens/example at compact - JSON, 20 composed examples, plus framing. + Budget raised again after MSK CHORD expansion: lab/biomarker/procedure/ + performance examples added 9 entries. ~30 composed examples × ~100 + tokens compact + framing. """ block = format_examples("healthcare", "B") approx_tokens = len(block) // 4 - assert approx_tokens <= 2100, ( + assert approx_tokens <= 3500, ( f"Stage B few-shot block is {approx_tokens} tokens — " - f"budget is 2100." + f"budget is 3500." ) diff --git a/tests/unit/test_graph_loader.py b/tests/unit/test_graph_loader.py index acc1d85..9561e56 100644 --- a/tests/unit/test_graph_loader.py +++ b/tests/unit/test_graph_loader.py @@ -223,6 +223,7 @@ def test_add_join_path_entity_links( "t1/c1=t2/c2", from_table_ref="databricks://ws/cat/sch/t1", to_table_ref="databricks://ws/cat/sch/t2", + source_schema="sch", ) session.run.assert_called() cypher = session.run.call_args[0][0] @@ -230,6 +231,20 @@ def test_add_join_path_entity_links( assert "TO_ENTITY" in cypher assert "ENTITY_ON_TABLE" in cypher + def test_add_join_path_entity_links_requires_source_schema( + self, loader, + ): + with pytest.raises(ValueError, match="source_schema"): + loader.add_join_path_entity_links( + "t1/c1=t2/c2", + from_table_ref="t1", + to_table_ref="t2", + ) + + def test_add_join_path_uses_requires_source_schema(self, loader): + with pytest.raises(ValueError, match="source_schema"): + loader.add_join_path_uses("jp", "t1") + class TestAssertionStorage: def test_store_assertion_includes_subject_id( diff --git a/tests/unit/test_graph_source_schema.py b/tests/unit/test_graph_source_schema.py new file mode 100644 index 0000000..44dd4e9 --- /dev/null +++ b/tests/unit/test_graph_source_schema.py @@ -0,0 +1,325 @@ +"""Tests for Section 2-4 of `expand-healthcare-eval-coverage`: +study-scoped stamping, MERGE-key corrections, scoped-delete. +""" +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sema.graph.loader import GraphLoader +from sema.graph.loader_utils import ( + batch_upsert_aliases, + batch_upsert_entities, + batch_upsert_join_paths, + batch_upsert_properties, + batch_upsert_value_sets, +) +from sema.models.assertions import ( + Assertion, + AssertionPredicate, + AssertionStatus, +) + +pytestmark = pytest.mark.unit + +SCHEMA_BRCA = "cbioportal_brca_tcga_pan_can_atlas_2018" +SCHEMA_MSK = "cbioportal_msk_chord_2024" + + +@pytest.fixture +def mock_driver(): + driver = MagicMock() + session = MagicMock() + driver.session.return_value.__enter__ = MagicMock( + return_value=session, + ) + driver.session.return_value.__exit__ = MagicMock(return_value=False) + return driver, session + + +@pytest.fixture +def loader(mock_driver): + driver, _ = mock_driver + return GraphLoader(driver) + + +def _assertion(predicate=AssertionPredicate.HAS_LABEL, **overrides): + base = dict( + id="a-1", + subject_ref="databricks://ws/cat/sch/tbl", + predicate=predicate, + payload={"value": "x"}, + source="llm_interpretation", + confidence=0.9, + status=AssertionStatus.AUTO, + run_id="run-1", + observed_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + base.update(overrides) + return Assertion(**base) + + +class TestAssertionSourceSchemaStamp: + def test_assertion_model_has_source_schema_field(self): + a = _assertion(source_schema=SCHEMA_BRCA) + assert a.source_schema == SCHEMA_BRCA + + def test_assertion_source_schema_optional(self): + a = _assertion() + assert a.source_schema is None + + def test_store_assertion_writes_source_schema( + self, loader, mock_driver, + ): + _, session = mock_driver + a = _assertion() + loader.store_assertion(a, source_schema=SCHEMA_BRCA) + cypher = session.run.call_args[0][0] + params = session.run.call_args[1] + assert "source_schema" in cypher + assert params["source_schema"] == SCHEMA_BRCA + + def test_commit_table_assertions_threads_source_schema( + self, loader, mock_driver, + ): + driver, session = mock_driver + tx = MagicMock() + session.begin_transaction.return_value = tx + loader.commit_table_assertions( + [_assertion()], source_schema=SCHEMA_MSK, + ) + cypher = tx.run.call_args[0][0] + kwargs = tx.run.call_args[1] + assert "source_schema: a.source_schema" in cypher + assert kwargs["assertions"][0]["source_schema"] == SCHEMA_MSK + + +class TestEdgeStamping: + def _row(self, **overrides): + base = dict( + name="Patient", + description=None, + source="llm_interpretation", + confidence=0.9, + entity_name="Patient", + column_name="age", + table_name="patient", + schema_name="sch", + catalog="cat", + ) + base.update(overrides) + return base + + def test_entity_on_table_carries_source_schema(self, loader): + loader._run = MagicMock() + batch_upsert_entities( + loader, [self._row()], source_schema=SCHEMA_BRCA, + ) + cypher = loader._run.call_args[0][0] + assert "ENTITY_ON_TABLE" in cypher + assert "source_schema: r.source_schema" in cypher + rows = loader._run.call_args[1]["rows"] + assert rows[0]["source_schema"] == SCHEMA_BRCA + + def test_has_property_and_property_on_column_stamped(self, loader): + loader._run = MagicMock() + batch_upsert_properties( + loader, [self._row()], source_schema=SCHEMA_MSK, + ) + cypher = loader._run.call_args[0][0] + assert "HAS_PROPERTY" in cypher + assert "PROPERTY_ON_COLUMN" in cypher + assert cypher.count("source_schema: r.source_schema") == 2 + + def test_has_value_set_stamped(self, loader): + loader._run = MagicMock() + batch_upsert_value_sets( + loader, + [{ + "name": "vs", "column_ref": "cat.sch.tbl.col", + "column_name": "col", "table_name": "tbl", + "schema_name": "sch", "catalog": "cat", + }], + source_schema=SCHEMA_BRCA, + ) + cypher = loader._run.call_args[0][0] + assert "HAS_VALUE_SET" in cypher + assert "source_schema: r.source_schema" in cypher + + def test_member_of_stamped(self, loader, mock_driver): + _, session = mock_driver + loader.add_term_to_value_set( + "TP53", "patient_status_values", + source_schema=SCHEMA_BRCA, + ) + cypher = session.run.call_args[0][0] + params = session.run.call_args[1] + assert "MEMBER_OF" in cypher + assert "source_schema: $source_schema" in cypher + assert params["source_schema"] == SCHEMA_BRCA + + def test_parent_of_stamped(self, loader, mock_driver): + _, session = mock_driver + loader.add_term_hierarchy( + "NEOPLASM", "CARCINOMA", source_schema=SCHEMA_MSK, + ) + cypher = session.run.call_args[0][0] + params = session.run.call_args[1] + assert "PARENT_OF" in cypher + assert params["source_schema"] == SCHEMA_MSK + + def test_refers_to_stamped(self, loader): + loader._run = MagicMock() + batch_upsert_aliases( + loader, + [{ + "text": "colon cancer", + "target_key": "ref", + "parent_name": "Cancer Diagnosis", + "parent_entity_name": None, + "source": "llm", "confidence": 0.8, + "is_preferred": False, "description": None, + }], + ":Entity", + source_schema=SCHEMA_BRCA, + ) + cypher = loader._run.call_args[0][0] + assert "REFERS_TO" in cypher + assert "source_schema: r.source_schema" in cypher + + +class TestMergeKeys: + def test_entity_merge_is_name_only_no_datasource_id(self, loader): + loader._run = MagicMock() + batch_upsert_entities( + loader, + [{ + "name": "Patient", "description": None, + "source": "llm", "confidence": 0.9, + "table_name": "p", "schema_name": "s", "catalog": "c", + }], + source_schema=SCHEMA_BRCA, + ) + cypher = loader._run.call_args[0][0] + assert "MERGE (e:Entity {name: r.name})" in cypher + assert "datasource_id" not in cypher + assert "table_key" not in cypher + + def test_property_merge_is_entity_name_plus_name(self, loader): + loader._run = MagicMock() + batch_upsert_properties( + loader, + [{ + "name": "age", "entity_name": "Patient", + "semantic_type": "numeric", + "source": "llm", "confidence": 0.9, + "column_name": "age", "table_name": "p", + "schema_name": "s", "catalog": "c", + }], + source_schema=SCHEMA_BRCA, + ) + cypher = loader._run.call_args[0][0] + assert ( + "MERGE (p:Property {entity_name: r.entity_name, " + "name: r.name})" + ) in cypher + + def test_value_set_merge_is_column_ref(self, loader): + loader._run = MagicMock() + batch_upsert_value_sets( + loader, + [{ + "name": "vs", "column_ref": "c.s.t.col", + "column_name": "col", "table_name": "t", + "schema_name": "s", "catalog": "c", + }], + source_schema=SCHEMA_BRCA, + ) + cypher = loader._run.call_args[0][0] + assert "MERGE (vs:ValueSet {column_ref: r.column_ref})" in cypher + + def test_join_path_merge_keyed_by_name_and_source_schema( + self, loader, + ): + loader._run = MagicMock() + batch_upsert_join_paths( + loader, + [{ + "name": "patient_to_sample", + "join_predicates": [{"left_table": "patient"}], + "hop_count": 1, "source": "fk_detector", + "confidence": 0.95, + "sql_snippet": None, "cardinality_hint": None, + }], + source_schema=SCHEMA_BRCA, + ) + cypher = loader._run.call_args[0][0] + assert ( + "MERGE (jp:JoinPath {name: r.name, " + "source_schema: r.source_schema})" + ) in cypher + + +class TestJoinPathEdgeRequiresSourceSchema: + def test_uses_requires_source_schema(self, loader): + with pytest.raises(ValueError, match="source_schema"): + loader.add_join_path_uses("jp", "tbl_ref") + + def test_entity_links_requires_source_schema(self, loader): + with pytest.raises(ValueError, match="source_schema"): + loader.add_join_path_entity_links( + "jp", "from", "to", + ) + + def test_uses_with_schema_matches_by_name_and_schema( + self, loader, mock_driver, + ): + _, session = mock_driver + loader.add_join_path_uses( + "jp", "tbl_ref", source_schema=SCHEMA_BRCA, + ) + cypher = session.run.call_args[0][0] + assert ( + "MATCH (jp:JoinPath {name: $jp_name, " + "source_schema: $source_schema})" + ) in cypher + + +class TestScopedDelete: + def test_delete_runs_three_queries(self, loader, mock_driver): + _, session = mock_driver + loader.delete_study_scoped(SCHEMA_BRCA) + calls = [c[0][0] for c in session.run.call_args_list] + assert any( + "MATCH ()-[r {source_schema: $schema}]-() DELETE r" in c + for c in calls + ) + assert any( + "MATCH (a:Assertion {source_schema: $schema})" in c + and "DETACH DELETE a" in c + for c in calls + ) + assert any( + "MATCH (jp:JoinPath {source_schema: $schema})" in c + and "DETACH DELETE jp" in c + for c in calls + ) + for c in session.run.call_args_list: + assert c[1]["schema"] == SCHEMA_BRCA + + def test_delete_does_not_touch_shared_concept_nodes( + self, loader, mock_driver, + ): + _, session = mock_driver + loader.delete_study_scoped(SCHEMA_BRCA) + joined = " ".join( + c[0][0] for c in session.run.call_args_list + ) + for shared in ( + ":Entity", ":Term", ":ValueSet", ":Property", + ":SemanticType", ":Table", ":Column", ":Schema", + ): + assert f"DELETE n.{shared}" not in joined + assert f"DETACH DELETE {shared}" not in joined diff --git a/tests/unit/test_healthcare_few_shot_msk_chord.py b/tests/unit/test_healthcare_few_shot_msk_chord.py new file mode 100644 index 0000000..0709d41 --- /dev/null +++ b/tests/unit/test_healthcare_few_shot_msk_chord.py @@ -0,0 +1,165 @@ +"""MSK CHORD healthcare few-shot expansion: Stage A/B/C examples for labs, procedures, +biomarker states, performance status, CNA matrices.""" +from __future__ import annotations + +import pytest + +from sema.engine.few_shot import compose_examples, get_examples +from sema.engine.few_shot_healthcare import ( + HEALTHCARE_STAGE_A, + HEALTHCARE_STAGE_B, + HEALTHCARE_STAGE_C, +) + +pytestmark = pytest.mark.unit + + +def _table_names_in_stage(stage: list[dict[str, object]]) -> set[str]: + names: set[str] = set() + for ex in stage: + inp = ex["input"] + assert isinstance(inp, dict) + name = inp.get("table_name") + if isinstance(name, str): + names.add(name) + return names + + +class TestStageAHealthcareBreadth: + def test_includes_lab_timeline(self) -> None: + names = _table_names_in_stage(HEALTHCARE_STAGE_A) + assert "timeline_labtest" in names + + def test_includes_procedure_event(self) -> None: + names = _table_names_in_stage(HEALTHCARE_STAGE_A) + assert any(n.startswith("timeline_") and "surg" in n.lower() for n in names) or \ + "procedure" in names or "timeline_surgery" in names + + def test_includes_performance_status(self) -> None: + names = _table_names_in_stage(HEALTHCARE_STAGE_A) + assert "timeline_performance_status" in names + + def test_includes_cna_segmented(self) -> None: + names = _table_names_in_stage(HEALTHCARE_STAGE_A) + assert "cna_segmented" in names + + +class TestStageBHealthcareBreadth: + def _by_col(self) -> dict[tuple[str, str], dict[str, object]]: + index: dict[tuple[str, str], dict[str, object]] = {} + for ex in HEALTHCARE_STAGE_B: + inp = ex["input"] + assert isinstance(inp, dict) + tbl = inp.get("table_name") + col = inp.get("column") + assert isinstance(tbl, str) and isinstance(col, str) + index[(tbl, col)] = ex["output"] # type: ignore[assignment] + return index + + def test_lab_value_with_units_present(self) -> None: + index = self._by_col() + assert ("timeline_labtest", "VALUE") in index + assert ("timeline_labtest", "UNITS") in index + + def test_pdl1_biomarker_state(self) -> None: + index = self._by_col() + keys = list(index.keys()) + assert any("pd_l1" in c.lower() or "pdl1" in c.lower() for _t, c in keys) + + def test_mmr_status(self) -> None: + index = self._by_col() + assert any("mmr" in c.lower() for _t, c in index) + + def test_ecog_performance_score(self) -> None: + index = self._by_col() + assert any("ecog" in c.lower() for _t, c in index) + + def test_karnofsky_performance_score(self) -> None: + index = self._by_col() + assert any("karnofsky" in c.lower() for _t, c in index) + + def test_procedure_code(self) -> None: + index = self._by_col() + assert any( + "procedure" in t.lower() or "procedure_code" in c.lower() + for t, c in index + ) + + +class TestStageCHealthcareDecodings: + def _columns(self) -> set[str]: + cols: set[str] = set() + for ex in HEALTHCARE_STAGE_C: + inp = ex["input"] + assert isinstance(inp, dict) + col = inp.get("column") + if isinstance(col, str): + cols.add(col.lower()) + return cols + + def test_pdl1_decoding(self) -> None: + assert any("pd_l1" in c or "pdl1" in c for c in self._columns()) + + def test_mmr_decoding(self) -> None: + assert any("mmr" in c for c in self._columns()) + + def test_gleason_decoding(self) -> None: + assert any("gleason" in c for c in self._columns()) + + def test_ecog_decoding(self) -> None: + assert any("ecog" in c for c in self._columns()) + + def test_karnofsky_decoding(self) -> None: + assert any("karnofsky" in c for c in self._columns()) + + +class TestSynonymsCoverageOnNewExamples: + def test_lab_biomarker_procedure_examples_have_two_or_more_synonyms(self) -> None: + targets: list[tuple[str, str]] = [] + for ex in HEALTHCARE_STAGE_B: + inp = ex["input"] + out = ex["output"] + assert isinstance(inp, dict) and isinstance(out, dict) + tbl = inp.get("table_name", "") + col = inp.get("column", "") + assert isinstance(tbl, str) and isinstance(col, str) + sem_type = out.get("semantic_type", "") + assert isinstance(sem_type, str) + is_lab = "lab" in tbl.lower() or col.upper() in {"VALUE", "UNITS", "TEST"} + is_biomarker = ( + "biomarker" in sem_type.lower() + or "pd_l1" in col.lower() + or "pdl1" in col.lower() + or "mmr" in col.lower() + or "gleason" in col.lower() + ) + is_procedure = "procedure" in tbl.lower() or "procedure_code" in col.lower() + if is_lab or is_biomarker or is_procedure: + synonyms = out.get("synonyms", []) + assert isinstance(synonyms, list) + if not synonyms or len(synonyms) < 2: + targets.append((tbl, col)) + assert not targets, ( + f"Lab/biomarker/procedure B examples missing ≥2 synonyms: {targets}" + ) + + +class TestRegistryFallbackBehavior: + def test_healthcare_domain_returns_healthcare_examples(self) -> None: + composed_b = compose_examples("healthcare", "B") + # composed is generic + healthcare + healthcare_only = get_examples("healthcare", "B") + assert healthcare_only, "healthcare B examples should be registered" + # All healthcare entries appear in composed + for hc_ex in healthcare_only: + assert hc_ex in composed_b + + def test_unknown_domain_falls_back_to_generic_only(self) -> None: + composed = compose_examples("nonexistent_domain", "B") + from sema.engine.few_shot_generic import GENERIC_STAGE_B + assert composed == list(GENERIC_STAGE_B) + + def test_zero_shot_when_neither_registered(self) -> None: + # Stage Z is not registered for any domain + composed = compose_examples("healthcare", "Z") + assert composed == [] diff --git a/tests/unit/test_join_detector.py b/tests/unit/test_join_detector.py new file mode 100644 index 0000000..5e5f28c --- /dev/null +++ b/tests/unit/test_join_detector.py @@ -0,0 +1,263 @@ +"""Unit tests for the FK / join detector (Section 12).""" +from __future__ import annotations + +import pytest + +from sema.engine.join_detector import ( + DEFAULT_MATERIALIZE_THRESHOLD, + JoinDetector, + TIER_1, + TIER_2, + TIER_3, + to_fk_assertion, +) +from sema.engine.join_detector_utils import ( + enumerate_candidates_from_metadata, + fk_name_root, + types_compatible, +) +from sema.models.assertions import AssertionPredicate +from sema.models.extraction import ExtractedColumn + +pytestmark = pytest.mark.unit + +SCHEMA = "cbioportal_msk_chord_2024" +CAT = "workspace" + + +def _col(name: str, table: str, dtype: str = "STRING") -> ExtractedColumn: + return ExtractedColumn( + name=name, table_name=table, + catalog=CAT, schema=SCHEMA, data_type=dtype, + ) + + +class TestNameAndType: + def test_fk_name_root_recognizes_id_suffix(self): + assert fk_name_root("patient_id") == "patient" + assert fk_name_root("sample_key") == "sample" + assert fk_name_root("gene_code") == "gene" + + def test_fk_name_root_rejects_unrelated(self): + assert fk_name_root("os_status") is None + assert fk_name_root("age") is None + + def test_types_compatible_handles_aliases(self): + assert types_compatible("STRING", "VARCHAR(64)") + assert types_compatible("BIGINT", "INT") + assert not types_compatible("STRING", "BIGINT") + + +class TestCandidateEnumeration: + def test_proposes_patient_to_sample(self): + cols = [ + _col("patient_id", "patient"), + _col("os_status", "patient"), + _col("patient_id", "sample"), + _col("sample_id", "sample"), + ] + cands = enumerate_candidates_from_metadata(cols) + match = [ + c for c in cands + if c.fk_table == "sample" and c.fk_column == "patient_id" + and c.pk_table == "patient" and c.pk_column == "patient_id" + ] + assert len(match) == 1 + + def test_rejects_cross_schema_pair(self): + cols = [ + ExtractedColumn( + "patient_id", "patient", CAT, "schema_a", "STRING", + ), + ExtractedColumn( + "patient_id", "sample", CAT, "schema_b", "STRING", + ), + ] + cands = enumerate_candidates_from_metadata(cols) + assert all( + c.schema_name in {"schema_a", "schema_b"} for c in cands + ) + for c in cands: + assert ( + c.fk_column == "patient_id" + and c.pk_table == "patient" + ) is False or c.schema_name == c.schema_name + cross = [ + c for c in cands + if c.fk_table == "sample" and c.pk_table == "patient" + ] + assert cross == [] + + def test_rejects_type_mismatch(self): + cols = [ + _col("patient_id", "patient", "STRING"), + _col("patient_id", "sample", "BIGINT"), + ] + assert enumerate_candidates_from_metadata(cols) == [] + + +class TestTierAssignment: + def setup_method(self): + self.cols = [ + _col("patient_id", "patient"), + _col("patient_id", "sample"), + ] + + def test_tier_1_with_subset_samples(self): + det = JoinDetector() + samples = { + (SCHEMA, "patient", "patient_id"): {"P1", "P2", "P3"}, + (SCHEMA, "sample", "patient_id"): {"P1", "P2"}, + } + out = det.detect( + columns=self.cols, source_schema=SCHEMA, samples=samples, + ) + assert len(out) == 1 + assert out[0].tier == 1 + assert out[0].confidence == TIER_1 + + def test_tier_2_with_consistent_cardinality(self): + det = JoinDetector() + profiles = { + (SCHEMA, "patient", "patient_id"): (1000, 1000), + (SCHEMA, "sample", "patient_id"): (850, 3000), + } + out = det.detect( + columns=self.cols, source_schema=SCHEMA, profiles=profiles, + ) + assert out[0].tier == 2 + assert out[0].confidence == TIER_2 + + def test_tier_3_structural_only(self): + det = JoinDetector() + out = det.detect(columns=self.cols, source_schema=SCHEMA) + assert out[0].tier == 3 + assert out[0].confidence == TIER_3 + + +class TestSampleSourcingFallback: + def setup_method(self): + self.cols = [ + _col("patient_id", "patient"), + _col("patient_id", "sample"), + ] + + def test_profiler_samples_preferred(self): + det = JoinDetector() + sampler_calls: list[tuple] = [] + + def sampler(key): + sampler_calls.append(key) + return {"unused"} + + samples = { + (SCHEMA, "patient", "patient_id"): {"a", "b"}, + (SCHEMA, "sample", "patient_id"): {"a"}, + } + det.detect( + columns=self.cols, source_schema=SCHEMA, + samples=samples, sampler=sampler, + ) + assert sampler_calls == [] + + def test_detector_owned_sampling_when_profiler_missing(self): + det = JoinDetector() + sampler_calls: list[tuple] = [] + + def sampler(key): + sampler_calls.append(key) + return {"a", "b"} if "patient_id" in key[2] else None + + out = det.detect( + columns=self.cols, source_schema=SCHEMA, sampler=sampler, + ) + assert len(sampler_calls) == 2 + assert out[0].tier == 1 + + def test_cap_exceeded_downgrades_to_tier_2(self): + det = JoinDetector(sample_cap=3) + # FK sample has exactly cap → inconclusive, downgrade. + samples = { + (SCHEMA, "patient", "patient_id"): {"a", "b", "c"}, + (SCHEMA, "sample", "patient_id"): {"a", "b", "c"}, + } + profiles = { + (SCHEMA, "patient", "patient_id"): (100, 100), + (SCHEMA, "sample", "patient_id"): (90, 200), + } + out = det.detect( + columns=self.cols, source_schema=SCHEMA, + samples=samples, profiles=profiles, + ) + assert out[0].tier == 2 + + def test_cap_exceeded_no_cardinality_downgrades_to_tier_3(self): + det = JoinDetector(sample_cap=2) + samples = { + (SCHEMA, "patient", "patient_id"): {"a", "b"}, + (SCHEMA, "sample", "patient_id"): {"a", "b"}, + } + out = det.detect( + columns=self.cols, source_schema=SCHEMA, samples=samples, + ) + assert out[0].tier == 3 + + def test_warehouse_error_falls_through_to_tier_3(self): + det = JoinDetector() + + def sampler(key): + raise RuntimeError("warehouse unreachable") + + out = det.detect( + columns=self.cols, source_schema=SCHEMA, sampler=sampler, + ) + assert out[0].tier == 3 + + +class TestMaterializationThreshold: + def setup_method(self): + self.cols = [ + _col("patient_id", "patient"), + _col("patient_id", "sample"), + ] + + def test_default_excludes_tier_3(self): + det = JoinDetector() + out = det.detect(columns=self.cols, source_schema=SCHEMA) + assert det.should_materialize(out[0]) is False + + def test_default_threshold_value(self): + assert DEFAULT_MATERIALIZE_THRESHOLD == 0.80 + + def test_default_includes_tier_2(self): + det = JoinDetector() + profiles = { + (SCHEMA, "patient", "patient_id"): (10, 10), + (SCHEMA, "sample", "patient_id"): (5, 30), + } + out = det.detect( + columns=self.cols, source_schema=SCHEMA, profiles=profiles, + ) + assert det.should_materialize(out[0]) is True + + def test_explicit_opt_in_includes_tier_3(self): + det = JoinDetector(materialization_threshold=0.70) + out = det.detect(columns=self.cols, source_schema=SCHEMA) + assert det.should_materialize(out[0]) is True + + +class TestAssertionEmission: + def test_to_fk_assertion_carries_source_schema_and_predicate(self): + det = JoinDetector() + cols = [ + _col("patient_id", "patient"), + _col("patient_id", "sample"), + ] + out = det.detect(columns=cols, source_schema=SCHEMA) + assertion = to_fk_assertion(out[0], run_id="run-1") + assert assertion.predicate == AssertionPredicate.FK_TO + assert assertion.source_schema == SCHEMA + assert assertion.confidence == TIER_3 + assert assertion.payload["fk_table"] == "sample" + assert assertion.payload["pk_table"] == "patient" + assert assertion.payload["tier"] == 3 diff --git a/tests/unit/test_join_materializer.py b/tests/unit/test_join_materializer.py new file mode 100644 index 0000000..cc9d7cc --- /dev/null +++ b/tests/unit/test_join_materializer.py @@ -0,0 +1,139 @@ +"""Tests for `sema.graph.join_materializer` source_schema threading.""" +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest + +from sema.engine.join_detector import FKAssertion, to_fk_assertion +from sema.engine.join_detector_utils import FKCandidate +from sema.graph.join_materializer import ( + _build_join_path_records, + _derive_join_path_name, + materialize_join_paths, +) +from sema.models.assertions import ( + Assertion, + AssertionPredicate, + AssertionStatus, +) + +pytestmark = pytest.mark.unit + +SCHEMA = "cbioportal_msk_chord_2024" + + +def _join_assertion(name: str = "patient_to_sample"): + payload = { + "join_predicates": [{ + "left_table": "patient", "left_column": "patient_id", + "right_table": "sample", "right_column": "patient_id", + "operator": "=", + }], + "hop_count": 1, + "from_table": "databricks://ws/cat/sch/patient", + "to_table": "databricks://ws/cat/sch/sample", + } + return Assertion( + id=f"a-{name}", + subject_ref="databricks://ws/cat/sch/patient", + predicate=AssertionPredicate.HAS_JOIN_EVIDENCE, + payload=payload, + source="fk_detector", + confidence=0.95, + status=AssertionStatus.AUTO, + run_id="run-1", + observed_at=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + +def test_derive_join_path_name(): + name = _derive_join_path_name([ + {"left_table": "p", "left_column": "id", + "right_table": "s", "right_column": "p_id"}, + ]) + assert name == "p/id=s/p_id" + + +def test_build_join_path_records_skips_empty_group(): + records = _build_join_path_records({"x": []}) + assert records == [] + + +def test_materialize_threads_source_schema_to_batch(): + loader = MagicMock() + a = _join_assertion() + groups = {(a.subject_ref, a.predicate.value): [a]} + materialize_join_paths(loader, groups, source_schema=SCHEMA) + loader.add_join_path_uses.assert_called() + for call in loader.add_join_path_uses.call_args_list: + assert call.kwargs.get("source_schema") == SCHEMA + if loader.add_join_path_entity_links.called: + for call in loader.add_join_path_entity_links.call_args_list: + assert call.kwargs.get("source_schema") == SCHEMA + + +def test_materialize_skips_edge_writes_when_source_schema_missing(): + """Edge writes require source_schema; skip if absent (legacy).""" + loader = MagicMock() + a = _join_assertion() + groups = {(a.subject_ref, a.predicate.value): [a]} + materialize_join_paths(loader, groups) + loader.add_join_path_uses.assert_not_called() + loader.add_join_path_entity_links.assert_not_called() + + +def _fk_to_assertion(tier: int = 1, confidence: float = 0.95) -> Assertion: + candidate = FKCandidate( + pk_table="patient", pk_column="patient_id", + fk_table="sample", fk_column="patient_id", + pk_type="string", fk_type="string", + schema_name=SCHEMA, catalog="workspace", + ) + fk = FKAssertion( + candidate=candidate, confidence=confidence, + tier=tier, source_schema=SCHEMA, + ) + return to_fk_assertion(fk, run_id="run-1") + + +def test_fk_to_predicate_produces_join_path(): + loader = MagicMock() + a = _fk_to_assertion() + groups = {(a.subject_ref, a.predicate.value): [a]} + materialize_join_paths(loader, groups, source_schema=SCHEMA) + batch_calls = [ + c for c in loader.method_calls if c[0] == "_run" + ] + assert batch_calls, "expected JoinPath upsert to fire" + rows = batch_calls[0].kwargs["rows"] + assert len(rows) == 1 + rec = rows[0] + assert rec["name"] == "patient/patient_id=sample/patient_id" + assert rec["confidence"] == 0.95 + assert rec["source_schema"] == SCHEMA + loader.add_join_path_uses.assert_called() + loader.add_join_path_entity_links.assert_called() + for call in loader.add_join_path_uses.call_args_list: + assert call.kwargs.get("source_schema") == SCHEMA + + +def test_fk_to_and_legacy_join_evidence_coexist(): + loader = MagicMock() + legacy = _join_assertion(name="legacy") + fk = _fk_to_assertion() + groups = { + (legacy.subject_ref, legacy.predicate.value): [legacy], + (fk.subject_ref, fk.predicate.value): [fk], + } + materialize_join_paths(loader, groups, source_schema=SCHEMA) + rows = [ + c for c in loader.method_calls if c[0] == "_run" + ][0].kwargs["rows"] + names = sorted(r["name"] for r in rows) + assert names == [ + "patient/patient_id=sample/patient_id", + "patient/patient_id=sample/patient_id", + ] + assert len(rows) == 2 diff --git a/tests/unit/test_llm_client.py b/tests/unit/test_llm_client.py index 9b29986..9795abe 100644 --- a/tests/unit/test_llm_client.py +++ b/tests/unit/test_llm_client.py @@ -258,7 +258,10 @@ def invoke_side_effect(prompt): llm.invoke.side_effect = invoke_side_effect - client = LLMClient(llm, retry_max_attempts=3, retry_base_delay=0.01) + client = LLMClient( + llm, retry_max_attempts=3, retry_base_delay=0.01, + rate_limit_base_delay=0.01, + ) result = client.invoke("prompt", TableSummary) assert result.entity_name == "Patient" assert call_count[0] == 2 @@ -316,7 +319,10 @@ def test_max_attempts_exhausted(self): rate_error.status_code = 429 llm.invoke.side_effect = rate_error - client = LLMClient(llm, retry_max_attempts=3, retry_base_delay=0.01) + client = LLMClient( + llm, retry_max_attempts=3, retry_base_delay=0.01, + rate_limit_base_delay=0.01, + ) with pytest.raises(LLMStageError): client.invoke("prompt", TableSummary, table_ref="ref", stage_name="test") assert llm.invoke.call_count == 3 # 3 attempts in step 2 (no structured output) @@ -327,14 +333,13 @@ def test_max_attempts_exhausted(self): # --------------------------------------------------------------------------- class TestBackoffTiming: - def test_exponential_delay_pattern(self): + def test_exponential_delay_pattern_for_non_rate_limit_transient(self): llm = MagicMock(spec=["invoke"]) - rate_error = Exception("rate limit") - rate_error.status_code = 429 - llm.invoke.side_effect = rate_error + server_error = Exception("internal server error") + server_error.status_code = 500 + llm.invoke.side_effect = server_error sleep_times = [] - original_sleep = time.sleep with patch("sema.llm_client.time.sleep") as mock_sleep: mock_sleep.side_effect = lambda t: sleep_times.append(t) client = LLMClient( @@ -344,16 +349,15 @@ def test_exponential_delay_pattern(self): with pytest.raises(LLMStageError): client.invoke("prompt", TableSummary, table_ref="ref", stage_name="test") - # 2 sleeps (between attempt 1→2 and 2→3) assert len(sleep_times) == 2 - assert sleep_times[0] == pytest.approx(2.0, abs=0.1) # base * 2^0 - assert sleep_times[1] == pytest.approx(4.0, abs=0.1) # base * 2^1 + assert sleep_times[0] == pytest.approx(2.0, abs=0.1) + assert sleep_times[1] == pytest.approx(4.0, abs=0.1) - def test_jitter_within_range(self): + def test_jitter_within_range_for_non_rate_limit_transient(self): llm = MagicMock(spec=["invoke"]) - rate_error = Exception("rate limit") - rate_error.status_code = 429 - llm.invoke.side_effect = rate_error + server_error = Exception("internal server error") + server_error.status_code = 500 + llm.invoke.side_effect = server_error sleep_times = [] with patch("sema.llm_client.time.sleep") as mock_sleep: @@ -366,11 +370,109 @@ def test_jitter_within_range(self): client.invoke("prompt", TableSummary, table_ref="ref", stage_name="test") assert len(sleep_times) == 2 - # First delay: 2.0 ± 0.5 → [1.5, 2.5] assert 1.5 <= sleep_times[0] <= 2.5 - # Second delay: 4.0 ± 0.5 → [3.5, 4.5] assert 3.5 <= sleep_times[1] <= 4.5 + def test_rate_limit_uses_long_backoff_schedule(self): + llm = MagicMock(spec=["invoke"]) + rate_error = Exception("REQUEST_LIMIT_EXCEEDED: rate limit") + rate_error.status_code = 429 + llm.invoke.side_effect = rate_error + + sleep_times = [] + with patch("sema.llm_client.time.sleep") as mock_sleep: + mock_sleep.side_effect = lambda t: sleep_times.append(t) + client = LLMClient( + llm, retry_max_attempts=3, retry_base_delay=2.0, + retry_multiplier=2.0, retry_jitter=0.0, + ) + with pytest.raises(LLMStageError): + client.invoke("prompt", TableSummary, table_ref="ref", stage_name="test") + + assert len(sleep_times) == 2 + assert sleep_times[0] >= 10.0 + assert sleep_times[1] >= 30.0 + + +@pytest.mark.unit +class TestRateLimitClassification: + def test_request_limit_exceeded_message_classified_as_rate_limit(self): + from sema.llm_client import _is_rate_limit_error + + err = Exception("REQUEST_LIMIT_EXCEEDED: Exceeded workspace tokens/min") + assert _is_rate_limit_error(err) is True + + def test_429_status_code_classified_as_rate_limit(self): + from sema.llm_client import _is_rate_limit_error + + err = Exception("nope") + err.status_code = 429 + assert _is_rate_limit_error(err) is True + + def test_500_not_classified_as_rate_limit(self): + from sema.llm_client import _is_rate_limit_error + + err = Exception("internal server error") + err.status_code = 500 + assert _is_rate_limit_error(err) is False + + +@pytest.mark.unit +class TestCircuitBreakerInteraction: + def test_pure_rate_limit_failure_does_not_record_circuit_breaker_failure(self): + llm = MagicMock(spec=["invoke"]) + rate_error = Exception("REQUEST_LIMIT_EXCEEDED") + rate_error.status_code = 429 + llm.invoke.side_effect = rate_error + breaker = MagicMock() + breaker.check.return_value = None + + with patch("sema.llm_client.time.sleep"): + client = LLMClient( + llm, retry_max_attempts=2, retry_base_delay=0.01, + circuit_breaker=breaker, + ) + with pytest.raises(LLMStageError): + client.invoke("prompt", TableSummary, table_ref="ref", stage_name="test") + + breaker.record_failure.assert_not_called() + + def test_non_rate_limit_failure_still_records_circuit_breaker_failure(self): + llm = MagicMock(spec=["invoke"]) + server_error = Exception("internal server error") + server_error.status_code = 500 + llm.invoke.side_effect = server_error + breaker = MagicMock() + breaker.check.return_value = None + + with patch("sema.llm_client.time.sleep"): + client = LLMClient( + llm, retry_max_attempts=2, retry_base_delay=0.01, + circuit_breaker=breaker, + ) + with pytest.raises(LLMStageError): + client.invoke("prompt", TableSummary, table_ref="ref", stage_name="test") + + breaker.record_failure.assert_called_once() + + def test_success_records_circuit_breaker_success(self): + llm = MagicMock(spec=["invoke"]) + response = MagicMock() + response.content = '{"entity_name": "Patient"}' + llm.invoke.return_value = response + breaker = MagicMock() + breaker.check.return_value = None + + client = LLMClient( + llm, retry_max_attempts=1, retry_base_delay=0.01, + circuit_breaker=breaker, + ) + result = client.invoke("prompt", TableSummary) + + assert result.entity_name == "Patient" + breaker.record_success.assert_called_once() + breaker.record_failure.assert_not_called() + # --------------------------------------------------------------------------- # Configurable retry_max_attempts tests (Task 3.3) @@ -383,7 +485,10 @@ def test_custom_retry_count(self): rate_error.status_code = 429 llm.invoke.side_effect = rate_error - client = LLMClient(llm, retry_max_attempts=5, retry_base_delay=0.01) + client = LLMClient( + llm, retry_max_attempts=5, retry_base_delay=0.01, + rate_limit_base_delay=0.01, + ) with pytest.raises(LLMStageError): client.invoke("prompt", TableSummary, table_ref="ref", stage_name="test") assert llm.invoke.call_count == 5 @@ -394,7 +499,10 @@ def test_retries_disabled(self): rate_error.status_code = 429 llm.invoke.side_effect = rate_error - client = LLMClient(llm, retry_max_attempts=1, retry_base_delay=0.01) + client = LLMClient( + llm, retry_max_attempts=1, retry_base_delay=0.01, + rate_limit_base_delay=0.01, + ) with pytest.raises(LLMStageError): client.invoke("prompt", TableSummary, table_ref="ref", stage_name="test") assert llm.invoke.call_count == 1 diff --git a/tests/unit/test_loader_utils_columns.py b/tests/unit/test_loader_utils_columns.py new file mode 100644 index 0000000..4b92881 --- /dev/null +++ b/tests/unit/test_loader_utils_columns.py @@ -0,0 +1,59 @@ +"""Tests for `fetch_columns_by_schema` in `sema.graph.loader_utils`.""" +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from sema.graph.loader_utils import fetch_columns_by_schema +from sema.models.extraction import ExtractedColumn + +pytestmark = pytest.mark.unit + + +def test_fetch_columns_by_schema_returns_extracted_columns(): + loader = MagicMock() + loader._run_read.return_value = [ + { + "name": "patient_id", "table_name": "patient", + "catalog": "workspace", + "schema_name": "cbioportal_msk_chord_2024", + "data_type": "string", "nullable": False, + "comment": None, + }, + { + "name": "patient_id", "table_name": "sample", + "catalog": "workspace", + "schema_name": "cbioportal_msk_chord_2024", + "data_type": "string", "nullable": True, + "comment": "fk to patient.patient_id", + }, + ] + + cols = fetch_columns_by_schema( + loader, "cbioportal_msk_chord_2024", + ) + + assert len(cols) == 2 + assert all(isinstance(c, ExtractedColumn) for c in cols) + assert cols[0].name == "patient_id" + assert cols[0].table_name == "patient" + assert cols[0].nullable is False + assert cols[1].comment == "fk to patient.patient_id" + + +def test_fetch_columns_passes_schema_filter_to_query(): + loader = MagicMock() + loader._run_read.return_value = [] + + fetch_columns_by_schema(loader, "some_schema") + + loader._run_read.assert_called_once() + _args, kwargs = loader._run_read.call_args + assert kwargs.get("schema_name") == "some_schema" + + +def test_fetch_columns_returns_empty_list_when_no_results(): + loader = MagicMock() + loader._run_read.return_value = [] + assert fetch_columns_by_schema(loader, "empty_schema") == [] diff --git a/tests/unit/test_materializer_utils.py b/tests/unit/test_materializer_utils.py index ca5f134..fc94d8d 100644 --- a/tests/unit/test_materializer_utils.py +++ b/tests/unit/test_materializer_utils.py @@ -442,6 +442,7 @@ def test_creates_hierarchy_edges(self): apply_resolution_edges(loader, groups) loader.add_term_hierarchy.assert_called_once_with( parent_code="NEOPLASM", child_code="CARCINOMA", + source_schema=None, ) def test_skips_rejected_hierarchy(self): diff --git a/tests/unit/test_naming.py b/tests/unit/test_naming.py new file mode 100644 index 0000000..c5103d0 --- /dev/null +++ b/tests/unit/test_naming.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import hashlib + +import pytest + +from sema.ingest.naming import sanitize_schema_name + + +@pytest.mark.unit +class TestSanitizeSchemaName: + def test_conforming_id_passes_through(self) -> None: + assert sanitize_schema_name("cbioportal", "msk_chord_2024") == "cbioportal_msk_chord_2024" + + def test_lowercases_input(self) -> None: + assert sanitize_schema_name("cbioportal", "MSK_CHORD") == "cbioportal_msk_chord" + + def test_replaces_hyphens_with_underscores(self) -> None: + assert ( + sanitize_schema_name("cbioportal", "BRCA-TCGA-Pan-Cancer") + == "cbioportal_brca_tcga_pan_cancer" + ) + + def test_replaces_non_identifier_chars(self) -> None: + assert sanitize_schema_name("cbioportal", "study/v1.0") == "cbioportal_study_v1_0" + + def test_collapses_runs_of_underscores(self) -> None: + assert sanitize_schema_name("cbioportal", "a---b___c") == "cbioportal_a_b_c" + + def test_strips_leading_and_trailing_underscores(self) -> None: + assert sanitize_schema_name("cbioportal", "--abc--") == "cbioportal_abc" + + def test_truncates_with_hash_when_over_63(self) -> None: + long_id = "x" * 100 + result = sanitize_schema_name("cbioportal", long_id) + assert len(result) <= 63 + digest10 = hashlib.sha256(long_id.encode("utf-8")).hexdigest()[:10] + assert result.endswith(f"_{digest10}") + assert result.startswith("cbioportal_") + + def test_hash_suffix_is_deterministic(self) -> None: + long_id = "study-" + ("z" * 100) + first = sanitize_schema_name("cbioportal", long_id) + second = sanitize_schema_name("cbioportal", long_id) + assert first == second + + def test_disambiguates_long_ids_sharing_prefix(self) -> None: + a = "x" * 80 + "_alpha" + b = "x" * 80 + "_beta" + result_a = sanitize_schema_name("cbioportal", a) + result_b = sanitize_schema_name("cbioportal", b) + assert result_a != result_b + assert len(result_a) <= 63 and len(result_b) <= 63 + + def test_short_collisions_not_resolved_by_hash(self) -> None: + # Hash suffix only kicks in when truncation applies. + # Short IDs that sanitize identically MUST collide so registry layer can fail fast. + first = sanitize_schema_name("cbioportal", "BRCA-TCGA") + second = sanitize_schema_name("cbioportal", "BRCA_TCGA") + assert first == second == "cbioportal_brca_tcga" + + def test_result_never_exceeds_63_chars(self) -> None: + for n in (50, 63, 64, 80, 200, 1000): + result = sanitize_schema_name("cbioportal", "x" * n) + assert len(result) <= 63, f"length {len(result)} for n={n}" + + def test_truncation_does_not_leave_trailing_underscore_before_hash(self) -> None: + # Construct an ID where the natural truncation point lands on an underscore. + # cbioportal_ prefix = 11 chars, _ suffix = 11 chars, leaves 41 for sanitized. + long_id = "x" * 40 + "_" + "y" * 50 + result = sanitize_schema_name("cbioportal", long_id) + digest10 = hashlib.sha256(long_id.encode("utf-8")).hexdigest()[:10] + assert result.endswith(f"_{digest10}") + assert "__" not in result + + def test_raises_when_prefix_leaves_no_room(self) -> None: + with pytest.raises(ValueError): + sanitize_schema_name("a" * 60, "x" * 100) + + def test_empty_study_id_raises(self) -> None: + with pytest.raises(ValueError): + sanitize_schema_name("cbioportal", "") + + def test_study_id_only_specials_raises(self) -> None: + with pytest.raises(ValueError): + sanitize_schema_name("cbioportal", "---") diff --git a/tests/unit/test_orchestrate_fk_detection.py b/tests/unit/test_orchestrate_fk_detection.py new file mode 100644 index 0000000..21a224a --- /dev/null +++ b/tests/unit/test_orchestrate_fk_detection.py @@ -0,0 +1,203 @@ +"""Tests for FK detection wiring in `orchestrate_utils.run_fk_detection`.""" +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from sema.models.config import BuildConfig +from sema.models.extraction import ExtractedColumn +from sema.pipeline.orchestrate_utils import run_fk_detection + +pytestmark = pytest.mark.unit + +SCHEMA = "cbioportal_msk_chord_2024" + + +def _columns_with_fk_pair() -> list[ExtractedColumn]: + return [ + ExtractedColumn( + name="patient_id", table_name="patient", + catalog="workspace", schema=SCHEMA, data_type="string", + ), + ExtractedColumn( + name="patient_id", table_name="sample", + catalog="workspace", schema=SCHEMA, data_type="string", + ), + ] + + +def _build_config(**overrides) -> BuildConfig: + return BuildConfig( + catalog="workspace", schemas=[SCHEMA], **overrides, + ) + + +def _connector() -> MagicMock: + c = MagicMock() + c._execute = MagicMock(return_value=[]) + return c + + +def test_skips_when_disabled(): + loader = MagicMock() + config = _build_config(enable_fk_detection=False) + with patch( + "sema.pipeline.orchestrate_utils.fetch_columns_by_schema" + ) as fetch: + run_fk_detection( + loader, _connector(), config, [SCHEMA], run_id="r-1", + ) + fetch.assert_not_called() + + +def test_uses_default_threshold_080(): + loader = MagicMock() + config = _build_config() + captured: dict[str, float] = {} + with patch( + "sema.pipeline.orchestrate_utils.fetch_columns_by_schema", + return_value=[], + ), patch( + "sema.pipeline.orchestrate_utils.JoinDetector" + ) as DetectorCls: + DetectorCls.side_effect = lambda **kw: ( + captured.update(kw) or MagicMock() + ) + run_fk_detection( + loader, _connector(), config, [SCHEMA], run_id="r-1", + ) + assert captured["materialization_threshold"] == 0.80 + + +def test_threshold_drops_to_070_with_structural_opt_in(): + loader = MagicMock() + config = _build_config(materialize_structural_fk=True) + captured: dict[str, float] = {} + with patch( + "sema.pipeline.orchestrate_utils.fetch_columns_by_schema", + return_value=[], + ), patch( + "sema.pipeline.orchestrate_utils.JoinDetector" + ) as DetectorCls: + DetectorCls.side_effect = lambda **kw: ( + captured.update(kw) or MagicMock() + ) + run_fk_detection( + loader, _connector(), config, [SCHEMA], run_id="r-1", + ) + assert captured["materialization_threshold"] == 0.70 + + +def test_filters_by_should_materialize_then_calls_materializer(): + loader = MagicMock() + config = _build_config() + columns = _columns_with_fk_pair() + + above = MagicMock(name="above_threshold") + below = MagicMock(name="below_threshold") + detector_inst = MagicMock() + detector_inst.detect.return_value = [above, below] + detector_inst.should_materialize.side_effect = ( + lambda fk: fk is above + ) + + with patch( + "sema.pipeline.orchestrate_utils.fetch_columns_by_schema", + return_value=columns, + ), patch( + "sema.pipeline.orchestrate_utils.JoinDetector", + return_value=detector_inst, + ), patch( + "sema.pipeline.orchestrate_utils.to_fk_assertion" + ) as to_fk, patch( + "sema.pipeline.orchestrate_utils.materialize_join_paths" + ) as materialize: + to_fk.return_value = MagicMock( + subject_ref="ref-1", + predicate=MagicMock(value="fk_to"), + ) + run_fk_detection( + loader, _connector(), config, [SCHEMA], run_id="r-1", + ) + + to_fk.assert_called_once_with(above, "r-1") + materialize.assert_called_once() + _, kwargs = materialize.call_args + assert kwargs["source_schema"] == SCHEMA + + +def test_skips_schemas_with_no_columns(): + loader = MagicMock() + config = _build_config() + with patch( + "sema.pipeline.orchestrate_utils.fetch_columns_by_schema", + return_value=[], + ), patch( + "sema.pipeline.orchestrate_utils.materialize_join_paths" + ) as materialize: + run_fk_detection( + loader, _connector(), config, [SCHEMA], run_id="r-1", + ) + materialize.assert_not_called() + + +def test_passes_sampler_and_profiles_to_detector(): + loader = MagicMock() + config = _build_config() + columns = _columns_with_fk_pair() + connector = _connector() + + detector_inst = MagicMock() + detector_inst.detect.return_value = [] + + with patch( + "sema.pipeline.orchestrate_utils.fetch_columns_by_schema", + return_value=columns, + ), patch( + "sema.pipeline.orchestrate_utils.JoinDetector", + return_value=detector_inst, + ): + run_fk_detection( + loader, connector, config, [SCHEMA], run_id="r-1", + ) + + detect_kwargs = detector_inst.detect.call_args.kwargs + assert "sampler" in detect_kwargs + assert detect_kwargs["sampler"] is not None + assert "profiles" in detect_kwargs + assert isinstance(detect_kwargs["profiles"], dict) + + +def test_prebuilds_profiles_only_for_candidate_columns(): + """Profile lookup must run only for columns mentioned in candidates.""" + loader = MagicMock() + config = _build_config() + columns = _columns_with_fk_pair() + + profile_calls: list[tuple] = [] + + def fake_lookup(key): + profile_calls.append(key) + return (5, 10) + + detector_inst = MagicMock() + detector_inst.detect.return_value = [] + + with patch( + "sema.pipeline.orchestrate_utils.fetch_columns_by_schema", + return_value=columns, + ), patch( + "sema.pipeline.orchestrate_utils.JoinDetector", + return_value=detector_inst, + ), patch( + "sema.pipeline.orchestrate_utils.WarehouseProfileLookup", + return_value=fake_lookup, + ): + run_fk_detection( + loader, _connector(), config, [SCHEMA], run_id="r-1", + ) + + keys = set(profile_calls) + assert (SCHEMA, "patient", "patient_id") in keys + assert (SCHEMA, "sample", "patient_id") in keys diff --git a/tests/unit/test_parallel_execution.py b/tests/unit/test_parallel_execution.py index 4a32bc7..3f5cfd5 100644 --- a/tests/unit/test_parallel_execution.py +++ b/tests/unit/test_parallel_execution.py @@ -156,7 +156,7 @@ def process_one(work_item): work_item.table_name ) # Track what each loader commits - def capture_commit(assertions): + def capture_commit(assertions, source_schema=None): refs = [a.subject_ref for a in assertions] committed_refs.append( (work_item.table_name, refs) diff --git a/tests/unit/test_push_discovery.py b/tests/unit/test_push_discovery.py new file mode 100644 index 0000000..83dcb18 --- /dev/null +++ b/tests/unit/test_push_discovery.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pyarrow as pa +import pytest +from pydantic import SecretStr + +from sema.ingest.databricks_push import Bridge +from sema.ingest.duckdb_staging import Staging +from sema.ingest.study_registry import StudyRegistry +from sema.models.config import ( + DatabricksConfig, + IngestConfig, + IngestDatabricksTargetConfig, +) + + +def _config(target_schemas: list[str] | None = None) -> IngestConfig: + creds = DatabricksConfig( + host="https://test.databricks.com", + token=SecretStr("token"), + http_path="/sql/1.0/warehouses/abc", + ) + target = IngestDatabricksTargetConfig(catalog="workspace") + if target_schemas is not None: + target = IngestDatabricksTargetConfig(catalog="workspace", schemas=target_schemas) + return IngestConfig(databricks=target, databricks_creds=creds) + + +def _mock_cursor() -> MagicMock: + cursor = MagicMock() + cursor.__enter__ = MagicMock(return_value=cursor) + cursor.__exit__ = MagicMock(return_value=False) + cursor.fetchone.return_value = (1,) + return cursor + + +def _mock_connection(cursor: MagicMock) -> MagicMock: + conn = MagicMock() + conn.cursor.return_value = cursor + return conn + + +def _seed_table(staging: Staging, schema: str, table: str = "patient") -> None: + staging.execute(f'CREATE SCHEMA IF NOT EXISTS "{schema}"') + staging.write_table( + schema=schema, + table=table, + rows=pa.table({"id": [1]}), + column_types={"id": "INTEGER"}, + column_comments={}, + table_comment=None, + ) + + +@pytest.mark.unit +class TestStagingListRegisteredSchemas: + def test_returns_shared_only_when_registry_absent(self, tmp_path: Path) -> None: + staging = Staging(str(tmp_path / "no_reg.duckdb")) + result = staging.list_registered_schemas() + assert sorted(result) == ["ontology_omop", "vocabulary_omop"] + + def test_unions_registry_with_shared(self, tmp_path: Path) -> None: + staging = Staging(str(tmp_path / "reg.duckdb")) + registry = StudyRegistry(staging) + registry.register("cbioportal_msk_chord_2024", "msk_chord_2024", "cbioportal") + result = staging.list_registered_schemas() + assert sorted(result) == [ + "cbioportal_msk_chord_2024", + "ontology_omop", + "vocabulary_omop", + ] + + def test_dedups_when_registry_lists_shared_schema(self, tmp_path: Path) -> None: + staging = Staging(str(tmp_path / "dedup.duckdb")) + registry = StudyRegistry(staging) + registry.register("ontology_omop", "shared_ontology", "omop") + result = staging.list_registered_schemas() + assert result.count("ontology_omop") == 1 + + +@pytest.mark.unit +class TestStagingListAllSchemas: + def test_excludes_system_schemas(self, tmp_path: Path) -> None: + staging = Staging(str(tmp_path / "sys.duckdb")) + result = staging.list_all_schemas() + assert "main" not in result + assert "information_schema" not in result + assert "pg_catalog" not in result + assert "_sema" not in result + assert "ontology_omop" in result + assert "vocabulary_omop" in result + + def test_includes_user_created_schemas(self, tmp_path: Path) -> None: + staging = Staging(str(tmp_path / "user.duckdb")) + staging.execute('CREATE SCHEMA "scratch_experiment"') + result = staging.list_all_schemas() + assert "scratch_experiment" in result + + +@pytest.mark.unit +class TestPushSchemasDefaults: + def test_default_uses_registry_and_shared_only(self, tmp_path: Path) -> None: + staging = Staging(str(tmp_path / "default.duckdb")) + registry = StudyRegistry(staging) + registry.register("cbioportal_msk_chord_2024", "msk_chord_2024", "cbioportal") + _seed_table(staging, "cbioportal_msk_chord_2024") + _seed_table(staging, "ontology_omop") + # scratch schema NOT registered; must NOT be pushed by default + _seed_table(staging, "my_experiments") + + cursor = _mock_cursor() + conn = _mock_connection(cursor) + with patch("sema.ingest.databricks_push.sql_connect", return_value=conn): + bridge = Bridge(_config(target_schemas=[]), staging=staging) + results = bridge.push_schemas() + + pushed_schemas = {r.schema for r in results} + assert "cbioportal_msk_chord_2024" in pushed_schemas + assert "ontology_omop" in pushed_schemas + assert "my_experiments" not in pushed_schemas + + def test_explicit_schemas_overrides_discovery(self, tmp_path: Path) -> None: + staging = Staging(str(tmp_path / "explicit.duckdb")) + registry = StudyRegistry(staging) + registry.register("cbioportal_a", "A", "cbioportal") + registry.register("cbioportal_b", "B", "cbioportal") + _seed_table(staging, "cbioportal_a") + _seed_table(staging, "cbioportal_b") + + cursor = _mock_cursor() + conn = _mock_connection(cursor) + with patch("sema.ingest.databricks_push.sql_connect", return_value=conn): + bridge = Bridge(_config(target_schemas=[]), staging=staging) + results = bridge.push_schemas(schemas=["cbioportal_a"]) + + pushed_schemas = {r.schema for r in results} + assert pushed_schemas == {"cbioportal_a"} + + def test_config_target_schemas_acts_as_filter(self, tmp_path: Path) -> None: + staging = Staging(str(tmp_path / "filter.duckdb")) + registry = StudyRegistry(staging) + registry.register("cbioportal_a", "A", "cbioportal") + registry.register("cbioportal_b", "B", "cbioportal") + _seed_table(staging, "cbioportal_a") + _seed_table(staging, "cbioportal_b") + _seed_table(staging, "ontology_omop") + + cursor = _mock_cursor() + conn = _mock_connection(cursor) + with patch("sema.ingest.databricks_push.sql_connect", return_value=conn): + bridge = Bridge(_config(target_schemas=["cbioportal_a"]), staging=staging) + results = bridge.push_schemas() + + pushed_schemas = {r.schema for r in results} + assert pushed_schemas == {"cbioportal_a"} + + def test_discover_all_schemas_includes_unregistered( + self, tmp_path: Path, caplog: pytest.LogCaptureFixture + ) -> None: + staging = Staging(str(tmp_path / "all.duckdb")) + registry = StudyRegistry(staging) + registry.register("cbioportal_a", "A", "cbioportal") + _seed_table(staging, "cbioportal_a") + _seed_table(staging, "scratch_local") + + cursor = _mock_cursor() + conn = _mock_connection(cursor) + with patch("sema.ingest.databricks_push.sql_connect", return_value=conn): + bridge = Bridge(_config(target_schemas=[]), staging=staging) + results = bridge.push_schemas(discover_all=True) + + pushed_schemas = {r.schema for r in results} + assert "cbioportal_a" in pushed_schemas + assert "scratch_local" in pushed_schemas diff --git a/tests/unit/test_slice_contamination.py b/tests/unit/test_slice_contamination.py new file mode 100644 index 0000000..9c382d3 --- /dev/null +++ b/tests/unit/test_slice_contamination.py @@ -0,0 +1,109 @@ +"""Tests for slice contamination check: holdout must not overlap with +few-shot source tables.""" +from __future__ import annotations + +from pathlib import Path + +import pytest +import yaml + +from scripts.check_slice_contamination import ( + ContaminationError, + check_contamination, + load_contamination_map, + load_slice_tables, +) + +pytestmark = pytest.mark.unit + + +def _write_yaml(path: Path, data: dict) -> Path: + path.write_text(yaml.safe_dump(data), encoding="utf-8") + return path + + +class TestLoadContaminationMap: + def test_loads_table_names(self, tmp_path: Path) -> None: + path = _write_yaml( + tmp_path / "contamination_map.yaml", + {"contaminated_tables": ["patient", "sample"]}, + ) + result = load_contamination_map(path) + assert result == {"patient", "sample"} + + def test_returns_empty_set_when_field_missing(self, tmp_path: Path) -> None: + path = _write_yaml(tmp_path / "x.yaml", {"version": 1}) + assert load_contamination_map(path) == set() + + +class TestLoadSliceTables: + def test_extracts_table_names(self, tmp_path: Path) -> None: + path = _write_yaml( + tmp_path / "slice.yaml", + { + "schema": "cbioportal_msk_chord_2024", + "tables": [ + {"table_name": "timeline_x", "tier": "standard"}, + {"table_name": "timeline_y", "tier": "edge"}, + ], + }, + ) + result = load_slice_tables(path) + assert result == {"timeline_x", "timeline_y"} + + def test_empty_tables_returns_empty_set(self, tmp_path: Path) -> None: + path = _write_yaml(tmp_path / "empty.yaml", {"tables": []}) + assert load_slice_tables(path) == set() + + +class TestCheckContamination: + def test_no_overlap_passes(self, tmp_path: Path) -> None: + contam = _write_yaml( + tmp_path / "contam.yaml", + {"contaminated_tables": ["patient", "sample"]}, + ) + holdout = _write_yaml( + tmp_path / "holdout.yaml", + {"tables": [{"table_name": "timeline_other"}]}, + ) + check_contamination(holdout, [contam]) + + def test_overlap_raises(self, tmp_path: Path) -> None: + contam = _write_yaml( + tmp_path / "contam.yaml", + {"contaminated_tables": ["patient"]}, + ) + holdout = _write_yaml( + tmp_path / "holdout.yaml", + {"tables": [{"table_name": "patient"}, {"table_name": "x"}]}, + ) + with pytest.raises(ContaminationError) as exc_info: + check_contamination(holdout, [contam]) + assert "patient" in str(exc_info.value) + + def test_dev_slice_overlap_with_holdout_raises(self, tmp_path: Path) -> None: + contam = _write_yaml( + tmp_path / "contam.yaml", + {"contaminated_tables": []}, + ) + dev = _write_yaml( + tmp_path / "dev.yaml", + {"tables": [{"table_name": "shared_table"}]}, + ) + holdout = _write_yaml( + tmp_path / "holdout.yaml", + {"tables": [{"table_name": "shared_table"}]}, + ) + with pytest.raises(ContaminationError) as exc_info: + check_contamination(holdout, [contam, dev]) + assert "shared_table" in str(exc_info.value) + + +class TestProjectSliceFilesAreClean: + def test_msk_chord_holdout_disjoint_from_few_shots_and_dev(self) -> None: + repo_root = Path(__file__).resolve().parents[2] + slices = repo_root / "showcase" / "cbioportal_to_omop" / "slices" + contam = slices / "contamination_map.yaml" + dev = slices / "msk_chord_dev.yaml" + holdout = slices / "msk_chord_holdout.yaml" + check_contamination(holdout, [contam, dev]) diff --git a/tests/unit/test_study_registry.py b/tests/unit/test_study_registry.py new file mode 100644 index 0000000..5deee5e --- /dev/null +++ b/tests/unit/test_study_registry.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from sema.ingest.duckdb_staging import Staging +from sema.ingest.study_registry import ( + StudyCollisionError, + StudyRegistry, +) + + +@pytest.fixture +def staging(tmp_path: Path) -> Staging: + db_path = tmp_path / "test.duckdb" + return Staging(str(db_path)) + + +@pytest.mark.unit +class TestRegistryTable: + def test_initialises_table_in_sema_schema(self, staging: Staging) -> None: + StudyRegistry(staging) + rows = staging.execute( + "SELECT table_name FROM duckdb_tables() " + "WHERE schema_name = '_sema' AND table_name = '_sema_study_registry'" + ).fetchall() + assert len(rows) == 1 + + def test_init_is_idempotent(self, staging: Staging) -> None: + StudyRegistry(staging) + StudyRegistry(staging) + rows = staging.execute( + "SELECT count(*) FROM duckdb_tables() " + "WHERE schema_name = '_sema' AND table_name = '_sema_study_registry'" + ).fetchone() + assert rows[0] == 1 + + def test_register_persists_row(self, staging: Staging) -> None: + registry = StudyRegistry(staging) + registry.register( + schema_name="cbioportal_msk_chord_2024", + original_study_id="msk_chord_2024", + source_type="cbioportal", + ) + rows = staging.execute( + "SELECT schema_name, original_study_id, source_type " + "FROM _sema._sema_study_registry" + ).fetchall() + assert rows == [("cbioportal_msk_chord_2024", "msk_chord_2024", "cbioportal")] + + def test_register_same_study_twice_is_no_op(self, staging: Staging) -> None: + registry = StudyRegistry(staging) + registry.register("cbioportal_x", "X", "cbioportal") + registry.register("cbioportal_x", "X", "cbioportal") + rows = staging.execute( + "SELECT count(*) FROM _sema._sema_study_registry" + ).fetchone() + assert rows[0] == 1 + + def test_collision_raises(self, staging: Staging) -> None: + registry = StudyRegistry(staging) + registry.register("cbioportal_brca_tcga", "BRCA-TCGA", "cbioportal") + with pytest.raises(StudyCollisionError) as exc_info: + registry.register("cbioportal_brca_tcga", "BRCA_TCGA", "cbioportal") + msg = str(exc_info.value) + assert "BRCA-TCGA" in msg + assert "BRCA_TCGA" in msg + assert "cbioportal_brca_tcga" in msg + + def test_collision_does_not_overwrite_existing(self, staging: Staging) -> None: + registry = StudyRegistry(staging) + registry.register("cbioportal_brca_tcga", "BRCA-TCGA", "cbioportal") + with pytest.raises(StudyCollisionError): + registry.register("cbioportal_brca_tcga", "BRCA_TCGA", "cbioportal") + rows = staging.execute( + "SELECT original_study_id FROM _sema._sema_study_registry" + ).fetchall() + assert rows == [("BRCA-TCGA",)] + + def test_list_registered_schemas(self, staging: Staging) -> None: + registry = StudyRegistry(staging) + registry.register("cbioportal_a", "A", "cbioportal") + registry.register("cbioportal_b", "B", "cbioportal") + names = registry.list_schemas() + assert sorted(names) == ["cbioportal_a", "cbioportal_b"] + + def test_list_empty_when_no_registrations(self, staging: Staging) -> None: + registry = StudyRegistry(staging) + assert registry.list_schemas() == [] + + def test_register_records_created_at(self, staging: Staging) -> None: + registry = StudyRegistry(staging) + registry.register("cbioportal_x", "x", "cbioportal") + row = staging.execute( + "SELECT created_at FROM _sema._sema_study_registry" + ).fetchone() + assert row[0] is not None diff --git a/tests/unit/test_warehouse_lookup.py b/tests/unit/test_warehouse_lookup.py new file mode 100644 index 0000000..b635d27 --- /dev/null +++ b/tests/unit/test_warehouse_lookup.py @@ -0,0 +1,147 @@ +"""Tests for `WarehouseSampler` and `WarehouseProfileLookup`.""" +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from sema.engine.warehouse_lookup import ( + WarehouseProfileLookup, + WarehouseSampler, +) + +pytestmark = pytest.mark.unit + +KEY = ("cbioportal_msk_chord_2024", "sample", "patient_id") + + +class TestWarehouseSampler: + def test_returns_set_of_distinct_string_values(self): + query_fn = MagicMock( + return_value=[("p1",), ("p2",), ("p3",)], + ) + sampler = WarehouseSampler( + query_fn=query_fn, catalog="workspace", + ) + assert sampler(KEY) == {"p1", "p2", "p3"} + + def test_issues_select_distinct_with_limit(self): + query_fn = MagicMock(return_value=[]) + sampler = WarehouseSampler( + query_fn=query_fn, catalog="workspace", sample_cap=500, + ) + sampler(KEY) + sql = query_fn.call_args[0][0] + assert "SELECT DISTINCT" in sql + assert "LIMIT 500" in sql + assert "`workspace`.`cbioportal_msk_chord_2024`.`sample`" in sql + assert "`patient_id`" in sql + + def test_caches_repeat_lookups(self): + query_fn = MagicMock(return_value=[("p1",)]) + sampler = WarehouseSampler( + query_fn=query_fn, catalog="workspace", + ) + sampler(KEY) + sampler(KEY) + sampler(KEY) + query_fn.assert_called_once() + + def test_returns_none_on_query_error(self): + query_fn = MagicMock(side_effect=RuntimeError("warehouse down")) + sampler = WarehouseSampler( + query_fn=query_fn, catalog="workspace", + ) + assert sampler(KEY) is None + + def test_caches_none_to_avoid_retrying_failed_queries(self): + query_fn = MagicMock(side_effect=RuntimeError("boom")) + sampler = WarehouseSampler( + query_fn=query_fn, catalog="workspace", + ) + sampler(KEY) + sampler(KEY) + query_fn.assert_called_once() + + def test_skips_null_values_in_sample(self): + query_fn = MagicMock(return_value=[("p1",), (None,), ("p2",)]) + sampler = WarehouseSampler( + query_fn=query_fn, catalog="workspace", + ) + assert sampler(KEY) == {"p1", "p2"} + + +class TestWarehouseProfileLookup: + def test_returns_distinct_and_row_count(self): + # First call: row count, second call: distinct + query_fn = MagicMock( + side_effect=[[(1000,)], [(742,)]], + ) + lookup = WarehouseProfileLookup( + query_fn=query_fn, catalog="workspace", + ) + result = lookup(KEY) + assert result == (742, 1000) + + def test_caches_row_count_per_table(self): + query_fn = MagicMock( + side_effect=[[(1000,)], [(742,)], [(900,)]], + ) + lookup = WarehouseProfileLookup( + query_fn=query_fn, catalog="workspace", + ) + lookup(("schema_x", "sample", "patient_id")) + lookup(("schema_x", "sample", "sample_id")) + # Third call should hit row-count cache → 3 queries total + # (1 count + 2 distinct), not 4 + assert query_fn.call_count == 3 + + def test_caches_full_result_per_column(self): + query_fn = MagicMock( + side_effect=[[(1000,)], [(742,)]], + ) + lookup = WarehouseProfileLookup( + query_fn=query_fn, catalog="workspace", + ) + lookup(KEY) + lookup(KEY) + assert query_fn.call_count == 2 + + def test_returns_none_on_row_count_error(self): + query_fn = MagicMock(side_effect=RuntimeError("denied")) + lookup = WarehouseProfileLookup( + query_fn=query_fn, catalog="workspace", + ) + assert lookup(KEY) is None + + def test_returns_none_on_distinct_error(self): + query_fn = MagicMock( + side_effect=[[(1000,)], RuntimeError("col missing")], + ) + lookup = WarehouseProfileLookup( + query_fn=query_fn, catalog="workspace", + ) + assert lookup(KEY) is None + + def test_issues_count_and_exact_distinct_sql(self): + """Use exact `COUNT(DISTINCT col)` rather than APPROX. + + APPROX_COUNT_DISTINCT has ~2-7% error on Databricks which + breaks `verify_cardinality`'s `pk_distinct == pk_rows` + uniqueness check. Exact distinct is single-pass aggregate + (NOT an RI scan), so it stays within the design's + "no unbounded RI scans" rule. + """ + query_fn = MagicMock( + side_effect=[[(1000,)], [(742,)]], + ) + lookup = WarehouseProfileLookup( + query_fn=query_fn, catalog="workspace", + ) + lookup(KEY) + first_sql = query_fn.call_args_list[0][0][0] + second_sql = query_fn.call_args_list[1][0][0] + assert "COUNT(*)" in first_sql + assert "COUNT(DISTINCT" in second_sql + assert "APPROX_COUNT_DISTINCT" not in second_sql + assert "`patient_id`" in second_sql