Match output format with APIs (#1595)
This commit is contained in:
@@ -68,13 +68,7 @@ print(result)
|
||||
|
||||
Output:
|
||||
```python
|
||||
[
|
||||
{
|
||||
'id': 'm1',
|
||||
'event': 'add',
|
||||
'data': 'Likes to play cricket on weekends'
|
||||
}
|
||||
]
|
||||
{'message': 'ok'}
|
||||
```
|
||||
|
||||
### Retrieve Memories
|
||||
@@ -89,15 +83,15 @@ Output:
|
||||
|
||||
```python
|
||||
[
|
||||
{
|
||||
'id': 'm1',
|
||||
'text': 'Likes to play cricket on weekends',
|
||||
'metadata': {
|
||||
'data': 'Likes to play cricket on weekends',
|
||||
'category': 'hobbies'
|
||||
}
|
||||
},
|
||||
# ... other memories ...
|
||||
{
|
||||
"id":"13efe83b-a8df-4ec0-814e-428d6e8451eb",
|
||||
"memory":"Likes to play cricket on weekends",
|
||||
"hash":"87bcddeb-fe45-4353-bc22-15a841c50308",
|
||||
"metadata":"None",
|
||||
"created_at":"2024-07-26T08:44:41.039788-07:00",
|
||||
"updated_at":"None",
|
||||
"user_id":"alice"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
@@ -110,12 +104,13 @@ print(specific_memory)
|
||||
Output:
|
||||
```python
|
||||
{
|
||||
'id': 'm1',
|
||||
'text': 'Likes to play cricket on weekends',
|
||||
'metadata': {
|
||||
'data': 'Likes to play cricket on weekends',
|
||||
'category': 'hobbies'
|
||||
}
|
||||
"id":"13efe83b-a8df-4ec0-814e-428d6e8451eb",
|
||||
"memory":"Likes to play cricket on weekends",
|
||||
"hash":"87bcddeb-fe45-4353-bc22-15a841c50308",
|
||||
"metadata":"None",
|
||||
"created_at":"2024-07-26T08:44:41.039788-07:00",
|
||||
"updated_at":"None",
|
||||
"user_id":"alice"
|
||||
}
|
||||
```
|
||||
|
||||
@@ -130,16 +125,18 @@ Output:
|
||||
|
||||
```python
|
||||
[
|
||||
{
|
||||
'id': 'm1',
|
||||
'text': 'Likes to play cricket on weekends',
|
||||
'metadata': {
|
||||
'data': 'Likes to play cricket on weekends',
|
||||
'category': 'hobbies'
|
||||
},
|
||||
'score': 0.85 # Similarity score
|
||||
},
|
||||
# ... other related memories ...
|
||||
{
|
||||
"id":"ea925981-272f-40dd-b576-be64e4871429",
|
||||
"memory":"Likes to play cricket and plays cricket on weekends.",
|
||||
"hash":"c8809002-25c1-4c97-a3a2-227ce9c20c53",
|
||||
"metadata":{
|
||||
"category":"hobbies"
|
||||
},
|
||||
"score":0.32116443111457704,
|
||||
"created_at":"2024-07-26T10:29:36.630547-07:00",
|
||||
"updated_at":"None",
|
||||
"user_id":"alice"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
@@ -153,11 +150,7 @@ print(result)
|
||||
Output:
|
||||
|
||||
```python
|
||||
{
|
||||
'id': 'm1',
|
||||
'event': 'update',
|
||||
'data': 'Likes to play tennis on weekends'
|
||||
}
|
||||
{'message': 'Memory updated successfully!'}
|
||||
```
|
||||
|
||||
### Memory History
|
||||
@@ -169,24 +162,24 @@ print(history)
|
||||
Output:
|
||||
```python
|
||||
[
|
||||
{
|
||||
'id': 'h1',
|
||||
'memory_id': 'm1',
|
||||
'prev_value': None,
|
||||
'new_value': 'Likes to play cricket on weekends',
|
||||
'event': 'add',
|
||||
'timestamp': '2024-07-14 10:00:54.466687',
|
||||
'is_deleted': 0
|
||||
},
|
||||
{
|
||||
'id': 'h2',
|
||||
'memory_id': 'm1',
|
||||
'prev_value': 'Likes to play cricket on weekends',
|
||||
'new_value': 'Likes to play tennis on weekends',
|
||||
'event': 'update',
|
||||
'timestamp': '2024-07-14 10:15:17.230943',
|
||||
'is_deleted': 0
|
||||
}
|
||||
{
|
||||
"id":"4e0e63d6-a9c6-43c0-b11c-a1bad3fc7abb",
|
||||
"memory_id":"ea925981-272f-40dd-b576-be64e4871429",
|
||||
"old_memory":"None",
|
||||
"new_memory":"Likes to play cricket and plays cricket on weekends.",
|
||||
"event":"ADD",
|
||||
"created_at":"2024-07-26T10:29:36.630547-07:00",
|
||||
"updated_at":"None"
|
||||
},
|
||||
{
|
||||
"id":"548b75f0-f442-44b9-9ca1-772a105abb12",
|
||||
"memory_id":"ea925981-272f-40dd-b576-be64e4871429",
|
||||
"old_memory":"Likes to play cricket and plays cricket on weekends.",
|
||||
"new_memory":"Likes to play tennis on weekends",
|
||||
"event":"UPDATE",
|
||||
"created_at":"2024-07-26T10:29:36.630547-07:00",
|
||||
"updated_at":"2024-07-26T10:32:46.332336-07:00"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import logging
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import pytz
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
@@ -28,14 +30,15 @@ setup_config()
|
||||
|
||||
class MemoryItem(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the text data")
|
||||
text: str = Field(..., description="The text content")
|
||||
memory: str = Field(..., description="The memory deduced from the text data") # TODO After prompt changes from platform, update this
|
||||
hash: Optional[str] = Field(None, description="The hash of the memory")
|
||||
# The metadata value can be anything and not just string. Fix it
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata for the text data"
|
||||
)
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data")
|
||||
score: Optional[float] = Field(
|
||||
None, description="The score associated with the text data"
|
||||
)
|
||||
created_at: Optional[str] = Field(None, description="The timestamp when the memory was created")
|
||||
updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated")
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
@@ -135,7 +138,7 @@ class Memory(MemoryBase):
|
||||
id=mem.id,
|
||||
score=mem.score,
|
||||
metadata=mem.payload,
|
||||
text=mem.payload["data"],
|
||||
memory=mem.payload["data"],
|
||||
)
|
||||
for mem in existing_memories
|
||||
]
|
||||
@@ -187,7 +190,7 @@ class Memory(MemoryBase):
|
||||
{"memory_id": function_result, "function_name": function_name},
|
||||
)
|
||||
capture_event("mem0.add", self)
|
||||
return response
|
||||
return {"message": "ok"}
|
||||
|
||||
def get(self, memory_id):
|
||||
"""
|
||||
@@ -203,11 +206,27 @@ class Memory(MemoryBase):
|
||||
memory = self.vector_store.get(name=self.collection_name, vector_id=memory_id)
|
||||
if not memory:
|
||||
return None
|
||||
return MemoryItem(
|
||||
|
||||
filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)}
|
||||
|
||||
# Prepare base memory item
|
||||
memory_item = MemoryItem(
|
||||
id=memory.id,
|
||||
metadata=memory.payload,
|
||||
text=memory.payload["data"],
|
||||
memory=memory.payload["data"],
|
||||
hash=memory.payload.get("hash"),
|
||||
created_at=memory.payload.get("created_at"),
|
||||
updated_at=memory.payload.get("updated_at"),
|
||||
).model_dump(exclude={"score"})
|
||||
|
||||
# Add metadata if there are additional keys
|
||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
|
||||
if additional_metadata:
|
||||
memory_item["metadata"] = additional_metadata
|
||||
|
||||
result = {**memory_item, **filters}
|
||||
|
||||
return result
|
||||
|
||||
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
|
||||
"""
|
||||
@@ -228,12 +247,21 @@ class Memory(MemoryBase):
|
||||
memories = self.vector_store.list(
|
||||
name=self.collection_name, filters=filters, limit=limit
|
||||
)
|
||||
|
||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||
return [
|
||||
MemoryItem(
|
||||
id=mem.id,
|
||||
metadata=mem.payload,
|
||||
text=mem.payload["data"],
|
||||
).model_dump(exclude={"score"})
|
||||
{
|
||||
**MemoryItem(
|
||||
id=mem.id,
|
||||
memory=mem.payload["data"],
|
||||
hash=mem.payload.get("hash"),
|
||||
created_at=mem.payload.get("created_at"),
|
||||
updated_at=mem.payload.get("updated_at"),
|
||||
).model_dump(exclude={"score"}),
|
||||
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
|
||||
**({"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
|
||||
if any(k for k in mem.payload if k not in excluded_keys) else {})
|
||||
}
|
||||
for mem in memories[0]
|
||||
]
|
||||
|
||||
@@ -267,13 +295,23 @@ class Memory(MemoryBase):
|
||||
memories = self.vector_store.search(
|
||||
name=self.collection_name, query=embeddings, limit=limit, filters=filters
|
||||
)
|
||||
|
||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||
|
||||
return [
|
||||
MemoryItem(
|
||||
id=mem.id,
|
||||
metadata=mem.payload,
|
||||
score=mem.score,
|
||||
text=mem.payload["data"],
|
||||
).model_dump()
|
||||
{
|
||||
**MemoryItem(
|
||||
id=mem.id,
|
||||
memory=mem.payload["data"],
|
||||
hash=mem.payload.get("hash"),
|
||||
created_at=mem.payload.get("created_at"),
|
||||
updated_at=mem.payload.get("updated_at"),
|
||||
score=mem.score,
|
||||
).model_dump(),
|
||||
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
|
||||
**({"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
|
||||
if any(k for k in mem.payload if k not in excluded_keys) else {})
|
||||
}
|
||||
for mem in memories
|
||||
]
|
||||
|
||||
@@ -290,6 +328,7 @@ class Memory(MemoryBase):
|
||||
"""
|
||||
capture_event("mem0.get_all", self, {"memory_id": memory_id})
|
||||
self._update_memory_tool(memory_id, data)
|
||||
return {'message': 'Memory updated successfully!'}
|
||||
|
||||
def delete(self, memory_id):
|
||||
"""
|
||||
@@ -300,6 +339,7 @@ class Memory(MemoryBase):
|
||||
"""
|
||||
capture_event("mem0.delete", self, {"memory_id": memory_id})
|
||||
self._delete_memory_tool(memory_id)
|
||||
return {'message': 'Memory deleted successfully!'}
|
||||
|
||||
def delete_all(self, user_id=None, agent_id=None, run_id=None):
|
||||
"""
|
||||
@@ -327,6 +367,7 @@ class Memory(MemoryBase):
|
||||
memories = self.vector_store.list(name=self.collection_name, filters=filters)[0]
|
||||
for memory in memories:
|
||||
self._delete_memory_tool(memory.id)
|
||||
return {'message': 'Memories deleted successfully!'}
|
||||
|
||||
def history(self, memory_id):
|
||||
"""
|
||||
@@ -347,7 +388,8 @@ class Memory(MemoryBase):
|
||||
memory_id = str(uuid.uuid4())
|
||||
metadata = metadata or {}
|
||||
metadata["data"] = data
|
||||
metadata["created_at"] = int(time.time())
|
||||
metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
|
||||
metadata["created_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat()
|
||||
|
||||
self.vector_store.insert(
|
||||
name=self.collection_name,
|
||||
@@ -355,7 +397,7 @@ class Memory(MemoryBase):
|
||||
ids=[memory_id],
|
||||
payloads=[metadata],
|
||||
)
|
||||
self.db.add_history(memory_id, None, data, "add")
|
||||
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
|
||||
return memory_id
|
||||
|
||||
def _update_memory_tool(self, memory_id, data, metadata=None):
|
||||
@@ -366,7 +408,8 @@ class Memory(MemoryBase):
|
||||
|
||||
new_metadata = metadata or {}
|
||||
new_metadata["data"] = data
|
||||
new_metadata["updated_at"] = int(time.time())
|
||||
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
||||
new_metadata["updated_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat()
|
||||
embeddings = self.embedding_model.embed(data)
|
||||
self.vector_store.update(
|
||||
name=self.collection_name,
|
||||
@@ -375,7 +418,7 @@ class Memory(MemoryBase):
|
||||
payload=new_metadata,
|
||||
)
|
||||
logging.info(f"Updating memory with ID {memory_id=} with {data=}")
|
||||
self.db.add_history(memory_id, prev_value, data, "update")
|
||||
self.db.add_history(memory_id, prev_value, data, "UPDATE", created_at=new_metadata["created_at"], updated_at=new_metadata["updated_at"])
|
||||
|
||||
def _delete_memory_tool(self, memory_id):
|
||||
logging.info(f"Deleting memory with {memory_id=}")
|
||||
@@ -384,7 +427,7 @@ class Memory(MemoryBase):
|
||||
)
|
||||
prev_value = existing_memory.payload["data"]
|
||||
self.vector_store.delete(name=self.collection_name, vector_id=memory_id)
|
||||
self.db.add_history(memory_id, prev_value, None, "delete", is_deleted=1)
|
||||
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
|
||||
@@ -6,7 +6,45 @@ from datetime import datetime
|
||||
class SQLiteManager:
|
||||
def __init__(self, db_path=":memory:"):
|
||||
self.connection = sqlite3.connect(db_path, check_same_thread=False)
|
||||
self._migrate_history_table()
|
||||
self._create_history_table()
|
||||
|
||||
def _migrate_history_table(self):
|
||||
with self.connection:
|
||||
cursor = self.connection.cursor()
|
||||
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
|
||||
table_exists = cursor.fetchone() is not None
|
||||
|
||||
if table_exists:
|
||||
# Rename the old table
|
||||
cursor.execute("ALTER TABLE history RENAME TO old_history")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS history (
|
||||
id TEXT PRIMARY KEY,
|
||||
memory_id TEXT,
|
||||
old_memory TEXT,
|
||||
new_memory TEXT,
|
||||
new_value TEXT,
|
||||
event TEXT,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME,
|
||||
is_deleted INTEGER
|
||||
)
|
||||
""")
|
||||
|
||||
# Copy data from the old table to the new table
|
||||
cursor.execute("""
|
||||
INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted)
|
||||
SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted
|
||||
FROM old_history
|
||||
""")
|
||||
|
||||
cursor.execute("DROP TABLE old_history")
|
||||
|
||||
self.connection.commit()
|
||||
|
||||
|
||||
def _create_history_table(self):
|
||||
with self.connection:
|
||||
@@ -15,29 +53,32 @@ class SQLiteManager:
|
||||
CREATE TABLE IF NOT EXISTS history (
|
||||
id TEXT PRIMARY KEY,
|
||||
memory_id TEXT,
|
||||
prev_value TEXT,
|
||||
old_memory TEXT,
|
||||
new_memory TEXT,
|
||||
new_value TEXT,
|
||||
event TEXT,
|
||||
timestamp DATETIME,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME,
|
||||
is_deleted INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
def add_history(self, memory_id, prev_value, new_value, event, is_deleted=0):
|
||||
def add_history(self, memory_id, old_memory, new_memory, event, created_at = None, updated_at = None, is_deleted=0):
|
||||
with self.connection:
|
||||
self.connection.execute(
|
||||
"""
|
||||
INSERT INTO history (id, memory_id, prev_value, new_value, event, timestamp, is_deleted)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO history (id, memory_id, old_memory, new_memory, event, created_at, updated_at, is_deleted)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
memory_id,
|
||||
prev_value,
|
||||
new_value,
|
||||
old_memory,
|
||||
new_memory,
|
||||
event,
|
||||
datetime.utcnow(),
|
||||
created_at,
|
||||
updated_at,
|
||||
is_deleted,
|
||||
),
|
||||
)
|
||||
@@ -45,10 +86,10 @@ class SQLiteManager:
|
||||
def get_history(self, memory_id):
|
||||
cursor = self.connection.execute(
|
||||
"""
|
||||
SELECT id, memory_id, prev_value, new_value, event, timestamp, is_deleted
|
||||
SELECT id, memory_id, old_memory, new_memory, event, created_at, updated_at
|
||||
FROM history
|
||||
WHERE memory_id = ?
|
||||
ORDER BY timestamp ASC
|
||||
ORDER BY updated_at ASC
|
||||
""",
|
||||
(memory_id,),
|
||||
)
|
||||
@@ -57,11 +98,11 @@ class SQLiteManager:
|
||||
{
|
||||
"id": row[0],
|
||||
"memory_id": row[1],
|
||||
"prev_value": row[2],
|
||||
"new_value": row[3],
|
||||
"old_memory": row[2],
|
||||
"new_memory": row[3],
|
||||
"event": row[4],
|
||||
"timestamp": row[5],
|
||||
"is_deleted": row[6],
|
||||
"created_at": row[5],
|
||||
"updated_at": row[6],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
13
poetry.lock
generated
13
poetry.lock
generated
@@ -897,6 +897,17 @@ files = [
|
||||
[package.dependencies]
|
||||
six = ">=1.5"
|
||||
|
||||
[[package]]
|
||||
name = "pytz"
|
||||
version = "2024.1"
|
||||
description = "World timezone definitions, modern and historical"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"},
|
||||
{file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pywin32"
|
||||
version = "306"
|
||||
@@ -1180,4 +1191,4 @@ zstd = ["zstandard (>=0.18.0)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.8"
|
||||
content-hash = "8c9526d9748b01cf2e43f6764f3aa841a987a3bd218d62506dfa5f66c0bf7887"
|
||||
content-hash = "984fce48f87c2279c9c9caa8696ab9f70995506c799efa8b9818cc56a927d10a"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "mem0ai"
|
||||
version = "0.0.9"
|
||||
version = "0.0.10"
|
||||
description = "Long-term memory for AI Agents"
|
||||
authors = ["Mem0 <founders@mem0.ai>"]
|
||||
exclude = [
|
||||
@@ -20,6 +20,7 @@ qdrant-client = "^1.9.1"
|
||||
pydantic = "^2.7.3"
|
||||
openai = "^1.33.0"
|
||||
posthog = "^3.5.0"
|
||||
pytz = "^2024.1"
|
||||
sqlalchemy = "^2.0.31"
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
|
||||
Reference in New Issue
Block a user