[improvement]: Duplicate embedding generation removed. (#1900)

This commit is contained in:
Mayank
2024-09-25 09:54:30 +05:30
committed by GitHub
parent 3914f4d6ac
commit 5525c4e6fe

View File

@@ -148,8 +148,10 @@ class Memory(MemoryBase):
new_retrieved_facts = [] new_retrieved_facts = []
retrieved_old_memory = [] retrieved_old_memory = []
new_message_embeddings = {}
for new_mem in new_retrieved_facts: for new_mem in new_retrieved_facts:
messages_embeddings = self.embedding_model.embed(new_mem) messages_embeddings = self.embedding_model.embed(new_mem)
new_message_embeddings[new_mem] = messages_embeddings
existing_memories = self.vector_store.search( existing_memories = self.vector_store.search(
query=messages_embeddings, query=messages_embeddings,
limit=5, limit=5,
@@ -173,7 +175,7 @@ class Memory(MemoryBase):
logging.info(resp) logging.info(resp)
try: try:
if resp["event"] == "ADD": if resp["event"] == "ADD":
_ = self._create_memory(data=resp["text"], metadata=metadata) _ = self._create_memory(data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata)
returned_memories.append( returned_memories.append(
{ {
"memory": resp["text"], "memory": resp["text"],
@@ -181,7 +183,7 @@ class Memory(MemoryBase):
} }
) )
elif resp["event"] == "UPDATE": elif resp["event"] == "UPDATE":
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata) self._update_memory(memory_id=resp["id"], data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata)
returned_memories.append( returned_memories.append(
{ {
"memory": resp["text"], "memory": resp["text"],
@@ -504,9 +506,12 @@ class Memory(MemoryBase):
capture_event("mem0.history", self, {"memory_id": memory_id}) capture_event("mem0.history", self, {"memory_id": memory_id})
return self.db.get_history(memory_id) return self.db.get_history(memory_id)
def _create_memory(self, data, metadata=None): def _create_memory(self, data, existing_embeddings, metadata=None):
logging.info(f"Creating memory with {data=}") logging.info(f"Creating memory with {data=}")
embeddings = self.embedding_model.embed(data) if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data)
memory_id = str(uuid.uuid4()) memory_id = str(uuid.uuid4())
metadata = metadata or {} metadata = metadata or {}
metadata["data"] = data metadata["data"] = data
@@ -521,7 +526,7 @@ class Memory(MemoryBase):
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"]) self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
return memory_id return memory_id
def _update_memory(self, memory_id, data, metadata=None): def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
logger.info(f"Updating memory with {data=}") logger.info(f"Updating memory with {data=}")
existing_memory = self.vector_store.get(vector_id=memory_id) existing_memory = self.vector_store.get(vector_id=memory_id)
prev_value = existing_memory.payload.get("data") prev_value = existing_memory.payload.get("data")
@@ -539,7 +544,10 @@ class Memory(MemoryBase):
if "run_id" in existing_memory.payload: if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"] new_metadata["run_id"] = existing_memory.payload["run_id"]
embeddings = self.embedding_model.embed(data) if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data)
self.vector_store.update( self.vector_store.update(
vector_id=memory_id, vector_id=memory_id,
vector=embeddings, vector=embeddings,