diff --git a/docs/components/vectordbs/config.mdx b/docs/components/vectordbs/config.mdx
index 7ddff7c9..5169781e 100644
--- a/docs/components/vectordbs/config.mdx
+++ b/docs/components/vectordbs/config.mdx
@@ -6,7 +6,7 @@ Config in mem0 is a dictionary that specifies the settings for your vector datab
The config is defined as a Python dictionary with two main keys:
- `vector_store`: Specifies the vector database provider and its configuration
- - `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus")
+ - `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus","azure_ai_search")
- `config`: A nested dictionary containing provider-specific settings
## How to Use Config
diff --git a/docs/components/vectordbs/dbs/azure_ai_search.mdx b/docs/components/vectordbs/dbs/azure_ai_search.mdx
new file mode 100644
index 00000000..5bb2a952
--- /dev/null
+++ b/docs/components/vectordbs/dbs/azure_ai_search.mdx
@@ -0,0 +1,38 @@
+[Azure AI Search](https://learn.microsoft.com/en-us/azure/search/search-what-is-azure-search/) (formerly known as "Azure Cognitive Search") provides secure information retrieval at scale over user-owned content in traditional and generative AI search applications.
+
+### Usage
+
+```python
+import os
+from mem0 import Memory
+
+os.environ["OPENAI_API_KEY"] = "sk-xx" #this key is used for embedding purpose
+
+config = {
+ "vector_store": {
+ "provider": "azure_ai_search",
+ "config": {
+ "service_name": "ai-search-test",
+ "api_key": "*****",
+ "collection_name": "mem0",
+ "embedding_model_dims": 1536 ,
+ "use_compression": False
+ }
+ }
+}
+
+m = Memory.from_config(config)
+m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
+```
+
+### Config
+
+Let's see the available parameters for the `qdrant` config:
+service_name (str): Azure Cognitive Search service name.
+| Parameter | Description | Default Value |
+| --- | --- | --- |
+| `service_name` | Azure AI Search service name | `None` |
+| `api_key` | API key of the Azure AI Search service | `None` |
+| `collection_name` | The name of the collection/index to store the vectors, it will be created automatically if not exist | `mem0` |
+| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
+| `use_compression` | Use scalar quantization vector compression | False |
\ No newline at end of file
diff --git a/docs/components/vectordbs/overview.mdx b/docs/components/vectordbs/overview.mdx
index 85856309..36c1e9c3 100644
--- a/docs/components/vectordbs/overview.mdx
+++ b/docs/components/vectordbs/overview.mdx
@@ -12,6 +12,7 @@ See the list of supported vector databases below.
+
## Usage
diff --git a/docs/mint.json b/docs/mint.json
index efb11f5b..4a15ce86 100644
--- a/docs/mint.json
+++ b/docs/mint.json
@@ -110,7 +110,8 @@
"components/vectordbs/dbs/chroma",
"components/vectordbs/dbs/pgvector",
"components/vectordbs/dbs/qdrant",
- "components/vectordbs/dbs/milvus"
+ "components/vectordbs/dbs/milvus",
+ "components/vectordbs/dbs/azure_ai_search"
]
}
]
diff --git a/mem0/configs/vector_stores/azure_ai_search.py b/mem0/configs/vector_stores/azure_ai_search.py
new file mode 100644
index 00000000..5619b300
--- /dev/null
+++ b/mem0/configs/vector_stores/azure_ai_search.py
@@ -0,0 +1,27 @@
+from typing import Any, Dict
+
+from pydantic import BaseModel, Field, model_validator
+
+
+class AzureAISearchConfig(BaseModel):
+ collection_name: str = Field("mem0", description="Name of the collection")
+ service_name: str = Field(None, description="Azure Cognitive Search service name")
+ api_key: str = Field(None, description="API key for the Azure Cognitive Search service")
+ embedding_model_dims: int = Field(None, description="Dimension of the embedding vector")
+ use_compression: bool = Field(False, description="Whether to use scalar quantization vector compression.")
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ allowed_fields = set(cls.model_fields.keys())
+ input_fields = set(values.keys())
+ extra_fields = input_fields - allowed_fields
+ if extra_fields:
+ raise ValueError(
+ f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
+ )
+ return values
+
+ model_config = {
+ "arbitrary_types_allowed": True,
+ }
diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py
index b7606c69..56d2cc2c 100644
--- a/mem0/utils/factory.py
+++ b/mem0/utils/factory.py
@@ -63,6 +63,7 @@ class VectorStoreFactory:
"chroma": "mem0.vector_stores.chroma.ChromaDB",
"pgvector": "mem0.vector_stores.pgvector.PGVector",
"milvus": "mem0.vector_stores.milvus.MilvusDB",
+ "azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
}
@classmethod
diff --git a/mem0/vector_stores/azure_ai_search.py b/mem0/vector_stores/azure_ai_search.py
new file mode 100644
index 00000000..1c68c4f7
--- /dev/null
+++ b/mem0/vector_stores/azure_ai_search.py
@@ -0,0 +1,227 @@
+import json
+import logging
+from typing import List, Optional
+
+from pydantic import BaseModel
+
+from mem0.vector_stores.base import VectorStoreBase
+
+try:
+ from azure.core.credentials import AzureKeyCredential
+ from azure.core.exceptions import ResourceNotFoundError
+ from azure.search.documents import SearchClient
+ from azure.search.documents.indexes import SearchIndexClient
+ from azure.search.documents.indexes.models import (
+ HnswAlgorithmConfiguration,
+ ScalarQuantizationCompression,
+ SearchField,
+ SearchFieldDataType,
+ SearchIndex,
+ SimpleField,
+ VectorSearch,
+ VectorSearchProfile,
+ )
+ from azure.search.documents.models import VectorizedQuery
+except ImportError:
+ raise ImportError(
+ "The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.1'."
+ )
+
+logger = logging.getLogger(__name__)
+
+
+class OutputData(BaseModel):
+ id: Optional[str]
+ score: Optional[float]
+ payload: Optional[dict]
+
+
+class AzureAISearch(VectorStoreBase):
+ def __init__(self, service_name, collection_name, api_key, embedding_model_dims, use_compression):
+ """Initialize the Azure Cognitive Search vector store.
+
+ Args:
+ service_name (str): Azure Cognitive Search service name.
+ collection_name (str): Index name.
+ api_key (str): API key for the Azure Cognitive Search service.
+ embedding_model_dims (int): Dimension of the embedding vector.
+ use_compression (bool): Use scalar quantization vector compression
+ """
+ self.index_name = collection_name
+ self.collection_name = collection_name
+ self.embedding_model_dims = embedding_model_dims
+ self.use_compression = use_compression
+ self.search_client = SearchClient(
+ endpoint=f"https://{service_name}.search.windows.net",
+ index_name=self.index_name,
+ credential=AzureKeyCredential(api_key),
+ )
+ self.index_client = SearchIndexClient(
+ endpoint=f"https://{service_name}.search.windows.net", credential=AzureKeyCredential(api_key)
+ )
+ self.create_col() # create the collection / index
+
+ def create_col(self):
+ """Create a new index in Azure Cognitive Search."""
+ vector_dimensions = self.embedding_model_dims # Set this to the number of dimensions in your vector
+
+ if self.use_compression:
+ vector_type = "Collection(Edm.Half)"
+ compression_name = "myCompression"
+ compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)]
+ else:
+ vector_type = "Collection(Edm.Single)"
+ compression_name = None
+ compression_configurations = []
+
+ fields = [
+ SimpleField(name="id", type=SearchFieldDataType.String, key=True),
+ SearchField(
+ name="vector",
+ type=vector_type,
+ searchable=True,
+ vector_search_dimensions=vector_dimensions,
+ vector_search_profile_name="my-vector-config",
+ ),
+ SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
+ ]
+
+ vector_search = VectorSearch(
+ profiles=[
+ VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config")
+ ],
+ algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
+ compressions=compression_configurations,
+ )
+ index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
+ self.index_client.create_or_update_index(index)
+
+ def insert(self, vectors, payloads=None, ids=None):
+ """Insert vectors into the index.
+
+ Args:
+ vectors (List[List[float]]): List of vectors to insert.
+ payloads (List[Dict], optional): List of payloads corresponding to vectors.
+ ids (List[str], optional): List of IDs corresponding to vectors.
+ """
+ logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
+ documents = [
+ {"id": id, "vector": vector, "payload": json.dumps(payload)}
+ for id, vector, payload in zip(ids, vectors, payloads)
+ ]
+ self.search_client.upload_documents(documents)
+
+ def search(self, query, limit=5, filters=None):
+ """Search for similar vectors.
+
+ Args:
+ query (List[float]): Query vectors.
+ limit (int, optional): Number of results to return. Defaults to 5.
+ filters (Dict, optional): Filters to apply to the search. Defaults to None.
+
+ Returns:
+ list: Search results.
+ """
+
+ vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector")
+ search_results = self.search_client.search(vector_queries=[vector_query], top=limit)
+
+ results = []
+ for result in search_results:
+ payload = json.loads(result["payload"])
+ if filters:
+ for key, value in filters.items():
+ if key not in payload or payload[key] != value:
+ continue
+ results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
+ return results
+
+ def delete(self, vector_id):
+ """Delete a vector by ID.
+
+ Args:
+ vector_id (str): ID of the vector to delete.
+ """
+ self.search_client.delete_documents(documents=[{"id": vector_id}])
+
+ def update(self, vector_id, vector=None, payload=None):
+ """Update a vector and its payload.
+
+ Args:
+ vector_id (str): ID of the vector to update.
+ vector (List[float], optional): Updated vector.
+ payload (Dict, optional): Updated payload.
+ """
+ document = {"id": vector_id}
+ if vector:
+ document["vector"] = vector
+ if payload:
+ document["payload"] = json.dumps(payload)
+ self.search_client.merge_or_upload_documents(documents=[document])
+
+ def get(self, vector_id) -> OutputData:
+ """Retrieve a vector by ID.
+
+ Args:
+ vector_id (str): ID of the vector to retrieve.
+
+ Returns:
+ OutputData: Retrieved vector.
+ """
+ try:
+ result = self.search_client.get_document(key=vector_id)
+ except ResourceNotFoundError:
+ return None
+ return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
+
+ def list_cols(self) -> List[str]:
+ """List all collections (indexes).
+
+ Returns:
+ List[str]: List of index names.
+ """
+ indexes = self.index_client.list_indexes()
+ return [index.name for index in indexes]
+
+ def delete_col(self):
+ """Delete the index."""
+ self.index_client.delete_index(self.index_name)
+
+ def col_info(self):
+ """Get information about the index.
+
+ Returns:
+ Dict[str, Any]: Index information.
+ """
+ index = self.index_client.get_index(self.index_name)
+ return {"name": index.name, "fields": index.fields}
+
+ def list(self, filters=None, limit=100):
+ """List all vectors in the index.
+
+ Args:
+ filters (Dict, optional): Filters to apply to the list.
+ limit (int, optional): Number of vectors to return. Defaults to 100.
+
+ Returns:
+ List[OutputData]: List of vectors.
+ """
+ search_results = self.search_client.search(search_text="*", top=limit)
+ results = []
+ for result in search_results:
+ payload = json.loads(result["payload"])
+ include_result = True
+ if filters:
+ for key, value in filters.items():
+ if (key not in payload) or (payload[key] != filters[key]):
+ include_result = False
+ break
+ if include_result:
+ results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
+
+ return [results]
+
+ def __del__(self):
+ """Close the search client when the object is deleted."""
+ self.search_client.close()
+ self.index_client.close()
diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py
index 65e55a53..c76e3a11 100644
--- a/mem0/vector_stores/configs.py
+++ b/mem0/vector_stores/configs.py
@@ -15,6 +15,7 @@ class VectorStoreConfig(BaseModel):
"chroma": "ChromaDbConfig",
"pgvector": "PGVectorConfig",
"milvus": "MilvusDBConfig",
+ "azure_ai_search": "AzureAISearchConfig",
}
@model_validator(mode="after")