From 1ba9c71f540ae475bc6436d346b12128686bad3a Mon Sep 17 00:00:00 2001 From: Antaripa Saha Date: Mon, 26 May 2025 23:19:37 +0530 Subject: [PATCH] Add support for sarvam-m model (#2802) --- docs/components/llms/config.mdx | 6 ++ docs/components/llms/models/sarvam.mdx | 75 +++++++++++++++++++ docs/components/llms/overview.mdx | 1 + docs/docs.json | 1 + mem0/configs/llms/base.py | 7 ++ mem0/llms/configs.py | 1 + mem0/llms/sarvam.py | 100 +++++++++++++++++++++++++ mem0/utils/factory.py | 1 + 8 files changed, 192 insertions(+) create mode 100644 docs/components/llms/models/sarvam.mdx create mode 100644 mem0/llms/sarvam.py diff --git a/docs/components/llms/config.mdx b/docs/components/llms/config.mdx index bb13cbcb..0bb1cf16 100644 --- a/docs/components/llms/config.mdx +++ b/docs/components/llms/config.mdx @@ -110,6 +110,12 @@ Here's a comprehensive list of all parameters that can be used across different | `azure_kwargs` | Azure LLM args for initialization | AzureOpenAI | | `deepseek_base_url` | Base URL for DeepSeek API | DeepSeek | | `xai_base_url` | Base URL for XAI API | XAI | + | `sarvam_base_url` | Base URL for Sarvam API | Sarvam | + | `reasoning_effort` | Reasoning level (low, medium, high) | Sarvam | + | `frequency_penalty` | Penalize frequent tokens (-2.0 to 2.0) | Sarvam | + | `presence_penalty` | Penalize existing tokens (-2.0 to 2.0) | Sarvam | + | `seed` | Seed for deterministic sampling | Sarvam | + | `stop` | Stop sequences (max 4) | Sarvam | | `lmstudio_base_url` | Base URL for LM Studio API | LM Studio | diff --git a/docs/components/llms/models/sarvam.mdx b/docs/components/llms/models/sarvam.mdx new file mode 100644 index 00000000..daaa0363 --- /dev/null +++ b/docs/components/llms/models/sarvam.mdx @@ -0,0 +1,75 @@ +--- +title: Sarvam AI +--- + + + +**Sarvam AI** is an Indian AI company developing language models with a focus on Indian languages and cultural context. Their latest model **Sarvam-M** is designed to understand and generate content in multiple Indian languages while maintaining high performance in English. + +To use Sarvam AI's models, please set the `SARVAM_API_KEY` which you can get from their [platform](https://dashboard.sarvam.ai/). + +## Usage + +```python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "your-api-key" # used for embedding model +os.environ["SARVAM_API_KEY"] = "your-api-key" + +config = { + "llm": { + "provider": "sarvam", + "config": { + "model": "sarvam-m", + "temperature": 0.7, + } + } +} + +m = Memory.from_config(config) +messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, + {"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] +m.add(messages, user_id="alex") +``` + +## Advanced Usage with Sarvam-Specific Features + +```python +import os +from mem0 import Memory + +config = { + "llm": { + "provider": "sarvam", + "config": { + "model": { + "name": "sarvam-m", + "reasoning_effort": "high", # Enable advanced reasoning + "frequency_penalty": 0.1, # Reduce repetition + "seed": 42 # For deterministic outputs + }, + "temperature": 0.3, + "max_tokens": 2000, + "api_key": "your-sarvam-api-key" + } + } +} + +m = Memory.from_config(config) + +# Example with Hindi conversation +messages = [ + {"role": "user", "content": "मैं SBI में joint account खोलना चाहता हूँ।"}, + {"role": "assistant", "content": "SBI में joint account खोलने के लिए आपको कुछ documents की जरूरत होगी। क्या आप जानना चाहते हैं कि कौन से documents चाहिए?"} +] +m.add(messages, user_id="rajesh", metadata={"language": "hindi", "topic": "banking"}) +``` + +## Config + +All available parameters for the `sarvam` config are present in [Master List of All Params in Config](../config). diff --git a/docs/components/llms/overview.mdx b/docs/components/llms/overview.mdx index e3114bb7..f5249ad0 100644 --- a/docs/components/llms/overview.mdx +++ b/docs/components/llms/overview.mdx @@ -34,6 +34,7 @@ To view all supported llms, visit the [Supported LLMs](./models). + diff --git a/docs/docs.json b/docs/docs.json index d3a97d34..65ae9590 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -115,6 +115,7 @@ "components/llms/models/gemini", "components/llms/models/deepseek", "components/llms/models/xAI", + "components/llms/models/sarvam", "components/llms/models/lmstudio", "components/llms/models/langchain" ] diff --git a/mem0/configs/llms/base.py b/mem0/configs/llms/base.py index 983d9f8e..271de2ef 100644 --- a/mem0/configs/llms/base.py +++ b/mem0/configs/llms/base.py @@ -39,6 +39,8 @@ class BaseLlmConfig(ABC): deepseek_base_url: Optional[str] = None, # XAI specific xai_base_url: Optional[str] = None, + # Sarvam specific + sarvam_base_url: Optional[str] = "https://api.sarvam.ai/v1", # LM Studio specific lmstudio_base_url: Optional[str] = "http://localhost:1234/v1", # AWS Bedrock specific @@ -89,6 +91,8 @@ class BaseLlmConfig(ABC): :type deepseek_base_url: Optional[str], optional :param xai_base_url: XAI base URL to be use, defaults to None :type xai_base_url: Optional[str], optional + :param sarvam_base_url: Sarvam base URL to be use, defaults to "https://api.sarvam.ai/v1" + :type sarvam_base_url: Optional[str], optional :param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1" :type lmstudio_base_url: Optional[str], optional """ @@ -125,6 +129,9 @@ class BaseLlmConfig(ABC): # XAI specific self.xai_base_url = xai_base_url + # Sarvam specific + self.sarvam_base_url = sarvam_base_url + # LM Studio specific self.lmstudio_base_url = lmstudio_base_url diff --git a/mem0/llms/configs.py b/mem0/llms/configs.py index 615fc37a..68ec661e 100644 --- a/mem0/llms/configs.py +++ b/mem0/llms/configs.py @@ -24,6 +24,7 @@ class LlmConfig(BaseModel): "gemini", "deepseek", "xai", + "sarvam", "lmstudio", "langchain", ): diff --git a/mem0/llms/sarvam.py b/mem0/llms/sarvam.py new file mode 100644 index 00000000..16b72b0d --- /dev/null +++ b/mem0/llms/sarvam.py @@ -0,0 +1,100 @@ +import os +import requests +from typing import Dict, List, Optional +from mem0.configs.llms.base import BaseLlmConfig +from mem0.llms.base import LLMBase + + +class SarvamLLM(LLMBase): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config) + + # Set default model if not provided + if not self.config.model: + self.config.model = "sarvam-m" + + # Get API key from config or environment variable + self.api_key = self.config.api_key or os.getenv("SARVAM_API_KEY") + + if not self.api_key: + raise ValueError( + "Sarvam API key is required. Set SARVAM_API_KEY environment variable " + "or provide api_key in config." + ) + + # Set base URL - use config value or environment or default + self.base_url = ( + getattr(self.config, 'sarvam_base_url', None) or + os.getenv("SARVAM_API_BASE") or + "https://api.sarvam.ai/v1" + ) + + def generate_response( + self, + messages: List[Dict[str, str]], + response_format=None + ) -> str: + """ + Generate a response based on the given messages using Sarvam-M. + + Args: + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. + Currently not used by Sarvam API. + + Returns: + str: The generated response. + """ + url = f"{self.base_url}/chat/completions" + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # Prepare the request payload + params = { + "messages": messages, + "model": self.config.model if isinstance(self.config.model, str) else "sarvam-m", + } + + # Add standard parameters that already exist in BaseLlmConfig + if self.config.temperature is not None: + params["temperature"] = self.config.temperature + + if self.config.max_tokens is not None: + params["max_tokens"] = self.config.max_tokens + + if self.config.top_p is not None: + params["top_p"] = self.config.top_p + + # Handle Sarvam-specific parameters if model is passed as dict + if isinstance(self.config.model, dict): + # Extract model name + params["model"] = self.config.model.get("name", "sarvam-m") + + # Add Sarvam-specific parameters + sarvam_specific_params = [ + 'reasoning_effort', 'frequency_penalty', 'presence_penalty', + 'seed', 'stop', 'n' + ] + + for param in sarvam_specific_params: + if param in self.config.model: + params[param] = self.config.model[param] + + try: + response = requests.post(url, headers=headers, json=params, timeout=30) + response.raise_for_status() + + result = response.json() + + if 'choices' in result and len(result['choices']) > 0: + return result['choices'][0]['message']['content'] + else: + raise ValueError("No response choices found in Sarvam API response") + + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Sarvam API request failed: {e}") + except KeyError as e: + raise ValueError(f"Unexpected response format from Sarvam API: {e}") \ No newline at end of file diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 93ebbd8a..d137e273 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -27,6 +27,7 @@ class LlmFactory: "gemini": "mem0.llms.gemini.GeminiLLM", "deepseek": "mem0.llms.deepseek.DeepSeekLLM", "xai": "mem0.llms.xai.XAILLM", + "sarvam": "mem0.llms.sarvam.SarvamLLM", "lmstudio": "mem0.llms.lmstudio.LMStudioLLM", "langchain": "mem0.llms.langchain.LangchainLLM", }