Support for langchain LLMs (#2506)

This commit is contained in:
Dev Khant
2025-04-07 11:28:30 +05:30
committed by GitHub
parent d30c78c5eb
commit 39e5cbfacc
9 changed files with 393 additions and 1 deletions

View File

@@ -109,6 +109,7 @@ Here's a comprehensive list of all parameters that can be used across different
| `deepseek_base_url` | Base URL for DeepSeek API | DeepSeek |
| `xai_base_url` | Base URL for XAI API | XAI |
| `lmstudio_base_url` | Base URL for LM Studio API | LM Studio |
| `langchain_provider` | Provider for Langchain | Langchain |
</Tab>
<Tab title="TypeScript">
| Parameter | Description | Provider |

View File

@@ -0,0 +1,72 @@
---
title: LangChain
---
Mem0 supports LangChain as a provider to access a wide range of LLM models. LangChain is a framework for developing applications powered by language models, making it easy to integrate various LLM providers through a consistent interface.
For a complete list of available chat models supported by LangChain, refer to the [LangChain Chat Models documentation](https://python.langchain.com/docs/integrations/chat).
## Usage
<CodeGroup>
```python Python
import os
from mem0 import Memory
# Set necessary environment variables for your chosen LangChain provider
# For example, if using OpenAI through LangChain:
os.environ["OPENAI_API_KEY"] = "your-api-key"
config = {
"llm": {
"provider": "langchain",
"config": {
"langchain_provider": "OpenAI",
"model": "gpt-4o",
"temperature": 0.2,
"max_tokens": 2000,
}
}
}
m = Memory.from_config(config)
messages = [
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
]
m.add(messages, user_id="alice", metadata={"category": "movies"})
```
</CodeGroup>
## Supported LangChain Providers
LangChain supports a wide range of LLM providers, including:
- OpenAI (`ChatOpenAI`)
- Anthropic (`ChatAnthropic`)
- Google (`ChatGoogleGenerativeAI`, `ChatGooglePalm`)
- Mistral (`ChatMistralAI`)
- Ollama (`ChatOllama`)
- Azure OpenAI (`AzureChatOpenAI`)
- HuggingFace (`HuggingFaceChatEndpoint`)
- And many more
You can specify any supported provider in the `langchain_provider` parameter. For a complete and up-to-date list of available providers, refer to the [LangChain Chat Models documentation](https://python.langchain.com/docs/integrations/chat).
## Provider-Specific Configuration
When using LangChain as a provider, you'll need to:
1. Set the appropriate environment variables for your chosen LLM provider
2. Specify the LangChain provider class name in the `langchain_provider` parameter
3. Include any additional configuration parameters required by the specific provider
<Note>
Make sure to install the necessary LangChain packages and any provider-specific dependencies.
</Note>
## Config
All available parameters for the `langchain` config are present in [Master List of All Params in Config](../config).

View File

@@ -33,6 +33,7 @@ To view all supported llms, visit the [Supported LLMs](./models).
<Card title="DeepSeek" href="/components/llms/models/deepseek" />
<Card title="xAI" href="/components/llms/models/xAI" />
<Card title="LM Studio" href="/components/llms/models/lmstudio" />
<Card title="Langchain" href="/components/llms/models/langchain" />
</CardGroup>
## Structured vs Unstructured Outputs

View File

@@ -111,7 +111,8 @@
"components/llms/models/gemini",
"components/llms/models/deepseek",
"components/llms/models/xAI",
"components/llms/models/lmstudio"
"components/llms/models/lmstudio",
"components/llms/models/langchain"
]
}
]

View File

@@ -41,6 +41,8 @@ class BaseLlmConfig(ABC):
xai_base_url: Optional[str] = None,
# LM Studio specific
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
# Langchain specific
langchain_provider: Optional[str] = None,
):
"""
Initializes a configuration class instance for the LLM.
@@ -87,6 +89,8 @@ class BaseLlmConfig(ABC):
:type xai_base_url: Optional[str], optional
:param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1"
:type lmstudio_base_url: Optional[str], optional
:param langchain_provider: Langchain provider to be use, defaults to None
:type langchain_provider: Optional[str], optional
"""
self.model = model
@@ -123,3 +127,6 @@ class BaseLlmConfig(ABC):
# LM Studio specific
self.lmstudio_base_url = lmstudio_base_url
# Langchain specific
self.langchain_provider = langchain_provider

View File

@@ -25,6 +25,7 @@ class LlmConfig(BaseModel):
"deepseek",
"xai",
"lmstudio",
"langchain",
):
return v
else:

208
mem0/llms/langchain.py Normal file
View File

@@ -0,0 +1,208 @@
from typing import Dict, List, Optional
import enum
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
# Default import for langchain_community
try:
from langchain_community import chat_models
except ImportError:
raise ImportError("langchain_community not found. Please install it with `pip install langchain-community`")
# Provider-specific package mapping
PROVIDER_PACKAGES = {
# "Anthropic": "langchain_anthropic", # Special handling for Anthropic with Pydantic v2
"MistralAI": "langchain_mistralai",
"Fireworks": "langchain_fireworks",
"AzureOpenAI": "langchain_openai",
"OpenAI": "langchain_openai",
"Together": "langchain_together",
"VertexAI": "langchain_google_vertexai",
"GoogleAI": "langchain_google_genai",
"Groq": "langchain_groq",
"Cohere": "langchain_cohere",
"Bedrock": "langchain_aws",
"HuggingFace": "langchain_huggingface",
"NVIDIA": "langchain_nvidia_ai_endpoints",
"Ollama": "langchain_ollama",
"AI21": "langchain_ai21",
"Upstage": "langchain_upstage",
"Databricks": "databricks_langchain",
"Watsonx": "langchain_ibm",
"xAI": "langchain_xai",
"Perplexity": "langchain_perplexity",
}
class LangchainProvider(enum.Enum):
Abso = "ChatAbso"
AI21 = "ChatAI21"
Alibaba = "ChatAlibabaCloud"
Anthropic = "ChatAnthropic"
Anyscale = "ChatAnyscale"
AzureAIChatCompletionsModel = "AzureAIChatCompletionsModel"
AzureOpenAI = "AzureChatOpenAI"
AzureMLEndpoint = "ChatAzureMLEndpoint"
Baichuan = "ChatBaichuan"
Qianfan = "ChatQianfan"
Bedrock = "ChatBedrock"
Cerebras = "ChatCerebras"
CloudflareWorkersAI = "ChatCloudflareWorkersAI"
Cohere = "ChatCohere"
ContextualAI = "ChatContextualAI"
Coze = "ChatCoze"
Dappier = "ChatDappier"
Databricks = "ChatDatabricks"
DeepInfra = "ChatDeepInfra"
DeepSeek = "ChatDeepSeek"
EdenAI = "ChatEdenAI"
EverlyAI = "ChatEverlyAI"
Fireworks = "ChatFireworks"
Friendli = "ChatFriendli"
GigaChat = "ChatGigaChat"
Goodfire = "ChatGoodfire"
GoogleAI = "ChatGoogleAI"
VertexAI = "VertexAI"
GPTRouter = "ChatGPTRouter"
Groq = "ChatGroq"
HuggingFace = "ChatHuggingFace"
Watsonx = "ChatWatsonx"
Jina = "ChatJina"
Kinetica = "ChatKinetica"
Konko = "ChatKonko"
LiteLLM = "ChatLiteLLM"
LiteLLMRouter = "ChatLiteLLMRouter"
Llama2Chat = "Llama2Chat"
LlamaAPI = "ChatLlamaAPI"
LlamaEdge = "ChatLlamaEdge"
LlamaCpp = "ChatLlamaCpp"
Maritalk = "ChatMaritalk"
MiniMax = "ChatMiniMax"
MistralAI = "ChatMistralAI"
MLX = "ChatMLX"
ModelScope = "ChatModelScope"
Moonshot = "ChatMoonshot"
Naver = "ChatNaver"
Netmind = "ChatNetmind"
NVIDIA = "ChatNVIDIA"
OCIModelDeployment = "ChatOCIModelDeployment"
OCIGenAI = "ChatOCIGenAI"
OctoAI = "ChatOctoAI"
Ollama = "ChatOllama"
OpenAI = "ChatOpenAI"
Outlines = "ChatOutlines"
Perplexity = "ChatPerplexity"
Pipeshift = "ChatPipeshift"
PredictionGuard = "ChatPredictionGuard"
PremAI = "ChatPremAI"
PromptLayerOpenAI = "PromptLayerChatOpenAI"
QwQ = "ChatQwQ"
Reka = "ChatReka"
RunPod = "ChatRunPod"
SambaNovaCloud = "ChatSambaNovaCloud"
SambaStudio = "ChatSambaStudio"
SeekrFlow = "ChatSeekrFlow"
SnowflakeCortex = "ChatSnowflakeCortex"
Solar = "ChatSolar"
SparkLLM = "ChatSparkLLM"
Nebula = "ChatNebula"
Hunyuan = "ChatHunyuan"
Together = "ChatTogether"
TongyiQwen = "ChatTongyiQwen"
Upstage = "ChatUpstage"
Vectara = "ChatVectara"
VLLM = "ChatVLLM"
VolcEngine = "ChatVolcEngine"
Writer = "ChatWriter"
xAI = "ChatXAI"
Xinference = "ChatXinference"
Yandex = "ChatYandex"
Yi = "ChatYi"
Yuan2 = "ChatYuan2"
ZhipuAI = "ChatZhipuAI"
class LangchainLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
provider = self.config.langchain_provider
if provider not in LangchainProvider.__members__:
raise ValueError(f"Invalid provider: {provider}")
model_name = LangchainProvider[provider].value
try:
# Check if this provider needs a specialized package
if provider in PROVIDER_PACKAGES:
package_name = PROVIDER_PACKAGES[provider]
try:
# Import the model class directly from the package
module_path = f"{package_name}"
model_class = __import__(module_path, fromlist=[model_name])
model_class = getattr(model_class, model_name)
except ImportError:
raise ImportError(
f"Package {package_name} not found. " f"Please install it with `pip install {package_name}`"
)
except AttributeError:
raise ImportError(f"Model {model_name} not found in {package_name}")
else:
# Use the default langchain_community module
if not hasattr(chat_models, model_name):
raise ImportError(f"Provider {provider} not found in langchain_community.chat_models")
model_class = getattr(chat_models, model_name)
# Initialize the model with relevant config parameters
self.langchain_model = model_class(
model=self.config.model,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
api_key=self.config.api_key,
)
except (ImportError, AttributeError, ValueError) as e:
raise ImportError(f"Error setting up langchain model for provider {provider}: {str(e)}")
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using langchain_community.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Not used in Langchain.
tools (list, optional): List of tools that the model can call. Not used in Langchain.
tool_choice (str, optional): Tool choice method. Not used in Langchain.
Returns:
str: The generated response.
"""
try:
# Convert the messages to LangChain's tuple format
langchain_messages = []
for message in messages:
role = message["role"]
content = message["content"]
if role == "system":
langchain_messages.append(("system", content))
elif role == "user":
langchain_messages.append(("human", content))
elif role == "assistant":
langchain_messages.append(("ai", content))
if not langchain_messages:
raise ValueError("No valid messages found in the messages list")
ai_message = self.langchain_model.invoke(langchain_messages)
return ai_message.content
except Exception as e:
raise Exception(f"Error generating response using langchain model: {str(e)}")

View File

@@ -26,6 +26,7 @@ class LlmFactory:
"deepseek": "mem0.llms.deepseek.DeepSeekLLM",
"xai": "mem0.llms.xai.XAILLM",
"lmstudio": "mem0.llms.lmstudio.LMStudioLLM",
"langchain": "mem0.llms.langchain.LangchainLLM",
}
@classmethod

View File

@@ -0,0 +1,100 @@
from unittest.mock import Mock, patch
import pytest
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.langchain import LangchainLLM
@pytest.fixture
def mock_langchain_model():
"""Mock a Langchain model for testing."""
with patch("langchain_openai.ChatOpenAI") as mock_chat_model:
mock_model = Mock()
mock_model.invoke.return_value = Mock(content="This is a test response")
mock_chat_model.return_value = mock_model
yield mock_model
def test_langchain_initialization():
"""Test that LangchainLLM initializes correctly with a valid provider."""
with patch("langchain_openai.ChatOpenAI") as mock_chat_model:
# Setup the mock model
mock_model = Mock()
mock_chat_model.return_value = mock_model
# Create a config with OpenAI provider
config = BaseLlmConfig(
model="gpt-3.5-turbo",
temperature=0.7,
max_tokens=100,
api_key="test-api-key",
langchain_provider="OpenAI"
)
# Initialize the LangchainLLM
llm = LangchainLLM(config)
# Verify the model was initialized with correct parameters
mock_chat_model.assert_called_once_with(
model="gpt-3.5-turbo",
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
assert llm.langchain_model == mock_model
def test_generate_response(mock_langchain_model):
"""Test that generate_response correctly processes messages and returns a response."""
# Create a config with OpenAI provider
config = BaseLlmConfig(
model="gpt-3.5-turbo",
temperature=0.7,
max_tokens=100,
api_key="test-api-key",
langchain_provider="OpenAI"
)
# Initialize the LangchainLLM
with patch("langchain_openai.ChatOpenAI", return_value=mock_langchain_model):
llm = LangchainLLM(config)
# Create test messages
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
{"role": "user", "content": "Tell me a joke."}
]
# Get response
response = llm.generate_response(messages)
# Verify the correct message format was passed to the model
expected_langchain_messages = [
("system", "You are a helpful assistant."),
("human", "Hello, how are you?"),
("ai", "I'm doing well! How can I help you?"),
("human", "Tell me a joke.")
]
mock_langchain_model.invoke.assert_called_once()
# Extract the first argument of the first call
actual_messages = mock_langchain_model.invoke.call_args[0][0]
assert actual_messages == expected_langchain_messages
assert response == "This is a test response"
def test_invalid_provider():
"""Test that LangchainLLM raises an error with an invalid provider."""
config = BaseLlmConfig(
model="test-model",
temperature=0.7,
max_tokens=100,
api_key="test-api-key",
langchain_provider="InvalidProvider"
)
with pytest.raises(ValueError, match="Invalid provider: InvalidProvider"):
LangchainLLM(config)