Formatting (#2750)
This commit is contained in:
@@ -33,35 +33,34 @@ Do NOT include both CORRECT and WRONG in your response, or it will break the eva
|
||||
Just return the label CORRECT or WRONG in a json format with the key as "label".
|
||||
"""
|
||||
|
||||
|
||||
def evaluate_llm_judge(question, gold_answer, generated_answer):
|
||||
"""Evaluate the generated answer against the gold answer using an LLM judge."""
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": ACCURACY_PROMPT.format(
|
||||
question=question,
|
||||
gold_answer=gold_answer,
|
||||
generated_answer=generated_answer
|
||||
)
|
||||
}],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": ACCURACY_PROMPT.format(
|
||||
question=question, gold_answer=gold_answer, generated_answer=generated_answer
|
||||
),
|
||||
}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.0
|
||||
temperature=0.0,
|
||||
)
|
||||
label = json.loads(response.choices[0].message.content)['label']
|
||||
label = json.loads(response.choices[0].message.content)["label"]
|
||||
return 1 if label == "CORRECT" else 0
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to evaluate RAG results using LLM judge."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Evaluate RAG results using LLM judge'
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Evaluate RAG results using LLM judge")
|
||||
parser.add_argument(
|
||||
'--input_file',
|
||||
"--input_file",
|
||||
type=str,
|
||||
default="results/default_run_v4_k30_new_graph.json",
|
||||
help='Path to the input dataset file'
|
||||
help="Path to the input dataset file",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -78,10 +77,10 @@ def main():
|
||||
index = 0
|
||||
for k, v in data.items():
|
||||
for x in v:
|
||||
question = x['question']
|
||||
gold_answer = x['answer']
|
||||
generated_answer = x['response']
|
||||
category = x['category']
|
||||
question = x["question"]
|
||||
gold_answer = x["answer"]
|
||||
generated_answer = x["response"]
|
||||
category = x["category"]
|
||||
|
||||
# Skip category 5
|
||||
if int(category) == 5:
|
||||
@@ -92,13 +91,15 @@ def main():
|
||||
LLM_JUDGE[category].append(label)
|
||||
|
||||
# Store the results
|
||||
RESULTS[index].append({
|
||||
"question": question,
|
||||
"gt_answer": gold_answer,
|
||||
"response": generated_answer,
|
||||
"category": category,
|
||||
"llm_label": label
|
||||
})
|
||||
RESULTS[index].append(
|
||||
{
|
||||
"question": question,
|
||||
"gt_answer": gold_answer,
|
||||
"response": generated_answer,
|
||||
"category": category,
|
||||
"llm_label": label,
|
||||
}
|
||||
)
|
||||
|
||||
# Save intermediate results
|
||||
with open(output_path, "w") as f:
|
||||
@@ -108,8 +109,7 @@ def main():
|
||||
print("All categories accuracy:")
|
||||
for cat, results in LLM_JUDGE.items():
|
||||
if results: # Only print if there are results for this category
|
||||
print(f" Category {cat}: {np.mean(results):.4f} "
|
||||
f"({sum(results)}/{len(results)})")
|
||||
print(f" Category {cat}: {np.mean(results):.4f} " f"({sum(results)}/{len(results)})")
|
||||
print("------------------------------------------")
|
||||
index += 1
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ Borrowed from https://github.com/WujiangXu/AgenticMemory/blob/main/utils.py
|
||||
|
||||
@article{xu2025mem,
|
||||
title={A-mem: Agentic memory for llm agents},
|
||||
author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao
|
||||
author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao
|
||||
and Zhang, Yongfeng},
|
||||
journal={arXiv preprint arXiv:2502.12110},
|
||||
year={2025}
|
||||
@@ -26,42 +26,45 @@ from sentence_transformers.util import pytorch_cos_sim
|
||||
|
||||
# Download required NLTK data
|
||||
try:
|
||||
nltk.download('punkt', quiet=True)
|
||||
nltk.download('wordnet', quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
nltk.download("wordnet", quiet=True)
|
||||
except Exception as e:
|
||||
print(f"Error downloading NLTK data: {e}")
|
||||
|
||||
# Initialize SentenceTransformer model (this will be reused)
|
||||
try:
|
||||
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load SentenceTransformer model: {e}")
|
||||
sentence_model = None
|
||||
|
||||
|
||||
def simple_tokenize(text):
|
||||
"""Simple tokenization function."""
|
||||
# Convert to string if not already
|
||||
text = str(text)
|
||||
return text.lower().replace('.', ' ').replace(',', ' ').replace('!', ' ').replace('?', ' ').split()
|
||||
return text.lower().replace(".", " ").replace(",", " ").replace("!", " ").replace("?", " ").split()
|
||||
|
||||
|
||||
def calculate_rouge_scores(prediction: str, reference: str) -> Dict[str, float]:
|
||||
"""Calculate ROUGE scores for prediction against reference."""
|
||||
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
|
||||
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
|
||||
scores = scorer.score(reference, prediction)
|
||||
return {
|
||||
'rouge1_f': scores['rouge1'].fmeasure,
|
||||
'rouge2_f': scores['rouge2'].fmeasure,
|
||||
'rougeL_f': scores['rougeL'].fmeasure
|
||||
"rouge1_f": scores["rouge1"].fmeasure,
|
||||
"rouge2_f": scores["rouge2"].fmeasure,
|
||||
"rougeL_f": scores["rougeL"].fmeasure,
|
||||
}
|
||||
|
||||
|
||||
def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
|
||||
"""Calculate BLEU scores with different n-gram settings."""
|
||||
pred_tokens = nltk.word_tokenize(prediction.lower())
|
||||
ref_tokens = [nltk.word_tokenize(reference.lower())]
|
||||
|
||||
|
||||
weights_list = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)]
|
||||
smooth = SmoothingFunction().method1
|
||||
|
||||
|
||||
scores = {}
|
||||
for n, weights in enumerate(weights_list, start=1):
|
||||
try:
|
||||
@@ -69,26 +72,20 @@ def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
|
||||
except Exception as e:
|
||||
print(f"Error calculating BLEU score: {e}")
|
||||
score = 0.0
|
||||
scores[f'bleu{n}'] = score
|
||||
|
||||
scores[f"bleu{n}"] = score
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def calculate_bert_scores(prediction: str, reference: str) -> Dict[str, float]:
|
||||
"""Calculate BERTScore for semantic similarity."""
|
||||
try:
|
||||
P, R, F1 = bert_score([prediction], [reference], lang='en', verbose=False)
|
||||
return {
|
||||
'bert_precision': P.item(),
|
||||
'bert_recall': R.item(),
|
||||
'bert_f1': F1.item()
|
||||
}
|
||||
P, R, F1 = bert_score([prediction], [reference], lang="en", verbose=False)
|
||||
return {"bert_precision": P.item(), "bert_recall": R.item(), "bert_f1": F1.item()}
|
||||
except Exception as e:
|
||||
print(f"Error calculating BERTScore: {e}")
|
||||
return {
|
||||
'bert_precision': 0.0,
|
||||
'bert_recall': 0.0,
|
||||
'bert_f1': 0.0
|
||||
}
|
||||
return {"bert_precision": 0.0, "bert_recall": 0.0, "bert_f1": 0.0}
|
||||
|
||||
|
||||
def calculate_meteor_score(prediction: str, reference: str) -> float:
|
||||
"""Calculate METEOR score for the prediction."""
|
||||
@@ -98,6 +95,7 @@ def calculate_meteor_score(prediction: str, reference: str) -> float:
|
||||
print(f"Error calculating METEOR score: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def calculate_sentence_similarity(prediction: str, reference: str) -> float:
|
||||
"""Calculate sentence embedding similarity using SentenceBERT."""
|
||||
if sentence_model is None:
|
||||
@@ -106,7 +104,7 @@ def calculate_sentence_similarity(prediction: str, reference: str) -> float:
|
||||
# Encode sentences
|
||||
embedding1 = sentence_model.encode([prediction], convert_to_tensor=True)
|
||||
embedding2 = sentence_model.encode([reference], convert_to_tensor=True)
|
||||
|
||||
|
||||
# Calculate cosine similarity
|
||||
similarity = pytorch_cos_sim(embedding1, embedding2).item()
|
||||
return float(similarity)
|
||||
@@ -114,6 +112,7 @@ def calculate_sentence_similarity(prediction: str, reference: str) -> float:
|
||||
print(f"Error calculating sentence similarity: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
|
||||
"""Calculate comprehensive evaluation metrics for a prediction."""
|
||||
# Handle empty or None values
|
||||
@@ -130,31 +129,31 @@ def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
|
||||
"bleu4": 0.0,
|
||||
"bert_f1": 0.0,
|
||||
"meteor": 0.0,
|
||||
"sbert_similarity": 0.0
|
||||
"sbert_similarity": 0.0,
|
||||
}
|
||||
|
||||
|
||||
# Convert to strings if they're not already
|
||||
prediction = str(prediction).strip()
|
||||
reference = str(reference).strip()
|
||||
|
||||
|
||||
# Calculate exact match
|
||||
exact_match = int(prediction.lower() == reference.lower())
|
||||
|
||||
|
||||
# Calculate token-based F1 score
|
||||
pred_tokens = set(simple_tokenize(prediction))
|
||||
ref_tokens = set(simple_tokenize(reference))
|
||||
common_tokens = pred_tokens & ref_tokens
|
||||
|
||||
|
||||
if not pred_tokens or not ref_tokens:
|
||||
f1 = 0.0
|
||||
else:
|
||||
precision = len(common_tokens) / len(pred_tokens)
|
||||
recall = len(common_tokens) / len(ref_tokens)
|
||||
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||
|
||||
|
||||
# Calculate all scores
|
||||
bleu_scores = calculate_bleu_scores(prediction, reference)
|
||||
|
||||
|
||||
# Combine all metrics
|
||||
metrics = {
|
||||
"exact_match": exact_match,
|
||||
@@ -164,48 +163,49 @@ def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
|
||||
|
||||
return metrics
|
||||
|
||||
def aggregate_metrics(all_metrics: List[Dict[str, float]], all_categories: List[int]) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]:
|
||||
|
||||
def aggregate_metrics(
|
||||
all_metrics: List[Dict[str, float]], all_categories: List[int]
|
||||
) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]:
|
||||
"""Calculate aggregate statistics for all metrics, split by category."""
|
||||
if not all_metrics:
|
||||
return {}
|
||||
|
||||
|
||||
# Initialize aggregates for overall and per-category metrics
|
||||
aggregates = defaultdict(list)
|
||||
category_aggregates = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
|
||||
# Collect all values for each metric, both overall and per category
|
||||
for metrics, category in zip(all_metrics, all_categories):
|
||||
for metric_name, value in metrics.items():
|
||||
aggregates[metric_name].append(value)
|
||||
category_aggregates[category][metric_name].append(value)
|
||||
|
||||
|
||||
# Calculate statistics for overall metrics
|
||||
results = {
|
||||
"overall": {}
|
||||
}
|
||||
|
||||
results = {"overall": {}}
|
||||
|
||||
for metric_name, values in aggregates.items():
|
||||
results["overall"][metric_name] = {
|
||||
'mean': statistics.mean(values),
|
||||
'std': statistics.stdev(values) if len(values) > 1 else 0.0,
|
||||
'median': statistics.median(values),
|
||||
'min': min(values),
|
||||
'max': max(values),
|
||||
'count': len(values)
|
||||
"mean": statistics.mean(values),
|
||||
"std": statistics.stdev(values) if len(values) > 1 else 0.0,
|
||||
"median": statistics.median(values),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
"count": len(values),
|
||||
}
|
||||
|
||||
|
||||
# Calculate statistics for each category
|
||||
for category in sorted(category_aggregates.keys()):
|
||||
results[f"category_{category}"] = {}
|
||||
for metric_name, values in category_aggregates[category].items():
|
||||
if values: # Only calculate if we have values for this category
|
||||
results[f"category_{category}"][metric_name] = {
|
||||
'mean': statistics.mean(values),
|
||||
'std': statistics.stdev(values) if len(values) > 1 else 0.0,
|
||||
'median': statistics.median(values),
|
||||
'min': min(values),
|
||||
'max': max(values),
|
||||
'count': len(values)
|
||||
"mean": statistics.mean(values),
|
||||
"std": statistics.stdev(values) if len(values) > 1 else 0.0,
|
||||
"median": statistics.median(values),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
"count": len(values),
|
||||
}
|
||||
|
||||
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user