Feature: milvus db integration (#1821)
This commit is contained in:
@@ -6,7 +6,7 @@ Config in mem0 is a dictionary that specifies the settings for your vector datab
|
||||
|
||||
The config is defined as a Python dictionary with two main keys:
|
||||
- `vector_store`: Specifies the vector database provider and its configuration
|
||||
- `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant")
|
||||
- `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus")
|
||||
- `config`: A nested dictionary containing provider-specific settings
|
||||
|
||||
## How to Use Config
|
||||
|
||||
35
docs/components/vectordbs/dbs/milvus.mdx
Normal file
35
docs/components/vectordbs/dbs/milvus.mdx
Normal file
@@ -0,0 +1,35 @@
|
||||
[Milvus](https://milvus.io/) Milvus is an open-source vector database that suits AI applications of every size from running a demo chatbot in Jupyter notebook to building web-scale search that serves billions of users.
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
import os
|
||||
from mem0 import Memory
|
||||
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "milvus",
|
||||
"config": {
|
||||
"collection_name": "test",
|
||||
"embedding_model_dims": "123",
|
||||
"url": "127.0.0.1",
|
||||
"token": "8e4b8ca8cf2c67",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
|
||||
```
|
||||
|
||||
### Config
|
||||
|
||||
Here's the parameters available for configuring Milvus Database:
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| --- | --- | --- |
|
||||
| `url` | Full URL/Uri for Milvus/Zilliz server | `http://localhost:19530` |
|
||||
| `token` | Token for Zilliz server / for local setup defaults to None. | `None` |
|
||||
| `collection_name` | The name of the collection | `mem0` |
|
||||
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
|
||||
| `metric_type` | Metric type for similarity search | `L2` |
|
||||
41
mem0/configs/vector_stores/milvus.py
Normal file
41
mem0/configs/vector_stores/milvus.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel, model_validator, Field
|
||||
|
||||
|
||||
class MetricType(str, Enum):
|
||||
"""
|
||||
Metric Constant for milvus/ zilliz server.
|
||||
"""
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
L2 = "L2"
|
||||
IP = "IP"
|
||||
COSINE = "COSINE"
|
||||
HAMMING = "HAMMING"
|
||||
JACCARD = "JACCARD"
|
||||
|
||||
|
||||
class MilvusDBConfig(BaseModel):
|
||||
url: str = Field("http://localhost:19530", description="Full URL for Milvus/Zilliz server")
|
||||
token: str = Field(None, description="Token for Zilliz server / local setup defaults to None.")
|
||||
collection_name: str = Field("mem0", description="Name of the collection")
|
||||
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
||||
metric_type: str = Field("L2", description="Metric type for similarity search")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
@@ -5,7 +5,6 @@ ADD_MEMORY_TOOL = {
|
||||
"function": {
|
||||
"name": "add_memory",
|
||||
"description": "Add a memory",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -22,7 +21,6 @@ UPDATE_MEMORY_TOOL = {
|
||||
"function": {
|
||||
"name": "update_memory",
|
||||
"description": "Update memory provided ID and data",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -46,7 +44,6 @@ DELETE_MEMORY_TOOL = {
|
||||
"function": {
|
||||
"name": "delete_memory",
|
||||
"description": "Delete memory by memory_id",
|
||||
"strict": True,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -59,6 +59,7 @@ class VectorStoreFactory:
|
||||
"qdrant": "mem0.vector_stores.qdrant.Qdrant",
|
||||
"chroma": "mem0.vector_stores.chroma.ChromaDB",
|
||||
"pgvector": "mem0.vector_stores.pgvector.PGVector",
|
||||
"milvus": "mem0.vector_stores.milvus.MilvusDB"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -16,6 +16,7 @@ class VectorStoreConfig(BaseModel):
|
||||
"qdrant": "QdrantConfig",
|
||||
"chroma": "ChromaDbConfig",
|
||||
"pgvector": "PGVectorConfig",
|
||||
"milvus" : "MilvusDBConfig"
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
239
mem0/vector_stores/milvus.py
Normal file
239
mem0/vector_stores/milvus.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import logging
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
from mem0.configs.vector_stores.milvus import MetricType
|
||||
|
||||
try:
|
||||
import pymilvus
|
||||
except ImportError:
|
||||
raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.")
|
||||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
|
||||
class MilvusDB(VectorStoreBase):
|
||||
def __init__(self, url: str, token: str, collection_name: str, embedding_model_dims: int, metric_type: MetricType) -> None:
|
||||
"""Initialize the MilvusDB database.
|
||||
|
||||
Args:
|
||||
url (str): Full URL for Milvus/Zilliz server.
|
||||
token (str): Token/api_key for Zilliz server / for local setup defaults to None.
|
||||
collection_name (str): Name of the collection (defaults to mem0).
|
||||
embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536).
|
||||
metric_type (MetricType): Metric type for similarity search (defaults to L2).
|
||||
"""
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.metric_type = metric_type
|
||||
|
||||
self.client = MilvusClient(uri=url,token=token)
|
||||
|
||||
self.create_col(
|
||||
collection_name=self.collection_name,
|
||||
vector_size=self.embedding_model_dims,
|
||||
metric_type=self.metric_type
|
||||
)
|
||||
|
||||
|
||||
def create_col(
|
||||
self, collection_name : str, vector_size : str, metric_type : MetricType = MetricType.COSINE
|
||||
) -> None:
|
||||
"""Create a new collection with index_type AUTOINDEX.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the collection (defaults to mem0).
|
||||
vector_size (str): Dimensions of the embedding model (defaults to 1536).
|
||||
metric_type (MetricType, optional): etric type for similarity search. Defaults to MetricType.COSINE.
|
||||
"""
|
||||
|
||||
if self.client.has_collection(collection_name):
|
||||
logger.info(f"Collection {collection_name} already exists. Skipping creation.")
|
||||
else:
|
||||
fields = [
|
||||
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
|
||||
FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
|
||||
FieldSchema(name="metadata", dtype=DataType.JSON),
|
||||
]
|
||||
|
||||
schema = CollectionSchema(fields, enable_dynamic_field=True)
|
||||
|
||||
index = self.client.prepare_index_params(
|
||||
field_name="vectors",
|
||||
metric_type=metric_type,
|
||||
index_type="AUTOINDEX",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
|
||||
|
||||
|
||||
def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):
|
||||
"""Insert vectors into a collection.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert.
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||
"""
|
||||
for idx, embedding, metadata in zip(ids, vectors, payloads):
|
||||
data = {"id": idx, "vectors": embedding, "metadata": metadata}
|
||||
self.client.insert(collection_name=self.collection_name, data=data, **kwargs)
|
||||
|
||||
|
||||
def _create_filter(self, filters: dict):
|
||||
"""Prepare filters for efficient query.
|
||||
|
||||
Args:
|
||||
filters (dict): filters [user_id, agent_id, run_id]
|
||||
|
||||
Returns:
|
||||
str: formated filter.
|
||||
"""
|
||||
operands = []
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, str):
|
||||
operands.append(f'(metadata["{key}"] == "{value}")')
|
||||
else:
|
||||
operands.append(f'(metadata["{key}"] == {value})')
|
||||
|
||||
return " and ".join(operands)
|
||||
|
||||
|
||||
def _parse_output(self, data: list):
|
||||
"""
|
||||
Parse the output data.
|
||||
|
||||
Args:
|
||||
data (Dict): Output data.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Parsed output data.
|
||||
"""
|
||||
memory = []
|
||||
|
||||
for value in data:
|
||||
uid, score, metadata = (
|
||||
value.get("id"),
|
||||
value.get("distance"),
|
||||
value.get("entity",{}).get("metadata")
|
||||
)
|
||||
|
||||
memory_obj = OutputData(id=uid, score=score, payload=metadata)
|
||||
memory.append(memory_obj)
|
||||
|
||||
return memory
|
||||
|
||||
|
||||
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (List[float]): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: Search results.
|
||||
"""
|
||||
query_filter = self._create_filter(filters) if filters else None
|
||||
hits = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
data=[query], limit=limit, filter=query_filter,
|
||||
output_fields=["*"]
|
||||
)
|
||||
result = self._parse_output(data=hits[0])
|
||||
|
||||
return result
|
||||
|
||||
def delete(self, vector_id):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete.
|
||||
"""
|
||||
self.client.delete(collection_name=self.collection_name, ids=vector_id)
|
||||
|
||||
|
||||
def update(self, vector_id=None, vector=None, payload=None):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update.
|
||||
vector (List[float], optional): Updated vector.
|
||||
payload (Dict, optional): Updated payload.
|
||||
"""
|
||||
schema = {"id" : vector_id, "vectors": vector, "metadata" : payload}
|
||||
self.client.upsert(collection_name=self.collection_name, data=schema)
|
||||
|
||||
def get(self, vector_id):
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
OutputData: Retrieved vector.
|
||||
"""
|
||||
result = self.client.get(collection_name=self.collection_name, ids=vector_id)
|
||||
output = OutputData(id=result[0].get("id", None), score=None, payload=result[0].get("metadata", None))
|
||||
return output
|
||||
|
||||
def list_cols(self):
|
||||
"""
|
||||
List all collections.
|
||||
|
||||
Returns:
|
||||
List[str]: List of collection names.
|
||||
"""
|
||||
return self.client.list_collections()
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete a collection."""
|
||||
return self.client.drop_collection(collection_name=self.collection_name)
|
||||
|
||||
def col_info(self):
|
||||
"""
|
||||
Get information about a collection.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Collection information.
|
||||
"""
|
||||
return self.client.get_collection_stats(collection_name=self.collection_name)
|
||||
|
||||
def list(self, filters: dict = None, limit: int = 100) -> list:
|
||||
"""
|
||||
List all vectors in a collection.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to apply to the list.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors.
|
||||
"""
|
||||
query_filter = self._create_filter(filters) if filters else None
|
||||
result = self.client.query(
|
||||
collection_name=self.collection_name,
|
||||
filter=query_filter,
|
||||
limit=limit)
|
||||
memories = []
|
||||
for data in result:
|
||||
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
|
||||
memories.append(obj)
|
||||
return [memories]
|
||||
4
poetry.lock
generated
4
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
@@ -1966,4 +1966,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
content-hash = "458055aee51b5e75c8f189fc1b0fbd238b9bb0d8a8becced0bd62a6a59d8d428"
|
||||
content-hash = "5a74dacc8f9b1b40bb9d53fbbdcb0a95f5d05d55ffd9d61af870ca8a731954b4"
|
||||
|
||||
@@ -35,8 +35,6 @@ isort = "^5.13.2"
|
||||
pytest = "^8.2.2"
|
||||
|
||||
|
||||
[tool.poetry.group.optional.dependencies]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
Reference in New Issue
Block a user