Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gptcli/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gptcli.completion import CompletionProvider, ModelOverrides, Message
from gptcli.google import GoogleCompletionProvider
from gptcli.llama import LLaMACompletionProvider
from gptcli.mistral import MistralCompletionProvider
from gptcli.openai import OpenAICompletionProvider
from gptcli.anthropic import AnthropicCompletionProvider

Expand Down Expand Up @@ -64,6 +65,8 @@ def get_completion_provider(model: str) -> CompletionProvider:
return LLaMACompletionProvider()
elif model.startswith("chat-bison"):
return GoogleCompletionProvider()
elif model.startswith("mistral"):
return MistralCompletionProvider()
else:
raise ValueError(f"Unknown model: {model}")

Expand Down
1 change: 1 addition & 0 deletions gptcli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class GptCliConfig:
show_price: bool = True
api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
mistral_api_key: Optional[str] = os.environ.get("MISTRAL_API_KEY")
anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY")
google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY")
log_file: Optional[str] = None
Expand Down
4 changes: 4 additions & 0 deletions gptcli/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import datetime
import google.generativeai as genai
import gptcli.anthropic
import gptcli.mistral
from gptcli.assistant import (
Assistant,
DEFAULT_ASSISTANTS,
Expand Down Expand Up @@ -178,6 +179,9 @@ def main():
)
sys.exit(1)

if config.mistral_api_key:
gptcli.mistral.api_key = config.mistral_api_key

if config.anthropic_api_key:
gptcli.anthropic.api_key = config.anthropic_api_key

Expand Down
47 changes: 47 additions & 0 deletions gptcli/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Iterator, List
import os
from gptcli.completion import CompletionProvider, Message
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

api_key = os.environ.get("MISTRAL_API_KEY")


class MistralCompletionProvider(CompletionProvider):
def __init__(self):
self.client = MistralClient(api_key=api_key)

def complete(
self, messages: List[Message], args: dict, stream: bool = False
) -> Iterator[str]:
kwargs = {}
if "temperature" in args:
kwargs["temperature"] = args["temperature"]
if "top_p" in args:
kwargs["top_p"] = args["top_p"]

messages = [
ChatMessage(role=msg["role"], content=msg["content"])
for msg in messages
]

if stream:
response_iter = self.client.chat_stream(
model=args["model"],
messages=messages,
**kwargs,
)

for response in response_iter:
next_choice = response.choices[0]
if next_choice.finish_reason is None and next_choice.delta.content:
yield next_choice.delta.content
else:
response = self.client.chat(
model=args["model"],
messages=messages,
**kwargs,
)
next_choice = response.choices[0]
if next_choice.message.content:
yield next_choice.message.content