Sets up metadata db for every llm class (#1401)

This commit is contained in:
Pranav Puranik
2024-08-01 14:15:28 -05:00
committed by GitHub
parent 58b6887bf5
commit b386e24f5d
5 changed files with 22 additions and 14 deletions

View File

@@ -20,7 +20,7 @@ from embedchain.cache import (
) )
from embedchain.client import Client from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
from embedchain.core.db.database import get_session, init_db, setup_engine from embedchain.core.db.database import get_session
from embedchain.core.db.models import DataSource from embedchain.core.db.models import DataSource
from embedchain.embedchain import EmbedChain from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
@@ -89,10 +89,6 @@ class App(EmbedChain):
if name and config: if name and config:
raise Exception("Cannot provide both name and config. Please provide only one of them.") raise Exception("Cannot provide both name and config. Please provide only one of them.")
# Initialize the metadata db for the app
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
init_db()
self.auto_deploy = auto_deploy self.auto_deploy = auto_deploy
# Store the dict config as an attribute to be able to send it # Store the dict config as an attribute to be able to send it
self.config_data = config_data if (config_data and validate_config(config_data)) else None self.config_data = config_data if (config_data and validate_config(config_data)) else None
@@ -389,10 +385,6 @@ class App(EmbedChain):
vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {})) vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
if llm_config_data: if llm_config_data:
# Initialize the metadata db for the app here since llmfactory needs it for initialization of
# the llm memory
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
init_db()
llm_provider = llm_config_data.get("provider", "openai") llm_provider = llm_config_data.get("provider", "openai")
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {})) llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
else: else:

View File

@@ -1,9 +1,11 @@
import logging import logging
import os
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional from typing import Any, Optional
from langchain.schema import BaseMessage as LCBaseMessage from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.constants import SQLITE_PATH
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import ( from embedchain.config.llm.base import (
DEFAULT_PROMPT, DEFAULT_PROMPT,
@@ -11,6 +13,7 @@ from embedchain.config.llm.base import (
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE, DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE,
) )
from embedchain.core.db.database import init_db, setup_engine
from embedchain.helpers.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage from embedchain.memory.message import ChatMessage
@@ -30,6 +33,11 @@ class BaseLlm(JSONSerializable):
else: else:
self.config = config self.config = config
# Initialize the metadata db for the app here since llmfactory needs it for initialization of
# the llm memory
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI", f"sqlite:///{SQLITE_PATH}"))
init_db()
self.memory = ChatHistory() self.memory = ChatHistory()
self.is_docs_site_instance = False self.is_docs_site_instance = False
self.history: Any = None self.history: Any = None

View File

@@ -2391,18 +2391,18 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0"
[[package]] [[package]]
name = "langchain-aws" name = "langchain-aws"
version = "0.1.10" version = "0.1.13"
description = "An integration package connecting AWS and LangChain" description = "An integration package connecting AWS and LangChain"
optional = true optional = true
python-versions = "<4.0,>=3.8.1" python-versions = "<4.0,>=3.8.1"
files = [ files = [
{file = "langchain_aws-0.1.10-py3-none-any.whl", hash = "sha256:2cba72efaa9f0dc406d8e06a1fbaa3762678d489cbc5147cf64a7012189c161c"}, {file = "langchain_aws-0.1.13-py3-none-any.whl", hash = "sha256:c4db60c8a83b8ff3e66170e0bd646739176fcd1a20a9d0a10828a1e21339af1d"},
{file = "langchain_aws-0.1.10.tar.gz", hash = "sha256:7f01dacbf8345a28192cec4ef31018cc33a91de0b82122f913eec09a76d64fd5"}, {file = "langchain_aws-0.1.13.tar.gz", hash = "sha256:fda790732a72de4ccec3760dba24db5f9fa5cb8724dfd9676a7d5cf87a9f1a98"},
] ]
[package.dependencies] [package.dependencies]
boto3 = ">=1.34.131,<1.35.0" boto3 = ">=1.34.131,<1.35.0"
langchain-core = ">=0.2.6,<0.3" langchain-core = ">=0.2.17,<0.3"
numpy = ">=1,<2" numpy = ">=1,<2"
[[package]] [[package]]

View File

@@ -0,0 +1,8 @@
from unittest import mock
import pytest
@pytest.fixture(autouse=True)
def mock_alembic_command_upgrade():
with mock.patch("alembic.command.upgrade"):
yield