Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/ai/providers/ai_gateway/client/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ class GatewayError(Exception):
status_code: int
generation_id: str | None
is_retryable: bool
response_body: Any = None
"""Full HTTP error response body from the gateway (parsed JSON when
possible, else the raw text).

The gateway includes provider routing/fallback details here -- which
backends it tried and how each one failed -- so the whole body is
retained rather than just the extracted ``message``/``type``.
"""

def __init__(
self,
Expand All @@ -41,12 +49,14 @@ def __init__(
status_code: int = 500,
generation_id: str | None = None,
is_retryable: bool | None = None,
response_body: Any = None,
) -> None:
display = f"{message} [{generation_id}]" if generation_id else message
super().__init__(display)
self.message = display
self.status_code = status_code
self.generation_id = generation_id
self.response_body = response_body
self.is_retryable = (
status_code in {408, 409, 429} or status_code >= 500
if is_retryable is None
Expand Down Expand Up @@ -210,8 +220,8 @@ def __init__(
message,
status_code=status_code,
generation_id=generation_id,
response_body=response_body,
)
self.response_body = response_body
self.validation_error = validation_error


Expand Down Expand Up @@ -295,9 +305,10 @@ def create_gateway_error(
error_type: str | None = error_obj.get("type")
generation_id: str | None = body.get("generationId")

err: GatewayError
match error_type:
case "authentication_error":
return GatewayAuthenticationError.create_contextual(
err = GatewayAuthenticationError.create_contextual(
api_key_provided=api_key_provided,
status_code=status_code,
generation_id=generation_id,
Expand All @@ -306,7 +317,7 @@ def create_gateway_error(
case "model_not_found":
param = error_obj.get("param")
model_id = param.get("modelId") if isinstance(param, dict) else None
return GatewayModelNotFoundError(
err = GatewayModelNotFoundError(
message=message,
status_code=status_code,
model_id=model_id,
Expand All @@ -315,8 +326,13 @@ def create_gateway_error(

case _:
cls = _TYPE_MAP.get(error_type or "", GatewayInternalServerError)
return cls(
err = cls(
message=message,
status_code=status_code,
generation_id=generation_id,
)

# Retain the full body (provider routing/fallback details live alongside
# the extracted message/type) for surfacing to callers.
err.response_body = body
return err
9 changes: 3 additions & 6 deletions src/ai/providers/ai_gateway/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ def map_error(exc: client_errors.GatewayError) -> ai_errors.ProviderAPIError:
model_id=exc.model_id,
provider="ai-gateway",
http_context=_http_context(exc),
body=exc.response_body,
error_type=exc.type,
is_retryable=exc.is_retryable,
)
if isinstance(exc, client_errors.GatewayInternalServerError):
return _mapped(ai_errors.ProviderInternalServerError, exc)
if isinstance(exc, client_errors.GatewayResponseError):
return _mapped(
ai_errors.ProviderResponseError, exc, body=exc.response_body
)
return _mapped(ai_errors.ProviderResponseError, exc)
if isinstance(exc, client_errors.GatewayTimeoutError):
return _mapped(ai_errors.ProviderTimeoutError, exc)
return _mapped(ai_errors.ProviderAPIError, exc)
Expand All @@ -37,14 +36,12 @@ def map_error(exc: client_errors.GatewayError) -> ai_errors.ProviderAPIError:
def _mapped(
cls: type[ai_errors.ProviderAPIError],
exc: client_errors.GatewayError,
*,
body: object | None = None,
) -> ai_errors.ProviderAPIError:
return cls(
str(exc),
provider="ai-gateway",
http_context=_http_context(exc),
body=body,
body=exc.response_body,
error_type=exc.type,
is_retryable=exc.is_retryable,
)
Expand Down
40 changes: 40 additions & 0 deletions tests/providers/ai_gateway/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,43 @@ def test_response_error_mapping_preserves_response_body(self) -> None:
assert mapped.body == {"raw": True}
assert mapped.http_context is not None
assert mapped.http_context.status_code == 502

def test_full_body_with_routing_info_is_retained(self) -> None:
# Routing/fallback details live alongside the extracted message,
# so the whole body must survive parsing and mapping.
body = {
"error": {
"message": "All providers failed",
"type": "internal_server_error",
},
"routing": {
"attempts": [
{"provider": "anthropic", "error": "overloaded"},
{"provider": "bedrock", "error": "timeout"},
]
},
}
err = client_errors.create_gateway_error(
response_body=json.dumps(body), status_code=500
)
assert err.response_body == body
mapped = errors.map_error(err)
assert mapped.body == body

def test_body_retained_for_every_mapped_error_type(self) -> None:
for error_type, status in (
("authentication_error", 401),
("invalid_request_error", 400),
("rate_limit_exceeded", 429),
("model_not_found", 404),
("internal_server_error", 500),
):
body = {
"error": {"message": "boom", "type": error_type},
"routing": {"tried": ["a", "b"]},
}
err = client_errors.create_gateway_error(
response_body=body, status_code=status
)
assert err.response_body == body, error_type
assert errors.map_error(err).body == body, error_type
Loading