From 431866369749479bcc100ba4491a040fd2986e07 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Sat, 1 Mar 2025 11:36:20 +0530 Subject: [PATCH] Make api_version=v1.1 default and version bump -> 0.1.59 (#2278) Co-authored-by: Deshraj Yadav --- docs/open-source/graph_memory/features.mdx | 3 +- docs/open-source/graph_memory/overview.mdx | 9 ++-- docs/open-source/python-quickstart.mdx | 17 +------ mem0/configs/base.py | 2 +- mem0/memory/main.py | 57 ++++++++++------------ pyproject.toml | 2 +- tests/test_main.py | 12 +++-- tests/vector_stores/test_qdrant.py | 4 +- 8 files changed, 45 insertions(+), 61 deletions(-) diff --git a/docs/open-source/graph_memory/features.mdx b/docs/open-source/graph_memory/features.mdx index 52abfc79..5c74203a 100644 --- a/docs/open-source/graph_memory/features.mdx +++ b/docs/open-source/graph_memory/features.mdx @@ -28,8 +28,7 @@ config = { "password": "xxx" }, "custom_prompt": "Please only extract entities containing sports related relationships and nothing else.", - }, - "version": "v1.1" + } } m = Memory.from_config(config_dict=config) diff --git a/docs/open-source/graph_memory/overview.mdx b/docs/open-source/graph_memory/overview.mdx index 849b0530..11edcd4e 100644 --- a/docs/open-source/graph_memory/overview.mdx +++ b/docs/open-source/graph_memory/overview.mdx @@ -38,8 +38,7 @@ allowfullscreen ## Initialize Graph Memory To initialize Graph Memory you'll need to set up your configuration with graph store providers. -Currently, we support Neo4j as a graph store provider. You can setup [Neo4j](https://neo4j.com/) locally or use the hosted [Neo4j AuraDB](https://neo4j.com/product/auradb/). -Moreover, you also need to set the version to `v1.1` (*prior versions are not supported*). +Currently, we support Neo4j as a graph store provider. You can setup [Neo4j](https://neo4j.com/) locally or use the hosted [Neo4j AuraDB](https://neo4j.com/product/auradb/). If you are using Neo4j locally, then you need to install [APOC plugins](https://neo4j.com/labs/apoc/4.1/installation/). @@ -65,8 +64,7 @@ config = { "username": "neo4j", "password": "xxx" } - }, - "version": "v1.1" + } } m = Memory.from_config(config_dict=config) @@ -98,8 +96,7 @@ config = { "temperature": 0.0, } } - }, - "version": "v1.1" + } } m = Memory.from_config(config_dict=config) diff --git a/docs/open-source/python-quickstart.mdx b/docs/open-source/python-quickstart.mdx index 8467419f..b52399db 100644 --- a/docs/open-source/python-quickstart.mdx +++ b/docs/open-source/python-quickstart.mdx @@ -71,8 +71,7 @@ config = { "username": "neo4j", "password": "---" } - }, - "version": "v1.1" + } } m = Memory.from_config(config_dict=config) @@ -100,10 +99,6 @@ result = m.add("Likes to play cricket on weekends", user_id="alice", metadata={" { "results": [ {"id": "bf4d4092-cf91-4181-bfeb-b6fa2ed3061b", "memory": "Likes to play cricket on weekends", "event": "ADD"} - ], - "relations": [ - {"source": "alice", "relationship": "likes_to_play", "target": "cricket"}, - {"source": "alice", "relationship": "plays_on", "target": "weekends"} ] } ``` @@ -129,10 +124,6 @@ all_memories = m.get_all(user_id="alice") "updated_at": None, "user_id": "alice" } - ], - "relations": [ - {"source": "alice", "relationship": "likes_to_play", "target": "cricket"}, - {"source": "alice", "relationship": "plays_on", "target": "weekends"} ] } ``` @@ -180,10 +171,6 @@ related_memories = m.search(query="What are Alice's hobbies?", user_id="alice") "updated_at": None, "user_id": "alice" } - ], - "relations": [ - {"source": "alice", "relationship": "plays_on", "target": "weekends"}, - {"source": "alice", "relationship": "likes_to_play", "target": "cricket"} ] } ``` @@ -303,7 +290,7 @@ Mem0 offers extensive configuration options to customize its behavior according | Parameter | Description | Default | |------------------|--------------------------------------|----------------------------| | `history_db_path` | Path to the history database | "{mem0_dir}/history.db" | -| `version` | API version | "v1.0" | +| `version` | API version | "v1.1" | | `custom_prompt` | Custom prompt for memory processing | None | diff --git a/mem0/configs/base.py b/mem0/configs/base.py index 55d6b2e9..4d54dc2d 100644 --- a/mem0/configs/base.py +++ b/mem0/configs/base.py @@ -46,7 +46,7 @@ class MemoryConfig(BaseModel): ) version: str = Field( description="The version of the API", - default="v1.0", + default="v1.1", ) custom_prompt: Optional[str] = Field( description="Custom prompt for the memory", diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 1c3a4b03..6ab3fb64 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -46,7 +46,7 @@ class Memory(MemoryBase): self.enable_graph = False - if self.api_version == "v1.1" and self.config.graph_store.config: + if self.config.graph_store.config: from mem0.memory.graph_memory import MemoryGraph self.graph = MemoryGraph(self.config) @@ -126,12 +126,7 @@ class Memory(MemoryBase): vector_store_result = future1.result() graph_result = future2.result() - if self.api_version == "v1.1": - return { - "results": vector_store_result, - "relations": graph_result, - } - else: + if self.api_version == "v1.0": warnings.warn( "The current add API output format is deprecated. " "To use the latest format, set `api_version='v1.1'`. " @@ -141,6 +136,14 @@ class Memory(MemoryBase): ) return vector_store_result + if self.enable_graph: + return { + "results": vector_store_result, + "relations": graph_result, + } + + return {"results": vector_store_result} + def _add_to_vector_store(self, messages, metadata, filters): parsed_messages = parse_messages(messages) @@ -252,7 +255,7 @@ class Memory(MemoryBase): def _add_to_graph(self, messages, filters): added_entities = [] - if self.api_version == "v1.1" and self.enable_graph: + if self.enable_graph: if filters.get("user_id") is None: filters["user_id"] = "user" @@ -324,11 +327,7 @@ class Memory(MemoryBase): with concurrent.futures.ThreadPoolExecutor() as executor: future_memories = executor.submit(self._get_all_from_vector_store, filters, limit) - future_graph_entities = ( - executor.submit(self.graph.get_all, filters, limit) - if self.api_version == "v1.1" and self.enable_graph - else None - ) + future_graph_entities = executor.submit(self.graph.get_all, filters, limit) if self.enable_graph else None concurrent.futures.wait( [future_memories, future_graph_entities] if future_graph_entities else [future_memories] @@ -337,12 +336,10 @@ class Memory(MemoryBase): all_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None - if self.api_version == "v1.1": - if self.enable_graph: - return {"results": all_memories, "relations": graph_entities} - else: - return {"results": all_memories} - else: + if self.enable_graph: + return {"results": all_memories, "relations": graph_entities} + + if self.api_version == "v1.0": warnings.warn( "The current get_all API output format is deprecated. " "To use the latest format, set `api_version='v1.1'`. " @@ -351,6 +348,8 @@ class Memory(MemoryBase): stacklevel=2, ) return all_memories + else: + return {"results": all_memories} def _get_all_from_vector_store(self, filters, limit): memories = self.vector_store.list(filters=filters, limit=limit) @@ -419,9 +418,7 @@ class Memory(MemoryBase): with concurrent.futures.ThreadPoolExecutor() as executor: future_memories = executor.submit(self._search_vector_store, query, filters, limit) future_graph_entities = ( - executor.submit(self.graph.search, query, filters, limit) - if self.api_version == "v1.1" and self.enable_graph - else None + executor.submit(self.graph.search, query, filters, limit) if self.enable_graph else None ) concurrent.futures.wait( @@ -431,20 +428,20 @@ class Memory(MemoryBase): original_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None - if self.api_version == "v1.1": - if self.enable_graph: - return {"results": original_memories, "relations": graph_entities} - else: - return {"results": original_memories} - else: + if self.enable_graph: + return {"results": original_memories, "relations": graph_entities} + + if self.api_version == "v1.0": warnings.warn( - "The current search API output format is deprecated. " + "The current get_all API output format is deprecated. " "To use the latest format, set `api_version='v1.1'`. " "The current format will be removed in mem0ai 1.1.0 and later versions.", category=DeprecationWarning, stacklevel=2, ) return original_memories + else: + return {"results": original_memories} def _search_vector_store(self, query, filters, limit): embeddings = self.embedding_model.embed(query, "search") @@ -540,7 +537,7 @@ class Memory(MemoryBase): logger.info(f"Deleted {len(memories)} memories") - if self.api_version == "v1.1" and self.enable_graph: + if self.enable_graph: self.graph.delete_all(filters) return {"message": "Memories deleted successfully!"} diff --git a/pyproject.toml b/pyproject.toml index c5c1b8c1..adc09067 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mem0ai" -version = "0.1.58" +version = "0.1.59" description = "Long-term memory for AI Agents" authors = ["Mem0 "] exclude = [ diff --git a/tests/test_main.py b/tests/test_main.py index 29e66200..b1375310 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -41,10 +41,14 @@ def test_add(memory_instance, version, enable_graph): result = memory_instance.add(messages=[{"role": "user", "content": "Test message"}], user_id="test_user") - assert "results" in result - assert result["results"] == [{"memory": "Test memory", "event": "ADD"}] - assert "relations" in result - assert result["relations"] == [] + if enable_graph: + assert "results" in result + assert result["results"] == [{"memory": "Test memory", "event": "ADD"}] + assert "relations" in result + assert result["relations"] == [] + else: + assert "results" in result + assert result["results"] == [{"memory": "Test memory", "event": "ADD"}] memory_instance._add_to_vector_store.assert_called_once_with( [{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"} diff --git a/tests/vector_stores/test_qdrant.py b/tests/vector_stores/test_qdrant.py index ab80cde9..ce5cc5b1 100644 --- a/tests/vector_stores/test_qdrant.py +++ b/tests/vector_stores/test_qdrant.py @@ -55,9 +55,9 @@ class TestQdrant(unittest.TestCase): results = self.qdrant.search(query=query_vector, limit=1) - self.client_mock.search.assert_called_once_with( + self.client_mock.query_points.assert_called_once_with( collection_name="test_collection", - query_vector=query_vector, + query=query_vector, query_filter=None, limit=1, )