diff --git a/cmd/wave/commands/db_logging_emitter_test.go b/cmd/wave/commands/db_logging_emitter_test.go new file mode 100644 index 000000000..65ae10b04 --- /dev/null +++ b/cmd/wave/commands/db_logging_emitter_test.go @@ -0,0 +1,198 @@ +package commands + +import ( + "testing" + + "github.com/recinq/wave/internal/event" + "github.com/recinq/wave/internal/state" +) + +type logEventCall struct { + runID, stepID, state, persona, message string + tokens int + durationMs int64 + model, configuredModel, adapter string +} + +// fakeLogEventStore satisfies state.StateStore by embedding the interface. +// Only LogEvent is implemented — any other method call panics, which is fine +// because dbLoggingEmitter.Emit only invokes LogEvent. +type fakeLogEventStore struct { + state.StateStore + calls []logEventCall +} + +func (f *fakeLogEventStore) LogEvent(runID, stepID, st, persona, message string, tokens int, durationMs int64, model, configuredModel, adapter string) error { + f.calls = append(f.calls, logEventCall{ + runID: runID, + stepID: stepID, + state: st, + persona: persona, + message: message, + tokens: tokens, + durationMs: durationMs, + model: model, + configuredModel: configuredModel, + adapter: adapter, + }) + return nil +} + +type fakeEventEmitter struct { + events []event.Event +} + +func (f *fakeEventEmitter) Emit(e event.Event) { + f.events = append(f.events, e) +} + +func TestDBLoggingEmitter_Emit(t *testing.T) { + tests := []struct { + name string + ev event.Event + wantPersist bool + wantMessage string + wantRunID string + }{ + { + name: "empty step_progress heartbeat is dropped", + ev: event.Event{State: "step_progress"}, + wantPersist: false, + }, + { + name: "empty stream_activity heartbeat is dropped", + ev: event.Event{State: "stream_activity"}, + wantPersist: false, + }, + { + name: "stream_activity with ToolName composes message", + ev: event.Event{ + State: "stream_activity", + ToolName: "Read", + ToolTarget: "cmd/wave/commands/run.go", + StepID: "step-1", + Persona: "navigator", + }, + wantPersist: true, + wantMessage: "Read cmd/wave/commands/run.go", + wantRunID: "default-run", + }, + { + name: "step_progress with tokens used is persisted", + ev: event.Event{ + State: "step_progress", + TokensUsed: 42, + StepID: "step-1", + }, + wantPersist: true, + wantMessage: "", + wantRunID: "default-run", + }, + { + name: "step_progress with duration is persisted", + ev: event.Event{ + State: "step_progress", + DurationMs: 100, + StepID: "step-1", + }, + wantPersist: true, + wantMessage: "", + wantRunID: "default-run", + }, + { + name: "running state with message is persisted", + ev: event.Event{ + State: "running", + Message: "step started", + StepID: "step-1", + Persona: "implementer", + }, + wantPersist: true, + wantMessage: "step started", + wantRunID: "default-run", + }, + { + name: "completed state with no message still persists (not heartbeat)", + ev: event.Event{ + State: "completed", + StepID: "step-1", + }, + wantPersist: true, + wantMessage: "", + wantRunID: "default-run", + }, + { + name: "event with PipelineID overrides default runID", + ev: event.Event{ + State: "running", + Message: "child running", + PipelineID: "child-run-id", + StepID: "step-1", + }, + wantPersist: true, + wantMessage: "child running", + wantRunID: "child-run-id", + }, + { + name: "stream_activity with ToolName but no target", + ev: event.Event{ + State: "stream_activity", + ToolName: "Bash", + StepID: "step-1", + }, + wantPersist: true, + wantMessage: "Bash ", + wantRunID: "default-run", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inner := &fakeEventEmitter{} + store := &fakeLogEventStore{} + d := &dbLoggingEmitter{inner: inner, store: store, runID: "default-run"} + + d.Emit(tt.ev) + + if len(inner.events) != 1 { + t.Fatalf("inner.Emit called %d times, want 1", len(inner.events)) + } + if inner.events[0].State != tt.ev.State || inner.events[0].StepID != tt.ev.StepID || inner.events[0].Message != tt.ev.Message { + t.Errorf("inner received unexpected event: got %+v want %+v", inner.events[0], tt.ev) + } + + if !tt.wantPersist { + if len(store.calls) != 0 { + t.Errorf("LogEvent called %d times, want 0 (heartbeat should be dropped); calls=%+v", len(store.calls), store.calls) + } + return + } + + if len(store.calls) != 1 { + t.Fatalf("LogEvent called %d times, want 1", len(store.calls)) + } + c := store.calls[0] + if c.runID != tt.wantRunID { + t.Errorf("LogEvent runID = %q, want %q", c.runID, tt.wantRunID) + } + if c.message != tt.wantMessage { + t.Errorf("LogEvent message = %q, want %q", c.message, tt.wantMessage) + } + if c.state != tt.ev.State { + t.Errorf("LogEvent state = %q, want %q", c.state, tt.ev.State) + } + if c.stepID != tt.ev.StepID { + t.Errorf("LogEvent stepID = %q, want %q", c.stepID, tt.ev.StepID) + } + if c.persona != tt.ev.Persona { + t.Errorf("LogEvent persona = %q, want %q", c.persona, tt.ev.Persona) + } + if c.tokens != tt.ev.TokensUsed { + t.Errorf("LogEvent tokens = %d, want %d", c.tokens, tt.ev.TokensUsed) + } + if c.durationMs != tt.ev.DurationMs { + t.Errorf("LogEvent durationMs = %d, want %d", c.durationMs, tt.ev.DurationMs) + } + }) + } +} diff --git a/cmd/wave/commands/preflight_metadata_test.go b/cmd/wave/commands/preflight_metadata_test.go index dbf5cfe4b..e51b0668f 100644 --- a/cmd/wave/commands/preflight_metadata_test.go +++ b/cmd/wave/commands/preflight_metadata_test.go @@ -2,100 +2,134 @@ package commands import ( "errors" + "fmt" + "reflect" "testing" "github.com/recinq/wave/internal/preflight" + "github.com/recinq/wave/internal/recovery" ) func TestExtractPreflightMetadata(t *testing.T) { tests := []struct { - name string - err error - wantSkills []string - wantTools []string - wantNil bool + name string + err error + want *recovery.PreflightMetadata }{ { - name: "nil error returns nil", - err: nil, - wantNil: true, + name: "nil error returns nil", + err: nil, + want: nil, }, { - name: "non-preflight error returns nil", - err: errors.New("generic error"), - wantNil: true, + name: "non-preflight error returns nil", + err: errors.New("generic error"), + want: nil, }, { name: "skill error extracts missing skills", err: &preflight.SkillError{ MissingSkills: []string{"speckit", "testkit"}, }, - wantSkills: []string{"speckit", "testkit"}, + want: &recovery.PreflightMetadata{ + MissingSkills: []string{"speckit", "testkit"}, + }, }, { name: "tool error extracts missing tools", err: &preflight.ToolError{ MissingTools: []string{"jq", "yq"}, }, - wantTools: []string{"jq", "yq"}, + want: &recovery.PreflightMetadata{ + MissingTools: []string{"jq", "yq"}, + }, }, { - name: "wrapped skill error extracts missing skills", + name: "errors.Join with skill error extracts missing skills", err: errors.Join( errors.New("preflight check failed"), &preflight.SkillError{ MissingSkills: []string{"speckit"}, }, ), - wantSkills: []string{"speckit"}, + want: &recovery.PreflightMetadata{ + MissingSkills: []string{"speckit"}, + }, }, { - name: "wrapped tool error extracts missing tools", + name: "errors.Join with tool error extracts missing tools", err: errors.Join( errors.New("preflight check failed"), &preflight.ToolError{ MissingTools: []string{"jq"}, }, ), - wantTools: []string{"jq"}, + want: &recovery.PreflightMetadata{ + MissingTools: []string{"jq"}, + }, + }, + { + name: "errors.Join with both skill and tool errors extracts both", + err: errors.Join( + &preflight.SkillError{MissingSkills: []string{"speckit"}}, + &preflight.ToolError{MissingTools: []string{"jq"}}, + ), + want: &recovery.PreflightMetadata{ + MissingSkills: []string{"speckit"}, + MissingTools: []string{"jq"}, + }, + }, + { + name: "fmt.Errorf %w wrapping skill error extracts missing skills", + err: fmt.Errorf("preflight failed: %w", &preflight.SkillError{ + MissingSkills: []string{"speckit", "testkit"}, + }), + want: &recovery.PreflightMetadata{ + MissingSkills: []string{"speckit", "testkit"}, + }, + }, + { + name: "fmt.Errorf %w wrapping tool error extracts missing tools", + err: fmt.Errorf("preflight failed: %w", &preflight.ToolError{ + MissingTools: []string{"jq"}, + }), + want: &recovery.PreflightMetadata{ + MissingTools: []string{"jq"}, + }, + }, + { + name: "double-wrapped fmt.Errorf %w skill error still extracts missing skills", + err: fmt.Errorf("outer: %w", + fmt.Errorf("inner: %w", &preflight.SkillError{ + MissingSkills: []string{"speckit"}, + }), + ), + want: &recovery.PreflightMetadata{ + MissingSkills: []string{"speckit"}, + }, + }, + { + name: "skill error with empty MissingSkills returns nil", + err: &preflight.SkillError{ + MissingSkills: nil, + }, + want: nil, + }, + { + name: "tool error with empty MissingTools returns nil", + err: &preflight.ToolError{ + MissingTools: []string{}, + }, + want: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - meta := extractPreflightMetadata(tt.err) - - if tt.wantNil { - if meta != nil { - t.Errorf("expected nil metadata, got %+v", meta) - } - return - } - - if meta == nil { - t.Fatal("expected non-nil metadata") - } - - if len(tt.wantSkills) > 0 { - if len(meta.MissingSkills) != len(tt.wantSkills) { - t.Errorf("MissingSkills count = %d, want %d", len(meta.MissingSkills), len(tt.wantSkills)) - } - for i, skill := range tt.wantSkills { - if i >= len(meta.MissingSkills) || meta.MissingSkills[i] != skill { - t.Errorf("MissingSkills[%d] = %q, want %q", i, meta.MissingSkills[i], skill) - } - } - } + got := extractPreflightMetadata(tt.err) - if len(tt.wantTools) > 0 { - if len(meta.MissingTools) != len(tt.wantTools) { - t.Errorf("MissingTools count = %d, want %d", len(meta.MissingTools), len(tt.wantTools)) - } - for i, tool := range tt.wantTools { - if i >= len(meta.MissingTools) || meta.MissingTools[i] != tool { - t.Errorf("MissingTools[%d] = %q, want %q", i, meta.MissingTools[i], tool) - } - } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("extractPreflightMetadata(%v) = %+v, want %+v", tt.err, got, tt.want) } }) } diff --git a/cmd/wave/commands/run.go b/cmd/wave/commands/run.go index b5be16ee4..3b83093d3 100644 --- a/cmd/wave/commands/run.go +++ b/cmd/wave/commands/run.go @@ -463,7 +463,7 @@ func runRun(opts RunOptions, debug bool) error { if opts.ForceModel { execOpts = append(execOpts, pipeline.WithForceModel(true)) } - registry := adapter.NewAdapterRegistry(nil) + registry := adapter.NewAdapterRegistry(m.Runtime.Fallbacks) for name, a := range m.Adapters { if a.Binary != "" { registry.SetBinary(name, a.Binary) diff --git a/internal/adapter/fallback.go b/internal/adapter/fallback.go index 2ed218b41..fc033a989 100644 --- a/internal/adapter/fallback.go +++ b/internal/adapter/fallback.go @@ -80,11 +80,23 @@ func (f *FallbackRunner) Run(ctx context.Context, cfg AdapterRunConfig) (*Adapte return lastResult, fmt.Errorf("all fallback adapters exhausted") } -// isFallbackTrigger returns true if the result indicates a rate limit -// failure that should trigger fallback to the next provider. +// isFallbackTrigger returns true when a failure has a real chance of +// succeeding on a different provider — i.e. the failure is upstream-capacity +// (rate limit), wall-clock (the model stalled past the timeout), or +// context-budget (the model couldn't fit the prompt). All three commonly +// resolve when retried on a peer with different limits, model architecture, +// or context window. +// +// Other classifications (general_error, validation, etc.) are intentionally +// excluded — they typically indicate a bug or schema mismatch that will fail +// the same way on any provider. func isFallbackTrigger(result *AdapterResult) bool { if result == nil { return false } - return result.FailureReason == "rate_limit" + switch result.FailureReason { + case FailureReasonRateLimit, FailureReasonTimeout, FailureReasonContextExhaustion: + return true + } + return false } diff --git a/internal/adapter/fallback_test.go b/internal/adapter/fallback_test.go index c796082f4..54c8cb96e 100644 --- a/internal/adapter/fallback_test.go +++ b/internal/adapter/fallback_test.go @@ -72,7 +72,7 @@ func TestFallbackRunner_RateLimitTriggersFallback(t *testing.T) { assert.Equal(t, 1, fallback.callCount) } -func TestFallbackRunner_ContextExhaustionDoesNotTriggerFallback(t *testing.T) { +func TestFallbackRunner_ContextExhaustionTriggersFallback(t *testing.T) { primary := &failingRunner{failureReason: "context_exhaustion"} fallback := &successRunner{} @@ -83,8 +83,25 @@ func TestFallbackRunner_ContextExhaustionDoesNotTriggerFallback(t *testing.T) { result, err := fr.Run(context.Background(), AdapterRunConfig{}) assert.NoError(t, err) - assert.Equal(t, "context_exhaustion", result.FailureReason) - assert.Equal(t, 0, fallback.callCount, "should NOT call fallback on context_exhaustion") + assert.Equal(t, "success", result.ResultContent, + "context_exhaustion should trigger fallback to a peer with a different context budget") + assert.Equal(t, 1, fallback.callCount) +} + +func TestFallbackRunner_TimeoutTriggersFallback(t *testing.T) { + primary := &failingRunner{failureReason: "timeout"} + fallback := &successRunner{} + + registry := NewAdapterRegistry(nil) + registry.RegisterOverride("codex", fallback) + + fr := NewFallbackRunner(primary, []string{"codex"}, registry) + result, err := fr.Run(context.Background(), AdapterRunConfig{}) + + assert.NoError(t, err) + assert.Equal(t, "success", result.ResultContent, + "timeout should trigger fallback so a stalled local model hands off to a peer") + assert.Equal(t, 1, fallback.callCount) } func TestFallbackRunner_GeneralErrorDoesNotTriggerFallback(t *testing.T) { @@ -130,9 +147,14 @@ func TestFallbackRunner_EmptyChainReturnsPrimaryResult(t *testing.T) { result, err := fr.Run(context.Background(), AdapterRunConfig{}) // Empty chain — the initial rate_limit result won't trigger any fallbacks - // but the loop doesn't execute, so we get the "all fallback adapters exhausted" error - assert.Error(t, err) - assert.NotNil(t, result) + // but the loop doesn't execute, so we get the "all fallback adapters exhausted" error. + // The primary's lastResult is returned unwrapped so callers can inspect why + // the fallback chain bailed out (here: the primary's rate_limit failure). + assert.EqualError(t, err, "all fallback adapters exhausted") + require.NotNil(t, result) + assert.Equal(t, "rate_limit", result.FailureReason) + assert.Equal(t, "failed: rate_limit", result.ResultContent) + assert.Equal(t, 1, primary.callCount) } func TestFallbackRunner_HardErrorFromPrimary(t *testing.T) { @@ -183,9 +205,9 @@ func TestFallbackRunner_ContextCancelledDuringFallback(t *testing.T) { func TestIsFallbackTrigger(t *testing.T) { assert.True(t, isFallbackTrigger(&AdapterResult{FailureReason: "rate_limit"})) - assert.False(t, isFallbackTrigger(&AdapterResult{FailureReason: "context_exhaustion"})) + assert.True(t, isFallbackTrigger(&AdapterResult{FailureReason: "context_exhaustion"})) + assert.True(t, isFallbackTrigger(&AdapterResult{FailureReason: "timeout"})) assert.False(t, isFallbackTrigger(&AdapterResult{FailureReason: "general_error"})) - assert.False(t, isFallbackTrigger(&AdapterResult{FailureReason: "timeout"})) assert.False(t, isFallbackTrigger(&AdapterResult{FailureReason: ""})) assert.False(t, isFallbackTrigger(nil)) } diff --git a/wave.yaml b/wave.yaml index 6f408ffaa..70cd0804c 100644 --- a/wave.yaml +++ b/wave.yaml @@ -473,6 +473,14 @@ runtime: default_timeout_minutes: 30 stall_timeout: 10m max_concurrent_workers: 5 + # Adapter fallback chains. When the primary adapter returns a fallback- + # eligible failure (rate_limit, timeout, context_exhaustion), the registry + # walks the chain in order until one succeeds. Useful for self-healing + # when local Ollama models stall on tool-call streams (the GLM-via-Ollama + # bug documented in PRs #1404 / #1468). + fallbacks: + opencode-glm: [opencode-qwen, claude] + opencode-qwen: [claude] meta_pipeline: max_depth: 2 max_total_steps: 20