diff --git a/gptcli/__init__.py b/gptcli/__init__.py index f9aa3e1..e19434e 100644 --- a/gptcli/__init__.py +++ b/gptcli/__init__.py @@ -1 +1 @@ -__version__ = "0.3.2" +__version__ = "0.3.3" diff --git a/gptcli/assistant.py b/gptcli/assistant.py index c435d8f..fb37b52 100644 --- a/gptcli/assistant.py +++ b/gptcli/assistant.py @@ -15,6 +15,7 @@ from gptcli.providers.anthropic import AnthropicCompletionProvider from gptcli.providers.cohere import CohereCompletionProvider from gptcli.providers.azure_openai import AzureOpenAICompletionProvider +from gptcli.providers.xai import XAICompletionProvider class AssistantConfig(TypedDict, total=False): @@ -93,6 +94,8 @@ def get_completion_provider( return CohereCompletionProvider() elif model.startswith("gemini"): return GoogleCompletionProvider() + elif model.startswith("grok"): + return XAICompletionProvider() else: raise ValueError(f"Unknown model: {model}") diff --git a/gptcli/config.py b/gptcli/config.py index df5c065..0d5ce15 100644 --- a/gptcli/config.py +++ b/gptcli/config.py @@ -25,6 +25,7 @@ class GptCliConfig: anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY") google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY") cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY") + xai_api_key: Optional[str] = os.environ.get("XAI_API_KEY") log_file: Optional[str] = None log_level: str = "INFO" assistants: Dict[str, AssistantConfig] = {} diff --git a/gptcli/gpt.py b/gptcli/gpt.py index 1e55c1f..ec8faba 100755 --- a/gptcli/gpt.py +++ b/gptcli/gpt.py @@ -193,6 +193,9 @@ def main(): if config.cohere_api_key: gptcli.providers.cohere.api_key = config.cohere_api_key + if config.xai_api_key: + gptcli.providers.xai.api_key = config.xai_api_key + if config.google_api_key: genai.configure(api_key=config.google_api_key) diff --git a/gptcli/providers/openai.py b/gptcli/providers/openai.py index 81b2e08..5afbccd 100644 --- a/gptcli/providers/openai.py +++ b/gptcli/providers/openai.py @@ -56,7 +56,7 @@ def complete( ): yield MessageDeltaEvent(response.choices[0].delta.content) - if response.usage and (pricing := gpt_pricing(args["model"])): + if response.usage and (pricing := self.pricing(args["model"])): yield UsageEvent.with_pricing( prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, @@ -73,7 +73,7 @@ def complete( next_choice = response.choices[0] if next_choice.message.content: yield MessageDeltaEvent(next_choice.message.content) - if response.usage and (pricing := gpt_pricing(args["model"])): + if response.usage and (pricing := self.pricing(args["model"])): yield UsageEvent.with_pricing( prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, @@ -86,6 +86,8 @@ def complete( except openai.APIError as e: raise CompletionError(e.message) from e + def pricing(self, model: str) -> Optional[Pricing]: + return gpt_pricing(model) GPT_3_5_TURBO_PRICE_PER_TOKEN: Pricing = { "prompt": 0.50 / 1_000_000, diff --git a/gptcli/providers/xai.py b/gptcli/providers/xai.py new file mode 100644 index 0000000..b47f0e6 --- /dev/null +++ b/gptcli/providers/xai.py @@ -0,0 +1,33 @@ +import os +from typing import Optional + +from openai import OpenAI + +from gptcli.completion import Pricing +from gptcli.providers.openai import OpenAICompletionProvider + +api_key = os.environ.get("XAI_API_KEY") + + +class XAICompletionProvider(OpenAICompletionProvider): + def __init__(self): + self.client = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1") + + def pricing(self, model: str) -> Optional[Pricing]: + if model.startswith("grok-beta"): + return GROK_BETA_PRICE_PER_TOKEN + elif model.startswith("grok-2"): + return GROK_2_PRICE_PER_TOKEN + else: + return None + + +GROK_2_PRICE_PER_TOKEN: Pricing = { + "prompt": 2.00 / 1_000_000, + "response": 10.00 / 1_000_000, +} + +GROK_BETA_PRICE_PER_TOKEN: Pricing = { + "prompt": 5.00 / 1_000_000, + "response": 15.00 / 1_000_000, +}