Add embedder docs and config changes (#1684)

This commit is contained in:
Dev Khant
2024-08-12 16:09:01 +05:30
committed by GitHub
parent 464a188662
commit b245309242
12 changed files with 130 additions and 28 deletions

View File

@@ -9,9 +9,12 @@ class BaseEmbedderConfig(ABC):
def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
embedding_dims: Optional[int] = None,
# Ollama specific
base_url: Optional[str] = None,
ollama_base_url: Optional[str] = None,
# Huggingface specific
model_kwargs: Optional[dict] = None
):
@@ -20,20 +23,23 @@ class BaseEmbedderConfig(ABC):
:param model: Embedding model to use, defaults to None
:type model: Optional[str], optional
:param api_key: API key to be use, defaults to None
:type api_key: Optional[str], optional
:param embedding_dims: The number of dimensions in the embedding, defaults to None
:type embedding_dims: Optional[int], optional
:param base_url: Base URL for the Ollama API, defaults to None
:type base_url: Optional[str], optional
:param ollama_base_url: Base URL for the Ollama API, defaults to None
:type ollama_base_url: Optional[str], optional
:param model_kwargs: key-value arguments for the huggingface embedding model, defaults a dict inside init
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
"""
self.model = model
self.api_key = api_key
self.embedding_dims = embedding_dims
# Ollama specific
self.base_url = base_url
self.ollama_base_url = ollama_base_url
# Huggingface specific
self.model_kwargs = model_kwargs or {}

View File

@@ -10,6 +10,7 @@ class BaseLlmConfig(ABC):
self,
model: Optional[str] = None,
temperature: float = 0,
api_key: Optional[str] = None,
max_tokens: int = 3000,
top_p: float = 0,
top_k: int = 1,
@@ -32,6 +33,8 @@ class BaseLlmConfig(ABC):
:param temperature: Controls the randomness of the model's output.
Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
:type temperature: float, optional
:param api_key: OpenAI API key to be use, defaults to None
:type api_key: Optional[str], optional
:param max_tokens: Controls how many tokens are generated, defaults to 3000
:type max_tokens: int, optional
:param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
@@ -39,15 +42,15 @@ class BaseLlmConfig(ABC):
:type top_p: float, optional
:param top_k: Controls the diversity of words. Higher values make word selection more diverse, defaults to 0
:type top_k: int, optional
:param models: Controls the Openrouter models used, defaults to None
:param models: Openrouter models to use, defaults to None
:type models: Optional[list[str]], optional
:param route: Controls the Openrouter route used, defaults to "fallback"
:param route: Openrouter route to be used, defaults to "fallback"
:type route: Optional[str], optional
:param openrouter_base_url: Controls the Openrouter base URL used, defaults to "https://openrouter.ai/api/v1"
:param openrouter_base_url: Openrouter base URL to be use, defaults to "https://openrouter.ai/api/v1"
:type openrouter_base_url: Optional[str], optional
:param site_url: Controls the Openrouter site URL used, defaults to None
:param site_url: Openrouter site URL to use, defaults to None
:type site_url: Optional[str], optional
:param app_name: Controls the Openrouter app name used, defaults to None
:param app_name: Openrouter app name to use, defaults to None
:type app_name: Optional[str], optional
:param ollama_base_url: The base URL of the LLM, defaults to None
:type ollama_base_url: Optional[str], optional
@@ -55,6 +58,7 @@ class BaseLlmConfig(ABC):
self.model = model
self.temperature = temperature
self.api_key = api_key
self.max_tokens = max_tokens
self.top_p = top_p
self.top_k = top_k

View File

@@ -1,4 +1,4 @@
from typing import Optional,ClassVar
from typing import Optional, ClassVar, Dict, Any
from pydantic import BaseModel, Field, model_validator
@@ -21,6 +21,17 @@ class ChromaDbConfig(BaseModel):
if not path and not (host and port):
raise ValueError("Either 'host' and 'port' or 'path' must be provided.")
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}")
return values
class Config:
arbitrary_types_allowed = True
model_config = {
"arbitrary_types_allowed": True,
}

View File

@@ -1,6 +1,5 @@
from typing import Optional,ClassVar
from pydantic import BaseModel, Field, model_validator
from typing import Optional, ClassVar, Dict, Any
class QdrantConfig(BaseModel):
from qdrant_client import QdrantClient
@@ -14,10 +13,11 @@ class QdrantConfig(BaseModel):
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
on_disk: Optional[bool]= Field(False, description="Enables persistant storage")
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
@model_validator(mode="before")
def check_host_port_or_path(cls, values):
@classmethod
def check_host_port_or_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
host, port, path, url, api_key = (
values.get("host"),
values.get("port"),
@@ -31,5 +31,16 @@ class QdrantConfig(BaseModel):
)
return values
class Config:
arbitrary_types_allowed = True
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}")
return values
model_config = {
"arbitrary_types_allowed": True,
}