diff --git a/src/nene2/database/interfaces.py b/src/nene2/database/interfaces.py index 086fcdf..dd7eded 100644 --- a/src/nene2/database/interfaces.py +++ b/src/nene2/database/interfaces.py @@ -5,6 +5,7 @@ """ from abc import ABC, abstractmethod +from collections.abc import Callable from typing import Any @@ -31,7 +32,18 @@ def write(self, sql: str, params: dict[str, Any] | None = None) -> int: class DatabaseTransactionManagerInterface(ABC): - """Manage database transactions.""" + """Manage database transactions. + + High-level API: use transactional() — it commits on success and rolls back on exception. + Low-level API: begin() / commit() / rollback() for manual control. + """ + + @abstractmethod + def transactional[T]( + self, callback: Callable[[DatabaseQueryExecutorInterface], T] + ) -> T: + """Run callback inside a transaction; commit on success, rollback on exception.""" + ... @abstractmethod def begin(self) -> None: ... diff --git a/src/nene2/database/sqlalchemy_executor.py b/src/nene2/database/sqlalchemy_executor.py index 2d1c225..05a8367 100644 --- a/src/nene2/database/sqlalchemy_executor.py +++ b/src/nene2/database/sqlalchemy_executor.py @@ -3,9 +3,10 @@ Supports SQLite, MySQL, and PostgreSQL via SQLAlchemy's engine URL. """ +from collections.abc import Callable from typing import Any -from sqlalchemy import Engine, text +from sqlalchemy import Connection, Engine, text from sqlalchemy.exc import OperationalError from .exceptions import DatabaseConnectionException @@ -48,23 +49,67 @@ def write(self, sql: str, params: dict[str, Any] | None = None) -> int: raise DatabaseConnectionException(str(exc)) from exc +class _BoundQueryExecutor(DatabaseQueryExecutorInterface): + """Query executor bound to an existing connection (within a transaction).""" + + def __init__(self, conn: Connection) -> None: + self._conn = conn + + def fetch_all( + self, sql: str, params: dict[str, Any] | None = None + ) -> list[dict[str, Any]]: + result = self._conn.execute(text(sql), params or {}) + return [dict(row._mapping) for row in result] + + def fetch_one( + self, sql: str, params: dict[str, Any] | None = None + ) -> dict[str, Any] | None: + result = self._conn.execute(text(sql), params or {}) + row = result.fetchone() + return dict(row._mapping) if row else None + + def write(self, sql: str, params: dict[str, Any] | None = None) -> int: + result = self._conn.execute(text(sql), params or {}) + return result.lastrowid or result.rowcount + + class SqlAlchemyTransactionManager(DatabaseTransactionManagerInterface): - """Manage an explicit transaction on a single SQLAlchemy connection.""" + """Manage database transactions using SQLAlchemy. + + Use transactional() for the recommended callback-based API. + Use begin() / commit() / rollback() for explicit transaction control. + """ def __init__(self, engine: Engine) -> None: self._engine = engine - self._conn = engine.connect() + self._conn: Connection | None = None self._tx: Any = None + def transactional[T]( + self, callback: Callable[[DatabaseQueryExecutorInterface], T] + ) -> T: + try: + with self._engine.begin() as conn: + return callback(_BoundQueryExecutor(conn)) + except OperationalError as exc: + raise DatabaseConnectionException(str(exc)) from exc + def begin(self) -> None: + self._conn = self._engine.connect() self._tx = self._conn.begin() def commit(self) -> None: if self._tx is not None: self._tx.commit() self._tx = None + if self._conn is not None: + self._conn.close() + self._conn = None def rollback(self) -> None: if self._tx is not None: self._tx.rollback() self._tx = None + if self._conn is not None: + self._conn.close() + self._conn = None diff --git a/tests/nene2/database/__init__.py b/tests/nene2/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/nene2/database/test_transaction.py b/tests/nene2/database/test_transaction.py new file mode 100644 index 0000000..f7fdd55 --- /dev/null +++ b/tests/nene2/database/test_transaction.py @@ -0,0 +1,76 @@ +"""Tests for SqlAlchemyTransactionManager — transactional() and begin/commit/rollback.""" + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.pool import StaticPool + +from nene2.database import DatabaseQueryExecutorInterface, SqlAlchemyTransactionManager +from nene2.database.exceptions import DatabaseConnectionException + + +def _manager() -> SqlAlchemyTransactionManager: + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + manager = SqlAlchemyTransactionManager(engine) + manager.transactional( + lambda ex: ex.write( + "CREATE TABLE items (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL)" + ) + ) + return manager + + +def test_transactional_commits_on_success() -> None: + mgr = _manager() + + def insert(ex: DatabaseQueryExecutorInterface) -> int: + return ex.write("INSERT INTO items (name) VALUES (:name)", {"name": "hello"}) + + mgr.transactional(insert) + rows = mgr.transactional(lambda ex: ex.fetch_all("SELECT * FROM items")) + assert len(rows) == 1 + assert rows[0]["name"] == "hello" + + +def test_transactional_rollback_on_exception() -> None: + mgr = _manager() + + def failing(ex: DatabaseQueryExecutorInterface) -> None: + ex.write("INSERT INTO items (name) VALUES (:name)", {"name": "will-rollback"}) + raise ValueError("intentional failure") + + with pytest.raises(ValueError): + mgr.transactional(failing) + + rows = mgr.transactional(lambda ex: ex.fetch_all("SELECT * FROM items")) + assert rows == [] + + +def test_transactional_returns_callback_value() -> None: + mgr = _manager() + mgr.transactional(lambda ex: ex.write("INSERT INTO items (name) VALUES ('x')")) + count = mgr.transactional( + lambda ex: ex.fetch_one("SELECT COUNT(*) AS cnt FROM items") + ) + assert count is not None + assert count["cnt"] == 1 + + +def test_begin_commit_workflow() -> None: + mgr = _manager() + mgr.begin() + mgr.transactional(lambda ex: ex.write("INSERT INTO items (name) VALUES ('low-level')")) + mgr.commit() + rows = mgr.transactional(lambda ex: ex.fetch_all("SELECT * FROM items")) + assert any(r["name"] == "low-level" for r in rows) + + +def test_begin_rollback_workflow() -> None: + mgr = _manager() + mgr.begin() + mgr.rollback() + rows = mgr.transactional(lambda ex: ex.fetch_all("SELECT * FROM items")) + assert rows == []