Rename embedchain to mem0 and open sourcing code for long term memory (#1474)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
10
embedchain/embedchain/__init__.py
Normal file
10
embedchain/embedchain/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import importlib.metadata
|
||||
|
||||
__version__ = importlib.metadata.version(__package__ or __name__)
|
||||
|
||||
from embedchain.app import App # noqa: F401
|
||||
from embedchain.client import Client # noqa: F401
|
||||
from embedchain.pipeline import Pipeline # noqa: F401
|
||||
|
||||
# Setup the user directory if doesn't exist already
|
||||
Client.setup()
|
||||
116
embedchain/embedchain/alembic.ini
Normal file
116
embedchain/embedchain/alembic.ini
Normal file
@@ -0,0 +1,116 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = embedchain:migrations
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
521
embedchain/embedchain/app.py
Normal file
521
embedchain/embedchain/app.py
Normal file
@@ -0,0 +1,521 @@
|
||||
import ast
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
from mem0 import Mem0
|
||||
from embedchain.cache import (
|
||||
Config,
|
||||
ExactMatchEvaluation,
|
||||
SearchDistanceEvaluation,
|
||||
cache,
|
||||
gptcache_data_manager,
|
||||
gptcache_pre_function,
|
||||
)
|
||||
from embedchain.client import Client
|
||||
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
|
||||
from embedchain.core.db.database import get_session, init_db, setup_engine
|
||||
from embedchain.core.db.models import DataSource
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.evaluation.base import BaseMetric
|
||||
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils.evaluation import EvalData, EvalMetric
|
||||
from embedchain.utils.misc import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class App(EmbedChain):
|
||||
"""
|
||||
EmbedChain App lets you create a LLM powered app for your unstructured
|
||||
data by defining your chosen data source, embedding model,
|
||||
and vector database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = None,
|
||||
name: str = None,
|
||||
config: AppConfig = None,
|
||||
db: BaseVectorDB = None,
|
||||
embedding_model: BaseEmbedder = None,
|
||||
llm: BaseLlm = None,
|
||||
config_data: dict = None,
|
||||
auto_deploy: bool = False,
|
||||
chunker: ChunkerConfig = None,
|
||||
cache_config: CacheConfig = None,
|
||||
memory_config: Mem0Config = None,
|
||||
log_level: int = logging.WARN,
|
||||
):
|
||||
"""
|
||||
Initialize a new `App` instance.
|
||||
|
||||
:param config: Configuration for the pipeline, defaults to None
|
||||
:type config: AppConfig, optional
|
||||
:param db: The database to use for storing and retrieving embeddings, defaults to None
|
||||
:type db: BaseVectorDB, optional
|
||||
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
|
||||
:type embedding_model: BaseEmbedder, optional
|
||||
:param llm: The LLM model used to calculate embeddings, defaults to None
|
||||
:type llm: BaseLlm, optional
|
||||
:param config_data: Config dictionary, defaults to None
|
||||
:type config_data: dict, optional
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:raises Exception: If an error occurs while creating the pipeline
|
||||
"""
|
||||
if id and config_data:
|
||||
raise Exception("Cannot provide both id and config. Please provide only one of them.")
|
||||
|
||||
if id and name:
|
||||
raise Exception("Cannot provide both id and name. Please provide only one of them.")
|
||||
|
||||
if name and config:
|
||||
raise Exception("Cannot provide both name and config. Please provide only one of them.")
|
||||
|
||||
# Initialize the metadata db for the app
|
||||
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
|
||||
init_db()
|
||||
|
||||
self.auto_deploy = auto_deploy
|
||||
# Store the dict config as an attribute to be able to send it
|
||||
self.config_data = config_data if (config_data and validate_config(config_data)) else None
|
||||
self.client = None
|
||||
# pipeline_id from the backend
|
||||
self.id = None
|
||||
self.chunker = ChunkerConfig(**chunker) if chunker else None
|
||||
self.cache_config = cache_config
|
||||
self.memory_config = memory_config
|
||||
|
||||
self.config = config or AppConfig()
|
||||
self.name = self.config.name
|
||||
self.config.id = self.local_id = "default-app-id" if self.config.id is None else self.config.id
|
||||
|
||||
if id is not None:
|
||||
# Init client first since user is trying to fetch the pipeline
|
||||
# details from the platform
|
||||
self._init_client()
|
||||
pipeline_details = self._get_pipeline(id)
|
||||
self.config.id = self.local_id = pipeline_details["metadata"]["local_id"]
|
||||
self.id = id
|
||||
|
||||
if name is not None:
|
||||
self.name = name
|
||||
|
||||
self.embedding_model = embedding_model or OpenAIEmbedder()
|
||||
self.db = db or ChromaDB()
|
||||
self.llm = llm or OpenAILlm()
|
||||
self._init_db()
|
||||
|
||||
# Session for the metadata db
|
||||
self.db_session = get_session()
|
||||
|
||||
# If cache_config is provided, initializing the cache ...
|
||||
if self.cache_config is not None:
|
||||
self._init_cache()
|
||||
|
||||
# If memory_config is provided, initializing the memory ...
|
||||
self.mem0_client = None
|
||||
if self.memory_config is not None:
|
||||
self.mem0_client = Mem0(api_key=self.memory_config.api_key)
|
||||
|
||||
# Send anonymous telemetry
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
|
||||
|
||||
self.user_asks = []
|
||||
if self.auto_deploy:
|
||||
self.deploy()
|
||||
|
||||
def _init_db(self):
|
||||
"""
|
||||
Initialize the database.
|
||||
"""
|
||||
self.db._set_embedder(self.embedding_model)
|
||||
self.db._initialize()
|
||||
self.db.set_collection_name(self.db.config.collection_name)
|
||||
|
||||
def _init_cache(self):
|
||||
if self.cache_config.similarity_eval_config.strategy == "exact":
|
||||
similarity_eval_func = ExactMatchEvaluation()
|
||||
else:
|
||||
similarity_eval_func = SearchDistanceEvaluation(
|
||||
max_distance=self.cache_config.similarity_eval_config.max_distance,
|
||||
positive=self.cache_config.similarity_eval_config.positive,
|
||||
)
|
||||
|
||||
cache.init(
|
||||
pre_embedding_func=gptcache_pre_function,
|
||||
embedding_func=self.embedding_model.to_embeddings,
|
||||
data_manager=gptcache_data_manager(vector_dimension=self.embedding_model.vector_dimension),
|
||||
similarity_evaluation=similarity_eval_func,
|
||||
config=Config(**self.cache_config.init_config.as_dict()),
|
||||
)
|
||||
|
||||
def _init_client(self):
|
||||
"""
|
||||
Initialize the client.
|
||||
"""
|
||||
config = Client.load_config()
|
||||
if config.get("api_key"):
|
||||
self.client = Client()
|
||||
else:
|
||||
api_key = input(
|
||||
"🔑 Enter your Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n" # noqa: E501
|
||||
)
|
||||
self.client = Client(api_key=api_key)
|
||||
|
||||
def _get_pipeline(self, id):
|
||||
"""
|
||||
Get existing pipeline
|
||||
"""
|
||||
print("🛠️ Fetching pipeline details from the platform...")
|
||||
url = f"{self.client.host}/api/v1/pipelines/{id}/cli/"
|
||||
r = requests.get(
|
||||
url,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
if r.status_code == 404:
|
||||
raise Exception(f"❌ Pipeline with id {id} not found!")
|
||||
|
||||
print(
|
||||
f"🎉 Pipeline loaded successfully! Pipeline url: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
)
|
||||
return r.json()
|
||||
|
||||
def _create_pipeline(self):
|
||||
"""
|
||||
Create a pipeline on the platform.
|
||||
"""
|
||||
print("🛠️ Creating pipeline on the platform...")
|
||||
# self.config_data is a dict. Pass it inside the key 'yaml_config' to the backend
|
||||
payload = {
|
||||
"yaml_config": json.dumps(self.config_data),
|
||||
"name": self.name,
|
||||
"local_id": self.local_id,
|
||||
}
|
||||
url = f"{self.client.host}/api/v1/pipelines/cli/create/"
|
||||
r = requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
if r.status_code not in [200, 201]:
|
||||
raise Exception(f"❌ Error occurred while creating pipeline. API response: {r.text}")
|
||||
|
||||
if r.status_code == 200:
|
||||
print(
|
||||
f"🎉🎉🎉 Existing pipeline found! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
) # noqa: E501
|
||||
elif r.status_code == 201:
|
||||
print(
|
||||
f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
)
|
||||
return r.json()
|
||||
|
||||
def _get_presigned_url(self, data_type, data_value):
|
||||
payload = {"data_type": data_type, "data_value": data_value}
|
||||
r = requests.post(
|
||||
f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
def _upload_file_to_presigned_url(self, presigned_url, file_path):
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
response = requests.put(presigned_url, data=file)
|
||||
response.raise_for_status()
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logger.exception(f"Error occurred during file upload: {str(e)}")
|
||||
print("❌ Error occurred during file upload!")
|
||||
return False
|
||||
|
||||
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
|
||||
payload = {
|
||||
"data_type": data_type,
|
||||
"data_value": data_value,
|
||||
"metadata": metadata,
|
||||
}
|
||||
try:
|
||||
self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
|
||||
# print the local file path if user tries to upload a local file
|
||||
printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
|
||||
print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
|
||||
except Exception as e:
|
||||
print(f"❌ Error occurred during data upload for type {data_type}!. Error: {str(e)}")
|
||||
|
||||
def _send_api_request(self, endpoint, payload):
|
||||
url = f"{self.client.host}{endpoint}"
|
||||
headers = {"Authorization": f"Token {self.client.api_key}"}
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _process_and_upload_data(self, data_hash, data_type, data_value):
|
||||
if os.path.isabs(data_value):
|
||||
presigned_url_data = self._get_presigned_url(data_type, data_value)
|
||||
presigned_url = presigned_url_data["presigned_url"]
|
||||
s3_key = presigned_url_data["s3_key"]
|
||||
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
|
||||
metadata = {"file_path": data_value, "s3_key": s3_key}
|
||||
data_value = presigned_url
|
||||
else:
|
||||
logger.error(f"File upload failed for hash: {data_hash}")
|
||||
return False
|
||||
else:
|
||||
if data_type == "qna_pair":
|
||||
data_value = list(ast.literal_eval(data_value))
|
||||
metadata = {}
|
||||
|
||||
try:
|
||||
self._upload_data_to_pipeline(data_type, data_value, metadata)
|
||||
self._mark_data_as_uploaded(data_hash)
|
||||
return True
|
||||
except Exception:
|
||||
print(f"❌ Error occurred during data upload for hash {data_hash}!")
|
||||
return False
|
||||
|
||||
def _mark_data_as_uploaded(self, data_hash):
|
||||
self.db_session.query(DataSource).filter_by(hash=data_hash, app_id=self.local_id).update({"is_uploaded": 1})
|
||||
|
||||
def get_data_sources(self):
|
||||
data_sources = self.db_session.query(DataSource).filter_by(app_id=self.local_id).all()
|
||||
results = []
|
||||
for row in data_sources:
|
||||
results.append({"data_type": row.type, "data_value": row.value, "metadata": row.meta_data})
|
||||
return results
|
||||
|
||||
def deploy(self):
|
||||
if self.client is None:
|
||||
self._init_client()
|
||||
|
||||
pipeline_data = self._create_pipeline()
|
||||
self.id = pipeline_data["id"]
|
||||
|
||||
results = self.db_session.query(DataSource).filter_by(app_id=self.local_id, is_uploaded=0).all()
|
||||
if len(results) > 0:
|
||||
print("🛠️ Adding data to your pipeline...")
|
||||
for result in results:
|
||||
data_hash, data_type, data_value = result.hash, result.data_type, result.data_value
|
||||
self._process_and_upload_data(data_hash, data_type, data_value)
|
||||
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config_path: Optional[str] = None,
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
auto_deploy: bool = False,
|
||||
yaml_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Instantiate a App object from a configuration.
|
||||
|
||||
:param config_path: Path to the YAML or JSON configuration file.
|
||||
:type config_path: Optional[str]
|
||||
:param config: A dictionary containing the configuration.
|
||||
:type config: Optional[dict[str, Any]]
|
||||
:param auto_deploy: Whether to deploy the app automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
|
||||
:type yaml_path: Optional[str]
|
||||
:return: An instance of the App class.
|
||||
:rtype: App
|
||||
"""
|
||||
# Backward compatibility for yaml_path
|
||||
if yaml_path and not config_path:
|
||||
config_path = yaml_path
|
||||
|
||||
if config_path and config:
|
||||
raise ValueError("Please provide only one of config_path or config.")
|
||||
|
||||
config_data = None
|
||||
|
||||
if config_path:
|
||||
file_extension = os.path.splitext(config_path)[1]
|
||||
with open(config_path, "r", encoding="UTF-8") as file:
|
||||
if file_extension in [".yaml", ".yml"]:
|
||||
config_data = yaml.safe_load(file)
|
||||
elif file_extension == ".json":
|
||||
config_data = json.load(file)
|
||||
else:
|
||||
raise ValueError("config_path must be a path to a YAML or JSON file.")
|
||||
elif config and isinstance(config, dict):
|
||||
config_data = config
|
||||
else:
|
||||
logger.error(
|
||||
"Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
|
||||
)
|
||||
config_data = {}
|
||||
|
||||
# Validate the config
|
||||
validate_config(config_data)
|
||||
|
||||
app_config_data = config_data.get("app", {}).get("config", {})
|
||||
vector_db_config_data = config_data.get("vectordb", {})
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
memory_config_data = config_data.get("memory", {})
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
chunker_config_data = config_data.get("chunker", {})
|
||||
cache_config_data = config_data.get("cache", None)
|
||||
|
||||
app_config = AppConfig(**app_config_data)
|
||||
memory_config = Mem0Config(**memory_config_data) if memory_config_data else None
|
||||
|
||||
vector_db_provider = vector_db_config_data.get("provider", "chroma")
|
||||
vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
|
||||
|
||||
if llm_config_data:
|
||||
# Initialize the metadata db for the app here since llmfactory needs it for initialization of
|
||||
# the llm memory
|
||||
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
|
||||
init_db()
|
||||
llm_provider = llm_config_data.get("provider", "openai")
|
||||
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
|
||||
else:
|
||||
llm = None
|
||||
|
||||
embedding_model_provider = embedding_model_config_data.get("provider", "openai")
|
||||
embedding_model = EmbedderFactory.create(
|
||||
embedding_model_provider, embedding_model_config_data.get("config", {})
|
||||
)
|
||||
|
||||
if cache_config_data is not None:
|
||||
cache_config = CacheConfig.from_config(cache_config_data)
|
||||
else:
|
||||
cache_config = None
|
||||
|
||||
return cls(
|
||||
config=app_config,
|
||||
llm=llm,
|
||||
db=vector_db,
|
||||
embedding_model=embedding_model,
|
||||
config_data=config_data,
|
||||
auto_deploy=auto_deploy,
|
||||
chunker=chunker_config_data,
|
||||
cache_config=cache_config,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
def _eval(self, dataset: list[EvalData], metric: Union[BaseMetric, str]):
|
||||
"""
|
||||
Evaluate the app on a dataset for a given metric.
|
||||
"""
|
||||
metric_str = metric.name if isinstance(metric, BaseMetric) else metric
|
||||
eval_class_map = {
|
||||
EvalMetric.CONTEXT_RELEVANCY.value: ContextRelevance,
|
||||
EvalMetric.ANSWER_RELEVANCY.value: AnswerRelevance,
|
||||
EvalMetric.GROUNDEDNESS.value: Groundedness,
|
||||
}
|
||||
|
||||
if metric_str in eval_class_map:
|
||||
return eval_class_map[metric_str]().evaluate(dataset)
|
||||
|
||||
# Handle the case for custom metrics
|
||||
if isinstance(metric, BaseMetric):
|
||||
return metric.evaluate(dataset)
|
||||
else:
|
||||
raise ValueError(f"Invalid metric: {metric}")
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
questions: Union[str, list[str]],
|
||||
metrics: Optional[list[Union[BaseMetric, str]]] = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
"""
|
||||
Evaluate the app on a question.
|
||||
|
||||
param: questions: A question or a list of questions to evaluate.
|
||||
type: questions: Union[str, list[str]]
|
||||
param: metrics: A list of metrics to evaluate. Defaults to all metrics.
|
||||
type: metrics: Optional[list[Union[BaseMetric, str]]]
|
||||
param: num_workers: Number of workers to use for parallel processing.
|
||||
type: num_workers: int
|
||||
return: A dictionary containing the evaluation results.
|
||||
rtype: dict
|
||||
"""
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
raise ValueError("Please set the OPENAI_API_KEY environment variable with permission to use `gpt4` model.")
|
||||
|
||||
queries, answers, contexts = [], [], []
|
||||
if isinstance(questions, list):
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
future_to_data = {executor.submit(self.query, q, citations=True): q for q in questions}
|
||||
for future in tqdm(
|
||||
concurrent.futures.as_completed(future_to_data),
|
||||
total=len(future_to_data),
|
||||
desc="Getting answer and contexts for questions",
|
||||
):
|
||||
question = future_to_data[future]
|
||||
queries.append(question)
|
||||
answer, context = future.result()
|
||||
answers.append(answer)
|
||||
contexts.append(list(map(lambda x: x[0], context)))
|
||||
else:
|
||||
answer, context = self.query(questions, citations=True)
|
||||
queries = [questions]
|
||||
answers = [answer]
|
||||
contexts = [list(map(lambda x: x[0], context))]
|
||||
|
||||
metrics = metrics or [
|
||||
EvalMetric.CONTEXT_RELEVANCY.value,
|
||||
EvalMetric.ANSWER_RELEVANCY.value,
|
||||
EvalMetric.GROUNDEDNESS.value,
|
||||
]
|
||||
|
||||
logger.info(f"Collecting data from {len(queries)} questions for evaluation...")
|
||||
dataset = []
|
||||
for q, a, c in zip(queries, answers, contexts):
|
||||
dataset.append(EvalData(question=q, answer=a, contexts=c))
|
||||
|
||||
logger.info(f"Evaluating {len(dataset)} data points...")
|
||||
result = {}
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
future_to_metric = {executor.submit(self._eval, dataset, metric): metric for metric in metrics}
|
||||
for future in tqdm(
|
||||
concurrent.futures.as_completed(future_to_metric),
|
||||
total=len(future_to_metric),
|
||||
desc="Evaluating metrics",
|
||||
):
|
||||
metric = future_to_metric[future]
|
||||
if isinstance(metric, BaseMetric):
|
||||
result[metric.name] = future.result()
|
||||
else:
|
||||
result[metric] = future.result()
|
||||
|
||||
if self.config.collect_metrics:
|
||||
telemetry_props = self._telemetry_props
|
||||
metrics_names = []
|
||||
for metric in metrics:
|
||||
if isinstance(metric, BaseMetric):
|
||||
metrics_names.append(metric.name)
|
||||
else:
|
||||
metrics_names.append(metric)
|
||||
telemetry_props["metrics"] = metrics_names
|
||||
self.telemetry.capture(event_name="evaluate", properties=telemetry_props)
|
||||
|
||||
return result
|
||||
5
embedchain/embedchain/bots/__init__.py
Normal file
5
embedchain/embedchain/bots/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from embedchain.bots.poe import PoeBot # noqa: F401
|
||||
from embedchain.bots.whatsapp import WhatsAppBot # noqa: F401
|
||||
|
||||
# TODO: fix discord import
|
||||
# from embedchain.bots.discord import DiscordBot
|
||||
46
embedchain/embedchain/bots/base.py
Normal file
46
embedchain/embedchain/bots/base.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Any
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AddConfig, AppConfig, BaseLlmConfig
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.helpers.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class BaseBot(JSONSerializable):
|
||||
def __init__(self):
|
||||
self.app = App(config=AppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedding_model=OpenAIEmbedder())
|
||||
|
||||
def add(self, data: Any, config: AddConfig = None):
|
||||
"""
|
||||
Add data to the bot (to the vector database).
|
||||
Auto-dectects type only, so some data types might not be usable.
|
||||
|
||||
:param data: data to embed
|
||||
:type data: Any
|
||||
:param config: configuration class instance, defaults to None
|
||||
:type config: AddConfig, optional
|
||||
"""
|
||||
config = config if config else AddConfig()
|
||||
self.app.add(data, config=config)
|
||||
|
||||
def query(self, query: str, config: BaseLlmConfig = None) -> str:
|
||||
"""
|
||||
Query the bot
|
||||
|
||||
:param query: the user query
|
||||
:type query: str
|
||||
:param config: configuration class instance, defaults to None
|
||||
:type config: BaseLlmConfig, optional
|
||||
:return: Answer
|
||||
:rtype: str
|
||||
"""
|
||||
config = config
|
||||
return self.app.query(query, config=config)
|
||||
|
||||
def start(self):
|
||||
"""Start the bot's functionality."""
|
||||
raise NotImplementedError("Subclasses must implement the start method.")
|
||||
128
embedchain/embedchain/bots/discord.py
Normal file
128
embedchain/embedchain/bots/discord.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
|
||||
try:
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for Discord are not installed." "Please install with `pip install discord==2.3.2`"
|
||||
) from None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = discord.Client(intents=intents)
|
||||
tree = app_commands.CommandTree(client)
|
||||
|
||||
# Invite link example
|
||||
# https://discord.com/api/oauth2/authorize?client_id={DISCORD_CLIENT_ID}&permissions=2048&scope=bot
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DiscordBot(BaseBot):
|
||||
def __init__(self, *args, **kwargs):
|
||||
BaseBot.__init__(self, *args, **kwargs)
|
||||
|
||||
def add_data(self, message):
|
||||
data = message.split(" ")[-1]
|
||||
try:
|
||||
self.add(data)
|
||||
response = f"Added data from: {data}"
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add data {data}.")
|
||||
response = "Some error occurred while adding data."
|
||||
return response
|
||||
|
||||
def ask_bot(self, message):
|
||||
try:
|
||||
response = self.query(message)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to query {message}.")
|
||||
response = "An error occurred. Please try again!"
|
||||
return response
|
||||
|
||||
def start(self):
|
||||
client.run(os.environ["DISCORD_BOT_TOKEN"])
|
||||
|
||||
|
||||
# @tree decorator cannot be used in a class. A global discord_bot is used as a workaround.
|
||||
|
||||
|
||||
@tree.command(name="question", description="ask embedchain")
|
||||
async def query_command(interaction: discord.Interaction, question: str):
|
||||
await interaction.response.defer()
|
||||
member = client.guilds[0].get_member(client.user.id)
|
||||
logger.info(f"User: {member}, Query: {question}")
|
||||
try:
|
||||
answer = discord_bot.ask_bot(question)
|
||||
if args.include_question:
|
||||
response = f"> {question}\n\n{answer}"
|
||||
else:
|
||||
response = answer
|
||||
await interaction.followup.send(response)
|
||||
except Exception as e:
|
||||
await interaction.followup.send("An error occurred. Please try again!")
|
||||
logger.error("Error occurred during 'query' command:", e)
|
||||
|
||||
|
||||
@tree.command(name="add", description="add new content to the embedchain database")
|
||||
async def add_command(interaction: discord.Interaction, url_or_text: str):
|
||||
await interaction.response.defer()
|
||||
member = client.guilds[0].get_member(client.user.id)
|
||||
logger.info(f"User: {member}, Add: {url_or_text}")
|
||||
try:
|
||||
response = discord_bot.add_data(url_or_text)
|
||||
await interaction.followup.send(response)
|
||||
except Exception as e:
|
||||
await interaction.followup.send("An error occurred. Please try again!")
|
||||
logger.error("Error occurred during 'add' command:", e)
|
||||
|
||||
|
||||
@tree.command(name="ping", description="Simple ping pong command")
|
||||
async def ping(interaction: discord.Interaction):
|
||||
await interaction.response.send_message("Pong", ephemeral=True)
|
||||
|
||||
|
||||
@tree.error
|
||||
async def on_app_command_error(interaction: discord.Interaction, error: discord.app_commands.AppCommandError) -> None:
|
||||
if isinstance(error, commands.CommandNotFound):
|
||||
await interaction.followup.send("Invalid command. Please refer to the documentation for correct syntax.")
|
||||
else:
|
||||
logger.error("Error occurred during command execution:", error)
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
# TODO: Sync in admin command, to not hit rate limits.
|
||||
# This might be overkill for most users, and it would require to set a guild or user id, where sync is allowed.
|
||||
await tree.sync()
|
||||
logger.debug("Command tree synced")
|
||||
logger.info(f"Logged in as {client.user.name}")
|
||||
|
||||
|
||||
def start_command():
|
||||
parser = argparse.ArgumentParser(description="EmbedChain DiscordBot command line interface")
|
||||
parser.add_argument(
|
||||
"--include-question",
|
||||
help="include question in query reply, otherwise it is hidden behind the slash command.",
|
||||
action="store_true",
|
||||
)
|
||||
global args
|
||||
args = parser.parse_args()
|
||||
|
||||
global discord_bot
|
||||
discord_bot = DiscordBot()
|
||||
discord_bot.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_command()
|
||||
87
embedchain/embedchain/bots/poe.py
Normal file
87
embedchain/embedchain/bots/poe.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
|
||||
try:
|
||||
from fastapi_poe import PoeBot, run
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for Poe are not installed." "Please install with `pip install fastapi-poe==0.0.16`"
|
||||
) from None
|
||||
|
||||
|
||||
def start_command():
|
||||
parser = argparse.ArgumentParser(description="EmbedChain PoeBot command line interface")
|
||||
# parser.add_argument("--host", default="0.0.0.0", help="Host IP to bind")
|
||||
parser.add_argument("--port", default=8080, type=int, help="Port to bind")
|
||||
parser.add_argument("--api-key", type=str, help="Poe API key")
|
||||
# parser.add_argument(
|
||||
# "--history-length",
|
||||
# default=5,
|
||||
# type=int,
|
||||
# help="Set the max size of the chat history. Multiplies cost, but improves conversation awareness.",
|
||||
# )
|
||||
args = parser.parse_args()
|
||||
|
||||
# FIXME: Arguments are automatically loaded by Poebot's ArgumentParser which causes it to fail.
|
||||
# the port argument here is also just for show, it actually works because poe has the same argument.
|
||||
|
||||
run(PoeBot(), api_key=args.api_key or os.environ.get("POE_API_KEY"))
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PoeBot(BaseBot, PoeBot):
|
||||
def __init__(self):
|
||||
self.history_length = 5
|
||||
super().__init__()
|
||||
|
||||
async def get_response(self, query):
|
||||
last_message = query.query[-1].content
|
||||
try:
|
||||
history = (
|
||||
[f"{m.role}: {m.content}" for m in query.query[-(self.history_length + 1) : -1]]
|
||||
if len(query.query) > 0
|
||||
else None
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error when processing the chat history. Message is being sent without history. Error: {e}")
|
||||
answer = self.handle_message(last_message, history)
|
||||
yield self.text_event(answer)
|
||||
|
||||
def handle_message(self, message, history: Optional[list[str]] = None):
|
||||
if message.startswith("/add "):
|
||||
response = self.add_data(message)
|
||||
else:
|
||||
response = self.ask_bot(message, history)
|
||||
return response
|
||||
|
||||
# def add_data(self, message):
|
||||
# data = message.split(" ")[-1]
|
||||
# try:
|
||||
# self.add(data)
|
||||
# response = f"Added data from: {data}"
|
||||
# except Exception:
|
||||
# logging.exception(f"Failed to add data {data}.")
|
||||
# response = "Some error occurred while adding data."
|
||||
# return response
|
||||
|
||||
def ask_bot(self, message, history: list[str]):
|
||||
try:
|
||||
self.app.llm.set_history(history=history)
|
||||
response = self.query(message)
|
||||
except Exception:
|
||||
logging.exception(f"Failed to query {message}.")
|
||||
response = "An error occurred. Please try again!"
|
||||
return response
|
||||
|
||||
def start(self):
|
||||
start_command()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_command()
|
||||
101
embedchain/embedchain/bots/slack.py
Normal file
101
embedchain/embedchain/bots/slack.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
|
||||
try:
|
||||
from flask import Flask, request
|
||||
from slack_sdk import WebClient
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for Slack are not installed."
|
||||
"Please install with `pip install slack-sdk==3.21.3 flask==2.3.3`"
|
||||
) from None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SLACK_BOT_TOKEN = os.environ.get("SLACK_BOT_TOKEN")
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class SlackBot(BaseBot):
|
||||
def __init__(self):
|
||||
self.client = WebClient(token=SLACK_BOT_TOKEN)
|
||||
self.chat_bot = App()
|
||||
self.recent_message = {"ts": 0, "channel": ""}
|
||||
super().__init__()
|
||||
|
||||
def handle_message(self, event_data):
|
||||
message = event_data.get("event")
|
||||
if message and "text" in message and message.get("subtype") != "bot_message":
|
||||
text: str = message["text"]
|
||||
if float(message.get("ts")) > float(self.recent_message["ts"]):
|
||||
self.recent_message["ts"] = message["ts"]
|
||||
self.recent_message["channel"] = message["channel"]
|
||||
if text.startswith("query"):
|
||||
_, question = text.split(" ", 1)
|
||||
try:
|
||||
response = self.chat_bot.chat(question)
|
||||
self.send_slack_message(message["channel"], response)
|
||||
logger.info("Query answered successfully!")
|
||||
except Exception as e:
|
||||
self.send_slack_message(message["channel"], "An error occurred. Please try again!")
|
||||
logger.error("Error occurred during 'query' command:", e)
|
||||
elif text.startswith("add"):
|
||||
_, data_type, url_or_text = text.split(" ", 2)
|
||||
if url_or_text.startswith("<") and url_or_text.endswith(">"):
|
||||
url_or_text = url_or_text[1:-1]
|
||||
try:
|
||||
self.chat_bot.add(url_or_text, data_type)
|
||||
self.send_slack_message(message["channel"], f"Added {data_type} : {url_or_text}")
|
||||
except ValueError as e:
|
||||
self.send_slack_message(message["channel"], f"Error: {str(e)}")
|
||||
logger.error("Error occurred during 'add' command:", e)
|
||||
except Exception as e:
|
||||
self.send_slack_message(message["channel"], f"Failed to add {data_type} : {url_or_text}")
|
||||
logger.error("Error occurred during 'add' command:", e)
|
||||
|
||||
def send_slack_message(self, channel, message):
|
||||
response = self.client.chat_postMessage(channel=channel, text=message)
|
||||
return response
|
||||
|
||||
def start(self, host="0.0.0.0", port=5000, debug=True):
|
||||
app = Flask(__name__)
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logger.info("\nGracefully shutting down the SlackBot...")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def chat():
|
||||
# Check if the request is a verification request
|
||||
if request.json.get("challenge"):
|
||||
return str(request.json.get("challenge"))
|
||||
|
||||
response = self.handle_message(request.json)
|
||||
return str(response)
|
||||
|
||||
app.run(host=host, port=port, debug=debug)
|
||||
|
||||
|
||||
def start_command():
|
||||
parser = argparse.ArgumentParser(description="EmbedChain SlackBot command line interface")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Host IP to bind")
|
||||
parser.add_argument("--port", default=5000, type=int, help="Port to bind")
|
||||
args = parser.parse_args()
|
||||
|
||||
slack_bot = SlackBot()
|
||||
slack_bot.start(host=args.host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_command()
|
||||
83
embedchain/embedchain/bots/whatsapp.py
Normal file
83
embedchain/embedchain/bots/whatsapp.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class WhatsAppBot(BaseBot):
|
||||
def __init__(self):
|
||||
try:
|
||||
self.flask = importlib.import_module("flask")
|
||||
self.twilio = importlib.import_module("twilio")
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for WhatsApp are not installed. "
|
||||
"Please install with `pip install twilio==8.5.0 flask==2.3.3`"
|
||||
) from None
|
||||
super().__init__()
|
||||
|
||||
def handle_message(self, message):
|
||||
if message.startswith("add "):
|
||||
response = self.add_data(message)
|
||||
else:
|
||||
response = self.ask_bot(message)
|
||||
return response
|
||||
|
||||
def add_data(self, message):
|
||||
data = message.split(" ")[-1]
|
||||
try:
|
||||
self.add(data)
|
||||
response = f"Added data from: {data}"
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add data {data}.")
|
||||
response = "Some error occurred while adding data."
|
||||
return response
|
||||
|
||||
def ask_bot(self, message):
|
||||
try:
|
||||
response = self.query(message)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to query {message}.")
|
||||
response = "An error occurred. Please try again!"
|
||||
return response
|
||||
|
||||
def start(self, host="0.0.0.0", port=5000, debug=True):
|
||||
app = self.flask.Flask(__name__)
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logger.info("\nGracefully shutting down the WhatsAppBot...")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
@app.route("/chat", methods=["POST"])
|
||||
def chat():
|
||||
incoming_message = self.flask.request.values.get("Body", "").lower()
|
||||
response = self.handle_message(incoming_message)
|
||||
twilio_response = self.twilio.twiml.messaging_response.MessagingResponse()
|
||||
twilio_response.message(response)
|
||||
return str(twilio_response)
|
||||
|
||||
app.run(host=host, port=port, debug=debug)
|
||||
|
||||
|
||||
def start_command():
|
||||
parser = argparse.ArgumentParser(description="EmbedChain WhatsAppBot command line interface")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Host IP to bind")
|
||||
parser.add_argument("--port", default=5000, type=int, help="Port to bind")
|
||||
args = parser.parse_args()
|
||||
|
||||
whatsapp_bot = WhatsAppBot()
|
||||
whatsapp_bot.start(host=args.host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_command()
|
||||
44
embedchain/embedchain/cache.py
Normal file
44
embedchain/embedchain/cache.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
import os # noqa: F401
|
||||
from typing import Any
|
||||
|
||||
from gptcache import cache # noqa: F401
|
||||
from gptcache.adapter.adapter import adapt # noqa: F401
|
||||
from gptcache.config import Config # noqa: F401
|
||||
from gptcache.manager import get_data_manager
|
||||
from gptcache.manager.scalar_data.base import Answer
|
||||
from gptcache.manager.scalar_data.base import DataType as CacheDataType
|
||||
from gptcache.session import Session
|
||||
from gptcache.similarity_evaluation.distance import \
|
||||
SearchDistanceEvaluation # noqa: F401
|
||||
from gptcache.similarity_evaluation.exact_match import \
|
||||
ExactMatchEvaluation # noqa: F401
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def gptcache_pre_function(data: dict[str, Any], **params: dict[str, Any]):
|
||||
return data["input_query"]
|
||||
|
||||
|
||||
def gptcache_data_manager(vector_dimension):
|
||||
return get_data_manager(cache_base="sqlite", vector_base="chromadb", max_size=1000, eviction="LRU")
|
||||
|
||||
|
||||
def gptcache_data_convert(cache_data):
|
||||
logger.info("[Cache] Cache hit, returning cache data...")
|
||||
return cache_data
|
||||
|
||||
|
||||
def gptcache_update_cache_callback(llm_data, update_cache_func, *args, **kwargs):
|
||||
logger.info("[Cache] Cache missed, updating cache...")
|
||||
update_cache_func(Answer(llm_data, CacheDataType.STR))
|
||||
return llm_data
|
||||
|
||||
|
||||
def _gptcache_session_hit_func(cur_session_id: str, cache_session_ids: list, cache_questions: list, cache_answer: str):
|
||||
return cur_session_id in cache_session_ids
|
||||
|
||||
|
||||
def get_gptcache_session(session_id: str):
|
||||
return Session(name=session_id, check_hit_func=_gptcache_session_hit_func)
|
||||
0
embedchain/embedchain/chunkers/__init__.py
Normal file
0
embedchain/embedchain/chunkers/__init__.py
Normal file
22
embedchain/embedchain/chunkers/audio.py
Normal file
22
embedchain/embedchain/chunkers/audio.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AudioChunker(BaseChunker):
|
||||
"""Chunker for audio."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
87
embedchain/embedchain/chunkers/base_chunker.py
Normal file
87
embedchain/embedchain/chunkers/base_chunker.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseChunker(JSONSerializable):
|
||||
def __init__(self, text_splitter):
|
||||
"""Initialize the chunker."""
|
||||
self.text_splitter = text_splitter
|
||||
self.data_type = None
|
||||
|
||||
def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
|
||||
"""
|
||||
Loads data and chunks it.
|
||||
|
||||
:param loader: The loader whose `load_data` method is used to create
|
||||
the raw data.
|
||||
:param src: The data to be handled by the loader. Can be a URL for
|
||||
remote sources or local content for local loaders.
|
||||
:param app_id: App id used to generate the doc_id.
|
||||
"""
|
||||
documents = []
|
||||
chunk_ids = []
|
||||
id_map = {}
|
||||
min_chunk_size = config.min_chunk_size if config is not None else 1
|
||||
logger.info(f"Skipping chunks smaller than {min_chunk_size} characters")
|
||||
data_result = loader.load_data(src)
|
||||
data_records = data_result["data"]
|
||||
doc_id = data_result["doc_id"]
|
||||
# Prefix app_id in the document id if app_id is not None to
|
||||
# distinguish between different documents stored in the same
|
||||
# elasticsearch or opensearch index
|
||||
doc_id = f"{app_id}--{doc_id}" if app_id is not None else doc_id
|
||||
metadatas = []
|
||||
for data in data_records:
|
||||
content = data["content"]
|
||||
|
||||
metadata = data["meta_data"]
|
||||
# add data type to meta data to allow query using data type
|
||||
metadata["data_type"] = self.data_type.value
|
||||
metadata["doc_id"] = doc_id
|
||||
|
||||
# TODO: Currently defaulting to the src as the url. This is done intentianally since some
|
||||
# of the data types like 'gmail' loader doesn't have the url in the meta data.
|
||||
url = metadata.get("url", src)
|
||||
|
||||
chunks = self.get_chunks(content)
|
||||
for chunk in chunks:
|
||||
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
|
||||
chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
|
||||
if id_map.get(chunk_id) is None and len(chunk) >= min_chunk_size:
|
||||
id_map[chunk_id] = True
|
||||
chunk_ids.append(chunk_id)
|
||||
documents.append(chunk)
|
||||
metadatas.append(metadata)
|
||||
return {
|
||||
"documents": documents,
|
||||
"ids": chunk_ids,
|
||||
"metadatas": metadatas,
|
||||
"doc_id": doc_id,
|
||||
}
|
||||
|
||||
def get_chunks(self, content):
|
||||
"""
|
||||
Returns chunks using text splitter instance.
|
||||
|
||||
Override in child class if custom logic.
|
||||
"""
|
||||
return self.text_splitter.split_text(content)
|
||||
|
||||
def set_data_type(self, data_type: DataType):
|
||||
"""
|
||||
set the data type of chunker
|
||||
"""
|
||||
self.data_type = data_type
|
||||
|
||||
# TODO: This should be done during initialization. This means it has to be done in the child classes.
|
||||
|
||||
@staticmethod
|
||||
def get_word_count(documents) -> int:
|
||||
return sum(len(document.split(" ")) for document in documents)
|
||||
22
embedchain/embedchain/chunkers/beehiiv.py
Normal file
22
embedchain/embedchain/chunkers/beehiiv.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class BeehiivChunker(BaseChunker):
|
||||
"""Chunker for Beehiiv."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/common_chunker.py
Normal file
22
embedchain/embedchain/chunkers/common_chunker.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CommonChunker(BaseChunker):
|
||||
"""Common chunker for all loaders."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/discourse.py
Normal file
22
embedchain/embedchain/chunkers/discourse.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DiscourseChunker(BaseChunker):
|
||||
"""Chunker for discourse."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/docs_site.py
Normal file
22
embedchain/embedchain/chunkers/docs_site.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DocsSiteChunker(BaseChunker):
|
||||
"""Chunker for code docs site."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=500, chunk_overlap=50, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/docx_file.py
Normal file
22
embedchain/embedchain/chunkers/docx_file.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DocxFileChunker(BaseChunker):
|
||||
"""Chunker for .docx file."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/excel_file.py
Normal file
22
embedchain/embedchain/chunkers/excel_file.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ExcelFileChunker(BaseChunker):
|
||||
"""Chunker for Excel file."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/gmail.py
Normal file
22
embedchain/embedchain/chunkers/gmail.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class GmailChunker(BaseChunker):
|
||||
"""Chunker for gmail."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/google_drive.py
Normal file
22
embedchain/embedchain/chunkers/google_drive.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class GoogleDriveChunker(BaseChunker):
|
||||
"""Chunker for google drive folder."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/image.py
Normal file
22
embedchain/embedchain/chunkers/image.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ImageChunker(BaseChunker):
|
||||
"""Chunker for Images."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/json.py
Normal file
22
embedchain/embedchain/chunkers/json.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class JSONChunker(BaseChunker):
|
||||
"""Chunker for json."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/mdx.py
Normal file
22
embedchain/embedchain/chunkers/mdx.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class MdxChunker(BaseChunker):
|
||||
"""Chunker for mdx files."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/mysql.py
Normal file
22
embedchain/embedchain/chunkers/mysql.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class MySQLChunker(BaseChunker):
|
||||
"""Chunker for json."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/notion.py
Normal file
22
embedchain/embedchain/chunkers/notion.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class NotionChunker(BaseChunker):
|
||||
"""Chunker for notion."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
18
embedchain/embedchain/chunkers/openapi.py
Normal file
18
embedchain/embedchain/chunkers/openapi.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
|
||||
|
||||
class OpenAPIChunker(BaseChunker):
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/pdf_file.py
Normal file
22
embedchain/embedchain/chunkers/pdf_file.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PdfFileChunker(BaseChunker):
|
||||
"""Chunker for PDF file."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/postgres.py
Normal file
22
embedchain/embedchain/chunkers/postgres.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PostgresChunker(BaseChunker):
|
||||
"""Chunker for postgres."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/qna_pair.py
Normal file
22
embedchain/embedchain/chunkers/qna_pair.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class QnaPairChunker(BaseChunker):
|
||||
"""Chunker for QnA pair."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/rss_feed.py
Normal file
22
embedchain/embedchain/chunkers/rss_feed.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class RSSFeedChunker(BaseChunker):
|
||||
"""Chunker for RSS Feed."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/sitemap.py
Normal file
22
embedchain/embedchain/chunkers/sitemap.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class SitemapChunker(BaseChunker):
|
||||
"""Chunker for sitemap."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/slack.py
Normal file
22
embedchain/embedchain/chunkers/slack.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class SlackChunker(BaseChunker):
|
||||
"""Chunker for postgres."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/substack.py
Normal file
22
embedchain/embedchain/chunkers/substack.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class SubstackChunker(BaseChunker):
|
||||
"""Chunker for Substack."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
20
embedchain/embedchain/chunkers/table.py
Normal file
20
embedchain/embedchain/chunkers/table.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
|
||||
|
||||
class TableChunker(BaseChunker):
|
||||
"""Chunker for tables, for instance csv, google sheets or databases."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/text.py
Normal file
22
embedchain/embedchain/chunkers/text.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class TextChunker(BaseChunker):
|
||||
"""Chunker for text."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/unstructured_file.py
Normal file
22
embedchain/embedchain/chunkers/unstructured_file.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class UnstructuredFileChunker(BaseChunker):
|
||||
"""Chunker for Unstructured file."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/web_page.py
Normal file
22
embedchain/embedchain/chunkers/web_page.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class WebPageChunker(BaseChunker):
|
||||
"""Chunker for web page."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/xml.py
Normal file
22
embedchain/embedchain/chunkers/xml.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class XmlChunker(BaseChunker):
|
||||
"""Chunker for XML files."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=500, chunk_overlap=50, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
22
embedchain/embedchain/chunkers/youtube_video.py
Normal file
22
embedchain/embedchain/chunkers/youtube_video.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class YoutubeVideoChunker(BaseChunker):
|
||||
"""Chunker for Youtube video."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
327
embedchain/embedchain/cli.py
Normal file
327
embedchain/embedchain/cli.py
Normal file
@@ -0,0 +1,327 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import requests
|
||||
from rich.console import Console
|
||||
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils.cli import (deploy_fly, deploy_gradio_app,
|
||||
deploy_hf_spaces, deploy_modal,
|
||||
deploy_render, deploy_streamlit,
|
||||
get_pkg_path_from_name, setup_fly_io_app,
|
||||
setup_gradio_app, setup_hf_app,
|
||||
setup_modal_com_app, setup_render_com_app,
|
||||
setup_streamlit_io_app)
|
||||
|
||||
console = Console()
|
||||
api_process = None
|
||||
ui_process = None
|
||||
|
||||
anonymous_telemetry = AnonymousTelemetry()
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
"""Signal handler to catch termination signals and kill server processes."""
|
||||
global api_process, ui_process
|
||||
console.print("\n🛑 [bold yellow]Stopping servers...[/bold yellow]")
|
||||
if api_process:
|
||||
api_process.terminate()
|
||||
console.print("🛑 [bold yellow]API server stopped.[/bold yellow]")
|
||||
if ui_process:
|
||||
ui_process.terminate()
|
||||
console.print("🛑 [bold yellow]UI server stopped.[/bold yellow]")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("app_name")
|
||||
@click.option("--docker", is_flag=True, help="Use docker to create the app.")
|
||||
@click.pass_context
|
||||
def create_app(ctx, app_name, docker):
|
||||
if Path(app_name).exists():
|
||||
console.print(
|
||||
f"❌ [red]Directory '{app_name}' already exists. Try using a new directory name, or remove it.[/red]"
|
||||
)
|
||||
return
|
||||
|
||||
os.makedirs(app_name)
|
||||
os.chdir(app_name)
|
||||
|
||||
# Step 1: Download the zip file
|
||||
zip_url = "http://github.com/embedchain/ec-admin/archive/main.zip"
|
||||
console.print(f"Creating a new embedchain app in [green]{Path().resolve()}[/green]\n")
|
||||
try:
|
||||
response = requests.get(zip_url)
|
||||
response.raise_for_status()
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
||||
tmp_file.write(response.content)
|
||||
zip_file_path = tmp_file.name
|
||||
console.print("✅ [bold green]Fetched template successfully.[/bold green]")
|
||||
except requests.RequestException as e:
|
||||
console.print(f"❌ [bold red]Failed to download zip file: {e}[/bold red]")
|
||||
anonymous_telemetry.capture(event_name="ec_create_app", properties={"success": False})
|
||||
return
|
||||
|
||||
# Step 2: Extract the zip file
|
||||
try:
|
||||
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
||||
# Get the name of the root directory inside the zip file
|
||||
root_dir = Path(zip_ref.namelist()[0])
|
||||
for member in zip_ref.infolist():
|
||||
# Build the path to extract the file to, skipping the root directory
|
||||
target_file = Path(member.filename).relative_to(root_dir)
|
||||
source_file = zip_ref.open(member, "r")
|
||||
if member.is_dir():
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(target_file, exist_ok=True)
|
||||
else:
|
||||
with open(target_file, "wb") as file:
|
||||
# Write the file
|
||||
shutil.copyfileobj(source_file, file)
|
||||
console.print("✅ [bold green]Extracted zip file successfully.[/bold green]")
|
||||
anonymous_telemetry.capture(event_name="ec_create_app", properties={"success": True})
|
||||
except zipfile.BadZipFile:
|
||||
console.print("❌ [bold red]Error in extracting zip file. The file might be corrupted.[/bold red]")
|
||||
anonymous_telemetry.capture(event_name="ec_create_app", properties={"success": False})
|
||||
return
|
||||
|
||||
if docker:
|
||||
subprocess.run(["docker-compose", "build"], check=True)
|
||||
else:
|
||||
ctx.invoke(install_reqs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
def install_reqs():
|
||||
try:
|
||||
console.print("Installing python requirements...\n")
|
||||
time.sleep(2)
|
||||
os.chdir("api")
|
||||
subprocess.run(["pip", "install", "-r", "requirements.txt"], check=True)
|
||||
os.chdir("..")
|
||||
console.print("\n ✅ [bold green]Installed API requirements successfully.[/bold green]\n")
|
||||
except Exception as e:
|
||||
console.print(f"❌ [bold red]Failed to install API requirements: {e}[/bold red]")
|
||||
anonymous_telemetry.capture(event_name="ec_install_reqs", properties={"success": False})
|
||||
return
|
||||
|
||||
try:
|
||||
os.chdir("ui")
|
||||
subprocess.run(["yarn"], check=True)
|
||||
console.print("\n✅ [bold green]Successfully installed frontend requirements.[/bold green]")
|
||||
anonymous_telemetry.capture(event_name="ec_install_reqs", properties={"success": True})
|
||||
except Exception as e:
|
||||
console.print(f"❌ [bold red]Failed to install frontend requirements. Error: {e}[/bold red]")
|
||||
anonymous_telemetry.capture(event_name="ec_install_reqs", properties={"success": False})
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--docker", is_flag=True, help="Run inside docker.")
|
||||
def start(docker):
|
||||
if docker:
|
||||
subprocess.run(["docker-compose", "up"], check=True)
|
||||
return
|
||||
|
||||
# Set up signal handling
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Step 1: Start the API server
|
||||
try:
|
||||
os.chdir("api")
|
||||
api_process = subprocess.Popen(["python", "-m", "main"], stdout=None, stderr=None)
|
||||
os.chdir("..")
|
||||
console.print("✅ [bold green]API server started successfully.[/bold green]")
|
||||
except Exception as e:
|
||||
console.print(f"❌ [bold red]Failed to start the API server: {e}[/bold red]")
|
||||
anonymous_telemetry.capture(event_name="ec_start", properties={"success": False})
|
||||
return
|
||||
|
||||
# Sleep for 2 seconds to give the user time to read the message
|
||||
time.sleep(2)
|
||||
|
||||
# Step 2: Install UI requirements and start the UI server
|
||||
try:
|
||||
os.chdir("ui")
|
||||
subprocess.run(["yarn"], check=True)
|
||||
ui_process = subprocess.Popen(["yarn", "dev"])
|
||||
console.print("✅ [bold green]UI server started successfully.[/bold green]")
|
||||
anonymous_telemetry.capture(event_name="ec_start", properties={"success": True})
|
||||
except Exception as e:
|
||||
console.print(f"❌ [bold red]Failed to start the UI server: {e}[/bold red]")
|
||||
anonymous_telemetry.capture(event_name="ec_start", properties={"success": False})
|
||||
|
||||
# Keep the script running until it receives a kill signal
|
||||
try:
|
||||
api_process.wait()
|
||||
ui_process.wait()
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n🛑 [bold yellow]Stopping server...[/bold yellow]")
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--template", default="fly.io", help="The template to use.")
|
||||
@click.argument("extra_args", nargs=-1, type=click.UNPROCESSED)
|
||||
def create(template, extra_args):
|
||||
anonymous_telemetry.capture(event_name="ec_create", properties={"template_used": template})
|
||||
template_dir = template
|
||||
if "/" in template_dir:
|
||||
template_dir = template.split("/")[1]
|
||||
src_path = get_pkg_path_from_name(template_dir)
|
||||
shutil.copytree(src_path, os.getcwd(), dirs_exist_ok=True)
|
||||
console.print(f"✅ [bold green]Successfully created app from template '{template}'.[/bold green]")
|
||||
|
||||
if template == "fly.io":
|
||||
setup_fly_io_app(extra_args)
|
||||
elif template == "modal.com":
|
||||
setup_modal_com_app(extra_args)
|
||||
elif template == "render.com":
|
||||
setup_render_com_app()
|
||||
elif template == "streamlit.io":
|
||||
setup_streamlit_io_app()
|
||||
elif template == "gradio.app":
|
||||
setup_gradio_app()
|
||||
elif template == "hf/gradio.app" or template == "hf/streamlit.io":
|
||||
setup_hf_app()
|
||||
else:
|
||||
raise ValueError(f"Unknown template '{template}'.")
|
||||
|
||||
embedchain_config = {"provider": template}
|
||||
with open("embedchain.json", "w") as file:
|
||||
json.dump(embedchain_config, file, indent=4)
|
||||
console.print(
|
||||
f"🎉 [green]All done! Successfully created `embedchain.json` with '{template}' as provider.[/green]"
|
||||
)
|
||||
|
||||
|
||||
def run_dev_fly_io(debug, host, port):
|
||||
uvicorn_command = ["uvicorn", "app:app"]
|
||||
|
||||
if debug:
|
||||
uvicorn_command.append("--reload")
|
||||
|
||||
uvicorn_command.extend(["--host", host, "--port", str(port)])
|
||||
|
||||
try:
|
||||
console.print(f"🚀 [bold cyan]Running FastAPI app with command: {' '.join(uvicorn_command)}[/bold cyan]")
|
||||
subprocess.run(uvicorn_command, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n🛑 [bold yellow]FastAPI server stopped[/bold yellow]")
|
||||
|
||||
|
||||
def run_dev_modal_com():
|
||||
modal_run_cmd = ["modal", "serve", "app"]
|
||||
try:
|
||||
console.print(f"🚀 [bold cyan]Running FastAPI app with command: {' '.join(modal_run_cmd)}[/bold cyan]")
|
||||
subprocess.run(modal_run_cmd, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n🛑 [bold yellow]FastAPI server stopped[/bold yellow]")
|
||||
|
||||
|
||||
def run_dev_streamlit_io():
|
||||
streamlit_run_cmd = ["streamlit", "run", "app.py"]
|
||||
try:
|
||||
console.print(f"🚀 [bold cyan]Running Streamlit app with command: {' '.join(streamlit_run_cmd)}[/bold cyan]")
|
||||
subprocess.run(streamlit_run_cmd, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n🛑 [bold yellow]Streamlit server stopped[/bold yellow]")
|
||||
|
||||
|
||||
def run_dev_render_com(debug, host, port):
|
||||
uvicorn_command = ["uvicorn", "app:app"]
|
||||
|
||||
if debug:
|
||||
uvicorn_command.append("--reload")
|
||||
|
||||
uvicorn_command.extend(["--host", host, "--port", str(port)])
|
||||
|
||||
try:
|
||||
console.print(f"🚀 [bold cyan]Running FastAPI app with command: {' '.join(uvicorn_command)}[/bold cyan]")
|
||||
subprocess.run(uvicorn_command, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n🛑 [bold yellow]FastAPI server stopped[/bold yellow]")
|
||||
|
||||
|
||||
def run_dev_gradio():
|
||||
gradio_run_cmd = ["gradio", "app.py"]
|
||||
try:
|
||||
console.print(f"🚀 [bold cyan]Running Gradio app with command: {' '.join(gradio_run_cmd)}[/bold cyan]")
|
||||
subprocess.run(gradio_run_cmd, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n🛑 [bold yellow]Gradio server stopped[/bold yellow]")
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--debug", is_flag=True, help="Enable or disable debug mode.")
|
||||
@click.option("--host", default="127.0.0.1", help="The host address to run the FastAPI app on.")
|
||||
@click.option("--port", default=8000, help="The port to run the FastAPI app on.")
|
||||
def dev(debug, host, port):
|
||||
template = ""
|
||||
with open("embedchain.json", "r") as file:
|
||||
embedchain_config = json.load(file)
|
||||
template = embedchain_config["provider"]
|
||||
|
||||
anonymous_telemetry.capture(event_name="ec_dev", properties={"template_used": template})
|
||||
if template == "fly.io":
|
||||
run_dev_fly_io(debug, host, port)
|
||||
elif template == "modal.com":
|
||||
run_dev_modal_com()
|
||||
elif template == "render.com":
|
||||
run_dev_render_com(debug, host, port)
|
||||
elif template == "streamlit.io" or template == "hf/streamlit.io":
|
||||
run_dev_streamlit_io()
|
||||
elif template == "gradio.app" or template == "hf/gradio.app":
|
||||
run_dev_gradio()
|
||||
else:
|
||||
raise ValueError(f"Unknown template '{template}'.")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def deploy():
|
||||
# Check for platform-specific files
|
||||
template = ""
|
||||
ec_app_name = ""
|
||||
with open("embedchain.json", "r") as file:
|
||||
embedchain_config = json.load(file)
|
||||
ec_app_name = embedchain_config["name"] if "name" in embedchain_config else None
|
||||
template = embedchain_config["provider"]
|
||||
|
||||
anonymous_telemetry.capture(event_name="ec_deploy", properties={"template_used": template})
|
||||
if template == "fly.io":
|
||||
deploy_fly()
|
||||
elif template == "modal.com":
|
||||
deploy_modal()
|
||||
elif template == "render.com":
|
||||
deploy_render()
|
||||
elif template == "streamlit.io":
|
||||
deploy_streamlit()
|
||||
elif template == "gradio.app":
|
||||
deploy_gradio_app()
|
||||
elif template.startswith("hf/"):
|
||||
deploy_hf_spaces(ec_app_name)
|
||||
else:
|
||||
console.print("❌ [bold red]No recognized deployment platform found.[/bold red]")
|
||||
103
embedchain/embedchain/client.py
Normal file
103
embedchain/embedchain/client.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
from embedchain.constants import CONFIG_DIR, CONFIG_FILE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, api_key=None, host="https://apiv2.embedchain.ai"):
|
||||
self.config_data = self.load_config()
|
||||
self.host = host
|
||||
|
||||
if api_key:
|
||||
if self.check(api_key):
|
||||
self.api_key = api_key
|
||||
self.save()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid API key provided. You can find your API key on https://app.embedchain.ai/settings/keys."
|
||||
)
|
||||
else:
|
||||
if "api_key" in self.config_data:
|
||||
self.api_key = self.config_data["api_key"]
|
||||
logger.info("API key loaded successfully!")
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are not logged in. Please obtain an API key from https://app.embedchain.ai/settings/keys/"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def setup(cls):
|
||||
"""
|
||||
Loads the user id from the config file if it exists, otherwise generates a new
|
||||
one and saves it to the config file.
|
||||
|
||||
:return: user id
|
||||
:rtype: str
|
||||
"""
|
||||
os.makedirs(CONFIG_DIR, exist_ok=True)
|
||||
|
||||
if os.path.exists(CONFIG_FILE):
|
||||
with open(CONFIG_FILE, "r") as f:
|
||||
data = json.load(f)
|
||||
if "user_id" in data:
|
||||
return data["user_id"]
|
||||
|
||||
u_id = str(uuid.uuid4())
|
||||
with open(CONFIG_FILE, "w") as f:
|
||||
json.dump({"user_id": u_id}, f)
|
||||
|
||||
@classmethod
|
||||
def load_config(cls):
|
||||
if not os.path.exists(CONFIG_FILE):
|
||||
cls.setup()
|
||||
|
||||
with open(CONFIG_FILE, "r") as config_file:
|
||||
return json.load(config_file)
|
||||
|
||||
def save(self):
|
||||
self.config_data["api_key"] = self.api_key
|
||||
with open(CONFIG_FILE, "w") as config_file:
|
||||
json.dump(self.config_data, config_file, indent=4)
|
||||
|
||||
logger.info("API key saved successfully!")
|
||||
|
||||
def clear(self):
|
||||
if "api_key" in self.config_data:
|
||||
del self.config_data["api_key"]
|
||||
with open(CONFIG_FILE, "w") as config_file:
|
||||
json.dump(self.config_data, config_file, indent=4)
|
||||
self.api_key = None
|
||||
logger.info("API key deleted successfully!")
|
||||
else:
|
||||
logger.warning("API key not found in the configuration file.")
|
||||
|
||||
def update(self, api_key):
|
||||
if self.check(api_key):
|
||||
self.api_key = api_key
|
||||
self.save()
|
||||
logger.info("API key updated successfully!")
|
||||
else:
|
||||
logger.warning("Invalid API key provided. API key not updated.")
|
||||
|
||||
def check(self, api_key):
|
||||
validation_url = f"{self.host}/api/v1/accounts/api_keys/validate/"
|
||||
response = requests.post(validation_url, headers={"Authorization": f"Token {api_key}"})
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Response from API: {response.text}")
|
||||
logger.warning("Invalid API key. Unable to validate.")
|
||||
return False
|
||||
|
||||
def get(self):
|
||||
return self.api_key
|
||||
|
||||
def __str__(self):
|
||||
return self.api_key
|
||||
15
embedchain/embedchain/config/__init__.py
Normal file
15
embedchain/embedchain/config/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from .add_config import AddConfig, ChunkerConfig
|
||||
from .app_config import AppConfig
|
||||
from .base_config import BaseConfig
|
||||
from .cache_config import CacheConfig
|
||||
from .embedder.base import BaseEmbedderConfig
|
||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||
from .embedder.ollama import OllamaEmbedderConfig
|
||||
from .llm.base import BaseLlmConfig
|
||||
from .vector_db.chroma import ChromaDbConfig
|
||||
from .vector_db.elasticsearch import ElasticsearchDBConfig
|
||||
from .vector_db.opensearch import OpenSearchDBConfig
|
||||
from .vector_db.zilliz import ZillizDBConfig
|
||||
from .mem0_config import Mem0Config
|
||||
79
embedchain/embedchain/config/add_config.py
Normal file
79
embedchain/embedchain/config/add_config.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import builtins
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from importlib import import_module
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ChunkerConfig(BaseConfig):
|
||||
"""
|
||||
Config for the chunker used in `add` method
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: Optional[int] = 2000,
|
||||
chunk_overlap: Optional[int] = 0,
|
||||
length_function: Optional[Callable[[str], int]] = None,
|
||||
min_chunk_size: Optional[int] = 0,
|
||||
):
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.min_chunk_size = min_chunk_size
|
||||
if self.min_chunk_size >= self.chunk_size:
|
||||
raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
|
||||
if self.min_chunk_size < self.chunk_overlap:
|
||||
logging.warning(
|
||||
f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant." # noqa:E501
|
||||
)
|
||||
|
||||
if isinstance(length_function, str):
|
||||
self.length_function = self.load_func(length_function)
|
||||
else:
|
||||
self.length_function = length_function if length_function else len
|
||||
|
||||
@staticmethod
|
||||
def load_func(dotpath: str):
|
||||
if "." not in dotpath:
|
||||
return getattr(builtins, dotpath)
|
||||
else:
|
||||
module_, func = dotpath.rsplit(".", maxsplit=1)
|
||||
m = import_module(module_)
|
||||
return getattr(m, func)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class LoaderConfig(BaseConfig):
|
||||
"""
|
||||
Config for the loader used in `add` method
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AddConfig(BaseConfig):
|
||||
"""
|
||||
Config for the `add` method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunker: Optional[ChunkerConfig] = None,
|
||||
loader: Optional[LoaderConfig] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the `add` method.
|
||||
|
||||
:param chunker: Chunker config, defaults to None
|
||||
:type chunker: Optional[ChunkerConfig], optional
|
||||
:param loader: Loader config, defaults to None
|
||||
:type loader: Optional[LoaderConfig], optional
|
||||
"""
|
||||
self.loader = loader
|
||||
self.chunker = chunker
|
||||
34
embedchain/embedchain/config/app_config.py
Normal file
34
embedchain/embedchain/config/app_config.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .base_app_config import BaseAppConfig
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AppConfig(BaseAppConfig):
|
||||
"""
|
||||
Config to initialize an embedchain custom `App` instance, with extra config options.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_level: str = "WARNING",
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
collect_metrics: Optional[bool] = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an App. This is the simplest form of an embedchain app.
|
||||
Most of the configuration is done in the `App` class itself.
|
||||
|
||||
:param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
|
||||
:type log_level: str, optional
|
||||
:param id: ID of the app. Document metadata will have this id., defaults to None
|
||||
:type id: Optional[str], optional
|
||||
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
|
||||
:type collect_metrics: Optional[bool], optional
|
||||
"""
|
||||
self.name = name
|
||||
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, **kwargs)
|
||||
58
embedchain/embedchain/config/base_app_config.py
Normal file
58
embedchain/embedchain/config/base_app_config.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAppConfig(BaseConfig, JSONSerializable):
|
||||
"""
|
||||
Parent config to initialize an instance of `App`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_level: str = "WARNING",
|
||||
db: Optional[BaseVectorDB] = None,
|
||||
id: Optional[str] = None,
|
||||
collect_metrics: bool = True,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an App.
|
||||
Most of the configuration is done in the `App` class itself.
|
||||
|
||||
:param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
|
||||
:type log_level: str, optional
|
||||
:param db: A database class. It is recommended to set this directly in the `App` class, not this config,
|
||||
defaults to None
|
||||
:type db: Optional[BaseVectorDB], optional
|
||||
:param id: ID of the app. Document metadata will have this id., defaults to None
|
||||
:type id: Optional[str], optional
|
||||
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
|
||||
:type collect_metrics: Optional[bool], optional
|
||||
:param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
|
||||
defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
"""
|
||||
self.id = id
|
||||
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
|
||||
self.collection_name = collection_name
|
||||
|
||||
if db:
|
||||
self._db = db
|
||||
logger.warning(
|
||||
"DEPRECATION WARNING: Please supply the database as the second parameter during app init. "
|
||||
"Such as `app(config=config, db=db)`."
|
||||
)
|
||||
|
||||
if collection_name:
|
||||
logger.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
|
||||
return
|
||||
|
||||
def _setup_logging(self, log_level):
|
||||
logger.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=log_level)
|
||||
self.logger = logger.getLogger(__name__)
|
||||
21
embedchain/embedchain/config/base_config.py
Normal file
21
embedchain/embedchain/config/base_config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Any
|
||||
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
class BaseConfig(JSONSerializable):
|
||||
"""
|
||||
Base config.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes a configuration class for a class."""
|
||||
pass
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return config object as a dict
|
||||
|
||||
:return: config object as dict
|
||||
:rtype: dict[str, Any]
|
||||
"""
|
||||
return vars(self)
|
||||
96
embedchain/embedchain/config/cache_config.py
Normal file
96
embedchain/embedchain/config/cache_config.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CacheSimilarityEvalConfig(BaseConfig):
|
||||
"""
|
||||
This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage.
|
||||
In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been
|
||||
put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`].
|
||||
`positive` is used to indicate this distance is directly proportional to the similarity of two entities.
|
||||
If `positive` is set `False`, `max_distance` will be used to subtract this distance to get the final score.
|
||||
|
||||
:param max_distance: the bound of maximum distance.
|
||||
:type max_distance: float
|
||||
:param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise, it is False.
|
||||
:type positive: bool
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: Optional[str] = "distance",
|
||||
max_distance: Optional[float] = 1.0,
|
||||
positive: Optional[bool] = False,
|
||||
):
|
||||
self.strategy = strategy
|
||||
self.max_distance = max_distance
|
||||
self.positive = positive
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheSimilarityEvalConfig()
|
||||
else:
|
||||
return CacheSimilarityEvalConfig(
|
||||
strategy=config.get("strategy", "distance"),
|
||||
max_distance=config.get("max_distance", 1.0),
|
||||
positive=config.get("positive", False),
|
||||
)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CacheInitConfig(BaseConfig):
|
||||
"""
|
||||
This is a cache init config. Used to initialize a cache.
|
||||
|
||||
:param similarity_threshold: a threshold ranged from 0 to 1 to filter search results with similarity score higher \
|
||||
than the threshold. When it is 0, there is no hits. When it is 1, all search results will be returned as hits.
|
||||
:type similarity_threshold: float
|
||||
:param auto_flush: it will be automatically flushed every time xx pieces of data are added, default to 20
|
||||
:type auto_flush: int
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
similarity_threshold: Optional[float] = 0.8,
|
||||
auto_flush: Optional[int] = 20,
|
||||
):
|
||||
if similarity_threshold < 0 or similarity_threshold > 1:
|
||||
raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
|
||||
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.auto_flush = auto_flush
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheInitConfig()
|
||||
else:
|
||||
return CacheInitConfig(
|
||||
similarity_threshold=config.get("similarity_threshold", 0.8),
|
||||
auto_flush=config.get("auto_flush", 20),
|
||||
)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CacheConfig(BaseConfig):
|
||||
def __init__(
|
||||
self,
|
||||
similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(),
|
||||
init_config: Optional[CacheInitConfig] = CacheInitConfig(),
|
||||
):
|
||||
self.similarity_eval_config = similarity_eval_config
|
||||
self.init_config = init_config
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheConfig()
|
||||
else:
|
||||
return CacheConfig(
|
||||
similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})),
|
||||
init_config=CacheInitConfig.from_config(config.get("init_config", {})),
|
||||
)
|
||||
0
embedchain/embedchain/config/embedder/__init__.py
Normal file
0
embedchain/embedchain/config/embedder/__init__.py
Normal file
42
embedchain/embedchain/config/embedder/base.py
Normal file
42
embedchain/embedchain/config/embedder/base.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class BaseEmbedderConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
deployment_name: Optional[str] = None,
|
||||
vector_dimension: Optional[int] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new instance of an embedder config class.
|
||||
|
||||
:param model: model name of the llm embedding model (not applicable to all providers), defaults to None
|
||||
:type model: Optional[str], optional
|
||||
:param deployment_name: deployment name for llm embedding model, defaults to None
|
||||
:type deployment_name: Optional[str], optional
|
||||
:param vector_dimension: vector dimension of the embedding model, defaults to None
|
||||
:type vector_dimension: Optional[int], optional
|
||||
:param endpoint: endpoint for the embedding model, defaults to None
|
||||
:type endpoint: Optional[str], optional
|
||||
:param api_key: hugginface api key, defaults to None
|
||||
:type api_key: Optional[str], optional
|
||||
:param api_base: huggingface api base, defaults to None
|
||||
:type api_base: Optional[str], optional
|
||||
:param model_kwargs: key-value arguments for the embedding model, defaults a dict inside init.
|
||||
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init.
|
||||
"""
|
||||
self.model = model
|
||||
self.deployment_name = deployment_name
|
||||
self.vector_dimension = vector_dimension
|
||||
self.endpoint = endpoint
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
19
embedchain/embedchain/config/embedder/google.py
Normal file
19
embedchain/embedchain/config/embedder/google.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class GoogleAIEmbedderConfig(BaseEmbedderConfig):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
deployment_name: Optional[str] = None,
|
||||
vector_dimension: Optional[int] = None,
|
||||
task_type: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
):
|
||||
super().__init__(model, deployment_name, vector_dimension)
|
||||
self.task_type = task_type or "retrieval_document"
|
||||
self.title = title or "Embeddings for Embedchain"
|
||||
16
embedchain/embedchain/config/embedder/ollama.py
Normal file
16
embedchain/embedchain/config/embedder/ollama.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class OllamaEmbedderConfig(BaseEmbedderConfig):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
vector_dimension: Optional[int] = None,
|
||||
):
|
||||
super().__init__(model=model, vector_dimension=vector_dimension)
|
||||
self.base_url = base_url or "http://localhost:11434"
|
||||
2
embedchain/embedchain/config/evaluation/__init__.py
Normal file
2
embedchain/embedchain/config/evaluation/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import (AnswerRelevanceConfig, ContextRelevanceConfig, # noqa: F401
|
||||
GroundednessConfig)
|
||||
92
embedchain/embedchain/config/evaluation/base.py
Normal file
92
embedchain/embedchain/config/evaluation/base.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
|
||||
ANSWER_RELEVANCY_PROMPT = """
|
||||
Please provide $num_gen_questions questions from the provided answer.
|
||||
You must provide the complete question, if are not able to provide the complete question, return empty string ("").
|
||||
Please only provide one question per line without numbers or bullets to distinguish them.
|
||||
You must only provide the questions and no other text.
|
||||
|
||||
$answer
|
||||
""" # noqa:E501
|
||||
|
||||
|
||||
CONTEXT_RELEVANCY_PROMPT = """
|
||||
Please extract relevant sentences from the provided context that is required to answer the given question.
|
||||
If no relevant sentences are found, or if you believe the question cannot be answered from the given context, return the empty string ("").
|
||||
While extracting candidate sentences you're not allowed to make any changes to sentences from given context or make up any sentences.
|
||||
You must only provide sentences from the given context and nothing else.
|
||||
|
||||
Context: $context
|
||||
Question: $question
|
||||
""" # noqa:E501
|
||||
|
||||
GROUNDEDNESS_ANSWER_CLAIMS_PROMPT = """
|
||||
Please provide one or more statements from each sentence of the provided answer.
|
||||
You must provide the symantically equivalent statements for each sentence of the answer.
|
||||
You must provide the complete statement, if are not able to provide the complete statement, return empty string ("").
|
||||
Please only provide one statement per line WITHOUT numbers or bullets.
|
||||
If the question provided is not being answered in the provided answer, return empty string ("").
|
||||
You must only provide the statements and no other text.
|
||||
|
||||
$question
|
||||
$answer
|
||||
""" # noqa:E501
|
||||
|
||||
GROUNDEDNESS_CLAIMS_INFERENCE_PROMPT = """
|
||||
Given the context and the provided claim statements, please provide a verdict for each claim statement whether it can be completely infered from the given context or not.
|
||||
Use only "1" (yes), "0" (no) and "-1" (null) for "yes", "no" or "null" respectively.
|
||||
You must provide one verdict per line, ONLY WITH "1", "0" or "-1" as per your verdict to the given statement and nothing else.
|
||||
You must provide the verdicts in the same order as the claim statements.
|
||||
|
||||
Contexts:
|
||||
$context
|
||||
|
||||
Claim statements:
|
||||
$claim_statements
|
||||
""" # noqa:E501
|
||||
|
||||
|
||||
class GroundednessConfig(BaseConfig):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4",
|
||||
api_key: Optional[str] = None,
|
||||
answer_claims_prompt: str = GROUNDEDNESS_ANSWER_CLAIMS_PROMPT,
|
||||
claims_inference_prompt: str = GROUNDEDNESS_CLAIMS_INFERENCE_PROMPT,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.answer_claims_prompt = answer_claims_prompt
|
||||
self.claims_inference_prompt = claims_inference_prompt
|
||||
|
||||
|
||||
class AnswerRelevanceConfig(BaseConfig):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4",
|
||||
embedder: str = "text-embedding-ada-002",
|
||||
api_key: Optional[str] = None,
|
||||
num_gen_questions: int = 1,
|
||||
prompt: str = ANSWER_RELEVANCY_PROMPT,
|
||||
):
|
||||
self.model = model
|
||||
self.embedder = embedder
|
||||
self.api_key = api_key
|
||||
self.num_gen_questions = num_gen_questions
|
||||
self.prompt = prompt
|
||||
|
||||
|
||||
class ContextRelevanceConfig(BaseConfig):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4",
|
||||
api_key: Optional[str] = None,
|
||||
language: str = "en",
|
||||
prompt: str = CONTEXT_RELEVANCY_PROMPT,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.language = language
|
||||
self.prompt = prompt
|
||||
0
embedchain/embedchain/config/llm/__init__.py
Normal file
0
embedchain/embedchain/config/llm/__init__.py
Normal file
275
embedchain/embedchain/config/llm/base.py
Normal file
275
embedchain/embedchain/config/llm/base.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from string import Template
|
||||
from typing import Any, Mapping, Optional, Dict, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_PROMPT = """
|
||||
You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. Here are some guidelines to follow:
|
||||
|
||||
1. Refrain from explicitly mentioning the context provided in your response.
|
||||
2. The context should silently guide your answers without being directly acknowledged.
|
||||
3. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
|
||||
|
||||
Context information:
|
||||
----------------------
|
||||
$context
|
||||
----------------------
|
||||
|
||||
Query: $query
|
||||
Answer:
|
||||
""" # noqa:E501
|
||||
|
||||
DEFAULT_PROMPT_WITH_HISTORY = """
|
||||
You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. You are also provided with the conversation history with the user. Make sure to use relevant context from conversation history as needed.
|
||||
|
||||
Here are some guidelines to follow:
|
||||
|
||||
1. Refrain from explicitly mentioning the context provided in your response.
|
||||
2. The context should silently guide your answers without being directly acknowledged.
|
||||
3. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
|
||||
|
||||
Context information:
|
||||
----------------------
|
||||
$context
|
||||
----------------------
|
||||
|
||||
Conversation history:
|
||||
----------------------
|
||||
$history
|
||||
----------------------
|
||||
|
||||
Query: $query
|
||||
Answer:
|
||||
""" # noqa:E501
|
||||
|
||||
DEFAULT_PROMPT_WITH_MEM0_MEMORY = """
|
||||
You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. You are also provided with the conversation history and memories with the user. Make sure to use relevant context from conversation history and memories as needed.
|
||||
|
||||
Here are some guidelines to follow:
|
||||
|
||||
1. Refrain from explicitly mentioning the context provided in your response.
|
||||
2. Take into consideration the conversation history and memories provided.
|
||||
3. The context should silently guide your answers without being directly acknowledged.
|
||||
4. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
|
||||
|
||||
Context information:
|
||||
----------------------
|
||||
$context
|
||||
----------------------
|
||||
|
||||
Conversation history:
|
||||
----------------------
|
||||
$history
|
||||
----------------------
|
||||
|
||||
Memories/Preferences:
|
||||
----------------------
|
||||
$memories
|
||||
----------------------
|
||||
|
||||
Query: $query
|
||||
Answer:
|
||||
""" # noqa:E501
|
||||
|
||||
DOCS_SITE_DEFAULT_PROMPT = """
|
||||
You are an expert AI assistant for developer support product. Your responses must always be rooted in the context provided for each query. Wherever possible, give complete code snippet. Dont make up any code snippet on your own.
|
||||
|
||||
Here are some guidelines to follow:
|
||||
|
||||
1. Refrain from explicitly mentioning the context provided in your response.
|
||||
2. The context should silently guide your answers without being directly acknowledged.
|
||||
3. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
|
||||
|
||||
Context information:
|
||||
----------------------
|
||||
$context
|
||||
----------------------
|
||||
|
||||
Query: $query
|
||||
Answer:
|
||||
""" # noqa:E501
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
|
||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
|
||||
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_MEM0_MEMORY)
|
||||
DOCS_SITE_PROMPT_TEMPLATE = Template(DOCS_SITE_DEFAULT_PROMPT)
|
||||
query_re = re.compile(r"\$\{*query\}*")
|
||||
context_re = re.compile(r"\$\{*context\}*")
|
||||
history_re = re.compile(r"\$\{*history\}*")
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class BaseLlmConfig(BaseConfig):
|
||||
"""
|
||||
Config for the `query` method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
number_documents: int = 3,
|
||||
template: Optional[Template] = None,
|
||||
prompt: Optional[Template] = None,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 1000,
|
||||
top_p: float = 1,
|
||||
stream: bool = False,
|
||||
online: bool = False,
|
||||
token_usage: bool = False,
|
||||
deployment_name: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
where: dict[str, Any] = None,
|
||||
query_type: Optional[str] = None,
|
||||
callbacks: Optional[list] = None,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
model_kwargs: Optional[dict[str, Any]] = None,
|
||||
http_client_proxies: Optional[Union[Dict, str]] = None,
|
||||
http_async_client_proxies: Optional[Union[Dict, str]] = None,
|
||||
local: Optional[bool] = False,
|
||||
default_headers: Optional[Mapping[str, str]] = None,
|
||||
api_version: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the LLM.
|
||||
|
||||
Takes the place of the former `QueryConfig` or `ChatConfig`.
|
||||
|
||||
:param number_documents: Number of documents to pull from the database as
|
||||
context, defaults to 1
|
||||
:type number_documents: int, optional
|
||||
:param template: The `Template` instance to use as a template for
|
||||
prompt, defaults to None (deprecated)
|
||||
:type template: Optional[Template], optional
|
||||
:param prompt: The `Template` instance to use as a template for
|
||||
prompt, defaults to None
|
||||
:type prompt: Optional[Template], optional
|
||||
:param model: Controls the OpenAI model used, defaults to None
|
||||
:type model: Optional[str], optional
|
||||
:param temperature: Controls the randomness of the model's output.
|
||||
Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
|
||||
:type temperature: float, optional
|
||||
:param max_tokens: Controls how many tokens are generated, defaults to 1000
|
||||
:type max_tokens: int, optional
|
||||
:param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
|
||||
defaults to 1
|
||||
:type top_p: float, optional
|
||||
:param stream: Control if response is streamed back to user, defaults to False
|
||||
:type stream: bool, optional
|
||||
:param online: Controls whether to use internet for answering query, defaults to False
|
||||
:type online: bool, optional
|
||||
:param token_usage: Controls whether to return token usage in response, defaults to False
|
||||
:type token_usage: bool, optional
|
||||
:param deployment_name: t.b.a., defaults to None
|
||||
:type deployment_name: Optional[str], optional
|
||||
:param system_prompt: System prompt string, defaults to None
|
||||
:type system_prompt: Optional[str], optional
|
||||
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
||||
:type where: dict[str, Any], optional
|
||||
:param api_key: The api key of the custom endpoint, defaults to None
|
||||
:type api_key: Optional[str], optional
|
||||
:param endpoint: The api url of the custom endpoint, defaults to None
|
||||
:type endpoint: Optional[str], optional
|
||||
:param model_kwargs: A dictionary of key-value pairs to pass to the model, defaults to None
|
||||
:type model_kwargs: Optional[Dict[str, Any]], optional
|
||||
:param callbacks: Langchain callback functions to use, defaults to None
|
||||
:type callbacks: Optional[list], optional
|
||||
:param query_type: The type of query to use, defaults to None
|
||||
:type query_type: Optional[str], optional
|
||||
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
|
||||
:type http_client_proxies: Optional[Dict | str], optional
|
||||
:param http_async_client_proxies: The proxy server settings for async calls used to create
|
||||
self.http_async_client, defaults to None
|
||||
:type http_async_client_proxies: Optional[Dict | str], optional
|
||||
:param local: If True, the model will be run locally, defaults to False (for huggingface provider)
|
||||
:type local: Optional[bool], optional
|
||||
:param default_headers: Set additional HTTP headers to be sent with requests to OpenAI
|
||||
:type default_headers: Optional[Mapping[str, str]], optional
|
||||
:raises ValueError: If the template is not valid as template should
|
||||
contain $context and $query (and optionally $history)
|
||||
:raises ValueError: Stream is not boolean
|
||||
"""
|
||||
if template is not None:
|
||||
logger.warning(
|
||||
"The `template` argument is deprecated and will be removed in a future version. "
|
||||
+ "Please use `prompt` instead."
|
||||
)
|
||||
if prompt is None:
|
||||
prompt = template
|
||||
|
||||
if prompt is None:
|
||||
prompt = DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
self.number_documents = number_documents
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.model = model
|
||||
self.top_p = top_p
|
||||
self.online = online
|
||||
self.token_usage = token_usage
|
||||
self.deployment_name = deployment_name
|
||||
self.system_prompt = system_prompt
|
||||
self.query_type = query_type
|
||||
self.callbacks = callbacks
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.endpoint = endpoint
|
||||
self.model_kwargs = model_kwargs
|
||||
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
|
||||
self.http_async_client = (
|
||||
httpx.AsyncClient(proxies=http_async_client_proxies) if http_async_client_proxies else None
|
||||
)
|
||||
self.local = local
|
||||
self.default_headers = default_headers
|
||||
self.online = online
|
||||
self.api_version = api_version
|
||||
|
||||
if token_usage:
|
||||
f = open("model_prices_and_context_window.json")
|
||||
self.model_pricing_map = json.load(f)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = Template(prompt)
|
||||
|
||||
if self.validate_prompt(prompt):
|
||||
self.prompt = prompt
|
||||
else:
|
||||
raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).")
|
||||
|
||||
if not isinstance(stream, bool):
|
||||
raise ValueError("`stream` should be bool")
|
||||
self.stream = stream
|
||||
self.where = where
|
||||
|
||||
@staticmethod
|
||||
def validate_prompt(prompt: Template) -> Optional[re.Match[str]]:
|
||||
"""
|
||||
validate the prompt
|
||||
|
||||
:param prompt: the prompt to validate
|
||||
:type prompt: Template
|
||||
:return: valid (true) or invalid (false)
|
||||
:rtype: Optional[re.Match[str]]
|
||||
"""
|
||||
return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
|
||||
|
||||
@staticmethod
|
||||
def _validate_prompt_history(prompt: Template) -> Optional[re.Match[str]]:
|
||||
"""
|
||||
validate the prompt with history
|
||||
|
||||
:param prompt: the prompt to validate
|
||||
:type prompt: Template
|
||||
:return: valid (true) or invalid (false)
|
||||
:rtype: Optional[re.Match[str]]
|
||||
"""
|
||||
return re.search(history_re, prompt.template)
|
||||
21
embedchain/embedchain/config/mem0_config.py
Normal file
21
embedchain/embedchain/config/mem0_config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class Mem0Config(BaseConfig):
|
||||
def __init__(self, api_key: str, top_k: Optional[int] = 10):
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return Mem0Config()
|
||||
else:
|
||||
return Mem0Config(
|
||||
api_key=config.get("api_key", ""),
|
||||
init_config=config.get("top_k", 10),
|
||||
)
|
||||
36
embedchain/embedchain/config/vector_db/base.py
Normal file
36
embedchain/embedchain/config/vector_db/base.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
|
||||
|
||||
class BaseVectorDbConfig(BaseConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: str = "db",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the vector database.
|
||||
|
||||
:param collection_name: Default name for the collection, defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to "db"
|
||||
:type dir: str, optional
|
||||
:param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
|
||||
:type host: Optional[str], optional
|
||||
:param host: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
|
||||
:type port: Optional[str], optional
|
||||
:param kwargs: Additional keyword arguments
|
||||
:type kwargs: dict
|
||||
"""
|
||||
self.collection_name = collection_name or "embedchain_store"
|
||||
self.dir = dir
|
||||
self.host = host
|
||||
self.port = port
|
||||
# Assign additional keyword arguments
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
41
embedchain/embedchain/config/vector_db/chroma.py
Normal file
41
embedchain/embedchain/config/vector_db/chroma.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ChromaDbConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
allow_reset=False,
|
||||
chroma_settings: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for ChromaDB.
|
||||
|
||||
:param collection_name: Default name for the collection, defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
:type dir: Optional[str], optional
|
||||
:param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
|
||||
:type host: Optional[str], optional
|
||||
:param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
|
||||
:type port: Optional[str], optional
|
||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||
:type batch_size: Optional[int], optional
|
||||
:param allow_reset: Resets the database. defaults to False
|
||||
:type allow_reset: bool
|
||||
:param chroma_settings: Chroma settings dict, defaults to None
|
||||
:type chroma_settings: Optional[dict], optional
|
||||
"""
|
||||
|
||||
self.chroma_settings = chroma_settings
|
||||
self.allow_reset = allow_reset
|
||||
self.batch_size = batch_size
|
||||
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)
|
||||
56
embedchain/embedchain/config/vector_db/elasticsearch.py
Normal file
56
embedchain/embedchain/config/vector_db/elasticsearch.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ElasticsearchDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
es_url: Union[str, list[str]] = None,
|
||||
cloud_id: Optional[str] = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
**ES_EXTRA_PARAMS: dict[str, any],
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an Elasticsearch client.
|
||||
|
||||
:param collection_name: Default name for the collection, defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
:type dir: Optional[str], optional
|
||||
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
|
||||
:type es_url: Union[str, list[str]], optional
|
||||
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
|
||||
:type cloud_id: Optional[str], optional
|
||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||
:type batch_size: Optional[int], optional
|
||||
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
|
||||
:type ES_EXTRA_PARAMS: dict[str, Any], optional
|
||||
"""
|
||||
if es_url and cloud_id:
|
||||
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
|
||||
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
|
||||
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
|
||||
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
|
||||
if not self.ES_URL and not self.CLOUD_ID:
|
||||
raise AttributeError(
|
||||
"Elasticsearch needs a URL or CLOUD_ID attribute, "
|
||||
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
|
||||
)
|
||||
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
|
||||
# Load API key from .env if it's not explicitly passed.
|
||||
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
|
||||
if (
|
||||
not self.ES_EXTRA_PARAMS.get("api_key")
|
||||
and not self.ES_EXTRA_PARAMS.get("basic_auth")
|
||||
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
|
||||
):
|
||||
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
|
||||
|
||||
self.batch_size = batch_size
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
33
embedchain/embedchain/config/vector_db/lancedb.py
Normal file
33
embedchain/embedchain/config/vector_db/lancedb.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class LanceDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
allow_reset=True,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for LanceDB.
|
||||
|
||||
:param collection_name: Default name for the collection, defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
:type dir: Optional[str], optional
|
||||
:param host: Database connection remote host. Use this if you run Embedchain as a client, defaults to None
|
||||
:type host: Optional[str], optional
|
||||
:param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
|
||||
:type port: Optional[str], optional
|
||||
:param allow_reset: Resets the database. defaults to False
|
||||
:type allow_reset: bool
|
||||
"""
|
||||
|
||||
self.allow_reset = allow_reset
|
||||
super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)
|
||||
41
embedchain/embedchain/config/vector_db/opensearch.py
Normal file
41
embedchain/embedchain/config/vector_db/opensearch.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class OpenSearchDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
opensearch_url: str,
|
||||
http_auth: tuple[str, str],
|
||||
vector_dimension: int = 1536,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an OpenSearch client.
|
||||
|
||||
:param collection_name: Default name for the collection, defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
:param opensearch_url: URL of the OpenSearch domain
|
||||
:type opensearch_url: str, Eg, "http://localhost:9200"
|
||||
:param http_auth: Tuple of username and password
|
||||
:type http_auth: tuple[str, str], Eg, ("username", "password")
|
||||
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
|
||||
:type vector_dimension: int, optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
:type dir: Optional[str], optional
|
||||
:param batch_size: Number of items to insert in one batch, defaults to 100
|
||||
:type batch_size: Optional[int], optional
|
||||
"""
|
||||
self.opensearch_url = opensearch_url
|
||||
self.http_auth = http_auth
|
||||
self.vector_dimension = vector_dimension
|
||||
self.extra_params = extra_params
|
||||
self.batch_size = batch_size
|
||||
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
47
embedchain/embedchain/config/vector_db/pinecone.py
Normal file
47
embedchain/embedchain/config/vector_db/pinecone.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PineconeDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
index_name: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
vector_dimension: int = 1536,
|
||||
metric: Optional[str] = "cosine",
|
||||
pod_config: Optional[dict[str, any]] = None,
|
||||
serverless_config: Optional[dict[str, any]] = None,
|
||||
hybrid_search: bool = False,
|
||||
bm25_encoder: any = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
self.metric = metric
|
||||
self.api_key = api_key
|
||||
self.index_name = index_name
|
||||
self.vector_dimension = vector_dimension
|
||||
self.extra_params = extra_params
|
||||
self.hybrid_search = hybrid_search
|
||||
self.bm25_encoder = bm25_encoder
|
||||
self.batch_size = batch_size
|
||||
if pod_config is None and serverless_config is None:
|
||||
# If no config is provided, use the default pod spec config
|
||||
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
|
||||
self.pod_config = {"environment": pod_environment, "metadata_config": {"indexed": ["*"]}}
|
||||
else:
|
||||
self.pod_config = pod_config
|
||||
self.serverless_config = serverless_config
|
||||
|
||||
if self.pod_config and self.serverless_config:
|
||||
raise ValueError("Only one of pod_config or serverless_config can be provided.")
|
||||
|
||||
if self.hybrid_search and self.metric != "dotproduct":
|
||||
raise ValueError(
|
||||
"Hybrid search is only supported with dotproduct metric in Pinecone. See full docs here: https://docs.pinecone.io/docs/hybrid-search#limitations"
|
||||
) # noqa:E501
|
||||
|
||||
super().__init__(collection_name=self.index_name, dir=None)
|
||||
48
embedchain/embedchain/config/vector_db/qdrant.py
Normal file
48
embedchain/embedchain/config/vector_db/qdrant.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class QdrantDBConfig(BaseVectorDbConfig):
|
||||
"""
|
||||
Config to initialize a qdrant client.
|
||||
:param: url. qdrant url or list of nodes url to be used for connection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
hnsw_config: Optional[dict[str, any]] = None,
|
||||
quantization_config: Optional[dict[str, any]] = None,
|
||||
on_disk: Optional[bool] = None,
|
||||
batch_size: Optional[int] = 10,
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for a qdrant client.
|
||||
|
||||
:param collection_name: Default name for the collection, defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to None
|
||||
:type dir: Optional[str], optional
|
||||
:param hnsw_config: Params for HNSW index
|
||||
:type hnsw_config: Optional[dict[str, any]], defaults to None
|
||||
:param quantization_config: Params for quantization, if None - quantization will be disabled
|
||||
:type quantization_config: Optional[dict[str, any]], defaults to None
|
||||
:param on_disk: If true - point`s payload will not be stored in memory.
|
||||
It will be read from the disk every time it is requested.
|
||||
This setting saves RAM by (slightly) increasing the response time.
|
||||
Note: those payload values that are involved in filtering and are indexed - remain in RAM.
|
||||
:type on_disk: bool, optional, defaults to None
|
||||
:param batch_size: Number of items to insert in one batch, defaults to 10
|
||||
:type batch_size: Optional[int], optional
|
||||
"""
|
||||
self.hnsw_config = hnsw_config
|
||||
self.quantization_config = quantization_config
|
||||
self.on_disk = on_disk
|
||||
self.batch_size = batch_size
|
||||
self.extra_params = extra_params
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
18
embedchain/embedchain/config/vector_db/weaviate.py
Normal file
18
embedchain/embedchain/config/vector_db/weaviate.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class WeaviateDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
batch_size: Optional[int] = 100,
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.extra_params = extra_params
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
49
embedchain/embedchain/config/vector_db/zilliz.py
Normal file
49
embedchain/embedchain/config/vector_db/zilliz.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vector_db.base import BaseVectorDbConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ZillizDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
vector_dim: Optional[str] = None,
|
||||
metric_type: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the vector database.
|
||||
|
||||
:param collection_name: Default name for the collection, defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
:param dir: Path to the database directory, where the database is stored, defaults to "db"
|
||||
:type dir: str, optional
|
||||
:param uri: Cluster endpoint obtained from the Zilliz Console, defaults to None
|
||||
:type uri: Optional[str], optional
|
||||
:param token: API Key, if a Serverless Cluster, username:password, if a Dedicated Cluster, defaults to None
|
||||
:type token: Optional[str], optional
|
||||
"""
|
||||
self.uri = uri or os.environ.get("ZILLIZ_CLOUD_URI")
|
||||
if not self.uri:
|
||||
raise AttributeError(
|
||||
"Zilliz needs a URI attribute, "
|
||||
"this can either be passed to `ZILLIZ_CLOUD_URI` or as `ZILLIZ_CLOUD_URI` in `.env`"
|
||||
)
|
||||
|
||||
self.token = token or os.environ.get("ZILLIZ_CLOUD_TOKEN")
|
||||
if not self.token:
|
||||
raise AttributeError(
|
||||
"Zilliz needs a token attribute, "
|
||||
"this can either be passed to `ZILLIZ_CLOUD_TOKEN` or as `ZILLIZ_CLOUD_TOKEN` in `.env`,"
|
||||
"if having a username and password, pass it in the form 'username:password' to `ZILLIZ_CLOUD_TOKEN`"
|
||||
)
|
||||
|
||||
self.metric_type = metric_type if metric_type else "L2"
|
||||
|
||||
self.vector_dim = vector_dim
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
0
embedchain/embedchain/config/vectordb/__init__.py
Normal file
0
embedchain/embedchain/config/vectordb/__init__.py
Normal file
11
embedchain/embedchain/constants.py
Normal file
11
embedchain/embedchain/constants.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
ABS_PATH = os.getcwd()
|
||||
HOME_DIR = os.environ.get("EMBEDCHAIN_CONFIG_DIR", str(Path.home()))
|
||||
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
||||
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
|
||||
|
||||
# Set the environment variable for the database URI
|
||||
os.environ.setdefault("EMBEDCHAIN_DB_URI", f"sqlite:///{SQLITE_PATH}")
|
||||
0
embedchain/embedchain/core/__init__.py
Normal file
0
embedchain/embedchain/core/__init__.py
Normal file
1
embedchain/embedchain/data_formatter/__init__.py
Normal file
1
embedchain/embedchain/data_formatter/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .data_formatter import DataFormatter # noqa: F401
|
||||
148
embedchain/embedchain/data_formatter/data_formatter.py
Normal file
148
embedchain/embedchain/data_formatter/data_formatter.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from importlib import import_module
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig
|
||||
from embedchain.config.add_config import ChunkerConfig, LoaderConfig
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
class DataFormatter(JSONSerializable):
|
||||
"""
|
||||
DataFormatter is an internal utility class which abstracts the mapping for
|
||||
loaders and chunkers to the data_type entered by the user in their
|
||||
.add or .add_local method call
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_type: DataType,
|
||||
config: AddConfig,
|
||||
loader: Optional[BaseLoader] = None,
|
||||
chunker: Optional[BaseChunker] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a dataformatter, set data type and chunker based on datatype.
|
||||
|
||||
:param data_type: The type of the data to load and chunk.
|
||||
:type data_type: DataType
|
||||
:param config: AddConfig instance with nested loader and chunker config attributes.
|
||||
:type config: AddConfig
|
||||
"""
|
||||
self.loader = self._get_loader(data_type=data_type, config=config.loader, loader=loader)
|
||||
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, chunker=chunker)
|
||||
|
||||
@staticmethod
|
||||
def _lazy_load(module_path: str):
|
||||
module_path, class_name = module_path.rsplit(".", 1)
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
def _get_loader(self, data_type: DataType, config: LoaderConfig, loader: Optional[BaseLoader]) -> BaseLoader:
|
||||
"""
|
||||
Returns the appropriate data loader for the given data type.
|
||||
|
||||
:param data_type: The type of the data to load.
|
||||
:type data_type: DataType
|
||||
:param config: Config to initialize the loader with.
|
||||
:type config: LoaderConfig
|
||||
:raises ValueError: If an unsupported data type is provided.
|
||||
:return: The loader for the given data type.
|
||||
:rtype: BaseLoader
|
||||
"""
|
||||
loaders = {
|
||||
DataType.YOUTUBE_VIDEO: "embedchain.loaders.youtube_video.YoutubeVideoLoader",
|
||||
DataType.PDF_FILE: "embedchain.loaders.pdf_file.PdfFileLoader",
|
||||
DataType.WEB_PAGE: "embedchain.loaders.web_page.WebPageLoader",
|
||||
DataType.QNA_PAIR: "embedchain.loaders.local_qna_pair.LocalQnaPairLoader",
|
||||
DataType.TEXT: "embedchain.loaders.local_text.LocalTextLoader",
|
||||
DataType.DOCX: "embedchain.loaders.docx_file.DocxFileLoader",
|
||||
DataType.SITEMAP: "embedchain.loaders.sitemap.SitemapLoader",
|
||||
DataType.XML: "embedchain.loaders.xml.XmlLoader",
|
||||
DataType.DOCS_SITE: "embedchain.loaders.docs_site_loader.DocsSiteLoader",
|
||||
DataType.CSV: "embedchain.loaders.csv.CsvLoader",
|
||||
DataType.MDX: "embedchain.loaders.mdx.MdxLoader",
|
||||
DataType.IMAGE: "embedchain.loaders.image.ImageLoader",
|
||||
DataType.UNSTRUCTURED: "embedchain.loaders.unstructured_file.UnstructuredLoader",
|
||||
DataType.JSON: "embedchain.loaders.json.JSONLoader",
|
||||
DataType.OPENAPI: "embedchain.loaders.openapi.OpenAPILoader",
|
||||
DataType.GMAIL: "embedchain.loaders.gmail.GmailLoader",
|
||||
DataType.NOTION: "embedchain.loaders.notion.NotionLoader",
|
||||
DataType.SUBSTACK: "embedchain.loaders.substack.SubstackLoader",
|
||||
DataType.YOUTUBE_CHANNEL: "embedchain.loaders.youtube_channel.YoutubeChannelLoader",
|
||||
DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
|
||||
DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader",
|
||||
DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader",
|
||||
DataType.GOOGLE_DRIVE: "embedchain.loaders.google_drive.GoogleDriveLoader",
|
||||
DataType.DIRECTORY: "embedchain.loaders.directory_loader.DirectoryLoader",
|
||||
DataType.SLACK: "embedchain.loaders.slack.SlackLoader",
|
||||
DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader",
|
||||
DataType.TEXT_FILE: "embedchain.loaders.text_file.TextFileLoader",
|
||||
DataType.EXCEL_FILE: "embedchain.loaders.excel_file.ExcelFileLoader",
|
||||
DataType.AUDIO: "embedchain.loaders.audio.AudioLoader",
|
||||
}
|
||||
|
||||
if data_type == DataType.CUSTOM or loader is not None:
|
||||
loader_class: type = loader
|
||||
if loader_class:
|
||||
return loader_class
|
||||
elif data_type in loaders:
|
||||
loader_class: type = self._lazy_load(loaders[data_type])
|
||||
return loader_class()
|
||||
|
||||
raise ValueError(
|
||||
f"Cant find the loader for {data_type}.\
|
||||
We recommend to pass the loader to use data_type: {data_type},\
|
||||
check `https://docs.embedchain.ai/data-sources/overview`."
|
||||
)
|
||||
|
||||
def _get_chunker(self, data_type: DataType, config: ChunkerConfig, chunker: Optional[BaseChunker]) -> BaseChunker:
|
||||
"""Returns the appropriate chunker for the given data type (updated for lazy loading)."""
|
||||
chunker_classes = {
|
||||
DataType.YOUTUBE_VIDEO: "embedchain.chunkers.youtube_video.YoutubeVideoChunker",
|
||||
DataType.PDF_FILE: "embedchain.chunkers.pdf_file.PdfFileChunker",
|
||||
DataType.WEB_PAGE: "embedchain.chunkers.web_page.WebPageChunker",
|
||||
DataType.QNA_PAIR: "embedchain.chunkers.qna_pair.QnaPairChunker",
|
||||
DataType.TEXT: "embedchain.chunkers.text.TextChunker",
|
||||
DataType.DOCX: "embedchain.chunkers.docx_file.DocxFileChunker",
|
||||
DataType.SITEMAP: "embedchain.chunkers.sitemap.SitemapChunker",
|
||||
DataType.XML: "embedchain.chunkers.xml.XmlChunker",
|
||||
DataType.DOCS_SITE: "embedchain.chunkers.docs_site.DocsSiteChunker",
|
||||
DataType.CSV: "embedchain.chunkers.table.TableChunker",
|
||||
DataType.MDX: "embedchain.chunkers.mdx.MdxChunker",
|
||||
DataType.IMAGE: "embedchain.chunkers.image.ImageChunker",
|
||||
DataType.UNSTRUCTURED: "embedchain.chunkers.unstructured_file.UnstructuredFileChunker",
|
||||
DataType.JSON: "embedchain.chunkers.json.JSONChunker",
|
||||
DataType.OPENAPI: "embedchain.chunkers.openapi.OpenAPIChunker",
|
||||
DataType.GMAIL: "embedchain.chunkers.gmail.GmailChunker",
|
||||
DataType.NOTION: "embedchain.chunkers.notion.NotionChunker",
|
||||
DataType.SUBSTACK: "embedchain.chunkers.substack.SubstackChunker",
|
||||
DataType.YOUTUBE_CHANNEL: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.DISCORD: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker",
|
||||
DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker",
|
||||
DataType.GOOGLE_DRIVE: "embedchain.chunkers.google_drive.GoogleDriveChunker",
|
||||
DataType.DIRECTORY: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.SLACK: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.TEXT_FILE: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.EXCEL_FILE: "embedchain.chunkers.excel_file.ExcelFileChunker",
|
||||
DataType.AUDIO: "embedchain.chunkers.audio.AudioChunker",
|
||||
}
|
||||
|
||||
if chunker is not None:
|
||||
return chunker
|
||||
elif data_type in chunker_classes:
|
||||
chunker_class = self._lazy_load(chunker_classes[data_type])
|
||||
chunker = chunker_class(config)
|
||||
chunker.set_data_type(data_type)
|
||||
return chunker
|
||||
|
||||
raise ValueError(
|
||||
f"Cant find the chunker for {data_type}.\
|
||||
We recommend to pass the chunker to use data_type: {data_type},\
|
||||
check `https://docs.embedchain.ai/data-sources/overview`."
|
||||
)
|
||||
1
embedchain/embedchain/deployment/fly.io/.dockerignore
Normal file
1
embedchain/embedchain/deployment/fly.io/.dockerignore
Normal file
@@ -0,0 +1 @@
|
||||
db/
|
||||
1
embedchain/embedchain/deployment/fly.io/.env.example
Normal file
1
embedchain/embedchain/deployment/fly.io/.env.example
Normal file
@@ -0,0 +1 @@
|
||||
OPENAI_API_KEY=sk-xxx
|
||||
13
embedchain/embedchain/deployment/fly.io/Dockerfile
Normal file
13
embedchain/embedchain/deployment/fly.io/Dockerfile
Normal file
@@ -0,0 +1,13 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt /app/
|
||||
|
||||
RUN pip install -r requirements.txt
|
||||
|
||||
COPY . /app
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8080"]
|
||||
56
embedchain/embedchain/deployment/fly.io/app.py
Normal file
56
embedchain/embedchain/deployment/fly.io/app.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, responses
|
||||
from pydantic import BaseModel
|
||||
|
||||
from embedchain import App
|
||||
|
||||
load_dotenv(".env")
|
||||
|
||||
app = FastAPI(title="Embedchain FastAPI App")
|
||||
embedchain_app = App()
|
||||
|
||||
|
||||
class SourceModel(BaseModel):
|
||||
source: str
|
||||
|
||||
|
||||
class QuestionModel(BaseModel):
|
||||
question: str
|
||||
|
||||
|
||||
@app.post("/add")
|
||||
async def add_source(source_model: SourceModel):
|
||||
"""
|
||||
Adds a new source to the EmbedChain app.
|
||||
Expects a JSON with a "source" key.
|
||||
"""
|
||||
source = source_model.source
|
||||
embedchain_app.add(source)
|
||||
return {"message": f"Source '{source}' added successfully."}
|
||||
|
||||
|
||||
@app.post("/query")
|
||||
async def handle_query(question_model: QuestionModel):
|
||||
"""
|
||||
Handles a query to the EmbedChain app.
|
||||
Expects a JSON with a "question" key.
|
||||
"""
|
||||
question = question_model.question
|
||||
answer = embedchain_app.query(question)
|
||||
return {"answer": answer}
|
||||
|
||||
|
||||
@app.post("/chat")
|
||||
async def handle_chat(question_model: QuestionModel):
|
||||
"""
|
||||
Handles a chat request to the EmbedChain app.
|
||||
Expects a JSON with a "question" key.
|
||||
"""
|
||||
question = question_model.question
|
||||
response = embedchain_app.chat(question)
|
||||
return {"response": response}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return responses.RedirectResponse(url="/docs")
|
||||
4
embedchain/embedchain/deployment/fly.io/requirements.txt
Normal file
4
embedchain/embedchain/deployment/fly.io/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
fastapi==0.104.0
|
||||
uvicorn==0.23.2
|
||||
embedchain
|
||||
beautifulsoup4
|
||||
18
embedchain/embedchain/deployment/gradio.app/app.py
Normal file
18
embedchain/embedchain/deployment/gradio.app/app.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from embedchain import App
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xxx"
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
def query(message, history):
|
||||
return app.chat(message)
|
||||
|
||||
|
||||
demo = gr.ChatInterface(query)
|
||||
|
||||
demo.launch()
|
||||
@@ -0,0 +1,2 @@
|
||||
gradio==4.11.0
|
||||
embedchain
|
||||
1
embedchain/embedchain/deployment/modal.com/.env.example
Normal file
1
embedchain/embedchain/deployment/modal.com/.env.example
Normal file
@@ -0,0 +1 @@
|
||||
OPENAI_API_KEY=sk-xxx
|
||||
1
embedchain/embedchain/deployment/modal.com/.gitignore
vendored
Normal file
1
embedchain/embedchain/deployment/modal.com/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
.env
|
||||
86
embedchain/embedchain/deployment/modal.com/app.py
Normal file
86
embedchain/embedchain/deployment/modal.com/app.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Body, FastAPI, responses
|
||||
from modal import Image, Secret, Stub, asgi_app
|
||||
|
||||
from embedchain import App
|
||||
|
||||
load_dotenv(".env")
|
||||
|
||||
image = Image.debian_slim().pip_install(
|
||||
"embedchain",
|
||||
"lanchain_community==0.2.6",
|
||||
"youtube-transcript-api==0.6.1",
|
||||
"pytube==15.0.0",
|
||||
"beautifulsoup4==4.12.3",
|
||||
"slack-sdk==3.21.3",
|
||||
"huggingface_hub==0.23.0",
|
||||
"gitpython==3.1.38",
|
||||
"yt_dlp==2023.11.14",
|
||||
"PyGithub==1.59.1",
|
||||
"feedparser==6.0.10",
|
||||
"newspaper3k==0.2.8",
|
||||
"listparser==0.19",
|
||||
)
|
||||
|
||||
stub = Stub(
|
||||
name="embedchain-app",
|
||||
image=image,
|
||||
secrets=[Secret.from_dotenv(".env")],
|
||||
)
|
||||
|
||||
web_app = FastAPI()
|
||||
embedchain_app = App(name="embedchain-modal-app")
|
||||
|
||||
|
||||
@web_app.post("/add")
|
||||
async def add(
|
||||
source: str = Body(..., description="Source to be added"),
|
||||
data_type: str | None = Body(None, description="Type of the data source"),
|
||||
):
|
||||
"""
|
||||
Adds a new source to the EmbedChain app.
|
||||
Expects a JSON with a "source" and "data_type" key.
|
||||
"data_type" is optional.
|
||||
"""
|
||||
if source and data_type:
|
||||
embedchain_app.add(source, data_type)
|
||||
elif source:
|
||||
embedchain_app.add(source)
|
||||
else:
|
||||
return {"message": "No source provided."}
|
||||
return {"message": f"Source '{source}' added successfully."}
|
||||
|
||||
|
||||
@web_app.post("/query")
|
||||
async def query(question: str = Body(..., description="Question to be answered")):
|
||||
"""
|
||||
Handles a query to the EmbedChain app.
|
||||
Expects a JSON with a "question" key.
|
||||
"""
|
||||
if not question:
|
||||
return {"message": "No question provided."}
|
||||
answer = embedchain_app.query(question)
|
||||
return {"answer": answer}
|
||||
|
||||
|
||||
@web_app.get("/chat")
|
||||
async def chat(question: str = Body(..., description="Question to be answered")):
|
||||
"""
|
||||
Handles a chat request to the EmbedChain app.
|
||||
Expects a JSON with a "question" key.
|
||||
"""
|
||||
if not question:
|
||||
return {"message": "No question provided."}
|
||||
response = embedchain_app.chat(question)
|
||||
return {"response": response}
|
||||
|
||||
|
||||
@web_app.get("/")
|
||||
async def root():
|
||||
return responses.RedirectResponse(url="/docs")
|
||||
|
||||
|
||||
@stub.function(image=image)
|
||||
@asgi_app()
|
||||
def fastapi_app():
|
||||
return web_app
|
||||
@@ -0,0 +1,4 @@
|
||||
modal==0.56.4329
|
||||
fastapi==0.104.0
|
||||
uvicorn==0.23.2
|
||||
embedchain
|
||||
1
embedchain/embedchain/deployment/render.com/.env.example
Normal file
1
embedchain/embedchain/deployment/render.com/.env.example
Normal file
@@ -0,0 +1 @@
|
||||
OPENAI_API_KEY=sk-xxx
|
||||
1
embedchain/embedchain/deployment/render.com/.gitignore
vendored
Normal file
1
embedchain/embedchain/deployment/render.com/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
.env
|
||||
53
embedchain/embedchain/deployment/render.com/app.py
Normal file
53
embedchain/embedchain/deployment/render.com/app.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from fastapi import FastAPI, responses
|
||||
from pydantic import BaseModel
|
||||
|
||||
from embedchain import App
|
||||
|
||||
app = FastAPI(title="Embedchain FastAPI App")
|
||||
embedchain_app = App()
|
||||
|
||||
|
||||
class SourceModel(BaseModel):
|
||||
source: str
|
||||
|
||||
|
||||
class QuestionModel(BaseModel):
|
||||
question: str
|
||||
|
||||
|
||||
@app.post("/add")
|
||||
async def add_source(source_model: SourceModel):
|
||||
"""
|
||||
Adds a new source to the EmbedChain app.
|
||||
Expects a JSON with a "source" key.
|
||||
"""
|
||||
source = source_model.source
|
||||
embedchain_app.add(source)
|
||||
return {"message": f"Source '{source}' added successfully."}
|
||||
|
||||
|
||||
@app.post("/query")
|
||||
async def handle_query(question_model: QuestionModel):
|
||||
"""
|
||||
Handles a query to the EmbedChain app.
|
||||
Expects a JSON with a "question" key.
|
||||
"""
|
||||
question = question_model.question
|
||||
answer = embedchain_app.query(question)
|
||||
return {"answer": answer}
|
||||
|
||||
|
||||
@app.post("/chat")
|
||||
async def handle_chat(question_model: QuestionModel):
|
||||
"""
|
||||
Handles a chat request to the EmbedChain app.
|
||||
Expects a JSON with a "question" key.
|
||||
"""
|
||||
question = question_model.question
|
||||
response = embedchain_app.chat(question)
|
||||
return {"response": response}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return responses.RedirectResponse(url="/docs")
|
||||
16
embedchain/embedchain/deployment/render.com/render.yaml
Normal file
16
embedchain/embedchain/deployment/render.com/render.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
services:
|
||||
- type: web
|
||||
name: ec-render-app
|
||||
runtime: python
|
||||
repo: https://github.com/<your-username>/<repo-name>
|
||||
scaling:
|
||||
minInstances: 1
|
||||
maxInstances: 3
|
||||
targetMemoryPercent: 60 # optional if targetCPUPercent is set
|
||||
targetCPUPercent: 60 # optional if targetMemory is set
|
||||
buildCommand: pip install -r requirements.txt
|
||||
startCommand: uvicorn app:app --host 0.0.0.0
|
||||
envVars:
|
||||
- key: OPENAI_API_KEY
|
||||
value: sk-xxx
|
||||
autoDeploy: false # optional
|
||||
@@ -0,0 +1,4 @@
|
||||
fastapi==0.104.0
|
||||
uvicorn==0.23.2
|
||||
embedchain
|
||||
beautifulsoup4
|
||||
@@ -0,0 +1 @@
|
||||
OPENAI_API_KEY="sk-xxx"
|
||||
59
embedchain/embedchain/deployment/streamlit.io/app.py
Normal file
59
embedchain/embedchain/deployment/streamlit.io/app.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import streamlit as st
|
||||
|
||||
from embedchain import App
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def embedchain_bot():
|
||||
return App()
|
||||
|
||||
|
||||
st.title("💬 Chatbot")
|
||||
st.caption("🚀 An Embedchain app powered by OpenAI!")
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": """
|
||||
Hi! I'm a chatbot. I can answer questions and learn new things!\n
|
||||
Ask me anything and if you want me to learn something do `/add <source>`.\n
|
||||
I can learn mostly everything. :)
|
||||
""",
|
||||
}
|
||||
]
|
||||
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
|
||||
if prompt := st.chat_input("Ask me anything!"):
|
||||
app = embedchain_bot()
|
||||
|
||||
if prompt.startswith("/add"):
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
prompt = prompt.replace("/add", "").strip()
|
||||
with st.chat_message("assistant"):
|
||||
message_placeholder = st.empty()
|
||||
message_placeholder.markdown("Adding to knowledge base...")
|
||||
app.add(prompt)
|
||||
message_placeholder.markdown(f"Added {prompt} to knowledge base!")
|
||||
st.session_state.messages.append({"role": "assistant", "content": f"Added {prompt} to knowledge base!"})
|
||||
st.stop()
|
||||
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
msg_placeholder = st.empty()
|
||||
msg_placeholder.markdown("Thinking...")
|
||||
full_response = ""
|
||||
|
||||
for response in app.chat(prompt):
|
||||
msg_placeholder.empty()
|
||||
full_response += response
|
||||
|
||||
msg_placeholder.markdown(full_response)
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
@@ -0,0 +1,2 @@
|
||||
streamlit==1.29.0
|
||||
embedchain
|
||||
776
embedchain/embedchain/embedchain.py
Normal file
776
embedchain/embedchain/embedchain.py
Normal file
@@ -0,0 +1,776 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config.base_app_config import BaseAppConfig
|
||||
from embedchain.core.db.models import ChatHistory, DataSource
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
||||
from embedchain.utils.misc import detect_datatype, is_valid_json_string
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbedChain(JSONSerializable):
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseAppConfig,
|
||||
llm: BaseLlm,
|
||||
db: BaseVectorDB = None,
|
||||
embedder: BaseEmbedder = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the EmbedChain instance, sets up a vector DB client and
|
||||
creates a collection.
|
||||
|
||||
:param config: Configuration just for the app, not the db or llm or embedder.
|
||||
:type config: BaseAppConfig
|
||||
:param llm: Instance of the LLM you want to use.
|
||||
:type llm: BaseLlm
|
||||
:param db: Instance of the Database to use, defaults to None
|
||||
:type db: BaseVectorDB, optional
|
||||
:param embedder: instance of the embedder to use, defaults to None
|
||||
:type embedder: BaseEmbedder, optional
|
||||
:param system_prompt: System prompt to use in the llm query, defaults to None
|
||||
:type system_prompt: Optional[str], optional
|
||||
:raises ValueError: No database or embedder provided.
|
||||
"""
|
||||
self.config = config
|
||||
self.cache_config = None
|
||||
self.memory_config = None
|
||||
self.mem0_client = None
|
||||
# Llm
|
||||
self.llm = llm
|
||||
# Database has support for config assignment for backwards compatibility
|
||||
if db is None and (not hasattr(self.config, "db") or self.config.db is None):
|
||||
raise ValueError("App requires Database.")
|
||||
self.db = db or self.config.db
|
||||
# Embedder
|
||||
if embedder is None:
|
||||
raise ValueError("App requires Embedder.")
|
||||
self.embedder = embedder
|
||||
|
||||
# Initialize database
|
||||
self.db._set_embedder(self.embedder)
|
||||
self.db._initialize()
|
||||
# Set collection name from app config for backwards compatibility.
|
||||
if config.collection_name:
|
||||
self.db.set_collection_name(config.collection_name)
|
||||
|
||||
# Add variables that are "shortcuts"
|
||||
if system_prompt:
|
||||
self.llm.config.system_prompt = system_prompt
|
||||
|
||||
# Fetch the history from the database if exists
|
||||
self.llm.update_history(app_id=self.config.id)
|
||||
|
||||
# Attributes that aren't subclass related.
|
||||
self.user_asks = []
|
||||
|
||||
self.chunker: Optional[ChunkerConfig] = None
|
||||
|
||||
@property
|
||||
def collect_metrics(self):
|
||||
return self.config.collect_metrics
|
||||
|
||||
@collect_metrics.setter
|
||||
def collect_metrics(self, value):
|
||||
if not isinstance(value, bool):
|
||||
raise ValueError(f"Boolean value expected but got {type(value)}.")
|
||||
self.config.collect_metrics = value
|
||||
|
||||
@property
|
||||
def online(self):
|
||||
return self.llm.config.online
|
||||
|
||||
@online.setter
|
||||
def online(self, value):
|
||||
if not isinstance(value, bool):
|
||||
raise ValueError(f"Boolean value expected but got {type(value)}.")
|
||||
self.llm.config.online = value
|
||||
|
||||
def add(
|
||||
self,
|
||||
source: Any,
|
||||
data_type: Optional[DataType] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
config: Optional[AddConfig] = None,
|
||||
dry_run=False,
|
||||
loader: Optional[BaseLoader] = None,
|
||||
chunker: Optional[BaseChunker] = None,
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Adds the data from the given URL to the vector db.
|
||||
Loads the data, chunks it, create embedding for each chunk
|
||||
and then stores the embedding to vector database.
|
||||
|
||||
:param source: The data to embed, can be a URL, local file or raw content, depending on the data type.
|
||||
:type source: Any
|
||||
:param data_type: Automatically detected, but can be forced with this argument. The type of the data to add,
|
||||
defaults to None
|
||||
:type data_type: Optional[DataType], optional
|
||||
:param metadata: Metadata associated with the data source., defaults to None
|
||||
:type metadata: Optional[dict[str, Any]], optional
|
||||
:param config: The `AddConfig` instance to use as configuration options., defaults to None
|
||||
:type config: Optional[AddConfig], optional
|
||||
:raises ValueError: Invalid data type
|
||||
:param dry_run: Optional. A dry run displays the chunks to ensure that the loader and chunker work as intended.
|
||||
defaults to False
|
||||
:type dry_run: bool
|
||||
:param loader: The loader to use to load the data, defaults to None
|
||||
:type loader: BaseLoader, optional
|
||||
:param chunker: The chunker to use to chunk the data, defaults to None
|
||||
:type chunker: BaseChunker, optional
|
||||
:param kwargs: To read more params for the query function
|
||||
:type kwargs: dict[str, Any]
|
||||
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
|
||||
:rtype: str
|
||||
"""
|
||||
if config is not None:
|
||||
pass
|
||||
elif self.chunker is not None:
|
||||
config = AddConfig(chunker=self.chunker)
|
||||
else:
|
||||
config = AddConfig()
|
||||
|
||||
try:
|
||||
DataType(source)
|
||||
logger.warning(
|
||||
f"""Starting from version v0.0.40, Embedchain can automatically detect the data type. So, in the `add` method, the argument order has changed. You no longer need to specify '{source}' for the `source` argument. So the code snippet will be `.add("{data_type}", "{source}")`""" # noqa #E501
|
||||
)
|
||||
logger.warning(
|
||||
"Embedchain is swapping the arguments for you. This functionality might be deprecated in the future, so please adjust your code." # noqa #E501
|
||||
)
|
||||
source, data_type = data_type, source
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if data_type:
|
||||
try:
|
||||
data_type = DataType(data_type)
|
||||
except ValueError:
|
||||
logger.info(
|
||||
f"Invalid data_type: '{data_type}', using `custom` instead.\n Check docs to pass the valid data type: `https://docs.embedchain.ai/data-sources/overview`" # noqa: E501
|
||||
)
|
||||
data_type = DataType.CUSTOM
|
||||
|
||||
if not data_type:
|
||||
data_type = detect_datatype(source)
|
||||
|
||||
# `source_hash` is the md5 hash of the source argument
|
||||
source_hash = hashlib.md5(str(source).encode("utf-8")).hexdigest()
|
||||
|
||||
self.user_asks.append([source, data_type.value, metadata])
|
||||
|
||||
data_formatter = DataFormatter(data_type, config, loader, chunker)
|
||||
documents, metadatas, _ids, new_chunks = self._load_and_embed(
|
||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, config, dry_run, **kwargs
|
||||
)
|
||||
if data_type in {DataType.DOCS_SITE}:
|
||||
self.is_docs_site_instance = True
|
||||
|
||||
# Convert the source to a string if it is not already
|
||||
if not isinstance(source, str):
|
||||
source = str(source)
|
||||
|
||||
# Insert the data into the 'ec_data_sources' table
|
||||
self.db_session.add(
|
||||
DataSource(
|
||||
hash=source_hash,
|
||||
app_id=self.config.id,
|
||||
type=data_type.value,
|
||||
value=source,
|
||||
metadata=json.dumps(metadata),
|
||||
)
|
||||
)
|
||||
try:
|
||||
self.db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding data source: {e}")
|
||||
self.db_session.rollback()
|
||||
|
||||
if dry_run:
|
||||
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
|
||||
logger.debug(f"Dry run info : {data_chunks_info}")
|
||||
return data_chunks_info
|
||||
|
||||
# Send anonymous telemetry
|
||||
if self.config.collect_metrics:
|
||||
# it's quicker to check the variable twice than to count words when they won't be submitted.
|
||||
word_count = data_formatter.chunker.get_word_count(documents)
|
||||
|
||||
# Send anonymous telemetry
|
||||
event_properties = {
|
||||
**self._telemetry_props,
|
||||
"data_type": data_type.value,
|
||||
"word_count": word_count,
|
||||
"chunks_count": new_chunks,
|
||||
}
|
||||
self.telemetry.capture(event_name="add", properties=event_properties)
|
||||
|
||||
return source_hash
|
||||
|
||||
def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
|
||||
"""
|
||||
Get id of existing document for a given source, based on the data type
|
||||
"""
|
||||
# Find existing embeddings for the source
|
||||
# Depending on the data type, existing embeddings are checked for.
|
||||
if chunker.data_type.value in [item.value for item in DirectDataType]:
|
||||
# DirectDataTypes can't be updated.
|
||||
# Think of a text:
|
||||
# Either it's the same, then it won't change, so it's not an update.
|
||||
# Or it's different, then it will be added as a new text.
|
||||
return None
|
||||
elif chunker.data_type.value in [item.value for item in IndirectDataType]:
|
||||
# These types have an indirect source reference
|
||||
# As long as the reference is the same, they can be updated.
|
||||
where = {"url": src}
|
||||
if chunker.data_type == DataType.JSON and is_valid_json_string(src):
|
||||
url = hashlib.sha256((src).encode("utf-8")).hexdigest()
|
||||
where = {"url": url}
|
||||
|
||||
if self.config.id is not None:
|
||||
where.update({"app_id": self.config.id})
|
||||
|
||||
existing_embeddings = self.db.get(
|
||||
where=where,
|
||||
limit=1,
|
||||
)
|
||||
if len(existing_embeddings.get("metadatas", [])) > 0:
|
||||
return existing_embeddings["metadatas"][0]["doc_id"]
|
||||
else:
|
||||
return None
|
||||
elif chunker.data_type.value in [item.value for item in SpecialDataType]:
|
||||
# These types don't contain indirect references.
|
||||
# Through custom logic, they can be attributed to a source and be updated.
|
||||
if chunker.data_type == DataType.QNA_PAIR:
|
||||
# QNA_PAIRs update the answer if the question already exists.
|
||||
where = {"question": src[0]}
|
||||
if self.config.id is not None:
|
||||
where.update({"app_id": self.config.id})
|
||||
|
||||
existing_embeddings = self.db.get(
|
||||
where=where,
|
||||
limit=1,
|
||||
)
|
||||
if len(existing_embeddings.get("metadatas", [])) > 0:
|
||||
return existing_embeddings["metadatas"][0]["doc_id"]
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"SpecialDataType {chunker.data_type} must have a custom logic to check for existing data"
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{chunker.data_type} is type {type(chunker.data_type)}. "
|
||||
"When it should be DirectDataType, IndirectDataType or SpecialDataType."
|
||||
)
|
||||
|
||||
def _load_and_embed(
|
||||
self,
|
||||
loader: BaseLoader,
|
||||
chunker: BaseChunker,
|
||||
src: Any,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
source_hash: Optional[str] = None,
|
||||
add_config: Optional[AddConfig] = None,
|
||||
dry_run=False,
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Loads the data from the given URL, chunks it, and adds it to database.
|
||||
|
||||
:param loader: The loader to use to load the data.
|
||||
:type loader: BaseLoader
|
||||
:param chunker: The chunker to use to chunk the data.
|
||||
:type chunker: BaseChunker
|
||||
:param src: The data to be handled by the loader. Can be a URL for
|
||||
remote sources or local content for local loaders.
|
||||
:type src: Any
|
||||
:param metadata: Metadata associated with the data source.
|
||||
:type metadata: dict[str, Any], optional
|
||||
:param source_hash: Hexadecimal hash of the source.
|
||||
:type source_hash: str, optional
|
||||
:param add_config: The `AddConfig` instance to use as configuration options.
|
||||
:type add_config: AddConfig, optional
|
||||
:param dry_run: A dry run returns chunks and doesn't update DB.
|
||||
:type dry_run: bool, defaults to False
|
||||
:return: (list) documents (embedded text), (list) metadata, (list) ids, (int) number of chunks
|
||||
"""
|
||||
existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
|
||||
app_id = self.config.id if self.config is not None else None
|
||||
|
||||
# Create chunks
|
||||
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker)
|
||||
# spread chunking results
|
||||
documents = embeddings_data["documents"]
|
||||
metadatas = embeddings_data["metadatas"]
|
||||
ids = embeddings_data["ids"]
|
||||
new_doc_id = embeddings_data["doc_id"]
|
||||
|
||||
if existing_doc_id and existing_doc_id == new_doc_id:
|
||||
logger.info("Doc content has not changed. Skipping creating chunks and embeddings")
|
||||
return [], [], [], 0
|
||||
|
||||
# this means that doc content has changed.
|
||||
if existing_doc_id and existing_doc_id != new_doc_id:
|
||||
logger.info("Doc content has changed. Recomputing chunks and embeddings intelligently.")
|
||||
self.db.delete({"doc_id": existing_doc_id})
|
||||
|
||||
# get existing ids, and discard doc if any common id exist.
|
||||
where = {"url": src}
|
||||
if chunker.data_type == DataType.JSON and is_valid_json_string(src):
|
||||
url = hashlib.sha256((src).encode("utf-8")).hexdigest()
|
||||
where = {"url": url}
|
||||
|
||||
# if data type is qna_pair, we check for question
|
||||
if chunker.data_type == DataType.QNA_PAIR:
|
||||
where = {"question": src[0]}
|
||||
|
||||
if self.config.id is not None:
|
||||
where["app_id"] = self.config.id
|
||||
|
||||
db_result = self.db.get(ids=ids, where=where) # optional filter
|
||||
existing_ids = set(db_result["ids"])
|
||||
if len(existing_ids):
|
||||
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
|
||||
data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
|
||||
|
||||
if not data_dict:
|
||||
src_copy = src
|
||||
if len(src_copy) > 50:
|
||||
src_copy = src[:50] + "..."
|
||||
logger.info(f"All data from {src_copy} already exists in the database.")
|
||||
# Make sure to return a matching return type
|
||||
return [], [], [], 0
|
||||
|
||||
ids = list(data_dict.keys())
|
||||
documents, metadatas = zip(*data_dict.values())
|
||||
|
||||
# Loop though all metadatas and add extras.
|
||||
new_metadatas = []
|
||||
for m in metadatas:
|
||||
# Add app id in metadatas so that they can be queried on later
|
||||
if self.config.id:
|
||||
m["app_id"] = self.config.id
|
||||
|
||||
# Add hashed source
|
||||
m["hash"] = source_hash
|
||||
|
||||
# Note: Metadata is the function argument
|
||||
if metadata:
|
||||
# Spread whatever is in metadata into the new object.
|
||||
m.update(metadata)
|
||||
|
||||
new_metadatas.append(m)
|
||||
metadatas = new_metadatas
|
||||
|
||||
if dry_run:
|
||||
return list(documents), metadatas, ids, 0
|
||||
|
||||
# Count before, to calculate a delta in the end.
|
||||
chunks_before_addition = self.db.count()
|
||||
|
||||
# Filter out empty documents and ensure they meet the API requirements
|
||||
valid_documents = [doc for doc in documents if doc and isinstance(doc, str)]
|
||||
|
||||
documents = valid_documents
|
||||
|
||||
# Chunk documents into batches of 2048 and handle each batch
|
||||
# helps wigth large loads of embeddings that hit OpenAI limits
|
||||
document_batches = [documents[i : i + 2048] for i in range(0, len(documents), 2048)]
|
||||
metadata_batches = [metadatas[i : i + 2048] for i in range(0, len(metadatas), 2048)]
|
||||
id_batches = [ids[i : i + 2048] for i in range(0, len(ids), 2048)]
|
||||
for batch_docs, batch_meta, batch_ids in zip(document_batches, metadata_batches, id_batches):
|
||||
try:
|
||||
# Add only valid batches
|
||||
if batch_docs:
|
||||
self.db.add(documents=batch_docs, metadatas=batch_meta, ids=batch_ids, **kwargs)
|
||||
except Exception as e:
|
||||
logger.info(f"Failed to add batch due to a bad request: {e}")
|
||||
# Handle the error, e.g., by logging, retrying, or skipping
|
||||
pass
|
||||
|
||||
count_new_chunks = self.db.count() - chunks_before_addition
|
||||
logger.info(f"Successfully saved {str(src)[:100]} ({chunker.data_type}). New chunks count: {count_new_chunks}")
|
||||
|
||||
return list(documents), metadatas, ids, count_new_chunks
|
||||
|
||||
@staticmethod
|
||||
def _format_result(results):
|
||||
return [
|
||||
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
||||
for result in zip(
|
||||
results["documents"][0],
|
||||
results["metadatas"][0],
|
||||
results["distances"][0],
|
||||
)
|
||||
]
|
||||
|
||||
def _retrieve_from_database(
|
||||
self,
|
||||
input_query: str,
|
||||
config: Optional[BaseLlmConfig] = None,
|
||||
where=None,
|
||||
citations: bool = False,
|
||||
**kwargs: Optional[dict[str, Any]],
|
||||
) -> Union[list[tuple[str, str, str]], list[str]]:
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query
|
||||
|
||||
:param input_query: The query to use.
|
||||
:type input_query: str
|
||||
:param config: The query configuration, defaults to None
|
||||
:type config: Optional[BaseLlmConfig], optional
|
||||
:param where: A dictionary of key-value pairs to filter the database results, defaults to None
|
||||
:type where: _type_, optional
|
||||
:param citations: A boolean to indicate if db should fetch citation source
|
||||
:type citations: bool
|
||||
:return: List of contents of the document that matched your query
|
||||
:rtype: list[str]
|
||||
"""
|
||||
query_config = config or self.llm.config
|
||||
if where is not None:
|
||||
where = where
|
||||
else:
|
||||
where = {}
|
||||
if query_config is not None and query_config.where is not None:
|
||||
where = query_config.where
|
||||
|
||||
if self.config.id is not None:
|
||||
where.update({"app_id": self.config.id})
|
||||
|
||||
contexts = self.db.query(
|
||||
input_query=input_query,
|
||||
n_results=query_config.number_documents,
|
||||
where=where,
|
||||
citations=citations,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
def query(
|
||||
self,
|
||||
input_query: str,
|
||||
config: BaseLlmConfig = None,
|
||||
dry_run=False,
|
||||
where: Optional[dict] = None,
|
||||
citations: bool = False,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]:
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
LLM as context to get the answer.
|
||||
|
||||
:param input_query: The query to use.
|
||||
:type input_query: str
|
||||
:param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
|
||||
To persistently use a config, declare it during app init., defaults to None
|
||||
:type config: BaseLlmConfig, optional
|
||||
:param dry_run: A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response., defaults to False
|
||||
:type dry_run: bool, optional
|
||||
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
||||
:type where: dict[str, str], optional
|
||||
:param citations: A boolean to indicate if db should fetch citation source
|
||||
:type citations: bool
|
||||
:param kwargs: To read more params for the query function. Ex. we use citations boolean
|
||||
param to return context along with the answer
|
||||
:type kwargs: dict[str, Any]
|
||||
:return: The answer to the query, with citations if the citation flag is True
|
||||
or the dry run result
|
||||
:rtype: str, if citations is False and token_usage is False, otherwise if citations is true then
|
||||
tuple[str, list[tuple[str,str,str]]] and if token_usage is true then
|
||||
tuple[str, list[tuple[str,str,str]], dict[str, Any]]
|
||||
"""
|
||||
contexts = self._retrieve_from_database(
|
||||
input_query=input_query, config=config, where=where, citations=citations, **kwargs
|
||||
)
|
||||
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
||||
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
||||
else:
|
||||
contexts_data_for_llm_query = contexts
|
||||
|
||||
if self.cache_config is not None:
|
||||
logger.info("Cache enabled. Checking cache...")
|
||||
answer = adapt(
|
||||
llm_handler=self.llm.query,
|
||||
cache_data_convert=gptcache_data_convert,
|
||||
update_cache_callback=gptcache_update_cache_callback,
|
||||
session=get_gptcache_session(session_id=self.config.id),
|
||||
input_query=input_query,
|
||||
contexts=contexts_data_for_llm_query,
|
||||
config=config,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
else:
|
||||
if self.llm.config.token_usage:
|
||||
answer, token_info = self.llm.query(
|
||||
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||
)
|
||||
else:
|
||||
answer = self.llm.query(
|
||||
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||
)
|
||||
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="query", properties=self._telemetry_props)
|
||||
|
||||
if citations:
|
||||
if self.llm.config.token_usage:
|
||||
return {"answer": answer, "contexts": contexts, "usage": token_info}
|
||||
return answer, contexts
|
||||
if self.llm.config.token_usage:
|
||||
return {"answer": answer, "usage": token_info}
|
||||
|
||||
logger.warning(
|
||||
"Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`."
|
||||
)
|
||||
return answer
|
||||
|
||||
def chat(
|
||||
self,
|
||||
input_query: str,
|
||||
config: Optional[BaseLlmConfig] = None,
|
||||
dry_run=False,
|
||||
session_id: str = "default",
|
||||
where: Optional[dict[str, str]] = None,
|
||||
citations: bool = False,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]:
|
||||
"""
|
||||
Queries the vector database on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
LLM as context to get the answer.
|
||||
|
||||
Maintains the whole conversation in memory.
|
||||
|
||||
:param input_query: The query to use.
|
||||
:type input_query: str
|
||||
:param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
|
||||
To persistently use a config, declare it during app init., defaults to None
|
||||
:type config: BaseLlmConfig, optional
|
||||
:param dry_run: A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response., defaults to False
|
||||
:type dry_run: bool, optional
|
||||
:param session_id: The session id to use for chat history, defaults to 'default'.
|
||||
:type session_id: str, optional
|
||||
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
||||
:type where: dict[str, str], optional
|
||||
:param citations: A boolean to indicate if db should fetch citation source
|
||||
:type citations: bool
|
||||
:param kwargs: To read more params for the query function. Ex. we use citations boolean
|
||||
param to return context along with the answer
|
||||
:type kwargs: dict[str, Any]
|
||||
:return: The answer to the query, with citations if the citation flag is True
|
||||
or the dry run result
|
||||
:rtype: str, if citations is False and token_usage is False, otherwise if citations is true then
|
||||
tuple[str, list[tuple[str,str,str]]] and if token_usage is true then
|
||||
tuple[str, list[tuple[str,str,str]], dict[str, Any]]
|
||||
"""
|
||||
contexts = self._retrieve_from_database(
|
||||
input_query=input_query, config=config, where=where, citations=citations, **kwargs
|
||||
)
|
||||
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
||||
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
||||
else:
|
||||
contexts_data_for_llm_query = contexts
|
||||
|
||||
memories = None
|
||||
if self.mem0_client:
|
||||
memories = self.mem0_client.search(
|
||||
query=input_query, agent_id=self.config.id, session_id=session_id, limit=self.memory_config.top_k
|
||||
)
|
||||
|
||||
# Update the history beforehand so that we can handle multiple chat sessions in the same python session
|
||||
self.llm.update_history(app_id=self.config.id, session_id=session_id)
|
||||
|
||||
if self.cache_config is not None:
|
||||
logger.debug("Cache enabled. Checking cache...")
|
||||
cache_id = f"{session_id}--{self.config.id}"
|
||||
answer = adapt(
|
||||
llm_handler=self.llm.chat,
|
||||
cache_data_convert=gptcache_data_convert,
|
||||
update_cache_callback=gptcache_update_cache_callback,
|
||||
session=get_gptcache_session(session_id=cache_id),
|
||||
input_query=input_query,
|
||||
contexts=contexts_data_for_llm_query,
|
||||
config=config,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
else:
|
||||
logger.debug("Cache disabled. Running chat without cache.")
|
||||
if self.llm.config.token_usage:
|
||||
answer, token_info = self.llm.query(
|
||||
input_query=input_query,
|
||||
contexts=contexts_data_for_llm_query,
|
||||
config=config,
|
||||
dry_run=dry_run,
|
||||
memories=memories,
|
||||
)
|
||||
else:
|
||||
answer = self.llm.query(
|
||||
input_query=input_query,
|
||||
contexts=contexts_data_for_llm_query,
|
||||
config=config,
|
||||
dry_run=dry_run,
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
# Add to Mem0 memory if enabled
|
||||
# TODO: Might need to prepend with some text like:
|
||||
# "Remember user preferences from following user query: {input_query}"
|
||||
if self.mem0_client:
|
||||
self.mem0_client.add(data=input_query, agent_id=self.config.id, session_id=session_id)
|
||||
|
||||
# add conversation in memory
|
||||
self.llm.add_history(self.config.id, input_query, answer, session_id=session_id)
|
||||
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
|
||||
|
||||
if citations:
|
||||
if self.llm.config.token_usage:
|
||||
return {"answer": answer, "contexts": contexts, "usage": token_info}
|
||||
return answer, contexts
|
||||
if self.llm.config.token_usage:
|
||||
return {"answer": answer, "usage": token_info}
|
||||
|
||||
logger.warning(
|
||||
"Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`."
|
||||
)
|
||||
return answer
|
||||
|
||||
def search(self, query, num_documents=3, where=None, raw_filter=None, namespace=None):
|
||||
"""
|
||||
Search for similar documents related to the query in the vector database.
|
||||
|
||||
Args:
|
||||
query (str): The query to use.
|
||||
num_documents (int, optional): Number of similar documents to fetch. Defaults to 3.
|
||||
where (dict[str, any], optional): Filter criteria for the search.
|
||||
raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.
|
||||
namespace (str, optional): The namespace to search in. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `raw_filter` and `where` are used simultaneously.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of dictionaries, each containing the 'context' and 'metadata' of a document.
|
||||
"""
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
|
||||
|
||||
if raw_filter and where:
|
||||
raise ValueError("You can't use both `raw_filter` and `where` together.")
|
||||
|
||||
filter_type = "raw_filter" if raw_filter else "where"
|
||||
filter_criteria = raw_filter if raw_filter else where
|
||||
|
||||
params = {
|
||||
"input_query": query,
|
||||
"n_results": num_documents,
|
||||
"citations": True,
|
||||
"app_id": self.config.id,
|
||||
"namespace": namespace,
|
||||
filter_type: filter_criteria,
|
||||
}
|
||||
|
||||
return [{"context": c[0], "metadata": c[1]} for c in self.db.query(**params)]
|
||||
|
||||
def set_collection_name(self, name: str):
|
||||
"""
|
||||
Set the name of the collection. A collection is an isolated space for vectors.
|
||||
|
||||
Using `app.db.set_collection_name` method is preferred to this.
|
||||
|
||||
:param name: Name of the collection.
|
||||
:type name: str
|
||||
"""
|
||||
self.db.set_collection_name(name)
|
||||
# Create the collection if it does not exist
|
||||
self.db._get_or_create_collection(name)
|
||||
# TODO: Check whether it is necessary to assign to the `self.collection` attribute,
|
||||
# since the main purpose is the creation.
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the database. Deletes all embeddings irreversibly.
|
||||
`App` does not have to be reinitialized after using this method.
|
||||
"""
|
||||
try:
|
||||
self.db_session.query(DataSource).filter_by(app_id=self.config.id).delete()
|
||||
self.db_session.query(ChatHistory).filter_by(app_id=self.config.id).delete()
|
||||
self.db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting data sources: {e}")
|
||||
self.db_session.rollback()
|
||||
return None
|
||||
self.db.reset()
|
||||
self.delete_all_chat_history(app_id=self.config.id)
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
|
||||
|
||||
def get_history(
|
||||
self,
|
||||
num_rounds: int = 10,
|
||||
display_format: bool = True,
|
||||
session_id: Optional[str] = "default",
|
||||
fetch_all: bool = False,
|
||||
):
|
||||
history = self.llm.memory.get(
|
||||
app_id=self.config.id,
|
||||
session_id=session_id,
|
||||
num_rounds=num_rounds,
|
||||
display_format=display_format,
|
||||
fetch_all=fetch_all,
|
||||
)
|
||||
return history
|
||||
|
||||
def delete_session_chat_history(self, session_id: str = "default"):
|
||||
self.llm.memory.delete(app_id=self.config.id, session_id=session_id)
|
||||
self.llm.update_history(app_id=self.config.id)
|
||||
|
||||
def delete_all_chat_history(self, app_id: str):
|
||||
self.llm.memory.delete(app_id=app_id)
|
||||
self.llm.update_history(app_id=app_id)
|
||||
|
||||
def delete(self, source_id: str):
|
||||
"""
|
||||
Deletes the data from the database.
|
||||
:param source_hash: The hash of the source.
|
||||
:type source_hash: str
|
||||
"""
|
||||
try:
|
||||
self.db_session.query(DataSource).filter_by(hash=source_id, app_id=self.config.id).delete()
|
||||
self.db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting data sources: {e}")
|
||||
self.db_session.rollback()
|
||||
return None
|
||||
self.db.delete(where={"hash": source_id})
|
||||
logger.info(f"Successfully deleted {source_id}")
|
||||
# Send anonymous telemetry
|
||||
if self.config.collect_metrics:
|
||||
self.telemetry.capture(event_name="delete", properties=self._telemetry_props)
|
||||
0
embedchain/embedchain/embedder/__init__.py
Normal file
0
embedchain/embedchain/embedder/__init__.py
Normal file
22
embedchain/embedchain/embedder/azure_openai.py
Normal file
22
embedchain/embedchain/embedder/azure_openai.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain_community.embeddings import AzureOpenAIEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
|
||||
class AzureOpenAIEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
if self.config.model is None:
|
||||
self.config.model = "text-embedding-ada-002"
|
||||
|
||||
embeddings = AzureOpenAIEmbeddings(deployment=self.config.deployment_name)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
vector_dimension = self.config.vector_dimension or VectorDimensions.OPENAI.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
90
embedchain/embedchain/embedder/base.py
Normal file
90
embedchain/embedchain/embedder/base.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
|
||||
try:
|
||||
from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
|
||||
except RuntimeError:
|
||||
from embedchain.utils.misc import use_pysqlite3
|
||||
|
||||
use_pysqlite3()
|
||||
from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
|
||||
|
||||
|
||||
class EmbeddingFunc(EmbeddingFunction):
|
||||
def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
|
||||
self.embedding_fn = embedding_fn
|
||||
|
||||
def __call__(self, input: Embeddable) -> Embeddings:
|
||||
return self.embedding_fn(input)
|
||||
|
||||
|
||||
class BaseEmbedder:
|
||||
"""
|
||||
Class that manages everything regarding embeddings. Including embedding function, loaders and chunkers.
|
||||
|
||||
Embedding functions and vector dimensions are set based on the child class you choose.
|
||||
To manually overwrite you can use this classes `set_...` methods.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
"""
|
||||
Initialize the embedder class.
|
||||
|
||||
:param config: embedder configuration option class, defaults to None
|
||||
:type config: Optional[BaseEmbedderConfig], optional
|
||||
"""
|
||||
if config is None:
|
||||
self.config = BaseEmbedderConfig()
|
||||
else:
|
||||
self.config = config
|
||||
self.vector_dimension: int
|
||||
|
||||
def set_embedding_fn(self, embedding_fn: Callable[[list[str]], list[str]]):
|
||||
"""
|
||||
Set or overwrite the embedding function to be used by the database to store and retrieve documents.
|
||||
|
||||
:param embedding_fn: Function to be used to generate embeddings.
|
||||
:type embedding_fn: Callable[[list[str]], list[str]]
|
||||
:raises ValueError: Embedding function is not callable.
|
||||
"""
|
||||
if not hasattr(embedding_fn, "__call__"):
|
||||
raise ValueError("Embedding function is not a function")
|
||||
self.embedding_fn = embedding_fn
|
||||
|
||||
def set_vector_dimension(self, vector_dimension: int):
|
||||
"""
|
||||
Set or overwrite the vector dimension size
|
||||
|
||||
:param vector_dimension: vector dimension size
|
||||
:type vector_dimension: int
|
||||
"""
|
||||
if not isinstance(vector_dimension, int):
|
||||
raise TypeError("vector dimension must be int")
|
||||
self.vector_dimension = vector_dimension
|
||||
|
||||
@staticmethod
|
||||
def _langchain_default_concept(embeddings: Any):
|
||||
"""
|
||||
Langchains default function layout for embeddings.
|
||||
|
||||
:param embeddings: Langchain embeddings
|
||||
:type embeddings: Any
|
||||
:return: embedding function
|
||||
:rtype: Callable
|
||||
"""
|
||||
|
||||
return EmbeddingFunc(embeddings.embed_documents)
|
||||
|
||||
def to_embeddings(self, data: str, **_):
|
||||
"""
|
||||
Convert data to embeddings
|
||||
|
||||
:param data: data to convert to embeddings
|
||||
:type data: str
|
||||
:return: embeddings
|
||||
:rtype: list[float]
|
||||
"""
|
||||
embeddings = self.embedding_fn([data])
|
||||
return embeddings[0]
|
||||
52
embedchain/embedchain/embedder/clarifai.py
Normal file
52
embedchain/embedchain/embedder/clarifai.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
|
||||
from chromadb import EmbeddingFunction, Embeddings
|
||||
|
||||
|
||||
class ClarifaiEmbeddingFunction(EmbeddingFunction):
|
||||
def __init__(self, config: BaseEmbedderConfig) -> None:
|
||||
super().__init__()
|
||||
try:
|
||||
from clarifai.client.model import Model
|
||||
from clarifai.client.input import Inputs
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for ClarifaiEmbeddingFunction are not installed."
|
||||
'Please install with `pip install --upgrade "embedchain[clarifai]"`'
|
||||
) from None
|
||||
self.config = config
|
||||
self.api_key = config.api_key or os.getenv("CLARIFAI_PAT")
|
||||
self.model = config.model
|
||||
self.model_obj = Model(url=self.model, pat=self.api_key)
|
||||
self.input_obj = Inputs(pat=self.api_key)
|
||||
|
||||
def __call__(self, input: Union[str, list[str]]) -> Embeddings:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
batch_size = 32
|
||||
embeddings = []
|
||||
try:
|
||||
for i in range(0, len(input), batch_size):
|
||||
batch = input[i : i + batch_size]
|
||||
input_batch = [
|
||||
self.input_obj.get_text_input(input_id=str(id), raw_text=inp) for id, inp in enumerate(batch)
|
||||
]
|
||||
response = self.model_obj.predict(input_batch)
|
||||
embeddings.extend([list(output.data.embeddings[0].vector) for output in response.outputs])
|
||||
except Exception as e:
|
||||
print(f"Predict failed, exception: {e}")
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class ClarifaiEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
embedding_func = ClarifaiEmbeddingFunction(config=self.config)
|
||||
self.set_embedding_fn(embedding_fn=embedding_func)
|
||||
19
embedchain/embedchain/embedder/cohere.py
Normal file
19
embedchain/embedchain/embedder/cohere.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain_cohere.embeddings import CohereEmbeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
|
||||
class CohereEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
embeddings = CohereEmbeddings(model=self.config.model)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = self.config.vector_dimension or VectorDimensions.COHERE.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
38
embedchain/embedchain/embedder/google.py
Normal file
38
embedchain/embedchain/embedder/google.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import google.generativeai as genai
|
||||
from chromadb import EmbeddingFunction, Embeddings
|
||||
|
||||
from embedchain.config.embedder.google import GoogleAIEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
|
||||
class GoogleAIEmbeddingFunction(EmbeddingFunction):
|
||||
def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None) -> None:
|
||||
super().__init__()
|
||||
self.config = config or GoogleAIEmbedderConfig()
|
||||
|
||||
def __call__(self, input: Union[list[str], str]) -> Embeddings:
|
||||
model = self.config.model
|
||||
title = self.config.title
|
||||
task_type = self.config.task_type
|
||||
if isinstance(input, str):
|
||||
input_ = [input]
|
||||
else:
|
||||
input_ = input
|
||||
data = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
|
||||
embeddings = data["embedding"]
|
||||
if isinstance(input_, str):
|
||||
embeddings = [embeddings]
|
||||
return embeddings
|
||||
|
||||
|
||||
class GoogleAIEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[GoogleAIEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
embedding_fn = GoogleAIEmbeddingFunction(config=config)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = self.config.vector_dimension or VectorDimensions.GOOGLE_AI.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
20
embedchain/embedchain/embedder/gpt4all.py
Normal file
20
embedchain/embedchain/embedder/gpt4all.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
|
||||
class GPT4AllEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
from langchain_community.embeddings import GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
|
||||
|
||||
model_name = self.config.model or "all-MiniLM-L6-v2-f16.gguf"
|
||||
embeddings = LangchainGPT4AllEmbeddings(model_name=model_name)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = self.config.vector_dimension or VectorDimensions.GPT4ALL.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
40
embedchain/embedchain/embedder/huggingface.py
Normal file
40
embedchain/embedchain/embedder/huggingface.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
try:
|
||||
from langchain_huggingface import HuggingFaceEndpointEmbeddings
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for HuggingFaceHub are not installed."
|
||||
"Please install with `pip install langchain_huggingface`"
|
||||
) from None
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
|
||||
class HuggingFaceEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
if self.config.endpoint:
|
||||
if not self.config.api_key and "HUGGINGFACE_ACCESS_TOKEN" not in os.environ:
|
||||
raise ValueError(
|
||||
"Please set the HUGGINGFACE_ACCESS_TOKEN environment variable or pass API Key in the config."
|
||||
)
|
||||
|
||||
embeddings = HuggingFaceEndpointEmbeddings(
|
||||
model=self.config.endpoint,
|
||||
huggingfacehub_api_token=self.config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
|
||||
)
|
||||
else:
|
||||
embeddings = HuggingFaceEmbeddings(model_name=self.config.model, model_kwargs=self.config.model_kwargs)
|
||||
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = self.config.vector_dimension or VectorDimensions.HUGGING_FACE.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
46
embedchain/embedchain/embedder/mistralai.py
Normal file
46
embedchain/embedchain/embedder/mistralai.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from chromadb import EmbeddingFunction, Embeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
|
||||
class MistralAIEmbeddingFunction(EmbeddingFunction):
|
||||
def __init__(self, config: BaseEmbedderConfig) -> None:
|
||||
super().__init__()
|
||||
try:
|
||||
from langchain_mistralai import MistralAIEmbeddings
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for MistralAI are not installed."
|
||||
'Please install with `pip install --upgrade "embedchain[mistralai]"`'
|
||||
) from None
|
||||
self.config = config
|
||||
api_key = self.config.api_key or os.getenv("MISTRAL_API_KEY")
|
||||
self.client = MistralAIEmbeddings(mistral_api_key=api_key)
|
||||
self.client.model = self.config.model
|
||||
|
||||
def __call__(self, input: Union[list[str], str]) -> Embeddings:
|
||||
if isinstance(input, str):
|
||||
input_ = [input]
|
||||
else:
|
||||
input_ = input
|
||||
response = self.client.embed_documents(input_)
|
||||
return response
|
||||
|
||||
|
||||
class MistralAIEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if self.config.model is None:
|
||||
self.config.model = "mistral-embed"
|
||||
|
||||
embedding_fn = MistralAIEmbeddingFunction(config=self.config)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
vector_dimension = self.config.vector_dimension or VectorDimensions.MISTRAL_AI.value
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user