[Misc] Lint code and fix code smells (#1871)
This commit is contained in:
@@ -15,14 +15,14 @@ class AzureOpenAIEmbedding(EmbeddingBase):
|
||||
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("EMBEDDING_AZURE_DEPLOYMENT")
|
||||
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("EMBEDDING_AZURE_ENDPOINT")
|
||||
api_version = self.config.azure_kwargs.api_version or os.getenv("EMBEDDING_AZURE_API_VERSION")
|
||||
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
azure_deployment=azure_deployment,
|
||||
azure_deployment=azure_deployment,
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
http_client=self.config.http_client
|
||||
)
|
||||
http_client=self.config.http_client,
|
||||
)
|
||||
|
||||
def embed(self, text):
|
||||
"""
|
||||
@@ -35,8 +35,4 @@ class AzureOpenAIEmbedding(EmbeddingBase):
|
||||
list: The embedding vector.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
return (
|
||||
self.client.embeddings.create(input=[text], model=self.config.model)
|
||||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
|
||||
|
||||
@@ -8,9 +8,7 @@ class EmbedderConfig(BaseModel):
|
||||
description="Provider of the embedding model (e.g., 'ollama', 'openai')",
|
||||
default="openai",
|
||||
)
|
||||
config: Optional[dict] = Field(
|
||||
description="Configuration for the specific embedding model", default={}
|
||||
)
|
||||
config: Optional[dict] = Field(description="Configuration for the specific embedding model", default={})
|
||||
|
||||
@field_validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
|
||||
@@ -9,7 +9,7 @@ try:
|
||||
from ollama import Client
|
||||
except ImportError:
|
||||
user_input = input("The 'ollama' library is required. Install it now? [y/N]: ")
|
||||
if user_input.lower() == 'y':
|
||||
if user_input.lower() == "y":
|
||||
try:
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"])
|
||||
from ollama import Client
|
||||
|
||||
@@ -29,8 +29,4 @@ class OpenAIEmbedding(EmbeddingBase):
|
||||
list: The embedding vector.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
return (
|
||||
self.client.embeddings.create(input=[text], model=self.config.model)
|
||||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
|
||||
|
||||
@@ -6,6 +6,7 @@ from vertexai.language_models import TextEmbeddingModel
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
|
||||
class VertexAI(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
@@ -34,6 +35,6 @@ class VertexAI(EmbeddingBase):
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
embeddings = self.model.get_embeddings(texts=[text], output_dimensionality= self.config.embedding_dims)
|
||||
|
||||
embeddings = self.model.get_embeddings(texts=[text], output_dimensionality=self.config.embedding_dims)
|
||||
|
||||
return embeddings[0].values
|
||||
|
||||
Reference in New Issue
Block a user