[Improvements] allow setting up the elasticsearch cloud instance (#997)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-12-07 14:17:59 -08:00
committed by GitHub
parent d62a23edf6
commit 0ea8ab228c
4 changed files with 49 additions and 21 deletions

View File

@@ -11,6 +11,7 @@ except ImportError:
from embedchain.config import ElasticsearchDBConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.utils import chunks
from embedchain.vectordb.base import BaseVectorDB
@@ -20,6 +21,8 @@ class ElasticsearchDB(BaseVectorDB):
Elasticsearch as vector database
"""
BATCH_SIZE = 100
def __init__(
self,
config: Optional[ElasticsearchDBConfig] = None,
@@ -43,7 +46,14 @@ class ElasticsearchDB(BaseVectorDB):
"Please make sure the type is right and that you are passing an instance."
)
self.config = config or es_config
self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS)
if self.config.ES_URL:
self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS)
elif self.config.CLOUD_ID:
self.client = Elasticsearch(cloud_id=self.config.CLOUD_ID, **self.config.ES_EXTRA_PARAMS)
else:
raise ValueError(
"Something is wrong with your config. Please check again - `https://docs.embedchain.ai/components/vector-databases#elasticsearch`" # noqa: E501
)
# Call parent init here because embedder is needed
super().__init__(config=self.config)
@@ -121,19 +131,29 @@ class ElasticsearchDB(BaseVectorDB):
:type skip_embedding: bool
"""
docs = []
if not skip_embedding:
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
docs.append(
{
"_index": self._get_index(),
"_id": id,
"_source": {"text": text, "metadata": metadata, "embeddings": embeddings},
}
)
bulk(self.client, docs)
for chunk in chunks(
list(zip(ids, documents, metadatas, embeddings)), self.BATCH_SIZE, desc="Inserting batches in elasticsearch"
): # noqa: E501
ids, docs, metadatas, embeddings = [], [], [], []
for id, text, metadata, embedding in chunk:
ids.append(id)
docs.append(text)
metadatas.append(metadata)
embeddings.append(embedding)
batch_docs = []
for id, text, metadata, embedding in zip(ids, docs, metadatas, embeddings):
batch_docs.append(
{
"_index": self._get_index(),
"_id": id,
"_source": {"text": text, "metadata": metadata, "embeddings": embedding},
}
)
bulk(self.client, batch_docs, **kwargs)
self.client.indices.refresh(index=self._get_index())
def query(