[Feature] Add support for RAG evaluation (#1154)
Co-authored-by: Deven Patel <deven298@yahoo.com> Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -1,13 +1,15 @@
|
||||
import ast
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.cache import (Config, ExactMatchEvaluation,
|
||||
SearchDistanceEvaluation, cache,
|
||||
@@ -18,11 +20,15 @@ from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.eval.base import BaseMetric
|
||||
from embedchain.eval.metrics import (AnswerRelevance, ContextRelevance,
|
||||
Groundedness)
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils.eval import EvalData, EvalMetric
|
||||
from embedchain.utils.misc import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
@@ -455,3 +461,103 @@ class App(EmbedChain):
|
||||
chunker=chunker_config_data,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
|
||||
def _eval(self, dataset: list[EvalData], metric: Union[BaseMetric, str]):
|
||||
"""
|
||||
Evaluate the app on a dataset for a given metric.
|
||||
"""
|
||||
metric_str = metric.name if isinstance(metric, BaseMetric) else metric
|
||||
eval_class_map = {
|
||||
EvalMetric.CONTEXT_RELEVANCY.value: ContextRelevance,
|
||||
EvalMetric.ANSWER_RELEVANCY.value: AnswerRelevance,
|
||||
EvalMetric.GROUNDEDNESS.value: Groundedness,
|
||||
}
|
||||
|
||||
if metric_str in eval_class_map:
|
||||
return eval_class_map[metric_str]().evaluate(dataset)
|
||||
|
||||
# Handle the case for custom metrics
|
||||
if isinstance(metric, BaseMetric):
|
||||
return metric.evaluate(dataset)
|
||||
else:
|
||||
raise ValueError(f"Invalid metric: {metric}")
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
questions: Union[str, list[str]],
|
||||
metrics: Optional[list[Union[BaseMetric, str]]] = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
"""
|
||||
Evaluate the app on a question.
|
||||
|
||||
param: questions: A question or a list of questions to evaluate.
|
||||
type: questions: Union[str, list[str]]
|
||||
param: metrics: A list of metrics to evaluate. Defaults to all metrics.
|
||||
type: metrics: Optional[list[Union[BaseMetric, str]]]
|
||||
param: num_workers: Number of workers to use for parallel processing.
|
||||
type: num_workers: int
|
||||
return: A dictionary containing the evaluation results.
|
||||
rtype: dict
|
||||
"""
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the OPENAI_API_KEY environment variable with permission to use `gpt4` model.")
|
||||
|
||||
queries, answers, contexts = [], [], []
|
||||
if isinstance(questions, list):
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
future_to_data = {executor.submit(self.query, q, citations=True): q for q in questions}
|
||||
for future in tqdm(
|
||||
concurrent.futures.as_completed(future_to_data),
|
||||
total=len(future_to_data),
|
||||
desc="Getting answer and contexts for questions",
|
||||
):
|
||||
question = future_to_data[future]
|
||||
queries.append(question)
|
||||
answer, context = future.result()
|
||||
answers.append(answer)
|
||||
contexts.append(list(map(lambda x: x[0], context)))
|
||||
else:
|
||||
answer, context = self.query(questions, citations=True)
|
||||
queries = [questions]
|
||||
answers = [answer]
|
||||
contexts = [list(map(lambda x: x[0], context))]
|
||||
|
||||
metrics = metrics or [
|
||||
EvalMetric.CONTEXT_RELEVANCY.value,
|
||||
EvalMetric.ANSWER_RELEVANCY.value,
|
||||
EvalMetric.GROUNDEDNESS.value,
|
||||
]
|
||||
|
||||
logging.info(f"Collecting data from {len(queries)} questions for evaluation...")
|
||||
dataset = []
|
||||
for q, a, c in zip(queries, answers, contexts):
|
||||
dataset.append(EvalData(question=q, answer=a, contexts=c))
|
||||
|
||||
logging.info(f"Evaluating {len(dataset)} data points...")
|
||||
result = {}
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
future_to_metric = {executor.submit(self._eval, dataset, metric): metric for metric in metrics}
|
||||
for future in tqdm(
|
||||
concurrent.futures.as_completed(future_to_metric),
|
||||
total=len(future_to_metric),
|
||||
desc="Evaluating metrics",
|
||||
):
|
||||
metric = future_to_metric[future]
|
||||
if isinstance(metric, BaseMetric):
|
||||
result[metric.name] = future.result()
|
||||
else:
|
||||
result[metric] = future.result()
|
||||
|
||||
if self.config.collect_metrics:
|
||||
telemetry_props = self._telemetry_props
|
||||
metrics_names = []
|
||||
for metric in metrics:
|
||||
if isinstance(metric, BaseMetric):
|
||||
metrics_names.append(metric.name)
|
||||
else:
|
||||
metrics_names.append(metric)
|
||||
telemetry_props["metrics"] = metrics_names
|
||||
self.telemetry.capture(event_name="evaluate", properties=telemetry_props)
|
||||
|
||||
return result
|
||||
|
||||
2
embedchain/config/eval/__init__.py
Normal file
2
embedchain/config/eval/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import (AnswerRelevanceConfig, ContextRelevanceConfig, # noqa: F401
|
||||
GroundednessConfig)
|
||||
92
embedchain/config/eval/base.py
Normal file
92
embedchain/config/eval/base.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
|
||||
ANSWER_RELEVANCY_PROMPT = """
|
||||
Please provide $num_gen_questions questions from the provided answer.
|
||||
You must provide the complete question, if are not able to provide the complete question, return empty string ("").
|
||||
Please only provide one question per line without numbers or bullets to distinguish them.
|
||||
You must only provide the questions and no other text.
|
||||
|
||||
$answer
|
||||
""" # noqa:E501
|
||||
|
||||
|
||||
CONTEXT_RELEVANCY_PROMPT = """
|
||||
Please extract relevant sentences from the provided context that is required to answer the given question.
|
||||
If no relevant sentences are found, or if you believe the question cannot be answered from the given context, return the empty string ("").
|
||||
While extracting candidate sentences you're not allowed to make any changes to sentences from given context or make up any sentences.
|
||||
You must only provide sentences from the given context and nothing else.
|
||||
|
||||
Context: $context
|
||||
Question: $question
|
||||
""" # noqa:E501
|
||||
|
||||
GROUNDEDNESS_ANSWER_CLAIMS_PROMPT = """
|
||||
Please provide one or more statements from each sentence of the provided answer.
|
||||
You must provide the symantically equivalent statements for each sentence of the answer.
|
||||
You must provide the complete statement, if are not able to provide the complete statement, return empty string ("").
|
||||
Please only provide one statement per line WITHOUT numbers or bullets.
|
||||
If the question provided is not being answered in the provided answer, return empty string ("").
|
||||
You must only provide the statements and no other text.
|
||||
|
||||
$question
|
||||
$answer
|
||||
""" # noqa:E501
|
||||
|
||||
GROUNDEDNESS_CLAIMS_INFERENCE_PROMPT = """
|
||||
Given the context and the provided claim statements, please provide a verdict for each claim statement whether it can be completely infered from the given context or not.
|
||||
Use only "1" (yes), "0" (no) and "-1" (null) for "yes", "no" or "null" respectively.
|
||||
You must provide one verdict per line, ONLY WITH "1", "0" or "-1" as per your verdict to the given statement and nothing else.
|
||||
You must provide the verdicts in the same order as the claim statements.
|
||||
|
||||
Contexts:
|
||||
$context
|
||||
|
||||
Claim statements:
|
||||
$claim_statements
|
||||
""" # noqa:E501
|
||||
|
||||
|
||||
class GroundednessConfig(BaseConfig):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4",
|
||||
api_key: Optional[str] = None,
|
||||
answer_claims_prompt: str = GROUNDEDNESS_ANSWER_CLAIMS_PROMPT,
|
||||
claims_inference_prompt: str = GROUNDEDNESS_CLAIMS_INFERENCE_PROMPT,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.answer_claims_prompt = answer_claims_prompt
|
||||
self.claims_inference_prompt = claims_inference_prompt
|
||||
|
||||
|
||||
class AnswerRelevanceConfig(BaseConfig):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4",
|
||||
embedder: str = "text-embedding-ada-002",
|
||||
api_key: Optional[str] = None,
|
||||
num_gen_questions: int = 1,
|
||||
prompt: str = ANSWER_RELEVANCY_PROMPT,
|
||||
):
|
||||
self.model = model
|
||||
self.embedder = embedder
|
||||
self.api_key = api_key
|
||||
self.num_gen_questions = num_gen_questions
|
||||
self.prompt = prompt
|
||||
|
||||
|
||||
class ContextRelevanceConfig(BaseConfig):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4",
|
||||
api_key: Optional[str] = None,
|
||||
language: str = "en",
|
||||
prompt: str = CONTEXT_RELEVANCY_PROMPT,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.language = language
|
||||
self.prompt = prompt
|
||||
@@ -7,12 +7,9 @@ from typing import Any, Optional, Union
|
||||
from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from embedchain.cache import (
|
||||
adapt,
|
||||
get_gptcache_session,
|
||||
gptcache_data_convert,
|
||||
gptcache_update_cache_callback,
|
||||
)
|
||||
from embedchain.cache import (adapt, get_gptcache_session,
|
||||
gptcache_data_convert,
|
||||
gptcache_update_cache_callback)
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config.base_app_config import BaseAppConfig
|
||||
@@ -22,7 +19,8 @@ from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils.misc import detect_datatype, is_valid_json_string
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
0
embedchain/eval/__init__.py
Normal file
0
embedchain/eval/__init__.py
Normal file
29
embedchain/eval/base.py
Normal file
29
embedchain/eval/base.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from embedchain.utils.eval import EvalData
|
||||
|
||||
|
||||
class BaseMetric(ABC):
|
||||
"""Base class for a metric.
|
||||
|
||||
This class provides a common interface for all metrics.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "base_metric"):
|
||||
"""
|
||||
Initialize the BaseMetric.
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, dataset: list[EvalData]):
|
||||
"""
|
||||
Abstract method to evaluate the dataset.
|
||||
|
||||
This method should be implemented by subclasses to perform the actual
|
||||
evaluation on the dataset.
|
||||
|
||||
:param dataset: dataset to evaluate
|
||||
:type dataset: list[EvalData]
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
3
embedchain/eval/metrics/__init__.py
Normal file
3
embedchain/eval/metrics/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .answer_relevancy import AnswerRelevance # noqa: F401
|
||||
from .context_relevancy import ContextRelevance # noqa: F401
|
||||
from .groundedness import Groundedness # noqa: F401
|
||||
93
embedchain/eval/metrics/answer_relevancy.py
Normal file
93
embedchain/eval/metrics/answer_relevancy.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
from string import Template
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.config.eval.base import AnswerRelevanceConfig
|
||||
from embedchain.eval.base import BaseMetric
|
||||
from embedchain.utils.eval import EvalData, EvalMetric
|
||||
|
||||
|
||||
class AnswerRelevance(BaseMetric):
|
||||
"""
|
||||
Metric for evaluating the relevance of answers.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[AnswerRelevanceConfig] = AnswerRelevanceConfig()):
|
||||
super().__init__(name=EvalMetric.ANSWER_RELEVANCY.value)
|
||||
self.config = config
|
||||
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("API key not found. Set 'OPENAI_API_KEY' or pass it in the config.")
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
|
||||
def _generate_prompt(self, data: EvalData) -> str:
|
||||
"""
|
||||
Generates a prompt based on the provided data.
|
||||
"""
|
||||
return Template(self.config.prompt).substitute(
|
||||
num_gen_questions=self.config.num_gen_questions, answer=data.answer
|
||||
)
|
||||
|
||||
def _generate_questions(self, prompt: str) -> list[str]:
|
||||
"""
|
||||
Generates questions from the prompt.
|
||||
"""
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.config.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
)
|
||||
return response.choices[0].message.content.strip().split("\n")
|
||||
|
||||
def _generate_embedding(self, question: str) -> np.ndarray:
|
||||
"""
|
||||
Generates the embedding for a question.
|
||||
"""
|
||||
response = self.client.embeddings.create(
|
||||
input=question,
|
||||
model=self.config.embedder,
|
||||
)
|
||||
return np.array(response.data[0].embedding)
|
||||
|
||||
def _compute_similarity(self, original: np.ndarray, generated: np.ndarray) -> float:
|
||||
"""
|
||||
Computes the cosine similarity between two embeddings.
|
||||
"""
|
||||
original = original.reshape(1, -1)
|
||||
norm = np.linalg.norm(original) * np.linalg.norm(generated, axis=1)
|
||||
return np.dot(generated, original.T).flatten() / norm
|
||||
|
||||
def _compute_score(self, data: EvalData) -> float:
|
||||
"""
|
||||
Computes the relevance score for a given data item.
|
||||
"""
|
||||
prompt = self._generate_prompt(data)
|
||||
generated_questions = self._generate_questions(prompt)
|
||||
original_embedding = self._generate_embedding(data.question)
|
||||
generated_embeddings = np.array([self._generate_embedding(q) for q in generated_questions])
|
||||
similarities = self._compute_similarity(original_embedding, generated_embeddings)
|
||||
return np.mean(similarities)
|
||||
|
||||
def evaluate(self, dataset: list[EvalData]) -> float:
|
||||
"""
|
||||
Evaluates the dataset and returns the average answer relevance score.
|
||||
"""
|
||||
results = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_to_data = {executor.submit(self._compute_score, data): data for data in dataset}
|
||||
for future in tqdm(
|
||||
concurrent.futures.as_completed(future_to_data), total=len(dataset), desc="Evaluating Answer Relevancy"
|
||||
):
|
||||
data = future_to_data[future]
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as e:
|
||||
logging.error(f"Error evaluating answer relevancy for {data}: {e}")
|
||||
|
||||
return np.mean(results) if results else 0.0
|
||||
69
embedchain/eval/metrics/context_relevancy.py
Normal file
69
embedchain/eval/metrics/context_relevancy.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import concurrent.futures
|
||||
import os
|
||||
from string import Template
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pysbd
|
||||
from openai import OpenAI
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.config.eval.base import ContextRelevanceConfig
|
||||
from embedchain.eval.base import BaseMetric
|
||||
from embedchain.utils.eval import EvalData, EvalMetric
|
||||
|
||||
|
||||
class ContextRelevance(BaseMetric):
|
||||
"""
|
||||
Metric for evaluating the relevance of context in a dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ContextRelevanceConfig] = ContextRelevanceConfig()):
|
||||
super().__init__(name=EvalMetric.CONTEXT_RELEVANCY.value)
|
||||
self.config = config
|
||||
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("API key not found. Set 'OPENAI_API_KEY' or pass it in the config.")
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
self._sbd = pysbd.Segmenter(language=self.config.language, clean=False)
|
||||
|
||||
def _sentence_segmenter(self, text: str) -> list[str]:
|
||||
"""
|
||||
Segments the given text into sentences.
|
||||
"""
|
||||
return self._sbd.segment(text)
|
||||
|
||||
def _compute_score(self, data: EvalData) -> float:
|
||||
"""
|
||||
Computes the context relevance score for a given data item.
|
||||
"""
|
||||
original_context = "\n".join(data.contexts)
|
||||
prompt = Template(self.config.prompt).substitute(context=original_context, question=data.question)
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.config.model, messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
useful_context = response.choices[0].message.content.strip()
|
||||
useful_context_sentences = self._sentence_segmenter(useful_context)
|
||||
original_context_sentences = self._sentence_segmenter(original_context)
|
||||
|
||||
if not original_context_sentences:
|
||||
return 0.0
|
||||
return len(useful_context_sentences) / len(original_context_sentences)
|
||||
|
||||
def evaluate(self, dataset: list[EvalData]) -> float:
|
||||
"""
|
||||
Evaluates the dataset and returns the average context relevance score.
|
||||
"""
|
||||
scores = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(self._compute_score, data) for data in dataset]
|
||||
for future in tqdm(
|
||||
concurrent.futures.as_completed(futures), total=len(dataset), desc="Evaluating Context Relevancy"
|
||||
):
|
||||
try:
|
||||
scores.append(future.result())
|
||||
except Exception as e:
|
||||
print(f"Error during evaluation: {e}")
|
||||
|
||||
return np.mean(scores) if scores else 0.0
|
||||
102
embedchain/eval/metrics/groundedness.py
Normal file
102
embedchain/eval/metrics/groundedness.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
from string import Template
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.config.eval.base import GroundednessConfig
|
||||
from embedchain.eval.base import BaseMetric
|
||||
from embedchain.utils.eval import EvalData, EvalMetric
|
||||
|
||||
|
||||
class Groundedness(BaseMetric):
|
||||
"""
|
||||
Metric for groundedness (aka faithfulness) of answer from the given contexts.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[GroundednessConfig] = None):
|
||||
super().__init__(name=EvalMetric.GROUNDEDNESS.value)
|
||||
self.config = config or GroundednessConfig()
|
||||
api_key = self.config.api_key or os.environ["OPENAI_API_KEY"]
|
||||
if not api_key:
|
||||
raise ValueError("Please set the OPENAI_API_KEY environment variable or pass the `api_key` in config.")
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
|
||||
def _generate_answer_claim_prompt(self, data: EvalData) -> str:
|
||||
"""
|
||||
Generate the prompt for the given data.
|
||||
"""
|
||||
prompt = Template(self.config.answer_claims_prompt).substitute(question=data.question, answer=data.answer)
|
||||
return prompt
|
||||
|
||||
def _get_claim_statements(self, prompt: str) -> np.ndarray:
|
||||
"""
|
||||
Get claim statements from the answer.
|
||||
"""
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.config.model,
|
||||
messages=[{"role": "user", "content": f"{prompt}"}],
|
||||
)
|
||||
result = response.choices[0].message.content.strip()
|
||||
claim_statements = np.array([statement for statement in result.split("\n") if statement])
|
||||
return claim_statements
|
||||
|
||||
def _generate_claim_inference_prompt(self, data: EvalData, claim_statements: list[str]) -> str:
|
||||
"""
|
||||
Generate the claim inference prompt for the given data and claim statements.
|
||||
"""
|
||||
prompt = Template(self.config.claims_inference_prompt).substitute(
|
||||
context="\n".join(data.contexts), claim_statements="\n".join(claim_statements)
|
||||
)
|
||||
return prompt
|
||||
|
||||
def _get_claim_verdict_scores(self, prompt: str) -> np.ndarray:
|
||||
"""
|
||||
Get verdicts for claim statements.
|
||||
"""
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.config.model,
|
||||
messages=[{"role": "user", "content": f"{prompt}"}],
|
||||
)
|
||||
result = response.choices[0].message.content.strip()
|
||||
claim_verdicts = result.split("\n")
|
||||
verdict_score_map = {"1": 1, "0": 0, "-1": np.nan}
|
||||
verdict_scores = np.array([verdict_score_map[verdict] for verdict in claim_verdicts])
|
||||
return verdict_scores
|
||||
|
||||
def _compute_score(self, data: EvalData) -> float:
|
||||
"""
|
||||
Compute the groundedness score (aka faithfulness) for a single data point.
|
||||
"""
|
||||
answer_claims_prompt = self._generate_answer_claim_prompt(data)
|
||||
claim_statements = self._get_claim_statements(answer_claims_prompt)
|
||||
|
||||
claim_inference_prompt = self._generate_claim_inference_prompt(data, claim_statements)
|
||||
verdict_scores = self._get_claim_verdict_scores(claim_inference_prompt)
|
||||
return np.sum(verdict_scores) / claim_statements.size
|
||||
|
||||
def evaluate(self, dataset: list[EvalData]):
|
||||
"""
|
||||
Evaluate the dataset and returns the average groundedness score.
|
||||
"""
|
||||
results = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_to_data = {executor.submit(self._compute_score, data): data for data in dataset}
|
||||
for future in tqdm(
|
||||
concurrent.futures.as_completed(future_to_data),
|
||||
total=len(future_to_data),
|
||||
desc="Evaluating groundedness (aka faithfulness)",
|
||||
):
|
||||
data = future_to_data[future]
|
||||
try:
|
||||
score = future.result()
|
||||
results.append(score)
|
||||
except Exception as e:
|
||||
logging.error(f"Error while evaluating groundedness for data point {data}: {e}")
|
||||
|
||||
return np.mean(results) if results else 0.0
|
||||
17
embedchain/utils/eval.py
Normal file
17
embedchain/utils/eval.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EvalMetric(Enum):
|
||||
CONTEXT_RELEVANCY = "context_relevancy"
|
||||
ANSWER_RELEVANCY = "answer_relevancy"
|
||||
GROUNDEDNESS = "groundedness"
|
||||
|
||||
|
||||
class EvalData(BaseModel):
|
||||
question: str
|
||||
contexts: list[str]
|
||||
answer: str
|
||||
ground_truth: Optional[str] = None # Not used as of now
|
||||
@@ -201,7 +201,8 @@ def detect_datatype(source: Any) -> DataType:
|
||||
formatted_source = format_source(str(source), 30)
|
||||
|
||||
if url:
|
||||
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
from langchain.document_loaders.youtube import \
|
||||
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
|
||||
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||
|
||||
@@ -6,15 +6,8 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
try:
|
||||
from pymilvus import (
|
||||
Collection,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
FieldSchema,
|
||||
MilvusClient,
|
||||
connections,
|
||||
utility,
|
||||
)
|
||||
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
|
||||
MilvusClient, connections, utility)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
|
||||
|
||||
Reference in New Issue
Block a user