Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pickle
from math import comb
from collections import Counter
from concurrent.futures import ThreadPoolExecutor

# envs.VLLM_HOST_IP="0.0.0.0" or "127.0.0.1"

Expand Down Expand Up @@ -187,28 +188,32 @@ def infer(args):



for i in tqdm(range(len(examples)), "check correct..."):
d = examples[i]
gt_cot, gt_ans = parse_ground_truth(d, args.data_name)
generated_responses = file_outputs[i]['generated_responses']
generated_answers = [extract_answer(generated_response) for generated_response in generated_responses]
filtered_gen_answers = [x for x in generated_answers if x]
if len(filtered_gen_answers) >= 1 :
most_common_answer, _ = Counter(filtered_gen_answers).most_common(1)[0]
else:
most_common_answer = ""
is_correct_list = [check_is_correct(generated_answer, gt_ans) for generated_answer in generated_answers]
is_correct = check_is_correct(most_common_answer, gt_ans)
correct_cnt_passk += sum(is_correct_list)
is_correct_majk = check_is_correct(most_common_answer, gt_ans)
if is_correct_majk:
correct_cnt_majk += 1
file_outputs[i]['generated_answers'] = generated_answers
file_outputs[i]['gold_answer'] = gt_ans
file_outputs[i]['is_correct'] = is_correct
file_outputs[i]['answers_correctness'] = is_correct_list
file_outputs[i]['most_common_answer'] = most_common_answer
file_outputs[i]['is_correct_majk'] = is_correct_majk
def process_output(i):
d = examples[i]
gt_cot, gt_ans = parse_ground_truth(d, args.data_name)
generated_responses = file_outputs[i]['generated_responses']
generated_answers = [extract_answer(resp) for resp in generated_responses]
filtered_gen_answers = [x for x in generated_answers if x]
most_common_answer = Counter(filtered_gen_answers).most_common(1)[0][0] if filtered_gen_answers else ""
is_correct_list = [check_is_correct(ans, gt_ans) for ans in generated_answers]
is_correct_majk = check_is_correct(most_common_answer, gt_ans)
return {
'generated_answers': generated_answers,
'gold_answer': gt_ans,
'is_correct': check_is_correct(most_common_answer, gt_ans),
'answers_correctness': is_correct_list,
'most_common_answer': most_common_answer,
'is_correct_majk': is_correct_majk,
'sum_correct': sum(is_correct_list),
'majk_correct': int(is_correct_majk),
}

with ThreadPoolExecutor() as executor:
results = list(tqdm(executor.map(process_output, range(len(examples))), total=len(examples), desc="check correct..."))
correct_cnt_passk = sum(r['sum_correct'] for r in results)
correct_cnt_majk = sum(r['majk_correct'] for r in results)
for i, res in enumerate(results):
file_outputs[i].update(res)


temp_out_file = out_file + ".tmp"
Expand Down