diff --git a/embedchain/client.py b/embedchain/client.py index 56f2d190..7303aafe 100644 --- a/embedchain/client.py +++ b/embedchain/client.py @@ -5,7 +5,7 @@ import uuid import requests -from embedchain.embedchain import CONFIG_DIR, CONFIG_FILE +from embedchain.constants import CONFIG_DIR, CONFIG_FILE class Client: diff --git a/embedchain/constants.py b/embedchain/constants.py new file mode 100644 index 00000000..d3fda313 --- /dev/null +++ b/embedchain/constants.py @@ -0,0 +1,8 @@ +import os +from pathlib import Path + +ABS_PATH = os.getcwd() +HOME_DIR = str(Path.home()) +CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain") +CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json") +SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db") diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index b1a31a36..373bf635 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -1,9 +1,7 @@ import hashlib import json import logging -import os import sqlite3 -from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union from dotenv import load_dotenv @@ -12,6 +10,7 @@ from langchain.docstore.document import Document from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config.apps.base_app_config import BaseAppConfig +from embedchain.constants import SQLITE_PATH from embedchain.data_formatter import DataFormatter from embedchain.embedder.base import BaseEmbedder from embedchain.helper.json_serializable import JSONSerializable @@ -25,12 +24,6 @@ from embedchain.vectordb.base import BaseVectorDB load_dotenv() -ABS_PATH = os.getcwd() -HOME_DIR = str(Path.home()) -CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain") -CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json") -SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db") - class EmbedChain(JSONSerializable): def __init__( @@ -602,6 +595,9 @@ class EmbedChain(JSONSerializable): input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run ) + # add conversation in memory + self.llm.add_history(self.config.id, input_query, answer) + # Send anonymous telemetry self.telemetry.capture(event_name="chat", properties=self._telemetry_props) @@ -645,5 +641,9 @@ class EmbedChain(JSONSerializable): self.db.reset() self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,)) self.connection.commit() + self.clear_history() # Send anonymous telemetry self.telemetry.capture(event_name="reset", properties=self._telemetry_props) + + def clear_history(self): + self.llm.memory.delete_chat_history(app_id=self.config.id) diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py index c90771a7..8bb38833 100644 --- a/embedchain/llm/base.py +++ b/embedchain/llm/base.py @@ -1,14 +1,15 @@ import logging from typing import Any, Dict, Generator, List, Optional -from langchain.memory import ConversationBufferMemory -from langchain.schema import BaseMessage +from langchain.schema import BaseMessage as LCBaseMessage from embedchain.config import BaseLlmConfig from embedchain.config.llm.base import (DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE) from embedchain.helper.json_serializable import JSONSerializable +from embedchain.memory.base import ECChatMemory +from embedchain.memory.message import ChatMessage class BaseLlm(JSONSerializable): @@ -23,7 +24,7 @@ class BaseLlm(JSONSerializable): else: self.config = config - self.memory = ConversationBufferMemory() + self.memory = ECChatMemory() self.is_docs_site_instance = False self.online = False self.history: Any = None @@ -44,11 +45,18 @@ class BaseLlm(JSONSerializable): """ self.history = history - def update_history(self): + def update_history(self, app_id: str): """Update class history attribute with history in memory (for chat method)""" - chat_history = self.memory.load_memory_variables({})["history"] + chat_history = self.memory.get_recent_memories(app_id=app_id, num_rounds=10) if chat_history: - self.set_history(chat_history) + self.set_history([str(history) for history in chat_history]) + + def add_history(self, app_id: str, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None): + chat_message = ChatMessage() + chat_message.add_user_message(question, metadata=metadata) + chat_message.add_ai_message(answer, metadata=metadata) + self.memory.add(app_id=app_id, chat_message=chat_message) + self.update_history(app_id=app_id) def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str: """ @@ -165,7 +173,6 @@ class BaseLlm(JSONSerializable): for chunk in answer: streamed_answer = streamed_answer + chunk yield chunk - self.memory.chat_memory.add_ai_message(streamed_answer) logging.info(f"Answer: {streamed_answer}") def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False): @@ -257,8 +264,6 @@ class BaseLlm(JSONSerializable): if self.online: k["web_search_result"] = self.access_search_and_get_results(input_query) - self.update_history() - prompt = self.generate_prompt(input_query, contexts, **k) logging.info(f"Prompt: {prompt}") @@ -267,16 +272,9 @@ class BaseLlm(JSONSerializable): answer = self.get_answer_from_llm(prompt) - self.memory.chat_memory.add_user_message(input_query) - if isinstance(answer, str): - self.memory.chat_memory.add_ai_message(answer) logging.info(f"Answer: {answer}") - # NOTE: Adding to history before and after. This could be seen as redundant. - # If we change it, we have to change the tests (no big deal). - self.update_history() - return answer else: # this is a streamed response and needs to be handled differently. @@ -287,7 +285,7 @@ class BaseLlm(JSONSerializable): self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config) @staticmethod - def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]: + def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]: """ Construct a list of langchain messages diff --git a/embedchain/memory/__init__.py b/embedchain/memory/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/embedchain/memory/base.py b/embedchain/memory/base.py new file mode 100644 index 00000000..83bb146b --- /dev/null +++ b/embedchain/memory/base.py @@ -0,0 +1,112 @@ +import json +import logging +import sqlite3 +import uuid +from typing import Any, Dict, List, Optional + +from embedchain.constants import SQLITE_PATH +from embedchain.memory.message import ChatMessage +from embedchain.memory.utils import merge_metadata_dict + +CHAT_MESSAGE_CREATE_TABLE_QUERY = """ + CREATE TABLE IF NOT EXISTS chat_history ( + app_id TEXT, + id TEXT, + question TEXT, + answer TEXT, + metadata TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id, app_id) + ) + """ + + +class ECChatMemory: + def __init__(self) -> None: + with sqlite3.connect(SQLITE_PATH) as self.connection: + self.cursor = self.connection.cursor() + + self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY) + self.connection.commit() + + def add(self, app_id, chat_message: ChatMessage) -> Optional[str]: + memory_id = str(uuid.uuid4()) + metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata) + if metadata_dict: + metadata = self._serialize_json(metadata_dict) + ADD_CHAT_MESSAGE_QUERY = """ + INSERT INTO chat_history (app_id, id, question, answer, metadata) + VALUES (?, ?, ?, ?, ?) + """ + self.cursor.execute( + ADD_CHAT_MESSAGE_QUERY, + ( + app_id, + memory_id, + chat_message.human_message.content, + chat_message.ai_message.content, + metadata if metadata_dict else "{}", + ), + ) + self.connection.commit() + logging.info(f"Added chat memory to db with id: {memory_id}") + return memory_id + + def delete_chat_history(self, app_id: str): + DELETE_CHAT_HISTORY_QUERY = """ + DELETE FROM chat_history WHERE app_id=? + """ + self.cursor.execute( + DELETE_CHAT_HISTORY_QUERY, + (app_id,), + ) + self.connection.commit() + + def get_recent_memories(self, app_id, num_rounds=10) -> List[ChatMessage]: + """ + Get the most recent num_rounds rounds of conversations + between human and AI, for a given app_id. + """ + + QUERY = """ + SELECT * FROM chat_history + WHERE app_id=? + ORDER BY created_at DESC + LIMIT ? + """ + self.cursor.execute( + QUERY, + (app_id, num_rounds), + ) + + results = self.cursor.fetchall() + history = [] + for result in results: + app_id, id, question, answer, metadata, timestamp = result + metadata = self._deserialize_json(metadata=metadata) + memory = ChatMessage() + memory.add_user_message(question, metadata=metadata) + memory.add_ai_message(answer, metadata=metadata) + history.append(memory) + return history + + def _serialize_json(self, metadata: Dict[str, Any]): + return json.dumps(metadata) + + def _deserialize_json(self, metadata: str): + return json.loads(metadata) + + def close_connection(self): + self.connection.close() + + def count_history_messages(self, app_id: str): + QUERY = """ + SELECT COUNT(*) FROM chat_history + WHERE app_id=? + """ + self.cursor.execute( + QUERY, + (app_id,), + ) + count = self.cursor.fetchone()[0] + return count diff --git a/embedchain/memory/message.py b/embedchain/memory/message.py new file mode 100644 index 00000000..3fe74af2 --- /dev/null +++ b/embedchain/memory/message.py @@ -0,0 +1,72 @@ +import logging +from typing import Any, Dict, Optional + +from embedchain.helper.json_serializable import JSONSerializable + + +class BaseMessage(JSONSerializable): + """ + The base abstract message class. + + Messages are the inputs and outputs of Models. + """ + + # The string content of the message. + content: str + + # The creator of the message. AI, Human, Bot etc. + by: str + + # Any additional info. + metadata: Dict[str, Any] + + def __init__(self, content: str, creator: str, metadata: Optional[Dict[str, Any]] = None) -> None: + super().__init__() + self.content = content + self.creator = creator + self.metadata = metadata + + @property + def type(self) -> str: + """Type of the Message, used for serialization.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this class is serializable.""" + return True + + def __str__(self) -> str: + return f"{self.creator}: {self.content}" + + +class ChatMessage(JSONSerializable): + """ + The base abstract chat message class. + + Chat messages are the pair of (question, answer) conversation + between human and model. + """ + + human_message: Optional[BaseMessage] = None + ai_message: Optional[BaseMessage] = None + + def add_user_message(self, message: str, metadata: Optional[dict] = None): + if self.human_message: + logging.info( + "Human message already exists in the chat message,\ + overwritting it with new message." + ) + + self.human_message = BaseMessage(content=message, creator="human", metadata=metadata) + + def add_ai_message(self, message: str, metadata: Optional[dict] = None): + if self.ai_message: + logging.info( + "AI message already exists in the chat message,\ + overwritting it with new message." + ) + + self.ai_message = BaseMessage(content=message, creator="ai", metadata=metadata) + + def __str__(self) -> str: + return f"{self.human_message} | {self.ai_message}" diff --git a/embedchain/memory/utils.py b/embedchain/memory/utils.py new file mode 100644 index 00000000..eb2e35f1 --- /dev/null +++ b/embedchain/memory/utils.py @@ -0,0 +1,35 @@ +from typing import Any, Dict, Optional + + +def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """ + Merge the metadatas of two BaseMessage types. + + Args: + left (Dict[str, Any]): metadata of human message + right (Dict[str, Any]): metadata of ai message + + Returns: + Dict[str, Any]: combined metadata dict with dedup + to be saved in db. + """ + if not left and not right: + return None + elif not left: + return right + elif not right: + return left + + merged = left.copy() + for k, v in right.items(): + if k not in merged: + merged[k] = v + elif type(merged[k]) != type(v): + raise ValueError(f'additional_kwargs["{k}"] already exists in this message,' " but with a different type.") + elif isinstance(merged[k], str): + merged[k] += v + elif isinstance(merged[k], dict): + merged[k] = merge_metadata_dict(merged[k], v) + else: + raise ValueError(f"Additional kwargs key {k} already exists in this message.") + return merged diff --git a/embedchain/pipeline.py b/embedchain/pipeline.py index dc753704..ed37a283 100644 --- a/embedchain/pipeline.py +++ b/embedchain/pipeline.py @@ -10,7 +10,8 @@ import yaml from embedchain import Client from embedchain.config import ChunkerConfig, PipelineConfig -from embedchain.embedchain import CONFIG_DIR, EmbedChain +from embedchain.constants import SQLITE_PATH +from embedchain.embedchain import EmbedChain from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.openai import OpenAIEmbedder from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory @@ -22,8 +23,6 @@ from embedchain.utils import validate_yaml_config from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.chroma import ChromaDB -SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db") - @register_deserializable class Pipeline(EmbedChain): diff --git a/embedchain/utils.py b/embedchain/utils.py index 455430e3..811375f7 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -138,7 +138,8 @@ def detect_datatype(source: Any) -> DataType: formatted_source = format_source(str(source), 30) if url: - from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS + from langchain.document_loaders.youtube import \ + ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS if url.netloc in YOUTUBE_ALLOWED_NETLOCS: logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.") diff --git a/tests/embedchain/test_embedchain.py b/tests/embedchain/test_embedchain.py index 8d57dca2..ed1c4f64 100644 --- a/tests/embedchain/test_embedchain.py +++ b/tests/embedchain/test_embedchain.py @@ -7,6 +7,7 @@ from embedchain import App from embedchain.config import AppConfig, ChromaDbConfig from embedchain.embedchain import EmbedChain from embedchain.llm.base import BaseLlm +from embedchain.memory.base import ECChatMemory os.environ["OPENAI_API_KEY"] = "test-api-key" @@ -25,6 +26,11 @@ def test_whole_app(app_instance, mocker): mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge) mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge) mocker.patch.object(BaseLlm, "generate_prompt") + mocker.patch.object( + BaseLlm, + "add_history", + ) + mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True) app_instance.add(knowledge, data_type="text") app_instance.query("What text did I give you?") @@ -41,6 +47,10 @@ def test_add_after_reset(app_instance, mocker): chroma_config = {"allow_reset": True} app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config)) + + # mock delete chat history + mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True) + app_instance.reset() app_instance.db.client.heartbeat() diff --git a/tests/llm/test_chat.py b/tests/llm/test_chat.py index 2fcb7a11..a70e62ce 100644 --- a/tests/llm/test_chat.py +++ b/tests/llm/test_chat.py @@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch from embedchain import App from embedchain.config import AppConfig, BaseLlmConfig from embedchain.llm.base import BaseLlm +from embedchain.memory.base import ECChatMemory +from embedchain.memory.message import ChatMessage class TestApp(unittest.TestCase): @@ -31,14 +33,14 @@ class TestApp(unittest.TestCase): """ config = AppConfig(collect_metrics=False) app = App(config=config) - first_answer = app.chat("Test query 1") - self.assertEqual(first_answer, "Test answer") - self.assertEqual(len(app.llm.memory.chat_memory.messages), 2) - self.assertEqual(len(app.llm.history.splitlines()), 2) - second_answer = app.chat("Test query 2") - self.assertEqual(second_answer, "Test answer") - self.assertEqual(len(app.llm.memory.chat_memory.messages), 4) - self.assertEqual(len(app.llm.history.splitlines()), 4) + with patch.object(BaseLlm, "add_history") as mock_history: + first_answer = app.chat("Test query 1") + self.assertEqual(first_answer, "Test answer") + mock_history.assert_called_with(app.config.id, "Test query 1", "Test answer") + + second_answer = app.chat("Test query 2") + self.assertEqual(second_answer, "Test answer") + mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer") @patch.object(App, "retrieve_from_database", return_value=["Test context"]) @patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer") @@ -49,16 +51,22 @@ class TestApp(unittest.TestCase): Also tests that a dry run does not change the history """ - config = AppConfig(collect_metrics=False) - app = App(config=config) - first_answer = app.chat("Test query 1") - self.assertEqual(first_answer, "Test answer") - self.assertEqual(len(app.llm.history.splitlines()), 2) - history = app.llm.history - dry_run = app.chat("Test query 2", dry_run=True) - self.assertIn("History:", dry_run) - self.assertEqual(history, app.llm.history) - self.assertEqual(len(app.llm.history.splitlines()), 2) + with patch.object(ECChatMemory, "get_recent_memories") as mock_memory: + mock_message = ChatMessage() + mock_message.add_user_message("Test query 1") + mock_message.add_ai_message("Test answer") + mock_memory.return_value = [mock_message] + + config = AppConfig(collect_metrics=False) + app = App(config=config) + first_answer = app.chat("Test query 1") + self.assertEqual(first_answer, "Test answer") + self.assertEqual(len(app.llm.history), 1) + history = app.llm.history + dry_run = app.chat("Test query 2", dry_run=True) + self.assertIn("History:", dry_run) + self.assertEqual(history, app.llm.history) + self.assertEqual(len(app.llm.history), 1) @patch("chromadb.api.models.Collection.Collection.add", MagicMock) def test_chat_with_where_in_params(self): diff --git a/tests/memory/test_chat_memory.py b/tests/memory/test_chat_memory.py new file mode 100644 index 00000000..9b78d200 --- /dev/null +++ b/tests/memory/test_chat_memory.py @@ -0,0 +1,67 @@ +import pytest + +from embedchain.memory.base import ECChatMemory +from embedchain.memory.message import ChatMessage + + +# Fixture for creating an instance of ECChatMemory +@pytest.fixture +def chat_memory_instance(): + return ECChatMemory() + + +def test_add_chat_memory(chat_memory_instance): + app_id = "test_app" + human_message = "Hello, how are you?" + ai_message = "I'm fine, thank you!" + + chat_message = ChatMessage() + chat_message.add_user_message(human_message) + chat_message.add_ai_message(ai_message) + + chat_memory_instance.add(app_id, chat_message) + + assert chat_memory_instance.count_history_messages(app_id) == 1 + chat_memory_instance.delete_chat_history(app_id) + + +def test_get_recent_memories(chat_memory_instance): + app_id = "test_app" + + for i in range(1, 7): + human_message = f"Question {i}" + ai_message = f"Answer {i}" + + chat_message = ChatMessage() + chat_message.add_user_message(human_message) + chat_message.add_ai_message(ai_message) + + chat_memory_instance.add(app_id, chat_message) + + recent_memories = chat_memory_instance.get_recent_memories(app_id, num_rounds=5) + + assert len(recent_memories) == 5 + + +def test_delete_chat_history(chat_memory_instance): + app_id = "test_app" + + for i in range(1, 6): + human_message = f"Question {i}" + ai_message = f"Answer {i}" + + chat_message = ChatMessage() + chat_message.add_user_message(human_message) + chat_message.add_ai_message(ai_message) + + chat_memory_instance.add(app_id, chat_message) + + chat_memory_instance.delete_chat_history(app_id) + + assert chat_memory_instance.count_history_messages(app_id) == 0 + + +@pytest.fixture +def close_connection(chat_memory_instance): + yield + chat_memory_instance.close_connection() diff --git a/tests/memory/test_memory_messages.py b/tests/memory/test_memory_messages.py new file mode 100644 index 00000000..81adc79c --- /dev/null +++ b/tests/memory/test_memory_messages.py @@ -0,0 +1,37 @@ +from embedchain.memory.message import BaseMessage, ChatMessage + + +def test_ec_base_message(): + content = "Hello, how are you?" + creator = "human" + metadata = {"key": "value"} + + message = BaseMessage(content=content, creator=creator, metadata=metadata) + + assert message.content == content + assert message.creator == creator + assert message.metadata == metadata + assert message.type is None + assert message.is_lc_serializable() is True + assert str(message) == f"{creator}: {content}" + + +def test_ec_base_chat_message(): + human_message_content = "Hello, how are you?" + ai_message_content = "I'm fine, thank you!" + human_metadata = {"user": "John"} + ai_metadata = {"response_time": 0.5} + + chat_message = ChatMessage() + chat_message.add_user_message(human_message_content, metadata=human_metadata) + chat_message.add_ai_message(ai_message_content, metadata=ai_metadata) + + assert chat_message.human_message.content == human_message_content + assert chat_message.human_message.creator == "human" + assert chat_message.human_message.metadata == human_metadata + + assert chat_message.ai_message.content == ai_message_content + assert chat_message.ai_message.creator == "ai" + assert chat_message.ai_message.metadata == ai_metadata + + assert str(chat_message) == f"human: {human_message_content} | ai: {ai_message_content}" diff --git a/tests/test_utils.py b/tests/test_utils.py index c0dd2eb1..722cda23 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ import yaml + from embedchain.utils import validate_yaml_config CONFIG_YAMLS = [