-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathsolution_code_generation.py
More file actions
122 lines (99 loc) · 3.96 KB
/
Copy pathsolution_code_generation.py
File metadata and controls
122 lines (99 loc) · 3.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
import re
import os
import time
import argparse
from openai import OpenAI
import random
parser = argparse.ArgumentParser(description='Generate Python code for math word problems.')
parser.add_argument('input_file', type=str, help='Path to input JSONL file')
parser.add_argument('output_file', type=str, help='Path to output JSONL file')
parser.add_argument('--num_samples', type=int, default=1000, help='Number of samples to process (default: 1000)')
args = parser.parse_args()
data_file = args.input_file
output_file = args.output_file
num_samples = args.num_samples
with open(data_file, "r", encoding="utf-8") as f:
data = [json.loads(line) for line in f]
data = random.sample(data, num_samples)
# Function to parse the final answer after '#### '
def parse_final_answer(answer_str):
match = re.search(r"####\s*(.*)", answer_str)
if match:
return match.group(1).strip()
return None
def strip_code_block_markers(code_str):
# Remove triple backticks and optional language hints
code_str = re.sub(r"^```[a-zA-Z]*\n", "", code_str)
code_str = re.sub(r"```$", "", code_str)
return code_str.strip()
def test_generated_code(code_str):
"""
Executes the generated code and checks for errors.
Returns (True, None) if success, (False, error_message) if error.
"""
local_vars = {}
try:
exec(code_str, {}, local_vars)
return True, None
except Exception as e:
return False, str(e)
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
results = []
for i, entry in enumerate(data):
question = entry.get("question", "")
answer = entry.get("answer", "")
final_answer = parse_final_answer(answer)
print(f"Question: {question}")
print(f"Answer: {answer}")
print(f"Final answer: {final_answer}")
prompt = f"""
Given the following math word problem and its full worked solution, write a Python function that solves the problem for all possible initial values (i.e., make the numbers in the problem parameters). The function should take the relevant parameters as arguments and return the answer. Do not use hardcoded values from the example; generalize the solution. Please ensure that the function throws an exception when it receives improper parameter values.
Problem:
{question}
Full Solution:
{answer}
Write only the Python function code. Then write a function call with the initial values from the problem description and assert the answer is equal to the ground truth answer.
"""
print(f"Prompt: {prompt}")
while True:
try:
completion = client.chat.completions.create(
model="o3",
messages=[
{"role": "developer", "content": "You are a helpful assistant that writes Python code for math problems."},
{"role": "user", "content": prompt}
],
)
code = completion.choices[0].message.content.strip()
print(f"Code: {code}")
break
except Exception as e:
code = f"Error: {e}"
print(f"Error: {e}")
time.sleep(1)
# Remove code block markers if present
code_clean = strip_code_block_markers(code)
# Test the generated code
ok, err = test_generated_code(code_clean)
code_test_result = "success" if ok else f"error: {err}"
results.append({
"question": question,
"answer": answer,
"parsed_final_answer": final_answer,
"generated_code": code,
"code_clean": code_clean,
"code_test_result": code_test_result
})
print(f"Processed {i+1}/1 entries. Code test result: {code_test_result}")
time.sleep(1)
while True:
try:
with open(output_file, "a", encoding="utf-8") as f:
json.dump(results[-1], f, ensure_ascii=False)
f.write("\n")
break
except Exception as e:
print(f"Error: {e}")
time.sleep(10)
print(f"Saved results to {output_file}")