From ee9c9020b7878e3f006124175ca4e54d2cdc8b3e Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Wed, 2 Apr 2025 14:40:51 -0400 Subject: [PATCH 1/3] Add context to tool calls --- .github/workflows/publish.yaml | 6 +++--- pyproject.toml | 2 +- src/ai_migrate/context.py | 11 +++++++++++ src/ai_migrate/migrate.py | 18 ++++++++---------- uv.lock | 2 +- 5 files changed, 24 insertions(+), 15 deletions(-) create mode 100644 src/ai_migrate/context.py diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 5ae9a98..fb81527 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -4,9 +4,9 @@ on: push: tags: - 'v[0-9]+.[0-9]+.[0-9]+' - - 'v[0-9]+.[0-9]+.[0-9]+a[0-9]*' - - 'v[0-9]+.[0-9]+.[0-9]+b[0-9]*' - - 'v[0-9]+.[0-9]+.[0-9]+rc[0-9]*' + - 'v[0-9]+.[0-9]+.[0-9]+a[0-9]+' + - 'v[0-9]+.[0-9]+.[0-9]+b[0-9]+' + - 'v[0-9]+.[0-9]+.[0-9]+rc[0-9]+' permissions: id-token: write diff --git a/pyproject.toml b/pyproject.toml index aeb8ff5..e4f862e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ai-migrate-tools" -version = "0.1.4a1" +version = "0.1.4a2" description = "LLM-powered code migrations at scale" license = "Apache-2.0" readme = "README.md" diff --git a/src/ai_migrate/context.py b/src/ai_migrate/context.py new file mode 100644 index 0000000..ccfb7b2 --- /dev/null +++ b/src/ai_migrate/context.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + +from pydantic_ai import RunContext + + +@dataclass +class MigrationContext: + target_files: list[str] + + +ToolCallContext = RunContext[MigrationContext] diff --git a/src/ai_migrate/migrate.py b/src/ai_migrate/migrate.py index 5b3c8d7..a795ce6 100644 --- a/src/ai_migrate/migrate.py +++ b/src/ai_migrate/migrate.py @@ -12,9 +12,9 @@ from pydantic_ai.messages import ToolCallPart from pydantic_ai.tools import Tool -from pydantic_ai import RunContext from ai_migrate.llm_providers import DefaultClient +from .context import MigrationContext, ToolCallContext from .fake_llm_client import FakeLLMClient from .git_identity import environment_variables from .manifest import FileGroup, FileEntry, Manifest @@ -31,9 +31,6 @@ def log(*args, **kwargs): print(*args, **kwargs, file=LOG_STREAM.get(), flush=True) -NoneContext = RunContext[None] - - @dataclass class FileContent: name: str @@ -201,7 +198,7 @@ def migrate_prompt(example: MigrationExample) -> list[dict]: async def handle_tool_calls( - tools: list[Tool], tool_calls: list[dict[str, Any]] + tools: list[Tool], tool_calls: list[dict[str, Any]], context: MigrationContext ) -> list[dict[str, str]]: tool_results = [] tools_by_name = {tool.name: tool for tool in tools} @@ -215,8 +212,8 @@ async def handle_tool_calls( log("[agent] Running tool", tool.name, f"{tool_call_message=}") result = await tool._run( tool_call_message, - NoneContext( - deps=None, + ToolCallContext( + deps=context, model=None, usage=None, prompt="", @@ -265,7 +262,7 @@ def combine_examples_into_conversation( async def call_llm( - client: DefaultClient, messages: list, tools: list[Tool], temperature=0.1 + client: DefaultClient, messages: list, tools: list[Tool], *, context: MigrationContext, temperature=0.1 ) -> tuple[dict, list[dict]]: """Call LLM for completions @@ -283,7 +280,7 @@ async def call_llm( if not (tool_calls := assistant_message.get("tool_calls")): return response, messages - tool_results = await handle_tool_calls(tools, tool_calls) + tool_results = await handle_tool_calls(tools, tool_calls, context) messages.append(assistant_message) for result in tool_results: @@ -606,7 +603,8 @@ async def _run( iteration_messages = iteration_messages[1:] messages = build_messages(messages, iteration_messages) - response, messages = await call_llm(client, messages, tools or []) + context = MigrationContext(target_files=target_files) + response, messages = await call_llm(client, messages, tools or [], context=context) response_text = response["choices"][0]["message"]["content"] parsed_result = extract_code_blocks(response_text) diff --git a/uv.lock b/uv.lock index 923283a..3622a63 100644 --- a/uv.lock +++ b/uv.lock @@ -4,7 +4,7 @@ requires-python = ">=3.13" [[package]] name = "ai-migrate-tools" -version = "0.1.3" +version = "0.1.4a2" source = { editable = "." } dependencies = [ { name = "click" }, From 24141c665e7b1caffe08821b632dfbb8a2cdbdaa Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Wed, 2 Apr 2025 14:44:00 -0400 Subject: [PATCH 2/3] Add target_dir --- src/ai_migrate/context.py | 2 ++ src/ai_migrate/migrate.py | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/ai_migrate/context.py b/src/ai_migrate/context.py index ccfb7b2..d904a88 100644 --- a/src/ai_migrate/context.py +++ b/src/ai_migrate/context.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from pathlib import Path from pydantic_ai import RunContext @@ -6,6 +7,7 @@ @dataclass class MigrationContext: target_files: list[str] + target_dir: Path | None ToolCallContext = RunContext[MigrationContext] diff --git a/src/ai_migrate/migrate.py b/src/ai_migrate/migrate.py index a795ce6..ad4388b 100644 --- a/src/ai_migrate/migrate.py +++ b/src/ai_migrate/migrate.py @@ -262,7 +262,12 @@ def combine_examples_into_conversation( async def call_llm( - client: DefaultClient, messages: list, tools: list[Tool], *, context: MigrationContext, temperature=0.1 + client: DefaultClient, + messages: list, + tools: list[Tool], + *, + context: MigrationContext, + temperature=0.1, ) -> tuple[dict, list[dict]]: """Call LLM for completions @@ -603,8 +608,13 @@ async def _run( iteration_messages = iteration_messages[1:] messages = build_messages(messages, iteration_messages) - context = MigrationContext(target_files=target_files) - response, messages = await call_llm(client, messages, tools or [], context=context) + context = MigrationContext( + target_files=target_files, + target_dir=Path(target_dir) if target_dir else None, + ) + response, messages = await call_llm( + client, messages, tools or [], context=context + ) response_text = response["choices"][0]["message"]["content"] parsed_result = extract_code_blocks(response_text) From 18ea003e8138247ef490e68e27e5bde6bf0b59ca Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Wed, 2 Apr 2025 14:48:58 -0400 Subject: [PATCH 3/3] Tests --- .github/workflows/test.yaml | 5 +++-- projects/mini/tools.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7481b61..bc1ada6 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -5,13 +5,14 @@ on: jobs: check: - name: Ruff check + name: Ruff & uv check runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: cashapp/activate-hermit@v1 - run: | + uv lock --check uv run ruff check . uv run ruff format --check @@ -34,4 +35,4 @@ jobs: - uses: actions/checkout@v4 - uses: cashapp/activate-hermit@v1 - run: | - uv run pytest \ No newline at end of file + uv run pytest diff --git a/projects/mini/tools.py b/projects/mini/tools.py index 1667c92..3d6323d 100644 --- a/projects/mini/tools.py +++ b/projects/mini/tools.py @@ -1,7 +1,9 @@ from pydantic_ai import Tool +from ai_migrate.context import ToolCallContext -def add(x: int, y: int) -> int: + +def add(ctx: ToolCallContext, x: int, y: int) -> int: """Add two numbers""" return x + y