Formatting (#2750)
This commit is contained in:
@@ -14,46 +14,47 @@ def process_item(item_data):
|
||||
local_results = defaultdict(list)
|
||||
|
||||
for item in v:
|
||||
gt_answer = str(item['answer'])
|
||||
pred_answer = str(item['response'])
|
||||
category = str(item['category'])
|
||||
question = str(item['question'])
|
||||
gt_answer = str(item["answer"])
|
||||
pred_answer = str(item["response"])
|
||||
category = str(item["category"])
|
||||
question = str(item["question"])
|
||||
|
||||
# Skip category 5
|
||||
if category == '5':
|
||||
if category == "5":
|
||||
continue
|
||||
|
||||
metrics = calculate_metrics(pred_answer, gt_answer)
|
||||
bleu_scores = calculate_bleu_scores(pred_answer, gt_answer)
|
||||
llm_score = evaluate_llm_judge(question, gt_answer, pred_answer)
|
||||
|
||||
local_results[k].append({
|
||||
"question": question,
|
||||
"answer": gt_answer,
|
||||
"response": pred_answer,
|
||||
"category": category,
|
||||
"bleu_score": bleu_scores["bleu1"],
|
||||
"f1_score": metrics["f1"],
|
||||
"llm_score": llm_score
|
||||
})
|
||||
local_results[k].append(
|
||||
{
|
||||
"question": question,
|
||||
"answer": gt_answer,
|
||||
"response": pred_answer,
|
||||
"category": category,
|
||||
"bleu_score": bleu_scores["bleu1"],
|
||||
"f1_score": metrics["f1"],
|
||||
"llm_score": llm_score,
|
||||
}
|
||||
)
|
||||
|
||||
return local_results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Evaluate RAG results')
|
||||
parser.add_argument('--input_file', type=str,
|
||||
default="results/rag_results_500_k1.json",
|
||||
help='Path to the input dataset file')
|
||||
parser.add_argument('--output_file', type=str,
|
||||
default="evaluation_metrics.json",
|
||||
help='Path to save the evaluation results')
|
||||
parser.add_argument('--max_workers', type=int, default=10,
|
||||
help='Maximum number of worker threads')
|
||||
parser = argparse.ArgumentParser(description="Evaluate RAG results")
|
||||
parser.add_argument(
|
||||
"--input_file", type=str, default="results/rag_results_500_k1.json", help="Path to the input dataset file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_file", type=str, default="evaluation_metrics.json", help="Path to save the evaluation results"
|
||||
)
|
||||
parser.add_argument("--max_workers", type=int, default=10, help="Maximum number of worker threads")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.input_file, 'r') as f:
|
||||
with open(args.input_file, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
results = defaultdict(list)
|
||||
@@ -61,18 +62,16 @@ def main():
|
||||
|
||||
# Use ThreadPoolExecutor with specified workers
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
||||
futures = [executor.submit(process_item, item_data)
|
||||
for item_data in data.items()]
|
||||
futures = [executor.submit(process_item, item_data) for item_data in data.items()]
|
||||
|
||||
for future in tqdm(concurrent.futures.as_completed(futures),
|
||||
total=len(futures)):
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
|
||||
local_results = future.result()
|
||||
with results_lock:
|
||||
for k, items in local_results.items():
|
||||
results[k].extend(items)
|
||||
|
||||
# Save results to JSON file
|
||||
with open(args.output_file, 'w') as f:
|
||||
with open(args.output_file, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
|
||||
print(f"Results saved to {args.output_file}")
|
||||
|
||||
Reference in New Issue
Block a user