[improvement]: Duplicate embedding generation removed. (#1900)

This commit is contained in:
Mayank
2024-09-25 09:54:30 +05:30
committed by GitHub
parent 3914f4d6ac
commit 5525c4e6fe

View File

@@ -148,8 +148,10 @@ class Memory(MemoryBase):
new_retrieved_facts = []
retrieved_old_memory = []
new_message_embeddings = {}
for new_mem in new_retrieved_facts:
messages_embeddings = self.embedding_model.embed(new_mem)
new_message_embeddings[new_mem] = messages_embeddings
existing_memories = self.vector_store.search(
query=messages_embeddings,
limit=5,
@@ -173,7 +175,7 @@ class Memory(MemoryBase):
logging.info(resp)
try:
if resp["event"] == "ADD":
_ = self._create_memory(data=resp["text"], metadata=metadata)
_ = self._create_memory(data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata)
returned_memories.append(
{
"memory": resp["text"],
@@ -181,7 +183,7 @@ class Memory(MemoryBase):
}
)
elif resp["event"] == "UPDATE":
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
self._update_memory(memory_id=resp["id"], data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata)
returned_memories.append(
{
"memory": resp["text"],
@@ -504,9 +506,12 @@ class Memory(MemoryBase):
capture_event("mem0.history", self, {"memory_id": memory_id})
return self.db.get_history(memory_id)
def _create_memory(self, data, metadata=None):
def _create_memory(self, data, existing_embeddings, metadata=None):
logging.info(f"Creating memory with {data=}")
embeddings = self.embedding_model.embed(data)
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data)
memory_id = str(uuid.uuid4())
metadata = metadata or {}
metadata["data"] = data
@@ -521,7 +526,7 @@ class Memory(MemoryBase):
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
return memory_id
def _update_memory(self, memory_id, data, metadata=None):
def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
logger.info(f"Updating memory with {data=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
prev_value = existing_memory.payload.get("data")
@@ -539,7 +544,10 @@ class Memory(MemoryBase):
if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"]
embeddings = self.embedding_model.embed(data)
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data)
self.vector_store.update(
vector_id=memory_id,
vector=embeddings,