improvement(OSS): Fix AOSS and AWS BedRock LLM (#2697)
Co-authored-by: Prateek Chhikara <prateekchhikara24@gmail.com> Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -33,6 +33,10 @@ class BaseEmbedderConfig(ABC):
|
||||
memory_search_embedding_type: Optional[str] = None,
|
||||
# LM Studio specific
|
||||
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
||||
# AWS Bedrock specific
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region: Optional[str] = "us-west-2",
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the Embeddings.
|
||||
@@ -92,3 +96,8 @@ class BaseEmbedderConfig(ABC):
|
||||
|
||||
# LM Studio specific
|
||||
self.lmstudio_base_url = lmstudio_base_url
|
||||
|
||||
# AWS Bedrock specific
|
||||
self.aws_access_key_id = aws_access_key_id
|
||||
self.aws_secret_access_key = aws_secret_access_key
|
||||
self.aws_region = aws_region
|
||||
|
||||
@@ -41,6 +41,10 @@ class BaseLlmConfig(ABC):
|
||||
xai_base_url: Optional[str] = None,
|
||||
# LM Studio specific
|
||||
lmstudio_base_url: Optional[str] = "http://localhost:1234/v1",
|
||||
# AWS Bedrock specific
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region: Optional[str] = "us-west-2",
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the LLM.
|
||||
@@ -123,3 +127,8 @@ class BaseLlmConfig(ABC):
|
||||
|
||||
# LM Studio specific
|
||||
self.lmstudio_base_url = lmstudio_base_url
|
||||
|
||||
# AWS Bedrock specific
|
||||
self.aws_access_key_id = aws_access_key_id
|
||||
self.aws_secret_access_key = aws_secret_access_key
|
||||
self.aws_region = aws_region
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Union, Type
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -7,14 +7,33 @@ class OpenSearchConfig(BaseModel):
|
||||
collection_name: str = Field("mem0", description="Name of the index")
|
||||
host: str = Field("localhost", description="OpenSearch host")
|
||||
port: int = Field(9200, description="OpenSearch port")
|
||||
user: Optional[str] = Field(None, description="Username for authentication")
|
||||
password: Optional[str] = Field(None, description="Password for authentication")
|
||||
api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)")
|
||||
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector")
|
||||
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)")
|
||||
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)")
|
||||
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
|
||||
http_auth: Optional[object] = Field(None, description="HTTP authentication method / AWS SigV4")
|
||||
user: Optional[str] = Field(
|
||||
None, description="Username for authentication"
|
||||
)
|
||||
password: Optional[str] = Field(
|
||||
None, description="Password for authentication"
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
None, description="API key for authentication (if applicable)"
|
||||
)
|
||||
embedding_model_dims: int = Field(
|
||||
1536, description="Dimension of the embedding vector"
|
||||
)
|
||||
verify_certs: bool = Field(
|
||||
False, description="Verify SSL certificates (default False for OpenSearch)"
|
||||
)
|
||||
use_ssl: bool = Field(
|
||||
False, description="Use SSL for connection (default False for OpenSearch)"
|
||||
)
|
||||
http_auth: Optional[object] = Field(
|
||||
None, description="HTTP authentication method / AWS SigV4"
|
||||
)
|
||||
connection_class: Optional[Union[str, Type]] = Field(
|
||||
"RequestsHttpConnection", description="Connection class for OpenSearch"
|
||||
)
|
||||
pool_maxsize: int = Field(
|
||||
20, description="Maximum number of connections in the pool"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -22,11 +41,7 @@ class OpenSearchConfig(BaseModel):
|
||||
# Check if host is provided
|
||||
if not values.get("host"):
|
||||
raise ValueError("Host must be provided for OpenSearch")
|
||||
|
||||
# Authentication: Either API key or user/password must be provided
|
||||
if not any([values.get("api_key"), (values.get("user") and values.get("password")), values.get("http_auth")]):
|
||||
raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication")
|
||||
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -37,6 +52,7 @@ class OpenSearchConfig(BaseModel):
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. " f"Allowed fields: {', '.join(allowed_fields)}"
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. "
|
||||
f"Allowed fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
|
||||
try:
|
||||
@@ -22,7 +23,26 @@ class AWSBedrockEmbedding(EmbeddingBase):
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = self.config.model or "amazon.titan-embed-text-v1"
|
||||
self.client = boto3.client("bedrock-runtime")
|
||||
|
||||
# Get AWS config from environment variables or use defaults
|
||||
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
||||
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
||||
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
||||
|
||||
# Check if AWS config is provided in the config
|
||||
if hasattr(self.config, "aws_access_key_id"):
|
||||
aws_access_key = self.config.aws_access_key_id
|
||||
if hasattr(self.config, "aws_secret_access_key"):
|
||||
aws_secret_key = self.config.aws_secret_access_key
|
||||
if hasattr(self.config, "aws_region"):
|
||||
aws_region = self.config.aws_region
|
||||
|
||||
self.client = boto3.client(
|
||||
"bedrock-runtime",
|
||||
region_name=aws_region,
|
||||
aws_access_key_id=aws_access_key if aws_access_key else None,
|
||||
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
||||
)
|
||||
|
||||
def _normalize_vector(self, embeddings):
|
||||
"""Normalize the embedding to a unit vector."""
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
@@ -9,6 +11,14 @@ except ImportError:
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
|
||||
|
||||
|
||||
def extract_provider(model: str) -> str:
|
||||
for provider in PROVIDERS:
|
||||
if re.search(rf"\b{re.escape(provider)}\b", model):
|
||||
return provider
|
||||
raise ValueError(f"Unknown provider in model: {model}")
|
||||
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
@@ -16,7 +26,27 @@ class AWSBedrockLLM(LLMBase):
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
self.client = boto3.client("bedrock-runtime")
|
||||
|
||||
# Get AWS config from environment variables or use defaults
|
||||
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
||||
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
||||
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
||||
|
||||
# Check if AWS config is provided in the config
|
||||
if hasattr(self.config, "aws_access_key_id"):
|
||||
aws_access_key = self.config.aws_access_key_id
|
||||
if hasattr(self.config, "aws_secret_access_key"):
|
||||
aws_secret_key = self.config.aws_secret_access_key
|
||||
if hasattr(self.config, "aws_region"):
|
||||
aws_region = self.config.aws_region
|
||||
|
||||
self.client = boto3.client(
|
||||
"bedrock-runtime",
|
||||
region_name=aws_region,
|
||||
aws_access_key_id=aws_access_key if aws_access_key else None,
|
||||
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
||||
)
|
||||
|
||||
self.model_kwargs = {
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens_to_sample": self.config.max_tokens,
|
||||
@@ -34,13 +64,14 @@ class AWSBedrockLLM(LLMBase):
|
||||
Returns:
|
||||
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
|
||||
"""
|
||||
|
||||
formatted_messages = []
|
||||
for message in messages:
|
||||
role = message["role"].capitalize()
|
||||
content = message["content"]
|
||||
formatted_messages.append(f"\n\n{role}: {content}")
|
||||
|
||||
return "".join(formatted_messages) + "\n\nAssistant:"
|
||||
return "\n\nHuman: " + "".join(formatted_messages) + "\n\nAssistant:"
|
||||
|
||||
def _parse_response(self, response, tools) -> str:
|
||||
"""
|
||||
@@ -68,8 +99,9 @@ class AWSBedrockLLM(LLMBase):
|
||||
|
||||
return processed_response
|
||||
|
||||
response_body = json.loads(response["body"].read().decode())
|
||||
return response_body.get("completion", "")
|
||||
response_body = response.get("body").read().decode()
|
||||
response_json = json.loads(response_body)
|
||||
return response_json.get("content", [{"text": ""}])[0].get("text", "")
|
||||
|
||||
def _prepare_input(
|
||||
self,
|
||||
@@ -113,9 +145,9 @@ class AWSBedrockLLM(LLMBase):
|
||||
input_body = {
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": {
|
||||
"maxTokenCount": model_kwargs.get("max_tokens_to_sample"),
|
||||
"topP": model_kwargs.get("top_p"),
|
||||
"temperature": model_kwargs.get("temperature"),
|
||||
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
|
||||
"topP": self.model_kwargs["top_p"] or 0.9,
|
||||
"temperature": self.model_kwargs["temperature"] or 0.1,
|
||||
},
|
||||
}
|
||||
input_body["textGenerationConfig"] = {
|
||||
@@ -206,15 +238,40 @@ class AWSBedrockLLM(LLMBase):
|
||||
else:
|
||||
# Use invoke_model method when no tools are provided
|
||||
prompt = self._format_messages(messages)
|
||||
provider = self.model.split(".")[0]
|
||||
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
|
||||
provider = extract_provider(self.config.model)
|
||||
input_body = self._prepare_input(provider, self.config.model, prompt, model_kwargs=self.model_kwargs)
|
||||
body = json.dumps(input_body)
|
||||
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
if provider == "anthropic" or provider == "deepseek":
|
||||
|
||||
input_body = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": prompt}]
|
||||
}
|
||||
],
|
||||
"max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
|
||||
"temperature": self.model_kwargs["temperature"] or 0.1,
|
||||
"top_p": self.model_kwargs["top_p"] or 0.9,
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
}
|
||||
|
||||
body = json.dumps(input_body)
|
||||
|
||||
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.config.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
else:
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.config.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
|
||||
@@ -69,17 +69,14 @@ class Memory(MemoryBase):
|
||||
self.enable_graph = True
|
||||
else:
|
||||
self.graph = None
|
||||
|
||||
self.config.vector_store.config.collection_name = "mem0migrations"
|
||||
if self.config.vector_store.provider in ["faiss", "qdrant"]:
|
||||
provider_path = f"migrations_{self.config.vector_store.provider}"
|
||||
self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path)
|
||||
os.makedirs(self.config.vector_store.config.path, exist_ok=True)
|
||||
|
||||
self._telemetry_vector_store = VectorStoreFactory.create(
|
||||
self.config.vector_store.provider, self.config.vector_store.config
|
||||
)
|
||||
|
||||
capture_event("mem0.init", self, {"sync_type": "sync"})
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -38,7 +38,7 @@ def get_or_create_user_id(vector_store):
|
||||
|
||||
# Try to get existing user_id from vector store
|
||||
try:
|
||||
existing = vector_store.get(vector_id=VECTOR_ID)
|
||||
existing = vector_store.get(vector_id=user_id)
|
||||
if existing and hasattr(existing, "payload") and existing.payload and "user_id" in existing.payload:
|
||||
return existing.payload["user_id"]
|
||||
except Exception:
|
||||
@@ -48,7 +48,7 @@ def get_or_create_user_id(vector_store):
|
||||
try:
|
||||
dims = getattr(vector_store, "embedding_model_dims", 1536)
|
||||
vector_store.insert(
|
||||
vectors=[[0.0] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[VECTOR_ID]
|
||||
vectors=[[0.1] * dims], payloads=[{"user_id": user_id, "type": "user_identity"}], ids=[user_id]
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
import time
|
||||
|
||||
try:
|
||||
from opensearchpy import OpenSearch, RequestsHttpConnection
|
||||
@@ -34,28 +35,26 @@ class OpenSearchDB(VectorStoreBase):
|
||||
use_ssl=config.use_ssl,
|
||||
verify_certs=config.verify_certs,
|
||||
connection_class=RequestsHttpConnection,
|
||||
pool_maxsize=20
|
||||
)
|
||||
|
||||
self.collection_name = config.collection_name
|
||||
self.embedding_model_dims = config.embedding_model_dims
|
||||
|
||||
# Create index only if auto_create_index is True
|
||||
if config.auto_create_index:
|
||||
self.create_index()
|
||||
self.create_col(self.collection_name, self.embedding_model_dims)
|
||||
|
||||
def create_index(self) -> None:
|
||||
"""Create OpenSearch index with proper mappings if it doesn't exist."""
|
||||
index_settings = {
|
||||
"settings": {
|
||||
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True}
|
||||
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "10s", "knn": True}
|
||||
},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"text": {"type": "text"},
|
||||
"vector": {
|
||||
"vector_field": {
|
||||
"type": "knn_vector",
|
||||
"dimension": self.embedding_model_dims,
|
||||
"method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"},
|
||||
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
|
||||
},
|
||||
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
|
||||
}
|
||||
@@ -71,12 +70,15 @@ class OpenSearchDB(VectorStoreBase):
|
||||
def create_col(self, name: str, vector_size: int) -> None:
|
||||
"""Create a new collection (index in OpenSearch)."""
|
||||
index_settings = {
|
||||
"settings": {
|
||||
"index.knn": True
|
||||
},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"vector": {
|
||||
"vector_field": {
|
||||
"type": "knn_vector",
|
||||
"dimension": vector_size,
|
||||
"method": {"engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"},
|
||||
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
|
||||
},
|
||||
"payload": {"type": "object"},
|
||||
"id": {"type": "keyword"},
|
||||
@@ -88,6 +90,24 @@ class OpenSearchDB(VectorStoreBase):
|
||||
self.client.indices.create(index=name, body=index_settings)
|
||||
logger.info(f"Created index {name}")
|
||||
|
||||
# Wait for index to be ready
|
||||
max_retries = 60 # 60 seconds timeout
|
||||
retry_count = 0
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# Check if index is ready by attempting a simple search
|
||||
self.client.search(index=name, body={"query": {"match_all": {}}})
|
||||
logger.info(f"Index {name} is ready")
|
||||
time.sleep(1)
|
||||
return
|
||||
except Exception:
|
||||
retry_count += 1
|
||||
if retry_count == max_retries:
|
||||
raise TimeoutError(
|
||||
f"Index {name} creation timed out after {max_retries} seconds"
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
def insert(
|
||||
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
|
||||
) -> List[OutputData]:
|
||||
@@ -98,74 +118,161 @@ class OpenSearchDB(VectorStoreBase):
|
||||
if payloads is None:
|
||||
payloads = [{} for _ in range(len(vectors))]
|
||||
|
||||
actions = []
|
||||
for i, (vec, id_) in enumerate(zip(vectors, ids)):
|
||||
action = {
|
||||
"_index": self.collection_name,
|
||||
"_id": id_,
|
||||
"_source": {
|
||||
"vector": vec,
|
||||
"metadata": payloads[i], # Store metadata in the metadata field
|
||||
},
|
||||
body = {
|
||||
"vector_field": vec,
|
||||
"payload": payloads[i],
|
||||
"id": id_,
|
||||
}
|
||||
actions.append(action)
|
||||
|
||||
bulk(self.client, actions)
|
||||
self.client.index(index=self.collection_name, body=body)
|
||||
|
||||
results = []
|
||||
for i, id_ in enumerate(ids):
|
||||
results.append(OutputData(id=id_, score=1.0, payload=payloads[i]))
|
||||
|
||||
return results
|
||||
|
||||
def search(
|
||||
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
||||
) -> List[OutputData]:
|
||||
"""Search for similar vectors using OpenSearch k-NN search with pre-filtering."""
|
||||
search_query = {
|
||||
"size": limit,
|
||||
"query": {
|
||||
"knn": {
|
||||
"vector": {
|
||||
"vector": vectors,
|
||||
"k": limit,
|
||||
}
|
||||
"""Search for similar vectors using OpenSearch k-NN search with optional filters."""
|
||||
|
||||
# Base KNN query
|
||||
knn_query = {
|
||||
"knn": {
|
||||
"vector_field": {
|
||||
"vector": vectors,
|
||||
"k": limit * 2,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# Start building the full query
|
||||
query_body = {
|
||||
"size": limit * 2,
|
||||
"query": None
|
||||
}
|
||||
|
||||
# Prepare filter conditions if applicable
|
||||
filter_clauses = []
|
||||
if filters:
|
||||
filter_conditions = [{"term": {f"metadata.{key}": value}} for key, value in filters.items()]
|
||||
search_query["query"]["knn"]["vector"]["filter"] = {"bool": {"filter": filter_conditions}}
|
||||
for key in ["user_id", "run_id", "agent_id"]:
|
||||
value = filters.get(key)
|
||||
if value:
|
||||
filter_clauses.append({
|
||||
"term": {f"payload.{key}.keyword": value}
|
||||
})
|
||||
|
||||
response = self.client.search(index=self.collection_name, body=search_query)
|
||||
# Combine knn with filters if needed
|
||||
if filter_clauses:
|
||||
query_body["query"] = {
|
||||
"bool": {
|
||||
"must": knn_query,
|
||||
"filter": filter_clauses
|
||||
}
|
||||
}
|
||||
else:
|
||||
query_body["query"] = knn_query
|
||||
|
||||
# Execute search
|
||||
response = self.client.search(index=self.collection_name, body=query_body)
|
||||
|
||||
hits = response["hits"]["hits"]
|
||||
results = [
|
||||
OutputData(id=hit["_id"], score=hit["_score"], payload=hit["_source"].get("metadata", {}))
|
||||
for hit in response["hits"]["hits"]
|
||||
OutputData(
|
||||
id=hit["_source"].get("id"),
|
||||
score=hit["_score"],
|
||||
payload=hit["_source"].get("payload", {})
|
||||
)
|
||||
for hit in hits
|
||||
]
|
||||
return results
|
||||
|
||||
def delete(self, vector_id: str) -> None:
|
||||
"""Delete a vector by ID."""
|
||||
self.client.delete(index=self.collection_name, id=vector_id)
|
||||
"""Delete a vector by custom ID."""
|
||||
# First, find the document by custom ID
|
||||
search_query = {
|
||||
"query": {
|
||||
"term": {
|
||||
"id": vector_id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response = self.client.search(index=self.collection_name, body=search_query)
|
||||
hits = response.get("hits", {}).get("hits", [])
|
||||
|
||||
if not hits:
|
||||
return
|
||||
|
||||
opensearch_id = hits[0]["_id"]
|
||||
|
||||
# Delete using the actual document ID
|
||||
self.client.delete(index=self.collection_name, id=opensearch_id)
|
||||
|
||||
|
||||
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
||||
"""Update a vector and its payload."""
|
||||
"""Update a vector and its payload using the custom 'id' field."""
|
||||
|
||||
# First, find the document by custom ID
|
||||
search_query = {
|
||||
"query": {
|
||||
"term": {
|
||||
"id": vector_id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response = self.client.search(index=self.collection_name, body=search_query)
|
||||
hits = response.get("hits", {}).get("hits", [])
|
||||
|
||||
if not hits:
|
||||
return
|
||||
|
||||
opensearch_id = hits[0]["_id"] # The actual document ID in OpenSearch
|
||||
|
||||
# Prepare updated fields
|
||||
doc = {}
|
||||
if vector is not None:
|
||||
doc["vector"] = vector
|
||||
doc["vector_field"] = vector
|
||||
if payload is not None:
|
||||
doc["metadata"] = payload
|
||||
doc["payload"] = payload
|
||||
|
||||
if doc:
|
||||
try:
|
||||
response = self.client.update(index=self.collection_name, id=opensearch_id, body={"doc": doc})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.client.update(index=self.collection_name, id=vector_id, body={"doc": doc})
|
||||
|
||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||
"""Retrieve a vector by ID."""
|
||||
try:
|
||||
response = self.client.get(index=self.collection_name, id=vector_id)
|
||||
return OutputData(id=response["_id"], score=1.0, payload=response["_source"].get("metadata", {}))
|
||||
# First check if index exists
|
||||
if not self.client.indices.exists(index=self.collection_name):
|
||||
logger.info(f"Index {self.collection_name} does not exist, creating it...")
|
||||
self.create_col(self.collection_name, self.embedding_model_dims)
|
||||
return None
|
||||
|
||||
search_query = {
|
||||
"query": {
|
||||
"term": {
|
||||
"id": vector_id
|
||||
}
|
||||
}
|
||||
}
|
||||
response = self.client.search(index=self.collection_name, body=search_query)
|
||||
|
||||
hits = response["hits"]["hits"]
|
||||
|
||||
if not hits:
|
||||
return None
|
||||
|
||||
return OutputData(
|
||||
id=hits[0]["_source"].get("id"),
|
||||
score=1.0,
|
||||
payload=hits[0]["_source"].get("payload", {})
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector {vector_id}: {e}")
|
||||
logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
@@ -180,28 +287,52 @@ class OpenSearchDB(VectorStoreBase):
|
||||
"""Get information about a collection (index)."""
|
||||
return self.client.indices.get(index=name)
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
|
||||
"""List all memories."""
|
||||
query = {"query": {"match_all": {}}}
|
||||
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
|
||||
|
||||
if filters:
|
||||
query["query"] = {
|
||||
"bool": {"must": [{"term": {f"metadata.{key}": value}} for key, value in filters.items()]}
|
||||
try:
|
||||
"""List all memories with optional filters."""
|
||||
query: Dict = {
|
||||
"query": {
|
||||
"match_all": {}
|
||||
}
|
||||
}
|
||||
|
||||
if limit:
|
||||
query["size"] = limit
|
||||
filter_clauses = []
|
||||
if filters:
|
||||
for key in ["user_id", "run_id", "agent_id"]:
|
||||
value = filters.get(key)
|
||||
if value:
|
||||
filter_clauses.append({
|
||||
"term": {f"payload.{key}.keyword": value}
|
||||
})
|
||||
|
||||
if filter_clauses:
|
||||
query["query"] = {
|
||||
"bool": {
|
||||
"filter": filter_clauses
|
||||
}
|
||||
}
|
||||
|
||||
if limit:
|
||||
query["size"] = limit
|
||||
|
||||
response = self.client.search(index=self.collection_name, body=query)
|
||||
hits = response["hits"]["hits"]
|
||||
|
||||
return [[
|
||||
OutputData(
|
||||
id=hit["_source"].get("id"),
|
||||
score=1.0,
|
||||
payload=hit["_source"].get("payload", {})
|
||||
)
|
||||
for hit in hits
|
||||
]]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
response = self.client.search(index=self.collection_name, body=query)
|
||||
return [
|
||||
[
|
||||
OutputData(id=hit["_id"], score=1.0, payload=hit["_source"].get("metadata", {}))
|
||||
for hit in response["hits"]["hits"]
|
||||
]
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_index()
|
||||
self.create_col(self.collection_name, self.embedding_model_dims)
|
||||
|
||||
Reference in New Issue
Block a user