Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ repos:
rev: 25.1.0
hooks:
- id: black-jupyter
types: [jupyter]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
Expand Down
35 changes: 14 additions & 21 deletions src/fhda/Dockerfile.custom_deployment
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ WORKDIR /app
ENV PYTHONUNBUFFERED=1
ENV DEBIAN_FRONTEND=noninteractive

RUN --mount=type=cache,target=/var/cache/apt \
apt-get update -qq && \
RUN apt-get update -qq && \
apt-get install -yq --no-install-recommends \
git \
openssh-client \
Expand All @@ -28,15 +27,13 @@ ENV PATH="/app/miniconda/bin:$PATH"
ENV PYTHONPATH="/app/miniconda/lib/python3.12/site-packages:${PYTHONPATH:-}"

# Install uv & mamba
RUN --mount=type=cache,target=/root/.cache/pip \
pip3 install --no-cache-dir uv==0.5.21
RUN --mount=type=cache,target=/app/miniconda/pkgs \
conda install -c conda-forge mamba -y
RUN pip3 install --no-cache-dir uv==0.5.21
RUN conda install -c conda-forge mamba -y

# Install R and kernels in the crow_env environment
RUN --mount=type=cache,target=/app/miniconda/pkgs \
mamba install -c conda-forge -y \
RUN mamba install -c conda-forge -y \
r-base=4.3.3 \
r-r.utils=2.13.0 \
r-recommended=4.3 \
r-irkernel=1.3.2 \
r-factominer=2.11 \
Expand Down Expand Up @@ -86,13 +83,10 @@ RUN --mount=type=cache,target=/app/miniconda/pkgs \
statsmodels=0.14.4 \
umap-learn=0.5.7

RUN --mount=type=cache,target=/app/miniconda/pkgs \
python -m ipykernel install --user --name python3 --display-name "Python 3 (ipykernel)"
RUN --mount=type=cache,target=/app/miniconda/pkgs \
R -e 'IRkernel::installspec(name = "R", displayname = "R (4.3.3)")'
RUN python -m ipykernel install --user --name python3 --display-name "Python 3 (ipykernel)"
RUN R -e 'IRkernel::installspec(name = "R", displayname = "R (4.3.3)")'

RUN --mount=type=cache,target=/app/miniconda/pkgs \
mamba install -c conda-forge -c bioconda -y \
RUN mamba install -c conda-forge -c bioconda -y \
biokit=0.5.0 \
gseapy=1.1.4 \
blast=2.16.0 \
Expand All @@ -116,7 +110,9 @@ RUN --mount=type=cache,target=/app/miniconda/pkgs \
bioconductor-summarizedexperiment=1.32.0 \
bioconductor-apeglm=1.24.0 \
bioconductor-flowcore=2.14.0 \
bioconductor-flowmeans=1.62.0
bioconductor-flowmeans=1.62.0 \
bioconductor-limma=3.58.1 \
bioconductor-geoquery=2.70.0

ENV UV_COMPILE_BYTECODE=1
ENV UV_LINK_MODE=copy
Expand All @@ -131,8 +127,7 @@ FROM base AS builder

ARG MODULE_NAME

RUN --mount=type=cache,target=/var/cache/apt \
apt-get update -qq && \
RUN apt-get update -qq && \
apt-get install -yq --no-install-recommends \
build-essential && \
apt-get clean && rm -rf /var/lib/apt/lists/*
Expand All @@ -147,9 +142,7 @@ COPY ./scripts/run_crow_job.py /app/scripts/

# Install application dependencies (this will only rerun when code changes)
WORKDIR /app/${MODULE_NAME}
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=cache,target=/app/miniconda/pkgs \
if [ -f "pyproject.toml" ]; then \
RUN if [ -f "pyproject.toml" ]; then \
uv pip install --system -e .; \
elif [ -f "requirements.txt" ]; then \
uv pip install --system -r requirements.txt; \
Expand All @@ -167,4 +160,4 @@ COPY --from=builder /app/ /app/
ENV VIRTUAL_ENV="/app/miniconda/bin"
ENV PATH="/app/miniconda/bin:$PATH"
ENV PYTHONPATH="/app/miniconda/lib/python3.12/site-packages:${PYTHONPATH:-}"
CMD ["python", "scripts/run_crow_job.py"]
CMD ["python", "scripts/run_crow_job.py"]
2 changes: 1 addition & 1 deletion src/fhda/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@
# FutureHosue client config
ENVIRONMENT = os.getenv("ENVIRONMENT", "prod")
CROW_STAGE = getattr(Stage, ENVIRONMENT.upper(), Stage.PROD)
PLATFORM_API_KEY = os.getenv("CROW_API_KEY", None)
PLATFORM_API_KEY = os.getenv("FH_API_KEY", None)
82 changes: 71 additions & 11 deletions src/fhda/data_analysis_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from futurehouse_client import FutureHouseClient

from .notebook_env import NBEnvironment
from .utils import NBLanguage, MultipleChoiceQuestion
from .utils import NBLanguage, MultipleChoiceQuestion, extract_xml_content
from . import prompts
from . import config as cfg

Expand Down Expand Up @@ -174,6 +174,7 @@ def from_task(
trajectory_id: str | None = None,
user_id: str | None = None,
environment_config: dict[str, Any] | None = None,
continued_trajectory_id: str | None = None,
) -> "DataAnalysisEnv":
"""
Perform data analysis on a user query.
Expand All @@ -188,18 +189,21 @@ def from_task(
logger.info("environment_config: %s", environment_config)
logger.info("trajectory_id: %s", trajectory_id)
logger.info("user_id: %s", user_id)
# Track cost of running the environment
logger.info("continued_trajectory_id: %s", continued_trajectory_id)
enable_cost_tracking()

if (
not gcs_artifact_path
(not gcs_artifact_path) and not continued_trajectory_id
): # Platform jobs should always be associated with data from a GCS bucket
raise NotImplementedError(
"Running crow jobs without gcs_artifact_path is not supported"
)

if user_id is None:
logger.warning("No user_id provided, using default_user")
user_id = "default_user"
if trajectory_id is None:
logger.warning("No trajectory_id provided, using time-based id")
trajectory_id = f"{gcs_artifact_path}-{time.time()}"
if environment_config:
kwargs = {
Expand All @@ -214,11 +218,49 @@ def from_task(
trajectory_path = (
cfg.DATA_STORAGE_PATH / "user_trajectories" / user_id / trajectory_id
)
if environment_config.get("gcs_override", False):
data_path = cfg.DATA_STORAGE_PATH / gcs_artifact_path

if continued_trajectory_id:
kwargs["rerun_all_cells"] = True
data_path = (
cfg.DATA_STORAGE_PATH
/ "user_trajectories"
/ user_id
/ continued_trajectory_id
)
logger.info("Continuing trajectory from %s", continued_trajectory_id)
if cfg.PLATFORM_API_KEY is None:
logger.warning(
"Platform API key is not set, can't fetch previous trajectory"
)
previous_research_question = None
previous_final_answer = None
else:
logger.info("Fetching previous trajectory")
client = FutureHouseClient(
stage=cfg.CROW_STAGE,
auth_type=AuthType.API_KEY,
api_key=cfg.PLATFORM_API_KEY,
)
previous_trajectory = client.get_task(
continued_trajectory_id, verbose=True
)
previous_research_question = extract_xml_content(
previous_trajectory.query, "query"
)
previous_final_answer = previous_trajectory.environment_frame["state"][
"state"
]["answer"]
language = previous_trajectory.environment_frame["state"]["info"][
"language"
]
language = getattr(NBLanguage, language.upper())
kwargs["language"] = language

elif environment_config.get("gcs_override", False):
data_path = cfg.DATA_STORAGE_PATH / gcs_artifact_path # type: ignore
else:
data_path = (
cfg.DATA_STORAGE_PATH / "user_data" / user_id / gcs_artifact_path
cfg.DATA_STORAGE_PATH / "user_data" / user_id / gcs_artifact_path # type: ignore
)
logger.info("Trajectory path: %s", trajectory_path)
logger.info("Data path: %s", data_path)
Expand All @@ -230,12 +272,19 @@ def from_task(
shutil.copytree(item, trajectory_path / item.name, dirs_exist_ok=True)
logger.info("Filtered kwargs: %s", kwargs)

language = getattr(NBLanguage, environment_config.get("language", "PYTHON"))
# Overwrite the language in the kwargs with NBLanguage enum
kwargs["language"] = language
logger.info("Language: %s", language.name)
# If it's continued, we already have the language
if continued_trajectory_id:
logger.info(
"Language already set from previous trajectory notebook %s",
kwargs.get("language", None),
)
else:
language = getattr(NBLanguage, environment_config.get("language", "PYTHON"))
# Overwrite the language in the kwargs with NBLanguage enum
kwargs["language"] = language
logger.info("Language: %s", language.name)

if not environment_config.get("eval", False):
if not environment_config.get("eval", False) and not continued_trajectory_id:
logger.info(
"Platform job detected, augmenting user query with CoT instructions"
)
Expand All @@ -248,6 +297,17 @@ def from_task(
f"{task}\n"
f"</query>\n"
)
if continued_trajectory_id and not environment_config.get("eval", False):
logger.info(
"Continuation job detected, augmenting user query with continuation instructions"
)
task = prompts.CONTINUATION_PROMPT_TEMPLATE.format(
previous_research_question=previous_research_question,
previous_final_answer=previous_final_answer,
query=task,
language=kwargs.get("language", "PYTHON"),
)

nb_path = trajectory_path / NBEnvironment.NOTEBOOK_NAME
logger.info("NB path: %s", nb_path)

Expand Down
5 changes: 5 additions & 0 deletions src/fhda/notebook_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
language: utils.NBLanguage = utils.NBLanguage.PYTHON,
allow_download_from_gcs: bool = False,
run_notebook_on_edit: bool = False,
rerun_all_cells: bool = False,
):
"""Initialize a notebook environment.

Expand All @@ -140,6 +141,7 @@ def __init__(
task requires data on GCS. Disabled by default.
run_notebook_on_edit: If True (default), the whole notebook will be rerun
after each edit. If False, only the cell that was edited will be rerun.
rerun_all_cells: If True, the whole notebook will be run at the beginning of the episode. This is for continued trajectories.
"""
self.work_dir = Path(work_dir)
self.nb_path = Path(nb_path) if nb_path else self.work_dir / self.NOTEBOOK_NAME
Expand All @@ -149,6 +151,7 @@ def __init__(
self.allow_download_from_gcs = allow_download_from_gcs
self.use_docker = cfg.USE_DOCKER
self.run_notebook_on_edit = run_notebook_on_edit
self.rerun_all_cells = rerun_all_cells

async def reset(self) -> tuple[Messages, list[Tool]]:
nb_path, work_dir = self._set_work_dir()
Expand All @@ -158,6 +161,8 @@ async def reset(self) -> tuple[Messages, list[Tool]]:
language=self.language,
use_docker=self.use_docker,
)
if self.rerun_all_cells:
await self.run_notebook()

self.tools = [
Tool.from_function(self.edit_cell),
Expand Down
25 changes: 25 additions & 0 deletions src/fhda/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,28 @@
{GENERAL_NOTEBOOK_GUIDELINES}
{R_SPECIFIC_GUIDELINES}
"""

CONTINUATION_PROMPT_TEMPLATE = f"""
{GENERAL_NOTEBOOK_GUIDELINES}

You have been provided with a notebook previously generated by an agent based on a user's research question.

This was the user's research question:
<previous_research_question>
{{previous_research_question}}
</previous_research_question>

This was the final answer generated by the previous agent:
<previous_final_answer>
{{previous_final_answer}}
</previous_final_answer>

The user has now tasked you with addressing a new query:
<query>
{{query}}
</query>

Please make any edits required to the notebook and the answer to address the new query. Be extremely diligent and ensure that the notebook is fully updated to address the new query.
Note you may have to run all cells one by one again if the user query involved updating one of the intermediate cells and subsequent cells depend on it.
Once you have updated the notebook, use the submit_answer tool to submit your final answer once the user's query is addressed.
"""
Loading