diff --git a/examples/05_advanced_caching_redis.py b/examples/advanced_caching_redis.py similarity index 100% rename from examples/05_advanced_caching_redis.py rename to examples/advanced_caching_redis.py diff --git a/examples/04_advanced_filtering_pagination.py b/examples/advanced_filtering_pagination.py similarity index 100% rename from examples/04_advanced_filtering_pagination.py rename to examples/advanced_filtering_pagination.py diff --git a/examples/03_advanced_validation.py b/examples/advanced_validation.py similarity index 100% rename from examples/03_advanced_validation.py rename to examples/advanced_validation.py diff --git a/examples/06_async_performance.py b/examples/async_performance.py similarity index 100% rename from examples/06_async_performance.py rename to examples/async_performance.py diff --git a/examples/02_authentication_jwt.py b/examples/authentication_jwt.py similarity index 100% rename from examples/02_authentication_jwt.py rename to examples/authentication_jwt.py diff --git a/examples/10_batch_operations.py b/examples/batch_operations.py similarity index 100% rename from examples/10_batch_operations.py rename to examples/batch_operations.py diff --git a/examples/10_blog_post.py b/examples/blog_post.py similarity index 100% rename from examples/10_blog_post.py rename to examples/blog_post.py diff --git a/examples/05_caching_redis_custom.py b/examples/caching_redis_custom.py similarity index 100% rename from examples/05_caching_redis_custom.py rename to examples/caching_redis_custom.py diff --git a/examples/10_comprehensive_ideal_usage.py b/examples/comprehensive_ideal_usage.py similarity index 100% rename from examples/10_comprehensive_ideal_usage.py rename to examples/comprehensive_ideal_usage.py diff --git a/examples/01_database_transactions.py b/examples/database_transactions.py similarity index 100% rename from examples/01_database_transactions.py rename to examples/database_transactions.py diff --git a/examples/01_error_handling_basic.py b/examples/error_handling_basic.py similarity index 100% rename from examples/01_error_handling_basic.py rename to examples/error_handling_basic.py diff --git a/examples/01_example.py b/examples/example.py similarity index 100% rename from examples/01_example.py rename to examples/example.py diff --git a/examples/04_filtering_pagination.py b/examples/filtering_pagination.py similarity index 100% rename from examples/04_filtering_pagination.py rename to examples/filtering_pagination.py diff --git a/examples/01_general_usage.py b/examples/general_usage.py similarity index 100% rename from examples/01_general_usage.py rename to examples/general_usage.py diff --git a/examples/10_mega_example.py b/examples/mega_example.py similarity index 100% rename from examples/10_mega_example.py rename to examples/mega_example.py diff --git a/examples/07_middleware_cors_auth.py b/examples/middleware_cors_auth.py similarity index 100% rename from examples/07_middleware_cors_auth.py rename to examples/middleware_cors_auth.py diff --git a/examples/07_middleware_custom.py b/examples/middleware_custom.py similarity index 100% rename from examples/07_middleware_custom.py rename to examples/middleware_custom.py diff --git a/examples/10_nested_resources.py b/examples/nested_resources.py similarity index 100% rename from examples/10_nested_resources.py rename to examples/nested_resources.py diff --git a/examples/10_relationships_sqlalchemy.py b/examples/relationships_sqlalchemy.py similarity index 100% rename from examples/10_relationships_sqlalchemy.py rename to examples/relationships_sqlalchemy.py diff --git a/examples/01_response_customization.py b/examples/response_customization.py similarity index 100% rename from examples/01_response_customization.py rename to examples/response_customization.py diff --git a/examples/01_rest_crud_basic.py b/examples/rest_crud_basic.py similarity index 100% rename from examples/01_rest_crud_basic.py rename to examples/rest_crud_basic.py diff --git a/examples/04_search_functionality.py b/examples/search_functionality.py similarity index 100% rename from examples/04_search_functionality.py rename to examples/search_functionality.py diff --git a/examples/08_swagger_openapi_docs.py b/examples/swagger_openapi_docs.py similarity index 100% rename from examples/08_swagger_openapi_docs.py rename to examples/swagger_openapi_docs.py diff --git a/examples/10_user_goal_example.py b/examples/user_goal_example.py similarity index 100% rename from examples/10_user_goal_example.py rename to examples/user_goal_example.py diff --git a/examples/03_validation_custom_fields.py b/examples/validation_custom_fields.py similarity index 100% rename from examples/03_validation_custom_fields.py rename to examples/validation_custom_fields.py diff --git a/examples/09_yaml_advanced_permissions.py b/examples/yaml_advanced_permissions.py similarity index 100% rename from examples/09_yaml_advanced_permissions.py rename to examples/yaml_advanced_permissions.py diff --git a/examples/09_yaml_basic_example.py b/examples/yaml_basic_example.py similarity index 100% rename from examples/09_yaml_basic_example.py rename to examples/yaml_basic_example.py diff --git a/examples/09_yaml_comprehensive_example.py b/examples/yaml_comprehensive_example.py similarity index 100% rename from examples/09_yaml_comprehensive_example.py rename to examples/yaml_comprehensive_example.py diff --git a/examples/09_yaml_configuration.py b/examples/yaml_configuration.py similarity index 100% rename from examples/09_yaml_configuration.py rename to examples/yaml_configuration.py diff --git a/examples/09_yaml_database_types.py b/examples/yaml_database_types.py similarity index 100% rename from examples/09_yaml_database_types.py rename to examples/yaml_database_types.py diff --git a/examples/09_yaml_environment_variables.py b/examples/yaml_environment_variables.py similarity index 100% rename from examples/09_yaml_environment_variables.py rename to examples/yaml_environment_variables.py diff --git a/examples/09_yaml_minimal_readonly.py b/examples/yaml_minimal_readonly.py similarity index 100% rename from examples/09_yaml_minimal_readonly.py rename to examples/yaml_minimal_readonly.py diff --git a/lightapi/auth.py b/lightapi/auth.py index f38d076..942f751 100644 --- a/lightapi/auth.py +++ b/lightapi/auth.py @@ -2,10 +2,14 @@ from datetime import datetime, timedelta from typing import Any, Dict, Optional -import jwt from starlette.responses import JSONResponse from .config import config +from .jwt_custom import jwt_encode, jwt_decode + + +class InvalidTokenError(Exception): + pass class BaseAuthentication: @@ -86,7 +90,7 @@ def authenticate(self, request): payload = self.decode_token(token) request.state.user = payload return True - except jwt.InvalidTokenError: + except InvalidTokenError: return False def generate_token(self, payload: Dict, expiration: Optional[int] = None) -> str: @@ -103,9 +107,9 @@ def generate_token(self, payload: Dict, expiration: Optional[int] = None) -> str exp_seconds = expiration or self.expiration token_data = { **payload, - "exp": datetime.utcnow() + timedelta(seconds=exp_seconds), + "exp": time.time() + exp_seconds, } - return jwt.encode(token_data, self.secret_key, algorithm=self.algorithm) + return jwt_encode(token_data, self.secret_key, algorithm=self.algorithm) def decode_token(self, token: str) -> Dict: """ @@ -118,6 +122,9 @@ def decode_token(self, token: str) -> Dict: dict: The decoded token payload. Raises: - jwt.InvalidTokenError: If the token is invalid or expired. + InvalidTokenError: If the token is invalid or expired. """ - return jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) + try: + return jwt_decode(token, self.secret_key, algorithms=[self.algorithm]) + except ValueError as e: + raise InvalidTokenError(str(e)) diff --git a/lightapi/jwt_custom.py b/lightapi/jwt_custom.py new file mode 100644 index 0000000..1545c27 --- /dev/null +++ b/lightapi/jwt_custom.py @@ -0,0 +1,61 @@ + +import base64 +import json +import hmac +import hashlib +import time + +def b64_encode(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("utf-8") + +def b64_decode(data: str) -> bytes: + padding = b"=" * (4 - (len(data) % 4)) + return base64.urlsafe_b64decode(data.encode("utf-8") + padding) + +def jwt_encode(payload: dict, secret: str, algorithm: str = "HS256") -> str: + header = {"alg": algorithm, "typ": "JWT"} + + encoded_header = b64_encode(json.dumps(header, separators=(",", ":")).encode("utf-8")) + encoded_payload = b64_encode(json.dumps(payload, separators=(",", ":")).encode("utf-8")) + + signing_input = f"{encoded_header}.{encoded_payload}".encode("utf-8") + + if algorithm == "HS256": + signature = hmac.new(secret.encode("utf-8"), signing_input, hashlib.sha256).digest() + else: + raise ValueError("Unsupported algorithm") + + encoded_signature = b64_encode(signature) + + return f"{encoded_header}.{encoded_payload}.{encoded_signature}" + +def jwt_decode(token: str, secret: str, algorithms: list[str] = ["HS256"]) -> dict: + try: + encoded_header, encoded_payload, encoded_signature = token.split(".") + except ValueError: + raise ValueError("Invalid token") + + header_data = json.loads(b64_decode(encoded_header)) + alg = header_data.get("alg") + + if not alg or alg not in algorithms: + raise ValueError("Invalid algorithm") + + signing_input = f"{encoded_header}.{encoded_payload}".encode("utf-8") + + if alg == "HS256": + expected_signature = hmac.new(secret.encode("utf-8"), signing_input, hashlib.sha256).digest() + else: + raise ValueError("Unsupported algorithm") + + decoded_signature = b64_decode(encoded_signature) + + if not hmac.compare_digest(decoded_signature, expected_signature): + raise ValueError("Invalid signature") + + payload = json.loads(b64_decode(encoded_payload)) + + if "exp" in payload and payload["exp"] < time.time(): + raise ValueError("Token has expired") + + return payload diff --git a/pyproject.toml b/pyproject.toml index 00713fd..2e0f2eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ classifiers = [ dependencies = [ "SQLAlchemy>=2.0.30,<3.0.0", "aiohttp>=3.9.5,<4.0.0", - "PyJWT>=2.8.0,<3.0.0", + "starlette>=0.37.0,<1.0.0", "uvicorn>=0.30.0,<1.0.0", "redis>=5.0.0,<6.0.0", @@ -56,7 +56,7 @@ dev = [ ] test = [ "pytest>=7.3.1,<8.0.0", - "PyJWT>=2.8.0,<3.0.0", + "starlette>=0.37.0,<1.0.0", "uvicorn>=0.30.0,<1.0.0", "redis>=5.0.0,<6.0.0", diff --git a/requirements.txt b/requirements.txt index 431ed2a..036ae19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,5 @@ setuptools==68.2.0 aiohttp==3.9.5 psycopg2-binary==2.9.9 -PyJWT==2.9.0 pytest==8.2.2 -PyYAML>=5.1 \ No newline at end of file +PyYAML>=5.1 diff --git a/tests/test_caching_example.py b/tests/test_caching_example.py index 8906ed1..ee44130 100644 --- a/tests/test_caching_example.py +++ b/tests/test_caching_example.py @@ -7,7 +7,7 @@ import pytest -from examples.05_caching_redis_custom import ( +from examples.caching_redis_custom import ( ConfigurableCacheEndpoint, CustomCache, WeatherEndpoint, @@ -32,8 +32,8 @@ def endpoint(self): endpoint.cache = CustomCache() return endpoint - @patch("examples.05_caching_redis_custom.time.sleep") - @patch("examples.05_caching_redis_custom.print") + @patch("examples.caching_redis_custom.time.sleep") + @patch("examples.caching_redis_custom.print") def test_get_cache_miss(self, mock_print, mock_sleep, endpoint): """Test that get generates new data on cache miss. @@ -64,8 +64,8 @@ class MockRequest: # Verify sleep was called to simulate slow operation mock_sleep.assert_called_once_with(0.1) - @patch("examples.05_caching_redis_custom.time.sleep") - @patch("examples.05_caching_redis_custom.print") + @patch("examples.caching_redis_custom.time.sleep") + @patch("examples.caching_redis_custom.print") def test_get_cache_hit(self, mock_print, mock_sleep, endpoint): """Test that get returns cached data on cache hit. @@ -107,7 +107,7 @@ class MockRequest: # Verify sleep was not called (no slow operation) mock_sleep.assert_not_called() - @patch("examples.05_caching_redis_custom.print") + @patch("examples.caching_redis_custom.print") def test_delete_specific_city(self, mock_print, endpoint): """Test that delete removes cache for a specific city. @@ -137,7 +137,7 @@ class MockRequest: # Verify delete cache method was called mock_print.assert_any_call("Cache DELETE for 'weather:London'") - @patch("examples.05_caching_redis_custom.print") + @patch("examples.caching_redis_custom.print") def test_delete_all_cities(self, mock_print, endpoint): """Test that delete with no city param clears all cache. @@ -185,8 +185,8 @@ def endpoint(self): endpoint.cache = CustomCache() return endpoint - @patch("examples.05_caching_redis_custom.time.sleep") - @patch("examples.05_caching_redis_custom.print") + @patch("examples.caching_redis_custom.time.sleep") + @patch("examples.caching_redis_custom.print") def test_get_with_custom_ttl(self, mock_print, mock_sleep, endpoint): """Test that get respects custom TTL from query params. @@ -215,8 +215,8 @@ class MockRequest: # Verify sleep was called to simulate slow operation mock_sleep.assert_called_once_with(1) - @patch("examples.05_caching_redis_custom.time.sleep") - @patch("examples.05_caching_redis_custom.print") + @patch("examples.caching_redis_custom.time.sleep") + @patch("examples.caching_redis_custom.print") def test_get_with_default_ttl(self, mock_print, mock_sleep, endpoint): """Test that get uses default TTL when not specified. @@ -241,8 +241,8 @@ class MockRequest: # Verify cache was set with default TTL mock_print.assert_any_call("Cache SET for 'resource:resource123' (expires in 60s)") - @patch("examples.05_caching_redis_custom.time.sleep") - @patch("examples.05_caching_redis_custom.print") + @patch("examples.caching_redis_custom.time.sleep") + @patch("examples.caching_redis_custom.print") def test_get_cache_hit(self, mock_print, mock_sleep, endpoint): """Test that get uses cached data when available. diff --git a/tests/test_custom_snippet.py b/tests/test_custom_snippet.py index 100f691..1a40014 100644 --- a/tests/test_custom_snippet.py +++ b/tests/test_custom_snippet.py @@ -9,7 +9,7 @@ import jwt from starlette.testclient import TestClient -from examples.07_middleware_cors_auth import Company, CustomEndpoint, create_app +from examples.middleware_cors_auth import Company, CustomEndpoint, create_app from lightapi.config import config from lightapi.core import Middleware, Response from lightapi.lightapi import LightApi diff --git a/tests/test_filtering_pagination_example.py b/tests/test_filtering_pagination_example.py index e9862a2..321be2f 100644 --- a/tests/test_filtering_pagination_example.py +++ b/tests/test_filtering_pagination_example.py @@ -6,7 +6,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from examples.04_filtering_pagination import ( +from examples.filtering_pagination import ( Product, ProductFilter, ProductPaginator, @@ -457,7 +457,7 @@ def test_init_database(self): # Call the init_database function with a custom engine from unittest.mock import patch - with patch("examples.filtering_pagination_04.create_engine", return_value=engine): + with patch("examples.filtering_pagination.create_engine", return_value=engine): init_database() # Create a session to verify the data