From 0a78120a3d79dc5848871d3950df0b9c51b5768b Mon Sep 17 00:00:00 2001 From: Will Handley Date: Sat, 1 Feb 2025 18:55:34 +0000 Subject: [PATCH] First draft of assistant --- gptcli/assistant.py | 59 +++++++++++++++++++------------------- gptcli/config.py | 17 +++++++++-- gptcli/gpt.py | 2 +- gptcli/providers/openai.py | 18 ++++++------ 4 files changed, 56 insertions(+), 40 deletions(-) diff --git a/gptcli/assistant.py b/gptcli/assistant.py index c435d8f..55e3ace 100644 --- a/gptcli/assistant.py +++ b/gptcli/assistant.py @@ -15,15 +15,7 @@ from gptcli.providers.anthropic import AnthropicCompletionProvider from gptcli.providers.cohere import CohereCompletionProvider from gptcli.providers.azure_openai import AzureOpenAICompletionProvider - - -class AssistantConfig(TypedDict, total=False): - messages: List[Message] - model: str - openai_base_url_override: Optional[str] - openai_api_key_override: Optional[str] - temperature: float - top_p: float +from gptcli.config import ModelConfig, AssistantConfig CONFIG_DEFAULTS = { @@ -73,17 +65,7 @@ def get_completion_provider( openai_base_url_override: Optional[str] = None, openai_api_key_override: Optional[str] = None, ) -> CompletionProvider: - if ( - model.startswith("gpt") - or model.startswith("ft:gpt") - or model.startswith("oai-compat:") - or model.startswith("chatgpt") - or model.startswith("o1") - ): - return OpenAICompletionProvider( - openai_base_url_override, openai_api_key_override - ) - elif model.startswith("oai-azure:"): + if model.startswith("oai-azure:"): return AzureOpenAICompletionProvider() elif model.startswith("claude"): return AnthropicCompletionProvider() @@ -94,15 +76,18 @@ def get_completion_provider( elif model.startswith("gemini"): return GoogleCompletionProvider() else: - raise ValueError(f"Unknown model: {model}") + return OpenAICompletionProvider( + openai_base_url_override, openai_api_key_override + ) class Assistant: - def __init__(self, config: AssistantConfig): + def __init__(self, config: AssistantConfig, model_configs: Optional[Dict[str, ModelConfig]] = None): self.config = config + self.model_configs = model_configs or {} @classmethod - def from_config(cls, name: str, config: AssistantConfig): + def from_config(cls, name: str, config: AssistantConfig, model_configs: Optional[Dict[str, ModelConfig]] = None): config = config.copy() if name in DEFAULT_ASSISTANTS: # Merge the config with the default config @@ -112,7 +97,7 @@ def from_config(cls, name: str, config: AssistantConfig): if config.get(key) is None: config[key] = default_config[key] - return cls(config) + return cls(config, model_configs=model_configs) def init_messages(self) -> List[Message]: return self.config.get("messages", [])[:] @@ -124,10 +109,22 @@ def _param(self, param: str) -> Any: def complete_chat(self, messages, stream: bool = True) -> Iterator[CompletionEvent]: model = self._param("model") + # Check if there is a model configuration override for this model. + if model in self.model_configs: + model_conf = self.model_configs[model] + print(model_conf) + openai_api_key_override = model_conf['api_key'] or self.config.get("openai_api_key_override") + openai_base_url_override = model_conf['base_url'] or self.config.get("openai_base_url_override") + pricing_override = model_conf['pricing'] + else: + openai_api_key_override = self._param("openai_api_key_override") + openai_base_url_override = self._param("openai_base_url_override") + pricing_override = None + completion_provider = get_completion_provider( model, - self._param("openai_base_url_override"), - self._param("openai_api_key_override"), + openai_base_url_override, + openai_api_key_override, ) return completion_provider.complete( messages, @@ -135,6 +132,8 @@ def complete_chat(self, messages, stream: bool = True) -> Iterator[CompletionEve "model": model, "temperature": float(self._param("temperature")), "top_p": float(self._param("top_p")), + # Pass along the pricing override if available. + "pricing": pricing_override, }, stream, ) @@ -149,13 +148,15 @@ class AssistantGlobalArgs: def init_assistant( - args: AssistantGlobalArgs, custom_assistants: Dict[str, AssistantConfig] + args: AssistantGlobalArgs, + custom_assistants: Dict[str, AssistantConfig], + model_configs: Optional[Dict[str, ModelConfig]] = None, ) -> Assistant: name = args.assistant_name if name in custom_assistants: - assistant = Assistant.from_config(name, custom_assistants[name]) + assistant = Assistant.from_config(name, custom_assistants[name], model_configs=model_configs) elif name in DEFAULT_ASSISTANTS: - assistant = Assistant.from_config(name, DEFAULT_ASSISTANTS[name]) + assistant = Assistant.from_config(name, DEFAULT_ASSISTANTS[name], model_configs=model_configs) else: print(f"Unknown assistant: {name}") sys.exit(1) diff --git a/gptcli/config.py b/gptcli/config.py index df5c065..84d6e6b 100644 --- a/gptcli/config.py +++ b/gptcli/config.py @@ -4,14 +4,27 @@ import yaml from attr import dataclass -from gptcli.assistant import AssistantConfig from gptcli.providers.llama import LLaMAModelConfig +from gptcli.completion import Message CONFIG_FILE_PATHS = [ os.path.join(os.path.expanduser("~"), ".config", "gpt-cli", "gpt.yml"), os.path.join(os.path.expanduser("~"), ".gptrc"), ] +class AssistantConfig: + messages: List[Message] + model: str + openai_base_url_override: Optional[str] + openai_api_key_override: Optional[str] + temperature: float + top_p: float + +@dataclass +class ModelConfig: + api_key: Optional[str] = None + base_url: Optional[str] = None + pricing: Optional[Dict[str, float]] = None @dataclass class GptCliConfig: @@ -30,7 +43,7 @@ class GptCliConfig: assistants: Dict[str, AssistantConfig] = {} interactive: Optional[bool] = None llama_models: Optional[Dict[str, LLaMAModelConfig]] = None - + model_configs: Optional[Dict[str, ModelConfig]] = None def choose_config_file(paths: List[str]) -> str: for path in paths: diff --git a/gptcli/gpt.py b/gptcli/gpt.py index 1e55c1f..cb75f0c 100755 --- a/gptcli/gpt.py +++ b/gptcli/gpt.py @@ -199,7 +199,7 @@ def main(): if config.llama_models is not None: init_llama_models(config.llama_models) - assistant = init_assistant(cast(AssistantGlobalArgs, args), config.assistants) + assistant = init_assistant(cast(AssistantGlobalArgs, args), config.assistants, model_configs=config.model_configs) if args.prompt is not None: run_non_interactive(args, assistant) diff --git a/gptcli/providers/openai.py b/gptcli/providers/openai.py index 81b2e08..f0b7981 100644 --- a/gptcli/providers/openai.py +++ b/gptcli/providers/openai.py @@ -56,7 +56,8 @@ def complete( ): yield MessageDeltaEvent(response.choices[0].delta.content) - if response.usage and (pricing := gpt_pricing(args["model"])): + pricing = args.get("pricing") or gpt_pricing(args["model"]) + if response.usage and pricing: yield UsageEvent.with_pricing( prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, @@ -73,13 +74,14 @@ def complete( next_choice = response.choices[0] if next_choice.message.content: yield MessageDeltaEvent(next_choice.message.content) - if response.usage and (pricing := gpt_pricing(args["model"])): - yield UsageEvent.with_pricing( - prompt_tokens=response.usage.prompt_tokens, - completion_tokens=response.usage.completion_tokens, - total_tokens=response.usage.total_tokens, - pricing=pricing, - ) + pricing = args.get("pricing") or gpt_pricing(args["model"]) + if response.usage and pricing: + yield UsageEvent.with_pricing( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + pricing=pricing, + ) except openai.BadRequestError as e: raise BadRequestError(e.message) from e