Code Formatting (#1828)

This commit is contained in:
Dev Khant
2024-09-07 22:39:28 +05:30
committed by GitHub
parent 6a54d27286
commit a972d2fb07
66 changed files with 208 additions and 138 deletions

View File

@@ -9,7 +9,6 @@ import requests
import yaml
from tqdm import tqdm
from mem0 import Memory
from embedchain.cache import (
Config,
ExactMatchEvaluation,
@@ -26,7 +25,11 @@ from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.evaluation.base import BaseMetric
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
from embedchain.evaluation.metrics import (
AnswerRelevance,
ContextRelevance,
Groundedness,
)
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@@ -36,6 +39,7 @@ from embedchain.utils.evaluation import EvalData, EvalMetric
from embedchain.utils.misc import validate_config
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
from mem0 import Memory
logger = logging.getLogger(__name__)

View File

@@ -3,8 +3,10 @@ from typing import Any
from embedchain import App
from embedchain.config import AddConfig, AppConfig, BaseLlmConfig
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.helpers.json_serializable import (JSONSerializable,
register_deserializable)
from embedchain.helpers.json_serializable import (
JSONSerializable,
register_deserializable,
)
from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.chroma import ChromaDB

View File

@@ -9,10 +9,12 @@ from gptcache.manager import get_data_manager
from gptcache.manager.scalar_data.base import Answer
from gptcache.manager.scalar_data.base import DataType as CacheDataType
from gptcache.session import Session
from gptcache.similarity_evaluation.distance import \
SearchDistanceEvaluation # noqa: F401
from gptcache.similarity_evaluation.exact_match import \
ExactMatchEvaluation # noqa: F401
from gptcache.similarity_evaluation.distance import ( # noqa: F401
SearchDistanceEvaluation,
)
from gptcache.similarity_evaluation.exact_match import ( # noqa: F401
ExactMatchEvaluation,
)
logger = logging.getLogger(__name__)

View File

@@ -14,13 +14,21 @@ import requests
from rich.console import Console
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils.cli import (deploy_fly, deploy_gradio_app,
deploy_hf_spaces, deploy_modal,
deploy_render, deploy_streamlit,
get_pkg_path_from_name, setup_fly_io_app,
setup_gradio_app, setup_hf_app,
setup_modal_com_app, setup_render_com_app,
setup_streamlit_io_app)
from embedchain.utils.cli import (
deploy_fly,
deploy_gradio_app,
deploy_hf_spaces,
deploy_modal,
deploy_render,
deploy_streamlit,
get_pkg_path_from_name,
setup_fly_io_app,
setup_gradio_app,
setup_hf_app,
setup_modal_com_app,
setup_render_com_app,
setup_streamlit_io_app,
)
console = Console()
api_process = None

View File

@@ -8,8 +8,8 @@ from .embedder.base import BaseEmbedderConfig
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
from .embedder.ollama import OllamaEmbedderConfig
from .llm.base import BaseLlmConfig
from .mem0_config import Mem0Config
from .vector_db.chroma import ChromaDbConfig
from .vector_db.elasticsearch import ElasticsearchDBConfig
from .vector_db.opensearch import OpenSearchDBConfig
from .vector_db.zilliz import ZillizDBConfig
from .mem0_config import Mem0Config

View File

@@ -1,2 +1,5 @@
from .base import (AnswerRelevanceConfig, ContextRelevanceConfig, # noqa: F401
GroundednessConfig)
from .base import ( # noqa: F401
AnswerRelevanceConfig,
ContextRelevanceConfig,
GroundednessConfig,
)

View File

@@ -1,9 +1,9 @@
import json
import logging
import re
from string import Template
from typing import Any, Mapping, Optional, Dict, Union
from pathlib import Path
from string import Template
from typing import Any, Dict, Mapping, Optional, Union
import httpx

View File

@@ -6,7 +6,12 @@ from typing import Any, Optional, Union
from dotenv import load_dotenv
from langchain.docstore.document import Document
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
from embedchain.cache import (
adapt,
get_gptcache_session,
gptcache_data_convert,
gptcache_update_cache_callback,
)
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig
@@ -16,7 +21,12 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
from embedchain.models.data_type import (
DataType,
DirectDataType,
IndirectDataType,
SpecialDataType,
)
from embedchain.utils.misc import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB

View File

@@ -1,18 +1,18 @@
import os
from typing import Optional, Union
from chromadb import EmbeddingFunction, Embeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from chromadb import EmbeddingFunction, Embeddings
class ClarifaiEmbeddingFunction(EmbeddingFunction):
def __init__(self, config: BaseEmbedderConfig) -> None:
super().__init__()
try:
from clarifai.client.model import Model
from clarifai.client.input import Inputs
from clarifai.client.model import Model
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for ClarifaiEmbeddingFunction are not installed."

View File

@@ -9,7 +9,9 @@ class GPT4AllEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config)
from langchain_community.embeddings import GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
from langchain_community.embeddings import (
GPT4AllEmbeddings as LangchainGPT4AllEmbeddings,
)
model_name = self.config.model or "all-MiniLM-L6-v2-f16.gguf"
gpt4all_kwargs = {'allow_download': 'True'}

View File

@@ -3,7 +3,6 @@ from typing import Optional
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions

View File

@@ -45,8 +45,9 @@ class AWSBedrockLlm(BaseLlm):
}
if config.stream:
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout import (
StreamingStdOutCallbackHandler,
)
kwargs["streaming"] = True
kwargs["callbacks"] = [StreamingStdOutCallbackHandler()]

View File

@@ -5,7 +5,6 @@ from typing import Any, Optional
from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.constants import SQLITE_PATH
from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import (
DEFAULT_PROMPT,
@@ -13,6 +12,7 @@ from embedchain.config.llm.base import (
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE,
)
from embedchain.constants import SQLITE_PATH
from embedchain.core.db.database import init_db, setup_engine
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory

View File

@@ -26,8 +26,7 @@ class GPT4ALLLlm(BaseLlm):
@staticmethod
def _get_instance(model):
try:
from langchain_community.llms.gpt4all import \
GPT4All as LangchainGPT4All
from langchain_community.llms.gpt4all import GPT4All as LangchainGPT4All
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501

View File

@@ -35,8 +35,9 @@ class JinaLlm(BaseLlm):
if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream:
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout import (
StreamingStdOutCallbackHandler,
)
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
else:

View File

@@ -1,6 +1,7 @@
import hashlib
from langchain_community.document_loaders import PyPDFLoader
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string

View File

@@ -28,7 +28,9 @@ class RSSFeedLoader(BaseLoader):
@staticmethod
def get_rss_content(url: str):
try:
from langchain_community.document_loaders import RSSFeedLoader as LangchainRSSFeedLoader
from langchain_community.document_loaders import (
RSSFeedLoader as LangchainRSSFeedLoader,
)
except ImportError:
raise ImportError(
"""RSSFeedLoader file requires extra dependencies.

View File

@@ -11,8 +11,7 @@ class UnstructuredLoader(BaseLoader):
"""Load data from an Unstructured file."""
try:
import unstructured # noqa: F401
from langchain_community.document_loaders import \
UnstructuredFileLoader
from langchain_community.document_loaders import UnstructuredFileLoader
except ImportError:
raise ImportError(
'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501

View File

@@ -6,8 +6,15 @@ from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB
try:
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
MilvusClient, connections, utility)
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusClient,
connections,
utility,
)
except ImportError:
raise ImportError(
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"

View File

@@ -8,8 +8,7 @@ import streamlit as st
from embedchain import App
from embedchain.config import BaseLlmConfig
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
generate)
from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
def embedchain_bot(db_path, api_key):

View File

@@ -8,8 +8,7 @@ import streamlit as st
from embedchain import App
from embedchain.config import BaseLlmConfig
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
generate)
from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
@st.cache_resource

View File

@@ -4,8 +4,7 @@ import streamlit as st
from embedchain import App
from embedchain.config import BaseLlmConfig
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
generate)
from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
@st.cache_resource

View File

@@ -1,4 +1,4 @@
from unittest.mock import patch, Mock
from unittest.mock import Mock, patch
import httpx

View File

@@ -1,5 +1,6 @@
from unittest.mock import patch
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.huggingface import HuggingFaceEmbedder

View File

@@ -4,8 +4,10 @@ from string import Template
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.helpers.json_serializable import (JSONSerializable,
register_deserializable)
from embedchain.helpers.json_serializable import (
JSONSerializable,
register_deserializable,
)
class TestJsonSerializable(unittest.TestCase):

View File

@@ -1,7 +1,9 @@
from unittest import mock
import pytest
@pytest.fixture(autouse=True)
def mock_alembic_command_upgrade():
with mock.patch("alembic.command.upgrade"):

View File

@@ -1,4 +1,4 @@
from unittest.mock import Mock, MagicMock, patch
from unittest.mock import MagicMock, Mock, patch
import httpx
import pytest

View File

@@ -1,8 +1,8 @@
import pytest
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from embedchain.config import BaseLlmConfig
from embedchain.llm.ollama import OllamaLlm
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@pytest.fixture

View File

@@ -1,5 +1,9 @@
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.models.data_type import (
DataType,
DirectDataType,
IndirectDataType,
SpecialDataType,
)
def test_subclass_types_in_data_type():