feat: Add embedding manager (#570)

This commit is contained in:
Taranjeet Singh
2023-09-11 23:43:53 -07:00
committed by GitHub
parent ba208f5b48
commit 2bd6881361
16 changed files with 311 additions and 73 deletions

View File

@@ -22,14 +22,17 @@ class BaseChunker(JSONSerializable):
documents = [] documents = []
ids = [] ids = []
idMap = {} idMap = {}
datas = loader.load_data(src) data_result = loader.load_data(src)
data_records = data_result["data"]
doc_id = data_result["doc_id"]
metadatas = [] metadatas = []
for data in datas: for data in data_records:
content = data["content"] content = data["content"]
meta_data = data["meta_data"] meta_data = data["meta_data"]
# add data type to meta data to allow query using data type # add data type to meta data to allow query using data type
meta_data["data_type"] = self.data_type.value meta_data["data_type"] = self.data_type.value
meta_data["doc_id"] = doc_id
url = meta_data["url"] url = meta_data["url"]
chunks = self.get_chunks(content) chunks = self.get_chunks(content)
@@ -45,6 +48,7 @@ class BaseChunker(JSONSerializable):
"documents": documents, "documents": documents,
"ids": ids, "ids": ids,
"metadatas": metadatas, "metadatas": metadatas,
"doc_id": doc_id,
} }
def get_chunks(self, content): def get_chunks(self, content):

View File

@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple
import requests import requests
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.docstore.document import Document
from tenacity import retry, stop_after_attempt, wait_fixed from tenacity import retry, stop_after_attempt, wait_fixed
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
@@ -179,7 +180,7 @@ class EmbedChain(JSONSerializable):
data_formatter = DataFormatter(data_type, config) data_formatter = DataFormatter(data_type, config)
self.user_asks.append([source, data_type.value, metadata]) self.user_asks.append([source, data_type.value, metadata])
documents, metadatas, _ids, new_chunks = self.load_and_embed( documents, metadatas, _ids, new_chunks = self.load_and_embed_v2(
data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
) )
if data_type in {DataType.DOCS_SITE}: if data_type in {DataType.DOCS_SITE}:
@@ -271,10 +272,11 @@ class EmbedChain(JSONSerializable):
# get existing ids, and discard doc if any common id exist. # get existing ids, and discard doc if any common id exist.
where = {"app_id": self.config.id} if self.config.id is not None else {} where = {"app_id": self.config.id} if self.config.id is not None else {}
# where={"url": src} # where={"url": src}
existing_ids = self.db.get( db_result = self.db.get(
ids=ids, ids=ids,
where=where, # optional filter where=where, # optional filter
) )
existing_ids = set(db_result["ids"])
if len(existing_ids): if len(existing_ids):
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)} data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
@@ -317,6 +319,112 @@ class EmbedChain(JSONSerializable):
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks return list(documents), metadatas, ids, count_new_chunks
def load_and_embed_v2(
self,
loader: BaseLoader,
chunker: BaseChunker,
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None,
dry_run = False
):
"""
Loads the data from the given URL, chunks it, and adds it to database.
:param loader: The loader to use to load the data.
:param chunker: The chunker to use to chunk the data.
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
:param metadata: Optional. Metadata associated with the data source.
:param source_id: Hexadecimal hash of the source.
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
"""
existing_embeddings_data = self.db.get(
where={
"url": src,
},
limit=1,
)
try:
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
except Exception:
existing_doc_id = None
embeddings_data = chunker.create_chunks(loader, src)
# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
ids = embeddings_data["ids"]
new_doc_id = embeddings_data["doc_id"]
if existing_doc_id and existing_doc_id == new_doc_id:
print("Doc content has not changed. Skipping creating chunks and embeddings")
return [], [], [], 0
# this means that doc content has changed.
if existing_doc_id and existing_doc_id != new_doc_id:
print("Doc content has changed. Recomputing chunks and embeddings intelligently.")
self.db.delete({
"doc_id": existing_doc_id
})
# get existing ids, and discard doc if any common id exist.
where = {"app_id": self.config.id} if self.config.id is not None else {}
# where={"url": src}
db_result = self.db.get(
ids=ids,
where=where, # optional filter
)
existing_ids = set(db_result["ids"])
if len(existing_ids):
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
if not data_dict:
print(f"All data from {src} already exists in the database.")
# Make sure to return a matching return type
return [], [], [], 0
ids = list(data_dict.keys())
documents, metadatas = zip(*data_dict.values())
# Loop though all metadatas and add extras.
new_metadatas = []
for m in metadatas:
# Add app id in metadatas so that they can be queried on later
if self.config.id:
m["app_id"] = self.config.id
# Add hashed source
m["hash"] = source_id
# Note: Metadata is the function argument
if metadata:
# Spread whatever is in metadata into the new object.
m.update(metadata)
new_metadatas.append(m)
metadatas = new_metadatas
# Count before, to calculate a delta in the end.
chunks_before_addition = self.count()
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
count_new_chunks = self.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
def _format_result(self, results):
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]
def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]: def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
""" """
Queries the vector database based on the given input query. Queries the vector database based on the given input query.

View File

@@ -1,4 +1,5 @@
import csv import csv
import hashlib
from io import StringIO from io import StringIO
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -34,7 +35,7 @@ class CsvLoader(BaseLoader):
def load_data(content): def load_data(content):
"""Load a csv file with headers. Each line is a document""" """Load a csv file with headers. Each line is a document"""
result = [] result = []
lines = []
with CsvLoader._get_file_content(content) as file: with CsvLoader._get_file_content(content) as file:
first_line = file.readline() first_line = file.readline()
delimiter = CsvLoader._detect_delimiter(first_line) delimiter = CsvLoader._detect_delimiter(first_line)
@@ -42,5 +43,10 @@ class CsvLoader(BaseLoader):
reader = csv.DictReader(file, delimiter=delimiter) reader = csv.DictReader(file, delimiter=delimiter)
for i, row in enumerate(reader): for i, row in enumerate(reader):
line = ", ".join([f"{field}: {value}" for field, value in row.items()]) line = ", ".join([f"{field}: {value}" for field, value in row.items()])
lines.append(line)
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}}) result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
return result doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": result
}

View File

@@ -1,3 +1,4 @@
import hashlib
import logging import logging
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
@@ -99,4 +100,8 @@ class DocsSiteLoader(BaseLoader):
output = [] output = []
for u in all_urls: for u in all_urls:
output.extend(self._load_data_from_url(u)) output.extend(self._load_data_from_url(u))
return output doc_id = hashlib.sha256((" ".join(all_urls) + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}

View File

@@ -1,3 +1,5 @@
import hashlib
from langchain.document_loaders import Docx2txtLoader from langchain.document_loaders import Docx2txtLoader
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
@@ -15,4 +17,8 @@ class DocxFileLoader(BaseLoader):
meta_data = data[0].metadata meta_data = data[0].metadata
meta_data["url"] = "local" meta_data["url"] = "local"
output.append({"content": content, "meta_data": meta_data}) output.append({"content": content, "meta_data": meta_data})
return output doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}

View File

@@ -1,3 +1,5 @@
import hashlib
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
@@ -8,12 +10,17 @@ class LocalQnaPairLoader(BaseLoader):
"""Load data from a local QnA pair.""" """Load data from a local QnA pair."""
question, answer = content question, answer = content
content = f"Q: {question}\nA: {answer}" content = f"Q: {question}\nA: {answer}"
url = "local"
meta_data = { meta_data = {
"url": "local", "url": url,
}
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": meta_data,
}
]
} }
return [
{
"content": content,
"meta_data": meta_data,
}
]

View File

@@ -1,3 +1,5 @@
import hashlib
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
@@ -6,12 +8,17 @@ from embedchain.loaders.base_loader import BaseLoader
class LocalTextLoader(BaseLoader): class LocalTextLoader(BaseLoader):
def load_data(self, content): def load_data(self, content):
"""Load data from a local text file.""" """Load data from a local text file."""
url = "local"
meta_data = { meta_data = {
"url": "local", "url": url,
}
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": meta_data,
}
]
} }
return [
{
"content": content,
"meta_data": meta_data,
}
]

View File

@@ -1,3 +1,4 @@
import hashlib
import logging import logging
import os import os
@@ -34,10 +35,13 @@ class NotionLoader(BaseLoader):
# Clean text # Clean text
text = clean_string(raw_text) text = clean_string(raw_text)
doc_id = hashlib.sha256((text + source).encode()).hexdigest()
return [ return {
"doc_id": doc_id,
"data": [
{ {
"content": text, "content": text,
"meta_data": {"url": f"notion-{formatted_id}"}, "meta_data": {"url": f"notion-{formatted_id}"},
} }
] ],
}

View File

@@ -1,3 +1,5 @@
import hashlib
from langchain.document_loaders import PyPDFLoader from langchain.document_loaders import PyPDFLoader
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
@@ -10,7 +12,8 @@ class PdfFileLoader(BaseLoader):
def load_data(self, url): def load_data(self, url):
"""Load data from a PDF file.""" """Load data from a PDF file."""
loader = PyPDFLoader(url) loader = PyPDFLoader(url)
output = [] data = []
all_content = []
pages = loader.load_and_split() pages = loader.load_and_split()
if not len(pages): if not len(pages):
raise ValueError("No data found") raise ValueError("No data found")
@@ -19,10 +22,15 @@ class PdfFileLoader(BaseLoader):
content = clean_string(content) content = clean_string(content)
meta_data = page.metadata meta_data = page.metadata
meta_data["url"] = url meta_data["url"] = url
output.append( data.append(
{ {
"content": content, "content": content,
"meta_data": meta_data, "meta_data": meta_data,
} }
) )
return output all_content.append(content)
doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}

View File

@@ -1,3 +1,4 @@
import hashlib
import logging import logging
import requests import requests
@@ -30,6 +31,8 @@ class SitemapLoader(BaseLoader):
# Get all <loc> tags as a fallback. This might include images. # Get all <loc> tags as a fallback. This might include images.
links = [link.text for link in soup.find_all("loc")] links = [link.text for link in soup.find_all("loc")]
doc_id = hashlib.sha256((" ".join(links) + sitemap_url).encode()).hexdigest()
for link in links: for link in links:
try: try:
each_load_data = web_page_loader.load_data(link) each_load_data = web_page_loader.load_data(link)
@@ -40,4 +43,7 @@ class SitemapLoader(BaseLoader):
logging.warning(f"Page is not readable (too many invalid characters): {link}") logging.warning(f"Page is not readable (too many invalid characters): {link}")
except ParserRejectedMarkup as e: except ParserRejectedMarkup as e:
logging.error(f"Failed to parse {link}: {e}") logging.error(f"Failed to parse {link}: {e}")
return [data[0] for data in output] return {
"doc_id": doc_id,
"data": [data[0] for data in output]
}

View File

@@ -1,3 +1,4 @@
import hashlib
import logging import logging
import requests import requests
@@ -63,10 +64,14 @@ class WebPageLoader(BaseLoader):
meta_data = { meta_data = {
"url": url, "url": url,
} }
content = content
return [ doc_id = hashlib.sha256((content + url).encode()).hexdigest()
{ return {
"content": content, "doc_id": doc_id,
"meta_data": meta_data, "data": [
} {
] "content": content,
"meta_data": meta_data,
}
],
}

View File

@@ -1,3 +1,5 @@
import hashlib
from langchain.document_loaders import YoutubeLoader from langchain.document_loaders import YoutubeLoader
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
@@ -18,10 +20,15 @@ class YoutubeVideoLoader(BaseLoader):
content = clean_string(content) content = clean_string(content)
meta_data = doc[0].metadata meta_data = doc[0].metadata
meta_data["url"] = url meta_data["url"] = url
output.append( output.append(
{ {
"content": content, "content": content,
"meta_data": meta_data, "meta_data": meta_data,
} }
) )
return output doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}

View File

@@ -0,0 +1,50 @@
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.helper_classes.json_serializable import JSONSerializable
class BaseVectorDB(JSONSerializable):
"""Base class for vector database."""
def __init__(self, config: BaseVectorDbConfig):
self.client = self._get_or_create_db()
self.config: BaseVectorDbConfig = config
def _initialize(self):
"""
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
So it's can't be done in __init__ in one step.
"""
raise NotImplementedError
def _get_or_create_db(self):
"""Get or create the database."""
raise NotImplementedError
def _get_or_create_collection(self):
raise NotImplementedError
def _set_embedder(self, embedder: BaseEmbedder):
self.embedder = embedder
def get(self):
raise NotImplementedError
def add(self):
raise NotImplementedError
def query(self):
raise NotImplementedError
def count(self):
raise NotImplementedError
def delete(self):
raise NotImplementedError
def reset(self):
raise NotImplementedError
def set_collection_name(self, name: str):
raise NotImplementedError

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional, Any
from chromadb import Collection, QueryResult from chromadb import Collection, QueryResult
from langchain.docstore.document import Document from langchain.docstore.document import Document
@@ -87,25 +87,32 @@ class ChromaDB(BaseVectorDB):
) )
return self.collection return self.collection
def get(self, ids: List[str], where: Dict[str, any]) -> List[str]: def get(self, ids=None, where=None, limit=None):
""" """
Get existing doc ids present in vector database Get existing doc ids present in vector database
:param ids: list of doc ids to check for existence :param ids: list of doc ids to check for existence
:type ids: List[str] :type ids: List[str]
:param where: Optional. to filter data :param where: Optional. to filter data
:type where: Dict[str, any] :type where: Dict[str, Any]
:return: Existing documents. :return: Existing documents.
:rtype: List[str] :rtype: List[str]
""" """
existing_docs = self.collection.get( args = {}
ids=ids, if ids:
where=where, # optional filter args["ids"] = ids
if where:
args["where"] = where
if limit:
args["limit"] = limit
return self.collection.get(
**args
) )
return set(existing_docs["ids"]) def get_advanced(self, where):
return self.collection.get(where=where, limit=1)
def add(self, documents: List[str], metadatas: List[object], ids: List[str]): def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
""" """
Add vectors to chroma database Add vectors to chroma database
@@ -136,7 +143,7 @@ class ChromaDB(BaseVectorDB):
) )
] ]
def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]: def query(self, input_query: List[str], n_results: int, where: Dict[str, Any]) -> List[str]:
""" """
Query contents from vector data base based on vector similarity Query contents from vector data base based on vector similarity
@@ -145,7 +152,7 @@ class ChromaDB(BaseVectorDB):
:param n_results: no of similar documents to fetch from database :param n_results: no of similar documents to fetch from database
:type n_results: int :type n_results: int
:param where: to filter data :param where: to filter data
:type where: Dict[str, any] :type where: Dict[str, Any]
:raises InvalidDimensionException: Dimensions do not match. :raises InvalidDimensionException: Dimensions do not match.
:return: The content of the document that matched your query. :return: The content of the document that matched your query.
:rtype: List[str] :rtype: List[str]
@@ -187,6 +194,9 @@ class ChromaDB(BaseVectorDB):
""" """
return self.collection.count() return self.collection.count()
def delete(self, where):
return self.collection.delete(where=where)
def reset(self): def reset(self):
""" """
Resets the database. Deletes all embeddings irreversibly. Resets the database. Deletes all embeddings irreversibly.

View File

@@ -69,9 +69,12 @@ class MockLoader:
Mock loader that returns a list of data dictionaries. Mock loader that returns a list of data dictionaries.
Adjust this method to return different data for testing. Adjust this method to return different data for testing.
""" """
return [ return {
{ "doc_id": "123",
"content": src, "data": [
"meta_data": {"url": "none"}, {
} "content": src,
] "meta_data": {"url": "none"},
}
]
}

View File

@@ -29,18 +29,19 @@ def test_load_data(delimiter):
# Loading CSV using CsvLoader # Loading CSV using CsvLoader
loader = CsvLoader() loader = CsvLoader()
result = loader.load_data(filename) result = loader.load_data(filename)
data = result["data"]
# Assertions # Assertions
assert len(result) == 3 assert len(data) == 3
assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer" assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
assert result[0]["meta_data"]["url"] == filename assert data[0]["meta_data"]["url"] == filename
assert result[0]["meta_data"]["row"] == 1 assert data[0]["meta_data"]["row"] == 1
assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor" assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
assert result[1]["meta_data"]["url"] == filename assert data[1]["meta_data"]["url"] == filename
assert result[1]["meta_data"]["row"] == 2 assert data[1]["meta_data"]["row"] == 2
assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student" assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
assert result[2]["meta_data"]["url"] == filename assert data[2]["meta_data"]["url"] == filename
assert result[2]["meta_data"]["row"] == 3 assert data[2]["meta_data"]["row"] == 3
# Cleaning up the temporary file # Cleaning up the temporary file
os.unlink(filename) os.unlink(filename)
@@ -67,18 +68,19 @@ def test_load_data_with_file_uri(delimiter):
# Loading CSV using CsvLoader # Loading CSV using CsvLoader
loader = CsvLoader() loader = CsvLoader()
result = loader.load_data(filename) result = loader.load_data(filename)
data = result["data"]
# Assertions # Assertions
assert len(result) == 3 assert len(data) == 3
assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer" assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
assert result[0]["meta_data"]["url"] == filename assert data[0]["meta_data"]["url"] == filename
assert result[0]["meta_data"]["row"] == 1 assert data[0]["meta_data"]["row"] == 1
assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor" assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
assert result[1]["meta_data"]["url"] == filename assert data[1]["meta_data"]["url"] == filename
assert result[1]["meta_data"]["row"] == 2 assert data[1]["meta_data"]["row"] == 2
assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student" assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
assert result[2]["meta_data"]["url"] == filename assert data[2]["meta_data"]["url"] == filename
assert result[2]["meta_data"]["row"] == 3 assert data[2]["meta_data"]["row"] == 3
# Cleaning up the temporary file # Cleaning up the temporary file
os.unlink(tmpfile.name) os.unlink(tmpfile.name)