Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 69 additions & 4 deletions pkg/runtime/model_switcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/docker/docker-agent/pkg/environment"
"github.com/docker/docker-agent/pkg/model/provider"
"github.com/docker/docker-agent/pkg/model/provider/options"
"github.com/docker/docker-agent/pkg/modelsdev"
)

// ModelChoice represents a model available for selection in the TUI picker.
Expand All @@ -32,6 +33,31 @@ type ModelChoice struct {
IsCustom bool
// IsCatalog indicates this is a model from the models.dev catalog
IsCatalog bool

// The fields below are populated (best-effort) from the models.dev
// catalog. They are optional and may all be zero/empty when no
// catalog entry is found for the model.

// Family is the model family (e.g., "claude", "gpt").
Family string
// InputCost is the price (in USD) per 1M input tokens.
InputCost float64
// OutputCost is the price (in USD) per 1M output tokens.
OutputCost float64
// CacheReadCost is the price (in USD) per 1M cached input tokens.
CacheReadCost float64
// CacheWriteCost is the price (in USD) per 1M cache-write tokens.
CacheWriteCost float64
// ContextLimit is the maximum context window size in tokens.
ContextLimit int
// OutputLimit is the maximum number of tokens the model can produce
// in a single response.
OutputLimit int64
// InputModalities lists the input modalities supported by the model
// (e.g., "text", "image", "audio").
InputModalities []string
// OutputModalities lists the output modalities the model can produce.
OutputModalities []string
}

// ModelSwitcher is an optional interface for runtimes that support changing the model
Expand Down Expand Up @@ -240,13 +266,18 @@ func (r *LocalRuntime) AvailableModels(ctx context.Context) []ModelChoice {

// Add all configured models, marking the current agent's default
for name, cfg := range r.modelSwitcherCfg.Models {
choices = append(choices, ModelChoice{
choice := ModelChoice{
Name: name,
Ref: name,
Provider: cfg.Provider,
Model: cfg.DisplayOrModel(),
IsDefault: name == currentAgentDefault,
})
}
// Best-effort lookup of pricing / context information from models.dev.
if cfg.Provider != "" && cfg.Model != "" {
r.populateCatalogMetadata(ctx, &choice, cfg.Provider, cfg.Model)
}
choices = append(choices, choice)
}

// Append models.dev catalog entries filtered by available credentials
Expand Down Expand Up @@ -308,13 +339,15 @@ func (r *LocalRuntime) buildCatalogChoices(ctx context.Context) []ModelChoice {
}
existingRefs[ref] = true

choices = append(choices, ModelChoice{
choice := ModelChoice{
Name: model.Name,
Ref: ref,
Provider: dockerAgentProvider,
Model: modelID,
IsCatalog: true,
})
}
applyCatalogMetadata(&choice, &model)
choices = append(choices, choice)
}
}

Expand All @@ -333,6 +366,38 @@ func mapModelsDevProvider(providerID string) (string, bool) {
return "", false
}

// populateCatalogMetadata fetches models.dev metadata for the given
// provider/model pair and copies it onto choice. It silently does
// nothing when the lookup fails or when the runtime has no models store.
func (r *LocalRuntime) populateCatalogMetadata(ctx context.Context, choice *ModelChoice, providerID, modelID string) {
if r.modelsStore == nil {
return
}
m, err := r.modelsStore.GetModel(ctx, providerID+"/"+modelID)
if err == nil {
applyCatalogMetadata(choice, m)
}
}

// applyCatalogMetadata copies pricing/limit/modality information from a
// models.dev Model entry onto a ModelChoice.
func applyCatalogMetadata(choice *ModelChoice, m *modelsdev.Model) {
if m == nil {
return
}
choice.Family = m.Family
if m.Cost != nil {
choice.InputCost = m.Cost.Input
choice.OutputCost = m.Cost.Output
choice.CacheReadCost = m.Cost.CacheRead
choice.CacheWriteCost = m.Cost.CacheWrite
}
choice.ContextLimit = m.Limit.Context
choice.OutputLimit = m.Limit.Output
choice.InputModalities = slices.Clone(m.Modalities.Input)
choice.OutputModalities = slices.Clone(m.Modalities.Output)
}

// isEmbeddingModel returns true if the model is an embedding model
// based on its family or name fields from models.dev.
func isEmbeddingModel(family, name string) bool {
Expand Down
Loading
Loading