Feature/vllm support (#2981)
This commit is contained in:
@@ -58,6 +58,7 @@ config = {
|
||||
|
||||
m = Memory.from_config(config)
|
||||
m.add("Your text here", user_id="user", metadata={"category": "example"})
|
||||
|
||||
```
|
||||
|
||||
```typescript TypeScript
|
||||
@@ -76,6 +77,7 @@ const config = {
|
||||
const memory = new Memory(config);
|
||||
await memory.add("Your text here", { userId: "user123", metadata: { category: "example" } });
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
## Why is Config Needed?
|
||||
|
||||
109
docs/components/llms/models/vllm.mdx
Normal file
109
docs/components/llms/models/vllm.mdx
Normal file
@@ -0,0 +1,109 @@
|
||||
---
|
||||
title: vLLM
|
||||
---
|
||||
|
||||
<Snippet file="paper-release.mdx" />
|
||||
|
||||
[vLLM](https://docs.vllm.ai/) is a high-performance inference engine for large language models that provides significant performance improvements for local inference. It's designed to maximize throughput and memory efficiency for serving LLMs.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. **Install vLLM**:
|
||||
|
||||
```bash
|
||||
pip install vllm
|
||||
```
|
||||
|
||||
2. **Start vLLM server**:
|
||||
|
||||
```bash
|
||||
# For testing with a small model
|
||||
vllm serve microsoft/DialoGPT-medium --port 8000
|
||||
|
||||
# For production with a larger model (requires GPU)
|
||||
vllm serve Qwen/Qwen2.5-32B-Instruct --port 8000
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
import os
|
||||
from mem0 import Memory
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "your-api-key" # used for embedding model
|
||||
|
||||
config = {
|
||||
"llm": {
|
||||
"provider": "vllm",
|
||||
"config": {
|
||||
"model": "Qwen/Qwen2.5-32B-Instruct",
|
||||
"vllm_base_url": "http://localhost:8000/v1",
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 2000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
messages = [
|
||||
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||
{"role": "assistant", "content": "How about thriller movies? They can be quite engaging."},
|
||||
{"role": "user", "content": "I'm not a big fan of thrillers, but I love sci-fi movies."},
|
||||
{"role": "assistant", "content": "Got it! I'll avoid thrillers and suggest sci-fi movies instead."}
|
||||
]
|
||||
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||
```
|
||||
|
||||
## Configuration Parameters
|
||||
|
||||
| Parameter | Description | Default | Environment Variable |
|
||||
| --------------- | --------------------------------- | ----------------------------- | -------------------- |
|
||||
| `model` | Model name running on vLLM server | `"Qwen/Qwen2.5-32B-Instruct"` | - |
|
||||
| `vllm_base_url` | vLLM server URL | `"http://localhost:8000/v1"` | `VLLM_BASE_URL` |
|
||||
| `api_key` | API key (dummy for local) | `"vllm-api-key"` | `VLLM_API_KEY` |
|
||||
| `temperature` | Sampling temperature | `0.1` | - |
|
||||
| `max_tokens` | Maximum tokens to generate | `2000` | - |
|
||||
|
||||
## Environment Variables
|
||||
|
||||
You can set these environment variables instead of specifying them in config:
|
||||
|
||||
```bash
|
||||
export VLLM_BASE_URL="http://localhost:8000/v1"
|
||||
export VLLM_API_KEY="your-vllm-api-key"
|
||||
export OPENAI_API_KEY="your-openai-api-key" # for embeddings
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
- **High Performance**: 2-24x faster inference than standard implementations
|
||||
- **Memory Efficient**: Optimized memory usage with PagedAttention
|
||||
- **Local Deployment**: Keep your data private and reduce API costs
|
||||
- **Easy Integration**: Drop-in replacement for other LLM providers
|
||||
- **Flexible**: Works with any model supported by vLLM
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
1. **Server not responding**: Make sure vLLM server is running
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
2. **404 errors**: Ensure correct base URL format
|
||||
|
||||
```python
|
||||
"vllm_base_url": "http://localhost:8000/v1" # Note the /v1
|
||||
```
|
||||
|
||||
3. **Model not found**: Check model name matches server
|
||||
|
||||
4. **Out of memory**: Try smaller models or reduce `max_model_len`
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2.5-32B-Instruct --max-model-len 4096
|
||||
```
|
||||
|
||||
## Config
|
||||
|
||||
All available parameters for the `vllm` config are present in [Master List of All Params in Config](../config).
|
||||
@@ -117,7 +117,8 @@
|
||||
"components/llms/models/xAI",
|
||||
"components/llms/models/sarvam",
|
||||
"components/llms/models/lmstudio",
|
||||
"components/llms/models/langchain"
|
||||
"components/llms/models/langchain",
|
||||
"components/llms/models/vllm"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
144
examples/misc/vllm_example.py
Normal file
144
examples/misc/vllm_example.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Example of using vLLM with mem0 for high-performance memory operations.
|
||||
|
||||
SETUP INSTRUCTIONS:
|
||||
1. Install vLLM:
|
||||
pip install vllm
|
||||
|
||||
2. Start vLLM server (in a separate terminal):
|
||||
vllm serve microsoft/DialoGPT-small --port 8000
|
||||
|
||||
Wait for the message: "Uvicorn running on http://0.0.0.0:8000"
|
||||
(Small model: ~500MB download, much faster!)
|
||||
|
||||
3. Verify server is running:
|
||||
curl http://localhost:8000/health
|
||||
|
||||
4. Run this example:
|
||||
python examples/misc/vllm_example.py
|
||||
|
||||
Optional environment variables:
|
||||
export VLLM_BASE_URL="http://localhost:8000/v1"
|
||||
export VLLM_API_KEY="vllm-api-key"
|
||||
"""
|
||||
|
||||
from mem0 import Memory
|
||||
|
||||
# Configuration for vLLM integration
|
||||
config = {
|
||||
"llm": {
|
||||
"provider": "vllm",
|
||||
"config": {
|
||||
"model": "Qwen/Qwen2.5-32B-Instruct",
|
||||
"vllm_base_url": "http://localhost:8000/v1",
|
||||
"api_key": "vllm-api-key",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100,
|
||||
}
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small"
|
||||
}
|
||||
},
|
||||
"vector_store": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"collection_name": "vllm_memories",
|
||||
"host": "localhost",
|
||||
"port": 6333
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def main():
|
||||
"""
|
||||
Demonstrate vLLM integration with mem0
|
||||
"""
|
||||
print("--> Initializing mem0 with vLLM...")
|
||||
|
||||
# Initialize memory with vLLM
|
||||
memory = Memory.from_config(config)
|
||||
|
||||
print("--> Memory initialized successfully!")
|
||||
|
||||
# Example conversations to store
|
||||
conversations = [
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "I love playing chess on weekends"},
|
||||
{"role": "assistant", "content": "That's great! Chess is an excellent strategic game that helps improve critical thinking."}
|
||||
],
|
||||
"user_id": "user_123"
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "I'm learning Python programming"},
|
||||
{"role": "assistant", "content": "Python is a fantastic language for beginners! What specific areas are you focusing on?"}
|
||||
],
|
||||
"user_id": "user_123"
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "I prefer working late at night, I'm more productive then"},
|
||||
{"role": "assistant", "content": "Many people find they're more creative and focused during nighttime hours. It's important to maintain a consistent schedule that works for you."}
|
||||
],
|
||||
"user_id": "user_123"
|
||||
}
|
||||
]
|
||||
|
||||
print("\n--> Adding memories using vLLM...")
|
||||
|
||||
# Add memories - now powered by vLLM's high-performance inference
|
||||
for i, conversation in enumerate(conversations, 1):
|
||||
result = memory.add(
|
||||
messages=conversation["messages"],
|
||||
user_id=conversation["user_id"]
|
||||
)
|
||||
print(f"Memory {i} added: {result}")
|
||||
|
||||
print("\n🔍 Searching memories...")
|
||||
|
||||
# Search memories - vLLM will process the search and memory operations
|
||||
search_queries = [
|
||||
"What does the user like to do on weekends?",
|
||||
"What is the user learning?",
|
||||
"When is the user most productive?"
|
||||
]
|
||||
|
||||
for query in search_queries:
|
||||
print(f"\nQuery: {query}")
|
||||
memories = memory.search(
|
||||
query=query,
|
||||
user_id="user_123"
|
||||
)
|
||||
|
||||
for memory_item in memories:
|
||||
print(f" - {memory_item['memory']}")
|
||||
|
||||
print("\n--> Getting all memories for user...")
|
||||
all_memories = memory.get_all(user_id="user_123")
|
||||
print(f"Total memories stored: {len(all_memories)}")
|
||||
|
||||
for memory_item in all_memories:
|
||||
print(f" - {memory_item['memory']}")
|
||||
|
||||
print("\n--> vLLM integration demo completed successfully!")
|
||||
print("\nBenefits of using vLLM:")
|
||||
print(" -> 2.7x higher throughput compared to standard implementations")
|
||||
print(" -> 5x faster time-per-output-token")
|
||||
print(" -> Efficient memory usage with PagedAttention")
|
||||
print(" -> Simple configuration, same as other providers")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except Exception as e:
|
||||
print(f"=> Error: {e}")
|
||||
print("\nTroubleshooting:")
|
||||
print("1. Make sure vLLM server is running: vllm serve microsoft/DialoGPT-small --port 8000")
|
||||
print("2. Check if the model is downloaded and accessible")
|
||||
print("3. Verify the base URL and port configuration")
|
||||
print("4. Ensure you have the required dependencies installed")
|
||||
@@ -44,6 +44,8 @@ class BaseLlmConfig(ABC):
|
||||
# LM Studio specific
|
||||
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
||||
lmstudio_response_format: dict = None,
|
||||
# vLLM specific
|
||||
vllm_base_url: Optional[str] = "http://localhost:8000/v1",
|
||||
# AWS Bedrock specific
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
@@ -98,6 +100,8 @@ class BaseLlmConfig(ABC):
|
||||
:type lmstudio_base_url: Optional[str], optional
|
||||
:param lmstudio_response_format: LM Studio response format to be use, defaults to None
|
||||
:type lmstudio_response_format: Optional[Dict], optional
|
||||
:param vllm_base_url: vLLM base URL to be use, defaults to "http://localhost:8000/v1"
|
||||
:type vllm_base_url: Optional[str], optional
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
@@ -139,6 +143,9 @@ class BaseLlmConfig(ABC):
|
||||
self.lmstudio_base_url = lmstudio_base_url
|
||||
self.lmstudio_response_format = lmstudio_response_format
|
||||
|
||||
# vLLM specific
|
||||
self.vllm_base_url = vllm_base_url
|
||||
|
||||
# AWS Bedrock specific
|
||||
self.aws_access_key_id = aws_access_key_id
|
||||
self.aws_secret_access_key = aws_secret_access_key
|
||||
|
||||
@@ -26,6 +26,7 @@ class LlmConfig(BaseModel):
|
||||
"xai",
|
||||
"sarvam",
|
||||
"lmstudio",
|
||||
"vllm",
|
||||
"langchain",
|
||||
):
|
||||
return v
|
||||
|
||||
84
mem0/llms/vllm.py
Normal file
84
mem0/llms/vllm.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class VllmLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "Qwen/Qwen2.5-32B-Instruct"
|
||||
|
||||
self.config.api_key = self.config.api_key or os.getenv("VLLM_API_KEY") or "vllm-api-key"
|
||||
base_url = self.config.vllm_base_url or os.getenv("VLLM_BASE_URL")
|
||||
|
||||
self.client = OpenAI(base_url=base_url, api_key=self.config.api_key)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append({
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
})
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using vLLM.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
@@ -29,6 +29,7 @@ class LlmFactory:
|
||||
"xai": "mem0.llms.xai.XAILLM",
|
||||
"sarvam": "mem0.llms.sarvam.SarvamLLM",
|
||||
"lmstudio": "mem0.llms.lmstudio.LMStudioLLM",
|
||||
"vllm": "mem0.llms.vllm.VllmLLM",
|
||||
"langchain": "mem0.llms.langchain.LangchainLLM",
|
||||
}
|
||||
|
||||
|
||||
0
openmemory/run.sh
Executable file → Normal file
0
openmemory/run.sh
Executable file → Normal file
80
tests/llms/test_vllm.py
Normal file
80
tests/llms/test_vllm.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.vllm import VllmLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vllm_client():
|
||||
with patch("mem0.llms.vllm.OpenAI") as mock_openai:
|
||||
mock_client = Mock()
|
||||
mock_openai.return_value = mock_client
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_generate_response_without_tools(mock_vllm_client):
|
||||
config = BaseLlmConfig(model="Qwen/Qwen2.5-32B-Instruct", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
llm = VllmLLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||
mock_vllm_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages)
|
||||
|
||||
mock_vllm_client.chat.completions.create.assert_called_once_with(
|
||||
model="Qwen/Qwen2.5-32B-Instruct", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
|
||||
)
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
|
||||
def test_generate_response_with_tools(mock_vllm_client):
|
||||
config = BaseLlmConfig(model="Qwen/Qwen2.5-32B-Instruct", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
llm = VllmLLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_memory",
|
||||
"description": "Add a memory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||
"required": ["data"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_message = Mock()
|
||||
mock_message.content = "I've added the memory for you."
|
||||
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "add_memory"
|
||||
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||
|
||||
mock_message.tool_calls = [mock_tool_call]
|
||||
mock_response.choices = [Mock(message=mock_message)]
|
||||
mock_vllm_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages, tools=tools)
|
||||
|
||||
mock_vllm_client.chat.completions.create.assert_called_once_with(
|
||||
model="Qwen/Qwen2.5-32B-Instruct", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, tools=tools, tool_choice="auto"
|
||||
)
|
||||
|
||||
assert response["content"] == "I've added the memory for you."
|
||||
assert len(response["tool_calls"]) == 1
|
||||
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||
Reference in New Issue
Block a user