Formatting (#2750)
This commit is contained in:
@@ -3,7 +3,7 @@ Borrowed from https://github.com/WujiangXu/AgenticMemory/blob/main/utils.py
|
||||
|
||||
@article{xu2025mem,
|
||||
title={A-mem: Agentic memory for llm agents},
|
||||
author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao
|
||||
author={Xu, Wujiang and Liang, Zujie and Mei, Kai and Gao, Hang and Tan, Juntao
|
||||
and Zhang, Yongfeng},
|
||||
journal={arXiv preprint arXiv:2502.12110},
|
||||
year={2025}
|
||||
@@ -26,42 +26,45 @@ from sentence_transformers.util import pytorch_cos_sim
|
||||
|
||||
# Download required NLTK data
|
||||
try:
|
||||
nltk.download('punkt', quiet=True)
|
||||
nltk.download('wordnet', quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
nltk.download("wordnet", quiet=True)
|
||||
except Exception as e:
|
||||
print(f"Error downloading NLTK data: {e}")
|
||||
|
||||
# Initialize SentenceTransformer model (this will be reused)
|
||||
try:
|
||||
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load SentenceTransformer model: {e}")
|
||||
sentence_model = None
|
||||
|
||||
|
||||
def simple_tokenize(text):
|
||||
"""Simple tokenization function."""
|
||||
# Convert to string if not already
|
||||
text = str(text)
|
||||
return text.lower().replace('.', ' ').replace(',', ' ').replace('!', ' ').replace('?', ' ').split()
|
||||
return text.lower().replace(".", " ").replace(",", " ").replace("!", " ").replace("?", " ").split()
|
||||
|
||||
|
||||
def calculate_rouge_scores(prediction: str, reference: str) -> Dict[str, float]:
|
||||
"""Calculate ROUGE scores for prediction against reference."""
|
||||
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
|
||||
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
|
||||
scores = scorer.score(reference, prediction)
|
||||
return {
|
||||
'rouge1_f': scores['rouge1'].fmeasure,
|
||||
'rouge2_f': scores['rouge2'].fmeasure,
|
||||
'rougeL_f': scores['rougeL'].fmeasure
|
||||
"rouge1_f": scores["rouge1"].fmeasure,
|
||||
"rouge2_f": scores["rouge2"].fmeasure,
|
||||
"rougeL_f": scores["rougeL"].fmeasure,
|
||||
}
|
||||
|
||||
|
||||
def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
|
||||
"""Calculate BLEU scores with different n-gram settings."""
|
||||
pred_tokens = nltk.word_tokenize(prediction.lower())
|
||||
ref_tokens = [nltk.word_tokenize(reference.lower())]
|
||||
|
||||
|
||||
weights_list = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)]
|
||||
smooth = SmoothingFunction().method1
|
||||
|
||||
|
||||
scores = {}
|
||||
for n, weights in enumerate(weights_list, start=1):
|
||||
try:
|
||||
@@ -69,26 +72,20 @@ def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
|
||||
except Exception as e:
|
||||
print(f"Error calculating BLEU score: {e}")
|
||||
score = 0.0
|
||||
scores[f'bleu{n}'] = score
|
||||
|
||||
scores[f"bleu{n}"] = score
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def calculate_bert_scores(prediction: str, reference: str) -> Dict[str, float]:
|
||||
"""Calculate BERTScore for semantic similarity."""
|
||||
try:
|
||||
P, R, F1 = bert_score([prediction], [reference], lang='en', verbose=False)
|
||||
return {
|
||||
'bert_precision': P.item(),
|
||||
'bert_recall': R.item(),
|
||||
'bert_f1': F1.item()
|
||||
}
|
||||
P, R, F1 = bert_score([prediction], [reference], lang="en", verbose=False)
|
||||
return {"bert_precision": P.item(), "bert_recall": R.item(), "bert_f1": F1.item()}
|
||||
except Exception as e:
|
||||
print(f"Error calculating BERTScore: {e}")
|
||||
return {
|
||||
'bert_precision': 0.0,
|
||||
'bert_recall': 0.0,
|
||||
'bert_f1': 0.0
|
||||
}
|
||||
return {"bert_precision": 0.0, "bert_recall": 0.0, "bert_f1": 0.0}
|
||||
|
||||
|
||||
def calculate_meteor_score(prediction: str, reference: str) -> float:
|
||||
"""Calculate METEOR score for the prediction."""
|
||||
@@ -98,6 +95,7 @@ def calculate_meteor_score(prediction: str, reference: str) -> float:
|
||||
print(f"Error calculating METEOR score: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def calculate_sentence_similarity(prediction: str, reference: str) -> float:
|
||||
"""Calculate sentence embedding similarity using SentenceBERT."""
|
||||
if sentence_model is None:
|
||||
@@ -106,7 +104,7 @@ def calculate_sentence_similarity(prediction: str, reference: str) -> float:
|
||||
# Encode sentences
|
||||
embedding1 = sentence_model.encode([prediction], convert_to_tensor=True)
|
||||
embedding2 = sentence_model.encode([reference], convert_to_tensor=True)
|
||||
|
||||
|
||||
# Calculate cosine similarity
|
||||
similarity = pytorch_cos_sim(embedding1, embedding2).item()
|
||||
return float(similarity)
|
||||
@@ -114,6 +112,7 @@ def calculate_sentence_similarity(prediction: str, reference: str) -> float:
|
||||
print(f"Error calculating sentence similarity: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
|
||||
"""Calculate comprehensive evaluation metrics for a prediction."""
|
||||
# Handle empty or None values
|
||||
@@ -130,31 +129,31 @@ def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
|
||||
"bleu4": 0.0,
|
||||
"bert_f1": 0.0,
|
||||
"meteor": 0.0,
|
||||
"sbert_similarity": 0.0
|
||||
"sbert_similarity": 0.0,
|
||||
}
|
||||
|
||||
|
||||
# Convert to strings if they're not already
|
||||
prediction = str(prediction).strip()
|
||||
reference = str(reference).strip()
|
||||
|
||||
|
||||
# Calculate exact match
|
||||
exact_match = int(prediction.lower() == reference.lower())
|
||||
|
||||
|
||||
# Calculate token-based F1 score
|
||||
pred_tokens = set(simple_tokenize(prediction))
|
||||
ref_tokens = set(simple_tokenize(reference))
|
||||
common_tokens = pred_tokens & ref_tokens
|
||||
|
||||
|
||||
if not pred_tokens or not ref_tokens:
|
||||
f1 = 0.0
|
||||
else:
|
||||
precision = len(common_tokens) / len(pred_tokens)
|
||||
recall = len(common_tokens) / len(ref_tokens)
|
||||
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||
|
||||
|
||||
# Calculate all scores
|
||||
bleu_scores = calculate_bleu_scores(prediction, reference)
|
||||
|
||||
|
||||
# Combine all metrics
|
||||
metrics = {
|
||||
"exact_match": exact_match,
|
||||
@@ -164,48 +163,49 @@ def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
|
||||
|
||||
return metrics
|
||||
|
||||
def aggregate_metrics(all_metrics: List[Dict[str, float]], all_categories: List[int]) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]:
|
||||
|
||||
def aggregate_metrics(
|
||||
all_metrics: List[Dict[str, float]], all_categories: List[int]
|
||||
) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]:
|
||||
"""Calculate aggregate statistics for all metrics, split by category."""
|
||||
if not all_metrics:
|
||||
return {}
|
||||
|
||||
|
||||
# Initialize aggregates for overall and per-category metrics
|
||||
aggregates = defaultdict(list)
|
||||
category_aggregates = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
|
||||
# Collect all values for each metric, both overall and per category
|
||||
for metrics, category in zip(all_metrics, all_categories):
|
||||
for metric_name, value in metrics.items():
|
||||
aggregates[metric_name].append(value)
|
||||
category_aggregates[category][metric_name].append(value)
|
||||
|
||||
|
||||
# Calculate statistics for overall metrics
|
||||
results = {
|
||||
"overall": {}
|
||||
}
|
||||
|
||||
results = {"overall": {}}
|
||||
|
||||
for metric_name, values in aggregates.items():
|
||||
results["overall"][metric_name] = {
|
||||
'mean': statistics.mean(values),
|
||||
'std': statistics.stdev(values) if len(values) > 1 else 0.0,
|
||||
'median': statistics.median(values),
|
||||
'min': min(values),
|
||||
'max': max(values),
|
||||
'count': len(values)
|
||||
"mean": statistics.mean(values),
|
||||
"std": statistics.stdev(values) if len(values) > 1 else 0.0,
|
||||
"median": statistics.median(values),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
"count": len(values),
|
||||
}
|
||||
|
||||
|
||||
# Calculate statistics for each category
|
||||
for category in sorted(category_aggregates.keys()):
|
||||
results[f"category_{category}"] = {}
|
||||
for metric_name, values in category_aggregates[category].items():
|
||||
if values: # Only calculate if we have values for this category
|
||||
results[f"category_{category}"][metric_name] = {
|
||||
'mean': statistics.mean(values),
|
||||
'std': statistics.stdev(values) if len(values) > 1 else 0.0,
|
||||
'median': statistics.median(values),
|
||||
'min': min(values),
|
||||
'max': max(values),
|
||||
'count': len(values)
|
||||
"mean": statistics.mean(values),
|
||||
"std": statistics.stdev(values) if len(values) > 1 else 0.0,
|
||||
"median": statistics.median(values),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
"count": len(values),
|
||||
}
|
||||
|
||||
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user