[Add] Error handling for update method in OSS & platform code. (#1939)

This commit is contained in:
Parshva Daftari
2024-10-08 15:04:59 +05:30
committed by GitHub
parent ab862d0d40
commit c689f94c52
3 changed files with 15 additions and 3 deletions

View File

@@ -197,6 +197,7 @@ class MemoryClient:
"""
capture_client_event("client.update", self)
response = self.client.put(f"/v1/memories/{memory_id}/", json={"text": data})
response.raise_for_status()
return response.json()
@api_error_handler

View File

@@ -453,7 +453,10 @@ class Memory(MemoryBase):
dict: Updated memory.
"""
capture_event("mem0.update", self, {"memory_id": memory_id})
self._update_memory(memory_id, data)
existing_embeddings = {data: self.embedding_model.embed(data)}
self._update_memory(memory_id, data, existing_embeddings)
return {"message": "Memory updated successfully!"}
def delete(self, memory_id):
@@ -536,7 +539,11 @@ class Memory(MemoryBase):
def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
logger.info(f"Updating memory with {data=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
try:
existing_memory = self.vector_store.get(vector_id=memory_id)
except Exception as e:
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data")
new_metadata = metadata or {}

View File

@@ -127,11 +127,15 @@ def test_search(memory_instance, version, enable_graph):
def test_update(memory_instance):
memory_instance.embedding_model = Mock()
memory_instance.embedding_model.embed = Mock(return_value=[0.1, 0.2, 0.3])
memory_instance._update_memory = Mock()
result = memory_instance.update("test_id", "Updated memory")
memory_instance._update_memory.assert_called_once_with("test_id", "Updated memory")
memory_instance._update_memory.assert_called_once_with("test_id", "Updated memory", {"Updated memory": [0.1, 0.2, 0.3]})
assert result["message"] == "Memory updated successfully!"