[Feature] Add support for deploying local pipelines to Embedchain platform (#847)

This commit is contained in:
Deshraj Yadav
2023-10-25 13:36:24 -07:00
committed by GitHub
parent 76f1993e7a
commit 3979480532
9 changed files with 467 additions and 43 deletions

View File

@@ -3,5 +3,6 @@ import importlib.metadata
__version__ = importlib.metadata.version(__package__ or __name__)
from embedchain.apps.app import App # noqa: F401
from embedchain.client import Client # noqa: F401
from embedchain.pipeline import Pipeline # noqa: F401
from embedchain.vectordb.chroma import ChromaDB # noqa: F401

102
embedchain/client.py Normal file
View File

@@ -0,0 +1,102 @@
import json
import logging
import os
import uuid
import requests
from embedchain.embedchain import CONFIG_DIR, CONFIG_FILE
class Client:
def __init__(self, api_key=None, host="https://apiv2.embedchain.ai"):
self.config_data = self.load_config()
self.host = host
if api_key:
if self.check(api_key):
self.api_key = api_key
self.save()
else:
raise ValueError(
"Invalid API key provided. You can find your API key on https://app.embedchain.ai/settings/keys."
)
else:
if "api_key" in self.config_data:
self.api_key = self.config_data["api_key"]
logging.info("API key loaded successfully!")
else:
raise ValueError(
"You are not logged in. Please obtain an API key from https://app.embedchain.ai/settings/keys/"
)
@classmethod
def setup_dir(self):
"""
Loads the user id from the config file if it exists, otherwise generates a new
one and saves it to the config file.
:return: user id
:rtype: str
"""
if not os.path.exists(CONFIG_DIR):
os.makedirs(CONFIG_DIR)
if os.path.exists(CONFIG_FILE):
with open(CONFIG_FILE, "r") as f:
data = json.load(f)
if "user_id" in data:
return data["user_id"]
u_id = str(uuid.uuid4())
with open(CONFIG_FILE, "w") as f:
json.dump({"user_id": u_id}, f)
@classmethod
def load_config(cls):
if not os.path.exists(CONFIG_FILE):
cls.setup_dir()
with open(CONFIG_FILE, "r") as config_file:
return json.load(config_file)
def save(self):
self.config_data["api_key"] = self.api_key
with open(CONFIG_FILE, "w") as config_file:
json.dump(self.config_data, config_file, indent=4)
logging.info("API key saved successfully!")
def clear(self):
if "api_key" in self.config_data:
del self.config_data["api_key"]
with open(CONFIG_FILE, "w") as config_file:
json.dump(self.config_data, config_file, indent=4)
self.api_key = None
logging.info("API key deleted successfully!")
else:
logging.warning("API key not found in the configuration file.")
def update(self, api_key):
if self.check(api_key):
self.api_key = api_key
self.save()
logging.info("API key updated successfully!")
else:
logging.warning("Invalid API key provided. API key not updated.")
def check(self, api_key):
validation_url = f"{self.host}/api/v1/accounts/api_keys/validate/"
response = requests.post(validation_url, headers={"Authorization": f"Token {api_key}"})
if response.status_code == 200:
return True
else:
logging.warning(f"Response from API: {response.text}")
logging.warning("Invalid API key. Unable to validate.")
return False
def get(self):
return self.api_key
def __str__(self):
return self.api_key

View File

@@ -2,7 +2,6 @@
from .add_config import AddConfig, ChunkerConfig
from .apps.app_config import AppConfig
from .pipeline_config import PipelineConfig
from .base_config import BaseConfig
from .embedder.base import BaseEmbedderConfig
from .embedder.base import BaseEmbedderConfig as EmbedderConfig

View File

@@ -3,6 +3,7 @@ import importlib.metadata
import json
import logging
import os
import sqlite3
import threading
import uuid
from pathlib import Path
@@ -32,6 +33,7 @@ ABS_PATH = os.getcwd()
HOME_DIR = str(Path.home())
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
class EmbedChain(JSONSerializable):
@@ -89,6 +91,27 @@ class EmbedChain(JSONSerializable):
# Send anonymous telemetry
self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
self.u_id = self._load_or_generate_user_id()
# Establish a connection to the SQLite database
self.connection = sqlite3.connect(SQLITE_PATH)
self.cursor = self.connection.cursor()
# Create the 'data_sources' table if it doesn't exist
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS data_sources (
pipeline_id TEXT,
hash TEXT,
type TEXT,
value TEXT,
metadata TEXT,
is_uploaded INTEGER DEFAULT 0,
PRIMARY KEY (pipeline_id, hash)
)
"""
)
self.connection.commit()
# NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
# if (self.config.collect_metrics):
# raise ConnectionRefusedError("Collection of metrics should not be allowed.")
@@ -163,7 +186,7 @@ class EmbedChain(JSONSerializable):
:raises ValueError: Invalid data type
:param dry_run: Optional. A dry run displays the chunks to ensure that the loader and chunker work as intended.
deafaults to False
:return: source_id, a md5-hash of the source, in hexadecimal representation.
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
:rtype: str
"""
if config is None:
@@ -192,18 +215,40 @@ class EmbedChain(JSONSerializable):
if not data_type:
data_type = detect_datatype(source)
# `source_id` is the hash of the source argument
# `source_hash` is the md5 hash of the source argument
hash_object = hashlib.md5(str(source).encode("utf-8"))
source_id = hash_object.hexdigest()
source_hash = hash_object.hexdigest()
# Check if the data hash already exists, if so, skip the addition
self.cursor.execute(
"SELECT 1 FROM data_sources WHERE hash = ? AND pipeline_id = ?", (source_hash, self.config.id)
)
existing_data = self.cursor.fetchone()
if existing_data:
print(f"Data with hash {source_hash} already exists. Skipping addition.")
return source_hash
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([source, data_type.value, metadata])
documents, metadatas, _ids, new_chunks = self.load_and_embed(
data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
)
if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True
# Insert the data into the 'data' table
self.cursor.execute(
"""
INSERT INTO data_sources (hash, pipeline_id, type, value, metadata)
VALUES (?, ?, ?, ?, ?)
""",
(source_hash, self.config.id, data_type.value, str(source), json.dumps(metadata)),
)
# Commit the transaction
self.connection.commit()
if dry_run:
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
logging.debug(f"Dry run info : {data_chunks_info}")
@@ -218,7 +263,7 @@ class EmbedChain(JSONSerializable):
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
thread_telemetry.start()
return source_id
return source_hash
def add_local(
self,
@@ -245,7 +290,7 @@ class EmbedChain(JSONSerializable):
:param config: The `AddConfig` instance to use as configuration options., defaults to None
:type config: Optional[AddConfig], optional
:raises ValueError: Invalid data type
:return: source_id, a md5-hash of the source, in hexadecimal representation.
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
:rtype: str
"""
logging.warning(
@@ -313,7 +358,7 @@ class EmbedChain(JSONSerializable):
chunker: BaseChunker,
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None,
source_hash: Optional[str] = None,
dry_run=False,
):
"""
@@ -324,7 +369,7 @@ class EmbedChain(JSONSerializable):
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
:param metadata: Optional. Metadata associated with the data source.
:param source_id: Hexadecimal hash of the source.
:param source_hash: Hexadecimal hash of the source.
:param dry_run: Optional. A dry run returns chunks and doesn't update DB.
:type dry_run: bool, defaults to False
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
@@ -382,7 +427,7 @@ class EmbedChain(JSONSerializable):
m["app_id"] = self.config.id
# Add hashed source
m["hash"] = source_id
m["hash"] = source_hash
# Note: Metadata is the function argument
if metadata:
@@ -558,15 +603,14 @@ class EmbedChain(JSONSerializable):
"""
Resets the database. Deletes all embeddings irreversibly.
`App` does not have to be reinitialized after using this method.
DEPRECATED IN FAVOR OF `db.reset()`
"""
# Send anonymous telemetry
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
thread_telemetry.start()
logging.warning("DEPRECATION WARNING: Please use `app.db.reset()` instead of `App.reset()`.")
self.db.reset()
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
self.connection.commit()
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):

View File

@@ -1,17 +1,27 @@
import threading
import ast
import json
import logging
import os
import sqlite3
import uuid
import requests
import yaml
from fastapi import FastAPI, HTTPException
from embedchain import Client
from embedchain.config import PipelineConfig
from embedchain.embedchain import EmbedChain
from embedchain.embedchain import CONFIG_DIR, EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.factory import EmbedderFactory, VectorDBFactory
from embedchain.helper.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
@register_deserializable
class Pipeline(EmbedChain):
@@ -21,7 +31,15 @@ class Pipeline(EmbedChain):
and vector database.
"""
def __init__(self, config: PipelineConfig = None, db: BaseVectorDB = None, embedding_model: BaseEmbedder = None):
def __init__(
self,
config: PipelineConfig = None,
db: BaseVectorDB = None,
embedding_model: BaseEmbedder = None,
llm: BaseLlm = None,
yaml_path: str = None,
log_level=logging.INFO,
):
"""
Initialize a new `App` instance.
@@ -32,42 +50,196 @@ class Pipeline(EmbedChain):
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
:type embedding_model: BaseEmbedder, optional
"""
super().__init__()
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
self.logger = logging.getLogger(__name__)
# Store the yaml config as an attribute to be able to send it
self.yaml_config = None
self.client = None
if yaml_path:
with open(yaml_path, "r") as file:
config_data = yaml.safe_load(file)
self.yaml_config = config_data
self.config = config or PipelineConfig()
self.name = self.config.name
self.id = self.config.id or str(uuid.uuid4())
self.local_id = self.config.id or str(uuid.uuid4())
self.embedding_model = embedding_model or OpenAIEmbedder()
self.db = db or ChromaDB()
self._initialize_db()
self.llm = llm or None
self._init_db()
# setup user id and directory
self.u_id = self._load_or_generate_user_id()
# Establish a connection to the SQLite database
self.connection = sqlite3.connect(SQLITE_PATH)
self.cursor = self.connection.cursor()
# Create the 'data_sources' table if it doesn't exist
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS data_sources (
pipeline_id TEXT,
hash TEXT,
type TEXT,
value TEXT,
metadata TEXT
is_uploaded INTEGER DEFAULT 0,
PRIMARY KEY (pipeline_id, hash)
)
"""
)
self.connection.commit()
self.user_asks = [] # legacy defaults
self.s_id = self.config.id or str(uuid.uuid4())
self.u_id = self._load_or_generate_user_id()
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("pipeline_init",))
thread_telemetry.start()
def _initialize_db(self):
def _init_db(self):
"""
Initialize the database.
"""
self.db._set_embedder(self.embedding_model)
self.db._initialize()
self.db.set_collection_name(self.name)
self.db.set_collection_name(self.db.config.collection_name)
def _init_client(self):
"""
Initialize the client.
"""
config = Client.load_config()
if config.get("api_key"):
self.client = Client()
else:
api_key = input("Enter API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n")
self.client = Client(api_key=api_key)
def _create_pipeline(self):
"""
Create a pipeline on the platform.
"""
print("Creating pipeline on the platform...")
# self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend
payload = {
"yaml_config": json.dumps(self.yaml_config),
"name": self.name,
"local_id": self.local_id,
}
url = f"{self.client.host}/api/v1/pipelines/cli/create/"
r = requests.post(
url,
json=payload,
headers={"Authorization": f"Token {self.client.api_key}"},
)
if r.status_code not in [200, 201]:
raise Exception(f"Error occurred while creating pipeline. Response from API: {r.text}")
print(f"Pipeline created. link: https://app.embedchain.ai/pipelines/{r.json()['id']}")
return r.json()
def _get_presigned_url(self, data_type, data_value):
payload = {"data_type": data_type, "data_value": data_value}
r = requests.post(
f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
json=payload,
headers={"Authorization": f"Token {self.client.api_key}"},
)
r.raise_for_status()
return r.json()
def search(self, query, num_documents=3):
"""
Search for similar documents related to the query in the vector database.
"""
where = {"app_id": self.id}
return self.db.query(
query,
n_results=num_documents,
where=where,
skip_embedding=False,
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
if self.deploy is False:
where = {"app_id": self.local_id}
return self.db.query(
query,
n_results=num_documents,
where=where,
skip_embedding=False,
)
else:
# Make API call to the backend to get the results
NotImplementedError("Search is not implemented yet for the prod mode.")
def _upload_file_to_presigned_url(self, presigned_url, file_path):
try:
with open(file_path, "rb") as file:
response = requests.put(presigned_url, data=file)
response.raise_for_status()
return response.status_code == 200
except Exception as e:
self.logger.exception(f"Error occurred during file upload: {str(e)}")
return False
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
payload = {
"data_type": data_type,
"data_value": data_value,
"metadata": metadata,
}
return self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
def _send_api_request(self, endpoint, payload):
url = f"{self.client.host}{endpoint}"
headers = {"Authorization": f"Token {self.client.api_key}"}
response = requests.post(url, json=payload, headers=headers)
response.raise_for_status()
return response
def _process_and_upload_data(self, data_hash, data_type, data_value):
if os.path.isabs(data_value):
presigned_url_data = self._get_presigned_url(data_type, data_value)
presigned_url = presigned_url_data["presigned_url"]
s3_key = presigned_url_data["s3_key"]
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
data_value = presigned_url
metadata = {"file_path": data_value, "s3_key": s3_key}
else:
self.logger.error(f"File upload failed for hash: {data_hash}")
return False
else:
if data_type == "qna_pair":
data_value = list(ast.literal_eval(data_value))
metadata = {}
try:
self._upload_data_to_pipeline(data_type, data_value, metadata)
self._mark_data_as_uploaded(data_hash)
self.logger.info(f"Data of type {data_type} uploaded successfully.")
return True
except Exception as e:
self.logger.error(f"Error occurred during data upload: {str(e)}")
return False
def _mark_data_as_uploaded(self, data_hash):
self.cursor.execute(
"UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ? AND is_uploaded = 0",
(data_hash, self.local_id),
)
self.connection.commit()
def deploy(self):
try:
if self.client is None:
self._init_client()
pipeline_data = self._create_pipeline()
self.id = pipeline_data["id"]
results = self.cursor.execute(
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,)
).fetchall()
for result in results:
data_hash, data_type, data_value = result[0], result[2], result[3]
if self._process_and_upload_data(data_hash, data_type, data_value):
self.logger.info(f"Data with hash {data_hash} uploaded successfully.")
except Exception as e:
self.logger.exception(f"Error occurred during deployment: {str(e)}")
raise HTTPException(status_code=500, detail="Error occurred during deployment.")
@classmethod
def from_config(cls, yaml_path: str):
@@ -82,7 +254,7 @@ class Pipeline(EmbedChain):
with open(yaml_path, "r") as file:
config_data = yaml.safe_load(file)
pipeline_config_data = config_data.get("pipeline", {})
pipeline_config_data = config_data.get("pipeline", {}).get("config", {})
db_config_data = config_data.get("vectordb", {})
embedding_model_config_data = config_data.get("embedding_model", {})
@@ -95,4 +267,39 @@ class Pipeline(EmbedChain):
embedding_model = EmbedderFactory.create(
embedding_model_provider, embedding_model_config_data.get("config", {})
)
return cls(config=pipeline_config, db=db, embedding_model=embedding_model)
return cls(
config=pipeline_config,
db=db,
embedding_model=embedding_model,
yaml_path=yaml_path,
)
def start(self, host="0.0.0.0", port=8000):
app = FastAPI()
@app.post("/add")
async def add_document(data_value: str, data_type: str = None):
"""
Add a document to the pipeline.
"""
try:
document = {"data_value": data_value, "data_type": data_type}
self.add(document)
return {"message": "Document added successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query")
async def query_documents(query: str, num_documents: int = 3):
"""
Query for similar documents in the pipeline.
"""
try:
results = self.search(query, num_documents)
return results
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
import uvicorn
uvicorn.run(app, host=host, port=port)

16
tests/conftest.py Normal file
View File

@@ -0,0 +1,16 @@
import os
import pytest
def clean_db():
db_path = os.path.expanduser("~/.embedchain/embedchain.db")
if os.path.exists(db_path):
os.remove(db_path)
@pytest.fixture
def setup():
clean_db()
yield
clean_db()

View File

@@ -16,19 +16,20 @@ def app(mocker):
def test_add(app):
app.add("https://example.com", metadata={"meta": "meta-data"})
assert app.user_asks == [["https://example.com", "web_page", {"meta": "meta-data"}]]
app.add("https://example.com", metadata={"foo": "bar"})
assert app.user_asks == [["https://example.com", "web_page", {"foo": "bar"}]]
def test_add_sitemap(app):
app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"})
assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]]
# TODO: Make this test faster by generating a sitemap locally rather than using a remote one
# def test_add_sitemap(app):
# app.add("https://www.google.com/sitemap.xml", metadata={"foo": "bar"})
# assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"foo": "bar"}]]
def test_add_forced_type(app):
data_type = "text"
app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"})
assert app.user_asks == [["https://example.com", data_type, {"meta": "meta-data"}]]
app.add("https://example.com", data_type=data_type, metadata={"foo": "bar"})
assert app.user_asks == [["https://example.com", data_type, {"foo": "bar"}]]
def test_dry_run(app):

53
tests/test_client.py Normal file
View File

@@ -0,0 +1,53 @@
import pytest
from embedchain import Client
class TestClient:
@pytest.fixture
def mock_requests_post(self, mocker):
return mocker.patch("embedchain.client.requests.post")
def test_valid_api_key(self, mock_requests_post):
mock_requests_post.return_value.status_code = 200
client = Client(api_key="valid_api_key")
assert client.check("valid_api_key") is True
def test_invalid_api_key(self, mock_requests_post):
mock_requests_post.return_value.status_code = 401
with pytest.raises(ValueError):
Client(api_key="invalid_api_key")
def test_update_valid_api_key(self, mock_requests_post):
mock_requests_post.return_value.status_code = 200
client = Client(api_key="valid_api_key")
client.update("new_valid_api_key")
assert client.get() == "new_valid_api_key"
def test_clear_api_key(self, mock_requests_post):
mock_requests_post.return_value.status_code = 200
client = Client(api_key="valid_api_key")
client.clear()
assert client.get() is None
def test_save_api_key(self, mock_requests_post):
mock_requests_post.return_value.status_code = 200
api_key_to_save = "valid_api_key"
client = Client(api_key=api_key_to_save)
client.save()
assert client.get() == api_key_to_save
def test_load_api_key_from_config(self, mocker):
mocker.patch("embedchain.Client.load_config", return_value={"api_key": "test_api_key"})
client = Client()
assert client.get() == "test_api_key"
def test_load_invalid_api_key_from_config(self, mocker):
mocker.patch("embedchain.Client.load_config", return_value={})
with pytest.raises(ValueError):
Client()
def test_load_missing_api_key_from_config(self, mocker):
mocker.patch("embedchain.Client.load_config", return_value={})
with pytest.raises(ValueError):
Client()

View File

@@ -1,9 +1,10 @@
import os
import shutil
import pytest
from unittest.mock import patch
import pytest
from chromadb.config import Settings
from embedchain import App
from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.vectordb.chroma import ChromaDB