From 3979480532e69dfc7ae73134fd1f998073806c37 Mon Sep 17 00:00:00 2001 From: Deshraj Yadav Date: Wed, 25 Oct 2023 13:36:24 -0700 Subject: [PATCH] [Feature] Add support for deploying local pipelines to Embedchain platform (#847) --- embedchain/__init__.py | 1 + embedchain/client.py | 102 +++++++++++++ embedchain/config/__init__.py | 1 - embedchain/embedchain.py | 68 +++++++-- embedchain/pipeline.py | 251 ++++++++++++++++++++++++++++--- tests/conftest.py | 16 ++ tests/embedchain/test_add.py | 15 +- tests/test_client.py | 53 +++++++ tests/vectordb/test_chroma_db.py | 3 +- 9 files changed, 467 insertions(+), 43 deletions(-) create mode 100644 embedchain/client.py create mode 100644 tests/conftest.py create mode 100644 tests/test_client.py diff --git a/embedchain/__init__.py b/embedchain/__init__.py index 31cc26fd..14d913f2 100644 --- a/embedchain/__init__.py +++ b/embedchain/__init__.py @@ -3,5 +3,6 @@ import importlib.metadata __version__ = importlib.metadata.version(__package__ or __name__) from embedchain.apps.app import App # noqa: F401 +from embedchain.client import Client # noqa: F401 from embedchain.pipeline import Pipeline # noqa: F401 from embedchain.vectordb.chroma import ChromaDB # noqa: F401 diff --git a/embedchain/client.py b/embedchain/client.py new file mode 100644 index 00000000..56f2d190 --- /dev/null +++ b/embedchain/client.py @@ -0,0 +1,102 @@ +import json +import logging +import os +import uuid + +import requests + +from embedchain.embedchain import CONFIG_DIR, CONFIG_FILE + + +class Client: + def __init__(self, api_key=None, host="https://apiv2.embedchain.ai"): + self.config_data = self.load_config() + self.host = host + + if api_key: + if self.check(api_key): + self.api_key = api_key + self.save() + else: + raise ValueError( + "Invalid API key provided. You can find your API key on https://app.embedchain.ai/settings/keys." + ) + else: + if "api_key" in self.config_data: + self.api_key = self.config_data["api_key"] + logging.info("API key loaded successfully!") + else: + raise ValueError( + "You are not logged in. Please obtain an API key from https://app.embedchain.ai/settings/keys/" + ) + + @classmethod + def setup_dir(self): + """ + Loads the user id from the config file if it exists, otherwise generates a new + one and saves it to the config file. + + :return: user id + :rtype: str + """ + if not os.path.exists(CONFIG_DIR): + os.makedirs(CONFIG_DIR) + + if os.path.exists(CONFIG_FILE): + with open(CONFIG_FILE, "r") as f: + data = json.load(f) + if "user_id" in data: + return data["user_id"] + + u_id = str(uuid.uuid4()) + with open(CONFIG_FILE, "w") as f: + json.dump({"user_id": u_id}, f) + + @classmethod + def load_config(cls): + if not os.path.exists(CONFIG_FILE): + cls.setup_dir() + + with open(CONFIG_FILE, "r") as config_file: + return json.load(config_file) + + def save(self): + self.config_data["api_key"] = self.api_key + with open(CONFIG_FILE, "w") as config_file: + json.dump(self.config_data, config_file, indent=4) + + logging.info("API key saved successfully!") + + def clear(self): + if "api_key" in self.config_data: + del self.config_data["api_key"] + with open(CONFIG_FILE, "w") as config_file: + json.dump(self.config_data, config_file, indent=4) + self.api_key = None + logging.info("API key deleted successfully!") + else: + logging.warning("API key not found in the configuration file.") + + def update(self, api_key): + if self.check(api_key): + self.api_key = api_key + self.save() + logging.info("API key updated successfully!") + else: + logging.warning("Invalid API key provided. API key not updated.") + + def check(self, api_key): + validation_url = f"{self.host}/api/v1/accounts/api_keys/validate/" + response = requests.post(validation_url, headers={"Authorization": f"Token {api_key}"}) + if response.status_code == 200: + return True + else: + logging.warning(f"Response from API: {response.text}") + logging.warning("Invalid API key. Unable to validate.") + return False + + def get(self): + return self.api_key + + def __str__(self): + return self.api_key diff --git a/embedchain/config/__init__.py b/embedchain/config/__init__.py index 7d9d63c7..46fce1e5 100644 --- a/embedchain/config/__init__.py +++ b/embedchain/config/__init__.py @@ -2,7 +2,6 @@ from .add_config import AddConfig, ChunkerConfig from .apps.app_config import AppConfig -from .pipeline_config import PipelineConfig from .base_config import BaseConfig from .embedder.base import BaseEmbedderConfig from .embedder.base import BaseEmbedderConfig as EmbedderConfig diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 47544769..35c53139 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -3,6 +3,7 @@ import importlib.metadata import json import logging import os +import sqlite3 import threading import uuid from pathlib import Path @@ -32,6 +33,7 @@ ABS_PATH = os.getcwd() HOME_DIR = str(Path.home()) CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain") CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json") +SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db") class EmbedChain(JSONSerializable): @@ -89,6 +91,27 @@ class EmbedChain(JSONSerializable): # Send anonymous telemetry self.s_id = self.config.id if self.config.id else str(uuid.uuid4()) self.u_id = self._load_or_generate_user_id() + + # Establish a connection to the SQLite database + self.connection = sqlite3.connect(SQLITE_PATH) + self.cursor = self.connection.cursor() + + # Create the 'data_sources' table if it doesn't exist + self.cursor.execute( + """ + CREATE TABLE IF NOT EXISTS data_sources ( + pipeline_id TEXT, + hash TEXT, + type TEXT, + value TEXT, + metadata TEXT, + is_uploaded INTEGER DEFAULT 0, + PRIMARY KEY (pipeline_id, hash) + ) + """ + ) + self.connection.commit() + # NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event. # if (self.config.collect_metrics): # raise ConnectionRefusedError("Collection of metrics should not be allowed.") @@ -163,7 +186,7 @@ class EmbedChain(JSONSerializable): :raises ValueError: Invalid data type :param dry_run: Optional. A dry run displays the chunks to ensure that the loader and chunker work as intended. deafaults to False - :return: source_id, a md5-hash of the source, in hexadecimal representation. + :return: source_hash, a md5-hash of the source, in hexadecimal representation. :rtype: str """ if config is None: @@ -192,18 +215,40 @@ class EmbedChain(JSONSerializable): if not data_type: data_type = detect_datatype(source) - # `source_id` is the hash of the source argument + # `source_hash` is the md5 hash of the source argument hash_object = hashlib.md5(str(source).encode("utf-8")) - source_id = hash_object.hexdigest() + source_hash = hash_object.hexdigest() + + # Check if the data hash already exists, if so, skip the addition + self.cursor.execute( + "SELECT 1 FROM data_sources WHERE hash = ? AND pipeline_id = ?", (source_hash, self.config.id) + ) + existing_data = self.cursor.fetchone() + + if existing_data: + print(f"Data with hash {source_hash} already exists. Skipping addition.") + return source_hash data_formatter = DataFormatter(data_type, config) self.user_asks.append([source, data_type.value, metadata]) documents, metadatas, _ids, new_chunks = self.load_and_embed( - data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run + data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run ) if data_type in {DataType.DOCS_SITE}: self.is_docs_site_instance = True + # Insert the data into the 'data' table + self.cursor.execute( + """ + INSERT INTO data_sources (hash, pipeline_id, type, value, metadata) + VALUES (?, ?, ?, ?, ?) + """, + (source_hash, self.config.id, data_type.value, str(source), json.dumps(metadata)), + ) + + # Commit the transaction + self.connection.commit() + if dry_run: data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type} logging.debug(f"Dry run info : {data_chunks_info}") @@ -218,7 +263,7 @@ class EmbedChain(JSONSerializable): thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata)) thread_telemetry.start() - return source_id + return source_hash def add_local( self, @@ -245,7 +290,7 @@ class EmbedChain(JSONSerializable): :param config: The `AddConfig` instance to use as configuration options., defaults to None :type config: Optional[AddConfig], optional :raises ValueError: Invalid data type - :return: source_id, a md5-hash of the source, in hexadecimal representation. + :return: source_hash, a md5-hash of the source, in hexadecimal representation. :rtype: str """ logging.warning( @@ -313,7 +358,7 @@ class EmbedChain(JSONSerializable): chunker: BaseChunker, src: Any, metadata: Optional[Dict[str, Any]] = None, - source_id: Optional[str] = None, + source_hash: Optional[str] = None, dry_run=False, ): """ @@ -324,7 +369,7 @@ class EmbedChain(JSONSerializable): :param src: The data to be handled by the loader. Can be a URL for remote sources or local content for local loaders. :param metadata: Optional. Metadata associated with the data source. - :param source_id: Hexadecimal hash of the source. + :param source_hash: Hexadecimal hash of the source. :param dry_run: Optional. A dry run returns chunks and doesn't update DB. :type dry_run: bool, defaults to False :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks @@ -382,7 +427,7 @@ class EmbedChain(JSONSerializable): m["app_id"] = self.config.id # Add hashed source - m["hash"] = source_id + m["hash"] = source_hash # Note: Metadata is the function argument if metadata: @@ -558,15 +603,14 @@ class EmbedChain(JSONSerializable): """ Resets the database. Deletes all embeddings irreversibly. `App` does not have to be reinitialized after using this method. - - DEPRECATED IN FAVOR OF `db.reset()` """ # Send anonymous telemetry thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",)) thread_telemetry.start() - logging.warning("DEPRECATION WARNING: Please use `app.db.reset()` instead of `App.reset()`.") self.db.reset() + self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,)) + self.connection.commit() @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None): diff --git a/embedchain/pipeline.py b/embedchain/pipeline.py index e07cf995..e2562c92 100644 --- a/embedchain/pipeline.py +++ b/embedchain/pipeline.py @@ -1,17 +1,27 @@ -import threading +import ast +import json +import logging +import os +import sqlite3 import uuid +import requests import yaml +from fastapi import FastAPI, HTTPException +from embedchain import Client from embedchain.config import PipelineConfig -from embedchain.embedchain import EmbedChain +from embedchain.embedchain import CONFIG_DIR, EmbedChain from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.openai import OpenAIEmbedder from embedchain.factory import EmbedderFactory, VectorDBFactory from embedchain.helper.json_serializable import register_deserializable +from embedchain.llm.base import BaseLlm from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.chroma import ChromaDB +SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db") + @register_deserializable class Pipeline(EmbedChain): @@ -21,7 +31,15 @@ class Pipeline(EmbedChain): and vector database. """ - def __init__(self, config: PipelineConfig = None, db: BaseVectorDB = None, embedding_model: BaseEmbedder = None): + def __init__( + self, + config: PipelineConfig = None, + db: BaseVectorDB = None, + embedding_model: BaseEmbedder = None, + llm: BaseLlm = None, + yaml_path: str = None, + log_level=logging.INFO, + ): """ Initialize a new `App` instance. @@ -32,42 +50,196 @@ class Pipeline(EmbedChain): :param embedding_model: The embedding model used to calculate embeddings, defaults to None :type embedding_model: BaseEmbedder, optional """ - super().__init__() + logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") + self.logger = logging.getLogger(__name__) + # Store the yaml config as an attribute to be able to send it + self.yaml_config = None + self.client = None + if yaml_path: + with open(yaml_path, "r") as file: + config_data = yaml.safe_load(file) + self.yaml_config = config_data + self.config = config or PipelineConfig() self.name = self.config.name - self.id = self.config.id or str(uuid.uuid4()) + self.local_id = self.config.id or str(uuid.uuid4()) self.embedding_model = embedding_model or OpenAIEmbedder() self.db = db or ChromaDB() - self._initialize_db() + self.llm = llm or None + self._init_db() + + # setup user id and directory + self.u_id = self._load_or_generate_user_id() + + # Establish a connection to the SQLite database + self.connection = sqlite3.connect(SQLITE_PATH) + self.cursor = self.connection.cursor() + + # Create the 'data_sources' table if it doesn't exist + self.cursor.execute( + """ + CREATE TABLE IF NOT EXISTS data_sources ( + pipeline_id TEXT, + hash TEXT, + type TEXT, + value TEXT, + metadata TEXT + is_uploaded INTEGER DEFAULT 0, + PRIMARY KEY (pipeline_id, hash) + ) + """ + ) + self.connection.commit() self.user_asks = [] # legacy defaults - self.s_id = self.config.id or str(uuid.uuid4()) - self.u_id = self._load_or_generate_user_id() - - thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("pipeline_init",)) - thread_telemetry.start() - - def _initialize_db(self): + def _init_db(self): """ Initialize the database. """ self.db._set_embedder(self.embedding_model) self.db._initialize() - self.db.set_collection_name(self.name) + self.db.set_collection_name(self.db.config.collection_name) + + def _init_client(self): + """ + Initialize the client. + """ + config = Client.load_config() + if config.get("api_key"): + self.client = Client() + else: + api_key = input("Enter API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n") + self.client = Client(api_key=api_key) + + def _create_pipeline(self): + """ + Create a pipeline on the platform. + """ + print("Creating pipeline on the platform...") + # self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend + payload = { + "yaml_config": json.dumps(self.yaml_config), + "name": self.name, + "local_id": self.local_id, + } + url = f"{self.client.host}/api/v1/pipelines/cli/create/" + r = requests.post( + url, + json=payload, + headers={"Authorization": f"Token {self.client.api_key}"}, + ) + if r.status_code not in [200, 201]: + raise Exception(f"Error occurred while creating pipeline. Response from API: {r.text}") + + print(f"Pipeline created. link: https://app.embedchain.ai/pipelines/{r.json()['id']}") + return r.json() + + def _get_presigned_url(self, data_type, data_value): + payload = {"data_type": data_type, "data_value": data_value} + r = requests.post( + f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/", + json=payload, + headers={"Authorization": f"Token {self.client.api_key}"}, + ) + r.raise_for_status() + return r.json() def search(self, query, num_documents=3): """ Search for similar documents related to the query in the vector database. """ - where = {"app_id": self.id} - return self.db.query( - query, - n_results=num_documents, - where=where, - skip_embedding=False, + # TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True. + if self.deploy is False: + where = {"app_id": self.local_id} + return self.db.query( + query, + n_results=num_documents, + where=where, + skip_embedding=False, + ) + else: + # Make API call to the backend to get the results + NotImplementedError("Search is not implemented yet for the prod mode.") + + def _upload_file_to_presigned_url(self, presigned_url, file_path): + try: + with open(file_path, "rb") as file: + response = requests.put(presigned_url, data=file) + response.raise_for_status() + return response.status_code == 200 + except Exception as e: + self.logger.exception(f"Error occurred during file upload: {str(e)}") + return False + + def _upload_data_to_pipeline(self, data_type, data_value, metadata=None): + payload = { + "data_type": data_type, + "data_value": data_value, + "metadata": metadata, + } + return self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload) + + def _send_api_request(self, endpoint, payload): + url = f"{self.client.host}{endpoint}" + headers = {"Authorization": f"Token {self.client.api_key}"} + response = requests.post(url, json=payload, headers=headers) + response.raise_for_status() + return response + + def _process_and_upload_data(self, data_hash, data_type, data_value): + if os.path.isabs(data_value): + presigned_url_data = self._get_presigned_url(data_type, data_value) + presigned_url = presigned_url_data["presigned_url"] + s3_key = presigned_url_data["s3_key"] + if self._upload_file_to_presigned_url(presigned_url, file_path=data_value): + data_value = presigned_url + metadata = {"file_path": data_value, "s3_key": s3_key} + else: + self.logger.error(f"File upload failed for hash: {data_hash}") + return False + else: + if data_type == "qna_pair": + data_value = list(ast.literal_eval(data_value)) + metadata = {} + + try: + self._upload_data_to_pipeline(data_type, data_value, metadata) + self._mark_data_as_uploaded(data_hash) + self.logger.info(f"Data of type {data_type} uploaded successfully.") + return True + except Exception as e: + self.logger.error(f"Error occurred during data upload: {str(e)}") + return False + + def _mark_data_as_uploaded(self, data_hash): + self.cursor.execute( + "UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ? AND is_uploaded = 0", + (data_hash, self.local_id), ) + self.connection.commit() + + def deploy(self): + try: + if self.client is None: + self._init_client() + + pipeline_data = self._create_pipeline() + self.id = pipeline_data["id"] + + results = self.cursor.execute( + "SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,) + ).fetchall() + + for result in results: + data_hash, data_type, data_value = result[0], result[2], result[3] + if self._process_and_upload_data(data_hash, data_type, data_value): + self.logger.info(f"Data with hash {data_hash} uploaded successfully.") + + except Exception as e: + self.logger.exception(f"Error occurred during deployment: {str(e)}") + raise HTTPException(status_code=500, detail="Error occurred during deployment.") @classmethod def from_config(cls, yaml_path: str): @@ -82,7 +254,7 @@ class Pipeline(EmbedChain): with open(yaml_path, "r") as file: config_data = yaml.safe_load(file) - pipeline_config_data = config_data.get("pipeline", {}) + pipeline_config_data = config_data.get("pipeline", {}).get("config", {}) db_config_data = config_data.get("vectordb", {}) embedding_model_config_data = config_data.get("embedding_model", {}) @@ -95,4 +267,39 @@ class Pipeline(EmbedChain): embedding_model = EmbedderFactory.create( embedding_model_provider, embedding_model_config_data.get("config", {}) ) - return cls(config=pipeline_config, db=db, embedding_model=embedding_model) + return cls( + config=pipeline_config, + db=db, + embedding_model=embedding_model, + yaml_path=yaml_path, + ) + + def start(self, host="0.0.0.0", port=8000): + app = FastAPI() + + @app.post("/add") + async def add_document(data_value: str, data_type: str = None): + """ + Add a document to the pipeline. + """ + try: + document = {"data_value": data_value, "data_type": data_type} + self.add(document) + return {"message": "Document added successfully"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/query") + async def query_documents(query: str, num_documents: int = 3): + """ + Query for similar documents in the pipeline. + """ + try: + results = self.search(query, num_documents) + return results + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + import uvicorn + + uvicorn.run(app, host=host, port=port) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..2465fa72 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,16 @@ +import os + +import pytest + + +def clean_db(): + db_path = os.path.expanduser("~/.embedchain/embedchain.db") + if os.path.exists(db_path): + os.remove(db_path) + + +@pytest.fixture +def setup(): + clean_db() + yield + clean_db() diff --git a/tests/embedchain/test_add.py b/tests/embedchain/test_add.py index b1250943..7152f728 100644 --- a/tests/embedchain/test_add.py +++ b/tests/embedchain/test_add.py @@ -16,19 +16,20 @@ def app(mocker): def test_add(app): - app.add("https://example.com", metadata={"meta": "meta-data"}) - assert app.user_asks == [["https://example.com", "web_page", {"meta": "meta-data"}]] + app.add("https://example.com", metadata={"foo": "bar"}) + assert app.user_asks == [["https://example.com", "web_page", {"foo": "bar"}]] -def test_add_sitemap(app): - app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"}) - assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]] +# TODO: Make this test faster by generating a sitemap locally rather than using a remote one +# def test_add_sitemap(app): +# app.add("https://www.google.com/sitemap.xml", metadata={"foo": "bar"}) +# assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"foo": "bar"}]] def test_add_forced_type(app): data_type = "text" - app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"}) - assert app.user_asks == [["https://example.com", data_type, {"meta": "meta-data"}]] + app.add("https://example.com", data_type=data_type, metadata={"foo": "bar"}) + assert app.user_asks == [["https://example.com", data_type, {"foo": "bar"}]] def test_dry_run(app): diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..5259ecd6 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,53 @@ +import pytest + +from embedchain import Client + + +class TestClient: + @pytest.fixture + def mock_requests_post(self, mocker): + return mocker.patch("embedchain.client.requests.post") + + def test_valid_api_key(self, mock_requests_post): + mock_requests_post.return_value.status_code = 200 + client = Client(api_key="valid_api_key") + assert client.check("valid_api_key") is True + + def test_invalid_api_key(self, mock_requests_post): + mock_requests_post.return_value.status_code = 401 + with pytest.raises(ValueError): + Client(api_key="invalid_api_key") + + def test_update_valid_api_key(self, mock_requests_post): + mock_requests_post.return_value.status_code = 200 + client = Client(api_key="valid_api_key") + client.update("new_valid_api_key") + assert client.get() == "new_valid_api_key" + + def test_clear_api_key(self, mock_requests_post): + mock_requests_post.return_value.status_code = 200 + client = Client(api_key="valid_api_key") + client.clear() + assert client.get() is None + + def test_save_api_key(self, mock_requests_post): + mock_requests_post.return_value.status_code = 200 + api_key_to_save = "valid_api_key" + client = Client(api_key=api_key_to_save) + client.save() + assert client.get() == api_key_to_save + + def test_load_api_key_from_config(self, mocker): + mocker.patch("embedchain.Client.load_config", return_value={"api_key": "test_api_key"}) + client = Client() + assert client.get() == "test_api_key" + + def test_load_invalid_api_key_from_config(self, mocker): + mocker.patch("embedchain.Client.load_config", return_value={}) + with pytest.raises(ValueError): + Client() + + def test_load_missing_api_key_from_config(self, mocker): + mocker.patch("embedchain.Client.load_config", return_value={}) + with pytest.raises(ValueError): + Client() diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index 3477d717..b6861410 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -1,9 +1,10 @@ import os import shutil -import pytest from unittest.mock import patch +import pytest from chromadb.config import Settings + from embedchain import App from embedchain.config import AppConfig, ChromaDbConfig from embedchain.vectordb.chroma import ChromaDB