From 0ea8ab228cb622e1d877a67f6890e428840a781b Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Thu, 7 Dec 2023 14:17:59 -0800 Subject: [PATCH] [Improvements] allow setting up the elasticsearch cloud instance (#997) Co-authored-by: Deven Patel --- docs/components/vector-databases.mdx | 15 +++++--- embedchain/config/llm/base.py | 2 +- embedchain/config/vectordb/elasticsearch.py | 11 ++++-- embedchain/vectordb/elasticsearch.py | 42 +++++++++++++++------ 4 files changed, 49 insertions(+), 21 deletions(-) diff --git a/docs/components/vector-databases.mdx b/docs/components/vector-databases.mdx index fe54c575..029afb52 100644 --- a/docs/components/vector-databases.mdx +++ b/docs/components/vector-databases.mdx @@ -58,6 +58,12 @@ Install related dependencies using the following command: pip install --upgrade 'embedchain[elasticsearch]' ``` + +You can configure the Elasticsearch connection by providing either `es_url` or `cloud_id`. If you are using the Elasticsearch Service on Elastic Cloud, you can find the `cloud_id` on the [Elastic Cloud dashboard](https://cloud.elastic.co/deployments). + + +You can authorize the connection to Elasticsearch by providing either `basic_auth`, `api_key`, or `bearer_auth`. + ```python main.py @@ -72,11 +78,10 @@ vectordb: provider: elasticsearch config: collection_name: 'es-index' - es_url: http://localhost:9200 - http_auth: - - admin - - admin - api_key: xxx + cloud_id: 'deployment-name:xxxx' + basic_auth: + - elastic + - verify_certs: false ``` diff --git a/embedchain/config/llm/base.py b/embedchain/config/llm/base.py index 6dbfdb90..4e4054c2 100644 --- a/embedchain/config/llm/base.py +++ b/embedchain/config/llm/base.py @@ -57,7 +57,7 @@ class BaseLlmConfig(BaseConfig): def __init__( self, - number_documents: int = 1, + number_documents: int = 3, template: Optional[Template] = None, model: Optional[str] = None, temperature: float = 0, diff --git a/embedchain/config/vectordb/elasticsearch.py b/embedchain/config/vectordb/elasticsearch.py index 77d54a16..7ccf4226 100644 --- a/embedchain/config/vectordb/elasticsearch.py +++ b/embedchain/config/vectordb/elasticsearch.py @@ -12,6 +12,7 @@ class ElasticsearchDBConfig(BaseVectorDbConfig): collection_name: Optional[str] = None, dir: Optional[str] = None, es_url: Union[str, List[str]] = None, + cloud_id: Optional[str] = None, **ES_EXTRA_PARAMS: Dict[str, any], ): """ @@ -26,12 +27,15 @@ class ElasticsearchDBConfig(BaseVectorDbConfig): :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch. :type ES_EXTRA_PARAMS: Dict[str, Any], optional """ + if es_url and cloud_id: + raise ValueError("Only one of `es_url` and `cloud_id` can be set.") # self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]): self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL") - if not self.ES_URL: + self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID") + if not self.ES_URL and not self.CLOUD_ID: raise AttributeError( - "Elasticsearch needs a URL attribute, " - "this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` in `.env`" + "Elasticsearch needs a URL or CLOUD_ID attribute, " + "this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501 ) self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS # Load API key from .env if it's not explicitly passed. @@ -40,7 +44,6 @@ class ElasticsearchDBConfig(BaseVectorDbConfig): not self.ES_EXTRA_PARAMS.get("api_key") and not self.ES_EXTRA_PARAMS.get("basic_auth") and not self.ES_EXTRA_PARAMS.get("bearer_auth") - and not self.ES_EXTRA_PARAMS.get("http_auth") ): self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY") super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index e3b25042..866d3eb1 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -11,6 +11,7 @@ except ImportError: from embedchain.config import ElasticsearchDBConfig from embedchain.helpers.json_serializable import register_deserializable +from embedchain.utils import chunks from embedchain.vectordb.base import BaseVectorDB @@ -20,6 +21,8 @@ class ElasticsearchDB(BaseVectorDB): Elasticsearch as vector database """ + BATCH_SIZE = 100 + def __init__( self, config: Optional[ElasticsearchDBConfig] = None, @@ -43,7 +46,14 @@ class ElasticsearchDB(BaseVectorDB): "Please make sure the type is right and that you are passing an instance." ) self.config = config or es_config - self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS) + if self.config.ES_URL: + self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS) + elif self.config.CLOUD_ID: + self.client = Elasticsearch(cloud_id=self.config.CLOUD_ID, **self.config.ES_EXTRA_PARAMS) + else: + raise ValueError( + "Something is wrong with your config. Please check again - `https://docs.embedchain.ai/components/vector-databases#elasticsearch`" # noqa: E501 + ) # Call parent init here because embedder is needed super().__init__(config=self.config) @@ -121,19 +131,29 @@ class ElasticsearchDB(BaseVectorDB): :type skip_embedding: bool """ - docs = [] if not skip_embedding: embeddings = self.embedder.embedding_fn(documents) - for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings): - docs.append( - { - "_index": self._get_index(), - "_id": id, - "_source": {"text": text, "metadata": metadata, "embeddings": embeddings}, - } - ) - bulk(self.client, docs) + for chunk in chunks( + list(zip(ids, documents, metadatas, embeddings)), self.BATCH_SIZE, desc="Inserting batches in elasticsearch" + ): # noqa: E501 + ids, docs, metadatas, embeddings = [], [], [], [] + for id, text, metadata, embedding in chunk: + ids.append(id) + docs.append(text) + metadatas.append(metadata) + embeddings.append(embedding) + + batch_docs = [] + for id, text, metadata, embedding in zip(ids, docs, metadatas, embeddings): + batch_docs.append( + { + "_index": self._get_index(), + "_id": id, + "_source": {"text": text, "metadata": metadata, "embeddings": embedding}, + } + ) + bulk(self.client, batch_docs, **kwargs) self.client.indices.refresh(index=self._get_index()) def query(