Code Formatting (#1828)

This commit is contained in:
Dev Khant
2024-09-07 22:39:28 +05:30
committed by GitHub
parent 6a54d27286
commit a972d2fb07
66 changed files with 208 additions and 138 deletions

View File

@@ -3,7 +3,7 @@ repos:
hooks: hooks:
- id: ruff - id: ruff
name: Ruff name: Ruff
entry: ruff entry: ruff check
language: system language: system
types: [python] types: [python]
args: [--fix] args: [--fix]

View File

@@ -1,6 +1,8 @@
# This example shows how to use vector config to use QDRANT CLOUD # This example shows how to use vector config to use QDRANT CLOUD
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
from mem0 import Memory from mem0 import Memory
# Loading OpenAI API Key # Loading OpenAI API Key

View File

@@ -9,7 +9,6 @@ import requests
import yaml import yaml
from tqdm import tqdm from tqdm import tqdm
from mem0 import Memory
from embedchain.cache import ( from embedchain.cache import (
Config, Config,
ExactMatchEvaluation, ExactMatchEvaluation,
@@ -26,7 +25,11 @@ from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.evaluation.base import BaseMetric from embedchain.evaluation.base import BaseMetric
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness from embedchain.evaluation.metrics import (
AnswerRelevance,
ContextRelevance,
Groundedness,
)
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
@@ -36,6 +39,7 @@ from embedchain.utils.evaluation import EvalData, EvalMetric
from embedchain.utils.misc import validate_config from embedchain.utils.misc import validate_config
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB from embedchain.vectordb.chroma import ChromaDB
from mem0 import Memory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -3,8 +3,10 @@ from typing import Any
from embedchain import App from embedchain import App
from embedchain.config import AddConfig, AppConfig, BaseLlmConfig from embedchain.config import AddConfig, AppConfig, BaseLlmConfig
from embedchain.embedder.openai import OpenAIEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.helpers.json_serializable import (JSONSerializable, from embedchain.helpers.json_serializable import (
register_deserializable) JSONSerializable,
register_deserializable,
)
from embedchain.llm.openai import OpenAILlm from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.chroma import ChromaDB from embedchain.vectordb.chroma import ChromaDB

View File

@@ -9,10 +9,12 @@ from gptcache.manager import get_data_manager
from gptcache.manager.scalar_data.base import Answer from gptcache.manager.scalar_data.base import Answer
from gptcache.manager.scalar_data.base import DataType as CacheDataType from gptcache.manager.scalar_data.base import DataType as CacheDataType
from gptcache.session import Session from gptcache.session import Session
from gptcache.similarity_evaluation.distance import \ from gptcache.similarity_evaluation.distance import ( # noqa: F401
SearchDistanceEvaluation # noqa: F401 SearchDistanceEvaluation,
from gptcache.similarity_evaluation.exact_match import \ )
ExactMatchEvaluation # noqa: F401 from gptcache.similarity_evaluation.exact_match import ( # noqa: F401
ExactMatchEvaluation,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -14,13 +14,21 @@ import requests
from rich.console import Console from rich.console import Console
from embedchain.telemetry.posthog import AnonymousTelemetry from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils.cli import (deploy_fly, deploy_gradio_app, from embedchain.utils.cli import (
deploy_hf_spaces, deploy_modal, deploy_fly,
deploy_render, deploy_streamlit, deploy_gradio_app,
get_pkg_path_from_name, setup_fly_io_app, deploy_hf_spaces,
setup_gradio_app, setup_hf_app, deploy_modal,
setup_modal_com_app, setup_render_com_app, deploy_render,
setup_streamlit_io_app) deploy_streamlit,
get_pkg_path_from_name,
setup_fly_io_app,
setup_gradio_app,
setup_hf_app,
setup_modal_com_app,
setup_render_com_app,
setup_streamlit_io_app,
)
console = Console() console = Console()
api_process = None api_process = None

View File

@@ -8,8 +8,8 @@ from .embedder.base import BaseEmbedderConfig
from .embedder.base import BaseEmbedderConfig as EmbedderConfig from .embedder.base import BaseEmbedderConfig as EmbedderConfig
from .embedder.ollama import OllamaEmbedderConfig from .embedder.ollama import OllamaEmbedderConfig
from .llm.base import BaseLlmConfig from .llm.base import BaseLlmConfig
from .mem0_config import Mem0Config
from .vector_db.chroma import ChromaDbConfig from .vector_db.chroma import ChromaDbConfig
from .vector_db.elasticsearch import ElasticsearchDBConfig from .vector_db.elasticsearch import ElasticsearchDBConfig
from .vector_db.opensearch import OpenSearchDBConfig from .vector_db.opensearch import OpenSearchDBConfig
from .vector_db.zilliz import ZillizDBConfig from .vector_db.zilliz import ZillizDBConfig
from .mem0_config import Mem0Config

View File

@@ -1,2 +1,5 @@
from .base import (AnswerRelevanceConfig, ContextRelevanceConfig, # noqa: F401 from .base import ( # noqa: F401
GroundednessConfig) AnswerRelevanceConfig,
ContextRelevanceConfig,
GroundednessConfig,
)

View File

@@ -1,9 +1,9 @@
import json import json
import logging import logging
import re import re
from string import Template
from typing import Any, Mapping, Optional, Dict, Union
from pathlib import Path from pathlib import Path
from string import Template
from typing import Any, Dict, Mapping, Optional, Union
import httpx import httpx

View File

@@ -6,7 +6,12 @@ from typing import Any, Optional, Union
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.docstore.document import Document from langchain.docstore.document import Document
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback from embedchain.cache import (
adapt,
get_gptcache_session,
gptcache_data_convert,
gptcache_update_cache_callback,
)
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig from embedchain.config.base_app_config import BaseAppConfig
@@ -16,7 +21,12 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType from embedchain.models.data_type import (
DataType,
DirectDataType,
IndirectDataType,
SpecialDataType,
)
from embedchain.utils.misc import detect_datatype, is_valid_json_string from embedchain.utils.misc import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB

View File

@@ -1,18 +1,18 @@
import os import os
from typing import Optional, Union from typing import Optional, Union
from chromadb import EmbeddingFunction, Embeddings
from embedchain.config import BaseEmbedderConfig from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from chromadb import EmbeddingFunction, Embeddings
class ClarifaiEmbeddingFunction(EmbeddingFunction): class ClarifaiEmbeddingFunction(EmbeddingFunction):
def __init__(self, config: BaseEmbedderConfig) -> None: def __init__(self, config: BaseEmbedderConfig) -> None:
super().__init__() super().__init__()
try: try:
from clarifai.client.model import Model
from clarifai.client.input import Inputs from clarifai.client.input import Inputs
from clarifai.client.model import Model
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError( raise ModuleNotFoundError(
"The required dependencies for ClarifaiEmbeddingFunction are not installed." "The required dependencies for ClarifaiEmbeddingFunction are not installed."

View File

@@ -9,7 +9,9 @@ class GPT4AllEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None): def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config) super().__init__(config=config)
from langchain_community.embeddings import GPT4AllEmbeddings as LangchainGPT4AllEmbeddings from langchain_community.embeddings import (
GPT4AllEmbeddings as LangchainGPT4AllEmbeddings,
)
model_name = self.config.model or "all-MiniLM-L6-v2-f16.gguf" model_name = self.config.model or "all-MiniLM-L6-v2-f16.gguf"
gpt4all_kwargs = {'allow_download': 'True'} gpt4all_kwargs = {'allow_download': 'True'}

View File

@@ -3,7 +3,6 @@ from typing import Optional
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
from embedchain.config import BaseEmbedderConfig from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions from embedchain.models import VectorDimensions

View File

@@ -45,8 +45,9 @@ class AWSBedrockLlm(BaseLlm):
} }
if config.stream: if config.stream:
from langchain.callbacks.streaming_stdout import \ from langchain.callbacks.streaming_stdout import (
StreamingStdOutCallbackHandler StreamingStdOutCallbackHandler,
)
kwargs["streaming"] = True kwargs["streaming"] = True
kwargs["callbacks"] = [StreamingStdOutCallbackHandler()] kwargs["callbacks"] = [StreamingStdOutCallbackHandler()]

View File

@@ -5,7 +5,6 @@ from typing import Any, Optional
from langchain.schema import BaseMessage as LCBaseMessage from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.constants import SQLITE_PATH
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import ( from embedchain.config.llm.base import (
DEFAULT_PROMPT, DEFAULT_PROMPT,
@@ -13,6 +12,7 @@ from embedchain.config.llm.base import (
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE, DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE,
) )
from embedchain.constants import SQLITE_PATH
from embedchain.core.db.database import init_db, setup_engine from embedchain.core.db.database import init_db, setup_engine
from embedchain.helpers.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory from embedchain.memory.base import ChatHistory

View File

@@ -26,8 +26,7 @@ class GPT4ALLLlm(BaseLlm):
@staticmethod @staticmethod
def _get_instance(model): def _get_instance(model):
try: try:
from langchain_community.llms.gpt4all import \ from langchain_community.llms.gpt4all import GPT4All as LangchainGPT4All
GPT4All as LangchainGPT4All
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError( raise ModuleNotFoundError(
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501 "The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501

View File

@@ -35,8 +35,9 @@ class JinaLlm(BaseLlm):
if config.top_p: if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p kwargs["model_kwargs"]["top_p"] = config.top_p
if config.stream: if config.stream:
from langchain.callbacks.streaming_stdout import \ from langchain.callbacks.streaming_stdout import (
StreamingStdOutCallbackHandler StreamingStdOutCallbackHandler,
)
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()]) chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
else: else:

View File

@@ -1,6 +1,7 @@
import hashlib import hashlib
from langchain_community.document_loaders import PyPDFLoader from langchain_community.document_loaders import PyPDFLoader
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string from embedchain.utils.misc import clean_string

View File

@@ -28,7 +28,9 @@ class RSSFeedLoader(BaseLoader):
@staticmethod @staticmethod
def get_rss_content(url: str): def get_rss_content(url: str):
try: try:
from langchain_community.document_loaders import RSSFeedLoader as LangchainRSSFeedLoader from langchain_community.document_loaders import (
RSSFeedLoader as LangchainRSSFeedLoader,
)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"""RSSFeedLoader file requires extra dependencies. """RSSFeedLoader file requires extra dependencies.

View File

@@ -11,8 +11,7 @@ class UnstructuredLoader(BaseLoader):
"""Load data from an Unstructured file.""" """Load data from an Unstructured file."""
try: try:
import unstructured # noqa: F401 import unstructured # noqa: F401
from langchain_community.document_loaders import \ from langchain_community.document_loaders import UnstructuredFileLoader
UnstructuredFileLoader
except ImportError: except ImportError:
raise ImportError( raise ImportError(
'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501 'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501

View File

@@ -6,8 +6,15 @@ from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
try: try:
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema, from pymilvus import (
MilvusClient, connections, utility) Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusClient,
connections,
utility,
)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`" "Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"

View File

@@ -8,8 +8,7 @@ import streamlit as st
from embedchain import App from embedchain import App
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield, from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
generate)
def embedchain_bot(db_path, api_key): def embedchain_bot(db_path, api_key):

View File

@@ -8,8 +8,7 @@ import streamlit as st
from embedchain import App from embedchain import App
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield, from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
generate)
@st.cache_resource @st.cache_resource

View File

@@ -4,8 +4,7 @@ import streamlit as st
from embedchain import App from embedchain import App
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield, from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
generate)
@st.cache_resource @st.cache_resource

View File

@@ -1,4 +1,4 @@
from unittest.mock import patch, Mock from unittest.mock import Mock, patch
import httpx import httpx

View File

@@ -1,5 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
from embedchain.config import BaseEmbedderConfig from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.huggingface import HuggingFaceEmbedder from embedchain.embedder.huggingface import HuggingFaceEmbedder

View File

@@ -4,8 +4,10 @@ from string import Template
from embedchain import App from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.helpers.json_serializable import (JSONSerializable, from embedchain.helpers.json_serializable import (
register_deserializable) JSONSerializable,
register_deserializable,
)
class TestJsonSerializable(unittest.TestCase): class TestJsonSerializable(unittest.TestCase):

View File

@@ -1,7 +1,9 @@
from unittest import mock from unittest import mock
import pytest import pytest
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_alembic_command_upgrade(): def mock_alembic_command_upgrade():
with mock.patch("alembic.command.upgrade"): with mock.patch("alembic.command.upgrade"):

View File

@@ -1,4 +1,4 @@
from unittest.mock import Mock, MagicMock, patch from unittest.mock import MagicMock, Mock, patch
import httpx import httpx
import pytest import pytest

View File

@@ -1,8 +1,8 @@
import pytest import pytest
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.llm.ollama import OllamaLlm from embedchain.llm.ollama import OllamaLlm
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@pytest.fixture @pytest.fixture

View File

@@ -1,5 +1,9 @@
from embedchain.models.data_type import (DataType, DirectDataType, from embedchain.models.data_type import (
IndirectDataType, SpecialDataType) DataType,
DirectDataType,
IndirectDataType,
SpecialDataType,
)
def test_subclass_types_in_data_type(): def test_subclass_types_in_data_type():

View File

@@ -2,5 +2,5 @@ import importlib.metadata
__version__ = importlib.metadata.version("mem0ai") __version__ = importlib.metadata.version("mem0ai")
from mem0.memory.main import Memory # noqa
from mem0.client.main import MemoryClient # noqa from mem0.client.main import MemoryClient # noqa
from mem0.memory.main import Memory # noqa

View File

@@ -1,12 +1,13 @@
import os import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from mem0.memory.setup import mem0_dir
from mem0.vector_stores.configs import VectorStoreConfig
from mem0.llms.configs import LlmConfig
from mem0.embeddings.configs import EmbedderConfig from mem0.embeddings.configs import EmbedderConfig
from mem0.graphs.configs import GraphStoreConfig from mem0.graphs.configs import GraphStoreConfig
from mem0.llms.configs import LlmConfig
from mem0.memory.setup import mem0_dir
from mem0.vector_stores.configs import VectorStoreConfig
class MemoryItem(BaseModel): class MemoryItem(BaseModel):

View File

@@ -1,9 +1,10 @@
from abc import ABC from abc import ABC
from mem0.configs.base import AzureConfig from typing import Dict, Optional, Union
from typing import Optional, Union, Dict
import httpx import httpx
from mem0.configs.base import AzureConfig
class BaseEmbedderConfig(ABC): class BaseEmbedderConfig(ABC):
""" """

View File

@@ -1,9 +1,10 @@
from abc import ABC from abc import ABC
from mem0.configs.base import AzureConfig from typing import Dict, Optional, Union
from typing import Optional, Union, Dict
import httpx import httpx
from mem0.configs.base import AzureConfig
class BaseLlmConfig(ABC): class BaseLlmConfig(ABC):
""" """

View File

@@ -1,6 +1,6 @@
import subprocess import subprocess
import sys import sys
from typing import Optional, ClassVar, Dict, Any from typing import Any, ClassVar, Dict, Optional
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator

View File

@@ -1,4 +1,4 @@
from typing import Optional, Dict, Any from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator

View File

@@ -1,5 +1,6 @@
from typing import Any, ClassVar, Dict, Optional
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from typing import Optional, ClassVar, Dict, Any
class QdrantConfig(BaseModel): class QdrantConfig(BaseModel):

View File

@@ -1,5 +1,5 @@
from typing import Optional
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig

View File

@@ -1,7 +1,10 @@
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from mem0.llms.configs import LlmConfig from mem0.llms.configs import LlmConfig
class Neo4jConfig(BaseModel): class Neo4jConfig(BaseModel):
url: Optional[str] = Field(None, description="Host address for the graph database") url: Optional[str] = Field(None, description="Host address for the graph database")
username: Optional[str] = Field(None, description="Username for the graph database") username: Optional[str] = Field(None, description="Username for the graph database")

View File

@@ -1,7 +1,4 @@
import subprocess
import sys
import os import os
import json
from typing import Dict, List, Optional from typing import Dict, List, Optional
try: try:
@@ -9,8 +6,8 @@ try:
except ImportError: except ImportError:
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.") raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class AnthropicLLM(LLMBase): class AnthropicLLM(LLMBase):

View File

@@ -1,16 +1,14 @@
import subprocess
import sys
import os
import json import json
from typing import Dict, List, Optional, Any import os
from typing import Any, Dict, List, Optional
try: try:
import boto3 import boto3
except ImportError: except ImportError:
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.") raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class AWSBedrockLLM(LLMBase): class AWSBedrockLLM(LLMBase):

View File

@@ -1,11 +1,11 @@
import os
import json import json
import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
from openai import AzureOpenAI from openai import AzureOpenAI
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class AzureOpenAILLM(LLMBase): class AzureOpenAILLM(LLMBase):

View File

@@ -1,5 +1,5 @@
from typing import Optional
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig

View File

@@ -1,7 +1,5 @@
import subprocess
import sys
import os
import json import json
import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
try: try:
@@ -9,8 +7,8 @@ try:
except ImportError: except ImportError:
raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.") raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.")
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class GroqLLM(LLMBase): class GroqLLM(LLMBase):

View File

@@ -1,5 +1,3 @@
import subprocess
import sys
import json import json
from typing import Dict, List, Optional from typing import Dict, List, Optional
@@ -8,8 +6,8 @@ try:
except ImportError: except ImportError:
raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.") raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.")
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class LiteLLM(LLMBase): class LiteLLM(LLMBase):

View File

@@ -1,5 +1,3 @@
import subprocess
import sys
from typing import Dict, List, Optional from typing import Dict, List, Optional
try: try:
@@ -7,8 +5,8 @@ try:
except ImportError: except ImportError:
raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.") raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.")
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class OllamaLLM(LLMBase): class OllamaLLM(LLMBase):

View File

@@ -1,11 +1,11 @@
import os
import json import json
import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
from openai import OpenAI from openai import OpenAI
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class OpenAILLM(LLMBase): class OpenAILLM(LLMBase):

View File

@@ -1,10 +1,11 @@
import os, json import json
import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
from openai import OpenAI from openai import OpenAI
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class OpenAIStructuredLLM(LLMBase): class OpenAIStructuredLLM(LLMBase):

View File

@@ -1,7 +1,5 @@
import subprocess
import sys
import os
import json import json
import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
try: try:
@@ -9,8 +7,8 @@ try:
except ImportError: except ImportError:
raise ImportError("The 'together' library is required. Please install it using 'pip install together'.") raise ImportError("The 'together' library is required. Please install it using 'pip install together'.")
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class TogetherLLM(LLMBase): class TogetherLLM(LLMBase):

View File

@@ -1,10 +1,17 @@
import json
import logging import logging
from langchain_community.graphs import Neo4jGraph from langchain_community.graphs import Neo4jGraph
from rank_bm25 import BM25Okapi from rank_bm25 import BM25Okapi
from mem0.utils.factory import LlmFactory, EmbedderFactory
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT from mem0.graphs.tools import (
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL ADD_MEMORY_TOOL_GRAPH,
ADD_MESSAGE_TOOL,
NOOP_TOOL,
SEARCH_TOOL,
UPDATE_MEMORY_TOOL_GRAPH,
)
from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1,22 +1,24 @@
import logging import concurrent
import hashlib import hashlib
import uuid
import pytz
import json import json
import logging
import threading
import uuid
import warnings
from datetime import datetime from datetime import datetime
from typing import Any, Dict from typing import Any, Dict
import warnings
import pytz
from pydantic import ValidationError from pydantic import ValidationError
from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.prompts import get_update_memory_messages
from mem0.memory.base import MemoryBase from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
from mem0.configs.prompts import get_update_memory_messages from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
from mem0.configs.base import MemoryItem, MemoryConfig
import threading
import concurrent
# Setup user config # Setup user config
setup_config() setup_config()
@@ -154,7 +156,7 @@ class Memory(MemoryBase):
logging.info(resp) logging.info(resp)
try: try:
if resp["event"] == "ADD": if resp["event"] == "ADD":
memory_id = self._create_memory(data=resp["text"], metadata=metadata) self._create_memory(data=resp["text"], metadata=metadata)
elif resp["event"] == "UPDATE": elif resp["event"] == "UPDATE":
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata) self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
elif resp["event"] == "DELETE": elif resp["event"] == "DELETE":
@@ -175,7 +177,7 @@ class Memory(MemoryBase):
else: else:
self.graph.user_id = "USER" self.graph.user_id = "USER"
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"]) data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
added_entities = self.graph.add(data, filters) self.graph.add(data, filters)
def get(self, memory_id): def get(self, memory_id):
""" """

View File

@@ -1,3 +1,4 @@
import logging
import platform import platform
import sys import sys
@@ -5,6 +6,8 @@ from posthog import Posthog
from mem0.memory.setup import get_user_id, setup_config from mem0.memory.setup import get_user_id, setup_config
logging.getLogger('posthog').setLevel(logging.CRITICAL + 1)
logging.getLogger('urllib3').setLevel(logging.CRITICAL + 1)
class AnonymousTelemetry: class AnonymousTelemetry:
def __init__(self, project_api_key, host): def __init__(self, project_api_key, host):

View File

@@ -1,9 +1,10 @@
import logging import logging
import subprocess import subprocess
import sys import sys
import httpx
from typing import Optional, List, Union
import threading import threading
from typing import List, Optional, Union
import httpx
try: try:
import litellm import litellm
@@ -20,9 +21,9 @@ except ImportError:
raise ImportError("The required 'litellm' library is not installed.") raise ImportError("The required 'litellm' library is not installed.")
sys.exit(1) sys.exit(1)
from mem0.memory.telemetry import capture_client_event
from mem0 import Memory, MemoryClient from mem0 import Memory, MemoryClient
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
from mem0.memory.telemetry import capture_client_event
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1,7 +1,7 @@
import importlib import importlib
from mem0.configs.llms.base import BaseLlmConfig
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.configs.llms.base import BaseLlmConfig
def load_class(class_type): def load_class(class_type):

View File

@@ -1,7 +1,5 @@
import subprocess
import sys
import logging import logging
from typing import Optional, List, Dict from typing import Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -1,4 +1,5 @@
from typing import Optional, Dict from typing import Dict, Optional
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator

View File

@@ -1,8 +1,7 @@
import subprocess
import sys
import json import json
import logging import logging
from typing import Optional, List from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
try: try:

View File

@@ -1,6 +1,6 @@
import logging
import os import os
import shutil import shutil
import logging
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.models import ( from qdrant_client.models import (

View File

@@ -1,10 +1,11 @@
from unittest.mock import Mock, patch
import httpx import httpx
import pytest import pytest
from unittest.mock import Mock, patch
from mem0.llms.azure_openai import AzureOpenAILLM
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.azure_openai import AzureOpenAILLM
MODEL = "gpt-4o" # or your custom deployment name MODEL = "gpt-4o" # or your custom deployment name
TEMPERATURE = 0.7 TEMPERATURE = 0.7

View File

@@ -1,7 +1,10 @@
import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from mem0.llms.groq import GroqLLM
import pytest
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.groq import GroqLLM
@pytest.fixture @pytest.fixture
def mock_groq_client(): def mock_groq_client():

View File

@@ -1,8 +1,10 @@
import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from mem0.llms import litellm import pytest
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms import litellm
@pytest.fixture @pytest.fixture
def mock_litellm(): def mock_litellm():

View File

@@ -1,9 +1,12 @@
import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from mem0.llms.ollama import OllamaLLM
import pytest
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.ollama import OllamaLLM
from mem0.llms.utils.tools import ADD_MEMORY_TOOL from mem0.llms.utils.tools import ADD_MEMORY_TOOL
@pytest.fixture @pytest.fixture
def mock_ollama_client(): def mock_ollama_client():
with patch('mem0.llms.ollama.Client') as mock_ollama: with patch('mem0.llms.ollama.Client') as mock_ollama:

View File

@@ -1,7 +1,10 @@
import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from mem0.llms.openai import OpenAILLM
import pytest
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.openai import OpenAILLM
@pytest.fixture @pytest.fixture
def mock_openai_client(): def mock_openai_client():

View File

@@ -1,7 +1,10 @@
import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from mem0.llms.together import TogetherLLM
import pytest
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.together import TogetherLLM
@pytest.fixture @pytest.fixture
def mock_together_client(): def mock_together_client():

View File

@@ -1,10 +1,11 @@
import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT import pytest
from mem0 import Memory, MemoryClient from mem0 import Memory, MemoryClient
from mem0.proxy.main import Mem0 from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
from mem0.proxy.main import Chat, Completions from mem0.proxy.main import Chat, Completions, Mem0
@pytest.fixture @pytest.fixture
def mock_memory_client(): def mock_memory_client():