feat: add method - detect format / data_type (#380)
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import hashlib
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
@@ -17,6 +18,8 @@ from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType
|
||||
from embedchain.utils import detect_datatype
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -47,27 +50,62 @@ class EmbedChain:
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
|
||||
thread_telemetry.start()
|
||||
|
||||
def add(self, data_type, url, metadata=None, config: AddConfig = None):
|
||||
def add(
|
||||
self,
|
||||
source,
|
||||
data_type: Optional[DataType] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
config: Optional[AddConfig] = None,
|
||||
):
|
||||
"""
|
||||
Adds the data from the given URL to the vector db.
|
||||
Loads the data, chunks it, create embedding for each chunk
|
||||
and then stores the embedding to vector database.
|
||||
|
||||
:param data_type: The type of the data to add.
|
||||
:param url: The URL where the data is located.
|
||||
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
|
||||
:param data_type: Optional. Automatically detected, but can be forced with this argument.
|
||||
The type of the data to add.
|
||||
:param metadata: Optional. Metadata associated with the data source.
|
||||
:param config: Optional. The `AddConfig` instance to use as configuration
|
||||
options.
|
||||
:return: source_id, a md5-hash of the source, in hexadecimal representation.
|
||||
"""
|
||||
if config is None:
|
||||
config = AddConfig()
|
||||
|
||||
try:
|
||||
DataType(source)
|
||||
logging.warning(
|
||||
f"""Starting from version v0.0.39, Embedchain can automatically detect the data type. So, in the `add` method, the argument order has changed. You no longer need to specify '{source}' for the `source` argument. So the code snippet will be `.add("{data_type}", "{source}")`""" # noqa #E501
|
||||
)
|
||||
logging.warning(
|
||||
"Embedchain is swapping the arguments for you. This functionality might be deprecated in the future, so please adjust your code." # noqa #E501
|
||||
)
|
||||
source, data_type = data_type, source
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if data_type:
|
||||
try:
|
||||
data_type = DataType(data_type)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid data_type: '{data_type}'.",
|
||||
f"Please use one of the following: {[data_type.value for data_type in DataType]}",
|
||||
) from None
|
||||
if not data_type:
|
||||
data_type = detect_datatype(source)
|
||||
|
||||
# `source_id` is the hash of the source argument
|
||||
hash_object = hashlib.md5(str(source).encode("utf-8"))
|
||||
source_id = hash_object.hexdigest()
|
||||
|
||||
data_formatter = DataFormatter(data_type, config)
|
||||
self.user_asks.append([data_type, url, metadata])
|
||||
self.user_asks.append([source, data_type.value, metadata])
|
||||
documents, _metadatas, _ids, new_chunks = self.load_and_embed(
|
||||
data_formatter.loader, data_formatter.chunker, url, metadata
|
||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_id
|
||||
)
|
||||
if data_type in ("docs_site",):
|
||||
if data_type in {DataType.DOCS_SITE}:
|
||||
self.is_docs_site_instance = True
|
||||
|
||||
# Send anonymous telemetry
|
||||
@@ -75,41 +113,35 @@ class EmbedChain:
|
||||
# it's quicker to check the variable twice than to count words when they won't be submitted.
|
||||
word_count = sum([len(document.split(" ")) for document in documents])
|
||||
|
||||
extra_metadata = {"data_type": data_type, "word_count": word_count, "chunks_count": new_chunks}
|
||||
extra_metadata = {"data_type": data_type.value, "word_count": word_count, "chunks_count": new_chunks}
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
|
||||
thread_telemetry.start()
|
||||
|
||||
def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
|
||||
return source_id
|
||||
|
||||
def add_local(self, source, data_type=None, metadata=None, config: AddConfig = None):
|
||||
"""
|
||||
Adds the data you supply to the vector db.
|
||||
Warning:
|
||||
This method is deprecated and will be removed in future versions. Use `add` instead.
|
||||
|
||||
Adds the data from the given URL to the vector db.
|
||||
Loads the data, chunks it, create embedding for each chunk
|
||||
and then stores the embedding to vector database.
|
||||
|
||||
:param data_type: The type of the data to add.
|
||||
:param content: The local data. Refer to the `README` for formatting.
|
||||
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
|
||||
:param data_type: Optional. Automatically detected, but can be forced with this argument.
|
||||
The type of the data to add.
|
||||
:param metadata: Optional. Metadata associated with the data source.
|
||||
:param config: Optional. The `AddConfig` instance to use as
|
||||
configuration options.
|
||||
:param config: Optional. The `AddConfig` instance to use as configuration
|
||||
options.
|
||||
:return: md5-hash of the source, in hexadecimal representation.
|
||||
"""
|
||||
if config is None:
|
||||
config = AddConfig()
|
||||
|
||||
data_formatter = DataFormatter(data_type, config)
|
||||
self.user_asks.append([data_type, content])
|
||||
documents, _metadatas, _ids, new_chunks = self.load_and_embed(
|
||||
data_formatter.loader, data_formatter.chunker, content, metadata
|
||||
logging.warning(
|
||||
"The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files." # noqa: E501
|
||||
)
|
||||
return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
|
||||
|
||||
# Send anonymous telemetry
|
||||
if self.config.collect_metrics:
|
||||
# it's quicker to check the variable twice than to count words when they won't be submitted.
|
||||
word_count = sum([len(document.split(" ")) for document in documents])
|
||||
|
||||
extra_metadata = {"data_type": data_type, "word_count": word_count, "chunks_count": new_chunks}
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add_local", extra_metadata))
|
||||
thread_telemetry.start()
|
||||
|
||||
def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None):
|
||||
def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None, source_id=None):
|
||||
"""
|
||||
Loads the data from the given URL, chunks it, and adds it to database.
|
||||
|
||||
@@ -118,12 +150,16 @@ class EmbedChain:
|
||||
:param src: The data to be handled by the loader. Can be a URL for
|
||||
remote sources or local content for local loaders.
|
||||
:param metadata: Optional. Metadata associated with the data source.
|
||||
:param source_id: Hexadecimal hash of the source.
|
||||
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
||||
"""
|
||||
embeddings_data = chunker.create_chunks(loader, src)
|
||||
|
||||
# spread chunking results
|
||||
documents = embeddings_data["documents"]
|
||||
metadatas = embeddings_data["metadatas"]
|
||||
ids = embeddings_data["ids"]
|
||||
|
||||
# get existing ids, and discard doc if any common id exist.
|
||||
where = {"app_id": self.config.id} if self.config.id is not None else {}
|
||||
# where={"url": src}
|
||||
@@ -144,22 +180,31 @@ class EmbedChain:
|
||||
ids = list(data_dict.keys())
|
||||
documents, metadatas = zip(*data_dict.values())
|
||||
|
||||
# Add app id in metadatas so that they can be queried on later
|
||||
if self.config.id is not None:
|
||||
metadatas = [{**m, "app_id": self.config.id} for m in metadatas]
|
||||
# Loop though all metadatas and add extras.
|
||||
new_metadatas = []
|
||||
for m in metadatas:
|
||||
# Add app id in metadatas so that they can be queried on later
|
||||
if self.config.id:
|
||||
m["app_id"] = self.config.id
|
||||
|
||||
# FIXME: Fix the error handling logic when metadatas or metadata is None
|
||||
metadatas = metadatas if metadatas else []
|
||||
metadata = metadata if metadata else {}
|
||||
# Add hashed source
|
||||
m["hash"] = source_id
|
||||
|
||||
# Note: Metadata is the function argument
|
||||
if metadata:
|
||||
# Spread whatever is in metadata into the new object.
|
||||
m.update(metadata)
|
||||
|
||||
new_metadatas.append(m)
|
||||
metadatas = new_metadatas
|
||||
|
||||
# Count before, to calculate a delta in the end.
|
||||
chunks_before_addition = self.count()
|
||||
|
||||
# Add metadata to each document
|
||||
metadatas_with_metadata = [{**meta, **metadata} for meta in metadatas]
|
||||
|
||||
self.db.add(documents=documents, metadatas=metadatas_with_metadata, ids=ids)
|
||||
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
count_new_chunks = self.count() - chunks_before_addition
|
||||
print((f"Successfully saved {src}. New chunks count: {count_new_chunks}"))
|
||||
return list(documents), metadatas_with_metadata, ids, count_new_chunks
|
||||
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
||||
return list(documents), metadatas, ids, count_new_chunks
|
||||
|
||||
def _format_result(self, results):
|
||||
return [
|
||||
|
||||
Reference in New Issue
Block a user