diff --git a/embedchain/chunkers/base_chunker.py b/embedchain/chunkers/base_chunker.py index 3787ff5b..42c030be 100644 --- a/embedchain/chunkers/base_chunker.py +++ b/embedchain/chunkers/base_chunker.py @@ -22,14 +22,17 @@ class BaseChunker(JSONSerializable): documents = [] ids = [] idMap = {} - datas = loader.load_data(src) + data_result = loader.load_data(src) + data_records = data_result["data"] + doc_id = data_result["doc_id"] metadatas = [] - for data in datas: + for data in data_records: content = data["content"] meta_data = data["meta_data"] # add data type to meta data to allow query using data type meta_data["data_type"] = self.data_type.value + meta_data["doc_id"] = doc_id url = meta_data["url"] chunks = self.get_chunks(content) @@ -45,6 +48,7 @@ class BaseChunker(JSONSerializable): "documents": documents, "ids": ids, "metadatas": metadatas, + "doc_id": doc_id, } def get_chunks(self, content): diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index e91b9fe1..a2019b75 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple import requests from dotenv import load_dotenv +from langchain.docstore.document import Document from tenacity import retry, stop_after_attempt, wait_fixed from embedchain.chunkers.base_chunker import BaseChunker @@ -179,7 +180,7 @@ class EmbedChain(JSONSerializable): data_formatter = DataFormatter(data_type, config) self.user_asks.append([source, data_type.value, metadata]) - documents, metadatas, _ids, new_chunks = self.load_and_embed( + documents, metadatas, _ids, new_chunks = self.load_and_embed_v2( data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run ) if data_type in {DataType.DOCS_SITE}: @@ -271,10 +272,11 @@ class EmbedChain(JSONSerializable): # get existing ids, and discard doc if any common id exist. where = {"app_id": self.config.id} if self.config.id is not None else {} # where={"url": src} - existing_ids = self.db.get( + db_result = self.db.get( ids=ids, where=where, # optional filter ) + existing_ids = set(db_result["ids"]) if len(existing_ids): data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)} @@ -317,6 +319,112 @@ class EmbedChain(JSONSerializable): print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) return list(documents), metadatas, ids, count_new_chunks + def load_and_embed_v2( + self, + loader: BaseLoader, + chunker: BaseChunker, + src: Any, + metadata: Optional[Dict[str, Any]] = None, + source_id: Optional[str] = None, + dry_run = False + ): + """ + Loads the data from the given URL, chunks it, and adds it to database. + + :param loader: The loader to use to load the data. + :param chunker: The chunker to use to chunk the data. + :param src: The data to be handled by the loader. Can be a URL for + remote sources or local content for local loaders. + :param metadata: Optional. Metadata associated with the data source. + :param source_id: Hexadecimal hash of the source. + :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks + """ + existing_embeddings_data = self.db.get( + where={ + "url": src, + }, + limit=1, + ) + try: + existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"] + except Exception: + existing_doc_id = None + embeddings_data = chunker.create_chunks(loader, src) + + # spread chunking results + documents = embeddings_data["documents"] + metadatas = embeddings_data["metadatas"] + ids = embeddings_data["ids"] + new_doc_id = embeddings_data["doc_id"] + + if existing_doc_id and existing_doc_id == new_doc_id: + print("Doc content has not changed. Skipping creating chunks and embeddings") + return [], [], [], 0 + + # this means that doc content has changed. + if existing_doc_id and existing_doc_id != new_doc_id: + print("Doc content has changed. Recomputing chunks and embeddings intelligently.") + self.db.delete({ + "doc_id": existing_doc_id + }) + + # get existing ids, and discard doc if any common id exist. + where = {"app_id": self.config.id} if self.config.id is not None else {} + # where={"url": src} + db_result = self.db.get( + ids=ids, + where=where, # optional filter + ) + existing_ids = set(db_result["ids"]) + + if len(existing_ids): + data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)} + data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids} + + if not data_dict: + print(f"All data from {src} already exists in the database.") + # Make sure to return a matching return type + return [], [], [], 0 + + ids = list(data_dict.keys()) + documents, metadatas = zip(*data_dict.values()) + + # Loop though all metadatas and add extras. + new_metadatas = [] + for m in metadatas: + # Add app id in metadatas so that they can be queried on later + if self.config.id: + m["app_id"] = self.config.id + + # Add hashed source + m["hash"] = source_id + + # Note: Metadata is the function argument + if metadata: + # Spread whatever is in metadata into the new object. + m.update(metadata) + + new_metadatas.append(m) + metadatas = new_metadatas + + # Count before, to calculate a delta in the end. + chunks_before_addition = self.count() + + self.db.add(documents=documents, metadatas=metadatas, ids=ids) + count_new_chunks = self.count() - chunks_before_addition + print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) + return list(documents), metadatas, ids, count_new_chunks + + def _format_result(self, results): + return [ + (Document(page_content=result[0], metadata=result[1] or {}), result[2]) + for result in zip( + results["documents"][0], + results["metadatas"][0], + results["distances"][0], + ) + ] + def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]: """ Queries the vector database based on the given input query. diff --git a/embedchain/loaders/csv.py b/embedchain/loaders/csv.py index 6c5eb7f0..9de84de0 100644 --- a/embedchain/loaders/csv.py +++ b/embedchain/loaders/csv.py @@ -1,4 +1,5 @@ import csv +import hashlib from io import StringIO from urllib.parse import urlparse @@ -34,7 +35,7 @@ class CsvLoader(BaseLoader): def load_data(content): """Load a csv file with headers. Each line is a document""" result = [] - + lines = [] with CsvLoader._get_file_content(content) as file: first_line = file.readline() delimiter = CsvLoader._detect_delimiter(first_line) @@ -42,5 +43,10 @@ class CsvLoader(BaseLoader): reader = csv.DictReader(file, delimiter=delimiter) for i, row in enumerate(reader): line = ", ".join([f"{field}: {value}" for field, value in row.items()]) + lines.append(line) result.append({"content": line, "meta_data": {"url": content, "row": i + 1}}) - return result + doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": result + } diff --git a/embedchain/loaders/docs_site_loader.py b/embedchain/loaders/docs_site_loader.py index 11c9758e..5f8ed164 100644 --- a/embedchain/loaders/docs_site_loader.py +++ b/embedchain/loaders/docs_site_loader.py @@ -1,3 +1,4 @@ +import hashlib import logging from urllib.parse import urljoin, urlparse @@ -99,4 +100,8 @@ class DocsSiteLoader(BaseLoader): output = [] for u in all_urls: output.extend(self._load_data_from_url(u)) - return output + doc_id = hashlib.sha256((" ".join(all_urls) + url).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": output, + } diff --git a/embedchain/loaders/docx_file.py b/embedchain/loaders/docx_file.py index 5ea1931b..7dc2ce4a 100644 --- a/embedchain/loaders/docx_file.py +++ b/embedchain/loaders/docx_file.py @@ -1,3 +1,5 @@ +import hashlib + from langchain.document_loaders import Docx2txtLoader from embedchain.helper.json_serializable import register_deserializable @@ -15,4 +17,8 @@ class DocxFileLoader(BaseLoader): meta_data = data[0].metadata meta_data["url"] = "local" output.append({"content": content, "meta_data": meta_data}) - return output + doc_id = hashlib.sha256((content + url).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": output, + } diff --git a/embedchain/loaders/local_qna_pair.py b/embedchain/loaders/local_qna_pair.py index 5130f184..36da278e 100644 --- a/embedchain/loaders/local_qna_pair.py +++ b/embedchain/loaders/local_qna_pair.py @@ -1,3 +1,5 @@ +import hashlib + from embedchain.helper.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader @@ -8,12 +10,17 @@ class LocalQnaPairLoader(BaseLoader): """Load data from a local QnA pair.""" question, answer = content content = f"Q: {question}\nA: {answer}" + url = "local" meta_data = { - "url": "local", + "url": url, + } + doc_id = hashlib.sha256((content + url).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": [ + { + "content": content, + "meta_data": meta_data, + } + ] } - return [ - { - "content": content, - "meta_data": meta_data, - } - ] diff --git a/embedchain/loaders/local_text.py b/embedchain/loaders/local_text.py index 80b13d29..7a519578 100644 --- a/embedchain/loaders/local_text.py +++ b/embedchain/loaders/local_text.py @@ -1,3 +1,5 @@ +import hashlib + from embedchain.helper.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader @@ -6,12 +8,17 @@ from embedchain.loaders.base_loader import BaseLoader class LocalTextLoader(BaseLoader): def load_data(self, content): """Load data from a local text file.""" + url = "local" meta_data = { - "url": "local", + "url": url, + } + doc_id = hashlib.sha256((content + url).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": [ + { + "content": content, + "meta_data": meta_data, + } + ] } - return [ - { - "content": content, - "meta_data": meta_data, - } - ] diff --git a/embedchain/loaders/notion.py b/embedchain/loaders/notion.py index a4bf79ff..065673e5 100644 --- a/embedchain/loaders/notion.py +++ b/embedchain/loaders/notion.py @@ -1,3 +1,4 @@ +import hashlib import logging import os @@ -34,10 +35,13 @@ class NotionLoader(BaseLoader): # Clean text text = clean_string(raw_text) - - return [ + doc_id = hashlib.sha256((text + source).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": [ { "content": text, "meta_data": {"url": f"notion-{formatted_id}"}, } - ] + ], + } diff --git a/embedchain/loaders/pdf_file.py b/embedchain/loaders/pdf_file.py index 4084299f..b5431d31 100644 --- a/embedchain/loaders/pdf_file.py +++ b/embedchain/loaders/pdf_file.py @@ -1,3 +1,5 @@ +import hashlib + from langchain.document_loaders import PyPDFLoader from embedchain.helper.json_serializable import register_deserializable @@ -10,7 +12,8 @@ class PdfFileLoader(BaseLoader): def load_data(self, url): """Load data from a PDF file.""" loader = PyPDFLoader(url) - output = [] + data = [] + all_content = [] pages = loader.load_and_split() if not len(pages): raise ValueError("No data found") @@ -19,10 +22,15 @@ class PdfFileLoader(BaseLoader): content = clean_string(content) meta_data = page.metadata meta_data["url"] = url - output.append( + data.append( { "content": content, "meta_data": meta_data, } ) - return output + all_content.append(content) + doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": data, + } diff --git a/embedchain/loaders/sitemap.py b/embedchain/loaders/sitemap.py index 555cbe72..c78542c4 100644 --- a/embedchain/loaders/sitemap.py +++ b/embedchain/loaders/sitemap.py @@ -1,3 +1,4 @@ +import hashlib import logging import requests @@ -30,6 +31,8 @@ class SitemapLoader(BaseLoader): # Get all tags as a fallback. This might include images. links = [link.text for link in soup.find_all("loc")] + doc_id = hashlib.sha256((" ".join(links) + sitemap_url).encode()).hexdigest() + for link in links: try: each_load_data = web_page_loader.load_data(link) @@ -40,4 +43,7 @@ class SitemapLoader(BaseLoader): logging.warning(f"Page is not readable (too many invalid characters): {link}") except ParserRejectedMarkup as e: logging.error(f"Failed to parse {link}: {e}") - return [data[0] for data in output] + return { + "doc_id": doc_id, + "data": [data[0] for data in output] + } diff --git a/embedchain/loaders/web_page.py b/embedchain/loaders/web_page.py index 9b6b8d94..9e62df38 100644 --- a/embedchain/loaders/web_page.py +++ b/embedchain/loaders/web_page.py @@ -1,3 +1,4 @@ +import hashlib import logging import requests @@ -63,10 +64,14 @@ class WebPageLoader(BaseLoader): meta_data = { "url": url, } - - return [ - { - "content": content, - "meta_data": meta_data, - } - ] + content = content + doc_id = hashlib.sha256((content + url).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": [ + { + "content": content, + "meta_data": meta_data, + } + ], + } diff --git a/embedchain/loaders/youtube_video.py b/embedchain/loaders/youtube_video.py index 7a606064..9b3ca30d 100644 --- a/embedchain/loaders/youtube_video.py +++ b/embedchain/loaders/youtube_video.py @@ -1,3 +1,5 @@ +import hashlib + from langchain.document_loaders import YoutubeLoader from embedchain.helper.json_serializable import register_deserializable @@ -18,10 +20,15 @@ class YoutubeVideoLoader(BaseLoader): content = clean_string(content) meta_data = doc[0].metadata meta_data["url"] = url + output.append( { "content": content, "meta_data": meta_data, } ) - return output + doc_id = hashlib.sha256((content + url).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": output, + } diff --git a/embedchain/vectordb/base_vector_db.py b/embedchain/vectordb/base_vector_db.py new file mode 100644 index 00000000..e7ef5c8e --- /dev/null +++ b/embedchain/vectordb/base_vector_db.py @@ -0,0 +1,50 @@ +from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig +from embedchain.embedder.base_embedder import BaseEmbedder +from embedchain.helper_classes.json_serializable import JSONSerializable + + +class BaseVectorDB(JSONSerializable): + """Base class for vector database.""" + + def __init__(self, config: BaseVectorDbConfig): + self.client = self._get_or_create_db() + self.config: BaseVectorDbConfig = config + + def _initialize(self): + """ + This method is needed because `embedder` attribute needs to be set externally before it can be initialized. + + So it's can't be done in __init__ in one step. + """ + raise NotImplementedError + + def _get_or_create_db(self): + """Get or create the database.""" + raise NotImplementedError + + def _get_or_create_collection(self): + raise NotImplementedError + + def _set_embedder(self, embedder: BaseEmbedder): + self.embedder = embedder + + def get(self): + raise NotImplementedError + + def add(self): + raise NotImplementedError + + def query(self): + raise NotImplementedError + + def count(self): + raise NotImplementedError + + def delete(self): + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + def set_collection_name(self, name: str): + raise NotImplementedError \ No newline at end of file diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index 84f45840..2717b378 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Any from chromadb import Collection, QueryResult from langchain.docstore.document import Document @@ -87,25 +87,32 @@ class ChromaDB(BaseVectorDB): ) return self.collection - def get(self, ids: List[str], where: Dict[str, any]) -> List[str]: + def get(self, ids=None, where=None, limit=None): """ Get existing doc ids present in vector database :param ids: list of doc ids to check for existence :type ids: List[str] :param where: Optional. to filter data - :type where: Dict[str, any] + :type where: Dict[str, Any] :return: Existing documents. :rtype: List[str] """ - existing_docs = self.collection.get( - ids=ids, - where=where, # optional filter + args = {} + if ids: + args["ids"] = ids + if where: + args["where"] = where + if limit: + args["limit"] = limit + return self.collection.get( + **args ) - return set(existing_docs["ids"]) + def get_advanced(self, where): + return self.collection.get(where=where, limit=1) - def add(self, documents: List[str], metadatas: List[object], ids: List[str]): + def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any: """ Add vectors to chroma database @@ -136,7 +143,7 @@ class ChromaDB(BaseVectorDB): ) ] - def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]: + def query(self, input_query: List[str], n_results: int, where: Dict[str, Any]) -> List[str]: """ Query contents from vector data base based on vector similarity @@ -145,7 +152,7 @@ class ChromaDB(BaseVectorDB): :param n_results: no of similar documents to fetch from database :type n_results: int :param where: to filter data - :type where: Dict[str, any] + :type where: Dict[str, Any] :raises InvalidDimensionException: Dimensions do not match. :return: The content of the document that matched your query. :rtype: List[str] @@ -187,6 +194,9 @@ class ChromaDB(BaseVectorDB): """ return self.collection.count() + def delete(self, where): + return self.collection.delete(where=where) + def reset(self): """ Resets the database. Deletes all embeddings irreversibly. diff --git a/tests/chunkers/test_text.py b/tests/chunkers/test_text.py index 5545cf69..6ea56620 100644 --- a/tests/chunkers/test_text.py +++ b/tests/chunkers/test_text.py @@ -69,9 +69,12 @@ class MockLoader: Mock loader that returns a list of data dictionaries. Adjust this method to return different data for testing. """ - return [ - { - "content": src, - "meta_data": {"url": "none"}, - } - ] + return { + "doc_id": "123", + "data": [ + { + "content": src, + "meta_data": {"url": "none"}, + } + ] + } diff --git a/tests/loaders/test_csv.py b/tests/loaders/test_csv.py index d004c510..07f06ae8 100644 --- a/tests/loaders/test_csv.py +++ b/tests/loaders/test_csv.py @@ -29,18 +29,19 @@ def test_load_data(delimiter): # Loading CSV using CsvLoader loader = CsvLoader() result = loader.load_data(filename) + data = result["data"] # Assertions - assert len(result) == 3 - assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer" - assert result[0]["meta_data"]["url"] == filename - assert result[0]["meta_data"]["row"] == 1 - assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor" - assert result[1]["meta_data"]["url"] == filename - assert result[1]["meta_data"]["row"] == 2 - assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student" - assert result[2]["meta_data"]["url"] == filename - assert result[2]["meta_data"]["row"] == 3 + assert len(data) == 3 + assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer" + assert data[0]["meta_data"]["url"] == filename + assert data[0]["meta_data"]["row"] == 1 + assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor" + assert data[1]["meta_data"]["url"] == filename + assert data[1]["meta_data"]["row"] == 2 + assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student" + assert data[2]["meta_data"]["url"] == filename + assert data[2]["meta_data"]["row"] == 3 # Cleaning up the temporary file os.unlink(filename) @@ -67,18 +68,19 @@ def test_load_data_with_file_uri(delimiter): # Loading CSV using CsvLoader loader = CsvLoader() result = loader.load_data(filename) + data = result["data"] # Assertions - assert len(result) == 3 - assert result[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer" - assert result[0]["meta_data"]["url"] == filename - assert result[0]["meta_data"]["row"] == 1 - assert result[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor" - assert result[1]["meta_data"]["url"] == filename - assert result[1]["meta_data"]["row"] == 2 - assert result[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student" - assert result[2]["meta_data"]["url"] == filename - assert result[2]["meta_data"]["row"] == 3 + assert len(data) == 3 + assert data[0]["content"] == "Name: Alice, Age: 28, Occupation: Engineer" + assert data[0]["meta_data"]["url"] == filename + assert data[0]["meta_data"]["row"] == 1 + assert data[1]["content"] == "Name: Bob, Age: 35, Occupation: Doctor" + assert data[1]["meta_data"]["url"] == filename + assert data[1]["meta_data"]["row"] == 2 + assert data[2]["content"] == "Name: Charlie, Age: 22, Occupation: Student" + assert data[2]["meta_data"]["url"] == filename + assert data[2]["meta_data"]["row"] == 3 # Cleaning up the temporary file os.unlink(tmpfile.name)