-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy patheval.py
More file actions
125 lines (107 loc) · 4.32 KB
/
eval.py
File metadata and controls
125 lines (107 loc) · 4.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import sys
import argparse
import re
import json
import torch
from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from accelerate.utils import gather_object
from utils.dschat_utils import print_rank_0, set_random_seed
from utils.model_utils import load_hf_tokenizer, create_hf_model
from utils.utils import load_args, convert_model, i_prompt, extract_answer, extract_answer_number, extract_answer_letter, generate_test, calc_correct
def restore_model(saved_path, saved_args, epoch):
if epoch != -1:
model_path = os.path.join(saved_path, f'Epoch_{epoch}')
else:
model_path = saved_path
print(model_path)
tokenizer = load_hf_tokenizer(model_path, fast_tokenizer=True)
tokenizer.padding_side = "left"
print_rank_0(f"tokenizer pad side: {tokenizer.padding_side}")
model = create_hf_model(
AutoModelForCausalLM,
model_path,
tokenizer,
)
if epoch != -1:
model = convert_model(model, saved_args, is_train=False)
state_dict = torch.load(os.path.join(model_path, 'pytorch_model.bin'), map_location='cpu')
print(model.load_state_dict(state_dict, strict=True))
return model, tokenizer
@torch.no_grad()
def main(args):
accelerator = Accelerator()
set_random_seed(args.seed)
saved_args = load_args(args.saved_path)
print_rank_0("Loading model and tokenizer...")
model, tokenizer = restore_model(args.saved_path, saved_args, args.epoch)
args.dtype = torch.float16 if args.dtype == 'fp16' else torch.float32 if args.dtype == 'fp32' else torch.bfloat16
model = model.to(dtype=args.dtype)
print('Move!')
model = accelerator.prepare(model)
model.eval()
tpath= f'{args.data_path}/{args.dataset}/test.json'
t_test_data = json.load(open(tpath, 'r'))
if args.sample != 0:
t_test_data = t_test_data[:args.sample]
prompts = []
for example in t_test_data:
prompt = i_prompt.format_map(example)
prompts.append(prompt)
print_rank_0(prompts[0])
accelerator.wait_for_everyone()
print('Let`s Start!')
device = accelerator.device
with accelerator.split_between_processes(prompts) as prompt:
model_outputs = []
outputs = generate_test(
model=model,
tokenizer=tokenizer,
device=device,
prompts=prompt,
batch_size=args.per_device_eval_batch_size,
verbose=True,
)
model_outputs.extend(outputs)
outputs = gather_object(model_outputs)
save_outputs, correct = calc_correct(t_test_data, outputs, args.dataset, args.data_type)
print_rank_0(f"Saving outputs to {args.output_dir}")
weighted_acc = correct/len(t_test_data)
print_rank_0("Result {:.1f}, total: {}".format(weighted_acc * 100, len(t_test_data)))
with open(os.path.join(args.output_dir, f"model_predictions.jsonl"), "w") as fout:
for example in save_outputs:
fout.write(json.dumps(example) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, default="", required=True)
parser.add_argument("--data_type", type=str, default="", required=True)
parser.add_argument("--dataset", type=str, default="", required=True)
parser.add_argument("--output_dir", type=str, default="", required=True)
parser.add_argument(
"--saved_path",
type=str,
help=
"Path to saved model",
required=True,
)
parser.add_argument("--epoch",
type=int,
default=0,
help="Epoch select.")
parser.add_argument("--sample",
type=int,
default=0,
help="data sampled.")
parser.add_argument("--seed",
type=int,
default=1234,
help="A seed for reproducible training.")
parser.add_argument('--dtype',
type=str,
default='bf16',
choices=['fp16', 'bf16', 'fp32'],
help='Inference data type')
parser.add_argument("--per_device_eval_batch_size", type=int, default=16, help="batch size for evaluation.")
args = parser.parse_args()
main(args)