Support for langchain LLMs (#2506)

This commit is contained in:
Dev Khant
2025-04-07 11:28:30 +05:30
committed by GitHub
parent d30c78c5eb
commit 39e5cbfacc
9 changed files with 393 additions and 1 deletions

View File

@@ -25,6 +25,7 @@ class LlmConfig(BaseModel):
"deepseek",
"xai",
"lmstudio",
"langchain",
):
return v
else:

208
mem0/llms/langchain.py Normal file
View File

@@ -0,0 +1,208 @@
from typing import Dict, List, Optional
import enum
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
# Default import for langchain_community
try:
from langchain_community import chat_models
except ImportError:
raise ImportError("langchain_community not found. Please install it with `pip install langchain-community`")
# Provider-specific package mapping
PROVIDER_PACKAGES = {
# "Anthropic": "langchain_anthropic", # Special handling for Anthropic with Pydantic v2
"MistralAI": "langchain_mistralai",
"Fireworks": "langchain_fireworks",
"AzureOpenAI": "langchain_openai",
"OpenAI": "langchain_openai",
"Together": "langchain_together",
"VertexAI": "langchain_google_vertexai",
"GoogleAI": "langchain_google_genai",
"Groq": "langchain_groq",
"Cohere": "langchain_cohere",
"Bedrock": "langchain_aws",
"HuggingFace": "langchain_huggingface",
"NVIDIA": "langchain_nvidia_ai_endpoints",
"Ollama": "langchain_ollama",
"AI21": "langchain_ai21",
"Upstage": "langchain_upstage",
"Databricks": "databricks_langchain",
"Watsonx": "langchain_ibm",
"xAI": "langchain_xai",
"Perplexity": "langchain_perplexity",
}
class LangchainProvider(enum.Enum):
Abso = "ChatAbso"
AI21 = "ChatAI21"
Alibaba = "ChatAlibabaCloud"
Anthropic = "ChatAnthropic"
Anyscale = "ChatAnyscale"
AzureAIChatCompletionsModel = "AzureAIChatCompletionsModel"
AzureOpenAI = "AzureChatOpenAI"
AzureMLEndpoint = "ChatAzureMLEndpoint"
Baichuan = "ChatBaichuan"
Qianfan = "ChatQianfan"
Bedrock = "ChatBedrock"
Cerebras = "ChatCerebras"
CloudflareWorkersAI = "ChatCloudflareWorkersAI"
Cohere = "ChatCohere"
ContextualAI = "ChatContextualAI"
Coze = "ChatCoze"
Dappier = "ChatDappier"
Databricks = "ChatDatabricks"
DeepInfra = "ChatDeepInfra"
DeepSeek = "ChatDeepSeek"
EdenAI = "ChatEdenAI"
EverlyAI = "ChatEverlyAI"
Fireworks = "ChatFireworks"
Friendli = "ChatFriendli"
GigaChat = "ChatGigaChat"
Goodfire = "ChatGoodfire"
GoogleAI = "ChatGoogleAI"
VertexAI = "VertexAI"
GPTRouter = "ChatGPTRouter"
Groq = "ChatGroq"
HuggingFace = "ChatHuggingFace"
Watsonx = "ChatWatsonx"
Jina = "ChatJina"
Kinetica = "ChatKinetica"
Konko = "ChatKonko"
LiteLLM = "ChatLiteLLM"
LiteLLMRouter = "ChatLiteLLMRouter"
Llama2Chat = "Llama2Chat"
LlamaAPI = "ChatLlamaAPI"
LlamaEdge = "ChatLlamaEdge"
LlamaCpp = "ChatLlamaCpp"
Maritalk = "ChatMaritalk"
MiniMax = "ChatMiniMax"
MistralAI = "ChatMistralAI"
MLX = "ChatMLX"
ModelScope = "ChatModelScope"
Moonshot = "ChatMoonshot"
Naver = "ChatNaver"
Netmind = "ChatNetmind"
NVIDIA = "ChatNVIDIA"
OCIModelDeployment = "ChatOCIModelDeployment"
OCIGenAI = "ChatOCIGenAI"
OctoAI = "ChatOctoAI"
Ollama = "ChatOllama"
OpenAI = "ChatOpenAI"
Outlines = "ChatOutlines"
Perplexity = "ChatPerplexity"
Pipeshift = "ChatPipeshift"
PredictionGuard = "ChatPredictionGuard"
PremAI = "ChatPremAI"
PromptLayerOpenAI = "PromptLayerChatOpenAI"
QwQ = "ChatQwQ"
Reka = "ChatReka"
RunPod = "ChatRunPod"
SambaNovaCloud = "ChatSambaNovaCloud"
SambaStudio = "ChatSambaStudio"
SeekrFlow = "ChatSeekrFlow"
SnowflakeCortex = "ChatSnowflakeCortex"
Solar = "ChatSolar"
SparkLLM = "ChatSparkLLM"
Nebula = "ChatNebula"
Hunyuan = "ChatHunyuan"
Together = "ChatTogether"
TongyiQwen = "ChatTongyiQwen"
Upstage = "ChatUpstage"
Vectara = "ChatVectara"
VLLM = "ChatVLLM"
VolcEngine = "ChatVolcEngine"
Writer = "ChatWriter"
xAI = "ChatXAI"
Xinference = "ChatXinference"
Yandex = "ChatYandex"
Yi = "ChatYi"
Yuan2 = "ChatYuan2"
ZhipuAI = "ChatZhipuAI"
class LangchainLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
provider = self.config.langchain_provider
if provider not in LangchainProvider.__members__:
raise ValueError(f"Invalid provider: {provider}")
model_name = LangchainProvider[provider].value
try:
# Check if this provider needs a specialized package
if provider in PROVIDER_PACKAGES:
package_name = PROVIDER_PACKAGES[provider]
try:
# Import the model class directly from the package
module_path = f"{package_name}"
model_class = __import__(module_path, fromlist=[model_name])
model_class = getattr(model_class, model_name)
except ImportError:
raise ImportError(
f"Package {package_name} not found. " f"Please install it with `pip install {package_name}`"
)
except AttributeError:
raise ImportError(f"Model {model_name} not found in {package_name}")
else:
# Use the default langchain_community module
if not hasattr(chat_models, model_name):
raise ImportError(f"Provider {provider} not found in langchain_community.chat_models")
model_class = getattr(chat_models, model_name)
# Initialize the model with relevant config parameters
self.langchain_model = model_class(
model=self.config.model,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
api_key=self.config.api_key,
)
except (ImportError, AttributeError, ValueError) as e:
raise ImportError(f"Error setting up langchain model for provider {provider}: {str(e)}")
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using langchain_community.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Not used in Langchain.
tools (list, optional): List of tools that the model can call. Not used in Langchain.
tool_choice (str, optional): Tool choice method. Not used in Langchain.
Returns:
str: The generated response.
"""
try:
# Convert the messages to LangChain's tuple format
langchain_messages = []
for message in messages:
role = message["role"]
content = message["content"]
if role == "system":
langchain_messages.append(("system", content))
elif role == "user":
langchain_messages.append(("human", content))
elif role == "assistant":
langchain_messages.append(("ai", content))
if not langchain_messages:
raise ValueError("No valid messages found in the messages list")
ai_message = self.langchain_model.invoke(langchain_messages)
return ai_message.content
except Exception as e:
raise Exception(f"Error generating response using langchain model: {str(e)}")