[Redis]: Vector database added. (#2032)
This commit is contained in:
44
docs/components/vectordbs/dbs/redis.mdx
Normal file
44
docs/components/vectordbs/dbs/redis.mdx
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
[Redis](https://redis.io/) is a scalable, real-time database that can store, search, and analyze vector data.
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
```bash
|
||||||
|
pip install redis redisvl
|
||||||
|
```
|
||||||
|
|
||||||
|
Redis Stack using Docker:
|
||||||
|
```bash
|
||||||
|
docker run -d --name redis-stack -p 6379:6379 -p 8001:8001 redis/redis-stack:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "sk-xx"
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"vector_store": {
|
||||||
|
"provider": "redis",
|
||||||
|
"config": {
|
||||||
|
"collection_name": "mem0",
|
||||||
|
"embedding_model_dims": 1536,
|
||||||
|
"redis_url": "redis://localhost:6379"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m = Memory.from_config(config)
|
||||||
|
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Config
|
||||||
|
|
||||||
|
Let's see the available parameters for the `redis` config:
|
||||||
|
|
||||||
|
| Parameter | Description | Default Value |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `collection_name` | The name of the collection to store the vectors | `mem0` |
|
||||||
|
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
|
||||||
|
| `redis_url` | The URL of the Redis server | `None` |
|
||||||
@@ -14,6 +14,7 @@ See the list of supported vector databases below.
|
|||||||
<Card title="Pgvector" href="/components/vectordbs/dbs/pgvector"></Card>
|
<Card title="Pgvector" href="/components/vectordbs/dbs/pgvector"></Card>
|
||||||
<Card title="Milvus" href="/components/vectordbs/dbs/milvus"></Card>
|
<Card title="Milvus" href="/components/vectordbs/dbs/milvus"></Card>
|
||||||
<Card title="Azure AI Search" href="/components/vectordbs/dbs/azure_ai_search"></Card>
|
<Card title="Azure AI Search" href="/components/vectordbs/dbs/azure_ai_search"></Card>
|
||||||
|
<Card title="Redis" href="/components/vectordbs/dbs/redis"></Card>
|
||||||
</CardGroup>
|
</CardGroup>
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -111,7 +111,8 @@
|
|||||||
"components/vectordbs/dbs/chroma",
|
"components/vectordbs/dbs/chroma",
|
||||||
"components/vectordbs/dbs/pgvector",
|
"components/vectordbs/dbs/pgvector",
|
||||||
"components/vectordbs/dbs/milvus",
|
"components/vectordbs/dbs/milvus",
|
||||||
"components/vectordbs/dbs/azure_ai_search"
|
"components/vectordbs/dbs/azure_ai_search",
|
||||||
|
"components/vectordbs/dbs/redis"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
26
mem0/configs/vector_stores/redis.py
Normal file
26
mem0/configs/vector_stores/redis.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Upgrade to latest pydantic version
|
||||||
|
class RedisDBConfig(BaseModel):
|
||||||
|
redis_url: str = Field(..., description="Redis URL")
|
||||||
|
collection_name: str = Field("mem0", description="Collection name")
|
||||||
|
embedding_model_dims: int = Field(1536, description="Embedding model dimensions")
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
allowed_fields = set(cls.model_fields.keys())
|
||||||
|
input_fields = set(values.keys())
|
||||||
|
extra_fields = input_fields - allowed_fields
|
||||||
|
if extra_fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True,
|
||||||
|
}
|
||||||
@@ -65,6 +65,7 @@ class VectorStoreFactory:
|
|||||||
"pgvector": "mem0.vector_stores.pgvector.PGVector",
|
"pgvector": "mem0.vector_stores.pgvector.PGVector",
|
||||||
"milvus": "mem0.vector_stores.milvus.MilvusDB",
|
"milvus": "mem0.vector_stores.milvus.MilvusDB",
|
||||||
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
|
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
|
||||||
|
"redis": "mem0.vector_stores.redis.RedisDB",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ class VectorStoreConfig(BaseModel):
|
|||||||
"pgvector": "PGVectorConfig",
|
"pgvector": "PGVectorConfig",
|
||||||
"milvus": "MilvusDBConfig",
|
"milvus": "MilvusDBConfig",
|
||||||
"azure_ai_search": "AzureAISearchConfig",
|
"azure_ai_search": "AzureAISearchConfig",
|
||||||
|
"redis": "RedisDBConfig",
|
||||||
}
|
}
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|||||||
236
mem0/vector_stores/redis.py
Normal file
236
mem0/vector_stores/redis.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytz
|
||||||
|
import redis
|
||||||
|
from redis.commands.search.query import Query
|
||||||
|
from redisvl.index import SearchIndex
|
||||||
|
from redisvl.query import VectorQuery
|
||||||
|
from redisvl.query.filter import Tag
|
||||||
|
|
||||||
|
from mem0.vector_stores.base import VectorStoreBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# TODO: Improve as these are not the best fields for the Redis's perspective. Might do away with them.
|
||||||
|
DEFAULT_FIELDS = [
|
||||||
|
{"name": "memory_id", "type": "tag"},
|
||||||
|
{"name": "hash", "type": "tag"},
|
||||||
|
{"name": "agent_id", "type": "tag"},
|
||||||
|
{"name": "run_id", "type": "tag"},
|
||||||
|
{"name": "user_id", "type": "tag"},
|
||||||
|
{"name": "memory", "type": "text"},
|
||||||
|
{"name": "metadata", "type": "text"},
|
||||||
|
# TODO: Although it is numeric but also accepts string
|
||||||
|
{"name": "created_at", "type": "numeric"},
|
||||||
|
{"name": "updated_at", "type": "numeric"},
|
||||||
|
{
|
||||||
|
"name": "embedding",
|
||||||
|
"type": "vector",
|
||||||
|
"attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryResult:
|
||||||
|
def __init__(self, id: str, payload: dict, score: float = None):
|
||||||
|
self.id = id
|
||||||
|
self.payload = payload
|
||||||
|
self.score = score
|
||||||
|
|
||||||
|
|
||||||
|
class RedisDB(VectorStoreBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis_url: str,
|
||||||
|
collection_name: str,
|
||||||
|
embedding_model_dims: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the Redis vector store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis_url (str): Redis URL.
|
||||||
|
collection_name (str): Collection name.
|
||||||
|
embedding_model_dims (int): Embedding model dimensions.
|
||||||
|
"""
|
||||||
|
index_schema = {
|
||||||
|
"name": collection_name,
|
||||||
|
"prefix": f"mem0:{collection_name}",
|
||||||
|
}
|
||||||
|
|
||||||
|
fields = DEFAULT_FIELDS.copy()
|
||||||
|
fields[-1]["attrs"]["dims"] = embedding_model_dims
|
||||||
|
|
||||||
|
self.schema = {"index": index_schema, "fields": fields}
|
||||||
|
|
||||||
|
self.client = redis.Redis.from_url(redis_url)
|
||||||
|
self.index = SearchIndex.from_dict(self.schema)
|
||||||
|
self.index.set_client(self.client)
|
||||||
|
self.index.create(overwrite=True)
|
||||||
|
|
||||||
|
# TODO: Implement multiindex support.
|
||||||
|
def create_col(self, name, vector_size, distance):
|
||||||
|
raise NotImplementedError("Collection/Index creation not supported yet.")
|
||||||
|
|
||||||
|
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
||||||
|
data = []
|
||||||
|
for vector, payload, id in zip(vectors, payloads, ids):
|
||||||
|
# Start with required fields
|
||||||
|
entry = {
|
||||||
|
"memory_id": id,
|
||||||
|
"hash": payload["hash"],
|
||||||
|
"memory": payload["data"],
|
||||||
|
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
||||||
|
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Conditionally add optional fields
|
||||||
|
for field in ["agent_id", "run_id", "user_id"]:
|
||||||
|
if field in payload:
|
||||||
|
entry[field] = payload[field]
|
||||||
|
|
||||||
|
# Add metadata excluding specific keys
|
||||||
|
entry["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
||||||
|
|
||||||
|
data.append(entry)
|
||||||
|
self.index.load(data, id_field="memory_id")
|
||||||
|
|
||||||
|
def search(self, query: list, limit: int = 5, filters: dict = None):
|
||||||
|
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
|
||||||
|
filter = reduce(lambda x, y: x & y, conditions)
|
||||||
|
|
||||||
|
v = VectorQuery(
|
||||||
|
vector=np.array(query, dtype=np.float32).tobytes(),
|
||||||
|
vector_field_name="embedding",
|
||||||
|
return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"],
|
||||||
|
filter_expression=filter,
|
||||||
|
num_results=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = self.index.query(v)
|
||||||
|
|
||||||
|
return [
|
||||||
|
MemoryResult(
|
||||||
|
id=result["memory_id"],
|
||||||
|
score=result["vector_distance"],
|
||||||
|
payload={
|
||||||
|
"hash": result["hash"],
|
||||||
|
"data": result["memory"],
|
||||||
|
"created_at": datetime.fromtimestamp(
|
||||||
|
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
|
||||||
|
).isoformat(timespec="microseconds"),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"updated_at": datetime.fromtimestamp(
|
||||||
|
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
||||||
|
).isoformat(timespec="microseconds")
|
||||||
|
}
|
||||||
|
if "updated_at" in result
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
||||||
|
**{k: v for k, v in json.loads(result["metadata"]).items()},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
|
||||||
|
def delete(self, vector_id):
|
||||||
|
self.index.drop_keys(f"{self.schema['index']['prefix']}:{vector_id}")
|
||||||
|
|
||||||
|
def update(self, vector_id=None, vector=None, payload=None):
|
||||||
|
data = {
|
||||||
|
"memory_id": vector_id,
|
||||||
|
"hash": payload["hash"],
|
||||||
|
"memory": payload["data"],
|
||||||
|
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
||||||
|
"updated_at": int(datetime.fromisoformat(payload["updated_at"]).timestamp()),
|
||||||
|
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
||||||
|
}
|
||||||
|
|
||||||
|
for field in ["agent_id", "run_id", "user_id"]:
|
||||||
|
if field in payload:
|
||||||
|
data[field] = payload[field]
|
||||||
|
|
||||||
|
data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
||||||
|
self.index.load(data=[data], keys=[f"{self.schema['index']['prefix']}:{vector_id}"], id_field="memory_id")
|
||||||
|
|
||||||
|
def get(self, vector_id):
|
||||||
|
result = self.index.fetch(vector_id)
|
||||||
|
payload = {
|
||||||
|
"hash": result["hash"],
|
||||||
|
"data": result["memory"],
|
||||||
|
"created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone("US/Pacific")).isoformat(
|
||||||
|
timespec="microseconds"
|
||||||
|
),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"updated_at": datetime.fromtimestamp(
|
||||||
|
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
||||||
|
).isoformat(timespec="microseconds")
|
||||||
|
}
|
||||||
|
if "updated_at" in result
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
||||||
|
**{k: v for k, v in json.loads(result["metadata"]).items()},
|
||||||
|
}
|
||||||
|
|
||||||
|
return MemoryResult(id=result["memory_id"], payload=payload)
|
||||||
|
|
||||||
|
def list_cols(self):
|
||||||
|
return self.index.listall()
|
||||||
|
|
||||||
|
def delete_col(self):
|
||||||
|
self.index.delete()
|
||||||
|
|
||||||
|
def col_info(self, name):
|
||||||
|
return self.index.info()
|
||||||
|
|
||||||
|
def list(self, filters: dict = None, limit: int = None) -> list:
|
||||||
|
"""
|
||||||
|
List all recent created memories from the vector store.
|
||||||
|
"""
|
||||||
|
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
|
||||||
|
filter = reduce(lambda x, y: x & y, conditions)
|
||||||
|
query = Query(str(filter)).sort_by("created_at", asc=False)
|
||||||
|
if limit is not None:
|
||||||
|
query = Query(str(filter)).sort_by("created_at", asc=False).paging(0, limit)
|
||||||
|
|
||||||
|
results = self.index.search(query)
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
MemoryResult(
|
||||||
|
id=result["memory_id"],
|
||||||
|
payload={
|
||||||
|
"hash": result["hash"],
|
||||||
|
"data": result["memory"],
|
||||||
|
"created_at": datetime.fromtimestamp(
|
||||||
|
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
|
||||||
|
).isoformat(timespec="microseconds"),
|
||||||
|
**(
|
||||||
|
{
|
||||||
|
"updated_at": datetime.fromtimestamp(
|
||||||
|
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
||||||
|
).isoformat(timespec="microseconds")
|
||||||
|
}
|
||||||
|
if result.__dict__.get("updated_at")
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**{
|
||||||
|
field: result[field]
|
||||||
|
for field in ["agent_id", "run_id", "user_id"]
|
||||||
|
if field in result.__dict__
|
||||||
|
},
|
||||||
|
**{k: v for k, v in json.loads(result["metadata"]).items()},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for result in results.docs
|
||||||
|
]
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user