diff --git a/README.md b/README.md index 68ebfff..ef1b4aa 100644 --- a/README.md +++ b/README.md @@ -182,9 +182,13 @@ results = await asyncio.gather( # Compute rewards per_category_rewards = defaultdict(list) for row, result in zip(test_ds, results, strict=True): + # NOTE: you can also use `ether0.rewards.accuracy_reward`, + # but we decided to go a bit "lower level" for this demo reward_info = RewardFunctionInfo.model_validate(row["solution"]) yhat = extract_answer_loose(result[0].text) - reward = EVAL_FUNCTIONS[reward_info.fxn_name](yhat=yhat, y=reward_info.answer_info) + reward = EVAL_FUNCTIONS[reward_info.fxn_name]( + yhat=yhat, y=reward_info.answer_info, test=True + ) per_category_rewards[get_problem_category(reward_info.problem_type)].append(reward) for category, rewards in sorted(per_category_rewards.items()): diff --git a/src/ether0/rewards.py b/src/ether0/rewards.py index a86bce7..911c94e 100644 --- a/src/ether0/rewards.py +++ b/src/ether0/rewards.py @@ -19,7 +19,7 @@ from ether0.clients import fetch_forward_rxn, fetch_purchasable, fetch_solubility from ether0.data import is_reasonable_fp, is_reasonable_ring_system, mol_from_smiles -from ether0.model_prompts import extract_thought_answer_strict +from ether0.model_prompts import extract_answer_loose, extract_thought_answer_strict from ether0.models import RewardFunctionInfo, RewardReason block = BlockLogs() @@ -702,14 +702,11 @@ def accuracy_reward( reward_info = RewardFunctionInfo.model_validate(info) fxn_name, answer_info, problem_type = tuple(reward_info.model_dump().values()) try: - if test: - answer: str | None = ( - content.split("")[1].split("")[0] - if "" in content - else content - ) - else: - answer = extract_thought_answer_strict(content, reasoning=reasoning)[1] + answer: str | None = ( + extract_answer_loose(content) + if test + else extract_thought_answer_strict(content, reasoning=reasoning)[1] + ) if answer is not None: # During test time, see if full SMILES string was given as input if problem_type == "valid_mol_eval" and test: