Add LM Studio support (#2425)
This commit is contained in:
@@ -30,6 +30,8 @@ class BaseEmbedderConfig(ABC):
|
||||
memory_add_embedding_type: Optional[str] = None,
|
||||
memory_update_embedding_type: Optional[str] = None,
|
||||
memory_search_embedding_type: Optional[str] = None,
|
||||
# LM Studio specific
|
||||
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the Embeddings.
|
||||
@@ -58,6 +60,8 @@ class BaseEmbedderConfig(ABC):
|
||||
:type memory_update_embedding_type: Optional[str], optional
|
||||
:param memory_search_embedding_type: The type of embedding to use for the search memory action, defaults to None
|
||||
:type memory_search_embedding_type: Optional[str], optional
|
||||
:param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1"
|
||||
:type lmstudio_base_url: Optional[str], optional
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
@@ -82,3 +86,6 @@ class BaseEmbedderConfig(ABC):
|
||||
self.memory_add_embedding_type = memory_add_embedding_type
|
||||
self.memory_update_embedding_type = memory_update_embedding_type
|
||||
self.memory_search_embedding_type = memory_search_embedding_type
|
||||
|
||||
# LM Studio specific
|
||||
self.lmstudio_base_url = lmstudio_base_url
|
||||
|
||||
@@ -39,6 +39,8 @@ class BaseLlmConfig(ABC):
|
||||
deepseek_base_url: Optional[str] = None,
|
||||
# XAI specific
|
||||
xai_base_url: Optional[str] = None,
|
||||
# LM Studio specific
|
||||
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the LLM.
|
||||
@@ -83,6 +85,8 @@ class BaseLlmConfig(ABC):
|
||||
:type deepseek_base_url: Optional[str], optional
|
||||
:param xai_base_url: XAI base URL to be use, defaults to None
|
||||
:type xai_base_url: Optional[str], optional
|
||||
:param lmstudio_base_url: LM Studio base URL to be use, defaults to "http://localhost:1234/v1"
|
||||
:type lmstudio_base_url: Optional[str], optional
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
@@ -116,3 +120,6 @@ class BaseLlmConfig(ABC):
|
||||
|
||||
# XAI specific
|
||||
self.xai_base_url = xai_base_url
|
||||
|
||||
# LM Studio specific
|
||||
self.lmstudio_base_url = lmstudio_base_url
|
||||
|
||||
@@ -13,7 +13,7 @@ class EmbedderConfig(BaseModel):
|
||||
@field_validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
provider = values.data.get("provider")
|
||||
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai", "together"]:
|
||||
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai", "together", "lmstudio"]:
|
||||
return v
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding provider: {provider}")
|
||||
|
||||
33
mem0/embeddings/lmstudio.py
Normal file
33
mem0/embeddings/lmstudio.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
|
||||
|
||||
class LMStudioEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = self.config.model or "nomic-ai/nomic-embed-text-v1.5-GGUF/nomic-embed-text-v1.5.f16.gguf"
|
||||
self.config.embedding_dims = self.config.embedding_dims or 1536
|
||||
self.config.api_key = self.config.api_key or "lm-studio"
|
||||
|
||||
self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Get the embedding for the given text using LM Studio.
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
return (
|
||||
self.client.embeddings.create(input=[text], model=self.config.model)
|
||||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
@@ -24,6 +24,7 @@ class LlmConfig(BaseModel):
|
||||
"gemini",
|
||||
"deepseek",
|
||||
"xai",
|
||||
"lmstudio",
|
||||
):
|
||||
return v
|
||||
else:
|
||||
|
||||
48
mem0/llms/lmstudio.py
Normal file
48
mem0/llms/lmstudio.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class LMStudioLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = self.config.model or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf"
|
||||
self.config.api_key = self.config.api_key or "lm-studio"
|
||||
|
||||
self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format: dict = {"type": "json_object"},
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto"
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using LM Studio.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return response.choices[0].message.content
|
||||
@@ -25,6 +25,7 @@ class LlmFactory:
|
||||
"gemini": "mem0.llms.gemini.GeminiLLM",
|
||||
"deepseek": "mem0.llms.deepseek.DeepSeekLLM",
|
||||
"xai": "mem0.llms.xai.XAILLM",
|
||||
"lmstudio": "mem0.llms.lmstudio.LMStudioLLM",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -47,6 +48,7 @@ class EmbedderFactory:
|
||||
"gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding",
|
||||
"vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding",
|
||||
"together": "mem0.embeddings.together.TogetherEmbedding",
|
||||
"lmstudio": "mem0.embeddings.lmstudio.LMStudioEmbedding",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user