fix: Elasticsearch - use correct class attributes (#566)

This commit is contained in:
cachho
2023-09-06 00:58:40 +02:00
committed by GitHub
parent f0844ed923
commit b0d8711b65
2 changed files with 7 additions and 6 deletions

View File

@@ -24,6 +24,7 @@ class BaseEmbedder:
self.config = BaseEmbedderConfig()
else:
self.config = config
self.vector_dimension: int
def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
if not hasattr(embedding_fn, "__call__"):

View File

@@ -81,7 +81,7 @@ class ElasticsearchDB(BaseVectorDB):
:param ids: ids of docs
"""
docs = []
embeddings = self.config.embedding_fn(documents)
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
docs.append(
{
@@ -101,7 +101,7 @@ class ElasticsearchDB(BaseVectorDB):
:param n_results: no of similar documents to fetch from database
:param where: Optional. to filter data
"""
input_query_vector = self.config.embedding_fn(input_query)
input_query_vector = self.embedder.embedding_fn(input_query)
query_vector = input_query_vector[0]
query = {
"script_score": {
@@ -126,17 +126,17 @@ class ElasticsearchDB(BaseVectorDB):
def count(self) -> int:
query = {"match_all": {}}
response = self.client.count(index=self.es_index, query=query)
response = self.client.count(index=self._get_index(), query=query)
doc_count = response["count"]
return doc_count
def reset(self):
# Delete all data from the database
if self.client.indices.exists(index=self.es_index):
if self.client.indices.exists(index=self._get_index()):
# delete index in Es
self.client.indices.delete(index=self.es_index)
self.client.indices.delete(index=self._get_index())
def _get_index(self):
# NOTE: The method is preferred to an attribute, because if collection name changes,
# it's always up-to-date.
return f"{self.config.collection_name}_{self.config.vector_dim}"
return f"{self.config.collection_name}_{self.embedder.vector_dimension}"