Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/9723.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Migrate `ScalingGroupOpts` from attrs/trafaret to Pydantic `BaseModel` for consistency with other model classes in the scaling group module.
8 changes: 4 additions & 4 deletions src/ai/backend/manager/api/gql_legacy/scaling_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def from_row(
driver=row.driver,
driver_opts=row.driver_opts,
scheduler=row.scheduler,
scheduler_opts=row.scheduler_opts.to_json(),
scheduler_opts=row.scheduler_opts.model_dump(mode="json"),
use_host_network=row.use_host_network,
)

Expand All @@ -480,7 +480,7 @@ def from_orm_row(
driver=row.driver,
driver_opts=row.driver_opts,
scheduler=row.scheduler,
scheduler_opts=row.scheduler_opts.to_json(),
scheduler_opts=row.scheduler_opts.model_dump(mode="json"),
use_host_network=row.use_host_network,
)

Expand Down Expand Up @@ -680,7 +680,7 @@ def to_updater(self, name: str) -> Updater[ScalingGroupRow]:
scheduler_spec = ScalingGroupSchedulerConfigUpdaterSpec(
scheduler=OptionalState.from_graphql(self.scheduler),
scheduler_opts=OptionalState.from_graphql(
ScalingGroupOpts.from_json(self.scheduler_opts)
ScalingGroupOpts.model_validate(self.scheduler_opts)
if self.scheduler_opts is not None and self.scheduler_opts is not Undefined
else Undefined
),
Expand Down Expand Up @@ -725,7 +725,7 @@ async def mutate(
driver=props.driver,
driver_opts=props.driver_opts,
scheduler=props.scheduler,
scheduler_opts=ScalingGroupOpts.from_json(props.scheduler_opts),
scheduler_opts=ScalingGroupOpts.model_validate(props.scheduler_opts),
use_host_network=bool(props.use_host_network),
)
creator = Creator(spec=spec)
Expand Down
70 changes: 21 additions & 49 deletions src/ai/backend/manager/models/scaling_group/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
override,
)

import attr
import sqlalchemy as sa
import trafaret as t
from pydantic import BaseModel, ConfigDict, Field, field_serializer
from sqlalchemy.dialects import postgresql as pgsql
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
Expand All @@ -31,19 +30,15 @@
)
from sqlalchemy.sql.expression import true

from ai.backend.common import validators as tx
from ai.backend.common.config import agent_selector_config_iv
from ai.backend.common.types import (
AgentSelectionStrategy,
JSONSerializableMixin,
SessionTypes,
)
from ai.backend.manager.data.scaling_group.types import ScalingGroupData
from ai.backend.manager.models.base import (
GUID,
Base,
PydanticColumn,
StructuredJSONObjectColumn,
)
from ai.backend.manager.models.group import resolve_group_name_or_id, resolve_groups
from ai.backend.manager.models.rbac import (
Expand Down Expand Up @@ -84,67 +79,44 @@
)


@attr.define(slots=True)
class ScalingGroupOpts(JSONSerializableMixin):
allowed_session_types: list[SessionTypes] = attr.Factory(
lambda: [
class ScalingGroupOpts(BaseModel):
model_config = ConfigDict(frozen=True)

allowed_session_types: list[SessionTypes] = Field(
default_factory=lambda: [
SessionTypes.INTERACTIVE,
SessionTypes.BATCH,
SessionTypes.INFERENCE,
],
]
)
pending_timeout: timedelta = timedelta(seconds=0)
config: Mapping[str, Any] = attr.field(factory=dict)
config: dict[str, Any] = Field(default_factory=dict)

# Scheduler has a dedicated database column to store its name,
# but agent selector configuration is stored as a part of the scheduler_opts column.
agent_selection_strategy: AgentSelectionStrategy = AgentSelectionStrategy.DISPERSED
agent_selector_config: Mapping[str, Any] = attr.field(factory=dict)
agent_selector_config: dict[str, Any] = Field(default_factory=dict)

# Only used in the ConcentratedAgentSelector
enforce_spreading_endpoint_replica: bool = False

allow_fractional_resource_fragmentation: bool = True
"""If set to false, agent will refuse to start kernel when they are forced to fragment fractional resource request"""

route_cleanup_target_statuses: list[str] = attr.field(factory=lambda: ["unhealthy"])
route_cleanup_target_statuses: list[str] = Field(default_factory=lambda: ["unhealthy"])
"""List of route statuses that should be automatically cleaned up. Valid values: healthy, unhealthy, degraded"""

def to_json(self) -> dict[str, Any]:
return {
"allowed_session_types": [item.value for item in self.allowed_session_types],
"pending_timeout": self.pending_timeout.total_seconds(),
"config": self.config,
"agent_selection_strategy": self.agent_selection_strategy,
"agent_selector_config": self.agent_selector_config,
"enforce_spreading_endpoint_replica": self.enforce_spreading_endpoint_replica,
"allow_fractional_resource_fragmentation": self.allow_fractional_resource_fragmentation,
"route_cleanup_target_statuses": self.route_cleanup_target_statuses,
}
@field_serializer("allowed_session_types", mode="plain")
def serialize_allowed_session_types(self, value: list[SessionTypes]) -> list[str]:
return [item.value for item in value]

@classmethod
def from_json(cls, obj: Mapping[str, Any]) -> ScalingGroupOpts:
return cls(**cls.as_trafaret().check(obj))
@field_serializer("pending_timeout", mode="plain")
def serialize_pending_timeout(self, value: timedelta) -> float:
return value.total_seconds()

@classmethod
def as_trafaret(cls) -> t.Trafaret:
return t.Dict({
t.Key("allowed_session_types", default=["interactive", "batch"]): t.List(
tx.Enum(SessionTypes), min_length=1
),
t.Key("pending_timeout", default=0): tx.TimeDuration(allow_negative=False),
# Each scheduler impl refers an additional "config" key.
t.Key("config", default={}): t.Mapping(t.String, t.Any),
t.Key("agent_selection_strategy", default=AgentSelectionStrategy.DISPERSED): tx.Enum(
AgentSelectionStrategy
),
t.Key("agent_selector_config", default={}): agent_selector_config_iv,
t.Key("enforce_spreading_endpoint_replica", default=False): t.ToBool,
t.Key("allow_fractional_resource_fragmentation", default=True): t.ToBool,
t.Key("route_cleanup_target_statuses", default=["unhealthy"]): t.List(
t.Enum("healthy", "unhealthy", "degraded")
),
}).allow_extra("*")
@field_serializer("agent_selection_strategy", mode="plain")
def serialize_agent_selection_strategy(self, value: AgentSelectionStrategy) -> str:
return value.value


# When scheduling, we take the union of allowed scaling groups for
Expand Down Expand Up @@ -292,9 +264,9 @@ class ScalingGroupRow(Base): # type: ignore[misc]
)
scheduler_opts: Mapped[ScalingGroupOpts] = mapped_column(
"scheduler_opts",
StructuredJSONObjectColumn(ScalingGroupOpts),
PydanticColumn(ScalingGroupOpts),
nullable=False,
default={},
default=ScalingGroupOpts,
)
fair_share_spec: Mapped[FairShareScalingGroupSpec | None] = mapped_column(
"fair_share_spec",
Expand Down
194 changes: 194 additions & 0 deletions tests/unit/manager/models/scaling_group/test_row.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Tests for ScalingGroupOpts Pydantic model serialization and validation."""

from __future__ import annotations

from datetime import timedelta

import pytest
from pydantic import ValidationError

from ai.backend.common.types import AgentSelectionStrategy, SessionTypes
from ai.backend.manager.models.scaling_group.row import ScalingGroupOpts


class TestScalingGroupOptsDefaults:
"""Test default instantiation of ScalingGroupOpts."""

def test_create_with_no_args(self) -> None:
opts = ScalingGroupOpts()
assert opts.allowed_session_types == [
SessionTypes.INTERACTIVE,
SessionTypes.BATCH,
SessionTypes.INFERENCE,
]
assert opts.pending_timeout == timedelta(seconds=0)
assert opts.config == {}
assert opts.agent_selection_strategy == AgentSelectionStrategy.DISPERSED
assert opts.agent_selector_config == {}
assert opts.enforce_spreading_endpoint_replica is False
assert opts.allow_fractional_resource_fragmentation is True
assert opts.route_cleanup_target_statuses == ["unhealthy"]

def test_frozen_model_raises_on_mutation(self) -> None:
opts = ScalingGroupOpts()
with pytest.raises(ValidationError):
opts.pending_timeout = timedelta(seconds=30) # type: ignore[misc]


class TestScalingGroupOptsRoundtrip:
"""Test roundtrip serialization: model_dump(mode='json') → model_validate()."""

def test_roundtrip_defaults(self) -> None:
original = ScalingGroupOpts()
dumped = original.model_dump(mode="json")
restored = ScalingGroupOpts.model_validate(dumped)
assert restored == original

def test_roundtrip_custom_values(self) -> None:
original = ScalingGroupOpts(
allowed_session_types=[SessionTypes.BATCH],
pending_timeout=timedelta(seconds=120),
config={"key": "value"},
agent_selection_strategy=AgentSelectionStrategy.CONCENTRATED,
agent_selector_config={"max_agents": 2},
enforce_spreading_endpoint_replica=True,
allow_fractional_resource_fragmentation=False,
route_cleanup_target_statuses=["unhealthy", "degraded"],
)
dumped = original.model_dump(mode="json")
restored = ScalingGroupOpts.model_validate(dumped)
assert restored == original

def test_model_dump_json_mode_types(self) -> None:
"""Verify model_dump(mode='json') produces JSON-compatible types."""
opts = ScalingGroupOpts(
allowed_session_types=[SessionTypes.INTERACTIVE, SessionTypes.BATCH],
pending_timeout=timedelta(seconds=60),
agent_selection_strategy=AgentSelectionStrategy.DISPERSED,
)
dumped = opts.model_dump(mode="json")

# pending_timeout → float seconds
assert isinstance(dumped["pending_timeout"], float)
assert dumped["pending_timeout"] == 60.0

# allowed_session_types → list of strings
assert isinstance(dumped["allowed_session_types"], list)
assert all(isinstance(v, str) for v in dumped["allowed_session_types"])
assert "interactive" in dumped["allowed_session_types"]
assert "batch" in dumped["allowed_session_types"]

# agent_selection_strategy → string
assert isinstance(dumped["agent_selection_strategy"], str)
assert dumped["agent_selection_strategy"] == AgentSelectionStrategy.DISPERSED.value


class TestScalingGroupOptsPendingTimeout:
"""Test pending_timeout field validator and serializer."""

def test_validate_from_int(self) -> None:
opts = ScalingGroupOpts(pending_timeout=30) # type: ignore[arg-type]
assert opts.pending_timeout == timedelta(seconds=30)

def test_validate_from_float(self) -> None:
opts = ScalingGroupOpts(pending_timeout=45.5) # type: ignore[arg-type]
assert opts.pending_timeout == timedelta(seconds=45.5)

def test_validate_from_timedelta(self) -> None:
td = timedelta(minutes=2)
opts = ScalingGroupOpts(pending_timeout=td)
assert opts.pending_timeout == td

def test_validate_from_zero(self) -> None:
opts = ScalingGroupOpts(pending_timeout=0) # type: ignore[arg-type]
assert opts.pending_timeout == timedelta(seconds=0)

def test_serialize_to_seconds(self) -> None:
opts = ScalingGroupOpts(pending_timeout=timedelta(minutes=1, seconds=30))
dumped = opts.model_dump(mode="json")
assert dumped["pending_timeout"] == 90.0

def test_validate_invalid_type_raises(self) -> None:
with pytest.raises(ValidationError):
ScalingGroupOpts(pending_timeout="not-a-number") # type: ignore[arg-type]


class TestScalingGroupOptsAllowedSessionTypes:
"""Test allowed_session_types field validator and serializer."""

def test_validate_from_string_list(self) -> None:
opts = ScalingGroupOpts.model_validate({
"allowed_session_types": ["interactive", "batch"],
})
assert opts.allowed_session_types == [SessionTypes.INTERACTIVE, SessionTypes.BATCH]

def test_validate_from_enum_list(self) -> None:
opts = ScalingGroupOpts(allowed_session_types=[SessionTypes.INFERENCE])
assert opts.allowed_session_types == [SessionTypes.INFERENCE]

def test_validate_invalid_session_type_raises(self) -> None:
with pytest.raises(ValidationError):
ScalingGroupOpts.model_validate({"allowed_session_types": ["not_a_valid_type"]})

def test_validate_non_list_raises(self) -> None:
with pytest.raises(ValidationError):
ScalingGroupOpts.model_validate({"allowed_session_types": "interactive"})

def test_serialize_to_string_list(self) -> None:
opts = ScalingGroupOpts(
allowed_session_types=[SessionTypes.INTERACTIVE, SessionTypes.INFERENCE]
)
dumped = opts.model_dump(mode="json")
assert dumped["allowed_session_types"] == [
SessionTypes.INTERACTIVE.value,
SessionTypes.INFERENCE.value,
]


class TestScalingGroupOptsAgentSelectionStrategy:
"""Test agent_selection_strategy field serializer."""

def test_serialize_to_string(self) -> None:
opts = ScalingGroupOpts(agent_selection_strategy=AgentSelectionStrategy.CONCENTRATED)
dumped = opts.model_dump(mode="json")
assert dumped["agent_selection_strategy"] == AgentSelectionStrategy.CONCENTRATED.value

def test_validate_from_string(self) -> None:
opts = ScalingGroupOpts.model_validate({
"agent_selection_strategy": AgentSelectionStrategy.CONCENTRATED.value,
})
assert opts.agent_selection_strategy == AgentSelectionStrategy.CONCENTRATED


class TestScalingGroupOptsBackwardCompatibility:
"""Test backward compatibility: existing JSON data (without new fields) loads correctly."""

def test_empty_dict_uses_all_defaults(self) -> None:
opts = ScalingGroupOpts.model_validate({})
assert opts.allowed_session_types == [
SessionTypes.INTERACTIVE,
SessionTypes.BATCH,
SessionTypes.INFERENCE,
]
assert opts.pending_timeout == timedelta(seconds=0)
assert opts.config == {}
assert opts.agent_selection_strategy == AgentSelectionStrategy.DISPERSED
assert opts.agent_selector_config == {}
assert opts.enforce_spreading_endpoint_replica is False
assert opts.allow_fractional_resource_fragmentation is True
assert opts.route_cleanup_target_statuses == ["unhealthy"]

def test_partial_legacy_data(self) -> None:
"""Legacy JSON with only some fields sets defaults for missing ones."""
legacy_data = {
"allowed_session_types": ["interactive", "batch"],
"pending_timeout": 0,
}
opts = ScalingGroupOpts.model_validate(legacy_data)
assert opts.allowed_session_types == [SessionTypes.INTERACTIVE, SessionTypes.BATCH]
assert opts.pending_timeout == timedelta(seconds=0)
# New fields get defaults
assert opts.config == {}
assert opts.agent_selection_strategy == AgentSelectionStrategy.DISPERSED
assert opts.enforce_spreading_endpoint_replica is False
assert opts.allow_fractional_resource_fragmentation is True
2 changes: 1 addition & 1 deletion tests/unit/manager/test_scaling_group_for_enqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _create_mock_sgroup(name: str, allowed_session_types: list[str]) -> MagicMoc
"""Create a mock scaling group with proper attribute access."""
mock = MagicMock()
mock.name = name
mock.scheduler_opts = ScalingGroupOpts.from_json({
mock.scheduler_opts = ScalingGroupOpts.model_validate({
"allowed_session_types": allowed_session_types,
})
return mock
Expand Down
Loading