Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
19 changes: 13 additions & 6 deletions lightapi/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
from datetime import datetime, timedelta
from typing import Any, Dict, Optional

import jwt
from starlette.responses import JSONResponse

from .config import config
from .jwt_custom import jwt_encode, jwt_decode


class InvalidTokenError(Exception):
pass


class BaseAuthentication:
Expand Down Expand Up @@ -86,7 +90,7 @@ def authenticate(self, request):
payload = self.decode_token(token)
request.state.user = payload
return True
except jwt.InvalidTokenError:
except InvalidTokenError:
return False

def generate_token(self, payload: Dict, expiration: Optional[int] = None) -> str:
Expand All @@ -103,9 +107,9 @@ def generate_token(self, payload: Dict, expiration: Optional[int] = None) -> str
exp_seconds = expiration or self.expiration
token_data = {
**payload,
"exp": datetime.utcnow() + timedelta(seconds=exp_seconds),
"exp": time.time() + exp_seconds,
}
return jwt.encode(token_data, self.secret_key, algorithm=self.algorithm)
return jwt_encode(token_data, self.secret_key, algorithm=self.algorithm)

def decode_token(self, token: str) -> Dict:
"""
Expand All @@ -118,6 +122,9 @@ def decode_token(self, token: str) -> Dict:
dict: The decoded token payload.

Raises:
jwt.InvalidTokenError: If the token is invalid or expired.
InvalidTokenError: If the token is invalid or expired.
"""
return jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
try:
return jwt_decode(token, self.secret_key, algorithms=[self.algorithm])
except ValueError as e:
raise InvalidTokenError(str(e))
61 changes: 61 additions & 0 deletions lightapi/jwt_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@

import base64
import json
import hmac
import hashlib
import time

Comment on lines +2 to +7
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Make typing Python 3.8‑compatible and remove mutable default.

list[str] breaks on 3.8 and default list violates B006; pipeline is failing. Use Optional[List[str]] and set default inside.

Apply:

 import base64
 import json
 import hmac
 import hashlib
 import time
+from typing import List, Optional
@@
-def jwt_decode(token: str, secret: str, algorithms: list[str] = ["HS256"]) -> dict:
+def jwt_decode(token: str, secret: str, algorithms: Optional[List[str]] = None) -> dict:

And inside jwt_decode before using algorithms:

-    header_data = json.loads(b64_decode(encoded_header))
+    header_data = json.loads(b64_decode(encoded_header))
+    algorithms = algorithms or ["HS256"]

Also applies to: 32-32

🤖 Prompt for AI Agents
In lightapi/jwt_custom.py around lines 2-7 and line 32, update any type
annotations using list[str] to use typing.Optional and typing.List (e.g.,
Optional[List[str]]), add the necessary imports from typing, change default
mutable parameters from [] to None, and inside jwt_decode set algorithms =
algorithms or [] (or algorithms = list(algorithms) if you need a copy) before
use so the default is created at call time rather than at function definition.

def b64_encode(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("utf-8")

def b64_decode(data: str) -> bytes:
padding = b"=" * (4 - (len(data) % 4))
return base64.urlsafe_b64decode(data.encode("utf-8") + padding)
Comment on lines +11 to +13
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix Base64 padding (current code adds 4 "=" when none needed).

When len%4==0 you still append four "=" causing decode errors. Compute missing padding modulo 4.

Apply:

-def b64_decode(data: str) -> bytes:
-    padding = b"=" * (4 - (len(data) % 4))
-    return base64.urlsafe_b64decode(data.encode("utf-8") + padding)
+def b64_decode(data: str) -> bytes:
+    # Add only the required padding (0..3)
+    missing = (-len(data)) % 4
+    if missing:
+        data += "=" * missing
+    return base64.urlsafe_b64decode(data.encode("utf-8"))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def b64_decode(data: str) -> bytes:
padding = b"=" * (4 - (len(data) % 4))
return base64.urlsafe_b64decode(data.encode("utf-8") + padding)
def b64_decode(data: str) -> bytes:
# Add only the required padding (0..3)
missing = (-len(data)) % 4
if missing:
data += "=" * missing
return base64.urlsafe_b64decode(data.encode("utf-8"))
🤖 Prompt for AI Agents
In lightapi/jwt_custom.py around lines 11 to 13, the base64 padding logic always
appends four "=" when data length is already a multiple of 4 causing decode
errors; change the padding calculation to compute missing = (4 - (len(data) %
4)) % 4 (or equivalent negative-mod trick) and append b"=" * missing so you only
add the needed padding bytes before urlsafe_b64decode.


def jwt_encode(payload: dict, secret: str, algorithm: str = "HS256") -> str:
header = {"alg": algorithm, "typ": "JWT"}

encoded_header = b64_encode(json.dumps(header, separators=(",", ":")).encode("utf-8"))
encoded_payload = b64_encode(json.dumps(payload, separators=(",", ":")).encode("utf-8"))

signing_input = f"{encoded_header}.{encoded_payload}".encode("utf-8")

if algorithm == "HS256":
signature = hmac.new(secret.encode("utf-8"), signing_input, hashlib.sha256).digest()
else:
raise ValueError("Unsupported algorithm")

encoded_signature = b64_encode(signature)

return f"{encoded_header}.{encoded_payload}.{encoded_signature}"

def jwt_decode(token: str, secret: str, algorithms: list[str] = ["HS256"]) -> dict:
try:
encoded_header, encoded_payload, encoded_signature = token.split(".")
except ValueError:
raise ValueError("Invalid token")

header_data = json.loads(b64_decode(encoded_header))
alg = header_data.get("alg")

if not alg or alg not in algorithms:
raise ValueError("Invalid algorithm")

signing_input = f"{encoded_header}.{encoded_payload}".encode("utf-8")

if alg == "HS256":
expected_signature = hmac.new(secret.encode("utf-8"), signing_input, hashlib.sha256).digest()
else:
raise ValueError("Unsupported algorithm")

decoded_signature = b64_decode(encoded_signature)

if not hmac.compare_digest(decoded_signature, expected_signature):
raise ValueError("Invalid signature")

payload = json.loads(b64_decode(encoded_payload))

if "exp" in payload and payload["exp"] < time.time():
raise ValueError("Token has expired")

return payload
Comment on lines +33 to +61
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Normalize decode errors to ValueError to avoid 500s upstream.

Base64/JSON errors currently leak as non‑ValueError; auth.decode_token won’t catch them. Wrap and re‑raise ValueError.

Apply:

-    try:
-        encoded_header, encoded_payload, encoded_signature = token.split(".")
-    except ValueError:
-        raise ValueError("Invalid token")
+    try:
+        encoded_header, encoded_payload, encoded_signature = token.split(".")
+    except ValueError:
+        raise ValueError("Invalid token") from None
@@
-    header_data = json.loads(b64_decode(encoded_header))
+    try:
+        header_data = json.loads(b64_decode(encoded_header))
+    except Exception:
+        raise ValueError("Invalid token") from None
@@
-    decoded_signature = b64_decode(encoded_signature)
+    try:
+        decoded_signature = b64_decode(encoded_signature)
+    except Exception:
+        raise ValueError("Invalid token") from None
@@
-    payload = json.loads(b64_decode(encoded_payload))
+    try:
+        payload = json.loads(b64_decode(encoded_payload))
+    except Exception:
+        raise ValueError("Invalid token") from None
🧰 Tools
🪛 Ruff (0.14.1)

36-36: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


36-36: Avoid specifying long messages outside the exception class

(TRY003)


42-42: Avoid specifying long messages outside the exception class

(TRY003)


49-49: Avoid specifying long messages outside the exception class

(TRY003)


54-54: Avoid specifying long messages outside the exception class

(TRY003)


59-59: Avoid specifying long messages outside the exception class

(TRY003)

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ classifiers = [
dependencies = [
"SQLAlchemy>=2.0.30,<3.0.0",
"aiohttp>=3.9.5,<4.0.0",
"PyJWT>=2.8.0,<3.0.0",

"starlette>=0.37.0,<1.0.0",
"uvicorn>=0.30.0,<1.0.0",
"redis>=5.0.0,<6.0.0",
Expand All @@ -56,7 +56,7 @@ dev = [
]
test = [
"pytest>=7.3.1,<8.0.0",
"PyJWT>=2.8.0,<3.0.0",

"starlette>=0.37.0,<1.0.0",
"uvicorn>=0.30.0,<1.0.0",
"redis>=5.0.0,<6.0.0",
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ setuptools==68.2.0
aiohttp==3.9.5
psycopg2-binary==2.9.9

PyJWT==2.9.0
pytest==8.2.2
PyYAML>=5.1
PyYAML>=5.1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

LGTM; confirm tests don’t rely on PyJWT.

PyJWT is removed project-wide; ensure no tests still import it. See fix proposed in tests/test_custom_snippet.py.

Run to find stragglers:


🏁 Script executed:

#!/bin/bash
rg -nP '^\s*import\s+jwt\b|jwt\.(encode|decode)\b|from\s+jwt\s+import' -g '!**/.venv/**' -S

Length of output: 1747


Critical: Tests still import PyJWT despite removal from requirements.txt.

The verification found multiple test files that still import and use PyJWT:

  • tests/test_custom_snippet.py:9 - import jwt
  • tests/test_auth.py:8, 22, 56, 86 - import jwt with jwt.encode() and jwt.decode() calls

Removing PyJWT from requirements.txt will break these tests. Either update tests to not rely on PyJWT, or retain the dependency in requirements.txt.

🤖 Prompt for AI Agents
In requirements.txt at line 8, removing PyJWT breaks tests that still import and
use jwt (tests/test_custom_snippet.py and tests/test_auth.py); either restore
the dependency by re-adding a compatible PyJWT entry (e.g., PyJWT>=2.0,<3.0) to
requirements.txt so tests keep working, or update the tests to stop relying on
PyJWT by replacing jwt.encode()/jwt.decode() calls with a project-supported
alternative (mock the jwt module or use a different JWT library already in
requirements) and adjust imports accordingly; pick one approach and apply
consistently across the mentioned test files.

26 changes: 13 additions & 13 deletions tests/test_caching_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest

from examples.05_caching_redis_custom import (
from examples.caching_redis_custom import (
ConfigurableCacheEndpoint,
CustomCache,
WeatherEndpoint,
Expand All @@ -32,8 +32,8 @@ def endpoint(self):
endpoint.cache = CustomCache()
return endpoint

@patch("examples.05_caching_redis_custom.time.sleep")
@patch("examples.05_caching_redis_custom.print")
@patch("examples.caching_redis_custom.time.sleep")
@patch("examples.caching_redis_custom.print")
def test_get_cache_miss(self, mock_print, mock_sleep, endpoint):
"""Test that get generates new data on cache miss.

Expand Down Expand Up @@ -64,8 +64,8 @@ class MockRequest:
# Verify sleep was called to simulate slow operation
mock_sleep.assert_called_once_with(0.1)

@patch("examples.05_caching_redis_custom.time.sleep")
@patch("examples.05_caching_redis_custom.print")
@patch("examples.caching_redis_custom.time.sleep")
@patch("examples.caching_redis_custom.print")
def test_get_cache_hit(self, mock_print, mock_sleep, endpoint):
"""Test that get returns cached data on cache hit.

Expand Down Expand Up @@ -107,7 +107,7 @@ class MockRequest:
# Verify sleep was not called (no slow operation)
mock_sleep.assert_not_called()

@patch("examples.05_caching_redis_custom.print")
@patch("examples.caching_redis_custom.print")
def test_delete_specific_city(self, mock_print, endpoint):
"""Test that delete removes cache for a specific city.

Expand Down Expand Up @@ -137,7 +137,7 @@ class MockRequest:
# Verify delete cache method was called
mock_print.assert_any_call("Cache DELETE for 'weather:London'")

@patch("examples.05_caching_redis_custom.print")
@patch("examples.caching_redis_custom.print")
def test_delete_all_cities(self, mock_print, endpoint):
"""Test that delete with no city param clears all cache.

Expand Down Expand Up @@ -185,8 +185,8 @@ def endpoint(self):
endpoint.cache = CustomCache()
return endpoint

@patch("examples.05_caching_redis_custom.time.sleep")
@patch("examples.05_caching_redis_custom.print")
@patch("examples.caching_redis_custom.time.sleep")
@patch("examples.caching_redis_custom.print")
def test_get_with_custom_ttl(self, mock_print, mock_sleep, endpoint):
"""Test that get respects custom TTL from query params.

Expand Down Expand Up @@ -215,8 +215,8 @@ class MockRequest:
# Verify sleep was called to simulate slow operation
mock_sleep.assert_called_once_with(1)

@patch("examples.05_caching_redis_custom.time.sleep")
@patch("examples.05_caching_redis_custom.print")
@patch("examples.caching_redis_custom.time.sleep")
@patch("examples.caching_redis_custom.print")
def test_get_with_default_ttl(self, mock_print, mock_sleep, endpoint):
"""Test that get uses default TTL when not specified.

Expand All @@ -241,8 +241,8 @@ class MockRequest:
# Verify cache was set with default TTL
mock_print.assert_any_call("Cache SET for 'resource:resource123' (expires in 60s)")

@patch("examples.05_caching_redis_custom.time.sleep")
@patch("examples.05_caching_redis_custom.print")
@patch("examples.caching_redis_custom.time.sleep")
@patch("examples.caching_redis_custom.print")
def test_get_cache_hit(self, mock_print, mock_sleep, endpoint):
"""Test that get uses cached data when available.

Expand Down
2 changes: 1 addition & 1 deletion tests/test_custom_snippet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jwt
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Remove PyJWT usage in tests; use local jwt helper.

PyJWT is no longer a dependency; this import will fail.

Apply:

-import jwt
+import time
+from lightapi.jwt_custom import jwt_encode
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import jwt
import time
from lightapi.jwt_custom import jwt_encode
🤖 Prompt for AI Agents
In tests/test_custom_snippet.py around line 9, remove the external "import jwt"
(PyJWT is no longer a dependency) and instead import the repository's local JWT
helper; replace the import with the local helper module/function used across the
test suite (for example import the encode helper such as from
tests.helpers.jwt_helper import jwt_encode or from tests.utils import
jwt_encode) and update any jwt.* calls in this test to use that local helper
(e.g., jwt_encode) so the test no longer depends on PyJWT.

from starlette.testclient import TestClient

from examples.07_middleware_cors_auth import Company, CustomEndpoint, create_app
from examples.middleware_cors_auth import Company, CustomEndpoint, create_app
from lightapi.config import config
from lightapi.core import Middleware, Response
from lightapi.lightapi import LightApi
Expand Down
4 changes: 2 additions & 2 deletions tests/test_filtering_pagination_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from examples.04_filtering_pagination import (
from examples.filtering_pagination import (
Product,
ProductFilter,
ProductPaginator,
Expand Down Expand Up @@ -457,7 +457,7 @@ def test_init_database(self):
# Call the init_database function with a custom engine
from unittest.mock import patch

with patch("examples.filtering_pagination_04.create_engine", return_value=engine):
with patch("examples.filtering_pagination.create_engine", return_value=engine):
init_database()

# Create a session to verify the data
Expand Down
Loading