Add embedder docs and config changes (#1684)
This commit is contained in:
64
docs/components/embedders.mdx
Normal file
64
docs/components/embedders.mdx
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
---
|
||||||
|
title: 🧩 Embedding models
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Mem0 offers support for various embedding models, allowing users to choose the one that best suits their needs.
|
||||||
|
|
||||||
|
<CardGroup cols={3}>
|
||||||
|
<Card title="OpenAI" href="#openai"></Card>
|
||||||
|
<Card title="Ollama" href="#ollama"></Card>
|
||||||
|
</CodeGroup>
|
||||||
|
|
||||||
|
> When using `Qdrant` as a vector database, ensure you update the `embedding_model_dims` to match the dimensions of the embedding model you are using.
|
||||||
|
|
||||||
|
## OpenAI
|
||||||
|
|
||||||
|
To use OpenAI embedding models, set the `OPENAI_API_KEY` environment variable. You can obtain the OpenAI API key from the [OpenAI Platform](https://platform.openai.com/account/api-keys).
|
||||||
|
|
||||||
|
Example of how to select the desired embedding model:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "your_api_key"
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"embedder": {
|
||||||
|
"provider": "openai",
|
||||||
|
"config": {
|
||||||
|
"model": "text-embedding-3-large"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m = Memory.from_config(config)
|
||||||
|
m.add("I'm visiting Paris", user_id="john")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Ollama
|
||||||
|
|
||||||
|
You can use embedding models from Ollama to run Mem0 locally.
|
||||||
|
|
||||||
|
Here's how to select it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "your_api_key"
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"embedder": {
|
||||||
|
"provider": "ollama",
|
||||||
|
"config": {
|
||||||
|
"model": "mxbai-embed-large"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m = Memory.from_config(config)
|
||||||
|
m.add("I'm visiting Paris", user_id="john")
|
||||||
|
```
|
||||||
@@ -69,6 +69,10 @@
|
|||||||
"group": "Vector Database",
|
"group": "Vector Database",
|
||||||
"pages": ["components/vectordb"]
|
"pages": ["components/vectordb"]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"group": "Embedding Models",
|
||||||
|
"pages": ["components/embedders"]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"group": "Features",
|
"group": "Features",
|
||||||
"pages": ["features/openai_compatibility"]
|
"pages": ["features/openai_compatibility"]
|
||||||
|
|||||||
@@ -9,9 +9,12 @@ class BaseEmbedderConfig(ABC):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
embedding_dims: Optional[int] = None,
|
embedding_dims: Optional[int] = None,
|
||||||
|
|
||||||
# Ollama specific
|
# Ollama specific
|
||||||
base_url: Optional[str] = None,
|
ollama_base_url: Optional[str] = None,
|
||||||
|
|
||||||
# Huggingface specific
|
# Huggingface specific
|
||||||
model_kwargs: Optional[dict] = None
|
model_kwargs: Optional[dict] = None
|
||||||
):
|
):
|
||||||
@@ -20,20 +23,23 @@ class BaseEmbedderConfig(ABC):
|
|||||||
|
|
||||||
:param model: Embedding model to use, defaults to None
|
:param model: Embedding model to use, defaults to None
|
||||||
:type model: Optional[str], optional
|
:type model: Optional[str], optional
|
||||||
|
:param api_key: API key to be use, defaults to None
|
||||||
|
:type api_key: Optional[str], optional
|
||||||
:param embedding_dims: The number of dimensions in the embedding, defaults to None
|
:param embedding_dims: The number of dimensions in the embedding, defaults to None
|
||||||
:type embedding_dims: Optional[int], optional
|
:type embedding_dims: Optional[int], optional
|
||||||
:param base_url: Base URL for the Ollama API, defaults to None
|
:param ollama_base_url: Base URL for the Ollama API, defaults to None
|
||||||
:type base_url: Optional[str], optional
|
:type ollama_base_url: Optional[str], optional
|
||||||
:param model_kwargs: key-value arguments for the huggingface embedding model, defaults a dict inside init
|
:param model_kwargs: key-value arguments for the huggingface embedding model, defaults a dict inside init
|
||||||
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
|
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
self.embedding_dims = embedding_dims
|
self.embedding_dims = embedding_dims
|
||||||
|
|
||||||
# Ollama specific
|
# Ollama specific
|
||||||
self.base_url = base_url
|
self.ollama_base_url = ollama_base_url
|
||||||
|
|
||||||
# Huggingface specific
|
# Huggingface specific
|
||||||
self.model_kwargs = model_kwargs or {}
|
self.model_kwargs = model_kwargs or {}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ class BaseLlmConfig(ABC):
|
|||||||
self,
|
self,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
temperature: float = 0,
|
temperature: float = 0,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
max_tokens: int = 3000,
|
max_tokens: int = 3000,
|
||||||
top_p: float = 0,
|
top_p: float = 0,
|
||||||
top_k: int = 1,
|
top_k: int = 1,
|
||||||
@@ -32,6 +33,8 @@ class BaseLlmConfig(ABC):
|
|||||||
:param temperature: Controls the randomness of the model's output.
|
:param temperature: Controls the randomness of the model's output.
|
||||||
Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
|
Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
|
||||||
:type temperature: float, optional
|
:type temperature: float, optional
|
||||||
|
:param api_key: OpenAI API key to be use, defaults to None
|
||||||
|
:type api_key: Optional[str], optional
|
||||||
:param max_tokens: Controls how many tokens are generated, defaults to 3000
|
:param max_tokens: Controls how many tokens are generated, defaults to 3000
|
||||||
:type max_tokens: int, optional
|
:type max_tokens: int, optional
|
||||||
:param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
|
:param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
|
||||||
@@ -39,15 +42,15 @@ class BaseLlmConfig(ABC):
|
|||||||
:type top_p: float, optional
|
:type top_p: float, optional
|
||||||
:param top_k: Controls the diversity of words. Higher values make word selection more diverse, defaults to 0
|
:param top_k: Controls the diversity of words. Higher values make word selection more diverse, defaults to 0
|
||||||
:type top_k: int, optional
|
:type top_k: int, optional
|
||||||
:param models: Controls the Openrouter models used, defaults to None
|
:param models: Openrouter models to use, defaults to None
|
||||||
:type models: Optional[list[str]], optional
|
:type models: Optional[list[str]], optional
|
||||||
:param route: Controls the Openrouter route used, defaults to "fallback"
|
:param route: Openrouter route to be used, defaults to "fallback"
|
||||||
:type route: Optional[str], optional
|
:type route: Optional[str], optional
|
||||||
:param openrouter_base_url: Controls the Openrouter base URL used, defaults to "https://openrouter.ai/api/v1"
|
:param openrouter_base_url: Openrouter base URL to be use, defaults to "https://openrouter.ai/api/v1"
|
||||||
:type openrouter_base_url: Optional[str], optional
|
:type openrouter_base_url: Optional[str], optional
|
||||||
:param site_url: Controls the Openrouter site URL used, defaults to None
|
:param site_url: Openrouter site URL to use, defaults to None
|
||||||
:type site_url: Optional[str], optional
|
:type site_url: Optional[str], optional
|
||||||
:param app_name: Controls the Openrouter app name used, defaults to None
|
:param app_name: Openrouter app name to use, defaults to None
|
||||||
:type app_name: Optional[str], optional
|
:type app_name: Optional[str], optional
|
||||||
:param ollama_base_url: The base URL of the LLM, defaults to None
|
:param ollama_base_url: The base URL of the LLM, defaults to None
|
||||||
:type ollama_base_url: Optional[str], optional
|
:type ollama_base_url: Optional[str], optional
|
||||||
@@ -55,6 +58,7 @@ class BaseLlmConfig(ABC):
|
|||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
self.api_key = api_key
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional,ClassVar
|
from typing import Optional, ClassVar, Dict, Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
@@ -21,6 +21,17 @@ class ChromaDbConfig(BaseModel):
|
|||||||
if not path and not (host and port):
|
if not path and not (host and port):
|
||||||
raise ValueError("Either 'host' and 'port' or 'path' must be provided.")
|
raise ValueError("Either 'host' and 'port' or 'path' must be provided.")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
allowed_fields = set(cls.model_fields.keys())
|
||||||
|
input_fields = set(values.keys())
|
||||||
|
extra_fields = input_fields - allowed_fields
|
||||||
|
if extra_fields:
|
||||||
|
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}")
|
||||||
|
return values
|
||||||
|
|
||||||
class Config:
|
model_config = {
|
||||||
arbitrary_types_allowed = True
|
"arbitrary_types_allowed": True,
|
||||||
|
}
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
from typing import Optional,ClassVar
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
from typing import Optional, ClassVar, Dict, Any
|
||||||
|
|
||||||
class QdrantConfig(BaseModel):
|
class QdrantConfig(BaseModel):
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
@@ -14,10 +13,11 @@ class QdrantConfig(BaseModel):
|
|||||||
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
|
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
|
||||||
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
|
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
|
||||||
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
|
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
|
||||||
on_disk: Optional[bool]= Field(False, description="Enables persistant storage")
|
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
def check_host_port_or_path(cls, values):
|
@classmethod
|
||||||
|
def check_host_port_or_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
host, port, path, url, api_key = (
|
host, port, path, url, api_key = (
|
||||||
values.get("host"),
|
values.get("host"),
|
||||||
values.get("port"),
|
values.get("port"),
|
||||||
@@ -31,5 +31,16 @@ class QdrantConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
class Config:
|
@model_validator(mode="before")
|
||||||
arbitrary_types_allowed = True
|
@classmethod
|
||||||
|
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
allowed_fields = set(cls.model_fields.keys())
|
||||||
|
input_fields = set(values.keys())
|
||||||
|
extra_fields = input_fields - allowed_fields
|
||||||
|
if extra_fields:
|
||||||
|
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}")
|
||||||
|
return values
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True,
|
||||||
|
}
|
||||||
@@ -18,7 +18,7 @@ class OllamaEmbedding(EmbeddingBase):
|
|||||||
if not self.config.embedding_dims:
|
if not self.config.embedding_dims:
|
||||||
self.config.embedding_dims=512
|
self.config.embedding_dims=512
|
||||||
|
|
||||||
self.client = Client(host=self.config.base_url)
|
self.client = Client(host=self.config.ollama_base_url)
|
||||||
self._ensure_model_exists()
|
self._ensure_model_exists()
|
||||||
|
|
||||||
def _ensure_model_exists(self):
|
def _ensure_model_exists(self):
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@@ -9,12 +10,11 @@ class OpenAIEmbedding(EmbeddingBase):
|
|||||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not self.config.model:
|
self.config.model = self.config.model or "text-embedding-3-small"
|
||||||
self.config.model="text-embedding-3-small"
|
self.config.embedding_dims = self.config.embedding_dims or 1536
|
||||||
if not self.config.embedding_dims:
|
|
||||||
self.config.embedding_dims=1536
|
|
||||||
|
|
||||||
self.client = OpenAI()
|
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
|
||||||
|
self.client = OpenAI(api_key=api_key)
|
||||||
|
|
||||||
def embed(self, text):
|
def embed(self, text):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -14,10 +14,11 @@ class OpenAILLM(LLMBase):
|
|||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
self.config.model="gpt-4o"
|
self.config.model="gpt-4o"
|
||||||
|
|
||||||
if os.environ.get("OPENROUTER_API_KEY"):
|
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
|
||||||
self.client = OpenAI(api_key=os.environ.get("OPENROUTER_API_KEY"), base_url=self.config.openrouter_base_url)
|
self.client = OpenAI(api_key=os.environ.get("OPENROUTER_API_KEY"), base_url=self.config.openrouter_base_url)
|
||||||
else:
|
else:
|
||||||
self.client = OpenAI()
|
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
|
||||||
|
self.client = OpenAI(api_key=api_key)
|
||||||
|
|
||||||
def _parse_response(self, response, tools):
|
def _parse_response(self, response, tools):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -374,6 +374,7 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
new_metadata = metadata or {}
|
new_metadata = metadata or {}
|
||||||
new_metadata["data"] = data
|
new_metadata["data"] = data
|
||||||
|
new_metadata["hash"] = existing_memory.payload.get("hash")
|
||||||
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
||||||
new_metadata["updated_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat()
|
new_metadata["updated_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat()
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from pydantic import BaseModel, Field, model_validator
|
|||||||
|
|
||||||
class VectorStoreConfig(BaseModel):
|
class VectorStoreConfig(BaseModel):
|
||||||
provider: str = Field(
|
provider: str = Field(
|
||||||
description="Provider of the vector store (e.g., 'qdrant', 'chromadb')",
|
description="Provider of the vector store (e.g., 'qdrant', 'chroma')",
|
||||||
default="qdrant",
|
default="qdrant",
|
||||||
)
|
)
|
||||||
config: Optional[Dict] = Field(
|
config: Optional[Dict] = Field(
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "mem0ai"
|
name = "mem0ai"
|
||||||
version = "0.0.15"
|
version = "0.0.16"
|
||||||
description = "Long-term memory for AI Agents"
|
description = "Long-term memory for AI Agents"
|
||||||
authors = ["Mem0 <founders@mem0.ai>"]
|
authors = ["Mem0 <founders@mem0.ai>"]
|
||||||
exclude = [
|
exclude = [
|
||||||
|
|||||||
Reference in New Issue
Block a user