diff --git a/docs/components/data-sources/image.mdx b/docs/components/data-sources/image.mdx
new file mode 100644
index 00000000..b7904366
--- /dev/null
+++ b/docs/components/data-sources/image.mdx
@@ -0,0 +1,45 @@
+---
+title: "🖼️ Image"
+---
+
+
+To use an image as data source, just add `data_type` as `image` and pass in the path of the image (local or hosted).
+
+We use [GPT4 Vision](https://platform.openai.com/docs/guides/vision) to generate meaning of the image using a custom prompt, and then use the generated text as the data source.
+
+You would require an OpenAI API key with access to `gpt-4-vision-preview` model to use this feature.
+
+### Without customization
+
+```python
+import os
+from embedchain import App
+
+os.environ["OPENAI_API_KEY"] = "sk-xxx"
+
+app = App()
+app.add("./Elon-Musk.webp", data_type="image")
+response = app.query("Describe the man in the image.")
+print(response)
+# Answer: The man in the image is dressed in formal attire, wearing a dark suit jacket and a white collared shirt. He has short hair and is standing. He appears to be gazing off to the side with a reflective expression. The background is dark with faint, warm-toned vertical lines, possibly from a lit environment behind the individual or reflections. The overall atmosphere is somewhat moody and introspective.
+```
+
+### Customization
+
+```python
+import os
+from embedchain import App
+from embedchain.loaders.image import ImageLoader
+
+image_loader = ImageLoader(
+ max_tokens=100,
+ api_key="sk-xxx",
+ prompt="Is the person looking wealthy? Structure your thoughts around what you see in the image.",
+)
+
+app = App()
+app.add("./Elon-Musk.webp", data_type="image", loader=image_loader)
+response = app.query("Describe the man in the image.")
+print(response)
+# Answer: The man in the image appears to be well-dressed in a suit and shirt, suggesting that he may be in a professional or formal setting. His composed demeanor and confident posture further indicate a sense of self-assurance. Based on these visual cues, one could infer that the man may have a certain level of economic or social status, possibly indicating wealth or professional success.
+```
diff --git a/docs/components/data-sources/overview.mdx b/docs/components/data-sources/overview.mdx
index 878614f9..ed963aff 100644
--- a/docs/components/data-sources/overview.mdx
+++ b/docs/components/data-sources/overview.mdx
@@ -31,6 +31,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
+
diff --git a/docs/mint.json b/docs/mint.json
index c345cabc..b4f8a641 100644
--- a/docs/mint.json
+++ b/docs/mint.json
@@ -119,7 +119,9 @@
"components/data-sources/discourse",
"components/data-sources/substack",
"components/data-sources/beehiiv",
+ "components/data-sources/directory",
"components/data-sources/dropbox",
+ "components/data-sources/image",
"components/data-sources/custom"
]
},
diff --git a/embedchain/app.py b/embedchain/app.py
index 0906a3c6..18042e2f 100644
--- a/embedchain/app.py
+++ b/embedchain/app.py
@@ -249,7 +249,6 @@ class App(EmbedChain):
query,
n_results=num_documents,
where=where,
- skip_embedding=False,
citations=True,
)
result = []
diff --git a/embedchain/chunkers/image.py b/embedchain/chunkers/image.py
new file mode 100644
index 00000000..d29a84f4
--- /dev/null
+++ b/embedchain/chunkers/image.py
@@ -0,0 +1,22 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
+from embedchain.helpers.json_serializable import register_deserializable
+
+
+@register_deserializable
+class ImageChunker(BaseChunker):
+ """Chunker for Images."""
+
+ def __init__(self, config: Optional[ChunkerConfig] = None):
+ if config is None:
+ config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
+ text_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=config.chunk_size,
+ chunk_overlap=config.chunk_overlap,
+ length_function=config.length_function,
+ )
+ super().__init__(text_splitter)
diff --git a/embedchain/chunkers/images.py b/embedchain/chunkers/images.py
deleted file mode 100644
index 8e0ac03d..00000000
--- a/embedchain/chunkers/images.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import hashlib
-import logging
-from typing import Optional
-
-from langchain.text_splitter import RecursiveCharacterTextSplitter
-
-from embedchain.chunkers.base_chunker import BaseChunker
-from embedchain.config.add_config import ChunkerConfig
-
-
-class ImagesChunker(BaseChunker):
- """Chunker for an Image."""
-
- def __init__(self, config: Optional[ChunkerConfig] = None):
- if config is None:
- config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
- image_splitter = RecursiveCharacterTextSplitter(
- chunk_size=config.chunk_size,
- chunk_overlap=config.chunk_overlap,
- length_function=config.length_function,
- )
- super().__init__(image_splitter)
-
- def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
- """
- Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
-
- :param loader: The loader whose `load_data` method is used to create
- the raw data.
- :param src: The data to be handled by the loader. Can be a URL for
- remote sources or local content for local loaders.
- """
- documents = []
- embeddings = []
- ids = []
- min_chunk_size = config.min_chunk_size if config is not None else 0
- logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
- data_result = loader.load_data(src)
- data_records = data_result["data"]
- doc_id = data_result["doc_id"]
- doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
- metadatas = []
- for data in data_records:
- meta_data = data["meta_data"]
- # add data type to meta data to allow query using data type
- meta_data["data_type"] = self.data_type.value
- chunk_id = hashlib.sha256(meta_data["url"].encode()).hexdigest()
- ids.append(chunk_id)
- documents.append(data["content"])
- embeddings.append(data["embedding"])
- meta_data["doc_id"] = doc_id
- metadatas.append(meta_data)
-
- return {
- "documents": documents,
- "embeddings": embeddings,
- "ids": ids,
- "metadatas": metadatas,
- "doc_id": doc_id,
- }
-
- def get_word_count(self, documents):
- """
- The number of chunks and the corresponding word count for an image is fixed to 1, as 1 embedding is created for
- each image
- """
- return 1
diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py
index 36f5c719..9ec7c258 100644
--- a/embedchain/data_formatter/data_formatter.py
+++ b/embedchain/data_formatter/data_formatter.py
@@ -63,7 +63,7 @@ class DataFormatter(JSONSerializable):
DataType.DOCS_SITE: "embedchain.loaders.docs_site_loader.DocsSiteLoader",
DataType.CSV: "embedchain.loaders.csv.CsvLoader",
DataType.MDX: "embedchain.loaders.mdx.MdxLoader",
- DataType.IMAGES: "embedchain.loaders.images.ImagesLoader",
+ DataType.IMAGE: "embedchain.loaders.image.ImageLoader",
DataType.UNSTRUCTURED: "embedchain.loaders.unstructured_file.UnstructuredLoader",
DataType.JSON: "embedchain.loaders.json.JSONLoader",
DataType.OPENAPI: "embedchain.loaders.openapi.OpenAPILoader",
@@ -108,7 +108,7 @@ class DataFormatter(JSONSerializable):
DataType.DOCS_SITE: "embedchain.chunkers.docs_site.DocsSiteChunker",
DataType.CSV: "embedchain.chunkers.table.TableChunker",
DataType.MDX: "embedchain.chunkers.mdx.MdxChunker",
- DataType.IMAGES: "embedchain.chunkers.images.ImagesChunker",
+ DataType.IMAGE: "embedchain.chunkers.image.ImageChunker",
DataType.UNSTRUCTURED: "embedchain.chunkers.unstructured_file.UnstructuredFileChunker",
DataType.JSON: "embedchain.chunkers.json.JSONChunker",
DataType.OPENAPI: "embedchain.chunkers.openapi.OpenAPIChunker",
diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py
index d634f3eb..3f1ad6a4 100644
--- a/embedchain/embedchain.py
+++ b/embedchain/embedchain.py
@@ -438,7 +438,6 @@ class EmbedChain(JSONSerializable):
documents=documents,
metadatas=metadatas,
ids=ids,
- skip_embedding=(chunker.data_type == DataType.IMAGES),
**kwargs,
)
count_new_chunks = self.db.count() - chunks_before_addition
@@ -490,21 +489,10 @@ class EmbedChain(JSONSerializable):
if self.config.id is not None:
where.update({"app_id": self.config.id})
- # We cannot query the database with the input query in case of an image search. This is because we need
- # to bring down both the image and text to the same dimension to be able to compare them.
- db_query = input_query
- if hasattr(config, "query_type") and config.query_type == "Images":
- # We import the clip processor here to make sure the package is not dependent on clip dependency even if the
- # image dataset is not being used
- from embedchain.models.clip_processor import ClipProcessor
-
- db_query = ClipProcessor.get_text_features(query=input_query)
-
contexts = self.db.query(
- input_query=db_query,
+ input_query=input_query,
n_results=query_config.number_documents,
where=where,
- skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
citations=citations,
**kwargs,
)
diff --git a/embedchain/loaders/image.py b/embedchain/loaders/image.py
new file mode 100644
index 00000000..911f6d34
--- /dev/null
+++ b/embedchain/loaders/image.py
@@ -0,0 +1,49 @@
+import base64
+import hashlib
+import os
+from pathlib import Path
+
+from openai import OpenAI
+
+from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.loaders.base_loader import BaseLoader
+
+DESCRIBE_IMAGE_PROMPT = "Describe the image:"
+
+
+@register_deserializable
+class ImageLoader(BaseLoader):
+ def __init__(self, max_tokens: int = 500, api_key: str = None, prompt: str = None):
+ super().__init__()
+ self.custom_prompt = prompt or DESCRIBE_IMAGE_PROMPT
+ self.max_tokens = max_tokens
+ self.api_key = api_key or os.environ["OPENAI_API_KEY"]
+ self.client = OpenAI(api_key=self.api_key)
+
+ def _encode_image(self, image_path: str):
+ with open(image_path, "rb") as image_file:
+ return base64.b64encode(image_file.read()).decode("utf-8")
+
+ def _create_completion_request(self, content: str):
+ return self.client.chat.completions.create(
+ model="gpt-4-vision-preview", messages=[{"role": "user", "content": content}], max_tokens=self.max_tokens
+ )
+
+ def _process_url(self, url: str):
+ if url.startswith("http"):
+ return [{"type": "text", "text": self.custom_prompt}, {"type": "image_url", "image_url": {"url": url}}]
+ elif Path(url).is_file():
+ extension = Path(url).suffix.lstrip(".")
+ encoded_image = self._encode_image(url)
+ image_data = f"data:image/{extension};base64,{encoded_image}"
+ return [{"type": "text", "text": self.custom_prompt}, {"type": "image", "image_url": {"url": image_data}}]
+ else:
+ raise ValueError(f"Invalid URL or file path: {url}")
+
+ def load_data(self, url: str):
+ content = self._process_url(url)
+ response = self._create_completion_request(content)
+ content = response.choices[0].message.content
+
+ doc_id = hashlib.sha256((content + url).encode()).hexdigest()
+ return {"doc_id": doc_id, "data": [{"content": content, "meta_data": {"url": url, "type": "image"}}]}
diff --git a/embedchain/loaders/images.py b/embedchain/loaders/images.py
deleted file mode 100644
index bd954b0d..00000000
--- a/embedchain/loaders/images.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import hashlib
-import logging
-import os
-
-from embedchain.loaders.base_loader import BaseLoader
-
-
-class ImagesLoader(BaseLoader):
- def load_data(self, image_url):
- """
- Loads images from the supplied directory/file and applies CLIP model transformation to represent these images
- in vector form
-
- :param image_url: The URL from which the images are to be loaded
- """
- # load model and image preprocessing
- from embedchain.models.clip_processor import ClipProcessor
-
- model = ClipProcessor.load_model()
- if os.path.isfile(image_url):
- data = [ClipProcessor.get_image_features(image_url, model)]
- else:
- data = []
- for filename in os.listdir(image_url):
- filepath = os.path.join(image_url, filename)
- try:
- data.append(ClipProcessor.get_image_features(filepath, model))
- except Exception as e:
- # Log the file that was not loaded
- logging.exception("Failed to load the file {}. Exception {}".format(filepath, e))
- # Get the metadata like Size, Last Modified and Last Created timestamps
- image_path_metadata = [
- str(os.path.getsize(image_url)),
- str(os.path.getmtime(image_url)),
- str(os.path.getctime(image_url)),
- ]
- doc_id = hashlib.sha256((" ".join(image_path_metadata) + image_url).encode()).hexdigest()
- return {
- "doc_id": doc_id,
- "data": data,
- }
diff --git a/embedchain/models/clip_processor.py b/embedchain/models/clip_processor.py
deleted file mode 100644
index 46a89c16..00000000
--- a/embedchain/models/clip_processor.py
+++ /dev/null
@@ -1,42 +0,0 @@
-try:
- from PIL import Image, UnidentifiedImageError
- from sentence_transformers import SentenceTransformer
-except ImportError:
- raise ImportError("Images requires extra dependencies. Install with `pip install 'embedchain[images]'") from None
-
-MODEL_NAME = "clip-ViT-B-32"
-
-
-class ClipProcessor:
- @staticmethod
- def load_model():
- """Load data from a director of images."""
- # load model and image preprocessing
- model = SentenceTransformer(MODEL_NAME)
- return model
-
- @staticmethod
- def get_image_features(image_url, model):
- """
- Applies the CLIP model to evaluate the vector representation of the supplied image
- """
- try:
- # load image
- image = Image.open(image_url)
- except FileNotFoundError:
- raise FileNotFoundError("The supplied file does not exist`")
- except UnidentifiedImageError:
- raise UnidentifiedImageError("The supplied file is not an image`")
-
- image_features = model.encode(image)
- meta_data = {"url": image_url}
- return {"content": image_url, "embedding": image_features.tolist(), "meta_data": meta_data}
-
- @staticmethod
- def get_text_features(query):
- """
- Applies the CLIP model to evaluate the vector representation of the supplied text
- """
- model = ClipProcessor.load_model()
- text_features = model.encode(query)
- return text_features.tolist()
diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py
index fb9da2cd..23a1fffc 100644
--- a/embedchain/models/data_type.py
+++ b/embedchain/models/data_type.py
@@ -24,7 +24,7 @@ class IndirectDataType(Enum):
NOTION = "notion"
CSV = "csv"
MDX = "mdx"
- IMAGES = "images"
+ IMAGE = "image"
UNSTRUCTURED = "unstructured"
JSON = "json"
OPENAPI = "openapi"
@@ -62,7 +62,7 @@ class DataType(Enum):
CSV = IndirectDataType.CSV.value
MDX = IndirectDataType.MDX.value
QNA_PAIR = SpecialDataType.QNA_PAIR.value
- IMAGES = IndirectDataType.IMAGES.value
+ IMAGE = IndirectDataType.IMAGE.value
UNSTRUCTURED = IndirectDataType.UNSTRUCTURED.value
JSON = IndirectDataType.JSON.value
OPENAPI = IndirectDataType.OPENAPI.value
diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py
index 32e528f0..7763c207 100644
--- a/embedchain/vectordb/chroma.py
+++ b/embedchain/vectordb/chroma.py
@@ -132,7 +132,6 @@ class ChromaDB(BaseVectorDB):
documents: List[str],
metadatas: List[object],
ids: List[str],
- skip_embedding: bool,
**kwargs: Optional[Dict[str, Any]],
) -> Any:
"""
@@ -146,13 +145,8 @@ class ChromaDB(BaseVectorDB):
:type metadatas: List[object]
:param ids: ids
:type ids: List[str]
- :param skip_embedding: Optional. If True, then the embeddings are assumed to be already generated.
- :type skip_embedding: bool
"""
size = len(documents)
- if skip_embedding and (embeddings is None or len(embeddings) != len(documents)):
- raise ValueError("Cannot add documents to chromadb with inconsistent embeddings")
-
if len(documents) != size or len(metadatas) != size or len(ids) != size:
raise ValueError(
"Cannot add documents to chromadb with inconsistent sizes. Documents size: {}, Metadata size: {},"
@@ -160,19 +154,11 @@ class ChromaDB(BaseVectorDB):
)
for i in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in chromadb"):
- if skip_embedding:
- self.collection.add(
- embeddings=embeddings[i : i + self.BATCH_SIZE],
- documents=documents[i : i + self.BATCH_SIZE],
- metadatas=metadatas[i : i + self.BATCH_SIZE],
- ids=ids[i : i + self.BATCH_SIZE],
- )
- else:
- self.collection.add(
- documents=documents[i : i + self.BATCH_SIZE],
- metadatas=metadatas[i : i + self.BATCH_SIZE],
- ids=ids[i : i + self.BATCH_SIZE],
- )
+ self.collection.add(
+ documents=documents[i : i + self.BATCH_SIZE],
+ metadatas=metadatas[i : i + self.BATCH_SIZE],
+ ids=ids[i : i + self.BATCH_SIZE],
+ )
def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
"""
@@ -197,7 +183,6 @@ class ChromaDB(BaseVectorDB):
input_query: List[str],
n_results: int,
where: Dict[str, any],
- skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
@@ -210,8 +195,6 @@ class ChromaDB(BaseVectorDB):
:type n_results: int
:param where: to filter data
:type where: Dict[str, Any]
- :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
- :type skip_embedding: bool
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:raises InvalidDimensionException: Dimensions do not match.
@@ -220,24 +203,14 @@ class ChromaDB(BaseVectorDB):
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
"""
try:
- if skip_embedding:
- result = self.collection.query(
- query_embeddings=[
- input_query,
- ],
- n_results=n_results,
- where=self._generate_where_clause(where),
- **kwargs,
- )
- else:
- result = self.collection.query(
- query_texts=[
- input_query,
- ],
- n_results=n_results,
- where=self._generate_where_clause(where),
- **kwargs,
- )
+ result = self.collection.query(
+ query_texts=[
+ input_query,
+ ],
+ n_results=n_results,
+ where=self._generate_where_clause(where),
+ **kwargs,
+ )
except InvalidDimensionException as e:
raise InvalidDimensionException(
e.message()
diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py
index 11610b12..62744f5d 100644
--- a/embedchain/vectordb/elasticsearch.py
+++ b/embedchain/vectordb/elasticsearch.py
@@ -114,7 +114,6 @@ class ElasticsearchDB(BaseVectorDB):
documents: List[str],
metadatas: List[object],
ids: List[str],
- skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
) -> Any:
"""
@@ -127,12 +126,9 @@ class ElasticsearchDB(BaseVectorDB):
:type metadatas: List[object]
:param ids: ids of docs
:type ids: List[str]
- :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
- :type skip_embedding: bool
"""
- if not skip_embedding:
- embeddings = self.embedder.embedding_fn(documents)
+ embeddings = self.embedder.embedding_fn(documents)
for chunk in chunks(
list(zip(ids, documents, metadatas, embeddings)), self.BATCH_SIZE, desc="Inserting batches in elasticsearch"
@@ -161,7 +157,6 @@ class ElasticsearchDB(BaseVectorDB):
input_query: List[str],
n_results: int,
where: Dict[str, any],
- skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
@@ -174,8 +169,6 @@ class ElasticsearchDB(BaseVectorDB):
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
- :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
- :type skip_embedding: bool
:return: The context of the document that matched your query, url of the source, doc_id
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
@@ -183,11 +176,8 @@ class ElasticsearchDB(BaseVectorDB):
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
"""
- if skip_embedding:
- query_vector = input_query
- else:
- input_query_vector = self.embedder.embedding_fn(input_query)
- query_vector = input_query_vector[0]
+ input_query_vector = self.embedder.embedding_fn(input_query)
+ query_vector = input_query_vector[0]
# `https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html`
query = {
diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py
index a1f408f1..365ccda3 100644
--- a/embedchain/vectordb/opensearch.py
+++ b/embedchain/vectordb/opensearch.py
@@ -120,7 +120,6 @@ class OpenSearchDB(BaseVectorDB):
documents: List[str],
metadatas: List[object],
ids: List[str],
- skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""Add data in vector database.
@@ -130,17 +129,11 @@ class OpenSearchDB(BaseVectorDB):
documents (List[str]): List of texts to add.
metadatas (List[object]): List of metadata associated with docs.
ids (List[str]): IDs of docs.
- skip_embedding (bool): If True, then embeddings are assumed to be already generated.
"""
for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"):
batch_end = batch_start + self.BATCH_SIZE
batch_documents = documents[batch_start:batch_end]
-
- # Generate embeddings for the batch if not skipping embedding
- if not skip_embedding:
- batch_embeddings = self.embedder.embedding_fn(batch_documents)
- else:
- batch_embeddings = embeddings[batch_start:batch_end]
+ batch_embeddings = embeddings[batch_start:batch_end]
# Create document entries for bulk upload
batch_entries = [
@@ -166,7 +159,6 @@ class OpenSearchDB(BaseVectorDB):
input_query: List[str],
n_results: int,
where: Dict[str, any],
- skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
@@ -179,15 +171,12 @@ class OpenSearchDB(BaseVectorDB):
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
- :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
- :type skip_embedding: bool
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
"""
- # TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists
embeddings = OpenAIEmbeddings()
docsearch = OpenSearchVectorSearch(
index_name=self._get_index(),
diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py
index cd039d62..dd3da4ac 100644
--- a/embedchain/vectordb/pinecone.py
+++ b/embedchain/vectordb/pinecone.py
@@ -92,7 +92,6 @@ class PineconeDB(BaseVectorDB):
documents: List[str],
metadatas: List[object],
ids: List[str],
- skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""add data in vector database
@@ -124,7 +123,6 @@ class PineconeDB(BaseVectorDB):
input_query: List[str],
n_results: int,
where: Dict[str, any],
- skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
@@ -136,18 +134,13 @@ class PineconeDB(BaseVectorDB):
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
- :param skip_embedding: Optional. if True, input_query is already embedded
- :type skip_embedding: bool
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
"""
- if not skip_embedding:
- query_vector = self.embedder.embedding_fn([input_query])[0]
- else:
- query_vector = input_query
+ query_vector = self.embedder.embedding_fn([input_query])[0]
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs)
contexts = []
for doc in data["matches"]:
diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py
index e9df0217..be2d9523 100644
--- a/embedchain/vectordb/qdrant.py
+++ b/embedchain/vectordb/qdrant.py
@@ -126,7 +126,6 @@ class QdrantDB(BaseVectorDB):
documents: List[str],
metadatas: List[object],
ids: List[str],
- skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""add data in vector database
@@ -138,12 +137,8 @@ class QdrantDB(BaseVectorDB):
:type metadatas: List[object]
:param ids: ids of docs
:type ids: List[str]
- :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
- generated or not
- :type skip_embedding: bool
"""
- if not skip_embedding:
- embeddings = self.embedder.embedding_fn(documents)
+ embeddings = self.embedder.embedding_fn(documents)
payloads = []
qdrant_ids = []
@@ -167,7 +162,6 @@ class QdrantDB(BaseVectorDB):
input_query: List[str],
n_results: int,
where: Dict[str, any],
- skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
@@ -179,20 +173,13 @@ class QdrantDB(BaseVectorDB):
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
- :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
- generated or not
- :type skip_embedding: bool
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
"""
- if not skip_embedding:
- query_vector = self.embedder.embedding_fn([input_query])[0]
- else:
- query_vector = input_query
-
+ query_vector = self.embedder.embedding_fn([input_query])[0]
keys = set(where.keys() if where is not None else set())
qdrant_must_filters = []
diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py
index 620087bf..08861bb7 100644
--- a/embedchain/vectordb/weaviate.py
+++ b/embedchain/vectordb/weaviate.py
@@ -157,7 +157,6 @@ class WeaviateDB(BaseVectorDB):
documents: List[str],
metadatas: List[object],
ids: List[str],
- skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""add data in vector database
@@ -169,14 +168,8 @@ class WeaviateDB(BaseVectorDB):
:type metadatas: List[object]
:param ids: ids of docs
:type ids: List[str]
- :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
- generated or not
- :type skip_embedding: bool
"""
-
- print("Adding documents to Weaviate...")
- if not skip_embedding:
- embeddings = self.embedder.embedding_fn(documents)
+ embeddings = self.embedder.embedding_fn(documents)
self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch
with self.client.batch as batch: # Initialize a batch process
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
@@ -202,7 +195,6 @@ class WeaviateDB(BaseVectorDB):
input_query: List[str],
n_results: int,
where: Dict[str, any],
- skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
@@ -214,20 +206,13 @@ class WeaviateDB(BaseVectorDB):
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
- :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
- generated or not
- :type skip_embedding: bool
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
"""
- if not skip_embedding:
- query_vector = self.embedder.embedding_fn([input_query])[0]
- else:
- query_vector = input_query
-
+ query_vector = self.embedder.embedding_fn([input_query])[0]
keys = set(where.keys() if where is not None else set())
data_fields = ["text"]
diff --git a/embedchain/vectordb/zilliz.py b/embedchain/vectordb/zilliz.py
index ca398f14..35bd2230 100644
--- a/embedchain/vectordb/zilliz.py
+++ b/embedchain/vectordb/zilliz.py
@@ -112,12 +112,10 @@ class ZillizVectorDB(BaseVectorDB):
documents: List[str],
metadatas: List[object],
ids: List[str],
- skip_embedding: bool,
**kwargs: Optional[Dict[str, any]],
):
"""Add to database"""
- if not skip_embedding:
- embeddings = self.embedder.embedding_fn(documents)
+ embeddings = self.embedder.embedding_fn(documents)
for id, doc, metadata, embedding in zip(ids, documents, metadatas, embeddings):
data = {**metadata, "id": id, "text": doc, "embeddings": embedding}
@@ -132,7 +130,6 @@ class ZillizVectorDB(BaseVectorDB):
input_query: List[str],
n_results: int,
where: Dict[str, any],
- skip_embedding: bool,
citations: bool = False,
**kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]:
@@ -160,27 +157,16 @@ class ZillizVectorDB(BaseVectorDB):
where = None
output_fields = ["*"]
- if skip_embedding:
- query_vector = input_query
- query_result = self.client.search(
- collection_name=self.config.collection_name,
- data=query_vector,
- limit=n_results,
- output_fields=output_fields,
- **kwargs,
- )
+ input_query_vector = self.embedder.embedding_fn([input_query])
+ query_vector = input_query_vector[0]
- else:
- input_query_vector = self.embedder.embedding_fn([input_query])
- query_vector = input_query_vector[0]
-
- query_result = self.client.search(
- collection_name=self.config.collection_name,
- data=[query_vector],
- limit=n_results,
- output_fields=output_fields,
- **kwargs,
- )
+ query_result = self.client.search(
+ collection_name=self.config.collection_name,
+ data=[query_vector],
+ limit=n_results,
+ output_fields=output_fields,
+ **kwargs,
+ )
query_result = query_result[0]
contexts = []
for query in query_result:
diff --git a/poetry.lock b/poetry.lock
index 5550ea23..99c28892 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]]
name = "aiofiles"
@@ -333,6 +333,26 @@ description = "The uncompromising code formatter."
optional = false
python-versions = ">=3.8"
files = [
+ {file = "black-23.9.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:d6bc09188020c9ac2555a498949401ab35bb6bf76d4e0f8ee251694664df6301"},
+ {file = "black-23.9.1-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:13ef033794029b85dfea8032c9d3b92b42b526f1ff4bf13b2182ce4e917f5100"},
+ {file = "black-23.9.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:75a2dc41b183d4872d3a500d2b9c9016e67ed95738a3624f4751a0cb4818fe71"},
+ {file = "black-23.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13a2e4a93bb8ca74a749b6974925c27219bb3df4d42fc45e948a5d9feb5122b7"},
+ {file = "black-23.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:adc3e4442eef57f99b5590b245a328aad19c99552e0bdc7f0b04db6656debd80"},
+ {file = "black-23.9.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:8431445bf62d2a914b541da7ab3e2b4f3bc052d2ccbf157ebad18ea126efb91f"},
+ {file = "black-23.9.1-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:8fc1ddcf83f996247505db6b715294eba56ea9372e107fd54963c7553f2b6dfe"},
+ {file = "black-23.9.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:7d30ec46de88091e4316b17ae58bbbfc12b2de05e069030f6b747dfc649ad186"},
+ {file = "black-23.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:031e8c69f3d3b09e1aa471a926a1eeb0b9071f80b17689a655f7885ac9325a6f"},
+ {file = "black-23.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:538efb451cd50f43aba394e9ec7ad55a37598faae3348d723b59ea8e91616300"},
+ {file = "black-23.9.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:638619a559280de0c2aa4d76f504891c9860bb8fa214267358f0a20f27c12948"},
+ {file = "black-23.9.1-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:a732b82747235e0542c03bf352c126052c0fbc458d8a239a94701175b17d4855"},
+ {file = "black-23.9.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:cf3a4d00e4cdb6734b64bf23cd4341421e8953615cba6b3670453737a72ec204"},
+ {file = "black-23.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf99f3de8b3273a8317681d8194ea222f10e0133a24a7548c73ce44ea1679377"},
+ {file = "black-23.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:14f04c990259576acd093871e7e9b14918eb28f1866f91968ff5524293f9c573"},
+ {file = "black-23.9.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:c619f063c2d68f19b2d7270f4cf3192cb81c9ec5bc5ba02df91471d0b88c4c5c"},
+ {file = "black-23.9.1-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:6a3b50e4b93f43b34a9d3ef00d9b6728b4a722c997c99ab09102fd5efdb88325"},
+ {file = "black-23.9.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c46767e8df1b7beefb0899c4a95fb43058fa8500b6db144f4ff3ca38eb2f6393"},
+ {file = "black-23.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50254ebfa56aa46a9fdd5d651f9637485068a1adf42270148cd101cdf56e0ad9"},
+ {file = "black-23.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:403397c033adbc45c2bd41747da1f7fc7eaa44efbee256b53842470d4ac5a70f"},
{file = "black-23.9.1-py3-none-any.whl", hash = "sha256:6ccd59584cc834b6d127628713e4b6b968e5f79572da66284532525a042549f9"},
{file = "black-23.9.1.tar.gz", hash = "sha256:24b6b3ff5c6d9ea08a8888f6977eae858e1f340d7260cf56d70a49823236b62d"},
]
@@ -1732,20 +1752,6 @@ smb = ["smbprotocol"]
ssh = ["paramiko"]
tqdm = ["tqdm"]
-[[package]]
-name = "ftfy"
-version = "6.1.1"
-description = "Fixes mojibake and other problems with Unicode, after the fact"
-optional = true
-python-versions = ">=3.7,<4"
-files = [
- {file = "ftfy-6.1.1-py3-none-any.whl", hash = "sha256:0ffd33fce16b54cccaec78d6ec73d95ad370e5df5a25255c8966a6147bd667ca"},
- {file = "ftfy-6.1.1.tar.gz", hash = "sha256:bfc2019f84fcd851419152320a6375604a0f1459c281b5b199b2cd0d2e727f8f"},
-]
-
-[package.dependencies]
-wcwidth = ">=0.2.5"
-
[[package]]
name = "gitdb"
version = "4.0.11"
@@ -1808,11 +1814,11 @@ files = [
google-auth = ">=2.14.1,<3.0.dev0"
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
grpcio = [
- {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""},
+ {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
]
grpcio-status = [
- {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""},
+ {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
@@ -2228,7 +2234,7 @@ files = [
{file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0b72b802496cccbd9b31acea72b6f87e7771ccfd7f7927437d592e5c92ed703c"},
{file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:527cd90ba3d8d7ae7dceb06fda619895768a46a1b4e423bdb24c1969823b8362"},
{file = "greenlet-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:37f60b3a42d8b5499be910d1267b24355c495064f271cfe74bf28b17b099133c"},
- {file = "greenlet-3.0.0-cp311-universal2-macosx_10_9_universal2.whl", hash = "sha256:c3692ecf3fe754c8c0f2c95ff19626584459eab110eaab66413b1e7425cd84e9"},
+ {file = "greenlet-3.0.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1482fba7fbed96ea7842b5a7fc11d61727e8be75a077e603e8ab49d24e234383"},
{file = "greenlet-3.0.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:be557119bf467d37a8099d91fbf11b2de5eb1fd5fc5b91598407574848dc910f"},
{file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73b2f1922a39d5d59cc0e597987300df3396b148a9bd10b76a058a2f2772fc04"},
{file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1e22c22f7826096ad503e9bb681b05b8c1f5a8138469b255eb91f26a76634f2"},
@@ -2238,7 +2244,6 @@ files = [
{file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:952256c2bc5b4ee8df8dfc54fc4de330970bf5d79253c863fb5e6761f00dda35"},
{file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:269d06fa0f9624455ce08ae0179430eea61085e3cf6457f05982b37fd2cefe17"},
{file = "greenlet-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9adbd8ecf097e34ada8efde9b6fec4dd2a903b1e98037adf72d12993a1c80b51"},
- {file = "greenlet-3.0.0-cp312-universal2-macosx_10_9_universal2.whl", hash = "sha256:553d6fb2324e7f4f0899e5ad2c427a4579ed4873f42124beba763f16032959af"},
{file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6b5ce7f40f0e2f8b88c28e6691ca6806814157ff05e794cdd161be928550f4c"},
{file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf94aa539e97a8411b5ea52fc6ccd8371be9550c4041011a091eb8b3ca1d810"},
{file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80dcd3c938cbcac986c5c92779db8e8ce51a89a849c135172c88ecbdc8c056b7"},
@@ -3301,6 +3306,16 @@ files = [
{file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"},
{file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"},
{file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"},
+ {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"},
{file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"},
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"},
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"},
@@ -4139,12 +4154,10 @@ files = [
[package.dependencies]
numpy = [
- {version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
- {version = ">=1.21.2", markers = "python_version >= \"3.10\""},
- {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""},
- {version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""},
- {version = ">=1.17.0", markers = "python_version >= \"3.7\""},
- {version = ">=1.17.3", markers = "python_version >= \"3.8\""},
+ {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
+ {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
+ {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
+ {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
]
@@ -5515,6 +5528,7 @@ files = [
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
+ {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
@@ -5522,8 +5536,15 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
+ {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
+ {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
+ {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
+ {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
+ {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
+ {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
+ {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
@@ -5540,6 +5561,7 @@ files = [
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
+ {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
@@ -5547,6 +5569,7 @@ files = [
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
+ {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
@@ -6093,6 +6116,11 @@ files = [
{file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f66eddfda9d45dd6cadcd706b65669ce1df84b8549875691b1f403730bdef217"},
{file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6448c37741145b241eeac617028ba6ec2119e1339b1385c9720dae31367f2be"},
{file = "scikit_learn-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:c413c2c850241998168bbb3bd1bb59ff03b1195a53864f0b80ab092071af6028"},
+ {file = "scikit_learn-1.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ef540e09873e31569bc8b02c8a9f745ee04d8e1263255a15c9969f6f5caa627f"},
+ {file = "scikit_learn-1.3.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:9147a3a4df4d401e618713880be023e36109c85d8569b3bf5377e6cd3fecdeac"},
+ {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2cd3634695ad192bf71645702b3df498bd1e246fc2d529effdb45a06ab028b4"},
+ {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c275a06c5190c5ce00af0acbb61c06374087949f643ef32d355ece12c4db043"},
+ {file = "scikit_learn-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:0e1aa8f206d0de814b81b41d60c1ce31f7f2c7354597af38fae46d9c47c45122"},
{file = "scikit_learn-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:52b77cc08bd555969ec5150788ed50276f5ef83abb72e6f469c5b91a0009bbca"},
{file = "scikit_learn-1.3.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a683394bc3f80b7c312c27f9b14ebea7766b1f0a34faf1a2e9158d80e860ec26"},
{file = "scikit_learn-1.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15d964d9eb181c79c190d3dbc2fff7338786bf017e9039571418a1d53dab236"},
@@ -6435,18 +6463,59 @@ description = "Database Abstraction Library"
optional = false
python-versions = ">=3.7"
files = [
+ {file = "SQLAlchemy-2.0.22-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f146c61ae128ab43ea3a0955de1af7e1633942c2b2b4985ac51cc292daf33222"},
+ {file = "SQLAlchemy-2.0.22-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:875de9414393e778b655a3d97d60465eb3fae7c919e88b70cc10b40b9f56042d"},
+ {file = "SQLAlchemy-2.0.22-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13790cb42f917c45c9c850b39b9941539ca8ee7917dacf099cc0b569f3d40da7"},
+ {file = "SQLAlchemy-2.0.22-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e04ab55cf49daf1aeb8c622c54d23fa4bec91cb051a43cc24351ba97e1dd09f5"},
+ {file = "SQLAlchemy-2.0.22-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a42c9fa3abcda0dcfad053e49c4f752eef71ecd8c155221e18b99d4224621176"},
+ {file = "SQLAlchemy-2.0.22-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:14cd3bcbb853379fef2cd01e7c64a5d6f1d005406d877ed9509afb7a05ff40a5"},
+ {file = "SQLAlchemy-2.0.22-cp310-cp310-win32.whl", hash = "sha256:d143c5a9dada696bcfdb96ba2de4a47d5a89168e71d05a076e88a01386872f97"},
+ {file = "SQLAlchemy-2.0.22-cp310-cp310-win_amd64.whl", hash = "sha256:ccd87c25e4c8559e1b918d46b4fa90b37f459c9b4566f1dfbce0eb8122571547"},
{file = "SQLAlchemy-2.0.22-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4f6ff392b27a743c1ad346d215655503cec64405d3b694228b3454878bf21590"},
{file = "SQLAlchemy-2.0.22-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f776c2c30f0e5f4db45c3ee11a5f2a8d9de68e81eb73ec4237de1e32e04ae81c"},
+ {file = "SQLAlchemy-2.0.22-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c8f1792d20d2f4e875ce7a113f43c3561ad12b34ff796b84002a256f37ce9437"},
+ {file = "SQLAlchemy-2.0.22-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d80eeb5189d7d4b1af519fc3f148fe7521b9dfce8f4d6a0820e8f5769b005051"},
+ {file = "SQLAlchemy-2.0.22-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:69fd9e41cf9368afa034e1c81f3570afb96f30fcd2eb1ef29cb4d9371c6eece2"},
+ {file = "SQLAlchemy-2.0.22-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54bcceaf4eebef07dadfde424f5c26b491e4a64e61761dea9459103ecd6ccc95"},
+ {file = "SQLAlchemy-2.0.22-cp311-cp311-win32.whl", hash = "sha256:7ee7ccf47aa503033b6afd57efbac6b9e05180f492aeed9fcf70752556f95624"},
+ {file = "SQLAlchemy-2.0.22-cp311-cp311-win_amd64.whl", hash = "sha256:b560f075c151900587ade06706b0c51d04b3277c111151997ea0813455378ae0"},
{file = "SQLAlchemy-2.0.22-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2c9bac865ee06d27a1533471405ad240a6f5d83195eca481f9fc4a71d8b87df8"},
{file = "SQLAlchemy-2.0.22-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:625b72d77ac8ac23da3b1622e2da88c4aedaee14df47c8432bf8f6495e655de2"},
+ {file = "SQLAlchemy-2.0.22-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b39a6e21110204a8c08d40ff56a73ba542ec60bab701c36ce721e7990df49fb9"},
+ {file = "SQLAlchemy-2.0.22-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53a766cb0b468223cafdf63e2d37f14a4757476157927b09300c8c5832d88560"},
+ {file = "SQLAlchemy-2.0.22-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0e1ce8ebd2e040357dde01a3fb7d30d9b5736b3e54a94002641dfd0aa12ae6ce"},
+ {file = "SQLAlchemy-2.0.22-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:505f503763a767556fa4deae5194b2be056b64ecca72ac65224381a0acab7ebe"},
+ {file = "SQLAlchemy-2.0.22-cp312-cp312-win32.whl", hash = "sha256:154a32f3c7b00de3d090bc60ec8006a78149e221f1182e3edcf0376016be9396"},
+ {file = "SQLAlchemy-2.0.22-cp312-cp312-win_amd64.whl", hash = "sha256:129415f89744b05741c6f0b04a84525f37fbabe5dc3774f7edf100e7458c48cd"},
{file = "SQLAlchemy-2.0.22-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3940677d341f2b685a999bffe7078697b5848a40b5f6952794ffcf3af150c301"},
+ {file = "SQLAlchemy-2.0.22-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55914d45a631b81a8a2cb1a54f03eea265cf1783241ac55396ec6d735be14883"},
+ {file = "SQLAlchemy-2.0.22-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2096d6b018d242a2bcc9e451618166f860bb0304f590d205173d317b69986c95"},
+ {file = "SQLAlchemy-2.0.22-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:19c6986cf2fb4bc8e0e846f97f4135a8e753b57d2aaaa87c50f9acbe606bd1db"},
+ {file = "SQLAlchemy-2.0.22-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6ac28bd6888fe3c81fbe97584eb0b96804bd7032d6100b9701255d9441373ec1"},
+ {file = "SQLAlchemy-2.0.22-cp37-cp37m-win32.whl", hash = "sha256:cb9a758ad973e795267da334a92dd82bb7555cb36a0960dcabcf724d26299db8"},
+ {file = "SQLAlchemy-2.0.22-cp37-cp37m-win_amd64.whl", hash = "sha256:40b1206a0d923e73aa54f0a6bd61419a96b914f1cd19900b6c8226899d9742ad"},
{file = "SQLAlchemy-2.0.22-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3aa1472bf44f61dd27987cd051f1c893b7d3b17238bff8c23fceaef4f1133868"},
{file = "SQLAlchemy-2.0.22-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:56a7e2bb639df9263bf6418231bc2a92a773f57886d371ddb7a869a24919face"},
+ {file = "SQLAlchemy-2.0.22-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ccca778c0737a773a1ad86b68bda52a71ad5950b25e120b6eb1330f0df54c3d0"},
+ {file = "SQLAlchemy-2.0.22-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7c6c3e9350f9fb16de5b5e5fbf17b578811a52d71bb784cc5ff71acb7de2a7f9"},
+ {file = "SQLAlchemy-2.0.22-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:564e9f9e4e6466273dbfab0e0a2e5fe819eec480c57b53a2cdee8e4fdae3ad5f"},
+ {file = "SQLAlchemy-2.0.22-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:af66001d7b76a3fab0d5e4c1ec9339ac45748bc4a399cbc2baa48c1980d3c1f4"},
+ {file = "SQLAlchemy-2.0.22-cp38-cp38-win32.whl", hash = "sha256:9e55dff5ec115316dd7a083cdc1a52de63693695aecf72bc53a8e1468ce429e5"},
+ {file = "SQLAlchemy-2.0.22-cp38-cp38-win_amd64.whl", hash = "sha256:4e869a8ff7ee7a833b74868a0887e8462445ec462432d8cbeff5e85f475186da"},
+ {file = "SQLAlchemy-2.0.22-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9886a72c8e6371280cb247c5d32c9c8fa141dc560124348762db8a8b236f8692"},
+ {file = "SQLAlchemy-2.0.22-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a571bc8ac092a3175a1d994794a8e7a1f2f651e7c744de24a19b4f740fe95034"},
+ {file = "SQLAlchemy-2.0.22-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8db5ba8b7da759b727faebc4289a9e6a51edadc7fc32207a30f7c6203a181592"},
+ {file = "SQLAlchemy-2.0.22-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b0b3f2686c3f162123adba3cb8b626ed7e9b8433ab528e36ed270b4f70d1cdb"},
+ {file = "SQLAlchemy-2.0.22-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0c1fea8c0abcb070ffe15311853abfda4e55bf7dc1d4889497b3403629f3bf00"},
+ {file = "SQLAlchemy-2.0.22-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4bb062784f37b2d75fd9b074c8ec360ad5df71f933f927e9e95c50eb8e05323c"},
+ {file = "SQLAlchemy-2.0.22-cp39-cp39-win32.whl", hash = "sha256:58a3aba1bfb32ae7af68da3f277ed91d9f57620cf7ce651db96636790a78b736"},
+ {file = "SQLAlchemy-2.0.22-cp39-cp39-win_amd64.whl", hash = "sha256:92e512a6af769e4725fa5b25981ba790335d42c5977e94ded07db7d641490a85"},
+ {file = "SQLAlchemy-2.0.22-py3-none-any.whl", hash = "sha256:3076740335e4aaadd7deb3fe6dcb96b3015f1613bd190a4e1634e1b99b02ec86"},
{file = "SQLAlchemy-2.0.22.tar.gz", hash = "sha256:5434cc601aa17570d79e5377f5fd45ff92f9379e2abed0be5e8c2fba8d353d2b"},
]
[package.dependencies]
-greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""}
+greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""}
typing-extensions = ">=4.2.0"
[package.extras]
@@ -7610,17 +7679,6 @@ files = [
[package.dependencies]
anyio = ">=3.0.0"
-[[package]]
-name = "wcwidth"
-version = "0.2.8"
-description = "Measures the displayed width of unicode strings in a terminal"
-optional = true
-python-versions = "*"
-files = [
- {file = "wcwidth-0.2.8-py2.py3-none-any.whl", hash = "sha256:77f719e01648ed600dfa5402c347481c0992263b81a027344f3e1ba25493a704"},
- {file = "wcwidth-0.2.8.tar.gz", hash = "sha256:8705c569999ffbb4f6a87c6d1b80f324bd6db952f5eb0b95bc07517f4c1813d4"},
-]
-
[[package]]
name = "weaviate-client"
version = "3.24.2"
@@ -8038,7 +8096,6 @@ github = ["PyGithub", "gitpython"]
gmail = ["google-api-core", "google-api-python-client", "google-auth", "google-auth-httplib2", "google-auth-oauthlib", "requests"]
google = ["google-generativeai"]
huggingface-hub = ["huggingface_hub"]
-images = ["ftfy", "pillow", "regex", "torch", "torchvision"]
llama2 = ["replicate"]
milvus = ["pymilvus"]
modal = ["modal"]
@@ -8061,4 +8118,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.12"
-content-hash = "335c42c91a2b5e4a1c3d8a7c39dee8665fd1eee0410e1bc6cb6cb1d6f6722445"
+content-hash = "8def3cb3aa4737793eaacd9358c092e0331f001044f5cacca513fc47faf44b06"
diff --git a/pyproject.toml b/pyproject.toml
index 650130e8..030baf43 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -125,10 +125,6 @@ docx2txt = { version = "^0.8", optional = true }
pinecone-client = { version = "^2.2.4", optional = true }
qdrant-client = { version = "1.6.3", optional = true }
unstructured = {extras = ["local-inference", "all-docs"], version = "^0.10.18", optional = true}
-pillow = { version = "10.0.1", optional = true }
-torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
-ftfy = { version = "6.1.1", optional = true }
-regex = { version = "2023.8.8", optional = true }
huggingface_hub = { version = "^0.17.3", optional = true }
pymilvus = { version = "2.3.1", optional = true }
google-cloud-aiplatform = { version = "^1.26.1", optional = true }
@@ -179,7 +175,6 @@ whatsapp = ["twilio", "flask"]
weaviate = ["weaviate-client"]
pinecone = ["pinecone-client"]
qdrant = ["qdrant-client"]
-images = ["torch", "ftfy", "regex", "pillow", "torchvision"]
huggingface_hub=["huggingface_hub"]
cohere = ["cohere"]
together = ["together"]
diff --git a/tests/chunkers/test_image_chunker.py b/tests/chunkers/test_image_chunker.py
deleted file mode 100644
index 67f5e563..00000000
--- a/tests/chunkers/test_image_chunker.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import unittest
-
-from embedchain.chunkers.images import ImagesChunker
-from embedchain.config import ChunkerConfig
-from embedchain.models.data_type import DataType
-
-
-class TestImageChunker(unittest.TestCase):
- def test_chunks(self):
- """
- Test the chunks generated by TextChunker.
- # TODO: Not a very precise test.
- """
- chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
- chunker = ImagesChunker(config=chunker_config)
- # Data type must be set manually in the test
- chunker.set_data_type(DataType.IMAGES)
-
- image_path = "./tmp/image.jpeg"
- app_id = "app1"
- result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
-
- expected_chunks = {
- "doc_id": f"{app_id}--123",
- "documents": [image_path],
- "embeddings": ["embedding"],
- "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
- "metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
- }
- self.assertEqual(expected_chunks, result)
-
- def test_chunks_with_default_config(self):
- """
- Test the chunks generated by ImageChunker with default config.
- """
- chunker = ImagesChunker()
- # Data type must be set manually in the test
- chunker.set_data_type(DataType.IMAGES)
-
- image_path = "./tmp/image.jpeg"
- app_id = "app1"
- result = chunker.create_chunks(MockLoader(), image_path, app_id=app_id)
-
- expected_chunks = {
- "doc_id": f"{app_id}--123",
- "documents": [image_path],
- "embeddings": ["embedding"],
- "ids": ["140bedbf9c3f6d56a9846d2ba7088798683f4da0c248231336e6a05679e4fdfe"],
- "metadatas": [{"data_type": "images", "doc_id": f"{app_id}--123", "url": "none"}],
- }
- self.assertEqual(expected_chunks, result)
-
- def test_word_count(self):
- chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
- chunker = ImagesChunker(config=chunker_config)
- chunker.set_data_type(DataType.IMAGES)
-
- document = [["ab cd", "ef gh"], ["ij kl", "mn op"]]
- result = chunker.get_word_count(document)
- self.assertEqual(result, 1)
-
-
-class MockLoader:
- def load_data(self, src):
- """
- Mock loader that returns a list of data dictionaries.
- Adjust this method to return different data for testing.
- """
- return {
- "doc_id": "123",
- "data": [
- {
- "content": src,
- "embedding": "embedding",
- "meta_data": {"url": "none"},
- }
- ],
- }
diff --git a/tests/models/test_clip_processor.py b/tests/models/test_clip_processor.py
deleted file mode 100644
index 3dcd5b9e..00000000
--- a/tests/models/test_clip_processor.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import os
-import tempfile
-import urllib
-
-from PIL import Image
-
-from embedchain.models.clip_processor import ClipProcessor
-
-
-class TestClipProcessor:
- def test_load_model(self):
- # Test that the `load_model()` method loads the CLIP model and image preprocessing correctly.
- model = ClipProcessor.load_model()
- assert model is not None
-
- def test_get_image_features(self):
- # Clone the image to a temporary folder.
- with tempfile.TemporaryDirectory() as tmp_dir:
- urllib.request.urlretrieve("https://upload.wikimedia.org/wikipedia/en/a/a9/Example.jpg", "image.jpg")
-
- image = Image.open("image.jpg")
- image.save(os.path.join(tmp_dir, "image.jpg"))
-
- # Get the image features.
- model = ClipProcessor.load_model()
- ClipProcessor.get_image_features(os.path.join(tmp_dir, "image.jpg"), model)
-
- # Delete the temporary file.
- os.remove(os.path.join(tmp_dir, "image.jpg"))
- os.remove("image.jpg")
-
- def test_get_text_features(self):
- # Test that the `get_text_features()` method returns a list containing the text embedding.
- query = "This is a text query."
- text_features = ClipProcessor.get_text_features(query)
-
- # Assert that the text embedding is not None.
- assert text_features is not None
-
- # Assert that the text embedding is a list of floats.
- assert isinstance(text_features, list)
-
- # Assert that the text embedding has the correct length.
- assert len(text_features) == 512
diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py
index 0af4490c..31deb241 100644
--- a/tests/vectordb/test_chroma_db.py
+++ b/tests/vectordb/test_chroma_db.py
@@ -148,73 +148,6 @@ def test_chroma_db_collection_changes_encapsulated():
app.db.reset()
-def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
- # Start with a clean app
- app_with_settings.db.reset()
-
- assert app_with_settings.db.count() == 0
-
- app_with_settings.db.add(
- embeddings=[[0, 0, 0]],
- documents=["document"],
- metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
- ids=["id"],
- skip_embedding=True,
- )
-
- assert app_with_settings.db.count() == 1
-
- data = app_with_settings.db.get(["id"], limit=1)
- expected_value = {
- "documents": ["document"],
- "embeddings": None,
- "ids": ["id"],
- "metadatas": [{"url": "url_1", "doc_id": "doc_id_1"}],
- "data": None,
- "uris": None,
- }
-
- assert data == expected_value
-
- data_without_citations = app_with_settings.db.query(
- input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True
- )
- expected_value_without_citations = ["document"]
- assert data_without_citations == expected_value_without_citations
-
- app_with_settings.db.reset()
-
-
-def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings):
- # Start with a clean app
- app_with_settings.db.reset()
-
- assert app_with_settings.db.count() == 0
-
- with pytest.raises(ValueError):
- app_with_settings.db.add(
- embeddings=[[0, 0, 0]],
- documents=["document", "document2"],
- metadatas=[{"value": "somevalue"}],
- ids=["id"],
- skip_embedding=True,
- )
-
- assert app_with_settings.db.count() == 0
-
- with pytest.raises(ValueError):
- app_with_settings.db.add(
- embeddings=None,
- documents=["document", "document2"],
- metadatas=[{"value": "somevalue"}],
- ids=["id"],
- skip_embedding=True,
- )
-
- assert app_with_settings.db.count() == 0
- app_with_settings.db.reset()
-
-
def test_chroma_db_collection_collections_are_persistent():
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
@@ -312,60 +245,3 @@ def test_chroma_db_collection_reset():
app2.db.reset()
app3.db.reset()
app4.db.reset()
-
-
-def test_chroma_db_collection_query(app_with_settings):
- app_with_settings.db.reset()
-
- assert app_with_settings.db.count() == 0
-
- app_with_settings.db.add(
- embeddings=[[0, 0, 0]],
- documents=["document"],
- metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
- ids=["id"],
- skip_embedding=True,
- )
-
- assert app_with_settings.db.count() == 1
-
- app_with_settings.db.add(
- embeddings=[[0, 1, 0]],
- documents=["document2"],
- metadatas=[{"url": "url_2", "doc_id": "doc_id_2"}],
- ids=["id2"],
- skip_embedding=True,
- )
-
- assert app_with_settings.db.count() == 2
-
- data_without_citations = app_with_settings.db.query(
- input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True
- )
- expected_value_without_citations = ["document", "document2"]
- assert data_without_citations == expected_value_without_citations
-
- data_with_citations = app_with_settings.db.query(
- input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
- )
- expected_value_with_citations = [
- (
- "document",
- {
- "url": "url_1",
- "doc_id": "doc_id_1",
- "score": 0.0,
- },
- ),
- (
- "document2",
- {
- "url": "url_2",
- "doc_id": "doc_id_2",
- "score": 1.0,
- },
- ),
- ]
- assert data_with_citations == expected_value_with_citations
-
- app_with_settings.db.reset()
diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py
index 28e2ec8f..953f7813 100644
--- a/tests/vectordb/test_elasticsearch_db.py
+++ b/tests/vectordb/test_elasticsearch_db.py
@@ -35,7 +35,7 @@ class TestEsDB(unittest.TestCase):
ids = ["doc_1", "doc_2"]
# Add the data to the database.
- self.db.add(embeddings, documents, metadatas, ids, skip_embedding=False)
+ self.db.add(embeddings, documents, metadatas, ids)
search_response = {
"hits": {
@@ -60,63 +60,17 @@ class TestEsDB(unittest.TestCase):
# Query the database for the documents that are most similar to the query "This is a document".
query = ["This is a document"]
- results_without_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False)
+ results_without_citations = self.db.query(query, n_results=2, where={})
expected_results_without_citations = ["This is a document.", "This is another document."]
self.assertEqual(results_without_citations, expected_results_without_citations)
- results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
+ results_with_citations = self.db.query(query, n_results=2, where={}, citations=True)
expected_results_with_citations = [
("This is a document.", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.9}),
("This is another document.", {"url": "url_2", "doc_id": "doc_id_2", "score": 0.8}),
]
self.assertEqual(results_with_citations, expected_results_with_citations)
- @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
- def test_query_with_skip_embedding(self, mock_client):
- self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
- app_config = AppConfig(collect_metrics=False)
- self.app = App(config=app_config, db=self.db)
-
- # Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
- self.assertEqual(self.db.client, mock_client.return_value)
-
- # Create some dummy data.
- embeddings = [[1, 2, 3], [4, 5, 6]]
- documents = ["This is a document.", "This is another document."]
- metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
- ids = ["doc_1", "doc_2"]
-
- # Add the data to the database.
- self.db.add(embeddings, documents, metadatas, ids, skip_embedding=True)
-
- search_response = {
- "hits": {
- "hits": [
- {
- "_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
- "_score": 0.9,
- },
- {
- "_source": {
- "text": "This is another document.",
- "metadata": {"url": "url_2", "doc_id": "doc_id_2"},
- },
- "_score": 0.8,
- },
- ]
- }
- }
-
- # Configure the mock client to return the mocked response.
- mock_client.return_value.search.return_value = search_response
-
- # Query the database for the documents that are most similar to the query "This is a document".
- query = ["This is a document"]
- results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
-
- # Assert that the results are correct.
- self.assertEqual(results, ["This is a document.", "This is another document."])
-
def test_init_without_url(self):
# Make sure it's not loaded from env
try:
diff --git a/tests/vectordb/test_pinecone.py b/tests/vectordb/test_pinecone.py
index 7d53840b..08a18a65 100644
--- a/tests/vectordb/test_pinecone.py
+++ b/tests/vectordb/test_pinecone.py
@@ -54,7 +54,7 @@ class TestPinecone:
documents = ["This is a document.", "This is another document."]
metadatas = [{}, {}]
ids = ["doc1", "doc2"]
- db.add(vectors, documents, metadatas, ids, True)
+ db.add(vectors, documents, metadatas, ids)
expected_pinecone_upsert_args = [
{"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
@@ -81,7 +81,7 @@ class TestPinecone:
# Query the database for documents that are similar to "document"
input_query = ["document"]
n_results = 1
- db.query(input_query, n_results, where={}, skip_embedding=False)
+ db.query(input_query, n_results, where={})
# Assert that the Pinecone client was called to query the database
pinecone_client_mock.query.assert_called_once_with(
diff --git a/tests/vectordb/test_qdrant.py b/tests/vectordb/test_qdrant.py
index 0754cde2..c38e5786 100644
--- a/tests/vectordb/test_qdrant.py
+++ b/tests/vectordb/test_qdrant.py
@@ -12,6 +12,11 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.qdrant import QdrantDB
+def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
+ """A mock embedding function."""
+ return [[1, 2, 3], [4, 5, 6]]
+
+
class TestQdrantDB(unittest.TestCase):
TEST_UUIDS = ["abc", "def", "ghi"]
@@ -25,6 +30,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
+ embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
@@ -42,6 +48,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
+ embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
@@ -61,6 +68,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
+ embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
@@ -71,8 +79,7 @@ class TestQdrantDB(unittest.TestCase):
documents = ["This is a test document.", "This is another test document."]
metadatas = [{}, {}]
ids = ["123", "456"]
- skip_embedding = True
- db.add(embeddings, documents, metadatas, ids, skip_embedding)
+ db.add(embeddings, documents, metadatas, ids)
qdrant_client_mock.return_value.upsert.assert_called_once_with(
collection_name="embedchain-store-1526",
points=Batch(
@@ -98,6 +105,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
+ embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
@@ -105,7 +113,7 @@ class TestQdrantDB(unittest.TestCase):
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
- db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
+ db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
qdrant_client_mock.return_value.search.assert_called_once_with(
collection_name="embedchain-store-1526",
@@ -119,7 +127,7 @@ class TestQdrantDB(unittest.TestCase):
)
]
),
- query_vector=["This is a test document."],
+ query_vector=[1, 2, 3],
limit=1,
)
@@ -128,6 +136,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
+ embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
@@ -142,6 +151,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
+ embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance
db = QdrantDB()
diff --git a/tests/vectordb/test_weaviate.py b/tests/vectordb/test_weaviate.py
index e4535048..ba4045a7 100644
--- a/tests/vectordb/test_weaviate.py
+++ b/tests/vectordb/test_weaviate.py
@@ -8,6 +8,11 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.weaviate import WeaviateDB
+def mock_embedding_fn(texts: list[str]) -> list[list[float]]:
+ """A mock embedding function."""
+ return [[1, 2, 3], [4, 5, 6]]
+
+
class TestWeaviateDb(unittest.TestCase):
def test_incorrect_config_throws_error(self):
"""Test the init method of the WeaviateDb class throws error for incorrect config"""
@@ -25,6 +30,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
+ embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
@@ -92,6 +98,7 @@ class TestWeaviateDb(unittest.TestCase):
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
+ embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
@@ -111,6 +118,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
+ embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
@@ -122,8 +130,7 @@ class TestWeaviateDb(unittest.TestCase):
documents = ["This is a test document.", "This is another test document."]
metadatas = [None, None]
ids = ["123", "456"]
- skip_embedding = True
- db.add(embeddings, documents, metadatas, ids, skip_embedding)
+ db.add(embeddings, documents, metadatas, ids)
# Check if the document was added to the database.
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
@@ -155,6 +162,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
+ embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
@@ -162,12 +170,10 @@ class TestWeaviateDb(unittest.TestCase):
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
- db.query(input_query=["This is a test document."], n_results=1, where={}, skip_embedding=True)
+ db.query(input_query=["This is a test document."], n_results=1, where={})
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
- weaviate_client_query_get_mock.with_near_vector.assert_called_once_with(
- {"vector": ["This is a test document."]}
- )
+ weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
@patch("embedchain.vectordb.weaviate.weaviate")
def test_query_with_where(self, weaviate_mock):
@@ -180,6 +186,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
+ embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
@@ -187,15 +194,13 @@ class TestWeaviateDb(unittest.TestCase):
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
- db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
+ db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
weaviate_client_query_get_mock.with_where.assert_called_once_with(
{"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"}
)
- weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with(
- {"vector": ["This is a test document."]}
- )
+ weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
@patch("embedchain.vectordb.weaviate.weaviate")
def test_reset(self, weaviate_mock):
@@ -206,6 +211,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
+ embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
@@ -228,6 +234,7 @@ class TestWeaviateDb(unittest.TestCase):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
+ embedder.set_embedding_fn(mock_embedding_fn)
# Create a Weaviate instance
db = WeaviateDB()
diff --git a/tests/vectordb/test_zilliz_db.py b/tests/vectordb/test_zilliz_db.py
index d4d9fdd4..d4ec4675 100644
--- a/tests/vectordb/test_zilliz_db.py
+++ b/tests/vectordb/test_zilliz_db.py
@@ -108,65 +108,7 @@ class TestZillizDBCollection:
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
- def test_query_with_skip_embedding(self, mock_connect, mock_client, mock_config):
- """
- Test if the `ZillizVectorDB` instance is takes in the query with skip_embeddings.
- """
- # Create an instance of ZillizVectorDB with mock config
- zilliz_db = ZillizVectorDB(config=mock_config)
-
- # Add a 'collection' attribute to the ZillizVectorDB instance for testing
- zilliz_db.collection = Mock(is_empty=False) # Mock the 'collection' object
-
- assert zilliz_db.client == mock_client()
-
- # Mock the MilvusClient search method
- with patch.object(zilliz_db.client, "search") as mock_search:
- # Mock the search result
- mock_search.return_value = [
- [
- {
- "distance": 0.5,
- "entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]},
- }
- ]
- ]
-
- # Call the query method with skip_embedding=True
- query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
-
- # Assert that MilvusClient.search was called with the correct parameters
- mock_search.assert_called_with(
- collection_name=mock_config.collection_name,
- data=["query_text"],
- limit=1,
- output_fields=["*"],
- )
-
- # Assert that the query result matches the expected result
- assert query_result == ["result_doc"]
-
- query_result_with_citations = zilliz_db.query(
- input_query=["query_text"], n_results=1, where={}, skip_embedding=True, citations=True
- )
-
- mock_search.assert_called_with(
- collection_name=mock_config.collection_name,
- data=["query_text"],
- limit=1,
- output_fields=["*"],
- )
-
- assert query_result_with_citations == [
- ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.5})
- ]
-
- @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
- @patch("embedchain.vectordb.zilliz.connections", autospec=True)
- def test_query_without_skip_embedding(self, mock_connect, mock_client, mock_embedder, mock_config):
- """
- Test if the `ZillizVectorDB` instance is takes in the query without skip_embeddings.
- """
+ def test_query(self, mock_connect, mock_client, mock_embedder, mock_config):
# Create an instance of ZillizVectorDB with mock config
zilliz_db = ZillizVectorDB(config=mock_config)
@@ -193,8 +135,7 @@ class TestZillizDBCollection:
]
]
- # Call the query method with skip_embedding=False
- query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
+ query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={})
# Assert that MilvusClient.search was called with the correct parameters
mock_search.assert_called_with(
@@ -208,7 +149,7 @@ class TestZillizDBCollection:
assert query_result == ["result_doc"]
query_result_with_citations = zilliz_db.query(
- input_query=["query_text"], n_results=1, where={}, skip_embedding=False, citations=True
+ input_query=["query_text"], n_results=1, where={}, citations=True
)
mock_search.assert_called_with(