[Improvements] Package improvements (#993)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-12-05 23:42:45 -08:00
committed by GitHub
parent 1d4e00ccef
commit 51b4966801
13 changed files with 96 additions and 40 deletions

View File

@@ -41,7 +41,6 @@ class BaseChunker(JSONSerializable):
url = meta_data["url"]
chunks = self.get_chunks(content)
for chunk in chunks:
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id

View File

@@ -1,5 +1,5 @@
from importlib import import_module
from typing import Any, Dict
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig
@@ -16,7 +16,13 @@ class DataFormatter(JSONSerializable):
.add or .add_local method call
"""
def __init__(self, data_type: DataType, config: AddConfig, kwargs: Dict[str, Any]):
def __init__(
self,
data_type: DataType,
config: AddConfig,
loader: Optional[BaseLoader] = None,
chunker: Optional[BaseChunker] = None,
):
"""
Initialize a dataformatter, set data type and chunker based on datatype.
@@ -25,15 +31,15 @@ class DataFormatter(JSONSerializable):
:param config: AddConfig instance with nested loader and chunker config attributes.
:type config: AddConfig
"""
self.loader = self._get_loader(data_type=data_type, config=config.loader, kwargs=kwargs)
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, kwargs=kwargs)
self.loader = self._get_loader(data_type=data_type, config=config.loader, loader=loader)
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, chunker=chunker)
def _lazy_load(self, module_path: str):
module_path, class_name = module_path.rsplit(".", 1)
module = import_module(module_path)
return getattr(module, class_name)
def _get_loader(self, data_type: DataType, config: LoaderConfig, kwargs: Dict[str, Any]) -> BaseLoader:
def _get_loader(self, data_type: DataType, config: LoaderConfig, loader: Optional[BaseLoader]) -> BaseLoader:
"""
Returns the appropriate data loader for the given data type.
@@ -68,8 +74,8 @@ class DataFormatter(JSONSerializable):
DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
}
if data_type == DataType.CUSTOM or ("loader" in kwargs):
loader_class: type = kwargs.get("loader", None)
if data_type == DataType.CUSTOM or loader is not None:
loader_class: type = loader
if loader_class:
return loader_class
elif data_type in loaders:
@@ -82,7 +88,7 @@ class DataFormatter(JSONSerializable):
check `https://docs.embedchain.ai/data-sources/overview`."
)
def _get_chunker(self, data_type: DataType, config: ChunkerConfig, kwargs: Dict[str, Any]) -> BaseChunker:
def _get_chunker(self, data_type: DataType, config: ChunkerConfig, chunker: Optional[BaseChunker]) -> BaseChunker:
"""Returns the appropriate chunker for the given data type (updated for lazy loading)."""
chunker_classes = {
DataType.YOUTUBE_VIDEO: "embedchain.chunkers.youtube_video.YoutubeVideoChunker",
@@ -108,12 +114,8 @@ class DataFormatter(JSONSerializable):
DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
}
if "chunker" in kwargs:
chunker_class = kwargs.get("chunker", None)
if chunker_class:
chunker = chunker_class(config)
chunker.set_data_type(data_type)
return chunker
if chunker is not None:
return chunker
elif data_type in chunker_classes:
chunker_class = self._lazy_load(chunker_classes[data_type])
chunker = chunker_class(config)

View File

@@ -133,7 +133,9 @@ class EmbedChain(JSONSerializable):
metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None,
dry_run=False,
**kwargs: Dict[str, Any],
loader: Optional[BaseLoader] = None,
chunker: Optional[BaseChunker] = None,
**kwargs: Optional[Dict[str, Any]],
):
"""
Adds the data from the given URL to the vector db.
@@ -192,9 +194,9 @@ class EmbedChain(JSONSerializable):
self.user_asks.append([source, data_type.value, metadata])
data_formatter = DataFormatter(data_type, config, kwargs)
data_formatter = DataFormatter(data_type, config, loader, chunker)
documents, metadatas, _ids, new_chunks = self._load_and_embed(
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run, **kwargs
)
if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True
@@ -238,7 +240,7 @@ class EmbedChain(JSONSerializable):
data_type: Optional[DataType] = None,
metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None,
**kwargs: Dict[str, Any],
**kwargs: Optional[Dict[str, Any]],
):
"""
Adds the data from the given URL to the vector db.
@@ -269,7 +271,7 @@ class EmbedChain(JSONSerializable):
data_type=data_type,
metadata=metadata,
config=config,
kwargs=kwargs,
**kwargs,
)
def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
@@ -338,6 +340,7 @@ class EmbedChain(JSONSerializable):
metadata: Optional[Dict[str, Any]] = None,
source_hash: Optional[str] = None,
dry_run=False,
**kwargs: Optional[Dict[str, Any]],
):
"""
Loads the data from the given URL, chunks it, and adds it to database.
@@ -431,6 +434,7 @@ class EmbedChain(JSONSerializable):
metadatas=metadatas,
ids=ids,
skip_embedding=(chunker.data_type == DataType.IMAGES),
**kwargs,
)
count_new_chunks = self.db.count() - chunks_before_addition
@@ -448,7 +452,12 @@ class EmbedChain(JSONSerializable):
]
def _retrieve_from_database(
self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
self,
input_query: str,
config: Optional[BaseLlmConfig] = None,
where=None,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
Queries the vector database based on the given input query.
@@ -492,6 +501,7 @@ class EmbedChain(JSONSerializable):
where=where,
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
citations=citations,
**kwargs,
)
return contexts
@@ -526,9 +536,13 @@ class EmbedChain(JSONSerializable):
or the dry run result
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
"""
citations = kwargs.get("citations", False)
if "citations" in kwargs:
citations = kwargs.pop("citations")
else:
citations = False
contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations
input_query=input_query, config=config, where=where, citations=citations, **kwargs
)
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
@@ -579,9 +593,13 @@ class EmbedChain(JSONSerializable):
or the dry run result
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
"""
citations = kwargs.get("citations", False)
if "citations" in kwargs:
citations = kwargs.pop("citations")
else:
citations = False
contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations
input_query=input_query, config=config, where=where, citations=citations, **kwargs
)
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))

View File

@@ -196,7 +196,6 @@ class GithubLoader(BaseLoader):
logging.info(f"Total repos found: {repos_results.totalCount}")
for repo_result in tqdm(repos_results, total=repos_results.totalCount, desc="Loading discussions from github"):
teams = repo_result.get_teams()
# import pdb; pdb.set_trace()
for team in teams:
team_discussions = team.get_discussions()
for discussion in team_discussions:

View File

@@ -1,3 +1,4 @@
import itertools
import json
import logging
import os
@@ -6,6 +7,7 @@ import string
from typing import Any
from schema import Optional, Or, Schema
from tqdm import tqdm
from embedchain.models.data_type import DataType
@@ -422,3 +424,16 @@ def validate_config(config_data):
)
return schema.validate(config_data)
def chunks(iterable, batch_size=100, desc="Processing chunks"):
"""A helper function to break an iterable into chunks of size batch_size."""
it = iter(iterable)
total_size = len(iterable)
with tqdm(total=total_size, desc=desc, unit="batch") as pbar:
chunk = tuple(itertools.islice(it, batch_size))
while chunk:
yield chunk
pbar.update(len(chunk))
chunk = tuple(itertools.islice(it, batch_size))

View File

@@ -133,6 +133,7 @@ class ChromaDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, Any]],
) -> Any:
"""
Add vectors to chroma database
@@ -198,6 +199,7 @@ class ChromaDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
Query contents from vector database based on vector similarity
@@ -225,6 +227,7 @@ class ChromaDB(BaseVectorDB):
],
n_results=n_results,
where=self._generate_where_clause(where),
**kwargs,
)
else:
result = self.collection.query(
@@ -233,6 +236,7 @@ class ChromaDB(BaseVectorDB):
],
n_results=n_results,
where=self._generate_where_clause(where),
**kwargs,
)
except InvalidDimensionException as e:
raise InvalidDimensionException(

View File

@@ -105,6 +105,7 @@ class ElasticsearchDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
) -> Any:
"""
add data in vector database
@@ -142,6 +143,7 @@ class ElasticsearchDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
query contents from vector data base based on vector similarity

View File

@@ -1,6 +1,6 @@
import logging
import time
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from tqdm import tqdm
@@ -121,6 +121,7 @@ class OpenSearchDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""Add data in vector database.
@@ -154,7 +155,7 @@ class OpenSearchDB(BaseVectorDB):
]
# Perform bulk operation
bulk(self.client, batch_entries)
bulk(self.client, batch_entries, **kwargs)
self.client.indices.refresh(index=self._get_index())
# Sleep to avoid rate limiting
@@ -167,6 +168,7 @@ class OpenSearchDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
query contents from vector data base based on vector similarity
@@ -209,6 +211,7 @@ class OpenSearchDB(BaseVectorDB):
metadata_field="metadata",
pre_filter=pre_filter,
k=n_results,
**kwargs,
)
contexts = []

View File

@@ -10,6 +10,7 @@ except ImportError:
from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.utils import chunks
from embedchain.vectordb.base import BaseVectorDB
@@ -92,6 +93,7 @@ class PineconeDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""add data in vector database
@@ -104,7 +106,6 @@ class PineconeDB(BaseVectorDB):
"""
docs = []
print("Adding documents to Pinecone...")
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
docs.append(
@@ -115,8 +116,8 @@ class PineconeDB(BaseVectorDB):
}
)
for i in range(0, len(docs), self.BATCH_SIZE):
self.client.upsert(docs[i : i + self.BATCH_SIZE])
for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches..."):
self.client.upsert(chunk, **kwargs)
def query(
self,
@@ -125,6 +126,7 @@ class PineconeDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
query contents from vector database based on vector similarity
@@ -146,7 +148,7 @@ class PineconeDB(BaseVectorDB):
query_vector = self.embedder.embedding_fn([input_query])[0]
else:
query_vector = input_query
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs)
contexts = []
for doc in data["matches"]:
metadata = doc["metadata"]

View File

@@ -1,7 +1,7 @@
import copy
import os
import uuid
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
try:
from qdrant_client import QdrantClient
@@ -127,6 +127,7 @@ class QdrantDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""add data in vector database
:param embeddings: list of embeddings for the corresponding documents to be added
@@ -158,6 +159,7 @@ class QdrantDB(BaseVectorDB):
payloads=payloads[i : i + self.BATCH_SIZE],
vectors=embeddings[i : i + self.BATCH_SIZE],
),
**kwargs,
)
def query(
@@ -167,6 +169,7 @@ class QdrantDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
query contents from vector database based on vector similarity
@@ -208,6 +211,7 @@ class QdrantDB(BaseVectorDB):
query_filter=models.Filter(must=qdrant_must_filters),
query_vector=query_vector,
limit=n_results,
**kwargs,
)
contexts = []

View File

@@ -1,6 +1,6 @@
import copy
import os
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
try:
import weaviate
@@ -158,6 +158,7 @@ class WeaviateDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""add data in vector database
:param embeddings: list of embeddings for the corresponding documents to be added
@@ -192,7 +193,9 @@ class WeaviateDB(BaseVectorDB):
class_name=self.index_name + "_metadata",
vector=embedding,
)
batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
batch.add_reference(
obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata", **kwargs
)
def query(
self,
@@ -201,6 +204,7 @@ class WeaviateDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
query contents from vector database based on vector similarity

View File

@@ -1,5 +1,5 @@
import logging
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from embedchain.config import ZillizDBConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -113,6 +113,7 @@ class ZillizVectorDB(BaseVectorDB):
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""Add to database"""
if not skip_embedding:
@@ -120,7 +121,7 @@ class ZillizVectorDB(BaseVectorDB):
for id, doc, metadata, embedding in zip(ids, documents, metadatas, embeddings):
data = {**metadata, "id": id, "text": doc, "embeddings": embedding}
self.client.insert(collection_name=self.config.collection_name, data=data)
self.client.insert(collection_name=self.config.collection_name, data=data, **kwargs)
self.collection.load()
self.collection.flush()
@@ -133,6 +134,7 @@ class ZillizVectorDB(BaseVectorDB):
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
Query contents from vector data base based on vector similarity
@@ -165,6 +167,7 @@ class ZillizVectorDB(BaseVectorDB):
data=query_vector,
limit=n_results,
output_fields=output_fields,
**kwargs,
)
else:
@@ -176,6 +179,7 @@ class ZillizVectorDB(BaseVectorDB):
data=[query_vector],
limit=n_results,
output_fields=output_fields,
**kwargs,
)
contexts = []