Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/api/rlm.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ RLM(
max_tokens: int | None = None,
max_errors: int | None = None,
custom_system_prompt: str | None = None,
tool_prompts: str | list[str] | None = None,
tool_code: str | None = None,
other_backends: list[str] | None = None,
other_backend_kwargs: list[dict] | None = None,
logger: RLMLogger | None = None,
Expand Down Expand Up @@ -267,6 +269,50 @@ rlm = RLM(..., custom_system_prompt=custom_prompt)

---

#### `tool_prompts`
{: .no_toc }

**Type:** `str | list[str] | None`
**Default:** `None`

Extra prompt section(s) appended to the system prompt (without replacing the default) so the root LM knows about extra helper functions/tools you provide. Accepts a single string or a list of strings.

```python
rlm = RLM(
...,
tool_prompts=[
"Extra tools available:\n- fetch(url: str) -> str: Fetch a URL and return text\n- summarize(text: str) -> str: Summarize text",
],
)
```

---

#### `tool_code`
{: .no_toc }

**Type:** `str | None`
**Default:** `None`

Python code executed in the environment before iterations. This is merged into `environment_kwargs["setup_code"]` so your helper functions exist in the REPL namespace. The code must be valid Python syntax (validated at initialization).

**Execution order:** When both `tool_code` and `environment_kwargs["setup_code"]` are provided, `setup_code` runs first, then `tool_code` is appended after it. This means `tool_code` can reference variables or imports defined in `setup_code`.

**Persistent mode:** When `persistent=True`, `tool_code` is injected only on the first environment creation. Subsequent calls reuse the persistent environment which already has the tool functions available.

```python
rlm = RLM(
...,
tool_prompts=["You can call add(a, b) to add two numbers."],
tool_code="""
def add(a, b):
return a + b
""",
)
```

---

#### `other_backends` / `other_backend_kwargs`
{: .no_toc }

Expand Down
33 changes: 33 additions & 0 deletions rlm/core/rlm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import time
from collections.abc import Callable
from contextlib import contextmanager
Expand Down Expand Up @@ -32,6 +33,7 @@
from rlm.utils.prompts import (
RLM_SYSTEM_PROMPT,
QueryMetadata,
append_prompt_sections,
build_rlm_system_prompt,
build_user_prompt,
)
Expand Down Expand Up @@ -61,6 +63,8 @@ def __init__(
max_tokens: int | None = None,
max_errors: int | None = None,
custom_system_prompt: str | None = None,
tool_prompts: str | list[str] | None = None,
tool_code: str | None = None,
other_backends: list[ClientBackend] | None = None,
other_backend_kwargs: list[dict[str, Any]] | None = None,
logger: RLMLogger | None = None,
Expand Down Expand Up @@ -89,6 +93,10 @@ def __init__(
max_tokens: Maximum total tokens (input + output). Execution stops if exceeded, returning best answer if available.
max_errors: Maximum consecutive errors before stopping. Execution stops if exceeded, returning best answer if available.
custom_system_prompt: The custom system prompt to use for the RLM.
tool_prompts: Extra prompt section(s) appended to the system prompt (e.g., tool/function descriptions).
Accepts a single string or a list of strings.
tool_code: Python code executed in the environment before iterations. Merged into
environment_kwargs['setup_code']: existing setup_code runs first, then tool_code.
other_backends: A list of other client backends that the environments can use to make sub-calls.
other_backend_kwargs: The kwargs to pass to the other client backends (ordered to match other_backends).
logger: The logger to use for the RLM.
Expand Down Expand Up @@ -141,6 +149,19 @@ def __init__(
self.max_tokens = max_tokens
self.max_errors = max_errors
self.system_prompt = custom_system_prompt if custom_system_prompt else RLM_SYSTEM_PROMPT

if isinstance(tool_prompts, str):
tool_prompts = [tool_prompts]

base_system_prompt = custom_system_prompt if custom_system_prompt else RLM_SYSTEM_PROMPT
self.system_prompt = append_prompt_sections(base_system_prompt, tool_prompts)

if tool_code is not None:
try:
ast.parse(tool_code)
except SyntaxError as e:
raise ValueError(f"tool_code contains invalid Python syntax: {e}") from e
self.tool_code = tool_code
self.logger = logger
self.verbose = VerbosePrinter(enabled=verbose)

Expand Down Expand Up @@ -238,6 +259,18 @@ def _spawn_completion_context(self, prompt: str | dict[str, Any]):
env_kwargs["custom_sub_tools"] = self.custom_sub_tools
if self.compaction and self.environment_type == "local":
env_kwargs["compaction"] = True

# Merge tool_code into setup_code so user-defined functions exist in the REPL namespace.
# This only runs on environment creation; persistent envs inherit it on subsequent calls.
if self.tool_code and self.tool_code.strip():
existing_setup_code = env_kwargs.get("setup_code")
if existing_setup_code and existing_setup_code.strip():
env_kwargs["setup_code"] = (
existing_setup_code.rstrip() + "\n\n" + self.tool_code.lstrip()
)
else:
env_kwargs["setup_code"] = self.tool_code

environment: BaseEnv = get_environment(self.environment_type, env_kwargs)

if self.persistent:
Expand Down
14 changes: 14 additions & 0 deletions rlm/utils/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@

from rlm.core.types import QueryMetadata


def append_prompt_sections(base_prompt: str, extra_sections: str | list[str] | None) -> str:
if not extra_sections:
return base_prompt

non_empty = [section for section in extra_sections if section and section.strip()]
cleaned_sections = [section.strip() for section in non_empty]
if not cleaned_sections:
return base_prompt

# Keep the existing prompt verbatim; just append extra sections.
return base_prompt.rstrip() + "\n\n" + "\n\n".join(cleaned_sections) + "\n"


# System prompt for the REPL environment with explicit final answer checking
RLM_SYSTEM_PROMPT = textwrap.dedent(
"""You are tasked with answering a query with associated context. You can access, transform, and analyze this context interactively in a REPL environment that can recursively query sub-LLMs, which you are strongly encouraged to use as much as possible. You will be queried iteratively until you provide a final answer.
Expand Down
172 changes: 172 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""Tests for tool_prompts and tool_code functionality."""

import pytest

from rlm.environments.local_repl import LocalREPL
from rlm.utils.prompts import RLM_SYSTEM_PROMPT, append_prompt_sections


class TestAppendPromptSections:
"""Tests for append_prompt_sections helper."""

def test_none_returns_base(self):
assert append_prompt_sections("base", None) == "base"

def test_empty_list_returns_base(self):
assert append_prompt_sections("base", []) == "base"

def test_list_of_empty_strings_returns_base(self):
assert append_prompt_sections("base", ["", " ", ""]) == "base"

def test_single_section(self):
result = append_prompt_sections("base", ["extra info"])
assert result == "base\n\nextra info\n"

def test_multiple_sections(self):
result = append_prompt_sections("base", ["section1", "section2"])
assert result == "base\n\nsection1\n\nsection2\n"

def test_strips_whitespace_from_sections(self):
result = append_prompt_sections("base", [" padded "])
assert result == "base\n\npadded\n"

def test_filters_out_empty_among_valid(self):
result = append_prompt_sections("base", ["valid", "", "also valid"])
assert result == "base\n\nvalid\n\nalso valid\n"

def test_preserves_base_prompt_content(self):
result = append_prompt_sections(RLM_SYSTEM_PROMPT, ["tool hint"])
assert result.startswith(RLM_SYSTEM_PROMPT.rstrip())
assert result.endswith("tool hint\n")


class TestToolCodeValidation:
"""Tests for tool_code syntax validation in RLM.__init__."""

def test_invalid_tool_code_raises_value_error(self):
from rlm.core.rlm import RLM

with pytest.raises(ValueError, match="tool_code contains invalid Python syntax"):
RLM(
backend="openai",
backend_kwargs={"model_name": "gpt-4"},
tool_code="def broken(",
)

def test_none_tool_code_is_accepted(self):
from rlm.core.rlm import RLM

rlm = RLM(
backend="openai",
backend_kwargs={"model_name": "gpt-4"},
tool_code=None,
)
assert rlm.tool_code is None

def test_valid_tool_code_is_accepted(self):
from rlm.core.rlm import RLM

code = "def add(a, b):\n return a + b\n"
rlm = RLM(
backend="openai",
backend_kwargs={"model_name": "gpt-4"},
tool_code=code,
)
assert rlm.tool_code == code


class TestToolPromptsNormalization:
"""Tests for tool_prompts accepting str | list[str] | None."""

def test_string_tool_prompts(self):
from rlm.core.rlm import RLM

rlm = RLM(
backend="openai",
backend_kwargs={"model_name": "gpt-4"},
tool_prompts="You can call add(a, b).",
)
assert "You can call add(a, b)." in rlm.system_prompt

def test_list_tool_prompts(self):
from rlm.core.rlm import RLM

rlm = RLM(
backend="openai",
backend_kwargs={"model_name": "gpt-4"},
tool_prompts=["hint1", "hint2"],
)
assert "hint1" in rlm.system_prompt
assert "hint2" in rlm.system_prompt

def test_none_tool_prompts_uses_default(self):
from rlm.core.rlm import RLM

rlm = RLM(
backend="openai",
backend_kwargs={"model_name": "gpt-4"},
tool_prompts=None,
)
assert rlm.system_prompt == RLM_SYSTEM_PROMPT


class TestToolCodeMerge:
"""Tests for tool_code merging with setup_code in environment kwargs."""

def test_tool_code_only(self):
"""tool_code becomes setup_code when no existing setup_code."""
from rlm.core.rlm import RLM

rlm = RLM(
backend="openai",
backend_kwargs={"model_name": "gpt-4"},
tool_code="x = 1",
)
# We can't easily test the merge without calling completion, so
# verify the attribute is stored and will be used.
assert rlm.tool_code == "x = 1"

def test_both_setup_and_tool_code_order(self):
"""setup_code should run before tool_code when both are provided."""
from rlm.core.rlm import RLM

rlm = RLM(
backend="openai",
backend_kwargs={"model_name": "gpt-4"},
environment_kwargs={"setup_code": "first = 1"},
tool_code="second = 2",
)
# Simulate the merge logic that happens in _spawn_completion_context
env_kwargs = rlm.environment_kwargs.copy()
if rlm.tool_code and rlm.tool_code.strip():
existing = env_kwargs.get("setup_code")
if existing and existing.strip():
env_kwargs["setup_code"] = existing.rstrip() + "\n\n" + rlm.tool_code.lstrip()
else:
env_kwargs["setup_code"] = rlm.tool_code

merged = env_kwargs["setup_code"]
first_pos = merged.index("first = 1")
second_pos = merged.index("second = 2")
assert first_pos < second_pos, "setup_code should come before tool_code"


class TestToolCodeInREPL:
"""End-to-end test: tool function defined via setup_code is callable."""

def test_tool_function_callable_in_repl(self):
setup = "def add(a, b):\n return a + b\n"
repl = LocalREPL(setup_code=setup)
result = repl.execute_code("result = add(2, 3)\nprint(result)")
assert result.stderr == ""
assert "5" in result.stdout
assert repl.locals["result"] == 5
repl.cleanup()

def test_setup_code_then_tool_code_both_available(self):
combined = "base_val = 10\n\ndef multiply(a, b):\n return a * b\n"
repl = LocalREPL(setup_code=combined)
result = repl.execute_code("result = multiply(base_val, 3)\nprint(result)")
assert result.stderr == ""
assert "30" in result.stdout
repl.cleanup()