From d3afcc9b358c69c86fc43ce1ac89efd71ba57ddb Mon Sep 17 00:00:00 2001 From: iklobato Date: Fri, 3 Oct 2025 13:23:38 -0300 Subject: [PATCH] fix: resolve critical framework issues and improve stability Refactor core to use Starlette consistently, resolving race conditions. Implement robust error handling for unique constraints and missing resources. Add flexible datetime parsing for model fields. Fix Swagger/OpenAPI documentation generation and add ReDoc support. Improve docstrings and update documentation to reflect changes. --- docs/advanced/validation.md | 11 + docs/getting-started/first-steps.md | 4 +- docs/getting-started/quickstart.md | 15 +- docs/troubleshooting.md | 38 +++ lightapi/core.py | 328 +++++++------------- lightapi/database.py | 34 ++- lightapi/handlers.py | 443 +++++++++------------------- lightapi/models.py | 25 -- lightapi/swagger.py | 93 ++++-- requirements.txt | 3 +- 10 files changed, 390 insertions(+), 604 deletions(-) diff --git a/docs/advanced/validation.md b/docs/advanced/validation.md index 455b829..a8df774 100644 --- a/docs/advanced/validation.md +++ b/docs/advanced/validation.md @@ -50,3 +50,14 @@ class UserEndpoint(RestEndpoint): - You can also override the `validate(self, data: dict)` method directly for full-body validation. - Combine with filtering and pagination for robust endpoint logic. + +## 5. Automatic Datetime Parsing + +LightAPI automatically parses string values for columns of type `DateTime` and `Date`. It uses the `python-dateutil` library to flexibly parse a wide range of formats, including: + +- `YYYY-MM-DDTHH:MM:SS` +- `YYYY-MM-DDTHH:MM:SS.ffffff` +- `YYYY-MM-DDTHH:MM:SS+HH:MM` (with timezone) +- `YYYY-MM-DD` + +If a string cannot be parsed as a valid datetime, a `400 Bad Request` error is returned with a descriptive message. diff --git a/docs/getting-started/first-steps.md b/docs/getting-started/first-steps.md index fff37d5..c11b92b 100644 --- a/docs/getting-started/first-steps.md +++ b/docs/getting-started/first-steps.md @@ -51,9 +51,7 @@ from lightapi import LightApi from app.models import User app = LightApi() -app.register({ - '/users': User -}) +app.register(User) if __name__ == '__main__': app.run(host='127.0.0.1', port=8000) diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index 55e6fcd..b7de8a1 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -119,10 +119,8 @@ from lightapi import LightApi from models import User, Post app = LightApi(database_url="sqlite:///blog.db") -app.register({ - '/users': User, - '/posts': Post -}) +app.register(User) +app.register(Post) if __name__ == '__main__': app.run(host='0.0.0.0', port=8000) @@ -132,9 +130,14 @@ if __name__ == '__main__': Once your API is running, you can test it in several ways: -### 1. Interactive Swagger Documentation +### 1. Interactive API Documentation -Visit **http://localhost:8000/docs** in your browser for an interactive API documentation interface where you can: +LightAPI provides two interactive documentation interfaces out of the box: + +- **Swagger UI**: Visit **http://localhost:8000/docs** in your browser. +- **ReDoc**: Visit **http://localhost:8000/redoc** for an alternative documentation layout. + +Both interfaces allow you to: - Browse all available endpoints - Test API calls directly from the browser - View request/response schemas diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 7488f69..48a5f35 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -9,6 +9,44 @@ This guide covers common issues you might encounter when using LightAPI and thei ## Runtime Errors +### Understanding Error Responses + +LightAPI now provides detailed error responses in debug mode to help with troubleshooting. + +#### Validation Errors (4xx) + +When a request fails due to invalid data (e.g., missing required fields, incorrect data types), you will receive a `4xx` status code with a JSON body describing the error. + +**Example: Unique Constraint Violation (`409 Conflict`)** +```json +{ + "error": "Unique constraint violated for users.email." +} +``` + +**Example: Invalid Datetime Format (`400 Bad Request`)** +```json +{ + "error": "Invalid datetime format for field 'published_date'" +} +``` + +#### Server Errors (500) + +When an unexpected server error occurs, LightAPI will return a `500 Internal Server Error` with a detailed JSON body in debug mode. + +**Example: 500 Error Response (in debug mode)** +```json +{ + "error": "Internal Server Error", + "message": "...", + "traceback": "..." +} +``` + +In production mode (`debug=False`), the `message` and `traceback` fields are omitted to avoid leaking sensitive information. + + ### Content-Length Errors **Error**: `RuntimeError: Response content longer than Content-Length` diff --git a/lightapi/core.py b/lightapi/core.py index 3f675dd..17d7832 100644 --- a/lightapi/core.py +++ b/lightapi/core.py @@ -1,37 +1,27 @@ -import hashlib import json -from inspect import iscoroutinefunction -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Type +from typing import Any, Callable, Dict, List, Type import uvicorn from starlette.applications import Starlette - from starlette.middleware.cors import CORSMiddleware as StarletteCORSMiddleware from starlette.responses import JSONResponse from starlette.routing import Route from .config import config -from .models import setup_database - -if TYPE_CHECKING: - from .rest import RestEndpoint +from .database import Base, setup_database class LightApi: - """ - Main application class for building REST APIs. - - LightApi provides functionality for setting up and running a - REST API application. It includes features for registering endpoints, - applying middleware, generating API documentation, and running the server. + """Main application class for building REST APIs. Attributes: - routes: List of Starlette routes. - middleware: List of middleware classes. - engine: SQLAlchemy engine. - Session: SQLAlchemy session factory. - enable_swagger: Whether Swagger documentation is enabled. - swagger_generator: SwaggerGenerator instance (if enabled). + routes: A list of Starlette routes. + middleware: A list of middleware classes. + engine: A SQLAlchemy engine instance. + Session: A SQLAlchemy session factory. + enable_swagger: A boolean indicating if Swagger is enabled. + swagger_generator: An instance of SwaggerGenerator. + debug: A boolean indicating if debug mode is enabled. """ def __init__( @@ -42,17 +32,18 @@ def __init__( swagger_description: str = None, enable_swagger: bool = None, cors_origins: List[str] = None, + debug: bool = False, ): - """ - Initialize a new LightApi application. + """Initializes the LightApi application. Args: - database_url: URL for the database connection. - swagger_title: Title for the Swagger documentation. - swagger_version: Version for the Swagger documentation. - swagger_description: Description for the Swagger documentation. + database_url: The URL for the database connection. + swagger_title: The title for the Swagger documentation. + swagger_version: The version for the Swagger documentation. + swagger_description: The description for the Swagger documentation. enable_swagger: Whether to enable Swagger documentation. - cors_origins: List of allowed CORS origins. + cors_origins: A list of allowed CORS origins. + debug: Whether to enable debug mode. """ # Update config with any provided values that are not None update_values = {} @@ -75,149 +66,54 @@ def __init__( self.middleware = [] self.engine, self.Session = setup_database(config.database_url) self.enable_swagger = config.enable_swagger + self.debug = debug if self.enable_swagger: - from .swagger import SwaggerGenerator + from .swagger import SwaggerGenerator, openapi_json_route, redoc_ui_route, swagger_ui_route self.swagger_generator = SwaggerGenerator( title=config.swagger_title, version=config.swagger_version, description=config.swagger_description, ) + self.routes.append(Route("/docs", swagger_ui_route, include_in_schema=False)) + self.routes.append(Route("/redoc", redoc_ui_route, include_in_schema=False)) + self.routes.append(Route("/openapi.json", openapi_json_route, include_in_schema=False)) - def register(self, handler): - """ - Register a model or endpoint class with the application. - Accepts a single SQLAlchemy model or RestEndpoint subclass per call. - """ - from .swagger import openapi_json_route, swagger_ui_route - - # If handler has route_patterns (custom endpoints) - route_patterns = getattr(handler, "route_patterns", None) - if route_patterns: - methods = ( - handler.Configuration.http_method_names - if hasattr(handler, "Configuration") and hasattr(handler.Configuration, "http_method_names") - else ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"] - ) - endpoint_handler = self._create_handler(handler, methods) - for pattern in route_patterns: - self.routes.append(Route(pattern, endpoint_handler, methods=methods)) - if self.enable_swagger: - self.swagger_generator.register_endpoint(pattern, handler) - return - - # If it's a SQLAlchemy model (RESTful resource) - if hasattr(handler, "__tablename__") and handler.__tablename__: - tablename = handler.__tablename__ - methods = ( - handler.Configuration.http_method_names - if hasattr(handler, "Configuration") and hasattr(handler.Configuration, "http_method_names") - else ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"] - ) - endpoint_handler = self._create_handler(handler, methods) - # Register /tablename and /tablename/{id} - base_path = f"/{tablename}" - id_path = f"/{tablename}/{{id}}" - self.routes.append(Route(base_path, endpoint_handler, methods=methods)) - self.routes.append(Route(id_path, endpoint_handler, methods=methods)) - if self.enable_swagger: - self.swagger_generator.register_endpoint(base_path, handler) - self.swagger_generator.register_endpoint(id_path, handler) - return - - # If it's a RestEndpoint subclass without route_patterns or __tablename__ - if hasattr(handler, "Configuration") or hasattr(handler, "get") or hasattr(handler, "post"): - path = f"/{handler.__name__.lower()}" - methods = ( - handler.Configuration.http_method_names - if hasattr(handler, "Configuration") and hasattr(handler.Configuration, "http_method_names") - else ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"] - ) - endpoint_handler = self._create_handler(handler, methods) - self.routes.append(Route(path, endpoint_handler, methods=methods)) - if self.enable_swagger: - self.swagger_generator.register_endpoint(path, handler) - return - - raise TypeError(f"Handler must be a SQLAlchemy model class or RestEndpoint class. Got: {handler}") - - def _create_handler(self, endpoint_class: Type["RestEndpoint"], methods: List[str]) -> Callable: - """ - Create a request handler for an endpoint class. + def register(self, model_class: Type[Base]): + """Registers a SQLAlchemy model to generate CRUD endpoints. Args: - endpoint_class: The endpoint class to create a handler for. - methods: List of HTTP methods the endpoint supports. - - Returns: - An async function that handles requests to the endpoint. + model_class: The SQLAlchemy model class to register. """ + from . import handlers - async def handler(request): - try: - endpoint = endpoint_class() - - if request.method in ["POST", "PUT", "PATCH"]: - try: - body = await request.body() - if body: - request.data = json.loads(body) - else: - request.data = {} - except json.JSONDecodeError: - request.data = {} - else: - request.data = {} - - # Setup the endpoint and check for authentication errors - setup_result = endpoint._setup(request, self.Session()) - if setup_result: - return setup_result - - method = request.method.lower() - if method.upper() not in [m.upper() for m in methods]: - return JSONResponse({"error": f"Method {method} not allowed"}, status_code=405) - - func = getattr(endpoint, method) - if iscoroutinefunction(func): - result = await func(request) - else: - result = func(request) - - # Convert returned value to a Response instance - if isinstance(result, (Response, JSONResponse)): - response = result - else: - if isinstance(result, tuple) and len(result) == 2: - body, status = result - else: - body, status = result, 200 - response = JSONResponse(body, status_code=status) + base_path = f"/{model_class.__tablename__}" + id_path = f"/{model_class.__tablename__}/{{id}}" - return response + self.routes.extend([ + Route(base_path, handlers.CreateHandler(model_class, self.Session), methods=["POST"]), + Route(base_path, handlers.RetrieveAllHandler(model_class, self.Session), methods=["GET"]), + Route(id_path, handlers.ReadHandler(model_class, self.Session), methods=["GET"]), + Route(id_path, handlers.UpdateHandler(model_class, self.Session), methods=["PUT"]), + Route(id_path, handlers.PatchHandler(model_class, self.Session), methods=["PATCH"]), + Route(id_path, handlers.DeleteHandler(model_class, self.Session), methods=["DELETE"]), + ]) - except Exception as e: - return JSONResponse({"error": str(e)}, status_code=500) - - return handler + if self.enable_swagger: + self.swagger_generator.register_endpoint(base_path, model_class) + self.swagger_generator.register_endpoint(id_path, model_class) def add_middleware(self, middleware_classes: List[Type["Middleware"]]): - """ - Add middleware classes to the application. + """Adds middleware classes to the application. Args: - middleware_classes: List of middleware classes to add. + middleware_classes: A list of middleware classes to add. """ self.middleware = middleware_classes def _print_endpoints(self): - """ - Print all registered endpoints to the console. - - This method displays a formatted table of all available endpoints, - including their paths, HTTP methods, and additional information. - """ + """Prints all registered endpoints to the console.""" if not self.routes: print("\nšŸ“” No endpoints registered") return @@ -226,55 +122,60 @@ def _print_endpoints(self): print("šŸš€ LightAPI - Available Endpoints") print("=" * 60) - # Group routes by path for better display endpoint_info = [] - for route in self.routes: if hasattr(route, "path") and hasattr(route, "methods"): path = route.path methods = list(route.methods) if route.methods else ["*"] - - # Skip special routes (docs, openapi) - if path in ["/api/docs", "/openapi.json"]: + if path in ["/docs", "/redoc", "/openapi.json"]: continue - - # Format methods string methods_str = ", ".join(sorted(methods)) - - # Try to get endpoint class name if available - endpoint_name = "Unknown" - if hasattr(route, "endpoint"): - if hasattr(route.endpoint, "__name__"): - endpoint_name = route.endpoint.__name__ - elif hasattr(route.endpoint, "__class__"): - endpoint_name = route.endpoint.__class__.__name__ - + endpoint_name = route.endpoint.__class__.__name__ endpoint_info.append({"path": path, "methods": methods_str, "name": endpoint_name}) if not endpoint_info: print("šŸ“” No API endpoints found (only system routes)") return - # Calculate column widths for formatting max_path_len = max(len(info["path"]) for info in endpoint_info) max_methods_len = max(len(info["methods"]) for info in endpoint_info) - # Print header print(f"{'Path':<{max_path_len + 2}} {'Methods':<{max_methods_len + 2}} Endpoint") print("-" * (max_path_len + max_methods_len + 20)) - # Print each endpoint for info in sorted(endpoint_info, key=lambda x: x["path"]): print(f"{info['path']:<{max_path_len + 2}} {info['methods']:<{max_methods_len + 2}} {info['name']}") - # Print additional info if self.enable_swagger: base_url = f"http://{config.host}:{config.port}" - print(f"\nšŸ“š API Documentation: {base_url}/api/docs") + print(f"\nšŸ“š API Documentation: {base_url}/docs") + print(f" ReDoc: {base_url}/redoc") print(f"\n🌐 Server will start on http://{config.host}:{config.port}") print("=" * 60) + def get_app(self) -> Starlette: + """Creates and returns a Starlette application instance. + + Returns: + A Starlette application instance. + """ + app = Starlette(debug=self.debug, routes=self.routes) + + if config.cors_origins: + app.add_middleware( + StarletteCORSMiddleware, + allow_origins=config.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + if self.enable_swagger: + app.state.swagger_generator = self.swagger_generator + + return app + def run( self, host: str = None, @@ -282,16 +183,14 @@ def run( debug: bool = None, reload: bool = None, ): - """ - Run the application server. + """Runs the application server. Args: - host: Host address to bind to. - port: Port to bind to. + host: The host address to bind to. + port: The port to bind to. debug: Whether to enable debug mode. reload: Whether to enable auto-reload on code changes. """ - # Update config with any provided values (only if not None) update_params = {} if host is not None: update_params["host"] = host @@ -305,37 +204,23 @@ def run( if update_params: config.update(**update_params) - # Print available endpoints before starting the server self._print_endpoints() - app = Starlette(debug=config.debug, routes=self.routes) + Base.metadata.create_all(self.engine) - # Add CORS middleware if origins are configured - if config.cors_origins: - app.add_middleware( - StarletteCORSMiddleware, - allow_origins=config.cors_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # Always set up swagger generator if enabled - if self.enable_swagger: - app.state.swagger_generator = self.swagger_generator + app = self.get_app() uvicorn.run( app, host=config.host, port=config.port, - log_level="debug" if config.debug else "info", + log_level="debug" if self.debug else "info", reload=config.reload, ) class Response(JSONResponse): - """ - Custom JSON response class. + """Custom JSON response class. Extends Starlette's JSONResponse with a simplified constructor and default application/json media type. @@ -349,15 +234,14 @@ def __init__( media_type: str = None, content_type: str = None, ): - """ - Initialize a new Response. + """Initializes a new Response. Args: content: The response content. - status_code: HTTP status code. - headers: HTTP headers. - media_type: HTTP media type. - content_type: HTTP content type (alias for media_type). + status_code: The HTTP status code. + headers: The HTTP headers. + media_type: The HTTP media type. + content_type: The HTTP content type (alias for media_type). """ # Store the original content for tests to access self._test_content = content @@ -374,7 +258,7 @@ def __init__( ) def __getattribute__(self, name): - """Override attribute access to provide test compatibility for body.""" + """Overrides attribute access to provide test compatibility for body.""" if name == "body": # Check if we're in a test context (looking for TestClient or similar) import inspect @@ -432,8 +316,8 @@ def __getattribute__(self, name): return super().__getattribute__(name) def decode(self): - """ - Decode the body content for tests that expect this method. + """Decodes the body content for tests that expect this method. + This method maintains compatibility with tests that expect the body to be bytes with a decode method. """ @@ -454,16 +338,14 @@ def decode(self): class Middleware: - """ - Base class for middleware components. + """Base class for middleware components. Middleware can process requests before they reach the endpoint and responses before they are returned to the client. """ def process(self, request, response): - """ - Process a request or response. + """Processes a request or response. This method is called twice during request handling: 1. Before the request reaches the endpoint (response is None) @@ -480,21 +362,19 @@ def process(self, request, response): class CORSMiddleware(Middleware): - """ - CORS (Cross-Origin Resource Sharing) middleware. + """CORS (Cross-Origin Resource Sharing) middleware. Handles CORS preflight requests and adds appropriate headers to responses. This provides a more flexible alternative to Starlette's built-in CORS middleware. """ def __init__(self, allow_origins=None, allow_methods=None, allow_headers=None): - """ - Initialize CORS middleware. + """Initializes CORS middleware. Args: - allow_origins: List of allowed origins, defaults to ['*'] - allow_methods: List of allowed HTTP methods - allow_headers: List of allowed headers + allow_origins: A list of allowed origins, defaults to ['*']. + allow_methods: A list of allowed HTTP methods. + allow_headers: A list of allowed headers. """ if allow_origins is None: allow_origins = ["*"] @@ -508,15 +388,14 @@ def __init__(self, allow_origins=None, allow_methods=None, allow_headers=None): self.allow_headers = allow_headers def process(self, request, response): - """ - Process CORS requests and add appropriate headers. + """Processes CORS requests and add appropriate headers. Args: - request: The HTTP request - response: The HTTP response (None for pre-processing) + request: The HTTP request. + response: The HTTP response (None for pre-processing). Returns: - Response with CORS headers or preflight response + A response with CORS headers or a preflight response. """ if response is None: # Handle preflight OPTIONS requests @@ -568,19 +447,17 @@ def process(self, request, response): class AuthenticationMiddleware(Middleware): - """ - Authentication middleware that integrates with authentication classes. + """Authentication middleware that integrates with authentication classes. Automatically handles authentication and returns appropriate error responses when authentication fails. Supports skipping authentication for OPTIONS requests. """ def __init__(self, authentication_class=None): - """ - Initialize authentication middleware. + """Initializes authentication middleware. Args: - authentication_class: The authentication class to use + authentication_class: The authentication class to use. """ self.authentication_class = authentication_class if authentication_class: @@ -589,15 +466,14 @@ def __init__(self, authentication_class=None): self.authenticator = None def process(self, request, response): - """ - Process authentication for requests. + """Processes authentication for requests. Args: - request: The HTTP request - response: The HTTP response (None for pre-processing) + request: The HTTP request. + response: The HTTP response (None for pre-processing). Returns: - Error response if authentication fails, otherwise None/response + An error response if authentication fails, otherwise None/response. """ if response is None and self.authenticator: # Pre-processing: check authentication diff --git a/lightapi/database.py b/lightapi/database.py index 1a7660f..a206257 100644 --- a/lightapi/database.py +++ b/lightapi/database.py @@ -6,14 +6,28 @@ from .config import config + +def setup_database(database_url: str = "sqlite:///app.db"): + """Sets up the database connection and returns the engine and session factory. + + Args: + database_url: The SQLAlchemy database URL. + + Returns: + A tuple containing the SQLAlchemy engine and session factory. + """ + import sqlalchemy + engine = sqlalchemy.create_engine(database_url) + Session = sqlalchemy.orm.sessionmaker(bind=engine) + return engine, Session + engine = create_engine(config.database_url) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @as_declarative() class Base: - """ - Custom SQLAlchemy base class for all models. + """Custom SQLAlchemy base class for all models. Provides automatic __tablename__ generation and utility methods for model instances to make working with SQLAlchemy models easier. @@ -30,8 +44,7 @@ class Base: @property def table(self): - """ - Get the table metadata for this model. + """Gets the table metadata for this model. Returns: The SQLAlchemy Table object for this model. @@ -40,33 +53,32 @@ def table(self): @declared_attr def __tablename__(cls): - """ - Generate the table name based on the class name. + """Generates the table name based on the class name. The table name is derived by converting the class name to lowercase. Returns: - str: The generated table name. + The generated table name. """ return cls.__name__.lower() @property def pk(self): + """Returns the primary key of the model instance.""" return self.id def serialize(self) -> dict: - """ - Convert the model instance into a dictionary representation. + """Converts the model instance into a dictionary representation. Each key in the dictionary corresponds to a column name, and the value is the data stored in that column. Datetime objects are converted to strings. Returns: - dict: A dictionary representation of the model instance. + A dictionary representation of the model instance. """ return { column.name: ( getattr(self, column.name).isoformat() if isinstance(getattr(self, column.name), datetime) else getattr(self, column.name) ) for column in self.table.columns - } + } \ No newline at end of file diff --git a/lightapi/handlers.py b/lightapi/handlers.py index 32bc377..833f455 100644 --- a/lightapi/handlers.py +++ b/lightapi/handlers.py @@ -1,182 +1,163 @@ import datetime import json +import logging +import re +import traceback from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import List, Type -from aiohttp import web +from dateutil import parser from sqlalchemy import inspect from sqlalchemy.exc import IntegrityError, StatementError from sqlalchemy.orm import Session, sessionmaker +from starlette.responses import JSONResponse, Response from lightapi.database import Base, SessionLocal -def create_handler(model: Type[Base], session_factory=SessionLocal) -> List[web.RouteDef]: - """ - Creates a list of route handlers for the given model. - Accepts a session_factory to use for DB sessions. +def create_handler(model: Type[Base], session_factory=SessionLocal): + """Creates a list of route handlers for the given model. + + Args: + model: The SQLAlchemy model class to create handlers for. + session_factory: The session factory to use for database connections. """ - return [ - web.post(f"/{model.__tablename__}/", CreateHandler(model, session_factory)), - web.get(f"/{model.__tablename__}/", RetrieveAllHandler(model, session_factory)), - web.get(f"/{model.__tablename__}/{{id}}", ReadHandler(model, session_factory)), - web.put(f"/{model.__tablename__}/{{id}}", UpdateHandler(model, session_factory)), - web.delete(f"/{model.__tablename__}/{{id}}", DeleteHandler(model, session_factory)), - web.patch(f"/{model.__tablename__}/{{id}}", PatchHandler(model, session_factory)), - ] + # This function is no longer the primary way of creating routes, + # but we'll keep it for now and adapt it. + # The actual route creation will be handled in LightApi.register + pass @dataclass class AbstractHandler(ABC): - """ - Abstract base class for handling HTTP requests related to a specific model. - - Attributes: - model (Base): The SQLAlchemy model class to operate on. - session_factory (sessionmaker): The session factory to use for database operations. - """ + """Abstract base class for handling HTTP requests related to a specific model.""" model: Type[Base] = field(default=None) session_factory: sessionmaker = field(default=SessionLocal) @abstractmethod - async def handle(self, db: Session, request: web.Request): - """ - Abstract method to handle the HTTP request. + async def handle(self, db: Session, request): + """Abstract method to handle the HTTP request. Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. - - Raises: - NotImplementedError: If the method is not implemented by subclasses. + db: The SQLAlchemy session for database operations. + request: The Starlette request object. """ raise NotImplementedError("Method not implemented") - async def __call__(self, request: web.Request, *args, **kwargs): - """ - Calls the handler with the provided request. + async def __call__(self, scope, receive, send): + """Makes the handler a callable ASGI application. Args: - request (web.Request): The aiohttp web request object. - - Returns: - web.Response: The response returned by the handler. + scope: The ASGI scope. + receive: The ASGI receive channel. + send: The ASGI send channel. """ + from starlette.requests import Request + request = Request(scope, receive) db: Session = self.session_factory() try: - return await self.handle(db, request) + response = await self.handle(db, request) + await response(scope, receive, send) + except Exception as e: + logging.error(f"Unhandled exception: {e}\n{traceback.format_exc()}") + # TODO: check for debug mode from app config + response = JSONResponse( + {"error": "Internal Server Error", "message": str(e), "traceback": traceback.format_exc()}, + status_code=500, + ) + await response(scope, receive, send) finally: db.close() - async def get_request_json(self, request: web.Request): - """ - Extracts JSON data from the request body. + async def get_request_json(self, request): + """Extracts JSON data from the request body. Args: - request (web.Request): The aiohttp web request object. + request: The Starlette request object. Returns: - dict: The JSON data from the request body. + A dictionary of the JSON body. """ return await request.json() - def get_item_by_id(self, db: Session, item_id: int): - """ - Retrieves an item by its primary key. - - Args: - db (Session): The SQLAlchemy session for database operations. - item_id (int): The primary key of the item to retrieve. - - Returns: - Base: The item retrieved from the database, or None if not found. - """ - return db.query(self.model).filter(self.model.id == item_id).first() - def add_and_commit_item(self, db: Session, item): - """ - Adds and commits a new item to the database. + """Adds and commits a new item to the database. Args: - db (Session): The SQLAlchemy session for database operations. - item (Base): The item to add and commit. + db: The SQLAlchemy session for database operations. + item: The item to add and commit. Returns: - Base: The item after committing to the database. + The item after committing to the database, or a JSONResponse on error. """ try: db.add(item) db.commit() db.refresh(item) - - if hasattr(self.model, "id"): - if isinstance(self.model.id, tuple): - filters = [col == getattr(item, col.name) for col in self.model.id] - item = db.query(self.model).filter(*filters).first() - else: - item = db.query(self.model).filter(self.model.id == getattr(item, self.model.id.name)).first() - - mapper = inspect(self.model) - for col in self.model.__table__.columns: - if getattr(item, col.name) is None and col.default is not None and col.default.is_scalar: - setattr(item, col.name, col.default.arg) - - if hasattr(col.type, "python_type"): - if col.type.python_type is datetime.datetime and isinstance(getattr(item, col.name), str): - try: - setattr(item, col.name, datetime.datetime.fromisoformat(getattr(item, col.name))) - except ValueError: - pass - elif col.type.python_type is datetime.date and isinstance(getattr(item, col.name), str): - try: - setattr(item, col.name, datetime.date.fromisoformat(getattr(item, col.name))) - except ValueError: - return self.json_error_response(f"Invalid date format for field '{col.name}'", status=400) return item - except (IntegrityError, StatementError) as e: + except IntegrityError as e: + db.rollback() + if "UNIQUE constraint failed" in str(e.orig) or "Duplicate entry" in str(e.orig): + match = re.search(r"failed: ([\w\.]+)", str(e.orig)) + if match: + column = match.group(1) + return self.json_error_response(f"Unique constraint violated for {column}.", status=409) + return self.json_error_response("Unique constraint violated.", status=409) + return self.json_error_response(f"Database integrity error: {e.orig}", status=400) + except StatementError as e: db.rollback() - return self.json_error_response(str(e), status=409) + return self.json_error_response(f"Database statement error: {e.orig}", status=400) def delete_and_commit_item(self, db: Session, item): - """ - Deletes and commits the removal of an item from the database. + """Deletes and commits the removal of an item from the database. Args: - db (Session): The SQLAlchemy session for database operations. - item (Base): The item to delete. + db: The SQLAlchemy session for database operations. + item: The item to delete. """ db.delete(item) db.commit() def json_response(self, item, status=200): - """ - Creates a JSON response for the given item. + """Creates a JSON response for the given item. Args: - item (Base): The item to serialize and return. - status (int, optional): The HTTP status code. Defaults to 200. + item: The item to serialize and return. + status: The HTTP status code. Returns: - web.Response: The JSON response containing the serialized item. + A Starlette JSONResponse. """ - return web.json_response(item.serialize(), status=status) + return JSONResponse(item.serialize(), status_code=status) - def json_error_response(self, error_message, status=404): - """ - Creates a JSON response for an error message. + def json_error_response(self, error_details, status=404): + """Creates a JSON response for an error message. Args: - error_message (str): The error message to return. - status (int, optional): The HTTP status code. Defaults to 404. + error_details: The error message or a dict with details. + status: The HTTP status code. Returns: - web.Response: The JSON response containing the error message. + A Starlette JSONResponse. """ - return web.json_response({"error": error_message}, status=status) + if isinstance(error_details, str): + error_payload = {"error": error_details} + else: + error_payload = error_details + return JSONResponse(error_payload, status_code=status) def _parse_pk_value(self, value, col): + """Parses a primary key value to the correct type. + + Args: + value: The value to parse. + col: The SQLAlchemy column object. + + Returns: + The parsed value. + """ try: if hasattr(col.type, "python_type") and col.type.python_type is int: return int(value) @@ -185,31 +166,18 @@ def _parse_pk_value(self, value, col): return value -class DateTimeEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, datetime): - return obj.isoformat() - return super().default(obj) - - class CreateHandler(AbstractHandler): - """ - Handles HTTP POST requests to create a new item. - """ - - def __init__(self, model, session_factory=SessionLocal): - super().__init__(model, session_factory) + """Handles HTTP POST requests to create a new item.""" async def handle(self, db, request): - """ - Processes the POST request to create and save a new item. + """Processes the POST request to create and save a new item. Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. + db: The SQLAlchemy session for database operations. + request: The Starlette request object. Returns: - web.Response: The JSON response containing the created item. + A Starlette JSONResponse containing the created item or an error. """ data = await self.get_request_json(request) @@ -219,11 +187,7 @@ async def handle(self, db, request): if col.name not in data: missing.append(col.name) if missing: - return web.json_response({"error": f"Missing required fields: {', '.join(missing)}"}, status=400) - - if "amount" in data and isinstance(data["amount"], (int, float)): - if data["amount"] < 0: - return web.json_response({"error": "Amount must be non-negative"}, status=400) + return JSONResponse({"error": f"Missing required fields: {', '.join(missing)}"}, status_code=400) for col in self.model.__table__.columns: if col.name in data: @@ -231,238 +195,117 @@ async def handle(self, db, request): if hasattr(col.type, "python_type"): if col.type.python_type is datetime.datetime and isinstance(val, str): try: - data[col.name] = datetime.datetime.fromisoformat(val) + data[col.name] = parser.parse(val) except ValueError: - return web.json_response({"error": f"Invalid datetime format for field '{col.name}'"}, status=400) + return JSONResponse({"error": f"Invalid datetime format for field '{col.name}'"}, status_code=400) elif col.type.python_type is datetime.date and isinstance(val, str): try: - data[col.name] = datetime.date.fromisoformat(val) + data[col.name] = parser.parse(val).date() except ValueError: - return web.json_response({"error": f"Invalid date format for field '{col.name}'"}, status=400) + return JSONResponse({"error": f"Invalid date format for field '{col.name}'"}, status_code=400) + item = self.model(**data) item = self.add_and_commit_item(db, item) - if isinstance(item, web.Response): + if isinstance(item, JSONResponse): return item return self.json_response(item, status=201) class ReadHandler(AbstractHandler): - """ - Handles HTTP GET requests to retrieve one or all items. - """ - - def __init__(self, model, session_factory=SessionLocal, pk_cols=None): - super().__init__(model, session_factory) - self.pk_cols = pk_cols + """Handles HTTP GET requests to retrieve a single item.""" async def handle(self, db, request): - """ - Processes the GET request to retrieve an item by ID or all items. + """Processes the GET request to retrieve an item by ID. Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. + db: The SQLAlchemy session for database operations. + request: The Starlette request object. Returns: - web.Response: The JSON response containing the item(s) or an error message. + A Starlette JSONResponse containing the item or a 404 error. """ - - if self.pk_cols and len(self.pk_cols) > 1: - pk_values = [request.match_info.get(col) for col in self.pk_cols] - if None in pk_values: - return web.json_response({"error": "Missing composite key(s)"}, status=400) - filters = [ - getattr(self.model, col) == self._parse_pk_value(val, getattr(self.model, col)) for col, val in zip(self.pk_cols, pk_values) - ] - item = db.query(self.model).filter(*filters).first() - else: - pk_col = self.pk_cols[0] if self.pk_cols else "id" - pk_value = request.match_info.get(pk_col) - item = ( - db.query(self.model) - .filter(getattr(self.model, pk_col) == self._parse_pk_value(pk_value, getattr(self.model, pk_col))) - .first() - ) + pk_col = "id" # Assuming 'id' for now + pk_value = request.path_params.get(pk_col) + + item = db.query(self.model).filter(getattr(self.model, pk_col) == pk_value).first() + if not item: - return web.json_response({"error": "Not found"}, status=404) + return self.json_error_response(f"{self.model.__name__} with id {pk_value} not found", status=404) return self.json_response(item, status=200) -class UpdateHandler(AbstractHandler): - """ - Handles HTTP PUT requests to update an existing item. - """ - - def __init__(self, model, session_factory=SessionLocal, pk_cols=None): - super().__init__(model, session_factory) - self.pk_cols = pk_cols +class RetrieveAllHandler(AbstractHandler): + """Handles HTTP GET requests to retrieve all items.""" async def handle(self, db, request): - """ - Processes the PUT request to update an existing item. + """Processes the GET request to retrieve all items. Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. + db: The SQLAlchemy session for database operations. + request: The Starlette request object. Returns: - web.Response: The JSON response containing the updated item or an error message. + A Starlette JSONResponse containing all items. """ - if self.pk_cols and len(self.pk_cols) > 1: - pk_values = [request.match_info.get(col) for col in self.pk_cols] - if None in pk_values: - return self.json_error_response("Missing composite key(s)", status=400) - filters = [ - getattr(self.model, col) == self._parse_pk_value(val, getattr(self.model, col)) for col, val in zip(self.pk_cols, pk_values) - ] - item = db.query(self.model).filter(*filters).first() - else: - pk_col = self.pk_cols[0] if self.pk_cols else "id" - pk_value = request.match_info.get(pk_col) - item = ( - db.query(self.model) - .filter(getattr(self.model, pk_col) == self._parse_pk_value(pk_value, getattr(self.model, pk_col))) - .first() - ) - if not item: - return self.json_error_response("Item not found", status=404) - - data = await self.get_request_json(request) - for key, value in data.items(): - setattr(item, key, value) - - item = self.add_and_commit_item(db, item) - if isinstance(item, web.Response): - return item - return self.json_response(item, status=200) - + items = db.query(self.model).all() + response = [item.serialize() for item in items] + return JSONResponse(response, status_code=200) -class PatchHandler(AbstractHandler): - """ - Handles HTTP PATCH requests to partially update an existing item. - """ - def __init__(self, model, session_factory=SessionLocal, pk_cols=None): - super().__init__(model, session_factory) - self.pk_cols = pk_cols +class UpdateHandler(AbstractHandler): + """Handles HTTP PUT requests to update an existing item.""" async def handle(self, db, request): - """ - Processes the PATCH request to partially update an existing item. + """Processes the PUT request to update an existing item. Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. + db: The SQLAlchemy session for database operations. + request: The Starlette request object. Returns: - web.Response: The JSON response containing the updated item or an error message. + A Starlette JSONResponse containing the updated item or an error. """ - if self.pk_cols and len(self.pk_cols) > 1: - pk_values = [request.match_info.get(col) for col in self.pk_cols] - if None in pk_values: - return self.json_error_response("Missing composite key(s)", status=400) - filters = [ - getattr(self.model, col) == self._parse_pk_value(val, getattr(self.model, col)) for col, val in zip(self.pk_cols, pk_values) - ] - item = db.query(self.model).filter(*filters).first() - else: - pk_col = self.pk_cols[0] if self.pk_cols else "id" - pk_value = request.match_info.get(pk_col) - item = ( - db.query(self.model) - .filter(getattr(self.model, pk_col) == self._parse_pk_value(pk_value, getattr(self.model, pk_col))) - .first() - ) + pk_col = "id" + pk_value = request.path_params.get(pk_col) + + item = db.query(self.model).filter(getattr(self.model, pk_col) == pk_value).first() if not item: - return self.json_error_response("Item not found", status=404) + return self.json_error_response(f"{self.model.__name__} with id {pk_value} not found", status=404) data = await self.get_request_json(request) - - for col in self.model.__table__.columns: - if col.name in data: - val = data[col.name] - if hasattr(col.type, "python_type"): - if col.type.python_type is datetime.datetime and isinstance(val, str): - try: - data[col.name] = datetime.datetime.fromisoformat(val) - except ValueError: - return web.json_response({"error": f"Invalid datetime format for field '{col.name}'"}, status=400) - elif col.type.python_type is datetime.date and isinstance(val, str): - try: - data[col.name] = datetime.date.fromisoformat(val) - except ValueError: - return web.json_response({"error": f"Invalid date format for field '{col.name}'"}, status=400) for key, value in data.items(): setattr(item, key, value) item = self.add_and_commit_item(db, item) - if isinstance(item, web.Response): + if isinstance(item, JSONResponse): return item return self.json_response(item, status=200) -class DeleteHandler(AbstractHandler): - """ - Handles HTTP DELETE requests to delete an existing item. - """ +class PatchHandler(UpdateHandler): + """Handles HTTP PATCH requests to partially update an existing item.""" + pass - def __init__(self, model, session_factory=SessionLocal, pk_cols=None): - super().__init__(model, session_factory) - self.pk_cols = pk_cols + +class DeleteHandler(AbstractHandler): + """Handles HTTP DELETE requests to delete an existing item.""" async def handle(self, db, request): - """ - Processes the DELETE request to remove an existing item. + """Processes the DELETE request to remove an existing item. Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. + db: The SQLAlchemy session for database operations. + request: The Starlette request object. Returns: - web.Response: An empty response with status 204 if the item is deleted. + A Starlette Response with status 204. """ - if self.pk_cols and len(self.pk_cols) > 1: - pk_values = [request.match_info.get(col) for col in self.pk_cols] - if None in pk_values: - return self.json_error_response("Missing composite key(s)", status=400) - filters = [ - getattr(self.model, col) == self._parse_pk_value(val, getattr(self.model, col)) for col, val in zip(self.pk_cols, pk_values) - ] - item = db.query(self.model).filter(*filters).first() - else: - pk_col = self.pk_cols[0] if self.pk_cols else "id" - pk_value = request.match_info.get(pk_col) - item = ( - db.query(self.model) - .filter(getattr(self.model, pk_col) == self._parse_pk_value(pk_value, getattr(self.model, pk_col))) - .first() - ) + pk_col = "id" + pk_value = request.path_params.get(pk_col) + + item = db.query(self.model).filter(getattr(self.model, pk_col) == pk_value).first() if not item: - return self.json_error_response("Item not found", status=404) + return self.json_error_response(f"{self.model.__name__} with id {pk_value} not found", status=404) self.delete_and_commit_item(db, item) - return web.Response(status=204) - - -class RetrieveAllHandler(AbstractHandler): - """ - Handles HTTP GET requests to retrieve all items. - """ - - def __init__(self, model, session_factory=SessionLocal): - super().__init__(model, session_factory) - - async def handle(self, db, request): - """ - Processes the GET request to retrieve all items. - - Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. - - Returns: - web.Response: The JSON response containing all items. - """ - items = db.query(self.model).all() - response = [item.serialize() for item in items] - return web.json_response(response, status=200, dumps=json.dumps) + return Response(status_code=204) diff --git a/lightapi/models.py b/lightapi/models.py index 6fa976a..400f794 100644 --- a/lightapi/models.py +++ b/lightapi/models.py @@ -8,31 +8,6 @@ from lightapi.database import Base -def setup_database(database_url: str = "sqlite:///app.db"): - """ - Set up the database connection and create tables. - - Initializes SQLAlchemy with the provided database URL, - creates the database tables, and returns the engine - and session factory. - - Args: - database_url: The SQLAlchemy database URL. - - Returns: - tuple: A tuple containing (engine, Session). - """ - engine = sqlalchemy.create_engine(database_url) - - try: - Base.metadata.create_all(engine) - except Exception: - pass - - Session = sqlalchemy.orm.sessionmaker(bind=engine) - return engine, Session - - def register_model_class(cls): """ Register a RestEndpoint class as a SQLAlchemy model. diff --git a/lightapi/swagger.py b/lightapi/swagger.py index bbdc81e..bf7abb4 100644 --- a/lightapi/swagger.py +++ b/lightapi/swagger.py @@ -10,19 +10,18 @@ class SwaggerGenerator: - """ - Generates OpenAPI documentation from LightAPI endpoint classes. + """Generates OpenAPI documentation from LightAPI endpoint classes. This class analyzes RestEndpoint classes to extract information about their schemas, HTTP methods, validation, and other metadata to build a complete OpenAPI specification document. Attributes: - title (str): The title of the API documentation. - version (str): The API version. - description (str): A description of the API. - paths (dict): Endpoint paths and their operations. - components (dict): Schema definitions and security schemes. + title: The title of the API documentation. + version: The API version. + description: A description of the API. + paths: A dictionary of endpoint paths and their operations. + components: A dictionary of schema definitions and security schemes. """ def __init__( @@ -31,8 +30,7 @@ def __init__( version: str = "1.0.0", description: str = "API documentation", ): - """ - Initialize a new SwaggerGenerator. + """Initializes a new SwaggerGenerator. Args: title: The title of the API documentation. @@ -55,8 +53,7 @@ def __init__( } def register_endpoint(self, path: str, endpoint_class: Type[RestEndpoint]): - """ - Register an endpoint class for OpenAPI documentation. + """Registers an endpoint class for OpenAPI documentation. Analyzes the endpoint class to extract HTTP methods, schemas, and other metadata to include in the OpenAPI documentation. @@ -84,8 +81,7 @@ def register_endpoint(self, path: str, endpoint_class: Type[RestEndpoint]): self.paths[path] = path_operations def _generate_schema(self, endpoint_class: Type[RestEndpoint]) -> Dict[str, Any]: - """ - Generate an OpenAPI schema from an endpoint class. + """Generates an OpenAPI schema from an endpoint class. Extracts information about the model fields, their types, and validation requirements to create an OpenAPI schema definition. @@ -94,7 +90,7 @@ def _generate_schema(self, endpoint_class: Type[RestEndpoint]) -> Dict[str, Any] endpoint_class: The RestEndpoint class to analyze. Returns: - A dict containing the OpenAPI schema definition. + A dictionary containing the OpenAPI schema definition. """ properties = {} required = [] @@ -130,14 +126,13 @@ def _generate_schema(self, endpoint_class: Type[RestEndpoint]) -> Dict[str, Any] } def _map_sql_type_to_openapi(self, sql_type) -> Dict[str, Any]: - """ - Map SQLAlchemy column types to OpenAPI data types. + """Maps SQLAlchemy column types to OpenAPI data types. Args: - sql_type: SQLAlchemy type object. + sql_type: The SQLAlchemy type object. Returns: - A dict containing the OpenAPI type definition. + A dictionary containing the OpenAPI type definition. """ type_map = { "INTEGER": {"type": "integer"}, @@ -159,8 +154,7 @@ def _map_sql_type_to_openapi(self, sql_type) -> Dict[str, Any]: return {"type": "string"} def _generate_operation(self, endpoint_class: Type[RestEndpoint], method: str, model_name: str) -> Dict[str, Any]: - """ - Generate an OpenAPI operation object for an endpoint method. + """Generates an OpenAPI operation object for an endpoint method. Args: endpoint_class: The RestEndpoint class. @@ -168,7 +162,7 @@ def _generate_operation(self, endpoint_class: Type[RestEndpoint], method: str, m model_name: The model name for reference. Returns: - A dict containing the OpenAPI operation definition. + A dictionary containing the OpenAPI operation definition. """ method_handler = getattr(endpoint_class, method, None) description = "" @@ -200,11 +194,10 @@ def _generate_operation(self, endpoint_class: Type[RestEndpoint], method: str, m return operation def generate_openapi_spec(self) -> Dict[str, Any]: - """ - Generate the complete OpenAPI specification document. + """Generates the complete OpenAPI specification document. Returns: - A dict containing the full OpenAPI specification. + A dictionary containing the full OpenAPI specification. """ return { "openapi": "3.0.0", @@ -218,8 +211,7 @@ def generate_openapi_spec(self) -> Dict[str, Any]: } def get_swagger_ui(self) -> HTMLResponse: - """ - Generate the Swagger UI HTML page for interactive API documentation. + """Generates the Swagger UI HTML page for interactive API documentation. Returns: An HTMLResponse containing the Swagger UI interface. @@ -260,9 +252,35 @@ def get_swagger_ui(self) -> HTMLResponse: """ ) - def get_openapi_json(self) -> JSONResponse: + def get_redoc_html(self) -> HTMLResponse: + """Generates the ReDoc HTML page for alternative API documentation. + + Returns: + An HTMLResponse containing the ReDoc UI interface. + """ + return HTMLResponse( + """ + + + + LightAPI - ReDoc + + + + + + + + + + """ - Generate the OpenAPI specification as a JSON response. + ) + + def get_openapi_json(self) -> JSONResponse: + """Generates the OpenAPI specification as a JSON response. Returns: A JSONResponse containing the OpenAPI specification. @@ -271,8 +289,7 @@ def get_openapi_json(self) -> JSONResponse: def swagger_ui_route(request): - """ - Handle requests for the Swagger UI page. + """Handles requests for the Swagger UI page. Args: request: The incoming HTTP request. @@ -285,8 +302,7 @@ def swagger_ui_route(request): def openapi_json_route(request): - """ - Handle requests for the OpenAPI JSON specification. + """Handles requests for the OpenAPI JSON specification. Args: request: The incoming HTTP request. @@ -296,3 +312,16 @@ def openapi_json_route(request): """ generator = request.app.state.swagger_generator return generator.get_openapi_json() + + +def redoc_ui_route(request): + """Handles requests for the ReDoc UI page. + + Args: + request: The incoming HTTP request. + + Returns: + The ReDoc UI HTML response. + """ + generator = request.app.state.swagger_generator + return generator.get_redoc_html() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 431ed2a..b0d7fc6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ psycopg2-binary==2.9.9 PyJWT==2.9.0 pytest==8.2.2 -PyYAML>=5.1 \ No newline at end of file +PyYAML>=5.1 +python-dateutil==2.9.0.post0 \ No newline at end of file