[Feature] Add support for AIAssistant (#938)
This commit is contained in:
@@ -482,7 +482,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
where = {}
|
where = {}
|
||||||
if query_config is not None and query_config.where is not None:
|
if query_config is not None and query_config.where is not None:
|
||||||
where = query_config.where
|
where = query_config.where
|
||||||
|
|
||||||
if self.config.id is not None:
|
if self.config.id is not None:
|
||||||
where.update({"app_id": self.config.id})
|
where.update({"app_id": self.config.id})
|
||||||
|
|
||||||
|
|||||||
@@ -3,12 +3,14 @@ import os
|
|||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from openai.types.beta.threads import MessageContentText, ThreadMessage
|
from openai.types.beta.threads import MessageContentText, ThreadMessage
|
||||||
|
|
||||||
|
from embedchain import Pipeline
|
||||||
from embedchain.config import AddConfig
|
from embedchain.config import AddConfig
|
||||||
from embedchain.data_formatter import DataFormatter
|
from embedchain.data_formatter import DataFormatter
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
@@ -138,3 +140,65 @@ class OpenAIAssistant:
|
|||||||
with open(file_path, "wb") as file:
|
with open(file_path, "wb") as file:
|
||||||
file.write(data)
|
file.write(data)
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
class AIAssistant:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name=None,
|
||||||
|
instructions=None,
|
||||||
|
yaml_path=None,
|
||||||
|
assistant_id=None,
|
||||||
|
thread_id=None,
|
||||||
|
data_sources=None,
|
||||||
|
log_level=logging.WARN,
|
||||||
|
collect_metrics=True,
|
||||||
|
):
|
||||||
|
logging.basicConfig(level=log_level)
|
||||||
|
|
||||||
|
self.name = name or "AI Assistant"
|
||||||
|
self.data_sources = data_sources or []
|
||||||
|
self.log_level = log_level
|
||||||
|
self.instructions = instructions
|
||||||
|
self.assistant_id = assistant_id or str(uuid.uuid4())
|
||||||
|
self.thread_id = thread_id or str(uuid.uuid4())
|
||||||
|
self.pipeline = Pipeline.from_config(yaml_path=yaml_path) if yaml_path else Pipeline()
|
||||||
|
self.pipeline.local_id = self.pipeline.config.id = self.thread_id
|
||||||
|
|
||||||
|
if self.instructions:
|
||||||
|
self.pipeline.system_prompt = self.instructions
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🎉 Created AI Assistant with name: {self.name}, assistant_id: {self.assistant_id}, thread_id: {self.thread_id}" # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
|
# telemetry related properties
|
||||||
|
self._telemetry_props = {"class": self.__class__.__name__}
|
||||||
|
self.telemetry = AnonymousTelemetry(enabled=collect_metrics)
|
||||||
|
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
|
||||||
|
|
||||||
|
if self.data_sources:
|
||||||
|
for data_source in self.data_sources:
|
||||||
|
metadata = {"assistant_id": self.assistant_id, "thread_id": "global_knowledge"}
|
||||||
|
self.pipeline.add(data_source["source"], data_source.get("data_type"), metadata=metadata)
|
||||||
|
|
||||||
|
def add(self, source, data_type=None):
|
||||||
|
metadata = {"assistant_id": self.assistant_id, "thread_id": self.thread_id}
|
||||||
|
self.pipeline.add(source, data_type=data_type, metadata=metadata)
|
||||||
|
event_props = {
|
||||||
|
**self._telemetry_props,
|
||||||
|
"data_type": data_type or detect_datatype(source),
|
||||||
|
}
|
||||||
|
self.telemetry.capture(event_name="add", properties=event_props)
|
||||||
|
|
||||||
|
def chat(self, query):
|
||||||
|
where = {
|
||||||
|
"$and": [
|
||||||
|
{"assistant_id": {"$eq": self.assistant_id}},
|
||||||
|
{"thread_id": {"$in": [self.thread_id, "global_knowledge"]}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
return self.pipeline.chat(query, where=where)
|
||||||
|
|
||||||
|
def delete(self):
|
||||||
|
self.pipeline.reset()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "embedchain"
|
name = "embedchain"
|
||||||
version = "0.1.6"
|
version = "0.1.7"
|
||||||
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
||||||
authors = [
|
authors = [
|
||||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||||
|
|||||||
Reference in New Issue
Block a user