[Bug Fix] fix chromadb where clause for query and delete (#937)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -3,7 +3,8 @@ from typing import Optional
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from embedchain.client import Client
|
from embedchain.client import Client
|
||||||
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChunkerConfig
|
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
|
||||||
|
ChunkerConfig)
|
||||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||||
from embedchain.embedchain import EmbedChain
|
from embedchain.embedchain import EmbedChain
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from embedchain import Pipeline as App
|
from embedchain import Pipeline as App
|
||||||
from embedchain.config import AddConfig, PipelineConfig, BaseLlmConfig
|
from embedchain.config import AddConfig, BaseLlmConfig, PipelineConfig
|
||||||
from embedchain.embedder.openai import OpenAIEmbedder
|
from embedchain.embedder.openai import OpenAIEmbedder
|
||||||
from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
|
from embedchain.helper.json_serializable import (JSONSerializable,
|
||||||
|
register_deserializable)
|
||||||
from embedchain.llm.openai import OpenAILlm
|
from embedchain.llm.openai import OpenAILlm
|
||||||
from embedchain.vectordb.chroma import ChromaDB
|
from embedchain.vectordb.chroma import ChromaDB
|
||||||
|
|
||||||
|
|||||||
@@ -478,13 +478,13 @@ class EmbedChain(JSONSerializable):
|
|||||||
query_config = config or self.llm.config
|
query_config = config or self.llm.config
|
||||||
if where is not None:
|
if where is not None:
|
||||||
where = where
|
where = where
|
||||||
elif query_config is not None and query_config.where is not None:
|
|
||||||
where = query_config.where
|
|
||||||
else:
|
else:
|
||||||
where = {}
|
where = {}
|
||||||
|
if query_config is not None and query_config.where is not None:
|
||||||
if self.config.id is not None:
|
where = query_config.where
|
||||||
where.update({"app_id": self.config.id})
|
|
||||||
|
if self.config.id is not None:
|
||||||
|
where.update({"app_id": self.config.id})
|
||||||
|
|
||||||
# We cannot query the database with the input query in case of an image search. This is because we need
|
# We cannot query the database with the input query in case of an image search. This is because we need
|
||||||
# to bring down both the image and text to the same dimension to be able to compare them.
|
# to bring down both the image and text to the same dimension to be able to compare them.
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ class ChromaDB(BaseVectorDB):
|
|||||||
def _generate_where_clause(self, where: Dict[str, any]) -> str:
|
def _generate_where_clause(self, where: Dict[str, any]) -> str:
|
||||||
# If only one filter is supplied, return it as is
|
# If only one filter is supplied, return it as is
|
||||||
# (no need to wrap in $and based on chroma docs)
|
# (no need to wrap in $and based on chroma docs)
|
||||||
if len(where.keys()) == 1:
|
if len(where.keys()) <= 1:
|
||||||
return where
|
return where
|
||||||
where_filters = []
|
where_filters = []
|
||||||
for k, v in where.items():
|
for k, v in where.items():
|
||||||
@@ -224,7 +224,7 @@ class ChromaDB(BaseVectorDB):
|
|||||||
input_query,
|
input_query,
|
||||||
],
|
],
|
||||||
n_results=n_results,
|
n_results=n_results,
|
||||||
where=where,
|
where=self._generate_where_clause(where),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = self.collection.query(
|
result = self.collection.query(
|
||||||
@@ -232,7 +232,7 @@ class ChromaDB(BaseVectorDB):
|
|||||||
input_query,
|
input_query,
|
||||||
],
|
],
|
||||||
n_results=n_results,
|
n_results=n_results,
|
||||||
where=where,
|
where=self._generate_where_clause(where),
|
||||||
)
|
)
|
||||||
except InvalidDimensionException as e:
|
except InvalidDimensionException as e:
|
||||||
raise InvalidDimensionException(
|
raise InvalidDimensionException(
|
||||||
@@ -275,7 +275,7 @@ class ChromaDB(BaseVectorDB):
|
|||||||
return self.collection.count()
|
return self.collection.count()
|
||||||
|
|
||||||
def delete(self, where):
|
def delete(self, where):
|
||||||
return self.collection.delete(where=where)
|
return self.collection.delete(where=self._generate_where_clause(where))
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "embedchain"
|
name = "embedchain"
|
||||||
version = "0.1.5"
|
version = "0.1.6"
|
||||||
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
||||||
authors = [
|
authors = [
|
||||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||||
|
|||||||
Reference in New Issue
Block a user