fix: elastic search (#600)

This commit is contained in:
cachho
2023-09-13 19:58:18 +02:00
committed by GitHub
parent 79efa51941
commit 119ec5e405
11 changed files with 135 additions and 55 deletions

View File

@@ -5,30 +5,66 @@ title: '💾 Vector Database'
We support `Chroma` and `Elasticsearch` as two vector database. We support `Chroma` and `Elasticsearch` as two vector database.
`Chroma` is used as a default database. `Chroma` is used as a default database.
### Elasticsearch ## Elasticsearch
### Minimal Example
In order to use `Elasticsearch` as vector database we need to use App type `CustomApp`. In order to use `Elasticsearch` as vector database we need to use App type `CustomApp`.
1. Set the environment variables in a `.env` file.
```
OPENAI_API_KEY=sk-SECRETKEY
ELASTICSEARCH_API_KEY=SECRETKEY==
ELASTICSEARCH_URL=https://secret-domain.europe-west3.gcp.cloud.es.io:443
```
Please note that the key needs certain privileges. For testing you can just toggle off `restrict privileges` under `/app/management/security/api_keys/` in your web interface.
2. Load the app
```python
from embedchain import CustomApp
from embedchain.embedder.openai import OpenAiEmbedder
from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.elasticsearch import ElasticsearchDB
es_app = CustomApp(
llm=OpenAILlm(),
embedder=OpenAiEmbedder(),
db=ElasticsearchDB(),
)
```
### More custom settings
You can get a URL for elasticsearch in the cloud, or run it locally.
The following example shows you how to configure embedchain to work with a locally running elasticsearch.
Instead of using an API key, we use http login credentials. The localhost url can be defined in .env or in the config.
```python ```python
import os import os
from embedchain import CustomApp from embedchain import CustomApp
from embedchain.config import CustomAppConfig, ElasticsearchDBConfig from embedchain.config import CustomAppConfig, ElasticsearchDBConfig
from embedchain.models import Providers, EmbeddingFunctions, VectorDatabases from embedchain.embedder.openai import OpenAiEmbedder
from embedchain.llm.openai import OpenAILlm
os.environ["OPENAI_API_KEY"] = 'OPENAI_API_KEY' from embedchain.vectordb.elasticsearch import ElasticsearchDB
es_config = ElasticsearchDBConfig( es_config = ElasticsearchDBConfig(
# elasticsearch url or list of nodes url with different hosts and ports. # elasticsearch url or list of nodes url with different hosts and ports.
es_url='http://localhost:9200', es_url='https://localhost:9200',
# pass named parameters supported by Python Elasticsearch client # pass named parameters supported by Python Elasticsearch client
ca_certs="/path/to/http_ca.crt", http_auth=("elastic", "secret"),
basic_auth=("username", "password") ca_certs="~/binaries/elasticsearch-8.7.0/config/certs/http_ca.crt" # your cert path
# verify_certs=False # Alternative, if you aren't using certs
) # pass named parameters supported by elasticsearch-py
es_app = CustomApp(
config=CustomAppConfig(log_level="INFO"),
llm=OpenAILlm(),
embedder=OpenAiEmbedder(),
db=ElasticsearchDB(config=es_config),
) )
config = CustomAppConfig(
embedding_fn=EmbeddingFunctions.OPENAI,
provider=Providers.OPENAI,
db_type=VectorDatabases.ELASTICSEARCH,
es_config=es_config,
)
es_app = CustomApp(config)
``` ```
- Set `db_type=VectorDatabases.ELASTICSEARCH` and `es_config=ElasticsearchDBConfig(es_url='')` in `CustomAppConfig`. 3. This should log your connection details to the console.
- `ElasticsearchDBConfig` accepts `es_url` as elasticsearch url or as list of nodes url with different hosts and ports. Additionally we can pass named parameters supported by Python Elasticsearch client. 4. Alternatively to a URL, you `ElasticsearchDBConfig` accepts `es_url` as a list of nodes url with different hosts and ports.
5. Additionally we can pass named parameters supported by Python Elasticsearch client.

View File

@@ -1,3 +1,4 @@
import os
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
@@ -26,7 +27,20 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
:type ES_EXTRA_PARAMS: Dict[str, Any], optional :type ES_EXTRA_PARAMS: Dict[str, Any], optional
""" """
# self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]): # self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
self.ES_URL = es_url self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
if not self.ES_URL:
raise AttributeError(
"Elasticsearch needs a URL attribute, "
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` in `.env`"
)
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
# Load API key from .env if it's not explicitly passed.
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
if (
not self.ES_EXTRA_PARAMS.get("api_key")
and not self.ES_EXTRA_PARAMS.get("basic_auth")
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
and not self.ES_EXTRA_PARAMS.get("http_auth")
):
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
super().__init__(collection_name=collection_name, dir=dir) super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -51,6 +51,8 @@ class BaseEmbedder:
:param vector_dimension: vector dimension size :param vector_dimension: vector dimension size
:type vector_dimension: int :type vector_dimension: int
""" """
if not isinstance(vector_dimension, int):
raise TypeError("vector dimension must be int")
self.vector_dimension = vector_dimension self.vector_dimension = vector_dimension
@staticmethod @staticmethod

View File

@@ -4,7 +4,7 @@ from chromadb.utils import embedding_functions
from embedchain.config import BaseEmbedderConfig from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.models import EmbeddingFunctions from embedchain.models import VectorDimensions
class GPT4AllEmbedder(BaseEmbedder): class GPT4AllEmbedder(BaseEmbedder):
@@ -17,5 +17,5 @@ class GPT4AllEmbedder(BaseEmbedder):
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.config.model) embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.config.model)
self.set_embedding_fn(embedding_fn=embedding_fn) self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = EmbeddingFunctions.GPT4ALL.value vector_dimension = VectorDimensions.GPT4ALL.value
self.set_vector_dimension(vector_dimension=vector_dimension) self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -4,7 +4,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
from embedchain.config import BaseEmbedderConfig from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.models import EmbeddingFunctions from embedchain.models import VectorDimensions
class HuggingFaceEmbedder(BaseEmbedder): class HuggingFaceEmbedder(BaseEmbedder):
@@ -15,5 +15,5 @@ class HuggingFaceEmbedder(BaseEmbedder):
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn) self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = EmbeddingFunctions.HUGGING_FACE.value vector_dimension = VectorDimensions.HUGGING_FACE.value
self.set_vector_dimension(vector_dimension=vector_dimension) self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -5,7 +5,7 @@ from langchain.embeddings import OpenAIEmbeddings
from embedchain.config import BaseEmbedderConfig from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.models import EmbeddingFunctions from embedchain.models import VectorDimensions
try: try:
from chromadb.utils import embedding_functions from chromadb.utils import embedding_functions
@@ -37,4 +37,4 @@ class OpenAiEmbedder(BaseEmbedder):
) )
self.set_embedding_fn(embedding_fn=embedding_fn) self.set_embedding_fn(embedding_fn=embedding_fn)
self.set_vector_dimension(vector_dimension=EmbeddingFunctions.OPENAI.value) self.set_vector_dimension(vector_dimension=VectorDimensions.OPENAI.value)

View File

@@ -4,7 +4,7 @@ from langchain.embeddings import VertexAIEmbeddings
from embedchain.config import BaseEmbedderConfig from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.models import EmbeddingFunctions from embedchain.models import VectorDimensions
class VertexAiEmbedder(BaseEmbedder): class VertexAiEmbedder(BaseEmbedder):
@@ -15,5 +15,5 @@ class VertexAiEmbedder(BaseEmbedder):
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn) self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = EmbeddingFunctions.VERTEX_AI.value vector_dimension = VectorDimensions.VERTEX_AI.value
self.set_vector_dimension(vector_dimension=vector_dimension) self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -87,7 +87,7 @@ class ChromaDB(BaseVectorDB):
) )
return self.collection return self.collection
def get(self, ids=None, where=None, limit=None): def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
""" """
Get existing doc ids present in vector database Get existing doc ids present in vector database
@@ -95,6 +95,8 @@ class ChromaDB(BaseVectorDB):
:type ids: List[str] :type ids: List[str]
:param where: Optional. to filter data :param where: Optional. to filter data
:type where: Dict[str, Any] :type where: Dict[str, Any]
:param limit: Optional. maximum number of documents
:type limit: Optional[int]
:return: Existing documents. :return: Existing documents.
:rtype: List[str] :rtype: List[str]
""" """
@@ -180,6 +182,8 @@ class ChromaDB(BaseVectorDB):
:param name: Name of the collection. :param name: Name of the collection.
:type name: str :type name: str
""" """
if not isinstance(name, str):
raise TypeError("Collection name must be a string")
self.config.collection_name = name self.config.collection_name = name
self._get_or_create_collection(self.config.collection_name) self._get_or_create_collection(self.config.collection_name)

View File

@@ -1,3 +1,4 @@
import logging
from typing import Dict, List, Optional, Set from typing import Dict, List, Optional, Set
try: try:
@@ -34,9 +35,15 @@ class ElasticsearchDB(BaseVectorDB):
:raises ValueError: No config provided :raises ValueError: No config provided
""" """
if config is None and es_config is None: if config is None and es_config is None:
raise ValueError("ElasticsearchDBConfig is required") self.config = ElasticsearchDBConfig()
else:
if not isinstance(config, ElasticsearchDBConfig):
raise TypeError(
"config is not a `ElasticsearchDBConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
self.config = config or es_config self.config = config or es_config
self.client = Elasticsearch(es_config.ES_URL, **es_config.ES_EXTRA_PARAMS) self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS)
# Call parent init here because embedder is needed # Call parent init here because embedder is needed
super().__init__(config=self.config) super().__init__(config=self.config)
@@ -45,6 +52,7 @@ class ElasticsearchDB(BaseVectorDB):
""" """
This method is needed because `embedder` attribute needs to be set externally before it can be initialized. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
""" """
logging.info(self.client.info())
index_settings = { index_settings = {
"mappings": { "mappings": {
"properties": { "properties": {
@@ -66,7 +74,9 @@ class ElasticsearchDB(BaseVectorDB):
def _get_or_create_collection(self, name): def _get_or_create_collection(self, name):
"""Note: nothing to return here. Discuss later""" """Note: nothing to return here. Discuss later"""
def get(self, ids: List[str], where: Dict[str, any]) -> Set[str]: def get(
self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None
) -> Set[str]:
""" """
Get existing doc ids present in vector database Get existing doc ids present in vector database
@@ -77,14 +87,18 @@ class ElasticsearchDB(BaseVectorDB):
:return: ids :return: ids
:rtype: Set[str] :rtype: Set[str]
""" """
if ids:
query = {"bool": {"must": [{"ids": {"values": ids}}]}} query = {"bool": {"must": [{"ids": {"values": ids}}]}}
else:
query = {"bool": {"must": []}}
if "app_id" in where: if "app_id" in where:
app_id = where["app_id"] app_id = where["app_id"]
query["bool"]["must"].append({"term": {"metadata.app_id": app_id}}) query["bool"]["must"].append({"term": {"metadata.app_id": app_id}})
response = self.client.search(index=self.es_index, query=query, _source=False)
response = self.client.search(index=self._get_index(), query=query, _source=False, size=limit)
docs = response["hits"]["hits"] docs = response["hits"]["hits"]
ids = [doc["_id"] for doc in docs] ids = [doc["_id"] for doc in docs]
return set(ids) return {"ids": set(ids)}
def add(self, documents: List[str], metadatas: List[object], ids: List[str]): def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
"""add data in vector database """add data in vector database
@@ -150,6 +164,8 @@ class ElasticsearchDB(BaseVectorDB):
:param name: Name of the collection. :param name: Name of the collection.
:type name: str :type name: str
""" """
if not isinstance(name, str):
raise TypeError("Collection name must be a string")
self.config.collection_name = name self.config.collection_name = name
def count(self) -> int: def count(self) -> int:
@@ -181,4 +197,4 @@ class ElasticsearchDB(BaseVectorDB):
""" """
# NOTE: The method is preferred to an attribute, because if collection name changes, # NOTE: The method is preferred to an attribute, because if collection name changes,
# it's always up-to-date. # it's always up-to-date.
return f"{self.config.collection_name}_{self.embedder.vector_dimension}" return f"{self.config.collection_name}_{self.embedder.vector_dimension}".lower()

View File

@@ -0,0 +1,11 @@
import unittest
from embedchain.embedder.base import BaseEmbedder
class TestEmbedder(unittest.TestCase):
def test_init_with_invalid_vector_dim(self):
# Test if an exception is raised when an invalid vector_dim is provided
embedder = BaseEmbedder()
with self.assertRaises(TypeError):
embedder.set_vector_dimension(None)

View File

@@ -1,29 +1,26 @@
import os
import unittest import unittest
from embedchain.config import ElasticsearchDBConfig from embedchain.config import ElasticsearchDBConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.elasticsearch import ElasticsearchDB from embedchain.vectordb.elasticsearch import ElasticsearchDB
class TestEsDB(unittest.TestCase): class TestEsDB(unittest.TestCase):
def setUp(self): def setUp(self):
self.es_config = ElasticsearchDBConfig() self.es_config = ElasticsearchDBConfig(es_url="http://mock-url.net")
self.vector_dim = 384 self.vector_dim = 384
def test_init_without_url(self):
# Make sure it's not loaded from env
try:
del os.environ["ELASTICSEARCH_URL"]
except KeyError:
pass
# Test if an exception is raised when an invalid es_config is provided
with self.assertRaises(AttributeError):
ElasticsearchDB()
def test_init_with_invalid_es_config(self): def test_init_with_invalid_es_config(self):
# Test if an exception is raised when an invalid es_config is provided # Test if an exception is raised when an invalid es_config is provided
with self.assertRaises(ValueError): with self.assertRaises(TypeError):
ElasticsearchDB(es_config=None) ElasticsearchDB(es_config={"ES_URL": "some_url", "valid es_config": False})
def test_init_with_invalid_vector_dim(self):
# Test if an exception is raised when an invalid vector_dim is provided
embedder = BaseEmbedder()
embedder.set_vector_dimension(None)
with self.assertRaises(ValueError):
ElasticsearchDB(es_config=self.es_config)
def test_init_with_invalid_collection_name(self):
# Test if an exception is raised when an invalid collection_name is provided
self.es_config.collection_name = None
with self.assertRaises(ValueError):
ElasticsearchDB(es_config=self.es_config)