From 19637804b3ced945bccf9a82ebf6098c3b7e23c3 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Tue, 16 Jul 2024 23:33:28 +0530 Subject: [PATCH] Add Groq Support (#1481) --- docs/llms.mdx | 66 ++++++++++++++++++++++++++++++++++++ docs/mint.json | 6 ++++ mem0/embeddings/configs.py | 22 ++++++++++++ mem0/llms/configs.py | 21 ++++++++++++ mem0/llms/groq.py | 40 ++++++++++++++++++++++ mem0/memory/main.py | 17 +++++++--- mem0/utils/factory.py | 41 ++++++++++++++++++++++ poetry.lock | 23 +++++++++++-- pyproject.toml | 1 + tests/llms/test_groq.py | 69 ++++++++++++++++++++++++++++++++++++++ tests/llms/test_openai.py | 69 ++++++++++++++++++++++++++++++++++++++ 11 files changed, 369 insertions(+), 6 deletions(-) create mode 100644 docs/llms.mdx create mode 100644 mem0/embeddings/configs.py create mode 100644 mem0/llms/configs.py create mode 100644 mem0/llms/groq.py create mode 100644 mem0/utils/factory.py create mode 100644 tests/llms/test_groq.py create mode 100644 tests/llms/test_openai.py diff --git a/docs/llms.mdx b/docs/llms.mdx new file mode 100644 index 00000000..0d55de40 --- /dev/null +++ b/docs/llms.mdx @@ -0,0 +1,66 @@ +--- +title: 🤖 Large language models (LLMs) +--- + +## Overview + +Mem0 includes built-in support for various popular large language models. Memory can utilize the LLM provided by the user, ensuring efficient use for specific needs. + + + + + + +## OpenAI + +To use OpenAI LLM models, you have to set the `OPENAI_API_KEY` environment variable. You can obtain the OpenAI API key from the [OpenAI Platform](https://platform.openai.com/account/api-keys). + +Once you have obtained the key, you can use it like this: + +```python +import os +from mem0 import Memory + +os.environ['OPENAI_API_KEY'] = 'xxx' + +config = { + "llm": { + "provider": "openai", + "config": { + "model": "gpt-4o", + "temperature": 0.2, + "max_tokens": 1500, + } + } +} + +m = Memory.from_config(config) +m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) +``` + +## Groq + +[Groq](https://groq.com/) is the creator of the world's first Language Processing Unit (LPU), providing exceptional speed performance for AI workloads running on their LPU Inference Engine. + +In order to use LLMs from Groq, go to their [platform](https://console.groq.com/keys) and get the API key. Set the API key as `GROQ_API_KEY` environment variable to use the model as given below in the example. + +```python +import os +from mem0 import Memory + +os.environ['GROQ_API_KEY'] = 'xxx' + +config = { + "llm": { + "provider": "groq", + "config": { + "model": "mixtral-8x7b-32768", + "temperature": 0.1, + "max_tokens": 1000, + } + } +} + +m = Memory.from_config(config) +m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) +``` diff --git a/docs/mint.json b/docs/mint.json index 98b18475..03ca08c1 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -53,6 +53,12 @@ "quickstart" ] }, + { + "group": "LLMs", + "pages": [ + "llms" + ] + }, { "group": "💡 Examples", "pages": [ diff --git a/mem0/embeddings/configs.py b/mem0/embeddings/configs.py new file mode 100644 index 00000000..d92e7e99 --- /dev/null +++ b/mem0/embeddings/configs.py @@ -0,0 +1,22 @@ +from typing import Optional + +from pydantic import BaseModel, Field, field_validator + + +class EmbedderConfig(BaseModel): + provider: str = Field( + description="Provider of the embedding model (e.g., 'ollama', 'openai')", + default="openai", + ) + config: Optional[dict] = Field( + description="Configuration for the specific embedding model", default=None + ) + + @field_validator("config") + def validate_config(cls, v, values): + provider = values.data.get("provider") + if provider in ["openai", "ollama"]: + return v + else: + raise ValueError(f"Unsupported embedding provider: {provider}") + \ No newline at end of file diff --git a/mem0/llms/configs.py b/mem0/llms/configs.py new file mode 100644 index 00000000..b28a8a5e --- /dev/null +++ b/mem0/llms/configs.py @@ -0,0 +1,21 @@ +from typing import Optional + +from pydantic import BaseModel, Field, field_validator + + +class LlmConfig(BaseModel): + provider: str = Field( + description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai" + ) + config: Optional[dict] = Field( + description="Configuration for the specific LLM", default=None + ) + + @field_validator("config") + def validate_config(cls, v, values): + provider = values.data.get("provider") + if provider in ["openai", "ollama", "groq"]: + return v + else: + raise ValueError(f"Unsupported LLM provider: {provider}") + \ No newline at end of file diff --git a/mem0/llms/groq.py b/mem0/llms/groq.py new file mode 100644 index 00000000..9d662899 --- /dev/null +++ b/mem0/llms/groq.py @@ -0,0 +1,40 @@ +from typing import Dict, List, Optional + +from groq import Groq + +from mem0.llms.base import LLMBase + + +class GroqLLM(LLMBase): + def __init__(self, model="llama3-70b-8192"): + self.client = Groq() + self.model = model + + def generate_response( + self, + messages: List[Dict[str, str]], + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): + """ + Generate a response based on the given messages using Groq. + + Args: + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". + + Returns: + str: The generated response. + """ + params = {"model": self.model, "messages": messages} + if response_format: + params["response_format"] = response_format + if tools: + params["tools"] = tools + params["tool_choice"] = tool_choice + + response = self.client.chat.completions.create(**params) + return response diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 0bb89c64..89e1e0ae 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -7,8 +7,6 @@ from typing import Any, Dict, Optional from pydantic import BaseModel, Field, ValidationError -from mem0.embeddings.openai import OpenAIEmbedding -from mem0.llms.openai import OpenAILLM from mem0.llms.utils.tools import ( ADD_MEMORY_TOOL, DELETE_MEMORY_TOOL, @@ -21,7 +19,10 @@ from mem0.memory.storage import SQLiteManager from mem0.memory.telemetry import capture_event from mem0.memory.utils import get_update_memory_messages from mem0.vector_stores.configs import VectorStoreConfig +from mem0.llms.configs import LlmConfig +from mem0.embeddings.configs import EmbedderConfig from mem0.vector_stores.qdrant import Qdrant +from mem0.utils.factory import LlmFactory, EmbedderFactory # Setup user config setup_config() @@ -44,6 +45,14 @@ class MemoryConfig(BaseModel): description="Configuration for the vector store", default_factory=VectorStoreConfig, ) + llm: LlmConfig = Field( + description="Configuration for the language model", + default_factory=LlmConfig, + ) + embedder: EmbedderConfig = Field( + description="Configuration for the embedding model", + default_factory=EmbedderConfig, + ) history_db_path: str = Field( description="Path to the history database", default=os.path.join(mem0_dir, "history.db"), @@ -57,7 +66,7 @@ class MemoryConfig(BaseModel): class Memory(MemoryBase): def __init__(self, config: MemoryConfig = MemoryConfig()): self.config = config - self.embedding_model = OpenAIEmbedding() + self.embedding_model = EmbedderFactory.create(self.config.embedder.provider) # Initialize the appropriate vector store based on the configuration vector_store_config = self.config.vector_store.config if self.config.vector_store.provider == "qdrant": @@ -73,7 +82,7 @@ class Memory(MemoryBase): f"Unsupported vector store type: {self.config.vector_store_type}" ) - self.llm = OpenAILLM() + self.llm = LlmFactory.create(self.config.llm.provider) self.db = SQLiteManager(self.config.history_db_path) self.collection_name = self.config.collection_name self.vector_store.create_col( diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py new file mode 100644 index 00000000..15fa0092 --- /dev/null +++ b/mem0/utils/factory.py @@ -0,0 +1,41 @@ +import importlib + + +def load_class(class_type): + module_path, class_name = class_type.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +class LlmFactory: + provider_to_class = { + "ollama": "mem0.llms.ollama.py.OllamaLLM", + "openai": "mem0.llms.openai.OpenAILLM", + "groq": "mem0.llms.groq.GroqLLM" + } + + @classmethod + def create(cls, provider_name): + class_type = cls.provider_to_class.get(provider_name) + if class_type: + llm_instance = load_class(class_type)() + return llm_instance + else: + raise ValueError(f"Unsupported Llm provider: {provider_name}") + +class EmbedderFactory: + provider_to_class = { + "openai": "mem0.embeddings.openai.OpenAIEmbedding", + "ollama": "mem0.embeddings.ollama.OllamaEmbedding", + "huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding" + } + + @classmethod + def create(cls, provider_name): + class_type = cls.provider_to_class.get(provider_name) + if class_type: + embedder_instance = load_class(class_type)() + return embedder_instance + else: + raise ValueError(f"Unsupported Embedder provider: {provider_name}") + \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 568429da..bccaf58d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -370,6 +370,25 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] +[[package]] +name = "groq" +version = "0.9.0" +description = "The official Python library for the groq API" +optional = false +python-versions = ">=3.7" +files = [ + {file = "groq-0.9.0-py3-none-any.whl", hash = "sha256:d0e46f4ad645504672bb09c8100af3ced3a7db0d5119dc13e4aca535fc455874"}, + {file = "groq-0.9.0.tar.gz", hash = "sha256:130ed5e35d3acfaab46b9e7a078eeaebf91052f4a9d71f86f87fb319b5fec332"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +typing-extensions = ">=4.7,<5" + [[package]] name = "grpcio" version = "1.64.1" @@ -1707,4 +1726,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "5138c101a58db8dbddcb640545a5b2b4fc482f9e555008d117e315ae292d7697" +content-hash = "7216c3479e9bce779f99016825bfb726399ffb0ac5f942ac73b899fc373efd37" diff --git a/pyproject.toml b/pyproject.toml index 8ad51594..4f976e41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ qdrant-client = "^1.9.1" pydantic = "^2.7.3" openai = "^1.33.0" posthog = "^3.5.0" +groq = "^0.9.0" [tool.poetry.group.test.dependencies] diff --git a/tests/llms/test_groq.py b/tests/llms/test_groq.py new file mode 100644 index 00000000..6c6ede3e --- /dev/null +++ b/tests/llms/test_groq.py @@ -0,0 +1,69 @@ +import pytest +from unittest.mock import Mock, patch +from mem0.llms.groq import GroqLLM + +@pytest.fixture +def mock_groq_client(): + with patch('mem0.llms.groq.Groq') as mock_groq: + mock_client = Mock() + mock_groq.return_value = mock_client + yield mock_client + + +def test_generate_response_without_tools(mock_groq_client): + llm = GroqLLM() + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"} + ] + + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] + mock_groq_client.chat.completions.create.return_value = mock_response + + response = llm.generate_response(messages) + + mock_groq_client.chat.completions.create.assert_called_once_with( + model="llama3-70b-8192", + messages=messages + ) + assert response.choices[0].message.content == "I'm doing well, thank you for asking!" + + +def test_generate_response_with_tools(mock_groq_client): + llm = GroqLLM() + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Add a new memory: Today is a sunny day."} + ] + tools = [ + { + "type": "function", + "function": { + "name": "add_memory", + "description": "Add a memory", + "parameters": { + "type": "object", + "properties": { + "data": {"type": "string", "description": "Data to add to memory"} + }, + "required": ["data"], + }, + }, + } + ] + + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content="Memory added successfully."))] + mock_groq_client.chat.completions.create.return_value = mock_response + + response = llm.generate_response(messages, tools=tools) + + mock_groq_client.chat.completions.create.assert_called_once_with( + model="llama3-70b-8192", + messages=messages, + tools=tools, + tool_choice="auto" + ) + assert response.choices[0].message.content == "Memory added successfully." + \ No newline at end of file diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py new file mode 100644 index 00000000..796f87b1 --- /dev/null +++ b/tests/llms/test_openai.py @@ -0,0 +1,69 @@ +import pytest +from unittest.mock import Mock, patch +from mem0.llms.openai import OpenAILLM + +@pytest.fixture +def mock_groq_client(): + with patch('mem0.llms.openai.OpenAI') as mock_groq: + mock_client = Mock() + mock_groq.return_value = mock_client + yield mock_client + + +def test_generate_response_without_tools(mock_groq_client): + llm = OpenAILLM() + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"} + ] + + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] + mock_groq_client.chat.completions.create.return_value = mock_response + + response = llm.generate_response(messages) + + mock_groq_client.chat.completions.create.assert_called_once_with( + model="gpt-4o", + messages=messages + ) + assert response.choices[0].message.content == "I'm doing well, thank you for asking!" + + +def test_generate_response_with_tools(mock_groq_client): + llm = OpenAILLM() + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Add a new memory: Today is a sunny day."} + ] + tools = [ + { + "type": "function", + "function": { + "name": "add_memory", + "description": "Add a memory", + "parameters": { + "type": "object", + "properties": { + "data": {"type": "string", "description": "Data to add to memory"} + }, + "required": ["data"], + }, + }, + } + ] + + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content="Memory added successfully."))] + mock_groq_client.chat.completions.create.return_value = mock_response + + response = llm.generate_response(messages, tools=tools) + + mock_groq_client.chat.completions.create.assert_called_once_with( + model="gpt-4o", + messages=messages, + tools=tools, + tool_choice="auto" + ) + assert response.choices[0].message.content == "Memory added successfully." + \ No newline at end of file