diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index a764d13f..991397cb 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -482,7 +482,7 @@ class EmbedChain(JSONSerializable): where = {} if query_config is not None and query_config.where is not None: where = query_config.where - + if self.config.id is not None: where.update({"app_id": self.config.id}) diff --git a/embedchain/store/assistants.py b/embedchain/store/assistants.py index 938af55a..e3bde8de 100644 --- a/embedchain/store/assistants.py +++ b/embedchain/store/assistants.py @@ -3,12 +3,14 @@ import os import re import tempfile import time +import uuid from pathlib import Path from typing import cast from openai import OpenAI from openai.types.beta.threads import MessageContentText, ThreadMessage +from embedchain import Pipeline from embedchain.config import AddConfig from embedchain.data_formatter import DataFormatter from embedchain.models.data_type import DataType @@ -138,3 +140,65 @@ class OpenAIAssistant: with open(file_path, "wb") as file: file.write(data) return file_path + + +class AIAssistant: + def __init__( + self, + name=None, + instructions=None, + yaml_path=None, + assistant_id=None, + thread_id=None, + data_sources=None, + log_level=logging.WARN, + collect_metrics=True, + ): + logging.basicConfig(level=log_level) + + self.name = name or "AI Assistant" + self.data_sources = data_sources or [] + self.log_level = log_level + self.instructions = instructions + self.assistant_id = assistant_id or str(uuid.uuid4()) + self.thread_id = thread_id or str(uuid.uuid4()) + self.pipeline = Pipeline.from_config(yaml_path=yaml_path) if yaml_path else Pipeline() + self.pipeline.local_id = self.pipeline.config.id = self.thread_id + + if self.instructions: + self.pipeline.system_prompt = self.instructions + + print( + f"🎉 Created AI Assistant with name: {self.name}, assistant_id: {self.assistant_id}, thread_id: {self.thread_id}" # noqa: E501 + ) + + # telemetry related properties + self._telemetry_props = {"class": self.__class__.__name__} + self.telemetry = AnonymousTelemetry(enabled=collect_metrics) + self.telemetry.capture(event_name="init", properties=self._telemetry_props) + + if self.data_sources: + for data_source in self.data_sources: + metadata = {"assistant_id": self.assistant_id, "thread_id": "global_knowledge"} + self.pipeline.add(data_source["source"], data_source.get("data_type"), metadata=metadata) + + def add(self, source, data_type=None): + metadata = {"assistant_id": self.assistant_id, "thread_id": self.thread_id} + self.pipeline.add(source, data_type=data_type, metadata=metadata) + event_props = { + **self._telemetry_props, + "data_type": data_type or detect_datatype(source), + } + self.telemetry.capture(event_name="add", properties=event_props) + + def chat(self, query): + where = { + "$and": [ + {"assistant_id": {"$eq": self.assistant_id}}, + {"thread_id": {"$in": [self.thread_id, "global_knowledge"]}}, + ] + } + return self.pipeline.chat(query, where=where) + + def delete(self): + self.pipeline.reset() diff --git a/pyproject.toml b/pyproject.toml index 6b6b391f..317a8026 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.6" +version = "0.1.7" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" authors = [ "Taranjeet Singh ",