From 04368b2d2ccc29938ce711d5302209fe05c512e3 Mon Sep 17 00:00:00 2001 From: rbotnari Date: Thu, 19 Feb 2026 00:26:46 +0200 Subject: [PATCH 1/2] add extra user defined tools --- docs/api/rlm.md | 46 ++++++++++++ rlm/core/rlm.py | 31 ++++++++ rlm/utils/prompts.py | 13 ++++ tests/test_tools.py | 172 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 262 insertions(+) create mode 100644 tests/test_tools.py diff --git a/docs/api/rlm.md b/docs/api/rlm.md index 383f2cb2..42bdb098 100644 --- a/docs/api/rlm.md +++ b/docs/api/rlm.md @@ -50,6 +50,8 @@ RLM( max_tokens: int | None = None, max_errors: int | None = None, custom_system_prompt: str | None = None, + tool_prompts: str | list[str] | None = None, + tool_code: str | None = None, other_backends: list[str] | None = None, other_backend_kwargs: list[dict] | None = None, logger: RLMLogger | None = None, @@ -267,6 +269,50 @@ rlm = RLM(..., custom_system_prompt=custom_prompt) --- +#### `tool_prompts` +{: .no_toc } + +**Type:** `str | list[str] | None` +**Default:** `None` + +Extra prompt section(s) appended to the system prompt (without replacing the default) so the root LM knows about extra helper functions/tools you provide. Accepts a single string or a list of strings. + +```python +rlm = RLM( + ..., + tool_prompts=[ + "Extra tools available:\n- fetch(url: str) -> str: Fetch a URL and return text\n- summarize(text: str) -> str: Summarize text", + ], +) +``` + +--- + +#### `tool_code` +{: .no_toc } + +**Type:** `str | None` +**Default:** `None` + +Python code executed in the environment before iterations. This is merged into `environment_kwargs["setup_code"]` so your helper functions exist in the REPL namespace. The code must be valid Python syntax (validated at initialization). + +**Execution order:** When both `tool_code` and `environment_kwargs["setup_code"]` are provided, `setup_code` runs first, then `tool_code` is appended after it. This means `tool_code` can reference variables or imports defined in `setup_code`. + +**Persistent mode:** When `persistent=True`, `tool_code` is injected only on the first environment creation. Subsequent calls reuse the persistent environment which already has the tool functions available. + +```python +rlm = RLM( + ..., + tool_prompts=["You can call add(a, b) to add two numbers."], + tool_code=""" +def add(a, b): + return a + b +""", +) +``` + +--- + #### `other_backends` / `other_backend_kwargs` {: .no_toc } diff --git a/rlm/core/rlm.py b/rlm/core/rlm.py index ce247ecd..387a7a11 100644 --- a/rlm/core/rlm.py +++ b/rlm/core/rlm.py @@ -1,3 +1,4 @@ +import ast import time from collections.abc import Callable from contextlib import contextmanager @@ -32,6 +33,7 @@ from rlm.utils.prompts import ( RLM_SYSTEM_PROMPT, QueryMetadata, + append_prompt_sections, build_rlm_system_prompt, build_user_prompt, ) @@ -61,6 +63,8 @@ def __init__( max_tokens: int | None = None, max_errors: int | None = None, custom_system_prompt: str | None = None, + tool_prompts: str | list[str] | None = None, + tool_code: str | None = None, other_backends: list[ClientBackend] | None = None, other_backend_kwargs: list[dict[str, Any]] | None = None, logger: RLMLogger | None = None, @@ -89,6 +93,10 @@ def __init__( max_tokens: Maximum total tokens (input + output). Execution stops if exceeded, returning best answer if available. max_errors: Maximum consecutive errors before stopping. Execution stops if exceeded, returning best answer if available. custom_system_prompt: The custom system prompt to use for the RLM. + tool_prompts: Extra prompt section(s) appended to the system prompt (e.g., tool/function descriptions). + Accepts a single string or a list of strings. + tool_code: Python code executed in the environment before iterations. Merged into + environment_kwargs['setup_code']: existing setup_code runs first, then tool_code. other_backends: A list of other client backends that the environments can use to make sub-calls. other_backend_kwargs: The kwargs to pass to the other client backends (ordered to match other_backends). logger: The logger to use for the RLM. @@ -141,6 +149,19 @@ def __init__( self.max_tokens = max_tokens self.max_errors = max_errors self.system_prompt = custom_system_prompt if custom_system_prompt else RLM_SYSTEM_PROMPT + + if isinstance(tool_prompts, str): + tool_prompts = [tool_prompts] + + base_system_prompt = custom_system_prompt if custom_system_prompt else RLM_SYSTEM_PROMPT + self.system_prompt = append_prompt_sections(base_system_prompt, tool_prompts) + + if tool_code is not None: + try: + ast.parse(tool_code) + except SyntaxError as e: + raise ValueError(f"tool_code contains invalid Python syntax: {e}") from e + self.tool_code = tool_code self.logger = logger self.verbose = VerbosePrinter(enabled=verbose) @@ -238,6 +259,16 @@ def _spawn_completion_context(self, prompt: str | dict[str, Any]): env_kwargs["custom_sub_tools"] = self.custom_sub_tools if self.compaction and self.environment_type == "local": env_kwargs["compaction"] = True + + # Merge tool_code into setup_code so user-defined functions exist in the REPL namespace. + # This only runs on environment creation; persistent envs inherit it on subsequent calls. + if self.tool_code and self.tool_code.strip(): + existing_setup_code = env_kwargs.get("setup_code") + if existing_setup_code and existing_setup_code.strip(): + env_kwargs["setup_code"] = existing_setup_code.rstrip() + "\n\n" + self.tool_code.lstrip() + else: + env_kwargs["setup_code"] = self.tool_code + environment: BaseEnv = get_environment(self.environment_type, env_kwargs) if self.persistent: diff --git a/rlm/utils/prompts.py b/rlm/utils/prompts.py index 3add902e..f09a17bc 100644 --- a/rlm/utils/prompts.py +++ b/rlm/utils/prompts.py @@ -3,6 +3,19 @@ from rlm.core.types import QueryMetadata + +def append_prompt_sections(base_prompt: str, extra_sections: str | list[str] | None) -> str: + if not extra_sections: + return base_prompt + + non_empty = [section for section in extra_sections if section and section.strip()] + cleaned_sections = [section.strip() for section in non_empty] + if not cleaned_sections: + return base_prompt + + # Keep the existing prompt verbatim; just append extra sections. + return base_prompt.rstrip() + "\n\n" + "\n\n".join(cleaned_sections) + "\n" + # System prompt for the REPL environment with explicit final answer checking RLM_SYSTEM_PROMPT = textwrap.dedent( """You are tasked with answering a query with associated context. You can access, transform, and analyze this context interactively in a REPL environment that can recursively query sub-LLMs, which you are strongly encouraged to use as much as possible. You will be queried iteratively until you provide a final answer. diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 00000000..6580db8c --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,172 @@ +"""Tests for tool_prompts and tool_code functionality.""" + +import pytest + +from rlm.environments.local_repl import LocalREPL +from rlm.utils.prompts import RLM_SYSTEM_PROMPT, append_prompt_sections + + +class TestAppendPromptSections: + """Tests for append_prompt_sections helper.""" + + def test_none_returns_base(self): + assert append_prompt_sections("base", None) == "base" + + def test_empty_list_returns_base(self): + assert append_prompt_sections("base", []) == "base" + + def test_list_of_empty_strings_returns_base(self): + assert append_prompt_sections("base", ["", " ", ""]) == "base" + + def test_single_section(self): + result = append_prompt_sections("base", ["extra info"]) + assert result == "base\n\nextra info\n" + + def test_multiple_sections(self): + result = append_prompt_sections("base", ["section1", "section2"]) + assert result == "base\n\nsection1\n\nsection2\n" + + def test_strips_whitespace_from_sections(self): + result = append_prompt_sections("base", [" padded "]) + assert result == "base\n\npadded\n" + + def test_filters_out_empty_among_valid(self): + result = append_prompt_sections("base", ["valid", "", "also valid"]) + assert result == "base\n\nvalid\n\nalso valid\n" + + def test_preserves_base_prompt_content(self): + result = append_prompt_sections(RLM_SYSTEM_PROMPT, ["tool hint"]) + assert result.startswith(RLM_SYSTEM_PROMPT.rstrip()) + assert result.endswith("tool hint\n") + + +class TestToolCodeValidation: + """Tests for tool_code syntax validation in RLM.__init__.""" + + def test_invalid_tool_code_raises_value_error(self): + from rlm.core.rlm import RLM + + with pytest.raises(ValueError, match="tool_code contains invalid Python syntax"): + RLM( + backend="openai", + backend_kwargs={"model_name": "gpt-4"}, + tool_code="def broken(", + ) + + def test_none_tool_code_is_accepted(self): + from rlm.core.rlm import RLM + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "gpt-4"}, + tool_code=None, + ) + assert rlm.tool_code is None + + def test_valid_tool_code_is_accepted(self): + from rlm.core.rlm import RLM + + code = "def add(a, b):\n return a + b\n" + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "gpt-4"}, + tool_code=code, + ) + assert rlm.tool_code == code + + +class TestToolPromptsNormalization: + """Tests for tool_prompts accepting str | list[str] | None.""" + + def test_string_tool_prompts(self): + from rlm.core.rlm import RLM + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "gpt-4"}, + tool_prompts="You can call add(a, b).", + ) + assert "You can call add(a, b)." in rlm.system_prompt + + def test_list_tool_prompts(self): + from rlm.core.rlm import RLM + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "gpt-4"}, + tool_prompts=["hint1", "hint2"], + ) + assert "hint1" in rlm.system_prompt + assert "hint2" in rlm.system_prompt + + def test_none_tool_prompts_uses_default(self): + from rlm.core.rlm import RLM + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "gpt-4"}, + tool_prompts=None, + ) + assert rlm.system_prompt == RLM_SYSTEM_PROMPT + + +class TestToolCodeMerge: + """Tests for tool_code merging with setup_code in environment kwargs.""" + + def test_tool_code_only(self): + """tool_code becomes setup_code when no existing setup_code.""" + from rlm.core.rlm import RLM + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "gpt-4"}, + tool_code="x = 1", + ) + # We can't easily test the merge without calling completion, so + # verify the attribute is stored and will be used. + assert rlm.tool_code == "x = 1" + + def test_both_setup_and_tool_code_order(self): + """setup_code should run before tool_code when both are provided.""" + from rlm.core.rlm import RLM + + rlm = RLM( + backend="openai", + backend_kwargs={"model_name": "gpt-4"}, + environment_kwargs={"setup_code": "first = 1"}, + tool_code="second = 2", + ) + # Simulate the merge logic that happens in _spawn_completion_context + env_kwargs = rlm.environment_kwargs.copy() + if rlm.tool_code and rlm.tool_code.strip(): + existing = env_kwargs.get("setup_code") + if existing and existing.strip(): + env_kwargs["setup_code"] = existing.rstrip() + "\n\n" + rlm.tool_code.lstrip() + else: + env_kwargs["setup_code"] = rlm.tool_code + + merged = env_kwargs["setup_code"] + first_pos = merged.index("first = 1") + second_pos = merged.index("second = 2") + assert first_pos < second_pos, "setup_code should come before tool_code" + + +class TestToolCodeInREPL: + """End-to-end test: tool function defined via setup_code is callable.""" + + def test_tool_function_callable_in_repl(self): + setup = "def add(a, b):\n return a + b\n" + repl = LocalREPL(setup_code=setup) + result = repl.execute_code("result = add(2, 3)\nprint(result)") + assert result.stderr == "" + assert "5" in result.stdout + assert repl.locals["result"] == 5 + repl.cleanup() + + def test_setup_code_then_tool_code_both_available(self): + combined = "base_val = 10\n\ndef multiply(a, b):\n return a * b\n" + repl = LocalREPL(setup_code=combined) + result = repl.execute_code("result = multiply(base_val, 3)\nprint(result)") + assert result.stderr == "" + assert "30" in result.stdout + repl.cleanup() From f7ef2babb53394c4146c8fef5ccbb696d3b7055d Mon Sep 17 00:00:00 2001 From: rbotnari Date: Thu, 19 Feb 2026 08:32:04 +0200 Subject: [PATCH 2/2] fix formatting --- rlm/core/rlm.py | 4 +++- rlm/utils/prompts.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/rlm/core/rlm.py b/rlm/core/rlm.py index 387a7a11..f281b928 100644 --- a/rlm/core/rlm.py +++ b/rlm/core/rlm.py @@ -265,7 +265,9 @@ def _spawn_completion_context(self, prompt: str | dict[str, Any]): if self.tool_code and self.tool_code.strip(): existing_setup_code = env_kwargs.get("setup_code") if existing_setup_code and existing_setup_code.strip(): - env_kwargs["setup_code"] = existing_setup_code.rstrip() + "\n\n" + self.tool_code.lstrip() + env_kwargs["setup_code"] = ( + existing_setup_code.rstrip() + "\n\n" + self.tool_code.lstrip() + ) else: env_kwargs["setup_code"] = self.tool_code diff --git a/rlm/utils/prompts.py b/rlm/utils/prompts.py index f09a17bc..c4da426a 100644 --- a/rlm/utils/prompts.py +++ b/rlm/utils/prompts.py @@ -16,6 +16,7 @@ def append_prompt_sections(base_prompt: str, extra_sections: str | list[str] | N # Keep the existing prompt verbatim; just append extra sections. return base_prompt.rstrip() + "\n\n" + "\n\n".join(cleaned_sections) + "\n" + # System prompt for the REPL environment with explicit final answer checking RLM_SYSTEM_PROMPT = textwrap.dedent( """You are tasked with answering a query with associated context. You can access, transform, and analyze this context interactively in a REPL environment that can recursively query sub-LLMs, which you are strongly encouraged to use as much as possible. You will be queried iteratively until you provide a final answer.