[Bug fix] Fix typos, static methods and other sanity improvements in the package (#1129)

This commit is contained in:
Sandra Serrano
2024-01-08 19:47:46 +01:00
committed by GitHub
parent 62c0c52e31
commit 2496ed133e
41 changed files with 133 additions and 103 deletions

View File

@@ -9,9 +9,14 @@ from typing import Any, Dict, Optional
import requests import requests
import yaml import yaml
from embedchain.cache import (Config, ExactMatchEvaluation, from embedchain.cache import (
SearchDistanceEvaluation, cache, Config,
gptcache_data_manager, gptcache_pre_function) ExactMatchEvaluation,
SearchDistanceEvaluation,
cache,
gptcache_data_manager,
gptcache_pre_function,
)
from embedchain.client import Client from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.constants import SQLITE_PATH from embedchain.constants import SQLITE_PATH
@@ -27,7 +32,7 @@ from embedchain.utils.misc import validate_config
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB from embedchain.vectordb.chroma import ChromaDB
# Setup the user directory if doesn't exist already # Set up the user directory if it doesn't exist already
Client.setup_dir() Client.setup_dir()

View File

@@ -17,7 +17,7 @@ class BaseChunker(JSONSerializable):
""" """
Loads data and chunks it. Loads data and chunks it.
:param loader: The loader which's `load_data` method is used to create :param loader: The loader whose `load_data` method is used to create
the raw data. the raw data.
:param src: The data to be handled by the loader. Can be a URL for :param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders. remote sources or local content for local loaders.
@@ -25,7 +25,7 @@ class BaseChunker(JSONSerializable):
""" """
documents = [] documents = []
chunk_ids = [] chunk_ids = []
idMap = {} id_map = {}
min_chunk_size = config.min_chunk_size if config is not None else 1 min_chunk_size = config.min_chunk_size if config is not None else 1
logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters") logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
data_result = loader.load_data(src) data_result = loader.load_data(src)
@@ -49,8 +49,8 @@ class BaseChunker(JSONSerializable):
for chunk in chunks: for chunk in chunks:
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest() chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
if idMap.get(chunk_id) is None and len(chunk) >= min_chunk_size: if id_map.get(chunk_id) is None and len(chunk) >= min_chunk_size:
idMap[chunk_id] = True id_map[chunk_id] = True
chunk_ids.append(chunk_id) chunk_ids.append(chunk_id)
documents.append(chunk) documents.append(chunk)
metadatas.append(meta_data) metadatas.append(meta_data)
@@ -77,5 +77,6 @@ class BaseChunker(JSONSerializable):
# TODO: This should be done during initialization. This means it has to be done in the child classes. # TODO: This should be done during initialization. This means it has to be done in the child classes.
def get_word_count(self, documents): @staticmethod
def get_word_count(documents) -> int:
return sum([len(document.split(" ")) for document in documents]) return sum([len(document.split(" ")) for document in documents])

View File

@@ -31,7 +31,7 @@ class Client:
) )
@classmethod @classmethod
def setup_dir(self): def setup_dir(cls):
""" """
Loads the user id from the config file if it exists, otherwise generates a new Loads the user id from the config file if it exists, otherwise generates a new
one and saves it to the config file. one and saves it to the config file.

View File

@@ -26,7 +26,7 @@ class ChunkerConfig(BaseConfig):
if self.min_chunk_size >= self.chunk_size: if self.min_chunk_size >= self.chunk_size:
raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}") raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
if self.min_chunk_size < self.chunk_overlap: if self.min_chunk_size < self.chunk_overlap:
logging.warn( logging.warning(
f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant." # noqa:E501 f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant." # noqa:E501
) )
@@ -35,7 +35,8 @@ class ChunkerConfig(BaseConfig):
else: else:
self.length_function = length_function if length_function else len self.length_function = length_function if length_function else len
def load_func(self, dotpath: str): @staticmethod
def load_func(dotpath: str):
if "." not in dotpath: if "." not in dotpath:
return getattr(builtins, dotpath) return getattr(builtins, dotpath)
else: else:

View File

@@ -10,12 +10,12 @@ class CacheSimilarityEvalConfig(BaseConfig):
This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage. This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage.
In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been
put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`]. put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`].
`positive` is used to indicate this distance is directly proportional to the similarity of two entites. `positive` is used to indicate this distance is directly proportional to the similarity of two entities.
If `positive` is set `False`, `max_distance` will be used to substract this distance to get the final score. If `positive` is set `False`, `max_distance` will be used to subtract this distance to get the final score.
:param max_distance: the bound of maximum distance. :param max_distance: the bound of maximum distance.
:type max_distance: float :type max_distance: float
:param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise it is False. :param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise, it is False.
:type positive: bool :type positive: bool
""" """
@@ -29,6 +29,7 @@ class CacheSimilarityEvalConfig(BaseConfig):
self.max_distance = max_distance self.max_distance = max_distance
self.positive = positive self.positive = positive
@staticmethod
def from_config(config: Optional[Dict[str, Any]]): def from_config(config: Optional[Dict[str, Any]]):
if config is None: if config is None:
return CacheSimilarityEvalConfig() return CacheSimilarityEvalConfig()
@@ -63,6 +64,7 @@ class CacheInitConfig(BaseConfig):
self.similarity_threshold = similarity_threshold self.similarity_threshold = similarity_threshold
self.auto_flush = auto_flush self.auto_flush = auto_flush
@staticmethod
def from_config(config: Optional[Dict[str, Any]]): def from_config(config: Optional[Dict[str, Any]]):
if config is None: if config is None:
return CacheInitConfig() return CacheInitConfig()
@@ -83,6 +85,7 @@ class CacheConfig(BaseConfig):
self.similarity_eval_config = similarity_eval_config self.similarity_eval_config = similarity_eval_config
self.init_config = init_config self.init_config = init_config
@staticmethod
def from_config(config: Optional[Dict[str, Any]]): def from_config(config: Optional[Dict[str, Any]]):
if config is None: if config is None:
return CacheConfig() return CacheConfig()

View File

@@ -155,24 +155,26 @@ class BaseLlmConfig(BaseConfig):
self.stream = stream self.stream = stream
self.where = where self.where = where
def validate_prompt(self, prompt: Template) -> bool: @staticmethod
def validate_prompt(prompt: Template) -> Optional[re.Match[str]]:
""" """
validate the prompt validate the prompt
:param prompt: the prompt to validate :param prompt: the prompt to validate
:type prompt: Template :type prompt: Template
:return: valid (true) or invalid (false) :return: valid (true) or invalid (false)
:rtype: bool :rtype: Optional[re.Match[str]]
""" """
return re.search(query_re, prompt.template) and re.search(context_re, prompt.template) return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
def _validate_prompt_history(self, prompt: Template) -> bool: @staticmethod
def _validate_prompt_history(prompt: Template) -> Optional[re.Match[str]]:
""" """
validate the prompt with history validate the prompt with history
:param prompt: the prompt to validate :param prompt: the prompt to validate
:type prompt: Template :type prompt: Template
:return: valid (true) or invalid (false) :return: valid (true) or invalid (false)
:rtype: bool :rtype: Optional[re.Match[str]]
""" """
return re.search(history_re, prompt.template) return re.search(history_re, prompt.template)

View File

@@ -7,8 +7,8 @@ from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable
class QdrantDBConfig(BaseVectorDbConfig): class QdrantDBConfig(BaseVectorDbConfig):
""" """
Config to initialize an qdrant client. Config to initialize a qdrant client.
:param url. qdrant url or list of nodes url to be used for connection :param: url. qdrant url or list of nodes url to be used for connection
""" """
def __init__( def __init__(

View File

@@ -26,7 +26,7 @@ class ZillizDBConfig(BaseVectorDbConfig):
:param uri: Cluster endpoint obtained from the Zilliz Console, defaults to None :param uri: Cluster endpoint obtained from the Zilliz Console, defaults to None
:type uri: Optional[str], optional :type uri: Optional[str], optional
:param token: API Key, if a Serverless Cluster, username:password, if a Dedicated Cluster, defaults to None :param token: API Key, if a Serverless Cluster, username:password, if a Dedicated Cluster, defaults to None
:type port: Optional[str], optional :type token: Optional[str], optional
""" """
self.uri = uri or os.environ.get("ZILLIZ_CLOUD_URI") self.uri = uri or os.environ.get("ZILLIZ_CLOUD_URI")
if not self.uri: if not self.uri:

View File

@@ -34,7 +34,8 @@ class DataFormatter(JSONSerializable):
self.loader = self._get_loader(data_type=data_type, config=config.loader, loader=loader) self.loader = self._get_loader(data_type=data_type, config=config.loader, loader=loader)
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, chunker=chunker) self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, chunker=chunker)
def _lazy_load(self, module_path: str): @staticmethod
def _lazy_load(module_path: str):
module_path, class_name = module_path.rsplit(".", 1) module_path, class_name = module_path.rsplit(".", 1)
module = import_module(module_path) module = import_module(module_path)
return getattr(module, class_name) return getattr(module, class_name)

View File

@@ -7,9 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.docstore.document import Document from langchain.docstore.document import Document
from embedchain.cache import (adapt, get_gptcache_session, from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
gptcache_data_convert,
gptcache_update_cache_callback)
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig from embedchain.config.base_app_config import BaseAppConfig
@@ -19,8 +17,7 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import (DataType, DirectDataType, from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
IndirectDataType, SpecialDataType)
from embedchain.telemetry.posthog import AnonymousTelemetry from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils.misc import detect_datatype, is_valid_json_string from embedchain.utils.misc import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
@@ -84,7 +81,7 @@ class EmbedChain(JSONSerializable):
# Attributes that aren't subclass related. # Attributes that aren't subclass related.
self.user_asks = [] self.user_asks = []
self.chunker: ChunkerConfig = None self.chunker: Optional[ChunkerConfig] = None
# Send anonymous telemetry # Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__} self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics) self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -290,7 +287,7 @@ class EmbedChain(JSONSerializable):
# Or it's different, then it will be added as a new text. # Or it's different, then it will be added as a new text.
return None return None
elif chunker.data_type.value in [item.value for item in IndirectDataType]: elif chunker.data_type.value in [item.value for item in IndirectDataType]:
# These types have a indirect source reference # These types have an indirect source reference
# As long as the reference is the same, they can be updated. # As long as the reference is the same, they can be updated.
where = {"url": src} where = {"url": src}
if chunker.data_type == DataType.JSON and is_valid_json_string(src): if chunker.data_type == DataType.JSON and is_valid_json_string(src):
@@ -442,10 +439,11 @@ class EmbedChain(JSONSerializable):
) )
count_new_chunks = self.db.count() - chunks_before_addition count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) print(f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")
return list(documents), metadatas, ids, count_new_chunks return list(documents), metadatas, ids, count_new_chunks
def _format_result(self, results): @staticmethod
def _format_result(results):
return [ return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2]) (Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip( for result in zip(

View File

@@ -15,8 +15,8 @@ class EmbeddingFunc(EmbeddingFunction):
def __init__(self, embedding_fn: Callable[[list[str]], list[str]]): def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
self.embedding_fn = embedding_fn self.embedding_fn = embedding_fn
def __call__(self, input: Embeddable) -> Embeddings: def __call__(self, input_: Embeddable) -> Embeddings:
return self.embedding_fn(input) return self.embedding_fn(input_)
class BaseEmbedder: class BaseEmbedder:
@@ -29,7 +29,7 @@ class BaseEmbedder:
def __init__(self, config: Optional[BaseEmbedderConfig] = None): def __init__(self, config: Optional[BaseEmbedderConfig] = None):
""" """
Intialize the embedder class. Initialize the embedder class.
:param config: embedder configuration option class, defaults to None :param config: embedder configuration option class, defaults to None
:type config: Optional[BaseEmbedderConfig], optional :type config: Optional[BaseEmbedderConfig], optional

View File

@@ -13,11 +13,11 @@ class GoogleAIEmbeddingFunction(EmbeddingFunction):
super().__init__() super().__init__()
self.config = config or GoogleAIEmbedderConfig() self.config = config or GoogleAIEmbedderConfig()
def __call__(self, input: str) -> Embeddings: def __call__(self, input_: str) -> Embeddings:
model = self.config.model model = self.config.model
title = self.config.title title = self.config.title
task_type = self.config.task_type task_type = self.config.task_type
embeddings = genai.embed_content(model=model, content=input, task_type=task_type, title=title) embeddings = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
return embeddings["embedding"] return embeddings["embedding"]

View File

@@ -42,7 +42,7 @@ class JSONSerializable:
A class to represent a JSON serializable object. A class to represent a JSON serializable object.
This class provides methods to serialize and deserialize objects, This class provides methods to serialize and deserialize objects,
as well as save serialized objects to a file and load them back. as well as to save serialized objects to a file and load them back.
""" """
_deserializable_classes = set() # Contains classes that are whitelisted for deserialization. _deserializable_classes = set() # Contains classes that are whitelisted for deserialization.

View File

@@ -4,9 +4,7 @@ from typing import Any, Dict, Generator, List, Optional
from langchain.schema import BaseMessage as LCBaseMessage from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import (DEFAULT_PROMPT, from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.helpers.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage from embedchain.memory.message import ChatMessage
@@ -76,7 +74,7 @@ class BaseLlm(JSONSerializable):
:return: The prompt :return: The prompt
:rtype: str :rtype: str
""" """
context_string = (" | ").join(contexts) context_string = " | ".join(contexts)
web_search_result = kwargs.get("web_search_result", "") web_search_result = kwargs.get("web_search_result", "")
if web_search_result: if web_search_result:
context_string = self._append_search_and_context(context_string, web_search_result) context_string = self._append_search_and_context(context_string, web_search_result)
@@ -110,7 +108,8 @@ class BaseLlm(JSONSerializable):
prompt = self.config.prompt.substitute(context=context_string, query=input_query) prompt = self.config.prompt.substitute(context=context_string, query=input_query)
return prompt return prompt
def _append_search_and_context(self, context: str, web_search_result: str) -> str: @staticmethod
def _append_search_and_context(context: str, web_search_result: str) -> str:
"""Append web search context to existing context """Append web search context to existing context
:param context: Existing context :param context: Existing context
@@ -134,7 +133,8 @@ class BaseLlm(JSONSerializable):
""" """
return self.get_llm_model_answer(prompt) return self.get_llm_model_answer(prompt)
def access_search_and_get_results(self, input_query: str): @staticmethod
def access_search_and_get_results(input_query: str):
""" """
Search the internet for additional context Search the internet for additional context
@@ -153,7 +153,8 @@ class BaseLlm(JSONSerializable):
logging.info(f"Access search to get answers for {input_query}") logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query) return search.run(input_query)
def _stream_response(self, answer: Any) -> Generator[Any, Any, None]: @staticmethod
def _stream_response(answer: Any) -> Generator[Any, Any, None]:
"""Generator to be used as streaming response """Generator to be used as streaming response
:param answer: Answer chunk from llm :param answer: Answer chunk from llm

View File

@@ -44,7 +44,7 @@ class GoogleLlm(BaseLlm):
"temperature": self.config.temperature or 0.5, "temperature": self.config.temperature or 0.5,
} }
if self.config.top_p >= 0.0 and self.config.top_p <= 1.0: if 0.0 <= self.config.top_p <= 1.0:
generation_config_params["top_p"] = self.config.top_p generation_config_params["top_p"] = self.config.top_p
else: else:
raise ValueError("`top_p` must be > 0.0 and < 1.0") raise ValueError("`top_p` must be > 0.0 and < 1.0")

View File

@@ -48,7 +48,7 @@ class HuggingFaceLlm(BaseLlm):
"max_new_tokens": config.max_tokens, "max_new_tokens": config.max_tokens,
} }
if config.top_p > 0.0 and config.top_p < 1.0: if 0.0 < config.top_p < 1.0:
model_kwargs["top_p"] = config.top_p model_kwargs["top_p"] = config.top_p
else: else:
raise ValueError("`top_p` must be > 0.0 and < 1.0") raise ValueError("`top_p` must be > 0.0 and < 1.0")

View File

@@ -20,7 +20,8 @@ class OllamaLlm(BaseLlm):
def get_llm_model_answer(self, prompt): def get_llm_model_answer(self, prompt):
return self._get_answer(prompt=prompt, config=self.config) return self._get_answer(prompt=prompt, config=self.config)
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]: @staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()] callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
llm = Ollama( llm = Ollama(

View File

@@ -5,7 +5,7 @@ class BaseLoader(JSONSerializable):
def __init__(self): def __init__(self):
pass pass
def load_data(): def load_data(self, url):
""" """
Implemented by child classes Implemented by child classes
""" """

View File

@@ -32,7 +32,7 @@ class DirectoryLoader(BaseLoader):
doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest() doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
for error in self.errors: for error in self.errors:
logging.warn(error) logging.warning(error)
return {"doc_id": doc_id, "data": data_list} return {"doc_id": doc_id, "data": data_list}

View File

@@ -49,7 +49,8 @@ class DocsSiteLoader(BaseLoader):
urls = [link for link in self.visited_links if urlparse(link).netloc == urlparse(url).netloc] urls = [link for link in self.visited_links if urlparse(link).netloc == urlparse(url).netloc]
return urls return urls
def _load_data_from_url(self, url): @staticmethod
def _load_data_from_url(url: str) -> list:
response = requests.get(url) response = requests.get(url)
if response.status_code != 200: if response.status_code != 200:
logging.info(f"Failed to fetch the website: {response.status_code}") logging.info(f"Failed to fetch the website: {response.status_code}")

View File

@@ -18,7 +18,7 @@ VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
class GithubLoader(BaseLoader): class GithubLoader(BaseLoader):
"""Load data from github search query.""" """Load data from GitHub search query."""
def __init__(self, config: Optional[Dict[str, Any]] = None): def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__() super().__init__()
@@ -48,7 +48,7 @@ class GithubLoader(BaseLoader):
self.client = None self.client = None
def _github_search_code(self, query: str): def _github_search_code(self, query: str):
"""Search github code.""" """Search GitHub code."""
data = [] data = []
results = self.client.search_code(query) results = self.client.search_code(query)
for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"): for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"):
@@ -66,7 +66,8 @@ class GithubLoader(BaseLoader):
) )
return data return data
def _get_github_repo_data(self, repo_url: str): @staticmethod
def _get_github_repo_data(repo_url: str):
local_hash = hashlib.sha256(repo_url.encode()).hexdigest() local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
local_path = f"/tmp/{local_hash}" local_path = f"/tmp/{local_hash}"
data = [] data = []
@@ -121,14 +122,14 @@ class GithubLoader(BaseLoader):
return data return data
def _github_search_repo(self, query: str): def _github_search_repo(self, query: str) -> list[dict]:
"""Search github repo.""" """Search GitHub repo."""
data = [] data = []
logging.info(f"Searching github repos with query: {query}") logging.info(f"Searching github repos with query: {query}")
results = self.client.search_repositories(query) results = self.client.search_repositories(query)
# Add repo urls and descriptions # Add repo urls and descriptions
urls = list(map(lambda x: x.html_url, results)) urls = list(map(lambda x: x.html_url, results))
discriptions = list(map(lambda x: x.description, results)) descriptions = list(map(lambda x: x.description, results))
data.append( data.append(
{ {
"content": clean_string(desc), "content": clean_string(desc),
@@ -136,7 +137,7 @@ class GithubLoader(BaseLoader):
"url": url, "url": url,
}, },
} }
for url, desc in zip(urls, discriptions) for url, desc in zip(urls, descriptions)
) )
# Add repo contents # Add repo contents
@@ -146,8 +147,8 @@ class GithubLoader(BaseLoader):
data = self._get_github_repo_data(clone_url) data = self._get_github_repo_data(clone_url)
return data return data
def _github_search_issues_and_pr(self, query: str, type: str): def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]:
"""Search github issues and PRs.""" """Search GitHub issues and PRs."""
data = [] data = []
query = f"{query} is:{type}" query = f"{query} is:{type}"
@@ -161,7 +162,7 @@ class GithubLoader(BaseLoader):
title = result.title title = result.title
body = result.body body = result.body
if not body: if not body:
logging.warn(f"Skipping issue because empty content for: {url}") logging.warning(f"Skipping issue because empty content for: {url}")
continue continue
labels = " ".join([label.name for label in result.labels]) labels = " ".join([label.name for label in result.labels])
issue_comments = result.get_comments() issue_comments = result.get_comments()
@@ -186,7 +187,7 @@ class GithubLoader(BaseLoader):
# need to test more for discussion # need to test more for discussion
def _github_search_discussions(self, query: str): def _github_search_discussions(self, query: str):
"""Search github discussions.""" """Search GitHub discussions."""
data = [] data = []
query = f"{query} is:discussion" query = f"{query} is:discussion"
@@ -202,7 +203,7 @@ class GithubLoader(BaseLoader):
title = discussion.title title = discussion.title
body = discussion.body body = discussion.body
if not body: if not body:
logging.warn(f"Skipping discussion because empty content for: {url}") logging.warning(f"Skipping discussion because empty content for: {url}")
continue continue
comments = [] comments = []
comments_created_at = [] comments_created_at = []
@@ -233,11 +234,14 @@ class GithubLoader(BaseLoader):
data = self._github_search_issues_and_pr(query, search_type) data = self._github_search_issues_and_pr(query, search_type)
elif search_type == "discussion": elif search_type == "discussion":
raise ValueError("GithubLoader does not support searching discussions yet.") raise ValueError("GithubLoader does not support searching discussions yet.")
else:
raise NotImplementedError(f"{search_type} not supported")
return data return data
def _get_valid_github_query(self, query: str): @staticmethod
"""Check if query is valid and return search types and valid github query.""" def _get_valid_github_query(query: str):
"""Check if query is valid and return search types and valid GitHub query."""
query_terms = shlex.split(query) query_terms = shlex.split(query)
# query must provide repo to load data from # query must provide repo to load data from
if len(query_terms) < 1 or "repo:" not in query: if len(query_terms) < 1 or "repo:" not in query:
@@ -273,7 +277,7 @@ class GithubLoader(BaseLoader):
return types, query return types, query
def load_data(self, search_query: str, max_results: int = 1000): def load_data(self, search_query: str, max_results: int = 1000):
"""Load data from github search query.""" """Load data from GitHub search query."""
if not self.client: if not self.client:
raise ValueError( raise ValueError(

View File

@@ -20,7 +20,8 @@ class ImageLoader(BaseLoader):
self.api_key = api_key or os.environ["OPENAI_API_KEY"] self.api_key = api_key or os.environ["OPENAI_API_KEY"]
self.client = OpenAI(api_key=self.api_key) self.client = OpenAI(api_key=self.api_key)
def _encode_image(self, image_path: str): @staticmethod
def _encode_image(image_path: str):
with open(image_path, "rb") as image_file: with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8") return base64.b64encode(image_file.read()).decode("utf-8")

View File

@@ -15,7 +15,8 @@ class JSONReader:
"""Initialize the JSONReader.""" """Initialize the JSONReader."""
pass pass
def load_data(self, json_data: Union[Dict, str]) -> List[str]: @staticmethod
def load_data(json_data: Union[Dict, str]) -> List[str]:
"""Load data from a JSON structure. """Load data from a JSON structure.
Args: Args:

View File

@@ -39,7 +39,8 @@ class MySQLLoader(BaseLoader):
Refer `https://docs.embedchain.ai/data-sources/mysql`.", Refer `https://docs.embedchain.ai/data-sources/mysql`.",
) )
def _check_query(self, query): @staticmethod
def _check_query(query):
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError( raise ValueError(
f"Invalid mysql query: {query}", f"Invalid mysql query: {query}",

View File

@@ -24,7 +24,6 @@ class PostgresLoader(BaseLoader):
Run `pip install --upgrade 'embedchain[postgres]'`" Run `pip install --upgrade 'embedchain[postgres]'`"
) from e ) from e
config_info = ""
if "url" in config: if "url" in config:
config_info = config.get("url") config_info = config.get("url")
else: else:
@@ -37,7 +36,8 @@ class PostgresLoader(BaseLoader):
self.connection = psycopg.connect(conninfo=config_info) self.connection = psycopg.connect(conninfo=config_info)
self.cursor = self.connection.cursor() self.cursor = self.connection.cursor()
def _check_query(self, query): @staticmethod
def _check_query(query):
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError( raise ValueError(
f"Invalid postgres query: {query}. Provide the valid source to add from postgres, make sure you are following `https://docs.embedchain.ai/data-sources/postgres`", # noqa:E501 f"Invalid postgres query: {query}. Provide the valid source to add from postgres, make sure you are following `https://docs.embedchain.ai/data-sources/postgres`", # noqa:E501

View File

@@ -56,7 +56,8 @@ class SlackLoader(BaseLoader):
) )
logging.info("Slack Loader setup successful!") logging.info("Slack Loader setup successful!")
def _check_query(self, query): @staticmethod
def _check_query(query):
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError( raise ValueError(
f"Invalid query passed to Slack loader, found: {query}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501 f"Invalid query passed to Slack loader, found: {query}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501

View File

@@ -8,7 +8,7 @@ from embedchain.utils.misc import clean_string
@register_deserializable @register_deserializable
class UnstructuredLoader(BaseLoader): class UnstructuredLoader(BaseLoader):
def load_data(self, url): def load_data(self, url):
"""Load data from a Unstructured file.""" """Load data from an Unstructured file."""
try: try:
from langchain.document_loaders import UnstructuredFileLoader from langchain.document_loaders import UnstructuredFileLoader
except ImportError: except ImportError:

View File

@@ -21,7 +21,7 @@ class WebPageLoader(BaseLoader):
_session = requests.Session() _session = requests.Session()
def load_data(self, url): def load_data(self, url):
"""Load data from a web page using a shared requests session.""" """Load data from a web page using a shared requests' session."""
response = self._session.get(url, timeout=30) response = self._session.get(url, timeout=30)
response.raise_for_status() response.raise_for_status()
data = response.content data = response.content
@@ -40,7 +40,8 @@ class WebPageLoader(BaseLoader):
], ],
} }
def _get_clean_content(self, html, url) -> str: @staticmethod
def _get_clean_content(html, url) -> str:
soup = BeautifulSoup(html, "html.parser") soup = BeautifulSoup(html, "html.parser")
original_size = len(str(soup.get_text())) original_size = len(str(soup.get_text()))
@@ -60,8 +61,8 @@ class WebPageLoader(BaseLoader):
tag.decompose() tag.decompose()
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"] ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
for id in ids_to_exclude: for id_ in ids_to_exclude:
tags = soup.find_all(id=id) tags = soup.find_all(id=id_)
for tag in tags: for tag in tags:
tag.decompose() tag.decompose()

View File

@@ -113,10 +113,12 @@ class ChatHistory:
count = self.cursor.fetchone()[0] count = self.cursor.fetchone()[0]
return count return count
def _serialize_json(self, metadata: Dict[str, Any]): @staticmethod
def _serialize_json(metadata: Dict[str, Any]):
return json.dumps(metadata) return json.dumps(metadata)
def _deserialize_json(self, metadata: str): @staticmethod
def _deserialize_json(metadata: str):
return json.loads(metadata) return json.loads(metadata)
def close_connection(self): def close_connection(self):

View File

@@ -54,7 +54,7 @@ class ChatMessage(JSONSerializable):
if self.human_message: if self.human_message:
logging.info( logging.info(
"Human message already exists in the chat message,\ "Human message already exists in the chat message,\
overwritting it with new message." overwriting it with new message."
) )
self.human_message = BaseMessage(content=message, created_by="human", metadata=metadata) self.human_message = BaseMessage(content=message, created_by="human", metadata=metadata)
@@ -63,7 +63,7 @@ class ChatMessage(JSONSerializable):
if self.ai_message: if self.ai_message:
logging.info( logging.info(
"AI message already exists in the chat message,\ "AI message already exists in the chat message,\
overwritting it with new message." overwriting it with new message."
) )
self.ai_message = BaseMessage(content=message, created_by="ai", metadata=metadata) self.ai_message = BaseMessage(content=message, created_by="ai", metadata=metadata)

View File

@@ -7,7 +7,7 @@ def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str
Args: Args:
left (Dict[str, Any]): metadata of human message left (Dict[str, Any]): metadata of human message
right (Dict[str, Any]): metadata of ai message right (Dict[str, Any]): metadata of AI message
Returns: Returns:
Dict[str, Any]: combined metadata dict with dedup Dict[str, Any]: combined metadata dict with dedup

View File

@@ -19,7 +19,7 @@ from embedchain.utils.misc import detect_datatype
logging.basicConfig(level=logging.WARN) logging.basicConfig(level=logging.WARN)
# Setup the user directory if doesn't exist already # Set up the user directory if it doesn't exist already
Client.setup_dir() Client.setup_dir()
@@ -130,12 +130,14 @@ class OpenAIAssistant:
messages = self._client.beta.threads.messages.list(thread_id=self.thread_id, order="desc") messages = self._client.beta.threads.messages.list(thread_id=self.thread_id, order="desc")
return list(messages) return list(messages)
def _format_message(self, thread_message): @staticmethod
def _format_message(thread_message):
thread_message = cast(ThreadMessage, thread_message) thread_message = cast(ThreadMessage, thread_message)
content = [c.text.value for c in thread_message.content if isinstance(c, MessageContentText)] content = [c.text.value for c in thread_message.content if isinstance(c, MessageContentText)]
return " ".join(content) return " ".join(content)
def _save_temp_data(self, data, source): @staticmethod
def _save_temp_data(data, source):
special_chars_pattern = r'[\\/:*?"<>|&=% ]+' special_chars_pattern = r'[\\/:*?"<>|&=% ]+'
sanitized_source = re.sub(special_chars_pattern, "_", source)[:256] sanitized_source = re.sub(special_chars_pattern, "_", source)[:256]
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()

View File

@@ -38,7 +38,8 @@ class AnonymousTelemetry:
posthog_logger = logging.getLogger("posthog") posthog_logger = logging.getLogger("posthog")
posthog_logger.disabled = True posthog_logger.disabled = True
def _get_user_id(self): @staticmethod
def _get_user_id():
os.makedirs(CONFIG_DIR, exist_ok=True) os.makedirs(CONFIG_DIR, exist_ok=True)
if os.path.exists(CONFIG_FILE): if os.path.exists(CONFIG_FILE):
with open(CONFIG_FILE, "r") as f: with open(CONFIG_FILE, "r") as f:

View File

@@ -201,8 +201,7 @@ def detect_datatype(source: Any) -> DataType:
formatted_source = format_source(str(source), 30) formatted_source = format_source(str(source), 30)
if url: if url:
from langchain.document_loaders.youtube import \ from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
if url.netloc in YOUTUBE_ALLOWED_NETLOCS: if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.") logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
@@ -345,7 +344,7 @@ def detect_datatype(source: Any) -> DataType:
return DataType.TEXT_FILE return DataType.TEXT_FILE
# If the source is a valid file, that's not detectable as a type, an error is raised. # If the source is a valid file, that's not detectable as a type, an error is raised.
# It does not fallback to text. # It does not fall back to text.
raise ValueError( raise ValueError(
"Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)." # noqa: E501 "Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)." # noqa: E501
) )

View File

@@ -49,7 +49,7 @@ class BaseVectorDB(JSONSerializable):
raise NotImplementedError raise NotImplementedError
def query(self): def query(self):
"""Query contents from vector data base based on vector similarity""" """Query contents from vector database based on vector similarity"""
raise NotImplementedError raise NotImplementedError
def count(self) -> int: def count(self) -> int:

View File

@@ -75,7 +75,8 @@ class ChromaDB(BaseVectorDB):
"""Called during initialization""" """Called during initialization"""
return self.client return self.client
def _generate_where_clause(self, where: Dict[str, any]) -> str: @staticmethod
def _generate_where_clause(where: Dict[str, any]) -> Dict[str, any]:
# If only one filter is supplied, return it as is # If only one filter is supplied, return it as is
# (no need to wrap in $and based on chroma docs) # (no need to wrap in $and based on chroma docs)
if len(where.keys()) <= 1: if len(where.keys()) <= 1:
@@ -160,7 +161,8 @@ class ChromaDB(BaseVectorDB):
ids=ids[i : i + self.BATCH_SIZE], ids=ids[i : i + self.BATCH_SIZE],
) )
def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]: @staticmethod
def _format_result(results: QueryResult) -> list[tuple[Document, float]]:
""" """
Format Chroma results Format Chroma results

View File

@@ -88,7 +88,7 @@ class ElasticsearchDB(BaseVectorDB):
""" """
Get existing doc ids present in vector database Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existance :param ids: _list of doc ids to check for existence
:type ids: List[str] :type ids: List[str]
:param where: to filter data :param where: to filter data
:type where: Dict[str, any] :type where: Dict[str, any]
@@ -161,7 +161,7 @@ class ElasticsearchDB(BaseVectorDB):
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]: ) -> Union[List[Tuple[str, Dict]], List[str]]:
""" """
query contents from vector data base based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
:type input_query: List[str] :type input_query: List[str]

View File

@@ -163,7 +163,7 @@ class OpenSearchDB(BaseVectorDB):
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]: ) -> Union[List[Tuple[str, Dict]], List[str]]:
""" """
query contents from vector data base based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
:type input_query: List[str] :type input_query: List[str]

View File

@@ -305,7 +305,8 @@ class WeaviateDB(BaseVectorDB):
""" """
return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize() return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize()
def _query_with_cursor(self, query, cursor): @staticmethod
def _query_with_cursor(query, cursor):
if cursor is not None: if cursor is not None:
query.with_after(cursor) query.with_after(cursor)
results = query.do() results = query.do()

View File

@@ -6,8 +6,7 @@ from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
try: try:
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema, from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusClient, connections, utility
MilvusClient, connections, utility)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`" "Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
@@ -97,10 +96,10 @@ class ZillizVectorDB(BaseVectorDB):
if ids is None or len(ids) == 0 or self.collection.num_entities == 0: if ids is None or len(ids) == 0 or self.collection.num_entities == 0:
return {"ids": []} return {"ids": []}
if not (self.collection.is_empty): if not self.collection.is_empty:
filter = f"id in {ids}" filter_ = f"id in {ids}"
results = self.client.query( results = self.client.query(
collection_name=self.config.collection_name, filter=filter, output_fields=["id"] collection_name=self.config.collection_name, filter=filter_, output_fields=["id"]
) )
results = [res["id"] for res in results] results = [res["id"] for res in results]
@@ -134,7 +133,7 @@ class ZillizVectorDB(BaseVectorDB):
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[Dict[str, Any]],
) -> Union[List[Tuple[str, Dict]], List[str]]: ) -> Union[List[Tuple[str, Dict]], List[str]]:
""" """
Query contents from vector data base based on vector similarity Query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
:type input_query: List[str] :type input_query: List[str]

View File

@@ -69,7 +69,8 @@ class TestTextChunker:
class MockLoader: class MockLoader:
def load_data(self, src): @staticmethod
def load_data(src) -> dict:
""" """
Mock loader that returns a list of data dictionaries. Mock loader that returns a list of data dictionaries.
Adjust this method to return different data for testing. Adjust this method to return different data for testing.