Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Simply stage your files and run `commitai`. It analyzes the diff, optionally tak
* 📝 **Optional Explanations**: Provide a high-level description of your changes as input to guide the AI, or let it infer the context solely from the code diff.
* ✅ **Pre-commit Hook Integration**: Automatically runs your existing native Git pre-commit hook (`.git/hooks/pre-commit`) before generating the message, ensuring code quality and style checks pass.
* 🔧 **Customizable Prompts via Templates**: Add custom instructions or context to the AI prompt using global environment variables or repository-specific template files.
* 🤖 **Multiple AI Provider Support**: Choose your preferred AI model from OpenAI, Anthropic, or Google.
* 🤖 **Multiple AI Provider Support**: Choose your preferred AI model from OpenAI, Anthropic, Google or local AI models with Ollama.
* ⚙️ **Flexible Workflow**:
* Stages all changes automatically (`-a` flag).
* Reviews message in your default Git editor (default behavior).
Expand Down Expand Up @@ -98,6 +98,14 @@ CommitAi requires API keys for the AI provider you intend to use. Set these as e

You only need to set the key for the provider corresponding to the model you select (or the default, Gemini).

### Ollama

CommitAi can also work with Ollama models:
```bash
export OLLAMA_HOST="your_ollama_base_url"
```


### Commit Templates (Optional)

You can add custom instructions to the default system prompt used by the AI. This is useful for enforcing project-specific guidelines (e.g., mentioning ticket numbers).
Expand Down Expand Up @@ -234,7 +242,7 @@ Contributions are highly welcome! Please follow these steps:
9. Run checks locally before committing:
* Format code: `ruff format .`
* Lint code: `ruff check .`
* Run type checks: `mypy commitai commitai/tests`
* Run type checks: `mypy commitai tests`
* Run tests: `pytest`
10. Commit your changes (you can use `commitai`!).
11. Push your branch to your fork: `git push origin my-feature-branch`
Expand Down
36 changes: 8 additions & 28 deletions commitai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
# -*- coding: utf-8 -*-

import os
from typing import Optional, Tuple
from typing import Optional, Tuple, cast

import click
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI

# Keep SecretStr import in case it's needed elsewhere or for future refinement
Expand Down Expand Up @@ -34,7 +35,6 @@
)


# Helper function to get API key with priority
def _get_google_api_key() -> Optional[str]:
"""Gets the Google API key from environment variables in priority order."""
return (
Expand All @@ -44,7 +44,6 @@ def _get_google_api_key() -> Optional[str]:
)


# Helper function to initialize the LLM
def _initialize_llm(model: str) -> BaseChatModel:
"""Initializes and returns the LangChain chat model based on the model name."""
google_api_key_str = _get_google_api_key()
Expand All @@ -56,17 +55,16 @@ def _initialize_llm(model: str) -> BaseChatModel:
raise click.ClickException(
"Error: OPENAI_API_KEY environment variable not set."
)
# Pass raw string and ignore Mypy SecretStr complaint
return ChatOpenAI(model=model, api_key=api_key, temperature=0.7)

elif model.startswith("claude-"):
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise click.ClickException(
"Error: ANTHROPIC_API_KEY environment variable not set."
)
# Pass raw string and ignore Mypy SecretStr complaint
# Also ignore missing timeout argument if it's optional
return ChatAnthropic(model_name=model, api_key=api_key, temperature=0.7)

elif model.startswith("gemini-"):
if ChatGoogleGenerativeAI is None:
raise click.ClickException(
Expand All @@ -79,30 +77,23 @@ def _initialize_llm(model: str) -> BaseChatModel:
"Error: Google API Key not found. Set GOOGLE_API_KEY, "
"GEMINI_API_KEY, or GOOGLE_GENERATIVE_AI_API_KEY."
)
# Pass raw string and ignore Mypy SecretStr complaint
# Also ignore missing optional arguments
return ChatGoogleGenerativeAI(
model=model,
google_api_key=google_api_key_str,
temperature=0.7,
convert_system_message_to_human=True,
)
elif model.startswith("llama"):
# Ollama models (e.g., llama2, llama3)
return cast(BaseChatModel, ChatOllama(model=model, temperature=0.7))
else:
raise click.ClickException(f"🚫 Unsupported model: {model}")

except Exception as e:
raise click.ClickException(f"Error initializing AI model: {e}") from e


# Helper function to prepare context (diff, repo, branch)
def _prepare_context() -> str:
"""
Gets the repository context (name, branch, diff).

Returns:
str: The formatted diff string.
Raises:
click.ClickException: If no staged changes are found.
"""
diff = get_staged_changes_diff()
if not diff:
raise click.ClickException("⚠️ Warning: No staged changes found. Exiting.")
Expand All @@ -112,11 +103,9 @@ def _prepare_context() -> str:
return f"{repo_name}/{branch_name}\n\n{diff}"


# Helper function to build the final prompt
def _build_prompt(
explanation: str, formatted_diff: str, template: Optional[str]
) -> str:
"""Builds the complete prompt for the AI model."""
system_message = default_system_message
if template:
system_message += adding_template
Expand All @@ -130,14 +119,7 @@ def _build_prompt(
return f"{system_message}\n\n{diff_message}"


# Helper function to handle commit message editing and creation
def _handle_commit(commit_message: str, commit_flag: bool) -> None:
"""
Writes message, optionally opens editor, and creates the commit.

Raises:
click.ClickException: On file I/O errors or if the commit is aborted.
"""
repo_path = get_repository_name()
git_dir = os.path.join(repo_path, ".git")
try:
Expand Down Expand Up @@ -180,7 +162,6 @@ def _handle_commit(commit_message: str, commit_flag: bool) -> None:

@click.group(context_settings={"help_option_names": ["-h", "--help"]})
def cli() -> None:
"""CommitAi CLI group."""
pass


Expand Down Expand Up @@ -224,7 +205,6 @@ def generate_message(
add: bool,
model: str,
) -> None:
"""Generates a commit message based on staged changes and description."""
explanation = " ".join(description)

llm = _initialize_llm(model)
Expand Down
19 changes: 11 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ build-backend = "hatchling.build"
[project]
name = "commitai"
# Make sure to update version in commitai/__init__.py as well
version = "1.0.4" # Assuming version was bumped, adjust if needed

version = "1.0.5"

description = "Commitai helps you generate git commit messages using AI"
readme = "README.md"
requires-python = ">=3.9"
Expand All @@ -29,12 +31,13 @@ classifiers = [
]
dependencies = [
"click>=8.0,<9.0",
"langchain>=0.1.0,<0.3.0",
"langchain-core>=0.1.0,<0.3.0",
"langchain-community>=0.0.20,<0.2.0",
"langchain-anthropic>=0.1.0,<0.3.0",
"langchain-openai>=0.1.0,<0.3.0",
"langchain-google-genai~=0.0.9",
"langchain>=0.1.0,<=0.3.25",
"langchain-core>=0.1.0,<=0.3.58",
"langchain-community>=0.0.20,<=0.3.23",
"langchain-anthropic>=0.1.0,<=0.3.12",
"langchain-openai>=0.1.0,<=0.3.16",
"langchain-google-genai~=2.1.4",
"langchain-ollama~=0.3.2",
"pydantic>=2.0,<3.0",
]

Expand All @@ -55,7 +58,7 @@ test = [
"types-setuptools",
# Pin ruff version to match pre-commit hook
"ruff==0.4.4",
"langchain-google-genai~=0.0.9", # Keep google genai here for mypy in pre-commit
"langchain-google-genai~=2.1.4", # Keep google genai here for mypy in pre-commit
]

[tool.hatch.version]
Expand Down
37 changes: 25 additions & 12 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from langchain_google_genai import (
ChatGoogleGenerativeAI as ActualChatGoogleGenerativeAI,
)
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI

from commitai.cli import cli
Expand All @@ -35,6 +36,7 @@ def mock_generate_deps(tmp_path):
) as mock_google_class_in_cli,
patch("commitai.cli.ChatOpenAI", spec=ChatOpenAI) as mock_openai_class,
patch("commitai.cli.ChatAnthropic", spec=ChatAnthropic) as mock_anthropic_class,
patch("commitai.cli.ChatOllama", spec=ChatOllama) as mock_ollama_class,
patch("commitai.cli.stage_all_changes") as mock_stage,
patch("commitai.cli.run_pre_commit_hook", return_value=True) as mock_hook,
patch(
Expand Down Expand Up @@ -62,17 +64,20 @@ def mock_generate_deps(tmp_path):
mock_openai_instance = mock_openai_class.return_value
mock_anthropic_instance = mock_anthropic_class.return_value
mock_google_instance = mock_google_class_in_cli.return_value
mock_ollama_instance = mock_ollama_class.return_value

mock_openai_instance.spec = ChatOpenAI
mock_anthropic_instance.spec = ChatAnthropic
if mock_google_class_in_cli is not None:
mock_google_instance.spec = ActualChatGoogleGenerativeAI
mock_ollama_instance.spec = ChatOllama

content_mock = MagicMock()
content_mock.content = "Generated commit message"
mock_openai_instance.invoke.return_value = content_mock
mock_anthropic_instance.invoke.return_value = content_mock
mock_google_instance.invoke.return_value = content_mock
mock_ollama_instance.invoke.return_value = content_mock

def getenv_side_effect(key, default=None):
if key == "OPENAI_API_KEY":
Expand All @@ -81,6 +86,8 @@ def getenv_side_effect(key, default=None):
return "fake_anthropic_key"
if key == "TEMPLATE_COMMIT":
return None
if key == "OLLAMA_HOST":
return "fake_ollama_host"
return os.environ.get(key, default)

mock_getenv.side_effect = getenv_side_effect
Expand All @@ -89,9 +96,11 @@ def getenv_side_effect(key, default=None):
"openai_class": mock_openai_class,
"anthropic_class": mock_anthropic_class,
"google_class": mock_google_class_in_cli,
"ollama_class": mock_ollama_class,
"openai_instance": mock_openai_instance,
"anthropic_instance": mock_anthropic_instance,
"google_instance": mock_google_instance,
"ollama_instance": mock_ollama_instance,
"stage": mock_stage,
"hook": mock_hook,
"diff": mock_diff,
Expand Down Expand Up @@ -186,6 +195,22 @@ def test_generate_select_claude(mock_generate_deps):
mock_generate_deps["commit"].assert_called_once()


def test_generate_select_ollama(mock_generate_deps):
"""Test selecting ollama model via generate command."""
runner = CliRunner()
mock_generate_deps[
"file_open"
].return_value.read.return_value = "Generated commit message"
result = runner.invoke(cli, ["generate", "-m", "llama3", "Test explanation"])

assert result.exit_code == 0, result.output
mock_generate_deps["ollama_class"].assert_called_once_with(
model="llama3", temperature=0.7
)
mock_generate_deps["ollama_instance"].invoke.assert_called_once()
mock_generate_deps["commit"].assert_called_once()


def test_generate_with_add_flag(mock_generate_deps):
"""Test the -a flag with generate command."""
runner = CliRunner()
Expand Down Expand Up @@ -314,18 +339,6 @@ def test_generate_google_key_priority(mock_generate_deps):
)


def test_generate_unsupported_model(mock_generate_deps):
"""Test generate command with an unsupported model."""
runner = CliRunner()
result = runner.invoke(
cli, ["generate", "-m", "unsupported-model", "Test explanation"]
)

assert result.exit_code == 1, result.output
assert "Unsupported model: unsupported-model" in result.output
mock_generate_deps["commit"].assert_not_called()


def test_generate_empty_commit_message_aborts(mock_generate_deps):
"""Test generate command aborts with empty commit message after edit."""
runner = CliRunner()
Expand Down
Loading