[Improvement] Use SQLite for chat memory (#910)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import uuid
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from embedchain.embedchain import CONFIG_DIR, CONFIG_FILE
|
from embedchain.constants import CONFIG_DIR, CONFIG_FILE
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
|
|||||||
8
embedchain/constants.py
Normal file
8
embedchain/constants.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ABS_PATH = os.getcwd()
|
||||||
|
HOME_DIR = str(Path.home())
|
||||||
|
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
||||||
|
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
||||||
|
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -12,6 +10,7 @@ from langchain.docstore.document import Document
|
|||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||||
from embedchain.config.apps.base_app_config import BaseAppConfig
|
from embedchain.config.apps.base_app_config import BaseAppConfig
|
||||||
|
from embedchain.constants import SQLITE_PATH
|
||||||
from embedchain.data_formatter import DataFormatter
|
from embedchain.data_formatter import DataFormatter
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
from embedchain.helper.json_serializable import JSONSerializable
|
from embedchain.helper.json_serializable import JSONSerializable
|
||||||
@@ -25,12 +24,6 @@ from embedchain.vectordb.base import BaseVectorDB
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
ABS_PATH = os.getcwd()
|
|
||||||
HOME_DIR = str(Path.home())
|
|
||||||
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
|
||||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
|
||||||
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
|
|
||||||
|
|
||||||
|
|
||||||
class EmbedChain(JSONSerializable):
|
class EmbedChain(JSONSerializable):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -602,6 +595,9 @@ class EmbedChain(JSONSerializable):
|
|||||||
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# add conversation in memory
|
||||||
|
self.llm.add_history(self.config.id, input_query, answer)
|
||||||
|
|
||||||
# Send anonymous telemetry
|
# Send anonymous telemetry
|
||||||
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
|
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
|
||||||
|
|
||||||
@@ -645,5 +641,9 @@ class EmbedChain(JSONSerializable):
|
|||||||
self.db.reset()
|
self.db.reset()
|
||||||
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
|
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
self.clear_history()
|
||||||
# Send anonymous telemetry
|
# Send anonymous telemetry
|
||||||
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
|
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
|
||||||
|
|
||||||
|
def clear_history(self):
|
||||||
|
self.llm.memory.delete_chat_history(app_id=self.config.id)
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Generator, List, Optional
|
from typing import Any, Dict, Generator, List, Optional
|
||||||
|
|
||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.schema import BaseMessage as LCBaseMessage
|
||||||
from langchain.schema import BaseMessage
|
|
||||||
|
|
||||||
from embedchain.config import BaseLlmConfig
|
from embedchain.config import BaseLlmConfig
|
||||||
from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
||||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
||||||
DOCS_SITE_PROMPT_TEMPLATE)
|
DOCS_SITE_PROMPT_TEMPLATE)
|
||||||
from embedchain.helper.json_serializable import JSONSerializable
|
from embedchain.helper.json_serializable import JSONSerializable
|
||||||
|
from embedchain.memory.base import ECChatMemory
|
||||||
|
from embedchain.memory.message import ChatMessage
|
||||||
|
|
||||||
|
|
||||||
class BaseLlm(JSONSerializable):
|
class BaseLlm(JSONSerializable):
|
||||||
@@ -23,7 +24,7 @@ class BaseLlm(JSONSerializable):
|
|||||||
else:
|
else:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.memory = ConversationBufferMemory()
|
self.memory = ECChatMemory()
|
||||||
self.is_docs_site_instance = False
|
self.is_docs_site_instance = False
|
||||||
self.online = False
|
self.online = False
|
||||||
self.history: Any = None
|
self.history: Any = None
|
||||||
@@ -44,11 +45,18 @@ class BaseLlm(JSONSerializable):
|
|||||||
"""
|
"""
|
||||||
self.history = history
|
self.history = history
|
||||||
|
|
||||||
def update_history(self):
|
def update_history(self, app_id: str):
|
||||||
"""Update class history attribute with history in memory (for chat method)"""
|
"""Update class history attribute with history in memory (for chat method)"""
|
||||||
chat_history = self.memory.load_memory_variables({})["history"]
|
chat_history = self.memory.get_recent_memories(app_id=app_id, num_rounds=10)
|
||||||
if chat_history:
|
if chat_history:
|
||||||
self.set_history(chat_history)
|
self.set_history([str(history) for history in chat_history])
|
||||||
|
|
||||||
|
def add_history(self, app_id: str, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None):
|
||||||
|
chat_message = ChatMessage()
|
||||||
|
chat_message.add_user_message(question, metadata=metadata)
|
||||||
|
chat_message.add_ai_message(answer, metadata=metadata)
|
||||||
|
self.memory.add(app_id=app_id, chat_message=chat_message)
|
||||||
|
self.update_history(app_id=app_id)
|
||||||
|
|
||||||
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
|
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -165,7 +173,6 @@ class BaseLlm(JSONSerializable):
|
|||||||
for chunk in answer:
|
for chunk in answer:
|
||||||
streamed_answer = streamed_answer + chunk
|
streamed_answer = streamed_answer + chunk
|
||||||
yield chunk
|
yield chunk
|
||||||
self.memory.chat_memory.add_ai_message(streamed_answer)
|
|
||||||
logging.info(f"Answer: {streamed_answer}")
|
logging.info(f"Answer: {streamed_answer}")
|
||||||
|
|
||||||
def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
|
def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
|
||||||
@@ -257,8 +264,6 @@ class BaseLlm(JSONSerializable):
|
|||||||
if self.online:
|
if self.online:
|
||||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||||
|
|
||||||
self.update_history()
|
|
||||||
|
|
||||||
prompt = self.generate_prompt(input_query, contexts, **k)
|
prompt = self.generate_prompt(input_query, contexts, **k)
|
||||||
logging.info(f"Prompt: {prompt}")
|
logging.info(f"Prompt: {prompt}")
|
||||||
|
|
||||||
@@ -267,16 +272,9 @@ class BaseLlm(JSONSerializable):
|
|||||||
|
|
||||||
answer = self.get_answer_from_llm(prompt)
|
answer = self.get_answer_from_llm(prompt)
|
||||||
|
|
||||||
self.memory.chat_memory.add_user_message(input_query)
|
|
||||||
|
|
||||||
if isinstance(answer, str):
|
if isinstance(answer, str):
|
||||||
self.memory.chat_memory.add_ai_message(answer)
|
|
||||||
logging.info(f"Answer: {answer}")
|
logging.info(f"Answer: {answer}")
|
||||||
|
|
||||||
# NOTE: Adding to history before and after. This could be seen as redundant.
|
|
||||||
# If we change it, we have to change the tests (no big deal).
|
|
||||||
self.update_history()
|
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
else:
|
else:
|
||||||
# this is a streamed response and needs to be handled differently.
|
# this is a streamed response and needs to be handled differently.
|
||||||
@@ -287,7 +285,7 @@ class BaseLlm(JSONSerializable):
|
|||||||
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
|
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
|
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
|
||||||
"""
|
"""
|
||||||
Construct a list of langchain messages
|
Construct a list of langchain messages
|
||||||
|
|
||||||
|
|||||||
0
embedchain/memory/__init__.py
Normal file
0
embedchain/memory/__init__.py
Normal file
112
embedchain/memory/base.py
Normal file
112
embedchain/memory/base.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from embedchain.constants import SQLITE_PATH
|
||||||
|
from embedchain.memory.message import ChatMessage
|
||||||
|
from embedchain.memory.utils import merge_metadata_dict
|
||||||
|
|
||||||
|
CHAT_MESSAGE_CREATE_TABLE_QUERY = """
|
||||||
|
CREATE TABLE IF NOT EXISTS chat_history (
|
||||||
|
app_id TEXT,
|
||||||
|
id TEXT,
|
||||||
|
question TEXT,
|
||||||
|
answer TEXT,
|
||||||
|
metadata TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
PRIMARY KEY (id, app_id)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ECChatMemory:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
with sqlite3.connect(SQLITE_PATH) as self.connection:
|
||||||
|
self.cursor = self.connection.cursor()
|
||||||
|
|
||||||
|
self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
|
def add(self, app_id, chat_message: ChatMessage) -> Optional[str]:
|
||||||
|
memory_id = str(uuid.uuid4())
|
||||||
|
metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
|
||||||
|
if metadata_dict:
|
||||||
|
metadata = self._serialize_json(metadata_dict)
|
||||||
|
ADD_CHAT_MESSAGE_QUERY = """
|
||||||
|
INSERT INTO chat_history (app_id, id, question, answer, metadata)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
"""
|
||||||
|
self.cursor.execute(
|
||||||
|
ADD_CHAT_MESSAGE_QUERY,
|
||||||
|
(
|
||||||
|
app_id,
|
||||||
|
memory_id,
|
||||||
|
chat_message.human_message.content,
|
||||||
|
chat_message.ai_message.content,
|
||||||
|
metadata if metadata_dict else "{}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.connection.commit()
|
||||||
|
logging.info(f"Added chat memory to db with id: {memory_id}")
|
||||||
|
return memory_id
|
||||||
|
|
||||||
|
def delete_chat_history(self, app_id: str):
|
||||||
|
DELETE_CHAT_HISTORY_QUERY = """
|
||||||
|
DELETE FROM chat_history WHERE app_id=?
|
||||||
|
"""
|
||||||
|
self.cursor.execute(
|
||||||
|
DELETE_CHAT_HISTORY_QUERY,
|
||||||
|
(app_id,),
|
||||||
|
)
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
|
def get_recent_memories(self, app_id, num_rounds=10) -> List[ChatMessage]:
|
||||||
|
"""
|
||||||
|
Get the most recent num_rounds rounds of conversations
|
||||||
|
between human and AI, for a given app_id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
QUERY = """
|
||||||
|
SELECT * FROM chat_history
|
||||||
|
WHERE app_id=?
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
self.cursor.execute(
|
||||||
|
QUERY,
|
||||||
|
(app_id, num_rounds),
|
||||||
|
)
|
||||||
|
|
||||||
|
results = self.cursor.fetchall()
|
||||||
|
history = []
|
||||||
|
for result in results:
|
||||||
|
app_id, id, question, answer, metadata, timestamp = result
|
||||||
|
metadata = self._deserialize_json(metadata=metadata)
|
||||||
|
memory = ChatMessage()
|
||||||
|
memory.add_user_message(question, metadata=metadata)
|
||||||
|
memory.add_ai_message(answer, metadata=metadata)
|
||||||
|
history.append(memory)
|
||||||
|
return history
|
||||||
|
|
||||||
|
def _serialize_json(self, metadata: Dict[str, Any]):
|
||||||
|
return json.dumps(metadata)
|
||||||
|
|
||||||
|
def _deserialize_json(self, metadata: str):
|
||||||
|
return json.loads(metadata)
|
||||||
|
|
||||||
|
def close_connection(self):
|
||||||
|
self.connection.close()
|
||||||
|
|
||||||
|
def count_history_messages(self, app_id: str):
|
||||||
|
QUERY = """
|
||||||
|
SELECT COUNT(*) FROM chat_history
|
||||||
|
WHERE app_id=?
|
||||||
|
"""
|
||||||
|
self.cursor.execute(
|
||||||
|
QUERY,
|
||||||
|
(app_id,),
|
||||||
|
)
|
||||||
|
count = self.cursor.fetchone()[0]
|
||||||
|
return count
|
||||||
72
embedchain/memory/message.py
Normal file
72
embedchain/memory/message.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from embedchain.helper.json_serializable import JSONSerializable
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessage(JSONSerializable):
|
||||||
|
"""
|
||||||
|
The base abstract message class.
|
||||||
|
|
||||||
|
Messages are the inputs and outputs of Models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The string content of the message.
|
||||||
|
content: str
|
||||||
|
|
||||||
|
# The creator of the message. AI, Human, Bot etc.
|
||||||
|
by: str
|
||||||
|
|
||||||
|
# Any additional info.
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
def __init__(self, content: str, creator: str, metadata: Optional[Dict[str, Any]] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.content = content
|
||||||
|
self.creator = creator
|
||||||
|
self.metadata = metadata
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the Message, used for serialization."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""Return whether this class is serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.creator}: {self.content}"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(JSONSerializable):
|
||||||
|
"""
|
||||||
|
The base abstract chat message class.
|
||||||
|
|
||||||
|
Chat messages are the pair of (question, answer) conversation
|
||||||
|
between human and model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
human_message: Optional[BaseMessage] = None
|
||||||
|
ai_message: Optional[BaseMessage] = None
|
||||||
|
|
||||||
|
def add_user_message(self, message: str, metadata: Optional[dict] = None):
|
||||||
|
if self.human_message:
|
||||||
|
logging.info(
|
||||||
|
"Human message already exists in the chat message,\
|
||||||
|
overwritting it with new message."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.human_message = BaseMessage(content=message, creator="human", metadata=metadata)
|
||||||
|
|
||||||
|
def add_ai_message(self, message: str, metadata: Optional[dict] = None):
|
||||||
|
if self.ai_message:
|
||||||
|
logging.info(
|
||||||
|
"AI message already exists in the chat message,\
|
||||||
|
overwritting it with new message."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ai_message = BaseMessage(content=message, creator="ai", metadata=metadata)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.human_message} | {self.ai_message}"
|
||||||
35
embedchain/memory/utils.py
Normal file
35
embedchain/memory/utils.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Merge the metadatas of two BaseMessage types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
left (Dict[str, Any]): metadata of human message
|
||||||
|
right (Dict[str, Any]): metadata of ai message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: combined metadata dict with dedup
|
||||||
|
to be saved in db.
|
||||||
|
"""
|
||||||
|
if not left and not right:
|
||||||
|
return None
|
||||||
|
elif not left:
|
||||||
|
return right
|
||||||
|
elif not right:
|
||||||
|
return left
|
||||||
|
|
||||||
|
merged = left.copy()
|
||||||
|
for k, v in right.items():
|
||||||
|
if k not in merged:
|
||||||
|
merged[k] = v
|
||||||
|
elif type(merged[k]) != type(v):
|
||||||
|
raise ValueError(f'additional_kwargs["{k}"] already exists in this message,' " but with a different type.")
|
||||||
|
elif isinstance(merged[k], str):
|
||||||
|
merged[k] += v
|
||||||
|
elif isinstance(merged[k], dict):
|
||||||
|
merged[k] = merge_metadata_dict(merged[k], v)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Additional kwargs key {k} already exists in this message.")
|
||||||
|
return merged
|
||||||
@@ -10,7 +10,8 @@ import yaml
|
|||||||
|
|
||||||
from embedchain import Client
|
from embedchain import Client
|
||||||
from embedchain.config import ChunkerConfig, PipelineConfig
|
from embedchain.config import ChunkerConfig, PipelineConfig
|
||||||
from embedchain.embedchain import CONFIG_DIR, EmbedChain
|
from embedchain.constants import SQLITE_PATH
|
||||||
|
from embedchain.embedchain import EmbedChain
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
from embedchain.embedder.openai import OpenAIEmbedder
|
from embedchain.embedder.openai import OpenAIEmbedder
|
||||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||||
@@ -22,8 +23,6 @@ from embedchain.utils import validate_yaml_config
|
|||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
from embedchain.vectordb.chroma import ChromaDB
|
from embedchain.vectordb.chroma import ChromaDB
|
||||||
|
|
||||||
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
|
|
||||||
|
|
||||||
|
|
||||||
@register_deserializable
|
@register_deserializable
|
||||||
class Pipeline(EmbedChain):
|
class Pipeline(EmbedChain):
|
||||||
|
|||||||
@@ -138,7 +138,8 @@ def detect_datatype(source: Any) -> DataType:
|
|||||||
formatted_source = format_source(str(source), 30)
|
formatted_source = format_source(str(source), 30)
|
||||||
|
|
||||||
if url:
|
if url:
|
||||||
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
from langchain.document_loaders.youtube import \
|
||||||
|
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||||
|
|
||||||
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
||||||
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from embedchain import App
|
|||||||
from embedchain.config import AppConfig, ChromaDbConfig
|
from embedchain.config import AppConfig, ChromaDbConfig
|
||||||
from embedchain.embedchain import EmbedChain
|
from embedchain.embedchain import EmbedChain
|
||||||
from embedchain.llm.base import BaseLlm
|
from embedchain.llm.base import BaseLlm
|
||||||
|
from embedchain.memory.base import ECChatMemory
|
||||||
|
|
||||||
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||||
|
|
||||||
@@ -25,6 +26,11 @@ def test_whole_app(app_instance, mocker):
|
|||||||
mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
|
mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
|
||||||
mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
|
mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
|
||||||
mocker.patch.object(BaseLlm, "generate_prompt")
|
mocker.patch.object(BaseLlm, "generate_prompt")
|
||||||
|
mocker.patch.object(
|
||||||
|
BaseLlm,
|
||||||
|
"add_history",
|
||||||
|
)
|
||||||
|
mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
|
||||||
|
|
||||||
app_instance.add(knowledge, data_type="text")
|
app_instance.add(knowledge, data_type="text")
|
||||||
app_instance.query("What text did I give you?")
|
app_instance.query("What text did I give you?")
|
||||||
@@ -41,6 +47,10 @@ def test_add_after_reset(app_instance, mocker):
|
|||||||
chroma_config = {"allow_reset": True}
|
chroma_config = {"allow_reset": True}
|
||||||
|
|
||||||
app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
|
app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
|
||||||
|
|
||||||
|
# mock delete chat history
|
||||||
|
mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
|
||||||
|
|
||||||
app_instance.reset()
|
app_instance.reset()
|
||||||
|
|
||||||
app_instance.db.client.heartbeat()
|
app_instance.db.client.heartbeat()
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
|
|||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import AppConfig, BaseLlmConfig
|
from embedchain.config import AppConfig, BaseLlmConfig
|
||||||
from embedchain.llm.base import BaseLlm
|
from embedchain.llm.base import BaseLlm
|
||||||
|
from embedchain.memory.base import ECChatMemory
|
||||||
|
from embedchain.memory.message import ChatMessage
|
||||||
|
|
||||||
|
|
||||||
class TestApp(unittest.TestCase):
|
class TestApp(unittest.TestCase):
|
||||||
@@ -31,14 +33,14 @@ class TestApp(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
config = AppConfig(collect_metrics=False)
|
config = AppConfig(collect_metrics=False)
|
||||||
app = App(config=config)
|
app = App(config=config)
|
||||||
first_answer = app.chat("Test query 1")
|
with patch.object(BaseLlm, "add_history") as mock_history:
|
||||||
self.assertEqual(first_answer, "Test answer")
|
first_answer = app.chat("Test query 1")
|
||||||
self.assertEqual(len(app.llm.memory.chat_memory.messages), 2)
|
self.assertEqual(first_answer, "Test answer")
|
||||||
self.assertEqual(len(app.llm.history.splitlines()), 2)
|
mock_history.assert_called_with(app.config.id, "Test query 1", "Test answer")
|
||||||
second_answer = app.chat("Test query 2")
|
|
||||||
self.assertEqual(second_answer, "Test answer")
|
second_answer = app.chat("Test query 2")
|
||||||
self.assertEqual(len(app.llm.memory.chat_memory.messages), 4)
|
self.assertEqual(second_answer, "Test answer")
|
||||||
self.assertEqual(len(app.llm.history.splitlines()), 4)
|
mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer")
|
||||||
|
|
||||||
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
|
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
|
||||||
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
||||||
@@ -49,16 +51,22 @@ class TestApp(unittest.TestCase):
|
|||||||
|
|
||||||
Also tests that a dry run does not change the history
|
Also tests that a dry run does not change the history
|
||||||
"""
|
"""
|
||||||
config = AppConfig(collect_metrics=False)
|
with patch.object(ECChatMemory, "get_recent_memories") as mock_memory:
|
||||||
app = App(config=config)
|
mock_message = ChatMessage()
|
||||||
first_answer = app.chat("Test query 1")
|
mock_message.add_user_message("Test query 1")
|
||||||
self.assertEqual(first_answer, "Test answer")
|
mock_message.add_ai_message("Test answer")
|
||||||
self.assertEqual(len(app.llm.history.splitlines()), 2)
|
mock_memory.return_value = [mock_message]
|
||||||
history = app.llm.history
|
|
||||||
dry_run = app.chat("Test query 2", dry_run=True)
|
config = AppConfig(collect_metrics=False)
|
||||||
self.assertIn("History:", dry_run)
|
app = App(config=config)
|
||||||
self.assertEqual(history, app.llm.history)
|
first_answer = app.chat("Test query 1")
|
||||||
self.assertEqual(len(app.llm.history.splitlines()), 2)
|
self.assertEqual(first_answer, "Test answer")
|
||||||
|
self.assertEqual(len(app.llm.history), 1)
|
||||||
|
history = app.llm.history
|
||||||
|
dry_run = app.chat("Test query 2", dry_run=True)
|
||||||
|
self.assertIn("History:", dry_run)
|
||||||
|
self.assertEqual(history, app.llm.history)
|
||||||
|
self.assertEqual(len(app.llm.history), 1)
|
||||||
|
|
||||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||||
def test_chat_with_where_in_params(self):
|
def test_chat_with_where_in_params(self):
|
||||||
|
|||||||
67
tests/memory/test_chat_memory.py
Normal file
67
tests/memory/test_chat_memory.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from embedchain.memory.base import ECChatMemory
|
||||||
|
from embedchain.memory.message import ChatMessage
|
||||||
|
|
||||||
|
|
||||||
|
# Fixture for creating an instance of ECChatMemory
|
||||||
|
@pytest.fixture
|
||||||
|
def chat_memory_instance():
|
||||||
|
return ECChatMemory()
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_chat_memory(chat_memory_instance):
|
||||||
|
app_id = "test_app"
|
||||||
|
human_message = "Hello, how are you?"
|
||||||
|
ai_message = "I'm fine, thank you!"
|
||||||
|
|
||||||
|
chat_message = ChatMessage()
|
||||||
|
chat_message.add_user_message(human_message)
|
||||||
|
chat_message.add_ai_message(ai_message)
|
||||||
|
|
||||||
|
chat_memory_instance.add(app_id, chat_message)
|
||||||
|
|
||||||
|
assert chat_memory_instance.count_history_messages(app_id) == 1
|
||||||
|
chat_memory_instance.delete_chat_history(app_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_recent_memories(chat_memory_instance):
|
||||||
|
app_id = "test_app"
|
||||||
|
|
||||||
|
for i in range(1, 7):
|
||||||
|
human_message = f"Question {i}"
|
||||||
|
ai_message = f"Answer {i}"
|
||||||
|
|
||||||
|
chat_message = ChatMessage()
|
||||||
|
chat_message.add_user_message(human_message)
|
||||||
|
chat_message.add_ai_message(ai_message)
|
||||||
|
|
||||||
|
chat_memory_instance.add(app_id, chat_message)
|
||||||
|
|
||||||
|
recent_memories = chat_memory_instance.get_recent_memories(app_id, num_rounds=5)
|
||||||
|
|
||||||
|
assert len(recent_memories) == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_chat_history(chat_memory_instance):
|
||||||
|
app_id = "test_app"
|
||||||
|
|
||||||
|
for i in range(1, 6):
|
||||||
|
human_message = f"Question {i}"
|
||||||
|
ai_message = f"Answer {i}"
|
||||||
|
|
||||||
|
chat_message = ChatMessage()
|
||||||
|
chat_message.add_user_message(human_message)
|
||||||
|
chat_message.add_ai_message(ai_message)
|
||||||
|
|
||||||
|
chat_memory_instance.add(app_id, chat_message)
|
||||||
|
|
||||||
|
chat_memory_instance.delete_chat_history(app_id)
|
||||||
|
|
||||||
|
assert chat_memory_instance.count_history_messages(app_id) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def close_connection(chat_memory_instance):
|
||||||
|
yield
|
||||||
|
chat_memory_instance.close_connection()
|
||||||
37
tests/memory/test_memory_messages.py
Normal file
37
tests/memory/test_memory_messages.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from embedchain.memory.message import BaseMessage, ChatMessage
|
||||||
|
|
||||||
|
|
||||||
|
def test_ec_base_message():
|
||||||
|
content = "Hello, how are you?"
|
||||||
|
creator = "human"
|
||||||
|
metadata = {"key": "value"}
|
||||||
|
|
||||||
|
message = BaseMessage(content=content, creator=creator, metadata=metadata)
|
||||||
|
|
||||||
|
assert message.content == content
|
||||||
|
assert message.creator == creator
|
||||||
|
assert message.metadata == metadata
|
||||||
|
assert message.type is None
|
||||||
|
assert message.is_lc_serializable() is True
|
||||||
|
assert str(message) == f"{creator}: {content}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ec_base_chat_message():
|
||||||
|
human_message_content = "Hello, how are you?"
|
||||||
|
ai_message_content = "I'm fine, thank you!"
|
||||||
|
human_metadata = {"user": "John"}
|
||||||
|
ai_metadata = {"response_time": 0.5}
|
||||||
|
|
||||||
|
chat_message = ChatMessage()
|
||||||
|
chat_message.add_user_message(human_message_content, metadata=human_metadata)
|
||||||
|
chat_message.add_ai_message(ai_message_content, metadata=ai_metadata)
|
||||||
|
|
||||||
|
assert chat_message.human_message.content == human_message_content
|
||||||
|
assert chat_message.human_message.creator == "human"
|
||||||
|
assert chat_message.human_message.metadata == human_metadata
|
||||||
|
|
||||||
|
assert chat_message.ai_message.content == ai_message_content
|
||||||
|
assert chat_message.ai_message.creator == "ai"
|
||||||
|
assert chat_message.ai_message.metadata == ai_metadata
|
||||||
|
|
||||||
|
assert str(chat_message) == f"human: {human_message_content} | ai: {ai_message_content}"
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from embedchain.utils import validate_yaml_config
|
from embedchain.utils import validate_yaml_config
|
||||||
|
|
||||||
CONFIG_YAMLS = [
|
CONFIG_YAMLS = [
|
||||||
|
|||||||
Reference in New Issue
Block a user