refactor: classes and configs (#528)
This commit is contained in:
@@ -1,26 +1,25 @@
|
||||
from embedchain import CustomApp
|
||||
from embedchain.config import AddConfig, CustomAppConfig, QueryConfig
|
||||
from embedchain.config import AddConfig, CustomAppConfig, LlmConfig
|
||||
from embedchain.embedder.openai_embedder import OpenAiEmbedder
|
||||
from embedchain.helper_classes.json_serializable import (
|
||||
JSONSerializable, register_deserializable)
|
||||
from embedchain.models import EmbeddingFunctions, Providers
|
||||
from embedchain.llm.openai_llm import OpenAiLlm
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class BaseBot(JSONSerializable):
|
||||
def __init__(self, app_config=None):
|
||||
if app_config is None:
|
||||
app_config = CustomAppConfig(embedding_fn=EmbeddingFunctions.OPENAI, provider=Providers.OPENAI)
|
||||
self.app_config = app_config
|
||||
self.app = CustomApp(config=self.app_config)
|
||||
def __init__(self):
|
||||
self.app = CustomApp(config=CustomAppConfig(), llm=OpenAiLlm(), db=ChromaDB(), embedder=OpenAiEmbedder())
|
||||
|
||||
def add(self, data, config: AddConfig = None):
|
||||
"""Add data to the bot"""
|
||||
config = config if config else AddConfig()
|
||||
self.app.add(data, config=config)
|
||||
|
||||
def query(self, query, config: QueryConfig = None):
|
||||
def query(self, query, config: LlmConfig = None):
|
||||
"""Query bot"""
|
||||
config = config if config else QueryConfig()
|
||||
config = config
|
||||
return self.app.query(query, config=config)
|
||||
|
||||
def start(self):
|
||||
|
||||
@@ -6,6 +6,8 @@ import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
|
||||
intents = discord.Intents.default()
|
||||
@@ -17,6 +19,7 @@ tree = app_commands.CommandTree(client)
|
||||
# https://discord.com/api/oauth2/authorize?client_id={DISCORD_CLIENT_ID}&permissions=2048&scope=bot
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DiscordBot(BaseBot):
|
||||
def __init__(self, *args, **kwargs):
|
||||
BaseBot.__init__(self, *args, **kwargs)
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import List, Optional
|
||||
|
||||
from fastapi_poe import PoeBot, run
|
||||
|
||||
from embedchain.config import QueryConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
@@ -46,7 +45,6 @@ class PoeBot(BaseBot, PoeBot):
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error when processing the chat history. Message is being sent without history. Error: {e}")
|
||||
logging.warning(history)
|
||||
answer = self.handle_message(last_message, history)
|
||||
yield self.text_event(answer)
|
||||
|
||||
@@ -69,8 +67,8 @@ class PoeBot(BaseBot, PoeBot):
|
||||
|
||||
def ask_bot(self, message, history: List[str]):
|
||||
try:
|
||||
config = QueryConfig(history=history)
|
||||
response = self.query(message, config)
|
||||
self.app.llm.set_history(history=history)
|
||||
response = self.query(message)
|
||||
except Exception:
|
||||
logging.exception(f"Failed to query {message}.")
|
||||
response = "An error occurred. Please try again!"
|
||||
|
||||
Reference in New Issue
Block a user