Add langchain embedding, update langchain LLM and version bump -> 0.1.84 (#2510)

This commit is contained in:
Dev Khant
2025-04-07 15:27:26 +05:30
committed by GitHub
parent 5509066925
commit 9dfa9b4412
14 changed files with 266 additions and 253 deletions

View File

@@ -6,6 +6,15 @@ mode: "wide"
<Tabs> <Tabs>
<Tab title="Python"> <Tab title="Python">
<Update label="2025-04-07" description="v0.1.84">
**New Features:**
- **Langchain Embedder:** Added Langchain embedder integration
**Improvements:**
- **Langchain LLM:** Updated Langchain LLM integration to directly pass the Langchain object LLM
</Update>
<Update label="2025-04-07" description="v0.1.83"> <Update label="2025-04-07" description="v0.1.83">
**Bug Fixes:** **Bug Fixes:**

View File

@@ -0,0 +1,120 @@
---
title: LangChain
---
Mem0 supports LangChain as a provider to access a wide range of embedding models. LangChain is a framework for developing applications powered by language models, making it easy to integrate various embedding providers through a consistent interface.
For a complete list of available embedding models supported by LangChain, refer to the [LangChain Text Embedding documentation](https://python.langchain.com/docs/integrations/text_embedding/).
## Usage
<CodeGroup>
```python Python
import os
from mem0 import Memory
from langchain_openai import OpenAIEmbeddings
# Set necessary environment variables for your chosen LangChain provider
os.environ["OPENAI_API_KEY"] = "your-api-key"
# Initialize a LangChain embeddings model directly
openai_embeddings = OpenAIEmbeddings(
model="text-embedding-3-small",
dimensions=1536
)
# Pass the initialized model to the config
config = {
"embedder": {
"provider": "langchain",
"config": {
"model": openai_embeddings
}
}
}
m = Memory.from_config(config)
messages = [
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
]
m.add(messages, user_id="alice", metadata={"category": "movies"})
```
</CodeGroup>
## Supported LangChain Embedding Providers
LangChain supports a wide range of embedding providers, including:
- OpenAI (`OpenAIEmbeddings`)
- Cohere (`CohereEmbeddings`)
- Google (`VertexAIEmbeddings`)
- Hugging Face (`HuggingFaceEmbeddings`)
- Sentence Transformers (`HuggingFaceEmbeddings`)
- Azure OpenAI (`AzureOpenAIEmbeddings`)
- Ollama (`OllamaEmbeddings`)
- Together (`TogetherEmbeddings`)
- And many more
You can use any of these model instances directly in your configuration. For a complete and up-to-date list of available embedding providers, refer to the [LangChain Text Embedding documentation](https://python.langchain.com/docs/integrations/text_embedding/).
## Provider-Specific Configuration
When using LangChain as an embedder provider, you'll need to:
1. Set the appropriate environment variables for your chosen embedding provider
2. Import and initialize the specific model class you want to use
3. Pass the initialized model instance to the config
### Examples with Different Providers
#### HuggingFace Embeddings
```python
from langchain_huggingface import HuggingFaceEmbeddings
# Initialize a HuggingFace embeddings model
hf_embeddings = HuggingFaceEmbeddings(
model_name="BAAI/bge-small-en-v1.5",
encode_kwargs={"normalize_embeddings": True}
)
config = {
"embedder": {
"provider": "langchain",
"config": {
"model": hf_embeddings
}
}
}
```
#### Ollama Embeddings
```python
from langchain_ollama import OllamaEmbeddings
# Initialize an Ollama embeddings model
ollama_embeddings = OllamaEmbeddings(
model="nomic-embed-text"
)
config = {
"embedder": {
"provider": "langchain",
"config": {
"model": ollama_embeddings
}
}
}
```
<Note>
Make sure to install the necessary LangChain packages and any provider-specific dependencies.
</Note>
## Config
All available parameters for the `langchain` embedder config are present in [Master List of All Params in Config](../config).

View File

@@ -23,6 +23,7 @@ See the list of supported embedders below.
<Card title="Vertex AI" href="/components/embedders/models/vertexai"></Card> <Card title="Vertex AI" href="/components/embedders/models/vertexai"></Card>
<Card title="Together" href="/components/embedders/models/together"></Card> <Card title="Together" href="/components/embedders/models/together"></Card>
<Card title="LM Studio" href="/components/embedders/models/lmstudio"></Card> <Card title="LM Studio" href="/components/embedders/models/lmstudio"></Card>
<Card title="Langchain" href="/components/embedders/models/langchain"></Card>
</CardGroup> </CardGroup>
## Usage ## Usage

View File

@@ -109,7 +109,6 @@ Here's a comprehensive list of all parameters that can be used across different
| `deepseek_base_url` | Base URL for DeepSeek API | DeepSeek | | `deepseek_base_url` | Base URL for DeepSeek API | DeepSeek |
| `xai_base_url` | Base URL for XAI API | XAI | | `xai_base_url` | Base URL for XAI API | XAI |
| `lmstudio_base_url` | Base URL for LM Studio API | LM Studio | | `lmstudio_base_url` | Base URL for LM Studio API | LM Studio |
| `langchain_provider` | Provider for Langchain | Langchain |
</Tab> </Tab>
<Tab title="TypeScript"> <Tab title="TypeScript">
| Parameter | Description | Provider | | Parameter | Description | Provider |

View File

@@ -12,19 +12,24 @@ For a complete list of available chat models supported by LangChain, refer to th
```python Python ```python Python
import os import os
from mem0 import Memory from mem0 import Memory
from langchain_openai import ChatOpenAI
# Set necessary environment variables for your chosen LangChain provider # Set necessary environment variables for your chosen LangChain provider
# For example, if using OpenAI through LangChain:
os.environ["OPENAI_API_KEY"] = "your-api-key" os.environ["OPENAI_API_KEY"] = "your-api-key"
# Initialize a LangChain model directly
openai_model = ChatOpenAI(
model="gpt-4o",
temperature=0.2,
max_tokens=2000
)
# Pass the initialized model to the config
config = { config = {
"llm": { "llm": {
"provider": "langchain", "provider": "langchain",
"config": { "config": {
"langchain_provider": "OpenAI", "model": openai_model
"model": "gpt-4o",
"temperature": 0.2,
"max_tokens": 2000,
} }
} }
} }
@@ -53,15 +58,15 @@ LangChain supports a wide range of LLM providers, including:
- HuggingFace (`HuggingFaceChatEndpoint`) - HuggingFace (`HuggingFaceChatEndpoint`)
- And many more - And many more
You can specify any supported provider in the `langchain_provider` parameter. For a complete and up-to-date list of available providers, refer to the [LangChain Chat Models documentation](https://python.langchain.com/docs/integrations/chat). You can use any of these model instances directly in your configuration. For a complete and up-to-date list of available providers, refer to the [LangChain Chat Models documentation](https://python.langchain.com/docs/integrations/chat).
## Provider-Specific Configuration ## Provider-Specific Configuration
When using LangChain as a provider, you'll need to: When using LangChain as a provider, you'll need to:
1. Set the appropriate environment variables for your chosen LLM provider 1. Set the appropriate environment variables for your chosen LLM provider
2. Specify the LangChain provider class name in the `langchain_provider` parameter 2. Import and initialize the specific model class you want to use
3. Include any additional configuration parameters required by the specific provider 3. Pass the initialized model instance to the config
<Note> <Note>
Make sure to install the necessary LangChain packages and any provider-specific dependencies. Make sure to install the necessary LangChain packages and any provider-specific dependencies.

View File

@@ -161,7 +161,8 @@
"components/embedders/models/vertexai", "components/embedders/models/vertexai",
"components/embedders/models/gemini", "components/embedders/models/gemini",
"components/embedders/models/lmstudio", "components/embedders/models/lmstudio",
"components/embedders/models/together" "components/embedders/models/together",
"components/embedders/models/langchain"
] ]
} }
] ]

View File

@@ -13,7 +13,7 @@ class BaseLlmConfig(ABC):
def __init__( def __init__(
self, self,
model: Optional[str] = None, model: Optional[Union[str, Dict]] = None,
temperature: float = 0.1, temperature: float = 0.1,
api_key: Optional[str] = None, api_key: Optional[str] = None,
max_tokens: int = 2000, max_tokens: int = 2000,
@@ -41,8 +41,6 @@ class BaseLlmConfig(ABC):
xai_base_url: Optional[str] = None, xai_base_url: Optional[str] = None,
# LM Studio specific # LM Studio specific
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1", lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
# Langchain specific
langchain_provider: Optional[str] = None,
): ):
""" """
Initializes a configuration class instance for the LLM. Initializes a configuration class instance for the LLM.
@@ -89,8 +87,6 @@ class BaseLlmConfig(ABC):
:type xai_base_url: Optional[str], optional :type xai_base_url: Optional[str], optional
:param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1" :param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1"
:type lmstudio_base_url: Optional[str], optional :type lmstudio_base_url: Optional[str], optional
:param langchain_provider: Langchain provider to be use, defaults to None
:type langchain_provider: Optional[str], optional
""" """
self.model = model self.model = model
@@ -127,6 +123,3 @@ class BaseLlmConfig(ABC):
# LM Studio specific # LM Studio specific
self.lmstudio_base_url = lmstudio_base_url self.lmstudio_base_url = lmstudio_base_url
# Langchain specific
self.langchain_provider = langchain_provider

View File

@@ -22,6 +22,7 @@ class EmbedderConfig(BaseModel):
"vertexai", "vertexai",
"together", "together",
"lmstudio", "lmstudio",
"langchain",
]: ]:
return v return v
else: else:

View File

@@ -0,0 +1,36 @@
import os
from typing import Literal, Optional
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase
try:
from langchain.embeddings.base import Embeddings
except ImportError:
raise ImportError("langchain is not installed. Please install it using `pip install langchain`")
class LangchainEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
if self.config.model is None:
raise ValueError("`model` parameter is required")
if not isinstance(self.config.model, Embeddings):
raise ValueError("`model` must be an instance of Embeddings")
self.langchain_model = self.config.model
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
"""
Get the embedding for the given text using Langchain.
Args:
text (str): The text to embed.
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns:
list: The embedding vector.
"""
return self.langchain_model.embed_query(text)

View File

@@ -1,174 +1,25 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional
import enum
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
# Default import for langchain_community
try: try:
from langchain_community import chat_models from langchain.chat_models.base import BaseChatModel
except ImportError: except ImportError:
raise ImportError("langchain_community not found. Please install it with `pip install langchain-community`") raise ImportError("langchain is not installed. Please install it using `pip install langchain`")
# Provider-specific package mapping
PROVIDER_PACKAGES = {
"Anthropic": "langchain_anthropic",
"MistralAI": "langchain_mistralai",
"Fireworks": "langchain_fireworks",
"AzureOpenAI": "langchain_openai",
"OpenAI": "langchain_openai",
"Together": "langchain_together",
"VertexAI": "langchain_google_vertexai",
"GoogleAI": "langchain_google_genai",
"Groq": "langchain_groq",
"Cohere": "langchain_cohere",
"Bedrock": "langchain_aws",
"HuggingFace": "langchain_huggingface",
"NVIDIA": "langchain_nvidia_ai_endpoints",
"Ollama": "langchain_ollama",
"AI21": "langchain_ai21",
"Upstage": "langchain_upstage",
"Databricks": "databricks_langchain",
"Watsonx": "langchain_ibm",
"xAI": "langchain_xai",
"Perplexity": "langchain_perplexity",
}
class LangchainProvider(enum.Enum):
Abso = "ChatAbso"
AI21 = "ChatAI21"
Alibaba = "ChatAlibabaCloud"
Anthropic = "ChatAnthropic"
Anyscale = "ChatAnyscale"
AzureAIChatCompletionsModel = "AzureAIChatCompletionsModel"
AzureOpenAI = "AzureChatOpenAI"
AzureMLEndpoint = "ChatAzureMLEndpoint"
Baichuan = "ChatBaichuan"
Qianfan = "ChatQianfan"
Bedrock = "ChatBedrock"
Cerebras = "ChatCerebras"
CloudflareWorkersAI = "ChatCloudflareWorkersAI"
Cohere = "ChatCohere"
ContextualAI = "ChatContextualAI"
Coze = "ChatCoze"
Dappier = "ChatDappier"
Databricks = "ChatDatabricks"
DeepInfra = "ChatDeepInfra"
DeepSeek = "ChatDeepSeek"
EdenAI = "ChatEdenAI"
EverlyAI = "ChatEverlyAI"
Fireworks = "ChatFireworks"
Friendli = "ChatFriendli"
GigaChat = "ChatGigaChat"
Goodfire = "ChatGoodfire"
GoogleAI = "ChatGoogleAI"
VertexAI = "VertexAI"
GPTRouter = "ChatGPTRouter"
Groq = "ChatGroq"
HuggingFace = "ChatHuggingFace"
Watsonx = "ChatWatsonx"
Jina = "ChatJina"
Kinetica = "ChatKinetica"
Konko = "ChatKonko"
LiteLLM = "ChatLiteLLM"
LiteLLMRouter = "ChatLiteLLMRouter"
Llama2Chat = "Llama2Chat"
LlamaAPI = "ChatLlamaAPI"
LlamaEdge = "ChatLlamaEdge"
LlamaCpp = "ChatLlamaCpp"
Maritalk = "ChatMaritalk"
MiniMax = "ChatMiniMax"
MistralAI = "ChatMistralAI"
MLX = "ChatMLX"
ModelScope = "ChatModelScope"
Moonshot = "ChatMoonshot"
Naver = "ChatNaver"
Netmind = "ChatNetmind"
NVIDIA = "ChatNVIDIA"
OCIModelDeployment = "ChatOCIModelDeployment"
OCIGenAI = "ChatOCIGenAI"
OctoAI = "ChatOctoAI"
Ollama = "ChatOllama"
OpenAI = "ChatOpenAI"
Outlines = "ChatOutlines"
Perplexity = "ChatPerplexity"
Pipeshift = "ChatPipeshift"
PredictionGuard = "ChatPredictionGuard"
PremAI = "ChatPremAI"
PromptLayerOpenAI = "PromptLayerChatOpenAI"
QwQ = "ChatQwQ"
Reka = "ChatReka"
RunPod = "ChatRunPod"
SambaNovaCloud = "ChatSambaNovaCloud"
SambaStudio = "ChatSambaStudio"
SeekrFlow = "ChatSeekrFlow"
SnowflakeCortex = "ChatSnowflakeCortex"
Solar = "ChatSolar"
SparkLLM = "ChatSparkLLM"
Nebula = "ChatNebula"
Hunyuan = "ChatHunyuan"
Together = "ChatTogether"
TongyiQwen = "ChatTongyiQwen"
Upstage = "ChatUpstage"
Vectara = "ChatVectara"
VLLM = "ChatVLLM"
VolcEngine = "ChatVolcEngine"
Writer = "ChatWriter"
xAI = "ChatXAI"
Xinference = "ChatXinference"
Yandex = "ChatYandex"
Yi = "ChatYi"
Yuan2 = "ChatYuan2"
ZhipuAI = "ChatZhipuAI"
class LangchainLLM(LLMBase): class LangchainLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config) super().__init__(config)
provider = self.config.langchain_provider if self.config.model is None:
if provider not in LangchainProvider.__members__: raise ValueError("`model` parameter is required")
raise ValueError(f"Invalid provider: {provider}")
model_name = LangchainProvider[provider].value
try: if not isinstance(self.config.model, BaseChatModel):
# Check if this provider needs a specialized package raise ValueError("`model` must be an instance of BaseChatModel")
if provider in PROVIDER_PACKAGES:
if provider == "Anthropic": # Special handling for Anthropic with Pydantic v2
try:
from langchain_anthropic import ChatAnthropic
model_class = ChatAnthropic
except ImportError:
raise ImportError("langchain_anthropic not found. Please install it with `pip install langchain-anthropic`")
else:
package_name = PROVIDER_PACKAGES[provider]
try:
# Import the model class directly from the package
module_path = f"{package_name}"
model_class = __import__(module_path, fromlist=[model_name])
model_class = getattr(model_class, model_name)
except ImportError:
raise ImportError(
f"Package {package_name} not found. " f"Please install it with `pip install {package_name}`"
)
except AttributeError:
raise ImportError(f"Model {model_name} not found in {package_name}")
else:
# Use the default langchain_community module
if not hasattr(chat_models, model_name):
raise ImportError(f"Provider {provider} not found in langchain_community.chat_models")
model_class = getattr(chat_models, model_name) self.langchain_model = self.config.model
# Initialize the model with relevant config parameters
self.langchain_model = model_class(
model=self.config.model,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens
)
except (ImportError, AttributeError, ValueError) as e:
raise ImportError(f"Error setting up langchain model for provider {provider}: {str(e)}")
def generate_response( def generate_response(
self, self,

View File

@@ -623,14 +623,13 @@ class Memory(MemoryBase):
capture_event("mem0._create_memory", self, {"memory_id": memory_id}) capture_event("mem0._create_memory", self, {"memory_id": memory_id})
return memory_id return memory_id
def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None): def _create_procedural_memory(self, messages, metadata=None, prompt=None):
""" """
Create a procedural memory Create a procedural memory
Args: Args:
messages (list): List of messages to create a procedural memory from. messages (list): List of messages to create a procedural memory from.
metadata (dict): Metadata to create a procedural memory from. metadata (dict): Metadata to create a procedural memory from.
llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel.
prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None. prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None.
""" """
try: try:
@@ -650,12 +649,7 @@ class Memory(MemoryBase):
] ]
try: try:
if llm is not None: procedural_memory = self.llm.generate_response(messages=parsed_messages)
parsed_messages = convert_to_messages(parsed_messages)
response = llm.invoke(input=parsed_messages)
procedural_memory = response.content
else:
procedural_memory = self.llm.generate_response(messages=parsed_messages)
except Exception as e: except Exception as e:
logger.error(f"Error generating procedural memory summary: {e}") logger.error(f"Error generating procedural memory summary: {e}")
raise raise

View File

@@ -50,6 +50,7 @@ class EmbedderFactory:
"vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding", "vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding",
"together": "mem0.embeddings.together.TogetherEmbedding", "together": "mem0.embeddings.together.TogetherEmbedding",
"lmstudio": "mem0.embeddings.lmstudio.LMStudioEmbedding", "lmstudio": "mem0.embeddings.lmstudio.LMStudioEmbedding",
"langchain": "mem0.embeddings.langchain.LangchainEmbedding",
} }
@classmethod @classmethod

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "mem0ai" name = "mem0ai"
version = "0.1.83" version = "0.1.84"
description = "Long-term memory for AI Agents" description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"] authors = ["Mem0 <founders@mem0.ai>"]
exclude = [ exclude = [

View File

@@ -4,97 +4,99 @@ import pytest
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.langchain import LangchainLLM from mem0.llms.langchain import LangchainLLM
# Add the import for BaseChatModel
try:
from langchain.chat_models.base import BaseChatModel
except ImportError:
from unittest.mock import MagicMock
BaseChatModel = MagicMock
@pytest.fixture @pytest.fixture
def mock_langchain_model(): def mock_langchain_model():
"""Mock a Langchain model for testing.""" """Mock a Langchain model for testing."""
with patch("langchain_openai.ChatOpenAI") as mock_chat_model: mock_model = Mock(spec=BaseChatModel)
mock_model = Mock() mock_model.invoke.return_value = Mock(content="This is a test response")
mock_model.invoke.return_value = Mock(content="This is a test response") return mock_model
mock_chat_model.return_value = mock_model
yield mock_model
def test_langchain_initialization(): def test_langchain_initialization(mock_langchain_model):
"""Test that LangchainLLM initializes correctly with a valid provider.""" """Test that LangchainLLM initializes correctly with a valid model."""
with patch("langchain_openai.ChatOpenAI") as mock_chat_model: # Create a config with the model instance directly
# Setup the mock model config = BaseLlmConfig(
mock_model = Mock() model=mock_langchain_model,
mock_chat_model.return_value = mock_model temperature=0.7,
max_tokens=100,
# Create a config with OpenAI provider api_key="test-api-key"
config = BaseLlmConfig( )
model="gpt-3.5-turbo",
temperature=0.7, # Initialize the LangchainLLM
max_tokens=100, llm = LangchainLLM(config)
api_key="test-api-key",
langchain_provider="OpenAI" # Verify the model was correctly assigned
) assert llm.langchain_model == mock_langchain_model
# Initialize the LangchainLLM
llm = LangchainLLM(config)
# Verify the model was initialized with correct parameters
mock_chat_model.assert_called_once_with(
model="gpt-3.5-turbo",
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
assert llm.langchain_model == mock_model
def test_generate_response(mock_langchain_model): def test_generate_response(mock_langchain_model):
"""Test that generate_response correctly processes messages and returns a response.""" """Test that generate_response correctly processes messages and returns a response."""
# Create a config with OpenAI provider # Create a config with the model instance
config = BaseLlmConfig( config = BaseLlmConfig(
model="gpt-3.5-turbo", model=mock_langchain_model,
temperature=0.7, temperature=0.7,
max_tokens=100, max_tokens=100,
api_key="test-api-key", api_key="test-api-key"
langchain_provider="OpenAI"
) )
# Initialize the LangchainLLM # Initialize the LangchainLLM
with patch("langchain_openai.ChatOpenAI", return_value=mock_langchain_model): llm = LangchainLLM(config)
llm = LangchainLLM(config)
# Create test messages
# Create test messages messages = [
messages = [ {"role": "system", "content": "You are a helpful assistant."},
{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"},
{"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing well! How can I help you?"},
{"role": "assistant", "content": "I'm doing well! How can I help you?"}, {"role": "user", "content": "Tell me a joke."}
{"role": "user", "content": "Tell me a joke."} ]
]
# Get response
# Get response response = llm.generate_response(messages)
response = llm.generate_response(messages)
# Verify the correct message format was passed to the model
# Verify the correct message format was passed to the model expected_langchain_messages = [
expected_langchain_messages = [ ("system", "You are a helpful assistant."),
("system", "You are a helpful assistant."), ("human", "Hello, how are you?"),
("human", "Hello, how are you?"), ("ai", "I'm doing well! How can I help you?"),
("ai", "I'm doing well! How can I help you?"), ("human", "Tell me a joke.")
("human", "Tell me a joke.") ]
]
mock_langchain_model.invoke.assert_called_once()
mock_langchain_model.invoke.assert_called_once() # Extract the first argument of the first call
# Extract the first argument of the first call actual_messages = mock_langchain_model.invoke.call_args[0][0]
actual_messages = mock_langchain_model.invoke.call_args[0][0] assert actual_messages == expected_langchain_messages
assert actual_messages == expected_langchain_messages assert response == "This is a test response"
assert response == "This is a test response"
def test_invalid_provider(): def test_invalid_model():
"""Test that LangchainLLM raises an error with an invalid provider.""" """Test that LangchainLLM raises an error with an invalid model."""
config = BaseLlmConfig( config = BaseLlmConfig(
model="test-model", model="not-a-valid-model-instance",
temperature=0.7, temperature=0.7,
max_tokens=100, max_tokens=100,
api_key="test-api-key", api_key="test-api-key"
langchain_provider="InvalidProvider"
) )
with pytest.raises(ValueError, match="Invalid provider: InvalidProvider"): with pytest.raises(ValueError, match="`model` must be an instance of BaseChatModel"):
LangchainLLM(config)
def test_missing_model():
"""Test that LangchainLLM raises an error when model is None."""
config = BaseLlmConfig(
model=None,
temperature=0.7,
max_tokens=100,
api_key="test-api-key"
)
with pytest.raises(ValueError, match="`model` parameter is required"):
LangchainLLM(config) LangchainLLM(config)