diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 3f1ad6a4..eec35cb5 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -612,11 +612,12 @@ class EmbedChain(JSONSerializable): if self.cache_config is not None: logging.info("Cache enabled. Checking cache...") + cache_id = f"{session_id}--{self.config.id}" answer = adapt( llm_handler=self.llm.chat, cache_data_convert=gptcache_data_convert, update_cache_callback=gptcache_update_cache_callback, - session=get_gptcache_session(session_id=self.config.id), + session=get_gptcache_session(session_id=cache_id), input_query=input_query, contexts=contexts_data_for_llm_query, config=config, diff --git a/embedchain/llm/gpt4all.py b/embedchain/llm/gpt4all.py index 4afddfab..fe4d6970 100644 --- a/embedchain/llm/gpt4all.py +++ b/embedchain/llm/gpt4all.py @@ -1,3 +1,5 @@ +import os +from pathlib import Path from typing import Iterable, Optional, Union from langchain.callbacks.stdout import StdOutCallbackHandler @@ -29,7 +31,14 @@ class GPT4ALLLlm(BaseLlm): "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501 ) from None - return LangchainGPT4All(model=model) + model_path = Path(model).expanduser() + if os.path.isabs(model_path): + if os.path.exists(model_path): + return LangchainGPT4All(model=str(model_path)) + else: + raise ValueError(f"Model does not exist at {model_path=}") + else: + return LangchainGPT4All(model=model, allow_download=True) def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]: if config.model and config.model != self.config.model: