[Feature] Add support for AIAssistant (#938)

This commit is contained in:
Deshraj Yadav
2023-11-10 16:47:34 -08:00
committed by GitHub
parent deaa7f50f8
commit 1364975396
3 changed files with 66 additions and 2 deletions

View File

@@ -482,7 +482,7 @@ class EmbedChain(JSONSerializable):
where = {}
if query_config is not None and query_config.where is not None:
where = query_config.where
if self.config.id is not None:
where.update({"app_id": self.config.id})

View File

@@ -3,12 +3,14 @@ import os
import re
import tempfile
import time
import uuid
from pathlib import Path
from typing import cast
from openai import OpenAI
from openai.types.beta.threads import MessageContentText, ThreadMessage
from embedchain import Pipeline
from embedchain.config import AddConfig
from embedchain.data_formatter import DataFormatter
from embedchain.models.data_type import DataType
@@ -138,3 +140,65 @@ class OpenAIAssistant:
with open(file_path, "wb") as file:
file.write(data)
return file_path
class AIAssistant:
def __init__(
self,
name=None,
instructions=None,
yaml_path=None,
assistant_id=None,
thread_id=None,
data_sources=None,
log_level=logging.WARN,
collect_metrics=True,
):
logging.basicConfig(level=log_level)
self.name = name or "AI Assistant"
self.data_sources = data_sources or []
self.log_level = log_level
self.instructions = instructions
self.assistant_id = assistant_id or str(uuid.uuid4())
self.thread_id = thread_id or str(uuid.uuid4())
self.pipeline = Pipeline.from_config(yaml_path=yaml_path) if yaml_path else Pipeline()
self.pipeline.local_id = self.pipeline.config.id = self.thread_id
if self.instructions:
self.pipeline.system_prompt = self.instructions
print(
f"🎉 Created AI Assistant with name: {self.name}, assistant_id: {self.assistant_id}, thread_id: {self.thread_id}" # noqa: E501
)
# telemetry related properties
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=collect_metrics)
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
if self.data_sources:
for data_source in self.data_sources:
metadata = {"assistant_id": self.assistant_id, "thread_id": "global_knowledge"}
self.pipeline.add(data_source["source"], data_source.get("data_type"), metadata=metadata)
def add(self, source, data_type=None):
metadata = {"assistant_id": self.assistant_id, "thread_id": self.thread_id}
self.pipeline.add(source, data_type=data_type, metadata=metadata)
event_props = {
**self._telemetry_props,
"data_type": data_type or detect_datatype(source),
}
self.telemetry.capture(event_name="add", properties=event_props)
def chat(self, query):
where = {
"$and": [
{"assistant_id": {"$eq": self.assistant_id}},
{"thread_id": {"$in": [self.thread_id, "global_knowledge"]}},
]
}
return self.pipeline.chat(query, where=where)
def delete(self):
self.pipeline.reset()

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "embedchain"
version = "0.1.6"
version = "0.1.7"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>",