199 lines
6.4 KiB
Python
199 lines
6.4 KiB
Python
import json
|
|
import os
|
|
import time
|
|
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
import tiktoken
|
|
from dotenv import load_dotenv
|
|
from jinja2 import Template
|
|
from openai import OpenAI
|
|
from tqdm import tqdm
|
|
|
|
load_dotenv()
|
|
|
|
PROMPT = """
|
|
# Question:
|
|
{{QUESTION}}
|
|
|
|
# Context:
|
|
{{CONTEXT}}
|
|
|
|
# Short answer:
|
|
"""
|
|
|
|
|
|
class RAGManager:
|
|
def __init__(self, data_path="dataset/locomo10_rag.json", chunk_size=500, k=1):
|
|
self.model = os.getenv("MODEL")
|
|
self.client = OpenAI()
|
|
self.data_path = data_path
|
|
self.chunk_size = chunk_size
|
|
self.k = k
|
|
|
|
def generate_response(self, question, context):
|
|
template = Template(PROMPT)
|
|
prompt = template.render(
|
|
CONTEXT=context,
|
|
QUESTION=question
|
|
)
|
|
|
|
max_retries = 3
|
|
retries = 0
|
|
|
|
while retries <= max_retries:
|
|
try:
|
|
t1 = time.time()
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{"role": "system",
|
|
"content": "You are a helpful assistant that can answer "
|
|
"questions based on the provided context."
|
|
"If the question involves timing, use the conversation date for reference."
|
|
"Provide the shortest possible answer."
|
|
"Use words directly from the conversation when possible."
|
|
"Avoid using subjects in your answer."},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
temperature=0
|
|
)
|
|
t2 = time.time()
|
|
return response.choices[0].message.content.strip(), t2-t1
|
|
except Exception as e:
|
|
retries += 1
|
|
if retries > max_retries:
|
|
raise e
|
|
time.sleep(1) # Wait before retrying
|
|
|
|
def clean_chat_history(self, chat_history):
|
|
cleaned_chat_history = ""
|
|
for c in chat_history:
|
|
cleaned_chat_history += (f"{c['timestamp']} | {c['speaker']}: "
|
|
f"{c['text']}\n")
|
|
|
|
return cleaned_chat_history
|
|
|
|
def calculate_embedding(self, document):
|
|
response = self.client.embeddings.create(
|
|
model=os.getenv("EMBEDDING_MODEL"),
|
|
input=document
|
|
)
|
|
return response.data[0].embedding
|
|
|
|
def calculate_similarity(self, embedding1, embedding2):
|
|
return np.dot(embedding1, embedding2) / (
|
|
np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
|
|
|
|
def search(self, query, chunks, embeddings, k=1):
|
|
"""
|
|
Search for the top-k most similar chunks to the query.
|
|
|
|
Args:
|
|
query: The query string
|
|
chunks: List of text chunks
|
|
embeddings: List of embeddings for each chunk
|
|
k: Number of top chunks to return (default: 1)
|
|
|
|
Returns:
|
|
combined_chunks: The combined text of the top-k chunks
|
|
search_time: Time taken for the search
|
|
"""
|
|
t1 = time.time()
|
|
query_embedding = self.calculate_embedding(query)
|
|
similarities = [
|
|
self.calculate_similarity(query_embedding, embedding)
|
|
for embedding in embeddings
|
|
]
|
|
|
|
# Get indices of top-k most similar chunks
|
|
if k == 1:
|
|
# Original behavior - just get the most similar chunk
|
|
top_indices = [np.argmax(similarities)]
|
|
else:
|
|
# Get indices of top-k chunks
|
|
top_indices = np.argsort(similarities)[-k:][::-1]
|
|
|
|
# Combine the top-k chunks
|
|
combined_chunks = "\n<->\n".join([chunks[i] for i in top_indices])
|
|
|
|
t2 = time.time()
|
|
return combined_chunks, t2-t1
|
|
|
|
def create_chunks(self, chat_history, chunk_size=500):
|
|
"""
|
|
Create chunks using tiktoken for more accurate token counting
|
|
"""
|
|
# Get the encoding for the model
|
|
encoding = tiktoken.encoding_for_model(os.getenv("EMBEDDING_MODEL"))
|
|
|
|
documents = self.clean_chat_history(chat_history)
|
|
|
|
if chunk_size == -1:
|
|
return [documents], []
|
|
|
|
chunks = []
|
|
|
|
# Encode the document
|
|
tokens = encoding.encode(documents)
|
|
|
|
# Split into chunks based on token count
|
|
for i in range(0, len(tokens), chunk_size):
|
|
chunk_tokens = tokens[i:i+chunk_size]
|
|
chunk = encoding.decode(chunk_tokens)
|
|
chunks.append(chunk)
|
|
|
|
embeddings = []
|
|
for chunk in chunks:
|
|
embedding = self.calculate_embedding(chunk)
|
|
embeddings.append(embedding)
|
|
|
|
return chunks, embeddings
|
|
|
|
def process_all_conversations(self, output_file_path):
|
|
with open(self.data_path, "r") as f:
|
|
data = json.load(f)
|
|
|
|
FINAL_RESULTS = defaultdict(list)
|
|
for key, value in tqdm(data.items(), desc="Processing conversations"):
|
|
chat_history = value["conversation"]
|
|
questions = value["question"]
|
|
|
|
chunks, embeddings = self.create_chunks(
|
|
chat_history, self.chunk_size
|
|
)
|
|
|
|
for item in tqdm(
|
|
questions, desc="Answering questions", leave=False
|
|
):
|
|
question = item["question"]
|
|
answer = item.get("answer", "")
|
|
category = item["category"]
|
|
|
|
if self.chunk_size == -1:
|
|
context = chunks[0]
|
|
search_time = 0
|
|
else:
|
|
context, search_time = self.search(
|
|
question, chunks, embeddings, k=self.k
|
|
)
|
|
response, response_time = self.generate_response(
|
|
question, context
|
|
)
|
|
|
|
FINAL_RESULTS[key].append({
|
|
"question": question,
|
|
"answer": answer,
|
|
"category": category,
|
|
"context": context,
|
|
"response": response,
|
|
"search_time": search_time,
|
|
"response_time": response_time,
|
|
})
|
|
with open(output_file_path, "w+") as f:
|
|
json.dump(FINAL_RESULTS, f, indent=4)
|
|
|
|
# Save results
|
|
with open(output_file_path, "w+") as f:
|
|
json.dump(FINAL_RESULTS, f, indent=4)
|