From e9bc4cdc958627fbc3e02083a14a6f3eafca8c21 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Wed, 26 Feb 2025 13:34:01 +0530 Subject: [PATCH] Add Grok Support (#2260) --- docs/components/llms/config.mdx | 2 +- docs/components/llms/models/xai.mdx | 31 +++++++++++++++++++ docs/components/llms/overview.mdx | 1 + docs/docs.json | 4 ++- mem0/configs/llms/base.py | 11 +++++-- mem0/llms/configs.py | 1 + mem0/llms/xai.py | 48 +++++++++++++++++++++++++++++ mem0/utils/factory.py | 1 + 8 files changed, 95 insertions(+), 4 deletions(-) create mode 100644 docs/components/llms/models/xai.mdx create mode 100644 mem0/llms/xai.py diff --git a/docs/components/llms/config.mdx b/docs/components/llms/config.mdx index 226d6868..f3caa654 100644 --- a/docs/components/llms/config.mdx +++ b/docs/components/llms/config.mdx @@ -75,7 +75,7 @@ Here's the table based on the provided parameters: | `openai_base_url` | Base URL for OpenAI API | OpenAI | | `azure_kwargs` | Azure LLM args for initialization | AzureOpenAI | | `deepseek_base_url` | Base URL for DeepSeek API | DeepSeek | - +| `xai_base_url` | Base URL for XAI API | XAI | ## Supported LLMs diff --git a/docs/components/llms/models/xai.mdx b/docs/components/llms/models/xai.mdx new file mode 100644 index 00000000..bd8ea60c --- /dev/null +++ b/docs/components/llms/models/xai.mdx @@ -0,0 +1,31 @@ +[XAI](https://x.ai/) is a new AI company founded by Elon Musk that develops large language models, including Grok. Grok is trained on real-time data from X (formerly Twitter) and aims to provide accurate, up-to-date responses with a touch of wit and humor. + +In order to use LLMs from XAI, go to their [platform](https://console.x.ai) and get the API key. Set the API key as `XAI_API_KEY` environment variable to use the model as given below in the example. + +## Usage + +```python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "your-api-key" # used for embedding model +os.environ["XAI_API_KEY"] = "your-api-key" + +config = { + "llm": { + "provider": "xai", + "config": { + "model": "grok-2-latest", + "temperature": 0.1, + "max_tokens": 1000, + } + } +} + +m = Memory.from_config(config) +m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) +``` + +## Config + +All available parameters for the `xai` config are present in [Master List of All Params in Config](../config). \ No newline at end of file diff --git a/docs/components/llms/overview.mdx b/docs/components/llms/overview.mdx index 4489bc0f..abfe4024 100644 --- a/docs/components/llms/overview.mdx +++ b/docs/components/llms/overview.mdx @@ -27,6 +27,7 @@ To view all supported llms, visit the [Supported LLMs](./models). + ## Structured vs Unstructured Outputs diff --git a/docs/docs.json b/docs/docs.json index d4d92435..d6e73a87 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -102,7 +102,9 @@ "components/llms/models/mistral_AI", "components/llms/models/google_AI", "components/llms/models/aws_bedrock", - "components/llms/models/gemini" + "components/llms/models/gemini", + "components/llms/models/deepseek", + "components/llms/models/xai" ] } ] diff --git a/mem0/configs/llms/base.py b/mem0/configs/llms/base.py index 78ad13b1..36e32257 100644 --- a/mem0/configs/llms/base.py +++ b/mem0/configs/llms/base.py @@ -14,10 +14,10 @@ class BaseLlmConfig(ABC): def __init__( self, model: Optional[str] = None, - temperature: float = 0, + temperature: float = 0.1, api_key: Optional[str] = None, max_tokens: int = 3000, - top_p: float = 0, + top_p: float = 0.1, top_k: int = 1, # Openrouter specific models: Optional[list[str]] = None, @@ -35,6 +35,8 @@ class BaseLlmConfig(ABC): http_client_proxies: Optional[Union[Dict, str]] = None, # DeepSeek specific deepseek_base_url: Optional[str] = None, + # XAI specific + xai_base_url: Optional[str] = None, ): """ Initializes a configuration class instance for the LLM. @@ -73,6 +75,8 @@ class BaseLlmConfig(ABC): :type http_client_proxies: Optional[Dict | str], optional :param deepseek_base_url: DeepSeek base URL to be use, defaults to None :type deepseek_base_url: Optional[str], optional + :param xai_base_url: XAI base URL to be use, defaults to None + :type xai_base_url: Optional[str], optional """ self.model = model @@ -101,3 +105,6 @@ class BaseLlmConfig(ABC): # AzureOpenAI specific self.azure_kwargs = AzureConfig(**azure_kwargs) or {} + + # XAI specific + self.xai_base_url = xai_base_url diff --git a/mem0/llms/configs.py b/mem0/llms/configs.py index 5a806903..4c095e96 100644 --- a/mem0/llms/configs.py +++ b/mem0/llms/configs.py @@ -23,6 +23,7 @@ class LlmConfig(BaseModel): "azure_openai_structured", "gemini", "deepseek", + "xai" ): return v else: diff --git a/mem0/llms/xai.py b/mem0/llms/xai.py new file mode 100644 index 00000000..60210e11 --- /dev/null +++ b/mem0/llms/xai.py @@ -0,0 +1,48 @@ +import os +from typing import Dict, List, Optional + +from openai import OpenAI + +from mem0.configs.llms.base import BaseLlmConfig +from mem0.llms.base import LLMBase + + +class XAILLM(LLMBase): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config) + + if not self.config.model: + self.config.model = "grok-2-latest" + + api_key = self.config.api_key or os.getenv("XAI_API_KEY") + base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1" + self.client = OpenAI(api_key=api_key, base_url=base_url) + + def generate_response( + self, + messages: List[Dict[str, str]], + response_format=None + ): + """ + Generate a response based on the given messages using XAI. + + Args: + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + + Returns: + str: The generated response. + """ + params = { + "model": self.config.model, + "messages": messages, + "temperature": self.config.temperature, + "max_tokens": self.config.max_tokens, + "top_p": self.config.top_p, + } + + if response_format: + params["response_format"] = response_format + + response = self.client.chat.completions.create(**params) + return response.choices[0].message.content diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 0505f461..82af19d5 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -24,6 +24,7 @@ class LlmFactory: "azure_openai_structured": "mem0.llms.azure_openai_structured.AzureOpenAIStructuredLLM", "gemini": "mem0.llms.gemini.GeminiLLM", "deepseek": "mem0.llms.deepseek.DeepSeekLLM", + "xai": "mem0.llms.xai.XAILLM", } @classmethod