From 8f42ced9b590d30b3abf4303fda5a19e828cc41e Mon Sep 17 00:00:00 2001 From: Rayhan Patel <73016463+Rayhanpatel@users.noreply.github.com> Date: Thu, 13 Jul 2023 10:17:42 +0530 Subject: [PATCH] Add metadata support to added data sources (#253) --- embedchain/embedchain.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 1ea79ace..3cd2f417 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -36,7 +36,7 @@ class EmbedChain: self.collection = self.config.db.collection self.user_asks = [] - def add(self, data_type, url, config: AddConfig = None): + def add(self, data_type, url, metadata=None, config: AddConfig = None): """ Adds the data from the given URL to the vector db. Loads the data, chunks it, create embedding for each chunk @@ -44,6 +44,7 @@ class EmbedChain: :param data_type: The type of the data to add. :param url: The URL where the data is located. + :param metadata: Optional. Metadata associated with the data source. :param config: Optional. The `AddConfig` instance to use as configuration options. """ @@ -51,10 +52,10 @@ class EmbedChain: config = AddConfig() data_formatter = DataFormatter(data_type, config) - self.user_asks.append([data_type, url]) - self.load_and_embed(data_formatter.loader, data_formatter.chunker, url) + self.user_asks.append([data_type, url, metadata]) + self.load_and_embed(data_formatter.loader, data_formatter.chunker, url, metadata) - def add_local(self, data_type, content, config: AddConfig = None): + def add_local(self, data_type, content, metadata=None, config: AddConfig = None): """ Adds the data you supply to the vector db. Loads the data, chunks it, create embedding for each chunk @@ -62,6 +63,7 @@ class EmbedChain: :param data_type: The type of the data to add. :param content: The local data. Refer to the `README` for formatting. + :param metadata: Optional. Metadata associated with the data source. :param config: Optional. The `AddConfig` instance to use as configuration options. """ @@ -74,9 +76,10 @@ class EmbedChain: data_formatter.loader, data_formatter.chunker, content, + metadata, ) - def load_and_embed(self, loader, chunker, src): + def load_and_embed(self, loader, chunker, src, metadata=None): """ Loads the data from the given URL, chunks it, and adds it to database. @@ -84,6 +87,7 @@ class EmbedChain: :param chunker: The chunker to use to chunk the data. :param src: The data to be handled by the loader. Can be a URL for remote sources or local content for local loaders. + :param metadata: Optional. Metadata associated with the data source. """ embeddings_data = chunker.create_chunks(loader, src) documents = embeddings_data["documents"] @@ -112,7 +116,11 @@ class EmbedChain: documents, metadatas = zip(*data_dict.values()) chunks_before_addition = self.count() - self.collection.add(documents=documents, metadatas=list(metadatas), ids=ids) + + # Add metadata to each document + metadatas_with_metadata = [meta or metadata for meta in metadatas] + + self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids) print( ( f"Successfully saved {src}. New chunks count: "