feat: Add embedding manager (#570)

This commit is contained in:
Taranjeet Singh
2023-09-11 23:43:53 -07:00
committed by GitHub
parent ba208f5b48
commit 2bd6881361
16 changed files with 311 additions and 73 deletions

View File

@@ -22,14 +22,17 @@ class BaseChunker(JSONSerializable):
documents = []
ids = []
idMap = {}
datas = loader.load_data(src)
data_result = loader.load_data(src)
data_records = data_result["data"]
doc_id = data_result["doc_id"]
metadatas = []
for data in datas:
for data in data_records:
content = data["content"]
meta_data = data["meta_data"]
# add data type to meta data to allow query using data type
meta_data["data_type"] = self.data_type.value
meta_data["doc_id"] = doc_id
url = meta_data["url"]
chunks = self.get_chunks(content)
@@ -45,6 +48,7 @@ class BaseChunker(JSONSerializable):
"documents": documents,
"ids": ids,
"metadatas": metadatas,
"doc_id": doc_id,
}
def get_chunks(self, content):

View File

@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple
import requests
from dotenv import load_dotenv
from langchain.docstore.document import Document
from tenacity import retry, stop_after_attempt, wait_fixed
from embedchain.chunkers.base_chunker import BaseChunker
@@ -179,7 +180,7 @@ class EmbedChain(JSONSerializable):
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([source, data_type.value, metadata])
documents, metadatas, _ids, new_chunks = self.load_and_embed(
documents, metadatas, _ids, new_chunks = self.load_and_embed_v2(
data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
)
if data_type in {DataType.DOCS_SITE}:
@@ -271,10 +272,11 @@ class EmbedChain(JSONSerializable):
# get existing ids, and discard doc if any common id exist.
where = {"app_id": self.config.id} if self.config.id is not None else {}
# where={"url": src}
existing_ids = self.db.get(
db_result = self.db.get(
ids=ids,
where=where, # optional filter
)
existing_ids = set(db_result["ids"])
if len(existing_ids):
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
@@ -317,6 +319,112 @@ class EmbedChain(JSONSerializable):
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
def load_and_embed_v2(
self,
loader: BaseLoader,
chunker: BaseChunker,
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None,
dry_run = False
):
"""
Loads the data from the given URL, chunks it, and adds it to database.
:param loader: The loader to use to load the data.
:param chunker: The chunker to use to chunk the data.
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
:param metadata: Optional. Metadata associated with the data source.
:param source_id: Hexadecimal hash of the source.
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
"""
existing_embeddings_data = self.db.get(
where={
"url": src,
},
limit=1,
)
try:
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
except Exception:
existing_doc_id = None
embeddings_data = chunker.create_chunks(loader, src)
# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
ids = embeddings_data["ids"]
new_doc_id = embeddings_data["doc_id"]
if existing_doc_id and existing_doc_id == new_doc_id:
print("Doc content has not changed. Skipping creating chunks and embeddings")
return [], [], [], 0
# this means that doc content has changed.
if existing_doc_id and existing_doc_id != new_doc_id:
print("Doc content has changed. Recomputing chunks and embeddings intelligently.")
self.db.delete({
"doc_id": existing_doc_id
})
# get existing ids, and discard doc if any common id exist.
where = {"app_id": self.config.id} if self.config.id is not None else {}
# where={"url": src}
db_result = self.db.get(
ids=ids,
where=where, # optional filter
)
existing_ids = set(db_result["ids"])
if len(existing_ids):
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
if not data_dict:
print(f"All data from {src} already exists in the database.")
# Make sure to return a matching return type
return [], [], [], 0
ids = list(data_dict.keys())
documents, metadatas = zip(*data_dict.values())
# Loop though all metadatas and add extras.
new_metadatas = []
for m in metadatas:
# Add app id in metadatas so that they can be queried on later
if self.config.id:
m["app_id"] = self.config.id
# Add hashed source
m["hash"] = source_id
# Note: Metadata is the function argument
if metadata:
# Spread whatever is in metadata into the new object.
m.update(metadata)
new_metadatas.append(m)
metadatas = new_metadatas
# Count before, to calculate a delta in the end.
chunks_before_addition = self.count()
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
count_new_chunks = self.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
def _format_result(self, results):
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]
def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
"""
Queries the vector database based on the given input query.

View File

@@ -1,4 +1,5 @@
import csv
import hashlib
from io import StringIO
from urllib.parse import urlparse
@@ -34,7 +35,7 @@ class CsvLoader(BaseLoader):
def load_data(content):
"""Load a csv file with headers. Each line is a document"""
result = []
lines = []
with CsvLoader._get_file_content(content) as file:
first_line = file.readline()
delimiter = CsvLoader._detect_delimiter(first_line)
@@ -42,5 +43,10 @@ class CsvLoader(BaseLoader):
reader = csv.DictReader(file, delimiter=delimiter)
for i, row in enumerate(reader):
line = ", ".join([f"{field}: {value}" for field, value in row.items()])
lines.append(line)
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
return result
doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": result
}

View File

@@ -1,3 +1,4 @@
import hashlib
import logging
from urllib.parse import urljoin, urlparse
@@ -99,4 +100,8 @@ class DocsSiteLoader(BaseLoader):
output = []
for u in all_urls:
output.extend(self._load_data_from_url(u))
return output
doc_id = hashlib.sha256((" ".join(all_urls) + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}

View File

@@ -1,3 +1,5 @@
import hashlib
from langchain.document_loaders import Docx2txtLoader
from embedchain.helper.json_serializable import register_deserializable
@@ -15,4 +17,8 @@ class DocxFileLoader(BaseLoader):
meta_data = data[0].metadata
meta_data["url"] = "local"
output.append({"content": content, "meta_data": meta_data})
return output
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}

View File

@@ -1,3 +1,5 @@
import hashlib
from embedchain.helper.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@@ -8,12 +10,17 @@ class LocalQnaPairLoader(BaseLoader):
"""Load data from a local QnA pair."""
question, answer = content
content = f"Q: {question}\nA: {answer}"
url = "local"
meta_data = {
"url": "local",
"url": url,
}
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": meta_data,
}
]
}
return [
{
"content": content,
"meta_data": meta_data,
}
]

View File

@@ -1,3 +1,5 @@
import hashlib
from embedchain.helper.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@@ -6,12 +8,17 @@ from embedchain.loaders.base_loader import BaseLoader
class LocalTextLoader(BaseLoader):
def load_data(self, content):
"""Load data from a local text file."""
url = "local"
meta_data = {
"url": "local",
"url": url,
}
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": meta_data,
}
]
}
return [
{
"content": content,
"meta_data": meta_data,
}
]

View File

@@ -1,3 +1,4 @@
import hashlib
import logging
import os
@@ -34,10 +35,13 @@ class NotionLoader(BaseLoader):
# Clean text
text = clean_string(raw_text)
return [
doc_id = hashlib.sha256((text + source).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": text,
"meta_data": {"url": f"notion-{formatted_id}"},
}
]
],
}

View File

@@ -1,3 +1,5 @@
import hashlib
from langchain.document_loaders import PyPDFLoader
from embedchain.helper.json_serializable import register_deserializable
@@ -10,7 +12,8 @@ class PdfFileLoader(BaseLoader):
def load_data(self, url):
"""Load data from a PDF file."""
loader = PyPDFLoader(url)
output = []
data = []
all_content = []
pages = loader.load_and_split()
if not len(pages):
raise ValueError("No data found")
@@ -19,10 +22,15 @@ class PdfFileLoader(BaseLoader):
content = clean_string(content)
meta_data = page.metadata
meta_data["url"] = url
output.append(
data.append(
{
"content": content,
"meta_data": meta_data,
}
)
return output
all_content.append(content)
doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}

View File

@@ -1,3 +1,4 @@
import hashlib
import logging
import requests
@@ -30,6 +31,8 @@ class SitemapLoader(BaseLoader):
# Get all <loc> tags as a fallback. This might include images.
links = [link.text for link in soup.find_all("loc")]
doc_id = hashlib.sha256((" ".join(links) + sitemap_url).encode()).hexdigest()
for link in links:
try:
each_load_data = web_page_loader.load_data(link)
@@ -40,4 +43,7 @@ class SitemapLoader(BaseLoader):
logging.warning(f"Page is not readable (too many invalid characters): {link}")
except ParserRejectedMarkup as e:
logging.error(f"Failed to parse {link}: {e}")
return [data[0] for data in output]
return {
"doc_id": doc_id,
"data": [data[0] for data in output]
}

View File

@@ -1,3 +1,4 @@
import hashlib
import logging
import requests
@@ -63,10 +64,14 @@ class WebPageLoader(BaseLoader):
meta_data = {
"url": url,
}
return [
{
"content": content,
"meta_data": meta_data,
}
]
content = content
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": meta_data,
}
],
}

View File

@@ -1,3 +1,5 @@
import hashlib
from langchain.document_loaders import YoutubeLoader
from embedchain.helper.json_serializable import register_deserializable
@@ -18,10 +20,15 @@ class YoutubeVideoLoader(BaseLoader):
content = clean_string(content)
meta_data = doc[0].metadata
meta_data["url"] = url
output.append(
{
"content": content,
"meta_data": meta_data,
}
)
return output
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}

View File

@@ -0,0 +1,50 @@
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.helper_classes.json_serializable import JSONSerializable
class BaseVectorDB(JSONSerializable):
"""Base class for vector database."""
def __init__(self, config: BaseVectorDbConfig):
self.client = self._get_or_create_db()
self.config: BaseVectorDbConfig = config
def _initialize(self):
"""
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
So it's can't be done in __init__ in one step.
"""
raise NotImplementedError
def _get_or_create_db(self):
"""Get or create the database."""
raise NotImplementedError
def _get_or_create_collection(self):
raise NotImplementedError
def _set_embedder(self, embedder: BaseEmbedder):
self.embedder = embedder
def get(self):
raise NotImplementedError
def add(self):
raise NotImplementedError
def query(self):
raise NotImplementedError
def count(self):
raise NotImplementedError
def delete(self):
raise NotImplementedError
def reset(self):
raise NotImplementedError
def set_collection_name(self, name: str):
raise NotImplementedError

View File

@@ -1,5 +1,5 @@
import logging
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Any
from chromadb import Collection, QueryResult
from langchain.docstore.document import Document
@@ -87,25 +87,32 @@ class ChromaDB(BaseVectorDB):
)
return self.collection
def get(self, ids: List[str], where: Dict[str, any]) -> List[str]:
def get(self, ids=None, where=None, limit=None):
"""
Get existing doc ids present in vector database
:param ids: list of doc ids to check for existence
:type ids: List[str]
:param where: Optional. to filter data
:type where: Dict[str, any]
:type where: Dict[str, Any]
:return: Existing documents.
:rtype: List[str]
"""
existing_docs = self.collection.get(
ids=ids,
where=where, # optional filter
args = {}
if ids:
args["ids"] = ids
if where:
args["where"] = where
if limit:
args["limit"] = limit
return self.collection.get(
**args
)
return set(existing_docs["ids"])
def get_advanced(self, where):
return self.collection.get(where=where, limit=1)
def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
"""
Add vectors to chroma database
@@ -136,7 +143,7 @@ class ChromaDB(BaseVectorDB):
)
]
def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
def query(self, input_query: List[str], n_results: int, where: Dict[str, Any]) -> List[str]:
"""
Query contents from vector data base based on vector similarity
@@ -145,7 +152,7 @@ class ChromaDB(BaseVectorDB):
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: to filter data
:type where: Dict[str, any]
:type where: Dict[str, Any]
:raises InvalidDimensionException: Dimensions do not match.
:return: The content of the document that matched your query.
:rtype: List[str]
@@ -187,6 +194,9 @@ class ChromaDB(BaseVectorDB):
"""
return self.collection.count()
def delete(self, where):
return self.collection.delete(where=where)
def reset(self):
"""
Resets the database. Deletes all embeddings irreversibly.

View File

@@ -69,9 +69,12 @@ class MockLoader:
Mock loader that returns a list of data dictionaries.
Adjust this method to return different data for testing.
"""
return [
{
"content": src,
"meta_data": {"url": "none"},
}
]
return {
"doc_id": "123",
"data": [
{
"content": src,
"meta_data": {"url": "none"},
}
]
}

View File

@@ -29,18 +29,19 @@ def test_load_data(delimiter):
# Loading CSV using CsvLoader
loader = CsvLoader()
result = loader.load_data(filename)
data = result["data"]
# Assertions
assert len(result) == 3
assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
assert result[0]["meta_data"]["url"] == filename
assert result[0]["meta_data"]["row"] == 1
assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
assert result[1]["meta_data"]["url"] == filename
assert result[1]["meta_data"]["row"] == 2
assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
assert result[2]["meta_data"]["url"] == filename
assert result[2]["meta_data"]["row"] == 3
assert len(data) == 3
assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
assert data[0]["meta_data"]["url"] == filename
assert data[0]["meta_data"]["row"] == 1
assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
assert data[1]["meta_data"]["url"] == filename
assert data[1]["meta_data"]["row"] == 2
assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
assert data[2]["meta_data"]["url"] == filename
assert data[2]["meta_data"]["row"] == 3
# Cleaning up the temporary file
os.unlink(filename)
@@ -67,18 +68,19 @@ def test_load_data_with_file_uri(delimiter):
# Loading CSV using CsvLoader
loader = CsvLoader()
result = loader.load_data(filename)
data = result["data"]
# Assertions
assert len(result) == 3
assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
assert result[0]["meta_data"]["url"] == filename
assert result[0]["meta_data"]["row"] == 1
assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
assert result[1]["meta_data"]["url"] == filename
assert result[1]["meta_data"]["row"] == 2
assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
assert result[2]["meta_data"]["url"] == filename
assert result[2]["meta_data"]["row"] == 3
assert len(data) == 3
assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer"
assert data[0]["meta_data"]["url"] == filename
assert data[0]["meta_data"]["row"] == 1
assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor"
assert data[1]["meta_data"]["url"] == filename
assert data[1]["meta_data"]["row"] == 2
assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student"
assert data[2]["meta_data"]["url"] == filename
assert data[2]["meta_data"]["row"] == 3
# Cleaning up the temporary file
os.unlink(tmpfile.name)