Lint and formatting fixes (#554)

Co-authored-by: cachho <admin@ch-webdev.com>
Co-authored-by: Taranjeet Singh <reachtotj@gmail.com>
This commit is contained in:
Dev Khant
2023-09-06 04:24:19 +05:30
committed by GitHub
parent 6481b555b4
commit 129242534d
13 changed files with 33 additions and 34 deletions

View File

@@ -1,4 +1,5 @@
from embedchain.bots.poe import PoeBot from embedchain.bots.poe import PoeBot # noqa: F401
from embedchain.bots.whatsapp import WhatsAppBot from embedchain.bots.whatsapp import WhatsAppBot # noqa: F401
# TODO: fix discord import # TODO: fix discord import
# from embedchain.bots.discord import DiscordBot # from embedchain.bots.discord import DiscordBot

View File

@@ -1,4 +1,5 @@
import argparse import argparse
import importlib
import logging import logging
import signal import signal
import sys import sys
@@ -11,8 +12,14 @@ from .base import BaseBot
@register_deserializable @register_deserializable
class WhatsAppBot(BaseBot): class WhatsAppBot(BaseBot):
def __init__(self): def __init__(self):
from flask import Flask, request try:
from twilio.twiml.messaging_response import MessagingResponse self.flask = importlib.import_module("flask")
self.twilio = importlib.import_module("twilio")
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for WhatsApp are not installed. "
'Please install with `pip install --upgrade "embedchain[whatsapp]"`'
) from None
super().__init__() super().__init__()
def handle_message(self, message): def handle_message(self, message):
@@ -41,7 +48,7 @@ class WhatsAppBot(BaseBot):
return response return response
def start(self, host="0.0.0.0", port=5000, debug=True): def start(self, host="0.0.0.0", port=5000, debug=True):
app = Flask(__name__) app = self.flask.Flask(__name__)
def signal_handler(sig, frame): def signal_handler(sig, frame):
logging.info("\nGracefully shutting down the WhatsAppBot...") logging.info("\nGracefully shutting down the WhatsAppBot...")
@@ -51,9 +58,9 @@ class WhatsAppBot(BaseBot):
@app.route("/chat", methods=["POST"]) @app.route("/chat", methods=["POST"])
def chat(): def chat():
incoming_message = request.values.get("Body", "").lower() incoming_message = self.flask.request.values.get("Body", "").lower()
response = self.handle_message(incoming_message) response = self.handle_message(incoming_message)
twilio_response = MessagingResponse() twilio_response = self.twilio.twiml.messaging_response.MessagingResponse()
twilio_response.message(response) twilio_response.message(response)
return str(twilio_response) return str(twilio_response)

View File

@@ -82,7 +82,7 @@ class EmbedChain(JSONSerializable):
# Send anonymous telemetry # Send anonymous telemetry
self.s_id = self.config.id if self.config.id else str(uuid.uuid4()) self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
self.u_id = self._load_or_generate_user_id() self.u_id = self._load_or_generate_user_id()
# NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event. # NOTE: Uncomment the next two lines when running tests to see if any test fires a telemetry event.
# if (self.config.collect_metrics): # if (self.config.collect_metrics):
# raise ConnectionRefusedError("Collection of metrics should not be allowed.") # raise ConnectionRefusedError("Collection of metrics should not be allowed.")
thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",)) thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("init",))

View File

@@ -2,9 +2,8 @@ import logging
from typing import Optional from typing import Optional
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.base_llm import BaseLlm
@register_deserializable @register_deserializable

View File

@@ -2,9 +2,8 @@ import logging
from typing import Optional from typing import Optional
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.base_llm import BaseLlm
@register_deserializable @register_deserializable

View File

@@ -3,12 +3,12 @@ from typing import List, Optional
from langchain.memory import ConversationBufferMemory from langchain.memory import ConversationBufferMemory
from langchain.schema import BaseMessage from langchain.schema import BaseMessage
from embedchain.helper_classes.json_serializable import JSONSerializable
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base_llm_config import ( from embedchain.config.llm.base_llm_config import (
DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE) DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.helper_classes.json_serializable import JSONSerializable
class BaseLlm(JSONSerializable): class BaseLlm(JSONSerializable):

View File

@@ -1,9 +1,8 @@
from typing import Iterable, Optional, Union from typing import Iterable, Optional, Union
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.base_llm import BaseLlm
@register_deserializable @register_deserializable

View File

@@ -4,9 +4,8 @@ from typing import Optional
from langchain.llms import Replicate from langchain.llms import Replicate
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.base_llm import BaseLlm
@register_deserializable @register_deserializable

View File

@@ -3,9 +3,8 @@ from typing import Optional
import openai import openai
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.base_llm import BaseLlm
@register_deserializable @register_deserializable

View File

@@ -2,9 +2,8 @@ import logging
from typing import Optional from typing import Optional
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.base_llm import BaseLlm
@register_deserializable @register_deserializable

View File

@@ -1,7 +1,6 @@
import os import os
import unittest import unittest
from unittest.mock import patch, MagicMock from unittest.mock import MagicMock, patch
from embedchain import App from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig from embedchain.config import AppConfig, BaseLlmConfig
@@ -88,8 +87,8 @@ class TestApp(unittest.TestCase):
self.assertEqual(answer, "Test answer") self.assertEqual(answer, "Test answer")
_args, kwargs = mock_retrieve.call_args _args, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get('input_query'), "Test query") self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"}) self.assertEqual(kwargs.get("where"), {"attribute": "value"})
mock_answer.assert_called_once() mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock) @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
@@ -120,6 +119,6 @@ class TestApp(unittest.TestCase):
self.assertEqual(answer, "Test answer") self.assertEqual(answer, "Test answer")
_args, kwargs = mock_database_query.call_args _args, kwargs = mock_database_query.call_args
self.assertEqual(kwargs.get('input_query'), "Test query") self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"}) self.assertEqual(kwargs.get("where"), {"attribute": "value"})
mock_answer.assert_called_once() mock_answer.assert_called_once()

View File

@@ -109,8 +109,8 @@ class TestApp(unittest.TestCase):
self.assertEqual(answer, "Test answer") self.assertEqual(answer, "Test answer")
_args, kwargs = mock_retrieve.call_args _args, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get('input_query'), "Test query") self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"}) self.assertEqual(kwargs.get("where"), {"attribute": "value"})
mock_answer.assert_called_once() mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock) @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
@@ -142,6 +142,6 @@ class TestApp(unittest.TestCase):
self.assertEqual(answer, "Test answer") self.assertEqual(answer, "Test answer")
_args, kwargs = mock_database_query.call_args _args, kwargs = mock_database_query.call_args
self.assertEqual(kwargs.get('input_query'), "Test query") self.assertEqual(kwargs.get("input_query"), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"}) self.assertEqual(kwargs.get("where"), {"attribute": "value"})
mock_answer.assert_called_once() mock_answer.assert_called_once()

View File

@@ -7,7 +7,6 @@ from chromadb.config import Settings
from embedchain import App from embedchain import App
from embedchain.config import AppConfig, ChromaDbConfig from embedchain.config import AppConfig, ChromaDbConfig
from embedchain.models import EmbeddingFunctions, Providers
from embedchain.vectordb.chroma_db import ChromaDB from embedchain.vectordb.chroma_db import ChromaDB
@@ -86,7 +85,6 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
""" """
Test if the `App` instance is initialized without a config that does not contain default hosts and ports. Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
""" """
config = AppConfig(log_level="DEBUG")
_app = App(config=AppConfig(collect_metrics=False)) _app = App(config=AppConfig(collect_metrics=False))