Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 183 additions & 20 deletions backend/internal/aigateway/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"fmt"
"io"
"log"
mathrand "math/rand/v2"
"net/http"
"net/url"
"regexp"
Expand Down Expand Up @@ -119,6 +120,17 @@ const autoModelID = "auto"
const maxStoredIdentifierLength = 100
const anthropicVersionHeader = "2023-06-01"
const defaultAnthropicMaxTokens = 4096
const maxFallbackRetries = 2

// errProviderConnection signals that the provider HTTP call failed before any
// response headers were written, making the request safe to retry on a
// different model.
type errProviderConnection struct {
wrapped error
}

func (e *errProviderConnection) Error() string { return e.wrapped.Error() }
func (e *errProviderConnection) Unwrap() error { return e.wrapped }

var providerVersionSegmentPattern = regexp.MustCompile(`(?i)^v\d+(?:[a-z0-9._-]*)?$`)

Expand Down Expand Up @@ -324,15 +336,24 @@ func (s *service) ListAvailableModels() ([]AvailableModel, error) {
return []AvailableModel{}, nil
}

return []AvailableModel{
{
ID: 0,
DisplayName: "Auto",
Description: stringPtr("Automatically route requests to the best available model under current governance policy."),
IsSecure: false,
Provider: "gateway",
},
}, nil
result := make([]AvailableModel, 0, len(items)+1)
result = append(result, AvailableModel{
ID: 0,
DisplayName: "Auto",
Description: stringPtr("Automatically route requests to the best available model under current governance policy."),
IsSecure: false,
Provider: "gateway",
})
for _, item := range items {
result = append(result, AvailableModel{
ID: item.ID,
DisplayName: item.DisplayName,
Description: item.Description,
IsSecure: item.IsSecure,
Provider: item.ProviderType,
})
}
return result, nil
}

func (s *service) ChatCompletions(ctx context.Context, userID int, req ChatCompletionRequest) (*ProxyResponse, string, error) {
Expand All @@ -345,6 +366,22 @@ func (s *service) ChatCompletions(ctx context.Context, userID int, req ChatCompl
return nil, traceID, err
}

resp, traceID, err := s.dispatchCall(ctx, prepared)
if err != nil {
return resp, traceID, err
}

// Fallback: retry with alternate models on connection failure or 5xx.
if resp != nil && resp.StatusCode >= 500 {
if fallbackResp, fallbackTraceID, ok := s.callWithFallback(ctx, prepared); ok {
return fallbackResp, fallbackTraceID, nil
}
}
return resp, traceID, nil
}

// dispatchCall routes a prepared request to the correct provider.
func (s *service) dispatchCall(ctx context.Context, prepared *preparedChatRequest) (*ProxyResponse, string, error) {
switch models.ResolveLLMProtocolTypeOrDefault(prepared.resolvedModel.ProviderType, prepared.resolvedModel.ProtocolType) {
case models.ProtocolTypeOpenAI, models.ProtocolTypeOpenAICompatible:
return s.callOpenAICompatible(ctx, prepared)
Expand All @@ -366,6 +403,39 @@ func (s *service) ChatCompletions(ctx context.Context, userID int, req ChatCompl
}
}

// callWithFallback tries alternate models when the primary model fails.
// Returns (response, traceID, true) on successful fallback, or (nil, "", false) if exhausted.
func (s *service) callWithFallback(ctx context.Context, prepared *preparedChatRequest) (*ProxyResponse, string, bool) {
excludeIDs := map[int]bool{prepared.resolvedModel.ID: true}

for attempt := 0; attempt < maxFallbackRetries; attempt++ {
fallbackModel, err := s.selectAutoModelExcluding(excludeIDs)
if err != nil {
break // no more candidates
}

_ = s.auditEventService.RecordEvent(&models.AuditEvent{
TraceID: prepared.traceID,
SessionID: prepared.req.SessionID,
RequestID: prepared.requestIDPtr,
UserID: prepared.userIDPtr,
InstanceID: prepared.req.InstanceID,
EventType: "gateway.request.fallback",
TrafficClass: models.TrafficClassLLM,
Severity: models.AuditSeverityWarn,
Message: fmt.Sprintf("Falling back from model %s to %s (attempt %d)", prepared.resolvedModel.DisplayName, fallbackModel.DisplayName, attempt+1),
})

prepared.resolvedModel = fallbackModel
resp, traceID, err := s.dispatchCall(ctx, prepared)
if err == nil && (resp == nil || resp.StatusCode < 500) {
return resp, traceID, true
}
excludeIDs[fallbackModel.ID] = true
}
return nil, "", false
}

func (s *service) StreamChatCompletions(ctx context.Context, userID int, req ChatCompletionRequest, w http.ResponseWriter) (string, error) {
prepared, err := s.prepareChatRequest(userID, req)
if err != nil {
Expand All @@ -376,11 +446,31 @@ func (s *service) StreamChatCompletions(ctx context.Context, userID int, req Cha
return traceID, err
}

streamErr := s.dispatchStream(ctx, prepared, w)
if streamErr == nil {
return prepared.traceID, nil
}

// Streaming fallback: only retry when the provider connection failed
// before any response headers were written to the client.
var connErr *errProviderConnection
if !errors.As(streamErr, &connErr) {
return prepared.traceID, streamErr
}

if fallbackErr := s.streamWithFallback(ctx, prepared, w); fallbackErr != nil {
return prepared.traceID, fallbackErr
}
return prepared.traceID, nil
}

// dispatchStream routes a streaming request to the correct provider.
func (s *service) dispatchStream(ctx context.Context, prepared *preparedChatRequest, w http.ResponseWriter) error {
switch models.ResolveLLMProtocolTypeOrDefault(prepared.resolvedModel.ProviderType, prepared.resolvedModel.ProtocolType) {
case models.ProtocolTypeOpenAI, models.ProtocolTypeOpenAICompatible:
return prepared.traceID, s.streamOpenAICompatible(ctx, prepared, w)
return s.streamOpenAICompatible(ctx, prepared, w)
case models.ProtocolTypeAnthropic:
return prepared.traceID, s.streamAnthropic(ctx, prepared, w)
return s.streamAnthropic(ctx, prepared, w)
default:
_ = s.auditEventService.RecordEvent(&models.AuditEvent{
TraceID: prepared.traceID,
Expand All @@ -393,10 +483,49 @@ func (s *service) StreamChatCompletions(ctx context.Context, userID int, req Cha
Severity: models.AuditSeverityWarn,
Message: fmt.Sprintf("Provider type %s is not supported yet", prepared.resolvedModel.ProviderType),
})
return prepared.traceID, errors.New("provider type is not supported yet")
return errors.New("provider type is not supported yet")
}
}

// streamWithFallback tries alternate models for streaming when the primary
// provider connection failed before any response was sent to the client.
func (s *service) streamWithFallback(ctx context.Context, prepared *preparedChatRequest, w http.ResponseWriter) error {
excludeIDs := map[int]bool{prepared.resolvedModel.ID: true}

for attempt := 0; attempt < maxFallbackRetries; attempt++ {
fallbackModel, err := s.selectAutoModelExcluding(excludeIDs)
if err != nil {
return fmt.Errorf("no fallback models available: %w", err)
}

_ = s.auditEventService.RecordEvent(&models.AuditEvent{
TraceID: prepared.traceID,
SessionID: prepared.req.SessionID,
RequestID: prepared.requestIDPtr,
UserID: prepared.userIDPtr,
InstanceID: prepared.req.InstanceID,
EventType: "gateway.request.fallback",
TrafficClass: models.TrafficClassLLM,
Severity: models.AuditSeverityWarn,
Message: fmt.Sprintf("Stream falling back from model %s to %s (attempt %d)", prepared.resolvedModel.DisplayName, fallbackModel.DisplayName, attempt+1),
})

prepared.resolvedModel = fallbackModel
streamErr := s.dispatchStream(ctx, prepared, w)
if streamErr == nil {
return nil
}

var connErr *errProviderConnection
if !errors.As(streamErr, &connErr) {
// Headers were written or non-connection error; cannot retry.
return streamErr
}
excludeIDs[fallbackModel.ID] = true
}
return errors.New("all fallback models exhausted for streaming request")
}

func (s *service) prepareChatRequest(userID int, req ChatCompletionRequest) (*preparedChatRequest, error) {
if strings.TrimSpace(req.Model) == "" {
return nil, errors.New("model is required")
Expand Down Expand Up @@ -818,7 +947,7 @@ func (s *service) streamOpenAICompatible(ctx context.Context, prepared *prepared
response, err := s.httpClient.Do(httpRequest)
if err != nil {
s.recordFailure(prepared.traceID, prepared.requestID, prepared.req, userIDOrZero(prepared.userIDPtr), prepared.resolvedModel, startedAt, fmt.Sprintf("provider call failed: %v", err), providerRequestBody)
return fmt.Errorf("failed to call provider: %w", err)
return &errProviderConnection{wrapped: fmt.Errorf("failed to call provider: %w", err)}
}
defer response.Body.Close()

Expand Down Expand Up @@ -907,7 +1036,7 @@ func (s *service) streamAnthropic(ctx context.Context, prepared *preparedChatReq
response, err := s.httpClient.Do(httpRequest)
if err != nil {
s.recordFailure(prepared.traceID, prepared.requestID, prepared.req, userIDOrZero(prepared.userIDPtr), prepared.resolvedModel, startedAt, fmt.Sprintf("provider call failed: %v", err), providerRequestBody)
return fmt.Errorf("failed to call provider: %w", err)
return &errProviderConnection{wrapped: fmt.Errorf("failed to call provider: %w", err)}
}
defer response.Body.Close()

Expand Down Expand Up @@ -1907,22 +2036,56 @@ func (s *service) resolveRequestedModel(requestedModel string) (*models.LLMModel
}

func (s *service) selectAutoModel() (*models.LLMModel, error) {
return s.selectAutoModelExcluding(nil)
}

// selectAutoModelExcluding picks the best non-secure model by priority,
// randomly breaking ties among the highest-priority group.
// Models whose IDs appear in excludeIDs are skipped (used by fallback).
func (s *service) selectAutoModelExcluding(excludeIDs map[int]bool) (*models.LLMModel, error) {
items, err := s.modelRepo.ListActive()
if err != nil {
return nil, fmt.Errorf("failed to list active models: %w", err)
}
if len(items) == 0 {

// Filter: non-secure and not excluded.
candidates := make([]models.LLMModel, 0, len(items))
for _, item := range items {
if item.IsSecure {
continue
}
if excludeIDs != nil && excludeIDs[item.ID] {
continue
}
candidates = append(candidates, item)
}

if len(candidates) == 0 {
// Fallback: include secure models if no non-secure candidates remain.
for _, item := range items {
if excludeIDs != nil && excludeIDs[item.ID] {
continue
}
candidates = append(candidates, item)
}
}
if len(candidates) == 0 {
return nil, errors.New("no active models are configured")
}

for _, item := range items {
if !item.IsSecure {
selected := item
return &selected, nil
// Items are already sorted by -priority, -is_secure, display_name from the repo.
// Collect the highest-priority group.
highestPriority := candidates[0].Priority
topGroup := make([]models.LLMModel, 0, len(candidates))
for _, c := range candidates {
if c.Priority < highestPriority {
break
}
topGroup = append(topGroup, c)
}

selected := items[0]
// Random pick among the top-priority group for simple load balancing.
selected := topGroup[mathrand.IntN(len(topGroup))]
return &selected, nil
}

Expand Down
Loading
Loading