Formatting (#2750)

This commit is contained in:
Dev Khant
2025-05-22 01:17:29 +05:30
committed by GitHub
parent dff91154a7
commit d85fcda037
71 changed files with 1391 additions and 1823 deletions

View File

@@ -19,12 +19,12 @@ class ZepAdd:
self.load_data()
def load_data(self):
with open(self.data_path, 'r') as f:
with open(self.data_path, "r") as f:
self.data = json.load(f)
return self.data
def process_conversation(self, run_id, item, idx):
conversation = item['conversation']
conversation = item["conversation"]
user_id = f"run_id_{run_id}_experiment_user_{idx}"
session_id = f"run_id_{run_id}_experiment_session_{idx}"
@@ -41,7 +41,7 @@ class ZepAdd:
print("Starting to add memories... for user", user_id)
for key in tqdm(conversation.keys(), desc=f"Processing user {user_id}"):
if key in ['speaker_a', 'speaker_b'] or "date" in key:
if key in ["speaker_a", "speaker_b"] or "date" in key:
continue
date_time_key = key + "_date_time"
@@ -51,11 +51,13 @@ class ZepAdd:
for chat in tqdm(chats, desc=f"Adding chats for {key}", leave=False):
self.zep_client.memory.add(
session_id=session_id,
messages=[Message(
role=chat['speaker'],
role_type="user",
content=f"{timestamp}: {chat['text']}",
)]
messages=[
Message(
role=chat["speaker"],
role_type="user",
content=f"{timestamp}: {chat['text']}",
)
],
)
def process_all_conversations(self, run_id):
@@ -71,4 +73,4 @@ if __name__ == "__main__":
parser.add_argument("--run_id", type=str, required=True)
args = parser.parse_args()
zep_add = ZepAdd(data_path="../../dataset/locomo10.json")
zep_add.process_all_conversations(args.run_id)
zep_add.process_all_conversations(args.run_id)

View File

@@ -42,9 +42,9 @@ class ZepSearch:
return f"{edge.valid_at if edge.valid_at else 'date unknown'} - {(edge.invalid_at if edge.invalid_at else 'present')}"
def compose_search_context(self, edges: list[EntityEdge], nodes: list[EntityNode]) -> str:
facts = [f' - {edge.fact} ({self.format_edge_date_range(edge)})' for edge in edges]
entities = [f' - {node.name}: {node.summary}' for node in nodes]
return TEMPLATE.format(facts='\n'.join(facts), entities='\n'.join(entities))
facts = [f" - {edge.fact} ({self.format_edge_date_range(edge)})" for edge in edges]
entities = [f" - {node.name}: {node.summary}" for node in nodes]
return TEMPLATE.format(facts="\n".join(facts), entities="\n".join(entities))
def search_memory(self, run_id, idx, query, max_retries=3, retry_delay=1):
start_time = time.time()
@@ -52,8 +52,14 @@ class ZepSearch:
while retries < max_retries:
try:
user_id = f"run_id_{run_id}_experiment_user_{idx}"
edges_results = (self.zep_client.graph.search(user_id=user_id, reranker='cross_encoder', query=query, scope='edges', limit=20)).edges
node_results = (self.zep_client.graph.search(user_id=user_id, reranker='rrf', query=query, scope='nodes', limit=20)).nodes
edges_results = (
self.zep_client.graph.search(
user_id=user_id, reranker="cross_encoder", query=query, scope="edges", limit=20
)
).edges
node_results = (
self.zep_client.graph.search(user_id=user_id, reranker="rrf", query=query, scope="nodes", limit=20)
).nodes
context = self.compose_search_context(edges_results, node_results)
break
except Exception as e:
@@ -68,17 +74,13 @@ class ZepSearch:
return context, end_time - start_time
def process_question(self, run_id, val, idx):
question = val.get('question', '')
answer = val.get('answer', '')
category = val.get('category', -1)
evidence = val.get('evidence', [])
adversarial_answer = val.get('adversarial_answer', '')
question = val.get("question", "")
answer = val.get("answer", "")
category = val.get("category", -1)
evidence = val.get("evidence", [])
adversarial_answer = val.get("adversarial_answer", "")
response, search_memory_time, response_time, context = self.answer_question(
run_id,
idx,
question
)
response, search_memory_time, response_time, context = self.answer_question(run_id, idx, question)
result = {
"question": question,
@@ -89,7 +91,7 @@ class ZepSearch:
"adversarial_answer": adversarial_answer,
"search_memory_time": search_memory_time,
"response_time": response_time,
"context": context
"context": context,
}
return result
@@ -98,44 +100,35 @@ class ZepSearch:
context, search_memory_time = self.search_memory(run_id, idx, question)
template = Template(ANSWER_PROMPT_ZEP)
answer_prompt = template.render(
memories=context,
question=question
)
answer_prompt = template.render(memories=context, question=question)
t1 = time.time()
response = self.openai_client.chat.completions.create(
model=os.getenv("MODEL"),
messages=[
{"role": "system", "content": answer_prompt}
],
temperature=0.0
model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0
)
t2 = time.time()
response_time = t2 - t1
return response.choices[0].message.content, search_memory_time, response_time, context
def process_data_file(self, file_path, run_id, output_file_path):
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
data = json.load(f)
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
qa = item['qa']
qa = item["qa"]
for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False):
result = self.process_question(
run_id,
question_item,
idx
)
for question_item in tqdm(
qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False
):
result = self.process_question(run_id, question_item, idx)
self.results[idx].append(result)
# Save results after each question is processed
with open(output_file_path, 'w') as f:
with open(output_file_path, "w") as f:
json.dump(self.results, f, indent=4)
# Final save at the end
with open(output_file_path, 'w') as f:
with open(output_file_path, "w") as f:
json.dump(self.results, f, indent=4)