Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gptcli/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_completion_provider(
return LLaMACompletionProvider()
elif model.startswith("command") or model.startswith("c4ai"):
return CohereCompletionProvider()
elif model.startswith("gemini"):
elif model.startswith("gemini") or model.startswith("gemma"):
return GoogleCompletionProvider()
else:
raise ValueError(f"Unknown model: {model}")
Expand Down
13 changes: 9 additions & 4 deletions gptcli/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import os
from typing import cast
import openai
import google.generativeai as genai
import argparse
import sys
import logging
import datetime
import gptcli.providers.anthropic
import gptcli.providers.cohere
import gptcli.providers.google as google
from gptcli.assistant import (
Assistant,
DEFAULT_ASSISTANTS,
Expand Down Expand Up @@ -201,7 +201,7 @@ def main():
gptcli.providers.cohere.api_key = config.cohere_api_key

if config.google_api_key:
genai.configure(api_key=config.google_api_key)
google.api_key = config.google_api_key

if config.llama_models is not None:
init_llama_models(config.llama_models)
Expand Down Expand Up @@ -240,7 +240,9 @@ def run_non_interactive(args, assistant):


class CLIChatSession(ChatSession):
def __init__(self, assistant: Assistant, markdown: bool, show_price: bool, stream: bool):
def __init__(
self, assistant: Assistant, markdown: bool, show_price: bool, stream: bool
):
listeners = [
CLIChatListener(markdown),
LoggingChatListener(),
Expand All @@ -256,7 +258,10 @@ def __init__(self, assistant: Assistant, markdown: bool, show_price: bool, strea
def run_interactive(args, assistant):
logger.info("Starting a new chat session. Assistant config: %s", assistant.config)
session = CLIChatSession(
assistant=assistant, markdown=args.markdown, show_price=args.show_price, stream=not args.no_stream
assistant=assistant,
markdown=args.markdown,
show_price=args.show_price,
stream=not args.no_stream,
)
history_filename = os.path.expanduser("~/.config/gpt-cli/history")
os.makedirs(os.path.dirname(history_filename), exist_ok=True)
Expand Down
126 changes: 76 additions & 50 deletions gptcli/providers/google.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import google.generativeai as genai
from google.generativeai.types.content_types import ContentDict
from google.generativeai.types.generation_types import GenerationConfig
from google.generativeai.types.safety_types import (
HarmBlockThreshold,
HarmCategory,
)
import os
from google import genai
from google.genai import types

from typing import Iterator, List, Optional

from gptcli.completion import (
Expand All @@ -22,65 +19,74 @@
}


def map_message(message: Message) -> ContentDict:
return {"role": ROLE_MAP[message["role"]], "parts": [message["content"]]}


SAFETY_SETTINGS = [
{"category": category, "threshold": HarmBlockThreshold.BLOCK_NONE}
for category in [
HarmCategory.HARM_CATEGORY_HARASSMENT,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
HarmCategory.HARM_CATEGORY_HATE_SPEECH,
]
]
api_key = os.environ.get("GEMINI_API_KEY")


class GoogleCompletionProvider(CompletionProvider):
def complete(
self, messages: List[Message], args: dict, stream: bool = False
) -> Iterator[CompletionEvent]:
generation_config = GenerationConfig(
temperature=args.get("temperature"),
top_p=args.get("top_p"),
)

model_name = args["model"]

client = genai.Client(api_key=api_key)
model = args["model"]
system_instruction = None
if messages[0]["role"] == "system":
system_instruction = messages[0]["content"]
messages = messages[1:]
else:
system_instruction = None

chat_history = [map_message(m) for m in messages]
contents = [
types.Content(
role=ROLE_MAP[m["role"]],
parts=[types.Part.from_text(text=m["content"])],
)
for m in messages
]

model = genai.GenerativeModel(model_name, system_instruction=system_instruction)
generate_content_config = types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=args.get("temperature"),
top_p=args.get("top_p"),
thinking_config=(
types.ThinkingConfig(
include_thoughts=True,
thinking_budget=args.get("thinking_budget"),
)
if args.get("thinking_budget")
else None
),
response_mime_type="text/plain",
)

if stream:
response = model.generate_content(
chat_history,
generation_config=generation_config,
safety_settings=SAFETY_SETTINGS,
stream=True,
response = client.models.generate_content_stream(
model=model,
contents=list(contents),
config=generate_content_config,
)

for chunk in response:
yield MessageDeltaEvent(chunk.text)
if chunk.usage_metadata:
prompt_tokens = chunk.usage_metadata.prompt_token_count or 0
completion_tokens = chunk.usage_metadata.candidates_token_count or 0
total_tokens = prompt_tokens + completion_tokens
yield MessageDeltaEvent(chunk.text or "")

else:
response = model.generate_content(
chat_history,
generation_config=generation_config,
safety_settings=SAFETY_SETTINGS,
response = client.models.generate_content(
model=model,
contents=list(contents),
config=generate_content_config,
)
yield MessageDeltaEvent(response.text)
yield MessageDeltaEvent(response.text or "")

prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
if response.usage_metadata:
prompt_tokens = response.usage_metadata.prompt_token_count or 0
completion_tokens = response.usage_metadata.candidates_token_count or 0
total_tokens = prompt_tokens + completion_tokens

prompt_tokens = response.usage_metadata.prompt_token_count
completion_tokens = response.usage_metadata.candidates_token_count
total_tokens = prompt_tokens + completion_tokens
pricing = get_gemini_pricing(model_name, prompt_tokens)
pricing = get_gemini_pricing(model, prompt_tokens)
if pricing:
yield UsageEvent.with_pricing(
prompt_tokens=prompt_tokens,
Expand All @@ -91,15 +97,35 @@ def complete(


def get_gemini_pricing(model: str, prompt_tokens: int) -> Optional[Pricing]:
if model.startswith("gemini-1.5-flash-8b"):
return {
"prompt": (0.0375 if prompt_tokens < 128000 else 0.075) / 1_000_000,
"response": (0.15 if prompt_tokens < 128000 else 0.30) / 1_000_000,
}
if model.startswith("gemini-1.5-flash"):
return {
"prompt": (0.35 if prompt_tokens < 128000 else 0.7) / 1_000_000,
"response": (1.05 if prompt_tokens < 128000 else 2.10) / 1_000_000,
"prompt": (0.075 if prompt_tokens < 128000 else 0.15) / 1_000_000,
"response": (0.30 if prompt_tokens < 128000 else 0.60) / 1_000_000,
}
elif model.startswith("gemini-1.5-pro"):
return {
"prompt": (3.50 if prompt_tokens < 128000 else 7.00) / 1_000_000,
"response": (10.5 if prompt_tokens < 128000 else 21.0) / 1_000_000,
"prompt": (1.25 if prompt_tokens < 128000 else 2.50) / 1_000_000,
"response": (5.0 if prompt_tokens < 128000 else 10.0) / 1_000_000,
}
elif model.startswith("gemini-2.0-flash-lite"):
return {
"prompt": 0.075 / 1_000_000,
"response": 0.30 / 1_000_000,
}
elif model.startswith("gemini-2.0-flash"):
return {
"prompt": 0.10 / 1_000_000,
"response": 0.40 / 1_000_000,
}
elif model.startswith("gemini-2.5-pro"):
return {
"prompt": (1.25 if prompt_tokens < 200000 else 2.50) / 1_000_000,
"response": (10.0 if prompt_tokens < 200000 else 15.0) / 1_000_000,
}
elif model.startswith("gemini-pro"):
return {
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@ dependencies = [
"attrs~=25.1.0",
"black~=25.1.0",
"cohere~=5.13.12",
"google-generativeai~=0.8.4",
"google-genai~=1.10.0",
"openai~=1.64.0",
"prompt-toolkit~=3.0.50",
"pytest~=8.3.4",
"PyYAML~=6.0.2",
"rich~=13.9.4",
"typing_extensions~=4.12.2",
"pydantic<2",
]

[project.optional-dependencies]
Expand Down
Loading