Rename embedchain to mem0 and open sourcing code for long term memory (#1474)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Taranjeet Singh
2024-07-12 07:51:33 -07:00
committed by GitHub
parent 83e8c97295
commit f842a92e25
665 changed files with 9427 additions and 6592 deletions

View File

@@ -0,0 +1,10 @@
import importlib.metadata
__version__ = importlib.metadata.version(__package__ or __name__)
from embedchain.app import App # noqa: F401
from embedchain.client import Client # noqa: F401
from embedchain.pipeline import Pipeline # noqa: F401
# Setup the user directory if doesn't exist already
Client.setup()

View File

@@ -0,0 +1,116 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = embedchain:migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = WARN
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -0,0 +1,521 @@
import ast
import concurrent.futures
import json
import logging
import os
from typing import Any, Optional, Union
import requests
import yaml
from tqdm import tqdm
from mem0 import Mem0
from embedchain.cache import (
Config,
ExactMatchEvaluation,
SearchDistanceEvaluation,
cache,
gptcache_data_manager,
gptcache_pre_function,
)
from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
from embedchain.core.db.database import get_session, init_db, setup_engine
from embedchain.core.db.models import DataSource
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.evaluation.base import BaseMetric
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
from embedchain.llm.openai import OpenAILlm
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils.evaluation import EvalData, EvalMetric
from embedchain.utils.misc import validate_config
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
logger = logging.getLogger(__name__)
@register_deserializable
class App(EmbedChain):
"""
EmbedChain App lets you create a LLM powered app for your unstructured
data by defining your chosen data source, embedding model,
and vector database.
"""
def __init__(
self,
id: str = None,
name: str = None,
config: AppConfig = None,
db: BaseVectorDB = None,
embedding_model: BaseEmbedder = None,
llm: BaseLlm = None,
config_data: dict = None,
auto_deploy: bool = False,
chunker: ChunkerConfig = None,
cache_config: CacheConfig = None,
memory_config: Mem0Config = None,
log_level: int = logging.WARN,
):
"""
Initialize a new `App` instance.
:param config: Configuration for the pipeline, defaults to None
:type config: AppConfig, optional
:param db: The database to use for storing and retrieving embeddings, defaults to None
:type db: BaseVectorDB, optional
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
:type embedding_model: BaseEmbedder, optional
:param llm: The LLM model used to calculate embeddings, defaults to None
:type llm: BaseLlm, optional
:param config_data: Config dictionary, defaults to None
:type config_data: dict, optional
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
:type auto_deploy: bool, optional
:raises Exception: If an error occurs while creating the pipeline
"""
if id and config_data:
raise Exception("Cannot provide both id and config. Please provide only one of them.")
if id and name:
raise Exception("Cannot provide both id and name. Please provide only one of them.")
if name and config:
raise Exception("Cannot provide both name and config. Please provide only one of them.")
# Initialize the metadata db for the app
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
init_db()
self.auto_deploy = auto_deploy
# Store the dict config as an attribute to be able to send it
self.config_data = config_data if (config_data and validate_config(config_data)) else None
self.client = None
# pipeline_id from the backend
self.id = None
self.chunker = ChunkerConfig(**chunker) if chunker else None
self.cache_config = cache_config
self.memory_config = memory_config
self.config = config or AppConfig()
self.name = self.config.name
self.config.id = self.local_id = "default-app-id" if self.config.id is None else self.config.id
if id is not None:
# Init client first since user is trying to fetch the pipeline
# details from the platform
self._init_client()
pipeline_details = self._get_pipeline(id)
self.config.id = self.local_id = pipeline_details["metadata"]["local_id"]
self.id = id
if name is not None:
self.name = name
self.embedding_model = embedding_model or OpenAIEmbedder()
self.db = db or ChromaDB()
self.llm = llm or OpenAILlm()
self._init_db()
# Session for the metadata db
self.db_session = get_session()
# If cache_config is provided, initializing the cache ...
if self.cache_config is not None:
self._init_cache()
# If memory_config is provided, initializing the memory ...
self.mem0_client = None
if self.memory_config is not None:
self.mem0_client = Mem0(api_key=self.memory_config.api_key)
# Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
self.user_asks = []
if self.auto_deploy:
self.deploy()
def _init_db(self):
"""
Initialize the database.
"""
self.db._set_embedder(self.embedding_model)
self.db._initialize()
self.db.set_collection_name(self.db.config.collection_name)
def _init_cache(self):
if self.cache_config.similarity_eval_config.strategy == "exact":
similarity_eval_func = ExactMatchEvaluation()
else:
similarity_eval_func = SearchDistanceEvaluation(
max_distance=self.cache_config.similarity_eval_config.max_distance,
positive=self.cache_config.similarity_eval_config.positive,
)
cache.init(
pre_embedding_func=gptcache_pre_function,
embedding_func=self.embedding_model.to_embeddings,
data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
similarity_evaluation=similarity_eval_func,
config=Config(**self.cache_config.init_config.as_dict()),
)
def _init_client(self):
"""
Initialize the client.
"""
config = Client.load_config()
if config.get("api_key"):
self.client = Client()
else:
api_key = input(
"🔑 Enter your Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n" # noqa: E501
)
self.client = Client(api_key=api_key)
def _get_pipeline(self, id):
"""
Get existing pipeline
"""
print("🛠️ Fetching pipeline details from the platform...")
url = f"{self.client.host}/api/v1/pipelines/{id}/cli/"
r = requests.get(
url,
headers={"Authorization": f"Token {self.client.api_key}"},
)
if r.status_code == 404:
raise Exception(f"❌ Pipeline with id {id} not found!")
print(
f"🎉 Pipeline loaded successfully! Pipeline url: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
)
return r.json()
def _create_pipeline(self):
"""
Create a pipeline on the platform.
"""
print("🛠️ Creating pipeline on the platform...")
# self.config_data is a dict. Pass it inside the key 'yaml_config' to the backend
payload = {
"yaml_config": json.dumps(self.config_data),
"name": self.name,
"local_id": self.local_id,
}
url = f"{self.client.host}/api/v1/pipelines/cli/create/"
r = requests.post(
url,
json=payload,
headers={"Authorization": f"Token {self.client.api_key}"},
)
if r.status_code not in [200, 201]:
raise Exception(f"❌ Error occurred while creating pipeline. API response: {r.text}")
if r.status_code == 200:
print(
f"🎉🎉🎉 Existing pipeline found! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
) # noqa: E501
elif r.status_code == 201:
print(
f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
)
return r.json()
def _get_presigned_url(self, data_type, data_value):
payload = {"data_type": data_type, "data_value": data_value}
r = requests.post(
f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
json=payload,
headers={"Authorization": f"Token {self.client.api_key}"},
)
r.raise_for_status()
return r.json()
def _upload_file_to_presigned_url(self, presigned_url, file_path):
try:
with open(file_path, "rb") as file:
response = requests.put(presigned_url, data=file)
response.raise_for_status()
return response.status_code == 200
except Exception as e:
logger.exception(f"Error occurred during file upload: {str(e)}")
print("❌ Error occurred during file upload!")
return False
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
payload = {
"data_type": data_type,
"data_value": data_value,
"metadata": metadata,
}
try:
self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
# print the local file path if user tries to upload a local file
printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
except Exception as e:
print(f"❌ Error occurred during data upload for type {data_type}!. Error: {str(e)}")
def _send_api_request(self, endpoint, payload):
url = f"{self.client.host}{endpoint}"
headers = {"Authorization": f"Token {self.client.api_key}"}
response = requests.post(url, json=payload, headers=headers)
response.raise_for_status()
return response
def _process_and_upload_data(self, data_hash, data_type, data_value):
if os.path.isabs(data_value):
presigned_url_data = self._get_presigned_url(data_type, data_value)
presigned_url = presigned_url_data["presigned_url"]
s3_key = presigned_url_data["s3_key"]
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
metadata = {"file_path": data_value, "s3_key": s3_key}
data_value = presigned_url
else:
logger.error(f"File upload failed for hash: {data_hash}")
return False
else:
if data_type == "qna_pair":
data_value = list(ast.literal_eval(data_value))
metadata = {}
try:
self._upload_data_to_pipeline(data_type, data_value, metadata)
self._mark_data_as_uploaded(data_hash)
return True
except Exception:
print(f"❌ Error occurred during data upload for hash {data_hash}!")
return False
def _mark_data_as_uploaded(self, data_hash):
self.db_session.query(DataSource).filter_by(hash=data_hash, app_id=self.local_id).update({"is_uploaded": 1})
def get_data_sources(self):
data_sources = self.db_session.query(DataSource).filter_by(app_id=self.local_id).all()
results = []
for row in data_sources:
results.append({"data_type": row.type, "data_value": row.value, "metadata": row.meta_data})
return results
def deploy(self):
if self.client is None:
self._init_client()
pipeline_data = self._create_pipeline()
self.id = pipeline_data["id"]
results = self.db_session.query(DataSource).filter_by(app_id=self.local_id, is_uploaded=0).all()
if len(results) > 0:
print("🛠️ Adding data to your pipeline...")
for result in results:
data_hash, data_type, data_value = result.hash, result.data_type, result.data_value
self._process_and_upload_data(data_hash, data_type, data_value)
# Send anonymous telemetry
self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
@classmethod
def from_config(
cls,
config_path: Optional[str] = None,
config: Optional[dict[str, Any]] = None,
auto_deploy: bool = False,
yaml_path: Optional[str] = None,
):
"""
Instantiate a App object from a configuration.
:param config_path: Path to the YAML or JSON configuration file.
:type config_path: Optional[str]
:param config: A dictionary containing the configuration.
:type config: Optional[dict[str, Any]]
:param auto_deploy: Whether to deploy the app automatically, defaults to False
:type auto_deploy: bool, optional
:param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
:type yaml_path: Optional[str]
:return: An instance of the App class.
:rtype: App
"""
# Backward compatibility for yaml_path
if yaml_path and not config_path:
config_path = yaml_path
if config_path and config:
raise ValueError("Please provide only one of config_path or config.")
config_data = None
if config_path:
file_extension = os.path.splitext(config_path)[1]
with open(config_path, "r", encoding="UTF-8") as file:
if file_extension in [".yaml", ".yml"]:
config_data = yaml.safe_load(file)
elif file_extension == ".json":
config_data = json.load(file)
else:
raise ValueError("config_path must be a path to a YAML or JSON file.")
elif config and isinstance(config, dict):
config_data = config
else:
logger.error(
"Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
)
config_data = {}
# Validate the config
validate_config(config_data)
app_config_data = config_data.get("app", {}).get("config", {})
vector_db_config_data = config_data.get("vectordb", {})
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
memory_config_data = config_data.get("memory", {})
llm_config_data = config_data.get("llm", {})
chunker_config_data = config_data.get("chunker", {})
cache_config_data = config_data.get("cache", None)
app_config = AppConfig(**app_config_data)
memory_config = Mem0Config(**memory_config_data) if memory_config_data else None
vector_db_provider = vector_db_config_data.get("provider", "chroma")
vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
if llm_config_data:
# Initialize the metadata db for the app here since llmfactory needs it for initialization of
# the llm memory
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
init_db()
llm_provider = llm_config_data.get("provider", "openai")
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
else:
llm = None
embedding_model_provider = embedding_model_config_data.get("provider", "openai")
embedding_model = EmbedderFactory.create(
embedding_model_provider, embedding_model_config_data.get("config", {})
)
if cache_config_data is not None:
cache_config = CacheConfig.from_config(cache_config_data)
else:
cache_config = None
return cls(
config=app_config,
llm=llm,
db=vector_db,
embedding_model=embedding_model,
config_data=config_data,
auto_deploy=auto_deploy,
chunker=chunker_config_data,
cache_config=cache_config,
memory_config=memory_config,
)
def _eval(self, dataset: list[EvalData], metric: Union[BaseMetric, str]):
"""
Evaluate the app on a dataset for a given metric.
"""
metric_str = metric.name if isinstance(metric, BaseMetric) else metric
eval_class_map = {
EvalMetric.CONTEXT_RELEVANCY.value: ContextRelevance,
EvalMetric.ANSWER_RELEVANCY.value: AnswerRelevance,
EvalMetric.GROUNDEDNESS.value: Groundedness,
}
if metric_str in eval_class_map:
return eval_class_map[metric_str]().evaluate(dataset)
# Handle the case for custom metrics
if isinstance(metric, BaseMetric):
return metric.evaluate(dataset)
else:
raise ValueError(f"Invalid metric: {metric}")
def evaluate(
self,
questions: Union[str, list[str]],
metrics: Optional[list[Union[BaseMetric, str]]] = None,
num_workers: int = 4,
):
"""
Evaluate the app on a question.
param: questions: A question or a list of questions to evaluate.
type: questions: Union[str, list[str]]
param: metrics: A list of metrics to evaluate. Defaults to all metrics.
type: metrics: Optional[list[Union[BaseMetric, str]]]
param: num_workers: Number of workers to use for parallel processing.
type: num_workers: int
return: A dictionary containing the evaluation results.
rtype: dict
"""
if "OPENAI_API_KEY" not in os.environ:
raise ValueError("Please set the OPENAI_API_KEY environment variable with permission to use `gpt4` model.")
queries, answers, contexts = [], [], []
if isinstance(questions, list):
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
future_to_data = {executor.submit(self.query, q, citations=True): q for q in questions}
for future in tqdm(
concurrent.futures.as_completed(future_to_data),
total=len(future_to_data),
desc="Getting answer and contexts for questions",
):
question = future_to_data[future]
queries.append(question)
answer, context = future.result()
answers.append(answer)
contexts.append(list(map(lambda x: x[0], context)))
else:
answer, context = self.query(questions, citations=True)
queries = [questions]
answers = [answer]
contexts = [list(map(lambda x: x[0], context))]
metrics = metrics or [
EvalMetric.CONTEXT_RELEVANCY.value,
EvalMetric.ANSWER_RELEVANCY.value,
EvalMetric.GROUNDEDNESS.value,
]
logger.info(f"Collecting data from {len(queries)} questions for evaluation...")
dataset = []
for q, a, c in zip(queries, answers, contexts):
dataset.append(EvalData(question=q, answer=a, contexts=c))
logger.info(f"Evaluating {len(dataset)} data points...")
result = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
future_to_metric = {executor.submit(self._eval, dataset, metric): metric for metric in metrics}
for future in tqdm(
concurrent.futures.as_completed(future_to_metric),
total=len(future_to_metric),
desc="Evaluating metrics",
):
metric = future_to_metric[future]
if isinstance(metric, BaseMetric):
result[metric.name] = future.result()
else:
result[metric] = future.result()
if self.config.collect_metrics:
telemetry_props = self._telemetry_props
metrics_names = []
for metric in metrics:
if isinstance(metric, BaseMetric):
metrics_names.append(metric.name)
else:
metrics_names.append(metric)
telemetry_props["metrics"] = metrics_names
self.telemetry.capture(event_name="evaluate", properties=telemetry_props)
return result

View File

@@ -0,0 +1,5 @@
from embedchain.bots.poe import PoeBot # noqa: F401
from embedchain.bots.whatsapp import WhatsAppBot # noqa: F401
# TODO: fix discord import
# from embedchain.bots.discord import DiscordBot

View File

@@ -0,0 +1,46 @@
from typing import Any
from embedchain import App
from embedchain.config import AddConfig, AppConfig, BaseLlmConfig
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.helpers.json_serializable import (JSONSerializable,
register_deserializable)
from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.chroma import ChromaDB
@register_deserializable
class BaseBot(JSONSerializable):
def __init__(self):
self.app = App(config=AppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedding_model=OpenAIEmbedder())
def add(self, data: Any, config: AddConfig = None):
"""
Add data to the bot (to the vector database).
Auto-dectects type only, so some data types might not be usable.
:param data: data to embed
:type data: Any
:param config: configuration class instance, defaults to None
:type config: AddConfig, optional
"""
config = config if config else AddConfig()
self.app.add(data, config=config)
def query(self, query: str, config: BaseLlmConfig = None) -> str:
"""
Query the bot
:param query: the user query
:type query: str
:param config: configuration class instance, defaults to None
:type config: BaseLlmConfig, optional
:return: Answer
:rtype: str
"""
config = config
return self.app.query(query, config=config)
def start(self):
"""Start the bot's functionality."""
raise NotImplementedError("Subclasses must implement the start method.")

View File

@@ -0,0 +1,128 @@
import argparse
import logging
import os
from embedchain.helpers.json_serializable import register_deserializable
from .base import BaseBot
try:
import discord
from discord import app_commands
from discord.ext import commands
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for Discord are not installed." "Please install with `pip install discord==2.3.2`"
) from None
logger = logging.getLogger(__name__)
intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)
tree = app_commands.CommandTree(client)
# Invite link example
# https://discord.com/api/oauth2/authorize?client_id={DISCORD_CLIENT_ID}&permissions=2048&scope=bot
@register_deserializable
class DiscordBot(BaseBot):
def __init__(self, *args, **kwargs):
BaseBot.__init__(self, *args, **kwargs)
def add_data(self, message):
data = message.split(" ")[-1]
try:
self.add(data)
response = f"Added data from: {data}"
except Exception:
logger.exception(f"Failed to add data {data}.")
response = "Some error occurred while adding data."
return response
def ask_bot(self, message):
try:
response = self.query(message)
except Exception:
logger.exception(f"Failed to query {message}.")
response = "An error occurred. Please try again!"
return response
def start(self):
client.run(os.environ["DISCORD_BOT_TOKEN"])
# @tree decorator cannot be used in a class. A global discord_bot is used as a workaround.
@tree.command(name="question", description="ask embedchain")
async def query_command(interaction: discord.Interaction, question: str):
await interaction.response.defer()
member = client.guilds[0].get_member(client.user.id)
logger.info(f"User: {member}, Query: {question}")
try:
answer = discord_bot.ask_bot(question)
if args.include_question:
response = f"> {question}\n\n{answer}"
else:
response = answer
await interaction.followup.send(response)
except Exception as e:
await interaction.followup.send("An error occurred. Please try again!")
logger.error("Error occurred during 'query' command:", e)
@tree.command(name="add", description="add new content to the embedchain database")
async def add_command(interaction: discord.Interaction, url_or_text: str):
await interaction.response.defer()
member = client.guilds[0].get_member(client.user.id)
logger.info(f"User: {member}, Add: {url_or_text}")
try:
response = discord_bot.add_data(url_or_text)
await interaction.followup.send(response)
except Exception as e:
await interaction.followup.send("An error occurred. Please try again!")
logger.error("Error occurred during 'add' command:", e)
@tree.command(name="ping", description="Simple ping pong command")
async def ping(interaction: discord.Interaction):
await interaction.response.send_message("Pong", ephemeral=True)
@tree.error
async def on_app_command_error(interaction: discord.Interaction, error: discord.app_commands.AppCommandError) -> None:
if isinstance(error, commands.CommandNotFound):
await interaction.followup.send("Invalid command. Please refer to the documentation for correct syntax.")
else:
logger.error("Error occurred during command execution:", error)
@client.event
async def on_ready():
# TODO: Sync in admin command, to not hit rate limits.
# This might be overkill for most users, and it would require to set a guild or user id, where sync is allowed.
await tree.sync()
logger.debug("Command tree synced")
logger.info(f"Logged in as {client.user.name}")
def start_command():
parser = argparse.ArgumentParser(description="EmbedChain DiscordBot command line interface")
parser.add_argument(
"--include-question",
help="include question in query reply, otherwise it is hidden behind the slash command.",
action="store_true",
)
global args
args = parser.parse_args()
global discord_bot
discord_bot = DiscordBot()
discord_bot.start()
if __name__ == "__main__":
start_command()

View File

@@ -0,0 +1,87 @@
import argparse
import logging
import os
from typing import Optional
from embedchain.helpers.json_serializable import register_deserializable
from .base import BaseBot
try:
from fastapi_poe import PoeBot, run
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for Poe are not installed." "Please install with `pip install fastapi-poe==0.0.16`"
) from None
def start_command():
parser = argparse.ArgumentParser(description="EmbedChain PoeBot command line interface")
# parser.add_argument("--host", default="0.0.0.0", help="Host IP to bind")
parser.add_argument("--port", default=8080, type=int, help="Port to bind")
parser.add_argument("--api-key", type=str, help="Poe API key")
# parser.add_argument(
# "--history-length",
# default=5,
# type=int,
# help="Set the max size of the chat history. Multiplies cost, but improves conversation awareness.",
# )
args = parser.parse_args()
# FIXME: Arguments are automatically loaded by Poebot's ArgumentParser which causes it to fail.
# the port argument here is also just for show, it actually works because poe has the same argument.
run(PoeBot(), api_key=args.api_key or os.environ.get("POE_API_KEY"))
@register_deserializable
class PoeBot(BaseBot, PoeBot):
def __init__(self):
self.history_length = 5
super().__init__()
async def get_response(self, query):
last_message = query.query[-1].content
try:
history = (
[f"{m.role}: {m.content}" for m in query.query[-(self.history_length + 1) : -1]]
if len(query.query) > 0
else None
)
except Exception as e:
logging.error(f"Error when processing the chat history. Message is being sent without history. Error: {e}")
answer = self.handle_message(last_message, history)
yield self.text_event(answer)
def handle_message(self, message, history: Optional[list[str]] = None):
if message.startswith("/add "):
response = self.add_data(message)
else:
response = self.ask_bot(message, history)
return response
# def add_data(self, message):
# data = message.split(" ")[-1]
# try:
# self.add(data)
# response = f"Added data from: {data}"
# except Exception:
# logging.exception(f"Failed to add data {data}.")
# response = "Some error occurred while adding data."
# return response
def ask_bot(self, message, history: list[str]):
try:
self.app.llm.set_history(history=history)
response = self.query(message)
except Exception:
logging.exception(f"Failed to query {message}.")
response = "An error occurred. Please try again!"
return response
def start(self):
start_command()
if __name__ == "__main__":
start_command()

View File

@@ -0,0 +1,101 @@
import argparse
import logging
import os
import signal
import sys
from embedchain import App
from embedchain.helpers.json_serializable import register_deserializable
from .base import BaseBot
try:
from flask import Flask, request
from slack_sdk import WebClient
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for Slack are not installed."
"Please install with `pip install slack-sdk==3.21.3 flask==2.3.3`"
) from None
logger = logging.getLogger(__name__)
SLACK_BOT_TOKEN = os.environ.get("SLACK_BOT_TOKEN")
@register_deserializable
class SlackBot(BaseBot):
def __init__(self):
self.client = WebClient(token=SLACK_BOT_TOKEN)
self.chat_bot = App()
self.recent_message = {"ts": 0, "channel": ""}
super().__init__()
def handle_message(self, event_data):
message = event_data.get("event")
if message and "text" in message and message.get("subtype") != "bot_message":
text: str = message["text"]
if float(message.get("ts")) > float(self.recent_message["ts"]):
self.recent_message["ts"] = message["ts"]
self.recent_message["channel"] = message["channel"]
if text.startswith("query"):
_, question = text.split(" ", 1)
try:
response = self.chat_bot.chat(question)
self.send_slack_message(message["channel"], response)
logger.info("Query answered successfully!")
except Exception as e:
self.send_slack_message(message["channel"], "An error occurred. Please try again!")
logger.error("Error occurred during 'query' command:", e)
elif text.startswith("add"):
_, data_type, url_or_text = text.split(" ", 2)
if url_or_text.startswith("<") and url_or_text.endswith(">"):
url_or_text = url_or_text[1:-1]
try:
self.chat_bot.add(url_or_text, data_type)
self.send_slack_message(message["channel"], f"Added {data_type} : {url_or_text}")
except ValueError as e:
self.send_slack_message(message["channel"], f"Error: {str(e)}")
logger.error("Error occurred during 'add' command:", e)
except Exception as e:
self.send_slack_message(message["channel"], f"Failed to add {data_type} : {url_or_text}")
logger.error("Error occurred during 'add' command:", e)
def send_slack_message(self, channel, message):
response = self.client.chat_postMessage(channel=channel, text=message)
return response
def start(self, host="0.0.0.0", port=5000, debug=True):
app = Flask(__name__)
def signal_handler(sig, frame):
logger.info("\nGracefully shutting down the SlackBot...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
@app.route("/", methods=["POST"])
def chat():
# Check if the request is a verification request
if request.json.get("challenge"):
return str(request.json.get("challenge"))
response = self.handle_message(request.json)
return str(response)
app.run(host=host, port=port, debug=debug)
def start_command():
parser = argparse.ArgumentParser(description="EmbedChain SlackBot command line interface")
parser.add_argument("--host", default="0.0.0.0", help="Host IP to bind")
parser.add_argument("--port", default=5000, type=int, help="Port to bind")
args = parser.parse_args()
slack_bot = SlackBot()
slack_bot.start(host=args.host, port=args.port)
if __name__ == "__main__":
start_command()

View File

@@ -0,0 +1,83 @@
import argparse
import importlib
import logging
import signal
import sys
from embedchain.helpers.json_serializable import register_deserializable
from .base import BaseBot
logger = logging.getLogger(__name__)
@register_deserializable
class WhatsAppBot(BaseBot):
def __init__(self):
try:
self.flask = importlib.import_module("flask")
self.twilio = importlib.import_module("twilio")
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for WhatsApp are not installed. "
"Please install with `pip install twilio==8.5.0 flask==2.3.3`"
) from None
super().__init__()
def handle_message(self, message):
if message.startswith("add "):
response = self.add_data(message)
else:
response = self.ask_bot(message)
return response
def add_data(self, message):
data = message.split(" ")[-1]
try:
self.add(data)
response = f"Added data from: {data}"
except Exception:
logger.exception(f"Failed to add data {data}.")
response = "Some error occurred while adding data."
return response
def ask_bot(self, message):
try:
response = self.query(message)
except Exception:
logger.exception(f"Failed to query {message}.")
response = "An error occurred. Please try again!"
return response
def start(self, host="0.0.0.0", port=5000, debug=True):
app = self.flask.Flask(__name__)
def signal_handler(sig, frame):
logger.info("\nGracefully shutting down the WhatsAppBot...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
@app.route("/chat", methods=["POST"])
def chat():
incoming_message = self.flask.request.values.get("Body", "").lower()
response = self.handle_message(incoming_message)
twilio_response = self.twilio.twiml.messaging_response.MessagingResponse()
twilio_response.message(response)
return str(twilio_response)
app.run(host=host, port=port, debug=debug)
def start_command():
parser = argparse.ArgumentParser(description="EmbedChain WhatsAppBot command line interface")
parser.add_argument("--host", default="0.0.0.0", help="Host IP to bind")
parser.add_argument("--port", default=5000, type=int, help="Port to bind")
args = parser.parse_args()
whatsapp_bot = WhatsAppBot()
whatsapp_bot.start(host=args.host, port=args.port)
if __name__ == "__main__":
start_command()

View File

@@ -0,0 +1,44 @@
import logging
import os # noqa: F401
from typing import Any
from gptcache import cache # noqa: F401
from gptcache.adapter.adapter import adapt # noqa: F401
from gptcache.config import Config # noqa: F401
from gptcache.manager import get_data_manager
from gptcache.manager.scalar_data.base import Answer
from gptcache.manager.scalar_data.base import DataType as CacheDataType
from gptcache.session import Session
from gptcache.similarity_evaluation.distance import \
SearchDistanceEvaluation # noqa: F401
from gptcache.similarity_evaluation.exact_match import \
ExactMatchEvaluation # noqa: F401
logger = logging.getLogger(__name__)
def gptcache_pre_function(data: dict[str, Any], **params: dict[str, Any]):
return data["input_query"]
def gptcache_data_manager(vector_dimension):
return get_data_manager(cache_base="sqlite", vector_base="chromadb", max_size=1000, eviction="LRU")
def gptcache_data_convert(cache_data):
logger.info("[Cache] Cache hit, returning cache data...")
return cache_data
def gptcache_update_cache_callback(llm_data, update_cache_func, *args, **kwargs):
logger.info("[Cache] Cache missed, updating cache...")
update_cache_func(Answer(llm_data, CacheDataType.STR))
return llm_data
def _gptcache_session_hit_func(cur_session_id: str, cache_session_ids: list, cache_questions: list, cache_answer: str):
return cur_session_id in cache_session_ids
def get_gptcache_session(session_id: str):
return Session(name=session_id, check_hit_func=_gptcache_session_hit_func)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class AudioChunker(BaseChunker):
"""Chunker for audio."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,87 @@
import hashlib
import logging
from typing import Optional
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.models.data_type import DataType
logger = logging.getLogger(__name__)
class BaseChunker(JSONSerializable):
def __init__(self, text_splitter):
"""Initialize the chunker."""
self.text_splitter = text_splitter
self.data_type = None
def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
"""
Loads data and chunks it.
:param loader: The loader whose `load_data` method is used to create
the raw data.
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
:param app_id: App id used to generate the doc_id.
"""
documents = []
chunk_ids = []
id_map = {}
min_chunk_size = config.min_chunk_size if config is not None else 1
logger.info(f"Skipping chunks smaller than {min_chunk_size} characters")
data_result = loader.load_data(src)
data_records = data_result["data"]
doc_id = data_result["doc_id"]
# Prefix app_id in the document id if app_id is not None to
# distinguish between different documents stored in the same
# elasticsearch or opensearch index
doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
metadatas = []
for data in data_records:
content = data["content"]
metadata = data["meta_data"]
# add data type to meta data to allow query using data type
metadata["data_type"] = self.data_type.value
metadata["doc_id"] = doc_id
# TODO: Currently defaulting to the src as the url. This is done intentianally since some
# of the data types like 'gmail' loader doesn't have the url in the meta data.
url = metadata.get("url", src)
chunks = self.get_chunks(content)
for chunk in chunks:
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
if id_map.get(chunk_id) is None and len(chunk) >= min_chunk_size:
id_map[chunk_id] = True
chunk_ids.append(chunk_id)
documents.append(chunk)
metadatas.append(metadata)
return {
"documents": documents,
"ids": chunk_ids,
"metadatas": metadatas,
"doc_id": doc_id,
}
def get_chunks(self, content):
"""
Returns chunks using text splitter instance.
Override in child class if custom logic.
"""
return self.text_splitter.split_text(content)
def set_data_type(self, data_type: DataType):
"""
set the data type of chunker
"""
self.data_type = data_type
# TODO: This should be done during initialization. This means it has to be done in the child classes.
@staticmethod
def get_word_count(documents) -> int:
return sum(len(document.split(" ")) for document in documents)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class BeehiivChunker(BaseChunker):
"""Chunker for Beehiiv."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class CommonChunker(BaseChunker):
"""Common chunker for all loaders."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class DiscourseChunker(BaseChunker):
"""Chunker for discourse."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class DocsSiteChunker(BaseChunker):
"""Chunker for code docs site."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=500, chunk_overlap=50, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class DocxFileChunker(BaseChunker):
"""Chunker for .docx file."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class ExcelFileChunker(BaseChunker):
"""Chunker for Excel file."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class GmailChunker(BaseChunker):
"""Chunker for gmail."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class GoogleDriveChunker(BaseChunker):
"""Chunker for google drive folder."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class ImageChunker(BaseChunker):
"""Chunker for Images."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class JSONChunker(BaseChunker):
"""Chunker for json."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class MdxChunker(BaseChunker):
"""Chunker for mdx files."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class MySQLChunker(BaseChunker):
"""Chunker for json."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class NotionChunker(BaseChunker):
"""Chunker for notion."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,18 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
class OpenAPIChunker(BaseChunker):
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class PdfFileChunker(BaseChunker):
"""Chunker for PDF file."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class PostgresChunker(BaseChunker):
"""Chunker for postgres."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class QnaPairChunker(BaseChunker):
"""Chunker for QnA pair."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class RSSFeedChunker(BaseChunker):
"""Chunker for RSS Feed."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class SitemapChunker(BaseChunker):
"""Chunker for sitemap."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class SlackChunker(BaseChunker):
"""Chunker for postgres."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class SubstackChunker(BaseChunker):
"""Chunker for Substack."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,20 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
class TableChunker(BaseChunker):
"""Chunker for tables, for instance csv, google sheets or databases."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class TextChunker(BaseChunker):
"""Chunker for text."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class UnstructuredFileChunker(BaseChunker):
"""Chunker for Unstructured file."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class WebPageChunker(BaseChunker):
"""Chunker for web page."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class XmlChunker(BaseChunker):
"""Chunker for XML files."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=500, chunk_overlap=50, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class YoutubeVideoChunker(BaseChunker):
"""Chunker for Youtube video."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -0,0 +1,327 @@
import json
import os
import shutil
import signal
import subprocess
import sys
import tempfile
import time
import zipfile
from pathlib import Path
import click
import requests
from rich.console import Console
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils.cli import (deploy_fly, deploy_gradio_app,
deploy_hf_spaces, deploy_modal,
deploy_render, deploy_streamlit,
get_pkg_path_from_name, setup_fly_io_app,
setup_gradio_app, setup_hf_app,
setup_modal_com_app, setup_render_com_app,
setup_streamlit_io_app)
console = Console()
api_process = None
ui_process = None
anonymous_telemetry = AnonymousTelemetry()
def signal_handler(sig, frame):
"""Signal handler to catch termination signals and kill server processes."""
global api_process, ui_process
console.print("\n🛑 [bold yellow]Stopping servers...[/bold yellow]")
if api_process:
api_process.terminate()
console.print("🛑 [bold yellow]API server stopped.[/bold yellow]")
if ui_process:
ui_process.terminate()
console.print("🛑 [bold yellow]UI server stopped.[/bold yellow]")
sys.exit(0)
@click.group()
def cli():
pass
@cli.command()
@click.argument("app_name")
@click.option("--docker", is_flag=True, help="Use docker to create the app.")
@click.pass_context
def create_app(ctx, app_name, docker):
if Path(app_name).exists():
console.print(
f"❌ [red]Directory '{app_name}' already exists. Try using a new directory name, or remove it.[/red]"
)
return
os.makedirs(app_name)
os.chdir(app_name)
# Step 1: Download the zip file
zip_url = "http://github.com/embedchain/ec-admin/archive/main.zip"
console.print(f"Creating a new embedchain app in [green]{Path().resolve()}[/green]\n")
try:
response = requests.get(zip_url)
response.raise_for_status()
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_file.write(response.content)
zip_file_path = tmp_file.name
console.print("✅ [bold green]Fetched template successfully.[/bold green]")
except requests.RequestException as e:
console.print(f"❌ [bold red]Failed to download zip file: {e}[/bold red]")
anonymous_telemetry.capture(event_name="ec_create_app", properties={"success": False})
return
# Step 2: Extract the zip file
try:
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
# Get the name of the root directory inside the zip file
root_dir = Path(zip_ref.namelist()[0])
for member in zip_ref.infolist():
# Build the path to extract the file to, skipping the root directory
target_file = Path(member.filename).relative_to(root_dir)
source_file = zip_ref.open(member, "r")
if member.is_dir():
# Create directory if it doesn't exist
os.makedirs(target_file, exist_ok=True)
else:
with open(target_file, "wb") as file:
# Write the file
shutil.copyfileobj(source_file, file)
console.print("✅ [bold green]Extracted zip file successfully.[/bold green]")
anonymous_telemetry.capture(event_name="ec_create_app", properties={"success": True})
except zipfile.BadZipFile:
console.print("❌ [bold red]Error in extracting zip file. The file might be corrupted.[/bold red]")
anonymous_telemetry.capture(event_name="ec_create_app", properties={"success": False})
return
if docker:
subprocess.run(["docker-compose", "build"], check=True)
else:
ctx.invoke(install_reqs)
@cli.command()
def install_reqs():
try:
console.print("Installing python requirements...\n")
time.sleep(2)
os.chdir("api")
subprocess.run(["pip", "install", "-r", "requirements.txt"], check=True)
os.chdir("..")
console.print("\n ✅ [bold green]Installed API requirements successfully.[/bold green]\n")
except Exception as e:
console.print(f"❌ [bold red]Failed to install API requirements: {e}[/bold red]")
anonymous_telemetry.capture(event_name="ec_install_reqs", properties={"success": False})
return
try:
os.chdir("ui")
subprocess.run(["yarn"], check=True)
console.print("\n✅ [bold green]Successfully installed frontend requirements.[/bold green]")
anonymous_telemetry.capture(event_name="ec_install_reqs", properties={"success": True})
except Exception as e:
console.print(f"❌ [bold red]Failed to install frontend requirements. Error: {e}[/bold red]")
anonymous_telemetry.capture(event_name="ec_install_reqs", properties={"success": False})
@cli.command()
@click.option("--docker", is_flag=True, help="Run inside docker.")
def start(docker):
if docker:
subprocess.run(["docker-compose", "up"], check=True)
return
# Set up signal handling
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Step 1: Start the API server
try:
os.chdir("api")
api_process = subprocess.Popen(["python", "-m", "main"], stdout=None, stderr=None)
os.chdir("..")
console.print("✅ [bold green]API server started successfully.[/bold green]")
except Exception as e:
console.print(f"❌ [bold red]Failed to start the API server: {e}[/bold red]")
anonymous_telemetry.capture(event_name="ec_start", properties={"success": False})
return
# Sleep for 2 seconds to give the user time to read the message
time.sleep(2)
# Step 2: Install UI requirements and start the UI server
try:
os.chdir("ui")
subprocess.run(["yarn"], check=True)
ui_process = subprocess.Popen(["yarn", "dev"])
console.print("✅ [bold green]UI server started successfully.[/bold green]")
anonymous_telemetry.capture(event_name="ec_start", properties={"success": True})
except Exception as e:
console.print(f"❌ [bold red]Failed to start the UI server: {e}[/bold red]")
anonymous_telemetry.capture(event_name="ec_start", properties={"success": False})
# Keep the script running until it receives a kill signal
try:
api_process.wait()
ui_process.wait()
except KeyboardInterrupt:
console.print("\n🛑 [bold yellow]Stopping server...[/bold yellow]")
@cli.command()
@click.option("--template", default="fly.io", help="The template to use.")
@click.argument("extra_args", nargs=-1, type=click.UNPROCESSED)
def create(template, extra_args):
anonymous_telemetry.capture(event_name="ec_create", properties={"template_used": template})
template_dir = template
if "/" in template_dir:
template_dir = template.split("/")[1]
src_path = get_pkg_path_from_name(template_dir)
shutil.copytree(src_path, os.getcwd(), dirs_exist_ok=True)
console.print(f"✅ [bold green]Successfully created app from template '{template}'.[/bold green]")
if template == "fly.io":
setup_fly_io_app(extra_args)
elif template == "modal.com":
setup_modal_com_app(extra_args)
elif template == "render.com":
setup_render_com_app()
elif template == "streamlit.io":
setup_streamlit_io_app()
elif template == "gradio.app":
setup_gradio_app()
elif template == "hf/gradio.app" or template == "hf/streamlit.io":
setup_hf_app()
else:
raise ValueError(f"Unknown template '{template}'.")
embedchain_config = {"provider": template}
with open("embedchain.json", "w") as file:
json.dump(embedchain_config, file, indent=4)
console.print(
f"🎉 [green]All done! Successfully created `embedchain.json` with '{template}' as provider.[/green]"
)
def run_dev_fly_io(debug, host, port):
uvicorn_command = ["uvicorn", "app:app"]
if debug:
uvicorn_command.append("--reload")
uvicorn_command.extend(["--host", host, "--port", str(port)])
try:
console.print(f"🚀 [bold cyan]Running FastAPI app with command: {' '.join(uvicorn_command)}[/bold cyan]")
subprocess.run(uvicorn_command, check=True)
except subprocess.CalledProcessError as e:
console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
except KeyboardInterrupt:
console.print("\n🛑 [bold yellow]FastAPI server stopped[/bold yellow]")
def run_dev_modal_com():
modal_run_cmd = ["modal", "serve", "app"]
try:
console.print(f"🚀 [bold cyan]Running FastAPI app with command: {' '.join(modal_run_cmd)}[/bold cyan]")
subprocess.run(modal_run_cmd, check=True)
except subprocess.CalledProcessError as e:
console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
except KeyboardInterrupt:
console.print("\n🛑 [bold yellow]FastAPI server stopped[/bold yellow]")
def run_dev_streamlit_io():
streamlit_run_cmd = ["streamlit", "run", "app.py"]
try:
console.print(f"🚀 [bold cyan]Running Streamlit app with command: {' '.join(streamlit_run_cmd)}[/bold cyan]")
subprocess.run(streamlit_run_cmd, check=True)
except subprocess.CalledProcessError as e:
console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
except KeyboardInterrupt:
console.print("\n🛑 [bold yellow]Streamlit server stopped[/bold yellow]")
def run_dev_render_com(debug, host, port):
uvicorn_command = ["uvicorn", "app:app"]
if debug:
uvicorn_command.append("--reload")
uvicorn_command.extend(["--host", host, "--port", str(port)])
try:
console.print(f"🚀 [bold cyan]Running FastAPI app with command: {' '.join(uvicorn_command)}[/bold cyan]")
subprocess.run(uvicorn_command, check=True)
except subprocess.CalledProcessError as e:
console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
except KeyboardInterrupt:
console.print("\n🛑 [bold yellow]FastAPI server stopped[/bold yellow]")
def run_dev_gradio():
gradio_run_cmd = ["gradio", "app.py"]
try:
console.print(f"🚀 [bold cyan]Running Gradio app with command: {' '.join(gradio_run_cmd)}[/bold cyan]")
subprocess.run(gradio_run_cmd, check=True)
except subprocess.CalledProcessError as e:
console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
except KeyboardInterrupt:
console.print("\n🛑 [bold yellow]Gradio server stopped[/bold yellow]")
@cli.command()
@click.option("--debug", is_flag=True, help="Enable or disable debug mode.")
@click.option("--host", default="127.0.0.1", help="The host address to run the FastAPI app on.")
@click.option("--port", default=8000, help="The port to run the FastAPI app on.")
def dev(debug, host, port):
template = ""
with open("embedchain.json", "r") as file:
embedchain_config = json.load(file)
template = embedchain_config["provider"]
anonymous_telemetry.capture(event_name="ec_dev", properties={"template_used": template})
if template == "fly.io":
run_dev_fly_io(debug, host, port)
elif template == "modal.com":
run_dev_modal_com()
elif template == "render.com":
run_dev_render_com(debug, host, port)
elif template == "streamlit.io" or template == "hf/streamlit.io":
run_dev_streamlit_io()
elif template == "gradio.app" or template == "hf/gradio.app":
run_dev_gradio()
else:
raise ValueError(f"Unknown template '{template}'.")
@cli.command()
def deploy():
# Check for platform-specific files
template = ""
ec_app_name = ""
with open("embedchain.json", "r") as file:
embedchain_config = json.load(file)
ec_app_name = embedchain_config["name"] if "name" in embedchain_config else None
template = embedchain_config["provider"]
anonymous_telemetry.capture(event_name="ec_deploy", properties={"template_used": template})
if template == "fly.io":
deploy_fly()
elif template == "modal.com":
deploy_modal()
elif template == "render.com":
deploy_render()
elif template == "streamlit.io":
deploy_streamlit()
elif template == "gradio.app":
deploy_gradio_app()
elif template.startswith("hf/"):
deploy_hf_spaces(ec_app_name)
else:
console.print("❌ [bold red]No recognized deployment platform found.[/bold red]")

View File

@@ -0,0 +1,103 @@
import json
import logging
import os
import uuid
import requests
from embedchain.constants import CONFIG_DIR, CONFIG_FILE
logger = logging.getLogger(__name__)
class Client:
def __init__(self, api_key=None, host="https://apiv2.embedchain.ai"):
self.config_data = self.load_config()
self.host = host
if api_key:
if self.check(api_key):
self.api_key = api_key
self.save()
else:
raise ValueError(
"Invalid API key provided. You can find your API key on https://app.embedchain.ai/settings/keys."
)
else:
if "api_key" in self.config_data:
self.api_key = self.config_data["api_key"]
logger.info("API key loaded successfully!")
else:
raise ValueError(
"You are not logged in. Please obtain an API key from https://app.embedchain.ai/settings/keys/"
)
@classmethod
def setup(cls):
"""
Loads the user id from the config file if it exists, otherwise generates a new
one and saves it to the config file.
:return: user id
:rtype: str
"""
os.makedirs(CONFIG_DIR, exist_ok=True)
if os.path.exists(CONFIG_FILE):
with open(CONFIG_FILE, "r") as f:
data = json.load(f)
if "user_id" in data:
return data["user_id"]
u_id = str(uuid.uuid4())
with open(CONFIG_FILE, "w") as f:
json.dump({"user_id": u_id}, f)
@classmethod
def load_config(cls):
if not os.path.exists(CONFIG_FILE):
cls.setup()
with open(CONFIG_FILE, "r") as config_file:
return json.load(config_file)
def save(self):
self.config_data["api_key"] = self.api_key
with open(CONFIG_FILE, "w") as config_file:
json.dump(self.config_data, config_file, indent=4)
logger.info("API key saved successfully!")
def clear(self):
if "api_key" in self.config_data:
del self.config_data["api_key"]
with open(CONFIG_FILE, "w") as config_file:
json.dump(self.config_data, config_file, indent=4)
self.api_key = None
logger.info("API key deleted successfully!")
else:
logger.warning("API key not found in the configuration file.")
def update(self, api_key):
if self.check(api_key):
self.api_key = api_key
self.save()
logger.info("API key updated successfully!")
else:
logger.warning("Invalid API key provided. API key not updated.")
def check(self, api_key):
validation_url = f"{self.host}/api/v1/accounts/api_keys/validate/"
response = requests.post(validation_url, headers={"Authorization": f"Token {api_key}"})
if response.status_code == 200:
return True
else:
logger.warning(f"Response from API: {response.text}")
logger.warning("Invalid API key. Unable to validate.")
return False
def get(self):
return self.api_key
def __str__(self):
return self.api_key

View File

@@ -0,0 +1,15 @@
# flake8: noqa: F401
from .add_config import AddConfig, ChunkerConfig
from .app_config import AppConfig
from .base_config import BaseConfig
from .cache_config import CacheConfig
from .embedder.base import BaseEmbedderConfig
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
from .embedder.ollama import OllamaEmbedderConfig
from .llm.base import BaseLlmConfig
from .vector_db.chroma import ChromaDbConfig
from .vector_db.elasticsearch import ElasticsearchDBConfig
from .vector_db.opensearch import OpenSearchDBConfig
from .vector_db.zilliz import ZillizDBConfig
from .mem0_config import Mem0Config

View File

@@ -0,0 +1,79 @@
import builtins
import logging
from collections.abc import Callable
from importlib import import_module
from typing import Optional
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class ChunkerConfig(BaseConfig):
"""
Config for the chunker used in `add` method
"""
def __init__(
self,
chunk_size: Optional[int] = 2000,
chunk_overlap: Optional[int] = 0,
length_function: Optional[Callable[[str], int]] = None,
min_chunk_size: Optional[int] = 0,
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.min_chunk_size = min_chunk_size
if self.min_chunk_size >= self.chunk_size:
raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
if self.min_chunk_size < self.chunk_overlap:
logging.warning(
f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant." # noqa:E501
)
if isinstance(length_function, str):
self.length_function = self.load_func(length_function)
else:
self.length_function = length_function if length_function else len
@staticmethod
def load_func(dotpath: str):
if "." not in dotpath:
return getattr(builtins, dotpath)
else:
module_, func = dotpath.rsplit(".", maxsplit=1)
m = import_module(module_)
return getattr(m, func)
@register_deserializable
class LoaderConfig(BaseConfig):
"""
Config for the loader used in `add` method
"""
def __init__(self):
pass
@register_deserializable
class AddConfig(BaseConfig):
"""
Config for the `add` method.
"""
def __init__(
self,
chunker: Optional[ChunkerConfig] = None,
loader: Optional[LoaderConfig] = None,
):
"""
Initializes a configuration class instance for the `add` method.
:param chunker: Chunker config, defaults to None
:type chunker: Optional[ChunkerConfig], optional
:param loader: Loader config, defaults to None
:type loader: Optional[LoaderConfig], optional
"""
self.loader = loader
self.chunker = chunker

View File

@@ -0,0 +1,34 @@
from typing import Optional
from embedchain.helpers.json_serializable import register_deserializable
from .base_app_config import BaseAppConfig
@register_deserializable
class AppConfig(BaseAppConfig):
"""
Config to initialize an embedchain custom `App` instance, with extra config options.
"""
def __init__(
self,
log_level: str = "WARNING",
id: Optional[str] = None,
name: Optional[str] = None,
collect_metrics: Optional[bool] = True,
**kwargs,
):
"""
Initializes a configuration class instance for an App. This is the simplest form of an embedchain app.
Most of the configuration is done in the `App` class itself.
:param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
:type log_level: str, optional
:param id: ID of the app. Document metadata will have this id., defaults to None
:type id: Optional[str], optional
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
:type collect_metrics: Optional[bool], optional
"""
self.name = name
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, **kwargs)

View File

@@ -0,0 +1,58 @@
import logging
from typing import Optional
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.vectordb.base import BaseVectorDB
logger = logging.getLogger(__name__)
class BaseAppConfig(BaseConfig, JSONSerializable):
"""
Parent config to initialize an instance of `App`.
"""
def __init__(
self,
log_level: str = "WARNING",
db: Optional[BaseVectorDB] = None,
id: Optional[str] = None,
collect_metrics: bool = True,
collection_name: Optional[str] = None,
):
"""
Initializes a configuration class instance for an App.
Most of the configuration is done in the `App` class itself.
:param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
:type log_level: str, optional
:param db: A database class. It is recommended to set this directly in the `App` class, not this config,
defaults to None
:type db: Optional[BaseVectorDB], optional
:param id: ID of the app. Document metadata will have this id., defaults to None
:type id: Optional[str], optional
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
:type collect_metrics: Optional[bool], optional
:param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
defaults to None
:type collection_name: Optional[str], optional
"""
self.id = id
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
self.collection_name = collection_name
if db:
self._db = db
logger.warning(
"DEPRECATION WARNING: Please supply the database as the second parameter during app init. "
"Such as `app(config=config, db=db)`."
)
if collection_name:
logger.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
return
def _setup_logging(self, log_level):
logger.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=log_level)
self.logger = logger.getLogger(__name__)

View File

@@ -0,0 +1,21 @@
from typing import Any
from embedchain.helpers.json_serializable import JSONSerializable
class BaseConfig(JSONSerializable):
"""
Base config.
"""
def __init__(self):
"""Initializes a configuration class for a class."""
pass
def as_dict(self) -> dict[str, Any]:
"""Return config object as a dict
:return: config object as dict
:rtype: dict[str, Any]
"""
return vars(self)

View File

@@ -0,0 +1,96 @@
from typing import Any, Optional
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class CacheSimilarityEvalConfig(BaseConfig):
"""
This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage.
In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been
put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`].
`positive` is used to indicate this distance is directly proportional to the similarity of two entities.
If `positive` is set `False`, `max_distance` will be used to subtract this distance to get the final score.
:param max_distance: the bound of maximum distance.
:type max_distance: float
:param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise, it is False.
:type positive: bool
"""
def __init__(
self,
strategy: Optional[str] = "distance",
max_distance: Optional[float] = 1.0,
positive: Optional[bool] = False,
):
self.strategy = strategy
self.max_distance = max_distance
self.positive = positive
@staticmethod
def from_config(config: Optional[dict[str, Any]]):
if config is None:
return CacheSimilarityEvalConfig()
else:
return CacheSimilarityEvalConfig(
strategy=config.get("strategy", "distance"),
max_distance=config.get("max_distance", 1.0),
positive=config.get("positive", False),
)
@register_deserializable
class CacheInitConfig(BaseConfig):
"""
This is a cache init config. Used to initialize a cache.
:param similarity_threshold: a threshold ranged from 0 to 1 to filter search results with similarity score higher \
than the threshold. When it is 0, there is no hits. When it is 1, all search results will be returned as hits.
:type similarity_threshold: float
:param auto_flush: it will be automatically flushed every time xx pieces of data are added, default to 20
:type auto_flush: int
"""
def __init__(
self,
similarity_threshold: Optional[float] = 0.8,
auto_flush: Optional[int] = 20,
):
if similarity_threshold < 0 or similarity_threshold > 1:
raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
self.similarity_threshold = similarity_threshold
self.auto_flush = auto_flush
@staticmethod
def from_config(config: Optional[dict[str, Any]]):
if config is None:
return CacheInitConfig()
else:
return CacheInitConfig(
similarity_threshold=config.get("similarity_threshold", 0.8),
auto_flush=config.get("auto_flush", 20),
)
@register_deserializable
class CacheConfig(BaseConfig):
def __init__(
self,
similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(),
init_config: Optional[CacheInitConfig] = CacheInitConfig(),
):
self.similarity_eval_config = similarity_eval_config
self.init_config = init_config
@staticmethod
def from_config(config: Optional[dict[str, Any]]):
if config is None:
return CacheConfig()
else:
return CacheConfig(
similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})),
init_config=CacheInitConfig.from_config(config.get("init_config", {})),
)

View File

@@ -0,0 +1,42 @@
from typing import Any, Dict, Optional
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class BaseEmbedderConfig:
def __init__(
self,
model: Optional[str] = None,
deployment_name: Optional[str] = None,
vector_dimension: Optional[int] = None,
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Initialize a new instance of an embedder config class.
:param model: model name of the llm embedding model (not applicable to all providers), defaults to None
:type model: Optional[str], optional
:param deployment_name: deployment name for llm embedding model, defaults to None
:type deployment_name: Optional[str], optional
:param vector_dimension: vector dimension of the embedding model, defaults to None
:type vector_dimension: Optional[int], optional
:param endpoint: endpoint for the embedding model, defaults to None
:type endpoint: Optional[str], optional
:param api_key: hugginface api key, defaults to None
:type api_key: Optional[str], optional
:param api_base: huggingface api base, defaults to None
:type api_base: Optional[str], optional
:param model_kwargs: key-value arguments for the embedding model, defaults a dict inside init.
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init.
"""
self.model = model
self.deployment_name = deployment_name
self.vector_dimension = vector_dimension
self.endpoint = endpoint
self.api_key = api_key
self.api_base = api_base
self.model_kwargs = model_kwargs or {}

View File

@@ -0,0 +1,19 @@
from typing import Optional
from embedchain.config.embedder.base import BaseEmbedderConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class GoogleAIEmbedderConfig(BaseEmbedderConfig):
def __init__(
self,
model: Optional[str] = None,
deployment_name: Optional[str] = None,
vector_dimension: Optional[int] = None,
task_type: Optional[str] = None,
title: Optional[str] = None,
):
super().__init__(model, deployment_name, vector_dimension)
self.task_type = task_type or "retrieval_document"
self.title = title or "Embeddings for Embedchain"

View File

@@ -0,0 +1,16 @@
from typing import Optional
from embedchain.config.embedder.base import BaseEmbedderConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class OllamaEmbedderConfig(BaseEmbedderConfig):
def __init__(
self,
model: Optional[str] = None,
base_url: Optional[str] = None,
vector_dimension: Optional[int] = None,
):
super().__init__(model=model, vector_dimension=vector_dimension)
self.base_url = base_url or "http://localhost:11434"

View File

@@ -0,0 +1,2 @@
from .base import (AnswerRelevanceConfig, ContextRelevanceConfig, # noqa: F401
GroundednessConfig)

View File

@@ -0,0 +1,92 @@
from typing import Optional
from embedchain.config.base_config import BaseConfig
ANSWER_RELEVANCY_PROMPT = """
Please provide $num_gen_questions questions from the provided answer.
You must provide the complete question, if are not able to provide the complete question, return empty string ("").
Please only provide one question per line without numbers or bullets to distinguish them.
You must only provide the questions and no other text.
$answer
""" # noqa:E501
CONTEXT_RELEVANCY_PROMPT = """
Please extract relevant sentences from the provided context that is required to answer the given question.
If no relevant sentences are found, or if you believe the question cannot be answered from the given context, return the empty string ("").
While extracting candidate sentences you're not allowed to make any changes to sentences from given context or make up any sentences.
You must only provide sentences from the given context and nothing else.
Context: $context
Question: $question
""" # noqa:E501
GROUNDEDNESS_ANSWER_CLAIMS_PROMPT = """
Please provide one or more statements from each sentence of the provided answer.
You must provide the symantically equivalent statements for each sentence of the answer.
You must provide the complete statement, if are not able to provide the complete statement, return empty string ("").
Please only provide one statement per line WITHOUT numbers or bullets.
If the question provided is not being answered in the provided answer, return empty string ("").
You must only provide the statements and no other text.
$question
$answer
""" # noqa:E501
GROUNDEDNESS_CLAIMS_INFERENCE_PROMPT = """
Given the context and the provided claim statements, please provide a verdict for each claim statement whether it can be completely infered from the given context or not.
Use only "1" (yes), "0" (no) and "-1" (null) for "yes", "no" or "null" respectively.
You must provide one verdict per line, ONLY WITH "1", "0" or "-1" as per your verdict to the given statement and nothing else.
You must provide the verdicts in the same order as the claim statements.
Contexts:
$context
Claim statements:
$claim_statements
""" # noqa:E501
class GroundednessConfig(BaseConfig):
def __init__(
self,
model: str = "gpt-4",
api_key: Optional[str] = None,
answer_claims_prompt: str = GROUNDEDNESS_ANSWER_CLAIMS_PROMPT,
claims_inference_prompt: str = GROUNDEDNESS_CLAIMS_INFERENCE_PROMPT,
):
self.model = model
self.api_key = api_key
self.answer_claims_prompt = answer_claims_prompt
self.claims_inference_prompt = claims_inference_prompt
class AnswerRelevanceConfig(BaseConfig):
def __init__(
self,
model: str = "gpt-4",
embedder: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
num_gen_questions: int = 1,
prompt: str = ANSWER_RELEVANCY_PROMPT,
):
self.model = model
self.embedder = embedder
self.api_key = api_key
self.num_gen_questions = num_gen_questions
self.prompt = prompt
class ContextRelevanceConfig(BaseConfig):
def __init__(
self,
model: str = "gpt-4",
api_key: Optional[str] = None,
language: str = "en",
prompt: str = CONTEXT_RELEVANCY_PROMPT,
):
self.model = model
self.api_key = api_key
self.language = language
self.prompt = prompt

View File

@@ -0,0 +1,275 @@
import json
import logging
import re
from string import Template
from typing import Any, Mapping, Optional, Dict, Union
import httpx
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable
logger = logging.getLogger(__name__)
DEFAULT_PROMPT = """
You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. Here are some guidelines to follow:
1. Refrain from explicitly mentioning the context provided in your response.
2. The context should silently guide your answers without being directly acknowledged.
3. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
Context information:
----------------------
$context
----------------------
Query: $query
Answer:
""" # noqa:E501
DEFAULT_PROMPT_WITH_HISTORY = """
You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. You are also provided with the conversation history with the user. Make sure to use relevant context from conversation history as needed.
Here are some guidelines to follow:
1. Refrain from explicitly mentioning the context provided in your response.
2. The context should silently guide your answers without being directly acknowledged.
3. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
Context information:
----------------------
$context
----------------------
Conversation history:
----------------------
$history
----------------------
Query: $query
Answer:
""" # noqa:E501
DEFAULT_PROMPT_WITH_MEM0_MEMORY = """
You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. You are also provided with the conversation history and memories with the user. Make sure to use relevant context from conversation history and memories as needed.
Here are some guidelines to follow:
1. Refrain from explicitly mentioning the context provided in your response.
2. Take into consideration the conversation history and memories provided.
3. The context should silently guide your answers without being directly acknowledged.
4. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
Context information:
----------------------
$context
----------------------
Conversation history:
----------------------
$history
----------------------
Memories/Preferences:
----------------------
$memories
----------------------
Query: $query
Answer:
""" # noqa:E501
DOCS_SITE_DEFAULT_PROMPT = """
You are an expert AI assistant for developer support product. Your responses must always be rooted in the context provided for each query. Wherever possible, give complete code snippet. Dont make up any code snippet on your own.
Here are some guidelines to follow:
1. Refrain from explicitly mentioning the context provided in your response.
2. The context should silently guide your answers without being directly acknowledged.
3. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
Context information:
----------------------
$context
----------------------
Query: $query
Answer:
""" # noqa:E501
DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_MEM0_MEMORY)
DOCS_SITE_PROMPT_TEMPLATE = Template(DOCS_SITE_DEFAULT_PROMPT)
query_re = re.compile(r"\$\{*query\}*")
context_re = re.compile(r"\$\{*context\}*")
history_re = re.compile(r"\$\{*history\}*")
@register_deserializable
class BaseLlmConfig(BaseConfig):
"""
Config for the `query` method.
"""
def __init__(
self,
number_documents: int = 3,
template: Optional[Template] = None,
prompt: Optional[Template] = None,
model: Optional[str] = None,
temperature: float = 0,
max_tokens: int = 1000,
top_p: float = 1,
stream: bool = False,
online: bool = False,
token_usage: bool = False,
deployment_name: Optional[str] = None,
system_prompt: Optional[str] = None,
where: dict[str, Any] = None,
query_type: Optional[str] = None,
callbacks: Optional[list] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
endpoint: Optional[str] = None,
model_kwargs: Optional[dict[str, Any]] = None,
http_client_proxies: Optional[Union[Dict, str]] = None,
http_async_client_proxies: Optional[Union[Dict, str]] = None,
local: Optional[bool] = False,
default_headers: Optional[Mapping[str, str]] = None,
api_version: Optional[str] = None,
):
"""
Initializes a configuration class instance for the LLM.
Takes the place of the former `QueryConfig` or `ChatConfig`.
:param number_documents: Number of documents to pull from the database as
context, defaults to 1
:type number_documents: int, optional
:param template: The `Template` instance to use as a template for
prompt, defaults to None (deprecated)
:type template: Optional[Template], optional
:param prompt: The `Template` instance to use as a template for
prompt, defaults to None
:type prompt: Optional[Template], optional
:param model: Controls the OpenAI model used, defaults to None
:type model: Optional[str], optional
:param temperature: Controls the randomness of the model's output.
Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
:type temperature: float, optional
:param max_tokens: Controls how many tokens are generated, defaults to 1000
:type max_tokens: int, optional
:param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
defaults to 1
:type top_p: float, optional
:param stream: Control if response is streamed back to user, defaults to False
:type stream: bool, optional
:param online: Controls whether to use internet for answering query, defaults to False
:type online: bool, optional
:param token_usage: Controls whether to return token usage in response, defaults to False
:type token_usage: bool, optional
:param deployment_name: t.b.a., defaults to None
:type deployment_name: Optional[str], optional
:param system_prompt: System prompt string, defaults to None
:type system_prompt: Optional[str], optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: dict[str, Any], optional
:param api_key: The api key of the custom endpoint, defaults to None
:type api_key: Optional[str], optional
:param endpoint: The api url of the custom endpoint, defaults to None
:type endpoint: Optional[str], optional
:param model_kwargs: A dictionary of key-value pairs to pass to the model, defaults to None
:type model_kwargs: Optional[Dict[str, Any]], optional
:param callbacks: Langchain callback functions to use, defaults to None
:type callbacks: Optional[list], optional
:param query_type: The type of query to use, defaults to None
:type query_type: Optional[str], optional
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
:type http_client_proxies: Optional[Dict | str], optional
:param http_async_client_proxies: The proxy server settings for async calls used to create
self.http_async_client, defaults to None
:type http_async_client_proxies: Optional[Dict | str], optional
:param local: If True, the model will be run locally, defaults to False (for huggingface provider)
:type local: Optional[bool], optional
:param default_headers: Set additional HTTP headers to be sent with requests to OpenAI
:type default_headers: Optional[Mapping[str, str]], optional
:raises ValueError: If the template is not valid as template should
contain $context and $query (and optionally $history)
:raises ValueError: Stream is not boolean
"""
if template is not None:
logger.warning(
"The `template` argument is deprecated and will be removed in a future version. "
+ "Please use `prompt` instead."
)
if prompt is None:
prompt = template
if prompt is None:
prompt = DEFAULT_PROMPT_TEMPLATE
self.number_documents = number_documents
self.temperature = temperature
self.max_tokens = max_tokens
self.model = model
self.top_p = top_p
self.online = online
self.token_usage = token_usage
self.deployment_name = deployment_name
self.system_prompt = system_prompt
self.query_type = query_type
self.callbacks = callbacks
self.api_key = api_key
self.base_url = base_url
self.endpoint = endpoint
self.model_kwargs = model_kwargs
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
self.http_async_client = (
httpx.AsyncClient(proxies=http_async_client_proxies) if http_async_client_proxies else None
)
self.local = local
self.default_headers = default_headers
self.online = online
self.api_version = api_version
if token_usage:
f = open("model_prices_and_context_window.json")
self.model_pricing_map = json.load(f)
if isinstance(prompt, str):
prompt = Template(prompt)
if self.validate_prompt(prompt):
self.prompt = prompt
else:
raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).")
if not isinstance(stream, bool):
raise ValueError("`stream` should be bool")
self.stream = stream
self.where = where
@staticmethod
def validate_prompt(prompt: Template) -> Optional[re.Match[str]]:
"""
validate the prompt
:param prompt: the prompt to validate
:type prompt: Template
:return: valid (true) or invalid (false)
:rtype: Optional[re.Match[str]]
"""
return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
@staticmethod
def _validate_prompt_history(prompt: Template) -> Optional[re.Match[str]]:
"""
validate the prompt with history
:param prompt: the prompt to validate
:type prompt: Template
:return: valid (true) or invalid (false)
:rtype: Optional[re.Match[str]]
"""
return re.search(history_re, prompt.template)

View File

@@ -0,0 +1,21 @@
from typing import Any, Optional
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class Mem0Config(BaseConfig):
def __init__(self, api_key: str, top_k: Optional[int] = 10):
self.api_key = api_key
self.top_k = top_k
@staticmethod
def from_config(config: Optional[dict[str, Any]]):
if config is None:
return Mem0Config()
else:
return Mem0Config(
api_key=config.get("api_key", ""),
init_config=config.get("top_k", 10),
)

View File

@@ -0,0 +1,36 @@
from typing import Optional
from embedchain.config.base_config import BaseConfig
class BaseVectorDbConfig(BaseConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: str = "db",
host: Optional[str] = None,
port: Optional[str] = None,
**kwargs,
):
"""
Initializes a configuration class instance for the vector database.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to "db"
:type dir: str, optional
:param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
:type host: Optional[str], optional
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param kwargs: Additional keyword arguments
:type kwargs: dict
"""
self.collection_name = collection_name or "embedchain_store"
self.dir = dir
self.host = host
self.port = port
# Assign additional keyword arguments
if kwargs:
for key, value in kwargs.items():
setattr(self, key, value)

View File

@@ -0,0 +1,41 @@
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class ChromaDbConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
batch_size: Optional[int] = 100,
allow_reset=False,
chroma_settings: Optional[dict] = None,
):
"""
Initializes a configuration class instance for ChromaDB.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
:type host: Optional[str], optional
:param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param allow_reset: Resets the database. defaults to False
:type allow_reset: bool
:param chroma_settings: Chroma settings dict, defaults to None
:type chroma_settings: Optional[dict], optional
"""
self.chroma_settings = chroma_settings
self.allow_reset = allow_reset
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)

View File

@@ -0,0 +1,56 @@
import os
from typing import Optional, Union
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class ElasticsearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
es_url: Union[str, list[str]] = None,
cloud_id: Optional[str] = None,
batch_size: Optional[int] = 100,
**ES_EXTRA_PARAMS: dict[str, any],
):
"""
Initializes a configuration class instance for an Elasticsearch client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
:type es_url: Union[str, list[str]], optional
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
:type cloud_id: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
:type ES_EXTRA_PARAMS: dict[str, Any], optional
"""
if es_url and cloud_id:
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
if not self.ES_URL and not self.CLOUD_ID:
raise AttributeError(
"Elasticsearch needs a URL or CLOUD_ID attribute, "
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
)
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
# Load API key from .env if it's not explicitly passed.
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
if (
not self.ES_EXTRA_PARAMS.get("api_key")
and not self.ES_EXTRA_PARAMS.get("basic_auth")
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
):
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -0,0 +1,33 @@
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class LanceDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
allow_reset=True,
):
"""
Initializes a configuration class instance for LanceDB.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
:type host: Optional[str], optional
:param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
:type port: Optional[str], optional
:param allow_reset: Resets the database. defaults to False
:type allow_reset: bool
"""
self.allow_reset = allow_reset
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)

View File

@@ -0,0 +1,41 @@
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class OpenSearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
opensearch_url: str,
http_auth: tuple[str, str],
vector_dimension: int = 1536,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
"""
Initializes a configuration class instance for an OpenSearch client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param opensearch_url: URL of the OpenSearch domain
:type opensearch_url: str, Eg, "http://localhost:9200"
:param http_auth: Tuple of username and password
:type http_auth: tuple[str, str], Eg, ("username", "password")
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
:type vector_dimension: int, optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
"""
self.opensearch_url = opensearch_url
self.http_auth = http_auth
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -0,0 +1,47 @@
import os
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class PineconeDBConfig(BaseVectorDbConfig):
def __init__(
self,
index_name: Optional[str] = None,
api_key: Optional[str] = None,
vector_dimension: int = 1536,
metric: Optional[str] = "cosine",
pod_config: Optional[dict[str, any]] = None,
serverless_config: Optional[dict[str, any]] = None,
hybrid_search: bool = False,
bm25_encoder: any = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
self.metric = metric
self.api_key = api_key
self.index_name = index_name
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.hybrid_search = hybrid_search
self.bm25_encoder = bm25_encoder
self.batch_size = batch_size
if pod_config is None and serverless_config is None:
# If no config is provided, use the default pod spec config
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
self.pod_config = {"environment": pod_environment, "metadata_config": {"indexed": ["*"]}}
else:
self.pod_config = pod_config
self.serverless_config = serverless_config
if self.pod_config and self.serverless_config:
raise ValueError("Only one of pod_config or serverless_config can be provided.")
if self.hybrid_search and self.metric != "dotproduct":
raise ValueError(
"Hybrid search is only supported with dotproduct metric in Pinecone. See full docs here: https://docs.pinecone.io/docs/hybrid-search#limitations"
) # noqa:E501
super().__init__(collection_name=self.index_name, dir=None)

View File

@@ -0,0 +1,48 @@
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class QdrantDBConfig(BaseVectorDbConfig):
"""
Config to initialize a qdrant client.
:param: url. qdrant url or list of nodes url to be used for connection
"""
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
hnsw_config: Optional[dict[str, any]] = None,
quantization_config: Optional[dict[str, any]] = None,
on_disk: Optional[bool] = None,
batch_size: Optional[int] = 10,
**extra_params: dict[str, any],
):
"""
Initializes a configuration class instance for a qdrant client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param hnsw_config: Params for HNSW index
:type hnsw_config: Optional[dict[str, any]], defaults to None
:param quantization_config: Params for quantization, if None - quantization will be disabled
:type quantization_config: Optional[dict[str, any]], defaults to None
:param on_disk: If true - point`s payload will not be stored in memory.
It will be read from the disk every time it is requested.
This setting saves RAM by (slightly) increasing the response time.
Note: those payload values that are involved in filtering and are indexed - remain in RAM.
:type on_disk: bool, optional, defaults to None
:param batch_size: Number of items to insert in one batch, defaults to 10
:type batch_size: Optional[int], optional
"""
self.hnsw_config = hnsw_config
self.quantization_config = quantization_config
self.on_disk = on_disk
self.batch_size = batch_size
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -0,0 +1,18 @@
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class WeaviateDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
self.batch_size = batch_size
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -0,0 +1,49 @@
import os
from typing import Optional
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class ZillizDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
uri: Optional[str] = None,
token: Optional[str] = None,
vector_dim: Optional[str] = None,
metric_type: Optional[str] = None,
):
"""
Initializes a configuration class instance for the vector database.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to "db"
:type dir: str, optional
:param uri: Cluster endpoint obtained from the Zilliz Console, defaults to None
:type uri: Optional[str], optional
:param token: API Key, if a Serverless Cluster, username:password, if a Dedicated Cluster, defaults to None
:type token: Optional[str], optional
"""
self.uri = uri or os.environ.get("ZILLIZ_CLOUD_URI")
if not self.uri:
raise AttributeError(
"Zilliz needs a URI attribute, "
"this can either be passed to `ZILLIZ_CLOUD_URI` or as `ZILLIZ_CLOUD_URI` in `.env`"
)
self.token = token or os.environ.get("ZILLIZ_CLOUD_TOKEN")
if not self.token:
raise AttributeError(
"Zilliz needs a token attribute, "
"this can either be passed to `ZILLIZ_CLOUD_TOKEN` or as `ZILLIZ_CLOUD_TOKEN` in `.env`,"
"if having a username and password, pass it in the form 'username:password' to `ZILLIZ_CLOUD_TOKEN`"
)
self.metric_type = metric_type if metric_type else "L2"
self.vector_dim = vector_dim
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -0,0 +1,11 @@
import os
from pathlib import Path
ABS_PATH = os.getcwd()
HOME_DIR = os.environ.get("EMBEDCHAIN_CONFIG_DIR", str(Path.home()))
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
# Set the environment variable for the database URI
os.environ.setdefault("EMBEDCHAIN_DB_URI", f"sqlite:///{SQLITE_PATH}")

View File

View File

@@ -0,0 +1 @@
from .data_formatter import DataFormatter # noqa: F401

View File

@@ -0,0 +1,148 @@
from importlib import import_module
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig
from embedchain.config.add_config import ChunkerConfig, LoaderConfig
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType
class DataFormatter(JSONSerializable):
"""
DataFormatter is an internal utility class which abstracts the mapping for
loaders and chunkers to the data_type entered by the user in their
.add or .add_local method call
"""
def __init__(
self,
data_type: DataType,
config: AddConfig,
loader: Optional[BaseLoader] = None,
chunker: Optional[BaseChunker] = None,
):
"""
Initialize a dataformatter, set data type and chunker based on datatype.
:param data_type: The type of the data to load and chunk.
:type data_type: DataType
:param config: AddConfig instance with nested loader and chunker config attributes.
:type config: AddConfig
"""
self.loader = self._get_loader(data_type=data_type, config=config.loader, loader=loader)
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, chunker=chunker)
@staticmethod
def _lazy_load(module_path: str):
module_path, class_name = module_path.rsplit(".", 1)
module = import_module(module_path)
return getattr(module, class_name)
def _get_loader(self, data_type: DataType, config: LoaderConfig, loader: Optional[BaseLoader]) -> BaseLoader:
"""
Returns the appropriate data loader for the given data type.
:param data_type: The type of the data to load.
:type data_type: DataType
:param config: Config to initialize the loader with.
:type config: LoaderConfig
:raises ValueError: If an unsupported data type is provided.
:return: The loader for the given data type.
:rtype: BaseLoader
"""
loaders = {
DataType.YOUTUBE_VIDEO: "embedchain.loaders.youtube_video.YoutubeVideoLoader",
DataType.PDF_FILE: "embedchain.loaders.pdf_file.PdfFileLoader",
DataType.WEB_PAGE: "embedchain.loaders.web_page.WebPageLoader",
DataType.QNA_PAIR: "embedchain.loaders.local_qna_pair.LocalQnaPairLoader",
DataType.TEXT: "embedchain.loaders.local_text.LocalTextLoader",
DataType.DOCX: "embedchain.loaders.docx_file.DocxFileLoader",
DataType.SITEMAP: "embedchain.loaders.sitemap.SitemapLoader",
DataType.XML: "embedchain.loaders.xml.XmlLoader",
DataType.DOCS_SITE: "embedchain.loaders.docs_site_loader.DocsSiteLoader",
DataType.CSV: "embedchain.loaders.csv.CsvLoader",
DataType.MDX: "embedchain.loaders.mdx.MdxLoader",
DataType.IMAGE: "embedchain.loaders.image.ImageLoader",
DataType.UNSTRUCTURED: "embedchain.loaders.unstructured_file.UnstructuredLoader",
DataType.JSON: "embedchain.loaders.json.JSONLoader",
DataType.OPENAPI: "embedchain.loaders.openapi.OpenAPILoader",
DataType.GMAIL: "embedchain.loaders.gmail.GmailLoader",
DataType.NOTION: "embedchain.loaders.notion.NotionLoader",
DataType.SUBSTACK: "embedchain.loaders.substack.SubstackLoader",
DataType.YOUTUBE_CHANNEL: "embedchain.loaders.youtube_channel.YoutubeChannelLoader",
DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader",
DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader",
DataType.GOOGLE_DRIVE: "embedchain.loaders.google_drive.GoogleDriveLoader",
DataType.DIRECTORY: "embedchain.loaders.directory_loader.DirectoryLoader",
DataType.SLACK: "embedchain.loaders.slack.SlackLoader",
DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader",
DataType.TEXT_FILE: "embedchain.loaders.text_file.TextFileLoader",
DataType.EXCEL_FILE: "embedchain.loaders.excel_file.ExcelFileLoader",
DataType.AUDIO: "embedchain.loaders.audio.AudioLoader",
}
if data_type == DataType.CUSTOM or loader is not None:
loader_class: type = loader
if loader_class:
return loader_class
elif data_type in loaders:
loader_class: type = self._lazy_load(loaders[data_type])
return loader_class()
raise ValueError(
f"Cant find the loader for {data_type}.\
We recommend to pass the loader to use data_type: {data_type},\
check `https://docs.embedchain.ai/data-sources/overview`."
)
def _get_chunker(self, data_type: DataType, config: ChunkerConfig, chunker: Optional[BaseChunker]) -> BaseChunker:
"""Returns the appropriate chunker for the given data type (updated for lazy loading)."""
chunker_classes = {
DataType.YOUTUBE_VIDEO: "embedchain.chunkers.youtube_video.YoutubeVideoChunker",
DataType.PDF_FILE: "embedchain.chunkers.pdf_file.PdfFileChunker",
DataType.WEB_PAGE: "embedchain.chunkers.web_page.WebPageChunker",
DataType.QNA_PAIR: "embedchain.chunkers.qna_pair.QnaPairChunker",
DataType.TEXT: "embedchain.chunkers.text.TextChunker",
DataType.DOCX: "embedchain.chunkers.docx_file.DocxFileChunker",
DataType.SITEMAP: "embedchain.chunkers.sitemap.SitemapChunker",
DataType.XML: "embedchain.chunkers.xml.XmlChunker",
DataType.DOCS_SITE: "embedchain.chunkers.docs_site.DocsSiteChunker",
DataType.CSV: "embedchain.chunkers.table.TableChunker",
DataType.MDX: "embedchain.chunkers.mdx.MdxChunker",
DataType.IMAGE: "embedchain.chunkers.image.ImageChunker",
DataType.UNSTRUCTURED: "embedchain.chunkers.unstructured_file.UnstructuredFileChunker",
DataType.JSON: "embedchain.chunkers.json.JSONChunker",
DataType.OPENAPI: "embedchain.chunkers.openapi.OpenAPIChunker",
DataType.GMAIL: "embedchain.chunkers.gmail.GmailChunker",
DataType.NOTION: "embedchain.chunkers.notion.NotionChunker",
DataType.SUBSTACK: "embedchain.chunkers.substack.SubstackChunker",
DataType.YOUTUBE_CHANNEL: "embedchain.chunkers.common_chunker.CommonChunker",
DataType.DISCORD: "embedchain.chunkers.common_chunker.CommonChunker",
DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker",
DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker",
DataType.GOOGLE_DRIVE: "embedchain.chunkers.google_drive.GoogleDriveChunker",
DataType.DIRECTORY: "embedchain.chunkers.common_chunker.CommonChunker",
DataType.SLACK: "embedchain.chunkers.common_chunker.CommonChunker",
DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker",
DataType.TEXT_FILE: "embedchain.chunkers.common_chunker.CommonChunker",
DataType.EXCEL_FILE: "embedchain.chunkers.excel_file.ExcelFileChunker",
DataType.AUDIO: "embedchain.chunkers.audio.AudioChunker",
}
if chunker is not None:
return chunker
elif data_type in chunker_classes:
chunker_class = self._lazy_load(chunker_classes[data_type])
chunker = chunker_class(config)
chunker.set_data_type(data_type)
return chunker
raise ValueError(
f"Cant find the chunker for {data_type}.\
We recommend to pass the chunker to use data_type: {data_type},\
check `https://docs.embedchain.ai/data-sources/overview`."
)

View File

@@ -0,0 +1 @@
db/

View File

@@ -0,0 +1 @@
OPENAI_API_KEY=sk-xxx

View File

@@ -0,0 +1,13 @@
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt /app/
RUN pip install -r requirements.txt
COPY . /app
EXPOSE 8080
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8080"]

View File

@@ -0,0 +1,56 @@
from dotenv import load_dotenv
from fastapi import FastAPI, responses
from pydantic import BaseModel
from embedchain import App
load_dotenv(".env")
app = FastAPI(title="Embedchain FastAPI App")
embedchain_app = App()
class SourceModel(BaseModel):
source: str
class QuestionModel(BaseModel):
question: str
@app.post("/add")
async def add_source(source_model: SourceModel):
"""
Adds a new source to the EmbedChain app.
Expects a JSON with a "source" key.
"""
source = source_model.source
embedchain_app.add(source)
return {"message": f"Source '{source}' added successfully."}
@app.post("/query")
async def handle_query(question_model: QuestionModel):
"""
Handles a query to the EmbedChain app.
Expects a JSON with a "question" key.
"""
question = question_model.question
answer = embedchain_app.query(question)
return {"answer": answer}
@app.post("/chat")
async def handle_chat(question_model: QuestionModel):
"""
Handles a chat request to the EmbedChain app.
Expects a JSON with a "question" key.
"""
question = question_model.question
response = embedchain_app.chat(question)
return {"response": response}
@app.get("/")
async def root():
return responses.RedirectResponse(url="/docs")

View File

@@ -0,0 +1,4 @@
fastapi==0.104.0
uvicorn==0.23.2
embedchain
beautifulsoup4

View File

@@ -0,0 +1,18 @@
import os
import gradio as gr
from embedchain import App
os.environ["OPENAI_API_KEY"] = "sk-xxx"
app = App()
def query(message, history):
return app.chat(message)
demo = gr.ChatInterface(query)
demo.launch()

View File

@@ -0,0 +1,2 @@
gradio==4.11.0
embedchain

View File

@@ -0,0 +1 @@
OPENAI_API_KEY=sk-xxx

View File

@@ -0,0 +1 @@
.env

View File

@@ -0,0 +1,86 @@
from dotenv import load_dotenv
from fastapi import Body, FastAPI, responses
from modal import Image, Secret, Stub, asgi_app
from embedchain import App
load_dotenv(".env")
image = Image.debian_slim().pip_install(
"embedchain",
"lanchain_community==0.2.6",
"youtube-transcript-api==0.6.1",
"pytube==15.0.0",
"beautifulsoup4==4.12.3",
"slack-sdk==3.21.3",
"huggingface_hub==0.23.0",
"gitpython==3.1.38",
"yt_dlp==2023.11.14",
"PyGithub==1.59.1",
"feedparser==6.0.10",
"newspaper3k==0.2.8",
"listparser==0.19",
)
stub = Stub(
name="embedchain-app",
image=image,
secrets=[Secret.from_dotenv(".env")],
)
web_app = FastAPI()
embedchain_app = App(name="embedchain-modal-app")
@web_app.post("/add")
async def add(
source: str = Body(..., description="Source to be added"),
data_type: str | None = Body(None, description="Type of the data source"),
):
"""
Adds a new source to the EmbedChain app.
Expects a JSON with a "source" and "data_type" key.
"data_type" is optional.
"""
if source and data_type:
embedchain_app.add(source, data_type)
elif source:
embedchain_app.add(source)
else:
return {"message": "No source provided."}
return {"message": f"Source '{source}' added successfully."}
@web_app.post("/query")
async def query(question: str = Body(..., description="Question to be answered")):
"""
Handles a query to the EmbedChain app.
Expects a JSON with a "question" key.
"""
if not question:
return {"message": "No question provided."}
answer = embedchain_app.query(question)
return {"answer": answer}
@web_app.get("/chat")
async def chat(question: str = Body(..., description="Question to be answered")):
"""
Handles a chat request to the EmbedChain app.
Expects a JSON with a "question" key.
"""
if not question:
return {"message": "No question provided."}
response = embedchain_app.chat(question)
return {"response": response}
@web_app.get("/")
async def root():
return responses.RedirectResponse(url="/docs")
@stub.function(image=image)
@asgi_app()
def fastapi_app():
return web_app

View File

@@ -0,0 +1,4 @@
modal==0.56.4329
fastapi==0.104.0
uvicorn==0.23.2
embedchain

View File

@@ -0,0 +1 @@
OPENAI_API_KEY=sk-xxx

View File

@@ -0,0 +1 @@
.env

View File

@@ -0,0 +1,53 @@
from fastapi import FastAPI, responses
from pydantic import BaseModel
from embedchain import App
app = FastAPI(title="Embedchain FastAPI App")
embedchain_app = App()
class SourceModel(BaseModel):
source: str
class QuestionModel(BaseModel):
question: str
@app.post("/add")
async def add_source(source_model: SourceModel):
"""
Adds a new source to the EmbedChain app.
Expects a JSON with a "source" key.
"""
source = source_model.source
embedchain_app.add(source)
return {"message": f"Source '{source}' added successfully."}
@app.post("/query")
async def handle_query(question_model: QuestionModel):
"""
Handles a query to the EmbedChain app.
Expects a JSON with a "question" key.
"""
question = question_model.question
answer = embedchain_app.query(question)
return {"answer": answer}
@app.post("/chat")
async def handle_chat(question_model: QuestionModel):
"""
Handles a chat request to the EmbedChain app.
Expects a JSON with a "question" key.
"""
question = question_model.question
response = embedchain_app.chat(question)
return {"response": response}
@app.get("/")
async def root():
return responses.RedirectResponse(url="/docs")

View File

@@ -0,0 +1,16 @@
services:
- type: web
name: ec-render-app
runtime: python
repo: https://github.com/<your-username>/<repo-name>
scaling:
minInstances: 1
maxInstances: 3
targetMemoryPercent: 60 # optional if targetCPUPercent is set
targetCPUPercent: 60 # optional if targetMemory is set
buildCommand: pip install -r requirements.txt
startCommand: uvicorn app:app --host 0.0.0.0
envVars:
- key: OPENAI_API_KEY
value: sk-xxx
autoDeploy: false # optional

View File

@@ -0,0 +1,4 @@
fastapi==0.104.0
uvicorn==0.23.2
embedchain
beautifulsoup4

View File

@@ -0,0 +1 @@
OPENAI_API_KEY="sk-xxx"

View File

@@ -0,0 +1,59 @@
import streamlit as st
from embedchain import App
@st.cache_resource
def embedchain_bot():
return App()
st.title("💬 Chatbot")
st.caption("🚀 An Embedchain app powered by OpenAI!")
if "messages" not in st.session_state:
st.session_state.messages = [
{
"role": "assistant",
"content": """
Hi! I'm a chatbot. I can answer questions and learn new things!\n
Ask me anything and if you want me to learn something do `/add <source>`.\n
I can learn mostly everything. :)
""",
}
]
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Ask me anything!"):
app = embedchain_bot()
if prompt.startswith("/add"):
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
prompt = prompt.replace("/add", "").strip()
with st.chat_message("assistant"):
message_placeholder = st.empty()
message_placeholder.markdown("Adding to knowledge base...")
app.add(prompt)
message_placeholder.markdown(f"Added {prompt} to knowledge base!")
st.session_state.messages.append({"role": "assistant", "content": f"Added {prompt} to knowledge base!"})
st.stop()
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("assistant"):
msg_placeholder = st.empty()
msg_placeholder.markdown("Thinking...")
full_response = ""
for response in app.chat(prompt):
msg_placeholder.empty()
full_response += response
msg_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})

View File

@@ -0,0 +1,2 @@
streamlit==1.29.0
embedchain

View File

@@ -0,0 +1,776 @@
import hashlib
import json
import logging
from typing import Any, Optional, Union
from dotenv import load_dotenv
from langchain.docstore.document import Document
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig
from embedchain.core.db.models import ChatHistory, DataSource
from embedchain.data_formatter import DataFormatter
from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
from embedchain.utils.misc import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB
load_dotenv()
logger = logging.getLogger(__name__)
class EmbedChain(JSONSerializable):
def __init__(
self,
config: BaseAppConfig,
llm: BaseLlm,
db: BaseVectorDB = None,
embedder: BaseEmbedder = None,
system_prompt: Optional[str] = None,
):
"""
Initializes the EmbedChain instance, sets up a vector DB client and
creates a collection.
:param config: Configuration just for the app, not the db or llm or embedder.
:type config: BaseAppConfig
:param llm: Instance of the LLM you want to use.
:type llm: BaseLlm
:param db: Instance of the Database to use, defaults to None
:type db: BaseVectorDB, optional
:param embedder: instance of the embedder to use, defaults to None
:type embedder: BaseEmbedder, optional
:param system_prompt: System prompt to use in the llm query, defaults to None
:type system_prompt: Optional[str], optional
:raises ValueError: No database or embedder provided.
"""
self.config = config
self.cache_config = None
self.memory_config = None
self.mem0_client = None
# Llm
self.llm = llm
# Database has support for config assignment for backwards compatibility
if db is None and (not hasattr(self.config, "db") or self.config.db is None):
raise ValueError("App requires Database.")
self.db = db or self.config.db
# Embedder
if embedder is None:
raise ValueError("App requires Embedder.")
self.embedder = embedder
# Initialize database
self.db._set_embedder(self.embedder)
self.db._initialize()
# Set collection name from app config for backwards compatibility.
if config.collection_name:
self.db.set_collection_name(config.collection_name)
# Add variables that are "shortcuts"
if system_prompt:
self.llm.config.system_prompt = system_prompt
# Fetch the history from the database if exists
self.llm.update_history(app_id=self.config.id)
# Attributes that aren't subclass related.
self.user_asks = []
self.chunker: Optional[ChunkerConfig] = None
@property
def collect_metrics(self):
return self.config.collect_metrics
@collect_metrics.setter
def collect_metrics(self, value):
if not isinstance(value, bool):
raise ValueError(f"Boolean value expected but got {type(value)}.")
self.config.collect_metrics = value
@property
def online(self):
return self.llm.config.online
@online.setter
def online(self, value):
if not isinstance(value, bool):
raise ValueError(f"Boolean value expected but got {type(value)}.")
self.llm.config.online = value
def add(
self,
source: Any,
data_type: Optional[DataType] = None,
metadata: Optional[dict[str, Any]] = None,
config: Optional[AddConfig] = None,
dry_run=False,
loader: Optional[BaseLoader] = None,
chunker: Optional[BaseChunker] = None,
**kwargs: Optional[dict[str, Any]],
):
"""
Adds the data from the given URL to the vector db.
Loads the data, chunks it, create embedding for each chunk
and then stores the embedding to vector database.
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
:type source: Any
:param data_type: Automatically detected, but can be forced with this argument. The type of the data to add,
defaults to None
:type data_type: Optional[DataType], optional
:param metadata: Metadata associated with the data source., defaults to None
:type metadata: Optional[dict[str, Any]], optional
:param config: The `AddConfig` instance to use as configuration options., defaults to None
:type config: Optional[AddConfig], optional
:raises ValueError: Invalid data type
:param dry_run: Optional. A dry run displays the chunks to ensure that the loader and chunker work as intended.
defaults to False
:type dry_run: bool
:param loader: The loader to use to load the data, defaults to None
:type loader: BaseLoader, optional
:param chunker: The chunker to use to chunk the data, defaults to None
:type chunker: BaseChunker, optional
:param kwargs: To read more params for the query function
:type kwargs: dict[str, Any]
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
:rtype: str
"""
if config is not None:
pass
elif self.chunker is not None:
config = AddConfig(chunker=self.chunker)
else:
config = AddConfig()
try:
DataType(source)
logger.warning(
f"""Starting from version v0.0.40, Embedchain can automatically detect the data type. So, in the `add` method, the argument order has changed. You no longer need to specify '{source}' for the `source` argument. So the code snippet will be `.add("{data_type}", "{source}")`""" # noqa #E501
)
logger.warning(
"Embedchain is swapping the arguments for you. This functionality might be deprecated in the future, so please adjust your code." # noqa #E501
)
source, data_type = data_type, source
except ValueError:
pass
if data_type:
try:
data_type = DataType(data_type)
except ValueError:
logger.info(
f"Invalid data_type: '{data_type}', using `custom` instead.\n Check docs to pass the valid data type: `https://docs.embedchain.ai/data-sources/overview`" # noqa: E501
)
data_type = DataType.CUSTOM
if not data_type:
data_type = detect_datatype(source)
# `source_hash` is the md5 hash of the source argument
source_hash = hashlib.md5(str(source).encode("utf-8")).hexdigest()
self.user_asks.append([source, data_type.value, metadata])
data_formatter = DataFormatter(data_type, config, loader, chunker)
documents, metadatas, _ids, new_chunks = self._load_and_embed(
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, config, dry_run, **kwargs
)
if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True
# Convert the source to a string if it is not already
if not isinstance(source, str):
source = str(source)
# Insert the data into the 'ec_data_sources' table
self.db_session.add(
DataSource(
hash=source_hash,
app_id=self.config.id,
type=data_type.value,
value=source,
metadata=json.dumps(metadata),
)
)
try:
self.db_session.commit()
except Exception as e:
logger.error(f"Error adding data source: {e}")
self.db_session.rollback()
if dry_run:
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
logger.debug(f"Dry run info : {data_chunks_info}")
return data_chunks_info
# Send anonymous telemetry
if self.config.collect_metrics:
# it's quicker to check the variable twice than to count words when they won't be submitted.
word_count = data_formatter.chunker.get_word_count(documents)
# Send anonymous telemetry
event_properties = {
**self._telemetry_props,
"data_type": data_type.value,
"word_count": word_count,
"chunks_count": new_chunks,
}
self.telemetry.capture(event_name="add", properties=event_properties)
return source_hash
def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
"""
Get id of existing document for a given source, based on the data type
"""
# Find existing embeddings for the source
# Depending on the data type, existing embeddings are checked for.
if chunker.data_type.value in [item.value for item in DirectDataType]:
# DirectDataTypes can't be updated.
# Think of a text:
# Either it's the same, then it won't change, so it's not an update.
# Or it's different, then it will be added as a new text.
return None
elif chunker.data_type.value in [item.value for item in IndirectDataType]:
# These types have an indirect source reference
# As long as the reference is the same, they can be updated.
where = {"url": src}
if chunker.data_type == DataType.JSON and is_valid_json_string(src):
url = hashlib.sha256((src).encode("utf-8")).hexdigest()
where = {"url": url}
if self.config.id is not None:
where.update({"app_id": self.config.id})
existing_embeddings = self.db.get(
where=where,
limit=1,
)
if len(existing_embeddings.get("metadatas", [])) > 0:
return existing_embeddings["metadatas"][0]["doc_id"]
else:
return None
elif chunker.data_type.value in [item.value for item in SpecialDataType]:
# These types don't contain indirect references.
# Through custom logic, they can be attributed to a source and be updated.
if chunker.data_type == DataType.QNA_PAIR:
# QNA_PAIRs update the answer if the question already exists.
where = {"question": src[0]}
if self.config.id is not None:
where.update({"app_id": self.config.id})
existing_embeddings = self.db.get(
where=where,
limit=1,
)
if len(existing_embeddings.get("metadatas", [])) > 0:
return existing_embeddings["metadatas"][0]["doc_id"]
else:
return None
else:
raise NotImplementedError(
f"SpecialDataType {chunker.data_type} must have a custom logic to check for existing data"
)
else:
raise TypeError(
f"{chunker.data_type} is type {type(chunker.data_type)}. "
"When it should be DirectDataType, IndirectDataType or SpecialDataType."
)
def _load_and_embed(
self,
loader: BaseLoader,
chunker: BaseChunker,
src: Any,
metadata: Optional[dict[str, Any]] = None,
source_hash: Optional[str] = None,
add_config: Optional[AddConfig] = None,
dry_run=False,
**kwargs: Optional[dict[str, Any]],
):
"""
Loads the data from the given URL, chunks it, and adds it to database.
:param loader: The loader to use to load the data.
:type loader: BaseLoader
:param chunker: The chunker to use to chunk the data.
:type chunker: BaseChunker
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
:type src: Any
:param metadata: Metadata associated with the data source.
:type metadata: dict[str, Any], optional
:param source_hash: Hexadecimal hash of the source.
:type source_hash: str, optional
:param add_config: The `AddConfig` instance to use as configuration options.
:type add_config: AddConfig, optional
:param dry_run: A dry run returns chunks and doesn't update DB.
:type dry_run: bool, defaults to False
:return: (list) documents (embedded text), (list) metadata, (list) ids, (int) number of chunks
"""
existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
app_id = self.config.id if self.config is not None else None
# Create chunks
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker)
# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
ids = embeddings_data["ids"]
new_doc_id = embeddings_data["doc_id"]
if existing_doc_id and existing_doc_id == new_doc_id:
logger.info("Doc content has not changed. Skipping creating chunks and embeddings")
return [], [], [], 0
# this means that doc content has changed.
if existing_doc_id and existing_doc_id != new_doc_id:
logger.info("Doc content has changed. Recomputing chunks and embeddings intelligently.")
self.db.delete({"doc_id": existing_doc_id})
# get existing ids, and discard doc if any common id exist.
where = {"url": src}
if chunker.data_type == DataType.JSON and is_valid_json_string(src):
url = hashlib.sha256((src).encode("utf-8")).hexdigest()
where = {"url": url}
# if data type is qna_pair, we check for question
if chunker.data_type == DataType.QNA_PAIR:
where = {"question": src[0]}
if self.config.id is not None:
where["app_id"] = self.config.id
db_result = self.db.get(ids=ids, where=where) # optional filter
existing_ids = set(db_result["ids"])
if len(existing_ids):
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
if not data_dict:
src_copy = src
if len(src_copy) > 50:
src_copy = src[:50] + "..."
logger.info(f"All data from {src_copy} already exists in the database.")
# Make sure to return a matching return type
return [], [], [], 0
ids = list(data_dict.keys())
documents, metadatas = zip(*data_dict.values())
# Loop though all metadatas and add extras.
new_metadatas = []
for m in metadatas:
# Add app id in metadatas so that they can be queried on later
if self.config.id:
m["app_id"] = self.config.id
# Add hashed source
m["hash"] = source_hash
# Note: Metadata is the function argument
if metadata:
# Spread whatever is in metadata into the new object.
m.update(metadata)
new_metadatas.append(m)
metadatas = new_metadatas
if dry_run:
return list(documents), metadatas, ids, 0
# Count before, to calculate a delta in the end.
chunks_before_addition = self.db.count()
# Filter out empty documents and ensure they meet the API requirements
valid_documents = [doc for doc in documents if doc and isinstance(doc, str)]
documents = valid_documents
# Chunk documents into batches of 2048 and handle each batch
# helps wigth large loads of embeddings that hit OpenAI limits
document_batches = [documents[i : i + 2048] for i in range(0, len(documents), 2048)]
metadata_batches = [metadatas[i : i + 2048] for i in range(0, len(metadatas), 2048)]
id_batches = [ids[i : i + 2048] for i in range(0, len(ids), 2048)]
for batch_docs, batch_meta, batch_ids in zip(document_batches, metadata_batches, id_batches):
try:
# Add only valid batches
if batch_docs:
self.db.add(documents=batch_docs, metadatas=batch_meta, ids=batch_ids, **kwargs)
except Exception as e:
logger.info(f"Failed to add batch due to a bad request: {e}")
# Handle the error, e.g., by logging, retrying, or skipping
pass
count_new_chunks = self.db.count() - chunks_before_addition
logger.info(f"Successfully saved {str(src)[:100]} ({chunker.data_type}). New chunks count: {count_new_chunks}")
return list(documents), metadatas, ids, count_new_chunks
@staticmethod
def _format_result(results):
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]
def _retrieve_from_database(
self,
input_query: str,
config: Optional[BaseLlmConfig] = None,
where=None,
citations: bool = False,
**kwargs: Optional[dict[str, Any]],
) -> Union[list[tuple[str, str, str]], list[str]]:
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query
:param input_query: The query to use.
:type input_query: str
:param config: The query configuration, defaults to None
:type config: Optional[BaseLlmConfig], optional
:param where: A dictionary of key-value pairs to filter the database results, defaults to None
:type where: _type_, optional
:param citations: A boolean to indicate if db should fetch citation source
:type citations: bool
:return: List of contents of the document that matched your query
:rtype: list[str]
"""
query_config = config or self.llm.config
if where is not None:
where = where
else:
where = {}
if query_config is not None and query_config.where is not None:
where = query_config.where
if self.config.id is not None:
where.update({"app_id": self.config.id})
contexts = self.db.query(
input_query=input_query,
n_results=query_config.number_documents,
where=where,
citations=citations,
**kwargs,
)
return contexts
def query(
self,
input_query: str,
config: BaseLlmConfig = None,
dry_run=False,
where: Optional[dict] = None,
citations: bool = False,
**kwargs: dict[str, Any],
) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]:
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
:param input_query: The query to use.
:type input_query: str
:param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
To persistently use a config, declare it during app init., defaults to None
:type config: BaseLlmConfig, optional
:param dry_run: A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response., defaults to False
:type dry_run: bool, optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: dict[str, str], optional
:param citations: A boolean to indicate if db should fetch citation source
:type citations: bool
:param kwargs: To read more params for the query function. Ex. we use citations boolean
param to return context along with the answer
:type kwargs: dict[str, Any]
:return: The answer to the query, with citations if the citation flag is True
or the dry run result
:rtype: str, if citations is False and token_usage is False, otherwise if citations is true then
tuple[str, list[tuple[str,str,str]]] and if token_usage is true then
tuple[str, list[tuple[str,str,str]], dict[str, Any]]
"""
contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations, **kwargs
)
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
else:
contexts_data_for_llm_query = contexts
if self.cache_config is not None:
logger.info("Cache enabled. Checking cache...")
answer = adapt(
llm_handler=self.llm.query,
cache_data_convert=gptcache_data_convert,
update_cache_callback=gptcache_update_cache_callback,
session=get_gptcache_session(session_id=self.config.id),
input_query=input_query,
contexts=contexts_data_for_llm_query,
config=config,
dry_run=dry_run,
)
else:
if self.llm.config.token_usage:
answer, token_info = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
else:
answer = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
# Send anonymous telemetry
self.telemetry.capture(event_name="query", properties=self._telemetry_props)
if citations:
if self.llm.config.token_usage:
return {"answer": answer, "contexts": contexts, "usage": token_info}
return answer, contexts
if self.llm.config.token_usage:
return {"answer": answer, "usage": token_info}
logger.warning(
"Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`."
)
return answer
def chat(
self,
input_query: str,
config: Optional[BaseLlmConfig] = None,
dry_run=False,
session_id: str = "default",
where: Optional[dict[str, str]] = None,
citations: bool = False,
**kwargs: dict[str, Any],
) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]:
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
Maintains the whole conversation in memory.
:param input_query: The query to use.
:type input_query: str
:param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
To persistently use a config, declare it during app init., defaults to None
:type config: BaseLlmConfig, optional
:param dry_run: A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response., defaults to False
:type dry_run: bool, optional
:param session_id: The session id to use for chat history, defaults to 'default'.
:type session_id: str, optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: dict[str, str], optional
:param citations: A boolean to indicate if db should fetch citation source
:type citations: bool
:param kwargs: To read more params for the query function. Ex. we use citations boolean
param to return context along with the answer
:type kwargs: dict[str, Any]
:return: The answer to the query, with citations if the citation flag is True
or the dry run result
:rtype: str, if citations is False and token_usage is False, otherwise if citations is true then
tuple[str, list[tuple[str,str,str]]] and if token_usage is true then
tuple[str, list[tuple[str,str,str]], dict[str, Any]]
"""
contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations, **kwargs
)
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
else:
contexts_data_for_llm_query = contexts
memories = None
if self.mem0_client:
memories = self.mem0_client.search(
query=input_query, agent_id=self.config.id, session_id=session_id, limit=self.memory_config.top_k
)
# Update the history beforehand so that we can handle multiple chat sessions in the same python session
self.llm.update_history(app_id=self.config.id, session_id=session_id)
if self.cache_config is not None:
logger.debug("Cache enabled. Checking cache...")
cache_id = f"{session_id}--{self.config.id}"
answer = adapt(
llm_handler=self.llm.chat,
cache_data_convert=gptcache_data_convert,
update_cache_callback=gptcache_update_cache_callback,
session=get_gptcache_session(session_id=cache_id),
input_query=input_query,
contexts=contexts_data_for_llm_query,
config=config,
dry_run=dry_run,
)
else:
logger.debug("Cache disabled. Running chat without cache.")
if self.llm.config.token_usage:
answer, token_info = self.llm.query(
input_query=input_query,
contexts=contexts_data_for_llm_query,
config=config,
dry_run=dry_run,
memories=memories,
)
else:
answer = self.llm.query(
input_query=input_query,
contexts=contexts_data_for_llm_query,
config=config,
dry_run=dry_run,
memories=memories,
)
# Add to Mem0 memory if enabled
# TODO: Might need to prepend with some text like:
# "Remember user preferences from following user query: {input_query}"
if self.mem0_client:
self.mem0_client.add(data=input_query, agent_id=self.config.id, session_id=session_id)
# add conversation in memory
self.llm.add_history(self.config.id, input_query, answer, session_id=session_id)
# Send anonymous telemetry
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
if citations:
if self.llm.config.token_usage:
return {"answer": answer, "contexts": contexts, "usage": token_info}
return answer, contexts
if self.llm.config.token_usage:
return {"answer": answer, "usage": token_info}
logger.warning(
"Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`."
)
return answer
def search(self, query, num_documents=3, where=None, raw_filter=None, namespace=None):
"""
Search for similar documents related to the query in the vector database.
Args:
query (str): The query to use.
num_documents (int, optional): Number of similar documents to fetch. Defaults to 3.
where (dict[str, any], optional): Filter criteria for the search.
raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.
namespace (str, optional): The namespace to search in. Defaults to None.
Raises:
ValueError: If both `raw_filter` and `where` are used simultaneously.
Returns:
list[dict]: A list of dictionaries, each containing the 'context' and 'metadata' of a document.
"""
# Send anonymous telemetry
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
if raw_filter and where:
raise ValueError("You can't use both `raw_filter` and `where` together.")
filter_type = "raw_filter" if raw_filter else "where"
filter_criteria = raw_filter if raw_filter else where
params = {
"input_query": query,
"n_results": num_documents,
"citations": True,
"app_id": self.config.id,
"namespace": namespace,
filter_type: filter_criteria,
}
return [{"context": c[0], "metadata": c[1]} for c in self.db.query(**params)]
def set_collection_name(self, name: str):
"""
Set the name of the collection. A collection is an isolated space for vectors.
Using `app.db.set_collection_name` method is preferred to this.
:param name: Name of the collection.
:type name: str
"""
self.db.set_collection_name(name)
# Create the collection if it does not exist
self.db._get_or_create_collection(name)
# TODO: Check whether it is necessary to assign to the `self.collection` attribute,
# since the main purpose is the creation.
def reset(self):
"""
Resets the database. Deletes all embeddings irreversibly.
`App` does not have to be reinitialized after using this method.
"""
try:
self.db_session.query(DataSource).filter_by(app_id=self.config.id).delete()
self.db_session.query(ChatHistory).filter_by(app_id=self.config.id).delete()
self.db_session.commit()
except Exception as e:
logger.error(f"Error deleting data sources: {e}")
self.db_session.rollback()
return None
self.db.reset()
self.delete_all_chat_history(app_id=self.config.id)
# Send anonymous telemetry
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
def get_history(
self,
num_rounds: int = 10,
display_format: bool = True,
session_id: Optional[str] = "default",
fetch_all: bool = False,
):
history = self.llm.memory.get(
app_id=self.config.id,
session_id=session_id,
num_rounds=num_rounds,
display_format=display_format,
fetch_all=fetch_all,
)
return history
def delete_session_chat_history(self, session_id: str = "default"):
self.llm.memory.delete(app_id=self.config.id, session_id=session_id)
self.llm.update_history(app_id=self.config.id)
def delete_all_chat_history(self, app_id: str):
self.llm.memory.delete(app_id=app_id)
self.llm.update_history(app_id=app_id)
def delete(self, source_id: str):
"""
Deletes the data from the database.
:param source_hash: The hash of the source.
:type source_hash: str
"""
try:
self.db_session.query(DataSource).filter_by(hash=source_id, app_id=self.config.id).delete()
self.db_session.commit()
except Exception as e:
logger.error(f"Error deleting data sources: {e}")
self.db_session.rollback()
return None
self.db.delete(where={"hash": source_id})
logger.info(f"Successfully deleted {source_id}")
# Send anonymous telemetry
if self.config.collect_metrics:
self.telemetry.capture(event_name="delete", properties=self._telemetry_props)

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain_community.embeddings import AzureOpenAIEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class AzureOpenAIEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config)
if self.config.model is None:
self.config.model = "text-embedding-ada-002"
embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = self.config.vector_dimension or VectorDimensions.OPENAI.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -0,0 +1,90 @@
from collections.abc import Callable
from typing import Any, Optional
from embedchain.config.embedder.base import BaseEmbedderConfig
try:
from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
except RuntimeError:
from embedchain.utils.misc import use_pysqlite3
use_pysqlite3()
from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
class EmbeddingFunc(EmbeddingFunction):
def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
self.embedding_fn = embedding_fn
def __call__(self, input: Embeddable) -> Embeddings:
return self.embedding_fn(input)
class BaseEmbedder:
"""
Class that manages everything regarding embeddings. Including embedding function, loaders and chunkers.
Embedding functions and vector dimensions are set based on the child class you choose.
To manually overwrite you can use this classes `set_...` methods.
"""
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
"""
Initialize the embedder class.
:param config: embedder configuration option class, defaults to None
:type config: Optional[BaseEmbedderConfig], optional
"""
if config is None:
self.config = BaseEmbedderConfig()
else:
self.config = config
self.vector_dimension: int
def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
"""
Set or overwrite the embedding function to be used by the database to store and retrieve documents.
:param embedding_fn: Function to be used to generate embeddings.
:type embedding_fn: Callable[[list[str]], list[str]]
:raises ValueError: Embedding function is not callable.
"""
if not hasattr(embedding_fn, "__call__"):
raise ValueError("Embedding function is not a function")
self.embedding_fn = embedding_fn
def set_vector_dimension(self, vector_dimension: int):
"""
Set or overwrite the vector dimension size
:param vector_dimension: vector dimension size
:type vector_dimension: int
"""
if not isinstance(vector_dimension, int):
raise TypeError("vector dimension must be int")
self.vector_dimension = vector_dimension
@staticmethod
def _langchain_default_concept(embeddings: Any):
"""
Langchains default function layout for embeddings.
:param embeddings: Langchain embeddings
:type embeddings: Any
:return: embedding function
:rtype: Callable
"""
return EmbeddingFunc(embeddings.embed_documents)
def to_embeddings(self, data: str, **_):
"""
Convert data to embeddings
:param data: data to convert to embeddings
:type data: str
:return: embeddings
:rtype: list[float]
"""
embeddings = self.embedding_fn([data])
return embeddings[0]

View File

@@ -0,0 +1,52 @@
import os
from typing import Optional, Union
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from chromadb import EmbeddingFunction, Embeddings
class ClarifaiEmbeddingFunction(EmbeddingFunction):
def __init__(self, config: BaseEmbedderConfig) -> None:
super().__init__()
try:
from clarifai.client.model import Model
from clarifai.client.input import Inputs
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for ClarifaiEmbeddingFunction are not installed."
'Please install with `pip install --upgrade "embedchain[clarifai]"`'
) from None
self.config = config
self.api_key = config.api_key or os.getenv("CLARIFAI_PAT")
self.model = config.model
self.model_obj = Model(url=self.model, pat=self.api_key)
self.input_obj = Inputs(pat=self.api_key)
def __call__(self, input: Union[str, list[str]]) -> Embeddings:
if isinstance(input, str):
input = [input]
batch_size = 32
embeddings = []
try:
for i in range(0, len(input), batch_size):
batch = input[i : i + batch_size]
input_batch = [
self.input_obj.get_text_input(input_id=str(id), raw_text=inp) for id, inp in enumerate(batch)
]
response = self.model_obj.predict(input_batch)
embeddings.extend([list(output.data.embeddings[0].vector) for output in response.outputs])
except Exception as e:
print(f"Predict failed, exception: {e}")
return embeddings
class ClarifaiEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
embedding_func = ClarifaiEmbeddingFunction(config=self.config)
self.set_embedding_fn(embedding_fn=embedding_func)

View File

@@ -0,0 +1,19 @@
from typing import Optional
from langchain_cohere.embeddings import CohereEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class CohereEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config)
embeddings = CohereEmbeddings(model=self.config.model)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = self.config.vector_dimension or VectorDimensions.COHERE.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -0,0 +1,38 @@
from typing import Optional, Union
import google.generativeai as genai
from chromadb import EmbeddingFunction, Embeddings
from embedchain.config.embedder.google import GoogleAIEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class GoogleAIEmbeddingFunction(EmbeddingFunction):
def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None) -> None:
super().__init__()
self.config = config or GoogleAIEmbedderConfig()
def __call__(self, input: Union[list[str], str]) -> Embeddings:
model = self.config.model
title = self.config.title
task_type = self.config.task_type
if isinstance(input, str):
input_ = [input]
else:
input_ = input
data = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
embeddings = data["embedding"]
if isinstance(input_, str):
embeddings = [embeddings]
return embeddings
class GoogleAIEmbedder(BaseEmbedder):
def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None):
super().__init__(config)
embedding_fn = GoogleAIEmbeddingFunction(config=config)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = self.config.vector_dimension or VectorDimensions.GOOGLE_AI.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -0,0 +1,20 @@
from typing import Optional
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class GPT4AllEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config)
from langchain_community.embeddings import GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
model_name = self.config.model or "all-MiniLM-L6-v2-f16.gguf"
embeddings = LangchainGPT4AllEmbeddings(model_name=model_name)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = self.config.vector_dimension or VectorDimensions.GPT4ALL.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -0,0 +1,40 @@
import os
from typing import Optional
from langchain_community.embeddings import HuggingFaceEmbeddings
try:
from langchain_huggingface import HuggingFaceEndpointEmbeddings
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for HuggingFaceHub are not installed."
"Please install with `pip install langchain_huggingface`"
) from None
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class HuggingFaceEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config)
if self.config.endpoint:
if not self.config.api_key and "HUGGINGFACE_ACCESS_TOKEN" not in os.environ:
raise ValueError(
"Please set the HUGGINGFACE_ACCESS_TOKEN environment variable or pass API Key in the config."
)
embeddings = HuggingFaceEndpointEmbeddings(
model=self.config.endpoint,
huggingfacehub_api_token=self.config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
)
else:
embeddings = HuggingFaceEmbeddings(model_name=self.config.model, model_kwargs=self.config.model_kwargs)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = self.config.vector_dimension or VectorDimensions.HUGGING_FACE.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -0,0 +1,46 @@
import os
from typing import Optional, Union
from chromadb import EmbeddingFunction, Embeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class MistralAIEmbeddingFunction(EmbeddingFunction):
def __init__(self, config: BaseEmbedderConfig) -> None:
super().__init__()
try:
from langchain_mistralai import MistralAIEmbeddings
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for MistralAI are not installed."
'Please install with `pip install --upgrade "embedchain[mistralai]"`'
) from None
self.config = config
api_key = self.config.api_key or os.getenv("MISTRAL_API_KEY")
self.client = MistralAIEmbeddings(mistral_api_key=api_key)
self.client.model = self.config.model
def __call__(self, input: Union[list[str], str]) -> Embeddings:
if isinstance(input, str):
input_ = [input]
else:
input_ = input
response = self.client.embed_documents(input_)
return response
class MistralAIEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
if self.config.model is None:
self.config.model = "mistral-embed"
embedding_fn = MistralAIEmbeddingFunction(config=self.config)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = self.config.vector_dimension or VectorDimensions.MISTRAL_AI.value
self.set_vector_dimension(vector_dimension=vector_dimension)

Some files were not shown because too many files have changed in this diff Show More