From 653734ff1435e3bd20a8f585ec5598f00ce540a2 Mon Sep 17 00:00:00 2001 From: Valery Kharitonov Date: Wed, 19 Jul 2023 01:37:18 -0400 Subject: [PATCH] Support for Llama 2 chat models --- gptcli/llama.py | 165 ++++++++++++++++++++++++++++++++++++++++++++---- pyproject.toml | 2 +- 2 files changed, 153 insertions(+), 14 deletions(-) diff --git a/gptcli/llama.py b/gptcli/llama.py index 359d61a..df51939 100644 --- a/gptcli/llama.py +++ b/gptcli/llama.py @@ -1,6 +1,8 @@ import os import sys -from typing import Iterator, List, Optional, TypedDict, cast +import random +from typing import Any, Iterator, List, Optional, TypedDict, cast +from typing_extensions import Required try: from llama_cpp import Completion, CompletionChunk, Llama @@ -12,8 +14,10 @@ from gptcli.completion import CompletionProvider, Message -class LLaMAModelConfig(TypedDict): - path: str +class LLaMAModelConfig(TypedDict, total=False): + path: Required[str] + llama2: bool + n_gpu_layers: int human_prompt: str assistant_prompt: str @@ -41,6 +45,10 @@ def init_llama_models(models: dict[str, LLaMAModelConfig]): def role_to_name(role: str, model_config: LLaMAModelConfig) -> str: + assert ( + "human_prompt" in model_config and "assistant_prompt" in model_config + ), "either `llama2: True` or human_prompt and assistant_prompt must be set in the model config" + if role == "system" or role == "user": return model_config["human_prompt"] elif role == "assistant": @@ -49,7 +57,11 @@ def role_to_name(role: str, model_config: LLaMAModelConfig) -> str: raise ValueError(f"Unknown role: {role}") -def make_prompt(messages: List[Message], model_config: LLaMAModelConfig) -> str: +def make_prompt_llama1(messages: List[Message], model_config: LLaMAModelConfig) -> str: + assert ( + "human_prompt" in model_config and "assistant_prompt" in model_config + ), "either `llama2: True` or human_prompt and assistant_prompt must be set in the model config" + prompt = "\n".join( [ f"{role_to_name(message['role'], model_config)} {message['content']}" @@ -60,6 +72,73 @@ def make_prompt(messages: List[Message], model_config: LLaMAModelConfig) -> str: return prompt +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" +DEFAULT_SYSTEM_PROMPT = """You are a helpful and honest assistant.""" + + +def make_prompt_llama2(llm, messages: List[Message]) -> List[int]: + if messages[0]["role"] != "system": + messages = [ + cast( + Message, + { + "role": "system", + "content": DEFAULT_SYSTEM_PROMPT, + }, + ) + ] + messages + messages = [ + cast( + Message, + { + "role": messages[1]["role"], + "content": B_SYS + + messages[0]["content"] + + E_SYS + + messages[1]["content"], + }, + ) + ] + messages[2:] + assert all([msg["role"] == "user" for msg in messages[::2]]) and all( + [msg["role"] == "assistant" for msg in messages[1::2]] + ), ( + "model only supports 'system', 'user' and 'assistant' roles, " + "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" + ) + + dialog_tokens = sum( + [ + llm.tokenize( + bytes( + f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", + "utf-8", + ), + add_bos=True, + ) + + [llm.token_eos()] + for prompt, answer in zip( + messages[::2], + messages[1::2], + ) + ], + [], + ) + assert ( + messages[-1]["role"] == "user" + ), f"Last message must be from user, got {messages[-1]['role']}" + + dialog_tokens += llm.tokenize( + bytes(f"{B_INST} {(messages[-1]['content']).strip()} {E_INST}", "utf-8"), + add_bos=True, + ) + + return dialog_tokens + + +llms: dict[str, Any] = {} + + class LLaMACompletionProvider(CompletionProvider): def complete( self, messages: List[Message], args: dict, stream: bool = False @@ -68,15 +147,39 @@ def complete( model_config = LLAMA_MODELS[args["model"]] - with suppress_stderr(): - llm = Llama( - model_path=model_config["path"], - n_ctx=2048, - verbose=False, - use_mlock=True, - ) - prompt = make_prompt(messages, model_config) - print(prompt) + if model_config.get("llama2", False): + return self._complete_llama2(model_config, messages, args, stream) + else: + return self._complete_llama1(model_config, messages, args, stream) + + def _create_model(self, model_config: LLaMAModelConfig): + path = model_config["path"] + if path not in llms: + with suppress_stderr(): + llms[path] = Llama( + model_path=path, + n_ctx=4096 if model_config.get("llama2", False) else 2048, + verbose=False, + use_mlock=True, + n_gpu_layers=model_config.get("n_gpu_layers", 0), + seed=random.randint(0, 2**32 - 1), + ) + return llms[path] + + def _complete_llama1( + self, + model_config: LLaMAModelConfig, + messages: List[Message], + args: dict, + stream: bool = False, + ) -> Iterator[str]: + assert ( + "human_prompt" in model_config and "assistant_prompt" in model_config + ), "either `llama2: True` or human_prompt and assistant_prompt must be set in the model config" + + llm = self._create_model(model_config) + + prompt = make_prompt_llama1(messages, model_config) extra_args = {} if "temperature" in args: @@ -98,6 +201,42 @@ def complete( else: yield cast(Completion, gen)["choices"][0]["text"] + def _complete_llama2( + self, + model_config: LLaMAModelConfig, + messages: List[Message], + args: dict, + stream: bool = False, + ) -> Iterator[str]: + llm = self._create_model(model_config) + + prompt = make_prompt_llama2(llm, messages) + + extra_args = {} + if "temperature" in args: + extra_args["temp"] = args["temperature"] + if "top_p" in args: + extra_args["top_p"] = args["top_p"] + + gen = llm.generate( + prompt, + top_k=65536, + **extra_args, + ) + + result = "" + for token in gen: + if token == llm.token_eos(): + break + + text = llm.detokenize([token]).decode("utf-8") + result += text + if stream: + yield text + + if not stream: + yield result + # https://stackoverflow.com/a/50438156 class suppress_stderr(object): diff --git a/pyproject.toml b/pyproject.toml index d391ca8..3de28bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ [project.optional-dependencies] llama = [ - "llama-cpp-python==0.1.57", + "llama-cpp-python==0.1.73", ] [project.urls]