Add infer param and version bump (#2389)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -85,14 +85,18 @@ m = Memory.from_config(config_dict=config)
|
|||||||
|
|
||||||
<CodeGroup>
|
<CodeGroup>
|
||||||
```python Code
|
```python Code
|
||||||
const messages = [
|
messages = [
|
||||||
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||||
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
||||||
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||||
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Store inferred memories (default behavior)
|
||||||
result = m.add(messages, user_id="alice", metadata={"category": "movie_recommendations"})
|
result = m.add(messages, user_id="alice", metadata={"category": "movie_recommendations"})
|
||||||
|
|
||||||
|
# Store raw messages without inference
|
||||||
|
# result = m.add(messages, user_id="alice", metadata={"category": "movie_recommendations"}, infer=False)
|
||||||
```
|
```
|
||||||
|
|
||||||
```json Output
|
```json Output
|
||||||
|
|||||||
@@ -16,8 +16,12 @@ from mem0.memory.base import MemoryBase
|
|||||||
from mem0.memory.setup import setup_config
|
from mem0.memory.setup import setup_config
|
||||||
from mem0.memory.storage import SQLiteManager
|
from mem0.memory.storage import SQLiteManager
|
||||||
from mem0.memory.telemetry import capture_event
|
from mem0.memory.telemetry import capture_event
|
||||||
from mem0.memory.utils import (get_fact_retrieval_messages, parse_messages,
|
from mem0.memory.utils import (
|
||||||
parse_vision_messages, remove_code_blocks)
|
get_fact_retrieval_messages,
|
||||||
|
parse_messages,
|
||||||
|
parse_vision_messages,
|
||||||
|
remove_code_blocks,
|
||||||
|
)
|
||||||
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
||||||
|
|
||||||
# Setup user config
|
# Setup user config
|
||||||
@@ -82,6 +86,7 @@ class Memory(MemoryBase):
|
|||||||
run_id=None,
|
run_id=None,
|
||||||
metadata=None,
|
metadata=None,
|
||||||
filters=None,
|
filters=None,
|
||||||
|
infer=True,
|
||||||
prompt=None,
|
prompt=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -94,6 +99,7 @@ class Memory(MemoryBase):
|
|||||||
run_id (str, optional): ID of the run creating the memory. Defaults to None.
|
run_id (str, optional): ID of the run creating the memory. Defaults to None.
|
||||||
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
|
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
|
||||||
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||||
|
infer (bool, optional): Whether to infer the memories. Defaults to True.
|
||||||
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
|
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -132,7 +138,7 @@ class Memory(MemoryBase):
|
|||||||
messages = parse_vision_messages(messages)
|
messages = parse_vision_messages(messages)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
|
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters, infer)
|
||||||
future2 = executor.submit(self._add_to_graph, messages, filters)
|
future2 = executor.submit(self._add_to_graph, messages, filters)
|
||||||
|
|
||||||
concurrent.futures.wait([future1, future2])
|
concurrent.futures.wait([future1, future2])
|
||||||
@@ -158,7 +164,16 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
return {"results": vector_store_result}
|
return {"results": vector_store_result}
|
||||||
|
|
||||||
def _add_to_vector_store(self, messages, metadata, filters):
|
def _add_to_vector_store(self, messages, metadata, filters, infer):
|
||||||
|
if not infer:
|
||||||
|
returned_memories = []
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] != "system":
|
||||||
|
message_embeddings = self.embedding_model.embed(message["content"], "add")
|
||||||
|
memory_id = self._create_memory(message["content"], message_embeddings, metadata)
|
||||||
|
returned_memories.append({"id": memory_id, "memory": message["content"], "event": "ADD"})
|
||||||
|
return returned_memories
|
||||||
|
|
||||||
parsed_messages = parse_messages(messages)
|
parsed_messages = parse_messages(messages)
|
||||||
|
|
||||||
if self.custom_prompt:
|
if self.custom_prompt:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "mem0ai"
|
name = "mem0ai"
|
||||||
version = "0.1.68"
|
version = "0.1.69"
|
||||||
description = "Long-term memory for AI Agents"
|
description = "Long-term memory for AI Agents"
|
||||||
authors = ["Mem0 <founders@mem0.ai>"]
|
authors = ["Mem0 <founders@mem0.ai>"]
|
||||||
exclude = [
|
exclude = [
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def test_add(memory_instance, version, enable_graph):
|
|||||||
assert result["results"] == [{"memory": "Test memory", "event": "ADD"}]
|
assert result["results"] == [{"memory": "Test memory", "event": "ADD"}]
|
||||||
|
|
||||||
memory_instance._add_to_vector_store.assert_called_once_with(
|
memory_instance._add_to_vector_store.assert_called_once_with(
|
||||||
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}
|
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}, True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove the conditional assertion for _add_to_graph
|
# Remove the conditional assertion for _add_to_graph
|
||||||
|
|||||||
Reference in New Issue
Block a user