improvement(OSS): Fix AOSS and AWS BedRock LLM (#2697)

Co-authored-by: Prateek Chhikara <prateekchhikara24@gmail.com>
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Saket Aryan
2025-05-16 04:49:29 +05:30
committed by GitHub
parent 267e5b13ea
commit 5c67a5e6bc
14 changed files with 502 additions and 127 deletions

View File

@@ -33,6 +33,10 @@ class BaseEmbedderConfig(ABC):
memory_search_embedding_type: Optional[str] = None,
# LM Studio specific
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
# AWS Bedrock specific
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region: Optional[str] = "us-west-2",
):
"""
Initializes a configuration class instance for the Embeddings.
@@ -92,3 +96,8 @@ class BaseEmbedderConfig(ABC):
# LM Studio specific
self.lmstudio_base_url = lmstudio_base_url
# AWS Bedrock specific
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_region = aws_region

View File

@@ -41,6 +41,10 @@ class BaseLlmConfig(ABC):
xai_base_url: Optional[str] = None,
# LM Studio specific
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
# AWS Bedrock specific
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region: Optional[str] = "us-west-2",
):
"""
Initializes a configuration class instance for the LLM.
@@ -123,3 +127,8 @@ class BaseLlmConfig(ABC):
# LM Studio specific
self.lmstudio_base_url = lmstudio_base_url
# AWS Bedrock specific
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_region = aws_region

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union, Type
from pydantic import BaseModel, Field, model_validator
@@ -7,14 +7,33 @@ class OpenSearchConfig(BaseModel):
collection_name: str = Field("mem0", description="Name of the index")
host: str = Field("localhost", description="OpenSearch host")
port: int = Field(9200, description="OpenSearch port")
user: Optional[str] = Field(None, description="Username for authentication")
password: Optional[str] = Field(None, description="Password for authentication")
api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
user: Optional[str] = Field(
None, description="Username for authentication"
)
password: Optional[str] = Field(
None, description="Password for authentication"
)
api_key: Optional[str] = Field(
None, description="API key for authentication (if applicable)"
)
embedding_model_dims: int = Field(
1536, description="Dimension of the embedding vector"
)
verify_certs: bool = Field(
False, description="Verify SSL certificates (default False for OpenSearch)"
)
use_ssl: bool = Field(
False, description="Use SSL for connection (default False for OpenSearch)"
)
http_auth: Optional[object] = Field(
None, description="HTTP authentication method / AWS SigV4"
)
connection_class: Optional[Union[str, Type]] = Field(
"RequestsHttpConnection", description="Connection class for OpenSearch"
)
pool_maxsize: int = Field(
20, description="Maximum number of connections in the pool"
)
@model_validator(mode="before")
@classmethod
@@ -22,11 +41,7 @@ class OpenSearchConfig(BaseModel):
# Check if host is provided
if not values.get("host"):
raise ValueError("Host must be provided for OpenSearch")
# Authentication: Either API key or user/password must be provided
if not any([values.get("api_key"), (values.get("user") and values.get("password")), values.get("http_auth")]):
raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication")
return values
@model_validator(mode="before")
@@ -37,6 +52,7 @@ class OpenSearchConfig(BaseModel):
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}"
f"Extra fields not allowed: {', '.join(extra_fields)}. "
f"Allowed fields: {', '.join(allowed_fields)}"
)
return values

View File

@@ -1,4 +1,5 @@
import json
import os
from typing import Literal, Optional
try:
@@ -22,7 +23,26 @@ class AWSBedrockEmbedding(EmbeddingBase):
super().__init__(config)
self.config.model = self.config.model or "amazon.titan-embed-text-v1"
self.client = boto3.client("bedrock-runtime")
# Get AWS config from environment variables or use defaults
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
aws_region = os.environ.get("AWS_REGION", "us-west-2")
# Check if AWS config is provided in the config
if hasattr(self.config, "aws_access_key_id"):
aws_access_key = self.config.aws_access_key_id
if hasattr(self.config, "aws_secret_access_key"):
aws_secret_key = self.config.aws_secret_access_key
if hasattr(self.config, "aws_region"):
aws_region = self.config.aws_region
self.client = boto3.client(
"bedrock-runtime",
region_name=aws_region,
aws_access_key_id=aws_access_key if aws_access_key else None,
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
)
def _normalize_vector(self, embeddings):
"""Normalize the embedding to a unit vector."""

View File

@@ -1,4 +1,6 @@
import json
import os
import re
from typing import Any, Dict, List, Optional
try:
@@ -9,6 +11,14 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
def extract_provider(model: str) -> str:
for provider in PROVIDERS:
if re.search(rf"\b{re.escape(provider)}\b", model):
return provider
raise ValueError(f"Unknown provider in model: {model}")
class AWSBedrockLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
@@ -16,7 +26,27 @@ class AWSBedrockLLM(LLMBase):
if not self.config.model:
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
self.client = boto3.client("bedrock-runtime")
# Get AWS config from environment variables or use defaults
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
aws_region = os.environ.get("AWS_REGION", "us-west-2")
# Check if AWS config is provided in the config
if hasattr(self.config, "aws_access_key_id"):
aws_access_key = self.config.aws_access_key_id
if hasattr(self.config, "aws_secret_access_key"):
aws_secret_key = self.config.aws_secret_access_key
if hasattr(self.config, "aws_region"):
aws_region = self.config.aws_region
self.client = boto3.client(
"bedrock-runtime",
region_name=aws_region,
aws_access_key_id=aws_access_key if aws_access_key else None,
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
)
self.model_kwargs = {
"temperature": self.config.temperature,
"max_tokens_to_sample": self.config.max_tokens,
@@ -34,13 +64,14 @@ class AWSBedrockLLM(LLMBase):
Returns:
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
"""
formatted_messages = []
for message in messages:
role = message["role"].capitalize()
content = message["content"]
formatted_messages.append(f"\n\n{role}: {content}")
return "".join(formatted_messages) + "\n\nAssistant:"
return "\n\nHuman: " + "".join(formatted_messages) + "\n\nAssistant:"
def _parse_response(self, response, tools) -> str:
"""
@@ -68,8 +99,9 @@ class AWSBedrockLLM(LLMBase):
return processed_response
response_body = json.loads(response["body"].read().decode())
return response_body.get("completion", "")
response_body = response.get("body").read().decode()
response_json = json.loads(response_body)
return response_json.get("content", [{"text": ""}])[0].get("text", "")
def _prepare_input(
self,
@@ -113,9 +145,9 @@ class AWSBedrockLLM(LLMBase):
input_body = {
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": model_kwargs.get("max_tokens_to_sample"),
"topP": model_kwargs.get("top_p"),
"temperature": model_kwargs.get("temperature"),
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
"topP": self.model_kwargs["top_p"] or 0.9,
"temperature": self.model_kwargs["temperature"] or 0.1,
},
}
input_body["textGenerationConfig"] = {
@@ -206,15 +238,40 @@ class AWSBedrockLLM(LLMBase):
else:
# Use invoke_model method when no tools are provided
prompt = self._format_messages(messages)
provider = self.model.split(".")[0]
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
provider = extract_provider(self.config.model)
input_body = self._prepare_input(provider, self.config.model, prompt, model_kwargs=self.model_kwargs)
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=self.model,
accept="application/json",
contentType="application/json",
if provider == "anthropic" or provider == "deepseek":
input_body = {
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
}
],
"max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
"temperature": self.model_kwargs["temperature"] or 0.1,
"top_p": self.model_kwargs["top_p"] or 0.9,
"anthropic_version": "bedrock-2023-05-31",
}
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=self.config.model,
accept="application/json",
contentType="application/json",
)
else:
response = self.client.invoke_model(
body=body,
modelId=self.config.model,
accept="application/json",
contentType="application/json",
)
return self._parse_response(response, tools)

View File

@@ -69,17 +69,14 @@ class Memory(MemoryBase):
self.enable_graph = True
else:
self.graph = None
self.config.vector_store.config.collection_name = "mem0migrations"
if self.config.vector_store.provider in ["faiss", "qdrant"]:
provider_path = f"migrations_{self.config.vector_store.provider}"
self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path)
os.makedirs(self.config.vector_store.config.path, exist_ok=True)
self._telemetry_vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
capture_event("mem0.init", self, {"sync_type": "sync"})
@classmethod

View File

@@ -38,7 +38,7 @@ def get_or_create_user_id(vector_store):
# Try to get existing user_id from vector store
try:
existing = vector_store.get(vector_id=VECTOR_ID)
existing = vector_store.get(vector_id=user_id)
if existing and hasattr(existing, "payload") and existing.payload and "user_id" in existing.payload:
return existing.payload["user_id"]
except Exception:
@@ -48,7 +48,7 @@ def get_or_create_user_id(vector_store):
try:
dims = getattr(vector_store, "embedding_model_dims", 1536)
vector_store.insert(
vectors=[[0.0] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[VECTOR_ID]
vectors=[[0.1] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[user_id]
)
except Exception:
pass

View File

@@ -1,5 +1,6 @@
import logging
from typing import Any, Dict, List, Optional
import time
try:
from opensearchpy import OpenSearch, RequestsHttpConnection
@@ -34,28 +35,26 @@ class OpenSearchDB(VectorStoreBase):
use_ssl=config.use_ssl,
verify_certs=config.verify_certs,
connection_class=RequestsHttpConnection,
pool_maxsize=20
)
self.collection_name = config.collection_name
self.embedding_model_dims = config.embedding_model_dims
# Create index only if auto_create_index is True
if config.auto_create_index:
self.create_index()
self.create_col(self.collection_name, self.embedding_model_dims)
def create_index(self) -> None:
"""Create OpenSearch index with proper mappings if it doesn't exist."""
index_settings = {
"settings": {
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True}
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "10s", "knn": True}
},
"mappings": {
"properties": {
"text": {"type": "text"},
"vector": {
"vector_field": {
"type": "knn_vector",
"dimension": self.embedding_model_dims,
"method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"},
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
},
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
}
@@ -71,12 +70,15 @@ class OpenSearchDB(VectorStoreBase):
def create_col(self, name: str, vector_size: int) -> None:
"""Create a new collection (index in OpenSearch)."""
index_settings = {
"settings": {
"index.knn": True
},
"mappings": {
"properties": {
"vector": {
"vector_field": {
"type": "knn_vector",
"dimension": vector_size,
"method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"},
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
},
"payload": {"type": "object"},
"id": {"type": "keyword"},
@@ -88,6 +90,24 @@ class OpenSearchDB(VectorStoreBase):
self.client.indices.create(index=name, body=index_settings)
logger.info(f"Created index {name}")
# Wait for index to be ready
max_retries = 60 # 60 seconds timeout
retry_count = 0
while retry_count < max_retries:
try:
# Check if index is ready by attempting a simple search
self.client.search(index=name, body={"query": {"match_all": {}}})
logger.info(f"Index {name} is ready")
time.sleep(1)
return
except Exception:
retry_count += 1
if retry_count == max_retries:
raise TimeoutError(
f"Index {name} creation timed out after {max_retries} seconds"
)
time.sleep(0.5)
def insert(
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
) -> List[OutputData]:
@@ -98,74 +118,161 @@ class OpenSearchDB(VectorStoreBase):
if payloads is None:
payloads = [{} for _ in range(len(vectors))]
actions = []
for i, (vec, id_) in enumerate(zip(vectors, ids)):
action = {
"_index": self.collection_name,
"_id": id_,
"_source": {
"vector": vec,
"metadata": payloads[i], # Store metadata in the metadata field
},
body = {
"vector_field": vec,
"payload": payloads[i],
"id": id_,
}
actions.append(action)
bulk(self.client, actions)
self.client.index(index=self.collection_name, body=body)
results = []
for i, id_ in enumerate(ids):
results.append(OutputData(id=id_, score=1.0, payload=payloads[i]))
return results
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""Search for similar vectors using OpenSearch k-NN search with pre-filtering."""
search_query = {
"size": limit,
"query": {
"knn": {
"vector": {
"vector": vectors,
"k": limit,
}
"""Search for similar vectors using OpenSearch k-NN search with optional filters."""
# Base KNN query
knn_query = {
"knn": {
"vector_field": {
"vector": vectors,
"k": limit * 2,
}
},
}
}
# Start building the full query
query_body = {
"size": limit * 2,
"query": None
}
# Prepare filter conditions if applicable
filter_clauses = []
if filters:
filter_conditions = [{"term": {f"metadata.{key}": value}} for key, value in filters.items()]
search_query["query"]["knn"]["vector"]["filter"] = {"bool": {"filter": filter_conditions}}
for key in ["user_id", "run_id", "agent_id"]:
value = filters.get(key)
if value:
filter_clauses.append({
"term": {f"payload.{key}.keyword": value}
})
response = self.client.search(index=self.collection_name, body=search_query)
# Combine knn with filters if needed
if filter_clauses:
query_body["query"] = {
"bool": {
"must": knn_query,
"filter": filter_clauses
}
}
else:
query_body["query"] = knn_query
# Execute search
response = self.client.search(index=self.collection_name, body=query_body)
hits = response["hits"]["hits"]
results = [
OutputData(id=hit["_id"], score=hit["_score"], payload=hit["_source"].get("metadata", {}))
for hit in response["hits"]["hits"]
OutputData(
id=hit["_source"].get("id"),
score=hit["_score"],
payload=hit["_source"].get("payload", {})
)
for hit in hits
]
return results
def delete(self, vector_id: str) -> None:
"""Delete a vector by ID."""
self.client.delete(index=self.collection_name, id=vector_id)
"""Delete a vector by custom ID."""
# First, find the document by custom ID
search_query = {
"query": {
"term": {
"id": vector_id
}
}
}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response.get("hits", {}).get("hits", [])
if not hits:
return
opensearch_id = hits[0]["_id"]
# Delete using the actual document ID
self.client.delete(index=self.collection_name, id=opensearch_id)
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
"""Update a vector and its payload."""
"""Update a vector and its payload using the custom 'id' field."""
# First, find the document by custom ID
search_query = {
"query": {
"term": {
"id": vector_id
}
}
}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response.get("hits", {}).get("hits", [])
if not hits:
return
opensearch_id = hits[0]["_id"] # The actual document ID in OpenSearch
# Prepare updated fields
doc = {}
if vector is not None:
doc["vector"] = vector
doc["vector_field"] = vector
if payload is not None:
doc["metadata"] = payload
doc["payload"] = payload
if doc:
try:
response = self.client.update(index=self.collection_name, id=opensearch_id, body={"doc": doc})
except Exception:
pass
self.client.update(index=self.collection_name, id=vector_id, body={"doc": doc})
def get(self, vector_id: str) -> Optional[OutputData]:
"""Retrieve a vector by ID."""
try:
response = self.client.get(index=self.collection_name, id=vector_id)
return OutputData(id=response["_id"], score=1.0, payload=response["_source"].get("metadata", {}))
# First check if index exists
if not self.client.indices.exists(index=self.collection_name):
logger.info(f"Index {self.collection_name} does not exist, creating it...")
self.create_col(self.collection_name, self.embedding_model_dims)
return None
search_query = {
"query": {
"term": {
"id": vector_id
}
}
}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response["hits"]["hits"]
if not hits:
return None
return OutputData(
id=hits[0]["_source"].get("id"),
score=1.0,
payload=hits[0]["_source"].get("payload", {})
)
except Exception as e:
logger.error(f"Error retrieving vector {vector_id}: {e}")
logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
return None
def list_cols(self) -> List[str]:
@@ -180,28 +287,52 @@ class OpenSearchDB(VectorStoreBase):
"""Get information about a collection (index)."""
return self.client.indices.get(index=name)
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
"""List all memories."""
query = {"query": {"match_all": {}}}
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
if filters:
query["query"] = {
"bool": {"must": [{"term": {f"metadata.{key}": value}} for key, value in filters.items()]}
try:
"""List all memories with optional filters."""
query: Dict = {
"query": {
"match_all": {}
}
}
if limit:
query["size"] = limit
filter_clauses = []
if filters:
for key in ["user_id", "run_id", "agent_id"]:
value = filters.get(key)
if value:
filter_clauses.append({
"term": {f"payload.{key}.keyword": value}
})
if filter_clauses:
query["query"] = {
"bool": {
"filter": filter_clauses
}
}
if limit:
query["size"] = limit
response = self.client.search(index=self.collection_name, body=query)
hits = response["hits"]["hits"]
return [[
OutputData(
id=hit["_source"].get("id"),
score=1.0,
payload=hit["_source"].get("payload", {})
)
for hit in hits
]]
except Exception:
return []
response = self.client.search(index=self.collection_name, body=query)
return [
[
OutputData(id=hit["_id"], score=1.0, payload=hit["_source"].get("metadata", {}))
for hit in response["hits"]["hits"]
]
]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_index()
self.create_col(self.collection_name, self.embedding_model_dims)