Add embedder docs and config changes (#1684)

This commit is contained in:
Dev Khant
2024-08-12 16:09:01 +05:30
committed by GitHub
parent 464a188662
commit b245309242
12 changed files with 130 additions and 28 deletions

View File

@@ -1,4 +1,4 @@
from typing import Optional,ClassVar
from typing import Optional, ClassVar, Dict, Any
from pydantic import BaseModel, Field, model_validator
@@ -21,6 +21,17 @@ class ChromaDbConfig(BaseModel):
if not path and not (host and port):
raise ValueError("Either 'host' and 'port' or 'path' must be provided.")
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}")
return values
class Config:
arbitrary_types_allowed = True
model_config = {
"arbitrary_types_allowed": True,
}

View File

@@ -1,6 +1,5 @@
from typing import Optional,ClassVar
from pydantic import BaseModel, Field, model_validator
from typing import Optional, ClassVar, Dict, Any
class QdrantConfig(BaseModel):
from qdrant_client import QdrantClient
@@ -14,10 +13,11 @@ class QdrantConfig(BaseModel):
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
on_disk: Optional[bool]= Field(False, description="Enables persistant storage")
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
@model_validator(mode="before")
def check_host_port_or_path(cls, values):
@classmethod
def check_host_port_or_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
host, port, path, url, api_key = (
values.get("host"),
values.get("port"),
@@ -31,5 +31,16 @@ class QdrantConfig(BaseModel):
)
return values
class Config:
arbitrary_types_allowed = True
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}")
return values
model_config = {
"arbitrary_types_allowed": True,
}