diff --git a/eval/eval.py b/eval/eval.py index 43200e5..7c02142 100644 --- a/eval/eval.py +++ b/eval/eval.py @@ -18,6 +18,7 @@ import pickle from math import comb from collections import Counter +from concurrent.futures import ThreadPoolExecutor # envs.VLLM_HOST_IP="0.0.0.0" or "127.0.0.1" @@ -187,28 +188,32 @@ def infer(args): - for i in tqdm(range(len(examples)), "check correct..."): - d = examples[i] - gt_cot, gt_ans = parse_ground_truth(d, args.data_name) - generated_responses = file_outputs[i]['generated_responses'] - generated_answers = [extract_answer(generated_response) for generated_response in generated_responses] - filtered_gen_answers = [x for x in generated_answers if x] - if len(filtered_gen_answers) >= 1 : - most_common_answer, _ = Counter(filtered_gen_answers).most_common(1)[0] - else: - most_common_answer = "" - is_correct_list = [check_is_correct(generated_answer, gt_ans) for generated_answer in generated_answers] - is_correct = check_is_correct(most_common_answer, gt_ans) - correct_cnt_passk += sum(is_correct_list) - is_correct_majk = check_is_correct(most_common_answer, gt_ans) - if is_correct_majk: - correct_cnt_majk += 1 - file_outputs[i]['generated_answers'] = generated_answers - file_outputs[i]['gold_answer'] = gt_ans - file_outputs[i]['is_correct'] = is_correct - file_outputs[i]['answers_correctness'] = is_correct_list - file_outputs[i]['most_common_answer'] = most_common_answer - file_outputs[i]['is_correct_majk'] = is_correct_majk +def process_output(i): + d = examples[i] + gt_cot, gt_ans = parse_ground_truth(d, args.data_name) + generated_responses = file_outputs[i]['generated_responses'] + generated_answers = [extract_answer(resp) for resp in generated_responses] + filtered_gen_answers = [x for x in generated_answers if x] + most_common_answer = Counter(filtered_gen_answers).most_common(1)[0][0] if filtered_gen_answers else "" + is_correct_list = [check_is_correct(ans, gt_ans) for ans in generated_answers] + is_correct_majk = check_is_correct(most_common_answer, gt_ans) + return { + 'generated_answers': generated_answers, + 'gold_answer': gt_ans, + 'is_correct': check_is_correct(most_common_answer, gt_ans), + 'answers_correctness': is_correct_list, + 'most_common_answer': most_common_answer, + 'is_correct_majk': is_correct_majk, + 'sum_correct': sum(is_correct_list), + 'majk_correct': int(is_correct_majk), + } + +with ThreadPoolExecutor() as executor: + results = list(tqdm(executor.map(process_output, range(len(examples))), total=len(examples), desc="check correct...")) +correct_cnt_passk = sum(r['sum_correct'] for r in results) +correct_cnt_majk = sum(r['majk_correct'] for r in results) +for i, res in enumerate(results): + file_outputs[i].update(res) temp_out_file = out_file + ".tmp"