feat: Adding app id in metadata while reading and writing to vector db (#189)
This commit is contained in:
committed by
GitHub
parent
fd97fb268a
commit
d4b8542207
@@ -1,32 +1,34 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from chromadb.utils import embedding_functions
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
from embedchain.config.BaseConfig import BaseConfig
|
from embedchain.config.BaseConfig import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
class InitConfig(BaseConfig):
|
class InitConfig(BaseConfig):
|
||||||
"""
|
"""
|
||||||
Config to initialize an embedchain `App` instance.
|
Config to initialize an embedchain `App` instance.
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, log_level=None, ef=None, db=None, host=None, port=None, id=None):
|
||||||
def __init__(self, log_level=None, ef=None, db=None, host=None, port=None):
|
|
||||||
"""
|
"""
|
||||||
:param log_level: Optional. (String) Debug level
|
:param log_level: Optional. (String) Debug level
|
||||||
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
|
||||||
:param ef: Optional. Embedding function to use.
|
:param ef: Optional. Embedding function to use.
|
||||||
:param db: Optional. (Vector) database to use for embeddings.
|
:param db: Optional. (Vector) database to use for embeddings.
|
||||||
|
:param id: Optional. ID of the app. Document metadata will have this id.
|
||||||
:param host: Optional. Hostname for the database server.
|
:param host: Optional. Hostname for the database server.
|
||||||
:param port: Optional. Port for the database server.
|
:param port: Optional. Port for the database server.
|
||||||
"""
|
"""
|
||||||
self._setup_logging(log_level)
|
self._setup_logging(log_level)
|
||||||
|
|
||||||
|
if db is None:
|
||||||
|
from embedchain.vectordb.chroma_db import ChromaDB
|
||||||
|
self.db = ChromaDB(ef=self.ef)
|
||||||
|
else:
|
||||||
|
self.db = db
|
||||||
|
|
||||||
self.ef = ef
|
self.ef = ef
|
||||||
self.db = db
|
|
||||||
|
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
|
self.id = id
|
||||||
return
|
return
|
||||||
|
|
||||||
def _set_embedding_function(self, ef):
|
def _set_embedding_function(self, ef):
|
||||||
|
|||||||
@@ -97,9 +97,11 @@ class EmbedChain:
|
|||||||
metadatas = embeddings_data["metadatas"]
|
metadatas = embeddings_data["metadatas"]
|
||||||
ids = embeddings_data["ids"]
|
ids = embeddings_data["ids"]
|
||||||
# get existing ids, and discard doc if any common id exist.
|
# get existing ids, and discard doc if any common id exist.
|
||||||
|
where={"app_id": self.config.id} if self.config.id is not None else {}
|
||||||
|
# where={"url": src}
|
||||||
existing_docs = self.collection.get(
|
existing_docs = self.collection.get(
|
||||||
ids=ids,
|
ids=ids,
|
||||||
# where={"url": src}
|
where=where, # optional filter
|
||||||
)
|
)
|
||||||
existing_ids = set(existing_docs["ids"])
|
existing_ids = set(existing_docs["ids"])
|
||||||
|
|
||||||
@@ -113,6 +115,10 @@ class EmbedChain:
|
|||||||
|
|
||||||
ids = list(data_dict.keys())
|
ids = list(data_dict.keys())
|
||||||
documents, metadatas = zip(*data_dict.values())
|
documents, metadatas = zip(*data_dict.values())
|
||||||
|
|
||||||
|
# Add app id in metadatas so that they can be queried on later
|
||||||
|
if (self.config.id is not None):
|
||||||
|
metadatas = [{**m, "app_id": self.config.id} for m in metadatas]
|
||||||
|
|
||||||
chunks_before_addition = self.count()
|
chunks_before_addition = self.count()
|
||||||
|
|
||||||
@@ -144,11 +150,11 @@ class EmbedChain:
|
|||||||
:param config: The query configuration.
|
:param config: The query configuration.
|
||||||
:return: The content of the document that matched your query.
|
:return: The content of the document that matched your query.
|
||||||
"""
|
"""
|
||||||
|
where = {"app_id": self.config.id} if self.config.id is not None else {} # optional filter
|
||||||
result = self.collection.query(
|
result = self.collection.query(
|
||||||
query_texts=[
|
query_texts=[input_query,],
|
||||||
input_query,
|
|
||||||
],
|
|
||||||
n_results=config.number_documents,
|
n_results=config.number_documents,
|
||||||
|
where=where,
|
||||||
)
|
)
|
||||||
results_formatted = self._format_result(result)
|
results_formatted = self._format_result(result)
|
||||||
contents = [result[0].page_content for result in results_formatted]
|
contents = [result[0].page_content for result in results_formatted]
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -33,7 +33,7 @@ setuptools.setup(
|
|||||||
"gpt4all",
|
"gpt4all",
|
||||||
"sentence_transformers",
|
"sentence_transformers",
|
||||||
"docx2txt",
|
"docx2txt",
|
||||||
"pydantic==1.10.8",
|
"pydantic==1.10.8"
|
||||||
],
|
],
|
||||||
extras_require={"dev": ["black", "ruff", "isort", "pytest"]},
|
extras_require={"dev": ["black", "ruff", "isort", "pytest"]},
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user