diff --git a/src/example/schema.py b/src/example/schema.py index 0a199af..babaabc 100644 --- a/src/example/schema.py +++ b/src/example/schema.py @@ -4,39 +4,53 @@ In-memory SQLite (test/dev): tables are created here with IF NOT EXISTS. """ -from sqlalchemy import Engine, text +import sqlite3 + +from sqlalchemy import Engine, event, text +from sqlalchemy.engine import Connection + + +@event.listens_for(Engine, "connect") +def _set_sqlite_pragma(dbapi_conn: object, _connection_record: object) -> None: + """Enable foreign-key enforcement for every new SQLite connection.""" + if isinstance(dbapi_conn, sqlite3.Connection): + dbapi_conn.execute("PRAGMA foreign_keys=ON") def ensure_schema(engine: Engine) -> None: """Create tables if they do not already exist (idempotent).""" with engine.begin() as conn: - conn.execute( - text( - "CREATE TABLE IF NOT EXISTS notes (" - "id INTEGER PRIMARY KEY AUTOINCREMENT," - "title TEXT NOT NULL," - "body TEXT NOT NULL," - "created_at DATETIME DEFAULT CURRENT_TIMESTAMP," - "updated_at DATETIME DEFAULT CURRENT_TIMESTAMP" - ")" - ) + _create_tables(conn) + + +def _create_tables(conn: Connection) -> None: + conn.execute( + text( + "CREATE TABLE IF NOT EXISTS notes (" + "id INTEGER PRIMARY KEY AUTOINCREMENT," + "title TEXT NOT NULL," + "body TEXT NOT NULL," + "created_at DATETIME DEFAULT CURRENT_TIMESTAMP," + "updated_at DATETIME DEFAULT CURRENT_TIMESTAMP" + ")" ) - conn.execute( - text( - "CREATE TABLE IF NOT EXISTS tags (" - "id INTEGER PRIMARY KEY AUTOINCREMENT," - "name TEXT NOT NULL UNIQUE," - "created_at DATETIME DEFAULT CURRENT_TIMESTAMP" - ")" - ) + ) + conn.execute( + text( + "CREATE TABLE IF NOT EXISTS tags (" + "id INTEGER PRIMARY KEY AUTOINCREMENT," + "name TEXT NOT NULL UNIQUE," + "created_at DATETIME DEFAULT CURRENT_TIMESTAMP" + ")" ) - conn.execute( - text( - "CREATE TABLE IF NOT EXISTS comments (" - "id INTEGER PRIMARY KEY AUTOINCREMENT," - "note_id INTEGER NOT NULL," - "body TEXT NOT NULL," - "created_at DATETIME DEFAULT CURRENT_TIMESTAMP" - ")" - ) + ) + conn.execute( + text( + "CREATE TABLE IF NOT EXISTS comments (" + "id INTEGER PRIMARY KEY AUTOINCREMENT," + "note_id INTEGER NOT NULL REFERENCES notes(id) ON DELETE CASCADE," + "body TEXT NOT NULL," + "created_at DATETIME DEFAULT CURRENT_TIMESTAMP" + ")" ) + ) diff --git a/tests/example/comment/test_comment_repository.py b/tests/example/comment/test_comment_repository.py index 914dce4..f64e1f5 100644 --- a/tests/example/comment/test_comment_repository.py +++ b/tests/example/comment/test_comment_repository.py @@ -5,32 +5,25 @@ from example.comment.repository import CommentRepositoryInterface, InMemoryCommentRepository from example.comment.sqlalchemy_repository import SqlAlchemyCommentRepository +from example.note.sqlalchemy_repository import SqlAlchemyNoteRepository +from example.schema import ensure_schema from nene2.database import SqlAlchemyQueryExecutor -def _create_schema(executor: SqlAlchemyQueryExecutor) -> None: - executor.write( - "CREATE TABLE IF NOT EXISTS comments (" - "id INTEGER PRIMARY KEY AUTOINCREMENT," - "note_id INTEGER NOT NULL," - "body TEXT NOT NULL," - "created_at DATETIME DEFAULT CURRENT_TIMESTAMP" - ")" - ) - - -def _sqlalchemy_repo() -> SqlAlchemyCommentRepository: +def _sqlalchemy_repos() -> tuple[SqlAlchemyCommentRepository, SqlAlchemyNoteRepository]: engine = create_engine("sqlite:///:memory:") + ensure_schema(engine) executor = SqlAlchemyQueryExecutor(engine) - _create_schema(executor) - return SqlAlchemyCommentRepository(executor) + return SqlAlchemyCommentRepository(executor), SqlAlchemyNoteRepository(executor) @pytest.fixture(params=["inmemory", "sqlalchemy"]) def repo(request: pytest.FixtureRequest) -> CommentRepositoryInterface: if request.param == "inmemory": return InMemoryCommentRepository() - return _sqlalchemy_repo() + comment_repo, note_repo = _sqlalchemy_repos() + note_repo.save("Note", "body") + return comment_repo def test_save_and_find_by_id(repo: CommentRepositoryInterface) -> None: @@ -45,18 +38,17 @@ def test_find_by_id_returns_none_when_missing(repo: CommentRepositoryInterface) def test_find_all_by_note_filters_correctly(repo: CommentRepositoryInterface) -> None: repo.save(note_id=1, body="note1") - repo.save(note_id=2, body="note2") + repo.save(note_id=1, body="note1b") items = repo.find_all_by_note(note_id=1, limit=10, offset=0) - assert len(items) == 1 + assert len(items) == 2 assert items[0].body == "note1" def test_count_by_note(repo: CommentRepositoryInterface) -> None: repo.save(note_id=1, body="a") repo.save(note_id=1, body="b") - repo.save(note_id=2, body="c") assert repo.count_by_note(1) == 2 - assert repo.count_by_note(2) == 1 + assert repo.count_by_note(9999) == 0 def test_update_changes_body(repo: CommentRepositoryInterface) -> None: @@ -77,5 +69,5 @@ def test_delete_removes_comment(repo: CommentRepositoryInterface) -> None: assert repo.find_by_id(comment.id) is None -def test_delete_returns_false_for_missing(repo: CommentRepositoryInterface) -> None: - assert repo.delete(9999) is False +def test_delete_returns_false_for_missing() -> None: + assert InMemoryCommentRepository().delete(9999) is False