Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions free_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ func (p *tokenPool) ensureSession(ctx context.Context) (string, error) {
if err != nil {
p.session = nil
p.lastError = err.Error()
if isBannedErrorMessage(err.Error()) {
p.disabled = true
}
} else if waitingErr := waitingRoomErrorFromSession(p.name, session, time.Now()); waitingErr != nil {
p.lastError = waitingErr.Error()
} else {
Expand Down
22 changes: 18 additions & 4 deletions models.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"io"
"log"
"math/rand"
"net/http"
"regexp"
"sort"
Expand Down Expand Up @@ -202,7 +201,7 @@ func parseAllFreeModels(source string) map[string][]string {
}

// buildModelMapping creates the model→agent reverse mapping and deduplicated model list.
// When a model appears in multiple agents, one is chosen at random.
// When a model appears in multiple agents, pick the least-used agent to spread traffic.
func buildModelMapping(agentModels map[string][]string) (map[string]string, []string) {
modelAgents := make(map[string][]string)
for agentID, models := range agentModels {
Expand All @@ -213,10 +212,25 @@ func buildModelMapping(agentModels map[string][]string) (map[string]string, []st

modelToAgent := make(map[string]string, len(modelAgents))
allModels := make([]string, 0, len(modelAgents))
for model, agents := range modelAgents {
modelToAgent[model] = agents[rand.Intn(len(agents))]
for model := range modelAgents {
allModels = append(allModels, model)
}
sort.Strings(allModels)

agentUseCount := make(map[string]int, len(agentModels))
for _, model := range allModels {
agents := append([]string(nil), modelAgents[model]...)
sort.Strings(agents)
chosen := agents[0]
bestCount := agentUseCount[chosen]
for _, agentID := range agents[1:] {
if count := agentUseCount[agentID]; count < bestCount {
chosen = agentID
bestCount = count
}
}
modelToAgent[model] = chosen
agentUseCount[chosen]++
}
return modelToAgent, allModels
}
57 changes: 57 additions & 0 deletions run_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type tokenPool struct {
sessionRefreshCh chan struct{}
lastError string
cooldownUntil time.Time
disabled bool
}

type managedRun struct {
Expand Down Expand Up @@ -63,6 +64,7 @@ type tokenSnapshot struct {
SessionPollAt time.Time `json:"session_poll_at,omitempty"`
CooldownUntil time.Time `json:"cooldown_until,omitempty"`
LastError string `json:"last_error,omitempty"`
Disabled bool `json:"disabled,omitempty"`
}

type runSnapshot struct {
Expand Down Expand Up @@ -157,9 +159,15 @@ func (m *RunManager) prewarm(agentIDs []string) {
defer cancel()

for _, pool := range m.pools {
if pool.isDisabled() {
continue
}
if _, err := pool.ensureSession(ctx); err != nil {
m.logger.Printf("%s: free session prewarm failed: %v", pool.name, err)
}
if pool.isDisabled() {
continue
}
for _, agentID := range agentIDs {
if err := pool.rotateAgent(ctx, agentID); err != nil {
m.logger.Printf("%s: prewarm %s failed: %v", pool.name, agentID, err)
Expand Down Expand Up @@ -247,6 +255,14 @@ func (m *RunManager) Snapshots() []tokenSnapshot {

func (p *tokenPool) acquire(ctx context.Context, agentID string) (*runLease, error) {
p.mu.Lock()
if p.disabled {
lastError := p.lastError
p.mu.Unlock()
if lastError == "" {
lastError = "token disabled"
}
return nil, errors.New(lastError)
}
if now := time.Now(); now.Before(p.cooldownUntil) {
cooldownUntil := p.cooldownUntil
p.mu.Unlock()
Expand Down Expand Up @@ -278,9 +294,15 @@ func (p *tokenPool) acquire(ctx context.Context, agentID string) (*runLease, err
}

func (p *tokenPool) maintain(ctx context.Context) error {
if p.isDisabled() {
return nil
}
if _, err := p.ensureSession(ctx); err != nil {
p.logger.Printf("%s: refresh free session failed: %v", p.name, err)
}
if p.isDisabled() {
return nil
}

p.mu.Lock()
var toRotate []string
Expand Down Expand Up @@ -334,6 +356,14 @@ func (p *tokenPool) shutdown(ctx context.Context) error {

func (p *tokenPool) rotateAgent(ctx context.Context, agentID string) error {
p.mu.Lock()
if p.disabled {
lastError := p.lastError
p.mu.Unlock()
if lastError == "" {
lastError = "token disabled"
}
return errors.New(lastError)
}
if now := time.Now(); now.Before(p.cooldownUntil) {
cooldownUntil := p.cooldownUntil
p.mu.Unlock()
Expand All @@ -343,6 +373,10 @@ func (p *tokenPool) rotateAgent(ctx context.Context, agentID string) error {

runID, err := p.client.StartRun(ctx, p.token, agentID)
if err != nil {
if isBannedErrorMessage(err.Error()) {
p.disable("upstream token banned")
return err
}
p.mu.Lock()
p.lastError = err.Error()
p.mu.Unlock()
Expand Down Expand Up @@ -467,6 +501,7 @@ func (p *tokenPool) snapshot() tokenSnapshot {
DrainingRuns: len(p.draining),
CooldownUntil: p.cooldownUntil,
LastError: p.lastError,
Disabled: p.disabled,
}
if p.session != nil {
snapshot.SessionStatus = string(p.session.status)
Expand All @@ -487,3 +522,25 @@ func (p *tokenPool) snapshot() tokenSnapshot {
}
return snapshot
}

func (p *tokenPool) disable(reason string) {
p.mu.Lock()
defer p.mu.Unlock()
p.disabled = true
p.session = nil
p.cooldownUntil = time.Time{}
if reason != "" {
p.lastError = reason
}
}

func (p *tokenPool) isDisabled() bool {
p.mu.Lock()
defer p.mu.Unlock()
return p.disabled
}

func isBannedErrorMessage(message string) bool {
message = strings.ToLower(strings.TrimSpace(message))
return strings.Contains(message, `"status":"banned"`) || strings.Contains(message, `"status": "banned"`) || strings.Contains(message, "status\":\"banned")
}
33 changes: 24 additions & 9 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,12 @@ func (s *Server) proxyChatRequest(
return
}

for attempt := 0; attempt < 2; attempt++ {
maxAttempts := len(s.cfg.AuthTokens) + 1
if maxAttempts < 2 {
maxAttempts = 2
}

for attempt := 0; attempt < maxAttempts; attempt++ {
lease, err := s.runs.Acquire(r.Context(), agentID)
if err != nil {
var waitingErr *waitingRoomError
Expand Down Expand Up @@ -320,6 +325,21 @@ func (s *Server) proxyChatRequest(
return
}

message, _, code := extractUpstreamError(errorBody)
if isBannedErrorMessage(string(errorBody)) {
s.logger.Printf("%s: upstream token banned, disabling token", lease.pool.name)
lease.pool.disable("upstream token banned")
s.runs.Release(lease)
continue
}
if strings.TrimSpace(code) == "session_model_mismatch" {
s.logger.Printf("%s: session model mismatch on run %s, rotating run and refreshing session", lease.pool.name, lease.run.id)
lease.pool.invalidateSession(strings.TrimSpace(message))
s.runs.Invalidate(lease, strings.TrimSpace(message))
s.runs.Release(lease)
continue
}

if isSessionInvalid(resp.StatusCode, errorBody) {
s.logger.Printf("%s: free session invalid, refreshing and retrying", lease.pool.name)
lease.pool.invalidateSession(strings.TrimSpace(string(errorBody)))
Expand Down Expand Up @@ -388,14 +408,9 @@ func isSessionInvalid(statusCode int, errorBody []byte) bool {
if statusCode < 400 {
return false
}
var payload struct {
Error string `json:"error"`
}
if err := json.Unmarshal(errorBody, &payload); err != nil {
return false
}
switch strings.TrimSpace(payload.Error) {
case "freebuff_update_required", "waiting_room_required", "waiting_room_queued", "session_superseded", "session_expired":
_, _, code := extractUpstreamError(errorBody)
switch strings.TrimSpace(code) {
case "freebuff_update_required", "waiting_room_required", "waiting_room_queued", "session_superseded", "session_expired", "session_model_mismatch":
return true
default:
return false
Expand Down