Make api_version=v1.1 default and version bump -> 0.1.59 (#2278)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Dev Khant
2025-03-01 11:36:20 +05:30
committed by GitHub
parent 5606c3ffb8
commit 4318663697
8 changed files with 45 additions and 61 deletions

View File

@@ -28,8 +28,7 @@ config = {
"password": "xxx" "password": "xxx"
}, },
"custom_prompt": "Please only extract entities containing sports related relationships and nothing else.", "custom_prompt": "Please only extract entities containing sports related relationships and nothing else.",
}, }
"version": "v1.1"
} }
m = Memory.from_config(config_dict=config) m = Memory.from_config(config_dict=config)

View File

@@ -39,7 +39,6 @@ allowfullscreen
To initialize Graph Memory you'll need to set up your configuration with graph store providers. To initialize Graph Memory you'll need to set up your configuration with graph store providers.
Currently, we support Neo4j as a graph store provider. You can setup [Neo4j](https://neo4j.com/) locally or use the hosted [Neo4j AuraDB](https://neo4j.com/product/auradb/). Currently, we support Neo4j as a graph store provider. You can setup [Neo4j](https://neo4j.com/) locally or use the hosted [Neo4j AuraDB](https://neo4j.com/product/auradb/).
Moreover, you also need to set the version to `v1.1` (*prior versions are not supported*).
<Note>If you are using Neo4j locally, then you need to install [APOC plugins](https://neo4j.com/labs/apoc/4.1/installation/).</Note> <Note>If you are using Neo4j locally, then you need to install [APOC plugins](https://neo4j.com/labs/apoc/4.1/installation/).</Note>
@@ -65,8 +64,7 @@ config = {
"username": "neo4j", "username": "neo4j",
"password": "xxx" "password": "xxx"
} }
}, }
"version": "v1.1"
} }
m = Memory.from_config(config_dict=config) m = Memory.from_config(config_dict=config)
@@ -98,8 +96,7 @@ config = {
"temperature": 0.0, "temperature": 0.0,
} }
} }
}, }
"version": "v1.1"
} }
m = Memory.from_config(config_dict=config) m = Memory.from_config(config_dict=config)

View File

@@ -71,8 +71,7 @@ config = {
"username": "neo4j", "username": "neo4j",
"password": "---" "password": "---"
} }
}, }
"version": "v1.1"
} }
m = Memory.from_config(config_dict=config) m = Memory.from_config(config_dict=config)
@@ -100,10 +99,6 @@ result = m.add("Likes to play cricket on weekends", user_id="alice", metadata={"
{ {
"results": [ "results": [
{"id": "bf4d4092-cf91-4181-bfeb-b6fa2ed3061b", "memory": "Likes to play cricket on weekends", "event": "ADD"} {"id": "bf4d4092-cf91-4181-bfeb-b6fa2ed3061b", "memory": "Likes to play cricket on weekends", "event": "ADD"}
],
"relations": [
{"source": "alice", "relationship": "likes_to_play", "target": "cricket"},
{"source": "alice", "relationship": "plays_on", "target": "weekends"}
] ]
} }
``` ```
@@ -129,10 +124,6 @@ all_memories = m.get_all(user_id="alice")
"updated_at": None, "updated_at": None,
"user_id": "alice" "user_id": "alice"
} }
],
"relations": [
{"source": "alice", "relationship": "likes_to_play", "target": "cricket"},
{"source": "alice", "relationship": "plays_on", "target": "weekends"}
] ]
} }
``` ```
@@ -180,10 +171,6 @@ related_memories = m.search(query="What are Alice's hobbies?", user_id="alice")
"updated_at": None, "updated_at": None,
"user_id": "alice" "user_id": "alice"
} }
],
"relations": [
{"source": "alice", "relationship": "plays_on", "target": "weekends"},
{"source": "alice", "relationship": "likes_to_play", "target": "cricket"}
] ]
} }
``` ```
@@ -303,7 +290,7 @@ Mem0 offers extensive configuration options to customize its behavior according
| Parameter | Description | Default | | Parameter | Description | Default |
|------------------|--------------------------------------|----------------------------| |------------------|--------------------------------------|----------------------------|
| `history_db_path` | Path to the history database | "{mem0_dir}/history.db" | | `history_db_path` | Path to the history database | "{mem0_dir}/history.db" |
| `version` | API version | "v1.0" | | `version` | API version | "v1.1" |
| `custom_prompt` | Custom prompt for memory processing | None | | `custom_prompt` | Custom prompt for memory processing | None |
</Accordion> </Accordion>

View File

@@ -46,7 +46,7 @@ class MemoryConfig(BaseModel):
) )
version: str = Field( version: str = Field(
description="The version of the API", description="The version of the API",
default="v1.0", default="v1.1",
) )
custom_prompt: Optional[str] = Field( custom_prompt: Optional[str] = Field(
description="Custom prompt for the memory", description="Custom prompt for the memory",

View File

@@ -46,7 +46,7 @@ class Memory(MemoryBase):
self.enable_graph = False self.enable_graph = False
if self.api_version == "v1.1" and self.config.graph_store.config: if self.config.graph_store.config:
from mem0.memory.graph_memory import MemoryGraph from mem0.memory.graph_memory import MemoryGraph
self.graph = MemoryGraph(self.config) self.graph = MemoryGraph(self.config)
@@ -126,12 +126,7 @@ class Memory(MemoryBase):
vector_store_result = future1.result() vector_store_result = future1.result()
graph_result = future2.result() graph_result = future2.result()
if self.api_version == "v1.1": if self.api_version == "v1.0":
return {
"results": vector_store_result,
"relations": graph_result,
}
else:
warnings.warn( warnings.warn(
"The current add API output format is deprecated. " "The current add API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. " "To use the latest format, set `api_version='v1.1'`. "
@@ -141,6 +136,14 @@ class Memory(MemoryBase):
) )
return vector_store_result return vector_store_result
if self.enable_graph:
return {
"results": vector_store_result,
"relations": graph_result,
}
return {"results": vector_store_result}
def _add_to_vector_store(self, messages, metadata, filters): def _add_to_vector_store(self, messages, metadata, filters):
parsed_messages = parse_messages(messages) parsed_messages = parse_messages(messages)
@@ -252,7 +255,7 @@ class Memory(MemoryBase):
def _add_to_graph(self, messages, filters): def _add_to_graph(self, messages, filters):
added_entities = [] added_entities = []
if self.api_version == "v1.1" and self.enable_graph: if self.enable_graph:
if filters.get("user_id") is None: if filters.get("user_id") is None:
filters["user_id"] = "user" filters["user_id"] = "user"
@@ -324,11 +327,7 @@ class Memory(MemoryBase):
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit) future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = ( future_graph_entities = executor.submit(self.graph.get_all, filters, limit) if self.enable_graph else None
executor.submit(self.graph.get_all, filters, limit)
if self.api_version == "v1.1" and self.enable_graph
else None
)
concurrent.futures.wait( concurrent.futures.wait(
[future_memories, future_graph_entities] if future_graph_entities else [future_memories] [future_memories, future_graph_entities] if future_graph_entities else [future_memories]
@@ -337,12 +336,10 @@ class Memory(MemoryBase):
all_memories = future_memories.result() all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.api_version == "v1.1":
if self.enable_graph: if self.enable_graph:
return {"results": all_memories, "relations": graph_entities} return {"results": all_memories, "relations": graph_entities}
else:
return {"results": all_memories} if self.api_version == "v1.0":
else:
warnings.warn( warnings.warn(
"The current get_all API output format is deprecated. " "The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. " "To use the latest format, set `api_version='v1.1'`. "
@@ -351,6 +348,8 @@ class Memory(MemoryBase):
stacklevel=2, stacklevel=2,
) )
return all_memories return all_memories
else:
return {"results": all_memories}
def _get_all_from_vector_store(self, filters, limit): def _get_all_from_vector_store(self, filters, limit):
memories = self.vector_store.list(filters=filters, limit=limit) memories = self.vector_store.list(filters=filters, limit=limit)
@@ -419,9 +418,7 @@ class Memory(MemoryBase):
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, filters, limit) future_memories = executor.submit(self._search_vector_store, query, filters, limit)
future_graph_entities = ( future_graph_entities = (
executor.submit(self.graph.search, query, filters, limit) executor.submit(self.graph.search, query, filters, limit) if self.enable_graph else None
if self.api_version == "v1.1" and self.enable_graph
else None
) )
concurrent.futures.wait( concurrent.futures.wait(
@@ -431,20 +428,20 @@ class Memory(MemoryBase):
original_memories = future_memories.result() original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.api_version == "v1.1":
if self.enable_graph: if self.enable_graph:
return {"results": original_memories, "relations": graph_entities} return {"results": original_memories, "relations": graph_entities}
else:
return {"results": original_memories} if self.api_version == "v1.0":
else:
warnings.warn( warnings.warn(
"The current search API output format is deprecated. " "The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. " "To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.", "The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning, category=DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
return original_memories return original_memories
else:
return {"results": original_memories}
def _search_vector_store(self, query, filters, limit): def _search_vector_store(self, query, filters, limit):
embeddings = self.embedding_model.embed(query, "search") embeddings = self.embedding_model.embed(query, "search")
@@ -540,7 +537,7 @@ class Memory(MemoryBase):
logger.info(f"Deleted {len(memories)} memories") logger.info(f"Deleted {len(memories)} memories")
if self.api_version == "v1.1" and self.enable_graph: if self.enable_graph:
self.graph.delete_all(filters) self.graph.delete_all(filters)
return {"message": "Memories deleted successfully!"} return {"message": "Memories deleted successfully!"}

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "mem0ai" name = "mem0ai"
version = "0.1.58" version = "0.1.59"
description = "Long-term memory for AI Agents" description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"] authors = ["Mem0 <founders@mem0.ai>"]
exclude = [ exclude = [

View File

@@ -41,10 +41,14 @@ def test_add(memory_instance, version, enable_graph):
result = memory_instance.add(messages=[{"role": "user", "content": "Test message"}], user_id="test_user") result = memory_instance.add(messages=[{"role": "user", "content": "Test message"}], user_id="test_user")
if enable_graph:
assert "results" in result assert "results" in result
assert result["results"] == [{"memory": "Test memory", "event": "ADD"}] assert result["results"] == [{"memory": "Test memory", "event": "ADD"}]
assert "relations" in result assert "relations" in result
assert result["relations"] == [] assert result["relations"] == []
else:
assert "results" in result
assert result["results"] == [{"memory": "Test memory", "event": "ADD"}]
memory_instance._add_to_vector_store.assert_called_once_with( memory_instance._add_to_vector_store.assert_called_once_with(
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"} [{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}

View File

@@ -55,9 +55,9 @@ class TestQdrant(unittest.TestCase):
results = self.qdrant.search(query=query_vector, limit=1) results = self.qdrant.search(query=query_vector, limit=1)
self.client_mock.search.assert_called_once_with( self.client_mock.query_points.assert_called_once_with(
collection_name="test_collection", collection_name="test_collection",
query_vector=query_vector, query=query_vector,
query_filter=None, query_filter=None,
limit=1, limit=1,
) )