Code Formatting (#1828)
This commit is contained in:
@@ -3,7 +3,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
name: Ruff
|
name: Ruff
|
||||||
entry: ruff
|
entry: ruff check
|
||||||
language: system
|
language: system
|
||||||
types: [python]
|
types: [python]
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
# This example shows how to use vector config to use QDRANT CLOUD
|
# This example shows how to use vector config to use QDRANT CLOUD
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from mem0 import Memory
|
from mem0 import Memory
|
||||||
|
|
||||||
# Loading OpenAI API Key
|
# Loading OpenAI API Key
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import requests
|
|||||||
import yaml
|
import yaml
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from mem0 import Memory
|
|
||||||
from embedchain.cache import (
|
from embedchain.cache import (
|
||||||
Config,
|
Config,
|
||||||
ExactMatchEvaluation,
|
ExactMatchEvaluation,
|
||||||
@@ -26,7 +25,11 @@ from embedchain.embedchain import EmbedChain
|
|||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
from embedchain.embedder.openai import OpenAIEmbedder
|
from embedchain.embedder.openai import OpenAIEmbedder
|
||||||
from embedchain.evaluation.base import BaseMetric
|
from embedchain.evaluation.base import BaseMetric
|
||||||
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
|
from embedchain.evaluation.metrics import (
|
||||||
|
AnswerRelevance,
|
||||||
|
ContextRelevance,
|
||||||
|
Groundedness,
|
||||||
|
)
|
||||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
from embedchain.llm.base import BaseLlm
|
from embedchain.llm.base import BaseLlm
|
||||||
@@ -36,6 +39,7 @@ from embedchain.utils.evaluation import EvalData, EvalMetric
|
|||||||
from embedchain.utils.misc import validate_config
|
from embedchain.utils.misc import validate_config
|
||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
from embedchain.vectordb.chroma import ChromaDB
|
from embedchain.vectordb.chroma import ChromaDB
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ from typing import Any
|
|||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import AddConfig, AppConfig, BaseLlmConfig
|
from embedchain.config import AddConfig, AppConfig, BaseLlmConfig
|
||||||
from embedchain.embedder.openai import OpenAIEmbedder
|
from embedchain.embedder.openai import OpenAIEmbedder
|
||||||
from embedchain.helpers.json_serializable import (JSONSerializable,
|
from embedchain.helpers.json_serializable import (
|
||||||
register_deserializable)
|
JSONSerializable,
|
||||||
|
register_deserializable,
|
||||||
|
)
|
||||||
from embedchain.llm.openai import OpenAILlm
|
from embedchain.llm.openai import OpenAILlm
|
||||||
from embedchain.vectordb.chroma import ChromaDB
|
from embedchain.vectordb.chroma import ChromaDB
|
||||||
|
|
||||||
|
|||||||
@@ -9,10 +9,12 @@ from gptcache.manager import get_data_manager
|
|||||||
from gptcache.manager.scalar_data.base import Answer
|
from gptcache.manager.scalar_data.base import Answer
|
||||||
from gptcache.manager.scalar_data.base import DataType as CacheDataType
|
from gptcache.manager.scalar_data.base import DataType as CacheDataType
|
||||||
from gptcache.session import Session
|
from gptcache.session import Session
|
||||||
from gptcache.similarity_evaluation.distance import \
|
from gptcache.similarity_evaluation.distance import ( # noqa: F401
|
||||||
SearchDistanceEvaluation # noqa: F401
|
SearchDistanceEvaluation,
|
||||||
from gptcache.similarity_evaluation.exact_match import \
|
)
|
||||||
ExactMatchEvaluation # noqa: F401
|
from gptcache.similarity_evaluation.exact_match import ( # noqa: F401
|
||||||
|
ExactMatchEvaluation,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -14,13 +14,21 @@ import requests
|
|||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||||
from embedchain.utils.cli import (deploy_fly, deploy_gradio_app,
|
from embedchain.utils.cli import (
|
||||||
deploy_hf_spaces, deploy_modal,
|
deploy_fly,
|
||||||
deploy_render, deploy_streamlit,
|
deploy_gradio_app,
|
||||||
get_pkg_path_from_name, setup_fly_io_app,
|
deploy_hf_spaces,
|
||||||
setup_gradio_app, setup_hf_app,
|
deploy_modal,
|
||||||
setup_modal_com_app, setup_render_com_app,
|
deploy_render,
|
||||||
setup_streamlit_io_app)
|
deploy_streamlit,
|
||||||
|
get_pkg_path_from_name,
|
||||||
|
setup_fly_io_app,
|
||||||
|
setup_gradio_app,
|
||||||
|
setup_hf_app,
|
||||||
|
setup_modal_com_app,
|
||||||
|
setup_render_com_app,
|
||||||
|
setup_streamlit_io_app,
|
||||||
|
)
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
api_process = None
|
api_process = None
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ from .embedder.base import BaseEmbedderConfig
|
|||||||
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
|
||||||
from .embedder.ollama import OllamaEmbedderConfig
|
from .embedder.ollama import OllamaEmbedderConfig
|
||||||
from .llm.base import BaseLlmConfig
|
from .llm.base import BaseLlmConfig
|
||||||
|
from .mem0_config import Mem0Config
|
||||||
from .vector_db.chroma import ChromaDbConfig
|
from .vector_db.chroma import ChromaDbConfig
|
||||||
from .vector_db.elasticsearch import ElasticsearchDBConfig
|
from .vector_db.elasticsearch import ElasticsearchDBConfig
|
||||||
from .vector_db.opensearch import OpenSearchDBConfig
|
from .vector_db.opensearch import OpenSearchDBConfig
|
||||||
from .vector_db.zilliz import ZillizDBConfig
|
from .vector_db.zilliz import ZillizDBConfig
|
||||||
from .mem0_config import Mem0Config
|
|
||||||
|
|||||||
@@ -1,2 +1,5 @@
|
|||||||
from .base import (AnswerRelevanceConfig, ContextRelevanceConfig, # noqa: F401
|
from .base import ( # noqa: F401
|
||||||
GroundednessConfig)
|
AnswerRelevanceConfig,
|
||||||
|
ContextRelevanceConfig,
|
||||||
|
GroundednessConfig,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from string import Template
|
|
||||||
from typing import Any, Mapping, Optional, Dict, Union
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from string import Template
|
||||||
|
from typing import Any, Dict, Mapping, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,12 @@ from typing import Any, Optional, Union
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
|
from embedchain.cache import (
|
||||||
|
adapt,
|
||||||
|
get_gptcache_session,
|
||||||
|
gptcache_data_convert,
|
||||||
|
gptcache_update_cache_callback,
|
||||||
|
)
|
||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||||
from embedchain.config.base_app_config import BaseAppConfig
|
from embedchain.config.base_app_config import BaseAppConfig
|
||||||
@@ -16,7 +21,12 @@ from embedchain.embedder.base import BaseEmbedder
|
|||||||
from embedchain.helpers.json_serializable import JSONSerializable
|
from embedchain.helpers.json_serializable import JSONSerializable
|
||||||
from embedchain.llm.base import BaseLlm
|
from embedchain.llm.base import BaseLlm
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
from embedchain.models.data_type import (
|
||||||
|
DataType,
|
||||||
|
DirectDataType,
|
||||||
|
IndirectDataType,
|
||||||
|
SpecialDataType,
|
||||||
|
)
|
||||||
from embedchain.utils.misc import detect_datatype, is_valid_json_string
|
from embedchain.utils.misc import detect_datatype, is_valid_json_string
|
||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from chromadb import EmbeddingFunction, Embeddings
|
||||||
|
|
||||||
from embedchain.config import BaseEmbedderConfig
|
from embedchain.config import BaseEmbedderConfig
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
|
|
||||||
from chromadb import EmbeddingFunction, Embeddings
|
|
||||||
|
|
||||||
|
|
||||||
class ClarifaiEmbeddingFunction(EmbeddingFunction):
|
class ClarifaiEmbeddingFunction(EmbeddingFunction):
|
||||||
def __init__(self, config: BaseEmbedderConfig) -> None:
|
def __init__(self, config: BaseEmbedderConfig) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
try:
|
try:
|
||||||
from clarifai.client.model import Model
|
|
||||||
from clarifai.client.input import Inputs
|
from clarifai.client.input import Inputs
|
||||||
|
from clarifai.client.model import Model
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
"The required dependencies for ClarifaiEmbeddingFunction are not installed."
|
"The required dependencies for ClarifaiEmbeddingFunction are not installed."
|
||||||
|
|||||||
@@ -9,7 +9,9 @@ class GPT4AllEmbedder(BaseEmbedder):
|
|||||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
|
||||||
from langchain_community.embeddings import GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
|
from langchain_community.embeddings import (
|
||||||
|
GPT4AllEmbeddings as LangchainGPT4AllEmbeddings,
|
||||||
|
)
|
||||||
|
|
||||||
model_name = self.config.model or "all-MiniLM-L6-v2-f16.gguf"
|
model_name = self.config.model or "all-MiniLM-L6-v2-f16.gguf"
|
||||||
gpt4all_kwargs = {'allow_download': 'True'}
|
gpt4all_kwargs = {'allow_download': 'True'}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from typing import Optional
|
|||||||
|
|
||||||
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
|
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
|
||||||
|
|
||||||
|
|
||||||
from embedchain.config import BaseEmbedderConfig
|
from embedchain.config import BaseEmbedderConfig
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
from embedchain.models import VectorDimensions
|
from embedchain.models import VectorDimensions
|
||||||
|
|||||||
@@ -45,8 +45,9 @@ class AWSBedrockLlm(BaseLlm):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.stream:
|
if config.stream:
|
||||||
from langchain.callbacks.streaming_stdout import \
|
from langchain.callbacks.streaming_stdout import (
|
||||||
StreamingStdOutCallbackHandler
|
StreamingStdOutCallbackHandler,
|
||||||
|
)
|
||||||
|
|
||||||
kwargs["streaming"] = True
|
kwargs["streaming"] = True
|
||||||
kwargs["callbacks"] = [StreamingStdOutCallbackHandler()]
|
kwargs["callbacks"] = [StreamingStdOutCallbackHandler()]
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from langchain.schema import BaseMessage as LCBaseMessage
|
from langchain.schema import BaseMessage as LCBaseMessage
|
||||||
|
|
||||||
from embedchain.constants import SQLITE_PATH
|
|
||||||
from embedchain.config import BaseLlmConfig
|
from embedchain.config import BaseLlmConfig
|
||||||
from embedchain.config.llm.base import (
|
from embedchain.config.llm.base import (
|
||||||
DEFAULT_PROMPT,
|
DEFAULT_PROMPT,
|
||||||
@@ -13,6 +12,7 @@ from embedchain.config.llm.base import (
|
|||||||
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
|
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
|
||||||
DOCS_SITE_PROMPT_TEMPLATE,
|
DOCS_SITE_PROMPT_TEMPLATE,
|
||||||
)
|
)
|
||||||
|
from embedchain.constants import SQLITE_PATH
|
||||||
from embedchain.core.db.database import init_db, setup_engine
|
from embedchain.core.db.database import init_db, setup_engine
|
||||||
from embedchain.helpers.json_serializable import JSONSerializable
|
from embedchain.helpers.json_serializable import JSONSerializable
|
||||||
from embedchain.memory.base import ChatHistory
|
from embedchain.memory.base import ChatHistory
|
||||||
|
|||||||
@@ -26,8 +26,7 @@ class GPT4ALLLlm(BaseLlm):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_instance(model):
|
def _get_instance(model):
|
||||||
try:
|
try:
|
||||||
from langchain_community.llms.gpt4all import \
|
from langchain_community.llms.gpt4all import GPT4All as LangchainGPT4All
|
||||||
GPT4All as LangchainGPT4All
|
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
|
"The GPT4All python package is not installed. Please install it with `pip install --upgrade embedchain[opensource]`" # noqa E501
|
||||||
|
|||||||
@@ -35,8 +35,9 @@ class JinaLlm(BaseLlm):
|
|||||||
if config.top_p:
|
if config.top_p:
|
||||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||||
if config.stream:
|
if config.stream:
|
||||||
from langchain.callbacks.streaming_stdout import \
|
from langchain.callbacks.streaming_stdout import (
|
||||||
StreamingStdOutCallbackHandler
|
StreamingStdOutCallbackHandler,
|
||||||
|
)
|
||||||
|
|
||||||
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
|
chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
from langchain_community.document_loaders import PyPDFLoader
|
from langchain_community.document_loaders import PyPDFLoader
|
||||||
|
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
from embedchain.utils.misc import clean_string
|
from embedchain.utils.misc import clean_string
|
||||||
|
|||||||
@@ -28,7 +28,9 @@ class RSSFeedLoader(BaseLoader):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_rss_content(url: str):
|
def get_rss_content(url: str):
|
||||||
try:
|
try:
|
||||||
from langchain_community.document_loaders import RSSFeedLoader as LangchainRSSFeedLoader
|
from langchain_community.document_loaders import (
|
||||||
|
RSSFeedLoader as LangchainRSSFeedLoader,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"""RSSFeedLoader file requires extra dependencies.
|
"""RSSFeedLoader file requires extra dependencies.
|
||||||
|
|||||||
@@ -11,8 +11,7 @@ class UnstructuredLoader(BaseLoader):
|
|||||||
"""Load data from an Unstructured file."""
|
"""Load data from an Unstructured file."""
|
||||||
try:
|
try:
|
||||||
import unstructured # noqa: F401
|
import unstructured # noqa: F401
|
||||||
from langchain_community.document_loaders import \
|
from langchain_community.document_loaders import UnstructuredFileLoader
|
||||||
UnstructuredFileLoader
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501
|
'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501
|
||||||
|
|||||||
@@ -6,8 +6,15 @@ from embedchain.helpers.json_serializable import register_deserializable
|
|||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
|
from pymilvus import (
|
||||||
MilvusClient, connections, utility)
|
Collection,
|
||||||
|
CollectionSchema,
|
||||||
|
DataType,
|
||||||
|
FieldSchema,
|
||||||
|
MilvusClient,
|
||||||
|
connections,
|
||||||
|
utility,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
|
"Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
|
||||||
|
|||||||
@@ -8,8 +8,7 @@ import streamlit as st
|
|||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import BaseLlmConfig
|
from embedchain.config import BaseLlmConfig
|
||||||
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
|
from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
|
||||||
generate)
|
|
||||||
|
|
||||||
|
|
||||||
def embedchain_bot(db_path, api_key):
|
def embedchain_bot(db_path, api_key):
|
||||||
|
|||||||
@@ -8,8 +8,7 @@ import streamlit as st
|
|||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import BaseLlmConfig
|
from embedchain.config import BaseLlmConfig
|
||||||
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
|
from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
|
||||||
generate)
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ import streamlit as st
|
|||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import BaseLlmConfig
|
from embedchain.config import BaseLlmConfig
|
||||||
from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield,
|
from embedchain.helpers.callbacks import StreamingStdOutCallbackHandlerYield, generate
|
||||||
generate)
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from unittest.mock import patch, Mock
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from embedchain.config import BaseEmbedderConfig
|
from embedchain.config import BaseEmbedderConfig
|
||||||
from embedchain.embedder.huggingface import HuggingFaceEmbedder
|
from embedchain.embedder.huggingface import HuggingFaceEmbedder
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from string import Template
|
|||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import AppConfig, BaseLlmConfig
|
from embedchain.config import AppConfig, BaseLlmConfig
|
||||||
from embedchain.helpers.json_serializable import (JSONSerializable,
|
from embedchain.helpers.json_serializable import (
|
||||||
register_deserializable)
|
JSONSerializable,
|
||||||
|
register_deserializable,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestJsonSerializable(unittest.TestCase):
|
class TestJsonSerializable(unittest.TestCase):
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_alembic_command_upgrade():
|
def mock_alembic_command_upgrade():
|
||||||
with mock.patch("alembic.command.upgrade"):
|
with mock.patch("alembic.command.upgrade"):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from unittest.mock import Mock, MagicMock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
|
|
||||||
from embedchain.config import BaseLlmConfig
|
from embedchain.config import BaseLlmConfig
|
||||||
from embedchain.llm.ollama import OllamaLlm
|
from embedchain.llm.ollama import OllamaLlm
|
||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
from embedchain.models.data_type import (
|
||||||
IndirectDataType, SpecialDataType)
|
DataType,
|
||||||
|
DirectDataType,
|
||||||
|
IndirectDataType,
|
||||||
|
SpecialDataType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_subclass_types_in_data_type():
|
def test_subclass_types_in_data_type():
|
||||||
|
|||||||
@@ -2,5 +2,5 @@ import importlib.metadata
|
|||||||
|
|
||||||
__version__ = importlib.metadata.version("mem0ai")
|
__version__ = importlib.metadata.version("mem0ai")
|
||||||
|
|
||||||
from mem0.memory.main import Memory # noqa
|
|
||||||
from mem0.client.main import MemoryClient # noqa
|
from mem0.client.main import MemoryClient # noqa
|
||||||
|
from mem0.memory.main import Memory # noqa
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from mem0.memory.setup import mem0_dir
|
|
||||||
from mem0.vector_stores.configs import VectorStoreConfig
|
|
||||||
from mem0.llms.configs import LlmConfig
|
|
||||||
from mem0.embeddings.configs import EmbedderConfig
|
from mem0.embeddings.configs import EmbedderConfig
|
||||||
from mem0.graphs.configs import GraphStoreConfig
|
from mem0.graphs.configs import GraphStoreConfig
|
||||||
|
from mem0.llms.configs import LlmConfig
|
||||||
|
from mem0.memory.setup import mem0_dir
|
||||||
|
from mem0.vector_stores.configs import VectorStoreConfig
|
||||||
|
|
||||||
|
|
||||||
class MemoryItem(BaseModel):
|
class MemoryItem(BaseModel):
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from mem0.configs.base import AzureConfig
|
from typing import Dict, Optional, Union
|
||||||
from typing import Optional, Union, Dict
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from mem0.configs.base import AzureConfig
|
||||||
|
|
||||||
|
|
||||||
class BaseEmbedderConfig(ABC):
|
class BaseEmbedderConfig(ABC):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from mem0.configs.base import AzureConfig
|
from typing import Dict, Optional, Union
|
||||||
from typing import Optional, Union, Dict
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from mem0.configs.base import AzureConfig
|
||||||
|
|
||||||
|
|
||||||
class BaseLlmConfig(ABC):
|
class BaseLlmConfig(ABC):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional, ClassVar, Dict, Any
|
from typing import Any, ClassVar, Dict, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Dict, Any
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
|
from typing import Any, ClassVar, Dict, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from typing import Optional, ClassVar, Dict, Any
|
|
||||||
|
|
||||||
|
|
||||||
class QdrantConfig(BaseModel):
|
class QdrantConfig(BaseModel):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from typing import Optional
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
from mem0.llms.configs import LlmConfig
|
from mem0.llms.configs import LlmConfig
|
||||||
|
|
||||||
|
|
||||||
class Neo4jConfig(BaseModel):
|
class Neo4jConfig(BaseModel):
|
||||||
url: Optional[str] = Field(None, description="Host address for the graph database")
|
url: Optional[str] = Field(None, description="Host address for the graph database")
|
||||||
username: Optional[str] = Field(None, description="Username for the graph database")
|
username: Optional[str] = Field(None, description="Username for the graph database")
|
||||||
|
|||||||
@@ -1,7 +1,4 @@
|
|||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -9,8 +6,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
|
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class AnthropicLLM(LLMBase):
|
class AnthropicLLM(LLMBase):
|
||||||
|
|||||||
@@ -1,16 +1,14 @@
|
|||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
from typing import Dict, List, Optional, Any
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import boto3
|
import boto3
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class AWSBedrockLLM(LLMBase):
|
class AWSBedrockLLM(LLMBase):
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from openai import AzureOpenAI
|
from openai import AzureOpenAI
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAILLM(LLMBase):
|
class AzureOpenAILLM(LLMBase):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from typing import Optional
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -9,8 +7,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.")
|
raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.")
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class GroqLLM(LLMBase):
|
class GroqLLM(LLMBase):
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
@@ -8,8 +6,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.")
|
raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.")
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM(LLMBase):
|
class LiteLLM(LLMBase):
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -7,8 +5,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.")
|
raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.")
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class OllamaLLM(LLMBase):
|
class OllamaLLM(LLMBase):
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class OpenAILLM(LLMBase):
|
class OpenAILLM(LLMBase):
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import os, json
|
import json
|
||||||
|
import os
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class OpenAIStructuredLLM(LLMBase):
|
class OpenAIStructuredLLM(LLMBase):
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -9,8 +7,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("The 'together' library is required. Please install it using 'pip install together'.")
|
raise ImportError("The 'together' library is required. Please install it using 'pip install together'.")
|
||||||
|
|
||||||
from mem0.llms.base import LLMBase
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class TogetherLLM(LLMBase):
|
class TogetherLLM(LLMBase):
|
||||||
|
|||||||
@@ -1,10 +1,17 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from langchain_community.graphs import Neo4jGraph
|
from langchain_community.graphs import Neo4jGraph
|
||||||
from rank_bm25 import BM25Okapi
|
from rank_bm25 import BM25Okapi
|
||||||
from mem0.utils.factory import LlmFactory, EmbedderFactory
|
|
||||||
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
|
from mem0.graphs.tools import (
|
||||||
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL
|
ADD_MEMORY_TOOL_GRAPH,
|
||||||
|
ADD_MESSAGE_TOOL,
|
||||||
|
NOOP_TOOL,
|
||||||
|
SEARCH_TOOL,
|
||||||
|
UPDATE_MEMORY_TOOL_GRAPH,
|
||||||
|
)
|
||||||
|
from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages
|
||||||
|
from mem0.utils.factory import EmbedderFactory, LlmFactory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,22 +1,24 @@
|
|||||||
import logging
|
import concurrent
|
||||||
import hashlib
|
import hashlib
|
||||||
import uuid
|
|
||||||
import pytz
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
import warnings
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
import warnings
|
|
||||||
|
import pytz
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from mem0.configs.base import MemoryConfig, MemoryItem
|
||||||
|
from mem0.configs.prompts import get_update_memory_messages
|
||||||
from mem0.memory.base import MemoryBase
|
from mem0.memory.base import MemoryBase
|
||||||
from mem0.memory.setup import setup_config
|
from mem0.memory.setup import setup_config
|
||||||
from mem0.memory.storage import SQLiteManager
|
from mem0.memory.storage import SQLiteManager
|
||||||
from mem0.memory.telemetry import capture_event
|
from mem0.memory.telemetry import capture_event
|
||||||
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
|
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
|
||||||
from mem0.configs.prompts import get_update_memory_messages
|
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
||||||
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
|
|
||||||
from mem0.configs.base import MemoryItem, MemoryConfig
|
|
||||||
import threading
|
|
||||||
import concurrent
|
|
||||||
|
|
||||||
# Setup user config
|
# Setup user config
|
||||||
setup_config()
|
setup_config()
|
||||||
@@ -154,7 +156,7 @@ class Memory(MemoryBase):
|
|||||||
logging.info(resp)
|
logging.info(resp)
|
||||||
try:
|
try:
|
||||||
if resp["event"] == "ADD":
|
if resp["event"] == "ADD":
|
||||||
memory_id = self._create_memory(data=resp["text"], metadata=metadata)
|
self._create_memory(data=resp["text"], metadata=metadata)
|
||||||
elif resp["event"] == "UPDATE":
|
elif resp["event"] == "UPDATE":
|
||||||
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
|
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
|
||||||
elif resp["event"] == "DELETE":
|
elif resp["event"] == "DELETE":
|
||||||
@@ -175,7 +177,7 @@ class Memory(MemoryBase):
|
|||||||
else:
|
else:
|
||||||
self.graph.user_id = "USER"
|
self.graph.user_id = "USER"
|
||||||
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
|
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
|
||||||
added_entities = self.graph.add(data, filters)
|
self.graph.add(data, filters)
|
||||||
|
|
||||||
def get(self, memory_id):
|
def get(self, memory_id):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -5,6 +6,8 @@ from posthog import Posthog
|
|||||||
|
|
||||||
from mem0.memory.setup import get_user_id, setup_config
|
from mem0.memory.setup import get_user_id, setup_config
|
||||||
|
|
||||||
|
logging.getLogger('posthog').setLevel(logging.CRITICAL + 1)
|
||||||
|
logging.getLogger('urllib3').setLevel(logging.CRITICAL + 1)
|
||||||
|
|
||||||
class AnonymousTelemetry:
|
class AnonymousTelemetry:
|
||||||
def __init__(self, project_api_key, host):
|
def __init__(self, project_api_key, host):
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import httpx
|
|
||||||
from typing import Optional, List, Union
|
|
||||||
import threading
|
import threading
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import litellm
|
import litellm
|
||||||
@@ -20,9 +21,9 @@ except ImportError:
|
|||||||
raise ImportError("The required 'litellm' library is not installed.")
|
raise ImportError("The required 'litellm' library is not installed.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
from mem0.memory.telemetry import capture_client_event
|
|
||||||
from mem0 import Memory, MemoryClient
|
from mem0 import Memory, MemoryClient
|
||||||
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
|
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
|
||||||
|
from mem0.memory.telemetry import capture_client_event
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
|
||||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||||
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
|
||||||
|
|
||||||
def load_class(class_type):
|
def load_class(class_type):
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, List, Dict
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Optional, Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, List
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import logging
|
|
||||||
|
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.models import (
|
from qdrant_client.models import (
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
|
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
from mem0.llms.azure_openai import AzureOpenAILLM
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.azure_openai import AzureOpenAILLM
|
||||||
|
|
||||||
MODEL = "gpt-4o" # or your custom deployment name
|
MODEL = "gpt-4o" # or your custom deployment name
|
||||||
TEMPERATURE = 0.7
|
TEMPERATURE = 0.7
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from mem0.llms.groq import GroqLLM
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.groq import GroqLLM
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_groq_client():
|
def mock_groq_client():
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from mem0.llms import litellm
|
import pytest
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms import litellm
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_litellm():
|
def mock_litellm():
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from mem0.llms.ollama import OllamaLLM
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.ollama import OllamaLLM
|
||||||
from mem0.llms.utils.tools import ADD_MEMORY_TOOL
|
from mem0.llms.utils.tools import ADD_MEMORY_TOOL
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_ollama_client():
|
def mock_ollama_client():
|
||||||
with patch('mem0.llms.ollama.Client') as mock_ollama:
|
with patch('mem0.llms.ollama.Client') as mock_ollama:
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from mem0.llms.openai import OpenAILLM
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.openai import OpenAILLM
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_openai_client():
|
def mock_openai_client():
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from mem0.llms.together import TogetherLLM
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.llms.together import TogetherLLM
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_together_client():
|
def mock_together_client():
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
|
import pytest
|
||||||
|
|
||||||
from mem0 import Memory, MemoryClient
|
from mem0 import Memory, MemoryClient
|
||||||
from mem0.proxy.main import Mem0
|
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
|
||||||
from mem0.proxy.main import Chat, Completions
|
from mem0.proxy.main import Chat, Completions, Mem0
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_memory_client():
|
def mock_memory_client():
|
||||||
|
|||||||
Reference in New Issue
Block a user