Files
t6_mem0/embedchain/store/assistants.py

210 lines
7.7 KiB
Python

import logging
import os
import re
import tempfile
import time
import uuid
from pathlib import Path
from typing import cast
from openai import OpenAI
from openai.types.beta.threads import MessageContentText, ThreadMessage
from embedchain import Client, Pipeline
from embedchain.config import AddConfig
from embedchain.data_formatter import DataFormatter
from embedchain.models.data_type import DataType
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils.misc import detect_datatype
logging.basicConfig(level=logging.WARN)
# Set up the user directory if it doesn't exist already
Client.setup()
class OpenAIAssistant:
def __init__(
self,
name=None,
instructions=None,
tools=None,
thread_id=None,
model="gpt-4-1106-preview",
data_sources=None,
assistant_id=None,
log_level=logging.WARN,
collect_metrics=True,
):
self.name = name or "OpenAI Assistant"
self.instructions = instructions
self.tools = tools or [{"type": "retrieval"}]
self.model = model
self.data_sources = data_sources or []
self.log_level = log_level
self._client = OpenAI()
self._initialize_assistant(assistant_id)
self.thread_id = thread_id or self._create_thread()
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=collect_metrics)
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
def add(self, source, data_type=None):
file_path = self._prepare_source_path(source, data_type)
self._add_file_to_assistant(file_path)
event_props = {
**self._telemetry_props,
"data_type": data_type or detect_datatype(source),
}
self.telemetry.capture(event_name="add", properties=event_props)
logging.info("Data successfully added to the assistant.")
def chat(self, message):
self._send_message(message)
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
return self._get_latest_response()
def delete_thread(self):
self._client.beta.threads.delete(self.thread_id)
self.thread_id = self._create_thread()
# Internal methods
def _initialize_assistant(self, assistant_id):
file_ids = self._generate_file_ids(self.data_sources)
self.assistant = (
self._client.beta.assistants.retrieve(assistant_id)
if assistant_id
else self._client.beta.assistants.create(
name=self.name, model=self.model, file_ids=file_ids, instructions=self.instructions, tools=self.tools
)
)
def _create_thread(self):
thread = self._client.beta.threads.create()
return thread.id
def _prepare_source_path(self, source, data_type=None):
if Path(source).is_file():
return source
data_type = data_type or detect_datatype(source)
formatter = DataFormatter(data_type=DataType(data_type), config=AddConfig())
data = formatter.loader.load_data(source)["data"]
return self._save_temp_data(data=data[0]["content"].encode(), source=source)
def _add_file_to_assistant(self, file_path):
file_obj = self._client.files.create(file=open(file_path, "rb"), purpose="assistants")
self._client.beta.assistants.files.create(assistant_id=self.assistant.id, file_id=file_obj.id)
def _generate_file_ids(self, data_sources):
return [
self._add_file_to_assistant(self._prepare_source_path(ds["source"], ds.get("data_type")))
for ds in data_sources
]
def _send_message(self, message):
self._client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=message)
self._wait_for_completion()
def _wait_for_completion(self):
run = self._client.beta.threads.runs.create(
thread_id=self.thread_id,
assistant_id=self.assistant.id,
instructions=self.instructions,
)
run_id = run.id
run_status = run.status
while run_status in ["queued", "in_progress", "requires_action"]:
time.sleep(0.1) # Sleep before making the next API call to avoid hitting rate limits
run = self._client.beta.threads.runs.retrieve(thread_id=self.thread_id, run_id=run_id)
run_status = run.status
if run_status == "failed":
raise ValueError(f"Thread run failed with the following error: {run.last_error}")
def _get_latest_response(self):
history = self._get_history()
return self._format_message(history[0]) if history else None
def _get_history(self):
messages = self._client.beta.threads.messages.list(thread_id=self.thread_id, order="desc")
return list(messages)
@staticmethod
def _format_message(thread_message):
thread_message = cast(ThreadMessage, thread_message)
content = [c.text.value for c in thread_message.content if isinstance(c, MessageContentText)]
return " ".join(content)
@staticmethod
def _save_temp_data(data, source):
special_chars_pattern = r'[\\/:*?"<>|&=% ]+'
sanitized_source = re.sub(special_chars_pattern, "_", source)[:256]
temp_dir = tempfile.mkdtemp()
file_path = os.path.join(temp_dir, sanitized_source)
with open(file_path, "wb") as file:
file.write(data)
return file_path
class AIAssistant:
def __init__(
self,
name=None,
instructions=None,
yaml_path=None,
assistant_id=None,
thread_id=None,
data_sources=None,
log_level=logging.WARN,
collect_metrics=True,
):
logging.basicConfig(level=log_level)
self.name = name or "AI Assistant"
self.data_sources = data_sources or []
self.log_level = log_level
self.instructions = instructions
self.assistant_id = assistant_id or str(uuid.uuid4())
self.thread_id = thread_id or str(uuid.uuid4())
self.pipeline = Pipeline.from_config(config_path=yaml_path) if yaml_path else Pipeline()
self.pipeline.local_id = self.pipeline.config.id = self.thread_id
if self.instructions:
self.pipeline.system_prompt = self.instructions
print(
f"🎉 Created AI Assistant with name: {self.name}, assistant_id: {self.assistant_id}, thread_id: {self.thread_id}" # noqa: E501
)
# telemetry related properties
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=collect_metrics)
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
if self.data_sources:
for data_source in self.data_sources:
metadata = {"assistant_id": self.assistant_id, "thread_id": "global_knowledge"}
self.pipeline.add(data_source["source"], data_source.get("data_type"), metadata=metadata)
def add(self, source, data_type=None):
metadata = {"assistant_id": self.assistant_id, "thread_id": self.thread_id}
self.pipeline.add(source, data_type=data_type, metadata=metadata)
event_props = {
**self._telemetry_props,
"data_type": data_type or detect_datatype(source),
}
self.telemetry.capture(event_name="add", properties=event_props)
def chat(self, query):
where = {
"$and": [
{"assistant_id": {"$eq": self.assistant_id}},
{"thread_id": {"$in": [self.thread_id, "global_knowledge"]}},
]
}
return self.pipeline.chat(query, where=where)
def delete(self):
self.pipeline.reset()