docs: update docstrings (#565)

This commit is contained in:
cachho
2023-09-07 02:04:44 +02:00
committed by GitHub
parent 4754372fcd
commit 1ac8aef4de
25 changed files with 736 additions and 298 deletions

View File

@@ -6,11 +6,10 @@ import os
import threading
import uuid
from pathlib import Path
from typing import Dict, Optional
from typing import Any, Dict, List, Optional, Tuple
import requests
from dotenv import load_dotenv
from langchain.docstore.document import Document
from tenacity import retry, stop_after_attempt, wait_fixed
from embedchain.chunkers.base_chunker import BaseChunker
@@ -46,8 +45,17 @@ class EmbedChain(JSONSerializable):
Initializes the EmbedChain instance, sets up a vector DB client and
creates a collection.
:param config: BaseAppConfig instance to load as configuration.
:param system_prompt: Optional. System prompt string.
:param config: Configuration just for the app, not the db or llm or embedder.
:type config: BaseAppConfig
:param llm: Instance of the LLM you want to use.
:type llm: BaseLlm
:param db: Instance of the Database to use, defaults to None
:type db: BaseVectorDB, optional
:param embedder: instance of the embedder to use, defaults to None
:type embedder: BaseEmbedder, optional
:param system_prompt: System prompt to use in the llm query, defaults to None
:type system_prompt: Optional[str], optional
:raises ValueError: No database or embedder provided.
"""
self.config = config
@@ -88,10 +96,13 @@ class EmbedChain(JSONSerializable):
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
thread_telemetry.start()
def _load_or_generate_user_id(self):
def _load_or_generate_user_id(self) -> str:
"""
Loads the user id from the config file if it exists, otherwise generates a new
one and saves it to the config file.
:return: user id
:rtype: str
"""
if not os.path.exists(CONFIG_DIR):
os.makedirs(CONFIG_DIR)
@@ -110,9 +121,9 @@ class EmbedChain(JSONSerializable):
def add(
self,
source,
source: Any,
data_type: Optional[DataType] = None,
metadata: Optional[Dict] = None,
metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None,
):
"""
@@ -121,12 +132,17 @@ class EmbedChain(JSONSerializable):
and then stores the embedding to vector database.
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
:param data_type: Optional. Automatically detected, but can be forced with this argument.
The type of the data to add.
:param metadata: Optional. Metadata associated with the data source.
:param config: Optional. The `AddConfig` instance to use as configuration
options.
:type source: Any
:param data_type: Automatically detected, but can be forced with this argument. The type of the data to add,
defaults to None
:type data_type: Optional[DataType], optional
:param metadata: Metadata associated with the data source., defaults to None
:type metadata: Optional[Dict[str, Any]], optional
:param config: The `AddConfig` instance to use as configuration options., defaults to None
:type config: Optional[AddConfig], optional
:raises ValueError: Invalid data type
:return: source_id, a md5-hash of the source, in hexadecimal representation.
:rtype: str
"""
if config is None:
config = AddConfig()
@@ -177,39 +193,62 @@ class EmbedChain(JSONSerializable):
return source_id
def add_local(self, source, data_type=None, metadata=None, config: AddConfig = None):
def add_local(
self,
source: Any,
data_type: Optional[DataType] = None,
metadata: Optional[Dict[str, Any]] = None,
config: Optional[AddConfig] = None,
):
"""
Warning:
This method is deprecated and will be removed in future versions. Use `add` instead.
Adds the data from the given URL to the vector db.
Loads the data, chunks it, create embedding for each chunk
and then stores the embedding to vector database.
Warning:
This method is deprecated and will be removed in future versions. Use `add` instead.
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
:param data_type: Optional. Automatically detected, but can be forced with this argument.
The type of the data to add.
:param metadata: Optional. Metadata associated with the data source.
:param config: Optional. The `AddConfig` instance to use as configuration
options.
:return: md5-hash of the source, in hexadecimal representation.
:type source: Any
:param data_type: Automatically detected, but can be forced with this argument. The type of the data to add,
defaults to None
:type data_type: Optional[DataType], optional
:param metadata: Metadata associated with the data source., defaults to None
:type metadata: Optional[Dict[str, Any]], optional
:param config: The `AddConfig` instance to use as configuration options., defaults to None
:type config: Optional[AddConfig], optional
:raises ValueError: Invalid data type
:return: source_id, a md5-hash of the source, in hexadecimal representation.
:rtype: str
"""
logging.warning(
"The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files." # noqa: E501
)
return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None, source_id=None):
"""
Loads the data from the given URL, chunks it, and adds it to database.
def load_and_embed(
self,
loader: BaseLoader,
chunker: BaseChunker,
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None,
) -> Tuple[List[str], Dict[str, Any], List[str], int]:
"""The loader to use to load the data.
:param loader: The loader to use to load the data.
:type loader: BaseLoader
:param chunker: The chunker to use to chunk the data.
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
:param metadata: Optional. Metadata associated with the data source.
:param source_id: Hexadecimal hash of the source.
:type chunker: BaseChunker
:param src: The data to be handled by the loader.
Can be a URL for remote sources or local content for local loaders.
:type src: Any
:param metadata: Metadata associated with the data source., defaults to None
:type metadata: Dict[str, Any], optional
:param source_id: Hexadecimal hash of the source., defaults to None
:type source_id: str, optional
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
:rtype: Tuple[List[str], Dict[str, Any], List[str], int]
"""
embeddings_data = chunker.create_chunks(loader, src)
@@ -264,25 +303,19 @@ class EmbedChain(JSONSerializable):
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
def _format_result(self, results):
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]
def retrieve_from_database(self, input_query, config: Optional[BaseLlmConfig] = None, where=None):
def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query
:param input_query: The query to use.
:param config: The query configuration.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The content of the document that matched your query.
:type input_query: str
:param config: The query configuration, defaults to None
:type config: Optional[BaseLlmConfig], optional
:param where: A dictionary of key-value pairs to filter the database results, defaults to None
:type where: _type_, optional
:return: List of contents of the document that matched your query
:rtype: List[str]
"""
query_config = config or self.llm.config
@@ -304,23 +337,24 @@ class EmbedChain(JSONSerializable):
return contents
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
:param input_query: The query to use.
:param config: Optional. The `LlmConfig` instance to use as configuration options.
This is used for one method call. To persistently use a config, declare it during app init.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
:type input_query: str
:param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
To persistently use a config, declare it during app init., defaults to None
:type config: Optional[BaseLlmConfig], optional
:param dry_run: A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response., defaults to False
:type dry_run: bool, optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: Optional[Dict[str, str]], optional
:return: The answer to the query or the dry run result
:rtype: str
"""
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
@@ -331,24 +365,32 @@ class EmbedChain(JSONSerializable):
return answer
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
def chat(
self,
input_query: str,
config: Optional[BaseLlmConfig] = None,
dry_run=False,
where: Optional[Dict[str, str]] = None,
) -> str:
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
Maintains the whole conversation in memory.
:param input_query: The query to use.
:param config: Optional. The `LlmConfig` instance to use as configuration options.
This is used for one method call. To persistently use a config, declare it during app init.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
:type input_query: str
:param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
To persistently use a config, declare it during app init., defaults to None
:type config: Optional[BaseLlmConfig], optional
:param dry_run: A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response., defaults to False
:type dry_run: bool, optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: Optional[Dict[str, str]], optional
:return: The answer to the query or the dry run result
:rtype: str
"""
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
@@ -359,15 +401,18 @@ class EmbedChain(JSONSerializable):
return answer
def set_collection(self, collection_name):
def set_collection_name(self, name: str):
"""
Set the collection to use.
Set the name of the collection. A collection is an isolated space for vectors.
:param collection_name: The name of the collection to use.
Using `app.db.set_collection_name` method is preferred to this.
:param name: Name of the collection.
:type name: str
"""
self.db.set_collection_name(collection_name)
self.db.set_collection_name(name)
# Create the collection if it does not exist
self.db._get_or_create_collection(collection_name)
self.db._get_or_create_collection(name)
# TODO: Check whether it is necessary to assign to the `self.collection` attribute,
# since the main purpose is the creation.
@@ -378,8 +423,9 @@ class EmbedChain(JSONSerializable):
DEPRECATED IN FAVOR OF `db.count()`
:return: The number of embeddings.
:rtype: int
"""
logging.warning("DEPRECATION WARNING: Please use `db.count()` instead of `count()`.")
logging.warning("DEPRECATION WARNING: Please use `app.db.count()` instead of `app.count()`.")
return self.db.count()
def reset(self):
@@ -393,11 +439,14 @@ class EmbedChain(JSONSerializable):
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
thread_telemetry.start()
logging.warning("DEPRECATION WARNING: Please use `db.reset()` instead of `reset()`.")
logging.warning("DEPRECATION WARNING: Please use `app.db.reset()` instead of `App.reset()`.")
self.db.reset()
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):
"""
Send telemetry event to the embedchain server. This is anonymous. It can be toggled off in `AppConfig`.
"""
if not self.config.collect_metrics:
return