[Redis]: Vector database added. (#2032)

This commit is contained in:
Mayank
2024-11-20 17:12:16 +05:30
committed by GitHub
parent 13374a12e9
commit 5ab09ffd5a
7 changed files with 311 additions and 1 deletions

View File

@@ -0,0 +1,26 @@
from typing import Any, Dict
from pydantic import BaseModel, Field, model_validator
# TODO: Upgrade to latest pydantic version
class RedisDBConfig(BaseModel):
redis_url: str = Field(..., description="Redis URL")
collection_name: str = Field("mem0", description="Collection name")
embedding_model_dims: int = Field(1536, description="Embedding model dimensions")
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = {
"arbitrary_types_allowed": True,
}

View File

@@ -65,6 +65,7 @@ class VectorStoreFactory:
"pgvector": "mem0.vector_stores.pgvector.PGVector",
"milvus": "mem0.vector_stores.milvus.MilvusDB",
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
"redis": "mem0.vector_stores.redis.RedisDB",
}
@classmethod

View File

@@ -16,6 +16,7 @@ class VectorStoreConfig(BaseModel):
"pgvector": "PGVectorConfig",
"milvus": "MilvusDBConfig",
"azure_ai_search": "AzureAISearchConfig",
"redis": "RedisDBConfig",
}
@model_validator(mode="after")

236
mem0/vector_stores/redis.py Normal file
View File

@@ -0,0 +1,236 @@
import json
import logging
from datetime import datetime
from functools import reduce
import numpy as np
import pytz
import redis
from redis.commands.search.query import Query
from redisvl.index import SearchIndex
from redisvl.query import VectorQuery
from redisvl.query.filter import Tag
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
# TODO: Improve as these are not the best fields for the Redis's perspective. Might do away with them.
DEFAULT_FIELDS = [
{"name": "memory_id", "type": "tag"},
{"name": "hash", "type": "tag"},
{"name": "agent_id", "type": "tag"},
{"name": "run_id", "type": "tag"},
{"name": "user_id", "type": "tag"},
{"name": "memory", "type": "text"},
{"name": "metadata", "type": "text"},
# TODO: Although it is numeric but also accepts string
{"name": "created_at", "type": "numeric"},
{"name": "updated_at", "type": "numeric"},
{
"name": "embedding",
"type": "vector",
"attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"},
},
]
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
class MemoryResult:
def __init__(self, id: str, payload: dict, score: float = None):
self.id = id
self.payload = payload
self.score = score
class RedisDB(VectorStoreBase):
def __init__(
self,
redis_url: str,
collection_name: str,
embedding_model_dims: int,
):
"""
Initialize the Redis vector store.
Args:
redis_url (str): Redis URL.
collection_name (str): Collection name.
embedding_model_dims (int): Embedding model dimensions.
"""
index_schema = {
"name": collection_name,
"prefix": f"mem0:{collection_name}",
}
fields = DEFAULT_FIELDS.copy()
fields[-1]["attrs"]["dims"] = embedding_model_dims
self.schema = {"index": index_schema, "fields": fields}
self.client = redis.Redis.from_url(redis_url)
self.index = SearchIndex.from_dict(self.schema)
self.index.set_client(self.client)
self.index.create(overwrite=True)
# TODO: Implement multiindex support.
def create_col(self, name, vector_size, distance):
raise NotImplementedError("Collection/Index creation not supported yet.")
def insert(self, vectors: list, payloads: list = None, ids: list = None):
data = []
for vector, payload, id in zip(vectors, payloads, ids):
# Start with required fields
entry = {
"memory_id": id,
"hash": payload["hash"],
"memory": payload["data"],
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
"embedding": np.array(vector, dtype=np.float32).tobytes(),
}
# Conditionally add optional fields
for field in ["agent_id", "run_id", "user_id"]:
if field in payload:
entry[field] = payload[field]
# Add metadata excluding specific keys
entry["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
data.append(entry)
self.index.load(data, id_field="memory_id")
def search(self, query: list, limit: int = 5, filters: dict = None):
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
filter = reduce(lambda x, y: x & y, conditions)
v = VectorQuery(
vector=np.array(query, dtype=np.float32).tobytes(),
vector_field_name="embedding",
return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"],
filter_expression=filter,
num_results=limit,
)
results = self.index.query(v)
return [
MemoryResult(
id=result["memory_id"],
score=result["vector_distance"],
payload={
"hash": result["hash"],
"data": result["memory"],
"created_at": datetime.fromtimestamp(
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds"),
**(
{
"updated_at": datetime.fromtimestamp(
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds")
}
if "updated_at" in result
else {}
),
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
**{k: v for k, v in json.loads(result["metadata"]).items()},
},
)
for result in results
]
def delete(self, vector_id):
self.index.drop_keys(f"{self.schema['index']['prefix']}:{vector_id}")
def update(self, vector_id=None, vector=None, payload=None):
data = {
"memory_id": vector_id,
"hash": payload["hash"],
"memory": payload["data"],
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
"updated_at": int(datetime.fromisoformat(payload["updated_at"]).timestamp()),
"embedding": np.array(vector, dtype=np.float32).tobytes(),
}
for field in ["agent_id", "run_id", "user_id"]:
if field in payload:
data[field] = payload[field]
data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
self.index.load(data=[data], keys=[f"{self.schema['index']['prefix']}:{vector_id}"], id_field="memory_id")
def get(self, vector_id):
result = self.index.fetch(vector_id)
payload = {
"hash": result["hash"],
"data": result["memory"],
"created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone("US/Pacific")).isoformat(
timespec="microseconds"
),
**(
{
"updated_at": datetime.fromtimestamp(
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds")
}
if "updated_at" in result
else {}
),
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
**{k: v for k, v in json.loads(result["metadata"]).items()},
}
return MemoryResult(id=result["memory_id"], payload=payload)
def list_cols(self):
return self.index.listall()
def delete_col(self):
self.index.delete()
def col_info(self, name):
return self.index.info()
def list(self, filters: dict = None, limit: int = None) -> list:
"""
List all recent created memories from the vector store.
"""
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
filter = reduce(lambda x, y: x & y, conditions)
query = Query(str(filter)).sort_by("created_at", asc=False)
if limit is not None:
query = Query(str(filter)).sort_by("created_at", asc=False).paging(0, limit)
results = self.index.search(query)
return [
[
MemoryResult(
id=result["memory_id"],
payload={
"hash": result["hash"],
"data": result["memory"],
"created_at": datetime.fromtimestamp(
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds"),
**(
{
"updated_at": datetime.fromtimestamp(
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
).isoformat(timespec="microseconds")
}
if result.__dict__.get("updated_at")
else {}
),
**{
field: result[field]
for field in ["agent_id", "run_id", "user_id"]
if field in result.__dict__
},
**{k: v for k, v in json.loads(result["metadata"]).items()},
},
)
for result in results.docs
]
]