Formatting (#2750)
This commit is contained in:
@@ -59,23 +59,19 @@ class OpenAIPredict:
|
||||
self.results = defaultdict(list)
|
||||
|
||||
def search_memory(self, idx):
|
||||
|
||||
with open(f'memories/{idx}.txt', 'r') as file:
|
||||
with open(f"memories/{idx}.txt", "r") as file:
|
||||
memories = file.read()
|
||||
|
||||
return memories, 0
|
||||
|
||||
def process_question(self, val, idx):
|
||||
question = val.get('question', '')
|
||||
answer = val.get('answer', '')
|
||||
category = val.get('category', -1)
|
||||
evidence = val.get('evidence', [])
|
||||
adversarial_answer = val.get('adversarial_answer', '')
|
||||
question = val.get("question", "")
|
||||
answer = val.get("answer", "")
|
||||
category = val.get("category", -1)
|
||||
evidence = val.get("evidence", [])
|
||||
adversarial_answer = val.get("adversarial_answer", "")
|
||||
|
||||
response, search_memory_time, response_time, context = self.answer_question(
|
||||
idx,
|
||||
question
|
||||
)
|
||||
response, search_memory_time, response_time, context = self.answer_question(idx, question)
|
||||
|
||||
result = {
|
||||
"question": question,
|
||||
@@ -86,7 +82,7 @@ class OpenAIPredict:
|
||||
"adversarial_answer": adversarial_answer,
|
||||
"search_memory_time": search_memory_time,
|
||||
"response_time": response_time,
|
||||
"context": context
|
||||
"context": context,
|
||||
}
|
||||
|
||||
return result
|
||||
@@ -95,43 +91,35 @@ class OpenAIPredict:
|
||||
memories, search_memory_time = self.search_memory(idx)
|
||||
|
||||
template = Template(ANSWER_PROMPT)
|
||||
answer_prompt = template.render(
|
||||
memories=memories,
|
||||
question=question
|
||||
)
|
||||
answer_prompt = template.render(memories=memories, question=question)
|
||||
|
||||
t1 = time.time()
|
||||
response = self.openai_client.chat.completions.create(
|
||||
model=os.getenv("MODEL"),
|
||||
messages=[
|
||||
{"role": "system", "content": answer_prompt}
|
||||
],
|
||||
temperature=0.0
|
||||
model=os.getenv("MODEL"), messages=[{"role": "system", "content": answer_prompt}], temperature=0.0
|
||||
)
|
||||
t2 = time.time()
|
||||
response_time = t2 - t1
|
||||
return response.choices[0].message.content, search_memory_time, response_time, memories
|
||||
|
||||
def process_data_file(self, file_path, output_file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
with open(file_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for idx, item in tqdm(enumerate(data), total=len(data), desc="Processing conversations"):
|
||||
qa = item['qa']
|
||||
qa = item["qa"]
|
||||
|
||||
for question_item in tqdm(qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False):
|
||||
result = self.process_question(
|
||||
question_item,
|
||||
idx
|
||||
)
|
||||
for question_item in tqdm(
|
||||
qa, total=len(qa), desc=f"Processing questions for conversation {idx}", leave=False
|
||||
):
|
||||
result = self.process_question(question_item, idx)
|
||||
self.results[idx].append(result)
|
||||
|
||||
# Save results after each question is processed
|
||||
with open(output_file_path, 'w') as f:
|
||||
with open(output_file_path, "w") as f:
|
||||
json.dump(self.results, f, indent=4)
|
||||
|
||||
# Final save at the end
|
||||
with open(output_file_path, 'w') as f:
|
||||
with open(output_file_path, "w") as f:
|
||||
json.dump(self.results, f, indent=4)
|
||||
|
||||
|
||||
@@ -141,4 +129,3 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
openai_predict = OpenAIPredict()
|
||||
openai_predict.process_data_file("../../dataset/locomo10.json", args.output_file_path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user