Code Formatting (#1828)
This commit is contained in:
@@ -9,7 +9,6 @@ import requests
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
from mem0 import Memory
|
||||
from embedchain.cache import (
|
||||
Config,
|
||||
ExactMatchEvaluation,
|
||||
@@ -26,7 +25,11 @@ from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.evaluation.base import BaseMetric
|
||||
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
|
||||
from embedchain.evaluation.metrics import (
|
||||
AnswerRelevance,
|
||||
ContextRelevance,
|
||||
Groundedness,
|
||||
)
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
@@ -36,6 +39,7 @@ from embedchain.utils.evaluation import EvalData, EvalMetric
|
||||
from embedchain.utils.misc import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
from mem0 import Memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@ from typing import Any
|
||||
from embedchain import App
|
||||
from embedchain.config import AddConfig, AppConfig, BaseLlmConfig
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.helpers.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
from embedchain.helpers.json_serializable import (
|
||||
JSONSerializable,
|
||||
register_deserializable,
|
||||
)
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
|
||||
@@ -9,10 +9,12 @@ from gptcache.manager import get_data_manager
|
||||
from gptcache.manager.scalar_data.base import Answer
|
||||
from gptcache.manager.scalar_data.base import DataType as CacheDataType
|
||||
from gptcache.session import Session
|
||||
from gptcache.similarity_evaluation.distance import \
|
||||
SearchDistanceEvaluation # noqa: F401
|
||||
from gptcache.similarity_evaluation.exact_match import \
|
||||
ExactMatchEvaluation # noqa: F401
|
||||
from gptcache.similarity_evaluation.distance import ( # noqa: F401
|
||||
SearchDistanceEvaluation,
|
||||
)
|
||||
from gptcache.similarity_evaluation.exact_match import ( # noqa: F401
|
||||
ExactMatchEvaluation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -14,13 +14,21 @@ import requests
|
||||
from rich.console import Console
|
||||
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils.cli import (deploy_fly, deploy_gradio_app,
|
||||
deploy_hf_spaces, deploy_modal,
|
||||
deploy_render, deploy_streamlit,
|
||||
get_pkg_path_from_name, setup_fly_io_app,
|
||||
setup_gradio_app, setup_hf_app,
|
||||
setup_modal_com_app, setup_render_com_app,
|
||||
setup_streamlit_io_app)
|
||||
from embedchain.utils.cli import (
|
||||
deploy_fly,
|
||||
deploy_gradio_app,
|
||||
deploy_hf_spaces,
|
||||
deploy_modal,
|
||||
deploy_render,
|
||||
deploy_streamlit,
|
||||
get_pkg_path_from_name,
|
||||
setup_fly_io_app,
|
||||
setup_gradio_app,
|
||||
setup_hf_app,
|
||||
setup_modal_com_app,
|
||||
setup_render_com_app,
|
||||
setup_streamlit_io_app,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
api_process = None
|
||||
|
||||
@@ -8,8 +8,8 @@ from .embedder.base import BaseEmbedderConfig
|
||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||
from .embedder.ollama import OllamaEmbedderConfig
|
||||
from .llm.base import BaseLlmConfig
|
||||
from .mem0_config import Mem0Config
|
||||
from .vector_db.chroma import ChromaDbConfig
|
||||
from .vector_db.elasticsearch import ElasticsearchDBConfig
|
||||
from .vector_db.opensearch import OpenSearchDBConfig
|
||||
from .vector_db.zilliz import ZillizDBConfig
|
||||
from .mem0_config import Mem0Config
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
from .base import (AnswerRelevanceConfig, ContextRelevanceConfig, # noqa: F401
|
||||
GroundednessConfig)
|
||||
from .base import ( # noqa: F401
|
||||
AnswerRelevanceConfig,
|
||||
ContextRelevanceConfig,
|
||||
GroundednessConfig,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from string import Template
|
||||
from typing import Any, Mapping, Optional, Dict, Union
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
from typing import Any, Dict, Mapping, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
@@ -6,7 +6,12 @@ from typing import Any, Optional, Union
|
||||
from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
|
||||
from embedchain.cache import (
|
||||
adapt,
|
||||
get_gptcache_session,
|
||||
gptcache_data_convert,
|
||||
gptcache_update_cache_callback,
|
||||
)
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config.base_app_config import BaseAppConfig
|
||||
@@ -16,7 +21,12 @@ from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
||||
from embedchain.models.data_type import (
|
||||
DataType,
|
||||
DirectDataType,
|
||||
IndirectDataType,
|
||||
SpecialDataType,
|
||||
)
|
||||
from embedchain.utils.misc import detect_datatype, is_valid_json_string
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from chromadb import EmbeddingFunction, Embeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
|
||||
from chromadb import EmbeddingFunction, Embeddings
|
||||
|
||||
|
||||
class ClarifaiEmbeddingFunction(EmbeddingFunction):
|
||||
def __init__(self, config: BaseEmbedderConfig) -> None:
|
||||
super().__init__()
|
||||
try:
|
||||
from clarifai.client.model import Model
|
||||
from clarifai.client.input import Inputs
|
||||
from clarifai.client.model import Model
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for ClarifaiEmbeddingFunction are not installed."
|
||||
|
||||
@@ -9,7 +9,9 @@ class GPT4AllEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
from langchain_community.embeddings import GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
|
||||
from langchain_community.embeddings import (
|
||||
GPT4AllEmbeddings as LangchainGPT4AllEmbeddings,
|
||||
)
|
||||
|
||||
model_name = self.config.model or "all-MiniLM-L6-v2-f16.gguf"
|
||||
gpt4all_kwargs = {'allow_download': 'True'}
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Optional
|
||||
|
||||
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
|
||||
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
@@ -45,8 +45,9 @@ class AWSBedrockLlm(BaseLlm):
|
||||
}
|
||||
|
||||
if config.stream:
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import (
|
||||
StreamingStdOutCallbackHandler,
|
||||
)
|
||||
|
||||
kwargs["streaming"] = True
|
||||
kwargs["callbacks"] = [StreamingStdOutCallbackHandler()]
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Any, Optional
|
||||
|
||||
from langchain.schema import BaseMessage as LCBaseMessage
|
||||
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.config.llm.base import (
|
||||
DEFAULT_PROMPT,
|
||||
@@ -13,6 +12,7 @@ from embedchain.config.llm.base import (
|
||||
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE,
|
||||
)
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.core.db.database import init_db, setup_engine
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.memory.base import ChatHistory
|
||||
|
||||
@@ -26,8 +26,7 @@ class GPT4ALLLlm(BaseLlm):
|
||||
@staticmethod
|
||||
def _get_instance(model):
|
||||
try:
|
||||
from langchain_community.llms.gpt4all import \
|
||||
GPT4All as LangchainGPT4All
|
||||
from langchain_community.llms.gpt4all import GPT4All as LangchainGPT4All
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
|
||||
|
||||
@@ -35,8 +35,9 @@ class JinaLlm(BaseLlm):
|
||||
if config.top_p:
|
||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||
if config.stream:
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import (
|
||||
StreamingStdOutCallbackHandler,
|
||||
)
|
||||
|
||||
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
@@ -28,7 +28,9 @@ class RSSFeedLoader(BaseLoader):
|
||||
@staticmethod
|
||||
def get_rss_content(url: str):
|
||||
try:
|
||||
from langchain_community.document_loaders import RSSFeedLoader as LangchainRSSFeedLoader
|
||||
from langchain_community.document_loaders import (
|
||||
RSSFeedLoader as LangchainRSSFeedLoader,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"""RSSFeedLoader file requires extra dependencies.
|
||||
|
||||
@@ -11,8 +11,7 @@ class UnstructuredLoader(BaseLoader):
|
||||
"""Load data from an Unstructured file."""
|
||||
try:
|
||||
import unstructured # noqa: F401
|
||||
from langchain_community.document_loaders import \
|
||||
UnstructuredFileLoader
|
||||
from langchain_community.document_loaders import UnstructuredFileLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501
|
||||
|
||||
@@ -6,8 +6,15 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
try:
|
||||
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
|
||||
MilvusClient, connections, utility)
|
||||
from pymilvus import (
|
||||
Collection,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
FieldSchema,
|
||||
MilvusClient,
|
||||
connections,
|
||||
utility,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
|
||||
|
||||
@@ -8,8 +8,7 @@ import streamlit as st
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
|
||||
generate)
|
||||
from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
|
||||
|
||||
|
||||
def embedchain_bot(db_path, api_key):
|
||||
|
||||
@@ -8,8 +8,7 @@ import streamlit as st
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
|
||||
generate)
|
||||
from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
|
||||
@@ -4,8 +4,7 @@ import streamlit as st
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
|
||||
generate)
|
||||
from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import patch, Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.huggingface import HuggingFaceEmbedder
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ from string import Template
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
from embedchain.helpers.json_serializable import (
|
||||
JSONSerializable,
|
||||
register_deserializable,
|
||||
)
|
||||
|
||||
|
||||
class TestJsonSerializable(unittest.TestCase):
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_alembic_command_upgrade():
|
||||
with mock.patch("alembic.command.upgrade"):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import pytest
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.ollama import OllamaLlm
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
from embedchain.models.data_type import (
|
||||
DataType,
|
||||
DirectDataType,
|
||||
IndirectDataType,
|
||||
SpecialDataType,
|
||||
)
|
||||
|
||||
|
||||
def test_subclass_types_in_data_type():
|
||||
|
||||
Reference in New Issue
Block a user