diff --git a/docs/advanced/validation.md b/docs/advanced/validation.md index 455b829..a8df774 100644 --- a/docs/advanced/validation.md +++ b/docs/advanced/validation.md @@ -50,3 +50,14 @@ class UserEndpoint(RestEndpoint): - You can also override the `validate(self, data: dict)` method directly for full-body validation. - Combine with filtering and pagination for robust endpoint logic. + +## 5. Automatic Datetime Parsing + +LightAPI automatically parses string values for columns of type `DateTime` and `Date`. It uses the `python-dateutil` library to flexibly parse a wide range of formats, including: + +- `YYYY-MM-DDTHH:MM:SS` +- `YYYY-MM-DDTHH:MM:SS.ffffff` +- `YYYY-MM-DDTHH:MM:SS+HH:MM` (with timezone) +- `YYYY-MM-DD` + +If a string cannot be parsed as a valid datetime, a `400 Bad Request` error is returned with a descriptive message. diff --git a/docs/getting-started/first-steps.md b/docs/getting-started/first-steps.md index fff37d5..c11b92b 100644 --- a/docs/getting-started/first-steps.md +++ b/docs/getting-started/first-steps.md @@ -51,9 +51,7 @@ from lightapi import LightApi from app.models import User app = LightApi() -app.register({ - '/users': User -}) +app.register(User) if __name__ == '__main__': app.run(host='127.0.0.1', port=8000) diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index 55e6fcd..b7de8a1 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -119,10 +119,8 @@ from lightapi import LightApi from models import User, Post app = LightApi(database_url="sqlite:///blog.db") -app.register({ - '/users': User, - '/posts': Post -}) +app.register(User) +app.register(Post) if __name__ == '__main__': app.run(host='0.0.0.0', port=8000) @@ -132,9 +130,14 @@ if __name__ == '__main__': Once your API is running, you can test it in several ways: -### 1. Interactive Swagger Documentation +### 1. Interactive API Documentation -Visit **http://localhost:8000/docs** in your browser for an interactive API documentation interface where you can: +LightAPI provides two interactive documentation interfaces out of the box: + +- **Swagger UI**: Visit **http://localhost:8000/docs** in your browser. +- **ReDoc**: Visit **http://localhost:8000/redoc** for an alternative documentation layout. + +Both interfaces allow you to: - Browse all available endpoints - Test API calls directly from the browser - View request/response schemas diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 7488f69..48a5f35 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -9,6 +9,44 @@ This guide covers common issues you might encounter when using LightAPI and thei ## Runtime Errors +### Understanding Error Responses + +LightAPI now provides detailed error responses in debug mode to help with troubleshooting. + +#### Validation Errors (4xx) + +When a request fails due to invalid data (e.g., missing required fields, incorrect data types), you will receive a `4xx` status code with a JSON body describing the error. + +**Example: Unique Constraint Violation (`409 Conflict`)** +```json +{ + "error": "Unique constraint violated for users.email." +} +``` + +**Example: Invalid Datetime Format (`400 Bad Request`)** +```json +{ + "error": "Invalid datetime format for field 'published_date'" +} +``` + +#### Server Errors (500) + +When an unexpected server error occurs, LightAPI will return a `500 Internal Server Error` with a detailed JSON body in debug mode. + +**Example: 500 Error Response (in debug mode)** +```json +{ + "error": "Internal Server Error", + "message": "...", + "traceback": "..." +} +``` + +In production mode (`debug=False`), the `message` and `traceback` fields are omitted to avoid leaking sensitive information. + + ### Content-Length Errors **Error**: `RuntimeError: Response content longer than Content-Length` diff --git a/lightapi/core.py b/lightapi/core.py index 3f675dd..17d7832 100644 --- a/lightapi/core.py +++ b/lightapi/core.py @@ -1,37 +1,27 @@ -import hashlib import json -from inspect import iscoroutinefunction -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Type +from typing import Any, Callable, Dict, List, Type import uvicorn from starlette.applications import Starlette - from starlette.middleware.cors import CORSMiddleware as StarletteCORSMiddleware from starlette.responses import JSONResponse from starlette.routing import Route from .config import config -from .models import setup_database - -if TYPE_CHECKING: - from .rest import RestEndpoint +from .database import Base, setup_database class LightApi: - """ - Main application class for building REST APIs. - - LightApi provides functionality for setting up and running a - REST API application. It includes features for registering endpoints, - applying middleware, generating API documentation, and running the server. + """Main application class for building REST APIs. Attributes: - routes: List of Starlette routes. - middleware: List of middleware classes. - engine: SQLAlchemy engine. - Session: SQLAlchemy session factory. - enable_swagger: Whether Swagger documentation is enabled. - swagger_generator: SwaggerGenerator instance (if enabled). + routes: A list of Starlette routes. + middleware: A list of middleware classes. + engine: A SQLAlchemy engine instance. + Session: A SQLAlchemy session factory. + enable_swagger: A boolean indicating if Swagger is enabled. + swagger_generator: An instance of SwaggerGenerator. + debug: A boolean indicating if debug mode is enabled. """ def __init__( @@ -42,17 +32,18 @@ def __init__( swagger_description: str = None, enable_swagger: bool = None, cors_origins: List[str] = None, + debug: bool = False, ): - """ - Initialize a new LightApi application. + """Initializes the LightApi application. Args: - database_url: URL for the database connection. - swagger_title: Title for the Swagger documentation. - swagger_version: Version for the Swagger documentation. - swagger_description: Description for the Swagger documentation. + database_url: The URL for the database connection. + swagger_title: The title for the Swagger documentation. + swagger_version: The version for the Swagger documentation. + swagger_description: The description for the Swagger documentation. enable_swagger: Whether to enable Swagger documentation. - cors_origins: List of allowed CORS origins. + cors_origins: A list of allowed CORS origins. + debug: Whether to enable debug mode. """ # Update config with any provided values that are not None update_values = {} @@ -75,149 +66,54 @@ def __init__( self.middleware = [] self.engine, self.Session = setup_database(config.database_url) self.enable_swagger = config.enable_swagger + self.debug = debug if self.enable_swagger: - from .swagger import SwaggerGenerator + from .swagger import SwaggerGenerator, openapi_json_route, redoc_ui_route, swagger_ui_route self.swagger_generator = SwaggerGenerator( title=config.swagger_title, version=config.swagger_version, description=config.swagger_description, ) + self.routes.append(Route("/docs", swagger_ui_route, include_in_schema=False)) + self.routes.append(Route("/redoc", redoc_ui_route, include_in_schema=False)) + self.routes.append(Route("/openapi.json", openapi_json_route, include_in_schema=False)) - def register(self, handler): - """ - Register a model or endpoint class with the application. - Accepts a single SQLAlchemy model or RestEndpoint subclass per call. - """ - from .swagger import openapi_json_route, swagger_ui_route - - # If handler has route_patterns (custom endpoints) - route_patterns = getattr(handler, "route_patterns", None) - if route_patterns: - methods = ( - handler.Configuration.http_method_names - if hasattr(handler, "Configuration") and hasattr(handler.Configuration, "http_method_names") - else ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"] - ) - endpoint_handler = self._create_handler(handler, methods) - for pattern in route_patterns: - self.routes.append(Route(pattern, endpoint_handler, methods=methods)) - if self.enable_swagger: - self.swagger_generator.register_endpoint(pattern, handler) - return - - # If it's a SQLAlchemy model (RESTful resource) - if hasattr(handler, "__tablename__") and handler.__tablename__: - tablename = handler.__tablename__ - methods = ( - handler.Configuration.http_method_names - if hasattr(handler, "Configuration") and hasattr(handler.Configuration, "http_method_names") - else ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"] - ) - endpoint_handler = self._create_handler(handler, methods) - # Register /tablename and /tablename/{id} - base_path = f"/{tablename}" - id_path = f"/{tablename}/{{id}}" - self.routes.append(Route(base_path, endpoint_handler, methods=methods)) - self.routes.append(Route(id_path, endpoint_handler, methods=methods)) - if self.enable_swagger: - self.swagger_generator.register_endpoint(base_path, handler) - self.swagger_generator.register_endpoint(id_path, handler) - return - - # If it's a RestEndpoint subclass without route_patterns or __tablename__ - if hasattr(handler, "Configuration") or hasattr(handler, "get") or hasattr(handler, "post"): - path = f"/{handler.__name__.lower()}" - methods = ( - handler.Configuration.http_method_names - if hasattr(handler, "Configuration") and hasattr(handler.Configuration, "http_method_names") - else ["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"] - ) - endpoint_handler = self._create_handler(handler, methods) - self.routes.append(Route(path, endpoint_handler, methods=methods)) - if self.enable_swagger: - self.swagger_generator.register_endpoint(path, handler) - return - - raise TypeError(f"Handler must be a SQLAlchemy model class or RestEndpoint class. Got: {handler}") - - def _create_handler(self, endpoint_class: Type["RestEndpoint"], methods: List[str]) -> Callable: - """ - Create a request handler for an endpoint class. + def register(self, model_class: Type[Base]): + """Registers a SQLAlchemy model to generate CRUD endpoints. Args: - endpoint_class: The endpoint class to create a handler for. - methods: List of HTTP methods the endpoint supports. - - Returns: - An async function that handles requests to the endpoint. + model_class: The SQLAlchemy model class to register. """ + from . import handlers - async def handler(request): - try: - endpoint = endpoint_class() - - if request.method in ["POST", "PUT", "PATCH"]: - try: - body = await request.body() - if body: - request.data = json.loads(body) - else: - request.data = {} - except json.JSONDecodeError: - request.data = {} - else: - request.data = {} - - # Setup the endpoint and check for authentication errors - setup_result = endpoint._setup(request, self.Session()) - if setup_result: - return setup_result - - method = request.method.lower() - if method.upper() not in [m.upper() for m in methods]: - return JSONResponse({"error": f"Method {method} not allowed"}, status_code=405) - - func = getattr(endpoint, method) - if iscoroutinefunction(func): - result = await func(request) - else: - result = func(request) - - # Convert returned value to a Response instance - if isinstance(result, (Response, JSONResponse)): - response = result - else: - if isinstance(result, tuple) and len(result) == 2: - body, status = result - else: - body, status = result, 200 - response = JSONResponse(body, status_code=status) + base_path = f"/{model_class.__tablename__}" + id_path = f"/{model_class.__tablename__}/{{id}}" - return response + self.routes.extend([ + Route(base_path, handlers.CreateHandler(model_class, self.Session), methods=["POST"]), + Route(base_path, handlers.RetrieveAllHandler(model_class, self.Session), methods=["GET"]), + Route(id_path, handlers.ReadHandler(model_class, self.Session), methods=["GET"]), + Route(id_path, handlers.UpdateHandler(model_class, self.Session), methods=["PUT"]), + Route(id_path, handlers.PatchHandler(model_class, self.Session), methods=["PATCH"]), + Route(id_path, handlers.DeleteHandler(model_class, self.Session), methods=["DELETE"]), + ]) - except Exception as e: - return JSONResponse({"error": str(e)}, status_code=500) - - return handler + if self.enable_swagger: + self.swagger_generator.register_endpoint(base_path, model_class) + self.swagger_generator.register_endpoint(id_path, model_class) def add_middleware(self, middleware_classes: List[Type["Middleware"]]): - """ - Add middleware classes to the application. + """Adds middleware classes to the application. Args: - middleware_classes: List of middleware classes to add. + middleware_classes: A list of middleware classes to add. """ self.middleware = middleware_classes def _print_endpoints(self): - """ - Print all registered endpoints to the console. - - This method displays a formatted table of all available endpoints, - including their paths, HTTP methods, and additional information. - """ + """Prints all registered endpoints to the console.""" if not self.routes: print("\nš” No endpoints registered") return @@ -226,55 +122,60 @@ def _print_endpoints(self): print("š LightAPI - Available Endpoints") print("=" * 60) - # Group routes by path for better display endpoint_info = [] - for route in self.routes: if hasattr(route, "path") and hasattr(route, "methods"): path = route.path methods = list(route.methods) if route.methods else ["*"] - - # Skip special routes (docs, openapi) - if path in ["/api/docs", "/openapi.json"]: + if path in ["/docs", "/redoc", "/openapi.json"]: continue - - # Format methods string methods_str = ", ".join(sorted(methods)) - - # Try to get endpoint class name if available - endpoint_name = "Unknown" - if hasattr(route, "endpoint"): - if hasattr(route.endpoint, "__name__"): - endpoint_name = route.endpoint.__name__ - elif hasattr(route.endpoint, "__class__"): - endpoint_name = route.endpoint.__class__.__name__ - + endpoint_name = route.endpoint.__class__.__name__ endpoint_info.append({"path": path, "methods": methods_str, "name": endpoint_name}) if not endpoint_info: print("š” No API endpoints found (only system routes)") return - # Calculate column widths for formatting max_path_len = max(len(info["path"]) for info in endpoint_info) max_methods_len = max(len(info["methods"]) for info in endpoint_info) - # Print header print(f"{'Path':<{max_path_len + 2}} {'Methods':<{max_methods_len + 2}} Endpoint") print("-" * (max_path_len + max_methods_len + 20)) - # Print each endpoint for info in sorted(endpoint_info, key=lambda x: x["path"]): print(f"{info['path']:<{max_path_len + 2}} {info['methods']:<{max_methods_len + 2}} {info['name']}") - # Print additional info if self.enable_swagger: base_url = f"http://{config.host}:{config.port}" - print(f"\nš API Documentation: {base_url}/api/docs") + print(f"\nš API Documentation: {base_url}/docs") + print(f" ReDoc: {base_url}/redoc") print(f"\nš Server will start on http://{config.host}:{config.port}") print("=" * 60) + def get_app(self) -> Starlette: + """Creates and returns a Starlette application instance. + + Returns: + A Starlette application instance. + """ + app = Starlette(debug=self.debug, routes=self.routes) + + if config.cors_origins: + app.add_middleware( + StarletteCORSMiddleware, + allow_origins=config.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + if self.enable_swagger: + app.state.swagger_generator = self.swagger_generator + + return app + def run( self, host: str = None, @@ -282,16 +183,14 @@ def run( debug: bool = None, reload: bool = None, ): - """ - Run the application server. + """Runs the application server. Args: - host: Host address to bind to. - port: Port to bind to. + host: The host address to bind to. + port: The port to bind to. debug: Whether to enable debug mode. reload: Whether to enable auto-reload on code changes. """ - # Update config with any provided values (only if not None) update_params = {} if host is not None: update_params["host"] = host @@ -305,37 +204,23 @@ def run( if update_params: config.update(**update_params) - # Print available endpoints before starting the server self._print_endpoints() - app = Starlette(debug=config.debug, routes=self.routes) + Base.metadata.create_all(self.engine) - # Add CORS middleware if origins are configured - if config.cors_origins: - app.add_middleware( - StarletteCORSMiddleware, - allow_origins=config.cors_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # Always set up swagger generator if enabled - if self.enable_swagger: - app.state.swagger_generator = self.swagger_generator + app = self.get_app() uvicorn.run( app, host=config.host, port=config.port, - log_level="debug" if config.debug else "info", + log_level="debug" if self.debug else "info", reload=config.reload, ) class Response(JSONResponse): - """ - Custom JSON response class. + """Custom JSON response class. Extends Starlette's JSONResponse with a simplified constructor and default application/json media type. @@ -349,15 +234,14 @@ def __init__( media_type: str = None, content_type: str = None, ): - """ - Initialize a new Response. + """Initializes a new Response. Args: content: The response content. - status_code: HTTP status code. - headers: HTTP headers. - media_type: HTTP media type. - content_type: HTTP content type (alias for media_type). + status_code: The HTTP status code. + headers: The HTTP headers. + media_type: The HTTP media type. + content_type: The HTTP content type (alias for media_type). """ # Store the original content for tests to access self._test_content = content @@ -374,7 +258,7 @@ def __init__( ) def __getattribute__(self, name): - """Override attribute access to provide test compatibility for body.""" + """Overrides attribute access to provide test compatibility for body.""" if name == "body": # Check if we're in a test context (looking for TestClient or similar) import inspect @@ -432,8 +316,8 @@ def __getattribute__(self, name): return super().__getattribute__(name) def decode(self): - """ - Decode the body content for tests that expect this method. + """Decodes the body content for tests that expect this method. + This method maintains compatibility with tests that expect the body to be bytes with a decode method. """ @@ -454,16 +338,14 @@ def decode(self): class Middleware: - """ - Base class for middleware components. + """Base class for middleware components. Middleware can process requests before they reach the endpoint and responses before they are returned to the client. """ def process(self, request, response): - """ - Process a request or response. + """Processes a request or response. This method is called twice during request handling: 1. Before the request reaches the endpoint (response is None) @@ -480,21 +362,19 @@ def process(self, request, response): class CORSMiddleware(Middleware): - """ - CORS (Cross-Origin Resource Sharing) middleware. + """CORS (Cross-Origin Resource Sharing) middleware. Handles CORS preflight requests and adds appropriate headers to responses. This provides a more flexible alternative to Starlette's built-in CORS middleware. """ def __init__(self, allow_origins=None, allow_methods=None, allow_headers=None): - """ - Initialize CORS middleware. + """Initializes CORS middleware. Args: - allow_origins: List of allowed origins, defaults to ['*'] - allow_methods: List of allowed HTTP methods - allow_headers: List of allowed headers + allow_origins: A list of allowed origins, defaults to ['*']. + allow_methods: A list of allowed HTTP methods. + allow_headers: A list of allowed headers. """ if allow_origins is None: allow_origins = ["*"] @@ -508,15 +388,14 @@ def __init__(self, allow_origins=None, allow_methods=None, allow_headers=None): self.allow_headers = allow_headers def process(self, request, response): - """ - Process CORS requests and add appropriate headers. + """Processes CORS requests and add appropriate headers. Args: - request: The HTTP request - response: The HTTP response (None for pre-processing) + request: The HTTP request. + response: The HTTP response (None for pre-processing). Returns: - Response with CORS headers or preflight response + A response with CORS headers or a preflight response. """ if response is None: # Handle preflight OPTIONS requests @@ -568,19 +447,17 @@ def process(self, request, response): class AuthenticationMiddleware(Middleware): - """ - Authentication middleware that integrates with authentication classes. + """Authentication middleware that integrates with authentication classes. Automatically handles authentication and returns appropriate error responses when authentication fails. Supports skipping authentication for OPTIONS requests. """ def __init__(self, authentication_class=None): - """ - Initialize authentication middleware. + """Initializes authentication middleware. Args: - authentication_class: The authentication class to use + authentication_class: The authentication class to use. """ self.authentication_class = authentication_class if authentication_class: @@ -589,15 +466,14 @@ def __init__(self, authentication_class=None): self.authenticator = None def process(self, request, response): - """ - Process authentication for requests. + """Processes authentication for requests. Args: - request: The HTTP request - response: The HTTP response (None for pre-processing) + request: The HTTP request. + response: The HTTP response (None for pre-processing). Returns: - Error response if authentication fails, otherwise None/response + An error response if authentication fails, otherwise None/response. """ if response is None and self.authenticator: # Pre-processing: check authentication diff --git a/lightapi/database.py b/lightapi/database.py index 1a7660f..a206257 100644 --- a/lightapi/database.py +++ b/lightapi/database.py @@ -6,14 +6,28 @@ from .config import config + +def setup_database(database_url: str = "sqlite:///app.db"): + """Sets up the database connection and returns the engine and session factory. + + Args: + database_url: The SQLAlchemy database URL. + + Returns: + A tuple containing the SQLAlchemy engine and session factory. + """ + import sqlalchemy + engine = sqlalchemy.create_engine(database_url) + Session = sqlalchemy.orm.sessionmaker(bind=engine) + return engine, Session + engine = create_engine(config.database_url) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @as_declarative() class Base: - """ - Custom SQLAlchemy base class for all models. + """Custom SQLAlchemy base class for all models. Provides automatic __tablename__ generation and utility methods for model instances to make working with SQLAlchemy models easier. @@ -30,8 +44,7 @@ class Base: @property def table(self): - """ - Get the table metadata for this model. + """Gets the table metadata for this model. Returns: The SQLAlchemy Table object for this model. @@ -40,33 +53,32 @@ def table(self): @declared_attr def __tablename__(cls): - """ - Generate the table name based on the class name. + """Generates the table name based on the class name. The table name is derived by converting the class name to lowercase. Returns: - str: The generated table name. + The generated table name. """ return cls.__name__.lower() @property def pk(self): + """Returns the primary key of the model instance.""" return self.id def serialize(self) -> dict: - """ - Convert the model instance into a dictionary representation. + """Converts the model instance into a dictionary representation. Each key in the dictionary corresponds to a column name, and the value is the data stored in that column. Datetime objects are converted to strings. Returns: - dict: A dictionary representation of the model instance. + A dictionary representation of the model instance. """ return { column.name: ( getattr(self, column.name).isoformat() if isinstance(getattr(self, column.name), datetime) else getattr(self, column.name) ) for column in self.table.columns - } + } \ No newline at end of file diff --git a/lightapi/handlers.py b/lightapi/handlers.py index 32bc377..833f455 100644 --- a/lightapi/handlers.py +++ b/lightapi/handlers.py @@ -1,182 +1,163 @@ import datetime import json +import logging +import re +import traceback from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import List, Type -from aiohttp import web +from dateutil import parser from sqlalchemy import inspect from sqlalchemy.exc import IntegrityError, StatementError from sqlalchemy.orm import Session, sessionmaker +from starlette.responses import JSONResponse, Response from lightapi.database import Base, SessionLocal -def create_handler(model: Type[Base], session_factory=SessionLocal) -> List[web.RouteDef]: - """ - Creates a list of route handlers for the given model. - Accepts a session_factory to use for DB sessions. +def create_handler(model: Type[Base], session_factory=SessionLocal): + """Creates a list of route handlers for the given model. + + Args: + model: The SQLAlchemy model class to create handlers for. + session_factory: The session factory to use for database connections. """ - return [ - web.post(f"/{model.__tablename__}/", CreateHandler(model, session_factory)), - web.get(f"/{model.__tablename__}/", RetrieveAllHandler(model, session_factory)), - web.get(f"/{model.__tablename__}/{{id}}", ReadHandler(model, session_factory)), - web.put(f"/{model.__tablename__}/{{id}}", UpdateHandler(model, session_factory)), - web.delete(f"/{model.__tablename__}/{{id}}", DeleteHandler(model, session_factory)), - web.patch(f"/{model.__tablename__}/{{id}}", PatchHandler(model, session_factory)), - ] + # This function is no longer the primary way of creating routes, + # but we'll keep it for now and adapt it. + # The actual route creation will be handled in LightApi.register + pass @dataclass class AbstractHandler(ABC): - """ - Abstract base class for handling HTTP requests related to a specific model. - - Attributes: - model (Base): The SQLAlchemy model class to operate on. - session_factory (sessionmaker): The session factory to use for database operations. - """ + """Abstract base class for handling HTTP requests related to a specific model.""" model: Type[Base] = field(default=None) session_factory: sessionmaker = field(default=SessionLocal) @abstractmethod - async def handle(self, db: Session, request: web.Request): - """ - Abstract method to handle the HTTP request. + async def handle(self, db: Session, request): + """Abstract method to handle the HTTP request. Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. - - Raises: - NotImplementedError: If the method is not implemented by subclasses. + db: The SQLAlchemy session for database operations. + request: The Starlette request object. """ raise NotImplementedError("Method not implemented") - async def __call__(self, request: web.Request, *args, **kwargs): - """ - Calls the handler with the provided request. + async def __call__(self, scope, receive, send): + """Makes the handler a callable ASGI application. Args: - request (web.Request): The aiohttp web request object. - - Returns: - web.Response: The response returned by the handler. + scope: The ASGI scope. + receive: The ASGI receive channel. + send: The ASGI send channel. """ + from starlette.requests import Request + request = Request(scope, receive) db: Session = self.session_factory() try: - return await self.handle(db, request) + response = await self.handle(db, request) + await response(scope, receive, send) + except Exception as e: + logging.error(f"Unhandled exception: {e}\n{traceback.format_exc()}") + # TODO: check for debug mode from app config + response = JSONResponse( + {"error": "Internal Server Error", "message": str(e), "traceback": traceback.format_exc()}, + status_code=500, + ) + await response(scope, receive, send) finally: db.close() - async def get_request_json(self, request: web.Request): - """ - Extracts JSON data from the request body. + async def get_request_json(self, request): + """Extracts JSON data from the request body. Args: - request (web.Request): The aiohttp web request object. + request: The Starlette request object. Returns: - dict: The JSON data from the request body. + A dictionary of the JSON body. """ return await request.json() - def get_item_by_id(self, db: Session, item_id: int): - """ - Retrieves an item by its primary key. - - Args: - db (Session): The SQLAlchemy session for database operations. - item_id (int): The primary key of the item to retrieve. - - Returns: - Base: The item retrieved from the database, or None if not found. - """ - return db.query(self.model).filter(self.model.id == item_id).first() - def add_and_commit_item(self, db: Session, item): - """ - Adds and commits a new item to the database. + """Adds and commits a new item to the database. Args: - db (Session): The SQLAlchemy session for database operations. - item (Base): The item to add and commit. + db: The SQLAlchemy session for database operations. + item: The item to add and commit. Returns: - Base: The item after committing to the database. + The item after committing to the database, or a JSONResponse on error. """ try: db.add(item) db.commit() db.refresh(item) - - if hasattr(self.model, "id"): - if isinstance(self.model.id, tuple): - filters = [col == getattr(item, col.name) for col in self.model.id] - item = db.query(self.model).filter(*filters).first() - else: - item = db.query(self.model).filter(self.model.id == getattr(item, self.model.id.name)).first() - - mapper = inspect(self.model) - for col in self.model.__table__.columns: - if getattr(item, col.name) is None and col.default is not None and col.default.is_scalar: - setattr(item, col.name, col.default.arg) - - if hasattr(col.type, "python_type"): - if col.type.python_type is datetime.datetime and isinstance(getattr(item, col.name), str): - try: - setattr(item, col.name, datetime.datetime.fromisoformat(getattr(item, col.name))) - except ValueError: - pass - elif col.type.python_type is datetime.date and isinstance(getattr(item, col.name), str): - try: - setattr(item, col.name, datetime.date.fromisoformat(getattr(item, col.name))) - except ValueError: - return self.json_error_response(f"Invalid date format for field '{col.name}'", status=400) return item - except (IntegrityError, StatementError) as e: + except IntegrityError as e: + db.rollback() + if "UNIQUE constraint failed" in str(e.orig) or "Duplicate entry" in str(e.orig): + match = re.search(r"failed: ([\w\.]+)", str(e.orig)) + if match: + column = match.group(1) + return self.json_error_response(f"Unique constraint violated for {column}.", status=409) + return self.json_error_response("Unique constraint violated.", status=409) + return self.json_error_response(f"Database integrity error: {e.orig}", status=400) + except StatementError as e: db.rollback() - return self.json_error_response(str(e), status=409) + return self.json_error_response(f"Database statement error: {e.orig}", status=400) def delete_and_commit_item(self, db: Session, item): - """ - Deletes and commits the removal of an item from the database. + """Deletes and commits the removal of an item from the database. Args: - db (Session): The SQLAlchemy session for database operations. - item (Base): The item to delete. + db: The SQLAlchemy session for database operations. + item: The item to delete. """ db.delete(item) db.commit() def json_response(self, item, status=200): - """ - Creates a JSON response for the given item. + """Creates a JSON response for the given item. Args: - item (Base): The item to serialize and return. - status (int, optional): The HTTP status code. Defaults to 200. + item: The item to serialize and return. + status: The HTTP status code. Returns: - web.Response: The JSON response containing the serialized item. + A Starlette JSONResponse. """ - return web.json_response(item.serialize(), status=status) + return JSONResponse(item.serialize(), status_code=status) - def json_error_response(self, error_message, status=404): - """ - Creates a JSON response for an error message. + def json_error_response(self, error_details, status=404): + """Creates a JSON response for an error message. Args: - error_message (str): The error message to return. - status (int, optional): The HTTP status code. Defaults to 404. + error_details: The error message or a dict with details. + status: The HTTP status code. Returns: - web.Response: The JSON response containing the error message. + A Starlette JSONResponse. """ - return web.json_response({"error": error_message}, status=status) + if isinstance(error_details, str): + error_payload = {"error": error_details} + else: + error_payload = error_details + return JSONResponse(error_payload, status_code=status) def _parse_pk_value(self, value, col): + """Parses a primary key value to the correct type. + + Args: + value: The value to parse. + col: The SQLAlchemy column object. + + Returns: + The parsed value. + """ try: if hasattr(col.type, "python_type") and col.type.python_type is int: return int(value) @@ -185,31 +166,18 @@ def _parse_pk_value(self, value, col): return value -class DateTimeEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, datetime): - return obj.isoformat() - return super().default(obj) - - class CreateHandler(AbstractHandler): - """ - Handles HTTP POST requests to create a new item. - """ - - def __init__(self, model, session_factory=SessionLocal): - super().__init__(model, session_factory) + """Handles HTTP POST requests to create a new item.""" async def handle(self, db, request): - """ - Processes the POST request to create and save a new item. + """Processes the POST request to create and save a new item. Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. + db: The SQLAlchemy session for database operations. + request: The Starlette request object. Returns: - web.Response: The JSON response containing the created item. + A Starlette JSONResponse containing the created item or an error. """ data = await self.get_request_json(request) @@ -219,11 +187,7 @@ async def handle(self, db, request): if col.name not in data: missing.append(col.name) if missing: - return web.json_response({"error": f"Missing required fields: {', '.join(missing)}"}, status=400) - - if "amount" in data and isinstance(data["amount"], (int, float)): - if data["amount"] < 0: - return web.json_response({"error": "Amount must be non-negative"}, status=400) + return JSONResponse({"error": f"Missing required fields: {', '.join(missing)}"}, status_code=400) for col in self.model.__table__.columns: if col.name in data: @@ -231,238 +195,117 @@ async def handle(self, db, request): if hasattr(col.type, "python_type"): if col.type.python_type is datetime.datetime and isinstance(val, str): try: - data[col.name] = datetime.datetime.fromisoformat(val) + data[col.name] = parser.parse(val) except ValueError: - return web.json_response({"error": f"Invalid datetime format for field '{col.name}'"}, status=400) + return JSONResponse({"error": f"Invalid datetime format for field '{col.name}'"}, status_code=400) elif col.type.python_type is datetime.date and isinstance(val, str): try: - data[col.name] = datetime.date.fromisoformat(val) + data[col.name] = parser.parse(val).date() except ValueError: - return web.json_response({"error": f"Invalid date format for field '{col.name}'"}, status=400) + return JSONResponse({"error": f"Invalid date format for field '{col.name}'"}, status_code=400) + item = self.model(**data) item = self.add_and_commit_item(db, item) - if isinstance(item, web.Response): + if isinstance(item, JSONResponse): return item return self.json_response(item, status=201) class ReadHandler(AbstractHandler): - """ - Handles HTTP GET requests to retrieve one or all items. - """ - - def __init__(self, model, session_factory=SessionLocal, pk_cols=None): - super().__init__(model, session_factory) - self.pk_cols = pk_cols + """Handles HTTP GET requests to retrieve a single item.""" async def handle(self, db, request): - """ - Processes the GET request to retrieve an item by ID or all items. + """Processes the GET request to retrieve an item by ID. Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. + db: The SQLAlchemy session for database operations. + request: The Starlette request object. Returns: - web.Response: The JSON response containing the item(s) or an error message. + A Starlette JSONResponse containing the item or a 404 error. """ - - if self.pk_cols and len(self.pk_cols) > 1: - pk_values = [request.match_info.get(col) for col in self.pk_cols] - if None in pk_values: - return web.json_response({"error": "Missing composite key(s)"}, status=400) - filters = [ - getattr(self.model, col) == self._parse_pk_value(val, getattr(self.model, col)) for col, val in zip(self.pk_cols, pk_values) - ] - item = db.query(self.model).filter(*filters).first() - else: - pk_col = self.pk_cols[0] if self.pk_cols else "id" - pk_value = request.match_info.get(pk_col) - item = ( - db.query(self.model) - .filter(getattr(self.model, pk_col) == self._parse_pk_value(pk_value, getattr(self.model, pk_col))) - .first() - ) + pk_col = "id" # Assuming 'id' for now + pk_value = request.path_params.get(pk_col) + + item = db.query(self.model).filter(getattr(self.model, pk_col) == pk_value).first() + if not item: - return web.json_response({"error": "Not found"}, status=404) + return self.json_error_response(f"{self.model.__name__} with id {pk_value} not found", status=404) return self.json_response(item, status=200) -class UpdateHandler(AbstractHandler): - """ - Handles HTTP PUT requests to update an existing item. - """ - - def __init__(self, model, session_factory=SessionLocal, pk_cols=None): - super().__init__(model, session_factory) - self.pk_cols = pk_cols +class RetrieveAllHandler(AbstractHandler): + """Handles HTTP GET requests to retrieve all items.""" async def handle(self, db, request): - """ - Processes the PUT request to update an existing item. + """Processes the GET request to retrieve all items. Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. + db: The SQLAlchemy session for database operations. + request: The Starlette request object. Returns: - web.Response: The JSON response containing the updated item or an error message. + A Starlette JSONResponse containing all items. """ - if self.pk_cols and len(self.pk_cols) > 1: - pk_values = [request.match_info.get(col) for col in self.pk_cols] - if None in pk_values: - return self.json_error_response("Missing composite key(s)", status=400) - filters = [ - getattr(self.model, col) == self._parse_pk_value(val, getattr(self.model, col)) for col, val in zip(self.pk_cols, pk_values) - ] - item = db.query(self.model).filter(*filters).first() - else: - pk_col = self.pk_cols[0] if self.pk_cols else "id" - pk_value = request.match_info.get(pk_col) - item = ( - db.query(self.model) - .filter(getattr(self.model, pk_col) == self._parse_pk_value(pk_value, getattr(self.model, pk_col))) - .first() - ) - if not item: - return self.json_error_response("Item not found", status=404) - - data = await self.get_request_json(request) - for key, value in data.items(): - setattr(item, key, value) - - item = self.add_and_commit_item(db, item) - if isinstance(item, web.Response): - return item - return self.json_response(item, status=200) - + items = db.query(self.model).all() + response = [item.serialize() for item in items] + return JSONResponse(response, status_code=200) -class PatchHandler(AbstractHandler): - """ - Handles HTTP PATCH requests to partially update an existing item. - """ - def __init__(self, model, session_factory=SessionLocal, pk_cols=None): - super().__init__(model, session_factory) - self.pk_cols = pk_cols +class UpdateHandler(AbstractHandler): + """Handles HTTP PUT requests to update an existing item.""" async def handle(self, db, request): - """ - Processes the PATCH request to partially update an existing item. + """Processes the PUT request to update an existing item. Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. + db: The SQLAlchemy session for database operations. + request: The Starlette request object. Returns: - web.Response: The JSON response containing the updated item or an error message. + A Starlette JSONResponse containing the updated item or an error. """ - if self.pk_cols and len(self.pk_cols) > 1: - pk_values = [request.match_info.get(col) for col in self.pk_cols] - if None in pk_values: - return self.json_error_response("Missing composite key(s)", status=400) - filters = [ - getattr(self.model, col) == self._parse_pk_value(val, getattr(self.model, col)) for col, val in zip(self.pk_cols, pk_values) - ] - item = db.query(self.model).filter(*filters).first() - else: - pk_col = self.pk_cols[0] if self.pk_cols else "id" - pk_value = request.match_info.get(pk_col) - item = ( - db.query(self.model) - .filter(getattr(self.model, pk_col) == self._parse_pk_value(pk_value, getattr(self.model, pk_col))) - .first() - ) + pk_col = "id" + pk_value = request.path_params.get(pk_col) + + item = db.query(self.model).filter(getattr(self.model, pk_col) == pk_value).first() if not item: - return self.json_error_response("Item not found", status=404) + return self.json_error_response(f"{self.model.__name__} with id {pk_value} not found", status=404) data = await self.get_request_json(request) - - for col in self.model.__table__.columns: - if col.name in data: - val = data[col.name] - if hasattr(col.type, "python_type"): - if col.type.python_type is datetime.datetime and isinstance(val, str): - try: - data[col.name] = datetime.datetime.fromisoformat(val) - except ValueError: - return web.json_response({"error": f"Invalid datetime format for field '{col.name}'"}, status=400) - elif col.type.python_type is datetime.date and isinstance(val, str): - try: - data[col.name] = datetime.date.fromisoformat(val) - except ValueError: - return web.json_response({"error": f"Invalid date format for field '{col.name}'"}, status=400) for key, value in data.items(): setattr(item, key, value) item = self.add_and_commit_item(db, item) - if isinstance(item, web.Response): + if isinstance(item, JSONResponse): return item return self.json_response(item, status=200) -class DeleteHandler(AbstractHandler): - """ - Handles HTTP DELETE requests to delete an existing item. - """ +class PatchHandler(UpdateHandler): + """Handles HTTP PATCH requests to partially update an existing item.""" + pass - def __init__(self, model, session_factory=SessionLocal, pk_cols=None): - super().__init__(model, session_factory) - self.pk_cols = pk_cols + +class DeleteHandler(AbstractHandler): + """Handles HTTP DELETE requests to delete an existing item.""" async def handle(self, db, request): - """ - Processes the DELETE request to remove an existing item. + """Processes the DELETE request to remove an existing item. Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. + db: The SQLAlchemy session for database operations. + request: The Starlette request object. Returns: - web.Response: An empty response with status 204 if the item is deleted. + A Starlette Response with status 204. """ - if self.pk_cols and len(self.pk_cols) > 1: - pk_values = [request.match_info.get(col) for col in self.pk_cols] - if None in pk_values: - return self.json_error_response("Missing composite key(s)", status=400) - filters = [ - getattr(self.model, col) == self._parse_pk_value(val, getattr(self.model, col)) for col, val in zip(self.pk_cols, pk_values) - ] - item = db.query(self.model).filter(*filters).first() - else: - pk_col = self.pk_cols[0] if self.pk_cols else "id" - pk_value = request.match_info.get(pk_col) - item = ( - db.query(self.model) - .filter(getattr(self.model, pk_col) == self._parse_pk_value(pk_value, getattr(self.model, pk_col))) - .first() - ) + pk_col = "id" + pk_value = request.path_params.get(pk_col) + + item = db.query(self.model).filter(getattr(self.model, pk_col) == pk_value).first() if not item: - return self.json_error_response("Item not found", status=404) + return self.json_error_response(f"{self.model.__name__} with id {pk_value} not found", status=404) self.delete_and_commit_item(db, item) - return web.Response(status=204) - - -class RetrieveAllHandler(AbstractHandler): - """ - Handles HTTP GET requests to retrieve all items. - """ - - def __init__(self, model, session_factory=SessionLocal): - super().__init__(model, session_factory) - - async def handle(self, db, request): - """ - Processes the GET request to retrieve all items. - - Args: - db (Session): The SQLAlchemy session for database operations. - request (web.Request): The aiohttp web request object. - - Returns: - web.Response: The JSON response containing all items. - """ - items = db.query(self.model).all() - response = [item.serialize() for item in items] - return web.json_response(response, status=200, dumps=json.dumps) + return Response(status_code=204) diff --git a/lightapi/models.py b/lightapi/models.py index 6fa976a..400f794 100644 --- a/lightapi/models.py +++ b/lightapi/models.py @@ -8,31 +8,6 @@ from lightapi.database import Base -def setup_database(database_url: str = "sqlite:///app.db"): - """ - Set up the database connection and create tables. - - Initializes SQLAlchemy with the provided database URL, - creates the database tables, and returns the engine - and session factory. - - Args: - database_url: The SQLAlchemy database URL. - - Returns: - tuple: A tuple containing (engine, Session). - """ - engine = sqlalchemy.create_engine(database_url) - - try: - Base.metadata.create_all(engine) - except Exception: - pass - - Session = sqlalchemy.orm.sessionmaker(bind=engine) - return engine, Session - - def register_model_class(cls): """ Register a RestEndpoint class as a SQLAlchemy model. diff --git a/lightapi/swagger.py b/lightapi/swagger.py index bbdc81e..bf7abb4 100644 --- a/lightapi/swagger.py +++ b/lightapi/swagger.py @@ -10,19 +10,18 @@ class SwaggerGenerator: - """ - Generates OpenAPI documentation from LightAPI endpoint classes. + """Generates OpenAPI documentation from LightAPI endpoint classes. This class analyzes RestEndpoint classes to extract information about their schemas, HTTP methods, validation, and other metadata to build a complete OpenAPI specification document. Attributes: - title (str): The title of the API documentation. - version (str): The API version. - description (str): A description of the API. - paths (dict): Endpoint paths and their operations. - components (dict): Schema definitions and security schemes. + title: The title of the API documentation. + version: The API version. + description: A description of the API. + paths: A dictionary of endpoint paths and their operations. + components: A dictionary of schema definitions and security schemes. """ def __init__( @@ -31,8 +30,7 @@ def __init__( version: str = "1.0.0", description: str = "API documentation", ): - """ - Initialize a new SwaggerGenerator. + """Initializes a new SwaggerGenerator. Args: title: The title of the API documentation. @@ -55,8 +53,7 @@ def __init__( } def register_endpoint(self, path: str, endpoint_class: Type[RestEndpoint]): - """ - Register an endpoint class for OpenAPI documentation. + """Registers an endpoint class for OpenAPI documentation. Analyzes the endpoint class to extract HTTP methods, schemas, and other metadata to include in the OpenAPI documentation. @@ -84,8 +81,7 @@ def register_endpoint(self, path: str, endpoint_class: Type[RestEndpoint]): self.paths[path] = path_operations def _generate_schema(self, endpoint_class: Type[RestEndpoint]) -> Dict[str, Any]: - """ - Generate an OpenAPI schema from an endpoint class. + """Generates an OpenAPI schema from an endpoint class. Extracts information about the model fields, their types, and validation requirements to create an OpenAPI schema definition. @@ -94,7 +90,7 @@ def _generate_schema(self, endpoint_class: Type[RestEndpoint]) -> Dict[str, Any] endpoint_class: The RestEndpoint class to analyze. Returns: - A dict containing the OpenAPI schema definition. + A dictionary containing the OpenAPI schema definition. """ properties = {} required = [] @@ -130,14 +126,13 @@ def _generate_schema(self, endpoint_class: Type[RestEndpoint]) -> Dict[str, Any] } def _map_sql_type_to_openapi(self, sql_type) -> Dict[str, Any]: - """ - Map SQLAlchemy column types to OpenAPI data types. + """Maps SQLAlchemy column types to OpenAPI data types. Args: - sql_type: SQLAlchemy type object. + sql_type: The SQLAlchemy type object. Returns: - A dict containing the OpenAPI type definition. + A dictionary containing the OpenAPI type definition. """ type_map = { "INTEGER": {"type": "integer"}, @@ -159,8 +154,7 @@ def _map_sql_type_to_openapi(self, sql_type) -> Dict[str, Any]: return {"type": "string"} def _generate_operation(self, endpoint_class: Type[RestEndpoint], method: str, model_name: str) -> Dict[str, Any]: - """ - Generate an OpenAPI operation object for an endpoint method. + """Generates an OpenAPI operation object for an endpoint method. Args: endpoint_class: The RestEndpoint class. @@ -168,7 +162,7 @@ def _generate_operation(self, endpoint_class: Type[RestEndpoint], method: str, m model_name: The model name for reference. Returns: - A dict containing the OpenAPI operation definition. + A dictionary containing the OpenAPI operation definition. """ method_handler = getattr(endpoint_class, method, None) description = "" @@ -200,11 +194,10 @@ def _generate_operation(self, endpoint_class: Type[RestEndpoint], method: str, m return operation def generate_openapi_spec(self) -> Dict[str, Any]: - """ - Generate the complete OpenAPI specification document. + """Generates the complete OpenAPI specification document. Returns: - A dict containing the full OpenAPI specification. + A dictionary containing the full OpenAPI specification. """ return { "openapi": "3.0.0", @@ -218,8 +211,7 @@ def generate_openapi_spec(self) -> Dict[str, Any]: } def get_swagger_ui(self) -> HTMLResponse: - """ - Generate the Swagger UI HTML page for interactive API documentation. + """Generates the Swagger UI HTML page for interactive API documentation. Returns: An HTMLResponse containing the Swagger UI interface. @@ -260,9 +252,35 @@ def get_swagger_ui(self) -> HTMLResponse: """ ) - def get_openapi_json(self) -> JSONResponse: + def get_redoc_html(self) -> HTMLResponse: + """Generates the ReDoc HTML page for alternative API documentation. + + Returns: + An HTMLResponse containing the ReDoc UI interface. + """ + return HTMLResponse( + """ + + +
+