Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@
# fallback (default): use configured models only when upstream /models fails, is nil, or is empty.
# allowlist: expose only the configured models for providers that define a list, and skip their upstream /models calls.
# CONFIGURED_PROVIDER_MODELS_MODE=fallback
# Model ID format for GET /v1/models: "qualified" (provider/model), "unqualified" (model only), or "both"
# MODELS_ENDPOINT_ID_FORMAT=qualified
# Examples: OPENROUTER_MODELS=..., OPENROUTER_EU_MODELS=..., AZURE_MODELS=..., VLLM_MODELS=...

# Fallback & Workflow Configuration
Expand Down
1 change: 1 addition & 0 deletions config/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ models:
enabled_by_default: true # env: MODELS_ENABLED_BY_DEFAULT; when false, models stay unavailable until an override allows one or more user paths
overrides_enabled: true # env: MODEL_OVERRIDES_ENABLED; load/enforce persisted model overrides and enable dashboard editing
configured_provider_models_mode: "fallback" # env: CONFIGURED_PROVIDER_MODELS_MODE; "fallback" uses configured lists only when upstream /models is unavailable/empty, "allowlist" exposes only configured models and skips upstream /models for configured lists
models_endpoint_id_format: "qualified" # env: MODELS_ENDPOINT_ID_FORMAT; "qualified" (provider/model), "unqualified" (model only), or "both"

cache:
model:
Expand Down
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func buildDefaultConfig() *Config {
OverridesEnabled: true,
KeepOnlyAliasesAtModelsEndpoint: false,
ConfiguredProviderModelsMode: ConfiguredProviderModelsModeFallback,
ModelsEndpointIDFormat: ModelsEndpointIDFormatQualified,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add validation for ModelsEndpointIDFormat in Load().

ConfiguredProviderModelsMode is both resolved and validated in Load() (lines 169–172), but ModelsEndpointIDFormat is only defaulted here with no corresponding normalization or validation in Load(). This inconsistency means invalid user-supplied values will silently fall back to the default instead of producing a clear configuration error.

Proposed fix to add validation after line 172
 cfg.Models.ConfiguredProviderModelsMode = ResolveConfiguredProviderModelsMode(cfg.Models.ConfiguredProviderModelsMode)
 if !cfg.Models.ConfiguredProviderModelsMode.Valid() {
 	return nil, fmt.Errorf("models.configured_provider_models_mode must be one of: fallback, allowlist")
 }
+cfg.Models.ModelsEndpointIDFormat = ResolveModelsEndpointIDFormat(cfg.Models.ModelsEndpointIDFormat)
+if !cfg.Models.ModelsEndpointIDFormat.Valid() {
+	return nil, fmt.Errorf("models.models_endpoint_id_format must be one of: qualified, unqualified, both")
+}
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@config/config.go` at line 65, Add validation and normalization for
ModelsEndpointIDFormat inside the Load() function similar to how
ConfiguredProviderModelsMode is handled: after resolving the field, check that
the value is one of the allowed constants (e.g., ModelsEndpointIDFormatQualified
and any other valid enum values), normalize any synonyms if needed, and return
an error if the value is invalid instead of silently defaulting; update
references to ModelsEndpointIDFormat and ensure the defaulting code (which
currently sets ModelsEndpointIDFormat: ModelsEndpointIDFormatQualified) does not
hide invalid user input by performing this validation/normalization in Load().

},
Cache: CacheConfig{
Model: ModelCacheConfig{
Expand Down
38 changes: 38 additions & 0 deletions config/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ type ModelsConfig struct {
// provider *_MODELS env vars affect the provider model inventory.
// Supported values: "fallback", "allowlist". Default: "fallback".
ConfiguredProviderModelsMode ConfiguredProviderModelsMode `yaml:"configured_provider_models_mode" env:"CONFIGURED_PROVIDER_MODELS_MODE"`

// ModelsEndpointIDFormat controls the model ID format returned by GET /v1/models.
// Supported values: "qualified" (provider/model), "unqualified" (model), "both".
// Default: "qualified".
ModelsEndpointIDFormat ModelsEndpointIDFormat `yaml:"models_endpoint_id_format" env:"MODELS_ENDPOINT_ID_FORMAT"`
}

// ConfiguredProviderModelsMode controls how explicitly configured provider
Expand Down Expand Up @@ -57,3 +62,36 @@ func ResolveConfiguredProviderModelsMode(mode ConfiguredProviderModelsMode) Conf
}
return mode
}

// ModelsEndpointIDFormat controls the model ID format returned by GET /v1/models.
type ModelsEndpointIDFormat string

const (
ModelsEndpointIDFormatQualified ModelsEndpointIDFormat = "qualified"
ModelsEndpointIDFormatUnqualified ModelsEndpointIDFormat = "unqualified"
ModelsEndpointIDFormatBoth ModelsEndpointIDFormat = "both"
)

// Valid reports whether f is one of the supported models endpoint ID formats.
func (f ModelsEndpointIDFormat) Valid() bool {
switch NormalizeModelsEndpointIDFormat(f) {
case ModelsEndpointIDFormatQualified, ModelsEndpointIDFormatUnqualified, ModelsEndpointIDFormatBoth:
return true
default:
return false
}
}

// NormalizeModelsEndpointIDFormat canonicalizes a models endpoint ID format value.
func NormalizeModelsEndpointIDFormat(f ModelsEndpointIDFormat) ModelsEndpointIDFormat {
return ModelsEndpointIDFormat(strings.ToLower(strings.TrimSpace(string(f))))
}

// ResolveModelsEndpointIDFormat canonicalizes f and applies the process default.
func ResolveModelsEndpointIDFormat(f ModelsEndpointIDFormat) ModelsEndpointIDFormat {
f = NormalizeModelsEndpointIDFormat(f)
if f == "" {
return ModelsEndpointIDFormatQualified
}
return f
}
1 change: 1 addition & 0 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ func New(ctx context.Context, cfg Config) (*App, error) {
return nil, fmt.Errorf("failed to initialize providers: %w", err)
}
app.providers = providerResult
app.providers.Router.SetModelsEndpointIDFormat(appCfg.Models.ModelsEndpointIDFormat)

// Initialize audit logging
auditResult, err := auditlog.New(ctx, appCfg)
Expand Down
55 changes: 54 additions & 1 deletion internal/providers/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,60 @@ func (r *ModelRegistry) ListPublicModels() []core.Model {
return result
}

// ModelCount returns the number of registered models
// ListModelsWithFormat returns models formatted according to the given ID format.
func (r *ModelRegistry) ListModelsWithFormat(format config.ModelsEndpointIDFormat) []core.Model {
format = config.ResolveModelsEndpointIDFormat(format)
switch format {
case config.ModelsEndpointIDFormatUnqualified:
return r.listModelsUnqualified()
case config.ModelsEndpointIDFormatBoth:
return r.listModelsBoth()
default:
return r.ListPublicModels()
}
}

func (r *ModelRegistry) listModelsUnqualified() []core.Model {
r.mu.RLock()
defer r.mu.RUnlock()

seen := make(map[string]struct{})
result := make([]core.Model, 0, len(r.models))
for providerName, models := range r.modelsByProvider {
for modelID, info := range models {
if _, exists := seen[modelID]; exists {
continue
}
seen[modelID] = struct{}{}
model := info.Model
model.ID = modelID
model.OwnedBy = providerName
result = append(result, model)
}
}
sort.Slice(result, func(i, j int) bool { return result[i].ID < result[j].ID })
return result
Comment on lines +361 to +379
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Non-deterministic deduplication breaks "first provider wins" guarantee

listModelsUnqualified iterates over r.modelsByProvider, a Go map whose iteration order is randomized per run. When two providers share the same model ID (e.g., gpt-4o on both openai and openrouter), which provider "wins" the deduplication — and therefore what OwnedBy and model metadata are returned — is random. The PR description says this matches routing behavior, but routing uses r.models (populated in registration order, first-registered wins deterministically) as the source of truth. Using r.models here instead would give the correct, stable result: the same provider that routing would actually use.

}

func (r *ModelRegistry) listModelsBoth() []core.Model {
qualified := r.ListPublicModels()
unqualified := r.listModelsUnqualified()

seen := make(map[string]struct{}, len(qualified)+len(unqualified))
result := make([]core.Model, 0, len(qualified)+len(unqualified))
for _, m := range qualified {
seen[m.ID] = struct{}{}
result = append(result, m)
}
for _, m := range unqualified {
if _, exists := seen[m.ID]; exists {
continue
}
result = append(result, m)
}
sort.Slice(result, func(i, j int) bool { return result[i].ID < result[j].ID })
return result
}
func (r *ModelRegistry) ModelCount() int {
Comment on lines +398 to 401
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing blank line and doc comment between listModelsBoth and ModelCount. The original // ModelCount returns the number of registered models comment was dropped when the new functions were inserted.

Suggested change
sort.Slice(result, func(i, j int) bool { return result[i].ID < result[j].ID })
return result
}
func (r *ModelRegistry) ModelCount() int {
sort.Slice(result, func(i, j int) bool { return result[i].ID < result[j].ID })
return result
}
// ModelCount returns the number of registered models
func (r *ModelRegistry) ModelCount() int {

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

r.mu.RLock()
defer r.mu.RUnlock()
Expand Down
18 changes: 15 additions & 3 deletions internal/providers/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"sort"
"strings"

"gomodel/config"
"gomodel/internal/core"
)

Expand All @@ -21,7 +22,8 @@ var ErrRegistryNotInitialized = fmt.Errorf("model registry has no models: ensure
// It uses a dynamic model-to-provider mapping that is populated at startup
// by fetching available models from each provider's /models endpoint.
type Router struct {
lookup core.ModelLookup
lookup core.ModelLookup
modelsEndpointIDFormat config.ModelsEndpointIDFormat
}

type providerTypeRegistry interface {
Expand Down Expand Up @@ -64,10 +66,16 @@ func NewRouter(lookup core.ModelLookup) (*Router, error) {
return nil, fmt.Errorf("lookup cannot be nil")
}
return &Router{
lookup: lookup,
lookup: lookup,
modelsEndpointIDFormat: config.ModelsEndpointIDFormatQualified,
}, nil
}

// SetModelsEndpointIDFormat configures the model ID format for GET /v1/models.
func (r *Router) SetModelsEndpointIDFormat(format config.ModelsEndpointIDFormat) {
r.modelsEndpointIDFormat = config.ResolveModelsEndpointIDFormat(format)
}

// checkReady verifies the lookup has models available.
// Returns ErrRegistryNotInitialized if no models are loaded.
func (r *Router) checkReady() error {
Expand Down Expand Up @@ -527,7 +535,11 @@ func (r *Router) ListModels(_ context.Context) (*core.ModelsResponse, error) {
return nil, registryUnavailableError(err)
}
var models []core.Model
if public, ok := r.lookup.(publicModelLister); ok {
if formatted, ok := r.lookup.(interface {
ListModelsWithFormat(config.ModelsEndpointIDFormat) []core.Model
}); ok {
models = formatted.ListModelsWithFormat(r.modelsEndpointIDFormat)
} else if public, ok := r.lookup.(publicModelLister); ok {
models = public.ListPublicModels()
} else {
models = r.lookup.ListModels()
Expand Down
65 changes: 65 additions & 0 deletions internal/providers/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"
"testing"

"gomodel/config"
"gomodel/internal/core"
)

Expand Down Expand Up @@ -841,6 +842,70 @@ func TestRouterListModels(t *testing.T) {
}
}

func TestRouterListModelsWithFormat(t *testing.T) {
registry := newTestRegistryWithModels(
registryModelEntry{provider: &mockProvider{}, providerName: "anthropic", providerType: "anthropic", modelID: "claude-sonnet-4-6"},
registryModelEntry{provider: &mockProvider{}, providerName: "openai", providerType: "openai", modelID: "gpt-4o"},
)

t.Run("qualified", func(t *testing.T) {
router, _ := NewRouter(registry)
router.SetModelsEndpointIDFormat(config.ModelsEndpointIDFormatQualified)

resp, err := router.ListModels(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
for _, m := range resp.Data {
if !strings.Contains(m.ID, "/") {
t.Errorf("expected qualified ID with '/', got %q", m.ID)
}
}
})

t.Run("unqualified", func(t *testing.T) {
router, _ := NewRouter(registry)
router.SetModelsEndpointIDFormat(config.ModelsEndpointIDFormatUnqualified)

resp, err := router.ListModels(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.Data) != 2 {
t.Fatalf("expected 2 models, got %d", len(resp.Data))
}
for _, m := range resp.Data {
if strings.Contains(m.ID, "/") {
t.Errorf("expected unqualified ID without '/', got %q", m.ID)
}
}
})

t.Run("both", func(t *testing.T) {
router, _ := NewRouter(registry)
router.SetModelsEndpointIDFormat(config.ModelsEndpointIDFormatBoth)

resp, err := router.ListModels(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp.Data) != 4 {
t.Fatalf("expected 4 models (2 qualified + 2 unqualified), got %d", len(resp.Data))
}
var qualified, unqualified int
for _, m := range resp.Data {
if strings.Contains(m.ID, "/") {
qualified++
} else {
unqualified++
}
}
if qualified != 2 || unqualified != 2 {
t.Errorf("expected 2 qualified + 2 unqualified, got %d + %d", qualified, unqualified)
}
})
}

func TestRouterGetProviderType(t *testing.T) {
lookup := newMockLookup()
lookup.addModel("gpt-4o", &mockProvider{}, "openai")
Expand Down