Add langchain embedding, update langchain LLM and version bump -> 0.1.84 (#2510)
This commit is contained in:
@@ -6,6 +6,15 @@ mode: "wide"
|
|||||||
<Tabs>
|
<Tabs>
|
||||||
<Tab title="Python">
|
<Tab title="Python">
|
||||||
|
|
||||||
|
<Update label="2025-04-07" description="v0.1.84">
|
||||||
|
|
||||||
|
**New Features:**
|
||||||
|
- **Langchain Embedder:** Added Langchain embedder integration
|
||||||
|
|
||||||
|
**Improvements:**
|
||||||
|
- **Langchain LLM:** Updated Langchain LLM integration to directly pass the Langchain object LLM
|
||||||
|
</Update>
|
||||||
|
|
||||||
<Update label="2025-04-07" description="v0.1.83">
|
<Update label="2025-04-07" description="v0.1.83">
|
||||||
|
|
||||||
**Bug Fixes:**
|
**Bug Fixes:**
|
||||||
|
|||||||
120
docs/components/embedders/models/langchain.mdx
Normal file
120
docs/components/embedders/models/langchain.mdx
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
---
|
||||||
|
title: LangChain
|
||||||
|
---
|
||||||
|
|
||||||
|
Mem0 supports LangChain as a provider to access a wide range of embedding models. LangChain is a framework for developing applications powered by language models, making it easy to integrate various embedding providers through a consistent interface.
|
||||||
|
|
||||||
|
For a complete list of available embedding models supported by LangChain, refer to the [LangChain Text Embedding documentation](https://python.langchain.com/docs/integrations/text_embedding/).
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
<CodeGroup>
|
||||||
|
```python Python
|
||||||
|
import os
|
||||||
|
from mem0 import Memory
|
||||||
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
# Set necessary environment variables for your chosen LangChain provider
|
||||||
|
os.environ["OPENAI_API_KEY"] = "your-api-key"
|
||||||
|
|
||||||
|
# Initialize a LangChain embeddings model directly
|
||||||
|
openai_embeddings = OpenAIEmbeddings(
|
||||||
|
model="text-embedding-3-small",
|
||||||
|
dimensions=1536
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pass the initialized model to the config
|
||||||
|
config = {
|
||||||
|
"embedder": {
|
||||||
|
"provider": "langchain",
|
||||||
|
"config": {
|
||||||
|
"model": openai_embeddings
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m = Memory.from_config(config)
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||||
|
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
||||||
|
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||||
|
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||||
|
]
|
||||||
|
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||||
|
```
|
||||||
|
</CodeGroup>
|
||||||
|
|
||||||
|
## Supported LangChain Embedding Providers
|
||||||
|
|
||||||
|
LangChain supports a wide range of embedding providers, including:
|
||||||
|
|
||||||
|
- OpenAI (`OpenAIEmbeddings`)
|
||||||
|
- Cohere (`CohereEmbeddings`)
|
||||||
|
- Google (`VertexAIEmbeddings`)
|
||||||
|
- Hugging Face (`HuggingFaceEmbeddings`)
|
||||||
|
- Sentence Transformers (`HuggingFaceEmbeddings`)
|
||||||
|
- Azure OpenAI (`AzureOpenAIEmbeddings`)
|
||||||
|
- Ollama (`OllamaEmbeddings`)
|
||||||
|
- Together (`TogetherEmbeddings`)
|
||||||
|
- And many more
|
||||||
|
|
||||||
|
You can use any of these model instances directly in your configuration. For a complete and up-to-date list of available embedding providers, refer to the [LangChain Text Embedding documentation](https://python.langchain.com/docs/integrations/text_embedding/).
|
||||||
|
|
||||||
|
## Provider-Specific Configuration
|
||||||
|
|
||||||
|
When using LangChain as an embedder provider, you'll need to:
|
||||||
|
|
||||||
|
1. Set the appropriate environment variables for your chosen embedding provider
|
||||||
|
2. Import and initialize the specific model class you want to use
|
||||||
|
3. Pass the initialized model instance to the config
|
||||||
|
|
||||||
|
### Examples with Different Providers
|
||||||
|
|
||||||
|
#### HuggingFace Embeddings
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
# Initialize a HuggingFace embeddings model
|
||||||
|
hf_embeddings = HuggingFaceEmbeddings(
|
||||||
|
model_name="BAAI/bge-small-en-v1.5",
|
||||||
|
encode_kwargs={"normalize_embeddings": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"embedder": {
|
||||||
|
"provider": "langchain",
|
||||||
|
"config": {
|
||||||
|
"model": hf_embeddings
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Ollama Embeddings
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain_ollama import OllamaEmbeddings
|
||||||
|
|
||||||
|
# Initialize an Ollama embeddings model
|
||||||
|
ollama_embeddings = OllamaEmbeddings(
|
||||||
|
model="nomic-embed-text"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"embedder": {
|
||||||
|
"provider": "langchain",
|
||||||
|
"config": {
|
||||||
|
"model": ollama_embeddings
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
<Note>
|
||||||
|
Make sure to install the necessary LangChain packages and any provider-specific dependencies.
|
||||||
|
</Note>
|
||||||
|
|
||||||
|
## Config
|
||||||
|
|
||||||
|
All available parameters for the `langchain` embedder config are present in [Master List of All Params in Config](../config).
|
||||||
@@ -23,6 +23,7 @@ See the list of supported embedders below.
|
|||||||
<Card title="Vertex AI" href="/components/embedders/models/vertexai"></Card>
|
<Card title="Vertex AI" href="/components/embedders/models/vertexai"></Card>
|
||||||
<Card title="Together" href="/components/embedders/models/together"></Card>
|
<Card title="Together" href="/components/embedders/models/together"></Card>
|
||||||
<Card title="LM Studio" href="/components/embedders/models/lmstudio"></Card>
|
<Card title="LM Studio" href="/components/embedders/models/lmstudio"></Card>
|
||||||
|
<Card title="Langchain" href="/components/embedders/models/langchain"></Card>
|
||||||
</CardGroup>
|
</CardGroup>
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -109,7 +109,6 @@ Here's a comprehensive list of all parameters that can be used across different
|
|||||||
| `deepseek_base_url` | Base URL for DeepSeek API | DeepSeek |
|
| `deepseek_base_url` | Base URL for DeepSeek API | DeepSeek |
|
||||||
| `xai_base_url` | Base URL for XAI API | XAI |
|
| `xai_base_url` | Base URL for XAI API | XAI |
|
||||||
| `lmstudio_base_url` | Base URL for LM Studio API | LM Studio |
|
| `lmstudio_base_url` | Base URL for LM Studio API | LM Studio |
|
||||||
| `langchain_provider` | Provider for Langchain | Langchain |
|
|
||||||
</Tab>
|
</Tab>
|
||||||
<Tab title="TypeScript">
|
<Tab title="TypeScript">
|
||||||
| Parameter | Description | Provider |
|
| Parameter | Description | Provider |
|
||||||
|
|||||||
@@ -12,19 +12,24 @@ For a complete list of available chat models supported by LangChain, refer to th
|
|||||||
```python Python
|
```python Python
|
||||||
import os
|
import os
|
||||||
from mem0 import Memory
|
from mem0 import Memory
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
# Set necessary environment variables for your chosen LangChain provider
|
# Set necessary environment variables for your chosen LangChain provider
|
||||||
# For example, if using OpenAI through LangChain:
|
|
||||||
os.environ["OPENAI_API_KEY"] = "your-api-key"
|
os.environ["OPENAI_API_KEY"] = "your-api-key"
|
||||||
|
|
||||||
|
# Initialize a LangChain model directly
|
||||||
|
openai_model = ChatOpenAI(
|
||||||
|
model="gpt-4o",
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=2000
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pass the initialized model to the config
|
||||||
config = {
|
config = {
|
||||||
"llm": {
|
"llm": {
|
||||||
"provider": "langchain",
|
"provider": "langchain",
|
||||||
"config": {
|
"config": {
|
||||||
"langchain_provider": "OpenAI",
|
"model": openai_model
|
||||||
"model": "gpt-4o",
|
|
||||||
"temperature": 0.2,
|
|
||||||
"max_tokens": 2000,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -53,15 +58,15 @@ LangChain supports a wide range of LLM providers, including:
|
|||||||
- HuggingFace (`HuggingFaceChatEndpoint`)
|
- HuggingFace (`HuggingFaceChatEndpoint`)
|
||||||
- And many more
|
- And many more
|
||||||
|
|
||||||
You can specify any supported provider in the `langchain_provider` parameter. For a complete and up-to-date list of available providers, refer to the [LangChain Chat Models documentation](https://python.langchain.com/docs/integrations/chat).
|
You can use any of these model instances directly in your configuration. For a complete and up-to-date list of available providers, refer to the [LangChain Chat Models documentation](https://python.langchain.com/docs/integrations/chat).
|
||||||
|
|
||||||
## Provider-Specific Configuration
|
## Provider-Specific Configuration
|
||||||
|
|
||||||
When using LangChain as a provider, you'll need to:
|
When using LangChain as a provider, you'll need to:
|
||||||
|
|
||||||
1. Set the appropriate environment variables for your chosen LLM provider
|
1. Set the appropriate environment variables for your chosen LLM provider
|
||||||
2. Specify the LangChain provider class name in the `langchain_provider` parameter
|
2. Import and initialize the specific model class you want to use
|
||||||
3. Include any additional configuration parameters required by the specific provider
|
3. Pass the initialized model instance to the config
|
||||||
|
|
||||||
<Note>
|
<Note>
|
||||||
Make sure to install the necessary LangChain packages and any provider-specific dependencies.
|
Make sure to install the necessary LangChain packages and any provider-specific dependencies.
|
||||||
|
|||||||
@@ -161,7 +161,8 @@
|
|||||||
"components/embedders/models/vertexai",
|
"components/embedders/models/vertexai",
|
||||||
"components/embedders/models/gemini",
|
"components/embedders/models/gemini",
|
||||||
"components/embedders/models/lmstudio",
|
"components/embedders/models/lmstudio",
|
||||||
"components/embedders/models/together"
|
"components/embedders/models/together",
|
||||||
|
"components/embedders/models/langchain"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class BaseLlmConfig(ABC):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Optional[str] = None,
|
model: Optional[Union[str, Dict]] = None,
|
||||||
temperature: float = 0.1,
|
temperature: float = 0.1,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
max_tokens: int = 2000,
|
max_tokens: int = 2000,
|
||||||
@@ -41,8 +41,6 @@ class BaseLlmConfig(ABC):
|
|||||||
xai_base_url: Optional[str] = None,
|
xai_base_url: Optional[str] = None,
|
||||||
# LM Studio specific
|
# LM Studio specific
|
||||||
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
||||||
# Langchain specific
|
|
||||||
langchain_provider: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes a configuration class instance for the LLM.
|
Initializes a configuration class instance for the LLM.
|
||||||
@@ -89,8 +87,6 @@ class BaseLlmConfig(ABC):
|
|||||||
:type xai_base_url: Optional[str], optional
|
:type xai_base_url: Optional[str], optional
|
||||||
:param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1"
|
:param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1"
|
||||||
:type lmstudio_base_url: Optional[str], optional
|
:type lmstudio_base_url: Optional[str], optional
|
||||||
:param langchain_provider: Langchain provider to be use, defaults to None
|
|
||||||
:type langchain_provider: Optional[str], optional
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -127,6 +123,3 @@ class BaseLlmConfig(ABC):
|
|||||||
|
|
||||||
# LM Studio specific
|
# LM Studio specific
|
||||||
self.lmstudio_base_url = lmstudio_base_url
|
self.lmstudio_base_url = lmstudio_base_url
|
||||||
|
|
||||||
# Langchain specific
|
|
||||||
self.langchain_provider = langchain_provider
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ class EmbedderConfig(BaseModel):
|
|||||||
"vertexai",
|
"vertexai",
|
||||||
"together",
|
"together",
|
||||||
"lmstudio",
|
"lmstudio",
|
||||||
|
"langchain",
|
||||||
]:
|
]:
|
||||||
return v
|
return v
|
||||||
else:
|
else:
|
||||||
|
|||||||
36
mem0/embeddings/langchain.py
Normal file
36
mem0/embeddings/langchain.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import os
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||||
|
from mem0.embeddings.base import EmbeddingBase
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("langchain is not installed. Please install it using `pip install langchain`")
|
||||||
|
|
||||||
|
|
||||||
|
class LangchainEmbedding(EmbeddingBase):
|
||||||
|
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
if self.config.model is None:
|
||||||
|
raise ValueError("`model` parameter is required")
|
||||||
|
|
||||||
|
if not isinstance(self.config.model, Embeddings):
|
||||||
|
raise ValueError("`model` must be an instance of Embeddings")
|
||||||
|
|
||||||
|
self.langchain_model = self.config.model
|
||||||
|
|
||||||
|
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||||
|
"""
|
||||||
|
Get the embedding for the given text using Langchain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The text to embed.
|
||||||
|
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||||
|
Returns:
|
||||||
|
list: The embedding vector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.langchain_model.embed_query(text)
|
||||||
@@ -1,174 +1,25 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
import enum
|
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
# Default import for langchain_community
|
|
||||||
try:
|
try:
|
||||||
from langchain_community import chat_models
|
from langchain.chat_models.base import BaseChatModel
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("langchain_community not found. Please install it with `pip install langchain-community`")
|
raise ImportError("langchain is not installed. Please install it using `pip install langchain`")
|
||||||
|
|
||||||
# Provider-specific package mapping
|
|
||||||
PROVIDER_PACKAGES = {
|
|
||||||
"Anthropic": "langchain_anthropic",
|
|
||||||
"MistralAI": "langchain_mistralai",
|
|
||||||
"Fireworks": "langchain_fireworks",
|
|
||||||
"AzureOpenAI": "langchain_openai",
|
|
||||||
"OpenAI": "langchain_openai",
|
|
||||||
"Together": "langchain_together",
|
|
||||||
"VertexAI": "langchain_google_vertexai",
|
|
||||||
"GoogleAI": "langchain_google_genai",
|
|
||||||
"Groq": "langchain_groq",
|
|
||||||
"Cohere": "langchain_cohere",
|
|
||||||
"Bedrock": "langchain_aws",
|
|
||||||
"HuggingFace": "langchain_huggingface",
|
|
||||||
"NVIDIA": "langchain_nvidia_ai_endpoints",
|
|
||||||
"Ollama": "langchain_ollama",
|
|
||||||
"AI21": "langchain_ai21",
|
|
||||||
"Upstage": "langchain_upstage",
|
|
||||||
"Databricks": "databricks_langchain",
|
|
||||||
"Watsonx": "langchain_ibm",
|
|
||||||
"xAI": "langchain_xai",
|
|
||||||
"Perplexity": "langchain_perplexity",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class LangchainProvider(enum.Enum):
|
|
||||||
Abso = "ChatAbso"
|
|
||||||
AI21 = "ChatAI21"
|
|
||||||
Alibaba = "ChatAlibabaCloud"
|
|
||||||
Anthropic = "ChatAnthropic"
|
|
||||||
Anyscale = "ChatAnyscale"
|
|
||||||
AzureAIChatCompletionsModel = "AzureAIChatCompletionsModel"
|
|
||||||
AzureOpenAI = "AzureChatOpenAI"
|
|
||||||
AzureMLEndpoint = "ChatAzureMLEndpoint"
|
|
||||||
Baichuan = "ChatBaichuan"
|
|
||||||
Qianfan = "ChatQianfan"
|
|
||||||
Bedrock = "ChatBedrock"
|
|
||||||
Cerebras = "ChatCerebras"
|
|
||||||
CloudflareWorkersAI = "ChatCloudflareWorkersAI"
|
|
||||||
Cohere = "ChatCohere"
|
|
||||||
ContextualAI = "ChatContextualAI"
|
|
||||||
Coze = "ChatCoze"
|
|
||||||
Dappier = "ChatDappier"
|
|
||||||
Databricks = "ChatDatabricks"
|
|
||||||
DeepInfra = "ChatDeepInfra"
|
|
||||||
DeepSeek = "ChatDeepSeek"
|
|
||||||
EdenAI = "ChatEdenAI"
|
|
||||||
EverlyAI = "ChatEverlyAI"
|
|
||||||
Fireworks = "ChatFireworks"
|
|
||||||
Friendli = "ChatFriendli"
|
|
||||||
GigaChat = "ChatGigaChat"
|
|
||||||
Goodfire = "ChatGoodfire"
|
|
||||||
GoogleAI = "ChatGoogleAI"
|
|
||||||
VertexAI = "VertexAI"
|
|
||||||
GPTRouter = "ChatGPTRouter"
|
|
||||||
Groq = "ChatGroq"
|
|
||||||
HuggingFace = "ChatHuggingFace"
|
|
||||||
Watsonx = "ChatWatsonx"
|
|
||||||
Jina = "ChatJina"
|
|
||||||
Kinetica = "ChatKinetica"
|
|
||||||
Konko = "ChatKonko"
|
|
||||||
LiteLLM = "ChatLiteLLM"
|
|
||||||
LiteLLMRouter = "ChatLiteLLMRouter"
|
|
||||||
Llama2Chat = "Llama2Chat"
|
|
||||||
LlamaAPI = "ChatLlamaAPI"
|
|
||||||
LlamaEdge = "ChatLlamaEdge"
|
|
||||||
LlamaCpp = "ChatLlamaCpp"
|
|
||||||
Maritalk = "ChatMaritalk"
|
|
||||||
MiniMax = "ChatMiniMax"
|
|
||||||
MistralAI = "ChatMistralAI"
|
|
||||||
MLX = "ChatMLX"
|
|
||||||
ModelScope = "ChatModelScope"
|
|
||||||
Moonshot = "ChatMoonshot"
|
|
||||||
Naver = "ChatNaver"
|
|
||||||
Netmind = "ChatNetmind"
|
|
||||||
NVIDIA = "ChatNVIDIA"
|
|
||||||
OCIModelDeployment = "ChatOCIModelDeployment"
|
|
||||||
OCIGenAI = "ChatOCIGenAI"
|
|
||||||
OctoAI = "ChatOctoAI"
|
|
||||||
Ollama = "ChatOllama"
|
|
||||||
OpenAI = "ChatOpenAI"
|
|
||||||
Outlines = "ChatOutlines"
|
|
||||||
Perplexity = "ChatPerplexity"
|
|
||||||
Pipeshift = "ChatPipeshift"
|
|
||||||
PredictionGuard = "ChatPredictionGuard"
|
|
||||||
PremAI = "ChatPremAI"
|
|
||||||
PromptLayerOpenAI = "PromptLayerChatOpenAI"
|
|
||||||
QwQ = "ChatQwQ"
|
|
||||||
Reka = "ChatReka"
|
|
||||||
RunPod = "ChatRunPod"
|
|
||||||
SambaNovaCloud = "ChatSambaNovaCloud"
|
|
||||||
SambaStudio = "ChatSambaStudio"
|
|
||||||
SeekrFlow = "ChatSeekrFlow"
|
|
||||||
SnowflakeCortex = "ChatSnowflakeCortex"
|
|
||||||
Solar = "ChatSolar"
|
|
||||||
SparkLLM = "ChatSparkLLM"
|
|
||||||
Nebula = "ChatNebula"
|
|
||||||
Hunyuan = "ChatHunyuan"
|
|
||||||
Together = "ChatTogether"
|
|
||||||
TongyiQwen = "ChatTongyiQwen"
|
|
||||||
Upstage = "ChatUpstage"
|
|
||||||
Vectara = "ChatVectara"
|
|
||||||
VLLM = "ChatVLLM"
|
|
||||||
VolcEngine = "ChatVolcEngine"
|
|
||||||
Writer = "ChatWriter"
|
|
||||||
xAI = "ChatXAI"
|
|
||||||
Xinference = "ChatXinference"
|
|
||||||
Yandex = "ChatYandex"
|
|
||||||
Yi = "ChatYi"
|
|
||||||
Yuan2 = "ChatYuan2"
|
|
||||||
ZhipuAI = "ChatZhipuAI"
|
|
||||||
|
|
||||||
|
|
||||||
class LangchainLLM(LLMBase):
|
class LangchainLLM(LLMBase):
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
provider = self.config.langchain_provider
|
if self.config.model is None:
|
||||||
if provider not in LangchainProvider.__members__:
|
raise ValueError("`model` parameter is required")
|
||||||
raise ValueError(f"Invalid provider: {provider}")
|
|
||||||
model_name = LangchainProvider[provider].value
|
|
||||||
|
|
||||||
try:
|
if not isinstance(self.config.model, BaseChatModel):
|
||||||
# Check if this provider needs a specialized package
|
raise ValueError("`model` must be an instance of BaseChatModel")
|
||||||
if provider in PROVIDER_PACKAGES:
|
|
||||||
if provider == "Anthropic": # Special handling for Anthropic with Pydantic v2
|
|
||||||
try:
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
|
||||||
model_class = ChatAnthropic
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("langchain_anthropic not found. Please install it with `pip install langchain-anthropic`")
|
|
||||||
else:
|
|
||||||
package_name = PROVIDER_PACKAGES[provider]
|
|
||||||
try:
|
|
||||||
# Import the model class directly from the package
|
|
||||||
module_path = f"{package_name}"
|
|
||||||
model_class = __import__(module_path, fromlist=[model_name])
|
|
||||||
model_class = getattr(model_class, model_name)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
f"Package {package_name} not found. " f"Please install it with `pip install {package_name}`"
|
|
||||||
)
|
|
||||||
except AttributeError:
|
|
||||||
raise ImportError(f"Model {model_name} not found in {package_name}")
|
|
||||||
else:
|
|
||||||
# Use the default langchain_community module
|
|
||||||
if not hasattr(chat_models, model_name):
|
|
||||||
raise ImportError(f"Provider {provider} not found in langchain_community.chat_models")
|
|
||||||
|
|
||||||
model_class = getattr(chat_models, model_name)
|
self.langchain_model = self.config.model
|
||||||
|
|
||||||
# Initialize the model with relevant config parameters
|
|
||||||
self.langchain_model = model_class(
|
|
||||||
model=self.config.model,
|
|
||||||
temperature=self.config.temperature,
|
|
||||||
max_tokens=self.config.max_tokens
|
|
||||||
)
|
|
||||||
except (ImportError, AttributeError, ValueError) as e:
|
|
||||||
raise ImportError(f"Error setting up langchain model for provider {provider}: {str(e)}")
|
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -623,14 +623,13 @@ class Memory(MemoryBase):
|
|||||||
capture_event("mem0._create_memory", self, {"memory_id": memory_id})
|
capture_event("mem0._create_memory", self, {"memory_id": memory_id})
|
||||||
return memory_id
|
return memory_id
|
||||||
|
|
||||||
def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None):
|
def _create_procedural_memory(self, messages, metadata=None, prompt=None):
|
||||||
"""
|
"""
|
||||||
Create a procedural memory
|
Create a procedural memory
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (list): List of messages to create a procedural memory from.
|
messages (list): List of messages to create a procedural memory from.
|
||||||
metadata (dict): Metadata to create a procedural memory from.
|
metadata (dict): Metadata to create a procedural memory from.
|
||||||
llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel.
|
|
||||||
prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None.
|
prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@@ -650,12 +649,7 @@ class Memory(MemoryBase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if llm is not None:
|
procedural_memory = self.llm.generate_response(messages=parsed_messages)
|
||||||
parsed_messages = convert_to_messages(parsed_messages)
|
|
||||||
response = llm.invoke(input=parsed_messages)
|
|
||||||
procedural_memory = response.content
|
|
||||||
else:
|
|
||||||
procedural_memory = self.llm.generate_response(messages=parsed_messages)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating procedural memory summary: {e}")
|
logger.error(f"Error generating procedural memory summary: {e}")
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class EmbedderFactory:
|
|||||||
"vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding",
|
"vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding",
|
||||||
"together": "mem0.embeddings.together.TogetherEmbedding",
|
"together": "mem0.embeddings.together.TogetherEmbedding",
|
||||||
"lmstudio": "mem0.embeddings.lmstudio.LMStudioEmbedding",
|
"lmstudio": "mem0.embeddings.lmstudio.LMStudioEmbedding",
|
||||||
|
"langchain": "mem0.embeddings.langchain.LangchainEmbedding",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "mem0ai"
|
name = "mem0ai"
|
||||||
version = "0.1.83"
|
version = "0.1.84"
|
||||||
description = "Long-term memory for AI Agents"
|
description = "Long-term memory for AI Agents"
|
||||||
authors = ["Mem0 <founders@mem0.ai>"]
|
authors = ["Mem0 <founders@mem0.ai>"]
|
||||||
exclude = [
|
exclude = [
|
||||||
|
|||||||
@@ -4,97 +4,99 @@ import pytest
|
|||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.langchain import LangchainLLM
|
from mem0.llms.langchain import LangchainLLM
|
||||||
|
|
||||||
|
# Add the import for BaseChatModel
|
||||||
|
try:
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
except ImportError:
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
BaseChatModel = MagicMock
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_langchain_model():
|
def mock_langchain_model():
|
||||||
"""Mock a Langchain model for testing."""
|
"""Mock a Langchain model for testing."""
|
||||||
with patch("langchain_openai.ChatOpenAI") as mock_chat_model:
|
mock_model = Mock(spec=BaseChatModel)
|
||||||
mock_model = Mock()
|
mock_model.invoke.return_value = Mock(content="This is a test response")
|
||||||
mock_model.invoke.return_value = Mock(content="This is a test response")
|
return mock_model
|
||||||
mock_chat_model.return_value = mock_model
|
|
||||||
yield mock_model
|
|
||||||
|
|
||||||
|
|
||||||
def test_langchain_initialization():
|
def test_langchain_initialization(mock_langchain_model):
|
||||||
"""Test that LangchainLLM initializes correctly with a valid provider."""
|
"""Test that LangchainLLM initializes correctly with a valid model."""
|
||||||
with patch("langchain_openai.ChatOpenAI") as mock_chat_model:
|
# Create a config with the model instance directly
|
||||||
# Setup the mock model
|
config = BaseLlmConfig(
|
||||||
mock_model = Mock()
|
model=mock_langchain_model,
|
||||||
mock_chat_model.return_value = mock_model
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
# Create a config with OpenAI provider
|
api_key="test-api-key"
|
||||||
config = BaseLlmConfig(
|
)
|
||||||
model="gpt-3.5-turbo",
|
|
||||||
temperature=0.7,
|
# Initialize the LangchainLLM
|
||||||
max_tokens=100,
|
llm = LangchainLLM(config)
|
||||||
api_key="test-api-key",
|
|
||||||
langchain_provider="OpenAI"
|
# Verify the model was correctly assigned
|
||||||
)
|
assert llm.langchain_model == mock_langchain_model
|
||||||
|
|
||||||
# Initialize the LangchainLLM
|
|
||||||
llm = LangchainLLM(config)
|
|
||||||
|
|
||||||
# Verify the model was initialized with correct parameters
|
|
||||||
mock_chat_model.assert_called_once_with(
|
|
||||||
model="gpt-3.5-turbo",
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100,
|
|
||||||
api_key="test-api-key"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert llm.langchain_model == mock_model
|
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response(mock_langchain_model):
|
def test_generate_response(mock_langchain_model):
|
||||||
"""Test that generate_response correctly processes messages and returns a response."""
|
"""Test that generate_response correctly processes messages and returns a response."""
|
||||||
# Create a config with OpenAI provider
|
# Create a config with the model instance
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(
|
||||||
model="gpt-3.5-turbo",
|
model=mock_langchain_model,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
api_key="test-api-key",
|
api_key="test-api-key"
|
||||||
langchain_provider="OpenAI"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the LangchainLLM
|
# Initialize the LangchainLLM
|
||||||
with patch("langchain_openai.ChatOpenAI", return_value=mock_langchain_model):
|
llm = LangchainLLM(config)
|
||||||
llm = LangchainLLM(config)
|
|
||||||
|
# Create test messages
|
||||||
# Create test messages
|
messages = [
|
||||||
messages = [
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "user", "content": "Hello, how are you?"},
|
||||||
{"role": "user", "content": "Hello, how are you?"},
|
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
|
||||||
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
|
{"role": "user", "content": "Tell me a joke."}
|
||||||
{"role": "user", "content": "Tell me a joke."}
|
]
|
||||||
]
|
|
||||||
|
# Get response
|
||||||
# Get response
|
response = llm.generate_response(messages)
|
||||||
response = llm.generate_response(messages)
|
|
||||||
|
# Verify the correct message format was passed to the model
|
||||||
# Verify the correct message format was passed to the model
|
expected_langchain_messages = [
|
||||||
expected_langchain_messages = [
|
("system", "You are a helpful assistant."),
|
||||||
("system", "You are a helpful assistant."),
|
("human", "Hello, how are you?"),
|
||||||
("human", "Hello, how are you?"),
|
("ai", "I'm doing well! How can I help you?"),
|
||||||
("ai", "I'm doing well! How can I help you?"),
|
("human", "Tell me a joke.")
|
||||||
("human", "Tell me a joke.")
|
]
|
||||||
]
|
|
||||||
|
mock_langchain_model.invoke.assert_called_once()
|
||||||
mock_langchain_model.invoke.assert_called_once()
|
# Extract the first argument of the first call
|
||||||
# Extract the first argument of the first call
|
actual_messages = mock_langchain_model.invoke.call_args[0][0]
|
||||||
actual_messages = mock_langchain_model.invoke.call_args[0][0]
|
assert actual_messages == expected_langchain_messages
|
||||||
assert actual_messages == expected_langchain_messages
|
assert response == "This is a test response"
|
||||||
assert response == "This is a test response"
|
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_provider():
|
def test_invalid_model():
|
||||||
"""Test that LangchainLLM raises an error with an invalid provider."""
|
"""Test that LangchainLLM raises an error with an invalid model."""
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(
|
||||||
model="test-model",
|
model="not-a-valid-model-instance",
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
api_key="test-api-key",
|
api_key="test-api-key"
|
||||||
langchain_provider="InvalidProvider"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Invalid provider: InvalidProvider"):
|
with pytest.raises(ValueError, match="`model` must be an instance of BaseChatModel"):
|
||||||
|
LangchainLLM(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_model():
|
||||||
|
"""Test that LangchainLLM raises an error when model is None."""
|
||||||
|
config = BaseLlmConfig(
|
||||||
|
model=None,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
api_key="test-api-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="`model` parameter is required"):
|
||||||
LangchainLLM(config)
|
LangchainLLM(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user