Add langchain embedding, update langchain LLM and version bump -> 0.1.84 (#2510)

This commit is contained in:
Dev Khant
2025-04-07 15:27:26 +05:30
committed by GitHub
parent 5509066925
commit 9dfa9b4412
14 changed files with 266 additions and 253 deletions

View File

@@ -13,7 +13,7 @@ class BaseLlmConfig(ABC):
def __init__(
self,
model: Optional[str] = None,
model: Optional[Union[str, Dict]] = None,
temperature: float = 0.1,
api_key: Optional[str] = None,
max_tokens: int = 2000,
@@ -41,8 +41,6 @@ class BaseLlmConfig(ABC):
xai_base_url: Optional[str] = None,
# LM Studio specific
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
# Langchain specific
langchain_provider: Optional[str] = None,
):
"""
Initializes a configuration class instance for the LLM.
@@ -89,8 +87,6 @@ class BaseLlmConfig(ABC):
:type xai_base_url: Optional[str], optional
:param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1"
:type lmstudio_base_url: Optional[str], optional
:param langchain_provider: Langchain provider to be use, defaults to None
:type langchain_provider: Optional[str], optional
"""
self.model = model
@@ -127,6 +123,3 @@ class BaseLlmConfig(ABC):
# LM Studio specific
self.lmstudio_base_url = lmstudio_base_url
# Langchain specific
self.langchain_provider = langchain_provider

View File

@@ -22,6 +22,7 @@ class EmbedderConfig(BaseModel):
"vertexai",
"together",
"lmstudio",
"langchain",
]:
return v
else:

View File

@@ -0,0 +1,36 @@
import os
from typing import Literal, Optional
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase
try:
from langchain.embeddings.base import Embeddings
except ImportError:
raise ImportError("langchain is not installed. Please install it using `pip install langchain`")
class LangchainEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)
if self.config.model is None:
raise ValueError("`model` parameter is required")
if not isinstance(self.config.model, Embeddings):
raise ValueError("`model` must be an instance of Embeddings")
self.langchain_model = self.config.model
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
"""
Get the embedding for the given text using Langchain.
Args:
text (str): The text to embed.
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns:
list: The embedding vector.
"""
return self.langchain_model.embed_query(text)

View File

@@ -1,174 +1,25 @@
from typing import Dict, List, Optional
import enum
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
# Default import for langchain_community
try:
from langchain_community import chat_models
from langchain.chat_models.base import BaseChatModel
except ImportError:
raise ImportError("langchain_community not found. Please install it with `pip install langchain-community`")
# Provider-specific package mapping
PROVIDER_PACKAGES = {
"Anthropic": "langchain_anthropic",
"MistralAI": "langchain_mistralai",
"Fireworks": "langchain_fireworks",
"AzureOpenAI": "langchain_openai",
"OpenAI": "langchain_openai",
"Together": "langchain_together",
"VertexAI": "langchain_google_vertexai",
"GoogleAI": "langchain_google_genai",
"Groq": "langchain_groq",
"Cohere": "langchain_cohere",
"Bedrock": "langchain_aws",
"HuggingFace": "langchain_huggingface",
"NVIDIA": "langchain_nvidia_ai_endpoints",
"Ollama": "langchain_ollama",
"AI21": "langchain_ai21",
"Upstage": "langchain_upstage",
"Databricks": "databricks_langchain",
"Watsonx": "langchain_ibm",
"xAI": "langchain_xai",
"Perplexity": "langchain_perplexity",
}
class LangchainProvider(enum.Enum):
Abso = "ChatAbso"
AI21 = "ChatAI21"
Alibaba = "ChatAlibabaCloud"
Anthropic = "ChatAnthropic"
Anyscale = "ChatAnyscale"
AzureAIChatCompletionsModel = "AzureAIChatCompletionsModel"
AzureOpenAI = "AzureChatOpenAI"
AzureMLEndpoint = "ChatAzureMLEndpoint"
Baichuan = "ChatBaichuan"
Qianfan = "ChatQianfan"
Bedrock = "ChatBedrock"
Cerebras = "ChatCerebras"
CloudflareWorkersAI = "ChatCloudflareWorkersAI"
Cohere = "ChatCohere"
ContextualAI = "ChatContextualAI"
Coze = "ChatCoze"
Dappier = "ChatDappier"
Databricks = "ChatDatabricks"
DeepInfra = "ChatDeepInfra"
DeepSeek = "ChatDeepSeek"
EdenAI = "ChatEdenAI"
EverlyAI = "ChatEverlyAI"
Fireworks = "ChatFireworks"
Friendli = "ChatFriendli"
GigaChat = "ChatGigaChat"
Goodfire = "ChatGoodfire"
GoogleAI = "ChatGoogleAI"
VertexAI = "VertexAI"
GPTRouter = "ChatGPTRouter"
Groq = "ChatGroq"
HuggingFace = "ChatHuggingFace"
Watsonx = "ChatWatsonx"
Jina = "ChatJina"
Kinetica = "ChatKinetica"
Konko = "ChatKonko"
LiteLLM = "ChatLiteLLM"
LiteLLMRouter = "ChatLiteLLMRouter"
Llama2Chat = "Llama2Chat"
LlamaAPI = "ChatLlamaAPI"
LlamaEdge = "ChatLlamaEdge"
LlamaCpp = "ChatLlamaCpp"
Maritalk = "ChatMaritalk"
MiniMax = "ChatMiniMax"
MistralAI = "ChatMistralAI"
MLX = "ChatMLX"
ModelScope = "ChatModelScope"
Moonshot = "ChatMoonshot"
Naver = "ChatNaver"
Netmind = "ChatNetmind"
NVIDIA = "ChatNVIDIA"
OCIModelDeployment = "ChatOCIModelDeployment"
OCIGenAI = "ChatOCIGenAI"
OctoAI = "ChatOctoAI"
Ollama = "ChatOllama"
OpenAI = "ChatOpenAI"
Outlines = "ChatOutlines"
Perplexity = "ChatPerplexity"
Pipeshift = "ChatPipeshift"
PredictionGuard = "ChatPredictionGuard"
PremAI = "ChatPremAI"
PromptLayerOpenAI = "PromptLayerChatOpenAI"
QwQ = "ChatQwQ"
Reka = "ChatReka"
RunPod = "ChatRunPod"
SambaNovaCloud = "ChatSambaNovaCloud"
SambaStudio = "ChatSambaStudio"
SeekrFlow = "ChatSeekrFlow"
SnowflakeCortex = "ChatSnowflakeCortex"
Solar = "ChatSolar"
SparkLLM = "ChatSparkLLM"
Nebula = "ChatNebula"
Hunyuan = "ChatHunyuan"
Together = "ChatTogether"
TongyiQwen = "ChatTongyiQwen"
Upstage = "ChatUpstage"
Vectara = "ChatVectara"
VLLM = "ChatVLLM"
VolcEngine = "ChatVolcEngine"
Writer = "ChatWriter"
xAI = "ChatXAI"
Xinference = "ChatXinference"
Yandex = "ChatYandex"
Yi = "ChatYi"
Yuan2 = "ChatYuan2"
ZhipuAI = "ChatZhipuAI"
raise ImportError("langchain is not installed. Please install it using `pip install langchain`")
class LangchainLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
provider = self.config.langchain_provider
if provider not in LangchainProvider.__members__:
raise ValueError(f"Invalid provider: {provider}")
model_name = LangchainProvider[provider].value
if self.config.model is None:
raise ValueError("`model` parameter is required")
try:
# Check if this provider needs a specialized package
if provider in PROVIDER_PACKAGES:
if provider == "Anthropic": # Special handling for Anthropic with Pydantic v2
try:
from langchain_anthropic import ChatAnthropic
model_class = ChatAnthropic
except ImportError:
raise ImportError("langchain_anthropic not found. Please install it with `pip install langchain-anthropic`")
else:
package_name = PROVIDER_PACKAGES[provider]
try:
# Import the model class directly from the package
module_path = f"{package_name}"
model_class = __import__(module_path, fromlist=[model_name])
model_class = getattr(model_class, model_name)
except ImportError:
raise ImportError(
f"Package {package_name} not found. " f"Please install it with `pip install {package_name}`"
)
except AttributeError:
raise ImportError(f"Model {model_name} not found in {package_name}")
else:
# Use the default langchain_community module
if not hasattr(chat_models, model_name):
raise ImportError(f"Provider {provider} not found in langchain_community.chat_models")
if not isinstance(self.config.model, BaseChatModel):
raise ValueError("`model` must be an instance of BaseChatModel")
model_class = getattr(chat_models, model_name)
# Initialize the model with relevant config parameters
self.langchain_model = model_class(
model=self.config.model,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens
)
except (ImportError, AttributeError, ValueError) as e:
raise ImportError(f"Error setting up langchain model for provider {provider}: {str(e)}")
self.langchain_model = self.config.model
def generate_response(
self,

View File

@@ -623,14 +623,13 @@ class Memory(MemoryBase):
capture_event("mem0._create_memory", self, {"memory_id": memory_id})
return memory_id
def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None):
def _create_procedural_memory(self, messages, metadata=None, prompt=None):
"""
Create a procedural memory
Args:
messages (list): List of messages to create a procedural memory from.
metadata (dict): Metadata to create a procedural memory from.
llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel.
prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None.
"""
try:
@@ -650,12 +649,7 @@ class Memory(MemoryBase):
]
try:
if llm is not None:
parsed_messages = convert_to_messages(parsed_messages)
response = llm.invoke(input=parsed_messages)
procedural_memory = response.content
else:
procedural_memory = self.llm.generate_response(messages=parsed_messages)
procedural_memory = self.llm.generate_response(messages=parsed_messages)
except Exception as e:
logger.error(f"Error generating procedural memory summary: {e}")
raise

View File

@@ -50,6 +50,7 @@ class EmbedderFactory:
"vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding",
"together": "mem0.embeddings.together.TogetherEmbedding",
"lmstudio": "mem0.embeddings.lmstudio.LMStudioEmbedding",
"langchain": "mem0.embeddings.langchain.LangchainEmbedding",
}
@classmethod