Add support for OpenSearch as vector database (#725)

This commit is contained in:
Deshraj Yadav
2023-09-28 14:54:42 -07:00
committed by GitHub
parent 9951b58005
commit 414c69fd62
22 changed files with 326 additions and 82 deletions

View File

@@ -7,8 +7,19 @@ PROJECT_NAME := embedchain
.PHONY: install format lint clean test ci_lint ci_test .PHONY: install format lint clean test ci_lint ci_test
install: install:
$(PIP) install --upgrade pip poetry install
$(PIP) install -e .[dev]
install_es:
poetry install --extras elasticsearch
install_opensearch:
poetry install --extras opensearch
shell:
poetry shell
py_shell:
poetry run python
format: format:
$(PYTHON) -m black . $(PYTHON) -m black .

View File

@@ -70,6 +70,6 @@ app.reset()
Counts the number of embeddings (chunks) in the database. Counts the number of embeddings (chunks) in the database.
```python ```python
print(app.count()) print(app.db.count())
# returns: 481 # returns: 481
``` ```

View File

@@ -2,7 +2,7 @@
title: '💾 Vector Database' title: '💾 Vector Database'
--- ---
We support `Chroma` and `Elasticsearch` as two vector database. We support `Chroma`, `Elasticsearch` and `OpenSearch` as vector databases.
`Chroma` is used as a default database. `Chroma` is used as a default database.
## Elasticsearch ## Elasticsearch
@@ -22,13 +22,13 @@ Please note that the key needs certain privileges. For testing you can just togg
2. Load the app 2. Load the app
```python ```python
from embedchain import CustomApp from embedchain import CustomApp
from embedchain.embedder.openai import OpenAiEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.llm.openai import OpenAILlm from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.elasticsearch import ElasticsearchDB from embedchain.vectordb.elasticsearch import ElasticsearchDB
es_app = CustomApp( es_app = CustomApp(
llm=OpenAILlm(), llm=OpenAILlm(),
embedder=OpenAiEmbedder(), embedder=OpenAIEmbedder(),
db=ElasticsearchDB(), db=ElasticsearchDB(),
) )
``` ```
@@ -45,7 +45,7 @@ import os
from embedchain import CustomApp from embedchain import CustomApp
from embedchain.config import CustomAppConfig, ElasticsearchDBConfig from embedchain.config import CustomAppConfig, ElasticsearchDBConfig
from embedchain.embedder.openai import OpenAiEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.llm.openai import OpenAILlm from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.elasticsearch import ElasticsearchDB from embedchain.vectordb.elasticsearch import ElasticsearchDB
@@ -61,10 +61,58 @@ es_config = ElasticsearchDBConfig(
es_app = CustomApp( es_app = CustomApp(
config=CustomAppConfig(log_level="INFO"), config=CustomAppConfig(log_level="INFO"),
llm=OpenAILlm(), llm=OpenAILlm(),
embedder=OpenAiEmbedder(), embedder=OpenAIEmbedder(),
db=ElasticsearchDB(config=es_config), db=ElasticsearchDB(config=es_config),
) )
``` ```
3. This should log your connection details to the console. 3. This should log your connection details to the console.
4. Alternatively to a URL, you `ElasticsearchDBConfig` accepts `es_url` as a list of nodes url with different hosts and ports. 4. Alternatively to a URL, you `ElasticsearchDBConfig` accepts `es_url` as a list of nodes url with different hosts and ports.
5. Additionally we can pass named parameters supported by Python Elasticsearch client. 5. Additionally we can pass named parameters supported by Python Elasticsearch client.
## OpenSearch 🔍
To use OpenSearch as a vector database with a CustomApp, follow these simple steps:
1. Set the `OPENAI_API_KEY` environment variable:
```
OPENAI_API_KEY=sk-xxxx
```
2. Define the OpenSearch configuration in your Python code:
```python
from embedchain import CustomApp
from embedchain.config import OpenSearchDBConfig
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.opensearch import OpenSearchDB
opensearch_url = "https://localhost:9200"
http_auth = ("username", "password")
db_config = OpenSearchDBConfig(
opensearch_url=opensearch_url,
http_auth=http_auth,
collection_name="embedchain-app",
use_ssl=True,
timeout=30,
)
db = OpenSearchDB(config=db_config)
```
2. Instantiate the app and add data:
```python
app = CustomApp(llm=OpenAILlm(), embedder=OpenAIEmbedder(), db=db)
app.add("https://en.wikipedia.org/wiki/Elon_Musk")
app.add("https://www.forbes.com/profile/elon-musk")
app.add("https://www.britannica.com/biography/Elon-Musk")
```
3. You're all set! Start querying using the following command:
```python
app.query("What is the net worth of Elon Musk?")
```

View File

@@ -2,7 +2,7 @@ from typing import Optional
from embedchain.apps.custom_app import CustomApp from embedchain.apps.custom_app import CustomApp
from embedchain.config import CustomAppConfig from embedchain.config import CustomAppConfig
from embedchain.embedder.openai import OpenAiEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
from embedchain.llm.llama2 import Llama2Llm from embedchain.llm.llama2 import Llama2Llm
from embedchain.vectordb.chroma import ChromaDB from embedchain.vectordb.chroma import ChromaDB
@@ -29,5 +29,5 @@ class Llama2App(CustomApp):
config = CustomAppConfig() config = CustomAppConfig()
super().__init__( super().__init__(
config=config, llm=Llama2Llm(), db=ChromaDB(), embedder=OpenAiEmbedder(), system_prompt=system_prompt config=config, llm=Llama2Llm(), db=ChromaDB(), embedder=OpenAIEmbedder(), system_prompt=system_prompt
) )

View File

@@ -3,7 +3,7 @@ from typing import Optional
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig, from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
ChromaDbConfig) ChromaDbConfig)
from embedchain.embedchain import EmbedChain from embedchain.embedchain import EmbedChain
from embedchain.embedder.openai import OpenAiEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
from embedchain.llm.openai import OpenAILlm from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.chroma import ChromaDB from embedchain.vectordb.chroma import ChromaDB
@@ -48,7 +48,7 @@ class App(EmbedChain):
config = AppConfig() config = AppConfig()
llm = OpenAILlm(config=llm_config) llm = OpenAILlm(config=llm_config)
embedder = OpenAiEmbedder(config=BaseEmbedderConfig(model="text-embedding-ada-002")) embedder = OpenAIEmbedder(config=BaseEmbedderConfig(model="text-embedding-ada-002"))
database = ChromaDB(config=chromadb_config) database = ChromaDB(config=chromadb_config)
super().__init__(config, llm, db=database, embedder=embedder, system_prompt=system_prompt) super().__init__(config, llm, db=database, embedder=embedder, system_prompt=system_prompt)

View File

@@ -2,7 +2,7 @@ from typing import Any
from embedchain import CustomApp from embedchain import CustomApp
from embedchain.config import AddConfig, CustomAppConfig, LlmConfig from embedchain.config import AddConfig, CustomAppConfig, LlmConfig
from embedchain.embedder.openai import OpenAiEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.helper.json_serializable import (JSONSerializable, from embedchain.helper.json_serializable import (JSONSerializable,
register_deserializable) register_deserializable)
from embedchain.llm.openai import OpenAILlm from embedchain.llm.openai import OpenAILlm
@@ -12,7 +12,7 @@ from embedchain.vectordb.chroma import ChromaDB
@register_deserializable @register_deserializable
class BaseBot(JSONSerializable): class BaseBot(JSONSerializable):
def __init__(self): def __init__(self):
self.app = CustomApp(config=CustomAppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedder=OpenAiEmbedder()) self.app = CustomApp(config=CustomAppConfig(), llm=OpenAILlm(), db=ChromaDB(), embedder=OpenAIEmbedder())
def add(self, data: Any, config: AddConfig = None): def add(self, data: Any, config: AddConfig = None):
""" """

View File

@@ -5,9 +5,10 @@ from .apps.app_config import AppConfig
from .apps.custom_app_config import CustomAppConfig from .apps.custom_app_config import CustomAppConfig
from .apps.open_source_app_config import OpenSourceAppConfig from .apps.open_source_app_config import OpenSourceAppConfig
from .base_config import BaseConfig from .base_config import BaseConfig
from .embedder.BaseEmbedderConfig import BaseEmbedderConfig from .embedder.base import BaseEmbedderConfig
from .embedder.BaseEmbedderConfig import BaseEmbedderConfig as EmbedderConfig from .embedder.base import BaseEmbedderConfig as EmbedderConfig
from .llm.base_llm_config import BaseLlmConfig from .llm.base_llm_config import BaseLlmConfig
from .llm.base_llm_config import BaseLlmConfig as LlmConfig from .llm.base_llm_config import BaseLlmConfig as LlmConfig
from .vectordbs.ChromaDbConfig import ChromaDbConfig from .vectordb.chroma import ChromaDbConfig
from .vectordbs.ElasticsearchDBConfig import ElasticsearchDBConfig from .vectordb.elasticsearch import ElasticsearchDBConfig
from .vectordb.opensearch import OpenSearchDBConfig

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable

View File

@@ -1,7 +1,7 @@
import os import os
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable

View File

@@ -0,0 +1,37 @@
from typing import Dict, Optional, Tuple
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable
@register_deserializable
class OpenSearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
opensearch_url: str,
http_auth: Tuple[str, str],
vector_dimension: int = 1536,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
**extra_params: Dict[str, any],
):
"""
Initializes a configuration class instance for an OpenSearch client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param opensearch_url: URL of the OpenSearch domain
:type opensearch_url: str, Eg, "http://localhost:9200"
:param http_auth: Tuple of username and password
:type http_auth: Tuple[str, str], Eg, ("username", "password")
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
:type vector_dimension: int, optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
"""
self.opensearch_url = opensearch_url
self.http_auth = http_auth
self.vector_dimension = vector_dimension
self.extra_params = extra_params
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -61,16 +61,13 @@ class EmbedChain(JSONSerializable):
""" """
self.config = config self.config = config
# Llm
# Add subclasses
## Llm
self.llm = llm self.llm = llm
## Database
# Database has support for config assignment for backwards compatibility # Database has support for config assignment for backwards compatibility
if db is None and (not hasattr(self.config, "db") or self.config.db is None): if db is None and (not hasattr(self.config, "db") or self.config.db is None):
raise ValueError("App requires Database.") raise ValueError("App requires Database.")
self.db = db or self.config.db self.db = db or self.config.db
## Embedder # Embedder
if embedder is None: if embedder is None:
raise ValueError("App requires Embedder.") raise ValueError("App requires Embedder.")
self.embedder = embedder self.embedder = embedder
@@ -256,7 +253,6 @@ class EmbedChain(JSONSerializable):
) )
return self.add(source=source, data_type=data_type, metadata=metadata, config=config) return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
def _get_existing_doc_id(self, chunker: BaseChunker, src: Any): def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
""" """
Get id of existing document for a given source, based on the data type Get id of existing document for a given source, based on the data type
@@ -395,10 +391,10 @@ class EmbedChain(JSONSerializable):
return list(documents), metadatas, ids, 0 return list(documents), metadatas, ids, 0
# Count before, to calculate a delta in the end. # Count before, to calculate a delta in the end.
chunks_before_addition = self.count() chunks_before_addition = self.db.count()
self.db.add(documents=documents, metadatas=metadatas, ids=ids) self.db.add(documents=documents, metadatas=metadatas, ids=ids)
count_new_chunks = self.count() - chunks_before_addition count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks return list(documents), metadatas, ids, count_new_chunks

View File

@@ -1,6 +1,6 @@
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from embedchain.config.embedder.BaseEmbedderConfig import BaseEmbedderConfig from embedchain.config.embedder.base import BaseEmbedderConfig
try: try:
from chromadb.api.types import Documents, Embeddings from chromadb.api.types import Documents, Embeddings

View File

@@ -16,7 +16,7 @@ except RuntimeError:
from chromadb.utils import embedding_functions from chromadb.utils import embedding_functions
class OpenAiEmbedder(BaseEmbedder): class OpenAIEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None): def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config) super().__init__(config=config)
if self.config.model is None: if self.config.model is None:

View File

@@ -4,3 +4,4 @@ from enum import Enum
class VectorDatabases(Enum): class VectorDatabases(Enum):
CHROMADB = "CHROMADB" CHROMADB = "CHROMADB"
ELASTICSEARCH = "ELASTICSEARCH" ELASTICSEARCH = "ELASTICSEARCH"
OPENSEARCH = "OPENSEARCH"

View File

@@ -1,4 +1,4 @@
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.helper.json_serializable import JSONSerializable from embedchain.helper.json_serializable import JSONSerializable

View File

@@ -1,50 +0,0 @@
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.helper_classes.json_serializable import JSONSerializable
class BaseVectorDB(JSONSerializable):
"""Base class for vector database."""
def __init__(self, config: BaseVectorDbConfig):
self.client = self._get_or_create_db()
self.config: BaseVectorDbConfig = config
def _initialize(self):
"""
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
So it's can't be done in __init__ in one step.
"""
raise NotImplementedError
def _get_or_create_db(self):
"""Get or create the database."""
raise NotImplementedError
def _get_or_create_collection(self):
raise NotImplementedError
def _set_embedder(self, embedder: BaseEmbedder):
self.embedder = embedder
def get(self):
raise NotImplementedError
def add(self):
raise NotImplementedError
def query(self):
raise NotImplementedError
def count(self):
raise NotImplementedError
def delete(self):
raise NotImplementedError
def reset(self):
raise NotImplementedError
def set_collection_name(self, name: str):
raise NotImplementedError

View File

@@ -63,7 +63,9 @@ class ChromaDB(BaseVectorDB):
This method is needed because `embedder` attribute needs to be set externally before it can be initialized. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
""" """
if not self.embedder: if not self.embedder:
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") raise ValueError(
"Embedder not set. Please set an embedder with `_set_embedder()` function before initialization."
)
self._get_or_create_collection(self.config.collection_name) self._get_or_create_collection(self.config.collection_name)
def _get_or_create_db(self): def _get_or_create_db(self):

View File

@@ -0,0 +1,196 @@
import logging
from typing import Dict, List, Optional, Set
try:
from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
except ImportError:
raise ImportError(
"OpenSearch requires extra dependencies. Install with `pip install --upgrade embedchain[opensearch]`"
) from None
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import OpenSearchVectorSearch
from embedchain.config import OpenSearchDBConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB
@register_deserializable
class OpenSearchDB(BaseVectorDB):
"""
OpenSearch as vector database
"""
def __init__(self, config: OpenSearchDBConfig):
"""OpenSearch as vector database.
:param config: OpenSearch domain config
:type config: OpenSearchDBConfig
"""
if config is None:
raise ValueError("OpenSearchDBConfig is required")
self.config = config
self.client = OpenSearch(
hosts=[self.config.opensearch_url],
http_auth=self.config.http_auth,
**self.config.extra_params,
)
info = self.client.info()
logging.info(f"Connected to {info['version']['distribution']}. Version: {info['version']['number']}")
# Remove auth credentials from config after successful connection
super().__init__(config=self.config)
def _initialize(self):
logging.info(self.client.info())
index_name = self._get_index()
if self.client.indices.exists(index=index_name):
print(f"Index '{index_name}' already exists.")
return
index_body = {
"settings": {"knn": True},
"mappings": {
"properties": {
"text": {"type": "text"},
"embeddings": {
"type": "knn_vector",
"index": False,
"dimension": self.config.vector_dimension,
},
}
},
}
self.client.indices.create(index_name, body=index_body)
print(self.client.indices.get(index_name))
def _get_or_create_db(self):
"""Called during initialization"""
return self.client
def _get_or_create_collection(self, name):
"""Note: nothing to return here. Discuss later"""
def get(
self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None
) -> Set[str]:
"""
Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existence
:type ids: List[str]
:param where: to filter data
:type where: Dict[str, any]
:return: ids
:type: Set[str]
"""
if ids:
query = {"query": {"bool": {"must": [{"ids": {"values": ids}}]}}}
else:
query = {"query": {"bool": {"must": []}}}
if "app_id" in where:
app_id = where["app_id"]
query["query"]["bool"]["must"].append({"term": {"metadata.app_id": app_id}})
# OpenSearch syntax is different from Elasticsearch
response = self.client.search(index=self._get_index(), body=query, _source=False, size=limit)
docs = response["hits"]["hits"]
ids = [doc["_id"] for doc in docs]
return {"ids": set(ids)}
def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
"""add data in vector database
:param documents: list of texts to add
:type documents: List[str]
:param metadatas: list of metadata associated with docs
:type metadatas: List[object]
:param ids: ids of docs
:type ids: List[str]
"""
docs = []
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embeddings in zip(ids, documents, metadatas, embeddings):
docs.append(
{
"_index": self._get_index(),
"_id": id,
"_source": {"text": text, "metadata": metadata, "embeddings": embeddings},
}
)
bulk(self.client, docs)
self.client.indices.refresh(index=self._get_index())
def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
"""
query contents from vector data base based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
:return: Database contents that are the result of the query
:rtype: List[str]
"""
embeddings = OpenAIEmbeddings()
docsearch = OpenSearchVectorSearch(
index_name=self._get_index(),
embedding_function=embeddings,
opensearch_url=f"{self.config.opensearch_url}",
http_auth=self.config.http_auth,
use_ssl=True,
)
docs = docsearch.similarity_search(
input_query,
search_type="script_scoring",
space_type="cosinesimil",
vector_field="embeddings",
text_field="text",
metadata_field="metadata",
)
contents = [doc.page_content for doc in docs]
return contents
def set_collection_name(self, name: str):
"""
Set the name of the collection. A collection is an isolated space for vectors.
:param name: Name of the collection.
:type name: str
"""
if not isinstance(name, str):
raise TypeError("Collection name must be a string")
self.config.collection_name = name
def count(self) -> int:
"""
Count number of documents/chunks embedded in the database.
:return: number of documents
:rtype: int
"""
query = {"query": {"match_all": {}}}
response = self.client.count(index=self._get_index(), body=query)
doc_count = response["count"]
return doc_count
def reset(self):
"""
Resets the database. Deletes all embeddings irreversibly.
"""
# Delete all data from the database
if self.client.indices.exists(index=self._get_index()):
# delete index in Es
self.client.indices.delete(index=self._get_index())
def _get_index(self) -> str:
"""Get the OpenSearch index for a collection
:return: OpenSearch index
:rtype: str
"""
return self.config.collection_name

View File

@@ -98,6 +98,7 @@ torch = { version = ">=2.0.0, !=2.0.1", optional = true }
# Torch 2.0.1 is not compatible with poetry (https://github.com/pytorch/pytorch/issues/100974) # Torch 2.0.1 is not compatible with poetry (https://github.com/pytorch/pytorch/issues/100974)
gpt4all = { version = "1.0.8", optional = true } gpt4all = { version = "1.0.8", optional = true }
# 1.0.9 is not working for some users (https://github.com/nomic-ai/gpt4all/issues/1394) # 1.0.9 is not working for some users (https://github.com/nomic-ai/gpt4all/issues/1394)
opensearch-py = { version = "2.3.1", optional = true }
elasticsearch = { version = "^8.9.0", optional = true } elasticsearch = { version = "^8.9.0", optional = true }
flask = { version = "^2.3.3", optional = true } flask = { version = "^2.3.3", optional = true }
twilio = { version = "^8.5.0", optional = true } twilio = { version = "^8.5.0", optional = true }
@@ -123,6 +124,7 @@ streamlit = ["streamlit"]
community = ["llama-hub"] community = ["llama-hub"]
opensource = ["sentence-transformers", "torch", "gpt4all"] opensource = ["sentence-transformers", "torch", "gpt4all"]
elasticsearch = ["elasticsearch"] elasticsearch = ["elasticsearch"]
opensearch = ["opensearch-py"]
poe = ["fastapi-poe"] poe = ["fastapi-poe"]
discord = ["discord"] discord = ["discord"]
slack = ["slack-sdk", "flask"] slack = ["slack-sdk", "flask"]