[Feature] Add support for GPTCache (#1065)
This commit is contained in:
@@ -9,8 +9,10 @@ from typing import Any, Dict, Optional
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from embedchain.cache import (Config, SearchDistanceEvaluation, cache,
|
||||
gptcache_data_manager, gptcache_pre_function)
|
||||
from embedchain.client import Client
|
||||
from embedchain.config import AppConfig, ChunkerConfig
|
||||
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
@@ -48,6 +50,7 @@ class App(EmbedChain):
|
||||
log_level=logging.WARN,
|
||||
auto_deploy: bool = False,
|
||||
chunker: ChunkerConfig = None,
|
||||
cache_config: CacheConfig = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new `App` instance.
|
||||
@@ -88,6 +91,7 @@ class App(EmbedChain):
|
||||
self.chunker = None
|
||||
if chunker:
|
||||
self.chunker = ChunkerConfig(**chunker)
|
||||
self.cache_config = cache_config
|
||||
|
||||
self.config = config or AppConfig()
|
||||
self.name = self.config.name
|
||||
@@ -109,6 +113,10 @@ class App(EmbedChain):
|
||||
self.llm = llm or OpenAILlm()
|
||||
self._init_db()
|
||||
|
||||
# If cache_config is provided, initializing the cache ...
|
||||
if self.cache_config is not None:
|
||||
self._init_cache()
|
||||
|
||||
# Send anonymous telemetry
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||
@@ -147,6 +155,15 @@ class App(EmbedChain):
|
||||
self.db._initialize()
|
||||
self.db.set_collection_name(self.db.config.collection_name)
|
||||
|
||||
def _init_cache(self):
|
||||
cache.init(
|
||||
pre_embedding_func=gptcache_pre_function,
|
||||
embedding_func=self.embedding_model.to_embeddings,
|
||||
data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
|
||||
similarity_evaluation=SearchDistanceEvaluation(max_distance=1.0),
|
||||
config=Config(similarity_threshold=self.cache_config.similarity_threshold),
|
||||
)
|
||||
|
||||
def _init_client(self):
|
||||
"""
|
||||
Initialize the client.
|
||||
@@ -399,6 +416,7 @@ class App(EmbedChain):
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
chunker_config_data = config_data.get("chunker", {})
|
||||
cache_config_data = config_data.get("cache", None)
|
||||
|
||||
app_config = AppConfig(**app_config_data)
|
||||
|
||||
@@ -416,6 +434,11 @@ class App(EmbedChain):
|
||||
embedding_model_provider, embedding_model_config_data.get("config", {})
|
||||
)
|
||||
|
||||
if cache_config_data is not None:
|
||||
cache_config = CacheConfig(**cache_config_data)
|
||||
else:
|
||||
cache_config = None
|
||||
|
||||
# Send anonymous telemetry
|
||||
event_properties = {"init_type": "config_data"}
|
||||
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
|
||||
@@ -428,4 +451,5 @@ class App(EmbedChain):
|
||||
config_data=config_data,
|
||||
auto_deploy=auto_deploy,
|
||||
chunker=chunker_config_data,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user