Fix batch_size for vectordb (#1449)
This commit is contained in:
@@ -10,7 +10,6 @@ class BaseVectorDbConfig(BaseConfig):
|
||||
dir: str = "db",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -24,8 +23,6 @@ class BaseVectorDbConfig(BaseConfig):
|
||||
:type host: Optional[str], optional
|
||||
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
|
||||
:type port: Optional[str], optional
|
||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||
:type batch_size: Optional[int], optional
|
||||
:param kwargs: Additional keyword arguments
|
||||
:type kwargs: dict
|
||||
"""
|
||||
@@ -33,7 +30,6 @@ class BaseVectorDbConfig(BaseConfig):
|
||||
self.dir = dir
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.batch_size = batch_size
|
||||
# Assign additional keyword arguments
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
|
||||
@@ -12,6 +12,7 @@ class ChromaDbConfig(BaseVectorDbConfig):
|
||||
dir: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
allow_reset=False,
|
||||
chroma_settings: Optional[dict] = None,
|
||||
):
|
||||
@@ -26,6 +27,8 @@ class ChromaDbConfig(BaseVectorDbConfig):
|
||||
:type host: Optional[str], optional
|
||||
:param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
|
||||
:type port: Optional[str], optional
|
||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||
:type batch_size: Optional[int], optional
|
||||
:param allow_reset: Resets the database. defaults to False
|
||||
:type allow_reset: bool
|
||||
:param chroma_settings: Chroma settings dict, defaults to None
|
||||
@@ -34,4 +37,5 @@ class ChromaDbConfig(BaseVectorDbConfig):
|
||||
|
||||
self.chroma_settings = chroma_settings
|
||||
self.allow_reset = allow_reset
|
||||
self.batch_size = batch_size
|
||||
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)
|
||||
|
||||
@@ -13,6 +13,7 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
|
||||
dir: Optional[str] = None,
|
||||
es_url: Union[str, list[str]] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
**ES_EXTRA_PARAMS: dict[str, any],
|
||||
):
|
||||
"""
|
||||
@@ -24,6 +25,10 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
|
||||
:type dir: Optional[str], optional
|
||||
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
|
||||
:type es_url: Union[str, list[str]], optional
|
||||
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
|
||||
:type cloud_id: Optional[str], optional
|
||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||
:type batch_size: Optional[int], optional
|
||||
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
||||
:type ES_EXTRA_PARAMS: dict[str, Any], optional
|
||||
"""
|
||||
@@ -46,4 +51,6 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
|
||||
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
|
||||
):
|
||||
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
|
||||
|
||||
self.batch_size = batch_size
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
|
||||
@@ -13,6 +13,7 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
|
||||
vector_dimension: int = 1536,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
"""
|
||||
@@ -28,10 +29,13 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
|
||||
:type vector_dimension: int, optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
:type dir: Optional[str], optional
|
||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||
:type batch_size: Optional[int], optional
|
||||
"""
|
||||
self.opensearch_url = opensearch_url
|
||||
self.http_auth = http_auth
|
||||
self.vector_dimension = vector_dimension
|
||||
self.extra_params = extra_params
|
||||
self.batch_size = batch_size
|
||||
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
|
||||
@@ -17,6 +17,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
|
||||
serverless_config: Optional[dict[str, any]] = None,
|
||||
hybrid_search: bool = False,
|
||||
bm25_encoder: any = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
self.metric = metric
|
||||
@@ -26,6 +27,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
|
||||
self.extra_params = extra_params
|
||||
self.hybrid_search = hybrid_search
|
||||
self.bm25_encoder = bm25_encoder
|
||||
self.batch_size = batch_size
|
||||
if pod_config is None and serverless_config is None:
|
||||
# If no config is provided, use the default pod spec config
|
||||
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
|
||||
|
||||
@@ -18,6 +18,7 @@ class QdrantDBConfig(BaseVectorDbConfig):
|
||||
hnsw_config: Optional[dict[str, any]] = None,
|
||||
quantization_config: Optional[dict[str, any]] = None,
|
||||
on_disk: Optional[bool] = None,
|
||||
batch_size: Optional[int] = 10,
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
"""
|
||||
@@ -36,9 +37,12 @@ class QdrantDBConfig(BaseVectorDbConfig):
|
||||
This setting saves RAM by (slightly) increasing the response time.
|
||||
Note: those payload values that are involved in filtering and are indexed - remain in RAM.
|
||||
:type on_disk: bool, optional, defaults to None
|
||||
:param batch_size: Number of items to insert in one batch, defaults to 10
|
||||
:type batch_size: Optional[int], optional
|
||||
"""
|
||||
self.hnsw_config = hnsw_config
|
||||
self.quantization_config = quantization_config
|
||||
self.on_disk = on_disk
|
||||
self.batch_size = batch_size
|
||||
self.extra_params = extra_params
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
|
||||
@@ -10,7 +10,9 @@ class WeaviateDBConfig(BaseVectorDbConfig):
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.extra_params = extra_params
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
|
||||
Reference in New Issue
Block a user