[Bug fix] Fix issues related to creating pipelines (#850)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -174,3 +174,4 @@ test-db
|
|||||||
|
|
||||||
notebooks/*.yaml
|
notebooks/*.yaml
|
||||||
.ipynb_checkpoints/
|
.ipynb_checkpoints/
|
||||||
|
!configs/*.yaml
|
||||||
|
|||||||
26
configs/pipeline.yaml
Normal file
26
configs/pipeline.yaml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
pipeline:
|
||||||
|
config:
|
||||||
|
name: Example pipeline
|
||||||
|
id: pipeline-1 # Make sure that id is different every time you create a new pipeline
|
||||||
|
|
||||||
|
vectordb:
|
||||||
|
provider: chroma
|
||||||
|
config:
|
||||||
|
collection_name: pipeline-1
|
||||||
|
dir: db
|
||||||
|
allow_reset: true
|
||||||
|
|
||||||
|
llm:
|
||||||
|
provider: gpt4all
|
||||||
|
config:
|
||||||
|
model: 'orca-mini-3b.ggmlv3.q4_0.bin'
|
||||||
|
temperature: 0.5
|
||||||
|
max_tokens: 1000
|
||||||
|
top_p: 1
|
||||||
|
stream: false
|
||||||
|
|
||||||
|
embedding_model:
|
||||||
|
provider: gpt4all
|
||||||
|
config:
|
||||||
|
model: 'all-MiniLM-L6-v2'
|
||||||
|
deployment_name: null
|
||||||
@@ -14,7 +14,7 @@ from embedchain.config import PipelineConfig
|
|||||||
from embedchain.embedchain import CONFIG_DIR, EmbedChain
|
from embedchain.embedchain import CONFIG_DIR, EmbedChain
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
from embedchain.embedder.openai import OpenAIEmbedder
|
from embedchain.embedder.openai import OpenAIEmbedder
|
||||||
from embedchain.factory import EmbedderFactory, VectorDBFactory
|
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||||
from embedchain.helper.json_serializable import register_deserializable
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
from embedchain.llm.base import BaseLlm
|
from embedchain.llm.base import BaseLlm
|
||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
@@ -39,6 +39,7 @@ class Pipeline(EmbedChain):
|
|||||||
llm: BaseLlm = None,
|
llm: BaseLlm = None,
|
||||||
yaml_path: str = None,
|
yaml_path: str = None,
|
||||||
log_level=logging.INFO,
|
log_level=logging.INFO,
|
||||||
|
auto_deploy: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a new `App` instance.
|
Initialize a new `App` instance.
|
||||||
@@ -49,12 +50,26 @@ class Pipeline(EmbedChain):
|
|||||||
:type db: BaseVectorDB, optional
|
:type db: BaseVectorDB, optional
|
||||||
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
|
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
|
||||||
:type embedding_model: BaseEmbedder, optional
|
:type embedding_model: BaseEmbedder, optional
|
||||||
|
:param llm: The LLM model used to calculate embeddings, defaults to None
|
||||||
|
:type llm: BaseLlm, optional
|
||||||
|
:param yaml_path: Path to the YAML configuration file, defaults to None
|
||||||
|
:type yaml_path: str, optional
|
||||||
|
:param log_level: Log level to use, defaults to logging.INFO
|
||||||
|
:type log_level: int, optional
|
||||||
|
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||||
|
:type auto_deploy: bool, optional
|
||||||
|
:raises Exception: If an error occurs while creating the pipeline
|
||||||
"""
|
"""
|
||||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
self.auto_deploy = auto_deploy
|
||||||
|
|
||||||
# Store the yaml config as an attribute to be able to send it
|
# Store the yaml config as an attribute to be able to send it
|
||||||
self.yaml_config = None
|
self.yaml_config = None
|
||||||
self.client = None
|
self.client = None
|
||||||
|
# pipeline_id from the backend
|
||||||
|
self.id = None
|
||||||
if yaml_path:
|
if yaml_path:
|
||||||
with open(yaml_path, "r") as file:
|
with open(yaml_path, "r") as file:
|
||||||
config_data = yaml.safe_load(file)
|
config_data = yaml.safe_load(file)
|
||||||
@@ -84,7 +99,7 @@ class Pipeline(EmbedChain):
|
|||||||
hash TEXT,
|
hash TEXT,
|
||||||
type TEXT,
|
type TEXT,
|
||||||
value TEXT,
|
value TEXT,
|
||||||
metadata TEXT
|
metadata TEXT,
|
||||||
is_uploaded INTEGER DEFAULT 0,
|
is_uploaded INTEGER DEFAULT 0,
|
||||||
PRIMARY KEY (pipeline_id, hash)
|
PRIMARY KEY (pipeline_id, hash)
|
||||||
)
|
)
|
||||||
@@ -93,6 +108,8 @@ class Pipeline(EmbedChain):
|
|||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
self.user_asks = [] # legacy defaults
|
self.user_asks = [] # legacy defaults
|
||||||
|
if self.auto_deploy:
|
||||||
|
self.deploy()
|
||||||
|
|
||||||
def _init_db(self):
|
def _init_db(self):
|
||||||
"""
|
"""
|
||||||
@@ -110,14 +127,16 @@ class Pipeline(EmbedChain):
|
|||||||
if config.get("api_key"):
|
if config.get("api_key"):
|
||||||
self.client = Client()
|
self.client = Client()
|
||||||
else:
|
else:
|
||||||
api_key = input("Enter API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n")
|
api_key = input(
|
||||||
|
"Enter Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n"
|
||||||
|
)
|
||||||
self.client = Client(api_key=api_key)
|
self.client = Client(api_key=api_key)
|
||||||
|
|
||||||
def _create_pipeline(self):
|
def _create_pipeline(self):
|
||||||
"""
|
"""
|
||||||
Create a pipeline on the platform.
|
Create a pipeline on the platform.
|
||||||
"""
|
"""
|
||||||
print("Creating pipeline on the platform...")
|
print("🛠️ Creating pipeline on the platform...")
|
||||||
# self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend
|
# self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend
|
||||||
payload = {
|
payload = {
|
||||||
"yaml_config": json.dumps(self.yaml_config),
|
"yaml_config": json.dumps(self.yaml_config),
|
||||||
@@ -133,7 +152,9 @@ class Pipeline(EmbedChain):
|
|||||||
if r.status_code not in [200, 201]:
|
if r.status_code not in [200, 201]:
|
||||||
raise Exception(f"Error occurred while creating pipeline. Response from API: {r.text}")
|
raise Exception(f"Error occurred while creating pipeline. Response from API: {r.text}")
|
||||||
|
|
||||||
print(f"Pipeline created. link: https://app.embedchain.ai/pipelines/{r.json()['id']}")
|
print(
|
||||||
|
f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||||
|
)
|
||||||
return r.json()
|
return r.json()
|
||||||
|
|
||||||
def _get_presigned_url(self, data_type, data_value):
|
def _get_presigned_url(self, data_type, data_value):
|
||||||
@@ -151,7 +172,7 @@ class Pipeline(EmbedChain):
|
|||||||
Search for similar documents related to the query in the vector database.
|
Search for similar documents related to the query in the vector database.
|
||||||
"""
|
"""
|
||||||
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
|
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
|
||||||
if self.deploy is False:
|
if self.id is None:
|
||||||
where = {"app_id": self.local_id}
|
where = {"app_id": self.local_id}
|
||||||
return self.db.query(
|
return self.db.query(
|
||||||
query,
|
query,
|
||||||
@@ -171,6 +192,7 @@ class Pipeline(EmbedChain):
|
|||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.exception(f"Error occurred during file upload: {str(e)}")
|
self.logger.exception(f"Error occurred during file upload: {str(e)}")
|
||||||
|
print("❌ Error occurred during file upload!")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
|
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
|
||||||
@@ -179,7 +201,14 @@ class Pipeline(EmbedChain):
|
|||||||
"data_value": data_value,
|
"data_value": data_value,
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
return self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
|
try:
|
||||||
|
self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
|
||||||
|
# print the local file path if user tries to upload a local file
|
||||||
|
printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
|
||||||
|
print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error occurred during data upload: {str(e)}")
|
||||||
|
print(f"❌ Error occurred during data upload for type {data_type}!")
|
||||||
|
|
||||||
def _send_api_request(self, endpoint, payload):
|
def _send_api_request(self, endpoint, payload):
|
||||||
url = f"{self.client.host}{endpoint}"
|
url = f"{self.client.host}{endpoint}"
|
||||||
@@ -194,8 +223,8 @@ class Pipeline(EmbedChain):
|
|||||||
presigned_url = presigned_url_data["presigned_url"]
|
presigned_url = presigned_url_data["presigned_url"]
|
||||||
s3_key = presigned_url_data["s3_key"]
|
s3_key = presigned_url_data["s3_key"]
|
||||||
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
|
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
|
||||||
data_value = presigned_url
|
|
||||||
metadata = {"file_path": data_value, "s3_key": s3_key}
|
metadata = {"file_path": data_value, "s3_key": s3_key}
|
||||||
|
data_value = presigned_url
|
||||||
else:
|
else:
|
||||||
self.logger.error(f"File upload failed for hash: {data_hash}")
|
self.logger.error(f"File upload failed for hash: {data_hash}")
|
||||||
return False
|
return False
|
||||||
@@ -207,10 +236,10 @@ class Pipeline(EmbedChain):
|
|||||||
try:
|
try:
|
||||||
self._upload_data_to_pipeline(data_type, data_value, metadata)
|
self._upload_data_to_pipeline(data_type, data_value, metadata)
|
||||||
self._mark_data_as_uploaded(data_hash)
|
self._mark_data_as_uploaded(data_hash)
|
||||||
self.logger.info(f"Data of type {data_type} uploaded successfully.")
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error occurred during data upload: {str(e)}")
|
self.logger.error(f"Error occurred during data upload: {str(e)}")
|
||||||
|
print(f"❌ Error occurred during data upload for hash {data_hash}!")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _mark_data_as_uploaded(self, data_hash):
|
def _mark_data_as_uploaded(self, data_hash):
|
||||||
@@ -232,22 +261,25 @@ class Pipeline(EmbedChain):
|
|||||||
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,)
|
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,)
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
|
if len(results) > 0:
|
||||||
|
print("🛠️ Adding data to your pipeline...")
|
||||||
for result in results:
|
for result in results:
|
||||||
data_hash, data_type, data_value = result[0], result[2], result[3]
|
data_hash, data_type, data_value = result[0], result[2], result[3]
|
||||||
if self._process_and_upload_data(data_hash, data_type, data_value):
|
self._process_and_upload_data(data_hash, data_type, data_value)
|
||||||
self.logger.info(f"Data with hash {data_hash} uploaded successfully.")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.exception(f"Error occurred during deployment: {str(e)}")
|
self.logger.exception(f"Error occurred during deployment: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail="Error occurred during deployment.")
|
raise HTTPException(status_code=500, detail="Error occurred during deployment.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, yaml_path: str):
|
def from_config(cls, yaml_path: str, auto_deploy: bool = False):
|
||||||
"""
|
"""
|
||||||
Instantiate a Pipeline object from a YAML configuration file.
|
Instantiate a Pipeline object from a YAML configuration file.
|
||||||
|
|
||||||
:param yaml_path: Path to the YAML configuration file.
|
:param yaml_path: Path to the YAML configuration file.
|
||||||
:type yaml_path: str
|
:type yaml_path: str
|
||||||
|
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||||
|
:type auto_deploy: bool, optional
|
||||||
:return: An instance of the Pipeline class.
|
:return: An instance of the Pipeline class.
|
||||||
:rtype: Pipeline
|
:rtype: Pipeline
|
||||||
"""
|
"""
|
||||||
@@ -257,21 +289,30 @@ class Pipeline(EmbedChain):
|
|||||||
pipeline_config_data = config_data.get("pipeline", {}).get("config", {})
|
pipeline_config_data = config_data.get("pipeline", {}).get("config", {})
|
||||||
db_config_data = config_data.get("vectordb", {})
|
db_config_data = config_data.get("vectordb", {})
|
||||||
embedding_model_config_data = config_data.get("embedding_model", {})
|
embedding_model_config_data = config_data.get("embedding_model", {})
|
||||||
|
llm_config_data = config_data.get("llm", {})
|
||||||
|
|
||||||
pipeline_config = PipelineConfig(**pipeline_config_data)
|
pipeline_config = PipelineConfig(**pipeline_config_data)
|
||||||
|
|
||||||
db_provider = db_config_data.get("provider", "chroma")
|
db_provider = db_config_data.get("provider", "chroma")
|
||||||
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
|
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
|
||||||
|
|
||||||
|
if llm_config_data:
|
||||||
|
llm_provider = llm_config_data.get("provider", "openai")
|
||||||
|
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
|
||||||
|
else:
|
||||||
|
llm = None
|
||||||
|
|
||||||
embedding_model_provider = embedding_model_config_data.get("provider", "openai")
|
embedding_model_provider = embedding_model_config_data.get("provider", "openai")
|
||||||
embedding_model = EmbedderFactory.create(
|
embedding_model = EmbedderFactory.create(
|
||||||
embedding_model_provider, embedding_model_config_data.get("config", {})
|
embedding_model_provider, embedding_model_config_data.get("config", {})
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
config=pipeline_config,
|
config=pipeline_config,
|
||||||
|
llm=llm,
|
||||||
db=db,
|
db=db,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
yaml_path=yaml_path,
|
yaml_path=yaml_path,
|
||||||
|
auto_deploy=auto_deploy,
|
||||||
)
|
)
|
||||||
|
|
||||||
def start(self, host="0.0.0.0", port=8000):
|
def start(self, host="0.0.0.0", port=8000):
|
||||||
|
|||||||
Reference in New Issue
Block a user