Files
t6_mem0/evaluation/evals.py
Prateek Chhikara 393a4fd5a6 Docs Update (#2591)
2025-04-29 08:15:25 -07:00

82 lines
2.6 KiB
Python

import json
import argparse
from metrics.utils import calculate_metrics, calculate_bleu_scores
from metrics.llm_judge import evaluate_llm_judge
from collections import defaultdict
from tqdm import tqdm
import concurrent.futures
import threading
def process_item(item_data):
k, v = item_data
local_results = defaultdict(list)
for item in v:
gt_answer = str(item['answer'])
pred_answer = str(item['response'])
category = str(item['category'])
question = str(item['question'])
# Skip category 5
if category == '5':
continue
metrics = calculate_metrics(pred_answer, gt_answer)
bleu_scores = calculate_bleu_scores(pred_answer, gt_answer)
llm_score = evaluate_llm_judge(question, gt_answer, pred_answer)
local_results[k].append({
"question": question,
"answer": gt_answer,
"response": pred_answer,
"category": category,
"bleu_score": bleu_scores["bleu1"],
"f1_score": metrics["f1"],
"llm_score": llm_score
})
return local_results
def main():
parser = argparse.ArgumentParser(description='Evaluate RAG results')
parser.add_argument('--input_file', type=str,
default="results/rag_results_500_k1.json",
help='Path to the input dataset file')
parser.add_argument('--output_file', type=str,
default="evaluation_metrics.json",
help='Path to save the evaluation results')
parser.add_argument('--max_workers', type=int, default=10,
help='Maximum number of worker threads')
args = parser.parse_args()
with open(args.input_file, 'r') as f:
data = json.load(f)
results = defaultdict(list)
results_lock = threading.Lock()
# Use ThreadPoolExecutor with specified workers
with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
futures = [executor.submit(process_item, item_data)
for item_data in data.items()]
for future in tqdm(concurrent.futures.as_completed(futures),
total=len(futures)):
local_results = future.result()
with results_lock:
for k, items in local_results.items():
results[k].extend(items)
# Save results to JSON file
with open(args.output_file, 'w') as f:
json.dump(results, f, indent=4)
print(f"Results saved to {args.output_file}")
if __name__ == "__main__":
main()