[Feature] Setup base for creating pipelines in embedchain (#834)
This commit is contained in:
@@ -3,4 +3,5 @@ import importlib.metadata
|
|||||||
__version__ = importlib.metadata.version(__package__ or __name__)
|
__version__ = importlib.metadata.version(__package__ or __name__)
|
||||||
|
|
||||||
from embedchain.apps.app import App # noqa: F401
|
from embedchain.apps.app import App # noqa: F401
|
||||||
|
from embedchain.pipeline import Pipeline # noqa: F401
|
||||||
from embedchain.vectordb.chroma import ChromaDB # noqa: F401
|
from embedchain.vectordb.chroma import ChromaDB # noqa: F401
|
||||||
|
|||||||
@@ -2,10 +2,12 @@
|
|||||||
|
|
||||||
from .add_config import AddConfig, ChunkerConfig
|
from .add_config import AddConfig, ChunkerConfig
|
||||||
from .apps.app_config import AppConfig
|
from .apps.app_config import AppConfig
|
||||||
|
from .pipeline_config import PipelineConfig
|
||||||
from .base_config import BaseConfig
|
from .base_config import BaseConfig
|
||||||
from .embedder.base import BaseEmbedderConfig
|
from .embedder.base import BaseEmbedderConfig
|
||||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||||
from .llm.base import BaseLlmConfig
|
from .llm.base import BaseLlmConfig
|
||||||
|
from .pipeline_config import PipelineConfig
|
||||||
from .vectordb.chroma import ChromaDbConfig
|
from .vectordb.chroma import ChromaDbConfig
|
||||||
from .vectordb.elasticsearch import ElasticsearchDBConfig
|
from .vectordb.elasticsearch import ElasticsearchDBConfig
|
||||||
from .vectordb.opensearch import OpenSearchDBConfig
|
from .vectordb.opensearch import OpenSearchDBConfig
|
||||||
|
|||||||
38
embedchain/config/pipeline_config.py
Normal file
38
embedchain/config/pipeline_config.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
from .apps.base_app_config import BaseAppConfig
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class PipelineConfig(BaseAppConfig):
|
||||||
|
"""
|
||||||
|
Config to initialize an embedchain custom `App` instance, with extra config options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
log_level: str = "WARNING",
|
||||||
|
id: Optional[str] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
collect_metrics: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes a configuration class instance for an App. This is the simplest form of an embedchain app.
|
||||||
|
Most of the configuration is done in the `App` class itself.
|
||||||
|
|
||||||
|
:param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
|
||||||
|
:type log_level: str, optional
|
||||||
|
:param id: ID of the app. Document metadata will have this id., defaults to None
|
||||||
|
:type id: Optional[str], optional
|
||||||
|
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
|
||||||
|
:type collect_metrics: Optional[bool], optional
|
||||||
|
:param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
|
||||||
|
defaults to None
|
||||||
|
:type collection_name: Optional[str], optional
|
||||||
|
"""
|
||||||
|
self._setup_logging(log_level)
|
||||||
|
self.id = id
|
||||||
|
self.name = name
|
||||||
|
self.collect_metrics = collect_metrics
|
||||||
98
embedchain/pipeline.py
Normal file
98
embedchain/pipeline.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from embedchain.config import PipelineConfig
|
||||||
|
from embedchain.embedchain import EmbedChain
|
||||||
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
|
from embedchain.embedder.openai import OpenAIEmbedder
|
||||||
|
from embedchain.factory import EmbedderFactory, VectorDBFactory
|
||||||
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
|
from embedchain.vectordb.chroma import ChromaDB
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class Pipeline(EmbedChain):
|
||||||
|
"""
|
||||||
|
EmbedChain pipeline lets you create a LLM powered app for your unstructured
|
||||||
|
data by defining a pipeline with your chosen data source, embedding model,
|
||||||
|
and vector database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: PipelineConfig = None, db: BaseVectorDB = None, embedding_model: BaseEmbedder = None):
|
||||||
|
"""
|
||||||
|
Initialize a new `App` instance.
|
||||||
|
|
||||||
|
:param config: Configuration for the pipeline, defaults to None
|
||||||
|
:type config: PipelineConfig, optional
|
||||||
|
:param db: The database to use for storing and retrieving embeddings, defaults to None
|
||||||
|
:type db: BaseVectorDB, optional
|
||||||
|
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
|
||||||
|
:type embedding_model: BaseEmbedder, optional
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.config = config or PipelineConfig()
|
||||||
|
self.name = self.config.name
|
||||||
|
self.id = self.config.id or str(uuid.uuid4())
|
||||||
|
|
||||||
|
self.embedding_model = embedding_model or OpenAIEmbedder()
|
||||||
|
self.db = db or ChromaDB()
|
||||||
|
self._initialize_db()
|
||||||
|
|
||||||
|
self.user_asks = [] # legacy defaults
|
||||||
|
|
||||||
|
self.s_id = self.config.id or str(uuid.uuid4())
|
||||||
|
self.u_id = self._load_or_generate_user_id()
|
||||||
|
|
||||||
|
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("pipeline_init",))
|
||||||
|
thread_telemetry.start()
|
||||||
|
|
||||||
|
def _initialize_db(self):
|
||||||
|
"""
|
||||||
|
Initialize the database.
|
||||||
|
"""
|
||||||
|
self.db._set_embedder(self.embedding_model)
|
||||||
|
self.db._initialize()
|
||||||
|
self.db.set_collection_name(self.name)
|
||||||
|
|
||||||
|
def search(self, query, num_documents=3):
|
||||||
|
"""
|
||||||
|
Search for similar documents related to the query in the vector database.
|
||||||
|
"""
|
||||||
|
where = {"app_id": self.id}
|
||||||
|
return self.db.query(
|
||||||
|
query,
|
||||||
|
n_results=num_documents,
|
||||||
|
where=where,
|
||||||
|
skip_embedding=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, yaml_path: str):
|
||||||
|
"""
|
||||||
|
Instantiate a Pipeline object from a YAML configuration file.
|
||||||
|
|
||||||
|
:param yaml_path: Path to the YAML configuration file.
|
||||||
|
:type yaml_path: str
|
||||||
|
:return: An instance of the Pipeline class.
|
||||||
|
:rtype: Pipeline
|
||||||
|
"""
|
||||||
|
with open(yaml_path, "r") as file:
|
||||||
|
config_data = yaml.safe_load(file)
|
||||||
|
|
||||||
|
pipeline_config_data = config_data.get("pipeline", {})
|
||||||
|
db_config_data = config_data.get("vectordb", {})
|
||||||
|
embedding_model_config_data = config_data.get("embedding_model", {})
|
||||||
|
|
||||||
|
pipeline_config = PipelineConfig(**pipeline_config_data)
|
||||||
|
|
||||||
|
db_provider = db_config_data.get("provider", "chroma")
|
||||||
|
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
|
||||||
|
|
||||||
|
embedding_model_provider = embedding_model_config_data.get("provider", "openai")
|
||||||
|
embedding_model = EmbedderFactory.create(
|
||||||
|
embedding_model_provider, embedding_model_config_data.get("config", {})
|
||||||
|
)
|
||||||
|
return cls(config=pipeline_config, db=db, embedding_model=embedding_model)
|
||||||
Reference in New Issue
Block a user