From a73382340a1667573ce1d484b05e864c5c75ac81 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Sun, 26 May 2024 12:10:12 +0900 Subject: [PATCH 1/6] impl basic APIs and client side logic --- src/ai/backend/common/validators.py | 17 ++++++ src/ai/backend/manager/config.py | 4 ++ src/ai/backend/manager/registry.py | 6 +- .../backend/manager/scheduler/dispatcher.py | 55 ++++++++++++++++--- src/ai/backend/manager/types.py | 9 +++ 5 files changed, 83 insertions(+), 8 deletions(-) diff --git a/src/ai/backend/common/validators.py b/src/ai/backend/common/validators.py index 541a7e3ad3c..c2745b9c666 100644 --- a/src/ai/backend/common/validators.py +++ b/src/ai/backend/common/validators.py @@ -218,6 +218,23 @@ def check_and_return(self, value: Any) -> T_enum: self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value) +class EnumList(t.Trafaret, Generic[T_enum]): + def __init__(self, enum_cls: Type[T_enum], *, use_name: bool = False) -> None: + self.enum_cls = enum_cls + self.use_name = use_name + + def check_and_return(self, value: Any) -> list[T_enum]: + try: + if self.use_name: + return [self.enum_cls[val] for val in value] + else: + return [self.enum_cls(val) for val in value] + except TypeError: + self._failure("cannot parse value into list", value=value) + except (KeyError, ValueError): + self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value) + + class JSONString(t.Trafaret): def check_and_return(self, value: Any) -> dict: try: diff --git a/src/ai/backend/manager/config.py b/src/ai/backend/manager/config.py index 8b93b29078d..16c991709cd 100644 --- a/src/ai/backend/manager/config.py +++ b/src/ai/backend/manager/config.py @@ -224,6 +224,7 @@ from .api.exceptions import ObjectNotFound, ServerMisconfiguredError from .models.session import SessionStatus from .pglock import PgAdvisoryLock +from .types import DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS, AgentResourceSyncTrigger log = BraceStyleAdapter(logging.getLogger(__spec__.name)) @@ -295,6 +296,9 @@ "agent-selection-resource-priority", default=["cuda", "rocm", "tpu", "cpu", "mem"], ): t.List(t.String), + t.Key( + "agent-resource-sync-trigger", default=DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS + ): tx.EnumList(AgentResourceSyncTrigger), t.Key("importer-image", default="lablup/importer:manylinux2010"): t.String, t.Key("max-wsmsg-size", default=16 * (2**20)): t.ToInt, # default: 16 MiB tx.AliasedKey(["aiomonitor-termui-port", "aiomonitor-port"], default=48100): t.ToInt[ diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index d62e78bd6eb..01eb83f708b 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -182,7 +182,7 @@ reenter_txn_session, sql_json_merge, ) -from .types import UserScope +from .types import AgentResourceSyncTrigger, UserScope if TYPE_CHECKING: from sqlalchemy.engine.row import Row @@ -1694,6 +1694,10 @@ async def _create_kernels_in_one_agent( is_local = image_info["is_local"] resource_policy: KeyPairResourcePolicyRow = image_info["resource_policy"] auto_pull = image_info["auto_pull"] + agent_resource_sync_trigger = cast( + list[AgentResourceSyncTrigger], + self.local_config["manager"]["agent-resource-sync-policy"], + ) assert agent_alloc_ctx.agent_id is not None assert scheduled_session.id is not None diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 6c3768f3897..23389d7757f 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -20,6 +20,7 @@ Sequence, Tuple, Union, + cast, ) import aiotools @@ -55,6 +56,7 @@ from ai.backend.common.plugin.hook import PASSED, HookResult from ai.backend.common.types import ( AgentId, + AgentKernelRegistryByStatus, ClusterMode, RedisConnectionInfo, ResourceSlot, @@ -74,7 +76,7 @@ SessionNotFound, ) from ..defs import SERVICE_MAX_RETRIES, LockID -from ..exceptions import convert_to_status_data +from ..exceptions import MultiAgentError, convert_to_status_data from ..models import ( AgentRow, AgentStatus, @@ -94,6 +96,7 @@ ) from ..models.utils import ExtendedAsyncSAEngine as SAEngine from ..models.utils import execute_with_retry, sql_json_increment, sql_json_merge +from ..types import AgentResourceSyncTrigger from .predicates import ( check_concurrency, check_dependencies, @@ -265,6 +268,10 @@ async def schedule( log.debug("schedule(): triggered") manager_id = self.local_config["manager"]["id"] redis_key = f"manager.{manager_id}.schedule" + agent_resource_sync_trigger = cast( + list[AgentResourceSyncTrigger], + self.local_config["manager"]["agent-resource-sync-policy"], + ) def _pipeline(r: Redis) -> RedisPipeline: pipe = r.pipeline() @@ -293,10 +300,6 @@ def _pipeline(r: Redis) -> RedisPipeline: # as its individual steps are composed of many short-lived transactions. async with self.lock_factory(LockID.LOCKID_SCHEDULE, 60): async with self.db.begin_readonly_session() as db_sess: - # query = ( - # sa.select(ScalingGroupRow) - # .join(ScalingGroupRow.agents.and_(AgentRow.status == AgentStatus.ALIVE)) - # ) query = ( sa.select(AgentRow.scaling_group) .where(AgentRow.status == AgentStatus.ALIVE) @@ -304,12 +307,15 @@ def _pipeline(r: Redis) -> RedisPipeline: ) result = await db_sess.execute(query) schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()] + produce_do_prepare = False for sgroup_name in schedulable_scaling_groups: try: - await self._schedule_in_sgroup( + kernel_agent_bindings = await self._schedule_in_sgroup( sched_ctx, sgroup_name, ) + if kernel_agent_bindings: + produce_do_prepare = True await redis_helper.execute( self.redis_live, lambda r: r.hset( @@ -320,6 +326,35 @@ def _pipeline(r: Redis) -> RedisPipeline: ) except Exception as e: log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e)) + else: + if ( + AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger + and kernel_agent_bindings + ): + selected_agent_ids = [ + binding.agent_alloc_ctx.agent_id + for binding in kernel_agent_bindings + if binding.agent_alloc_ctx.agent_id is not None + ] + async with self.db.begin() as db_conn: + results = await self.registry.sync_agent_resource( + self.db, selected_agent_ids + ) + for agent_id, result in results.items(): + match result: + case AgentKernelRegistryByStatus( + all_running_kernels, + actual_terminating_kernels, + actual_terminated_kernels, + ): + pass + case MultiAgentError(): + pass + case _: + pass + pass + async with SASession(bind=db_conn) as db_session: + pass await redis_helper.execute( self.redis_live, lambda r: r.hset( @@ -328,6 +363,8 @@ def _pipeline(r: Redis) -> RedisPipeline: datetime.now(tzutc()).isoformat(), ), ) + if produce_do_prepare: + await self.event_producer.produce_event(DoPrepareEvent()) except DBAPIError as e: if getattr(e.orig, "pgcode", None) == "55P03": log.info( @@ -433,7 +470,6 @@ async def _update(): ) zero = ResourceSlot() kernel_agent_bindings_in_sgroup: list[KernelAgentBinding] = [] - while len(pending_sessions) > 0: async with self.db.begin_readonly_session() as db_sess: candidate_agents = await list_schedulable_agents_by_sgroup(db_sess, sgroup_name) @@ -1011,6 +1047,11 @@ async def _finalize_scheduled() -> None: ) return kernel_agent_bindings + kernel_agent_bindings: list[KernelAgentBinding] = [] + for kernel_row in sess_ctx.kernels: + kernel_agent_bindings.append(KernelAgentBinding(kernel_row, agent_alloc_ctx, set())) + return kernel_agent_bindings + async def _schedule_multi_node_session( self, sched_ctx: SchedulingContext, diff --git a/src/ai/backend/manager/types.py b/src/ai/backend/manager/types.py index dee2842dd62..045721ae5df 100644 --- a/src/ai/backend/manager/types.py +++ b/src/ai/backend/manager/types.py @@ -56,3 +56,12 @@ class MountOptionModel(BaseModel): MountPermission | None, Field(validation_alias=AliasChoices("permission", "perm"), default=None), ] + + +class AgentResourceSyncTrigger(enum.StrEnum): + AFTER_SCHEDULING = "after-scheduling" + BEFORE_KERNEL_CREATION = "before-kernel-creation" + ON_CREATION_FAILURE = "on-creation-failure" + + +DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS: list[AgentResourceSyncTrigger] = [] From c01e3b94d23a3dbce571e58fdf71462fe3c22c08 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Sun, 26 May 2024 23:17:39 +0900 Subject: [PATCH 2/6] revert DoPrepareEvent trigger --- src/ai/backend/manager/scheduler/dispatcher.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 23389d7757f..47a8ff69db6 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -307,15 +307,12 @@ def _pipeline(r: Redis) -> RedisPipeline: ) result = await db_sess.execute(query) schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()] - produce_do_prepare = False for sgroup_name in schedulable_scaling_groups: try: kernel_agent_bindings = await self._schedule_in_sgroup( sched_ctx, sgroup_name, ) - if kernel_agent_bindings: - produce_do_prepare = True await redis_helper.execute( self.redis_live, lambda r: r.hset( @@ -363,8 +360,6 @@ def _pipeline(r: Redis) -> RedisPipeline: datetime.now(tzutc()).isoformat(), ), ) - if produce_do_prepare: - await self.event_producer.produce_event(DoPrepareEvent()) except DBAPIError as e: if getattr(e.orig, "pgcode", None) == "55P03": log.info( @@ -469,7 +464,9 @@ async def _update(): len(cancelled_sessions), ) zero = ResourceSlot() + num_scheduled = 0 kernel_agent_bindings_in_sgroup: list[KernelAgentBinding] = [] + while len(pending_sessions) > 0: async with self.db.begin_readonly_session() as db_sess: candidate_agents = await list_schedulable_agents_by_sgroup(db_sess, sgroup_name) From 5a26c2b0265e12b48bb29b18d25b6ca1c4d28c23 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Tue, 28 May 2024 16:11:19 +0900 Subject: [PATCH 3/6] impl agent resource sync API --- src/ai/backend/manager/registry.py | 2 +- .../backend/manager/scheduler/dispatcher.py | 87 +++++++++---------- src/ai/backend/manager/types.py | 4 +- 3 files changed, 44 insertions(+), 49 deletions(-) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 01eb83f708b..0dd7d52bff6 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -1696,7 +1696,7 @@ async def _create_kernels_in_one_agent( auto_pull = image_info["auto_pull"] agent_resource_sync_trigger = cast( list[AgentResourceSyncTrigger], - self.local_config["manager"]["agent-resource-sync-policy"], + self.local_config["manager"]["agent-resource-sync-trigger"], ) assert agent_alloc_ctx.agent_id is not None assert scheduled_session.id is not None diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 47a8ff69db6..70f7951d03b 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -56,7 +56,6 @@ from ai.backend.common.plugin.hook import PASSED, HookResult from ai.backend.common.types import ( AgentId, - AgentKernelRegistryByStatus, ClusterMode, RedisConnectionInfo, ResourceSlot, @@ -76,7 +75,7 @@ SessionNotFound, ) from ..defs import SERVICE_MAX_RETRIES, LockID -from ..exceptions import MultiAgentError, convert_to_status_data +from ..exceptions import convert_to_status_data from ..models import ( AgentRow, AgentStatus, @@ -270,7 +269,7 @@ async def schedule( redis_key = f"manager.{manager_id}.schedule" agent_resource_sync_trigger = cast( list[AgentResourceSyncTrigger], - self.local_config["manager"]["agent-resource-sync-policy"], + self.local_config["manager"]["agent-resource-sync-trigger"], ) def _pipeline(r: Redis) -> RedisPipeline: @@ -307,51 +306,45 @@ def _pipeline(r: Redis) -> RedisPipeline: ) result = await db_sess.execute(query) schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()] - for sgroup_name in schedulable_scaling_groups: - try: - kernel_agent_bindings = await self._schedule_in_sgroup( - sched_ctx, - sgroup_name, - ) - await redis_helper.execute( - self.redis_live, - lambda r: r.hset( - redis_key, - "resource_group", + + async with self.db.begin() as db_conn: + for sgroup_name in schedulable_scaling_groups: + try: + kernel_agent_bindings = await self._schedule_in_sgroup( + sched_ctx, sgroup_name, - ), - ) - except Exception as e: - log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e)) - else: - if ( - AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger - and kernel_agent_bindings - ): - selected_agent_ids = [ - binding.agent_alloc_ctx.agent_id - for binding in kernel_agent_bindings - if binding.agent_alloc_ctx.agent_id is not None - ] - async with self.db.begin() as db_conn: - results = await self.registry.sync_agent_resource( - self.db, selected_agent_ids - ) - for agent_id, result in results.items(): - match result: - case AgentKernelRegistryByStatus( - all_running_kernels, - actual_terminating_kernels, - actual_terminated_kernels, - ): - pass - case MultiAgentError(): - pass - case _: - pass - pass - async with SASession(bind=db_conn) as db_session: - pass + ) + except Exception as e: + log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e)) + else: + if ( + AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger + and kernel_agent_bindings + ): + selected_agent_ids = [ + binding.agent_alloc_ctx.agent_id + for binding in kernel_agent_bindings + if binding.agent_alloc_ctx.agent_id is not None + ] + async with self.db.begin() as db_conn: + results = await self.registry.sync_agent_resource( + self.db, selected_agent_ids + ) + for agent_id, result in results.items(): + match result: + case AgentKernelRegistryByStatus( + all_running_kernels, + actual_terminating_kernels, + actual_terminated_kernels, + ): + pass + case MultiAgentError(): + pass + case _: + pass + pass + async with SASession(bind=db_conn) as db_session: + pass await redis_helper.execute( self.redis_live, lambda r: r.hset( diff --git a/src/ai/backend/manager/types.py b/src/ai/backend/manager/types.py index 045721ae5df..78833ff861b 100644 --- a/src/ai/backend/manager/types.py +++ b/src/ai/backend/manager/types.py @@ -64,4 +64,6 @@ class AgentResourceSyncTrigger(enum.StrEnum): ON_CREATION_FAILURE = "on-creation-failure" -DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS: list[AgentResourceSyncTrigger] = [] +DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS: list[AgentResourceSyncTrigger] = [ + AgentResourceSyncTrigger.ON_CREATION_FAILURE +] From e40213bb0446c1c5cabe272720ca053a49509594 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Tue, 28 May 2024 16:17:12 +0900 Subject: [PATCH 4/6] add agent resource sync API endpoint --- src/ai/backend/manager/api/session.py | 37 ++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 15bce8562ef..ae634f1a41c 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -43,7 +43,11 @@ import trafaret as t from aiohttp import hdrs, web from dateutil.tz import tzutc -from pydantic import AliasChoices, BaseModel, Field +from pydantic import ( + AliasChoices, + BaseModel, + Field, +) from redis.asyncio import Redis from sqlalchemy.orm import noload, selectinload from sqlalchemy.sql.expression import null, true @@ -969,6 +973,36 @@ async def sync_agent_registry(request: web.Request, params: Any) -> web.StreamRe return web.json_response({}, status=200) +class SyncAgentResourceRequestModel(BaseModel): + agent_id: AgentId = Field( + validation_alias=AliasChoices("agent_id", "agent"), + description="Target agent id to sync resource.", + ) + + +@server_status_required(ALL_ALLOWED) +@auth_required +@pydantic_params_api_handler(SyncAgentResourceRequestModel) +async def sync_agent_resource( + request: web.Request, params: SyncAgentResourceRequestModel +) -> web.Response: + root_ctx: RootContext = request.app["_root.context"] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + + agent_id = params.agent_id + log.info( + "SYNC_AGENT_RESOURCE (ak:{}/{}, a:{})", requester_access_key, owner_access_key, agent_id + ) + + async with root_ctx.db.begin() as db_conn: + try: + await root_ctx.registry.sync_agent_resource(db_conn, [agent_id]) + except BackendError: + log.exception("SYNC_AGENT_RESOURCE: exception") + raise + return web.Response(status=204) + + @server_status_required(ALL_ALLOWED) @auth_required @check_api_params( @@ -2315,6 +2349,7 @@ def create_app( cors.add(app.router.add_route("POST", "/_/create-cluster", create_cluster)) cors.add(app.router.add_route("GET", "/_/match", match_sessions)) cors.add(app.router.add_route("POST", "/_/sync-agent-registry", sync_agent_registry)) + cors.add(app.router.add_route("POST", "/_/sync-agent-resource", sync_agent_resource)) session_resource = cors.add(app.router.add_resource(r"/{session_name}")) cors.add(session_resource.add_route("GET", get_info)) cors.add(session_resource.add_route("PATCH", restart)) From 8bf8d76850b84204ded8c1b8d9e2132abaad2022 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Tue, 4 Jun 2024 15:44:04 +0900 Subject: [PATCH 5/6] change field names --- src/ai/backend/manager/scheduler/dispatcher.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 70f7951d03b..d3477f103bb 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -56,6 +56,7 @@ from ai.backend.common.plugin.hook import PASSED, HookResult from ai.backend.common.types import ( AgentId, + AgentKernelRegistryByStatus, ClusterMode, RedisConnectionInfo, ResourceSlot, @@ -75,7 +76,7 @@ SessionNotFound, ) from ..defs import SERVICE_MAX_RETRIES, LockID -from ..exceptions import convert_to_status_data +from ..exceptions import MultiAgentError, convert_to_status_data from ..models import ( AgentRow, AgentStatus, @@ -315,10 +316,13 @@ def _pipeline(r: Redis) -> RedisPipeline: sgroup_name, ) except Exception as e: - log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e)) + log.exception( + "schedule({}): scheduling error!\n{}", sgroup_name, repr(e) + ) else: if ( - AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger + AgentResourceSyncTrigger.AFTER_SCHEDULING + in agent_resource_sync_trigger and kernel_agent_bindings ): selected_agent_ids = [ From d7789703cf50c62ac5f990e0a6b50fc96d5f9cbc Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Fri, 12 Jul 2024 18:00:59 +0900 Subject: [PATCH 6/6] dont pass db cxn to sync_agent_resource() and fix wrong types --- src/ai/backend/manager/api/session.py | 23 ++++-- src/ai/backend/manager/registry.py | 17 ++++- .../backend/manager/scheduler/dispatcher.py | 75 +++++++------------ 3 files changed, 59 insertions(+), 56 deletions(-) diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index ae634f1a41c..ad041e33322 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -77,6 +77,7 @@ ClusterMode, ImageRegistry, KernelId, + KernelStatusCollection, MountPermission, MountTypes, SessionTypes, @@ -86,6 +87,7 @@ from ..config import DEFAULT_CHUNK_SIZE from ..defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE +from ..exceptions import MultiAgentError from ..models import ( AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, DEAD_SESSION_STATUSES, @@ -994,13 +996,20 @@ async def sync_agent_resource( "SYNC_AGENT_RESOURCE (ak:{}/{}, a:{})", requester_access_key, owner_access_key, agent_id ) - async with root_ctx.db.begin() as db_conn: - try: - await root_ctx.registry.sync_agent_resource(db_conn, [agent_id]) - except BackendError: - log.exception("SYNC_AGENT_RESOURCE: exception") - raise - return web.Response(status=204) + try: + result = await root_ctx.registry.sync_agent_resource(root_ctx.db, [agent_id]) + except BackendError: + log.exception("SYNC_AGENT_RESOURCE: exception") + raise + val = result.get(agent_id) + match val: + case KernelStatusCollection(): + pass + case MultiAgentError(): + return web.Response(status=500) + case _: + pass + return web.Response(status=204) @server_status_required(ALL_ALLOWED) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 0dd7d52bff6..320a1d037cd 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -129,7 +129,7 @@ ) from .config import LocalConfig, SharedConfig from .defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE, INTRINSIC_SLOTS -from .exceptions import MultiAgentError, convert_to_status_data +from .exceptions import ErrorStatusInfo, MultiAgentError, convert_to_status_data from .models import ( AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, AGENT_RESOURCE_OCCUPYING_SESSION_STATUSES, @@ -1794,6 +1794,9 @@ async def _update_kernel() -> None: ex = e err_info = convert_to_status_data(ex, self.debug) + def _is_insufficient_resource_err(err_info: ErrorStatusInfo) -> bool: + return err_info["error"]["name"] == "InsufficientResource" + # The agent has already cancelled or issued the destruction lifecycle event # for this batch of kernels. for binding in items: @@ -1825,6 +1828,18 @@ async def _update_failure() -> None: await db_sess.execute(query) await execute_with_retry(_update_failure) + if ( + AgentResourceSyncTrigger.ON_CREATION_FAILURE in agent_resource_sync_trigger + and _is_insufficient_resource_err(err_info) + ): + await self.sync_agent_resource( + self.db, + [ + binding.agent_alloc_ctx.agent_id + for binding in items + if binding.agent_alloc_ctx.agent_id is not None + ], + ) raise async def create_cluster_ssh_keypair(self) -> ClusterSSHKeyPair: diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index d3477f103bb..1c3d49da8aa 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -56,7 +56,6 @@ from ai.backend.common.plugin.hook import PASSED, HookResult from ai.backend.common.types import ( AgentId, - AgentKernelRegistryByStatus, ClusterMode, RedisConnectionInfo, ResourceSlot, @@ -76,7 +75,7 @@ SessionNotFound, ) from ..defs import SERVICE_MAX_RETRIES, LockID -from ..exceptions import MultiAgentError, convert_to_status_data +from ..exceptions import convert_to_status_data from ..models import ( AgentRow, AgentStatus, @@ -308,47 +307,33 @@ def _pipeline(r: Redis) -> RedisPipeline: result = await db_sess.execute(query) schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()] - async with self.db.begin() as db_conn: - for sgroup_name in schedulable_scaling_groups: - try: - kernel_agent_bindings = await self._schedule_in_sgroup( - sched_ctx, + for sgroup_name in schedulable_scaling_groups: + try: + kernel_agent_bindings = await self._schedule_in_sgroup( + sched_ctx, + sgroup_name, + ) + await redis_helper.execute( + self.redis_live, + lambda r: r.hset( + redis_key, + "resource_group", sgroup_name, - ) - except Exception as e: - log.exception( - "schedule({}): scheduling error!\n{}", sgroup_name, repr(e) - ) - else: - if ( - AgentResourceSyncTrigger.AFTER_SCHEDULING - in agent_resource_sync_trigger - and kernel_agent_bindings - ): - selected_agent_ids = [ - binding.agent_alloc_ctx.agent_id - for binding in kernel_agent_bindings - if binding.agent_alloc_ctx.agent_id is not None - ] - async with self.db.begin() as db_conn: - results = await self.registry.sync_agent_resource( - self.db, selected_agent_ids - ) - for agent_id, result in results.items(): - match result: - case AgentKernelRegistryByStatus( - all_running_kernels, - actual_terminating_kernels, - actual_terminated_kernels, - ): - pass - case MultiAgentError(): - pass - case _: - pass - pass - async with SASession(bind=db_conn) as db_session: - pass + ), + ) + except Exception as e: + log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e)) + else: + if ( + AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger + and kernel_agent_bindings + ): + selected_agent_ids = [ + binding.agent_alloc_ctx.agent_id + for binding in kernel_agent_bindings + if binding.agent_alloc_ctx.agent_id is not None + ] + await self.registry.sync_agent_resource(self.db, selected_agent_ids) await redis_helper.execute( self.redis_live, lambda r: r.hset( @@ -461,7 +446,6 @@ async def _update(): len(cancelled_sessions), ) zero = ResourceSlot() - num_scheduled = 0 kernel_agent_bindings_in_sgroup: list[KernelAgentBinding] = [] while len(pending_sessions) > 0: @@ -1041,11 +1025,6 @@ async def _finalize_scheduled() -> None: ) return kernel_agent_bindings - kernel_agent_bindings: list[KernelAgentBinding] = [] - for kernel_row in sess_ctx.kernels: - kernel_agent_bindings.append(KernelAgentBinding(kernel_row, agent_alloc_ctx, set())) - return kernel_agent_bindings - async def _schedule_multi_node_session( self, sched_ctx: SchedulingContext,