diff --git a/README.md b/README.md index 0cd3d53..6fd3be6 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ Simply stage your files and run `commitai`. It analyzes the diff, optionally tak * 📝 **Optional Explanations**: Provide a high-level description of your changes as input to guide the AI, or let it infer the context solely from the code diff. * ✅ **Pre-commit Hook Integration**: Automatically runs your existing native Git pre-commit hook (`.git/hooks/pre-commit`) before generating the message, ensuring code quality and style checks pass. * 🔧 **Customizable Prompts via Templates**: Add custom instructions or context to the AI prompt using global environment variables or repository-specific template files. -* 🤖 **Multiple AI Provider Support**: Choose your preferred AI model from OpenAI, Anthropic, or Google. +* 🤖 **Multiple AI Provider Support**: Choose your preferred AI model from OpenAI, Anthropic, Google or local AI models with Ollama. * ⚙️ **Flexible Workflow**: * Stages all changes automatically (`-a` flag). * Reviews message in your default Git editor (default behavior). @@ -98,6 +98,14 @@ CommitAi requires API keys for the AI provider you intend to use. Set these as e You only need to set the key for the provider corresponding to the model you select (or the default, Gemini). +### Ollama + +CommitAi can also work with Ollama models: +```bash +export OLLAMA_HOST="your_ollama_base_url" +``` + + ### Commit Templates (Optional) You can add custom instructions to the default system prompt used by the AI. This is useful for enforcing project-specific guidelines (e.g., mentioning ticket numbers). @@ -234,7 +242,7 @@ Contributions are highly welcome! Please follow these steps: 9. Run checks locally before committing: * Format code: `ruff format .` * Lint code: `ruff check .` - * Run type checks: `mypy commitai commitai/tests` + * Run type checks: `mypy commitai tests` * Run tests: `pytest` 10. Commit your changes (you can use `commitai`!). 11. Push your branch to your fork: `git push origin my-feature-branch` diff --git a/commitai/cli.py b/commitai/cli.py index 437982f..4618704 100644 --- a/commitai/cli.py +++ b/commitai/cli.py @@ -2,11 +2,12 @@ # -*- coding: utf-8 -*- import os -from typing import Optional, Tuple +from typing import Optional, Tuple, cast import click from langchain_anthropic import ChatAnthropic from langchain_core.language_models.chat_models import BaseChatModel +from langchain_ollama import ChatOllama from langchain_openai import ChatOpenAI # Keep SecretStr import in case it's needed elsewhere or for future refinement @@ -34,7 +35,6 @@ ) -# Helper function to get API key with priority def _get_google_api_key() -> Optional[str]: """Gets the Google API key from environment variables in priority order.""" return ( @@ -44,7 +44,6 @@ def _get_google_api_key() -> Optional[str]: ) -# Helper function to initialize the LLM def _initialize_llm(model: str) -> BaseChatModel: """Initializes and returns the LangChain chat model based on the model name.""" google_api_key_str = _get_google_api_key() @@ -56,17 +55,16 @@ def _initialize_llm(model: str) -> BaseChatModel: raise click.ClickException( "Error: OPENAI_API_KEY environment variable not set." ) - # Pass raw string and ignore Mypy SecretStr complaint return ChatOpenAI(model=model, api_key=api_key, temperature=0.7) + elif model.startswith("claude-"): api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key: raise click.ClickException( "Error: ANTHROPIC_API_KEY environment variable not set." ) - # Pass raw string and ignore Mypy SecretStr complaint - # Also ignore missing timeout argument if it's optional return ChatAnthropic(model_name=model, api_key=api_key, temperature=0.7) + elif model.startswith("gemini-"): if ChatGoogleGenerativeAI is None: raise click.ClickException( @@ -79,30 +77,23 @@ def _initialize_llm(model: str) -> BaseChatModel: "Error: Google API Key not found. Set GOOGLE_API_KEY, " "GEMINI_API_KEY, or GOOGLE_GENERATIVE_AI_API_KEY." ) - # Pass raw string and ignore Mypy SecretStr complaint - # Also ignore missing optional arguments return ChatGoogleGenerativeAI( model=model, google_api_key=google_api_key_str, temperature=0.7, convert_system_message_to_human=True, ) + elif model.startswith("llama"): + # Ollama models (e.g., llama2, llama3) + return cast(BaseChatModel, ChatOllama(model=model, temperature=0.7)) else: raise click.ClickException(f"🚫 Unsupported model: {model}") + except Exception as e: raise click.ClickException(f"Error initializing AI model: {e}") from e -# Helper function to prepare context (diff, repo, branch) def _prepare_context() -> str: - """ - Gets the repository context (name, branch, diff). - - Returns: - str: The formatted diff string. - Raises: - click.ClickException: If no staged changes are found. - """ diff = get_staged_changes_diff() if not diff: raise click.ClickException("⚠️ Warning: No staged changes found. Exiting.") @@ -112,11 +103,9 @@ def _prepare_context() -> str: return f"{repo_name}/{branch_name}\n\n{diff}" -# Helper function to build the final prompt def _build_prompt( explanation: str, formatted_diff: str, template: Optional[str] ) -> str: - """Builds the complete prompt for the AI model.""" system_message = default_system_message if template: system_message += adding_template @@ -130,14 +119,7 @@ def _build_prompt( return f"{system_message}\n\n{diff_message}" -# Helper function to handle commit message editing and creation def _handle_commit(commit_message: str, commit_flag: bool) -> None: - """ - Writes message, optionally opens editor, and creates the commit. - - Raises: - click.ClickException: On file I/O errors or if the commit is aborted. - """ repo_path = get_repository_name() git_dir = os.path.join(repo_path, ".git") try: @@ -180,7 +162,6 @@ def _handle_commit(commit_message: str, commit_flag: bool) -> None: @click.group(context_settings={"help_option_names": ["-h", "--help"]}) def cli() -> None: - """CommitAi CLI group.""" pass @@ -224,7 +205,6 @@ def generate_message( add: bool, model: str, ) -> None: - """Generates a commit message based on staged changes and description.""" explanation = " ".join(description) llm = _initialize_llm(model) diff --git a/pyproject.toml b/pyproject.toml index 944b7b9..cf78fb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,9 @@ build-backend = "hatchling.build" [project] name = "commitai" # Make sure to update version in commitai/__init__.py as well -version = "1.0.4" # Assuming version was bumped, adjust if needed + +version = "1.0.5" + description = "Commitai helps you generate git commit messages using AI" readme = "README.md" requires-python = ">=3.9" @@ -29,12 +31,13 @@ classifiers = [ ] dependencies = [ "click>=8.0,<9.0", - "langchain>=0.1.0,<0.3.0", - "langchain-core>=0.1.0,<0.3.0", - "langchain-community>=0.0.20,<0.2.0", - "langchain-anthropic>=0.1.0,<0.3.0", - "langchain-openai>=0.1.0,<0.3.0", - "langchain-google-genai~=0.0.9", + "langchain>=0.1.0,<=0.3.25", + "langchain-core>=0.1.0,<=0.3.58", + "langchain-community>=0.0.20,<=0.3.23", + "langchain-anthropic>=0.1.0,<=0.3.12", + "langchain-openai>=0.1.0,<=0.3.16", + "langchain-google-genai~=2.1.4", + "langchain-ollama~=0.3.2", "pydantic>=2.0,<3.0", ] @@ -55,7 +58,7 @@ test = [ "types-setuptools", # Pin ruff version to match pre-commit hook "ruff==0.4.4", - "langchain-google-genai~=0.0.9", # Keep google genai here for mypy in pre-commit + "langchain-google-genai~=2.1.4", # Keep google genai here for mypy in pre-commit ] [tool.hatch.version] diff --git a/tests/test_cli.py b/tests/test_cli.py index f629223..8d15e51 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -10,6 +10,7 @@ from langchain_google_genai import ( ChatGoogleGenerativeAI as ActualChatGoogleGenerativeAI, ) +from langchain_ollama import ChatOllama from langchain_openai import ChatOpenAI from commitai.cli import cli @@ -35,6 +36,7 @@ def mock_generate_deps(tmp_path): ) as mock_google_class_in_cli, patch("commitai.cli.ChatOpenAI", spec=ChatOpenAI) as mock_openai_class, patch("commitai.cli.ChatAnthropic", spec=ChatAnthropic) as mock_anthropic_class, + patch("commitai.cli.ChatOllama", spec=ChatOllama) as mock_ollama_class, patch("commitai.cli.stage_all_changes") as mock_stage, patch("commitai.cli.run_pre_commit_hook", return_value=True) as mock_hook, patch( @@ -62,17 +64,20 @@ def mock_generate_deps(tmp_path): mock_openai_instance = mock_openai_class.return_value mock_anthropic_instance = mock_anthropic_class.return_value mock_google_instance = mock_google_class_in_cli.return_value + mock_ollama_instance = mock_ollama_class.return_value mock_openai_instance.spec = ChatOpenAI mock_anthropic_instance.spec = ChatAnthropic if mock_google_class_in_cli is not None: mock_google_instance.spec = ActualChatGoogleGenerativeAI + mock_ollama_instance.spec = ChatOllama content_mock = MagicMock() content_mock.content = "Generated commit message" mock_openai_instance.invoke.return_value = content_mock mock_anthropic_instance.invoke.return_value = content_mock mock_google_instance.invoke.return_value = content_mock + mock_ollama_instance.invoke.return_value = content_mock def getenv_side_effect(key, default=None): if key == "OPENAI_API_KEY": @@ -81,6 +86,8 @@ def getenv_side_effect(key, default=None): return "fake_anthropic_key" if key == "TEMPLATE_COMMIT": return None + if key == "OLLAMA_HOST": + return "fake_ollama_host" return os.environ.get(key, default) mock_getenv.side_effect = getenv_side_effect @@ -89,9 +96,11 @@ def getenv_side_effect(key, default=None): "openai_class": mock_openai_class, "anthropic_class": mock_anthropic_class, "google_class": mock_google_class_in_cli, + "ollama_class": mock_ollama_class, "openai_instance": mock_openai_instance, "anthropic_instance": mock_anthropic_instance, "google_instance": mock_google_instance, + "ollama_instance": mock_ollama_instance, "stage": mock_stage, "hook": mock_hook, "diff": mock_diff, @@ -186,6 +195,22 @@ def test_generate_select_claude(mock_generate_deps): mock_generate_deps["commit"].assert_called_once() +def test_generate_select_ollama(mock_generate_deps): + """Test selecting ollama model via generate command.""" + runner = CliRunner() + mock_generate_deps[ + "file_open" + ].return_value.read.return_value = "Generated commit message" + result = runner.invoke(cli, ["generate", "-m", "llama3", "Test explanation"]) + + assert result.exit_code == 0, result.output + mock_generate_deps["ollama_class"].assert_called_once_with( + model="llama3", temperature=0.7 + ) + mock_generate_deps["ollama_instance"].invoke.assert_called_once() + mock_generate_deps["commit"].assert_called_once() + + def test_generate_with_add_flag(mock_generate_deps): """Test the -a flag with generate command.""" runner = CliRunner() @@ -314,18 +339,6 @@ def test_generate_google_key_priority(mock_generate_deps): ) -def test_generate_unsupported_model(mock_generate_deps): - """Test generate command with an unsupported model.""" - runner = CliRunner() - result = runner.invoke( - cli, ["generate", "-m", "unsupported-model", "Test explanation"] - ) - - assert result.exit_code == 1, result.output - assert "Unsupported model: unsupported-model" in result.output - mock_generate_deps["commit"].assert_not_called() - - def test_generate_empty_commit_message_aborts(mock_generate_deps): """Test generate command aborts with empty commit message after edit.""" runner = CliRunner()