Fix batch_size for vectordb (#1449)
This commit is contained in:
2
Makefile
2
Makefile
@@ -11,7 +11,7 @@ install:
|
|||||||
|
|
||||||
install_all:
|
install_all:
|
||||||
poetry install --all-extras
|
poetry install --all-extras
|
||||||
poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 langchain-huggingface
|
poetry run pip install pinecone-text pinecone-client langchain-anthropic "unstructured[local-inference, all-docs]" ollama deepgram-sdk==3.2.7 langchain-huggingface psutil
|
||||||
|
|
||||||
install_es:
|
install_es:
|
||||||
poetry install --extras elasticsearch
|
poetry install --extras elasticsearch
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ class BaseVectorDbConfig(BaseConfig):
|
|||||||
dir: str = "db",
|
dir: str = "db",
|
||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
port: Optional[str] = None,
|
port: Optional[str] = None,
|
||||||
batch_size: Optional[int] = 100,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -24,8 +23,6 @@ class BaseVectorDbConfig(BaseConfig):
|
|||||||
:type host: Optional[str], optional
|
:type host: Optional[str], optional
|
||||||
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
|
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
|
||||||
:type port: Optional[str], optional
|
:type port: Optional[str], optional
|
||||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
|
||||||
:type batch_size: Optional[int], optional
|
|
||||||
:param kwargs: Additional keyword arguments
|
:param kwargs: Additional keyword arguments
|
||||||
:type kwargs: dict
|
:type kwargs: dict
|
||||||
"""
|
"""
|
||||||
@@ -33,7 +30,6 @@ class BaseVectorDbConfig(BaseConfig):
|
|||||||
self.dir = dir
|
self.dir = dir
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.batch_size = batch_size
|
|
||||||
# Assign additional keyword arguments
|
# Assign additional keyword arguments
|
||||||
if kwargs:
|
if kwargs:
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class ChromaDbConfig(BaseVectorDbConfig):
|
|||||||
dir: Optional[str] = None,
|
dir: Optional[str] = None,
|
||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
port: Optional[str] = None,
|
port: Optional[str] = None,
|
||||||
|
batch_size: Optional[int] = 100,
|
||||||
allow_reset=False,
|
allow_reset=False,
|
||||||
chroma_settings: Optional[dict] = None,
|
chroma_settings: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
@@ -26,6 +27,8 @@ class ChromaDbConfig(BaseVectorDbConfig):
|
|||||||
:type host: Optional[str], optional
|
:type host: Optional[str], optional
|
||||||
:param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
|
:param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
|
||||||
:type port: Optional[str], optional
|
:type port: Optional[str], optional
|
||||||
|
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||||
|
:type batch_size: Optional[int], optional
|
||||||
:param allow_reset: Resets the database. defaults to False
|
:param allow_reset: Resets the database. defaults to False
|
||||||
:type allow_reset: bool
|
:type allow_reset: bool
|
||||||
:param chroma_settings: Chroma settings dict, defaults to None
|
:param chroma_settings: Chroma settings dict, defaults to None
|
||||||
@@ -34,4 +37,5 @@ class ChromaDbConfig(BaseVectorDbConfig):
|
|||||||
|
|
||||||
self.chroma_settings = chroma_settings
|
self.chroma_settings = chroma_settings
|
||||||
self.allow_reset = allow_reset
|
self.allow_reset = allow_reset
|
||||||
|
self.batch_size = batch_size
|
||||||
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)
|
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
|
|||||||
dir: Optional[str] = None,
|
dir: Optional[str] = None,
|
||||||
es_url: Union[str, list[str]] = None,
|
es_url: Union[str, list[str]] = None,
|
||||||
cloud_id: Optional[str] = None,
|
cloud_id: Optional[str] = None,
|
||||||
|
batch_size: Optional[int] = 100,
|
||||||
**ES_EXTRA_PARAMS: dict[str, any],
|
**ES_EXTRA_PARAMS: dict[str, any],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -24,6 +25,10 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
|
|||||||
:type dir: Optional[str], optional
|
:type dir: Optional[str], optional
|
||||||
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
|
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
|
||||||
:type es_url: Union[str, list[str]], optional
|
:type es_url: Union[str, list[str]], optional
|
||||||
|
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
|
||||||
|
:type cloud_id: Optional[str], optional
|
||||||
|
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||||
|
:type batch_size: Optional[int], optional
|
||||||
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
||||||
:type ES_EXTRA_PARAMS: dict[str, Any], optional
|
:type ES_EXTRA_PARAMS: dict[str, Any], optional
|
||||||
"""
|
"""
|
||||||
@@ -46,4 +51,6 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
|
|||||||
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
|
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
|
||||||
):
|
):
|
||||||
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
|
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
super().__init__(collection_name=collection_name, dir=dir)
|
super().__init__(collection_name=collection_name, dir=dir)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
|
|||||||
vector_dimension: int = 1536,
|
vector_dimension: int = 1536,
|
||||||
collection_name: Optional[str] = None,
|
collection_name: Optional[str] = None,
|
||||||
dir: Optional[str] = None,
|
dir: Optional[str] = None,
|
||||||
|
batch_size: Optional[int] = 100,
|
||||||
**extra_params: dict[str, any],
|
**extra_params: dict[str, any],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -28,10 +29,13 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
|
|||||||
:type vector_dimension: int, optional
|
:type vector_dimension: int, optional
|
||||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||||
:type dir: Optional[str], optional
|
:type dir: Optional[str], optional
|
||||||
|
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||||
|
:type batch_size: Optional[int], optional
|
||||||
"""
|
"""
|
||||||
self.opensearch_url = opensearch_url
|
self.opensearch_url = opensearch_url
|
||||||
self.http_auth = http_auth
|
self.http_auth = http_auth
|
||||||
self.vector_dimension = vector_dimension
|
self.vector_dimension = vector_dimension
|
||||||
self.extra_params = extra_params
|
self.extra_params = extra_params
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
super().__init__(collection_name=collection_name, dir=dir)
|
super().__init__(collection_name=collection_name, dir=dir)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
|
|||||||
serverless_config: Optional[dict[str, any]] = None,
|
serverless_config: Optional[dict[str, any]] = None,
|
||||||
hybrid_search: bool = False,
|
hybrid_search: bool = False,
|
||||||
bm25_encoder: any = None,
|
bm25_encoder: any = None,
|
||||||
|
batch_size: Optional[int] = 100,
|
||||||
**extra_params: dict[str, any],
|
**extra_params: dict[str, any],
|
||||||
):
|
):
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
@@ -26,6 +27,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
|
|||||||
self.extra_params = extra_params
|
self.extra_params = extra_params
|
||||||
self.hybrid_search = hybrid_search
|
self.hybrid_search = hybrid_search
|
||||||
self.bm25_encoder = bm25_encoder
|
self.bm25_encoder = bm25_encoder
|
||||||
|
self.batch_size = batch_size
|
||||||
if pod_config is None and serverless_config is None:
|
if pod_config is None and serverless_config is None:
|
||||||
# If no config is provided, use the default pod spec config
|
# If no config is provided, use the default pod spec config
|
||||||
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
|
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class QdrantDBConfig(BaseVectorDbConfig):
|
|||||||
hnsw_config: Optional[dict[str, any]] = None,
|
hnsw_config: Optional[dict[str, any]] = None,
|
||||||
quantization_config: Optional[dict[str, any]] = None,
|
quantization_config: Optional[dict[str, any]] = None,
|
||||||
on_disk: Optional[bool] = None,
|
on_disk: Optional[bool] = None,
|
||||||
|
batch_size: Optional[int] = 10,
|
||||||
**extra_params: dict[str, any],
|
**extra_params: dict[str, any],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -36,9 +37,12 @@ class QdrantDBConfig(BaseVectorDbConfig):
|
|||||||
This setting saves RAM by (slightly) increasing the response time.
|
This setting saves RAM by (slightly) increasing the response time.
|
||||||
Note: those payload values that are involved in filtering and are indexed - remain in RAM.
|
Note: those payload values that are involved in filtering and are indexed - remain in RAM.
|
||||||
:type on_disk: bool, optional, defaults to None
|
:type on_disk: bool, optional, defaults to None
|
||||||
|
:param batch_size: Number of items to insert in one batch, defaults to 10
|
||||||
|
:type batch_size: Optional[int], optional
|
||||||
"""
|
"""
|
||||||
self.hnsw_config = hnsw_config
|
self.hnsw_config = hnsw_config
|
||||||
self.quantization_config = quantization_config
|
self.quantization_config = quantization_config
|
||||||
self.on_disk = on_disk
|
self.on_disk = on_disk
|
||||||
|
self.batch_size = batch_size
|
||||||
self.extra_params = extra_params
|
self.extra_params = extra_params
|
||||||
super().__init__(collection_name=collection_name, dir=dir)
|
super().__init__(collection_name=collection_name, dir=dir)
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ class WeaviateDBConfig(BaseVectorDbConfig):
|
|||||||
self,
|
self,
|
||||||
collection_name: Optional[str] = None,
|
collection_name: Optional[str] = None,
|
||||||
dir: Optional[str] = None,
|
dir: Optional[str] = None,
|
||||||
|
batch_size: Optional[int] = 100,
|
||||||
**extra_params: dict[str, any],
|
**extra_params: dict[str, any],
|
||||||
):
|
):
|
||||||
|
self.batch_size = batch_size
|
||||||
self.extra_params = extra_params
|
self.extra_params = extra_params
|
||||||
super().__init__(collection_name=collection_name, dir=dir)
|
super().__init__(collection_name=collection_name, dir=dir)
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class ChromaDB(BaseVectorDB):
|
|||||||
|
|
||||||
self.settings = Settings(anonymized_telemetry=False)
|
self.settings = Settings(anonymized_telemetry=False)
|
||||||
self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False
|
self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False
|
||||||
|
self.batch_size = self.config.batch_size
|
||||||
if self.config.chroma_settings:
|
if self.config.chroma_settings:
|
||||||
for key, value in self.config.chroma_settings.items():
|
for key, value in self.config.chroma_settings.items():
|
||||||
if hasattr(self.settings, key):
|
if hasattr(self.settings, key):
|
||||||
@@ -153,12 +154,13 @@ class ChromaDB(BaseVectorDB):
|
|||||||
" Ids size: {}".format(len(documents), len(metadatas), len(ids))
|
" Ids size: {}".format(len(documents), len(metadatas), len(ids))
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in tqdm(range(0, len(documents), self.config.batch_size), desc="Inserting batches in chromadb"):
|
for i in tqdm(range(0, len(documents), self.batch_size), desc="Inserting batches in chromadb"):
|
||||||
self.collection.add(
|
self.collection.add(
|
||||||
documents=documents[i : i + self.config.batch_size],
|
documents=documents[i : i + self.batch_size],
|
||||||
metadatas=metadatas[i : i + self.config.batch_size],
|
metadatas=metadatas[i : i + self.batch_size],
|
||||||
ids=ids[i : i + self.config.batch_size],
|
ids=ids[i : i + self.batch_size],
|
||||||
)
|
)
|
||||||
|
self.config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _format_result(results: QueryResult) -> list[tuple[Document, float]]:
|
def _format_result(results: QueryResult) -> list[tuple[Document, float]]:
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
"Something is wrong with your config. Please check again - `https://docs.embedchain.ai/components/vector-databases#elasticsearch`" # noqa: E501
|
"Something is wrong with your config. Please check again - `https://docs.embedchain.ai/components/vector-databases#elasticsearch`" # noqa: E501
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.batch_size = self.config.batch_size
|
||||||
# Call parent init here because embedder is needed
|
# Call parent init here because embedder is needed
|
||||||
super().__init__(config=self.config)
|
super().__init__(config=self.config)
|
||||||
|
|
||||||
@@ -139,7 +140,7 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
|
|
||||||
for chunk in chunks(
|
for chunk in chunks(
|
||||||
list(zip(ids, documents, metadatas, embeddings)),
|
list(zip(ids, documents, metadatas, embeddings)),
|
||||||
self.config.batch_size,
|
self.batch_size,
|
||||||
desc="Inserting batches in elasticsearch",
|
desc="Inserting batches in elasticsearch",
|
||||||
): # noqa: E501
|
): # noqa: E501
|
||||||
ids, docs, metadatas, embeddings = [], [], [], []
|
ids, docs, metadatas, embeddings = [], [], [], []
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
if config is None:
|
if config is None:
|
||||||
raise ValueError("OpenSearchDBConfig is required")
|
raise ValueError("OpenSearchDBConfig is required")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.batch_size = self.config.batch_size
|
||||||
self.client = OpenSearch(
|
self.client = OpenSearch(
|
||||||
hosts=[self.config.opensearch_url],
|
hosts=[self.config.opensearch_url],
|
||||||
http_auth=self.config.http_auth,
|
http_auth=self.config.http_auth,
|
||||||
@@ -118,10 +119,8 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
"""Adds documents to the opensearch index"""
|
"""Adds documents to the opensearch index"""
|
||||||
|
|
||||||
embeddings = self.embedder.embedding_fn(documents)
|
embeddings = self.embedder.embedding_fn(documents)
|
||||||
for batch_start in tqdm(
|
for batch_start in tqdm(range(0, len(documents), self.batch_size), desc="Inserting batches in opensearch"):
|
||||||
range(0, len(documents), self.config.batch_size), desc="Inserting batches in opensearch"
|
batch_end = batch_start + self.batch_size
|
||||||
):
|
|
||||||
batch_end = batch_start + self.config.batch_size
|
|
||||||
batch_documents = documents[batch_start:batch_end]
|
batch_documents = documents[batch_start:batch_end]
|
||||||
batch_embeddings = embeddings[batch_start:batch_end]
|
batch_embeddings = embeddings[batch_start:batch_end]
|
||||||
|
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class PineconeDB(BaseVectorDB):
|
|||||||
|
|
||||||
# Setup BM25Encoder if sparse vectors are to be used
|
# Setup BM25Encoder if sparse vectors are to be used
|
||||||
self.bm25_encoder = None
|
self.bm25_encoder = None
|
||||||
|
self.batch_size = self.config.batch_size
|
||||||
if self.config.hybrid_search:
|
if self.config.hybrid_search:
|
||||||
logger.info("Initializing BM25Encoder for sparse vectors..")
|
logger.info("Initializing BM25Encoder for sparse vectors..")
|
||||||
self.bm25_encoder = self.config.bm25_encoder if self.config.bm25_encoder else BM25Encoder.default()
|
self.bm25_encoder = self.config.bm25_encoder if self.config.bm25_encoder else BM25Encoder.default()
|
||||||
@@ -102,8 +103,8 @@ class PineconeDB(BaseVectorDB):
|
|||||||
metadatas = []
|
metadatas = []
|
||||||
|
|
||||||
if ids is not None:
|
if ids is not None:
|
||||||
for i in range(0, len(ids), self.config.batch_size):
|
for i in range(0, len(ids), self.batch_size):
|
||||||
result = self.pinecone_index.fetch(ids=ids[i : i + self.config.batch_size])
|
result = self.pinecone_index.fetch(ids=ids[i : i + self.batch_size])
|
||||||
vectors = result.get("vectors")
|
vectors = result.get("vectors")
|
||||||
batch_existing_ids = list(vectors.keys())
|
batch_existing_ids = list(vectors.keys())
|
||||||
existing_ids.extend(batch_existing_ids)
|
existing_ids.extend(batch_existing_ids)
|
||||||
@@ -142,7 +143,7 @@ class PineconeDB(BaseVectorDB):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in chunks(docs, self.config.batch_size, desc="Adding chunks in batches"):
|
for chunk in chunks(docs, self.batch_size, desc="Adding chunks in batches"):
|
||||||
self.pinecone_index.upsert(chunk, **kwargs)
|
self.pinecone_index.upsert(chunk, **kwargs)
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class QdrantDB(BaseVectorDB):
|
|||||||
"Please make sure the type is right and that you are passing an instance."
|
"Please make sure the type is right and that you are passing an instance."
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.batch_size = self.config.batch_size
|
||||||
self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
|
self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
|
||||||
# Call parent init here because embedder is needed
|
# Call parent init here because embedder is needed
|
||||||
super().__init__(config=self.config)
|
super().__init__(config=self.config)
|
||||||
@@ -114,7 +115,7 @@ class QdrantDB(BaseVectorDB):
|
|||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
scroll_filter=models.Filter(must=qdrant_must_filters),
|
scroll_filter=models.Filter(must=qdrant_must_filters),
|
||||||
offset=offset,
|
offset=offset,
|
||||||
limit=self.config.batch_size,
|
limit=self.batch_size,
|
||||||
)
|
)
|
||||||
offset = response[1]
|
offset = response[1]
|
||||||
for doc in response[0]:
|
for doc in response[0]:
|
||||||
@@ -146,13 +147,13 @@ class QdrantDB(BaseVectorDB):
|
|||||||
qdrant_ids.append(id)
|
qdrant_ids.append(id)
|
||||||
payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
|
payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
|
||||||
|
|
||||||
for i in tqdm(range(0, len(qdrant_ids), self.config.batch_size), desc="Adding data in batches"):
|
for i in tqdm(range(0, len(qdrant_ids), self.batch_size), desc="Adding data in batches"):
|
||||||
self.client.upsert(
|
self.client.upsert(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
points=Batch(
|
points=Batch(
|
||||||
ids=qdrant_ids[i : i + self.config.batch_size],
|
ids=qdrant_ids[i : i + self.batch_size],
|
||||||
payloads=payloads[i : i + self.config.batch_size],
|
payloads=payloads[i : i + self.batch_size],
|
||||||
vectors=embeddings[i : i + self.config.batch_size],
|
vectors=embeddings[i : i + self.batch_size],
|
||||||
),
|
),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
"Please make sure the type is right and that you are passing an instance."
|
"Please make sure the type is right and that you are passing an instance."
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.batch_size = self.config.batch_size
|
||||||
self.client = weaviate.Client(
|
self.client = weaviate.Client(
|
||||||
url=os.environ.get("WEAVIATE_ENDPOINT"),
|
url=os.environ.get("WEAVIATE_ENDPOINT"),
|
||||||
auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
|
auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
|
||||||
@@ -167,7 +168,7 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
)
|
)
|
||||||
.with_where(weaviate_where_clause)
|
.with_where(weaviate_where_clause)
|
||||||
.with_additional(["id"])
|
.with_additional(["id"])
|
||||||
.with_limit(limit or self.config.batch_size),
|
.with_limit(limit or self.batch_size),
|
||||||
offset,
|
offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -196,7 +197,7 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
:type ids: list[str]
|
:type ids: list[str]
|
||||||
"""
|
"""
|
||||||
embeddings = self.embedder.embedding_fn(documents)
|
embeddings = self.embedder.embedding_fn(documents)
|
||||||
self.client.batch.configure(batch_size=self.config.batch_size, timeout_retries=3) # Configure batch
|
self.client.batch.configure(batch_size=self.batch_size, timeout_retries=3) # Configure batch
|
||||||
with self.client.batch as batch: # Initialize a batch process
|
with self.client.batch as batch: # Initialize a batch process
|
||||||
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
|
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
|
||||||
doc = {"identifier": id, "text": text}
|
doc = {"identifier": id, "text": text}
|
||||||
|
|||||||
@@ -124,7 +124,6 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
db = WeaviateDB()
|
db = WeaviateDB()
|
||||||
app_config = AppConfig(collect_metrics=False)
|
app_config = AppConfig(collect_metrics=False)
|
||||||
App(config=app_config, db=db, embedding_model=embedder)
|
App(config=app_config, db=db, embedding_model=embedder)
|
||||||
db.config.batch_size = 1
|
|
||||||
|
|
||||||
documents = ["This is test document"]
|
documents = ["This is test document"]
|
||||||
metadatas = [None]
|
metadatas = [None]
|
||||||
@@ -132,7 +131,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
db.add(documents, metadatas, ids)
|
db.add(documents, metadatas, ids)
|
||||||
|
|
||||||
# Check if the document was added to the database.
|
# Check if the document was added to the database.
|
||||||
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
|
weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=100, timeout_retries=3)
|
||||||
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
|
weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
|
||||||
data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3]
|
data_object={"text": documents[0]}, class_name="Embedchain_store_1536_metadata", vector=[1, 2, 3]
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user