Formatting (#2750)

This commit is contained in:
Dev Khant
2025-05-22 01:17:29 +05:30
committed by GitHub
parent dff91154a7
commit d85fcda037
71 changed files with 1391 additions and 1823 deletions

View File

@@ -14,46 +14,47 @@ def process_item(item_data):
local_results = defaultdict(list)
for item in v:
gt_answer = str(item['answer'])
pred_answer = str(item['response'])
category = str(item['category'])
question = str(item['question'])
gt_answer = str(item["answer"])
pred_answer = str(item["response"])
category = str(item["category"])
question = str(item["question"])
# Skip category 5
if category == '5':
if category == "5":
continue
metrics = calculate_metrics(pred_answer, gt_answer)
bleu_scores = calculate_bleu_scores(pred_answer, gt_answer)
llm_score = evaluate_llm_judge(question, gt_answer, pred_answer)
local_results[k].append({
"question": question,
"answer": gt_answer,
"response": pred_answer,
"category": category,
"bleu_score": bleu_scores["bleu1"],
"f1_score": metrics["f1"],
"llm_score": llm_score
})
local_results[k].append(
{
"question": question,
"answer": gt_answer,
"response": pred_answer,
"category": category,
"bleu_score": bleu_scores["bleu1"],
"f1_score": metrics["f1"],
"llm_score": llm_score,
}
)
return local_results
def main():
parser = argparse.ArgumentParser(description='Evaluate RAG results')
parser.add_argument('--input_file', type=str,
default="results/rag_results_500_k1.json",
help='Path to the input dataset file')
parser.add_argument('--output_file', type=str,
default="evaluation_metrics.json",
help='Path to save the evaluation results')
parser.add_argument('--max_workers', type=int, default=10,
help='Maximum number of worker threads')
parser = argparse.ArgumentParser(description="Evaluate RAG results")
parser.add_argument(
"--input_file", type=str, default="results/rag_results_500_k1.json", help="Path to the input dataset file"
)
parser.add_argument(
"--output_file", type=str, default="evaluation_metrics.json", help="Path to save the evaluation results"
)
parser.add_argument("--max_workers", type=int, default=10, help="Maximum number of worker threads")
args = parser.parse_args()
with open(args.input_file, 'r') as f:
with open(args.input_file, "r") as f:
data = json.load(f)
results = defaultdict(list)
@@ -61,18 +62,16 @@ def main():
# Use ThreadPoolExecutor with specified workers
with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
futures = [executor.submit(process_item, item_data)
for item_data in data.items()]
futures = [executor.submit(process_item, item_data) for item_data in data.items()]
for future in tqdm(concurrent.futures.as_completed(futures),
total=len(futures)):
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
local_results = future.result()
with results_lock:
for k, items in local_results.items():
results[k].extend(items)
# Save results to JSON file
with open(args.output_file, 'w') as f:
with open(args.output_file, "w") as f:
json.dump(results, f, indent=4)
print(f"Results saved to {args.output_file}")