From b0d8711b655bec416609952d1b84ce5f7c93b383 Mon Sep 17 00:00:00 2001 From: cachho Date: Wed, 6 Sep 2023 00:58:40 +0200 Subject: [PATCH] fix: Elasticsearch - use correct class attributes (#566) --- embedchain/embedder/base_embedder.py | 1 + embedchain/vectordb/elasticsearch_db.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/embedchain/embedder/base_embedder.py b/embedchain/embedder/base_embedder.py index 28614351..cef9f0f4 100644 --- a/embedchain/embedder/base_embedder.py +++ b/embedchain/embedder/base_embedder.py @@ -24,6 +24,7 @@ class BaseEmbedder: self.config = BaseEmbedderConfig() else: self.config = config + self.vector_dimension: int def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]): if not hasattr(embedding_fn, "__call__"): diff --git a/embedchain/vectordb/elasticsearch_db.py b/embedchain/vectordb/elasticsearch_db.py index 7a1d93f5..e913c2b6 100644 --- a/embedchain/vectordb/elasticsearch_db.py +++ b/embedchain/vectordb/elasticsearch_db.py @@ -81,7 +81,7 @@ class ElasticsearchDB(BaseVectorDB): :param ids: ids of docs """ docs = [] - embeddings = self.config.embedding_fn(documents) + embeddings = self.embedder.embedding_fn(documents) for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings): docs.append( { @@ -101,7 +101,7 @@ class ElasticsearchDB(BaseVectorDB): :param n_results: no of similar documents to fetch from database :param where: Optional. to filter data """ - input_query_vector = self.config.embedding_fn(input_query) + input_query_vector = self.embedder.embedding_fn(input_query) query_vector = input_query_vector[0] query = { "script_score": { @@ -126,17 +126,17 @@ class ElasticsearchDB(BaseVectorDB): def count(self) -> int: query = {"match_all": {}} - response = self.client.count(index=self.es_index, query=query) + response = self.client.count(index=self._get_index(), query=query) doc_count = response["count"] return doc_count def reset(self): # Delete all data from the database - if self.client.indices.exists(index=self.es_index): + if self.client.indices.exists(index=self._get_index()): # delete index in Es - self.client.indices.delete(index=self.es_index) + self.client.indices.delete(index=self._get_index()) def _get_index(self): # NOTE: The method is preferred to an attribute, because if collection name changes, # it's always up-to-date. - return f"{self.config.collection_name}_{self.config.vector_dim}" + return f"{self.config.collection_name}_{self.embedder.vector_dimension}"