[Feature] Add support for AIAssistant (#938)
This commit is contained in:
@@ -482,7 +482,7 @@ class EmbedChain(JSONSerializable):
|
||||
where = {}
|
||||
if query_config is not None and query_config.where is not None:
|
||||
where = query_config.where
|
||||
|
||||
|
||||
if self.config.id is not None:
|
||||
where.update({"app_id": self.config.id})
|
||||
|
||||
|
||||
@@ -3,12 +3,14 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from openai import OpenAI
|
||||
from openai.types.beta.threads import MessageContentText, ThreadMessage
|
||||
|
||||
from embedchain import Pipeline
|
||||
from embedchain.config import AddConfig
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.models.data_type import DataType
|
||||
@@ -138,3 +140,65 @@ class OpenAIAssistant:
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(data)
|
||||
return file_path
|
||||
|
||||
|
||||
class AIAssistant:
|
||||
def __init__(
|
||||
self,
|
||||
name=None,
|
||||
instructions=None,
|
||||
yaml_path=None,
|
||||
assistant_id=None,
|
||||
thread_id=None,
|
||||
data_sources=None,
|
||||
log_level=logging.WARN,
|
||||
collect_metrics=True,
|
||||
):
|
||||
logging.basicConfig(level=log_level)
|
||||
|
||||
self.name = name or "AI Assistant"
|
||||
self.data_sources = data_sources or []
|
||||
self.log_level = log_level
|
||||
self.instructions = instructions
|
||||
self.assistant_id = assistant_id or str(uuid.uuid4())
|
||||
self.thread_id = thread_id or str(uuid.uuid4())
|
||||
self.pipeline = Pipeline.from_config(yaml_path=yaml_path) if yaml_path else Pipeline()
|
||||
self.pipeline.local_id = self.pipeline.config.id = self.thread_id
|
||||
|
||||
if self.instructions:
|
||||
self.pipeline.system_prompt = self.instructions
|
||||
|
||||
print(
|
||||
f"🎉 Created AI Assistant with name: {self.name}, assistant_id: {self.assistant_id}, thread_id: {self.thread_id}" # noqa: E501
|
||||
)
|
||||
|
||||
# telemetry related properties
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=collect_metrics)
|
||||
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
|
||||
|
||||
if self.data_sources:
|
||||
for data_source in self.data_sources:
|
||||
metadata = {"assistant_id": self.assistant_id, "thread_id": "global_knowledge"}
|
||||
self.pipeline.add(data_source["source"], data_source.get("data_type"), metadata=metadata)
|
||||
|
||||
def add(self, source, data_type=None):
|
||||
metadata = {"assistant_id": self.assistant_id, "thread_id": self.thread_id}
|
||||
self.pipeline.add(source, data_type=data_type, metadata=metadata)
|
||||
event_props = {
|
||||
**self._telemetry_props,
|
||||
"data_type": data_type or detect_datatype(source),
|
||||
}
|
||||
self.telemetry.capture(event_name="add", properties=event_props)
|
||||
|
||||
def chat(self, query):
|
||||
where = {
|
||||
"$and": [
|
||||
{"assistant_id": {"$eq": self.assistant_id}},
|
||||
{"thread_id": {"$in": [self.thread_id, "global_knowledge"]}},
|
||||
]
|
||||
}
|
||||
return self.pipeline.chat(query, where=where)
|
||||
|
||||
def delete(self):
|
||||
self.pipeline.reset()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "embedchain"
|
||||
version = "0.1.6"
|
||||
version = "0.1.7"
|
||||
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
||||
authors = [
|
||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||
|
||||
Reference in New Issue
Block a user