From 752f638cfcc8fe41df7676056b45e795eeeca340 Mon Sep 17 00:00:00 2001 From: Deshraj Yadav Date: Mon, 26 Feb 2024 13:18:42 -0800 Subject: [PATCH] [Feature/Improvements] Delete data sources from metadata db when using `app.delete()` (#1286) --- embedchain/config/vectordb/pinecone.py | 2 ++ embedchain/embedchain.py | 19 +++++++++++++++---- embedchain/vectordb/pinecone.py | 3 +-- pyproject.toml | 2 +- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/embedchain/config/vectordb/pinecone.py b/embedchain/config/vectordb/pinecone.py index f82da24d..f021639d 100644 --- a/embedchain/config/vectordb/pinecone.py +++ b/embedchain/config/vectordb/pinecone.py @@ -16,6 +16,7 @@ class PineconeDBConfig(BaseVectorDbConfig): pod_config: Optional[dict[str, any]] = None, serverless_config: Optional[dict[str, any]] = None, hybrid_search: bool = False, + bm25_encoder: any = None, **extra_params: dict[str, any], ): self.metric = metric @@ -24,6 +25,7 @@ class PineconeDBConfig(BaseVectorDbConfig): self.vector_dimension = vector_dimension self.extra_params = extra_params self.hybrid_search = hybrid_search + self.bm25_encoder = bm25_encoder if pod_config is None and serverless_config is None: # If no config is provided, use the default pod spec config pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter") diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index c8e27f39..639c3ad6 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -6,17 +6,20 @@ from typing import Any, Optional, Union from dotenv import load_dotenv from langchain.docstore.document import Document -from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback +from embedchain.cache import (adapt, get_gptcache_session, + gptcache_data_convert, + gptcache_update_cache_callback) from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config.base_app_config import BaseAppConfig -from embedchain.core.db.models import DataSource +from embedchain.core.db.models import ChatHistory, DataSource from embedchain.data_formatter import DataFormatter from embedchain.embedder.base import BaseEmbedder from embedchain.helpers.json_serializable import JSONSerializable from embedchain.llm.base import BaseLlm from embedchain.loaders.base_loader import BaseLoader -from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType +from embedchain.models.data_type import (DataType, DirectDataType, + IndirectDataType, SpecialDataType) from embedchain.utils.misc import detect_datatype, is_valid_json_string from embedchain.vectordb.base import BaseVectorDB @@ -642,9 +645,10 @@ class EmbedChain(JSONSerializable): """ try: self.db_session.query(DataSource).filter_by(app_id=self.config.id).delete() + self.db_session.query(ChatHistory).filter_by(app_id=self.config.id).delete() self.db_session.commit() except Exception as e: - logging.error(f"Error deleting chat history: {e}") + logging.error(f"Error deleting data sources: {e}") self.db_session.rollback() return None self.db.reset() @@ -682,6 +686,13 @@ class EmbedChain(JSONSerializable): :param source_hash: The hash of the source. :type source_hash: str """ + try: + self.db_session.query(DataSource).filter_by(hash=source_id, app_id=self.config.id).delete() + self.db_session.commit() + except Exception as e: + logging.error(f"Error deleting data sources: {e}") + self.db_session.rollback() + return None self.db.delete(where={"hash": source_id}) logging.info(f"Successfully deleted {source_id}") # Send anonymous telemetry diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index 24c030eb..23fd64ce 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -49,9 +49,8 @@ class PineconeDB(BaseVectorDB): # Setup BM25Encoder if sparse vectors are to be used self.bm25_encoder = None if self.config.hybrid_search: - # TODO: Add support for fitting BM25Encoder on any corpus logging.info("Initializing BM25Encoder for sparse vectors..") - self.bm25_encoder = BM25Encoder.default() + self.bm25_encoder = self.config.bm25_encoder if self.config.bm25_encoder else BM25Encoder.default() # Call parent init here because embedder is needed super().__init__(config=self.config) diff --git a/pyproject.toml b/pyproject.toml index 9c9fb780..f1362687 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.86" +version = "0.1.87" description = "Simplest open source retrieval(RAG) framework" authors = [ "Taranjeet Singh ",