[Feature] Improve github and youtube channel loader (#966)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -1,8 +1,5 @@
|
||||
import hashlib
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
@@ -10,15 +7,7 @@ from embedchain.models.data_type import DataType
|
||||
class BaseChunker(JSONSerializable):
|
||||
def __init__(self, text_splitter):
|
||||
"""Initialize the chunker."""
|
||||
if text_splitter is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
else:
|
||||
self.text_splitter = text_splitter
|
||||
self.text_splitter = text_splitter
|
||||
self.data_type = None
|
||||
|
||||
def create_chunks(self, loader, src, app_id=None):
|
||||
|
||||
22
embedchain/chunkers/common_chunker.py
Normal file
22
embedchain/chunkers/common_chunker.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CommonChunker(BaseChunker):
|
||||
"""Common chunker for all loaders."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
@@ -116,8 +116,8 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.SLACK: "embedchain.chunkers.slack.SlackChunker",
|
||||
DataType.DISCOURSE: "embedchain.chunkers.discourse.DiscourseChunker",
|
||||
DataType.SUBSTACK: "embedchain.chunkers.substack.SubstackChunker",
|
||||
DataType.GITHUB: "embedchain.chunkers.base_chunker.BaseChunker",
|
||||
DataType.YOUTUBE_CHANNEL: "embedchain.chunkers.base_chunker.BaseChunker",
|
||||
DataType.GITHUB: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.YOUTUBE_CHANNEL: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
}
|
||||
|
||||
if data_type in chunker_classes:
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
"""
|
||||
Note that this file is copied from Chroma repository. We will remove this file once the fix in
|
||||
ChromaDB's repository.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.api.types import Documents, Embeddings
|
||||
|
||||
|
||||
class OpenAIEmbeddingFunction:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "text-embedding-ada-002",
|
||||
organization_id: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_type: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
deployment_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the OpenAIEmbeddingFunction.
|
||||
Args:
|
||||
api_key (str, optional): Your API key for the OpenAI API. If not
|
||||
provided, it will raise an error to provide an OpenAI API key.
|
||||
organization_id(str, optional): The OpenAI organization ID if applicable
|
||||
model_name (str, optional): The name of the model to use for text
|
||||
embeddings. Defaults to "text-embedding-ada-002".
|
||||
api_base (str, optional): The base path for the API. If not provided,
|
||||
it will use the base path for the OpenAI API. This can be used to
|
||||
point to a different deployment, such as an Azure deployment.
|
||||
api_type (str, optional): The type of the API deployment. This can be
|
||||
used to specify a different deployment, such as 'azure'. If not
|
||||
provided, it will use the default OpenAI deployment.
|
||||
api_version (str, optional): The api version for the API. If not provided,
|
||||
it will use the api version for the OpenAI API. This can be used to
|
||||
point to a different deployment, such as an Azure deployment.
|
||||
deployment_id (str, optional): Deployment ID for Azure OpenAI.
|
||||
|
||||
"""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError("The openai python package is not installed. Please install it with `pip install openai`")
|
||||
|
||||
if api_key is not None:
|
||||
openai.api_key = api_key
|
||||
# If the api key is still not set, raise an error
|
||||
elif openai.api_key is None:
|
||||
raise ValueError(
|
||||
"Please provide an OpenAI API key. You can get one at https://platform.openai.com/account/api-keys"
|
||||
)
|
||||
|
||||
if api_base is not None:
|
||||
openai.api_base = api_base
|
||||
|
||||
if api_version is not None:
|
||||
openai.api_version = api_version
|
||||
|
||||
self._api_type = api_type
|
||||
if api_type is not None:
|
||||
openai.api_type = api_type
|
||||
|
||||
if organization_id is not None:
|
||||
openai.organization = organization_id
|
||||
|
||||
self._v1 = openai.__version__.startswith("1.")
|
||||
if self._v1:
|
||||
if api_type == "azure":
|
||||
self._client = openai.AzureOpenAI(
|
||||
api_key=api_key, api_version=api_version, azure_endpoint=api_base
|
||||
).embeddings
|
||||
else:
|
||||
self._client = openai.OpenAI(api_key=api_key, base_url=api_base).embeddings
|
||||
else:
|
||||
self._client = openai.Embedding
|
||||
self._model_name = model_name
|
||||
self._deployment_id = deployment_id
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
# replace newlines, which can negatively affect performance.
|
||||
input = [t.replace("\n", " ") for t in input]
|
||||
|
||||
# Call the OpenAI Embedding API
|
||||
if self._v1:
|
||||
embeddings = self._client.create(input=input, model=self._deployment_id or self._model_name).data
|
||||
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e.index) # type: ignore
|
||||
|
||||
# Return just the embeddings
|
||||
return [result.embedding for result in sorted_embeddings]
|
||||
else:
|
||||
if self._api_type == "azure":
|
||||
embeddings = self._client.create(input=input, engine=self._deployment_id or self._model_name)["data"]
|
||||
else:
|
||||
embeddings = self._client.create(input=input, model=self._model_name)["data"]
|
||||
|
||||
# Sort resulting embeddings by index
|
||||
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
||||
|
||||
# Return just the embeddings
|
||||
return [result["embedding"] for result in sorted_embeddings]
|
||||
@@ -1,18 +1,18 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
from .chroma_embeddings import OpenAIEmbeddingFunction
|
||||
|
||||
|
||||
class OpenAIEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
if self.config.model is None:
|
||||
self.config.model = "text-embedding-ada-002"
|
||||
|
||||
|
||||
@@ -8,10 +8,35 @@ from tqdm import tqdm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.json import JSONLoader
|
||||
from embedchain.loaders.mdx import MdxLoader
|
||||
from embedchain.loaders.unstructured_file import UnstructuredLoader
|
||||
from embedchain.utils import detect_datatype
|
||||
|
||||
|
||||
def _load_file_data(path):
|
||||
data = []
|
||||
data_content = []
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
content = f.read().decode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"Error reading file {path}: {e}")
|
||||
raise ValueError(f"Failed to read file {path}")
|
||||
|
||||
meta_data = {}
|
||||
meta_data["url"] = path
|
||||
data.append(
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": meta_data,
|
||||
}
|
||||
)
|
||||
data_content.append(content)
|
||||
doc_id = hashlib.sha256((" ".join(data_content) + path).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
|
||||
|
||||
class GithubLoader(BaseLoader):
|
||||
def load_data(self, repo_url):
|
||||
"""Load data from a git repo."""
|
||||
@@ -24,7 +49,6 @@ class GithubLoader(BaseLoader):
|
||||
|
||||
mdx_loader = MdxLoader()
|
||||
json_loader = JSONLoader()
|
||||
unstructured_loader = UnstructuredLoader()
|
||||
data = []
|
||||
data_urls = []
|
||||
|
||||
@@ -51,7 +75,7 @@ class GithubLoader(BaseLoader):
|
||||
elif data_type == "json":
|
||||
data = json_loader.load_data(file_path)
|
||||
else:
|
||||
data = unstructured_loader.load_data(file_path)
|
||||
data = _load_file_data(file_path)
|
||||
|
||||
return data.get("data", [])
|
||||
|
||||
@@ -64,7 +88,7 @@ class GithubLoader(BaseLoader):
|
||||
return file_extension[1:] in whitelisted_extensions
|
||||
|
||||
def _add_repo_files(repo_path: str):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename)
|
||||
for root, _, files in os.walk(repo_path)
|
||||
|
||||
@@ -53,7 +53,7 @@ class SitemapLoader(BaseLoader):
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_to_link = {executor.submit(load_link, link): link for link in links}
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_link), total=len(links)):
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_link), total=len(links), desc="Loading pages"):
|
||||
link = future_to_link[future]
|
||||
try:
|
||||
data = future.result()
|
||||
|
||||
@@ -2,6 +2,8 @@ import concurrent.futures
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||
|
||||
@@ -48,11 +50,16 @@ class YoutubeChannelLoader(BaseLoader):
|
||||
|
||||
def _add_youtube_channel():
|
||||
video_links = _get_yt_video_links()
|
||||
logging.info("Loading videos from youtube channel...")
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# Submitting all tasks and storing the future object with the video link
|
||||
future_to_video = {
|
||||
executor.submit(_load_yt_video, video_link): video_link for video_link in video_links
|
||||
} # noqa: E501
|
||||
for future in concurrent.futures.as_completed(future_to_video):
|
||||
}
|
||||
|
||||
for future in tqdm(
|
||||
concurrent.futures.as_completed(future_to_video), total=len(video_links), desc="Processing videos"
|
||||
):
|
||||
video = future_to_video[future]
|
||||
try:
|
||||
results = future.result()
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from chromadb import Collection, QueryResult
|
||||
from langchain.docstore.document import Document
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.config import ChromaDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
@@ -157,12 +158,7 @@ class ChromaDB(BaseVectorDB):
|
||||
" Ids size: {}".format(len(documents), len(metadatas), len(ids))
|
||||
)
|
||||
|
||||
for i in range(0, len(documents), self.BATCH_SIZE):
|
||||
print(
|
||||
"Inserting batches from {} to {} in vector database.".format(
|
||||
i, min(len(documents), i + self.BATCH_SIZE)
|
||||
)
|
||||
)
|
||||
for i in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in chromadb"):
|
||||
if skip_embedding:
|
||||
self.collection.add(
|
||||
embeddings=embeddings[i : i + self.BATCH_SIZE],
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
from opensearchpy import OpenSearch
|
||||
from opensearchpy.helpers import bulk
|
||||
@@ -23,6 +26,8 @@ class OpenSearchDB(BaseVectorDB):
|
||||
OpenSearch as vector database
|
||||
"""
|
||||
|
||||
BATCH_SIZE = 100
|
||||
|
||||
def __init__(self, config: OpenSearchDBConfig):
|
||||
"""OpenSearch as vector database.
|
||||
|
||||
@@ -131,19 +136,28 @@ class OpenSearchDB(BaseVectorDB):
|
||||
:type skip_embedding: bool
|
||||
"""
|
||||
|
||||
docs = []
|
||||
if not skip_embedding:
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
|
||||
docs.append(
|
||||
{
|
||||
"_index": self._get_index(),
|
||||
"_id": id,
|
||||
"_source": {"text": text, "metadata": metadata, "embeddings": embeddings},
|
||||
}
|
||||
)
|
||||
bulk(self.client, docs)
|
||||
self.client.indices.refresh(index=self._get_index())
|
||||
for i in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"):
|
||||
if not skip_embedding:
|
||||
embeddings = self.embedder.embedding_fn(documents[i : i + self.BATCH_SIZE])
|
||||
|
||||
docs = []
|
||||
for id, text, metadata, embeddings in zip(
|
||||
ids[i : i + self.BATCH_SIZE],
|
||||
documents[i : i + self.BATCH_SIZE],
|
||||
metadatas[i : i + self.BATCH_SIZE],
|
||||
embeddings[i : i + self.BATCH_SIZE],
|
||||
):
|
||||
docs.append(
|
||||
{
|
||||
"_index": self._get_index(),
|
||||
"_id": id,
|
||||
"_source": {"text": text, "metadata": metadata, "embeddings": embeddings},
|
||||
}
|
||||
)
|
||||
bulk(self.client, docs)
|
||||
self.client.indices.refresh(index=self._get_index())
|
||||
# Sleep for 0.1 seconds to avoid rate limiting
|
||||
time.sleep(0.1)
|
||||
|
||||
def query(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user