diff --git a/changes/9595.feature.md b/changes/9595.feature.md new file mode 100644 index 00000000000..48f8e281ee3 --- /dev/null +++ b/changes/9595.feature.md @@ -0,0 +1 @@ +Register a persistent `BackendAIClientRegistry` on the webserver and use it for the `update-password-no-auth` API handler diff --git a/src/ai/backend/client/v2/auth.py b/src/ai/backend/client/v2/auth.py index fb464662288..7f6c4a0ad72 100644 --- a/src/ai/backend/client/v2/auth.py +++ b/src/ai/backend/client/v2/auth.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from datetime import datetime +from typing import override from yarl import URL @@ -21,6 +22,27 @@ def sign( raise NotImplementedError +class NoAuth(AuthStrategy): + """Auth strategy that provides no credentials. + + Used by the webserver, which has no real keypair and only proxies requests. + Any accidental call to an authenticated endpoint will get a 401 from the + manager rather than crash internally. + """ + + @override + def sign( + self, + method: str, + version: str, + endpoint: URL, + date: datetime, + rel_url: str, + content_type: str, + ) -> Mapping[str, str]: + return {} + + class HMACAuth(AuthStrategy): def __init__( self, @@ -32,6 +54,7 @@ def __init__( self.secret_key: str = secret_key self.hash_type: str = hash_type + @override def sign( self, method: str, diff --git a/src/ai/backend/client/v2/base_client.py b/src/ai/backend/client/v2/base_client.py index 6be0bb2c63b..115d724676d 100644 --- a/src/ai/backend/client/v2/base_client.py +++ b/src/ai/backend/client/v2/base_client.py @@ -101,6 +101,7 @@ async def _request( *, json: Any | None = None, params: dict[str, str] | None = None, + extra_headers: Mapping[str, str] | None = None, ) -> dict[str, Any] | list[Any] | None: session = self._session content_type = "application/json" @@ -108,7 +109,9 @@ async def _request( if params: qs = "&".join(f"{k}={v}" for k, v in params.items()) rel_url = f"{rel_url}?{qs}" - headers = self._sign(method, rel_url, content_type) + headers = {**self._sign(method, rel_url, content_type)} + if extra_headers: + headers.update(extra_headers) url = self._build_url(path) async with session.request( method, @@ -136,11 +139,14 @@ async def typed_request( request: BaseRequestModel | None = None, response_model: type[ResponseT], params: dict[str, str] | None = None, + extra_headers: Mapping[str, str] | None = None, ) -> ResponseT: json_body = ( request.model_dump(mode="json", exclude_none=True) if request is not None else None ) - data = await self._request(method, path, json=json_body, params=params) + data = await self._request( + method, path, json=json_body, params=params, extra_headers=extra_headers + ) if data is None: raise BackendAPIError( 204, diff --git a/src/ai/backend/client/v2/domains/auth.py b/src/ai/backend/client/v2/domains/auth.py index 70532b9ab9d..f53d1f04f88 100644 --- a/src/ai/backend/client/v2/domains/auth.py +++ b/src/ai/backend/client/v2/domains/auth.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping + from ai.backend.client.v2.base_client import BackendAIAnonymousClient, BackendAIAuthClient from ai.backend.client.v2.base_domain import BaseDomainClient from ai.backend.common.dto.manager.auth.request import ( @@ -74,13 +76,17 @@ async def update_password(self, request: UpdatePasswordRequest) -> UpdatePasswor ) async def update_password_no_auth( - self, request: UpdatePasswordNoAuthRequest + self, + request: UpdatePasswordNoAuthRequest, + *, + extra_headers: Mapping[str, str] | None = None, ) -> UpdatePasswordNoAuthResponse: return await self._anon_client.typed_request( "POST", "/auth/update-password-no-auth", request=request, response_model=UpdatePasswordNoAuthResponse, + extra_headers=extra_headers, ) async def update_full_name(self, request: UpdateFullNameRequest) -> UpdateFullNameResponse: diff --git a/src/ai/backend/web/auth.py b/src/ai/backend/web/auth.py index dde5e525db2..8f3aeb292d2 100644 --- a/src/ai/backend/web/auth.py +++ b/src/ai/backend/web/auth.py @@ -4,6 +4,7 @@ from typing import cast from aiohttp import web +from multidict import CIMultiDict from ai.backend.client.config import APIConfig from ai.backend.client.session import AsyncSession as APISession @@ -109,18 +110,29 @@ def get_client_ip(request: web.Request) -> str | None: return client_ip -def fill_forwarding_hdrs_to_api_session( - request: web.Request, - api_session: APISession, -) -> None: - _headers = { +def build_forwarding_headers(request: web.Request) -> CIMultiDict[str]: + """Build forwarding headers from the incoming request. + + Returns a ``CIMultiDict`` of ``X-Forwarded-*`` HTTP headers that can be + applied to outgoing requests via ``extra_headers`` parameters or + :func:`fill_forwarding_hdrs_to_api_session`. + """ + headers: CIMultiDict[str] = CIMultiDict({ "X-Forwarded-Host": request.headers.get("X-Forwarded-Host", request.host), "X-Forwarded-Proto": request.headers.get("X-Forwarded-Proto", request.scheme), - } + }) client_ip = get_client_ip(request) if client_ip: - _headers["X-Forwarded-For"] = client_ip - api_session.aiohttp_session.headers.update(_headers) + headers["X-Forwarded-For"] = client_ip + return headers + + +def fill_forwarding_hdrs_to_api_session( + request: web.Request, + api_session: APISession, +) -> None: + _headers = build_forwarding_headers(request) + api_session.aiohttp_session.headers.update(_headers) async def generate_jwt_token_for_session( diff --git a/src/ai/backend/web/server.py b/src/ai/backend/web/server.py index 3de0f3f7e64..09e6b2904b7 100644 --- a/src/ai/backend/web/server.py +++ b/src/ai/backend/web/server.py @@ -36,15 +36,20 @@ import uvloop from aiohttp import web from setproctitle import setproctitle +from yarl import URL from ai.backend.client.config import APIConfig from ai.backend.client.exceptions import BackendAPIError, BackendClientError from ai.backend.client.session import AsyncSession as APISession +from ai.backend.client.v2.auth import NoAuth +from ai.backend.client.v2.config import ClientConfig as V2ClientConfig +from ai.backend.client.v2.registry import BackendAIClientRegistry from ai.backend.common import config from ai.backend.common.clients.http_client.client_pool import ClientPool from ai.backend.common.clients.valkey_client.valkey_session.client import ValkeySessionClient from ai.backend.common.defs import REDIS_STATISTICS_DB, RedisRole from ai.backend.common.dto.internal.health import HealthResponse, HealthStatus +from ai.backend.common.dto.manager.auth.request import UpdatePasswordNoAuthRequest from ai.backend.common.dto.manager.auth.types import ( AuthSuccessResponse, RequireTwoFactorAuthResponse, @@ -68,7 +73,7 @@ from ai.backend.web.security import SecurityPolicy, security_policy_middleware from . import __version__, user_agent -from .auth import fill_forwarding_hdrs_to_api_session, get_client_ip +from .auth import build_forwarding_headers, fill_forwarding_hdrs_to_api_session, get_client_ip from .errors import InvalidAPIConfigurationError from .proxy import ( decrypt_payload, @@ -231,31 +236,22 @@ def _check_params(param_names: list[str]) -> web.Response | None: } try: - anon_api_config = APIConfig( - domain=config.api.domain, - endpoint=str(config.api.endpoint[0]), - access_key="", - secret_key="", # anonymous session - user_agent=user_agent, - skip_sslcert_validation=not config.api.ssl_verify, + registry: BackendAIClientRegistry = request.app["no_auth_client_registry"] + resp = await registry.auth.update_password_no_auth( + UpdatePasswordNoAuthRequest( + domain=config.api.domain, + username=creds["username"], + current_password=creds["current_password"], + new_password=creds["new_password"], + ), + extra_headers=build_forwarding_headers(request), + ) + result["password_changed_at"] = resp.password_changed_at + log.info( + "UPDATE_PASSWORD_NO_AUTH: Authorization succeeded for (email:{}, ip:{})", + creds["username"], + client_ip, ) - if not anon_api_config.is_anonymous: - raise InvalidAPIConfigurationError( - "Anonymous API configuration is not properly initialized." - ) - async with APISession(config=anon_api_config) as api_session: - fill_forwarding_hdrs_to_api_session(request, api_session) - result = await api_session.Auth.update_password_no_auth( - config.api.domain, - creds["username"], - creds["current_password"], - creds["new_password"], - ) - log.info( - "UPDATE_PASSWORD_NO_AUTH: Authorization succeeded for (email:{}, ip:{})", - creds["username"], - client_ip, - ) except BackendClientError as e: # This is error, not failed login, so we should not update login history. return web.HTTPBadGateway( @@ -764,6 +760,21 @@ async def _shutdown(_app: web.Application) -> None: yield client_pool +@asynccontextmanager +async def no_auth_client_registry_ctx( + config: WebServerUnifiedConfig, +) -> AsyncGenerator[BackendAIClientRegistry]: + client_config = V2ClientConfig( + endpoint=URL(str(config.api.endpoint[0])), + skip_ssl_verification=not config.api.ssl_verify, + ) + registry = await BackendAIClientRegistry.create(client_config, NoAuth()) + try: + yield registry + finally: + await registry.close() + + @asynccontextmanager async def webapp_ctx( config: WebServerUnifiedConfig, @@ -945,6 +956,9 @@ async def server_main( app["config"] = config app["stats"] = WebStats() app["client_pool"] = await web_init_stack.enter_async_context(client_ctx(config, app)) + app["no_auth_client_registry"] = await web_init_stack.enter_async_context( + no_auth_client_registry_ctx(config) + ) await web_init_stack.enter_async_context(redis_ctx(config, app, pidx)) # Initialize health probe