From 39e5cbfacc4e487773fe98fba0718e2ae5a5c901 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Mon, 7 Apr 2025 11:28:30 +0530 Subject: [PATCH] Support for langchain LLMs (#2506) --- docs/components/llms/config.mdx | 1 + docs/components/llms/models/langchain.mdx | 72 ++++++++ docs/components/llms/overview.mdx | 1 + docs/docs.json | 3 +- mem0/configs/llms/base.py | 7 + mem0/llms/configs.py | 1 + mem0/llms/langchain.py | 208 ++++++++++++++++++++++ mem0/utils/factory.py | 1 + tests/llms/test_langchain.py | 100 +++++++++++ 9 files changed, 393 insertions(+), 1 deletion(-) create mode 100644 docs/components/llms/models/langchain.mdx create mode 100644 mem0/llms/langchain.py create mode 100644 tests/llms/test_langchain.py diff --git a/docs/components/llms/config.mdx b/docs/components/llms/config.mdx index 57c8f70c..77b75229 100644 --- a/docs/components/llms/config.mdx +++ b/docs/components/llms/config.mdx @@ -109,6 +109,7 @@ Here's a comprehensive list of all parameters that can be used across different | `deepseek_base_url` | Base URL for DeepSeek API | DeepSeek | | `xai_base_url` | Base URL for XAI API | XAI | | `lmstudio_base_url` | Base URL for LM Studio API | LM Studio | + | `langchain_provider` | Provider for Langchain | Langchain | | Parameter | Description | Provider | diff --git a/docs/components/llms/models/langchain.mdx b/docs/components/llms/models/langchain.mdx new file mode 100644 index 00000000..5cc6566a --- /dev/null +++ b/docs/components/llms/models/langchain.mdx @@ -0,0 +1,72 @@ +--- +title: LangChain +--- + +Mem0 supports LangChain as a provider to access a wide range of LLM models. LangChain is a framework for developing applications powered by language models, making it easy to integrate various LLM providers through a consistent interface. + +For a complete list of available chat models supported by LangChain, refer to the [LangChain Chat Models documentation](https://python.langchain.com/docs/integrations/chat). + +## Usage + + +```python Python +import os +from mem0 import Memory + +# Set necessary environment variables for your chosen LangChain provider +# For example, if using OpenAI through LangChain: +os.environ["OPENAI_API_KEY"] = "your-api-key" + +config = { + "llm": { + "provider": "langchain", + "config": { + "langchain_provider": "OpenAI", + "model": "gpt-4o", + "temperature": 0.2, + "max_tokens": 2000, + } + } +} + +m = Memory.from_config(config) +messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, + {"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] +m.add(messages, user_id="alice", metadata={"category": "movies"}) +``` + + +## Supported LangChain Providers + +LangChain supports a wide range of LLM providers, including: + +- OpenAI (`ChatOpenAI`) +- Anthropic (`ChatAnthropic`) +- Google (`ChatGoogleGenerativeAI`, `ChatGooglePalm`) +- Mistral (`ChatMistralAI`) +- Ollama (`ChatOllama`) +- Azure OpenAI (`AzureChatOpenAI`) +- HuggingFace (`HuggingFaceChatEndpoint`) +- And many more + +You can specify any supported provider in the `langchain_provider` parameter. For a complete and up-to-date list of available providers, refer to the [LangChain Chat Models documentation](https://python.langchain.com/docs/integrations/chat). + +## Provider-Specific Configuration + +When using LangChain as a provider, you'll need to: + +1. Set the appropriate environment variables for your chosen LLM provider +2. Specify the LangChain provider class name in the `langchain_provider` parameter +3. Include any additional configuration parameters required by the specific provider + + + Make sure to install the necessary LangChain packages and any provider-specific dependencies. + + +## Config + +All available parameters for the `langchain` config are present in [Master List of All Params in Config](../config). diff --git a/docs/components/llms/overview.mdx b/docs/components/llms/overview.mdx index 1c2d8828..a719bc72 100644 --- a/docs/components/llms/overview.mdx +++ b/docs/components/llms/overview.mdx @@ -33,6 +33,7 @@ To view all supported llms, visit the [Supported LLMs](./models). + ## Structured vs Unstructured Outputs diff --git a/docs/docs.json b/docs/docs.json index 8ffa66b0..7e092213 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -111,7 +111,8 @@ "components/llms/models/gemini", "components/llms/models/deepseek", "components/llms/models/xAI", - "components/llms/models/lmstudio" + "components/llms/models/lmstudio", + "components/llms/models/langchain" ] } ] diff --git a/mem0/configs/llms/base.py b/mem0/configs/llms/base.py index e71cfdbc..b78d44fe 100644 --- a/mem0/configs/llms/base.py +++ b/mem0/configs/llms/base.py @@ -41,6 +41,8 @@ class BaseLlmConfig(ABC): xai_base_url: Optional[str] = None, # LM Studio specific lmstudio_base_url: Optional[str] = "http://localhost:1234/v1", + # Langchain specific + langchain_provider: Optional[str] = None, ): """ Initializes a configuration class instance for the LLM. @@ -87,6 +89,8 @@ class BaseLlmConfig(ABC): :type xai_base_url: Optional[str], optional :param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1" :type lmstudio_base_url: Optional[str], optional + :param langchain_provider: Langchain provider to be use, defaults to None + :type langchain_provider: Optional[str], optional """ self.model = model @@ -123,3 +127,6 @@ class BaseLlmConfig(ABC): # LM Studio specific self.lmstudio_base_url = lmstudio_base_url + + # Langchain specific + self.langchain_provider = langchain_provider diff --git a/mem0/llms/configs.py b/mem0/llms/configs.py index e94ef43f..615fc37a 100644 --- a/mem0/llms/configs.py +++ b/mem0/llms/configs.py @@ -25,6 +25,7 @@ class LlmConfig(BaseModel): "deepseek", "xai", "lmstudio", + "langchain", ): return v else: diff --git a/mem0/llms/langchain.py b/mem0/llms/langchain.py new file mode 100644 index 00000000..788af96b --- /dev/null +++ b/mem0/llms/langchain.py @@ -0,0 +1,208 @@ +from typing import Dict, List, Optional +import enum + +from mem0.configs.llms.base import BaseLlmConfig +from mem0.llms.base import LLMBase + +# Default import for langchain_community +try: + from langchain_community import chat_models +except ImportError: + raise ImportError("langchain_community not found. Please install it with `pip install langchain-community`") + +# Provider-specific package mapping +PROVIDER_PACKAGES = { + # "Anthropic": "langchain_anthropic", # Special handling for Anthropic with Pydantic v2 + "MistralAI": "langchain_mistralai", + "Fireworks": "langchain_fireworks", + "AzureOpenAI": "langchain_openai", + "OpenAI": "langchain_openai", + "Together": "langchain_together", + "VertexAI": "langchain_google_vertexai", + "GoogleAI": "langchain_google_genai", + "Groq": "langchain_groq", + "Cohere": "langchain_cohere", + "Bedrock": "langchain_aws", + "HuggingFace": "langchain_huggingface", + "NVIDIA": "langchain_nvidia_ai_endpoints", + "Ollama": "langchain_ollama", + "AI21": "langchain_ai21", + "Upstage": "langchain_upstage", + "Databricks": "databricks_langchain", + "Watsonx": "langchain_ibm", + "xAI": "langchain_xai", + "Perplexity": "langchain_perplexity", +} + + +class LangchainProvider(enum.Enum): + Abso = "ChatAbso" + AI21 = "ChatAI21" + Alibaba = "ChatAlibabaCloud" + Anthropic = "ChatAnthropic" + Anyscale = "ChatAnyscale" + AzureAIChatCompletionsModel = "AzureAIChatCompletionsModel" + AzureOpenAI = "AzureChatOpenAI" + AzureMLEndpoint = "ChatAzureMLEndpoint" + Baichuan = "ChatBaichuan" + Qianfan = "ChatQianfan" + Bedrock = "ChatBedrock" + Cerebras = "ChatCerebras" + CloudflareWorkersAI = "ChatCloudflareWorkersAI" + Cohere = "ChatCohere" + ContextualAI = "ChatContextualAI" + Coze = "ChatCoze" + Dappier = "ChatDappier" + Databricks = "ChatDatabricks" + DeepInfra = "ChatDeepInfra" + DeepSeek = "ChatDeepSeek" + EdenAI = "ChatEdenAI" + EverlyAI = "ChatEverlyAI" + Fireworks = "ChatFireworks" + Friendli = "ChatFriendli" + GigaChat = "ChatGigaChat" + Goodfire = "ChatGoodfire" + GoogleAI = "ChatGoogleAI" + VertexAI = "VertexAI" + GPTRouter = "ChatGPTRouter" + Groq = "ChatGroq" + HuggingFace = "ChatHuggingFace" + Watsonx = "ChatWatsonx" + Jina = "ChatJina" + Kinetica = "ChatKinetica" + Konko = "ChatKonko" + LiteLLM = "ChatLiteLLM" + LiteLLMRouter = "ChatLiteLLMRouter" + Llama2Chat = "Llama2Chat" + LlamaAPI = "ChatLlamaAPI" + LlamaEdge = "ChatLlamaEdge" + LlamaCpp = "ChatLlamaCpp" + Maritalk = "ChatMaritalk" + MiniMax = "ChatMiniMax" + MistralAI = "ChatMistralAI" + MLX = "ChatMLX" + ModelScope = "ChatModelScope" + Moonshot = "ChatMoonshot" + Naver = "ChatNaver" + Netmind = "ChatNetmind" + NVIDIA = "ChatNVIDIA" + OCIModelDeployment = "ChatOCIModelDeployment" + OCIGenAI = "ChatOCIGenAI" + OctoAI = "ChatOctoAI" + Ollama = "ChatOllama" + OpenAI = "ChatOpenAI" + Outlines = "ChatOutlines" + Perplexity = "ChatPerplexity" + Pipeshift = "ChatPipeshift" + PredictionGuard = "ChatPredictionGuard" + PremAI = "ChatPremAI" + PromptLayerOpenAI = "PromptLayerChatOpenAI" + QwQ = "ChatQwQ" + Reka = "ChatReka" + RunPod = "ChatRunPod" + SambaNovaCloud = "ChatSambaNovaCloud" + SambaStudio = "ChatSambaStudio" + SeekrFlow = "ChatSeekrFlow" + SnowflakeCortex = "ChatSnowflakeCortex" + Solar = "ChatSolar" + SparkLLM = "ChatSparkLLM" + Nebula = "ChatNebula" + Hunyuan = "ChatHunyuan" + Together = "ChatTogether" + TongyiQwen = "ChatTongyiQwen" + Upstage = "ChatUpstage" + Vectara = "ChatVectara" + VLLM = "ChatVLLM" + VolcEngine = "ChatVolcEngine" + Writer = "ChatWriter" + xAI = "ChatXAI" + Xinference = "ChatXinference" + Yandex = "ChatYandex" + Yi = "ChatYi" + Yuan2 = "ChatYuan2" + ZhipuAI = "ChatZhipuAI" + + +class LangchainLLM(LLMBase): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config) + + provider = self.config.langchain_provider + if provider not in LangchainProvider.__members__: + raise ValueError(f"Invalid provider: {provider}") + model_name = LangchainProvider[provider].value + + try: + # Check if this provider needs a specialized package + if provider in PROVIDER_PACKAGES: + package_name = PROVIDER_PACKAGES[provider] + try: + # Import the model class directly from the package + module_path = f"{package_name}" + model_class = __import__(module_path, fromlist=[model_name]) + model_class = getattr(model_class, model_name) + except ImportError: + raise ImportError( + f"Package {package_name} not found. " f"Please install it with `pip install {package_name}`" + ) + except AttributeError: + raise ImportError(f"Model {model_name} not found in {package_name}") + else: + # Use the default langchain_community module + if not hasattr(chat_models, model_name): + raise ImportError(f"Provider {provider} not found in langchain_community.chat_models") + + model_class = getattr(chat_models, model_name) + + # Initialize the model with relevant config parameters + self.langchain_model = model_class( + model=self.config.model, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + api_key=self.config.api_key, + ) + except (ImportError, AttributeError, ValueError) as e: + raise ImportError(f"Error setting up langchain model for provider {provider}: {str(e)}") + + def generate_response( + self, + messages: List[Dict[str, str]], + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): + """ + Generate a response based on the given messages using langchain_community. + + Args: + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Not used in Langchain. + tools (list, optional): List of tools that the model can call. Not used in Langchain. + tool_choice (str, optional): Tool choice method. Not used in Langchain. + + Returns: + str: The generated response. + """ + try: + # Convert the messages to LangChain's tuple format + langchain_messages = [] + for message in messages: + role = message["role"] + content = message["content"] + + if role == "system": + langchain_messages.append(("system", content)) + elif role == "user": + langchain_messages.append(("human", content)) + elif role == "assistant": + langchain_messages.append(("ai", content)) + + if not langchain_messages: + raise ValueError("No valid messages found in the messages list") + + ai_message = self.langchain_model.invoke(langchain_messages) + + return ai_message.content + + except Exception as e: + raise Exception(f"Error generating response using langchain model: {str(e)}") diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 0bb76797..31c2f3c6 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -26,6 +26,7 @@ class LlmFactory: "deepseek": "mem0.llms.deepseek.DeepSeekLLM", "xai": "mem0.llms.xai.XAILLM", "lmstudio": "mem0.llms.lmstudio.LMStudioLLM", + "langchain": "mem0.llms.langchain.LangchainLLM", } @classmethod diff --git a/tests/llms/test_langchain.py b/tests/llms/test_langchain.py new file mode 100644 index 00000000..67b00e6d --- /dev/null +++ b/tests/llms/test_langchain.py @@ -0,0 +1,100 @@ +from unittest.mock import Mock, patch +import pytest + +from mem0.configs.llms.base import BaseLlmConfig +from mem0.llms.langchain import LangchainLLM + + +@pytest.fixture +def mock_langchain_model(): + """Mock a Langchain model for testing.""" + with patch("langchain_openai.ChatOpenAI") as mock_chat_model: + mock_model = Mock() + mock_model.invoke.return_value = Mock(content="This is a test response") + mock_chat_model.return_value = mock_model + yield mock_model + + +def test_langchain_initialization(): + """Test that LangchainLLM initializes correctly with a valid provider.""" + with patch("langchain_openai.ChatOpenAI") as mock_chat_model: + # Setup the mock model + mock_model = Mock() + mock_chat_model.return_value = mock_model + + # Create a config with OpenAI provider + config = BaseLlmConfig( + model="gpt-3.5-turbo", + temperature=0.7, + max_tokens=100, + api_key="test-api-key", + langchain_provider="OpenAI" + ) + + # Initialize the LangchainLLM + llm = LangchainLLM(config) + + # Verify the model was initialized with correct parameters + mock_chat_model.assert_called_once_with( + model="gpt-3.5-turbo", + temperature=0.7, + max_tokens=100, + api_key="test-api-key" + ) + + assert llm.langchain_model == mock_model + + +def test_generate_response(mock_langchain_model): + """Test that generate_response correctly processes messages and returns a response.""" + # Create a config with OpenAI provider + config = BaseLlmConfig( + model="gpt-3.5-turbo", + temperature=0.7, + max_tokens=100, + api_key="test-api-key", + langchain_provider="OpenAI" + ) + + # Initialize the LangchainLLM + with patch("langchain_openai.ChatOpenAI", return_value=mock_langchain_model): + llm = LangchainLLM(config) + + # Create test messages + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well! How can I help you?"}, + {"role": "user", "content": "Tell me a joke."} + ] + + # Get response + response = llm.generate_response(messages) + + # Verify the correct message format was passed to the model + expected_langchain_messages = [ + ("system", "You are a helpful assistant."), + ("human", "Hello, how are you?"), + ("ai", "I'm doing well! How can I help you?"), + ("human", "Tell me a joke.") + ] + + mock_langchain_model.invoke.assert_called_once() + # Extract the first argument of the first call + actual_messages = mock_langchain_model.invoke.call_args[0][0] + assert actual_messages == expected_langchain_messages + assert response == "This is a test response" + + +def test_invalid_provider(): + """Test that LangchainLLM raises an error with an invalid provider.""" + config = BaseLlmConfig( + model="test-model", + temperature=0.7, + max_tokens=100, + api_key="test-api-key", + langchain_provider="InvalidProvider" + ) + + with pytest.raises(ValueError, match="Invalid provider: InvalidProvider"): + LangchainLLM(config)