Files
t6_mem0/evaluation/metrics/utils.py
2025-05-06 01:16:02 +05:30

212 lines
7.5 KiB
Python

"""
Borrowed from https://github.com/WujiangXu/AgenticMemory/blob/main/utils.py
@article{xu2025mem,
title={A-mem: Agentic memory for llm agents},
author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao
and Zhang, Yongfeng},
journal={arXiv preprint arXiv:2502.12110},
year={2025}
}
"""
import statistics
from collections import defaultdict
from typing import Dict, List, Union
import nltk
from bert_score import score as bert_score
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer
from sentence_transformers import SentenceTransformer
# from load_dataset import load_locomo_dataset, QA, Turn, Session, Conversation
from sentence_transformers.util import pytorch_cos_sim
# Download required NLTK data
try:
nltk.download('punkt', quiet=True)
nltk.download('wordnet', quiet=True)
except Exception as e:
print(f"Error downloading NLTK data: {e}")
# Initialize SentenceTransformer model (this will be reused)
try:
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
except Exception as e:
print(f"Warning: Could not load SentenceTransformer model: {e}")
sentence_model = None
def simple_tokenize(text):
"""Simple tokenization function."""
# Convert to string if not already
text = str(text)
return text.lower().replace('.', ' ').replace(',', ' ').replace('!', ' ').replace('?', ' ').split()
def calculate_rouge_scores(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate ROUGE scores for prediction against reference."""
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
scores = scorer.score(reference, prediction)
return {
'rouge1_f': scores['rouge1'].fmeasure,
'rouge2_f': scores['rouge2'].fmeasure,
'rougeL_f': scores['rougeL'].fmeasure
}
def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate BLEU scores with different n-gram settings."""
pred_tokens = nltk.word_tokenize(prediction.lower())
ref_tokens = [nltk.word_tokenize(reference.lower())]
weights_list = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)]
smooth = SmoothingFunction().method1
scores = {}
for n, weights in enumerate(weights_list, start=1):
try:
score = sentence_bleu(ref_tokens, pred_tokens, weights=weights, smoothing_function=smooth)
except Exception as e:
print(f"Error calculating BLEU score: {e}")
score = 0.0
scores[f'bleu{n}'] = score
return scores
def calculate_bert_scores(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate BERTScore for semantic similarity."""
try:
P, R, F1 = bert_score([prediction], [reference], lang='en', verbose=False)
return {
'bert_precision': P.item(),
'bert_recall': R.item(),
'bert_f1': F1.item()
}
except Exception as e:
print(f"Error calculating BERTScore: {e}")
return {
'bert_precision': 0.0,
'bert_recall': 0.0,
'bert_f1': 0.0
}
def calculate_meteor_score(prediction: str, reference: str) -> float:
"""Calculate METEOR score for the prediction."""
try:
return meteor_score([reference.split()], prediction.split())
except Exception as e:
print(f"Error calculating METEOR score: {e}")
return 0.0
def calculate_sentence_similarity(prediction: str, reference: str) -> float:
"""Calculate sentence embedding similarity using SentenceBERT."""
if sentence_model is None:
return 0.0
try:
# Encode sentences
embedding1 = sentence_model.encode([prediction], convert_to_tensor=True)
embedding2 = sentence_model.encode([reference], convert_to_tensor=True)
# Calculate cosine similarity
similarity = pytorch_cos_sim(embedding1, embedding2).item()
return float(similarity)
except Exception as e:
print(f"Error calculating sentence similarity: {e}")
return 0.0
def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate comprehensive evaluation metrics for a prediction."""
# Handle empty or None values
if not prediction or not reference:
return {
"exact_match": 0,
"f1": 0.0,
"rouge1_f": 0.0,
"rouge2_f": 0.0,
"rougeL_f": 0.0,
"bleu1": 0.0,
"bleu2": 0.0,
"bleu3": 0.0,
"bleu4": 0.0,
"bert_f1": 0.0,
"meteor": 0.0,
"sbert_similarity": 0.0
}
# Convert to strings if they're not already
prediction = str(prediction).strip()
reference = str(reference).strip()
# Calculate exact match
exact_match = int(prediction.lower() == reference.lower())
# Calculate token-based F1 score
pred_tokens = set(simple_tokenize(prediction))
ref_tokens = set(simple_tokenize(reference))
common_tokens = pred_tokens & ref_tokens
if not pred_tokens or not ref_tokens:
f1 = 0.0
else:
precision = len(common_tokens) / len(pred_tokens)
recall = len(common_tokens) / len(ref_tokens)
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
# Calculate all scores
bleu_scores = calculate_bleu_scores(prediction, reference)
# Combine all metrics
metrics = {
"exact_match": exact_match,
"f1": f1,
**bleu_scores,
}
return metrics
def aggregate_metrics(all_metrics: List[Dict[str, float]], all_categories: List[int]) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]:
"""Calculate aggregate statistics for all metrics, split by category."""
if not all_metrics:
return {}
# Initialize aggregates for overall and per-category metrics
aggregates = defaultdict(list)
category_aggregates = defaultdict(lambda: defaultdict(list))
# Collect all values for each metric, both overall and per category
for metrics, category in zip(all_metrics, all_categories):
for metric_name, value in metrics.items():
aggregates[metric_name].append(value)
category_aggregates[category][metric_name].append(value)
# Calculate statistics for overall metrics
results = {
"overall": {}
}
for metric_name, values in aggregates.items():
results["overall"][metric_name] = {
'mean': statistics.mean(values),
'std': statistics.stdev(values) if len(values) > 1 else 0.0,
'median': statistics.median(values),
'min': min(values),
'max': max(values),
'count': len(values)
}
# Calculate statistics for each category
for category in sorted(category_aggregates.keys()):
results[f"category_{category}"] = {}
for metric_name, values in category_aggregates[category].items():
if values: # Only calculate if we have values for this category
results[f"category_{category}"][metric_name] = {
'mean': statistics.mean(values),
'std': statistics.stdev(values) if len(values) > 1 else 0.0,
'median': statistics.median(values),
'min': min(values),
'max': max(values),
'count': len(values)
}
return results