tools fix and formatting (#2441)

This commit is contained in:
Dev Khant
2025-03-26 11:25:03 +05:30
committed by GitHub
parent 2517ccd489
commit 2004427acd
18 changed files with 536 additions and 1151 deletions

View File

@@ -13,7 +13,16 @@ class EmbedderConfig(BaseModel):
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai", "together", "lmstudio"]:
if provider in [
"openai",
"ollama",
"huggingface",
"azure_openai",
"gemini",
"vertexai",
"together",
"lmstudio",
]:
return v
else:
raise ValueError(f"Unsupported embedding provider: {provider}")

View File

@@ -28,5 +28,7 @@ class GoogleGenAIEmbedding(EmbeddingBase):
list: The embedding vector.
"""
text = text.replace("\n", " ")
response = genai.embed_content(model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims)
response = genai.embed_content(
model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims
)
return response["embedding"]

View File

@@ -26,8 +26,4 @@ class LMStudioEmbedding(EmbeddingBase):
list: The embedding vector.
"""
text = text.replace("\n", " ")
return (
self.client.embeddings.create(input=[text], model=self.config.model)
.data[0]
.embedding
)
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding

View File

@@ -17,16 +17,16 @@ class OpenAIEmbedding(EmbeddingBase):
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
base_url = (
self.config.openai_base_url
or os.getenv("OPENAI_API_BASE")
or os.getenv("OPENAI_BASE_URL")
or "https://api.openai.com/v1"
self.config.openai_base_url
or os.getenv("OPENAI_API_BASE")
or os.getenv("OPENAI_BASE_URL")
or "https://api.openai.com/v1"
)
if os.environ.get("OPENAI_API_BASE"):
warnings.warn(
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
"Please use 'OPENAI_BASE_URL' instead.",
DeprecationWarning
DeprecationWarning,
)
self.client = OpenAI(api_key=api_key, base_url=base_url)
@@ -42,4 +42,8 @@ class OpenAIEmbedding(EmbeddingBase):
list: The embedding vector.
"""
text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=self.config.model, dimensions = self.config.embedding_dims).data[0].embedding
return (
self.client.embeddings.create(input=[text], model=self.config.model, dimensions=self.config.embedding_dims)
.data[0]
.embedding
)