diff --git a/.gitignore b/.gitignore index de3e6e5..3dcf1b5 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,7 @@ wheels/ docs/backup/ docs/omop_relationships.csv .vscode/ -.env \ No newline at end of file +.env +resources/ +*.DS_Store +logging/ \ No newline at end of file diff --git a/README.md b/README.md index 9c5f4f1..700e3bb 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,3 @@ -# Architecture - -This library provides a lightweight, query-time knowledge-graph layer over an OMOP vocabulary database, with explicit separation between: - -* graph access (nodes, edges, predicates), -* graph algorithms (traversal, pathfinding), -* path scoring and explanation, and -* presentation / inspection utilities. - # omop-graph **omop-graph** is a lightweight, opinionated knowledge-graph traversal and path-analysis library built on top of the OMOP vocabulary model. @@ -14,8 +5,8 @@ This library provides a lightweight, query-time knowledge-graph layer over an OM It provides: - a stable **KnowledgeGraph façade** over OMOP concepts and relationships - flexible **graph traversal** (forward, backward, bidirectional) -- **path discovery and ranking** with transparent scoring -- **traceable explanations** of why one path is preferred over another +- **path discovery** with transparent scoring +- **traceable explanations** of traversal decisions - multiple **rendering backends** (text, HTML, Mermaid) The library is designed for: @@ -31,105 +22,95 @@ The library is designed for: pip install omop-graph ``` +With embedding support (sqlite-vec backend, zero config): + +```bash +pip install "omop-graph[emb]" +``` + +For larger deployments use `[pgvector]` or `[faiss-cpu]` instead (or in addition). +Full setup is covered in the [omop-emb documentation](https://australiancancerdatanetwork.github.io/omop-emb/). + +--- + ## Core Concepts ### KnowledgeGraph -KnowledgeGraph is the main entry point. It wraps an existing SQLAlchemy session connected to an OMOP vocabulary schema. kg-core assumes OMOP semantics and tables. +`KnowledgeGraph` is the main entry point. It wraps a SQLAlchemy `Engine` connected to an OMOP vocabulary schema and provides a high-level Pythonic API over the relational tables. ```python +from sqlalchemy import create_engine from omop_graph.graph.kg import KnowledgeGraph -``` -### Nodes and Edges +engine = create_engine("postgresql://user:pass@localhost/omop") +kg = KnowledgeGraph(engine) -Nodes are OMOP Concepts; Edges are OMOP Concept_Relationships +# Lookup a concept by label +match_group = kg.label_lookup("Atrial Fibrillation", fuzzy=False) +concept = match_group.best_match +print(f"ID: {concept.concept_id}, Name: {concept.matched_label}") -Relationships are classified into semantic kinds: +# Traverse the hierarchy +parents = kg.parents(concept.concept_id) +``` + +### Nodes and Edges -* ONTOLOGICAL -* MAPPING -* ATTRIBUTE -* VERSIONING -* METADATA +Nodes are OMOP Concepts; Edges are OMOP Concept_Relationships. -This classification drives traversal and scoring. +Relationships are pre-classified into semantic kinds (`ClassIDEnum`): -### Traversal, Paths and Scoring +- `HIERARCHY` — parent/child ontological relationships +- `IDENTITY` — mapping to standard concepts +- `COMPOSITION` — part-of relationships +- `ASSOCIATION` — lateral clinical associations +- `ATTRIBUTE` — concept attribute relationships -You can: +This classification drives traversal filtering and scoring. -* expand neighbourhoods -* extract subgraphs -* trace traversal decisions -* control which relationship kinds are followed -* discover multiple candidate paths between concepts and rank them -* render simple HTML cards for easy interactive exploration +### Traversal and Paths ```python from omop_graph.graph.paths import find_shortest_paths from omop_graph.extensions.omop_alchemy import ClassIDEnum -ingredient = kg.concept_id_by_code("RxNorm", "6809") # Metformin -drug = kg.concept_id_by_code("RxNorm", "860975") # Metformin 500 MG Oral Tablet - -kg.concept_view(drug) # ConceptView(id=40163924, RxNorm:860975, name='24 HR metformin hydrochloride 500 MG Extended Release Oral Tablet') -kg.concept_view(ingredient) # ConceptView(id=1503297, RxNorm:6809, name='metformin') +ingredient = kg.concept_id_by_code("RxNorm", "6809") # Metformin +drug = kg.concept_id_by_code("RxNorm", "860975") # Metformin 500 MG Oral Tablet paths, trace = find_shortest_paths( kg, source=drug, target=ingredient, - predicate_kinds={ - ClassIDEnum.HIERARCHICAL, - ClassIDEnum.IDENTITY, - }, + predicate_kinds=frozenset({ClassIDEnum.HIERARCHY, ClassIDEnum.IDENTITY}), max_depth=6, traced=True, ) - -ranked = rank_paths(kg, paths) - -``` - -### - -```python -paths = kg.find_shortest_paths( - source=a, - target=b, - max_depth=6, -) -ranked = kg.rank_paths(paths) ``` ### Rendering -Outputs can be rendered as: +Outputs can be rendered as plain text, HTML (Jupyter), or Mermaid diagrams. Rendering auto-detects the environment. -* plain text (CLI / logs) -* HTML (Jupyter) -* Mermaid diagrams - -Rendering auto-detects the environment. - -```python +```python from IPython.display import HTML, display from omop_graph.render import render_trace display(HTML(render_trace(kg, trace))) ``` +--- + ## Project Structure -```graphql +``` omop_graph/ ├── graph/ # graph logic, traversal, paths, scoring ├── render/ # HTML / text / Mermaid renderers -├── reasoning/ # Ontology traversal methods for specific reasoner tasks -├────── resolvers/ # Resolve labels for exact / fuzzy / synonym matches - TODO: embedding matches -├────── phenotypes/ # Set operations to build efficient hierarchical groupings for reasoning +├── reasoning/ # ontology traversal methods for specific reasoner tasks +│ ├── resolvers/ # resolve labels via exact / fuzzy / full-text / synonym search +│ └── phenotypes/ # set operations for hierarchical groupings +├── oaklib_interface/ # OAK-compliant adapter ├── api.py # stable public API surface └── db/ # session helpers - -``` \ No newline at end of file +``` diff --git a/docs/predicate_classification.csv b/config/predicate_classification.csv similarity index 100% rename from docs/predicate_classification.csv rename to config/predicate_classification.csv diff --git a/docs/predicate_mapping.csv b/config/predicate_mapping.csv similarity index 100% rename from docs/predicate_mapping.csv rename to config/predicate_mapping.csv diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..9035ad9 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,28 @@ +services: + omop-cdm-db: + image: postgres:16-alpine + restart: always + env_file: .env + environment: + - POSTGRES_USER=${OMOP_CDM_DB_USER:-omop} + - POSTGRES_PASSWORD=${OMOP_CDM_DB_PASSWORD:-omop} + - POSTGRES_DB=${OMOP_CDM_DB_NAME:-omop} + - PGDATA=/var/lib/postgresql/data/pgdata + volumes: + - db_data:/var/lib/postgresql/data + networks: + - omop-net + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${OMOP_CDM_DB_USER:-omop} -d ${OMOP_CDM_DB_NAME:-omop}"] + interval: 5s + timeout: 5s + retries: 5 + ports: + - "5432:5432" + +networks: + omop-net: + name: omop-net + +volumes: + db_data: diff --git a/docs/graph/edges.md b/docs/graph/edges.md index 8ac96a4..0093e25 100644 --- a/docs/graph/edges.md +++ b/docs/graph/edges.md @@ -16,10 +16,10 @@ To allow reproduction and evaluation of this approach, we provide clear guidelin ??? "Expand to see the grouping classification of predicates" - {{ to_grouped_table('docs/predicate_classification.csv', [0, 1], [0, 1, 2, 3, 4], [0, 1],) }} + {{ to_grouped_table('config/predicate_classification.csv', [0, 1], [0, 1, 2, 3, 4], [0, 1],) }} ## Predicate Mappings -Following the predicate classification guidelines of the previous seciton, we calssified the following predicates into their respective classification groups. +Following the predicate classification guidelines of the previous section, we classified the following predicates into their respective classification groups. !!! warning @@ -27,5 +27,5 @@ Following the predicate classification guidelines of the previous seciton, we ca ??? "Expand to see the classification of all edge connections" - {{ to_grouped_table('docs/predicate_mapping.csv', [0, 1], [0, 1, 2, 3], [0, 1], {"r_id": "relationship_id", "r_name": "relationship_name"}) }} + {{ to_grouped_table('config/predicate_mapping.csv', [0, 1], [0, 1, 2, 3], [0, 1], {"r_id": "relationship_id", "r_name": "relationship_name"}) }} \ No newline at end of file diff --git a/docs/graph/kg.md b/docs/graph/kg.md index 1aabfef..b9ca6e9 100644 --- a/docs/graph/kg.md +++ b/docs/graph/kg.md @@ -27,19 +27,14 @@ While the OMOP CDM is stored in a Relational Database Management System (RDBMS), ### Basic Usage -The `KnowledgeGraph` can be used standalone after connecting to the OMOP CDM database on disk. +The `KnowledgeGraph` can be used standalone after connecting to the OMOP CDM database. ```python from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker from omop_graph.graph.kg import KnowledgeGraph -# Setup your SQLAlchemy session engine = create_engine("postgresql://user:pass@localhost/omop") -SessionLocal = sessionmaker(bind=engine) - -# Initialize the Virtual Knowledge Graph -kg = KnowledgeGraph(SessionLocal) +kg = KnowledgeGraph(engine) # Lookup a concept by its label match_group = kg.label_lookup("Atrial Fibrillation", fuzzy=False) @@ -59,41 +54,53 @@ print(f"Parent IDs: {parents}") To enable semantic similarity and RAG-based retrieval, pass a `KnowledgeGraphEmbeddingConfiguration` when initialising the graph. This requires the optional `omop-emb` package — see the [installation guide](../usage/installation.md#embedding-rag). +!!! info "omop-emb documentation" + `omop-emb` manages all embedding storage, backends, and retrieval. Full documentation — including backend setup, CLI reference, FAISS sidecar, and configuration — is available at [australiancancerdatanetwork.github.io/omop-emb](https://australiancancerdatanetwork.github.io/omop-emb/). + #### Read-only (pre-computed embeddings already in the DB) Use this when embeddings have already been indexed and you only need retrieval: ```python +from sqlalchemy import create_engine from omop_graph.graph.kg import KnowledgeGraph, KnowledgeGraphEmbeddingConfiguration -from omop_emb import BackendType, ProviderType +from omop_emb.config import BackendType, MetricType, ProviderType + +engine = create_engine("postgresql://user:pass@localhost/omop") emb_config = KnowledgeGraphEmbeddingConfiguration( - backend_type=BackendType.FAISS, + backend_type=BackendType.PGVECTOR, # or BackendType.SQLITEVEC provider_type=ProviderType.OLLAMA, - canonical_model_name="text-embedding-3-small:0.6b", - base_storage_dir="/data/embeddings", + model_name="nomic-embed-text:v1.5", # must match the name used at ingestion time + metric_type=MetricType.COSINE, ) -kg = KnowledgeGraph(SessionLocal, emb_config=emb_config) +kg = KnowledgeGraph(engine, emb_config=emb_config) ``` +The backend is resolved from `backend_type` or, as a fallback, from the `OMOP_EMB_BACKEND` environment variable. +See the [omop-emb configuration reference](https://australiancancerdatanetwork.github.io/omop-emb/usage/configuration/) for all connection variables. + #### Write-capable (generate and store embeddings at runtime) -Provide an `EmbeddingClient` to enable both reading and writing embeddings: +Provide an `EmbeddingClient` to enable both reading and writing embeddings. The `provider_type` and `model_name` +are derived automatically from the client: ```python from omop_emb import EmbeddingClient -from omop_emb import BackendType, ProviderType +from omop_emb.config import BackendType, MetricType -client = EmbeddingClient(...) # configured for your provider +client = EmbeddingClient( + model="nomic-embed-text:v1.5", + api_base="http://ollama:11434/v1", +) emb_config = KnowledgeGraphEmbeddingConfiguration( - backend_type=BackendType.FAISS, - base_storage_dir="/data/embeddings", + backend_type=BackendType.PGVECTOR, + metric_type=MetricType.COSINE, client=client, ) -kg = KnowledgeGraph(SessionLocal, emb_config=emb_config) +kg = KnowledgeGraph(engine, emb_config=emb_config) ``` -The `provider_type` will be automatically determined from the `client`. #### Fallback embedding calculation @@ -107,12 +114,12 @@ for any missing concepts on-the-fly during a similarity call. ```python emb_config = KnowledgeGraphEmbeddingConfiguration( - backend_type="faiss", - base_storage_dir="/data/embeddings", + backend_type=BackendType.PGVECTOR, + metric_type=MetricType.COSINE, client=client, - compute_missing_embeddings=True, # compute embeddings for concepts not yet in the store + compute_missing_embeddings=True, ) -kg = KnowledgeGraph(SessionLocal, emb_config=emb_config) +kg = KnowledgeGraph(engine, emb_config=emb_config) ``` | `compute_missing_embeddings` | `client` present | Behaviour when concepts are missing | diff --git a/docs/index.md b/docs/index.md index 3f10243..c2b2b0a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,32 +1,40 @@ # omop-graph -**omop-graph** is a lightweight virtual knowledge Graph (VKG) built on-top of the OMOP CDM. -It transforms the static OMOP vocabulary tables into a dynamic graph environment suitable for NLP grounding, clinical reasoning and other tasks that benefit from a knowledge graph. +**omop-graph** is a lightweight Virtual Knowledge Graph (VKG) built on top of the OMOP CDM. +It transforms the static OMOP vocabulary tables into a dynamic graph environment suitable for NLP grounding, clinical reasoning, and other tasks that benefit from a knowledge graph. ## Why omop-graph? Unlike generic graph libraries, `omop-graph` is built specifically for clinical data: -- **Semantic Awareness**: Understands the difference between relationships. -- **Efficient Grounding**: Instead of traversing every possible path, the library uses a **Standard Anchor** approach: translating non-standard terms to standard concepts and leveraging the OMOP `concept_ancestor` table for high-speed hierarchy validation. -- **Transparent Scoring**: Decisions aren't black boxes. Every path is scored based on textual similarity, graph distance (parsimony), and clinical generality (broadness). -- **Pre-classification**: Relationships are already pre-classified into overarching groups, allowing quicker restrictions of connections and more efficient graph traversal. +- **Semantic Awareness**: Understands the difference between relationship kinds (hierarchy, identity, composition, association, attribute). +- **Efficient Grounding**: Instead of traversing every possible path, the library uses a **Standard Anchor** approach — translating non-standard terms to standard concepts and leveraging the OMOP `concept_ancestor` table for high-speed hierarchy validation. +- **Transparent Scoring**: Decisions aren't black boxes. Every candidate concept is scored based on textual similarity, graph distance (parsimony), and clinical generality (broadness). +- **Pre-classification**: Relationships are pre-classified into semantic groups, enabling quicker traversal restrictions and more targeted reasoning. + --- ## Documentation Overview ### Core Components -- [KnowledgeGraph](graph/kg.md): The VKG interface and what it attempts to solve. -- [Relationships](graph/edges.md): Pre-classification of edges/relationships of the OMOP CDM. -- [Oaklib Interface](oaklib/interface.md): `oaklib`-compliant interface +- [KnowledgeGraph](graph/kg.md): The VKG interface — connecting to OMOP and traversing the graph. +- [Relationships](graph/edges.md): Pre-classification of OMOP edges into semantic kinds. +- [Oaklib Interface](oaklib/interface.md): OAK-compliant adapter for cross-ontology tooling. ### Reasoning Explore the grounding pipeline used by clinical NLP tools. -- [Semantic grounding](reasoning/grounding.md): How regular search terms can be traced to a standard Ontology +- [Semantic Grounding](reasoning/grounding.md): Mapping free-text terms to standard OMOP concepts. +- [Resolver Pipelines](reasoning/resolvers.md): How candidate concepts are retrieved from the database. + +### Embedding Support + +!!! info "Powered by omop-emb" + Embedding-based similarity (vector search, RAG retrieval, on-the-fly embedding computation) is provided by the companion [`omop-emb`](https://australiancancerdatanetwork.github.io/omop-emb/) package. + Install it with `pip install "omop-graph[emb]"` and see [Knowledge Graph — Embedding Configuration](graph/kg.md#embedding-configuration) for integration details. ### Interactive Exploration -`omop-graph` includes built-in HTML renderers for Jupyter Notebooks, allowing you to visualize concepts and relationship summaries instantly. +`omop-graph` includes built-in HTML and Mermaid renderers for Jupyter Notebooks, allowing you to visualise concepts, traversal traces, and relationship summaries directly in a notebook. ### Testing -- [Testing](usage/testing.md): How test configuration works, what is covered, and how to set up environment variables for local test runs. \ No newline at end of file +- [Testing](usage/testing.md): Test configuration, coverage, and how to set up environment variables for local runs. diff --git a/docs/oaklib/interface.md b/docs/oaklib/interface.md index d2b4c9a..473b5c6 100644 --- a/docs/oaklib/interface.md +++ b/docs/oaklib/interface.md @@ -32,7 +32,7 @@ The primary adapter class that inherits from multiple OAK interfaces: To initialize a connection, `omop-graph` uses a specialized resource factory: * **`OMOPOntologyResource`**: A dataclass that wraps the SQLAlchemy connection URL, treating the database as a live ontology source. -* **`omop_resource()`**: A factory function that resolves database credentials from an explicit URL or the `OMOP_DATABASE_URL` environment variable. +* **`omop_resource()`**: A factory function that resolves database credentials from an explicit URL or the `OMOP_CDM_DB_URL` environment variable. --- diff --git a/docs/reasoning/grounding.md b/docs/reasoning/grounding.md index 77935cf..387d59e 100644 --- a/docs/reasoning/grounding.md +++ b/docs/reasoning/grounding.md @@ -21,9 +21,9 @@ To accelerate the grounding to standard concepts, `omop-graph` makes use of: The following steps summarise the entire grounding approach and are found in `omop_graph.reasoning.grounding` 1. **Configuration**: Determine graph restrictions using [`GroundingConstraints`](#grounding-constraints) - - `parent_id`: The `concept_id` of the parent Ontology. This attribute is **required** and allows testing whether a standard concept is part of the correct branch - - `domains`: The OMOP CDM domains that are allowed to be searched for. Each Ontolgoy has an associated domain as described in the [OMOP CDM](https://ohdsi.github.io/CommonDataModel/cdm54.html#concept). Specifying multiple permits all specified domains. - - `vocabs`: The OMOP CDM vocabularies that are allowed to be searched for. Each Ontology is also part of a vocabulary as described in the [OMOP CDM](https://ohdsi.github.io/CommonDataModel/cdm54.html#concept). Specifying multiple values permits all specified vocabularies. + - `parent_ids`: OMOP Concept IDs that act as required ancestors — any valid result must be a descendant of at least one of these. + - `search_constraint`: A [`SearchConstraintConcept`](#searchconstraintconcept) that filters the initial resolver query by domain, vocabulary, and/or standard status. + - `max_depth` / `predicate_kinds`: Control how far and along which relationship kinds the anchor walk is allowed to travel. 2. **Resolve**: Use the [`ResolverPipeline`](resolvers.md) to find any concepts (Standard or Non-Standard) matching the text. 3. **Anchor**: For each candidate, find the nearest **Standard Concept**. This is required for Step 3 as all standard concepts are in `concept_ancestor`. @@ -35,22 +35,47 @@ To accelerate the grounding to standard concepts, `omop-graph` makes use of: - Details of scoring algorithm shown [here](#scoring) ## Grounding Constraints -You can restrict the search using `GroundingConstraints`: -- **parent_ids**: Only return concepts that fall under these ancestors (e.g., only search within "Procedures"). -- **search_constraint**: Limit search to specific vocabularies or domains (e.g., "RxNorm" only). -- `parent_ids`: Restricts the search to descendants of specific OMOP concepts (e.g., only search for concepts under `Condition`). -- `vocabs`: Restricts the search to specific vocabularies (e.g., `SNOMED`, `RxNorm`). -- `domains`: Restricts by OMOP Domain ID (e.g., `Condition`, `Drug`). + +`GroundingConstraints` is composed of two layers: + +| Field | Type | Default | Purpose | +|---|---|---|---| +| `parent_ids` | `tuple[int, ...]` | `None` | Only accept candidates that are descendants of these OMOP concept IDs (hierarchy validation via `concept_ancestor`). | +| `search_constraint` | `SearchConstraintConcept` | `None` | Filters applied to the initial resolver query (domain, vocabulary, standard flag). | +| `max_depth` | `int` | `6` | Maximum hop distance allowed between a candidate and its standard anchor. | +| `predicate_kinds` | `frozenset[ClassIDEnum]` | `{IDENTITY}` | Relationship kinds followed when walking from a non-standard candidate to its standard anchor. | + +### SearchConstraintConcept + +`SearchConstraintConcept` controls which concepts are even considered as candidates during the resolve step. All fields are optional and composable: + +| Field | Type | Default | Purpose | +|---|---|---|---| +| `concept_ids` | `tuple[int, ...]` | `None` | Restrict to a specific set of concept IDs. | +| `domains` | `tuple[str, ...]` | `None` | Restrict by OMOP Domain ID (e.g. `"Condition"`, `"Drug"`). | +| `vocabularies` | `tuple[str, ...]` | `None` | Restrict by Vocabulary ID (e.g. `"SNOMED"`, `"RxNorm"`). | +| `require_standard` | `bool` | `False` | When `True`, only concepts with `standard_concept` in `('S', 'C')` are returned. | +| `limit` | `int` | `None` | Cap the number of candidates returned from the resolver query. | + +### Example ```python from omop_graph.reasoning.grounding import ground_term, GroundingConstraints +from omop_graph.graph.constraints import SearchConstraintConcept +from omop_graph.extensions.omop_alchemy import ClassIDEnum constraints = GroundingConstraints( - parent_ids=(441484,), # 'Clinical Finding' - max_depth=6 + parent_ids=(441484,), # 'Clinical Finding' — only accept descendants of this ancestor + search_constraint=SearchConstraintConcept( + domains=("Condition",), + vocabularies=("SNOMED",), + require_standard=True, + ), + max_depth=6, + predicate_kinds=frozenset({ClassIDEnum.IDENTITY}), ) -results = ground_term(pipeline, kg, "chest pain", constraints) +results = ground_term(pipeline, kg, "chest pain", text_embedding=None, text_embedding_model=None, constraints=constraints) ``` ## Scoring @@ -65,10 +90,15 @@ TotalScore = Relevance - ParsimonyPenalty + BroadnessBonus $$ #### 1. Relevance -Relevance represents the initial semantic fit. It is the product of: +Relevance represents the initial semantic fit and is computed as **either** embedding similarity **or** textual similarity — not both simultaneously: + +- **Without embeddings**: textual similarity is used exclusively. +- **With embeddings** (default when `omop-graph[emb]` is installed and configured): embedding cosine similarity **replaces** the textual score entirely. + +The two scoring modes: -- **Embedding Similarity**: Cosine similarity between the input text and the concept name. -- **Textual Similarity**: A custom token-overlap score that heavily penalizes missing words from the user's query but allows for "extra" descriptive words in the OMOP concept name. +- **Embedding Similarity**: Cosine similarity between the input text embedding and the concept embedding. Requires `omop-graph[emb]` and a configured `KnowledgeGraphEmbeddingConfiguration` — see the [Knowledge Graph docs](../graph/kg.md#embedding-configuration) and the [omop-emb documentation](https://australiancancerdatanetwork.github.io/omop-emb/) for setup. +- **Textual Similarity**: A custom token-overlap score that heavily penalizes missing words from the user's query but allows for "extra" descriptive words in the OMOP concept name. Used as a fallback when no embedding is available. #### 2. Parsimony: Distance Penalty OMOP is a deep hierarchy. A concept that is 1 hop away from your search term is more likely to be correct than one found 5 hops away. @@ -88,10 +118,11 @@ Scoring is performed in a batch operation to minimize database overhead: ```python from omop_graph.graph.scoring import score_standard_concepts -ranked = score_standard_concepts( +scored = score_standard_concepts( text="Hodgkin lymphoma", standard_concepts=candidates, kg=kg, - similarity_scores=embeddings_array + nearest_concept_matches=nearest_matches, # optional; from omop-emb embedding index ) +ranked = sorted(scored, key=lambda s: s.total_score, reverse=True) ``` \ No newline at end of file diff --git a/docs/usage/cli.md b/docs/usage/cli.md index 04a1fd9..bb9a83f 100644 --- a/docs/usage/cli.md +++ b/docs/usage/cli.md @@ -4,69 +4,29 @@ The OMOP CDM instantiation tool provides a streamlined way to bootstrap a local --- -## `omop-cdm` {: #omop-cdm } +## `omop-maint load-vocab-source` {: #load-vocab-source } -Bootstrap the OMOP CDM and load reference data from Athena into a local database. +!!! info "Moved to `omop-maint`" + Vocabulary loading was previously exposed as `omop-graph omop-cdm`. It is now provided by the [`OMOP_Alchemy`](https://australiancancerdatanetwork.github.io/OMOP_Alchemy/) package under the `omop-maint` CLI. -If you want PostgreSQL full-text sidecars for `concept` and `concept_synonym`, pass -`--fulltext`. The command will install and populate the sidecars after the vocabulary -load finishes. +Load Athena vocabulary CSV files from a configured source path into the OMOP CDM database using the ORM staged CSV loader. -!!! danger "Warning" - This command will wipe the existing database in the target container before loading new data. - -### Prerequisites - -Before running the command, ensure your environment is configured with a `.env` file or exported variables: - -- **`OMOP_DATABASE_URL`**: SQLAlchemy connection string (e.g., `postgresql://user:pass@localhost:5432/omop`). -- **`SOURCE_PATH`**: Local directory path containing the Athena CSV files (e.g., `CONCEPT.csv`, `VOCABULARY.csv`). - -### Usage -If installed as a package: -```bash -omop-graph omop-cdm [--add-test-data] [--fulltext] --chunk-size= -``` - -**Example Usage:** -```bash -# Instantiate with test data and a custom chunk size of 10,000 -omop-graph omop-cdm --add-test-data --chunk-size=10000 -``` -```bash -# Display the help -omop-graph omop-cdm --help -``` - -### Command Arguments -| Argument | Type | Default | Description | -| :--- | :--- | :---: | :--- | -| **`--add-test-data`** | `Boolean` | False | Whether to add synthetic test data after loading Athena data.| -| **`--chunk-size`**, **`-c`** | `Integer` | `5000` | Number of rows to process in each chunk. Adjust based on your system's memory capacity to avoid OOM errors. | -| **`--fulltext`** | `Boolean` | False | Install and populate PostgreSQL full-text sidecars for `concept` and `concept_synonym` after the vocabulary load. | -| **`--fulltext-regconfig`** | `String` | `english` | PostgreSQL text search configuration used when populating the full-text sidecars. | - ---- ## `relationship-classification` {: #relationship-classification } This command ingests pre-defined relationship classifications and mappings into the database. It categorizes standard OMOP relationships into semantic groups (e.g., Hierarchical, Lateral, Mapping) to enable more intelligent graph reasoning. ### Rationale -The standard OMOP `relationship` table provides basic metadata, but lacks unified semantic "kinds" out of the box. This tool maps those relationships to a specific `ClassIDEnum` (like `EQUIVALENT`, `HIERARCHICAL`, or `IDENTITY`) and provides detailed inference descriptions used by the `KnowledgeGraph` facade. - -### Prerequisites - -The command expects two CSV files to be present in the target directory: +The standard OMOP `relationship` table provides basic metadata, but lacks unified semantic "kinds" out of the box. This tool maps those relationships to a specific `ClassIDEnum` (like `HIERARCHY`, `IDENTITY`, or `ASSOCIATION`) and provides detailed inference descriptions used by the `KnowledgeGraph` facade. ### Prerequisites Before running the command, ensure your environment is configured with a `.env` file or exported variables: -1. Prepopulated OMOP CDM (e.g. using command [`omop-cdm`](#omop-cdm)) +1. Prepopulated OMOP CDM (e.g. using [`omop-maint load-vocab-source`](#load-vocab-source)) 2. **`predicate_classification.csv`**: Defines the semantic classes and subclasses (descriptions, semantics, and inference rules). 3. **`predicate_mapping.csv`**: Maps specific OMOP `relationship_id`s to the classes defined in the classification file. 4. Set following environment variables: - - **`OMOP_DATABASE_URL`**: SQLAlchemy connection string (e.g., `postgresql://user:pass@localhost:5432/omop`). - - **`SOURCE_PATH`**: Local directory path containing the Athena CSV files (e.g., `CONCEPT.csv`, `VOCABULARY.csv`). This is required as the new connections/tables are stored there after creation. + - **`OMOP_CDM_DB_URL`**: SQLAlchemy connection string (e.g., `postgresql+psycopg://user:pass@localhost:5432/omop`). See [`omop-maint load-vocab-source` options](#load-vocab-source) for connection configuration details. + - **`OMOP_VOCABULARY_DIR`**: Local directory path where the generated classification tables will be written as CSV files. ### Usage diff --git a/docs/usage/installation.md b/docs/usage/installation.md index ca98f7d..1da882d 100644 --- a/docs/usage/installation.md +++ b/docs/usage/installation.md @@ -1,29 +1,50 @@ -# Installation: Core +# Installation -!!! note - - The dependency on uv is not strictly enforced and can be replaced using `pip install omop-graph`. - ``` - -The package can be regularly installed using `pip` and `uv`: +The package can be installed with `pip` or `uv`: ```bash +pip install omop-graph +# or uv pip install omop-graph ``` -## Installation: Embedding and RAG support (optional, recommended) {:#embedding-rag} +## Embedding and RAG support (optional, recommended) {:#embedding-rag} !!! tip - This is a recommended setting and improves the functionality of the library detrimental. + Installing with `[emb]` is recommended. Without it the `KnowledgeGraph` operates in text-only mode and all embedding-based similarity scoring is disabled. -The optional [`omop-emb` module](https://australiancancerdatanetwork.github.io/omop-emb/) can be installed using the option `[emb]`: ```bash -uv pip install omop-graph[emb] +pip install "omop-graph[emb]" +# or +uv pip install "omop-graph[emb]" ``` -This allows: +This pulls in [`omop-emb`](https://australiancancerdatanetwork.github.io/omop-emb/) with its default **sqlite-vec** backend — a file-based vector store that requires no external database server and works out of the box. + +It enables: + +- vector similarity search over OMOP concepts +- embedding-weighted grounding scores +- on-the-fly embedding computation for un-indexed concepts + +### Scaling up: pgvector or FAISS + +For larger deployments or approximate-nearest-neighbour acceleration, install the corresponding extra instead of (or alongside) `[emb]`: + +| Extra | What it adds | +|---|---| +| `omop-graph[emb]` | sqlite-vec backend (default, zero config) | +| `omop-graph[pgvector]` | PostgreSQL/pgvector backend | +| `omop-graph[faiss-cpu]` | FAISS sidecar for fast approximate search | + +These can be combined: + +```bash +pip install "omop-graph[pgvector,faiss-cpu]" +``` -- RAG-based retrieval -- semantic similarity searches -- graph reasoning -- (Future): Agentic LLM interfaces \ No newline at end of file +!!! info "omop-emb documentation" + Backend configuration, CLI reference, and index management are covered in the + [omop-emb documentation](https://australiancancerdatanetwork.github.io/omop-emb/usage/installation/). + The `[emb]` extra mirrors the base `omop-emb` install; `[pgvector]` and `[faiss-cpu]` mirror + `omop-emb[pgvector]` and `omop-emb[faiss-cpu]` respectively. diff --git a/docs/usage/testing.md b/docs/usage/testing.md index 61c1b0f..55230d6 100644 --- a/docs/usage/testing.md +++ b/docs/usage/testing.md @@ -44,7 +44,7 @@ The grounding test suite is now structured with parametrized cases so each clini Create a local `.env` file in the `omop-graph` repo root: ```dotenv -OMOP_DATABASE_URL=postgresql://omop:omop@db-omop:5432/omop +OMOP_CDM_DB_URL=postgresql://omop:omop@db-omop:5432/omop OMOP_OLLAMA_API_BASE=http://ollama:11434/v1 OMOP_EMB_BACKEND=pgvector ``` diff --git a/pyproject.toml b/pyproject.toml index c572e81..7a682b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,6 @@ keywords = [ "knowledge-graph", "LLM-grounding", "real-world-evidence", - "rapidfuzz", - "omop-llm", ] classifiers = [ @@ -35,25 +33,29 @@ classifiers = [ ] dependencies = [ - "omop-alchemy>=0.6.0", - "orm-loader>=0.3.15", - "psycopg2-binary>=2.9.11", + "omop-alchemy>=0.6.0,<=0.6.2", + "orm-loader>=0.3.27,<0.4.0", "sqlalchemy>=2.0.45", "typing-extensions>=4.15.0", "typer", + "oaklib", ] [project.optional-dependencies] -faiss = [ - "omop-emb[faiss]==0.4.0", +emb = [ + "omop-emb>=1.0.0", +] +faiss-cpu = [ + "omop-emb[faiss-cpu]>=1.0.0", ] pgvector = [ - "omop-emb[pgvector]==0.4.0", + "omop-emb[pgvector]>=1.0.0", ] -emb = [ - "omop-emb[all]==0.4.0" +postgres = [ + "omop-alchemy[postgres]", ] + [dependency-groups] dev = [ "ipython>=9.8.0", @@ -62,7 +64,7 @@ dev = [ "pytest-cov>=7.0.0", "rich>=14.2.0", "ruff>=0.14.10", - "omop-emb[all]==0.4.0", + "omop-graph[emb,faiss-cpu,pgvector,postgres]", "mkdocs<2.0", "mkdocs-material", "mkdocstrings[python]", diff --git a/scripts/benchmarks/README.md b/scripts/benchmarks/README.md index e784fa9..6367adc 100644 --- a/scripts/benchmarks/README.md +++ b/scripts/benchmarks/README.md @@ -2,7 +2,7 @@ This benchmark evaluates resolver configurations against a live OMOP CDM database. -Set `OMOP_DATABASE_URL` or pass `--database-url` to point the benchmark at your local database. +Set `OMOP_CDM_DB_URL` or pass `--database-url` to point the benchmark at your local database. ## What It Measures diff --git a/scripts/benchmarks/__init__.py b/scripts/benchmarks/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/scripts/benchmarks/benchmark_base.py b/scripts/benchmarks/benchmark.py similarity index 56% rename from scripts/benchmarks/benchmark_base.py rename to scripts/benchmarks/benchmark.py index f66e298..cca64cf 100644 --- a/scripts/benchmarks/benchmark_base.py +++ b/scripts/benchmarks/benchmark.py @@ -11,15 +11,13 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Tuple, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, cast, Annotated +import typer import sqlalchemy as sa from dotenv import load_dotenv from sqlalchemy.orm import sessionmaker -import argparse -import statistics -import time import numpy as np from omop_emb.config import ( BackendType, @@ -29,11 +27,16 @@ parse_index_type, parse_metric_type, ) +from omop_emb.embeddings import ( + EmbeddingClient, + EmbeddingRole +) +from omop_emb.backends.index_config import index_config_from_index_type from omop_graph.cli import configure_logging_level -from omop_graph.extensions.emb import EmbeddingBackendType, MissingExtensionError +from omop_graph.extensions.emb import get_embedding_writer_interface, MissingExtensionError from omop_graph.extensions.omop_alchemy import ClassIDEnum from omop_graph.graph.constraints import SearchConstraintConcept -from omop_graph.graph.kg import KnowledgeGraph +from omop_graph.graph.kg import KnowledgeGraph, KnowledgeGraphEmbeddingConfiguration from omop_graph.graph.scoring import StandardConceptWithScore from omop_graph.reasoning.grounding import GroundingConstraints, ground_term from omop_graph.reasoning.resolvers.resolver_pipeline import ResolverPipeline @@ -47,8 +50,8 @@ PartialLabelResolver, PartialSynonymResolver, ) -from omop_emb import EmbeddingClient - +from omop_graph.db.session import make_engine +app = typer.Typer() DEFAULT_VOCABULARIES: Tuple[str, ...] = ("SNOMED", "ICDO3", "HemOnc") @@ -161,10 +164,10 @@ def build_session_factory(database_url: Optional[str]) -> sessionmaker: """Build a SQLAlchemy session factory for the configured OMOP database.""" load_dotenv() - resolved_url = database_url or os.getenv("OMOP_DATABASE_URL") + resolved_url = database_url or os.getenv("OMOP_CDM_DB_URL") if not resolved_url: raise RuntimeError( - "No database URL provided. Pass --database-url or set OMOP_DATABASE_URL." + "No database URL provided. Pass --database-url or set OMOP_CDM_DB_URL." ) engine = sa.create_engine(resolved_url, future=True, echo=False) @@ -175,10 +178,10 @@ def build_engine(database_url: Optional[str]) -> sa.Engine: """Build a SQLAlchemy engine for the configured OMOP database.""" load_dotenv() - resolved_url = database_url or os.getenv("OMOP_DATABASE_URL") + resolved_url = database_url or os.getenv("OMOP_CDM_DB_URL") if not resolved_url: raise RuntimeError( - "No database URL provided. Pass --database-url or set OMOP_DATABASE_URL." + "No database URL provided. Pass --database-url or set OMOP_CDM_DB_URL." ) return sa.create_engine(resolved_url, future=True, echo=False) @@ -186,25 +189,32 @@ def build_engine(database_url: Optional[str]) -> sa.Engine: def build_knowledge_graph(database_url: Optional[str]) -> KnowledgeGraph: """Create a KnowledgeGraph backed by the live OMOP CDM database.""" - - return KnowledgeGraph(session_factory=build_session_factory(database_url)) + return KnowledgeGraph(cdm_engine=make_engine(database_url)) def build_embedding_knowledge_graph( database_url: Optional[str], + embedding_metric: MetricType, + embedding_model: Optional[str], embedding_backend: Optional[str | BackendType], embedding_client: Optional[EmbeddingClient], - embedding_storage_base_dir: Optional[str], ) -> KnowledgeGraph: """Create a KnowledgeGraph with embedding support configured.""" - session_factory = build_session_factory(database_url) + cdm_engine = make_engine(database_url) resolved_embedding_backend = parse_backend_type(embedding_backend) if embedding_backend is not None else None + resolved_metric_type = parse_metric_type(embedding_metric) + + config = KnowledgeGraphEmbeddingConfiguration( + metric_type=resolved_metric_type, + backend_type=resolved_embedding_backend, + client=embedding_client, + compute_missing_embeddings=True, + model_name=embedding_model, + ) return KnowledgeGraph( - session_factory=session_factory, - emb_backend=resolved_embedding_backend, - emb_base_storage_dir=embedding_storage_base_dir, - emb_client=embedding_client, + cdm_engine=cdm_engine, + emb_config=config ) @@ -332,6 +342,80 @@ def _order_cases_for_report(cases: Sequence[BenchmarkCase]) -> List[BenchmarkCas return sorted(cases, key=lambda case: (_bucket_sort_key(case.bucket), case.id)) +def _summarise_config(rows: Sequence[Dict[str, Any]], label: str) -> Dict[str, Any]: + """Aggregate case-level ranking metrics into one configuration summary.""" + + if not rows: + return {"config": label, "count": 0} + n = len(rows) + return { + "config": label, + "count": n, + "top1_accuracy": sum(float(r["top1_correct"]) for r in rows) / n, + "mrr": sum(float(r["mrr"]) for r in rows) / n, + "recall_at_k": sum(float(r["recall_at_k"]) for r in rows) / n, + } + + +def _print_summary_report( + summaries: Dict[str, Dict[str, Any]], + bucket_summaries: Dict[str, Dict[str, Dict[str, Any]]], + significance: Dict[str, Dict[str, float]], + k: int, +) -> None: + """Print a formatted benchmark summary table to stdout.""" + + rk_label = f"R@{k}" + col_w = (35, 8, 8, 8, 6) + header = f"{'Config':<{col_w[0]}} {'Top-1':>{col_w[1]}} {'MRR':>{col_w[2]}} {rk_label:>{col_w[3]}} {'N':>{col_w[4]}}" + sep = "-" * len(header) + + def _row(name: str, s: Dict[str, Any]) -> str: + return ( + f"{name:<{col_w[0]}} " + f"{float(s.get('top1_accuracy', 0.0)):>{col_w[1]}.3f} " + f"{float(s.get('mrr', 0.0)):>{col_w[2]}.3f} " + f"{float(s.get('recall_at_k', 0.0)):>{col_w[3]}.3f} " + f"{int(s.get('count', 0)):>{col_w[4]}}" + ) + + print("\n=== Benchmark Summary ===") + print(header) + print(sep) + for config_name, summary in summaries.items(): + if int(summary.get("count", 0)) == 0: + continue + print(_row(config_name, summary)) + + all_buckets = sorted( + {b for bs in bucket_summaries.values() for b in bs}, + key=_bucket_sort_key, + ) + if all_buckets: + for bucket in all_buckets: + print(f"\n -- {bucket.upper()} --") + print(" " + header) + print(" " + sep) + for config_name, by_bucket in bucket_summaries.items(): + if bucket not in by_bucket: + continue + bs = by_bucket[bucket] + if int(bs.get("count", 0)) == 0: + continue + print(" " + _row(config_name, bs)) + + if significance: + print("\n--- McNemar significance tests ---") + for pair, result in significance.items(): + print( + f" {pair}: χ²={result.get('mcnemar_chi2_cc', 0.0):.3f}" + f" (a_only={int(result.get('a_only_correct', 0))}," + f" b_only={int(result.get('b_only_correct', 0))})" + ) + + print() + + def _grounded_element_to_dict( concept: StandardConceptWithScore, ) -> Dict[str, object]: @@ -343,7 +427,7 @@ def _grounded_element_to_dict( "relevance": float(concept.relevance), "embedding_score": float(concept.embedding_score) if concept.embedding_score is not None else 0.0, "separation": int(concept.separation), - "matched_label": concept.matched_label, + "matched_concept_label": concept.matched_concept_label, "match_kind": str(concept.match_kind), "synonym": concept.synonym, } @@ -384,18 +468,12 @@ def _evaluate_grounded_case( grounding_kwargs = grounding_kwargs or {} text_embedding = cast(Optional[np.ndarray], grounding_kwargs.get("text_embedding")) - text_embedding_model = cast(Optional[str], grounding_kwargs.get("text_embedding_model")) - embedding_client = cast(Optional[EmbeddingClient], grounding_kwargs.get("embedding_client")) - metric_type = cast(Optional[MetricType], grounding_kwargs.get("metric_type")) - index_type = cast(Optional[IndexType], grounding_kwargs.get("index_type")) grounded = ground_term( resolver_pipeline=resolver_pipeline, kg=kg, - text=case.text, - text_embedding=text_embedding, - text_embedding_model=text_embedding_model, - embedding_client=embedding_client, + query=case.text, + query_embedding=text_embedding, constraints=GroundingConstraints( parent_ids=parent_ids, search_constraint=search_constraint, @@ -403,8 +481,6 @@ def _evaluate_grounded_case( predicate_kinds=frozenset({ClassIDEnum.IDENTITY}), ), max_candidates=10, - metric_type=metric_type, - index_type=index_type, ) return { @@ -423,27 +499,70 @@ def _evaluate_grounded_case( } -def run_grounded_benchmark( - cases_path: Path, - k: int, - database_url: Optional[str] = None, - embedding_backend: Optional[str | BackendType] = None, - embedding_storage_base_dir: Optional[str] = None, - embedding_model: Optional[str] = None, - embedding_api_base: Optional[str] = None, - embedding_api_key: Optional[str] = None, - embedding_metric_type: str = "cosine", - embedding_index_type: str = "flat", - domain_filter: Optional[set[str]] = None, - vocab_filter: Optional[set[str]] = None, - grounding_parent_ids: Optional[Tuple[int, ...]] = None, -) -> Dict[str, object]: - """Run the grounded benchmark and return a JSON-serialisable report object.""" - - cases = load_cases(cases_path) - if domain_filter: +@app.command("Generalised benchmark interface.") +def run_benchmark( + cases_file: Annotated[str, typer.Option( + "--cases-file", "-c", + help="Path to the JSON file containing benchmark cases.") + ], + embedding_model: Annotated[str, typer.Option( + "--embedding-model", "-m", + help="Name of the embedding model to use (e.g., 'text-embedding-3-small').") + ], + embedding_api_base_url: Annotated[str, typer.Option( + "--embedding-api-base-url", "-u", + help="Base URL for the embedding API (e.g., 'http://localhost:8000').") + ], + embedding_api_key: Annotated[str, typer.Option( + "--embedding-api-key", "-k", + help="API key for the embedding service, if required.") + ], + embedding_metric_type: Annotated[str, typer.Option( + "--embedding-metric-type", "-M", + help="Distance metric type for embedding similarity (e.g., 'cosine').") + ], + embedding_index_type: Annotated[str, typer.Option( + "--embedding-index-type", "-I", + help="Index type for embedding retrieval (e.g., 'flat').") + ], + out_file: Annotated[Optional[str], typer.Option( + "--out-file", "-o", + help="Path to the output JSON file where results will be saved. If not provided, results will be printed to stdout.") + ] = None, + k: Annotated[int, typer.Option( + "--k", "-K", + help="Number of nearest neighbors to retrieve for each case.") + ] = 5, + allowed_domains: Annotated[Optional[str], typer.Option( + "--allowed-domains", "-D", + help="Comma-separated list of allowed OMOP domains to filter concepts (e.g., 'Condition,Drug'). If not provided, no domain filtering will be applied.") + ] = None, + allowed_vocabularies: Annotated[Optional[str], typer.Option( + "--allowed-vocabularies", "-V", + help="Comma-separated list of allowed vocabularies to filter concepts (e.g., 'SNOMED,LOINC'). If not provided, no vocabulary filtering will be applied.") + ] = None, + parent_ids: Annotated[Optional[str], typer.Option( + "--grounding-parent-ids", "-G", + help="Comma-separated list of OMOP concept IDs to use as parent nodes for grounding. If not provided, no parent ID filtering will be applied.") + ] = None, + database_url: Annotated[Optional[str], typer.Option( + "--database-url", "-d", + help="Database URL for the OMOP CDM database. If not provided, will use the environment variable specified by the library (e.g., OMOP_CDM_DATABASE_URL).") + ] = None, + embedding_backend: Annotated[Optional[str], typer.Option( + "--embedding-backend", "-e", + help="Embedding backend to use (e.g., 'sqlite_vec' or 'pgvector'). If not provided, will use the environment variable specified by the library (e.g., OMOP_EMB_BACKEND).") + ] = None, + verbosity: Annotated[int, typer.Option("--verbose", "-v", count=True, help="Increase verbosity (up to two levels)")] = 0, +): + configure_logging_level(verbosity) + + cases = load_cases(Path(cases_file)) + if allowed_domains: + domain_filter = set(allowed_domains.split(",")) cases = [c for c in cases if c.domain in domain_filter] - if vocab_filter: + if allowed_vocabularies: + vocab_filter = set(allowed_vocabularies.split(",")) cases = [ c for c in cases @@ -451,48 +570,51 @@ def run_grounded_benchmark( ] cases = _order_cases_for_report(cases) + grounding_parent_ids = tuple(map(int, parent_ids.split(","))) if parent_ids else None if grounding_parent_ids is None and all(c.parent_ids is None for c in cases): raise RuntimeError( - "No grounding parent IDs provided. Set --grounding-parent-id or add parent_ids per case." + "No grounding parent IDs provided." ) embedding_client = None embedding_kg = None query_embeddings: Dict[str, np.ndarray] = {} - engine = build_engine(database_url) resolved_embedding_index_type: Optional[IndexType] = None resolved_embedding_metric_type: Optional[MetricType] = None - if embedding_model is not None and embedding_api_base is not None: + if embedding_model is not None and embedding_api_base_url is not None: resolved_embedding_index_type = parse_index_type(embedding_index_type) resolved_embedding_metric_type = parse_metric_type(embedding_metric_type) embedding_client = EmbeddingClient( model=embedding_model, - api_base=embedding_api_base, + api_base=embedding_api_base_url, api_key=embedding_api_key or "ollama", ) canonical_model = embedding_client.provider.canonical_model_name(embedding_model) embedding_kg = build_embedding_knowledge_graph( database_url=database_url, + embedding_model=canonical_model, + embedding_metric=resolved_embedding_metric_type, embedding_backend=embedding_backend, embedding_client=embedding_client, - embedding_storage_base_dir=embedding_storage_base_dir, ) embedding_dim = embedding_client.embedding_dim if embedding_dim is None: raise RuntimeError("Embedding client did not expose an embedding dimension.") - - embedding_kg.emb.setup_and_register_model( - engine=engine, - canonical_model_name=canonical_model, - dimensions=embedding_dim, - index_type=IndexType(embedding_index_type), + + embedding_writer = get_embedding_writer_interface(embedding_kg) + assert embedding_writer is not None, "Embedding backend does not support writing embeddings, which is required for this benchmark configuration." + + embedding_writer.register_model( + index_config=index_config_from_index_type( + index_type=resolved_embedding_index_type, + ), ) query_embeddings = { - case.id: embedding_kg.emb.embed_texts(case.text) + case.id: embedding_writer.embed_texts(case.text, embedding_role=EmbeddingRole.QUERY) for case in cases } @@ -500,48 +622,48 @@ def run_grounded_benchmark( configs = build_grounded_configs() errors: Dict[str, str] = {} - case_reports: List[Dict[str, object]] = [] + case_reports: List[Dict[str, Any]] = [] active_kg = embedding_kg if embedding_kg is not None else kg for case in cases: - config_results: List[Dict[str, object]] = [] + config_results: List[Dict[str, Any]] = [] for config in configs: - try: - if config.requires_embedding and embedding_kg is None: - raise MissingExtensionError( - "Embedding config requires omop-emb plus embedding model/api settings." - ) - - grounding_kwargs: Optional[Dict[str, object]] = None - if embedding_kg is not None and embedding_model is not None: - grounding_kwargs = { - "text_embedding": query_embeddings.get(case.id), - "text_embedding_model": embedding_model, - "embedding_client": embedding_client, - "metric_type": resolved_embedding_metric_type, - "index_type": resolved_embedding_index_type, - } - - row = _evaluate_grounded_case( - kg=active_kg, - case=case, - config=config, - default_parent_ids=grounding_parent_ids, - grounding_kwargs=grounding_kwargs, - ) - config_results.append(row) - - except Exception as exc: - errors[f"{case.id}:{config.name}"] = str(exc) - config_results.append( - { - "config": config.name, - "error": str(exc), - "predicted_top": {"concept_id": None, "concept_name": None, "total_score": 0.0, "relevance": 0.0, "embedding_score": 0.0}, - "target_total_score": 0.0, - } + #try: + if config.requires_embedding and embedding_kg is None: + raise MissingExtensionError( + "Embedding config requires omop-emb plus embedding model/api settings." ) + grounding_kwargs: Optional[Dict[str, object]] = None + if embedding_kg is not None and embedding_model is not None: + grounding_kwargs = { + "text_embedding": query_embeddings.get(case.id), + "text_embedding_model": embedding_model, + "embedding_client": embedding_client, + "metric_type": resolved_embedding_metric_type, + "index_type": resolved_embedding_index_type, + } + + row = _evaluate_grounded_case( + kg=active_kg, + case=case, + config=config, + default_parent_ids=grounding_parent_ids, + grounding_kwargs=grounding_kwargs, + ) + config_results.append(row) + + #except Exception as exc: + # errors[f"{case.id}:{config.name}"] = str(exc) + # config_results.append( + # { + # "config": config.name, + # "error": str(exc), + # "predicted_top": {"concept_id": None, "concept_name": None, "total_score": 0.0, "relevance": 0.0, "embedding_score": 0.0}, + # "target_total_score": 0.0, + # } + # ) + case_reports.append( { "case_id": case.id, @@ -553,149 +675,78 @@ def run_grounded_benchmark( } ) - return { + # Build per-config ranking rows for summary statistics. + per_config: Dict[str, List[Dict[str, Any]]] = {} + for case_report in case_reports: + for cfg_result in case_report["config_results"]: + config_name = str(cfg_result.get("config", "")) + if "error" in cfg_result: + continue + target_idx = cfg_result.get("target_idx_in_grounded") + expected_id = case_report["expected_concept_id"] + if expected_id is None or target_idx is None: + top1, mrr_val, rak = 0.0, 0.0, 0.0 + else: + top1 = 1.0 if target_idx == 0 else 0.0 + mrr_val = 1.0 / (int(target_idx) + 1) + rak = 1.0 if int(target_idx) < k else 0.0 + per_config.setdefault(config_name, []).append({ + "case_id": case_report["case_id"], + "bucket": case_report["bucket"], + "top1_correct": top1, + "mrr": mrr_val, + "recall_at_k": rak, + }) + + summaries = {name: _summarise_config(rows, name) for name, rows in per_config.items()} + + bucket_summaries: Dict[str, Dict[str, Dict[str, Any]]] = {} + for config_name, rows in per_config.items(): + by_bucket: Dict[str, List[Dict[str, Any]]] = {} + for row in rows: + b = str(row["bucket"]) + by_bucket.setdefault(b, []).append(row) + bucket_summaries[config_name] = { + b: _summarise_config(b_rows, f"{config_name}:{b}") + for b, b_rows in by_bucket.items() + } + + significance: Dict[str, Dict[str, float]] = {} + if "basic" in per_config and "extended" in per_config: + significance["basic_vs_extended"] = mcnemar(per_config["basic"], per_config["extended"]) + if "extended" in per_config and "full_text" in per_config: + significance["extended_vs_full_text"] = mcnemar(per_config["extended"], per_config["full_text"]) + if "full_text" in per_config and "full_text_with_embedding" in per_config: + significance["full_text_vs_full_text_with_embedding"] = mcnemar( + per_config["full_text"], per_config["full_text_with_embedding"] + ) + + report = { "cases_evaluated": len(cases), "cases": case_reports, + "summaries": summaries, + "bucket_summaries": bucket_summaries, + "significance": significance, "errors": errors, - "database_url": database_url or os.getenv("OMOP_DATABASE_URL"), + "database_url": database_url or os.getenv("OMOP_CDM_DB_URL"), "embedding_model": embedding_model, "embedding_backend": embedding_backend, - "embedding_storage_base_dir": embedding_storage_base_dir, "embedding_metric_type": embedding_metric_type, "embedding_index_type": embedding_index_type, "grounding_parent_ids": grounding_parent_ids, "k": k, } + output = json.dumps(report, indent=2) + out_str = f"Results for {len(cases)} cases across {len(configs)} configs." + if out_file is not None: + with open(out_file, "w", encoding="utf-8") as f: + f.write(output) + out_str += f" Results saved to {out_file}" -def build_grounded_benchmark_parser( - *, - description: str, - default_cases: Path, - default_out: Optional[Path] = None, -) -> argparse.ArgumentParser: - """Build the CLI parser for grounded benchmark scripts.""" - - parser = argparse.ArgumentParser(description=description) - parser.add_argument( - "--cases", - type=Path, - default=default_cases, - help="Path to benchmark case JSON file.", - ) - parser.add_argument( - "--database-url", - type=str, - default=None, - help="SQLAlchemy database URL for the local OMOP CDM.", - ) - parser.add_argument( - "--embedding-backend", - type=str, - default=os.getenv("OMOP_EMB_BACKEND"), - help="Embedding backend. Defaults to OMOP_EMB_BACKEND.", - ) - parser.add_argument( - "--embedding-storage-base-dir", - type=str, - default=os.getenv("OMOP_EMB_BASE_STORAGE_DIR"), - help="Optional base directory for file-backed embedding backends.", - ) - parser.add_argument( - "--embedding-model", - type=str, - default=os.getenv("OMOP_EMB_MODEL"), - help="Embedding model name.", - ) - parser.add_argument( - "--embedding-api-base", - type=str, - default=os.getenv("OMOP_OLLAMA_API_BASE"), - help="OpenAI-compatible API base for embedding calls.", - ) - parser.add_argument( - "--embedding-api-key", - type=str, - default=os.getenv("OMOP_OLLAMA_API_KEY"), - help="Embedding API key.", - ) - parser.add_argument( - "--embedding-metric-type", - type=str, - default="cosine", - help="Embedding similarity metric for retrieval/scoring.", - ) - parser.add_argument( - "--embedding-index-type", - type=str, - default="flat", - help="Embedding index type for retrieval/scoring.", - ) - parser.add_argument("--k", type=int, default=5, help="K for Recall@K.") - parser.add_argument( - "--grounding-parent-id", - type=int, - action="append", - default=None, - help="Grounding parent concept ID (repeatable).", - ) - parser.add_argument( - "--domain", - action="append", - default=None, - help="Optional domain filter (repeatable).", - ) - parser.add_argument( - "--vocabulary", - action="append", - default=None, - help="Optional vocabulary filter (repeatable).", - ) - parser.add_argument( - "-v", - "--verbose", - action="count", - default=0, - help="Increase logging verbosity; use -vv for DEBUG output.", - ) - parser.add_argument("--out", type=Path, default=default_out, help="Optional output JSON path.") - return parser - - -def run_grounded_benchmark_cli( - *, - description: str, - default_cases: Path, - default_out: Path, -) -> None: - """Parse CLI args, run the grounded benchmark, and write JSON output.""" + _print_summary_report(summaries, bucket_summaries, significance, k) + print(out_str) - parser = build_grounded_benchmark_parser( - description=description, - default_cases=default_cases, - default_out=default_out, - ) - args = parser.parse_args() - - configure_logging_level(args.verbose) - - report = run_grounded_benchmark( - cases_path=args.cases, - k=args.k, - database_url=args.database_url, - embedding_backend=args.embedding_backend, - embedding_storage_base_dir=args.embedding_storage_base_dir, - embedding_model=args.embedding_model, - embedding_api_base=args.embedding_api_base, - embedding_api_key=args.embedding_api_key, - embedding_metric_type=args.embedding_metric_type, - embedding_index_type=args.embedding_index_type, - domain_filter=set(args.domain) if args.domain else None, - vocab_filter=set(args.vocabulary) if args.vocabulary else None, - grounding_parent_ids=(tuple(args.grounding_parent_id) if args.grounding_parent_id else None), - ) - output = json.dumps(report, indent=2) - print(output) - if args.out is not None: - args.out.write_text(output + "\n", encoding="utf-8") +if __name__ == "__main__": + app() \ No newline at end of file diff --git a/scripts/benchmarks/benchmark_cancer_nsw.py b/scripts/benchmarks/benchmark_cancer_nsw.py deleted file mode 100644 index ba0ca21..0000000 --- a/scripts/benchmarks/benchmark_cancer_nsw.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Run the grounded benchmark with the cancer NSW case set by default.""" - -from __future__ import annotations - -from pathlib import Path - -from benchmark_base import run_grounded_benchmark_cli - - -def main() -> None: - run_grounded_benchmark_cli( - description="Grounding benchmark for the cancer NSW case set.", - default_cases=Path(__file__).with_name("cancer_nsw_cases.json"), - default_out=Path("/home/vscode/benchmark_cancer_nsw.json"), - ) - - -if __name__ == "__main__": - main() diff --git a/scripts/benchmarks/cancer_nsw_cases.json b/scripts/benchmarks/benchmark_cases/cancer_nsw_cases.json similarity index 100% rename from scripts/benchmarks/cancer_nsw_cases.json rename to scripts/benchmarks/benchmark_cases/cancer_nsw_cases.json diff --git a/scripts/benchmarks/latest_report.json b/scripts/benchmarks/benchmark_cases/latest_report.json similarity index 100% rename from scripts/benchmarks/latest_report.json rename to scripts/benchmarks/benchmark_cases/latest_report.json diff --git a/scripts/benchmarks/resolver_cases.json b/scripts/benchmarks/benchmark_cases/resolver_cases.json similarity index 100% rename from scripts/benchmarks/resolver_cases.json rename to scripts/benchmarks/benchmark_cases/resolver_cases.json diff --git a/scripts/benchmarks/benchmark_poster.py b/scripts/benchmarks/benchmark_poster.py deleted file mode 100644 index 1543c31..0000000 --- a/scripts/benchmarks/benchmark_poster.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Grounding-focused benchmark tailored for poster/showcase outputs.""" - -from __future__ import annotations - -from pathlib import Path - -from benchmark_base import run_grounded_benchmark_cli - - -def main() -> None: - run_grounded_benchmark_cli( - description="Grounding benchmark for publication/poster showcases.", - default_cases=Path(__file__).with_name("resolver_cases.json"), - default_out=Path("/home/vscode/benchmark_poster.json"), - ) - - -if __name__ == "__main__": - main() diff --git a/scripts/benchmarks/benchmark_resolvers.py b/scripts/benchmarks/benchmark_resolvers.py deleted file mode 100644 index b298b5c..0000000 --- a/scripts/benchmarks/benchmark_resolvers.py +++ /dev/null @@ -1,497 +0,0 @@ -"""Resolver benchmark against a live OMOP CDM database. - -This benchmark measures how the resolver pipeline performs on real OMOP data. -Each case supplies an input phrase, a difficulty bucket, optional domain and -vocabulary constraints, and the expected OMOP concept ID. The benchmark does -not simulate resolver output; it runs the actual resolvers against the local -database configured by ``OMOP_DATABASE_URL`` or ``--database-url``. -""" - -from __future__ import annotations - -import argparse -import json -import os -import statistics -import time -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Tuple - -import numpy as np -import sqlalchemy as sa -from dotenv import load_dotenv -from sqlalchemy.orm import sessionmaker - -from omop_graph.graph.constraints import SearchConstraintConcept -from omop_graph.extensions.emb import EmbeddingBackendType, MissingExtensionError -from omop_graph.graph.kg import KnowledgeGraph -from omop_graph.reasoning.resolvers.resolver_pipeline import ResolverPipeline -from omop_graph.reasoning.resolvers.resolvers import ( - CandidateResolver, - EmbeddingResolver, - ExactLabelResolver, - ExactSynonymResolver, - FullTextResolver, - FullTextSynonymResolver, - PartialLabelResolver, - PartialSynonymResolver, -) -from omop_emb import EmbeddingClient - - -@dataclass(frozen=True) -class BenchmarkCase: - """One real benchmark example and its expected ground-truth concept ID.""" - - id: str - text: str - bucket: str - domain: str - vocabulary: str - expected_concept_id: Optional[int] - expected_concept_name: Optional[str] = None - - -@dataclass(frozen=True) -class BenchmarkConfig: - """One resolver ablation evaluated by the benchmark.""" - - name: str - resolvers: Tuple[CandidateResolver, ...] - requires_embedding: bool = False - - -def _load_cases(path: Path) -> List[BenchmarkCase]: - """Load benchmark cases from JSON into typed dataclass instances. - - The loader accepts either a legacy flat list of cases or a bucketed mapping - of bucket name to list of cases. Bucket names are attached to each case at - load time so the evaluation/reporting code can stay unchanged. - """ - - payload = json.loads(path.read_text(encoding="utf-8")) - - if isinstance(payload, list): - return [BenchmarkCase(**row) for row in payload] - - if isinstance(payload, dict): - cases: List[BenchmarkCase] = [] - for bucket, bucket_cases in payload.items(): - for row in bucket_cases: - cases.append(BenchmarkCase(bucket=bucket, **row)) - return cases - - raise TypeError(f"Unsupported benchmark case file shape: {type(payload).__name__}") - - -def _build_session_factory(database_url: Optional[str]) -> sessionmaker: - """Build a SQLAlchemy session factory for the configured OMOP database.""" - - load_dotenv() - resolved_url = database_url or os.getenv("OMOP_DATABASE_URL") - if not resolved_url: - raise RuntimeError( - "No database URL provided. Pass --database-url or set OMOP_DATABASE_URL." - ) - - engine = sa.create_engine(resolved_url, future=True, echo=False) - return sessionmaker(bind=engine, future=True) - - -def _build_knowledge_graph(database_url: Optional[str]) -> KnowledgeGraph: - """Create a KnowledgeGraph backed by the live OMOP CDM database.""" - - return KnowledgeGraph(session_factory=_build_session_factory(database_url)) - - -def _build_embedding_knowledge_graph( - database_url: Optional[str], - embedding_backend: Optional[EmbeddingBackendType], - embedding_client: Optional[EmbeddingClient], - embedding_storage_base_dir: Optional[str], -) -> KnowledgeGraph: - """Create a KnowledgeGraph with embedding support configured when requested.""" - - session_factory = _build_session_factory(database_url) - return KnowledgeGraph( - session_factory=session_factory, - emb_backend=embedding_backend, - emb_base_storage_dir=embedding_storage_base_dir, - emb_client=embedding_client, - ) - - -def _case_constraints(case: BenchmarkCase) -> Optional[SearchConstraintConcept]: - """Translate case metadata into OMOP search constraints when available.""" - - if case.domain == "NA" and case.vocabulary == "NA": - return None - - domains = (case.domain,) if case.domain != "NA" else None - vocabularies = (case.vocabulary,) if case.vocabulary != "NA" else None - - return SearchConstraintConcept( - domains=domains, - vocabularies=vocabularies, - require_standard=False, - ) - - -def _build_configs() -> Tuple[BenchmarkConfig, ...]: - """Build the real resolver ablations compared by the benchmark report.""" - - basic = ( - ExactLabelResolver(), - ExactSynonymResolver(), - ) - extended = ( - *basic, - PartialLabelResolver(), - PartialSynonymResolver(), - ) - full_text = ( - *extended, - FullTextResolver(), - FullTextSynonymResolver(), - ) - full_text_with_embedding = ( - *full_text, - EmbeddingResolver(), - ) - - return ( - BenchmarkConfig(name="basic", resolvers=basic), - BenchmarkConfig(name="extended", resolvers=extended), - BenchmarkConfig(name="full_text", resolvers=full_text), - BenchmarkConfig(name="full_text_with_embedding", resolvers=full_text_with_embedding, requires_embedding=True), - ) - - -def _evaluate_case( - kg: KnowledgeGraph, - case: BenchmarkCase, - resolvers: Tuple[CandidateResolver, ...], - k: int, - resolver_kwargs: Optional[Dict[str, Any]] = None, -) -> Dict[str, float | int | bool | str]: - """Run one real benchmark case and derive ranking, safety, and pruning metrics.""" - - constraints = _case_constraints(case) - pipeline = ResolverPipeline(resolvers=resolvers) - resolver_kwargs = resolver_kwargs or {} - - # Candidate pruning estimate: pre-dedup hits vs deduped pipeline output. - raw_hits = 0 - for resolver in resolvers: - raw_hits += len(tuple(resolver.resolve(kg, case.text, constraints=constraints, **resolver_kwargs))) - - t0 = time.perf_counter() - predictions = [hit.concept_id for hit in pipeline.resolve(kg, case.text, constraints=constraints, **resolver_kwargs)] - latency_ms = (time.perf_counter() - t0) * 1000.0 - - expected = case.expected_concept_id - top1_correct = False - mrr = 0.0 - recall_at_k = 0.0 - false_grounding = False - safe_null = False - - if expected is None: - safe_null = len(predictions) == 0 - false_grounding = len(predictions) > 0 - else: - if expected in predictions: - rank = predictions.index(expected) + 1 - top1_correct = rank == 1 - mrr = 1.0 / rank - recall_at_k = 1.0 if rank <= k else 0.0 - else: - false_grounding = len(predictions) > 0 - - unique_hits = len(predictions) - pruning_ratio = (1.0 - (unique_hits / raw_hits)) if raw_hits > 0 else 0.0 - - return { - "case_id": case.id, - "bucket": case.bucket, - "expected": -1 if expected is None else expected, - "expected_concept_name": case.expected_concept_name or "", - "pred_count": len(predictions), - "top1_correct": float(top1_correct), - "mrr": mrr, - "recall_at_k": recall_at_k, - "false_grounding": float(false_grounding), - "safe_null": float(safe_null), - "latency_ms": latency_ms, - "raw_hits": raw_hits, - "unique_hits": unique_hits, - "pruning_ratio": pruning_ratio, - } - - -def _summarise(results: Sequence[Dict[str, float | int | bool | str]], label: str) -> Dict[str, float | str]: - """Aggregate case-level measurements into one configuration summary.""" - - if not results: - return {"config": label, "count": 0} - - latencies = [float(r["latency_ms"]) for r in results] - - return { - "config": label, - "count": float(len(results)), - "top1_accuracy": sum(float(r["top1_correct"]) for r in results) / len(results), - "mrr": sum(float(r["mrr"]) for r in results) / len(results), - "recall_at_k": sum(float(r["recall_at_k"]) for r in results) / len(results), - "false_grounding_rate": sum(float(r["false_grounding"]) for r in results) / len(results), - "safe_null_rate": sum(float(r["safe_null"]) for r in results) / len(results), - "latency_median_ms": statistics.median(latencies), - "latency_p95_ms": _percentile(latencies, 95), - "pruning_ratio_mean": sum(float(r["pruning_ratio"]) for r in results) / len(results), - } - - -def _percentile(values: List[float], p: int) -> float: - """Return a simple nearest-rank percentile for a list of latencies.""" - - if not values: - return 0.0 - ordered = sorted(values) - idx = max(0, min(len(ordered) - 1, int((p / 100.0) * (len(ordered) - 1)))) - return ordered[idx] - - -def _mcnemar(a: Sequence[Dict[str, float | int | bool | str]], b: Sequence[Dict[str, float | int | bool | str]]) -> Dict[str, float]: - """Compute a lightweight paired comparison on top-1 correctness.""" - - paired = [(float(x["top1_correct"]), float(y["top1_correct"])) for x, y in zip(a, b, strict=True)] - b_only = sum(1 for x, y in paired if x == 0.0 and y == 1.0) - a_only = sum(1 for x, y in paired if x == 1.0 and y == 0.0) - denom = b_only + a_only - chi2 = ((abs(b_only - a_only) - 1.0) ** 2 / denom) if denom > 0 else 0.0 - return { - "a_only_correct": float(a_only), - "b_only_correct": float(b_only), - "mcnemar_chi2_cc": chi2, - } - - -def run( - cases_path: Path, - k: int, - database_url: Optional[str] = None, - embedding_backend: Optional[EmbeddingBackendType] = None, - embedding_storage_base_dir: Optional[str] = None, - embedding_model: Optional[str] = None, - embedding_api_base: Optional[str] = None, - embedding_api_key: Optional[str] = None, - embedding_metric_type: str = "cosine", - embedding_index_type: str = "flat", - domain_filter: Optional[set[str]] = None, - vocab_filter: Optional[set[str]] = None, -) -> Dict[str, object]: - """Execute the benchmark and return a JSON-serialisable report object.""" - - cases = _load_cases(cases_path) - if domain_filter: - cases = [c for c in cases if c.domain in domain_filter] - if vocab_filter: - cases = [c for c in cases if c.vocabulary in vocab_filter] - - embedding_client = None - embedding_kg = None - if embedding_model is not None and embedding_api_base is not None: - embedding_client = EmbeddingClient( - model=embedding_model, - api_base=embedding_api_base, - api_key=embedding_api_key or "ollama", - ) - embedding_kg = _build_embedding_knowledge_graph( - database_url=database_url, - embedding_backend=embedding_backend, - embedding_client=embedding_client, - embedding_storage_base_dir=embedding_storage_base_dir, - ) - - kg = _build_knowledge_graph(database_url) - configs = _build_configs() - - per_config: Dict[str, List[Dict[str, float | int | bool | str]]] = {} - summaries: Dict[str, Dict[str, float | str]] = {} - bucket_summaries: Dict[str, Dict[str, Dict[str, float | str]]] = {} - errors: Dict[str, str] = {} - - for config in configs: - try: - if config.requires_embedding and embedding_kg is None: - raise MissingExtensionError( - "Embedding benchmark requires `omop-emb` plus `OMOP_EMB_BACKEND`, `--embedding-model`, and `--embedding-api-base`.") - - if config.requires_embedding: - active_kg = embedding_kg - assert active_kg is not None - query_embeddings = { - case.id: active_kg.emb.embed_texts(case.text) - for case in cases - } - resolver_kwargs: Dict[str, Any] = { - "text_embedding": None, - "text_embedding_model": embedding_model, - "metric_type": embedding_metric_type, - "index_type": embedding_index_type, - } - - rows = [] - for case in cases: - resolver_kwargs["text_embedding"] = query_embeddings[case.id] - rows.append( - _evaluate_case( - kg=active_kg, - case=case, - resolvers=config.resolvers, - k=k, - resolver_kwargs=resolver_kwargs, - ) - ) - else: - active_kg = kg - rows = [ - _evaluate_case( - kg=active_kg, - case=case, - resolvers=config.resolvers, - k=k, - ) - for case in cases - ] - except Exception as exc: - errors[config.name] = str(exc) - continue - - per_config[config.name] = rows - summaries[config.name] = _summarise(rows, config.name) - - buckets: Dict[str, List[Dict[str, float | int | bool | str]]] = {} - for row in rows: - bucket = str(row["bucket"]) - buckets.setdefault(bucket, []).append(row) - bucket_summaries[config.name] = { - bucket: _summarise(bucket_rows, f"{config.name}:{bucket}") - for bucket, bucket_rows in buckets.items() - } - - significance: Dict[str, Dict[str, float]] = {} - if "basic" in per_config and "extended" in per_config: - significance["basic_vs_extended"] = _mcnemar(per_config["basic"], per_config["extended"]) - if "extended" in per_config and "full_text" in per_config: - significance["extended_vs_full_text"] = _mcnemar(per_config["extended"], per_config["full_text"]) - if "full_text" in per_config and "full_text_with_embedding" in per_config: - significance["full_text_vs_full_text_with_embedding"] = _mcnemar( - per_config["full_text"], - per_config["full_text_with_embedding"], - ) - - return { - "cases_evaluated": len(cases), - "k": k, - "summaries": summaries, - "bucket_summaries": bucket_summaries, - "significance": significance, - "errors": errors, - "database_url": database_url or os.getenv("OMOP_DATABASE_URL"), - "embedding_model": embedding_model, - "embedding_backend": embedding_backend, - "embedding_storage_base_dir": embedding_storage_base_dir, - "embedding_metric_type": embedding_metric_type, - "embedding_index_type": embedding_index_type, - } - - -def main() -> None: - """CLI entry point for running the real OMOP benchmark from the shell.""" - - parser = argparse.ArgumentParser(description="Resolver benchmark against a live OMOP CDM database.") - parser.add_argument( - "--cases", - type=Path, - default=Path(__file__).with_name("resolver_cases.json"), - help="Path to benchmark case JSON file.", - ) - parser.add_argument( - "--database-url", - type=str, - default=None, - help="SQLAlchemy database URL for the local OMOP CDM. Defaults to OMOP_DATABASE_URL.", - ) - parser.add_argument( - "--embedding-backend", - type=str, - default=os.getenv("OMOP_EMB_BACKEND"), - help="Embedding backend to use for the benchmark. Defaults to OMOP_EMB_BACKEND.", - ) - parser.add_argument( - "--embedding-storage-base-dir", - type=str, - default=os.getenv("OMOP_EMB_BASE_STORAGE_DIR"), - help="Optional base directory for file-backed embedding backends.", - ) - parser.add_argument( - "--embedding-model", - type=str, - default=os.getenv("OMOP_EMB_MODEL"), - help="Embedding model name for query embeddings and model lookup.", - ) - parser.add_argument( - "--embedding-api-base", - type=str, - default=os.getenv("OMOP_OLLAMA_API_BASE"), - help="OpenAI-compatible API base used by the embedding client.", - ) - parser.add_argument( - "--embedding-api-key", - type=str, - default=os.getenv("OMOP_OLLAMA_API_KEY"), - help="API key used by the embedding client. Defaults to the Ollama compatibility value.", - ) - parser.add_argument( - "--embedding-metric-type", - type=str, - default="cosine", - help="Similarity metric used by embedding retrieval.", - ) - parser.add_argument( - "--embedding-index-type", - type=str, - default="flat", - help="Embedding index type used by embedding retrieval.", - ) - parser.add_argument("--k", type=int, default=5, help="K for Recall@K.") - parser.add_argument("--domain", action="append", default=None, help="Optional domain filter (repeatable).") - parser.add_argument("--vocabulary", action="append", default=None, help="Optional vocabulary filter (repeatable).") - parser.add_argument("--out", type=Path, default=None, help="Optional output JSON path.") - args = parser.parse_args() - - report = run( - cases_path=args.cases, - k=args.k, - database_url=args.database_url, - embedding_backend=args.embedding_backend, - embedding_storage_base_dir=args.embedding_storage_base_dir, - embedding_model=args.embedding_model, - embedding_api_base=args.embedding_api_base, - embedding_api_key=args.embedding_api_key, - embedding_metric_type=args.embedding_metric_type, - embedding_index_type=args.embedding_index_type, - domain_filter=set(args.domain) if args.domain else None, - vocab_filter=set(args.vocabulary) if args.vocabulary else None, - ) - - output = json.dumps(report, indent=2) - print(output) - if args.out is not None: - args.out.write_text(output + "\n", encoding="utf-8") - - -if __name__ == "__main__": - main() diff --git a/src/omop_graph/cli.py b/src/omop_graph/cli.py index f659d2c..6427a54 100644 --- a/src/omop_graph/cli.py +++ b/src/omop_graph/cli.py @@ -1,70 +1,26 @@ import sqlalchemy as sa -from sqlalchemy.orm import sessionmaker, Session - -from orm_loader.helpers import create_db, bulk_load_context -from orm_loader.loaders.loader_interface import PandasLoader -from orm_loader.helpers.metadata import Base -from omop_alchemy.cdm.handlers import ( - install_fulltext_columns, - populate_fulltext_columns, -) -from omop_alchemy.cdm.base import CDMTableBase -from omop_alchemy.cdm.model.health_system import Location, Care_Site, Provider, Visit_Occurrence -from omop_alchemy.cdm.model.clinical import ( - Person, - Condition_Occurrence, - Death, - Measurement, -) -from omop_alchemy.cdm.model.structural.episode import Episode -from omop_alchemy.cdm.model.structural.episode_event import Episode_Event -from omop_alchemy.cdm.model.derived import Observation_Period -from omop_alchemy.cdm.model.vocabulary import ( - Domain, - Vocabulary, - Concept_Class, - Relationship, - Concept, - Concept_Ancestor, - Concept_Relationship, - Concept_Synonym -) -from omop_graph.extensions.omop_alchemy import RelationshipClass, RelationshipMapping -from typing import Annotated, Union, Optional +from sqlalchemy.orm import sessionmaker +from typing import Annotated, Optional import pandas as pd -import os from pathlib import Path -from random import randint, choice -import numpy as np -from datetime import date, timedelta + from dotenv import load_dotenv import typer +import tempfile import logging -app = typer.Typer() -logger = logging.getLogger(__name__) - -ATHENA_INITIAL_LOAD = [ - Domain, - Vocabulary, - Concept_Class, - Relationship, - Concept -] +from orm_loader.helpers import bulk_load_context +from orm_loader.loaders.loader_interface import PandasLoader +from orm_loader.helpers.metadata import Base +from omop_graph.extensions.omop_alchemy import RelationshipClass, RelationshipMapping +from omop_graph.oaklib_interface.omop_factory import build_engine_string -ATHENA_SUBSEQUENT_LOAD = [ - Concept_Ancestor, - Concept_Relationship, - Concept_Synonym, -] +app = typer.Typer() +logger = logging.getLogger(__name__) -ATHENA_RELATIONSHIP_CLASSIFICATION_LOAD = [ - RelationshipClass, - RelationshipMapping -] -def configure_logging_level(verbosity: int, reduce_logging: bool = True) -> None: +def configure_logging_level(verbosity: int, reduce_logging: bool = False) -> None: """Configure global logging.""" level_map = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} log_level = level_map.get(min(verbosity, 2), logging.DEBUG) @@ -97,360 +53,17 @@ def filter(self, record: logging.LogRecord) -> bool: logger_instance.propagate = False -def _enable_fulltext_sidecars(engine: sa.Engine, regconfig: str) -> None: - install_fulltext_columns(engine) - populate_fulltext_columns(engine, regconfig=regconfig) - -def _populate_reference_data( - session: Session, - avail_country: list[int], - avail_place_of_service: list[int], - avail_provider: list[int], - avail_gender: list[int], -): - - loc_ids = Location.allocator(session) - cs_ids = Care_Site.allocator(session) - pro_ids = Provider.allocator(session) - - location_data = [{'location_id': loc_ids.next(), 'country_concept_id': choice(avail_country), 'city': f'City {idx}'} for idx in range(10)] - locations = [Location(**row) for row in location_data] - care_site_data = [{'care_site_id': cs_ids.next(), 'care_site_name': f'Care Site {idx}', 'location_id': choice(locations).location_id, 'place_of_service_concept_id': choice(avail_place_of_service)} for idx in range(30)] - care_sites = [Care_Site(**row) for row in care_site_data] - provider_data = [{'provider_id': pro_ids.next(), 'specialty_concept_id': choice(avail_provider), 'gender_concept_id': choice(avail_gender), 'care_site_id': choice(care_sites).care_site_id} for _ in range(50)] - providers = [Provider(**row) for row in provider_data] - - session.add_all(locations) - session.add_all(care_sites) - session.add_all(providers) - session.commit() - - return locations, care_sites, providers - -def _populate_people_and_visits( - session: Session, - care_sites: list[Care_Site], - avail_gender: list[int], - avail_race: list[int], - avail_ethnicity: list[int], - avail_place_of_service: list[int], - ): - - person_ids = Person.allocator(session) - visit_ids = Visit_Occurrence.allocator(session) - - person_data = [{'person_id': person_ids.next(), 'year_of_birth': randint(1950, 2020), 'month_of_birth': randint(1, 12), 'gender_concept_id':choice(avail_gender), 'race_concept_id':choice(avail_race), 'ethnicity_concept_id':choice(avail_ethnicity)} for idx in range(1000)] - people = [Person(**row) for row in person_data] - - visits = [] - for person in people: - cs = choice(care_sites) - visit_num = randint(1, 3) - for v in range(visit_num): - days_delay = randint(0, 365) - visit_date = date(2020, 1, 1) + timedelta(days_delay) - visit = Visit_Occurrence( - visit_occurrence_id=visit_ids.next(), - person_id=person.person_id, - care_site_id=cs.care_site_id, - visit_concept_id=choice(avail_place_of_service), - visit_start_date=visit_date, - visit_end_date=visit_date, - ) - visits.append(visit) - session.add_all(people) - session.add_all(visits) - session.commit() - return people, visits - -def _populate_observation_periods( - session: Session, - avail_types: list[int], - ): - op_ids = Observation_Period.allocator(session) - deaths = [] - rows = ( - session.query( - Visit_Occurrence.person_id, - sa.func.min(Visit_Occurrence.visit_start_date).label("start"), - sa.func.max(Visit_Occurrence.visit_end_date).label("end"), - Death.death_date, - Observation_Period.observation_period_id - ) - .join(Death, Death.person_id==Visit_Occurrence.person_id, isouter=True) - .join(Observation_Period, Observation_Period.person_id==Visit_Occurrence.person_id, isouter=True) - .filter(Observation_Period.observation_period_id==None) - .group_by(Visit_Occurrence.person_id) - .all() - ) - obs = [] - for idx, r in enumerate(rows): - deceased = np.random.choice([True, False], p=[0.05, 0.95]) - if deceased: - death_date = r.end + timedelta(days=randint(1, 365)) - deaths.append( - Death( - person_id=r.person_id, - death_date=death_date, - death_type_concept_id=choice(avail_types), - ) - ) - obs_end = death_date - else: - obs_end = r.end - obs.append( - Observation_Period( - observation_period_id=op_ids.next(), - person_id=r.person_id, - observation_period_start_date=r.start, - observation_period_end_date=obs_end, - period_type_concept_id=choice(avail_types), - ) - ) - session.add_all(deaths) - session.add_all(obs) - session.commit() - return obs - -def _populate_conditions_and_modifiers( - session: Session, - staging_sets: dict[str, pd.DataFrame], - cancers: list[int], - avail_types: list[int], - ): - cond_ids = Condition_Occurrence.allocator(session) - meas_ids = Measurement.allocator(session) - ep_ids = Episode.allocator(session) - rows = ( - session.query( - Observation_Period, Death, Condition_Occurrence - ) - .join(Death, Observation_Period.person_id==Death.person_id, isouter=True) - .join(Condition_Occurrence, Observation_Period.person_id==Condition_Occurrence.person_id, isouter=True) - .all() - ) - conditions = [] - measurements = [] - episodes = [] - episode_events = [] - for obs, death, condition in rows: - if condition: - continue - t = choice(list(staging_sets['T'].concept_id)) - n = choice(list(staging_sets['N'].concept_id)) - m = choice(list(staging_sets['M'].concept_id)) - # don't worry abt overall stage for now as it should be calculated - condition_concept = choice(cancers) - condition = Condition_Occurrence( - condition_occurrence_id=cond_ids.next(), - condition_concept_id = condition_concept, - condition_start_date = obs.observation_period_start_date, - condition_type_concept_id = choice(avail_types), - person_id = obs.person_id, - condition_status_concept_id = 32902 - ) - conditions.append(condition) - episode = Episode( - episode_id=ep_ids.next(), - person_id=obs.person_id, - episode_concept_id=32533, # Episode of care - episode_object_concept_id=condition.condition_concept_id, - episode_start_date=condition.condition_start_date, - episode_end_date=( - death.death_date if death else obs.observation_period_end_date - ), - episode_type_concept_id=choice(avail_types), # EHR / registry / derived - ) - episodes.append(episode) - - for stage in [t, n, m]: - measurement = Measurement( - person_id = obs.person_id, - measurement_id = meas_ids.next(), - measurement_concept_id = stage, - measurement_event_id = condition.condition_occurrence_id, - meas_event_field_concept_id = 1147127, # condition_occurrence.condition_occurrence_id - measurement_date = condition.condition_start_date, - measurement_type_concept_id = choice(avail_types), - value_as_number = 1 - ) - measurements.append(measurement) - episode_events.append( - Episode_Event( - episode_id=episode.episode_id, - event_id=measurement.measurement_id, - episode_event_field_concept_id=1147138, # measurement.measurement_id - ) - ) - episode_events.append( - Episode_Event( - episode_id=episode.episode_id, - event_id=condition.condition_occurrence_id, - episode_event_field_concept_id=1147127, # condition_occurrence.condition_occurrence_id - ) - ) - session.add_all(conditions) - session.add_all(measurements) - session.add_all(episodes) - session.add_all(episode_events) - session.commit() - -def _populate_test_data(session): - """Brute force addition of test data for development/testing purposes.""" - - # Data - concept_by_domain = pd.DataFrame( - session.query( - *Concept.__table__.columns - ) - .filter( - sa.or_( - Concept.domain_id.in_(['Gender', 'Ethnicity', 'Race', 'Visit', 'Location', 'Provider', 'Type Concept']), - sa.and_( - Concept.domain_id == 'Condition', - Concept.vocabulary_id == 'ICDO3' - ) - ) - ) - ) - - avail_gender = list(concept_by_domain[concept_by_domain.domain_id=='Gender'].concept_id) - avail_ethnicity = list(concept_by_domain[concept_by_domain.domain_id=='Ethnicity'].concept_id) - avail_race = list(concept_by_domain[concept_by_domain.domain_id=='Race'].concept_id) - avail_place_of_service = list(concept_by_domain[concept_by_domain.domain_id=='Visit'].concept_id) - avail_country = list(concept_by_domain[concept_by_domain.concept_class_id=='Location'].concept_id) - avail_provider = list(concept_by_domain[concept_by_domain.domain_id=='Provider'].concept_id) - avail_types = list(concept_by_domain[concept_by_domain.domain_id=='Type Concept'].concept_id) - cancers = list(concept_by_domain[(concept_by_domain.domain_id=='Condition')&(concept_by_domain.vocabulary_id=='ICDO3') & (concept_by_domain.concept_code.str.contains('/3'))].concept_id) - - staging_parents = pd.DataFrame( - session.query( - *Concept.__table__.columns - ) - .join(Concept_Ancestor, Concept.concept_id==Concept_Ancestor.descendant_concept_id) - .filter(Concept_Ancestor.ancestor_concept_id==734320) - .filter(Concept_Ancestor.max_levels_of_separation==1) - ) - - staging_sets = {} - - for axis in ['T', 'N', 'M', 'Stage']: - parents = list(staging_parents[staging_parents.concept_name.str.contains(axis)].concept_id) - s = pd.DataFrame( - session.query( - *Concept.__table__.columns - ) - .join(Concept_Ancestor, Concept.concept_id==Concept_Ancestor.descendant_concept_id) - .filter(Concept_Ancestor.ancestor_concept_id.in_(parents)) - .filter(Concept.concept_code.ilike('%8th%')) - .filter(~Concept.concept_code.ilike('%yp%')) - ) - staging_sets[axis] = s - - - # Care sites - _populate_reference_data(session, avail_country, avail_place_of_service, avail_provider, avail_gender) - session.commit() - care_sites = session.query(Care_Site).all() - - # People and visits - _populate_people_and_visits(session, care_sites, avail_gender, avail_race, avail_ethnicity, avail_place_of_service) - _populate_observation_periods(session, avail_types) - _populate_conditions_and_modifiers(session, staging_sets, cancers, avail_types) - - -@app.command() -def omop_cdm( - add_test_data: Annotated[bool, typer.Option(help="Whether to add synthetic test data after loading Athena data. Omit if not used.")] = False, - chunk_size: Annotated[int, typer.Option( - "--chunk-size", "-c", - help="Number of rows to process in each chunk when loading large tables with fallback pandas loader.")] = 5000, - pred_class_dir: Annotated[Optional[str], typer.Option(help="Path to the directory containing `predicate_classification.csv` and `predicate_mapping.csv`.")] = None, - fulltext: Annotated[bool, typer.Option("--fulltext/--no-fulltext", help="Install and populate PostgreSQL full-text sidecars after loading the vocabulary tables.")] = False, - fulltext_regconfig: Annotated[str, typer.Option("--fulltext-regconfig", help="PostgreSQL text search configuration to use when populating the full-text sidecars.")] = "english", - verbosity: Annotated[int, typer.Option("--verbose", "-v", count=True, help="Increase verbosity (up to two levels)")] = 0, -): - """ - Instantiate the database from scratch by loading the Athena vocabularies. - IMPORTANT: This will wipe the entire existing database in the db container. - """ - configure_logging_level(verbosity) - load_dotenv() - - engine_string = os.getenv('OMOP_DATABASE_URL') - if engine_string is None: - raise RuntimeError("OMOP_DATABASE_URL environment variable not set.") - - engine = sa.create_engine(engine_string, future=True, echo=False) - - # Drop all existing tables for a fresh bootstrap - metadata = Base.metadata - metadata.reflect(bind=engine) - metadata.drop_all(engine) - - # Re-init tables - create_db(engine) - - Session = sessionmaker(bind=engine, future=True) - session = Session() - - loader = PandasLoader() - - athena_db_path = os.getenv('SOURCE_PATH') - if athena_db_path is None: - raise RuntimeError("SOURCE_PATH environment variable not set. Please set it in your .env file to point to the Athena CSV files base directory.") - base_path = Path(athena_db_path).resolve() - assert base_path.exists(), f"Source path {base_path} does not exist" - - with bulk_load_context(session): - for model in ATHENA_INITIAL_LOAD: - model.load_csv( - session, - base_path / f"{model.__tablename__.upper()}.csv", - dedupe=True, - merge_strategy="upsert", - loader=loader - ) - session.commit() - - with bulk_load_context(session): - for model in ATHENA_SUBSEQUENT_LOAD: - model.load_csv( - session, - base_path / f"{model.__tablename__.upper()}.csv", - dedupe=True, - chunksize=chunk_size, - merge_strategy="replace", - loader=loader - ) - session.commit() - - if fulltext: - try: - _enable_fulltext_sidecars(engine, fulltext_regconfig) - except Exception as exc: - logger.error(f"Failed to enable PostgreSQL full-text sidecars: {exc}") - logger.info("Continuing with bootstrap without full-text sidecars. You can rerun omop-maint fulltext install and omop-maint fulltext populate later.") - - try: - relationship_classification(pred_class_dir) - except Exception as e: - logger.error(f"Failed to ingest predicate classifications: {e}") - logger.info("Continuing with bootstrap without predicate classifications. Re-run cli `ingest-classification` command once the issue is resolved.") - - if add_test_data: - _populate_test_data(session) - @app.command() def relationship_classification( pred_class_dir: Annotated[Optional[str], typer.Option(help="Path to the directory containing `predicate_classification.csv` and `predicate_mapping.csv`.")] = None, + env_file: Annotated[Optional[str], typer.Option("--env-file", "-e", help="Path to the .env file containing database connection variables. If not provided, will look for .env in the current working directory.")] = None, verbosity: Annotated[int, typer.Option("--verbose", "-v", count=True, help="Increase verbosity (up to two levels)")] = 0, ): """ Method to get the pre-classified predicates into the database. """ configure_logging_level(verbosity) - load_dotenv() + load_dotenv(env_file) if pred_class_dir is None: pred_class_dir = str((Path(__file__).parent.parent.parent / "docs").resolve()) @@ -461,7 +74,7 @@ def relationship_classification( if not pred_mapping_file.is_file(): raise FileNotFoundError(f"`predicate_mapping.csv` not found in {pred_class_dir_pl}") pred_class_file = pred_class_dir_pl / "predicate_classification.csv" - if not pred_class_file: + if not pred_class_file.is_file(): raise FileNotFoundError(f"`predicate_classification.csv` not found in {pred_class_dir_pl}") df_class = pd.read_csv(pred_class_file) @@ -492,17 +105,7 @@ def relationship_classification( df_rel_mapping = df_rel_mapping.dropna(subset=['class_id', 'subclass_id'], how='any') df_rel_mapping_to_export = df_rel_mapping.drop_duplicates(subset=["relationship_id", "class_id", "subclass_id"]) - # Save and then load again - athena_db_path = os.getenv('SOURCE_PATH') - if athena_db_path is None: - raise RuntimeError("SOURCE_PATH environment variable not set. Please set it in your .env file to point to the Athena CSV files base directory.") - base_path = Path(athena_db_path).resolve() - assert base_path.exists(), f"Source path {base_path} does not exist" - - engine_string = os.getenv('OMOP_DATABASE_URL') - if engine_string is None: - raise RuntimeError("OMOP_DATABASE_URL environment variable not set. Please set it in your .env file to point to your database.") - + engine_string = build_engine_string() engine = sa.create_engine(engine_string, future=True, echo=False) Session = sessionmaker(bind=engine, future=True) session = Session() @@ -520,20 +123,22 @@ def relationship_classification( Base.metadata.drop_all(bind=engine, tables=tables_to_drop, checkfirst=True) # type: ignore Base.metadata.create_all(bind=engine, tables=tables_to_drop) # type: ignore + # Save to temporary file and then reload from there for model, df in zip([RelationshipClass, RelationshipMapping], [df_rel_cls_to_export, df_rel_mapping_to_export]): - csv_path = base_path / f"{model.__tablename__.upper()}.csv" - df.to_csv(csv_path, index=False) - logger.info(f"Saved {len(df)} records to `{csv_path}` for model `{model.__name__}`") - - with bulk_load_context(session): - model.load_csv( # type: ignore - session, - csv_path, - dedupe=True, - merge_strategy="replace", - loader=PandasLoader() - ) - session.commit() + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as tmp: + csv_path = tmp.name + df.to_csv(csv_path, index=False) + logger.info(f"Saved {len(df)} records to `{csv_path}` for model `{model.__name__}`") + + with bulk_load_context(session): + model.load_csv( # type: ignore + session, + csv_path, + dedupe=True, + merge_strategy="replace", + loader=PandasLoader() + ) + session.commit() if __name__ == "__main__": app() \ No newline at end of file diff --git a/src/omop_graph/cli_utils/__init__.py b/src/omop_graph/cli_utils/__init__.py new file mode 100644 index 0000000..46056ea --- /dev/null +++ b/src/omop_graph/cli_utils/__init__.py @@ -0,0 +1 @@ +from .cli_add_test_data import populate_test_data \ No newline at end of file diff --git a/src/omop_graph/cli_utils/cli_add_test_data.py b/src/omop_graph/cli_utils/cli_add_test_data.py new file mode 100644 index 0000000..6ce5daf --- /dev/null +++ b/src/omop_graph/cli_utils/cli_add_test_data.py @@ -0,0 +1,280 @@ +from random import randint, choice +import numpy as np +from datetime import date, timedelta +import pandas as pd + +import sqlalchemy as sa +from sqlalchemy.orm import Session + +from omop_alchemy.cdm.model.structural.episode import Episode +from omop_alchemy.cdm.model.structural.episode_event import Episode_Event +from omop_alchemy.cdm.model.derived import Observation_Period +from omop_alchemy.cdm.model.health_system import Location, Care_Site, Provider, Visit_Occurrence +from omop_alchemy.cdm.model.clinical import ( + Person, + Condition_Occurrence, + Death, + Measurement, +) +from omop_alchemy.cdm.model.vocabulary import ( + Concept, + Concept_Ancestor +) + + +def populate_reference_data( + session: Session, + avail_country: list[int], + avail_place_of_service: list[int], + avail_provider: list[int], + avail_gender: list[int], +): + + loc_ids = Location.allocator(session) + cs_ids = Care_Site.allocator(session) + pro_ids = Provider.allocator(session) + + location_data = [{'location_id': loc_ids.next(), 'country_concept_id': choice(avail_country), 'city': f'City {idx}'} for idx in range(10)] + locations = [Location(**row) for row in location_data] + care_site_data = [{'care_site_id': cs_ids.next(), 'care_site_name': f'Care Site {idx}', 'location_id': choice(locations).location_id, 'place_of_service_concept_id': choice(avail_place_of_service)} for idx in range(30)] + care_sites = [Care_Site(**row) for row in care_site_data] + provider_data = [{'provider_id': pro_ids.next(), 'specialty_concept_id': choice(avail_provider), 'gender_concept_id': choice(avail_gender), 'care_site_id': choice(care_sites).care_site_id} for _ in range(50)] + providers = [Provider(**row) for row in provider_data] + + session.add_all(locations) + session.add_all(care_sites) + session.add_all(providers) + session.commit() + + return locations, care_sites, providers + +def populate_people_and_visits( + session: Session, + care_sites: list[Care_Site], + avail_gender: list[int], + avail_race: list[int], + avail_ethnicity: list[int], + avail_place_of_service: list[int], + ): + + person_ids = Person.allocator(session) + visit_ids = Visit_Occurrence.allocator(session) + + person_data = [{'person_id': person_ids.next(), 'year_of_birth': randint(1950, 2020), 'month_of_birth': randint(1, 12), 'gender_concept_id':choice(avail_gender), 'race_concept_id':choice(avail_race), 'ethnicity_concept_id':choice(avail_ethnicity)} for idx in range(1000)] + people = [Person(**row) for row in person_data] + + visits = [] + for person in people: + cs = choice(care_sites) + visit_num = randint(1, 3) + for v in range(visit_num): + days_delay = randint(0, 365) + visit_date = date(2020, 1, 1) + timedelta(days_delay) + visit = Visit_Occurrence( + visit_occurrence_id=visit_ids.next(), + person_id=person.person_id, + care_site_id=cs.care_site_id, + visit_concept_id=choice(avail_place_of_service), + visit_start_date=visit_date, + visit_end_date=visit_date, + ) + visits.append(visit) + session.add_all(people) + session.add_all(visits) + session.commit() + return people, visits + +def populate_observation_periods( + session: Session, + avail_types: list[int], + ): + op_ids = Observation_Period.allocator(session) + deaths = [] + rows = ( + session.query( + Visit_Occurrence.person_id, + sa.func.min(Visit_Occurrence.visit_start_date).label("start"), + sa.func.max(Visit_Occurrence.visit_end_date).label("end"), + Death.death_date, + Observation_Period.observation_period_id + ) + .join(Death, Death.person_id==Visit_Occurrence.person_id, isouter=True) + .join(Observation_Period, Observation_Period.person_id==Visit_Occurrence.person_id, isouter=True) + .filter(Observation_Period.observation_period_id==None) + .group_by(Visit_Occurrence.person_id) + .all() + ) + obs = [] + for idx, r in enumerate(rows): + deceased = np.random.choice([True, False], p=[0.05, 0.95]) + if deceased: + death_date = r.end + timedelta(days=randint(1, 365)) + deaths.append( + Death( + person_id=r.person_id, + death_date=death_date, + death_type_concept_id=choice(avail_types), + ) + ) + obs_end = death_date + else: + obs_end = r.end + obs.append( + Observation_Period( + observation_period_id=op_ids.next(), + person_id=r.person_id, + observation_period_start_date=r.start, + observation_period_end_date=obs_end, + period_type_concept_id=choice(avail_types), + ) + ) + session.add_all(deaths) + session.add_all(obs) + session.commit() + return obs + +def populate_conditions_and_modifiers( + session: Session, + staging_sets: dict[str, pd.DataFrame], + cancers: list[int], + avail_types: list[int], + ): + cond_ids = Condition_Occurrence.allocator(session) + meas_ids = Measurement.allocator(session) + ep_ids = Episode.allocator(session) + rows = ( + session.query( + Observation_Period, Death, Condition_Occurrence + ) + .join(Death, Observation_Period.person_id==Death.person_id, isouter=True) + .join(Condition_Occurrence, Observation_Period.person_id==Condition_Occurrence.person_id, isouter=True) + .all() + ) + conditions = [] + measurements = [] + episodes = [] + episode_events = [] + for obs, death, condition in rows: + if condition: + continue + t = choice(list(staging_sets['T'].concept_id)) + n = choice(list(staging_sets['N'].concept_id)) + m = choice(list(staging_sets['M'].concept_id)) + # don't worry abt overall stage for now as it should be calculated + condition_concept = choice(cancers) + condition = Condition_Occurrence( + condition_occurrence_id=cond_ids.next(), + condition_concept_id = condition_concept, + condition_start_date = obs.observation_period_start_date, + condition_type_concept_id = choice(avail_types), + person_id = obs.person_id, + condition_status_concept_id = 32902 + ) + conditions.append(condition) + episode = Episode( + episode_id=ep_ids.next(), + person_id=obs.person_id, + episode_concept_id=32533, # Episode of care + episode_object_concept_id=condition.condition_concept_id, + episode_start_date=condition.condition_start_date, + episode_end_date=( + death.death_date if death else obs.observation_period_end_date + ), + episode_type_concept_id=choice(avail_types), # EHR / registry / derived + ) + episodes.append(episode) + + for stage in [t, n, m]: + measurement = Measurement( + person_id = obs.person_id, + measurement_id = meas_ids.next(), + measurement_concept_id = stage, + measurement_event_id = condition.condition_occurrence_id, + meas_event_field_concept_id = 1147127, # condition_occurrence.condition_occurrence_id + measurement_date = condition.condition_start_date, + measurement_type_concept_id = choice(avail_types), + value_as_number = 1 + ) + measurements.append(measurement) + episode_events.append( + Episode_Event( + episode_id=episode.episode_id, + event_id=measurement.measurement_id, + episode_event_field_concept_id=1147138, # measurement.measurement_id + ) + ) + episode_events.append( + Episode_Event( + episode_id=episode.episode_id, + event_id=condition.condition_occurrence_id, + episode_event_field_concept_id=1147127, # condition_occurrence.condition_occurrence_id + ) + ) + session.add_all(conditions) + session.add_all(measurements) + session.add_all(episodes) + session.add_all(episode_events) + session.commit() + +def populate_test_data(session): + """Brute force addition of test data for development/testing purposes.""" + + # Data + concept_by_domain = pd.DataFrame( + session.query( + *Concept.__table__.columns + ) + .filter( + sa.or_( + Concept.domain_id.in_(['Gender', 'Ethnicity', 'Race', 'Visit', 'Location', 'Provider', 'Type Concept']), + sa.and_( + Concept.domain_id == 'Condition', + Concept.vocabulary_id == 'ICDO3' + ) + ) + ) + ) + + avail_gender = list(concept_by_domain[concept_by_domain.domain_id=='Gender'].concept_id) + avail_ethnicity = list(concept_by_domain[concept_by_domain.domain_id=='Ethnicity'].concept_id) + avail_race = list(concept_by_domain[concept_by_domain.domain_id=='Race'].concept_id) + avail_place_of_service = list(concept_by_domain[concept_by_domain.domain_id=='Visit'].concept_id) + avail_country = list(concept_by_domain[concept_by_domain.concept_class_id=='Location'].concept_id) + avail_provider = list(concept_by_domain[concept_by_domain.domain_id=='Provider'].concept_id) + avail_types = list(concept_by_domain[concept_by_domain.domain_id=='Type Concept'].concept_id) + cancers = list(concept_by_domain[(concept_by_domain.domain_id=='Condition')&(concept_by_domain.vocabulary_id=='ICDO3') & (concept_by_domain.concept_code.str.contains('/3'))].concept_id) + + staging_parents = pd.DataFrame( + session.query( + *Concept.__table__.columns + ) + .join(Concept_Ancestor, Concept.concept_id==Concept_Ancestor.descendant_concept_id) + .filter(Concept_Ancestor.ancestor_concept_id==734320) + .filter(Concept_Ancestor.max_levels_of_separation==1) + ) + + staging_sets = {} + + for axis in ['T', 'N', 'M', 'Stage']: + parents = list(staging_parents[staging_parents.concept_name.str.contains(axis)].concept_id) + s = pd.DataFrame( + session.query( + *Concept.__table__.columns + ) + .join(Concept_Ancestor, Concept.concept_id==Concept_Ancestor.descendant_concept_id) + .filter(Concept_Ancestor.ancestor_concept_id.in_(parents)) + .filter(Concept.concept_code.ilike('%8th%')) + .filter(~Concept.concept_code.ilike('%yp%')) + ) + staging_sets[axis] = s + + + # Care sites + populate_reference_data(session, avail_country, avail_place_of_service, avail_provider, avail_gender) + session.commit() + care_sites = session.query(Care_Site).all() + + # People and visits + populate_people_and_visits(session, care_sites, avail_gender, avail_race, avail_ethnicity, avail_place_of_service) + populate_observation_periods(session, avail_types) + populate_conditions_and_modifiers(session, staging_sets, cancers, avail_types) diff --git a/src/omop_graph/config.py b/src/omop_graph/config.py new file mode 100644 index 0000000..6d36bf2 --- /dev/null +++ b/src/omop_graph/config.py @@ -0,0 +1,13 @@ +"""General configuration for the omop graph, including envrionment variables.""" + +# DB connection for OMOP CDM database +ENV_OMOP_CDM_DB_URL = "OMOP_CDM_DB_URL" +ENV_OMOP_CDM_DB_USER = "OMOP_CDM_DB_USER" +ENV_OMOP_CDM_DB_PASSWORD = "OMOP_CDM_DB_PASSWORD" +ENV_OMOP_CDM_DB_HOST = "OMOP_CDM_DB_HOST" +ENV_OMOP_CDM_DB_PORT = "OMOP_CDM_DB_PORT" +ENV_OMOP_CDM_DB_NAME = "OMOP_CDM_DB_NAME" +ENV_OMOP_CDM_DB_DRIVER = "OMOP_CDM_DB_DRIVER" + +# Ingestion +ENV_OMOP_VOCABULARY_DIR = "OMOP_VOCABULARY_DIR" \ No newline at end of file diff --git a/src/omop_graph/db/session.py b/src/omop_graph/db/session.py index a895156..0b4af49 100644 --- a/src/omop_graph/db/session.py +++ b/src/omop_graph/db/session.py @@ -2,10 +2,22 @@ from datetime import date from functools import wraps import warnings +import os +from typing import Optional, Union from sqlalchemy.exc import PendingRollbackError, InvalidRequestError -from sqlalchemy import create_engine +from sqlalchemy import create_engine, URL, make_url from sqlalchemy.orm import sessionmaker, Session +from omop_graph.config import ( + ENV_OMOP_CDM_DB_DRIVER, + ENV_OMOP_CDM_DB_HOST, + ENV_OMOP_CDM_DB_NAME, + ENV_OMOP_CDM_DB_PASSWORD, + ENV_OMOP_CDM_DB_PORT, + ENV_OMOP_CDM_DB_URL, + ENV_OMOP_CDM_DB_USER +) + def safe_execute(method): """ Decorator for OmopKnowledgeGraph methods. @@ -32,16 +44,69 @@ def wrapper(self, *args, **kwargs): def make_engine( - url: str, + url: Optional[Union[URL, str]] = None, *, echo: bool = False, connect_timeout: int = 10, ): + url = url or build_engine_string() + if isinstance(url, str): + url = URL.create(url) + kwargs = {} - if not url.startswith("sqlite"): + if not url.drivername.startswith("sqlite"): kwargs["connect_args"] = {"connect_timeout": connect_timeout} + return create_engine(url, echo=echo, **kwargs) +def build_engine_string() -> "URL": + """Compose a SQLAlchemy ``URL`` for the given backend at runtime. + + Returns + ------- + sqlalchemy.URL + + Notes + ----- + If ``OMOP_CDM_DB_URL`` is set it is directly used to create the URL, and all other environment variables are ignored. + Otherwise, the following environment variables are read to compose the URL for a relational database backend: + - ``OMOP_CDM_DB_DRIVER`` (required): the SQLAlchemy driver name (e.g. 'postgresql', 'mysql', 'sqlite'). + - ``OMOP_CDM_DB_USER`` (required): the username for database authentication. + - ``OMOP_CDM_DB_PASSWORD`` (required): the password for database authentication. + - ``OMOP_CDM_DB_HOST`` (required): the hostname or IP address of the database server. + - ``OMOP_CDM_DB_NAME`` (required): the name of the database to connect to. + - ``OMOP_CDM_DB_PORT`` (optional, default 5432): the port number on which the database server is listening. + + Raises + ------ + RuntimeError + If a required environment variable is missing. + ValueError + If ``backend`` does not support URL composition from environment + variables (e.g. ``FAISS``). + """ + + + optional_url = os.getenv(ENV_OMOP_CDM_DB_URL) + if optional_url: + return make_url(optional_url) + + driver = _get_required_env_variable(ENV_OMOP_CDM_DB_DRIVER) + user = _get_required_env_variable(ENV_OMOP_CDM_DB_USER) + password = _get_required_env_variable(ENV_OMOP_CDM_DB_PASSWORD) + host = _get_required_env_variable(ENV_OMOP_CDM_DB_HOST) + database = _get_required_env_variable(ENV_OMOP_CDM_DB_NAME) + port_str = os.getenv(ENV_OMOP_CDM_DB_PORT, "5432") + port = int(port_str) if port_str else None + return URL.create( + drivername=driver, + username=user, + password=password, + host=host, + port=port, + database=database, + ) + def make_session( url: str, @@ -50,4 +115,28 @@ def make_session( ) -> Session: engine = make_engine(url, echo=echo) SessionLocal = sessionmaker(bind=engine) - return SessionLocal() \ No newline at end of file + return SessionLocal() + + +def _get_required_env_variable(name: str) -> str: + """Get the value of a required environment variable. + + Parameters + ---------- + name : str + Environment variable name. + + Returns + ------- + str + Environment variable value. + + Raises + ------ + RuntimeError + If the environment variable is not set. + """ + value = os.getenv(name) + if value is None: + raise RuntimeError(f"Required environment variable {name!r} is not set.") + return value \ No newline at end of file diff --git a/src/omop_graph/extensions/emb.py b/src/omop_graph/extensions/emb.py index 2ca35ea..c7f8090 100644 --- a/src/omop_graph/extensions/emb.py +++ b/src/omop_graph/extensions/emb.py @@ -2,10 +2,8 @@ import logging import importlib.util -from typing import TYPE_CHECKING, Optional, Sequence, Mapping, TypeAlias, Tuple +from typing import TYPE_CHECKING, Optional, Sequence, TypeAlias, Tuple import numpy as np -from sqlalchemy.orm import Session -from omop_graph.graph.constraints import SearchConstraintConcept HAS_OMOP_EMB = importlib.util.find_spec("omop_emb") is not None @@ -13,10 +11,16 @@ # Optional embedding-specific ones from omop_emb import BackendType, MetricType, IndexType, ProviderType from omop_emb import EmbeddingWriterInterface, EmbeddingReaderInterface + from omop_emb.embeddings import EmbeddingRole + from omop_emb.utils.embedding_utils import NearestConceptMatch + from omop_emb.utils.embedding_utils import EmbeddingConceptFilter + + EmbeddingBackendType: TypeAlias = BackendType EmbeddingMetricType: TypeAlias = MetricType EmbeddingIndexType: TypeAlias = IndexType EmbeddingProviderType: TypeAlias = ProviderType + EmbeddingRoleType: TypeAlias = EmbeddingRole # Circular imports for static type hints from omop_graph.graph.kg import KnowledgeGraph @@ -27,22 +31,18 @@ EmbeddingMetricType = str EmbeddingIndexType = str EmbeddingProviderType = str + EmbeddingRoleType = str SUPPORTED_BACKENDS: Tuple[str, ...] = () SUPPORTED_METRICS: Tuple[str, ...] = () -_PARSE_INDEX_TYPE = None -_PARSE_METRIC_TYPE = None - if HAS_OMOP_EMB: try: from omop_emb import BackendType, MetricType - from omop_emb.config import parse_index_type, parse_metric_type + from omop_emb.embeddings import EmbeddingRole from omop_emb import EmbeddingReaderInterface, EmbeddingWriterInterface # Extract the string values from the StrEnums SUPPORTED_BACKENDS = tuple(v.value for v in BackendType) SUPPORTED_METRICS = tuple(v.value for v in MetricType) - _PARSE_INDEX_TYPE = parse_index_type - _PARSE_METRIC_TYPE = parse_metric_type except ModuleNotFoundError as exc: # Only swallow missing optional dependency imports. if exc.name and exc.name.startswith("omop_emb"): @@ -75,8 +75,11 @@ def _get_embedding_interface(kg: KnowledgeGraph) -> Optional[EmbeddingReaderInte """ try: return kg.emb - except (MissingExtensionError, ValueError) as exc: - logger.error(f"Embedding interface not available: {exc}") + except ValueError: + logger.debug("Embedding interface not available: no EmbeddingConfiguration provided.") + return None + except MissingExtensionError as exc: + logger.warning(f"Embedding interface not available: {exc}") return None def get_embedding_reader_interface(kg: KnowledgeGraph) -> Optional["EmbeddingReaderInterface"]: @@ -103,36 +106,24 @@ def get_embedding_writer_interface(kg: KnowledgeGraph) -> Optional["EmbeddingWri def semantic_similarity( kg: KnowledgeGraph, standard_concepts: Sequence[StandardConcept], - text_embedding: Optional[np.ndarray], - text_embedding_model: Optional[str], - metric_type: Optional[EmbeddingMetricType], - index_type: Optional[EmbeddingIndexType], -) -> Optional[Tuple[Mapping[int, float], ...]]: + query_embedding: np.ndarray, +) -> Optional[Tuple[Tuple[NearestConceptMatch, ...], ...]]: """ - Calculates similarity between text embeddings and concept embeddings. + Calculates similarity between a query embedding and stored concept embeddings. Parameters ---------- kg : KnowledgeGraph The knowledge graph instance, used to access the embedding interface. standard_concepts : Sequence[StandardConcept] - A sequence of standard concepts for which to calculate similarity scores against using the text_embedding. - text_embedding : Optional[np.ndarray] - The embedding vector to compare against concept embeddings. Expected shape is (q, dimension) where q is the number of query vectors and dimension is the size of the embedding space for the model. Note: q=1 for a single text embedding. - text_embedding_model : Optional[str] - The name of the text embedding model used to generate the text_embedding. This should correspond to - a model registered in the embedding interface. If None, similarity calculation will not be attempted. - metric_type : Optional[EmbeddingMetricType] - The similarity or distance metric to use for calculating similarity scores. This must be compatible with the index type used by the database. If None, similarity calculation will not be attempted. - index_type : Optional[EmbeddingIndexType] - The type of vector index used to store the embeddings. This is required to ensure that the correct retrieval method is used from the embedding interface. If None, similarity calculation will not be attempted. + A sequence of standard concepts to score against the query embedding. + query_embedding : np.ndarray + The query vector to compare against concept embeddings. Expected shape is (1, D). Returns ------- - Optional[Tuple[Mapping[int, float], ...]] - A tuple of dictionaries mapping concept IDs to similarity scores for each query embedding. - The outer tuple is of length q (number of query embeddings, shape[0] of text_embedding), and each inner dictionary contains up to k (the number of unique concepts) entries mapping concept IDs to their similarity scores with the query embedding. - + Optional[Tuple[Tuple[NearestConceptMatch, ...], ...]] + A tuple of tuple of NearestConceptMatch objects containing similarity scores for each concept. The tuples are of shape (q, k) where q is the number of query vectors (usually 1 for a single text embedding) and k is the number of nearest neighbors returned by the embedding interface. """ if not HAS_OMOP_EMB: logger.info("Embedding functionality is not available. Ensure 'omop-emb' is installed to use this feature.") @@ -143,160 +134,95 @@ def semantic_similarity( logger.info("Embedding reader interface not found in KG. Skipping similarity calculation.") return None - if index_type is None: - logger.info("Index type is required for similarity calculation but not provided. Skipping similarity calculation.") - return None + from omop_emb.utils.embedding_utils import EmbeddingConceptFilter concept_ids = tuple(dict.fromkeys(sc.concept_id for sc in standard_concepts)) - concept_filter = SearchConstraintConcept(concept_ids=concept_ids, limit=len(concept_ids)) + concept_filter = EmbeddingConceptFilter(concept_ids=concept_ids, limit=len(concept_ids)) - with kg.session_factory() as session: - missing_sc_embeddings = embedding_reader.get_concepts_without_embedding( - session=session, - concept_filter=concept_filter, # type: ignore - index_type=index_type, - ) + missing_sc_embeddings = embedding_reader.get_concepts_without_embedding( + omop_cdm_engine=kg.cdm_engine, + concept_filter=concept_filter, + ) - if missing_sc_embeddings: - if kg.compute_missing_embeddings: - logger.debug(f"Concepts missing embeddings: {missing_sc_embeddings}. Computing missing embeddings on-the-fly.") - embedding_writer = get_embedding_writer_interface(kg) - if ( - embedding_writer is not None and - text_embedding_model is not None and - text_embedding is not None - ): + if missing_sc_embeddings: + if kg.compute_missing_embeddings: + logger.debug(f"Concepts missing embeddings: {missing_sc_embeddings}. Computing missing embeddings on-the-fly.") + embedding_writer = get_embedding_writer_interface(kg) + if embedding_writer is not None: - missing_concept_ids = tuple(missing_sc_embeddings.keys()) - missing_concept_texts = tuple(missing_sc_embeddings.values()) - calculated_embeddings = embedding_writer.embed_texts(texts=missing_concept_texts) - embedding_writer.add_to_db( - embeddings=calculated_embeddings, - concept_ids=missing_concept_ids, - session=session, - index_type=index_type, - ) - logger.debug(f"Computed and stored embeddings for missing concepts: {missing_concept_ids}") - else: - param_dict = { - "text_embedding_model": text_embedding_model, - "embedding_writer": embedding_writer, - "text_embedding": text_embedding, - "index_type": index_type - } - none_params = [k for k, v in param_dict.items() if v is None] - logger.info( - f"Cannot compute missing embeddings due to missing parameters: {none_params}\n" - "Ensure the KG was initialised with a write-capable client to enable on-the-fly embedding computation.\n" - f"Expect missing embedding scores for concepts: {missing_sc_embeddings}" - ) + missing_concept_ids = tuple(missing_sc_embeddings.keys()) + missing_concept_texts = tuple(missing_sc_embeddings.values()) + + embedding_writer.embed_and_upsert_concepts( + omop_cdm_engine=kg.cdm_engine, + concept_ids=missing_concept_ids, + concept_texts=missing_concept_texts, + ) + logger.debug(f"Computed and stored embeddings for missing concepts: {missing_concept_ids}") else: logger.info( - f"Concepts missing embeddings: {missing_sc_embeddings}.\n" - "compute_missing_embeddings is disabled; these concepts will be skipped in similarity scoring.\n" - "Expect missing embedding scores for these concepts in the results." + f"Cannot compute missing embeddings due to missing embedding_writer.\n" + "Ensure the KG was initialised with a write-capable client to enable on-the-fly embedding computation.\n" + f"Expect missing embedding scores for concepts: {missing_sc_embeddings}" ) - - similarity_scores_tuple_of_dicts = get_neareast_concepts( - session=session, - kg=kg, - text_embedding_model=text_embedding_model, - text_embedding=text_embedding, - concept_filter=concept_filter, - metric_type=metric_type, - index_type=index_type, - ) + else: + logger.info( + f"Concepts missing embeddings: {missing_sc_embeddings}.\n" + "compute_missing_embeddings is disabled; these concepts will be skipped in similarity scoring.\n" + "Expect missing embedding scores for these concepts in the results." + ) + + nearest_concept_matches = get_neareast_concepts( + kg=kg, + query_embedding=query_embedding, + concept_filter=concept_filter, + ) - return similarity_scores_tuple_of_dicts + return nearest_concept_matches def get_neareast_concepts( - session: Session, kg: KnowledgeGraph, - text_embedding_model: Optional[str], - text_embedding: Optional[np.ndarray], - concept_filter: Optional[SearchConstraintConcept], - metric_type: Optional[EmbeddingMetricType], - index_type: Optional[EmbeddingIndexType], -) -> Optional[Tuple[Mapping[int, float], ...]]: + query_embedding: np.ndarray, + concept_filter: Optional[EmbeddingConceptFilter], +) -> Optional[Tuple[Tuple[NearestConceptMatch, ...], ...]]: """ - RAG retrieval for concept similarity scores. The text_embedding is used to retrieve the nearest concepts from the database - using stored embeddings and the specified similarity metric. + RAG retrieval for concept similarity scores. The query_embedding is compared against + stored embeddings using the metric and model already configured on the KG's embedding + reader interface. Parameters ---------- - session : Session - SQLAlchemy session for any required relational access. kg : KnowledgeGraph The knowledge graph instance, used to access the embedding interface. - text_embedding_model : Optional[str] - The name of the text embedding model to use for retrieval. This should correspond to a model registered in the embedding interface. If None, retrieval will not be attempted. - text_embedding : Optional[np.ndarray] - The embedding vector to search with. Expected shape is (q, dimension) where q is the number of query vectors and dimension is the size of the embedding space for the model. If None, retrieval will not be attempted. - concept_filter : Optional[EmbeddingConceptFilter], optional - A filter to specify which concepts to consider as potential nearest neighbors. Also limits the number of neighbors returned (K). If None, internal defaults are used to limit the number of neighbours. - index_type : IndexType - The type of vector index used to store the embeddings. - metric_type : MetricType - The similarity or distance metric to use for nearest neighbor search. This must be compatible with the index type used by the database. + query_embedding : np.ndarray + The query vector to search with. Expected shape is (q, D). + concept_filter : Optional[EmbeddingConceptFilter] + Pre-filter applied during KNN (concept IDs, domain, vocabulary, standard). + Also caps k to ``concept_filter.limit`` when set. Returns ------- - Tuple[Mapping[int, float], ...], optional - A tuple of dictionaries containing nearest concept matches for each query vector. The outer tuple is of length q (number of query vectors), and each inner dictionary maps concept IDs to their similarity scores with the query embedding (having k entries corresponding to the k nearest neighbors). - If retrieval fails or if any required parameters are missing, returns None. + Tuple[Tuple[NearestConceptMatch, ...], ...], optional + Shape ``(q, ≤k)``. Returns None when the interface is unavailable or + no embedding is provided. """ if not HAS_OMOP_EMB: return None - + embedding_reader = get_embedding_reader_interface(kg) if not embedding_reader: logger.info("Embedding interface not available in KG.") return None - - if not text_embedding_model: - logger.info("No text embedding model specified.") - return None - - if not index_type: - logger.info("No index type specified for retrieval.") - return None - - if metric_type is None: - logger.info("No metric type specified for retrieval.") - return None - - if _PARSE_INDEX_TYPE is None or _PARSE_METRIC_TYPE is None: - logger.info("Embedding type parsers are unavailable; cannot validate metric/index inputs.") - return None - try: - resolved_index_type = _PARSE_INDEX_TYPE(index_type) - resolved_metric_type = _PARSE_METRIC_TYPE(metric_type) - except ValueError as exc: - logger.info(f"Invalid embedding retrieval parameters: {exc}") - return None - - if not embedding_reader.is_model_registered(index_type=resolved_index_type): - logger.info(f"Model '{text_embedding_model}' not registered.") - return None - - if text_embedding is None: + if not embedding_reader.is_model_registered(): + logger.info("Model '%s' not registered.", embedding_reader.canonical_model_name) return None - similarity_scores_tuple = embedding_reader.get_nearest_concepts( - session=session, - index_type=resolved_index_type, - query_embedding=text_embedding, - concept_filter=concept_filter, # type: ignore - metric_type=resolved_metric_type, + nearest_concepts = embedding_reader.get_nearest_concepts( + query_embedding=query_embedding, + concept_filter=concept_filter, ) - - if similarity_scores_tuple is None: - logger.info("No similarity scores retrieved from embedding interface.") + if not nearest_concepts: + logger.info("No nearest concepts found for the given query embedding and filter.") return None - - if not all(isinstance(d, dict) for d in similarity_scores_tuple): - raise RuntimeError( - "Expected each item in similarity_scores_tuple to be a dictionary mapping concept IDs to scores." - ) - return similarity_scores_tuple \ No newline at end of file + return nearest_concepts \ No newline at end of file diff --git a/src/omop_graph/extensions/omop_alchemy.py b/src/omop_graph/extensions/omop_alchemy.py index f51a0d6..2f45c9f 100644 --- a/src/omop_graph/extensions/omop_alchemy.py +++ b/src/omop_graph/extensions/omop_alchemy.py @@ -2,10 +2,10 @@ import sqlalchemy as sa import sqlalchemy.orm as so from orm_loader.helpers import Base -from omop_alchemy.cdm.base import ReferenceTable, cdm_table, CDMTableBase, ReferenceContext +from omop_alchemy.cdm.base import ReferenceTable, cdm_table, CDMTableBase import functools -from enum import Enum, auto +from enum import Enum from dataclasses import dataclass class ClassIDEnum(Enum): diff --git a/src/omop_graph/graph/base.py b/src/omop_graph/graph/base.py index 46810b7..1580d98 100644 --- a/src/omop_graph/graph/base.py +++ b/src/omop_graph/graph/base.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from datetime import date +from functools import lru_cache from typing import Iterable, Optional, Literal from sqlalchemy.orm import Session @@ -17,6 +18,7 @@ class GraphBackend(ABC): """ @abstractmethod + @lru_cache(maxsize=200_000) def concept_view(self, concept_id: int) -> ConceptView: ... @@ -25,6 +27,7 @@ def predicate_kind(self, relationship_id: str) -> ClassIDEnum: ... @abstractmethod + @lru_cache(maxsize=10_000) def predicate_name(self, relationship_id: str) -> str: ... @@ -59,7 +62,7 @@ def iter_edges( 'out' for outgoing, 'in' for incoming. predicate_ids : frozenset[str], optional Filter by specific relationship IDs. - predicate_kinds : Set[PredicateKind], optional + predicate_kinds : Set[ClassIDEnum], optional Filter by semantic kind of relationship. active_only : bool If True, return only valid/active edges. diff --git a/src/omop_graph/graph/edges.py b/src/omop_graph/graph/edges.py index 9f67fea..1853e7b 100644 --- a/src/omop_graph/graph/edges.py +++ b/src/omop_graph/graph/edges.py @@ -18,10 +18,12 @@ from __future__ import annotations import logging -from dataclasses import dataclass, fields +from dataclasses import dataclass from datetime import date from typing import TYPE_CHECKING, Optional +from sqlalchemy.engine import Row + from ..extensions.omop_alchemy import ClassIDEnum if TYPE_CHECKING: @@ -84,8 +86,8 @@ def pretty(self, kg: KnowledgeGraph) -> str: return f"{s.concept_name} -[{pred.name}]-> {o.concept_name}" @classmethod - def from_query(cls, entry) -> "EdgeView": - data = dict(zip([f.name for f in fields(cls)], entry)) + def from_query(cls, entry: Row) -> "EdgeView": + data = dict(entry._mapping) if "class_id" in data: data["class_id"] = ClassIDEnum(data["class_id"]) return cls(**data) diff --git a/src/omop_graph/graph/kg.py b/src/omop_graph/graph/kg.py index 9ff2cab..94284ab 100644 --- a/src/omop_graph/graph/kg.py +++ b/src/omop_graph/graph/kg.py @@ -19,11 +19,13 @@ import logging import re +import os from datetime import date from functools import lru_cache -from typing import Dict, Optional, Tuple, Union, Literal, Generator, TYPE_CHECKING +from typing import Dict, Optional, Tuple, Literal, Generator, TYPE_CHECKING from dataclasses import dataclass, field +from sqlalchemy import Engine from sqlalchemy.exc import InvalidRequestError, PendingRollbackError from sqlalchemy.orm import Session, sessionmaker from omop_alchemy.cdm.handlers.fulltext import FullTextError @@ -32,7 +34,7 @@ from omop_emb import EmbeddingWriterInterface, EmbeddingReaderInterface, EmbeddingClient # Local Application Imports -from ..extensions.emb import MissingExtensionError, EmbeddingBackendType, EmbeddingProviderType +from ..extensions.emb import MissingExtensionError, EmbeddingBackendType, EmbeddingProviderType, EmbeddingMetricType from ..extensions.omop_alchemy import ClassIDEnum, RelationshipCache, validate_mapping_table from .base import GraphBackend from .constraints import SearchConstraintConcept @@ -77,18 +79,23 @@ class KnowledgeGraphEmbeddingConfiguration: Parameters ---------- + metric_type : EmbeddingMetricType + The similarity/distance metric to use for embedding comparisons (e.g., cosine, euclidean). + This is required to ensure that the correct type of index is used in the backend and that + similarity computations are consistent. + model_name : str + The canonical model name to use for the embedding reader interface (e.g., 'text-embedding-3-small:0.6b'). + Required for read-only embedding interface to determine which embeddings to retrieve for concepts. + Obtained from client if a client is provided, otherwise must be set explicitly for read-only use cases. backend_type : EmbeddingBackendType The embedding backend name (e.g., 'faiss', 'pinecone') or type to use. - base_storage_dir : str, optional - The directory where embeddings are stored. client : EmbeddingClient, optional An optional client instance for generating embeddings. If not provided, no writing operations can take place. provider_type : EmbeddingProviderType, optional The respective provider name (e.g., 'openai', 'ollama') or type if using a read-only embedding reader interface. - canonical_model_name : str, optional - The canonical model name to use for the embedding reader interface (e.g., 'text-embedding-3-small:0.6b'). - Required for read-only embedding interface to determine which embeddings to retrieve for concepts. - Obtained from client if a client is provided, otherwise must be set explicitly for read-only use cases. + provider_type : EmbeddingProviderType, optional + The provider type to use for the embedding reader interface (e.g., 'ollama'). + Required for read-only embedding interface to determine provider-specific canonical model name. compute_missing_embeddings : bool If True, the system will compute embeddings on-the-fly for any concept that is not yet present in the embedding store, and persist those embeddings back to the DB before running similarity scoring. @@ -96,12 +103,11 @@ class KnowledgeGraphEmbeddingConfiguration: the KG only holds a read-only interface, the flag has no effect, and missing concepts are silently skipped. Defaults to ``False`` so that unexpected writes do not occur when only a read-only configuration is given. """ - + metric_type: EmbeddingMetricType + model_name: Optional[str] = None backend_type: Optional[EmbeddingBackendType] = None - base_storage_dir: Optional[str] = None client: Optional[EmbeddingClient] = None provider_type: Optional[EmbeddingProviderType] = None - canonical_model_name: Optional[str] = None compute_missing_embeddings: bool = field(default=False) class KnowledgeGraph(GraphBackend): @@ -113,22 +119,28 @@ class KnowledgeGraph(GraphBackend): Parameters ---------- - session_factory : sessionmaker - The SQLAlchemy sessionmaker factory capable of creating separate sessions for - each database access. - + cdm_engine : Engine + The SQLAlchemy engine for the OMOP CDM database. """ def __init__( self, - session_factory: sessionmaker, + cdm_engine: Engine, emb_config: Optional[KnowledgeGraphEmbeddingConfiguration] = None, ): - self.session_factory = session_factory + self.cdm_engine = cdm_engine + self.session_factory = sessionmaker(bind=self.cdm_engine, future=True) - # Populate the relationshipcache - with self.session_factory() as session: - RelationshipCache.load(session) + try: + with self.session_factory() as session: + RelationshipCache.load(session) + except Exception as exc: + raise RuntimeError( + "Failed to load RelationshipCache. " + "The KnowledgeGraph requires relationship classification data. " + "Run `omop-graph relationship-classification` to populate it, " + "or `omop-graph omop-cdm` for a full bootstrap." + ) from exc # Embedding-specific private args self._emb_config = emb_config @@ -150,27 +162,38 @@ def emb(self) -> "EmbeddingWriterInterface | EmbeddingReaderInterface": try: from omop_emb.interface import EmbeddingWriterInterface, EmbeddingReaderInterface - + from omop_emb.config import ENV_OMOP_EMB_BACKEND + from omop_emb.backends.base_backend import resolve_backend + if self._emb_config is None: raise ValueError("Embedding configuration is not set. Please provide an EmbeddingConfiguration when initializing the KnowledgeGraph to use embedding features.") + + backend_type = self._emb_config.backend_type or os.getenv(ENV_OMOP_EMB_BACKEND, None) + if backend_type is None: + raise ValueError(f"Embedding backend type must be specified either in the configuration or via the {ENV_OMOP_EMB_BACKEND} environment variable.") + + backend = resolve_backend(backend_type) + if self._emb_config.client is not None: # Write-capable interface self._emb = EmbeddingWriterInterface( embedding_client=self._emb_config.client, - backend_name_or_type=self._emb_config.backend_type, - storage_base_dir=self._emb_config.base_storage_dir, + backend=backend, + metric_type=self._emb_config.metric_type, + omop_cdm_engine=self.cdm_engine, ) else: if self._emb_config.provider_type is None: raise ValueError("Provider type must be specified for read-only embedding interface.") - if self._emb_config.canonical_model_name is None: + if self._emb_config.model_name is None: raise ValueError("Canonical model name must be specified for read-only embedding interface.") # Read-only interface self._emb = EmbeddingReaderInterface( - canonical_model_name=self._emb_config.canonical_model_name, - backend_name_or_type=self._emb_config.backend_type, + model=self._emb_config.model_name, + backend=backend, + metric_type=self._emb_config.metric_type, + omop_cdm_engine=self.cdm_engine, provider_name_or_type=self._emb_config.provider_type, - storage_base_dir=self._emb_config.base_storage_dir, ) return self._emb @@ -266,18 +289,18 @@ def concept_id_by_code(self, vocabulary_id: str, concept_code: str) -> int: @lru_cache(maxsize=600_000) def concept_lookup( self, - label: str, + query_term: str, match_kind: LabelMatchKind, synonym: bool = False, search_constraint: Optional[SearchConstraintConcept] = None, sort: bool = True ) -> tuple[LabelMatch, ...]: """ - Resolve a label to concept_id(s). + Resolve a query to concept_id(s). Parameters ---------- - label : str + query_term : str The term to search for. match_kind : LabelMatchKind The kind of match to perform (exact, fulltext, partial). @@ -287,8 +310,8 @@ def concept_lookup( Additional filters for domain/vocabulary. """ - input_label = self._normalise_label(label) - if not input_label: + input_query_term = self._normalise_query_term(query_term) + if not input_query_term: return () if match_kind == LabelMatchKind.EXACT: @@ -301,31 +324,28 @@ def concept_lookup( raise ValueError(f"Unsupported search mode: {match_kind}") try: cn = fn( - input_label, + input_query_term, search_constraint=search_constraint, synonym=synonym, sort=sort, + engine=self.cdm_engine ) - except FullTextError: + except FullTextError as e: if match_kind == LabelMatchKind.FTS: - logger.info( - "Skipping full-text concept lookup because the optional OMOP " - "Alchemy tsvector columns are not available. Run `omop-maint " - "fulltext install` and `omop-maint fulltext populate` to enable " - "this resolver." - ) + logger.info(e) return () raise with self.session_factory() as session: matches = tuple( LabelMatch( - input_label=input_label, - matched_label=name, - concept_id=int(cid), + input_query=input_query_term, + matched_concept_label=name, + matched_concept_id=int(cid), match_kind=match_kind, is_standard=is_standard, is_active=is_active, + synonym=synonym, ) for cid, name, is_standard, is_active in session.execute(cn) ) @@ -422,8 +442,9 @@ def relationships( Yields ------- - Tuple[CURIE, PRED_CURIE, CURIE] - Triples (subject, predicate, object). + Tuple[int, str, int] + Triples of (subject_concept_id, relationship_id, object_concept_id). + When ``invert=True``, the triple is (object_concept_id, relationship_id, subject_concept_id). """ if invert: for s, p, o in self.relationships( @@ -434,17 +455,15 @@ def relationships( ): yield o, p, s return - - - for s,p,o in session.execute( + for s, p, o in session.execute( q_relationships( - subjects=objects, + subjects=subjects, predicates=predicates, - objects=subjects, + objects=objects, ) ): - yield o, p, s + yield s, p, o def reverse_predicate_id(self, relationship_id: str) -> Optional[str]: @@ -453,11 +472,11 @@ def reverse_predicate_id(self, relationship_id: str) -> Optional[str]: """ return self.predicate(relationship_id).reverse_id - def _normalise_label(self, s: str) -> str: + def _normalise_query_term(self, query_term: str) -> str: """ Normalize a string for lookup (lowercase, single spaces). """ - return re.sub(r"\s+", " ", s.strip().lower()) + return re.sub(r"\s+", " ", query_term.strip().lower()) @lru_cache(maxsize=1_000_000) def edges( @@ -628,7 +647,7 @@ def synonyms_for_concept(self, concept_id: int) -> tuple[str, ...]: """ with self.session_factory() as session: rows = session.execute(q_concept_synonym_filtered(concept_id)).all() - return tuple(row.concept_synonym_name for row in rows) + return tuple(row.name for row in rows) def rollback_session(self) -> None: """ @@ -729,10 +748,16 @@ def clear_caches(self) -> None: Clear all LRU caches associated with the graph. """ self.concept_view.cache_clear() + self.concept_views.cache_clear() self.concept_id_by_code.cache_clear() self.concept_ids_by_label.cache_clear() self.concept_lookup.cache_clear() self.predicate.cache_clear() self.predicate_name.cache_clear() self.parents.cache_clear() + self.children.cache_clear() + self.roots.cache_clear() + self.leaves.cache_clear() + self.singletons.cache_clear() + self.synonyms_for_concept.cache_clear() self.edges.cache_clear() \ No newline at end of file diff --git a/src/omop_graph/graph/nodes.py b/src/omop_graph/graph/nodes.py index 996d393..bb67d94 100644 --- a/src/omop_graph/graph/nodes.py +++ b/src/omop_graph/graph/nodes.py @@ -16,7 +16,7 @@ from collections import defaultdict from dataclasses import dataclass from datetime import date -from enum import Enum, auto +from enum import Enum from html import escape from itertools import chain from typing import Dict, Iterable, List, Optional, Tuple @@ -155,8 +155,9 @@ class LabelMatchKind(Enum): - PARTIAL: Partial match (fuzzy) substrings with ILIKE. - EMBEDDING: Match based on vector similarity. - It currently does not distinguish between synonym vs. concept_name matches as they are recognised as identical. - Could be extended in the future if needed. + It does not use synonym vs. concept_name as a ranking signal, + as these are treated as identical quality. + The ``LabelMatch.synonym`` field carries that distinction for callers that need it. """ EXACT = 0 @@ -175,27 +176,32 @@ class LabelMatch: Parameters ---------- - input_label : str + input_query : str The original text that was searched. - matched_label : str + matched_concept_label : str The text in the database that matched (concept name or synonym). - concept_id : int - The ID of the matched concept. + matched_concept_id : int + The OMOP Concept ID of the matched concept. match_kind : LabelMatchKind - How the match was found (Exact, Synonym, etc.). + How the match was found (Exact, FTS, Partial, Embedding). is_standard : bool Whether the matched concept is Standard. is_active : bool Whether the matched concept is currently valid. + synonym : bool + True if the match came from the ``concept_synonym`` table rather than + the primary ``concept_name`` field. This is informational only and does + not affect priority ordering. See ``LabelMatchKind`` for ranking. """ - input_label: str - matched_label: str - concept_id: int + input_query: str + matched_concept_label: str + matched_concept_id: int match_kind: LabelMatchKind is_standard: bool is_active: bool + synonym: bool def _repr_html_(self) -> str: """ @@ -224,8 +230,8 @@ def _repr_html_(self) -> str: return f"""
- {escape(self.matched_label)} - → concept_id {self.concept_id} + {escape(self.matched_concept_label)} + → concept_id {self.matched_concept_id}
{kind_badge} {std_badge} {active_badge} @@ -274,9 +280,12 @@ def from_matches(cls, matches: Iterable[LabelMatch]) -> LabelMatchGroupView: """ grouped: Dict[int, List[LabelMatch]] = defaultdict(list) for m in matches: - grouped[m.concept_id].append(m) + grouped[m.matched_concept_id].append(m) - grouped_tuple = {cid: tuple(ms) for cid, ms in grouped.items()} + grouped_tuple = { + cid: tuple(sorted(ms, key=lambda m: m.match_kind.value)) + for cid, ms in grouped.items() + } return cls(groups=grouped_tuple) def __iter__(self): @@ -322,12 +331,12 @@ def _repr_html_(self) -> str: reasons.append("active" if best.is_active else "inactive") # Collect other matched synonyms - other_labels = ", ".join(escape(m.matched_label) for m in ms[1:]) + other_labels = ", ".join(escape(m.matched_concept_label) for m in ms[1:]) rows.append(f""" {cid} - {escape(best.matched_label)} + {escape(best.matched_concept_label)} {escape(", ".join(reasons))} {other_labels if other_labels else "—"} diff --git a/src/omop_graph/graph/paths.py b/src/omop_graph/graph/paths.py index 2735bb7..7950e6b 100644 --- a/src/omop_graph/graph/paths.py +++ b/src/omop_graph/graph/paths.py @@ -22,15 +22,12 @@ Any, Dict, List, - Literal, Optional, Set, Tuple, Union, ) -import numpy as np - # Local Application Imports from omop_graph.extensions.omop_alchemy import ClassIDEnum from omop_graph.graph.edges import EdgeView @@ -175,27 +172,35 @@ def explain(self, kg: "KnowledgeGraph") -> str: return "\n ↳ ".join(parts) -def reconstruct_paths(source, target, meet, parents_fwd, parents_bwd): +def reconstruct_paths( + source, + target, + meet, + parents_fwd, + parents_bwd, + concept_standard_map: Dict[int, bool], +): """ - Helper function to reconstruct full paths from bidirectional BFS parent pointers. + Reconstruct full paths from bidirectional BFS parent pointers. + + Parameters + ---------- + concept_standard_map : dict[int, bool] + Mapping of concept_id → is_standard for all nodes discovered during BFS. + Built with a single batched ``kg.concept_views`` call after the BFS completes + so that every ``Node`` carries the correct flag with zero extra DB round-trips. """ + def std(cid: int) -> bool: + return concept_standard_map.get(cid, False) + def left(n): if n == source: return [()] out = [] for p, pred in parents_fwd[n]: for L in left(p): - # We need to construct Nodes here. In raw BFS we only tracked IDs. - # NOTE: This reconstruction assumes we fetch 'is_standard' later or ignore it here. - # For strictly typing PathStep, we create dummy Nodes here or need access to KG. - # Assuming simple reconstruction for now. - # To fix strictly: BFS needs to store Node info or we look it up. - # For now, we instantiate Nodes with is_standard=False as placeholders if strictly required, - # but usually the calling function enriches this. - subj = Node(p, False) - obj = Node(n, False) - out.append(L + (PathStep(subj, pred, obj),)) + out.append(L + (PathStep(Node(p, std(p)), pred, Node(n, std(n))),)) return out def right(n): @@ -204,9 +209,7 @@ def right(n): out = [] for nxt, pred in parents_bwd[n]: for R in right(nxt): - subj = Node(n, False) - obj = Node(nxt, False) - out.append((PathStep(subj, pred, obj),) + R) + out.append((PathStep(Node(n, std(n)), pred, Node(nxt, std(nxt))),) + R) return out return [GraphPath(L + R) for L in left(meet) for R in right(meet)] @@ -379,11 +382,18 @@ def find_shortest_paths( else None ) + # One batched lookup to get is_standard for every discovered concept so that + # reconstructed Node objects carry the correct flag (zero extra per-node DB calls). + all_discovered = tuple(set(depth_fwd.keys()) | set(depth_bwd.keys())) + concept_standard_map: Dict[int, bool] = { + v.concept_id: v.standard_concept for v in kg.concept_views(all_discovered) + } + paths: List[GraphPath] = [] for meet in meeting_nodes: - # Note: reconstruction logic needs careful implementation to create proper Node objects - # if using the simplified 'reconstruct_paths' helper above. - paths.extend(reconstruct_paths(source, target, meet, parents_fwd, parents_bwd)) + paths.extend( + reconstruct_paths(source, target, meet, parents_fwd, parents_bwd, concept_standard_map) + ) if len(paths) >= max_paths: break @@ -539,13 +549,19 @@ def find_shortest_paths_batch( else: bwd_frontier = next_frontier - # Reconstruct paths if not meeting_nodes: return [] + all_discovered = tuple(set(depth_fwd.keys()) | set(depth_bwd.keys())) + concept_standard_map: Dict[int, bool] = { + v.concept_id: v.standard_concept for v in kg.concept_views(all_discovered) + } + paths: List[GraphPath] = [] for meet in meeting_nodes: - paths.extend(reconstruct_paths(source, target, meet, parents_fwd, parents_bwd)) + paths.extend( + reconstruct_paths(source, target, meet, parents_fwd, parents_bwd, concept_standard_map) + ) if len(paths) >= max_paths: break @@ -573,7 +589,7 @@ class StandardConcept: separation: int original_id: int original_name: str - matched_label: str + matched_concept_label: str match_kind: LabelMatchKind synonym: bool hierarchy_cost: float = 0.0 @@ -614,38 +630,40 @@ def find_standard_paths( predicate_kinds: Optional[frozenset[Any]] = None, max_depth: int = 6, max_concepts: Optional[int] = None, - num_hops: int = 1, *args, **kwargs, ) -> List[StandardConcept]: """ - Search for Standard Concepts related to a target ID, starting from a candidate. + Search for Standard Concepts reachable from a candidate, verified against a target ancestor. - This method traverses from the candidate (Non-Standard) concept to find - Standard Concepts, then verifies if those Standard Concepts are ancestors - of the target concept in the hierarchy. + Starting from the candidate, outgoing edges are walked and only Standard Concept + neighbors are enqueued (non-standard neighbors are skipped to prevent graph explosion). + When a Standard Concept is reached its ancestry is verified against ``target`` via + ``concept_ancestor``. + It is never expanded further to standard_concepts related to this standard_concept, + as we want to find the closest standard_concept to the candidate that satisfies the + ancestor constraint, and expanding further would only find more distant standard_concepts, + thus diluting the results. Parameters ---------- kg : KnowledgeGraph The graph instance. target : int - The ancestor concept ID to check against. + The ancestor concept ID to verify candidates against. candidate : CandidateHit The search hit to start traversal from. predicate_kinds : frozenset, optional Allowed edge types for traversal. max_depth : int - Max separation levels allowed in the ancestor check. + Maximum ``min_levels_of_separation`` allowed in the ``concept_ancestor`` check. max_concepts : int, optional Stop after finding this many unique standard concepts. - num_hops : int - Max hops allowed from the candidate to reach a standard concept. Returns ------- list[StandardConcept] - The resolved concepts. + The resolved standard concepts that satisfy the ancestor constraint. """ source_view = kg.concept_view(candidate.concept_id) source_is_std = source_view.standard_concept if source_view else False @@ -676,10 +694,6 @@ def find_standard_paths( if max_concepts and len(found_standard_concepts) >= max_concepts: break - # Prevent infinite loops / deep traversals - if iterations > num_hops: - continue - if subject_node.is_standard: # We found a standard concept -> Check ancestry with target potential_ancestor = kg.get_potential_ancestor( @@ -700,7 +714,7 @@ def find_standard_paths( separation=potential_ancestor.min_levels_of_separation, original_id=candidate.concept_id, original_name=source_view.concept_name, - matched_label=candidate.matched_label, + matched_concept_label=candidate.matched_concept_label, match_kind=mk, synonym=candidate.synonym, ) @@ -720,16 +734,15 @@ def find_standard_paths( if not edges: continue - # Singular trip to the DB for object views - object_ids = tuple(e.object_id for e in edges) - object_views = kg.concept_views(object_ids) + # Batch-fetch views keyed by concept_id. Using a dict avoids the silent + # misalignment that zip produces when two edges share an object id and + # concept_views deduplicates the result set. + unique_object_ids = tuple(dict.fromkeys(e.object_id for e in edges)) + view_map = {v.concept_id: v for v in kg.concept_views(unique_object_ids)} - for edge, object_view in zip(edges, object_views): + for edge in edges: object_id = edge.object_id - - if object_view.concept_id != object_id: - object_view = kg.concept_view(object_id) - + object_view = view_map.get(object_id) or kg.concept_view(object_id) object_is_std = object_view.standard_concept # Optimization: Only traverse to Standard concepts @@ -737,9 +750,6 @@ def find_standard_paths( continue next_iterations = iterations + 1 - if next_iterations > num_hops: - continue - prev_best_iteration = visited_min_iteration.get(object_id) if prev_best_iteration is not None and prev_best_iteration <= next_iterations: continue @@ -798,56 +808,60 @@ def from_path( kg: "KnowledgeGraph", path: GraphPath, match_kind: LabelMatchKind, - embedding_sims: Optional[np.ndarray] = None, + source_concept_id: Optional[int] = None, ) -> "PathProfile": """ Analyze a path to determine the 'Standard Anchor'. - It traverses the path from the candidate term. The first Standard Concept - encountered via a MAPPING or VERSIONING edge is promoted as the Anchor. + The first Standard Concept encountered via an IDENTITY edge is promoted as + the anchor. + + Notes + ----- + For zero-hop paths (source == target), ``source_concept_id`` + must be provided; a ``ValueError`` is raised otherwise. + + Parameters + ---------- + source_concept_id : int, optional + Required when ``path`` has no steps (i.e. source == target). """ - # Path Traversal - standard_anchor: Optional[Tuple[int, str]] = None + if not path.steps: + if source_concept_id is None: + raise ValueError( + "source_concept_id is required for zero-hop paths " + "(find_shortest_paths was called with source == target)." + ) + view = kg.concept_view(source_concept_id) + return cls( + concept_id=source_concept_id, + concept_name=view.concept_name, + is_standard=view.standard_concept, + original_concept_id=source_concept_id, + original_concept_name=view.concept_name, + path=path, + ) - # Pre-fetch views to check standard status - # path.nodes() returns tuple of IDs (start + all objects) node_ids = path.nodes() - concept_views = kg.concept_views(node_ids) - - # NOTE: kg.concept_views usually returns tuple. - # If order is guaranteed, we can index by step. - # Ideally, map by ID to be safe. - view_map = {v.concept_id: v for v in concept_views} - - # Helper to get view by index in path sequence + view_map = {v.concept_id: v for v in kg.concept_views(node_ids)} + def get_view(idx): - cid = node_ids[idx] - return view_map[cid] + return view_map[node_ids[idx]] predicate_kinds = kg.predicate_kinds(tuple(p.predicate for p in path.steps)) + standard_anchor: Optional[Tuple[int, str]] = None for step_idx in range(len(path.steps)): - predicate_kind = predicate_kinds[step_idx] - - # We promote the first swap to a standard concept as the anchor point - # Check Next Node (index + 1) next_view = get_view(step_idx + 1) - - is_translation_edge = predicate_kind in ( - ClassIDEnum.IDENTITY, - ) - if ( - is_translation_edge + predicate_kinds[step_idx] is ClassIDEnum.IDENTITY and not standard_anchor and next_view.standard_concept ): standard_anchor = (next_view.concept_id, next_view.concept_name) - - # Logic for scoring/indices removed as it wasn't used in return first_view = get_view(0) - + if standard_anchor is None: concept_id = first_view.concept_id concept_name = first_view.concept_name @@ -898,7 +912,8 @@ def from_path( Construct an explanation by combining the path, the trace log, and semantic profiles. """ steps: List[PathExplanationStep] = [] - profile = PathProfile.from_path(kg, path, match_kind=match_kind) + source = path.steps[0].subject.concept_id if path.steps else (trace.seeds[0] if trace.seeds else None) + profile = PathProfile.from_path(kg, path, match_kind=match_kind, source_concept_id=source) for step in path.steps: ts = trace_contains_step(trace, step) diff --git a/src/omop_graph/graph/queries.py b/src/omop_graph/graph/queries.py index 118aa83..ee0f906 100644 --- a/src/omop_graph/graph/queries.py +++ b/src/omop_graph/graph/queries.py @@ -17,7 +17,7 @@ from typing import Optional, Tuple, Literal, Union from datetime import date -from sqlalchemy import and_, case, exists, func, literal, select +from sqlalchemy import and_, case, exists, func, literal, select, Engine, inspect, column from sqlalchemy.orm import aliased from sqlalchemy.sql import Select @@ -34,7 +34,7 @@ Relationship, ) -from ..extensions.omop_alchemy import RelationshipClass, RelationshipMapping, ClassIDEnum +from ..extensions.omop_alchemy import RelationshipMapping, ClassIDEnum from .constraints import SearchConstraintConcept @@ -209,17 +209,18 @@ def q_concept_synonym() -> Select: def q_concept_name_match( - name: str, + query_concept_name: str, search_constraint: Optional[SearchConstraintConcept] = None, synonym: bool = False, sort: bool = True, + **kwargs, ) -> Select: """ Query for exact case-insensitive matches on concept names. Parameters ---------- - name : str + query_concept_name : str The concept name to match. search_constraint : SearchConstraintConcept, optional Additional filters (domain, vocab). @@ -235,11 +236,11 @@ def q_concept_name_match( if synonym: base_stmt = q_concept_synonym().where( - func.lower(name_expr) == func.lower(name) + func.lower(name_expr) == func.lower(query_concept_name) ) else: base_stmt = q_concept_name().where( - func.lower(name_expr) == func.lower(name) + func.lower(name_expr) == func.lower(query_concept_name) ) if search_constraint: if not isinstance(search_constraint, SearchConstraintConcept): @@ -253,18 +254,19 @@ def q_concept_name_match( def q_concept_name_ilike( - term: str, + query_concept_name: str, search_constraint: Optional[SearchConstraintConcept] = None, synonym: bool = False, sort: bool = True, + **kwargs ) -> Select: """ Query for partial matches on concept names using ILIKE. Parameters ---------- - term : str - The search term (without wildcards; wildcards are added automatically). + query_concept_name : str + The concept name to search for. search_constraint : SearchConstraintConcept, optional Additional filters. synonym : bool, optional @@ -277,13 +279,16 @@ def q_concept_name_ilike( """ name_expr = Concept_Synonym.concept_synonym_name if synonym else Concept.concept_name + if "%" in query_concept_name: + raise ValueError("query_concept_name should not contain wildcards like '%'.") + if synonym: base_stmt = q_concept_synonym().where( - name_expr.ilike(f"%{term}%") + name_expr.ilike(f"%{query_concept_name}%") ) else: base_stmt = q_concept_name().where( - name_expr.ilike(f"%{term}%") + name_expr.ilike(f"%{query_concept_name}%") ) if search_constraint: if not isinstance(search_constraint, SearchConstraintConcept): @@ -297,7 +302,9 @@ def q_concept_name_ilike( def q_concept_name_fulltext( - term: str, + query_concept_name: str, + *, + engine: Engine, search_constraint: Optional['SearchConstraintConcept'] = None, synonym: bool = False, sort: bool = True, @@ -314,8 +321,8 @@ def q_concept_name_fulltext( Parameters ---------- - term : str - The search term to match. + query_concept_name : str + The concept name to search for. search_constraint : SearchConstraintConcept, optional Additional filters (domain, vocab). synonym : bool, optional @@ -324,22 +331,22 @@ def q_concept_name_fulltext( """ name_expr = Concept_Synonym.concept_synonym_name if synonym else Concept.concept_name - if synonym: - vector = Concept_Synonym.__table__.c.get(CONCEPT_SYNONYM_NAME_TSVECTOR_COLUMN) - stmt = q_concept_synonym() - else: - vector = Concept.__table__.c.get(CONCEPT_NAME_TSVECTOR_COLUMN) - stmt = q_concept_name() + inspector = inspect(engine) + target_table = Concept_Synonym if synonym else Concept + target_col = CONCEPT_SYNONYM_NAME_TSVECTOR_COLUMN if synonym else CONCEPT_NAME_TSVECTOR_COLUMN + stmt = q_concept_synonym() if synonym else q_concept_name() - if vector is None: + tsvector_col = next((c["name"] for c in inspector.get_columns(target_table.__tablename__) + if c["name"] == target_col), None) + + if tsvector_col is None: raise FullTextError( - "Full-text search is disabled because the optional OMOP Alchemy " - "tsvector columns are not registered on the current ORM metadata. " - "Run `omop-maint fulltext install` and `omop-maint fulltext populate` " - "to enable FTS, or skip LabelMatchKind.FTS resolvers." + f"Full-text search column '{target_col}' not found in table '{target_table.__tablename__}'. " + "Make sure to run 'omop-maint fulltext install' and 'omop-maint fulltext populate' to set up full-text search." ) - - query = func.plainto_tsquery("english", term) + + vector = column(tsvector_col) + query = func.plainto_tsquery("english", query_concept_name) stmt = stmt.where(vector.op("@@")(query)) # Hits the GIN index instantly @@ -422,17 +429,24 @@ def q_predicate_row_with_ancestry(relationship_id: str) -> Select: def q_all_predicates_with_ancestry() -> Select: - """Query all predicates with derived ancestry direction flags.""" + """Query all predicates with derived ancestry direction flags and classification.""" Rel = Relationship Rev = aliased(Relationship) - return select( - Rel.relationship_id, - Rel.relationship_name, - Rel.reverse_relationship_id, - Rel.is_hierarchical, - Rel.defines_ancestry.label("anc_down"), - Rev.defines_ancestry.label("anc_up"), - ).join(Rev, Rel.reverse_relationship_id == Rev.relationship_id) + Rm = aliased(RelationshipMapping) + return ( + select( + Rel.relationship_id, + Rel.relationship_name, + Rel.reverse_relationship_id, + Rel.is_hierarchical, + Rel.defines_ancestry.label("anc_down"), + Rev.defines_ancestry.label("anc_up"), + Rm.class_id, + Rm.subclass_id, + ) + .join(Rev, Rel.reverse_relationship_id == Rev.relationship_id) + .join(Rm, Rel.relationship_id == Rm.relationship_id) + ) def q_edges( @@ -452,14 +466,14 @@ def q_edges( Obj = aliased(Concept) stmt = select( - Concept_Relationship.concept_id_1, - Concept_Relationship.relationship_id, - Concept_Relationship.concept_id_2, + Concept_Relationship.concept_id_1.label("subject_id"), + Concept_Relationship.relationship_id.label("predicate_id"), + Concept_Relationship.concept_id_2.label("object_id"), Concept_Relationship.valid_start_date, Concept_Relationship.valid_end_date, Concept_Relationship.invalid_reason, RelationshipMapping.class_id, - RelationshipMapping.subclass_id + RelationshipMapping.subclass_id, ).join( RelationshipMapping, Concept_Relationship.relationship_id == RelationshipMapping.relationship_id @@ -657,7 +671,7 @@ def q_concept_vocabulary_ids() -> Select: def q_concept_potential_ancestor(child_id: int, parent_id: int) -> Select: """ - Check if a parent is an ancestor of a child (separation > 1). + Check if a parent is an ancestor of a child (including immediate parent). """ return select( Concept_Ancestor.ancestor_concept_id, @@ -667,7 +681,7 @@ def q_concept_potential_ancestor(child_id: int, parent_id: int) -> Select: and_( Concept_Ancestor.ancestor_concept_id == parent_id, Concept_Ancestor.descendant_concept_id == child_id, - Concept_Ancestor.min_levels_of_separation > 1, + Concept_Ancestor.min_levels_of_separation > 0, ) ) @@ -681,10 +695,11 @@ def q_concept_num_ancestors(concept_ids: Tuple[int, ...]) -> Select: func.count(Concept_Ancestor.ancestor_concept_id).label("num_ancestors"), ) .join( - Concept_Ancestor, + Concept_Ancestor, Concept.concept_id == Concept_Ancestor.descendant_concept_id ) .where(Concept.concept_id.in_(concept_ids)) + .where(Concept_Ancestor.min_levels_of_separation > 0) .group_by(Concept.concept_id) ) @@ -702,6 +717,7 @@ def q_concept_num_descendants(concept_ids: Tuple[int, ...]) -> Select: Concept_Ancestor, Concept.concept_id == Concept_Ancestor.ancestor_concept_id ) .where(Concept.concept_id.in_(concept_ids)) + .where(Concept_Ancestor.min_levels_of_separation > 0) .group_by(Concept.concept_id) ) diff --git a/src/omop_graph/graph/scoring.py b/src/omop_graph/graph/scoring.py index 111697a..577b451 100644 --- a/src/omop_graph/graph/scoring.py +++ b/src/omop_graph/graph/scoring.py @@ -14,7 +14,7 @@ import re from dataclasses import dataclass, field from difflib import SequenceMatcher -from typing import TYPE_CHECKING, List, Optional, Tuple, Mapping +from typing import TYPE_CHECKING, List, Optional, Tuple import numpy as np @@ -23,6 +23,7 @@ if TYPE_CHECKING: from omop_graph.graph.kg import KnowledgeGraph + from omop_emb.utils.embedding_utils import NearestConceptMatch logger = logging.getLogger(__name__) @@ -40,7 +41,8 @@ class StandardConceptWithScore(StandardConcept): embedding_score : float, optional The cosine similarity score from the embedding model. relevance : float - The composite relevance score (embedding * textual similarity). + The relevance score used for ranking: embedding similarity when available, + textual similarity otherwise. parsimony_penalty : float Penalty based on graph distance (separation). broadness_bonus : float @@ -87,45 +89,62 @@ def score_standard_concepts( text: str, standard_concepts: tuple[StandardConcept, ...], kg: "KnowledgeGraph", - similarity_scores_with_concept_ids: Optional[Tuple[Mapping[int, float], ...]] = None, + nearest_concept_matches: Optional[Tuple[Tuple[NearestConceptMatch, ...], ...]] = None, ) -> List[StandardConceptWithScore]: """ - Rank a list of standard concepts against a query text. + Attach scoring metrics to each standard concept. + + Notes + ----- + Scores are computed but the returned list preserves the input order. + Callers are responsible for sorting if ranking is required. Parameters ---------- text : str The original query text. standard_concepts : tuple[StandardConcept, ...] - The tuple of candidate concepts to score. + The candidate concepts to score. kg : KnowledgeGraph - The graph instance used for retrieving metadata (like ancestor counts). - similarity_scores_with_concept_ids : Tuple[Mapping[int, float], ...], optional - Pre-computed embedding similarity scores. The outer tuple corresponds to the query vectors in order, and each inner dictionary maps concept IDs to their similarity scores with the query embedding. + The graph instance used for retrieving metadata (ancestor counts). + nearest_concept_matches : Tuple[Tuple[NearestConceptMatch, ...], ...], optional + Pre-computed nearest-concept matches from the embedding index. The outer + tuple corresponds to query vectors in order; each inner tuple holds the + nearest matches for that query vector. Currently only a single query + vector is supported. Returns ------- list[StandardConceptWithScore] - The list of concepts with scores attached. + Scored concepts in the same order as ``standard_concepts``. """ # Get specificity scores (ancestor counts) for the standard concepts sc_dict = {sc.concept_id: sc for sc in standard_concepts} num_ancestors = kg.get_num_ancestors(tuple(sc_dict.keys())) # singular text - if similarity_scores_with_concept_ids is None: - similarity_scores_with_concept_ids = ({}, ) - - assert len(similarity_scores_with_concept_ids) == 1, "Currently only supports scoring with a single query embedding vector for the singular text input" - _similarity_scores_dict = similarity_scores_with_concept_ids[0] + if nearest_concept_matches is None: + nearest_concept_matches_dict = ({}, ) + else: + nearest_concept_matches_dict = tuple( + { + match.concept_id: match.similarity + for match in matches_for_query + } + for matches_for_query in nearest_concept_matches + ) + + assert len(nearest_concept_matches_dict) == 1, "Currently only supports scoring with a single query embedding vector for the singular text input" + nearest_concept_matches_dict_for_single_query = nearest_concept_matches_dict[0] + ranked_concepts = [ _score_standard_concept( text=text, kg=kg, standard_concept=sc, num_ancestors=num_ancestors.get(sc.concept_id, 0), - similarity_score=_similarity_scores_dict.get(sc.concept_id, None), + similarity_score=nearest_concept_matches_dict_for_single_query.get(sc.concept_id, None), ) for sc in standard_concepts ] @@ -175,7 +194,7 @@ def _score_standard_concept( if similarity_score is None: relevance = _textual_similarity_score( - query_text=text, matched_label=standard_concept.matched_label + query_text=text, matched_concept_label=standard_concept.matched_concept_label ) else: relevance = similarity_score @@ -197,7 +216,7 @@ def _score_standard_concept( def _textual_similarity_score( query_text: str, - matched_label: str, + matched_concept_label: str, similarity_threshold: float = 0.85, missing_penalty: float = 2.0, extra_penalty: float = 0.5, @@ -213,7 +232,7 @@ def _textual_similarity_score( ---------- query_text : str The user's query. - matched_label : str + matched_concept_label : str The label of the candidate concept. similarity_threshold : float, optional Minimum Levenshtein ratio to consider two tokens a 'match'. Default 0.85. @@ -235,7 +254,7 @@ def tokenize(text: str) -> List[str]: return [t for t in tokens if t not in stop_words] q_tokens = tokenize(query_text) - m_tokens = tokenize(matched_label) + m_tokens = tokenize(matched_concept_label) if not q_tokens or not m_tokens: return 0.0 diff --git a/src/omop_graph/graph/traverse.py b/src/omop_graph/graph/traverse.py index 03c3827..c9dd5a7 100644 --- a/src/omop_graph/graph/traverse.py +++ b/src/omop_graph/graph/traverse.py @@ -198,8 +198,15 @@ def traverse( TraceStep(depth=depth, node=node, expanded_edges=tuple(expanded)) ) - # Deduplicate edges found (multiple paths might traverse the same edge) - dedup = {(e.subject_id, e.predicate_id, e.object_id): e for e in edges_out} + # Deduplicate edges and drop any that target an unvisited node. + # The latter can happen when max_nodes terminates the loop while neighbour + # nodes are still queued but never expanded, which would break the invariant + # that every edge in the Subgraph is closed over its node set. + dedup = { + (e.subject_id, e.predicate_id, e.object_id): e + for e in edges_out + if e.object_id in visited + } sg = Subgraph(nodes=frozenset(visited), edges=tuple(dedup.values())) graph_trace = ( diff --git a/src/omop_graph/oaklib_interface/omop_factory.py b/src/omop_graph/oaklib_interface/omop_factory.py index 0bd0f62..d52d6e0 100644 --- a/src/omop_graph/oaklib_interface/omop_factory.py +++ b/src/omop_graph/oaklib_interface/omop_factory.py @@ -4,14 +4,85 @@ from sqlalchemy.engine import URL from .omop_resource import OMOPOntologyResource +from omop_graph.config import ( + ENV_OMOP_CDM_DB_URL, + ENV_OMOP_CDM_DB_HOST, + ENV_OMOP_CDM_DB_NAME, + ENV_OMOP_CDM_DB_PASSWORD, + ENV_OMOP_CDM_DB_PORT, + ENV_OMOP_CDM_DB_USER, + ENV_OMOP_CDM_DB_DRIVER, +) -OMOP_DATABASE_ENV_VAR = "OMOP_DATABASE_URL" + +def build_engine_string() -> URL: + """Compose a SQLAlchemy ``URL`` for the OMOP CDM database from environment variables. + + Returns + ------- + sqlalchemy.URL + + Notes + ----- + If ``OMOP_CDM_DB_URL`` is set it is used as-is for any backend, allowing + callers to supply a fully-qualified connection string without setting the + individual component variables. + + Raises + ------ + RuntimeError + If a required environment variable is missing. + """ + from sqlalchemy import URL + from sqlalchemy.engine import make_url + + optional_url = os.getenv(ENV_OMOP_CDM_DB_URL) + if optional_url: + return make_url(optional_url) + + # Required variables for composing the URL + driver = _get_required_env_variable(ENV_OMOP_CDM_DB_DRIVER) + user = _get_required_env_variable(ENV_OMOP_CDM_DB_USER) + password = _get_required_env_variable(ENV_OMOP_CDM_DB_PASSWORD) + host = _get_required_env_variable(ENV_OMOP_CDM_DB_HOST) + database = _get_required_env_variable(ENV_OMOP_CDM_DB_NAME) + port = int(_get_required_env_variable(ENV_OMOP_CDM_DB_PORT)) + return URL.create( + drivername=driver, + username=user, + password=password, + host=host, + port=port, + database=database, + ) + +def _get_required_env_variable(name: str) -> str: + """Get the value of a required environment variable. + + Parameters + ---------- + name : str + Environment variable name. + + Returns + ------- + str + Environment variable value. + + Raises + ------ + RuntimeError + If the environment variable is not set. + """ + value = os.getenv(name) + if value is None: + raise RuntimeError(f"Required environment variable {name!r} is not set.") + return value def omop_resource( *, url: Optional[Union[str, URL]] = None, - env_var: str = OMOP_DATABASE_ENV_VAR, slug: Optional[str] = "omop", ) -> OMOPOntologyResource: """ @@ -27,7 +98,7 @@ def omop_resource( The explicit database connection URL (highest priority). env_var : str, optional The name of the environment variable to check if `url` is None. - Defaults to 'OMOP_DATABASE_URL'. + Defaults to 'OMOP_CDM_DB_URL'. slug : str, optional A slug identifier for the resource. Defaults to 'omop'. @@ -41,11 +112,11 @@ def omop_resource( ValueError If neither `url` is provided nor the `env_var` is set. """ - resolved = url or os.getenv(env_var) + resolved = url or build_engine_string() if not resolved: raise ValueError( - f"No database URL provided and environment variable {env_var} is not set" + f"No database URL provided and required environment variables not set" ) return OMOPOntologyResource( diff --git a/src/omop_graph/oaklib_interface/omop_implementation.py b/src/omop_graph/oaklib_interface/omop_implementation.py index a85491b..68ef4cc 100644 --- a/src/omop_graph/oaklib_interface/omop_implementation.py +++ b/src/omop_graph/oaklib_interface/omop_implementation.py @@ -1,7 +1,7 @@ import logging import re from collections import defaultdict -from typing import Dict, Iterable, Iterator, List, Optional, Tuple, TYPE_CHECKING +from typing import Dict, Iterable, Iterator, List, Optional, Tuple import numpy as np from dotenv import load_dotenv @@ -29,13 +29,11 @@ from oaklib.interfaces.text_annotator_interface import nen_annotation from oaklib.types import CURIE, PRED_CURIE -from omop_alchemy.cdm.model import Concept, Concept_Relationship from omop_graph.graph import ( KnowledgeGraph, KnowledgeGraphEmbeddingConfiguration ) from omop_graph.extensions.omop_alchemy import ClassIDEnum -from omop_graph.extensions.emb import EmbeddingBackendType, MissingExtensionError, get_embedding_writer_interface from omop_graph.graph.constraints import SearchConstraintConcept from omop_graph.graph.nodes import LabelMatchKind from omop_graph.reasoning.grounding import GroundingConstraints, ground_term @@ -45,13 +43,9 @@ from omop_graph.oaklib_interface.omop_resource import OMOPOntologyResource from omop_graph.oaklib_interface.omop_factory import omop_resource -if TYPE_CHECKING: - from omop_emb import EmbeddingClient -from orm_loader.helpers.bootstrap import create_db -from sqlalchemy import create_engine, select +from sqlalchemy import create_engine from sqlalchemy.engine import URL -from sqlalchemy.orm import sessionmaker logger = logging.getLogger(__name__) @@ -240,13 +234,12 @@ def __init__( def _simple_tokenizer(self, text: str): for m in re.finditer(r"\b[\w\- ]{3,}\b", text): yield m.start(), m.end(), m.group() - + def annotate_text( self, text: str, - text_embedding: Optional[np.ndarray] = None, - text_embedding_model: Optional[str] = None, configuration: Optional[TextAnnotationConfiguration] = None, + query_embedding: Optional[np.ndarray] = None, annotations: Optional[Dict[str, Annotation]] = None, ) -> Iterator[TextAnnotation]: """ @@ -256,8 +249,9 @@ def annotate_text( ---------- text : str The input text to annotate. - text_embedding : np.ndarray - The embedding of the input text. + query_embedding : np.ndarray + Pre-computed query embedding for the input text. When None and the KG + has a writer interface, the embedding is computed on demand. configuration : TextAnnotationConfiguration, optional Configuration settings for annotation (e.g., token exclusion). annotations : Dict[str, Annotation], optional @@ -345,10 +339,9 @@ def split_annotations(ann): grounded = ground_term( resolver_pipeline=resolver_pipeline, kg=self.kg, - text=text, + query=text, constraints=constraints, - text_embedding=text_embedding, - text_embedding_model=text_embedding_model, + query_embedding=query_embedding, ) if not grounded: @@ -423,7 +416,7 @@ def basic_search( synonym=False, ) for lm in matches: - cid = lm.concept_id + cid = lm.matched_concept_id if cid not in seen: seen.add(cid) yield self._predicate_curie(cid) @@ -435,7 +428,7 @@ def basic_search( synonym=True, ) for lm in matches: - cid = lm.concept_id + cid = lm.matched_concept_id if cid not in seen: seen.add(cid) yield self._predicate_curie(cid) @@ -469,7 +462,7 @@ def __init__(self, *args, kg: KnowledgeGraph, **kwargs): def supports_reasoning(self) -> bool: return False - def entity_aliases(self, curie: CURIE) -> Iterable[str]: + def entity_aliases(self, curie: CURIE) -> List[str]: """ Retrieve aliases (synonyms and codes) for a given entity. @@ -480,7 +473,7 @@ def entity_aliases(self, curie: CURIE) -> Iterable[str]: Returns ------- - Iterable[str] + List[str] A sorted list of aliases. """ cid = self._parse_concept(curie) @@ -494,7 +487,7 @@ def entity_aliases(self, curie: CURIE) -> Iterable[str]: # vocabulary-qualified code alias aliases.add(f"{cv.vocabulary_id}:{cv.concept_code}") - return sorted(aliases) + return list(sorted(aliases)) def parents(self, curie: CURIE) -> Iterable[CURIE]: """ @@ -528,7 +521,11 @@ def languages(self) -> Iterable[str]: def multilingual(self) -> bool: return False - def entities( + @multilingual.setter + def multilingual(self, value: bool) -> None: + pass # OMOP is always monolingual; setter required by base interface contract + + def entities( # type: ignore[override] self, domain: str | None = None, standard_only: bool = True, @@ -619,24 +616,25 @@ def curies_by_label(self, label: str) -> List[CURIE]: cids = self.kg.concept_ids_by_label(label.strip()) return [self._concept_curie(cid) for cid in cids] - def relationships( + def relationships( # type: ignore[override] self, - subjects: list[CURIE] | None = None, - predicates: list[str] | None = None, - objects: list[CURIE] | None = None, + subjects: Iterable[CURIE] | None = None, + predicates: Iterable[str] | None = None, + objects: Iterable[CURIE] | None = None, invert: bool = False, + **kwargs, ) -> Iterable[Tuple[CURIE, PRED_CURIE, CURIE]]: """ Query relationships between concepts. Parameters ---------- - subjects : list[CURIE] | None - List of subject CURIEs. - predicates : list[str] | None - List of predicate (relationship) IDs. - objects : list[CURIE] | None - List of object CURIEs. + subjects : Iterable[CURIE] | None + Subject CURIEs to filter on. + predicates : Iterable[str] | None + Predicate (relationship) IDs to filter on. + objects : Iterable[CURIE] | None + Object CURIEs to filter on. invert : bool If True, swaps subjects and objects in the query and result. @@ -676,7 +674,6 @@ def hierarchical_parents( return [self._concept_curie(p) for p in parents] def simple_mappings_by_curie(self, curie: CURIE): - cid = self._parse_concept(curie) raise NotImplementedError( "TODO: need to implement mapping interface and have self.sssom_mappings" ) @@ -760,7 +757,7 @@ def entailed_outgoing_relationships( pred_curie = self._predicate_curie(edge.predicate_id) # hierarchical entailment - if self.kg.predicate_kind(edge.predicate_id) == PredicateKind.ONTO_UP: + if self.kg.predicate_kind(edge.predicate_id) == ClassIDEnum.HIERARCHY: yield pred_curie, self._concept_curie(edge.object_id) for parent in self.kg.parents(edge.object_id): @@ -792,19 +789,16 @@ def entailed_incoming_relationships( ) with self.kg.session_factory() as session: - # Consume and close the session - edges = self.kg.iter_edges( + for edge in self.kg.iter_edges( session=session, concept_ids=concept_id, direction="in", - predicate_ids=frozenset(pred_filter) if pred_filter else None - ) - - for edge in edges: - yield ( - self._predicate_curie(edge.predicate_id), - self._concept_curie(edge.subject_id), - ) + predicate_ids=frozenset(pred_filter) if pred_filter else None, + ): + yield ( + self._predicate_curie(edge.predicate_id), + self._concept_curie(edge.subject_id), + ) def entailed_incoming_relationships_by_curie( self, *args, **kwargs @@ -837,7 +831,7 @@ def entailed_relationships_between( yield self._predicate_curie("is_a") -class OMOPAlchemyImplementation( +class OMOPAlchemyImplementation( # type: ignore[override] OMOPRelationGraphInterface, OMOPSearchInterface, OMOPTextAnnotatorInterface, @@ -898,27 +892,19 @@ def __init__( assert self.engine_string is not None, "No database URL provided for OMOPAlchemyImplementation" - self.engine = create_engine(self.engine_string, future=True, echo=False) - create_db(self.engine) + engine = create_engine(self.engine_string, future=True, echo=False) - self._session_factory = sessionmaker(self.engine) self._connection = None if kg is None: kg = KnowledgeGraph( - session_factory=self._session_factory, emb_config=kg_emb_config, + cdm_engine=engine ) bind_default_renderers(kg) super().__init__(kg=kg, **kwargs) - @property - def session_factory(self) -> sessionmaker: - """ - Return the factory to a session. - """ - return self._session_factory # TODO: Implement if necessary! def _all_relationships(self): diff --git a/src/omop_graph/oaklib_interface/omop_resource.py b/src/omop_graph/oaklib_interface/omop_resource.py index cd896b5..4ca5a7e 100644 --- a/src/omop_graph/oaklib_interface/omop_resource.py +++ b/src/omop_graph/oaklib_interface/omop_resource.py @@ -29,12 +29,12 @@ class OMOPOntologyResource(OntologyResource): Whether the resource is read-only. Defaults to True. """ - url: Optional[Union[str, URL]] = None - slug: Optional[str] = None - scheme: str = "omop_alchemy" - local: bool = False - in_memory: bool = False - readonly: bool = True + url: Optional[Union[str, URL]] = None # type: ignore[assignment] + slug: Optional[str] = None # type: ignore[assignment] + scheme: str = "omop_alchemy" # type: ignore[assignment] + local: bool = False # type: ignore[assignment] + in_memory: bool = False # type: ignore[assignment] + readonly: bool = True # type: ignore[assignment] def _parsed_url(self) -> Optional[URL]: """ diff --git a/src/omop_graph/reasoning/grounding.py b/src/omop_graph/reasoning/grounding.py index 9d8eb15..3afa894 100644 --- a/src/omop_graph/reasoning/grounding.py +++ b/src/omop_graph/reasoning/grounding.py @@ -15,7 +15,7 @@ import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple, Sequence, Mapping +from typing import List, Optional, Tuple import numpy as np @@ -32,15 +32,10 @@ ResolverPipeline, ) from omop_graph.extensions.emb import ( - EmbeddingIndexType, - EmbeddingMetricType, get_embedding_writer_interface, semantic_similarity, ) -if TYPE_CHECKING: - from omop_emb import EmbeddingClient - logger = logging.getLogger(__name__) @@ -71,13 +66,10 @@ class GroundingConstraints: def ground_term( resolver_pipeline: ResolverPipeline, kg: KnowledgeGraph, - text: str, - text_embedding: Optional[np.ndarray], - text_embedding_model: Optional[str], + query: str, + query_embedding: Optional[np.ndarray], constraints: GroundingConstraints, max_candidates: Optional[int] = None, - metric_type: Optional[EmbeddingMetricType] = None, - index_type: Optional[EmbeddingIndexType] = None, ) -> List[StandardConceptWithScore]: """ Ground a text string to a ranked list of standard OMOP concepts. @@ -88,18 +80,15 @@ def ground_term( The pipeline of search strategies to find initial candidates. kg : KnowledgeGraph The OMOP Knowledge Graph instance. - text : str - The input text to ground. - text_embedding : np.ndarray - The embedding vector for the input text. - text_embedding_model : str, optional - The name of the embedding model used to generate `text_embedding`. Used for RAG retrieval from the database. + query : str + The input query to ground. + query_embedding : np.ndarray + The embedding vector for the input query. When None and a writer interface is + available, the embedding is computed on demand from ``query``. constraints : GroundingConstraints Contextual constraints (parents, domains, etc.) to apply. max_candidates : int, optional Limit for the number of candidates returned. If None, returns all candidates. - metric_type : EmbeddingMetricType, optional - The similarity or distance metric to use for optional embedding-based scoring. Returns ------- @@ -117,34 +106,40 @@ def ground_term( if search_constraints is not None: kg.check_search_constraints(search_constraints) - # Calculate the text embedding on demand if possible - embedding_writer = get_embedding_writer_interface(kg) - if ( - embedding_writer is not None and - text_embedding is None - ): - if embedding_writer._embedding_client is None: - logger.info("Embedding interface is available but no embedding_client provided. Skipping embedding-based scoring.") - else: - text_embedding = embedding_writer.embed_texts(texts=(text,)) + # If no embedding was passed, try to compute one on demand via the writer interface. + # Falls back to None, which disables embedding-based features for this call. + if query_embedding is None: + embedding_writer = get_embedding_writer_interface(kg) + if embedding_writer is not None: + from omop_emb.embeddings import EmbeddingRole + query_embedding = embedding_writer.embed_texts( + texts=(query,), + embedding_role=EmbeddingRole.QUERY, + ) - if text_embedding is not None: - # TODO: Support grounding to more texts - assert text_embedding.shape[0] == 1, "text_embedding should have shape (1, embedding_dim) for a single term to be grounded." + if query_embedding is not None: + assert query_embedding.shape[0] == 1, ( + "query_embedding must have shape (1, D) — one vector per call to ground_term." + ) + else: + logger.info( + f"No text embedding provided for '{query}' and no embedding_writer available. " + "Embedding-based features will be disabled for this grounding operation." + ) resolved = list( resolver_pipeline.resolve( kg, - text, + query, constraints=search_constraints, - text_embedding=text_embedding, - text_embedding_model=text_embedding_model, - metric_type=metric_type, - index_type=index_type, + query_embedding=query_embedding, ) ) + if not resolved: + logger.info(f"No candidates found for '{query}' using the resolver pipeline: {resolver_pipeline}") + return [] - # Anchoring + # Hierarchy anchoring for hit in resolved: if constraints.parent_ids is not None: candidate_standard_concepts = find_standard_concepts( @@ -155,54 +150,42 @@ def ground_term( max_paths=None, predicate_kinds=constraints.predicate_kinds, ) - if not candidate_standard_concepts: concept_name = kg.concept_view(hit.concept_id).concept_name logger.debug( f"Failed hierarchy constraint: {hit.concept_id} ({concept_name}) " - f"has no path to parents {constraints.parent_ids} with {constraints.max_depth} max depth and predicates {constraints.predicate_kinds}." + f"has no path to parents {constraints.parent_ids} " + f"(max_depth={constraints.max_depth}, predicates={constraints.predicate_kinds})" ) continue - standard_concepts.extend(candidate_standard_concepts) else: - # Note: We currently require parent_ids for clinical safety/context raise NotImplementedError("Grounding without parent_ids is not supported.") if not standard_concepts: - logger.info(f"No standard concepts found for '{text}' after hierarchy validation.") + logger.info(f"No standard concepts found for '{query}' after hierarchy validation.") return [] - - similarity_scores_with_concept_ids = semantic_similarity( - kg=kg, - standard_concepts=standard_concepts, - text_embedding=text_embedding, - text_embedding_model=text_embedding_model, - metric_type=metric_type, - index_type=index_type, + nearest_concept_matches = ( + semantic_similarity(kg=kg, standard_concepts=standard_concepts, query_embedding=query_embedding) + if query_embedding is not None + else None ) - # Scoring - ranked_standard_concepts = score_standard_concepts( - text=text, + standard_concepts_with_score = score_standard_concepts( + text=query, standard_concepts=tuple(standard_concepts), kg=kg, - similarity_scores_with_concept_ids=similarity_scores_with_concept_ids + nearest_concept_matches=nearest_concept_matches, ) - # Keep one best-scoring entry per standard concept after scoring all evidence. best_by_concept_id: dict[int, StandardConceptWithScore] = {} - for concept in ranked_standard_concepts: + for concept in standard_concepts_with_score: existing = best_by_concept_id.get(concept.concept_id) if existing is None or concept.total_score > existing.total_score: best_by_concept_id[concept.concept_id] = concept - deduped_ranked = sorted( - best_by_concept_id.values(), - key=lambda sc: sc.total_score, - reverse=True, - ) + deduped_ranked = sorted(best_by_concept_id.values(), key=lambda sc: sc.total_score, reverse=True) return deduped_ranked[:max_candidates] if max_candidates is not None else deduped_ranked diff --git a/src/omop_graph/reasoning/resolvers/resolver_pipeline.py b/src/omop_graph/reasoning/resolvers/resolver_pipeline.py index 8d10f39..80067bf 100644 --- a/src/omop_graph/reasoning/resolvers/resolver_pipeline.py +++ b/src/omop_graph/reasoning/resolvers/resolver_pipeline.py @@ -103,19 +103,19 @@ def with_all_resolvers(cls) -> "ResolverPipeline": def resolve( self, kg: KnowledgeGraph, - text: str, + query: str, constraints: Optional[SearchConstraintConcept] = None, **kwargs ) -> Generator[CandidateHit, None, None]: """ - Execute the pipeline to find candidate concepts for the input text. + Execute the pipeline to find candidate concepts for the input query. Parameters ---------- kg : KnowledgeGraph The graph instance used for lookups. - text : str - The input text to resolve. + query : str + The input query to resolve. constraints : SearchConstraintConcept, optional Domain or vocabulary restrictions to apply to the search. Determines also the number of candidates returned for each resolver using the `limit` field. If None, no additional filtering is applied. @@ -129,8 +129,8 @@ def resolve( for resolver in self.resolvers: hits = resolver.resolve( - kg, - text, + kg=kg, + query=query, constraints=constraints, **kwargs ) @@ -141,6 +141,6 @@ def resolve( yield hit # Early stopping - if type(resolver) == self._stop_at: + if type(resolver) is self._stop_at: logger.info(f"Stopping pipeline after resolver {type(resolver).__name__} as configured.") break \ No newline at end of file diff --git a/src/omop_graph/reasoning/resolvers/resolvers.py b/src/omop_graph/reasoning/resolvers/resolvers.py index 3bb4aee..63a1200 100644 --- a/src/omop_graph/reasoning/resolvers/resolvers.py +++ b/src/omop_graph/reasoning/resolvers/resolvers.py @@ -21,8 +21,6 @@ from omop_graph.graph.nodes import LabelMatch, LabelMatchKind from omop_graph.extensions.emb import ( HAS_OMOP_EMB, - EmbeddingIndexType, - EmbeddingMetricType, get_neareast_concepts, ) @@ -42,13 +40,13 @@ class CandidateHit: The OMOP Concept ID. match_kind : LabelMatchKind The kind of match of this hit. - matched_label : str + matched_concept_label : str The specific text in the database (name or synonym) that matched. """ concept_id: int match_kind: LabelMatchKind - matched_label: str + matched_concept_label: str synonym: bool @@ -84,7 +82,7 @@ def synonym(self) -> bool: def get_matches( self, kg: KnowledgeGraph, - text: str, + query: str, constraints: Optional[SearchConstraintConcept] = None, sort: bool = False, **kwargs @@ -96,8 +94,8 @@ def get_matches( ---------- kg : KnowledgeGraph The graph instance. - text : str - The input text to search for. + query : str + The input query to search for. constraints : SearchConstraintConcept, optional Filters for domain/vocabulary. sort : bool, default False @@ -110,7 +108,7 @@ def get_matches( """ return tuple( kg.concept_lookup( - label=text, + query_term=query, match_kind=self.match_kind, synonym=self.synonym, search_constraint=constraints, @@ -121,7 +119,7 @@ def get_matches( def resolve( self, kg: KnowledgeGraph, - text: str, + query: str, constraints: Optional[SearchConstraintConcept] = None, **kwargs ) -> Iterable[CandidateHit]: @@ -132,8 +130,8 @@ def resolve( ---------- kg : KnowledgeGraph The graph instance. - text : str - The input text. + query : str + The input query. constraints : SearchConstraintConcept, optional Filters for concepts to consider in the search. Also limits the number of candidates returned using the `limit` field. @@ -143,12 +141,12 @@ def resolve( Iterable[CandidateHit] The formatted candidate hits. """ - matches = self.get_matches(kg, text, constraints=constraints, **kwargs) + matches = self.get_matches(kg, query, constraints=constraints, **kwargs) hits = [ CandidateHit( - concept_id=m.concept_id, + concept_id=m.matched_concept_id, match_kind=self.match_kind, - matched_label=m.matched_label, + matched_concept_label=m.matched_concept_label, synonym=self.synonym ) for m in matches @@ -224,47 +222,58 @@ def __init__(self) -> None: def get_matches( self, kg: KnowledgeGraph, - text: str, + query: str, constraints: Optional[SearchConstraintConcept] = None, - text_embedding: Optional[np.ndarray] = None, - text_embedding_model: Optional[str] = None, - metric_type: Optional[EmbeddingMetricType] = None, - index_type: Optional[EmbeddingIndexType] = None, sort: bool = False, + query_embedding: Optional[np.ndarray] = None, + **kwargs, ) -> Tuple[LabelMatch, ...]: - - with kg.session_factory() as session: - matches = get_neareast_concepts( - session=session, - kg=kg, - text_embedding=text_embedding, - text_embedding_model=text_embedding_model, - concept_filter=constraints, - metric_type=metric_type, - index_type=index_type, - ) - if matches is None: - return () - if text_embedding is not None: - assert text_embedding.shape[0] == 1, "text_embedding should have shape (1, embedding_dim) for a single query." - assert len(matches) == 1, "Expected get_neareast_concepts to return a single dictionary given the text_embedding shape (1, embedding_dim)." - matches = matches[0] # Unpack the single dictionary from the tuple - concept_views = kg.concept_views( - concept_ids=tuple(matches.keys()), - sort=sort + + from omop_emb.utils.embedding_utils import EmbeddingConceptFilter + concept_filter: Optional[EmbeddingConceptFilter] = ( + EmbeddingConceptFilter( + concept_ids=constraints.concept_ids, + domains=constraints.domains, + vocabularies=constraints.vocabularies, + require_standard=constraints.require_standard, + limit=constraints.limit, ) - label_matches = tuple( - LabelMatch( - input_label=text, - matched_label=cv.concept_name, - concept_id=int(cv.concept_id), - match_kind=LabelMatchKind.EMBEDDING, - is_standard=bool(cv.standard_concept), - is_active=cv.invalid_reason is None, - ) - for cv in concept_views + if isinstance(constraints, SearchConstraintConcept) + else None + ) + + if query_embedding is None: + logger.warning(f"No text embedding provided for {self.__class__.__name__}. Returning no matches.") + return () + + matches = get_neareast_concepts( + kg=kg, + query_embedding=query_embedding, + concept_filter=concept_filter, + ) + if matches is None: + return () + if query_embedding is not None: + assert query_embedding.shape[0] == 1, "query_embedding should have shape (1, embedding_dim) for a single query." + assert len(matches) == 1, "Expected get_neareast_concepts to return a single dictionary given the query_embedding shape (1, embedding_dim)." + matches = matches[0] # Unpack the single dictionary from the tuple + concept_views = kg.concept_views( + concept_ids=tuple(m.concept_id for m in matches), + sort=sort + ) + label_matches = tuple( + LabelMatch( + input_query=query, + matched_concept_label=cv.concept_name, + matched_concept_id=int(cv.concept_id), + match_kind=LabelMatchKind.EMBEDDING, + is_standard=bool(cv.standard_concept), + is_active=cv.invalid_reason is None, + synonym=False, # Embedding matches are based on the primary name, not synonyms ) - return label_matches + for cv in concept_views + ) + return label_matches # Default sequence of resolvers to be used in a pipeline diff --git a/tests/fixtures/mock_cdm.py b/tests/fixtures/mock_cdm.py index 99eabd2..7e92d04 100644 --- a/tests/fixtures/mock_cdm.py +++ b/tests/fixtures/mock_cdm.py @@ -31,7 +31,7 @@ @pytest.fixture(scope="module") -def mock_cdm_session_factory() -> sessionmaker: +def mock_cdm_engine() -> sa.Engine: engine = sa.create_engine("sqlite+pysqlite:///:memory:", future=True) tables = cast( list[sa.Table], @@ -55,12 +55,12 @@ def mock_cdm_session_factory() -> sessionmaker: with session_local() as session: seed_mock_cdm(session) - return session_local + return engine @pytest.fixture() def mock_cdm_kg( - mock_cdm_session_factory: sessionmaker, + mock_cdm_engine: sa.Engine, monkeypatch: pytest.MonkeyPatch, ) -> KnowledgeGraph: # Ensure cache does not leak between tests. @@ -70,7 +70,7 @@ def mock_cdm_kg( # Grounding tests here focus on SQL + resolver + path pipeline. monkeypatch.setattr("omop_graph.reasoning.grounding.get_embedding_writer_interface", lambda _kg: None) - return KnowledgeGraph(session_factory=mock_cdm_session_factory) + return KnowledgeGraph(cdm_engine=mock_cdm_engine) def seed_mock_cdm(session: Session) -> None: diff --git a/tests/test_embedding_optional.py b/tests/test_embedding_optional.py index e8a577f..a9a4b9f 100644 --- a/tests/test_embedding_optional.py +++ b/tests/test_embedding_optional.py @@ -11,7 +11,7 @@ import numpy as np import pytest -from sqlalchemy.orm import Session +from sqlalchemy import Engine from omop_graph.extensions import emb as emb_ext from omop_graph.extensions.emb import MissingExtensionError @@ -20,41 +20,6 @@ from omop_graph.graph.paths import StandardConcept -def test_get_neareast_concepts_returns_none_when_index_type_missing(monkeypatch: pytest.MonkeyPatch): - mock_reader = Mock() - monkeypatch.setattr(emb_ext, "get_embedding_reader_interface", lambda kg: mock_reader) - kg = cast(KnowledgeGraph, SimpleNamespace(emb=SimpleNamespace())) - - result = emb_ext.get_neareast_concepts( - session=Mock(spec=Session), - kg=kg, - text_embedding_model="test-model", - text_embedding=np.zeros((1, 2), dtype=np.float32), - concept_filter=None, - metric_type=None, - index_type=None, - ) - - assert result is None - - -def test_get_neareast_concepts_returns_none_when_metric_type_missing(monkeypatch: pytest.MonkeyPatch): - mock_reader = Mock() - monkeypatch.setattr(emb_ext, "get_embedding_reader_interface", lambda kg: mock_reader) - kg = cast(KnowledgeGraph, SimpleNamespace(emb=SimpleNamespace())) - - result = emb_ext.get_neareast_concepts( - session=Mock(spec=Session), - kg=kg, - text_embedding_model="test-model", - text_embedding=np.zeros((1, 2), dtype=np.float32), - concept_filter=None, - metric_type=None, - index_type=cast(emb_ext.EmbeddingIndexType, "flat"), - ) - - assert result is None - def test_get_embedding_interface_returns_none_for_missing_extension_error(): class BrokenKG: @@ -83,102 +48,6 @@ def fake_import(name, globals=None, locals=None, fromlist=(), level=0): _ = kg.emb -def test_write_path_forwards_index_type_and_only_missing_ids(monkeypatch: pytest.MonkeyPatch): - """Verify the omop-emb write-path contracts when compute_missing_embeddings is True. - - 1. ``index_type`` is forwarded to ``get_concepts_without_embedding``. - 2. ``add_to_db`` receives only the IDs reported as missing — not the full - candidate set passed to ``semantic_similarity``. - - Concept 3 ("gamma") is in ``standard_concepts`` but is NOT returned by - ``get_concepts_without_embedding``, so ``add_to_db`` must receive only (1, 2). - """ - class FakeEmbeddingInterface: - canonical_model_name = "test-model" - - def __init__(self): - self.last_missing_kwargs = None - self.last_add_kwargs = None - - def get_concepts_without_embedding(self, **kwargs): - self.last_missing_kwargs = kwargs - return {1: "alpha", 2: "beta"} - - def embed_texts(self, texts): - assert tuple(texts) == ("alpha", "beta") - return np.zeros((2, 3), dtype=np.float32) - - def add_to_db(self, **kwargs): - self.last_add_kwargs = kwargs - - emb_interface = FakeEmbeddingInterface() - - class FakeKG: - compute_missing_embeddings = True - - def session_factory(self): - return contextlib.nullcontext(Mock(spec=Session)) - - # get_neareast_concepts is called once; return a proper tuple-of-dicts. - def fake_nearest(*args, **kwargs): - fake_nearest.calls += 1 - return ({1: 0.9, 2: 0.8},) - - fake_nearest.calls = 0 - - monkeypatch.setattr(emb_ext, "HAS_OMOP_EMB", True) - monkeypatch.setattr(emb_ext, "get_embedding_reader_interface", lambda kg: emb_interface) - monkeypatch.setattr(emb_ext, "get_embedding_writer_interface", lambda kg: emb_interface) - monkeypatch.setattr(emb_ext, "get_neareast_concepts", fake_nearest) - - result = emb_ext.semantic_similarity( - kg=cast(KnowledgeGraph, FakeKG()), - standard_concepts=[ - StandardConcept( - concept_id=1, - concept_name="alpha", - separation=0, - original_id=1, - original_name="alpha", - matched_label="alpha", - match_kind=LabelMatchKind.EXACT, - synonym=False - ), - StandardConcept( - concept_id=2, - concept_name="beta", - separation=0, - original_id=2, - original_name="beta", - matched_label="beta", - match_kind=LabelMatchKind.EXACT, - synonym=False - ), - StandardConcept( - concept_id=3, - concept_name="gamma", - separation=0, - original_id=3, - original_name="gamma", - matched_label="gamma", - match_kind=LabelMatchKind.EXACT, - synonym=False - ), - ], - text_embedding=np.zeros((1, 3), dtype=np.float32), - text_embedding_model="test-model", - metric_type=cast(emb_ext.EmbeddingMetricType, "cosine"), - index_type=cast(emb_ext.EmbeddingIndexType, "flat"), - ) - - assert result is not None - assert fake_nearest.calls == 1 - assert emb_interface.last_missing_kwargs is not None - assert emb_interface.last_missing_kwargs["index_type"] == "flat" - assert emb_interface.last_add_kwargs is not None - assert emb_interface.last_add_kwargs["concept_ids"] == (1, 2) - - # ── compute_missing_embeddings flag tests (using real KG + dummy CDM DB) ── def _make_standard_concept(concept_id: int, name: str) -> StandardConcept: @@ -188,7 +57,7 @@ def _make_standard_concept(concept_id: int, name: str) -> StandardConcept: separation=0, original_id=concept_id, original_name=name, - matched_label=name, + matched_concept_label=name, match_kind=LabelMatchKind.EXACT, synonym=False, ) @@ -208,6 +77,7 @@ def test_fallback_flag_true_logs_attempt_when_concepts_missing( """ mock_cdm_kg._emb_config = KnowledgeGraphEmbeddingConfiguration( compute_missing_embeddings=True, + metric_type=cast(emb_ext.EmbeddingMetricType, "cosine"), ) fake_reader = Mock() @@ -226,10 +96,7 @@ def test_fallback_flag_true_logs_attempt_when_concepts_missing( emb_ext.semantic_similarity( kg=mock_cdm_kg, standard_concepts=[_make_standard_concept(196653, "Malignant tumor of kidney")], - text_embedding=np.zeros((1, 3), dtype=np.float32), - text_embedding_model="test-model", - metric_type=cast(emb_ext.EmbeddingMetricType, "cosine"), - index_type=cast(emb_ext.EmbeddingIndexType, "flat"), + query_embedding=np.zeros((1, 3), dtype=np.float32), ) assert "Computing missing embeddings on-the-fly" in caplog.text @@ -248,6 +115,7 @@ def test_fallback_flag_false_logs_disabled_when_concepts_missing( """ mock_cdm_kg._emb_config = KnowledgeGraphEmbeddingConfiguration( compute_missing_embeddings=False, + metric_type=cast(emb_ext.EmbeddingMetricType, "cosine"), ) fake_reader = Mock() @@ -263,10 +131,7 @@ def test_fallback_flag_false_logs_disabled_when_concepts_missing( emb_ext.semantic_similarity( kg=mock_cdm_kg, standard_concepts=[_make_standard_concept(196653, "Malignant tumor of kidney")], - text_embedding=np.zeros((1, 3), dtype=np.float32), - text_embedding_model="test-model", - metric_type=cast(emb_ext.EmbeddingMetricType, "cosine"), - index_type=cast(emb_ext.EmbeddingIndexType, "flat"), + query_embedding=np.zeros((1, 3), dtype=np.float32), ) assert "compute_missing_embeddings is disabled" in caplog.text diff --git a/tests/test_fulltext_optional.py b/tests/test_fulltext_optional.py index ba6dbb9..e70b730 100644 --- a/tests/test_fulltext_optional.py +++ b/tests/test_fulltext_optional.py @@ -1,5 +1,6 @@ import pytest +from sqlalchemy import Engine from omop_alchemy.cdm.handlers.fulltext import ( CONCEPT_NAME_TSVECTOR_COLUMN, CONCEPT_SYNONYM_NAME_TSVECTOR_COLUMN, @@ -14,7 +15,7 @@ @pytest.mark.parametrize("synonym", [False, True]) -def test_fulltext_query_requires_registered_tsvector_columns(synonym: bool): +def test_fulltext_query_requires_registered_tsvector_columns(synonym: bool, mock_cdm_engine: Engine): """Full-text queries fail cleanly when optional tsvector metadata is absent.""" had_name_column = CONCEPT_NAME_TSVECTOR_COLUMN in Concept.__table__.c had_synonym_column = CONCEPT_SYNONYM_NAME_TSVECTOR_COLUMN in Concept_Synonym.__table__.c @@ -22,7 +23,7 @@ def test_fulltext_query_requires_registered_tsvector_columns(synonym: bool): unregister_optional_fulltext_columns() try: with pytest.raises(FullTextError): - q_concept_name_fulltext("kidney cancer", synonym=synonym) + q_concept_name_fulltext("kidney cancer", synonym=synonym, engine=mock_cdm_engine) finally: if had_name_column or had_synonym_column: register_optional_fulltext_columns() \ No newline at end of file diff --git a/tests/test_grounding.py b/tests/test_grounding.py index 67efce0..53c455f 100644 --- a/tests/test_grounding.py +++ b/tests/test_grounding.py @@ -30,7 +30,7 @@ def _constraints() -> GroundingConstraints: @pytest.mark.parametrize( - "input_text,expected_concept_id", + "query,expected_concept_id", [ pytest.param("Hodgkin's disease (clinical)", 4038835, id="exact-hodgkin"), pytest.param("Malignant neoplasm of ovary", 4181351, id="exact-ovary"), @@ -40,7 +40,7 @@ def _constraints() -> GroundingConstraints: ) def test_grounding_resolves_expected_standard_concepts( mock_cdm_kg: KnowledgeGraph, - input_text: str, + query: str, expected_concept_id: int, ) -> None: pipeline = ResolverPipeline( @@ -55,16 +55,13 @@ def test_grounding_resolves_expected_standard_concepts( ranked = ground_term( resolver_pipeline=pipeline, kg=mock_cdm_kg, - text=input_text, - text_embedding=None, - text_embedding_model=None, + query=query, + query_embedding=None, constraints=_constraints(), max_candidates=1, - metric_type=None, - index_type=None, ) - assert ranked, f"Expected at least one grounding for: {input_text}" + assert ranked, f"Expected at least one grounding for: {query}" assert ranked[0].concept_id == expected_concept_id @@ -76,13 +73,10 @@ def test_grounding_maps_non_standard_candidate_via_relationships( ranked = ground_term( resolver_pipeline=pipeline, kg=mock_cdm_kg, - text="Kidney carcinoma term", - text_embedding=None, - text_embedding_model=None, + query="Kidney carcinoma term", + query_embedding=None, constraints=_constraints(), max_candidates=1, - metric_type=None, - index_type=None, ) assert ranked, "Expected non-standard concept to map to a valid standard concept" @@ -97,13 +91,10 @@ def test_grounding_rejects_concepts_outside_anchored_hierarchy( ranked = ground_term( resolver_pipeline=pipeline, kg=mock_cdm_kg, - text="Meta concept", - text_embedding=None, - text_embedding_model=None, + query="Meta concept", + query_embedding=None, constraints=_constraints(), max_candidates=1, - metric_type=None, - index_type=None, ) assert ranked == [] diff --git a/tests/test_resolvers_mock.py b/tests/test_resolvers_mock.py index 5485ab7..0d14006 100644 --- a/tests/test_resolvers_mock.py +++ b/tests/test_resolvers_mock.py @@ -24,17 +24,18 @@ class _KG: def __init__(self, case: _Case) -> None: self.case = case - def concept_lookup(self, label: str, match_kind: LabelMatchKind, synonym: bool = False, search_constraint=None, sort: bool = False): + def concept_lookup(self, query_term: str, match_kind: LabelMatchKind, synonym: bool = False, search_constraint=None, sort: bool = False): key = _key(match_kind, synonym) concept_ids = self.case.hits.get(key, []) return tuple( LabelMatch( - input_label=label, - matched_label=f"concept-{cid}", - concept_id=cid, + input_query=query_term, + matched_concept_label=f"concept-{cid}", + matched_concept_id=cid, match_kind=match_kind, is_standard=True, is_active=True, + synonym=synonym, ) for cid in concept_ids )