[Bugfix] fix cache session id in chat method (#1107)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2024-01-03 13:51:11 +05:30
committed by GitHub
parent 1976d38b25
commit ae2e9cb890
2 changed files with 12 additions and 2 deletions

View File

@@ -612,11 +612,12 @@ class EmbedChain(JSONSerializable):
if self.cache_config is not None: if self.cache_config is not None:
logging.info("Cache enabled. Checking cache...") logging.info("Cache enabled. Checking cache...")
cache_id = f"{session_id}--{self.config.id}"
answer = adapt( answer = adapt(
llm_handler=self.llm.chat, llm_handler=self.llm.chat,
cache_data_convert=gptcache_data_convert, cache_data_convert=gptcache_data_convert,
update_cache_callback=gptcache_update_cache_callback, update_cache_callback=gptcache_update_cache_callback,
session=get_gptcache_session(session_id=self.config.id), session=get_gptcache_session(session_id=cache_id),
input_query=input_query, input_query=input_query,
contexts=contexts_data_for_llm_query, contexts=contexts_data_for_llm_query,
config=config, config=config,

View File

@@ -1,3 +1,5 @@
import os
from pathlib import Path
from typing import Iterable, Optional, Union from typing import Iterable, Optional, Union
from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler
@@ -29,7 +31,14 @@ class GPT4ALLLlm(BaseLlm):
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501 "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
) from None ) from None
return LangchainGPT4All(model=model) model_path = Path(model).expanduser()
if os.path.isabs(model_path):
if os.path.exists(model_path):
return LangchainGPT4All(model=str(model_path))
else:
raise ValueError(f"Model does not exist at {model_path=}")
else:
return LangchainGPT4All(model=model, allow_download=True)
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]: def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
if config.model and config.model != self.config.model: if config.model and config.model != self.config.model: