feat: Add embedding manager (#570)
This commit is contained in:
@@ -22,14 +22,17 @@ class BaseChunker(JSONSerializable):
|
|||||||
documents = []
|
documents = []
|
||||||
ids = []
|
ids = []
|
||||||
idMap = {}
|
idMap = {}
|
||||||
datas = loader.load_data(src)
|
data_result = loader.load_data(src)
|
||||||
|
data_records = data_result["data"]
|
||||||
|
doc_id = data_result["doc_id"]
|
||||||
metadatas = []
|
metadatas = []
|
||||||
for data in datas:
|
for data in data_records:
|
||||||
content = data["content"]
|
content = data["content"]
|
||||||
|
|
||||||
meta_data = data["meta_data"]
|
meta_data = data["meta_data"]
|
||||||
# add data type to meta data to allow query using data type
|
# add data type to meta data to allow query using data type
|
||||||
meta_data["data_type"] = self.data_type.value
|
meta_data["data_type"] = self.data_type.value
|
||||||
|
meta_data["doc_id"] = doc_id
|
||||||
url = meta_data["url"]
|
url = meta_data["url"]
|
||||||
|
|
||||||
chunks = self.get_chunks(content)
|
chunks = self.get_chunks(content)
|
||||||
@@ -45,6 +48,7 @@ class BaseChunker(JSONSerializable):
|
|||||||
"documents": documents,
|
"documents": documents,
|
||||||
"ids": ids,
|
"ids": ids,
|
||||||
"metadatas": metadatas,
|
"metadatas": metadatas,
|
||||||
|
"doc_id": doc_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_chunks(self, content):
|
def get_chunks(self, content):
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from langchain.docstore.document import Document
|
||||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||||
|
|
||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
@@ -179,7 +180,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
|
|
||||||
data_formatter = DataFormatter(data_type, config)
|
data_formatter = DataFormatter(data_type, config)
|
||||||
self.user_asks.append([source, data_type.value, metadata])
|
self.user_asks.append([source, data_type.value, metadata])
|
||||||
documents, metadatas, _ids, new_chunks = self.load_and_embed(
|
documents, metadatas, _ids, new_chunks = self.load_and_embed_v2(
|
||||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
|
data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
|
||||||
)
|
)
|
||||||
if data_type in {DataType.DOCS_SITE}:
|
if data_type in {DataType.DOCS_SITE}:
|
||||||
@@ -271,10 +272,11 @@ class EmbedChain(JSONSerializable):
|
|||||||
# get existing ids, and discard doc if any common id exist.
|
# get existing ids, and discard doc if any common id exist.
|
||||||
where = {"app_id": self.config.id} if self.config.id is not None else {}
|
where = {"app_id": self.config.id} if self.config.id is not None else {}
|
||||||
# where={"url": src}
|
# where={"url": src}
|
||||||
existing_ids = self.db.get(
|
db_result = self.db.get(
|
||||||
ids=ids,
|
ids=ids,
|
||||||
where=where, # optional filter
|
where=where, # optional filter
|
||||||
)
|
)
|
||||||
|
existing_ids = set(db_result["ids"])
|
||||||
|
|
||||||
if len(existing_ids):
|
if len(existing_ids):
|
||||||
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
|
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
|
||||||
@@ -317,6 +319,112 @@ class EmbedChain(JSONSerializable):
|
|||||||
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
||||||
return list(documents), metadatas, ids, count_new_chunks
|
return list(documents), metadatas, ids, count_new_chunks
|
||||||
|
|
||||||
|
def load_and_embed_v2(
|
||||||
|
self,
|
||||||
|
loader: BaseLoader,
|
||||||
|
chunker: BaseChunker,
|
||||||
|
src: Any,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
source_id: Optional[str] = None,
|
||||||
|
dry_run = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Loads the data from the given URL, chunks it, and adds it to database.
|
||||||
|
|
||||||
|
:param loader: The loader to use to load the data.
|
||||||
|
:param chunker: The chunker to use to chunk the data.
|
||||||
|
:param src: The data to be handled by the loader. Can be a URL for
|
||||||
|
remote sources or local content for local loaders.
|
||||||
|
:param metadata: Optional. Metadata associated with the data source.
|
||||||
|
:param source_id: Hexadecimal hash of the source.
|
||||||
|
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
||||||
|
"""
|
||||||
|
existing_embeddings_data = self.db.get(
|
||||||
|
where={
|
||||||
|
"url": src,
|
||||||
|
},
|
||||||
|
limit=1,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
|
||||||
|
except Exception:
|
||||||
|
existing_doc_id = None
|
||||||
|
embeddings_data = chunker.create_chunks(loader, src)
|
||||||
|
|
||||||
|
# spread chunking results
|
||||||
|
documents = embeddings_data["documents"]
|
||||||
|
metadatas = embeddings_data["metadatas"]
|
||||||
|
ids = embeddings_data["ids"]
|
||||||
|
new_doc_id = embeddings_data["doc_id"]
|
||||||
|
|
||||||
|
if existing_doc_id and existing_doc_id == new_doc_id:
|
||||||
|
print("Doc content has not changed. Skipping creating chunks and embeddings")
|
||||||
|
return [], [], [], 0
|
||||||
|
|
||||||
|
# this means that doc content has changed.
|
||||||
|
if existing_doc_id and existing_doc_id != new_doc_id:
|
||||||
|
print("Doc content has changed. Recomputing chunks and embeddings intelligently.")
|
||||||
|
self.db.delete({
|
||||||
|
"doc_id": existing_doc_id
|
||||||
|
})
|
||||||
|
|
||||||
|
# get existing ids, and discard doc if any common id exist.
|
||||||
|
where = {"app_id": self.config.id} if self.config.id is not None else {}
|
||||||
|
# where={"url": src}
|
||||||
|
db_result = self.db.get(
|
||||||
|
ids=ids,
|
||||||
|
where=where, # optional filter
|
||||||
|
)
|
||||||
|
existing_ids = set(db_result["ids"])
|
||||||
|
|
||||||
|
if len(existing_ids):
|
||||||
|
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
|
||||||
|
data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
|
||||||
|
|
||||||
|
if not data_dict:
|
||||||
|
print(f"All data from {src} already exists in the database.")
|
||||||
|
# Make sure to return a matching return type
|
||||||
|
return [], [], [], 0
|
||||||
|
|
||||||
|
ids = list(data_dict.keys())
|
||||||
|
documents, metadatas = zip(*data_dict.values())
|
||||||
|
|
||||||
|
# Loop though all metadatas and add extras.
|
||||||
|
new_metadatas = []
|
||||||
|
for m in metadatas:
|
||||||
|
# Add app id in metadatas so that they can be queried on later
|
||||||
|
if self.config.id:
|
||||||
|
m["app_id"] = self.config.id
|
||||||
|
|
||||||
|
# Add hashed source
|
||||||
|
m["hash"] = source_id
|
||||||
|
|
||||||
|
# Note: Metadata is the function argument
|
||||||
|
if metadata:
|
||||||
|
# Spread whatever is in metadata into the new object.
|
||||||
|
m.update(metadata)
|
||||||
|
|
||||||
|
new_metadatas.append(m)
|
||||||
|
metadatas = new_metadatas
|
||||||
|
|
||||||
|
# Count before, to calculate a delta in the end.
|
||||||
|
chunks_before_addition = self.count()
|
||||||
|
|
||||||
|
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||||
|
count_new_chunks = self.count() - chunks_before_addition
|
||||||
|
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
||||||
|
return list(documents), metadatas, ids, count_new_chunks
|
||||||
|
|
||||||
|
def _format_result(self, results):
|
||||||
|
return [
|
||||||
|
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
||||||
|
for result in zip(
|
||||||
|
results["documents"][0],
|
||||||
|
results["metadatas"][0],
|
||||||
|
results["distances"][0],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
|
def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Queries the vector database based on the given input query.
|
Queries the vector database based on the given input query.
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import csv
|
import csv
|
||||||
|
import hashlib
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -34,7 +35,7 @@ class CsvLoader(BaseLoader):
|
|||||||
def load_data(content):
|
def load_data(content):
|
||||||
"""Load a csv file with headers. Each line is a document"""
|
"""Load a csv file with headers. Each line is a document"""
|
||||||
result = []
|
result = []
|
||||||
|
lines = []
|
||||||
with CsvLoader._get_file_content(content) as file:
|
with CsvLoader._get_file_content(content) as file:
|
||||||
first_line = file.readline()
|
first_line = file.readline()
|
||||||
delimiter = CsvLoader._detect_delimiter(first_line)
|
delimiter = CsvLoader._detect_delimiter(first_line)
|
||||||
@@ -42,5 +43,10 @@ class CsvLoader(BaseLoader):
|
|||||||
reader = csv.DictReader(file, delimiter=delimiter)
|
reader = csv.DictReader(file, delimiter=delimiter)
|
||||||
for i, row in enumerate(reader):
|
for i, row in enumerate(reader):
|
||||||
line = ", ".join([f"{field}: {value}" for field, value in row.items()])
|
line = ", ".join([f"{field}: {value}" for field, value in row.items()])
|
||||||
|
lines.append(line)
|
||||||
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
|
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
|
||||||
return result
|
doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
|
||||||
|
return {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": result
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
@@ -99,4 +100,8 @@ class DocsSiteLoader(BaseLoader):
|
|||||||
output = []
|
output = []
|
||||||
for u in all_urls:
|
for u in all_urls:
|
||||||
output.extend(self._load_data_from_url(u))
|
output.extend(self._load_data_from_url(u))
|
||||||
return output
|
doc_id = hashlib.sha256((" ".join(all_urls) + url).encode()).hexdigest()
|
||||||
|
return {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": output,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import hashlib
|
||||||
|
|
||||||
from langchain.document_loaders import Docx2txtLoader
|
from langchain.document_loaders import Docx2txtLoader
|
||||||
|
|
||||||
from embedchain.helper.json_serializable import register_deserializable
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
@@ -15,4 +17,8 @@ class DocxFileLoader(BaseLoader):
|
|||||||
meta_data = data[0].metadata
|
meta_data = data[0].metadata
|
||||||
meta_data["url"] = "local"
|
meta_data["url"] = "local"
|
||||||
output.append({"content": content, "meta_data": meta_data})
|
output.append({"content": content, "meta_data": meta_data})
|
||||||
return output
|
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||||
|
return {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": output,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import hashlib
|
||||||
|
|
||||||
from embedchain.helper.json_serializable import register_deserializable
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
|
|
||||||
@@ -8,12 +10,17 @@ class LocalQnaPairLoader(BaseLoader):
|
|||||||
"""Load data from a local QnA pair."""
|
"""Load data from a local QnA pair."""
|
||||||
question, answer = content
|
question, answer = content
|
||||||
content = f"Q: {question}\nA: {answer}"
|
content = f"Q: {question}\nA: {answer}"
|
||||||
|
url = "local"
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"url": "local",
|
"url": url,
|
||||||
}
|
}
|
||||||
return [
|
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||||
|
return {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": [
|
||||||
{
|
{
|
||||||
"content": content,
|
"content": content,
|
||||||
"meta_data": meta_data,
|
"meta_data": meta_data,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import hashlib
|
||||||
|
|
||||||
from embedchain.helper.json_serializable import register_deserializable
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
|
|
||||||
@@ -6,12 +8,17 @@ from embedchain.loaders.base_loader import BaseLoader
|
|||||||
class LocalTextLoader(BaseLoader):
|
class LocalTextLoader(BaseLoader):
|
||||||
def load_data(self, content):
|
def load_data(self, content):
|
||||||
"""Load data from a local text file."""
|
"""Load data from a local text file."""
|
||||||
|
url = "local"
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"url": "local",
|
"url": url,
|
||||||
}
|
}
|
||||||
return [
|
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||||
|
return {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": [
|
||||||
{
|
{
|
||||||
"content": content,
|
"content": content,
|
||||||
"meta_data": meta_data,
|
"meta_data": meta_data,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -34,10 +35,13 @@ class NotionLoader(BaseLoader):
|
|||||||
|
|
||||||
# Clean text
|
# Clean text
|
||||||
text = clean_string(raw_text)
|
text = clean_string(raw_text)
|
||||||
|
doc_id = hashlib.sha256((text + source).encode()).hexdigest()
|
||||||
return [
|
return {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": [
|
||||||
{
|
{
|
||||||
"content": text,
|
"content": text,
|
||||||
"meta_data": {"url": f"notion-{formatted_id}"},
|
"meta_data": {"url": f"notion-{formatted_id}"},
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import hashlib
|
||||||
|
|
||||||
from langchain.document_loaders import PyPDFLoader
|
from langchain.document_loaders import PyPDFLoader
|
||||||
|
|
||||||
from embedchain.helper.json_serializable import register_deserializable
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
@@ -10,7 +12,8 @@ class PdfFileLoader(BaseLoader):
|
|||||||
def load_data(self, url):
|
def load_data(self, url):
|
||||||
"""Load data from a PDF file."""
|
"""Load data from a PDF file."""
|
||||||
loader = PyPDFLoader(url)
|
loader = PyPDFLoader(url)
|
||||||
output = []
|
data = []
|
||||||
|
all_content = []
|
||||||
pages = loader.load_and_split()
|
pages = loader.load_and_split()
|
||||||
if not len(pages):
|
if not len(pages):
|
||||||
raise ValueError("No data found")
|
raise ValueError("No data found")
|
||||||
@@ -19,10 +22,15 @@ class PdfFileLoader(BaseLoader):
|
|||||||
content = clean_string(content)
|
content = clean_string(content)
|
||||||
meta_data = page.metadata
|
meta_data = page.metadata
|
||||||
meta_data["url"] = url
|
meta_data["url"] = url
|
||||||
output.append(
|
data.append(
|
||||||
{
|
{
|
||||||
"content": content,
|
"content": content,
|
||||||
"meta_data": meta_data,
|
"meta_data": meta_data,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return output
|
all_content.append(content)
|
||||||
|
doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
|
||||||
|
return {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -30,6 +31,8 @@ class SitemapLoader(BaseLoader):
|
|||||||
# Get all <loc> tags as a fallback. This might include images.
|
# Get all <loc> tags as a fallback. This might include images.
|
||||||
links = [link.text for link in soup.find_all("loc")]
|
links = [link.text for link in soup.find_all("loc")]
|
||||||
|
|
||||||
|
doc_id = hashlib.sha256((" ".join(links) + sitemap_url).encode()).hexdigest()
|
||||||
|
|
||||||
for link in links:
|
for link in links:
|
||||||
try:
|
try:
|
||||||
each_load_data = web_page_loader.load_data(link)
|
each_load_data = web_page_loader.load_data(link)
|
||||||
@@ -40,4 +43,7 @@ class SitemapLoader(BaseLoader):
|
|||||||
logging.warning(f"Page is not readable (too many invalid characters): {link}")
|
logging.warning(f"Page is not readable (too many invalid characters): {link}")
|
||||||
except ParserRejectedMarkup as e:
|
except ParserRejectedMarkup as e:
|
||||||
logging.error(f"Failed to parse {link}: {e}")
|
logging.error(f"Failed to parse {link}: {e}")
|
||||||
return [data[0] for data in output]
|
return {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": [data[0] for data in output]
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -63,10 +64,14 @@ class WebPageLoader(BaseLoader):
|
|||||||
meta_data = {
|
meta_data = {
|
||||||
"url": url,
|
"url": url,
|
||||||
}
|
}
|
||||||
|
content = content
|
||||||
return [
|
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||||
|
return {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": [
|
||||||
{
|
{
|
||||||
"content": content,
|
"content": content,
|
||||||
"meta_data": meta_data,
|
"meta_data": meta_data,
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import hashlib
|
||||||
|
|
||||||
from langchain.document_loaders import YoutubeLoader
|
from langchain.document_loaders import YoutubeLoader
|
||||||
|
|
||||||
from embedchain.helper.json_serializable import register_deserializable
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
@@ -18,10 +20,15 @@ class YoutubeVideoLoader(BaseLoader):
|
|||||||
content = clean_string(content)
|
content = clean_string(content)
|
||||||
meta_data = doc[0].metadata
|
meta_data = doc[0].metadata
|
||||||
meta_data["url"] = url
|
meta_data["url"] = url
|
||||||
|
|
||||||
output.append(
|
output.append(
|
||||||
{
|
{
|
||||||
"content": content,
|
"content": content,
|
||||||
"meta_data": meta_data,
|
"meta_data": meta_data,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return output
|
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||||
|
return {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": output,
|
||||||
|
}
|
||||||
|
|||||||
50
embedchain/vectordb/base_vector_db.py
Normal file
50
embedchain/vectordb/base_vector_db.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
|
||||||
|
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||||
|
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVectorDB(JSONSerializable):
|
||||||
|
"""Base class for vector database."""
|
||||||
|
|
||||||
|
def __init__(self, config: BaseVectorDbConfig):
|
||||||
|
self.client = self._get_or_create_db()
|
||||||
|
self.config: BaseVectorDbConfig = config
|
||||||
|
|
||||||
|
def _initialize(self):
|
||||||
|
"""
|
||||||
|
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
|
||||||
|
|
||||||
|
So it's can't be done in __init__ in one step.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_or_create_db(self):
|
||||||
|
"""Get or create the database."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_or_create_collection(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _set_embedder(self, embedder: BaseEmbedder):
|
||||||
|
self.embedder = embedder
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def add(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def query(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def count(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def delete(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def set_collection_name(self, name: str):
|
||||||
|
raise NotImplementedError
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Any
|
||||||
|
|
||||||
from chromadb import Collection, QueryResult
|
from chromadb import Collection, QueryResult
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
@@ -87,25 +87,32 @@ class ChromaDB(BaseVectorDB):
|
|||||||
)
|
)
|
||||||
return self.collection
|
return self.collection
|
||||||
|
|
||||||
def get(self, ids: List[str], where: Dict[str, any]) -> List[str]:
|
def get(self, ids=None, where=None, limit=None):
|
||||||
"""
|
"""
|
||||||
Get existing doc ids present in vector database
|
Get existing doc ids present in vector database
|
||||||
|
|
||||||
:param ids: list of doc ids to check for existence
|
:param ids: list of doc ids to check for existence
|
||||||
:type ids: List[str]
|
:type ids: List[str]
|
||||||
:param where: Optional. to filter data
|
:param where: Optional. to filter data
|
||||||
:type where: Dict[str, any]
|
:type where: Dict[str, Any]
|
||||||
:return: Existing documents.
|
:return: Existing documents.
|
||||||
:rtype: List[str]
|
:rtype: List[str]
|
||||||
"""
|
"""
|
||||||
existing_docs = self.collection.get(
|
args = {}
|
||||||
ids=ids,
|
if ids:
|
||||||
where=where, # optional filter
|
args["ids"] = ids
|
||||||
|
if where:
|
||||||
|
args["where"] = where
|
||||||
|
if limit:
|
||||||
|
args["limit"] = limit
|
||||||
|
return self.collection.get(
|
||||||
|
**args
|
||||||
)
|
)
|
||||||
|
|
||||||
return set(existing_docs["ids"])
|
def get_advanced(self, where):
|
||||||
|
return self.collection.get(where=where, limit=1)
|
||||||
|
|
||||||
def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
|
def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
|
||||||
"""
|
"""
|
||||||
Add vectors to chroma database
|
Add vectors to chroma database
|
||||||
|
|
||||||
@@ -136,7 +143,7 @@ class ChromaDB(BaseVectorDB):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
|
def query(self, input_query: List[str], n_results: int, where: Dict[str, Any]) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Query contents from vector data base based on vector similarity
|
Query contents from vector data base based on vector similarity
|
||||||
|
|
||||||
@@ -145,7 +152,7 @@ class ChromaDB(BaseVectorDB):
|
|||||||
:param n_results: no of similar documents to fetch from database
|
:param n_results: no of similar documents to fetch from database
|
||||||
:type n_results: int
|
:type n_results: int
|
||||||
:param where: to filter data
|
:param where: to filter data
|
||||||
:type where: Dict[str, any]
|
:type where: Dict[str, Any]
|
||||||
:raises InvalidDimensionException: Dimensions do not match.
|
:raises InvalidDimensionException: Dimensions do not match.
|
||||||
:return: The content of the document that matched your query.
|
:return: The content of the document that matched your query.
|
||||||
:rtype: List[str]
|
:rtype: List[str]
|
||||||
@@ -187,6 +194,9 @@ class ChromaDB(BaseVectorDB):
|
|||||||
"""
|
"""
|
||||||
return self.collection.count()
|
return self.collection.count()
|
||||||
|
|
||||||
|
def delete(self, where):
|
||||||
|
return self.collection.delete(where=where)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
Resets the database. Deletes all embeddings irreversibly.
|
Resets the database. Deletes all embeddings irreversibly.
|
||||||
|
|||||||
@@ -69,9 +69,12 @@ class MockLoader:
|
|||||||
Mock loader that returns a list of data dictionaries.
|
Mock loader that returns a list of data dictionaries.
|
||||||
Adjust this method to return different data for testing.
|
Adjust this method to return different data for testing.
|
||||||
"""
|
"""
|
||||||
return [
|
return {
|
||||||
|
"doc_id": "123",
|
||||||
|
"data": [
|
||||||
{
|
{
|
||||||
"content": src,
|
"content": src,
|
||||||
"meta_data": {"url": "none"},
|
"meta_data": {"url": "none"},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
}
|
||||||
|
|||||||
@@ -29,18 +29,19 @@ def test_load_data(delimiter):
|
|||||||
# Loading CSV using CsvLoader
|
# Loading CSV using CsvLoader
|
||||||
loader = CsvLoader()
|
loader = CsvLoader()
|
||||||
result = loader.load_data(filename)
|
result = loader.load_data(filename)
|
||||||
|
data = result["data"]
|
||||||
|
|
||||||
# Assertions
|
# Assertions
|
||||||
assert len(result) == 3
|
assert len(data) == 3
|
||||||
assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
|
assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
|
||||||
assert result[0]["meta_data"]["url"] == filename
|
assert data[0]["meta_data"]["url"] == filename
|
||||||
assert result[0]["meta_data"]["row"] == 1
|
assert data[0]["meta_data"]["row"] == 1
|
||||||
assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
|
assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
|
||||||
assert result[1]["meta_data"]["url"] == filename
|
assert data[1]["meta_data"]["url"] == filename
|
||||||
assert result[1]["meta_data"]["row"] == 2
|
assert data[1]["meta_data"]["row"] == 2
|
||||||
assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
|
assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
|
||||||
assert result[2]["meta_data"]["url"] == filename
|
assert data[2]["meta_data"]["url"] == filename
|
||||||
assert result[2]["meta_data"]["row"] == 3
|
assert data[2]["meta_data"]["row"] == 3
|
||||||
|
|
||||||
# Cleaning up the temporary file
|
# Cleaning up the temporary file
|
||||||
os.unlink(filename)
|
os.unlink(filename)
|
||||||
@@ -67,18 +68,19 @@ def test_load_data_with_file_uri(delimiter):
|
|||||||
# Loading CSV using CsvLoader
|
# Loading CSV using CsvLoader
|
||||||
loader = CsvLoader()
|
loader = CsvLoader()
|
||||||
result = loader.load_data(filename)
|
result = loader.load_data(filename)
|
||||||
|
data = result["data"]
|
||||||
|
|
||||||
# Assertions
|
# Assertions
|
||||||
assert len(result) == 3
|
assert len(data) == 3
|
||||||
assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
|
assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
|
||||||
assert result[0]["meta_data"]["url"] == filename
|
assert data[0]["meta_data"]["url"] == filename
|
||||||
assert result[0]["meta_data"]["row"] == 1
|
assert data[0]["meta_data"]["row"] == 1
|
||||||
assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
|
assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
|
||||||
assert result[1]["meta_data"]["url"] == filename
|
assert data[1]["meta_data"]["url"] == filename
|
||||||
assert result[1]["meta_data"]["row"] == 2
|
assert data[1]["meta_data"]["row"] == 2
|
||||||
assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
|
assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
|
||||||
assert result[2]["meta_data"]["url"] == filename
|
assert data[2]["meta_data"]["url"] == filename
|
||||||
assert result[2]["meta_data"]["row"] == 3
|
assert data[2]["meta_data"]["row"] == 3
|
||||||
|
|
||||||
# Cleaning up the temporary file
|
# Cleaning up the temporary file
|
||||||
os.unlink(tmpfile.name)
|
os.unlink(tmpfile.name)
|
||||||
|
|||||||
Reference in New Issue
Block a user