Sets up metadata db for every llm class (#1401)

This commit is contained in:
Pranav Puranik
2024-08-01 14:15:28 -05:00
committed by GitHub
parent 58b6887bf5
commit b386e24f5d
5 changed files with 22 additions and 14 deletions

View File

@@ -20,7 +20,7 @@ from embedchain.cache import (
)
from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
from embedchain.core.db.database import get_session, init_db, setup_engine
from embedchain.core.db.database import get_session
from embedchain.core.db.models import DataSource
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
@@ -89,10 +89,6 @@ class App(EmbedChain):
if name and config:
raise Exception("Cannot provide both name and config. Please provide only one of them.")
# Initialize the metadata db for the app
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
init_db()
self.auto_deploy = auto_deploy
# Store the dict config as an attribute to be able to send it
self.config_data = config_data if (config_data and validate_config(config_data)) else None
@@ -389,10 +385,6 @@ class App(EmbedChain):
vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
if llm_config_data:
# Initialize the metadata db for the app here since llmfactory needs it for initialization of
# the llm memory
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
init_db()
llm_provider = llm_config_data.get("provider", "openai")
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
else:

View File

@@ -1,9 +1,11 @@
import logging
import os
from collections.abc import Generator
from typing import Any, Optional
from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.constants import SQLITE_PATH
from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import (
DEFAULT_PROMPT,
@@ -11,6 +13,7 @@ from embedchain.config.llm.base import (
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE,
)
from embedchain.core.db.database import init_db, setup_engine
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
@@ -30,6 +33,11 @@ class BaseLlm(JSONSerializable):
else:
self.config = config
# Initialize the metadata db for the app here since llmfactory needs it for initialization of
# the llm memory
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI", f"sqlite:///{SQLITE_PATH}"))
init_db()
self.memory = ChatHistory()
self.is_docs_site_instance = False
self.history: Any = None