Match output format with APIs (#1595)

This commit is contained in:
Dev Khant
2024-08-01 23:17:39 +05:30
committed by GitHub
parent 45ae1f0313
commit 80945df4ca
5 changed files with 186 additions and 97 deletions

View File

@@ -1,7 +1,9 @@
import logging
import hashlib
import os
import time
import uuid
import pytz
from datetime import datetime
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, ValidationError
@@ -28,14 +30,15 @@ setup_config()
class MemoryItem(BaseModel):
id: str = Field(..., description="The unique identifier for the text data")
text: str = Field(..., description="The text content")
memory: str = Field(..., description="The memory deduced from the text data") # TODO After prompt changes from platform, update this
hash: Optional[str] = Field(None, description="The hash of the memory")
# The metadata value can be anything and not just string. Fix it
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Additional metadata for the text data"
)
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data")
score: Optional[float] = Field(
None, description="The score associated with the text data"
)
created_at: Optional[str] = Field(None, description="The timestamp when the memory was created")
updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated")
class MemoryConfig(BaseModel):
@@ -135,7 +138,7 @@ class Memory(MemoryBase):
id=mem.id,
score=mem.score,
metadata=mem.payload,
text=mem.payload["data"],
memory=mem.payload["data"],
)
for mem in existing_memories
]
@@ -187,7 +190,7 @@ class Memory(MemoryBase):
{"memory_id": function_result, "function_name": function_name},
)
capture_event("mem0.add", self)
return response
return {"message": "ok"}
def get(self, memory_id):
"""
@@ -203,11 +206,27 @@ class Memory(MemoryBase):
memory = self.vector_store.get(name=self.collection_name, vector_id=memory_id)
if not memory:
return None
return MemoryItem(
filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)}
# Prepare base memory item
memory_item = MemoryItem(
id=memory.id,
metadata=memory.payload,
text=memory.payload["data"],
memory=memory.payload["data"],
hash=memory.payload.get("hash"),
created_at=memory.payload.get("created_at"),
updated_at=memory.payload.get("updated_at"),
).model_dump(exclude={"score"})
# Add metadata if there are additional keys
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
if additional_metadata:
memory_item["metadata"] = additional_metadata
result = {**memory_item, **filters}
return result
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
"""
@@ -228,12 +247,21 @@ class Memory(MemoryBase):
memories = self.vector_store.list(
name=self.collection_name, filters=filters, limit=limit
)
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
return [
MemoryItem(
id=mem.id,
metadata=mem.payload,
text=mem.payload["data"],
).model_dump(exclude={"score"})
{
**MemoryItem(
id=mem.id,
memory=mem.payload["data"],
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
).model_dump(exclude={"score"}),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**({"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys) else {})
}
for mem in memories[0]
]
@@ -267,13 +295,23 @@ class Memory(MemoryBase):
memories = self.vector_store.search(
name=self.collection_name, query=embeddings, limit=limit, filters=filters
)
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
return [
MemoryItem(
id=mem.id,
metadata=mem.payload,
score=mem.score,
text=mem.payload["data"],
).model_dump()
{
**MemoryItem(
id=mem.id,
memory=mem.payload["data"],
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump(),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**({"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys) else {})
}
for mem in memories
]
@@ -290,6 +328,7 @@ class Memory(MemoryBase):
"""
capture_event("mem0.get_all", self, {"memory_id": memory_id})
self._update_memory_tool(memory_id, data)
return {'message': 'Memory updated successfully!'}
def delete(self, memory_id):
"""
@@ -300,6 +339,7 @@ class Memory(MemoryBase):
"""
capture_event("mem0.delete", self, {"memory_id": memory_id})
self._delete_memory_tool(memory_id)
return {'message': 'Memory deleted successfully!'}
def delete_all(self, user_id=None, agent_id=None, run_id=None):
"""
@@ -327,6 +367,7 @@ class Memory(MemoryBase):
memories = self.vector_store.list(name=self.collection_name, filters=filters)[0]
for memory in memories:
self._delete_memory_tool(memory.id)
return {'message': 'Memories deleted successfully!'}
def history(self, memory_id):
"""
@@ -347,7 +388,8 @@ class Memory(MemoryBase):
memory_id = str(uuid.uuid4())
metadata = metadata or {}
metadata["data"] = data
metadata["created_at"] = int(time.time())
metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
metadata["created_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat()
self.vector_store.insert(
name=self.collection_name,
@@ -355,7 +397,7 @@ class Memory(MemoryBase):
ids=[memory_id],
payloads=[metadata],
)
self.db.add_history(memory_id, None, data, "add")
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
return memory_id
def _update_memory_tool(self, memory_id, data, metadata=None):
@@ -366,7 +408,8 @@ class Memory(MemoryBase):
new_metadata = metadata or {}
new_metadata["data"] = data
new_metadata["updated_at"] = int(time.time())
new_metadata["created_at"] = existing_memory.payload.get("created_at")
new_metadata["updated_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat()
embeddings = self.embedding_model.embed(data)
self.vector_store.update(
name=self.collection_name,
@@ -375,7 +418,7 @@ class Memory(MemoryBase):
payload=new_metadata,
)
logging.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history(memory_id, prev_value, data, "update")
self.db.add_history(memory_id, prev_value, data, "UPDATE", created_at=new_metadata["created_at"], updated_at=new_metadata["updated_at"])
def _delete_memory_tool(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}")
@@ -384,7 +427,7 @@ class Memory(MemoryBase):
)
prev_value = existing_memory.payload["data"]
self.vector_store.delete(name=self.collection_name, vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "delete", is_deleted=1)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
def reset(self):
"""

View File

@@ -6,7 +6,45 @@ from datetime import datetime
class SQLiteManager:
def __init__(self, db_path=":memory:"):
self.connection = sqlite3.connect(db_path, check_same_thread=False)
self._migrate_history_table()
self._create_history_table()
def _migrate_history_table(self):
with self.connection:
cursor = self.connection.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
table_exists = cursor.fetchone() is not None
if table_exists:
# Rename the old table
cursor.execute("ALTER TABLE history RENAME TO old_history")
cursor.execute("""
CREATE TABLE IF NOT EXISTS history (
id TEXT PRIMARY KEY,
memory_id TEXT,
old_memory TEXT,
new_memory TEXT,
new_value TEXT,
event TEXT,
created_at DATETIME,
updated_at DATETIME,
is_deleted INTEGER
)
""")
# Copy data from the old table to the new table
cursor.execute("""
INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted)
SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted
FROM old_history
""")
cursor.execute("DROP TABLE old_history")
self.connection.commit()
def _create_history_table(self):
with self.connection:
@@ -15,29 +53,32 @@ class SQLiteManager:
CREATE TABLE IF NOT EXISTS history (
id TEXT PRIMARY KEY,
memory_id TEXT,
prev_value TEXT,
old_memory TEXT,
new_memory TEXT,
new_value TEXT,
event TEXT,
timestamp DATETIME,
created_at DATETIME,
updated_at DATETIME,
is_deleted INTEGER
)
"""
)
def add_history(self, memory_id, prev_value, new_value, event, is_deleted=0):
def add_history(self, memory_id, old_memory, new_memory, event, created_at = None, updated_at = None, is_deleted=0):
with self.connection:
self.connection.execute(
"""
INSERT INTO history (id, memory_id, prev_value, new_value, event, timestamp, is_deleted)
VALUES (?, ?, ?, ?, ?, ?, ?)
INSERT INTO history (id, memory_id, old_memory, new_memory, event, created_at, updated_at, is_deleted)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
str(uuid.uuid4()),
memory_id,
prev_value,
new_value,
old_memory,
new_memory,
event,
datetime.utcnow(),
created_at,
updated_at,
is_deleted,
),
)
@@ -45,10 +86,10 @@ class SQLiteManager:
def get_history(self, memory_id):
cursor = self.connection.execute(
"""
SELECT id, memory_id, prev_value, new_value, event, timestamp, is_deleted
SELECT id, memory_id, old_memory, new_memory, event, created_at, updated_at
FROM history
WHERE memory_id = ?
ORDER BY timestamp ASC
ORDER BY updated_at ASC
""",
(memory_id,),
)
@@ -57,11 +98,11 @@ class SQLiteManager:
{
"id": row[0],
"memory_id": row[1],
"prev_value": row[2],
"new_value": row[3],
"old_memory": row[2],
"new_memory": row[3],
"event": row[4],
"timestamp": row[5],
"is_deleted": row[6],
"created_at": row[5],
"updated_at": row[6],
}
for row in rows
]