Skip to content
1 change: 1 addition & 0 deletions changes/9595.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Register a persistent `BackendAIClientRegistry` on the webserver and use it for the `update-password-no-auth` API handler
23 changes: 23 additions & 0 deletions src/ai/backend/client/v2/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from datetime import datetime
from typing import override

from yarl import URL

Expand All @@ -21,6 +22,27 @@ def sign(
raise NotImplementedError


class NoAuth(AuthStrategy):
"""Auth strategy that provides no credentials.

Used by the webserver, which has no real keypair and only proxies requests.
Any accidental call to an authenticated endpoint will get a 401 from the
manager rather than crash internally.
"""

@override
def sign(
self,
method: str,
version: str,
endpoint: URL,
date: datetime,
rel_url: str,
content_type: str,
) -> Mapping[str, str]:
return {}


class HMACAuth(AuthStrategy):
def __init__(
self,
Expand All @@ -32,6 +54,7 @@ def __init__(
self.secret_key: str = secret_key
self.hash_type: str = hash_type

@override
def sign(
self,
method: str,
Expand Down
10 changes: 8 additions & 2 deletions src/ai/backend/client/v2/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,17 @@ async def _request(
*,
json: Any | None = None,
params: dict[str, str] | None = None,
extra_headers: Mapping[str, str] | None = None,
) -> dict[str, Any] | list[Any] | None:
session = self._session
content_type = "application/json"
rel_url = "/" + path.lstrip("/")
if params:
qs = "&".join(f"{k}={v}" for k, v in params.items())
rel_url = f"{rel_url}?{qs}"
headers = self._sign(method, rel_url, content_type)
headers = {**self._sign(method, rel_url, content_type)}
if extra_headers:
headers.update(extra_headers)
url = self._build_url(path)
async with session.request(
method,
Expand Down Expand Up @@ -136,11 +139,14 @@ async def typed_request(
request: BaseRequestModel | None = None,
response_model: type[ResponseT],
params: dict[str, str] | None = None,
extra_headers: Mapping[str, str] | None = None,
) -> ResponseT:
json_body = (
request.model_dump(mode="json", exclude_none=True) if request is not None else None
)
data = await self._request(method, path, json=json_body, params=params)
data = await self._request(
method, path, json=json_body, params=params, extra_headers=extra_headers
)
if data is None:
raise BackendAPIError(
204,
Expand Down
8 changes: 7 additions & 1 deletion src/ai/backend/client/v2/domains/auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Mapping

from ai.backend.client.v2.base_client import BackendAIAnonymousClient, BackendAIAuthClient
from ai.backend.client.v2.base_domain import BaseDomainClient
from ai.backend.common.dto.manager.auth.request import (
Expand Down Expand Up @@ -74,13 +76,17 @@ async def update_password(self, request: UpdatePasswordRequest) -> UpdatePasswor
)

async def update_password_no_auth(
self, request: UpdatePasswordNoAuthRequest
self,
request: UpdatePasswordNoAuthRequest,
*,
extra_headers: Mapping[str, str] | None = None,
) -> UpdatePasswordNoAuthResponse:
return await self._anon_client.typed_request(
"POST",
"/auth/update-password-no-auth",
request=request,
response_model=UpdatePasswordNoAuthResponse,
extra_headers=extra_headers,
)

async def update_full_name(self, request: UpdateFullNameRequest) -> UpdateFullNameResponse:
Expand Down
28 changes: 20 additions & 8 deletions src/ai/backend/web/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import cast

from aiohttp import web
from multidict import CIMultiDict

from ai.backend.client.config import APIConfig
from ai.backend.client.session import AsyncSession as APISession
Expand Down Expand Up @@ -109,18 +110,29 @@ def get_client_ip(request: web.Request) -> str | None:
return client_ip


def fill_forwarding_hdrs_to_api_session(
request: web.Request,
api_session: APISession,
) -> None:
_headers = {
def build_forwarding_headers(request: web.Request) -> CIMultiDict[str]:
"""Build forwarding headers from the incoming request.

Returns a ``CIMultiDict`` of ``X-Forwarded-*`` HTTP headers that can be
applied to outgoing requests via ``extra_headers`` parameters or
:func:`fill_forwarding_hdrs_to_api_session`.
"""
headers: CIMultiDict[str] = CIMultiDict({
"X-Forwarded-Host": request.headers.get("X-Forwarded-Host", request.host),
"X-Forwarded-Proto": request.headers.get("X-Forwarded-Proto", request.scheme),
}
})
client_ip = get_client_ip(request)
if client_ip:
_headers["X-Forwarded-For"] = client_ip
api_session.aiohttp_session.headers.update(_headers)
headers["X-Forwarded-For"] = client_ip
return headers


def fill_forwarding_hdrs_to_api_session(
request: web.Request,
api_session: APISession,
) -> None:
_headers = build_forwarding_headers(request)
api_session.aiohttp_session.headers.update(_headers)


async def generate_jwt_token_for_session(
Expand Down
64 changes: 39 additions & 25 deletions src/ai/backend/web/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,20 @@
import uvloop
from aiohttp import web
from setproctitle import setproctitle
from yarl import URL

from ai.backend.client.config import APIConfig
from ai.backend.client.exceptions import BackendAPIError, BackendClientError
from ai.backend.client.session import AsyncSession as APISession
from ai.backend.client.v2.auth import NoAuth
from ai.backend.client.v2.config import ClientConfig as V2ClientConfig
from ai.backend.client.v2.registry import BackendAIClientRegistry
from ai.backend.common import config
from ai.backend.common.clients.http_client.client_pool import ClientPool
from ai.backend.common.clients.valkey_client.valkey_session.client import ValkeySessionClient
from ai.backend.common.defs import REDIS_STATISTICS_DB, RedisRole
from ai.backend.common.dto.internal.health import HealthResponse, HealthStatus
from ai.backend.common.dto.manager.auth.request import UpdatePasswordNoAuthRequest
from ai.backend.common.dto.manager.auth.types import (
AuthSuccessResponse,
RequireTwoFactorAuthResponse,
Expand All @@ -68,7 +73,7 @@
from ai.backend.web.security import SecurityPolicy, security_policy_middleware

from . import __version__, user_agent
from .auth import fill_forwarding_hdrs_to_api_session, get_client_ip
from .auth import build_forwarding_headers, fill_forwarding_hdrs_to_api_session, get_client_ip
from .errors import InvalidAPIConfigurationError
from .proxy import (
decrypt_payload,
Expand Down Expand Up @@ -231,31 +236,22 @@ def _check_params(param_names: list[str]) -> web.Response | None:
}

try:
anon_api_config = APIConfig(
domain=config.api.domain,
endpoint=str(config.api.endpoint[0]),
access_key="",
secret_key="", # anonymous session
user_agent=user_agent,
skip_sslcert_validation=not config.api.ssl_verify,
registry: BackendAIClientRegistry = request.app["no_auth_client_registry"]
resp = await registry.auth.update_password_no_auth(
UpdatePasswordNoAuthRequest(
domain=config.api.domain,
username=creds["username"],
current_password=creds["current_password"],
new_password=creds["new_password"],
),
extra_headers=build_forwarding_headers(request),
)
result["password_changed_at"] = resp.password_changed_at
log.info(
"UPDATE_PASSWORD_NO_AUTH: Authorization succeeded for (email:{}, ip:{})",
creds["username"],
client_ip,
)
if not anon_api_config.is_anonymous:
raise InvalidAPIConfigurationError(
"Anonymous API configuration is not properly initialized."
)
async with APISession(config=anon_api_config) as api_session:
fill_forwarding_hdrs_to_api_session(request, api_session)
result = await api_session.Auth.update_password_no_auth(
config.api.domain,
creds["username"],
creds["current_password"],
creds["new_password"],
)
log.info(
"UPDATE_PASSWORD_NO_AUTH: Authorization succeeded for (email:{}, ip:{})",
creds["username"],
client_ip,
)
except BackendClientError as e:
# This is error, not failed login, so we should not update login history.
return web.HTTPBadGateway(
Expand Down Expand Up @@ -764,6 +760,21 @@ async def _shutdown(_app: web.Application) -> None:
yield client_pool


@asynccontextmanager
async def no_auth_client_registry_ctx(
config: WebServerUnifiedConfig,
) -> AsyncGenerator[BackendAIClientRegistry]:
client_config = V2ClientConfig(
endpoint=URL(str(config.api.endpoint[0])),
skip_ssl_verification=not config.api.ssl_verify,
)
registry = await BackendAIClientRegistry.create(client_config, NoAuth())
try:
yield registry
finally:
await registry.close()


@asynccontextmanager
async def webapp_ctx(
config: WebServerUnifiedConfig,
Expand Down Expand Up @@ -945,6 +956,9 @@ async def server_main(
app["config"] = config
app["stats"] = WebStats()
app["client_pool"] = await web_init_stack.enter_async_context(client_ctx(config, app))
app["no_auth_client_registry"] = await web_init_stack.enter_async_context(
no_auth_client_registry_ctx(config)
)
await web_init_stack.enter_async_context(redis_ctx(config, app, pidx))

# Initialize health probe
Expand Down
Loading