[Redis]: Vector database added. (#2032)
This commit is contained in:
44
docs/components/vectordbs/dbs/redis.mdx
Normal file
44
docs/components/vectordbs/dbs/redis.mdx
Normal file
@@ -0,0 +1,44 @@
|
||||
[Redis](https://redis.io/) is a scalable, real-time database that can store, search, and analyze vector data.
|
||||
|
||||
### Installation
|
||||
```bash
|
||||
pip install redis redisvl
|
||||
```
|
||||
|
||||
Redis Stack using Docker:
|
||||
```bash
|
||||
docker run -d --name redis-stack -p 6379:6379 -p 8001:8001 redis/redis-stack:latest
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
import os
|
||||
from mem0 import Memory
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xx"
|
||||
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "redis",
|
||||
"config": {
|
||||
"collection_name": "mem0",
|
||||
"embedding_model_dims": 1536,
|
||||
"redis_url": "redis://localhost:6379"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
|
||||
```
|
||||
|
||||
### Config
|
||||
|
||||
Let's see the available parameters for the `redis` config:
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| --- | --- | --- |
|
||||
| `collection_name` | The name of the collection to store the vectors | `mem0` |
|
||||
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
|
||||
| `redis_url` | The URL of the Redis server | `None` |
|
||||
@@ -14,6 +14,7 @@ See the list of supported vector databases below.
|
||||
<Card title="Pgvector" href="/components/vectordbs/dbs/pgvector"></Card>
|
||||
<Card title="Milvus" href="/components/vectordbs/dbs/milvus"></Card>
|
||||
<Card title="Azure AI Search" href="/components/vectordbs/dbs/azure_ai_search"></Card>
|
||||
<Card title="Redis" href="/components/vectordbs/dbs/redis"></Card>
|
||||
</CardGroup>
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -111,7 +111,8 @@
|
||||
"components/vectordbs/dbs/chroma",
|
||||
"components/vectordbs/dbs/pgvector",
|
||||
"components/vectordbs/dbs/milvus",
|
||||
"components/vectordbs/dbs/azure_ai_search"
|
||||
"components/vectordbs/dbs/azure_ai_search",
|
||||
"components/vectordbs/dbs/redis"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
26
mem0/configs/vector_stores/redis.py
Normal file
26
mem0/configs/vector_stores/redis.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
# TODO: Upgrade to latest pydantic version
|
||||
class RedisDBConfig(BaseModel):
|
||||
redis_url: str = Field(..., description="Redis URL")
|
||||
collection_name: str = Field("mem0", description="Collection name")
|
||||
embedding_model_dims: int = Field(1536, description="Embedding model dimensions")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
@@ -65,6 +65,7 @@ class VectorStoreFactory:
|
||||
"pgvector": "mem0.vector_stores.pgvector.PGVector",
|
||||
"milvus": "mem0.vector_stores.milvus.MilvusDB",
|
||||
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
|
||||
"redis": "mem0.vector_stores.redis.RedisDB",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -16,6 +16,7 @@ class VectorStoreConfig(BaseModel):
|
||||
"pgvector": "PGVectorConfig",
|
||||
"milvus": "MilvusDBConfig",
|
||||
"azure_ai_search": "AzureAISearchConfig",
|
||||
"redis": "RedisDBConfig",
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
236
mem0/vector_stores/redis.py
Normal file
236
mem0/vector_stores/redis.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from functools import reduce
|
||||
|
||||
import numpy as np
|
||||
import pytz
|
||||
import redis
|
||||
from redis.commands.search.query import Query
|
||||
from redisvl.index import SearchIndex
|
||||
from redisvl.query import VectorQuery
|
||||
from redisvl.query.filter import Tag
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: Improve as these are not the best fields for the Redis's perspective. Might do away with them.
|
||||
DEFAULT_FIELDS = [
|
||||
{"name": "memory_id", "type": "tag"},
|
||||
{"name": "hash", "type": "tag"},
|
||||
{"name": "agent_id", "type": "tag"},
|
||||
{"name": "run_id", "type": "tag"},
|
||||
{"name": "user_id", "type": "tag"},
|
||||
{"name": "memory", "type": "text"},
|
||||
{"name": "metadata", "type": "text"},
|
||||
# TODO: Although it is numeric but also accepts string
|
||||
{"name": "created_at", "type": "numeric"},
|
||||
{"name": "updated_at", "type": "numeric"},
|
||||
{
|
||||
"name": "embedding",
|
||||
"type": "vector",
|
||||
"attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"},
|
||||
},
|
||||
]
|
||||
|
||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||
|
||||
|
||||
class MemoryResult:
|
||||
def __init__(self, id: str, payload: dict, score: float = None):
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.score = score
|
||||
|
||||
|
||||
class RedisDB(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
):
|
||||
"""
|
||||
Initialize the Redis vector store.
|
||||
|
||||
Args:
|
||||
redis_url (str): Redis URL.
|
||||
collection_name (str): Collection name.
|
||||
embedding_model_dims (int): Embedding model dimensions.
|
||||
"""
|
||||
index_schema = {
|
||||
"name": collection_name,
|
||||
"prefix": f"mem0:{collection_name}",
|
||||
}
|
||||
|
||||
fields = DEFAULT_FIELDS.copy()
|
||||
fields[-1]["attrs"]["dims"] = embedding_model_dims
|
||||
|
||||
self.schema = {"index": index_schema, "fields": fields}
|
||||
|
||||
self.client = redis.Redis.from_url(redis_url)
|
||||
self.index = SearchIndex.from_dict(self.schema)
|
||||
self.index.set_client(self.client)
|
||||
self.index.create(overwrite=True)
|
||||
|
||||
# TODO: Implement multiindex support.
|
||||
def create_col(self, name, vector_size, distance):
|
||||
raise NotImplementedError("Collection/Index creation not supported yet.")
|
||||
|
||||
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
||||
data = []
|
||||
for vector, payload, id in zip(vectors, payloads, ids):
|
||||
# Start with required fields
|
||||
entry = {
|
||||
"memory_id": id,
|
||||
"hash": payload["hash"],
|
||||
"memory": payload["data"],
|
||||
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
||||
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
||||
}
|
||||
|
||||
# Conditionally add optional fields
|
||||
for field in ["agent_id", "run_id", "user_id"]:
|
||||
if field in payload:
|
||||
entry[field] = payload[field]
|
||||
|
||||
# Add metadata excluding specific keys
|
||||
entry["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
||||
|
||||
data.append(entry)
|
||||
self.index.load(data, id_field="memory_id")
|
||||
|
||||
def search(self, query: list, limit: int = 5, filters: dict = None):
|
||||
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
|
||||
filter = reduce(lambda x, y: x & y, conditions)
|
||||
|
||||
v = VectorQuery(
|
||||
vector=np.array(query, dtype=np.float32).tobytes(),
|
||||
vector_field_name="embedding",
|
||||
return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"],
|
||||
filter_expression=filter,
|
||||
num_results=limit,
|
||||
)
|
||||
|
||||
results = self.index.query(v)
|
||||
|
||||
return [
|
||||
MemoryResult(
|
||||
id=result["memory_id"],
|
||||
score=result["vector_distance"],
|
||||
payload={
|
||||
"hash": result["hash"],
|
||||
"data": result["memory"],
|
||||
"created_at": datetime.fromtimestamp(
|
||||
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
|
||||
).isoformat(timespec="microseconds"),
|
||||
**(
|
||||
{
|
||||
"updated_at": datetime.fromtimestamp(
|
||||
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
||||
).isoformat(timespec="microseconds")
|
||||
}
|
||||
if "updated_at" in result
|
||||
else {}
|
||||
),
|
||||
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
||||
**{k: v for k, v in json.loads(result["metadata"]).items()},
|
||||
},
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
||||
def delete(self, vector_id):
|
||||
self.index.drop_keys(f"{self.schema['index']['prefix']}:{vector_id}")
|
||||
|
||||
def update(self, vector_id=None, vector=None, payload=None):
|
||||
data = {
|
||||
"memory_id": vector_id,
|
||||
"hash": payload["hash"],
|
||||
"memory": payload["data"],
|
||||
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
||||
"updated_at": int(datetime.fromisoformat(payload["updated_at"]).timestamp()),
|
||||
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
||||
}
|
||||
|
||||
for field in ["agent_id", "run_id", "user_id"]:
|
||||
if field in payload:
|
||||
data[field] = payload[field]
|
||||
|
||||
data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
||||
self.index.load(data=[data], keys=[f"{self.schema['index']['prefix']}:{vector_id}"], id_field="memory_id")
|
||||
|
||||
def get(self, vector_id):
|
||||
result = self.index.fetch(vector_id)
|
||||
payload = {
|
||||
"hash": result["hash"],
|
||||
"data": result["memory"],
|
||||
"created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone("US/Pacific")).isoformat(
|
||||
timespec="microseconds"
|
||||
),
|
||||
**(
|
||||
{
|
||||
"updated_at": datetime.fromtimestamp(
|
||||
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
||||
).isoformat(timespec="microseconds")
|
||||
}
|
||||
if "updated_at" in result
|
||||
else {}
|
||||
),
|
||||
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
||||
**{k: v for k, v in json.loads(result["metadata"]).items()},
|
||||
}
|
||||
|
||||
return MemoryResult(id=result["memory_id"], payload=payload)
|
||||
|
||||
def list_cols(self):
|
||||
return self.index.listall()
|
||||
|
||||
def delete_col(self):
|
||||
self.index.delete()
|
||||
|
||||
def col_info(self, name):
|
||||
return self.index.info()
|
||||
|
||||
def list(self, filters: dict = None, limit: int = None) -> list:
|
||||
"""
|
||||
List all recent created memories from the vector store.
|
||||
"""
|
||||
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
|
||||
filter = reduce(lambda x, y: x & y, conditions)
|
||||
query = Query(str(filter)).sort_by("created_at", asc=False)
|
||||
if limit is not None:
|
||||
query = Query(str(filter)).sort_by("created_at", asc=False).paging(0, limit)
|
||||
|
||||
results = self.index.search(query)
|
||||
return [
|
||||
[
|
||||
MemoryResult(
|
||||
id=result["memory_id"],
|
||||
payload={
|
||||
"hash": result["hash"],
|
||||
"data": result["memory"],
|
||||
"created_at": datetime.fromtimestamp(
|
||||
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
|
||||
).isoformat(timespec="microseconds"),
|
||||
**(
|
||||
{
|
||||
"updated_at": datetime.fromtimestamp(
|
||||
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
||||
).isoformat(timespec="microseconds")
|
||||
}
|
||||
if result.__dict__.get("updated_at")
|
||||
else {}
|
||||
),
|
||||
**{
|
||||
field: result[field]
|
||||
for field in ["agent_id", "run_id", "user_id"]
|
||||
if field in result.__dict__
|
||||
},
|
||||
**{k: v for k, v in json.loads(result["metadata"]).items()},
|
||||
},
|
||||
)
|
||||
for result in results.docs
|
||||
]
|
||||
]
|
||||
Reference in New Issue
Block a user