[Bug fix] Fix issues related to creating pipelines (#850)

This commit is contained in:
Deshraj Yadav
2023-10-25 20:26:58 -07:00
committed by GitHub
parent 797bb567c6
commit 413ccb83e6
3 changed files with 81 additions and 13 deletions

3
.gitignore vendored
View File

@@ -173,4 +173,5 @@ test-db
.DS_Store .DS_Store
notebooks/*.yaml notebooks/*.yaml
.ipynb_checkpoints/ .ipynb_checkpoints/
!configs/*.yaml

26
configs/pipeline.yaml Normal file
View File

@@ -0,0 +1,26 @@
pipeline:
config:
name: Example pipeline
id: pipeline-1 # Make sure that id is different every time you create a new pipeline
vectordb:
provider: chroma
config:
collection_name: pipeline-1
dir: db
allow_reset: true
llm:
provider: gpt4all
config:
model: 'orca-mini-3b.ggmlv3.q4_0.bin'
temperature: 0.5
max_tokens: 1000
top_p: 1
stream: false
embedding_model:
provider: gpt4all
config:
model: 'all-MiniLM-L6-v2'
deployment_name: null

View File

@@ -14,7 +14,7 @@ from embedchain.config import PipelineConfig
from embedchain.embedchain import CONFIG_DIR, EmbedChain from embedchain.embedchain import CONFIG_DIR, EmbedChain
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.factory import EmbedderFactory, VectorDBFactory from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
@@ -39,6 +39,7 @@ class Pipeline(EmbedChain):
llm: BaseLlm = None, llm: BaseLlm = None,
yaml_path: str = None, yaml_path: str = None,
log_level=logging.INFO, log_level=logging.INFO,
auto_deploy: bool = False,
): ):
""" """
Initialize a new `App` instance. Initialize a new `App` instance.
@@ -49,12 +50,26 @@ class Pipeline(EmbedChain):
:type db: BaseVectorDB, optional :type db: BaseVectorDB, optional
:param embedding_model: The embedding model used to calculate embeddings, defaults to None :param embedding_model: The embedding model used to calculate embeddings, defaults to None
:type embedding_model: BaseEmbedder, optional :type embedding_model: BaseEmbedder, optional
:param llm: The LLM model used to calculate embeddings, defaults to None
:type llm: BaseLlm, optional
:param yaml_path: Path to the YAML configuration file, defaults to None
:type yaml_path: str, optional
:param log_level: Log level to use, defaults to logging.INFO
:type log_level: int, optional
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
:type auto_deploy: bool, optional
:raises Exception: If an error occurs while creating the pipeline
""" """
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.auto_deploy = auto_deploy
# Store the yaml config as an attribute to be able to send it # Store the yaml config as an attribute to be able to send it
self.yaml_config = None self.yaml_config = None
self.client = None self.client = None
# pipeline_id from the backend
self.id = None
if yaml_path: if yaml_path:
with open(yaml_path, "r") as file: with open(yaml_path, "r") as file:
config_data = yaml.safe_load(file) config_data = yaml.safe_load(file)
@@ -84,7 +99,7 @@ class Pipeline(EmbedChain):
hash TEXT, hash TEXT,
type TEXT, type TEXT,
value TEXT, value TEXT,
metadata TEXT metadata TEXT,
is_uploaded INTEGER DEFAULT 0, is_uploaded INTEGER DEFAULT 0,
PRIMARY KEY (pipeline_id, hash) PRIMARY KEY (pipeline_id, hash)
) )
@@ -93,6 +108,8 @@ class Pipeline(EmbedChain):
self.connection.commit() self.connection.commit()
self.user_asks = [] # legacy defaults self.user_asks = [] # legacy defaults
if self.auto_deploy:
self.deploy()
def _init_db(self): def _init_db(self):
""" """
@@ -110,14 +127,16 @@ class Pipeline(EmbedChain):
if config.get("api_key"): if config.get("api_key"):
self.client = Client() self.client = Client()
else: else:
api_key = input("Enter API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n") api_key = input(
"Enter Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n"
)
self.client = Client(api_key=api_key) self.client = Client(api_key=api_key)
def _create_pipeline(self): def _create_pipeline(self):
""" """
Create a pipeline on the platform. Create a pipeline on the platform.
""" """
print("Creating pipeline on the platform...") print("🛠️ Creating pipeline on the platform...")
# self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend # self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend
payload = { payload = {
"yaml_config": json.dumps(self.yaml_config), "yaml_config": json.dumps(self.yaml_config),
@@ -133,7 +152,9 @@ class Pipeline(EmbedChain):
if r.status_code not in [200, 201]: if r.status_code not in [200, 201]:
raise Exception(f"Error occurred while creating pipeline. Response from API: {r.text}") raise Exception(f"Error occurred while creating pipeline. Response from API: {r.text}")
print(f"Pipeline created. link: https://app.embedchain.ai/pipelines/{r.json()['id']}") print(
f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
)
return r.json() return r.json()
def _get_presigned_url(self, data_type, data_value): def _get_presigned_url(self, data_type, data_value):
@@ -151,7 +172,7 @@ class Pipeline(EmbedChain):
Search for similar documents related to the query in the vector database. Search for similar documents related to the query in the vector database.
""" """
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True. # TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
if self.deploy is False: if self.id is None:
where = {"app_id": self.local_id} where = {"app_id": self.local_id}
return self.db.query( return self.db.query(
query, query,
@@ -171,6 +192,7 @@ class Pipeline(EmbedChain):
return response.status_code == 200 return response.status_code == 200
except Exception as e: except Exception as e:
self.logger.exception(f"Error occurred during file upload: {str(e)}") self.logger.exception(f"Error occurred during file upload: {str(e)}")
print("❌ Error occurred during file upload!")
return False return False
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None): def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
@@ -179,7 +201,14 @@ class Pipeline(EmbedChain):
"data_value": data_value, "data_value": data_value,
"metadata": metadata, "metadata": metadata,
} }
return self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload) try:
self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
# print the local file path if user tries to upload a local file
printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
except Exception as e:
self.logger.error(f"Error occurred during data upload: {str(e)}")
print(f"❌ Error occurred during data upload for type {data_type}!")
def _send_api_request(self, endpoint, payload): def _send_api_request(self, endpoint, payload):
url = f"{self.client.host}{endpoint}" url = f"{self.client.host}{endpoint}"
@@ -194,8 +223,8 @@ class Pipeline(EmbedChain):
presigned_url = presigned_url_data["presigned_url"] presigned_url = presigned_url_data["presigned_url"]
s3_key = presigned_url_data["s3_key"] s3_key = presigned_url_data["s3_key"]
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value): if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
data_value = presigned_url
metadata = {"file_path": data_value, "s3_key": s3_key} metadata = {"file_path": data_value, "s3_key": s3_key}
data_value = presigned_url
else: else:
self.logger.error(f"File upload failed for hash: {data_hash}") self.logger.error(f"File upload failed for hash: {data_hash}")
return False return False
@@ -207,10 +236,10 @@ class Pipeline(EmbedChain):
try: try:
self._upload_data_to_pipeline(data_type, data_value, metadata) self._upload_data_to_pipeline(data_type, data_value, metadata)
self._mark_data_as_uploaded(data_hash) self._mark_data_as_uploaded(data_hash)
self.logger.info(f"Data of type {data_type} uploaded successfully.")
return True return True
except Exception as e: except Exception as e:
self.logger.error(f"Error occurred during data upload: {str(e)}") self.logger.error(f"Error occurred during data upload: {str(e)}")
print(f"❌ Error occurred during data upload for hash {data_hash}!")
return False return False
def _mark_data_as_uploaded(self, data_hash): def _mark_data_as_uploaded(self, data_hash):
@@ -232,22 +261,25 @@ class Pipeline(EmbedChain):
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,) "SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,)
).fetchall() ).fetchall()
if len(results) > 0:
print("🛠️ Adding data to your pipeline...")
for result in results: for result in results:
data_hash, data_type, data_value = result[0], result[2], result[3] data_hash, data_type, data_value = result[0], result[2], result[3]
if self._process_and_upload_data(data_hash, data_type, data_value): self._process_and_upload_data(data_hash, data_type, data_value)
self.logger.info(f"Data with hash {data_hash} uploaded successfully.")
except Exception as e: except Exception as e:
self.logger.exception(f"Error occurred during deployment: {str(e)}") self.logger.exception(f"Error occurred during deployment: {str(e)}")
raise HTTPException(status_code=500, detail="Error occurred during deployment.") raise HTTPException(status_code=500, detail="Error occurred during deployment.")
@classmethod @classmethod
def from_config(cls, yaml_path: str): def from_config(cls, yaml_path: str, auto_deploy: bool = False):
""" """
Instantiate a Pipeline object from a YAML configuration file. Instantiate a Pipeline object from a YAML configuration file.
:param yaml_path: Path to the YAML configuration file. :param yaml_path: Path to the YAML configuration file.
:type yaml_path: str :type yaml_path: str
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
:type auto_deploy: bool, optional
:return: An instance of the Pipeline class. :return: An instance of the Pipeline class.
:rtype: Pipeline :rtype: Pipeline
""" """
@@ -257,21 +289,30 @@ class Pipeline(EmbedChain):
pipeline_config_data = config_data.get("pipeline", {}).get("config", {}) pipeline_config_data = config_data.get("pipeline", {}).get("config", {})
db_config_data = config_data.get("vectordb", {}) db_config_data = config_data.get("vectordb", {})
embedding_model_config_data = config_data.get("embedding_model", {}) embedding_model_config_data = config_data.get("embedding_model", {})
llm_config_data = config_data.get("llm", {})
pipeline_config = PipelineConfig(**pipeline_config_data) pipeline_config = PipelineConfig(**pipeline_config_data)
db_provider = db_config_data.get("provider", "chroma") db_provider = db_config_data.get("provider", "chroma")
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {})) db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
if llm_config_data:
llm_provider = llm_config_data.get("provider", "openai")
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
else:
llm = None
embedding_model_provider = embedding_model_config_data.get("provider", "openai") embedding_model_provider = embedding_model_config_data.get("provider", "openai")
embedding_model = EmbedderFactory.create( embedding_model = EmbedderFactory.create(
embedding_model_provider, embedding_model_config_data.get("config", {}) embedding_model_provider, embedding_model_config_data.get("config", {})
) )
return cls( return cls(
config=pipeline_config, config=pipeline_config,
llm=llm,
db=db, db=db,
embedding_model=embedding_model, embedding_model=embedding_model,
yaml_path=yaml_path, yaml_path=yaml_path,
auto_deploy=auto_deploy,
) )
def start(self, host="0.0.0.0", port=8000): def start(self, host="0.0.0.0", port=8000):