[Feature] Add support for GPTCache (#1065)

This commit is contained in:
Deven Patel
2023-12-30 14:51:48 +05:30
committed by GitHub
parent a7e1520d08
commit 04daa1b206
10 changed files with 157 additions and 11 deletions

View File

@@ -9,8 +9,10 @@ from typing import Any, Dict, Optional
import requests
import yaml
from embedchain.cache import (Config, SearchDistanceEvaluation, cache,
gptcache_data_manager, gptcache_pre_function)
from embedchain.client import Client
from embedchain.config import AppConfig, ChunkerConfig
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.constants import SQLITE_PATH
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
@@ -48,6 +50,7 @@ class App(EmbedChain):
log_level=logging.WARN,
auto_deploy: bool = False,
chunker: ChunkerConfig = None,
cache_config: CacheConfig = None,
):
"""
Initialize a new `App` instance.
@@ -88,6 +91,7 @@ class App(EmbedChain):
self.chunker = None
if chunker:
self.chunker = ChunkerConfig(**chunker)
self.cache_config = cache_config
self.config = config or AppConfig()
self.name = self.config.name
@@ -109,6 +113,10 @@ class App(EmbedChain):
self.llm = llm or OpenAILlm()
self._init_db()
# If cache_config is provided, initializing the cache ...
if self.cache_config is not None:
self._init_cache()
# Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -147,6 +155,15 @@ class App(EmbedChain):
self.db._initialize()
self.db.set_collection_name(self.db.config.collection_name)
def _init_cache(self):
cache.init(
pre_embedding_func=gptcache_pre_function,
embedding_func=self.embedding_model.to_embeddings,
data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
similarity_evaluation=SearchDistanceEvaluation(max_distance=1.0),
config=Config(similarity_threshold=self.cache_config.similarity_threshold),
)
def _init_client(self):
"""
Initialize the client.
@@ -399,6 +416,7 @@ class App(EmbedChain):
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
llm_config_data = config_data.get("llm", {})
chunker_config_data = config_data.get("chunker", {})
cache_config_data = config_data.get("cache", None)
app_config = AppConfig(**app_config_data)
@@ -416,6 +434,11 @@ class App(EmbedChain):
embedding_model_provider, embedding_model_config_data.get("config", {})
)
if cache_config_data is not None:
cache_config = CacheConfig(**cache_config_data)
else:
cache_config = None
# Send anonymous telemetry
event_properties = {"init_type": "config_data"}
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
@@ -428,4 +451,5 @@ class App(EmbedChain):
config_data=config_data,
auto_deploy=auto_deploy,
chunker=chunker_config_data,
cache_config=cache_config,
)

40
embedchain/cache.py Normal file
View File

@@ -0,0 +1,40 @@
import logging
import os # noqa: F401
from typing import Any, Dict
from gptcache import cache # noqa: F401
from gptcache.adapter.adapter import adapt # noqa: F401
from gptcache.config import Config # noqa: F401
from gptcache.manager import get_data_manager
from gptcache.manager.scalar_data.base import Answer
from gptcache.manager.scalar_data.base import DataType as CacheDataType
from gptcache.session import Session
from gptcache.similarity_evaluation.distance import \
SearchDistanceEvaluation # noqa: F401
def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]):
return data["input_query"]
def gptcache_data_manager(vector_dimension):
return get_data_manager(cache_base="sqlite", vector_base="chromadb", max_size=1000, eviction="LRU")
def gptcache_data_convert(cache_data):
logging.info("[Cache] Cache hit, returning cache data...")
return cache_data
def gptcache_update_cache_callback(llm_data, update_cache_func, *args, **kwargs):
logging.info("[Cache] Cache missed, updating cache...")
update_cache_func(Answer(llm_data, CacheDataType.STR))
return llm_data
def _gptcache_session_hit_func(cur_session_id: str, cache_session_ids: list, cache_questions: list, cache_answer: str):
return cur_session_id in cache_session_ids
def get_gptcache_session(session_id: str):
return Session(name=session_id, check_hit_func=_gptcache_session_hit_func)

View File

@@ -3,6 +3,7 @@
from .add_config import AddConfig, ChunkerConfig
from .app_config import AppConfig
from .base_config import BaseConfig
from .cache_config import CacheConfig
from .embedder.base import BaseEmbedderConfig
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
from .llm.base import BaseLlmConfig

View File

@@ -0,0 +1,16 @@
from typing import Optional
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class CacheConfig(BaseConfig):
def __init__(
self,
similarity_threshold: Optional[float] = 0.5,
):
if similarity_threshold < 0 or similarity_threshold > 1:
raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
self.similarity_threshold = similarity_threshold

View File

@@ -7,6 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from dotenv import load_dotenv
from langchain.docstore.document import Document
from embedchain.cache import (adapt, get_gptcache_session,
gptcache_data_convert,
gptcache_update_cache_callback)
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig
@@ -52,6 +55,7 @@ class EmbedChain(JSONSerializable):
"""
self.config = config
self.cache_config = None
# Llm
self.llm = llm
# Database has support for config assignment for backwards compatibility
@@ -546,9 +550,22 @@ class EmbedChain(JSONSerializable):
else:
contexts_data_for_llm_query = contexts
answer = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
if self.cache_config is not None:
logging.info("Cache enabled. Checking cache...")
answer = adapt(
llm_handler=self.llm.query,
cache_data_convert=gptcache_data_convert,
update_cache_callback=gptcache_update_cache_callback,
session=get_gptcache_session(session_id=self.config.id),
input_query=input_query,
contexts=contexts_data_for_llm_query,
config=config,
dry_run=dry_run,
)
else:
answer = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
# Send anonymous telemetry
self.telemetry.capture(event_name="query", properties=self._telemetry_props)
@@ -599,9 +616,22 @@ class EmbedChain(JSONSerializable):
else:
contexts_data_for_llm_query = contexts
answer = self.llm.chat(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
if self.cache_config is not None:
logging.info("Cache enabled. Checking cache...")
answer = adapt(
llm_handler=self.llm.chat,
cache_data_convert=gptcache_data_convert,
update_cache_callback=gptcache_update_cache_callback,
session=get_gptcache_session(session_id=self.config.id),
input_query=input_query,
contexts=contexts_data_for_llm_query,
config=config,
dry_run=dry_run,
)
else:
answer = self.llm.chat(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
# add conversation in memory
self.llm.add_history(self.config.id, input_query, answer)

View File

@@ -75,3 +75,15 @@ class BaseEmbedder:
"""
return EmbeddingFunc(embeddings.embed_documents)
def to_embeddings(self, data: str, **_):
"""
Convert data to embeddings
:param data: data to convert to embeddings
:type data: str
:return: embeddings
:rtype: list[float]
"""
embeddings = self.embedding_fn([data])
return embeddings[0]

View File

@@ -436,6 +436,9 @@ def validate_config(config_data):
Optional("length_function"): str,
Optional("min_chunk_size"): int,
},
Optional("cache"): {
Optional("similarity_threshold"): float,
},
}
)