[Bug Fix] fix chromadb where clause for query and delete (#937)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -3,7 +3,8 @@ from typing import Optional
|
||||
import yaml
|
||||
|
||||
from embedchain.client import Client
|
||||
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
|
||||
ChunkerConfig)
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain.config import AddConfig, PipelineConfig, BaseLlmConfig
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, PipelineConfig
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
|
||||
from embedchain.helper.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
|
||||
@@ -478,13 +478,13 @@ class EmbedChain(JSONSerializable):
|
||||
query_config = config or self.llm.config
|
||||
if where is not None:
|
||||
where = where
|
||||
elif query_config is not None and query_config.where is not None:
|
||||
where = query_config.where
|
||||
else:
|
||||
where = {}
|
||||
|
||||
if self.config.id is not None:
|
||||
where.update({"app_id": self.config.id})
|
||||
if query_config is not None and query_config.where is not None:
|
||||
where = query_config.where
|
||||
|
||||
if self.config.id is not None:
|
||||
where.update({"app_id": self.config.id})
|
||||
|
||||
# We cannot query the database with the input query in case of an image search. This is because we need
|
||||
# to bring down both the image and text to the same dimension to be able to compare them.
|
||||
|
||||
@@ -77,7 +77,7 @@ class ChromaDB(BaseVectorDB):
|
||||
def _generate_where_clause(self, where: Dict[str, any]) -> str:
|
||||
# If only one filter is supplied, return it as is
|
||||
# (no need to wrap in $and based on chroma docs)
|
||||
if len(where.keys()) == 1:
|
||||
if len(where.keys()) <= 1:
|
||||
return where
|
||||
where_filters = []
|
||||
for k, v in where.items():
|
||||
@@ -224,7 +224,7 @@ class ChromaDB(BaseVectorDB):
|
||||
input_query,
|
||||
],
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
where=self._generate_where_clause(where),
|
||||
)
|
||||
else:
|
||||
result = self.collection.query(
|
||||
@@ -232,7 +232,7 @@ class ChromaDB(BaseVectorDB):
|
||||
input_query,
|
||||
],
|
||||
n_results=n_results,
|
||||
where=where,
|
||||
where=self._generate_where_clause(where),
|
||||
)
|
||||
except InvalidDimensionException as e:
|
||||
raise InvalidDimensionException(
|
||||
@@ -275,7 +275,7 @@ class ChromaDB(BaseVectorDB):
|
||||
return self.collection.count()
|
||||
|
||||
def delete(self, where):
|
||||
return self.collection.delete(where=where)
|
||||
return self.collection.delete(where=self._generate_where_clause(where))
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "embedchain"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
||||
authors = [
|
||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||
|
||||
Reference in New Issue
Block a user