Add embedder docs and config changes (#1684)
This commit is contained in:
64
docs/components/embedders.mdx
Normal file
64
docs/components/embedders.mdx
Normal file
@@ -0,0 +1,64 @@
|
||||
---
|
||||
title: 🧩 Embedding models
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Mem0 offers support for various embedding models, allowing users to choose the one that best suits their needs.
|
||||
|
||||
<CardGroup cols={3}>
|
||||
<Card title="OpenAI" href="#openai"></Card>
|
||||
<Card title="Ollama" href="#ollama"></Card>
|
||||
</CodeGroup>
|
||||
|
||||
> When using `Qdrant` as a vector database, ensure you update the `embedding_model_dims` to match the dimensions of the embedding model you are using.
|
||||
|
||||
## OpenAI
|
||||
|
||||
To use OpenAI embedding models, set the `OPENAI_API_KEY` environment variable. You can obtain the OpenAI API key from the [OpenAI Platform](https://platform.openai.com/account/api-keys).
|
||||
|
||||
Example of how to select the desired embedding model:
|
||||
|
||||
```python
|
||||
import os
|
||||
from mem0 import Memory
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "your_api_key"
|
||||
|
||||
config = {
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-large"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
m.add("I'm visiting Paris", user_id="john")
|
||||
```
|
||||
|
||||
## Ollama
|
||||
|
||||
You can use embedding models from Ollama to run Mem0 locally.
|
||||
|
||||
Here's how to select it:
|
||||
|
||||
```python
|
||||
import os
|
||||
from mem0 import Memory
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "your_api_key"
|
||||
|
||||
config = {
|
||||
"embedder": {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "mxbai-embed-large"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
m.add("I'm visiting Paris", user_id="john")
|
||||
```
|
||||
@@ -69,6 +69,10 @@
|
||||
"group": "Vector Database",
|
||||
"pages": ["components/vectordb"]
|
||||
},
|
||||
{
|
||||
"group": "Embedding Models",
|
||||
"pages": ["components/embedders"]
|
||||
},
|
||||
{
|
||||
"group": "Features",
|
||||
"pages": ["features/openai_compatibility"]
|
||||
|
||||
@@ -9,9 +9,12 @@ class BaseEmbedderConfig(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
embedding_dims: Optional[int] = None,
|
||||
|
||||
# Ollama specific
|
||||
base_url: Optional[str] = None,
|
||||
ollama_base_url: Optional[str] = None,
|
||||
|
||||
# Huggingface specific
|
||||
model_kwargs: Optional[dict] = None
|
||||
):
|
||||
@@ -20,20 +23,23 @@ class BaseEmbedderConfig(ABC):
|
||||
|
||||
:param model: Embedding model to use, defaults to None
|
||||
:type model: Optional[str], optional
|
||||
:param api_key: API key to be use, defaults to None
|
||||
:type api_key: Optional[str], optional
|
||||
:param embedding_dims: The number of dimensions in the embedding, defaults to None
|
||||
:type embedding_dims: Optional[int], optional
|
||||
:param base_url: Base URL for the Ollama API, defaults to None
|
||||
:type base_url: Optional[str], optional
|
||||
:param ollama_base_url: Base URL for the Ollama API, defaults to None
|
||||
:type ollama_base_url: Optional[str], optional
|
||||
:param model_kwargs: key-value arguments for the huggingface embedding model, defaults a dict inside init
|
||||
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
|
||||
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.embedding_dims = embedding_dims
|
||||
|
||||
# Ollama specific
|
||||
self.base_url = base_url
|
||||
self.ollama_base_url = ollama_base_url
|
||||
|
||||
# Huggingface specific
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
|
||||
@@ -10,6 +10,7 @@ class BaseLlmConfig(ABC):
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0,
|
||||
api_key: Optional[str] = None,
|
||||
max_tokens: int = 3000,
|
||||
top_p: float = 0,
|
||||
top_k: int = 1,
|
||||
@@ -32,6 +33,8 @@ class BaseLlmConfig(ABC):
|
||||
:param temperature: Controls the randomness of the model's output.
|
||||
Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
|
||||
:type temperature: float, optional
|
||||
:param api_key: OpenAI API key to be use, defaults to None
|
||||
:type api_key: Optional[str], optional
|
||||
:param max_tokens: Controls how many tokens are generated, defaults to 3000
|
||||
:type max_tokens: int, optional
|
||||
:param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
|
||||
@@ -39,15 +42,15 @@ class BaseLlmConfig(ABC):
|
||||
:type top_p: float, optional
|
||||
:param top_k: Controls the diversity of words. Higher values make word selection more diverse, defaults to 0
|
||||
:type top_k: int, optional
|
||||
:param models: Controls the Openrouter models used, defaults to None
|
||||
:param models: Openrouter models to use, defaults to None
|
||||
:type models: Optional[list[str]], optional
|
||||
:param route: Controls the Openrouter route used, defaults to "fallback"
|
||||
:param route: Openrouter route to be used, defaults to "fallback"
|
||||
:type route: Optional[str], optional
|
||||
:param openrouter_base_url: Controls the Openrouter base URL used, defaults to "https://openrouter.ai/api/v1"
|
||||
:param openrouter_base_url: Openrouter base URL to be use, defaults to "https://openrouter.ai/api/v1"
|
||||
:type openrouter_base_url: Optional[str], optional
|
||||
:param site_url: Controls the Openrouter site URL used, defaults to None
|
||||
:param site_url: Openrouter site URL to use, defaults to None
|
||||
:type site_url: Optional[str], optional
|
||||
:param app_name: Controls the Openrouter app name used, defaults to None
|
||||
:param app_name: Openrouter app name to use, defaults to None
|
||||
:type app_name: Optional[str], optional
|
||||
:param ollama_base_url: The base URL of the LLM, defaults to None
|
||||
:type ollama_base_url: Optional[str], optional
|
||||
@@ -55,6 +58,7 @@ class BaseLlmConfig(ABC):
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional,ClassVar
|
||||
from typing import Optional, ClassVar, Dict, Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -21,6 +21,17 @@ class ChromaDbConfig(BaseModel):
|
||||
if not path and not (host and port):
|
||||
raise ValueError("Either 'host' and 'port' or 'path' must be provided.")
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}")
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Optional,ClassVar
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import Optional, ClassVar, Dict, Any
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
from qdrant_client import QdrantClient
|
||||
@@ -14,10 +13,11 @@ class QdrantConfig(BaseModel):
|
||||
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
|
||||
url: Optional[str] = Field(None, description="Full URL for Qdrant server")
|
||||
api_key: Optional[str] = Field(None, description="API key for Qdrant server")
|
||||
on_disk: Optional[bool]= Field(False, description="Enables persistant storage")
|
||||
on_disk: Optional[bool] = Field(False, description="Enables persistent storage")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_host_port_or_path(cls, values):
|
||||
@classmethod
|
||||
def check_host_port_or_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
host, port, path, url, api_key = (
|
||||
values.get("host"),
|
||||
values.get("port"),
|
||||
@@ -31,5 +31,16 @@ class QdrantConfig(BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}")
|
||||
return values
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
@@ -18,7 +18,7 @@ class OllamaEmbedding(EmbeddingBase):
|
||||
if not self.config.embedding_dims:
|
||||
self.config.embedding_dims=512
|
||||
|
||||
self.client = Client(host=self.config.base_url)
|
||||
self.client = Client(host=self.config.ollama_base_url)
|
||||
self._ensure_model_exists()
|
||||
|
||||
def _ensure_model_exists(self):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from openai import OpenAI
|
||||
@@ -9,12 +10,11 @@ class OpenAIEmbedding(EmbeddingBase):
|
||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model="text-embedding-3-small"
|
||||
if not self.config.embedding_dims:
|
||||
self.config.embedding_dims=1536
|
||||
self.config.model = self.config.model or "text-embedding-3-small"
|
||||
self.config.embedding_dims = self.config.embedding_dims or 1536
|
||||
|
||||
self.client = OpenAI()
|
||||
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
|
||||
def embed(self, text):
|
||||
"""
|
||||
|
||||
@@ -14,10 +14,11 @@ class OpenAILLM(LLMBase):
|
||||
if not self.config.model:
|
||||
self.config.model="gpt-4o"
|
||||
|
||||
if os.environ.get("OPENROUTER_API_KEY"):
|
||||
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
|
||||
self.client = OpenAI(api_key=os.environ.get("OPENROUTER_API_KEY"), base_url=self.config.openrouter_base_url)
|
||||
else:
|
||||
self.client = OpenAI()
|
||||
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
|
||||
@@ -374,6 +374,7 @@ class Memory(MemoryBase):
|
||||
|
||||
new_metadata = metadata or {}
|
||||
new_metadata["data"] = data
|
||||
new_metadata["hash"] = existing_memory.payload.get("hash")
|
||||
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
||||
new_metadata["updated_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat()
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
provider: str = Field(
|
||||
description="Provider of the vector store (e.g., 'qdrant', 'chromadb')",
|
||||
description="Provider of the vector store (e.g., 'qdrant', 'chroma')",
|
||||
default="qdrant",
|
||||
)
|
||||
config: Optional[Dict] = Field(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "mem0ai"
|
||||
version = "0.0.15"
|
||||
version = "0.0.16"
|
||||
description = "Long-term memory for AI Agents"
|
||||
authors = ["Mem0 <founders@mem0.ai>"]
|
||||
exclude = [
|
||||
|
||||
Reference in New Issue
Block a user