Integrate Mem0 (#1462)

Co-authored-by: Deshraj Yadav <deshraj@gatech.edu>
This commit is contained in:
Dev Khant
2024-07-07 00:57:01 +05:30
committed by GitHub
parent bd654e7aac
commit bbe56107fb
11 changed files with 195 additions and 34 deletions

View File

@@ -249,6 +249,9 @@ Alright, let's dive into what each key means in the yaml config above:
- `config` (Optional): The config for initializing the cache. If not provided, sensible default values are used as mentioned below.
- `similarity_threshold` (Float): The threshold for similarity evaluation. Defaults to `0.8`.
- `auto_flush` (Integer): The number of queries after which the cache is flushed. Defaults to `20`.
7. `memory` Section: (Optional)
- `api_key` (String): The API key of mem0.
- `top_k` (Integer): The number of top-k results to return. Defaults to `10`.
<Note>
If you provide a cache section, the app will automatically configure and use a cache to store the results of the language model. This is useful if you want to speed up the response time and save inference cost of your app.
</Note>

View File

@@ -144,3 +144,28 @@ app.add("https://www.forbes.com/profile/elon-musk")
query_config = BaseLlmConfig(number_documents=5)
app.chat("What is the net worth of Elon Musk?", config=query_config)
```
### With Mem0 to store chat history
Mem0 is a cutting-edge long-term memory for LLMs to enable personalization for the GenAI stack. It enables LLMs to remember past interactions and provide more personalized responses.
Follow these steps to use Mem0 to enable memory for personalization in your apps:
- Install the [`mem0`](https://docs.mem0.ai/) package using `pip install memzero`.
- Get the api_key from [Mem0 Platform](https://app.mem0.ai/).
- Provide api_key in config under `memory`, refer [Configurations](docs/api-reference/advanced/configuration.mdx).
```python with mem0
from embedchain import App
config = {
"memory": {
"api_key": "m0-xxx",
"top_k": 5
}
}
app = App.from_config(config=config)
app.add("https://www.forbes.com/profile/elon-musk")
app.chat("What is the net worth of Elon Musk?")
```

View File

@@ -9,19 +9,24 @@ import requests
import yaml
from tqdm import tqdm
from embedchain.cache import (Config, ExactMatchEvaluation,
SearchDistanceEvaluation, cache,
gptcache_data_manager, gptcache_pre_function)
from mem0 import Mem0
from embedchain.cache import (
Config,
ExactMatchEvaluation,
SearchDistanceEvaluation,
cache,
gptcache_data_manager,
gptcache_pre_function,
)
from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
from embedchain.core.db.database import get_session, init_db, setup_engine
from embedchain.core.db.models import DataSource
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.evaluation.base import BaseMetric
from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance,
Groundedness)
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@@ -55,6 +60,7 @@ class App(EmbedChain):
auto_deploy: bool = False,
chunker: ChunkerConfig = None,
cache_config: CacheConfig = None,
memory_config: Mem0Config = None,
log_level: int = logging.WARN,
):
"""
@@ -95,6 +101,7 @@ class App(EmbedChain):
self.id = None
self.chunker = ChunkerConfig(**chunker) if chunker else None
self.cache_config = cache_config
self.memory_config = memory_config
self.config = config or AppConfig()
self.name = self.config.name
@@ -123,6 +130,11 @@ class App(EmbedChain):
if self.cache_config is not None:
self._init_cache()
# If memory_config is provided, initializing the memory ...
self.mem0_client = None
if self.memory_config is not None:
self.mem0_client = Mem0(api_key=self.memory_config.api_key)
# Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -365,11 +377,13 @@ class App(EmbedChain):
app_config_data = config_data.get("app", {}).get("config", {})
vector_db_config_data = config_data.get("vectordb", {})
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
memory_config_data = config_data.get("memory", {})
llm_config_data = config_data.get("llm", {})
chunker_config_data = config_data.get("chunker", {})
cache_config_data = config_data.get("cache", None)
app_config = AppConfig(**app_config_data)
memory_config = Mem0Config(**memory_config_data) if memory_config_data else None
vector_db_provider = vector_db_config_data.get("provider", "chroma")
vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
@@ -403,6 +417,7 @@ class App(EmbedChain):
auto_deploy=auto_deploy,
chunker=chunker_config_data,
cache_config=cache_config,
memory_config=memory_config,
)
def _eval(self, dataset: list[EvalData], metric: Union[BaseMetric, str]):

View File

@@ -12,3 +12,4 @@ from .vectordb.chroma import ChromaDbConfig
from .vectordb.elasticsearch import ElasticsearchDBConfig
from .vectordb.opensearch import OpenSearchDBConfig
from .vectordb.zilliz import ZillizDBConfig
from .mem0_config import Mem0Config

View File

@@ -50,6 +50,35 @@ Query: $query
Answer:
""" # noqa:E501
DEFAULT_PROMPT_WITH_MEM0_MEMORY = """
You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. You are also provided with the conversation history and memories with the user. Make sure to use relevant context from conversation history and memories as needed.
Here are some guidelines to follow:
1. Refrain from explicitly mentioning the context provided in your response.
2. Take into consideration the conversation history and memories provided.
3. The context should silently guide your answers without being directly acknowledged.
4. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
Context information:
----------------------
$context
----------------------
Conversation history:
----------------------
$history
----------------------
Memories/Preferences:
----------------------
$memories
----------------------
Query: $query
Answer:
""" # noqa:E501
DOCS_SITE_DEFAULT_PROMPT = """
You are an expert AI assistant for developer support product. Your responses must always be rooted in the context provided for each query. Wherever possible, give complete code snippet. Dont make up any code snippet on your own.
@@ -70,6 +99,7 @@ Answer:
DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_MEM0_MEMORY)
DOCS_SITE_PROMPT_TEMPLATE = Template(DOCS_SITE_DEFAULT_PROMPT)
query_re = re.compile(r"\$\{*query\}*")
context_re = re.compile(r"\$\{*context\}*")

View File

@@ -0,0 +1,21 @@
from typing import Any, Optional
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class Mem0Config(BaseConfig):
def __init__(self, api_key: str, top_k: Optional[int] = 10):
self.api_key = api_key
self.top_k = top_k
@staticmethod
def from_config(config: Optional[dict[str, Any]]):
if config is None:
return Mem0Config()
else:
return Mem0Config(
api_key=config.get("api_key", ""),
init_config=config.get("top_k", 10),
)

View File

@@ -52,6 +52,8 @@ class EmbedChain(JSONSerializable):
"""
self.config = config
self.cache_config = None
self.memory_config = None
self.mem0_client = None
# Llm
self.llm = llm
# Database has support for config assignment for backwards compatibility
@@ -595,6 +597,12 @@ class EmbedChain(JSONSerializable):
else:
contexts_data_for_llm_query = contexts
memories = None
if self.mem0_client:
memories = self.mem0_client.search(
query=input_query, agent_id=self.config.id, session_id=session_id, limit=self.memory_config.top_k
)
# Update the history beforehand so that we can handle multiple chat sessions in the same python session
self.llm.update_history(app_id=self.config.id, session_id=session_id)
@@ -615,13 +623,27 @@ class EmbedChain(JSONSerializable):
logger.debug("Cache disabled. Running chat without cache.")
if self.llm.config.token_usage:
answer, token_info = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
input_query=input_query,
contexts=contexts_data_for_llm_query,
config=config,
dry_run=dry_run,
memories=memories,
)
else:
answer = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
input_query=input_query,
contexts=contexts_data_for_llm_query,
config=config,
dry_run=dry_run,
memories=memories,
)
# Add to Mem0 memory if enabled
# TODO: Might need to prepend with some text like:
# "Remember user preferences from following user query: {input_query}"
if self.mem0_client:
self.mem0_client.add(data=input_query, agent_id=self.config.id, session_id=session_id)
# add conversation in memory
self.llm.add_history(self.config.id, input_query, answer, session_id=session_id)

View File

@@ -5,9 +5,12 @@ from typing import Any, Optional
from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import (DEFAULT_PROMPT,
from embedchain.config.llm.base import (
DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE,
)
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
@@ -74,6 +77,16 @@ class BaseLlm(JSONSerializable):
"""
return "\n".join(self.history)
def _format_memories(self, memories: list[dict]) -> str:
"""Format memories to be used in prompt
:param memories: Memories to format
:type memories: list[dict]
:return: Formatted memories
:rtype: str
"""
return "\n".join([memory["text"] for memory in memories])
def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
"""
Generates a prompt based on the given query and context, ready to be
@@ -88,6 +101,7 @@ class BaseLlm(JSONSerializable):
"""
context_string = " | ".join(contexts)
web_search_result = kwargs.get("web_search_result", "")
memories = kwargs.get("memories", None)
if web_search_result:
context_string = self._append_search_and_context(context_string, web_search_result)
@@ -103,6 +117,15 @@ class BaseLlm(JSONSerializable):
not self.config._validate_prompt_history(self.config.prompt)
and self.config.prompt.template == DEFAULT_PROMPT
):
if memories:
# swap in the template with Mem0 memory template
prompt = DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE.substitute(
context=context_string,
query=input_query,
history=self._format_history(),
memories=self._format_memories(memories),
)
else:
# swap in the template with history
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
context=context_string, query=input_query, history=self._format_history()
@@ -180,7 +203,7 @@ class BaseLlm(JSONSerializable):
if token_info:
logger.info(f"Token Info: {token_info}")
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, memories=None):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -216,6 +239,7 @@ class BaseLlm(JSONSerializable):
k = {}
if self.config.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
k["memories"] = memories
prompt = self.generate_prompt(input_query, contexts, **k)
logger.info(f"Prompt: {prompt}")
if dry_run:

View File

@@ -520,6 +520,10 @@ def validate_config(config_data):
Optional("auto_flush"): int,
},
},
Optional("memory"): {
"api_key": str,
Optional("top_k"): int,
},
}
)

51
poetry.lock generated
View File

@@ -385,17 +385,17 @@ files = [
[[package]]
name = "boto3"
version = "1.34.139"
version = "1.34.140"
description = "The AWS SDK for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "boto3-1.34.139-py3-none-any.whl", hash = "sha256:98b2a12bcb30e679fa9f60fc74145a39db5ec2ca7b7c763f42896e3bd9b3a38d"},
{file = "boto3-1.34.139.tar.gz", hash = "sha256:32b99f0d76ec81fdca287ace2c9744a2eb8b92cb62bf4d26d52a4f516b63a6bf"},
{file = "boto3-1.34.140-py3-none-any.whl", hash = "sha256:23ca8d8f7a30c3bbd989808056b5fc5d68ff5121c02c722c6167b6b1bb7f8726"},
{file = "boto3-1.34.140.tar.gz", hash = "sha256:578bbd5e356005719b6b610d03edff7ea1b0824d078afe62d3fb8bea72f83a87"},
]
[package.dependencies]
botocore = ">=1.34.139,<1.35.0"
botocore = ">=1.34.140,<1.35.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.10.0,<0.11.0"
@@ -404,13 +404,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
version = "1.34.139"
version = "1.34.140"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">=3.8"
files = [
{file = "botocore-1.34.139-py3-none-any.whl", hash = "sha256:dd1e085d4caa2a4c1b7d83e3bc51416111c8238a35d498e9d3b04f3b63b086ba"},
{file = "botocore-1.34.139.tar.gz", hash = "sha256:df023d8cf8999d574214dad4645cb90f9d2ccd1494f6ee2b57b1ab7522f6be77"},
{file = "botocore-1.34.140-py3-none-any.whl", hash = "sha256:43940d3a67d946ba3301631ba4078476a75f1015d4fb0fb0272d0b754b2cf9de"},
{file = "botocore-1.34.140.tar.gz", hash = "sha256:86302b2226c743b9eec7915a4c6cfaffd338ae03989cd9ee181078ef39d1ab39"},
]
[package.dependencies]
@@ -882,13 +882,13 @@ all = ["pycocotools (==2.0.6)"]
[[package]]
name = "clarifai-grpc"
version = "10.5.4"
version = "10.6.1"
description = "Clarifai gRPC API Client"
optional = true
python-versions = ">=3.8"
files = [
{file = "clarifai_grpc-10.5.4-py3-none-any.whl", hash = "sha256:ae4c4d8985fdd2bf326cec27ee834571e44d0e989fb12686dd681f9b553ae218"},
{file = "clarifai_grpc-10.5.4.tar.gz", hash = "sha256:c67ce0dde186e8bab0d42a9923d28ddb4a05017b826c8e52ac7a86ec6df5f12a"},
{file = "clarifai_grpc-10.6.1-py3-none-any.whl", hash = "sha256:7f07c262f46042995b11af10cdd552718c4487e955db1b3f1253fcb0c2ab1ce1"},
{file = "clarifai_grpc-10.6.1.tar.gz", hash = "sha256:f692e3d6a051a1228ca371c3a9dc705cc9a61334eecc454d056f7af0b6f4dbad"},
]
[package.dependencies]
@@ -1280,18 +1280,17 @@ stone = ">=2"
[[package]]
name = "duckduckgo-search"
version = "6.1.8"
version = "6.1.9"
description = "Search for words, documents, images, news, maps and text translation using the DuckDuckGo.com search engine."
optional = true
python-versions = ">=3.8"
files = [
{file = "duckduckgo_search-6.1.8-py3-none-any.whl", hash = "sha256:fb67f6ae8df4f291462010018342aeaaa4f259b54667dc48de37c31d8ecab027"},
{file = "duckduckgo_search-6.1.8.tar.gz", hash = "sha256:e38fa695f598b0b2bd779fffde1fef2eeff1d6a3f218772e50f8b4f381f63279"},
{file = "duckduckgo_search-6.1.9-py3-none-any.whl", hash = "sha256:a208babf87b971290b1afed9908bc5ab6ac6c1738b90b48ad613267f7630cb77"},
{file = "duckduckgo_search-6.1.9.tar.gz", hash = "sha256:0d7d746e003d6b3bcd0d0dc11927c9a69b6fa271f3b3f65df6f01ea4d9d2689d"},
]
[package.dependencies]
click = ">=8.1.7"
orjson = ">=3.10.6"
pyreqwest-impersonate = ">=0.4.9"
[package.extras]
@@ -3337,6 +3336,22 @@ files = [
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
]
[[package]]
name = "memzero"
version = "0.0.7"
description = "Long-term memory for AI Agents"
optional = false
python-versions = "<4.0,>=3.9"
files = [
{file = "memzero-0.0.7-py3-none-any.whl", hash = "sha256:65f6da88d46263dbc05621fcd01bd09616d0e7f082d55ed9899dc2152491ffd2"},
{file = "memzero-0.0.7.tar.gz", hash = "sha256:0c1f413d8ee0ade955fe9f8b8f5aff2cf58bc94869537aca62139db3d9f50725"},
]
[package.dependencies]
httpx = ">=0.27.0,<0.28.0"
posthog = ">=3.5.0,<4.0.0"
pydantic = ">=2.7.3,<3.0.0"
[[package]]
name = "milvus-lite"
version = "2.4.8"
@@ -6603,13 +6618,13 @@ files = [
[[package]]
name = "tenacity"
version = "8.4.2"
version = "8.5.0"
description = "Retry code until it succeeds"
optional = false
python-versions = ">=3.8"
files = [
{file = "tenacity-8.4.2-py3-none-any.whl", hash = "sha256:9e6f7cf7da729125c7437222f8a522279751cdfbe6b67bfe64f75d3a348661b2"},
{file = "tenacity-8.4.2.tar.gz", hash = "sha256:cd80a53a79336edba8489e767f729e4f391c896956b57140b5d7511a64bbd3ef"},
{file = "tenacity-8.5.0-py3-none-any.whl", hash = "sha256:b594c2a5945830c267ce6b79a166228323ed52718f30302c1359836112346687"},
{file = "tenacity-8.5.0.tar.gz", hash = "sha256:8bc6c0c8a09b31e6cad13c47afbed1a567518250a9a171418582ed8d9c20ca78"},
]
[package.extras]
@@ -7914,4 +7929,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<=3.13"
content-hash = "afc88f00bafd2b76a954c758a0556cba6d3854e98c444bc5e720319bf472caa8"
content-hash = "22f5fb8700344234abb1d98a097a55c35162d2475010f3c0c3a97e37dc72c545"

View File

@@ -103,6 +103,7 @@ beautifulsoup4 = "^4.12.2"
pypdf = "^4.0.1"
gptcache = "^0.1.43"
pysbd = "^0.3.4"
memzero = "^0.0.7"
tiktoken = { version = "^0.7.0", optional = true }
youtube-transcript-api = { version = "^0.6.1", optional = true }
pytube = { version = "^15.0.0", optional = true }