[Bug fix] Fix typos, static methods and other sanity improvements in the package (#1129)
This commit is contained in:
@@ -26,7 +26,7 @@ class ChunkerConfig(BaseConfig):
|
||||
if self.min_chunk_size >= self.chunk_size:
|
||||
raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
|
||||
if self.min_chunk_size < self.chunk_overlap:
|
||||
logging.warn(
|
||||
logging.warning(
|
||||
f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant." # noqa:E501
|
||||
)
|
||||
|
||||
@@ -35,7 +35,8 @@ class ChunkerConfig(BaseConfig):
|
||||
else:
|
||||
self.length_function = length_function if length_function else len
|
||||
|
||||
def load_func(self, dotpath: str):
|
||||
@staticmethod
|
||||
def load_func(dotpath: str):
|
||||
if "." not in dotpath:
|
||||
return getattr(builtins, dotpath)
|
||||
else:
|
||||
|
||||
@@ -10,12 +10,12 @@ class CacheSimilarityEvalConfig(BaseConfig):
|
||||
This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage.
|
||||
In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been
|
||||
put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`].
|
||||
`positive` is used to indicate this distance is directly proportional to the similarity of two entites.
|
||||
If `positive` is set `False`, `max_distance` will be used to substract this distance to get the final score.
|
||||
`positive` is used to indicate this distance is directly proportional to the similarity of two entities.
|
||||
If `positive` is set `False`, `max_distance` will be used to subtract this distance to get the final score.
|
||||
|
||||
:param max_distance: the bound of maximum distance.
|
||||
:type max_distance: float
|
||||
:param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise it is False.
|
||||
:param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise, it is False.
|
||||
:type positive: bool
|
||||
"""
|
||||
|
||||
@@ -29,6 +29,7 @@ class CacheSimilarityEvalConfig(BaseConfig):
|
||||
self.max_distance = max_distance
|
||||
self.positive = positive
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[Dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheSimilarityEvalConfig()
|
||||
@@ -63,6 +64,7 @@ class CacheInitConfig(BaseConfig):
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.auto_flush = auto_flush
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[Dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheInitConfig()
|
||||
@@ -83,6 +85,7 @@ class CacheConfig(BaseConfig):
|
||||
self.similarity_eval_config = similarity_eval_config
|
||||
self.init_config = init_config
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[Dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheConfig()
|
||||
|
||||
@@ -155,24 +155,26 @@ class BaseLlmConfig(BaseConfig):
|
||||
self.stream = stream
|
||||
self.where = where
|
||||
|
||||
def validate_prompt(self, prompt: Template) -> bool:
|
||||
@staticmethod
|
||||
def validate_prompt(prompt: Template) -> Optional[re.Match[str]]:
|
||||
"""
|
||||
validate the prompt
|
||||
|
||||
:param prompt: the prompt to validate
|
||||
:type prompt: Template
|
||||
:return: valid (true) or invalid (false)
|
||||
:rtype: bool
|
||||
:rtype: Optional[re.Match[str]]
|
||||
"""
|
||||
return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
|
||||
|
||||
def _validate_prompt_history(self, prompt: Template) -> bool:
|
||||
@staticmethod
|
||||
def _validate_prompt_history(prompt: Template) -> Optional[re.Match[str]]:
|
||||
"""
|
||||
validate the prompt with history
|
||||
|
||||
:param prompt: the prompt to validate
|
||||
:type prompt: Template
|
||||
:return: valid (true) or invalid (false)
|
||||
:rtype: bool
|
||||
:rtype: Optional[re.Match[str]]
|
||||
"""
|
||||
return re.search(history_re, prompt.template)
|
||||
|
||||
@@ -7,8 +7,8 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
@register_deserializable
|
||||
class QdrantDBConfig(BaseVectorDbConfig):
|
||||
"""
|
||||
Config to initialize an qdrant client.
|
||||
:param url. qdrant url or list of nodes url to be used for connection
|
||||
Config to initialize a qdrant client.
|
||||
:param: url. qdrant url or list of nodes url to be used for connection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -26,7 +26,7 @@ class ZillizDBConfig(BaseVectorDbConfig):
|
||||
:param uri: Cluster endpoint obtained from the Zilliz Console, defaults to None
|
||||
:type uri: Optional[str], optional
|
||||
:param token: API Key, if a Serverless Cluster, username:password, if a Dedicated Cluster, defaults to None
|
||||
:type port: Optional[str], optional
|
||||
:type token: Optional[str], optional
|
||||
"""
|
||||
self.uri = uri or os.environ.get("ZILLIZ_CLOUD_URI")
|
||||
if not self.uri:
|
||||
|
||||
Reference in New Issue
Block a user