[Bug fix] Fix typos, static methods and other sanity improvements in the package (#1129)

This commit is contained in:
Sandra Serrano
2024-01-08 19:47:46 +01:00
committed by GitHub
parent 62c0c52e31
commit 2496ed133e
41 changed files with 133 additions and 103 deletions

View File

@@ -26,7 +26,7 @@ class ChunkerConfig(BaseConfig):
if self.min_chunk_size >= self.chunk_size:
raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
if self.min_chunk_size < self.chunk_overlap:
logging.warn(
logging.warning(
f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant." # noqa:E501
)
@@ -35,7 +35,8 @@ class ChunkerConfig(BaseConfig):
else:
self.length_function = length_function if length_function else len
def load_func(self, dotpath: str):
@staticmethod
def load_func(dotpath: str):
if "." not in dotpath:
return getattr(builtins, dotpath)
else:

View File

@@ -10,12 +10,12 @@ class CacheSimilarityEvalConfig(BaseConfig):
This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage.
In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been
put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`].
`positive` is used to indicate this distance is directly proportional to the similarity of two entites.
If `positive` is set `False`, `max_distance` will be used to substract this distance to get the final score.
`positive` is used to indicate this distance is directly proportional to the similarity of two entities.
If `positive` is set `False`, `max_distance` will be used to subtract this distance to get the final score.
:param max_distance: the bound of maximum distance.
:type max_distance: float
:param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise it is False.
:param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise, it is False.
:type positive: bool
"""
@@ -29,6 +29,7 @@ class CacheSimilarityEvalConfig(BaseConfig):
self.max_distance = max_distance
self.positive = positive
@staticmethod
def from_config(config: Optional[Dict[str, Any]]):
if config is None:
return CacheSimilarityEvalConfig()
@@ -63,6 +64,7 @@ class CacheInitConfig(BaseConfig):
self.similarity_threshold = similarity_threshold
self.auto_flush = auto_flush
@staticmethod
def from_config(config: Optional[Dict[str, Any]]):
if config is None:
return CacheInitConfig()
@@ -83,6 +85,7 @@ class CacheConfig(BaseConfig):
self.similarity_eval_config = similarity_eval_config
self.init_config = init_config
@staticmethod
def from_config(config: Optional[Dict[str, Any]]):
if config is None:
return CacheConfig()

View File

@@ -155,24 +155,26 @@ class BaseLlmConfig(BaseConfig):
self.stream = stream
self.where = where
def validate_prompt(self, prompt: Template) -> bool:
@staticmethod
def validate_prompt(prompt: Template) -> Optional[re.Match[str]]:
"""
validate the prompt
:param prompt: the prompt to validate
:type prompt: Template
:return: valid (true) or invalid (false)
:rtype: bool
:rtype: Optional[re.Match[str]]
"""
return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
def _validate_prompt_history(self, prompt: Template) -> bool:
@staticmethod
def _validate_prompt_history(prompt: Template) -> Optional[re.Match[str]]:
"""
validate the prompt with history
:param prompt: the prompt to validate
:type prompt: Template
:return: valid (true) or invalid (false)
:rtype: bool
:rtype: Optional[re.Match[str]]
"""
return re.search(history_re, prompt.template)

View File

@@ -7,8 +7,8 @@ from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class QdrantDBConfig(BaseVectorDbConfig):
"""
Config to initialize an qdrant client.
:param url. qdrant url or list of nodes url to be used for connection
Config to initialize a qdrant client.
:param: url. qdrant url or list of nodes url to be used for connection
"""
def __init__(

View File

@@ -26,7 +26,7 @@ class ZillizDBConfig(BaseVectorDbConfig):
:param uri: Cluster endpoint obtained from the Zilliz Console, defaults to None
:type uri: Optional[str], optional
:param token: API Key, if a Serverless Cluster, username:password, if a Dedicated Cluster, defaults to None
:type port: Optional[str], optional
:type token: Optional[str], optional
"""
self.uri = uri or os.environ.get("ZILLIZ_CLOUD_URI")
if not self.uri: