[Feature] Add support for RAG evaluation (#1154)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Deven Patel
2024-01-11 20:02:47 +05:30
committed by GitHub
parent 69e83adae0
commit e2cca61cd3
18 changed files with 788 additions and 21 deletions

View File

@@ -1,13 +1,15 @@
import ast
import concurrent.futures
import json
import logging
import os
import sqlite3
import uuid
from typing import Any, Optional
from typing import Any, Optional, Union
import requests
import yaml
from tqdm import tqdm
from embedchain.cache import (Config, ExactMatchEvaluation,
SearchDistanceEvaluation, cache,
@@ -18,11 +20,15 @@ from embedchain.constants import SQLITE_PATH
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.eval.base import BaseMetric
from embedchain.eval.metrics import (AnswerRelevance, ContextRelevance,
Groundedness)
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
from embedchain.llm.openai import OpenAILlm
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils.eval import EvalData, EvalMetric
from embedchain.utils.misc import validate_config
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
@@ -455,3 +461,103 @@ class App(EmbedChain):
chunker=chunker_config_data,
cache_config=cache_config,
)
def _eval(self, dataset: list[EvalData], metric: Union[BaseMetric, str]):
"""
Evaluate the app on a dataset for a given metric.
"""
metric_str = metric.name if isinstance(metric, BaseMetric) else metric
eval_class_map = {
EvalMetric.CONTEXT_RELEVANCY.value: ContextRelevance,
EvalMetric.ANSWER_RELEVANCY.value: AnswerRelevance,
EvalMetric.GROUNDEDNESS.value: Groundedness,
}
if metric_str in eval_class_map:
return eval_class_map[metric_str]().evaluate(dataset)
# Handle the case for custom metrics
if isinstance(metric, BaseMetric):
return metric.evaluate(dataset)
else:
raise ValueError(f"Invalid metric: {metric}")
def evaluate(
self,
questions: Union[str, list[str]],
metrics: Optional[list[Union[BaseMetric, str]]] = None,
num_workers: int = 4,
):
"""
Evaluate the app on a question.
param: questions: A question or a list of questions to evaluate.
type: questions: Union[str, list[str]]
param: metrics: A list of metrics to evaluate. Defaults to all metrics.
type: metrics: Optional[list[Union[BaseMetric, str]]]
param: num_workers: Number of workers to use for parallel processing.
type: num_workers: int
return: A dictionary containing the evaluation results.
rtype: dict
"""
if "OPENAI_API_KEY" not in os.environ:
raise ValueError("Please set the OPENAI_API_KEY environment variable with permission to use `gpt4` model.")
queries, answers, contexts = [], [], []
if isinstance(questions, list):
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
future_to_data = {executor.submit(self.query, q, citations=True): q for q in questions}
for future in tqdm(
concurrent.futures.as_completed(future_to_data),
total=len(future_to_data),
desc="Getting answer and contexts for questions",
):
question = future_to_data[future]
queries.append(question)
answer, context = future.result()
answers.append(answer)
contexts.append(list(map(lambda x: x[0], context)))
else:
answer, context = self.query(questions, citations=True)
queries = [questions]
answers = [answer]
contexts = [list(map(lambda x: x[0], context))]
metrics = metrics or [
EvalMetric.CONTEXT_RELEVANCY.value,
EvalMetric.ANSWER_RELEVANCY.value,
EvalMetric.GROUNDEDNESS.value,
]
logging.info(f"Collecting data from {len(queries)} questions for evaluation...")
dataset = []
for q, a, c in zip(queries, answers, contexts):
dataset.append(EvalData(question=q, answer=a, contexts=c))
logging.info(f"Evaluating {len(dataset)} data points...")
result = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
future_to_metric = {executor.submit(self._eval, dataset, metric): metric for metric in metrics}
for future in tqdm(
concurrent.futures.as_completed(future_to_metric),
total=len(future_to_metric),
desc="Evaluating metrics",
):
metric = future_to_metric[future]
if isinstance(metric, BaseMetric):
result[metric.name] = future.result()
else:
result[metric] = future.result()
if self.config.collect_metrics:
telemetry_props = self._telemetry_props
metrics_names = []
for metric in metrics:
if isinstance(metric, BaseMetric):
metrics_names.append(metric.name)
else:
metrics_names.append(metric)
telemetry_props["metrics"] = metrics_names
self.telemetry.capture(event_name="evaluate", properties=telemetry_props)
return result