[Feature] Add support for GPTCache (#1065)
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -175,3 +175,6 @@ notebooks/*.yaml
|
||||
.ipynb_checkpoints/
|
||||
|
||||
!configs/*.yaml
|
||||
|
||||
# cache db
|
||||
*.db
|
||||
|
||||
@@ -9,8 +9,10 @@ from typing import Any, Dict, Optional
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from embedchain.cache import (Config, SearchDistanceEvaluation, cache,
|
||||
gptcache_data_manager, gptcache_pre_function)
|
||||
from embedchain.client import Client
|
||||
from embedchain.config import AppConfig, ChunkerConfig
|
||||
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
@@ -48,6 +50,7 @@ class App(EmbedChain):
|
||||
log_level=logging.WARN,
|
||||
auto_deploy: bool = False,
|
||||
chunker: ChunkerConfig = None,
|
||||
cache_config: CacheConfig = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new `App` instance.
|
||||
@@ -88,6 +91,7 @@ class App(EmbedChain):
|
||||
self.chunker = None
|
||||
if chunker:
|
||||
self.chunker = ChunkerConfig(**chunker)
|
||||
self.cache_config = cache_config
|
||||
|
||||
self.config = config or AppConfig()
|
||||
self.name = self.config.name
|
||||
@@ -109,6 +113,10 @@ class App(EmbedChain):
|
||||
self.llm = llm or OpenAILlm()
|
||||
self._init_db()
|
||||
|
||||
# If cache_config is provided, initializing the cache ...
|
||||
if self.cache_config is not None:
|
||||
self._init_cache()
|
||||
|
||||
# Send anonymous telemetry
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||
@@ -147,6 +155,15 @@ class App(EmbedChain):
|
||||
self.db._initialize()
|
||||
self.db.set_collection_name(self.db.config.collection_name)
|
||||
|
||||
def _init_cache(self):
|
||||
cache.init(
|
||||
pre_embedding_func=gptcache_pre_function,
|
||||
embedding_func=self.embedding_model.to_embeddings,
|
||||
data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
|
||||
similarity_evaluation=SearchDistanceEvaluation(max_distance=1.0),
|
||||
config=Config(similarity_threshold=self.cache_config.similarity_threshold),
|
||||
)
|
||||
|
||||
def _init_client(self):
|
||||
"""
|
||||
Initialize the client.
|
||||
@@ -399,6 +416,7 @@ class App(EmbedChain):
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
chunker_config_data = config_data.get("chunker", {})
|
||||
cache_config_data = config_data.get("cache", None)
|
||||
|
||||
app_config = AppConfig(**app_config_data)
|
||||
|
||||
@@ -416,6 +434,11 @@ class App(EmbedChain):
|
||||
embedding_model_provider, embedding_model_config_data.get("config", {})
|
||||
)
|
||||
|
||||
if cache_config_data is not None:
|
||||
cache_config = CacheConfig(**cache_config_data)
|
||||
else:
|
||||
cache_config = None
|
||||
|
||||
# Send anonymous telemetry
|
||||
event_properties = {"init_type": "config_data"}
|
||||
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
|
||||
@@ -428,4 +451,5 @@ class App(EmbedChain):
|
||||
config_data=config_data,
|
||||
auto_deploy=auto_deploy,
|
||||
chunker=chunker_config_data,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
|
||||
40
embedchain/cache.py
Normal file
40
embedchain/cache.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import os # noqa: F401
|
||||
from typing import Any, Dict
|
||||
|
||||
from gptcache import cache # noqa: F401
|
||||
from gptcache.adapter.adapter import adapt # noqa: F401
|
||||
from gptcache.config import Config # noqa: F401
|
||||
from gptcache.manager import get_data_manager
|
||||
from gptcache.manager.scalar_data.base import Answer
|
||||
from gptcache.manager.scalar_data.base import DataType as CacheDataType
|
||||
from gptcache.session import Session
|
||||
from gptcache.similarity_evaluation.distance import \
|
||||
SearchDistanceEvaluation # noqa: F401
|
||||
|
||||
|
||||
def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]):
|
||||
return data["input_query"]
|
||||
|
||||
|
||||
def gptcache_data_manager(vector_dimension):
|
||||
return get_data_manager(cache_base="sqlite", vector_base="chromadb", max_size=1000, eviction="LRU")
|
||||
|
||||
|
||||
def gptcache_data_convert(cache_data):
|
||||
logging.info("[Cache] Cache hit, returning cache data...")
|
||||
return cache_data
|
||||
|
||||
|
||||
def gptcache_update_cache_callback(llm_data, update_cache_func, *args, **kwargs):
|
||||
logging.info("[Cache] Cache missed, updating cache...")
|
||||
update_cache_func(Answer(llm_data, CacheDataType.STR))
|
||||
return llm_data
|
||||
|
||||
|
||||
def _gptcache_session_hit_func(cur_session_id: str, cache_session_ids: list, cache_questions: list, cache_answer: str):
|
||||
return cur_session_id in cache_session_ids
|
||||
|
||||
|
||||
def get_gptcache_session(session_id: str):
|
||||
return Session(name=session_id, check_hit_func=_gptcache_session_hit_func)
|
||||
@@ -3,6 +3,7 @@
|
||||
from .add_config import AddConfig, ChunkerConfig
|
||||
from .app_config import AppConfig
|
||||
from .base_config import BaseConfig
|
||||
from .cache_config import CacheConfig
|
||||
from .embedder.base import BaseEmbedderConfig
|
||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||
from .llm.base import BaseLlmConfig
|
||||
|
||||
16
embedchain/config/cache_config.py
Normal file
16
embedchain/config/cache_config.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CacheConfig(BaseConfig):
|
||||
def __init__(
|
||||
self,
|
||||
similarity_threshold: Optional[float] = 0.5,
|
||||
):
|
||||
if similarity_threshold < 0 or similarity_threshold > 1:
|
||||
raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
|
||||
|
||||
self.similarity_threshold = similarity_threshold
|
||||
@@ -7,6 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from embedchain.cache import (adapt, get_gptcache_session,
|
||||
gptcache_data_convert,
|
||||
gptcache_update_cache_callback)
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config.base_app_config import BaseAppConfig
|
||||
@@ -52,6 +55,7 @@ class EmbedChain(JSONSerializable):
|
||||
"""
|
||||
|
||||
self.config = config
|
||||
self.cache_config = None
|
||||
# Llm
|
||||
self.llm = llm
|
||||
# Database has support for config assignment for backwards compatibility
|
||||
@@ -546,9 +550,22 @@ class EmbedChain(JSONSerializable):
|
||||
else:
|
||||
contexts_data_for_llm_query = contexts
|
||||
|
||||
answer = self.llm.query(
|
||||
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||
)
|
||||
if self.cache_config is not None:
|
||||
logging.info("Cache enabled. Checking cache...")
|
||||
answer = adapt(
|
||||
llm_handler=self.llm.query,
|
||||
cache_data_convert=gptcache_data_convert,
|
||||
update_cache_callback=gptcache_update_cache_callback,
|
||||
session=get_gptcache_session(session_id=self.config.id),
|
||||
input_query=input_query,
|
||||
contexts=contexts_data_for_llm_query,
|
||||
config=config,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
else:
|
||||
answer = self.llm.query(
|
||||
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||
)
|
||||
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="query", properties=self._telemetry_props)
|
||||
@@ -599,9 +616,22 @@ class EmbedChain(JSONSerializable):
|
||||
else:
|
||||
contexts_data_for_llm_query = contexts
|
||||
|
||||
answer = self.llm.chat(
|
||||
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||
)
|
||||
if self.cache_config is not None:
|
||||
logging.info("Cache enabled. Checking cache...")
|
||||
answer = adapt(
|
||||
llm_handler=self.llm.chat,
|
||||
cache_data_convert=gptcache_data_convert,
|
||||
update_cache_callback=gptcache_update_cache_callback,
|
||||
session=get_gptcache_session(session_id=self.config.id),
|
||||
input_query=input_query,
|
||||
contexts=contexts_data_for_llm_query,
|
||||
config=config,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
else:
|
||||
answer = self.llm.chat(
|
||||
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||
)
|
||||
|
||||
# add conversation in memory
|
||||
self.llm.add_history(self.config.id, input_query, answer)
|
||||
|
||||
@@ -75,3 +75,15 @@ class BaseEmbedder:
|
||||
"""
|
||||
|
||||
return EmbeddingFunc(embeddings.embed_documents)
|
||||
|
||||
def to_embeddings(self, data: str, **_):
|
||||
"""
|
||||
Convert data to embeddings
|
||||
|
||||
:param data: data to convert to embeddings
|
||||
:type data: str
|
||||
:return: embeddings
|
||||
:rtype: list[float]
|
||||
"""
|
||||
embeddings = self.embedding_fn([data])
|
||||
return embeddings[0]
|
||||
|
||||
@@ -436,6 +436,9 @@ def validate_config(config_data):
|
||||
Optional("length_function"): str,
|
||||
Optional("min_chunk_size"): int,
|
||||
},
|
||||
Optional("cache"): {
|
||||
Optional("similarity_threshold"): float,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
22
poetry.lock
generated
22
poetry.lock
generated
@@ -1849,11 +1849,11 @@ files = [
|
||||
google-auth = ">=2.14.1,<3.0.dev0"
|
||||
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
|
||||
grpcio = [
|
||||
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
]
|
||||
grpcio-status = [
|
||||
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
||||
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""},
|
||||
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||
]
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
|
||||
@@ -2176,6 +2176,22 @@ tqdm = "*"
|
||||
[package.extras]
|
||||
dev = ["black", "isort", "mkautodoc", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "pytest", "setuptools", "twine", "wheel"]
|
||||
|
||||
[[package]]
|
||||
name = "gptcache"
|
||||
version = "0.1.43"
|
||||
description = "GPTCache, a powerful caching library that can be used to speed up and lower the cost of chat applications that rely on the LLM service. GPTCache works as a memcache for AIGC applications, similar to how Redis works for traditional applications."
|
||||
optional = false
|
||||
python-versions = ">=3.8.1"
|
||||
files = [
|
||||
{file = "gptcache-0.1.43-py3-none-any.whl", hash = "sha256:9c557ec9cc14428942a0ebf1c838520dc6d2be801d67bb6964807043fc2feaf5"},
|
||||
{file = "gptcache-0.1.43.tar.gz", hash = "sha256:cebe7ec5e32a3347bf839e933a34e67c7fcae620deaa7cb8c6d7d276c8686f1a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cachetools = "*"
|
||||
numpy = "*"
|
||||
requests = "*"
|
||||
|
||||
[[package]]
|
||||
name = "greenlet"
|
||||
version = "3.0.0"
|
||||
@@ -6583,7 +6599,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""}
|
||||
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""}
|
||||
typing-extensions = ">=4.2.0"
|
||||
|
||||
[package.extras]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "embedchain"
|
||||
version = "0.1.47"
|
||||
version = "0.1.48"
|
||||
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
||||
authors = [
|
||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||
@@ -101,6 +101,7 @@ posthog = "^3.0.2"
|
||||
rich = "^13.7.0"
|
||||
beautifulsoup4 = "^4.12.2"
|
||||
pypdf = "^3.11.0"
|
||||
gptcache = "^0.1.43"
|
||||
tiktoken = { version = "^0.4.0", optional = true }
|
||||
youtube-transcript-api = { version = "^0.6.1", optional = true }
|
||||
pytube = { version = "^15.0.0", optional = true }
|
||||
|
||||
Reference in New Issue
Block a user