Fix all lint errors (#2627)
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from openai import OpenAI
|
||||
import argparse
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
|
||||
@@ -10,22 +10,17 @@ Borrowed from https://github.com/WujiangXu/AgenticMemory/blob/main/utils.py
|
||||
}
|
||||
"""
|
||||
|
||||
import re
|
||||
import string
|
||||
import numpy as np
|
||||
from typing import List, Dict, Union
|
||||
import statistics
|
||||
from collections import defaultdict
|
||||
from rouge_score import rouge_scorer
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
from bert_score import score as bert_score
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import nltk
|
||||
from bert_score import score as bert_score
|
||||
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
|
||||
from nltk.translate.meteor_score import meteor_score
|
||||
from rouge_score import rouge_scorer
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from openai import OpenAI
|
||||
|
||||
# from load_dataset import load_locomo_dataset, QA, Turn, Session, Conversation
|
||||
from sentence_transformers.util import pytorch_cos_sim
|
||||
|
||||
@@ -71,7 +66,7 @@ def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
|
||||
for n, weights in enumerate(weights_list, start=1):
|
||||
try:
|
||||
score = sentence_bleu(ref_tokens, pred_tokens, weights=weights, smoothing_function=smooth)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print(f"Error calculating BLEU score: {e}")
|
||||
score = 0.0
|
||||
scores[f'bleu{n}'] = score
|
||||
@@ -158,21 +153,13 @@ def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
|
||||
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||
|
||||
# Calculate all scores
|
||||
rouge_scores = 0 #calculate_rouge_scores(prediction, reference)
|
||||
bleu_scores = calculate_bleu_scores(prediction, reference)
|
||||
bert_scores = 0 # calculate_bert_scores(prediction, reference)
|
||||
meteor = 0 # calculate_meteor_score(prediction, reference)
|
||||
sbert_similarity = 0 # calculate_sentence_similarity(prediction, reference)
|
||||
|
||||
# Combine all metrics
|
||||
metrics = {
|
||||
"exact_match": exact_match,
|
||||
"f1": f1,
|
||||
# **rouge_scores,
|
||||
**bleu_scores,
|
||||
# **bert_scores,
|
||||
# "meteor": meteor,
|
||||
# "sbert_similarity": sbert_similarity
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
Reference in New Issue
Block a user