Sets up metadata db for every llm class (#1401)
This commit is contained in:
@@ -20,7 +20,7 @@ from embedchain.cache import (
|
|||||||
)
|
)
|
||||||
from embedchain.client import Client
|
from embedchain.client import Client
|
||||||
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
|
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
|
||||||
from embedchain.core.db.database import get_session, init_db, setup_engine
|
from embedchain.core.db.database import get_session
|
||||||
from embedchain.core.db.models import DataSource
|
from embedchain.core.db.models import DataSource
|
||||||
from embedchain.embedchain import EmbedChain
|
from embedchain.embedchain import EmbedChain
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
@@ -89,10 +89,6 @@ class App(EmbedChain):
|
|||||||
if name and config:
|
if name and config:
|
||||||
raise Exception("Cannot provide both name and config. Please provide only one of them.")
|
raise Exception("Cannot provide both name and config. Please provide only one of them.")
|
||||||
|
|
||||||
# Initialize the metadata db for the app
|
|
||||||
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
|
|
||||||
init_db()
|
|
||||||
|
|
||||||
self.auto_deploy = auto_deploy
|
self.auto_deploy = auto_deploy
|
||||||
# Store the dict config as an attribute to be able to send it
|
# Store the dict config as an attribute to be able to send it
|
||||||
self.config_data = config_data if (config_data and validate_config(config_data)) else None
|
self.config_data = config_data if (config_data and validate_config(config_data)) else None
|
||||||
@@ -389,10 +385,6 @@ class App(EmbedChain):
|
|||||||
vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
|
vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
|
||||||
|
|
||||||
if llm_config_data:
|
if llm_config_data:
|
||||||
# Initialize the metadata db for the app here since llmfactory needs it for initialization of
|
|
||||||
# the llm memory
|
|
||||||
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
|
|
||||||
init_db()
|
|
||||||
llm_provider = llm_config_data.get("provider", "openai")
|
llm_provider = llm_config_data.get("provider", "openai")
|
||||||
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
|
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain.schema import BaseMessage as LCBaseMessage
|
from langchain.schema import BaseMessage as LCBaseMessage
|
||||||
|
|
||||||
|
from embedchain.constants import SQLITE_PATH
|
||||||
from embedchain.config import BaseLlmConfig
|
from embedchain.config import BaseLlmConfig
|
||||||
from embedchain.config.llm.base import (
|
from embedchain.config.llm.base import (
|
||||||
DEFAULT_PROMPT,
|
DEFAULT_PROMPT,
|
||||||
@@ -11,6 +13,7 @@ from embedchain.config.llm.base import (
|
|||||||
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
|
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
|
||||||
DOCS_SITE_PROMPT_TEMPLATE,
|
DOCS_SITE_PROMPT_TEMPLATE,
|
||||||
)
|
)
|
||||||
|
from embedchain.core.db.database import init_db, setup_engine
|
||||||
from embedchain.helpers.json_serializable import JSONSerializable
|
from embedchain.helpers.json_serializable import JSONSerializable
|
||||||
from embedchain.memory.base import ChatHistory
|
from embedchain.memory.base import ChatHistory
|
||||||
from embedchain.memory.message import ChatMessage
|
from embedchain.memory.message import ChatMessage
|
||||||
@@ -30,6 +33,11 @@ class BaseLlm(JSONSerializable):
|
|||||||
else:
|
else:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
# Initialize the metadata db for the app here since llmfactory needs it for initialization of
|
||||||
|
# the llm memory
|
||||||
|
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI", f"sqlite:///{SQLITE_PATH}"))
|
||||||
|
init_db()
|
||||||
|
|
||||||
self.memory = ChatHistory()
|
self.memory = ChatHistory()
|
||||||
self.is_docs_site_instance = False
|
self.is_docs_site_instance = False
|
||||||
self.history: Any = None
|
self.history: Any = None
|
||||||
|
|||||||
8
embedchain/poetry.lock
generated
8
embedchain/poetry.lock
generated
@@ -2391,18 +2391,18 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-aws"
|
name = "langchain-aws"
|
||||||
version = "0.1.10"
|
version = "0.1.13"
|
||||||
description = "An integration package connecting AWS and LangChain"
|
description = "An integration package connecting AWS and LangChain"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = "<4.0,>=3.8.1"
|
python-versions = "<4.0,>=3.8.1"
|
||||||
files = [
|
files = [
|
||||||
{file = "langchain_aws-0.1.10-py3-none-any.whl", hash = "sha256:2cba72efaa9f0dc406d8e06a1fbaa3762678d489cbc5147cf64a7012189c161c"},
|
{file = "langchain_aws-0.1.13-py3-none-any.whl", hash = "sha256:c4db60c8a83b8ff3e66170e0bd646739176fcd1a20a9d0a10828a1e21339af1d"},
|
||||||
{file = "langchain_aws-0.1.10.tar.gz", hash = "sha256:7f01dacbf8345a28192cec4ef31018cc33a91de0b82122f913eec09a76d64fd5"},
|
{file = "langchain_aws-0.1.13.tar.gz", hash = "sha256:fda790732a72de4ccec3760dba24db5f9fa5cb8724dfd9676a7d5cf87a9f1a98"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
boto3 = ">=1.34.131,<1.35.0"
|
boto3 = ">=1.34.131,<1.35.0"
|
||||||
langchain-core = ">=0.2.6,<0.3"
|
langchain-core = ">=0.2.17,<0.3"
|
||||||
numpy = ">=1,<2"
|
numpy = ">=1,<2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
8
embedchain/tests/llm/conftest.py
Normal file
8
embedchain/tests/llm/conftest.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
|
||||||
|
from unittest import mock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_alembic_command_upgrade():
|
||||||
|
with mock.patch("alembic.command.upgrade"):
|
||||||
|
yield
|
||||||
Reference in New Issue
Block a user