diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 8a0cc1ac..d97c344f 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -148,8 +148,10 @@ class Memory(MemoryBase): new_retrieved_facts = [] retrieved_old_memory = [] + new_message_embeddings = {} for new_mem in new_retrieved_facts: messages_embeddings = self.embedding_model.embed(new_mem) + new_message_embeddings[new_mem] = messages_embeddings existing_memories = self.vector_store.search( query=messages_embeddings, limit=5, @@ -173,7 +175,7 @@ class Memory(MemoryBase): logging.info(resp) try: if resp["event"] == "ADD": - _ = self._create_memory(data=resp["text"], metadata=metadata) + _ = self._create_memory(data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata) returned_memories.append( { "memory": resp["text"], @@ -181,7 +183,7 @@ class Memory(MemoryBase): } ) elif resp["event"] == "UPDATE": - self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata) + self._update_memory(memory_id=resp["id"], data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata) returned_memories.append( { "memory": resp["text"], @@ -504,9 +506,12 @@ class Memory(MemoryBase): capture_event("mem0.history", self, {"memory_id": memory_id}) return self.db.get_history(memory_id) - def _create_memory(self, data, metadata=None): + def _create_memory(self, data, existing_embeddings, metadata=None): logging.info(f"Creating memory with {data=}") - embeddings = self.embedding_model.embed(data) + if data in existing_embeddings: + embeddings = existing_embeddings[data] + else: + embeddings = self.embedding_model.embed(data) memory_id = str(uuid.uuid4()) metadata = metadata or {} metadata["data"] = data @@ -521,7 +526,7 @@ class Memory(MemoryBase): self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"]) return memory_id - def _update_memory(self, memory_id, data, metadata=None): + def _update_memory(self, memory_id, data, existing_embeddings, metadata=None): logger.info(f"Updating memory with {data=}") existing_memory = self.vector_store.get(vector_id=memory_id) prev_value = existing_memory.payload.get("data") @@ -539,7 +544,10 @@ class Memory(MemoryBase): if "run_id" in existing_memory.payload: new_metadata["run_id"] = existing_memory.payload["run_id"] - embeddings = self.embedding_model.embed(data) + if data in existing_embeddings: + embeddings = existing_embeddings[data] + else: + embeddings = self.embedding_model.embed(data) self.vector_store.update( vector_id=memory_id, vector=embeddings,