Feature - Support Azure AI Search as a Vector DB (#1967)
Co-authored-by: Sidney Phoon <sidneyphoon17@gmail.com>
This commit is contained in:
@@ -6,7 +6,7 @@ Config in mem0 is a dictionary that specifies the settings for your vector datab
|
|||||||
|
|
||||||
The config is defined as a Python dictionary with two main keys:
|
The config is defined as a Python dictionary with two main keys:
|
||||||
- `vector_store`: Specifies the vector database provider and its configuration
|
- `vector_store`: Specifies the vector database provider and its configuration
|
||||||
- `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus")
|
- `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus","azure_ai_search")
|
||||||
- `config`: A nested dictionary containing provider-specific settings
|
- `config`: A nested dictionary containing provider-specific settings
|
||||||
|
|
||||||
## How to Use Config
|
## How to Use Config
|
||||||
|
|||||||
38
docs/components/vectordbs/dbs/azure_ai_search.mdx
Normal file
38
docs/components/vectordbs/dbs/azure_ai_search.mdx
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
[Azure AI Search](https://learn.microsoft.com/en-us/azure/search/search-what-is-azure-search/) (formerly known as "Azure Cognitive Search") provides secure information retrieval at scale over user-owned content in traditional and generative AI search applications.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "sk-xx" #this key is used for embedding purpose
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"vector_store": {
|
||||||
|
"provider": "azure_ai_search",
|
||||||
|
"config": {
|
||||||
|
"service_name": "ai-search-test",
|
||||||
|
"api_key": "*****",
|
||||||
|
"collection_name": "mem0",
|
||||||
|
"embedding_model_dims": 1536 ,
|
||||||
|
"use_compression": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m = Memory.from_config(config)
|
||||||
|
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Config
|
||||||
|
|
||||||
|
Let's see the available parameters for the `qdrant` config:
|
||||||
|
service_name (str): Azure Cognitive Search service name.
|
||||||
|
| Parameter | Description | Default Value |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `service_name` | Azure AI Search service name | `None` |
|
||||||
|
| `api_key` | API key of the Azure AI Search service | `None` |
|
||||||
|
| `collection_name` | The name of the collection/index to store the vectors, it will be created automatically if not exist | `mem0` |
|
||||||
|
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
|
||||||
|
| `use_compression` | Use scalar quantization vector compression | False |
|
||||||
@@ -12,6 +12,7 @@ See the list of supported vector databases below.
|
|||||||
<Card title="Qdrant" href="/components/vectordbs/dbs/qdrant"></Card>
|
<Card title="Qdrant" href="/components/vectordbs/dbs/qdrant"></Card>
|
||||||
<Card title="Chroma" href="/components/vectordbs/dbs/chroma"></Card>
|
<Card title="Chroma" href="/components/vectordbs/dbs/chroma"></Card>
|
||||||
<Card title="Pgvector" href="/components/vectordbs/dbs/pgvector"></Card>
|
<Card title="Pgvector" href="/components/vectordbs/dbs/pgvector"></Card>
|
||||||
|
<Card title="Azure AI Search" href="/components/vectordbs/dbs/azure_ai_search"></Card>
|
||||||
</CardGroup>
|
</CardGroup>
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -110,7 +110,8 @@
|
|||||||
"components/vectordbs/dbs/chroma",
|
"components/vectordbs/dbs/chroma",
|
||||||
"components/vectordbs/dbs/pgvector",
|
"components/vectordbs/dbs/pgvector",
|
||||||
"components/vectordbs/dbs/qdrant",
|
"components/vectordbs/dbs/qdrant",
|
||||||
"components/vectordbs/dbs/milvus"
|
"components/vectordbs/dbs/milvus",
|
||||||
|
"components/vectordbs/dbs/azure_ai_search"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
27
mem0/configs/vector_stores/azure_ai_search.py
Normal file
27
mem0/configs/vector_stores/azure_ai_search.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class AzureAISearchConfig(BaseModel):
|
||||||
|
collection_name: str = Field("mem0", description="Name of the collection")
|
||||||
|
service_name: str = Field(None, description="Azure Cognitive Search service name")
|
||||||
|
api_key: str = Field(None, description="API key for the Azure Cognitive Search service")
|
||||||
|
embedding_model_dims: int = Field(None, description="Dimension of the embedding vector")
|
||||||
|
use_compression: bool = Field(False, description="Whether to use scalar quantization vector compression.")
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
allowed_fields = set(cls.model_fields.keys())
|
||||||
|
input_fields = set(values.keys())
|
||||||
|
extra_fields = input_fields - allowed_fields
|
||||||
|
if extra_fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True,
|
||||||
|
}
|
||||||
@@ -63,6 +63,7 @@ class VectorStoreFactory:
|
|||||||
"chroma": "mem0.vector_stores.chroma.ChromaDB",
|
"chroma": "mem0.vector_stores.chroma.ChromaDB",
|
||||||
"pgvector": "mem0.vector_stores.pgvector.PGVector",
|
"pgvector": "mem0.vector_stores.pgvector.PGVector",
|
||||||
"milvus": "mem0.vector_stores.milvus.MilvusDB",
|
"milvus": "mem0.vector_stores.milvus.MilvusDB",
|
||||||
|
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
227
mem0/vector_stores/azure_ai_search.py
Normal file
227
mem0/vector_stores/azure_ai_search.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from mem0.vector_stores.base import VectorStoreBase
|
||||||
|
|
||||||
|
try:
|
||||||
|
from azure.core.credentials import AzureKeyCredential
|
||||||
|
from azure.core.exceptions import ResourceNotFoundError
|
||||||
|
from azure.search.documents import SearchClient
|
||||||
|
from azure.search.documents.indexes import SearchIndexClient
|
||||||
|
from azure.search.documents.indexes.models import (
|
||||||
|
HnswAlgorithmConfiguration,
|
||||||
|
ScalarQuantizationCompression,
|
||||||
|
SearchField,
|
||||||
|
SearchFieldDataType,
|
||||||
|
SearchIndex,
|
||||||
|
SimpleField,
|
||||||
|
VectorSearch,
|
||||||
|
VectorSearchProfile,
|
||||||
|
)
|
||||||
|
from azure.search.documents.models import VectorizedQuery
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.1'."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputData(BaseModel):
|
||||||
|
id: Optional[str]
|
||||||
|
score: Optional[float]
|
||||||
|
payload: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class AzureAISearch(VectorStoreBase):
|
||||||
|
def __init__(self, service_name, collection_name, api_key, embedding_model_dims, use_compression):
|
||||||
|
"""Initialize the Azure Cognitive Search vector store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service_name (str): Azure Cognitive Search service name.
|
||||||
|
collection_name (str): Index name.
|
||||||
|
api_key (str): API key for the Azure Cognitive Search service.
|
||||||
|
embedding_model_dims (int): Dimension of the embedding vector.
|
||||||
|
use_compression (bool): Use scalar quantization vector compression
|
||||||
|
"""
|
||||||
|
self.index_name = collection_name
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.embedding_model_dims = embedding_model_dims
|
||||||
|
self.use_compression = use_compression
|
||||||
|
self.search_client = SearchClient(
|
||||||
|
endpoint=f"https://{service_name}.search.windows.net",
|
||||||
|
index_name=self.index_name,
|
||||||
|
credential=AzureKeyCredential(api_key),
|
||||||
|
)
|
||||||
|
self.index_client = SearchIndexClient(
|
||||||
|
endpoint=f"https://{service_name}.search.windows.net", credential=AzureKeyCredential(api_key)
|
||||||
|
)
|
||||||
|
self.create_col() # create the collection / index
|
||||||
|
|
||||||
|
def create_col(self):
|
||||||
|
"""Create a new index in Azure Cognitive Search."""
|
||||||
|
vector_dimensions = self.embedding_model_dims # Set this to the number of dimensions in your vector
|
||||||
|
|
||||||
|
if self.use_compression:
|
||||||
|
vector_type = "Collection(Edm.Half)"
|
||||||
|
compression_name = "myCompression"
|
||||||
|
compression_configurations = [ScalarQuantizationCompression(compression_name=compression_name)]
|
||||||
|
else:
|
||||||
|
vector_type = "Collection(Edm.Single)"
|
||||||
|
compression_name = None
|
||||||
|
compression_configurations = []
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
|
||||||
|
SearchField(
|
||||||
|
name="vector",
|
||||||
|
type=vector_type,
|
||||||
|
searchable=True,
|
||||||
|
vector_search_dimensions=vector_dimensions,
|
||||||
|
vector_search_profile_name="my-vector-config",
|
||||||
|
),
|
||||||
|
SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
|
||||||
|
]
|
||||||
|
|
||||||
|
vector_search = VectorSearch(
|
||||||
|
profiles=[
|
||||||
|
VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config")
|
||||||
|
],
|
||||||
|
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
|
||||||
|
compressions=compression_configurations,
|
||||||
|
)
|
||||||
|
index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
|
||||||
|
self.index_client.create_or_update_index(index)
|
||||||
|
|
||||||
|
def insert(self, vectors, payloads=None, ids=None):
|
||||||
|
"""Insert vectors into the index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vectors (List[List[float]]): List of vectors to insert.
|
||||||
|
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
||||||
|
ids (List[str], optional): List of IDs corresponding to vectors.
|
||||||
|
"""
|
||||||
|
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
|
||||||
|
documents = [
|
||||||
|
{"id": id, "vector": vector, "payload": json.dumps(payload)}
|
||||||
|
for id, vector, payload in zip(ids, vectors, payloads)
|
||||||
|
]
|
||||||
|
self.search_client.upload_documents(documents)
|
||||||
|
|
||||||
|
def search(self, query, limit=5, filters=None):
|
||||||
|
"""Search for similar vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (List[float]): Query vectors.
|
||||||
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
|
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Search results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector")
|
||||||
|
search_results = self.search_client.search(vector_queries=[vector_query], top=limit)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for result in search_results:
|
||||||
|
payload = json.loads(result["payload"])
|
||||||
|
if filters:
|
||||||
|
for key, value in filters.items():
|
||||||
|
if key not in payload or payload[key] != value:
|
||||||
|
continue
|
||||||
|
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def delete(self, vector_id):
|
||||||
|
"""Delete a vector by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to delete.
|
||||||
|
"""
|
||||||
|
self.search_client.delete_documents(documents=[{"id": vector_id}])
|
||||||
|
|
||||||
|
def update(self, vector_id, vector=None, payload=None):
|
||||||
|
"""Update a vector and its payload.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to update.
|
||||||
|
vector (List[float], optional): Updated vector.
|
||||||
|
payload (Dict, optional): Updated payload.
|
||||||
|
"""
|
||||||
|
document = {"id": vector_id}
|
||||||
|
if vector:
|
||||||
|
document["vector"] = vector
|
||||||
|
if payload:
|
||||||
|
document["payload"] = json.dumps(payload)
|
||||||
|
self.search_client.merge_or_upload_documents(documents=[document])
|
||||||
|
|
||||||
|
def get(self, vector_id) -> OutputData:
|
||||||
|
"""Retrieve a vector by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OutputData: Retrieved vector.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = self.search_client.get_document(key=vector_id)
|
||||||
|
except ResourceNotFoundError:
|
||||||
|
return None
|
||||||
|
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
|
||||||
|
|
||||||
|
def list_cols(self) -> List[str]:
|
||||||
|
"""List all collections (indexes).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of index names.
|
||||||
|
"""
|
||||||
|
indexes = self.index_client.list_indexes()
|
||||||
|
return [index.name for index in indexes]
|
||||||
|
|
||||||
|
def delete_col(self):
|
||||||
|
"""Delete the index."""
|
||||||
|
self.index_client.delete_index(self.index_name)
|
||||||
|
|
||||||
|
def col_info(self):
|
||||||
|
"""Get information about the index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Index information.
|
||||||
|
"""
|
||||||
|
index = self.index_client.get_index(self.index_name)
|
||||||
|
return {"name": index.name, "fields": index.fields}
|
||||||
|
|
||||||
|
def list(self, filters=None, limit=100):
|
||||||
|
"""List all vectors in the index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters (Dict, optional): Filters to apply to the list.
|
||||||
|
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OutputData]: List of vectors.
|
||||||
|
"""
|
||||||
|
search_results = self.search_client.search(search_text="*", top=limit)
|
||||||
|
results = []
|
||||||
|
for result in search_results:
|
||||||
|
payload = json.loads(result["payload"])
|
||||||
|
include_result = True
|
||||||
|
if filters:
|
||||||
|
for key, value in filters.items():
|
||||||
|
if (key not in payload) or (payload[key] != filters[key]):
|
||||||
|
include_result = False
|
||||||
|
break
|
||||||
|
if include_result:
|
||||||
|
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
|
||||||
|
|
||||||
|
return [results]
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Close the search client when the object is deleted."""
|
||||||
|
self.search_client.close()
|
||||||
|
self.index_client.close()
|
||||||
@@ -15,6 +15,7 @@ class VectorStoreConfig(BaseModel):
|
|||||||
"chroma": "ChromaDbConfig",
|
"chroma": "ChromaDbConfig",
|
||||||
"pgvector": "PGVectorConfig",
|
"pgvector": "PGVectorConfig",
|
||||||
"milvus": "MilvusDBConfig",
|
"milvus": "MilvusDBConfig",
|
||||||
|
"azure_ai_search": "AzureAISearchConfig",
|
||||||
}
|
}
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|||||||
Reference in New Issue
Block a user