diff --git a/docs/components/vectordbs/config.mdx b/docs/components/vectordbs/config.mdx index fe0f1fcd..7ddff7c9 100644 --- a/docs/components/vectordbs/config.mdx +++ b/docs/components/vectordbs/config.mdx @@ -6,7 +6,7 @@ Config in mem0 is a dictionary that specifies the settings for your vector datab The config is defined as a Python dictionary with two main keys: - `vector_store`: Specifies the vector database provider and its configuration - - `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant") + - `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus") - `config`: A nested dictionary containing provider-specific settings ## How to Use Config diff --git a/docs/components/vectordbs/dbs/milvus.mdx b/docs/components/vectordbs/dbs/milvus.mdx new file mode 100644 index 00000000..12193f46 --- /dev/null +++ b/docs/components/vectordbs/dbs/milvus.mdx @@ -0,0 +1,35 @@ +[Milvus](https://milvus.io/) Milvus is an open-source vector database that suits AI applications of every size from running a demo chatbot in Jupyter notebook to building web-scale search that serves billions of users. + +### Usage + +```python +import os +from mem0 import Memory + +config = { + "vector_store": { + "provider": "milvus", + "config": { + "collection_name": "test", + "embedding_model_dims": "123", + "url": "127.0.0.1", + "token": "8e4b8ca8cf2c67", + } + } +} + +m = Memory.from_config(config) +m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) +``` + +### Config + +Here's the parameters available for configuring Milvus Database: + +| Parameter | Description | Default Value | +| --- | --- | --- | +| `url` | Full URL/Uri for Milvus/Zilliz server | `http://localhost:19530` | +| `token` | Token for Zilliz server / for local setup defaults to None. | `None` | +| `collection_name` | The name of the collection | `mem0` | +| `embedding_model_dims` | Dimensions of the embedding model | `1536` | +| `metric_type` | Metric type for similarity search | `L2` | diff --git a/mem0/configs/vector_stores/milvus.py b/mem0/configs/vector_stores/milvus.py new file mode 100644 index 00000000..1e433df1 --- /dev/null +++ b/mem0/configs/vector_stores/milvus.py @@ -0,0 +1,41 @@ +from enum import Enum +from typing import Dict, Any +from pydantic import BaseModel, model_validator, Field + + +class MetricType(str, Enum): + """ + Metric Constant for milvus/ zilliz server. + """ + def __str__(self) -> str: + return str(self.value) + + L2 = "L2" + IP = "IP" + COSINE = "COSINE" + HAMMING = "HAMMING" + JACCARD = "JACCARD" + + +class MilvusDBConfig(BaseModel): + url: str = Field("http://localhost:19530", description="Full URL for Milvus/Zilliz server") + token: str = Field(None, description="Token for Zilliz server / local setup defaults to None.") + collection_name: str = Field("mem0", description="Name of the collection") + embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model") + metric_type: str = Field("L2", description="Metric type for similarity search") + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError( + f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" + ) + return values + + model_config = { + "arbitrary_types_allowed": True, + } \ No newline at end of file diff --git a/mem0/llms/utils/tools.py b/mem0/llms/utils/tools.py index fb4ff4a2..64f93145 100644 --- a/mem0/llms/utils/tools.py +++ b/mem0/llms/utils/tools.py @@ -5,7 +5,6 @@ ADD_MEMORY_TOOL = { "function": { "name": "add_memory", "description": "Add a memory", - "strict": True, "parameters": { "type": "object", "properties": { @@ -22,7 +21,6 @@ UPDATE_MEMORY_TOOL = { "function": { "name": "update_memory", "description": "Update memory provided ID and data", - "strict": True, "parameters": { "type": "object", "properties": { @@ -46,7 +44,6 @@ DELETE_MEMORY_TOOL = { "function": { "name": "delete_memory", "description": "Delete memory by memory_id", - "strict": True, "parameters": { "type": "object", "properties": { diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index bdcc1806..7047febb 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -59,6 +59,7 @@ class VectorStoreFactory: "qdrant": "mem0.vector_stores.qdrant.Qdrant", "chroma": "mem0.vector_stores.chroma.ChromaDB", "pgvector": "mem0.vector_stores.pgvector.PGVector", + "milvus": "mem0.vector_stores.milvus.MilvusDB" } @classmethod diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index 2f052da5..d4cd6b13 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -16,6 +16,7 @@ class VectorStoreConfig(BaseModel): "qdrant": "QdrantConfig", "chroma": "ChromaDbConfig", "pgvector": "PGVectorConfig", + "milvus" : "MilvusDBConfig" } @model_validator(mode="after") diff --git a/mem0/vector_stores/milvus.py b/mem0/vector_stores/milvus.py new file mode 100644 index 00000000..eeed7ac6 --- /dev/null +++ b/mem0/vector_stores/milvus.py @@ -0,0 +1,239 @@ +import logging +from pydantic import BaseModel +from typing import Optional, Dict +from mem0.vector_stores.base import VectorStoreBase +from mem0.configs.vector_stores.milvus import MetricType + +try: + import pymilvus +except ImportError: + raise ImportError("The 'pymilvus' library is required. Please install it using 'pip install pymilvus'.") + +from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType + +logger = logging.getLogger(__name__) + + +class OutputData(BaseModel): + id: Optional[str] # memory id + score: Optional[float] # distance + payload: Optional[Dict] # metadata + + + +class MilvusDB(VectorStoreBase): + def __init__(self, url: str, token: str, collection_name: str, embedding_model_dims: int, metric_type: MetricType) -> None: + """Initialize the MilvusDB database. + + Args: + url (str): Full URL for Milvus/Zilliz server. + token (str): Token/api_key for Zilliz server / for local setup defaults to None. + collection_name (str): Name of the collection (defaults to mem0). + embedding_model_dims (int): Dimensions of the embedding model (defaults to 1536). + metric_type (MetricType): Metric type for similarity search (defaults to L2). + """ + + self.collection_name = collection_name + self.embedding_model_dims = embedding_model_dims + self.metric_type = metric_type + + self.client = MilvusClient(uri=url,token=token) + + self.create_col( + collection_name=self.collection_name, + vector_size=self.embedding_model_dims, + metric_type=self.metric_type + ) + + + def create_col( + self, collection_name : str, vector_size : str, metric_type : MetricType = MetricType.COSINE + ) -> None: + """Create a new collection with index_type AUTOINDEX. + + Args: + collection_name (str): Name of the collection (defaults to mem0). + vector_size (str): Dimensions of the embedding model (defaults to 1536). + metric_type (MetricType, optional): etric type for similarity search. Defaults to MetricType.COSINE. + """ + + if self.client.has_collection(collection_name): + logger.info(f"Collection {collection_name} already exists. Skipping creation.") + else: + fields = [ + FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=512), + FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=vector_size), + FieldSchema(name="metadata", dtype=DataType.JSON), + ] + + schema = CollectionSchema(fields, enable_dynamic_field=True) + + index = self.client.prepare_index_params( + field_name="vectors", + metric_type=metric_type, + index_type="AUTOINDEX", + index_name="vector_index", + params={ "nlist": 128 } + ) + + self.client.create_collection(collection_name=collection_name, schema=schema, index_params=index) + + + def insert(self, ids, vectors, payloads, **kwargs: Optional[dict[str, any]]): + """Insert vectors into a collection. + + Args: + vectors (List[List[float]]): List of vectors to insert. + payloads (List[Dict], optional): List of payloads corresponding to vectors. + ids (List[str], optional): List of IDs corresponding to vectors. + """ + for idx, embedding, metadata in zip(ids, vectors, payloads): + data = {"id": idx, "vectors": embedding, "metadata": metadata} + self.client.insert(collection_name=self.collection_name, data=data, **kwargs) + + + def _create_filter(self, filters: dict): + """Prepare filters for efficient query. + + Args: + filters (dict): filters [user_id, agent_id, run_id] + + Returns: + str: formated filter. + """ + operands = [] + for key, value in filters.items(): + if isinstance(value, str): + operands.append(f'(metadata["{key}"] == "{value}")') + else: + operands.append(f'(metadata["{key}"] == {value})') + + return " and ".join(operands) + + + def _parse_output(self, data: list): + """ + Parse the output data. + + Args: + data (Dict): Output data. + + Returns: + List[OutputData]: Parsed output data. + """ + memory = [] + + for value in data: + uid, score, metadata = ( + value.get("id"), + value.get("distance"), + value.get("entity",{}).get("metadata") + ) + + memory_obj = OutputData(id=uid, score=score, payload=metadata) + memory.append(memory_obj) + + return memory + + + def search(self, query: list, limit: int = 5, filters: dict = None) -> list: + """ + Search for similar vectors. + + Args: + query (List[float]): Query vector. + limit (int, optional): Number of results to return. Defaults to 5. + filters (Dict, optional): Filters to apply to the search. Defaults to None. + + Returns: + list: Search results. + """ + query_filter = self._create_filter(filters) if filters else None + hits = self.client.search( + collection_name=self.collection_name, + data=[query], limit=limit, filter=query_filter, + output_fields=["*"] + ) + result = self._parse_output(data=hits[0]) + + return result + + def delete(self, vector_id): + """ + Delete a vector by ID. + + Args: + vector_id (str): ID of the vector to delete. + """ + self.client.delete(collection_name=self.collection_name, ids=vector_id) + + + def update(self, vector_id=None, vector=None, payload=None): + """ + Update a vector and its payload. + + Args: + vector_id (str): ID of the vector to update. + vector (List[float], optional): Updated vector. + payload (Dict, optional): Updated payload. + """ + schema = {"id" : vector_id, "vectors": vector, "metadata" : payload} + self.client.upsert(collection_name=self.collection_name, data=schema) + + def get(self, vector_id): + """ + Retrieve a vector by ID. + + Args: + vector_id (str): ID of the vector to retrieve. + + Returns: + OutputData: Retrieved vector. + """ + result = self.client.get(collection_name=self.collection_name, ids=vector_id) + output = OutputData(id=result[0].get("id", None), score=None, payload=result[0].get("metadata", None)) + return output + + def list_cols(self): + """ + List all collections. + + Returns: + List[str]: List of collection names. + """ + return self.client.list_collections() + + def delete_col(self): + """Delete a collection.""" + return self.client.drop_collection(collection_name=self.collection_name) + + def col_info(self): + """ + Get information about a collection. + + Returns: + Dict[str, Any]: Collection information. + """ + return self.client.get_collection_stats(collection_name=self.collection_name) + + def list(self, filters: dict = None, limit: int = 100) -> list: + """ + List all vectors in a collection. + + Args: + filters (Dict, optional): Filters to apply to the list. + limit (int, optional): Number of vectors to return. Defaults to 100. + + Returns: + List[OutputData]: List of vectors. + """ + query_filter = self._create_filter(filters) if filters else None + result = self.client.query( + collection_name=self.collection_name, + filter=query_filter, + limit=limit) + memories = [] + for data in result: + obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata")) + memories.append(obj) + return [memories] \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index eab715d4..5e9aedb4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1966,4 +1966,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "458055aee51b5e75c8f189fc1b0fbd238b9bb0d8a8becced0bd62a6a59d8d428" +content-hash = "5a74dacc8f9b1b40bb9d53fbbdcb0a95f5d05d55ffd9d61af870ca8a731954b4" diff --git a/pyproject.toml b/pyproject.toml index c5ab0054..128d34dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,8 +35,6 @@ isort = "^5.13.2" pytest = "^8.2.2" -[tool.poetry.group.optional.dependencies] - [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"