From 04daa1b20615bf49b04683262098399b7e336bf7 Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Sat, 30 Dec 2023 14:51:48 +0530 Subject: [PATCH] [Feature] Add support for GPTCache (#1065) --- .gitignore | 3 +++ embedchain/app.py | 26 ++++++++++++++++++- embedchain/cache.py | 40 +++++++++++++++++++++++++++++ embedchain/config/__init__.py | 1 + embedchain/config/cache_config.py | 16 ++++++++++++ embedchain/embedchain.py | 42 ++++++++++++++++++++++++++----- embedchain/embedder/base.py | 12 +++++++++ embedchain/utils.py | 3 +++ poetry.lock | 22 +++++++++++++--- pyproject.toml | 3 ++- 10 files changed, 157 insertions(+), 11 deletions(-) create mode 100644 embedchain/cache.py create mode 100644 embedchain/config/cache_config.py diff --git a/.gitignore b/.gitignore index 3165bee0..1d9a71e1 100644 --- a/.gitignore +++ b/.gitignore @@ -175,3 +175,6 @@ notebooks/*.yaml .ipynb_checkpoints/ !configs/*.yaml + +# cache db +*.db diff --git a/embedchain/app.py b/embedchain/app.py index 646d7f6b..17d6da1b 100644 --- a/embedchain/app.py +++ b/embedchain/app.py @@ -9,8 +9,10 @@ from typing import Any, Dict, Optional import requests import yaml +from embedchain.cache import (Config, SearchDistanceEvaluation, cache, + gptcache_data_manager, gptcache_pre_function) from embedchain.client import Client -from embedchain.config import AppConfig, ChunkerConfig +from embedchain.config import AppConfig, CacheConfig, ChunkerConfig from embedchain.constants import SQLITE_PATH from embedchain.embedchain import EmbedChain from embedchain.embedder.base import BaseEmbedder @@ -48,6 +50,7 @@ class App(EmbedChain): log_level=logging.WARN, auto_deploy: bool = False, chunker: ChunkerConfig = None, + cache_config: CacheConfig = None, ): """ Initialize a new `App` instance. @@ -88,6 +91,7 @@ class App(EmbedChain): self.chunker = None if chunker: self.chunker = ChunkerConfig(**chunker) + self.cache_config = cache_config self.config = config or AppConfig() self.name = self.config.name @@ -109,6 +113,10 @@ class App(EmbedChain): self.llm = llm or OpenAILlm() self._init_db() + # If cache_config is provided, initializing the cache ... + if self.cache_config is not None: + self._init_cache() + # Send anonymous telemetry self._telemetry_props = {"class": self.__class__.__name__} self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics) @@ -147,6 +155,15 @@ class App(EmbedChain): self.db._initialize() self.db.set_collection_name(self.db.config.collection_name) + def _init_cache(self): + cache.init( + pre_embedding_func=gptcache_pre_function, + embedding_func=self.embedding_model.to_embeddings, + data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension), + similarity_evaluation=SearchDistanceEvaluation(max_distance=1.0), + config=Config(similarity_threshold=self.cache_config.similarity_threshold), + ) + def _init_client(self): """ Initialize the client. @@ -399,6 +416,7 @@ class App(EmbedChain): embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {})) llm_config_data = config_data.get("llm", {}) chunker_config_data = config_data.get("chunker", {}) + cache_config_data = config_data.get("cache", None) app_config = AppConfig(**app_config_data) @@ -416,6 +434,11 @@ class App(EmbedChain): embedding_model_provider, embedding_model_config_data.get("config", {}) ) + if cache_config_data is not None: + cache_config = CacheConfig(**cache_config_data) + else: + cache_config = None + # Send anonymous telemetry event_properties = {"init_type": "config_data"} AnonymousTelemetry().capture(event_name="init", properties=event_properties) @@ -428,4 +451,5 @@ class App(EmbedChain): config_data=config_data, auto_deploy=auto_deploy, chunker=chunker_config_data, + cache_config=cache_config, ) diff --git a/embedchain/cache.py b/embedchain/cache.py new file mode 100644 index 00000000..af2aa61c --- /dev/null +++ b/embedchain/cache.py @@ -0,0 +1,40 @@ +import logging +import os # noqa: F401 +from typing import Any, Dict + +from gptcache import cache # noqa: F401 +from gptcache.adapter.adapter import adapt # noqa: F401 +from gptcache.config import Config # noqa: F401 +from gptcache.manager import get_data_manager +from gptcache.manager.scalar_data.base import Answer +from gptcache.manager.scalar_data.base import DataType as CacheDataType +from gptcache.session import Session +from gptcache.similarity_evaluation.distance import \ + SearchDistanceEvaluation # noqa: F401 + + +def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]): + return data["input_query"] + + +def gptcache_data_manager(vector_dimension): + return get_data_manager(cache_base="sqlite", vector_base="chromadb", max_size=1000, eviction="LRU") + + +def gptcache_data_convert(cache_data): + logging.info("[Cache] Cache hit, returning cache data...") + return cache_data + + +def gptcache_update_cache_callback(llm_data, update_cache_func, *args, **kwargs): + logging.info("[Cache] Cache missed, updating cache...") + update_cache_func(Answer(llm_data, CacheDataType.STR)) + return llm_data + + +def _gptcache_session_hit_func(cur_session_id: str, cache_session_ids: list, cache_questions: list, cache_answer: str): + return cur_session_id in cache_session_ids + + +def get_gptcache_session(session_id: str): + return Session(name=session_id, check_hit_func=_gptcache_session_hit_func) diff --git a/embedchain/config/__init__.py b/embedchain/config/__init__.py index dfc67d63..0e980c81 100644 --- a/embedchain/config/__init__.py +++ b/embedchain/config/__init__.py @@ -3,6 +3,7 @@ from .add_config import AddConfig, ChunkerConfig from .app_config import AppConfig from .base_config import BaseConfig +from .cache_config import CacheConfig from .embedder.base import BaseEmbedderConfig from .embedder.base import BaseEmbedderConfig as EmbedderConfig from .llm.base import BaseLlmConfig diff --git a/embedchain/config/cache_config.py b/embedchain/config/cache_config.py new file mode 100644 index 00000000..466f69d4 --- /dev/null +++ b/embedchain/config/cache_config.py @@ -0,0 +1,16 @@ +from typing import Optional + +from embedchain.config.base_config import BaseConfig +from embedchain.helpers.json_serializable import register_deserializable + + +@register_deserializable +class CacheConfig(BaseConfig): + def __init__( + self, + similarity_threshold: Optional[float] = 0.5, + ): + if similarity_threshold < 0 or similarity_threshold > 1: + raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1") + + self.similarity_threshold = similarity_threshold diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 5de836bb..741cbd24 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -7,6 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union from dotenv import load_dotenv from langchain.docstore.document import Document +from embedchain.cache import (adapt, get_gptcache_session, + gptcache_data_convert, + gptcache_update_cache_callback) from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config.base_app_config import BaseAppConfig @@ -52,6 +55,7 @@ class EmbedChain(JSONSerializable): """ self.config = config + self.cache_config = None # Llm self.llm = llm # Database has support for config assignment for backwards compatibility @@ -546,9 +550,22 @@ class EmbedChain(JSONSerializable): else: contexts_data_for_llm_query = contexts - answer = self.llm.query( - input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run - ) + if self.cache_config is not None: + logging.info("Cache enabled. Checking cache...") + answer = adapt( + llm_handler=self.llm.query, + cache_data_convert=gptcache_data_convert, + update_cache_callback=gptcache_update_cache_callback, + session=get_gptcache_session(session_id=self.config.id), + input_query=input_query, + contexts=contexts_data_for_llm_query, + config=config, + dry_run=dry_run, + ) + else: + answer = self.llm.query( + input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run + ) # Send anonymous telemetry self.telemetry.capture(event_name="query", properties=self._telemetry_props) @@ -599,9 +616,22 @@ class EmbedChain(JSONSerializable): else: contexts_data_for_llm_query = contexts - answer = self.llm.chat( - input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run - ) + if self.cache_config is not None: + logging.info("Cache enabled. Checking cache...") + answer = adapt( + llm_handler=self.llm.chat, + cache_data_convert=gptcache_data_convert, + update_cache_callback=gptcache_update_cache_callback, + session=get_gptcache_session(session_id=self.config.id), + input_query=input_query, + contexts=contexts_data_for_llm_query, + config=config, + dry_run=dry_run, + ) + else: + answer = self.llm.chat( + input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run + ) # add conversation in memory self.llm.add_history(self.config.id, input_query, answer) diff --git a/embedchain/embedder/base.py b/embedchain/embedder/base.py index 14941b2f..fdad3b2e 100644 --- a/embedchain/embedder/base.py +++ b/embedchain/embedder/base.py @@ -75,3 +75,15 @@ class BaseEmbedder: """ return EmbeddingFunc(embeddings.embed_documents) + + def to_embeddings(self, data: str, **_): + """ + Convert data to embeddings + + :param data: data to convert to embeddings + :type data: str + :return: embeddings + :rtype: list[float] + """ + embeddings = self.embedding_fn([data]) + return embeddings[0] diff --git a/embedchain/utils.py b/embedchain/utils.py index 2cfe57a5..f60312d7 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -436,6 +436,9 @@ def validate_config(config_data): Optional("length_function"): str, Optional("min_chunk_size"): int, }, + Optional("cache"): { + Optional("similarity_threshold"): float, + }, } ) diff --git a/poetry.lock b/poetry.lock index 5202837f..ff2091c0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1849,11 +1849,11 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -2176,6 +2176,22 @@ tqdm = "*" [package.extras] dev = ["black", "isort", "mkautodoc", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "pytest", "setuptools", "twine", "wheel"] +[[package]] +name = "gptcache" +version = "0.1.43" +description = "GPTCache, a powerful caching library that can be used to speed up and lower the cost of chat applications that rely on the LLM service. GPTCache works as a memcache for AIGC applications, similar to how Redis works for traditional applications." +optional = false +python-versions = ">=3.8.1" +files = [ + {file = "gptcache-0.1.43-py3-none-any.whl", hash = "sha256:9c557ec9cc14428942a0ebf1c838520dc6d2be801d67bb6964807043fc2feaf5"}, + {file = "gptcache-0.1.43.tar.gz", hash = "sha256:cebe7ec5e32a3347bf839e933a34e67c7fcae620deaa7cb8c6d7d276c8686f1a"}, +] + +[package.dependencies] +cachetools = "*" +numpy = "*" +requests = "*" + [[package]] name = "greenlet" version = "3.0.0" @@ -6583,7 +6599,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""} +greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""} typing-extensions = ">=4.2.0" [package.extras] diff --git a/pyproject.toml b/pyproject.toml index 8573521e..b26b8bcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.47" +version = "0.1.48" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" authors = [ "Taranjeet Singh ", @@ -101,6 +101,7 @@ posthog = "^3.0.2" rich = "^13.7.0" beautifulsoup4 = "^4.12.2" pypdf = "^3.11.0" +gptcache = "^0.1.43" tiktoken = { version = "^0.4.0", optional = true } youtube-transcript-api = { version = "^0.6.1", optional = true } pytube = { version = "^15.0.0", optional = true }