[Feature]: Add support for creating app using yaml config (#787)

This commit is contained in:
Deshraj Yadav
2023-10-12 15:35:49 -07:00
committed by GitHub
parent 4820ea15d6
commit a86d7f52e9
36 changed files with 479 additions and 95 deletions

View File

@@ -37,7 +37,7 @@ class ChromaDB(BaseVectorDB):
self.config = ChromaDbConfig()
self.settings = Settings()
self.settings.allow_reset = self.config.allow_reset
self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False
if self.config.chroma_settings:
for key, value in self.config.chroma_settings.items():
if hasattr(self.settings, key):
@@ -72,6 +72,17 @@ class ChromaDB(BaseVectorDB):
"""Called during initialization"""
return self.client
def _generate_where_clause(self, where: Dict[str, any]) -> str:
# If only one filter is supplied, return it as is
# (no need to wrap in $and based on chroma docs)
if len(where.keys()) == 1:
return where
where_filters = []
for k, v in where.items():
if isinstance(v, str):
where_filters.append({k: v})
return {"$and": where_filters}
def _get_or_create_collection(self, name: str) -> Collection:
"""
Get or create a named collection.
@@ -107,13 +118,14 @@ class ChromaDB(BaseVectorDB):
if ids:
args["ids"] = ids
if where:
args["where"] = where
args["where"] = self._generate_where_clause(where)
if limit:
args["limit"] = limit
return self.collection.get(**args)
def get_advanced(self, where):
return self.collection.get(where=where, limit=1)
where_clause = self._generate_where_clause(where)
return self.collection.get(where=where_clause, limit=1)
def add(
self,

View File

@@ -110,8 +110,13 @@ class OpenSearchDB(BaseVectorDB):
return result
def add(
self, embeddings: List[List[str]], documents: List[str], metadatas: List[object], ids: List[str],
skip_embedding: bool):
self,
embeddings: List[List[str]],
documents: List[str],
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
):
"""add data in vector database
:param embeddings: list of embeddings to add
@@ -162,7 +167,8 @@ class OpenSearchDB(BaseVectorDB):
embedding_function=embeddings,
opensearch_url=f"{self.config.opensearch_url}",
http_auth=self.config.http_auth,
use_ssl=True,
use_ssl=hasattr(self.config, "use_ssl") and self.config.use_ssl,
verify_certs=hasattr(self.config, "verify_certs") and self.config.verify_certs,
)
pre_filter = {"match_all": {}} # default

View File

@@ -5,8 +5,8 @@ from embedchain.helper.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB
try:
from pymilvus import MilvusClient
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
MilvusClient, connections, utility)
except ImportError:
raise ImportError(
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"