141 lines
5.1 KiB
Python
141 lines
5.1 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
import time
|
|
from collections import defaultdict
|
|
|
|
from dotenv import load_dotenv
|
|
from jinja2 import Template
|
|
from openai import OpenAI
|
|
from prompts import ANSWER_PROMPT_ZEP
|
|
from tqdm import tqdm
|
|
from zep_cloud import EntityEdge, EntityNode
|
|
from zep_cloud.client import Zep
|
|
|
|
load_dotenv()
|
|
|
|
TEMPLATE = """
|
|
FACTS and ENTITIES represent relevant context to the current conversation.
|
|
|
|
# These are the most relevant facts and their valid date ranges
|
|
# format: FACT (Date range: from - to)
|
|
|
|
{facts}
|
|
|
|
|
|
# These are the most relevant entities
|
|
# ENTITY_NAME: entity summary
|
|
|
|
{entities}
|
|
|
|
"""
|
|
|
|
|
|
class ZepSearch:
|
|
def __init__(self):
|
|
self.zep_client = Zep(api_key=os.getenv("ZEP_API_KEY"))
|
|
self.results = defaultdict(list)
|
|
self.openai_client = OpenAI()
|
|
|
|
def format_edge_date_range(self, edge: EntityEdge) -> str:
|
|
# return f"{datetime(edge.valid_at).strftime('%Y-%m-%d %H:%M:%S') if edge.valid_at else 'date unknown'} - {(edge.invalid_at.strftime('%Y-%m-%d %H:%M:%S') if edge.invalid_at else 'present')}"
|
|
return f"{edge.valid_at if edge.valid_at else 'date unknown'} - {(edge.invalid_at if edge.invalid_at else 'present')}"
|
|
|
|
def compose_search_context(self, edges: list[EntityEdge], nodes: list[EntityNode]) -> str:
|
|
facts = [f" - {edge.fact} ({self.format_edge_date_range(edge)})" for edge in edges]
|
|
entities = [f" - {node.name}: {node.summary}" for node in nodes]
|
|
return TEMPLATE.format(facts="\n".join(facts), entities="\n".join(entities))
|
|
|
|
def search_memory(self, run_id, idx, query, max_retries=3, retry_delay=1):
|
|
start_time = time.time()
|
|
retries = 0
|
|
while retries < max_retries:
|
|
try:
|
|
user_id = f"run_id_{run_id}_experiment_user_{idx}"
|
|
edges_results = (
|
|
self.zep_client.graph.search(
|
|
user_id=user_id, reranker="cross_encoder", query=query, scope="edges", limit=20
|
|
)
|
|
).edges
|
|
node_results = (
|
|
self.zep_client.graph.search(user_id=user_id, reranker="rrf", query=query, scope="nodes", limit=20)
|
|
).nodes
|
|
context = self.compose_search_context(edges_results, node_results)
|
|
break
|
|
except Exception as e:
|
|
print("Retrying...")
|
|
retries += 1
|
|
if retries >= max_retries:
|
|
raise e
|
|
time.sleep(retry_delay)
|
|
|
|
end_time = time.time()
|
|
|
|
return context, end_time - start_time
|
|
|
|
def process_question(self, run_id, val, idx):
|
|
question = val.get("question", "")
|
|
answer = val.get("answer", "")
|
|
category = val.get("category", -1)
|
|
evidence = val.get("evidence", [])
|
|
adversarial_answer = val.get("adversarial_answer", "")
|
|
|
|
response, search_memory_time, response_time, context = self.answer_question(run_id, idx, question)
|
|
|
|
result = {
|
|
"question": question,
|
|
"answer": answer,
|
|
"category": category,
|
|
"evidence": evidence,
|
|
"response": response,
|
|
"adversarial_answer": adversarial_answer,
|
|
"search_memory_time": search_memory_time,
|
|
"response_time": response_time,
|
|
"context": context,
|
|
}
|
|
|
|
return result
|
|
|
|
def answer_question(self, run_id, idx, question):
|
|
context, search_memory_time = self.search_memory(run_id, idx, question)
|
|
|
|
template = Template(ANSWER_PROMPT_ZEP)
|
|
answer_prompt = template.render(memories=context, question=question)
|
|
|
|
t1 = time.time()
|
|
response = self.openai_client.chat.completions.create(
|
|
model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0
|
|
)
|
|
t2 = time.time()
|
|
response_time = t2 - t1
|
|
return response.choices[0].message.content, search_memory_time, response_time, context
|
|
|
|
def process_data_file(self, file_path, run_id, output_file_path):
|
|
with open(file_path, "r") as f:
|
|
data = json.load(f)
|
|
|
|
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
|
|
qa = item["qa"]
|
|
|
|
for question_item in tqdm(
|
|
qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False
|
|
):
|
|
result = self.process_question(run_id, question_item, idx)
|
|
self.results[idx].append(result)
|
|
|
|
# Save results after each question is processed
|
|
with open(output_file_path, "w") as f:
|
|
json.dump(self.results, f, indent=4)
|
|
|
|
# Final save at the end
|
|
with open(output_file_path, "w") as f:
|
|
json.dump(self.results, f, indent=4)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--run_id", type=str, required=True)
|
|
args = parser.parse_args()
|
|
zep_search = ZepSearch()
|
|
zep_search.process_data_file("../../dataset/locomo10.json", args.run_id, "results/zep_search_results.json")
|