Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ jobs/**
node_modules/
bench/__pycache__
.ai/
.forge/FORGE_EDITMSG.md
233 changes: 230 additions & 3 deletions crates/forge_config/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,72 @@ pub struct ProviderUrlParam {
pub options: Vec<String>,
}

/// Input modality supported by a model.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, Dummy)]
#[serde(rename_all = "lowercase")]
pub enum InputModality {
Text,
Image,
}

/// Default input modalities when not specified (text-only).
fn default_input_modalities() -> Vec<InputModality> {
vec![InputModality::Text]
}

/// A static model entry for inline provider configuration.
///
/// This allows defining model capabilities directly in `forge.toml` without
/// requiring a URL-based model discovery.
#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize, JsonSchema, Dummy, Setters)]
#[serde(rename_all = "snake_case")]
#[setters(strip_option)]
pub struct StaticModelEntry {
/// Unique model identifier (e.g. `"gpt-4"`).
pub id: String,
/// Human-readable model name.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
/// Description of the model.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
/// Maximum context window size in tokens.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub context_length: Option<u64>,
/// Whether the model supports tool calls.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tools_supported: Option<bool>,
/// Whether the model supports parallel tool calls.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub supports_parallel_tool_calls: Option<bool>,
/// Whether the model supports reasoning.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub supports_reasoning: Option<bool>,
/// Input modalities supported by the model (defaults to text-only).
#[serde(default = "default_input_modalities")]
pub input_modalities: Vec<InputModality>,
}

/// Model source for a provider.
///
/// This can be either a URL template for fetching the model list, or a
/// static list of model entries defined inline.
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema)]
#[serde(untagged)]
pub enum ProviderModels {
/// URL template for fetching the model list.
Url(String),
/// Static list of model entries.
Static(Vec<StaticModelEntry>),
}

impl fake::Dummy<fake::Faker> for ProviderModels {
fn dummy_with_rng<R: fake::RngExt + ?Sized>(_: &fake::Faker, rng: &mut R) -> Self {
// Generate a static list of models for testing
Self::Static(vec![StaticModelEntry::dummy_with_rng(&fake::Faker, rng)])
}
}

/// A single provider entry defined inline in `forge.toml`.
///
/// Inline providers are merged with the built-in provider list; entries with
Expand All @@ -75,10 +141,10 @@ pub struct ProviderEntry {
/// URL template for chat completions; may contain `{{VAR}}` placeholders
/// that are substituted from the credential's url params.
pub url: String,
/// URL template for fetching the model list; may contain `{{VAR}}`
/// placeholders.
/// Model source for this provider. Can be either a URL template for
/// fetching the model list, or a static list of model entries.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub models: Option<String>,
pub models: Option<ProviderModels>,
/// Wire protocol used by this provider.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub response_type: Option<ProviderResponseType>,
Expand Down Expand Up @@ -367,4 +433,165 @@ mod tests {

assert_eq!(actual.temperature, fixture.temperature);
}

#[test]
fn test_provider_models_url() {
let json = r#""https://api.example.com/models""#;
let models: ProviderModels = serde_json::from_str(json).unwrap();
assert!(
matches!(models, ProviderModels::Url(url) if url == "https://api.example.com/models")
);
}

#[test]
fn test_provider_models_static() {
let json = r#"[{"id": "gpt-4", "name": "GPT-4", "context_length": 128000}]"#;
let models: ProviderModels = serde_json::from_str(json).unwrap();

match models {
ProviderModels::Static(entries) => {
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].id, "gpt-4");
assert_eq!(entries[0].name, Some("GPT-4".to_string()));
assert_eq!(entries[0].context_length, Some(128000));
}
other => panic!("Expected Static variant, got {:?}", other),
}
}

#[test]
fn test_provider_models_multiple_entries() {
let json = r#"[
{"id": "gpt-4", "tools_supported": true},
{"id": "gpt-3.5-turbo", "tools_supported": false}
]"#;
let models: ProviderModels = serde_json::from_str(json).unwrap();

match models {
ProviderModels::Static(entries) => {
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].id, "gpt-4");
assert_eq!(entries[0].tools_supported, Some(true));
assert_eq!(entries[1].id, "gpt-3.5-turbo");
assert_eq!(entries[1].tools_supported, Some(false));
}
other => panic!("Expected Static variant, got {:?}", other),
}
}

#[test]
fn test_provider_models_round_trip() {
let original = ProviderModels::Static(vec![StaticModelEntry {
id: "qwen3-35b".to_string(),
name: Some("Qwen 3.5 35B".to_string()),
description: Some("Local reasoning model".to_string()),
context_length: Some(262144),
tools_supported: Some(true),
supports_parallel_tool_calls: Some(true),
supports_reasoning: Some(true),
input_modalities: vec![InputModality::Text],
}]);

let json = serde_json::to_string(&original).unwrap();
let parsed: ProviderModels = serde_json::from_str(&json).unwrap();

match (&original, &parsed) {
(ProviderModels::Static(orig_entries), ProviderModels::Static(parsed_entries)) => {
assert_eq!(parsed_entries.len(), 1);
assert_eq!(parsed_entries[0].id, "qwen3-35b");
assert_eq!(parsed_entries[0].name, orig_entries[0].name);
assert_eq!(
parsed_entries[0].context_length,
orig_entries[0].context_length
);
assert_eq!(
parsed_entries[0].tools_supported,
orig_entries[0].tools_supported
);
}
other => panic!("Expected Static variants, got {:?}", other),
}
}

#[test]
fn test_provider_models_input_modalities_default() {
let json = r#"[{"id": "test-model"}]"#;
let models: ProviderModels = serde_json::from_str(json).unwrap();

match models {
ProviderModels::Static(entries) => {
assert_eq!(entries[0].input_modalities, vec![InputModality::Text]);
}
other => panic!("Expected Static variant, got {:?}", other),
}
}

#[test]
fn test_provider_models_with_image_modality() {
let json = r#"[{"id": "vision-model", "input_modalities": ["text", "image"]}]"#;
let models: ProviderModels = serde_json::from_str(json).unwrap();

match models {
ProviderModels::Static(entries) => {
assert_eq!(
entries[0].input_modalities,
vec![InputModality::Text, InputModality::Image]
);
}
other => panic!("Expected Static variant, got {:?}", other),
}
}

#[test]
fn test_provider_entry_with_static_models() {
let json = r#"{
"id": "ollama",
"url": "http://localhost:8000/v1/chat/completions",
"models": [
{
"id": "qwen3-35b",
"name": "Qwen 3.5 35B",
"context_length": 262144,
"tools_supported": true,
"supports_parallel_tool_calls": true,
"supports_reasoning": true,
"input_modalities": ["text"]
}
]
}"#;

let entry: ProviderEntry = serde_json::from_str(json).unwrap();
assert_eq!(entry.id, "ollama");
assert_eq!(entry.url, "http://localhost:8000/v1/chat/completions");

match entry.models {
Some(ProviderModels::Static(entries)) => {
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].id, "qwen3-35b");
assert_eq!(entries[0].context_length, Some(262144));
assert_eq!(entries[0].tools_supported, Some(true));
assert_eq!(entries[0].supports_reasoning, Some(true));
}
other => panic!("Expected Static variant, got {:?}", other),
}
}

#[test]
fn test_provider_entry_with_url_models() {
let json = r#"{
"id": "openai",
"url": "https://api.openai.com/v1/chat/completions",
"models": "https://api.openai.com/v1/models"
}"#;

let entry: ProviderEntry = serde_json::from_str(json).unwrap();
assert_eq!(entry.id, "openai");

match entry.models {
Some(ProviderModels::Url(url)) => {
assert_eq!(url, "https://api.openai.com/v1/models");
}
other => panic!("Expected Url variant, got {:?}", other),
}
}
}
30 changes: 29 additions & 1 deletion crates/forge_repo/src/provider/provider_repo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,35 @@ impl From<forge_config::ProviderEntry> for ProviderConfig {
url_param_vars: entry.url_param_vars.into_iter().map(Into::into).collect(),
response_type,
url: entry.url,
models: entry.models.map(Models::Url),
models: entry.models.map(|m| match m {
forge_config::ProviderModels::Url(url) => Models::Url(url),
forge_config::ProviderModels::Static(entries) => Models::Hardcoded(
entries
.into_iter()
.map(|e| forge_app::domain::Model {
id: e.id.into(),
name: e.name,
description: e.description,
context_length: e.context_length,
tools_supported: e.tools_supported,
supports_parallel_tool_calls: e.supports_parallel_tool_calls,
supports_reasoning: e.supports_reasoning,
input_modalities: e
.input_modalities
.into_iter()
.map(|m| match m {
forge_config::InputModality::Text => {
forge_app::domain::InputModality::Text
}
forge_config::InputModality::Image => {
forge_app::domain::InputModality::Image
}
})
.collect(),
})
.collect(),
),
}),
auth_methods,
custom_headers: entry.custom_headers,
}
Expand Down
Loading
Loading