Feature: milvus db integration (#1821)
This commit is contained in:
@@ -6,7 +6,7 @@ Config in mem0 is a dictionary that specifies the settings for your vector datab
|
|||||||
|
|
||||||
The config is defined as a Python dictionary with two main keys:
|
The config is defined as a Python dictionary with two main keys:
|
||||||
- `vector_store`: Specifies the vector database provider and its configuration
|
- `vector_store`: Specifies the vector database provider and its configuration
|
||||||
- `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant")
|
- `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus")
|
||||||
- `config`: A nested dictionary containing provider-specific settings
|
- `config`: A nested dictionary containing provider-specific settings
|
||||||
|
|
||||||
## How to Use Config
|
## How to Use Config
|
||||||
|
|||||||
35
docs/components/vectordbs/dbs/milvus.mdx
Normal file
35
docs/components/vectordbs/dbs/milvus.mdx
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
[Milvus](https://milvus.io/) Milvus is an open-source vector database that suits AI applications of every size from running a demo chatbot in Jupyter notebook to building web-scale search that serves billions of users.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"vector_store": {
|
||||||
|
"provider": "milvus",
|
||||||
|
"config": {
|
||||||
|
"collection_name": "test",
|
||||||
|
"embedding_model_dims": "123",
|
||||||
|
"url": "127.0.0.1",
|
||||||
|
"token": "8e4b8ca8cf2c67",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m = Memory.from_config(config)
|
||||||
|
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Config
|
||||||
|
|
||||||
|
Here's the parameters available for configuring Milvus Database:
|
||||||
|
|
||||||
|
| Parameter | Description | Default Value |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `url` | Full URL/Uri for Milvus/Zilliz server | `http://localhost:19530` |
|
||||||
|
| `token` | Token for Zilliz server / for local setup defaults to None. | `None` |
|
||||||
|
| `collection_name` | The name of the collection | `mem0` |
|
||||||
|
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
|
||||||
|
| `metric_type` | Metric type for similarity search | `L2` |
|
||||||
41
mem0/configs/vector_stores/milvus.py
Normal file
41
mem0/configs/vector_stores/milvus.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Any
|
||||||
|
from pydantic import BaseModel, model_validator, Field
|
||||||
|
|
||||||
|
|
||||||
|
class MetricType(str, Enum):
|
||||||
|
"""
|
||||||
|
Metric Constant for milvus/ zilliz server.
|
||||||
|
"""
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
L2 = "L2"
|
||||||
|
IP = "IP"
|
||||||
|
COSINE = "COSINE"
|
||||||
|
HAMMING = "HAMMING"
|
||||||
|
JACCARD = "JACCARD"
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusDBConfig(BaseModel):
|
||||||
|
url: str = Field("http://localhost:19530", description="Full URL for Milvus/Zilliz server")
|
||||||
|
token: str = Field(None, description="Token for Zilliz server / local setup defaults to None.")
|
||||||
|
collection_name: str = Field("mem0", description="Name of the collection")
|
||||||
|
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
||||||
|
metric_type: str = Field("L2", description="Metric type for similarity search")
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
allowed_fields = set(cls.model_fields.keys())
|
||||||
|
input_fields = set(values.keys())
|
||||||
|
extra_fields = input_fields - allowed_fields
|
||||||
|
if extra_fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True,
|
||||||
|
}
|
||||||
@@ -5,7 +5,6 @@ ADD_MEMORY_TOOL = {
|
|||||||
"function": {
|
"function": {
|
||||||
"name": "add_memory",
|
"name": "add_memory",
|
||||||
"description": "Add a memory",
|
"description": "Add a memory",
|
||||||
"strict": True,
|
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -22,7 +21,6 @@ UPDATE_MEMORY_TOOL = {
|
|||||||
"function": {
|
"function": {
|
||||||
"name": "update_memory",
|
"name": "update_memory",
|
||||||
"description": "Update memory provided ID and data",
|
"description": "Update memory provided ID and data",
|
||||||
"strict": True,
|
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -46,7 +44,6 @@ DELETE_MEMORY_TOOL = {
|
|||||||
"function": {
|
"function": {
|
||||||
"name": "delete_memory",
|
"name": "delete_memory",
|
||||||
"description": "Delete memory by memory_id",
|
"description": "Delete memory by memory_id",
|
||||||
"strict": True,
|
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ class VectorStoreFactory:
|
|||||||
"qdrant": "mem0.vector_stores.qdrant.Qdrant",
|
"qdrant": "mem0.vector_stores.qdrant.Qdrant",
|
||||||
"chroma": "mem0.vector_stores.chroma.ChromaDB",
|
"chroma": "mem0.vector_stores.chroma.ChromaDB",
|
||||||
"pgvector": "mem0.vector_stores.pgvector.PGVector",
|
"pgvector": "mem0.vector_stores.pgvector.PGVector",
|
||||||
|
"milvus": "mem0.vector_stores.milvus.MilvusDB"
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ class VectorStoreConfig(BaseModel):
|
|||||||
"qdrant": "QdrantConfig",
|
"qdrant": "QdrantConfig",
|
||||||
"chroma": "ChromaDbConfig",
|
"chroma": "ChromaDbConfig",
|
||||||
"pgvector": "PGVectorConfig",
|
"pgvector": "PGVectorConfig",
|
||||||
|
"milvus" : "MilvusDBConfig"
|
||||||
}
|
}
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|||||||
239
mem0/vector_stores/milvus.py
Normal file
239
mem0/vector_stores/milvus.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
import logging
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional, Dict
|
||||||
|
from mem0.vector_stores.base import VectorStoreBase
|
||||||
|
from mem0.configs.vector_stores.milvus import MetricType
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pymilvus
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.")
|
||||||
|
|
||||||
|
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputData(BaseModel):
|
||||||
|
id: Optional[str] # memory id
|
||||||
|
score: Optional[float] # distance
|
||||||
|
payload: Optional[Dict] # metadata
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusDB(VectorStoreBase):
|
||||||
|
def __init__(self, url: str, token: str, collection_name: str, embedding_model_dims: int, metric_type: MetricType) -> None:
|
||||||
|
"""Initialize the MilvusDB database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): Full URL for Milvus/Zilliz server.
|
||||||
|
token (str): Token/api_key for Zilliz server / for local setup defaults to None.
|
||||||
|
collection_name (str): Name of the collection (defaults to mem0).
|
||||||
|
embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536).
|
||||||
|
metric_type (MetricType): Metric type for similarity search (defaults to L2).
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.embedding_model_dims = embedding_model_dims
|
||||||
|
self.metric_type = metric_type
|
||||||
|
|
||||||
|
self.client = MilvusClient(uri=url,token=token)
|
||||||
|
|
||||||
|
self.create_col(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
vector_size=self.embedding_model_dims,
|
||||||
|
metric_type=self.metric_type
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_col(
|
||||||
|
self, collection_name : str, vector_size : str, metric_type : MetricType = MetricType.COSINE
|
||||||
|
) -> None:
|
||||||
|
"""Create a new collection with index_type AUTOINDEX.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name (str): Name of the collection (defaults to mem0).
|
||||||
|
vector_size (str): Dimensions of the embedding model (defaults to 1536).
|
||||||
|
metric_type (MetricType, optional): etric type for similarity search. Defaults to MetricType.COSINE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.client.has_collection(collection_name):
|
||||||
|
logger.info(f"Collection {collection_name} already exists. Skipping creation.")
|
||||||
|
else:
|
||||||
|
fields = [
|
||||||
|
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
|
||||||
|
FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
|
||||||
|
FieldSchema(name="metadata", dtype=DataType.JSON),
|
||||||
|
]
|
||||||
|
|
||||||
|
schema = CollectionSchema(fields, enable_dynamic_field=True)
|
||||||
|
|
||||||
|
index = self.client.prepare_index_params(
|
||||||
|
field_name="vectors",
|
||||||
|
metric_type=metric_type,
|
||||||
|
index_type="AUTOINDEX",
|
||||||
|
index_name="vector_index",
|
||||||
|
params={ "nlist": 128 }
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
|
||||||
|
|
||||||
|
|
||||||
|
def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):
|
||||||
|
"""Insert vectors into a collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vectors (List[List[float]]): List of vectors to insert.
|
||||||
|
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||||
|
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||||
|
"""
|
||||||
|
for idx, embedding, metadata in zip(ids, vectors, payloads):
|
||||||
|
data = {"id": idx, "vectors": embedding, "metadata": metadata}
|
||||||
|
self.client.insert(collection_name=self.collection_name, data=data, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_filter(self, filters: dict):
|
||||||
|
"""Prepare filters for efficient query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters (dict): filters [user_id, agent_id, run_id]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: formated filter.
|
||||||
|
"""
|
||||||
|
operands = []
|
||||||
|
for key, value in filters.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
operands.append(f'(metadata["{key}"] == "{value}")')
|
||||||
|
else:
|
||||||
|
operands.append(f'(metadata["{key}"] == {value})')
|
||||||
|
|
||||||
|
return " and ".join(operands)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_output(self, data: list):
|
||||||
|
"""
|
||||||
|
Parse the output data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (Dict): Output data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OutputData]: Parsed output data.
|
||||||
|
"""
|
||||||
|
memory = []
|
||||||
|
|
||||||
|
for value in data:
|
||||||
|
uid, score, metadata = (
|
||||||
|
value.get("id"),
|
||||||
|
value.get("distance"),
|
||||||
|
value.get("entity",{}).get("metadata")
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_obj = OutputData(id=uid, score=score, payload=metadata)
|
||||||
|
memory.append(memory_obj)
|
||||||
|
|
||||||
|
return memory
|
||||||
|
|
||||||
|
|
||||||
|
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
|
||||||
|
"""
|
||||||
|
Search for similar vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (List[float]): Query vector.
|
||||||
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
|
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Search results.
|
||||||
|
"""
|
||||||
|
query_filter = self._create_filter(filters) if filters else None
|
||||||
|
hits = self.client.search(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
data=[query], limit=limit, filter=query_filter,
|
||||||
|
output_fields=["*"]
|
||||||
|
)
|
||||||
|
result = self._parse_output(data=hits[0])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def delete(self, vector_id):
|
||||||
|
"""
|
||||||
|
Delete a vector by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to delete.
|
||||||
|
"""
|
||||||
|
self.client.delete(collection_name=self.collection_name, ids=vector_id)
|
||||||
|
|
||||||
|
|
||||||
|
def update(self, vector_id=None, vector=None, payload=None):
|
||||||
|
"""
|
||||||
|
Update a vector and its payload.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to update.
|
||||||
|
vector (List[float], optional): Updated vector.
|
||||||
|
payload (Dict, optional): Updated payload.
|
||||||
|
"""
|
||||||
|
schema = {"id" : vector_id, "vectors": vector, "metadata" : payload}
|
||||||
|
self.client.upsert(collection_name=self.collection_name, data=schema)
|
||||||
|
|
||||||
|
def get(self, vector_id):
|
||||||
|
"""
|
||||||
|
Retrieve a vector by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OutputData: Retrieved vector.
|
||||||
|
"""
|
||||||
|
result = self.client.get(collection_name=self.collection_name, ids=vector_id)
|
||||||
|
output = OutputData(id=result[0].get("id", None), score=None, payload=result[0].get("metadata", None))
|
||||||
|
return output
|
||||||
|
|
||||||
|
def list_cols(self):
|
||||||
|
"""
|
||||||
|
List all collections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of collection names.
|
||||||
|
"""
|
||||||
|
return self.client.list_collections()
|
||||||
|
|
||||||
|
def delete_col(self):
|
||||||
|
"""Delete a collection."""
|
||||||
|
return self.client.drop_collection(collection_name=self.collection_name)
|
||||||
|
|
||||||
|
def col_info(self):
|
||||||
|
"""
|
||||||
|
Get information about a collection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Collection information.
|
||||||
|
"""
|
||||||
|
return self.client.get_collection_stats(collection_name=self.collection_name)
|
||||||
|
|
||||||
|
def list(self, filters: dict = None, limit: int = 100) -> list:
|
||||||
|
"""
|
||||||
|
List all vectors in a collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters (Dict, optional): Filters to apply to the list.
|
||||||
|
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OutputData]: List of vectors.
|
||||||
|
"""
|
||||||
|
query_filter = self._create_filter(filters) if filters else None
|
||||||
|
result = self.client.query(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
filter=query_filter,
|
||||||
|
limit=limit)
|
||||||
|
memories = []
|
||||||
|
for data in result:
|
||||||
|
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
|
||||||
|
memories.append(obj)
|
||||||
|
return [memories]
|
||||||
4
poetry.lock
generated
4
poetry.lock
generated
@@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiohappyeyeballs"
|
name = "aiohappyeyeballs"
|
||||||
@@ -1966,4 +1966,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<4.0"
|
python-versions = ">=3.9,<4.0"
|
||||||
content-hash = "458055aee51b5e75c8f189fc1b0fbd238b9bb0d8a8becced0bd62a6a59d8d428"
|
content-hash = "5a74dacc8f9b1b40bb9d53fbbdcb0a95f5d05d55ffd9d61af870ca8a731954b4"
|
||||||
|
|||||||
@@ -35,8 +35,6 @@ isort = "^5.13.2"
|
|||||||
pytest = "^8.2.2"
|
pytest = "^8.2.2"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.optional.dependencies]
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|||||||
Reference in New Issue
Block a user