Formatting (#2750)

This commit is contained in:
Dev Khant
2025-05-22 01:17:29 +05:30
committed by GitHub
parent dff91154a7
commit d85fcda037
71 changed files with 1391 additions and 1823 deletions

View File

@@ -1,6 +1,6 @@
import logging
from typing import Any, Dict, List, Optional
import time
from typing import Any, Dict, List, Optional
try:
from opensearchpy import OpenSearch, RequestsHttpConnection
@@ -34,7 +34,7 @@ class OpenSearchDB(VectorStoreBase):
use_ssl=config.use_ssl,
verify_certs=config.verify_certs,
connection_class=RequestsHttpConnection,
pool_maxsize=20
pool_maxsize=20,
)
self.collection_name = config.collection_name
@@ -69,9 +69,7 @@ class OpenSearchDB(VectorStoreBase):
def create_col(self, name: str, vector_size: int) -> None:
"""Create a new collection (index in OpenSearch)."""
index_settings = {
"settings": {
"index.knn": True
},
"settings": {"index.knn": True},
"mappings": {
"properties": {
"vector_field": {
@@ -82,7 +80,7 @@ class OpenSearchDB(VectorStoreBase):
"payload": {"type": "object"},
"id": {"type": "keyword"},
}
}
},
}
if not self.client.indices.exists(index=name):
@@ -102,9 +100,7 @@ class OpenSearchDB(VectorStoreBase):
except Exception:
retry_count += 1
if retry_count == max_retries:
raise TimeoutError(
f"Index {name} creation timed out after {max_retries} seconds"
)
raise TimeoutError(f"Index {name} creation timed out after {max_retries} seconds")
time.sleep(0.5)
def insert(
@@ -145,10 +141,7 @@ class OpenSearchDB(VectorStoreBase):
}
# Start building the full query
query_body = {
"size": limit * 2,
"query": None
}
query_body = {"size": limit * 2, "query": None}
# Prepare filter conditions if applicable
filter_clauses = []
@@ -156,18 +149,11 @@ class OpenSearchDB(VectorStoreBase):
for key in ["user_id", "run_id", "agent_id"]:
value = filters.get(key)
if value:
filter_clauses.append({
"term": {f"payload.{key}.keyword": value}
})
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
# Combine knn with filters if needed
if filter_clauses:
query_body["query"] = {
"bool": {
"must": knn_query,
"filter": filter_clauses
}
}
query_body["query"] = {"bool": {"must": knn_query, "filter": filter_clauses}}
else:
query_body["query"] = knn_query
@@ -176,11 +162,7 @@ class OpenSearchDB(VectorStoreBase):
hits = response["hits"]["hits"]
results = [
OutputData(
id=hit["_source"].get("id"),
score=hit["_score"],
payload=hit["_source"].get("payload", {})
)
OutputData(id=hit["_source"].get("id"), score=hit["_score"], payload=hit["_source"].get("payload", {}))
for hit in hits
]
return results
@@ -188,13 +170,7 @@ class OpenSearchDB(VectorStoreBase):
def delete(self, vector_id: str) -> None:
"""Delete a vector by custom ID."""
# First, find the document by custom ID
search_query = {
"query": {
"term": {
"id": vector_id
}
}
}
search_query = {"query": {"term": {"id": vector_id}}}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response.get("hits", {}).get("hits", [])
@@ -207,18 +183,11 @@ class OpenSearchDB(VectorStoreBase):
# Delete using the actual document ID
self.client.delete(index=self.collection_name, id=opensearch_id)
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
"""Update a vector and its payload using the custom 'id' field."""
# First, find the document by custom ID
search_query = {
"query": {
"term": {
"id": vector_id
}
}
}
search_query = {"query": {"term": {"id": vector_id}}}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response.get("hits", {}).get("hits", [])
@@ -241,7 +210,6 @@ class OpenSearchDB(VectorStoreBase):
except Exception:
pass
def get(self, vector_id: str) -> Optional[OutputData]:
"""Retrieve a vector by ID."""
try:
@@ -251,13 +219,7 @@ class OpenSearchDB(VectorStoreBase):
self.create_col(self.collection_name, self.embedding_model_dims)
return None
search_query = {
"query": {
"term": {
"id": vector_id
}
}
}
search_query = {"query": {"term": {"id": vector_id}}}
response = self.client.search(index=self.collection_name, body=search_query)
hits = response["hits"]["hits"]
@@ -265,11 +227,7 @@ class OpenSearchDB(VectorStoreBase):
if not hits:
return None
return OutputData(
id=hits[0]["_source"].get("id"),
score=1.0,
payload=hits[0]["_source"].get("payload", {})
)
return OutputData(id=hits[0]["_source"].get("id"), score=1.0, payload=hits[0]["_source"].get("payload", {}))
except Exception as e:
logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
return None
@@ -287,30 +245,19 @@ class OpenSearchDB(VectorStoreBase):
return self.client.indices.get(index=name)
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
try:
"""List all memories with optional filters."""
query: Dict = {
"query": {
"match_all": {}
}
}
query: Dict = {"query": {"match_all": {}}}
filter_clauses = []
if filters:
for key in ["user_id", "run_id", "agent_id"]:
value = filters.get(key)
if value:
filter_clauses.append({
"term": {f"payload.{key}.keyword": value}
})
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
if filter_clauses:
query["query"] = {
"bool": {
"filter": filter_clauses
}
}
query["query"] = {"bool": {"filter": filter_clauses}}
if limit:
query["size"] = limit
@@ -318,18 +265,15 @@ class OpenSearchDB(VectorStoreBase):
response = self.client.search(index=self.collection_name, body=query)
hits = response["hits"]["hits"]
return [[
OutputData(
id=hit["_source"].get("id"),
score=1.0,
payload=hit["_source"].get("payload", {})
)
for hit in hits
]]
return [
[
OutputData(id=hit["_source"].get("id"), score=1.0, payload=hit["_source"].get("payload", {}))
for hit in hits
]
]
except Exception:
return []
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")