fix: Elasticsearch - use correct class attributes (#566)
This commit is contained in:
@@ -24,6 +24,7 @@ class BaseEmbedder:
|
|||||||
self.config = BaseEmbedderConfig()
|
self.config = BaseEmbedderConfig()
|
||||||
else:
|
else:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.vector_dimension: int
|
||||||
|
|
||||||
def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
|
def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
|
||||||
if not hasattr(embedding_fn, "__call__"):
|
if not hasattr(embedding_fn, "__call__"):
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
:param ids: ids of docs
|
:param ids: ids of docs
|
||||||
"""
|
"""
|
||||||
docs = []
|
docs = []
|
||||||
embeddings = self.config.embedding_fn(documents)
|
embeddings = self.embedder.embedding_fn(documents)
|
||||||
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
|
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
|
||||||
docs.append(
|
docs.append(
|
||||||
{
|
{
|
||||||
@@ -101,7 +101,7 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
:param n_results: no of similar documents to fetch from database
|
:param n_results: no of similar documents to fetch from database
|
||||||
:param where: Optional. to filter data
|
:param where: Optional. to filter data
|
||||||
"""
|
"""
|
||||||
input_query_vector = self.config.embedding_fn(input_query)
|
input_query_vector = self.embedder.embedding_fn(input_query)
|
||||||
query_vector = input_query_vector[0]
|
query_vector = input_query_vector[0]
|
||||||
query = {
|
query = {
|
||||||
"script_score": {
|
"script_score": {
|
||||||
@@ -126,17 +126,17 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
|
|
||||||
def count(self) -> int:
|
def count(self) -> int:
|
||||||
query = {"match_all": {}}
|
query = {"match_all": {}}
|
||||||
response = self.client.count(index=self.es_index, query=query)
|
response = self.client.count(index=self._get_index(), query=query)
|
||||||
doc_count = response["count"]
|
doc_count = response["count"]
|
||||||
return doc_count
|
return doc_count
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
# Delete all data from the database
|
# Delete all data from the database
|
||||||
if self.client.indices.exists(index=self.es_index):
|
if self.client.indices.exists(index=self._get_index()):
|
||||||
# delete index in Es
|
# delete index in Es
|
||||||
self.client.indices.delete(index=self.es_index)
|
self.client.indices.delete(index=self._get_index())
|
||||||
|
|
||||||
def _get_index(self):
|
def _get_index(self):
|
||||||
# NOTE: The method is preferred to an attribute, because if collection name changes,
|
# NOTE: The method is preferred to an attribute, because if collection name changes,
|
||||||
# it's always up-to-date.
|
# it's always up-to-date.
|
||||||
return f"{self.config.collection_name}_{self.config.vector_dim}"
|
return f"{self.config.collection_name}_{self.embedder.vector_dimension}"
|
||||||
|
|||||||
Reference in New Issue
Block a user