Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/example/comment/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ async def update_comment(
errors.append(ValidationError("body", "Body must not be empty.", "required"))
if errors:
raise ValidationException(errors)
comment = update_use_case.execute(
UpdateCommentInput(comment_id=comment_id, body=body.body)
)
comment = update_use_case.execute(UpdateCommentInput(comment_id=comment_id, body=body.body))
return JSONResponse(_comment_dict(comment))

@router.delete("/{comment_id}", status_code=204)
Expand Down
4 changes: 1 addition & 3 deletions src/example/comment/sqlalchemy_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def update(self, comment_id: int, body: str) -> Comment | None:
return Comment(id=comment_id, note_id=row["note_id"], body=body)

def delete(self, comment_id: int) -> bool:
affected = self._executor.write(
"DELETE FROM comments WHERE id = :id", {"id": comment_id}
)
affected = self._executor.write("DELETE FROM comments WHERE id = :id", {"id": comment_id})
return affected > 0

def count_by_note(self, note_id: int) -> int:
Expand Down
1 change: 1 addition & 0 deletions src/example/note/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def make_note_router(
delete_use_case: DeleteNoteUseCase,
) -> APIRouter:
router = APIRouter(prefix="/notes", tags=["notes"])

@router.get("")
async def list_notes(request: Request) -> JSONResponse:
pagination = PaginationQueryParser.parse(request)
Expand Down
54 changes: 30 additions & 24 deletions src/example/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,33 @@
def ensure_schema(engine: Engine) -> None:
"""Create tables if they do not already exist (idempotent)."""
with engine.begin() as conn:
conn.execute(text(
"CREATE TABLE IF NOT EXISTS notes ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"title TEXT NOT NULL,"
"body TEXT NOT NULL,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP,"
"updated_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")"
))
conn.execute(text(
"CREATE TABLE IF NOT EXISTS tags ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"name TEXT NOT NULL UNIQUE,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")"
))
conn.execute(text(
"CREATE TABLE IF NOT EXISTS comments ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"note_id INTEGER NOT NULL,"
"body TEXT NOT NULL,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")"
))
conn.execute(
text(
"CREATE TABLE IF NOT EXISTS notes ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"title TEXT NOT NULL,"
"body TEXT NOT NULL,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP,"
"updated_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")"
)
)
conn.execute(
text(
"CREATE TABLE IF NOT EXISTS tags ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"name TEXT NOT NULL UNIQUE,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")"
)
)
conn.execute(
text(
"CREATE TABLE IF NOT EXISTS comments ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"note_id INTEGER NOT NULL,"
"body TEXT NOT NULL,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")"
)
)
1 change: 1 addition & 0 deletions src/example/tag/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def make_tag_router(
delete_use_case: DeleteTagUseCase,
) -> APIRouter:
router = APIRouter(prefix="/tags", tags=["tags"])

@router.get("")
async def list_tags(request: Request) -> JSONResponse:
pagination = PaginationQueryParser.parse(request)
Expand Down
2 changes: 1 addition & 1 deletion src/nene2/auth/bearer_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
)
response.headers["WWW-Authenticate"] = _WWW_AUTH
return response
token = auth[len("Bearer "):]
token = auth[len("Bearer ") :]
try:
verified = self._verifier.verify(token)
except TokenVerificationException:
Expand Down
4 changes: 3 additions & 1 deletion src/nene2/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,6 @@ def db_url(self) -> str:
port = self.db_port
if self.db_adapter == "mysql":
return f"mysql+pymysql://{self.db_user}:{password}@{self.db_host}:{port}/{self.db_name}"
return f"postgresql+psycopg2://{self.db_user}:{password}@{self.db_host}:{port}/{self.db_name}"
return (
f"postgresql+psycopg2://{self.db_user}:{password}@{self.db_host}:{port}/{self.db_name}"
)
8 changes: 2 additions & 6 deletions src/nene2/database/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ class DatabaseQueryExecutorInterface(ABC):
"""Execute parameterised SQL queries against a database."""

@abstractmethod
def fetch_all(
self, sql: str, params: dict[str, Any] | None = None
) -> list[dict[str, Any]]: ...
def fetch_all(self, sql: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]: ...

@abstractmethod
def fetch_one(
Expand All @@ -39,9 +37,7 @@ class DatabaseTransactionManagerInterface(ABC):
"""

@abstractmethod
def transactional[T](
self, callback: Callable[[DatabaseQueryExecutorInterface], T]
) -> T:
def transactional[T](self, callback: Callable[[DatabaseQueryExecutorInterface], T]) -> T:
"""Run callback inside a transaction; commit on success, rollback on exception."""
...

Expand Down
20 changes: 5 additions & 15 deletions src/nene2/database/sqlalchemy_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,15 @@ class SqlAlchemyQueryExecutor(DatabaseQueryExecutorInterface):
def __init__(self, engine: Engine) -> None:
self._engine = engine

def fetch_all(
self, sql: str, params: dict[str, Any] | None = None
) -> list[dict[str, Any]]:
def fetch_all(self, sql: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
try:
with self._engine.connect() as conn:
result = conn.execute(text(sql), params or {})
return [dict(row._mapping) for row in result]
except OperationalError as exc:
raise DatabaseConnectionException(str(exc)) from exc

def fetch_one(
self, sql: str, params: dict[str, Any] | None = None
) -> dict[str, Any] | None:
def fetch_one(self, sql: str, params: dict[str, Any] | None = None) -> dict[str, Any] | None:
try:
with self._engine.connect() as conn:
result = conn.execute(text(sql), params or {})
Expand All @@ -55,15 +51,11 @@ class _BoundQueryExecutor(DatabaseQueryExecutorInterface):
def __init__(self, conn: Connection) -> None:
self._conn = conn

def fetch_all(
self, sql: str, params: dict[str, Any] | None = None
) -> list[dict[str, Any]]:
def fetch_all(self, sql: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
result = self._conn.execute(text(sql), params or {})
return [dict(row._mapping) for row in result]

def fetch_one(
self, sql: str, params: dict[str, Any] | None = None
) -> dict[str, Any] | None:
def fetch_one(self, sql: str, params: dict[str, Any] | None = None) -> dict[str, Any] | None:
result = self._conn.execute(text(sql), params or {})
row = result.fetchone()
return dict(row._mapping) if row else None
Expand All @@ -85,9 +77,7 @@ def __init__(self, engine: Engine) -> None:
self._conn: Connection | None = None
self._tx: Any = None

def transactional[T](
self, callback: Callable[[DatabaseQueryExecutorInterface], T]
) -> T:
def transactional[T](self, callback: Callable[[DatabaseQueryExecutorInterface], T]) -> T:
try:
with self._engine.begin() as conn:
return callback(_BoundQueryExecutor(conn))
Expand Down
16 changes: 4 additions & 12 deletions src/nene2/mcp/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@ class McpHttpClientProtocol(Protocol):

def get(self, base_url: str, path: str) -> McpHttpResponse: ...

def post(
self, base_url: str, path: str, body: dict[str, object]
) -> McpHttpResponse: ...
def post(self, base_url: str, path: str, body: dict[str, object]) -> McpHttpResponse: ...

def put(
self, base_url: str, path: str, body: dict[str, object]
) -> McpHttpResponse: ...
def put(self, base_url: str, path: str, body: dict[str, object]) -> McpHttpResponse: ...

def delete(self, base_url: str, path: str) -> McpHttpResponse: ...

Expand All @@ -64,14 +60,10 @@ def __init__(
def get(self, base_url: str, path: str) -> McpHttpResponse:
return self._request("GET", base_url, path, None)

def post(
self, base_url: str, path: str, body: dict[str, object]
) -> McpHttpResponse:
def post(self, base_url: str, path: str, body: dict[str, object]) -> McpHttpResponse:
return self._request("POST", base_url, path, body)

def put(
self, base_url: str, path: str, body: dict[str, object]
) -> McpHttpResponse:
def put(self, base_url: str, path: str, body: dict[str, object]) -> McpHttpResponse:
return self._request("PUT", base_url, path, body)

def delete(self, base_url: str, path: str) -> McpHttpResponse:
Expand Down
4 changes: 1 addition & 3 deletions tests/nene2/database/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def failing(ex: DatabaseQueryExecutorInterface) -> None:
def test_transactional_returns_callback_value() -> None:
mgr = _manager()
mgr.transactional(lambda ex: ex.write("INSERT INTO items (name) VALUES ('x')"))
count = mgr.transactional(
lambda ex: ex.fetch_one("SELECT COUNT(*) AS cnt FROM items")
)
count = mgr.transactional(lambda ex: ex.fetch_one("SELECT COUNT(*) AS cnt FROM items"))
assert count is not None
assert count["cnt"] == 1

Expand Down
Loading