Fix batch_size for vectordb (#1449)

This commit is contained in:
Dev Khant
2024-06-28 23:48:22 +05:30
committed by GitHub
parent 0a78198bb5
commit 50c0285cb2
15 changed files with 49 additions and 26 deletions

View File

@@ -10,7 +10,6 @@ class BaseVectorDbConfig(BaseConfig):
dir: str = "db",
host: Optional[str] = None,
port: Optional[str] = None,
batch_size: Optional[int] = 100,
**kwargs,
):
"""
@@ -24,8 +23,6 @@ class BaseVectorDbConfig(BaseConfig):
:type host: Optional[str], optional
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param kwargs: Additional keyword arguments
:type kwargs: dict
"""
@@ -33,7 +30,6 @@ class BaseVectorDbConfig(BaseConfig):
self.dir = dir
self.host = host
self.port = port
self.batch_size = batch_size
# Assign additional keyword arguments
if kwargs:
for key, value in kwargs.items():

View File

@@ -12,6 +12,7 @@ class ChromaDbConfig(BaseVectorDbConfig):
dir: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
batch_size: Optional[int] = 100,
allow_reset=False,
chroma_settings: Optional[dict] = None,
):
@@ -26,6 +27,8 @@ class ChromaDbConfig(BaseVectorDbConfig):
:type host: Optional[str], optional
:param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param allow_reset: Resets the database. defaults to False
:type allow_reset: bool
:param chroma_settings: Chroma settings dict, defaults to None
@@ -34,4 +37,5 @@ class ChromaDbConfig(BaseVectorDbConfig):
self.chroma_settings = chroma_settings
self.allow_reset = allow_reset
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)

View File

@@ -13,6 +13,7 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
dir: Optional[str] = None,
es_url: Union[str, list[str]] = None,
cloud_id: Optional[str] = None,
batch_size: Optional[int] = 100,
**ES_EXTRA_PARAMS: dict[str, any],
):
"""
@@ -24,6 +25,10 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
:type dir: Optional[str], optional
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
:type es_url: Union[str, list[str]], optional
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
:type cloud_id: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
:type ES_EXTRA_PARAMS: dict[str, Any], optional
"""
@@ -46,4 +51,6 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
):
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -13,6 +13,7 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
vector_dimension: int = 1536,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
"""
@@ -28,10 +29,13 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
:type vector_dimension: int, optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
"""
self.opensearch_url = opensearch_url
self.http_auth = http_auth
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -17,6 +17,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
serverless_config: Optional[dict[str, any]] = None,
hybrid_search: bool = False,
bm25_encoder: any = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
self.metric = metric
@@ -26,6 +27,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
self.extra_params = extra_params
self.hybrid_search = hybrid_search
self.bm25_encoder = bm25_encoder
self.batch_size = batch_size
if pod_config is None and serverless_config is None:
# If no config is provided, use the default pod spec config
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")

View File

@@ -18,6 +18,7 @@ class QdrantDBConfig(BaseVectorDbConfig):
hnsw_config: Optional[dict[str, any]] = None,
quantization_config: Optional[dict[str, any]] = None,
on_disk: Optional[bool] = None,
batch_size: Optional[int] = 10,
**extra_params: dict[str, any],
):
"""
@@ -36,9 +37,12 @@ class QdrantDBConfig(BaseVectorDbConfig):
This setting saves RAM by (slightly) increasing the response time.
Note: those payload values that are involved in filtering and are indexed - remain in RAM.
:type on_disk: bool, optional, defaults to None
:param batch_size: Number of items to insert in one batch, defaults to 10
:type batch_size: Optional[int], optional
"""
self.hnsw_config = hnsw_config
self.quantization_config = quantization_config
self.on_disk = on_disk
self.batch_size = batch_size
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -10,7 +10,9 @@ class WeaviateDBConfig(BaseVectorDbConfig):
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
self.batch_size = batch_size
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)