[Updates] Update GPTCache configuration/docs (#1098)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2024-01-02 17:32:48 +05:30
committed by GitHub
parent c62663f2e4
commit 295cd3fac6
6 changed files with 146 additions and 11 deletions

View File

@@ -9,7 +9,8 @@ from typing import Any, Dict, Optional
import requests
import yaml
from embedchain.cache import (Config, SearchDistanceEvaluation, cache,
from embedchain.cache import (Config, ExactMatchEvaluation,
SearchDistanceEvaluation, cache,
gptcache_data_manager, gptcache_pre_function)
from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
@@ -156,12 +157,20 @@ class App(EmbedChain):
self.db.set_collection_name(self.db.config.collection_name)
def _init_cache(self):
if self.cache_config.similarity_eval_config.strategy == "exact":
similarity_eval_func = ExactMatchEvaluation()
else:
similarity_eval_func = SearchDistanceEvaluation(
max_distance=self.cache_config.similarity_eval_config.max_distance,
positive=self.cache_config.similarity_eval_config.positive,
)
cache.init(
pre_embedding_func=gptcache_pre_function,
embedding_func=self.embedding_model.to_embeddings,
data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
similarity_evaluation=SearchDistanceEvaluation(max_distance=1.0),
config=Config(similarity_threshold=self.cache_config.similarity_threshold),
similarity_evaluation=similarity_eval_func,
config=Config(**self.cache_config.init_config.as_dict()),
)
def _init_client(self):
@@ -428,7 +437,7 @@ class App(EmbedChain):
)
if cache_config_data is not None:
cache_config = CacheConfig(**cache_config_data)
cache_config = CacheConfig.from_config(cache_config_data)
else:
cache_config = None

View File

@@ -11,6 +11,8 @@ from gptcache.manager.scalar_data.base import DataType as CacheDataType
from gptcache.session import Session
from gptcache.similarity_evaluation.distance import \
SearchDistanceEvaluation # noqa: F401
from gptcache.similarity_evaluation.exact_match import \
ExactMatchEvaluation # noqa: F401
def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]):

View File

@@ -1,16 +1,93 @@
from typing import Optional
from typing import Any, Dict, Optional
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class CacheConfig(BaseConfig):
class CacheSimilarityEvalConfig(BaseConfig):
"""
This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage.
In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been
put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`].
`positive` is used to indicate this distance is directly proportional to the similarity of two entites.
If `positive` is set `False`, `max_distance` will be used to substract this distance to get the final score.
:param max_distance: the bound of maximum distance.
:type max_distance: float
:param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise it is False.
:type positive: bool
"""
def __init__(
self,
similarity_threshold: Optional[float] = 0.5,
strategy: Optional[str] = "distance",
max_distance: Optional[float] = 1.0,
positive: Optional[bool] = False,
):
self.strategy = strategy
self.max_distance = max_distance
self.positive = positive
def from_config(config: Optional[Dict[str, Any]]):
if config is None:
return CacheSimilarityEvalConfig()
else:
return CacheSimilarityEvalConfig(
strategy=config.get("strategy", "distance"),
max_distance=config.get("max_distance", 1.0),
positive=config.get("positive", False),
)
@register_deserializable
class CacheInitConfig(BaseConfig):
"""
This is a cache init config. Used to initialize a cache.
:param similarity_threshold: a threshold ranged from 0 to 1 to filter search results with similarity score higher \
than the threshold. When it is 0, there is no hits. When it is 1, all search results will be returned as hits.
:type similarity_threshold: float
:param auto_flush: it will be automatically flushed every time xx pieces of data are added, default to 20
:type auto_flush: int
"""
def __init__(
self,
similarity_threshold: Optional[float] = 0.8,
auto_flush: Optional[int] = 20,
):
if similarity_threshold < 0 or similarity_threshold > 1:
raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
self.similarity_threshold = similarity_threshold
self.auto_flush = auto_flush
def from_config(config: Optional[Dict[str, Any]]):
if config is None:
return CacheInitConfig()
else:
return CacheInitConfig(
similarity_threshold=config.get("similarity_threshold", 0.8),
auto_flush=config.get("auto_flush", 20),
)
@register_deserializable
class CacheConfig(BaseConfig):
def __init__(
self,
similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(),
init_config: Optional[CacheInitConfig] = CacheInitConfig(),
):
self.similarity_eval_config = similarity_eval_config
self.init_config = init_config
def from_config(config: Optional[Dict[str, Any]]):
if config is None:
return CacheConfig()
else:
return CacheConfig(
similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})),
init_config=CacheInitConfig.from_config(config.get("init_config", {})),
)

View File

@@ -42,6 +42,7 @@ class OpenAILlm(BaseLlm):
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
else:
chat = ChatOpenAI(**kwargs, api_key=api_key)
if self.functions is not None:
from langchain.chains.openai_functions import \
create_openai_fn_runnable

View File

@@ -441,7 +441,15 @@ def validate_config(config_data):
Optional("min_chunk_size"): int,
},
Optional("cache"): {
Optional("similarity_threshold"): float,
Optional("similarity_evaluation"): {
Optional("strategy"): Or("distance", "exact"),
Optional("max_distance"): float,
Optional("positive"): bool,
},
Optional("config"): {
Optional("similarity_threshold"): float,
Optional("auto_flush"): int,
},
},
}
)