Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,70 @@ or in your terminal via `chafa valid_molecule.svg`
([chafa docs](https://hpjansson.org/chafa/)).

![valid molecule](docs/assets/valid_molecule.svg)

### Benchmark

Here is a sample baseline of
[`ether0-benchmark`](https://huggingface.co/datasets/futurehouse/ether0-benchmark)
on `gpt-4o` using [`lmi`](https://github.com/Future-House/ldp/tree/main/packages/lmi).
To install `lmi`, please install `ether0` with the `baselines` extra
(for example `uv sync --extra baselines`).

We also need to run our remote rewards server via `ether0-serve`
(for more information, see [`ether0.remotes` docs](packages/remotes/README.md)):

```bash
ETHER0_REMOTES_API_TOKEN=abc123 ether0-serve
```

Next, start `ipython` with the relevant environment variables set:

```bash
ETHER0_REMOTES_API_BASE_URL="http://127.0.0.1:8000" ETHER0_REMOTES_API_TOKEN=abc123 ipython
```

And run the following Python code:

```python
import itertools
import statistics
from collections import defaultdict

from aviary.core import Message
from datasets import load_dataset
from lmi import LiteLLMModel
from tqdm.asyncio import tqdm_asyncio as asyncio

from ether0.data import get_problem_category
from ether0.model_prompts import LOOSE_XML_ANSWER_USER_PROMPT, extract_answer_loose
from ether0.models import RewardFunctionInfo
from ether0.rewards import EVAL_FUNCTIONS

# Add LLM prompt of your making to the dataset
test_ds = load_dataset("futurehouse/ether0-benchmark", split="test").map(
lambda x: {"prompt": "\n\n".join((LOOSE_XML_ANSWER_USER_PROMPT, x["problem"]))}
)

# Prompt to LLM
model = LiteLLMModel(name="gpt-4o")
results = await asyncio.gather(
*(model.acompletion([Message(content=row["prompt"])]) for row in test_ds),
desc="Running evaluation",
)

# Compute rewards
per_category_rewards = defaultdict(list)
for row, result in zip(test_ds, results, strict=True):
reward_info = RewardFunctionInfo.model_validate(row["solution"])
yhat = extract_answer_loose(result[0].text)
reward = EVAL_FUNCTIONS[reward_info.fxn_name](yhat=yhat, y=reward_info.answer_info)
per_category_rewards[get_problem_category(reward_info.problem_type)].append(reward)

for category, rewards in sorted(per_category_rewards.items()):
print(
f"In category {category!r} of {len(rewards)} questions,"
f" average reward was {statistics.mean(rewards):.3f}."
)
accuracy = statistics.mean(itertools.chain.from_iterable(per_category_rewards.values()))
print(f"Cumulative average reward across {len(test_ds)} questions was {accuracy:.3f}.")
```
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ add-tokens = [
"ipywidgets>=8", # For Jupyter notebook support, and pin to keep recent
"transformers>=4.49", # Pin to keep recent
]
baselines = [
"fhaviary>=0.19", # Pin for Python 3.13 compatibility
"fhlmi>=0.26", # Pin for Python 3.13 compatibility
"ipython",
]
dev = [
"ether0[add-tokens,typing]",
"huggingface-hub[cli]", # For login inside of CI
Expand Down
8 changes: 8 additions & 0 deletions src/ether0/model_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def extract_thought_answer_strict(
return None, None # Consider nested answer as a failure


LOOSE_XML_ANSWER_USER_PROMPT = (
"When answering,"
" be sure to place the final answer as"
" SMILES notation into XML tags <answer></answer>."
" An example is <answer>CCO</answer>."
)


def extract_answer_loose(text: str | None) -> str:
"""
Extract thought and answer from text using a loose XML pattern.
Expand Down
Loading