[Misc] Lint code and fix code smells (#1871)

This commit is contained in:
Deshraj Yadav
2024-09-16 17:39:54 -07:00
committed by GitHub
parent 0a78cb9f7a
commit 55c54beeab
57 changed files with 1178 additions and 1357 deletions

View File

@@ -15,14 +15,14 @@ class AzureOpenAIEmbedding(EmbeddingBase):
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("EMBEDDING_AZURE_DEPLOYMENT")
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("EMBEDDING_AZURE_ENDPOINT")
api_version = self.config.azure_kwargs.api_version or os.getenv("EMBEDDING_AZURE_API_VERSION")
self.client = AzureOpenAI(
azure_deployment=azure_deployment,
azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
http_client=self.config.http_client
)
http_client=self.config.http_client,
)
def embed(self, text):
"""
@@ -35,8 +35,4 @@ class AzureOpenAIEmbedding(EmbeddingBase):
list: The embedding vector.
"""
text = text.replace("\n", " ")
return (
self.client.embeddings.create(input=[text], model=self.config.model)
.data[0]
.embedding
)
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding

View File

@@ -8,9 +8,7 @@ class EmbedderConfig(BaseModel):
description="Provider of the embedding model (e.g., 'ollama', 'openai')",
default="openai",
)
config: Optional[dict] = Field(
description="Configuration for the specific embedding model", default={}
)
config: Optional[dict] = Field(description="Configuration for the specific embedding model", default={})
@field_validator("config")
def validate_config(cls, v, values):

View File

@@ -9,7 +9,7 @@ try:
from ollama import Client
except ImportError:
user_input = input("The 'ollama' library is required. Install it now? [y/N]: ")
if user_input.lower() == 'y':
if user_input.lower() == "y":
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"])
from ollama import Client

View File

@@ -29,8 +29,4 @@ class OpenAIEmbedding(EmbeddingBase):
list: The embedding vector.
"""
text = text.replace("\n", " ")
return (
self.client.embeddings.create(input=[text], model=self.config.model)
.data[0]
.embedding
)
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding

View File

@@ -6,6 +6,7 @@ from vertexai.language_models import TextEmbeddingModel
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase
class VertexAI(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
@@ -34,6 +35,6 @@ class VertexAI(EmbeddingBase):
Returns:
list: The embedding vector.
"""
embeddings = self.model.get_embeddings(texts=[text], output_dimensionality= self.config.embedding_dims)
embeddings = self.model.get_embeddings(texts=[text], output_dimensionality=self.config.embedding_dims)
return embeddings[0].values