[Refactor] Converge Pipeline and App classes (#1021)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -2,10 +2,9 @@ import importlib.metadata
|
||||
|
||||
__version__ = importlib.metadata.version(__package__ or __name__)
|
||||
|
||||
from embedchain.apps.app import App # noqa: F401
|
||||
from embedchain.app import App # noqa: F401
|
||||
from embedchain.client import Client # noqa: F401
|
||||
from embedchain.pipeline import Pipeline # noqa: F401
|
||||
from embedchain.vectordb.chroma import ChromaDB # noqa: F401
|
||||
|
||||
# Setup the user directory if doesn't exist already
|
||||
Client.setup_dir()
|
||||
|
||||
431
embedchain/app.py
Normal file
431
embedchain/app.py
Normal file
@@ -0,0 +1,431 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from embedchain.client import Client
|
||||
from embedchain.config import AppConfig, ChunkerConfig
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
# Setup the user directory if doesn't exist already
|
||||
Client.setup_dir()
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class App(EmbedChain):
|
||||
"""
|
||||
EmbedChain App lets you create a LLM powered app for your unstructured
|
||||
data by defining your chosen data source, embedding model,
|
||||
and vector database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = None,
|
||||
name: str = None,
|
||||
config: AppConfig = None,
|
||||
db: BaseVectorDB = None,
|
||||
embedding_model: BaseEmbedder = None,
|
||||
llm: BaseLlm = None,
|
||||
config_data: dict = None,
|
||||
log_level=logging.WARN,
|
||||
auto_deploy: bool = False,
|
||||
chunker: ChunkerConfig = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new `App` instance.
|
||||
|
||||
:param config: Configuration for the pipeline, defaults to None
|
||||
:type config: AppConfig, optional
|
||||
:param db: The database to use for storing and retrieving embeddings, defaults to None
|
||||
:type db: BaseVectorDB, optional
|
||||
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
|
||||
:type embedding_model: BaseEmbedder, optional
|
||||
:param llm: The LLM model used to calculate embeddings, defaults to None
|
||||
:type llm: BaseLlm, optional
|
||||
:param config_data: Config dictionary, defaults to None
|
||||
:type config_data: dict, optional
|
||||
:param log_level: Log level to use, defaults to logging.WARN
|
||||
:type log_level: int, optional
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:raises Exception: If an error occurs while creating the pipeline
|
||||
"""
|
||||
if id and config_data:
|
||||
raise Exception("Cannot provide both id and config. Please provide only one of them.")
|
||||
|
||||
if id and name:
|
||||
raise Exception("Cannot provide both id and name. Please provide only one of them.")
|
||||
|
||||
if name and config:
|
||||
raise Exception("Cannot provide both name and config. Please provide only one of them.")
|
||||
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.auto_deploy = auto_deploy
|
||||
# Store the dict config as an attribute to be able to send it
|
||||
self.config_data = config_data if (config_data and validate_config(config_data)) else None
|
||||
self.client = None
|
||||
# pipeline_id from the backend
|
||||
self.id = None
|
||||
self.chunker = None
|
||||
if chunker:
|
||||
self.chunker = ChunkerConfig(**chunker)
|
||||
|
||||
self.config = config or AppConfig()
|
||||
self.name = self.config.name
|
||||
self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id
|
||||
|
||||
if id is not None:
|
||||
# Init client first since user is trying to fetch the pipeline
|
||||
# details from the platform
|
||||
self._init_client()
|
||||
pipeline_details = self._get_pipeline(id)
|
||||
self.config.id = self.local_id = pipeline_details["metadata"]["local_id"]
|
||||
self.id = id
|
||||
|
||||
if name is not None:
|
||||
self.name = name
|
||||
|
||||
self.embedding_model = embedding_model or OpenAIEmbedder()
|
||||
self.db = db or ChromaDB()
|
||||
self.llm = llm or OpenAILlm()
|
||||
self._init_db()
|
||||
|
||||
# Send anonymous telemetry
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||
|
||||
# Establish a connection to the SQLite database
|
||||
self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
# Create the 'data_sources' table if it doesn't exist
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS data_sources (
|
||||
pipeline_id TEXT,
|
||||
hash TEXT,
|
||||
type TEXT,
|
||||
value TEXT,
|
||||
metadata TEXT,
|
||||
is_uploaded INTEGER DEFAULT 0,
|
||||
PRIMARY KEY (pipeline_id, hash)
|
||||
)
|
||||
"""
|
||||
)
|
||||
self.connection.commit()
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
|
||||
|
||||
self.user_asks = []
|
||||
if self.auto_deploy:
|
||||
self.deploy()
|
||||
|
||||
def _init_db(self):
|
||||
"""
|
||||
Initialize the database.
|
||||
"""
|
||||
self.db._set_embedder(self.embedding_model)
|
||||
self.db._initialize()
|
||||
self.db.set_collection_name(self.db.config.collection_name)
|
||||
|
||||
def _init_client(self):
|
||||
"""
|
||||
Initialize the client.
|
||||
"""
|
||||
config = Client.load_config()
|
||||
if config.get("api_key"):
|
||||
self.client = Client()
|
||||
else:
|
||||
api_key = input(
|
||||
"🔑 Enter your Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n" # noqa: E501
|
||||
)
|
||||
self.client = Client(api_key=api_key)
|
||||
|
||||
def _get_pipeline(self, id):
|
||||
"""
|
||||
Get existing pipeline
|
||||
"""
|
||||
print("🛠️ Fetching pipeline details from the platform...")
|
||||
url = f"{self.client.host}/api/v1/pipelines/{id}/cli/"
|
||||
r = requests.get(
|
||||
url,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
if r.status_code == 404:
|
||||
raise Exception(f"❌ Pipeline with id {id} not found!")
|
||||
|
||||
print(
|
||||
f"🎉 Pipeline loaded successfully! Pipeline url: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
)
|
||||
return r.json()
|
||||
|
||||
def _create_pipeline(self):
|
||||
"""
|
||||
Create a pipeline on the platform.
|
||||
"""
|
||||
print("🛠️ Creating pipeline on the platform...")
|
||||
# self.config_data is a dict. Pass it inside the key 'yaml_config' to the backend
|
||||
payload = {
|
||||
"yaml_config": json.dumps(self.config_data),
|
||||
"name": self.name,
|
||||
"local_id": self.local_id,
|
||||
}
|
||||
url = f"{self.client.host}/api/v1/pipelines/cli/create/"
|
||||
r = requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
if r.status_code not in [200, 201]:
|
||||
raise Exception(f"❌ Error occurred while creating pipeline. API response: {r.text}")
|
||||
|
||||
if r.status_code == 200:
|
||||
print(
|
||||
f"🎉🎉🎉 Existing pipeline found! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
) # noqa: E501
|
||||
elif r.status_code == 201:
|
||||
print(
|
||||
f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
)
|
||||
return r.json()
|
||||
|
||||
def _get_presigned_url(self, data_type, data_value):
|
||||
payload = {"data_type": data_type, "data_value": data_value}
|
||||
r = requests.post(
|
||||
f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
def search(self, query, num_documents=3):
|
||||
"""
|
||||
Search for similar documents related to the query in the vector database.
|
||||
"""
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
|
||||
|
||||
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
|
||||
if self.id is None:
|
||||
where = {"app_id": self.local_id}
|
||||
context = self.db.query(
|
||||
query,
|
||||
n_results=num_documents,
|
||||
where=where,
|
||||
skip_embedding=False,
|
||||
citations=True,
|
||||
)
|
||||
result = []
|
||||
for c in context:
|
||||
result.append(
|
||||
{
|
||||
"context": c[0],
|
||||
"source": c[1],
|
||||
"document_id": c[2],
|
||||
}
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# Make API call to the backend to get the results
|
||||
NotImplementedError("Search is not implemented yet for the prod mode.")
|
||||
|
||||
def _upload_file_to_presigned_url(self, presigned_url, file_path):
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
response = requests.put(presigned_url, data=file)
|
||||
response.raise_for_status()
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error occurred during file upload: {str(e)}")
|
||||
print("❌ Error occurred during file upload!")
|
||||
return False
|
||||
|
||||
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
|
||||
payload = {
|
||||
"data_type": data_type,
|
||||
"data_value": data_value,
|
||||
"metadata": metadata,
|
||||
}
|
||||
try:
|
||||
self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
|
||||
# print the local file path if user tries to upload a local file
|
||||
printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
|
||||
print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
|
||||
except Exception as e:
|
||||
print(f"❌ Error occurred during data upload for type {data_type}!. Error: {str(e)}")
|
||||
|
||||
def _send_api_request(self, endpoint, payload):
|
||||
url = f"{self.client.host}{endpoint}"
|
||||
headers = {"Authorization": f"Token {self.client.api_key}"}
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _process_and_upload_data(self, data_hash, data_type, data_value):
|
||||
if os.path.isabs(data_value):
|
||||
presigned_url_data = self._get_presigned_url(data_type, data_value)
|
||||
presigned_url = presigned_url_data["presigned_url"]
|
||||
s3_key = presigned_url_data["s3_key"]
|
||||
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
|
||||
metadata = {"file_path": data_value, "s3_key": s3_key}
|
||||
data_value = presigned_url
|
||||
else:
|
||||
self.logger.error(f"File upload failed for hash: {data_hash}")
|
||||
return False
|
||||
else:
|
||||
if data_type == "qna_pair":
|
||||
data_value = list(ast.literal_eval(data_value))
|
||||
metadata = {}
|
||||
|
||||
try:
|
||||
self._upload_data_to_pipeline(data_type, data_value, metadata)
|
||||
self._mark_data_as_uploaded(data_hash)
|
||||
return True
|
||||
except Exception:
|
||||
print(f"❌ Error occurred during data upload for hash {data_hash}!")
|
||||
return False
|
||||
|
||||
def _mark_data_as_uploaded(self, data_hash):
|
||||
self.cursor.execute(
|
||||
"UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ?",
|
||||
(data_hash, self.local_id),
|
||||
)
|
||||
self.connection.commit()
|
||||
|
||||
def get_data_sources(self):
|
||||
db_data = self.cursor.execute("SELECT * FROM data_sources WHERE pipeline_id = ?", (self.local_id,)).fetchall()
|
||||
|
||||
data_sources = []
|
||||
for data in db_data:
|
||||
data_sources.append({"data_type": data[2], "data_value": data[3], "metadata": data[4]})
|
||||
|
||||
return data_sources
|
||||
|
||||
def deploy(self):
|
||||
if self.client is None:
|
||||
self._init_client()
|
||||
|
||||
pipeline_data = self._create_pipeline()
|
||||
self.id = pipeline_data["id"]
|
||||
|
||||
results = self.cursor.execute(
|
||||
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,) # noqa:E501
|
||||
).fetchall()
|
||||
|
||||
if len(results) > 0:
|
||||
print("🛠️ Adding data to your pipeline...")
|
||||
for result in results:
|
||||
data_hash, data_type, data_value = result[1], result[2], result[3]
|
||||
self._process_and_upload_data(data_hash, data_type, data_value)
|
||||
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config_path: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
auto_deploy: bool = False,
|
||||
yaml_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Instantiate a Pipeline object from a configuration.
|
||||
|
||||
:param config_path: Path to the YAML or JSON configuration file.
|
||||
:type config_path: Optional[str]
|
||||
:param config: A dictionary containing the configuration.
|
||||
:type config: Optional[Dict[str, Any]]
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
|
||||
:type yaml_path: Optional[str]
|
||||
:return: An instance of the Pipeline class.
|
||||
:rtype: Pipeline
|
||||
"""
|
||||
# Backward compatibility for yaml_path
|
||||
if yaml_path and not config_path:
|
||||
config_path = yaml_path
|
||||
|
||||
if config_path and config:
|
||||
raise ValueError("Please provide only one of config_path or config.")
|
||||
|
||||
config_data = None
|
||||
|
||||
if config_path:
|
||||
file_extension = os.path.splitext(config_path)[1]
|
||||
with open(config_path, "r") as file:
|
||||
if file_extension in [".yaml", ".yml"]:
|
||||
config_data = yaml.safe_load(file)
|
||||
elif file_extension == ".json":
|
||||
config_data = json.load(file)
|
||||
else:
|
||||
raise ValueError("config_path must be a path to a YAML or JSON file.")
|
||||
elif config and isinstance(config, dict):
|
||||
config_data = config
|
||||
else:
|
||||
logging.error(
|
||||
"Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
|
||||
)
|
||||
config_data = {}
|
||||
|
||||
try:
|
||||
validate_config(config_data)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error occurred while validating the config. Error: {str(e)}")
|
||||
|
||||
app_config_data = config_data.get("app", {}).get("config", {})
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
chunker_config_data = config_data.get("chunker", {})
|
||||
|
||||
app_config = AppConfig(**app_config_data)
|
||||
|
||||
db_provider = db_config_data.get("provider", "chroma")
|
||||
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
|
||||
|
||||
if llm_config_data:
|
||||
llm_provider = llm_config_data.get("provider", "openai")
|
||||
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
|
||||
else:
|
||||
llm = None
|
||||
|
||||
embedding_model_provider = embedding_model_config_data.get("provider", "openai")
|
||||
embedding_model = EmbedderFactory.create(
|
||||
embedding_model_provider, embedding_model_config_data.get("config", {})
|
||||
)
|
||||
|
||||
# Send anonymous telemetry
|
||||
event_properties = {"init_type": "config_data"}
|
||||
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
|
||||
|
||||
return cls(
|
||||
config=app_config,
|
||||
llm=llm,
|
||||
db=db,
|
||||
embedding_model=embedding_model,
|
||||
config_data=config_data,
|
||||
auto_deploy=auto_deploy,
|
||||
chunker=chunker_config_data,
|
||||
)
|
||||
@@ -1,157 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
|
||||
ChunkerConfig)
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.utils import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class App(EmbedChain):
|
||||
"""
|
||||
The EmbedChain app in it's simplest and most straightforward form.
|
||||
An opinionated choice of LLM, vector database and embedding model.
|
||||
|
||||
Methods:
|
||||
add(source, data_type): adds the data from the given URL to the vector db.
|
||||
query(query): finds answer to the given query using vector database and LLM.
|
||||
chat(query): finds answer to the given query using vector database and LLM, with conversation history.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[AppConfig] = None,
|
||||
llm: BaseLlm = None,
|
||||
llm_config: Optional[BaseLlmConfig] = None,
|
||||
db: BaseVectorDB = None,
|
||||
db_config: Optional[BaseVectorDbConfig] = None,
|
||||
embedder: BaseEmbedder = None,
|
||||
embedder_config: Optional[BaseEmbedderConfig] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
chunker: Optional[ChunkerConfig] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new `App` instance.
|
||||
|
||||
:param config: Config for the app instance., defaults to None
|
||||
:type config: Optional[AppConfig], optional
|
||||
:param llm: LLM Class instance. example: `from embedchain.llm.openai import OpenAILlm`, defaults to OpenAiLlm
|
||||
:type llm: BaseLlm, optional
|
||||
:param llm_config: Allows you to configure the LLM, e.g. how many documents to return,
|
||||
example: `from embedchain.config import BaseLlmConfig`, defaults to None
|
||||
:type llm_config: Optional[BaseLlmConfig], optional
|
||||
:param db: The database to use for storing and retrieving embeddings,
|
||||
example: `from embedchain.vectordb.chroma_db import ChromaDb`, defaults to ChromaDb
|
||||
:type db: BaseVectorDB, optional
|
||||
:param db_config: Allows you to configure the vector database,
|
||||
example: `from embedchain.config import ChromaDbConfig`, defaults to None
|
||||
:type db_config: Optional[BaseVectorDbConfig], optional
|
||||
:param embedder: The embedder (embedding model and function) use to calculate embeddings.
|
||||
example: `from embedchain.embedder.gpt4all_embedder import GPT4AllEmbedder`, defaults to OpenAIEmbedder
|
||||
:type embedder: BaseEmbedder, optional
|
||||
:param embedder_config: Allows you to configure the Embedder.
|
||||
example: `from embedchain.config import BaseEmbedderConfig`, defaults to None
|
||||
:type embedder_config: Optional[BaseEmbedderConfig], optional
|
||||
:param system_prompt: System prompt that will be provided to the LLM as such, defaults to None
|
||||
:type system_prompt: Optional[str], optional
|
||||
:raises TypeError: LLM, database or embedder or their config is not a valid class instance.
|
||||
"""
|
||||
# Type check configs
|
||||
if config and not isinstance(config, AppConfig):
|
||||
raise TypeError(
|
||||
"Config is not a `AppConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if llm_config and not isinstance(llm_config, BaseLlmConfig):
|
||||
raise TypeError(
|
||||
"`llm_config` is not a `BaseLlmConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if db_config and not isinstance(db_config, BaseVectorDbConfig):
|
||||
raise TypeError(
|
||||
"`db_config` is not a `BaseVectorDbConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if embedder_config and not isinstance(embedder_config, BaseEmbedderConfig):
|
||||
raise TypeError(
|
||||
"`embedder_config` is not a `BaseEmbedderConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
|
||||
# Assign defaults
|
||||
if config is None:
|
||||
config = AppConfig()
|
||||
if llm is None:
|
||||
llm = OpenAILlm(config=llm_config)
|
||||
if db is None:
|
||||
db = ChromaDB(config=db_config)
|
||||
if embedder is None:
|
||||
embedder = OpenAIEmbedder(config=embedder_config)
|
||||
|
||||
self.chunker = None
|
||||
if chunker:
|
||||
self.chunker = ChunkerConfig(**chunker)
|
||||
# Type check assignments
|
||||
if not isinstance(llm, BaseLlm):
|
||||
raise TypeError(
|
||||
"LLM is not a `BaseLlm` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if not isinstance(db, BaseVectorDB):
|
||||
raise TypeError(
|
||||
"Database is not a `BaseVectorDB` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if not isinstance(embedder, BaseEmbedder):
|
||||
raise TypeError(
|
||||
"Embedder is not a `BaseEmbedder` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
super().__init__(config, llm=llm, db=db, embedder=embedder, system_prompt=system_prompt)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, yaml_path: str):
|
||||
"""
|
||||
Instantiate an App object from a YAML configuration file.
|
||||
|
||||
:param yaml_path: Path to the YAML configuration file.
|
||||
:type yaml_path: str
|
||||
:return: An instance of the App class.
|
||||
:rtype: App
|
||||
"""
|
||||
with open(yaml_path, "r") as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
|
||||
try:
|
||||
validate_config(config_data)
|
||||
except Exception as e:
|
||||
raise Exception(f"❌ Error occurred while validating the YAML config. Error: {str(e)}")
|
||||
|
||||
app_config_data = config_data.get("app", {})
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
chunker_config_data = config_data.get("chunker", {})
|
||||
|
||||
app_config = AppConfig(**app_config_data.get("config", {}))
|
||||
|
||||
llm_provider = llm_config_data.get("provider", "openai")
|
||||
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
|
||||
|
||||
db_provider = db_config_data.get("provider", "chroma")
|
||||
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
|
||||
|
||||
embedder_provider = embedding_model_config_data.get("provider", "openai")
|
||||
embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {}))
|
||||
return cls(config=app_config, llm=llm, db=db, embedder=embedder, chunker=chunker_config_data)
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, PipelineConfig
|
||||
from embedchain import App
|
||||
from embedchain.config import AddConfig, AppConfig, BaseLlmConfig
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.helpers.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
@@ -12,7 +12,7 @@ from embedchain.vectordb.chroma import ChromaDB
|
||||
@register_deserializable
|
||||
class BaseBot(JSONSerializable):
|
||||
def __init__(self):
|
||||
self.app = App(config=PipelineConfig(), llm=OpenAILlm(), db=ChromaDB(), embedding_model=OpenAIEmbedder())
|
||||
self.app = App(config=AppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedding_model=OpenAIEmbedder())
|
||||
|
||||
def add(self, data: Any, config: AddConfig = None):
|
||||
"""
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from .add_config import AddConfig, ChunkerConfig
|
||||
from .apps.app_config import AppConfig
|
||||
from .app_config import AppConfig
|
||||
from .base_config import BaseConfig
|
||||
from .embedder.base import BaseEmbedderConfig
|
||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||
from .llm.base import BaseLlmConfig
|
||||
from .pipeline_config import PipelineConfig
|
||||
from .vectordb.chroma import ChromaDbConfig
|
||||
from .vectordb.elasticsearch import ElasticsearchDBConfig
|
||||
from .vectordb.opensearch import OpenSearchDBConfig
|
||||
|
||||
@@ -15,8 +15,9 @@ class AppConfig(BaseAppConfig):
|
||||
self,
|
||||
log_level: str = "WARNING",
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
collect_metrics: Optional[bool] = True,
|
||||
collection_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an App. This is the simplest form of an embedchain app.
|
||||
@@ -28,8 +29,6 @@ class AppConfig(BaseAppConfig):
|
||||
:type id: Optional[str], optional
|
||||
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
|
||||
:type collect_metrics: Optional[bool], optional
|
||||
:param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
|
||||
defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
"""
|
||||
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name)
|
||||
self.name = name
|
||||
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, **kwargs)
|
||||
@@ -1,38 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .apps.base_app_config import BaseAppConfig
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PipelineConfig(BaseAppConfig):
|
||||
"""
|
||||
Config to initialize an embedchain custom `App` instance, with extra config options.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_level: str = "WARNING",
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
collect_metrics: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an App. This is the simplest form of an embedchain app.
|
||||
Most of the configuration is done in the `App` class itself.
|
||||
|
||||
:param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
|
||||
:type log_level: str, optional
|
||||
:param id: ID of the app. Document metadata will have this id., defaults to None
|
||||
:type id: Optional[str], optional
|
||||
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
|
||||
:type collect_metrics: Optional[bool], optional
|
||||
:param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
|
||||
defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
"""
|
||||
self._setup_logging(log_level)
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.collect_metrics = collect_metrics
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xxx"
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import streamlit as st
|
||||
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain.docstore.document import Document
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config.apps.base_app_config import BaseAppConfig
|
||||
from embedchain.config.base_app_config import BaseAppConfig
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
|
||||
@@ -1,425 +1,9 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from embedchain import Client
|
||||
from embedchain.config import ChunkerConfig, PipelineConfig
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
# Setup the user directory if doesn't exist already
|
||||
Client.setup_dir()
|
||||
from embedchain.app import App
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class Pipeline(EmbedChain):
|
||||
class Pipeline(App):
|
||||
"""
|
||||
EmbedChain pipeline lets you create a LLM powered app for your unstructured
|
||||
data by defining a pipeline with your chosen data source, embedding model,
|
||||
and vector database.
|
||||
This is deprecated. Use `App` instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = None,
|
||||
name: str = None,
|
||||
config: PipelineConfig = None,
|
||||
db: BaseVectorDB = None,
|
||||
embedding_model: BaseEmbedder = None,
|
||||
llm: BaseLlm = None,
|
||||
config_data: dict = None,
|
||||
log_level=logging.WARN,
|
||||
auto_deploy: bool = False,
|
||||
chunker: ChunkerConfig = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new `App` instance.
|
||||
|
||||
:param config: Configuration for the pipeline, defaults to None
|
||||
:type config: PipelineConfig, optional
|
||||
:param db: The database to use for storing and retrieving embeddings, defaults to None
|
||||
:type db: BaseVectorDB, optional
|
||||
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
|
||||
:type embedding_model: BaseEmbedder, optional
|
||||
:param llm: The LLM model used to calculate embeddings, defaults to None
|
||||
:type llm: BaseLlm, optional
|
||||
:param config_data: Config dictionary, defaults to None
|
||||
:type config_data: dict, optional
|
||||
:param log_level: Log level to use, defaults to logging.WARN
|
||||
:type log_level: int, optional
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:raises Exception: If an error occurs while creating the pipeline
|
||||
"""
|
||||
if id and config_data:
|
||||
raise Exception("Cannot provide both id and config. Please provide only one of them.")
|
||||
|
||||
if id and name:
|
||||
raise Exception("Cannot provide both id and name. Please provide only one of them.")
|
||||
|
||||
if name and config:
|
||||
raise Exception("Cannot provide both name and config. Please provide only one of them.")
|
||||
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.auto_deploy = auto_deploy
|
||||
# Store the dict config as an attribute to be able to send it
|
||||
self.config_data = config_data if (config_data and validate_config(config_data)) else None
|
||||
self.client = None
|
||||
# pipeline_id from the backend
|
||||
self.id = None
|
||||
self.chunker = None
|
||||
if chunker:
|
||||
self.chunker = ChunkerConfig(**chunker)
|
||||
|
||||
self.config = config or PipelineConfig()
|
||||
self.name = self.config.name
|
||||
self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id
|
||||
|
||||
if id is not None:
|
||||
# Init client first since user is trying to fetch the pipeline
|
||||
# details from the platform
|
||||
self._init_client()
|
||||
pipeline_details = self._get_pipeline(id)
|
||||
self.config.id = self.local_id = pipeline_details["metadata"]["local_id"]
|
||||
self.id = id
|
||||
|
||||
if name is not None:
|
||||
self.name = name
|
||||
|
||||
self.embedding_model = embedding_model or OpenAIEmbedder()
|
||||
self.db = db or ChromaDB()
|
||||
self.llm = llm or OpenAILlm()
|
||||
self._init_db()
|
||||
|
||||
# Send anonymous telemetry
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||
|
||||
# Establish a connection to the SQLite database
|
||||
self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
# Create the 'data_sources' table if it doesn't exist
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS data_sources (
|
||||
pipeline_id TEXT,
|
||||
hash TEXT,
|
||||
type TEXT,
|
||||
value TEXT,
|
||||
metadata TEXT,
|
||||
is_uploaded INTEGER DEFAULT 0,
|
||||
PRIMARY KEY (pipeline_id, hash)
|
||||
)
|
||||
"""
|
||||
)
|
||||
self.connection.commit()
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
|
||||
|
||||
self.user_asks = []
|
||||
if self.auto_deploy:
|
||||
self.deploy()
|
||||
|
||||
def _init_db(self):
|
||||
"""
|
||||
Initialize the database.
|
||||
"""
|
||||
self.db._set_embedder(self.embedding_model)
|
||||
self.db._initialize()
|
||||
self.db.set_collection_name(self.db.config.collection_name)
|
||||
|
||||
def _init_client(self):
|
||||
"""
|
||||
Initialize the client.
|
||||
"""
|
||||
config = Client.load_config()
|
||||
if config.get("api_key"):
|
||||
self.client = Client()
|
||||
else:
|
||||
api_key = input(
|
||||
"🔑 Enter your Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n" # noqa: E501
|
||||
)
|
||||
self.client = Client(api_key=api_key)
|
||||
|
||||
def _get_pipeline(self, id):
|
||||
"""
|
||||
Get existing pipeline
|
||||
"""
|
||||
print("🛠️ Fetching pipeline details from the platform...")
|
||||
url = f"{self.client.host}/api/v1/pipelines/{id}/cli/"
|
||||
r = requests.get(
|
||||
url,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
if r.status_code == 404:
|
||||
raise Exception(f"❌ Pipeline with id {id} not found!")
|
||||
|
||||
print(
|
||||
f"🎉 Pipeline loaded successfully! Pipeline url: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
)
|
||||
return r.json()
|
||||
|
||||
def _create_pipeline(self):
|
||||
"""
|
||||
Create a pipeline on the platform.
|
||||
"""
|
||||
print("🛠️ Creating pipeline on the platform...")
|
||||
# self.config_data is a dict. Pass it inside the key 'yaml_config' to the backend
|
||||
payload = {
|
||||
"yaml_config": json.dumps(self.config_data),
|
||||
"name": self.name,
|
||||
"local_id": self.local_id,
|
||||
}
|
||||
url = f"{self.client.host}/api/v1/pipelines/cli/create/"
|
||||
r = requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
if r.status_code not in [200, 201]:
|
||||
raise Exception(f"❌ Error occurred while creating pipeline. API response: {r.text}")
|
||||
|
||||
if r.status_code == 200:
|
||||
print(
|
||||
f"🎉🎉🎉 Existing pipeline found! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
) # noqa: E501
|
||||
elif r.status_code == 201:
|
||||
print(
|
||||
f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
)
|
||||
return r.json()
|
||||
|
||||
def _get_presigned_url(self, data_type, data_value):
|
||||
payload = {"data_type": data_type, "data_value": data_value}
|
||||
r = requests.post(
|
||||
f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
def search(self, query, num_documents=3):
|
||||
"""
|
||||
Search for similar documents related to the query in the vector database.
|
||||
"""
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
|
||||
|
||||
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
|
||||
if self.id is None:
|
||||
where = {"app_id": self.local_id}
|
||||
context = self.db.query(
|
||||
query,
|
||||
n_results=num_documents,
|
||||
where=where,
|
||||
skip_embedding=False,
|
||||
citations=True,
|
||||
)
|
||||
result = []
|
||||
for c in context:
|
||||
result.append({"context": c[0], "metadata": c[1]})
|
||||
return result
|
||||
else:
|
||||
# Make API call to the backend to get the results
|
||||
NotImplementedError("Search is not implemented yet for the prod mode.")
|
||||
|
||||
def _upload_file_to_presigned_url(self, presigned_url, file_path):
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
response = requests.put(presigned_url, data=file)
|
||||
response.raise_for_status()
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error occurred during file upload: {str(e)}")
|
||||
print("❌ Error occurred during file upload!")
|
||||
return False
|
||||
|
||||
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
|
||||
payload = {
|
||||
"data_type": data_type,
|
||||
"data_value": data_value,
|
||||
"metadata": metadata,
|
||||
}
|
||||
try:
|
||||
self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
|
||||
# print the local file path if user tries to upload a local file
|
||||
printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
|
||||
print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
|
||||
except Exception as e:
|
||||
print(f"❌ Error occurred during data upload for type {data_type}!. Error: {str(e)}")
|
||||
|
||||
def _send_api_request(self, endpoint, payload):
|
||||
url = f"{self.client.host}{endpoint}"
|
||||
headers = {"Authorization": f"Token {self.client.api_key}"}
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _process_and_upload_data(self, data_hash, data_type, data_value):
|
||||
if os.path.isabs(data_value):
|
||||
presigned_url_data = self._get_presigned_url(data_type, data_value)
|
||||
presigned_url = presigned_url_data["presigned_url"]
|
||||
s3_key = presigned_url_data["s3_key"]
|
||||
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
|
||||
metadata = {"file_path": data_value, "s3_key": s3_key}
|
||||
data_value = presigned_url
|
||||
else:
|
||||
self.logger.error(f"File upload failed for hash: {data_hash}")
|
||||
return False
|
||||
else:
|
||||
if data_type == "qna_pair":
|
||||
data_value = list(ast.literal_eval(data_value))
|
||||
metadata = {}
|
||||
|
||||
try:
|
||||
self._upload_data_to_pipeline(data_type, data_value, metadata)
|
||||
self._mark_data_as_uploaded(data_hash)
|
||||
return True
|
||||
except Exception:
|
||||
print(f"❌ Error occurred during data upload for hash {data_hash}!")
|
||||
return False
|
||||
|
||||
def _mark_data_as_uploaded(self, data_hash):
|
||||
self.cursor.execute(
|
||||
"UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ?",
|
||||
(data_hash, self.local_id),
|
||||
)
|
||||
self.connection.commit()
|
||||
|
||||
def get_data_sources(self):
|
||||
db_data = self.cursor.execute("SELECT * FROM data_sources WHERE pipeline_id = ?", (self.local_id,)).fetchall()
|
||||
|
||||
data_sources = []
|
||||
for data in db_data:
|
||||
data_sources.append({"data_type": data[2], "data_value": data[3], "metadata": data[4]})
|
||||
|
||||
return data_sources
|
||||
|
||||
def deploy(self):
|
||||
if self.client is None:
|
||||
self._init_client()
|
||||
|
||||
pipeline_data = self._create_pipeline()
|
||||
self.id = pipeline_data["id"]
|
||||
|
||||
results = self.cursor.execute(
|
||||
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,) # noqa:E501
|
||||
).fetchall()
|
||||
|
||||
if len(results) > 0:
|
||||
print("🛠️ Adding data to your pipeline...")
|
||||
for result in results:
|
||||
data_hash, data_type, data_value = result[1], result[2], result[3]
|
||||
self._process_and_upload_data(data_hash, data_type, data_value)
|
||||
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config_path: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
auto_deploy: bool = False,
|
||||
yaml_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Instantiate a Pipeline object from a configuration.
|
||||
|
||||
:param config_path: Path to the YAML or JSON configuration file.
|
||||
:type config_path: Optional[str]
|
||||
:param config: A dictionary containing the configuration.
|
||||
:type config: Optional[Dict[str, Any]]
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
|
||||
:type yaml_path: Optional[str]
|
||||
:return: An instance of the Pipeline class.
|
||||
:rtype: Pipeline
|
||||
"""
|
||||
# Backward compatibility for yaml_path
|
||||
if yaml_path and not config_path:
|
||||
config_path = yaml_path
|
||||
|
||||
if config_path and config:
|
||||
raise ValueError("Please provide only one of config_path or config.")
|
||||
|
||||
config_data = None
|
||||
|
||||
if config_path:
|
||||
file_extension = os.path.splitext(config_path)[1]
|
||||
with open(config_path, "r") as file:
|
||||
if file_extension in [".yaml", ".yml"]:
|
||||
config_data = yaml.safe_load(file)
|
||||
elif file_extension == ".json":
|
||||
config_data = json.load(file)
|
||||
else:
|
||||
raise ValueError("config_path must be a path to a YAML or JSON file.")
|
||||
elif config and isinstance(config, dict):
|
||||
config_data = config
|
||||
else:
|
||||
logging.error(
|
||||
"Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
|
||||
)
|
||||
config_data = {}
|
||||
|
||||
try:
|
||||
validate_config(config_data)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error occurred while validating the config. Error: {str(e)}")
|
||||
|
||||
pipeline_config_data = config_data.get("app", {}).get("config", {})
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
chunker_config_data = config_data.get("chunker", {})
|
||||
|
||||
pipeline_config = PipelineConfig(**pipeline_config_data)
|
||||
|
||||
db_provider = db_config_data.get("provider", "chroma")
|
||||
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
|
||||
|
||||
if llm_config_data:
|
||||
llm_provider = llm_config_data.get("provider", "openai")
|
||||
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
|
||||
else:
|
||||
llm = None
|
||||
|
||||
embedding_model_provider = embedding_model_config_data.get("provider", "openai")
|
||||
embedding_model = EmbedderFactory.create(
|
||||
embedding_model_provider, embedding_model_config_data.get("config", {})
|
||||
)
|
||||
|
||||
# Send anonymous telemetry
|
||||
event_properties = {"init_type": "config_data"}
|
||||
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
|
||||
|
||||
return cls(
|
||||
config=pipeline_config,
|
||||
llm=llm,
|
||||
db=db,
|
||||
embedding_model=embedding_model,
|
||||
config_data=config_data,
|
||||
auto_deploy=auto_deploy,
|
||||
chunker=chunker_config_data,
|
||||
)
|
||||
pass
|
||||
Reference in New Issue
Block a user