From 9dfa9b441243ba9c595e75ad285968a3cdbf146e Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Mon, 7 Apr 2025 15:27:26 +0530 Subject: [PATCH] Add langchain embedding, update langchain LLM and version bump -> 0.1.84 (#2510) --- docs/changelog/overview.mdx | 9 + .../components/embedders/models/langchain.mdx | 120 +++++++++++++ docs/components/embedders/overview.mdx | 1 + docs/components/llms/config.mdx | 1 - docs/components/llms/models/langchain.mdx | 21 ++- docs/docs.json | 3 +- mem0/configs/llms/base.py | 9 +- mem0/embeddings/configs.py | 1 + mem0/embeddings/langchain.py | 36 ++++ mem0/llms/langchain.py | 163 +----------------- mem0/memory/main.py | 10 +- mem0/utils/factory.py | 1 + pyproject.toml | 2 +- tests/llms/test_langchain.py | 142 +++++++-------- 14 files changed, 266 insertions(+), 253 deletions(-) create mode 100644 docs/components/embedders/models/langchain.mdx create mode 100644 mem0/embeddings/langchain.py diff --git a/docs/changelog/overview.mdx b/docs/changelog/overview.mdx index 96e708b3..5969a178 100644 --- a/docs/changelog/overview.mdx +++ b/docs/changelog/overview.mdx @@ -6,6 +6,15 @@ mode: "wide" + + +**New Features:** +- **Langchain Embedder:** Added Langchain embedder integration + +**Improvements:** +- **Langchain LLM:** Updated Langchain LLM integration to directly pass the Langchain object LLM + + **Bug Fixes:** diff --git a/docs/components/embedders/models/langchain.mdx b/docs/components/embedders/models/langchain.mdx new file mode 100644 index 00000000..00dc4c77 --- /dev/null +++ b/docs/components/embedders/models/langchain.mdx @@ -0,0 +1,120 @@ +--- +title: LangChain +--- + +Mem0 supports LangChain as a provider to access a wide range of embedding models. LangChain is a framework for developing applications powered by language models, making it easy to integrate various embedding providers through a consistent interface. + +For a complete list of available embedding models supported by LangChain, refer to the [LangChain Text Embedding documentation](https://python.langchain.com/docs/integrations/text_embedding/). + +## Usage + + +```python Python +import os +from mem0 import Memory +from langchain_openai import OpenAIEmbeddings + +# Set necessary environment variables for your chosen LangChain provider +os.environ["OPENAI_API_KEY"] = "your-api-key" + +# Initialize a LangChain embeddings model directly +openai_embeddings = OpenAIEmbeddings( + model="text-embedding-3-small", + dimensions=1536 +) + +# Pass the initialized model to the config +config = { + "embedder": { + "provider": "langchain", + "config": { + "model": openai_embeddings + } + } +} + +m = Memory.from_config(config) +messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, + {"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] +m.add(messages, user_id="alice", metadata={"category": "movies"}) +``` + + +## Supported LangChain Embedding Providers + +LangChain supports a wide range of embedding providers, including: + +- OpenAI (`OpenAIEmbeddings`) +- Cohere (`CohereEmbeddings`) +- Google (`VertexAIEmbeddings`) +- Hugging Face (`HuggingFaceEmbeddings`) +- Sentence Transformers (`HuggingFaceEmbeddings`) +- Azure OpenAI (`AzureOpenAIEmbeddings`) +- Ollama (`OllamaEmbeddings`) +- Together (`TogetherEmbeddings`) +- And many more + +You can use any of these model instances directly in your configuration. For a complete and up-to-date list of available embedding providers, refer to the [LangChain Text Embedding documentation](https://python.langchain.com/docs/integrations/text_embedding/). + +## Provider-Specific Configuration + +When using LangChain as an embedder provider, you'll need to: + +1. Set the appropriate environment variables for your chosen embedding provider +2. Import and initialize the specific model class you want to use +3. Pass the initialized model instance to the config + +### Examples with Different Providers + +#### HuggingFace Embeddings + +```python +from langchain_huggingface import HuggingFaceEmbeddings + +# Initialize a HuggingFace embeddings model +hf_embeddings = HuggingFaceEmbeddings( + model_name="BAAI/bge-small-en-v1.5", + encode_kwargs={"normalize_embeddings": True} +) + +config = { + "embedder": { + "provider": "langchain", + "config": { + "model": hf_embeddings + } + } +} +``` + +#### Ollama Embeddings + +```python +from langchain_ollama import OllamaEmbeddings + +# Initialize an Ollama embeddings model +ollama_embeddings = OllamaEmbeddings( + model="nomic-embed-text" +) + +config = { + "embedder": { + "provider": "langchain", + "config": { + "model": ollama_embeddings + } + } +} +``` + + + Make sure to install the necessary LangChain packages and any provider-specific dependencies. + + +## Config + +All available parameters for the `langchain` embedder config are present in [Master List of All Params in Config](../config). diff --git a/docs/components/embedders/overview.mdx b/docs/components/embedders/overview.mdx index 7630e731..f98f2db4 100644 --- a/docs/components/embedders/overview.mdx +++ b/docs/components/embedders/overview.mdx @@ -23,6 +23,7 @@ See the list of supported embedders below. + ## Usage diff --git a/docs/components/llms/config.mdx b/docs/components/llms/config.mdx index 77b75229..57c8f70c 100644 --- a/docs/components/llms/config.mdx +++ b/docs/components/llms/config.mdx @@ -109,7 +109,6 @@ Here's a comprehensive list of all parameters that can be used across different | `deepseek_base_url` | Base URL for DeepSeek API | DeepSeek | | `xai_base_url` | Base URL for XAI API | XAI | | `lmstudio_base_url` | Base URL for LM Studio API | LM Studio | - | `langchain_provider` | Provider for Langchain | Langchain | | Parameter | Description | Provider | diff --git a/docs/components/llms/models/langchain.mdx b/docs/components/llms/models/langchain.mdx index 5cc6566a..3198ca83 100644 --- a/docs/components/llms/models/langchain.mdx +++ b/docs/components/llms/models/langchain.mdx @@ -12,19 +12,24 @@ For a complete list of available chat models supported by LangChain, refer to th ```python Python import os from mem0 import Memory +from langchain_openai import ChatOpenAI # Set necessary environment variables for your chosen LangChain provider -# For example, if using OpenAI through LangChain: os.environ["OPENAI_API_KEY"] = "your-api-key" +# Initialize a LangChain model directly +openai_model = ChatOpenAI( + model="gpt-4o", + temperature=0.2, + max_tokens=2000 +) + +# Pass the initialized model to the config config = { "llm": { "provider": "langchain", "config": { - "langchain_provider": "OpenAI", - "model": "gpt-4o", - "temperature": 0.2, - "max_tokens": 2000, + "model": openai_model } } } @@ -53,15 +58,15 @@ LangChain supports a wide range of LLM providers, including: - HuggingFace (`HuggingFaceChatEndpoint`) - And many more -You can specify any supported provider in the `langchain_provider` parameter. For a complete and up-to-date list of available providers, refer to the [LangChain Chat Models documentation](https://python.langchain.com/docs/integrations/chat). +You can use any of these model instances directly in your configuration. For a complete and up-to-date list of available providers, refer to the [LangChain Chat Models documentation](https://python.langchain.com/docs/integrations/chat). ## Provider-Specific Configuration When using LangChain as a provider, you'll need to: 1. Set the appropriate environment variables for your chosen LLM provider -2. Specify the LangChain provider class name in the `langchain_provider` parameter -3. Include any additional configuration parameters required by the specific provider +2. Import and initialize the specific model class you want to use +3. Pass the initialized model instance to the config Make sure to install the necessary LangChain packages and any provider-specific dependencies. diff --git a/docs/docs.json b/docs/docs.json index 7e092213..2ca34fef 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -161,7 +161,8 @@ "components/embedders/models/vertexai", "components/embedders/models/gemini", "components/embedders/models/lmstudio", - "components/embedders/models/together" + "components/embedders/models/together", + "components/embedders/models/langchain" ] } ] diff --git a/mem0/configs/llms/base.py b/mem0/configs/llms/base.py index b78d44fe..6f062eca 100644 --- a/mem0/configs/llms/base.py +++ b/mem0/configs/llms/base.py @@ -13,7 +13,7 @@ class BaseLlmConfig(ABC): def __init__( self, - model: Optional[str] = None, + model: Optional[Union[str, Dict]] = None, temperature: float = 0.1, api_key: Optional[str] = None, max_tokens: int = 2000, @@ -41,8 +41,6 @@ class BaseLlmConfig(ABC): xai_base_url: Optional[str] = None, # LM Studio specific lmstudio_base_url: Optional[str] = "http://localhost:1234/v1", - # Langchain specific - langchain_provider: Optional[str] = None, ): """ Initializes a configuration class instance for the LLM. @@ -89,8 +87,6 @@ class BaseLlmConfig(ABC): :type xai_base_url: Optional[str], optional :param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1" :type lmstudio_base_url: Optional[str], optional - :param langchain_provider: Langchain provider to be use, defaults to None - :type langchain_provider: Optional[str], optional """ self.model = model @@ -127,6 +123,3 @@ class BaseLlmConfig(ABC): # LM Studio specific self.lmstudio_base_url = lmstudio_base_url - - # Langchain specific - self.langchain_provider = langchain_provider diff --git a/mem0/embeddings/configs.py b/mem0/embeddings/configs.py index 37c2cae1..9b1be04a 100644 --- a/mem0/embeddings/configs.py +++ b/mem0/embeddings/configs.py @@ -22,6 +22,7 @@ class EmbedderConfig(BaseModel): "vertexai", "together", "lmstudio", + "langchain", ]: return v else: diff --git a/mem0/embeddings/langchain.py b/mem0/embeddings/langchain.py new file mode 100644 index 00000000..f95f83de --- /dev/null +++ b/mem0/embeddings/langchain.py @@ -0,0 +1,36 @@ +import os +from typing import Literal, Optional + +from mem0.configs.embeddings.base import BaseEmbedderConfig +from mem0.embeddings.base import EmbeddingBase + +try: + from langchain.embeddings.base import Embeddings +except ImportError: + raise ImportError("langchain is not installed. Please install it using `pip install langchain`") + + +class LangchainEmbedding(EmbeddingBase): + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + super().__init__(config) + + if self.config.model is None: + raise ValueError("`model` parameter is required") + + if not isinstance(self.config.model, Embeddings): + raise ValueError("`model` must be an instance of Embeddings") + + self.langchain_model = self.config.model + + def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): + """ + Get the embedding for the given text using Langchain. + + Args: + text (str): The text to embed. + memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None. + Returns: + list: The embedding vector. + """ + + return self.langchain_model.embed_query(text) diff --git a/mem0/llms/langchain.py b/mem0/llms/langchain.py index 88836ab0..3e722d60 100644 --- a/mem0/llms/langchain.py +++ b/mem0/llms/langchain.py @@ -1,174 +1,25 @@ from typing import Dict, List, Optional -import enum from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase -# Default import for langchain_community try: - from langchain_community import chat_models + from langchain.chat_models.base import BaseChatModel except ImportError: - raise ImportError("langchain_community not found. Please install it with `pip install langchain-community`") - -# Provider-specific package mapping -PROVIDER_PACKAGES = { - "Anthropic": "langchain_anthropic", - "MistralAI": "langchain_mistralai", - "Fireworks": "langchain_fireworks", - "AzureOpenAI": "langchain_openai", - "OpenAI": "langchain_openai", - "Together": "langchain_together", - "VertexAI": "langchain_google_vertexai", - "GoogleAI": "langchain_google_genai", - "Groq": "langchain_groq", - "Cohere": "langchain_cohere", - "Bedrock": "langchain_aws", - "HuggingFace": "langchain_huggingface", - "NVIDIA": "langchain_nvidia_ai_endpoints", - "Ollama": "langchain_ollama", - "AI21": "langchain_ai21", - "Upstage": "langchain_upstage", - "Databricks": "databricks_langchain", - "Watsonx": "langchain_ibm", - "xAI": "langchain_xai", - "Perplexity": "langchain_perplexity", -} - - -class LangchainProvider(enum.Enum): - Abso = "ChatAbso" - AI21 = "ChatAI21" - Alibaba = "ChatAlibabaCloud" - Anthropic = "ChatAnthropic" - Anyscale = "ChatAnyscale" - AzureAIChatCompletionsModel = "AzureAIChatCompletionsModel" - AzureOpenAI = "AzureChatOpenAI" - AzureMLEndpoint = "ChatAzureMLEndpoint" - Baichuan = "ChatBaichuan" - Qianfan = "ChatQianfan" - Bedrock = "ChatBedrock" - Cerebras = "ChatCerebras" - CloudflareWorkersAI = "ChatCloudflareWorkersAI" - Cohere = "ChatCohere" - ContextualAI = "ChatContextualAI" - Coze = "ChatCoze" - Dappier = "ChatDappier" - Databricks = "ChatDatabricks" - DeepInfra = "ChatDeepInfra" - DeepSeek = "ChatDeepSeek" - EdenAI = "ChatEdenAI" - EverlyAI = "ChatEverlyAI" - Fireworks = "ChatFireworks" - Friendli = "ChatFriendli" - GigaChat = "ChatGigaChat" - Goodfire = "ChatGoodfire" - GoogleAI = "ChatGoogleAI" - VertexAI = "VertexAI" - GPTRouter = "ChatGPTRouter" - Groq = "ChatGroq" - HuggingFace = "ChatHuggingFace" - Watsonx = "ChatWatsonx" - Jina = "ChatJina" - Kinetica = "ChatKinetica" - Konko = "ChatKonko" - LiteLLM = "ChatLiteLLM" - LiteLLMRouter = "ChatLiteLLMRouter" - Llama2Chat = "Llama2Chat" - LlamaAPI = "ChatLlamaAPI" - LlamaEdge = "ChatLlamaEdge" - LlamaCpp = "ChatLlamaCpp" - Maritalk = "ChatMaritalk" - MiniMax = "ChatMiniMax" - MistralAI = "ChatMistralAI" - MLX = "ChatMLX" - ModelScope = "ChatModelScope" - Moonshot = "ChatMoonshot" - Naver = "ChatNaver" - Netmind = "ChatNetmind" - NVIDIA = "ChatNVIDIA" - OCIModelDeployment = "ChatOCIModelDeployment" - OCIGenAI = "ChatOCIGenAI" - OctoAI = "ChatOctoAI" - Ollama = "ChatOllama" - OpenAI = "ChatOpenAI" - Outlines = "ChatOutlines" - Perplexity = "ChatPerplexity" - Pipeshift = "ChatPipeshift" - PredictionGuard = "ChatPredictionGuard" - PremAI = "ChatPremAI" - PromptLayerOpenAI = "PromptLayerChatOpenAI" - QwQ = "ChatQwQ" - Reka = "ChatReka" - RunPod = "ChatRunPod" - SambaNovaCloud = "ChatSambaNovaCloud" - SambaStudio = "ChatSambaStudio" - SeekrFlow = "ChatSeekrFlow" - SnowflakeCortex = "ChatSnowflakeCortex" - Solar = "ChatSolar" - SparkLLM = "ChatSparkLLM" - Nebula = "ChatNebula" - Hunyuan = "ChatHunyuan" - Together = "ChatTogether" - TongyiQwen = "ChatTongyiQwen" - Upstage = "ChatUpstage" - Vectara = "ChatVectara" - VLLM = "ChatVLLM" - VolcEngine = "ChatVolcEngine" - Writer = "ChatWriter" - xAI = "ChatXAI" - Xinference = "ChatXinference" - Yandex = "ChatYandex" - Yi = "ChatYi" - Yuan2 = "ChatYuan2" - ZhipuAI = "ChatZhipuAI" + raise ImportError("langchain is not installed. Please install it using `pip install langchain`") class LangchainLLM(LLMBase): def __init__(self, config: Optional[BaseLlmConfig] = None): super().__init__(config) - provider = self.config.langchain_provider - if provider not in LangchainProvider.__members__: - raise ValueError(f"Invalid provider: {provider}") - model_name = LangchainProvider[provider].value + if self.config.model is None: + raise ValueError("`model` parameter is required") - try: - # Check if this provider needs a specialized package - if provider in PROVIDER_PACKAGES: - if provider == "Anthropic": # Special handling for Anthropic with Pydantic v2 - try: - from langchain_anthropic import ChatAnthropic - model_class = ChatAnthropic - except ImportError: - raise ImportError("langchain_anthropic not found. Please install it with `pip install langchain-anthropic`") - else: - package_name = PROVIDER_PACKAGES[provider] - try: - # Import the model class directly from the package - module_path = f"{package_name}" - model_class = __import__(module_path, fromlist=[model_name]) - model_class = getattr(model_class, model_name) - except ImportError: - raise ImportError( - f"Package {package_name} not found. " f"Please install it with `pip install {package_name}`" - ) - except AttributeError: - raise ImportError(f"Model {model_name} not found in {package_name}") - else: - # Use the default langchain_community module - if not hasattr(chat_models, model_name): - raise ImportError(f"Provider {provider} not found in langchain_community.chat_models") + if not isinstance(self.config.model, BaseChatModel): + raise ValueError("`model` must be an instance of BaseChatModel") - model_class = getattr(chat_models, model_name) - - # Initialize the model with relevant config parameters - self.langchain_model = model_class( - model=self.config.model, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens - ) - except (ImportError, AttributeError, ValueError) as e: - raise ImportError(f"Error setting up langchain model for provider {provider}: {str(e)}") + self.langchain_model = self.config.model def generate_response( self, diff --git a/mem0/memory/main.py b/mem0/memory/main.py index b287d2be..e7a106af 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -623,14 +623,13 @@ class Memory(MemoryBase): capture_event("mem0._create_memory", self, {"memory_id": memory_id}) return memory_id - def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None): + def _create_procedural_memory(self, messages, metadata=None, prompt=None): """ Create a procedural memory Args: messages (list): List of messages to create a procedural memory from. metadata (dict): Metadata to create a procedural memory from. - llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel. prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None. """ try: @@ -650,12 +649,7 @@ class Memory(MemoryBase): ] try: - if llm is not None: - parsed_messages = convert_to_messages(parsed_messages) - response = llm.invoke(input=parsed_messages) - procedural_memory = response.content - else: - procedural_memory = self.llm.generate_response(messages=parsed_messages) + procedural_memory = self.llm.generate_response(messages=parsed_messages) except Exception as e: logger.error(f"Error generating procedural memory summary: {e}") raise diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 31c2f3c6..9e10dca1 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -50,6 +50,7 @@ class EmbedderFactory: "vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding", "together": "mem0.embeddings.together.TogetherEmbedding", "lmstudio": "mem0.embeddings.lmstudio.LMStudioEmbedding", + "langchain": "mem0.embeddings.langchain.LangchainEmbedding", } @classmethod diff --git a/pyproject.toml b/pyproject.toml index 297436ec..0965a653 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mem0ai" -version = "0.1.83" +version = "0.1.84" description = "Long-term memory for AI Agents" authors = ["Mem0 "] exclude = [ diff --git a/tests/llms/test_langchain.py b/tests/llms/test_langchain.py index 67b00e6d..3d1aa2ad 100644 --- a/tests/llms/test_langchain.py +++ b/tests/llms/test_langchain.py @@ -4,97 +4,99 @@ import pytest from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.langchain import LangchainLLM +# Add the import for BaseChatModel +try: + from langchain.chat_models.base import BaseChatModel +except ImportError: + from unittest.mock import MagicMock + BaseChatModel = MagicMock + @pytest.fixture def mock_langchain_model(): """Mock a Langchain model for testing.""" - with patch("langchain_openai.ChatOpenAI") as mock_chat_model: - mock_model = Mock() - mock_model.invoke.return_value = Mock(content="This is a test response") - mock_chat_model.return_value = mock_model - yield mock_model + mock_model = Mock(spec=BaseChatModel) + mock_model.invoke.return_value = Mock(content="This is a test response") + return mock_model -def test_langchain_initialization(): - """Test that LangchainLLM initializes correctly with a valid provider.""" - with patch("langchain_openai.ChatOpenAI") as mock_chat_model: - # Setup the mock model - mock_model = Mock() - mock_chat_model.return_value = mock_model - - # Create a config with OpenAI provider - config = BaseLlmConfig( - model="gpt-3.5-turbo", - temperature=0.7, - max_tokens=100, - api_key="test-api-key", - langchain_provider="OpenAI" - ) - - # Initialize the LangchainLLM - llm = LangchainLLM(config) - - # Verify the model was initialized with correct parameters - mock_chat_model.assert_called_once_with( - model="gpt-3.5-turbo", - temperature=0.7, - max_tokens=100, - api_key="test-api-key" - ) - - assert llm.langchain_model == mock_model +def test_langchain_initialization(mock_langchain_model): + """Test that LangchainLLM initializes correctly with a valid model.""" + # Create a config with the model instance directly + config = BaseLlmConfig( + model=mock_langchain_model, + temperature=0.7, + max_tokens=100, + api_key="test-api-key" + ) + + # Initialize the LangchainLLM + llm = LangchainLLM(config) + + # Verify the model was correctly assigned + assert llm.langchain_model == mock_langchain_model def test_generate_response(mock_langchain_model): """Test that generate_response correctly processes messages and returns a response.""" - # Create a config with OpenAI provider + # Create a config with the model instance config = BaseLlmConfig( - model="gpt-3.5-turbo", + model=mock_langchain_model, temperature=0.7, max_tokens=100, - api_key="test-api-key", - langchain_provider="OpenAI" + api_key="test-api-key" ) # Initialize the LangchainLLM - with patch("langchain_openai.ChatOpenAI", return_value=mock_langchain_model): - llm = LangchainLLM(config) - - # Create test messages - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I'm doing well! How can I help you?"}, - {"role": "user", "content": "Tell me a joke."} - ] - - # Get response - response = llm.generate_response(messages) - - # Verify the correct message format was passed to the model - expected_langchain_messages = [ - ("system", "You are a helpful assistant."), - ("human", "Hello, how are you?"), - ("ai", "I'm doing well! How can I help you?"), - ("human", "Tell me a joke.") - ] - - mock_langchain_model.invoke.assert_called_once() - # Extract the first argument of the first call - actual_messages = mock_langchain_model.invoke.call_args[0][0] - assert actual_messages == expected_langchain_messages - assert response == "This is a test response" + llm = LangchainLLM(config) + + # Create test messages + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well! How can I help you?"}, + {"role": "user", "content": "Tell me a joke."} + ] + + # Get response + response = llm.generate_response(messages) + + # Verify the correct message format was passed to the model + expected_langchain_messages = [ + ("system", "You are a helpful assistant."), + ("human", "Hello, how are you?"), + ("ai", "I'm doing well! How can I help you?"), + ("human", "Tell me a joke.") + ] + + mock_langchain_model.invoke.assert_called_once() + # Extract the first argument of the first call + actual_messages = mock_langchain_model.invoke.call_args[0][0] + assert actual_messages == expected_langchain_messages + assert response == "This is a test response" -def test_invalid_provider(): - """Test that LangchainLLM raises an error with an invalid provider.""" +def test_invalid_model(): + """Test that LangchainLLM raises an error with an invalid model.""" config = BaseLlmConfig( - model="test-model", + model="not-a-valid-model-instance", temperature=0.7, max_tokens=100, - api_key="test-api-key", - langchain_provider="InvalidProvider" + api_key="test-api-key" ) - with pytest.raises(ValueError, match="Invalid provider: InvalidProvider"): + with pytest.raises(ValueError, match="`model` must be an instance of BaseChatModel"): + LangchainLLM(config) + + +def test_missing_model(): + """Test that LangchainLLM raises an error when model is None.""" + config = BaseLlmConfig( + model=None, + temperature=0.7, + max_tokens=100, + api_key="test-api-key" + ) + + with pytest.raises(ValueError, match="`model` parameter is required"): LangchainLLM(config)