Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ on:
push:
tags:
- 'v[0-9]+.[0-9]+.[0-9]+'
- 'v[0-9]+.[0-9]+.[0-9]+a[0-9]*'
- 'v[0-9]+.[0-9]+.[0-9]+b[0-9]*'
- 'v[0-9]+.[0-9]+.[0-9]+rc[0-9]*'
- 'v[0-9]+.[0-9]+.[0-9]+a[0-9]+'
- 'v[0-9]+.[0-9]+.[0-9]+b[0-9]+'
- 'v[0-9]+.[0-9]+.[0-9]+rc[0-9]+'

permissions:
id-token: write
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ on:

jobs:
check:
name: Ruff check
name: Ruff & uv check
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- uses: cashapp/activate-hermit@v1
- run: |
uv lock --check
uv run ruff check .
uv run ruff format --check

Expand All @@ -34,4 +35,4 @@ jobs:
- uses: actions/checkout@v4
- uses: cashapp/activate-hermit@v1
- run: |
uv run pytest
uv run pytest
4 changes: 3 additions & 1 deletion projects/mini/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from pydantic_ai import Tool

from ai_migrate.context import ToolCallContext

def add(x: int, y: int) -> int:

def add(ctx: ToolCallContext, x: int, y: int) -> int:
"""Add two numbers"""
return x + y

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ai-migrate-tools"
version = "0.1.4a1"
version = "0.1.4a2"
description = "LLM-powered code migrations at scale"
license = "Apache-2.0"
readme = "README.md"
Expand Down
13 changes: 13 additions & 0 deletions src/ai_migrate/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dataclasses import dataclass
from pathlib import Path

from pydantic_ai import RunContext


@dataclass
class MigrationContext:
target_files: list[str]
target_dir: Path | None


ToolCallContext = RunContext[MigrationContext]
28 changes: 18 additions & 10 deletions src/ai_migrate/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

from pydantic_ai.messages import ToolCallPart
from pydantic_ai.tools import Tool
from pydantic_ai import RunContext

from ai_migrate.llm_providers import DefaultClient
from .context import MigrationContext, ToolCallContext
from .fake_llm_client import FakeLLMClient
from .git_identity import environment_variables
from .manifest import FileGroup, FileEntry, Manifest
Expand All @@ -31,9 +31,6 @@ def log(*args, **kwargs):
print(*args, **kwargs, file=LOG_STREAM.get(), flush=True)


NoneContext = RunContext[None]


@dataclass
class FileContent:
name: str
Expand Down Expand Up @@ -201,7 +198,7 @@ def migrate_prompt(example: MigrationExample) -> list[dict]:


async def handle_tool_calls(
tools: list[Tool], tool_calls: list[dict[str, Any]]
tools: list[Tool], tool_calls: list[dict[str, Any]], context: MigrationContext
) -> list[dict[str, str]]:
tool_results = []
tools_by_name = {tool.name: tool for tool in tools}
Expand All @@ -215,8 +212,8 @@ async def handle_tool_calls(
log("[agent] Running tool", tool.name, f"{tool_call_message=}")
result = await tool._run(
tool_call_message,
NoneContext(
deps=None,
ToolCallContext(
deps=context,
model=None,
usage=None,
prompt="",
Expand Down Expand Up @@ -265,7 +262,12 @@ def combine_examples_into_conversation(


async def call_llm(
client: DefaultClient, messages: list, tools: list[Tool], temperature=0.1
client: DefaultClient,
messages: list,
tools: list[Tool],
*,
context: MigrationContext,
temperature=0.1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious why temperature is now a keyword arg?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't really a hard rule here, but once you start having several args I think it's good practice to make most of them keyword-only. Not to mention if you call this particular one without using the keyword, the value is pretty confusing: call_llm(messages, tools, 0.2) for example.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair

) -> tuple[dict, list[dict]]:
"""Call LLM for completions

Expand All @@ -283,7 +285,7 @@ async def call_llm(
if not (tool_calls := assistant_message.get("tool_calls")):
return response, messages

tool_results = await handle_tool_calls(tools, tool_calls)
tool_results = await handle_tool_calls(tools, tool_calls, context)

messages.append(assistant_message)
for result in tool_results:
Expand Down Expand Up @@ -606,7 +608,13 @@ async def _run(
iteration_messages = iteration_messages[1:]
messages = build_messages(messages, iteration_messages)

response, messages = await call_llm(client, messages, tools or [])
context = MigrationContext(
target_files=target_files,
target_dir=Path(target_dir) if target_dir else None,
)
response, messages = await call_llm(
client, messages, tools or [], context=context
)

response_text = response["choices"][0]["message"]["content"]
parsed_result = extract_code_blocks(response_text)
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.