Docs Update (#2591)
This commit is contained in:
193
evaluation/src/langmem.py
Normal file
193
evaluation/src/langmem.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
from langgraph.utils.config import get_store
|
||||
from langmem import (
|
||||
create_manage_memory_tool,
|
||||
create_search_memory_tool
|
||||
)
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
import json
|
||||
from functools import partial
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
from openai import OpenAI
|
||||
from collections import defaultdict
|
||||
from dotenv import load_dotenv
|
||||
from prompts import ANSWER_PROMPT
|
||||
|
||||
load_dotenv()
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
ANSWER_PROMPT_TEMPLATE = Template(ANSWER_PROMPT)
|
||||
|
||||
|
||||
def get_answer(question, speaker_1_user_id, speaker_1_memories, speaker_2_user_id, speaker_2_memories):
|
||||
prompt = ANSWER_PROMPT_TEMPLATE.render(
|
||||
question=question,
|
||||
speaker_1_user_id=speaker_1_user_id,
|
||||
speaker_1_memories=speaker_1_memories,
|
||||
speaker_2_user_id=speaker_2_user_id,
|
||||
speaker_2_memories=speaker_2_memories
|
||||
)
|
||||
|
||||
t1 = time.time()
|
||||
response = client.chat.completions.create(
|
||||
model=os.getenv("MODEL"),
|
||||
messages=[{"role": "system", "content": prompt}],
|
||||
temperature=0.0
|
||||
)
|
||||
t2 = time.time()
|
||||
return response.choices[0].message.content, t2 - t1
|
||||
|
||||
|
||||
def prompt(state):
|
||||
"""Prepare the messages for the LLM."""
|
||||
store = get_store()
|
||||
memories = store.search(
|
||||
("memories",),
|
||||
query=state["messages"][-1].content,
|
||||
)
|
||||
system_msg = f"""You are a helpful assistant.
|
||||
|
||||
## Memories
|
||||
<memories>
|
||||
{memories}
|
||||
</memories>
|
||||
"""
|
||||
return [{"role": "system", "content": system_msg}, *state["messages"]]
|
||||
|
||||
|
||||
class LangMem:
|
||||
def __init__(self,):
|
||||
self.store = InMemoryStore(
|
||||
index={
|
||||
"dims": 1536,
|
||||
"embed": f"openai:{os.getenv('EMBEDDING_MODEL')}",
|
||||
}
|
||||
)
|
||||
self.checkpointer = MemorySaver() # Checkpoint graph state
|
||||
|
||||
self.agent = create_react_agent(
|
||||
f"openai:{os.getenv('MODEL')}",
|
||||
prompt=prompt,
|
||||
tools=[
|
||||
create_manage_memory_tool(namespace=("memories",)),
|
||||
create_search_memory_tool(namespace=("memories",)),
|
||||
],
|
||||
store=self.store,
|
||||
checkpointer=self.checkpointer,
|
||||
)
|
||||
|
||||
def add_memory(self, message, config):
|
||||
return self.agent.invoke(
|
||||
{"messages": [{"role": "user", "content": message}]},
|
||||
config=config
|
||||
)
|
||||
|
||||
def search_memory(self, query, config):
|
||||
try:
|
||||
t1 = time.time()
|
||||
response = self.agent.invoke(
|
||||
{"messages": [{"role": "user", "content": query}]},
|
||||
config=config
|
||||
)
|
||||
t2 = time.time()
|
||||
return response["messages"][-1].content, t2 - t1
|
||||
except Exception as e:
|
||||
print(f"Error in search_memory: {e}")
|
||||
return "", t2 - t1
|
||||
|
||||
|
||||
class LangMemManager:
|
||||
def __init__(self, dataset_path):
|
||||
self.dataset_path = dataset_path
|
||||
with open(self.dataset_path, 'r') as f:
|
||||
self.data = json.load(f)
|
||||
|
||||
def process_all_conversations(self, output_file_path):
|
||||
OUTPUT = defaultdict(list)
|
||||
|
||||
# Process conversations in parallel with multiple workers
|
||||
def process_conversation(key_value_pair):
|
||||
key, value = key_value_pair
|
||||
result = defaultdict(list)
|
||||
|
||||
chat_history = value["conversation"]
|
||||
questions = value["question"]
|
||||
|
||||
agent1 = LangMem()
|
||||
agent2 = LangMem()
|
||||
config = {"configurable": {"thread_id": f"thread-{key}"}}
|
||||
speakers = set()
|
||||
|
||||
# Identify speakers
|
||||
for conv in chat_history:
|
||||
speakers.add(conv['speaker'])
|
||||
|
||||
if len(speakers) != 2:
|
||||
raise ValueError(f"Expected 2 speakers, got {len(speakers)}")
|
||||
|
||||
speaker1 = list(speakers)[0]
|
||||
speaker2 = list(speakers)[1]
|
||||
|
||||
# Add memories for each message
|
||||
for conv in tqdm(chat_history, desc=f"Processing messages {key}", leave=False):
|
||||
message = f"{conv['timestamp']} | {conv['speaker']}: {conv['text']}"
|
||||
if conv['speaker'] == speaker1:
|
||||
agent1.add_memory(message, config)
|
||||
elif conv['speaker'] == speaker2:
|
||||
agent2.add_memory(message, config)
|
||||
else:
|
||||
raise ValueError(f"Expected speaker1 or speaker2, got {conv['speaker']}")
|
||||
|
||||
# Process questions
|
||||
for q in tqdm(questions, desc=f"Processing questions {key}", leave=False):
|
||||
category = q['category']
|
||||
|
||||
if int(category) == 5:
|
||||
continue
|
||||
|
||||
answer = q['answer']
|
||||
question = q['question']
|
||||
response1, speaker1_memory_time = agent1.search_memory(question, config)
|
||||
response2, speaker2_memory_time = agent2.search_memory(question, config)
|
||||
|
||||
generated_answer, response_time = get_answer(
|
||||
question, speaker1, response1, speaker2, response2
|
||||
)
|
||||
|
||||
result[key].append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"response1": response1,
|
||||
"response2": response2,
|
||||
"category": category,
|
||||
"speaker1_memory_time": speaker1_memory_time,
|
||||
"speaker2_memory_time": speaker2_memory_time,
|
||||
"response_time": response_time,
|
||||
'response': generated_answer
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
# Use multiprocessing to process conversations in parallel
|
||||
with mp.Pool(processes=10) as pool:
|
||||
results = list(tqdm(
|
||||
pool.imap(process_conversation, list(self.data.items())),
|
||||
total=len(self.data),
|
||||
desc="Processing conversations"
|
||||
))
|
||||
|
||||
# Combine results from all workers
|
||||
for result in results:
|
||||
for key, items in result.items():
|
||||
OUTPUT[key].extend(items)
|
||||
|
||||
# Save final results
|
||||
with open(output_file_path, 'w') as f:
|
||||
json.dump(OUTPUT, f, indent=4)
|
||||
Reference in New Issue
Block a user