diff --git a/src/ai_migrate/manifest.py b/src/ai_migrate/manifest.py index 3a9a431..4fed3ea 100644 --- a/src/ai_migrate/manifest.py +++ b/src/ai_migrate/manifest.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field SYSTEM_PROMPT_FILE = "system_prompt.md" +GOOSE_PROMPT_FILE = "goose_prompt.md" VERIFY_SCRIPT_FILE = "verify.py" @@ -70,6 +71,12 @@ def to_file_group(self) -> FileGroup: return FileGroup(files=files, result=self.result, base_name=base_name) +class GooseConfig(BaseModel): + user_prompt: str = f"{{project_dir}}/{GOOSE_PROMPT_FILE}" + timeout_seconds: int = 15 * 60 + max_retries: int = 3 + + class Manifest(BaseModel): eval_target_repo_ref: str = "" eval_target_repo_remote: str = "" @@ -80,3 +87,4 @@ class Manifest(BaseModel): verify_cmd: str = f"{{py}} {{project_dir}}/{VERIFY_SCRIPT_FILE}" pre_verify_cmd: str = f"{{py}} {{project_dir}}/{VERIFY_SCRIPT_FILE} --pre" time: datetime = Field(default_factory=datetime.now) + goose_config: GooseConfig = Field(default_factory=GooseConfig) diff --git a/src/ai_migrate/migrate.py b/src/ai_migrate/migrate.py index ad4388b..3cd6222 100644 --- a/src/ai_migrate/migrate.py +++ b/src/ai_migrate/migrate.py @@ -17,7 +17,7 @@ from .context import MigrationContext, ToolCallContext from .fake_llm_client import FakeLLMClient from .git_identity import environment_variables -from .manifest import FileGroup, FileEntry, Manifest +from .manifest import FileGroup, FileEntry, GooseConfig, Manifest from .eval_generator import generate_eval_from_migration @@ -349,6 +349,7 @@ async def run( target_basename: str = "", dont_create_evals: bool = False, tools: list[Tool] = None, + goose_config: Optional[GooseConfig] = None, ): """Run the migration process on the target files. Args: @@ -431,11 +432,13 @@ async def run( check=True, cwd=git_root, ) - await subprocess_run( - ["git", "checkout", "--force", "-B", branch, start_point], - check=True, - cwd=worktree_root, - ) + + if int(os.getenv("AI_MIGRATE_MAX_TRIES", 10)) >= 1: + await subprocess_run( + ["git", "checkout", "--force", "-B", branch, start_point], + check=True, + cwd=worktree_root, + ) # If using target_dir, read files from original location instead of worktree target_root = source_git_root if target_dir else worktree_root @@ -454,6 +457,7 @@ async def run( target_dir_rel_path=target_dir_rel_path, target_basename=target_basename, tools=tools, + goose_config=goose_config, ) @@ -529,6 +533,7 @@ async def _run( target_dir_rel_path: Path | str | None = None, target_basename: str = None, tools: list[Tool] = None, + goose_config: Optional[GooseConfig] = None, ): if llm_fakes: client = FakeLLMClient(llm_fakes) @@ -569,6 +574,19 @@ async def _run( messages = combine_examples_into_conversation(examples, target, system_prompt) all_files_to_verify = set() + def build_verify_cmd(all_files_to_verify: set[str]): + if target_dir: + return [ + *verify_cmd, + str(Path(target_dir_rel_path) / target_basename), + ] + return [ + *verify_cmd, + *[str(Path(worktree_root) / f) for f in all_files_to_verify], + ] + + full_verify_cmd = build_verify_cmd(all_files_to_verify) + iteration_messages = [] if pre_verify_cmd: @@ -652,17 +670,7 @@ async def _run( all_files_to_verify |= written_files - # Add the files to verify with the correct paths - if target_dir: - full_verify_cmd = [ - *verify_cmd, - str(Path(target_dir_rel_path) / target_basename), - ] - else: - full_verify_cmd = [ - *verify_cmd, - *[str(Path(worktree_root) / f) for f in all_files_to_verify], - ] + full_verify_cmd = build_verify_cmd(all_files_to_verify) log(f"Running verification: {full_verify_cmd}") verify_process = await asyncio.create_subprocess_exec( @@ -765,7 +773,7 @@ async def _run( log(f"Exception type: {type(e).__name__}") await remove_worktree(worktree_root) - break + return True log("Verification failed:") for line in verification_output.splitlines(): log(f"[verify] {line}") @@ -789,4 +797,173 @@ async def _run( iteration_messages.append(iteration_message) else: - raise ValueError("Migration failed: Out of tries") + log("Migration failed: Out of tries") + + if goose_config: + best_exit_code = float("inf") # Track best exit code so far + + for i in range(goose_config.max_retries): + log(f"Running migration attempt {i + 1} with Goose") + + goose_user_extra = "" + if goose_config.user_prompt: + goose_user_extra = Path(goose_config.user_prompt).read_text() + + directory_instructions = ( + f"You may only make changes to the files inside {target_dir_rel_path}/{target_basename}. Under no circumstances should you touch any files outside of this directory. If I detect that you do, I will be very disappointed in you and will switch to a smarter model." + if target_dir + else f"You may only make changes to the files: {', '.join(target_files)}. Under no circumstances should you touch any other files." + ) + + verify_cmd_str = " ".join(full_verify_cmd) + + goose_prompt = ( + "You are a helpful assistant for code migration. The migration is almost done but is not passing verification. " + "With as few changes as possible, make the migration pass verification. " + f"{directory_instructions} " + "You may verify if the migration is correct by running the following command: " + f"{verify_cmd_str} " + "The verification output may be large so pipe it to a file verification_output.txt and read it from there. " + "Keep trying until the migration passes verification." + ) + + if goose_user_extra: + goose_prompt += f"\n\n{goose_user_extra}" + + goose_command = [ + "goose", + "run", + "--text", + goose_prompt, + "--with-builtin", + "developer", + ] + + log(f"Running goose: {goose_command}") + + goose_process = await asyncio.create_subprocess_exec( + *goose_command, + cwd=worktree_root, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + async def kill_after_timeout(): + await asyncio.sleep(goose_config.timeout_seconds) + if goose_process.returncode is None: + log( + f"[goose] Killing process after {goose_config.timeout_seconds} seconds timeout" + ) + goose_process.kill() + + timeout_task = asyncio.create_task(kill_after_timeout()) + + output_lines = [] + + async def read_stream(stream, prefix): + while True: + line = await stream.readline() + if not line: + break + decoded = line.decode().rstrip() + log(f"[{prefix}] {decoded}") + output_lines.append(decoded) + + stdout_task = asyncio.create_task( + read_stream(goose_process.stdout, "goose") + ) + stderr_task = asyncio.create_task( + read_stream(goose_process.stderr, "goose-err") + ) + + try: + await goose_process.wait() + await stdout_task + await stderr_task + timeout_task.cancel() + except asyncio.CancelledError: + if goose_process.returncode is None: + goose_process.kill() + stdout_task.cancel() + stderr_task.cancel() + raise + + goose_output = "\n".join(output_lines[-50:]) + + if target_dir: + await subprocess_run( + ["git", "reset"], + cwd=worktree_root, + ) + + git_path = Path(target_dir_rel_path) / target_basename + await subprocess_run( + ["git", "add", git_path], + cwd=worktree_root, + ) + else: + await subprocess_run( + ["git", "reset"], + cwd=worktree_root, + ) + + for file in written_files: + await subprocess_run( + ["git", "add", file], + cwd=worktree_root, + ) + + verify_process = await asyncio.create_subprocess_exec( + *full_verify_cmd, + cwd=worktree_root, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + stdout, stderr = await verify_process.communicate() + verification_output = (stderr or stdout or b"").decode() + exit_code = verify_process.returncode + + if exit_code > best_exit_code and best_exit_code != 0: + log( + f"Exit code {exit_code} is worse than previous best {best_exit_code}, resetting changes" + ) + await subprocess_run( + ["git", "reset", "--hard"], + cwd=worktree_root, + ) + continue + + best_exit_code = min(best_exit_code, exit_code) + + commit_message = f"Goose attempt {i + 1}, remaining tests: {exit_code}:\n\nGoose response:\n{goose_output}" + + await subprocess_run( + ["git", "commit", "--allow-empty", "-m", commit_message], + check=True, + cwd=worktree_root, + env={**os.environ, **environment_variables()}, + ) + + await subprocess_run( + [ + "git", + "notes", + "--ref=migrator-verify", + "add", + "-f", + "-m", + verification_output, + ], + check=True, + cwd=worktree_root, + ) + + if exit_code == 0: + log("Verification successful") + return True + else: + log(f"Verification failed with {exit_code} remaining test steps:") + for line in verification_output.splitlines(): + log(f"[verify] {line}") + + raise ValueError("Migration failed: Out of tries") diff --git a/src/ai_migrate/projects.py b/src/ai_migrate/projects.py index 580e0fe..9790052 100644 --- a/src/ai_migrate/projects.py +++ b/src/ai_migrate/projects.py @@ -150,6 +150,10 @@ async def process_one_fileset(index, files: FileGroup, task_name: str): status_manager.get_logger(task_name, header=f"==> {log_file} <=="), log_buffer, ) + if manifest.goose_config: + manifest.goose_config.user_prompt = ( + manifest.goose_config.user_prompt.format(project_dir=project_dir) + ) try: await run_migration( @@ -169,6 +173,7 @@ async def process_one_fileset(index, files: FileGroup, task_name: str): target_dir=manifest.target_dir, target_basename=files.base_name, tools=tools, + goose_config=manifest.goose_config, ) new_result = "pass" await status_manager.mark_with_status(task_name, Status.PASSED) diff --git a/src/ai_migrate/test_manifest.py b/src/ai_migrate/test_manifest.py index 972db47..b57831e 100644 --- a/src/ai_migrate/test_manifest.py +++ b/src/ai_migrate/test_manifest.py @@ -22,7 +22,8 @@ def test_valid_manifest(): "system_prompt": "{project_dir}/system_prompt.md", "verify_cmd": "{py} {project_dir}/verify.py", "pre_verify_cmd": "{py} {project_dir}/verify.py --pre", - "time": "2025-02-10T11:26:33.969758" + "time": "2025-02-10T11:26:33.969758", + "goose": {} } """ Manifest.model_validate_json(json) @@ -57,7 +58,12 @@ def test_valid_manifest_groups(): "system_prompt": "{project_dir}/system_prompt.md", "verify_cmd": "{py} {project_dir}/verify.py", "pre_verify_cmd": "{py} {project_dir}/verify.py --pre", - "time": "2025-02-10T11:26:33.969758" + "time": "2025-02-10T11:26:33.969758", + "goose_config": { + "user_prompt": "{project_dir}/goose_prompt.md", + "timeout_seconds": 100, + "max_retries": 3 + } } """ manifest = Manifest.model_validate_json(json) @@ -93,6 +99,10 @@ def test_valid_manifest_groups(): # The hash should be different because the glob pattern is different assert dir_glob_group_name.split("-")[1] != dir_group_name.split("-")[1] + assert manifest.goose_config.user_prompt == "{project_dir}/goose_prompt.md" + assert manifest.goose_config.timeout_seconds == 100 + assert manifest.goose_config.max_retries == 3 + def test_normalize_files(): """Test that Directory.to_file_group correctly converts Directory objects to FileGroup objects."""