improvement(OSS): Fix AOSS and AWS BedRock LLM (#2697)

Co-authored-by: Prateek Chhikara <prateekchhikara24@gmail.com>
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Saket Aryan
2025-05-16 04:49:29 +05:30
committed by GitHub
parent 267e5b13ea
commit 5c67a5e6bc
14 changed files with 502 additions and 127 deletions

View File

@@ -33,6 +33,10 @@ class BaseEmbedderConfig(ABC):
memory_search_embedding_type: Optional[str] = None,
# LM Studio specific
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
# AWS Bedrock specific
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region: Optional[str] = "us-west-2",
):
"""
Initializes a configuration class instance for the Embeddings.
@@ -92,3 +96,8 @@ class BaseEmbedderConfig(ABC):
# LM Studio specific
self.lmstudio_base_url = lmstudio_base_url
# AWS Bedrock specific
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_region = aws_region

View File

@@ -41,6 +41,10 @@ class BaseLlmConfig(ABC):
xai_base_url: Optional[str] = None,
# LM Studio specific
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
# AWS Bedrock specific
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region: Optional[str] = "us-west-2",
):
"""
Initializes a configuration class instance for the LLM.
@@ -123,3 +127,8 @@ class BaseLlmConfig(ABC):
# LM Studio specific
self.lmstudio_base_url = lmstudio_base_url
# AWS Bedrock specific
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_region = aws_region

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union, Type
from pydantic import BaseModel, Field, model_validator
@@ -7,14 +7,33 @@ class OpenSearchConfig(BaseModel):
collection_name: str = Field("mem0", description="Name of the index")
host: str = Field("localhost", description="OpenSearch host")
port: int = Field(9200, description="OpenSearch port")
user: Optional[str] = Field(None, description="Username for authentication")
password: Optional[str] = Field(None, description="Password for authentication")
api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
user: Optional[str] = Field(
None, description="Username for authentication"
)
password: Optional[str] = Field(
None, description="Password for authentication"
)
api_key: Optional[str] = Field(
None, description="API key for authentication (if applicable)"
)
embedding_model_dims: int = Field(
1536, description="Dimension of the embedding vector"
)
verify_certs: bool = Field(
False, description="Verify SSL certificates (default False for OpenSearch)"
)
use_ssl: bool = Field(
False, description="Use SSL for connection (default False for OpenSearch)"
)
http_auth: Optional[object] = Field(
None, description="HTTP authentication method / AWS SigV4"
)
connection_class: Optional[Union[str, Type]] = Field(
"RequestsHttpConnection", description="Connection class for OpenSearch"
)
pool_maxsize: int = Field(
20, description="Maximum number of connections in the pool"
)
@model_validator(mode="before")
@classmethod
@@ -22,11 +41,7 @@ class OpenSearchConfig(BaseModel):
# Check if host is provided
if not values.get("host"):
raise ValueError("Host must be provided for OpenSearch")
# Authentication: Either API key or user/password must be provided
if not any([values.get("api_key"), (values.get("user") and values.get("password")), values.get("http_auth")]):
raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication")
return values
@model_validator(mode="before")
@@ -37,6 +52,7 @@ class OpenSearchConfig(BaseModel):
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}"
f"Extra fields not allowed: {', '.join(extra_fields)}. "
f"Allowed fields: {', '.join(allowed_fields)}"
)
return values