docs: update docstrings (#565)
This commit is contained in:
@@ -6,11 +6,10 @@ import os
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
@@ -46,8 +45,17 @@ class EmbedChain(JSONSerializable):
|
||||
Initializes the EmbedChain instance, sets up a vector DB client and
|
||||
creates a collection.
|
||||
|
||||
:param config: BaseAppConfig instance to load as configuration.
|
||||
:param system_prompt: Optional. System prompt string.
|
||||
:param config: Configuration just for the app, not the db or llm or embedder.
|
||||
:type config: BaseAppConfig
|
||||
:param llm: Instance of the LLM you want to use.
|
||||
:type llm: BaseLlm
|
||||
:param db: Instance of the Database to use, defaults to None
|
||||
:type db: BaseVectorDB, optional
|
||||
:param embedder: instance of the embedder to use, defaults to None
|
||||
:type embedder: BaseEmbedder, optional
|
||||
:param system_prompt: System prompt to use in the llm query, defaults to None
|
||||
:type system_prompt: Optional[str], optional
|
||||
:raises ValueError: No database or embedder provided.
|
||||
"""
|
||||
|
||||
self.config = config
|
||||
@@ -88,10 +96,13 @@ class EmbedChain(JSONSerializable):
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))
|
||||
thread_telemetry.start()
|
||||
|
||||
def _load_or_generate_user_id(self):
|
||||
def _load_or_generate_user_id(self) -> str:
|
||||
"""
|
||||
Loads the user id from the config file if it exists, otherwise generates a new
|
||||
one and saves it to the config file.
|
||||
|
||||
:return: user id
|
||||
:rtype: str
|
||||
"""
|
||||
if not os.path.exists(CONFIG_DIR):
|
||||
os.makedirs(CONFIG_DIR)
|
||||
@@ -110,9 +121,9 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
def add(
|
||||
self,
|
||||
source,
|
||||
source: Any,
|
||||
data_type: Optional[DataType] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[AddConfig] = None,
|
||||
):
|
||||
"""
|
||||
@@ -121,12 +132,17 @@ class EmbedChain(JSONSerializable):
|
||||
and then stores the embedding to vector database.
|
||||
|
||||
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
|
||||
:param data_type: Optional. Automatically detected, but can be forced with this argument.
|
||||
The type of the data to add.
|
||||
:param metadata: Optional. Metadata associated with the data source.
|
||||
:param config: Optional. The `AddConfig` instance to use as configuration
|
||||
options.
|
||||
:type source: Any
|
||||
:param data_type: Automatically detected, but can be forced with this argument. The type of the data to add,
|
||||
defaults to None
|
||||
:type data_type: Optional[DataType], optional
|
||||
:param metadata: Metadata associated with the data source., defaults to None
|
||||
:type metadata: Optional[Dict[str, Any]], optional
|
||||
:param config: The `AddConfig` instance to use as configuration options., defaults to None
|
||||
:type config: Optional[AddConfig], optional
|
||||
:raises ValueError: Invalid data type
|
||||
:return: source_id, a md5-hash of the source, in hexadecimal representation.
|
||||
:rtype: str
|
||||
"""
|
||||
if config is None:
|
||||
config = AddConfig()
|
||||
@@ -177,39 +193,62 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
return source_id
|
||||
|
||||
def add_local(self, source, data_type=None, metadata=None, config: AddConfig = None):
|
||||
def add_local(
|
||||
self,
|
||||
source: Any,
|
||||
data_type: Optional[DataType] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[AddConfig] = None,
|
||||
):
|
||||
"""
|
||||
Warning:
|
||||
This method is deprecated and will be removed in future versions. Use `add` instead.
|
||||
|
||||
Adds the data from the given URL to the vector db.
|
||||
Loads the data, chunks it, create embedding for each chunk
|
||||
and then stores the embedding to vector database.
|
||||
|
||||
Warning:
|
||||
This method is deprecated and will be removed in future versions. Use `add` instead.
|
||||
|
||||
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
|
||||
:param data_type: Optional. Automatically detected, but can be forced with this argument.
|
||||
The type of the data to add.
|
||||
:param metadata: Optional. Metadata associated with the data source.
|
||||
:param config: Optional. The `AddConfig` instance to use as configuration
|
||||
options.
|
||||
:return: md5-hash of the source, in hexadecimal representation.
|
||||
:type source: Any
|
||||
:param data_type: Automatically detected, but can be forced with this argument. The type of the data to add,
|
||||
defaults to None
|
||||
:type data_type: Optional[DataType], optional
|
||||
:param metadata: Metadata associated with the data source., defaults to None
|
||||
:type metadata: Optional[Dict[str, Any]], optional
|
||||
:param config: The `AddConfig` instance to use as configuration options., defaults to None
|
||||
:type config: Optional[AddConfig], optional
|
||||
:raises ValueError: Invalid data type
|
||||
:return: source_id, a md5-hash of the source, in hexadecimal representation.
|
||||
:rtype: str
|
||||
"""
|
||||
logging.warning(
|
||||
"The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files." # noqa: E501
|
||||
)
|
||||
return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
|
||||
|
||||
def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None, source_id=None):
|
||||
"""
|
||||
Loads the data from the given URL, chunks it, and adds it to database.
|
||||
def load_and_embed(
|
||||
self,
|
||||
loader: BaseLoader,
|
||||
chunker: BaseChunker,
|
||||
src: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
source_id: Optional[str] = None,
|
||||
) -> Tuple[List[str], Dict[str, Any], List[str], int]:
|
||||
"""The loader to use to load the data.
|
||||
|
||||
:param loader: The loader to use to load the data.
|
||||
:type loader: BaseLoader
|
||||
:param chunker: The chunker to use to chunk the data.
|
||||
:param src: The data to be handled by the loader. Can be a URL for
|
||||
remote sources or local content for local loaders.
|
||||
:param metadata: Optional. Metadata associated with the data source.
|
||||
:param source_id: Hexadecimal hash of the source.
|
||||
:type chunker: BaseChunker
|
||||
:param src: The data to be handled by the loader.
|
||||
Can be a URL for remote sources or local content for local loaders.
|
||||
:type src: Any
|
||||
:param metadata: Metadata associated with the data source., defaults to None
|
||||
:type metadata: Dict[str, Any], optional
|
||||
:param source_id: Hexadecimal hash of the source., defaults to None
|
||||
:type source_id: str, optional
|
||||
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
||||
:rtype: Tuple[List[str], Dict[str, Any], List[str], int]
|
||||
"""
|
||||
embeddings_data = chunker.create_chunks(loader, src)
|
||||
|
||||
@@ -264,25 +303,19 @@ class EmbedChain(JSONSerializable):
|
||||
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
||||
return list(documents), metadatas, ids, count_new_chunks
|
||||
|
||||
def _format_result(self, results):
|
||||
return [
|
||||
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
||||
for result in zip(
|
||||
results["documents"][0],
|
||||
results["metadatas"][0],
|
||||
results["distances"][0],
|
||||
)
|
||||
]
|
||||
|
||||
def retrieve_from_database(self, input_query, config: Optional[BaseLlmConfig] = None, where=None):
|
||||
def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param config: The query configuration.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The content of the document that matched your query.
|
||||
:type input_query: str
|
||||
:param config: The query configuration, defaults to None
|
||||
:type config: Optional[BaseLlmConfig], optional
|
||||
:param where: A dictionary of key-value pairs to filter the database results, defaults to None
|
||||
:type where: _type_, optional
|
||||
:return: List of contents of the document that matched your query
|
||||
:rtype: List[str]
|
||||
"""
|
||||
query_config = config or self.llm.config
|
||||
|
||||
@@ -304,23 +337,24 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
return contents
|
||||
|
||||
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
|
||||
def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
LLM as context to get the answer.
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `LlmConfig` instance to use as configuration options.
|
||||
This is used for one method call. To persistently use a config, declare it during app init.
|
||||
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response.
|
||||
You can use it to test your prompt, including the context provided
|
||||
by the vector database's doc retrieval.
|
||||
The only thing the dry run does not consider is the cut-off due to
|
||||
the `max_tokens` parameter.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The answer to the query.
|
||||
:type input_query: str
|
||||
:param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
|
||||
To persistently use a config, declare it during app init., defaults to None
|
||||
:type config: Optional[BaseLlmConfig], optional
|
||||
:param dry_run: A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response., defaults to False
|
||||
:type dry_run: bool, optional
|
||||
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
||||
:type where: Optional[Dict[str, str]], optional
|
||||
:return: The answer to the query or the dry run result
|
||||
:rtype: str
|
||||
"""
|
||||
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
|
||||
answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
|
||||
@@ -331,24 +365,32 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
return answer
|
||||
|
||||
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
|
||||
def chat(
|
||||
self,
|
||||
input_query: str,
|
||||
config: Optional[BaseLlmConfig] = None,
|
||||
dry_run=False,
|
||||
where: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Queries the vector database on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
LLM as context to get the answer.
|
||||
|
||||
Maintains the whole conversation in memory.
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `LlmConfig` instance to use as configuration options.
|
||||
This is used for one method call. To persistently use a config, declare it during app init.
|
||||
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response.
|
||||
You can use it to test your prompt, including the context provided
|
||||
by the vector database's doc retrieval.
|
||||
The only thing the dry run does not consider is the cut-off due to
|
||||
the `max_tokens` parameter.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The answer to the query.
|
||||
:type input_query: str
|
||||
:param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
|
||||
To persistently use a config, declare it during app init., defaults to None
|
||||
:type config: Optional[BaseLlmConfig], optional
|
||||
:param dry_run: A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response., defaults to False
|
||||
:type dry_run: bool, optional
|
||||
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
||||
:type where: Optional[Dict[str, str]], optional
|
||||
:return: The answer to the query or the dry run result
|
||||
:rtype: str
|
||||
"""
|
||||
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where)
|
||||
answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run)
|
||||
@@ -359,15 +401,18 @@ class EmbedChain(JSONSerializable):
|
||||
|
||||
return answer
|
||||
|
||||
def set_collection(self, collection_name):
|
||||
def set_collection_name(self, name: str):
|
||||
"""
|
||||
Set the collection to use.
|
||||
Set the name of the collection. A collection is an isolated space for vectors.
|
||||
|
||||
:param collection_name: The name of the collection to use.
|
||||
Using `app.db.set_collection_name` method is preferred to this.
|
||||
|
||||
:param name: Name of the collection.
|
||||
:type name: str
|
||||
"""
|
||||
self.db.set_collection_name(collection_name)
|
||||
self.db.set_collection_name(name)
|
||||
# Create the collection if it does not exist
|
||||
self.db._get_or_create_collection(collection_name)
|
||||
self.db._get_or_create_collection(name)
|
||||
# TODO: Check whether it is necessary to assign to the `self.collection` attribute,
|
||||
# since the main purpose is the creation.
|
||||
|
||||
@@ -378,8 +423,9 @@ class EmbedChain(JSONSerializable):
|
||||
DEPRECATED IN FAVOR OF `db.count()`
|
||||
|
||||
:return: The number of embeddings.
|
||||
:rtype: int
|
||||
"""
|
||||
logging.warning("DEPRECATION WARNING: Please use `db.count()` instead of `count()`.")
|
||||
logging.warning("DEPRECATION WARNING: Please use `app.db.count()` instead of `app.count()`.")
|
||||
return self.db.count()
|
||||
|
||||
def reset(self):
|
||||
@@ -393,11 +439,14 @@ class EmbedChain(JSONSerializable):
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
|
||||
thread_telemetry.start()
|
||||
|
||||
logging.warning("DEPRECATION WARNING: Please use `db.reset()` instead of `reset()`.")
|
||||
logging.warning("DEPRECATION WARNING: Please use `app.db.reset()` instead of `App.reset()`.")
|
||||
self.db.reset()
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
|
||||
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):
|
||||
"""
|
||||
Send telemetry event to the embedchain server. This is anonymous. It can be toggled off in `AppConfig`.
|
||||
"""
|
||||
if not self.config.collect_metrics:
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user