Support for langchain LLMs (#2506)
This commit is contained in:
@@ -109,6 +109,7 @@ Here's a comprehensive list of all parameters that can be used across different
|
||||
| `deepseek_base_url` | Base URL for DeepSeek API | DeepSeek |
|
||||
| `xai_base_url` | Base URL for XAI API | XAI |
|
||||
| `lmstudio_base_url` | Base URL for LM Studio API | LM Studio |
|
||||
| `langchain_provider` | Provider for Langchain | Langchain |
|
||||
</Tab>
|
||||
<Tab title="TypeScript">
|
||||
| Parameter | Description | Provider |
|
||||
|
||||
72
docs/components/llms/models/langchain.mdx
Normal file
72
docs/components/llms/models/langchain.mdx
Normal file
@@ -0,0 +1,72 @@
|
||||
---
|
||||
title: LangChain
|
||||
---
|
||||
|
||||
Mem0 supports LangChain as a provider to access a wide range of LLM models. LangChain is a framework for developing applications powered by language models, making it easy to integrate various LLM providers through a consistent interface.
|
||||
|
||||
For a complete list of available chat models supported by LangChain, refer to the [LangChain Chat Models documentation](https://python.langchain.com/docs/integrations/chat).
|
||||
|
||||
## Usage
|
||||
|
||||
<CodeGroup>
|
||||
```python Python
|
||||
import os
|
||||
from mem0 import Memory
|
||||
|
||||
# Set necessary environment variables for your chosen LangChain provider
|
||||
# For example, if using OpenAI through LangChain:
|
||||
os.environ["OPENAI_API_KEY"] = "your-api-key"
|
||||
|
||||
config = {
|
||||
"llm": {
|
||||
"provider": "langchain",
|
||||
"config": {
|
||||
"langchain_provider": "OpenAI",
|
||||
"model": "gpt-4o",
|
||||
"temperature": 0.2,
|
||||
"max_tokens": 2000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
messages = [
|
||||
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
||||
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||
]
|
||||
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
## Supported LangChain Providers
|
||||
|
||||
LangChain supports a wide range of LLM providers, including:
|
||||
|
||||
- OpenAI (`ChatOpenAI`)
|
||||
- Anthropic (`ChatAnthropic`)
|
||||
- Google (`ChatGoogleGenerativeAI`, `ChatGooglePalm`)
|
||||
- Mistral (`ChatMistralAI`)
|
||||
- Ollama (`ChatOllama`)
|
||||
- Azure OpenAI (`AzureChatOpenAI`)
|
||||
- HuggingFace (`HuggingFaceChatEndpoint`)
|
||||
- And many more
|
||||
|
||||
You can specify any supported provider in the `langchain_provider` parameter. For a complete and up-to-date list of available providers, refer to the [LangChain Chat Models documentation](https://python.langchain.com/docs/integrations/chat).
|
||||
|
||||
## Provider-Specific Configuration
|
||||
|
||||
When using LangChain as a provider, you'll need to:
|
||||
|
||||
1. Set the appropriate environment variables for your chosen LLM provider
|
||||
2. Specify the LangChain provider class name in the `langchain_provider` parameter
|
||||
3. Include any additional configuration parameters required by the specific provider
|
||||
|
||||
<Note>
|
||||
Make sure to install the necessary LangChain packages and any provider-specific dependencies.
|
||||
</Note>
|
||||
|
||||
## Config
|
||||
|
||||
All available parameters for the `langchain` config are present in [Master List of All Params in Config](../config).
|
||||
@@ -33,6 +33,7 @@ To view all supported llms, visit the [Supported LLMs](./models).
|
||||
<Card title="DeepSeek" href="/components/llms/models/deepseek" />
|
||||
<Card title="xAI" href="/components/llms/models/xAI" />
|
||||
<Card title="LM Studio" href="/components/llms/models/lmstudio" />
|
||||
<Card title="Langchain" href="/components/llms/models/langchain" />
|
||||
</CardGroup>
|
||||
|
||||
## Structured vs Unstructured Outputs
|
||||
|
||||
@@ -111,7 +111,8 @@
|
||||
"components/llms/models/gemini",
|
||||
"components/llms/models/deepseek",
|
||||
"components/llms/models/xAI",
|
||||
"components/llms/models/lmstudio"
|
||||
"components/llms/models/lmstudio",
|
||||
"components/llms/models/langchain"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@@ -41,6 +41,8 @@ class BaseLlmConfig(ABC):
|
||||
xai_base_url: Optional[str] = None,
|
||||
# LM Studio specific
|
||||
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
||||
# Langchain specific
|
||||
langchain_provider: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the LLM.
|
||||
@@ -87,6 +89,8 @@ class BaseLlmConfig(ABC):
|
||||
:type xai_base_url: Optional[str], optional
|
||||
:param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1"
|
||||
:type lmstudio_base_url: Optional[str], optional
|
||||
:param langchain_provider: Langchain provider to be use, defaults to None
|
||||
:type langchain_provider: Optional[str], optional
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
@@ -123,3 +127,6 @@ class BaseLlmConfig(ABC):
|
||||
|
||||
# LM Studio specific
|
||||
self.lmstudio_base_url = lmstudio_base_url
|
||||
|
||||
# Langchain specific
|
||||
self.langchain_provider = langchain_provider
|
||||
|
||||
@@ -25,6 +25,7 @@ class LlmConfig(BaseModel):
|
||||
"deepseek",
|
||||
"xai",
|
||||
"lmstudio",
|
||||
"langchain",
|
||||
):
|
||||
return v
|
||||
else:
|
||||
|
||||
208
mem0/llms/langchain.py
Normal file
208
mem0/llms/langchain.py
Normal file
@@ -0,0 +1,208 @@
|
||||
from typing import Dict, List, Optional
|
||||
import enum
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
# Default import for langchain_community
|
||||
try:
|
||||
from langchain_community import chat_models
|
||||
except ImportError:
|
||||
raise ImportError("langchain_community not found. Please install it with `pip install langchain-community`")
|
||||
|
||||
# Provider-specific package mapping
|
||||
PROVIDER_PACKAGES = {
|
||||
# "Anthropic": "langchain_anthropic", # Special handling for Anthropic with Pydantic v2
|
||||
"MistralAI": "langchain_mistralai",
|
||||
"Fireworks": "langchain_fireworks",
|
||||
"AzureOpenAI": "langchain_openai",
|
||||
"OpenAI": "langchain_openai",
|
||||
"Together": "langchain_together",
|
||||
"VertexAI": "langchain_google_vertexai",
|
||||
"GoogleAI": "langchain_google_genai",
|
||||
"Groq": "langchain_groq",
|
||||
"Cohere": "langchain_cohere",
|
||||
"Bedrock": "langchain_aws",
|
||||
"HuggingFace": "langchain_huggingface",
|
||||
"NVIDIA": "langchain_nvidia_ai_endpoints",
|
||||
"Ollama": "langchain_ollama",
|
||||
"AI21": "langchain_ai21",
|
||||
"Upstage": "langchain_upstage",
|
||||
"Databricks": "databricks_langchain",
|
||||
"Watsonx": "langchain_ibm",
|
||||
"xAI": "langchain_xai",
|
||||
"Perplexity": "langchain_perplexity",
|
||||
}
|
||||
|
||||
|
||||
class LangchainProvider(enum.Enum):
|
||||
Abso = "ChatAbso"
|
||||
AI21 = "ChatAI21"
|
||||
Alibaba = "ChatAlibabaCloud"
|
||||
Anthropic = "ChatAnthropic"
|
||||
Anyscale = "ChatAnyscale"
|
||||
AzureAIChatCompletionsModel = "AzureAIChatCompletionsModel"
|
||||
AzureOpenAI = "AzureChatOpenAI"
|
||||
AzureMLEndpoint = "ChatAzureMLEndpoint"
|
||||
Baichuan = "ChatBaichuan"
|
||||
Qianfan = "ChatQianfan"
|
||||
Bedrock = "ChatBedrock"
|
||||
Cerebras = "ChatCerebras"
|
||||
CloudflareWorkersAI = "ChatCloudflareWorkersAI"
|
||||
Cohere = "ChatCohere"
|
||||
ContextualAI = "ChatContextualAI"
|
||||
Coze = "ChatCoze"
|
||||
Dappier = "ChatDappier"
|
||||
Databricks = "ChatDatabricks"
|
||||
DeepInfra = "ChatDeepInfra"
|
||||
DeepSeek = "ChatDeepSeek"
|
||||
EdenAI = "ChatEdenAI"
|
||||
EverlyAI = "ChatEverlyAI"
|
||||
Fireworks = "ChatFireworks"
|
||||
Friendli = "ChatFriendli"
|
||||
GigaChat = "ChatGigaChat"
|
||||
Goodfire = "ChatGoodfire"
|
||||
GoogleAI = "ChatGoogleAI"
|
||||
VertexAI = "VertexAI"
|
||||
GPTRouter = "ChatGPTRouter"
|
||||
Groq = "ChatGroq"
|
||||
HuggingFace = "ChatHuggingFace"
|
||||
Watsonx = "ChatWatsonx"
|
||||
Jina = "ChatJina"
|
||||
Kinetica = "ChatKinetica"
|
||||
Konko = "ChatKonko"
|
||||
LiteLLM = "ChatLiteLLM"
|
||||
LiteLLMRouter = "ChatLiteLLMRouter"
|
||||
Llama2Chat = "Llama2Chat"
|
||||
LlamaAPI = "ChatLlamaAPI"
|
||||
LlamaEdge = "ChatLlamaEdge"
|
||||
LlamaCpp = "ChatLlamaCpp"
|
||||
Maritalk = "ChatMaritalk"
|
||||
MiniMax = "ChatMiniMax"
|
||||
MistralAI = "ChatMistralAI"
|
||||
MLX = "ChatMLX"
|
||||
ModelScope = "ChatModelScope"
|
||||
Moonshot = "ChatMoonshot"
|
||||
Naver = "ChatNaver"
|
||||
Netmind = "ChatNetmind"
|
||||
NVIDIA = "ChatNVIDIA"
|
||||
OCIModelDeployment = "ChatOCIModelDeployment"
|
||||
OCIGenAI = "ChatOCIGenAI"
|
||||
OctoAI = "ChatOctoAI"
|
||||
Ollama = "ChatOllama"
|
||||
OpenAI = "ChatOpenAI"
|
||||
Outlines = "ChatOutlines"
|
||||
Perplexity = "ChatPerplexity"
|
||||
Pipeshift = "ChatPipeshift"
|
||||
PredictionGuard = "ChatPredictionGuard"
|
||||
PremAI = "ChatPremAI"
|
||||
PromptLayerOpenAI = "PromptLayerChatOpenAI"
|
||||
QwQ = "ChatQwQ"
|
||||
Reka = "ChatReka"
|
||||
RunPod = "ChatRunPod"
|
||||
SambaNovaCloud = "ChatSambaNovaCloud"
|
||||
SambaStudio = "ChatSambaStudio"
|
||||
SeekrFlow = "ChatSeekrFlow"
|
||||
SnowflakeCortex = "ChatSnowflakeCortex"
|
||||
Solar = "ChatSolar"
|
||||
SparkLLM = "ChatSparkLLM"
|
||||
Nebula = "ChatNebula"
|
||||
Hunyuan = "ChatHunyuan"
|
||||
Together = "ChatTogether"
|
||||
TongyiQwen = "ChatTongyiQwen"
|
||||
Upstage = "ChatUpstage"
|
||||
Vectara = "ChatVectara"
|
||||
VLLM = "ChatVLLM"
|
||||
VolcEngine = "ChatVolcEngine"
|
||||
Writer = "ChatWriter"
|
||||
xAI = "ChatXAI"
|
||||
Xinference = "ChatXinference"
|
||||
Yandex = "ChatYandex"
|
||||
Yi = "ChatYi"
|
||||
Yuan2 = "ChatYuan2"
|
||||
ZhipuAI = "ChatZhipuAI"
|
||||
|
||||
|
||||
class LangchainLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
provider = self.config.langchain_provider
|
||||
if provider not in LangchainProvider.__members__:
|
||||
raise ValueError(f"Invalid provider: {provider}")
|
||||
model_name = LangchainProvider[provider].value
|
||||
|
||||
try:
|
||||
# Check if this provider needs a specialized package
|
||||
if provider in PROVIDER_PACKAGES:
|
||||
package_name = PROVIDER_PACKAGES[provider]
|
||||
try:
|
||||
# Import the model class directly from the package
|
||||
module_path = f"{package_name}"
|
||||
model_class = __import__(module_path, fromlist=[model_name])
|
||||
model_class = getattr(model_class, model_name)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
f"Package {package_name} not found. " f"Please install it with `pip install {package_name}`"
|
||||
)
|
||||
except AttributeError:
|
||||
raise ImportError(f"Model {model_name} not found in {package_name}")
|
||||
else:
|
||||
# Use the default langchain_community module
|
||||
if not hasattr(chat_models, model_name):
|
||||
raise ImportError(f"Provider {provider} not found in langchain_community.chat_models")
|
||||
|
||||
model_class = getattr(chat_models, model_name)
|
||||
|
||||
# Initialize the model with relevant config parameters
|
||||
self.langchain_model = model_class(
|
||||
model=self.config.model,
|
||||
temperature=self.config.temperature,
|
||||
max_tokens=self.config.max_tokens,
|
||||
api_key=self.config.api_key,
|
||||
)
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ImportError(f"Error setting up langchain model for provider {provider}: {str(e)}")
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using langchain_community.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Not used in Langchain.
|
||||
tools (list, optional): List of tools that the model can call. Not used in Langchain.
|
||||
tool_choice (str, optional): Tool choice method. Not used in Langchain.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
try:
|
||||
# Convert the messages to LangChain's tuple format
|
||||
langchain_messages = []
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
if role == "system":
|
||||
langchain_messages.append(("system", content))
|
||||
elif role == "user":
|
||||
langchain_messages.append(("human", content))
|
||||
elif role == "assistant":
|
||||
langchain_messages.append(("ai", content))
|
||||
|
||||
if not langchain_messages:
|
||||
raise ValueError("No valid messages found in the messages list")
|
||||
|
||||
ai_message = self.langchain_model.invoke(langchain_messages)
|
||||
|
||||
return ai_message.content
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Error generating response using langchain model: {str(e)}")
|
||||
@@ -26,6 +26,7 @@ class LlmFactory:
|
||||
"deepseek": "mem0.llms.deepseek.DeepSeekLLM",
|
||||
"xai": "mem0.llms.xai.XAILLM",
|
||||
"lmstudio": "mem0.llms.lmstudio.LMStudioLLM",
|
||||
"langchain": "mem0.llms.langchain.LangchainLLM",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
100
tests/llms/test_langchain.py
Normal file
100
tests/llms/test_langchain.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.langchain import LangchainLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_langchain_model():
|
||||
"""Mock a Langchain model for testing."""
|
||||
with patch("langchain_openai.ChatOpenAI") as mock_chat_model:
|
||||
mock_model = Mock()
|
||||
mock_model.invoke.return_value = Mock(content="This is a test response")
|
||||
mock_chat_model.return_value = mock_model
|
||||
yield mock_model
|
||||
|
||||
|
||||
def test_langchain_initialization():
|
||||
"""Test that LangchainLLM initializes correctly with a valid provider."""
|
||||
with patch("langchain_openai.ChatOpenAI") as mock_chat_model:
|
||||
# Setup the mock model
|
||||
mock_model = Mock()
|
||||
mock_chat_model.return_value = mock_model
|
||||
|
||||
# Create a config with OpenAI provider
|
||||
config = BaseLlmConfig(
|
||||
model="gpt-3.5-turbo",
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
api_key="test-api-key",
|
||||
langchain_provider="OpenAI"
|
||||
)
|
||||
|
||||
# Initialize the LangchainLLM
|
||||
llm = LangchainLLM(config)
|
||||
|
||||
# Verify the model was initialized with correct parameters
|
||||
mock_chat_model.assert_called_once_with(
|
||||
model="gpt-3.5-turbo",
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
api_key="test-api-key"
|
||||
)
|
||||
|
||||
assert llm.langchain_model == mock_model
|
||||
|
||||
|
||||
def test_generate_response(mock_langchain_model):
|
||||
"""Test that generate_response correctly processes messages and returns a response."""
|
||||
# Create a config with OpenAI provider
|
||||
config = BaseLlmConfig(
|
||||
model="gpt-3.5-turbo",
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
api_key="test-api-key",
|
||||
langchain_provider="OpenAI"
|
||||
)
|
||||
|
||||
# Initialize the LangchainLLM
|
||||
with patch("langchain_openai.ChatOpenAI", return_value=mock_langchain_model):
|
||||
llm = LangchainLLM(config)
|
||||
|
||||
# Create test messages
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
|
||||
{"role": "user", "content": "Tell me a joke."}
|
||||
]
|
||||
|
||||
# Get response
|
||||
response = llm.generate_response(messages)
|
||||
|
||||
# Verify the correct message format was passed to the model
|
||||
expected_langchain_messages = [
|
||||
("system", "You are a helpful assistant."),
|
||||
("human", "Hello, how are you?"),
|
||||
("ai", "I'm doing well! How can I help you?"),
|
||||
("human", "Tell me a joke.")
|
||||
]
|
||||
|
||||
mock_langchain_model.invoke.assert_called_once()
|
||||
# Extract the first argument of the first call
|
||||
actual_messages = mock_langchain_model.invoke.call_args[0][0]
|
||||
assert actual_messages == expected_langchain_messages
|
||||
assert response == "This is a test response"
|
||||
|
||||
|
||||
def test_invalid_provider():
|
||||
"""Test that LangchainLLM raises an error with an invalid provider."""
|
||||
config = BaseLlmConfig(
|
||||
model="test-model",
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
api_key="test-api-key",
|
||||
langchain_provider="InvalidProvider"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid provider: InvalidProvider"):
|
||||
LangchainLLM(config)
|
||||
Reference in New Issue
Block a user