diff --git a/src/ai/providers/ai_gateway/client/errors.py b/src/ai/providers/ai_gateway/client/errors.py index a7000366..0cfa774f 100644 --- a/src/ai/providers/ai_gateway/client/errors.py +++ b/src/ai/providers/ai_gateway/client/errors.py @@ -33,6 +33,14 @@ class GatewayError(Exception): status_code: int generation_id: str | None is_retryable: bool + response_body: Any = None + """Full HTTP error response body from the gateway (parsed JSON when + possible, else the raw text). + + The gateway includes provider routing/fallback details here -- which + backends it tried and how each one failed -- so the whole body is + retained rather than just the extracted ``message``/``type``. + """ def __init__( self, @@ -41,12 +49,14 @@ def __init__( status_code: int = 500, generation_id: str | None = None, is_retryable: bool | None = None, + response_body: Any = None, ) -> None: display = f"{message} [{generation_id}]" if generation_id else message super().__init__(display) self.message = display self.status_code = status_code self.generation_id = generation_id + self.response_body = response_body self.is_retryable = ( status_code in {408, 409, 429} or status_code >= 500 if is_retryable is None @@ -210,8 +220,8 @@ def __init__( message, status_code=status_code, generation_id=generation_id, + response_body=response_body, ) - self.response_body = response_body self.validation_error = validation_error @@ -295,9 +305,10 @@ def create_gateway_error( error_type: str | None = error_obj.get("type") generation_id: str | None = body.get("generationId") + err: GatewayError match error_type: case "authentication_error": - return GatewayAuthenticationError.create_contextual( + err = GatewayAuthenticationError.create_contextual( api_key_provided=api_key_provided, status_code=status_code, generation_id=generation_id, @@ -306,7 +317,7 @@ def create_gateway_error( case "model_not_found": param = error_obj.get("param") model_id = param.get("modelId") if isinstance(param, dict) else None - return GatewayModelNotFoundError( + err = GatewayModelNotFoundError( message=message, status_code=status_code, model_id=model_id, @@ -315,8 +326,13 @@ def create_gateway_error( case _: cls = _TYPE_MAP.get(error_type or "", GatewayInternalServerError) - return cls( + err = cls( message=message, status_code=status_code, generation_id=generation_id, ) + + # Retain the full body (provider routing/fallback details live alongside + # the extracted message/type) for surfacing to callers. + err.response_body = body + return err diff --git a/src/ai/providers/ai_gateway/errors.py b/src/ai/providers/ai_gateway/errors.py index acfe78f1..0f8d4b45 100644 --- a/src/ai/providers/ai_gateway/errors.py +++ b/src/ai/providers/ai_gateway/errors.py @@ -20,15 +20,14 @@ def map_error(exc: client_errors.GatewayError) -> ai_errors.ProviderAPIError: model_id=exc.model_id, provider="ai-gateway", http_context=_http_context(exc), + body=exc.response_body, error_type=exc.type, is_retryable=exc.is_retryable, ) if isinstance(exc, client_errors.GatewayInternalServerError): return _mapped(ai_errors.ProviderInternalServerError, exc) if isinstance(exc, client_errors.GatewayResponseError): - return _mapped( - ai_errors.ProviderResponseError, exc, body=exc.response_body - ) + return _mapped(ai_errors.ProviderResponseError, exc) if isinstance(exc, client_errors.GatewayTimeoutError): return _mapped(ai_errors.ProviderTimeoutError, exc) return _mapped(ai_errors.ProviderAPIError, exc) @@ -37,14 +36,12 @@ def map_error(exc: client_errors.GatewayError) -> ai_errors.ProviderAPIError: def _mapped( cls: type[ai_errors.ProviderAPIError], exc: client_errors.GatewayError, - *, - body: object | None = None, ) -> ai_errors.ProviderAPIError: return cls( str(exc), provider="ai-gateway", http_context=_http_context(exc), - body=body, + body=exc.response_body, error_type=exc.type, is_retryable=exc.is_retryable, ) diff --git a/tests/providers/ai_gateway/test_errors.py b/tests/providers/ai_gateway/test_errors.py index 2c737a29..6e55ce1d 100644 --- a/tests/providers/ai_gateway/test_errors.py +++ b/tests/providers/ai_gateway/test_errors.py @@ -199,3 +199,43 @@ def test_response_error_mapping_preserves_response_body(self) -> None: assert mapped.body == {"raw": True} assert mapped.http_context is not None assert mapped.http_context.status_code == 502 + + def test_full_body_with_routing_info_is_retained(self) -> None: + # Routing/fallback details live alongside the extracted message, + # so the whole body must survive parsing and mapping. + body = { + "error": { + "message": "All providers failed", + "type": "internal_server_error", + }, + "routing": { + "attempts": [ + {"provider": "anthropic", "error": "overloaded"}, + {"provider": "bedrock", "error": "timeout"}, + ] + }, + } + err = client_errors.create_gateway_error( + response_body=json.dumps(body), status_code=500 + ) + assert err.response_body == body + mapped = errors.map_error(err) + assert mapped.body == body + + def test_body_retained_for_every_mapped_error_type(self) -> None: + for error_type, status in ( + ("authentication_error", 401), + ("invalid_request_error", 400), + ("rate_limit_exceeded", 429), + ("model_not_found", 404), + ("internal_server_error", 500), + ): + body = { + "error": {"message": "boom", "type": error_type}, + "routing": {"tried": ["a", "b"]}, + } + err = client_errors.create_gateway_error( + response_body=body, status_code=status + ) + assert err.response_body == body, error_type + assert errors.map_error(err).body == body, error_type