Add embedder docs and config changes (#1684)
This commit is contained in:
@@ -9,9 +9,12 @@ class BaseEmbedderConfig(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
embedding_dims: Optional[int] = None,
|
||||
|
||||
# Ollama specific
|
||||
base_url: Optional[str] = None,
|
||||
ollama_base_url: Optional[str] = None,
|
||||
|
||||
# Huggingface specific
|
||||
model_kwargs: Optional[dict] = None
|
||||
):
|
||||
@@ -20,20 +23,23 @@ class BaseEmbedderConfig(ABC):
|
||||
|
||||
:param model: Embedding model to use, defaults to None
|
||||
:type model: Optional[str], optional
|
||||
:param api_key: API key to be use, defaults to None
|
||||
:type api_key: Optional[str], optional
|
||||
:param embedding_dims: The number of dimensions in the embedding, defaults to None
|
||||
:type embedding_dims: Optional[int], optional
|
||||
:param base_url: Base URL for the Ollama API, defaults to None
|
||||
:type base_url: Optional[str], optional
|
||||
:param ollama_base_url: Base URL for the Ollama API, defaults to None
|
||||
:type ollama_base_url: Optional[str], optional
|
||||
:param model_kwargs: key-value arguments for the huggingface embedding model, defaults a dict inside init
|
||||
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
|
||||
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.embedding_dims = embedding_dims
|
||||
|
||||
# Ollama specific
|
||||
self.base_url = base_url
|
||||
self.ollama_base_url = ollama_base_url
|
||||
|
||||
# Huggingface specific
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
|
||||
@@ -10,6 +10,7 @@ class BaseLlmConfig(ABC):
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0,
|
||||
api_key: Optional[str] = None,
|
||||
max_tokens: int = 3000,
|
||||
top_p: float = 0,
|
||||
top_k: int = 1,
|
||||
@@ -32,6 +33,8 @@ class BaseLlmConfig(ABC):
|
||||
:param temperature: Controls the randomness of the model's output.
|
||||
Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
|
||||
:type temperature: float, optional
|
||||
:param api_key: OpenAI API key to be use, defaults to None
|
||||
:type api_key: Optional[str], optional
|
||||
:param max_tokens: Controls how many tokens are generated, defaults to 3000
|
||||
:type max_tokens: int, optional
|
||||
:param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
|
||||
@@ -39,15 +42,15 @@ class BaseLlmConfig(ABC):
|
||||
:type top_p: float, optional
|
||||
:param top_k: Controls the diversity of words. Higher values make word selection more diverse, defaults to 0
|
||||
:type top_k: int, optional
|
||||
:param models: Controls the Openrouter models used, defaults to None
|
||||
:param models: Openrouter models to use, defaults to None
|
||||
:type models: Optional[list[str]], optional
|
||||
:param route: Controls the Openrouter route used, defaults to "fallback"
|
||||
:param route: Openrouter route to be used, defaults to "fallback"
|
||||
:type route: Optional[str], optional
|
||||
:param openrouter_base_url: Controls the Openrouter base URL used, defaults to "https://openrouter.ai/api/v1"
|
||||
:param openrouter_base_url: Openrouter base URL to be use, defaults to "https://openrouter.ai/api/v1"
|
||||
:type openrouter_base_url: Optional[str], optional
|
||||
:param site_url: Controls the Openrouter site URL used, defaults to None
|
||||
:param site_url: Openrouter site URL to use, defaults to None
|
||||
:type site_url: Optional[str], optional
|
||||
:param app_name: Controls the Openrouter app name used, defaults to None
|
||||
:param app_name: Openrouter app name to use, defaults to None
|
||||
:type app_name: Optional[str], optional
|
||||
:param ollama_base_url: The base URL of the LLM, defaults to None
|
||||
:type ollama_base_url: Optional[str], optional
|
||||
@@ -55,6 +58,7 @@ class BaseLlmConfig(ABC):
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional,ClassVar
|
||||
from typing import Optional, ClassVar, Dict, Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -21,6 +21,17 @@ class ChromaDbConfig(BaseModel):
|
||||
if not path and not (host and port):
|
||||
raise ValueError("Either 'host' and 'port' or 'path' must be provided.")
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}")
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Optional,ClassVar
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import Optional, ClassVar, Dict, Any
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
from qdrant_client import QdrantClient
|
||||
@@ -14,10 +13,11 @@ class QdrantConfig(BaseModel):
|
||||
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
|
||||
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
|
||||
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
|
||||
on_disk: Optional[bool]= Field(False, description="Enables persistant storage")
|
||||
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_host_port_or_path(cls, values):
|
||||
@classmethod
|
||||
def check_host_port_or_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
host, port, path, url, api_key = (
|
||||
values.get("host"),
|
||||
values.get("port"),
|
||||
@@ -31,5 +31,16 @@ class QdrantConfig(BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}")
|
||||
return values
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
@@ -18,7 +18,7 @@ class OllamaEmbedding(EmbeddingBase):
|
||||
if not self.config.embedding_dims:
|
||||
self.config.embedding_dims=512
|
||||
|
||||
self.client = Client(host=self.config.base_url)
|
||||
self.client = Client(host=self.config.ollama_base_url)
|
||||
self._ensure_model_exists()
|
||||
|
||||
def _ensure_model_exists(self):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from openai import OpenAI
|
||||
@@ -9,12 +10,11 @@ class OpenAIEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model="text-embedding-3-small"
|
||||
if not self.config.embedding_dims:
|
||||
self.config.embedding_dims=1536
|
||||
self.config.model = self.config.model or "text-embedding-3-small"
|
||||
self.config.embedding_dims = self.config.embedding_dims or 1536
|
||||
|
||||
self.client = OpenAI()
|
||||
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
|
||||
def embed(self, text):
|
||||
"""
|
||||
|
||||
@@ -14,10 +14,11 @@ class OpenAILLM(LLMBase):
|
||||
if not self.config.model:
|
||||
self.config.model="gpt-4o"
|
||||
|
||||
if os.environ.get("OPENROUTER_API_KEY"):
|
||||
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
|
||||
self.client = OpenAI(api_key=os.environ.get("OPENROUTER_API_KEY"), base_url=self.config.openrouter_base_url)
|
||||
else:
|
||||
self.client = OpenAI()
|
||||
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
|
||||
@@ -374,6 +374,7 @@ class Memory(MemoryBase):
|
||||
|
||||
new_metadata = metadata or {}
|
||||
new_metadata["data"] = data
|
||||
new_metadata["hash"] = existing_memory.payload.get("hash")
|
||||
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
||||
new_metadata["updated_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat()
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
provider: str = Field(
|
||||
description="Provider of the vector store (e.g., 'qdrant', 'chromadb')",
|
||||
description="Provider of the vector store (e.g., 'qdrant', 'chroma')",
|
||||
default="qdrant",
|
||||
)
|
||||
config: Optional[Dict] = Field(
|
||||
|
||||
Reference in New Issue
Block a user