[Feature] Add support for GPTCache (#1065)
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -175,3 +175,6 @@ notebooks/*.yaml
|
|||||||
.ipynb_checkpoints/
|
.ipynb_checkpoints/
|
||||||
|
|
||||||
!configs/*.yaml
|
!configs/*.yaml
|
||||||
|
|
||||||
|
# cache db
|
||||||
|
*.db
|
||||||
|
|||||||
@@ -9,8 +9,10 @@ from typing import Any, Dict, Optional
|
|||||||
import requests
|
import requests
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from embedchain.cache import (Config, SearchDistanceEvaluation, cache,
|
||||||
|
gptcache_data_manager, gptcache_pre_function)
|
||||||
from embedchain.client import Client
|
from embedchain.client import Client
|
||||||
from embedchain.config import AppConfig, ChunkerConfig
|
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
|
||||||
from embedchain.constants import SQLITE_PATH
|
from embedchain.constants import SQLITE_PATH
|
||||||
from embedchain.embedchain import EmbedChain
|
from embedchain.embedchain import EmbedChain
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
@@ -48,6 +50,7 @@ class App(EmbedChain):
|
|||||||
log_level=logging.WARN,
|
log_level=logging.WARN,
|
||||||
auto_deploy: bool = False,
|
auto_deploy: bool = False,
|
||||||
chunker: ChunkerConfig = None,
|
chunker: ChunkerConfig = None,
|
||||||
|
cache_config: CacheConfig = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a new `App` instance.
|
Initialize a new `App` instance.
|
||||||
@@ -88,6 +91,7 @@ class App(EmbedChain):
|
|||||||
self.chunker = None
|
self.chunker = None
|
||||||
if chunker:
|
if chunker:
|
||||||
self.chunker = ChunkerConfig(**chunker)
|
self.chunker = ChunkerConfig(**chunker)
|
||||||
|
self.cache_config = cache_config
|
||||||
|
|
||||||
self.config = config or AppConfig()
|
self.config = config or AppConfig()
|
||||||
self.name = self.config.name
|
self.name = self.config.name
|
||||||
@@ -109,6 +113,10 @@ class App(EmbedChain):
|
|||||||
self.llm = llm or OpenAILlm()
|
self.llm = llm or OpenAILlm()
|
||||||
self._init_db()
|
self._init_db()
|
||||||
|
|
||||||
|
# If cache_config is provided, initializing the cache ...
|
||||||
|
if self.cache_config is not None:
|
||||||
|
self._init_cache()
|
||||||
|
|
||||||
# Send anonymous telemetry
|
# Send anonymous telemetry
|
||||||
self._telemetry_props = {"class": self.__class__.__name__}
|
self._telemetry_props = {"class": self.__class__.__name__}
|
||||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||||
@@ -147,6 +155,15 @@ class App(EmbedChain):
|
|||||||
self.db._initialize()
|
self.db._initialize()
|
||||||
self.db.set_collection_name(self.db.config.collection_name)
|
self.db.set_collection_name(self.db.config.collection_name)
|
||||||
|
|
||||||
|
def _init_cache(self):
|
||||||
|
cache.init(
|
||||||
|
pre_embedding_func=gptcache_pre_function,
|
||||||
|
embedding_func=self.embedding_model.to_embeddings,
|
||||||
|
data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
|
||||||
|
similarity_evaluation=SearchDistanceEvaluation(max_distance=1.0),
|
||||||
|
config=Config(similarity_threshold=self.cache_config.similarity_threshold),
|
||||||
|
)
|
||||||
|
|
||||||
def _init_client(self):
|
def _init_client(self):
|
||||||
"""
|
"""
|
||||||
Initialize the client.
|
Initialize the client.
|
||||||
@@ -399,6 +416,7 @@ class App(EmbedChain):
|
|||||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||||
llm_config_data = config_data.get("llm", {})
|
llm_config_data = config_data.get("llm", {})
|
||||||
chunker_config_data = config_data.get("chunker", {})
|
chunker_config_data = config_data.get("chunker", {})
|
||||||
|
cache_config_data = config_data.get("cache", None)
|
||||||
|
|
||||||
app_config = AppConfig(**app_config_data)
|
app_config = AppConfig(**app_config_data)
|
||||||
|
|
||||||
@@ -416,6 +434,11 @@ class App(EmbedChain):
|
|||||||
embedding_model_provider, embedding_model_config_data.get("config", {})
|
embedding_model_provider, embedding_model_config_data.get("config", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cache_config_data is not None:
|
||||||
|
cache_config = CacheConfig(**cache_config_data)
|
||||||
|
else:
|
||||||
|
cache_config = None
|
||||||
|
|
||||||
# Send anonymous telemetry
|
# Send anonymous telemetry
|
||||||
event_properties = {"init_type": "config_data"}
|
event_properties = {"init_type": "config_data"}
|
||||||
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
|
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
|
||||||
@@ -428,4 +451,5 @@ class App(EmbedChain):
|
|||||||
config_data=config_data,
|
config_data=config_data,
|
||||||
auto_deploy=auto_deploy,
|
auto_deploy=auto_deploy,
|
||||||
chunker=chunker_config_data,
|
chunker=chunker_config_data,
|
||||||
|
cache_config=cache_config,
|
||||||
)
|
)
|
||||||
|
|||||||
40
embedchain/cache.py
Normal file
40
embedchain/cache.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import logging
|
||||||
|
import os # noqa: F401
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from gptcache import cache # noqa: F401
|
||||||
|
from gptcache.adapter.adapter import adapt # noqa: F401
|
||||||
|
from gptcache.config import Config # noqa: F401
|
||||||
|
from gptcache.manager import get_data_manager
|
||||||
|
from gptcache.manager.scalar_data.base import Answer
|
||||||
|
from gptcache.manager.scalar_data.base import DataType as CacheDataType
|
||||||
|
from gptcache.session import Session
|
||||||
|
from gptcache.similarity_evaluation.distance import \
|
||||||
|
SearchDistanceEvaluation # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]):
|
||||||
|
return data["input_query"]
|
||||||
|
|
||||||
|
|
||||||
|
def gptcache_data_manager(vector_dimension):
|
||||||
|
return get_data_manager(cache_base="sqlite", vector_base="chromadb", max_size=1000, eviction="LRU")
|
||||||
|
|
||||||
|
|
||||||
|
def gptcache_data_convert(cache_data):
|
||||||
|
logging.info("[Cache] Cache hit, returning cache data...")
|
||||||
|
return cache_data
|
||||||
|
|
||||||
|
|
||||||
|
def gptcache_update_cache_callback(llm_data, update_cache_func, *args, **kwargs):
|
||||||
|
logging.info("[Cache] Cache missed, updating cache...")
|
||||||
|
update_cache_func(Answer(llm_data, CacheDataType.STR))
|
||||||
|
return llm_data
|
||||||
|
|
||||||
|
|
||||||
|
def _gptcache_session_hit_func(cur_session_id: str, cache_session_ids: list, cache_questions: list, cache_answer: str):
|
||||||
|
return cur_session_id in cache_session_ids
|
||||||
|
|
||||||
|
|
||||||
|
def get_gptcache_session(session_id: str):
|
||||||
|
return Session(name=session_id, check_hit_func=_gptcache_session_hit_func)
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
from .add_config import AddConfig, ChunkerConfig
|
from .add_config import AddConfig, ChunkerConfig
|
||||||
from .app_config import AppConfig
|
from .app_config import AppConfig
|
||||||
from .base_config import BaseConfig
|
from .base_config import BaseConfig
|
||||||
|
from .cache_config import CacheConfig
|
||||||
from .embedder.base import BaseEmbedderConfig
|
from .embedder.base import BaseEmbedderConfig
|
||||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||||
from .llm.base import BaseLlmConfig
|
from .llm.base import BaseLlmConfig
|
||||||
|
|||||||
16
embedchain/config/cache_config.py
Normal file
16
embedchain/config/cache_config.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from embedchain.config.base_config import BaseConfig
|
||||||
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class CacheConfig(BaseConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
similarity_threshold: Optional[float] = 0.5,
|
||||||
|
):
|
||||||
|
if similarity_threshold < 0 or similarity_threshold > 1:
|
||||||
|
raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
|
||||||
|
|
||||||
|
self.similarity_threshold = similarity_threshold
|
||||||
@@ -7,6 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
from embedchain.cache import (adapt, get_gptcache_session,
|
||||||
|
gptcache_data_convert,
|
||||||
|
gptcache_update_cache_callback)
|
||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||||
from embedchain.config.base_app_config import BaseAppConfig
|
from embedchain.config.base_app_config import BaseAppConfig
|
||||||
@@ -52,6 +55,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.cache_config = None
|
||||||
# Llm
|
# Llm
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
# Database has support for config assignment for backwards compatibility
|
# Database has support for config assignment for backwards compatibility
|
||||||
@@ -546,9 +550,22 @@ class EmbedChain(JSONSerializable):
|
|||||||
else:
|
else:
|
||||||
contexts_data_for_llm_query = contexts
|
contexts_data_for_llm_query = contexts
|
||||||
|
|
||||||
answer = self.llm.query(
|
if self.cache_config is not None:
|
||||||
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
logging.info("Cache enabled. Checking cache...")
|
||||||
)
|
answer = adapt(
|
||||||
|
llm_handler=self.llm.query,
|
||||||
|
cache_data_convert=gptcache_data_convert,
|
||||||
|
update_cache_callback=gptcache_update_cache_callback,
|
||||||
|
session=get_gptcache_session(session_id=self.config.id),
|
||||||
|
input_query=input_query,
|
||||||
|
contexts=contexts_data_for_llm_query,
|
||||||
|
config=config,
|
||||||
|
dry_run=dry_run,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
answer = self.llm.query(
|
||||||
|
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||||
|
)
|
||||||
|
|
||||||
# Send anonymous telemetry
|
# Send anonymous telemetry
|
||||||
self.telemetry.capture(event_name="query", properties=self._telemetry_props)
|
self.telemetry.capture(event_name="query", properties=self._telemetry_props)
|
||||||
@@ -599,9 +616,22 @@ class EmbedChain(JSONSerializable):
|
|||||||
else:
|
else:
|
||||||
contexts_data_for_llm_query = contexts
|
contexts_data_for_llm_query = contexts
|
||||||
|
|
||||||
answer = self.llm.chat(
|
if self.cache_config is not None:
|
||||||
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
logging.info("Cache enabled. Checking cache...")
|
||||||
)
|
answer = adapt(
|
||||||
|
llm_handler=self.llm.chat,
|
||||||
|
cache_data_convert=gptcache_data_convert,
|
||||||
|
update_cache_callback=gptcache_update_cache_callback,
|
||||||
|
session=get_gptcache_session(session_id=self.config.id),
|
||||||
|
input_query=input_query,
|
||||||
|
contexts=contexts_data_for_llm_query,
|
||||||
|
config=config,
|
||||||
|
dry_run=dry_run,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
answer = self.llm.chat(
|
||||||
|
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||||
|
)
|
||||||
|
|
||||||
# add conversation in memory
|
# add conversation in memory
|
||||||
self.llm.add_history(self.config.id, input_query, answer)
|
self.llm.add_history(self.config.id, input_query, answer)
|
||||||
|
|||||||
@@ -75,3 +75,15 @@ class BaseEmbedder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
return EmbeddingFunc(embeddings.embed_documents)
|
return EmbeddingFunc(embeddings.embed_documents)
|
||||||
|
|
||||||
|
def to_embeddings(self, data: str, **_):
|
||||||
|
"""
|
||||||
|
Convert data to embeddings
|
||||||
|
|
||||||
|
:param data: data to convert to embeddings
|
||||||
|
:type data: str
|
||||||
|
:return: embeddings
|
||||||
|
:rtype: list[float]
|
||||||
|
"""
|
||||||
|
embeddings = self.embedding_fn([data])
|
||||||
|
return embeddings[0]
|
||||||
|
|||||||
@@ -436,6 +436,9 @@ def validate_config(config_data):
|
|||||||
Optional("length_function"): str,
|
Optional("length_function"): str,
|
||||||
Optional("min_chunk_size"): int,
|
Optional("min_chunk_size"): int,
|
||||||
},
|
},
|
||||||
|
Optional("cache"): {
|
||||||
|
Optional("similarity_threshold"): float,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
22
poetry.lock
generated
22
poetry.lock
generated
@@ -1849,11 +1849,11 @@ files = [
|
|||||||
google-auth = ">=2.14.1,<3.0.dev0"
|
google-auth = ">=2.14.1,<3.0.dev0"
|
||||||
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
|
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
|
||||||
grpcio = [
|
grpcio = [
|
||||||
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
{version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""},
|
||||||
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||||
]
|
]
|
||||||
grpcio-status = [
|
grpcio-status = [
|
||||||
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
|
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""},
|
||||||
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
|
||||||
]
|
]
|
||||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
|
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
|
||||||
@@ -2176,6 +2176,22 @@ tqdm = "*"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
dev = ["black", "isort", "mkautodoc", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "pytest", "setuptools", "twine", "wheel"]
|
dev = ["black", "isort", "mkautodoc", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "pytest", "setuptools", "twine", "wheel"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gptcache"
|
||||||
|
version = "0.1.43"
|
||||||
|
description = "GPTCache, a powerful caching library that can be used to speed up and lower the cost of chat applications that rely on the LLM service. GPTCache works as a memcache for AIGC applications, similar to how Redis works for traditional applications."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8.1"
|
||||||
|
files = [
|
||||||
|
{file = "gptcache-0.1.43-py3-none-any.whl", hash = "sha256:9c557ec9cc14428942a0ebf1c838520dc6d2be801d67bb6964807043fc2feaf5"},
|
||||||
|
{file = "gptcache-0.1.43.tar.gz", hash = "sha256:cebe7ec5e32a3347bf839e933a34e67c7fcae620deaa7cb8c6d7d276c8686f1a"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
cachetools = "*"
|
||||||
|
numpy = "*"
|
||||||
|
requests = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "greenlet"
|
name = "greenlet"
|
||||||
version = "3.0.0"
|
version = "3.0.0"
|
||||||
@@ -6583,7 +6599,7 @@ files = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""}
|
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""}
|
||||||
typing-extensions = ">=4.2.0"
|
typing-extensions = ">=4.2.0"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "embedchain"
|
name = "embedchain"
|
||||||
version = "0.1.47"
|
version = "0.1.48"
|
||||||
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
||||||
authors = [
|
authors = [
|
||||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||||
@@ -101,6 +101,7 @@ posthog = "^3.0.2"
|
|||||||
rich = "^13.7.0"
|
rich = "^13.7.0"
|
||||||
beautifulsoup4 = "^4.12.2"
|
beautifulsoup4 = "^4.12.2"
|
||||||
pypdf = "^3.11.0"
|
pypdf = "^3.11.0"
|
||||||
|
gptcache = "^0.1.43"
|
||||||
tiktoken = { version = "^0.4.0", optional = true }
|
tiktoken = { version = "^0.4.0", optional = true }
|
||||||
youtube-transcript-api = { version = "^0.6.1", optional = true }
|
youtube-transcript-api = { version = "^0.6.1", optional = true }
|
||||||
pytube = { version = "^15.0.0", optional = true }
|
pytube = { version = "^15.0.0", optional = true }
|
||||||
|
|||||||
Reference in New Issue
Block a user