From 101e5a0b21bb375d519571b091edd8f76b66af6b Mon Sep 17 00:00:00 2001 From: hippoley Date: Mon, 27 Apr 2026 23:04:19 +0800 Subject: [PATCH] feat: smart model routing with priority selection and provider fallback Addresses Issue #68 (problems 1, 2, 3): 1. ListAvailableModels now returns Auto + all active models so users can see and select specific models instead of only 'Auto'. 2. selectAutoModel uses a priority field (INT, higher = preferred) with random tie-breaking among the highest-priority group for simple load balancing. Falls back to secure models when no non-secure candidates remain. 3. Provider fallback on failure: - Non-streaming: retries with alternate models on connection error or 5xx response (max 2 retries via callWithFallback). - Streaming: retries only on connection-level failure before any response headers are written (max 2 retries via streamWithFallback). - Each fallback attempt records an audit event for observability. Schema: adds 'priority' column to llm_models (INT NOT NULL DEFAULT 0) with idempotent ALTER TABLE guarded by duplicate-column-name check. New errProviderConnection sentinel type distinguishes retriable connection failures from committed-response errors in the streaming path. Tests: 9 new tests covering model listing, priority selection, load distribution, exclusion logic, and error type detection. --- backend/internal/aigateway/service.go | 203 +++++++++++++++-- backend/internal/aigateway/service_test.go | 210 ++++++++++++++++++ backend/internal/models/llm_model.go | 1 + .../repository/llm_model_repository.go | 15 +- 4 files changed, 407 insertions(+), 22 deletions(-) diff --git a/backend/internal/aigateway/service.go b/backend/internal/aigateway/service.go index bb515db..1518cf4 100644 --- a/backend/internal/aigateway/service.go +++ b/backend/internal/aigateway/service.go @@ -12,6 +12,7 @@ import ( "fmt" "io" "log" + mathrand "math/rand/v2" "net/http" "net/url" "regexp" @@ -119,6 +120,17 @@ const autoModelID = "auto" const maxStoredIdentifierLength = 100 const anthropicVersionHeader = "2023-06-01" const defaultAnthropicMaxTokens = 4096 +const maxFallbackRetries = 2 + +// errProviderConnection signals that the provider HTTP call failed before any +// response headers were written, making the request safe to retry on a +// different model. +type errProviderConnection struct { + wrapped error +} + +func (e *errProviderConnection) Error() string { return e.wrapped.Error() } +func (e *errProviderConnection) Unwrap() error { return e.wrapped } var providerVersionSegmentPattern = regexp.MustCompile(`(?i)^v\d+(?:[a-z0-9._-]*)?$`) @@ -324,15 +336,24 @@ func (s *service) ListAvailableModels() ([]AvailableModel, error) { return []AvailableModel{}, nil } - return []AvailableModel{ - { - ID: 0, - DisplayName: "Auto", - Description: stringPtr("Automatically route requests to the best available model under current governance policy."), - IsSecure: false, - Provider: "gateway", - }, - }, nil + result := make([]AvailableModel, 0, len(items)+1) + result = append(result, AvailableModel{ + ID: 0, + DisplayName: "Auto", + Description: stringPtr("Automatically route requests to the best available model under current governance policy."), + IsSecure: false, + Provider: "gateway", + }) + for _, item := range items { + result = append(result, AvailableModel{ + ID: item.ID, + DisplayName: item.DisplayName, + Description: item.Description, + IsSecure: item.IsSecure, + Provider: item.ProviderType, + }) + } + return result, nil } func (s *service) ChatCompletions(ctx context.Context, userID int, req ChatCompletionRequest) (*ProxyResponse, string, error) { @@ -345,6 +366,22 @@ func (s *service) ChatCompletions(ctx context.Context, userID int, req ChatCompl return nil, traceID, err } + resp, traceID, err := s.dispatchCall(ctx, prepared) + if err != nil { + return resp, traceID, err + } + + // Fallback: retry with alternate models on connection failure or 5xx. + if resp != nil && resp.StatusCode >= 500 { + if fallbackResp, fallbackTraceID, ok := s.callWithFallback(ctx, prepared); ok { + return fallbackResp, fallbackTraceID, nil + } + } + return resp, traceID, nil +} + +// dispatchCall routes a prepared request to the correct provider. +func (s *service) dispatchCall(ctx context.Context, prepared *preparedChatRequest) (*ProxyResponse, string, error) { switch models.ResolveLLMProtocolTypeOrDefault(prepared.resolvedModel.ProviderType, prepared.resolvedModel.ProtocolType) { case models.ProtocolTypeOpenAI, models.ProtocolTypeOpenAICompatible: return s.callOpenAICompatible(ctx, prepared) @@ -366,6 +403,39 @@ func (s *service) ChatCompletions(ctx context.Context, userID int, req ChatCompl } } +// callWithFallback tries alternate models when the primary model fails. +// Returns (response, traceID, true) on successful fallback, or (nil, "", false) if exhausted. +func (s *service) callWithFallback(ctx context.Context, prepared *preparedChatRequest) (*ProxyResponse, string, bool) { + excludeIDs := map[int]bool{prepared.resolvedModel.ID: true} + + for attempt := 0; attempt < maxFallbackRetries; attempt++ { + fallbackModel, err := s.selectAutoModelExcluding(excludeIDs) + if err != nil { + break // no more candidates + } + + _ = s.auditEventService.RecordEvent(&models.AuditEvent{ + TraceID: prepared.traceID, + SessionID: prepared.req.SessionID, + RequestID: prepared.requestIDPtr, + UserID: prepared.userIDPtr, + InstanceID: prepared.req.InstanceID, + EventType: "gateway.request.fallback", + TrafficClass: models.TrafficClassLLM, + Severity: models.AuditSeverityWarn, + Message: fmt.Sprintf("Falling back from model %s to %s (attempt %d)", prepared.resolvedModel.DisplayName, fallbackModel.DisplayName, attempt+1), + }) + + prepared.resolvedModel = fallbackModel + resp, traceID, err := s.dispatchCall(ctx, prepared) + if err == nil && (resp == nil || resp.StatusCode < 500) { + return resp, traceID, true + } + excludeIDs[fallbackModel.ID] = true + } + return nil, "", false +} + func (s *service) StreamChatCompletions(ctx context.Context, userID int, req ChatCompletionRequest, w http.ResponseWriter) (string, error) { prepared, err := s.prepareChatRequest(userID, req) if err != nil { @@ -376,11 +446,31 @@ func (s *service) StreamChatCompletions(ctx context.Context, userID int, req Cha return traceID, err } + streamErr := s.dispatchStream(ctx, prepared, w) + if streamErr == nil { + return prepared.traceID, nil + } + + // Streaming fallback: only retry when the provider connection failed + // before any response headers were written to the client. + var connErr *errProviderConnection + if !errors.As(streamErr, &connErr) { + return prepared.traceID, streamErr + } + + if fallbackErr := s.streamWithFallback(ctx, prepared, w); fallbackErr != nil { + return prepared.traceID, fallbackErr + } + return prepared.traceID, nil +} + +// dispatchStream routes a streaming request to the correct provider. +func (s *service) dispatchStream(ctx context.Context, prepared *preparedChatRequest, w http.ResponseWriter) error { switch models.ResolveLLMProtocolTypeOrDefault(prepared.resolvedModel.ProviderType, prepared.resolvedModel.ProtocolType) { case models.ProtocolTypeOpenAI, models.ProtocolTypeOpenAICompatible: - return prepared.traceID, s.streamOpenAICompatible(ctx, prepared, w) + return s.streamOpenAICompatible(ctx, prepared, w) case models.ProtocolTypeAnthropic: - return prepared.traceID, s.streamAnthropic(ctx, prepared, w) + return s.streamAnthropic(ctx, prepared, w) default: _ = s.auditEventService.RecordEvent(&models.AuditEvent{ TraceID: prepared.traceID, @@ -393,10 +483,49 @@ func (s *service) StreamChatCompletions(ctx context.Context, userID int, req Cha Severity: models.AuditSeverityWarn, Message: fmt.Sprintf("Provider type %s is not supported yet", prepared.resolvedModel.ProviderType), }) - return prepared.traceID, errors.New("provider type is not supported yet") + return errors.New("provider type is not supported yet") } } +// streamWithFallback tries alternate models for streaming when the primary +// provider connection failed before any response was sent to the client. +func (s *service) streamWithFallback(ctx context.Context, prepared *preparedChatRequest, w http.ResponseWriter) error { + excludeIDs := map[int]bool{prepared.resolvedModel.ID: true} + + for attempt := 0; attempt < maxFallbackRetries; attempt++ { + fallbackModel, err := s.selectAutoModelExcluding(excludeIDs) + if err != nil { + return fmt.Errorf("no fallback models available: %w", err) + } + + _ = s.auditEventService.RecordEvent(&models.AuditEvent{ + TraceID: prepared.traceID, + SessionID: prepared.req.SessionID, + RequestID: prepared.requestIDPtr, + UserID: prepared.userIDPtr, + InstanceID: prepared.req.InstanceID, + EventType: "gateway.request.fallback", + TrafficClass: models.TrafficClassLLM, + Severity: models.AuditSeverityWarn, + Message: fmt.Sprintf("Stream falling back from model %s to %s (attempt %d)", prepared.resolvedModel.DisplayName, fallbackModel.DisplayName, attempt+1), + }) + + prepared.resolvedModel = fallbackModel + streamErr := s.dispatchStream(ctx, prepared, w) + if streamErr == nil { + return nil + } + + var connErr *errProviderConnection + if !errors.As(streamErr, &connErr) { + // Headers were written or non-connection error; cannot retry. + return streamErr + } + excludeIDs[fallbackModel.ID] = true + } + return errors.New("all fallback models exhausted for streaming request") +} + func (s *service) prepareChatRequest(userID int, req ChatCompletionRequest) (*preparedChatRequest, error) { if strings.TrimSpace(req.Model) == "" { return nil, errors.New("model is required") @@ -818,7 +947,7 @@ func (s *service) streamOpenAICompatible(ctx context.Context, prepared *prepared response, err := s.httpClient.Do(httpRequest) if err != nil { s.recordFailure(prepared.traceID, prepared.requestID, prepared.req, userIDOrZero(prepared.userIDPtr), prepared.resolvedModel, startedAt, fmt.Sprintf("provider call failed: %v", err), providerRequestBody) - return fmt.Errorf("failed to call provider: %w", err) + return &errProviderConnection{wrapped: fmt.Errorf("failed to call provider: %w", err)} } defer response.Body.Close() @@ -907,7 +1036,7 @@ func (s *service) streamAnthropic(ctx context.Context, prepared *preparedChatReq response, err := s.httpClient.Do(httpRequest) if err != nil { s.recordFailure(prepared.traceID, prepared.requestID, prepared.req, userIDOrZero(prepared.userIDPtr), prepared.resolvedModel, startedAt, fmt.Sprintf("provider call failed: %v", err), providerRequestBody) - return fmt.Errorf("failed to call provider: %w", err) + return &errProviderConnection{wrapped: fmt.Errorf("failed to call provider: %w", err)} } defer response.Body.Close() @@ -1907,22 +2036,56 @@ func (s *service) resolveRequestedModel(requestedModel string) (*models.LLMModel } func (s *service) selectAutoModel() (*models.LLMModel, error) { + return s.selectAutoModelExcluding(nil) +} + +// selectAutoModelExcluding picks the best non-secure model by priority, +// randomly breaking ties among the highest-priority group. +// Models whose IDs appear in excludeIDs are skipped (used by fallback). +func (s *service) selectAutoModelExcluding(excludeIDs map[int]bool) (*models.LLMModel, error) { items, err := s.modelRepo.ListActive() if err != nil { return nil, fmt.Errorf("failed to list active models: %w", err) } - if len(items) == 0 { + + // Filter: non-secure and not excluded. + candidates := make([]models.LLMModel, 0, len(items)) + for _, item := range items { + if item.IsSecure { + continue + } + if excludeIDs != nil && excludeIDs[item.ID] { + continue + } + candidates = append(candidates, item) + } + + if len(candidates) == 0 { + // Fallback: include secure models if no non-secure candidates remain. + for _, item := range items { + if excludeIDs != nil && excludeIDs[item.ID] { + continue + } + candidates = append(candidates, item) + } + } + if len(candidates) == 0 { return nil, errors.New("no active models are configured") } - for _, item := range items { - if !item.IsSecure { - selected := item - return &selected, nil + // Items are already sorted by -priority, -is_secure, display_name from the repo. + // Collect the highest-priority group. + highestPriority := candidates[0].Priority + topGroup := make([]models.LLMModel, 0, len(candidates)) + for _, c := range candidates { + if c.Priority < highestPriority { + break } + topGroup = append(topGroup, c) } - selected := items[0] + // Random pick among the top-priority group for simple load balancing. + selected := topGroup[mathrand.IntN(len(topGroup))] return &selected, nil } diff --git a/backend/internal/aigateway/service_test.go b/backend/internal/aigateway/service_test.go index 32f00b7..0a994af 100644 --- a/backend/internal/aigateway/service_test.go +++ b/backend/internal/aigateway/service_test.go @@ -2,6 +2,7 @@ package aigateway import ( "encoding/json" + "errors" "strings" "testing" @@ -44,6 +45,215 @@ func (s *stubChatSessionService) EnsureSession(sessionID string, userID, instanc return s.session, nil } +// --- LLMModelRepository stub --- + +type stubLLMModelRepository struct { + activeModels []models.LLMModel + allModels []models.LLMModel + byName map[string]*models.LLMModel +} + +func (r *stubLLMModelRepository) List() ([]models.LLMModel, error) { + return r.allModels, nil +} + +func (r *stubLLMModelRepository) ListActive() ([]models.LLMModel, error) { + return r.activeModels, nil +} + +func (r *stubLLMModelRepository) GetByID(id int) (*models.LLMModel, error) { + for i := range r.allModels { + if r.allModels[i].ID == id { + return &r.allModels[i], nil + } + } + return nil, nil +} + +func (r *stubLLMModelRepository) GetByDisplayName(displayName string) (*models.LLMModel, error) { + if r.byName != nil { + return r.byName[displayName], nil + } + for i := range r.allModels { + if r.allModels[i].DisplayName == displayName { + return &r.allModels[i], nil + } + } + return nil, nil +} + +func (r *stubLLMModelRepository) Save(model *models.LLMModel) error { return nil } +func (r *stubLLMModelRepository) Delete(id int) error { return nil } + +// --- Tests for ListAvailableModels --- + +func TestListAvailableModelsReturnsAutoAndAllActive(t *testing.T) { + repo := &stubLLMModelRepository{ + activeModels: []models.LLMModel{ + {ID: 1, DisplayName: "GPT-4o", ProviderType: "openai", IsActive: true, Priority: 10}, + {ID: 2, DisplayName: "Claude-3", ProviderType: "anthropic", IsActive: true, Priority: 5}, + }, + } + svc := &service{modelRepo: repo} + + result, err := svc.ListAvailableModels() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result) != 3 { + t.Fatalf("expected 3 models (Auto + 2 active), got %d", len(result)) + } + if result[0].DisplayName != "Auto" { + t.Errorf("first model should be Auto, got %q", result[0].DisplayName) + } + if result[1].DisplayName != "GPT-4o" { + t.Errorf("second model should be GPT-4o, got %q", result[1].DisplayName) + } + if result[2].DisplayName != "Claude-3" { + t.Errorf("third model should be Claude-3, got %q", result[2].DisplayName) + } +} + +func TestListAvailableModelsEmptyWhenNoActive(t *testing.T) { + repo := &stubLLMModelRepository{activeModels: []models.LLMModel{}} + svc := &service{modelRepo: repo} + + result, err := svc.ListAvailableModels() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result) != 0 { + t.Fatalf("expected 0 models when none active, got %d", len(result)) + } +} + +// --- Tests for selectAutoModel --- + +func TestSelectAutoModelPicksHighestPriority(t *testing.T) { + repo := &stubLLMModelRepository{ + activeModels: []models.LLMModel{ + {ID: 1, DisplayName: "High", Priority: 100, IsSecure: false, IsActive: true}, + {ID: 2, DisplayName: "Low", Priority: 10, IsSecure: false, IsActive: true}, + }, + } + svc := &service{modelRepo: repo} + + selected, err := svc.selectAutoModel() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if selected.DisplayName != "High" { + t.Errorf("expected highest priority model 'High', got %q", selected.DisplayName) + } +} + +func TestSelectAutoModelSkipsSecure(t *testing.T) { + repo := &stubLLMModelRepository{ + activeModels: []models.LLMModel{ + {ID: 1, DisplayName: "Secure", Priority: 100, IsSecure: true, IsActive: true}, + {ID: 2, DisplayName: "Normal", Priority: 50, IsSecure: false, IsActive: true}, + }, + } + svc := &service{modelRepo: repo} + + selected, err := svc.selectAutoModel() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if selected.DisplayName != "Normal" { + t.Errorf("expected non-secure model 'Normal', got %q", selected.DisplayName) + } +} + +func TestSelectAutoModelFallsBackToSecureWhenNoNonSecure(t *testing.T) { + repo := &stubLLMModelRepository{ + activeModels: []models.LLMModel{ + {ID: 1, DisplayName: "SecureOnly", Priority: 10, IsSecure: true, IsActive: true}, + }, + } + svc := &service{modelRepo: repo} + + selected, err := svc.selectAutoModel() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if selected.DisplayName != "SecureOnly" { + t.Errorf("expected fallback to secure model, got %q", selected.DisplayName) + } +} + +func TestSelectAutoModelDistributesAmongSamePriority(t *testing.T) { + repo := &stubLLMModelRepository{ + activeModels: []models.LLMModel{ + {ID: 1, DisplayName: "A", Priority: 100, IsSecure: false, IsActive: true}, + {ID: 2, DisplayName: "B", Priority: 100, IsSecure: false, IsActive: true}, + {ID: 3, DisplayName: "C", Priority: 100, IsSecure: false, IsActive: true}, + }, + } + svc := &service{modelRepo: repo} + + counts := map[string]int{} + iterations := 300 + for i := 0; i < iterations; i++ { + selected, err := svc.selectAutoModel() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + counts[selected.DisplayName]++ + } + + // Each model should be selected at least once in 300 iterations. + for _, name := range []string{"A", "B", "C"} { + if counts[name] == 0 { + t.Errorf("model %q was never selected in %d iterations — load balancing broken", name, iterations) + } + } +} + +func TestSelectAutoModelExcludingSkipsExcluded(t *testing.T) { + repo := &stubLLMModelRepository{ + activeModels: []models.LLMModel{ + {ID: 1, DisplayName: "Primary", Priority: 100, IsSecure: false, IsActive: true}, + {ID: 2, DisplayName: "Fallback", Priority: 50, IsSecure: false, IsActive: true}, + }, + } + svc := &service{modelRepo: repo} + + selected, err := svc.selectAutoModelExcluding(map[int]bool{1: true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if selected.DisplayName != "Fallback" { + t.Errorf("expected 'Fallback' after excluding primary, got %q", selected.DisplayName) + } +} + +func TestSelectAutoModelExcludingAllReturnsError(t *testing.T) { + repo := &stubLLMModelRepository{ + activeModels: []models.LLMModel{ + {ID: 1, DisplayName: "Only", Priority: 100, IsSecure: false, IsActive: true}, + }, + } + svc := &service{modelRepo: repo} + + _, err := svc.selectAutoModelExcluding(map[int]bool{1: true}) + if err == nil { + t.Fatal("expected error when all models excluded, got nil") + } +} + +func TestErrProviderConnectionIsDetectable(t *testing.T) { + var err error = &errProviderConnection{wrapped: errors.New("connection refused")} + + var connErr *errProviderConnection + if !errors.As(err, &connErr) { + t.Fatal("errors.As should detect errProviderConnection") + } + if connErr.Error() != "connection refused" { + t.Errorf("unexpected error message: %q", connErr.Error()) + } +} + func TestBuildProviderRequestPreservesToolConfiguration(t *testing.T) { req := ChatCompletionRequest{ Model: "gateway-model", diff --git a/backend/internal/models/llm_model.go b/backend/internal/models/llm_model.go index c0e986d..ba47cca 100644 --- a/backend/internal/models/llm_model.go +++ b/backend/internal/models/llm_model.go @@ -15,6 +15,7 @@ type LLMModel struct { APIKeySecretRef *string `db:"api_key_secret_ref" json:"api_key_secret_ref,omitempty"` IsSecure bool `db:"is_secure" json:"is_secure"` IsActive bool `db:"is_active" json:"is_active"` + Priority int `db:"priority" json:"priority"` InputPrice float64 `db:"input_price" json:"input_price"` OutputPrice float64 `db:"output_price" json:"output_price"` Currency string `db:"currency" json:"currency"` diff --git a/backend/internal/repository/llm_model_repository.go b/backend/internal/repository/llm_model_repository.go index 9a2e2f6..81d0af8 100644 --- a/backend/internal/repository/llm_model_repository.go +++ b/backend/internal/repository/llm_model_repository.go @@ -87,11 +87,22 @@ WHERE protocol_type IS NULL OR TRIM(protocol_type) = ''; if _, err := r.sess.SQL().Exec(backfillProtocolTypeQuery); err != nil { panic(fmt.Errorf("failed to ensure llm_models protocol_type column: %w", err)) } + + const addPriorityColumn = ` +ALTER TABLE llm_models + ADD COLUMN priority INT NOT NULL DEFAULT 0 AFTER is_active; +` + + if _, err := r.sess.SQL().Exec(addPriorityColumn); err != nil { + if !strings.Contains(strings.ToLower(err.Error()), "duplicate column name") { + panic(fmt.Errorf("failed to ensure llm_models priority column: %w", err)) + } + } } func (r *llmModelRepository) List() ([]models.LLMModel, error) { var items []models.LLMModel - if err := r.sess.Collection("llm_models").Find().OrderBy("-is_secure", "display_name").All(&items); err != nil { + if err := r.sess.Collection("llm_models").Find().OrderBy("-priority", "-is_secure", "display_name").All(&items); err != nil { return nil, fmt.Errorf("failed to list llm models: %w", err) } return items, nil @@ -99,7 +110,7 @@ func (r *llmModelRepository) List() ([]models.LLMModel, error) { func (r *llmModelRepository) ListActive() ([]models.LLMModel, error) { var items []models.LLMModel - if err := r.sess.Collection("llm_models").Find(db.Cond{"is_active": true}).OrderBy("-is_secure", "display_name").All(&items); err != nil { + if err := r.sess.Collection("llm_models").Find(db.Cond{"is_active": true}).OrderBy("-priority", "-is_secure", "display_name").All(&items); err != nil { return nil, fmt.Errorf("failed to list active llm models: %w", err) } return items, nil