Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,13 @@ results = await asyncio.gather(
# Compute rewards
per_category_rewards = defaultdict(list)
for row, result in zip(test_ds, results, strict=True):
# NOTE: you can also use `ether0.rewards.accuracy_reward`,
# but we decided to go a bit "lower level" for this demo
reward_info = RewardFunctionInfo.model_validate(row["solution"])
yhat = extract_answer_loose(result[0].text)
reward = EVAL_FUNCTIONS[reward_info.fxn_name](yhat=yhat, y=reward_info.answer_info)
reward = EVAL_FUNCTIONS[reward_info.fxn_name](
yhat=yhat, y=reward_info.answer_info, test=True
)
per_category_rewards[get_problem_category(reward_info.problem_type)].append(reward)

for category, rewards in sorted(per_category_rewards.items()):
Expand Down
15 changes: 6 additions & 9 deletions src/ether0/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ether0.clients import fetch_forward_rxn, fetch_purchasable, fetch_solubility
from ether0.data import is_reasonable_fp, is_reasonable_ring_system, mol_from_smiles
from ether0.model_prompts import extract_thought_answer_strict
from ether0.model_prompts import extract_answer_loose, extract_thought_answer_strict
from ether0.models import RewardFunctionInfo, RewardReason

block = BlockLogs()
Expand Down Expand Up @@ -702,14 +702,11 @@ def accuracy_reward(
reward_info = RewardFunctionInfo.model_validate(info)
fxn_name, answer_info, problem_type = tuple(reward_info.model_dump().values())
try:
if test:
answer: str | None = (
content.split("<answer>")[1].split("</answer>")[0]
if "<answer>" in content
else content
)
else:
answer = extract_thought_answer_strict(content, reasoning=reasoning)[1]
answer: str | None = (
extract_answer_loose(content)
if test
else extract_thought_answer_strict(content, reasoning=reasoning)[1]
)
if answer is not None:
# During test time, see if full SMILES string was given as input
if problem_type == "valid_mol_eval" and test:
Expand Down