[Feature] Add support for GPTCache (#1065)

This commit is contained in:
Deven Patel
2023-12-30 14:51:48 +05:30
committed by GitHub
parent a7e1520d08
commit 04daa1b206
10 changed files with 157 additions and 11 deletions

View File

@@ -9,8 +9,10 @@ from typing import Any, Dict, Optional
import requests
import yaml
from embedchain.cache import (Config, SearchDistanceEvaluation, cache,
gptcache_data_manager, gptcache_pre_function)
from embedchain.client import Client
from embedchain.config import AppConfig, ChunkerConfig
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.constants import SQLITE_PATH
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
@@ -48,6 +50,7 @@ class App(EmbedChain):
log_level=logging.WARN,
auto_deploy: bool = False,
chunker: ChunkerConfig = None,
cache_config: CacheConfig = None,
):
"""
Initialize a new `App` instance.
@@ -88,6 +91,7 @@ class App(EmbedChain):
self.chunker = None
if chunker:
self.chunker = ChunkerConfig(**chunker)
self.cache_config = cache_config
self.config = config or AppConfig()
self.name = self.config.name
@@ -109,6 +113,10 @@ class App(EmbedChain):
self.llm = llm or OpenAILlm()
self._init_db()
# If cache_config is provided, initializing the cache ...
if self.cache_config is not None:
self._init_cache()
# Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -147,6 +155,15 @@ class App(EmbedChain):
self.db._initialize()
self.db.set_collection_name(self.db.config.collection_name)
def _init_cache(self):
cache.init(
pre_embedding_func=gptcache_pre_function,
embedding_func=self.embedding_model.to_embeddings,
data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
similarity_evaluation=SearchDistanceEvaluation(max_distance=1.0),
config=Config(similarity_threshold=self.cache_config.similarity_threshold),
)
def _init_client(self):
"""
Initialize the client.
@@ -399,6 +416,7 @@ class App(EmbedChain):
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
llm_config_data = config_data.get("llm", {})
chunker_config_data = config_data.get("chunker", {})
cache_config_data = config_data.get("cache", None)
app_config = AppConfig(**app_config_data)
@@ -416,6 +434,11 @@ class App(EmbedChain):
embedding_model_provider, embedding_model_config_data.get("config", {})
)
if cache_config_data is not None:
cache_config = CacheConfig(**cache_config_data)
else:
cache_config = None
# Send anonymous telemetry
event_properties = {"init_type": "config_data"}
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
@@ -428,4 +451,5 @@ class App(EmbedChain):
config_data=config_data,
auto_deploy=auto_deploy,
chunker=chunker_config_data,
cache_config=cache_config,
)