improvement(OSS): Fix AOSS and AWS BedRock LLM (#2697)
Co-authored-by: Prateek Chhikara <prateekchhikara24@gmail.com> Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -25,7 +25,7 @@ from mem0 import Memory
|
|||||||
os.environ["OPENAI_API_KEY"] = "your-openai-api-key"
|
os.environ["OPENAI_API_KEY"] = "your-openai-api-key"
|
||||||
|
|
||||||
# AWS credentials
|
# AWS credentials
|
||||||
os.environ["AWS_REGION"] = "us-east-1"
|
os.environ["AWS_REGION"] = "us-west-2"
|
||||||
os.environ["AWS_ACCESS_KEY_ID"] = "your-access-key"
|
os.environ["AWS_ACCESS_KEY_ID"] = "your-access-key"
|
||||||
os.environ["AWS_SECRET_ACCESS_KEY"] = "your-secret-key"
|
os.environ["AWS_SECRET_ACCESS_KEY"] = "your-secret-key"
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ config = {
|
|||||||
"embedder": {
|
"embedder": {
|
||||||
"provider": "aws_bedrock",
|
"provider": "aws_bedrock",
|
||||||
"config": {
|
"config": {
|
||||||
"model": "amazon.titan-embed-text-v1"
|
"model": "amazon.titan-embed-text-v2:0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,16 +15,15 @@ title: AWS Bedrock
|
|||||||
import os
|
import os
|
||||||
from mem0 import Memory
|
from mem0 import Memory
|
||||||
|
|
||||||
os.environ["OPENAI_API_KEY"] = "your-api-key" # used for embedding model
|
os.environ['AWS_REGION'] = 'us-west-2'
|
||||||
os.environ['AWS_REGION'] = 'us-east-1'
|
os.environ["AWS_ACCESS_KEY_ID"] = "xx"
|
||||||
os.environ["AWS_ACCESS_KEY"] = "xx"
|
|
||||||
os.environ["AWS_SECRET_ACCESS_KEY"] = "xx"
|
os.environ["AWS_SECRET_ACCESS_KEY"] = "xx"
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"llm": {
|
"llm": {
|
||||||
"provider": "aws_bedrock",
|
"provider": "aws_bedrock",
|
||||||
"config": {
|
"config": {
|
||||||
"model": "arn:aws:bedrock:us-east-1:123456789012:model/your-model-name",
|
"model": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||||
"temperature": 0.2,
|
"temperature": 0.2,
|
||||||
"max_tokens": 2000,
|
"max_tokens": 2000,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,59 +1,75 @@
|
|||||||
[OpenSearch](https://opensearch.org/) is an open-source, enterprise-grade search and observability suite that brings order to unstructured data at scale. OpenSearch supports k-NN (k-Nearest Neighbors) and allows you to store and retrieve high-dimensional vector embeddings efficiently.
|
[OpenSearch](https://opensearch.org/) is an enterprise-grade search and observability suite that brings order to unstructured data at scale. OpenSearch supports k-NN (k-Nearest Neighbors) and allows you to store and retrieve high-dimensional vector embeddings efficiently.
|
||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
OpenSearch support requires additional dependencies. Install them with:
|
OpenSearch support requires additional dependencies. Install them with:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install opensearch>=2.8.0
|
pip install opensearch-py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
Before using OpenSearch with Mem0, you need to set up a collection in AWS OpenSearch Service.
|
||||||
|
|
||||||
|
#### AWS OpenSearch Service
|
||||||
|
You can create a collection through the AWS Console:
|
||||||
|
- Navigate to [OpenSearch Service Console](https://console.aws.amazon.com/aos/home)
|
||||||
|
- Click "Create collection"
|
||||||
|
- Select "Serverless collection" and then enable "Vector search" capabilities
|
||||||
|
- Once created, note the endpoint URL (host) for your configuration
|
||||||
|
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from mem0 import Memory
|
from mem0 import Memory
|
||||||
|
import boto3
|
||||||
|
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
|
||||||
|
|
||||||
os.environ["OPENAI_API_KEY"] = "sk-xx"
|
# For AWS OpenSearch Service with IAM authentication
|
||||||
|
region = 'us-west-2'
|
||||||
|
service = 'aoss'
|
||||||
|
credentials = boto3.Session().get_credentials()
|
||||||
|
auth = AWSV4SignerAuth(credentials, region, service)
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"vector_store": {
|
"vector_store": {
|
||||||
"provider": "opensearch",
|
"provider": "opensearch",
|
||||||
"config": {
|
"config": {
|
||||||
"collection_name": "mem0",
|
"collection_name": "mem0",
|
||||||
"host": "localhost",
|
"host": "your-domain.us-west-2.aoss.amazonaws.com",
|
||||||
"port": 9200,
|
"port": 443,
|
||||||
"embedding_model_dims": 1536
|
"http_auth": auth,
|
||||||
|
"embedding_model_dims": 1024,
|
||||||
|
"connection_class": RequestsHttpConnection,
|
||||||
|
"pool_maxsize": 20,
|
||||||
|
"use_ssl": True,
|
||||||
|
"verify_certs": True
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Add Memories
|
||||||
|
|
||||||
|
```python
|
||||||
m = Memory.from_config(config)
|
m = Memory.from_config(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||||
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
||||||
{"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."},
|
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||||
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||||
]
|
]
|
||||||
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||||
```
|
```
|
||||||
|
|
||||||
### Config
|
### Search Memories
|
||||||
|
|
||||||
Let's see the available parameters for the `opensearch` config:
|
```python
|
||||||
|
results = m.search("What kind of movies does Alice like?", user_id="alice")
|
||||||
| Parameter | Description | Default Value |
|
```
|
||||||
| ---------------------- | -------------------------------------------------- | ------------- |
|
|
||||||
| `collection_name` | The name of the index to store the vectors | `mem0` |
|
|
||||||
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
|
|
||||||
| `host` | The host where the OpenSearch server is running | `localhost` |
|
|
||||||
| `port` | The port where the OpenSearch server is running | `9200` |
|
|
||||||
| `api_key` | API key for authentication | `None` |
|
|
||||||
| `user` | Username for basic authentication | `None` |
|
|
||||||
| `password` | Password for basic authentication | `None` |
|
|
||||||
| `verify_certs` | Whether to verify SSL certificates | `False` |
|
|
||||||
| `auto_create_index` | Whether to automatically create the index | `True` |
|
|
||||||
| `use_ssl` | Whether to use SSL for connection | `False` |
|
|
||||||
|
|
||||||
### Features
|
### Features
|
||||||
|
|
||||||
|
|||||||
@@ -200,6 +200,7 @@
|
|||||||
"icon": "lightbulb",
|
"icon": "lightbulb",
|
||||||
"pages": [
|
"pages": [
|
||||||
"examples",
|
"examples",
|
||||||
|
"examples/aws_example",
|
||||||
"examples/mem0-demo",
|
"examples/mem0-demo",
|
||||||
"examples/ai_companion_js",
|
"examples/ai_companion_js",
|
||||||
"examples/eliza_os",
|
"examples/eliza_os",
|
||||||
|
|||||||
120
docs/examples/aws_example.mdx
Normal file
120
docs/examples/aws_example.mdx
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
---
|
||||||
|
title: AWS Bedrock and AOSS
|
||||||
|
---
|
||||||
|
|
||||||
|
<Snippet file="paper-release.mdx" />
|
||||||
|
|
||||||
|
This example demonstrates how to configure and use the `mem0ai` SDK with **AWS Bedrock** and **OpenSearch Service (AOSS)** for persistent memory capabilities in Python.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Install the required dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mem0ai boto3 opensearch-py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment Setup
|
||||||
|
|
||||||
|
Set your AWS environment variables:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set these in your environment or notebook
|
||||||
|
os.environ['AWS_REGION'] = 'us-west-2'
|
||||||
|
os.environ['AWS_ACCESS_KEY_ID'] = 'AK00000000000000000'
|
||||||
|
os.environ['AWS_SECRET_ACCESS_KEY'] = 'AS00000000000000000'
|
||||||
|
|
||||||
|
# Confirm they are set
|
||||||
|
print(os.environ['AWS_REGION'])
|
||||||
|
print(os.environ['AWS_ACCESS_KEY_ID'])
|
||||||
|
print(os.environ['AWS_SECRET_ACCESS_KEY'])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration and Usage
|
||||||
|
|
||||||
|
This sets up Mem0 with AWS Bedrock for embeddings and LLM, and OpenSearch as the vector store.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import boto3
|
||||||
|
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
|
||||||
|
from mem0.memory.main import Memory
|
||||||
|
|
||||||
|
region = 'us-west-2'
|
||||||
|
service = 'aoss'
|
||||||
|
credentials = boto3.Session().get_credentials()
|
||||||
|
auth = AWSV4SignerAuth(credentials, region, service)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"embedder": {
|
||||||
|
"provider": "aws_bedrock",
|
||||||
|
"config": {
|
||||||
|
"model": "amazon.titan-embed-text-v2:0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"llm": {
|
||||||
|
"provider": "aws_bedrock",
|
||||||
|
"config": {
|
||||||
|
"model": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||||
|
"temperature": 0.1,
|
||||||
|
"max_tokens": 2000
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"vector_store": {
|
||||||
|
"provider": "opensearch",
|
||||||
|
"config": {
|
||||||
|
"collection_name": "mem0",
|
||||||
|
"host": "your-opensearch-domain.us-west-2.es.amazonaws.com",
|
||||||
|
"port": 443,
|
||||||
|
"http_auth": auth,
|
||||||
|
"embedding_model_dims": 1024,
|
||||||
|
"connection_class": RequestsHttpConnection,
|
||||||
|
"pool_maxsize": 20,
|
||||||
|
"use_ssl": True,
|
||||||
|
"verify_certs": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize memory system
|
||||||
|
m = Memory.from_config(config)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
#### Add a memory:
|
||||||
|
|
||||||
|
```python
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||||
|
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
||||||
|
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||||
|
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Store inferred memories (default behavior)
|
||||||
|
result = m.add(messages, user_id="alice", metadata={"category": "movie_recommendations"})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Search a memory:
|
||||||
|
```python
|
||||||
|
relevant_memories = m.search(query, user_id="alice")
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Get all memories:
|
||||||
|
```python
|
||||||
|
all_memories = m.get_all(user_id="alice")
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Get a specific memory:
|
||||||
|
```python
|
||||||
|
memory = m.get(memory_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Conclusion
|
||||||
|
|
||||||
|
With Mem0 and AWS services like Bedrock and OpenSearch, you can build intelligent AI companions that remember, adapt, and personalize their responses over time. This makes them ideal for long-term assistants, tutors, or support bots with persistent memory and natural conversation abilities.
|
||||||
@@ -33,6 +33,10 @@ class BaseEmbedderConfig(ABC):
|
|||||||
memory_search_embedding_type: Optional[str] = None,
|
memory_search_embedding_type: Optional[str] = None,
|
||||||
# LM Studio specific
|
# LM Studio specific
|
||||||
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
||||||
|
# AWS Bedrock specific
|
||||||
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
aws_region: Optional[str] = "us-west-2",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes a configuration class instance for the Embeddings.
|
Initializes a configuration class instance for the Embeddings.
|
||||||
@@ -92,3 +96,8 @@ class BaseEmbedderConfig(ABC):
|
|||||||
|
|
||||||
# LM Studio specific
|
# LM Studio specific
|
||||||
self.lmstudio_base_url = lmstudio_base_url
|
self.lmstudio_base_url = lmstudio_base_url
|
||||||
|
|
||||||
|
# AWS Bedrock specific
|
||||||
|
self.aws_access_key_id = aws_access_key_id
|
||||||
|
self.aws_secret_access_key = aws_secret_access_key
|
||||||
|
self.aws_region = aws_region
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ class BaseLlmConfig(ABC):
|
|||||||
xai_base_url: Optional[str] = None,
|
xai_base_url: Optional[str] = None,
|
||||||
# LM Studio specific
|
# LM Studio specific
|
||||||
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
||||||
|
# AWS Bedrock specific
|
||||||
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
aws_region: Optional[str] = "us-west-2",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes a configuration class instance for the LLM.
|
Initializes a configuration class instance for the LLM.
|
||||||
@@ -123,3 +127,8 @@ class BaseLlmConfig(ABC):
|
|||||||
|
|
||||||
# LM Studio specific
|
# LM Studio specific
|
||||||
self.lmstudio_base_url = lmstudio_base_url
|
self.lmstudio_base_url = lmstudio_base_url
|
||||||
|
|
||||||
|
# AWS Bedrock specific
|
||||||
|
self.aws_access_key_id = aws_access_key_id
|
||||||
|
self.aws_secret_access_key = aws_secret_access_key
|
||||||
|
self.aws_region = aws_region
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Union, Type
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
@@ -7,14 +7,33 @@ class OpenSearchConfig(BaseModel):
|
|||||||
collection_name: str = Field("mem0", description="Name of the index")
|
collection_name: str = Field("mem0", description="Name of the index")
|
||||||
host: str = Field("localhost", description="OpenSearch host")
|
host: str = Field("localhost", description="OpenSearch host")
|
||||||
port: int = Field(9200, description="OpenSearch port")
|
port: int = Field(9200, description="OpenSearch port")
|
||||||
user: Optional[str] = Field(None, description="Username for authentication")
|
user: Optional[str] = Field(
|
||||||
password: Optional[str] = Field(None, description="Password for authentication")
|
None, description="Username for authentication"
|
||||||
api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
|
)
|
||||||
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
|
password: Optional[str] = Field(
|
||||||
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
|
None, description="Password for authentication"
|
||||||
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
|
)
|
||||||
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
|
api_key: Optional[str] = Field(
|
||||||
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
|
None, description="API key for authentication (if applicable)"
|
||||||
|
)
|
||||||
|
embedding_model_dims: int = Field(
|
||||||
|
1536, description="Dimension of the embedding vector"
|
||||||
|
)
|
||||||
|
verify_certs: bool = Field(
|
||||||
|
False, description="Verify SSL certificates (default False for OpenSearch)"
|
||||||
|
)
|
||||||
|
use_ssl: bool = Field(
|
||||||
|
False, description="Use SSL for connection (default False for OpenSearch)"
|
||||||
|
)
|
||||||
|
http_auth: Optional[object] = Field(
|
||||||
|
None, description="HTTP authentication method / AWS SigV4"
|
||||||
|
)
|
||||||
|
connection_class: Optional[Union[str, Type]] = Field(
|
||||||
|
"RequestsHttpConnection", description="Connection class for OpenSearch"
|
||||||
|
)
|
||||||
|
pool_maxsize: int = Field(
|
||||||
|
20, description="Maximum number of connections in the pool"
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -22,11 +41,7 @@ class OpenSearchConfig(BaseModel):
|
|||||||
# Check if host is provided
|
# Check if host is provided
|
||||||
if not values.get("host"):
|
if not values.get("host"):
|
||||||
raise ValueError("Host must be provided for OpenSearch")
|
raise ValueError("Host must be provided for OpenSearch")
|
||||||
|
|
||||||
# Authentication: Either API key or user/password must be provided
|
|
||||||
if not any([values.get("api_key"), (values.get("user") and values.get("password")), values.get("http_auth")]):
|
|
||||||
raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication")
|
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -37,6 +52,7 @@ class OpenSearchConfig(BaseModel):
|
|||||||
extra_fields = input_fields - allowed_fields
|
extra_fields = input_fields - allowed_fields
|
||||||
if extra_fields:
|
if extra_fields:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}"
|
f"Extra fields not allowed: {', '.join(extra_fields)}. "
|
||||||
|
f"Allowed fields: {', '.join(allowed_fields)}"
|
||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -22,7 +23,26 @@ class AWSBedrockEmbedding(EmbeddingBase):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.config.model = self.config.model or "amazon.titan-embed-text-v1"
|
self.config.model = self.config.model or "amazon.titan-embed-text-v1"
|
||||||
self.client = boto3.client("bedrock-runtime")
|
|
||||||
|
# Get AWS config from environment variables or use defaults
|
||||||
|
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
||||||
|
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
||||||
|
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
||||||
|
|
||||||
|
# Check if AWS config is provided in the config
|
||||||
|
if hasattr(self.config, "aws_access_key_id"):
|
||||||
|
aws_access_key = self.config.aws_access_key_id
|
||||||
|
if hasattr(self.config, "aws_secret_access_key"):
|
||||||
|
aws_secret_key = self.config.aws_secret_access_key
|
||||||
|
if hasattr(self.config, "aws_region"):
|
||||||
|
aws_region = self.config.aws_region
|
||||||
|
|
||||||
|
self.client = boto3.client(
|
||||||
|
"bedrock-runtime",
|
||||||
|
region_name=aws_region,
|
||||||
|
aws_access_key_id=aws_access_key if aws_access_key else None,
|
||||||
|
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
||||||
|
)
|
||||||
|
|
||||||
def _normalize_vector(self, embeddings):
|
def _normalize_vector(self, embeddings):
|
||||||
"""Normalize the embedding to a unit vector."""
|
"""Normalize the embedding to a unit vector."""
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -9,6 +11,14 @@ except ImportError:
|
|||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_provider(model: str) -> str:
|
||||||
|
for provider in PROVIDERS:
|
||||||
|
if re.search(rf"\b{re.escape(provider)}\b", model):
|
||||||
|
return provider
|
||||||
|
raise ValueError(f"Unknown provider in model: {model}")
|
||||||
|
|
||||||
class AWSBedrockLLM(LLMBase):
|
class AWSBedrockLLM(LLMBase):
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
@@ -16,7 +26,27 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
|
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
self.client = boto3.client("bedrock-runtime")
|
|
||||||
|
# Get AWS config from environment variables or use defaults
|
||||||
|
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
||||||
|
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
||||||
|
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
||||||
|
|
||||||
|
# Check if AWS config is provided in the config
|
||||||
|
if hasattr(self.config, "aws_access_key_id"):
|
||||||
|
aws_access_key = self.config.aws_access_key_id
|
||||||
|
if hasattr(self.config, "aws_secret_access_key"):
|
||||||
|
aws_secret_key = self.config.aws_secret_access_key
|
||||||
|
if hasattr(self.config, "aws_region"):
|
||||||
|
aws_region = self.config.aws_region
|
||||||
|
|
||||||
|
self.client = boto3.client(
|
||||||
|
"bedrock-runtime",
|
||||||
|
region_name=aws_region,
|
||||||
|
aws_access_key_id=aws_access_key if aws_access_key else None,
|
||||||
|
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
||||||
|
)
|
||||||
|
|
||||||
self.model_kwargs = {
|
self.model_kwargs = {
|
||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
"max_tokens_to_sample": self.config.max_tokens,
|
"max_tokens_to_sample": self.config.max_tokens,
|
||||||
@@ -34,13 +64,14 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
Returns:
|
Returns:
|
||||||
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
|
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
formatted_messages = []
|
formatted_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
role = message["role"].capitalize()
|
role = message["role"].capitalize()
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
formatted_messages.append(f"\n\n{role}: {content}")
|
formatted_messages.append(f"\n\n{role}: {content}")
|
||||||
|
|
||||||
return "".join(formatted_messages) + "\n\nAssistant:"
|
return "\n\nHuman: " + "".join(formatted_messages) + "\n\nAssistant:"
|
||||||
|
|
||||||
def _parse_response(self, response, tools) -> str:
|
def _parse_response(self, response, tools) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -68,8 +99,9 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
|
|
||||||
return processed_response
|
return processed_response
|
||||||
|
|
||||||
response_body = json.loads(response["body"].read().decode())
|
response_body = response.get("body").read().decode()
|
||||||
return response_body.get("completion", "")
|
response_json = json.loads(response_body)
|
||||||
|
return response_json.get("content", [{"text": ""}])[0].get("text", "")
|
||||||
|
|
||||||
def _prepare_input(
|
def _prepare_input(
|
||||||
self,
|
self,
|
||||||
@@ -113,9 +145,9 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
input_body = {
|
input_body = {
|
||||||
"inputText": prompt,
|
"inputText": prompt,
|
||||||
"textGenerationConfig": {
|
"textGenerationConfig": {
|
||||||
"maxTokenCount": model_kwargs.get("max_tokens_to_sample"),
|
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
|
||||||
"topP": model_kwargs.get("top_p"),
|
"topP": self.model_kwargs["top_p"] or 0.9,
|
||||||
"temperature": model_kwargs.get("temperature"),
|
"temperature": self.model_kwargs["temperature"] or 0.1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
input_body["textGenerationConfig"] = {
|
input_body["textGenerationConfig"] = {
|
||||||
@@ -206,15 +238,40 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
else:
|
else:
|
||||||
# Use invoke_model method when no tools are provided
|
# Use invoke_model method when no tools are provided
|
||||||
prompt = self._format_messages(messages)
|
prompt = self._format_messages(messages)
|
||||||
provider = self.model.split(".")[0]
|
provider = extract_provider(self.config.model)
|
||||||
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
|
input_body = self._prepare_input(provider, self.config.model, prompt, model_kwargs=self.model_kwargs)
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
response = self.client.invoke_model(
|
if provider == "anthropic" or provider == "deepseek":
|
||||||
body=body,
|
|
||||||
modelId=self.model,
|
input_body = {
|
||||||
accept="application/json",
|
"messages": [
|
||||||
contentType="application/json",
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": prompt}]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
|
||||||
|
"temperature": self.model_kwargs["temperature"] or 0.1,
|
||||||
|
"top_p": self.model_kwargs["top_p"] or 0.9,
|
||||||
|
"anthropic_version": "bedrock-2023-05-31",
|
||||||
|
}
|
||||||
|
|
||||||
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
|
|
||||||
|
response = self.client.invoke_model(
|
||||||
|
body=body,
|
||||||
|
modelId=self.config.model,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = self.client.invoke_model(
|
||||||
|
body=body,
|
||||||
|
modelId=self.config.model,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._parse_response(response, tools)
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -69,17 +69,14 @@ class Memory(MemoryBase):
|
|||||||
self.enable_graph = True
|
self.enable_graph = True
|
||||||
else:
|
else:
|
||||||
self.graph = None
|
self.graph = None
|
||||||
|
|
||||||
self.config.vector_store.config.collection_name = "mem0migrations"
|
self.config.vector_store.config.collection_name = "mem0migrations"
|
||||||
if self.config.vector_store.provider in ["faiss", "qdrant"]:
|
if self.config.vector_store.provider in ["faiss", "qdrant"]:
|
||||||
provider_path = f"migrations_{self.config.vector_store.provider}"
|
provider_path = f"migrations_{self.config.vector_store.provider}"
|
||||||
self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path)
|
self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path)
|
||||||
os.makedirs(self.config.vector_store.config.path, exist_ok=True)
|
os.makedirs(self.config.vector_store.config.path, exist_ok=True)
|
||||||
|
|
||||||
self._telemetry_vector_store = VectorStoreFactory.create(
|
self._telemetry_vector_store = VectorStoreFactory.create(
|
||||||
self.config.vector_store.provider, self.config.vector_store.config
|
self.config.vector_store.provider, self.config.vector_store.config
|
||||||
)
|
)
|
||||||
|
|
||||||
capture_event("mem0.init", self, {"sync_type": "sync"})
|
capture_event("mem0.init", self, {"sync_type": "sync"})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def get_or_create_user_id(vector_store):
|
|||||||
|
|
||||||
# Try to get existing user_id from vector store
|
# Try to get existing user_id from vector store
|
||||||
try:
|
try:
|
||||||
existing = vector_store.get(vector_id=VECTOR_ID)
|
existing = vector_store.get(vector_id=user_id)
|
||||||
if existing and hasattr(existing, "payload") and existing.payload and "user_id" in existing.payload:
|
if existing and hasattr(existing, "payload") and existing.payload and "user_id" in existing.payload:
|
||||||
return existing.payload["user_id"]
|
return existing.payload["user_id"]
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -48,7 +48,7 @@ def get_or_create_user_id(vector_store):
|
|||||||
try:
|
try:
|
||||||
dims = getattr(vector_store, "embedding_model_dims", 1536)
|
dims = getattr(vector_store, "embedding_model_dims", 1536)
|
||||||
vector_store.insert(
|
vector_store.insert(
|
||||||
vectors=[[0.0] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[VECTOR_ID]
|
vectors=[[0.1] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[user_id]
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
import time
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from opensearchpy import OpenSearch, RequestsHttpConnection
|
from opensearchpy import OpenSearch, RequestsHttpConnection
|
||||||
@@ -34,28 +35,26 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
use_ssl=config.use_ssl,
|
use_ssl=config.use_ssl,
|
||||||
verify_certs=config.verify_certs,
|
verify_certs=config.verify_certs,
|
||||||
connection_class=RequestsHttpConnection,
|
connection_class=RequestsHttpConnection,
|
||||||
|
pool_maxsize=20
|
||||||
)
|
)
|
||||||
|
|
||||||
self.collection_name = config.collection_name
|
self.collection_name = config.collection_name
|
||||||
self.embedding_model_dims = config.embedding_model_dims
|
self.embedding_model_dims = config.embedding_model_dims
|
||||||
|
self.create_col(self.collection_name, self.embedding_model_dims)
|
||||||
# Create index only if auto_create_index is True
|
|
||||||
if config.auto_create_index:
|
|
||||||
self.create_index()
|
|
||||||
|
|
||||||
def create_index(self) -> None:
|
def create_index(self) -> None:
|
||||||
"""Create OpenSearch index with proper mappings if it doesn't exist."""
|
"""Create OpenSearch index with proper mappings if it doesn't exist."""
|
||||||
index_settings = {
|
index_settings = {
|
||||||
"settings": {
|
"settings": {
|
||||||
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True}
|
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "10s", "knn": True}
|
||||||
},
|
},
|
||||||
"mappings": {
|
"mappings": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"text": {"type": "text"},
|
"text": {"type": "text"},
|
||||||
"vector": {
|
"vector_field": {
|
||||||
"type": "knn_vector",
|
"type": "knn_vector",
|
||||||
"dimension": self.embedding_model_dims,
|
"dimension": self.embedding_model_dims,
|
||||||
"method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"},
|
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
|
||||||
},
|
},
|
||||||
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
|
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
|
||||||
}
|
}
|
||||||
@@ -71,12 +70,15 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
def create_col(self, name: str, vector_size: int) -> None:
|
def create_col(self, name: str, vector_size: int) -> None:
|
||||||
"""Create a new collection (index in OpenSearch)."""
|
"""Create a new collection (index in OpenSearch)."""
|
||||||
index_settings = {
|
index_settings = {
|
||||||
|
"settings": {
|
||||||
|
"index.knn": True
|
||||||
|
},
|
||||||
"mappings": {
|
"mappings": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"vector": {
|
"vector_field": {
|
||||||
"type": "knn_vector",
|
"type": "knn_vector",
|
||||||
"dimension": vector_size,
|
"dimension": vector_size,
|
||||||
"method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"},
|
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
|
||||||
},
|
},
|
||||||
"payload": {"type": "object"},
|
"payload": {"type": "object"},
|
||||||
"id": {"type": "keyword"},
|
"id": {"type": "keyword"},
|
||||||
@@ -88,6 +90,24 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
self.client.indices.create(index=name, body=index_settings)
|
self.client.indices.create(index=name, body=index_settings)
|
||||||
logger.info(f"Created index {name}")
|
logger.info(f"Created index {name}")
|
||||||
|
|
||||||
|
# Wait for index to be ready
|
||||||
|
max_retries = 60 # 60 seconds timeout
|
||||||
|
retry_count = 0
|
||||||
|
while retry_count < max_retries:
|
||||||
|
try:
|
||||||
|
# Check if index is ready by attempting a simple search
|
||||||
|
self.client.search(index=name, body={"query": {"match_all": {}}})
|
||||||
|
logger.info(f"Index {name} is ready")
|
||||||
|
time.sleep(1)
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count == max_retries:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Index {name} creation timed out after {max_retries} seconds"
|
||||||
|
)
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
def insert(
|
def insert(
|
||||||
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
|
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
|
||||||
) -> List[OutputData]:
|
) -> List[OutputData]:
|
||||||
@@ -98,74 +118,161 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
if payloads is None:
|
if payloads is None:
|
||||||
payloads = [{} for _ in range(len(vectors))]
|
payloads = [{} for _ in range(len(vectors))]
|
||||||
|
|
||||||
actions = []
|
|
||||||
for i, (vec, id_) in enumerate(zip(vectors, ids)):
|
for i, (vec, id_) in enumerate(zip(vectors, ids)):
|
||||||
action = {
|
body = {
|
||||||
"_index": self.collection_name,
|
"vector_field": vec,
|
||||||
"_id": id_,
|
"payload": payloads[i],
|
||||||
"_source": {
|
"id": id_,
|
||||||
"vector": vec,
|
|
||||||
"metadata": payloads[i], # Store metadata in the metadata field
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
actions.append(action)
|
self.client.index(index=self.collection_name, body=body)
|
||||||
|
|
||||||
bulk(self.client, actions)
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i, id_ in enumerate(ids):
|
|
||||||
results.append(OutputData(id=id_, score=1.0, payload=payloads[i]))
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
||||||
) -> List[OutputData]:
|
) -> List[OutputData]:
|
||||||
"""Search for similar vectors using OpenSearch k-NN search with pre-filtering."""
|
"""Search for similar vectors using OpenSearch k-NN search with optional filters."""
|
||||||
search_query = {
|
|
||||||
"size": limit,
|
# Base KNN query
|
||||||
"query": {
|
knn_query = {
|
||||||
"knn": {
|
"knn": {
|
||||||
"vector": {
|
"vector_field": {
|
||||||
"vector": vectors,
|
"vector": vectors,
|
||||||
"k": limit,
|
"k": limit * 2,
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Start building the full query
|
||||||
|
query_body = {
|
||||||
|
"size": limit * 2,
|
||||||
|
"query": None
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prepare filter conditions if applicable
|
||||||
|
filter_clauses = []
|
||||||
if filters:
|
if filters:
|
||||||
filter_conditions = [{"term": {f"metadata.{key}": value}} for key, value in filters.items()]
|
for key in ["user_id", "run_id", "agent_id"]:
|
||||||
search_query["query"]["knn"]["vector"]["filter"] = {"bool": {"filter": filter_conditions}}
|
value = filters.get(key)
|
||||||
|
if value:
|
||||||
|
filter_clauses.append({
|
||||||
|
"term": {f"payload.{key}.keyword": value}
|
||||||
|
})
|
||||||
|
|
||||||
response = self.client.search(index=self.collection_name, body=search_query)
|
# Combine knn with filters if needed
|
||||||
|
if filter_clauses:
|
||||||
|
query_body["query"] = {
|
||||||
|
"bool": {
|
||||||
|
"must": knn_query,
|
||||||
|
"filter": filter_clauses
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
query_body["query"] = knn_query
|
||||||
|
|
||||||
|
# Execute search
|
||||||
|
response = self.client.search(index=self.collection_name, body=query_body)
|
||||||
|
|
||||||
|
hits = response["hits"]["hits"]
|
||||||
results = [
|
results = [
|
||||||
OutputData(id=hit["_id"], score=hit["_score"], payload=hit["_source"].get("metadata", {}))
|
OutputData(
|
||||||
for hit in response["hits"]["hits"]
|
id=hit["_source"].get("id"),
|
||||||
|
score=hit["_score"],
|
||||||
|
payload=hit["_source"].get("payload", {})
|
||||||
|
)
|
||||||
|
for hit in hits
|
||||||
]
|
]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def delete(self, vector_id: str) -> None:
|
def delete(self, vector_id: str) -> None:
|
||||||
"""Delete a vector by ID."""
|
"""Delete a vector by custom ID."""
|
||||||
self.client.delete(index=self.collection_name, id=vector_id)
|
# First, find the document by custom ID
|
||||||
|
search_query = {
|
||||||
|
"query": {
|
||||||
|
"term": {
|
||||||
|
"id": vector_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response = self.client.search(index=self.collection_name, body=search_query)
|
||||||
|
hits = response.get("hits", {}).get("hits", [])
|
||||||
|
|
||||||
|
if not hits:
|
||||||
|
return
|
||||||
|
|
||||||
|
opensearch_id = hits[0]["_id"]
|
||||||
|
|
||||||
|
# Delete using the actual document ID
|
||||||
|
self.client.delete(index=self.collection_name, id=opensearch_id)
|
||||||
|
|
||||||
|
|
||||||
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
||||||
"""Update a vector and its payload."""
|
"""Update a vector and its payload using the custom 'id' field."""
|
||||||
|
|
||||||
|
# First, find the document by custom ID
|
||||||
|
search_query = {
|
||||||
|
"query": {
|
||||||
|
"term": {
|
||||||
|
"id": vector_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response = self.client.search(index=self.collection_name, body=search_query)
|
||||||
|
hits = response.get("hits", {}).get("hits", [])
|
||||||
|
|
||||||
|
if not hits:
|
||||||
|
return
|
||||||
|
|
||||||
|
opensearch_id = hits[0]["_id"] # The actual document ID in OpenSearch
|
||||||
|
|
||||||
|
# Prepare updated fields
|
||||||
doc = {}
|
doc = {}
|
||||||
if vector is not None:
|
if vector is not None:
|
||||||
doc["vector"] = vector
|
doc["vector_field"] = vector
|
||||||
if payload is not None:
|
if payload is not None:
|
||||||
doc["metadata"] = payload
|
doc["payload"] = payload
|
||||||
|
|
||||||
|
if doc:
|
||||||
|
try:
|
||||||
|
response = self.client.update(index=self.collection_name, id=opensearch_id, body={"doc": doc})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
self.client.update(index=self.collection_name, id=vector_id, body={"doc": doc})
|
|
||||||
|
|
||||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||||
"""Retrieve a vector by ID."""
|
"""Retrieve a vector by ID."""
|
||||||
try:
|
try:
|
||||||
response = self.client.get(index=self.collection_name, id=vector_id)
|
# First check if index exists
|
||||||
return OutputData(id=response["_id"], score=1.0, payload=response["_source"].get("metadata", {}))
|
if not self.client.indices.exists(index=self.collection_name):
|
||||||
|
logger.info(f"Index {self.collection_name} does not exist, creating it...")
|
||||||
|
self.create_col(self.collection_name, self.embedding_model_dims)
|
||||||
|
return None
|
||||||
|
|
||||||
|
search_query = {
|
||||||
|
"query": {
|
||||||
|
"term": {
|
||||||
|
"id": vector_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response = self.client.search(index=self.collection_name, body=search_query)
|
||||||
|
|
||||||
|
hits = response["hits"]["hits"]
|
||||||
|
|
||||||
|
if not hits:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return OutputData(
|
||||||
|
id=hits[0]["_source"].get("id"),
|
||||||
|
score=1.0,
|
||||||
|
payload=hits[0]["_source"].get("payload", {})
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving vector {vector_id}: {e}")
|
logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def list_cols(self) -> List[str]:
|
def list_cols(self) -> List[str]:
|
||||||
@@ -180,28 +287,52 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
"""Get information about a collection (index)."""
|
"""Get information about a collection (index)."""
|
||||||
return self.client.indices.get(index=name)
|
return self.client.indices.get(index=name)
|
||||||
|
|
||||||
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
|
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
|
||||||
"""List all memories."""
|
|
||||||
query = {"query": {"match_all": {}}}
|
|
||||||
|
|
||||||
if filters:
|
try:
|
||||||
query["query"] = {
|
"""List all memories with optional filters."""
|
||||||
"bool": {"must": [{"term": {f"metadata.{key}": value}} for key, value in filters.items()]}
|
query: Dict = {
|
||||||
|
"query": {
|
||||||
|
"match_all": {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if limit:
|
filter_clauses = []
|
||||||
query["size"] = limit
|
if filters:
|
||||||
|
for key in ["user_id", "run_id", "agent_id"]:
|
||||||
|
value = filters.get(key)
|
||||||
|
if value:
|
||||||
|
filter_clauses.append({
|
||||||
|
"term": {f"payload.{key}.keyword": value}
|
||||||
|
})
|
||||||
|
|
||||||
|
if filter_clauses:
|
||||||
|
query["query"] = {
|
||||||
|
"bool": {
|
||||||
|
"filter": filter_clauses
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
query["size"] = limit
|
||||||
|
|
||||||
|
response = self.client.search(index=self.collection_name, body=query)
|
||||||
|
hits = response["hits"]["hits"]
|
||||||
|
|
||||||
|
return [[
|
||||||
|
OutputData(
|
||||||
|
id=hit["_source"].get("id"),
|
||||||
|
score=1.0,
|
||||||
|
payload=hit["_source"].get("payload", {})
|
||||||
|
)
|
||||||
|
for hit in hits
|
||||||
|
]]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
response = self.client.search(index=self.collection_name, body=query)
|
|
||||||
return [
|
|
||||||
[
|
|
||||||
OutputData(id=hit["_id"], score=1.0, payload=hit["_source"].get("metadata", {}))
|
|
||||||
for hit in response["hits"]["hits"]
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the index by deleting and recreating it."""
|
"""Reset the index by deleting and recreating it."""
|
||||||
logger.warning(f"Resetting index {self.collection_name}...")
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
self.delete_col()
|
self.delete_col()
|
||||||
self.create_index()
|
self.create_col(self.collection_name, self.embedding_model_dims)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "mem0ai"
|
name = "mem0ai"
|
||||||
version = "0.1.98"
|
version = "0.1.99"
|
||||||
description = "Long-term memory for AI Agents"
|
description = "Long-term memory for AI Agents"
|
||||||
authors = ["Mem0 <founders@mem0.ai>"]
|
authors = ["Mem0 <founders@mem0.ai>"]
|
||||||
exclude = [
|
exclude = [
|
||||||
|
|||||||
Reference in New Issue
Block a user