Make api_version=v1.1 default and version bump -> 0.1.59 (#2278)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -46,7 +46,7 @@ class Memory(MemoryBase):
|
||||
|
||||
self.enable_graph = False
|
||||
|
||||
if self.api_version == "v1.1" and self.config.graph_store.config:
|
||||
if self.config.graph_store.config:
|
||||
from mem0.memory.graph_memory import MemoryGraph
|
||||
|
||||
self.graph = MemoryGraph(self.config)
|
||||
@@ -126,12 +126,7 @@ class Memory(MemoryBase):
|
||||
vector_store_result = future1.result()
|
||||
graph_result = future2.result()
|
||||
|
||||
if self.api_version == "v1.1":
|
||||
return {
|
||||
"results": vector_store_result,
|
||||
"relations": graph_result,
|
||||
}
|
||||
else:
|
||||
if self.api_version == "v1.0":
|
||||
warnings.warn(
|
||||
"The current add API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
@@ -141,6 +136,14 @@ class Memory(MemoryBase):
|
||||
)
|
||||
return vector_store_result
|
||||
|
||||
if self.enable_graph:
|
||||
return {
|
||||
"results": vector_store_result,
|
||||
"relations": graph_result,
|
||||
}
|
||||
|
||||
return {"results": vector_store_result}
|
||||
|
||||
def _add_to_vector_store(self, messages, metadata, filters):
|
||||
parsed_messages = parse_messages(messages)
|
||||
|
||||
@@ -252,7 +255,7 @@ class Memory(MemoryBase):
|
||||
|
||||
def _add_to_graph(self, messages, filters):
|
||||
added_entities = []
|
||||
if self.api_version == "v1.1" and self.enable_graph:
|
||||
if self.enable_graph:
|
||||
if filters.get("user_id") is None:
|
||||
filters["user_id"] = "user"
|
||||
|
||||
@@ -324,11 +327,7 @@ class Memory(MemoryBase):
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
|
||||
future_graph_entities = (
|
||||
executor.submit(self.graph.get_all, filters, limit)
|
||||
if self.api_version == "v1.1" and self.enable_graph
|
||||
else None
|
||||
)
|
||||
future_graph_entities = executor.submit(self.graph.get_all, filters, limit) if self.enable_graph else None
|
||||
|
||||
concurrent.futures.wait(
|
||||
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
|
||||
@@ -337,12 +336,10 @@ class Memory(MemoryBase):
|
||||
all_memories = future_memories.result()
|
||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||
|
||||
if self.api_version == "v1.1":
|
||||
if self.enable_graph:
|
||||
return {"results": all_memories, "relations": graph_entities}
|
||||
else:
|
||||
return {"results": all_memories}
|
||||
else:
|
||||
if self.enable_graph:
|
||||
return {"results": all_memories, "relations": graph_entities}
|
||||
|
||||
if self.api_version == "v1.0":
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
@@ -351,6 +348,8 @@ class Memory(MemoryBase):
|
||||
stacklevel=2,
|
||||
)
|
||||
return all_memories
|
||||
else:
|
||||
return {"results": all_memories}
|
||||
|
||||
def _get_all_from_vector_store(self, filters, limit):
|
||||
memories = self.vector_store.list(filters=filters, limit=limit)
|
||||
@@ -419,9 +418,7 @@ class Memory(MemoryBase):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
|
||||
future_graph_entities = (
|
||||
executor.submit(self.graph.search, query, filters, limit)
|
||||
if self.api_version == "v1.1" and self.enable_graph
|
||||
else None
|
||||
executor.submit(self.graph.search, query, filters, limit) if self.enable_graph else None
|
||||
)
|
||||
|
||||
concurrent.futures.wait(
|
||||
@@ -431,20 +428,20 @@ class Memory(MemoryBase):
|
||||
original_memories = future_memories.result()
|
||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||
|
||||
if self.api_version == "v1.1":
|
||||
if self.enable_graph:
|
||||
return {"results": original_memories, "relations": graph_entities}
|
||||
else:
|
||||
return {"results": original_memories}
|
||||
else:
|
||||
if self.enable_graph:
|
||||
return {"results": original_memories, "relations": graph_entities}
|
||||
|
||||
if self.api_version == "v1.0":
|
||||
warnings.warn(
|
||||
"The current search API output format is deprecated. "
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return original_memories
|
||||
else:
|
||||
return {"results": original_memories}
|
||||
|
||||
def _search_vector_store(self, query, filters, limit):
|
||||
embeddings = self.embedding_model.embed(query, "search")
|
||||
@@ -540,7 +537,7 @@ class Memory(MemoryBase):
|
||||
|
||||
logger.info(f"Deleted {len(memories)} memories")
|
||||
|
||||
if self.api_version == "v1.1" and self.enable_graph:
|
||||
if self.enable_graph:
|
||||
self.graph.delete_all(filters)
|
||||
|
||||
return {"message": "Memories deleted successfully!"}
|
||||
|
||||
Reference in New Issue
Block a user