From 79981a0853c84766b499d2fc65b21e59b9f6b706 Mon Sep 17 00:00:00 2001 From: Valery Kharitonov Date: Tue, 12 Dec 2023 00:36:23 -0500 Subject: [PATCH 1/2] Add together.xyz endpoint --- gptcli/assistant.py | 22 ++++++++++++++++++++-- gptcli/config.py | 1 + gptcli/gpt.py | 4 ++++ pyproject.toml | 2 ++ 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/gptcli/assistant.py b/gptcli/assistant.py index 6adac67..1c16ccd 100644 --- a/gptcli/assistant.py +++ b/gptcli/assistant.py @@ -9,6 +9,7 @@ from gptcli.llama import LLaMACompletionProvider from gptcli.openai import OpenAICompletionProvider from gptcli.anthropic import AnthropicCompletionProvider +from gptcli.together import TogetherCompletionProvider class AssistantConfig(TypedDict, total=False): @@ -16,6 +17,11 @@ class AssistantConfig(TypedDict, total=False): model: str temperature: float top_p: float + system_prefix: str + system_suffix: str + user_prefix: str + user_suffix: str + stop_tokens: List[str] CONFIG_DEFAULTS = { @@ -55,7 +61,9 @@ class AssistantConfig(TypedDict, total=False): } -def get_completion_provider(model: str) -> CompletionProvider: +def get_completion_provider( + model: str, assistant_config: AssistantConfig +) -> CompletionProvider: if model.startswith("gpt"): return OpenAICompletionProvider() elif model.startswith("claude"): @@ -64,6 +72,16 @@ def get_completion_provider(model: str) -> CompletionProvider: return LLaMACompletionProvider() elif model.startswith("chat-bison"): return GoogleCompletionProvider() + elif model.startswith("together"): + return TogetherCompletionProvider( + { + "system_prefix": assistant_config.get("system_prefix", ""), + "system_suffix": assistant_config.get("system_suffix", ""), + "user_prefix": assistant_config.get("user_prefix", ""), + "user_suffix": assistant_config.get("user_suffix", ""), + "stop_tokens": assistant_config.get("stop_tokens", []), + } + ) else: raise ValueError(f"Unknown model: {model}") @@ -103,7 +121,7 @@ def complete_chat( self, messages, override_params: ModelOverrides = {}, stream: bool = True ) -> Iterator[str]: model = self._param("model", override_params) - completion_provider = get_completion_provider(model) + completion_provider = get_completion_provider(model, self.config) return completion_provider.complete( messages, { diff --git a/gptcli/config.py b/gptcli/config.py index 3cb9070..b205114 100644 --- a/gptcli/config.py +++ b/gptcli/config.py @@ -22,6 +22,7 @@ class GptCliConfig: openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY") google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY") + together_api_key: Optional[str] = os.environ.get("TOGETHER_API_KEY") log_file: Optional[str] = None log_level: str = "INFO" assistants: Dict[str, AssistantConfig] = {} diff --git a/gptcli/gpt.py b/gptcli/gpt.py index e9634e9..32a4e31 100755 --- a/gptcli/gpt.py +++ b/gptcli/gpt.py @@ -15,6 +15,7 @@ import datetime import google.generativeai as genai import gptcli.anthropic +import gptcli.together from gptcli.assistant import ( Assistant, DEFAULT_ASSISTANTS, @@ -178,6 +179,9 @@ def main(): ) sys.exit(1) + if config.together_api_key: + gptcli.together.api_key = config.together_api_key + if config.anthropic_api_key: gptcli.anthropic.api_key = config.anthropic_api_key diff --git a/pyproject.toml b/pyproject.toml index 612739d..4585571 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ dependencies = [ "pytest==7.3.1", "PyYAML==6.0", "rich==13.7.0", + "requests==2.31.0", + "sseclient-py==1.8.0", "tiktoken==0.5.2", "tokenizers==0.15.0", "typing_extensions==4.5.0", From e7781288f235d174b4e042e7394296a81e98b2f9 Mon Sep 17 00:00:00 2001 From: Valery Kharitonov Date: Tue, 12 Dec 2023 01:20:49 -0500 Subject: [PATCH 2/2] Add missing file --- gptcli/together.py | 84 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 gptcli/together.py diff --git a/gptcli/together.py b/gptcli/together.py new file mode 100644 index 0000000..35f4c0a --- /dev/null +++ b/gptcli/together.py @@ -0,0 +1,84 @@ +import json +import logging +import os +import requests +import sseclient + +from typing import Iterator, List, TypedDict, cast + +from gptcli.completion import CompletionProvider, Message + +api_key = os.environ.get("TOGETHER_API_KEY") +url = "https://api.together.xyz/inference" + + +class PromptConfig(TypedDict): + system_prefix: str + system_suffix: str + user_prefix: str + user_suffix: str + stop_tokens: List[str] + + +def build_prompt(messages: List[Message], prompt_config: PromptConfig) -> str: + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += prompt_config["system_prefix"] + prompt += message["content"] + prompt += prompt_config["system_suffix"] + elif message["role"] == "user": + prompt += prompt_config["user_prefix"] + prompt += message["content"] + prompt += prompt_config["user_suffix"] + else: + prompt += message["content"] + return prompt + + +class TogetherCompletionProvider(CompletionProvider): + def __init__(self, prompt_config: PromptConfig): + self.prompt_config = prompt_config + + def complete( + self, messages: List[Message], args: dict, stream: bool = False + ) -> Iterator[str]: + kwargs = {} + if "temperature" in args: + kwargs["temperature"] = args["temperature"] + if "top_p" in args: + kwargs["top_p"] = args["top_p"] + + assert stream, "Together only supports streaming completions" + + model = args["model"].split("/", 1)[1] + prompt = build_prompt(messages, self.prompt_config) + + logging.info(f"Prompt: {prompt}") + + payload = { + "model": model, + "prompt": prompt, + "max_tokens": 2048, + "stream_tokens": True, + "stop": self.prompt_config["stop_tokens"], + **kwargs, + } + + headers = { + "accept": "application/json", + "content-type": "application/json", + "Authorization": f"Bearer {api_key}", + } + + response = requests.post(url, json=payload, headers=headers, stream=True) + response.raise_for_status() + + client = sseclient.SSEClient(response) + for event in client.events(): + if event.data == "[DONE]": + break + + partial_result = json.loads(event.data) + token = partial_result["choices"][0]["text"] + yield token