Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 70 additions & 18 deletions src/ai_migrate/progress.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,54 @@
import asyncio
import sys
import itertools
from enum import StrEnum
from typing import Dict
import shutil


class Status(StrEnum):
PASSED = "passed"
FAILED = "failed"
RUNNING = "running"
WAITING = "waiting"


class StatusLog:
def __init__(self, line_limit):
self.line_limit = line_limit
self.lines = [""] * line_limit
self.header = ""

def write(self, s: str):
self.lines.extend(s.removesuffix("\n").splitlines())
self.lines = self.lines[-self.line_limit :]

def flush(self):
pass

def close(self):
pass

def getvalue(self):
body = "\n".join(self.lines)
return f"{self.header}\n{body}" if self.header else body


class StatusBar:
def __init__(self, name: str = "Task"):
self.name = name
self.status = None
self.status = Status.WAITING
self.message = ""
self.spinner_chars = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"
self.spinner_iter = itertools.cycle(self.spinner_chars)
self.logger = StatusLog(line_limit=3)

def render(self) -> str:
terminal_width = shutil.get_terminal_size((80, 20)).columns

if self.status == "passed":
if self.status == Status.PASSED:
status_symbol = "✓"
elif self.status == "failed":
elif self.status == Status.FAILED:
status_symbol = "✗"
else:
status_symbol = next(self.spinner_iter)
Expand All @@ -29,7 +59,21 @@ def render(self) -> str:

name_part = f"{self.name}: "
padding = terminal_width - len(name_part) - len(right_part)
return f"\r{name_part}{' ' * max(0, padding)}{right_part}"
logs = (
[
f" {line}"[:terminal_width]
for line in self.logger.getvalue().splitlines()
]
if self.status == Status.RUNNING
else []
)
return f"\r{name_part}{' ' * max(0, padding)}{right_part}" + (
f"\n{'\n'.join(logs)}" if logs else ""
)

def get_logger(self, header: str):
self.logger.header = header
return self.logger


class StatusManager:
Expand Down Expand Up @@ -62,22 +106,33 @@ async def render(self):
sys.stdout.write("\033[2K\033[A" * self._last_render_lines)
sys.stdout.flush()

self._last_render_lines = len(self.bars)
self._last_render_lines = 0
bars = list(self.bars.values())
bars = sorted(bars, key=lambda bar: (bar.status or "~", bar.name))
for bar in bars:
print(bar.render())

async def mark_passed(self, name: str):
bars = itertools.groupby(
sorted(bars, key=lambda bar: (bar.status, bar.name)),
key=lambda bar: bar.status,
)
for status, bars in bars:
bars = [*bars]
if status == Status.WAITING:
print(f"{len(bars)} more in queue...")
self._last_render_lines += 1
else:
for bar in bars:
rendered = bar.render()
print(rendered)
self._last_render_lines += len(rendered.splitlines())

async def mark_with_status(self, name: str, status: Status):
async with self.lock:
if name in self.bars:
self.bars[name].status = "passed"
self.bars[name].status = status
await self.render()

async def mark_failed(self, name: str):
async def set_message(self, name: str, message: str):
async with self.lock:
if name in self.bars:
self.bars[name].status = "failed"
self.bars[name].message = message
await self.render()

async def stop(self):
Expand All @@ -89,8 +144,5 @@ async def stop(self):
except asyncio.CancelledError:
pass

async def update_message(self, name: str, message: str):
async with self.lock:
if name in self.bars:
self.bars[name].message = message
await self.render()
def get_logger(self, name: str, header: str = ""):
return self.bars[name].get_logger(header=header)
36 changes: 26 additions & 10 deletions src/ai_migrate/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Directory,
)
from .migrate import run as run_migration, FailedPreVerification
from .progress import StatusManager
from .progress import StatusManager, Status


def get_git_sha(directory: str | Path) -> str:
Expand Down Expand Up @@ -75,6 +75,19 @@ def load_tools_from_dir(project_dir: str) -> list[Tool]:
return []


class Tee:
def __init__(self, *files):
self.files = files

def write(self, data):
for f in self.files:
f.write(data)

def flush(self):
for f in self.files:
f.flush()


async def run(
project_dir: str,
logs_dir: str | Path,
Expand Down Expand Up @@ -127,12 +140,17 @@ async def process_one_fileset(index, files: FileGroup, task_name: str):
if target_sha is None:
target_sha = get_git_sha(Path(files.files[0]).parent)

await status_manager.update_message(task_name, "Running...")
await status_manager.mark_with_status(task_name, Status.RUNNING)

log_file = (logs_dir / task_name).with_suffix(".log")
log_file.parent.mkdir(parents=True, exist_ok=True)

log_buffer = open(log_file, "w")

logger = Tee(
status_manager.get_logger(task_name, header=f"==> {log_file} <=="),
log_buffer,
)

try:
await run_migration(
files.files,
Expand All @@ -144,7 +162,7 @@ async def process_one_fileset(index, files: FileGroup, task_name: str):
pre_verify_cmd=manifest.pre_verify_cmd.format(
project_dir=project_dir, py=sys.executable
),
log_stream=log_buffer,
log_stream=logger,
local_worktrees=local_worktrees,
llm_fakes=llm_fakes,
dont_create_evals=dont_create_evals,
Expand All @@ -153,16 +171,15 @@ async def process_one_fileset(index, files: FileGroup, task_name: str):
tools=tools,
)
new_result = "pass"
await status_manager.mark_passed(task_name)
await status_manager.mark_with_status(task_name, Status.PASSED)
except FailedPreVerification:
await status_manager.mark_failed(task_name)
await status_manager.mark_with_status(task_name, Status.FAILED)
new_result = "fail-pre-verify"
except Exception:
await status_manager.mark_failed(task_name)
traceback.print_exc(file=log_buffer)
await status_manager.mark_with_status(task_name, Status.FAILED)
traceback.print_exc(file=logger)
new_result = "fail"
finally:
await status_manager.update_message(task_name, "")
log_buffer.close()

results.append(FileGroup(files=files.files, result=new_result))
Expand All @@ -184,7 +201,6 @@ async def process_one_with_sem(index, files: FileGroup, task_name: str):
task_name = task_name + f" (+{len(file_set.files) - 1})"

await status_manager.add_status(task_name)
await status_manager.update_message(task_name, "Waiting...")
tg.create_task(process_one_with_sem(i, file_set, task_name))

print("Project run complete.")
Expand Down