fix: elastic search (#600)

This commit is contained in:
cachho
2023-09-13 19:58:18 +02:00
committed by GitHub
parent 79efa51941
commit 119ec5e405
11 changed files with 135 additions and 55 deletions

View File

@@ -5,30 +5,66 @@ title: '💾 Vector Database'
We support `Chroma` and `Elasticsearch` as two vector database.
`Chroma` is used as a default database.
### Elasticsearch
In order to use `Elasticsearch` as vector database we need to use App type `CustomApp`.
## Elasticsearch
### Minimal Example
In order to use `Elasticsearch` as vector database we need to use App type `CustomApp`.
1. Set the environment variables in a `.env` file.
```
OPENAI_API_KEY=sk-SECRETKEY
ELASTICSEARCH_API_KEY=SECRETKEY==
ELASTICSEARCH_URL=https://secret-domain.europe-west3.gcp.cloud.es.io:443
```
Please note that the key needs certain privileges. For testing you can just toggle off `restrict privileges` under `/app/management/security/api_keys/` in your web interface.
2. Load the app
```python
from embedchain import CustomApp
from embedchain.embedder.openai import OpenAiEmbedder
from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.elasticsearch import ElasticsearchDB
es_app = CustomApp(
llm=OpenAILlm(),
embedder=OpenAiEmbedder(),
db=ElasticsearchDB(),
)
```
### More custom settings
You can get a URL for elasticsearch in the cloud, or run it locally.
The following example shows you how to configure embedchain to work with a locally running elasticsearch.
Instead of using an API key, we use http login credentials. The localhost url can be defined in .env or in the config.
```python
import os
from embedchain import CustomApp
from embedchain.config import CustomAppConfig, ElasticsearchDBConfig
from embedchain.models import Providers, EmbeddingFunctions, VectorDatabases
os.environ["OPENAI_API_KEY"] = 'OPENAI_API_KEY'
from embedchain.embedder.openai import OpenAiEmbedder
from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.elasticsearch import ElasticsearchDB
es_config = ElasticsearchDBConfig(
# elasticsearch url or list of nodes url with different hosts and ports.
es_url='http://localhost:9200',
# pass named parameters supported by Python Elasticsearch client
ca_certs="/path/to/http_ca.crt",
basic_auth=("username", "password")
# elasticsearch url or list of nodes url with different hosts and ports.
es_url='https://localhost:9200',
# pass named parameters supported by Python Elasticsearch client
http_auth=("elastic", "secret"),
ca_certs="~/binaries/elasticsearch-8.7.0/config/certs/http_ca.crt" # your cert path
# verify_certs=False # Alternative, if you aren't using certs
) # pass named parameters supported by elasticsearch-py
es_app = CustomApp(
config=CustomAppConfig(log_level="INFO"),
llm=OpenAILlm(),
embedder=OpenAiEmbedder(),
db=ElasticsearchDB(config=es_config),
)
config = CustomAppConfig(
embedding_fn=EmbeddingFunctions.OPENAI,
provider=Providers.OPENAI,
db_type=VectorDatabases.ELASTICSEARCH,
es_config=es_config,
)
es_app = CustomApp(config)
```
- Set `db_type=VectorDatabases.ELASTICSEARCH` and `es_config=ElasticsearchDBConfig(es_url='')` in `CustomAppConfig`.
- `ElasticsearchDBConfig` accepts `es_url` as elasticsearch url or as list of nodes url with different hosts and ports. Additionally we can pass named parameters supported by Python Elasticsearch client.
3. This should log your connection details to the console.
4. Alternatively to a URL, you `ElasticsearchDBConfig` accepts `es_url` as a list of nodes url with different hosts and ports.
5. Additionally we can pass named parameters supported by Python Elasticsearch client.

View File

@@ -1,3 +1,4 @@
import os
from typing import Dict, List, Optional, Union
from embedchain.config.vectordbs.BaseVectorDbConfig import BaseVectorDbConfig
@@ -26,7 +27,20 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
:type ES_EXTRA_PARAMS: Dict[str, Any], optional
"""
# self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
self.ES_URL = es_url
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
if not self.ES_URL:
raise AttributeError(
"Elasticsearch needs a URL attribute, "
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` in `.env`"
)
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
# Load API key from .env if it's not explicitly passed.
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
if (
not self.ES_EXTRA_PARAMS.get("api_key")
and not self.ES_EXTRA_PARAMS.get("basic_auth")
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
and not self.ES_EXTRA_PARAMS.get("http_auth")
):
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -51,6 +51,8 @@ class BaseEmbedder:
:param vector_dimension: vector dimension size
:type vector_dimension: int
"""
if not isinstance(vector_dimension, int):
raise TypeError("vector dimension must be int")
self.vector_dimension = vector_dimension
@staticmethod

View File

@@ -4,7 +4,7 @@ from chromadb.utils import embedding_functions
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import EmbeddingFunctions
from embedchain.models import VectorDimensions
class GPT4AllEmbedder(BaseEmbedder):
@@ -17,5 +17,5 @@ class GPT4AllEmbedder(BaseEmbedder):
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.config.model)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = EmbeddingFunctions.GPT4ALL.value
vector_dimension = VectorDimensions.GPT4ALL.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -4,7 +4,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import EmbeddingFunctions
from embedchain.models import VectorDimensions
class HuggingFaceEmbedder(BaseEmbedder):
@@ -15,5 +15,5 @@ class HuggingFaceEmbedder(BaseEmbedder):
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = EmbeddingFunctions.HUGGING_FACE.value
vector_dimension = VectorDimensions.HUGGING_FACE.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -5,7 +5,7 @@ from langchain.embeddings import OpenAIEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import EmbeddingFunctions
from embedchain.models import VectorDimensions
try:
from chromadb.utils import embedding_functions
@@ -37,4 +37,4 @@ class OpenAiEmbedder(BaseEmbedder):
)
self.set_embedding_fn(embedding_fn=embedding_fn)
self.set_vector_dimension(vector_dimension=EmbeddingFunctions.OPENAI.value)
self.set_vector_dimension(vector_dimension=VectorDimensions.OPENAI.value)

View File

@@ -4,7 +4,7 @@ from langchain.embeddings import VertexAIEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import EmbeddingFunctions
from embedchain.models import VectorDimensions
class VertexAiEmbedder(BaseEmbedder):
@@ -15,5 +15,5 @@ class VertexAiEmbedder(BaseEmbedder):
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = EmbeddingFunctions.VERTEX_AI.value
vector_dimension = VectorDimensions.VERTEX_AI.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -87,7 +87,7 @@ class ChromaDB(BaseVectorDB):
)
return self.collection
def get(self, ids=None, where=None, limit=None):
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
"""
Get existing doc ids present in vector database
@@ -95,6 +95,8 @@ class ChromaDB(BaseVectorDB):
:type ids: List[str]
:param where: Optional. to filter data
:type where: Dict[str, Any]
:param limit: Optional. maximum number of documents
:type limit: Optional[int]
:return: Existing documents.
:rtype: List[str]
"""
@@ -180,6 +182,8 @@ class ChromaDB(BaseVectorDB):
:param name: Name of the collection.
:type name: str
"""
if not isinstance(name, str):
raise TypeError("Collection name must be a string")
self.config.collection_name = name
self._get_or_create_collection(self.config.collection_name)

View File

@@ -1,3 +1,4 @@
import logging
from typing import Dict, List, Optional, Set
try:
@@ -34,9 +35,15 @@ class ElasticsearchDB(BaseVectorDB):
:raises ValueError: No config provided
"""
if config is None and es_config is None:
raise ValueError("ElasticsearchDBConfig is required")
self.config = config or es_config
self.client = Elasticsearch(es_config.ES_URL, **es_config.ES_EXTRA_PARAMS)
self.config = ElasticsearchDBConfig()
else:
if not isinstance(config, ElasticsearchDBConfig):
raise TypeError(
"config is not a `ElasticsearchDBConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
self.config = config or es_config
self.client = Elasticsearch(self.config.ES_URL, **self.config.ES_EXTRA_PARAMS)
# Call parent init here because embedder is needed
super().__init__(config=self.config)
@@ -45,6 +52,7 @@ class ElasticsearchDB(BaseVectorDB):
"""
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
"""
logging.info(self.client.info())
index_settings = {
"mappings": {
"properties": {
@@ -66,7 +74,9 @@ class ElasticsearchDB(BaseVectorDB):
def _get_or_create_collection(self, name):
"""Note: nothing to return here. Discuss later"""
def get(self, ids: List[str], where: Dict[str, any]) -> Set[str]:
def get(
self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None
) -> Set[str]:
"""
Get existing doc ids present in vector database
@@ -77,14 +87,18 @@ class ElasticsearchDB(BaseVectorDB):
:return: ids
:rtype: Set[str]
"""
query = {"bool": {"must": [{"ids": {"values": ids}}]}}
if ids:
query = {"bool": {"must": [{"ids": {"values": ids}}]}}
else:
query = {"bool": {"must": []}}
if "app_id" in where:
app_id = where["app_id"]
query["bool"]["must"].append({"term": {"metadata.app_id": app_id}})
response = self.client.search(index=self.es_index, query=query, _source=False)
response = self.client.search(index=self._get_index(), query=query, _source=False, size=limit)
docs = response["hits"]["hits"]
ids = [doc["_id"] for doc in docs]
return set(ids)
return {"ids": set(ids)}
def add(self, documents: List[str], metadatas: List[object], ids: List[str]):
"""add data in vector database
@@ -150,6 +164,8 @@ class ElasticsearchDB(BaseVectorDB):
:param name: Name of the collection.
:type name: str
"""
if not isinstance(name, str):
raise TypeError("Collection name must be a string")
self.config.collection_name = name
def count(self) -> int:
@@ -181,4 +197,4 @@ class ElasticsearchDB(BaseVectorDB):
"""
# NOTE: The method is preferred to an attribute, because if collection name changes,
# it's always up-to-date.
return f"{self.config.collection_name}_{self.embedder.vector_dimension}"
return f"{self.config.collection_name}_{self.embedder.vector_dimension}".lower()

View File

@@ -0,0 +1,11 @@
import unittest
from embedchain.embedder.base import BaseEmbedder
class TestEmbedder(unittest.TestCase):
def test_init_with_invalid_vector_dim(self):
# Test if an exception is raised when an invalid vector_dim is provided
embedder = BaseEmbedder()
with self.assertRaises(TypeError):
embedder.set_vector_dimension(None)

View File

@@ -1,29 +1,26 @@
import os
import unittest
from embedchain.config import ElasticsearchDBConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.elasticsearch import ElasticsearchDB
class TestEsDB(unittest.TestCase):
def setUp(self):
self.es_config = ElasticsearchDBConfig()
self.es_config = ElasticsearchDBConfig(es_url="http://mock-url.net")
self.vector_dim = 384
def test_init_without_url(self):
# Make sure it's not loaded from env
try:
del os.environ["ELASTICSEARCH_URL"]
except KeyError:
pass
# Test if an exception is raised when an invalid es_config is provided
with self.assertRaises(AttributeError):
ElasticsearchDB()
def test_init_with_invalid_es_config(self):
# Test if an exception is raised when an invalid es_config is provided
with self.assertRaises(ValueError):
ElasticsearchDB(es_config=None)
def test_init_with_invalid_vector_dim(self):
# Test if an exception is raised when an invalid vector_dim is provided
embedder = BaseEmbedder()
embedder.set_vector_dimension(None)
with self.assertRaises(ValueError):
ElasticsearchDB(es_config=self.es_config)
def test_init_with_invalid_collection_name(self):
# Test if an exception is raised when an invalid collection_name is provided
self.es_config.collection_name = None
with self.assertRaises(ValueError):
ElasticsearchDB(es_config=self.es_config)
with self.assertRaises(TypeError):
ElasticsearchDB(es_config={"ES_URL": "some_url", "valid es_config": False})