[Improvements] allow setting up the elasticsearch cloud instance (#997)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-12-07 14:17:59 -08:00
committed by GitHub
parent d62a23edf6
commit 0ea8ab228c
4 changed files with 49 additions and 21 deletions

View File

@@ -58,6 +58,12 @@ Install related dependencies using the following command:
pip install --upgrade 'embedchain[elasticsearch]' pip install --upgrade 'embedchain[elasticsearch]'
``` ```
<Note>
You can configure the Elasticsearch connection by providing either `es_url` or `cloud_id`. If you are using the Elasticsearch Service on Elastic Cloud, you can find the `cloud_id` on the [Elastic Cloud dashboard](https://cloud.elastic.co/deployments).
</Note>
You can authorize the connection to Elasticsearch by providing either `basic_auth`, `api_key`, or `bearer_auth`.
<CodeGroup> <CodeGroup>
```python main.py ```python main.py
@@ -72,11 +78,10 @@ vectordb:
provider: elasticsearch provider: elasticsearch
config: config:
collection_name: 'es-index' collection_name: 'es-index'
es_url: http://localhost:9200 cloud_id: 'deployment-name:xxxx'
http_auth: basic_auth:
- admin - elastic
- admin - <your_password>
api_key: xxx
verify_certs: false verify_certs: false
``` ```
</CodeGroup> </CodeGroup>

View File

@@ -57,7 +57,7 @@ class BaseLlmConfig(BaseConfig):
def __init__( def __init__(
self, self,
number_documents: int = 1, number_documents: int = 3,
template: Optional[Template] = None, template: Optional[Template] = None,
model: Optional[str] = None, model: Optional[str] = None,
temperature: float = 0, temperature: float = 0,

View File

@@ -12,6 +12,7 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
dir: Optional[str] = None, dir: Optional[str] = None,
es_url: Union[str, List[str]] = None, es_url: Union[str, List[str]] = None,
cloud_id: Optional[str] = None,
**ES_EXTRA_PARAMS: Dict[str, any], **ES_EXTRA_PARAMS: Dict[str, any],
): ):
""" """
@@ -26,12 +27,15 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch. :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
:type ES_EXTRA_PARAMS: Dict[str, Any], optional :type ES_EXTRA_PARAMS: Dict[str, Any], optional
""" """
if es_url and cloud_id:
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
# self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]): # self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL") self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
if not self.ES_URL: self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
if not self.ES_URL and not self.CLOUD_ID:
raise AttributeError( raise AttributeError(
"Elasticsearch needs a URL attribute, " "Elasticsearch needs a URL or CLOUD_ID attribute, "
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` in `.env`" "this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
) )
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
# Load API key from .env if it's not explicitly passed. # Load API key from .env if it's not explicitly passed.
@@ -40,7 +44,6 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
not self.ES_EXTRA_PARAMS.get("api_key") not self.ES_EXTRA_PARAMS.get("api_key")
and not self.ES_EXTRA_PARAMS.get("basic_auth") and not self.ES_EXTRA_PARAMS.get("basic_auth")
and not self.ES_EXTRA_PARAMS.get("bearer_auth") and not self.ES_EXTRA_PARAMS.get("bearer_auth")
and not self.ES_EXTRA_PARAMS.get("http_auth")
): ):
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY") self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
super().__init__(collection_name=collection_name, dir=dir) super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -11,6 +11,7 @@ except ImportError:
from embedchain.config import ElasticsearchDBConfig from embedchain.config import ElasticsearchDBConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.utils import chunks
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
@@ -20,6 +21,8 @@ class ElasticsearchDB(BaseVectorDB):
Elasticsearch as vector database Elasticsearch as vector database
""" """
BATCH_SIZE = 100
def __init__( def __init__(
self, self,
config: Optional[ElasticsearchDBConfig] = None, config: Optional[ElasticsearchDBConfig] = None,
@@ -43,7 +46,14 @@ class ElasticsearchDB(BaseVectorDB):
"Please make sure the type is right and that you are passing an instance." "Please make sure the type is right and that you are passing an instance."
) )
self.config = config or es_config self.config = config or es_config
self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS) if self.config.ES_URL:
self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS)
elif self.config.CLOUD_ID:
self.client = Elasticsearch(cloud_id=self.config.CLOUD_ID, **self.config.ES_EXTRA_PARAMS)
else:
raise ValueError(
"Something is wrong with your config. Please check again - `https://docs.embedchain.ai/components/vector-databases#elasticsearch`" # noqa: E501
)
# Call parent init here because embedder is needed # Call parent init here because embedder is needed
super().__init__(config=self.config) super().__init__(config=self.config)
@@ -121,19 +131,29 @@ class ElasticsearchDB(BaseVectorDB):
:type skip_embedding: bool :type skip_embedding: bool
""" """
docs = []
if not skip_embedding: if not skip_embedding:
embeddings = self.embedder.embedding_fn(documents) embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings): for chunk in chunks(
docs.append( list(zip(ids, documents, metadatas, embeddings)), self.BATCH_SIZE, desc="Inserting batches in elasticsearch"
{ ): # noqa: E501
"_index": self._get_index(), ids, docs, metadatas, embeddings = [], [], [], []
"_id": id, for id, text, metadata, embedding in chunk:
"_source": {"text": text, "metadata": metadata, "embeddings": embeddings}, ids.append(id)
} docs.append(text)
) metadatas.append(metadata)
bulk(self.client, docs) embeddings.append(embedding)
batch_docs = []
for id, text, metadata, embedding in zip(ids, docs, metadatas, embeddings):
batch_docs.append(
{
"_index": self._get_index(),
"_id": id,
"_source": {"text": text, "metadata": metadata, "embeddings": embedding},
}
)
bulk(self.client, batch_docs, **kwargs)
self.client.indices.refresh(index=self._get_index()) self.client.indices.refresh(index=self._get_index())
def query( def query(