Files
t6_mem0/evaluation/src/langmem.py
Prateek Chhikara 393a4fd5a6 Docs Update (#2591)
2025-04-29 08:15:25 -07:00

194 lines
6.1 KiB
Python

from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from langgraph.store.memory import InMemoryStore
from langgraph.utils.config import get_store
from langmem import (
create_manage_memory_tool,
create_search_memory_tool
)
import time
import multiprocessing as mp
import json
from functools import partial
import os
from tqdm import tqdm
from openai import OpenAI
from collections import defaultdict
from dotenv import load_dotenv
from prompts import ANSWER_PROMPT
load_dotenv()
client = OpenAI()
from jinja2 import Template
ANSWER_PROMPT_TEMPLATE = Template(ANSWER_PROMPT)
def get_answer(question, speaker_1_user_id, speaker_1_memories, speaker_2_user_id, speaker_2_memories):
prompt = ANSWER_PROMPT_TEMPLATE.render(
question=question,
speaker_1_user_id=speaker_1_user_id,
speaker_1_memories=speaker_1_memories,
speaker_2_user_id=speaker_2_user_id,
speaker_2_memories=speaker_2_memories
)
t1 = time.time()
response = client.chat.completions.create(
model=os.getenv("MODEL"),
messages=[{"role": "system", "content": prompt}],
temperature=0.0
)
t2 = time.time()
return response.choices[0].message.content, t2 - t1
def prompt(state):
"""Prepare the messages for the LLM."""
store = get_store()
memories = store.search(
("memories",),
query=state["messages"][-1].content,
)
system_msg = f"""You are a helpful assistant.
## Memories
<memories>
{memories}
</memories>
"""
return [{"role": "system", "content": system_msg}, *state["messages"]]
class LangMem:
def __init__(self,):
self.store = InMemoryStore(
index={
"dims": 1536,
"embed": f"openai:{os.getenv('EMBEDDING_MODEL')}",
}
)
self.checkpointer = MemorySaver() # Checkpoint graph state
self.agent = create_react_agent(
f"openai:{os.getenv('MODEL')}",
prompt=prompt,
tools=[
create_manage_memory_tool(namespace=("memories",)),
create_search_memory_tool(namespace=("memories",)),
],
store=self.store,
checkpointer=self.checkpointer,
)
def add_memory(self, message, config):
return self.agent.invoke(
{"messages": [{"role": "user", "content": message}]},
config=config
)
def search_memory(self, query, config):
try:
t1 = time.time()
response = self.agent.invoke(
{"messages": [{"role": "user", "content": query}]},
config=config
)
t2 = time.time()
return response["messages"][-1].content, t2 - t1
except Exception as e:
print(f"Error in search_memory: {e}")
return "", t2 - t1
class LangMemManager:
def __init__(self, dataset_path):
self.dataset_path = dataset_path
with open(self.dataset_path, 'r') as f:
self.data = json.load(f)
def process_all_conversations(self, output_file_path):
OUTPUT = defaultdict(list)
# Process conversations in parallel with multiple workers
def process_conversation(key_value_pair):
key, value = key_value_pair
result = defaultdict(list)
chat_history = value["conversation"]
questions = value["question"]
agent1 = LangMem()
agent2 = LangMem()
config = {"configurable": {"thread_id": f"thread-{key}"}}
speakers = set()
# Identify speakers
for conv in chat_history:
speakers.add(conv['speaker'])
if len(speakers) != 2:
raise ValueError(f"Expected 2 speakers, got {len(speakers)}")
speaker1 = list(speakers)[0]
speaker2 = list(speakers)[1]
# Add memories for each message
for conv in tqdm(chat_history, desc=f"Processing messages {key}", leave=False):
message = f"{conv['timestamp']} | {conv['speaker']}: {conv['text']}"
if conv['speaker'] == speaker1:
agent1.add_memory(message, config)
elif conv['speaker'] == speaker2:
agent2.add_memory(message, config)
else:
raise ValueError(f"Expected speaker1 or speaker2, got {conv['speaker']}")
# Process questions
for q in tqdm(questions, desc=f"Processing questions {key}", leave=False):
category = q['category']
if int(category) == 5:
continue
answer = q['answer']
question = q['question']
response1, speaker1_memory_time = agent1.search_memory(question, config)
response2, speaker2_memory_time = agent2.search_memory(question, config)
generated_answer, response_time = get_answer(
question, speaker1, response1, speaker2, response2
)
result[key].append({
"question": question,
"answer": answer,
"response1": response1,
"response2": response2,
"category": category,
"speaker1_memory_time": speaker1_memory_time,
"speaker2_memory_time": speaker2_memory_time,
"response_time": response_time,
'response': generated_answer
})
return result
# Use multiprocessing to process conversations in parallel
with mp.Pool(processes=10) as pool:
results = list(tqdm(
pool.imap(process_conversation, list(self.data.items())),
total=len(self.data),
desc="Processing conversations"
))
# Combine results from all workers
for result in results:
for key, items in result.items():
OUTPUT[key].extend(items)
# Save final results
with open(output_file_path, 'w') as f:
json.dump(OUTPUT, f, indent=4)