[Feature] Add support for OpenAI assistants and support openai version >=1.0.0 (#921)

This commit is contained in:
Deshraj Yadav
2023-11-08 22:49:03 -08:00
committed by GitHub
parent d8cdbe0041
commit f7dd65a3de
28 changed files with 621 additions and 247 deletions

View File

@@ -2,7 +2,8 @@ from typing import Optional
import yaml
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
ChunkerConfig)
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder

View File

@@ -17,7 +17,8 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helper.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB

View File

@@ -0,0 +1,104 @@
"""
Note that this file is copied from Chroma repository. We will remove this file once the fix in
ChromaDB's repository.
"""
from typing import Optional
from chromadb.api.types import Documents, Embeddings
class OpenAIEmbeddingFunction:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "text-embedding-ada-002",
organization_id: Optional[str] = None,
api_base: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
deployment_id: Optional[str] = None,
):
"""
Initialize the OpenAIEmbeddingFunction.
Args:
api_key (str, optional): Your API key for the OpenAI API. If not
provided, it will raise an error to provide an OpenAI API key.
organization_id(str, optional): The OpenAI organization ID if applicable
model_name (str, optional): The name of the model to use for text
embeddings. Defaults to "text-embedding-ada-002".
api_base (str, optional): The base path for the API. If not provided,
it will use the base path for the OpenAI API. This can be used to
point to a different deployment, such as an Azure deployment.
api_type (str, optional): The type of the API deployment. This can be
used to specify a different deployment, such as 'azure'. If not
provided, it will use the default OpenAI deployment.
api_version (str, optional): The api version for the API. If not provided,
it will use the api version for the OpenAI API. This can be used to
point to a different deployment, such as an Azure deployment.
deployment_id (str, optional): Deployment ID for Azure OpenAI.
"""
try:
import openai
except ImportError:
raise ValueError("The openai python package is not installed. Please install it with `pip install openai`")
if api_key is not None:
openai.api_key = api_key
# If the api key is still not set, raise an error
elif openai.api_key is None:
raise ValueError(
"Please provide an OpenAI API key. You can get one at https://platform.openai.com/account/api-keys"
)
if api_base is not None:
openai.api_base = api_base
if api_version is not None:
openai.api_version = api_version
self._api_type = api_type
if api_type is not None:
openai.api_type = api_type
if organization_id is not None:
openai.organization = organization_id
self._v1 = openai.__version__.startswith("1.")
if self._v1:
if api_type == "azure":
self._client = openai.AzureOpenAI(
api_key=api_key, api_version=api_version, azure_endpoint=api_base
).embeddings
else:
self._client = openai.OpenAI(api_key=api_key, base_url=api_base).embeddings
else:
self._client = openai.Embedding
self._model_name = model_name
self._deployment_id = deployment_id
def __call__(self, input: Documents) -> Embeddings:
# replace newlines, which can negatively affect performance.
input = [t.replace("\n", " ") for t in input]
# Call the OpenAI Embedding API
if self._v1:
embeddings = self._client.create(input=input, model=self._deployment_id or self._model_name).data
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e.index) # type: ignore
# Return just the embeddings
return [result.embedding for result in sorted_embeddings]
else:
if self._api_type == "azure":
embeddings = self._client.create(input=input, engine=self._deployment_id or self._model_name)["data"]
else:
embeddings = self._client.create(input=input, model=self._model_name)["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]

View File

@@ -7,13 +7,7 @@ from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
try:
from chromadb.utils import embedding_functions
except RuntimeError:
from embedchain.utils import use_pysqlite3
use_pysqlite3()
from chromadb.utils import embedding_functions
from .chroma_embeddings import OpenAIEmbeddingFunction
class OpenAIEmbedder(BaseEmbedder):
@@ -30,11 +24,10 @@ class OpenAIEmbedder(BaseEmbedder):
raise ValueError(
"OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided"
) # noqa:E501
embedding_fn = embedding_functions.OpenAIEmbeddingFunction(
embedding_fn = OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
organization_id=os.getenv("OPENAI_ORGANIZATION"),
model_name=self.config.model,
)
self.set_embedding_fn(embedding_fn=embedding_fn)
self.set_vector_dimension(vector_dimension=VectorDimensions.OPENAI.value)

View File

@@ -13,7 +13,7 @@ class GPT4ALLLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
if self.config.model is None:
self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
self.config.model = "orca-mini-3b-gguf2-q4_0.gguf"
self.instance = GPT4ALLLlm._get_instance(self.config.model)
self.instance.streaming = self.config.stream

View File

@@ -9,7 +9,7 @@ import requests
import yaml
from embedchain import Client
from embedchain.config import PipelineConfig, ChunkerConfig
from embedchain.config import ChunkerConfig, PipelineConfig
from embedchain.embedchain import CONFIG_DIR, EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
@@ -42,7 +42,7 @@ class Pipeline(EmbedChain):
embedding_model: BaseEmbedder = None,
llm: BaseLlm = None,
yaml_path: str = None,
log_level=logging.INFO,
log_level=logging.WARN,
auto_deploy: bool = False,
chunker: ChunkerConfig = None,
):
@@ -59,7 +59,7 @@ class Pipeline(EmbedChain):
:type llm: BaseLlm, optional
:param yaml_path: Path to the YAML configuration file, defaults to None
:type yaml_path: str, optional
:param log_level: Log level to use, defaults to logging.INFO
:param log_level: Log level to use, defaults to logging.WARN
:type log_level: int, optional
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
:type auto_deploy: bool, optional

View File

View File

@@ -0,0 +1,125 @@
import logging
import os
import tempfile
import time
from pathlib import Path
from typing import cast
from openai import OpenAI
from openai.types.beta.threads import MessageContentText, ThreadMessage
from embedchain.config import AddConfig
from embedchain.data_formatter import DataFormatter
from embedchain.models.data_type import DataType
from embedchain.utils import detect_datatype
logging.basicConfig(level=logging.WARN)
class OpenAIAssistant:
def __init__(
self,
name=None,
instructions=None,
tools=None,
thread_id=None,
model="gpt-4-1106-preview",
data_sources=None,
assistant_id=None,
log_level=logging.WARN,
):
self.name = name or "OpenAI Assistant"
self.instructions = instructions
self.tools = tools or [{"type": "retrieval"}]
self.model = model
self.data_sources = data_sources or []
self.log_level = log_level
self._client = OpenAI()
self._initialize_assistant(assistant_id)
self.thread_id = thread_id or self._create_thread()
def add(self, source, data_type=None):
file_path = self._prepare_source_path(source, data_type)
self._add_file_to_assistant(file_path)
logging.info("Data successfully added to the assistant.")
def chat(self, message):
self._send_message(message)
return self._get_latest_response()
def delete_thread(self):
self._client.beta.threads.delete(self.thread_id)
self.thread_id = self._create_thread()
# Internal methods
def _initialize_assistant(self, assistant_id):
file_ids = self._generate_file_ids(self.data_sources)
self.assistant = (
self._client.beta.assistants.retrieve(assistant_id)
if assistant_id
else self._client.beta.assistants.create(
name=self.name, model=self.model, file_ids=file_ids, instructions=self.instructions, tools=self.tools
)
)
def _create_thread(self):
thread = self._client.beta.threads.create()
return thread.id
def _prepare_source_path(self, source, data_type=None):
if Path(source).is_file():
return source
data_type = data_type or detect_datatype(source)
formatter = DataFormatter(data_type=DataType(data_type), config=AddConfig())
data = formatter.loader.load_data(source)["data"]
return self._save_temp_data(data[0]["content"].encode())
def _add_file_to_assistant(self, file_path):
file_obj = self._client.files.create(file=open(file_path, "rb"), purpose="assistants")
self._client.beta.assistants.files.create(assistant_id=self.assistant.id, file_id=file_obj.id)
def _generate_file_ids(self, data_sources):
return [
self._add_file_to_assistant(self._prepare_source_path(ds["source"], ds.get("data_type")))
for ds in data_sources
]
def _send_message(self, message):
self._client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=message)
self._wait_for_completion()
def _wait_for_completion(self):
run = self._client.beta.threads.runs.create(
thread_id=self.thread_id,
assistant_id=self.assistant.id,
instructions=self.instructions,
)
run_id = run.id
run_status = run.status
while run_status in ["queued", "in_progress", "requires_action"]:
time.sleep(0.1) # Sleep before making the next API call to avoid hitting rate limits
run = self._client.beta.threads.runs.retrieve(thread_id=self.thread_id, run_id=run_id)
run_status = run.status
if run_status == "failed":
raise ValueError(f"Thread run failed with the following error: {run.last_error}")
def _get_latest_response(self):
history = self._get_history()
return self._format_message(history[0]) if history else None
def _get_history(self):
messages = self._client.beta.threads.messages.list(thread_id=self.thread_id, order="desc")
return list(messages)
def _format_message(self, thread_message):
thread_message = cast(ThreadMessage, thread_message)
content = [c.text.value for c in thread_message.content if isinstance(c, MessageContentText)]
return " ".join(content)
def _save_temp_data(self, data):
temp_dir = tempfile.mkdtemp()
file_path = os.path.join(temp_dir, "temp_data")
with open(file_path, "wb") as file:
file.write(data)
return file_path

View File

@@ -138,7 +138,8 @@ def detect_datatype(source: Any) -> DataType:
formatted_source = format_source(str(source), 30)
if url:
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
from langchain.document_loaders.youtube import \
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")