[Feature] Add support for deploying local pipelines to Embedchain platform (#847)
This commit is contained in:
@@ -3,5 +3,6 @@ import importlib.metadata
|
|||||||
__version__ = importlib.metadata.version(__package__ or __name__)
|
__version__ = importlib.metadata.version(__package__ or __name__)
|
||||||
|
|
||||||
from embedchain.apps.app import App # noqa: F401
|
from embedchain.apps.app import App # noqa: F401
|
||||||
|
from embedchain.client import Client # noqa: F401
|
||||||
from embedchain.pipeline import Pipeline # noqa: F401
|
from embedchain.pipeline import Pipeline # noqa: F401
|
||||||
from embedchain.vectordb.chroma import ChromaDB # noqa: F401
|
from embedchain.vectordb.chroma import ChromaDB # noqa: F401
|
||||||
|
|||||||
102
embedchain/client.py
Normal file
102
embedchain/client.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from embedchain.embedchain import CONFIG_DIR, CONFIG_FILE
|
||||||
|
|
||||||
|
|
||||||
|
class Client:
|
||||||
|
def __init__(self, api_key=None, host="https://apiv2.embedchain.ai"):
|
||||||
|
self.config_data = self.load_config()
|
||||||
|
self.host = host
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
if self.check(api_key):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.save()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid API key provided. You can find your API key on https://app.embedchain.ai/settings/keys."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if "api_key" in self.config_data:
|
||||||
|
self.api_key = self.config_data["api_key"]
|
||||||
|
logging.info("API key loaded successfully!")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You are not logged in. Please obtain an API key from https://app.embedchain.ai/settings/keys/"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_dir(self):
|
||||||
|
"""
|
||||||
|
Loads the user id from the config file if it exists, otherwise generates a new
|
||||||
|
one and saves it to the config file.
|
||||||
|
|
||||||
|
:return: user id
|
||||||
|
:rtype: str
|
||||||
|
"""
|
||||||
|
if not os.path.exists(CONFIG_DIR):
|
||||||
|
os.makedirs(CONFIG_DIR)
|
||||||
|
|
||||||
|
if os.path.exists(CONFIG_FILE):
|
||||||
|
with open(CONFIG_FILE, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if "user_id" in data:
|
||||||
|
return data["user_id"]
|
||||||
|
|
||||||
|
u_id = str(uuid.uuid4())
|
||||||
|
with open(CONFIG_FILE, "w") as f:
|
||||||
|
json.dump({"user_id": u_id}, f)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_config(cls):
|
||||||
|
if not os.path.exists(CONFIG_FILE):
|
||||||
|
cls.setup_dir()
|
||||||
|
|
||||||
|
with open(CONFIG_FILE, "r") as config_file:
|
||||||
|
return json.load(config_file)
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
self.config_data["api_key"] = self.api_key
|
||||||
|
with open(CONFIG_FILE, "w") as config_file:
|
||||||
|
json.dump(self.config_data, config_file, indent=4)
|
||||||
|
|
||||||
|
logging.info("API key saved successfully!")
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
if "api_key" in self.config_data:
|
||||||
|
del self.config_data["api_key"]
|
||||||
|
with open(CONFIG_FILE, "w") as config_file:
|
||||||
|
json.dump(self.config_data, config_file, indent=4)
|
||||||
|
self.api_key = None
|
||||||
|
logging.info("API key deleted successfully!")
|
||||||
|
else:
|
||||||
|
logging.warning("API key not found in the configuration file.")
|
||||||
|
|
||||||
|
def update(self, api_key):
|
||||||
|
if self.check(api_key):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.save()
|
||||||
|
logging.info("API key updated successfully!")
|
||||||
|
else:
|
||||||
|
logging.warning("Invalid API key provided. API key not updated.")
|
||||||
|
|
||||||
|
def check(self, api_key):
|
||||||
|
validation_url = f"{self.host}/api/v1/accounts/api_keys/validate/"
|
||||||
|
response = requests.post(validation_url, headers={"Authorization": f"Token {api_key}"})
|
||||||
|
if response.status_code == 200:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logging.warning(f"Response from API: {response.text}")
|
||||||
|
logging.warning("Invalid API key. Unable to validate.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
return self.api_key
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.api_key
|
||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from .add_config import AddConfig, ChunkerConfig
|
from .add_config import AddConfig, ChunkerConfig
|
||||||
from .apps.app_config import AppConfig
|
from .apps.app_config import AppConfig
|
||||||
from .pipeline_config import PipelineConfig
|
|
||||||
from .base_config import BaseConfig
|
from .base_config import BaseConfig
|
||||||
from .embedder.base import BaseEmbedderConfig
|
from .embedder.base import BaseEmbedderConfig
|
||||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import importlib.metadata
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -32,6 +33,7 @@ ABS_PATH = os.getcwd()
|
|||||||
HOME_DIR = str(Path.home())
|
HOME_DIR = str(Path.home())
|
||||||
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
||||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
||||||
|
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
|
||||||
|
|
||||||
|
|
||||||
class EmbedChain(JSONSerializable):
|
class EmbedChain(JSONSerializable):
|
||||||
@@ -89,6 +91,27 @@ class EmbedChain(JSONSerializable):
|
|||||||
# Send anonymous telemetry
|
# Send anonymous telemetry
|
||||||
self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
|
self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
|
||||||
self.u_id = self._load_or_generate_user_id()
|
self.u_id = self._load_or_generate_user_id()
|
||||||
|
|
||||||
|
# Establish a connection to the SQLite database
|
||||||
|
self.connection = sqlite3.connect(SQLITE_PATH)
|
||||||
|
self.cursor = self.connection.cursor()
|
||||||
|
|
||||||
|
# Create the 'data_sources' table if it doesn't exist
|
||||||
|
self.cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS data_sources (
|
||||||
|
pipeline_id TEXT,
|
||||||
|
hash TEXT,
|
||||||
|
type TEXT,
|
||||||
|
value TEXT,
|
||||||
|
metadata TEXT,
|
||||||
|
is_uploaded INTEGER DEFAULT 0,
|
||||||
|
PRIMARY KEY (pipeline_id, hash)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
# NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
|
# NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
|
||||||
# if (self.config.collect_metrics):
|
# if (self.config.collect_metrics):
|
||||||
# raise ConnectionRefusedError("Collection of metrics should not be allowed.")
|
# raise ConnectionRefusedError("Collection of metrics should not be allowed.")
|
||||||
@@ -163,7 +186,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
:raises ValueError: Invalid data type
|
:raises ValueError: Invalid data type
|
||||||
:param dry_run: Optional. A dry run displays the chunks to ensure that the loader and chunker work as intended.
|
:param dry_run: Optional. A dry run displays the chunks to ensure that the loader and chunker work as intended.
|
||||||
deafaults to False
|
deafaults to False
|
||||||
:return: source_id, a md5-hash of the source, in hexadecimal representation.
|
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
@@ -192,18 +215,40 @@ class EmbedChain(JSONSerializable):
|
|||||||
if not data_type:
|
if not data_type:
|
||||||
data_type = detect_datatype(source)
|
data_type = detect_datatype(source)
|
||||||
|
|
||||||
# `source_id` is the hash of the source argument
|
# `source_hash` is the md5 hash of the source argument
|
||||||
hash_object = hashlib.md5(str(source).encode("utf-8"))
|
hash_object = hashlib.md5(str(source).encode("utf-8"))
|
||||||
source_id = hash_object.hexdigest()
|
source_hash = hash_object.hexdigest()
|
||||||
|
|
||||||
|
# Check if the data hash already exists, if so, skip the addition
|
||||||
|
self.cursor.execute(
|
||||||
|
"SELECT 1 FROM data_sources WHERE hash = ? AND pipeline_id = ?", (source_hash, self.config.id)
|
||||||
|
)
|
||||||
|
existing_data = self.cursor.fetchone()
|
||||||
|
|
||||||
|
if existing_data:
|
||||||
|
print(f"Data with hash {source_hash} already exists. Skipping addition.")
|
||||||
|
return source_hash
|
||||||
|
|
||||||
data_formatter = DataFormatter(data_type, config)
|
data_formatter = DataFormatter(data_type, config)
|
||||||
self.user_asks.append([source, data_type.value, metadata])
|
self.user_asks.append([source, data_type.value, metadata])
|
||||||
documents, metadatas, _ids, new_chunks = self.load_and_embed(
|
documents, metadatas, _ids, new_chunks = self.load_and_embed(
|
||||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
|
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
|
||||||
)
|
)
|
||||||
if data_type in {DataType.DOCS_SITE}:
|
if data_type in {DataType.DOCS_SITE}:
|
||||||
self.is_docs_site_instance = True
|
self.is_docs_site_instance = True
|
||||||
|
|
||||||
|
# Insert the data into the 'data' table
|
||||||
|
self.cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO data_sources (hash, pipeline_id, type, value, metadata)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(source_hash, self.config.id, data_type.value, str(source), json.dumps(metadata)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Commit the transaction
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
if dry_run:
|
if dry_run:
|
||||||
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
|
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
|
||||||
logging.debug(f"Dry run info : {data_chunks_info}")
|
logging.debug(f"Dry run info : {data_chunks_info}")
|
||||||
@@ -218,7 +263,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
|
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
|
||||||
thread_telemetry.start()
|
thread_telemetry.start()
|
||||||
|
|
||||||
return source_id
|
return source_hash
|
||||||
|
|
||||||
def add_local(
|
def add_local(
|
||||||
self,
|
self,
|
||||||
@@ -245,7 +290,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
:param config: The `AddConfig` instance to use as configuration options., defaults to None
|
:param config: The `AddConfig` instance to use as configuration options., defaults to None
|
||||||
:type config: Optional[AddConfig], optional
|
:type config: Optional[AddConfig], optional
|
||||||
:raises ValueError: Invalid data type
|
:raises ValueError: Invalid data type
|
||||||
:return: source_id, a md5-hash of the source, in hexadecimal representation.
|
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
logging.warning(
|
logging.warning(
|
||||||
@@ -313,7 +358,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
chunker: BaseChunker,
|
chunker: BaseChunker,
|
||||||
src: Any,
|
src: Any,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
source_id: Optional[str] = None,
|
source_hash: Optional[str] = None,
|
||||||
dry_run=False,
|
dry_run=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -324,7 +369,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
:param src: The data to be handled by the loader. Can be a URL for
|
:param src: The data to be handled by the loader. Can be a URL for
|
||||||
remote sources or local content for local loaders.
|
remote sources or local content for local loaders.
|
||||||
:param metadata: Optional. Metadata associated with the data source.
|
:param metadata: Optional. Metadata associated with the data source.
|
||||||
:param source_id: Hexadecimal hash of the source.
|
:param source_hash: Hexadecimal hash of the source.
|
||||||
:param dry_run: Optional. A dry run returns chunks and doesn't update DB.
|
:param dry_run: Optional. A dry run returns chunks and doesn't update DB.
|
||||||
:type dry_run: bool, defaults to False
|
:type dry_run: bool, defaults to False
|
||||||
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
||||||
@@ -382,7 +427,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
m["app_id"] = self.config.id
|
m["app_id"] = self.config.id
|
||||||
|
|
||||||
# Add hashed source
|
# Add hashed source
|
||||||
m["hash"] = source_id
|
m["hash"] = source_hash
|
||||||
|
|
||||||
# Note: Metadata is the function argument
|
# Note: Metadata is the function argument
|
||||||
if metadata:
|
if metadata:
|
||||||
@@ -558,15 +603,14 @@ class EmbedChain(JSONSerializable):
|
|||||||
"""
|
"""
|
||||||
Resets the database. Deletes all embeddings irreversibly.
|
Resets the database. Deletes all embeddings irreversibly.
|
||||||
`App` does not have to be reinitialized after using this method.
|
`App` does not have to be reinitialized after using this method.
|
||||||
|
|
||||||
DEPRECATED IN FAVOR OF `db.reset()`
|
|
||||||
"""
|
"""
|
||||||
# Send anonymous telemetry
|
# Send anonymous telemetry
|
||||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
|
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
|
||||||
thread_telemetry.start()
|
thread_telemetry.start()
|
||||||
|
|
||||||
logging.warning("DEPRECATION WARNING: Please use `app.db.reset()` instead of `App.reset()`.")
|
|
||||||
self.db.reset()
|
self.db.reset()
|
||||||
|
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
|
||||||
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):
|
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):
|
||||||
|
|||||||
@@ -1,17 +1,27 @@
|
|||||||
import threading
|
import ast
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
import requests
|
||||||
import yaml
|
import yaml
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
|
||||||
|
from embedchain import Client
|
||||||
from embedchain.config import PipelineConfig
|
from embedchain.config import PipelineConfig
|
||||||
from embedchain.embedchain import EmbedChain
|
from embedchain.embedchain import CONFIG_DIR, EmbedChain
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
from embedchain.embedder.openai import OpenAIEmbedder
|
from embedchain.embedder.openai import OpenAIEmbedder
|
||||||
from embedchain.factory import EmbedderFactory, VectorDBFactory
|
from embedchain.factory import EmbedderFactory, VectorDBFactory
|
||||||
from embedchain.helper.json_serializable import register_deserializable
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
|
from embedchain.llm.base import BaseLlm
|
||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
from embedchain.vectordb.chroma import ChromaDB
|
from embedchain.vectordb.chroma import ChromaDB
|
||||||
|
|
||||||
|
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
|
||||||
|
|
||||||
|
|
||||||
@register_deserializable
|
@register_deserializable
|
||||||
class Pipeline(EmbedChain):
|
class Pipeline(EmbedChain):
|
||||||
@@ -21,7 +31,15 @@ class Pipeline(EmbedChain):
|
|||||||
and vector database.
|
and vector database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: PipelineConfig = None, db: BaseVectorDB = None, embedding_model: BaseEmbedder = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PipelineConfig = None,
|
||||||
|
db: BaseVectorDB = None,
|
||||||
|
embedding_model: BaseEmbedder = None,
|
||||||
|
llm: BaseLlm = None,
|
||||||
|
yaml_path: str = None,
|
||||||
|
log_level=logging.INFO,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a new `App` instance.
|
Initialize a new `App` instance.
|
||||||
|
|
||||||
@@ -32,42 +50,196 @@ class Pipeline(EmbedChain):
|
|||||||
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
|
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
|
||||||
:type embedding_model: BaseEmbedder, optional
|
:type embedding_model: BaseEmbedder, optional
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
# Store the yaml config as an attribute to be able to send it
|
||||||
|
self.yaml_config = None
|
||||||
|
self.client = None
|
||||||
|
if yaml_path:
|
||||||
|
with open(yaml_path, "r") as file:
|
||||||
|
config_data = yaml.safe_load(file)
|
||||||
|
self.yaml_config = config_data
|
||||||
|
|
||||||
self.config = config or PipelineConfig()
|
self.config = config or PipelineConfig()
|
||||||
self.name = self.config.name
|
self.name = self.config.name
|
||||||
self.id = self.config.id or str(uuid.uuid4())
|
self.local_id = self.config.id or str(uuid.uuid4())
|
||||||
|
|
||||||
self.embedding_model = embedding_model or OpenAIEmbedder()
|
self.embedding_model = embedding_model or OpenAIEmbedder()
|
||||||
self.db = db or ChromaDB()
|
self.db = db or ChromaDB()
|
||||||
self._initialize_db()
|
self.llm = llm or None
|
||||||
|
self._init_db()
|
||||||
|
|
||||||
|
# setup user id and directory
|
||||||
|
self.u_id = self._load_or_generate_user_id()
|
||||||
|
|
||||||
|
# Establish a connection to the SQLite database
|
||||||
|
self.connection = sqlite3.connect(SQLITE_PATH)
|
||||||
|
self.cursor = self.connection.cursor()
|
||||||
|
|
||||||
|
# Create the 'data_sources' table if it doesn't exist
|
||||||
|
self.cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS data_sources (
|
||||||
|
pipeline_id TEXT,
|
||||||
|
hash TEXT,
|
||||||
|
type TEXT,
|
||||||
|
value TEXT,
|
||||||
|
metadata TEXT
|
||||||
|
is_uploaded INTEGER DEFAULT 0,
|
||||||
|
PRIMARY KEY (pipeline_id, hash)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
self.user_asks = [] # legacy defaults
|
self.user_asks = [] # legacy defaults
|
||||||
|
|
||||||
self.s_id = self.config.id or str(uuid.uuid4())
|
def _init_db(self):
|
||||||
self.u_id = self._load_or_generate_user_id()
|
|
||||||
|
|
||||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("pipeline_init",))
|
|
||||||
thread_telemetry.start()
|
|
||||||
|
|
||||||
def _initialize_db(self):
|
|
||||||
"""
|
"""
|
||||||
Initialize the database.
|
Initialize the database.
|
||||||
"""
|
"""
|
||||||
self.db._set_embedder(self.embedding_model)
|
self.db._set_embedder(self.embedding_model)
|
||||||
self.db._initialize()
|
self.db._initialize()
|
||||||
self.db.set_collection_name(self.name)
|
self.db.set_collection_name(self.db.config.collection_name)
|
||||||
|
|
||||||
|
def _init_client(self):
|
||||||
|
"""
|
||||||
|
Initialize the client.
|
||||||
|
"""
|
||||||
|
config = Client.load_config()
|
||||||
|
if config.get("api_key"):
|
||||||
|
self.client = Client()
|
||||||
|
else:
|
||||||
|
api_key = input("Enter API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n")
|
||||||
|
self.client = Client(api_key=api_key)
|
||||||
|
|
||||||
|
def _create_pipeline(self):
|
||||||
|
"""
|
||||||
|
Create a pipeline on the platform.
|
||||||
|
"""
|
||||||
|
print("Creating pipeline on the platform...")
|
||||||
|
# self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend
|
||||||
|
payload = {
|
||||||
|
"yaml_config": json.dumps(self.yaml_config),
|
||||||
|
"name": self.name,
|
||||||
|
"local_id": self.local_id,
|
||||||
|
}
|
||||||
|
url = f"{self.client.host}/api/v1/pipelines/cli/create/"
|
||||||
|
r = requests.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||||
|
)
|
||||||
|
if r.status_code not in [200, 201]:
|
||||||
|
raise Exception(f"Error occurred while creating pipeline. Response from API: {r.text}")
|
||||||
|
|
||||||
|
print(f"Pipeline created. link: https://app.embedchain.ai/pipelines/{r.json()['id']}")
|
||||||
|
return r.json()
|
||||||
|
|
||||||
|
def _get_presigned_url(self, data_type, data_value):
|
||||||
|
payload = {"data_type": data_type, "data_value": data_value}
|
||||||
|
r = requests.post(
|
||||||
|
f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
|
||||||
|
json=payload,
|
||||||
|
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
return r.json()
|
||||||
|
|
||||||
def search(self, query, num_documents=3):
|
def search(self, query, num_documents=3):
|
||||||
"""
|
"""
|
||||||
Search for similar documents related to the query in the vector database.
|
Search for similar documents related to the query in the vector database.
|
||||||
"""
|
"""
|
||||||
where = {"app_id": self.id}
|
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
|
||||||
return self.db.query(
|
if self.deploy is False:
|
||||||
query,
|
where = {"app_id": self.local_id}
|
||||||
n_results=num_documents,
|
return self.db.query(
|
||||||
where=where,
|
query,
|
||||||
skip_embedding=False,
|
n_results=num_documents,
|
||||||
|
where=where,
|
||||||
|
skip_embedding=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Make API call to the backend to get the results
|
||||||
|
NotImplementedError("Search is not implemented yet for the prod mode.")
|
||||||
|
|
||||||
|
def _upload_file_to_presigned_url(self, presigned_url, file_path):
|
||||||
|
try:
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
response = requests.put(presigned_url, data=file)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.status_code == 200
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error occurred during file upload: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
|
||||||
|
payload = {
|
||||||
|
"data_type": data_type,
|
||||||
|
"data_value": data_value,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
return self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
|
||||||
|
|
||||||
|
def _send_api_request(self, endpoint, payload):
|
||||||
|
url = f"{self.client.host}{endpoint}"
|
||||||
|
headers = {"Authorization": f"Token {self.client.api_key}"}
|
||||||
|
response = requests.post(url, json=payload, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _process_and_upload_data(self, data_hash, data_type, data_value):
|
||||||
|
if os.path.isabs(data_value):
|
||||||
|
presigned_url_data = self._get_presigned_url(data_type, data_value)
|
||||||
|
presigned_url = presigned_url_data["presigned_url"]
|
||||||
|
s3_key = presigned_url_data["s3_key"]
|
||||||
|
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
|
||||||
|
data_value = presigned_url
|
||||||
|
metadata = {"file_path": data_value, "s3_key": s3_key}
|
||||||
|
else:
|
||||||
|
self.logger.error(f"File upload failed for hash: {data_hash}")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if data_type == "qna_pair":
|
||||||
|
data_value = list(ast.literal_eval(data_value))
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._upload_data_to_pipeline(data_type, data_value, metadata)
|
||||||
|
self._mark_data_as_uploaded(data_hash)
|
||||||
|
self.logger.info(f"Data of type {data_type} uploaded successfully.")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error occurred during data upload: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _mark_data_as_uploaded(self, data_hash):
|
||||||
|
self.cursor.execute(
|
||||||
|
"UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ? AND is_uploaded = 0",
|
||||||
|
(data_hash, self.local_id),
|
||||||
)
|
)
|
||||||
|
self.connection.commit()
|
||||||
|
|
||||||
|
def deploy(self):
|
||||||
|
try:
|
||||||
|
if self.client is None:
|
||||||
|
self._init_client()
|
||||||
|
|
||||||
|
pipeline_data = self._create_pipeline()
|
||||||
|
self.id = pipeline_data["id"]
|
||||||
|
|
||||||
|
results = self.cursor.execute(
|
||||||
|
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,)
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
data_hash, data_type, data_value = result[0], result[2], result[3]
|
||||||
|
if self._process_and_upload_data(data_hash, data_type, data_value):
|
||||||
|
self.logger.info(f"Data with hash {data_hash} uploaded successfully.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error occurred during deployment: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail="Error occurred during deployment.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, yaml_path: str):
|
def from_config(cls, yaml_path: str):
|
||||||
@@ -82,7 +254,7 @@ class Pipeline(EmbedChain):
|
|||||||
with open(yaml_path, "r") as file:
|
with open(yaml_path, "r") as file:
|
||||||
config_data = yaml.safe_load(file)
|
config_data = yaml.safe_load(file)
|
||||||
|
|
||||||
pipeline_config_data = config_data.get("pipeline", {})
|
pipeline_config_data = config_data.get("pipeline", {}).get("config", {})
|
||||||
db_config_data = config_data.get("vectordb", {})
|
db_config_data = config_data.get("vectordb", {})
|
||||||
embedding_model_config_data = config_data.get("embedding_model", {})
|
embedding_model_config_data = config_data.get("embedding_model", {})
|
||||||
|
|
||||||
@@ -95,4 +267,39 @@ class Pipeline(EmbedChain):
|
|||||||
embedding_model = EmbedderFactory.create(
|
embedding_model = EmbedderFactory.create(
|
||||||
embedding_model_provider, embedding_model_config_data.get("config", {})
|
embedding_model_provider, embedding_model_config_data.get("config", {})
|
||||||
)
|
)
|
||||||
return cls(config=pipeline_config, db=db, embedding_model=embedding_model)
|
return cls(
|
||||||
|
config=pipeline_config,
|
||||||
|
db=db,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
yaml_path=yaml_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
def start(self, host="0.0.0.0", port=8000):
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.post("/add")
|
||||||
|
async def add_document(data_value: str, data_type: str = None):
|
||||||
|
"""
|
||||||
|
Add a document to the pipeline.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
document = {"data_value": data_value, "data_type": data_type}
|
||||||
|
self.add(document)
|
||||||
|
return {"message": "Document added successfully"}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.post("/query")
|
||||||
|
async def query_documents(query: str, num_documents: int = 3):
|
||||||
|
"""
|
||||||
|
Query for similar documents in the pipeline.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
results = self.search(query, num_documents)
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|||||||
16
tests/conftest.py
Normal file
16
tests/conftest.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def clean_db():
|
||||||
|
db_path = os.path.expanduser("~/.embedchain/embedchain.db")
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.remove(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def setup():
|
||||||
|
clean_db()
|
||||||
|
yield
|
||||||
|
clean_db()
|
||||||
@@ -16,19 +16,20 @@ def app(mocker):
|
|||||||
|
|
||||||
|
|
||||||
def test_add(app):
|
def test_add(app):
|
||||||
app.add("https://example.com", metadata={"meta": "meta-data"})
|
app.add("https://example.com", metadata={"foo": "bar"})
|
||||||
assert app.user_asks == [["https://example.com", "web_page", {"meta": "meta-data"}]]
|
assert app.user_asks == [["https://example.com", "web_page", {"foo": "bar"}]]
|
||||||
|
|
||||||
|
|
||||||
def test_add_sitemap(app):
|
# TODO: Make this test faster by generating a sitemap locally rather than using a remote one
|
||||||
app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"})
|
# def test_add_sitemap(app):
|
||||||
assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]]
|
# app.add("https://www.google.com/sitemap.xml", metadata={"foo": "bar"})
|
||||||
|
# assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"foo": "bar"}]]
|
||||||
|
|
||||||
|
|
||||||
def test_add_forced_type(app):
|
def test_add_forced_type(app):
|
||||||
data_type = "text"
|
data_type = "text"
|
||||||
app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"})
|
app.add("https://example.com", data_type=data_type, metadata={"foo": "bar"})
|
||||||
assert app.user_asks == [["https://example.com", data_type, {"meta": "meta-data"}]]
|
assert app.user_asks == [["https://example.com", data_type, {"foo": "bar"}]]
|
||||||
|
|
||||||
|
|
||||||
def test_dry_run(app):
|
def test_dry_run(app):
|
||||||
|
|||||||
53
tests/test_client.py
Normal file
53
tests/test_client.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from embedchain import Client
|
||||||
|
|
||||||
|
|
||||||
|
class TestClient:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_requests_post(self, mocker):
|
||||||
|
return mocker.patch("embedchain.client.requests.post")
|
||||||
|
|
||||||
|
def test_valid_api_key(self, mock_requests_post):
|
||||||
|
mock_requests_post.return_value.status_code = 200
|
||||||
|
client = Client(api_key="valid_api_key")
|
||||||
|
assert client.check("valid_api_key") is True
|
||||||
|
|
||||||
|
def test_invalid_api_key(self, mock_requests_post):
|
||||||
|
mock_requests_post.return_value.status_code = 401
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Client(api_key="invalid_api_key")
|
||||||
|
|
||||||
|
def test_update_valid_api_key(self, mock_requests_post):
|
||||||
|
mock_requests_post.return_value.status_code = 200
|
||||||
|
client = Client(api_key="valid_api_key")
|
||||||
|
client.update("new_valid_api_key")
|
||||||
|
assert client.get() == "new_valid_api_key"
|
||||||
|
|
||||||
|
def test_clear_api_key(self, mock_requests_post):
|
||||||
|
mock_requests_post.return_value.status_code = 200
|
||||||
|
client = Client(api_key="valid_api_key")
|
||||||
|
client.clear()
|
||||||
|
assert client.get() is None
|
||||||
|
|
||||||
|
def test_save_api_key(self, mock_requests_post):
|
||||||
|
mock_requests_post.return_value.status_code = 200
|
||||||
|
api_key_to_save = "valid_api_key"
|
||||||
|
client = Client(api_key=api_key_to_save)
|
||||||
|
client.save()
|
||||||
|
assert client.get() == api_key_to_save
|
||||||
|
|
||||||
|
def test_load_api_key_from_config(self, mocker):
|
||||||
|
mocker.patch("embedchain.Client.load_config", return_value={"api_key": "test_api_key"})
|
||||||
|
client = Client()
|
||||||
|
assert client.get() == "test_api_key"
|
||||||
|
|
||||||
|
def test_load_invalid_api_key_from_config(self, mocker):
|
||||||
|
mocker.patch("embedchain.Client.load_config", return_value={})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Client()
|
||||||
|
|
||||||
|
def test_load_missing_api_key_from_config(self, mocker):
|
||||||
|
mocker.patch("embedchain.Client.load_config", return_value={})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Client()
|
||||||
@@ -1,9 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import pytest
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import AppConfig, ChromaDbConfig
|
from embedchain.config import AppConfig, ChromaDbConfig
|
||||||
from embedchain.vectordb.chroma import ChromaDB
|
from embedchain.vectordb.chroma import ChromaDB
|
||||||
|
|||||||
Reference in New Issue
Block a user