[Feature] Add support for OpenAI assistants and support openai version >=1.0.0 (#921)
This commit is contained in:
0
embedchain/store/__init__.py
Normal file
0
embedchain/store/__init__.py
Normal file
125
embedchain/store/assistants.py
Normal file
125
embedchain/store/assistants.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from openai import OpenAI
|
||||
from openai.types.beta.threads import MessageContentText, ThreadMessage
|
||||
|
||||
from embedchain.config import AddConfig
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.models.data_type import DataType
|
||||
from embedchain.utils import detect_datatype
|
||||
|
||||
logging.basicConfig(level=logging.WARN)
|
||||
|
||||
|
||||
class OpenAIAssistant:
|
||||
def __init__(
|
||||
self,
|
||||
name=None,
|
||||
instructions=None,
|
||||
tools=None,
|
||||
thread_id=None,
|
||||
model="gpt-4-1106-preview",
|
||||
data_sources=None,
|
||||
assistant_id=None,
|
||||
log_level=logging.WARN,
|
||||
):
|
||||
self.name = name or "OpenAI Assistant"
|
||||
self.instructions = instructions
|
||||
self.tools = tools or [{"type": "retrieval"}]
|
||||
self.model = model
|
||||
self.data_sources = data_sources or []
|
||||
self.log_level = log_level
|
||||
self._client = OpenAI()
|
||||
self._initialize_assistant(assistant_id)
|
||||
self.thread_id = thread_id or self._create_thread()
|
||||
|
||||
def add(self, source, data_type=None):
|
||||
file_path = self._prepare_source_path(source, data_type)
|
||||
self._add_file_to_assistant(file_path)
|
||||
logging.info("Data successfully added to the assistant.")
|
||||
|
||||
def chat(self, message):
|
||||
self._send_message(message)
|
||||
return self._get_latest_response()
|
||||
|
||||
def delete_thread(self):
|
||||
self._client.beta.threads.delete(self.thread_id)
|
||||
self.thread_id = self._create_thread()
|
||||
|
||||
# Internal methods
|
||||
def _initialize_assistant(self, assistant_id):
|
||||
file_ids = self._generate_file_ids(self.data_sources)
|
||||
self.assistant = (
|
||||
self._client.beta.assistants.retrieve(assistant_id)
|
||||
if assistant_id
|
||||
else self._client.beta.assistants.create(
|
||||
name=self.name, model=self.model, file_ids=file_ids, instructions=self.instructions, tools=self.tools
|
||||
)
|
||||
)
|
||||
|
||||
def _create_thread(self):
|
||||
thread = self._client.beta.threads.create()
|
||||
return thread.id
|
||||
|
||||
def _prepare_source_path(self, source, data_type=None):
|
||||
if Path(source).is_file():
|
||||
return source
|
||||
data_type = data_type or detect_datatype(source)
|
||||
formatter = DataFormatter(data_type=DataType(data_type), config=AddConfig())
|
||||
data = formatter.loader.load_data(source)["data"]
|
||||
return self._save_temp_data(data[0]["content"].encode())
|
||||
|
||||
def _add_file_to_assistant(self, file_path):
|
||||
file_obj = self._client.files.create(file=open(file_path, "rb"), purpose="assistants")
|
||||
self._client.beta.assistants.files.create(assistant_id=self.assistant.id, file_id=file_obj.id)
|
||||
|
||||
def _generate_file_ids(self, data_sources):
|
||||
return [
|
||||
self._add_file_to_assistant(self._prepare_source_path(ds["source"], ds.get("data_type")))
|
||||
for ds in data_sources
|
||||
]
|
||||
|
||||
def _send_message(self, message):
|
||||
self._client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=message)
|
||||
self._wait_for_completion()
|
||||
|
||||
def _wait_for_completion(self):
|
||||
run = self._client.beta.threads.runs.create(
|
||||
thread_id=self.thread_id,
|
||||
assistant_id=self.assistant.id,
|
||||
instructions=self.instructions,
|
||||
)
|
||||
run_id = run.id
|
||||
run_status = run.status
|
||||
|
||||
while run_status in ["queued", "in_progress", "requires_action"]:
|
||||
time.sleep(0.1) # Sleep before making the next API call to avoid hitting rate limits
|
||||
run = self._client.beta.threads.runs.retrieve(thread_id=self.thread_id, run_id=run_id)
|
||||
run_status = run.status
|
||||
if run_status == "failed":
|
||||
raise ValueError(f"Thread run failed with the following error: {run.last_error}")
|
||||
|
||||
def _get_latest_response(self):
|
||||
history = self._get_history()
|
||||
return self._format_message(history[0]) if history else None
|
||||
|
||||
def _get_history(self):
|
||||
messages = self._client.beta.threads.messages.list(thread_id=self.thread_id, order="desc")
|
||||
return list(messages)
|
||||
|
||||
def _format_message(self, thread_message):
|
||||
thread_message = cast(ThreadMessage, thread_message)
|
||||
content = [c.text.value for c in thread_message.content if isinstance(c, MessageContentText)]
|
||||
return " ".join(content)
|
||||
|
||||
def _save_temp_data(self, data):
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
file_path = os.path.join(temp_dir, "temp_data")
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(data)
|
||||
return file_path
|
||||
Reference in New Issue
Block a user