[Improvement] Set a default app id if not provided in the app configuration (#1300)

This commit is contained in:
Deshraj Yadav
2024-03-02 15:10:34 -08:00
committed by GitHub
parent 8d7e8b6fb9
commit faacfeb891
4 changed files with 42 additions and 41 deletions

View File

@@ -31,41 +31,47 @@ This section gives a quickstart example of using Mistral as the Open source LLM
We are using Mistral hosted at Hugging Face, so will you need a Hugging Face token to run this example. Its *free* and you can create one [here](https://huggingface.co/docs/hub/security-tokens). We are using Mistral hosted at Hugging Face, so will you need a Hugging Face token to run this example. Its *free* and you can create one [here](https://huggingface.co/docs/hub/security-tokens).
<CodeGroup> <CodeGroup>
```python quickstart.py ```python huggingface_demo.py
import os import os
# replace this with your HF key # Replace this with your HF token
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "hf_xxxx" os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "hf_xxxx"
from embedchain import App from embedchain import App
app = App.from_config("mistral.yaml")
config = {
'llm': {
'provider': 'huggingface',
'config': {
'model': 'mistralai/Mistral-7B-Instruct-v0.2',
'top_p': 0.5
}
},
'embedder': {
'provider': 'huggingface',
'config': {
'model': 'sentence-transformers/all-mpnet-base-v2'
}
}
}
app = App.from_config(config=config)
app.add("https://www.forbes.com/profile/elon-musk") app.add("https://www.forbes.com/profile/elon-musk")
app.add("https://en.wikipedia.org/wiki/Elon_Musk") app.add("https://en.wikipedia.org/wiki/Elon_Musk")
app.query("What is the net worth of Elon Musk today?") app.query("What is the net worth of Elon Musk today?")
# Answer: The net worth of Elon Musk today is $258.7 billion. # Answer: The net worth of Elon Musk today is $258.7 billion.
``` ```
```yaml mistral.yaml
llm:
provider: huggingface
config:
model: 'mistralai/Mistral-7B-Instruct-v0.2'
top_p: 0.5
embedder:
provider: huggingface
config:
model: 'sentence-transformers/all-mpnet-base-v2'
```
</CodeGroup> </CodeGroup>
## Paid Models ## Paid Models
In this section, we will use both LLM and embedding model from OpenAI. In this section, we will use both LLM and embedding model from OpenAI.
```python quickstart.py ```python openai_demo.py
import os import os
# replace this with your OpenAI key from embedchain import App
# Replace this with your OpenAI key
os.environ["OPENAI_API_KEY"] = "sk-xxxx" os.environ["OPENAI_API_KEY"] = "sk-xxxx"
from embedchain import App
app = App() app = App()
app.add("https://www.forbes.com/profile/elon-musk") app.add("https://www.forbes.com/profile/elon-musk")
app.add("https://en.wikipedia.org/wiki/Elon_Musk") app.add("https://en.wikipedia.org/wiki/Elon_Musk")

View File

@@ -3,21 +3,15 @@ import concurrent.futures
import json import json
import logging import logging
import os import os
import uuid
from typing import Any, Optional, Union from typing import Any, Optional, Union
import requests import requests
import yaml import yaml
from tqdm import tqdm from tqdm import tqdm
from embedchain.cache import ( from embedchain.cache import (Config, ExactMatchEvaluation,
Config, SearchDistanceEvaluation, cache,
ExactMatchEvaluation, gptcache_data_manager, gptcache_pre_function)
SearchDistanceEvaluation,
cache,
gptcache_data_manager,
gptcache_pre_function,
)
from embedchain.client import Client from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.core.db.database import get_session, init_db, setup_engine from embedchain.core.db.database import get_session, init_db, setup_engine
@@ -26,7 +20,8 @@ from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.evaluation.base import BaseMetric from embedchain.evaluation.base import BaseMetric
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance,
Groundedness)
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
@@ -106,7 +101,7 @@ class App(EmbedChain):
self.config = config or AppConfig() self.config = config or AppConfig()
self.name = self.config.name self.name = self.config.name
self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id self.config.id = self.local_id = "default-app-id" if self.config.id is None else self.config.id
if id is not None: if id is not None:
# Init client first since user is trying to fetch the pipeline # Init client first since user is trying to fetch the pipeline

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "embedchain" name = "embedchain"
version = "0.1.91" version = "0.1.92"
description = "Simplest open source retrieval(RAG) framework" description = "Simplest open source retrieval(RAG) framework"
authors = [ authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>", "Taranjeet Singh <taranjeet@embedchain.ai>",

View File

@@ -29,7 +29,7 @@ class TestQdrantDB(unittest.TestCase):
def test_initialize(self, qdrant_client_mock): def test_initialize(self, qdrant_client_mock):
# Set the embedder # Set the embedder
embedder = BaseEmbedder() embedder = BaseEmbedder()
embedder.set_vector_dimension(1526) embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn) embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance # Create a Qdrant instance
@@ -37,7 +37,7 @@ class TestQdrantDB(unittest.TestCase):
app_config = AppConfig(collect_metrics=False) app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder) App(config=app_config, db=db, embedding_model=embedder)
self.assertEqual(db.collection_name, "embedchain-store-1526") self.assertEqual(db.collection_name, "embedchain-store-1536")
self.assertEqual(db.client, qdrant_client_mock.return_value) self.assertEqual(db.client, qdrant_client_mock.return_value)
qdrant_client_mock.return_value.get_collections.assert_called_once() qdrant_client_mock.return_value.get_collections.assert_called_once()
@@ -47,7 +47,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder # Set the embedder
embedder = BaseEmbedder() embedder = BaseEmbedder()
embedder.set_vector_dimension(1526) embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn) embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance # Create a Qdrant instance
@@ -67,7 +67,7 @@ class TestQdrantDB(unittest.TestCase):
# Set the embedder # Set the embedder
embedder = BaseEmbedder() embedder = BaseEmbedder()
embedder.set_vector_dimension(1526) embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn) embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance # Create a Qdrant instance
@@ -80,9 +80,9 @@ class TestQdrantDB(unittest.TestCase):
ids = ["123", "456"] ids = ["123", "456"]
db.add(documents, metadatas, ids) db.add(documents, metadatas, ids)
qdrant_client_mock.return_value.upsert.assert_called_once_with( qdrant_client_mock.return_value.upsert.assert_called_once_with(
collection_name="embedchain-store-1526", collection_name="embedchain-store-1536",
points=Batch( points=Batch(
ids=["def", "ghi"], ids=["abc", "def"],
payloads=[ payloads=[
{ {
"identifier": "123", "identifier": "123",
@@ -103,7 +103,7 @@ class TestQdrantDB(unittest.TestCase):
def test_query(self, qdrant_client_mock): def test_query(self, qdrant_client_mock):
# Set the embedder # Set the embedder
embedder = BaseEmbedder() embedder = BaseEmbedder()
embedder.set_vector_dimension(1526) embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn) embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance # Create a Qdrant instance
@@ -115,7 +115,7 @@ class TestQdrantDB(unittest.TestCase):
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}) db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
qdrant_client_mock.return_value.search.assert_called_once_with( qdrant_client_mock.return_value.search.assert_called_once_with(
collection_name="embedchain-store-1526", collection_name="embedchain-store-1536",
query_filter=models.Filter( query_filter=models.Filter(
must=[ must=[
models.FieldCondition( models.FieldCondition(
@@ -134,7 +134,7 @@ class TestQdrantDB(unittest.TestCase):
def test_count(self, qdrant_client_mock): def test_count(self, qdrant_client_mock):
# Set the embedder # Set the embedder
embedder = BaseEmbedder() embedder = BaseEmbedder()
embedder.set_vector_dimension(1526) embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn) embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance # Create a Qdrant instance
@@ -143,13 +143,13 @@ class TestQdrantDB(unittest.TestCase):
App(config=app_config, db=db, embedding_model=embedder) App(config=app_config, db=db, embedding_model=embedder)
db.count() db.count()
qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1526") qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1536")
@patch("embedchain.vectordb.qdrant.QdrantClient") @patch("embedchain.vectordb.qdrant.QdrantClient")
def test_reset(self, qdrant_client_mock): def test_reset(self, qdrant_client_mock):
# Set the embedder # Set the embedder
embedder = BaseEmbedder() embedder = BaseEmbedder()
embedder.set_vector_dimension(1526) embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn) embedder.set_embedding_fn(mock_embedding_fn)
# Create a Qdrant instance # Create a Qdrant instance
@@ -159,7 +159,7 @@ class TestQdrantDB(unittest.TestCase):
db.reset() db.reset()
qdrant_client_mock.return_value.delete_collection.assert_called_once_with( qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
collection_name="embedchain-store-1526" collection_name="embedchain-store-1536"
) )