Feature: milvus db integration (#1821)

This commit is contained in:
k10
2024-09-10 22:06:50 +05:30
committed by GitHub
parent 5b9b65c395
commit 3bd49b57cc
9 changed files with 320 additions and 8 deletions

View File

@@ -6,7 +6,7 @@ Config in mem0 is a dictionary that specifies the settings for your vector datab
The config is defined as a Python dictionary with two main keys:
- `vector_store`: Specifies the vector database provider and its configuration
- `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant")
- `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus")
- `config`: A nested dictionary containing provider-specific settings
## How to Use Config

View File

@@ -0,0 +1,35 @@
[Milvus](https://milvus.io/) Milvus is an open-source vector database that suits AI applications of every size from running a demo chatbot in Jupyter notebook to building web-scale search that serves billions of users.
### Usage
```python
import os
from mem0 import Memory
config = {
"vector_store": {
"provider": "milvus",
"config": {
"collection_name": "test",
"embedding_model_dims": "123",
"url": "127.0.0.1",
"token": "8e4b8ca8cf2c67",
}
}
}
m = Memory.from_config(config)
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
```
### Config
Here's the parameters available for configuring Milvus Database:
| Parameter | Description | Default Value |
| --- | --- | --- |
| `url` | Full URL/Uri for Milvus/Zilliz server | `http://localhost:19530` |
| `token` | Token for Zilliz server / for local setup defaults to None. | `None` |
| `collection_name` | The name of the collection | `mem0` |
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
| `metric_type` | Metric type for similarity search | `L2` |

View File

@@ -0,0 +1,41 @@
from enum import Enum
from typing import Dict, Any
from pydantic import BaseModel, model_validator, Field
class MetricType(str, Enum):
"""
Metric Constant for milvus/ zilliz server.
"""
def __str__(self) -> str:
return str(self.value)
L2 = "L2"
IP = "IP"
COSINE = "COSINE"
HAMMING = "HAMMING"
JACCARD = "JACCARD"
class MilvusDBConfig(BaseModel):
url: str = Field("http://localhost:19530", description="Full URL for Milvus/Zilliz server")
token: str = Field(None, description="Token for Zilliz server / local setup defaults to None.")
collection_name: str = Field("mem0", description="Name of the collection")
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
metric_type: str = Field("L2", description="Metric type for similarity search")
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values
model_config = {
"arbitrary_types_allowed": True,
}

View File

@@ -5,7 +5,6 @@ ADD_MEMORY_TOOL = {
"function": {
"name": "add_memory",
"description": "Add a memory",
"strict": True,
"parameters": {
"type": "object",
"properties": {
@@ -22,7 +21,6 @@ UPDATE_MEMORY_TOOL = {
"function": {
"name": "update_memory",
"description": "Update memory provided ID and data",
"strict": True,
"parameters": {
"type": "object",
"properties": {
@@ -46,7 +44,6 @@ DELETE_MEMORY_TOOL = {
"function": {
"name": "delete_memory",
"description": "Delete memory by memory_id",
"strict": True,
"parameters": {
"type": "object",
"properties": {

View File

@@ -59,6 +59,7 @@ class VectorStoreFactory:
"qdrant": "mem0.vector_stores.qdrant.Qdrant",
"chroma": "mem0.vector_stores.chroma.ChromaDB",
"pgvector": "mem0.vector_stores.pgvector.PGVector",
"milvus": "mem0.vector_stores.milvus.MilvusDB"
}
@classmethod

View File

@@ -16,6 +16,7 @@ class VectorStoreConfig(BaseModel):
"qdrant": "QdrantConfig",
"chroma": "ChromaDbConfig",
"pgvector": "PGVectorConfig",
"milvus" : "MilvusDBConfig"
}
@model_validator(mode="after")

View File

@@ -0,0 +1,239 @@
import logging
from pydantic import BaseModel
from typing import Optional, Dict
from mem0.vector_stores.base import VectorStoreBase
from mem0.configs.vector_stores.milvus import MetricType
try:
import pymilvus
except ImportError:
raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.")
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
score: Optional[float] # distance
payload: Optional[Dict] # metadata
class MilvusDB(VectorStoreBase):
def __init__(self, url: str, token: str, collection_name: str, embedding_model_dims: int, metric_type: MetricType) -> None:
"""Initialize the MilvusDB database.
Args:
url (str): Full URL for Milvus/Zilliz server.
token (str): Token/api_key for Zilliz server / for local setup defaults to None.
collection_name (str): Name of the collection (defaults to mem0).
embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536).
metric_type (MetricType): Metric type for similarity search (defaults to L2).
"""
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.metric_type = metric_type
self.client = MilvusClient(uri=url,token=token)
self.create_col(
collection_name=self.collection_name,
vector_size=self.embedding_model_dims,
metric_type=self.metric_type
)
def create_col(
self, collection_name : str, vector_size : str, metric_type : MetricType = MetricType.COSINE
) -> None:
"""Create a new collection with index_type AUTOINDEX.
Args:
collection_name (str): Name of the collection (defaults to mem0).
vector_size (str): Dimensions of the embedding model (defaults to 1536).
metric_type (MetricType, optional): etric type for similarity search. Defaults to MetricType.COSINE.
"""
if self.client.has_collection(collection_name):
logger.info(f"Collection {collection_name} already exists. Skipping creation.")
else:
fields = [
FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size),
FieldSchema(name="metadata", dtype=DataType.JSON),
]
schema = CollectionSchema(fields, enable_dynamic_field=True)
index = self.client.prepare_index_params(
field_name="vectors",
metric_type=metric_type,
index_type="AUTOINDEX",
index_name="vector_index",
params={ "nlist": 128 }
)
self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index)
def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]):
"""Insert vectors into a collection.
Args:
vectors (List[List[float]]): List of vectors to insert.
payloads (List[Dict], optional): List of payloads corresponding to vectors.
ids (List[str], optional): List of IDs corresponding to vectors.
"""
for idx, embedding, metadata in zip(ids, vectors, payloads):
data = {"id": idx, "vectors": embedding, "metadata": metadata}
self.client.insert(collection_name=self.collection_name, data=data, **kwargs)
def _create_filter(self, filters: dict):
"""Prepare filters for efficient query.
Args:
filters (dict): filters [user_id, agent_id, run_id]
Returns:
str: formated filter.
"""
operands = []
for key, value in filters.items():
if isinstance(value, str):
operands.append(f'(metadata["{key}"] == "{value}")')
else:
operands.append(f'(metadata["{key}"] == {value})')
return " and ".join(operands)
def _parse_output(self, data: list):
"""
Parse the output data.
Args:
data (Dict): Output data.
Returns:
List[OutputData]: Parsed output data.
"""
memory = []
for value in data:
uid, score, metadata = (
value.get("id"),
value.get("distance"),
value.get("entity",{}).get("metadata")
)
memory_obj = OutputData(id=uid, score=score, payload=metadata)
memory.append(memory_obj)
return memory
def search(self, query: list, limit: int = 5, filters: dict = None) -> list:
"""
Search for similar vectors.
Args:
query (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: Search results.
"""
query_filter = self._create_filter(filters) if filters else None
hits = self.client.search(
collection_name=self.collection_name,
data=[query], limit=limit, filter=query_filter,
output_fields=["*"]
)
result = self._parse_output(data=hits[0])
return result
def delete(self, vector_id):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
self.client.delete(collection_name=self.collection_name, ids=vector_id)
def update(self, vector_id=None, vector=None, payload=None):
"""
Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
schema = {"id" : vector_id, "vectors": vector, "metadata" : payload}
self.client.upsert(collection_name=self.collection_name, data=schema)
def get(self, vector_id):
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
OutputData: Retrieved vector.
"""
result = self.client.get(collection_name=self.collection_name, ids=vector_id)
output = OutputData(id=result[0].get("id", None), score=None, payload=result[0].get("metadata", None))
return output
def list_cols(self):
"""
List all collections.
Returns:
List[str]: List of collection names.
"""
return self.client.list_collections()
def delete_col(self):
"""Delete a collection."""
return self.client.drop_collection(collection_name=self.collection_name)
def col_info(self):
"""
Get information about a collection.
Returns:
Dict[str, Any]: Collection information.
"""
return self.client.get_collection_stats(collection_name=self.collection_name)
def list(self, filters: dict = None, limit: int = 100) -> list:
"""
List all vectors in a collection.
Args:
filters (Dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors.
"""
query_filter = self._create_filter(filters) if filters else None
result = self.client.query(
collection_name=self.collection_name,
filter=query_filter,
limit=limit)
memories = []
for data in result:
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
memories.append(obj)
return [memories]

4
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@@ -1966,4 +1966,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<4.0"
content-hash = "458055aee51b5e75c8f189fc1b0fbd238b9bb0d8a8becced0bd62a6a59d8d428"
content-hash = "5a74dacc8f9b1b40bb9d53fbbdcb0a95f5d05d55ffd9d61af870ca8a731954b4"

View File

@@ -35,8 +35,6 @@ isort = "^5.13.2"
pytest = "^8.2.2"
[tool.poetry.group.optional.dependencies]
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"