diff --git a/README.md b/README.md
index 68ebfff..ef1b4aa 100644
--- a/README.md
+++ b/README.md
@@ -182,9 +182,13 @@ results = await asyncio.gather(
# Compute rewards
per_category_rewards = defaultdict(list)
for row, result in zip(test_ds, results, strict=True):
+ # NOTE: you can also use `ether0.rewards.accuracy_reward`,
+ # but we decided to go a bit "lower level" for this demo
reward_info = RewardFunctionInfo.model_validate(row["solution"])
yhat = extract_answer_loose(result[0].text)
- reward = EVAL_FUNCTIONS[reward_info.fxn_name](yhat=yhat, y=reward_info.answer_info)
+ reward = EVAL_FUNCTIONS[reward_info.fxn_name](
+ yhat=yhat, y=reward_info.answer_info, test=True
+ )
per_category_rewards[get_problem_category(reward_info.problem_type)].append(reward)
for category, rewards in sorted(per_category_rewards.items()):
diff --git a/src/ether0/rewards.py b/src/ether0/rewards.py
index a86bce7..911c94e 100644
--- a/src/ether0/rewards.py
+++ b/src/ether0/rewards.py
@@ -19,7 +19,7 @@
from ether0.clients import fetch_forward_rxn, fetch_purchasable, fetch_solubility
from ether0.data import is_reasonable_fp, is_reasonable_ring_system, mol_from_smiles
-from ether0.model_prompts import extract_thought_answer_strict
+from ether0.model_prompts import extract_answer_loose, extract_thought_answer_strict
from ether0.models import RewardFunctionInfo, RewardReason
block = BlockLogs()
@@ -702,14 +702,11 @@ def accuracy_reward(
reward_info = RewardFunctionInfo.model_validate(info)
fxn_name, answer_info, problem_type = tuple(reward_info.model_dump().values())
try:
- if test:
- answer: str | None = (
- content.split("")[1].split("")[0]
- if "" in content
- else content
- )
- else:
- answer = extract_thought_answer_strict(content, reasoning=reasoning)[1]
+ answer: str | None = (
+ extract_answer_loose(content)
+ if test
+ else extract_thought_answer_strict(content, reasoning=reasoning)[1]
+ )
if answer is not None:
# During test time, see if full SMILES string was given as input
if problem_type == "valid_mol_eval" and test: