diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 029ff40..49102fd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,6 +6,7 @@ repos: rev: 25.1.0 hooks: - id: black-jupyter + types: [jupyter] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: diff --git a/src/fhda/Dockerfile.custom_deployment b/src/fhda/Dockerfile.custom_deployment index 3914fef..c94c3fd 100644 --- a/src/fhda/Dockerfile.custom_deployment +++ b/src/fhda/Dockerfile.custom_deployment @@ -5,8 +5,7 @@ WORKDIR /app ENV PYTHONUNBUFFERED=1 ENV DEBIAN_FRONTEND=noninteractive -RUN --mount=type=cache,target=/var/cache/apt \ - apt-get update -qq && \ +RUN apt-get update -qq && \ apt-get install -yq --no-install-recommends \ git \ openssh-client \ @@ -28,15 +27,13 @@ ENV PATH="/app/miniconda/bin:$PATH" ENV PYTHONPATH="/app/miniconda/lib/python3.12/site-packages:${PYTHONPATH:-}" # Install uv & mamba -RUN --mount=type=cache,target=/root/.cache/pip \ - pip3 install --no-cache-dir uv==0.5.21 -RUN --mount=type=cache,target=/app/miniconda/pkgs \ - conda install -c conda-forge mamba -y +RUN pip3 install --no-cache-dir uv==0.5.21 +RUN conda install -c conda-forge mamba -y # Install R and kernels in the crow_env environment -RUN --mount=type=cache,target=/app/miniconda/pkgs \ - mamba install -c conda-forge -y \ +RUN mamba install -c conda-forge -y \ r-base=4.3.3 \ + r-r.utils=2.13.0 \ r-recommended=4.3 \ r-irkernel=1.3.2 \ r-factominer=2.11 \ @@ -86,13 +83,10 @@ RUN --mount=type=cache,target=/app/miniconda/pkgs \ statsmodels=0.14.4 \ umap-learn=0.5.7 -RUN --mount=type=cache,target=/app/miniconda/pkgs \ - python -m ipykernel install --user --name python3 --display-name "Python 3 (ipykernel)" -RUN --mount=type=cache,target=/app/miniconda/pkgs \ - R -e 'IRkernel::installspec(name = "R", displayname = "R (4.3.3)")' +RUN python -m ipykernel install --user --name python3 --display-name "Python 3 (ipykernel)" +RUN R -e 'IRkernel::installspec(name = "R", displayname = "R (4.3.3)")' -RUN --mount=type=cache,target=/app/miniconda/pkgs \ - mamba install -c conda-forge -c bioconda -y \ +RUN mamba install -c conda-forge -c bioconda -y \ biokit=0.5.0 \ gseapy=1.1.4 \ blast=2.16.0 \ @@ -116,7 +110,9 @@ RUN --mount=type=cache,target=/app/miniconda/pkgs \ bioconductor-summarizedexperiment=1.32.0 \ bioconductor-apeglm=1.24.0 \ bioconductor-flowcore=2.14.0 \ - bioconductor-flowmeans=1.62.0 + bioconductor-flowmeans=1.62.0 \ + bioconductor-limma=3.58.1 \ + bioconductor-geoquery=2.70.0 ENV UV_COMPILE_BYTECODE=1 ENV UV_LINK_MODE=copy @@ -131,8 +127,7 @@ FROM base AS builder ARG MODULE_NAME -RUN --mount=type=cache,target=/var/cache/apt \ - apt-get update -qq && \ +RUN apt-get update -qq && \ apt-get install -yq --no-install-recommends \ build-essential && \ apt-get clean && rm -rf /var/lib/apt/lists/* @@ -147,9 +142,7 @@ COPY ./scripts/run_crow_job.py /app/scripts/ # Install application dependencies (this will only rerun when code changes) WORKDIR /app/${MODULE_NAME} -RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=cache,target=/app/miniconda/pkgs \ - if [ -f "pyproject.toml" ]; then \ +RUN if [ -f "pyproject.toml" ]; then \ uv pip install --system -e .; \ elif [ -f "requirements.txt" ]; then \ uv pip install --system -r requirements.txt; \ @@ -167,4 +160,4 @@ COPY --from=builder /app/ /app/ ENV VIRTUAL_ENV="/app/miniconda/bin" ENV PATH="/app/miniconda/bin:$PATH" ENV PYTHONPATH="/app/miniconda/lib/python3.12/site-packages:${PYTHONPATH:-}" -CMD ["python", "scripts/run_crow_job.py"] +CMD ["python", "scripts/run_crow_job.py"] \ No newline at end of file diff --git a/src/fhda/config.py b/src/fhda/config.py index 7508480..27aa514 100644 --- a/src/fhda/config.py +++ b/src/fhda/config.py @@ -25,4 +25,4 @@ # FutureHosue client config ENVIRONMENT = os.getenv("ENVIRONMENT", "prod") CROW_STAGE = getattr(Stage, ENVIRONMENT.upper(), Stage.PROD) -PLATFORM_API_KEY = os.getenv("CROW_API_KEY", None) +PLATFORM_API_KEY = os.getenv("FH_API_KEY", None) diff --git a/src/fhda/data_analysis_env.py b/src/fhda/data_analysis_env.py index 0d70f7a..d89764b 100644 --- a/src/fhda/data_analysis_env.py +++ b/src/fhda/data_analysis_env.py @@ -15,7 +15,7 @@ from futurehouse_client import FutureHouseClient from .notebook_env import NBEnvironment -from .utils import NBLanguage, MultipleChoiceQuestion +from .utils import NBLanguage, MultipleChoiceQuestion, extract_xml_content from . import prompts from . import config as cfg @@ -174,6 +174,7 @@ def from_task( trajectory_id: str | None = None, user_id: str | None = None, environment_config: dict[str, Any] | None = None, + continued_trajectory_id: str | None = None, ) -> "DataAnalysisEnv": """ Perform data analysis on a user query. @@ -188,18 +189,21 @@ def from_task( logger.info("environment_config: %s", environment_config) logger.info("trajectory_id: %s", trajectory_id) logger.info("user_id: %s", user_id) - # Track cost of running the environment + logger.info("continued_trajectory_id: %s", continued_trajectory_id) enable_cost_tracking() + if ( - not gcs_artifact_path + (not gcs_artifact_path) and not continued_trajectory_id ): # Platform jobs should always be associated with data from a GCS bucket raise NotImplementedError( "Running crow jobs without gcs_artifact_path is not supported" ) if user_id is None: + logger.warning("No user_id provided, using default_user") user_id = "default_user" if trajectory_id is None: + logger.warning("No trajectory_id provided, using time-based id") trajectory_id = f"{gcs_artifact_path}-{time.time()}" if environment_config: kwargs = { @@ -214,11 +218,49 @@ def from_task( trajectory_path = ( cfg.DATA_STORAGE_PATH / "user_trajectories" / user_id / trajectory_id ) - if environment_config.get("gcs_override", False): - data_path = cfg.DATA_STORAGE_PATH / gcs_artifact_path + + if continued_trajectory_id: + kwargs["rerun_all_cells"] = True + data_path = ( + cfg.DATA_STORAGE_PATH + / "user_trajectories" + / user_id + / continued_trajectory_id + ) + logger.info("Continuing trajectory from %s", continued_trajectory_id) + if cfg.PLATFORM_API_KEY is None: + logger.warning( + "Platform API key is not set, can't fetch previous trajectory" + ) + previous_research_question = None + previous_final_answer = None + else: + logger.info("Fetching previous trajectory") + client = FutureHouseClient( + stage=cfg.CROW_STAGE, + auth_type=AuthType.API_KEY, + api_key=cfg.PLATFORM_API_KEY, + ) + previous_trajectory = client.get_task( + continued_trajectory_id, verbose=True + ) + previous_research_question = extract_xml_content( + previous_trajectory.query, "query" + ) + previous_final_answer = previous_trajectory.environment_frame["state"][ + "state" + ]["answer"] + language = previous_trajectory.environment_frame["state"]["info"][ + "language" + ] + language = getattr(NBLanguage, language.upper()) + kwargs["language"] = language + + elif environment_config.get("gcs_override", False): + data_path = cfg.DATA_STORAGE_PATH / gcs_artifact_path # type: ignore else: data_path = ( - cfg.DATA_STORAGE_PATH / "user_data" / user_id / gcs_artifact_path + cfg.DATA_STORAGE_PATH / "user_data" / user_id / gcs_artifact_path # type: ignore ) logger.info("Trajectory path: %s", trajectory_path) logger.info("Data path: %s", data_path) @@ -230,12 +272,19 @@ def from_task( shutil.copytree(item, trajectory_path / item.name, dirs_exist_ok=True) logger.info("Filtered kwargs: %s", kwargs) - language = getattr(NBLanguage, environment_config.get("language", "PYTHON")) - # Overwrite the language in the kwargs with NBLanguage enum - kwargs["language"] = language - logger.info("Language: %s", language.name) + # If it's continued, we already have the language + if continued_trajectory_id: + logger.info( + "Language already set from previous trajectory notebook %s", + kwargs.get("language", None), + ) + else: + language = getattr(NBLanguage, environment_config.get("language", "PYTHON")) + # Overwrite the language in the kwargs with NBLanguage enum + kwargs["language"] = language + logger.info("Language: %s", language.name) - if not environment_config.get("eval", False): + if not environment_config.get("eval", False) and not continued_trajectory_id: logger.info( "Platform job detected, augmenting user query with CoT instructions" ) @@ -248,6 +297,17 @@ def from_task( f"{task}\n" f"\n" ) + if continued_trajectory_id and not environment_config.get("eval", False): + logger.info( + "Continuation job detected, augmenting user query with continuation instructions" + ) + task = prompts.CONTINUATION_PROMPT_TEMPLATE.format( + previous_research_question=previous_research_question, + previous_final_answer=previous_final_answer, + query=task, + language=kwargs.get("language", "PYTHON"), + ) + nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME logger.info("NB path: %s", nb_path) diff --git a/src/fhda/notebook_env.py b/src/fhda/notebook_env.py index 8569dfd..a1cf7b2 100644 --- a/src/fhda/notebook_env.py +++ b/src/fhda/notebook_env.py @@ -124,6 +124,7 @@ def __init__( language: utils.NBLanguage = utils.NBLanguage.PYTHON, allow_download_from_gcs: bool = False, run_notebook_on_edit: bool = False, + rerun_all_cells: bool = False, ): """Initialize a notebook environment. @@ -140,6 +141,7 @@ def __init__( task requires data on GCS. Disabled by default. run_notebook_on_edit: If True (default), the whole notebook will be rerun after each edit. If False, only the cell that was edited will be rerun. + rerun_all_cells: If True, the whole notebook will be run at the beginning of the episode. This is for continued trajectories. """ self.work_dir = Path(work_dir) self.nb_path = Path(nb_path) if nb_path else self.work_dir / self.NOTEBOOK_NAME @@ -149,6 +151,7 @@ def __init__( self.allow_download_from_gcs = allow_download_from_gcs self.use_docker = cfg.USE_DOCKER self.run_notebook_on_edit = run_notebook_on_edit + self.rerun_all_cells = rerun_all_cells async def reset(self) -> tuple[Messages, list[Tool]]: nb_path, work_dir = self._set_work_dir() @@ -158,6 +161,8 @@ async def reset(self) -> tuple[Messages, list[Tool]]: language=self.language, use_docker=self.use_docker, ) + if self.rerun_all_cells: + await self.run_notebook() self.tools = [ Tool.from_function(self.edit_cell), diff --git a/src/fhda/prompts.py b/src/fhda/prompts.py index 0105a5b..786ce04 100644 --- a/src/fhda/prompts.py +++ b/src/fhda/prompts.py @@ -238,3 +238,28 @@ {GENERAL_NOTEBOOK_GUIDELINES} {R_SPECIFIC_GUIDELINES} """ + +CONTINUATION_PROMPT_TEMPLATE = f""" +{GENERAL_NOTEBOOK_GUIDELINES} + +You have been provided with a notebook previously generated by an agent based on a user's research question. + +This was the user's research question: + +{{previous_research_question}} + + +This was the final answer generated by the previous agent: + +{{previous_final_answer}} + + +The user has now tasked you with addressing a new query: + +{{query}} + + +Please make any edits required to the notebook and the answer to address the new query. Be extremely diligent and ensure that the notebook is fully updated to address the new query. +Note you may have to run all cells one by one again if the user query involved updating one of the intermediate cells and subsequent cells depend on it. +Once you have updated the notebook, use the submit_answer tool to submit your final answer once the user's query is addressed. +""" diff --git a/src/fhda/tortoise.py b/src/fhda/tortoise.py index 26a75d3..d30ce76 100644 --- a/src/fhda/tortoise.py +++ b/src/fhda/tortoise.py @@ -7,11 +7,18 @@ import time import json from pydantic import BaseModel, Field +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) from . import prompts from futurehouse_client import FutureHouseClient from futurehouse_client.models import TaskRequest, RuntimeConfig from futurehouse_client.models.app import AuthType +from futurehouse_client.clients.rest_client import TaskFetchError class StepConfig(BaseModel): @@ -95,7 +102,9 @@ class Tortoise: def __init__(self, api_key: str): """Initialize the tortoise framework with FutureHouse API key.""" - self.client = FutureHouseClient(auth_type=AuthType.API_KEY, api_key=api_key) + self.client = FutureHouseClient( + auth_type=AuthType.API_KEY, api_key=api_key, verbose=True + ) self.steps: list[Step] = [] self.results: dict[str, Any] = {} @@ -119,6 +128,33 @@ def save_results(self, output_dir: str | PathLike = "output") -> None: except Exception as e: print(f"Error saving results to {results_path}: {e}") + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + retry=retry_if_exception_type(Exception), + ) + def _upload_file_with_retry( + self, job_name: str, file_path: str, upload_id: str + ) -> None: + """Upload a file with retry logic.""" + self.client.upload_file(job_name, file_path=file_path, upload_id=upload_id) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + retry=retry_if_exception_type(Exception), + ) + def _download_file_with_retry( + self, job_name: str, trajectory_id: str, file_path: str, destination_path: str + ) -> None: + """Download a file with retry logic.""" + self.client.download_file( + job_name, + trajectory_id=trajectory_id, + file_path=file_path, + destination_path=destination_path, + ) + def _create_task_requests( self, step: Step, runtime_config: RuntimeConfig ) -> list[TaskRequest]: @@ -165,6 +201,23 @@ def _create_task_requests( return task_requests + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=2, max=30), + retry=retry_if_exception_type((Exception, TaskFetchError)), + ) + async def _run_tasks_with_retry( + self, task_requests, progress_bar, verbose, timeout + ): + """Run tasks with retry logic.""" + return await self.client.arun_tasks_until_done( + task_requests, + progress_bar=progress_bar, + verbose=verbose, + timeout=timeout, + concurrency=1, # Reduce concurrency to avoid overwhelming the server + ) + async def run_pipeline( self, output_dir: str | PathLike = "output" ) -> dict[str, Any]: @@ -178,9 +231,15 @@ async def run_pipeline( for source_path, dest_name in step.input_files.items(): print(f"Uploading file {source_path} as {dest_name}") - self.client.upload_file( - step.name, file_path=source_path, upload_id=step.upload_id - ) + try: + self._upload_file_with_retry( + step.name, file_path=source_path, upload_id=step.upload_id + ) + except Exception as e: + print( + f"Failed to upload file {source_path} after multiple retries: {e}" + ) + raise if step.config: runtime_config = RuntimeConfig( @@ -199,12 +258,25 @@ async def run_pipeline( print( f"Running {len(task_requests)} task{'s' if len(task_requests) > 1 else ''}" ) - task_responses = await self.client.arun_tasks_until_done( - task_requests, - progress_bar=True, - verbose=False, - timeout=step.config.timeout, - ) + try: + task_responses = await self._run_tasks_with_retry( + task_requests, + progress_bar=True, + verbose=False, + timeout=step.config.timeout, + ) + except Exception as e: + print( + f"Failed to run tasks for step {step.step_id} after multiple retries: {e}" + ) + # Create an error result entry and continue to the next step + self.results[step.step_id] = { + "task_ids": [], + "task_responses": [], + "success_rate": 0, + "error": str(e), + } + continue task_ids = [str(task.task_id) for task in task_responses] success_rate = sum( @@ -236,12 +308,17 @@ async def run_pipeline( os.path.dirname(os.path.abspath(path)), exist_ok=True ) print(f"Downloading file {source_name} to {path}") - self.client.download_file( - step.name, - trajectory_id=task_id, - file_path=source_name, - destination_path=path, - ) + try: + self._download_file_with_retry( + step.name, + trajectory_id=task_id, + file_path=source_name, + destination_path=path, + ) + except Exception as e: + print( + f"Failed to download {source_name} from task {task_id} after multiple retries: {e}" + ) except Exception as e: print( f"Error downloading {source_name} from task {task_id}: {e}" diff --git a/src/fhda/utils.py b/src/fhda/utils.py index f06f3e6..7d1532e 100644 --- a/src/fhda/utils.py +++ b/src/fhda/utils.py @@ -5,6 +5,7 @@ from enum import StrEnum, auto from typing import TYPE_CHECKING, assert_never import os +import re import nbformat from traitlets.config import Config @@ -438,3 +439,23 @@ def nb_to_html(nb: nbformat.NotebookNode) -> str: exporter = HTMLExporter(config=c) html, _ = exporter.from_notebook_node(nb) return html + + +def extract_xml_content(text, tag_name): + """ + Extract content between XML-like tags from text. + + Args: + text (str): The text containing XML-like tags + tag_name (str): The name of the tag to extract content from + + Returns: + str or None: The content between the tags, or None if not found + """ + + pattern = f"<{tag_name}>(.*?)" + match = re.search(pattern, text, re.DOTALL) + + if match: + return match.group(1).strip() + return None diff --git a/src/scripts/deploy.py b/src/scripts/deploy.py index 8b8a848..186d73c 100644 --- a/src/scripts/deploy.py +++ b/src/scripts/deploy.py @@ -13,7 +13,7 @@ from futurehouse_client.models.app import TaskQueuesConfig HIGH = True -ENVIRONMENT = "DEV" +ENVIRONMENT = "PROD" ENV_VARS = { # "OPENAI_API_KEY": os.environ["OPENAI_API_KEY"], @@ -21,7 +21,7 @@ "USE_DOCKER": "false", "STAGE": ENVIRONMENT, "ENVIRONMENT": ENVIRONMENT, - "API_KEY": os.environ[f"CROW_API_KEY_{ENVIRONMENT}"], + "FH_API_KEY": os.environ[f"CROW_API_KEY_{ENVIRONMENT}"], } CONTAINER_CONFIG = DockerContainerConfiguration(cpu="8", memory="16Gi")