diff --git a/docs/advanced/vector_database.mdx b/docs/advanced/vector_database.mdx new file mode 100644 index 00000000..b3cdda86 --- /dev/null +++ b/docs/advanced/vector_database.mdx @@ -0,0 +1,34 @@ +--- +title: '💾 Vector Database' +--- + +We support `Chroma` and `Elasticsearch` as two vector database. +`Chroma` is used as a default database. + +### Elasticsearch +In order to use `Elasticsearch` as vector database we need to use App type `CustomApp`. +```python +import os +from embedchain import CustomApp +from embedchain.config import CustomAppConfig, ElasticsearchDBConfig +from embedchain.models import Providers, EmbeddingFunctions, VectorDatabases + +os.environ["OPENAI_API_KEY"] = 'OPENAI_API_KEY' + +es_config = ElasticsearchDBConfig( + # elasticsearch url or list of nodes url with different hosts and ports. + es_url='http://localhost:9200', + # pass named parameters supported by Python Elasticsearch client + ca_certs="/path/to/http_ca.crt", + basic_auth=("username", "password") +) +config = CustomAppConfig( + embedding_fn=EmbeddingFunctions.OPENAI, + provider=Providers.OPENAI, + db_type=VectorDatabases.ELASTICSEARCH, + es_config=es_config, +) +es_app = CustomApp(config) +``` +- Set `db_type=VectorDatabases.ELASTICSEARCH` and `es_config=ElasticsearchDBConfig(es_url='')` in `CustomAppConfig`. +- `ElasticsearchDBConfig` accepts `es_url` as elasticsearch url or as list of nodes url with different hosts and ports. Additionally we can pass named paramaters supported by Python Elasticsearch client. \ No newline at end of file diff --git a/docs/mint.json b/docs/mint.json index cf73769f..750afc0b 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -32,7 +32,7 @@ }, { "group": "Advanced", - "pages": ["advanced/app_types", "advanced/interface_types", "advanced/adding_data","advanced/data_types", "advanced/query_configuration", "advanced/configuration", "advanced/testing", "advanced/showcase"] + "pages": ["advanced/app_types", "advanced/interface_types", "advanced/adding_data","advanced/data_types", "advanced/query_configuration", "advanced/configuration", "advanced/testing", "advanced/vector_database", "advanced/showcase"] }, { "group": "Examples", diff --git a/embedchain/config/__init__.py b/embedchain/config/__init__.py index 9bd12b26..684b63fa 100644 --- a/embedchain/config/__init__.py +++ b/embedchain/config/__init__.py @@ -5,3 +5,4 @@ from .apps.OpenSourceAppConfig import OpenSourceAppConfig # noqa: F401 from .BaseConfig import BaseConfig # noqa: F401 from .ChatConfig import ChatConfig # noqa: F401 from .QueryConfig import QueryConfig # noqa: F401 +from .vectordbs.ElasticsearchDBConfig import ElasticsearchDBConfig # noqa: F401 diff --git a/embedchain/config/apps/BaseAppConfig.py b/embedchain/config/apps/BaseAppConfig.py index c1f0daa6..4b85c1a4 100644 --- a/embedchain/config/apps/BaseAppConfig.py +++ b/embedchain/config/apps/BaseAppConfig.py @@ -1,6 +1,8 @@ import logging from embedchain.config.BaseConfig import BaseConfig +from embedchain.config.vectordbs import ElasticsearchDBConfig +from embedchain.models import VectorDatabases, VectorDimensions class BaseAppConfig(BaseConfig): @@ -8,7 +10,19 @@ class BaseAppConfig(BaseConfig): Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`. """ - def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None, collection_name=None): + def __init__( + self, + log_level=None, + embedding_fn=None, + db=None, + host=None, + port=None, + id=None, + collection_name=None, + db_type: VectorDatabases = None, + vector_dim: VectorDimensions = None, + es_config: ElasticsearchDBConfig = None, + ): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. @@ -18,27 +32,53 @@ class BaseAppConfig(BaseConfig): :param port: Optional. Port for the database server. :param id: Optional. ID of the app. Document metadata will have this id. :param collection_name: Optional. Collection name for the database. + :param db_type: Optional. type of Vector database to use + :param vector_dim: Vector dimension generated by embedding fn + :param es_config: Optional. elasticsearch database config to be used for connection """ self._setup_logging(log_level) - - self.db = db if db else BaseAppConfig.default_db(embedding_fn=embedding_fn, host=host, port=port) self.collection_name = collection_name if collection_name else "embedchain_store" + self.db = BaseAppConfig.get_db( + db=db, + embedding_fn=embedding_fn, + host=host, + port=port, + db_type=db_type, + vector_dim=vector_dim, + collection_name=self.collection_name, + es_config=es_config, + ) self.id = id return @staticmethod - def default_db(embedding_fn, host, port): + def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config): """ - Sets database to default (`ChromaDb`). - + Get db based on db_type, db with default database (`ChromaDb`) + :param Optional. (Vector) database to use for embeddings. :param embedding_fn: Embedding function to use in database. :param host: Optional. Hostname for the database server. :param port: Optional. Port for the database server. - :returns: Default database + :param db_type: Optional. db type to use. Supported values (`es`, `chroma`) + :param vector_dim: Vector dimension generated by embedding fn + :param collection_name: Optional. Collection name for the database. + :param es_config: Optional. elasticsearch database config to be used for connection :raises ValueError: BaseAppConfig knows no default embedding function. + :returns: database instance """ + if db: + return db + if embedding_fn is None: raise ValueError("ChromaDb cannot be instantiated without an embedding function") + + if db_type == VectorDatabases.ELASTICSEARCH: + from embedchain.vectordb.elasticsearch_db import ElasticsearchDB + + return ElasticsearchDB( + embedding_fn=embedding_fn, vector_dim=vector_dim, collection_name=collection_name, es_config=es_config + ) + from embedchain.vectordb.chroma_db import ChromaDB return ChromaDB(embedding_fn=embedding_fn, host=host, port=port) diff --git a/embedchain/config/apps/CustomAppConfig.py b/embedchain/config/apps/CustomAppConfig.py index cc42f85a..1d2dd916 100644 --- a/embedchain/config/apps/CustomAppConfig.py +++ b/embedchain/config/apps/CustomAppConfig.py @@ -3,7 +3,8 @@ from typing import Any from chromadb.api.types import Documents, Embeddings from dotenv import load_dotenv -from embedchain.models import EmbeddingFunctions, Providers +from embedchain.config.vectordbs import ElasticsearchDBConfig +from embedchain.models import EmbeddingFunctions, Providers, VectorDatabases, VectorDimensions from .BaseAppConfig import BaseAppConfig @@ -28,6 +29,8 @@ class CustomAppConfig(BaseAppConfig): provider: Providers = None, open_source_app_config=None, deployment_name=None, + db_type: VectorDatabases = None, + es_config: ElasticsearchDBConfig = None, ): """ :param log_level: Optional. (String) Debug level @@ -41,6 +44,8 @@ class CustomAppConfig(BaseAppConfig): :param collection_name: Optional. Collection name for the database. :param provider: Optional. (Providers): LLM Provider to use. :param open_source_app_config: Optional. Config instance needed for open source apps. + :param db_type: Optional. type of Vector database to use. + :param es_config: Optional. elasticsearch database config to be used for connection """ if provider: self.provider = provider @@ -59,6 +64,9 @@ class CustomAppConfig(BaseAppConfig): port=port, id=id, collection_name=collection_name, + db_type=db_type, + vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn), + es_config=es_config, ) @staticmethod @@ -108,3 +116,20 @@ class CustomAppConfig(BaseAppConfig): from chromadb.utils import embedding_functions return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model) + + @staticmethod + def get_vector_dimension(embedding_function: EmbeddingFunctions): + if not isinstance(embedding_function, EmbeddingFunctions): + raise ValueError(f"Invalid option: '{embedding_function}'.") + + if embedding_function == EmbeddingFunctions.OPENAI: + return VectorDimensions.OPENAI.value + + elif embedding_function == EmbeddingFunctions.HUGGING_FACE: + return VectorDimensions.HUGGING_FACE.value + + elif embedding_function == EmbeddingFunctions.VERTEX_AI: + return VectorDimensions.VERTEX_AI.value + + elif embedding_function == EmbeddingFunctions.GPT4ALL: + return VectorDimensions.GPT4ALL.value diff --git a/embedchain/config/vectordbs/ElasticsearchDBConfig.py b/embedchain/config/vectordbs/ElasticsearchDBConfig.py new file mode 100644 index 00000000..6e7dd0f9 --- /dev/null +++ b/embedchain/config/vectordbs/ElasticsearchDBConfig.py @@ -0,0 +1,15 @@ +from typing import Dict, List, Union + +from embedchain.config.BaseConfig import BaseConfig + + +class ElasticsearchDBConfig(BaseConfig): + """ + Config to initialize an elasticsearch client. + :param es_url. elasticsearch url or list of nodes url to be used for connection + :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch. + """ + + def __init__(self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]): + self.ES_URL = es_url + self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS diff --git a/embedchain/config/vectordbs/__init__.py b/embedchain/config/vectordbs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 8033eff0..8cfca455 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -1,7 +1,6 @@ import logging import os -from chromadb.errors import InvalidDimensionException from dotenv import load_dotenv from langchain.docstore.document import Document from langchain.memory import ConversationBufferMemory @@ -31,8 +30,8 @@ class EmbedChain: """ self.config = config - self.db_client = self.config.db.client self.collection = self.config.db._get_or_create_collection(self.config.collection_name) + self.db = self.config.db self.user_asks = [] self.is_docs_site_instance = False self.online = False @@ -99,11 +98,10 @@ class EmbedChain: # get existing ids, and discard doc if any common id exist. where = {"app_id": self.config.id} if self.config.id is not None else {} # where={"url": src} - existing_docs = self.collection.get( + existing_ids = self.db.get( ids=ids, where=where, # optional filter ) - existing_ids = set(existing_docs["ids"]) if len(existing_ids): data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)} @@ -128,7 +126,7 @@ class EmbedChain: # Add metadata to each document metadatas_with_metadata = [{**meta, **metadata} for meta in metadatas] - self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids) + self.db.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids) print((f"Successfully saved {src}. New chunks count: " f"{self.count() - chunks_before_addition}")) def _format_result(self, results): @@ -156,23 +154,13 @@ class EmbedChain: :param config: The query configuration. :return: The content of the document that matched your query. """ - try: - where = {"app_id": self.config.id} if self.config.id is not None else {} # optional filter - result = self.collection.query( - query_texts=[ - input_query, - ], - n_results=config.number_documents, - where=where, - ) - except InvalidDimensionException as e: - raise InvalidDimensionException( - e.message() - + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501 - ) from None + where = {"app_id": self.config.id} if self.config.id is not None else {} # optional filter + contents = self.db.query( + input_query=input_query, + n_results=config.number_documents, + where=where, + ) - results_formatted = self._format_result(result) - contents = [result[0].page_content for result in results_formatted] return contents def _append_search_and_context(self, context, web_search_result): @@ -339,11 +327,11 @@ class EmbedChain: :return: The number of embeddings. """ - return self.collection.count() + return self.db.count() def reset(self): """ Resets the database. Deletes all embeddings irreversibly. `App` has to be reinitialized after using this method. """ - self.db_client.reset() + self.db.reset() diff --git a/embedchain/models/VectorDatabases.py b/embedchain/models/VectorDatabases.py new file mode 100644 index 00000000..5abf3844 --- /dev/null +++ b/embedchain/models/VectorDatabases.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class VectorDatabases(Enum): + CHROMADB = "CHROMADB" + ELASTICSEARCH = "ELASTICSEARCH" diff --git a/embedchain/models/VectorDimensions.py b/embedchain/models/VectorDimensions.py new file mode 100644 index 00000000..9be1f304 --- /dev/null +++ b/embedchain/models/VectorDimensions.py @@ -0,0 +1,9 @@ +from enum import Enum + + +# vector length created by embedding fn +class VectorDimensions(Enum): + GPT4ALL = 384 + OPENAI = 1536 + VERTEX_AI = 768 + HUGGING_FACE = 384 diff --git a/embedchain/models/__init__.py b/embedchain/models/__init__.py index 7c459977..c5daa449 100644 --- a/embedchain/models/__init__.py +++ b/embedchain/models/__init__.py @@ -1,2 +1,4 @@ from .EmbeddingFunctions import EmbeddingFunctions # noqa: F401 from .Providers import Providers # noqa: F401 +from .VectorDatabases import VectorDatabases # noqa: F401 +from .VectorDimensions import VectorDimensions # noqa: F401 diff --git a/embedchain/vectordb/base_vector_db.py b/embedchain/vectordb/base_vector_db.py index f38e3d31..0ed1e3c0 100644 --- a/embedchain/vectordb/base_vector_db.py +++ b/embedchain/vectordb/base_vector_db.py @@ -10,3 +10,18 @@ class BaseVectorDB: def _get_or_create_collection(self): raise NotImplementedError + + def get(self): + raise NotImplementedError + + def add(self): + raise NotImplementedError + + def query(self): + raise NotImplementedError + + def count(self): + raise NotImplementedError + + def reset(self): + raise NotImplementedError diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index 168c6221..b50c8b52 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -1,4 +1,8 @@ import logging +from typing import Any, Dict, List + +from chromadb.errors import InvalidDimensionException +from langchain.docstore.document import Document try: import chromadb @@ -7,6 +11,7 @@ except RuntimeError: use_pysqlite3() import chromadb + from chromadb.config import Settings from embedchain.vectordb.base_vector_db import BaseVectorDB @@ -41,7 +46,73 @@ class ChromaDB(BaseVectorDB): def _get_or_create_collection(self, name): """Get or create the collection.""" - return self.client.get_or_create_collection( + self.collection = self.client.get_or_create_collection( name=name, embedding_function=self.embedding_fn, ) + return self.collection + + def get(self, ids: List[str], where: Dict[str, any]) -> List[str]: + """ + Get existing doc ids present in vector database + :param ids: list of doc ids to check for existance + :param where: Optional. to filter data + """ + existing_docs = self.collection.get( + ids=ids, + where=where, # optional filter + ) + + return set(existing_docs["ids"]) + + def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any: + """ + add data in vector database + :param documents: list of texts to add + :param metadatas: list of metadata associated with docs + :param ids: ids of docs + """ + self.collection.add(documents=documents, metadatas=metadatas, ids=ids) + + def _format_result(self, results): + return [ + (Document(page_content=result[0], metadata=result[1] or {}), result[2]) + for result in zip( + results["documents"][0], + results["metadatas"][0], + results["distances"][0], + ) + ] + + def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]: + """ + query contents from vector data base based on vector similarity + :param input_query: list of query string + :param n_results: no of similar documents to fetch from database + :param where: Optional. to filter data + :return: The content of the document that matched your query. + """ + try: + result = self.collection.query( + query_texts=[ + input_query, + ], + n_results=n_results, + where=where, + ) + except InvalidDimensionException as e: + raise InvalidDimensionException( + e.message() + + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501 + ) from None + + results_formatted = self._format_result(result) + contents = [result[0].page_content for result in results_formatted] + return contents + + def count(self) -> int: + return self.collection.count() + + def reset(self): + # Delete all data from the database + self.client.reset() diff --git a/embedchain/vectordb/elasticsearch_db.py b/embedchain/vectordb/elasticsearch_db.py new file mode 100644 index 00000000..4371237e --- /dev/null +++ b/embedchain/vectordb/elasticsearch_db.py @@ -0,0 +1,136 @@ +from typing import Any, Callable, Dict, List + +try: + from elasticsearch import Elasticsearch + from elasticsearch.helpers import bulk +except ImportError: + raise ImportError( + "Elasticsearch requires extra dependencies. Install with `pip install embedchain[elasticsearch]`" + ) from None + +from embedchain.config import ElasticsearchDBConfig +from embedchain.models.VectorDimensions import VectorDimensions +from embedchain.vectordb.base_vector_db import BaseVectorDB + + +class ElasticsearchDB(BaseVectorDB): + def __init__( + self, + es_config: ElasticsearchDBConfig = None, + embedding_fn: Callable[[list[str]], list[str]] = None, + vector_dim: VectorDimensions = None, + collection_name: str = None, + ): + """ + Elasticsearch as vector database + :param es_config. elasticsearch database config to be used for connection + :param embedding_fn: Function to generate embedding vectors. + :param vector_dim: Vector dimension generated by embedding fn + :param collection_name: Optional. Collection name for the database. + """ + if not hasattr(embedding_fn, "__call__"): + raise ValueError("Embedding function is not a function") + if es_config is None: + raise ValueError("ElasticsearchDBConfig is required") + if vector_dim is None: + raise ValueError("Vector Dimension is required to refer correct index and mapping") + if collection_name is None: + raise ValueError("collection name is required. It cannot be empty") + self.embedding_fn = embedding_fn + self.client = Elasticsearch(es_config.ES_URL, **es_config.ES_EXTRA_PARAMS) + self.vector_dim = vector_dim + self.es_index = f"{collection_name}_{self.vector_dim}" + index_settings = { + "mappings": { + "properties": { + "text": {"type": "text"}, + "text_vector": {"type": "dense_vector", "index": False, "dims": self.vector_dim}, + } + } + } + if not self.client.indices.exists(index=self.es_index): + # create index if not exist + print("Creating index", self.es_index, index_settings) + self.client.indices.create(index=self.es_index, body=index_settings) + super().__init__() + + def _get_or_create_db(self): + return self.client + + def _get_or_create_collection(self, name): + """Note: nothing to return here. Discuss later""" + + def get(self, ids: List[str], where: Dict[str, any]) -> List[str]: + """ + Get existing doc ids present in vector database + :param ids: list of doc ids to check for existance + :param where: Optional. to filter data + """ + query = {"bool": {"must": [{"ids": {"values": ids}}]}} + if "app_id" in where: + app_id = where["app_id"] + query["bool"]["must"].append({"term": {"metadata.app_id": app_id}}) + response = self.client.search(index=self.es_index, query=query, _source=False) + docs = response["hits"]["hits"] + ids = [doc["_id"] for doc in docs] + return set(ids) + + def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any: + """ + add data in vector database + :param documents: list of texts to add + :param metadatas: list of metadata associated with docs + :param ids: ids of docs + """ + docs = [] + embeddings = self.embedding_fn(documents) + for id, text, metadata, text_vector in zip(ids, documents, metadatas, embeddings): + docs.append( + { + "_index": self.es_index, + "_id": id, + "_source": {"text": text, "metadata": metadata, "text_vector": text_vector}, + } + ) + bulk(self.client, docs) + self.client.indices.refresh(index=self.es_index) + return + + def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]: + """ + query contents from vector data base based on vector similarity + :param input_query: list of query string + :param n_results: no of similar documents to fetch from database + :param where: Optional. to filter data + """ + input_query_vector = self.embedding_fn(input_query) + query_vector = input_query_vector[0] + query = { + "script_score": { + "query": {"bool": {"must": [{"exists": {"field": "text"}}]}}, + "script": { + "source": "cosineSimilarity(params.input_query_vector, 'text_vector') + 1.0", + "params": {"input_query_vector": query_vector}, + }, + } + } + if "app_id" in where: + app_id = where["app_id"] + query["script_score"]["query"]["bool"]["must"] = [{"term": {"metadata.app_id": app_id}}] + _source = ["text"] + response = self.client.search(index=self.es_index, query=query, _source=_source, size=n_results) + docs = response["hits"]["hits"] + contents = [doc["_source"]["text"] for doc in docs] + return contents + + def count(self) -> int: + query = {"match_all": {}} + response = self.client.count(index=self.es_index, query=query) + doc_count = response["count"] + return doc_count + + def reset(self): + # Delete all data from the database + if self.client.indices.exists(index=self.es_index): + # delete index in Es + self.client.indices.delete(index=self.es_index) diff --git a/pyproject.toml b/pyproject.toml index 85211804..d08bc234 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ beautifulsoup4 = "^4.12.2" pypdf = "^3.11.0" pytube = "^15.0.0" llama-index = { version = "^0.7.21", optional = true } +elasticsearch = { version = "^8.9.0", optional = true } @@ -107,6 +108,7 @@ isort = "^5.12.0" [tool.poetry.extras] streamlit = ["streamlit"] community = ["llama-index"] +elasticsearch = ["elasticsearch"] [tool.poetry.group.docs.dependencies] diff --git a/setup.py b/setup.py index 16d735af..4da77c45 100644 --- a/setup.py +++ b/setup.py @@ -37,5 +37,9 @@ setuptools.setup( "replicate==0.9.0", "duckduckgo-search==3.8.4", ], - extras_require={"dev": ["black", "ruff", "isort", "pytest"], "community": ["llama-index==0.7.21"]}, + extras_require={ + "dev": ["black", "ruff", "isort", "pytest"], + "community": ["llama-index==0.7.21"], + "elasticsearch": ["elasticsearch>=8.9.0"], + }, ) diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py new file mode 100644 index 00000000..4f316eae --- /dev/null +++ b/tests/vectordb/test_elasticsearch_db.py @@ -0,0 +1,33 @@ +import unittest +from unittest.mock import Mock + +from embedchain.config import ElasticsearchDBConfig +from embedchain.vectordb.elasticsearch_db import ElasticsearchDB + + +class TestEsDB(unittest.TestCase): + def setUp(self): + self.es_config = ElasticsearchDBConfig() + self.vector_dim = 384 + + def test_init_with_invalid_embedding_fn(self): + # Test if an exception is raised when an invalid embedding_fn is provided + with self.assertRaises(ValueError): + ElasticsearchDB(embedding_fn=None) + + def test_init_with_invalid_es_config(self): + # Test if an exception is raised when an invalid es_config is provided + with self.assertRaises(ValueError): + ElasticsearchDB(embedding_fn=Mock(), es_config=None) + + def test_init_with_invalid_vector_dim(self): + # Test if an exception is raised when an invalid vector_dim is provided + with self.assertRaises(ValueError): + ElasticsearchDB(embedding_fn=Mock(), es_config=self.es_config, vector_dim=None) + + def test_init_with_invalid_collection_name(self): + # Test if an exception is raised when an invalid collection_name is provided + with self.assertRaises(ValueError): + ElasticsearchDB( + embedding_fn=Mock(), es_config=self.es_config, vector_dim=self.vector_dim, collection_name=None + )