diff --git a/pyproject.toml b/pyproject.toml index 0675846..63032a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ authors = [ ] dependencies = [ "aiodocker==0.24.0", + "anthropic==0.52.2", # this is necessary for tortoise, remove in favor of LMI when it works with search "fhaviary[server]==0.19.0", "ldp==0.26.0", "pandas==2.2.3", @@ -17,11 +18,12 @@ dependencies = [ "google-auth==2.38.0", "google-cloud-storage==3.0.0", "google-cloud-secret-manager==2.23.0", - "futurehouse-client==0.3.18", + "futurehouse-client==0.3.19", "jupyter==1.1.1", "nbconvert==7.16.6", "notebook==7.3.2", - "nbformat==5.10.4" + "nbformat==5.10.4", + "seaborn==0.13.2" ] description = "Data analysis crow" name = "fhda" diff --git a/src/fhda/tortoise.py b/src/fhda/tortoise.py index e591507..8e4b0cd 100644 --- a/src/fhda/tortoise.py +++ b/src/fhda/tortoise.py @@ -13,28 +13,17 @@ wait_exponential, retry_if_exception_type, ) -from . import prompts +from . import config as cfg from futurehouse_client import FutureHouseClient from futurehouse_client.models import TaskRequest, RuntimeConfig -from futurehouse_client.models.app import AuthType -from futurehouse_client.clients.rest_client import TaskFetchError +from futurehouse_client.models.app import AuthType, Stage +import anthropic +import logging +import traceback - -class StepConfig(BaseModel): - """Agent runtime configuration.""" - - language: str = Field( - default="PYTHON", description="Language for execution environment" - ) - max_steps: int = Field( - default=30, description="Maximum number of steps for the agent" - ) - timeout: int = Field(default=15 * 60, description="Timeout for the step in seconds") - eval: bool = Field( - default=True, - description="For Finch, this indicates whether this is an API call or UI call. Setting it to True removes the automatic CoT additions.", - ) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) class Step(BaseModel): @@ -43,10 +32,16 @@ class Step(BaseModel): name: str = Field( description="Name of the job to run (e.g. 'job-futurehouse-data-analysis-crow-high')" ) - prompt_template: str = Field(description="Prompt template to use for the step") - cot_prompt: bool = Field( - default=False, description="Whether to augment the query with COT prompting" + llm_call: bool = Field( + default=False, description="Whether to call the LLM for the step" + ) + include_search_tool: bool = Field( + default=False, description="Whether to include the search tool in the LLM call" ) + model_name: str = Field( + default=cfg.DEFAULT_MODEL, description="Name of the model to use for the step" + ) + prompt_template: str = Field(description="Prompt template to use for the step") prompt_args: dict[str, Any] = Field( default_factory=dict, description="Keyword arguments to format the prompt template.", @@ -59,13 +54,14 @@ class Step(BaseModel): description="Files to download {'source_name': 'dest_path'}", ) step_id: str = Field( - default_factory=lambda: str(uuid.uuid4())[:8], + default_factory=lambda: str(uuid.uuid4()), description="Small UID for the step", ) - upload_id: Optional[str] = Field(default=None, description="Upload ID for GCS") - parallel: int = Field(default=1, description="Number of parallel tasks to run") - config: StepConfig = Field( - default_factory=StepConfig, description="Configuration for the step" + n_replicate_tasks: int = Field( + default=1, description="Number of parallel tasks to run" + ) + runtime_config: RuntimeConfig = Field( + default_factory=RuntimeConfig, description="Configuration for the step" ) post_process: Optional[Callable[[dict[str, Any], str], None]] = Field( default=None, description="Function to run after step completion" @@ -74,36 +70,24 @@ class Step(BaseModel): default=None, description="Function to generate prompts and args for parallel tasks based on previous results", ) - - def cot_prompting(self, query: str, language: str) -> str: - """Apply chain-of-thought prompting to the query.""" - guidelines = prompts.GENERAL_NOTEBOOK_GUIDELINES.format(language=language) - if language == "R": - guidelines = prompts.R_SPECIFIC_GUIDELINES.format(language=language) - return ( - f"{prompts.CHAIN_OF_THOUGHT_AGNOSTIC.format(language=language)}\n" - f"{guidelines}" - f"Here is the research question to address:\n" - f"\n" - f"{query}\n" - f"\n" - ) + timeout: int = Field(default=15 * 60, description="Timeout for the step in seconds") def format_prompt(self) -> str: """Format the prompt template with the provided arguments.""" final_prompt = self.prompt_template.format(**self.prompt_args) - if self.cot_prompt: - final_prompt = self.cot_prompting(final_prompt, self.config.language) return final_prompt class Tortoise: """Runner for multi-step agent pipelines.""" - def __init__(self, api_key: str): + def __init__(self, api_key: str, environment: str = "PROD"): """Initialize the tortoise framework with FutureHouse API key.""" self.client = FutureHouseClient( - auth_type=AuthType.API_KEY, api_key=api_key, verbose_logging=True + auth_type=AuthType.API_KEY, + api_key=api_key, + verbose_logging=True, + stage=getattr(Stage, environment.upper(), Stage.PROD), ) self.steps: list[Step] = [] self.results: dict[str, Any] = {} @@ -115,7 +99,7 @@ def add_step(self, step: Step) -> None: def save_results(self, output_dir: str | PathLike = "output") -> None: """Save the results to a JSON file.""" results_path = f"{output_dir}/results_{time.strftime('%Y%m%d_%H%M%S')}.json" - print(f"Saving all results to {results_path}") + logger.info(f"Saving all results to {results_path}") try: os.makedirs(output_dir, exist_ok=True) serializable_results = {} @@ -123,10 +107,10 @@ def save_results(self, output_dir: str | PathLike = "output") -> None: serializable_results[step_id] = dict(step_result) with open(results_path, "w") as f: - json.dump(serializable_results, f, indent=2) - print(f"Results successfully saved to {results_path}") + json.dump(serializable_results, f, indent=2, default=str) + logger.info(f"Results successfully saved to {results_path}") except Exception as e: - print(f"Error saving results to {results_path}: {e}") + logger.error(f"Error saving results to {results_path}: {e}") @retry( stop=stop_after_attempt(3), @@ -168,7 +152,21 @@ def _create_task_requests( List of task requests to be executed """ task_requests = [] - task_count = max(step.parallel, 1) + task_count = max(step.n_replicate_tasks, 1) + + if step.model_name: + agent_config = cfg.get_custom_agent_config(step.model_name) + runtime_config.agent = agent_config + + if step.runtime_config.continued_job_id: + task_ids = self.results[str(step.runtime_config.continued_job_id)][ + "task_ids" + ] + if len(task_ids) > 1: + logger.warning( + f"Continued job {step.runtime_config.continued_job_id} has multiple task ids, using the first one" + ) + runtime_config.continued_job_id = str(task_ids[0]) if step.prompt_generator and task_count > 1: # Generate dynamic prompts based on previous results @@ -201,11 +199,34 @@ def _create_task_requests( return task_requests - @retry( - stop=stop_after_attempt(5), - wait=wait_exponential(multiplier=1, min=2, max=30), - retry=retry_if_exception_type((Exception, TaskFetchError)), - ) + async def call_llm(self, step: Step) -> list: + """Call the LLM for the step.""" + anthropic_client = anthropic.Anthropic() + # TODO: This is a hack to get the model name without the provider prefix + model_name = step.model_name.replace("anthropic/", "") + if step.include_search_tool: + tools = [ + { + "type": "web_search_20250305", + "name": "web_search", + } + ] + else: + tools = [] + response = anthropic_client.messages.create( + model=model_name, + messages=[ + { + "role": "user", + "content": step.prompt_template, + } + ], + tools=tools, + max_tokens=8192, + ) + result = "\n".join([r.text for r in response.content if hasattr(r, "text")]) + return [result] + async def _run_tasks_with_retry( self, task_requests, progress_bar, verbose, timeout ): @@ -225,64 +246,60 @@ async def run_pipeline( os.makedirs(output_dir, exist_ok=True) for i, step in enumerate(self.steps): - print(f"Running step {i + 1}/{len(self.steps)}: {step.name}") - if not step.upload_id: - step.upload_id = f"{step.name}_{step.step_id}" + logger.info(f"Running step {i + 1}/{len(self.steps)}: {step.name}") + if not step.runtime_config.upload_id: + step.runtime_config.upload_id = step.step_id for source_path, dest_name in step.input_files.items(): - print(f"Uploading file {source_path} as {dest_name}") + logger.info(f"Uploading file {source_path} as {dest_name}") try: self._upload_file_with_retry( - step.name, file_path=source_path, upload_id=step.upload_id + step.name, + file_path=source_path, + upload_id=step.runtime_config.upload_id, ) except Exception as e: - print( + logger.error( f"Failed to upload file {source_path} after multiple retries: {e}" ) raise - if step.config: - runtime_config = RuntimeConfig( - max_steps=step.config.max_steps, - upload_id=step.upload_id, - environment_config={ - "eval": step.config.eval, - "language": step.config.language, - }, - ) + if step.llm_call: + task_responses = await self.call_llm(step) + task_ids = [f"llm_{str(uuid.uuid4())[:8]}"] + success_rate = 1 else: - runtime_config = None - - task_requests = self._create_task_requests(step, runtime_config) - - print( - f"Running {len(task_requests)} task{'s' if len(task_requests) > 1 else ''}" - ) - try: - task_responses = await self._run_tasks_with_retry( - task_requests, - progress_bar=True, - verbose=False, - timeout=step.config.timeout, - ) - except Exception as e: - print( - f"Failed to run tasks for step {step.step_id} after multiple retries: {e}" - ) - # Create an error result entry and continue to the next step - self.results[step.step_id] = { - "task_ids": [], - "task_responses": [], - "success_rate": 0, - "error": str(e), - } - continue + task_requests = self._create_task_requests(step, step.runtime_config) - task_ids = [str(task.task_id) for task in task_responses] - success_rate = sum( - [task.status == "success" for task in task_responses] - ) / len(task_responses) - print(f"Task success rate: {success_rate * 100}%") + logger.info( + f"Running {len(task_requests)} task{'s' if len(task_requests) > 1 else ''}" + ) + try: + task_responses = await self._run_tasks_with_retry( + task_requests, + progress_bar=True, + verbose=False, + timeout=step.timeout, + ) + except Exception as e: + logger.error( + f"Failed to run tasks for step {step.step_id} after multiple retries: {e}" + ) + logger.error(f"Full traceback:\n{traceback.format_exc()}") + # Create an error result entry and continue to the next step + self.results[step.step_id] = { + "task_ids": [], + "task_responses": [], + "success_rate": 0, + "error": str(e), + } + continue + + task_ids = [str(task.task_id) for task in task_responses] + success_rate = sum( + [task.status == "success" for task in task_responses] + ) / len(task_responses) + logger.info(f"Task success rate: {success_rate * 100}%") self.results[step.step_id] = { "task_ids": task_ids, @@ -307,7 +324,7 @@ async def run_pipeline( os.makedirs( os.path.dirname(os.path.abspath(path)), exist_ok=True ) - print(f"Downloading file {source_name} to {path}") + logger.info(f"Downloading file {source_name} to {path}") try: self._download_file_with_retry( step.name, @@ -316,21 +333,21 @@ async def run_pipeline( destination_path=path, ) except Exception as e: - print( + logger.error( f"Failed to download {source_name} from task {task_id} after multiple retries: {e}" ) except Exception as e: - print( + logger.error( f"Error downloading {source_name} from task {task_id}: {e}" ) if step.post_process: - print(f"Running post-processing for step {step.step_id}") + logger.info(f"Running post-processing for step {step.step_id}") step.post_process( self.results[step.step_id], f"{output_dir}/{step.step_id}" ) - print(f"Completed step {i + 1}/{len(self.steps)}") + logger.info(f"Completed step {i + 1}/{len(self.steps)}") self.save_results(output_dir) return self.results diff --git a/tutorial/consensus.ipynb b/tutorial/consensus.ipynb index cea2619..94ab0b0 100644 --- a/tutorial/consensus.ipynb +++ b/tutorial/consensus.ipynb @@ -141,7 +141,7 @@ " max_steps=30,\n", " upload_id=DEA_UPLOAD_ID,\n", " environment_config={\n", - " \"eval\": True, # DO NOT CHANGE THIS\n", + " \"default_cot_prompt\": False,\n", " \"language\": \"R\",\n", " },\n", ")\n", @@ -200,7 +200,7 @@ " max_steps=30,\n", " upload_id=CONSENSUS_UPLOAD_ID,\n", " environment_config={\n", - " \"eval\": True, # DO NOT CHANGE THIS\n", + " \"default_cot_prompt\": False,\n", " \"language\": \"R\",\n", " },\n", ")\n", @@ -304,7 +304,7 @@ " max_steps=30,\n", " upload_id=PQA_UPLOAD_ID,\n", " environment_config={\n", - " \"eval\": True, # DO NOT CHANGE THIS\n", + " \"default_cot_prompt\": False,\n", " \"language\": \"PYTHON\",\n", " },\n", ")\n", diff --git a/tutorial/example.ipynb b/tutorial/example.ipynb index b3d0bc3..7e66c25 100644 --- a/tutorial/example.ipynb +++ b/tutorial/example.ipynb @@ -75,6 +75,7 @@ " language=language,\n", " system_prompt=prompts.CAPSULE_SYSTEM_PROMPT_QUERY,\n", " use_tmp_work_dir=False,\n", + " run_notebook_on_edit=True if cfg.USE_DOCKER else False,\n", " )\n", " return dae" ] @@ -172,7 +173,7 @@ "outputs": [], "source": [ "# VANILLA ROLLOUT - this is a simple version of the what the rollout Manager does\n", - "dataset_folder = Path(\"datasets/brain_size_data.csv\")\n", + "dataset_folder = Path(\"dataset\")\n", "query = \"Analyze the dataset and give me an in depth analysis using pretty plots. I am particularly interested in crows.\"\n", "environment = setup_data_analysis_env(query, dataset_folder)\n", "\n", diff --git a/tutorial/multi_agent_orchestration.ipynb b/tutorial/multi_agent_orchestration.ipynb index 54059a9..a3d003d 100644 --- a/tutorial/multi_agent_orchestration.ipynb +++ b/tutorial/multi_agent_orchestration.ipynb @@ -17,7 +17,8 @@ "metadata": {}, "outputs": [], "source": [ - "from fhda.tortoise import Tortoise, Step, StepConfig\n", + "from fhda.tortoise import Tortoise, Step\n", + "from futurehouse_client.models import RuntimeConfig\n", "from futurehouse_client import JobNames\n", "import pandas as pd\n", "import json" @@ -90,14 +91,17 @@ "dea_step = Step(\n", " name=JobNames.FINCH,\n", " prompt_template=DEA_PROMPT,\n", - " cot_prompt=True,\n", " prompt_args={\"treatment\": TREATMENT, \"mechanism\": MECHANISM, \"context\": CONTEXT},\n", " input_files={\n", " \"datasets/GSE52778_All_Sample_FPKM_Matrix.txt.gz\": \"GSE52778_series_matrix.txt.gz\"\n", " },\n", " output_files={\"dea_results.csv\": \"dea_results/dea_results.csv\"},\n", - " parallel=PARALLEL_DEA,\n", - " config=StepConfig(language=\"R\", max_steps=30, timeout=15 * 60),\n", + " n_replicate_tasks=PARALLEL_DEA,\n", + " runtime_config=RuntimeConfig(\n", + " max_steps=30,\n", + " environment_config={\"language\": \"R\", \"default_cot_prompt\": True},\n", + " timeout=15 * 60,\n", + " ),\n", ")\n", "tortoise.add_step(dea_step)\n", "\n", @@ -105,13 +109,16 @@ "consensus_step = Step(\n", " name=JobNames.FINCH,\n", " prompt_template=CONSENSUS_PROMPT,\n", - " cot_prompt=True,\n", " input_files={f\"{OUTPUT_DIR}/{dea_step.step_id}/dea_results\": \"dea_results/\"},\n", " output_files={\n", " \"consensus_results.csv\": \"consensus_results.csv\",\n", " f\"top{N_TOP_GENES}_genes.csv\": f\"top{N_TOP_GENES}_genes.csv\",\n", " },\n", - " config=StepConfig(language=\"R\", max_steps=30, timeout=15 * 60),\n", + " runtime_config=RuntimeConfig(\n", + " max_steps=30,\n", + " environment_config={\"language\": \"R\", \"default_cot_prompt\": True},\n", + " timeout=15 * 60,\n", + " ),\n", ")\n", "tortoise.add_step(consensus_step)\n", "\n", @@ -164,7 +171,7 @@ " name=JobNames.CROW,\n", " prompt_template=PQA_PROMPT,\n", " prompt_generator=pqa_prompt_generator,\n", - " parallel=N_TOP_GENES, # Will process all top genes in parallel\n", + " n_replicate_tasks=N_TOP_GENES, # Will process all top genes in parallel\n", " post_process=pqa_post_process,\n", ")\n", "tortoise.add_step(pqa_step)\n", @@ -173,12 +180,15 @@ "volcano_step = Step(\n", " name=JobNames.FINCH,\n", " prompt_template=VOLCANO_PROMPT,\n", - " cot_prompt=True,\n", " input_files={\n", " f\"{OUTPUT_DIR}/{consensus_step.step_id}/consensus_results.csv\": \"consensus_results.csv\",\n", " f\"{OUTPUT_DIR}/{pqa_step.step_id}/pqa_results.csv\": \"pqa_results.csv\",\n", " },\n", - " config=StepConfig(language=\"PYTHON\", max_steps=30, timeout=15 * 60),\n", + " runtime_config=RuntimeConfig(\n", + " max_steps=30,\n", + " environment_config={\"language\": \"PYTHON\", \"default_cot_prompt\": True},\n", + " timeout=15 * 60,\n", + " ),\n", ")\n", "tortoise.add_step(volcano_step)\n", "\n", diff --git a/tutorial/platform_api.ipynb b/tutorial/platform_api.ipynb index c293091..97efb12 100644 --- a/tutorial/platform_api.ipynb +++ b/tutorial/platform_api.ipynb @@ -110,7 +110,7 @@ " max_steps=MAX_STEPS,\n", " upload_id=UPLOAD_ID,\n", " environment_config={\n", - " \"eval\": True, # DO NOT CHANGE THIS\n", + " \"default_cot_prompt\": False,\n", " \"language\": LANGUAGE,\n", " },\n", " ),\n",