[Misc] Lint code and fix code smells (#1871)

This commit is contained in:
Deshraj Yadav
2024-09-16 17:39:54 -07:00
committed by GitHub
parent 0a78cb9f7a
commit 55c54beeab
57 changed files with 1178 additions and 1357 deletions

View File

@@ -10,14 +10,14 @@ from typing import Any, Dict
import pytz
from pydantic import ValidationError
from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.prompts import get_update_memory_messages
from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
from mem0.configs.base import MemoryItem, MemoryConfig
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
# Setup user config
setup_config()
@@ -30,9 +30,7 @@ class Memory(MemoryBase):
self.config = config
self.custom_prompt = self.config.custom_prompt
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider, self.config.embedder.config
)
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
@@ -45,12 +43,12 @@ class Memory(MemoryBase):
if self.version == "v1.1" and self.config.graph_store.config:
from mem0.memory.graph_memory import MemoryGraph
self.graph = MemoryGraph(self.config)
self.enable_graph = True
capture_event("mem0.init", self)
@classmethod
def from_config(cls, config_dict: Dict[str, Any]):
try:
@@ -60,7 +58,6 @@ class Memory(MemoryBase):
raise
return cls(config)
def add(
self,
messages,
@@ -98,9 +95,7 @@ class Memory(MemoryBase):
filters["run_id"] = metadata["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError(
"One of the filters: user_id, agent_id or run_id is required!"
)
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
@@ -116,8 +111,8 @@ class Memory(MemoryBase):
if self.version == "v1.1":
return {
"results" : vector_store_result,
"relations" : graph_result,
"results": vector_store_result,
"relations": graph_result,
}
else:
warnings.warn(
@@ -125,29 +120,29 @@ class Memory(MemoryBase):
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
stacklevel=2,
)
return {"message": "ok"}
def _add_to_vector_store(self, messages, metadata, filters):
parsed_messages = parse_messages(messages)
if self.custom_prompt:
system_prompt=self.custom_prompt
user_prompt=f"Input: {parsed_messages}"
system_prompt = self.custom_prompt
user_prompt = f"Input: {parsed_messages}"
else:
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
response = self.llm.generate_response(
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object"},
)
try:
new_retrieved_facts = json.loads(response)[
"facts"
]
new_retrieved_facts = json.loads(response)["facts"]
except Exception as e:
logging.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = []
@@ -178,24 +173,30 @@ class Memory(MemoryBase):
logging.info(resp)
try:
if resp["event"] == "ADD":
memory_id = self._create_memory(data=resp["text"], metadata=metadata)
returned_memories.append({
"memory" : resp["text"],
"event" : resp["event"],
})
_ = self._create_memory(data=resp["text"], metadata=metadata)
returned_memories.append(
{
"memory": resp["text"],
"event": resp["event"],
}
)
elif resp["event"] == "UPDATE":
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
returned_memories.append({
"memory" : resp["text"],
"event" : resp["event"],
"previous_memory" : resp["old_memory"],
})
returned_memories.append(
{
"memory": resp["text"],
"event": resp["event"],
"previous_memory": resp["old_memory"],
}
)
elif resp["event"] == "DELETE":
self._delete_memory(memory_id=resp["id"])
returned_memories.append({
"memory" : resp["text"],
"event" : resp["event"],
})
returned_memories.append(
{
"memory": resp["text"],
"event": resp["event"],
}
)
elif resp["event"] == "NONE":
logging.info("NOOP for Memory.")
except Exception as e:
@@ -206,7 +207,6 @@ class Memory(MemoryBase):
capture_event("mem0.add", self)
return returned_memories
def _add_to_graph(self, messages, filters):
added_entities = []
@@ -220,7 +220,6 @@ class Memory(MemoryBase):
return added_entities
def get(self, memory_id):
"""
Retrieve a memory by ID.
@@ -236,11 +235,7 @@ class Memory(MemoryBase):
if not memory:
return None
filters = {
key: memory.payload[key]
for key in ["user_id", "agent_id", "run_id"]
if memory.payload.get(key)
}
filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)}
# Prepare base memory item
memory_item = MemoryItem(
@@ -261,9 +256,7 @@ class Memory(MemoryBase):
"created_at",
"updated_at",
}
additional_metadata = {
k: v for k, v in memory.payload.items() if k not in excluded_keys
}
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
if additional_metadata:
memory_item["metadata"] = additional_metadata
@@ -271,7 +264,6 @@ class Memory(MemoryBase):
return result
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
"""
List all memories.
@@ -288,10 +280,12 @@ class Memory(MemoryBase):
filters["run_id"] = run_id
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
future_graph_entities = (
executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
)
all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
@@ -307,15 +301,22 @@ class Memory(MemoryBase):
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
stacklevel=2,
)
return all_memories
def _get_all_from_vector_store(self, filters, limit):
memories = self.vector_store.list(filters=filters, limit=limit)
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
all_memories = [
{
**MemoryItem(
@@ -325,19 +326,9 @@ class Memory(MemoryBase):
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
).model_dump(exclude={"score"}),
**{
key: mem.payload[key]
for key in ["user_id", "agent_id", "run_id"]
if key in mem.payload
},
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**(
{
"metadata": {
k: v
for k, v in mem.payload.items()
if k not in excluded_keys
}
}
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
@@ -346,10 +337,7 @@ class Memory(MemoryBase):
]
return all_memories
def search(
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
):
def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None):
"""
Search for memories.
@@ -373,15 +361,21 @@ class Memory(MemoryBase):
filters["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError(
"One of the filters: user_id, agent_id or run_id is required!"
)
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version})
capture_event(
"mem0.search",
self,
{"filters": len(filters), "limit": limit, "version": self.version},
)
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
future_graph_entities = executor.submit(self.graph.search, query, filters) if self.version == "v1.1" and self.enable_graph else None
future_graph_entities = (
executor.submit(self.graph.search, query, filters)
if self.version == "v1.1" and self.enable_graph
else None
)
original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
@@ -390,23 +384,20 @@ class Memory(MemoryBase):
if self.enable_graph:
return {"results": original_memories, "relations": graph_entities}
else:
return {"results" : original_memories}
return {"results": original_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
stacklevel=2,
)
return original_memories
def _search_vector_store(self, query, filters, limit):
embeddings = self.embedding_model.embed(query)
memories = self.vector_store.search(
query=embeddings, limit=limit, filters=filters
)
memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters)
excluded_keys = {
"user_id",
@@ -428,19 +419,9 @@ class Memory(MemoryBase):
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump(),
**{
key: mem.payload[key]
for key in ["user_id", "agent_id", "run_id"]
if key in mem.payload
},
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**(
{
"metadata": {
k: v
for k, v in mem.payload.items()
if k not in excluded_keys
}
}
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
@@ -450,7 +431,6 @@ class Memory(MemoryBase):
return original_memories
def update(self, memory_id, data):
"""
Update a memory by ID.
@@ -466,7 +446,6 @@ class Memory(MemoryBase):
self._update_memory(memory_id, data)
return {"message": "Memory updated successfully!"}
def delete(self, memory_id):
"""
Delete a memory by ID.
@@ -478,7 +457,6 @@ class Memory(MemoryBase):
self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"}
def delete_all(self, user_id=None, agent_id=None, run_id=None):
"""
Delete all memories.
@@ -511,8 +489,7 @@ class Memory(MemoryBase):
if self.version == "v1.1" and self.enable_graph:
self.graph.delete_all(filters)
return {'message': 'Memories deleted successfully!'}
return {"message": "Memories deleted successfully!"}
def history(self, memory_id):
"""
@@ -527,7 +504,6 @@ class Memory(MemoryBase):
capture_event("mem0.history", self, {"memory_id": memory_id})
return self.db.get_history(memory_id)
def _create_memory(self, data, metadata=None):
logging.info(f"Creating memory with {data=}")
embeddings = self.embedding_model.embed(data)
@@ -542,12 +518,9 @@ class Memory(MemoryBase):
ids=[memory_id],
payloads=[metadata],
)
self.db.add_history(
memory_id, None, data, "ADD", created_at=metadata["created_at"]
)
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
return memory_id
def _update_memory(self, memory_id, data, metadata=None):
logger.info(f"Updating memory with {data=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
@@ -557,9 +530,7 @@ class Memory(MemoryBase):
new_metadata["data"] = data
new_metadata["hash"] = existing_memory.payload.get("hash")
new_metadata["created_at"] = existing_memory.payload.get("created_at")
new_metadata["updated_at"] = datetime.now(
pytz.timezone("US/Pacific")
).isoformat()
new_metadata["updated_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
if "user_id" in existing_memory.payload:
new_metadata["user_id"] = existing_memory.payload["user_id"]
@@ -584,7 +555,6 @@ class Memory(MemoryBase):
updated_at=new_metadata["updated_at"],
)
def _delete_memory(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
@@ -592,7 +562,6 @@ class Memory(MemoryBase):
self.vector_store.delete(vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
def reset(self):
"""
Reset the memory store.
@@ -602,6 +571,5 @@ class Memory(MemoryBase):
self.db.reset()
capture_event("mem0.reset", self)
def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.")