[Improvements] allow setting up the elasticsearch cloud instance (#997)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -58,6 +58,12 @@ Install related dependencies using the following command:
|
|||||||
pip install --upgrade 'embedchain[elasticsearch]'
|
pip install --upgrade 'embedchain[elasticsearch]'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
<Note>
|
||||||
|
You can configure the Elasticsearch connection by providing either `es_url` or `cloud_id`. If you are using the Elasticsearch Service on Elastic Cloud, you can find the `cloud_id` on the [Elastic Cloud dashboard](https://cloud.elastic.co/deployments).
|
||||||
|
</Note>
|
||||||
|
|
||||||
|
You can authorize the connection to Elasticsearch by providing either `basic_auth`, `api_key`, or `bearer_auth`.
|
||||||
|
|
||||||
<CodeGroup>
|
<CodeGroup>
|
||||||
|
|
||||||
```python main.py
|
```python main.py
|
||||||
@@ -72,11 +78,10 @@ vectordb:
|
|||||||
provider: elasticsearch
|
provider: elasticsearch
|
||||||
config:
|
config:
|
||||||
collection_name: 'es-index'
|
collection_name: 'es-index'
|
||||||
es_url: http://localhost:9200
|
cloud_id: 'deployment-name:xxxx'
|
||||||
http_auth:
|
basic_auth:
|
||||||
- admin
|
- elastic
|
||||||
- admin
|
- <your_password>
|
||||||
api_key: xxx
|
|
||||||
verify_certs: false
|
verify_certs: false
|
||||||
```
|
```
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class BaseLlmConfig(BaseConfig):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
number_documents: int = 1,
|
number_documents: int = 3,
|
||||||
template: Optional[Template] = None,
|
template: Optional[Template] = None,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
temperature: float = 0,
|
temperature: float = 0,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
|
|||||||
collection_name: Optional[str] = None,
|
collection_name: Optional[str] = None,
|
||||||
dir: Optional[str] = None,
|
dir: Optional[str] = None,
|
||||||
es_url: Union[str, List[str]] = None,
|
es_url: Union[str, List[str]] = None,
|
||||||
|
cloud_id: Optional[str] = None,
|
||||||
**ES_EXTRA_PARAMS: Dict[str, any],
|
**ES_EXTRA_PARAMS: Dict[str, any],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -26,12 +27,15 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
|
|||||||
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
||||||
:type ES_EXTRA_PARAMS: Dict[str, Any], optional
|
:type ES_EXTRA_PARAMS: Dict[str, Any], optional
|
||||||
"""
|
"""
|
||||||
|
if es_url and cloud_id:
|
||||||
|
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
|
||||||
# self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
|
# self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
|
||||||
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
|
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
|
||||||
if not self.ES_URL:
|
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
|
||||||
|
if not self.ES_URL and not self.CLOUD_ID:
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
"Elasticsearch needs a URL attribute, "
|
"Elasticsearch needs a URL or CLOUD_ID attribute, "
|
||||||
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` in `.env`"
|
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
|
||||||
)
|
)
|
||||||
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
|
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
|
||||||
# Load API key from .env if it's not explicitly passed.
|
# Load API key from .env if it's not explicitly passed.
|
||||||
@@ -40,7 +44,6 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
|
|||||||
not self.ES_EXTRA_PARAMS.get("api_key")
|
not self.ES_EXTRA_PARAMS.get("api_key")
|
||||||
and not self.ES_EXTRA_PARAMS.get("basic_auth")
|
and not self.ES_EXTRA_PARAMS.get("basic_auth")
|
||||||
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
|
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
|
||||||
and not self.ES_EXTRA_PARAMS.get("http_auth")
|
|
||||||
):
|
):
|
||||||
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
|
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
|
||||||
super().__init__(collection_name=collection_name, dir=dir)
|
super().__init__(collection_name=collection_name, dir=dir)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ except ImportError:
|
|||||||
|
|
||||||
from embedchain.config import ElasticsearchDBConfig
|
from embedchain.config import ElasticsearchDBConfig
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
from embedchain.utils import chunks
|
||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
|
|
||||||
|
|
||||||
@@ -20,6 +21,8 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
Elasticsearch as vector database
|
Elasticsearch as vector database
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
BATCH_SIZE = 100
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Optional[ElasticsearchDBConfig] = None,
|
config: Optional[ElasticsearchDBConfig] = None,
|
||||||
@@ -43,7 +46,14 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
"Please make sure the type is right and that you are passing an instance."
|
"Please make sure the type is right and that you are passing an instance."
|
||||||
)
|
)
|
||||||
self.config = config or es_config
|
self.config = config or es_config
|
||||||
self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS)
|
if self.config.ES_URL:
|
||||||
|
self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS)
|
||||||
|
elif self.config.CLOUD_ID:
|
||||||
|
self.client = Elasticsearch(cloud_id=self.config.CLOUD_ID, **self.config.ES_EXTRA_PARAMS)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Something is wrong with your config. Please check again - `https://docs.embedchain.ai/components/vector-databases#elasticsearch`" # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
# Call parent init here because embedder is needed
|
# Call parent init here because embedder is needed
|
||||||
super().__init__(config=self.config)
|
super().__init__(config=self.config)
|
||||||
@@ -121,19 +131,29 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
:type skip_embedding: bool
|
:type skip_embedding: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
docs = []
|
|
||||||
if not skip_embedding:
|
if not skip_embedding:
|
||||||
embeddings = self.embedder.embedding_fn(documents)
|
embeddings = self.embedder.embedding_fn(documents)
|
||||||
|
|
||||||
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
|
for chunk in chunks(
|
||||||
docs.append(
|
list(zip(ids, documents, metadatas, embeddings)), self.BATCH_SIZE, desc="Inserting batches in elasticsearch"
|
||||||
{
|
): # noqa: E501
|
||||||
"_index": self._get_index(),
|
ids, docs, metadatas, embeddings = [], [], [], []
|
||||||
"_id": id,
|
for id, text, metadata, embedding in chunk:
|
||||||
"_source": {"text": text, "metadata": metadata, "embeddings": embeddings},
|
ids.append(id)
|
||||||
}
|
docs.append(text)
|
||||||
)
|
metadatas.append(metadata)
|
||||||
bulk(self.client, docs)
|
embeddings.append(embedding)
|
||||||
|
|
||||||
|
batch_docs = []
|
||||||
|
for id, text, metadata, embedding in zip(ids, docs, metadatas, embeddings):
|
||||||
|
batch_docs.append(
|
||||||
|
{
|
||||||
|
"_index": self._get_index(),
|
||||||
|
"_id": id,
|
||||||
|
"_source": {"text": text, "metadata": metadata, "embeddings": embedding},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
bulk(self.client, batch_docs, **kwargs)
|
||||||
self.client.indices.refresh(index=self._get_index())
|
self.client.indices.refresh(index=self._get_index())
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
|
|||||||
Reference in New Issue
Block a user