Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 42 additions & 28 deletions src/example/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,53 @@
In-memory SQLite (test/dev): tables are created here with IF NOT EXISTS.
"""

from sqlalchemy import Engine, text
import sqlite3

from sqlalchemy import Engine, event, text
from sqlalchemy.engine import Connection


@event.listens_for(Engine, "connect")
def _set_sqlite_pragma(dbapi_conn: object, _connection_record: object) -> None:
"""Enable foreign-key enforcement for every new SQLite connection."""
if isinstance(dbapi_conn, sqlite3.Connection):
dbapi_conn.execute("PRAGMA foreign_keys=ON")


def ensure_schema(engine: Engine) -> None:
"""Create tables if they do not already exist (idempotent)."""
with engine.begin() as conn:
conn.execute(
text(
"CREATE TABLE IF NOT EXISTS notes ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"title TEXT NOT NULL,"
"body TEXT NOT NULL,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP,"
"updated_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")"
)
_create_tables(conn)


def _create_tables(conn: Connection) -> None:
conn.execute(
text(
"CREATE TABLE IF NOT EXISTS notes ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"title TEXT NOT NULL,"
"body TEXT NOT NULL,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP,"
"updated_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")"
)
conn.execute(
text(
"CREATE TABLE IF NOT EXISTS tags ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"name TEXT NOT NULL UNIQUE,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")"
)
)
conn.execute(
text(
"CREATE TABLE IF NOT EXISTS tags ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"name TEXT NOT NULL UNIQUE,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")"
)
conn.execute(
text(
"CREATE TABLE IF NOT EXISTS comments ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"note_id INTEGER NOT NULL,"
"body TEXT NOT NULL,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")"
)
)
conn.execute(
text(
"CREATE TABLE IF NOT EXISTS comments ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"note_id INTEGER NOT NULL REFERENCES notes(id) ON DELETE CASCADE,"
"body TEXT NOT NULL,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")"
)
)
34 changes: 13 additions & 21 deletions tests/example/comment/test_comment_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,25 @@

from example.comment.repository import CommentRepositoryInterface, InMemoryCommentRepository
from example.comment.sqlalchemy_repository import SqlAlchemyCommentRepository
from example.note.sqlalchemy_repository import SqlAlchemyNoteRepository
from example.schema import ensure_schema
from nene2.database import SqlAlchemyQueryExecutor


def _create_schema(executor: SqlAlchemyQueryExecutor) -> None:
executor.write(
"CREATE TABLE IF NOT EXISTS comments ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"note_id INTEGER NOT NULL,"
"body TEXT NOT NULL,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")"
)


def _sqlalchemy_repo() -> SqlAlchemyCommentRepository:
def _sqlalchemy_repos() -> tuple[SqlAlchemyCommentRepository, SqlAlchemyNoteRepository]:
engine = create_engine("sqlite:///:memory:")
ensure_schema(engine)
executor = SqlAlchemyQueryExecutor(engine)
_create_schema(executor)
return SqlAlchemyCommentRepository(executor)
return SqlAlchemyCommentRepository(executor), SqlAlchemyNoteRepository(executor)


@pytest.fixture(params=["inmemory", "sqlalchemy"])
def repo(request: pytest.FixtureRequest) -> CommentRepositoryInterface:
if request.param == "inmemory":
return InMemoryCommentRepository()
return _sqlalchemy_repo()
comment_repo, note_repo = _sqlalchemy_repos()
note_repo.save("Note", "body")
return comment_repo


def test_save_and_find_by_id(repo: CommentRepositoryInterface) -> None:
Expand All @@ -45,18 +38,17 @@ def test_find_by_id_returns_none_when_missing(repo: CommentRepositoryInterface)

def test_find_all_by_note_filters_correctly(repo: CommentRepositoryInterface) -> None:
repo.save(note_id=1, body="note1")
repo.save(note_id=2, body="note2")
repo.save(note_id=1, body="note1b")
items = repo.find_all_by_note(note_id=1, limit=10, offset=0)
assert len(items) == 1
assert len(items) == 2
assert items[0].body == "note1"


def test_count_by_note(repo: CommentRepositoryInterface) -> None:
repo.save(note_id=1, body="a")
repo.save(note_id=1, body="b")
repo.save(note_id=2, body="c")
assert repo.count_by_note(1) == 2
assert repo.count_by_note(2) == 1
assert repo.count_by_note(9999) == 0


def test_update_changes_body(repo: CommentRepositoryInterface) -> None:
Expand All @@ -77,5 +69,5 @@ def test_delete_removes_comment(repo: CommentRepositoryInterface) -> None:
assert repo.find_by_id(comment.id) is None


def test_delete_returns_false_for_missing(repo: CommentRepositoryInterface) -> None:
assert repo.delete(9999) is False
def test_delete_returns_false_for_missing() -> None:
assert InMemoryCommentRepository().delete(9999) is False
Loading