[Feature] Add support to use any sql database as the metadata storage for embedchain apps (#1273)

This commit is contained in:
Deshraj Yadav
2024-02-19 13:04:18 -08:00
committed by GitHub
parent 6c12bc9044
commit 5e2e7fb639
20 changed files with 601 additions and 202 deletions

1
.gitignore vendored
View File

@@ -165,6 +165,7 @@ cython_debug/
# Database # Database
db db
test-db test-db
!embedchain/core/db/
.vscode .vscode
.idea/ .idea/

View File

@@ -7,4 +7,4 @@ from embedchain.client import Client # noqa: F401
from embedchain.pipeline import Pipeline # noqa: F401 from embedchain.pipeline import Pipeline # noqa: F401
# Setup the user directory if doesn't exist already # Setup the user directory if doesn't exist already
Client.setup_dir() Client.setup()

116
embedchain/alembic.ini Normal file
View File

@@ -0,0 +1,116 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = embedchain/migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = WARN
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -3,7 +3,6 @@ import concurrent.futures
import json import json
import logging import logging
import os import os
import sqlite3
import uuid import uuid
from typing import Any, Optional, Union from typing import Any, Optional, Union
@@ -16,7 +15,8 @@ from embedchain.cache import (Config, ExactMatchEvaluation,
gptcache_data_manager, gptcache_pre_function) gptcache_data_manager, gptcache_pre_function)
from embedchain.client import Client from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.constants import SQLITE_PATH from embedchain.core.db.database import get_session
from embedchain.core.db.models import DataSource
from embedchain.embedchain import EmbedChain from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder from embedchain.embedder.openai import OpenAIEmbedder
@@ -33,9 +33,6 @@ from embedchain.utils.misc import validate_config
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB from embedchain.vectordb.chroma import ChromaDB
# Set up the user directory if it doesn't exist already
Client.setup_dir()
@register_deserializable @register_deserializable
class App(EmbedChain): class App(EmbedChain):
@@ -120,6 +117,9 @@ class App(EmbedChain):
self.llm = llm or OpenAILlm() self.llm = llm or OpenAILlm()
self._init_db() self._init_db()
# Session for the metadata db
self.db_session = get_session()
# If cache_config is provided, initializing the cache ... # If cache_config is provided, initializing the cache ...
if self.cache_config is not None: if self.cache_config is not None:
self._init_cache() self._init_cache()
@@ -127,27 +127,6 @@ class App(EmbedChain):
# Send anonymous telemetry # Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__} self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics) self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
# Establish a connection to the SQLite database
self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
self.cursor = self.connection.cursor()
# Create the 'data_sources' table if it doesn't exist
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS data_sources (
pipeline_id TEXT,
hash TEXT,
type TEXT,
value TEXT,
metadata TEXT,
is_uploaded INTEGER DEFAULT 0,
PRIMARY KEY (pipeline_id, hash)
)
"""
)
self.connection.commit()
# Send anonymous telemetry
self.telemetry.capture(event_name="init", properties=self._telemetry_props) self.telemetry.capture(event_name="init", properties=self._telemetry_props)
self.user_asks = [] self.user_asks = []
@@ -307,20 +286,14 @@ class App(EmbedChain):
return False return False
def _mark_data_as_uploaded(self, data_hash): def _mark_data_as_uploaded(self, data_hash):
self.cursor.execute( self.db_session.query(DataSource).filter_by(hash=data_hash, app_id=self.local_id).update({"is_uploaded": 1})
"UPDATE data_sources SET is_uploaded = 1 WHERE hash = ? AND pipeline_id = ?",
(data_hash, self.local_id),
)
self.connection.commit()
def get_data_sources(self): def get_data_sources(self):
db_data = self.cursor.execute("SELECT * FROM data_sources WHERE pipeline_id = ?", (self.local_id,)).fetchall() data_sources = self.db_session.query(DataSource).filter_by(app_id=self.local_id).all()
results = []
data_sources = [] for row in data_sources:
for data in db_data: results.append({"data_type": row.data_type, "data_value": row.data_value, "metadata": row.metadata})
data_sources.append({"data_type": data[2], "data_value": data[3], "metadata": data[4]}) return results
return data_sources
def deploy(self): def deploy(self):
if self.client is None: if self.client is None:
@@ -329,14 +302,11 @@ class App(EmbedChain):
pipeline_data = self._create_pipeline() pipeline_data = self._create_pipeline()
self.id = pipeline_data["id"] self.id = pipeline_data["id"]
results = self.cursor.execute( results = self.db_session.query(DataSource).filter_by(app_id=self.local_id, is_uploaded=0).all()
"SELECT * FROM data_sources WHERE pipeline_id = ? AND is_uploaded = 0", (self.local_id,) # noqa:E501
).fetchall()
if len(results) > 0: if len(results) > 0:
print("🛠️ Adding data to your pipeline...") print("🛠️ Adding data to your pipeline...")
for result in results: for result in results:
data_hash, data_type, data_value = result[1], result[2], result[3] data_hash, data_type, data_value = result.hash, result.data_type, result.data_value
self._process_and_upload_data(data_hash, data_type, data_value) self._process_and_upload_data(data_hash, data_type, data_value)
# Send anonymous telemetry # Send anonymous telemetry
@@ -423,10 +393,6 @@ class App(EmbedChain):
else: else:
cache_config = None cache_config = None
# Send anonymous telemetry
event_properties = {"init_type": "config_data"}
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
return cls( return cls(
config=app_config, config=app_config,
llm=llm, llm=llm,

View File

@@ -5,7 +5,8 @@ import uuid
import requests import requests
from embedchain.constants import CONFIG_DIR, CONFIG_FILE from embedchain.constants import CONFIG_DIR, CONFIG_FILE, DB_URI
from embedchain.core.db.database import init_db, setup_engine
class Client: class Client:
@@ -31,7 +32,7 @@ class Client:
) )
@classmethod @classmethod
def setup_dir(cls): def setup(cls):
""" """
Loads the user id from the config file if it exists, otherwise generates a new Loads the user id from the config file if it exists, otherwise generates a new
one and saves it to the config file. one and saves it to the config file.
@@ -40,6 +41,9 @@ class Client:
:rtype: str :rtype: str
""" """
os.makedirs(CONFIG_DIR, exist_ok=True) os.makedirs(CONFIG_DIR, exist_ok=True)
setup_engine(database_uri=DB_URI)
init_db()
if os.path.exists(CONFIG_FILE): if os.path.exists(CONFIG_FILE):
with open(CONFIG_FILE, "r") as f: with open(CONFIG_FILE, "r") as f:
data = json.load(f) data = json.load(f)
@@ -53,7 +57,7 @@ class Client:
@classmethod @classmethod
def load_config(cls): def load_config(cls):
if not os.path.exists(CONFIG_FILE): if not os.path.exists(CONFIG_FILE):
cls.setup_dir() cls.setup()
with open(CONFIG_FILE, "r") as config_file: with open(CONFIG_FILE, "r") as config_file:
return json.load(config_file) return json.load(config_file)

View File

@@ -6,3 +6,4 @@ HOME_DIR = str(Path.home())
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain") CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json") CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db") SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
DB_URI = f"sqlite:///{SQLITE_PATH}"

View File

View File

View File

@@ -0,0 +1,83 @@
import os
from alembic import command
from alembic.config import Config
from sqlalchemy import create_engine
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session as SQLAlchemySession
from sqlalchemy.orm import scoped_session, sessionmaker
from .models import Base
class DatabaseManager:
def __init__(self, database_uri: str = "sqlite:///embedchain.db", echo: bool = False):
self.database_uri = database_uri
self.echo = echo
self.engine: Engine = None
self._session_factory = None
def setup_engine(self) -> None:
"""Initializes the database engine and session factory."""
self.engine = create_engine(self.database_uri, echo=self.echo, connect_args={"check_same_thread": False})
self._session_factory = scoped_session(sessionmaker(bind=self.engine))
Base.metadata.bind = self.engine
def init_db(self) -> None:
"""Creates all tables defined in the Base metadata."""
if not self.engine:
raise RuntimeError("Database engine is not initialized. Call setup_engine() first.")
Base.metadata.create_all(self.engine)
def get_session(self) -> SQLAlchemySession:
"""Provides a session for database operations."""
if not self._session_factory:
raise RuntimeError("Session factory is not initialized. Call setup_engine() first.")
return self._session_factory()
def close_session(self) -> None:
"""Closes the current session."""
if self._session_factory:
self._session_factory.remove()
def execute_transaction(self, transaction_block):
"""Executes a block of code within a database transaction."""
session = self.get_session()
try:
transaction_block(session)
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
self.close_session()
# Singleton pattern to use throughout the application
database_manager = DatabaseManager()
# Convenience functions for backward compatibility and ease of use
def setup_engine(database_uri: str = "sqlite:///embedchain.db", echo: bool = False) -> None:
database_manager.database_uri = database_uri
database_manager.echo = echo
database_manager.setup_engine()
def alembic_upgrade() -> None:
"""Upgrades the database to the latest version."""
alembic_config_path = os.path.join(os.path.dirname(__file__), "..", "..", "alembic.ini")
alembic_cfg = Config(alembic_config_path)
command.upgrade(alembic_cfg, "head")
def init_db() -> None:
alembic_upgrade()
def get_session() -> SQLAlchemySession:
return database_manager.get_session()
def execute_transaction(transaction_block):
database_manager.execute_transaction(transaction_block)

View File

@@ -0,0 +1,31 @@
import uuid
from sqlalchemy import TIMESTAMP, Column, Integer, String, Text, func
from sqlalchemy.orm import declarative_base
Base = declarative_base()
metadata = Base.metadata
class DataSource(Base):
__tablename__ = "ec_data_sources"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
app_id = Column(Text, index=True)
hash = Column(Text, index=True)
type = Column(Text, index=True)
value = Column(Text)
meta_data = Column(Text, name="metadata")
is_uploaded = Column(Integer, default=0)
class ChatHistory(Base):
__tablename__ = "ec_chat_history"
app_id = Column(String, primary_key=True)
id = Column(String, primary_key=True)
session_id = Column(String, primary_key=True, index=True)
question = Column(Text)
answer = Column(Text)
meta_data = Column(Text, name="metadata")
created_at = Column(TIMESTAMP, default=func.current_timestamp(), index=True)

View File

@@ -1,7 +1,6 @@
import hashlib import hashlib
import json import json
import logging import logging
import sqlite3
from typing import Any, Optional, Union from typing import Any, Optional, Union
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -13,7 +12,7 @@ from embedchain.cache import (adapt, get_gptcache_session,
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig from embedchain.config.base_app_config import BaseAppConfig
from embedchain.constants import SQLITE_PATH from embedchain.core.db.models import DataSource
from embedchain.data_formatter import DataFormatter from embedchain.data_formatter import DataFormatter
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
@@ -21,7 +20,6 @@ from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import (DataType, DirectDataType, from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType) IndirectDataType, SpecialDataType)
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils.misc import detect_datatype, is_valid_json_string from embedchain.utils.misc import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
@@ -85,30 +83,6 @@ class EmbedChain(JSONSerializable):
self.user_asks = [] self.user_asks = []
self.chunker: Optional[ChunkerConfig] = None self.chunker: Optional[ChunkerConfig] = None
# Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
# Establish a connection to the SQLite database
self.connection = sqlite3.connect(SQLITE_PATH, check_same_thread=False)
self.cursor = self.connection.cursor()
# Create the 'data_sources' table if it doesn't exist
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS data_sources (
pipeline_id TEXT,
hash TEXT,
type TEXT,
value TEXT,
metadata TEXT,
is_uploaded INTEGER DEFAULT 0,
PRIMARY KEY (pipeline_id, hash)
)
"""
)
self.connection.commit()
# Send anonymous telemetry
self.telemetry.capture(event_name="init", properties=self._telemetry_props)
@property @property
def collect_metrics(self): def collect_metrics(self):
@@ -204,17 +178,21 @@ class EmbedChain(JSONSerializable):
if data_type in {DataType.DOCS_SITE}: if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True self.is_docs_site_instance = True
# Insert the data into the 'data' table # Insert the data into the 'ec_data_sources' table
self.cursor.execute( self.db_session.add(
""" DataSource(
INSERT OR REPLACE INTO data_sources (hash, pipeline_id, type, value, metadata) hash=source_hash,
VALUES (?, ?, ?, ?, ?) app_id=self.config.id,
""", type=data_type.value,
(source_hash, self.config.id, data_type.value, str(source), json.dumps(metadata)), value=source,
metadata=json.dumps(metadata),
)
) )
try:
# Commit the transaction self.db_session.commit()
self.connection.commit() except Exception as e:
logging.error(f"Error adding data source: {e}")
self.db_session.rollback()
if dry_run: if dry_run:
data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type} data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
@@ -666,9 +644,14 @@ class EmbedChain(JSONSerializable):
Resets the database. Deletes all embeddings irreversibly. Resets the database. Deletes all embeddings irreversibly.
`App` does not have to be reinitialized after using this method. `App` does not have to be reinitialized after using this method.
""" """
try:
self.db_session.query(DataSource).filter_by(app_id=self.config.id).delete()
self.db_session.commit()
except Exception as e:
logging.error(f"Error deleting chat history: {e}")
self.db_session.rollback()
return None
self.db.reset() self.db.reset()
self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
self.connection.commit()
self.delete_all_chat_history(app_id=self.config.id) self.delete_all_chat_history(app_id=self.config.id)
# Send anonymous telemetry # Send anonymous telemetry
self.telemetry.capture(event_name="reset", properties=self._telemetry_props) self.telemetry.capture(event_name="reset", properties=self._telemetry_props)

View File

@@ -1,55 +1,40 @@
import json import json
import logging import logging
import sqlite3
import uuid import uuid
from typing import Any, Optional from typing import Any, Optional
from embedchain.constants import SQLITE_PATH from embedchain.core.db.database import get_session
from embedchain.core.db.models import ChatHistory as ChatHistoryModel
from embedchain.memory.message import ChatMessage from embedchain.memory.message import ChatMessage
from embedchain.memory.utils import merge_metadata_dict from embedchain.memory.utils import merge_metadata_dict
CHAT_MESSAGE_CREATE_TABLE_QUERY = """
CREATE TABLE IF NOT EXISTS ec_chat_history (
app_id TEXT,
id TEXT,
session_id TEXT,
question TEXT,
answer TEXT,
metadata TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id, app_id, session_id)
)
"""
class ChatHistory: class ChatHistory:
def __init__(self) -> None: def __init__(self) -> None:
with sqlite3.connect(SQLITE_PATH, check_same_thread=False) as self.connection: self.db_session = get_session()
self.cursor = self.connection.cursor()
self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
self.connection.commit()
def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]: def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
memory_id = str(uuid.uuid4()) memory_id = str(uuid.uuid4())
metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata) metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
if metadata_dict: if metadata_dict:
metadata = self._serialize_json(metadata_dict) metadata = self._serialize_json(metadata_dict)
ADD_CHAT_MESSAGE_QUERY = """ self.db_session.add(
INSERT INTO ec_chat_history (app_id, id, session_id, question, answer, metadata) ChatHistoryModel(
VALUES (?, ?, ?, ?, ?, ?) app_id=app_id,
""" id=memory_id,
self.cursor.execute( session_id=session_id,
ADD_CHAT_MESSAGE_QUERY, question=chat_message.human_message.content,
( answer=chat_message.ai_message.content,
app_id, metadata=metadata if metadata_dict else "{}",
memory_id, )
session_id,
chat_message.human_message.content,
chat_message.ai_message.content,
metadata if metadata_dict else "{}",
),
) )
self.connection.commit() try:
self.db_session.commit()
except Exception as e:
logging.error(f"Error adding chat memory to db: {e}")
self.db_session.rollback()
return None
logging.info(f"Added chat memory to db with id: {memory_id}") logging.info(f"Added chat memory to db with id: {memory_id}")
return memory_id return memory_id
@@ -63,15 +48,15 @@ class ChatHistory:
:return: None :return: None
""" """
params = {"app_id": app_id}
if session_id: if session_id:
DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?" params["session_id"] = session_id
params = (app_id, session_id) self.db_session.query(ChatHistoryModel).filter_by(**params).delete()
else: try:
DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=?" self.db_session.commit()
params = (app_id,) except Exception as e:
logging.error(f"Error deleting chat history: {e}")
self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params) self.db_session.rollback()
self.connection.commit()
def get( def get(
self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False
@@ -85,50 +70,31 @@ class ChatHistory:
param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False
param: display_format (optional) - Whether to return the chat history in display format. Defaults to False param: display_format (optional) - Whether to return the chat history in display format. Defaults to False
""" """
params = {"app_id": app_id}
base_query = """ if not fetch_all:
SELECT * FROM ec_chat_history params["session_id"] = session_id
WHERE app_id=? results = (
""" self.db_session.query(ChatHistoryModel).filter_by(**params).order_by(ChatHistoryModel.created_at.asc())
if fetch_all:
additional_query = "ORDER BY created_at ASC"
params = (app_id,)
else:
additional_query = """
AND session_id=?
ORDER BY created_at ASC
LIMIT ?
"""
params = (app_id, session_id, num_rounds)
QUERY = base_query + additional_query
self.cursor.execute(
QUERY,
params,
) )
results = results.limit(num_rounds) if not fetch_all else results
results = self.cursor.fetchall()
history = [] history = []
for result in results: for result in results:
app_id, _, session_id, question, answer, metadata, timestamp = result metadata = self._deserialize_json(metadata=result.meta_data or "{}")
metadata = self._deserialize_json(metadata=metadata)
# Return list of dict if display_format is True # Return list of dict if display_format is True
if display_format: if display_format:
history.append( history.append(
{ {
"session_id": session_id, "session_id": result.session_id,
"human": question, "human": result.question,
"ai": answer, "ai": result.answer,
"metadata": metadata, "metadata": result.meta_data,
"timestamp": timestamp, "timestamp": result.created_at,
} }
) )
else: else:
memory = ChatMessage() memory = ChatMessage()
memory.add_user_message(question, metadata=metadata) memory.add_user_message(result.question, metadata=metadata)
memory.add_ai_message(answer, metadata=metadata) memory.add_ai_message(result.answer, metadata=metadata)
history.append(memory) history.append(memory)
return history return history
@@ -141,16 +107,11 @@ class ChatHistory:
:return: The number of chat messages for a given app_id and session_id :return: The number of chat messages for a given app_id and session_id
""" """
# Rewrite the logic below with sqlalchemy
params = {"app_id": app_id}
if session_id: if session_id:
QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?" params["session_id"] = session_id
params = (app_id, session_id) return self.db_session.query(ChatHistoryModel).filter_by(**params).count()
else:
QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=?"
params = (app_id,)
self.cursor.execute(QUERY, params)
count = self.cursor.fetchone()[0]
return count
@staticmethod @staticmethod
def _serialize_json(metadata: dict[str, Any]): def _serialize_json(metadata: dict[str, Any]):

View File

@@ -0,0 +1,74 @@
from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
from embedchain.constants import DB_URI
from embedchain.core.db.models import Base
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
config.set_main_option("sqlalchemy.url", DB_URI)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,62 @@
"""Create initial migrations
Revision ID: 40a327b3debd
Revises:
Create Date: 2024-02-18 15:29:19.409064
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "40a327b3debd"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"ec_chat_history",
sa.Column("app_id", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("session_id", sa.String(), nullable=False),
sa.Column("question", sa.Text(), nullable=True),
sa.Column("answer", sa.Text(), nullable=True),
sa.Column("metadata", sa.Text(), nullable=True),
sa.Column("created_at", sa.TIMESTAMP(), nullable=True),
sa.PrimaryKeyConstraint("app_id", "id", "session_id"),
)
op.create_index(op.f("ix_ec_chat_history_created_at"), "ec_chat_history", ["created_at"], unique=False)
op.create_index(op.f("ix_ec_chat_history_session_id"), "ec_chat_history", ["session_id"], unique=False)
op.create_table(
"ec_data_sources",
sa.Column("id", sa.String(), nullable=False),
sa.Column("app_id", sa.Text(), nullable=True),
sa.Column("hash", sa.Text(), nullable=True),
sa.Column("type", sa.Text(), nullable=True),
sa.Column("value", sa.Text(), nullable=True),
sa.Column("metadata", sa.Text(), nullable=True),
sa.Column("is_uploaded", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_ec_data_sources_hash"), "ec_data_sources", ["hash"], unique=False)
op.create_index(op.f("ix_ec_data_sources_app_id"), "ec_data_sources", ["app_id"], unique=False)
op.create_index(op.f("ix_ec_data_sources_type"), "ec_data_sources", ["type"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_ec_data_sources_type"), table_name="ec_data_sources")
op.drop_index(op.f("ix_ec_data_sources_app_id"), table_name="ec_data_sources")
op.drop_index(op.f("ix_ec_data_sources_hash"), table_name="ec_data_sources")
op.drop_table("ec_data_sources")
op.drop_index(op.f("ix_ec_chat_history_session_id"), table_name="ec_chat_history")
op.drop_index(op.f("ix_ec_chat_history_created_at"), table_name="ec_chat_history")
op.drop_table("ec_chat_history")
# ### end Alembic commands ###

View File

@@ -20,7 +20,7 @@ from embedchain.utils.misc import detect_datatype
logging.basicConfig(level=logging.WARN) logging.basicConfig(level=logging.WARN)
# Set up the user directory if it doesn't exist already # Set up the user directory if it doesn't exist already
Client.setup_dir() Client.setup()
class OpenAIAssistant: class OpenAIAssistant:

110
poetry.lock generated
View File

@@ -161,6 +161,25 @@ files = [
[package.dependencies] [package.dependencies]
typing-extensions = "*" typing-extensions = "*"
[[package]]
name = "alembic"
version = "1.13.1"
description = "A database migration tool for SQLAlchemy."
optional = false
python-versions = ">=3.8"
files = [
{file = "alembic-1.13.1-py3-none-any.whl", hash = "sha256:2edcc97bed0bd3272611ce3a98d98279e9c209e7186e43e75bbb1b2bdfdbcc43"},
{file = "alembic-1.13.1.tar.gz", hash = "sha256:4932c8558bf68f2ee92b9bbcb8218671c627064d5b08939437af6d77dc05e595"},
]
[package.dependencies]
Mako = "*"
SQLAlchemy = ">=1.3.0"
typing-extensions = ">=4"
[package.extras]
tz = ["backports.zoneinfo"]
[[package]] [[package]]
name = "annotated-types" name = "annotated-types"
version = "0.6.0" version = "0.6.0"
@@ -3334,6 +3353,25 @@ html5 = ["html5lib"]
htmlsoup = ["BeautifulSoup4"] htmlsoup = ["BeautifulSoup4"]
source = ["Cython (>=0.29.35)"] source = ["Cython (>=0.29.35)"]
[[package]]
name = "mako"
version = "1.3.2"
description = "A super-fast templating language that borrows the best ideas from the existing templating languages."
optional = false
python-versions = ">=3.8"
files = [
{file = "Mako-1.3.2-py3-none-any.whl", hash = "sha256:32a99d70754dfce237019d17ffe4a282d2d3351b9c476e90d8a60e63f133b80c"},
{file = "Mako-1.3.2.tar.gz", hash = "sha256:2a0c8ad7f6274271b3bb7467dd37cf9cc6dab4bc19cb69a4ef10669402de698e"},
]
[package.dependencies]
MarkupSafe = ">=0.9.2"
[package.extras]
babel = ["Babel"]
lingua = ["lingua"]
testing = ["pytest"]
[[package]] [[package]]
name = "markdown" name = "markdown"
version = "3.5" version = "3.5"
@@ -3380,7 +3418,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
name = "markupsafe" name = "markupsafe"
version = "2.1.3" version = "2.1.3"
description = "Safely add untrusted strings to HTML/XML markup." description = "Safely add untrusted strings to HTML/XML markup."
optional = true optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"},
@@ -6701,28 +6739,70 @@ files = [
[[package]] [[package]]
name = "sqlalchemy" name = "sqlalchemy"
version = "2.0.22" version = "2.0.27"
description = "Database Abstraction Library" description = "Database Abstraction Library"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "SQLAlchemy-2.0.22-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4f6ff392b27a743c1ad346d215655503cec64405d3b694228b3454878bf21590"}, {file = "SQLAlchemy-2.0.27-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d04e579e911562f1055d26dab1868d3e0bb905db3bccf664ee8ad109f035618a"},
{file = "SQLAlchemy-2.0.22-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f776c2c30f0e5f4db45c3ee11a5f2a8d9de68e81eb73ec4237de1e32e04ae81c"}, {file = "SQLAlchemy-2.0.27-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fa67d821c1fd268a5a87922ef4940442513b4e6c377553506b9db3b83beebbd8"},
{file = "SQLAlchemy-2.0.22-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2c9bac865ee06d27a1533471405ad240a6f5d83195eca481f9fc4a71d8b87df8"}, {file = "SQLAlchemy-2.0.27-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c7a596d0be71b7baa037f4ac10d5e057d276f65a9a611c46970f012752ebf2d"},
{file = "SQLAlchemy-2.0.22-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:625b72d77ac8ac23da3b1622e2da88c4aedaee14df47c8432bf8f6495e655de2"}, {file = "SQLAlchemy-2.0.27-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:954d9735ee9c3fa74874c830d089a815b7b48df6f6b6e357a74130e478dbd951"},
{file = "SQLAlchemy-2.0.22-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3940677d341f2b685a999bffe7078697b5848a40b5f6952794ffcf3af150c301"}, {file = "SQLAlchemy-2.0.27-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5cd20f58c29bbf2680039ff9f569fa6d21453fbd2fa84dbdb4092f006424c2e6"},
{file = "SQLAlchemy-2.0.22-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3aa1472bf44f61dd27987cd051f1c893b7d3b17238bff8c23fceaef4f1133868"}, {file = "SQLAlchemy-2.0.27-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:03f448ffb731b48323bda68bcc93152f751436ad6037f18a42b7e16af9e91c07"},
{file = "SQLAlchemy-2.0.22-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:56a7e2bb639df9263bf6418231bc2a92a773f57886d371ddb7a869a24919face"}, {file = "SQLAlchemy-2.0.27-cp310-cp310-win32.whl", hash = "sha256:d997c5938a08b5e172c30583ba6b8aad657ed9901fc24caf3a7152eeccb2f1b4"},
{file = "SQLAlchemy-2.0.22.tar.gz", hash = "sha256:5434cc601aa17570d79e5377f5fd45ff92f9379e2abed0be5e8c2fba8d353d2b"}, {file = "SQLAlchemy-2.0.27-cp310-cp310-win_amd64.whl", hash = "sha256:eb15ef40b833f5b2f19eeae65d65e191f039e71790dd565c2af2a3783f72262f"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6c5bad7c60a392850d2f0fee8f355953abaec878c483dd7c3836e0089f046bf6"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3012ab65ea42de1be81fff5fb28d6db893ef978950afc8130ba707179b4284a"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dbcd77c4d94b23e0753c5ed8deba8c69f331d4fd83f68bfc9db58bc8983f49cd"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d177b7e82f6dd5e1aebd24d9c3297c70ce09cd1d5d37b43e53f39514379c029c"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:680b9a36029b30cf063698755d277885d4a0eab70a2c7c6e71aab601323cba45"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1306102f6d9e625cebaca3d4c9c8f10588735ef877f0360b5cdb4fdfd3fd7131"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-win32.whl", hash = "sha256:5b78aa9f4f68212248aaf8943d84c0ff0f74efc65a661c2fc68b82d498311fd5"},
{file = "SQLAlchemy-2.0.27-cp311-cp311-win_amd64.whl", hash = "sha256:15e19a84b84528f52a68143439d0c7a3a69befcd4f50b8ef9b7b69d2628ae7c4"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0de1263aac858f288a80b2071990f02082c51d88335a1db0d589237a3435fe71"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce850db091bf7d2a1f2fdb615220b968aeff3849007b1204bf6e3e50a57b3d32"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8dfc936870507da96aebb43e664ae3a71a7b96278382bcfe84d277b88e379b18"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4fbe6a766301f2e8a4519f4500fe74ef0a8509a59e07a4085458f26228cd7cc"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4535c49d961fe9a77392e3a630a626af5baa967172d42732b7a43496c8b28876"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0fb3bffc0ced37e5aa4ac2416f56d6d858f46d4da70c09bb731a246e70bff4d5"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-win32.whl", hash = "sha256:7f470327d06400a0aa7926b375b8e8c3c31d335e0884f509fe272b3c700a7254"},
{file = "SQLAlchemy-2.0.27-cp312-cp312-win_amd64.whl", hash = "sha256:f9374e270e2553653d710ece397df67db9d19c60d2647bcd35bfc616f1622dcd"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e97cf143d74a7a5a0f143aa34039b4fecf11343eed66538610debc438685db4a"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7b5a3e2120982b8b6bd1d5d99e3025339f7fb8b8267551c679afb39e9c7c7f1"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e36aa62b765cf9f43a003233a8c2d7ffdeb55bc62eaa0a0380475b228663a38f"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5ada0438f5b74c3952d916c199367c29ee4d6858edff18eab783b3978d0db16d"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b1d9d1bfd96eef3c3faedb73f486c89e44e64e40e5bfec304ee163de01cf996f"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-win32.whl", hash = "sha256:ca891af9f3289d24a490a5fde664ea04fe2f4984cd97e26de7442a4251bd4b7c"},
{file = "SQLAlchemy-2.0.27-cp37-cp37m-win_amd64.whl", hash = "sha256:fd8aafda7cdff03b905d4426b714601c0978725a19efc39f5f207b86d188ba01"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ec1f5a328464daf7a1e4e385e4f5652dd9b1d12405075ccba1df842f7774b4fc"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ad862295ad3f644e3c2c0d8b10a988e1600d3123ecb48702d2c0f26771f1c396"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48217be1de7d29a5600b5c513f3f7664b21d32e596d69582be0a94e36b8309cb"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e56afce6431450442f3ab5973156289bd5ec33dd618941283847c9fd5ff06bf"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:611068511b5531304137bcd7fe8117c985d1b828eb86043bd944cebb7fae3910"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b86abba762ecfeea359112b2bb4490802b340850bbee1948f785141a5e020de8"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-win32.whl", hash = "sha256:30d81cc1192dc693d49d5671cd40cdec596b885b0ce3b72f323888ab1c3863d5"},
{file = "SQLAlchemy-2.0.27-cp38-cp38-win_amd64.whl", hash = "sha256:120af1e49d614d2525ac247f6123841589b029c318b9afbfc9e2b70e22e1827d"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d07ee7793f2aeb9b80ec8ceb96bc8cc08a2aec8a1b152da1955d64e4825fcbac"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cb0845e934647232b6ff5150df37ceffd0b67b754b9fdbb095233deebcddbd4a"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fc19ae2e07a067663dd24fca55f8ed06a288384f0e6e3910420bf4b1270cc51"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b90053be91973a6fb6020a6e44382c97739736a5a9d74e08cc29b196639eb979"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2f5c9dfb0b9ab5e3a8a00249534bdd838d943ec4cfb9abe176a6c33408430230"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33e8bde8fff203de50399b9039c4e14e42d4d227759155c21f8da4a47fc8053c"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-win32.whl", hash = "sha256:d873c21b356bfaf1589b89090a4011e6532582b3a8ea568a00e0c3aab09399dd"},
{file = "SQLAlchemy-2.0.27-cp39-cp39-win_amd64.whl", hash = "sha256:ff2f1b7c963961d41403b650842dc2039175b906ab2093635d8319bef0b7d620"},
{file = "SQLAlchemy-2.0.27-py3-none-any.whl", hash = "sha256:1ab4e0448018d01b142c916cc7119ca573803a4745cfe341b8f95657812700ac"},
{file = "SQLAlchemy-2.0.27.tar.gz", hash = "sha256:86a6ed69a71fe6b88bf9331594fa390a2adda4a49b5c06f98e47bf0d392534f8"},
] ]
[package.dependencies] [package.dependencies]
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""} greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""}
typing-extensions = ">=4.2.0" typing-extensions = ">=4.6.0"
[package.extras] [package.extras]
aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"] aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"]
aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing-extensions (!=3.10.0.1)"] aioodbc = ["aioodbc", "greenlet (!=0.4.17)"]
aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing_extensions (!=3.10.0.1)"]
asyncio = ["greenlet (!=0.4.17)"] asyncio = ["greenlet (!=0.4.17)"]
asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"] asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"]
mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"] mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"]
@@ -6732,7 +6812,7 @@ mssql-pyodbc = ["pyodbc"]
mypy = ["mypy (>=0.910)"] mypy = ["mypy (>=0.910)"]
mysql = ["mysqlclient (>=1.4.0)"] mysql = ["mysqlclient (>=1.4.0)"]
mysql-connector = ["mysql-connector-python"] mysql-connector = ["mysql-connector-python"]
oracle = ["cx-oracle (>=7)"] oracle = ["cx_oracle (>=8)"]
oracle-oracledb = ["oracledb (>=1.0.1)"] oracle-oracledb = ["oracledb (>=1.0.1)"]
postgresql = ["psycopg2 (>=2.7)"] postgresql = ["psycopg2 (>=2.7)"]
postgresql-asyncpg = ["asyncpg", "greenlet (!=0.4.17)"] postgresql-asyncpg = ["asyncpg", "greenlet (!=0.4.17)"]
@@ -6742,7 +6822,7 @@ postgresql-psycopg2binary = ["psycopg2-binary"]
postgresql-psycopg2cffi = ["psycopg2cffi"] postgresql-psycopg2cffi = ["psycopg2cffi"]
postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
pymysql = ["pymysql"] pymysql = ["pymysql"]
sqlcipher = ["sqlcipher3-binary"] sqlcipher = ["sqlcipher3_binary"]
[[package]] [[package]]
name = "sse-starlette" name = "sse-starlette"
@@ -8362,4 +8442,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<3.12" python-versions = ">=3.9,<3.12"
content-hash = "f613dc1a3e9b724c95b407d4d8b9e67518e718142c77ad4723b7cb1e43eec9db" content-hash = "e62b0c29fbd4b394814cff9a59ed4a22e0e8b09ec9a9bf8fc34533cee810a2f3"

View File

@@ -154,6 +154,8 @@ boto3 = { version = "^1.34.20", optional = true }
langchain-mistralai = { version = "^0.0.3", optional = true } langchain-mistralai = { version = "^0.0.3", optional = true }
langchain-openai = "^0.0.5" langchain-openai = "^0.0.5"
langchain-google-vertexai = { version = "^0.0.5", optional = true } langchain-google-vertexai = { version = "^0.0.5", optional = true }
sqlalchemy = "^2.0.27"
alembic = "^1.13.1"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
black = "^23.3.0" black = "^23.3.0"

View File

@@ -1,19 +1,31 @@
import os import os
import pytest import pytest
from sqlalchemy import MetaData, create_engine
from sqlalchemy.orm import sessionmaker
@pytest.fixture(autouse=True)
def clean_db(): def clean_db():
db_path = os.path.expanduser("~/.embedchain/embedchain.db") db_path = os.path.expanduser("~/.embedchain/embedchain.db")
if os.path.exists(db_path): db_url = f"sqlite:///{db_path}"
os.remove(db_path) engine = create_engine(db_url)
metadata = MetaData()
metadata.reflect(bind=engine) # Reflect schema from the engine
Session = sessionmaker(bind=engine)
session = Session()
try:
@pytest.fixture # Iterate over all tables in reversed order to respect foreign keys
def setup(): for table in reversed(metadata.sorted_tables):
clean_db() if table.name != "alembic_version": # Skip the Alembic version table
yield session.execute(table.delete())
clean_db() session.commit()
except Exception as e:
session.rollback()
print(f"Error cleaning database: {e}")
finally:
session.close()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)

View File

@@ -27,10 +27,7 @@ def test_whole_app(app_instance, mocker):
mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge) mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge) mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
mocker.patch.object(BaseLlm, "generate_prompt") mocker.patch.object(BaseLlm, "generate_prompt")
mocker.patch.object( mocker.patch.object(BaseLlm, "add_history")
BaseLlm,
"add_history",
)
mocker.patch.object(ChatHistory, "delete", autospec=True) mocker.patch.object(ChatHistory, "delete", autospec=True)
app_instance.add(knowledge, data_type="text") app_instance.add(knowledge, data_type="text")