improvement(OSS): Fix AOSS and AWS BedRock LLM (#2697)

Co-authored-by: Prateek Chhikara <prateekchhikara24@gmail.com>
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Saket Aryan
2025-05-16 04:49:29 +05:30
committed by GitHub
parent 267e5b13ea
commit 5c67a5e6bc
14 changed files with 502 additions and 127 deletions

View File

@@ -25,7 +25,7 @@ from mem0 import Memory
os.environ["OPENAI_API_KEY"] = "your-openai-api-key"
# AWS credentials
os.environ["AWS_REGION"] = "us-east-1"
os.environ["AWS_REGION"] = "us-west-2"
os.environ["AWS_ACCESS_KEY_ID"] = "your-access-key"
os.environ["AWS_SECRET_ACCESS_KEY"] = "your-secret-key"
@@ -33,7 +33,7 @@ config = {
"embedder": {
"provider": "aws_bedrock",
"config": {
"model": "amazon.titan-embed-text-v1"
"model": "amazon.titan-embed-text-v2:0"
}
}
}

View File

@@ -15,16 +15,15 @@ title: AWS Bedrock
import os
from mem0 import Memory
os.environ["OPENAI_API_KEY"] = "your-api-key" # used for embedding model
os.environ['AWS_REGION'] = 'us-east-1'
os.environ["AWS_ACCESS_KEY"] = "xx"
os.environ['AWS_REGION'] = 'us-west-2'
os.environ["AWS_ACCESS_KEY_ID"] = "xx"
os.environ["AWS_SECRET_ACCESS_KEY"] = "xx"
config = {
"llm": {
"provider": "aws_bedrock",
"config": {
"model": "arn:aws:bedrock:us-east-1:123456789012:model/your-model-name",
"model": "anthropic.claude-3-5-haiku-20241022-v1:0",
"temperature": 0.2,
"max_tokens": 2000,
}

View File

@@ -1,59 +1,75 @@
[OpenSearch](https://opensearch.org/) is an open-source, enterprise-grade search and observability suite that brings order to unstructured data at scale. OpenSearch supports k-NN (k-Nearest Neighbors) and allows you to store and retrieve high-dimensional vector embeddings efficiently.
[OpenSearch](https://opensearch.org/) is an enterprise-grade search and observability suite that brings order to unstructured data at scale. OpenSearch supports k-NN (k-Nearest Neighbors) and allows you to store and retrieve high-dimensional vector embeddings efficiently.
### Installation
OpenSearch support requires additional dependencies. Install them with:
```bash
pip install opensearch>=2.8.0
pip install opensearch-py
```
### Prerequisites
Before using OpenSearch with Mem0, you need to set up a collection in AWS OpenSearch Service.
#### AWS OpenSearch Service
You can create a collection through the AWS Console:
- Navigate to [OpenSearch Service Console](https://console.aws.amazon.com/aos/home)
- Click "Create collection"
- Select "Serverless collection" and then enable "Vector search" capabilities
- Once created, note the endpoint URL (host) for your configuration
### Usage
```python
import os
from mem0 import Memory
import boto3
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
os.environ["OPENAI_API_KEY"] = "sk-xx"
# For AWS OpenSearch Service with IAM authentication
region = 'us-west-2'
service = 'aoss'
credentials = boto3.Session().get_credentials()
auth = AWSV4SignerAuth(credentials, region, service)
config = {
"vector_store": {
"provider": "opensearch",
"config": {
"collection_name": "mem0",
"host": "localhost",
"port": 9200,
"embedding_model_dims": 1536
"host": "your-domain.us-west-2.aoss.amazonaws.com",
"port": 443,
"http_auth": auth,
"embedding_model_dims": 1024,
"connection_class": RequestsHttpConnection,
"pool_maxsize": 20,
"use_ssl": True,
"verify_certs": True
}
}
}
```
### Add Memories
```python
m = Memory.from_config(config)
messages = [
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
{"role": "user", "content": "Im not a big fan of thriller movies but I love sci-fi movies."},
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
]
m.add(messages, user_id="alice", metadata={"category": "movies"})
```
### Config
### Search Memories
Let's see the available parameters for the `opensearch` config:
| Parameter | Description | Default Value |
| ---------------------- | -------------------------------------------------- | ------------- |
| `collection_name` | The name of the index to store the vectors | `mem0` |
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
| `host` | The host where the OpenSearch server is running | `localhost` |
| `port` | The port where the OpenSearch server is running | `9200` |
| `api_key` | API key for authentication | `None` |
| `user` | Username for basic authentication | `None` |
| `password` | Password for basic authentication | `None` |
| `verify_certs` | Whether to verify SSL certificates | `False` |
| `auto_create_index` | Whether to automatically create the index | `True` |
| `use_ssl` | Whether to use SSL for connection | `False` |
```python
results = m.search("What kind of movies does Alice like?", user_id="alice")
```
### Features

View File

@@ -200,6 +200,7 @@
"icon": "lightbulb",
"pages": [
"examples",
"examples/aws_example",
"examples/mem0-demo",
"examples/ai_companion_js",
"examples/eliza_os",

View File

@@ -0,0 +1,120 @@
---
title: AWS Bedrock and AOSS
---
<Snippet file="paper-release.mdx" />
This example demonstrates how to configure and use the `mem0ai` SDK with **AWS Bedrock** and **OpenSearch Service (AOSS)** for persistent memory capabilities in Python.
## Installation
Install the required dependencies:
```bash
pip install mem0ai boto3 opensearch-py
```
## Environment Setup
Set your AWS environment variables:
```python
import os
# Set these in your environment or notebook
os.environ['AWS_REGION'] = 'us-west-2'
os.environ['AWS_ACCESS_KEY_ID'] = 'AK00000000000000000'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'AS00000000000000000'
# Confirm they are set
print(os.environ['AWS_REGION'])
print(os.environ['AWS_ACCESS_KEY_ID'])
print(os.environ['AWS_SECRET_ACCESS_KEY'])
```
## Configuration and Usage
This sets up Mem0 with AWS Bedrock for embeddings and LLM, and OpenSearch as the vector store.
```python
import boto3
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
from mem0.memory.main import Memory
region = 'us-west-2'
service = 'aoss'
credentials = boto3.Session().get_credentials()
auth = AWSV4SignerAuth(credentials, region, service)
config = {
"embedder": {
"provider": "aws_bedrock",
"config": {
"model": "amazon.titan-embed-text-v2:0"
}
},
"llm": {
"provider": "aws_bedrock",
"config": {
"model": "anthropic.claude-3-5-haiku-20241022-v1:0",
"temperature": 0.1,
"max_tokens": 2000
}
},
"vector_store": {
"provider": "opensearch",
"config": {
"collection_name": "mem0",
"host": "your-opensearch-domain.us-west-2.es.amazonaws.com",
"port": 443,
"http_auth": auth,
"embedding_model_dims": 1024,
"connection_class": RequestsHttpConnection,
"pool_maxsize": 20,
"use_ssl": True,
"verify_certs": True
}
}
}
# Initialize memory system
m = Memory.from_config(config)
```
## Usage
#### Add a memory:
```python
messages = [
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
]
# Store inferred memories (default behavior)
result = m.add(messages, user_id="alice", metadata={"category": "movie_recommendations"})
```
#### Search a memory:
```python
relevant_memories = m.search(query, user_id="alice")
```
#### Get all memories:
```python
all_memories = m.get_all(user_id="alice")
```
#### Get a specific memory:
```python
memory = m.get(memory_id)
```
---
## Conclusion
With Mem0 and AWS services like Bedrock and OpenSearch, you can build intelligent AI companions that remember, adapt, and personalize their responses over time. This makes them ideal for long-term assistants, tutors, or support bots with persistent memory and natural conversation abilities.

View File

@@ -33,6 +33,10 @@ class BaseEmbedderConfig(ABC):
memory_search_embedding_type: Optional[str] = None,
# LM Studio specific
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
# AWS Bedrock specific
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region: Optional[str] = "us-west-2",
):
"""
Initializes a configuration class instance for the Embeddings.
@@ -92,3 +96,8 @@ class BaseEmbedderConfig(ABC):
# LM Studio specific
self.lmstudio_base_url = lmstudio_base_url
# AWS Bedrock specific
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_region = aws_region

View File

@@ -41,6 +41,10 @@ class BaseLlmConfig(ABC):
xai_base_url: Optional[str] = None,
# LM Studio specific
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
# AWS Bedrock specific
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region: Optional[str] = "us-west-2",
):
"""
Initializes a configuration class instance for the LLM.
@@ -123,3 +127,8 @@ class BaseLlmConfig(ABC):
# LM Studio specific
self.lmstudio_base_url = lmstudio_base_url
# AWS Bedrock specific
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_region = aws_region

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union, Type
from pydantic import BaseModel, Field, model_validator
@@ -7,14 +7,33 @@ class OpenSearchConfig(BaseModel):
collection_name: str = Field("mem0", description="Name of the index")
host: str = Field("localhost", description="OpenSearch host")
port: int = Field(9200, description="OpenSearch port")
user: Optional[str] = Field(None, description="Username for authentication")
password: Optional[str] = Field(None, description="Password for authentication")
api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
user: Optional[str] = Field(
None, description="Username for authentication"
)
password: Optional[str] = Field(
None, description="Password for authentication"
)
api_key: Optional[str] = Field(
None, description="API key for authentication (if applicable)"
)
embedding_model_dims: int = Field(
1536, description="Dimension of the embedding vector"
)
verify_certs: bool = Field(
False, description="Verify SSL certificates (default False for OpenSearch)"
)
use_ssl: bool = Field(
False, description="Use SSL for connection (default False for OpenSearch)"
)
http_auth: Optional[object] = Field(
None, description="HTTP authentication method / AWS SigV4"
)
connection_class: Optional[Union[str, Type]] = Field(
"RequestsHttpConnection", description="Connection class for OpenSearch"
)
pool_maxsize: int = Field(
20, description="Maximum number of connections in the pool"
)
@model_validator(mode="before")
@classmethod
@@ -22,11 +41,7 @@ class OpenSearchConfig(BaseModel):
# Check if host is provided
if not values.get("host"):
raise ValueError("Host must be provided for OpenSearch")
# Authentication: Either API key or user/password must be provided
if not any([values.get("api_key"), (values.get("user") and values.get("password")), values.get("http_auth")]):
raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication")
return values
@model_validator(mode="before")
@@ -37,6 +52,7 @@ class OpenSearchConfig(BaseModel):
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}"
f"Extra fields not allowed: {', '.join(extra_fields)}. "
f"Allowed fields: {', '.join(allowed_fields)}"
)
return values

View File

@@ -1,4 +1,5 @@
import json
import os
from typing import Literal, Optional
try:
@@ -22,7 +23,26 @@ class AWSBedrockEmbedding(EmbeddingBase):
super().__init__(config)
self.config.model = self.config.model or "amazon.titan-embed-text-v1"
self.client = boto3.client("bedrock-runtime")
# Get AWS config from environment variables or use defaults
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
aws_region = os.environ.get("AWS_REGION", "us-west-2")
# Check if AWS config is provided in the config
if hasattr(self.config, "aws_access_key_id"):
aws_access_key = self.config.aws_access_key_id
if hasattr(self.config, "aws_secret_access_key"):
aws_secret_key = self.config.aws_secret_access_key
if hasattr(self.config, "aws_region"):
aws_region = self.config.aws_region
self.client = boto3.client(
"bedrock-runtime",
region_name=aws_region,
aws_access_key_id=aws_access_key if aws_access_key else None,
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
)
def _normalize_vector(self, embeddings):
"""Normalize the embedding to a unit vector."""

View File

@@ -1,4 +1,6 @@
import json
import os
import re
from typing import Any, Dict, List, Optional
try:
@@ -9,6 +11,14 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
def extract_provider(model: str) -> str:
for provider in PROVIDERS:
if re.search(rf"\b{re.escape(provider)}\b", model):
return provider
raise ValueError(f"Unknown provider in model: {model}")
class AWSBedrockLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
@@ -16,7 +26,27 @@ class AWSBedrockLLM(LLMBase):
if not self.config.model:
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
self.client = boto3.client("bedrock-runtime")
# Get AWS config from environment variables or use defaults
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
aws_region = os.environ.get("AWS_REGION", "us-west-2")
# Check if AWS config is provided in the config
if hasattr(self.config, "aws_access_key_id"):
aws_access_key = self.config.aws_access_key_id
if hasattr(self.config, "aws_secret_access_key"):
aws_secret_key = self.config.aws_secret_access_key
if hasattr(self.config, "aws_region"):
aws_region = self.config.aws_region
self.client = boto3.client(
"bedrock-runtime",
region_name=aws_region,
aws_access_key_id=aws_access_key if aws_access_key else None,
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
)
self.model_kwargs = {
"temperature": self.config.temperature,
"max_tokens_to_sample": self.config.max_tokens,
@@ -34,13 +64,14 @@ class AWSBedrockLLM(LLMBase):
Returns:
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
"""
formatted_messages = []
for message in messages:
role = message["role"].capitalize()
content = message["content"]
formatted_messages.append(f"\n\n{role}: {content}")
return "".join(formatted_messages) + "\n\nAssistant:"
return "\n\nHuman: " + "".join(formatted_messages) + "\n\nAssistant:"
def _parse_response(self, response, tools) -> str:
"""
@@ -68,8 +99,9 @@ class AWSBedrockLLM(LLMBase):
return processed_response
response_body = json.loads(response["body"].read().decode())
return response_body.get("completion", "")
response_body = response.get("body").read().decode()
response_json = json.loads(response_body)
return response_json.get("content", [{"text": ""}])[0].get("text", "")
def _prepare_input(
self,
@@ -113,9 +145,9 @@ class AWSBedrockLLM(LLMBase):
input_body = {
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": model_kwargs.get("max_tokens_to_sample"),
"topP": model_kwargs.get("top_p"),
"temperature": model_kwargs.get("temperature"),
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
"topP": self.model_kwargs["top_p"] or 0.9,
"temperature": self.model_kwargs["temperature"] or 0.1,
},
}
input_body["textGenerationConfig"] = {
@@ -206,15 +238,40 @@ class AWSBedrockLLM(LLMBase):
else:
# Use invoke_model method when no tools are provided
prompt = self._format_messages(messages)
provider = self.model.split(".")[0]
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
provider = extract_provider(self.config.model)
input_body = self._prepare_input(provider, self.config.model, prompt, model_kwargs=self.model_kwargs)
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=self.model,
accept="application/json",
contentType="application/json",
if provider == "anthropic" or provider == "deepseek":
input_body = {
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
}
],
"max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
"temperature": self.model_kwargs["temperature"] or 0.1,
"top_p": self.model_kwargs["top_p"] or 0.9,
"anthropic_version": "bedrock-2023-05-31",
}
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=self.config.model,
accept="application/json",
contentType="application/json",
)
else:
response = self.client.invoke_model(
body=body,
modelId=self.config.model,
accept="application/json",
contentType="application/json",
)
return self._parse_response(response, tools)

View File

@@ -69,17 +69,14 @@ class Memory(MemoryBase):
self.enable_graph = True
else:
self.graph = None
self.config.vector_store.config.collection_name = "mem0migrations"
if self.config.vector_store.provider in ["faiss", "qdrant"]:
provider_path = f"migrations_{self.config.vector_store.provider}"
self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path)
os.makedirs(self.config.vector_store.config.path, exist_ok=True)
self._telemetry_vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
capture_event("mem0.init", self, {"sync_type": "sync"})
@classmethod

View File

@@ -38,7 +38,7 @@ def get_or_create_user_id(vector_store):
# Try to get existing user_id from vector store
try:
existing = vector_store.get(vector_id=VECTOR_ID)
existing = vector_store.get(vector_id=user_id)
if existing and hasattr(existing, "payload") and existing.payload and "user_id" in existing.payload:
return existing.payload["user_id"]
except Exception:
@@ -48,7 +48,7 @@ def get_or_create_user_id(vector_store):
try:
dims = getattr(vector_store, "embedding_model_dims", 1536)
vector_store.insert(
vectors=[[0.0] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[VECTOR_ID]
vectors=[[0.1] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[user_id]
)
except Exception:
pass

View File

@@ -1,5 +1,6 @@
import logging
from typing import Any, Dict, List, Optional
import time
try:
from opensearchpy import OpenSearch, RequestsHttpConnection
@@ -34,28 +35,26 @@ class OpenSearchDB(VectorStoreBase):
use_ssl=config.use_ssl,
verify_certs=config.verify_certs,
connection_class=RequestsHttpConnection,
pool_maxsize=20
)
self.collection_name = config.collection_name
self.embedding_model_dims = config.embedding_model_dims
# Create index only if auto_create_index is True
if config.auto_create_index:
self.create_index()
self.create_col(self.collection_name, self.embedding_model_dims)
def create_index(self) -> None:
"""Create OpenSearch index with proper mappings if it doesn't exist."""
index_settings = {
"settings": {
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True}
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "10s", "knn": True}
},
"mappings": {
"properties": {
"text": {"type": "text"},
"vector": {
"vector_field": {
"type": "knn_vector",
"dimension": self.embedding_model_dims,
"method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"},
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
},
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
}
@@ -71,12 +70,15 @@ class OpenSearchDB(VectorStoreBase):
def create_col(self, name: str, vector_size: int) -> None:
"""Create a new collection (index in OpenSearch)."""
index_settings = {
"settings": {
"index.knn": True
},
"mappings": {
"properties": {
"vector": {
"vector_field": {
"type": "knn_vector",
"dimension": vector_size,
"method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"},
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
},
"payload": {"type": "object"},
"id": {"type": "keyword"},
@@ -88,6 +90,24 @@ class OpenSearchDB(VectorStoreBase):
self.client.indices.create(index=name, body=index_settings)
logger.info(f"Created index {name}")
# Wait for index to be ready
max_retries = 60 # 60 seconds timeout
retry_count = 0
while retry_count < max_retries:
try:
# Check if index is ready by attempting a simple search
self.client.search(index=name, body={"query": {"match_all": {}}})
logger.info(f"Index {name} is ready")
time.sleep(1)
return
except Exception:
retry_count += 1
if retry_count == max_retries:
raise TimeoutError(
f"Index {name} creation timed out after {max_retries} seconds"
)
time.sleep(0.5)
def insert(
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
) -> List[OutputData]:
@@ -98,74 +118,161 @@ class OpenSearchDB(VectorStoreBase):
if payloads is None:
payloads = [{} for _ in range(len(vectors))]
actions = []
for i, (vec, id_) in enumerate(zip(vectors, ids)):
action = {
"_index": self.collection_name,
"_id": id_,
"_source": {
"vector": vec,
"metadata": payloads[i], # Store metadata in the metadata field
},
body = {
"vector_field": vec,
"payload": payloads[i],
"id": id_,
}
actions.append(action)
bulk(self.client, actions)
self.client.index(index=self.collection_name, body=body)
results = []
for i, id_ in enumerate(ids):
results.append(OutputData(id=id_, score=1.0, payload=payloads[i]))
return results
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""Search for similar vectors using OpenSearch k-NN search with pre-filtering."""
search_query = {
"size": limit,
"query": {
"knn": {
"vector": {
"vector": vectors,
"k": limit,
}
"""Search for similar vectors using OpenSearch k-NN search with optional filters."""
# Base KNN query
knn_query = {
"knn": {
"vector_field": {
"vector": vectors,
"k": limit * 2,
}
},
}
}
# Start building the full query
query_body = {
"size": limit * 2,
"query": None
}
# Prepare filter conditions if applicable
filter_clauses = []
if filters:
filter_conditions = [{"term": {f"metadata.{key}": value}} for key, value in filters.items()]
search_query["query"]["knn"]["vector"]["filter"] = {"bool": {"filter": filter_conditions}}
for key in ["user_id", "run_id", "agent_id"]:
value = filters.get(key)
if value:
filter_clauses.append({
"term": {f"payload.{key}.keyword": value}
})
response = self.client.search(index=self.collection_name, body=search_query)
# Combine knn with filters if needed
if filter_clauses:
query_body["query"] = {
"bool": {
"must": knn_query,
"filter": filter_clauses
}
}
else:
query_body["query"] = knn_query
# Execute search
response = self.client.search(index=self.collection_name, body=query_body)
hits = response["hits"]["hits"]
results = [
OutputData(id=hit["_id"], score=hit["_score"], payload=hit["_source"].get("metadata", {}))
for hit in response["hits"]["hits"]
OutputData(
id=hit["_source"].get("id"),
score=hit["_score"],
payload=hit["_source"].get("payload", {})
)
for hit in hits
]
return results
def delete(self, vector_id: str) -> None:
"""Delete a vector by ID."""
self.client.delete(index=self.collection_name, id=vector_id)
"""Delete a vector by custom ID."""
# First, find the document by custom ID
search_query = {
"query": {
"term": {
"id": vector_id
}
}
}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response.get("hits", {}).get("hits", [])
if not hits:
return
opensearch_id = hits[0]["_id"]
# Delete using the actual document ID
self.client.delete(index=self.collection_name, id=opensearch_id)
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
"""Update a vector and its payload."""
"""Update a vector and its payload using the custom 'id' field."""
# First, find the document by custom ID
search_query = {
"query": {
"term": {
"id": vector_id
}
}
}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response.get("hits", {}).get("hits", [])
if not hits:
return
opensearch_id = hits[0]["_id"] # The actual document ID in OpenSearch
# Prepare updated fields
doc = {}
if vector is not None:
doc["vector"] = vector
doc["vector_field"] = vector
if payload is not None:
doc["metadata"] = payload
doc["payload"] = payload
if doc:
try:
response = self.client.update(index=self.collection_name, id=opensearch_id, body={"doc": doc})
except Exception:
pass
self.client.update(index=self.collection_name, id=vector_id, body={"doc": doc})
def get(self, vector_id: str) -> Optional[OutputData]:
"""Retrieve a vector by ID."""
try:
response = self.client.get(index=self.collection_name, id=vector_id)
return OutputData(id=response["_id"], score=1.0, payload=response["_source"].get("metadata", {}))
# First check if index exists
if not self.client.indices.exists(index=self.collection_name):
logger.info(f"Index {self.collection_name} does not exist, creating it...")
self.create_col(self.collection_name, self.embedding_model_dims)
return None
search_query = {
"query": {
"term": {
"id": vector_id
}
}
}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response["hits"]["hits"]
if not hits:
return None
return OutputData(
id=hits[0]["_source"].get("id"),
score=1.0,
payload=hits[0]["_source"].get("payload", {})
)
except Exception as e:
logger.error(f"Error retrieving vector {vector_id}: {e}")
logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
return None
def list_cols(self) -> List[str]:
@@ -180,28 +287,52 @@ class OpenSearchDB(VectorStoreBase):
"""Get information about a collection (index)."""
return self.client.indices.get(index=name)
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
"""List all memories."""
query = {"query": {"match_all": {}}}
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
if filters:
query["query"] = {
"bool": {"must": [{"term": {f"metadata.{key}": value}} for key, value in filters.items()]}
try:
"""List all memories with optional filters."""
query: Dict = {
"query": {
"match_all": {}
}
}
if limit:
query["size"] = limit
filter_clauses = []
if filters:
for key in ["user_id", "run_id", "agent_id"]:
value = filters.get(key)
if value:
filter_clauses.append({
"term": {f"payload.{key}.keyword": value}
})
if filter_clauses:
query["query"] = {
"bool": {
"filter": filter_clauses
}
}
if limit:
query["size"] = limit
response = self.client.search(index=self.collection_name, body=query)
hits = response["hits"]["hits"]
return [[
OutputData(
id=hit["_source"].get("id"),
score=1.0,
payload=hit["_source"].get("payload", {})
)
for hit in hits
]]
except Exception:
return []
response = self.client.search(index=self.collection_name, body=query)
return [
[
OutputData(id=hit["_id"], score=1.0, payload=hit["_source"].get("metadata", {}))
for hit in response["hits"]["hits"]
]
]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_index()
self.create_col(self.collection_name, self.embedding_model_dims)

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "mem0ai"
version = "0.1.98"
version = "0.1.99"
description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"]
exclude = [