[Refactor] Converge Pipeline and App classes (#1021)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -63,7 +63,7 @@ For example, you can create an Elon Musk bot using the following code:
|
||||
|
||||
```python
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# Create a bot instance
|
||||
os.environ["OPENAI_API_KEY"] = "YOUR API KEY"
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
app:
|
||||
config:
|
||||
id: 'my-app'
|
||||
collection_name: 'my-app'
|
||||
|
||||
llm:
|
||||
provider: openai
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
app:
|
||||
config:
|
||||
id: 'open-source-app'
|
||||
collection_name: 'open-source-app'
|
||||
collect_metrics: false
|
||||
|
||||
llm:
|
||||
|
||||
@@ -21,7 +21,7 @@ title: '📊 add'
|
||||
### Load data from webpage
|
||||
|
||||
```python Code example
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
app.add("https://www.forbes.com/profile/elon-musk")
|
||||
@@ -32,7 +32,7 @@ app.add("https://www.forbes.com/profile/elon-musk")
|
||||
### Load data from sitemap
|
||||
|
||||
```python Code example
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
app.add("https://python.langchain.com/sitemap.xml", data_type="sitemap")
|
||||
|
||||
@@ -36,7 +36,7 @@ title: '💬 chat'
|
||||
If you want to get the answer to question and return both answer and citations, use the following code snippet:
|
||||
|
||||
```python With Citations
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# Initialize app
|
||||
app = App()
|
||||
@@ -79,7 +79,7 @@ When `citations=True`, note that the returned `sources` are a list of tuples whe
|
||||
If you just want to return answers and don't want to return citations, you can use the following example:
|
||||
|
||||
```python Without Citations
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# Initialize app
|
||||
app = App()
|
||||
|
||||
@@ -7,7 +7,7 @@ title: 🗑 delete
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ The `deploy()` method not only deploys your pipeline but also efficiently manage
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# Initialize app
|
||||
app = App()
|
||||
|
||||
@@ -41,7 +41,7 @@ You can create an embedchain pipeline instance using the following methods:
|
||||
### Default setting
|
||||
|
||||
```python Code Example
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
app = App()
|
||||
```
|
||||
|
||||
@@ -49,7 +49,7 @@ app = App()
|
||||
### Python Dict
|
||||
|
||||
```python Code Example
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
config_dict = {
|
||||
'llm': {
|
||||
@@ -76,7 +76,7 @@ app = App.from_config(config=config_dict)
|
||||
<CodeGroup>
|
||||
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load llm configuration from config.yaml file
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
@@ -103,7 +103,7 @@ embedder:
|
||||
<CodeGroup>
|
||||
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load llm configuration from config.json file
|
||||
app = App.from_config(config_path="config.json")
|
||||
|
||||
@@ -36,7 +36,7 @@ title: '❓ query'
|
||||
If you want to get the answer to question and return both answer and citations, use the following code snippet:
|
||||
|
||||
```python With Citations
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# Initialize app
|
||||
app = App()
|
||||
@@ -78,7 +78,7 @@ When `citations=True`, note that the returned `sources` are a list of tuples whe
|
||||
If you just want to return answers and don't want to return citations, you can use the following example:
|
||||
|
||||
```python Without Citations
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# Initialize app
|
||||
app = App()
|
||||
|
||||
@@ -7,7 +7,7 @@ title: 🔄 reset
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
app.add("https://www.forbes.com/profile/elon-musk")
|
||||
|
||||
@@ -24,7 +24,7 @@ title: '🔍 search'
|
||||
Refer to the following example on how to use the search api:
|
||||
|
||||
```python Code example
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# Initialize app
|
||||
app = App()
|
||||
|
||||
@@ -5,7 +5,7 @@ title: "🐝 Beehiiv"
|
||||
To add any Beehiiv data sources to your app, just add the base url as the source and set the data_type to `beehiiv`.
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ title: '📊 CSV'
|
||||
To add any csv file, use the data_type as `csv`. `csv` allows remote urls and conventional file paths. Headers are included for each line, so if you have an `age` column, `18` will be added as `age: 18`. Eg:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
app.add('https://people.sc.fsu.edu/~jburkardt/data/csv/airtravel.csv', data_type="csv")
|
||||
|
||||
@@ -5,7 +5,7 @@ title: '⚙️ Custom'
|
||||
When we say "custom", we mean that you can customize the loader and chunker to your needs. This is done by passing a custom loader and chunker to the `add` method.
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
import your_loader
|
||||
import your_chunker
|
||||
|
||||
@@ -27,7 +27,7 @@ app.add("source", data_type="custom", loader=loader, chunker=chunker)
|
||||
Example:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
from embedchain.loaders.github import GithubLoader
|
||||
|
||||
app = App()
|
||||
|
||||
@@ -35,7 +35,7 @@ Default behavior is to create a persistent vector db in the directory **./db**.
|
||||
Create a local index:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
naval_chat_bot = App()
|
||||
naval_chat_bot.add("https://www.youtube.com/watch?v=3qHkcs3kG44")
|
||||
@@ -45,7 +45,7 @@ naval_chat_bot.add("https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Alma
|
||||
You can reuse the local index with the same code, but without adding new documents:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
naval_chat_bot = App()
|
||||
print(naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?"))
|
||||
@@ -56,7 +56,7 @@ print(naval_chat_bot.query("What unique capacity does Naval argue humans possess
|
||||
You can reset the app by simply calling the `reset` method. This will delete the vector database and all other app related files.
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
app.add("https://www.youtube.com/watch?v=3qHkcs3kG44")
|
||||
|
||||
@@ -8,7 +8,7 @@ To use an entire directory as data source, just add `data_type` as `directory` a
|
||||
|
||||
```python
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xxx"
|
||||
|
||||
@@ -23,7 +23,7 @@ print(response)
|
||||
|
||||
```python
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xxx"
|
||||
|
||||
@@ -12,7 +12,7 @@ To add any Discord channel messages to your app, just add the `channel_id` as th
|
||||
|
||||
```python
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# add your discord "BOT" token
|
||||
os.environ["DISCORD_TOKEN"] = "xxx"
|
||||
|
||||
@@ -5,7 +5,7 @@ title: '📚 Code documentation'
|
||||
To add any code documentation website as a loader, use the data_type as `docs_site`. Eg:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
app.add("https://docs.embedchain.ai/", data_type="docs_site")
|
||||
|
||||
@@ -7,7 +7,7 @@ title: '📄 Docx file'
|
||||
To add any doc/docx file, use the data_type as `docx`. `docx` allows remote urls and conventional file paths. Eg:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
app.add('https://example.com/content/intro.docx', data_type="docx")
|
||||
|
||||
@@ -24,7 +24,7 @@ To use this you need to save `credentials.json` in the directory from where you
|
||||
12. Put the `.json` file in your current directory and rename it to `credentials.json`
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ If you would like to add other data structures (e.g. list, dict etc.), convert i
|
||||
<CodeGroup>
|
||||
|
||||
```python python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ title: '📝 Mdx file'
|
||||
To add any `.mdx` file to your app, use the data_type (first argument to `.add()` method) as `mdx`. Note that this supports support mdx file present on machine, so this should be a file path. Eg:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
app.add('path/to/file.mdx', data_type='mdx')
|
||||
|
||||
@@ -8,7 +8,7 @@ To load a notion page, use the data_type as `notion`. Since it is hard to automa
|
||||
The next argument must **end** with the `notion page id`. The id is a 32-character string. Eg:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ title: 🙌 OpenAPI
|
||||
To add any OpenAPI spec yaml file (currently the json file will be detected as JSON data type), use the data_type as 'openapi'. 'openapi' allows remote urls and conventional file paths.
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ title: '📰 PDF file'
|
||||
To add any pdf file, use the data_type as `pdf_file`. Eg:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ title: '❓💬 Queston and answer pair'
|
||||
QnA pair is a local data type. To supply your own QnA pair, use the data_type as `qna_pair` and enter a tuple. Eg:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ title: '🗺️ Sitemap'
|
||||
Add all web pages from an xml-sitemap. Filters non-text files. Use the data_type as `sitemap`. Eg:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ This will automatically retrieve data from the workspace associated with the use
|
||||
|
||||
```python
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["SLACK_USER_TOKEN"] = "xoxp-xxx"
|
||||
app = App()
|
||||
|
||||
@@ -5,7 +5,7 @@ title: "📝 Substack"
|
||||
To add any Substack data sources to your app, just add the main base url as the source and set the data_type to `substack`.
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ title: '📝 Text'
|
||||
Text is a local data type. To supply your own text, use the data_type as `text` and enter a string. The text is not processed, this can be very versatile. Eg:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ title: '🌐 Web page'
|
||||
To add any web page, use the data_type as `web_page`. Eg:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ title: '🧾 XML file'
|
||||
To add any xml file, use the data_type as `xml`. Eg:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ pip install -u "embedchain[youtube]"
|
||||
</Note>
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
app.add("@channel_name", data_type="youtube_channel")
|
||||
|
||||
@@ -5,7 +5,7 @@ title: '📺 Youtube'
|
||||
To add any youtube video to your app, use the data_type as `youtube_video`. Eg:
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
app.add('a_valid_youtube_url_here', data_type='youtube_video')
|
||||
|
||||
@@ -25,7 +25,7 @@ Once you have obtained the key, you can use it like this:
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = 'xxx'
|
||||
|
||||
@@ -52,7 +52,7 @@ To use Google AI embedding function, you have to set the `GOOGLE_API_KEY` enviro
|
||||
<CodeGroup>
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["GOOGLE_API_KEY"] = "xxx"
|
||||
|
||||
@@ -81,7 +81,7 @@ To use Azure OpenAI embedding model, you have to set some of the azure openai re
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["OPENAI_API_TYPE"] = "azure"
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://xxx.openai.azure.com/"
|
||||
@@ -119,7 +119,7 @@ GPT4All supports generating high quality embeddings of arbitrary length document
|
||||
<CodeGroup>
|
||||
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load embedding model configuration from config.yaml file
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
@@ -148,7 +148,7 @@ Hugging Face supports generating embeddings of arbitrary length documents of tex
|
||||
<CodeGroup>
|
||||
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load embedding model configuration from config.yaml file
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
@@ -179,7 +179,7 @@ Embedchain supports Google's VertexAI embeddings model through a simple interfac
|
||||
<CodeGroup>
|
||||
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load embedding model configuration from config.yaml file
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
|
||||
@@ -29,7 +29,7 @@ Once you have obtained the key, you can use it like this:
|
||||
|
||||
```python
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = 'xxx'
|
||||
|
||||
@@ -44,7 +44,7 @@ If you are looking to configure the different parameters of the LLM, you can do
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = 'xxx'
|
||||
|
||||
@@ -71,7 +71,7 @@ Examples:
|
||||
<Accordion title="Using Pydantic Models">
|
||||
```python
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
import requests
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
@@ -123,7 +123,7 @@ print(result)
|
||||
<Accordion title="Using OpenAI JSON schema">
|
||||
```python
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
import requests
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
@@ -158,7 +158,7 @@ print(result)
|
||||
<Accordion title="Using actual python functions">
|
||||
```python
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
import requests
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
@@ -192,7 +192,7 @@ To use Google AI model, you have to set the `GOOGLE_API_KEY` environment variabl
|
||||
<CodeGroup>
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["GOOGLE_API_KEY"] = "xxx"
|
||||
|
||||
@@ -235,7 +235,7 @@ To use Azure OpenAI model, you have to set some of the azure openai related envi
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["OPENAI_API_TYPE"] = "azure"
|
||||
os.environ["OPENAI_API_BASE"] = "https://xxx.openai.azure.com/"
|
||||
@@ -274,7 +274,7 @@ To use anthropic's model, please set the `ANTHROPIC_API_KEY` which you find on t
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["ANTHROPIC_API_KEY"] = "xxx"
|
||||
|
||||
@@ -311,7 +311,7 @@ Once you have the API key, you are all set to use it with Embedchain.
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["COHERE_API_KEY"] = "xxx"
|
||||
|
||||
@@ -347,7 +347,7 @@ Once you have the API key, you are all set to use it with Embedchain.
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["TOGETHER_API_KEY"] = "xxx"
|
||||
|
||||
@@ -375,7 +375,7 @@ Setup Ollama using https://github.com/jmorganca/ollama
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load llm configuration from config.yaml file
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
@@ -406,7 +406,7 @@ GPT4all is a free-to-use, locally running, privacy-aware chatbot. No GPU or inte
|
||||
<CodeGroup>
|
||||
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load llm configuration from config.yaml file
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
@@ -438,7 +438,7 @@ Once you have the key, load the app using the config yaml file:
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["JINACHAT_API_KEY"] = "xxx"
|
||||
# load llm configuration from config.yaml file
|
||||
@@ -474,7 +474,7 @@ Once you have the token, load the app using the config yaml file:
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "xxx"
|
||||
|
||||
@@ -504,7 +504,7 @@ Once you have the token, load the app using the config yaml file:
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["REPLICATE_API_TOKEN"] = "xxx"
|
||||
|
||||
@@ -531,7 +531,7 @@ Setup Google Cloud Platform application credentials by following the instruction
|
||||
<CodeGroup>
|
||||
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load llm configuration from config.yaml file
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
|
||||
@@ -22,7 +22,7 @@ Utilizing a vector database alongside Embedchain is a seamless process. All you
|
||||
<CodeGroup>
|
||||
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load chroma configuration from yaml file
|
||||
app = App.from_config(config_path="config1.yaml")
|
||||
@@ -67,7 +67,7 @@ You can authorize the connection to Elasticsearch by providing either `basic_aut
|
||||
<CodeGroup>
|
||||
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load elasticsearch configuration from yaml file
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
@@ -97,7 +97,7 @@ pip install --upgrade 'embedchain[opensearch]'
|
||||
<CodeGroup>
|
||||
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load opensearch configuration from yaml file
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
@@ -133,7 +133,7 @@ Set the Zilliz environment variables `ZILLIZ_CLOUD_URI` and `ZILLIZ_CLOUD_TOKEN`
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ['ZILLIZ_CLOUD_URI'] = 'https://xxx.zillizcloud.com'
|
||||
os.environ['ZILLIZ_CLOUD_TOKEN'] = 'xxx'
|
||||
@@ -172,7 +172,7 @@ In order to use Pinecone as vector database, set the environment variables `PINE
|
||||
<CodeGroup>
|
||||
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load pinecone configuration from yaml file
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
@@ -195,7 +195,7 @@ In order to use Qdrant as a vector database, set the environment variables `QDRA
|
||||
|
||||
<CodeGroup>
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load qdrant configuration from yaml file
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
@@ -215,7 +215,7 @@ In order to use Weaviate as a vector database, set the environment variables `WE
|
||||
|
||||
<CodeGroup>
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load weaviate configuration from yaml file
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
|
||||
@@ -10,7 +10,7 @@ Embedchain enables developers to deploy their LLM-powered apps in production usi
|
||||
See the example below on how to use the deploy your app (for free):
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# Initialize app
|
||||
app = App()
|
||||
|
||||
@@ -11,7 +11,7 @@ Use the model provided on huggingface: `mistralai/Mistral-7B-v0.1`
|
||||
<CodeGroup>
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "hf_your_token"
|
||||
|
||||
@@ -40,7 +40,7 @@ Use the model `gpt-4-turbo` provided my openai.
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = 'xxx'
|
||||
|
||||
@@ -65,7 +65,7 @@ llm:
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = 'xxx'
|
||||
|
||||
@@ -90,7 +90,7 @@ llm:
|
||||
<CodeGroup>
|
||||
|
||||
```python main.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
# load llm configuration from opensource.yaml file
|
||||
app = App.from_config(config_path="opensource.yaml")
|
||||
@@ -131,7 +131,7 @@ llm:
|
||||
|
||||
```python main.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = 'sk-xxx'
|
||||
|
||||
@@ -149,7 +149,7 @@ response = app.query("What is the net worth of Elon Musk?")
|
||||
Set up the app by adding an `id` in the config file. This keeps the data for future use. You can include this `id` in the yaml config or input it directly in `config` dict.
|
||||
```python app1.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = 'sk-xxx'
|
||||
|
||||
@@ -167,7 +167,7 @@ response = app.query("What is the net worth of Elon Musk?")
|
||||
```
|
||||
```python app2.py
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = 'sk-xxx'
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ Creating an app involves 3 steps:
|
||||
<Steps>
|
||||
<Step title="⚙️ Import app instance">
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
app = App()
|
||||
```
|
||||
<Accordion title="Customize your app by a simple YAML config" icon="gear-complex">
|
||||
@@ -22,15 +22,15 @@ Creating an app involves 3 steps:
|
||||
Explore the custom configurations [here](https://docs.embedchain.ai/advanced/configuration).
|
||||
<CodeGroup>
|
||||
```python yaml_app.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
```
|
||||
```python json_app.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
app = App.from_config(config_path="config.json")
|
||||
```
|
||||
```python app.py
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
config = {} # Add your config here
|
||||
app = App.from_config(config=config)
|
||||
```
|
||||
|
||||
@@ -21,7 +21,7 @@ Create a new file called `app.py` and add the following code:
|
||||
|
||||
```python
|
||||
import chainlit as cl
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ os.environ['LANGCHAIN_PROJECT] = <your-project>
|
||||
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
app = App()
|
||||
app.add("https://en.wikipedia.org/wiki/Elon_Musk")
|
||||
|
||||
@@ -17,7 +17,7 @@ pip install embedchain streamlit
|
||||
<Tab title="app.py">
|
||||
```python
|
||||
import os
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
import streamlit as st
|
||||
|
||||
with st.sidebar:
|
||||
|
||||
@@ -24,7 +24,7 @@ Quickly create a RAG pipeline to answer queries about the [Next.JS Framework](ht
|
||||
First, let's create your RAG pipeline. Open your Python environment and enter:
|
||||
|
||||
```python Create pipeline
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
app = App()
|
||||
```
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ Embedchain offers a simple yet customizable `search()` API that you can use for
|
||||
First, let's create your RAG pipeline. Open your Python environment and enter:
|
||||
|
||||
```python Create pipeline
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
app = App()
|
||||
```
|
||||
|
||||
|
||||
@@ -2,10 +2,9 @@ import importlib.metadata
|
||||
|
||||
__version__ = importlib.metadata.version(__package__ or __name__)
|
||||
|
||||
from embedchain.apps.app import App # noqa: F401
|
||||
from embedchain.app import App # noqa: F401
|
||||
from embedchain.client import Client # noqa: F401
|
||||
from embedchain.pipeline import Pipeline # noqa: F401
|
||||
from embedchain.vectordb.chroma import ChromaDB # noqa: F401
|
||||
|
||||
# Setup the user directory if doesn't exist already
|
||||
Client.setup_dir()
|
||||
|
||||
431
embedchain/app.py
Normal file
431
embedchain/app.py
Normal file
@@ -0,0 +1,431 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from embedchain.client import Client
|
||||
from embedchain.config import AppConfig, ChunkerConfig
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
# Setup the user directory if doesn't exist already
|
||||
Client.setup_dir()
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class App(EmbedChain):
|
||||
"""
|
||||
EmbedChain App lets you create a LLM powered app for your unstructured
|
||||
data by defining your chosen data source, embedding model,
|
||||
and vector database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = None,
|
||||
name: str = None,
|
||||
config: AppConfig = None,
|
||||
db: BaseVectorDB = None,
|
||||
embedding_model: BaseEmbedder = None,
|
||||
llm: BaseLlm = None,
|
||||
config_data: dict = None,
|
||||
log_level=logging.WARN,
|
||||
auto_deploy: bool = False,
|
||||
chunker: ChunkerConfig = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new `App` instance.
|
||||
|
||||
:param config: Configuration for the pipeline, defaults to None
|
||||
:type config: AppConfig, optional
|
||||
:param db: The database to use for storing and retrieving embeddings, defaults to None
|
||||
:type db: BaseVectorDB, optional
|
||||
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
|
||||
:type embedding_model: BaseEmbedder, optional
|
||||
:param llm: The LLM model used to calculate embeddings, defaults to None
|
||||
:type llm: BaseLlm, optional
|
||||
:param config_data: Config dictionary, defaults to None
|
||||
:type config_data: dict, optional
|
||||
:param log_level: Log level to use, defaults to logging.WARN
|
||||
:type log_level: int, optional
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:raises Exception: If an error occurs while creating the pipeline
|
||||
"""
|
||||
if id and config_data:
|
||||
raise Exception("Cannot provide both id and config. Please provide only one of them.")
|
||||
|
||||
if id and name:
|
||||
raise Exception("Cannot provide both id and name. Please provide only one of them.")
|
||||
|
||||
if name and config:
|
||||
raise Exception("Cannot provide both name and config. Please provide only one of them.")
|
||||
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.auto_deploy = auto_deploy
|
||||
# Store the dict config as an attribute to be able to send it
|
||||
self.config_data = config_data if (config_data and validate_config(config_data)) else None
|
||||
self.client = None
|
||||
# pipeline_id from the backend
|
||||
self.id = None
|
||||
self.chunker = None
|
||||
if chunker:
|
||||
self.chunker = ChunkerConfig(**chunker)
|
||||
|
||||
self.config = config or AppConfig()
|
||||
self.name = self.config.name
|
||||
self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id
|
||||
|
||||
if id is not None:
|
||||
# Init client first since user is trying to fetch the pipeline
|
||||
# details from the platform
|
||||
self._init_client()
|
||||
pipeline_details = self._get_pipeline(id)
|
||||
self.config.id = self.local_id = pipeline_details["metadata"]["local_id"]
|
||||
self.id = id
|
||||
|
||||
if name is not None:
|
||||
self.name = name
|
||||
|
||||
self.embedding_model = embedding_model or OpenAIEmbedder()
|
||||
self.db = db or ChromaDB()
|
||||
self.llm = llm or OpenAILlm()
|
||||
self._init_db()
|
||||
|
||||
# Send anonymous telemetry
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||
|
||||
# Establish a connection to the SQLite database
|
||||
self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
# Create the 'data_sources' table if it doesn't exist
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS data_sources (
|
||||
pipeline_id TEXT,
|
||||
hash TEXT,
|
||||
type TEXT,
|
||||
value TEXT,
|
||||
metadata TEXT,
|
||||
is_uploaded INTEGER DEFAULT 0,
|
||||
PRIMARY KEY (pipeline_id, hash)
|
||||
)
|
||||
"""
|
||||
)
|
||||
self.connection.commit()
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
|
||||
|
||||
self.user_asks = []
|
||||
if self.auto_deploy:
|
||||
self.deploy()
|
||||
|
||||
def _init_db(self):
|
||||
"""
|
||||
Initialize the database.
|
||||
"""
|
||||
self.db._set_embedder(self.embedding_model)
|
||||
self.db._initialize()
|
||||
self.db.set_collection_name(self.db.config.collection_name)
|
||||
|
||||
def _init_client(self):
|
||||
"""
|
||||
Initialize the client.
|
||||
"""
|
||||
config = Client.load_config()
|
||||
if config.get("api_key"):
|
||||
self.client = Client()
|
||||
else:
|
||||
api_key = input(
|
||||
"🔑 Enter your Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n" # noqa: E501
|
||||
)
|
||||
self.client = Client(api_key=api_key)
|
||||
|
||||
def _get_pipeline(self, id):
|
||||
"""
|
||||
Get existing pipeline
|
||||
"""
|
||||
print("🛠️ Fetching pipeline details from the platform...")
|
||||
url = f"{self.client.host}/api/v1/pipelines/{id}/cli/"
|
||||
r = requests.get(
|
||||
url,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
if r.status_code == 404:
|
||||
raise Exception(f"❌ Pipeline with id {id} not found!")
|
||||
|
||||
print(
|
||||
f"🎉 Pipeline loaded successfully! Pipeline url: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
)
|
||||
return r.json()
|
||||
|
||||
def _create_pipeline(self):
|
||||
"""
|
||||
Create a pipeline on the platform.
|
||||
"""
|
||||
print("🛠️ Creating pipeline on the platform...")
|
||||
# self.config_data is a dict. Pass it inside the key 'yaml_config' to the backend
|
||||
payload = {
|
||||
"yaml_config": json.dumps(self.config_data),
|
||||
"name": self.name,
|
||||
"local_id": self.local_id,
|
||||
}
|
||||
url = f"{self.client.host}/api/v1/pipelines/cli/create/"
|
||||
r = requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
if r.status_code not in [200, 201]:
|
||||
raise Exception(f"❌ Error occurred while creating pipeline. API response: {r.text}")
|
||||
|
||||
if r.status_code == 200:
|
||||
print(
|
||||
f"🎉🎉🎉 Existing pipeline found! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
) # noqa: E501
|
||||
elif r.status_code == 201:
|
||||
print(
|
||||
f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
)
|
||||
return r.json()
|
||||
|
||||
def _get_presigned_url(self, data_type, data_value):
|
||||
payload = {"data_type": data_type, "data_value": data_value}
|
||||
r = requests.post(
|
||||
f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
def search(self, query, num_documents=3):
|
||||
"""
|
||||
Search for similar documents related to the query in the vector database.
|
||||
"""
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
|
||||
|
||||
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
|
||||
if self.id is None:
|
||||
where = {"app_id": self.local_id}
|
||||
context = self.db.query(
|
||||
query,
|
||||
n_results=num_documents,
|
||||
where=where,
|
||||
skip_embedding=False,
|
||||
citations=True,
|
||||
)
|
||||
result = []
|
||||
for c in context:
|
||||
result.append(
|
||||
{
|
||||
"context": c[0],
|
||||
"source": c[1],
|
||||
"document_id": c[2],
|
||||
}
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# Make API call to the backend to get the results
|
||||
NotImplementedError("Search is not implemented yet for the prod mode.")
|
||||
|
||||
def _upload_file_to_presigned_url(self, presigned_url, file_path):
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
response = requests.put(presigned_url, data=file)
|
||||
response.raise_for_status()
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error occurred during file upload: {str(e)}")
|
||||
print("❌ Error occurred during file upload!")
|
||||
return False
|
||||
|
||||
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
|
||||
payload = {
|
||||
"data_type": data_type,
|
||||
"data_value": data_value,
|
||||
"metadata": metadata,
|
||||
}
|
||||
try:
|
||||
self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
|
||||
# print the local file path if user tries to upload a local file
|
||||
printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
|
||||
print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
|
||||
except Exception as e:
|
||||
print(f"❌ Error occurred during data upload for type {data_type}!. Error: {str(e)}")
|
||||
|
||||
def _send_api_request(self, endpoint, payload):
|
||||
url = f"{self.client.host}{endpoint}"
|
||||
headers = {"Authorization": f"Token {self.client.api_key}"}
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _process_and_upload_data(self, data_hash, data_type, data_value):
|
||||
if os.path.isabs(data_value):
|
||||
presigned_url_data = self._get_presigned_url(data_type, data_value)
|
||||
presigned_url = presigned_url_data["presigned_url"]
|
||||
s3_key = presigned_url_data["s3_key"]
|
||||
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
|
||||
metadata = {"file_path": data_value, "s3_key": s3_key}
|
||||
data_value = presigned_url
|
||||
else:
|
||||
self.logger.error(f"File upload failed for hash: {data_hash}")
|
||||
return False
|
||||
else:
|
||||
if data_type == "qna_pair":
|
||||
data_value = list(ast.literal_eval(data_value))
|
||||
metadata = {}
|
||||
|
||||
try:
|
||||
self._upload_data_to_pipeline(data_type, data_value, metadata)
|
||||
self._mark_data_as_uploaded(data_hash)
|
||||
return True
|
||||
except Exception:
|
||||
print(f"❌ Error occurred during data upload for hash {data_hash}!")
|
||||
return False
|
||||
|
||||
def _mark_data_as_uploaded(self, data_hash):
|
||||
self.cursor.execute(
|
||||
"UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ?",
|
||||
(data_hash, self.local_id),
|
||||
)
|
||||
self.connection.commit()
|
||||
|
||||
def get_data_sources(self):
|
||||
db_data = self.cursor.execute("SELECT * FROM data_sources WHERE pipeline_id = ?", (self.local_id,)).fetchall()
|
||||
|
||||
data_sources = []
|
||||
for data in db_data:
|
||||
data_sources.append({"data_type": data[2], "data_value": data[3], "metadata": data[4]})
|
||||
|
||||
return data_sources
|
||||
|
||||
def deploy(self):
|
||||
if self.client is None:
|
||||
self._init_client()
|
||||
|
||||
pipeline_data = self._create_pipeline()
|
||||
self.id = pipeline_data["id"]
|
||||
|
||||
results = self.cursor.execute(
|
||||
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,) # noqa:E501
|
||||
).fetchall()
|
||||
|
||||
if len(results) > 0:
|
||||
print("🛠️ Adding data to your pipeline...")
|
||||
for result in results:
|
||||
data_hash, data_type, data_value = result[1], result[2], result[3]
|
||||
self._process_and_upload_data(data_hash, data_type, data_value)
|
||||
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config_path: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
auto_deploy: bool = False,
|
||||
yaml_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Instantiate a Pipeline object from a configuration.
|
||||
|
||||
:param config_path: Path to the YAML or JSON configuration file.
|
||||
:type config_path: Optional[str]
|
||||
:param config: A dictionary containing the configuration.
|
||||
:type config: Optional[Dict[str, Any]]
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
|
||||
:type yaml_path: Optional[str]
|
||||
:return: An instance of the Pipeline class.
|
||||
:rtype: Pipeline
|
||||
"""
|
||||
# Backward compatibility for yaml_path
|
||||
if yaml_path and not config_path:
|
||||
config_path = yaml_path
|
||||
|
||||
if config_path and config:
|
||||
raise ValueError("Please provide only one of config_path or config.")
|
||||
|
||||
config_data = None
|
||||
|
||||
if config_path:
|
||||
file_extension = os.path.splitext(config_path)[1]
|
||||
with open(config_path, "r") as file:
|
||||
if file_extension in [".yaml", ".yml"]:
|
||||
config_data = yaml.safe_load(file)
|
||||
elif file_extension == ".json":
|
||||
config_data = json.load(file)
|
||||
else:
|
||||
raise ValueError("config_path must be a path to a YAML or JSON file.")
|
||||
elif config and isinstance(config, dict):
|
||||
config_data = config
|
||||
else:
|
||||
logging.error(
|
||||
"Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
|
||||
)
|
||||
config_data = {}
|
||||
|
||||
try:
|
||||
validate_config(config_data)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error occurred while validating the config. Error: {str(e)}")
|
||||
|
||||
app_config_data = config_data.get("app", {}).get("config", {})
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
chunker_config_data = config_data.get("chunker", {})
|
||||
|
||||
app_config = AppConfig(**app_config_data)
|
||||
|
||||
db_provider = db_config_data.get("provider", "chroma")
|
||||
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
|
||||
|
||||
if llm_config_data:
|
||||
llm_provider = llm_config_data.get("provider", "openai")
|
||||
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
|
||||
else:
|
||||
llm = None
|
||||
|
||||
embedding_model_provider = embedding_model_config_data.get("provider", "openai")
|
||||
embedding_model = EmbedderFactory.create(
|
||||
embedding_model_provider, embedding_model_config_data.get("config", {})
|
||||
)
|
||||
|
||||
# Send anonymous telemetry
|
||||
event_properties = {"init_type": "config_data"}
|
||||
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
|
||||
|
||||
return cls(
|
||||
config=app_config,
|
||||
llm=llm,
|
||||
db=db,
|
||||
embedding_model=embedding_model,
|
||||
config_data=config_data,
|
||||
auto_deploy=auto_deploy,
|
||||
chunker=chunker_config_data,
|
||||
)
|
||||
@@ -1,157 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
|
||||
ChunkerConfig)
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.utils import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class App(EmbedChain):
|
||||
"""
|
||||
The EmbedChain app in it's simplest and most straightforward form.
|
||||
An opinionated choice of LLM, vector database and embedding model.
|
||||
|
||||
Methods:
|
||||
add(source, data_type): adds the data from the given URL to the vector db.
|
||||
query(query): finds answer to the given query using vector database and LLM.
|
||||
chat(query): finds answer to the given query using vector database and LLM, with conversation history.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[AppConfig] = None,
|
||||
llm: BaseLlm = None,
|
||||
llm_config: Optional[BaseLlmConfig] = None,
|
||||
db: BaseVectorDB = None,
|
||||
db_config: Optional[BaseVectorDbConfig] = None,
|
||||
embedder: BaseEmbedder = None,
|
||||
embedder_config: Optional[BaseEmbedderConfig] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
chunker: Optional[ChunkerConfig] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new `App` instance.
|
||||
|
||||
:param config: Config for the app instance., defaults to None
|
||||
:type config: Optional[AppConfig], optional
|
||||
:param llm: LLM Class instance. example: `from embedchain.llm.openai import OpenAILlm`, defaults to OpenAiLlm
|
||||
:type llm: BaseLlm, optional
|
||||
:param llm_config: Allows you to configure the LLM, e.g. how many documents to return,
|
||||
example: `from embedchain.config import BaseLlmConfig`, defaults to None
|
||||
:type llm_config: Optional[BaseLlmConfig], optional
|
||||
:param db: The database to use for storing and retrieving embeddings,
|
||||
example: `from embedchain.vectordb.chroma_db import ChromaDb`, defaults to ChromaDb
|
||||
:type db: BaseVectorDB, optional
|
||||
:param db_config: Allows you to configure the vector database,
|
||||
example: `from embedchain.config import ChromaDbConfig`, defaults to None
|
||||
:type db_config: Optional[BaseVectorDbConfig], optional
|
||||
:param embedder: The embedder (embedding model and function) use to calculate embeddings.
|
||||
example: `from embedchain.embedder.gpt4all_embedder import GPT4AllEmbedder`, defaults to OpenAIEmbedder
|
||||
:type embedder: BaseEmbedder, optional
|
||||
:param embedder_config: Allows you to configure the Embedder.
|
||||
example: `from embedchain.config import BaseEmbedderConfig`, defaults to None
|
||||
:type embedder_config: Optional[BaseEmbedderConfig], optional
|
||||
:param system_prompt: System prompt that will be provided to the LLM as such, defaults to None
|
||||
:type system_prompt: Optional[str], optional
|
||||
:raises TypeError: LLM, database or embedder or their config is not a valid class instance.
|
||||
"""
|
||||
# Type check configs
|
||||
if config and not isinstance(config, AppConfig):
|
||||
raise TypeError(
|
||||
"Config is not a `AppConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if llm_config and not isinstance(llm_config, BaseLlmConfig):
|
||||
raise TypeError(
|
||||
"`llm_config` is not a `BaseLlmConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if db_config and not isinstance(db_config, BaseVectorDbConfig):
|
||||
raise TypeError(
|
||||
"`db_config` is not a `BaseVectorDbConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if embedder_config and not isinstance(embedder_config, BaseEmbedderConfig):
|
||||
raise TypeError(
|
||||
"`embedder_config` is not a `BaseEmbedderConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
|
||||
# Assign defaults
|
||||
if config is None:
|
||||
config = AppConfig()
|
||||
if llm is None:
|
||||
llm = OpenAILlm(config=llm_config)
|
||||
if db is None:
|
||||
db = ChromaDB(config=db_config)
|
||||
if embedder is None:
|
||||
embedder = OpenAIEmbedder(config=embedder_config)
|
||||
|
||||
self.chunker = None
|
||||
if chunker:
|
||||
self.chunker = ChunkerConfig(**chunker)
|
||||
# Type check assignments
|
||||
if not isinstance(llm, BaseLlm):
|
||||
raise TypeError(
|
||||
"LLM is not a `BaseLlm` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if not isinstance(db, BaseVectorDB):
|
||||
raise TypeError(
|
||||
"Database is not a `BaseVectorDB` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if not isinstance(embedder, BaseEmbedder):
|
||||
raise TypeError(
|
||||
"Embedder is not a `BaseEmbedder` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
super().__init__(config, llm=llm, db=db, embedder=embedder, system_prompt=system_prompt)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, yaml_path: str):
|
||||
"""
|
||||
Instantiate an App object from a YAML configuration file.
|
||||
|
||||
:param yaml_path: Path to the YAML configuration file.
|
||||
:type yaml_path: str
|
||||
:return: An instance of the App class.
|
||||
:rtype: App
|
||||
"""
|
||||
with open(yaml_path, "r") as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
|
||||
try:
|
||||
validate_config(config_data)
|
||||
except Exception as e:
|
||||
raise Exception(f"❌ Error occurred while validating the YAML config. Error: {str(e)}")
|
||||
|
||||
app_config_data = config_data.get("app", {})
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
chunker_config_data = config_data.get("chunker", {})
|
||||
|
||||
app_config = AppConfig(**app_config_data.get("config", {}))
|
||||
|
||||
llm_provider = llm_config_data.get("provider", "openai")
|
||||
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
|
||||
|
||||
db_provider = db_config_data.get("provider", "chroma")
|
||||
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
|
||||
|
||||
embedder_provider = embedding_model_config_data.get("provider", "openai")
|
||||
embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {}))
|
||||
return cls(config=app_config, llm=llm, db=db, embedder=embedder, chunker=chunker_config_data)
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, PipelineConfig
|
||||
from embedchain import App
|
||||
from embedchain.config import AddConfig, AppConfig, BaseLlmConfig
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.helpers.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
@@ -12,7 +12,7 @@ from embedchain.vectordb.chroma import ChromaDB
|
||||
@register_deserializable
|
||||
class BaseBot(JSONSerializable):
|
||||
def __init__(self):
|
||||
self.app = App(config=PipelineConfig(), llm=OpenAILlm(), db=ChromaDB(), embedding_model=OpenAIEmbedder())
|
||||
self.app = App(config=AppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedding_model=OpenAIEmbedder())
|
||||
|
||||
def add(self, data: Any, config: AddConfig = None):
|
||||
"""
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from .add_config import AddConfig, ChunkerConfig
|
||||
from .apps.app_config import AppConfig
|
||||
from .app_config import AppConfig
|
||||
from .base_config import BaseConfig
|
||||
from .embedder.base import BaseEmbedderConfig
|
||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||
from .llm.base import BaseLlmConfig
|
||||
from .pipeline_config import PipelineConfig
|
||||
from .vectordb.chroma import ChromaDbConfig
|
||||
from .vectordb.elasticsearch import ElasticsearchDBConfig
|
||||
from .vectordb.opensearch import OpenSearchDBConfig
|
||||
|
||||
@@ -15,8 +15,9 @@ class AppConfig(BaseAppConfig):
|
||||
self,
|
||||
log_level: str = "WARNING",
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
collect_metrics: Optional[bool] = True,
|
||||
collection_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an App. This is the simplest form of an embedchain app.
|
||||
@@ -28,8 +29,6 @@ class AppConfig(BaseAppConfig):
|
||||
:type id: Optional[str], optional
|
||||
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
|
||||
:type collect_metrics: Optional[bool], optional
|
||||
:param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
|
||||
defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
"""
|
||||
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, collection_name=collection_name)
|
||||
self.name = name
|
||||
super().__init__(log_level=log_level, id=id, collect_metrics=collect_metrics, **kwargs)
|
||||
@@ -1,38 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .apps.base_app_config import BaseAppConfig
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PipelineConfig(BaseAppConfig):
|
||||
"""
|
||||
Config to initialize an embedchain custom `App` instance, with extra config options.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_level: str = "WARNING",
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
collect_metrics: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for an App. This is the simplest form of an embedchain app.
|
||||
Most of the configuration is done in the `App` class itself.
|
||||
|
||||
:param log_level: Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], defaults to "WARNING"
|
||||
:type log_level: str, optional
|
||||
:param id: ID of the app. Document metadata will have this id., defaults to None
|
||||
:type id: Optional[str], optional
|
||||
:param collect_metrics: Send anonymous telemetry to improve embedchain, defaults to True
|
||||
:type collect_metrics: Optional[bool], optional
|
||||
:param collection_name: Default collection name. It's recommended to use app.db.set_collection_name() instead,
|
||||
defaults to None
|
||||
:type collection_name: Optional[str], optional
|
||||
"""
|
||||
self._setup_logging(log_level)
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.collect_metrics = collect_metrics
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xxx"
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import streamlit as st
|
||||
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain.docstore.document import Document
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config.apps.base_app_config import BaseAppConfig
|
||||
from embedchain.config.base_app_config import BaseAppConfig
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
|
||||
@@ -1,425 +1,9 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from embedchain import Client
|
||||
from embedchain.config import ChunkerConfig, PipelineConfig
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
# Setup the user directory if doesn't exist already
|
||||
Client.setup_dir()
|
||||
from embedchain.app import App
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class Pipeline(EmbedChain):
|
||||
class Pipeline(App):
|
||||
"""
|
||||
EmbedChain pipeline lets you create a LLM powered app for your unstructured
|
||||
data by defining a pipeline with your chosen data source, embedding model,
|
||||
and vector database.
|
||||
This is deprecated. Use `App` instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = None,
|
||||
name: str = None,
|
||||
config: PipelineConfig = None,
|
||||
db: BaseVectorDB = None,
|
||||
embedding_model: BaseEmbedder = None,
|
||||
llm: BaseLlm = None,
|
||||
config_data: dict = None,
|
||||
log_level=logging.WARN,
|
||||
auto_deploy: bool = False,
|
||||
chunker: ChunkerConfig = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new `App` instance.
|
||||
|
||||
:param config: Configuration for the pipeline, defaults to None
|
||||
:type config: PipelineConfig, optional
|
||||
:param db: The database to use for storing and retrieving embeddings, defaults to None
|
||||
:type db: BaseVectorDB, optional
|
||||
:param embedding_model: The embedding model used to calculate embeddings, defaults to None
|
||||
:type embedding_model: BaseEmbedder, optional
|
||||
:param llm: The LLM model used to calculate embeddings, defaults to None
|
||||
:type llm: BaseLlm, optional
|
||||
:param config_data: Config dictionary, defaults to None
|
||||
:type config_data: dict, optional
|
||||
:param log_level: Log level to use, defaults to logging.WARN
|
||||
:type log_level: int, optional
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:raises Exception: If an error occurs while creating the pipeline
|
||||
"""
|
||||
if id and config_data:
|
||||
raise Exception("Cannot provide both id and config. Please provide only one of them.")
|
||||
|
||||
if id and name:
|
||||
raise Exception("Cannot provide both id and name. Please provide only one of them.")
|
||||
|
||||
if name and config:
|
||||
raise Exception("Cannot provide both name and config. Please provide only one of them.")
|
||||
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.auto_deploy = auto_deploy
|
||||
# Store the dict config as an attribute to be able to send it
|
||||
self.config_data = config_data if (config_data and validate_config(config_data)) else None
|
||||
self.client = None
|
||||
# pipeline_id from the backend
|
||||
self.id = None
|
||||
self.chunker = None
|
||||
if chunker:
|
||||
self.chunker = ChunkerConfig(**chunker)
|
||||
|
||||
self.config = config or PipelineConfig()
|
||||
self.name = self.config.name
|
||||
self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id
|
||||
|
||||
if id is not None:
|
||||
# Init client first since user is trying to fetch the pipeline
|
||||
# details from the platform
|
||||
self._init_client()
|
||||
pipeline_details = self._get_pipeline(id)
|
||||
self.config.id = self.local_id = pipeline_details["metadata"]["local_id"]
|
||||
self.id = id
|
||||
|
||||
if name is not None:
|
||||
self.name = name
|
||||
|
||||
self.embedding_model = embedding_model or OpenAIEmbedder()
|
||||
self.db = db or ChromaDB()
|
||||
self.llm = llm or OpenAILlm()
|
||||
self._init_db()
|
||||
|
||||
# Send anonymous telemetry
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||
|
||||
# Establish a connection to the SQLite database
|
||||
self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
# Create the 'data_sources' table if it doesn't exist
|
||||
self.cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS data_sources (
|
||||
pipeline_id TEXT,
|
||||
hash TEXT,
|
||||
type TEXT,
|
||||
value TEXT,
|
||||
metadata TEXT,
|
||||
is_uploaded INTEGER DEFAULT 0,
|
||||
PRIMARY KEY (pipeline_id, hash)
|
||||
)
|
||||
"""
|
||||
)
|
||||
self.connection.commit()
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
|
||||
|
||||
self.user_asks = []
|
||||
if self.auto_deploy:
|
||||
self.deploy()
|
||||
|
||||
def _init_db(self):
|
||||
"""
|
||||
Initialize the database.
|
||||
"""
|
||||
self.db._set_embedder(self.embedding_model)
|
||||
self.db._initialize()
|
||||
self.db.set_collection_name(self.db.config.collection_name)
|
||||
|
||||
def _init_client(self):
|
||||
"""
|
||||
Initialize the client.
|
||||
"""
|
||||
config = Client.load_config()
|
||||
if config.get("api_key"):
|
||||
self.client = Client()
|
||||
else:
|
||||
api_key = input(
|
||||
"🔑 Enter your Embedchain API key. You can find the API key at https://app.embedchain.ai/settings/keys/ \n" # noqa: E501
|
||||
)
|
||||
self.client = Client(api_key=api_key)
|
||||
|
||||
def _get_pipeline(self, id):
|
||||
"""
|
||||
Get existing pipeline
|
||||
"""
|
||||
print("🛠️ Fetching pipeline details from the platform...")
|
||||
url = f"{self.client.host}/api/v1/pipelines/{id}/cli/"
|
||||
r = requests.get(
|
||||
url,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
if r.status_code == 404:
|
||||
raise Exception(f"❌ Pipeline with id {id} not found!")
|
||||
|
||||
print(
|
||||
f"🎉 Pipeline loaded successfully! Pipeline url: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
)
|
||||
return r.json()
|
||||
|
||||
def _create_pipeline(self):
|
||||
"""
|
||||
Create a pipeline on the platform.
|
||||
"""
|
||||
print("🛠️ Creating pipeline on the platform...")
|
||||
# self.config_data is a dict. Pass it inside the key 'yaml_config' to the backend
|
||||
payload = {
|
||||
"yaml_config": json.dumps(self.config_data),
|
||||
"name": self.name,
|
||||
"local_id": self.local_id,
|
||||
}
|
||||
url = f"{self.client.host}/api/v1/pipelines/cli/create/"
|
||||
r = requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
if r.status_code not in [200, 201]:
|
||||
raise Exception(f"❌ Error occurred while creating pipeline. API response: {r.text}")
|
||||
|
||||
if r.status_code == 200:
|
||||
print(
|
||||
f"🎉🎉🎉 Existing pipeline found! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
) # noqa: E501
|
||||
elif r.status_code == 201:
|
||||
print(
|
||||
f"🎉🎉🎉 Pipeline created successfully! View your pipeline: https://app.embedchain.ai/pipelines/{r.json()['id']}\n" # noqa: E501
|
||||
)
|
||||
return r.json()
|
||||
|
||||
def _get_presigned_url(self, data_type, data_value):
|
||||
payload = {"data_type": data_type, "data_value": data_value}
|
||||
r = requests.post(
|
||||
f"{self.client.host}/api/v1/pipelines/{self.id}/cli/presigned_url/",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Token {self.client.api_key}"},
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
def search(self, query, num_documents=3):
|
||||
"""
|
||||
Search for similar documents related to the query in the vector database.
|
||||
"""
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
|
||||
|
||||
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
|
||||
if self.id is None:
|
||||
where = {"app_id": self.local_id}
|
||||
context = self.db.query(
|
||||
query,
|
||||
n_results=num_documents,
|
||||
where=where,
|
||||
skip_embedding=False,
|
||||
citations=True,
|
||||
)
|
||||
result = []
|
||||
for c in context:
|
||||
result.append({"context": c[0], "metadata": c[1]})
|
||||
return result
|
||||
else:
|
||||
# Make API call to the backend to get the results
|
||||
NotImplementedError("Search is not implemented yet for the prod mode.")
|
||||
|
||||
def _upload_file_to_presigned_url(self, presigned_url, file_path):
|
||||
try:
|
||||
with open(file_path, "rb") as file:
|
||||
response = requests.put(presigned_url, data=file)
|
||||
response.raise_for_status()
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error occurred during file upload: {str(e)}")
|
||||
print("❌ Error occurred during file upload!")
|
||||
return False
|
||||
|
||||
def _upload_data_to_pipeline(self, data_type, data_value, metadata=None):
|
||||
payload = {
|
||||
"data_type": data_type,
|
||||
"data_value": data_value,
|
||||
"metadata": metadata,
|
||||
}
|
||||
try:
|
||||
self._send_api_request(f"/api/v1/pipelines/{self.id}/cli/add/", payload)
|
||||
# print the local file path if user tries to upload a local file
|
||||
printed_value = metadata.get("file_path") if metadata.get("file_path") else data_value
|
||||
print(f"✅ Data of type: {data_type}, value: {printed_value} added successfully.")
|
||||
except Exception as e:
|
||||
print(f"❌ Error occurred during data upload for type {data_type}!. Error: {str(e)}")
|
||||
|
||||
def _send_api_request(self, endpoint, payload):
|
||||
url = f"{self.client.host}{endpoint}"
|
||||
headers = {"Authorization": f"Token {self.client.api_key}"}
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _process_and_upload_data(self, data_hash, data_type, data_value):
|
||||
if os.path.isabs(data_value):
|
||||
presigned_url_data = self._get_presigned_url(data_type, data_value)
|
||||
presigned_url = presigned_url_data["presigned_url"]
|
||||
s3_key = presigned_url_data["s3_key"]
|
||||
if self._upload_file_to_presigned_url(presigned_url, file_path=data_value):
|
||||
metadata = {"file_path": data_value, "s3_key": s3_key}
|
||||
data_value = presigned_url
|
||||
else:
|
||||
self.logger.error(f"File upload failed for hash: {data_hash}")
|
||||
return False
|
||||
else:
|
||||
if data_type == "qna_pair":
|
||||
data_value = list(ast.literal_eval(data_value))
|
||||
metadata = {}
|
||||
|
||||
try:
|
||||
self._upload_data_to_pipeline(data_type, data_value, metadata)
|
||||
self._mark_data_as_uploaded(data_hash)
|
||||
return True
|
||||
except Exception:
|
||||
print(f"❌ Error occurred during data upload for hash {data_hash}!")
|
||||
return False
|
||||
|
||||
def _mark_data_as_uploaded(self, data_hash):
|
||||
self.cursor.execute(
|
||||
"UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ?",
|
||||
(data_hash, self.local_id),
|
||||
)
|
||||
self.connection.commit()
|
||||
|
||||
def get_data_sources(self):
|
||||
db_data = self.cursor.execute("SELECT * FROM data_sources WHERE pipeline_id = ?", (self.local_id,)).fetchall()
|
||||
|
||||
data_sources = []
|
||||
for data in db_data:
|
||||
data_sources.append({"data_type": data[2], "data_value": data[3], "metadata": data[4]})
|
||||
|
||||
return data_sources
|
||||
|
||||
def deploy(self):
|
||||
if self.client is None:
|
||||
self._init_client()
|
||||
|
||||
pipeline_data = self._create_pipeline()
|
||||
self.id = pipeline_data["id"]
|
||||
|
||||
results = self.cursor.execute(
|
||||
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,) # noqa:E501
|
||||
).fetchall()
|
||||
|
||||
if len(results) > 0:
|
||||
print("🛠️ Adding data to your pipeline...")
|
||||
for result in results:
|
||||
data_hash, data_type, data_value = result[1], result[2], result[3]
|
||||
self._process_and_upload_data(data_hash, data_type, data_value)
|
||||
|
||||
# Send anonymous telemetry
|
||||
self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config_path: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
auto_deploy: bool = False,
|
||||
yaml_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Instantiate a Pipeline object from a configuration.
|
||||
|
||||
:param config_path: Path to the YAML or JSON configuration file.
|
||||
:type config_path: Optional[str]
|
||||
:param config: A dictionary containing the configuration.
|
||||
:type config: Optional[Dict[str, Any]]
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
|
||||
:type yaml_path: Optional[str]
|
||||
:return: An instance of the Pipeline class.
|
||||
:rtype: Pipeline
|
||||
"""
|
||||
# Backward compatibility for yaml_path
|
||||
if yaml_path and not config_path:
|
||||
config_path = yaml_path
|
||||
|
||||
if config_path and config:
|
||||
raise ValueError("Please provide only one of config_path or config.")
|
||||
|
||||
config_data = None
|
||||
|
||||
if config_path:
|
||||
file_extension = os.path.splitext(config_path)[1]
|
||||
with open(config_path, "r") as file:
|
||||
if file_extension in [".yaml", ".yml"]:
|
||||
config_data = yaml.safe_load(file)
|
||||
elif file_extension == ".json":
|
||||
config_data = json.load(file)
|
||||
else:
|
||||
raise ValueError("config_path must be a path to a YAML or JSON file.")
|
||||
elif config and isinstance(config, dict):
|
||||
config_data = config
|
||||
else:
|
||||
logging.error(
|
||||
"Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
|
||||
)
|
||||
config_data = {}
|
||||
|
||||
try:
|
||||
validate_config(config_data)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error occurred while validating the config. Error: {str(e)}")
|
||||
|
||||
pipeline_config_data = config_data.get("app", {}).get("config", {})
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
chunker_config_data = config_data.get("chunker", {})
|
||||
|
||||
pipeline_config = PipelineConfig(**pipeline_config_data)
|
||||
|
||||
db_provider = db_config_data.get("provider", "chroma")
|
||||
db = VectorDBFactory.create(db_provider, db_config_data.get("config", {}))
|
||||
|
||||
if llm_config_data:
|
||||
llm_provider = llm_config_data.get("provider", "openai")
|
||||
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
|
||||
else:
|
||||
llm = None
|
||||
|
||||
embedding_model_provider = embedding_model_config_data.get("provider", "openai")
|
||||
embedding_model = EmbedderFactory.create(
|
||||
embedding_model_provider, embedding_model_config_data.get("config", {})
|
||||
)
|
||||
|
||||
# Send anonymous telemetry
|
||||
event_properties = {"init_type": "config_data"}
|
||||
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
|
||||
|
||||
return cls(
|
||||
config=pipeline_config,
|
||||
llm=llm,
|
||||
db=db,
|
||||
embedding_model=embedding_model,
|
||||
config_data=config_data,
|
||||
auto_deploy=auto_deploy,
|
||||
chunker=chunker_config_data,
|
||||
)
|
||||
pass
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
|
||||
import chainlit as cl
|
||||
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xxx"
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import threading
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
|
||||
generate)
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
|
||||
@@ -9,7 +9,7 @@ from services import get_app, get_apps, remove_app, save_app
|
||||
from sqlalchemy.orm import Session
|
||||
from utils import generate_error_message_for_api_keys
|
||||
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
from embedchain.client import Client
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
@@ -6,7 +6,7 @@ from io import StringIO
|
||||
import requests
|
||||
import streamlit as st
|
||||
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
|
||||
generate)
|
||||
|
||||
@@ -2,7 +2,7 @@ import queue
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
|
||||
generate)
|
||||
@@ -35,7 +35,7 @@ with st.expander(":grey[Want to create your own Unacademy UPSC AI?]"):
|
||||
```
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain import App
|
||||
unacademy_ai_app = App()
|
||||
unacademy_ai_app.add(
|
||||
"https://unacademy.com/content/upsc/study-material/plan-policy/atma-nirbhar-bharat-3-0/",
|
||||
|
||||
@@ -54,7 +54,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"\n",
|
||||
"os.environ[\"ANTHROPIC_API_KEY\"] = \"xxx\""
|
||||
|
||||
@@ -44,7 +44,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_TYPE\"] = \"azure\"\n",
|
||||
"os.environ[\"OPENAI_API_BASE\"] = \"https://xxx.openai.azure.com/\"\n",
|
||||
@@ -143,7 +143,7 @@
|
||||
"source": [
|
||||
"while(True):\n",
|
||||
" question = input(\"Enter question: \")\n",
|
||||
" if question in ['q', 'exit', 'quit']\n",
|
||||
" if question in ['q', 'exit', 'quit']:\n",
|
||||
" break\n",
|
||||
" answer = app.query(question)\n",
|
||||
" print(answer)"
|
||||
|
||||
@@ -49,7 +49,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
|
||||
]
|
||||
|
||||
@@ -53,7 +53,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"\n",
|
||||
"os.environ[\"COHERE_API_KEY\"] = \"xxx\""
|
||||
|
||||
@@ -49,7 +49,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
|
||||
]
|
||||
|
||||
@@ -33,7 +33,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"from embedchain.config import AppConfig\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"embedchain_docs_bot = App()"
|
||||
]
|
||||
|
||||
@@ -52,7 +52,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from embedchain import Pipeline as App"
|
||||
"from embedchain import App"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -54,7 +54,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"os.environ[\"HUGGINGFACE_ACCESS_TOKEN\"] = \"hf_xxx\""
|
||||
]
|
||||
|
||||
@@ -54,7 +54,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"\n",
|
||||
"os.environ[\"JINACHAT_API_KEY\"] = \"xxx\""
|
||||
|
||||
@@ -53,7 +53,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"\n",
|
||||
"os.environ[\"REPLICATE_API_TOKEN\"] = \"xxx\""
|
||||
|
||||
@@ -92,7 +92,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"app = App.from_config(config_path=\"ollama.yaml\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -54,7 +54,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
|
||||
]
|
||||
@@ -80,7 +80,7 @@
|
||||
"llm:\n",
|
||||
" provider: openai\n",
|
||||
" config:\n",
|
||||
" model: gpt-35-turbo\n",
|
||||
" model: gpt-3.5-turbo\n",
|
||||
" temperature: 0.5\n",
|
||||
" max_tokens: 1000\n",
|
||||
" top_p: 1\n",
|
||||
|
||||
@@ -49,7 +49,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
|
||||
]
|
||||
|
||||
@@ -49,7 +49,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"\n",
|
||||
"os.environ[\"PINECONE_API_KEY\"] = \"xxx\"\n",
|
||||
|
||||
@@ -53,7 +53,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"\"\n",
|
||||
"os.environ[\"TOGETHER_API_KEY\"] = \"\""
|
||||
|
||||
@@ -53,7 +53,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from embedchain import Pipeline as App\n",
|
||||
"from embedchain import App\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\""
|
||||
]
|
||||
|
||||
@@ -8,6 +8,7 @@ from embedchain.config import AppConfig, ChromaDbConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||
|
||||
@@ -15,7 +16,7 @@ os.environ["OPENAI_API_KEY"] = "test-api-key"
|
||||
@pytest.fixture
|
||||
def app_instance():
|
||||
config = AppConfig(log_level="DEBUG", collect_metrics=False)
|
||||
return App(config)
|
||||
return App(config=config)
|
||||
|
||||
|
||||
def test_whole_app(app_instance, mocker):
|
||||
@@ -44,9 +45,9 @@ def test_add_after_reset(app_instance, mocker):
|
||||
mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
|
||||
config = AppConfig(log_level="DEBUG", collect_metrics=False)
|
||||
chroma_config = {"allow_reset": True}
|
||||
|
||||
app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
|
||||
chroma_config = ChromaDbConfig(allow_reset=True)
|
||||
db = ChromaDB(config=chroma_config)
|
||||
app_instance = App(config=config, db=db)
|
||||
|
||||
# mock delete chat history
|
||||
mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
|
||||
|
||||
@@ -114,5 +114,7 @@ class TestApp(unittest.TestCase):
|
||||
self.assertEqual(answer, "Test answer")
|
||||
_args, kwargs = mock_database_query.call_args
|
||||
self.assertEqual(kwargs.get("input_query"), "Test query")
|
||||
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
|
||||
where = kwargs.get("where")
|
||||
assert "app_id" in where
|
||||
assert "attribute" in where
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -37,25 +38,14 @@ def test_query_config_app_passing(mock_get_answer):
|
||||
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
|
||||
app = App(config=config, llm_config=chat_config)
|
||||
llm = OpenAILlm(config=chat_config)
|
||||
app = App(config=config, llm=llm)
|
||||
answer = app.llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert app.llm.config.system_prompt == "Test system prompt"
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
|
||||
def test_app_passing(mock_get_answer):
|
||||
mock_get_answer.return_value = MagicMock()
|
||||
mock_get_answer.return_value = "Test answer"
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chat_config = BaseLlmConfig()
|
||||
app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
|
||||
answer = app.llm.get_llm_model_answer("Test query")
|
||||
assert app.llm.config.system_prompt == "Test system prompt"
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_params(app):
|
||||
with patch.object(app, "_retrieve_from_database") as mock_retrieve:
|
||||
@@ -83,5 +73,7 @@ def test_query_with_where_in_query_config(app):
|
||||
assert answer == "Test answer"
|
||||
_, kwargs = mock_database_query.call_args
|
||||
assert kwargs.get("input_query") == "Test query"
|
||||
assert kwargs.get("where") == {"attribute": "value"}
|
||||
where = kwargs.get("where")
|
||||
assert "app_id" in where
|
||||
assert "attribute" in where
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@@ -4,11 +4,10 @@ import pytest
|
||||
import yaml
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
|
||||
BaseLlmConfig, ChromaDbConfig)
|
||||
from embedchain.config import ChromaDbConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
|
||||
@@ -21,13 +20,14 @@ def app():
|
||||
def test_app(app):
|
||||
assert isinstance(app.llm, BaseLlm)
|
||||
assert isinstance(app.db, BaseVectorDB)
|
||||
assert isinstance(app.embedder, BaseEmbedder)
|
||||
assert isinstance(app.embedding_model, BaseEmbedder)
|
||||
|
||||
|
||||
class TestConfigForAppComponents:
|
||||
def test_constructor_config(self):
|
||||
collection_name = "my-test-collection"
|
||||
app = App(db_config=ChromaDbConfig(collection_name=collection_name))
|
||||
db = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
|
||||
app = App(db=db)
|
||||
assert app.db.config.collection_name == collection_name
|
||||
|
||||
def test_component_config(self):
|
||||
@@ -36,50 +36,6 @@ class TestConfigForAppComponents:
|
||||
app = App(db=database)
|
||||
assert app.db.config.collection_name == collection_name
|
||||
|
||||
def test_different_configs_are_proper_instances(self):
|
||||
app_config = AppConfig()
|
||||
wrong_config = AddConfig()
|
||||
with pytest.raises(TypeError):
|
||||
App(config=wrong_config)
|
||||
|
||||
assert isinstance(app_config, AppConfig)
|
||||
|
||||
llm_config = BaseLlmConfig()
|
||||
wrong_llm_config = "wrong_llm_config"
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
App(llm_config=wrong_llm_config)
|
||||
|
||||
assert isinstance(llm_config, BaseLlmConfig)
|
||||
|
||||
db_config = BaseVectorDbConfig()
|
||||
wrong_db_config = "wrong_db_config"
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
App(db_config=wrong_db_config)
|
||||
|
||||
assert isinstance(db_config, BaseVectorDbConfig)
|
||||
|
||||
embedder_config = BaseEmbedderConfig()
|
||||
wrong_embedder_config = "wrong_embedder_config"
|
||||
with pytest.raises(TypeError):
|
||||
App(embedder_config=wrong_embedder_config)
|
||||
|
||||
assert isinstance(embedder_config, BaseEmbedderConfig)
|
||||
|
||||
def test_components_raises_type_error_if_not_proper_instances(self):
|
||||
wrong_llm = "wrong_llm"
|
||||
with pytest.raises(TypeError):
|
||||
App(llm=wrong_llm)
|
||||
|
||||
wrong_db = "wrong_db"
|
||||
with pytest.raises(TypeError):
|
||||
App(db=wrong_db)
|
||||
|
||||
wrong_embedder = "wrong_embedder"
|
||||
with pytest.raises(TypeError):
|
||||
App(embedder=wrong_embedder)
|
||||
|
||||
|
||||
class TestAppFromConfig:
|
||||
def load_config_data(self, yaml_path):
|
||||
@@ -92,14 +48,13 @@ class TestAppFromConfig:
|
||||
yaml_path = "configs/chroma.yaml"
|
||||
config_data = self.load_config_data(yaml_path)
|
||||
|
||||
app = App.from_config(yaml_path)
|
||||
app = App.from_config(config_path=yaml_path)
|
||||
|
||||
# Check if the App instance and its components were created correctly
|
||||
assert isinstance(app, App)
|
||||
|
||||
# Validate the AppConfig values
|
||||
assert app.config.id == config_data["app"]["config"]["id"]
|
||||
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
|
||||
# Even though not present in the config, the default value is used
|
||||
assert app.config.collect_metrics is True
|
||||
|
||||
@@ -118,8 +73,8 @@ class TestAppFromConfig:
|
||||
|
||||
# Validate the Embedder config values
|
||||
embedder_config = config_data["embedder"]["config"]
|
||||
assert app.embedder.config.model == embedder_config["model"]
|
||||
assert app.embedder.config.deployment_name == embedder_config.get("deployment_name")
|
||||
assert app.embedding_model.config.model == embedder_config["model"]
|
||||
assert app.embedding_model.config.deployment_name == embedder_config.get("deployment_name")
|
||||
|
||||
def test_from_opensource_config(self, mocker):
|
||||
mocker.patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
@@ -134,7 +89,6 @@ class TestAppFromConfig:
|
||||
|
||||
# Validate the AppConfig values
|
||||
assert app.config.id == config_data["app"]["config"]["id"]
|
||||
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
|
||||
assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"]
|
||||
|
||||
# Validate the LLM config values
|
||||
@@ -153,4 +107,4 @@ class TestAppFromConfig:
|
||||
|
||||
# Validate the Embedder config values
|
||||
embedder_config = config_data["embedder"]["config"]
|
||||
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
|
||||
assert app.embedding_model.config.deployment_name == embedder_config["deployment_name"]
|
||||
@@ -20,8 +20,9 @@ def chroma_db():
|
||||
@pytest.fixture
|
||||
def app_with_settings():
|
||||
chroma_config = ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
chroma_db = ChromaDB(config=chroma_config)
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
return App(config=app_config, db_config=chroma_config)
|
||||
return App(config=app_config, db=chroma_db)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@@ -65,7 +66,8 @@ def test_app_init_with_host_and_port(mock_client):
|
||||
port = "1234"
|
||||
config = AppConfig(collect_metrics=False)
|
||||
db_config = ChromaDbConfig(host=host, port=port)
|
||||
_app = App(config, db_config=db_config)
|
||||
db = ChromaDB(config=db_config)
|
||||
_app = App(config=config, db=db)
|
||||
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
assert called_settings.chroma_server_host == host
|
||||
@@ -74,7 +76,8 @@ def test_app_init_with_host_and_port(mock_client):
|
||||
|
||||
@patch("embedchain.vectordb.chroma.chromadb.Client")
|
||||
def test_app_init_with_host_and_port_none(mock_client):
|
||||
_app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
_app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
|
||||
called_settings: Settings = mock_client.call_args[0][0]
|
||||
assert called_settings.chroma_server_host is None
|
||||
@@ -82,7 +85,8 @@ def test_app_init_with_host_and_port_none(mock_client):
|
||||
|
||||
|
||||
def test_chroma_db_duplicates_throw_warning(caplog):
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
assert "Insert of existing embedding ID: 0" in caplog.text
|
||||
@@ -91,7 +95,8 @@ def test_chroma_db_duplicates_throw_warning(caplog):
|
||||
|
||||
|
||||
def test_chroma_db_duplicates_collections_no_warning(caplog):
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
app.set_collection_name("test_collection_2")
|
||||
@@ -104,24 +109,28 @@ def test_chroma_db_duplicates_collections_no_warning(caplog):
|
||||
|
||||
|
||||
def test_chroma_db_collection_init_with_default_collection():
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
assert app.db.collection.name == "embedchain_store"
|
||||
|
||||
|
||||
def test_chroma_db_collection_init_with_custom_collection():
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name(name="test_collection")
|
||||
assert app.db.collection.name == "test_collection"
|
||||
|
||||
|
||||
def test_chroma_db_collection_set_collection_name():
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection")
|
||||
assert app.db.collection.name == "test_collection"
|
||||
|
||||
|
||||
def test_chroma_db_collection_changes_encapsulated():
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
assert app.db.count() == 0
|
||||
|
||||
@@ -207,12 +216,14 @@ def test_chroma_db_collection_add_with_invalid_inputs(app_with_settings):
|
||||
|
||||
|
||||
def test_chroma_db_collection_collections_are_persistent():
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
app.db.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
|
||||
del app
|
||||
|
||||
app = App(config=AppConfig(collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
db = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app = App(config=AppConfig(collect_metrics=False), db=db)
|
||||
app.set_collection_name("test_collection_1")
|
||||
assert app.db.count() == 1
|
||||
|
||||
@@ -220,13 +231,15 @@ def test_chroma_db_collection_collections_are_persistent():
|
||||
|
||||
|
||||
def test_chroma_db_collection_parallel_collections():
|
||||
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_1"))
|
||||
app1 = App(
|
||||
AppConfig(collection_name="test_collection_1", collect_metrics=False),
|
||||
db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
|
||||
config=AppConfig(collect_metrics=False),
|
||||
db=db1,
|
||||
)
|
||||
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db", collection_name="test_collection_2"))
|
||||
app2 = App(
|
||||
AppConfig(collection_name="test_collection_2", collect_metrics=False),
|
||||
db_config=ChromaDbConfig(allow_reset=True, dir="test-db"),
|
||||
config=AppConfig(collect_metrics=False),
|
||||
db=db2,
|
||||
)
|
||||
|
||||
# cleanup if any previous tests failed or were interrupted
|
||||
@@ -251,13 +264,11 @@ def test_chroma_db_collection_parallel_collections():
|
||||
|
||||
|
||||
def test_chroma_db_collection_ids_share_collections():
|
||||
app1 = App(
|
||||
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
)
|
||||
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
|
||||
app1.set_collection_name("one_collection")
|
||||
app2 = App(
|
||||
AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
)
|
||||
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
|
||||
app2.set_collection_name("one_collection")
|
||||
|
||||
app1.db.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
|
||||
@@ -272,21 +283,17 @@ def test_chroma_db_collection_ids_share_collections():
|
||||
|
||||
|
||||
def test_chroma_db_collection_reset():
|
||||
app1 = App(
|
||||
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
)
|
||||
db1 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app1 = App(config=AppConfig(collect_metrics=False), db=db1)
|
||||
app1.set_collection_name("one_collection")
|
||||
app2 = App(
|
||||
AppConfig(id="new_app_id_2", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
)
|
||||
db2 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app2 = App(config=AppConfig(collect_metrics=False), db=db2)
|
||||
app2.set_collection_name("two_collection")
|
||||
app3 = App(
|
||||
AppConfig(id="new_app_id_1", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
)
|
||||
db3 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app3 = App(config=AppConfig(collect_metrics=False), db=db3)
|
||||
app3.set_collection_name("three_collection")
|
||||
app4 = App(
|
||||
AppConfig(id="new_app_id_4", collect_metrics=False), db_config=ChromaDbConfig(allow_reset=True, dir="test-db")
|
||||
)
|
||||
db4 = ChromaDB(config=ChromaDbConfig(allow_reset=True, dir="test-db"))
|
||||
app4 = App(config=AppConfig(collect_metrics=False), db=db4)
|
||||
app4.set_collection_name("four_collection")
|
||||
|
||||
app1.db.collection.add(embeddings=[0, 0, 0], ids=["1"])
|
||||
|
||||
@@ -13,7 +13,7 @@ class TestEsDB(unittest.TestCase):
|
||||
def test_setUp(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
self.vector_dim = 384
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db)
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
@@ -22,8 +22,8 @@ class TestEsDB(unittest.TestCase):
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_query(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db, embedder=GPT4AllEmbedder())
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db, embedding_model=GPT4AllEmbedder())
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
self.assertEqual(self.db.client, mock_client.return_value)
|
||||
@@ -74,7 +74,7 @@ class TestEsDB(unittest.TestCase):
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_query_with_skip_embedding(self, mock_client):
|
||||
self.db = ElasticsearchDB(config=ElasticsearchDBConfig(es_url="https://localhost:9200"))
|
||||
app_config = AppConfig(collection_name=False, collect_metrics=False)
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
self.app = App(config=app_config, db=self.db)
|
||||
|
||||
# Assert that the Elasticsearch client is stored in the ElasticsearchDB class.
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestPinecone:
|
||||
# Create a PineconeDB instance
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Assert that the embedder was set
|
||||
assert db.embedder == embedder
|
||||
@@ -48,7 +48,7 @@ class TestPinecone:
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=base_embedder)
|
||||
App(config=app_config, db=db, embedding_model=base_embedder)
|
||||
|
||||
# Add some documents to the database
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
@@ -76,7 +76,7 @@ class TestPinecone:
|
||||
# Create a PineconeDB instance
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=base_embedder)
|
||||
App(config=app_config, db=db, embedding_model=base_embedder)
|
||||
|
||||
# Query the database for documents that are similar to "document"
|
||||
input_query = ["document"]
|
||||
@@ -94,7 +94,7 @@ class TestPinecone:
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=BaseEmbedder())
|
||||
App(config=app_config, db=db, embedding_model=BaseEmbedder())
|
||||
|
||||
# Reset the database
|
||||
db.reset()
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
self.assertEqual(db.collection_name, "embedchain-store-1526")
|
||||
self.assertEqual(db.client, qdrant_client_mock.return_value)
|
||||
@@ -46,7 +46,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
resp = db.get(ids=[], where={})
|
||||
self.assertEqual(resp, {"ids": []})
|
||||
@@ -65,7 +65,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
||||
documents = ["This is a test document.", "This is another test document."]
|
||||
@@ -76,7 +76,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
qdrant_client_mock.return_value.upsert.assert_called_once_with(
|
||||
collection_name="embedchain-store-1526",
|
||||
points=Batch(
|
||||
ids=["abc", "def"],
|
||||
ids=["def", "ghi"],
|
||||
payloads=[
|
||||
{
|
||||
"identifier": "123",
|
||||
@@ -102,7 +102,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
|
||||
@@ -132,7 +132,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
db.count()
|
||||
qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1526")
|
||||
@@ -146,7 +146,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
# Create a Qdrant instance
|
||||
db = QdrantDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
db.reset()
|
||||
qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
expected_class_obj = {
|
||||
"classes": [
|
||||
@@ -96,7 +96,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
expected_client = db._get_or_create_db()
|
||||
self.assertEqual(expected_client, weaviate_client_mock)
|
||||
@@ -115,7 +115,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
db.BATCH_SIZE = 1
|
||||
|
||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
||||
@@ -159,7 +159,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={}, skip_embedding=True)
|
||||
@@ -184,7 +184,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
|
||||
@@ -210,7 +210,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Reset the database.
|
||||
db.reset()
|
||||
@@ -232,7 +232,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
# Create a Weaviate instance
|
||||
db = WeaviateDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Reset the database.
|
||||
db.count()
|
||||
|
||||
Reference in New Issue
Block a user