diff --git a/docs/components/vectordbs/config.mdx b/docs/components/vectordbs/config.mdx index 7ddff7c9..5169781e 100644 --- a/docs/components/vectordbs/config.mdx +++ b/docs/components/vectordbs/config.mdx @@ -6,7 +6,7 @@ Config in mem0 is a dictionary that specifies the settings for your vector datab The config is defined as a Python dictionary with two main keys: - `vector_store`: Specifies the vector database provider and its configuration - - `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus") + - `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus","azure_ai_search") - `config`: A nested dictionary containing provider-specific settings ## How to Use Config diff --git a/docs/components/vectordbs/dbs/azure_ai_search.mdx b/docs/components/vectordbs/dbs/azure_ai_search.mdx new file mode 100644 index 00000000..5bb2a952 --- /dev/null +++ b/docs/components/vectordbs/dbs/azure_ai_search.mdx @@ -0,0 +1,38 @@ +[Azure AI Search](https://learn.microsoft.com/en-us/azure/search/search-what-is-azure-search/) (formerly known as "Azure Cognitive Search") provides secure information retrieval at scale over user-owned content in traditional and generative AI search applications. + +### Usage + +```python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "sk-xx" #this key is used for embedding purpose + +config = { + "vector_store": { + "provider": "azure_ai_search", + "config": { + "service_name": "ai-search-test", + "api_key": "*****", + "collection_name": "mem0", + "embedding_model_dims": 1536 , + "use_compression": False + } + } +} + +m = Memory.from_config(config) +m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) +``` + +### Config + +Let's see the available parameters for the `qdrant` config: +service_name (str): Azure Cognitive Search service name. +| Parameter | Description | Default Value | +| --- | --- | --- | +| `service_name` | Azure AI Search service name | `None` | +| `api_key` | API key of the Azure AI Search service | `None` | +| `collection_name` | The name of the collection/index to store the vectors, it will be created automatically if not exist | `mem0` | +| `embedding_model_dims` | Dimensions of the embedding model | `1536` | +| `use_compression` | Use scalar quantization vector compression | False | \ No newline at end of file diff --git a/docs/components/vectordbs/overview.mdx b/docs/components/vectordbs/overview.mdx index 85856309..36c1e9c3 100644 --- a/docs/components/vectordbs/overview.mdx +++ b/docs/components/vectordbs/overview.mdx @@ -12,6 +12,7 @@ See the list of supported vector databases below. + ## Usage diff --git a/docs/mint.json b/docs/mint.json index efb11f5b..4a15ce86 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -110,7 +110,8 @@ "components/vectordbs/dbs/chroma", "components/vectordbs/dbs/pgvector", "components/vectordbs/dbs/qdrant", - "components/vectordbs/dbs/milvus" + "components/vectordbs/dbs/milvus", + "components/vectordbs/dbs/azure_ai_search" ] } ] diff --git a/mem0/configs/vector_stores/azure_ai_search.py b/mem0/configs/vector_stores/azure_ai_search.py new file mode 100644 index 00000000..5619b300 --- /dev/null +++ b/mem0/configs/vector_stores/azure_ai_search.py @@ -0,0 +1,27 @@ +from typing import Any, Dict + +from pydantic import BaseModel, Field, model_validator + + +class AzureAISearchConfig(BaseModel): + collection_name: str = Field("mem0", description="Name of the collection") + service_name: str = Field(None, description="Azure Cognitive Search service name") + api_key: str = Field(None, description="API key for the Azure Cognitive Search service") + embedding_model_dims: int = Field(None, description="Dimension of the embedding vector") + use_compression: bool = Field(False, description="Whether to use scalar quantization vector compression.") + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError( + f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" + ) + return values + + model_config = { + "arbitrary_types_allowed": True, + } diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index b7606c69..56d2cc2c 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -63,6 +63,7 @@ class VectorStoreFactory: "chroma": "mem0.vector_stores.chroma.ChromaDB", "pgvector": "mem0.vector_stores.pgvector.PGVector", "milvus": "mem0.vector_stores.milvus.MilvusDB", + "azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch", } @classmethod diff --git a/mem0/vector_stores/azure_ai_search.py b/mem0/vector_stores/azure_ai_search.py new file mode 100644 index 00000000..1c68c4f7 --- /dev/null +++ b/mem0/vector_stores/azure_ai_search.py @@ -0,0 +1,227 @@ +import json +import logging +from typing import List, Optional + +from pydantic import BaseModel + +from mem0.vector_stores.base import VectorStoreBase + +try: + from azure.core.credentials import AzureKeyCredential + from azure.core.exceptions import ResourceNotFoundError + from azure.search.documents import SearchClient + from azure.search.documents.indexes import SearchIndexClient + from azure.search.documents.indexes.models import ( + HnswAlgorithmConfiguration, + ScalarQuantizationCompression, + SearchField, + SearchFieldDataType, + SearchIndex, + SimpleField, + VectorSearch, + VectorSearchProfile, + ) + from azure.search.documents.models import VectorizedQuery +except ImportError: + raise ImportError( + "The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.1'." + ) + +logger = logging.getLogger(__name__) + + +class OutputData(BaseModel): + id: Optional[str] + score: Optional[float] + payload: Optional[dict] + + +class AzureAISearch(VectorStoreBase): + def __init__(self, service_name, collection_name, api_key, embedding_model_dims, use_compression): + """Initialize the Azure Cognitive Search vector store. + + Args: + service_name (str): Azure Cognitive Search service name. + collection_name (str): Index name. + api_key (str): API key for the Azure Cognitive Search service. + embedding_model_dims (int): Dimension of the embedding vector. + use_compression (bool): Use scalar quantization vector compression + """ + self.index_name = collection_name + self.collection_name = collection_name + self.embedding_model_dims = embedding_model_dims + self.use_compression = use_compression + self.search_client = SearchClient( + endpoint=f"https://{service_name}.search.windows.net", + index_name=self.index_name, + credential=AzureKeyCredential(api_key), + ) + self.index_client = SearchIndexClient( + endpoint=f"https://{service_name}.search.windows.net", credential=AzureKeyCredential(api_key) + ) + self.create_col() # create the collection / index + + def create_col(self): + """Create a new index in Azure Cognitive Search.""" + vector_dimensions = self.embedding_model_dims # Set this to the number of dimensions in your vector + + if self.use_compression: + vector_type = "Collection(Edm.Half)" + compression_name = "myCompression" + compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)] + else: + vector_type = "Collection(Edm.Single)" + compression_name = None + compression_configurations = [] + + fields = [ + SimpleField(name="id", type=SearchFieldDataType.String, key=True), + SearchField( + name="vector", + type=vector_type, + searchable=True, + vector_search_dimensions=vector_dimensions, + vector_search_profile_name="my-vector-config", + ), + SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True), + ] + + vector_search = VectorSearch( + profiles=[ + VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config") + ], + algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")], + compressions=compression_configurations, + ) + index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search) + self.index_client.create_or_update_index(index) + + def insert(self, vectors, payloads=None, ids=None): + """Insert vectors into the index. + + Args: + vectors (List[List[float]]): List of vectors to insert. + payloads (List[Dict], optional): List of payloads corresponding to vectors. + ids (List[str], optional): List of IDs corresponding to vectors. + """ + logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}") + documents = [ + {"id": id, "vector": vector, "payload": json.dumps(payload)} + for id, vector, payload in zip(ids, vectors, payloads) + ] + self.search_client.upload_documents(documents) + + def search(self, query, limit=5, filters=None): + """Search for similar vectors. + + Args: + query (List[float]): Query vectors. + limit (int, optional): Number of results to return. Defaults to 5. + filters (Dict, optional): Filters to apply to the search. Defaults to None. + + Returns: + list: Search results. + """ + + vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector") + search_results = self.search_client.search(vector_queries=[vector_query], top=limit) + + results = [] + for result in search_results: + payload = json.loads(result["payload"]) + if filters: + for key, value in filters.items(): + if key not in payload or payload[key] != value: + continue + results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) + return results + + def delete(self, vector_id): + """Delete a vector by ID. + + Args: + vector_id (str): ID of the vector to delete. + """ + self.search_client.delete_documents(documents=[{"id": vector_id}]) + + def update(self, vector_id, vector=None, payload=None): + """Update a vector and its payload. + + Args: + vector_id (str): ID of the vector to update. + vector (List[float], optional): Updated vector. + payload (Dict, optional): Updated payload. + """ + document = {"id": vector_id} + if vector: + document["vector"] = vector + if payload: + document["payload"] = json.dumps(payload) + self.search_client.merge_or_upload_documents(documents=[document]) + + def get(self, vector_id) -> OutputData: + """Retrieve a vector by ID. + + Args: + vector_id (str): ID of the vector to retrieve. + + Returns: + OutputData: Retrieved vector. + """ + try: + result = self.search_client.get_document(key=vector_id) + except ResourceNotFoundError: + return None + return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"])) + + def list_cols(self) -> List[str]: + """List all collections (indexes). + + Returns: + List[str]: List of index names. + """ + indexes = self.index_client.list_indexes() + return [index.name for index in indexes] + + def delete_col(self): + """Delete the index.""" + self.index_client.delete_index(self.index_name) + + def col_info(self): + """Get information about the index. + + Returns: + Dict[str, Any]: Index information. + """ + index = self.index_client.get_index(self.index_name) + return {"name": index.name, "fields": index.fields} + + def list(self, filters=None, limit=100): + """List all vectors in the index. + + Args: + filters (Dict, optional): Filters to apply to the list. + limit (int, optional): Number of vectors to return. Defaults to 100. + + Returns: + List[OutputData]: List of vectors. + """ + search_results = self.search_client.search(search_text="*", top=limit) + results = [] + for result in search_results: + payload = json.loads(result["payload"]) + include_result = True + if filters: + for key, value in filters.items(): + if (key not in payload) or (payload[key] != filters[key]): + include_result = False + break + if include_result: + results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) + + return [results] + + def __del__(self): + """Close the search client when the object is deleted.""" + self.search_client.close() + self.index_client.close() diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index 65e55a53..c76e3a11 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -15,6 +15,7 @@ class VectorStoreConfig(BaseModel): "chroma": "ChromaDbConfig", "pgvector": "PGVectorConfig", "milvus": "MilvusDBConfig", + "azure_ai_search": "AzureAISearchConfig", } @model_validator(mode="after")