From dea1addb1ca45e526dcf5310f49902e2c15251f8 Mon Sep 17 00:00:00 2001 From: georgie Date: Fri, 15 May 2026 08:26:03 +1000 Subject: [PATCH 01/21] explicit backend api cleanup --- .gitignore | 1 + CHANGELOG.md | 7 +- README.md | 68 +++-- TODO.txt | 2 + docs/index.md | 19 +- docs/loaders/context.md | 1 + docs/loaders/helpers.md | 22 +- docs/loaders/index.md | 23 +- docs/loaders/loaders.md | 21 +- docs/tables/loadable_table.md | 9 +- docs/tables/mat_view.md | 6 +- pyproject.toml | 21 +- src/orm_loader/backends/__init__.py | 12 + src/orm_loader/backends/base.py | 245 +++++++++++++++++ src/orm_loader/backends/postgres.py | 212 +++++++++++++++ src/orm_loader/backends/resolve.py | 42 +++ src/orm_loader/backends/sqlite.py | 248 ++++++++++++++++++ src/orm_loader/helpers/__init__.py | 11 +- src/orm_loader/helpers/bootstrap.py | 6 +- src/orm_loader/helpers/bulk.py | 120 +++------ src/orm_loader/helpers/discovery.py | 11 +- src/orm_loader/helpers/logging.py | 24 +- src/orm_loader/helpers/sqlite.py | 76 ++++-- src/orm_loader/loaders/data_classes.py | 28 -- src/orm_loader/loaders/loading_helpers.py | 13 +- .../mappers/materialised_view_mixin.py | 39 +-- src/orm_loader/tables/loadable_table.py | 219 ++++------------ src/orm_loader/tables/orm_table.py | 24 +- src/orm_loader/tables/serialisable_table.py | 18 +- src/orm_loader/tables/typing.py | 24 +- tests/backends/test_base_backend.py | 237 +++++++++++++++++ tests/backends/test_postgres_backend.py | 139 ++++++++++ tests/backends/test_sqlite_backend.py | 205 +++++++++++++++ tests/conftest.py | 3 +- tests/models.py | 4 +- uv.lock | 66 ++++- 36 files changed, 1742 insertions(+), 484 deletions(-) create mode 100644 TODO.txt create mode 100644 src/orm_loader/backends/__init__.py create mode 100644 src/orm_loader/backends/base.py create mode 100644 src/orm_loader/backends/postgres.py create mode 100644 src/orm_loader/backends/resolve.py create mode 100644 src/orm_loader/backends/sqlite.py create mode 100644 tests/backends/test_base_backend.py create mode 100644 tests/backends/test_postgres_backend.py create mode 100644 tests/backends/test_sqlite_backend.py diff --git a/.gitignore b/.gitignore index b6dc481..f898ecf 100644 --- a/.gitignore +++ b/.gitignore @@ -211,3 +211,4 @@ OMOP_CDM*.csv *.db .vscode/ .DS_Store +_temp/ \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e912da..e973d48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -108,4 +108,9 @@ - literally just removing stale sqlalchemy-utils dependency # 0.3.27 -- adding minimum versions for dependabot alerts (dev deps only) \ No newline at end of file +- adding minimum versions for dependabot alerts (dev deps only) + +# 0.4.0 +- update to handle psycopg (as opposed to psycopg2) cleanly +- overall api cleanup with the goal of being more explicit about selection of specific db backends +- general typing cleanup \ No newline at end of file diff --git a/README.md b/README.md index 63c4644..c0b1c98 100644 --- a/README.md +++ b/README.md @@ -4,28 +4,27 @@ https://github.com/AustralianCancerDataNetwork/orm-loader/actions/workflows/tests.yml ) -A lightweight, reusable foundation for building and validating SQLAlchemy-based clinical (and non-clinical) data models. +A lightweight foundation for building and validating SQLAlchemy-based data models. -This library provides general-purpose ORM infrastructure that sits below any specific data model (OMOP, PCORnet, custom CDMs, etc.), focusing on: +`orm-loader` sits below any particular schema or CDM. It gives you a small set of reusable pieces for defining tables, loading files through staging tables, and checking models against external specifications. It stays out of domain logic on purpose. -* declarative base configuration -* bulk ingestion patterns -* file-based validation & loading -* table introspection -* model-agnostic validation scaffolding -* safe, database-portable operational helpers +The library focuses on: -It intentionally contains no domain logic and no assumptions about a specific schema. +* ORM table mixins and introspection +* staged file loading +* loader and validation infrastructure +* operational helpers that work across supported backends +At the moment, the built-in backends are SQLite and PostgreSQL. -### What this library provides: -This library provides a small set of composable building blocks for defining, loading, inspecting, and validating SQLAlchemy-based data models. -All components are model-agnostic and can be selectively combined in downstream libraries. +### What this library provides -1. A minimal, opinionated ORM table base +The package is deliberately small. Most downstream projects only need a couple of these pieces. -ORMTableBase provides structural introspection utilities for SQLAlchemy-mapped tables, without imposing any domain semantics. +1. A minimal ORM table base + +`ORMTableBase` provides structural utilities for mapped tables without pulling domain rules into the base layer. It supports: * mapper access and inspection @@ -41,17 +40,19 @@ class MyTable(ORMTableBase, Base): __tablename__ = "my_table" ``` -This base is intended to be inherited by all ORM tables, either directly or via higher-level mixins. +You can inherit from it directly or pick it up through one of the higher-level mixins. 2. CSV-based ingestion mixins -CSVLoadableTableInterface adds opt-in CSV loading support for ORM tables using pandas, with a focus on correctness and scalability. +`CSVLoadableTableInterface` adds staged file loading to ORM tables. It can use pandas or PyArrow loaders, and on PostgreSQL it can use a fast `COPY` path when the input is clean enough. Features include: +* staging table creation and cleanup * chunked loading for large files -* optional per-table normalisation logic -* optional deduplication against existing database rows -* safe bulk inserts using SQLAlchemy sessions +* optional casting and deduplication before insert +* backend-specific merge behaviour +* PostgreSQL fast-path loading with ORM fallback +* backend-aware index handling during merge ```python class MyTable(CSVLoadableTableInterface, ORMTableBase, Base): @@ -59,15 +60,11 @@ class MyTable(CSVLoadableTableInterface, ORMTableBase, Base): ``` -Downstream models may override: -* normalise_dataframe(...) -* dedupe_dataframe(...) -* csv_columns() -to implement table-specific ingestion policies. +The main extension points here are loader choice, column mapping, and the normal SQLAlchemy model definitions themselves. Most downstream projects do not need to override much beyond `csv_columns()` and the model schema. 3. Structured serialisation and hashing -SerialisableTableInterface adds lightweight, explicit serialisation helpers for ORM rows. +`SerialisableTableInterface` adds lightweight serialisation helpers for ORM rows. It supports: * conversion to dictionaries @@ -92,7 +89,7 @@ This is useful for: 4. Model registry and validation scaffolding -The library includes model-agnostic validation infrastructure, designed to compare ORM models against external specifications. +The library includes validation infrastructure for comparing ORM models against external specifications. This includes: * a model registry @@ -118,7 +115,8 @@ Validation output is available as: * exit codes suitable for pipelines 5. Database bootstrap helpers -The library provides lightweight helpers for schema creation and bootstrapping, without imposing a migration strategy. + +The library provides lightweight helpers for schema creation and bootstrapping. It does not try to replace migrations. ```python from orm_loader.metadata import Base @@ -127,24 +125,20 @@ from orm_loader.bootstrap import bootstrap bootstrap(engine, create=True) ``` -6. Safe bulk-loading utilities +6. Bulk-loading helpers -A reusable context manager simplifies trusted bulk ingestion workflows: -* temporarily disables foreign key checks where supported -* suppresses autoflush for performance -* ensures reliable rollback on failure +There are a few lower-level helpers for trusted bulk workflows, including backend-aware foreign key management and SQLite connection setup for heavy local loads. ## Summary -This library intentionally focuses on infrastructure, not semantics. +This library is meant to be the boring layer underneath downstream models: -It provides: * reusable ORM mixins -* safe ingestion patterns +* staged ingestion patterns * validation scaffolding -* database-portable utilities +* operational helpers -while leaving domain rules, business logic, and schema semantics to downstream libraries. +Domain rules, business logic, and schema semantics stay in the downstream project. This makes it suitable as a shared foundation for: * clinical data models diff --git a/TODO.txt b/TODO.txt new file mode 100644 index 0000000..fe7d4f0 --- /dev/null +++ b/TODO.txt @@ -0,0 +1,2 @@ +[] consider opt-in malformed text repair (as opposed to existing normalisation) - e.g. load_csv(..., text_repair: str | None = None) +- consider ftfy.fix_encoding() \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index c8ff6bb..015be3c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,7 +3,7 @@ A lightweight, reusable foundation for building and validating SQLAlchemy-based data models. -`orm-loader` provides **infrastructure, not semantics**. +`orm-loader` provides infrastructure for SQLAlchemy-based data models. It is the shared plumbing layer, not the place where model-specific rules live. It focuses on: @@ -11,17 +11,16 @@ It focuses on: - safe bulk ingestion patterns - file-based loading via staging tables - model-agnostic validation scaffolding -- database-portable operational helpers +- operational helpers for supported backends -No domain logic is included. -No schema assumptions are enforced. +It currently ships with backend implementations for SQLite and PostgreSQL. --- ## Core Concepts - **Tables are structural** — semantics live downstream -- **Mixins define capabilities**, not behaviour contracts +- **Mixins define capabilities** - **Protocols decouple infrastructure from implementations** - **Ingestion is explicit and staged** @@ -37,13 +36,7 @@ No schema assumptions are enforced. # Design Philosophy -`orm-loader` is intentionally conservative. - -It provides: - -- *mechanisms*, not policies -- *capabilities*, not workflows -- *structure*, not semantics +`orm-loader` is intentionally conservative. It gives downstream libraries the machinery to load, inspect, and validate data without deciding what the data means. The library is designed to sit **below**: @@ -65,6 +58,7 @@ and **above**: - No schema enforcement - No migrations - No concurrency guarantees +- No support yet for arbitrary database dialects --- @@ -81,4 +75,3 @@ This allows downstream libraries to: - replace base classes - mock implementations - incrementally adopt features - diff --git a/docs/loaders/context.md b/docs/loaders/context.md index 2e4528d..29418fd 100644 --- a/docs/loaders/context.md +++ b/docs/loaders/context.md @@ -25,6 +25,7 @@ on globals or implicit configuration. | `chunksize` | Optional chunk size | | `normalise` | Whether to cast values to ORM types | | `dedupe` | Whether to deduplicate incoming data | +| `quote_mode` | CSV quoting mode for PostgreSQL fast-path loading | ::: orm_loader.loaders.data_classes.LoaderContext diff --git a/docs/loaders/helpers.md b/docs/loaders/helpers.md index f299a95..8842dd3 100644 --- a/docs/loaders/helpers.md +++ b/docs/loaders/helpers.md @@ -1,8 +1,6 @@ # Loader Helper Utilities -This page documents low-level helper functions used by loaders. - -These utilities are stateless and intentionally conservative. +This page covers the low-level functions that support the loader implementations. --- @@ -37,17 +35,17 @@ Used by `ParquetLoader` for internal deduplication. --- -## Conservative CSV parsing +## Batch-oriented CSV parsing ### `conservative_load_parquet(...)` -Reads CSV files using PyArrow with: +Despite the name, this helper reads delimited text with PyArrow and yields batches: - strict column inclusion - malformed row skipping - chunked batch iteration -This is used when loading CSVs via the Parquet pipeline. +This is used by the PyArrow-based loader path. --- @@ -55,18 +53,18 @@ This is used when loading CSVs via the Parquet pipeline. ### `quick_load_pg(...)` -Loads CSV files into PostgreSQL staging tables using `COPY`. +Loads CSV files into a PostgreSQL staging table using `COPY`. ### Characteristics -- Extremely fast -- Bypasses ORM -- Sensitive to data quality issues +- Fast +- Bypasses ORM row construction +- Works best on clean input ### Failure handling - Errors trigger rollback -- Loader falls back to ORM-based loading -- No partial silent loads +- `CSVLoadableTableInterface` falls back to ORM-based loading +- Failures are noisy on purpose This helper is only used when explicitly supported by the database. diff --git a/docs/loaders/index.md b/docs/loaders/index.md index b38f267..3d41b0f 100644 --- a/docs/loaders/index.md +++ b/docs/loaders/index.md @@ -1,14 +1,13 @@ # Loaders -The `orm_loader.loaders` module provides **conservative, schema-aware file -ingestion infrastructure** for loading external data into ORM-backed -staging tables. +The `orm_loader.loaders` module provides conservative, schema-aware file +loading into ORM-backed staging tables. This subsystem is designed to handle: - untrusted or messy source files - large datasets requiring chunked processing -- incremental and repeatable loads +- repeatable staged loads - dialect-specific optimisations (e.g. PostgreSQL COPY) - explicit, inspectable failure modes @@ -23,7 +22,7 @@ they do not embed domain rules or business semantics. [`LoaderContext`](context.md) -A `LoaderContext` object carries all state required to load a single file: +A `LoaderContext` object carries the state required to load one file: - target ORM table - database session @@ -44,8 +43,7 @@ All loaders implement a common interface: - `orm_file_load(ctx)` — orchestrates file ingestion - `dedupe(data, ctx)` — defines deduplication semantics -Concrete implementations differ only in **how data is read and processed**, -not in how it is staged. +Concrete implementations mainly differ in how they read and transform incoming data. --- @@ -54,11 +52,10 @@ not in how it is staged. Loaders always write to **staging tables**, never directly to production tables. -This allows: +This gives you: - safe rollback - repeatable merges -- database-level deduplication - bulk loading optimisations Final merge semantics are handled by the table mixins, not by loaders. @@ -69,8 +66,8 @@ Final merge semantics are handled by the table mixins, not by loaders. | Loader | Use case | |------|----------| -| `PandasLoader` | Flexible, debuggable CSV ingestion | -| `ParquetLoader` | High-volume, columnar ingestion | +| `PandasLoader` | Flexible CSV and TSV ingestion | +| `ParquetLoader` | Columnar or batch-oriented ingestion | Both loaders share the same lifecycle and guarantees. @@ -81,11 +78,11 @@ Both loaders share the same lifecycle and guarantees. 1. Detect file format and encoding 2. Read data in chunks or batches 3. Optionally normalise to ORM column types -4. Optionally deduplicate (internal and/or database-level) +4. Optionally deduplicate within the incoming data 5. Insert into staging table 6. Return row count -No implicit commits or merges occur at this layer. +Final merge behaviour belongs to the table mixins and backend layer, not to the loader itself. --- diff --git a/docs/loaders/loaders.md b/docs/loaders/loaders.md index 897bdb2..1fa8728 100644 --- a/docs/loaders/loaders.md +++ b/docs/loaders/loaders.md @@ -3,8 +3,7 @@ This page documents the concrete loader implementations provided by `orm_loader`. -All loaders implement the same interface and differ only in -how data is read and processed. +All loaders implement the same interface. The difference is in how they read data and how much work they do before rows reach the staging table. --- @@ -24,7 +23,7 @@ All loaders: - load into staging tables only - respect `LoaderContext` flags - return row counts -- avoid implicit commits +- leave final merge behaviour to the table layer --- @@ -34,7 +33,7 @@ All loaders: ### Characteristics -- Supports CSV and TSV inputs +- Works well with CSV and TSV inputs - Easy to debug and inspect - Supports chunked loading - Flexible transformation pipeline @@ -67,7 +66,6 @@ All loaders: - More complex pipeline - Less flexible row-wise transformations -- DB-level deduplication not yet implemented ### Best suited for @@ -79,16 +77,7 @@ All loaders: ## Deduplication behaviour -Deduplication occurs in two phases: - -1. **Internal deduplication** - Removes duplicate primary key rows within the incoming data. - -2. **Database-level deduplication (optional)** - Removes rows that already exist in the database. - -Database-level deduplication is currently implemented for pandas-based -loads. +Deduplication here means deduplicating within the incoming data before it is inserted into staging. The merge step is what decides what happens when incoming rows overlap with existing target rows. --- @@ -100,4 +89,4 @@ When enabled, loaders: - drop rows violating required constraints - log casting failures with examples -No schema changes are performed. +No schema changes are performed at the loader layer. diff --git a/docs/tables/loadable_table.md b/docs/tables/loadable_table.md index 90302e1..51ebfe4 100644 --- a/docs/tables/loadable_table.md +++ b/docs/tables/loadable_table.md @@ -3,10 +3,11 @@ Infrastructure for staged, file-based ingestion into ORM tables. Supports: -- CSV-based ingestion -- optional fast-path database COPY -- dialect-aware merge strategies -- Parquet loading hooks +- staged file loading into backend-specific staging tables +- PostgreSQL fast-path `COPY` with ORM fallback +- backend-aware merge strategies +- pandas and PyArrow-based loader paths +- index handling during merge --- diff --git a/docs/tables/mat_view.md b/docs/tables/mat_view.md index 5bbd429..2a1a807 100644 --- a/docs/tables/mat_view.md +++ b/docs/tables/mat_view.md @@ -1,6 +1,6 @@ # Materialised Views -This module provides a SQLAlchemy-native pattern for defining, creating, refreshing, and orchestrating materialized views using normal `Select` constructs, with explicit dependency management and deterministic refresh order. +This module provides a SQLAlchemy-native way to define, create, refresh, and order materialized views from ordinary `Select` constructs. It is designed for: @@ -9,7 +9,7 @@ It is designed for: * large fact tables with repeated joins or aggregates * schema-level orchestration (migrations, setup, Airflow, admin tasks) -The implementation is PostgreSQL-oriented (due to materialized view support), but remains cleanly isolated from ORM persistence logic. +The implementation is PostgreSQL-oriented. The mixin resolves a backend from the supplied bind, and the built-in PostgreSQL backend is currently the only one that supports materialized views. ## Overview @@ -21,7 +21,7 @@ The materialized view system consists of four main parts: * backing `Select` * optional dependencies 3. Dependency resolution: A topological sort over declared dependencies to determine refresh order. -4. Refresh orchestration: Helpers to refresh one or many materialized views safely and predictably. +4. Refresh orchestration: Helpers to refresh one or many materialized views in a predictable order. ### Defining the Materialised View diff --git a/pyproject.toml b/pyproject.toml index 398dc4c..030cc52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "orm-loader" -version = "0.3.27" +version = "0.4.0" description = "Generic base classes to handle ORM functionality for multiple downstream datamodels" readme = "README.md" authors = [ @@ -14,6 +14,18 @@ dependencies = [ "sqlalchemy>=2.0.45", ] + +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "Topic :: Scientific/Engineering :: Medical Science Apps.", + "Topic :: Database :: Database Engines/Servers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", +] + [project.urls] Homepage = "https://AustralianCancerDataNetwork.github.io/orm-loader" Documentation = "https://AustralianCancerDataNetwork.github.io/orm-loader" @@ -25,7 +37,11 @@ requires = ["uv_build>=0.9.2,<0.10.0"] build-backend = "uv_build" [project.optional-dependencies] +postgres = [ + "psycopg[binary]>=3.2", +] dev = [ + "pytest>=9.0.3", "mypy>=1.19.1", "pytest>=9.0.3", "ruff>=0.14.11", @@ -54,3 +70,6 @@ python_files = ["test_*.py"] python_classes = ["Test*"] python_functions = ["test_*"] addopts = "-ra" + +[tool.pyright] +reportMissingTypeStubs = false \ No newline at end of file diff --git a/src/orm_loader/backends/__init__.py b/src/orm_loader/backends/__init__.py new file mode 100644 index 0000000..d12fe23 --- /dev/null +++ b/src/orm_loader/backends/__init__.py @@ -0,0 +1,12 @@ +from .postgres import PostgresBackend +from .resolve import resolve_backend +from .sqlite import SQLiteBackend +from .base import BackendCapabilities, DatabaseBackend + +__all__ = [ + "BackendCapabilities", + "DatabaseBackend", + "PostgresBackend", + "SQLiteBackend", + "resolve_backend", +] diff --git a/src/orm_loader/backends/base.py b/src/orm_loader/backends/base.py new file mode 100644 index 0000000..09207f4 --- /dev/null +++ b/src/orm_loader/backends/base.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from contextlib import AbstractContextManager, contextmanager, nullcontext +from dataclasses import dataclass +from typing import TYPE_CHECKING, Type, Any, Iterator + +import sqlalchemy as sa +import sqlalchemy.orm as so +from sqlalchemy.engine import Connection, Engine + +if TYPE_CHECKING: + from ..loaders.data_classes import LoaderContext + from ..tables.typing import CSVTableProtocol + + +@dataclass(frozen=True) +class BackendCapabilities: + """ + Capability flags exposed by a database backend. + + These defaults are intentionally conservative. Concrete backends should + opt into capabilities explicitly. + """ + + supports_fast_load: bool = False + supports_unlogged_staging: bool = False + supports_fk_toggle: bool = False + supports_materialized_views: bool = False + + +class DatabaseBackend(ABC): + """ + Abstract base class for database-specific loader behavior. + + This class defines the stable contract for future backend implementations + without changing existing loader orchestration yet. + """ + + @property + @abstractmethod + def name(self) -> str: + """Human-readable backend name.""" + + @property + @abstractmethod + def dialect_names(self) -> tuple[str, ...]: + """SQLAlchemy dialect names handled by this backend.""" + + @property + @abstractmethod + def capabilities(self) -> BackendCapabilities: + """Capability flags supported by this backend.""" + + def supports_dialect(self, dialect_name: str) -> bool: + """Return ``True`` when the backend handles the given dialect name.""" + return dialect_name in self.dialect_names + + @property + def default_index_strategy(self) -> str: + """Default index strategy used when callers request ``auto``.""" + return "drop_rebuild" + + def resolve_index_strategy(self, index_strategy: str) -> str: + """ + Resolve a caller-facing index strategy to a concrete backend choice. + """ + valid = {"auto", "drop_rebuild", "keep"} + if index_strategy not in valid: + raise ValueError( + f"Unknown index_strategy '{index_strategy}'. Expected one of: {sorted(valid)}" + ) + if index_strategy == "auto": + return self.default_index_strategy + return index_strategy + + def _require_capability(self, capability_name: str, feature_name: str) -> None: + """ + Raise a clear error when a backend capability is not supported. + """ + if not hasattr(self.capabilities, capability_name): + raise AttributeError( + f"Unknown backend capability {capability_name!r} on {type(self.capabilities).__name__}" + ) + if not getattr(self.capabilities, capability_name): + raise NotImplementedError( + f"Backend '{self.name}' does not support {feature_name}" + ) + + @contextmanager + def _as_connection( + self, + bind: Engine | Connection, + ) -> Iterator[Connection]: + if isinstance(bind, Engine): + with bind.connect() as conn: + yield conn + else: + yield bind + + def _insertable_column_names( + self, + table_cls: Type["CSVTableProtocol"], + ) -> list[str]: + """ + Return column names safe to include in generic insert statements. + + Computed columns are excluded because backend loaders and merge helpers + should not attempt to write to them directly. + """ + return [c.name for c in table_cls.__table__.columns if c.computed is None] + + @abstractmethod + def create_staging_table( + self, + table_cls: Type["CSVTableProtocol"], + session: so.Session, + staging_name: str, + ) -> None: + """Create a staging table for the supplied ORM table class.""" + + @abstractmethod + def drop_staging_table( + self, + session: so.Session, + staging_name: str, + ) -> None: + """Drop a staging table if it exists.""" + + def load_staging_fast( + self, + loader_context: "LoaderContext", + staging_name: str, + ) -> int | None: + """ + Attempt a backend-native fast-path load. + + Return the inserted row count when handled, or ``None`` when the + backend has no fast-path loader for the given context. + """ + return None + + @abstractmethod + def disable_fk_check(self, session: so.Session) -> str | int: + """Disable FK checks and return the previous backend-specific state.""" + + @abstractmethod + def enable_fk_check(self, session: so.Session) -> str | int: + """Explicitly enable FK checks and return the previous backend-specific state.""" + + @abstractmethod + def restore_fk_check( + self, + session: so.Session, + previous_state: str | int, + ) -> None: + """Restore FK checks to a previously returned backend-specific state.""" + + @abstractmethod + def merge_replace( + self, + table_cls: Type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + """Merge staging rows by replacing matching target rows first.""" + + @abstractmethod + def merge_upsert( + self, + table_cls: Type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + """Merge staging rows using backend-specific upsert semantics.""" + + @abstractmethod + def merge_insert( + self, + table_cls: Type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + ) -> None: + """Insert all staging rows into the target table.""" + + def merge_context( + self, + table_cls: Type["CSVTableProtocol"], + session: so.Session, + ) -> AbstractContextManager[None]: + """Return a context manager for merge-time backend operations.""" + return nullcontext() + + @contextmanager + def bulk_load_context( + self, + session: so.Session, + *, + disable_fk: bool = True, + no_autoflush: bool = True, + ): + """ + Generic bulk-load context that defers FK semantics to the backend. + """ + previous_fk_state: str | int | None = None + try: + if disable_fk: + self._require_capability("supports_fk_toggle", "foreign key toggling") + previous_fk_state = self.disable_fk_check(session) + + if no_autoflush: + with session.no_autoflush: + yield + else: + yield + + except Exception: + session.rollback() + raise + + finally: + if previous_fk_state is not None: + self.restore_fk_check(session, previous_fk_state) + + @abstractmethod + def create_materialized_view( + self, + bind: "Engine | Connection", + name: str, + selectable: sa.sql.Select[Any], + ) -> None: + """Create a materialized view for the supplied selectable.""" + + @abstractmethod + def refresh_materialized_view( + self, + bind: "Engine | Connection", + name: str, + ) -> None: + """Refresh a materialized view.""" diff --git a/src/orm_loader/backends/postgres.py b/src/orm_loader/backends/postgres.py new file mode 100644 index 0000000..9dbdc4a --- /dev/null +++ b/src/orm_loader/backends/postgres.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any +import sqlalchemy as sa +import sqlalchemy.orm as so + +from .base import BackendCapabilities, DatabaseBackend +from ..loaders.loading_helpers import quick_load_pg + +if TYPE_CHECKING: + from sqlalchemy.engine import Connection, Engine + + from ..loaders.data_classes import LoaderContext + from ..tables.typing import CSVTableProtocol + + +class PostgresBackend(DatabaseBackend): + @property + def name(self) -> str: + return "postgres" + + @property + def dialect_names(self) -> tuple[str, ...]: + return ("postgresql",) + + @property + def capabilities(self) -> BackendCapabilities: + return BackendCapabilities( + supports_fast_load=True, + supports_unlogged_staging=True, + supports_fk_toggle=True, + supports_materialized_views=True, + ) + + def create_staging_table( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + staging_name: str, + ) -> None: + table = table_cls.__table__ + session.execute(sa.text(f'DROP TABLE IF EXISTS "{staging_name}";')) + session.execute( + sa.text( + f''' + CREATE UNLOGGED TABLE "{staging_name}" + (LIKE "{table.name}" INCLUDING DEFAULTS INCLUDING CONSTRAINTS); + ''' + ) + ) + + computed_cols = [c.name for c in table.columns if c.computed is not None] + for col in computed_cols: + session.execute(sa.text(f'ALTER TABLE "{staging_name}" DROP COLUMN "{col}";')) + + session.commit() + + def drop_staging_table( + self, + session: so.Session, + staging_name: str, + ) -> None: + session.execute(sa.text(f'DROP TABLE IF EXISTS "{staging_name}"')) + + def load_staging_fast( + self, + loader_context: "LoaderContext", + staging_name: str, + ) -> int | None: + return quick_load_pg( + path=loader_context.path, + session=loader_context.session, + tablename=staging_name, + quote_mode=loader_context.quote_mode, + ) + + def disable_fk_check(self, session: so.Session) -> str | int: + previous_state = session.execute(sa.text("SHOW session_replication_role")).scalar() + session.execute(sa.text("SET session_replication_role = 'replica'")) + assert isinstance(previous_state, str), "Expected PostgreSQL FK state to be a string" + return previous_state + + def enable_fk_check(self, session: so.Session) -> str | int: + previous_state = session.execute(sa.text("SHOW session_replication_role")).scalar() + session.execute(sa.text("SET session_replication_role = 'origin'")) + assert isinstance(previous_state, str), "Expected PostgreSQL FK state to be a string" + return previous_state + + def restore_fk_check( + self, + session: so.Session, + previous_state: str | int, + ) -> None: + session.execute(sa.text(f"SET session_replication_role = '{previous_state}'")) + + def merge_replace( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + pk_join = " AND ".join( + f't."{c}" = s."{c}"' for c in pk_cols + ) + session.execute( + sa.text( + f""" + DELETE FROM "{target_name}" t + USING "{staging_name}" s + WHERE {pk_join}; + """ + ) + ) + + def merge_upsert( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + insertable_cols = self._insertable_column_names(table_cls) + cols_str = ", ".join(f'"{c}"' for c in insertable_cols) + conflict_cols = ", ".join(f'"{c}"' for c in pk_cols) + session.execute( + sa.text( + f""" + INSERT INTO "{target_name}" ({cols_str}) + SELECT {cols_str} FROM "{staging_name}" + ON CONFLICT ({conflict_cols}) DO NOTHING; + """ + ) + ) + + def merge_insert( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + ) -> None: + insertable_cols = self._insertable_column_names(table_cls) + cols_str = ", ".join(f'"{c}"' for c in insertable_cols) + session.execute( + sa.text( + f""" + INSERT INTO "{target_name}" ({cols_str}) + SELECT {cols_str} FROM "{staging_name}"; + """ + ) + ) + + def merge_context( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + ): + return self.bulk_load_context(session, disable_fk=True, no_autoflush=False) + + + + def create_materialized_view( + self, + bind: Engine | Connection, + name: str, + selectable: sa.sql.Select[Any], + ) -> None: + from ..mappers.materialised_view_mixin import CreateMaterializedView + + with self._as_connection(bind) as conn: + conn.execute(CreateMaterializedView(name, selectable)) + + def refresh_materialized_view( + self, + bind: Engine | Connection, + name: str, + ) -> None: + with self._as_connection(bind) as conn: + safe_name = name + dialect = getattr(conn, "dialect", None) + if dialect is not None: + safe_name = dialect.identifier_preparer.quote(name) + conn.execute( + sa.text(f"REFRESH MATERIALIZED VIEW {safe_name};") + ) + + @contextmanager + def engine_with_replica_role(self, engine: "Engine"): + @sa.event.listens_for(engine, "connect") # type: ignore[arg-type] + def _set_replica_role( + dbapi_conn: sa.engine.interfaces.DBAPIConnection, + _, + ) -> None: + cur = dbapi_conn.cursor() + cur.execute("SET session_replication_role = replica") + cur.close() + + try: + yield engine + finally: + with engine.connect() as conn: + conn = conn.execution_options(isolation_level="AUTOCOMMIT") + conn.execute(sa.text("SET session_replication_role = DEFAULT")) + role = conn.execute( + sa.text("SHOW session_replication_role") + ).scalar() + if role != "origin": + raise RuntimeError("Failed to restore session_replication_role") diff --git a/src/orm_loader/backends/resolve.py b/src/orm_loader/backends/resolve.py new file mode 100644 index 0000000..9333c0e --- /dev/null +++ b/src/orm_loader/backends/resolve.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import sqlalchemy as sa +import sqlalchemy.orm as so + +from .base import DatabaseBackend +from .postgres import PostgresBackend +from .sqlite import SQLiteBackend + +if TYPE_CHECKING: + from sqlalchemy.engine import Connection, Engine + + +_BACKEND_TYPES: tuple[type[DatabaseBackend], ...] = ( + PostgresBackend, + SQLiteBackend, +) + + +def _dialect_name(bindable: so.Session | "Engine" | "Connection") -> str: + if isinstance(bindable, so.Session): + bind = bindable.get_bind() + return bind.dialect.name + + if hasattr(bindable, "dialect"): + return bindable.dialect.name + + raise TypeError(f"Unsupported bindable type: {type(bindable)!r}") + + +def resolve_backend(bindable: so.Session | "Engine" | "Connection") -> DatabaseBackend: + """ + Resolve a concrete backend from a SQLAlchemy session, engine, or connection. + """ + dialect_name = _dialect_name(bindable) + for backend_type in _BACKEND_TYPES: + backend = backend_type() + if backend.supports_dialect(dialect_name): + return backend + raise NotImplementedError(f"No backend registered for dialect '{dialect_name}'") diff --git a/src/orm_loader/backends/sqlite.py b/src/orm_loader/backends/sqlite.py new file mode 100644 index 0000000..54dcd45 --- /dev/null +++ b/src/orm_loader/backends/sqlite.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +import logging +import sqlite3 +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import sqlalchemy as sa +import sqlalchemy.orm as so +from sqlalchemy import event, text +from sqlalchemy.exc import IntegrityError + +from .base import BackendCapabilities, DatabaseBackend + +if TYPE_CHECKING: + from sqlalchemy.engine import Connection, Engine + + from ..tables.typing import CSVTableProtocol + + +logger = logging.getLogger(__name__) + + +class SQLiteBackend(DatabaseBackend): + def __init__( + self, + *, + busy_timeout_ms: int = 60000, + journal_mode: str = "WAL", + defer_foreign_keys: bool = True, + ) -> None: + self.busy_timeout_ms = busy_timeout_ms + self.journal_mode = journal_mode + self.defer_foreign_keys = defer_foreign_keys + + @property + def name(self) -> str: + return "sqlite" + + @property + def dialect_names(self) -> tuple[str, ...]: + return ("sqlite",) + + @property + def capabilities(self) -> BackendCapabilities: + return BackendCapabilities( + supports_fast_load=False, + supports_unlogged_staging=False, + supports_fk_toggle=True, + supports_materialized_views=False, + ) + + @property + def default_index_strategy(self) -> str: + return "keep" + + def create_staging_table( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + staging_name: str, + ) -> None: + session.execute(sa.text(f'DROP TABLE IF EXISTS "{staging_name}";')) + + metadata = sa.MetaData() + staging_columns = [ + sa.Column(col.name, col.type, nullable=True) + for col in table_cls.__table__.columns + ] + staging_table = sa.Table(staging_name, metadata, *staging_columns) + metadata.create_all(bind=session.connection(), tables=[staging_table]) + session.commit() + + def drop_staging_table( + self, + session: so.Session, + staging_name: str, + ) -> None: + session.execute(sa.text(f'DROP TABLE IF EXISTS "{staging_name}"')) + + def disable_fk_check(self, session: so.Session) -> str | int: + previous_state = session.execute(text("PRAGMA foreign_keys")).scalar() + session.execute(text("PRAGMA foreign_keys = OFF")) + assert isinstance(previous_state, int), "Expected SQLite FK state to be an int" + return previous_state + + def enable_fk_check(self, session: so.Session) -> str | int: + previous_state = session.execute(text("PRAGMA foreign_keys")).scalar() + session.execute(text("PRAGMA foreign_keys = ON")) + assert isinstance(previous_state, int), "Expected SQLite FK state to be an int" + return previous_state + + def restore_fk_check( + self, + session: so.Session, + previous_state: str | int, + ) -> None: + session.execute(text(f"PRAGMA foreign_keys = {previous_state}")) + + def merge_replace( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + if len(pk_cols) == 1: + pk = pk_cols[0] + session.execute( + sa.text( + f""" + DELETE FROM "{target_name}" + WHERE "{pk}" IN ( + SELECT "{pk}" FROM "{staging_name}" + ); + """ + ) + ) + return + + pk_match = " AND ".join( + f'"{target_name}"."{c}" = "{staging_name}"."{c}"' for c in pk_cols + ) + session.execute( + sa.text( + f""" + DELETE FROM "{target_name}" + WHERE EXISTS ( + SELECT 1 FROM "{staging_name}" + WHERE {pk_match} + ); + """ + ) + ) + + def merge_upsert( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + insertable_cols = self._insertable_column_names(table_cls) + cols_str = ", ".join(f'"{c}"' for c in insertable_cols) + session.execute( + sa.text( + f""" + INSERT OR IGNORE INTO "{target_name}" ({cols_str}) + SELECT {cols_str} FROM "{staging_name}"; + """ + ) + ) + + def merge_insert( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + ) -> None: + insertable_cols = self._insertable_column_names(table_cls) + cols_str = ", ".join(f'"{c}"' for c in insertable_cols) + session.execute( + sa.text( + f""" + INSERT INTO "{target_name}" ({cols_str}) + SELECT {cols_str} FROM "{staging_name}"; + """ + ) + ) + + def merge_context( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + ): + return self.bulk_load_context(session, disable_fk=True, no_autoflush=False) + + def create_materialized_view( + self, + bind: "Engine | Connection", + name: str, + selectable: sa.sql.Select[Any], + ) -> None: + self._require_capability("supports_materialized_views", "materialized views") + + def refresh_materialized_view( + self, + bind: "Engine | Connection", + name: str, + ) -> None: + self._require_capability("supports_materialized_views", "materialized views") + + def configure_dbapi_connection(self, dbapi_connection: sa.engine.interfaces.DBAPIConnection) -> None: + if dbapi_connection.__class__.__module__.startswith("sqlite3"): + cursor = dbapi_connection.cursor() + cursor.execute(f"PRAGMA busy_timeout = {self.busy_timeout_ms}") + cursor.execute(f"PRAGMA journal_mode = {self.journal_mode}") + if self.defer_foreign_keys: + cursor.execute("PRAGMA defer_foreign_keys = ON;") + cursor.close() + + def install_engine_hooks(self, engine: "Engine") -> None: + @event.listens_for(engine, "connect") + def _enable_sqlite_foreign_keys( # type: ignore + dbapi_connection: sa.engine.interfaces.DBAPIConnection, + _connection_record: Any + ) -> None: + self.configure_dbapi_connection(dbapi_connection) + + def explain_fk_error( + self, + session: so.Session, + exc: IntegrityError, + *, + raise_error: bool = True, + ) -> None: + bind: Engine | Connection = session.get_bind() + if bind.dialect.name != "sqlite": + raise exc + + with self._as_connection(bind) as conn: + rows = conn.execute(text("PRAGMA foreign_key_check")).fetchall() + + if rows: + for row in rows: + logger.error( + "FK violation: table=%s rowid=%s references=%s fk_index=%s", + row[0], row[1], row[2], row[3] + ) + + if raise_error: + raise exc + + def restore_journal_mode(self, db_path: Path) -> None: + timeout_s = max(self.busy_timeout_ms / 1000, 5) + try: + with sqlite3.connect(db_path.resolve(), timeout=timeout_s) as conn: + conn.execute("PRAGMA wal_checkpoint(TRUNCATE)") + conn.execute("PRAGMA journal_mode = DELETE") + conn.commit() + except sqlite3.OperationalError as exc: + raise RuntimeError( + "Failed to restore SQLite journal mode. " + "Close or dispose active SQLite connections before calling this helper." + ) from exc diff --git a/src/orm_loader/helpers/__init__.py b/src/orm_loader/helpers/__init__.py index 32623f5..01742a7 100644 --- a/src/orm_loader/helpers/__init__.py +++ b/src/orm_loader/helpers/__init__.py @@ -1,7 +1,12 @@ from .errors import IngestError, ValidationError from .logging import get_logger, configure_logging from .bootstrap import bootstrap, create_db -from .sqlite import enable_sqlite_foreign_keys, explain_sqlite_fk_error +from .sqlite import ( + attach_sqlite_bulk_load_pragmas, + enable_sqlite_foreign_keys, + explain_sqlite_fk_error, + restore_sqlite_journal_mode, +) from .bulk import bulk_load_context, engine_with_replica_role from .metadata import Base from .discovery import get_model_by_tablename @@ -14,11 +19,13 @@ "configure_logging", "bootstrap", "create_db", + "attach_sqlite_bulk_load_pragmas", "enable_sqlite_foreign_keys", "explain_sqlite_fk_error", + "restore_sqlite_journal_mode", "bulk_load_context", "engine_with_replica_role", "Base", "get_model_by_tablename", "normalise_null", -] \ No newline at end of file +] diff --git a/src/orm_loader/helpers/bootstrap.py b/src/orm_loader/helpers/bootstrap.py index 473d6e5..08f7760 100644 --- a/src/orm_loader/helpers/bootstrap.py +++ b/src/orm_loader/helpers/bootstrap.py @@ -1,13 +1,13 @@ from .metadata import Base import logging - +import sqlalchemy as sa logger = logging.getLogger(__name__) -def create_db(engine): +def create_db(engine: sa.engine.Engine) -> None: logger.debug("Creating database schema") Base.metadata.create_all(engine) -def bootstrap(engine, *, create: bool = True): +def bootstrap(engine: sa.engine.Engine, *, create: bool = True) -> None: logger.info("Bootstrapping schema (create=%s)", create) if create: create_db(engine) diff --git a/src/orm_loader/helpers/bulk.py b/src/orm_loader/helpers/bulk.py index 4c3a40a..7af521a 100644 --- a/src/orm_loader/helpers/bulk.py +++ b/src/orm_loader/helpers/bulk.py @@ -1,61 +1,28 @@ from contextlib import contextmanager -from sqlalchemy import text, Engine +from sqlalchemy import Engine from sqlalchemy.orm import Session -import sqlalchemy as sa +from ..backends.resolve import resolve_backend from .logging import get_logger logger = get_logger(__name__) def disable_fk_check(session: Session) -> str | int: - """Disables FK checks and returns the previous state.""" - engine = session.get_bind() - dialect = engine.dialect.name - previous_state = None - - if dialect == "postgresql": - previous_state = session.execute(text("SHOW session_replication_role")).scalar() - session.execute(text("SET session_replication_role = 'replica'")) - elif dialect == "sqlite": - previous_state = session.execute(text("PRAGMA foreign_keys")).scalar() - session.execute(text("PRAGMA foreign_keys = OFF")) - else: - raise NotImplementedError(f"FK disable not implemented for {dialect}") - + """Disable foreign-key checks for the current session and return the previous state.""" + previous_state = resolve_backend(session).disable_fk_check(session) logger.info("Disabled foreign key checks for bulk load.") assert isinstance(previous_state, (str, int)), "Expected previous FK state to be str or int" return previous_state def enable_fk_check(session: Session) -> str | int: - """Explicitly enables FK checks and returns the previous state.""" - engine = session.get_bind() - dialect = engine.dialect.name - previous_state = None - - if dialect == "postgresql": - previous_state = session.execute(text("SHOW session_replication_role")).scalar() - session.execute(text("SET session_replication_role = 'origin'")) - elif dialect == "sqlite": - previous_state = session.execute(text("PRAGMA foreign_keys")).scalar() - session.execute(text("PRAGMA foreign_keys = ON")) - else: - raise NotImplementedError(f"FK enable not implemented for {dialect}") - + """Enable foreign-key checks for the current session and return the previous state.""" + previous_state = resolve_backend(session).enable_fk_check(session) logger.info("Explicitly re-enabled foreign key checks.") assert isinstance(previous_state, (str, int)), "Expected previous FK state to be str or int" return previous_state def restore_fk_check(session: Session, previous_state: str | int): - """Restores FK checks to a specifically provided previous state.""" - engine = session.get_bind() - dialect = engine.dialect.name - - if dialect == "postgresql": - session.execute(text(f"SET session_replication_role = '{previous_state}'")) - elif dialect == "sqlite": - session.execute(text(f"PRAGMA foreign_keys = {previous_state}")) - else: - raise NotImplementedError(f"FK restore not implemented for {dialect}") - + """Restore foreign-key checks to a previously captured backend-specific state.""" + resolve_backend(session).restore_fk_check(session, previous_state) logger.info(f"Restored foreign key checks to state: {previous_state}") @contextmanager @@ -65,60 +32,35 @@ def bulk_load_context( disable_fk: bool = True, no_autoflush: bool = True, ): - previous_fk_state = None - try: - if disable_fk: - previous_fk_state = disable_fk_check(session) - - if no_autoflush: - with session.no_autoflush: - yield - else: - yield - - except Exception: - session.rollback() - raise + """ + Wrap a trusted bulk operation in backend-aware session settings. - finally: - if previous_fk_state is not None: - restore_fk_check(session, previous_fk_state) + This is a thin helper over ``DatabaseBackend.bulk_load_context()``. + It exists so older call sites can keep using the helper import path. + """ + backend = resolve_backend(session) + with backend.bulk_load_context( + session, + disable_fk=disable_fk, + no_autoflush=no_autoflush, + ): + yield @contextmanager def engine_with_replica_role(engine: Engine): """ - Context manager that: - - forces session_replication_role=replica on all connections - - restores DEFAULT on exit - - this is different to bulk_load_context manager from orm_loader.helpers - because this is engine scoped where that one is session scoped + Force ``session_replication_role=replica`` on PostgreSQL engine connections. - postgres only + This is engine-scoped rather than session-scoped. It is only available + on backends that explicitly implement the behaviour. """ - @sa.event.listens_for(engine, "connect") # type: ignore - def _set_replica_role(dbapi_conn, _): - cur = dbapi_conn.cursor() - cur.execute("SET session_replication_role = replica") - cur.close() - - try: - yield engine - finally: - # Explicitly restore on a fresh connection - with engine.connect() as conn: - conn = conn.execution_options(isolation_level="AUTOCOMMIT") - conn.execute(text("SET session_replication_role = DEFAULT")) - - role = conn.execute( - text("SHOW session_replication_role") - ).scalar() - - if role != "origin": - raise RuntimeError( - "Failed to restore session_replication_role" - ) - - logger.info("session_replication_role restored to DEFAULT") + backend = resolve_backend(engine) + method = getattr(backend, "engine_with_replica_role", None) + if method is None: + raise NotImplementedError( + f"Backend '{backend.name}' does not support replica-role engine contexts" + ) + with method(engine) as wrapped: + yield wrapped diff --git a/src/orm_loader/helpers/discovery.py b/src/orm_loader/helpers/discovery.py index eb3e1a1..69ec5b3 100644 --- a/src/orm_loader/helpers/discovery.py +++ b/src/orm_loader/helpers/discovery.py @@ -1,10 +1,13 @@ -from typing import Type +from typing import TypeVar from .metadata import Base -def get_model_by_tablename(tablename: str, base: Type[Base] | None = None) -> Type | None: +ModelT = TypeVar("ModelT", bound=Base) + +def get_model_by_tablename( + tablename: str, + base: type[ModelT] = Base, +) -> type[ModelT] | None: tablename = tablename.lower().strip() - if base is None: - base = Base for cls in base.__subclasses__(): if getattr(cls, "__tablename__", None) == tablename: return cls diff --git a/src/orm_loader/helpers/logging.py b/src/orm_loader/helpers/logging.py index fc92ae8..376f5e8 100644 --- a/src/orm_loader/helpers/logging.py +++ b/src/orm_loader/helpers/logging.py @@ -1,6 +1,6 @@ from __future__ import annotations import logging -from typing import Optional +from typing import Optional, Any import re SENSITIVE_KEYS = { @@ -19,19 +19,15 @@ def _coerce_log_level(level: int | str) -> int: if isinstance(level, int): return level - if isinstance(level, str): - s = level.strip().upper() - if s.isdigit(): - return int(s) + s = level.strip().upper() + if s.isdigit(): + return int(s) - mapping = logging.getLevelNamesMapping() - if s in mapping: - return mapping[s] - - raise ValueError(f"Invalid log level: {level!r}") - - raise TypeError(f"Invalid log level type: {type(level)}") + mapping = logging.getLevelNamesMapping() + if s in mapping: + return mapping[s] + raise ValueError(f"Invalid log level: {level!r}") def get_logger(name: Optional[str] = None) -> logging.Logger: """ @@ -46,13 +42,13 @@ def get_logger(name: Optional[str] = None) -> logging.Logger: class RedactingFormatter(logging.Formatter): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self._pattern = re.compile( r"(?i)\\b(" + "|".join(SENSITIVE_KEYS) + r")\\b\\s*[:=]\\s*[^\\s,;]+" ) - def format(self, record): + def format(self, record: logging.LogRecord) -> str: msg = super().format(record) return self._pattern.sub(r"\\1=", msg) diff --git a/src/orm_loader/helpers/sqlite.py b/src/orm_loader/helpers/sqlite.py index 19e4fe0..b27ce18 100644 --- a/src/orm_loader/helpers/sqlite.py +++ b/src/orm_loader/helpers/sqlite.py @@ -1,32 +1,56 @@ -from sqlalchemy import event, text +from pathlib import Path +from typing import Any + from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError -import logging -logger = logging.getLogger(__name__) +from ..backends.sqlite import SQLiteBackend + +def enable_sqlite_foreign_keys( + dbapi_connection: Any, + connection_record: Any, +) -> None: + """ + Apply the default SQLite connection settings used by orm-loader. + + This helper is kept for compatibility with older event-hook setups. + It delegates to ``SQLiteBackend.configure_dbapi_connection()``, + which may apply more than just foreign-key settings. + """ + del connection_record + SQLiteBackend().configure_dbapi_connection(dbapi_connection) + + +def attach_sqlite_bulk_load_pragmas( + engine: Engine, + *, + busy_timeout_ms: int = 60000, + journal_mode: str = "WAL", + defer_foreign_keys: bool = True, +) -> None: + """ + Install SQLite connect hooks aimed at heavy local write workloads. + + The hook currently sets ``busy_timeout`` and journal mode, and can + also enable deferred foreign-key checking for the connection. + """ + SQLiteBackend( + busy_timeout_ms=busy_timeout_ms, + journal_mode=journal_mode, + defer_foreign_keys=defer_foreign_keys, + ).install_engine_hooks(engine) -@event.listens_for(Engine, "connect") -def enable_sqlite_foreign_keys(dbapi_connection, connection_record): - if dbapi_connection.__class__.__module__.startswith("sqlite3"): - logger.debug("Enabling SQLite foreign key enforcement") - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA defer_foreign_keys = ON;") - cursor.close() def explain_sqlite_fk_error(session, exc: IntegrityError, raise_error: bool = True): - engine = session.get_bind() - if engine.dialect.name != "sqlite": - raise exc - - with engine.connect() as conn: - rows = conn.execute(text("PRAGMA foreign_key_check")).fetchall() - - if rows: - for r in rows: - logger.error( - "FK violation: table=%s rowid=%s references=%s fk_index=%s", - r[0], r[1], r[2], r[3] - ) - - if raise_error: - raise exc + """Log SQLite foreign-key check details before re-raising an error.""" + SQLiteBackend().explain_fk_error(session, exc, raise_error=raise_error) + + +def restore_sqlite_journal_mode(db_path: Path) -> None: + """ + Checkpoint WAL contents and switch the database back to ``DELETE`` mode. + + Call this after disposing active SQLite connections. Reconnecting + through an engine that still installs WAL hooks will enable WAL again. + """ + SQLiteBackend().restore_journal_mode(db_path) diff --git a/src/orm_loader/loaders/data_classes.py b/src/orm_loader/loaders/data_classes.py index f7dfe8b..148cd0e 100644 --- a/src/orm_loader/loaders/data_classes.py +++ b/src/orm_loader/loaders/data_classes.py @@ -170,33 +170,6 @@ def dedupe(cls, data: pd.DataFrame | pa.Table, ctx: LoaderContext) -> Any: """ raise NotImplementedError - # @classmethod - # def _dedupe_db(cls, df: pd.DataFrame, ctx: LoaderContext) -> pd.DataFrame: - # """ - # Perform database-level deduplication against existing rows. - - # Parameters - # ---------- - # df - # Incoming DataFrame. - # ctx - # Loader context. - - # Returns - # ------- - # pandas.DataFrame - # DataFrame with rows already present in the database removed. - # """ - # pk_names = ctx.tableclass.pk_names() - # pk_tuples = list(df[pk_names].itertuples(index=False, name=None)) - # if not pk_tuples: - # return df - # tableclass = ( - # ctx.staging_table - # if ctx.staging_table is not None - # else ctx.tableclass.__table__ - # ) - # pk_cols = [getattr(tableclass.c, pk) for pk in pk_names] # vars_per_row = len(pk_cols) # chunk_size = max(1, 10_000 // vars_per_row) @@ -286,4 +259,3 @@ def to_dict(self) -> dict[str, dict[str, Any]]: for col, stats in self.columns.items() } - diff --git a/src/orm_loader/loaders/loading_helpers.py b/src/orm_loader/loaders/loading_helpers.py index 93dd09e..353d456 100644 --- a/src/orm_loader/loaders/loading_helpers.py +++ b/src/orm_loader/loaders/loading_helpers.py @@ -10,6 +10,7 @@ import io logger = logging.getLogger(__name__) +COPY_BLOCK_SIZE = 8192 """ Loader Helper Functions @@ -202,17 +203,17 @@ def quick_load_pg( try: with open(path, "rb") as f: stream = NormalisedCSVStream(f, encoding=encoding, delimiter=delimiter) - - cur.copy_expert( - sql=f''' + with cur.copy( + f''' COPY "{tablename}" FROM STDIN WITH ( {copy_options} ) - ''', - file=stream, - ) + ''' + ) as copy: + while data := stream.read(COPY_BLOCK_SIZE): + copy.write(data) session.flush() total = session.execute(sa.text(f'SELECT COUNT(*) FROM "{tablename}"')).scalar_one() return total diff --git a/src/orm_loader/mappers/materialised_view_mixin.py b/src/orm_loader/mappers/materialised_view_mixin.py index a01096f..34e037c 100644 --- a/src/orm_loader/mappers/materialised_view_mixin.py +++ b/src/orm_loader/mappers/materialised_view_mixin.py @@ -1,7 +1,9 @@ from sqlalchemy.ext import compiler from sqlalchemy.schema import DDLElement import sqlalchemy as sa +from typing import Any from collections import defaultdict, deque +from ..backends.resolve import resolve_backend class CreateMaterializedView(DDLElement): """ @@ -23,12 +25,16 @@ class CreateMaterializedView(DDLElement): materialized view. """ - def __init__(self, name, selectable): + def __init__(self, name: str, selectable: sa.sql.Select[Any]): self.name = name self.selectable = selectable @compiler.compiles(CreateMaterializedView) -def _create_view(element, compiler, **kw): +def _create_view( # type: ignore + element: CreateMaterializedView, + compiler: sa.sql.compiler.SQLCompiler, + **kwargs: Any +) -> str: """ `_create_view` @@ -150,11 +156,11 @@ class DailyObservationCountsMV(Base, MaterializedViewMixin): """ __mv_name__: str - __mv_select__: sa.sql.Select + __mv_select__: sa.sql.Select[Any] __mv_dependencies__: set[str] = set() @classmethod - def create_mv(cls, bind): + def create_mv(cls, bind: "sa.engine.Connection | sa.engine.Engine") -> None: """ Create the materialized view if it does not already exist. @@ -166,8 +172,8 @@ def create_mv(cls, bind): Notes ----- The underlying SQL is emitted via a custom DDL element and executed - directly against the database. This operation is not transactional - on all backends. + through the resolved backend. With the built-in backends, this means + PostgreSQL. Unsupported backends raise ``NotImplementedError``. Examples @@ -193,11 +199,11 @@ def create_mv(cls, bind): WHERE observation.observation_date >= CURRENT_DATE - INTERVAL '30 days'; ``` """ - ddl = CreateMaterializedView(cls.__mv_name__, cls.__mv_select__) - bind.execute(ddl) + backend = resolve_backend(bind) + backend.create_materialized_view(bind, cls.__mv_name__, cls.__mv_select__) @classmethod - def refresh_mv(cls, bind): + def refresh_mv(cls, bind: "sa.engine.Connection | sa.engine.Engine") -> None: """ Refresh the contents of the materialized view. @@ -208,9 +214,9 @@ def refresh_mv(cls, bind): Notes ----- - This method issues a REFRESH MATERIALIZED VIEW statement and assumes - backend support (e.g. PostgreSQL). Concurrent refresh semantics are - not handled here. + This method issues a backend-specific refresh statement. With the + built-in backends, materialized views are PostgreSQL-only. + Concurrent refresh semantics are not handled here. Examples -------- @@ -219,7 +225,8 @@ def refresh_mv(cls, bind): RecentObservationMV.refresh_mv(conn) ``` """ - bind.execute(sa.text(f"REFRESH MATERIALIZED VIEW {cls.__mv_name__};")) + backend = resolve_backend(bind) + backend.refresh_materialized_view(bind, cls.__mv_name__) def resolve_mv_refresh_order(mv_classes: list[type[MaterializedViewMixin]]) -> list[type]: @@ -271,7 +278,7 @@ def resolve_mv_refresh_order(mv_classes: list[type[MaterializedViewMixin]]) -> l return [name_to_mv[name] for name in ordered] -def refresh_all_mvs(bind, mv_classes): +def refresh_all_mvs(bind: "sa.engine.Connection | sa.engine.Engine", mv_classes: list[type[MaterializedViewMixin]]) -> None: """ `refresh_all_mvs` @@ -289,7 +296,7 @@ def refresh_all_mvs(bind, mv_classes): refresh_all_mvs(engine, ALL_MVS) ``` """ - ordered = resolve_mv_refresh_order(mv_classes) + ordered: list[type[MaterializedViewMixin]] = resolve_mv_refresh_order(mv_classes) for mv in ordered: - mv.refresh_mv(bind) \ No newline at end of file + mv.refresh_mv(bind) diff --git a/src/orm_loader/tables/loadable_table.py b/src/orm_loader/tables/loadable_table.py index f8a91c5..d5db296 100644 --- a/src/orm_loader/tables/loadable_table.py +++ b/src/orm_loader/tables/loadable_table.py @@ -1,16 +1,16 @@ +# pyright: reportPrivateUsage=false import sqlalchemy as sa import sqlalchemy.orm as so import logging -from typing import Type, ClassVar, Optional +from typing import Type, ClassVar, Optional, Any from pathlib import Path from contextlib import contextmanager from .orm_table import ORMTableBase from .typing import CSVTableProtocol +from ..backends.resolve import resolve_backend from ..loaders.loader_interface import LoaderInterface, LoaderContext, PandasLoader, ParquetLoader -from ..loaders.loading_helpers import quick_load_pg -from ..helpers.bulk import restore_fk_check, disable_fk_check logger = logging.getLogger(__name__) @@ -66,7 +66,7 @@ def staging_tablename(cls: Type[CSVTableProtocol]) -> str: str The staging table name. """ - if cls._staging_tablename: + if cls._staging_tablename: # type: ignore return cls._staging_tablename return f"_staging_{cls.__tablename__}" @@ -93,74 +93,30 @@ def create_staging_table( NotImplementedError If the database dialect is unsupported. """ - table = cls.__table__ - session.execute(sa.text(f"""DROP TABLE IF EXISTS "{cls.staging_tablename()}";""")) - if session.bind is None: raise RuntimeError("Session is not bound to an engine") - - dialect = session.bind.dialect.name - - if dialect == "postgresql": - logger.info("Disabling indices on staging table for performance") - session.execute(sa.text(f''' - CREATE UNLOGGED TABLE "{cls.staging_tablename()}" - (LIKE "{table.name}" INCLUDING DEFAULTS INCLUDING CONSTRAINTS); - ''')) - - # Need to drop the columns we are not going to load into, otherwise the COPY will fail - computed_cols = [c.name for c in table.columns if c.computed is not None] - for col in computed_cols: - session.execute(sa.text(f'ALTER TABLE "{cls.staging_tablename()}" DROP COLUMN "{col}";')) - - elif dialect == "sqlite": - - metadata = sa.MetaData() - - staging_columns = [] - for col in table.columns: - staging_columns.append( - sa.Column( - col.name, - col.type, - nullable=True, - ) - ) - - staging_table = sa.Table( - cls.staging_tablename(), - metadata, - *staging_columns, - ) - - conn = session.connection() - metadata.create_all(bind=conn, tables=[staging_table]) - # this borks on date cols because it loses the date - # specification and reverts to NUM - # - changing to metadata.create_all approach for sqlite - # but not postgresql for now to keep unlogged table feature - # session.execute(sa.text(f''' - # CREATE TABLE "{cls.staging_tablename()}" AS - # SELECT * FROM "{table.name}" WHERE 0; - # ''')) - else: - raise NotImplementedError( - f"Staging table creation not implemented for dialect '{dialect}'" - ) - # query the sense of having internal commit here, but for now - # it is required for the ORM-based fallback loader to function - # cleanly for external pipeline purposes - - session.commit() + backend = resolve_backend(session) + backend.create_staging_table(cls, session, cls.staging_tablename()) @classmethod @contextmanager - def manage_indices(cls: Type['CSVTableProtocol'], session: so.Session): + def manage_indices( + cls: Type['CSVTableProtocol'], + session: so.Session, + index_strategy: str = "auto", + ): """ - Temporarily drops non-primary key indices before a bulk operation - and recreates them afterwards to prevent write amplification. + Manage non-primary-key indexes around a staged merge. + + ``index_strategy`` may be ``"auto"``, ``"drop_rebuild"``, or + ``"keep"``. The backend decides what ``"auto"`` means. At the + moment SQLite keeps indexes by default, while PostgreSQL drops + and rebuilds them. """ - indices = list(cls.__table__.indexes) + backend = resolve_backend(session) + resolved_index_strategy = backend.resolve_index_strategy(index_strategy) + + indices = list(cls.__table__.indexes) if resolved_index_strategy == "drop_rebuild" else [] inspector = sa.inspect(session.bind) assert inspector is not None, "Failed to create inspector for index management" @@ -174,20 +130,16 @@ def manage_indices(cls: Type['CSVTableProtocol'], session: so.Session): session.execute(sa.schema.DropIndex(idx)) session.commit() - # session.commit() above restores the original state of the session. We need that one after we are done - previous_fk_state = disable_fk_check(session) - try: - yield - session.commit() + with backend.merge_context(cls, session): + yield + session.commit() except Exception as e: session.rollback() logger.error(f"Table `{cls.__tablename__}`: Merge operation failed - {e}") raise finally: - restore_fk_check(session, previous_fk_state) - if indices: logger.info(f"Table `{cls.__tablename__}`: Verifying/Rebuilding indices.") inspector.clear_cache() # Required to ensure we get the current state of the database after potential changes @@ -269,25 +221,23 @@ def load_staging( if loader_context.session.bind is None: raise RuntimeError("Session is not bound to an engine") - dialect = loader_context.session.bind.dialect.name + backend = resolve_backend(loader_context.session) total = 0 try: cls.create_staging_table(loader_context.session) - if dialect == "postgresql": - try: - total = quick_load_pg( - path=loader_context.path, - session=loader_context.session, - tablename=cls.staging_tablename(), - quote_mode=loader_context.quote_mode, - ) + try: + total = backend.load_staging_fast( + loader_context=loader_context, + staging_name=cls.staging_tablename(), + ) + if total is not None: return total - except Exception as e: - loader_context.session.rollback() - logger.warning(f"COPY failed for {cls.staging_tablename()}: {e}") - logger.info('Falling back to ORM-based load functionality') + except Exception as e: + loader_context.session.rollback() + logger.warning(f"Fast-path load failed for {cls.staging_tablename()}: {e}") + logger.info('Falling back to ORM-based load functionality') total = cls.orm_staging_load( loader=loader, @@ -347,6 +297,7 @@ def load_csv( chunksize: int | None = None, merge_strategy: str = "replace", quote_mode: str = "csv", + index_strategy: str = "auto", ) -> int: """ @@ -374,11 +325,16 @@ def load_csv( Optional chunk size for incremental loading. merge_strategy Merge strategy to apply (e.g. ``replace`` or ``upsert``). + quote_mode + Quoting mode used by the PostgreSQL fast-path loader. + index_strategy + Index handling strategy during merge. Use ``"auto"`` to let + the backend choose a sensible default. Returns ------- int - Number of rows loaded. + Number of rows loaded into staging before merge. """ logger.debug(f"Table `{cls.__tablename__}`: Loading CSV from {path}") @@ -403,12 +359,12 @@ def load_csv( loader = cls._select_loader(path) # Load to staging (Indices are already excluded via updated create_staging_table) - logger.info(f"Table `{cls.__tablename__}`: Loading data into unlogged staging table") + logger.info(f"Table `{cls.__tablename__}`: Loading data into staging table") total = cls.load_staging(loader=loader, loader_context=loader_context) # Merge staging to target (Wrapped in our index dropper!) logger.info(f"Table `{cls.__tablename__}`: Merging staging data into target table") - with cls.manage_indices(session): + with cls.manage_indices(session, index_strategy=index_strategy): cls.merge_from_staging(session, merge_strategy=merge_strategy) cls.drop_staging_table(session) @@ -423,8 +379,7 @@ def _merge_replace( session: so.Session, target: str, staging: str, - pk_cols: list[str], - dialect: str + pk_cols: list[str] ): """ Merge staging data by replacing existing rows. @@ -432,37 +387,8 @@ def _merge_replace( Existing target rows matching the staging primary keys are deleted prior to insertion. """ - if dialect == "postgresql": - pk_join = " AND ".join( - f't."{c}" = s."{c}"' for c in pk_cols - ) - - session.execute(sa.text(f""" - DELETE FROM "{target}" t - USING "{staging}" s - WHERE {pk_join}; - """)) - - elif dialect == "sqlite": - if len(pk_cols) == 1: - pk = pk_cols[0] - session.execute(sa.text(f""" - DELETE FROM "{target}" - WHERE "{pk}" IN ( - SELECT "{pk}" FROM "{staging}" - ); - """)) - else: - pk_match = " AND ".join( - f'"{target}"."{c}" = "{staging}"."{c}"' for c in pk_cols - ) - session.execute(sa.text(f""" - DELETE FROM "{target}" - WHERE EXISTS ( - SELECT 1 FROM "{staging}" - WHERE {pk_match} - ); - """)) + backend = resolve_backend(session) + backend.merge_replace(cls, session, target, staging, pk_cols) @classmethod def _merge_insert( @@ -470,18 +396,12 @@ def _merge_insert( session: so.Session, target: str, staging: str - ): + ): """ Insert all rows from the staging table into the target table. """ - # Get all columns that are NOT computed - insertable_cols = [c.name for c in cls.__table__.columns if c.computed is None] - cols_str = ", ".join(f'"{c}"' for c in insertable_cols) - - session.execute(sa.text(f""" - INSERT INTO "{target}" ({cols_str}) - SELECT {cols_str} FROM "{staging}"; - """)) + backend = resolve_backend(session) + backend.merge_insert(cls, session, target, staging) @classmethod @@ -490,33 +410,13 @@ def _merge_upsert( session: so.Session, target: str, staging: str, - pk_cols: list[str], - dialect: str + pk_cols: list[str] ): """ Merge staging data using an upsert strategy. """ - - # Get all columns that are NOT computed - insertable_cols = [c.name for c in cls.__table__.columns if c.computed is None] - cols_str = ", ".join(f'"{c}"' for c in insertable_cols) - - if dialect == "postgresql": - # INSERT … ON CONFLICT DO NOTHING - session.execute(sa.text(f""" - INSERT INTO "{target}" ({cols_str}) - SELECT {cols_str} FROM "{staging}" - ON CONFLICT ({", ".join(f'"{c}"' for c in pk_cols)}) DO NOTHING; - """)) - - elif dialect == "sqlite": - session.execute(sa.text(f""" - INSERT OR IGNORE INTO "{target}" ({cols_str}) - SELECT {cols_str} FROM "{staging}"; - """)) - - else: - raise NotImplementedError + backend = resolve_backend(session) + backend.merge_upsert(cls, session, target, staging, pk_cols) @classmethod def merge_from_staging( @@ -540,15 +440,12 @@ def merge_from_staging( if not session.bind: raise RuntimeError("Session is not bound to an engine") - - dialect = session.bind.dialect.name if merge_strategy == "replace": cls._merge_replace( session=session, target=target, staging=staging, pk_cols=pk_cols, - dialect=dialect, ) cls._merge_insert( session=session, @@ -561,7 +458,6 @@ def merge_from_staging( target=target, staging=staging, pk_cols=pk_cols, - dialect=dialect, ) else: raise ValueError(f"Unknown merge strategy '{merge_strategy}'") @@ -571,12 +467,11 @@ def drop_staging_table(cls: Type[CSVTableProtocol], session: so.Session): """ Drop the staging table if it exists. """ - session.execute( - sa.text(f'DROP TABLE IF EXISTS "{cls.staging_tablename()}"') - ) + backend = resolve_backend(session) + backend.drop_staging_table(session, cls.staging_tablename()) @classmethod - def csv_columns(cls) -> dict[str, sa.ColumnElement]: + def csv_columns(cls) -> dict[str, sa.ColumnElement[Any]]: """ Return a mapping of CSV column names to model columns. @@ -590,4 +485,4 @@ def csv_columns(cls) -> dict[str, sa.ColumnElement]: """ cols = cls.model_columns() computed_names = {c.name for c in cls.__table__.columns if c.computed is not None} # type: ignore - return {k: v for k, v in cols.items() if k not in computed_names} \ No newline at end of file + return {k: v for k, v in cols.items() if k not in computed_names} diff --git a/src/orm_loader/tables/orm_table.py b/src/orm_loader/tables/orm_table.py index a7748e6..771f0e4 100644 --- a/src/orm_loader/tables/orm_table.py +++ b/src/orm_loader/tables/orm_table.py @@ -1,7 +1,7 @@ import sqlalchemy as sa import sqlalchemy.orm as so from sqlalchemy.exc import StatementError -from typing import Any, Tuple, Type, cast +from typing import Any import logging from .allocators import IdAllocator from ..helpers import normalise_null @@ -46,7 +46,7 @@ class ORMTableBase: __abstract__ = True @classmethod - def mapper_for(cls: Type) -> so.Mapper: + def mapper_for(cls: type[Any]) -> so.Mapper[Any]: """ Return the SQLAlchemy mapper associated with this ORM class. @@ -63,13 +63,13 @@ def mapper_for(cls: Type) -> so.Mapper: TypeError If the class is not a mapped SQLAlchemy ORM class. """ - mapper = sa.inspect(cls) + mapper: so.Mapper[Any] = sa.inspect(cls) if not mapper: raise TypeError(f"{cls.__name__} is not a mapped ORM class") - return cast(so.Mapper, mapper) + return mapper @classmethod - def pk_columns(cls) -> list[sa.ColumnElement]: + def pk_columns(cls) -> list[sa.ColumnElement[Any]]: """ Return the primary key columns for the mapped table. @@ -120,7 +120,7 @@ def pk_values(cls, obj: Any) -> dict[str, Any]: return {c.key: getattr(obj, c.key) for c in cls.pk_columns() if c.key is not None} @classmethod - def pk_tuple(cls, obj: Any) -> Tuple[Any, ...]: + def pk_tuple(cls, obj: Any) -> tuple[Any, ...]: """ Extract primary key values from an ORM instance as a tuple. @@ -143,7 +143,7 @@ def pk_tuple(cls, obj: Any) -> Tuple[Any, ...]: ) @classmethod - def model_columns(cls) -> dict[str, sa.ColumnElement]: + def model_columns(cls) -> dict[str, sa.ColumnElement[Any]]: """ Return all mapped columns for the table. @@ -153,7 +153,7 @@ def model_columns(cls) -> dict[str, sa.ColumnElement]: A mapping of column name to column object. """ mapper = cls.mapper_for() - return {c.key: c for c in mapper.columns if c.key is not None} + return {c.key: c for c in mapper.columns} @classmethod def required_columns(cls) -> set[str]: @@ -177,11 +177,11 @@ def required_columns(cls) -> set[str]: return { c.key for c in mapper.columns - if not c.nullable and not c.default and not c.server_default and c.key is not None + if not c.nullable and not c.default and not c.server_default } @classmethod - def max_id(cls, session) -> int: + def max_id(cls, session: so.Session) -> int: """ Return the maximum value of the primary key column. @@ -211,7 +211,7 @@ def max_id(cls, session) -> int: return session.query(sa.func.max(pk)).scalar() or 0 @classmethod - def allocator(cls, session) -> IdAllocator: + def allocator(cls, session: so.Session) -> IdAllocator: """ Create an ID allocator initialised from the current table state. @@ -251,7 +251,7 @@ def clean_kwargs( """ cols = cls.model_columns() - cleaned = {} + cleaned: dict[str, Any] = {} for k, v in data.items(): if k not in cols: continue # ignore unknown keys safely diff --git a/src/orm_loader/tables/serialisable_table.py b/src/orm_loader/tables/serialisable_table.py index e340b45..2310855 100644 --- a/src/orm_loader/tables/serialisable_table.py +++ b/src/orm_loader/tables/serialisable_table.py @@ -1,10 +1,14 @@ -from .orm_table import ORMTableBase -from typing import Any +from typing import Any, Unpack +from collections.abc import Iterator import json import hashlib import datetime -def json_default(obj) -> str: +from .orm_table import ORMTableBase +from .typing import ToDictKwargs + + +def json_default(obj: Any) -> str: """ Default JSON serialisation handler for unsupported types. @@ -79,7 +83,7 @@ def to_dict( dict[str, Any] A dictionary representation of the ORM row. """ - data = {} + data: dict[str, Any] = {} for key, _ in self.model_columns().items(): if only and key not in only: continue @@ -91,7 +95,7 @@ def to_dict( data[key] = value return data - def to_json(self, **kwargs) -> str: + def to_json(self, **kwargs: Unpack[ToDictKwargs]) -> str: """ Serialise the ORM instance to a JSON string. @@ -133,7 +137,7 @@ def fingerprint(self) -> str: payload = self.to_json(include_nulls=True) return hashlib.sha256(payload.encode("utf-8")).hexdigest() - def __iter__(self): + def __iter__(self) -> Iterator[tuple[str, Any]]: """ Iterate over the ORM instance as ``(key, value)`` pairs. @@ -147,7 +151,7 @@ def __iter__(self): """ yield from self.to_dict().items() - def __json__(self): + def __json__(self) -> dict[str, Any]: """ Return a JSON-serialisable representation of the ORM instance. diff --git a/src/orm_loader/tables/typing.py b/src/orm_loader/tables/typing.py index b61d183..2df700a 100644 --- a/src/orm_loader/tables/typing.py +++ b/src/orm_loader/tables/typing.py @@ -1,4 +1,4 @@ -from typing import Protocol, ClassVar, runtime_checkable, TYPE_CHECKING, Optional, Type, Dict, Any +from typing import Protocol, ClassVar, runtime_checkable, TYPE_CHECKING, Optional, Type, Dict, Any, Unpack, TypedDict import sqlalchemy.orm as so import sqlalchemy as sa from pathlib import Path @@ -6,6 +6,11 @@ if TYPE_CHECKING: from ..loaders import LoaderContext, LoaderInterface +class ToDictKwargs(TypedDict, total=False): + include_nulls: bool + only: set[str] | None + exclude: set[str] | None + @runtime_checkable class ORMTableProtocol(Protocol): """ @@ -28,17 +33,16 @@ class ORMTableProtocol(Protocol): metadata: ClassVar[sa.MetaData] @classmethod - def mapper_for(cls) -> so.Mapper: ... + def mapper_for(cls) -> so.Mapper[Any]: ... @classmethod def pk_names(cls) -> list[str]: ... @classmethod - def pk_columns(cls) -> list[sa.ColumnElement]: ... + def pk_columns(cls) -> list[sa.ColumnElement[Any]]: ... @classmethod - def model_columns(cls) -> dict[str, sa.ColumnElement]: ... - + def model_columns(cls) -> dict[str, sa.ColumnElement[Any]]: ... @runtime_checkable class CSVTableProtocol(ORMTableProtocol, Protocol): @@ -81,6 +85,7 @@ def load_csv( chunksize: int | None = None, merge_strategy: str = "replace", quote_mode: str = "csv", + index_strategy: str = "auto", ) -> int: ... @classmethod @@ -99,13 +104,13 @@ def drop_staging_table(cls, session: so.Session) -> None: ... def _merge_insert(cls, session: so.Session, target: str, staging: str) -> None: ... @classmethod - def _merge_replace(cls, session: so.Session, target: str, staging: str, pk_cols: list[str], dialect: str) -> None: ... + def _merge_replace(cls, session: so.Session, target: str, staging: str, pk_cols: list[str]) -> None: ... @classmethod - def _merge_upsert(cls, session: so.Session, target: str, staging: str, pk_cols: list[str], dialect: str) -> None: ... + def _merge_upsert(cls, session: so.Session, target: str, staging: str, pk_cols: list[str]) -> None: ... @classmethod - def manage_indices(cls, session: so.Session) -> AbstractContextManager[None]: + def manage_indices(cls, session: so.Session, index_strategy: str = "auto") -> AbstractContextManager[None]: ... @@ -130,11 +135,10 @@ def to_dict( exclude: set[str] | None = None, ) -> Dict[str, Any]: ... - def to_json(self, **kwargs) -> str: ... + def to_json(self, **kwargs: Unpack[ToDictKwargs]) -> str: ... def fingerprint(self) -> str: ... def __iter__(self) -> Any: ... def __json__(self) -> Any: ... - diff --git a/tests/backends/test_base_backend.py b/tests/backends/test_base_backend.py new file mode 100644 index 0000000..dc81d4d --- /dev/null +++ b/tests/backends/test_base_backend.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +import importlib +import importlib.abc +import sys + +import pytest +import sqlalchemy as sa +import sqlalchemy.orm as so + +from orm_loader.backends import BackendCapabilities, DatabaseBackend, resolve_backend + + +class _BlockPsycopg(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path=None, target=None): + if fullname == "psycopg" or fullname.startswith("psycopg."): + raise ModuleNotFoundError("No module named 'psycopg'") + return None + + +class FakeBackend(DatabaseBackend): + def __init__(self) -> None: + self.calls: list[tuple[str, object]] = [] + + @property + def name(self) -> str: + return "fake" + + @property + def dialect_names(self) -> tuple[str, ...]: + return ("fake",) + + @property + def capabilities(self) -> BackendCapabilities: + return BackendCapabilities( + supports_fast_load=True, + supports_fk_toggle=True, + ) + + def create_staging_table(self, table_cls, session, staging_name) -> None: + return None + + def drop_staging_table(self, session, staging_name) -> None: + return None + + def merge_replace(self, table_cls, session, target_name, staging_name, pk_cols) -> None: + return None + + def merge_upsert(self, table_cls, session, target_name, staging_name, pk_cols) -> None: + return None + + def merge_insert(self, table_cls, session, target_name, staging_name) -> None: + return None + + def disable_fk_check(self, session) -> str | int: + self.calls.append(("disable_fk_check", session)) + return "enabled" + + def enable_fk_check(self, session) -> str | int: + self.calls.append(("enable_fk_check", session)) + return "disabled" + + def restore_fk_check(self, session, previous_state: str | int) -> None: + self.calls.append(("restore_fk_check", previous_state)) + + def create_materialized_view(self, bind, name: str, selectable: sa.sql.Select) -> None: + return None + + def refresh_materialized_view(self, bind, name: str) -> None: + return None + + +class _ComputedTable: + __table__ = sa.Table( + "computed_table", + sa.MetaData(), + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("name", sa.String), + sa.Column("slug", sa.String, sa.Computed("lower(name)")), + ) + + +def test_backend_capabilities_defaults(): + caps = BackendCapabilities() + + assert caps.supports_fast_load is False + assert caps.supports_unlogged_staging is False + assert caps.supports_fk_toggle is False + assert caps.supports_materialized_views is False + + +def test_database_backend_is_abstract(): + with pytest.raises(TypeError): + DatabaseBackend() + + +def test_fake_backend_can_implement_contract(): + backend = FakeBackend() + + assert backend.name == "fake" + assert backend.dialect_names == ("fake",) + assert backend.capabilities.supports_fast_load is True + assert backend.capabilities.supports_fk_toggle is True + assert backend.supports_dialect("fake") is True + assert backend.supports_dialect("sqlite") is False + assert backend.resolve_index_strategy("auto") == "drop_rebuild" + assert backend.resolve_index_strategy("keep") == "keep" + assert backend.load_staging_fast(None, "staging") is None + + with backend.merge_context(None, None): + pass + + +def test_require_capability_passes_for_supported_feature(): + backend = FakeBackend() + + backend._require_capability("supports_fast_load", "fast loading") + + +def test_require_capability_raises_for_unsupported_feature(): + backend = FakeBackend() + + with pytest.raises(NotImplementedError, match="does not support materialized views"): + backend._require_capability("supports_materialized_views", "materialized views") + + +def test_require_capability_raises_for_unknown_flag(): + backend = FakeBackend() + + with pytest.raises(AttributeError, match="Unknown backend capability"): + backend._require_capability("not_a_capability", "something") + + +def test_resolve_index_strategy_raises_for_invalid_value(): + backend = FakeBackend() + + with pytest.raises(ValueError, match="Unknown index_strategy"): + backend.resolve_index_strategy("not-valid") + + +def test_insertable_column_names_exclude_computed_columns(): + backend = FakeBackend() + + assert backend._insertable_column_names(_ComputedTable) == ["id", "name"] + + +def test_bulk_load_context_toggles_fk_and_restores(session): + backend = FakeBackend() + + with backend.bulk_load_context(session): + pass + + assert backend.calls == [ + ("disable_fk_check", session), + ("restore_fk_check", "enabled"), + ] + + +def test_bulk_load_context_without_fk_toggle(session): + backend = FakeBackend() + + with backend.bulk_load_context(session, disable_fk=False): + pass + + assert backend.calls == [] + + +def test_bulk_load_context_raises_when_capability_missing(session): + class NoFKBackend(FakeBackend): + @property + def capabilities(self) -> BackendCapabilities: + return BackendCapabilities() + + backend = NoFKBackend() + + with pytest.raises(NotImplementedError, match="does not support foreign key toggling"): + with backend.bulk_load_context(session): + pass + + +def test_bulk_load_context_rolls_back_and_restores(session): + backend = FakeBackend() + + with pytest.raises(RuntimeError, match="boom"): + with backend.bulk_load_context(session): + raise RuntimeError("boom") + + assert backend.calls == [ + ("disable_fk_check", session), + ("restore_fk_check", "enabled"), + ] + + +def test_backends_package_exports(): + import orm_loader.backends as backends + + assert backends.DatabaseBackend is DatabaseBackend + assert backends.BackendCapabilities is BackendCapabilities + assert backends.resolve_backend is resolve_backend + + +def test_resolve_backend_for_sqlite_engine_and_session(): + engine = sa.create_engine("sqlite:///:memory:", future=True) + session = so.Session(engine) + + try: + engine_backend = resolve_backend(engine) + session_backend = resolve_backend(session) + + assert engine_backend.name == "sqlite" + assert session_backend.name == "sqlite" + finally: + session.close() + + +def test_resolve_backend_raises_for_unknown_dialect(): + class _Unknown: + class dialect: + name = "unknown" + + with pytest.raises(NotImplementedError, match="No backend registered"): + resolve_backend(_Unknown()) + + +def test_backends_import_does_not_require_psycopg(): + blocker = _BlockPsycopg() + original = sys.modules.pop("orm_loader.backends", None) + sys.meta_path.insert(0, blocker) + + try: + module = importlib.import_module("orm_loader.backends") + assert module.DatabaseBackend is not None + finally: + sys.meta_path.remove(blocker) + sys.modules.pop("orm_loader.backends", None) + if original is not None: + sys.modules["orm_loader.backends"] = original diff --git a/tests/backends/test_postgres_backend.py b/tests/backends/test_postgres_backend.py new file mode 100644 index 0000000..8fafa0a --- /dev/null +++ b/tests/backends/test_postgres_backend.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from orm_loader.backends import PostgresBackend + + +class _ComputedTable: + __table__ = sa.Table( + "target_table", + sa.MetaData(), + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("name", sa.String), + sa.Column("slug", sa.String, sa.Computed("lower(name)")), + ) + + +class _FakeSession: + def __init__(self, scalar_result="origin") -> None: + self.statements: list[str] = [] + self.scalar_result = scalar_result + self.commits = 0 + + def execute(self, statement): + if hasattr(statement, "compile"): + sql = str(statement.compile(dialect=postgresql.dialect())) + else: + sql = str(statement) + self.statements.append(sql) + + class _Result: + def __init__(self, value): + self._value = value + + def scalar(self): + return self._value + + return _Result(self.scalar_result) + + def commit(self) -> None: + self.commits += 1 + + +def test_postgres_backend_identity_and_capabilities(): + backend = PostgresBackend() + + assert backend.name == "postgres" + assert backend.supports_dialect("postgresql") is True + assert backend.capabilities.supports_fast_load is True + assert backend.capabilities.supports_unlogged_staging is True + assert backend.capabilities.supports_fk_toggle is True + assert backend.capabilities.supports_materialized_views is True + + +def test_postgres_backend_create_staging_table_drops_computed_columns(): + backend = PostgresBackend() + session = _FakeSession() + + backend.create_staging_table(_ComputedTable, session, "_staging_target_table") + + assert any('DROP TABLE IF EXISTS "_staging_target_table"' in sql for sql in session.statements) + assert any('CREATE UNLOGGED TABLE "_staging_target_table"' in sql for sql in session.statements) + assert any('ALTER TABLE "_staging_target_table" DROP COLUMN "slug"' in sql for sql in session.statements) + assert session.commits == 1 + + +def test_postgres_backend_drop_staging_table(): + backend = PostgresBackend() + session = _FakeSession() + + backend.drop_staging_table(session, "_staging_target_table") + + assert session.statements == ['DROP TABLE IF EXISTS "_staging_target_table"'] + + +def test_postgres_backend_fk_methods_emit_expected_sql(): + backend = PostgresBackend() + session = _FakeSession() + + previous = backend.disable_fk_check(session) + enabled = backend.enable_fk_check(session) + backend.restore_fk_check(session, previous) + + assert previous == "origin" + assert enabled == "origin" + assert session.statements == [ + "SHOW session_replication_role", + "SET session_replication_role = 'replica'", + "SHOW session_replication_role", + "SET session_replication_role = 'origin'", + "SET session_replication_role = 'origin'", + ] + + +def test_postgres_backend_merge_replace_uses_using_delete(): + backend = PostgresBackend() + session = _FakeSession() + + backend.merge_replace(_ComputedTable, session, "target_table", "_staging_target_table", ["id", "name"]) + + sql = session.statements[0] + assert 'DELETE FROM "target_table" t' in sql + assert 'USING "_staging_target_table" s' in sql + assert 't."id" = s."id" AND t."name" = s."name"' in sql + + +def test_postgres_backend_merge_insert_excludes_computed_columns(): + backend = PostgresBackend() + session = _FakeSession() + + backend.merge_insert(_ComputedTable, session, "target_table", "_staging_target_table") + + sql = session.statements[0] + assert 'INSERT INTO "target_table" ("id", "name")' in sql + assert 'SELECT "id", "name" FROM "_staging_target_table"' in sql + + +def test_postgres_backend_merge_upsert_excludes_computed_columns(): + backend = PostgresBackend() + session = _FakeSession() + + backend.merge_upsert(_ComputedTable, session, "target_table", "_staging_target_table", ["id"]) + + sql = session.statements[0] + assert 'INSERT INTO "target_table" ("id", "name")' in sql + assert 'ON CONFLICT ("id") DO NOTHING' in sql + + +def test_postgres_backend_materialized_view_methods_emit_expected_sql(): + backend = PostgresBackend() + session = _FakeSession() + selectable = sa.select(sa.literal(1).label("n")) + + backend.create_materialized_view(session, "mv_test", selectable) + backend.refresh_materialized_view(session, "mv_test") + + assert any("CREATE MATERIALIZED VIEW IF NOT EXISTS mv_test as SELECT" in sql for sql in session.statements) + assert any("REFRESH MATERIALIZED VIEW mv_test;" == sql for sql in session.statements) diff --git a/tests/backends/test_sqlite_backend.py b/tests/backends/test_sqlite_backend.py new file mode 100644 index 0000000..e7adc4f --- /dev/null +++ b/tests/backends/test_sqlite_backend.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import sqlite3 +from pathlib import Path + +import sqlalchemy as sa +import sqlalchemy.orm as so + +from orm_loader.backends import SQLiteBackend +from orm_loader.helpers.sqlite import attach_sqlite_bulk_load_pragmas + + +class _ComputedTable: + __table__ = sa.Table( + "target_table", + sa.MetaData(), + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("name", sa.String), + sa.Column("slug", sa.String, sa.Computed("lower(name)")), + ) + + +class _FakeSession: + def __init__(self, scalar_result=1) -> None: + self.statements: list[str] = [] + self.scalar_result = scalar_result + + def execute(self, statement): + self.statements.append(str(statement)) + + class _Result: + def __init__(self, value): + self._value = value + + def scalar(self): + return self._value + + return _Result(self.scalar_result) + + +def test_sqlite_backend_identity_and_capabilities(): + backend = SQLiteBackend() + + assert backend.name == "sqlite" + assert backend.supports_dialect("sqlite") is True + assert backend.capabilities.supports_fast_load is False + assert backend.capabilities.supports_unlogged_staging is False + assert backend.capabilities.supports_fk_toggle is True + assert backend.capabilities.supports_materialized_views is False + assert backend.resolve_index_strategy("auto") == "keep" + + +def test_sqlite_backend_create_staging_table(): + backend = SQLiteBackend() + engine = sa.create_engine("sqlite:///:memory:", future=True) + session = so.Session(engine) + + try: + backend.create_staging_table(_ComputedTable, session, "_staging_target_table") + inspector = sa.inspect(engine) + assert inspector.has_table("_staging_target_table") is True + cols = inspector.get_columns("_staging_target_table") + assert [c["name"] for c in cols] == ["id", "name", "slug"] + assert all(c["nullable"] is True for c in cols) + finally: + session.close() + + +def test_sqlite_backend_drop_staging_table(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.drop_staging_table(session, "_staging_target_table") + + assert session.statements == ['DROP TABLE IF EXISTS "_staging_target_table"'] + + +def test_sqlite_backend_fk_methods_emit_expected_sql(): + backend = SQLiteBackend() + session = _FakeSession() + + previous = backend.disable_fk_check(session) + enabled = backend.enable_fk_check(session) + backend.restore_fk_check(session, previous) + + assert previous == 1 + assert enabled == 1 + assert session.statements == [ + "PRAGMA foreign_keys", + "PRAGMA foreign_keys = OFF", + "PRAGMA foreign_keys", + "PRAGMA foreign_keys = ON", + "PRAGMA foreign_keys = 1", + ] + + +def test_sqlite_backend_merge_replace_single_pk(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.merge_replace(_ComputedTable, session, "target_table", "_staging_target_table", ["id"]) + + sql = session.statements[0] + assert 'DELETE FROM "target_table"' in sql + assert 'SELECT "id" FROM "_staging_target_table"' in sql + + +def test_sqlite_backend_merge_replace_composite_pk(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.merge_replace(_ComputedTable, session, "target_table", "_staging_target_table", ["id", "name"]) + + sql = session.statements[0] + assert 'WHERE EXISTS (' in sql + assert '"target_table"."id" = "_staging_target_table"."id"' in sql + assert '"target_table"."name" = "_staging_target_table"."name"' in sql + + +def test_sqlite_backend_merge_insert_excludes_computed_columns(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.merge_insert(_ComputedTable, session, "target_table", "_staging_target_table") + + sql = session.statements[0] + assert 'INSERT INTO "target_table" ("id", "name")' in sql + assert 'SELECT "id", "name" FROM "_staging_target_table"' in sql + + +def test_sqlite_backend_merge_upsert_excludes_computed_columns(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.merge_upsert(_ComputedTable, session, "target_table", "_staging_target_table", ["id"]) + + sql = session.statements[0] + assert 'INSERT OR IGNORE INTO "target_table" ("id", "name")' in sql + + +def test_sqlite_backend_materialized_view_methods_raise(): + backend = SQLiteBackend() + session = _FakeSession() + selectable = sa.select(sa.literal(1).label("n")) + + try: + backend.create_materialized_view(session, "mv_test", selectable) + except NotImplementedError as exc: + assert "does not support materialized views" in str(exc) + else: + raise AssertionError("Expected create_materialized_view() to raise NotImplementedError") + + try: + backend.refresh_materialized_view(session, "mv_test") + except NotImplementedError as exc: + assert "does not support materialized views" in str(exc) + else: + raise AssertionError("Expected refresh_materialized_view() to raise NotImplementedError") + + +def test_sqlite_backend_configures_bulk_load_pragmas(tmp_path: Path): + backend = SQLiteBackend() + db_path = tmp_path / "test.db" + engine = sa.create_engine(f"sqlite:///{db_path}", future=True) + backend.install_engine_hooks(engine) + + with engine.connect() as conn: + busy_timeout = conn.execute(sa.text("PRAGMA busy_timeout")).scalar_one() + journal_mode = conn.execute(sa.text("PRAGMA journal_mode")).scalar_one() + + assert busy_timeout == 60000 + assert str(journal_mode).lower() == "wal" + + +def test_sqlite_backend_restore_journal_mode(tmp_path: Path): + backend = SQLiteBackend() + db_path = tmp_path / "journal.db" + engine = sa.create_engine(f"sqlite:///{db_path}", future=True) + backend.install_engine_hooks(engine) + + with engine.begin() as conn: + conn.execute(sa.text("CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT)")) + conn.execute(sa.text("INSERT INTO t (name) VALUES ('x')")) + + engine.dispose() + backend.restore_journal_mode(db_path) + + with sqlite3.connect(db_path.resolve()) as conn: + journal_mode = conn.execute("PRAGMA journal_mode").fetchone()[0] + + assert str(journal_mode).lower() == "delete" + + +def test_attach_sqlite_bulk_load_pragmas_installs_backend_hook(tmp_path: Path): + db_path = tmp_path / "attached.db" + engine = sa.create_engine(f"sqlite:///{db_path}", future=True) + + attach_sqlite_bulk_load_pragmas(engine, busy_timeout_ms=45000) + + with engine.connect() as conn: + busy_timeout = conn.execute(sa.text("PRAGMA busy_timeout")).scalar_one() + journal_mode = conn.execute(sa.text("PRAGMA journal_mode")).scalar_one() + + assert busy_timeout == 45000 + assert str(journal_mode).lower() == "wal" diff --git a/tests/conftest.py b/tests/conftest.py index e12e9a3..d509cbf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ def session(engine): yield s -POSTGRES_URL = "postgresql+psycopg2://test:test@localhost:55432/test_db" +POSTGRES_URL = "postgresql+psycopg://test:test@localhost:55432/test_db" @pytest.fixture(scope="session") def pg_engine(): @@ -54,4 +54,3 @@ def pg_session(pg_engine): - diff --git a/tests/models.py b/tests/models.py index b581b87..630361a 100644 --- a/tests/models.py +++ b/tests/models.py @@ -14,6 +14,9 @@ class PandasLoaderTable(CSVLoadableTableInterface, Base): class SimpleTable(Base, CSVLoadableTableInterface): __tablename__ = "test_table" + __table_args__ = ( + sa.Index("ix_test_table_name", "name"), + ) id: so.Mapped[int] = so.mapped_column(sa.Integer, primary_key=True) name: so.Mapped[str] = so.mapped_column(sa.String, nullable=False) @@ -32,4 +35,3 @@ class CompositeTable(Base, CSVLoadableTableInterface): a: so.Mapped[int] = so.mapped_column(sa.Integer, primary_key=True) b: so.Mapped[int] = so.mapped_column(sa.Integer, primary_key=True) value: so.Mapped[str] = so.mapped_column(sa.String) - diff --git a/uv.lock b/uv.lock index 0ec9570..6073ac0 100644 --- a/uv.lock +++ b/uv.lock @@ -618,7 +618,7 @@ wheels = [ [[package]] name = "orm-loader" -version = "0.3.27" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "chardet" }, @@ -639,6 +639,9 @@ dev = [ { name = "requests" }, { name = "ruff" }, ] +postgres = [ + { name = "psycopg", extra = ["binary"] }, +] [package.metadata] requires-dist = [ @@ -649,6 +652,7 @@ requires-dist = [ { name = "mkdocstrings-python", marker = "extra == 'dev'", specifier = ">=2.0.1" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.19.1" }, { name = "pandas", specifier = ">=2.3.3" }, + { name = "psycopg", extras = ["binary"], marker = "extra == 'postgres'", specifier = ">=3.2" }, { name = "pyarrow", specifier = ">=23.0.0" }, { name = "pygments", marker = "extra == 'dev'", specifier = ">=2.20.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.3" }, @@ -656,7 +660,7 @@ requires-dist = [ { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.14.11" }, { name = "sqlalchemy", specifier = ">=2.0.45" }, ] -provides-extras = ["dev"] +provides-extras = ["postgres", "dev"] [[package]] name = "packaging" @@ -750,6 +754,64 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "psycopg" +version = "3.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "tzdata", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/db/2f/cb91e5502ec9de1de6f1b76cfbf69531932725361168bb06963620c77e2e/psycopg-3.3.4.tar.gz", hash = "sha256:e21207764952cff81b6b8bdacad9a3939f2793367fdac2987b3aac36a651b5bc", size = 165799, upload-time = "2026-05-01T23:31:55.179Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/e0/7b3dee031daae7743609ce3c746565d4a3ed7c2c186479eb48e34e838c64/psycopg-3.3.4-py3-none-any.whl", hash = "sha256:b6bbc25ccf05c8fad3b061d9db2ef0909a555171b84b07f29458a447253d679a", size = 213001, upload-time = "2026-05-01T23:20:50.816Z" }, +] + +[package.optional-dependencies] +binary = [ + { name = "psycopg-binary", marker = "implementation_name != 'pypy'" }, +] + +[[package]] +name = "psycopg-binary" +version = "3.3.4" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/7d/03818e13ba7f36de93573c93ee3482006d3dfa8b0f8d28df511bad0a1a92/psycopg_binary-3.3.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5ab28a2a7649df3b72e6b674b4c190e448e8e77cf496a65bd846472048de2089", size = 4591122, upload-time = "2026-05-01T23:27:56.162Z" }, + { url = "https://files.pythonhosted.org/packages/a5/b9/11b341edf8d54e2694726b273fe9652b254d989f4f63e3ac6816ad6b55f4/psycopg_binary-3.3.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6402a9d8146cf4b3974ded3fd28a971e83dc6a0333eb7822524a3aa20b546578", size = 4669943, upload-time = "2026-05-01T23:28:04.522Z" }, + { url = "https://files.pythonhosted.org/packages/8b/18/4665bacd65e7865b4372fcd8abb8b9186ada4b0025f8c2ca691b364a556c/psycopg_binary-3.3.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:580ae30a5f95ccd90008ec697d3ed6a4a2047a516407ad904283fa42086936e9", size = 5469697, upload-time = "2026-05-01T23:28:11.337Z" }, + { url = "https://files.pythonhosted.org/packages/7c/b1/b83136c6e510593d9b0c759ba5384337bc4ad82d19fda675adc4b2703c84/psycopg_binary-3.3.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e7510c37550f91a187e3660a8cc50d4b760f8c3b8b2f89ebc5698cd2c7f2c85d", size = 5152995, upload-time = "2026-05-01T23:28:20.529Z" }, + { url = "https://files.pythonhosted.org/packages/67/8d/a9821e2a648afe6091989929982a3b0f00b2631a859cb81379728f08fb75/psycopg_binary-3.3.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:77df19583501ea288eaf15ac0fe7ad01e6d8091a91d5c41df5c718f307d8e31b", size = 6738180, upload-time = "2026-05-01T23:28:30.654Z" }, + { url = "https://files.pythonhosted.org/packages/7e/58/2e349e8d23905dc2317b80ac65f48fb6f821a4777a4e994a60da91c4850f/psycopg_binary-3.3.4-cp312-cp312-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:018fbed325936da502feb546642c982dcc4b9ffdea32dfef78dbf3b7f7ad4070", size = 4978828, upload-time = "2026-05-01T23:28:37.277Z" }, + { url = "https://files.pythonhosted.org/packages/45/48/57b00d03b4721878326122a1f1e6b0a90b85bcaec56b5b2f8ea6cfa45235/psycopg_binary-3.3.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:17a21953a9e5ff3a16dab692625a3676e2f101db5e40072f39dbee2250194d68", size = 4509757, upload-time = "2026-05-01T23:28:43.078Z" }, + { url = "https://files.pythonhosted.org/packages/25/37/33b47d8c007df69aec500df5889767c4d313748e8e9e27a2fef8a6dabcee/psycopg_binary-3.3.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:eb05ee1c2b817d27c537333224c9e83c7afb86fe7296ba970990068baf819b16", size = 4190546, upload-time = "2026-05-01T23:28:50.016Z" }, + { url = "https://files.pythonhosted.org/packages/ca/c6/32b0835dbc2122617902b649d76a91c1e75406e76bf3d595b0c3bb5ffad6/psycopg_binary-3.3.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:773d573e11f437ce0bdb95b7c18dc58390494f96d43f8b45b9760436114f7652", size = 3926197, upload-time = "2026-05-01T23:28:55.55Z" }, + { url = "https://files.pythonhosted.org/packages/cd/68/d190ef0c0c5b16ded07831dabc8ddd412f4cdab07ec6e30ed38d9bda0e1f/psycopg_binary-3.3.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:71e55ccbdfae79a2ed9c6369c3008a3025817ff9d7e27b32a2d84e2a4267e66e", size = 4236627, upload-time = "2026-05-01T23:29:05.336Z" }, + { url = "https://files.pythonhosted.org/packages/25/8f/81dcbc2e8454b74d14881275ea45f00791052dac531a9fa8be1730d1685b/psycopg_binary-3.3.4-cp312-cp312-win_amd64.whl", hash = "sha256:494ca54901be8cf9eb7e02c25b731f2317c378efa44f43e8f9bd0e1184ae7be4", size = 3560782, upload-time = "2026-05-01T23:29:11.967Z" }, + { url = "https://files.pythonhosted.org/packages/09/43/13e9c406fbbf354580476e248a16b64802a376873ebe6339e30bb655572d/psycopg_binary-3.3.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fbd1d4ed566895ad2d3bf4ddfd8bae90026930ddf29df3b9d91d32c8c47866a7", size = 4590377, upload-time = "2026-05-01T23:29:18.782Z" }, + { url = "https://files.pythonhosted.org/packages/22/be/2923cd7c3683e7afdecf4f10796a18de02f5c5ddc0969aa2ad0a8cdd3bbd/psycopg_binary-3.3.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:75a9067e236f9b9ae3535b66fe99bddb33d39c0de10112e49b9ab11eee53dc31", size = 4669023, upload-time = "2026-05-01T23:29:25.884Z" }, + { url = "https://files.pythonhosted.org/packages/96/a0/2c913d6fe13d6a8bd13597d36739bf47af063ad9399e402cfecab16f3c1e/psycopg_binary-3.3.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:b56b603ebcea8aa10b46228b8410ba7f13e7c2ee54389d4d9be0927fd8ce2a70", size = 5467423, upload-time = "2026-05-01T23:29:33.416Z" }, + { url = "https://files.pythonhosted.org/packages/e7/38/205d10bc1ad0df4a21c5c51659126bd3ea0ef98fcad1e852f78c249bb9c3/psycopg_binary-3.3.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c677c4ad433cb7150c8cd304a0769ae3bcfbe5ea0676eb53faa7b1443b16d0d3", size = 5151137, upload-time = "2026-05-01T23:29:42.013Z" }, + { url = "https://files.pythonhosted.org/packages/36/fc/f0381ddcd45eff3bb70dbca6823a996048d7f507b2ec3fc92c6fabc0fe87/psycopg_binary-3.3.4-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:26df2717e59c0473e4465a97dfb1b7afebaa479277870fd5784d1436470db47c", size = 6736671, upload-time = "2026-05-01T23:29:51.626Z" }, + { url = "https://files.pythonhosted.org/packages/95/40/fa545ae152c24327651e5624e4902121e808270be36c10b12e9939be09bc/psycopg_binary-3.3.4-cp313-cp313-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:1dc1f79fd16bb1f3f4421417a514607539f17804d95c7ed617265369d1981cae", size = 4979601, upload-time = "2026-05-01T23:29:56.961Z" }, + { url = "https://files.pythonhosted.org/packages/86/e4/2f8a47ee97f90cd2b933d0463081d35631ff419de2b8c984a5f369857de0/psycopg_binary-3.3.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:136f199a407b5348b9b857c504aff60c77622a28482e7195839ce1b51238c4cc", size = 4510513, upload-time = "2026-05-01T23:30:07.243Z" }, + { url = "https://files.pythonhosted.org/packages/0e/0e/94e842ff4a7f98ed162580ca2e8b8864b28c1e0350f2443f8ee47f821167/psycopg_binary-3.3.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b6f5a29e9c775b9f12a1a717aa7a2c80f9e1db6f27ba44a5b59c80ac61d2ffcf", size = 4187243, upload-time = "2026-05-01T23:30:15.352Z" }, + { url = "https://files.pythonhosted.org/packages/d0/83/fc6c174b672e29b7de996ea77b6cbddf46c891751c3355f6974292baa6b4/psycopg_binary-3.3.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:ee17a2cf4943cde261adfad1bbc5bf38d6b3776d7afff74c7cabcbeaeb08c260", size = 3927347, upload-time = "2026-05-01T23:30:21.186Z" }, + { url = "https://files.pythonhosted.org/packages/e9/65/768364d4a97a15b1a7f47ba52688c1686f22941d8332a8398cefc468e25f/psycopg_binary-3.3.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5c4ab71be17bdca30cb34c34c4e1496e2f5d6f20c199c12bad226070b22ef9bf", size = 4236393, upload-time = "2026-05-01T23:30:26.211Z" }, + { url = "https://files.pythonhosted.org/packages/bd/3b/218efbc9e645becd80cdf651acda05f85cfe546b7a9c0458c7cbc8fe1f74/psycopg_binary-3.3.4-cp313-cp313-win_amd64.whl", hash = "sha256:dbfdb9b6cc79f31104a7b162a2b921b765fcc62af6c00540a167a8de47e4ed38", size = 3564592, upload-time = "2026-05-01T23:30:31.764Z" }, + { url = "https://files.pythonhosted.org/packages/48/a6/828c9185701dab71b234c2a76c38a08b098ebfec5020716b4e93807492b5/psycopg_binary-3.3.4-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:28b7398fdd19db3232c884fb24550bdfe951221f510e195e233299e4c9b78f97", size = 4607292, upload-time = "2026-05-01T23:30:38.962Z" }, + { url = "https://files.pythonhosted.org/packages/92/58/5b40dbc9d839045c9dae956960e4fb6d20bcabe6c59a2aa34fc3a371913f/psycopg_binary-3.3.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1fbaa292a3c8bb61b45df1ad3da1908ccee7cb889db9425e3557d9e34e2a4829", size = 4687023, upload-time = "2026-05-01T23:30:47.227Z" }, + { url = "https://files.pythonhosted.org/packages/85/a9/793f0ac107a9003b48441d0d1f9f616d96e0f37458dd8dc12528ceff55fb/psycopg_binary-3.3.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:94596f9e7633ee3f6440711d43bb70aa31cc0a46a900ab8b4201a366ace5c9e7", size = 5486985, upload-time = "2026-05-01T23:30:55.517Z" }, + { url = "https://files.pythonhosted.org/packages/8f/26/42e8533497e2592334f68ec529cf5f840f7fa4e99575a4bb61aa184dbfbf/psycopg_binary-3.3.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8c0056529e68dbe9184cd4019a1f3d8f3a4ead2f6fc7a5afcf27d3314edd1277", size = 5168745, upload-time = "2026-05-01T23:31:01.904Z" }, + { url = "https://files.pythonhosted.org/packages/15/af/b7151776cc08d5935d45c833ec818a9beb417cf7c08239af1aafbdae78ee/psycopg_binary-3.3.4-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2c09aad7051326e7603c14e50636db9c01f78272dc54b3accff03d46370461e6", size = 6761486, upload-time = "2026-05-01T23:31:14.511Z" }, + { url = "https://files.pythonhosted.org/packages/d0/ed/c92533b9124712d592cbf1cd6c76da933a2e0acea81dfe1fbe7e735f0cff/psycopg_binary-3.3.4-cp314-cp314-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:514404ed543efd620c85602b747df2a23cf1241b4067199e1a66f2d2757aaa41", size = 4997427, upload-time = "2026-05-01T23:31:20.901Z" }, + { url = "https://files.pythonhosted.org/packages/a2/23/ccadfd0de416aa188356daa199453af24087b042e296088706d190ae0295/psycopg_binary-3.3.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:46893c26858be12cc49ca4226ed6a60b4bfccadd946b3bebb783a60b38788228", size = 4533549, upload-time = "2026-05-01T23:31:26.204Z" }, + { url = "https://files.pythonhosted.org/packages/fd/a0/c8f43cee36386f7bc891ab41a9d31ea07cf9826038e732da79f26b1e5f34/psycopg_binary-3.3.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:df1d567fc430f6df15c9fcf67d87685fc49bdb325adc0db5af1adfb2f44eb5c9", size = 4210256, upload-time = "2026-05-01T23:31:33.884Z" }, + { url = "https://files.pythonhosted.org/packages/4e/2c/c1547871be3790676e8868b38655496422f94f0978dfb66b74bdba2f1676/psycopg_binary-3.3.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:6b9016b1714da4dd5ecaaa75b82098aa5a0b87854ce9b092e21c27c4ae23e014", size = 3946204, upload-time = "2026-05-01T23:31:39.626Z" }, + { url = "https://files.pythonhosted.org/packages/c4/b1/f6670f00fa7ea601584623f6c11602ab92117d83eaff885e0210f6de7418/psycopg_binary-3.3.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:47c656a8a7ba6eb0cff1801a4caaa9c8bdc12d03080e273aff1c8ac39971a77e", size = 4255811, upload-time = "2026-05-01T23:31:44.986Z" }, + { url = "https://files.pythonhosted.org/packages/eb/e6/5fff07a70d1f945ed90ae131c3bd76cab32beff7c58c6db15ad5820b6d1f/psycopg_binary-3.3.4-cp314-cp314-win_amd64.whl", hash = "sha256:c37e024c07308cd06cf3ec51bfd0e7f6157585a4d84d1bce4a7f5f7913719bf8", size = 3666849, upload-time = "2026-05-01T23:31:51.165Z" }, +] + [[package]] name = "pyarrow" version = "23.0.0" From 0ad18940d7209eca014bd3e7d776294c28d1b1d2 Mon Sep 17 00:00:00 2001 From: georgie Date: Fri, 15 May 2026 08:29:06 +1000 Subject: [PATCH 02/21] linting --- src/orm_loader/backends/resolve.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/orm_loader/backends/resolve.py b/src/orm_loader/backends/resolve.py index 9333c0e..0f237fd 100644 --- a/src/orm_loader/backends/resolve.py +++ b/src/orm_loader/backends/resolve.py @@ -1,8 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING - -import sqlalchemy as sa import sqlalchemy.orm as so from .base import DatabaseBackend From 765eeebf4a9fede9e3a9fcbcb199fc86b23cebfd Mon Sep 17 00:00:00 2001 From: georgie Date: Fri, 15 May 2026 08:31:48 +1000 Subject: [PATCH 03/21] linting --- src/orm_loader/backends/resolve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/orm_loader/backends/resolve.py b/src/orm_loader/backends/resolve.py index 0f237fd..fc3d6b5 100644 --- a/src/orm_loader/backends/resolve.py +++ b/src/orm_loader/backends/resolve.py @@ -17,7 +17,7 @@ ) -def _dialect_name(bindable: so.Session | "Engine" | "Connection") -> str: +def _dialect_name(bindable: "so.Session | Engine | Connection",) -> str: if isinstance(bindable, so.Session): bind = bindable.get_bind() return bind.dialect.name @@ -28,7 +28,7 @@ def _dialect_name(bindable: so.Session | "Engine" | "Connection") -> str: raise TypeError(f"Unsupported bindable type: {type(bindable)!r}") -def resolve_backend(bindable: so.Session | "Engine" | "Connection") -> DatabaseBackend: +def resolve_backend(bindable: "so.Session | Engine | Connection") -> DatabaseBackend: """ Resolve a concrete backend from a SQLAlchemy session, engine, or connection. """ From 228c847c9e0f573d21d90a17244bbed438e438f5 Mon Sep 17 00:00:00 2001 From: georgie Date: Fri, 15 May 2026 08:52:36 +1000 Subject: [PATCH 04/21] code review journal mode and session handling updates --- .gitignore | 3 +- CHANGELOG.md | 3 +- notebooks/01_setup_registry.ipynb | 207 --------------- notebooks/02_test_file_load.ipynb | 220 ---------------- notebooks/03_improve_load_perf.ipynb | 318 ------------------------ pyproject.toml | 1 - src/orm_loader/backends/postgres.py | 6 +- src/orm_loader/backends/sqlite.py | 16 +- src/orm_loader/helpers/sqlite.py | 8 +- tests/backends/test_postgres_backend.py | 54 ++++ tests/backends/test_sqlite_backend.py | 14 ++ 11 files changed, 96 insertions(+), 754 deletions(-) delete mode 100644 notebooks/01_setup_registry.ipynb delete mode 100644 notebooks/02_test_file_load.ipynb delete mode 100644 notebooks/03_improve_load_perf.ipynb diff --git a/.gitignore b/.gitignore index f898ecf..8475748 100644 --- a/.gitignore +++ b/.gitignore @@ -211,4 +211,5 @@ OMOP_CDM*.csv *.db .vscode/ .DS_Store -_temp/ \ No newline at end of file +_temp/ +notebooks/ \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index e973d48..3580cb4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -113,4 +113,5 @@ # 0.4.0 - update to handle psycopg (as opposed to psycopg2) cleanly - overall api cleanup with the goal of being more explicit about selection of specific db backends -- general typing cleanup \ No newline at end of file +- general typing cleanup +- removed example notebooks until they can be cleaned up with working use-cases according to updated api \ No newline at end of file diff --git a/notebooks/01_setup_registry.ipynb b/notebooks/01_setup_registry.ipynb deleted file mode 100644 index 1e8a679..0000000 --- a/notebooks/01_setup_registry.ipynb +++ /dev/null @@ -1,207 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "d4a7dfa5", - "metadata": {}, - "outputs": [], - "source": [ - "from orm_loader.registry import (\n", - " ModelRegistry,\n", - " ModelDescriptor,\n", - " TableSpec,\n", - " FieldSpec,\n", - " Validator,\n", - " ValidationIssue,\n", - " SeverityLevel,\n", - " ValidationRunner,\n", - " always_on_validators,\n", - ")\n", - "from pathlib import Path" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "9fec9cb5", - "metadata": {}, - "outputs": [], - "source": [ - "m = ModelRegistry(model_version = '5.4', model_name = 'CDM')" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "5685c951", - "metadata": {}, - "outputs": [], - "source": [ - "field_spec = Path('OMOP_CDMv5.4_Field_Level.csv')\n", - "table_spec = Path('OMOP_CDMv5.4_Table_Level.csv')\n", - "\n", - "m.load_table_specs(table_csv=table_spec, field_csv=field_spec)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "6efeec23", - "metadata": {}, - "outputs": [], - "source": [ - "m.discover_models('omop_alchemy.cdm.model')" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "9cb956ec", - "metadata": {}, - "outputs": [], - "source": [ - "runner = ValidationRunner(\n", - " validators=always_on_validators(),\n", - " fail_fast=False,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "ad537e7e", - "metadata": {}, - "outputs": [], - "source": [ - "report = runner.run(m)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "e3d63142", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "MODEL v5.4: 0 error(s), 27 warning(s), 8 info\n" - ] - } - ], - "source": [ - "print(report.summary())\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43589cc8", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "📦 cdm_source\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: cdm_source_name) Hint: ORM primary key not marked as primary key in specification\n", - "\n", - "📦 cohort\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: cohort_definition_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: subject_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 cohort_definition\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: cohort_definition_id) Hint: ORM primary key not marked as primary key in specification\n", - "\n", - "📦 concept_ancestor\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: ancestor_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: descendant_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 concept_relationship\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: concept_id_1) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: concept_id_2) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: relationship_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 concept_synonym\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: concept_synonym_name) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 death\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: person_id) Hint: ORM primary key not marked as primary key in specification\n", - "\n", - "📦 drug_strength\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: drug_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: ingredient_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 episode\n", - " ⚠️ FOREIGN_KEY_NOT_IN_SPEC (field: episode_parent_id) Hint: ORM defines FK but specification does not\n", - "\n", - "📦 episode_event\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: episode_event_field_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: episode_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: event_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 fact_relationship\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: domain_concept_id_1) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: domain_concept_id_2) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: fact_id_1) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: fact_id_2) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: relationship_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 relationship\n", - " ⚠️ FOREIGN_KEY_NOT_IN_SPEC (field: reverse_relationship_id) Hint: ORM defines FK but specification does not\n", - "\n", - "📦 source_to_concept_map\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: source_code) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: source_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: source_vocabulary_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n" - ] - } - ], - "source": [ - "if not report.is_valid():\n", - " print(report.render_text_report())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1ef909ef", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "orm-loader (3.11.12)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/02_test_file_load.ipynb b/notebooks/02_test_file_load.ipynb deleted file mode 100644 index eb2a64d..0000000 --- a/notebooks/02_test_file_load.ipynb +++ /dev/null @@ -1,220 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "c5d4e71b", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "897b6570", - "metadata": {}, - "outputs": [], - "source": [ - "import sqlalchemy as sa\n", - "import sqlalchemy.orm as so\n", - "from sqlalchemy.orm import DeclarativeBase, Session\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "import tempfile\n", - "import logging\n", - "from orm_loader.tables.base import CSVLoadableTableInterface \n", - "\n", - "logging.basicConfig(level=logging.INFO)\n", - "\n", - "class Base(DeclarativeBase):\n", - " pass\n", - "\n", - "engine = sa.create_engine(\"sqlite:///test.db\", echo=False, future=True)\n", - "Base.metadata.bind = engine\n", - "\n", - "\n", - "class TestTable(Base, CSVLoadableTableInterface):\n", - " __tablename__ = \"test_table\"\n", - "\n", - " id: so.Mapped[int] = so.mapped_column(primary_key=True)\n", - " name: so.Mapped[str] = so.mapped_column(nullable=False)\n", - "\n", - "Base.metadata.create_all(engine)\n", - "\n", - "tmp = Path(tempfile.mkdtemp())\n", - "\n", - "csv_initial = tmp / \"test_table.csv\"\n", - "csv_replace = tmp / \"test_table_replace.csv\"\n", - "csv_empty = tmp / \"test_table_empty.csv\"\n", - "\n", - "pd.DataFrame(\n", - " [\n", - " {\"id\": 1, \"name\": \"alpha\"},\n", - " {\"id\": 2, \"name\": \"beta\"},\n", - " {\"id\": 3, \"name\": \"gamma\"},\n", - " ]\n", - ").to_csv(csv_initial, index=False, sep=\"\\t\")\n", - "\n", - "pd.DataFrame(\n", - " [\n", - " {\"id\": 2, \"name\": \"beta_updated\"},\n", - " {\"id\": 3, \"name\": \"gamma_updated\"},\n", - " ]\n", - ").to_csv(csv_replace, index=False, sep=\"\\t\")\n", - "\n", - "csv_empty.touch()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "a62502c4", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_test_table does not exist; recreating\n" - ] - }, - { - "data": { - "text/plain": [ - "[<__main__.TestTable at 0x120949d30>, <__main__.TestTable at 0x1166facf0>]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "with Session(engine) as session:\n", - " inserted = TestTable.load_csv(\n", - " session,\n", - " csv_initial,\n", - " dedupe=False,\n", - " )\n", - " session.commit()\n", - "\n", - " rows = session.execute(\n", - " sa.select(TestTable).order_by(TestTable.id)\n", - " ).scalars().all()\n", - "\n", - "rows\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ba6337f0", - "metadata": {}, - "outputs": [], - "source": [ - "with Session(engine) as session:\n", - " rows = session.execute(\n", - " sa.select(TestTable).order_by(TestTable.id)\n", - " ).scalars().all()\n", - "rows" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a6956332", - "metadata": {}, - "outputs": [], - "source": [ - "with Session(engine) as session:\n", - " replaced = TestTable.replace_from_csv(\n", - " session,\n", - " csv_replace,\n", - " )\n", - " session.commit()\n", - "\n", - " rows = session.execute(\n", - " sa.select(TestTable).order_by(TestTable.id)\n", - " ).scalars().all()\n", - "\n", - "rows\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "29c775f5", - "metadata": {}, - "outputs": [], - "source": [ - "with engine.connect() as conn:\n", - " tables = conn.execute(\n", - " sa.text(\n", - " \"SELECT name FROM sqlite_master WHERE type='table'\"\n", - " )\n", - " ).fetchall()\n", - "\n", - "tables" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a6e8fc89", - "metadata": {}, - "outputs": [], - "source": [ - "with Session(engine) as session:\n", - " loaded = TestTable.replace_from_csv(\n", - " session,\n", - " csv_empty,\n", - " )\n", - " session.commit()\n", - "\n", - " rows = session.execute(\n", - " sa.select(TestTable).order_by(TestTable.id)\n", - " ).scalars().all()\n", - "\n", - " print(\"After empty file replace:\", [(r.id, r.name) for r in rows])\n", - " print(\"Rows loaded from empty file:\", loaded)\n", - "\n", - " # hard assertions (will raise if broken)\n", - " assert loaded == 0, \"Empty CSV should load 0 rows\"\n", - " assert [(r.id, r.name) for r in rows] == [\n", - " (1, \"alpha\"),\n", - " (2, \"beta_updated\"),\n", - " (3, \"gamma_updated\"),\n", - " ], \"Empty CSV must not modify existing rows\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "30eea280", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "orm-loader (3.12.10)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/03_improve_load_perf.ipynb b/notebooks/03_improve_load_perf.ipynb deleted file mode 100644 index 079f1c8..0000000 --- a/notebooks/03_improve_load_perf.ipynb +++ /dev/null @@ -1,318 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "a251fa62", - "metadata": {}, - "outputs": [], - "source": [ - "import sqlalchemy as sa\n", - "import sqlalchemy.orm as so\n", - "from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker\n", - "from sqlalchemy.exc import IntegrityError\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "import numpy as np\n", - "import tempfile, logging, os\n", - "from orm_loader.tables.base import CSVLoadableTableInterface \n", - "from orm_loader.loaders import LoaderContext\n", - "from orm_loader.loaders.loader_interface import ParquetLoader, LoaderInterface, PandasLoader\n", - "\n", - "from orm_loader.helpers import configure_logging, bootstrap, explain_sqlite_fk_error, bulk_load_context, configure_logging\n", - "\n", - "from omop_alchemy import get_engine_name, load_environment, TEST_PATH, ROOT_PATH\n", - "from omop_alchemy.cdm.model.vocabulary import (\n", - " Domain,\n", - " Vocabulary,\n", - " Concept_Class,\n", - " Relationship,\n", - " Concept,\n", - " Concept_Ancestor,\n", - " Concept_Relationship,\n", - " Concept_Synonym,\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "9173aad2", - "metadata": {}, - "outputs": [], - "source": [ - "logging.basicConfig(level=logging.INFO)\n", - "\n", - "class Base(DeclarativeBase):\n", - " pass\n", - "\n", - "engine_string = \"postgresql+psycopg2://airflow:airflow@0.0.0.0:5433/mosaiq\"\n", - "engine = sa.create_engine(engine_string, echo=False, future=True)\n", - "Base.metadata.bind = engine\n", - "\n", - "class TestTable(Base, CSVLoadableTableInterface):\n", - " __tablename__ = \"test_table\"\n", - "\n", - " id: so.Mapped[int] = so.mapped_column(primary_key=True)\n", - " name: so.Mapped[str] = so.mapped_column(nullable=False)\n", - "\n", - "Base.metadata.create_all(engine)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "8c2f3d9a", - "metadata": {}, - "outputs": [], - "source": [ - "tmp = Path(tempfile.mkdtemp())\n", - "\n", - "csv_initial = tmp / \"test_table.csv\"\n", - "csv_replace = tmp / \"test_table_replace.csv\"\n", - "csv_empty = tmp / \"test_table_empty.csv\"\n", - "\n", - "pd.DataFrame(\n", - " [\n", - " {\"id\": 1, \"name\": \"alpha\"},\n", - " {\"id\": 2, \"name\": \"beta\"},\n", - " {\"id\": 3, \"name\": \"gamma\"},\n", - " ]\n", - ").to_csv(csv_initial, index=False, sep=\"\\t\")\n", - "\n", - "pd.DataFrame(\n", - " [\n", - " {\"id\": 2, \"name\": \"beta_updated\"},\n", - " {\"id\": 3, \"name\": \"gamma_updated\"},\n", - " ]\n", - ").to_csv(csv_replace, index=False, sep=\"\\t\")\n", - "\n", - "csv_empty.touch()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "53b8a52a", - "metadata": {}, - "outputs": [], - "source": [ - "session = Session(engine)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "5a683a1b", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_test_table does not exist; recreating\n", - "INFO:orm_loader.loaders.loading_helpers:Bulk loading _staging_test_table via COPY (encoding=utf-8, delimiter=\t)\n" - ] - }, - { - "data": { - "text/plain": [ - "3" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "TestTable.load_csv(path=csv_initial, session=session)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "83505eb9", - "metadata": {}, - "outputs": [], - "source": [ - "session.commit()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "eba074e6", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2026-01-22 17:10:53,957 | INFO | sql_loader.omop_alchemy.config | Environment variables loaded from .env file\n", - "INFO:sql_loader.omop_alchemy.config:Environment variables loaded from .env file\n" - ] - } - ], - "source": [ - "ATHENA_INITIAL_LOAD = [\n", - " Domain,\n", - " Vocabulary,\n", - " Concept_Class,\n", - " Relationship,\n", - " Concept\n", - "]\n", - "\n", - "ATHENA_SUBSEQUENT_LOAD = [\n", - " Concept_Ancestor,\n", - " Concept_Relationship,\n", - " Concept_Synonym,\n", - "]\n", - "\n", - "configure_logging()\n", - "load_environment()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "930f6572", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2026-01-22 17:10:54,687 | INFO | sql_loader.omop_alchemy.config | Database engine configured for schema 'cdm'\n", - "INFO:sql_loader.omop_alchemy.config:Database engine configured for schema 'cdm'\n" - ] - } - ], - "source": [ - "engine_string = get_engine_name('cdm')" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "de3d47e5", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:orm_loader.helpers.bootstrap:Bootstrapping schema (create=True)\n" - ] - } - ], - "source": [ - "engine = sa.create_engine(engine_string, future=True, echo=False)\n", - "bootstrap(engine, create=True)\n", - "\n", - "Session = sessionmaker(bind=engine, future=True)\n", - "session = Session()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7da32a8f", - "metadata": {}, - "outputs": [], - "source": [ - "source_path = Path(os.environ['SOURCE_PATH'])\n", - "\n", - "\n", - "p = ParquetLoader()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6ac0d9a5", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "23ae5a8a", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2026-01-22 17:12:05,251 | INFO | sql_loader.orm_loader.helpers.bulk | Disabled foreign key checks for bulk load\n", - "INFO:sql_loader.orm_loader.helpers.bulk:Disabled foreign key checks for bulk load\n", - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_domain does not exist; recreating\n", - "INFO:orm_loader.loaders.loading_helpers:Bulk loading _staging_domain via COPY (encoding=utf-8, delimiter=\t)\n", - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_vocabulary does not exist; recreating\n", - "INFO:orm_loader.loaders.loading_helpers:Bulk loading _staging_vocabulary via COPY (encoding=utf-8, delimiter=\t)\n", - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_concept_class does not exist; recreating\n", - "INFO:orm_loader.loaders.loading_helpers:Bulk loading _staging_concept_class via COPY (encoding=utf-8, delimiter=\t)\n", - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_relationship does not exist; recreating\n", - "INFO:orm_loader.loaders.loading_helpers:Bulk loading _staging_relationship via COPY (encoding=utf-8, delimiter=\t)\n", - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_concept does not exist; recreating\n", - "INFO:orm_loader.loaders.loading_helpers:Bulk loading _staging_concept via COPY (encoding=utf-8, delimiter=\t)\n" - ] - } - ], - "source": [ - "with bulk_load_context(session):\n", - " for model in ATHENA_INITIAL_LOAD:\n", - " _ = model.load_csv(\n", - " session,\n", - " source_path / f\"{model.__tablename__.upper()}.csv\",\n", - " dedupe=False,\n", - " merge_strategy=\"upsert\",\n", - " loader=p,\n", - " )\n", - " session.commit()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e694d48c", - "metadata": {}, - "outputs": [], - "source": [ - "with bulk_load_context(session):\n", - " for model in ATHENA_SUBSEQUENT_LOAD:\n", - " _ = model.load_csv(\n", - " session,\n", - " source_path / f\"{model.__tablename__.upper()}.csv\",\n", - " dedupe=False,\n", - " chunksize=60_000_000, # parquet loader chunk is bytes not rows\n", - " merge_strategy=\"replace\",\n", - " loader=p,\n", - " )\n", - " session.commit()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "orm-loader (3.12.10)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/pyproject.toml b/pyproject.toml index 030cc52..cd6ba20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,6 @@ postgres = [ dev = [ "pytest>=9.0.3", "mypy>=1.19.1", - "pytest>=9.0.3", "ruff>=0.14.11", "mkdocs-material>=9.7.1", "mkdocstrings-python>=2.0.1", diff --git a/src/orm_loader/backends/postgres.py b/src/orm_loader/backends/postgres.py index 9dbdc4a..4fbe2fa 100644 --- a/src/orm_loader/backends/postgres.py +++ b/src/orm_loader/backends/postgres.py @@ -190,18 +190,20 @@ def refresh_materialized_view( @contextmanager def engine_with_replica_role(self, engine: "Engine"): - @sa.event.listens_for(engine, "connect") # type: ignore[arg-type] def _set_replica_role( dbapi_conn: sa.engine.interfaces.DBAPIConnection, _, ) -> None: cur = dbapi_conn.cursor() - cur.execute("SET session_replication_role = replica") + cur.execute("SET session_replication_role = 'replica'") cur.close() + sa.event.listen(engine, "connect", _set_replica_role) + try: yield engine finally: + sa.event.remove(engine, "connect", _set_replica_role) with engine.connect() as conn: conn = conn.execution_options(isolation_level="AUTOCOMMIT") conn.execute(sa.text("SET session_replication_role = DEFAULT")) diff --git a/src/orm_loader/backends/sqlite.py b/src/orm_loader/backends/sqlite.py index 54dcd45..4ae1c70 100644 --- a/src/orm_loader/backends/sqlite.py +++ b/src/orm_loader/backends/sqlite.py @@ -19,6 +19,9 @@ logger = logging.getLogger(__name__) +VALID_SQLITE_JOURNAL_MODES = frozenset( + {"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"} +) class SQLiteBackend(DatabaseBackend): @@ -30,9 +33,19 @@ def __init__( defer_foreign_keys: bool = True, ) -> None: self.busy_timeout_ms = busy_timeout_ms - self.journal_mode = journal_mode + self.journal_mode = self._validate_journal_mode(journal_mode) self.defer_foreign_keys = defer_foreign_keys + @staticmethod + def _validate_journal_mode(journal_mode: str) -> str: + normalised = journal_mode.strip().upper() + if normalised not in VALID_SQLITE_JOURNAL_MODES: + raise ValueError( + "Unsupported SQLite journal_mode " + f"{journal_mode!r}. Expected one of: {sorted(VALID_SQLITE_JOURNAL_MODES)}" + ) + return normalised + @property def name(self) -> str: return "sqlite" @@ -198,6 +211,7 @@ def configure_dbapi_connection(self, dbapi_connection: sa.engine.interfaces.DBA cursor = dbapi_connection.cursor() cursor.execute(f"PRAGMA busy_timeout = {self.busy_timeout_ms}") cursor.execute(f"PRAGMA journal_mode = {self.journal_mode}") + cursor.execute("PRAGMA foreign_keys = ON;") if self.defer_foreign_keys: cursor.execute("PRAGMA defer_foreign_keys = ON;") cursor.close() diff --git a/src/orm_loader/helpers/sqlite.py b/src/orm_loader/helpers/sqlite.py index b27ce18..1c26091 100644 --- a/src/orm_loader/helpers/sqlite.py +++ b/src/orm_loader/helpers/sqlite.py @@ -15,7 +15,8 @@ def enable_sqlite_foreign_keys( This helper is kept for compatibility with older event-hook setups. It delegates to ``SQLiteBackend.configure_dbapi_connection()``, - which may apply more than just foreign-key settings. + which enables foreign-key enforcement and may apply more than just + foreign-key settings. """ del connection_record SQLiteBackend().configure_dbapi_connection(dbapi_connection) @@ -31,8 +32,9 @@ def attach_sqlite_bulk_load_pragmas( """ Install SQLite connect hooks aimed at heavy local write workloads. - The hook currently sets ``busy_timeout`` and journal mode, and can - also enable deferred foreign-key checking for the connection. + The hook currently sets ``busy_timeout``, journal mode, and foreign-key + enforcement, and can also enable deferred foreign-key checking for the + connection. """ SQLiteBackend( busy_timeout_ms=busy_timeout_ms, diff --git a/tests/backends/test_postgres_backend.py b/tests/backends/test_postgres_backend.py index 8fafa0a..9051f6e 100644 --- a/tests/backends/test_postgres_backend.py +++ b/tests/backends/test_postgres_backend.py @@ -137,3 +137,57 @@ def test_postgres_backend_materialized_view_methods_emit_expected_sql(): assert any("CREATE MATERIALIZED VIEW IF NOT EXISTS mv_test as SELECT" in sql for sql in session.statements) assert any("REFRESH MATERIALIZED VIEW mv_test;" == sql for sql in session.statements) + + +def test_postgres_backend_engine_with_replica_role_unregisters_listener(monkeypatch): + backend = PostgresBackend() + events: list[tuple[str, object, str]] = [] + statements: list[str] = [] + + class _Result: + def scalar(self): + return "origin" + + class _Conn: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def execution_options(self, **kwargs): + return self + + def execute(self, statement): + sql = str(statement.compile(dialect=postgresql.dialect())) + statements.append(sql) + return _Result() + + class _Engine: + def connect(self): + events.append(("connect", self, "connect")) + return _Conn() + + engine = _Engine() + + def _listen(target, name, fn) -> None: + events.append(("listen", target, name)) + + def _remove(target, name, fn) -> None: + events.append(("remove", target, name)) + + monkeypatch.setattr(sa.event, "listen", _listen) + monkeypatch.setattr(sa.event, "remove", _remove) + + with backend.engine_with_replica_role(engine): + pass + + assert events == [ + ("listen", engine, "connect"), + ("remove", engine, "connect"), + ("connect", engine, "connect"), + ] + assert statements == [ + "SET session_replication_role = DEFAULT", + "SHOW session_replication_role", + ] diff --git a/tests/backends/test_sqlite_backend.py b/tests/backends/test_sqlite_backend.py index e7adc4f..e7d33a2 100644 --- a/tests/backends/test_sqlite_backend.py +++ b/tests/backends/test_sqlite_backend.py @@ -48,6 +48,7 @@ def test_sqlite_backend_identity_and_capabilities(): assert backend.capabilities.supports_fk_toggle is True assert backend.capabilities.supports_materialized_views is False assert backend.resolve_index_strategy("auto") == "keep" + assert backend.journal_mode == "WAL" def test_sqlite_backend_create_staging_table(): @@ -167,9 +168,11 @@ def test_sqlite_backend_configures_bulk_load_pragmas(tmp_path: Path): with engine.connect() as conn: busy_timeout = conn.execute(sa.text("PRAGMA busy_timeout")).scalar_one() journal_mode = conn.execute(sa.text("PRAGMA journal_mode")).scalar_one() + foreign_keys = conn.execute(sa.text("PRAGMA foreign_keys")).scalar_one() assert busy_timeout == 60000 assert str(journal_mode).lower() == "wal" + assert foreign_keys == 1 def test_sqlite_backend_restore_journal_mode(tmp_path: Path): @@ -200,6 +203,17 @@ def test_attach_sqlite_bulk_load_pragmas_installs_backend_hook(tmp_path: Path): with engine.connect() as conn: busy_timeout = conn.execute(sa.text("PRAGMA busy_timeout")).scalar_one() journal_mode = conn.execute(sa.text("PRAGMA journal_mode")).scalar_one() + foreign_keys = conn.execute(sa.text("PRAGMA foreign_keys")).scalar_one() assert busy_timeout == 45000 assert str(journal_mode).lower() == "wal" + assert foreign_keys == 1 + + +def test_sqlite_backend_rejects_invalid_journal_mode(): + try: + SQLiteBackend(journal_mode="wal; drop table x;") + except ValueError as exc: + assert "Unsupported SQLite journal_mode" in str(exc) + else: + raise AssertionError("Expected invalid journal_mode to raise ValueError") From 1da053abd49605409a3ae298e4df01d75511f271 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Mon, 18 May 2026 05:10:27 +0000 Subject: [PATCH 05/21] Add test cases for dialects and index_options tackling #3 --- src/orm_loader/backends/postgres.py | 5 +- tests/loaders/test_loader_e2e.py | 77 +++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/src/orm_loader/backends/postgres.py b/src/orm_loader/backends/postgres.py index 4fbe2fa..e8823ee 100644 --- a/src/orm_loader/backends/postgres.py +++ b/src/orm_loader/backends/postgres.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any import sqlalchemy as sa import sqlalchemy.orm as so +import sqlalchemy.event as sae from .base import BackendCapabilities, DatabaseBackend from ..loaders.loading_helpers import quick_load_pg @@ -198,12 +199,12 @@ def _set_replica_role( cur.execute("SET session_replication_role = 'replica'") cur.close() - sa.event.listen(engine, "connect", _set_replica_role) + sae.listen(engine, "connect", _set_replica_role) try: yield engine finally: - sa.event.remove(engine, "connect", _set_replica_role) + sae.remove(engine, "connect", _set_replica_role) with engine.connect() as conn: conn = conn.execution_options(isolation_level="AUTOCOMMIT") conn.execute(sa.text("SET session_replication_role = DEFAULT")) diff --git a/tests/loaders/test_loader_e2e.py b/tests/loaders/test_loader_e2e.py index 697d601..8a60b89 100644 --- a/tests/loaders/test_loader_e2e.py +++ b/tests/loaders/test_loader_e2e.py @@ -1,4 +1,5 @@ import sqlalchemy as sa +import sqlalchemy.event as sae import sqlalchemy.orm as so from sqlalchemy.orm import Session from pathlib import Path @@ -417,6 +418,82 @@ class TextTable2(Base, CSVLoadableTableInterface): rows = session.execute(sa.select(TextTable2)).scalars().all() assert rows[0].name == "foo\tbar" + +# --- index_strategy tests --- + +def _make_ddl_tracker(engine): + """Return a list that is populated with DROP/CREATE INDEX statements as they execute.""" + ddl_log: list[str] = [] + + @sae.listens_for(engine, "before_cursor_execute") + def _capture(conn, cursor, statement, parameters, context, executemany): + upper = statement.strip().upper() + if upper.startswith("DROP INDEX") or upper.startswith("CREATE INDEX"): + ddl_log.append(statement.strip()) + + return ddl_log + + +def test_auto_strategy_keeps_indices_on_sqlite(session, engine, tmp_csv_dir): + """On SQLite, 'auto' resolves to 'keep' — no index DDL should be emitted.""" + ddl_log = _make_ddl_tracker(engine) + csv_path = tmp_csv_dir / "test_table.csv" + pd.DataFrame([{"id": 1, "name": "alpha"}, {"id": 2, "name": "beta"}]).to_csv( + csv_path, index=False, sep="\t" + ) + + SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="auto") # type: ignore + session.commit() + + assert not any("DROP INDEX" in s.upper() for s in ddl_log) + assert not any("CREATE INDEX" in s.upper() for s in ddl_log) + inspector = sa.inspect(engine) + inspector.clear_cache() + assert "ix_test_table_name" in {idx["name"] for idx in inspector.get_indexes("test_table")} + + +def test_explicit_keep_preserves_indices(session, engine, tmp_csv_dir): + """Explicit 'keep' emits no index DDL regardless of dialect.""" + ddl_log = _make_ddl_tracker(engine) + csv_path = tmp_csv_dir / "test_table.csv" + pd.DataFrame([{"id": 1, "name": "alpha"}]).to_csv(csv_path, index=False, sep="\t") + + SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="keep") # type: ignore + session.commit() + + assert not any("DROP INDEX" in s.upper() for s in ddl_log) + inspector = sa.inspect(engine) + inspector.clear_cache() + assert "ix_test_table_name" in {idx["name"] for idx in inspector.get_indexes("test_table")} + + +def test_explicit_drop_rebuild_on_sqlite_restores_index(session, engine, tmp_csv_dir): + """Explicit 'drop_rebuild' drops then restores the index even on SQLite.""" + ddl_log = _make_ddl_tracker(engine) + csv_path = tmp_csv_dir / "test_table.csv" + pd.DataFrame([{"id": 1, "name": "alpha"}, {"id": 2, "name": "beta"}]).to_csv( + csv_path, index=False, sep="\t" + ) + + SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="drop_rebuild") # type: ignore + session.commit() + + assert any("DROP INDEX" in s.upper() for s in ddl_log) + assert any("CREATE INDEX" in s.upper() for s in ddl_log) + inspector = sa.inspect(engine) + inspector.clear_cache() + assert "ix_test_table_name" in {idx["name"] for idx in inspector.get_indexes("test_table")} + + +def test_invalid_index_strategy_raises(session, tmp_csv_dir): + """An unrecognised strategy value raises ValueError before any DB work.""" + csv_path = tmp_csv_dir / "test_table.csv" + pd.DataFrame([{"id": 1, "name": "alpha"}]).to_csv(csv_path, index=False, sep="\t") + + with pytest.raises(ValueError, match="Unknown index_strategy"): + SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="not-valid") # type: ignore + + # from hypothesis import given, strategies as st # from sqlalchemy.orm import declarative_base # from pathlib import Path From be353f96b5118dd4091f1577ca726aa3df0eba2e Mon Sep 17 00:00:00 2001 From: gkennos Date: Mon, 18 May 2026 22:10:34 +1000 Subject: [PATCH 06/21] adding infer quote mode functionality --- src/orm_loader/loaders/loading_helpers.py | 70 ++++++++++++++++++++++- tests/loaders/test_helpers.py | 34 ++++++++++- 2 files changed, 100 insertions(+), 4 deletions(-) diff --git a/src/orm_loader/loaders/loading_helpers.py b/src/orm_loader/loaders/loading_helpers.py index 353d456..b1f41a9 100644 --- a/src/orm_loader/loaders/loading_helpers.py +++ b/src/orm_loader/loaders/loading_helpers.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path import chardet +import csv as _csv import sqlalchemy as sa import sqlalchemy.orm as so import logging @@ -86,6 +87,66 @@ def infer_delim(file): return '\t' return ',' + +def infer_quote_mode( + path: Path, + delimiter: str, + encoding: str = "utf-8", + sample_rows: int = 200, +) -> str: + """Return 'csv' or 'literal' by comparing column-count consistency under both + quoting interpretations across a sample of rows. + + - 'csv' → standard RFC-4180 quoting; surrounding double-quotes are stripped + and embedded delimiters/newlines inside quotes are preserved. + - 'literal' → double-quote has no special meaning; every byte is stored as-is. + + Defaults to 'csv' when both modes produce identical output (no quoting in play) + or when the evidence is tied. Callers can always override by passing an + explicit value instead of relying on auto-detection. + """ + with open(path, encoding=encoding, errors="replace", newline="") as f: + lines = [f.readline() for _ in range(sample_rows + 1)] + + raw = "".join(ln for ln in lines if ln) + if not raw: + return "csv" + + try: + rows_csv = list(_csv.reader(io.StringIO(raw), delimiter=delimiter)) + except _csv.Error: + return "literal" + + try: + rows_lit = list( + _csv.reader(io.StringIO(raw), delimiter=delimiter, quoting=_csv.QUOTE_NONE) + ) + except _csv.Error: + return "csv" + + if not rows_csv: + return "csv" + + ncols = len(rows_csv[0]) + if ncols <= 1: + return "csv" + + # No difference between modes → no quoting is active, csv is the safe default + if rows_csv == rows_lit: + return "csv" + + data_csv = rows_csv[1:] + data_lit = rows_lit[1:] if len(rows_lit) > 1 else [] + + if not data_csv: + return "csv" + + csv_ok = sum(1 for r in data_csv if len(r) == ncols) + lit_ok = sum(1 for r in data_lit if len(r) == ncols) + + # Prefer csv on a tie; only choose literal when it is strictly more consistent + return "literal" if lit_ok > csv_ok else "csv" + def arrow_drop_duplicates( table: pa.Table, pk_names: list[str], @@ -169,15 +230,18 @@ def quick_load_pg( path: Path, session: so.Session, tablename: str, - quote_mode: str = "csv", + quote_mode: str = "auto", ) -> int: - raw_conn = session.connection().connection + raw_conn = session.connection().connection if not hasattr(raw_conn, "cursor"): raise RuntimeError("Expected DB-API connection for COPY") - + encoding = infer_encoding(path)['encoding'] or 'utf-8' delimiter = infer_delim(path) + if quote_mode == "auto": + quote_mode = infer_quote_mode(path, delimiter=delimiter, encoding=encoding) + logger.info(f"Auto-detected quote_mode={quote_mode!r} for {path.name}") if quote_mode == "csv": copy_options = f""" FORMAT csv, diff --git a/tests/loaders/test_helpers.py b/tests/loaders/test_helpers.py index 9a55de2..da907d1 100644 --- a/tests/loaders/test_helpers.py +++ b/tests/loaders/test_helpers.py @@ -1,5 +1,5 @@ from orm_loader.loaders.data_classes import ColumnCastingStats, TableCastingStats -from orm_loader.loaders.loading_helpers import infer_delim, infer_encoding +from orm_loader.loaders.loading_helpers import infer_delim, infer_encoding, infer_quote_mode def test_column_casting_stats_records_examples(): stats = ColumnCastingStats() @@ -35,3 +35,35 @@ def test_infer_encoding_utf8(tmp_path): p.write_text("hello") enc = infer_encoding(p).get("encoding") or "" assert enc.lower() in {"utf-8", "ascii"} + + +def test_infer_quote_mode_unquoted_tsv_returns_csv(tmp_path): + # No quotes anywhere: both modes identical, csv is the safe default + p = tmp_path / "x.csv" + p.write_text("id\tname\tvalue\n1\tAlice\t10\n2\tBob\t20\n") + assert infer_quote_mode(p, delimiter="\t") == "csv" + + +def test_infer_quote_mode_rfc4180_quoted_field_returns_csv(tmp_path): + # Athena-style: quoted concept_name at the varchar(255) boundary, + # no embedded delimiter — the column-count tie-break must favour csv + p = tmp_path / "x.csv" + long_name = "A" * 255 + p.write_text(f'id\tname\n1\t"{long_name}"\n2\tnormal\n') + assert infer_quote_mode(p, delimiter="\t") == "csv" + + +def test_infer_quote_mode_embedded_delimiter_in_quoted_field_returns_csv(tmp_path): + # Quoted field contains the delimiter: csv mode keeps column count consistent, + # literal mode splits on the embedded tab and produces ragged rows + p = tmp_path / "x.csv" + p.write_text('id\tname\tval\n1\t"foo\tbar"\t99\n2\tbaz\t0\n') + assert infer_quote_mode(p, delimiter="\t") == "csv" + + +def test_infer_quote_mode_unbalanced_quote_returns_literal(tmp_path): + # Unbalanced leading quote breaks CSV parsing: literal mode produces + # consistent 2-column rows while csv mode does not + p = tmp_path / "x.csv" + p.write_text('id\tname\n1\t"open\n2\t"open\n3\t"open\n') + assert infer_quote_mode(p, delimiter="\t") == "literal" From 7f2e2ed26c9f1daa25ba6f3952e91e782c0edba2 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Mon, 18 May 2026 22:54:16 +0000 Subject: [PATCH 07/21] Adapted tests with improved readability and interfaces --- pyproject.toml | 3 + src/orm_loader/tables/typing.py | 15 +-- tests/backends/test_base_backend.py | 72 ++++++++++---- tests/backends/test_postgres_backend.py | 53 +++++++---- tests/backends/test_sqlite_backend.py | 53 ++++++----- tests/conftest.py | 85 +++++++++++++++-- tests/loaders/test_dedupe.py | 11 ++- tests/loaders/test_loader_e2e.py | 119 +++++++++++------------- tests/loaders/test_parquet_loader.py | 11 ++- tests/loaders/test_pg_loader.py | 12 ++- tests/pytest.ini | 3 - 11 files changed, 285 insertions(+), 152 deletions(-) delete mode 100644 tests/pytest.ini diff --git a/pyproject.toml b/pyproject.toml index cd6ba20..f801f4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,9 @@ python_files = ["test_*.py"] python_classes = ["Test*"] python_functions = ["test_*"] addopts = "-ra" +markers = [ + "postgres: requires a running Postgres instance (set TEST_POSTGRES_URL)", +] [tool.pyright] reportMissingTypeStubs = false \ No newline at end of file diff --git a/src/orm_loader/tables/typing.py b/src/orm_loader/tables/typing.py index 2df700a..bda4751 100644 --- a/src/orm_loader/tables/typing.py +++ b/src/orm_loader/tables/typing.py @@ -76,13 +76,14 @@ def load_staging(cls: Type["CSVTableProtocol"], loader: "LoaderInterface", loade @classmethod def load_csv( - cls, - session: so.Session, - path: Path, - *, - normalise: bool = True, - dedupe: bool = False, - chunksize: int | None = None, + cls, + session: so.Session, + path: Path, + *, + loader: Optional["LoaderInterface"] = None, + normalise: bool = True, + dedupe: bool = False, + chunksize: int | None = None, merge_strategy: str = "replace", quote_mode: str = "csv", index_strategy: str = "auto", diff --git a/tests/backends/test_base_backend.py b/tests/backends/test_base_backend.py index dc81d4d..44b9502 100644 --- a/tests/backends/test_base_backend.py +++ b/tests/backends/test_base_backend.py @@ -3,16 +3,29 @@ import importlib import importlib.abc import sys +from importlib.machinery import ModuleSpec +from types import ModuleType +from typing import TYPE_CHECKING, Sequence, Type, cast, Any import pytest import sqlalchemy as sa import sqlalchemy.orm as so +from sqlalchemy.engine import Connection, Engine from orm_loader.backends import BackendCapabilities, DatabaseBackend, resolve_backend +if TYPE_CHECKING: + from orm_loader.loaders.data_classes import LoaderContext + from orm_loader.tables.typing import CSVTableProtocol + class _BlockPsycopg(importlib.abc.MetaPathFinder): - def find_spec(self, fullname, path=None, target=None): + def find_spec( + self, + fullname: str, + path: Sequence[str] | None = None, + target: ModuleType | None = None, + ) -> ModuleSpec | None: if fullname == "psycopg" or fullname.startswith("psycopg."): raise ModuleNotFoundError("No module named 'psycopg'") return None @@ -37,36 +50,60 @@ def capabilities(self) -> BackendCapabilities: supports_fk_toggle=True, ) - def create_staging_table(self, table_cls, session, staging_name) -> None: + def create_staging_table( + self, table_cls: Type[CSVTableProtocol], session: so.Session, staging_name: str + ) -> None: return None - def drop_staging_table(self, session, staging_name) -> None: + def drop_staging_table(self, session: so.Session, staging_name: str) -> None: return None - def merge_replace(self, table_cls, session, target_name, staging_name, pk_cols) -> None: + def merge_replace( + self, + table_cls: Type[CSVTableProtocol], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: return None - def merge_upsert(self, table_cls, session, target_name, staging_name, pk_cols) -> None: + def merge_upsert( + self, + table_cls: Type[CSVTableProtocol], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: return None - def merge_insert(self, table_cls, session, target_name, staging_name) -> None: + def merge_insert( + self, + table_cls: Type[CSVTableProtocol], + session: so.Session, + target_name: str, + staging_name: str, + ) -> None: return None - def disable_fk_check(self, session) -> str | int: + def disable_fk_check(self, session: so.Session) -> str | int: self.calls.append(("disable_fk_check", session)) return "enabled" - def enable_fk_check(self, session) -> str | int: + def enable_fk_check(self, session: so.Session) -> str | int: self.calls.append(("enable_fk_check", session)) return "disabled" - def restore_fk_check(self, session, previous_state: str | int) -> None: + def restore_fk_check(self, session: so.Session, previous_state: str | int) -> None: self.calls.append(("restore_fk_check", previous_state)) - def create_materialized_view(self, bind, name: str, selectable: sa.sql.Select) -> None: + def create_materialized_view( + self, bind: Engine | Connection, name: str, selectable: sa.sql.Select[Any] + ) -> None: return None - def refresh_materialized_view(self, bind, name: str) -> None: + def refresh_materialized_view(self, bind: Engine | Connection, name: str) -> None: return None @@ -80,6 +117,9 @@ class _ComputedTable: ) +_ComputedTableCls = cast("Type[CSVTableProtocol]", _ComputedTable) + + def test_backend_capabilities_defaults(): caps = BackendCapabilities() @@ -91,7 +131,7 @@ def test_backend_capabilities_defaults(): def test_database_backend_is_abstract(): with pytest.raises(TypeError): - DatabaseBackend() + DatabaseBackend() # type: ignore def test_fake_backend_can_implement_contract(): @@ -105,9 +145,9 @@ def test_fake_backend_can_implement_contract(): assert backend.supports_dialect("sqlite") is False assert backend.resolve_index_strategy("auto") == "drop_rebuild" assert backend.resolve_index_strategy("keep") == "keep" - assert backend.load_staging_fast(None, "staging") is None + assert backend.load_staging_fast(cast("LoaderContext", None), "staging") is None - with backend.merge_context(None, None): + with backend.merge_context(cast("Type[CSVTableProtocol]", None), cast(so.Session, None)): pass @@ -141,7 +181,7 @@ def test_resolve_index_strategy_raises_for_invalid_value(): def test_insertable_column_names_exclude_computed_columns(): backend = FakeBackend() - assert backend._insertable_column_names(_ComputedTable) == ["id", "name"] + assert backend._insertable_column_names(_ComputedTableCls) == ["id", "name"] def test_bulk_load_context_toggles_fk_and_restores(session): @@ -219,7 +259,7 @@ class dialect: name = "unknown" with pytest.raises(NotImplementedError, match="No backend registered"): - resolve_backend(_Unknown()) + resolve_backend(cast(Engine, _Unknown())) def test_backends_import_does_not_require_psycopg(): diff --git a/tests/backends/test_postgres_backend.py b/tests/backends/test_postgres_backend.py index 9051f6e..cee40fc 100644 --- a/tests/backends/test_postgres_backend.py +++ b/tests/backends/test_postgres_backend.py @@ -1,10 +1,18 @@ from __future__ import annotations +import sqlalchemy.event as sae +from typing import TYPE_CHECKING, Type, cast + import sqlalchemy as sa +import sqlalchemy.orm as so from sqlalchemy.dialects import postgresql +from sqlalchemy.engine import Connection, Engine from orm_loader.backends import PostgresBackend +if TYPE_CHECKING: + from orm_loader.tables.typing import CSVTableProtocol + class _ComputedTable: __table__ = sa.Table( @@ -42,6 +50,17 @@ def commit(self) -> None: self.commits += 1 +_ComputedTableCls = cast("Type[CSVTableProtocol]", _ComputedTable) + + +def _sess(s: _FakeSession) -> so.Session: + return cast(so.Session, s) + + +def _as_engine(s: _FakeSession) -> Engine | Connection: + return cast(Engine, s) + + def test_postgres_backend_identity_and_capabilities(): backend = PostgresBackend() @@ -57,7 +76,7 @@ def test_postgres_backend_create_staging_table_drops_computed_columns(): backend = PostgresBackend() session = _FakeSession() - backend.create_staging_table(_ComputedTable, session, "_staging_target_table") + backend.create_staging_table(_ComputedTableCls, _sess(session), "_staging_target_table") assert any('DROP TABLE IF EXISTS "_staging_target_table"' in sql for sql in session.statements) assert any('CREATE UNLOGGED TABLE "_staging_target_table"' in sql for sql in session.statements) @@ -69,7 +88,7 @@ def test_postgres_backend_drop_staging_table(): backend = PostgresBackend() session = _FakeSession() - backend.drop_staging_table(session, "_staging_target_table") + backend.drop_staging_table(_sess(session), "_staging_target_table") assert session.statements == ['DROP TABLE IF EXISTS "_staging_target_table"'] @@ -78,9 +97,9 @@ def test_postgres_backend_fk_methods_emit_expected_sql(): backend = PostgresBackend() session = _FakeSession() - previous = backend.disable_fk_check(session) - enabled = backend.enable_fk_check(session) - backend.restore_fk_check(session, previous) + previous = backend.disable_fk_check(_sess(session)) + enabled = backend.enable_fk_check(_sess(session)) + backend.restore_fk_check(_sess(session), previous) assert previous == "origin" assert enabled == "origin" @@ -97,7 +116,7 @@ def test_postgres_backend_merge_replace_uses_using_delete(): backend = PostgresBackend() session = _FakeSession() - backend.merge_replace(_ComputedTable, session, "target_table", "_staging_target_table", ["id", "name"]) + backend.merge_replace(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id", "name"]) sql = session.statements[0] assert 'DELETE FROM "target_table" t' in sql @@ -109,7 +128,7 @@ def test_postgres_backend_merge_insert_excludes_computed_columns(): backend = PostgresBackend() session = _FakeSession() - backend.merge_insert(_ComputedTable, session, "target_table", "_staging_target_table") + backend.merge_insert(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table") sql = session.statements[0] assert 'INSERT INTO "target_table" ("id", "name")' in sql @@ -120,7 +139,7 @@ def test_postgres_backend_merge_upsert_excludes_computed_columns(): backend = PostgresBackend() session = _FakeSession() - backend.merge_upsert(_ComputedTable, session, "target_table", "_staging_target_table", ["id"]) + backend.merge_upsert(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id"]) sql = session.statements[0] assert 'INSERT INTO "target_table" ("id", "name")' in sql @@ -132,8 +151,8 @@ def test_postgres_backend_materialized_view_methods_emit_expected_sql(): session = _FakeSession() selectable = sa.select(sa.literal(1).label("n")) - backend.create_materialized_view(session, "mv_test", selectable) - backend.refresh_materialized_view(session, "mv_test") + backend.create_materialized_view(_as_engine(session), "mv_test", selectable) + backend.refresh_materialized_view(_as_engine(session), "mv_test") assert any("CREATE MATERIALIZED VIEW IF NOT EXISTS mv_test as SELECT" in sql for sql in session.statements) assert any("REFRESH MATERIALIZED VIEW mv_test;" == sql for sql in session.statements) @@ -152,10 +171,10 @@ class _Conn: def __enter__(self): return self - def __exit__(self, exc_type, exc, tb) -> None: + def __exit__(self, *_) -> None: return None - def execution_options(self, **kwargs): + def execution_options(self, **_): return self def execute(self, statement): @@ -170,16 +189,16 @@ def connect(self): engine = _Engine() - def _listen(target, name, fn) -> None: + def _listen(target, name, *_) -> None: events.append(("listen", target, name)) - def _remove(target, name, fn) -> None: + def _remove(target, name, *_) -> None: events.append(("remove", target, name)) - monkeypatch.setattr(sa.event, "listen", _listen) - monkeypatch.setattr(sa.event, "remove", _remove) + monkeypatch.setattr(sae, "listen", _listen) + monkeypatch.setattr(sae, "remove", _remove) - with backend.engine_with_replica_role(engine): + with backend.engine_with_replica_role(cast(Engine, engine)): pass assert events == [ diff --git a/tests/backends/test_sqlite_backend.py b/tests/backends/test_sqlite_backend.py index e7d33a2..344ed29 100644 --- a/tests/backends/test_sqlite_backend.py +++ b/tests/backends/test_sqlite_backend.py @@ -2,6 +2,7 @@ import sqlite3 from pathlib import Path +from typing import TYPE_CHECKING, Type, cast import sqlalchemy as sa import sqlalchemy.orm as so @@ -9,6 +10,9 @@ from orm_loader.backends import SQLiteBackend from orm_loader.helpers.sqlite import attach_sqlite_bulk_load_pragmas +if TYPE_CHECKING: + from orm_loader.tables.typing import CSVTableProtocol + class _ComputedTable: __table__ = sa.Table( @@ -38,6 +42,13 @@ def scalar(self): return _Result(self.scalar_result) +_ComputedTableCls = cast("Type[CSVTableProtocol]", _ComputedTable) + + +def _sess(s: _FakeSession) -> so.Session: + return cast(so.Session, s) + + def test_sqlite_backend_identity_and_capabilities(): backend = SQLiteBackend() @@ -51,27 +62,22 @@ def test_sqlite_backend_identity_and_capabilities(): assert backend.journal_mode == "WAL" -def test_sqlite_backend_create_staging_table(): +def test_sqlite_backend_create_staging_table(session, engine): backend = SQLiteBackend() - engine = sa.create_engine("sqlite:///:memory:", future=True) - session = so.Session(engine) - try: - backend.create_staging_table(_ComputedTable, session, "_staging_target_table") - inspector = sa.inspect(engine) - assert inspector.has_table("_staging_target_table") is True - cols = inspector.get_columns("_staging_target_table") - assert [c["name"] for c in cols] == ["id", "name", "slug"] - assert all(c["nullable"] is True for c in cols) - finally: - session.close() + backend.create_staging_table(_ComputedTableCls, session, "_staging_target_table") + inspector = sa.inspect(engine) + assert inspector.has_table("_staging_target_table") is True + cols = inspector.get_columns("_staging_target_table") + assert [c["name"] for c in cols] == ["id", "name", "slug"] + assert all(c["nullable"] is True for c in cols) def test_sqlite_backend_drop_staging_table(): backend = SQLiteBackend() session = _FakeSession() - backend.drop_staging_table(session, "_staging_target_table") + backend.drop_staging_table(_sess(session), "_staging_target_table") assert session.statements == ['DROP TABLE IF EXISTS "_staging_target_table"'] @@ -80,9 +86,9 @@ def test_sqlite_backend_fk_methods_emit_expected_sql(): backend = SQLiteBackend() session = _FakeSession() - previous = backend.disable_fk_check(session) - enabled = backend.enable_fk_check(session) - backend.restore_fk_check(session, previous) + previous = backend.disable_fk_check(_sess(session)) + enabled = backend.enable_fk_check(_sess(session)) + backend.restore_fk_check(_sess(session), previous) assert previous == 1 assert enabled == 1 @@ -99,7 +105,7 @@ def test_sqlite_backend_merge_replace_single_pk(): backend = SQLiteBackend() session = _FakeSession() - backend.merge_replace(_ComputedTable, session, "target_table", "_staging_target_table", ["id"]) + backend.merge_replace(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id"]) sql = session.statements[0] assert 'DELETE FROM "target_table"' in sql @@ -110,7 +116,7 @@ def test_sqlite_backend_merge_replace_composite_pk(): backend = SQLiteBackend() session = _FakeSession() - backend.merge_replace(_ComputedTable, session, "target_table", "_staging_target_table", ["id", "name"]) + backend.merge_replace(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id", "name"]) sql = session.statements[0] assert 'WHERE EXISTS (' in sql @@ -122,7 +128,7 @@ def test_sqlite_backend_merge_insert_excludes_computed_columns(): backend = SQLiteBackend() session = _FakeSession() - backend.merge_insert(_ComputedTable, session, "target_table", "_staging_target_table") + backend.merge_insert(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table") sql = session.statements[0] assert 'INSERT INTO "target_table" ("id", "name")' in sql @@ -133,26 +139,25 @@ def test_sqlite_backend_merge_upsert_excludes_computed_columns(): backend = SQLiteBackend() session = _FakeSession() - backend.merge_upsert(_ComputedTable, session, "target_table", "_staging_target_table", ["id"]) + backend.merge_upsert(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id"]) sql = session.statements[0] assert 'INSERT OR IGNORE INTO "target_table" ("id", "name")' in sql -def test_sqlite_backend_materialized_view_methods_raise(): +def test_sqlite_backend_materialized_view_methods_raise(engine): backend = SQLiteBackend() - session = _FakeSession() selectable = sa.select(sa.literal(1).label("n")) try: - backend.create_materialized_view(session, "mv_test", selectable) + backend.create_materialized_view(engine, "mv_test", selectable) except NotImplementedError as exc: assert "does not support materialized views" in str(exc) else: raise AssertionError("Expected create_materialized_view() to raise NotImplementedError") try: - backend.refresh_materialized_view(session, "mv_test") + backend.refresh_materialized_view(engine, "mv_test") except NotImplementedError as exc: assert "does not support materialized views" in str(exc) else: diff --git a/tests/conftest.py b/tests/conftest.py index d509cbf..ef42411 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,25 +1,94 @@ +import os +import time +from pathlib import Path +from urllib.parse import urlparse, urlunparse + import pytest import sqlalchemy as sa import sqlalchemy.orm as so -import time +from dotenv import load_dotenv from tests.models import Base +load_dotenv(Path(__file__).parent.parent / ".env") + + @pytest.fixture def engine(): - return sa.create_engine("sqlite:///:memory:") + engine = sa.create_engine("sqlite:///:memory:", future=True) + Base.metadata.create_all(engine) + return engine + @pytest.fixture def session(engine): - Base.metadata.create_all(engine) with so.Session(engine) as s: yield s -POSTGRES_URL = "postgresql+psycopg://test:test@localhost:55432/test_db" +# --------------------------------------------------------------------------- +# Postgres fixtures +# --------------------------------------------------------------------------- + +POSTGRES_URL = os.getenv( + "TEST_POSTGRES_URL", + "postgresql+psycopg://test:test@localhost:55432/test_db", +) + +# Shown whenever Postgres is unreachable — centralised so every skip carries +# the same actionable instructions. +_PG_SKIP_MSG = ( + "Postgres tests skipped — could not connect to {url}.\n" + " Set TEST_POSTGRES_URL to a writable test database and re-run, e.g.:\n" + " export TEST_POSTGRES_URL='postgresql+psycopg://user:pass@host:5432/orm_loader_test'\n" + " Or add it to orm-loader/.env.\n" + " Last error: {{last_err}}" +).format(url=POSTGRES_URL) + +# Module-level sentinel: None = not yet attempted, str = skip reason. +# Prevents the 20-retry loop from running once per postgres test when +# the server is not reachable. +_pg_unavailable: str | None = None + + +def _ensure_db_exists(url: str) -> None: + """Create the target database if it doesn't already exist. + + Connects to the 'postgres' maintenance database (same host/user/pass) + so the target database can be created without touching anything else. + """ + parsed = urlparse(url) + db_name = parsed.path.lstrip("/") + admin_url = urlunparse(parsed._replace(path="/postgres")) + + admin_engine = sa.create_engine(admin_url, isolation_level="AUTOCOMMIT") + try: + with admin_engine.connect() as conn: + exists = conn.execute( + sa.text("SELECT 1 FROM pg_database WHERE datname = :name"), + {"name": db_name}, + ).scalar() + if not exists: + conn.execute(sa.text(f'CREATE DATABASE "{db_name}"')) + print(f"Created test database: {db_name!r}") + finally: + admin_engine.dispose() + @pytest.fixture(scope="session") def pg_engine(): + global _pg_unavailable + + # Fast path: already know Postgres is not reachable — skip immediately + # without re-running the retry loop. + if _pg_unavailable is not None: + pytest.skip(_pg_unavailable) + + try: + _ensure_db_exists(POSTGRES_URL) + except Exception as e: + print(f"Could not ensure test DB exists (will try connecting anyway): {e}") + last_err = None for i in range(20): try: @@ -35,13 +104,14 @@ def pg_engine(): print(f"[{i}] Postgres not ready:", repr(e)) time.sleep(1) - raise RuntimeError(f"Postgres never became available: {last_err!r}") + _pg_unavailable = _PG_SKIP_MSG.format(last_err=last_err) + pytest.skip(_pg_unavailable) + @pytest.fixture def pg_session(pg_engine): Session = so.sessionmaker(bind=pg_engine, future=True) with pg_engine.begin() as conn: - # optional: recreate schema per test Base.metadata.drop_all(conn) Base.metadata.create_all(conn) @@ -51,6 +121,3 @@ def pg_session(pg_engine): finally: session.rollback() session.close() - - - diff --git a/tests/loaders/test_dedupe.py b/tests/loaders/test_dedupe.py index 6cae76b..c84a6b3 100644 --- a/tests/loaders/test_dedupe.py +++ b/tests/loaders/test_dedupe.py @@ -1,10 +1,12 @@ import pyarrow as pa +from typing import cast, Type from orm_loader.loaders.loading_helpers import arrow_drop_duplicates import pandas as pd import sqlalchemy as sa import sqlalchemy.orm as so from sqlalchemy.orm import DeclarativeBase from orm_loader.tables.loadable_table import CSVLoadableTableInterface +from orm_loader.tables.typing import CSVTableProtocol from orm_loader.loaders.loader_interface import PandasLoader @@ -19,6 +21,9 @@ class DedupTable(Base, CSVLoadableTableInterface): value: so.Mapped[str] = so.mapped_column(sa.String, nullable=False) +_DedupTable = cast(Type[CSVTableProtocol], DedupTable) + + def test_arrow_drop_duplicates_simple(): table = pa.table({ "id": [1, 1, 2], @@ -31,8 +36,8 @@ def test_arrow_drop_duplicates_simple(): -def test_internal_deduplication(session, tmp_path): - Base.metadata.create_all(session.get_bind()) +def test_internal_deduplication(session, engine, tmp_path): + Base.metadata.create_all(engine) csv = tmp_path / "dedup_table.csv" pd.DataFrame( @@ -43,7 +48,7 @@ def test_internal_deduplication(session, tmp_path): ] ).to_csv(csv, index=False) - inserted = DedupTable.load_csv( # type: ignore + inserted = _DedupTable.load_csv( session, csv, loader=PandasLoader(), diff --git a/tests/loaders/test_loader_e2e.py b/tests/loaders/test_loader_e2e.py index 8a60b89..1821f7c 100644 --- a/tests/loaders/test_loader_e2e.py +++ b/tests/loaders/test_loader_e2e.py @@ -1,38 +1,27 @@ import sqlalchemy as sa import sqlalchemy.event as sae import sqlalchemy.orm as so -from sqlalchemy.orm import Session from pathlib import Path +from typing import cast, Type import pandas as pd import pytest +import numpy as np from orm_loader.loaders.data_classes import _clean_nulls from orm_loader.tables.loadable_table import CSVLoadableTableInterface +from orm_loader.tables.typing import CSVTableProtocol from orm_loader.loaders.loader_interface import PandasLoader from tests.models import Base, SimpleTable, RequiredTable, CompositeTable -import numpy as np - -@pytest.fixture -def engine(): - engine = sa.create_engine("sqlite:///:memory:", future=True) - Base.metadata.create_all(engine) - return engine - - -@pytest.fixture -def session(engine): - with Session(engine) as session: - yield session - - -@pytest.fixture -def tmp_csv_dir(tmp_path: Path) -> Path: - return tmp_path +# Typed aliases: Pylance cannot verify SQLAlchemy metaclass-generated attrs +# satisfy CSVTableProtocol structurally, so we cast once per class here. +_SimpleTable = cast(Type[CSVTableProtocol], SimpleTable) +_RequiredTable = cast(Type[CSVTableProtocol], RequiredTable) +_CompositeTable = cast(Type[CSVTableProtocol], CompositeTable) -def test_initial_csv_load(session, tmp_csv_dir): - csv_path = tmp_csv_dir / "test_table.csv" +def test_initial_csv_load(session, tmp_path): + csv_path = tmp_path / "test_table.csv" pd.DataFrame( [ @@ -44,7 +33,7 @@ def test_initial_csv_load(session, tmp_csv_dir): loader = PandasLoader() - inserted = SimpleTable.load_csv( # type: ignore + inserted = _SimpleTable.load_csv( session, csv_path, dedupe=False, @@ -65,8 +54,8 @@ def test_initial_csv_load(session, tmp_csv_dir): ] -def test_replace_merge_strategy(session, tmp_csv_dir): - csv_path = tmp_csv_dir / "test_table.csv" +def test_replace_merge_strategy(session, tmp_path): + csv_path = tmp_path / "test_table.csv" # Initial load pd.DataFrame( @@ -79,7 +68,7 @@ def test_replace_merge_strategy(session, tmp_csv_dir): loader = PandasLoader() - SimpleTable.load_csv( # type: ignore + _SimpleTable.load_csv( session, csv_path, dedupe=False, @@ -95,7 +84,7 @@ def test_replace_merge_strategy(session, tmp_csv_dir): ] ).to_csv(csv_path, index=False, sep="\t") - replaced = SimpleTable.load_csv( # type: ignore + replaced = _SimpleTable.load_csv( session, csv_path, dedupe=False, @@ -117,14 +106,14 @@ def test_replace_merge_strategy(session, tmp_csv_dir): ] -def test_empty_csv_is_noop(session, tmp_csv_dir): - csv_path = tmp_csv_dir / "test_table.csv" +def test_empty_csv_is_noop(session, tmp_path): + csv_path = tmp_path / "test_table.csv" csv_path.touch() loader = PandasLoader() - inserted = SimpleTable.load_csv( # type: ignore - session, + inserted = _SimpleTable.load_csv( + session, csv_path, dedupe=False, loader=loader, @@ -139,8 +128,6 @@ def test_empty_csv_is_noop(session, tmp_csv_dir): def test_required_column_violation_drops_rows(session, tmp_path): - Base.metadata.create_all(session.get_bind()) - csv = tmp_path / "required_table.csv" pd.DataFrame( [ @@ -149,7 +136,7 @@ def test_required_column_violation_drops_rows(session, tmp_path): ] ).to_csv(csv, index=False) - inserted = RequiredTable.load_csv( # type: ignore + inserted = _RequiredTable.load_csv( session, csv, loader=PandasLoader(), @@ -162,8 +149,6 @@ def test_required_column_violation_drops_rows(session, tmp_path): def test_composite_pk_dedup(session, tmp_path): - Base.metadata.create_all(session.get_bind()) - csv = tmp_path / "composite_table.csv" pd.DataFrame( [ @@ -173,7 +158,7 @@ def test_composite_pk_dedup(session, tmp_path): ] ).to_csv(csv, index=False) - inserted = CompositeTable.load_csv( # type: ignore + inserted = _CompositeTable.load_csv( session, csv, loader=PandasLoader(), @@ -207,8 +192,8 @@ def test_composite_pk_dedup(session, tmp_path): ), ], ) -def test_merge_strategies(session, tmp_csv_dir, merge_strategy, expected_rows, expected_inserted): - csv_path = tmp_csv_dir / "test_table.csv" +def test_merge_strategies(session, tmp_path, merge_strategy, expected_rows, expected_inserted): + csv_path = tmp_path / "test_table.csv" pd.DataFrame( [ @@ -247,12 +232,12 @@ def test_merge_strategies(session, tmp_csv_dir, merge_strategy, expected_rows, e assert [(r.id, r.name) for r in rows] == expected_rows -def test_staging_table_is_created_and_dropped(session, tmp_csv_dir): - csv_path = tmp_csv_dir / "test_table.csv" +def test_staging_table_is_created_and_dropped(session, engine, tmp_path): + csv_path = tmp_path / "test_table.csv" pd.DataFrame([{"id": 1, "name": "alpha"}]).to_csv(csv_path, index=False) - SimpleTable.load_csv( + _SimpleTable.load_csv( session, csv_path, loader=PandasLoader(), @@ -260,7 +245,7 @@ def test_staging_table_is_created_and_dropped(session, tmp_csv_dir): ) session.commit() - inspector = sa.inspect(session.get_bind()) + inspector = sa.inspect(engine) assert not inspector.has_table(SimpleTable.staging_tablename()) @@ -326,14 +311,15 @@ def test_clean_nulls_passthrough(): assert _clean_nulls("S") == "S" -def test_nullable_column_with_nan_does_not_crash(session, tmp_path): +def test_nullable_column_with_nan_does_not_crash(session, engine, tmp_path): class NullableTable(Base, CSVLoadableTableInterface): __tablename__ = "nullable_table" id: so.Mapped[int] = so.mapped_column(sa.Integer, primary_key=True) flag: so.Mapped[str | None] = so.mapped_column(sa.String, nullable=True) - Base.metadata.create_all(session.get_bind()) + Base.metadata.create_all(engine) + _NullableTable = cast(Type[CSVTableProtocol], NullableTable) csv = tmp_path / "nullable_table.csv" pd.DataFrame( @@ -343,7 +329,7 @@ class NullableTable(Base, CSVLoadableTableInterface): ] ).to_csv(csv, index=False) - inserted = NullableTable.load_csv( # type: ignore + inserted = _NullableTable.load_csv( session, csv, loader=PandasLoader(), @@ -363,14 +349,15 @@ class NullableTable(Base, CSVLoadableTableInterface): ] -def test_embedded_newline_in_field_is_preserved(session, tmp_path): +def test_embedded_newline_in_field_is_preserved(session, engine, tmp_path): class TextTable(Base, CSVLoadableTableInterface): __tablename__ = "text_table" id: so.Mapped[int] = so.mapped_column(sa.Integer, primary_key=True) name: so.Mapped[str] = so.mapped_column(sa.String) - Base.metadata.create_all(session.get_bind()) + Base.metadata.create_all(engine) + _TextTable = cast(Type[CSVTableProtocol], TextTable) csv = tmp_path / "text_table.csv" @@ -380,7 +367,7 @@ class TextTable(Base, CSVLoadableTableInterface): '1\t"hello\nworld"\n' ) - TextTable.load_csv( # type: ignore + _TextTable.load_csv( session, csv, loader=PandasLoader(), @@ -392,14 +379,15 @@ class TextTable(Base, CSVLoadableTableInterface): assert rows[0].name == "hello\nworld" -def test_embedded_tab_in_field(session, tmp_path): +def test_embedded_tab_in_field(session, engine, tmp_path): class TextTable2(Base, CSVLoadableTableInterface): __tablename__ = "tab_table" id: so.Mapped[int] = so.mapped_column(sa.Integer, primary_key=True) name: so.Mapped[str] = so.mapped_column(sa.String) - Base.metadata.create_all(session.get_bind()) + Base.metadata.create_all(engine) + _TextTable2 = cast(Type[CSVTableProtocol], TextTable2) csv = tmp_path / "tab_table.csv" csv.write_text( @@ -407,7 +395,7 @@ class TextTable2(Base, CSVLoadableTableInterface): '1\t"foo\tbar"\n' ) - TextTable2.load_csv( # type: ignore + _TextTable2.load_csv( session, csv, loader=PandasLoader(), @@ -425,24 +413,25 @@ def _make_ddl_tracker(engine): """Return a list that is populated with DROP/CREATE INDEX statements as they execute.""" ddl_log: list[str] = [] - @sae.listens_for(engine, "before_cursor_execute") - def _capture(conn, cursor, statement, parameters, context, executemany): + def _capture(*args): + statement: str = args[2] upper = statement.strip().upper() if upper.startswith("DROP INDEX") or upper.startswith("CREATE INDEX"): ddl_log.append(statement.strip()) + sae.listen(engine, "before_cursor_execute", _capture) return ddl_log -def test_auto_strategy_keeps_indices_on_sqlite(session, engine, tmp_csv_dir): +def test_auto_strategy_keeps_indices_on_sqlite(session, engine, tmp_path): """On SQLite, 'auto' resolves to 'keep' — no index DDL should be emitted.""" ddl_log = _make_ddl_tracker(engine) - csv_path = tmp_csv_dir / "test_table.csv" + csv_path = tmp_path / "test_table.csv" pd.DataFrame([{"id": 1, "name": "alpha"}, {"id": 2, "name": "beta"}]).to_csv( csv_path, index=False, sep="\t" ) - SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="auto") # type: ignore + _SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="auto") session.commit() assert not any("DROP INDEX" in s.upper() for s in ddl_log) @@ -452,13 +441,13 @@ def test_auto_strategy_keeps_indices_on_sqlite(session, engine, tmp_csv_dir): assert "ix_test_table_name" in {idx["name"] for idx in inspector.get_indexes("test_table")} -def test_explicit_keep_preserves_indices(session, engine, tmp_csv_dir): +def test_explicit_keep_preserves_indices(session, engine, tmp_path): """Explicit 'keep' emits no index DDL regardless of dialect.""" ddl_log = _make_ddl_tracker(engine) - csv_path = tmp_csv_dir / "test_table.csv" + csv_path = tmp_path / "test_table.csv" pd.DataFrame([{"id": 1, "name": "alpha"}]).to_csv(csv_path, index=False, sep="\t") - SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="keep") # type: ignore + _SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="keep") session.commit() assert not any("DROP INDEX" in s.upper() for s in ddl_log) @@ -467,15 +456,15 @@ def test_explicit_keep_preserves_indices(session, engine, tmp_csv_dir): assert "ix_test_table_name" in {idx["name"] for idx in inspector.get_indexes("test_table")} -def test_explicit_drop_rebuild_on_sqlite_restores_index(session, engine, tmp_csv_dir): +def test_explicit_drop_rebuild_on_sqlite_restores_index(session, engine, tmp_path): """Explicit 'drop_rebuild' drops then restores the index even on SQLite.""" ddl_log = _make_ddl_tracker(engine) - csv_path = tmp_csv_dir / "test_table.csv" + csv_path = tmp_path / "test_table.csv" pd.DataFrame([{"id": 1, "name": "alpha"}, {"id": 2, "name": "beta"}]).to_csv( csv_path, index=False, sep="\t" ) - SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="drop_rebuild") # type: ignore + _SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="drop_rebuild") session.commit() assert any("DROP INDEX" in s.upper() for s in ddl_log) @@ -485,13 +474,13 @@ def test_explicit_drop_rebuild_on_sqlite_restores_index(session, engine, tmp_csv assert "ix_test_table_name" in {idx["name"] for idx in inspector.get_indexes("test_table")} -def test_invalid_index_strategy_raises(session, tmp_csv_dir): +def test_invalid_index_strategy_raises(session, tmp_path): """An unrecognised strategy value raises ValueError before any DB work.""" - csv_path = tmp_csv_dir / "test_table.csv" + csv_path = tmp_path / "test_table.csv" pd.DataFrame([{"id": 1, "name": "alpha"}]).to_csv(csv_path, index=False, sep="\t") with pytest.raises(ValueError, match="Unknown index_strategy"): - SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="not-valid") # type: ignore + _SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="not-valid") # from hypothesis import given, strategies as st diff --git a/tests/loaders/test_parquet_loader.py b/tests/loaders/test_parquet_loader.py index 8dbca70..a1e735a 100644 --- a/tests/loaders/test_parquet_loader.py +++ b/tests/loaders/test_parquet_loader.py @@ -4,8 +4,10 @@ import sqlalchemy as sa import sqlalchemy.orm as so from sqlalchemy.orm import DeclarativeBase +from typing import cast, Type from orm_loader.tables.loadable_table import CSVLoadableTableInterface +from orm_loader.tables.typing import CSVTableProtocol from orm_loader.loaders.loader_interface import ParquetLoader @@ -20,8 +22,11 @@ class ParquetTable(Base, CSVLoadableTableInterface): value: so.Mapped[int] = so.mapped_column(sa.Integer, nullable=False) -def test_parquet_loader(session, tmp_path): - Base.metadata.create_all(session.get_bind()) +_ParquetTable = cast(Type[CSVTableProtocol], ParquetTable) + + +def test_parquet_loader(session, engine, tmp_path): + Base.metadata.create_all(engine) df = pd.DataFrame( [ @@ -33,7 +38,7 @@ def test_parquet_loader(session, tmp_path): path = tmp_path / "parquet_table.parquet" pq.write_table(table, path) - inserted = ParquetTable.load_csv( # type: ignore + inserted = _ParquetTable.load_csv( session, path, loader=ParquetLoader(), diff --git a/tests/loaders/test_pg_loader.py b/tests/loaders/test_pg_loader.py index 32228fd..0e278a8 100644 --- a/tests/loaders/test_pg_loader.py +++ b/tests/loaders/test_pg_loader.py @@ -49,8 +49,8 @@ def fake_quick_load_pg(*args, **kwargs): called["copy"] = True return 1 - import orm_loader.tables.loadable_table as loadable_table - monkeypatch.setattr(loadable_table, "quick_load_pg", fake_quick_load_pg) + import orm_loader.backends.postgres as pg_backend + monkeypatch.setattr(pg_backend, "quick_load_pg", fake_quick_load_pg) inserted = SimpleTable.load_csv(pg_session, csv) pg_session.commit() @@ -63,12 +63,12 @@ def test_copy_failure_falls_back_to_orm(pg_session, tmp_path, monkeypatch): csv = tmp_path / "test_table.csv" pd.DataFrame([{"id": 1, "name": "alpha"}]).to_csv(csv, index=False) - from orm_loader.loaders import loading_helpers + import orm_loader.backends.postgres as pg_backend def broken_copy(*args, **kwargs): raise RuntimeError("boom") - monkeypatch.setattr(loading_helpers, "quick_load_pg", broken_copy) + monkeypatch.setattr(pg_backend, "quick_load_pg", broken_copy) inserted = SimpleTable.load_csv(pg_session, csv) pg_session.commit() @@ -145,7 +145,9 @@ def test_infer_encoding_utf8(tmp_path): p.write_text("id,name\n1,α\n", encoding="utf-8") enc = infer_encoding(p) - assert enc["encoding"].lower().startswith("utf") + enc_str = enc["encoding"] + assert enc_str is not None + assert enc_str.lower().startswith("utf") def test_infer_delim_tab(tmp_path): p = tmp_path / "tab.csv" diff --git a/tests/pytest.ini b/tests/pytest.ini deleted file mode 100644 index d89701f..0000000 --- a/tests/pytest.ini +++ /dev/null @@ -1,3 +0,0 @@ -[pytest] -markers = - postgres: requires a running postgres instance \ No newline at end of file From cbc8be235f8b88bfbbc38ccbbb21efefce68ce6f Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Mon, 18 May 2026 22:59:48 +0000 Subject: [PATCH 08/21] Update the CI/CD to run the postgres tests --- .github/workflows/tests.yml | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c052d04..38ea4e8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,6 +10,21 @@ jobs: name: pytest (Python ${{ matrix.python-version }}) runs-on: ubuntu-latest + services: + postgres: + image: postgres:16 + env: + POSTGRES_USER: test + POSTGRES_PASSWORD: test + POSTGRES_DB: orm_loader_test + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 5s + --health-timeout 5s + --health-retries 10 + strategy: fail-fast: false matrix: @@ -30,9 +45,10 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e ".[dev]" + pip install -e ".[dev,postgres]" - name: Run pytest env: PYTHONPATH: src - run: pytest -m "not postgres" \ No newline at end of file + TEST_POSTGRES_URL: postgresql+psycopg://test:test@localhost:5432/orm_loader_test + run: pytest From 0d9cf6ef1b65bce36fe9e243eb4dffddf1c50c0f Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Mon, 18 May 2026 23:13:08 +0000 Subject: [PATCH 09/21] Attempt to fix ruff and missing dotenv for CI/CD --- pyproject.toml | 4 +- tests/loaders/test_loader_e2e.py | 64 ++++++++++++++------------------ 2 files changed, 31 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f801f4e..3b72c69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,9 @@ dev = [ "requests>=2.33.0", "mkdocs>=1.6.1", "mkdocs-mermaid2-plugin", - "Pygments>=2.20.0" + "Pygments>=2.20.0", + "dotenv", + "ruff" ] [tool.setuptools] diff --git a/tests/loaders/test_loader_e2e.py b/tests/loaders/test_loader_e2e.py index 1821f7c..bb53dd9 100644 --- a/tests/loaders/test_loader_e2e.py +++ b/tests/loaders/test_loader_e2e.py @@ -1,17 +1,17 @@ +from typing import Type, cast + +import numpy as np +import pandas as pd +import pytest import sqlalchemy as sa import sqlalchemy.event as sae import sqlalchemy.orm as so -from pathlib import Path -from typing import cast, Type -import pandas as pd -import pytest -import numpy as np + from orm_loader.loaders.data_classes import _clean_nulls +from orm_loader.loaders.loader_interface import PandasLoader from orm_loader.tables.loadable_table import CSVLoadableTableInterface from orm_loader.tables.typing import CSVTableProtocol -from orm_loader.loaders.loader_interface import PandasLoader - -from tests.models import Base, SimpleTable, RequiredTable, CompositeTable +from tests.models import Base, CompositeTable, RequiredTable, SimpleTable # Typed aliases: Pylance cannot verify SQLAlchemy metaclass-generated attrs # satisfy CSVTableProtocol structurally, so we cast once per class here. @@ -43,9 +43,7 @@ def test_initial_csv_load(session, tmp_path): assert inserted == 3 - rows = session.execute( - sa.select(SimpleTable).order_by(SimpleTable.id) - ).scalars().all() + rows = session.execute(sa.select(SimpleTable).order_by(SimpleTable.id)).scalars().all() assert [(r.id, r.name) for r in rows] == [ (1, "alpha"), @@ -95,9 +93,7 @@ def test_replace_merge_strategy(session, tmp_path): assert replaced == 2 - rows = session.execute( - sa.select(SimpleTable).order_by(SimpleTable.id) - ).scalars().all() + rows = session.execute(sa.select(SimpleTable).order_by(SimpleTable.id)).scalars().all() assert [(r.id, r.name) for r in rows] == [ (1, "alpha"), @@ -126,7 +122,6 @@ def test_empty_csv_is_noop(session, tmp_path): assert rows == [] - def test_required_column_violation_drops_rows(session, tmp_path): csv = tmp_path / "required_table.csv" pd.DataFrame( @@ -147,7 +142,6 @@ def test_required_column_violation_drops_rows(session, tmp_path): assert inserted == 1 - def test_composite_pk_dedup(session, tmp_path): csv = tmp_path / "composite_table.csv" pd.DataFrame( @@ -225,9 +219,11 @@ def test_merge_strategies(session, tmp_path, merge_strategy, expected_rows, expe assert inserted == expected_inserted - rows = session.execute( - sa.select(SimpleTable).order_by(SimpleTable.id, SimpleTable.name) - ).scalars().all() + rows = ( + session.execute(sa.select(SimpleTable).order_by(SimpleTable.id, SimpleTable.name)) + .scalars() + .all() + ) assert [(r.id, r.name) for r in rows] == expected_rows @@ -276,9 +272,11 @@ def test_composite_pk_replace_merge(session, tmp_path): ) session.commit() - rows = session.execute( - sa.select(CompositeTable).order_by(CompositeTable.a, CompositeTable.b) - ).scalars().all() + rows = ( + session.execute(sa.select(CompositeTable).order_by(CompositeTable.a, CompositeTable.b)) + .scalars() + .all() + ) assert [(r.a, r.b, r.value) for r in rows] == [ (1, 1, "x_updated"), @@ -304,9 +302,10 @@ def test_clean_nulls_basic(): assert _clean_nulls(float("nan")) is None assert _clean_nulls(np.nan) is None + def test_clean_nulls_passthrough(): assert _clean_nulls("") == "" - assert _clean_nulls("nan") == "nan" # string 'nan' must not be converted + assert _clean_nulls("nan") == "nan" # string 'nan' must not be converted assert _clean_nulls(0) == 0 assert _clean_nulls("S") == "S" @@ -325,7 +324,7 @@ class NullableTable(Base, CSVLoadableTableInterface): pd.DataFrame( [ {"id": 1, "flag": "S"}, - {"id": 2, "flag": None}, # becomes NaN in pandas + {"id": 2, "flag": None}, # becomes NaN in pandas ] ).to_csv(csv, index=False) @@ -339,9 +338,7 @@ class NullableTable(Base, CSVLoadableTableInterface): assert inserted == 2 - rows = session.execute( - sa.select(NullableTable).order_by(NullableTable.id) - ).scalars().all() + rows = session.execute(sa.select(NullableTable).order_by(NullableTable.id)).scalars().all() assert [(r.id, r.flag) for r in rows] == [ (1, "S"), @@ -362,10 +359,7 @@ class TextTable(Base, CSVLoadableTableInterface): csv = tmp_path / "text_table.csv" # Properly quoted CSV with embedded newline - csv.write_text( - 'id\tname\n' - '1\t"hello\nworld"\n' - ) + csv.write_text('id\tname\n1\t"hello\nworld"\n') _TextTable.load_csv( session, @@ -390,10 +384,7 @@ class TextTable2(Base, CSVLoadableTableInterface): _TextTable2 = cast(Type[CSVTableProtocol], TextTable2) csv = tmp_path / "tab_table.csv" - csv.write_text( - 'id\tname\n' - '1\t"foo\tbar"\n' - ) + csv.write_text('id\tname\n1\t"foo\tbar"\n') _TextTable2.load_csv( session, @@ -409,6 +400,7 @@ class TextTable2(Base, CSVLoadableTableInterface): # --- index_strategy tests --- + def _make_ddl_tracker(engine): """Return a list that is populated with DROP/CREATE INDEX statements as they execute.""" ddl_log: list[str] = [] @@ -526,4 +518,4 @@ def test_invalid_index_strategy_raises(session, tmp_path): # assert rows == [] # else: # # stored value may be str-canonicalised version -# assert rows[0].txt.encode("utf-8", errors="replace") == s.encode("utf-8", errors="replace") \ No newline at end of file +# assert rows[0].txt.encode("utf-8", errors="replace") == s.encode("utf-8", errors="replace") From dddc15e101036a3130561f10c9d925679be91714 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Tue, 19 May 2026 00:05:44 +0000 Subject: [PATCH 10/21] Fix remaining PR comments, including SQL injection sites --- src/orm_loader/backends/base.py | 11 +++- src/orm_loader/backends/postgres.py | 22 +++++++- src/orm_loader/backends/sqlite.py | 21 ++++++- src/orm_loader/helpers/logging.py | 12 +++- src/orm_loader/tables/orm_table.py | 8 +-- tests/backends/test_base_backend.py | 4 ++ tests/backends/test_sqlite_backend.py | 79 ++++++++++++++++++++++----- tests/tables/test_orm_table_base.py | 11 ++-- 8 files changed, 141 insertions(+), 27 deletions(-) diff --git a/src/orm_loader/backends/base.py b/src/orm_loader/backends/base.py index 09207f4..60ea2b0 100644 --- a/src/orm_loader/backends/base.py +++ b/src/orm_loader/backends/base.py @@ -93,7 +93,7 @@ def _as_connection( bind: Engine | Connection, ) -> Iterator[Connection]: if isinstance(bind, Engine): - with bind.connect() as conn: + with bind.begin() as conn: yield conn else: yield bind @@ -140,6 +140,15 @@ def load_staging_fast( """ return None + @staticmethod + @abstractmethod + def _normalize_fk_check_state(previous_state: str | int) -> str | int: + """Validate and normalise a previously-returned FK state before interpolating into SQL. + + Each backend accepts a different type (SQLite: int, Postgres: str) and must + implement this to guard restore_fk_check() against invalid or injected values. + """ + @abstractmethod def disable_fk_check(self, session: so.Session) -> str | int: """Disable FK checks and return the previous backend-specific state.""" diff --git a/src/orm_loader/backends/postgres.py b/src/orm_loader/backends/postgres.py index e8823ee..5ee0a38 100644 --- a/src/orm_loader/backends/postgres.py +++ b/src/orm_loader/backends/postgres.py @@ -15,6 +15,8 @@ from ..loaders.data_classes import LoaderContext from ..tables.typing import CSVTableProtocol +_VALID_PG_REPLICATION_ROLES = frozenset({"origin", "local", "replica"}) + class PostgresBackend(DatabaseBackend): @property @@ -76,6 +78,23 @@ def load_staging_fast( quote_mode=loader_context.quote_mode, ) + @staticmethod + def _normalize_fk_check_state(previous_state: str | int) -> str: + if isinstance(previous_state, int): + raise ValueError( + f"Invalid PostgreSQL session_replication_role {previous_state!r}: " + "Postgres uses string roles ('origin', 'local', 'replica'), not integers. " + "The value passed here should always come from this backend's own " + "disable_fk_check(), which returns a string." + ) + normalised = previous_state.strip().lower() + if normalised not in _VALID_PG_REPLICATION_ROLES: + raise ValueError( + f"Invalid PostgreSQL session_replication_role {previous_state!r}. " + f"Expected one of: {sorted(_VALID_PG_REPLICATION_ROLES)}" + ) + return normalised + def disable_fk_check(self, session: so.Session) -> str | int: previous_state = session.execute(sa.text("SHOW session_replication_role")).scalar() session.execute(sa.text("SET session_replication_role = 'replica'")) @@ -93,7 +112,8 @@ def restore_fk_check( session: so.Session, previous_state: str | int, ) -> None: - session.execute(sa.text(f"SET session_replication_role = '{previous_state}'")) + safe_state = self._normalize_fk_check_state(previous_state) + session.execute(sa.text(f"SET session_replication_role = '{safe_state}'")) def merge_replace( self, diff --git a/src/orm_loader/backends/sqlite.py b/src/orm_loader/backends/sqlite.py index 4ae1c70..c196955 100644 --- a/src/orm_loader/backends/sqlite.py +++ b/src/orm_loader/backends/sqlite.py @@ -46,6 +46,24 @@ def _validate_journal_mode(journal_mode: str) -> str: ) return normalised + @staticmethod + def _normalize_fk_check_state(previous_state: str | int) -> str: + if isinstance(previous_state, int): + if previous_state == 1: + return "ON" + if previous_state == 0: + return "OFF" + elif isinstance(previous_state, str): + normalised = previous_state.strip().upper() + if normalised in {"1", "ON"}: + return "ON" + if normalised in {"0", "OFF"}: + return "OFF" + raise ValueError( + f"Invalid SQLite foreign_keys state {previous_state!r}. " + "Expected 0, 1, 'OFF', or 'ON'." + ) + @property def name(self) -> str: return "sqlite" @@ -108,7 +126,8 @@ def restore_fk_check( session: so.Session, previous_state: str | int, ) -> None: - session.execute(text(f"PRAGMA foreign_keys = {previous_state}")) + safe_state = self._normalize_fk_check_state(previous_state) + session.execute(text(f"PRAGMA foreign_keys = {safe_state}")) def merge_replace( self, diff --git a/src/orm_loader/helpers/logging.py b/src/orm_loader/helpers/logging.py index 376f5e8..bce30f9 100644 --- a/src/orm_loader/helpers/logging.py +++ b/src/orm_loader/helpers/logging.py @@ -1,7 +1,8 @@ from __future__ import annotations + import logging -from typing import Optional, Any import re +from typing import Any, Optional SENSITIVE_KEYS = { "password", @@ -15,10 +16,13 @@ } LOGGING_NAMESPACE = "sql_loader" + def _coerce_log_level(level: int | str) -> int: if isinstance(level, int): return level + if not isinstance(level, str): + raise TypeError(f"log level must be an int or str, got {type(level).__name__}") s = level.strip().upper() if s.isdigit(): return int(s) @@ -29,6 +33,7 @@ def _coerce_log_level(level: int | str) -> int: raise ValueError(f"Invalid log level: {level!r}") + def get_logger(name: Optional[str] = None) -> logging.Logger: """ Return a namespaced logger. @@ -51,7 +56,8 @@ def __init__(self, *args: Any, **kwargs: Any): def format(self, record: logging.LogRecord) -> str: msg = super().format(record) return self._pattern.sub(r"\\1=", msg) - + + def configure_logging( *, level: int | str = logging.INFO, @@ -83,4 +89,4 @@ def configure_logging( logger.propagate = propagate -logging.getLogger(LOGGING_NAMESPACE).addHandler(logging.NullHandler()) \ No newline at end of file +logging.getLogger(LOGGING_NAMESPACE).addHandler(logging.NullHandler()) diff --git a/src/orm_loader/tables/orm_table.py b/src/orm_loader/tables/orm_table.py index 771f0e4..24c634d 100644 --- a/src/orm_loader/tables/orm_table.py +++ b/src/orm_loader/tables/orm_table.py @@ -1,6 +1,6 @@ import sqlalchemy as sa import sqlalchemy.orm as so -from sqlalchemy.exc import StatementError +from sqlalchemy.exc import NoInspectionAvailable, StatementError from typing import Any import logging from .allocators import IdAllocator @@ -63,10 +63,10 @@ def mapper_for(cls: type[Any]) -> so.Mapper[Any]: TypeError If the class is not a mapped SQLAlchemy ORM class. """ - mapper: so.Mapper[Any] = sa.inspect(cls) - if not mapper: + try: + return sa.inspect(cls) + except NoInspectionAvailable: raise TypeError(f"{cls.__name__} is not a mapped ORM class") - return mapper @classmethod def pk_columns(cls) -> list[sa.ColumnElement[Any]]: diff --git a/tests/backends/test_base_backend.py b/tests/backends/test_base_backend.py index 44b9502..829fbc8 100644 --- a/tests/backends/test_base_backend.py +++ b/tests/backends/test_base_backend.py @@ -87,6 +87,10 @@ def merge_insert( ) -> None: return None + @staticmethod + def _normalize_fk_check_state(previous_state: str | int) -> str | int: + return previous_state + def disable_fk_check(self, session: so.Session) -> str | int: self.calls.append(("disable_fk_check", session)) return "enabled" diff --git a/tests/backends/test_sqlite_backend.py b/tests/backends/test_sqlite_backend.py index 344ed29..32c9a57 100644 --- a/tests/backends/test_sqlite_backend.py +++ b/tests/backends/test_sqlite_backend.py @@ -82,30 +82,79 @@ def test_sqlite_backend_drop_staging_table(): assert session.statements == ['DROP TABLE IF EXISTS "_staging_target_table"'] -def test_sqlite_backend_fk_methods_emit_expected_sql(): +def test_sqlite_backend_disable_fk_reads_then_sets(): backend = SQLiteBackend() - session = _FakeSession() + session = _FakeSession(scalar_result=1) previous = backend.disable_fk_check(_sess(session)) - enabled = backend.enable_fk_check(_sess(session)) - backend.restore_fk_check(_sess(session), previous) assert previous == 1 - assert enabled == 1 assert session.statements == [ - "PRAGMA foreign_keys", - "PRAGMA foreign_keys = OFF", - "PRAGMA foreign_keys", + "PRAGMA foreign_keys", # read current state + "PRAGMA foreign_keys = OFF", # set to OFF + ] + + +def test_sqlite_backend_enable_fk_reads_then_sets(): + backend = SQLiteBackend() + session = _FakeSession(scalar_result=0) + + previous = backend.enable_fk_check(_sess(session)) + + assert previous == 0 + assert session.statements == [ + "PRAGMA foreign_keys", # read current state + "PRAGMA foreign_keys = ON", # set to ON + ] + + +def test_sqlite_backend_restore_fk_normalises_int_and_emits(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.restore_fk_check(_sess(session), 1) + backend.restore_fk_check(_sess(session), 0) + + assert session.statements == [ "PRAGMA foreign_keys = ON", - "PRAGMA foreign_keys = 1", + "PRAGMA foreign_keys = OFF", ] +def test_sqlite_backend_normalize_fk_check_state(): + normalize = SQLiteBackend._normalize_fk_check_state + + assert normalize(1) == "ON" + assert normalize(0) == "OFF" + assert normalize("1") == "ON" + assert normalize("0") == "OFF" + assert normalize("ON") == "ON" + assert normalize("OFF") == "OFF" + assert normalize("on") == "ON" + assert normalize("off") == "OFF" + + try: + normalize(2) + except ValueError as exc: + assert "Invalid SQLite foreign_keys state" in str(exc) + else: + raise AssertionError("Expected ValueError for out-of-range int") + + try: + normalize("enabled") + except ValueError as exc: + assert "Invalid SQLite foreign_keys state" in str(exc) + else: + raise AssertionError("Expected ValueError for unrecognised string") + + def test_sqlite_backend_merge_replace_single_pk(): backend = SQLiteBackend() session = _FakeSession() - backend.merge_replace(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id"]) + backend.merge_replace( + _ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id"] + ) sql = session.statements[0] assert 'DELETE FROM "target_table"' in sql @@ -116,10 +165,12 @@ def test_sqlite_backend_merge_replace_composite_pk(): backend = SQLiteBackend() session = _FakeSession() - backend.merge_replace(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id", "name"]) + backend.merge_replace( + _ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id", "name"] + ) sql = session.statements[0] - assert 'WHERE EXISTS (' in sql + assert "WHERE EXISTS (" in sql assert '"target_table"."id" = "_staging_target_table"."id"' in sql assert '"target_table"."name" = "_staging_target_table"."name"' in sql @@ -139,7 +190,9 @@ def test_sqlite_backend_merge_upsert_excludes_computed_columns(): backend = SQLiteBackend() session = _FakeSession() - backend.merge_upsert(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id"]) + backend.merge_upsert( + _ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id"] + ) sql = session.statements[0] assert 'INSERT OR IGNORE INTO "target_table" ("id", "name")' in sql diff --git a/tests/tables/test_orm_table_base.py b/tests/tables/test_orm_table_base.py index 23c1dab..806ad10 100644 --- a/tests/tables/test_orm_table_base.py +++ b/tests/tables/test_orm_table_base.py @@ -1,11 +1,12 @@ -from sqlalchemy.exc import NoInspectionAvailable -import sqlalchemy.orm as so +import pytest import sqlalchemy as sa +import sqlalchemy.orm as so + from orm_loader.tables.orm_table import ORMTableBase -import pytest Base = so.declarative_base() + def test_pk_introspection(): class T(ORMTableBase, Base): __tablename__ = "t" @@ -13,8 +14,10 @@ class T(ORMTableBase, Base): assert T.pk_names() == ["id"] + def test_pk_missing_raises(): class T(ORMTableBase): __tablename__ = "t" - with pytest.raises(NoInspectionAvailable): + + with pytest.raises(TypeError, match="not a mapped ORM class"): T.pk_columns() From 14af889a88facb63764721828567c8eaedb77a43 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Tue, 19 May 2026 00:49:46 +0000 Subject: [PATCH 11/21] Include FK check/normalisation in bulk_load_context --- src/orm_loader/backends/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/orm_loader/backends/base.py b/src/orm_loader/backends/base.py index 60ea2b0..7bb7d45 100644 --- a/src/orm_loader/backends/base.py +++ b/src/orm_loader/backends/base.py @@ -220,7 +220,8 @@ def bulk_load_context( try: if disable_fk: self._require_capability("supports_fk_toggle", "foreign key toggling") - previous_fk_state = self.disable_fk_check(session) + raw_state = self.disable_fk_check(session) + previous_fk_state = self._normalize_fk_check_state(raw_state) if no_autoflush: with session.no_autoflush: From 4f3006c0fe48a3509d7a2bf15e651e8262c9fc08 Mon Sep 17 00:00:00 2001 From: Nico Loesch Date: Tue, 19 May 2026 02:08:51 +0000 Subject: [PATCH 12/21] Recitify the pyproject.toml --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3b72c69..407130e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,8 +50,7 @@ dev = [ "mkdocs>=1.6.1", "mkdocs-mermaid2-plugin", "Pygments>=2.20.0", - "dotenv", - "ruff" + "python-dotenv" ] [tool.setuptools] From 204ba8f77cb2cc8f7ac5ec6161d62459fb7fe8f6 Mon Sep 17 00:00:00 2001 From: georgie Date: Tue, 19 May 2026 14:00:16 +1000 Subject: [PATCH 13/21] actioned dialect enum --- src/orm_loader/backends/__init__.py | 3 ++- src/orm_loader/backends/base.py | 18 ++++++++++++----- src/orm_loader/backends/postgres.py | 6 +++--- src/orm_loader/backends/resolve.py | 26 +++++++++++++++---------- src/orm_loader/backends/sqlite.py | 6 +++--- tests/backends/test_base_backend.py | 20 ++++++++++++------- tests/backends/test_postgres_backend.py | 5 +++-- tests/backends/test_sqlite_backend.py | 5 +++-- tests/conftest.py | 2 +- tests/docker-compose.yaml | 10 +++++++--- tests/pg_db.py | 2 +- uv.lock | 23 ++++++++++++++++++++++ 12 files changed, 88 insertions(+), 38 deletions(-) diff --git a/src/orm_loader/backends/__init__.py b/src/orm_loader/backends/__init__.py index d12fe23..3fa6888 100644 --- a/src/orm_loader/backends/__init__.py +++ b/src/orm_loader/backends/__init__.py @@ -1,11 +1,12 @@ from .postgres import PostgresBackend from .resolve import resolve_backend from .sqlite import SQLiteBackend -from .base import BackendCapabilities, DatabaseBackend +from .base import BackendCapabilities, DatabaseBackend, Dialect __all__ = [ "BackendCapabilities", "DatabaseBackend", + "Dialect", "PostgresBackend", "SQLiteBackend", "resolve_backend", diff --git a/src/orm_loader/backends/base.py b/src/orm_loader/backends/base.py index 7bb7d45..b5fc9a1 100644 --- a/src/orm_loader/backends/base.py +++ b/src/orm_loader/backends/base.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from contextlib import AbstractContextManager, contextmanager, nullcontext from dataclasses import dataclass +from enum import Enum from typing import TYPE_CHECKING, Type, Any, Iterator import sqlalchemy as sa @@ -29,6 +30,13 @@ class BackendCapabilities: supports_materialized_views: bool = False +class Dialect(str, Enum): + """Supported SQLAlchemy dialect names.""" + + SQLITE = "sqlite" + POSTGRESQL = "postgresql" + + class DatabaseBackend(ABC): """ Abstract base class for database-specific loader behavior. @@ -44,17 +52,17 @@ def name(self) -> str: @property @abstractmethod - def dialect_names(self) -> tuple[str, ...]: - """SQLAlchemy dialect names handled by this backend.""" + def dialect(self) -> Dialect: + """SQLAlchemy dialect handled by this backend.""" @property @abstractmethod def capabilities(self) -> BackendCapabilities: """Capability flags supported by this backend.""" - def supports_dialect(self, dialect_name: str) -> bool: - """Return ``True`` when the backend handles the given dialect name.""" - return dialect_name in self.dialect_names + def supports_dialect(self, dialect: Dialect) -> bool: + """Return ``True`` when the backend handles the given dialect.""" + return self.dialect == dialect @property def default_index_strategy(self) -> str: diff --git a/src/orm_loader/backends/postgres.py b/src/orm_loader/backends/postgres.py index 5ee0a38..96c76b1 100644 --- a/src/orm_loader/backends/postgres.py +++ b/src/orm_loader/backends/postgres.py @@ -6,7 +6,7 @@ import sqlalchemy.orm as so import sqlalchemy.event as sae -from .base import BackendCapabilities, DatabaseBackend +from .base import BackendCapabilities, DatabaseBackend, Dialect from ..loaders.loading_helpers import quick_load_pg if TYPE_CHECKING: @@ -24,8 +24,8 @@ def name(self) -> str: return "postgres" @property - def dialect_names(self) -> tuple[str, ...]: - return ("postgresql",) + def dialect(self) -> Dialect: + return Dialect.POSTGRESQL @property def capabilities(self) -> BackendCapabilities: diff --git a/src/orm_loader/backends/resolve.py b/src/orm_loader/backends/resolve.py index fc3d6b5..e3919c6 100644 --- a/src/orm_loader/backends/resolve.py +++ b/src/orm_loader/backends/resolve.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING import sqlalchemy.orm as so -from .base import DatabaseBackend +from .base import DatabaseBackend, Dialect from .postgres import PostgresBackend from .sqlite import SQLiteBackend @@ -17,24 +17,30 @@ ) -def _dialect_name(bindable: "so.Session | Engine | Connection",) -> str: +def _dialect(bindable: "so.Session | Engine | Connection") -> Dialect: if isinstance(bindable, so.Session): bind = bindable.get_bind() - return bind.dialect.name + dialect_name = bind.dialect.name + elif hasattr(bindable, "dialect"): + dialect_name = bindable.dialect.name + else: + raise TypeError(f"Unsupported bindable type: {type(bindable)!r}") - if hasattr(bindable, "dialect"): - return bindable.dialect.name - - raise TypeError(f"Unsupported bindable type: {type(bindable)!r}") + try: + return Dialect(dialect_name) + except ValueError as exc: + raise NotImplementedError( + f"Unsupported SQLAlchemy dialect '{dialect_name}'" + ) from exc def resolve_backend(bindable: "so.Session | Engine | Connection") -> DatabaseBackend: """ Resolve a concrete backend from a SQLAlchemy session, engine, or connection. """ - dialect_name = _dialect_name(bindable) + dialect = _dialect(bindable) for backend_type in _BACKEND_TYPES: backend = backend_type() - if backend.supports_dialect(dialect_name): + if backend.supports_dialect(dialect): return backend - raise NotImplementedError(f"No backend registered for dialect '{dialect_name}'") + raise NotImplementedError(f"No backend registered for dialect '{dialect.value}'") diff --git a/src/orm_loader/backends/sqlite.py b/src/orm_loader/backends/sqlite.py index c196955..0f16c3e 100644 --- a/src/orm_loader/backends/sqlite.py +++ b/src/orm_loader/backends/sqlite.py @@ -10,7 +10,7 @@ from sqlalchemy import event, text from sqlalchemy.exc import IntegrityError -from .base import BackendCapabilities, DatabaseBackend +from .base import BackendCapabilities, DatabaseBackend, Dialect if TYPE_CHECKING: from sqlalchemy.engine import Connection, Engine @@ -69,8 +69,8 @@ def name(self) -> str: return "sqlite" @property - def dialect_names(self) -> tuple[str, ...]: - return ("sqlite",) + def dialect(self) -> Dialect: + return Dialect.SQLITE @property def capabilities(self) -> BackendCapabilities: diff --git a/tests/backends/test_base_backend.py b/tests/backends/test_base_backend.py index 829fbc8..e1d2b44 100644 --- a/tests/backends/test_base_backend.py +++ b/tests/backends/test_base_backend.py @@ -12,7 +12,12 @@ import sqlalchemy.orm as so from sqlalchemy.engine import Connection, Engine -from orm_loader.backends import BackendCapabilities, DatabaseBackend, resolve_backend +from orm_loader.backends import ( + BackendCapabilities, + DatabaseBackend, + Dialect, + resolve_backend, +) if TYPE_CHECKING: from orm_loader.loaders.data_classes import LoaderContext @@ -40,8 +45,8 @@ def name(self) -> str: return "fake" @property - def dialect_names(self) -> tuple[str, ...]: - return ("fake",) + def dialect(self) -> Dialect: + return Dialect.SQLITE @property def capabilities(self) -> BackendCapabilities: @@ -142,11 +147,11 @@ def test_fake_backend_can_implement_contract(): backend = FakeBackend() assert backend.name == "fake" - assert backend.dialect_names == ("fake",) + assert backend.dialect == Dialect.SQLITE assert backend.capabilities.supports_fast_load is True assert backend.capabilities.supports_fk_toggle is True - assert backend.supports_dialect("fake") is True - assert backend.supports_dialect("sqlite") is False + assert backend.supports_dialect(Dialect.SQLITE) is True + assert backend.supports_dialect(Dialect.POSTGRESQL) is False assert backend.resolve_index_strategy("auto") == "drop_rebuild" assert backend.resolve_index_strategy("keep") == "keep" assert backend.load_staging_fast(cast("LoaderContext", None), "staging") is None @@ -240,6 +245,7 @@ def test_backends_package_exports(): assert backends.DatabaseBackend is DatabaseBackend assert backends.BackendCapabilities is BackendCapabilities + assert backends.Dialect is Dialect assert backends.resolve_backend is resolve_backend @@ -262,7 +268,7 @@ class _Unknown: class dialect: name = "unknown" - with pytest.raises(NotImplementedError, match="No backend registered"): + with pytest.raises(NotImplementedError, match="Unsupported SQLAlchemy dialect"): resolve_backend(cast(Engine, _Unknown())) diff --git a/tests/backends/test_postgres_backend.py b/tests/backends/test_postgres_backend.py index cee40fc..3cef920 100644 --- a/tests/backends/test_postgres_backend.py +++ b/tests/backends/test_postgres_backend.py @@ -8,7 +8,7 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.engine import Connection, Engine -from orm_loader.backends import PostgresBackend +from orm_loader.backends import Dialect, PostgresBackend if TYPE_CHECKING: from orm_loader.tables.typing import CSVTableProtocol @@ -65,7 +65,8 @@ def test_postgres_backend_identity_and_capabilities(): backend = PostgresBackend() assert backend.name == "postgres" - assert backend.supports_dialect("postgresql") is True + assert backend.dialect == Dialect.POSTGRESQL + assert backend.supports_dialect(Dialect.POSTGRESQL) is True assert backend.capabilities.supports_fast_load is True assert backend.capabilities.supports_unlogged_staging is True assert backend.capabilities.supports_fk_toggle is True diff --git a/tests/backends/test_sqlite_backend.py b/tests/backends/test_sqlite_backend.py index 32c9a57..aa3163e 100644 --- a/tests/backends/test_sqlite_backend.py +++ b/tests/backends/test_sqlite_backend.py @@ -7,7 +7,7 @@ import sqlalchemy as sa import sqlalchemy.orm as so -from orm_loader.backends import SQLiteBackend +from orm_loader.backends import Dialect, SQLiteBackend from orm_loader.helpers.sqlite import attach_sqlite_bulk_load_pragmas if TYPE_CHECKING: @@ -53,7 +53,8 @@ def test_sqlite_backend_identity_and_capabilities(): backend = SQLiteBackend() assert backend.name == "sqlite" - assert backend.supports_dialect("sqlite") is True + assert backend.dialect == Dialect.SQLITE + assert backend.supports_dialect(Dialect.SQLITE) is True assert backend.capabilities.supports_fast_load is False assert backend.capabilities.supports_unlogged_staging is False assert backend.capabilities.supports_fk_toggle is True diff --git a/tests/conftest.py b/tests/conftest.py index ef42411..64a2531 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,7 +32,7 @@ def session(engine): POSTGRES_URL = os.getenv( "TEST_POSTGRES_URL", - "postgresql+psycopg://test:test@localhost:55432/test_db", + "postgresql+psycopg://test:test@localhost:55432/test", ) # Shown whenever Postgres is unreachable — centralised so every skip carries diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 6328bfd..b8d6f8f 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -4,11 +4,15 @@ services: environment: POSTGRES_USER: test POSTGRES_PASSWORD: test - POSTGRES_DB: test_db + POSTGRES_DB: test ports: - "55432:5432" + volumes: + - postgres_orm_test_data:/var/lib/postgresql/data healthcheck: - test: ["CMD-SHELL", "pg_isready -U test"] + test: ["CMD-SHELL", "pg_isready -U test -d test"] interval: 2s timeout: 2s - retries: 10 \ No newline at end of file + retries: 10 +volumes: + postgres_orm_test_data: \ No newline at end of file diff --git a/tests/pg_db.py b/tests/pg_db.py index e383f9d..d0aacd5 100644 --- a/tests/pg_db.py +++ b/tests/pg_db.py @@ -5,7 +5,7 @@ from tests.models import Base -POSTGRES_URL = "postgresql+psycopg://test:test@localhost:55432/test_db" +POSTGRES_URL = "postgresql+psycopg://test:test@localhost:55432/test" @pytest.fixture(scope="session") def pg_engine(): diff --git a/uv.lock b/uv.lock index 6073ac0..f483117 100644 --- a/uv.lock +++ b/uv.lock @@ -134,6 +134,17 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "dotenv" +version = "0.9.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dotenv" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/b7/545d2c10c1fc15e48653c91efde329a790f2eecfbbf2bd16003b5db2bab0/dotenv-0.9.9-py2.py3-none-any.whl", hash = "sha256:29cf74a087b31dafdb5a446b6d7e11cbce8ed2741540e2339c69fbef92c94ce9", size = 1892, upload-time = "2025-02-19T22:15:01.647Z" }, +] + [[package]] name = "editorconfig" version = "0.17.1" @@ -629,6 +640,7 @@ dependencies = [ [package.optional-dependencies] dev = [ + { name = "dotenv" }, { name = "mkdocs" }, { name = "mkdocs-material" }, { name = "mkdocs-mermaid2-plugin" }, @@ -646,6 +658,7 @@ postgres = [ [package.metadata] requires-dist = [ { name = "chardet", specifier = ">=5.2.0" }, + { name = "dotenv", marker = "extra == 'dev'" }, { name = "mkdocs", marker = "extra == 'dev'", specifier = ">=1.6.1" }, { name = "mkdocs-material", marker = "extra == 'dev'", specifier = ">=9.7.1" }, { name = "mkdocs-mermaid2-plugin", marker = "extra == 'dev'" }, @@ -657,6 +670,7 @@ requires-dist = [ { name = "pygments", marker = "extra == 'dev'", specifier = ">=2.20.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.3" }, { name = "requests", marker = "extra == 'dev'", specifier = ">=2.33.0" }, + { name = "ruff", marker = "extra == 'dev'" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.14.11" }, { name = "sqlalchemy", specifier = ">=2.0.45" }, ] @@ -905,6 +919,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-dotenv" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/ed/0301aeeac3e5353ef3d94b6ec08bbcabd04a72018415dcb29e588514bba8/python_dotenv-1.2.2.tar.gz", hash = "sha256:2c371a91fbd7ba082c2c1dc1f8bf89ca22564a087c2c287cd9b662adde799cf3", size = 50135, upload-time = "2026-03-01T16:00:26.196Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101, upload-time = "2026-03-01T16:00:25.09Z" }, +] + [[package]] name = "pytz" version = "2025.2" From 22c71145a56853ce1b2526856b0fe781030413d4 Mon Sep 17 00:00:00 2001 From: georgie Date: Tue, 19 May 2026 14:07:39 +1000 Subject: [PATCH 14/21] prefer base.registry.mappers to __subclass__ --- src/orm_loader/helpers/discovery.py | 7 ++++++- tests/helpers/test_discovery.py | 22 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 tests/helpers/test_discovery.py diff --git a/src/orm_loader/helpers/discovery.py b/src/orm_loader/helpers/discovery.py index 69ec5b3..0e0333c 100644 --- a/src/orm_loader/helpers/discovery.py +++ b/src/orm_loader/helpers/discovery.py @@ -8,7 +8,12 @@ def get_model_by_tablename( base: type[ModelT] = Base, ) -> type[ModelT] | None: tablename = tablename.lower().strip() - for cls in base.__subclasses__(): + for mapper in base.registry.mappers: + cls = mapper.class_ + if not isinstance(cls, type): + continue + if not issubclass(cls, base): + continue if getattr(cls, "__tablename__", None) == tablename: return cls return None diff --git a/tests/helpers/test_discovery.py b/tests/helpers/test_discovery.py new file mode 100644 index 0000000..5bf16c8 --- /dev/null +++ b/tests/helpers/test_discovery.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import sqlalchemy as sa + +from orm_loader.helpers.discovery import get_model_by_tablename +from orm_loader.helpers.metadata import Base + + +def test_get_model_by_tablename_supports_nested_inheritance() -> None: + class Child(Base): + __abstract__ = True + + class GrandChild(Child): + __tablename__ = "_discovery_grandchild" + id = sa.Column(sa.Integer, primary_key=True) + + resolved = get_model_by_tablename("_discovery_grandchild") + assert resolved is GrandChild + + +def test_get_model_by_tablename_returns_none_for_unknown_table() -> None: + assert get_model_by_tablename("_not_a_real_table_name_") is None From 004648836eb44cf1097fc454aa4fc30b41e0fd5f Mon Sep 17 00:00:00 2001 From: georgie Date: Tue, 19 May 2026 14:24:36 +1000 Subject: [PATCH 15/21] removed redundant helpers --- src/orm_loader/helpers/__init__.py | 2 -- src/orm_loader/helpers/sqlite.py | 15 --------------- 2 files changed, 17 deletions(-) diff --git a/src/orm_loader/helpers/__init__.py b/src/orm_loader/helpers/__init__.py index 01742a7..f2c49a1 100644 --- a/src/orm_loader/helpers/__init__.py +++ b/src/orm_loader/helpers/__init__.py @@ -3,7 +3,6 @@ from .bootstrap import bootstrap, create_db from .sqlite import ( attach_sqlite_bulk_load_pragmas, - enable_sqlite_foreign_keys, explain_sqlite_fk_error, restore_sqlite_journal_mode, ) @@ -20,7 +19,6 @@ "bootstrap", "create_db", "attach_sqlite_bulk_load_pragmas", - "enable_sqlite_foreign_keys", "explain_sqlite_fk_error", "restore_sqlite_journal_mode", "bulk_load_context", diff --git a/src/orm_loader/helpers/sqlite.py b/src/orm_loader/helpers/sqlite.py index 1c26091..a252748 100644 --- a/src/orm_loader/helpers/sqlite.py +++ b/src/orm_loader/helpers/sqlite.py @@ -6,21 +6,6 @@ from ..backends.sqlite import SQLiteBackend -def enable_sqlite_foreign_keys( - dbapi_connection: Any, - connection_record: Any, -) -> None: - """ - Apply the default SQLite connection settings used by orm-loader. - - This helper is kept for compatibility with older event-hook setups. - It delegates to ``SQLiteBackend.configure_dbapi_connection()``, - which enables foreign-key enforcement and may apply more than just - foreign-key settings. - """ - del connection_record - SQLiteBackend().configure_dbapi_connection(dbapi_connection) - def attach_sqlite_bulk_load_pragmas( engine: Engine, From fe7c2af39f2ea95e58877dea4674d864552dfb83 Mon Sep 17 00:00:00 2001 From: georgie Date: Tue, 19 May 2026 14:26:54 +1000 Subject: [PATCH 16/21] removed old-style bindings --- src/orm_loader/tables/loadable_table.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/orm_loader/tables/loadable_table.py b/src/orm_loader/tables/loadable_table.py index d5db296..12a44b7 100644 --- a/src/orm_loader/tables/loadable_table.py +++ b/src/orm_loader/tables/loadable_table.py @@ -2,6 +2,7 @@ import sqlalchemy as sa import sqlalchemy.orm as so import logging +from sqlalchemy.exc import InvalidRequestError, UnboundExecutionError from typing import Type, ClassVar, Optional, Any from pathlib import Path @@ -15,6 +16,14 @@ logger = logging.getLogger(__name__) +def _require_bind(session: so.Session) -> sa.Engine | sa.Connection: + """Return a bound connectable or raise a stable runtime error.""" + try: + return session.get_bind() + except (InvalidRequestError, UnboundExecutionError) as exc: + raise RuntimeError("Session is not bound to an engine") from exc + + """ CSV Loadable Table Mixins ================================== @@ -93,8 +102,7 @@ def create_staging_table( NotImplementedError If the database dialect is unsupported. """ - if session.bind is None: - raise RuntimeError("Session is not bound to an engine") + _require_bind(session) backend = resolve_backend(session) backend.create_staging_table(cls, session, cls.staging_tablename()) @@ -117,7 +125,7 @@ def manage_indices( resolved_index_strategy = backend.resolve_index_strategy(index_strategy) indices = list(cls.__table__.indexes) if resolved_index_strategy == "drop_rebuild" else [] - inspector = sa.inspect(session.bind) + inspector = sa.inspect(_require_bind(session)) assert inspector is not None, "Failed to create inspector for index management" if indices: @@ -177,10 +185,7 @@ def get_staging_table( sqlalchemy.Table The reflected staging table. """ - if session.bind is None: - raise RuntimeError("Session is not bound to an engine") - - engine = session.get_bind() + engine = _require_bind(session) inspector = sa.inspect(engine) staging_name = cls.staging_tablename() @@ -218,8 +223,7 @@ def load_staging( int Number of rows loaded into the staging table. """ - if loader_context.session.bind is None: - raise RuntimeError("Session is not bound to an engine") + _require_bind(loader_context.session) backend = resolve_backend(loader_context.session) total = 0 @@ -438,8 +442,7 @@ def merge_from_staging( staging = cls.staging_tablename() pk_cols = cls.pk_names() - if not session.bind: - raise RuntimeError("Session is not bound to an engine") + _require_bind(session) if merge_strategy == "replace": cls._merge_replace( session=session, From 2f057230e3f52ef7dd1f00c1c27ccbd025e0e4b8 Mon Sep 17 00:00:00 2001 From: georgie Date: Tue, 19 May 2026 14:31:15 +1000 Subject: [PATCH 17/21] context manager typing --- src/orm_loader/backends/postgres.py | 4 ++-- src/orm_loader/backends/sqlite.py | 3 ++- src/orm_loader/tables/loadable_table.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/orm_loader/backends/postgres.py b/src/orm_loader/backends/postgres.py index 96c76b1..b6174ec 100644 --- a/src/orm_loader/backends/postgres.py +++ b/src/orm_loader/backends/postgres.py @@ -1,6 +1,6 @@ from __future__ import annotations -from contextlib import contextmanager +from contextlib import contextmanager, AbstractContextManager from typing import TYPE_CHECKING, Any import sqlalchemy as sa import sqlalchemy.orm as so @@ -179,7 +179,7 @@ def merge_context( self, table_cls: type["CSVTableProtocol"], session: so.Session, - ): + ) -> AbstractContextManager[None]: return self.bulk_load_context(session, disable_fk=True, no_autoflush=False) diff --git a/src/orm_loader/backends/sqlite.py b/src/orm_loader/backends/sqlite.py index 0f16c3e..8862617 100644 --- a/src/orm_loader/backends/sqlite.py +++ b/src/orm_loader/backends/sqlite.py @@ -4,6 +4,7 @@ import sqlite3 from pathlib import Path from typing import TYPE_CHECKING, Any +from contextlib import AbstractContextManager import sqlalchemy as sa import sqlalchemy.orm as so @@ -207,7 +208,7 @@ def merge_context( self, table_cls: type["CSVTableProtocol"], session: so.Session, - ): + ) -> AbstractContextManager[None]: return self.bulk_load_context(session, disable_fk=True, no_autoflush=False) def create_materialized_view( diff --git a/src/orm_loader/tables/loadable_table.py b/src/orm_loader/tables/loadable_table.py index 12a44b7..deeae45 100644 --- a/src/orm_loader/tables/loadable_table.py +++ b/src/orm_loader/tables/loadable_table.py @@ -4,7 +4,7 @@ import logging from sqlalchemy.exc import InvalidRequestError, UnboundExecutionError -from typing import Type, ClassVar, Optional, Any +from typing import Type, ClassVar, Optional, Any, Iterator from pathlib import Path from contextlib import contextmanager @@ -112,7 +112,7 @@ def manage_indices( cls: Type['CSVTableProtocol'], session: so.Session, index_strategy: str = "auto", - ): + ) -> Iterator[None]: """ Manage non-primary-key indexes around a staged merge. From 4913bc18ad11c5f6003e7251c7e5db6739c02877 Mon Sep 17 00:00:00 2001 From: georgie Date: Tue, 19 May 2026 14:54:35 +1000 Subject: [PATCH 18/21] final pr cleanup --- src/orm_loader/backends/postgres.py | 6 +++-- src/orm_loader/backends/sqlite.py | 6 +++-- src/orm_loader/loaders/data_classes.py | 30 +---------------------- src/orm_loader/loaders/loading_helpers.py | 5 ++++ src/orm_loader/tables/loadable_table.py | 2 +- 5 files changed, 15 insertions(+), 34 deletions(-) diff --git a/src/orm_loader/backends/postgres.py b/src/orm_loader/backends/postgres.py index b6174ec..7b57c8b 100644 --- a/src/orm_loader/backends/postgres.py +++ b/src/orm_loader/backends/postgres.py @@ -98,13 +98,15 @@ def _normalize_fk_check_state(previous_state: str | int) -> str: def disable_fk_check(self, session: so.Session) -> str | int: previous_state = session.execute(sa.text("SHOW session_replication_role")).scalar() session.execute(sa.text("SET session_replication_role = 'replica'")) - assert isinstance(previous_state, str), "Expected PostgreSQL FK state to be a string" + if not isinstance(previous_state, str): + raise RuntimeError("Expected PostgreSQL FK state to be a string") return previous_state def enable_fk_check(self, session: so.Session) -> str | int: previous_state = session.execute(sa.text("SHOW session_replication_role")).scalar() session.execute(sa.text("SET session_replication_role = 'origin'")) - assert isinstance(previous_state, str), "Expected PostgreSQL FK state to be a string" + if not isinstance(previous_state, str): + raise RuntimeError("Expected PostgreSQL FK state to be a string") return previous_state def restore_fk_check( diff --git a/src/orm_loader/backends/sqlite.py b/src/orm_loader/backends/sqlite.py index 8862617..753abd4 100644 --- a/src/orm_loader/backends/sqlite.py +++ b/src/orm_loader/backends/sqlite.py @@ -113,13 +113,15 @@ def drop_staging_table( def disable_fk_check(self, session: so.Session) -> str | int: previous_state = session.execute(text("PRAGMA foreign_keys")).scalar() session.execute(text("PRAGMA foreign_keys = OFF")) - assert isinstance(previous_state, int), "Expected SQLite FK state to be an int" + if not isinstance(previous_state, int): + raise RuntimeError("Expected SQLite FK state to be an int") return previous_state def enable_fk_check(self, session: so.Session) -> str | int: previous_state = session.execute(text("PRAGMA foreign_keys")).scalar() session.execute(text("PRAGMA foreign_keys = ON")) - assert isinstance(previous_state, int), "Expected SQLite FK state to be an int" + if not isinstance(previous_state, int): + raise RuntimeError("Expected SQLite FK state to be an int") return previous_state def restore_fk_check( diff --git a/src/orm_loader/loaders/data_classes.py b/src/orm_loader/loaders/data_classes.py index 148cd0e..d7031fd 100644 --- a/src/orm_loader/loaders/data_classes.py +++ b/src/orm_loader/loaders/data_classes.py @@ -70,7 +70,7 @@ class LoaderContext: chunksize: int | None = None normalise: bool = True dedupe: bool = True - quote_mode: str = "csv" + quote_mode: str = "auto" class LoaderInterface: @@ -171,34 +171,6 @@ def dedupe(cls, data: pd.DataFrame | pa.Table, ctx: LoaderContext) -> Any: raise NotImplementedError - # vars_per_row = len(pk_cols) - # chunk_size = max(1, 10_000 // vars_per_row) - # existing_rows: list[tuple] = [] - - # for i in range(0, len(pk_tuples), chunk_size): - # chunk = pk_tuples[i : i + chunk_size] - - # rows = ( - # ctx.session.query(*pk_cols) - # .filter(sa.tuple_(*pk_cols).in_(chunk)) - # .all() - # ) - # existing_rows.extend(rows) - - # if not existing_rows: - # return df - - # existing = pd.DataFrame(existing_rows, columns=pk_names) - - # logger.warning(f"Dropping {len(existing)} rows from {ctx.tableclass.__tablename__} that already exist in the database") - # df = ( - # df.merge(existing, on=pk_names, how="left", indicator=True) - # .loc[lambda x: x["_merge"] == "left_only"] - # .drop(columns="_merge") - # ) - # return df - - @dataclass class ColumnCastingStats: """ diff --git a/src/orm_loader/loaders/loading_helpers.py b/src/orm_loader/loaders/loading_helpers.py index b1f41a9..cbc5ef7 100644 --- a/src/orm_loader/loaders/loading_helpers.py +++ b/src/orm_loader/loaders/loading_helpers.py @@ -2,6 +2,7 @@ from pathlib import Path import chardet import csv as _csv +import re import sqlalchemy as sa import sqlalchemy.orm as so import logging @@ -10,6 +11,8 @@ import pyarrow.csv as pv import io +_SAFE_ENCODING = re.compile(r'^[A-Za-z][A-Za-z0-9_-]*$') + logger = logging.getLogger(__name__) COPY_BLOCK_SIZE = 8192 @@ -238,6 +241,8 @@ def quick_load_pg( encoding = infer_encoding(path)['encoding'] or 'utf-8' + if not _SAFE_ENCODING.match(encoding): + raise ValueError(f"Unsafe encoding value from chardet: {encoding!r}") delimiter = infer_delim(path) if quote_mode == "auto": quote_mode = infer_quote_mode(path, delimiter=delimiter, encoding=encoding) diff --git a/src/orm_loader/tables/loadable_table.py b/src/orm_loader/tables/loadable_table.py index deeae45..1ce97fe 100644 --- a/src/orm_loader/tables/loadable_table.py +++ b/src/orm_loader/tables/loadable_table.py @@ -300,7 +300,7 @@ def load_csv( dedupe: bool = False, chunksize: int | None = None, merge_strategy: str = "replace", - quote_mode: str = "csv", + quote_mode: str = "auto", index_strategy: str = "auto", ) -> int: From f561a47934d96806823eebd634bfaabe44e07913 Mon Sep 17 00:00:00 2001 From: georgie Date: Tue, 19 May 2026 15:00:24 +1000 Subject: [PATCH 19/21] added test coverage for fk management --- tests/backends/test_postgres_backend.py | 49 ++++++++++++++++++++++- tests/backends/test_sqlite_backend.py | 52 ++++++++++++++++++++++++- 2 files changed, 99 insertions(+), 2 deletions(-) diff --git a/tests/backends/test_postgres_backend.py b/tests/backends/test_postgres_backend.py index 3cef920..6ac4467 100644 --- a/tests/backends/test_postgres_backend.py +++ b/tests/backends/test_postgres_backend.py @@ -25,7 +25,7 @@ class _ComputedTable: class _FakeSession: - def __init__(self, scalar_result="origin") -> None: + def __init__(self, scalar_result: str | int = "origin") -> None: self.statements: list[str] = [] self.scalar_result = scalar_result self.commits = 0 @@ -159,6 +159,53 @@ def test_postgres_backend_materialized_view_methods_emit_expected_sql(): assert any("REFRESH MATERIALIZED VIEW mv_test;" == sql for sql in session.statements) +def test_postgres_backend_normalize_fk_check_state(): + normalize = PostgresBackend._normalize_fk_check_state + + assert normalize("origin") == "origin" + assert normalize("local") == "local" + assert normalize("replica") == "replica" + assert normalize(" ORIGIN ") == "origin" + + try: + normalize("invalid_role") + except ValueError as exc: + assert "Invalid PostgreSQL session_replication_role" in str(exc) + else: + raise AssertionError("Expected ValueError for unrecognised role") + + try: + normalize(1) + except ValueError as exc: + assert "Postgres uses string roles" in str(exc) + else: + raise AssertionError("Expected ValueError for integer input") + + +def test_postgres_backend_disable_fk_raises_when_show_returns_non_string(): + backend = PostgresBackend() + session = _FakeSession(scalar_result=42) + + try: + backend.disable_fk_check(_sess(session)) + except RuntimeError as exc: + assert "Expected PostgreSQL FK state to be a string" in str(exc) + else: + raise AssertionError("Expected RuntimeError when SHOW returns a non-string") + + +def test_postgres_backend_enable_fk_raises_when_show_returns_non_string(): + backend = PostgresBackend() + session = _FakeSession(scalar_result=42) + + try: + backend.enable_fk_check(_sess(session)) + except RuntimeError as exc: + assert "Expected PostgreSQL FK state to be a string" in str(exc) + else: + raise AssertionError("Expected RuntimeError when SHOW returns a non-string") + + def test_postgres_backend_engine_with_replica_role_unregisters_listener(monkeypatch): backend = PostgresBackend() events: list[tuple[str, object, str]] = [] diff --git a/tests/backends/test_sqlite_backend.py b/tests/backends/test_sqlite_backend.py index aa3163e..5b1f060 100644 --- a/tests/backends/test_sqlite_backend.py +++ b/tests/backends/test_sqlite_backend.py @@ -25,7 +25,7 @@ class _ComputedTable: class _FakeSession: - def __init__(self, scalar_result=1) -> None: + def __init__(self, scalar_result: int | str = 1) -> None: self.statements: list[str] = [] self.scalar_result = scalar_result @@ -276,3 +276,53 @@ def test_sqlite_backend_rejects_invalid_journal_mode(): assert "Unsupported SQLite journal_mode" in str(exc) else: raise AssertionError("Expected invalid journal_mode to raise ValueError") + + +def test_sqlite_backend_disable_fk_raises_when_pragma_returns_non_int(): + backend = SQLiteBackend() + session = _FakeSession(scalar_result="not_an_int") + + try: + backend.disable_fk_check(_sess(session)) + except RuntimeError as exc: + assert "Expected SQLite FK state to be an int" in str(exc) + else: + raise AssertionError("Expected RuntimeError when PRAGMA returns a non-int") + + +def test_sqlite_backend_enable_fk_raises_when_pragma_returns_non_int(): + backend = SQLiteBackend() + session = _FakeSession(scalar_result="not_an_int") + + try: + backend.enable_fk_check(_sess(session)) + except RuntimeError as exc: + assert "Expected SQLite FK state to be an int" in str(exc) + else: + raise AssertionError("Expected RuntimeError when PRAGMA returns a non-int") + + +def test_sqlite_backend_restore_fk_accepts_string_values(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.restore_fk_check(_sess(session), "ON") + backend.restore_fk_check(_sess(session), "OFF") + + assert session.statements == [ + "PRAGMA foreign_keys = ON", + "PRAGMA foreign_keys = OFF", + ] + + +def test_sqlite_backend_fk_toggle_round_trip(session): + backend = SQLiteBackend() + + session.execute(sa.text("PRAGMA foreign_keys = ON")) + assert session.execute(sa.text("PRAGMA foreign_keys")).scalar() == 1 + + previous = backend.disable_fk_check(session) + assert session.execute(sa.text("PRAGMA foreign_keys")).scalar() == 0 + + backend.restore_fk_check(session, previous) + assert session.execute(sa.text("PRAGMA foreign_keys")).scalar() == 1 From 2d3a466d280c31d44757847ea941e0b633ac7dde Mon Sep 17 00:00:00 2001 From: georgie Date: Tue, 19 May 2026 15:01:28 +1000 Subject: [PATCH 20/21] linting --- src/orm_loader/helpers/sqlite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/orm_loader/helpers/sqlite.py b/src/orm_loader/helpers/sqlite.py index a252748..b569b90 100644 --- a/src/orm_loader/helpers/sqlite.py +++ b/src/orm_loader/helpers/sqlite.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Any from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError From bd8565ba1603ddd427c8194e462cf82fd9e631b0 Mon Sep 17 00:00:00 2001 From: gkennos Date: Tue, 19 May 2026 16:18:24 +1000 Subject: [PATCH 21/21] typing and docstring --- src/orm_loader/helpers/bulk.py | 13 +++++++++---- src/orm_loader/helpers/sqlite.py | 3 +++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/orm_loader/helpers/bulk.py b/src/orm_loader/helpers/bulk.py index 7af521a..4be22b4 100644 --- a/src/orm_loader/helpers/bulk.py +++ b/src/orm_loader/helpers/bulk.py @@ -1,6 +1,7 @@ from contextlib import contextmanager from sqlalchemy import Engine from sqlalchemy.orm import Session +from typing import Iterator from ..backends.resolve import resolve_backend from .logging import get_logger @@ -10,14 +11,18 @@ def disable_fk_check(session: Session) -> str | int: """Disable foreign-key checks for the current session and return the previous state.""" previous_state = resolve_backend(session).disable_fk_check(session) logger.info("Disabled foreign key checks for bulk load.") - assert isinstance(previous_state, (str, int)), "Expected previous FK state to be str or int" + if not isinstance(previous_state, (str, int)): + logger.error(f"Unexpected FK state type: {type(previous_state)}. Expected str or int.") + raise TypeError(f"Expected previous FK state to be str or int, got {type(previous_state)}") return previous_state def enable_fk_check(session: Session) -> str | int: """Enable foreign-key checks for the current session and return the previous state.""" previous_state = resolve_backend(session).enable_fk_check(session) logger.info("Explicitly re-enabled foreign key checks.") - assert isinstance(previous_state, (str, int)), "Expected previous FK state to be str or int" + if not isinstance(previous_state, (str, int)): + logger.error(f"Unexpected FK state type: {type(previous_state)}. Expected str or int.") + raise TypeError(f"Expected previous FK state to be str or int, got {type(previous_state)}") return previous_state def restore_fk_check(session: Session, previous_state: str | int): @@ -31,7 +36,7 @@ def bulk_load_context( *, disable_fk: bool = True, no_autoflush: bool = True, -): +) -> Iterator[None]: """ Wrap a trusted bulk operation in backend-aware session settings. @@ -48,7 +53,7 @@ def bulk_load_context( @contextmanager -def engine_with_replica_role(engine: Engine): +def engine_with_replica_role(engine: Engine) -> Iterator[Engine]: """ Force ``session_replication_role=replica`` on PostgreSQL engine connections. diff --git a/src/orm_loader/helpers/sqlite.py b/src/orm_loader/helpers/sqlite.py index b569b90..ca8c134 100644 --- a/src/orm_loader/helpers/sqlite.py +++ b/src/orm_loader/helpers/sqlite.py @@ -19,6 +19,9 @@ def attach_sqlite_bulk_load_pragmas( The hook currently sets ``busy_timeout``, journal mode, and foreign-key enforcement, and can also enable deferred foreign-key checking for the connection. + + Note that this is the replacement for old ``enable_sqlite_foreign_keys()`` + workaround - this should be no longer needed. """ SQLiteBackend( busy_timeout_ms=busy_timeout_ms,