[Feature] Improve github and youtube channel loader (#966)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -1,104 +0,0 @@
|
||||
"""
|
||||
Note that this file is copied from Chroma repository. We will remove this file once the fix in
|
||||
ChromaDB's repository.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
|
||||
|
||||
class OpenAIEmbeddingFunction:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "text-embedding-ada-002",
|
||||
organization_id: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_type: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
deployment_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the OpenAIEmbeddingFunction.
|
||||
Args:
|
||||
api_key (str, optional): Your API key for the OpenAI API. If not
|
||||
provided, it will raise an error to provide an OpenAI API key.
|
||||
organization_id(str, optional): The OpenAI organization ID if applicable
|
||||
model_name (str, optional): The name of the model to use for text
|
||||
embeddings. Defaults to "text-embedding-ada-002".
|
||||
api_base (str, optional): The base path for the API. If not provided,
|
||||
it will use the base path for the OpenAI API. This can be used to
|
||||
point to a different deployment, such as an Azure deployment.
|
||||
api_type (str, optional): The type of the API deployment. This can be
|
||||
used to specify a different deployment, such as 'azure'. If not
|
||||
provided, it will use the default OpenAI deployment.
|
||||
api_version (str, optional): The api version for the API. If not provided,
|
||||
it will use the api version for the OpenAI API. This can be used to
|
||||
point to a different deployment, such as an Azure deployment.
|
||||
deployment_id (str, optional): Deployment ID for Azure OpenAI.
|
||||
|
||||
"""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError("The openai python package is not installed. Please install it with `pip install openai`")
|
||||
|
||||
if api_key is not None:
|
||||
openai.api_key = api_key
|
||||
# If the api key is still not set, raise an error
|
||||
elif openai.api_key is None:
|
||||
raise ValueError(
|
||||
"Please provide an OpenAI API key. You can get one at https://platform.openai.com/account/api-keys"
|
||||
)
|
||||
|
||||
if api_base is not None:
|
||||
openai.api_base = api_base
|
||||
|
||||
if api_version is not None:
|
||||
openai.api_version = api_version
|
||||
|
||||
self._api_type = api_type
|
||||
if api_type is not None:
|
||||
openai.api_type = api_type
|
||||
|
||||
if organization_id is not None:
|
||||
openai.organization = organization_id
|
||||
|
||||
self._v1 = openai.__version__.startswith("1.")
|
||||
if self._v1:
|
||||
if api_type == "azure":
|
||||
self._client = openai.AzureOpenAI(
|
||||
api_key=api_key, api_version=api_version, azure_endpoint=api_base
|
||||
).embeddings
|
||||
else:
|
||||
self._client = openai.OpenAI(api_key=api_key, base_url=api_base).embeddings
|
||||
else:
|
||||
self._client = openai.Embedding
|
||||
self._model_name = model_name
|
||||
self._deployment_id = deployment_id
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
# replace newlines, which can negatively affect performance.
|
||||
input = [t.replace("\n", " ") for t in input]
|
||||
|
||||
# Call the OpenAI Embedding API
|
||||
if self._v1:
|
||||
embeddings = self._client.create(input=input, model=self._deployment_id or self._model_name).data
|
||||
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e.index) # type: ignore
|
||||
|
||||
# Return just the embeddings
|
||||
return [result.embedding for result in sorted_embeddings]
|
||||
else:
|
||||
if self._api_type == "azure":
|
||||
embeddings = self._client.create(input=input, engine=self._deployment_id or self._model_name)["data"]
|
||||
else:
|
||||
embeddings = self._client.create(input=input, model=self._model_name)["data"]
|
||||
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
||||
|
||||
# Return just the embeddings
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
@@ -1,18 +1,18 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
from .chroma_embeddings import OpenAIEmbeddingFunction
|
||||
|
||||
|
||||
class OpenAIEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
if self.config.model is None:
|
||||
self.config.model = "text-embedding-ada-002"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user