diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index f7ebbde5..0b06e83a 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -80,6 +80,30 @@ The logic applies in the following order: 2. **Blacklist Check**: For any model *not* on the whitelist, the client checks the blacklist (`IGNORE_MODELS_`). If the model matches a blacklist pattern (supports wildcards like `*-preview`), it is excluded. 3. **Default**: If a model is on neither list, it is included. +#### Per-Model Routing Overrides (v1) + +`MODEL_ROUTING_OVERRIDES` lets operators rewrite `weighted-router/` aliases into a concrete provider-prefixed model before provider lock-in. v1 supports only strict `single` routes so retry, cooldown, and credential rotation continue to run inside one provider lane. + +In v1, `allowed_providers` must contain only the primary provider and `fallback_providers` must remain empty. + +Example: + +```bash +MODEL_ROUTING_OVERRIDES='{ + "nemotron-3-super": { + "strategy": "single", + "primary": "ollama", + "allowed_providers": ["ollama"], + "fallback_providers": [], + "strict": true, + "allow_global_fallback": false, + "reason": "Only available on Ollama Cloud" + } +}' +``` + +This rewrites `weighted-router/nemotron-3-super` to `ollama/nemotron-3-super`. Invalid override config fails at startup, and unmatched `weighted-router/*` models fail closed instead of silently falling back to another provider. + #### Request Lifecycle: A Deadline-Driven Approach The request lifecycle has been designed around a single, authoritative time budget to ensure predictable performance: @@ -1925,4 +1949,3 @@ The GUI modifies the same environment variables that the `RotatingClient` reads: 3. **Proxy applies rules** → `get_available_models()` filters based on rules **Note**: The proxy must be restarted to pick up rule changes made via the GUI (or use the Launcher TUI's reload functionality if available). - diff --git a/README.md b/README.md index a7c3c438..ef512e1e 100644 --- a/README.md +++ b/README.md @@ -477,6 +477,7 @@ The proxy includes a powerful text-based UI for configuration and management. | `ROTATION_MODE_` | `balanced` or `sequential` | `ROTATION_MODE_GEMINI=sequential` | | `IGNORE_MODELS_` | Blacklist (comma-separated, supports `*`) | `IGNORE_MODELS_OPENAI=*-preview*` | | `WHITELIST_MODELS_` | Whitelist (overrides blacklist) | `WHITELIST_MODELS_GEMINI=gemini-2.5-pro` | +| `MODEL_ROUTING_OVERRIDES` | JSON per-model routing overrides for `weighted-router/*` aliases | `{"nemotron-3-super":{"strategy":"single","primary":"ollama","allowed_providers":["ollama"],"fallback_providers":[],"strict":true,"allow_global_fallback":false}}` | ### Advanced Features @@ -491,6 +492,31 @@ The proxy includes a powerful text-based UI for configuration and management. +
+Weighted Router Per-Model Overrides (v1) + +Use `MODEL_ROUTING_OVERRIDES` to pin a `weighted-router/` alias to a single provider before credential selection begins. v1 supports only the `single` strategy and fails closed if a matching override is missing or invalid. + +In v1, `allowed_providers` must contain only the primary provider and `fallback_providers` must stay empty. + +```bash +export MODEL_ROUTING_OVERRIDES='{ + "nemotron-3-super": { + "strategy": "single", + "primary": "ollama", + "allowed_providers": ["ollama"], + "fallback_providers": [], + "strict": true, + "allow_global_fallback": false, + "reason": "Only available on Ollama Cloud" + } +}' +``` + +With that configuration, a request for `weighted-router/nemotron-3-super` is rewritten to `ollama/nemotron-3-super` before the normal retry and credential rotation flow runs. + +
+
Model Filtering (Whitelists & Blacklists) diff --git a/src/rotator_library/client.py b/src/rotator_library/client.py index fdd12d67..bef37028 100644 --- a/src/rotator_library/client.py +++ b/src/rotator_library/client.py @@ -43,6 +43,7 @@ from .credential_manager import CredentialManager from .background_refresher import BackgroundRefresher from .model_definitions import ModelDefinitions +from .routing_policy import RouteDecision, RoutingPolicy, RoutingPolicyError from .transaction_logger import TransactionLogger from .utils.paths import get_default_root, get_logs_dir, get_oauth_dir, get_data_file from .utils.suppress_litellm_warnings import suppress_litellm_serialization_warnings @@ -85,6 +86,7 @@ def __init__( enable_request_logging: bool = False, max_concurrent_requests_per_key: Optional[Dict[str, int]] = None, rotation_tolerance: float = DEFAULT_ROTATION_TOLERANCE, + model_routing_overrides: Optional[Dict[str, Any]] = None, data_dir: Optional[Union[str, Path]] = None, ): """ @@ -103,6 +105,7 @@ def __init__( whitelist_models: Models to explicitly whitelist per provider enable_request_logging: Whether to enable detailed request logging max_concurrent_requests_per_key: Max concurrent requests per key by provider + model_routing_overrides: Per-model routing overrides for weighted-router/* models. rotation_tolerance: Tolerance for weighted random credential rotation. - 0.0: Deterministic, least-used credential always selected - 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max @@ -172,6 +175,11 @@ def __init__( for provider, paths in self.oauth_credentials.items(): all_credentials.setdefault(provider, []).extend(paths) self.all_credentials = all_credentials + self.model_routing_overrides = ( + model_routing_overrides + if model_routing_overrides is not None + else self._load_model_routing_overrides_from_env() + ) self.max_retries = max_retries self.global_timeout = global_timeout @@ -180,6 +188,7 @@ def __init__( # Initialize provider plugins early so they can be used for rotation mode detection self._provider_plugins = PROVIDER_PLUGINS self._provider_instances = {} + self.routing_policy = self._build_routing_policy() # Build provider rotation modes map # Each provider can specify its preferred rotation mode ("balanced" or "sequential") @@ -247,9 +256,7 @@ def __init__( priority_multipliers_by_mode[provider] = {} if mode not in priority_multipliers_by_mode[provider]: priority_multipliers_by_mode[provider][mode] = {} - priority_multipliers_by_mode[provider][mode][ - priority - ] = multiplier + priority_multipliers_by_mode[provider][mode][priority] = multiplier lib_logger.info( f"Provider '{provider}' priority {priority} ({mode} mode) multiplier: {multiplier}x" ) @@ -273,13 +280,9 @@ def __init__( # Log configured multipliers for provider, multipliers in priority_multipliers.items(): if multipliers: - lib_logger.info( - f"Provider '{provider}' priority multipliers: {multipliers}" - ) + lib_logger.info(f"Provider '{provider}' priority multipliers: {multipliers}") for provider, fallback in sequential_fallback_multipliers.items(): - lib_logger.info( - f"Provider '{provider}' sequential fallback multiplier: {fallback}x" - ) + lib_logger.info(f"Provider '{provider}' sequential fallback multiplier: {fallback}x") # Build fair cycle configuration fair_cycle_enabled: Dict[str, bool] = {} @@ -296,9 +299,7 @@ def __init__( env_val = os.getenv(env_key) if env_val is not None: fair_cycle_enabled[provider] = env_val.lower() in ("true", "1", "yes") - elif provider_class and hasattr( - provider_class, "default_fair_cycle_enabled" - ): + elif provider_class and hasattr(provider_class, "default_fair_cycle_enabled"): default_val = provider_class.default_fair_cycle_enabled if default_val is not None: fair_cycle_enabled[provider] = default_val @@ -310,12 +311,8 @@ def __init__( env_val = os.getenv(env_key) if env_val is not None and env_val.lower() in ("model_group", "credential"): fair_cycle_tracking_mode[provider] = env_val.lower() - elif provider_class and hasattr( - provider_class, "default_fair_cycle_tracking_mode" - ): - fair_cycle_tracking_mode[provider] = ( - provider_class.default_fair_cycle_tracking_mode - ) + elif provider_class and hasattr(provider_class, "default_fair_cycle_tracking_mode"): + fair_cycle_tracking_mode[provider] = provider_class.default_fair_cycle_tracking_mode # Cross-tier - check env, then provider default env_key = f"FAIR_CYCLE_CROSS_TIER_{provider.upper()}" @@ -326,9 +323,7 @@ def __init__( "1", "yes", ) - elif provider_class and hasattr( - provider_class, "default_fair_cycle_cross_tier" - ): + elif provider_class and hasattr(provider_class, "default_fair_cycle_cross_tier"): if provider_class.default_fair_cycle_cross_tier: fair_cycle_cross_tier[provider] = True @@ -339,12 +334,8 @@ def __init__( try: fair_cycle_duration[provider] = int(env_val) except ValueError: - lib_logger.warning( - f"Invalid {env_key}: {env_val}. Must be integer." - ) - elif provider_class and hasattr( - provider_class, "default_fair_cycle_duration" - ): + lib_logger.warning(f"Invalid {env_key}: {env_val}. Must be integer.") + elif provider_class and hasattr(provider_class, "default_fair_cycle_duration"): duration = provider_class.default_fair_cycle_duration if ( duration != DEFAULT_FAIR_CYCLE_DURATION @@ -374,9 +365,7 @@ def __init__( try: exhaustion_cooldown_threshold[provider] = int(env_val) except ValueError: - lib_logger.warning( - f"Invalid {env_key}: {env_val}. Must be integer." - ) + lib_logger.warning(f"Invalid {env_key}: {env_val}. Must be integer.") elif provider_class and hasattr( provider_class, "default_exhaustion_cooldown_threshold" ): @@ -395,9 +384,7 @@ def __init__( lib_logger.info(f"Provider '{provider}' fair cycle: disabled") for provider, mode in fair_cycle_tracking_mode.items(): if mode != "model_group": - lib_logger.info( - f"Provider '{provider}' fair cycle tracking mode: {mode}" - ) + lib_logger.info(f"Provider '{provider}' fair cycle tracking mode: {mode}") for provider, cross_tier in fair_cycle_cross_tier.items(): if cross_tier: lib_logger.info(f"Provider '{provider}' fair cycle cross-tier: enabled") @@ -426,9 +413,7 @@ def __init__( cooldown_prefix = f"CUSTOM_CAP_COOLDOWN_{provider_upper}_T" for env_key, env_value in os.environ.items(): - if env_key.startswith(cap_prefix) and not env_key.startswith( - cooldown_prefix - ): + if env_key.startswith(cap_prefix) and not env_key.startswith(cooldown_prefix): # Parse cap value remainder = env_key[len(cap_prefix) :] tier_key, model_key = self._parse_custom_cap_env_key(remainder) @@ -443,9 +428,7 @@ def __init__( custom_caps[provider][tier_key][model_key] = {} # Store max_requests value - custom_caps[provider][tier_key][model_key]["max_requests"] = ( - env_value - ) + custom_caps[provider][tier_key][model_key]["max_requests"] = env_value elif env_key.startswith(cooldown_prefix): # Parse cooldown config @@ -460,9 +443,7 @@ def __init__( try: value = int(value_str) except ValueError: - lib_logger.warning( - f"Invalid cooldown value in {env_key}: {env_value}" - ) + lib_logger.warning(f"Invalid cooldown value in {env_key}: {env_value}") continue else: mode = env_value @@ -644,9 +625,9 @@ def _is_model_whitelisted(self, provider: str, model_id: str) -> bool: for whitelisted_pattern in whitelist: # Use fnmatch for full glob pattern support - if fnmatch.fnmatch( - provider_model_name, whitelisted_pattern - ) or fnmatch.fnmatch(model_id, whitelisted_pattern): + if fnmatch.fnmatch(provider_model_name, whitelisted_pattern) or fnmatch.fnmatch( + model_id, whitelisted_pattern + ): return True return False @@ -719,13 +700,9 @@ def _litellm_logger_callback(self, log_data: dict): # For failures, extract key info to make debug logs more readable. model = log_data.get("model", "N/A") call_id = log_data.get("litellm_call_id", "N/A") - error_info = log_data.get("standard_logging_object", {}).get( - "error_information", {} - ) + error_info = log_data.get("standard_logging_object", {}).get("error_information", {}) error_class = error_info.get("error_class", "UnknownError") - error_message = error_info.get( - "error_message", str(log_data.get("exception", "")) - ) + error_message = error_info.get("error_message", str(log_data.get("exception", ""))) error_message = " ".join(error_message.split()) # Sanitize lib_logger.debug( @@ -744,9 +721,7 @@ async def close(self): if hasattr(self, "http_client") and self.http_client: await self.http_client.aclose() - def _apply_default_safety_settings( - self, litellm_kwargs: Dict[str, Any], provider: str - ): + def _apply_default_safety_settings(self, litellm_kwargs: Dict[str, Any], provider: str): """ Ensure default Gemini safety settings are present when calling the Gemini provider. This will not override any explicit settings provided by the request. It accepts @@ -798,10 +773,7 @@ def _apply_default_safety_settings( return # Neither present: set generic defaults so provider conversion will translate them - if ( - "safety_settings" not in litellm_kwargs - and "safetySettings" not in litellm_kwargs - ): + if "safety_settings" not in litellm_kwargs and "safetySettings" not in litellm_kwargs: litellm_kwargs["safety_settings"] = default_generic.copy() def get_oauth_credentials(self) -> Dict[str, List[str]]: @@ -848,9 +820,7 @@ def _get_provider_instance(self, provider_name: str): if provider_name not in self._provider_instances: if provider_name in self._provider_plugins: - self._provider_instances[provider_name] = self._provider_plugins[ - provider_name - ]() + self._provider_instances[provider_name] = self._provider_plugins[provider_name]() elif self._is_custom_openai_compatible_provider(provider_name): # Create a generic OpenAI-compatible provider for custom providers try: @@ -886,9 +856,7 @@ def _resolve_model_id(self, model: str, provider: str) -> str: # Check if provider has model definitions if provider_plugin and hasattr(provider_plugin, "model_definitions"): - model_id = provider_plugin.model_definitions.get_model_id( - provider, model_name - ) + model_id = provider_plugin.model_definitions.get_model_id(provider, model_name) if model_id and model_id != model_name: # Return with provider prefix return f"{provider}/{model_id}" @@ -901,6 +869,68 @@ def _resolve_model_id(self, model: str, provider: str) -> str: # No conversion needed, return original return model + def _load_model_routing_overrides_from_env(self) -> Dict[str, Any]: + raw_value = os.getenv("MODEL_ROUTING_OVERRIDES") + if not raw_value: + return {} + + try: + overrides = json.loads(raw_value) + except json.JSONDecodeError as exc: + raise RoutingPolicyError(f"Invalid JSON in MODEL_ROUTING_OVERRIDES: {exc}") from exc + + if not isinstance(overrides, dict): + raise RoutingPolicyError( + "MODEL_ROUTING_OVERRIDES must decode to an object keyed by clean model name" + ) + + return overrides + + def _build_routing_policy(self) -> Optional[RoutingPolicy]: + if not self.model_routing_overrides: + return None + + model_definitions = ModelDefinitions() + provider_models = { + provider: set(model_definitions.get_provider_models(provider).keys()) + for provider in self.all_credentials.keys() + } + + policy = RoutingPolicy( + model_overrides=self.model_routing_overrides, + available_providers=self.all_credentials.keys(), + provider_models=provider_models, + ) + lib_logger.info( + "Loaded %d model routing override(s)", + len(self.model_routing_overrides), + ) + return policy + + def _apply_routing_policy(self, model: str) -> Tuple[str, Optional[RouteDecision]]: + if not self.routing_policy: + return model, None + + decision = self.routing_policy.resolve(model) + return decision.rewritten_model or model, decision + + def _log_route_decision(self, decision: Optional[RouteDecision]) -> None: + if not decision or not decision.override_applied: + return + + lib_logger.info( + "Route decision: requested_model=%s rewritten_model=%s selected_provider=%s strategy=%s selection_source=%s strict=%s allow_global_fallback=%s candidate_providers=%s reason=%s", + decision.requested_model, + decision.rewritten_model, + decision.selected_provider, + decision.strategy, + decision.selection_source, + decision.strict, + decision.allow_global_fallback, + decision.candidate_providers, + decision.reason, + ) + async def _safe_streaming_wrapper( self, stream: Any, @@ -1000,9 +1030,7 @@ async def _safe_streaming_wrapper( if last_usage: # Create a dummy ModelResponse for recording (only usage matters) dummy_response = litellm.ModelResponse(usage=last_usage) - await self.usage_manager.record_success( - key, model, dummy_response - ) + await self.usage_manager.record_success(key, model, dummy_response) else: # If no usage seen (rare), record success without tokens/cost await self.usage_manager.record_success(key, model) @@ -1041,9 +1069,7 @@ async def _safe_streaming_wrapper( raw_chunk = codecs.decode(match.group(1), "unicode_escape") else: # Fallback for other potential error formats that use "Received chunk:". - chunk_from_split = ( - str(e).split("Received chunk:")[-1].strip() - ) + chunk_from_split = str(e).split("Received chunk:")[-1].strip() if chunk_from_split != str( e ): # Ensure the split actually did something @@ -1059,9 +1085,7 @@ async def _safe_streaming_wrapper( parsed_data = json.loads(json_buffer) # If parsing succeeds, we have the complete object. - lib_logger.info( - f"Successfully reassembled JSON from stream: {json_buffer}" - ) + lib_logger.info(f"Successfully reassembled JSON from stream: {json_buffer}") # Wrap the complete error object and raise it. The outer function will decide how to handle it. raise StreamedAPIError( @@ -1082,9 +1106,7 @@ async def _safe_streaming_wrapper( lib_logger.error( f"Error during stream buffering logic: {buffer_exc}. Discarding buffer." ) - json_buffer = ( - "" # Clear the corrupted buffer to prevent further issues. - ) + json_buffer = "" # Clear the corrupted buffer to prevent further issues. raise buffer_exc except StreamedAPIError: @@ -1111,9 +1133,7 @@ async def _safe_streaming_wrapper( # Only send [DONE] if the stream completed naturally and the client is still there. # This prevents sending [DONE] to a disconnected client or after an error. - if stream_completed and ( - not request or not await request.is_disconnected() - ): + if stream_completed and (not request or not await request.is_disconnected()): yield "data: [DONE]\n\n" async def _transaction_logging_stream_wrapper( @@ -1183,6 +1203,11 @@ async def _execute_with_retry( if not model: raise ValueError("'model' is a required parameter.") + model, route_decision = self._apply_routing_policy(model) + if model != kwargs.get("model"): + kwargs["model"] = model + self._log_route_decision(route_decision) + provider = model.split("/")[0] if provider not in self.all_credentials: raise ValueError( @@ -1318,16 +1343,14 @@ async def _execute_with_retry( error_accumulator.model = model error_accumulator.provider = provider - while ( - len(tried_creds) < len(credentials_for_provider) and time.time() < deadline - ): + while len(tried_creds) < len(credentials_for_provider) and time.time() < deadline: current_cred = None key_acquired = False try: # Check for a provider-wide cooldown first. if await self.cooldown_manager.is_cooling_down(provider): - remaining_cooldown = ( - await self.cooldown_manager.get_cooldown_remaining(provider) + remaining_cooldown = await self.cooldown_manager.get_cooldown_remaining( + provider ) remaining_budget = deadline - time.time() @@ -1343,17 +1366,13 @@ async def _execute_with_retry( ) await asyncio.sleep(remaining_cooldown) - creds_to_try = [ - c for c in credentials_for_provider if c not in tried_creds - ] + creds_to_try = [c for c in credentials_for_provider if c not in tried_creds] if not creds_to_try: break # Get count of credentials not on cooldown for this model - availability_stats = ( - await self.usage_manager.get_credential_availability_stats( - creds_to_try, model, credential_priorities - ) + availability_stats = await self.usage_manager.get_credential_availability_stats( + creds_to_try, model, credential_priorities ) available_count = availability_stats["available"] total_count = len(credentials_for_provider) @@ -1366,9 +1385,7 @@ async def _execute_with_retry( exclusion_parts.append(f"cd:{on_cooldown}") if fc_excluded > 0: exclusion_parts.append(f"fc:{fc_excluded}") - exclusion_str = ( - f",{','.join(exclusion_parts)}" if exclusion_parts else "" - ) + exclusion_str = f",{','.join(exclusion_parts)}" if exclusion_parts else "" lib_logger.info( f"Acquiring key for model {model}. Tried keys: {len(tried_creds)}/{available_count}({total_count}{exclusion_str})" @@ -1412,9 +1429,7 @@ async def _execute_with_retry( litellm_kwargs[key] = value if provider_plugin and provider_plugin.has_custom_logic(): - lib_logger.debug( - f"Provider '{provider}' has custom logic. Delegating call." - ) + lib_logger.debug(f"Provider '{provider}' has custom logic. Delegating call.") litellm_kwargs["credential_identifier"] = current_cred litellm_kwargs["transaction_context"] = ( transaction_logger.get_context() if transaction_logger else None @@ -1445,9 +1460,7 @@ async def _execute_with_retry( ) # For non-streaming, success is immediate - await self.usage_manager.record_success( - current_cred, model, response - ) + await self.usage_manager.record_success(current_cred, model, response) await self.usage_manager.release_key(current_cred, model) key_acquired = False @@ -1476,9 +1489,7 @@ async def _execute_with_retry( model=model, attempt=attempt + 1, error=e, - request_headers=dict(request.headers) - if request - else {}, + request_headers=dict(request.headers) if request else {}, ) # Record in accumulator for client reporting @@ -1519,9 +1530,7 @@ async def _execute_with_retry( model=model, attempt=attempt + 1, error=e, - request_headers=dict(request.headers) - if request - else {}, + request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message = str(e).split("\n")[0] @@ -1569,9 +1578,7 @@ async def _execute_with_retry( model=model, attempt=attempt + 1, error=e, - request_headers=dict(request.headers) - if request - else {}, + request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message = str(e).split("\n")[0] @@ -1623,20 +1630,14 @@ async def _execute_with_retry( if provider_instance: # Ensure default Gemini safety settings are present (without overriding request) try: - self._apply_default_safety_settings( - litellm_kwargs, provider - ) + self._apply_default_safety_settings(litellm_kwargs, provider) except Exception: # If anything goes wrong here, avoid breaking the request flow. - lib_logger.debug( - "Could not apply default safety settings; continuing." - ) + lib_logger.debug("Could not apply default safety settings; continuing.") if "safety_settings" in litellm_kwargs: - converted_settings = ( - provider_instance.convert_safety_settings( - litellm_kwargs["safety_settings"] - ) + converted_settings = provider_instance.convert_safety_settings( + litellm_kwargs["safety_settings"] ) if converted_settings is not None: litellm_kwargs["safety_settings"] = converted_settings @@ -1644,13 +1645,9 @@ async def _execute_with_retry( del litellm_kwargs["safety_settings"] if provider == "gemini" and provider_instance: - provider_instance.handle_thinking_parameter( - litellm_kwargs, model - ) + provider_instance.handle_thinking_parameter(litellm_kwargs, model) if provider == "nvidia_nim" and provider_instance: - provider_instance.handle_thinking_parameter( - litellm_kwargs, model - ) + provider_instance.handle_thinking_parameter(litellm_kwargs, model) if "gemma-3" in model and "messages" in litellm_kwargs: litellm_kwargs["messages"] = [ @@ -1691,9 +1688,7 @@ async def _execute_with_retry( logger_fn=self._litellm_logger_callback, ) - await self.usage_manager.record_success( - current_cred, model, response - ) + await self.usage_manager.record_success(current_cred, model, response) await self.usage_manager.release_key(current_cred, model) key_acquired = False @@ -1716,9 +1711,7 @@ async def _execute_with_retry( model=model, attempt=attempt + 1, error=e, - request_headers=dict(request.headers) - if request - else {}, + request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) @@ -1760,9 +1753,7 @@ async def _execute_with_retry( model=model, attempt=attempt + 1, error=e, - request_headers=dict(request.headers) - if request - else {}, + request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message = str(e).split("\n")[0] @@ -1815,9 +1806,7 @@ async def _execute_with_retry( model=model, attempt=attempt + 1, error=e, - request_headers=dict(request.headers) - if request - else {}, + request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) @@ -1878,9 +1867,7 @@ async def _execute_with_retry( model=model, attempt=attempt + 1, error=e, - request_headers=dict(request.headers) - if request - else {}, + request_headers=dict(request.headers) if request else {}, ) if request and await request.is_disconnected(): @@ -1952,6 +1939,14 @@ async def _streaming_acompletion_with_retry( ) -> AsyncGenerator[str, None]: """A dedicated generator for retrying streaming completions with full request preparation and per-key retries.""" model = kwargs.get("model") + if not model: + raise ValueError("'model' is a required parameter.") + + model, route_decision = self._apply_routing_policy(model) + if model != kwargs.get("model"): + kwargs["model"] = model + self._log_route_decision(route_decision) + provider = model.split("/")[0] # Extract internal logging parameters (not passed to API) @@ -2081,16 +2076,13 @@ async def _streaming_acompletion_with_retry( error_accumulator.provider = provider try: - while ( - len(tried_creds) < len(credentials_for_provider) - and time.time() < deadline - ): + while len(tried_creds) < len(credentials_for_provider) and time.time() < deadline: current_cred = None key_acquired = False try: if await self.cooldown_manager.is_cooling_down(provider): - remaining_cooldown = ( - await self.cooldown_manager.get_cooldown_remaining(provider) + remaining_cooldown = await self.cooldown_manager.get_cooldown_remaining( + provider ) remaining_budget = deadline - time.time() if remaining_cooldown > remaining_budget: @@ -2103,9 +2095,7 @@ async def _streaming_acompletion_with_retry( ) await asyncio.sleep(remaining_cooldown) - creds_to_try = [ - c for c in credentials_for_provider if c not in tried_creds - ] + creds_to_try = [c for c in credentials_for_provider if c not in tried_creds] if not creds_to_try: lib_logger.warning( f"All credentials for provider {provider} have been tried. No more credentials to rotate to." @@ -2113,10 +2103,8 @@ async def _streaming_acompletion_with_retry( break # Get count of credentials not on cooldown for this model - availability_stats = ( - await self.usage_manager.get_credential_availability_stats( - creds_to_try, model, credential_priorities - ) + availability_stats = await self.usage_manager.get_credential_availability_stats( + creds_to_try, model, credential_priorities ) available_count = availability_stats["available"] total_count = len(credentials_for_provider) @@ -2129,16 +2117,12 @@ async def _streaming_acompletion_with_retry( exclusion_parts.append(f"cd:{on_cooldown}") if fc_excluded > 0: exclusion_parts.append(f"fc:{fc_excluded}") - exclusion_str = ( - f",{','.join(exclusion_parts)}" if exclusion_parts else "" - ) + exclusion_str = f",{','.join(exclusion_parts)}" if exclusion_parts else "" lib_logger.info( f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{available_count}({total_count}{exclusion_str})" ) - max_concurrent = self.max_concurrent_requests_per_key.get( - provider, 1 - ) + max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1) current_cred = await self.usage_manager.acquire_key( available_keys=creds_to_try, model=model, @@ -2168,9 +2152,7 @@ async def _streaming_acompletion_with_retry( # No further resolution needed here. # Apply model-specific options for custom providers - if provider_plugin and hasattr( - provider_plugin, "get_model_options" - ): + if provider_plugin and hasattr(provider_plugin, "get_model_options"): model_options = provider_plugin.get_model_options(model) if model_options: # Merge model options into litellm_kwargs @@ -2185,9 +2167,7 @@ async def _streaming_acompletion_with_retry( ) litellm_kwargs["credential_identifier"] = current_cred litellm_kwargs["transaction_context"] = ( - transaction_logger.get_context() - if transaction_logger - else None + transaction_logger.get_context() if transaction_logger else None ) for attempt in range(self.max_retries): @@ -2198,9 +2178,7 @@ async def _streaming_acompletion_with_retry( if pre_request_callback: try: - await pre_request_callback( - request, litellm_kwargs - ) + await pre_request_callback(request, litellm_kwargs) except Exception as e: if self.abort_on_callback_error: raise PreRequestCallbackError( @@ -2229,10 +2207,8 @@ async def _streaming_acompletion_with_retry( ) # Wrap with transaction logging - logged_stream = ( - self._transaction_logging_stream_wrapper( - stream_generator, transaction_logger, kwargs - ) + logged_stream = self._transaction_logging_stream_wrapper( + stream_generator, transaction_logger, kwargs ) async for chunk in logged_stream: @@ -2247,9 +2223,7 @@ async def _streaming_acompletion_with_retry( last_exception = e # If the exception is our custom wrapper, unwrap the original error original_exc = getattr(e, "data", e) - classified_error = classify_error( - original_exc, provider=provider - ) + classified_error = classify_error(original_exc, provider=provider) error_message = str(original_exc).split("\n")[0] log_failure( @@ -2257,9 +2231,7 @@ async def _streaming_acompletion_with_retry( model=model, attempt=attempt + 1, error=e, - request_headers=dict(request.headers) - if request - else {}, + request_headers=dict(request.headers) if request else {}, ) # Record in accumulator for client reporting @@ -2276,9 +2248,7 @@ async def _streaming_acompletion_with_retry( # Handle rate limits with cooldown (exclude quota_exceeded) if classified_error.error_type == "rate_limit": - cooldown_duration = ( - classified_error.retry_after or 60 - ) + cooldown_duration = classified_error.retry_after or 60 await self.cooldown_manager.start_cooldown( provider, cooldown_duration ) @@ -2302,9 +2272,7 @@ async def _streaming_acompletion_with_retry( model=model, attempt=attempt + 1, error=e, - request_headers=dict(request.headers) - if request - else {}, + request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message = str(e).split("\n")[0] @@ -2352,9 +2320,7 @@ async def _streaming_acompletion_with_retry( model=model, attempt=attempt + 1, error=e, - request_headers=dict(request.headers) - if request - else {}, + request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message = str(e).split("\n")[0] @@ -2396,19 +2362,15 @@ async def _streaming_acompletion_with_retry( if provider_instance: # Ensure default Gemini safety settings are present (without overriding request) try: - self._apply_default_safety_settings( - litellm_kwargs, provider - ) + self._apply_default_safety_settings(litellm_kwargs, provider) except Exception: lib_logger.debug( "Could not apply default safety settings for streaming path; continuing." ) if "safety_settings" in litellm_kwargs: - converted_settings = ( - provider_instance.convert_safety_settings( - litellm_kwargs["safety_settings"] - ) + converted_settings = provider_instance.convert_safety_settings( + litellm_kwargs["safety_settings"] ) if converted_settings is not None: litellm_kwargs["safety_settings"] = converted_settings @@ -2416,13 +2378,9 @@ async def _streaming_acompletion_with_retry( del litellm_kwargs["safety_settings"] if provider == "gemini" and provider_instance: - provider_instance.handle_thinking_parameter( - litellm_kwargs, model - ) + provider_instance.handle_thinking_parameter(litellm_kwargs, model) if provider == "nvidia_nim" and provider_instance: - provider_instance.handle_thinking_parameter( - litellm_kwargs, model - ) + provider_instance.handle_thinking_parameter(litellm_kwargs, model) if "gemma-3" in model and "messages" in litellm_kwargs: litellm_kwargs["messages"] = [ @@ -2504,9 +2462,7 @@ async def _streaming_acompletion_with_retry( cleaned_str = None # The actual exception might be wrapped in our StreamedAPIError. original_exc = getattr(e, "data", e) - classified_error = classify_error( - original_exc, provider=provider - ) + classified_error = classify_error(original_exc, provider=provider) # Check if this error should trigger rotation if not should_rotate_on_error(classified_error): @@ -2533,9 +2489,7 @@ async def _streaming_acompletion_with_retry( model=model, attempt=attempt + 1, error=e, - request_headers=dict(request.headers) - if request - else {}, + request_headers=dict(request.headers) if request else {}, raw_response_text=cleaned_str, ) @@ -2565,15 +2519,10 @@ async def _streaming_acompletion_with_retry( if isinstance(detail.get("violations"), list): for violation in detail["violations"]: if "quotaValue" in violation: - quota_value = violation[ - "quotaValue" - ] + quota_value = violation["quotaValue"] if "quotaId" in violation: quota_id = violation["quotaId"] - if ( - quota_value != "N/A" - and quota_id != "N/A" - ): + if quota_value != "N/A" and quota_id != "N/A": break await self.usage_manager.record_failure( @@ -2605,9 +2554,7 @@ async def _streaming_acompletion_with_retry( ) if classified_error.error_type == "rate_limit": - cooldown_duration = ( - classified_error.retry_after or 60 - ) + cooldown_duration = classified_error.retry_after or 60 await self.cooldown_manager.start_cooldown( provider, cooldown_duration ) @@ -2629,9 +2576,7 @@ async def _streaming_acompletion_with_retry( model=model, attempt=attempt + 1, error=e, - request_headers=dict(request.headers) - if request - else {}, + request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message_text = str(e).split("\n")[0] @@ -2680,9 +2625,7 @@ async def _streaming_acompletion_with_retry( model=model, attempt=attempt + 1, error=e, - request_headers=dict(request.headers) - if request - else {}, + request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message_text = str(e).split("\n")[0] @@ -2746,12 +2689,8 @@ async def _streaming_acompletion_with_retry( "Request failed: No available API keys after rotation or timeout." ) if last_exception: - final_error_message = ( - f"Request failed. Last error: {str(last_exception)}" - ) - error_data = { - "error": {"message": final_error_message, "type": "proxy_error"} - } + final_error_message = f"Request failed. Last error: {str(last_exception)}" + error_data = {"error": {"message": final_error_message, "type": "proxy_error"}} lib_logger.error(final_error_message) yield f"data: {json.dumps(error_data)}\n\n" @@ -2798,14 +2737,19 @@ def acompletion( Returns: The completion response object, or an async generator for streaming responses, or None if all retries fail. """ - # Handle iflow provider: remove stream_options to avoid HTTP 406 model = kwargs.get("model", "") + if model: + routed_model, route_decision = self._apply_routing_policy(model) + if routed_model != model: + kwargs["model"] = routed_model + model = routed_model + self._log_route_decision(route_decision) + + # Handle iflow provider: remove stream_options to avoid HTTP 406 provider = model.split("/")[0] if "/" in model else "" if provider == "iflow" and "stream_options" in kwargs: - lib_logger.debug( - "Removing stream_options for iflow provider to avoid HTTP 406" - ) + lib_logger.debug("Removing stream_options for iflow provider to avoid HTTP 406") kwargs.pop("stream_options", None) if kwargs.get("stream"): @@ -2924,12 +2868,8 @@ async def get_available_models(self, provider: str) -> List[str]: lib_logger.debug( f"Attempting to get models for {provider} with credential {cred_display}" ) - models = await provider_instance.get_models( - credential, self.http_client - ) - lib_logger.info( - f"Got {len(models)} models for provider: {provider}" - ) + models = await provider_instance.get_models(credential, self.http_client) + lib_logger.info(f"Got {len(models)} models for provider: {provider}") # Whitelist and blacklist logic final_models = [] @@ -2977,9 +2917,7 @@ async def get_all_available_models( all_provider_models = {} for provider, result in zip(all_providers, results): if isinstance(result, Exception): - lib_logger.error( - f"Failed to get models for provider {provider}: {result}" - ) + lib_logger.error(f"Failed to get models for provider {provider}: {result}") all_provider_models[provider] = [] else: all_provider_models[provider] = result @@ -3051,9 +2989,7 @@ async def get_quota_stats( # Track tier - get directly from provider cache since cred["tier"] not set yet tier = cred.get("tier") - if not tier and hasattr( - provider_instance, "project_tier_cache" - ): + if not tier and hasattr(provider_instance, "project_tier_cache"): cred_path = cred.get("full_path", "") tier = provider_instance.project_tier_cache.get(cred_path) tier = tier or "unknown" @@ -3062,9 +2998,7 @@ async def get_quota_stats( if tier not in group_stats["tiers"]: priority = 10 # default if hasattr(provider_instance, "_resolve_tier_priority"): - priority = provider_instance._resolve_tier_priority( - tier - ) + priority = provider_instance._resolve_tier_priority(tier) group_stats["tiers"][tier] = { "total": 0, "active": 0, @@ -3098,9 +3032,7 @@ async def get_quota_stats( if baseline is not None: remaining_pct = int(baseline * 100) - group_stats["total_remaining_pcts"].append( - remaining_pct - ) + group_stats["total_remaining_pcts"].append(remaining_pct) if baseline <= 0: group_stats["credentials_exhausted"] += 1 else: @@ -3120,9 +3052,7 @@ async def get_quota_stats( used = group_stats["total_requests_used"] max_r = group_stats["total_requests_max"] group_stats["total_requests_remaining"] = max_r - used - group_stats["total_remaining_pct"] = max( - 0, int((1 - used / max_r) * 100) - ) + group_stats["total_remaining_pct"] = max(0, int((1 - used / max_r) * 100)) else: group_stats["total_requests_remaining"] = 0 # Fallback to avg_remaining_pct when max_requests unavailable @@ -3150,10 +3080,7 @@ async def get_quota_stats( # Track the best (latest) reset_ts from any model in group candidate_reset_ts = candidate.get("quota_reset_ts") if candidate_reset_ts: - if ( - best_reset_ts is None - or candidate_reset_ts > best_reset_ts - ): + if best_reset_ts is None or candidate_reset_ts > best_reset_ts: best_reset_ts = candidate_reset_ts baseline = candidate.get("baseline_remaining_fraction") @@ -3169,13 +3096,9 @@ async def get_quota_stats( max_req = model_stats.get("quota_max_requests") req_count = model_stats.get("request_count", 0) # Use best_reset_ts from any model in the group - reset_ts = best_reset_ts or model_stats.get( - "quota_reset_ts" - ) + reset_ts = best_reset_ts or model_stats.get("quota_reset_ts") - remaining_pct = ( - int(baseline * 100) if baseline is not None else None - ) + remaining_pct = int(baseline * 100) if baseline is not None else None is_exhausted = baseline is not None and baseline <= 0 # Format reset time @@ -3190,9 +3113,7 @@ async def get_quota_stats( except (ValueError, OSError): pass - requests_remaining = ( - max(0, max_req - req_count) if max_req else 0 - ) + requests_remaining = max(0, max_req - req_count) if max_req else 0 # Determine display format # Priority: requests (if max known) > percentage (if baseline available) > unknown @@ -3212,17 +3133,14 @@ async def get_quota_stats( "is_exhausted": is_exhausted, "reset_time_iso": reset_iso, "models": group_models, - "confidence": self._get_baseline_confidence( - model_stats - ), + "confidence": self._get_baseline_confidence(model_stats), } # Recalculate credential's requests from model_groups # This fixes double-counting when models share quota groups if cred.get("model_groups"): group_requests = sum( - g.get("requests_used", 0) - for g in cred["model_groups"].values() + g.get("requests_used", 0) for g in cred["model_groups"].values() ) cred["requests"] = group_requests @@ -3279,9 +3197,7 @@ def _find_model_stats_in_data( api_model = provider_instance._user_to_api_model(model) if api_model != model: prefixed_api = f"{provider}/{api_model}" - model_stats = models_data.get(prefixed_api) or models_data.get( - api_model - ) + model_stats = models_data.get(prefixed_api) or models_data.get(api_model) return model_stats @@ -3334,9 +3250,7 @@ async def force_refresh_quota( """ result = { "action": "force_refresh", - "scope": "credential" - if credential - else ("provider" if provider else "all"), + "scope": "credential" if credential else ("provider" if provider else "all"), "provider": provider, "credential": credential, "credentials_refreshed": 0, @@ -3350,9 +3264,7 @@ async def force_refresh_quota( # Determine which providers to refresh if provider: - providers_to_refresh = ( - [provider] if provider in self.all_credentials else [] - ) + providers_to_refresh = [provider] if provider in self.all_credentials else [] else: providers_to_refresh = list(self.all_credentials.keys()) @@ -3390,10 +3302,8 @@ async def force_refresh_quota( # Store baselines in usage manager if hasattr(provider_instance, "_store_baselines_to_usage_manager"): - stored = ( - await provider_instance._store_baselines_to_usage_manager( - quota_results, self.usage_manager - ) + stored = await provider_instance._store_baselines_to_usage_manager( + quota_results, self.usage_manager ) result["success_count"] += stored @@ -3506,13 +3416,9 @@ async def anthropic_messages( # Convert OpenAI response to Anthropic format openai_response = ( - response.model_dump() - if hasattr(response, "model_dump") - else dict(response) - ) - anthropic_response = openai_to_anthropic_response( - openai_response, original_model + response.model_dump() if hasattr(response, "model_dump") else dict(response) ) + anthropic_response = openai_to_anthropic_response(openai_response, original_model) # Override the ID with our request ID anthropic_response["id"] = request_id @@ -3565,9 +3471,7 @@ async def anthropic_count_tokens( if request.tools: # Tools add tokens based on their definitions # Convert to JSON string and count tokens for tool definitions - openai_tools = anthropic_to_openai_tools( - [tool.model_dump() for tool in request.tools] - ) + openai_tools = anthropic_to_openai_tools([tool.model_dump() for tool in request.tools]) if openai_tools: # Serialize tools to count their token contribution tools_text = json.dumps(openai_tools) diff --git a/src/rotator_library/routing_policy.py b/src/rotator_library/routing_policy.py new file mode 100644 index 00000000..b334f75f --- /dev/null +++ b/src/rotator_library/routing_policy.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Iterable, Optional, Set + + +class RoutingPolicyError(ValueError): + """Raised when routing override configuration or resolution is invalid.""" + + +@dataclass(frozen=True) +class RouteDecision: + requested_model: str + clean_model: str + selected_provider: Optional[str] + rewritten_model: Optional[str] + strategy: str + selection_source: str + override_applied: bool + candidate_providers: list[str] + strict: bool + allow_global_fallback: bool + reason: Optional[str] = None + + +class RoutingPolicy: + """Resolve weighted-router models into concrete provider-prefixed models. + + v1 intentionally supports only strict single-provider overrides. It rewrites + abstract `weighted-router/` requests before provider lock-in so the + existing retry and credential machinery can continue unchanged. + """ + + def __init__( + self, + model_overrides: Dict[str, Any], + available_providers: Iterable[str], + provider_models: Optional[Dict[str, Set[str]]] = None, + ) -> None: + if not isinstance(model_overrides, dict): + raise RoutingPolicyError("MODEL_ROUTING_OVERRIDES must decode to an object") + + self.model_overrides = model_overrides + self.available_providers = set(available_providers) + self.provider_models = provider_models or {} + self._validate() + + def _validate(self) -> None: + for clean_model, override in self.model_overrides.items(): + if not isinstance(clean_model, str) or not clean_model: + raise RoutingPolicyError("routing override keys must be non-empty model names") + if not isinstance(override, dict): + raise RoutingPolicyError(f"routing override for '{clean_model}' must be an object") + + strategy = override.get("strategy") + if strategy != "single": + raise RoutingPolicyError( + f"routing override for '{clean_model}' must use strategy 'single' in v1" + ) + + primary = override.get("primary") + if not isinstance(primary, str) or not primary: + raise RoutingPolicyError( + f"routing override for '{clean_model}' requires a non-empty 'primary' provider" + ) + if primary not in self.available_providers: + raise RoutingPolicyError( + f"routing override for '{clean_model}' references unknown provider '{primary}'" + ) + + allowed_providers = override.get("allowed_providers", [primary]) + if not isinstance(allowed_providers, list) or not all( + isinstance(provider, str) and provider for provider in allowed_providers + ): + raise RoutingPolicyError( + f"routing override for '{clean_model}' must use a string list for 'allowed_providers'" + ) + if allowed_providers != [primary]: + raise RoutingPolicyError( + f"routing override for '{clean_model}' must restrict 'allowed_providers' to ['{primary}'] in v1" + ) + + fallback_providers = override.get("fallback_providers", []) + if fallback_providers not in (None, []): + raise RoutingPolicyError( + f"routing override for '{clean_model}' cannot define 'fallback_providers' in v1" + ) + + provider_models = self.provider_models.get(primary) + if provider_models and clean_model not in provider_models: + raise RoutingPolicyError( + f"provider '{primary}' does not expose model '{clean_model}' in configured model definitions" + ) + + def resolve(self, model: str) -> RouteDecision: + if "/" not in model: + return RouteDecision( + requested_model=model, + clean_model=model, + selected_provider=None, + rewritten_model=model, + strategy="passthrough", + selection_source="passthrough", + override_applied=False, + candidate_providers=[], + strict=False, + allow_global_fallback=True, + ) + + provider, clean_model = model.split("/", 1) + if provider != "weighted-router": + return RouteDecision( + requested_model=model, + clean_model=clean_model, + selected_provider=provider, + rewritten_model=model, + strategy="passthrough", + selection_source="passthrough", + override_applied=False, + candidate_providers=[provider], + strict=False, + allow_global_fallback=True, + ) + + override = self.model_overrides.get(clean_model) + if override is None: + raise RoutingPolicyError( + f"No routing override configured for weighted-router model '{clean_model}'" + ) + + selected_provider = override["primary"] + return RouteDecision( + requested_model=model, + clean_model=clean_model, + selected_provider=selected_provider, + rewritten_model=f"{selected_provider}/{clean_model}", + strategy="single", + selection_source="model_override", + override_applied=True, + candidate_providers=[selected_provider], + strict=bool(override.get("strict", True)), + allow_global_fallback=bool(override.get("allow_global_fallback", False)), + reason=override.get("reason"), + ) diff --git a/tests/test_client_routing_policy.py b/tests/test_client_routing_policy.py new file mode 100644 index 00000000..df898038 --- /dev/null +++ b/tests/test_client_routing_policy.py @@ -0,0 +1,98 @@ +import sys +import asyncio +from pathlib import Path + +import pytest + + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "src")) + +from rotator_library.client import RotatingClient +from rotator_library.routing_policy import RoutingPolicy, RoutingPolicyError + + +def test_client_helper_rewrites_weighted_router_model(): + client = RotatingClient.__new__(RotatingClient) + client.routing_policy = RoutingPolicy( + model_overrides={ + "nemotron-3-super": { + "strategy": "single", + "primary": "ollama", + "allowed_providers": ["ollama"], + "fallback_providers": [], + "strict": True, + "allow_global_fallback": False, + } + }, + available_providers={"ollama"}, + provider_models={"ollama": {"nemotron-3-super"}}, + ) + + model, decision = client._apply_routing_policy("weighted-router/nemotron-3-super") + + assert model == "ollama/nemotron-3-super" + assert decision is not None + assert decision.override_applied is True + + +def test_client_helper_passthrough_without_routing_policy(): + client = RotatingClient.__new__(RotatingClient) + client.routing_policy = None + + model, decision = client._apply_routing_policy("ollama/nemotron-3-super") + + assert model == "ollama/nemotron-3-super" + assert decision is None + + +def test_load_model_routing_overrides_from_env(monkeypatch): + client = RotatingClient.__new__(RotatingClient) + monkeypatch.setenv( + "MODEL_ROUTING_OVERRIDES", + '{"nemotron-3-super":{"strategy":"single","primary":"ollama","allowed_providers":["ollama"],"fallback_providers":[]}}', + ) + + overrides = client._load_model_routing_overrides_from_env() + + assert overrides["nemotron-3-super"]["primary"] == "ollama" + + +def test_invalid_model_routing_overrides_env_fails_closed(monkeypatch): + client = RotatingClient.__new__(RotatingClient) + monkeypatch.setenv("MODEL_ROUTING_OVERRIDES", "{invalid") + + with pytest.raises(RoutingPolicyError, match="Invalid JSON"): + client._load_model_routing_overrides_from_env() + + +def test_acompletion_rewrites_model_before_dispatch(monkeypatch): + client = RotatingClient.__new__(RotatingClient) + client.routing_policy = RoutingPolicy( + model_overrides={ + "nemotron-3-super": { + "strategy": "single", + "primary": "ollama", + "allowed_providers": ["ollama"], + "fallback_providers": [], + "strict": True, + "allow_global_fallback": False, + } + }, + available_providers={"ollama"}, + provider_models={"ollama": {"nemotron-3-super"}}, + ) + + captured = {} + + async def fake_execute_with_retry(api_call, request=None, pre_request_callback=None, **kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(client, "_execute_with_retry", fake_execute_with_retry) + monkeypatch.setattr(client, "_log_route_decision", lambda decision: None) + + result = asyncio.run(client.acompletion(model="weighted-router/nemotron-3-super", stream=False)) + + assert result == {"ok": True} + assert captured["model"] == "ollama/nemotron-3-super" diff --git a/tests/test_routing_policy.py b/tests/test_routing_policy.py new file mode 100644 index 00000000..392b093c --- /dev/null +++ b/tests/test_routing_policy.py @@ -0,0 +1,118 @@ +import sys +from pathlib import Path + +import pytest + + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "src")) + +from rotator_library.routing_policy import RoutingPolicy, RoutingPolicyError + + +def make_policy(overrides=None, provider_models=None, providers=None): + return RoutingPolicy( + model_overrides=overrides + or { + "nemotron-3-super": { + "strategy": "single", + "primary": "ollama", + "allowed_providers": ["ollama"], + "fallback_providers": [], + "strict": True, + "allow_global_fallback": False, + "reason": "Only available on Ollama Cloud", + } + }, + available_providers=providers or {"ollama", "chutes"}, + provider_models=provider_models + or { + "ollama": {"nemotron-3-super", "qwen3.5"}, + "chutes": {"qwen3.5"}, + }, + ) + + +def test_single_override_rewrites_weighted_router_model(): + decision = make_policy().resolve("weighted-router/nemotron-3-super") + + assert decision.selected_provider == "ollama" + assert decision.rewritten_model == "ollama/nemotron-3-super" + assert decision.selection_source == "model_override" + assert decision.override_applied is True + + +def test_non_weighted_router_model_passes_through(): + decision = make_policy().resolve("ollama/nemotron-3-super") + + assert decision.rewritten_model == "ollama/nemotron-3-super" + assert decision.override_applied is False + assert decision.selection_source == "passthrough" + + +def test_missing_override_for_weighted_router_model_fails_closed(): + with pytest.raises(RoutingPolicyError, match="No routing override configured"): + make_policy().resolve("weighted-router/qwen3.5") + + +def test_unknown_provider_fails_validation(): + with pytest.raises(RoutingPolicyError, match="unknown provider 'go'"): + make_policy( + overrides={ + "nemotron-3-super": { + "strategy": "single", + "primary": "go", + "allowed_providers": ["go"], + "fallback_providers": [], + } + } + ) + + +@pytest.mark.parametrize( + "override, expected_error", + [ + ( + {"nemotron-3-super": {"primary": "ollama", "allowed_providers": ["ollama"]}}, + "strategy 'single'", + ), + ( + { + "nemotron-3-super": { + "strategy": "single", + "primary": "ollama", + "allowed_providers": ["ollama", "chutes"], + } + }, + "restrict 'allowed_providers'", + ), + ( + { + "nemotron-3-super": { + "strategy": "single", + "primary": "ollama", + "allowed_providers": ["ollama"], + "fallback_providers": ["chutes"], + } + }, + "cannot define 'fallback_providers'", + ), + ], +) +def test_invalid_single_override_shapes_fail_validation(override, expected_error): + with pytest.raises(RoutingPolicyError, match=expected_error): + make_policy(overrides=override) + + +def test_provider_model_mismatch_fails_validation_when_models_are_known(): + with pytest.raises(RoutingPolicyError, match="does not expose model 'nemotron-3-super'"): + make_policy( + overrides={ + "nemotron-3-super": { + "strategy": "single", + "primary": "chutes", + "allowed_providers": ["chutes"], + "fallback_providers": [], + } + } + )