[Misc] Lint code and fix code smells (#1871)
This commit is contained in:
@@ -17,18 +17,10 @@ class MemoryItem(BaseModel):
|
||||
) # TODO After prompt changes from platform, update this
|
||||
hash: Optional[str] = Field(None, description="The hash of the memory")
|
||||
# The metadata value can be anything and not just string. Fix it
|
||||
metadata: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Additional metadata for the text data"
|
||||
)
|
||||
score: Optional[float] = Field(
|
||||
None, description="The score associated with the text data"
|
||||
)
|
||||
created_at: Optional[str] = Field(
|
||||
None, description="The timestamp when the memory was created"
|
||||
)
|
||||
updated_at: Optional[str] = Field(
|
||||
None, description="The timestamp when the memory was updated"
|
||||
)
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the text data")
|
||||
score: Optional[float] = Field(None, description="The score associated with the text data")
|
||||
created_at: Optional[str] = Field(None, description="The timestamp when the memory was created")
|
||||
updated_at: Optional[str] = Field(None, description="The timestamp when the memory was updated")
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
@@ -60,7 +52,7 @@ class MemoryConfig(BaseModel):
|
||||
description="Custom prompt for the memory",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class AzureConfig(BaseModel):
|
||||
"""
|
||||
@@ -73,7 +65,10 @@ class AzureConfig(BaseModel):
|
||||
api_version (str): The version of the Azure API being used.
|
||||
"""
|
||||
|
||||
api_key: str = Field(description="The API key used for authenticating with the Azure service.", default=None)
|
||||
azure_deployment : str = Field(description="The name of the Azure deployment.", default=None)
|
||||
azure_endpoint : str = Field(description="The endpoint URL for the Azure service.", default=None)
|
||||
api_version : str = Field(description="The version of the Azure API being used.", default=None)
|
||||
api_key: str = Field(
|
||||
description="The API key used for authenticating with the Azure service.",
|
||||
default=None,
|
||||
)
|
||||
azure_deployment: str = Field(description="The name of the Azure deployment.", default=None)
|
||||
azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None)
|
||||
api_version: str = Field(description="The version of the Azure API being used.", default=None)
|
||||
|
||||
@@ -60,6 +60,6 @@ class BaseEmbedderConfig(ABC):
|
||||
|
||||
# Huggingface specific
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
|
||||
|
||||
# AzureOpenAI specific
|
||||
self.azure_kwargs = AzureConfig(**azure_kwargs) or {}
|
||||
|
||||
@@ -59,6 +59,7 @@ You should detect the language of the user input and record the facts in the sam
|
||||
If you do not find anything relevant facts, user memories, and preferences in the below conversation, you can return an empty list corresponding to the "facts" key.
|
||||
"""
|
||||
|
||||
|
||||
def get_update_memory_messages(retrieved_old_memory_dict, response_content):
|
||||
return f"""You are a smart memory manager which controls the memory of a system.
|
||||
You can perform four operations: (1) add into the memory, (2) update the memory, (3) delete from the memory, and (4) no change.
|
||||
|
||||
@@ -13,9 +13,7 @@ class ChromaDbConfig(BaseModel):
|
||||
Client: ClassVar[type] = Client
|
||||
|
||||
collection_name: str = Field("mem0", description="Default name for the collection")
|
||||
client: Optional[Client] = Field(
|
||||
None, description="Existing ChromaDB client instance"
|
||||
)
|
||||
client: Optional[Client] = Field(None, description="Existing ChromaDB client instance")
|
||||
path: Optional[str] = Field(None, description="Path to the database directory")
|
||||
host: Optional[str] = Field(None, description="Database connection remote host")
|
||||
port: Optional[int] = Field(None, description="Database connection remote port")
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel, model_validator, Field
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class MetricType(str, Enum):
|
||||
"""
|
||||
Metric Constant for milvus/ zilliz server.
|
||||
"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
|
||||
L2 = "L2"
|
||||
IP = "IP"
|
||||
COSINE = "COSINE"
|
||||
HAMMING = "HAMMING"
|
||||
JACCARD = "JACCARD"
|
||||
|
||||
|
||||
IP = "IP"
|
||||
COSINE = "COSINE"
|
||||
HAMMING = "HAMMING"
|
||||
JACCARD = "JACCARD"
|
||||
|
||||
|
||||
class MilvusDBConfig(BaseModel):
|
||||
url: str = Field("http://localhost:19530", description="Full URL for Milvus/Zilliz server")
|
||||
token: str = Field(None, description="Token for Zilliz server / local setup defaults to None.")
|
||||
@@ -38,4 +40,4 @@ class MilvusDBConfig(BaseModel):
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,12 +4,9 @@ from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class PGVectorConfig(BaseModel):
|
||||
|
||||
dbname: str = Field("postgres", description="Default name for the database")
|
||||
collection_name: str = Field("mem0", description="Default name for the collection")
|
||||
embedding_model_dims: Optional[int] = Field(
|
||||
1536, description="Dimensions of the embedding model"
|
||||
)
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
||||
user: Optional[str] = Field(None, description="Database user")
|
||||
password: Optional[str] = Field(None, description="Database password")
|
||||
host: Optional[str] = Field(None, description="Database host. Default is localhost")
|
||||
|
||||
@@ -9,17 +9,11 @@ class QdrantConfig(BaseModel):
|
||||
QdrantClient: ClassVar[type] = QdrantClient
|
||||
|
||||
collection_name: str = Field("mem0", description="Name of the collection")
|
||||
embedding_model_dims: Optional[int] = Field(
|
||||
1536, description="Dimensions of the embedding model"
|
||||
)
|
||||
client: Optional[QdrantClient] = Field(
|
||||
None, description="Existing Qdrant client instance"
|
||||
)
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
||||
client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance")
|
||||
host: Optional[str] = Field(None, description="Host address for Qdrant server")
|
||||
port: Optional[int] = Field(None, description="Port for Qdrant server")
|
||||
path: Optional[str] = Field(
|
||||
"/tmp/qdrant", description="Path for local Qdrant database"
|
||||
)
|
||||
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
|
||||
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
|
||||
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
|
||||
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
|
||||
@@ -35,9 +29,7 @@ class QdrantConfig(BaseModel):
|
||||
values.get("api_key"),
|
||||
)
|
||||
if not path and not (host and port) and not (url and api_key):
|
||||
raise ValueError(
|
||||
"Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided."
|
||||
)
|
||||
raise ValueError("Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided.")
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
Reference in New Issue
Block a user