Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/nene2/database/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any


Expand All @@ -31,7 +32,18 @@ def write(self, sql: str, params: dict[str, Any] | None = None) -> int:


class DatabaseTransactionManagerInterface(ABC):
"""Manage database transactions."""
"""Manage database transactions.

High-level API: use transactional() — it commits on success and rolls back on exception.
Low-level API: begin() / commit() / rollback() for manual control.
"""

@abstractmethod
def transactional[T](
self, callback: Callable[[DatabaseQueryExecutorInterface], T]
) -> T:
"""Run callback inside a transaction; commit on success, rollback on exception."""
...

@abstractmethod
def begin(self) -> None: ...
Expand Down
51 changes: 48 additions & 3 deletions src/nene2/database/sqlalchemy_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
Supports SQLite, MySQL, and PostgreSQL via SQLAlchemy's engine URL.
"""

from collections.abc import Callable
from typing import Any

from sqlalchemy import Engine, text
from sqlalchemy import Connection, Engine, text
from sqlalchemy.exc import OperationalError

from .exceptions import DatabaseConnectionException
Expand Down Expand Up @@ -48,23 +49,67 @@ def write(self, sql: str, params: dict[str, Any] | None = None) -> int:
raise DatabaseConnectionException(str(exc)) from exc


class _BoundQueryExecutor(DatabaseQueryExecutorInterface):
"""Query executor bound to an existing connection (within a transaction)."""

def __init__(self, conn: Connection) -> None:
self._conn = conn

def fetch_all(
self, sql: str, params: dict[str, Any] | None = None
) -> list[dict[str, Any]]:
result = self._conn.execute(text(sql), params or {})
return [dict(row._mapping) for row in result]

def fetch_one(
self, sql: str, params: dict[str, Any] | None = None
) -> dict[str, Any] | None:
result = self._conn.execute(text(sql), params or {})
row = result.fetchone()
return dict(row._mapping) if row else None

def write(self, sql: str, params: dict[str, Any] | None = None) -> int:
result = self._conn.execute(text(sql), params or {})
return result.lastrowid or result.rowcount


class SqlAlchemyTransactionManager(DatabaseTransactionManagerInterface):
"""Manage an explicit transaction on a single SQLAlchemy connection."""
"""Manage database transactions using SQLAlchemy.

Use transactional() for the recommended callback-based API.
Use begin() / commit() / rollback() for explicit transaction control.
"""

def __init__(self, engine: Engine) -> None:
self._engine = engine
self._conn = engine.connect()
self._conn: Connection | None = None
self._tx: Any = None

def transactional[T](
self, callback: Callable[[DatabaseQueryExecutorInterface], T]
) -> T:
try:
with self._engine.begin() as conn:
return callback(_BoundQueryExecutor(conn))
except OperationalError as exc:
raise DatabaseConnectionException(str(exc)) from exc

def begin(self) -> None:
self._conn = self._engine.connect()
self._tx = self._conn.begin()

def commit(self) -> None:
if self._tx is not None:
self._tx.commit()
self._tx = None
if self._conn is not None:
self._conn.close()
self._conn = None

def rollback(self) -> None:
if self._tx is not None:
self._tx.rollback()
self._tx = None
if self._conn is not None:
self._conn.close()
self._conn = None
Empty file.
76 changes: 76 additions & 0 deletions tests/nene2/database/test_transaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Tests for SqlAlchemyTransactionManager — transactional() and begin/commit/rollback."""

import pytest
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool

from nene2.database import DatabaseQueryExecutorInterface, SqlAlchemyTransactionManager
from nene2.database.exceptions import DatabaseConnectionException


def _manager() -> SqlAlchemyTransactionManager:
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
manager = SqlAlchemyTransactionManager(engine)
manager.transactional(
lambda ex: ex.write(
"CREATE TABLE items (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL)"
)
)
return manager


def test_transactional_commits_on_success() -> None:
mgr = _manager()

def insert(ex: DatabaseQueryExecutorInterface) -> int:
return ex.write("INSERT INTO items (name) VALUES (:name)", {"name": "hello"})

mgr.transactional(insert)
rows = mgr.transactional(lambda ex: ex.fetch_all("SELECT * FROM items"))
assert len(rows) == 1
assert rows[0]["name"] == "hello"


def test_transactional_rollback_on_exception() -> None:
mgr = _manager()

def failing(ex: DatabaseQueryExecutorInterface) -> None:
ex.write("INSERT INTO items (name) VALUES (:name)", {"name": "will-rollback"})
raise ValueError("intentional failure")

with pytest.raises(ValueError):
mgr.transactional(failing)

rows = mgr.transactional(lambda ex: ex.fetch_all("SELECT * FROM items"))
assert rows == []


def test_transactional_returns_callback_value() -> None:
mgr = _manager()
mgr.transactional(lambda ex: ex.write("INSERT INTO items (name) VALUES ('x')"))
count = mgr.transactional(
lambda ex: ex.fetch_one("SELECT COUNT(*) AS cnt FROM items")
)
assert count is not None
assert count["cnt"] == 1


def test_begin_commit_workflow() -> None:
mgr = _manager()
mgr.begin()
mgr.transactional(lambda ex: ex.write("INSERT INTO items (name) VALUES ('low-level')"))
mgr.commit()
rows = mgr.transactional(lambda ex: ex.fetch_all("SELECT * FROM items"))
assert any(r["name"] == "low-level" for r in rows)


def test_begin_rollback_workflow() -> None:
mgr = _manager()
mgr.begin()
mgr.rollback()
rows = mgr.transactional(lambda ex: ex.fetch_all("SELECT * FROM items"))
assert rows == []
Loading