Code Formatting (#1828)
This commit is contained in:
@@ -3,7 +3,7 @@ repos:
|
||||
hooks:
|
||||
- id: ruff
|
||||
name: Ruff
|
||||
entry: ruff
|
||||
entry: ruff check
|
||||
language: system
|
||||
types: [python]
|
||||
args: [--fix]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# This example shows how to use vector config to use QDRANT CLOUD
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from mem0 import Memory
|
||||
|
||||
# Loading OpenAI API Key
|
||||
|
||||
@@ -9,7 +9,6 @@ import requests
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
from mem0 import Memory
|
||||
from embedchain.cache import (
|
||||
Config,
|
||||
ExactMatchEvaluation,
|
||||
@@ -26,7 +25,11 @@ from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.evaluation.base import BaseMetric
|
||||
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
|
||||
from embedchain.evaluation.metrics import (
|
||||
AnswerRelevance,
|
||||
ContextRelevance,
|
||||
Groundedness,
|
||||
)
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
@@ -36,6 +39,7 @@ from embedchain.utils.evaluation import EvalData, EvalMetric
|
||||
from embedchain.utils.misc import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
from mem0 import Memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@ from typing import Any
|
||||
from embedchain import App
|
||||
from embedchain.config import AddConfig, AppConfig, BaseLlmConfig
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.helpers.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
from embedchain.helpers.json_serializable import (
|
||||
JSONSerializable,
|
||||
register_deserializable,
|
||||
)
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
|
||||
@@ -9,10 +9,12 @@ from gptcache.manager import get_data_manager
|
||||
from gptcache.manager.scalar_data.base import Answer
|
||||
from gptcache.manager.scalar_data.base import DataType as CacheDataType
|
||||
from gptcache.session import Session
|
||||
from gptcache.similarity_evaluation.distance import \
|
||||
SearchDistanceEvaluation # noqa: F401
|
||||
from gptcache.similarity_evaluation.exact_match import \
|
||||
ExactMatchEvaluation # noqa: F401
|
||||
from gptcache.similarity_evaluation.distance import ( # noqa: F401
|
||||
SearchDistanceEvaluation,
|
||||
)
|
||||
from gptcache.similarity_evaluation.exact_match import ( # noqa: F401
|
||||
ExactMatchEvaluation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -14,13 +14,21 @@ import requests
|
||||
from rich.console import Console
|
||||
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils.cli import (deploy_fly, deploy_gradio_app,
|
||||
deploy_hf_spaces, deploy_modal,
|
||||
deploy_render, deploy_streamlit,
|
||||
get_pkg_path_from_name, setup_fly_io_app,
|
||||
setup_gradio_app, setup_hf_app,
|
||||
setup_modal_com_app, setup_render_com_app,
|
||||
setup_streamlit_io_app)
|
||||
from embedchain.utils.cli import (
|
||||
deploy_fly,
|
||||
deploy_gradio_app,
|
||||
deploy_hf_spaces,
|
||||
deploy_modal,
|
||||
deploy_render,
|
||||
deploy_streamlit,
|
||||
get_pkg_path_from_name,
|
||||
setup_fly_io_app,
|
||||
setup_gradio_app,
|
||||
setup_hf_app,
|
||||
setup_modal_com_app,
|
||||
setup_render_com_app,
|
||||
setup_streamlit_io_app,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
api_process = None
|
||||
|
||||
@@ -8,8 +8,8 @@ from .embedder.base import BaseEmbedderConfig
|
||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||
from .embedder.ollama import OllamaEmbedderConfig
|
||||
from .llm.base import BaseLlmConfig
|
||||
from .mem0_config import Mem0Config
|
||||
from .vector_db.chroma import ChromaDbConfig
|
||||
from .vector_db.elasticsearch import ElasticsearchDBConfig
|
||||
from .vector_db.opensearch import OpenSearchDBConfig
|
||||
from .vector_db.zilliz import ZillizDBConfig
|
||||
from .mem0_config import Mem0Config
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
from .base import (AnswerRelevanceConfig, ContextRelevanceConfig, # noqa: F401
|
||||
GroundednessConfig)
|
||||
from .base import ( # noqa: F401
|
||||
AnswerRelevanceConfig,
|
||||
ContextRelevanceConfig,
|
||||
GroundednessConfig,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from string import Template
|
||||
from typing import Any, Mapping, Optional, Dict, Union
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
from typing import Any, Dict, Mapping, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
@@ -6,7 +6,12 @@ from typing import Any, Optional, Union
|
||||
from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
|
||||
from embedchain.cache import (
|
||||
adapt,
|
||||
get_gptcache_session,
|
||||
gptcache_data_convert,
|
||||
gptcache_update_cache_callback,
|
||||
)
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config.base_app_config import BaseAppConfig
|
||||
@@ -16,7 +21,12 @@ from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
||||
from embedchain.models.data_type import (
|
||||
DataType,
|
||||
DirectDataType,
|
||||
IndirectDataType,
|
||||
SpecialDataType,
|
||||
)
|
||||
from embedchain.utils.misc import detect_datatype, is_valid_json_string
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from chromadb import EmbeddingFunction, Embeddings
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
|
||||
from chromadb import EmbeddingFunction, Embeddings
|
||||
|
||||
|
||||
class ClarifaiEmbeddingFunction(EmbeddingFunction):
|
||||
def __init__(self, config: BaseEmbedderConfig) -> None:
|
||||
super().__init__()
|
||||
try:
|
||||
from clarifai.client.model import Model
|
||||
from clarifai.client.input import Inputs
|
||||
from clarifai.client.model import Model
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for ClarifaiEmbeddingFunction are not installed."
|
||||
|
||||
@@ -9,7 +9,9 @@ class GPT4AllEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
from langchain_community.embeddings import GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
|
||||
from langchain_community.embeddings import (
|
||||
GPT4AllEmbeddings as LangchainGPT4AllEmbeddings,
|
||||
)
|
||||
|
||||
model_name = self.config.model or "all-MiniLM-L6-v2-f16.gguf"
|
||||
gpt4all_kwargs = {'allow_download': 'True'}
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Optional
|
||||
|
||||
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
|
||||
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
@@ -45,8 +45,9 @@ class AWSBedrockLlm(BaseLlm):
|
||||
}
|
||||
|
||||
if config.stream:
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import (
|
||||
StreamingStdOutCallbackHandler,
|
||||
)
|
||||
|
||||
kwargs["streaming"] = True
|
||||
kwargs["callbacks"] = [StreamingStdOutCallbackHandler()]
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Any, Optional
|
||||
|
||||
from langchain.schema import BaseMessage as LCBaseMessage
|
||||
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.config.llm.base import (
|
||||
DEFAULT_PROMPT,
|
||||
@@ -13,6 +12,7 @@ from embedchain.config.llm.base import (
|
||||
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE,
|
||||
)
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.core.db.database import init_db, setup_engine
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.memory.base import ChatHistory
|
||||
|
||||
@@ -26,8 +26,7 @@ class GPT4ALLLlm(BaseLlm):
|
||||
@staticmethod
|
||||
def _get_instance(model):
|
||||
try:
|
||||
from langchain_community.llms.gpt4all import \
|
||||
GPT4All as LangchainGPT4All
|
||||
from langchain_community.llms.gpt4all import GPT4All as LangchainGPT4All
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
|
||||
|
||||
@@ -35,8 +35,9 @@ class JinaLlm(BaseLlm):
|
||||
if config.top_p:
|
||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||
if config.stream:
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import (
|
||||
StreamingStdOutCallbackHandler,
|
||||
)
|
||||
|
||||
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
@@ -28,7 +28,9 @@ class RSSFeedLoader(BaseLoader):
|
||||
@staticmethod
|
||||
def get_rss_content(url: str):
|
||||
try:
|
||||
from langchain_community.document_loaders import RSSFeedLoader as LangchainRSSFeedLoader
|
||||
from langchain_community.document_loaders import (
|
||||
RSSFeedLoader as LangchainRSSFeedLoader,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"""RSSFeedLoader file requires extra dependencies.
|
||||
|
||||
@@ -11,8 +11,7 @@ class UnstructuredLoader(BaseLoader):
|
||||
"""Load data from an Unstructured file."""
|
||||
try:
|
||||
import unstructured # noqa: F401
|
||||
from langchain_community.document_loaders import \
|
||||
UnstructuredFileLoader
|
||||
from langchain_community.document_loaders import UnstructuredFileLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501
|
||||
|
||||
@@ -6,8 +6,15 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
try:
|
||||
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
|
||||
MilvusClient, connections, utility)
|
||||
from pymilvus import (
|
||||
Collection,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
FieldSchema,
|
||||
MilvusClient,
|
||||
connections,
|
||||
utility,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
|
||||
|
||||
@@ -8,8 +8,7 @@ import streamlit as st
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
|
||||
generate)
|
||||
from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
|
||||
|
||||
|
||||
def embedchain_bot(db_path, api_key):
|
||||
|
||||
@@ -8,8 +8,7 @@ import streamlit as st
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
|
||||
generate)
|
||||
from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
|
||||
@@ -4,8 +4,7 @@ import streamlit as st
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
|
||||
generate)
|
||||
from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import patch, Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.huggingface import HuggingFaceEmbedder
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ from string import Template
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
from embedchain.helpers.json_serializable import (
|
||||
JSONSerializable,
|
||||
register_deserializable,
|
||||
)
|
||||
|
||||
|
||||
class TestJsonSerializable(unittest.TestCase):
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_alembic_command_upgrade():
|
||||
with mock.patch("alembic.command.upgrade"):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import pytest
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.ollama import OllamaLlm
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
from embedchain.models.data_type import (
|
||||
DataType,
|
||||
DirectDataType,
|
||||
IndirectDataType,
|
||||
SpecialDataType,
|
||||
)
|
||||
|
||||
|
||||
def test_subclass_types_in_data_type():
|
||||
|
||||
@@ -2,5 +2,5 @@ import importlib.metadata
|
||||
|
||||
__version__ = importlib.metadata.version("mem0ai")
|
||||
|
||||
from mem0.memory.main import Memory # noqa
|
||||
from mem0.client.main import MemoryClient # noqa
|
||||
from mem0.memory.main import Memory # noqa
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from mem0.memory.setup import mem0_dir
|
||||
from mem0.vector_stores.configs import VectorStoreConfig
|
||||
from mem0.llms.configs import LlmConfig
|
||||
from mem0.embeddings.configs import EmbedderConfig
|
||||
from mem0.graphs.configs import GraphStoreConfig
|
||||
from mem0.llms.configs import LlmConfig
|
||||
from mem0.memory.setup import mem0_dir
|
||||
from mem0.vector_stores.configs import VectorStoreConfig
|
||||
|
||||
|
||||
class MemoryItem(BaseModel):
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from abc import ABC
|
||||
from mem0.configs.base import AzureConfig
|
||||
from typing import Optional, Union, Dict
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from mem0.configs.base import AzureConfig
|
||||
|
||||
|
||||
class BaseEmbedderConfig(ABC):
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from abc import ABC
|
||||
from mem0.configs.base import AzureConfig
|
||||
from typing import Optional, Union, Dict
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from mem0.configs.base import AzureConfig
|
||||
|
||||
|
||||
class BaseLlmConfig(ABC):
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Optional, ClassVar, Dict, Any
|
||||
from typing import Any, ClassVar, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, ClassVar, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import Optional, ClassVar, Dict, Any
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from mem0.llms.configs import LlmConfig
|
||||
|
||||
|
||||
class Neo4jConfig(BaseModel):
|
||||
url: Optional[str] = Field(None, description="Host address for the graph database")
|
||||
username: Optional[str] = Field(None, description="Username for the graph database")
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
@@ -9,8 +6,8 @@ try:
|
||||
except ImportError:
|
||||
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class AnthropicLLM(LLMBase):
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional, Any
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class AzureOpenAILLM(LLMBase):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
@@ -9,8 +7,8 @@ try:
|
||||
except ImportError:
|
||||
raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.")
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class GroqLLM(LLMBase):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -8,8 +6,8 @@ try:
|
||||
except ImportError:
|
||||
raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.")
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class LiteLLM(LLMBase):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
@@ -7,8 +5,8 @@ try:
|
||||
except ImportError:
|
||||
raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.")
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class OllamaLLM(LLMBase):
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class OpenAILLM(LLMBase):
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import os, json
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class OpenAIStructuredLLM(LLMBase):
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
@@ -9,8 +7,8 @@ try:
|
||||
except ImportError:
|
||||
raise ImportError("The 'together' library is required. Please install it using 'pip install together'.")
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class TogetherLLM(LLMBase):
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from rank_bm25 import BM25Okapi
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory
|
||||
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
|
||||
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL
|
||||
|
||||
from mem0.graphs.tools import (
|
||||
ADD_MEMORY_TOOL_GRAPH,
|
||||
ADD_MESSAGE_TOOL,
|
||||
NOOP_TOOL,
|
||||
SEARCH_TOOL,
|
||||
UPDATE_MEMORY_TOOL_GRAPH,
|
||||
)
|
||||
from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
import logging
|
||||
import concurrent
|
||||
import hashlib
|
||||
import uuid
|
||||
import pytz
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
import warnings
|
||||
|
||||
import pytz
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mem0.configs.base import MemoryConfig, MemoryItem
|
||||
from mem0.configs.prompts import get_update_memory_messages
|
||||
from mem0.memory.base import MemoryBase
|
||||
from mem0.memory.setup import setup_config
|
||||
from mem0.memory.storage import SQLiteManager
|
||||
from mem0.memory.telemetry import capture_event
|
||||
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
|
||||
from mem0.configs.prompts import get_update_memory_messages
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
|
||||
from mem0.configs.base import MemoryItem, MemoryConfig
|
||||
import threading
|
||||
import concurrent
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
||||
|
||||
# Setup user config
|
||||
setup_config()
|
||||
@@ -154,7 +156,7 @@ class Memory(MemoryBase):
|
||||
logging.info(resp)
|
||||
try:
|
||||
if resp["event"] == "ADD":
|
||||
memory_id = self._create_memory(data=resp["text"], metadata=metadata)
|
||||
self._create_memory(data=resp["text"], metadata=metadata)
|
||||
elif resp["event"] == "UPDATE":
|
||||
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
|
||||
elif resp["event"] == "DELETE":
|
||||
@@ -175,7 +177,7 @@ class Memory(MemoryBase):
|
||||
else:
|
||||
self.graph.user_id = "USER"
|
||||
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
|
||||
added_entities = self.graph.add(data, filters)
|
||||
self.graph.add(data, filters)
|
||||
|
||||
def get(self, memory_id):
|
||||
"""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import platform
|
||||
import sys
|
||||
|
||||
@@ -5,6 +6,8 @@ from posthog import Posthog
|
||||
|
||||
from mem0.memory.setup import get_user_id, setup_config
|
||||
|
||||
logging.getLogger('posthog').setLevel(logging.CRITICAL + 1)
|
||||
logging.getLogger('urllib3').setLevel(logging.CRITICAL + 1)
|
||||
|
||||
class AnonymousTelemetry:
|
||||
def __init__(self, project_api_key, host):
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import httpx
|
||||
from typing import Optional, List, Union
|
||||
import threading
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
try:
|
||||
import litellm
|
||||
@@ -20,9 +21,9 @@ except ImportError:
|
||||
raise ImportError("The required 'litellm' library is not installed.")
|
||||
sys.exit(1)
|
||||
|
||||
from mem0.memory.telemetry import capture_client_event
|
||||
from mem0 import Memory, MemoryClient
|
||||
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
|
||||
from mem0.memory.telemetry import capture_client_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import importlib
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
def load_class(class_type):
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import logging
|
||||
from typing import Optional, List, Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional, Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import logging
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from mem0.llms.azure_openai import AzureOpenAILLM
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.azure_openai import AzureOpenAILLM
|
||||
|
||||
MODEL = "gpt-4o" # or your custom deployment name
|
||||
TEMPERATURE = 0.7
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from mem0.llms.groq import GroqLLM
|
||||
|
||||
import pytest
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.groq import GroqLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_groq_client():
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from mem0.llms import litellm
|
||||
import pytest
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms import litellm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm():
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from mem0.llms.ollama import OllamaLLM
|
||||
|
||||
import pytest
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.ollama import OllamaLLM
|
||||
from mem0.llms.utils.tools import ADD_MEMORY_TOOL
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ollama_client():
|
||||
with patch('mem0.llms.ollama.Client') as mock_ollama:
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from mem0.llms.openai import OpenAILLM
|
||||
|
||||
import pytest
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.openai import OpenAILLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client():
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from mem0.llms.together import TogetherLLM
|
||||
|
||||
import pytest
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.together import TogetherLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_together_client():
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
|
||||
import pytest
|
||||
|
||||
from mem0 import Memory, MemoryClient
|
||||
from mem0.proxy.main import Mem0
|
||||
from mem0.proxy.main import Chat, Completions
|
||||
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
|
||||
from mem0.proxy.main import Chat, Completions, Mem0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_memory_client():
|
||||
|
||||
Reference in New Issue
Block a user