diff --git a/src/ai_migrate/progress.py b/src/ai_migrate/progress.py index b56f75a..c474906 100644 --- a/src/ai_migrate/progress.py +++ b/src/ai_migrate/progress.py @@ -1,24 +1,54 @@ import asyncio import sys import itertools +from enum import StrEnum from typing import Dict import shutil +class Status(StrEnum): + PASSED = "passed" + FAILED = "failed" + RUNNING = "running" + WAITING = "waiting" + + +class StatusLog: + def __init__(self, line_limit): + self.line_limit = line_limit + self.lines = [""] * line_limit + self.header = "" + + def write(self, s: str): + self.lines.extend(s.removesuffix("\n").splitlines()) + self.lines = self.lines[-self.line_limit :] + + def flush(self): + pass + + def close(self): + pass + + def getvalue(self): + body = "\n".join(self.lines) + return f"{self.header}\n{body}" if self.header else body + + class StatusBar: def __init__(self, name: str = "Task"): self.name = name - self.status = None + self.status = Status.WAITING self.message = "" self.spinner_chars = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏" self.spinner_iter = itertools.cycle(self.spinner_chars) + self.logger = StatusLog(line_limit=3) def render(self) -> str: terminal_width = shutil.get_terminal_size((80, 20)).columns - if self.status == "passed": + if self.status == Status.PASSED: status_symbol = "✓" - elif self.status == "failed": + elif self.status == Status.FAILED: status_symbol = "✗" else: status_symbol = next(self.spinner_iter) @@ -29,7 +59,21 @@ def render(self) -> str: name_part = f"{self.name}: " padding = terminal_width - len(name_part) - len(right_part) - return f"\r{name_part}{' ' * max(0, padding)}{right_part}" + logs = ( + [ + f" {line}"[:terminal_width] + for line in self.logger.getvalue().splitlines() + ] + if self.status == Status.RUNNING + else [] + ) + return f"\r{name_part}{' ' * max(0, padding)}{right_part}" + ( + f"\n{'\n'.join(logs)}" if logs else "" + ) + + def get_logger(self, header: str): + self.logger.header = header + return self.logger class StatusManager: @@ -62,22 +106,33 @@ async def render(self): sys.stdout.write("\033[2K\033[A" * self._last_render_lines) sys.stdout.flush() - self._last_render_lines = len(self.bars) + self._last_render_lines = 0 bars = list(self.bars.values()) - bars = sorted(bars, key=lambda bar: (bar.status or "~", bar.name)) - for bar in bars: - print(bar.render()) - - async def mark_passed(self, name: str): + bars = itertools.groupby( + sorted(bars, key=lambda bar: (bar.status, bar.name)), + key=lambda bar: bar.status, + ) + for status, bars in bars: + bars = [*bars] + if status == Status.WAITING: + print(f"{len(bars)} more in queue...") + self._last_render_lines += 1 + else: + for bar in bars: + rendered = bar.render() + print(rendered) + self._last_render_lines += len(rendered.splitlines()) + + async def mark_with_status(self, name: str, status: Status): async with self.lock: if name in self.bars: - self.bars[name].status = "passed" + self.bars[name].status = status await self.render() - async def mark_failed(self, name: str): + async def set_message(self, name: str, message: str): async with self.lock: if name in self.bars: - self.bars[name].status = "failed" + self.bars[name].message = message await self.render() async def stop(self): @@ -89,8 +144,5 @@ async def stop(self): except asyncio.CancelledError: pass - async def update_message(self, name: str, message: str): - async with self.lock: - if name in self.bars: - self.bars[name].message = message - await self.render() + def get_logger(self, name: str, header: str = ""): + return self.bars[name].get_logger(header=header) diff --git a/src/ai_migrate/projects.py b/src/ai_migrate/projects.py index 0379856..580e0fe 100644 --- a/src/ai_migrate/projects.py +++ b/src/ai_migrate/projects.py @@ -18,7 +18,7 @@ Directory, ) from .migrate import run as run_migration, FailedPreVerification -from .progress import StatusManager +from .progress import StatusManager, Status def get_git_sha(directory: str | Path) -> str: @@ -75,6 +75,19 @@ def load_tools_from_dir(project_dir: str) -> list[Tool]: return [] +class Tee: + def __init__(self, *files): + self.files = files + + def write(self, data): + for f in self.files: + f.write(data) + + def flush(self): + for f in self.files: + f.flush() + + async def run( project_dir: str, logs_dir: str | Path, @@ -127,12 +140,17 @@ async def process_one_fileset(index, files: FileGroup, task_name: str): if target_sha is None: target_sha = get_git_sha(Path(files.files[0]).parent) - await status_manager.update_message(task_name, "Running...") + await status_manager.mark_with_status(task_name, Status.RUNNING) log_file = (logs_dir / task_name).with_suffix(".log") log_file.parent.mkdir(parents=True, exist_ok=True) - log_buffer = open(log_file, "w") + + logger = Tee( + status_manager.get_logger(task_name, header=f"==> {log_file} <=="), + log_buffer, + ) + try: await run_migration( files.files, @@ -144,7 +162,7 @@ async def process_one_fileset(index, files: FileGroup, task_name: str): pre_verify_cmd=manifest.pre_verify_cmd.format( project_dir=project_dir, py=sys.executable ), - log_stream=log_buffer, + log_stream=logger, local_worktrees=local_worktrees, llm_fakes=llm_fakes, dont_create_evals=dont_create_evals, @@ -153,16 +171,15 @@ async def process_one_fileset(index, files: FileGroup, task_name: str): tools=tools, ) new_result = "pass" - await status_manager.mark_passed(task_name) + await status_manager.mark_with_status(task_name, Status.PASSED) except FailedPreVerification: - await status_manager.mark_failed(task_name) + await status_manager.mark_with_status(task_name, Status.FAILED) new_result = "fail-pre-verify" except Exception: - await status_manager.mark_failed(task_name) - traceback.print_exc(file=log_buffer) + await status_manager.mark_with_status(task_name, Status.FAILED) + traceback.print_exc(file=logger) new_result = "fail" finally: - await status_manager.update_message(task_name, "") log_buffer.close() results.append(FileGroup(files=files.files, result=new_result)) @@ -184,7 +201,6 @@ async def process_one_with_sem(index, files: FileGroup, task_name: str): task_name = task_name + f" (+{len(file_set.files) - 1})" await status_manager.add_status(task_name) - await status_manager.update_message(task_name, "Waiting...") tg.create_task(process_one_with_sem(i, file_set, task_name)) print("Project run complete.")