[Feature] Add support for OpenAI assistants and support openai version >=1.0.0 (#921)
This commit is contained in:
@@ -2,7 +2,8 @@ from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
|
||||
ChunkerConfig)
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
|
||||
@@ -17,7 +17,8 @@ from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils import detect_datatype, is_valid_json_string
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
104
embedchain/embedder/chroma_embeddings.py
Normal file
104
embedchain/embedder/chroma_embeddings.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Note that this file is copied from Chroma repository. We will remove this file once the fix in
|
||||
ChromaDB's repository.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
|
||||
|
||||
class OpenAIEmbeddingFunction:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "text-embedding-ada-002",
|
||||
organization_id: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_type: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
deployment_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the OpenAIEmbeddingFunction.
|
||||
Args:
|
||||
api_key (str, optional): Your API key for the OpenAI API. If not
|
||||
provided, it will raise an error to provide an OpenAI API key.
|
||||
organization_id(str, optional): The OpenAI organization ID if applicable
|
||||
model_name (str, optional): The name of the model to use for text
|
||||
embeddings. Defaults to "text-embedding-ada-002".
|
||||
api_base (str, optional): The base path for the API. If not provided,
|
||||
it will use the base path for the OpenAI API. This can be used to
|
||||
point to a different deployment, such as an Azure deployment.
|
||||
api_type (str, optional): The type of the API deployment. This can be
|
||||
used to specify a different deployment, such as 'azure'. If not
|
||||
provided, it will use the default OpenAI deployment.
|
||||
api_version (str, optional): The api version for the API. If not provided,
|
||||
it will use the api version for the OpenAI API. This can be used to
|
||||
point to a different deployment, such as an Azure deployment.
|
||||
deployment_id (str, optional): Deployment ID for Azure OpenAI.
|
||||
|
||||
"""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError("The openai python package is not installed. Please install it with `pip install openai`")
|
||||
|
||||
if api_key is not None:
|
||||
openai.api_key = api_key
|
||||
# If the api key is still not set, raise an error
|
||||
elif openai.api_key is None:
|
||||
raise ValueError(
|
||||
"Please provide an OpenAI API key. You can get one at https://platform.openai.com/account/api-keys"
|
||||
)
|
||||
|
||||
if api_base is not None:
|
||||
openai.api_base = api_base
|
||||
|
||||
if api_version is not None:
|
||||
openai.api_version = api_version
|
||||
|
||||
self._api_type = api_type
|
||||
if api_type is not None:
|
||||
openai.api_type = api_type
|
||||
|
||||
if organization_id is not None:
|
||||
openai.organization = organization_id
|
||||
|
||||
self._v1 = openai.__version__.startswith("1.")
|
||||
if self._v1:
|
||||
if api_type == "azure":
|
||||
self._client = openai.AzureOpenAI(
|
||||
api_key=api_key, api_version=api_version, azure_endpoint=api_base
|
||||
).embeddings
|
||||
else:
|
||||
self._client = openai.OpenAI(api_key=api_key, base_url=api_base).embeddings
|
||||
else:
|
||||
self._client = openai.Embedding
|
||||
self._model_name = model_name
|
||||
self._deployment_id = deployment_id
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
# replace newlines, which can negatively affect performance.
|
||||
input = [t.replace("\n", " ") for t in input]
|
||||
|
||||
# Call the OpenAI Embedding API
|
||||
if self._v1:
|
||||
embeddings = self._client.create(input=input, model=self._deployment_id or self._model_name).data
|
||||
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e.index) # type: ignore
|
||||
|
||||
# Return just the embeddings
|
||||
return [result.embedding for result in sorted_embeddings]
|
||||
else:
|
||||
if self._api_type == "azure":
|
||||
embeddings = self._client.create(input=input, engine=self._deployment_id or self._model_name)["data"]
|
||||
else:
|
||||
embeddings = self._client.create(input=input, model=self._model_name)["data"]
|
||||
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
||||
|
||||
# Return just the embeddings
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
@@ -7,13 +7,7 @@ from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
try:
|
||||
from chromadb.utils import embedding_functions
|
||||
except RuntimeError:
|
||||
from embedchain.utils import use_pysqlite3
|
||||
|
||||
use_pysqlite3()
|
||||
from chromadb.utils import embedding_functions
|
||||
from .chroma_embeddings import OpenAIEmbeddingFunction
|
||||
|
||||
|
||||
class OpenAIEmbedder(BaseEmbedder):
|
||||
@@ -30,11 +24,10 @@ class OpenAIEmbedder(BaseEmbedder):
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided"
|
||||
) # noqa:E501
|
||||
embedding_fn = embedding_functions.OpenAIEmbeddingFunction(
|
||||
embedding_fn = OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
organization_id=os.getenv("OPENAI_ORGANIZATION"),
|
||||
model_name=self.config.model,
|
||||
)
|
||||
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
self.set_vector_dimension(vector_dimension=VectorDimensions.OPENAI.value)
|
||||
|
||||
@@ -13,7 +13,7 @@ class GPT4ALLLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
if self.config.model is None:
|
||||
self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
|
||||
self.config.model = "orca-mini-3b-gguf2-q4_0.gguf"
|
||||
self.instance = GPT4ALLLlm._get_instance(self.config.model)
|
||||
self.instance.streaming = self.config.stream
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import requests
|
||||
import yaml
|
||||
|
||||
from embedchain import Client
|
||||
from embedchain.config import PipelineConfig, ChunkerConfig
|
||||
from embedchain.config import ChunkerConfig, PipelineConfig
|
||||
from embedchain.embedchain import CONFIG_DIR, EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
@@ -42,7 +42,7 @@ class Pipeline(EmbedChain):
|
||||
embedding_model: BaseEmbedder = None,
|
||||
llm: BaseLlm = None,
|
||||
yaml_path: str = None,
|
||||
log_level=logging.INFO,
|
||||
log_level=logging.WARN,
|
||||
auto_deploy: bool = False,
|
||||
chunker: ChunkerConfig = None,
|
||||
):
|
||||
@@ -59,7 +59,7 @@ class Pipeline(EmbedChain):
|
||||
:type llm: BaseLlm, optional
|
||||
:param yaml_path: Path to the YAML configuration file, defaults to None
|
||||
:type yaml_path: str, optional
|
||||
:param log_level: Log level to use, defaults to logging.INFO
|
||||
:param log_level: Log level to use, defaults to logging.WARN
|
||||
:type log_level: int, optional
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
|
||||
0
embedchain/store/__init__.py
Normal file
0
embedchain/store/__init__.py
Normal file
125
embedchain/store/assistants.py
Normal file
125
embedchain/store/assistants.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from openai import OpenAI
|
||||
from openai.types.beta.threads import MessageContentText, ThreadMessage
|
||||
|
||||
from embedchain.config import AddConfig
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.models.data_type import DataType
|
||||
from embedchain.utils import detect_datatype
|
||||
|
||||
logging.basicConfig(level=logging.WARN)
|
||||
|
||||
|
||||
class OpenAIAssistant:
|
||||
def __init__(
|
||||
self,
|
||||
name=None,
|
||||
instructions=None,
|
||||
tools=None,
|
||||
thread_id=None,
|
||||
model="gpt-4-1106-preview",
|
||||
data_sources=None,
|
||||
assistant_id=None,
|
||||
log_level=logging.WARN,
|
||||
):
|
||||
self.name = name or "OpenAI Assistant"
|
||||
self.instructions = instructions
|
||||
self.tools = tools or [{"type": "retrieval"}]
|
||||
self.model = model
|
||||
self.data_sources = data_sources or []
|
||||
self.log_level = log_level
|
||||
self._client = OpenAI()
|
||||
self._initialize_assistant(assistant_id)
|
||||
self.thread_id = thread_id or self._create_thread()
|
||||
|
||||
def add(self, source, data_type=None):
|
||||
file_path = self._prepare_source_path(source, data_type)
|
||||
self._add_file_to_assistant(file_path)
|
||||
logging.info("Data successfully added to the assistant.")
|
||||
|
||||
def chat(self, message):
|
||||
self._send_message(message)
|
||||
return self._get_latest_response()
|
||||
|
||||
def delete_thread(self):
|
||||
self._client.beta.threads.delete(self.thread_id)
|
||||
self.thread_id = self._create_thread()
|
||||
|
||||
# Internal methods
|
||||
def _initialize_assistant(self, assistant_id):
|
||||
file_ids = self._generate_file_ids(self.data_sources)
|
||||
self.assistant = (
|
||||
self._client.beta.assistants.retrieve(assistant_id)
|
||||
if assistant_id
|
||||
else self._client.beta.assistants.create(
|
||||
name=self.name, model=self.model, file_ids=file_ids, instructions=self.instructions, tools=self.tools
|
||||
)
|
||||
)
|
||||
|
||||
def _create_thread(self):
|
||||
thread = self._client.beta.threads.create()
|
||||
return thread.id
|
||||
|
||||
def _prepare_source_path(self, source, data_type=None):
|
||||
if Path(source).is_file():
|
||||
return source
|
||||
data_type = data_type or detect_datatype(source)
|
||||
formatter = DataFormatter(data_type=DataType(data_type), config=AddConfig())
|
||||
data = formatter.loader.load_data(source)["data"]
|
||||
return self._save_temp_data(data[0]["content"].encode())
|
||||
|
||||
def _add_file_to_assistant(self, file_path):
|
||||
file_obj = self._client.files.create(file=open(file_path, "rb"), purpose="assistants")
|
||||
self._client.beta.assistants.files.create(assistant_id=self.assistant.id, file_id=file_obj.id)
|
||||
|
||||
def _generate_file_ids(self, data_sources):
|
||||
return [
|
||||
self._add_file_to_assistant(self._prepare_source_path(ds["source"], ds.get("data_type")))
|
||||
for ds in data_sources
|
||||
]
|
||||
|
||||
def _send_message(self, message):
|
||||
self._client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=message)
|
||||
self._wait_for_completion()
|
||||
|
||||
def _wait_for_completion(self):
|
||||
run = self._client.beta.threads.runs.create(
|
||||
thread_id=self.thread_id,
|
||||
assistant_id=self.assistant.id,
|
||||
instructions=self.instructions,
|
||||
)
|
||||
run_id = run.id
|
||||
run_status = run.status
|
||||
|
||||
while run_status in ["queued", "in_progress", "requires_action"]:
|
||||
time.sleep(0.1) # Sleep before making the next API call to avoid hitting rate limits
|
||||
run = self._client.beta.threads.runs.retrieve(thread_id=self.thread_id, run_id=run_id)
|
||||
run_status = run.status
|
||||
if run_status == "failed":
|
||||
raise ValueError(f"Thread run failed with the following error: {run.last_error}")
|
||||
|
||||
def _get_latest_response(self):
|
||||
history = self._get_history()
|
||||
return self._format_message(history[0]) if history else None
|
||||
|
||||
def _get_history(self):
|
||||
messages = self._client.beta.threads.messages.list(thread_id=self.thread_id, order="desc")
|
||||
return list(messages)
|
||||
|
||||
def _format_message(self, thread_message):
|
||||
thread_message = cast(ThreadMessage, thread_message)
|
||||
content = [c.text.value for c in thread_message.content if isinstance(c, MessageContentText)]
|
||||
return " ".join(content)
|
||||
|
||||
def _save_temp_data(self, data):
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
file_path = os.path.join(temp_dir, "temp_data")
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(data)
|
||||
return file_path
|
||||
@@ -138,7 +138,8 @@ def detect_datatype(source: Any) -> DataType:
|
||||
formatted_source = format_source(str(source), 30)
|
||||
|
||||
if url:
|
||||
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
from langchain.document_loaders.youtube import \
|
||||
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
|
||||
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||
|
||||
Reference in New Issue
Block a user