[Refactor] Converge Pipeline and App classes (#1021)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-12-29 16:52:41 +05:30
committed by GitHub
parent c0aafd38c9
commit a926bcc640
91 changed files with 646 additions and 875 deletions

View File

@@ -63,7 +63,7 @@ For example, you can create an Elon Musk bot using the following code:
```python
import os
from embedchain import Pipeline as App
from embedchain import App
# Create a bot instance
os.environ["OPENAI_API_KEY"] = "YOUR API KEY"

View File

@@ -1,7 +1,6 @@
app:
config:
id: 'my-app'
collection_name: 'my-app'
llm:
provider: openai

View File

@@ -1,7 +1,6 @@
app:
config:
id: 'open-source-app'
collection_name: 'open-source-app'
collect_metrics: false
llm:

View File

@@ -21,7 +21,7 @@ title: '📊 add'
### Load data from webpage
```python Code example
from embedchain import Pipeline as App
from embedchain import App
app = App()
app.add("https://www.forbes.com/profile/elon-musk")
@@ -32,7 +32,7 @@ app.add("https://www.forbes.com/profile/elon-musk")
### Load data from sitemap
```python Code example
from embedchain import Pipeline as App
from embedchain import App
app = App()
app.add("https://python.langchain.com/sitemap.xml", data_type="sitemap")

View File

@@ -36,7 +36,7 @@ title: '💬 chat'
If you want to get the answer to question and return both answer and citations, use the following code snippet:
```python With Citations
from embedchain import Pipeline as App
from embedchain import App
# Initialize app
app = App()
@@ -79,7 +79,7 @@ When `citations=True`, note that the returned `sources` are a list of tuples whe
If you just want to return answers and don't want to return citations, you can use the following example:
```python Without Citations
from embedchain import Pipeline as App
from embedchain import App
# Initialize app
app = App()

View File

@@ -7,7 +7,7 @@ title: 🗑 delete
## Usage
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()

View File

@@ -9,7 +9,7 @@ The `deploy()` method not only deploys your pipeline but also efficiently manage
## Usage
```python
from embedchain import Pipeline as App
from embedchain import App
# Initialize app
app = App()

View File

@@ -41,7 +41,7 @@ You can create an embedchain pipeline instance using the following methods:
### Default setting
```python Code Example
from embedchain import Pipeline as App
from embedchain import App
app = App()
```
@@ -49,7 +49,7 @@ app = App()
### Python Dict
```python Code Example
from embedchain import Pipeline as App
from embedchain import App
config_dict = {
'llm': {
@@ -76,7 +76,7 @@ app = App.from_config(config=config_dict)
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load llm configuration from config.yaml file
app = App.from_config(config_path="config.yaml")
@@ -103,7 +103,7 @@ embedder:
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load llm configuration from config.json file
app = App.from_config(config_path="config.json")

View File

@@ -36,7 +36,7 @@ title: '❓ query'
If you want to get the answer to question and return both answer and citations, use the following code snippet:
```python With Citations
from embedchain import Pipeline as App
from embedchain import App
# Initialize app
app = App()
@@ -78,7 +78,7 @@ When `citations=True`, note that the returned `sources` are a list of tuples whe
If you just want to return answers and don't want to return citations, you can use the following example:
```python Without Citations
from embedchain import Pipeline as App
from embedchain import App
# Initialize app
app = App()

View File

@@ -7,7 +7,7 @@ title: 🔄 reset
## Usage
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()
app.add("https://www.forbes.com/profile/elon-musk")

View File

@@ -24,7 +24,7 @@ title: '🔍 search'
Refer to the following example on how to use the search api:
```python Code example
from embedchain import Pipeline as App
from embedchain import App
# Initialize app
app = App()

View File

@@ -5,7 +5,7 @@ title: "🐝 Beehiiv"
To add any Beehiiv data sources to your app, just add the base url as the source and set the data_type to `beehiiv`.
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()

View File

@@ -5,7 +5,7 @@ title: '📊 CSV'
To add any csv file, use the data_type as `csv`. `csv` allows remote urls and conventional file paths. Headers are included for each line, so if you have an `age` column, `18` will be added as `age: 18`. Eg:
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()
app.add('https://people.sc.fsu.edu/~jburkardt/data/csv/airtravel.csv', data_type="csv")

View File

@@ -5,7 +5,7 @@ title: '⚙️ Custom'
When we say "custom", we mean that you can customize the loader and chunker to your needs. This is done by passing a custom loader and chunker to the `add` method.
```python
from embedchain import Pipeline as App
from embedchain import App
import your_loader
import your_chunker
@@ -27,7 +27,7 @@ app.add("source", data_type="custom", loader=loader, chunker=chunker)
Example:
```python
from embedchain import Pipeline as App
from embedchain import App
from embedchain.loaders.github import GithubLoader
app = App()

View File

@@ -35,7 +35,7 @@ Default behavior is to create a persistent vector db in the directory **./db**.
Create a local index:
```python
from embedchain import Pipeline as App
from embedchain import App
naval_chat_bot = App()
naval_chat_bot.add("https://www.youtube.com/watch?v=3qHkcs3kG44")
@@ -45,7 +45,7 @@ naval_chat_bot.add("https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Alma
You can reuse the local index with the same code, but without adding new documents:
```python
from embedchain import Pipeline as App
from embedchain import App
naval_chat_bot = App()
print(naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?"))
@@ -56,7 +56,7 @@ print(naval_chat_bot.query("What unique capacity does Naval argue humans possess
You can reset the app by simply calling the `reset` method. This will delete the vector database and all other app related files.
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()
app.add("https://www.youtube.com/watch?v=3qHkcs3kG44")

View File

@@ -8,7 +8,7 @@ To use an entire directory as data source, just add `data_type` as `directory` a
```python
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ["OPENAI_API_KEY"] = "sk-xxx"
@@ -23,7 +23,7 @@ print(response)
```python
import os
from embedchain import Pipeline as App
from embedchain import App
from embedchain.loaders.directory_loader import DirectoryLoader
os.environ["OPENAI_API_KEY"] = "sk-xxx"

View File

@@ -12,7 +12,7 @@ To add any Discord channel messages to your app, just add the `channel_id` as th
```python
import os
from embedchain import Pipeline as App
from embedchain import App
# add your discord "BOT" token
os.environ["DISCORD_TOKEN"] = "xxx"

View File

@@ -5,7 +5,7 @@ title: '📚 Code documentation'
To add any code documentation website as a loader, use the data_type as `docs_site`. Eg:
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()
app.add("https://docs.embedchain.ai/", data_type="docs_site")

View File

@@ -7,7 +7,7 @@ title: '📄 Docx file'
To add any doc/docx file, use the data_type as `docx`. `docx` allows remote urls and conventional file paths. Eg:
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()
app.add('https://example.com/content/intro.docx', data_type="docx")

View File

@@ -24,7 +24,7 @@ To use this you need to save `credentials.json` in the directory from where you
12. Put the `.json` file in your current directory and rename it to `credentials.json`
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()

View File

@@ -21,7 +21,7 @@ If you would like to add other data structures (e.g. list, dict etc.), convert i
<CodeGroup>
```python python
from embedchain import Pipeline as App
from embedchain import App
app = App()

View File

@@ -5,7 +5,7 @@ title: '📝 Mdx file'
To add any `.mdx` file to your app, use the data_type (first argument to `.add()` method) as `mdx`. Note that this supports support mdx file present on machine, so this should be a file path. Eg:
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()
app.add('path/to/file.mdx', data_type='mdx')

View File

@@ -8,7 +8,7 @@ To load a notion page, use the data_type as `notion`. Since it is hard to automa
The next argument must **end** with the `notion page id`. The id is a 32-character string. Eg:
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()

View File

@@ -5,7 +5,7 @@ title: 🙌 OpenAPI
To add any OpenAPI spec yaml file (currently the json file will be detected as JSON data type), use the data_type as 'openapi'. 'openapi' allows remote urls and conventional file paths.
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()

View File

@@ -5,7 +5,7 @@ title: '📰 PDF file'
To add any pdf file, use the data_type as `pdf_file`. Eg:
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()

View File

@@ -5,7 +5,7 @@ title: '❓💬 Queston and answer pair'
QnA pair is a local data type. To supply your own QnA pair, use the data_type as `qna_pair` and enter a tuple. Eg:
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()

View File

@@ -5,7 +5,7 @@ title: '🗺️ Sitemap'
Add all web pages from an xml-sitemap. Filters non-text files. Use the data_type as `sitemap`. Eg:
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()

View File

@@ -16,7 +16,7 @@ This will automatically retrieve data from the workspace associated with the use
```python
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ["SLACK_USER_TOKEN"] = "xoxp-xxx"
app = App()

View File

@@ -5,7 +5,7 @@ title: "📝 Substack"
To add any Substack data sources to your app, just add the main base url as the source and set the data_type to `substack`.
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()

View File

@@ -7,7 +7,7 @@ title: '📝 Text'
Text is a local data type. To supply your own text, use the data_type as `text` and enter a string. The text is not processed, this can be very versatile. Eg:
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()

View File

@@ -5,7 +5,7 @@ title: '🌐 Web page'
To add any web page, use the data_type as `web_page`. Eg:
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()

View File

@@ -7,7 +7,7 @@ title: '🧾 XML file'
To add any xml file, use the data_type as `xml`. Eg:
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()

View File

@@ -13,7 +13,7 @@ pip install -u "embedchain[youtube]"
</Note>
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()
app.add("@channel_name", data_type="youtube_channel")

View File

@@ -5,7 +5,7 @@ title: '📺 Youtube'
To add any youtube video to your app, use the data_type as `youtube_video`. Eg:
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()
app.add('a_valid_youtube_url_here', data_type='youtube_video')

View File

@@ -25,7 +25,7 @@ Once you have obtained the key, you can use it like this:
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ['OPENAI_API_KEY'] = 'xxx'
@@ -52,7 +52,7 @@ To use Google AI embedding function, you have to set the `GOOGLE_API_KEY` enviro
<CodeGroup>
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ["GOOGLE_API_KEY"] = "xxx"
@@ -81,7 +81,7 @@ To use Azure OpenAI embedding model, you have to set some of the azure openai re
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ["OPENAI_API_TYPE"] = "azure"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://xxx.openai.azure.com/"
@@ -119,7 +119,7 @@ GPT4All supports generating high quality embeddings of arbitrary length document
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load embedding model configuration from config.yaml file
app = App.from_config(config_path="config.yaml")
@@ -148,7 +148,7 @@ Hugging Face supports generating embeddings of arbitrary length documents of tex
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load embedding model configuration from config.yaml file
app = App.from_config(config_path="config.yaml")
@@ -179,7 +179,7 @@ Embedchain supports Google's VertexAI embeddings model through a simple interfac
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load embedding model configuration from config.yaml file
app = App.from_config(config_path="config.yaml")

View File

@@ -29,7 +29,7 @@ Once you have obtained the key, you can use it like this:
```python
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ['OPENAI_API_KEY'] = 'xxx'
@@ -44,7 +44,7 @@ If you are looking to configure the different parameters of the LLM, you can do
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ['OPENAI_API_KEY'] = 'xxx'
@@ -71,7 +71,7 @@ Examples:
<Accordion title="Using Pydantic Models">
```python
import os
from embedchain import Pipeline as App
from embedchain import App
from embedchain.llm.openai import OpenAILlm
import requests
from pydantic import BaseModel, Field, ValidationError, field_validator
@@ -123,7 +123,7 @@ print(result)
<Accordion title="Using OpenAI JSON schema">
```python
import os
from embedchain import Pipeline as App
from embedchain import App
from embedchain.llm.openai import OpenAILlm
import requests
from pydantic import BaseModel, Field, ValidationError, field_validator
@@ -158,7 +158,7 @@ print(result)
<Accordion title="Using actual python functions">
```python
import os
from embedchain import Pipeline as App
from embedchain import App
from embedchain.llm.openai import OpenAILlm
import requests
from pydantic import BaseModel, Field, ValidationError, field_validator
@@ -192,7 +192,7 @@ To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variabl
<CodeGroup>
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ["GOOGLE_API_KEY"] = "xxx"
@@ -235,7 +235,7 @@ To use Azure OpenAI model, you have to set some of the azure openai related envi
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ["OPENAI_API_TYPE"] = "azure"
os.environ["OPENAI_API_BASE"] = "https://xxx.openai.azure.com/"
@@ -274,7 +274,7 @@ To use anthropic's model, please set the `ANTHROPIC_API_KEY` which you find on t
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ["ANTHROPIC_API_KEY"] = "xxx"
@@ -311,7 +311,7 @@ Once you have the API key, you are all set to use it with Embedchain.
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ["COHERE_API_KEY"] = "xxx"
@@ -347,7 +347,7 @@ Once you have the API key, you are all set to use it with Embedchain.
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ["TOGETHER_API_KEY"] = "xxx"
@@ -375,7 +375,7 @@ Setup Ollama using https://github.com/jmorganca/ollama
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
# load llm configuration from config.yaml file
app = App.from_config(config_path="config.yaml")
@@ -406,7 +406,7 @@ GPT4all is a free-to-use, locally running, privacy-aware chatbot. No GPU or inte
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load llm configuration from config.yaml file
app = App.from_config(config_path="config.yaml")
@@ -438,7 +438,7 @@ Once you have the key, load the app using the config yaml file:
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ["JINACHAT_API_KEY"] = "xxx"
# load llm configuration from config.yaml file
@@ -474,7 +474,7 @@ Once you have the token, load the app using the config yaml file:
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "xxx"
@@ -504,7 +504,7 @@ Once you have the token, load the app using the config yaml file:
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ["REPLICATE_API_TOKEN"] = "xxx"
@@ -531,7 +531,7 @@ Setup Google Cloud Platform application credentials by following the instruction
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load llm configuration from config.yaml file
app = App.from_config(config_path="config.yaml")

View File

@@ -22,7 +22,7 @@ Utilizing a vector database alongside Embedchain is a seamless process. All you
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load chroma configuration from yaml file
app = App.from_config(config_path="config1.yaml")
@@ -67,7 +67,7 @@ You can authorize the connection to Elasticsearch by providing either `basic_aut
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load elasticsearch configuration from yaml file
app = App.from_config(config_path="config.yaml")
@@ -97,7 +97,7 @@ pip install --upgrade 'embedchain[opensearch]'
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load opensearch configuration from yaml file
app = App.from_config(config_path="config.yaml")
@@ -133,7 +133,7 @@ Set the Zilliz environment variables `ZILLIZ_CLOUD_URI` and `ZILLIZ_CLOUD_TOKEN`
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ['ZILLIZ_CLOUD_URI'] = 'https://xxx.zillizcloud.com'
os.environ['ZILLIZ_CLOUD_TOKEN'] = 'xxx'
@@ -172,7 +172,7 @@ In order to use Pinecone as vector database, set the environment variables `PINE
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load pinecone configuration from yaml file
app = App.from_config(config_path="config.yaml")
@@ -195,7 +195,7 @@ In order to use Qdrant as a vector database, set the environment variables `QDRA
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load qdrant configuration from yaml file
app = App.from_config(config_path="config.yaml")
@@ -215,7 +215,7 @@ In order to use Weaviate as a vector database, set the environment variables `WE
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load weaviate configuration from yaml file
app = App.from_config(config_path="config.yaml")

View File

@@ -10,7 +10,7 @@ Embedchain enables developers to deploy their LLM-powered apps in production usi
See the example below on how to use the deploy your app (for free):
```python
from embedchain import Pipeline as App
from embedchain import App
# Initialize app
app = App()

View File

@@ -11,7 +11,7 @@ Use the model provided on huggingface: `mistralai/Mistral-7B-v0.1`
<CodeGroup>
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "hf_your_token"
@@ -40,7 +40,7 @@ Use the model `gpt-4-turbo` provided my openai.
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ['OPENAI_API_KEY'] = 'xxx'
@@ -65,7 +65,7 @@ llm:
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ['OPENAI_API_KEY'] = 'xxx'
@@ -90,7 +90,7 @@ llm:
<CodeGroup>
```python main.py
from embedchain import Pipeline as App
from embedchain import App
# load llm configuration from opensource.yaml file
app = App.from_config(config_path="opensource.yaml")
@@ -131,7 +131,7 @@ llm:
```python main.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ['OPENAI_API_KEY'] = 'sk-xxx'
@@ -149,7 +149,7 @@ response = app.query("What is the net worth of Elon Musk?")
Set up the app by adding an `id` in the config file. This keeps the data for future use. You can include this `id` in the yaml config or input it directly in `config` dict.
```python app1.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ['OPENAI_API_KEY'] = 'sk-xxx'
@@ -167,7 +167,7 @@ response = app.query("What is the net worth of Elon Musk?")
```
```python app2.py
import os
from embedchain import Pipeline as App
from embedchain import App
os.environ['OPENAI_API_KEY'] = 'sk-xxx'

View File

@@ -14,7 +14,7 @@ Creating an app involves 3 steps:
<Steps>
<Step title="⚙️ Import app instance">
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()
```
<Accordion title="Customize your app by a simple YAML config" icon="gear-complex">
@@ -22,15 +22,15 @@ Creating an app involves 3 steps:
Explore the custom configurations [here](https://docs.embedchain.ai/advanced/configuration).
<CodeGroup>
```python yaml_app.py
from embedchain import Pipeline as App
from embedchain import App
app = App.from_config(config_path="config.yaml")
```
```python json_app.py
from embedchain import Pipeline as App
from embedchain import App
app = App.from_config(config_path="config.json")
```
```python app.py
from embedchain import Pipeline as App
from embedchain import App
config = {} # Add your config here
app = App.from_config(config=config)
```

View File

@@ -21,7 +21,7 @@ Create a new file called `app.py` and add the following code:
```python
import chainlit as cl
from embedchain import Pipeline as App
from embedchain import App
import os

View File

@@ -39,7 +39,7 @@ os.environ['LANGCHAIN_PROJECT] = <your-project>
```python
from embedchain import Pipeline as App
from embedchain import App
app = App()
app.add("https://en.wikipedia.org/wiki/Elon_Musk")

View File

@@ -17,7 +17,7 @@ pip install embedchain streamlit
<Tab title="app.py">
```python
import os
from embedchain import Pipeline as App
from embedchain import App
import streamlit as st
with st.sidebar:

View File

@@ -24,7 +24,7 @@ Quickly create a RAG pipeline to answer queries about the [Next.JS Framework](ht
First, let's create your RAG pipeline. Open your Python environment and enter:
```python Create pipeline
from embedchain import Pipeline as App
from embedchain import App
app = App()
```

View File

@@ -19,7 +19,7 @@ Embedchain offers a simple yet customizable `search()` API that you can use for
First, let's create your RAG pipeline. Open your Python environment and enter:
```python Create pipeline
from embedchain import Pipeline as App
from embedchain import App
app = App()
```

View File

@@ -2,10 +2,9 @@ import importlib.metadata
__version__ = importlib.metadata.version(__package__ or __name__)
from embedchain.apps.app import App # noqa: F401
from embedchain.app import App # noqa: F401
from embedchain.client import Client # noqa: F401
from embedchain.pipeline import Pipeline # noqa: F401
from embedchain.vectordb.chroma import ChromaDB # noqa: F401
# Setup the user directory if doesn't exist already
Client.setup_dir()

431
embedchain/app.py Normal file
View File

@@ -0,0 +1,431 @@
import ast
import json
import logging
import os
import sqlite3
import uuid
from typing import Any, Dict, Optional
import requests
import yaml
from embedchain.client import Client
from embedchain.config import AppConfig, ChunkerConfig
from embedchain.constants import SQLITE_PATH
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
from embedchain.llm.openai import OpenAILlm
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import validate_config
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
# Setup the user directory if doesn't exist already
Client.setup_dir()
@register_deserializable
class App(EmbedChain):
"""
EmbedChain App lets you create a LLM powered app for your unstructured
data by defining your chosen data source, embedding model,
and vector database.
"""
def __init__(
self,
id: str = None,
name: str = None,
config: AppConfig = None,
db: BaseVectorDB = None,
embedding_model: BaseEmbedder = None,
llm: BaseLlm = None,
config_data: dict = None,
log_level=logging.WARN,
auto_deploy: bool = False,
chunker: ChunkerConfig = None,
):
"""
Initialize a new `App` instance.
:param config: Configuration for the pipeline, defaults to None
:type config: AppConfig, optional
:param db: The database to use for storing and retrieving embeddings, defaults to None
:type db: BaseVectorDB, optional
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
:type embedding_model: BaseEmbedder, optional
:param llm: The LLM model used to calculate embeddings, defaults to None
:type llm: BaseLlm, optional
:param config_data: Config dictionary, defaults to None
:type config_data: dict, optional
:param log_level: Log level to use, defaults to logging.WARN
:type log_level: int, optional
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
:type auto_deploy: bool, optional
:raises Exception: If an error occurs while creating the pipeline
"""
if id and config_data:
raise Exception("Cannot provide both id and config. Please provide only one of them.")
if id and name:
raise Exception("Cannot provide both id and name. Please provide only one of them.")
if name and config:
raise Exception("Cannot provide both name and config. Please provide only one of them.")
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
self.logger = logging.getLogger(__name__)
self.auto_deploy = auto_deploy
# Store the dict config as an attribute to be able to send it
self.config_data = config_data if (config_data and validate_config(config_data)) else None
self.client = None
# pipeline_id from the backend
self.id = None
self.chunker = None
if chunker:
self.chunker = ChunkerConfig(**chunker)
self.config = config or AppConfig()
self.name = self.config.name
self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id
if id is not None:
# Init client first since user is trying to fetch the pipeline
# details from the platform
self._init_client()
pipeline_details = self._get_pipeline(id)
self.config.id = self.local_id = pipeline_details["metadata"]["local_id"]
self.id = id
if name is not None:
self.name = name
self.embedding_model = embedding_model or OpenAIEmbedder()
self.db = db or ChromaDB()
self.llm = llm or OpenAILlm()
self._init_db()
# Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
# Establish a connection to the SQLite database
self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
self.cursor = self.connection.cursor()
# Create the 'data_sources' table if it doesn't exist
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS data_sources (
pipeline_id TEXT,
hash TEXT,
type TEXT,
value TEXT,
metadata TEXT,
is_uploaded INTEGER DEFAULT 0,
PRIMARY KEY (pipeline_id, hash)
)
"""
)
self.connection.commit()
# Send anonymous telemetry
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
self.user_asks = []
if self.auto_deploy:
self.deploy()
def _init_db(self):
"""
Initialize the database.
"""
self.db._set_embedder(self.embedding_model)
self.db._initialize()
self.db.set_collection_name(self.db.config.collection_name)
def _init_client(self):
"""
Initialize the client.
"""
config = Client.load_config()
if config.get("api_key"):
self.client = Client()
else:
api_key = input(
"🔑 Enter your Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n" # noqa: E501
)
self.client = Client(api_key=api_key)
def _get_pipeline(self, id):
"""
Get existing pipeline
"""
print("🛠️ Fetching pipeline details from the platform...")
url = f"{self.client.host}/api/v1/pipelines/{id}/cli/"
r = requests.get(
url,
headers={"Authorization": f"Token {self.client.api_key}"},
)
if r.status_code == 404:
raise Exception(f"❌ Pipeline with id {id} not found!")
print(
f"🎉 Pipeline loaded successfully! Pipeline url: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
)
return r.json()
def _create_pipeline(self):
"""
Create a pipeline on the platform.
"""
print("🛠️ Creating pipeline on the platform...")
# self.config_data is a dict. Pass it inside the key 'yaml_config' to the backend
payload = {
"yaml_config": json.dumps(self.config_data),
"name": self.name,
"local_id": self.local_id,
}
url = f"{self.client.host}/api/v1/pipelines/cli/create/"
r = requests.post(
url,
json=payload,
headers={"Authorization": f"Token {self.client.api_key}"},
)
if r.status_code not in [200, 201]:
raise Exception(f"❌ Error occurred while creating pipeline. API response: {r.text}")
if r.status_code == 200:
print(
f"🎉🎉🎉 Existing pipeline found! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
) # noqa: E501
elif r.status_code == 201:
print(
f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
)
return r.json()
def _get_presigned_url(self, data_type, data_value):
payload = {"data_type": data_type, "data_value": data_value}
r = requests.post(
f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
json=payload,
headers={"Authorization": f"Token {self.client.api_key}"},
)
r.raise_for_status()
return r.json()
def search(self, query, num_documents=3):
"""
Search for similar documents related to the query in the vector database.
"""
# Send anonymous telemetry
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
if self.id is None:
where = {"app_id": self.local_id}
context = self.db.query(
query,
n_results=num_documents,
where=where,
skip_embedding=False,
citations=True,
)
result = []
for c in context:
result.append(
{
"context": c[0],
"source": c[1],
"document_id": c[2],
}
)
return result
else:
# Make API call to the backend to get the results
NotImplementedError("Search is not implemented yet for the prod mode.")
def _upload_file_to_presigned_url(self, presigned_url, file_path):
try:
with open(file_path, "rb") as file:
response = requests.put(presigned_url, data=file)
response.raise_for_status()
return response.status_code == 200
except Exception as e:
self.logger.exception(f"Error occurred during file upload: {str(e)}")
print("❌ Error occurred during file upload!")
return False
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
payload = {
"data_type": data_type,
"data_value": data_value,
"metadata": metadata,
}
try:
self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
# print the local file path if user tries to upload a local file
printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
except Exception as e:
print(f"❌ Error occurred during data upload for type {data_type}!. Error: {str(e)}")
def _send_api_request(self, endpoint, payload):
url = f"{self.client.host}{endpoint}"
headers = {"Authorization": f"Token {self.client.api_key}"}
response = requests.post(url, json=payload, headers=headers)
response.raise_for_status()
return response
def _process_and_upload_data(self, data_hash, data_type, data_value):
if os.path.isabs(data_value):
presigned_url_data = self._get_presigned_url(data_type, data_value)
presigned_url = presigned_url_data["presigned_url"]
s3_key = presigned_url_data["s3_key"]
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
metadata = {"file_path": data_value, "s3_key": s3_key}
data_value = presigned_url
else:
self.logger.error(f"File upload failed for hash: {data_hash}")
return False
else:
if data_type == "qna_pair":
data_value = list(ast.literal_eval(data_value))
metadata = {}
try:
self._upload_data_to_pipeline(data_type, data_value, metadata)
self._mark_data_as_uploaded(data_hash)
return True
except Exception:
print(f"❌ Error occurred during data upload for hash {data_hash}!")
return False
def _mark_data_as_uploaded(self, data_hash):
self.cursor.execute(
"UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ?",
(data_hash, self.local_id),
)
self.connection.commit()
def get_data_sources(self):
db_data = self.cursor.execute("SELECT * FROM data_sources WHERE pipeline_id = ?", (self.local_id,)).fetchall()
data_sources = []
for data in db_data:
data_sources.append({"data_type": data[2], "data_value": data[3], "metadata": data[4]})
return data_sources
def deploy(self):
if self.client is None:
self._init_client()
pipeline_data = self._create_pipeline()
self.id = pipeline_data["id"]
results = self.cursor.execute(
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,) # noqa:E501
).fetchall()
if len(results) > 0:
print("🛠️ Adding data to your pipeline...")
for result in results:
data_hash, data_type, data_value = result[1], result[2], result[3]
self._process_and_upload_data(data_hash, data_type, data_value)
# Send anonymous telemetry
self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
@classmethod
def from_config(
cls,
config_path: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
auto_deploy: bool = False,
yaml_path: Optional[str] = None,
):
"""
Instantiate a Pipeline object from a configuration.
:param config_path: Path to the YAML or JSON configuration file.
:type config_path: Optional[str]
:param config: A dictionary containing the configuration.
:type config: Optional[Dict[str, Any]]
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
:type auto_deploy: bool, optional
:param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
:type yaml_path: Optional[str]
:return: An instance of the Pipeline class.
:rtype: Pipeline
"""
# Backward compatibility for yaml_path
if yaml_path and not config_path:
config_path = yaml_path
if config_path and config:
raise ValueError("Please provide only one of config_path or config.")
config_data = None
if config_path:
file_extension = os.path.splitext(config_path)[1]
with open(config_path, "r") as file:
if file_extension in [".yaml", ".yml"]:
config_data = yaml.safe_load(file)
elif file_extension == ".json":
config_data = json.load(file)
else:
raise ValueError("config_path must be a path to a YAML or JSON file.")
elif config and isinstance(config, dict):
config_data = config
else:
logging.error(
"Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
)
config_data = {}
try:
validate_config(config_data)
except Exception as e:
raise Exception(f"Error occurred while validating the config. Error: {str(e)}")
app_config_data = config_data.get("app", {}).get("config", {})
db_config_data = config_data.get("vectordb", {})
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
llm_config_data = config_data.get("llm", {})
chunker_config_data = config_data.get("chunker", {})
app_config = AppConfig(**app_config_data)
db_provider = db_config_data.get("provider", "chroma")
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
if llm_config_data:
llm_provider = llm_config_data.get("provider", "openai")
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
else:
llm = None
embedding_model_provider = embedding_model_config_data.get("provider", "openai")
embedding_model = EmbedderFactory.create(
embedding_model_provider, embedding_model_config_data.get("config", {})
)
# Send anonymous telemetry
event_properties = {"init_type": "config_data"}
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
return cls(
config=app_config,
llm=llm,
db=db,
embedding_model=embedding_model,
config_data=config_data,
auto_deploy=auto_deploy,
chunker=chunker_config_data,
)

View File

@@ -1,157 +0,0 @@
from typing import Optional
import yaml
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
ChunkerConfig)
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
from embedchain.llm.openai import OpenAILlm
from embedchain.utils import validate_config
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
@register_deserializable
class App(EmbedChain):
"""
The EmbedChain app in it's simplest and most straightforward form.
An opinionated choice of LLM, vector database and embedding model.
Methods:
add(source, data_type): adds the data from the given URL to the vector db.
query(query): finds answer to the given query using vector database and LLM.
chat(query): finds answer to the given query using vector database and LLM, with conversation history.
"""
def __init__(
self,
config: Optional[AppConfig] = None,
llm: BaseLlm = None,
llm_config: Optional[BaseLlmConfig] = None,
db: BaseVectorDB = None,
db_config: Optional[BaseVectorDbConfig] = None,
embedder: BaseEmbedder = None,
embedder_config: Optional[BaseEmbedderConfig] = None,
system_prompt: Optional[str] = None,
chunker: Optional[ChunkerConfig] = None,
):
"""
Initialize a new `App` instance.
:param config: Config for the app instance., defaults to None
:type config: Optional[AppConfig], optional
:param llm: LLM Class instance. example: `from embedchain.llm.openai import OpenAILlm`, defaults to OpenAiLlm
:type llm: BaseLlm, optional
:param llm_config: Allows you to configure the LLM, e.g. how many documents to return,
example: `from embedchain.config import BaseLlmConfig`, defaults to None
:type llm_config: Optional[BaseLlmConfig], optional
:param db: The database to use for storing and retrieving embeddings,
example: `from embedchain.vectordb.chroma_db import ChromaDb`, defaults to ChromaDb
:type db: BaseVectorDB, optional
:param db_config: Allows you to configure the vector database,
example: `from embedchain.config import ChromaDbConfig`, defaults to None
:type db_config: Optional[BaseVectorDbConfig], optional
:param embedder: The embedder (embedding model and function) use to calculate embeddings.
example: `from embedchain.embedder.gpt4all_embedder import GPT4AllEmbedder`, defaults to OpenAIEmbedder
:type embedder: BaseEmbedder, optional
:param embedder_config: Allows you to configure the Embedder.
example: `from embedchain.config import BaseEmbedderConfig`, defaults to None
:type embedder_config: Optional[BaseEmbedderConfig], optional
:param system_prompt: System prompt that will be provided to the LLM as such, defaults to None
:type system_prompt: Optional[str], optional
:raises TypeError: LLM, database or embedder or their config is not a valid class instance.
"""
# Type check configs
if config and not isinstance(config, AppConfig):
raise TypeError(
"Config is not a `AppConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
if llm_config and not isinstance(llm_config, BaseLlmConfig):
raise TypeError(
"`llm_config` is not a `BaseLlmConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
if db_config and not isinstance(db_config, BaseVectorDbConfig):
raise TypeError(
"`db_config` is not a `BaseVectorDbConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
if embedder_config and not isinstance(embedder_config, BaseEmbedderConfig):
raise TypeError(
"`embedder_config` is not a `BaseEmbedderConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
# Assign defaults
if config is None:
config = AppConfig()
if llm is None:
llm = OpenAILlm(config=llm_config)
if db is None:
db = ChromaDB(config=db_config)
if embedder is None:
embedder = OpenAIEmbedder(config=embedder_config)
self.chunker = None
if chunker:
self.chunker = ChunkerConfig(**chunker)
# Type check assignments
if not isinstance(llm, BaseLlm):
raise TypeError(
"LLM is not a `BaseLlm` instance. "
"Please make sure the type is right and that you are passing an instance."
)
if not isinstance(db, BaseVectorDB):
raise TypeError(
"Database is not a `BaseVectorDB` instance. "
"Please make sure the type is right and that you are passing an instance."
)
if not isinstance(embedder, BaseEmbedder):
raise TypeError(
"Embedder is not a `BaseEmbedder` instance. "
"Please make sure the type is right and that you are passing an instance."
)
super().__init__(config, llm=llm, db=db, embedder=embedder, system_prompt=system_prompt)
@classmethod
def from_config(cls, yaml_path: str):
"""
Instantiate an App object from a YAML configuration file.
:param yaml_path: Path to the YAML configuration file.
:type yaml_path: str
:return: An instance of the App class.
:rtype: App
"""
with open(yaml_path, "r") as file:
config_data = yaml.safe_load(file)
try:
validate_config(config_data)
except Exception as e:
raise Exception(f"❌ Error occurred while validating the YAML config. Error: {str(e)}")
app_config_data = config_data.get("app", {})
llm_config_data = config_data.get("llm", {})
db_config_data = config_data.get("vectordb", {})
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
chunker_config_data = config_data.get("chunker", {})
app_config = AppConfig(**app_config_data.get("config", {}))
llm_provider = llm_config_data.get("provider", "openai")
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
db_provider = db_config_data.get("provider", "chroma")
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
embedder_provider = embedding_model_config_data.get("provider", "openai")
embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {}))
return cls(config=app_config, llm=llm, db=db, embedder=embedder, chunker=chunker_config_data)

View File

@@ -1,7 +1,7 @@
from typing import Any
from embedchain import Pipeline as App
from embedchain.config import AddConfig, BaseLlmConfig, PipelineConfig
from embedchain import App
from embedchain.config import AddConfig, AppConfig, BaseLlmConfig
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.helpers.json_serializable import (JSONSerializable,
register_deserializable)
@@ -12,7 +12,7 @@ from embedchain.vectordb.chroma import ChromaDB
@register_deserializable
class BaseBot(JSONSerializable):
def __init__(self):
self.app = App(config=PipelineConfig(), llm=OpenAILlm(), db=ChromaDB(), embedding_model=OpenAIEmbedder())
self.app = App(config=AppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedding_model=OpenAIEmbedder())
def add(self, data: Any, config: AddConfig = None):
"""

View File

@@ -1,12 +1,11 @@
# flake8: noqa: F401
from .add_config import AddConfig, ChunkerConfig
from .apps.app_config import AppConfig
from .app_config import AppConfig
from .base_config import BaseConfig
from .embedder.base import BaseEmbedderConfig
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
from .llm.base import BaseLlmConfig
from .pipeline_config import PipelineConfig
from .vectordb.chroma import ChromaDbConfig
from .vectordb.elasticsearch import ElasticsearchDBConfig
from .vectordb.opensearch import OpenSearchDBConfig

View File

@@ -15,8 +15,9 @@ class AppConfig(BaseAppConfig):
self,
log_level: str = "WARNING",
id: Optional[str] = None,
name: Optional[str] = None,
collect_metrics: Optional[bool] = True,
collection_name: Optional[str] = None,
**kwargs,
):
"""
Initializes a configuration class instance for an App. This is the simplest form of an embedchain app.
@@ -28,8 +29,6 @@ class AppConfig(BaseAppConfig):
:type id: Optional[str], optional
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
:type collect_metrics: Optional[bool], optional
:param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
defaults to None
:type collection_name: Optional[str], optional
"""
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name)
self.name = name
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, **kwargs)

View File

@@ -1,38 +0,0 @@
from typing import Optional
from embedchain.helpers.json_serializable import register_deserializable
from .apps.base_app_config import BaseAppConfig
@register_deserializable
class PipelineConfig(BaseAppConfig):
"""
Config to initialize an embedchain custom `App` instance, with extra config options.
"""
def __init__(
self,
log_level: str = "WARNING",
id: Optional[str] = None,
name: Optional[str] = None,
collect_metrics: Optional[bool] = True,
):
"""
Initializes a configuration class instance for an App. This is the simplest form of an embedchain app.
Most of the configuration is done in the `App` class itself.
:param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
:type log_level: str, optional
:param id: ID of the app. Document metadata will have this id., defaults to None
:type id: Optional[str], optional
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
:type collect_metrics: Optional[bool], optional
:param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
defaults to None
:type collection_name: Optional[str], optional
"""
self._setup_logging(log_level)
self.id = id
self.name = name
self.collect_metrics = collect_metrics

View File

@@ -2,7 +2,7 @@ import os
import gradio as gr
from embedchain import Pipeline as App
from embedchain import App
os.environ["OPENAI_API_KEY"] = "sk-xxx"

View File

@@ -1,6 +1,6 @@
import streamlit as st
from embedchain import Pipeline as App
from embedchain import App
@st.cache_resource

View File

@@ -9,7 +9,7 @@ from langchain.docstore.document import Document
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.apps.base_app_config import BaseAppConfig
from embedchain.config.base_app_config import BaseAppConfig
from embedchain.constants import SQLITE_PATH
from embedchain.data_formatter import DataFormatter
from embedchain.embedder.base import BaseEmbedder

View File

@@ -1,425 +1,9 @@
import ast
import json
import logging
import os
import sqlite3
import uuid
from typing import Any, Dict, Optional
import requests
import yaml
from embedchain import Client
from embedchain.config import ChunkerConfig, PipelineConfig
from embedchain.constants import SQLITE_PATH
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
from embedchain.llm.openai import OpenAILlm
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import validate_config
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
# Setup the user directory if doesn't exist already
Client.setup_dir()
from embedchain.app import App
@register_deserializable
class Pipeline(EmbedChain):
class Pipeline(App):
"""
EmbedChain pipeline lets you create a LLM powered app for your unstructured
data by defining a pipeline with your chosen data source, embedding model,
and vector database.
This is deprecated. Use `App` instead.
"""
def __init__(
self,
id: str = None,
name: str = None,
config: PipelineConfig = None,
db: BaseVectorDB = None,
embedding_model: BaseEmbedder = None,
llm: BaseLlm = None,
config_data: dict = None,
log_level=logging.WARN,
auto_deploy: bool = False,
chunker: ChunkerConfig = None,
):
"""
Initialize a new `App` instance.
:param config: Configuration for the pipeline, defaults to None
:type config: PipelineConfig, optional
:param db: The database to use for storing and retrieving embeddings, defaults to None
:type db: BaseVectorDB, optional
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
:type embedding_model: BaseEmbedder, optional
:param llm: The LLM model used to calculate embeddings, defaults to None
:type llm: BaseLlm, optional
:param config_data: Config dictionary, defaults to None
:type config_data: dict, optional
:param log_level: Log level to use, defaults to logging.WARN
:type log_level: int, optional
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
:type auto_deploy: bool, optional
:raises Exception: If an error occurs while creating the pipeline
"""
if id and config_data:
raise Exception("Cannot provide both id and config. Please provide only one of them.")
if id and name:
raise Exception("Cannot provide both id and name. Please provide only one of them.")
if name and config:
raise Exception("Cannot provide both name and config. Please provide only one of them.")
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
self.logger = logging.getLogger(__name__)
self.auto_deploy = auto_deploy
# Store the dict config as an attribute to be able to send it
self.config_data = config_data if (config_data and validate_config(config_data)) else None
self.client = None
# pipeline_id from the backend
self.id = None
self.chunker = None
if chunker:
self.chunker = ChunkerConfig(**chunker)
self.config = config or PipelineConfig()
self.name = self.config.name
self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id
if id is not None:
# Init client first since user is trying to fetch the pipeline
# details from the platform
self._init_client()
pipeline_details = self._get_pipeline(id)
self.config.id = self.local_id = pipeline_details["metadata"]["local_id"]
self.id = id
if name is not None:
self.name = name
self.embedding_model = embedding_model or OpenAIEmbedder()
self.db = db or ChromaDB()
self.llm = llm or OpenAILlm()
self._init_db()
# Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
# Establish a connection to the SQLite database
self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
self.cursor = self.connection.cursor()
# Create the 'data_sources' table if it doesn't exist
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS data_sources (
pipeline_id TEXT,
hash TEXT,
type TEXT,
value TEXT,
metadata TEXT,
is_uploaded INTEGER DEFAULT 0,
PRIMARY KEY (pipeline_id, hash)
)
"""
)
self.connection.commit()
# Send anonymous telemetry
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
self.user_asks = []
if self.auto_deploy:
self.deploy()
def _init_db(self):
"""
Initialize the database.
"""
self.db._set_embedder(self.embedding_model)
self.db._initialize()
self.db.set_collection_name(self.db.config.collection_name)
def _init_client(self):
"""
Initialize the client.
"""
config = Client.load_config()
if config.get("api_key"):
self.client = Client()
else:
api_key = input(
"🔑 Enter your Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n" # noqa: E501
)
self.client = Client(api_key=api_key)
def _get_pipeline(self, id):
"""
Get existing pipeline
"""
print("🛠️ Fetching pipeline details from the platform...")
url = f"{self.client.host}/api/v1/pipelines/{id}/cli/"
r = requests.get(
url,
headers={"Authorization": f"Token {self.client.api_key}"},
)
if r.status_code == 404:
raise Exception(f"❌ Pipeline with id {id} not found!")
print(
f"🎉 Pipeline loaded successfully! Pipeline url: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
)
return r.json()
def _create_pipeline(self):
"""
Create a pipeline on the platform.
"""
print("🛠️ Creating pipeline on the platform...")
# self.config_data is a dict. Pass it inside the key 'yaml_config' to the backend
payload = {
"yaml_config": json.dumps(self.config_data),
"name": self.name,
"local_id": self.local_id,
}
url = f"{self.client.host}/api/v1/pipelines/cli/create/"
r = requests.post(
url,
json=payload,
headers={"Authorization": f"Token {self.client.api_key}"},
)
if r.status_code not in [200, 201]:
raise Exception(f"❌ Error occurred while creating pipeline. API response: {r.text}")
if r.status_code == 200:
print(
f"🎉🎉🎉 Existing pipeline found! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
) # noqa: E501
elif r.status_code == 201:
print(
f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
)
return r.json()
def _get_presigned_url(self, data_type, data_value):
payload = {"data_type": data_type, "data_value": data_value}
r = requests.post(
f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
json=payload,
headers={"Authorization": f"Token {self.client.api_key}"},
)
r.raise_for_status()
return r.json()
def search(self, query, num_documents=3):
"""
Search for similar documents related to the query in the vector database.
"""
# Send anonymous telemetry
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
if self.id is None:
where = {"app_id": self.local_id}
context = self.db.query(
query,
n_results=num_documents,
where=where,
skip_embedding=False,
citations=True,
)
result = []
for c in context:
result.append({"context": c[0], "metadata": c[1]})
return result
else:
# Make API call to the backend to get the results
NotImplementedError("Search is not implemented yet for the prod mode.")
def _upload_file_to_presigned_url(self, presigned_url, file_path):
try:
with open(file_path, "rb") as file:
response = requests.put(presigned_url, data=file)
response.raise_for_status()
return response.status_code == 200
except Exception as e:
self.logger.exception(f"Error occurred during file upload: {str(e)}")
print("❌ Error occurred during file upload!")
return False
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
payload = {
"data_type": data_type,
"data_value": data_value,
"metadata": metadata,
}
try:
self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
# print the local file path if user tries to upload a local file
printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
except Exception as e:
print(f"❌ Error occurred during data upload for type {data_type}!. Error: {str(e)}")
def _send_api_request(self, endpoint, payload):
url = f"{self.client.host}{endpoint}"
headers = {"Authorization": f"Token {self.client.api_key}"}
response = requests.post(url, json=payload, headers=headers)
response.raise_for_status()
return response
def _process_and_upload_data(self, data_hash, data_type, data_value):
if os.path.isabs(data_value):
presigned_url_data = self._get_presigned_url(data_type, data_value)
presigned_url = presigned_url_data["presigned_url"]
s3_key = presigned_url_data["s3_key"]
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
metadata = {"file_path": data_value, "s3_key": s3_key}
data_value = presigned_url
else:
self.logger.error(f"File upload failed for hash: {data_hash}")
return False
else:
if data_type == "qna_pair":
data_value = list(ast.literal_eval(data_value))
metadata = {}
try:
self._upload_data_to_pipeline(data_type, data_value, metadata)
self._mark_data_as_uploaded(data_hash)
return True
except Exception:
print(f"❌ Error occurred during data upload for hash {data_hash}!")
return False
def _mark_data_as_uploaded(self, data_hash):
self.cursor.execute(
"UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ?",
(data_hash, self.local_id),
)
self.connection.commit()
def get_data_sources(self):
db_data = self.cursor.execute("SELECT * FROM data_sources WHERE pipeline_id = ?", (self.local_id,)).fetchall()
data_sources = []
for data in db_data:
data_sources.append({"data_type": data[2], "data_value": data[3], "metadata": data[4]})
return data_sources
def deploy(self):
if self.client is None:
self._init_client()
pipeline_data = self._create_pipeline()
self.id = pipeline_data["id"]
results = self.cursor.execute(
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,) # noqa:E501
).fetchall()
if len(results) > 0:
print("🛠️ Adding data to your pipeline...")
for result in results:
data_hash, data_type, data_value = result[1], result[2], result[3]
self._process_and_upload_data(data_hash, data_type, data_value)
# Send anonymous telemetry
self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
@classmethod
def from_config(
cls,
config_path: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
auto_deploy: bool = False,
yaml_path: Optional[str] = None,
):
"""
Instantiate a Pipeline object from a configuration.
:param config_path: Path to the YAML or JSON configuration file.
:type config_path: Optional[str]
:param config: A dictionary containing the configuration.
:type config: Optional[Dict[str, Any]]
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
:type auto_deploy: bool, optional
:param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
:type yaml_path: Optional[str]
:return: An instance of the Pipeline class.
:rtype: Pipeline
"""
# Backward compatibility for yaml_path
if yaml_path and not config_path:
config_path = yaml_path
if config_path and config:
raise ValueError("Please provide only one of config_path or config.")
config_data = None
if config_path:
file_extension = os.path.splitext(config_path)[1]
with open(config_path, "r") as file:
if file_extension in [".yaml", ".yml"]:
config_data = yaml.safe_load(file)
elif file_extension == ".json":
config_data = json.load(file)
else:
raise ValueError("config_path must be a path to a YAML or JSON file.")
elif config and isinstance(config, dict):
config_data = config
else:
logging.error(
"Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
)
config_data = {}
try:
validate_config(config_data)
except Exception as e:
raise Exception(f"Error occurred while validating the config. Error: {str(e)}")
pipeline_config_data = config_data.get("app", {}).get("config", {})
db_config_data = config_data.get("vectordb", {})
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
llm_config_data = config_data.get("llm", {})
chunker_config_data = config_data.get("chunker", {})
pipeline_config = PipelineConfig(**pipeline_config_data)
db_provider = db_config_data.get("provider", "chroma")
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
if llm_config_data:
llm_provider = llm_config_data.get("provider", "openai")
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
else:
llm = None
embedding_model_provider = embedding_model_config_data.get("provider", "openai")
embedding_model = EmbedderFactory.create(
embedding_model_provider, embedding_model_config_data.get("config", {})
)
# Send anonymous telemetry
event_properties = {"init_type": "config_data"}
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
return cls(
config=pipeline_config,
llm=llm,
db=db,
embedding_model=embedding_model,
config_data=config_data,
auto_deploy=auto_deploy,
chunker=chunker_config_data,
)
pass

View File

@@ -2,7 +2,7 @@ import os
import chainlit as cl
from embedchain import Pipeline as App
from embedchain import App
os.environ["OPENAI_API_KEY"] = "sk-xxx"

View File

@@ -6,7 +6,7 @@ import threading
import streamlit as st
from embedchain import Pipeline as App
from embedchain import App
from embedchain.config import BaseLlmConfig
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
generate)

View File

@@ -2,7 +2,7 @@ import os
import streamlit as st
from embedchain import Pipeline as App
from embedchain import App
@st.cache_resource

View File

@@ -9,7 +9,7 @@ from services import get_app, get_apps, remove_app, save_app
from sqlalchemy.orm import Session
from utils import generate_error_message_for_api_keys
from embedchain import Pipeline as App
from embedchain import App
from embedchain.client import Client
Base.metadata.create_all(bind=engine)

View File

@@ -6,7 +6,7 @@ from io import StringIO
import requests
import streamlit as st
from embedchain import Pipeline as App
from embedchain import App
from embedchain.config import BaseLlmConfig
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
generate)

View File

@@ -2,7 +2,7 @@ import queue
import streamlit as st
from embedchain import Pipeline as App
from embedchain import App
from embedchain.config import BaseLlmConfig
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
generate)
@@ -35,7 +35,7 @@ with st.expander(":grey[Want to create your own Unacademy UPSC AI?]"):
```
```python
from embedchain import Pipeline as App
from embedchain import App
unacademy_ai_app = App()
unacademy_ai_app.add(
"https://unacademy.com/content/upsc/study-material/plan-policy/atma-nirbhar-bharat-3-0/",

View File

@@ -54,7 +54,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"\n",
"os.environ[\"ANTHROPIC_API_KEY\"] = \"xxx\""

View File

@@ -44,7 +44,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"os.environ[\"OPENAI_API_TYPE\"] = \"azure\"\n",
"os.environ[\"OPENAI_API_BASE\"] = \"https://xxx.openai.azure.com/\"\n",
@@ -143,7 +143,7 @@
"source": [
"while(True):\n",
" question = input(\"Enter question: \")\n",
" if question in ['q', 'exit', 'quit']\n",
" if question in ['q', 'exit', 'quit']:\n",
" break\n",
" answer = app.query(question)\n",
" print(answer)"

View File

@@ -49,7 +49,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
]

View File

@@ -53,7 +53,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"\n",
"os.environ[\"COHERE_API_KEY\"] = \"xxx\""

View File

@@ -49,7 +49,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
]

View File

@@ -33,7 +33,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"from embedchain.config import AppConfig\n",
"\n",
"\n",

View File

@@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"embedchain_docs_bot = App()"
]

View File

@@ -52,7 +52,7 @@
},
"outputs": [],
"source": [
"from embedchain import Pipeline as App"
"from embedchain import App"
]
},
{

View File

@@ -54,7 +54,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"os.environ[\"HUGGINGFACE_ACCESS_TOKEN\"] = \"hf_xxx\""
]

View File

@@ -54,7 +54,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"\n",
"os.environ[\"JINACHAT_API_KEY\"] = \"xxx\""

View File

@@ -53,7 +53,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"\n",
"os.environ[\"REPLICATE_API_TOKEN\"] = \"xxx\""

View File

@@ -92,7 +92,7 @@
}
],
"source": [
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"app = App.from_config(config_path=\"ollama.yaml\")"
]
},

View File

@@ -54,7 +54,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
]
@@ -80,7 +80,7 @@
"llm:\n",
" provider: openai\n",
" config:\n",
" model: gpt-35-turbo\n",
" model: gpt-3.5-turbo\n",
" temperature: 0.5\n",
" max_tokens: 1000\n",
" top_p: 1\n",

View File

@@ -49,7 +49,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
]

View File

@@ -49,7 +49,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"\n",
"os.environ[\"PINECONE_API_KEY\"] = \"xxx\"\n",

View File

@@ -53,7 +53,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"\"\n",
"os.environ[\"TOGETHER_API_KEY\"] = \"\""

View File

@@ -53,7 +53,7 @@
"outputs": [],
"source": [
"import os\n",
"from embedchain import Pipeline as App\n",
"from embedchain import App\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
]

View File

@@ -8,6 +8,7 @@ from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.embedchain import EmbedChain
from embedchain.llm.base import BaseLlm
from embedchain.memory.base import ECChatMemory
from embedchain.vectordb.chroma import ChromaDB
os.environ["OPENAI_API_KEY"] = "test-api-key"
@@ -15,7 +16,7 @@ os.environ["OPENAI_API_KEY"] = "test-api-key"
@pytest.fixture
def app_instance():
config = AppConfig(log_level="DEBUG", collect_metrics=False)
return App(config)
return App(config=config)
def test_whole_app(app_instance, mocker):
@@ -44,9 +45,9 @@ def test_add_after_reset(app_instance, mocker):
mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
config = AppConfig(log_level="DEBUG", collect_metrics=False)
chroma_config = {"allow_reset": True}
app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
chroma_config = ChromaDbConfig(allow_reset=True)
db = ChromaDB(config=chroma_config)
app_instance = App(config=config, db=db)
# mock delete chat history
mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)

View File

@@ -114,5 +114,7 @@ class TestApp(unittest.TestCase):
self.assertEqual(answer, "Test answer")
_args, kwargs = mock_database_query.call_args
self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
where = kwargs.get("where")
assert "app_id" in where
assert "attribute" in where
mock_answer.assert_called_once()

View File

@@ -5,6 +5,7 @@ import pytest
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.llm.openai import OpenAILlm
@pytest.fixture
@@ -37,25 +38,14 @@ def test_query_config_app_passing(mock_get_answer):
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
app = App(config=config, llm_config=chat_config)
llm = OpenAILlm(config=chat_config)
app = App(config=config, llm=llm)
answer = app.llm.get_llm_model_answer("Test query")
assert app.llm.config.system_prompt == "Test system prompt"
assert answer == "Test answer"
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
def test_app_passing(mock_get_answer):
mock_get_answer.return_value = MagicMock()
mock_get_answer.return_value = "Test answer"
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig()
app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
answer = app.llm.get_llm_model_answer("Test query")
assert app.llm.config.system_prompt == "Test system prompt"
assert answer == "Test answer"
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_params(app):
with patch.object(app, "_retrieve_from_database") as mock_retrieve:
@@ -83,5 +73,7 @@ def test_query_with_where_in_query_config(app):
assert answer == "Test answer"
_, kwargs = mock_database_query.call_args
assert kwargs.get("input_query") == "Test query"
assert kwargs.get("where") == {"attribute": "value"}
where = kwargs.get("where")
assert "app_id" in where
assert "attribute" in where
mock_answer.assert_called_once()

View File

@@ -4,11 +4,10 @@ import pytest
import yaml
from embedchain import App
from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
BaseLlmConfig, ChromaDbConfig)
from embedchain.config import ChromaDbConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.llm.base import BaseLlm
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
@@ -21,13 +20,14 @@ def app():
def test_app(app):
assert isinstance(app.llm, BaseLlm)
assert isinstance(app.db, BaseVectorDB)
assert isinstance(app.embedder, BaseEmbedder)
assert isinstance(app.embedding_model, BaseEmbedder)
class TestConfigForAppComponents:
def test_constructor_config(self):
collection_name = "my-test-collection"
app = App(db_config=ChromaDbConfig(collection_name=collection_name))
db = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
app = App(db=db)
assert app.db.config.collection_name == collection_name
def test_component_config(self):
@@ -36,50 +36,6 @@ class TestConfigForAppComponents:
app = App(db=database)
assert app.db.config.collection_name == collection_name
def test_different_configs_are_proper_instances(self):
app_config = AppConfig()
wrong_config = AddConfig()
with pytest.raises(TypeError):
App(config=wrong_config)
assert isinstance(app_config, AppConfig)
llm_config = BaseLlmConfig()
wrong_llm_config = "wrong_llm_config"
with pytest.raises(TypeError):
App(llm_config=wrong_llm_config)
assert isinstance(llm_config, BaseLlmConfig)
db_config = BaseVectorDbConfig()
wrong_db_config = "wrong_db_config"
with pytest.raises(TypeError):
App(db_config=wrong_db_config)
assert isinstance(db_config, BaseVectorDbConfig)
embedder_config = BaseEmbedderConfig()
wrong_embedder_config = "wrong_embedder_config"
with pytest.raises(TypeError):
App(embedder_config=wrong_embedder_config)
assert isinstance(embedder_config, BaseEmbedderConfig)
def test_components_raises_type_error_if_not_proper_instances(self):
wrong_llm = "wrong_llm"
with pytest.raises(TypeError):
App(llm=wrong_llm)
wrong_db = "wrong_db"
with pytest.raises(TypeError):
App(db=wrong_db)
wrong_embedder = "wrong_embedder"
with pytest.raises(TypeError):
App(embedder=wrong_embedder)
class TestAppFromConfig:
def load_config_data(self, yaml_path):
@@ -92,14 +48,13 @@ class TestAppFromConfig:
yaml_path = "configs/chroma.yaml"
config_data = self.load_config_data(yaml_path)
app = App.from_config(yaml_path)
app = App.from_config(config_path=yaml_path)
# Check if the App instance and its components were created correctly
assert isinstance(app, App)
# Validate the AppConfig values
assert app.config.id == config_data["app"]["config"]["id"]
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
# Even though not present in the config, the default value is used
assert app.config.collect_metrics is True
@@ -118,8 +73,8 @@ class TestAppFromConfig:
# Validate the Embedder config values
embedder_config = config_data["embedder"]["config"]
assert app.embedder.config.model == embedder_config["model"]
assert app.embedder.config.deployment_name == embedder_config.get("deployment_name")
assert app.embedding_model.config.model == embedder_config["model"]
assert app.embedding_model.config.deployment_name == embedder_config.get("deployment_name")
def test_from_opensource_config(self, mocker):
mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
@@ -134,7 +89,6 @@ class TestAppFromConfig:
# Validate the AppConfig values
assert app.config.id == config_data["app"]["config"]["id"]
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"]
# Validate the LLM config values
@@ -153,4 +107,4 @@ class TestAppFromConfig:
# Validate the Embedder config values
embedder_config = config_data["embedder"]["config"]
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
assert app.embedding_model.config.deployment_name == embedder_config["deployment_name"]

View File

@@ -20,8 +20,9 @@ def chroma_db():
@pytest.fixture
def app_with_settings():
chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db")
chroma_db = ChromaDB(config=chroma_config)
app_config = AppConfig(collect_metrics=False)
return App(config=app_config, db_config=chroma_config)
return App(config=app_config, db=chroma_db)
@pytest.fixture(scope="session", autouse=True)
@@ -65,7 +66,8 @@ def test_app_init_with_host_and_port(mock_client):
port = "1234"
config = AppConfig(collect_metrics=False)
db_config = ChromaDbConfig(host=host, port=port)
_app = App(config, db_config=db_config)
db = ChromaDB(config=db_config)
_app = App(config=config, db=db)
called_settings: Settings = mock_client.call_args[0][0]
assert called_settings.chroma_server_host == host
@@ -74,7 +76,8 @@ def test_app_init_with_host_and_port(mock_client):
@patch("embedchain.vectordb.chroma.chromadb.Client")
def test_app_init_with_host_and_port_none(mock_client):
_app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
_app = App(config=AppConfig(collect_metrics=False), db=db)
called_settings: Settings = mock_client.call_args[0][0]
assert called_settings.chroma_server_host is None
@@ -82,7 +85,8 @@ def test_app_init_with_host_and_port_none(mock_client):
def test_chroma_db_duplicates_throw_warning(caplog):
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" in caplog.text
@@ -91,7 +95,8 @@ def test_chroma_db_duplicates_throw_warning(caplog):
def test_chroma_db_duplicates_collections_no_warning(caplog):
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.set_collection_name("test_collection_2")
@@ -104,24 +109,28 @@ def test_chroma_db_duplicates_collections_no_warning(caplog):
def test_chroma_db_collection_init_with_default_collection():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
assert app.db.collection.name == "embedchain_store"
def test_chroma_db_collection_init_with_custom_collection():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name(name="test_collection")
assert app.db.collection.name == "test_collection"
def test_chroma_db_collection_set_collection_name():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection")
assert app.db.collection.name == "test_collection"
def test_chroma_db_collection_changes_encapsulated():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
assert app.db.count() == 0
@@ -207,12 +216,14 @@ def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings):
def test_chroma_db_collection_collections_are_persistent():
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
del app
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app = App(config=AppConfig(collect_metrics=False), db=db)
app.set_collection_name("test_collection_1")
assert app.db.count() == 1
@@ -220,13 +231,15 @@ def test_chroma_db_collection_collections_are_persistent():
def test_chroma_db_collection_parallel_collections():
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
app1 = App(
AppConfig(collection_name="test_collection_1", collect_metrics=False),
db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
config=AppConfig(collect_metrics=False),
db=db1,
)
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
app2 = App(
AppConfig(collection_name="test_collection_2", collect_metrics=False),
db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
config=AppConfig(collect_metrics=False),
db=db2,
)
# cleanup if any previous tests failed or were interrupted
@@ -251,13 +264,11 @@ def test_chroma_db_collection_parallel_collections():
def test_chroma_db_collection_ids_share_collections():
app1 = App(
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
app1.set_collection_name("one_collection")
app2 = App(
AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
app2.set_collection_name("one_collection")
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
@@ -272,21 +283,17 @@ def test_chroma_db_collection_ids_share_collections():
def test_chroma_db_collection_reset():
app1 = App(
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
app1.set_collection_name("one_collection")
app2 = App(
AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
app2.set_collection_name("two_collection")
app3 = App(
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
db3 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app3 = App(config=AppConfig(collect_metrics=False), db=db3)
app3.set_collection_name("three_collection")
app4 = App(
AppConfig(id="new_app_id_4", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
)
db4 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
app4 = App(config=AppConfig(collect_metrics=False), db=db4)
app4.set_collection_name("four_collection")
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])

View File

@@ -13,7 +13,7 @@ class TestEsDB(unittest.TestCase):
def test_setUp(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
self.vector_dim = 384
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_config = AppConfig(collect_metrics=False)
self.app = App(config=app_config, db=self.db)
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
@@ -22,8 +22,8 @@ class TestEsDB(unittest.TestCase):
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_query(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
app_config = AppConfig(collection_name=False, collect_metrics=False)
self.app = App(config=app_config, db=self.db, embedder=GPT4AllEmbedder())
app_config = AppConfig(collect_metrics=False)
self.app = App(config=app_config, db=self.db, embedding_model=GPT4AllEmbedder())
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
self.assertEqual(self.db.client, mock_client.return_value)
@@ -74,7 +74,7 @@ class TestEsDB(unittest.TestCase):
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_query_with_skip_embedding(self, mock_client):
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
app_config = AppConfig(collection_name=False, collect_metrics=False)
app_config = AppConfig(collect_metrics=False)
self.app = App(config=app_config, db=self.db)
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.

View File

@@ -29,7 +29,7 @@ class TestPinecone:
# Create a PineconeDB instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
# Assert that the embedder was set
assert db.embedder == embedder
@@ -48,7 +48,7 @@ class TestPinecone:
# Create a PineconeDb instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=base_embedder)
App(config=app_config, db=db, embedding_model=base_embedder)
# Add some documents to the database
documents = ["This is a document.", "This is another document."]
@@ -76,7 +76,7 @@ class TestPinecone:
# Create a PineconeDB instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=base_embedder)
App(config=app_config, db=db, embedding_model=base_embedder)
# Query the database for documents that are similar to "document"
input_query = ["document"]
@@ -94,7 +94,7 @@ class TestPinecone:
# Create a PineconeDb instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=BaseEmbedder())
App(config=app_config, db=db, embedding_model=BaseEmbedder())
# Reset the database
db.reset()

View File

@@ -29,7 +29,7 @@ class TestQdrantDB(unittest.TestCase):
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
self.assertEqual(db.collection_name, "embedchain-store-1526")
self.assertEqual(db.client, qdrant_client_mock.return_value)
@@ -46,7 +46,7 @@ class TestQdrantDB(unittest.TestCase):
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
resp = db.get(ids=[], where={})
self.assertEqual(resp, {"ids": []})
@@ -65,7 +65,7 @@ class TestQdrantDB(unittest.TestCase):
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
documents = ["This is a test document.", "This is another test document."]
@@ -76,7 +76,7 @@ class TestQdrantDB(unittest.TestCase):
qdrant_client_mock.return_value.upsert.assert_called_once_with(
collection_name="embedchain-store-1526",
points=Batch(
ids=["abc", "def"],
ids=["def", "ghi"],
payloads=[
{
"identifier": "123",
@@ -102,7 +102,7 @@ class TestQdrantDB(unittest.TestCase):
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
@@ -132,7 +132,7 @@ class TestQdrantDB(unittest.TestCase):
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
db.count()
qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1526")
@@ -146,7 +146,7 @@ class TestQdrantDB(unittest.TestCase):
# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
db.reset()
qdrant_client_mock.return_value.delete_collection.assert_called_once_with(

View File

@@ -29,7 +29,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
expected_class_obj = {
"classes": [
@@ -96,7 +96,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
expected_client = db._get_or_create_db()
self.assertEqual(expected_client, weaviate_client_mock)
@@ -115,7 +115,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
db.BATCH_SIZE = 1
embeddings = [[1, 2, 3], [4, 5, 6]]
@@ -159,7 +159,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
db.query(input_query=["This is a test document."], n_results=1, where={}, skip_embedding=True)
@@ -184,7 +184,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
# Query for the document.
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
@@ -210,7 +210,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
# Reset the database.
db.reset()
@@ -232,7 +232,7 @@ class TestWeaviateDb(unittest.TestCase):
# Create a Weaviate instance
db = WeaviateDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
App(config=app_config, db=db, embedding_model=embedder)
# Reset the database.
db.count()