[Bug Fix] fix chromadb where clause for query and delete (#937)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-11-10 16:04:25 -08:00
committed by GitHub
parent 744ab5156f
commit deaa7f50f8
5 changed files with 15 additions and 13 deletions

View File

@@ -3,7 +3,8 @@ from typing import Optional
import yaml import yaml
from embedchain.client import Client from embedchain.client import Client
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChunkerConfig from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
ChunkerConfig)
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.embedchain import EmbedChain from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder

View File

@@ -1,9 +1,10 @@
from typing import Any from typing import Any
from embedchain import Pipeline as App from embedchain import Pipeline as App
from embedchain.config import AddConfig, PipelineConfig, BaseLlmConfig from embedchain.config import AddConfig, BaseLlmConfig, PipelineConfig
from embedchain.embedder.openai import OpenAIEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.helper.json_serializable import JSONSerializable, register_deserializable from embedchain.helper.json_serializable import (JSONSerializable,
register_deserializable)
from embedchain.llm.openai import OpenAILlm from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.chroma import ChromaDB from embedchain.vectordb.chroma import ChromaDB

View File

@@ -478,13 +478,13 @@ class EmbedChain(JSONSerializable):
query_config = config or self.llm.config query_config = config or self.llm.config
if where is not None: if where is not None:
where = where where = where
elif query_config is not None and query_config.where is not None:
where = query_config.where
else: else:
where = {} where = {}
if query_config is not None and query_config.where is not None:
if self.config.id is not None: where = query_config.where
where.update({"app_id": self.config.id})
if self.config.id is not None:
where.update({"app_id": self.config.id})
# We cannot query the database with the input query in case of an image search. This is because we need # We cannot query the database with the input query in case of an image search. This is because we need
# to bring down both the image and text to the same dimension to be able to compare them. # to bring down both the image and text to the same dimension to be able to compare them.

View File

@@ -77,7 +77,7 @@ class ChromaDB(BaseVectorDB):
def _generate_where_clause(self, where: Dict[str, any]) -> str: def _generate_where_clause(self, where: Dict[str, any]) -> str:
# If only one filter is supplied, return it as is # If only one filter is supplied, return it as is
# (no need to wrap in $and based on chroma docs) # (no need to wrap in $and based on chroma docs)
if len(where.keys()) == 1: if len(where.keys()) <= 1:
return where return where
where_filters = [] where_filters = []
for k, v in where.items(): for k, v in where.items():
@@ -224,7 +224,7 @@ class ChromaDB(BaseVectorDB):
input_query, input_query,
], ],
n_results=n_results, n_results=n_results,
where=where, where=self._generate_where_clause(where),
) )
else: else:
result = self.collection.query( result = self.collection.query(
@@ -232,7 +232,7 @@ class ChromaDB(BaseVectorDB):
input_query, input_query,
], ],
n_results=n_results, n_results=n_results,
where=where, where=self._generate_where_clause(where),
) )
except InvalidDimensionException as e: except InvalidDimensionException as e:
raise InvalidDimensionException( raise InvalidDimensionException(
@@ -275,7 +275,7 @@ class ChromaDB(BaseVectorDB):
return self.collection.count() return self.collection.count()
def delete(self, where): def delete(self, where):
return self.collection.delete(where=where) return self.collection.delete(where=self._generate_where_clause(where))
def reset(self): def reset(self):
""" """

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "embedchain" name = "embedchain"
version = "0.1.5" version = "0.1.6"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = [ authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>", "Taranjeet Singh <taranjeet@embedchain.ai>",