diff --git a/docs/components/embedders/models/aws_bedrock.mdx b/docs/components/embedders/models/aws_bedrock.mdx index 470963ef..389fa655 100644 --- a/docs/components/embedders/models/aws_bedrock.mdx +++ b/docs/components/embedders/models/aws_bedrock.mdx @@ -25,7 +25,7 @@ from mem0 import Memory os.environ["OPENAI_API_KEY"] = "your-openai-api-key" # AWS credentials -os.environ["AWS_REGION"] = "us-east-1" +os.environ["AWS_REGION"] = "us-west-2" os.environ["AWS_ACCESS_KEY_ID"] = "your-access-key" os.environ["AWS_SECRET_ACCESS_KEY"] = "your-secret-key" @@ -33,7 +33,7 @@ config = { "embedder": { "provider": "aws_bedrock", "config": { - "model": "amazon.titan-embed-text-v1" + "model": "amazon.titan-embed-text-v2:0" } } } diff --git a/docs/components/llms/models/aws_bedrock.mdx b/docs/components/llms/models/aws_bedrock.mdx index 5561e698..f0ff0a75 100644 --- a/docs/components/llms/models/aws_bedrock.mdx +++ b/docs/components/llms/models/aws_bedrock.mdx @@ -15,16 +15,15 @@ title: AWS Bedrock import os from mem0 import Memory -os.environ["OPENAI_API_KEY"] = "your-api-key" # used for embedding model -os.environ['AWS_REGION'] = 'us-east-1' -os.environ["AWS_ACCESS_KEY"] = "xx" +os.environ['AWS_REGION'] = 'us-west-2' +os.environ["AWS_ACCESS_KEY_ID"] = "xx" os.environ["AWS_SECRET_ACCESS_KEY"] = "xx" config = { "llm": { "provider": "aws_bedrock", "config": { - "model": "arn:aws:bedrock:us-east-1:123456789012:model/your-model-name", + "model": "anthropic.claude-3-5-haiku-20241022-v1:0", "temperature": 0.2, "max_tokens": 2000, } diff --git a/docs/components/vectordbs/dbs/opensearch.mdx b/docs/components/vectordbs/dbs/opensearch.mdx index e5d8f7de..4c0a7290 100644 --- a/docs/components/vectordbs/dbs/opensearch.mdx +++ b/docs/components/vectordbs/dbs/opensearch.mdx @@ -1,59 +1,75 @@ -[OpenSearch](https://opensearch.org/) is an open-source, enterprise-grade search and observability suite that brings order to unstructured data at scale. OpenSearch supports k-NN (k-Nearest Neighbors) and allows you to store and retrieve high-dimensional vector embeddings efficiently. +[OpenSearch](https://opensearch.org/) is an enterprise-grade search and observability suite that brings order to unstructured data at scale. OpenSearch supports k-NN (k-Nearest Neighbors) and allows you to store and retrieve high-dimensional vector embeddings efficiently. ### Installation OpenSearch support requires additional dependencies. Install them with: ```bash -pip install opensearch>=2.8.0 +pip install opensearch-py ``` +### Prerequisites + +Before using OpenSearch with Mem0, you need to set up a collection in AWS OpenSearch Service. + +#### AWS OpenSearch Service +You can create a collection through the AWS Console: +- Navigate to [OpenSearch Service Console](https://console.aws.amazon.com/aos/home) +- Click "Create collection" +- Select "Serverless collection" and then enable "Vector search" capabilities +- Once created, note the endpoint URL (host) for your configuration + + ### Usage ```python import os from mem0 import Memory +import boto3 +from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth -os.environ["OPENAI_API_KEY"] = "sk-xx" +# For AWS OpenSearch Service with IAM authentication +region = 'us-west-2' +service = 'aoss' +credentials = boto3.Session().get_credentials() +auth = AWSV4SignerAuth(credentials, region, service) config = { "vector_store": { "provider": "opensearch", "config": { "collection_name": "mem0", - "host": "localhost", - "port": 9200, - "embedding_model_dims": 1536 + "host": "your-domain.us-west-2.aoss.amazonaws.com", + "port": 443, + "http_auth": auth, + "embedding_model_dims": 1024, + "connection_class": RequestsHttpConnection, + "pool_maxsize": 20, + "use_ssl": True, + "verify_certs": True } } } +``` +### Add Memories + +```python m = Memory.from_config(config) messages = [ {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, - {"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."}, {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} ] m.add(messages, user_id="alice", metadata={"category": "movies"}) ``` -### Config +### Search Memories -Let's see the available parameters for the `opensearch` config: - -| Parameter | Description | Default Value | -| ---------------------- | -------------------------------------------------- | ------------- | -| `collection_name` | The name of the index to store the vectors | `mem0` | -| `embedding_model_dims` | Dimensions of the embedding model | `1536` | -| `host` | The host where the OpenSearch server is running | `localhost` | -| `port` | The port where the OpenSearch server is running | `9200` | -| `api_key` | API key for authentication | `None` | -| `user` | Username for basic authentication | `None` | -| `password` | Password for basic authentication | `None` | -| `verify_certs` | Whether to verify SSL certificates | `False` | -| `auto_create_index` | Whether to automatically create the index | `True` | -| `use_ssl` | Whether to use SSL for connection | `False` | +```python +results = m.search("What kind of movies does Alice like?", user_id="alice") +``` ### Features diff --git a/docs/docs.json b/docs/docs.json index e93cfff1..e6891007 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -200,6 +200,7 @@ "icon": "lightbulb", "pages": [ "examples", + "examples/aws_example", "examples/mem0-demo", "examples/ai_companion_js", "examples/eliza_os", diff --git a/docs/examples/aws_example.mdx b/docs/examples/aws_example.mdx new file mode 100644 index 00000000..ce7b5f42 --- /dev/null +++ b/docs/examples/aws_example.mdx @@ -0,0 +1,120 @@ +--- +title: AWS Bedrock and AOSS +--- + + + +This example demonstrates how to configure and use the `mem0ai` SDK with **AWS Bedrock** and **OpenSearch Service (AOSS)** for persistent memory capabilities in Python. + +## Installation + +Install the required dependencies: + +```bash +pip install mem0ai boto3 opensearch-py +``` + +## Environment Setup + +Set your AWS environment variables: + +```python +import os + +# Set these in your environment or notebook +os.environ['AWS_REGION'] = 'us-west-2' +os.environ['AWS_ACCESS_KEY_ID'] = 'AK00000000000000000' +os.environ['AWS_SECRET_ACCESS_KEY'] = 'AS00000000000000000' + +# Confirm they are set +print(os.environ['AWS_REGION']) +print(os.environ['AWS_ACCESS_KEY_ID']) +print(os.environ['AWS_SECRET_ACCESS_KEY']) +``` + +## Configuration and Usage + +This sets up Mem0 with AWS Bedrock for embeddings and LLM, and OpenSearch as the vector store. + +```python +import boto3 +from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth +from mem0.memory.main import Memory + +region = 'us-west-2' +service = 'aoss' +credentials = boto3.Session().get_credentials() +auth = AWSV4SignerAuth(credentials, region, service) + +config = { + "embedder": { + "provider": "aws_bedrock", + "config": { + "model": "amazon.titan-embed-text-v2:0" + } + }, + "llm": { + "provider": "aws_bedrock", + "config": { + "model": "anthropic.claude-3-5-haiku-20241022-v1:0", + "temperature": 0.1, + "max_tokens": 2000 + } + }, + "vector_store": { + "provider": "opensearch", + "config": { + "collection_name": "mem0", + "host": "your-opensearch-domain.us-west-2.es.amazonaws.com", + "port": 443, + "http_auth": auth, + "embedding_model_dims": 1024, + "connection_class": RequestsHttpConnection, + "pool_maxsize": 20, + "use_ssl": True, + "verify_certs": True + } + } +} + +# Initialize memory system +m = Memory.from_config(config) +``` + +## Usage + +#### Add a memory: + +```python +messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, + {"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] + +# Store inferred memories (default behavior) +result = m.add(messages, user_id="alice", metadata={"category": "movie_recommendations"}) +``` + +#### Search a memory: +```python +relevant_memories = m.search(query, user_id="alice") +``` + +#### Get all memories: +```python +all_memories = m.get_all(user_id="alice") +``` + +#### Get a specific memory: +```python +memory = m.get(memory_id) +``` + + +--- + +## Conclusion + +With Mem0 and AWS services like Bedrock and OpenSearch, you can build intelligent AI companions that remember, adapt, and personalize their responses over time. This makes them ideal for long-term assistants, tutors, or support bots with persistent memory and natural conversation abilities. diff --git a/mem0/configs/embeddings/base.py b/mem0/configs/embeddings/base.py index 23a9c6f2..de8a19b9 100644 --- a/mem0/configs/embeddings/base.py +++ b/mem0/configs/embeddings/base.py @@ -33,6 +33,10 @@ class BaseEmbedderConfig(ABC): memory_search_embedding_type: Optional[str] = None, # LM Studio specific lmstudio_base_url: Optional[str] = "http://localhost:1234/v1", + # AWS Bedrock specific + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region: Optional[str] = "us-west-2", ): """ Initializes a configuration class instance for the Embeddings. @@ -92,3 +96,8 @@ class BaseEmbedderConfig(ABC): # LM Studio specific self.lmstudio_base_url = lmstudio_base_url + + # AWS Bedrock specific + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_region = aws_region diff --git a/mem0/configs/llms/base.py b/mem0/configs/llms/base.py index 6f062eca..983d9f8e 100644 --- a/mem0/configs/llms/base.py +++ b/mem0/configs/llms/base.py @@ -41,6 +41,10 @@ class BaseLlmConfig(ABC): xai_base_url: Optional[str] = None, # LM Studio specific lmstudio_base_url: Optional[str] = "http://localhost:1234/v1", + # AWS Bedrock specific + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region: Optional[str] = "us-west-2", ): """ Initializes a configuration class instance for the LLM. @@ -123,3 +127,8 @@ class BaseLlmConfig(ABC): # LM Studio specific self.lmstudio_base_url = lmstudio_base_url + + # AWS Bedrock specific + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_region = aws_region diff --git a/mem0/configs/vector_stores/opensearch.py b/mem0/configs/vector_stores/opensearch.py index 3a240061..8f158277 100644 --- a/mem0/configs/vector_stores/opensearch.py +++ b/mem0/configs/vector_stores/opensearch.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union, Type from pydantic import BaseModel, Field, model_validator @@ -7,14 +7,33 @@ class OpenSearchConfig(BaseModel): collection_name: str = Field("mem0", description="Name of the index") host: str = Field("localhost", description="OpenSearch host") port: int = Field(9200, description="OpenSearch port") - user: Optional[str] = Field(None, description="Username for authentication") - password: Optional[str] = Field(None, description="Password for authentication") - api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)") - embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector") - verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)") - use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)") - auto_create_index: bool = Field(True, description="Automatically create index during initialization") - http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4") + user: Optional[str] = Field( + None, description="Username for authentication" + ) + password: Optional[str] = Field( + None, description="Password for authentication" + ) + api_key: Optional[str] = Field( + None, description="API key for authentication (if applicable)" + ) + embedding_model_dims: int = Field( + 1536, description="Dimension of the embedding vector" + ) + verify_certs: bool = Field( + False, description="Verify SSL certificates (default False for OpenSearch)" + ) + use_ssl: bool = Field( + False, description="Use SSL for connection (default False for OpenSearch)" + ) + http_auth: Optional[object] = Field( + None, description="HTTP authentication method / AWS SigV4" + ) + connection_class: Optional[Union[str, Type]] = Field( + "RequestsHttpConnection", description="Connection class for OpenSearch" + ) + pool_maxsize: int = Field( + 20, description="Maximum number of connections in the pool" + ) @model_validator(mode="before") @classmethod @@ -22,11 +41,7 @@ class OpenSearchConfig(BaseModel): # Check if host is provided if not values.get("host"): raise ValueError("Host must be provided for OpenSearch") - - # Authentication: Either API key or user/password must be provided - if not any([values.get("api_key"), (values.get("user") and values.get("password")), values.get("http_auth")]): - raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication") - + return values @model_validator(mode="before") @@ -37,6 +52,7 @@ class OpenSearchConfig(BaseModel): extra_fields = input_fields - allowed_fields if extra_fields: raise ValueError( - f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}" + f"Extra fields not allowed: {', '.join(extra_fields)}. " + f"Allowed fields: {', '.join(allowed_fields)}" ) return values diff --git a/mem0/embeddings/aws_bedrock.py b/mem0/embeddings/aws_bedrock.py index 2fcf6df3..10116511 100644 --- a/mem0/embeddings/aws_bedrock.py +++ b/mem0/embeddings/aws_bedrock.py @@ -1,4 +1,5 @@ import json +import os from typing import Literal, Optional try: @@ -22,7 +23,26 @@ class AWSBedrockEmbedding(EmbeddingBase): super().__init__(config) self.config.model = self.config.model or "amazon.titan-embed-text-v1" - self.client = boto3.client("bedrock-runtime") + + # Get AWS config from environment variables or use defaults + aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "") + aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "") + aws_region = os.environ.get("AWS_REGION", "us-west-2") + + # Check if AWS config is provided in the config + if hasattr(self.config, "aws_access_key_id"): + aws_access_key = self.config.aws_access_key_id + if hasattr(self.config, "aws_secret_access_key"): + aws_secret_key = self.config.aws_secret_access_key + if hasattr(self.config, "aws_region"): + aws_region = self.config.aws_region + + self.client = boto3.client( + "bedrock-runtime", + region_name=aws_region, + aws_access_key_id=aws_access_key if aws_access_key else None, + aws_secret_access_key=aws_secret_key if aws_secret_key else None, + ) def _normalize_vector(self, embeddings): """Normalize the embedding to a unit vector.""" diff --git a/mem0/llms/aws_bedrock.py b/mem0/llms/aws_bedrock.py index 8d8bb01d..adf03762 100644 --- a/mem0/llms/aws_bedrock.py +++ b/mem0/llms/aws_bedrock.py @@ -1,4 +1,6 @@ import json +import os +import re from typing import Any, Dict, List, Optional try: @@ -9,6 +11,14 @@ except ImportError: from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase +PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"] + + +def extract_provider(model: str) -> str: + for provider in PROVIDERS: + if re.search(rf"\b{re.escape(provider)}\b", model): + return provider + raise ValueError(f"Unknown provider in model: {model}") class AWSBedrockLLM(LLMBase): def __init__(self, config: Optional[BaseLlmConfig] = None): @@ -16,7 +26,27 @@ class AWSBedrockLLM(LLMBase): if not self.config.model: self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0" - self.client = boto3.client("bedrock-runtime") + + # Get AWS config from environment variables or use defaults + aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "") + aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "") + aws_region = os.environ.get("AWS_REGION", "us-west-2") + + # Check if AWS config is provided in the config + if hasattr(self.config, "aws_access_key_id"): + aws_access_key = self.config.aws_access_key_id + if hasattr(self.config, "aws_secret_access_key"): + aws_secret_key = self.config.aws_secret_access_key + if hasattr(self.config, "aws_region"): + aws_region = self.config.aws_region + + self.client = boto3.client( + "bedrock-runtime", + region_name=aws_region, + aws_access_key_id=aws_access_key if aws_access_key else None, + aws_secret_access_key=aws_secret_key if aws_secret_key else None, + ) + self.model_kwargs = { "temperature": self.config.temperature, "max_tokens_to_sample": self.config.max_tokens, @@ -34,13 +64,14 @@ class AWSBedrockLLM(LLMBase): Returns: str: A formatted string combining all messages, structured with roles capitalized and separated by newlines. """ + formatted_messages = [] for message in messages: role = message["role"].capitalize() content = message["content"] formatted_messages.append(f"\n\n{role}: {content}") - return "".join(formatted_messages) + "\n\nAssistant:" + return "\n\nHuman: " + "".join(formatted_messages) + "\n\nAssistant:" def _parse_response(self, response, tools) -> str: """ @@ -68,8 +99,9 @@ class AWSBedrockLLM(LLMBase): return processed_response - response_body = json.loads(response["body"].read().decode()) - return response_body.get("completion", "") + response_body = response.get("body").read().decode() + response_json = json.loads(response_body) + return response_json.get("content", [{"text": ""}])[0].get("text", "") def _prepare_input( self, @@ -113,9 +145,9 @@ class AWSBedrockLLM(LLMBase): input_body = { "inputText": prompt, "textGenerationConfig": { - "maxTokenCount": model_kwargs.get("max_tokens_to_sample"), - "topP": model_kwargs.get("top_p"), - "temperature": model_kwargs.get("temperature"), + "maxTokenCount": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000, + "topP": self.model_kwargs["top_p"] or 0.9, + "temperature": self.model_kwargs["temperature"] or 0.1, }, } input_body["textGenerationConfig"] = { @@ -206,15 +238,40 @@ class AWSBedrockLLM(LLMBase): else: # Use invoke_model method when no tools are provided prompt = self._format_messages(messages) - provider = self.model.split(".")[0] - input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs) + provider = extract_provider(self.config.model) + input_body = self._prepare_input(provider, self.config.model, prompt, model_kwargs=self.model_kwargs) body = json.dumps(input_body) - response = self.client.invoke_model( - body=body, - modelId=self.model, - accept="application/json", - contentType="application/json", + if provider == "anthropic" or provider == "deepseek": + + input_body = { + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}] + } + ], + "max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000, + "temperature": self.model_kwargs["temperature"] or 0.1, + "top_p": self.model_kwargs["top_p"] or 0.9, + "anthropic_version": "bedrock-2023-05-31", + } + + body = json.dumps(input_body) + + + response = self.client.invoke_model( + body=body, + modelId=self.config.model, + accept="application/json", + contentType="application/json", + ) + else: + response = self.client.invoke_model( + body=body, + modelId=self.config.model, + accept="application/json", + contentType="application/json", ) return self._parse_response(response, tools) diff --git a/mem0/memory/main.py b/mem0/memory/main.py index fd02cf57..8ec74f8c 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -69,17 +69,14 @@ class Memory(MemoryBase): self.enable_graph = True else: self.graph = None - self.config.vector_store.config.collection_name = "mem0migrations" if self.config.vector_store.provider in ["faiss", "qdrant"]: provider_path = f"migrations_{self.config.vector_store.provider}" self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path) os.makedirs(self.config.vector_store.config.path, exist_ok=True) - self._telemetry_vector_store = VectorStoreFactory.create( self.config.vector_store.provider, self.config.vector_store.config ) - capture_event("mem0.init", self, {"sync_type": "sync"}) @classmethod diff --git a/mem0/memory/setup.py b/mem0/memory/setup.py index 3527de37..13864179 100644 --- a/mem0/memory/setup.py +++ b/mem0/memory/setup.py @@ -38,7 +38,7 @@ def get_or_create_user_id(vector_store): # Try to get existing user_id from vector store try: - existing = vector_store.get(vector_id=VECTOR_ID) + existing = vector_store.get(vector_id=user_id) if existing and hasattr(existing, "payload") and existing.payload and "user_id" in existing.payload: return existing.payload["user_id"] except Exception: @@ -48,7 +48,7 @@ def get_or_create_user_id(vector_store): try: dims = getattr(vector_store, "embedding_model_dims", 1536) vector_store.insert( - vectors=[[0.0] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[VECTOR_ID] + vectors=[[0.1] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[user_id] ) except Exception: pass diff --git a/mem0/vector_stores/opensearch.py b/mem0/vector_stores/opensearch.py index 72d39976..21ee6377 100644 --- a/mem0/vector_stores/opensearch.py +++ b/mem0/vector_stores/opensearch.py @@ -1,5 +1,6 @@ import logging from typing import Any, Dict, List, Optional +import time try: from opensearchpy import OpenSearch, RequestsHttpConnection @@ -34,28 +35,26 @@ class OpenSearchDB(VectorStoreBase): use_ssl=config.use_ssl, verify_certs=config.verify_certs, connection_class=RequestsHttpConnection, + pool_maxsize=20 ) self.collection_name = config.collection_name self.embedding_model_dims = config.embedding_model_dims - - # Create index only if auto_create_index is True - if config.auto_create_index: - self.create_index() + self.create_col(self.collection_name, self.embedding_model_dims) def create_index(self) -> None: """Create OpenSearch index with proper mappings if it doesn't exist.""" index_settings = { "settings": { - "index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True} + "index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "10s", "knn": True} }, "mappings": { "properties": { "text": {"type": "text"}, - "vector": { + "vector_field": { "type": "knn_vector", "dimension": self.embedding_model_dims, - "method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"}, + "method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"}, }, "metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}}, } @@ -71,12 +70,15 @@ class OpenSearchDB(VectorStoreBase): def create_col(self, name: str, vector_size: int) -> None: """Create a new collection (index in OpenSearch).""" index_settings = { + "settings": { + "index.knn": True + }, "mappings": { "properties": { - "vector": { + "vector_field": { "type": "knn_vector", "dimension": vector_size, - "method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"}, + "method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"}, }, "payload": {"type": "object"}, "id": {"type": "keyword"}, @@ -88,6 +90,24 @@ class OpenSearchDB(VectorStoreBase): self.client.indices.create(index=name, body=index_settings) logger.info(f"Created index {name}") + # Wait for index to be ready + max_retries = 60 # 60 seconds timeout + retry_count = 0 + while retry_count < max_retries: + try: + # Check if index is ready by attempting a simple search + self.client.search(index=name, body={"query": {"match_all": {}}}) + logger.info(f"Index {name} is ready") + time.sleep(1) + return + except Exception: + retry_count += 1 + if retry_count == max_retries: + raise TimeoutError( + f"Index {name} creation timed out after {max_retries} seconds" + ) + time.sleep(0.5) + def insert( self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None ) -> List[OutputData]: @@ -98,74 +118,161 @@ class OpenSearchDB(VectorStoreBase): if payloads is None: payloads = [{} for _ in range(len(vectors))] - actions = [] for i, (vec, id_) in enumerate(zip(vectors, ids)): - action = { - "_index": self.collection_name, - "_id": id_, - "_source": { - "vector": vec, - "metadata": payloads[i], # Store metadata in the metadata field - }, + body = { + "vector_field": vec, + "payload": payloads[i], + "id": id_, } - actions.append(action) - - bulk(self.client, actions) + self.client.index(index=self.collection_name, body=body) results = [] - for i, id_ in enumerate(ids): - results.append(OutputData(id=id_, score=1.0, payload=payloads[i])) + return results def search( self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None ) -> List[OutputData]: - """Search for similar vectors using OpenSearch k-NN search with pre-filtering.""" - search_query = { - "size": limit, - "query": { - "knn": { - "vector": { - "vector": vectors, - "k": limit, - } + """Search for similar vectors using OpenSearch k-NN search with optional filters.""" + + # Base KNN query + knn_query = { + "knn": { + "vector_field": { + "vector": vectors, + "k": limit * 2, } - }, + } } + # Start building the full query + query_body = { + "size": limit * 2, + "query": None + } + + # Prepare filter conditions if applicable + filter_clauses = [] if filters: - filter_conditions = [{"term": {f"metadata.{key}": value}} for key, value in filters.items()] - search_query["query"]["knn"]["vector"]["filter"] = {"bool": {"filter": filter_conditions}} + for key in ["user_id", "run_id", "agent_id"]: + value = filters.get(key) + if value: + filter_clauses.append({ + "term": {f"payload.{key}.keyword": value} + }) - response = self.client.search(index=self.collection_name, body=search_query) + # Combine knn with filters if needed + if filter_clauses: + query_body["query"] = { + "bool": { + "must": knn_query, + "filter": filter_clauses + } + } + else: + query_body["query"] = knn_query + # Execute search + response = self.client.search(index=self.collection_name, body=query_body) + + hits = response["hits"]["hits"] results = [ - OutputData(id=hit["_id"], score=hit["_score"], payload=hit["_source"].get("metadata", {})) - for hit in response["hits"]["hits"] + OutputData( + id=hit["_source"].get("id"), + score=hit["_score"], + payload=hit["_source"].get("payload", {}) + ) + for hit in hits ] return results def delete(self, vector_id: str) -> None: - """Delete a vector by ID.""" - self.client.delete(index=self.collection_name, id=vector_id) + """Delete a vector by custom ID.""" + # First, find the document by custom ID + search_query = { + "query": { + "term": { + "id": vector_id + } + } + } + + response = self.client.search(index=self.collection_name, body=search_query) + hits = response.get("hits", {}).get("hits", []) + + if not hits: + return + + opensearch_id = hits[0]["_id"] + + # Delete using the actual document ID + self.client.delete(index=self.collection_name, id=opensearch_id) + def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None: - """Update a vector and its payload.""" + """Update a vector and its payload using the custom 'id' field.""" + + # First, find the document by custom ID + search_query = { + "query": { + "term": { + "id": vector_id + } + } + } + + response = self.client.search(index=self.collection_name, body=search_query) + hits = response.get("hits", {}).get("hits", []) + + if not hits: + return + + opensearch_id = hits[0]["_id"] # The actual document ID in OpenSearch + + # Prepare updated fields doc = {} if vector is not None: - doc["vector"] = vector + doc["vector_field"] = vector if payload is not None: - doc["metadata"] = payload + doc["payload"] = payload + + if doc: + try: + response = self.client.update(index=self.collection_name, id=opensearch_id, body={"doc": doc}) + except Exception: + pass - self.client.update(index=self.collection_name, id=vector_id, body={"doc": doc}) def get(self, vector_id: str) -> Optional[OutputData]: """Retrieve a vector by ID.""" try: - response = self.client.get(index=self.collection_name, id=vector_id) - return OutputData(id=response["_id"], score=1.0, payload=response["_source"].get("metadata", {})) + # First check if index exists + if not self.client.indices.exists(index=self.collection_name): + logger.info(f"Index {self.collection_name} does not exist, creating it...") + self.create_col(self.collection_name, self.embedding_model_dims) + return None + + search_query = { + "query": { + "term": { + "id": vector_id + } + } + } + response = self.client.search(index=self.collection_name, body=search_query) + + hits = response["hits"]["hits"] + + if not hits: + return None + + return OutputData( + id=hits[0]["_source"].get("id"), + score=1.0, + payload=hits[0]["_source"].get("payload", {}) + ) except Exception as e: - logger.error(f"Error retrieving vector {vector_id}: {e}") + logger.error(f"Error retrieving vector {vector_id}: {str(e)}") return None def list_cols(self) -> List[str]: @@ -180,28 +287,52 @@ class OpenSearchDB(VectorStoreBase): """Get information about a collection (index).""" return self.client.indices.get(index=name) - def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]: - """List all memories.""" - query = {"query": {"match_all": {}}} + def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]: - if filters: - query["query"] = { - "bool": {"must": [{"term": {f"metadata.{key}": value}} for key, value in filters.items()]} + try: + """List all memories with optional filters.""" + query: Dict = { + "query": { + "match_all": {} + } } - if limit: - query["size"] = limit + filter_clauses = [] + if filters: + for key in ["user_id", "run_id", "agent_id"]: + value = filters.get(key) + if value: + filter_clauses.append({ + "term": {f"payload.{key}.keyword": value} + }) + + if filter_clauses: + query["query"] = { + "bool": { + "filter": filter_clauses + } + } + + if limit: + query["size"] = limit + + response = self.client.search(index=self.collection_name, body=query) + hits = response["hits"]["hits"] + + return [[ + OutputData( + id=hit["_source"].get("id"), + score=1.0, + payload=hit["_source"].get("payload", {}) + ) + for hit in hits + ]] + except Exception: + return [] + - response = self.client.search(index=self.collection_name, body=query) - return [ - [ - OutputData(id=hit["_id"], score=1.0, payload=hit["_source"].get("metadata", {})) - for hit in response["hits"]["hits"] - ] - ] - def reset(self): """Reset the index by deleting and recreating it.""" logger.warning(f"Resetting index {self.collection_name}...") self.delete_col() - self.create_index() + self.create_col(self.collection_name, self.embedding_model_dims) diff --git a/pyproject.toml b/pyproject.toml index 0483733f..21a21f9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mem0ai" -version = "0.1.98" +version = "0.1.99" description = "Long-term memory for AI Agents" authors = ["Mem0 "] exclude = [