[Feature] Add support for RAG evaluation (#1154)
Co-authored-by: Deven Patel <deven298@yahoo.com> Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -1,13 +1,15 @@
|
||||
import ast
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.cache import (Config, ExactMatchEvaluation,
|
||||
SearchDistanceEvaluation, cache,
|
||||
@@ -18,11 +20,15 @@ from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.eval.base import BaseMetric
|
||||
from embedchain.eval.metrics import (AnswerRelevance, ContextRelevance,
|
||||
Groundedness)
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils.eval import EvalData, EvalMetric
|
||||
from embedchain.utils.misc import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
@@ -455,3 +461,103 @@ class App(EmbedChain):
|
||||
chunker=chunker_config_data,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
|
||||
def _eval(self, dataset: list[EvalData], metric: Union[BaseMetric, str]):
|
||||
"""
|
||||
Evaluate the app on a dataset for a given metric.
|
||||
"""
|
||||
metric_str = metric.name if isinstance(metric, BaseMetric) else metric
|
||||
eval_class_map = {
|
||||
EvalMetric.CONTEXT_RELEVANCY.value: ContextRelevance,
|
||||
EvalMetric.ANSWER_RELEVANCY.value: AnswerRelevance,
|
||||
EvalMetric.GROUNDEDNESS.value: Groundedness,
|
||||
}
|
||||
|
||||
if metric_str in eval_class_map:
|
||||
return eval_class_map[metric_str]().evaluate(dataset)
|
||||
|
||||
# Handle the case for custom metrics
|
||||
if isinstance(metric, BaseMetric):
|
||||
return metric.evaluate(dataset)
|
||||
else:
|
||||
raise ValueError(f"Invalid metric: {metric}")
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
questions: Union[str, list[str]],
|
||||
metrics: Optional[list[Union[BaseMetric, str]]] = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
"""
|
||||
Evaluate the app on a question.
|
||||
|
||||
param: questions: A question or a list of questions to evaluate.
|
||||
type: questions: Union[str, list[str]]
|
||||
param: metrics: A list of metrics to evaluate. Defaults to all metrics.
|
||||
type: metrics: Optional[list[Union[BaseMetric, str]]]
|
||||
param: num_workers: Number of workers to use for parallel processing.
|
||||
type: num_workers: int
|
||||
return: A dictionary containing the evaluation results.
|
||||
rtype: dict
|
||||
"""
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the OPENAI_API_KEY environment variable with permission to use `gpt4` model.")
|
||||
|
||||
queries, answers, contexts = [], [], []
|
||||
if isinstance(questions, list):
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
future_to_data = {executor.submit(self.query, q, citations=True): q for q in questions}
|
||||
for future in tqdm(
|
||||
concurrent.futures.as_completed(future_to_data),
|
||||
total=len(future_to_data),
|
||||
desc="Getting answer and contexts for questions",
|
||||
):
|
||||
question = future_to_data[future]
|
||||
queries.append(question)
|
||||
answer, context = future.result()
|
||||
answers.append(answer)
|
||||
contexts.append(list(map(lambda x: x[0], context)))
|
||||
else:
|
||||
answer, context = self.query(questions, citations=True)
|
||||
queries = [questions]
|
||||
answers = [answer]
|
||||
contexts = [list(map(lambda x: x[0], context))]
|
||||
|
||||
metrics = metrics or [
|
||||
EvalMetric.CONTEXT_RELEVANCY.value,
|
||||
EvalMetric.ANSWER_RELEVANCY.value,
|
||||
EvalMetric.GROUNDEDNESS.value,
|
||||
]
|
||||
|
||||
logging.info(f"Collecting data from {len(queries)} questions for evaluation...")
|
||||
dataset = []
|
||||
for q, a, c in zip(queries, answers, contexts):
|
||||
dataset.append(EvalData(question=q, answer=a, contexts=c))
|
||||
|
||||
logging.info(f"Evaluating {len(dataset)} data points...")
|
||||
result = {}
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
future_to_metric = {executor.submit(self._eval, dataset, metric): metric for metric in metrics}
|
||||
for future in tqdm(
|
||||
concurrent.futures.as_completed(future_to_metric),
|
||||
total=len(future_to_metric),
|
||||
desc="Evaluating metrics",
|
||||
):
|
||||
metric = future_to_metric[future]
|
||||
if isinstance(metric, BaseMetric):
|
||||
result[metric.name] = future.result()
|
||||
else:
|
||||
result[metric] = future.result()
|
||||
|
||||
if self.config.collect_metrics:
|
||||
telemetry_props = self._telemetry_props
|
||||
metrics_names = []
|
||||
for metric in metrics:
|
||||
if isinstance(metric, BaseMetric):
|
||||
metrics_names.append(metric.name)
|
||||
else:
|
||||
metrics_names.append(metric)
|
||||
telemetry_props["metrics"] = metrics_names
|
||||
self.telemetry.capture(event_name="evaluate", properties=telemetry_props)
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user