diff --git a/docs/api-reference/advanced/configuration.mdx b/docs/api-reference/advanced/configuration.mdx index 11eab691..b631e799 100644 --- a/docs/api-reference/advanced/configuration.mdx +++ b/docs/api-reference/advanced/configuration.mdx @@ -56,6 +56,14 @@ chunker: chunk_overlap: 100 length_function: 'len' min_chunk_size: 0 + +cache: + similarity_evaluation: + strategy: distance + max_distance: 1.0 + config: + similarity_threshold: 0.8 + auto_flush: 50 ``` ```json config.json @@ -98,7 +106,17 @@ chunker: "chunk_overlap": 100, "length_function": "len", "min_chunk_size": 0 - } + }, + "cache": { + "similarity_evaluation": { + "strategy": "distance", + "max_distance": 1.0, + }, + "config": { + "similarity_threshold": 0.8, + "auto_flush": 50, + }, + }, } ``` @@ -148,7 +166,17 @@ config = { 'chunk_overlap': 100, 'length_function': 'len', 'min_chunk_size': 0 - } + }, + 'cache': { + 'similarity_evaluation': { + 'strategy': 'distance', + 'max_distance': 1.0, + }, + 'config': { + 'similarity_threshold': 0.8, + 'auto_flush': 50, + }, + }, } ``` @@ -192,7 +220,17 @@ Alright, let's dive into what each key means in the yaml config above: - `chunk_overlap` (Integer): The amount of overlap between each chunk of text. - `length_function` (String): The function used to calculate the length of each chunk of text. In this case, it's set to 'len'. You can also use any function import directly as a string here. - `min_chunk_size` (Integer): The minimum size of each chunk of text that is sent to the language model. Must be less than `chunk_size`, and greater than `chunk_overlap`. - +6. `cache` Section: (Optional) + - `similarity_evaluation` (Optional): The config for similarity evaluation strategy. If not provided, the default `distance` based similarity evaluation strategy is used. + - `strategy` (String): The strategy to use for similarity evaluation. Currently, only `distance` and `exact` based similarity evaluation is supported. Defaults to `distance`. + - `max_distance` (Float): The bound of maximum distance. Defaults to `1.0`. + - `positive` (Boolean): If the larger distance indicates more similar of two entities, set it `True`, otherwise `False`. Defaults to `False`. + - `config` (Optional): The config for initializing the cache. If not provided, sensible default values are used as mentioned below. + - `similarity_threshold` (Float): The threshold for similarity evaluation. Defaults to `0.8`. + - `auto_flush` (Integer): The number of queries after which the cache is flushed. Defaults to `20`. + + If you provide a cache section, the app will automatically configure and use a cache to store the results of the language model. This is useful if you want to speed up the response time and save inference cost of your app. + If you have questions about the configuration above, please feel free to reach out to us using one of the following methods: \ No newline at end of file diff --git a/embedchain/app.py b/embedchain/app.py index 18042e2f..288e5650 100644 --- a/embedchain/app.py +++ b/embedchain/app.py @@ -9,7 +9,8 @@ from typing import Any, Dict, Optional import requests import yaml -from embedchain.cache import (Config, SearchDistanceEvaluation, cache, +from embedchain.cache import (Config, ExactMatchEvaluation, + SearchDistanceEvaluation, cache, gptcache_data_manager, gptcache_pre_function) from embedchain.client import Client from embedchain.config import AppConfig, CacheConfig, ChunkerConfig @@ -156,12 +157,20 @@ class App(EmbedChain): self.db.set_collection_name(self.db.config.collection_name) def _init_cache(self): + if self.cache_config.similarity_eval_config.strategy == "exact": + similarity_eval_func = ExactMatchEvaluation() + else: + similarity_eval_func = SearchDistanceEvaluation( + max_distance=self.cache_config.similarity_eval_config.max_distance, + positive=self.cache_config.similarity_eval_config.positive, + ) + cache.init( pre_embedding_func=gptcache_pre_function, embedding_func=self.embedding_model.to_embeddings, data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension), - similarity_evaluation=SearchDistanceEvaluation(max_distance=1.0), - config=Config(similarity_threshold=self.cache_config.similarity_threshold), + similarity_evaluation=similarity_eval_func, + config=Config(**self.cache_config.init_config.as_dict()), ) def _init_client(self): @@ -428,7 +437,7 @@ class App(EmbedChain): ) if cache_config_data is not None: - cache_config = CacheConfig(**cache_config_data) + cache_config = CacheConfig.from_config(cache_config_data) else: cache_config = None diff --git a/embedchain/cache.py b/embedchain/cache.py index af2aa61c..ba1675ed 100644 --- a/embedchain/cache.py +++ b/embedchain/cache.py @@ -11,6 +11,8 @@ from gptcache.manager.scalar_data.base import DataType as CacheDataType from gptcache.session import Session from gptcache.similarity_evaluation.distance import \ SearchDistanceEvaluation # noqa: F401 +from gptcache.similarity_evaluation.exact_match import \ + ExactMatchEvaluation # noqa: F401 def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]): diff --git a/embedchain/config/cache_config.py b/embedchain/config/cache_config.py index 466f69d4..1121efd1 100644 --- a/embedchain/config/cache_config.py +++ b/embedchain/config/cache_config.py @@ -1,16 +1,93 @@ -from typing import Optional +from typing import Any, Dict, Optional from embedchain.config.base_config import BaseConfig from embedchain.helpers.json_serializable import register_deserializable @register_deserializable -class CacheConfig(BaseConfig): +class CacheSimilarityEvalConfig(BaseConfig): + """ + This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage. + In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been + put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`]. + `positive` is used to indicate this distance is directly proportional to the similarity of two entites. + If `positive` is set `False`, `max_distance` will be used to substract this distance to get the final score. + + :param max_distance: the bound of maximum distance. + :type max_distance: float + :param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise it is False. + :type positive: bool + """ + def __init__( self, - similarity_threshold: Optional[float] = 0.5, + strategy: Optional[str] = "distance", + max_distance: Optional[float] = 1.0, + positive: Optional[bool] = False, + ): + self.strategy = strategy + self.max_distance = max_distance + self.positive = positive + + def from_config(config: Optional[Dict[str, Any]]): + if config is None: + return CacheSimilarityEvalConfig() + else: + return CacheSimilarityEvalConfig( + strategy=config.get("strategy", "distance"), + max_distance=config.get("max_distance", 1.0), + positive=config.get("positive", False), + ) + + +@register_deserializable +class CacheInitConfig(BaseConfig): + """ + This is a cache init config. Used to initialize a cache. + + :param similarity_threshold: a threshold ranged from 0 to 1 to filter search results with similarity score higher \ + than the threshold. When it is 0, there is no hits. When it is 1, all search results will be returned as hits. + :type similarity_threshold: float + :param auto_flush: it will be automatically flushed every time xx pieces of data are added, default to 20 + :type auto_flush: int + """ + + def __init__( + self, + similarity_threshold: Optional[float] = 0.8, + auto_flush: Optional[int] = 20, ): if similarity_threshold < 0 or similarity_threshold > 1: raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1") self.similarity_threshold = similarity_threshold + self.auto_flush = auto_flush + + def from_config(config: Optional[Dict[str, Any]]): + if config is None: + return CacheInitConfig() + else: + return CacheInitConfig( + similarity_threshold=config.get("similarity_threshold", 0.8), + auto_flush=config.get("auto_flush", 20), + ) + + +@register_deserializable +class CacheConfig(BaseConfig): + def __init__( + self, + similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(), + init_config: Optional[CacheInitConfig] = CacheInitConfig(), + ): + self.similarity_eval_config = similarity_eval_config + self.init_config = init_config + + def from_config(config: Optional[Dict[str, Any]]): + if config is None: + return CacheConfig() + else: + return CacheConfig( + similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})), + init_config=CacheInitConfig.from_config(config.get("init_config", {})), + ) diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py index 5c38774b..78c2f136 100644 --- a/embedchain/llm/openai.py +++ b/embedchain/llm/openai.py @@ -42,6 +42,7 @@ class OpenAILlm(BaseLlm): chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key) else: chat = ChatOpenAI(**kwargs, api_key=api_key) + if self.functions is not None: from langchain.chains.openai_functions import \ create_openai_fn_runnable diff --git a/embedchain/utils.py b/embedchain/utils.py index 60cb32f9..4169558e 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -441,7 +441,15 @@ def validate_config(config_data): Optional("min_chunk_size"): int, }, Optional("cache"): { - Optional("similarity_threshold"): float, + Optional("similarity_evaluation"): { + Optional("strategy"): Or("distance", "exact"), + Optional("max_distance"): float, + Optional("positive"): bool, + }, + Optional("config"): { + Optional("similarity_threshold"): float, + Optional("auto_flush"): int, + }, }, } )