From be4063ce94cc306eed9c91076ad58d04ddd1f67b Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 6 Mar 2026 01:59:59 +0900 Subject: [PATCH 1/5] refactor(BA-4911): migrate ScalingGroupOpts from attrs/trafaret to Pydantic BaseModel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace @attr.define with Pydantic BaseModel (frozen=True) - Add field validators for pending_timeout (float→timedelta) and allowed_session_types (str list→SessionTypes list) - Add field serializers for pending_timeout (→float), allowed_session_types (→str list), and agent_selection_strategy (→str) - Remove to_json(), from_json(), as_trafaret() methods - Change scheduler_opts column type from StructuredJSONObjectColumn to PydanticColumn - Update default from {} to ScalingGroupOpts (callable factory) - Replace to_json()/from_json() calls in gql_legacy with model_dump()/model_validate() - Remove unused imports (attr, trafaret, JSONSerializableMixin, StructuredJSONObjectColumn) Co-Authored-By: Claude Sonnet 4.6 --- .../manager/api/gql_legacy/scaling_group.py | 8 +- .../manager/models/scaling_group/row.py | 84 ++++++++----------- 2 files changed, 40 insertions(+), 52 deletions(-) diff --git a/src/ai/backend/manager/api/gql_legacy/scaling_group.py b/src/ai/backend/manager/api/gql_legacy/scaling_group.py index bcb093a1e2d..2ca80f357c9 100644 --- a/src/ai/backend/manager/api/gql_legacy/scaling_group.py +++ b/src/ai/backend/manager/api/gql_legacy/scaling_group.py @@ -460,7 +460,7 @@ def from_row( driver=row.driver, driver_opts=row.driver_opts, scheduler=row.scheduler, - scheduler_opts=row.scheduler_opts.to_json(), + scheduler_opts=row.scheduler_opts.model_dump(mode="json"), use_host_network=row.use_host_network, ) @@ -480,7 +480,7 @@ def from_orm_row( driver=row.driver, driver_opts=row.driver_opts, scheduler=row.scheduler, - scheduler_opts=row.scheduler_opts.to_json(), + scheduler_opts=row.scheduler_opts.model_dump(mode="json"), use_host_network=row.use_host_network, ) @@ -680,7 +680,7 @@ def to_updater(self, name: str) -> Updater[ScalingGroupRow]: scheduler_spec = ScalingGroupSchedulerConfigUpdaterSpec( scheduler=OptionalState.from_graphql(self.scheduler), scheduler_opts=OptionalState.from_graphql( - ScalingGroupOpts.from_json(self.scheduler_opts) + ScalingGroupOpts.model_validate(self.scheduler_opts) if self.scheduler_opts is not None and self.scheduler_opts is not Undefined else Undefined ), @@ -725,7 +725,7 @@ async def mutate( driver=props.driver, driver_opts=props.driver_opts, scheduler=props.scheduler, - scheduler_opts=ScalingGroupOpts.from_json(props.scheduler_opts), + scheduler_opts=ScalingGroupOpts.model_validate(props.scheduler_opts), use_host_network=bool(props.use_host_network), ) creator = Creator(spec=spec) diff --git a/src/ai/backend/manager/models/scaling_group/row.py b/src/ai/backend/manager/models/scaling_group/row.py index 3346d6b3ce8..e57d68adf84 100644 --- a/src/ai/backend/manager/models/scaling_group/row.py +++ b/src/ai/backend/manager/models/scaling_group/row.py @@ -13,9 +13,8 @@ override, ) -import attr import sqlalchemy as sa -import trafaret as t +from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from sqlalchemy.dialects import postgresql as pgsql from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection @@ -31,11 +30,8 @@ ) from sqlalchemy.sql.expression import true -from ai.backend.common import validators as tx -from ai.backend.common.config import agent_selector_config_iv from ai.backend.common.types import ( AgentSelectionStrategy, - JSONSerializableMixin, SessionTypes, ) from ai.backend.manager.data.scaling_group.types import ScalingGroupData @@ -43,7 +39,6 @@ GUID, Base, PydanticColumn, - StructuredJSONObjectColumn, ) from ai.backend.manager.models.group import resolve_group_name_or_id, resolve_groups from ai.backend.manager.models.rbac import ( @@ -84,22 +79,23 @@ ) -@attr.define(slots=True) -class ScalingGroupOpts(JSONSerializableMixin): - allowed_session_types: list[SessionTypes] = attr.Factory( - lambda: [ +class ScalingGroupOpts(BaseModel): + model_config = ConfigDict(frozen=True) + + allowed_session_types: list[SessionTypes] = Field( + default_factory=lambda: [ SessionTypes.INTERACTIVE, SessionTypes.BATCH, SessionTypes.INFERENCE, - ], + ] ) pending_timeout: timedelta = timedelta(seconds=0) - config: Mapping[str, Any] = attr.field(factory=dict) + config: dict[str, Any] = Field(default_factory=dict) # Scheduler has a dedicated database column to store its name, # but agent selector configuration is stored as a part of the scheduler_opts column. agent_selection_strategy: AgentSelectionStrategy = AgentSelectionStrategy.DISPERSED - agent_selector_config: Mapping[str, Any] = attr.field(factory=dict) + agent_selector_config: dict[str, Any] = Field(default_factory=dict) # Only used in the ConcentratedAgentSelector enforce_spreading_endpoint_replica: bool = False @@ -107,44 +103,36 @@ class ScalingGroupOpts(JSONSerializableMixin): allow_fractional_resource_fragmentation: bool = True """If set to false, agent will refuse to start kernel when they are forced to fragment fractional resource request""" - route_cleanup_target_statuses: list[str] = attr.field(factory=lambda: ["unhealthy"]) + route_cleanup_target_statuses: list[str] = Field(default_factory=lambda: ["unhealthy"]) """List of route statuses that should be automatically cleaned up. Valid values: healthy, unhealthy, degraded""" - def to_json(self) -> dict[str, Any]: - return { - "allowed_session_types": [item.value for item in self.allowed_session_types], - "pending_timeout": self.pending_timeout.total_seconds(), - "config": self.config, - "agent_selection_strategy": self.agent_selection_strategy, - "agent_selector_config": self.agent_selector_config, - "enforce_spreading_endpoint_replica": self.enforce_spreading_endpoint_replica, - "allow_fractional_resource_fragmentation": self.allow_fractional_resource_fragmentation, - "route_cleanup_target_statuses": self.route_cleanup_target_statuses, - } - + @field_validator("allowed_session_types", mode="before") @classmethod - def from_json(cls, obj: Mapping[str, Any]) -> ScalingGroupOpts: - return cls(**cls.as_trafaret().check(obj)) + def validate_allowed_session_types(cls, value: Any) -> list[SessionTypes]: + if not isinstance(value, list): + raise ValueError(f"Expected a list, got {type(value)}") + return [SessionTypes(v) if isinstance(v, str) else v for v in value] + @field_serializer("allowed_session_types", mode="plain") + def serialize_allowed_session_types(self, value: list[SessionTypes]) -> list[str]: + return [item.value for item in value] + + @field_validator("pending_timeout", mode="before") @classmethod - def as_trafaret(cls) -> t.Trafaret: - return t.Dict({ - t.Key("allowed_session_types", default=["interactive", "batch"]): t.List( - tx.Enum(SessionTypes), min_length=1 - ), - t.Key("pending_timeout", default=0): tx.TimeDuration(allow_negative=False), - # Each scheduler impl refers an additional "config" key. - t.Key("config", default={}): t.Mapping(t.String, t.Any), - t.Key("agent_selection_strategy", default=AgentSelectionStrategy.DISPERSED): tx.Enum( - AgentSelectionStrategy - ), - t.Key("agent_selector_config", default={}): agent_selector_config_iv, - t.Key("enforce_spreading_endpoint_replica", default=False): t.ToBool, - t.Key("allow_fractional_resource_fragmentation", default=True): t.ToBool, - t.Key("route_cleanup_target_statuses", default=["unhealthy"]): t.List( - t.Enum("healthy", "unhealthy", "degraded") - ), - }).allow_extra("*") + def validate_pending_timeout(cls, value: Any) -> timedelta: + if isinstance(value, (int, float)): + return timedelta(seconds=value) + if isinstance(value, timedelta): + return value + raise ValueError(f"Expected a number or timedelta, got {type(value)}") + + @field_serializer("pending_timeout", mode="plain") + def serialize_pending_timeout(self, value: timedelta) -> float: + return value.total_seconds() + + @field_serializer("agent_selection_strategy", mode="plain") + def serialize_agent_selection_strategy(self, value: AgentSelectionStrategy) -> str: + return value.value # When scheduling, we take the union of allowed scaling groups for @@ -292,9 +280,9 @@ class ScalingGroupRow(Base): # type: ignore[misc] ) scheduler_opts: Mapped[ScalingGroupOpts] = mapped_column( "scheduler_opts", - StructuredJSONObjectColumn(ScalingGroupOpts), + PydanticColumn(ScalingGroupOpts), nullable=False, - default={}, + default=ScalingGroupOpts, ) fair_share_spec: Mapped[FairShareScalingGroupSpec | None] = mapped_column( "fair_share_spec", From ec4511d1f0c938b987b9774a854a500784a6f978 Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 6 Mar 2026 02:02:58 +0900 Subject: [PATCH 2/5] test(BA-4911): add unit tests for ScalingGroupOpts Pydantic model Test roundtrip serialization, default instantiation, field validators (pending_timeout timedelta, allowed_session_types enum, agent_selection_strategy), invalid input validation errors, and backward compatibility with legacy JSON data. Co-Authored-By: Claude Sonnet 4.6 --- .../manager/models/scaling_group/test_row.py | 194 ++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 tests/unit/manager/models/scaling_group/test_row.py diff --git a/tests/unit/manager/models/scaling_group/test_row.py b/tests/unit/manager/models/scaling_group/test_row.py new file mode 100644 index 00000000000..2b7a84423d4 --- /dev/null +++ b/tests/unit/manager/models/scaling_group/test_row.py @@ -0,0 +1,194 @@ +"""Tests for ScalingGroupOpts Pydantic model serialization and validation.""" + +from __future__ import annotations + +from datetime import timedelta + +import pytest +from pydantic import ValidationError + +from ai.backend.common.types import AgentSelectionStrategy, SessionTypes +from ai.backend.manager.models.scaling_group.row import ScalingGroupOpts + + +class TestScalingGroupOptsDefaults: + """Test default instantiation of ScalingGroupOpts.""" + + def test_create_with_no_args(self) -> None: + opts = ScalingGroupOpts() + assert opts.allowed_session_types == [ + SessionTypes.INTERACTIVE, + SessionTypes.BATCH, + SessionTypes.INFERENCE, + ] + assert opts.pending_timeout == timedelta(seconds=0) + assert opts.config == {} + assert opts.agent_selection_strategy == AgentSelectionStrategy.DISPERSED + assert opts.agent_selector_config == {} + assert opts.enforce_spreading_endpoint_replica is False + assert opts.allow_fractional_resource_fragmentation is True + assert opts.route_cleanup_target_statuses == ["unhealthy"] + + def test_frozen_model_raises_on_mutation(self) -> None: + opts = ScalingGroupOpts() + with pytest.raises(ValidationError): + opts.pending_timeout = timedelta(seconds=30) # type: ignore[misc] + + +class TestScalingGroupOptsRoundtrip: + """Test roundtrip serialization: model_dump(mode='json') → model_validate().""" + + def test_roundtrip_defaults(self) -> None: + original = ScalingGroupOpts() + dumped = original.model_dump(mode="json") + restored = ScalingGroupOpts.model_validate(dumped) + assert restored == original + + def test_roundtrip_custom_values(self) -> None: + original = ScalingGroupOpts( + allowed_session_types=[SessionTypes.BATCH], + pending_timeout=timedelta(seconds=120), + config={"key": "value"}, + agent_selection_strategy=AgentSelectionStrategy.CONCENTRATED, + agent_selector_config={"max_agents": 2}, + enforce_spreading_endpoint_replica=True, + allow_fractional_resource_fragmentation=False, + route_cleanup_target_statuses=["unhealthy", "degraded"], + ) + dumped = original.model_dump(mode="json") + restored = ScalingGroupOpts.model_validate(dumped) + assert restored == original + + def test_model_dump_json_mode_types(self) -> None: + """Verify model_dump(mode='json') produces JSON-compatible types.""" + opts = ScalingGroupOpts( + allowed_session_types=[SessionTypes.INTERACTIVE, SessionTypes.BATCH], + pending_timeout=timedelta(seconds=60), + agent_selection_strategy=AgentSelectionStrategy.DISPERSED, + ) + dumped = opts.model_dump(mode="json") + + # pending_timeout → float seconds + assert isinstance(dumped["pending_timeout"], float) + assert dumped["pending_timeout"] == 60.0 + + # allowed_session_types → list of strings + assert isinstance(dumped["allowed_session_types"], list) + assert all(isinstance(v, str) for v in dumped["allowed_session_types"]) + assert "interactive" in dumped["allowed_session_types"] + assert "batch" in dumped["allowed_session_types"] + + # agent_selection_strategy → string + assert isinstance(dumped["agent_selection_strategy"], str) + assert dumped["agent_selection_strategy"] == AgentSelectionStrategy.DISPERSED.value + + +class TestScalingGroupOptsPendingTimeout: + """Test pending_timeout field validator and serializer.""" + + def test_validate_from_int(self) -> None: + opts = ScalingGroupOpts(pending_timeout=30) # type: ignore[arg-type] + assert opts.pending_timeout == timedelta(seconds=30) + + def test_validate_from_float(self) -> None: + opts = ScalingGroupOpts(pending_timeout=45.5) # type: ignore[arg-type] + assert opts.pending_timeout == timedelta(seconds=45.5) + + def test_validate_from_timedelta(self) -> None: + td = timedelta(minutes=2) + opts = ScalingGroupOpts(pending_timeout=td) + assert opts.pending_timeout == td + + def test_validate_from_zero(self) -> None: + opts = ScalingGroupOpts(pending_timeout=0) # type: ignore[arg-type] + assert opts.pending_timeout == timedelta(seconds=0) + + def test_serialize_to_seconds(self) -> None: + opts = ScalingGroupOpts(pending_timeout=timedelta(minutes=1, seconds=30)) + dumped = opts.model_dump(mode="json") + assert dumped["pending_timeout"] == 90.0 + + def test_validate_invalid_type_raises(self) -> None: + with pytest.raises(ValidationError): + ScalingGroupOpts(pending_timeout="not-a-number") # type: ignore[arg-type] + + +class TestScalingGroupOptsAllowedSessionTypes: + """Test allowed_session_types field validator and serializer.""" + + def test_validate_from_string_list(self) -> None: + opts = ScalingGroupOpts.model_validate({ + "allowed_session_types": ["interactive", "batch"], + }) + assert opts.allowed_session_types == [SessionTypes.INTERACTIVE, SessionTypes.BATCH] + + def test_validate_from_enum_list(self) -> None: + opts = ScalingGroupOpts(allowed_session_types=[SessionTypes.INFERENCE]) + assert opts.allowed_session_types == [SessionTypes.INFERENCE] + + def test_validate_invalid_session_type_raises(self) -> None: + with pytest.raises(ValidationError): + ScalingGroupOpts.model_validate({"allowed_session_types": ["not_a_valid_type"]}) + + def test_validate_non_list_raises(self) -> None: + with pytest.raises(ValidationError): + ScalingGroupOpts.model_validate({"allowed_session_types": "interactive"}) + + def test_serialize_to_string_list(self) -> None: + opts = ScalingGroupOpts( + allowed_session_types=[SessionTypes.INTERACTIVE, SessionTypes.INFERENCE] + ) + dumped = opts.model_dump(mode="json") + assert dumped["allowed_session_types"] == [ + SessionTypes.INTERACTIVE.value, + SessionTypes.INFERENCE.value, + ] + + +class TestScalingGroupOptsAgentSelectionStrategy: + """Test agent_selection_strategy field serializer.""" + + def test_serialize_to_string(self) -> None: + opts = ScalingGroupOpts(agent_selection_strategy=AgentSelectionStrategy.CONCENTRATED) + dumped = opts.model_dump(mode="json") + assert dumped["agent_selection_strategy"] == AgentSelectionStrategy.CONCENTRATED.value + + def test_validate_from_string(self) -> None: + opts = ScalingGroupOpts.model_validate({ + "agent_selection_strategy": AgentSelectionStrategy.CONCENTRATED.value, + }) + assert opts.agent_selection_strategy == AgentSelectionStrategy.CONCENTRATED + + +class TestScalingGroupOptsBackwardCompatibility: + """Test backward compatibility: existing JSON data (without new fields) loads correctly.""" + + def test_empty_dict_uses_all_defaults(self) -> None: + opts = ScalingGroupOpts.model_validate({}) + assert opts.allowed_session_types == [ + SessionTypes.INTERACTIVE, + SessionTypes.BATCH, + SessionTypes.INFERENCE, + ] + assert opts.pending_timeout == timedelta(seconds=0) + assert opts.config == {} + assert opts.agent_selection_strategy == AgentSelectionStrategy.DISPERSED + assert opts.agent_selector_config == {} + assert opts.enforce_spreading_endpoint_replica is False + assert opts.allow_fractional_resource_fragmentation is True + assert opts.route_cleanup_target_statuses == ["unhealthy"] + + def test_partial_legacy_data(self) -> None: + """Legacy JSON with only some fields sets defaults for missing ones.""" + legacy_data = { + "allowed_session_types": ["interactive", "batch"], + "pending_timeout": 0, + } + opts = ScalingGroupOpts.model_validate(legacy_data) + assert opts.allowed_session_types == [SessionTypes.INTERACTIVE, SessionTypes.BATCH] + assert opts.pending_timeout == timedelta(seconds=0) + # New fields get defaults + assert opts.config == {} + assert opts.agent_selection_strategy == AgentSelectionStrategy.DISPERSED + assert opts.enforce_spreading_endpoint_replica is False + assert opts.allow_fractional_resource_fragmentation is True From b008b0a61e15eeaccaa55fba97745fd0cdce59cf Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 6 Mar 2026 02:04:15 +0900 Subject: [PATCH 3/5] changelog: add news fragment for PR #9723 --- changes/9723.enhance.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/9723.enhance.md diff --git a/changes/9723.enhance.md b/changes/9723.enhance.md new file mode 100644 index 00000000000..db5a843e51e --- /dev/null +++ b/changes/9723.enhance.md @@ -0,0 +1 @@ +Migrate `ScalingGroupOpts` from attrs/trafaret to Pydantic `BaseModel` for consistency with other model classes in the scaling group module. From 25f7857690688567342b2e2966c868970120a1c5 Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 6 Mar 2026 10:11:38 +0900 Subject: [PATCH 4/5] fix: replace ScalingGroupOpts.from_json() with model_validate() in test Co-Authored-By: Claude Opus 4.6 --- tests/unit/manager/test_scaling_group_for_enqueue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/manager/test_scaling_group_for_enqueue.py b/tests/unit/manager/test_scaling_group_for_enqueue.py index 5873edd50d3..db6e75e462a 100644 --- a/tests/unit/manager/test_scaling_group_for_enqueue.py +++ b/tests/unit/manager/test_scaling_group_for_enqueue.py @@ -19,7 +19,7 @@ def _create_mock_sgroup(name: str, allowed_session_types: list[str]) -> MagicMoc """Create a mock scaling group with proper attribute access.""" mock = MagicMock() mock.name = name - mock.scheduler_opts = ScalingGroupOpts.from_json({ + mock.scheduler_opts = ScalingGroupOpts.model_validate({ "allowed_session_types": allowed_session_types, }) return mock From d76b825b34f8756c0cc70c6275d43070a48b4539 Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 6 Mar 2026 11:43:08 +0900 Subject: [PATCH 5/5] Remove redundant Pydantic field validators for ScalingGroupOpts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pydantic v2 natively handles str→enum coercion for string-backed enums and int/float→timedelta coercion, making the custom `mode="before"` validators unnecessary. Serializers are retained for DB-compatible JSON output format (float seconds, string enum values). Co-Authored-By: Claude Opus 4.6 --- .../manager/models/scaling_group/row.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/src/ai/backend/manager/models/scaling_group/row.py b/src/ai/backend/manager/models/scaling_group/row.py index e57d68adf84..10877b3ecbf 100644 --- a/src/ai/backend/manager/models/scaling_group/row.py +++ b/src/ai/backend/manager/models/scaling_group/row.py @@ -14,7 +14,7 @@ ) import sqlalchemy as sa -from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_serializer from sqlalchemy.dialects import postgresql as pgsql from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection @@ -106,26 +106,10 @@ class ScalingGroupOpts(BaseModel): route_cleanup_target_statuses: list[str] = Field(default_factory=lambda: ["unhealthy"]) """List of route statuses that should be automatically cleaned up. Valid values: healthy, unhealthy, degraded""" - @field_validator("allowed_session_types", mode="before") - @classmethod - def validate_allowed_session_types(cls, value: Any) -> list[SessionTypes]: - if not isinstance(value, list): - raise ValueError(f"Expected a list, got {type(value)}") - return [SessionTypes(v) if isinstance(v, str) else v for v in value] - @field_serializer("allowed_session_types", mode="plain") def serialize_allowed_session_types(self, value: list[SessionTypes]) -> list[str]: return [item.value for item in value] - @field_validator("pending_timeout", mode="before") - @classmethod - def validate_pending_timeout(cls, value: Any) -> timedelta: - if isinstance(value, (int, float)): - return timedelta(seconds=value) - if isinstance(value, timedelta): - return value - raise ValueError(f"Expected a number or timedelta, got {type(value)}") - @field_serializer("pending_timeout", mode="plain") def serialize_pending_timeout(self, value: timedelta) -> float: return value.total_seconds()