feat: add method - detect format / data_type (#380)

This commit is contained in:
cachho
2023-08-16 22:18:24 +02:00
committed by GitHub
parent f92e890aa1
commit 4c8876f032
18 changed files with 472 additions and 121 deletions

View File

@@ -1,9 +1,10 @@
import hashlib
import importlib.metadata
import logging
import os
import threading
import uuid
from typing import Optional
from typing import Dict, Optional
import requests
from dotenv import load_dotenv
@@ -17,6 +18,8 @@ from embedchain.config.apps.BaseAppConfig import BaseAppConfig
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
from embedchain.data_formatter import DataFormatter
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType
from embedchain.utils import detect_datatype
load_dotenv()
@@ -47,27 +50,62 @@ class EmbedChain:
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
thread_telemetry.start()
def add(self, data_type, url, metadata=None, config: AddConfig = None):
def add(
self,
source,
data_type: Optional[DataType] = None,
metadata: Optional[Dict] = None,
config: Optional[AddConfig] = None,
):
"""
Adds the data from the given URL to the vector db.
Loads the data, chunks it, create embedding for each chunk
and then stores the embedding to vector database.
:param data_type: The type of the data to add.
:param url: The URL where the data is located.
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
:param data_type: Optional. Automatically detected, but can be forced with this argument.
The type of the data to add.
:param metadata: Optional. Metadata associated with the data source.
:param config: Optional. The `AddConfig` instance to use as configuration
options.
:return: source_id, a md5-hash of the source, in hexadecimal representation.
"""
if config is None:
config = AddConfig()
try:
DataType(source)
logging.warning(
f"""Starting from version v0.0.39, Embedchain can automatically detect the data type. So, in the `add` method, the argument order has changed. You no longer need to specify '{source}' for the `source` argument. So the code snippet will be `.add("{data_type}", "{source}")`""" # noqa #E501
)
logging.warning(
"Embedchain is swapping the arguments for you. This functionality might be deprecated in the future, so please adjust your code." # noqa #E501
)
source, data_type = data_type, source
except ValueError:
pass
if data_type:
try:
data_type = DataType(data_type)
except ValueError:
raise ValueError(
f"Invalid data_type: '{data_type}'.",
f"Please use one of the following: {[data_type.value for data_type in DataType]}",
) from None
if not data_type:
data_type = detect_datatype(source)
# `source_id` is the hash of the source argument
hash_object = hashlib.md5(str(source).encode("utf-8"))
source_id = hash_object.hexdigest()
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([data_type, url, metadata])
self.user_asks.append([source, data_type.value, metadata])
documents, _metadatas, _ids, new_chunks = self.load_and_embed(
data_formatter.loader, data_formatter.chunker, url, metadata
data_formatter.loader, data_formatter.chunker, source, metadata, source_id
)
if data_type in ("docs_site",):
if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True
# Send anonymous telemetry
@@ -75,41 +113,35 @@ class EmbedChain:
# it's quicker to check the variable twice than to count words when they won't be submitted.
word_count = sum([len(document.split(" ")) for document in documents])
extra_metadata = {"data_type": data_type, "word_count": word_count, "chunks_count": new_chunks}
extra_metadata = {"data_type": data_type.value, "word_count": word_count, "chunks_count": new_chunks}
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
thread_telemetry.start()
def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
return source_id
def add_local(self, source, data_type=None, metadata=None, config: AddConfig = None):
"""
Adds the data you supply to the vector db.
Warning:
This method is deprecated and will be removed in future versions. Use `add` instead.
Adds the data from the given URL to the vector db.
Loads the data, chunks it, create embedding for each chunk
and then stores the embedding to vector database.
:param data_type: The type of the data to add.
:param content: The local data. Refer to the `README` for formatting.
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
:param data_type: Optional. Automatically detected, but can be forced with this argument.
The type of the data to add.
:param metadata: Optional. Metadata associated with the data source.
:param config: Optional. The `AddConfig` instance to use as
configuration options.
:param config: Optional. The `AddConfig` instance to use as configuration
options.
:return: md5-hash of the source, in hexadecimal representation.
"""
if config is None:
config = AddConfig()
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([data_type, content])
documents, _metadatas, _ids, new_chunks = self.load_and_embed(
data_formatter.loader, data_formatter.chunker, content, metadata
logging.warning(
"The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files." # noqa: E501
)
return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
# Send anonymous telemetry
if self.config.collect_metrics:
# it's quicker to check the variable twice than to count words when they won't be submitted.
word_count = sum([len(document.split(" ")) for document in documents])
extra_metadata = {"data_type": data_type, "word_count": word_count, "chunks_count": new_chunks}
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add_local", extra_metadata))
thread_telemetry.start()
def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None):
def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None, source_id=None):
"""
Loads the data from the given URL, chunks it, and adds it to database.
@@ -118,12 +150,16 @@ class EmbedChain:
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
:param metadata: Optional. Metadata associated with the data source.
:param source_id: Hexadecimal hash of the source.
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
"""
embeddings_data = chunker.create_chunks(loader, src)
# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
ids = embeddings_data["ids"]
# get existing ids, and discard doc if any common id exist.
where = {"app_id": self.config.id} if self.config.id is not None else {}
# where={"url": src}
@@ -144,22 +180,31 @@ class EmbedChain:
ids = list(data_dict.keys())
documents, metadatas = zip(*data_dict.values())
# Add app id in metadatas so that they can be queried on later
if self.config.id is not None:
metadatas = [{**m, "app_id": self.config.id} for m in metadatas]
# Loop though all metadatas and add extras.
new_metadatas = []
for m in metadatas:
# Add app id in metadatas so that they can be queried on later
if self.config.id:
m["app_id"] = self.config.id
# FIXME: Fix the error handling logic when metadatas or metadata is None
metadatas = metadatas if metadatas else []
metadata = metadata if metadata else {}
# Add hashed source
m["hash"] = source_id
# Note: Metadata is the function argument
if metadata:
# Spread whatever is in metadata into the new object.
m.update(metadata)
new_metadatas.append(m)
metadatas = new_metadatas
# Count before, to calculate a delta in the end.
chunks_before_addition = self.count()
# Add metadata to each document
metadatas_with_metadata = [{**meta, **metadata} for meta in metadatas]
self.db.add(documents=documents, metadatas=metadatas_with_metadata, ids=ids)
self.db.add(documents=documents, metadatas=metadatas, ids=ids)
count_new_chunks = self.count() - chunks_before_addition
print((f"Successfully saved {src}. New chunks count: {count_new_chunks}"))
return list(documents), metadatas_with_metadata, ids, count_new_chunks
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
def _format_result(self, results):
return [