Add Groq Support (#1481)

This commit is contained in:
Dev Khant
2024-07-16 23:33:28 +05:30
committed by GitHub
parent 80f145fceb
commit 19637804b3
11 changed files with 369 additions and 6 deletions

66
docs/llms.mdx Normal file
View File

@@ -0,0 +1,66 @@
---
title: 🤖 Large language models (LLMs)
---
## Overview
Mem0 includes built-in support for various popular large language models. Memory can utilize the LLM provided by the user, ensuring efficient use for specific needs.
<CardGroup cols={4}>
<Card title="OpenAI" href="#openai"></Card>
<Card title="Groq" href="#groq"></Card>
</CardGroup>
## OpenAI
To use OpenAI LLM models, you have to set the `OPENAI_API_KEY` environment variable. You can obtain the OpenAI API key from the [OpenAI Platform](https://platform.openai.com/account/api-keys).
Once you have obtained the key, you can use it like this:
```python
import os
from mem0 import Memory
os.environ['OPENAI_API_KEY'] = 'xxx'
config = {
"llm": {
"provider": "openai",
"config": {
"model": "gpt-4o",
"temperature": 0.2,
"max_tokens": 1500,
}
}
}
m = Memory.from_config(config)
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
```
## Groq
[Groq](https://groq.com/) is the creator of the world's first Language Processing Unit (LPU), providing exceptional speed performance for AI workloads running on their LPU Inference Engine.
In order to use LLMs from Groq, go to their [platform](https://console.groq.com/keys) and get the API key. Set the API key as `GROQ_API_KEY` environment variable to use the model as given below in the example.
```python
import os
from mem0 import Memory
os.environ['GROQ_API_KEY'] = 'xxx'
config = {
"llm": {
"provider": "groq",
"config": {
"model": "mixtral-8x7b-32768",
"temperature": 0.1,
"max_tokens": 1000,
}
}
}
m = Memory.from_config(config)
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
```

View File

@@ -53,6 +53,12 @@
"quickstart"
]
},
{
"group": "LLMs",
"pages": [
"llms"
]
},
{
"group": "💡 Examples",
"pages": [

View File

@@ -0,0 +1,22 @@
from typing import Optional
from pydantic import BaseModel, Field, field_validator
class EmbedderConfig(BaseModel):
provider: str = Field(
description="Provider of the embedding model (e.g., 'ollama', 'openai')",
default="openai",
)
config: Optional[dict] = Field(
description="Configuration for the specific embedding model", default=None
)
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider in ["openai", "ollama"]:
return v
else:
raise ValueError(f"Unsupported embedding provider: {provider}")

21
mem0/llms/configs.py Normal file
View File

@@ -0,0 +1,21 @@
from typing import Optional
from pydantic import BaseModel, Field, field_validator
class LlmConfig(BaseModel):
provider: str = Field(
description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
)
config: Optional[dict] = Field(
description="Configuration for the specific LLM", default=None
)
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider in ["openai", "ollama", "groq"]:
return v
else:
raise ValueError(f"Unsupported LLM provider: {provider}")

40
mem0/llms/groq.py Normal file
View File

@@ -0,0 +1,40 @@
from typing import Dict, List, Optional
from groq import Groq
from mem0.llms.base import LLMBase
class GroqLLM(LLMBase):
def __init__(self, model="llama3-70b-8192"):
self.client = Groq()
self.model = model
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using Groq.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.
"""
params = {"model": self.model, "messages": messages}
if response_format:
params["response_format"] = response_format
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params)
return response

View File

@@ -7,8 +7,6 @@ from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, ValidationError
from mem0.embeddings.openai import OpenAIEmbedding
from mem0.llms.openai import OpenAILLM
from mem0.llms.utils.tools import (
ADD_MEMORY_TOOL,
DELETE_MEMORY_TOOL,
@@ -21,7 +19,10 @@ from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_update_memory_messages
from mem0.vector_stores.configs import VectorStoreConfig
from mem0.llms.configs import LlmConfig
from mem0.embeddings.configs import EmbedderConfig
from mem0.vector_stores.qdrant import Qdrant
from mem0.utils.factory import LlmFactory, EmbedderFactory
# Setup user config
setup_config()
@@ -44,6 +45,14 @@ class MemoryConfig(BaseModel):
description="Configuration for the vector store",
default_factory=VectorStoreConfig,
)
llm: LlmConfig = Field(
description="Configuration for the language model",
default_factory=LlmConfig,
)
embedder: EmbedderConfig = Field(
description="Configuration for the embedding model",
default_factory=EmbedderConfig,
)
history_db_path: str = Field(
description="Path to the history database",
default=os.path.join(mem0_dir, "history.db"),
@@ -57,7 +66,7 @@ class MemoryConfig(BaseModel):
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config
self.embedding_model = OpenAIEmbedding()
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider)
# Initialize the appropriate vector store based on the configuration
vector_store_config = self.config.vector_store.config
if self.config.vector_store.provider == "qdrant":
@@ -73,7 +82,7 @@ class Memory(MemoryBase):
f"Unsupported vector store type: {self.config.vector_store_type}"
)
self.llm = OpenAILLM()
self.llm = LlmFactory.create(self.config.llm.provider)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.collection_name
self.vector_store.create_col(

41
mem0/utils/factory.py Normal file
View File

@@ -0,0 +1,41 @@
import importlib
def load_class(class_type):
module_path, class_name = class_type.rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)
class LlmFactory:
provider_to_class = {
"ollama": "mem0.llms.ollama.py.OllamaLLM",
"openai": "mem0.llms.openai.OpenAILLM",
"groq": "mem0.llms.groq.GroqLLM"
}
@classmethod
def create(cls, provider_name):
class_type = cls.provider_to_class.get(provider_name)
if class_type:
llm_instance = load_class(class_type)()
return llm_instance
else:
raise ValueError(f"Unsupported Llm provider: {provider_name}")
class EmbedderFactory:
provider_to_class = {
"openai": "mem0.embeddings.openai.OpenAIEmbedding",
"ollama": "mem0.embeddings.ollama.OllamaEmbedding",
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding"
}
@classmethod
def create(cls, provider_name):
class_type = cls.provider_to_class.get(provider_name)
if class_type:
embedder_instance = load_class(class_type)()
return embedder_instance
else:
raise ValueError(f"Unsupported Embedder provider: {provider_name}")

23
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "annotated-types"
@@ -370,6 +370,25 @@ files = [
[package.extras]
tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"]
[[package]]
name = "groq"
version = "0.9.0"
description = "The official Python library for the groq API"
optional = false
python-versions = ">=3.7"
files = [
{file = "groq-0.9.0-py3-none-any.whl", hash = "sha256:d0e46f4ad645504672bb09c8100af3ced3a7db0d5119dc13e4aca535fc455874"},
{file = "groq-0.9.0.tar.gz", hash = "sha256:130ed5e35d3acfaab46b9e7a078eeaebf91052f4a9d71f86f87fb319b5fec332"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.23.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
typing-extensions = ">=4.7,<5"
[[package]]
name = "grpcio"
version = "1.64.1"
@@ -1707,4 +1726,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[metadata]
lock-version = "2.0"
python-versions = "^3.8"
content-hash = "5138c101a58db8dbddcb640545a5b2b4fc482f9e555008d117e315ae292d7697"
content-hash = "7216c3479e9bce779f99016825bfb726399ffb0ac5f942ac73b899fc373efd37"

View File

@@ -20,6 +20,7 @@ qdrant-client = "^1.9.1"
pydantic = "^2.7.3"
openai = "^1.33.0"
posthog = "^3.5.0"
groq = "^0.9.0"
[tool.poetry.group.test.dependencies]

69
tests/llms/test_groq.py Normal file
View File

@@ -0,0 +1,69 @@
import pytest
from unittest.mock import Mock, patch
from mem0.llms.groq import GroqLLM
@pytest.fixture
def mock_groq_client():
with patch('mem0.llms.groq.Groq') as mock_groq:
mock_client = Mock()
mock_groq.return_value = mock_client
yield mock_client
def test_generate_response_without_tools(mock_groq_client):
llm = GroqLLM()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
mock_groq_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages)
mock_groq_client.chat.completions.create.assert_called_once_with(
model="llama3-70b-8192",
messages=messages
)
assert response.choices[0].message.content == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_groq_client):
llm = GroqLLM()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."}
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {
"data": {"type": "string", "description": "Data to add to memory"}
},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="Memory added successfully."))]
mock_groq_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_groq_client.chat.completions.create.assert_called_once_with(
model="llama3-70b-8192",
messages=messages,
tools=tools,
tool_choice="auto"
)
assert response.choices[0].message.content == "Memory added successfully."

69
tests/llms/test_openai.py Normal file
View File

@@ -0,0 +1,69 @@
import pytest
from unittest.mock import Mock, patch
from mem0.llms.openai import OpenAILLM
@pytest.fixture
def mock_groq_client():
with patch('mem0.llms.openai.OpenAI') as mock_groq:
mock_client = Mock()
mock_groq.return_value = mock_client
yield mock_client
def test_generate_response_without_tools(mock_groq_client):
llm = OpenAILLM()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
]
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
mock_groq_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages)
mock_groq_client.chat.completions.create.assert_called_once_with(
model="gpt-4o",
messages=messages
)
assert response.choices[0].message.content == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_groq_client):
llm = OpenAILLM()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."}
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {
"data": {"type": "string", "description": "Data to add to memory"}
},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="Memory added successfully."))]
mock_groq_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_groq_client.chat.completions.create.assert_called_once_with(
model="gpt-4o",
messages=messages,
tools=tools,
tool_choice="auto"
)
assert response.choices[0].message.content == "Memory added successfully."