[Feature] Add support for deploying local pipelines to Embedchain platform (#847)
This commit is contained in:
@@ -3,5 +3,6 @@ import importlib.metadata
|
||||
__version__ = importlib.metadata.version(__package__ or __name__)
|
||||
|
||||
from embedchain.apps.app import App # noqa: F401
|
||||
from embedchain.client import Client # noqa: F401
|
||||
from embedchain.pipeline import Pipeline # noqa: F401
|
||||
from embedchain.vectordb.chroma import ChromaDB # noqa: F401
|
||||
|
||||
102
embedchain/client.py
Normal file
102
embedchain/client.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
from embedchain.embedchain import CONFIG_DIR, CONFIG_FILE
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, api_key=None, host="https://apiv2.embedchain.ai"):
|
||||
self.config_data = self.load_config()
|
||||
self.host = host
|
||||
|
||||
if api_key:
|
||||
if self.check(api_key):
|
||||
self.api_key = api_key
|
||||
self.save()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid API key provided. You can find your API key on https://app.embedchain.ai/settings/keys."
|
||||
)
|
||||
else:
|
||||
if "api_key" in self.config_data:
|
||||
self.api_key = self.config_data["api_key"]
|
||||
logging.info("API key loaded successfully!")
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are not logged in. Please obtain an API key from https://app.embedchain.ai/settings/keys/"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def setup_dir(self):
|
||||
"""
|
||||
Loads the user id from the config file if it exists, otherwise generates a new
|
||||
one and saves it to the config file.
|
||||
|
||||
:return: user id
|
||||
:rtype: str
|
||||
"""
|
||||
if not os.path.exists(CONFIG_DIR):
|
||||
os.makedirs(CONFIG_DIR)
|
||||
|
||||
if os.path.exists(CONFIG_FILE):
|
||||
with open(CONFIG_FILE, "r") as f:
|
||||
data = json.load(f)
|
||||
if "user_id" in data:
|
||||
return data["user_id"]
|
||||
|
||||
u_id = str(uuid.uuid4())
|
||||
with open(CONFIG_FILE, "w") as f:
|
||||
json.dump({"user_id": u_id}, f)
|
||||
|
||||
@classmethod
|
||||
def load_config(cls):
|
||||
if not os.path.exists(CONFIG_FILE):
|
||||
cls.setup_dir()
|
||||
|
||||
with open(CONFIG_FILE, "r") as config_file:
|
||||
return json.load(config_file)
|
||||
|
||||
def save(self):
|
||||
self.config_data["api_key"] = self.api_key
|
||||
with open(CONFIG_FILE, "w") as config_file:
|
||||
json.dump(self.config_data, config_file, indent=4)
|
||||
|
||||
logging.info("API key saved successfully!")
|
||||
|
||||
def clear(self):
|
||||
if "api_key" in self.config_data:
|
||||
del self.config_data["api_key"]
|
||||
with open(CONFIG_FILE, "w") as config_file:
|
||||
json.dump(self.config_data, config_file, indent=4)
|
||||
self.api_key = None
|
||||
logging.info("API key deleted successfully!")
|
||||
else:
|
||||
logging.warning("API key not found in the configuration file.")
|
||||
|
||||
def update(self, api_key):
|
||||
if self.check(api_key):
|
||||
self.api_key = api_key
|
||||
self.save()
|
||||
logging.info("API key updated successfully!")
|
||||
else:
|
||||
logging.warning("Invalid API key provided. API key not updated.")
|
||||
|
||||
def check(self, api_key):
|
||||
validation_url = f"{self.host}/api/v1/accounts/api_keys/validate/"
|
||||
response = requests.post(validation_url, headers={"Authorization": f"Token {api_key}"})
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
logging.warning(f"Response from API: {response.text}")
|
||||
logging.warning("Invalid API key. Unable to validate.")
|
||||
return False
|
||||
|
||||
def get(self):
|
||||
return self.api_key
|
||||
|
||||
def __str__(self):
|
||||
return self.api_key
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from .add_config import AddConfig, ChunkerConfig
|
||||
from .apps.app_config import AppConfig
|
||||
from .pipeline_config import PipelineConfig
|
||||
from .base_config import BaseConfig
|
||||
from .embedder.base import BaseEmbedderConfig
|
||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||
|
||||
@@ -3,6 +3,7 @@ import importlib.metadata
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
@@ -32,6 +33,7 @@ ABS_PATH = os.getcwd()
|
||||
HOME_DIR = str(Path.home())
|
||||
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
||||
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
|
||||
|
||||
|
||||
class EmbedChain(JSONSerializable):
|
||||
@@ -89,6 +91,27 @@ class EmbedChain(JSONSerializable):
|
||||
# Send anonymous telemetry
|
||||
self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
|
||||
self.u_id = self._load_or_generate_user_id()
|
||||
|
||||
# Establish a connection to the SQLite database
|
||||
self.connection = sqlite3.connect(SQLITE_PATH)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
# Create the 'data_sources' table if it doesn't exist
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS data_sources (
|
||||
pipeline_id TEXT,
|
||||
hash TEXT,
|
||||
type TEXT,
|
||||
value TEXT,
|
||||
metadata TEXT,
|
||||
is_uploaded INTEGER DEFAULT 0,
|
||||
PRIMARY KEY (pipeline_id, hash)
|
||||
)
|
||||
"""
|
||||
)
|
||||
self.connection.commit()
|
||||
|
||||
# NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
|
||||
# if (self.config.collect_metrics):
|
||||
# raise ConnectionRefusedError("Collection of metrics should not be allowed.")
|
||||
@@ -163,7 +186,7 @@ class EmbedChain(JSONSerializable):
|
||||
:raises ValueError: Invalid data type
|
||||
:param dry_run: Optional. A dry run displays the chunks to ensure that the loader and chunker work as intended.
|
||||
deafaults to False
|
||||
:return: source_id, a md5-hash of the source, in hexadecimal representation.
|
||||
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
|
||||
:rtype: str
|
||||
"""
|
||||
if config is None:
|
||||
@@ -192,18 +215,40 @@ class EmbedChain(JSONSerializable):
|
||||
if not data_type:
|
||||
data_type = detect_datatype(source)
|
||||
|
||||
# `source_id` is the hash of the source argument
|
||||
# `source_hash` is the md5 hash of the source argument
|
||||
hash_object = hashlib.md5(str(source).encode("utf-8"))
|
||||
source_id = hash_object.hexdigest()
|
||||
source_hash = hash_object.hexdigest()
|
||||
|
||||
# Check if the data hash already exists, if so, skip the addition
|
||||
self.cursor.execute(
|
||||
"SELECT 1 FROM data_sources WHERE hash = ? AND pipeline_id = ?", (source_hash, self.config.id)
|
||||
)
|
||||
existing_data = self.cursor.fetchone()
|
||||
|
||||
if existing_data:
|
||||
print(f"Data with hash {source_hash} already exists. Skipping addition.")
|
||||
return source_hash
|
||||
|
||||
data_formatter = DataFormatter(data_type, config)
|
||||
self.user_asks.append([source, data_type.value, metadata])
|
||||
documents, metadatas, _ids, new_chunks = self.load_and_embed(
|
||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
|
||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
|
||||
)
|
||||
if data_type in {DataType.DOCS_SITE}:
|
||||
self.is_docs_site_instance = True
|
||||
|
||||
# Insert the data into the 'data' table
|
||||
self.cursor.execute(
|
||||
"""
|
||||
INSERT INTO data_sources (hash, pipeline_id, type, value, metadata)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(source_hash, self.config.id, data_type.value, str(source), json.dumps(metadata)),
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
self.connection.commit()
|
||||
|
||||
if dry_run:
|
||||
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
|
||||
logging.debug(f"Dry run info : {data_chunks_info}")
|
||||
@@ -218,7 +263,7 @@ class EmbedChain(JSONSerializable):
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("add", extra_metadata))
|
||||
thread_telemetry.start()
|
||||
|
||||
return source_id
|
||||
return source_hash
|
||||
|
||||
def add_local(
|
||||
self,
|
||||
@@ -245,7 +290,7 @@ class EmbedChain(JSONSerializable):
|
||||
:param config: The `AddConfig` instance to use as configuration options., defaults to None
|
||||
:type config: Optional[AddConfig], optional
|
||||
:raises ValueError: Invalid data type
|
||||
:return: source_id, a md5-hash of the source, in hexadecimal representation.
|
||||
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
|
||||
:rtype: str
|
||||
"""
|
||||
logging.warning(
|
||||
@@ -313,7 +358,7 @@ class EmbedChain(JSONSerializable):
|
||||
chunker: BaseChunker,
|
||||
src: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
source_id: Optional[str] = None,
|
||||
source_hash: Optional[str] = None,
|
||||
dry_run=False,
|
||||
):
|
||||
"""
|
||||
@@ -324,7 +369,7 @@ class EmbedChain(JSONSerializable):
|
||||
:param src: The data to be handled by the loader. Can be a URL for
|
||||
remote sources or local content for local loaders.
|
||||
:param metadata: Optional. Metadata associated with the data source.
|
||||
:param source_id: Hexadecimal hash of the source.
|
||||
:param source_hash: Hexadecimal hash of the source.
|
||||
:param dry_run: Optional. A dry run returns chunks and doesn't update DB.
|
||||
:type dry_run: bool, defaults to False
|
||||
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
|
||||
@@ -382,7 +427,7 @@ class EmbedChain(JSONSerializable):
|
||||
m["app_id"] = self.config.id
|
||||
|
||||
# Add hashed source
|
||||
m["hash"] = source_id
|
||||
m["hash"] = source_hash
|
||||
|
||||
# Note: Metadata is the function argument
|
||||
if metadata:
|
||||
@@ -558,15 +603,14 @@ class EmbedChain(JSONSerializable):
|
||||
"""
|
||||
Resets the database. Deletes all embeddings irreversibly.
|
||||
`App` does not have to be reinitialized after using this method.
|
||||
|
||||
DEPRECATED IN FAVOR OF `db.reset()`
|
||||
"""
|
||||
# Send anonymous telemetry
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
|
||||
thread_telemetry.start()
|
||||
|
||||
logging.warning("DEPRECATION WARNING: Please use `app.db.reset()` instead of `App.reset()`.")
|
||||
self.db.reset()
|
||||
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
|
||||
self.connection.commit()
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
|
||||
def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):
|
||||
|
||||
@@ -1,17 +1,27 @@
|
||||
import threading
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
from fastapi import FastAPI, HTTPException
|
||||
|
||||
from embedchain import Client
|
||||
from embedchain.config import PipelineConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedchain import CONFIG_DIR, EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.factory import EmbedderFactory, VectorDBFactory
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class Pipeline(EmbedChain):
|
||||
@@ -21,7 +31,15 @@ class Pipeline(EmbedChain):
|
||||
and vector database.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PipelineConfig = None, db: BaseVectorDB = None, embedding_model: BaseEmbedder = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: PipelineConfig = None,
|
||||
db: BaseVectorDB = None,
|
||||
embedding_model: BaseEmbedder = None,
|
||||
llm: BaseLlm = None,
|
||||
yaml_path: str = None,
|
||||
log_level=logging.INFO,
|
||||
):
|
||||
"""
|
||||
Initialize a new `App` instance.
|
||||
|
||||
@@ -32,42 +50,196 @@ class Pipeline(EmbedChain):
|
||||
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
|
||||
:type embedding_model: BaseEmbedder, optional
|
||||
"""
|
||||
super().__init__()
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
self.logger = logging.getLogger(__name__)
|
||||
# Store the yaml config as an attribute to be able to send it
|
||||
self.yaml_config = None
|
||||
self.client = None
|
||||
if yaml_path:
|
||||
with open(yaml_path, "r") as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
self.yaml_config = config_data
|
||||
|
||||
self.config = config or PipelineConfig()
|
||||
self.name = self.config.name
|
||||
self.id = self.config.id or str(uuid.uuid4())
|
||||
self.local_id = self.config.id or str(uuid.uuid4())
|
||||
|
||||
self.embedding_model = embedding_model or OpenAIEmbedder()
|
||||
self.db = db or ChromaDB()
|
||||
self._initialize_db()
|
||||
self.llm = llm or None
|
||||
self._init_db()
|
||||
|
||||
# setup user id and directory
|
||||
self.u_id = self._load_or_generate_user_id()
|
||||
|
||||
# Establish a connection to the SQLite database
|
||||
self.connection = sqlite3.connect(SQLITE_PATH)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
# Create the 'data_sources' table if it doesn't exist
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS data_sources (
|
||||
pipeline_id TEXT,
|
||||
hash TEXT,
|
||||
type TEXT,
|
||||
value TEXT,
|
||||
metadata TEXT
|
||||
is_uploaded INTEGER DEFAULT 0,
|
||||
PRIMARY KEY (pipeline_id, hash)
|
||||
)
|
||||
"""
|
||||
)
|
||||
self.connection.commit()
|
||||
|
||||
self.user_asks = [] # legacy defaults
|
||||
|
||||
self.s_id = self.config.id or str(uuid.uuid4())
|
||||
self.u_id = self._load_or_generate_user_id()
|
||||
|
||||
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("pipeline_init",))
|
||||
thread_telemetry.start()
|
||||
|
||||
def _initialize_db(self):
|
||||
def _init_db(self):
|
||||
"""
|
||||
Initialize the database.
|
||||
"""
|
||||
self.db._set_embedder(self.embedding_model)
|
||||
self.db._initialize()
|
||||
self.db.set_collection_name(self.name)
|
||||
self.db.set_collection_name(self.db.config.collection_name)
|
||||
|
||||
def _init_client(self):
|
||||
"""
|
||||
Initialize the client.
|
||||
"""
|
||||
config = Client.load_config()
|
||||
if config.get("api_key"):
|
||||
self.client = Client()
|
||||
else:
|
||||
api_key = input("Enter API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n")
|
||||
self.client = Client(api_key=api_key)
|
||||
|
||||
def _create_pipeline(self):
|
||||
"""
|
||||
Create a pipeline on the platform.
|
||||
"""
|
||||
print("Creating pipeline on the platform...")
|
||||
# self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend
|
||||
payload = {
|
||||
"yaml_config": json.dumps(self.yaml_config),
|
||||
"name": self.name,
|
||||
"local_id": self.local_id,
|
||||
}
|
||||
url = f"{self.client.host}/api/v1/pipelines/cli/create/"
|
||||
r = requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
if r.status_code not in [200, 201]:
|
||||
raise Exception(f"Error occurred while creating pipeline. Response from API: {r.text}")
|
||||
|
||||
print(f"Pipeline created. link: https://app.embedchain.ai/pipelines/{r.json()['id']}")
|
||||
return r.json()
|
||||
|
||||
def _get_presigned_url(self, data_type, data_value):
|
||||
payload = {"data_type": data_type, "data_value": data_value}
|
||||
r = requests.post(
|
||||
f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
def search(self, query, num_documents=3):
|
||||
"""
|
||||
Search for similar documents related to the query in the vector database.
|
||||
"""
|
||||
where = {"app_id": self.id}
|
||||
return self.db.query(
|
||||
query,
|
||||
n_results=num_documents,
|
||||
where=where,
|
||||
skip_embedding=False,
|
||||
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
|
||||
if self.deploy is False:
|
||||
where = {"app_id": self.local_id}
|
||||
return self.db.query(
|
||||
query,
|
||||
n_results=num_documents,
|
||||
where=where,
|
||||
skip_embedding=False,
|
||||
)
|
||||
else:
|
||||
# Make API call to the backend to get the results
|
||||
NotImplementedError("Search is not implemented yet for the prod mode.")
|
||||
|
||||
def _upload_file_to_presigned_url(self, presigned_url, file_path):
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
response = requests.put(presigned_url, data=file)
|
||||
response.raise_for_status()
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error occurred during file upload: {str(e)}")
|
||||
return False
|
||||
|
||||
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
|
||||
payload = {
|
||||
"data_type": data_type,
|
||||
"data_value": data_value,
|
||||
"metadata": metadata,
|
||||
}
|
||||
return self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
|
||||
|
||||
def _send_api_request(self, endpoint, payload):
|
||||
url = f"{self.client.host}{endpoint}"
|
||||
headers = {"Authorization": f"Token {self.client.api_key}"}
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _process_and_upload_data(self, data_hash, data_type, data_value):
|
||||
if os.path.isabs(data_value):
|
||||
presigned_url_data = self._get_presigned_url(data_type, data_value)
|
||||
presigned_url = presigned_url_data["presigned_url"]
|
||||
s3_key = presigned_url_data["s3_key"]
|
||||
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
|
||||
data_value = presigned_url
|
||||
metadata = {"file_path": data_value, "s3_key": s3_key}
|
||||
else:
|
||||
self.logger.error(f"File upload failed for hash: {data_hash}")
|
||||
return False
|
||||
else:
|
||||
if data_type == "qna_pair":
|
||||
data_value = list(ast.literal_eval(data_value))
|
||||
metadata = {}
|
||||
|
||||
try:
|
||||
self._upload_data_to_pipeline(data_type, data_value, metadata)
|
||||
self._mark_data_as_uploaded(data_hash)
|
||||
self.logger.info(f"Data of type {data_type} uploaded successfully.")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error occurred during data upload: {str(e)}")
|
||||
return False
|
||||
|
||||
def _mark_data_as_uploaded(self, data_hash):
|
||||
self.cursor.execute(
|
||||
"UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ? AND is_uploaded = 0",
|
||||
(data_hash, self.local_id),
|
||||
)
|
||||
self.connection.commit()
|
||||
|
||||
def deploy(self):
|
||||
try:
|
||||
if self.client is None:
|
||||
self._init_client()
|
||||
|
||||
pipeline_data = self._create_pipeline()
|
||||
self.id = pipeline_data["id"]
|
||||
|
||||
results = self.cursor.execute(
|
||||
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,)
|
||||
).fetchall()
|
||||
|
||||
for result in results:
|
||||
data_hash, data_type, data_value = result[0], result[2], result[3]
|
||||
if self._process_and_upload_data(data_hash, data_type, data_value):
|
||||
self.logger.info(f"Data with hash {data_hash} uploaded successfully.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error occurred during deployment: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Error occurred during deployment.")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, yaml_path: str):
|
||||
@@ -82,7 +254,7 @@ class Pipeline(EmbedChain):
|
||||
with open(yaml_path, "r") as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
|
||||
pipeline_config_data = config_data.get("pipeline", {})
|
||||
pipeline_config_data = config_data.get("pipeline", {}).get("config", {})
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
embedding_model_config_data = config_data.get("embedding_model", {})
|
||||
|
||||
@@ -95,4 +267,39 @@ class Pipeline(EmbedChain):
|
||||
embedding_model = EmbedderFactory.create(
|
||||
embedding_model_provider, embedding_model_config_data.get("config", {})
|
||||
)
|
||||
return cls(config=pipeline_config, db=db, embedding_model=embedding_model)
|
||||
return cls(
|
||||
config=pipeline_config,
|
||||
db=db,
|
||||
embedding_model=embedding_model,
|
||||
yaml_path=yaml_path,
|
||||
)
|
||||
|
||||
def start(self, host="0.0.0.0", port=8000):
|
||||
app = FastAPI()
|
||||
|
||||
@app.post("/add")
|
||||
async def add_document(data_value: str, data_type: str = None):
|
||||
"""
|
||||
Add a document to the pipeline.
|
||||
"""
|
||||
try:
|
||||
document = {"data_value": data_value, "data_type": data_type}
|
||||
self.add(document)
|
||||
return {"message": "Document added successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/query")
|
||||
async def query_documents(query: str, num_documents: int = 3):
|
||||
"""
|
||||
Query for similar documents in the pipeline.
|
||||
"""
|
||||
try:
|
||||
results = self.search(query, num_documents)
|
||||
return results
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
16
tests/conftest.py
Normal file
16
tests/conftest.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def clean_db():
|
||||
db_path = os.path.expanduser("~/.embedchain/embedchain.db")
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup():
|
||||
clean_db()
|
||||
yield
|
||||
clean_db()
|
||||
@@ -16,19 +16,20 @@ def app(mocker):
|
||||
|
||||
|
||||
def test_add(app):
|
||||
app.add("https://example.com", metadata={"meta": "meta-data"})
|
||||
assert app.user_asks == [["https://example.com", "web_page", {"meta": "meta-data"}]]
|
||||
app.add("https://example.com", metadata={"foo": "bar"})
|
||||
assert app.user_asks == [["https://example.com", "web_page", {"foo": "bar"}]]
|
||||
|
||||
|
||||
def test_add_sitemap(app):
|
||||
app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"})
|
||||
assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]]
|
||||
# TODO: Make this test faster by generating a sitemap locally rather than using a remote one
|
||||
# def test_add_sitemap(app):
|
||||
# app.add("https://www.google.com/sitemap.xml", metadata={"foo": "bar"})
|
||||
# assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"foo": "bar"}]]
|
||||
|
||||
|
||||
def test_add_forced_type(app):
|
||||
data_type = "text"
|
||||
app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"})
|
||||
assert app.user_asks == [["https://example.com", data_type, {"meta": "meta-data"}]]
|
||||
app.add("https://example.com", data_type=data_type, metadata={"foo": "bar"})
|
||||
assert app.user_asks == [["https://example.com", data_type, {"foo": "bar"}]]
|
||||
|
||||
|
||||
def test_dry_run(app):
|
||||
|
||||
53
tests/test_client.py
Normal file
53
tests/test_client.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import pytest
|
||||
|
||||
from embedchain import Client
|
||||
|
||||
|
||||
class TestClient:
|
||||
@pytest.fixture
|
||||
def mock_requests_post(self, mocker):
|
||||
return mocker.patch("embedchain.client.requests.post")
|
||||
|
||||
def test_valid_api_key(self, mock_requests_post):
|
||||
mock_requests_post.return_value.status_code = 200
|
||||
client = Client(api_key="valid_api_key")
|
||||
assert client.check("valid_api_key") is True
|
||||
|
||||
def test_invalid_api_key(self, mock_requests_post):
|
||||
mock_requests_post.return_value.status_code = 401
|
||||
with pytest.raises(ValueError):
|
||||
Client(api_key="invalid_api_key")
|
||||
|
||||
def test_update_valid_api_key(self, mock_requests_post):
|
||||
mock_requests_post.return_value.status_code = 200
|
||||
client = Client(api_key="valid_api_key")
|
||||
client.update("new_valid_api_key")
|
||||
assert client.get() == "new_valid_api_key"
|
||||
|
||||
def test_clear_api_key(self, mock_requests_post):
|
||||
mock_requests_post.return_value.status_code = 200
|
||||
client = Client(api_key="valid_api_key")
|
||||
client.clear()
|
||||
assert client.get() is None
|
||||
|
||||
def test_save_api_key(self, mock_requests_post):
|
||||
mock_requests_post.return_value.status_code = 200
|
||||
api_key_to_save = "valid_api_key"
|
||||
client = Client(api_key=api_key_to_save)
|
||||
client.save()
|
||||
assert client.get() == api_key_to_save
|
||||
|
||||
def test_load_api_key_from_config(self, mocker):
|
||||
mocker.patch("embedchain.Client.load_config", return_value={"api_key": "test_api_key"})
|
||||
client = Client()
|
||||
assert client.get() == "test_api_key"
|
||||
|
||||
def test_load_invalid_api_key_from_config(self, mocker):
|
||||
mocker.patch("embedchain.Client.load_config", return_value={})
|
||||
with pytest.raises(ValueError):
|
||||
Client()
|
||||
|
||||
def test_load_missing_api_key_from_config(self, mocker):
|
||||
mocker.patch("embedchain.Client.load_config", return_value={})
|
||||
with pytest.raises(ValueError):
|
||||
Client()
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import shutil
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from chromadb.config import Settings
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, ChromaDbConfig
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
Reference in New Issue
Block a user