Add GPT4Vision Image loader (#1089)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -249,7 +249,6 @@ class App(EmbedChain):
|
||||
query,
|
||||
n_results=num_documents,
|
||||
where=where,
|
||||
skip_embedding=False,
|
||||
citations=True,
|
||||
)
|
||||
result = []
|
||||
|
||||
22
embedchain/chunkers/image.py
Normal file
22
embedchain/chunkers/image.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ImageChunker(BaseChunker):
|
||||
"""Chunker for Images."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
@@ -1,67 +0,0 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
|
||||
|
||||
class ImagesChunker(BaseChunker):
|
||||
"""Chunker for an Image."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
|
||||
image_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(image_splitter)
|
||||
|
||||
def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
|
||||
"""
|
||||
Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
|
||||
|
||||
:param loader: The loader whose `load_data` method is used to create
|
||||
the raw data.
|
||||
:param src: The data to be handled by the loader. Can be a URL for
|
||||
remote sources or local content for local loaders.
|
||||
"""
|
||||
documents = []
|
||||
embeddings = []
|
||||
ids = []
|
||||
min_chunk_size = config.min_chunk_size if config is not None else 0
|
||||
logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
|
||||
data_result = loader.load_data(src)
|
||||
data_records = data_result["data"]
|
||||
doc_id = data_result["doc_id"]
|
||||
doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
|
||||
metadatas = []
|
||||
for data in data_records:
|
||||
meta_data = data["meta_data"]
|
||||
# add data type to meta data to allow query using data type
|
||||
meta_data["data_type"] = self.data_type.value
|
||||
chunk_id = hashlib.sha256(meta_data["url"].encode()).hexdigest()
|
||||
ids.append(chunk_id)
|
||||
documents.append(data["content"])
|
||||
embeddings.append(data["embedding"])
|
||||
meta_data["doc_id"] = doc_id
|
||||
metadatas.append(meta_data)
|
||||
|
||||
return {
|
||||
"documents": documents,
|
||||
"embeddings": embeddings,
|
||||
"ids": ids,
|
||||
"metadatas": metadatas,
|
||||
"doc_id": doc_id,
|
||||
}
|
||||
|
||||
def get_word_count(self, documents):
|
||||
"""
|
||||
The number of chunks and the corresponding word count for an image is fixed to 1, as 1 embedding is created for
|
||||
each image
|
||||
"""
|
||||
return 1
|
||||
@@ -63,7 +63,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.DOCS_SITE: "embedchain.loaders.docs_site_loader.DocsSiteLoader",
|
||||
DataType.CSV: "embedchain.loaders.csv.CsvLoader",
|
||||
DataType.MDX: "embedchain.loaders.mdx.MdxLoader",
|
||||
DataType.IMAGES: "embedchain.loaders.images.ImagesLoader",
|
||||
DataType.IMAGE: "embedchain.loaders.image.ImageLoader",
|
||||
DataType.UNSTRUCTURED: "embedchain.loaders.unstructured_file.UnstructuredLoader",
|
||||
DataType.JSON: "embedchain.loaders.json.JSONLoader",
|
||||
DataType.OPENAPI: "embedchain.loaders.openapi.OpenAPILoader",
|
||||
@@ -108,7 +108,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.DOCS_SITE: "embedchain.chunkers.docs_site.DocsSiteChunker",
|
||||
DataType.CSV: "embedchain.chunkers.table.TableChunker",
|
||||
DataType.MDX: "embedchain.chunkers.mdx.MdxChunker",
|
||||
DataType.IMAGES: "embedchain.chunkers.images.ImagesChunker",
|
||||
DataType.IMAGE: "embedchain.chunkers.image.ImageChunker",
|
||||
DataType.UNSTRUCTURED: "embedchain.chunkers.unstructured_file.UnstructuredFileChunker",
|
||||
DataType.JSON: "embedchain.chunkers.json.JSONChunker",
|
||||
DataType.OPENAPI: "embedchain.chunkers.openapi.OpenAPIChunker",
|
||||
|
||||
@@ -438,7 +438,6 @@ class EmbedChain(JSONSerializable):
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
skip_embedding=(chunker.data_type == DataType.IMAGES),
|
||||
**kwargs,
|
||||
)
|
||||
count_new_chunks = self.db.count() - chunks_before_addition
|
||||
@@ -490,21 +489,10 @@ class EmbedChain(JSONSerializable):
|
||||
if self.config.id is not None:
|
||||
where.update({"app_id": self.config.id})
|
||||
|
||||
# We cannot query the database with the input query in case of an image search. This is because we need
|
||||
# to bring down both the image and text to the same dimension to be able to compare them.
|
||||
db_query = input_query
|
||||
if hasattr(config, "query_type") and config.query_type == "Images":
|
||||
# We import the clip processor here to make sure the package is not dependent on clip dependency even if the
|
||||
# image dataset is not being used
|
||||
from embedchain.models.clip_processor import ClipProcessor
|
||||
|
||||
db_query = ClipProcessor.get_text_features(query=input_query)
|
||||
|
||||
contexts = self.db.query(
|
||||
input_query=db_query,
|
||||
input_query=input_query,
|
||||
n_results=query_config.number_documents,
|
||||
where=where,
|
||||
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
|
||||
citations=citations,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
49
embedchain/loaders/image.py
Normal file
49
embedchain/loaders/image.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
DESCRIBE_IMAGE_PROMPT = "Describe the image:"
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ImageLoader(BaseLoader):
|
||||
def __init__(self, max_tokens: int = 500, api_key: str = None, prompt: str = None):
|
||||
super().__init__()
|
||||
self.custom_prompt = prompt or DESCRIBE_IMAGE_PROMPT
|
||||
self.max_tokens = max_tokens
|
||||
self.api_key = api_key or os.environ["OPENAI_API_KEY"]
|
||||
self.client = OpenAI(api_key=self.api_key)
|
||||
|
||||
def _encode_image(self, image_path: str):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
def _create_completion_request(self, content: str):
|
||||
return self.client.chat.completions.create(
|
||||
model="gpt-4-vision-preview", messages=[{"role": "user", "content": content}], max_tokens=self.max_tokens
|
||||
)
|
||||
|
||||
def _process_url(self, url: str):
|
||||
if url.startswith("http"):
|
||||
return [{"type": "text", "text": self.custom_prompt}, {"type": "image_url", "image_url": {"url": url}}]
|
||||
elif Path(url).is_file():
|
||||
extension = Path(url).suffix.lstrip(".")
|
||||
encoded_image = self._encode_image(url)
|
||||
image_data = f"data:image/{extension};base64,{encoded_image}"
|
||||
return [{"type": "text", "text": self.custom_prompt}, {"type": "image", "image_url": {"url": image_data}}]
|
||||
else:
|
||||
raise ValueError(f"Invalid URL or file path: {url}")
|
||||
|
||||
def load_data(self, url: str):
|
||||
content = self._process_url(url)
|
||||
response = self._create_completion_request(content)
|
||||
content = response.choices[0].message.content
|
||||
|
||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||
return {"doc_id": doc_id, "data": [{"content": content, "meta_data": {"url": url, "type": "image"}}]}
|
||||
@@ -1,41 +0,0 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
class ImagesLoader(BaseLoader):
|
||||
def load_data(self, image_url):
|
||||
"""
|
||||
Loads images from the supplied directory/file and applies CLIP model transformation to represent these images
|
||||
in vector form
|
||||
|
||||
:param image_url: The URL from which the images are to be loaded
|
||||
"""
|
||||
# load model and image preprocessing
|
||||
from embedchain.models.clip_processor import ClipProcessor
|
||||
|
||||
model = ClipProcessor.load_model()
|
||||
if os.path.isfile(image_url):
|
||||
data = [ClipProcessor.get_image_features(image_url, model)]
|
||||
else:
|
||||
data = []
|
||||
for filename in os.listdir(image_url):
|
||||
filepath = os.path.join(image_url, filename)
|
||||
try:
|
||||
data.append(ClipProcessor.get_image_features(filepath, model))
|
||||
except Exception as e:
|
||||
# Log the file that was not loaded
|
||||
logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))
|
||||
# Get the metadata like Size, Last Modified and Last Created timestamps
|
||||
image_path_metadata = [
|
||||
str(os.path.getsize(image_url)),
|
||||
str(os.path.getmtime(image_url)),
|
||||
str(os.path.getctime(image_url)),
|
||||
]
|
||||
doc_id = hashlib.sha256((" ".join(image_path_metadata) + image_url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
try:
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError:
|
||||
raise ImportError("Images requires extra dependencies. Install with `pip install 'embedchain[images]'") from None
|
||||
|
||||
MODEL_NAME = "clip-ViT-B-32"
|
||||
|
||||
|
||||
class ClipProcessor:
|
||||
@staticmethod
|
||||
def load_model():
|
||||
"""Load data from a director of images."""
|
||||
# load model and image preprocessing
|
||||
model = SentenceTransformer(MODEL_NAME)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def get_image_features(image_url, model):
|
||||
"""
|
||||
Applies the CLIP model to evaluate the vector representation of the supplied image
|
||||
"""
|
||||
try:
|
||||
# load image
|
||||
image = Image.open(image_url)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError("The supplied file does not exist`")
|
||||
except UnidentifiedImageError:
|
||||
raise UnidentifiedImageError("The supplied file is not an image`")
|
||||
|
||||
image_features = model.encode(image)
|
||||
meta_data = {"url": image_url}
|
||||
return {"content": image_url, "embedding": image_features.tolist(), "meta_data": meta_data}
|
||||
|
||||
@staticmethod
|
||||
def get_text_features(query):
|
||||
"""
|
||||
Applies the CLIP model to evaluate the vector representation of the supplied text
|
||||
"""
|
||||
model = ClipProcessor.load_model()
|
||||
text_features = model.encode(query)
|
||||
return text_features.tolist()
|
||||
@@ -24,7 +24,7 @@ class IndirectDataType(Enum):
|
||||
NOTION = "notion"
|
||||
CSV = "csv"
|
||||
MDX = "mdx"
|
||||
IMAGES = "images"
|
||||
IMAGE = "image"
|
||||
UNSTRUCTURED = "unstructured"
|
||||
JSON = "json"
|
||||
OPENAPI = "openapi"
|
||||
@@ -62,7 +62,7 @@ class DataType(Enum):
|
||||
CSV = IndirectDataType.CSV.value
|
||||
MDX = IndirectDataType.MDX.value
|
||||
QNA_PAIR = SpecialDataType.QNA_PAIR.value
|
||||
IMAGES = IndirectDataType.IMAGES.value
|
||||
IMAGE = IndirectDataType.IMAGE.value
|
||||
UNSTRUCTURED = IndirectDataType.UNSTRUCTURED.value
|
||||
JSON = IndirectDataType.JSON.value
|
||||
OPENAPI = IndirectDataType.OPENAPI.value
|
||||
|
||||
@@ -132,7 +132,6 @@ class ChromaDB(BaseVectorDB):
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
skip_embedding: bool,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Any:
|
||||
"""
|
||||
@@ -146,13 +145,8 @@ class ChromaDB(BaseVectorDB):
|
||||
:type metadatas: List[object]
|
||||
:param ids: ids
|
||||
:type ids: List[str]
|
||||
:param skip_embedding: Optional. If True, then the embeddings are assumed to be already generated.
|
||||
:type skip_embedding: bool
|
||||
"""
|
||||
size = len(documents)
|
||||
if skip_embedding and (embeddings is None or len(embeddings) != len(documents)):
|
||||
raise ValueError("Cannot add documents to chromadb with inconsistent embeddings")
|
||||
|
||||
if len(documents) != size or len(metadatas) != size or len(ids) != size:
|
||||
raise ValueError(
|
||||
"Cannot add documents to chromadb with inconsistent sizes. Documents size: {}, Metadata size: {},"
|
||||
@@ -160,19 +154,11 @@ class ChromaDB(BaseVectorDB):
|
||||
)
|
||||
|
||||
for i in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in chromadb"):
|
||||
if skip_embedding:
|
||||
self.collection.add(
|
||||
embeddings=embeddings[i : i + self.BATCH_SIZE],
|
||||
documents=documents[i : i + self.BATCH_SIZE],
|
||||
metadatas=metadatas[i : i + self.BATCH_SIZE],
|
||||
ids=ids[i : i + self.BATCH_SIZE],
|
||||
)
|
||||
else:
|
||||
self.collection.add(
|
||||
documents=documents[i : i + self.BATCH_SIZE],
|
||||
metadatas=metadatas[i : i + self.BATCH_SIZE],
|
||||
ids=ids[i : i + self.BATCH_SIZE],
|
||||
)
|
||||
self.collection.add(
|
||||
documents=documents[i : i + self.BATCH_SIZE],
|
||||
metadatas=metadatas[i : i + self.BATCH_SIZE],
|
||||
ids=ids[i : i + self.BATCH_SIZE],
|
||||
)
|
||||
|
||||
def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
|
||||
"""
|
||||
@@ -197,7 +183,6 @@ class ChromaDB(BaseVectorDB):
|
||||
input_query: List[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
skip_embedding: bool,
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
@@ -210,8 +195,6 @@ class ChromaDB(BaseVectorDB):
|
||||
:type n_results: int
|
||||
:param where: to filter data
|
||||
:type where: Dict[str, Any]
|
||||
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
||||
:type skip_embedding: bool
|
||||
:param citations: we use citations boolean param to return context along with the answer.
|
||||
:type citations: bool, default is False.
|
||||
:raises InvalidDimensionException: Dimensions do not match.
|
||||
@@ -220,24 +203,14 @@ class ChromaDB(BaseVectorDB):
|
||||
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||
"""
|
||||
try:
|
||||
if skip_embedding:
|
||||
result = self.collection.query(
|
||||
query_embeddings=[
|
||||
input_query,
|
||||
],
|
||||
n_results=n_results,
|
||||
where=self._generate_where_clause(where),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
result = self.collection.query(
|
||||
query_texts=[
|
||||
input_query,
|
||||
],
|
||||
n_results=n_results,
|
||||
where=self._generate_where_clause(where),
|
||||
**kwargs,
|
||||
)
|
||||
result = self.collection.query(
|
||||
query_texts=[
|
||||
input_query,
|
||||
],
|
||||
n_results=n_results,
|
||||
where=self._generate_where_clause(where),
|
||||
**kwargs,
|
||||
)
|
||||
except InvalidDimensionException as e:
|
||||
raise InvalidDimensionException(
|
||||
e.message()
|
||||
|
||||
@@ -114,7 +114,6 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
skip_embedding: bool,
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
) -> Any:
|
||||
"""
|
||||
@@ -127,12 +126,9 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
:type metadatas: List[object]
|
||||
:param ids: ids of docs
|
||||
:type ids: List[str]
|
||||
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
||||
:type skip_embedding: bool
|
||||
"""
|
||||
|
||||
if not skip_embedding:
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
|
||||
for chunk in chunks(
|
||||
list(zip(ids, documents, metadatas, embeddings)), self.BATCH_SIZE, desc="Inserting batches in elasticsearch"
|
||||
@@ -161,7 +157,6 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
input_query: List[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
skip_embedding: bool,
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
@@ -174,8 +169,6 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
||||
:type skip_embedding: bool
|
||||
:return: The context of the document that matched your query, url of the source, doc_id
|
||||
:param citations: we use citations boolean param to return context along with the answer.
|
||||
:type citations: bool, default is False.
|
||||
@@ -183,11 +176,8 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
along with url of the source and doc_id (if citations flag is true)
|
||||
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||
"""
|
||||
if skip_embedding:
|
||||
query_vector = input_query
|
||||
else:
|
||||
input_query_vector = self.embedder.embedding_fn(input_query)
|
||||
query_vector = input_query_vector[0]
|
||||
input_query_vector = self.embedder.embedding_fn(input_query)
|
||||
query_vector = input_query_vector[0]
|
||||
|
||||
# `https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html`
|
||||
query = {
|
||||
|
||||
@@ -120,7 +120,6 @@ class OpenSearchDB(BaseVectorDB):
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
skip_embedding: bool,
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
):
|
||||
"""Add data in vector database.
|
||||
@@ -130,17 +129,11 @@ class OpenSearchDB(BaseVectorDB):
|
||||
documents (List[str]): List of texts to add.
|
||||
metadatas (List[object]): List of metadata associated with docs.
|
||||
ids (List[str]): IDs of docs.
|
||||
skip_embedding (bool): If True, then embeddings are assumed to be already generated.
|
||||
"""
|
||||
for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"):
|
||||
batch_end = batch_start + self.BATCH_SIZE
|
||||
batch_documents = documents[batch_start:batch_end]
|
||||
|
||||
# Generate embeddings for the batch if not skipping embedding
|
||||
if not skip_embedding:
|
||||
batch_embeddings = self.embedder.embedding_fn(batch_documents)
|
||||
else:
|
||||
batch_embeddings = embeddings[batch_start:batch_end]
|
||||
batch_embeddings = embeddings[batch_start:batch_end]
|
||||
|
||||
# Create document entries for bulk upload
|
||||
batch_entries = [
|
||||
@@ -166,7 +159,6 @@ class OpenSearchDB(BaseVectorDB):
|
||||
input_query: List[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
skip_embedding: bool,
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
@@ -179,15 +171,12 @@ class OpenSearchDB(BaseVectorDB):
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
|
||||
:type skip_embedding: bool
|
||||
:param citations: we use citations boolean param to return context along with the answer.
|
||||
:type citations: bool, default is False.
|
||||
:return: The content of the document that matched your query,
|
||||
along with url of the source and doc_id (if citations flag is true)
|
||||
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||
"""
|
||||
# TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists
|
||||
embeddings = OpenAIEmbeddings()
|
||||
docsearch = OpenSearchVectorSearch(
|
||||
index_name=self._get_index(),
|
||||
|
||||
@@ -92,7 +92,6 @@ class PineconeDB(BaseVectorDB):
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
skip_embedding: bool,
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
):
|
||||
"""add data in vector database
|
||||
@@ -124,7 +123,6 @@ class PineconeDB(BaseVectorDB):
|
||||
input_query: List[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
skip_embedding: bool,
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
@@ -136,18 +134,13 @@ class PineconeDB(BaseVectorDB):
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:param skip_embedding: Optional. if True, input_query is already embedded
|
||||
:type skip_embedding: bool
|
||||
:param citations: we use citations boolean param to return context along with the answer.
|
||||
:type citations: bool, default is False.
|
||||
:return: The content of the document that matched your query,
|
||||
along with url of the source and doc_id (if citations flag is true)
|
||||
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||
"""
|
||||
if not skip_embedding:
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
else:
|
||||
query_vector = input_query
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs)
|
||||
contexts = []
|
||||
for doc in data["matches"]:
|
||||
|
||||
@@ -126,7 +126,6 @@ class QdrantDB(BaseVectorDB):
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
skip_embedding: bool,
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
):
|
||||
"""add data in vector database
|
||||
@@ -138,12 +137,8 @@ class QdrantDB(BaseVectorDB):
|
||||
:type metadatas: List[object]
|
||||
:param ids: ids of docs
|
||||
:type ids: List[str]
|
||||
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
||||
generated or not
|
||||
:type skip_embedding: bool
|
||||
"""
|
||||
if not skip_embedding:
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
|
||||
payloads = []
|
||||
qdrant_ids = []
|
||||
@@ -167,7 +162,6 @@ class QdrantDB(BaseVectorDB):
|
||||
input_query: List[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
skip_embedding: bool,
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
@@ -179,20 +173,13 @@ class QdrantDB(BaseVectorDB):
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
||||
generated or not
|
||||
:type skip_embedding: bool
|
||||
:param citations: we use citations boolean param to return context along with the answer.
|
||||
:type citations: bool, default is False.
|
||||
:return: The content of the document that matched your query,
|
||||
along with url of the source and doc_id (if citations flag is true)
|
||||
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||
"""
|
||||
if not skip_embedding:
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
else:
|
||||
query_vector = input_query
|
||||
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
keys = set(where.keys() if where is not None else set())
|
||||
|
||||
qdrant_must_filters = []
|
||||
|
||||
@@ -157,7 +157,6 @@ class WeaviateDB(BaseVectorDB):
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
skip_embedding: bool,
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
):
|
||||
"""add data in vector database
|
||||
@@ -169,14 +168,8 @@ class WeaviateDB(BaseVectorDB):
|
||||
:type metadatas: List[object]
|
||||
:param ids: ids of docs
|
||||
:type ids: List[str]
|
||||
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
||||
generated or not
|
||||
:type skip_embedding: bool
|
||||
"""
|
||||
|
||||
print("Adding documents to Weaviate...")
|
||||
if not skip_embedding:
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch
|
||||
with self.client.batch as batch: # Initialize a batch process
|
||||
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
|
||||
@@ -202,7 +195,6 @@ class WeaviateDB(BaseVectorDB):
|
||||
input_query: List[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
skip_embedding: bool,
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
@@ -214,20 +206,13 @@ class WeaviateDB(BaseVectorDB):
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
||||
generated or not
|
||||
:type skip_embedding: bool
|
||||
:param citations: we use citations boolean param to return context along with the answer.
|
||||
:type citations: bool, default is False.
|
||||
:return: The content of the document that matched your query,
|
||||
along with url of the source and doc_id (if citations flag is true)
|
||||
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
|
||||
"""
|
||||
if not skip_embedding:
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
else:
|
||||
query_vector = input_query
|
||||
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
keys = set(where.keys() if where is not None else set())
|
||||
data_fields = ["text"]
|
||||
|
||||
|
||||
@@ -112,12 +112,10 @@ class ZillizVectorDB(BaseVectorDB):
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
skip_embedding: bool,
|
||||
**kwargs: Optional[Dict[str, any]],
|
||||
):
|
||||
"""Add to database"""
|
||||
if not skip_embedding:
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
|
||||
for id, doc, metadata, embedding in zip(ids, documents, metadatas, embeddings):
|
||||
data = {**metadata, "id": id, "text": doc, "embeddings": embedding}
|
||||
@@ -132,7 +130,6 @@ class ZillizVectorDB(BaseVectorDB):
|
||||
input_query: List[str],
|
||||
n_results: int,
|
||||
where: Dict[str, any],
|
||||
skip_embedding: bool,
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[Dict[str, Any]],
|
||||
) -> Union[List[Tuple[str, Dict]], List[str]]:
|
||||
@@ -160,27 +157,16 @@ class ZillizVectorDB(BaseVectorDB):
|
||||
where = None
|
||||
|
||||
output_fields = ["*"]
|
||||
if skip_embedding:
|
||||
query_vector = input_query
|
||||
query_result = self.client.search(
|
||||
collection_name=self.config.collection_name,
|
||||
data=query_vector,
|
||||
limit=n_results,
|
||||
output_fields=output_fields,
|
||||
**kwargs,
|
||||
)
|
||||
input_query_vector = self.embedder.embedding_fn([input_query])
|
||||
query_vector = input_query_vector[0]
|
||||
|
||||
else:
|
||||
input_query_vector = self.embedder.embedding_fn([input_query])
|
||||
query_vector = input_query_vector[0]
|
||||
|
||||
query_result = self.client.search(
|
||||
collection_name=self.config.collection_name,
|
||||
data=[query_vector],
|
||||
limit=n_results,
|
||||
output_fields=output_fields,
|
||||
**kwargs,
|
||||
)
|
||||
query_result = self.client.search(
|
||||
collection_name=self.config.collection_name,
|
||||
data=[query_vector],
|
||||
limit=n_results,
|
||||
output_fields=output_fields,
|
||||
**kwargs,
|
||||
)
|
||||
query_result = query_result[0]
|
||||
contexts = []
|
||||
for query in query_result:
|
||||
|
||||
Reference in New Issue
Block a user