Match output format with APIs (#1595)

This commit is contained in:
Dev Khant
2024-08-01 23:17:39 +05:30
committed by GitHub
parent 45ae1f0313
commit 80945df4ca
5 changed files with 186 additions and 97 deletions

View File

@@ -68,13 +68,7 @@ print(result)
Output: Output:
```python ```python
[ {'message': 'ok'}
{
'id': 'm1',
'event': 'add',
'data': 'Likes to play cricket on weekends'
}
]
``` ```
### Retrieve Memories ### Retrieve Memories
@@ -89,15 +83,15 @@ Output:
```python ```python
[ [
{ {
'id': 'm1', "id":"13efe83b-a8df-4ec0-814e-428d6e8451eb",
'text': 'Likes to play cricket on weekends', "memory":"Likes to play cricket on weekends",
'metadata': { "hash":"87bcddeb-fe45-4353-bc22-15a841c50308",
'data': 'Likes to play cricket on weekends', "metadata":"None",
'category': 'hobbies' "created_at":"2024-07-26T08:44:41.039788-07:00",
} "updated_at":"None",
}, "user_id":"alice"
# ... other memories ... }
] ]
``` ```
@@ -110,12 +104,13 @@ print(specific_memory)
Output: Output:
```python ```python
{ {
'id': 'm1', "id":"13efe83b-a8df-4ec0-814e-428d6e8451eb",
'text': 'Likes to play cricket on weekends', "memory":"Likes to play cricket on weekends",
'metadata': { "hash":"87bcddeb-fe45-4353-bc22-15a841c50308",
'data': 'Likes to play cricket on weekends', "metadata":"None",
'category': 'hobbies' "created_at":"2024-07-26T08:44:41.039788-07:00",
} "updated_at":"None",
"user_id":"alice"
} }
``` ```
@@ -130,16 +125,18 @@ Output:
```python ```python
[ [
{ {
'id': 'm1', "id":"ea925981-272f-40dd-b576-be64e4871429",
'text': 'Likes to play cricket on weekends', "memory":"Likes to play cricket and plays cricket on weekends.",
'metadata': { "hash":"c8809002-25c1-4c97-a3a2-227ce9c20c53",
'data': 'Likes to play cricket on weekends', "metadata":{
'category': 'hobbies' "category":"hobbies"
}, },
'score': 0.85 # Similarity score "score":0.32116443111457704,
}, "created_at":"2024-07-26T10:29:36.630547-07:00",
# ... other related memories ... "updated_at":"None",
"user_id":"alice"
}
] ]
``` ```
@@ -153,11 +150,7 @@ print(result)
Output: Output:
```python ```python
{ {'message': 'Memory updated successfully!'}
'id': 'm1',
'event': 'update',
'data': 'Likes to play tennis on weekends'
}
``` ```
### Memory History ### Memory History
@@ -169,24 +162,24 @@ print(history)
Output: Output:
```python ```python
[ [
{ {
'id': 'h1', "id":"4e0e63d6-a9c6-43c0-b11c-a1bad3fc7abb",
'memory_id': 'm1', "memory_id":"ea925981-272f-40dd-b576-be64e4871429",
'prev_value': None, "old_memory":"None",
'new_value': 'Likes to play cricket on weekends', "new_memory":"Likes to play cricket and plays cricket on weekends.",
'event': 'add', "event":"ADD",
'timestamp': '2024-07-14 10:00:54.466687', "created_at":"2024-07-26T10:29:36.630547-07:00",
'is_deleted': 0 "updated_at":"None"
}, },
{ {
'id': 'h2', "id":"548b75f0-f442-44b9-9ca1-772a105abb12",
'memory_id': 'm1', "memory_id":"ea925981-272f-40dd-b576-be64e4871429",
'prev_value': 'Likes to play cricket on weekends', "old_memory":"Likes to play cricket and plays cricket on weekends.",
'new_value': 'Likes to play tennis on weekends', "new_memory":"Likes to play tennis on weekends",
'event': 'update', "event":"UPDATE",
'timestamp': '2024-07-14 10:15:17.230943', "created_at":"2024-07-26T10:29:36.630547-07:00",
'is_deleted': 0 "updated_at":"2024-07-26T10:32:46.332336-07:00"
} }
] ]
``` ```

View File

@@ -1,7 +1,9 @@
import logging import logging
import hashlib
import os import os
import time
import uuid import uuid
import pytz
from datetime import datetime
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, ValidationError
@@ -28,14 +30,15 @@ setup_config()
class MemoryItem(BaseModel): class MemoryItem(BaseModel):
id: str = Field(..., description="The unique identifier for the text data") id: str = Field(..., description="The unique identifier for the text data")
text: str = Field(..., description="The text content") memory: str = Field(..., description="The memory deduced from the text data") # TODO After prompt changes from platform, update this
hash: Optional[str] = Field(None, description="The hash of the memory")
# The metadata value can be anything and not just string. Fix it # The metadata value can be anything and not just string. Fix it
metadata: Dict[str, Any] = Field( metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data")
default_factory=dict, description="Additional metadata for the text data"
)
score: Optional[float] = Field( score: Optional[float] = Field(
None, description="The score associated with the text data" None, description="The score associated with the text data"
) )
created_at: Optional[str] = Field(None, description="The timestamp when the memory was created")
updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated")
class MemoryConfig(BaseModel): class MemoryConfig(BaseModel):
@@ -135,7 +138,7 @@ class Memory(MemoryBase):
id=mem.id, id=mem.id,
score=mem.score, score=mem.score,
metadata=mem.payload, metadata=mem.payload,
text=mem.payload["data"], memory=mem.payload["data"],
) )
for mem in existing_memories for mem in existing_memories
] ]
@@ -187,7 +190,7 @@ class Memory(MemoryBase):
{"memory_id": function_result, "function_name": function_name}, {"memory_id": function_result, "function_name": function_name},
) )
capture_event("mem0.add", self) capture_event("mem0.add", self)
return response return {"message": "ok"}
def get(self, memory_id): def get(self, memory_id):
""" """
@@ -203,12 +206,28 @@ class Memory(MemoryBase):
memory = self.vector_store.get(name=self.collection_name, vector_id=memory_id) memory = self.vector_store.get(name=self.collection_name, vector_id=memory_id)
if not memory: if not memory:
return None return None
return MemoryItem(
filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)}
# Prepare base memory item
memory_item = MemoryItem(
id=memory.id, id=memory.id,
metadata=memory.payload, memory=memory.payload["data"],
text=memory.payload["data"], hash=memory.payload.get("hash"),
created_at=memory.payload.get("created_at"),
updated_at=memory.payload.get("updated_at"),
).model_dump(exclude={"score"}) ).model_dump(exclude={"score"})
# Add metadata if there are additional keys
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
if additional_metadata:
memory_item["metadata"] = additional_metadata
result = {**memory_item, **filters}
return result
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100): def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
""" """
List all memories. List all memories.
@@ -228,12 +247,21 @@ class Memory(MemoryBase):
memories = self.vector_store.list( memories = self.vector_store.list(
name=self.collection_name, filters=filters, limit=limit name=self.collection_name, filters=filters, limit=limit
) )
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
return [ return [
MemoryItem( {
id=mem.id, **MemoryItem(
metadata=mem.payload, id=mem.id,
text=mem.payload["data"], memory=mem.payload["data"],
).model_dump(exclude={"score"}) hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
).model_dump(exclude={"score"}),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**({"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys) else {})
}
for mem in memories[0] for mem in memories[0]
] ]
@@ -267,13 +295,23 @@ class Memory(MemoryBase):
memories = self.vector_store.search( memories = self.vector_store.search(
name=self.collection_name, query=embeddings, limit=limit, filters=filters name=self.collection_name, query=embeddings, limit=limit, filters=filters
) )
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
return [ return [
MemoryItem( {
id=mem.id, **MemoryItem(
metadata=mem.payload, id=mem.id,
score=mem.score, memory=mem.payload["data"],
text=mem.payload["data"], hash=mem.payload.get("hash"),
).model_dump() created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump(),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**({"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys) else {})
}
for mem in memories for mem in memories
] ]
@@ -290,6 +328,7 @@ class Memory(MemoryBase):
""" """
capture_event("mem0.get_all", self, {"memory_id": memory_id}) capture_event("mem0.get_all", self, {"memory_id": memory_id})
self._update_memory_tool(memory_id, data) self._update_memory_tool(memory_id, data)
return {'message': 'Memory updated successfully!'}
def delete(self, memory_id): def delete(self, memory_id):
""" """
@@ -300,6 +339,7 @@ class Memory(MemoryBase):
""" """
capture_event("mem0.delete", self, {"memory_id": memory_id}) capture_event("mem0.delete", self, {"memory_id": memory_id})
self._delete_memory_tool(memory_id) self._delete_memory_tool(memory_id)
return {'message': 'Memory deleted successfully!'}
def delete_all(self, user_id=None, agent_id=None, run_id=None): def delete_all(self, user_id=None, agent_id=None, run_id=None):
""" """
@@ -327,6 +367,7 @@ class Memory(MemoryBase):
memories = self.vector_store.list(name=self.collection_name, filters=filters)[0] memories = self.vector_store.list(name=self.collection_name, filters=filters)[0]
for memory in memories: for memory in memories:
self._delete_memory_tool(memory.id) self._delete_memory_tool(memory.id)
return {'message': 'Memories deleted successfully!'}
def history(self, memory_id): def history(self, memory_id):
""" """
@@ -347,7 +388,8 @@ class Memory(MemoryBase):
memory_id = str(uuid.uuid4()) memory_id = str(uuid.uuid4())
metadata = metadata or {} metadata = metadata or {}
metadata["data"] = data metadata["data"] = data
metadata["created_at"] = int(time.time()) metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
metadata["created_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat()
self.vector_store.insert( self.vector_store.insert(
name=self.collection_name, name=self.collection_name,
@@ -355,7 +397,7 @@ class Memory(MemoryBase):
ids=[memory_id], ids=[memory_id],
payloads=[metadata], payloads=[metadata],
) )
self.db.add_history(memory_id, None, data, "add") self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
return memory_id return memory_id
def _update_memory_tool(self, memory_id, data, metadata=None): def _update_memory_tool(self, memory_id, data, metadata=None):
@@ -366,7 +408,8 @@ class Memory(MemoryBase):
new_metadata = metadata or {} new_metadata = metadata or {}
new_metadata["data"] = data new_metadata["data"] = data
new_metadata["updated_at"] = int(time.time()) new_metadata["created_at"] = existing_memory.payload.get("created_at")
new_metadata["updated_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat()
embeddings = self.embedding_model.embed(data) embeddings = self.embedding_model.embed(data)
self.vector_store.update( self.vector_store.update(
name=self.collection_name, name=self.collection_name,
@@ -375,7 +418,7 @@ class Memory(MemoryBase):
payload=new_metadata, payload=new_metadata,
) )
logging.info(f"Updating memory with ID {memory_id=} with {data=}") logging.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history(memory_id, prev_value, data, "update") self.db.add_history(memory_id, prev_value, data, "UPDATE", created_at=new_metadata["created_at"], updated_at=new_metadata["updated_at"])
def _delete_memory_tool(self, memory_id): def _delete_memory_tool(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}") logging.info(f"Deleting memory with {memory_id=}")
@@ -384,7 +427,7 @@ class Memory(MemoryBase):
) )
prev_value = existing_memory.payload["data"] prev_value = existing_memory.payload["data"]
self.vector_store.delete(name=self.collection_name, vector_id=memory_id) self.vector_store.delete(name=self.collection_name, vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "delete", is_deleted=1) self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
def reset(self): def reset(self):
""" """

View File

@@ -6,8 +6,46 @@ from datetime import datetime
class SQLiteManager: class SQLiteManager:
def __init__(self, db_path=":memory:"): def __init__(self, db_path=":memory:"):
self.connection = sqlite3.connect(db_path, check_same_thread=False) self.connection = sqlite3.connect(db_path, check_same_thread=False)
self._migrate_history_table()
self._create_history_table() self._create_history_table()
def _migrate_history_table(self):
with self.connection:
cursor = self.connection.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
table_exists = cursor.fetchone() is not None
if table_exists:
# Rename the old table
cursor.execute("ALTER TABLE history RENAME TO old_history")
cursor.execute("""
CREATE TABLE IF NOT EXISTS history (
id TEXT PRIMARY KEY,
memory_id TEXT,
old_memory TEXT,
new_memory TEXT,
new_value TEXT,
event TEXT,
created_at DATETIME,
updated_at DATETIME,
is_deleted INTEGER
)
""")
# Copy data from the old table to the new table
cursor.execute("""
INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted)
SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted
FROM old_history
""")
cursor.execute("DROP TABLE old_history")
self.connection.commit()
def _create_history_table(self): def _create_history_table(self):
with self.connection: with self.connection:
self.connection.execute( self.connection.execute(
@@ -15,29 +53,32 @@ class SQLiteManager:
CREATE TABLE IF NOT EXISTS history ( CREATE TABLE IF NOT EXISTS history (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
memory_id TEXT, memory_id TEXT,
prev_value TEXT, old_memory TEXT,
new_memory TEXT,
new_value TEXT, new_value TEXT,
event TEXT, event TEXT,
timestamp DATETIME, created_at DATETIME,
updated_at DATETIME,
is_deleted INTEGER is_deleted INTEGER
) )
""" """
) )
def add_history(self, memory_id, prev_value, new_value, event, is_deleted=0): def add_history(self, memory_id, old_memory, new_memory, event, created_at = None, updated_at = None, is_deleted=0):
with self.connection: with self.connection:
self.connection.execute( self.connection.execute(
""" """
INSERT INTO history (id, memory_id, prev_value, new_value, event, timestamp, is_deleted) INSERT INTO history (id, memory_id, old_memory, new_memory, event, created_at, updated_at, is_deleted)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", """,
( (
str(uuid.uuid4()), str(uuid.uuid4()),
memory_id, memory_id,
prev_value, old_memory,
new_value, new_memory,
event, event,
datetime.utcnow(), created_at,
updated_at,
is_deleted, is_deleted,
), ),
) )
@@ -45,10 +86,10 @@ class SQLiteManager:
def get_history(self, memory_id): def get_history(self, memory_id):
cursor = self.connection.execute( cursor = self.connection.execute(
""" """
SELECT id, memory_id, prev_value, new_value, event, timestamp, is_deleted SELECT id, memory_id, old_memory, new_memory, event, created_at, updated_at
FROM history FROM history
WHERE memory_id = ? WHERE memory_id = ?
ORDER BY timestamp ASC ORDER BY updated_at ASC
""", """,
(memory_id,), (memory_id,),
) )
@@ -57,11 +98,11 @@ class SQLiteManager:
{ {
"id": row[0], "id": row[0],
"memory_id": row[1], "memory_id": row[1],
"prev_value": row[2], "old_memory": row[2],
"new_value": row[3], "new_memory": row[3],
"event": row[4], "event": row[4],
"timestamp": row[5], "created_at": row[5],
"is_deleted": row[6], "updated_at": row[6],
} }
for row in rows for row in rows
] ]

13
poetry.lock generated
View File

@@ -897,6 +897,17 @@ files = [
[package.dependencies] [package.dependencies]
six = ">=1.5" six = ">=1.5"
[[package]]
name = "pytz"
version = "2024.1"
description = "World timezone definitions, modern and historical"
optional = false
python-versions = "*"
files = [
{file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"},
{file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"},
]
[[package]] [[package]]
name = "pywin32" name = "pywin32"
version = "306" version = "306"
@@ -1180,4 +1191,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.8" python-versions = "^3.8"
content-hash = "8c9526d9748b01cf2e43f6764f3aa841a987a3bd218d62506dfa5f66c0bf7887" content-hash = "984fce48f87c2279c9c9caa8696ab9f70995506c799efa8b9818cc56a927d10a"

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "mem0ai" name = "mem0ai"
version = "0.0.9" version = "0.0.10"
description = "Long-term memory for AI Agents" description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"] authors = ["Mem0 <founders@mem0.ai>"]
exclude = [ exclude = [
@@ -20,6 +20,7 @@ qdrant-client = "^1.9.1"
pydantic = "^2.7.3" pydantic = "^2.7.3"
openai = "^1.33.0" openai = "^1.33.0"
posthog = "^3.5.0" posthog = "^3.5.0"
pytz = "^2024.1"
sqlalchemy = "^2.0.31" sqlalchemy = "^2.0.31"
[tool.poetry.group.test.dependencies] [tool.poetry.group.test.dependencies]