diff --git a/embedchain/embedder/base_embedder.py b/embedchain/embedder/base_embedder.py index 28614351..cef9f0f4 100644 --- a/embedchain/embedder/base_embedder.py +++ b/embedchain/embedder/base_embedder.py @@ -24,6 +24,7 @@ class BaseEmbedder: self.config = BaseEmbedderConfig() else: self.config = config + self.vector_dimension: int def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]): if not hasattr(embedding_fn, "__call__"): diff --git a/embedchain/vectordb/elasticsearch_db.py b/embedchain/vectordb/elasticsearch_db.py index 7a1d93f5..e913c2b6 100644 --- a/embedchain/vectordb/elasticsearch_db.py +++ b/embedchain/vectordb/elasticsearch_db.py @@ -81,7 +81,7 @@ class ElasticsearchDB(BaseVectorDB): :param ids: ids of docs """ docs = [] - embeddings = self.config.embedding_fn(documents) + embeddings = self.embedder.embedding_fn(documents) for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings): docs.append( { @@ -101,7 +101,7 @@ class ElasticsearchDB(BaseVectorDB): :param n_results: no of similar documents to fetch from database :param where: Optional. to filter data """ - input_query_vector = self.config.embedding_fn(input_query) + input_query_vector = self.embedder.embedding_fn(input_query) query_vector = input_query_vector[0] query = { "script_score": { @@ -126,17 +126,17 @@ class ElasticsearchDB(BaseVectorDB): def count(self) -> int: query = {"match_all": {}} - response = self.client.count(index=self.es_index, query=query) + response = self.client.count(index=self._get_index(), query=query) doc_count = response["count"] return doc_count def reset(self): # Delete all data from the database - if self.client.indices.exists(index=self.es_index): + if self.client.indices.exists(index=self._get_index()): # delete index in Es - self.client.indices.delete(index=self.es_index) + self.client.indices.delete(index=self._get_index()) def _get_index(self): # NOTE: The method is preferred to an attribute, because if collection name changes, # it's always up-to-date. - return f"{self.config.collection_name}_{self.config.vector_dim}" + return f"{self.config.collection_name}_{self.embedder.vector_dimension}"