[Feature] Add support for GPTCache (#1065)

This commit is contained in:
Deven Patel
2023-12-30 14:51:48 +05:30
committed by GitHub
parent a7e1520d08
commit 04daa1b206
10 changed files with 157 additions and 11 deletions

3
.gitignore vendored
View File

@@ -175,3 +175,6 @@ notebooks/*.yaml
.ipynb_checkpoints/ .ipynb_checkpoints/
!configs/*.yaml !configs/*.yaml
# cache db
*.db

View File

@@ -9,8 +9,10 @@ from typing import Any, Dict, Optional
import requests import requests
import yaml import yaml
from embedchain.cache import (Config, SearchDistanceEvaluation, cache,
gptcache_data_manager, gptcache_pre_function)
from embedchain.client import Client from embedchain.client import Client
from embedchain.config import AppConfig, ChunkerConfig from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.constants import SQLITE_PATH from embedchain.constants import SQLITE_PATH
from embedchain.embedchain import EmbedChain from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
@@ -48,6 +50,7 @@ class App(EmbedChain):
log_level=logging.WARN, log_level=logging.WARN,
auto_deploy: bool = False, auto_deploy: bool = False,
chunker: ChunkerConfig = None, chunker: ChunkerConfig = None,
cache_config: CacheConfig = None,
): ):
""" """
Initialize a new `App` instance. Initialize a new `App` instance.
@@ -88,6 +91,7 @@ class App(EmbedChain):
self.chunker = None self.chunker = None
if chunker: if chunker:
self.chunker = ChunkerConfig(**chunker) self.chunker = ChunkerConfig(**chunker)
self.cache_config = cache_config
self.config = config or AppConfig() self.config = config or AppConfig()
self.name = self.config.name self.name = self.config.name
@@ -109,6 +113,10 @@ class App(EmbedChain):
self.llm = llm or OpenAILlm() self.llm = llm or OpenAILlm()
self._init_db() self._init_db()
# If cache_config is provided, initializing the cache ...
if self.cache_config is not None:
self._init_cache()
# Send anonymous telemetry # Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__} self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics) self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -147,6 +155,15 @@ class App(EmbedChain):
self.db._initialize() self.db._initialize()
self.db.set_collection_name(self.db.config.collection_name) self.db.set_collection_name(self.db.config.collection_name)
def _init_cache(self):
cache.init(
pre_embedding_func=gptcache_pre_function,
embedding_func=self.embedding_model.to_embeddings,
data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
similarity_evaluation=SearchDistanceEvaluation(max_distance=1.0),
config=Config(similarity_threshold=self.cache_config.similarity_threshold),
)
def _init_client(self): def _init_client(self):
""" """
Initialize the client. Initialize the client.
@@ -399,6 +416,7 @@ class App(EmbedChain):
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {})) embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
llm_config_data = config_data.get("llm", {}) llm_config_data = config_data.get("llm", {})
chunker_config_data = config_data.get("chunker", {}) chunker_config_data = config_data.get("chunker", {})
cache_config_data = config_data.get("cache", None)
app_config = AppConfig(**app_config_data) app_config = AppConfig(**app_config_data)
@@ -416,6 +434,11 @@ class App(EmbedChain):
embedding_model_provider, embedding_model_config_data.get("config", {}) embedding_model_provider, embedding_model_config_data.get("config", {})
) )
if cache_config_data is not None:
cache_config = CacheConfig(**cache_config_data)
else:
cache_config = None
# Send anonymous telemetry # Send anonymous telemetry
event_properties = {"init_type": "config_data"} event_properties = {"init_type": "config_data"}
AnonymousTelemetry().capture(event_name="init", properties=event_properties) AnonymousTelemetry().capture(event_name="init", properties=event_properties)
@@ -428,4 +451,5 @@ class App(EmbedChain):
config_data=config_data, config_data=config_data,
auto_deploy=auto_deploy, auto_deploy=auto_deploy,
chunker=chunker_config_data, chunker=chunker_config_data,
cache_config=cache_config,
) )

40
embedchain/cache.py Normal file
View File

@@ -0,0 +1,40 @@
import logging
import os # noqa: F401
from typing import Any, Dict
from gptcache import cache # noqa: F401
from gptcache.adapter.adapter import adapt # noqa: F401
from gptcache.config import Config # noqa: F401
from gptcache.manager import get_data_manager
from gptcache.manager.scalar_data.base import Answer
from gptcache.manager.scalar_data.base import DataType as CacheDataType
from gptcache.session import Session
from gptcache.similarity_evaluation.distance import \
SearchDistanceEvaluation # noqa: F401
def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]):
return data["input_query"]
def gptcache_data_manager(vector_dimension):
return get_data_manager(cache_base="sqlite", vector_base="chromadb", max_size=1000, eviction="LRU")
def gptcache_data_convert(cache_data):
logging.info("[Cache] Cache hit, returning cache data...")
return cache_data
def gptcache_update_cache_callback(llm_data, update_cache_func, *args, **kwargs):
logging.info("[Cache] Cache missed, updating cache...")
update_cache_func(Answer(llm_data, CacheDataType.STR))
return llm_data
def _gptcache_session_hit_func(cur_session_id: str, cache_session_ids: list, cache_questions: list, cache_answer: str):
return cur_session_id in cache_session_ids
def get_gptcache_session(session_id: str):
return Session(name=session_id, check_hit_func=_gptcache_session_hit_func)

View File

@@ -3,6 +3,7 @@
from .add_config import AddConfig, ChunkerConfig from .add_config import AddConfig, ChunkerConfig
from .app_config import AppConfig from .app_config import AppConfig
from .base_config import BaseConfig from .base_config import BaseConfig
from .cache_config import CacheConfig
from .embedder.base import BaseEmbedderConfig from .embedder.base import BaseEmbedderConfig
from .embedder.base import BaseEmbedderConfig as EmbedderConfig from .embedder.base import BaseEmbedderConfig as EmbedderConfig
from .llm.base import BaseLlmConfig from .llm.base import BaseLlmConfig

View File

@@ -0,0 +1,16 @@
from typing import Optional
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class CacheConfig(BaseConfig):
def __init__(
self,
similarity_threshold: Optional[float] = 0.5,
):
if similarity_threshold < 0 or similarity_threshold > 1:
raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
self.similarity_threshold = similarity_threshold

View File

@@ -7,6 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.docstore.document import Document from langchain.docstore.document import Document
from embedchain.cache import (adapt, get_gptcache_session,
gptcache_data_convert,
gptcache_update_cache_callback)
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig from embedchain.config.base_app_config import BaseAppConfig
@@ -52,6 +55,7 @@ class EmbedChain(JSONSerializable):
""" """
self.config = config self.config = config
self.cache_config = None
# Llm # Llm
self.llm = llm self.llm = llm
# Database has support for config assignment for backwards compatibility # Database has support for config assignment for backwards compatibility
@@ -546,9 +550,22 @@ class EmbedChain(JSONSerializable):
else: else:
contexts_data_for_llm_query = contexts contexts_data_for_llm_query = contexts
answer = self.llm.query( if self.cache_config is not None:
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run logging.info("Cache enabled. Checking cache...")
) answer = adapt(
llm_handler=self.llm.query,
cache_data_convert=gptcache_data_convert,
update_cache_callback=gptcache_update_cache_callback,
session=get_gptcache_session(session_id=self.config.id),
input_query=input_query,
contexts=contexts_data_for_llm_query,
config=config,
dry_run=dry_run,
)
else:
answer = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
# Send anonymous telemetry # Send anonymous telemetry
self.telemetry.capture(event_name="query", properties=self._telemetry_props) self.telemetry.capture(event_name="query", properties=self._telemetry_props)
@@ -599,9 +616,22 @@ class EmbedChain(JSONSerializable):
else: else:
contexts_data_for_llm_query = contexts contexts_data_for_llm_query = contexts
answer = self.llm.chat( if self.cache_config is not None:
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run logging.info("Cache enabled. Checking cache...")
) answer = adapt(
llm_handler=self.llm.chat,
cache_data_convert=gptcache_data_convert,
update_cache_callback=gptcache_update_cache_callback,
session=get_gptcache_session(session_id=self.config.id),
input_query=input_query,
contexts=contexts_data_for_llm_query,
config=config,
dry_run=dry_run,
)
else:
answer = self.llm.chat(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
# add conversation in memory # add conversation in memory
self.llm.add_history(self.config.id, input_query, answer) self.llm.add_history(self.config.id, input_query, answer)

View File

@@ -75,3 +75,15 @@ class BaseEmbedder:
""" """
return EmbeddingFunc(embeddings.embed_documents) return EmbeddingFunc(embeddings.embed_documents)
def to_embeddings(self, data: str, **_):
"""
Convert data to embeddings
:param data: data to convert to embeddings
:type data: str
:return: embeddings
:rtype: list[float]
"""
embeddings = self.embedding_fn([data])
return embeddings[0]

View File

@@ -436,6 +436,9 @@ def validate_config(config_data):
Optional("length_function"): str, Optional("length_function"): str,
Optional("min_chunk_size"): int, Optional("min_chunk_size"): int,
}, },
Optional("cache"): {
Optional("similarity_threshold"): float,
},
} }
) )

22
poetry.lock generated
View File

@@ -1849,11 +1849,11 @@ files = [
google-auth = ">=2.14.1,<3.0.dev0" google-auth = ">=2.14.1,<3.0.dev0"
googleapis-common-protos = ">=1.56.2,<2.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0"
grpcio = [ grpcio = [
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""},
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
] ]
grpcio-status = [ grpcio-status = [
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""},
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
] ]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
@@ -2176,6 +2176,22 @@ tqdm = "*"
[package.extras] [package.extras]
dev = ["black", "isort", "mkautodoc", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "pytest", "setuptools", "twine", "wheel"] dev = ["black", "isort", "mkautodoc", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "pytest", "setuptools", "twine", "wheel"]
[[package]]
name = "gptcache"
version = "0.1.43"
description = "GPTCache, a powerful caching library that can be used to speed up and lower the cost of chat applications that rely on the LLM service. GPTCache works as a memcache for AIGC applications, similar to how Redis works for traditional applications."
optional = false
python-versions = ">=3.8.1"
files = [
{file = "gptcache-0.1.43-py3-none-any.whl", hash = "sha256:9c557ec9cc14428942a0ebf1c838520dc6d2be801d67bb6964807043fc2feaf5"},
{file = "gptcache-0.1.43.tar.gz", hash = "sha256:cebe7ec5e32a3347bf839e933a34e67c7fcae620deaa7cb8c6d7d276c8686f1a"},
]
[package.dependencies]
cachetools = "*"
numpy = "*"
requests = "*"
[[package]] [[package]]
name = "greenlet" name = "greenlet"
version = "3.0.0" version = "3.0.0"
@@ -6583,7 +6599,7 @@ files = [
] ]
[package.dependencies] [package.dependencies]
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""} greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""}
typing-extensions = ">=4.2.0" typing-extensions = ">=4.2.0"
[package.extras] [package.extras]

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "embedchain" name = "embedchain"
version = "0.1.47" version = "0.1.48"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = [ authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>", "Taranjeet Singh <taranjeet@embedchain.ai>",
@@ -101,6 +101,7 @@ posthog = "^3.0.2"
rich = "^13.7.0" rich = "^13.7.0"
beautifulsoup4 = "^4.12.2" beautifulsoup4 = "^4.12.2"
pypdf = "^3.11.0" pypdf = "^3.11.0"
gptcache = "^0.1.43"
tiktoken = { version = "^0.4.0", optional = true } tiktoken = { version = "^0.4.0", optional = true }
youtube-transcript-api = { version = "^0.6.1", optional = true } youtube-transcript-api = { version = "^0.6.1", optional = true }
pytube = { version = "^15.0.0", optional = true } pytube = { version = "^15.0.0", optional = true }