Add support for sarvam-m model (#2802)
This commit is contained in:
@@ -39,6 +39,8 @@ class BaseLlmConfig(ABC):
|
||||
deepseek_base_url: Optional[str] = None,
|
||||
# XAI specific
|
||||
xai_base_url: Optional[str] = None,
|
||||
# Sarvam specific
|
||||
sarvam_base_url: Optional[str] = "https://api.sarvam.ai/v1",
|
||||
# LM Studio specific
|
||||
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
||||
# AWS Bedrock specific
|
||||
@@ -89,6 +91,8 @@ class BaseLlmConfig(ABC):
|
||||
:type deepseek_base_url: Optional[str], optional
|
||||
:param xai_base_url: XAI base URL to be use, defaults to None
|
||||
:type xai_base_url: Optional[str], optional
|
||||
:param sarvam_base_url: Sarvam base URL to be use, defaults to "https://api.sarvam.ai/v1"
|
||||
:type sarvam_base_url: Optional[str], optional
|
||||
:param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1"
|
||||
:type lmstudio_base_url: Optional[str], optional
|
||||
"""
|
||||
@@ -125,6 +129,9 @@ class BaseLlmConfig(ABC):
|
||||
# XAI specific
|
||||
self.xai_base_url = xai_base_url
|
||||
|
||||
# Sarvam specific
|
||||
self.sarvam_base_url = sarvam_base_url
|
||||
|
||||
# LM Studio specific
|
||||
self.lmstudio_base_url = lmstudio_base_url
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ class LlmConfig(BaseModel):
|
||||
"gemini",
|
||||
"deepseek",
|
||||
"xai",
|
||||
"sarvam",
|
||||
"lmstudio",
|
||||
"langchain",
|
||||
):
|
||||
|
||||
100
mem0/llms/sarvam.py
Normal file
100
mem0/llms/sarvam.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import os
|
||||
import requests
|
||||
from typing import Dict, List, Optional
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class SarvamLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
# Set default model if not provided
|
||||
if not self.config.model:
|
||||
self.config.model = "sarvam-m"
|
||||
|
||||
# Get API key from config or environment variable
|
||||
self.api_key = self.config.api_key or os.getenv("SARVAM_API_KEY")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Sarvam API key is required. Set SARVAM_API_KEY environment variable "
|
||||
"or provide api_key in config."
|
||||
)
|
||||
|
||||
# Set base URL - use config value or environment or default
|
||||
self.base_url = (
|
||||
getattr(self.config, 'sarvam_base_url', None) or
|
||||
os.getenv("SARVAM_API_BASE") or
|
||||
"https://api.sarvam.ai/v1"
|
||||
)
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using Sarvam-M.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response.
|
||||
Currently not used by Sarvam API.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Prepare the request payload
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": self.config.model if isinstance(self.config.model, str) else "sarvam-m",
|
||||
}
|
||||
|
||||
# Add standard parameters that already exist in BaseLlmConfig
|
||||
if self.config.temperature is not None:
|
||||
params["temperature"] = self.config.temperature
|
||||
|
||||
if self.config.max_tokens is not None:
|
||||
params["max_tokens"] = self.config.max_tokens
|
||||
|
||||
if self.config.top_p is not None:
|
||||
params["top_p"] = self.config.top_p
|
||||
|
||||
# Handle Sarvam-specific parameters if model is passed as dict
|
||||
if isinstance(self.config.model, dict):
|
||||
# Extract model name
|
||||
params["model"] = self.config.model.get("name", "sarvam-m")
|
||||
|
||||
# Add Sarvam-specific parameters
|
||||
sarvam_specific_params = [
|
||||
'reasoning_effort', 'frequency_penalty', 'presence_penalty',
|
||||
'seed', 'stop', 'n'
|
||||
]
|
||||
|
||||
for param in sarvam_specific_params:
|
||||
if param in self.config.model:
|
||||
params[param] = self.config.model[param]
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=params, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
|
||||
if 'choices' in result and len(result['choices']) > 0:
|
||||
return result['choices'][0]['message']['content']
|
||||
else:
|
||||
raise ValueError("No response choices found in Sarvam API response")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Sarvam API request failed: {e}")
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unexpected response format from Sarvam API: {e}")
|
||||
@@ -27,6 +27,7 @@ class LlmFactory:
|
||||
"gemini": "mem0.llms.gemini.GeminiLLM",
|
||||
"deepseek": "mem0.llms.deepseek.DeepSeekLLM",
|
||||
"xai": "mem0.llms.xai.XAILLM",
|
||||
"sarvam": "mem0.llms.sarvam.SarvamLLM",
|
||||
"lmstudio": "mem0.llms.lmstudio.LMStudioLLM",
|
||||
"langchain": "mem0.llms.langchain.LangchainLLM",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user